mirror of
https://github.com/opencv/opencv_contrib.git
synced 2025-10-19 19:44:14 +08:00
[GSoC] Implementation of the Global Patch Collider and demo for PCAFlow (#752)
* Minor fixes * Start adding correspondence finding * Added finding of correspondences using GPC * New evaluation tool for GPC * Changed default parameters * Display ground truth in the evaluation tool * Added training tool for MPI Sintel dataset * Added the training tool for Middlebury dataset * Added some OpenCL optimization * Added explanatory notes * Minor improvements: time measurements + little ocl optimization * Added demos * Fixed warnings * Make parameter struct assignable * Fix warning * Proper command line argument usage * Prettified training tool, added parameters * Fixed VS warning * Fixed VS warning * Using of compressed forest.yml.gz files by default to save space * Added OpenCL flag to the evaluation tool * Updated documentation * Major speed and memory improvements: 1) Added new (optional) type of patch descriptors which are much faster. Retraining with option --descriptor-type=1 is required. 2) Got rid of hash table for descriptors, less memory usage. * Fixed various floating point errors related to precision. SIMD for dot product, forest traversing is a little bit faster now. * Tolerant floating point comparison * Triplets * Added comment * Choosing negative sample among nearest neighbors * Fix warning * Usage of parallel_for_() in critical places. Performance improvments. * Simulated annealing heuristic * Moved OpenCL kernel to separate file * Moved implementation to source file * Added basic accuracy tests for GPC and PCAFlow * Fixing warnings * Test accuracy constraints were too strict * Test accuracy constraints were too strict * Make tests more lightweight
This commit is contained in:

committed by
Maksim Shabunin

