-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathMLMSVM.hpp
114 lines (100 loc) · 4.12 KB
/
MLMSVM.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
//----------------------------------------------------------------------
/*!\file
*
* \author Matthias Holoch <[email protected]>
* \date 2015-01-17
*
*/
//----------------------------------------------------------------------
#ifndef MLMSVM_HPP_INCLUDED
#define MLMSVM_HPP_INCLUDED
#include <opencv2/core/core.hpp>
#include <opencv2/ml/ml.hpp>
#include <pcl/registration/transforms.h>
#include "shared_types.hpp"
#include "MLModule.hpp"
/*!
* This class wraps OpenCV's Support Vector Machine for use on correspondences inside mlreg.
* It uses the FPFH-distance correspondences for its classification.
* The classified classes are "true correspondence" or "false correspondence".
* For training, it'll use correspondences from a TransformationHint, which also contains
* a transformation. This transformation can come from multiple sources:
* ground-truth data when training the algorithm for a known environment
* some intialization cloud obtained without moving the sensor and artificially put noise on the cloud
* robot odometry while it has high preicison
* ...
*/
class MLMSVM : public MLModule {
public:
/*!
* This struct is used for storing the parameters used by the MLMSVM class.
*/
struct Parameters {
float max_corr_distance_squared = 0.04;
std::string model_store_path = "svm_model.yaml";
};
/*!
* Default constructor with parameters.
*/
MLMSVM(struct Parameters& params)
: svm_(),
params_(params),
ready_(false)
{
}
void train(const Digest::Ptr& digest_source, const Digest::Ptr& digest_target, const TransformationHint& transformation_hint, const Correspondences& correspondences) {
// Stores the source cloud
Digest::Cloud::ConstPtr cloud_source = digest_source->getReducedCloud();
// Stores the transformed target cloud
Digest::Cloud::Ptr cloud_target(new Digest::Cloud);
// Stores the training data
cv::Mat trainingData(correspondences.size(), 33, CV_32FC1);
// Stores the training labels
cv::Mat labels(correspondences.size(), 1, CV_32SC1);
// Transform the target cloud with the tf from the TransformationHint:
pcl::transformPointCloud(*(digest_target->getReducedCloud()), *cloud_target, transformation_hint.transformation);
// Iterate through all correspondences and define their class by the distance between the transformed corresponding points
for (unsigned int i = 0; i < correspondences.size(); ++i) {
Digest::PointType p_src = cloud_source->at(digest_source->getDescriptorCloudIndices()->at(correspondences[i].source_id));
Digest::PointType p_trg = cloud_target->at(digest_target->getDescriptorCloudIndices()->at(correspondences[i].target_id));
Digest::PointType p_diff(p_src.x - p_trg.x, p_src.y - p_trg.y, p_src.z - p_trg.z);
// set the class of the correspondence depending on their squared euclidean distance
labels.at<int>(i,0) = (p_diff.x * p_diff.x + p_diff.y * p_diff.y + p_diff.z * p_diff.z <= params_.max_corr_distance_squared);
// add the feature vector to the training data
for (unsigned int j = 0; j < 33; ++j) {
trainingData.at<float>(i,j) = correspondences[i].distance.histogram[j];
}
}
// train the svm
svm_.train(trainingData, labels, cv::Mat(), cv::Mat(), cv::SVMParams());
ready_ = true;
}
float classify(const Correspondence& correspondence) const {
cv::Mat sample(1, 33, CV_32FC1);
for (unsigned int j = 0; j < 33; ++j) {
sample.at<float>(0,j) = correspondence.distance.histogram[j];
}
return svm_.predict(sample);
}
bool isReady() const {
return ready_;
}
/*!
* Loads the svm model from model_store_path. (see MLMSVM::Parameters)
*/
void loadModel() {
svm_.load(params_.model_store_path.c_str());
ready_ = true;
}
/*
* Saves the svm model to model_store_path. (see MLMSVM::Parameters)
*/
void saveModel() const {
svm_.save(params_.model_store_path.c_str());
}
protected:
cv::SVM svm_;
struct Parameters params_;
bool ready_;
};
#endif