mirror of
https://github.com/opencv/opencv_contrib.git
synced 2025-10-24 20:01:12 +08:00
Added im2row, tiny optimiziations
This commit is contained in:
@@ -20,13 +20,14 @@ const String keys =
|
|||||||
"https://www.dropbox.com/sh/dywzk3gyb12hpe5/AAD5YkUa8XgMpHs2gCRgmCVCa }"
|
"https://www.dropbox.com/sh/dywzk3gyb12hpe5/AAD5YkUa8XgMpHs2gCRgmCVCa }"
|
||||||
"{model m || path to Torch .net model file (model_best.net) }"
|
"{model m || path to Torch .net model file (model_best.net) }"
|
||||||
"{image i || path to image file }"
|
"{image i || path to image file }"
|
||||||
"{i_blob | .0 | input blob name) }"
|
"{c_names c || path to file with classnames for channels (optional, categories.txt) }"
|
||||||
"{o_blob || output blob name) }"
|
|
||||||
"{c_names c || path to file with classnames for channels (categories.txt) }"
|
|
||||||
"{result r || path to save output blob (optional, binary format, NCHW order) }"
|
"{result r || path to save output blob (optional, binary format, NCHW order) }"
|
||||||
|
"{show s || whether to show all output channels or not}"
|
||||||
;
|
;
|
||||||
|
|
||||||
std::vector<String> readClassNames(const char *filename);
|
std::vector<String> readClassNames(const char *filename);
|
||||||
|
static void colorizeSegmentation(Blob &score, Mat &segm,
|
||||||
|
Mat &legend, vector<String> &classNames);
|
||||||
|
|
||||||
int main(int argc, char **argv)
|
int main(int argc, char **argv)
|
||||||
{
|
{
|
||||||
@@ -40,8 +41,6 @@ int main(int argc, char **argv)
|
|||||||
|
|
||||||
String modelFile = parser.get<String>("model");
|
String modelFile = parser.get<String>("model");
|
||||||
String imageFile = parser.get<String>("image");
|
String imageFile = parser.get<String>("image");
|
||||||
String inBlobName = parser.get<String>("i_blob");
|
|
||||||
String outBlobName = parser.get<String>("o_blob");
|
|
||||||
|
|
||||||
if (!parser.check())
|
if (!parser.check())
|
||||||
{
|
{
|
||||||
@@ -78,7 +77,7 @@ int main(int argc, char **argv)
|
|||||||
//! [Initialize network]
|
//! [Initialize network]
|
||||||
|
|
||||||
//! [Prepare blob]
|
//! [Prepare blob]
|
||||||
Mat img = imread(imageFile);
|
Mat img = imread(imageFile), input;
|
||||||
if (img.empty())
|
if (img.empty())
|
||||||
{
|
{
|
||||||
std::cerr << "Can't read image from the file: " << imageFile << std::endl;
|
std::cerr << "Can't read image from the file: " << imageFile << std::endl;
|
||||||
@@ -91,15 +90,15 @@ int main(int argc, char **argv)
|
|||||||
resize(img, img, inputImgSize); //Resize image to input size
|
resize(img, img, inputImgSize); //Resize image to input size
|
||||||
|
|
||||||
if(img.channels() == 3)
|
if(img.channels() == 3)
|
||||||
cv::cvtColor(img, img, cv::COLOR_BGR2RGB);
|
cv::cvtColor(img, input, cv::COLOR_BGR2RGB);
|
||||||
|
|
||||||
img.convertTo(img, CV_32F, 1/255.0);
|
input.convertTo(input, CV_32F, 1/255.0);
|
||||||
|
|
||||||
dnn::Blob inputBlob = dnn::Blob::fromImages(img); //Convert Mat to dnn::Blob image batch
|
dnn::Blob inputBlob = dnn::Blob::fromImages(input); //Convert Mat to dnn::Blob image batch
|
||||||
//! [Prepare blob]
|
//! [Prepare blob]
|
||||||
|
|
||||||
//! [Set input blob]
|
//! [Set input blob]
|
||||||
net.setBlob(inBlobName, inputBlob); //set the network input
|
net.setBlob("", inputBlob); //set the network input
|
||||||
//! [Set input blob]
|
//! [Set input blob]
|
||||||
|
|
||||||
cv::TickMeter tm;
|
cv::TickMeter tm;
|
||||||
@@ -112,7 +111,8 @@ int main(int argc, char **argv)
|
|||||||
tm.stop();
|
tm.stop();
|
||||||
|
|
||||||
//! [Gather output]
|
//! [Gather output]
|
||||||
dnn::Blob prob = net.getBlob(outBlobName); //gather output of "prob" layer
|
|
||||||
|
dnn::Blob prob = net.getBlob(net.getLayerNames().back()); //gather output of "prob" layer
|
||||||
|
|
||||||
Mat& result = prob.matRef();
|
Mat& result = prob.matRef();
|
||||||
|
|
||||||
@@ -129,25 +129,27 @@ int main(int argc, char **argv)
|
|||||||
std::cout << "Output blob shape " << shape << std::endl;
|
std::cout << "Output blob shape " << shape << std::endl;
|
||||||
std::cout << "Inference time, ms: " << tm.getTimeMilli() << std::endl;
|
std::cout << "Inference time, ms: " << tm.getTimeMilli() << std::endl;
|
||||||
|
|
||||||
std::vector<String> classNames;
|
if (parser.has("show"))
|
||||||
if(!classNamesFile.empty()) {
|
{
|
||||||
classNames = readClassNames(classNamesFile.c_str());
|
std::vector<String> classNames;
|
||||||
if (classNames.size() > prob.channels())
|
if(!classNamesFile.empty()) {
|
||||||
classNames = std::vector<String>(classNames.begin() + classNames.size() - prob.channels(),
|
classNames = readClassNames(classNamesFile.c_str());
|
||||||
classNames.end());
|
if (classNames.size() > prob.channels())
|
||||||
|
classNames = std::vector<String>(classNames.begin() + classNames.size() - prob.channels(),
|
||||||
|
classNames.end());
|
||||||
|
}
|
||||||
|
Mat segm, legend;
|
||||||
|
colorizeSegmentation(prob, segm, legend, classNames);
|
||||||
|
|
||||||
|
Mat show;
|
||||||
|
addWeighted(img, 0.2, segm, 0.8, 0.0, show);
|
||||||
|
|
||||||
|
imshow("Result", show);
|
||||||
|
if(classNames.size())
|
||||||
|
imshow("Legend", legend);
|
||||||
|
waitKey();
|
||||||
}
|
}
|
||||||
|
|
||||||
for(int i_c = 0; i_c < prob.channels(); i_c++) {
|
|
||||||
ostringstream convert;
|
|
||||||
convert << "Channel #" << i_c;
|
|
||||||
|
|
||||||
if(classNames.size() == prob.channels())
|
|
||||||
convert << ": " << classNames[i_c];
|
|
||||||
|
|
||||||
imshow(convert.str().c_str(), prob.getPlane(0, i_c));
|
|
||||||
}
|
|
||||||
waitKey();
|
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
} //main
|
} //main
|
||||||
|
|
||||||
@@ -174,3 +176,57 @@ std::vector<String> readClassNames(const char *filename)
|
|||||||
fp.close();
|
fp.close();
|
||||||
return classNames;
|
return classNames;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void colorizeSegmentation(Blob &score, Mat &segm, Mat &legend, vector<String> &classNames)
|
||||||
|
{
|
||||||
|
const int rows = score.rows();
|
||||||
|
const int cols = score.cols();
|
||||||
|
const int chns = score.channels();
|
||||||
|
|
||||||
|
vector<Vec3i> colors;
|
||||||
|
RNG rng(12345678);
|
||||||
|
|
||||||
|
cv::Mat maxCl(rows, cols, CV_8UC1);
|
||||||
|
cv::Mat maxVal(rows, cols, CV_32FC1);
|
||||||
|
for (int ch = 0; ch < chns; ch++)
|
||||||
|
{
|
||||||
|
colors.push_back(Vec3i(rng.uniform(0, 256), rng.uniform(0, 256), rng.uniform(0, 256)));
|
||||||
|
for (int row = 0; row < rows; row++)
|
||||||
|
{
|
||||||
|
const float *ptrScore = score.ptrf(0, ch, row);
|
||||||
|
uchar *ptrMaxCl = maxCl.ptr<uchar>(row);
|
||||||
|
float *ptrMaxVal = maxVal.ptr<float>(row);
|
||||||
|
for (int col = 0; col < cols; col++)
|
||||||
|
{
|
||||||
|
if (ptrScore[col] > ptrMaxVal[col])
|
||||||
|
{
|
||||||
|
ptrMaxVal[col] = ptrScore[col];
|
||||||
|
ptrMaxCl[col] = ch;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
segm.create(rows, cols, CV_8UC3);
|
||||||
|
for (int row = 0; row < rows; row++)
|
||||||
|
{
|
||||||
|
const uchar *ptrMaxCl = maxCl.ptr<uchar>(row);
|
||||||
|
cv::Vec3b *ptrSegm = segm.ptr<cv::Vec3b>(row);
|
||||||
|
for (int col = 0; col < cols; col++)
|
||||||
|
{
|
||||||
|
ptrSegm[col] = colors[ptrMaxCl[col]];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (classNames.size() == colors.size())
|
||||||
|
{
|
||||||
|
int blockHeight = 30;
|
||||||
|
legend.create(blockHeight*classNames.size(), 200, CV_8UC3);
|
||||||
|
for(int i = 0; i < classNames.size(); i++)
|
||||||
|
{
|
||||||
|
cv::Mat block = legend.rowRange(i*blockHeight, (i+1)*blockHeight);
|
||||||
|
block = colors[i];
|
||||||
|
putText(block, classNames[i], Point(0, blockHeight/2), FONT_HERSHEY_SIMPLEX, 0.5, Scalar());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@@ -58,8 +58,7 @@ BaseConvolutionLayerImpl::BaseConvolutionLayerImpl():
|
|||||||
inpH(0), inpW(0), inpCn(0),
|
inpH(0), inpW(0), inpCn(0),
|
||||||
outH(0), outW(0), outCn(0),
|
outH(0), outW(0), outCn(0),
|
||||||
inpGroupCn(0), outGroupCn(0),
|
inpGroupCn(0), outGroupCn(0),
|
||||||
ksize(0), colBlobCols(0),
|
ksize(0), bias(false), tryUseOpenCL(false)
|
||||||
bias(false), tryUseOpenCL(false)
|
|
||||||
{
|
{
|
||||||
#if HAVE_CBLAS
|
#if HAVE_CBLAS
|
||||||
if (getBlasThreads() != cv::getThreadNum())
|
if (getBlasThreads() != cv::getThreadNum())
|
||||||
@@ -111,7 +110,7 @@ void BaseConvolutionLayerImpl::allocate(const std::vector<Blob*> &inputs, std::v
|
|||||||
|
|
||||||
if (!is1x1())
|
if (!is1x1())
|
||||||
{
|
{
|
||||||
colBlob.create(Shape(ksize, colBlobCols), input.type(), allocFlags);
|
colRowBlob.create(colRowBlobShape, input.type(), allocFlags);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -152,7 +151,7 @@ void ConvolutionLayerImpl::computeInpOutShape(const Blob &input)
|
|||||||
inpGroupCn = inpCn / group;
|
inpGroupCn = inpCn / group;
|
||||||
ksize = inpGroupCn * kernel.height * kernel.width;
|
ksize = inpGroupCn * kernel.height * kernel.width;
|
||||||
|
|
||||||
colBlobCols = outH * outW;
|
colRowBlobShape = BlobShape(outH * outW, ksize);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename XMat>
|
template<typename XMat>
|
||||||
@@ -174,7 +173,8 @@ void ConvolutionLayerImpl::forward_(std::vector<Blob*> &inputs, std::vector<Blob
|
|||||||
for (int g = 0; g < group; g++)
|
for (int g = 0; g < group; g++)
|
||||||
{
|
{
|
||||||
XMat colMat, curInp = slice(inpMat, n, _Range(g * inpGroupCn, inpGroupCn));
|
XMat colMat, curInp = slice(inpMat, n, _Range(g * inpGroupCn, inpGroupCn));
|
||||||
im2col(curInp, colMat);
|
|
||||||
|
im2row(curInp, colMat);
|
||||||
|
|
||||||
_Range kerRange(g * outGroupCn, outGroupCn);
|
_Range kerRange(g * outGroupCn, outGroupCn);
|
||||||
XMat kerMat = weightsMat.rowRange(kerRange);
|
XMat kerMat = weightsMat.rowRange(kerRange);
|
||||||
@@ -182,7 +182,7 @@ void ConvolutionLayerImpl::forward_(std::vector<Blob*> &inputs, std::vector<Blob
|
|||||||
_Range outRange((g + n * group) * outGroupCn, outGroupCn);
|
_Range outRange((g + n * group) * outGroupCn, outGroupCn);
|
||||||
XMat dstMat = outMat.rowRange(outRange);
|
XMat dstMat = outMat.rowRange(outRange);
|
||||||
|
|
||||||
dnn::gemm(kerMat, colMat, 1, dstMat, 0);
|
dnn::gemm(kerMat, colMat, 1, dstMat, 0, GEMM_2_T);
|
||||||
|
|
||||||
if (bias)
|
if (bias)
|
||||||
{
|
{
|
||||||
@@ -209,8 +209,8 @@ void ConvolutionLayerImpl::im2col(const UMat &srcImg, UMat &dstCol)
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
#ifdef HAVE_OPENCL
|
#ifdef HAVE_OPENCL
|
||||||
CV_Assert(im2col_ocl(srcImg, inpGroupCn, inpH, inpW, kernel.height, kernel.width, pad.height, pad.width, stride.height, stride.width, dilation.height, dilation.width, this->colBlob.umatRef()));
|
CV_Assert(im2col_ocl(srcImg, inpGroupCn, inpH, inpW, kernel.height, kernel.width, pad.height, pad.width, stride.height, stride.width, dilation.height, dilation.width, this->colRowBlob.umatRef()));
|
||||||
dstCol = this->colBlob.umatRefConst();
|
dstCol = this->colRowBlob.umatRefConst();
|
||||||
#else
|
#else
|
||||||
CV_Error(Error::StsInternal, "");
|
CV_Error(Error::StsInternal, "");
|
||||||
dstCol = srcImg; //supress warning
|
dstCol = srcImg; //supress warning
|
||||||
@@ -225,7 +225,7 @@ void ConvolutionLayerImpl::im2col(const Mat &srcImg, Mat &dstCol)
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
Mat &colMat = colBlob.matRef();
|
Mat &colMat = colRowBlob.matRef();
|
||||||
if (srcImg.type() == CV_32F)
|
if (srcImg.type() == CV_32F)
|
||||||
im2col_CpuPBody<float>::run(srcImg.ptr<float>(), inpGroupCn, inpH, inpW, kernel.height,
|
im2col_CpuPBody<float>::run(srcImg.ptr<float>(), inpGroupCn, inpH, inpW, kernel.height,
|
||||||
kernel.width, pad.height, pad.width, stride.height, stride.width,
|
kernel.width, pad.height, pad.width, stride.height, stride.width,
|
||||||
@@ -238,6 +238,32 @@ void ConvolutionLayerImpl::im2col(const Mat &srcImg, Mat &dstCol)
|
|||||||
dstCol = colMat;
|
dstCol = colMat;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ConvolutionLayerImpl::im2row(const Mat &srcImg, Mat &dstRow)
|
||||||
|
{
|
||||||
|
if (is1x1())
|
||||||
|
{
|
||||||
|
dstRow = reshaped(srcImg, Shape(ksize, outH*outW)).t();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
Mat &colMat = colRowBlob.matRef();
|
||||||
|
if (srcImg.type() == CV_32F)
|
||||||
|
im2row_CpuPBody<float>::run(srcImg.ptr<float>(), inpGroupCn, inpH, inpW, kernel.height,
|
||||||
|
kernel.width, pad.height, pad.width, stride.height, stride.width,
|
||||||
|
dilation.height, dilation.width, outW, outH, colMat.ptr<float>());
|
||||||
|
if (srcImg.type() == CV_64F)
|
||||||
|
im2row_CpuPBody<double>::run(srcImg.ptr<double>(), inpGroupCn, inpH, inpW, kernel.height,
|
||||||
|
kernel.width, pad.height, pad.width, stride.height, stride.width,
|
||||||
|
dilation.height, dilation.width, outW, outH, colMat.ptr<double>());
|
||||||
|
|
||||||
|
dstRow = colMat;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ConvolutionLayerImpl::im2row(const UMat &srcImg, UMat &dstCol)
|
||||||
|
{
|
||||||
|
CV_Error(cv::Error::StsNotImplemented, "");
|
||||||
|
}
|
||||||
|
|
||||||
//Deconvolution
|
//Deconvolution
|
||||||
|
|
||||||
void DeConvolutionLayerImpl::computeInpOutShape(const Blob &inpBlob)
|
void DeConvolutionLayerImpl::computeInpOutShape(const Blob &inpBlob)
|
||||||
@@ -264,7 +290,7 @@ void DeConvolutionLayerImpl::computeInpOutShape(const Blob &inpBlob)
|
|||||||
CV_Assert(inpCn % group == 0 && outCn % group == 0);
|
CV_Assert(inpCn % group == 0 && outCn % group == 0);
|
||||||
CV_Assert(blobs[0].channels() == outCn && blobs[0].num() == inpCn / group);
|
CV_Assert(blobs[0].channels() == outCn && blobs[0].num() == inpCn / group);
|
||||||
|
|
||||||
colBlobCols = inpH * inpW;
|
colRowBlobShape = BlobShape(ksize, inpH * inpW);
|
||||||
}
|
}
|
||||||
|
|
||||||
void DeConvolutionLayerImpl::forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
|
void DeConvolutionLayerImpl::forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
|
||||||
@@ -292,7 +318,7 @@ void DeConvolutionLayerImpl::forward_(std::vector<Blob *> &inputs, std::vector<B
|
|||||||
for (int g = 0; g < group; g++)
|
for (int g = 0; g < group; g++)
|
||||||
{
|
{
|
||||||
XMat dstMat = decnBlob.rowRange(_Range((g + n * group) * outGroupCn, outGroupCn));
|
XMat dstMat = decnBlob.rowRange(_Range((g + n * group) * outGroupCn, outGroupCn));
|
||||||
XMat &colMat = (is1x1()) ? dstMat : colBlob.getRef<XMat>();
|
XMat &colMat = (is1x1()) ? dstMat : colRowBlob.getRef<XMat>();
|
||||||
|
|
||||||
XMat convMat = convBlob.rowRange(_Range((g + n * group) * inpGroupCn, inpGroupCn));
|
XMat convMat = convBlob.rowRange(_Range((g + n * group) * inpGroupCn, inpGroupCn));
|
||||||
XMat wghtMat = weightsMat.rowRange(_Range(g * inpGroupCn, inpGroupCn));
|
XMat wghtMat = weightsMat.rowRange(_Range(g * inpGroupCn, inpGroupCn));
|
||||||
|
@@ -65,12 +65,12 @@ protected:
|
|||||||
int outH, outW, outCn;
|
int outH, outW, outCn;
|
||||||
int inpGroupCn, outGroupCn;
|
int inpGroupCn, outGroupCn;
|
||||||
int ksize;
|
int ksize;
|
||||||
int colBlobCols;
|
BlobShape colRowBlobShape;
|
||||||
|
|
||||||
bool bias;
|
bool bias;
|
||||||
bool tryUseOpenCL, useOpenCL;
|
bool tryUseOpenCL, useOpenCL;
|
||||||
|
|
||||||
Blob colBlob, biasOnesBlob;
|
Blob colRowBlob, biasOnesBlob;
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -86,7 +86,9 @@ protected:
|
|||||||
template<typename XMat>
|
template<typename XMat>
|
||||||
void forward_(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
|
void forward_(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
|
||||||
void im2col(const Mat &srcImg, Mat &dstCol);
|
void im2col(const Mat &srcImg, Mat &dstCol);
|
||||||
|
void im2row(const Mat &srcImg, Mat &dstRow);
|
||||||
void im2col(const UMat &srcImg, UMat &dstCol);
|
void im2col(const UMat &srcImg, UMat &dstCol);
|
||||||
|
void im2row(const UMat &srcImg, UMat &dstCol);
|
||||||
};
|
};
|
||||||
|
|
||||||
class DeConvolutionLayerImpl : public BaseConvolutionLayerImpl
|
class DeConvolutionLayerImpl : public BaseConvolutionLayerImpl
|
||||||
|
@@ -287,7 +287,9 @@ struct PowerFunctor
|
|||||||
{
|
{
|
||||||
typedef PowerLayer Layer;
|
typedef PowerLayer Layer;
|
||||||
|
|
||||||
double power, scale, shift;
|
const double power;
|
||||||
|
const double scale;
|
||||||
|
const double shift;
|
||||||
|
|
||||||
PowerFunctor(double power_, double scale_ = 1, double shift_ = 0)
|
PowerFunctor(double power_, double scale_ = 1, double shift_ = 0)
|
||||||
: power(power_), scale(scale_), shift(shift_) {}
|
: power(power_), scale(scale_), shift(shift_) {}
|
||||||
@@ -295,7 +297,7 @@ struct PowerFunctor
|
|||||||
template<typename TFloat>
|
template<typename TFloat>
|
||||||
inline TFloat operator()(TFloat x) const
|
inline TFloat operator()(TFloat x) const
|
||||||
{
|
{
|
||||||
return pow((TFloat)shift + (TFloat)scale * x, (TFloat)power);
|
return power == 1.0 ? (TFloat)shift + (TFloat)scale * x : pow((TFloat)shift + (TFloat)scale * x, (TFloat)power);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef HAVE_OPENCL
|
#ifdef HAVE_OPENCL
|
||||||
|
@@ -114,6 +114,92 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename Dtype>
|
||||||
|
class im2row_CpuPBody : public cv::ParallelLoopBody
|
||||||
|
{
|
||||||
|
const Dtype* data_im;
|
||||||
|
int channels, height, width;
|
||||||
|
int kernel_h, kernel_w;
|
||||||
|
int pad_h, pad_w;
|
||||||
|
int stride_h, stride_w;
|
||||||
|
int dilation_h, dilation_w;
|
||||||
|
Dtype* data_col;
|
||||||
|
int height_col, width_col, channels_col;
|
||||||
|
|
||||||
|
im2row_CpuPBody() {}
|
||||||
|
public:
|
||||||
|
|
||||||
|
static void run(const Dtype* data_im,
|
||||||
|
int channels, int height, int width,
|
||||||
|
int kernel_h, int kernel_w,
|
||||||
|
int pad_h, int pad_w,
|
||||||
|
int stride_h, int stride_w,
|
||||||
|
int dilation_h, int dilation_w,
|
||||||
|
int height_col, int width_col,
|
||||||
|
Dtype* data_col)
|
||||||
|
{
|
||||||
|
im2row_CpuPBody<Dtype> t;
|
||||||
|
|
||||||
|
t.data_im = data_im;
|
||||||
|
t.data_col = data_col;
|
||||||
|
t.channels = channels; t.height = height; t.width = width;
|
||||||
|
t.kernel_h = kernel_h; t.kernel_w = kernel_w;
|
||||||
|
t.pad_h = pad_h; t.pad_w = pad_w;
|
||||||
|
t.stride_h = stride_h; t.stride_w = stride_w;
|
||||||
|
t.dilation_h = dilation_h; t.dilation_w = dilation_w;
|
||||||
|
|
||||||
|
t.height_col = height_col;
|
||||||
|
t.width_col = width_col;
|
||||||
|
t.channels_col = channels * kernel_h * kernel_w;
|
||||||
|
|
||||||
|
cv::parallel_for_(Range(0, t.height_col*t.width_col), t, 16);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual void operator ()(const Range &r) const
|
||||||
|
{
|
||||||
|
int dh = dilation_h, dw = dilation_w;
|
||||||
|
Dtype* data_col_ = data_col;
|
||||||
|
const Dtype* data_im_ = data_im;
|
||||||
|
|
||||||
|
for (int row = r.start; row < r.end; ++row)
|
||||||
|
{
|
||||||
|
int out_c = row % width_col;
|
||||||
|
int out_r = row / width_col;
|
||||||
|
int out_row_offset = row*kernel_h*kernel_w*channels;
|
||||||
|
|
||||||
|
int start_in_r = out_r * stride_h - pad_h;
|
||||||
|
int start_in_c = out_c * stride_w - pad_w;
|
||||||
|
int start_k_r = std::max(0, cvCeil(-start_in_r/(float)dilation_h));
|
||||||
|
int end_k_r = std::min(kernel_h, cvCeil((height - start_in_r)/(float)dilation_h));
|
||||||
|
int start_k_c = std::max(0, cvCeil(-start_in_c/(float)dilation_w));
|
||||||
|
int end_k_c = std::min(kernel_w, cvCeil((width - start_in_c)/(float)dilation_w));
|
||||||
|
|
||||||
|
for(int i_c = 0; i_c < channels; i_c++)
|
||||||
|
{
|
||||||
|
int channels_offset = i_c * width * height;
|
||||||
|
int out_ch_offset = i_c*kernel_h*kernel_w;
|
||||||
|
int in_r = start_in_r + start_k_r*dilation_h;
|
||||||
|
|
||||||
|
for(int k_r = start_k_r; k_r < end_k_r; k_r++, in_r += dh)
|
||||||
|
{
|
||||||
|
int row_offset = in_r*width;
|
||||||
|
int out_col_offset = k_r*kernel_w;
|
||||||
|
int in_c = start_in_c + start_k_c*dilation_w;
|
||||||
|
|
||||||
|
for(int k_c = start_k_c; k_c < end_k_c; k_c++, in_c += dw)
|
||||||
|
{
|
||||||
|
int in_index = channels_offset + row_offset + in_c;
|
||||||
|
|
||||||
|
int out_index = out_row_offset + out_ch_offset + out_col_offset + k_c;
|
||||||
|
|
||||||
|
data_col_[out_index] = data_im_[in_index];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename Dtype>
|
template <typename Dtype>
|
||||||
class col2im_CpuPBody : public cv::ParallelLoopBody
|
class col2im_CpuPBody : public cv::ParallelLoopBody
|
||||||
{
|
{
|
||||||
@@ -154,6 +240,10 @@ public:
|
|||||||
|
|
||||||
virtual void operator ()(const Range &r) const
|
virtual void operator ()(const Range &r) const
|
||||||
{
|
{
|
||||||
|
const Dtype* data_col_ = data_col;
|
||||||
|
Dtype* data_im_ = data_im;
|
||||||
|
int coeff_h_col = (1 - stride_h * kernel_w * height_col) * width_col;
|
||||||
|
int coeff_w_col = (1 - stride_w * height_col * width_col);
|
||||||
for (int index = r.start; index < r.end; index++)
|
for (int index = r.start; index < r.end; index++)
|
||||||
{
|
{
|
||||||
Dtype val = 0;
|
Dtype val = 0;
|
||||||
@@ -170,14 +260,13 @@ public:
|
|||||||
// equivalent implementation
|
// equivalent implementation
|
||||||
int offset =
|
int offset =
|
||||||
(c * kernel_h * kernel_w + h * kernel_w + w) * height_col * width_col;
|
(c * kernel_h * kernel_w + h * kernel_w + w) * height_col * width_col;
|
||||||
int coeff_h_col = (1 - stride_h * kernel_w * height_col) * width_col;
|
|
||||||
int coeff_w_col = (1 - stride_w * height_col * width_col);
|
|
||||||
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
|
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
|
||||||
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
|
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
|
||||||
val += data_col[offset + h_col * coeff_h_col + w_col * coeff_w_col];
|
val += data_col_[offset + h_col * coeff_h_col + w_col * coeff_w_col];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
data_im[index] = val;
|
data_im_[index] = val;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@@ -197,7 +197,7 @@ struct TorchImporter : public ::cv::dnn::Importer
|
|||||||
|
|
||||||
if (typeStr == "Double")
|
if (typeStr == "Double")
|
||||||
return CV_64F;
|
return CV_64F;
|
||||||
else if (typeStr == "Float")
|
else if (typeStr == "Float" || typeStr == "Cuda")
|
||||||
return CV_32F;
|
return CV_32F;
|
||||||
else if (typeStr == "Byte")
|
else if (typeStr == "Byte")
|
||||||
return CV_8U;
|
return CV_8U;
|
||||||
|
Reference in New Issue
Block a user