parent
25575af653
commit
ac62d70f97
164
modules/optflow/samples/gpc_evaluate.cpp
Normal file
164
modules/optflow/samples/gpc_evaluate.cpp
Normal file
@@ -0,0 +1,164 @@
|
||||
#include "opencv2/core/ocl.hpp"
|
||||
#include "opencv2/highgui.hpp"
|
||||
#include "opencv2/imgcodecs.hpp"
|
||||
#include "opencv2/optflow.hpp"
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <stdio.h>
|
||||
|
||||
/* This tool finds correspondences between two images using Global Patch Collider
|
||||
* and calculates error using provided ground truth flow.
|
||||
*
|
||||
* It will look for the file named "forest.yml.gz" with a learned forest.
|
||||
* You can obtain the "forest.yml.gz" either by manually training it using another tool with *_train suffix
|
||||
* or by downloading one of the files trained on some publicly available dataset from here:
|
||||
*
|
||||
* https://drive.google.com/open?id=0B7Hb8cfuzrIIZDFscXVYd0NBNFU
|
||||
*/
|
||||
|
||||
using namespace cv;
|
||||
|
||||
const String keys = "{help h ? | | print this message}"
|
||||
"{@image1 |<none> | image1}"
|
||||
"{@image2 |<none> | image2}"
|
||||
"{@groundtruth |<none> | path to the .flo file}"
|
||||
"{@output | | output to a file instead of displaying, output image path}"
|
||||
"{g gpu | | use OpenCL}"
|
||||
"{f forest |forest.yml.gz| path to the forest.yml.gz}";
|
||||
|
||||
const int nTrees = 5;
|
||||
|
||||
static double normL2( const Point2f &v ) { return sqrt( v.x * v.x + v.y * v.y ); }
|
||||
|
||||
static Vec3d getFlowColor( const Point2f &f, const bool logScale = true, const double scaleDown = 5 )
|
||||
{
|
||||
if ( f.x == 0 && f.y == 0 )
|
||||
return Vec3d( 0, 0, 1 );
|
||||
|
||||
double radius = normL2( f );
|
||||
if ( logScale )
|
||||
radius = log( radius + 1 );
|
||||
radius /= scaleDown;
|
||||
radius = std::min( 1.0, radius );
|
||||
|
||||
double angle = ( atan2( -f.y, -f.x ) + CV_PI ) * 180 / CV_PI;
|
||||
return Vec3d( angle, radius, 1 );
|
||||
}
|
||||
|
||||
static void displayFlow( InputArray _flow, OutputArray _img )
|
||||
{
|
||||
const Size sz = _flow.size();
|
||||
Mat flow = _flow.getMat();
|
||||
_img.create( sz, CV_32FC3 );
|
||||
Mat img = _img.getMat();
|
||||
|
||||
for ( int i = 0; i < sz.height; ++i )
|
||||
for ( int j = 0; j < sz.width; ++j )
|
||||
img.at< Vec3f >( i, j ) = getFlowColor( flow.at< Point2f >( i, j ) );
|
||||
|
||||
cvtColor( img, img, COLOR_HSV2BGR );
|
||||
}
|
||||
|
||||
static bool fileProbe( const char *name ) { return std::ifstream( name ).good(); }
|
||||
|
||||
int main( int argc, const char **argv )
|
||||
{
|
||||
CommandLineParser parser( argc, argv, keys );
|
||||
parser.about( "Global Patch Collider evaluation tool" );
|
||||
|
||||
if ( parser.has( "help" ) )
|
||||
{
|
||||
parser.printMessage();
|
||||
return 0;
|
||||
}
|
||||
|
||||
String fromPath = parser.get< String >( 0 );
|
||||
String toPath = parser.get< String >( 1 );
|
||||
String gtPath = parser.get< String >( 2 );
|
||||
String outPath = parser.get< String >( 3 );
|
||||
const bool useOpenCL = parser.has( "gpu" );
|
||||
String forestDumpPath = parser.get< String >( "forest" );
|
||||
|
||||
if ( !parser.check() )
|
||||
{
|
||||
parser.printErrors();
|
||||
return 1;
|
||||
}
|
||||
|
||||
if ( !fileProbe( forestDumpPath.c_str() ) )
|
||||
{
|
||||
std::cerr << "Can't open the file with a trained model: `" << forestDumpPath
|
||||
<< "`.\nYou can obtain this file either by manually training the model using another tool with *_train suffix or by "
|
||||
"downloading one of the files trained on some publicly available dataset from "
|
||||
"here:\nhttps://drive.google.com/open?id=0B7Hb8cfuzrIIZDFscXVYd0NBNFU"
|
||||
<< std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
ocl::setUseOpenCL( useOpenCL );
|
||||
|
||||
Ptr< optflow::GPCForest< nTrees > > forest = Algorithm::load< optflow::GPCForest< nTrees > >( forestDumpPath );
|
||||
|
||||
Mat from = imread( fromPath );
|
||||
Mat to = imread( toPath );
|
||||
Mat gt = optflow::readOpticalFlow( gtPath );
|
||||
std::vector< std::pair< Point2i, Point2i > > corr;
|
||||
|
||||
TickMeter meter;
|
||||
meter.start();
|
||||
|
||||
forest->findCorrespondences( from, to, corr, optflow::GPCMatchingParams( useOpenCL ) );
|
||||
|
||||
meter.stop();
|
||||
|
||||
std::cout << "Found " << corr.size() << " matches." << std::endl;
|
||||
std::cout << "Time: " << meter.getTimeSec() << " sec." << std::endl;
|
||||
double error = 0;
|
||||
Mat dispErr = Mat::zeros( from.size(), CV_32FC3 );
|
||||
dispErr = Scalar( 0, 0, 1 );
|
||||
Mat disp = Mat::zeros( from.size(), CV_32FC3 );
|
||||
disp = Scalar( 0, 0, 1 );
|
||||
|
||||
for ( size_t i = 0; i < corr.size(); ++i )
|
||||
{
|
||||
const Point2f a = corr[i].first;
|
||||
const Point2f b = corr[i].second;
|
||||
const Point2f c = a + gt.at< Point2f >( corr[i].first.y, corr[i].first.x );
|
||||
error += normL2( b - c );
|
||||
circle( disp, a, 3, getFlowColor( b - a ), -1 );
|
||||
circle( dispErr, a, 3, getFlowColor( b - c, false, 32 ), -1 );
|
||||
}
|
||||
|
||||
error /= corr.size();
|
||||
|
||||
std::cout << "Average endpoint error: " << error << " px." << std::endl;
|
||||
|
||||
cvtColor( disp, disp, COLOR_HSV2BGR );
|
||||
cvtColor( dispErr, dispErr, COLOR_HSV2BGR );
|
||||
|
||||
Mat dispGroundTruth;
|
||||
displayFlow( gt, dispGroundTruth );
|
||||
|
||||
if ( outPath.length() )
|
||||
{
|
||||
putText( disp, "Sparse matching: Global Patch Collider", Point2i( 24, 40 ), FONT_HERSHEY_DUPLEX, 1, Vec3b( 1, 0, 0 ), 2, LINE_AA );
|
||||
char buf[256];
|
||||
sprintf( buf, "Average EPE: %.2f", error );
|
||||
putText( disp, buf, Point2i( 24, 80 ), FONT_HERSHEY_DUPLEX, 1, Vec3b( 1, 0, 0 ), 2, LINE_AA );
|
||||
sprintf( buf, "Number of matches: %u", (unsigned)corr.size() );
|
||||
putText( disp, buf, Point2i( 24, 120 ), FONT_HERSHEY_DUPLEX, 1, Vec3b( 1, 0, 0 ), 2, LINE_AA );
|
||||
disp *= 255;
|
||||
imwrite( outPath, disp );
|
||||
return 0;
|
||||
}
|
||||
|
||||
namedWindow( "Correspondences", WINDOW_AUTOSIZE );
|
||||
imshow( "Correspondences", disp );
|
||||
namedWindow( "Error", WINDOW_AUTOSIZE );
|
||||
imshow( "Error", dispErr );
|
||||
namedWindow( "Ground truth", WINDOW_AUTOSIZE );
|
||||
imshow( "Ground truth", dispGroundTruth );
|
||||
waitKey( 0 );
|
||||
|
||||
return 0;
|
||||
}
|
Reference in New Issue
Block a user