1
0
mirror of https://github.com/opencv/opencv_contrib.git synced 2025-10-21 06:11:09 +08:00

Added new method for training forest

This commit is contained in:
Vladislav Samsonov
2016-07-25 03:22:31 +03:00
parent 7f93d951d3
commit 17831add02
3 changed files with 79 additions and 67 deletions

View File

@@ -71,6 +71,27 @@ struct CV_EXPORTS_W GPCPatchDescriptor
typedef std::pair< GPCPatchDescriptor, GPCPatchDescriptor > GPCPatchSample; typedef std::pair< GPCPatchDescriptor, GPCPatchDescriptor > GPCPatchSample;
typedef std::vector< GPCPatchSample > GPCSamplesVector; typedef std::vector< GPCPatchSample > GPCSamplesVector;
/** @brief Class encapsulating training samples.
*/
class CV_EXPORTS_W GPCTrainingSamples
{
private:
GPCSamplesVector samples;
public:
/** @brief This function can be used to extract samples from a pair of images and a ground truth flow.
* Sizes of all the provided vectors must be equal.
*/
static Ptr< GPCTrainingSamples > create( const std::vector< String > &imagesFrom, const std::vector< String > &imagesTo,
const std::vector< String > &gt );
size_t size() const { return samples.size(); }
operator GPCSamplesVector() const { return samples; }
operator GPCSamplesVector &() { return samples; }
};
class CV_EXPORTS_W GPCTree : public Algorithm class CV_EXPORTS_W GPCTree : public Algorithm
{ {
public: public:
@@ -109,12 +130,27 @@ private:
GPCTree tree[T]; GPCTree tree[T];
public: public:
/** @brief Train the forest using one sample set for every tree.
* Please, consider using the next method instead of this one for better quality.
*/
void train( GPCSamplesVector &samples ) void train( GPCSamplesVector &samples )
{ {
for ( int i = 0; i < T; ++i ) for ( int i = 0; i < T; ++i )
tree[i].train( samples ); tree[i].train( samples );
} }
/** @brief Train the forest using individual samples for each tree.
* It is generally better to use this instead of the first method.
*/
void train( const std::vector< String > &imagesFrom, const std::vector< String > &imagesTo, const std::vector< String > &gt )
{
for ( int i = 0; i < T; ++i )
{
Ptr< GPCTrainingSamples > samples = GPCTrainingSamples::create( imagesFrom, imagesTo, gt ); // Create training set for the tree
tree[i].train( *samples );
}
}
void write( FileStorage &fs ) const void write( FileStorage &fs ) const
{ {
fs << "ntrees" << T << "trees" fs << "ntrees" << T << "trees"
@@ -138,27 +174,6 @@ public:
static Ptr< GPCForest > create() { return makePtr< GPCForest >(); } static Ptr< GPCForest > create() { return makePtr< GPCForest >(); }
}; };
/** @brief Class encapsulating training samples.
*/
class CV_EXPORTS_W GPCTrainingSamples
{
private:
GPCSamplesVector samples;
public:
/** @brief This function can be used to extract samples from a pair of images and a ground truth flow.
* Sizes of all the provided vectors must be equal.
*/
static Ptr<GPCTrainingSamples> create( const std::vector<String> &imagesFrom, const std::vector<String> &imagesTo,
const std::vector<String> &gt );
size_t size() const { return samples.size(); }
operator GPCSamplesVector() const { return samples; }
operator GPCSamplesVector &() { return samples; }
};
} }
CV_EXPORTS void write( FileStorage &fs, const String &name, const optflow::GPCTree::Node &node ); CV_EXPORTS void write( FileStorage &fs, const String &name, const optflow::GPCTree::Node &node );

View File

@@ -23,12 +23,8 @@ int main( int argc, const char **argv )
gt.push_back( argv[1 + i * 3 + 2] ); gt.push_back( argv[1 + i * 3 + 2] );
} }
cv::Ptr<cv::optflow::GPCTrainingSamples> ts = cv::optflow::GPCTrainingSamples::create( img1, img2, gt );
std::cout << "Got " << ts->size() << " samples." << std::endl;
cv::Ptr< cv::optflow::GPCForest< nTrees > > forest = cv::optflow::GPCForest< nTrees >::create(); cv::Ptr< cv::optflow::GPCForest< nTrees > > forest = cv::optflow::GPCForest< nTrees >::create();
forest->train( *ts ); forest->train( img1, img2, gt );
forest->save( "forest.dump" ); forest->save( "forest.dump" );
return 0; return 0;

View File

@@ -116,7 +116,8 @@ void getTrainingSamples( const Mat &from, const Mat &to, const Mat &gt, GPCSampl
for ( int j = patchRadius; j + patchRadius < sz.width; ++j ) for ( int j = patchRadius; j + patchRadius < sz.width; ++j )
mag.push_back( Magnitude( normL2Sqr( gt.at< Vec2f >( i, j ) ), i, j ) ); mag.push_back( Magnitude( normL2Sqr( gt.at< Vec2f >( i, j ) ), i, j ) );
size_t n = mag.size() * thresholdMagnitudeFrac; size_t n = mag.size() * thresholdMagnitudeFrac; // As suggested in the paper, we discard part of the training samples
// with a small displacement and train to better distinguish hard pairs.
std::nth_element( mag.begin(), mag.begin() + n, mag.end() ); std::nth_element( mag.begin(), mag.begin() + n, mag.end() );
mag.resize( n ); mag.resize( n );
std::random_shuffle( mag.begin(), mag.end() ); std::random_shuffle( mag.begin(), mag.end() );