mirror of
https://github.com/opencv/opencv_contrib.git
synced 2025-10-24 11:33:26 +08:00
Adding of perf test for DNN module.
This commit is contained in:
101
modules/dnn/cmake/FindMKL.cmake
Normal file
101
modules/dnn/cmake/FindMKL.cmake
Normal file
@@ -0,0 +1,101 @@
|
||||
# - Find the MKL libraries
|
||||
# Modified from Armadillo's ARMA_FindMKL.cmake
|
||||
# This module defines
|
||||
# MKL_INCLUDE_DIR, the directory for the MKL headers
|
||||
# MKL_LIB_DIR, the directory for the MKL library files
|
||||
# MKL_COMPILER_LIB_DIR, the directory for the MKL compiler library files
|
||||
# MKL_LIBRARIES, the libraries needed to use Intel's implementation of BLAS & LAPACK.
|
||||
# MKL_FOUND, If false, do not try to use MKL; if true, the macro definition USE_MKL is added.
|
||||
|
||||
# Set the include path
|
||||
# TODO: what if MKL is not installed in /opt/intel/mkl?
|
||||
# try to find at /opt/intel/mkl
|
||||
# in windows, try to find MKL at C:/Program Files (x86)/Intel/Composer XE/mkl
|
||||
|
||||
if ( WIN32 )
|
||||
if(NOT DEFINED ENV{MKLROOT_PATH})
|
||||
#set(MKLROOT_PATH "C:/Program Files (x86)/Intel/Composer XE" CACHE PATH "Where the MKL are stored")
|
||||
set(MKLROOT_PATH "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows" CACHE PATH "Where the MKL are stored")
|
||||
endif(NOT DEFINED ENV{MKLROOT_PATH})
|
||||
else ( WIN32 )
|
||||
set(MKLROOT_PATH "/opt/intel" CACHE PATH "Where the MKL are stored")
|
||||
endif ( WIN32 )
|
||||
|
||||
if (EXISTS ${MKLROOT_PATH}/mkl)
|
||||
SET(MKL_FOUND TRUE)
|
||||
message("MKL is found at ${MKLROOT_PATH}/mkl")
|
||||
IF(CMAKE_SIZEOF_VOID_P EQUAL 8)
|
||||
set( USE_MKL_64BIT On )
|
||||
if ( ARMADILLO_FOUND )
|
||||
if ( ARMADILLO_BLAS_LONG_LONG )
|
||||
set( USE_MKL_64BIT_LIB On )
|
||||
ADD_DEFINITIONS(-DMKL_ILP64)
|
||||
message("MKL is linked against ILP64 interface ... ")
|
||||
endif ( ARMADILLO_BLAS_LONG_LONG )
|
||||
endif ( ARMADILLO_FOUND )
|
||||
ELSE(CMAKE_SIZEOF_VOID_P EQUAL 8)
|
||||
set( USE_MKL_64BIT Off )
|
||||
ENDIF(CMAKE_SIZEOF_VOID_P EQUAL 8)
|
||||
else (EXISTS ${MKLROOT_PATH}/mkl)
|
||||
SET(MKL_FOUND FALSE)
|
||||
message("MKL is NOT found ... ")
|
||||
endif (EXISTS ${MKLROOT_PATH}/mkl)
|
||||
|
||||
if (MKL_FOUND)
|
||||
set(MKL_INCLUDE_DIR "${MKLROOT_PATH}/mkl/include")
|
||||
ADD_DEFINITIONS(-DUSE_MKL)
|
||||
if ( USE_MKL_64BIT )
|
||||
set(MKL_LIB_DIR "${MKLROOT_PATH}/mkl/lib/intel64")
|
||||
set(MKL_COMPILER_LIB_DIR "${MKLROOT_PATH}/compiler/lib/intel64")
|
||||
set(MKL_COMPILER_LIB_DIR ${MKL_COMPILER_LIB_DIR} "${MKLROOT_PATH}/lib/intel64")
|
||||
if ( USE_MKL_64BIT_LIB )
|
||||
if (WIN32)
|
||||
set(MKL_LIBRARIES ${MKL_LIBRARIES} mkl_intel_ilp64)
|
||||
else (WIN32)
|
||||
set(MKL_LIBRARIES ${MKL_LIBRARIES} mkl_intel_ilp64)
|
||||
endif (WIN32)
|
||||
else ( USE_MKL_64BIT_LIB )
|
||||
if (WIN32)
|
||||
set(MKL_LIBRARIES ${MKL_LIBRARIES} mkl_intel_lp64)
|
||||
else (WIN32)
|
||||
set(MKL_LIBRARIES ${MKL_LIBRARIES} mkl_intel_lp64)
|
||||
endif (WIN32)
|
||||
endif ( USE_MKL_64BIT_LIB )
|
||||
else ( USE_MKL_64BIT )
|
||||
set(MKL_LIB_DIR "${MKLROOT_PATH}/mkl/lib/ia32")
|
||||
set(MKL_COMPILER_LIB_DIR "${MKLROOT_PATH}/compiler/lib/ia32")
|
||||
set(MKL_COMPILER_LIB_DIR ${MKL_COMPILER_LIB_DIR} "${MKLROOT_PATH}/lib/ia32")
|
||||
if ( WIN32 )
|
||||
set(MKL_LIBRARIES ${MKL_LIBRARIES} mkl_intel_c)
|
||||
else ( WIN32 )
|
||||
set(MKL_LIBRARIES ${MKL_LIBRARIES} mkl_intel)
|
||||
endif ( WIN32 )
|
||||
endif ( USE_MKL_64BIT )
|
||||
|
||||
if (WIN32)
|
||||
SET(MKL_LIBRARIES ${MKL_LIBRARIES} mkl_intel_thread)
|
||||
SET(MKL_LIBRARIES ${MKL_LIBRARIES} mkl_core)
|
||||
SET(MKL_LIBRARIES ${MKL_LIBRARIES} libiomp5md)
|
||||
else (WIN32)
|
||||
SET(MKL_LIBRARIES ${MKL_LIBRARIES} mkl_gnu_thread)
|
||||
SET(MKL_LIBRARIES ${MKL_LIBRARIES} mkl_core)
|
||||
endif (WIN32)
|
||||
endif (MKL_FOUND)
|
||||
|
||||
IF (MKL_FOUND)
|
||||
IF (NOT MKL_FIND_QUIETLY)
|
||||
MESSAGE(STATUS "Found MKL libraries: ${MKL_LIBRARIES}")
|
||||
MESSAGE(STATUS "MKL_INCLUDE_DIR: ${MKL_INCLUDE_DIR}")
|
||||
MESSAGE(STATUS "MKL_LIB_DIR: ${MKL_LIB_DIR}")
|
||||
MESSAGE(STATUS "MKL_COMPILER_LIB_DIR: ${MKL_COMPILER_LIB_DIR}")
|
||||
ENDIF (NOT MKL_FIND_QUIETLY)
|
||||
|
||||
INCLUDE_DIRECTORIES( ${MKL_INCLUDE_DIR} )
|
||||
LINK_DIRECTORIES( ${MKL_LIB_DIR} ${MKL_COMPILER_LIB_DIR} )
|
||||
ELSE (MKL_FOUND)
|
||||
IF (MKL_FIND_REQUIRED)
|
||||
MESSAGE(FATAL_ERROR "Could not find MKL libraries")
|
||||
ENDIF (MKL_FIND_REQUIRED)
|
||||
ENDIF (MKL_FOUND)
|
||||
|
||||
# MARK_AS_ADVANCED(MKL_LIBRARY)
|
||||
@@ -116,7 +116,7 @@ namespace dnn //! This namespace is used for dnn module functionlaity.
|
||||
String type; //!< Type name which was used for creating layer by layer factory.
|
||||
|
||||
Layer();
|
||||
explicit Layer(const LayerParams ¶ms); //!< Initialize only #name, #type and #blobs fields.
|
||||
explicit Layer(const LayerParams ¶ms); //!< Initializes only #name, #type and #blobs fields.
|
||||
virtual ~Layer();
|
||||
};
|
||||
|
||||
|
||||
80
modules/dnn/perf/perf_convolution.cpp
Normal file
80
modules/dnn/perf/perf_convolution.cpp
Normal file
@@ -0,0 +1,80 @@
|
||||
#include "perf_precomp.hpp"
|
||||
|
||||
namespace cvtest
|
||||
{
|
||||
|
||||
using std::tr1::tuple;
|
||||
using std::tr1::get;
|
||||
using std::tr1::make_tuple;
|
||||
using std::make_pair;
|
||||
using namespace perf;
|
||||
using namespace testing;
|
||||
using namespace cv;
|
||||
using namespace cv::dnn;
|
||||
|
||||
enum {STRIDE_OFF = 1, STRIDE_ON = 2};
|
||||
CV_ENUM(StrideSize, STRIDE_OFF, STRIDE_ON);
|
||||
|
||||
enum {GROUP_OFF = 1, GROUP_2 = 2};
|
||||
CV_ENUM(GroupSize, GROUP_OFF, GROUP_2);
|
||||
|
||||
//Squared Size
|
||||
#define SSZ(n) cv::Size(n, n)
|
||||
|
||||
typedef std::pair<BlobShape, int> InpShapeNumOut;
|
||||
typedef tuple<Size, InpShapeNumOut, GroupSize, StrideSize> ConvParam; //kernel_size, inp shape, groups, stride
|
||||
typedef TestBaseWithParam<ConvParam> ConvolutionPerfTest;
|
||||
|
||||
PERF_TEST_P( ConvolutionPerfTest, perf, Combine(
|
||||
Values(Size(1, 1), Size(3, 3), Size(5, 5), Size(11, 11)),
|
||||
Values(make_pair(BlobShape(1, 4, 224, 224), 64),
|
||||
make_pair(BlobShape(1, 64, 112, 122), 128),
|
||||
make_pair(BlobShape(1, 256, 28, 28), 512)),
|
||||
GroupSize::all(),
|
||||
StrideSize::all())
|
||||
)
|
||||
{
|
||||
RNG rng(0);
|
||||
|
||||
ConvParam params = GetParam();
|
||||
int ksz = get<0>(params).width;
|
||||
BlobShape inpShape = get<1>(params).first;
|
||||
int outCn = get<1>(params).second;
|
||||
int groups = get<2>(params);
|
||||
int stride = (ksz >= 11) ? 4 : get<3>(params);
|
||||
|
||||
int inpCn = inpShape[1];
|
||||
Blob wgtBlob(BlobShape(outCn, inpCn/groups, ksz, ksz)), biasBlob(BlobShape(outCn, 1, 1, 1));
|
||||
Blob inpBlob(inpShape);
|
||||
rng.fill(biasBlob.matRef(), RNG::UNIFORM, -1, +1);
|
||||
rng.fill(wgtBlob.matRef(), RNG::UNIFORM, -1, +1);
|
||||
rng.fill(inpBlob.matRef(), RNG::UNIFORM, -1, +1);
|
||||
|
||||
LayerParams lp;
|
||||
lp.set("num_output", outCn);
|
||||
lp.set("group", groups);
|
||||
lp.set("stride", stride);
|
||||
lp.set("kernel_size", ksz);
|
||||
lp.blobs.reserve(2);
|
||||
lp.blobs.push_back(wgtBlob);
|
||||
lp.blobs.push_back(biasBlob);
|
||||
|
||||
std::vector<Blob*> inpBlobs(1, &inpBlob);
|
||||
std::vector<Blob> outBlobs;
|
||||
|
||||
cv::setNumThreads(cv::getNumberOfCPUs());
|
||||
|
||||
Ptr<Layer> layer = cv::dnn::LayerFactory::createLayerInstance("Convolution", lp);
|
||||
layer->allocate(inpBlobs, outBlobs);
|
||||
|
||||
declare.in(inpBlob.matRef(), wgtBlob.matRef(), WARMUP_RNG).out(outBlobs[0].matRef()).tbb_threads(cv::getNumThreads());
|
||||
|
||||
TEST_CYCLE_N(10)
|
||||
{
|
||||
layer->forward(inpBlobs, outBlobs);
|
||||
}
|
||||
|
||||
SANITY_CHECK_NOTHING();
|
||||
}
|
||||
|
||||
}
|
||||
3
modules/dnn/perf/perf_main.cpp
Normal file
3
modules/dnn/perf/perf_main.cpp
Normal file
@@ -0,0 +1,3 @@
|
||||
#include "perf_precomp.hpp"
|
||||
|
||||
CV_PERF_TEST_MAIN(dnn)
|
||||
17
modules/dnn/perf/perf_precomp.hpp
Normal file
17
modules/dnn/perf/perf_precomp.hpp
Normal file
@@ -0,0 +1,17 @@
|
||||
#ifdef __GNUC__
|
||||
# pragma GCC diagnostic ignored "-Wmissing-declarations"
|
||||
# if defined __clang__ || defined __APPLE__
|
||||
# pragma GCC diagnostic ignored "-Wmissing-prototypes"
|
||||
# pragma GCC diagnostic ignored "-Wextra"
|
||||
# endif
|
||||
#endif
|
||||
|
||||
#ifndef __OPENCV_PERF_PRECOMP_HPP__
|
||||
#define __OPENCV_PERF_PRECOMP_HPP__
|
||||
|
||||
#include <opencv2/ts.hpp>
|
||||
#include <opencv2/imgproc.hpp>
|
||||
#include <opencv2/highgui.hpp>
|
||||
#include <opencv2/dnn.hpp>
|
||||
|
||||
#endif
|
||||
@@ -43,13 +43,10 @@
|
||||
#include <opencv2/core/ocl.hpp>
|
||||
#include "layers_common.hpp"
|
||||
#include "convolution_layer.hpp"
|
||||
#include "im2col.hpp"
|
||||
#include "op_im2col.hpp"
|
||||
#include "op_blas.hpp"
|
||||
#include <iostream>
|
||||
|
||||
#if HAVE_CBLAS
|
||||
#include "cblas.h"
|
||||
#endif
|
||||
|
||||
namespace cv
|
||||
{
|
||||
namespace dnn
|
||||
@@ -78,17 +75,12 @@ namespace dnn
|
||||
//TBD
|
||||
useOpenCL = params.has("use_opencl");
|
||||
|
||||
//init BLAS
|
||||
#if HAVE_CBLAS
|
||||
{
|
||||
#ifdef OPENBLAS_VERSION
|
||||
if (openblas_get_num_threads() != cv::getNumThreads())
|
||||
if (getBlasThreads() != cv::getThreadNum())
|
||||
{
|
||||
openblas_set_num_threads(cv::getNumThreads());
|
||||
goto_set_num_threads(cv::getNumThreads());
|
||||
setBlasThreads(cv::getThreadNum());
|
||||
}
|
||||
//std::cout << "OpenBLAS threads " << openblas_get_num_threads() << "/" << openblas_get_num_procs() << "\n";
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -265,57 +257,5 @@ namespace dnn
|
||||
if (dstMat.type() == CV_64F)
|
||||
col2im_cpu((double*)colMat.ptr(), inpGroupCn, inpH, inpW, kerH, kerW, padH, padW, strideH, strideW, (double*)dstMat.ptr());
|
||||
}
|
||||
|
||||
void gemm(InputArray A, InputArray B, double alpha, InputOutputArray C, double beta, int flags /*= 0*/)
|
||||
{
|
||||
cv::gemm(A, B, alpha, C, beta, C, flags);
|
||||
}
|
||||
|
||||
inline void SwapRowCols(const Mat &A, int &rows, int &cols, bool transA = false)
|
||||
{
|
||||
rows = (transA) ? A.cols : A.rows;
|
||||
cols = (transA) ? A.rows : A.cols;
|
||||
}
|
||||
|
||||
void gemmCPU(const Mat &A, const Mat &B, double alpha, Mat &C, double beta, int flags /*= 0*/)
|
||||
{
|
||||
#if HAVE_CBLAS
|
||||
bool transA = flags & GEMM_1_T;
|
||||
bool transB = flags & GEMM_2_T;
|
||||
bool transC = flags & GEMM_3_T;
|
||||
|
||||
int Arows, Acols, Brows, Bcols, Crows, Ccols;
|
||||
SwapRowCols(A, Arows, Acols, transA);
|
||||
SwapRowCols(B, Brows, Bcols, transB);
|
||||
SwapRowCols(C, Crows, Ccols, transC);
|
||||
|
||||
CV_DbgAssert(!(flags & GEMM_3_T));
|
||||
CV_Assert(Acols == Brows && Arows == Crows && Bcols == Ccols);
|
||||
CV_DbgAssert(A.isContinuous() && B.isContinuous() && C.isContinuous());
|
||||
CV_DbgAssert(A.type() == CV_32F || A.type() == CV_64F);
|
||||
CV_DbgAssert(A.type() == B.type() && B.type() == C.type());
|
||||
|
||||
if (C.type() == CV_32F)
|
||||
{
|
||||
cblas_sgemm(CblasRowMajor, transA ? CblasTrans : CblasNoTrans, transB ? CblasTrans : CblasNoTrans,
|
||||
Arows, Bcols, Acols,
|
||||
(float)alpha, A.ptr<float>(), A.cols,
|
||||
B.ptr<float>(), B.cols,
|
||||
(float)beta, C.ptr<float>(), C.cols);
|
||||
}
|
||||
else if (C.type() == CV_64F)
|
||||
{
|
||||
//TODO: Should be tested
|
||||
cblas_dgemm(CblasRowMajor, transA ? CblasTrans : CblasNoTrans, transB ? CblasTrans : CblasNoTrans,
|
||||
Arows, Bcols, Acols,
|
||||
alpha, A.ptr<double>(), A.cols,
|
||||
B.ptr<double>(), B.cols,
|
||||
beta, C.ptr<double>(), C.cols);
|
||||
}
|
||||
#else
|
||||
cv::gemm(A, B, alpha, C, beta, C, flags);
|
||||
#endif
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,10 +87,6 @@ namespace dnn
|
||||
DeConvolutionLayer(LayerParams ¶ms);
|
||||
void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
|
||||
};
|
||||
|
||||
void gemm(InputArray A, InputArray B, double alpha, InputOutputArray C, double beta, int flags = 0);
|
||||
|
||||
void gemmCPU(const Mat &A, const Mat &B, double alpha, Mat &C, double beta, int flags = 0);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
84
modules/dnn/src/layers/op_blas.cpp
Normal file
84
modules/dnn/src/layers/op_blas.cpp
Normal file
@@ -0,0 +1,84 @@
|
||||
#include "op_blas.hpp"
|
||||
|
||||
#if HAVE_CBLAS
|
||||
#include "cblas.h"
|
||||
#endif
|
||||
|
||||
namespace cv
|
||||
{
|
||||
namespace dnn
|
||||
{
|
||||
|
||||
void gemm(InputArray A, InputArray B, double alpha, InputOutputArray C, double beta, int flags /*= 0*/)
|
||||
{
|
||||
cv::gemm(A, B, alpha, C, beta, C, flags);
|
||||
}
|
||||
|
||||
inline void SwapRowCols(const Mat &A, int &rows, int &cols, bool transA)
|
||||
{
|
||||
rows = (transA) ? A.cols : A.rows;
|
||||
cols = (transA) ? A.rows : A.cols;
|
||||
}
|
||||
|
||||
void gemmCPU(const Mat &A, const Mat &B, double alpha, Mat &C, double beta, int flags /*= 0*/)
|
||||
{
|
||||
#if HAVE_CBLAS
|
||||
int transA = flags & GEMM_1_T;
|
||||
int transB = flags & GEMM_2_T;
|
||||
int transC = flags & GEMM_3_T;
|
||||
|
||||
int Arows, Acols, Brows, Bcols, Crows, Ccols;
|
||||
SwapRowCols(A, Arows, Acols, transA);
|
||||
SwapRowCols(B, Brows, Bcols, transB);
|
||||
SwapRowCols(C, Crows, Ccols, transC);
|
||||
|
||||
CV_DbgAssert(!(flags & GEMM_3_T));
|
||||
CV_Assert(Acols == Brows && Arows == Crows && Bcols == Ccols);
|
||||
CV_DbgAssert(A.isContinuous() && B.isContinuous() && C.isContinuous());
|
||||
CV_DbgAssert(A.type() == CV_32F || A.type() == CV_64F);
|
||||
CV_DbgAssert(A.type() == B.type() && B.type() == C.type());
|
||||
|
||||
if (C.type() == CV_32F)
|
||||
{
|
||||
cblas_sgemm(CblasRowMajor, transA ? CblasTrans : CblasNoTrans, transB ? CblasTrans : CblasNoTrans,
|
||||
Arows, Bcols, Acols,
|
||||
(float)alpha, A.ptr<float>(), A.cols,
|
||||
B.ptr<float>(), B.cols,
|
||||
(float)beta, C.ptr<float>(), C.cols);
|
||||
}
|
||||
else if (C.type() == CV_64F)
|
||||
{
|
||||
//TODO: Should be tested
|
||||
cblas_dgemm(CblasRowMajor, transA ? CblasTrans : CblasNoTrans, transB ? CblasTrans : CblasNoTrans,
|
||||
Arows, Bcols, Acols,
|
||||
alpha, A.ptr<double>(), A.cols,
|
||||
B.ptr<double>(), B.cols,
|
||||
beta, C.ptr<double>(), C.cols);
|
||||
}
|
||||
#else
|
||||
cv::gemm(A, B, alpha, C, beta, C, flags);
|
||||
#endif
|
||||
}
|
||||
|
||||
int getBlasThreads()
|
||||
{
|
||||
#ifdef OPENBLAS_VERSION
|
||||
return openblas_get_num_threads();
|
||||
#else
|
||||
return 1;
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
void setBlasThreads(int numThreads)
|
||||
{
|
||||
#ifdef OPENBLAS_VERSION
|
||||
openblas_set_num_threads(numThreads);
|
||||
goto_set_num_threads(numThreads);
|
||||
#else
|
||||
numThreads = 0; //suppress compiler's warning
|
||||
#endif
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
59
modules/dnn/src/layers/op_blas.hpp
Normal file
59
modules/dnn/src/layers/op_blas.hpp
Normal file
@@ -0,0 +1,59 @@
|
||||
/*M///////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
|
||||
//
|
||||
// By downloading, copying, installing or using the software you agree to this license.
|
||||
// If you do not agree to this license, do not download, install,
|
||||
// copy or use the software.
|
||||
//
|
||||
//
|
||||
// License Agreement
|
||||
// For Open Source Computer Vision Library
|
||||
//
|
||||
// Copyright (C) 2013, OpenCV Foundation, all rights reserved.
|
||||
// Third party copyrights are property of their respective owners.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification,
|
||||
// are permitted provided that the following conditions are met:
|
||||
//
|
||||
// * Redistribution's of source code must retain the above copyright notice,
|
||||
// this list of conditions and the following disclaimer.
|
||||
//
|
||||
// * Redistribution's in binary form must reproduce the above copyright notice,
|
||||
// this list of conditions and the following disclaimer in the documentation
|
||||
// and/or other materials provided with the distribution.
|
||||
//
|
||||
// * The name of the copyright holders may not be used to endorse or promote products
|
||||
// derived from this software without specific prior written permission.
|
||||
//
|
||||
// This software is provided by the copyright holders and contributors "as is" and
|
||||
// any express or implied warranties, including, but not limited to, the implied
|
||||
// warranties of merchantability and fitness for a particular purpose are disclaimed.
|
||||
// In no event shall the Intel Corporation or contributors be liable for any direct,
|
||||
// indirect, incidental, special, exemplary, or consequential damages
|
||||
// (including, but not limited to, procurement of substitute goods or services;
|
||||
// loss of use, data, or profits; or business interruption) however caused
|
||||
// and on any theory of liability, whether in contract, strict liability,
|
||||
// or tort (including negligence or otherwise) arising in any way out of
|
||||
// the use of this software, even if advised of the possibility of such damage.
|
||||
//
|
||||
//M*/
|
||||
|
||||
#ifndef __OPENCV_DNN_LAYERS_OP_BLAS_HPP__
|
||||
#define __OPENCV_DNN_LAYERS_OP_BLAS_HPP__
|
||||
#include "../precomp.hpp"
|
||||
|
||||
namespace cv
|
||||
{
|
||||
namespace dnn
|
||||
{
|
||||
int getBlasThreads();
|
||||
|
||||
void setBlasThreads(int numThreads);
|
||||
|
||||
void gemm(InputArray A, InputArray B, double alpha, InputOutputArray C, double beta, int flags = 0);
|
||||
|
||||
void gemmCPU(const Mat &A, const Mat &B, double alpha, Mat &C, double beta, int flags = 0);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@@ -41,7 +41,7 @@
|
||||
|
||||
#include "../precomp.hpp"
|
||||
#include <opencv2/core/ocl.hpp>
|
||||
#include "im2col.hpp"
|
||||
#include "op_im2col.hpp"
|
||||
#include "opencl_kernels_dnn.hpp"
|
||||
|
||||
namespace cv
|
||||
@@ -43,6 +43,7 @@
|
||||
#include <opencv2/core/ocl.hpp>
|
||||
#include <iostream>
|
||||
#include "npy_blob.hpp"
|
||||
#include <opencv2/dnn/all_layers.hpp>
|
||||
|
||||
namespace cvtest
|
||||
{
|
||||
@@ -174,4 +175,61 @@ TEST(Layer_Test_Reshape_Split_Slice, Accuracy)
|
||||
normAssert(input, output);
|
||||
}
|
||||
|
||||
class Layer_LSTM_Test : public ::testing::Test
|
||||
{
|
||||
public:
|
||||
int Nx, Nc;
|
||||
Blob Wh, Wx, b;
|
||||
Ptr<LSTMLayer> lstm;
|
||||
|
||||
std::vector<Blob> inputs;
|
||||
std::vector<Blob> outputs;
|
||||
std::vector<Blob*> inputsPtr;
|
||||
|
||||
Layer_LSTM_Test(int _Nx = 31, int _Nc = 100)
|
||||
{
|
||||
Nx = _Nx;
|
||||
Nc = _Nc;
|
||||
|
||||
Wh = Blob(BlobShape(Vec2i(4 * Nc, Nc)));
|
||||
Wx = Blob(BlobShape(Vec2i(4 * Nc, Nx)));
|
||||
b = Blob(BlobShape(Vec2i(4 * Nc, 1)));
|
||||
|
||||
lstm = LSTMLayer::create();
|
||||
lstm->setWeights(Wh, Wx, b);
|
||||
}
|
||||
|
||||
void allocateAndForward()
|
||||
{
|
||||
inputsPtr.clear();
|
||||
for (size_t i = 0; i < inputs.size(); i++)
|
||||
inputsPtr.push_back(&inputs[i]);
|
||||
|
||||
lstm->allocate(inputsPtr, outputs);
|
||||
lstm->forward(inputsPtr, outputs);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(Layer_LSTM_Test, BasicTest_1)
|
||||
{
|
||||
inputs.push_back(Blob(BlobShape(1, 2, 3, Nx)));
|
||||
allocateAndForward();
|
||||
|
||||
EXPECT_EQ(outputs.size(), 2);
|
||||
EXPECT_EQ(outputs[0].shape(), BlobShape(1, 2, 3, Nc));
|
||||
EXPECT_EQ(outputs[1].shape(), BlobShape(1, 2, 3, Nc));
|
||||
}
|
||||
|
||||
TEST_F(Layer_LSTM_Test, BasicTest_2)
|
||||
{
|
||||
inputs.push_back(Blob(BlobShape(1, 2, 3, Nx)));
|
||||
inputs.push_back(Blob(BlobShape(1, 2, 3, Nc)));
|
||||
inputs.push_back(Blob(BlobShape(1, 2, 3, Nc)));
|
||||
allocateAndForward();
|
||||
|
||||
EXPECT_EQ(outputs.size(), 2);
|
||||
EXPECT_EQ(outputs[0].shape(), BlobShape(1, 2, 3, Nc));
|
||||
EXPECT_EQ(outputs[1].shape(), BlobShape(1, 2, 3, Nc));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user