1
0
mirror of https://github.com/opencv/opencv_contrib.git synced 2025-10-25 04:26:17 +08:00

Added Torch ENet support

This commit is contained in:
arrybn
2016-12-23 10:56:01 +03:00
parent c7cc1b3fe6
commit e784f137f6
26 changed files with 1116 additions and 183 deletions

View File

@@ -1,5 +1,6 @@
#include "../precomp.hpp"
#include "elementwise_layers.hpp"
#include "opencv2/imgproc.hpp"
namespace cv
{
@@ -42,5 +43,45 @@ Ptr<PowerLayer> PowerLayer::create(double power /*= 1*/, double scale /*= 1*/, d
return Ptr<PowerLayer>(new ElementWiseLayer<PowerFunctor>(f));
}
////////////////////////////////////////////////////////////////////////////
void ChannelsPReLULayerImpl::allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
CV_Assert(blobs.size() == 1);
outputs.resize(inputs.size());
for (size_t i = 0; i < inputs.size(); i++)
{
outputs[i].create(inputs[i]->shape());
}
}
void ChannelsPReLULayerImpl::forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
CV_Assert(inputs.size() == 1);
Blob &inpBlob = *inputs[0];
for (size_t ii = 0; ii < outputs.size(); ii++)
{
Blob &outBlob = outputs[ii];
CV_Assert(blobs[0].total() == inpBlob.channels());
for (int n = 0; n < inpBlob.channels(); n++)
{
float slopeWeight = blobs[0].matRefConst().at<float>(n);
cv::threshold(inpBlob.getPlane(0, n), outBlob.getPlane(0, n), 0, 0, cv::THRESH_TOZERO_INV);
outBlob.getPlane(0, n) = inpBlob.getPlane(0, n) + (slopeWeight - 1)*outBlob.getPlane(0, n);
}
}
}
Ptr<ChannelsPReLULayer> ChannelsPReLULayer::create()
{
return Ptr<ChannelsPReLULayer>(new ChannelsPReLULayerImpl());
}
}
}
}