mirror of
https://github.com/opencv/opencv_contrib.git
synced 2025-10-19 02:16:34 +08:00
Added new method for training forest
This commit is contained in:
@@ -63,20 +63,41 @@ namespace optflow
|
||||
struct CV_EXPORTS_W GPCPatchDescriptor
|
||||
{
|
||||
static const unsigned nFeatures = 18; // number of features in a patch descriptor
|
||||
Vec<double, nFeatures> feature;
|
||||
Vec< double, nFeatures > feature;
|
||||
|
||||
GPCPatchDescriptor( const Mat *imgCh, int i, int j );
|
||||
};
|
||||
|
||||
typedef std::pair<GPCPatchDescriptor, GPCPatchDescriptor> GPCPatchSample;
|
||||
typedef std::vector<GPCPatchSample> GPCSamplesVector;
|
||||
typedef std::pair< GPCPatchDescriptor, GPCPatchDescriptor > GPCPatchSample;
|
||||
typedef std::vector< GPCPatchSample > GPCSamplesVector;
|
||||
|
||||
/** @brief Class encapsulating training samples.
|
||||
*/
|
||||
class CV_EXPORTS_W GPCTrainingSamples
|
||||
{
|
||||
private:
|
||||
GPCSamplesVector samples;
|
||||
|
||||
public:
|
||||
/** @brief This function can be used to extract samples from a pair of images and a ground truth flow.
|
||||
* Sizes of all the provided vectors must be equal.
|
||||
*/
|
||||
static Ptr< GPCTrainingSamples > create( const std::vector< String > &imagesFrom, const std::vector< String > &imagesTo,
|
||||
const std::vector< String > > );
|
||||
|
||||
size_t size() const { return samples.size(); }
|
||||
|
||||
operator GPCSamplesVector() const { return samples; }
|
||||
|
||||
operator GPCSamplesVector &() { return samples; }
|
||||
};
|
||||
|
||||
class CV_EXPORTS_W GPCTree : public Algorithm
|
||||
{
|
||||
public:
|
||||
struct Node
|
||||
{
|
||||
Vec<double, GPCPatchDescriptor::nFeatures> coef; // hyperplane coefficients
|
||||
Vec< double, GPCPatchDescriptor::nFeatures > coef; // hyperplane coefficients
|
||||
double rhs;
|
||||
unsigned left;
|
||||
unsigned right;
|
||||
@@ -87,7 +108,7 @@ public:
|
||||
private:
|
||||
typedef GPCSamplesVector::iterator SIter;
|
||||
|
||||
std::vector<Node> nodes;
|
||||
std::vector< Node > nodes;
|
||||
|
||||
bool trainNode( size_t nodeId, SIter begin, SIter end, unsigned depth );
|
||||
|
||||
@@ -98,23 +119,38 @@ public:
|
||||
|
||||
void read( const FileNode &fn );
|
||||
|
||||
static Ptr<GPCTree> create() { return makePtr<GPCTree>(); }
|
||||
static Ptr< GPCTree > create() { return makePtr< GPCTree >(); }
|
||||
|
||||
bool operator==( const GPCTree &t ) const { return nodes == t.nodes; }
|
||||
};
|
||||
|
||||
template <int T> class CV_EXPORTS_W GPCForest : public Algorithm
|
||||
template < int T > class CV_EXPORTS_W GPCForest : public Algorithm
|
||||
{
|
||||
private:
|
||||
GPCTree tree[T];
|
||||
|
||||
public:
|
||||
/** @brief Train the forest using one sample set for every tree.
|
||||
* Please, consider using the next method instead of this one for better quality.
|
||||
*/
|
||||
void train( GPCSamplesVector &samples )
|
||||
{
|
||||
for ( int i = 0; i < T; ++i )
|
||||
tree[i].train( samples );
|
||||
}
|
||||
|
||||
/** @brief Train the forest using individual samples for each tree.
|
||||
* It is generally better to use this instead of the first method.
|
||||
*/
|
||||
void train( const std::vector< String > &imagesFrom, const std::vector< String > &imagesTo, const std::vector< String > > )
|
||||
{
|
||||
for ( int i = 0; i < T; ++i )
|
||||
{
|
||||
Ptr< GPCTrainingSamples > samples = GPCTrainingSamples::create( imagesFrom, imagesTo, gt ); // Create training set for the tree
|
||||
tree[i].train( *samples );
|
||||
}
|
||||
}
|
||||
|
||||
void write( FileStorage &fs ) const
|
||||
{
|
||||
fs << "ntrees" << T << "trees"
|
||||
@@ -136,28 +172,7 @@ public:
|
||||
tree[i].read( *it );
|
||||
}
|
||||
|
||||
static Ptr<GPCForest> create() { return makePtr<GPCForest>(); }
|
||||
};
|
||||
|
||||
/** @brief Class encapsulating training samples.
|
||||
*/
|
||||
class CV_EXPORTS_W GPCTrainingSamples
|
||||
{
|
||||
private:
|
||||
GPCSamplesVector samples;
|
||||
|
||||
public:
|
||||
/** @brief This function can be used to extract samples from a pair of images and a ground truth flow.
|
||||
* Sizes of all the provided vectors must be equal.
|
||||
*/
|
||||
static Ptr<GPCTrainingSamples> create( const std::vector<String> &imagesFrom, const std::vector<String> &imagesTo,
|
||||
const std::vector<String> > );
|
||||
|
||||
size_t size() const { return samples.size(); }
|
||||
|
||||
operator GPCSamplesVector() const { return samples; }
|
||||
|
||||
operator GPCSamplesVector &() { return samples; }
|
||||
static Ptr< GPCForest > create() { return makePtr< GPCForest >(); }
|
||||
};
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user