mirror of
https://github.com/opencv/opencv_contrib.git
synced 2025-10-22 07:31:26 +08:00
Added python scripts to estimating accuracy
This commit is contained in:
@@ -9,7 +9,7 @@ endif()
|
||||
|
||||
set(the_description "Deep neural network module. It allows to load models from different frameworks and to make forward pass")
|
||||
|
||||
ocv_add_module(dnn opencv_core opencv_imgproc)
|
||||
ocv_add_module(dnn opencv_core opencv_imgproc WRAP python matlab)
|
||||
ocv_warnings_disable(CMAKE_CXX_FLAGS -Wno-shadow -Wno-parentheses -Wmaybe-uninitialized -Wsign-promo
|
||||
-Wmissing-declarations -Wmissing-prototypes
|
||||
)
|
||||
|
@@ -304,6 +304,16 @@ namespace dnn //! This namespace is used for dnn module functionlaity.
|
||||
*/
|
||||
CV_EXPORTS_W Net readNetFromCaffe(const String &prototxt, const String &caffeModel = String());
|
||||
|
||||
/** @brief Reads a network model stored in Tensorflow model file.
|
||||
* @details This is shortcut consisting from createTensorflowImporter and Net::populateNet calls.
|
||||
*/
|
||||
CV_EXPORTS_W Net readNetFromTensorflow(const String &model);
|
||||
|
||||
/** @brief Reads a network model stored in Torch model file.
|
||||
* @details This is shortcut consisting from createTorchImporter and Net::populateNet calls.
|
||||
*/
|
||||
CV_EXPORTS_W Net readNetFromTorch(const String &model, bool isBinary = true);
|
||||
|
||||
/** @brief Creates the importer of <a href="http://www.tensorflow.org">TensorFlow</a> framework network.
|
||||
* @param model path to the .pb file with binary protobuf description of the network architecture.
|
||||
* @returns Pointer to the created importer, NULL in failure cases.
|
||||
|
@@ -26,4 +26,10 @@ bool pyopencv_to(PyObject *o, dnn::DictValue &dv, const char *name)
|
||||
return false;
|
||||
}
|
||||
|
||||
template<>
|
||||
bool pyopencv_to(PyObject *o, std::vector<Mat> &blobs, const char *name) //required for Layer::blobs RW
|
||||
{
|
||||
return pyopencvVecConverter<Mat>::to(o, blobs, ArgInfo(name, false));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
20
modules/dnn/samples/enet-classes.txt
Normal file
20
modules/dnn/samples/enet-classes.txt
Normal file
@@ -0,0 +1,20 @@
|
||||
Unlabeled 0 0 0
|
||||
Road 128 64 128
|
||||
Sidewalk 244 35 232
|
||||
Building 70 70 70
|
||||
Wall 102 102 156
|
||||
Fence 190 153 153
|
||||
Pole 153 153 153
|
||||
TrafficLight 250 170 30
|
||||
TrafficSign 220 220 0
|
||||
Vegetation 107 142 35
|
||||
Terrain 152 251 152
|
||||
Sky 70 130 180
|
||||
Person 220 20 60
|
||||
Rider 255 0 0
|
||||
Car 0 0 142
|
||||
Truck 0 0 70
|
||||
Bus 0 60 100
|
||||
Train 0 80 100
|
||||
Motorcycle 0 0 230
|
||||
Bicycle 119 11 32
|
@@ -99,7 +99,7 @@ int main(int argc, char **argv)
|
||||
|
||||
Mat inputBlob = blobFromImage(img); //Convert Mat to image batch
|
||||
//! [Prepare blob]
|
||||
|
||||
inputBlob -= 117.0;
|
||||
//! [Set input blob]
|
||||
net.setBlob(inBlobName, inputBlob); //set the network input
|
||||
//! [Set input blob]
|
||||
|
@@ -26,9 +26,9 @@ const String keys =
|
||||
"{o_blob || output blob's name. If empty, last blob's name in net is used}"
|
||||
;
|
||||
|
||||
std::vector<String> readClassNames(const char *filename);
|
||||
static void colorizeSegmentation(const Mat &score, Mat &segm,
|
||||
Mat &legend, vector<String> &classNames);
|
||||
Mat &legend, vector<String> &classNames, vector<Vec3b> &colors);
|
||||
static vector<Vec3b> readColors(const String &filename, vector<String>& classNames);
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
@@ -52,43 +52,21 @@ int main(int argc, char **argv)
|
||||
String classNamesFile = parser.get<String>("c_names");
|
||||
String resultFile = parser.get<String>("result");
|
||||
|
||||
//! [Create the importer of TensorFlow model]
|
||||
Ptr<dnn::Importer> importer;
|
||||
try //Try to import TensorFlow AlexNet model
|
||||
{
|
||||
importer = dnn::createTorchImporter(modelFile);
|
||||
}
|
||||
catch (const cv::Exception &err) //Importer can throw errors, we will catch them
|
||||
{
|
||||
std::cerr << err.msg << std::endl;
|
||||
}
|
||||
//! [Create the importer of Caffe model]
|
||||
|
||||
if (!importer)
|
||||
{
|
||||
std::cerr << "Can't load network by using the mode file: " << std::endl;
|
||||
std::cerr << modelFile << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
//! [Initialize network]
|
||||
dnn::Net net;
|
||||
importer->populateNet(net);
|
||||
importer.release(); //We don't need importer anymore
|
||||
//! [Initialize network]
|
||||
//! [Read model and initialize network]
|
||||
dnn::Net net = dnn::readNetFromTorch(modelFile);
|
||||
|
||||
//! [Prepare blob]
|
||||
Mat img = imread(imageFile, 1);
|
||||
|
||||
Mat img = imread(imageFile), input;
|
||||
if (img.empty())
|
||||
{
|
||||
std::cerr << "Can't read image from the file: " << imageFile << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
Size inputImgSize(512, 512);
|
||||
Size origSize = img.size();
|
||||
Size inputImgSize = cv::Size(1024, 512);
|
||||
|
||||
if (inputImgSize != img.size())
|
||||
if (inputImgSize != origSize)
|
||||
resize(img, img, inputImgSize); //Resize image to input size
|
||||
|
||||
Mat inputBlob = blobFromImage(img, 1./255, true); //Convert Mat to image batch
|
||||
@@ -130,20 +108,18 @@ int main(int argc, char **argv)
|
||||
|
||||
if (parser.has("show"))
|
||||
{
|
||||
size_t nclasses = result.size[1];
|
||||
std::vector<String> classNames;
|
||||
vector<cv::Vec3b> colors;
|
||||
if(!classNamesFile.empty()) {
|
||||
classNames = readClassNames(classNamesFile.c_str());
|
||||
if (classNames.size() > nclasses)
|
||||
classNames = std::vector<String>(classNames.begin() + classNames.size() - nclasses,
|
||||
classNames.end());
|
||||
colors = readColors(classNamesFile, classNames);
|
||||
}
|
||||
Mat segm, legend;
|
||||
colorizeSegmentation(result, segm, legend, classNames);
|
||||
colorizeSegmentation(result, segm, legend, classNames, colors);
|
||||
|
||||
Mat show;
|
||||
addWeighted(img, 0.2, segm, 0.8, 0.0, show);
|
||||
addWeighted(img, 0.1, segm, 0.9, 0.0, show);
|
||||
|
||||
cv::resize(show, show, origSize, 0, 0, cv::INTER_NEAREST);
|
||||
imshow("Result", show);
|
||||
if(classNames.size())
|
||||
imshow("Legend", legend);
|
||||
@@ -153,44 +129,16 @@ int main(int argc, char **argv)
|
||||
return 0;
|
||||
} //main
|
||||
|
||||
|
||||
std::vector<String> readClassNames(const char *filename)
|
||||
{
|
||||
std::vector<String> classNames;
|
||||
|
||||
std::ifstream fp(filename);
|
||||
if (!fp.is_open())
|
||||
{
|
||||
std::cerr << "File with classes labels not found: " << filename << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
std::string name;
|
||||
while (!fp.eof())
|
||||
{
|
||||
std::getline(fp, name);
|
||||
if (name.length())
|
||||
classNames.push_back(name);
|
||||
}
|
||||
|
||||
fp.close();
|
||||
return classNames;
|
||||
}
|
||||
|
||||
static void colorizeSegmentation(const Mat &score, Mat &segm, Mat &legend, vector<String> &classNames)
|
||||
static void colorizeSegmentation(const Mat &score, Mat &segm, Mat &legend, vector<String> &classNames, vector<Vec3b> &colors)
|
||||
{
|
||||
const int rows = score.size[2];
|
||||
const int cols = score.size[3];
|
||||
const int chns = score.size[1];
|
||||
|
||||
vector<Vec3i> colors;
|
||||
RNG rng(12345678);
|
||||
|
||||
cv::Mat maxCl(rows, cols, CV_8UC1);
|
||||
cv::Mat maxVal(rows, cols, CV_32FC1);
|
||||
for (int ch = 0; ch < chns; ch++)
|
||||
{
|
||||
colors.push_back(Vec3i(rng.uniform(0, 256), rng.uniform(0, 256), rng.uniform(0, 256)));
|
||||
for (int row = 0; row < rows; row++)
|
||||
{
|
||||
const float *ptrScore = score.ptr<float>(0, ch, row);
|
||||
@@ -230,3 +178,38 @@ static void colorizeSegmentation(const Mat &score, Mat &segm, Mat &legend, vecto
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static vector<Vec3b> readColors(const String &filename, vector<String>& classNames)
|
||||
{
|
||||
vector<cv::Vec3b> colors;
|
||||
classNames.clear();
|
||||
|
||||
ifstream fp(filename.c_str());
|
||||
if (!fp.is_open())
|
||||
{
|
||||
cerr << "File with colors not found: " << filename << endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
string line;
|
||||
while (!fp.eof())
|
||||
{
|
||||
getline(fp, line);
|
||||
if (line.length())
|
||||
{
|
||||
stringstream ss(line);
|
||||
|
||||
string name; ss >> name;
|
||||
int temp;
|
||||
cv::Vec3b color;
|
||||
ss >> temp; color[0] = temp;
|
||||
ss >> temp; color[1] = temp;
|
||||
ss >> temp; color[2] = temp;
|
||||
classNames.push_back(name);
|
||||
colors.push_back(color);
|
||||
}
|
||||
}
|
||||
|
||||
fp.close();
|
||||
return colors;
|
||||
}
|
||||
|
@@ -604,7 +604,10 @@ void Net::setBlob(String outputName, const Mat &blob_)
|
||||
|
||||
LayerData &ld = impl->layers[pin.lid];
|
||||
ld.outputBlobs.resize( std::max(pin.oid+1, (int)ld.requiredOutputs.size()) );
|
||||
MatSize prevShape = ld.outputBlobs[pin.oid].size;
|
||||
ld.outputBlobs[pin.oid] = blob_.clone();
|
||||
|
||||
impl->netWasAllocated = prevShape == blob_.size;
|
||||
}
|
||||
|
||||
Mat Net::getBlob(String outputName)
|
||||
|
@@ -736,6 +736,23 @@ void TFImporter::populateNet(Net dstNet)
|
||||
|
||||
} // namespace
|
||||
|
||||
Net cv::dnn::readNetFromTensorflow(const String &model)
|
||||
{
|
||||
Ptr<Importer> importer;
|
||||
try
|
||||
{
|
||||
importer = createTensorflowImporter(model);
|
||||
}
|
||||
catch(...)
|
||||
{
|
||||
}
|
||||
|
||||
Net net;
|
||||
if (importer)
|
||||
importer->populateNet(net);
|
||||
return net;
|
||||
}
|
||||
|
||||
Ptr<Importer> cv::dnn::createTensorflowImporter(const String &model)
|
||||
{
|
||||
return Ptr<Importer>(new TFImporter(model.c_str()));
|
||||
|
@@ -970,6 +970,24 @@ Mat readTorchBlob(const String &filename, bool isBinary)
|
||||
|
||||
return importer->tensors.begin()->second;
|
||||
}
|
||||
|
||||
Net readNetFromTorch(const String &model, bool isBinary)
|
||||
{
|
||||
Ptr<Importer> importer;
|
||||
try
|
||||
{
|
||||
importer = createTorchImporter(model, isBinary);
|
||||
}
|
||||
catch(...)
|
||||
{
|
||||
}
|
||||
|
||||
Net net;
|
||||
if (importer)
|
||||
importer->populateNet(net);
|
||||
return net;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
Ptr<Importer> createTorchImporter(const String &filename, bool isBinary)
|
||||
|
142
modules/dnn/test/cityscapes_semsegm_test_enet.py
Normal file
142
modules/dnn/test/cityscapes_semsegm_test_enet.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import numpy as np
|
||||
import sys
|
||||
import os
|
||||
import fnmatch
|
||||
import argparse
|
||||
|
||||
# sys.path.append('<path to opencv_build_dir/lib>')
|
||||
sys.path.append('/home/arrybn/build/opencv_w_contrib/lib')
|
||||
try:
|
||||
import cv2 as cv
|
||||
except ImportError:
|
||||
raise ImportError('Can\'t find opencv. If you\'ve built it from sources without installation, '
|
||||
'uncomment the line before and insert there path to opencv_build_dir/lib dir')
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
raise ImportError('Can\'t find pytorch. Please intall it by following instructions on the official site')
|
||||
|
||||
from torch.utils.serialization import load_lua
|
||||
from pascal_semsegm_test_fcn import eval_segm_result, get_conf_mat, get_metrics, DatasetImageFetch, SemSegmEvaluation
|
||||
from imagenet_cls_test_alexnet import Framework, DnnCaffeModel
|
||||
|
||||
|
||||
class NormalizePreproc:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def process(img):
|
||||
image_data = np.array(img).transpose(2, 0, 1).astype(np.float32)
|
||||
image_data = np.expand_dims(image_data, 0)
|
||||
image_data /= 255.0
|
||||
return image_data
|
||||
|
||||
|
||||
class CityscapesDataFetch(DatasetImageFetch):
|
||||
img_dir = ''
|
||||
segm_dir = ''
|
||||
segm_files = []
|
||||
colors = []
|
||||
i = 0
|
||||
|
||||
def __init__(self, img_dir, segm_dir, preproc):
|
||||
self.img_dir = img_dir
|
||||
self.segm_dir = segm_dir
|
||||
self.segm_files = sorted([img for img in self.locate('*_color.png', segm_dir)])
|
||||
self.colors = self.get_colors()
|
||||
self.data_prepoc = preproc
|
||||
self.i = 0
|
||||
|
||||
@staticmethod
|
||||
def get_colors():
|
||||
result = []
|
||||
colors_list = (
|
||||
(0, 0, 0), (128, 64, 128), (244, 35, 232), (70, 70, 70), (102, 102, 156), (190, 153, 153), (153, 153, 153),
|
||||
(250, 170, 30), (220, 220, 0), (107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60), (255, 0, 0),
|
||||
(0, 0, 142), (0, 0, 70), (0, 60, 100), (0, 80, 100), (0, 0, 230), (119, 11, 32))
|
||||
|
||||
for c in colors_list:
|
||||
result.append(DatasetImageFetch.pix_to_c(c))
|
||||
return result
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def next(self):
|
||||
if self.i < len(self.segm_files):
|
||||
segm_file = self.segm_files[self.i]
|
||||
segm = cv.imread(segm_file, cv.IMREAD_COLOR)[:, :, ::-1]
|
||||
segm = cv.resize(segm, (1024, 512), interpolation=cv.INTER_NEAREST)
|
||||
|
||||
img_file = self.rreplace(self.img_dir + segm_file[len(self.segm_dir):], 'gtFine_color', 'leftImg8bit')
|
||||
assert os.path.exists(img_file)
|
||||
img = cv.imread(img_file, cv.IMREAD_COLOR)[:, :, ::-1]
|
||||
img = cv.resize(img, (1024, 512))
|
||||
|
||||
self.i += 1
|
||||
gt = self.color_to_gt(segm, self.colors)
|
||||
img = self.data_prepoc.process(img)
|
||||
return img, gt
|
||||
else:
|
||||
self.i = 0
|
||||
raise StopIteration
|
||||
|
||||
def get_num_classes(self):
|
||||
return len(self.colors)
|
||||
|
||||
@staticmethod
|
||||
def locate(pattern, root_path):
|
||||
for path, dirs, files in os.walk(os.path.abspath(root_path)):
|
||||
for filename in fnmatch.filter(files, pattern):
|
||||
yield os.path.join(path, filename)
|
||||
|
||||
@staticmethod
|
||||
def rreplace(s, old, new, occurrence=1):
|
||||
li = s.rsplit(old, occurrence)
|
||||
return new.join(li)
|
||||
|
||||
|
||||
class TorchModel(Framework):
|
||||
net = object
|
||||
|
||||
def __init__(self, model_file):
|
||||
self.net = load_lua(model_file)
|
||||
|
||||
def get_name(self):
|
||||
return 'Torch'
|
||||
|
||||
def get_output(self, input_blob):
|
||||
tensor = torch.FloatTensor(input_blob)
|
||||
out = self.net.forward(tensor).numpy()
|
||||
return out
|
||||
|
||||
|
||||
class DnnTorchModel(DnnCaffeModel):
|
||||
net = cv.dnn.Net()
|
||||
|
||||
def __init__(self, model_file):
|
||||
self.net = cv.dnn.readNetFromTorch(model_file)
|
||||
|
||||
def get_output(self, input_blob):
|
||||
self.net.setBlob("", input_blob)
|
||||
self.net.forward()
|
||||
return self.net.getBlob(self.net.getLayerNames()[-1])
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--imgs_dir", help="path to Cityscapes validation images dir, imgsfine/leftImg8bit/val")
|
||||
parser.add_argument("--segm_dir", help="path to Cityscapes dir with segmentation, gtfine/gtFine/val")
|
||||
parser.add_argument("--model", help="path to torch model, download it here: "
|
||||
"https://www.dropbox.com/sh/dywzk3gyb12hpe5/AAD5YkUa8XgMpHs2gCRgmCVCa")
|
||||
parser.add_argument("--log", help="path to logging file")
|
||||
args = parser.parse_args()
|
||||
|
||||
prep = NormalizePreproc()
|
||||
df = CityscapesDataFetch(args.imgs_dir, args.segm_dir, prep)
|
||||
|
||||
fw = [TorchModel(args.model),
|
||||
DnnTorchModel(args.model)]
|
||||
|
||||
segm_eval = SemSegmEvaluation(args.log)
|
||||
segm_eval.process(fw, df)
|
246
modules/dnn/test/imagenet_cls_test_alexnet.py
Normal file
246
modules/dnn/test/imagenet_cls_test_alexnet.py
Normal file
@@ -0,0 +1,246 @@
|
||||
from abc import ABCMeta, abstractmethod
|
||||
import numpy as np
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
import time
|
||||
|
||||
# sys.path.append('<path to git/caffe/python dir>')
|
||||
sys.path.append('/home/arrybn/git/caffe/python')
|
||||
try:
|
||||
import caffe
|
||||
except ImportError:
|
||||
raise ImportError('Can\'t find caffe. If you\'ve built it from sources without installation, '
|
||||
'uncomment the line before and insert there path to git/caffe/python dir')
|
||||
# sys.path.append('<path to opencv_build_dir/lib>')
|
||||
sys.path.append('/home/arrybn/build/opencv_w_contrib/lib')
|
||||
try:
|
||||
import cv2 as cv
|
||||
except ImportError:
|
||||
raise ImportError('Can\'t find opencv. If you\'ve built it from sources without installation, '
|
||||
'uncomment the line before and insert there path to opencv_build_dir/lib dir')
|
||||
|
||||
|
||||
class DataFetch(object):
|
||||
imgs_dir = ''
|
||||
frame_size = 0
|
||||
bgr_to_rgb = False
|
||||
__metaclass__ = ABCMeta
|
||||
|
||||
@abstractmethod
|
||||
def preprocess(self, img):
|
||||
pass
|
||||
|
||||
def get_batch(self, imgs_names):
|
||||
assert type(imgs_names) is list
|
||||
batch = np.zeros((len(imgs_names), 3, self.frame_size, self.frame_size)).astype(np.float32)
|
||||
for i in range(len(imgs_names)):
|
||||
img_name = imgs_names[i]
|
||||
img_file = self.imgs_dir + img_name
|
||||
assert os.path.exists(img_file)
|
||||
img = cv.imread(img_file, cv.IMREAD_COLOR)
|
||||
min_dim = min(img.shape[-3], img.shape[-2])
|
||||
resize_ratio = self.frame_size / float(min_dim)
|
||||
img = cv.resize(img, (0, 0), fx=resize_ratio, fy=resize_ratio)
|
||||
cols = img.shape[1]
|
||||
rows = img.shape[0]
|
||||
y1 = (rows - self.frame_size) / 2
|
||||
y2 = y1 + self.frame_size
|
||||
x1 = (cols - self.frame_size) / 2
|
||||
x2 = x1 + self.frame_size
|
||||
img = img[y1:y2, x1:x2]
|
||||
if self.bgr_to_rgb:
|
||||
img = img[..., ::-1]
|
||||
image_data = img[:, :, 0:3].transpose(2, 0, 1)
|
||||
batch[i] = self.preprocess(image_data)
|
||||
return batch
|
||||
|
||||
|
||||
class MeanBlobFetch(DataFetch):
|
||||
mean_blob = np.ndarray(())
|
||||
|
||||
def __init__(self, frame_size, mean_blob_path, imgs_dir):
|
||||
self.imgs_dir = imgs_dir
|
||||
self.frame_size = frame_size
|
||||
blob = caffe.proto.caffe_pb2.BlobProto()
|
||||
data = open(mean_blob_path, 'rb').read()
|
||||
blob.ParseFromString(data)
|
||||
self.mean_blob = np.array(caffe.io.blobproto_to_array(blob))
|
||||
start = (self.mean_blob.shape[2] - self.frame_size) / 2
|
||||
stop = start + self.frame_size
|
||||
self.mean_blob = self.mean_blob[:, :, start:stop, start:stop][0]
|
||||
|
||||
def preprocess(self, img):
|
||||
return img - self.mean_blob
|
||||
|
||||
|
||||
class MeanChannelsFetch(MeanBlobFetch):
|
||||
def __init__(self, frame_size, imgs_dir):
|
||||
self.imgs_dir = imgs_dir
|
||||
self.frame_size = frame_size
|
||||
self.mean_blob = np.ones((3, self.frame_size, self.frame_size)).astype(np.float32)
|
||||
self.mean_blob[0] *= 104
|
||||
self.mean_blob[1] *= 117
|
||||
self.mean_blob[2] *= 123
|
||||
|
||||
|
||||
class MeanValueFetch(MeanBlobFetch):
|
||||
def __init__(self, frame_size, imgs_dir, bgr_to_rgb):
|
||||
self.imgs_dir = imgs_dir
|
||||
self.frame_size = frame_size
|
||||
self.mean_blob = np.ones((3, self.frame_size, self.frame_size)).astype(np.float32)
|
||||
self.mean_blob *= 117
|
||||
self.bgr_to_rgb = bgr_to_rgb
|
||||
|
||||
|
||||
def get_correct_answers(img_list, img_classes, net_output_blob):
|
||||
correct_answers = 0
|
||||
for i in range(len(img_list)):
|
||||
indexes = np.argsort(net_output_blob[i])[-5:]
|
||||
correct_index = img_classes[img_list[i]]
|
||||
if correct_index in indexes:
|
||||
correct_answers += 1
|
||||
return correct_answers
|
||||
|
||||
|
||||
class Framework(object):
|
||||
in_blob_name = ''
|
||||
out_blob_name = ''
|
||||
|
||||
__metaclass__ = ABCMeta
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_output(self, input_blob):
|
||||
pass
|
||||
|
||||
|
||||
class CaffeModel(Framework):
|
||||
net = caffe.Net
|
||||
need_reshape = False
|
||||
|
||||
def __init__(self, prototxt, caffemodel, in_blob_name, out_blob_name, need_reshape=False):
|
||||
caffe.set_mode_cpu()
|
||||
self.net = caffe.Net(prototxt, caffemodel, caffe.TEST)
|
||||
self.in_blob_name = in_blob_name
|
||||
self.out_blob_name = out_blob_name
|
||||
self.need_reshape = need_reshape
|
||||
|
||||
def get_name(self):
|
||||
return 'Caffe'
|
||||
|
||||
def get_output(self, input_blob):
|
||||
if self.need_reshape:
|
||||
self.net.blobs[self.in_blob_name].reshape(*input_blob.shape)
|
||||
return self.net.forward_all(**{self.in_blob_name: input_blob})[self.out_blob_name]
|
||||
|
||||
|
||||
class DnnCaffeModel(Framework):
|
||||
net = object
|
||||
|
||||
def __init__(self, prototxt, caffemodel, in_blob_name, out_blob_name):
|
||||
self.net = cv.dnn.readNetFromCaffe(prototxt, caffemodel)
|
||||
self.in_blob_name = in_blob_name
|
||||
self.out_blob_name = out_blob_name
|
||||
|
||||
def get_name(self):
|
||||
return 'DNN'
|
||||
|
||||
def get_output(self, input_blob):
|
||||
self.net.setBlob(self.in_blob_name, input_blob)
|
||||
self.net.forward()
|
||||
return self.net.getBlob(self.out_blob_name)
|
||||
|
||||
|
||||
class ClsAccEvaluation:
|
||||
log = file
|
||||
img_classes = {}
|
||||
batch_size = 0
|
||||
|
||||
def __init__(self, log_path, img_classes_file, batch_size):
|
||||
self.log = open(log_path, 'w')
|
||||
self.img_classes = self.read_classes(img_classes_file)
|
||||
self.batch_size = batch_size
|
||||
|
||||
@staticmethod
|
||||
def read_classes(img_classes_file):
|
||||
result = {}
|
||||
with open(img_classes_file) as file:
|
||||
for l in file.readlines():
|
||||
result[l.split()[0]] = int(l.split()[1])
|
||||
return result
|
||||
|
||||
def process(self, frameworks, data_fetcher):
|
||||
sorted_imgs_names = sorted(self.img_classes.keys())
|
||||
correct_answers = [0] * len(frameworks)
|
||||
samples_handled = 0
|
||||
blobs_l1_diff = [0] * len(frameworks)
|
||||
blobs_l1_diff_count = [0] * len(frameworks)
|
||||
blobs_l_inf_diff = [sys.float_info.min] * len(frameworks)
|
||||
inference_time = [0.0] * len(frameworks)
|
||||
|
||||
for x in xrange(0, len(sorted_imgs_names), self.batch_size):
|
||||
sublist = sorted_imgs_names[x:x + self.batch_size]
|
||||
batch = data_fetcher.get_batch(sublist)
|
||||
|
||||
samples_handled += len(sublist)
|
||||
|
||||
frameworks_out = []
|
||||
fw_accuracy = []
|
||||
for i in range(len(frameworks)):
|
||||
start = time.time()
|
||||
out = frameworks[i].get_output(batch)
|
||||
end = time.time()
|
||||
correct_answers[i] += get_correct_answers(sublist, self.img_classes, out)
|
||||
fw_accuracy.append(100 * correct_answers[i] / float(samples_handled))
|
||||
frameworks_out.append(out)
|
||||
inference_time[i] += end - start
|
||||
print >> self.log, samples_handled, 'Accuracy for', frameworks[i].get_name() + ':', fw_accuracy[i]
|
||||
print >> self.log, "Inference time, ms ", \
|
||||
frameworks[i].get_name(), inference_time[i] / samples_handled * 1000
|
||||
|
||||
for i in range(1, len(frameworks)):
|
||||
log_str = frameworks[0].get_name() + " vs " + frameworks[i].get_name() + ':'
|
||||
diff = np.abs(frameworks_out[0] - frameworks_out[i])
|
||||
l1_diff = np.sum(diff) / diff.size
|
||||
print >> self.log, samples_handled, "L1 difference", log_str, l1_diff
|
||||
blobs_l1_diff[i] += l1_diff
|
||||
blobs_l1_diff_count[i] += 1
|
||||
if np.max(diff) > blobs_l_inf_diff[i]:
|
||||
blobs_l_inf_diff[i] = np.max(diff)
|
||||
print >> self.log, samples_handled, "L_INF difference", log_str, blobs_l_inf_diff[i]
|
||||
|
||||
self.log.flush()
|
||||
|
||||
for i in range(1, len(blobs_l1_diff)):
|
||||
log_str = frameworks[0].get_name() + " vs " + frameworks[i].get_name() + ':'
|
||||
print >> self.log, 'Final l1 diff', log_str, blobs_l1_diff[i] / blobs_l1_diff_count[i]
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--imgs_dir", help="path to ImageNet validation subset images dir, ILSVRC2012_img_val dir")
|
||||
parser.add_argument("--img_cls_file", help="path to file with classes ids for images, val.txt file from this "
|
||||
"archive: http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz")
|
||||
parser.add_argument("--prototxt", help="path to caffe prototxt, download it here: "
|
||||
"https://github.com/BVLC/caffe/blob/master/models/bvlc_alexnet/deploy.prototxt")
|
||||
parser.add_argument("--caffemodel", help="path to caffemodel file, download it here: "
|
||||
"http://dl.caffe.berkeleyvision.org/bvlc_alexnet.caffemodel")
|
||||
parser.add_argument("--log", help="path to logging file")
|
||||
parser.add_argument("--mean", help="path to ImageNet mean blob caffe file, imagenet_mean.binaryproto file from"
|
||||
"this archive: http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz")
|
||||
parser.add_argument("--batch_size", help="size of images in batch", default=1000)
|
||||
parser.add_argument("--frame_size", help="size of input image", default=227)
|
||||
parser.add_argument("--in_blob", help="name for input blob", default='data')
|
||||
parser.add_argument("--out_blob", help="name for output blob", default='prob')
|
||||
args = parser.parse_args()
|
||||
|
||||
data_fetcher = MeanBlobFetch(args.frame_size, args.mean, args.imgs_dir)
|
||||
|
||||
frameworks = [CaffeModel(args.prototxt, args.caffemodel, args.in_blob, args.out_blob),
|
||||
DnnCaffeModel(args.prototxt, args.caffemodel, '', args.out_blob)]
|
||||
|
||||
acc_eval = ClsAccEvaluation(args.log, args.img_cls_file, args.batch_size)
|
||||
acc_eval.process(frameworks, data_fetcher)
|
43
modules/dnn/test/imagenet_cls_test_googlenet.py
Normal file
43
modules/dnn/test/imagenet_cls_test_googlenet.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import numpy as np
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
from imagenet_cls_test_alexnet import MeanChannelsFetch, CaffeModel, DnnCaffeModel, ClsAccEvaluation
|
||||
# sys.path.append('<path to git/caffe/python dir>')
|
||||
sys.path.append('/home/arrybn/git/caffe/python')
|
||||
try:
|
||||
import caffe
|
||||
except ImportError:
|
||||
raise ImportError('Can\'t find caffe. If you\'ve built it from sources without installation, '
|
||||
'uncomment the line before and insert there path to git/caffe/python dir')
|
||||
# sys.path.append('<path to opencv_build_dir/lib>')
|
||||
sys.path.append('/home/arrybn/build/opencv_w_contrib/lib')
|
||||
try:
|
||||
import cv2 as cv
|
||||
except ImportError:
|
||||
raise ImportError('Can\'t find opencv. If you\'ve built it from sources without installation, '
|
||||
'uncomment the line before and insert there path to opencv_build_dir/lib dir')
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--imgs_dir", help="path to ImageNet validation subset images dir, ILSVRC2012_img_val dir")
|
||||
parser.add_argument("--img_cls_file", help="path to file with classes ids for images, val.txt file from this "
|
||||
"archive: http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz")
|
||||
parser.add_argument("--prototxt", help="path to caffe prototxt, download it here: "
|
||||
"https://github.com/BVLC/caffe/blob/master/models/bvlc_alexnet/deploy.prototxt")
|
||||
parser.add_argument("--caffemodel", help="path to caffemodel file, download it here: "
|
||||
"http://dl.caffe.berkeleyvision.org/bvlc_alexnet.caffemodel")
|
||||
parser.add_argument("--log", help="path to logging file")
|
||||
parser.add_argument("--batch_size", help="size of images in batch", default=500, type=int)
|
||||
parser.add_argument("--frame_size", help="size of input image", default=224, type=int)
|
||||
parser.add_argument("--in_blob", help="name for input blob", default='data')
|
||||
parser.add_argument("--out_blob", help="name for output blob", default='prob')
|
||||
args = parser.parse_args()
|
||||
|
||||
data_fetcher = MeanChannelsFetch(args.frame_size, args.imgs_dir)
|
||||
|
||||
frameworks = [CaffeModel(args.prototxt, args.caffemodel, args.in_blob, args.out_blob),
|
||||
DnnCaffeModel(args.prototxt, args.caffemodel, '', args.out_blob)]
|
||||
|
||||
acc_eval = ClsAccEvaluation(args.log, args.img_cls_file, args.batch_size)
|
||||
acc_eval.process(frameworks, data_fetcher)
|
79
modules/dnn/test/imagenet_cls_test_inception.py
Normal file
79
modules/dnn/test/imagenet_cls_test_inception.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import numpy as np
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.platform import gfile
|
||||
from imagenet_cls_test_alexnet import MeanValueFetch, DnnCaffeModel, Framework, ClsAccEvaluation
|
||||
# sys.path.append('<path to opencv_build_dir/lib>')
|
||||
sys.path.append('/home/arrybn/build/opencv_w_contrib/lib')
|
||||
try:
|
||||
import cv2 as cv
|
||||
except ImportError:
|
||||
raise ImportError('Can\'t find opencv. If you\'ve built it from sources without installation, '
|
||||
'uncomment the line before and insert there path to opencv_build_dir/lib dir')
|
||||
|
||||
# If you've got an exception "Cannot load libmkl_avx.so or libmkl_def.so" or similar, try to export next variable
|
||||
# before runnigng the script:
|
||||
# LD_PRELOAD=/opt/intel/mkl/lib/intel64/libmkl_core.so:/opt/intel/mkl/lib/intel64/libmkl_sequential.so
|
||||
|
||||
|
||||
class TensorflowModel(Framework):
|
||||
sess = tf.Session
|
||||
output = tf.Graph
|
||||
|
||||
def __init__(self, model_file, in_blob_name, out_blob_name):
|
||||
self.in_blob_name = in_blob_name
|
||||
self.sess = tf.Session()
|
||||
with gfile.FastGFile(model_file, 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
self.sess.graph.as_default()
|
||||
tf.import_graph_def(graph_def, name='')
|
||||
self.output = self.sess.graph.get_tensor_by_name(out_blob_name + ":0")
|
||||
|
||||
def get_name(self):
|
||||
return 'Tensorflow'
|
||||
|
||||
def get_output(self, input_blob):
|
||||
assert len(input_blob.shape) == 4
|
||||
batch_tf = input_blob.transpose(0, 2, 3, 1)
|
||||
out = self.sess.run(self.output,
|
||||
{self.in_blob_name+':0': batch_tf})
|
||||
out = out[..., 1:1001]
|
||||
return out
|
||||
|
||||
|
||||
class DnnTfInceptionModel(DnnCaffeModel):
|
||||
net = cv.dnn.Net()
|
||||
|
||||
def __init__(self, model_file, in_blob_name, out_blob_name):
|
||||
self.net = cv.dnn.readNetFromTensorflow(model_file)
|
||||
self.in_blob_name = in_blob_name
|
||||
self.out_blob_name = out_blob_name
|
||||
|
||||
def get_output(self, input_blob):
|
||||
return super(DnnTfInceptionModel, self).get_output(input_blob)[..., 1:1001]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--imgs_dir", help="path to ImageNet validation subset images dir, ILSVRC2012_img_val dir")
|
||||
parser.add_argument("--img_cls_file", help="path to file with classes ids for images, download it here:"
|
||||
"https://github.com/opencv/opencv_extra/tree/master/testdata/dnn/img_classes_inception.txt")
|
||||
parser.add_argument("--model", help="path to tensorflow model, download it here:"
|
||||
"https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip")
|
||||
parser.add_argument("--log", help="path to logging file")
|
||||
parser.add_argument("--batch_size", help="size of images in batch", default=1)
|
||||
parser.add_argument("--frame_size", help="size of input image", default=224)
|
||||
parser.add_argument("--in_blob", help="name for input blob", default='input')
|
||||
parser.add_argument("--out_blob", help="name for output blob", default='softmax2')
|
||||
args = parser.parse_args()
|
||||
|
||||
data_fetcher = MeanValueFetch(args.frame_size, args.imgs_dir, True)
|
||||
|
||||
frameworks = [TensorflowModel(args.model, args.in_blob, args.out_blob),
|
||||
DnnTfInceptionModel(args.model, '', args.out_blob)]
|
||||
|
||||
acc_eval = ClsAccEvaluation(args.log, args.img_cls_file, args.batch_size)
|
||||
acc_eval.process(frameworks, data_fetcher)
|
225
modules/dnn/test/pascal_semsegm_test_fcn.py
Normal file
225
modules/dnn/test/pascal_semsegm_test_fcn.py
Normal file
@@ -0,0 +1,225 @@
|
||||
from abc import ABCMeta, abstractmethod
|
||||
import numpy as np
|
||||
import sys
|
||||
import argparse
|
||||
import time
|
||||
|
||||
from imagenet_cls_test_alexnet import CaffeModel, DnnCaffeModel
|
||||
sys.path.append('/home/arrybn/build/opencv_w_contrib/lib')
|
||||
try:
|
||||
import cv2 as cv
|
||||
except ImportError:
|
||||
raise ImportError('Can\'t find opencv. If you\'ve built it from sources without installation, '
|
||||
'uncomment the line before and insert there path to opencv_build_dir/lib dir')
|
||||
|
||||
|
||||
def get_metrics(conf_mat):
|
||||
pix_accuracy = np.trace(conf_mat) / np.sum(conf_mat)
|
||||
t = np.sum(conf_mat, 1)
|
||||
num_cl = np.count_nonzero(t)
|
||||
assert num_cl
|
||||
mean_accuracy = np.sum(np.nan_to_num(np.divide(np.diagonal(conf_mat), t))) / num_cl
|
||||
col_sum = np.sum(conf_mat, 0)
|
||||
mean_iou = np.sum(
|
||||
np.nan_to_num(np.divide(np.diagonal(conf_mat), (t + col_sum - np.diagonal(conf_mat))))) / num_cl
|
||||
return pix_accuracy, mean_accuracy, mean_iou
|
||||
|
||||
|
||||
def eval_segm_result(net_out):
|
||||
assert type(net_out) is np.ndarray
|
||||
assert len(net_out.shape) == 4
|
||||
|
||||
channels_dim = 1
|
||||
y_dim = channels_dim + 1
|
||||
x_dim = y_dim + 1
|
||||
res = np.zeros(net_out.shape).astype(np.int)
|
||||
for i in range(net_out.shape[y_dim]):
|
||||
for j in range(net_out.shape[x_dim]):
|
||||
max_ch = np.argmax(net_out[..., i, j])
|
||||
res[0, max_ch, i, j] = 1
|
||||
return res
|
||||
|
||||
|
||||
def get_conf_mat(gt, prob):
|
||||
assert type(gt) is np.ndarray
|
||||
assert type(prob) is np.ndarray
|
||||
|
||||
conf_mat = np.zeros((gt.shape[0], gt.shape[0]))
|
||||
for ch_gt in range(conf_mat.shape[0]):
|
||||
gt_channel = gt[ch_gt, ...]
|
||||
for ch_pr in range(conf_mat.shape[1]):
|
||||
prob_channel = prob[ch_pr, ...]
|
||||
conf_mat[ch_gt][ch_pr] = np.count_nonzero(np.multiply(gt_channel, prob_channel))
|
||||
return conf_mat
|
||||
|
||||
|
||||
class MeanChannelsPreproc:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def process(img):
|
||||
image_data = np.array(img).transpose(2, 0, 1).astype(np.float32)
|
||||
mean = np.ones(image_data.shape)
|
||||
mean[0] *= 104
|
||||
mean[1] *= 117
|
||||
mean[2] *= 123
|
||||
image_data -= mean
|
||||
image_data = np.expand_dims(image_data, 0)
|
||||
return image_data
|
||||
|
||||
|
||||
class DatasetImageFetch(object):
|
||||
__metaclass__ = ABCMeta
|
||||
data_prepoc = object
|
||||
|
||||
@abstractmethod
|
||||
def __iter__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def next(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def pix_to_c(pix):
|
||||
return pix[0] * 256 * 256 + pix[1] * 256 + pix[2]
|
||||
|
||||
@staticmethod
|
||||
def color_to_gt(color_img, colors):
|
||||
num_classes = len(colors)
|
||||
gt = np.zeros((num_classes, color_img.shape[0], color_img.shape[1])).astype(np.int)
|
||||
for img_y in range(color_img.shape[0]):
|
||||
for img_x in range(color_img.shape[1]):
|
||||
c = DatasetImageFetch.pix_to_c(color_img[img_y][img_x])
|
||||
if c in colors:
|
||||
cls = colors.index(c)
|
||||
gt[cls][img_y][img_x] = 1
|
||||
return gt
|
||||
|
||||
|
||||
class PASCALDataFetch(DatasetImageFetch):
|
||||
img_dir = ''
|
||||
segm_dir = ''
|
||||
names = []
|
||||
colors = []
|
||||
i = 0
|
||||
|
||||
def __init__(self, img_dir, segm_dir, names_file, segm_cls_colors_file, preproc):
|
||||
self.img_dir = img_dir
|
||||
self.segm_dir = segm_dir
|
||||
self.colors = self.read_colors(segm_cls_colors_file)
|
||||
self.data_prepoc = preproc
|
||||
self.i = 0
|
||||
|
||||
with open(names_file) as f:
|
||||
for l in f.readlines():
|
||||
self.names.append(l.rstrip())
|
||||
|
||||
@staticmethod
|
||||
def read_colors(img_classes_file):
|
||||
result = []
|
||||
with open(img_classes_file) as f:
|
||||
for l in f.readlines():
|
||||
color = np.array(map(int, l.split()[1:]))
|
||||
result.append(DatasetImageFetch.pix_to_c(color))
|
||||
return result
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def next(self):
|
||||
if self.i < len(self.names):
|
||||
name = self.names[self.i]
|
||||
self.i += 1
|
||||
segm_file = self.segm_dir + name + ".png"
|
||||
img_file = self.img_dir + name + ".jpg"
|
||||
gt = self.color_to_gt(cv.imread(segm_file, cv.IMREAD_COLOR)[:, :, ::-1], self.colors)
|
||||
img = self.data_prepoc.process(cv.imread(img_file, cv.IMREAD_COLOR)[:, :, ::-1])
|
||||
return img, gt
|
||||
else:
|
||||
self.i = 0
|
||||
raise StopIteration
|
||||
|
||||
def get_num_classes(self):
|
||||
return len(self.colors)
|
||||
|
||||
|
||||
class SemSegmEvaluation:
|
||||
log = file
|
||||
|
||||
def __init__(self, log_path,):
|
||||
self.log = open(log_path, 'w')
|
||||
|
||||
def process(self, frameworks, data_fetcher):
|
||||
samples_handled = 0
|
||||
|
||||
conf_mats = [np.zeros((data_fetcher.get_num_classes(), data_fetcher.get_num_classes())) for i in range(len(frameworks))]
|
||||
blobs_l1_diff = [0] * len(frameworks)
|
||||
blobs_l1_diff_count = [0] * len(frameworks)
|
||||
blobs_l_inf_diff = [sys.float_info.min] * len(frameworks)
|
||||
inference_time = [0.0] * len(frameworks)
|
||||
|
||||
for in_blob, gt in data_fetcher:
|
||||
frameworks_out = []
|
||||
samples_handled += 1
|
||||
for i in range(len(frameworks)):
|
||||
start = time.time()
|
||||
out = frameworks[i].get_output(in_blob)
|
||||
end = time.time()
|
||||
segm = eval_segm_result(out)
|
||||
conf_mats[i] += get_conf_mat(gt, segm[0])
|
||||
frameworks_out.append(out)
|
||||
inference_time[i] += end - start
|
||||
|
||||
pix_acc, mean_acc, miou = get_metrics(conf_mats[i])
|
||||
|
||||
name = frameworks[i].get_name()
|
||||
print >> self.log, samples_handled, 'Pixel accuracy, %s:' % name, 100 * pix_acc
|
||||
print >> self.log, samples_handled, 'Mean accuracy, %s:' % name, 100 * mean_acc
|
||||
print >> self.log, samples_handled, 'Mean IOU, %s:' % name, 100 * miou
|
||||
print >> self.log, "Inference time, ms ", \
|
||||
frameworks[i].get_name(), inference_time[i] / samples_handled * 1000
|
||||
|
||||
for i in range(1, len(frameworks)):
|
||||
log_str = frameworks[0].get_name() + " vs " + frameworks[i].get_name() + ':'
|
||||
diff = np.abs(frameworks_out[0] - frameworks_out[i])
|
||||
l1_diff = np.sum(diff) / diff.size
|
||||
print >> self.log, samples_handled, "L1 difference", log_str, l1_diff
|
||||
blobs_l1_diff[i] += l1_diff
|
||||
blobs_l1_diff_count[i] += 1
|
||||
if np.max(diff) > blobs_l_inf_diff[i]:
|
||||
blobs_l_inf_diff[i] = np.max(diff)
|
||||
print >> self.log, samples_handled, "L_INF difference", log_str, blobs_l_inf_diff[i]
|
||||
|
||||
self.log.flush()
|
||||
|
||||
for i in range(1, len(blobs_l1_diff)):
|
||||
log_str = frameworks[0].get_name() + " vs " + frameworks[i].get_name() + ':'
|
||||
print >> self.log, 'Final l1 diff', log_str, blobs_l1_diff[i] / blobs_l1_diff_count[i]
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--imgs_dir", help="path to PASCAL VOC 2012 images dir, data/VOC2012/JPEGImages")
|
||||
parser.add_argument("--segm_dir", help="path to PASCAL VOC 2012 segmentation dir, data/VOC2012/SegmentationClass/")
|
||||
parser.add_argument("--val_names", help="path to file with validation set image names, download it here: "
|
||||
"https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/data/pascal/seg11valid.txt")
|
||||
parser.add_argument("--cls_file", help="path to file with colors for classes, download it here: "
|
||||
"https://github.com/opencv/opencv_contrib/blob/master/modules/dnn/samples/pascal-classes.txt")
|
||||
parser.add_argument("--prototxt", help="path to caffe prototxt, download it here: "
|
||||
"https://github.com/opencv/opencv_contrib/blob/master/modules/dnn/samples/fcn8s-heavy-pascal.prototxt")
|
||||
parser.add_argument("--caffemodel", help="path to caffemodel file, download it here: "
|
||||
"http://dl.caffe.berkeleyvision.org/fcn8s-heavy-pascal.caffemodel")
|
||||
parser.add_argument("--log", help="path to logging file")
|
||||
parser.add_argument("--in_blob", help="name for input blob", default='data')
|
||||
parser.add_argument("--out_blob", help="name for output blob", default='score')
|
||||
args = parser.parse_args()
|
||||
|
||||
prep = MeanChannelsPreproc()
|
||||
df = PASCALDataFetch(args.imgs_dir, args.segm_dir, args.val_names, args.cls_file, prep)
|
||||
|
||||
fw = [CaffeModel(args.prototxt, args.caffemodel, args.in_blob, args.out_blob, True),
|
||||
DnnCaffeModel(args.prototxt, args.caffemodel, '', args.out_blob)]
|
||||
|
||||
segm_eval = SemSegmEvaluation(args.log)
|
||||
segm_eval.process(fw, df)
|
Reference in New Issue
Block a user