CVB++ 15.0
Loading...
Searching...
No Matches
classification_predictor.hpp
1#pragma once
2
3#include "../_cexports/c_polimago.h"
4
5#include "../global.hpp"
6#include "../string.hpp"
7#include "predictor_base.hpp"
8#include "classification_result.hpp"
9
10#include <utility>
11#include <vector>
12
13namespace Cvb
14{
15 CVB_BEGIN_INLINE_NS
16
18
37 namespace Polimago
38 {
39
52
54
56 class ClassificationPredictor : public PredictorBaseEx
57 {
58 private:
59 // Internal helper constructor version
60 explicit ClassificationPredictor(ReleaseObjectGuard &&guard)
61 : PredictorBaseEx(std::move(guard))
62 {
63 if (TrainingParameters().Usage == CExports::TClassifierUsage::CU_Regression)
64 {
65 throw std::runtime_error(std::string("The object is not a ") + thisObjectName_);
66 }
67
68 auto numClasses = NumClasses();
69 for (int i = 0; i < numClasses; ++i)
70 {
72 CVB_CALL_CAPI_CHECKED(PMGetClfClassLabelTyped(Handle(), i, lbl.data()));
73 classes_.push_back(lbl.data());
74 }
75 }
76
77 // Helper to load a saved predictor from file
78 static CExports::TCLF LoadInternal(const String &fileName)
79 {
80 CExports::TCLF predictor = nullptr;
81
82 CVB_CALL_CAPI_CHECKED(PMOpenClfTyped(fileName.c_str(), predictor));
83 return predictor;
84 }
85
86 void SaveFunction(const String &fileName) const override
87 {
88 CVB_CALL_CAPI_CHECKED(PMSaveClfTyped(fileName.c_str(), Handle()));
89 }
90
91 std::string ObjectName() const override
92 {
93 return thisObjectName_;
94 }
95
96 public:
98
102 explicit ClassificationPredictor(const String &fileName)
103 : ClassificationPredictor(ReleaseObjectGuard(LoadInternal(fileName)))
104 {
105 fileName_ = fileName;
106 }
107
109
116 static std::unique_ptr<ClassificationPredictor> FromHandle(ReleaseObjectGuard &&guard)
117 {
118 if (!guard.Handle())
119 {
120 throw std::invalid_argument("invalid classification predictor native handle");
121 }
122 return std::unique_ptr<ClassificationPredictor>(new ClassificationPredictor(std::move(guard)));
123 }
124
126
132 {
133 return std::make_unique<ClassificationPredictor>(fileName);
134 }
135
137
142 {
143 return classes_;
144 }
145
147
152 {
153 if (TrainingParameters().Usage == CExports::TClassifierUsage::CU_ClassifyOneVersusAll)
154 {
156 }
157 else if (TrainingParameters().Usage == CExports::TClassifierUsage::CU_ClassifyOneVersusOne)
158 {
160 }
161 else
162 {
163 throw std::runtime_error("invalid classification type");
164 }
165 }
166
168
172 int OutputDimension() const
173 {
174 return CVB_CALL_CAPI(PMGetOutputDimension(Handle()));
175 }
176
178
182 int NumClasses() const
183 {
184 return static_cast<int>(CVB_CALL_CAPI(PMGetNumClasses(Handle())));
185 }
186
188
196 {
197 VerifyCompatibility(img, pos);
198
200 double confidence = 0.0;
201 confidences.assign(NumClasses(), 0.0);
202 CVB_CALL_CAPI_CHECKED(PMClassifyTyped(Handle(), img.Handle(), pos.X(), pos.Y(), lbl.data(), classNameMaxLength_,
203 confidence, confidences.data()));
204
205 return ClassificationResult(lbl.data(), confidence);
206 }
207
209
216 {
217 std::vector<double> confidences;
218 return Classify(img, pos, confidences);
219 }
220
221 private:
222 std::vector<String> classes_;
223 static const int classNameMaxLength_ = 256;
224 const std::string thisObjectName_ = "Polimago Classification Predictor";
225 };
226
229
230 } /* namespace Polimago */
231 CVB_END_INLINE_NS
232} /* namespace Cvb */
The Common Vision Blox image.
Definition decl_image.hpp:50
void * Handle() const noexcept
Classic API image handle.
Definition decl_image.hpp:237
Multi-purpose 2D vector class.
Definition point_2d.hpp:20
T X() const noexcept
Gets the x-component of the point.
Definition point_2d.hpp:84
T Y() const noexcept
Gets the y-component of the point.
Definition point_2d.hpp:104
int NumClasses() const
Number of classes a classification predictor has been trained for.
Definition classification_predictor.hpp:182
static std::unique_ptr< ClassificationPredictor > FromHandle(ReleaseObjectGuard &&guard)
Creates predictor from a classic API handle.
Definition classification_predictor.hpp:116
static std::unique_ptr< ClassificationPredictor > Load(const String &fileName)
Load a saved predictor from a file.
Definition classification_predictor.hpp:131
ClassificationResult Classify(const Image &img, Point2D< int > pos)
Classify a location inside an image.
Definition classification_predictor.hpp:215
ClassificationPredictor(const String &fileName)
Load a saved Polimago classification predictor from a file.
Definition classification_predictor.hpp:102
int OutputDimension() const
Dimension of results generated by this predictor.
Definition classification_predictor.hpp:172
std::vector< String > Classes() const
Class labels available in this predictor.
Definition classification_predictor.hpp:141
ClassificationType Classification() const
The classification type for which this classifier has been generated.
Definition classification_predictor.hpp:151
ClassificationResult Classify(const Image &img, Point2D< int > pos, std::vector< double > &confidences)
Classify a location inside an image.
Definition classification_predictor.hpp:195
Polimago classification result container.
Definition classification_result.hpp:20
void * Handle() const noexcept
Classic API Polimago handle.
Definition predictor_base.hpp:68
T move(T... args)
Namespace for the Polimago package.
Definition classification_predictor.hpp:38
std::shared_ptr< ClassificationPredictor > ClassificationPredictorPtr
Convenience shared pointer for ClassificationPredictor.
Definition classification_predictor.hpp:228
ClassificationType
Determine the classification type to be carried out.
Definition classification_predictor.hpp:42
@ None
The enum element indicating undefined state.
Definition classification_predictor.hpp:44
@ OneVersusAll
Definition classification_predictor.hpp:47
@ OneVersusOne
Definition classification_predictor.hpp:50
Root namespace for the Image Manager interface.
Definition version.hpp:11
std::string String
String for wide characters or unicode characters.
Definition string.hpp:49