mirror of
https://github.com/opencv/opencv_contrib.git
synced 2025-10-23 18:09:25 +08:00
Added im2row, tiny optimiziations
This commit is contained in:
@@ -20,13 +20,14 @@ const String keys =
|
||||
"https://www.dropbox.com/sh/dywzk3gyb12hpe5/AAD5YkUa8XgMpHs2gCRgmCVCa }"
|
||||
"{model m || path to Torch .net model file (model_best.net) }"
|
||||
"{image i || path to image file }"
|
||||
"{i_blob | .0 | input blob name) }"
|
||||
"{o_blob || output blob name) }"
|
||||
"{c_names c || path to file with classnames for channels (categories.txt) }"
|
||||
"{c_names c || path to file with classnames for channels (optional, categories.txt) }"
|
||||
"{result r || path to save output blob (optional, binary format, NCHW order) }"
|
||||
"{show s || whether to show all output channels or not}"
|
||||
;
|
||||
|
||||
std::vector<String> readClassNames(const char *filename);
|
||||
static void colorizeSegmentation(Blob &score, Mat &segm,
|
||||
Mat &legend, vector<String> &classNames);
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
@@ -40,8 +41,6 @@ int main(int argc, char **argv)
|
||||
|
||||
String modelFile = parser.get<String>("model");
|
||||
String imageFile = parser.get<String>("image");
|
||||
String inBlobName = parser.get<String>("i_blob");
|
||||
String outBlobName = parser.get<String>("o_blob");
|
||||
|
||||
if (!parser.check())
|
||||
{
|
||||
@@ -78,7 +77,7 @@ int main(int argc, char **argv)
|
||||
//! [Initialize network]
|
||||
|
||||
//! [Prepare blob]
|
||||
Mat img = imread(imageFile);
|
||||
Mat img = imread(imageFile), input;
|
||||
if (img.empty())
|
||||
{
|
||||
std::cerr << "Can't read image from the file: " << imageFile << std::endl;
|
||||
@@ -91,15 +90,15 @@ int main(int argc, char **argv)
|
||||
resize(img, img, inputImgSize); //Resize image to input size
|
||||
|
||||
if(img.channels() == 3)
|
||||
cv::cvtColor(img, img, cv::COLOR_BGR2RGB);
|
||||
cv::cvtColor(img, input, cv::COLOR_BGR2RGB);
|
||||
|
||||
img.convertTo(img, CV_32F, 1/255.0);
|
||||
input.convertTo(input, CV_32F, 1/255.0);
|
||||
|
||||
dnn::Blob inputBlob = dnn::Blob::fromImages(img); //Convert Mat to dnn::Blob image batch
|
||||
dnn::Blob inputBlob = dnn::Blob::fromImages(input); //Convert Mat to dnn::Blob image batch
|
||||
//! [Prepare blob]
|
||||
|
||||
//! [Set input blob]
|
||||
net.setBlob(inBlobName, inputBlob); //set the network input
|
||||
net.setBlob("", inputBlob); //set the network input
|
||||
//! [Set input blob]
|
||||
|
||||
cv::TickMeter tm;
|
||||
@@ -112,7 +111,8 @@ int main(int argc, char **argv)
|
||||
tm.stop();
|
||||
|
||||
//! [Gather output]
|
||||
dnn::Blob prob = net.getBlob(outBlobName); //gather output of "prob" layer
|
||||
|
||||
dnn::Blob prob = net.getBlob(net.getLayerNames().back()); //gather output of "prob" layer
|
||||
|
||||
Mat& result = prob.matRef();
|
||||
|
||||
@@ -129,25 +129,27 @@ int main(int argc, char **argv)
|
||||
std::cout << "Output blob shape " << shape << std::endl;
|
||||
std::cout << "Inference time, ms: " << tm.getTimeMilli() << std::endl;
|
||||
|
||||
std::vector<String> classNames;
|
||||
if(!classNamesFile.empty()) {
|
||||
classNames = readClassNames(classNamesFile.c_str());
|
||||
if (classNames.size() > prob.channels())
|
||||
classNames = std::vector<String>(classNames.begin() + classNames.size() - prob.channels(),
|
||||
classNames.end());
|
||||
if (parser.has("show"))
|
||||
{
|
||||
std::vector<String> classNames;
|
||||
if(!classNamesFile.empty()) {
|
||||
classNames = readClassNames(classNamesFile.c_str());
|
||||
if (classNames.size() > prob.channels())
|
||||
classNames = std::vector<String>(classNames.begin() + classNames.size() - prob.channels(),
|
||||
classNames.end());
|
||||
}
|
||||
Mat segm, legend;
|
||||
colorizeSegmentation(prob, segm, legend, classNames);
|
||||
|
||||
Mat show;
|
||||
addWeighted(img, 0.2, segm, 0.8, 0.0, show);
|
||||
|
||||
imshow("Result", show);
|
||||
if(classNames.size())
|
||||
imshow("Legend", legend);
|
||||
waitKey();
|
||||
}
|
||||
|
||||
for(int i_c = 0; i_c < prob.channels(); i_c++) {
|
||||
ostringstream convert;
|
||||
convert << "Channel #" << i_c;
|
||||
|
||||
if(classNames.size() == prob.channels())
|
||||
convert << ": " << classNames[i_c];
|
||||
|
||||
imshow(convert.str().c_str(), prob.getPlane(0, i_c));
|
||||
}
|
||||
waitKey();
|
||||
|
||||
return 0;
|
||||
} //main
|
||||
|
||||
@@ -174,3 +176,57 @@ std::vector<String> readClassNames(const char *filename)
|
||||
fp.close();
|
||||
return classNames;
|
||||
}
|
||||
|
||||
static void colorizeSegmentation(Blob &score, Mat &segm, Mat &legend, vector<String> &classNames)
|
||||
{
|
||||
const int rows = score.rows();
|
||||
const int cols = score.cols();
|
||||
const int chns = score.channels();
|
||||
|
||||
vector<Vec3i> colors;
|
||||
RNG rng(12345678);
|
||||
|
||||
cv::Mat maxCl(rows, cols, CV_8UC1);
|
||||
cv::Mat maxVal(rows, cols, CV_32FC1);
|
||||
for (int ch = 0; ch < chns; ch++)
|
||||
{
|
||||
colors.push_back(Vec3i(rng.uniform(0, 256), rng.uniform(0, 256), rng.uniform(0, 256)));
|
||||
for (int row = 0; row < rows; row++)
|
||||
{
|
||||
const float *ptrScore = score.ptrf(0, ch, row);
|
||||
uchar *ptrMaxCl = maxCl.ptr<uchar>(row);
|
||||
float *ptrMaxVal = maxVal.ptr<float>(row);
|
||||
for (int col = 0; col < cols; col++)
|
||||
{
|
||||
if (ptrScore[col] > ptrMaxVal[col])
|
||||
{
|
||||
ptrMaxVal[col] = ptrScore[col];
|
||||
ptrMaxCl[col] = ch;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
segm.create(rows, cols, CV_8UC3);
|
||||
for (int row = 0; row < rows; row++)
|
||||
{
|
||||
const uchar *ptrMaxCl = maxCl.ptr<uchar>(row);
|
||||
cv::Vec3b *ptrSegm = segm.ptr<cv::Vec3b>(row);
|
||||
for (int col = 0; col < cols; col++)
|
||||
{
|
||||
ptrSegm[col] = colors[ptrMaxCl[col]];
|
||||
}
|
||||
}
|
||||
|
||||
if (classNames.size() == colors.size())
|
||||
{
|
||||
int blockHeight = 30;
|
||||
legend.create(blockHeight*classNames.size(), 200, CV_8UC3);
|
||||
for(int i = 0; i < classNames.size(); i++)
|
||||
{
|
||||
cv::Mat block = legend.rowRange(i*blockHeight, (i+1)*blockHeight);
|
||||
block = colors[i];
|
||||
putText(block, classNames[i], Point(0, blockHeight/2), FONT_HERSHEY_SIMPLEX, 0.5, Scalar());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user