mirror of
https://github.com/opencv/opencv_contrib.git
synced 2025-10-19 11:21:39 +08:00
Merge pull request #3943 from troelsy:4.x
Add Otsu's method to cv::cuda::threshold #3943 I implemented Otsu's method in CUDA for a separate project and want to add it to cv::cuda::threshold I have made an effort to use existing OpenCV functions in my code, but I had some trouble with `ThresholdTypes` and `cv::cuda::calcHist`. I couldn't figure out how to include `precomp.hpp` to get the definition of `ThresholdTypes`. For `cv::cuda::calcHist` I tried adding `opencv_cudaimgproc`, but it creates a circular dependency on `cudaarithm`. I have include a simple implementation of `calcHist` so the code runs, but I would like input on how to use `cv::cuda::calcHist` instead. ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [ ] The PR is proposed to the proper branch - [ ] There is a reference to the original bug report and related work - [ ] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [ ] The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
@@ -546,12 +546,16 @@ static inline void scaleAdd(InputArray src1, double alpha, InputArray src2, Outp
|
||||
|
||||
/** @brief Applies a fixed-level threshold to each array element.
|
||||
|
||||
The special value cv::THRESH_OTSU may be combined with one of the other types. In this case, the function determines the
|
||||
optimal threshold value using the Otsu's and uses it instead of the specified threshold. The function returns the
|
||||
computed threshold value in addititon to the thresholded matrix.
|
||||
The Otsu's method is implemented only for 8-bit matrices.
|
||||
|
||||
@param src Source array (single-channel).
|
||||
@param dst Destination array with the same size and type as src .
|
||||
@param dst Destination array with the same size and type as src.
|
||||
@param thresh Threshold value.
|
||||
@param maxval Maximum value to use with THRESH_BINARY and THRESH_BINARY_INV threshold types.
|
||||
@param type Threshold type. For details, see threshold . The THRESH_OTSU and THRESH_TRIANGLE
|
||||
threshold types are not supported.
|
||||
@param type Threshold type. For details, see threshold. The THRESH_TRIANGLE threshold type is not supported.
|
||||
@param stream Stream for the asynchronous version.
|
||||
|
||||
@sa threshold
|
||||
|
@@ -95,12 +95,232 @@ namespace
|
||||
}
|
||||
}
|
||||
|
||||
double cv::cuda::threshold(InputArray _src, OutputArray _dst, double thresh, double maxVal, int type, Stream& stream)
|
||||
|
||||
__global__ void otsu_sums(uint *histogram, uint *threshold_sums, unsigned long long *sums)
|
||||
{
|
||||
const uint32_t n_bins = 256;
|
||||
|
||||
__shared__ uint shared_memory_ts[n_bins];
|
||||
__shared__ unsigned long long shared_memory_s[n_bins];
|
||||
|
||||
int bin_idx = threadIdx.x;
|
||||
int threshold = blockIdx.x;
|
||||
|
||||
uint threshold_sum_above = 0;
|
||||
unsigned long long sum_above = 0;
|
||||
|
||||
if (bin_idx >= threshold)
|
||||
{
|
||||
uint value = histogram[bin_idx];
|
||||
threshold_sum_above = value;
|
||||
sum_above = value * bin_idx;
|
||||
}
|
||||
|
||||
blockReduce<n_bins>(shared_memory_ts, threshold_sum_above, bin_idx, plus<uint>());
|
||||
blockReduce<n_bins>(shared_memory_s, sum_above, bin_idx, plus<unsigned long long>());
|
||||
|
||||
if (bin_idx == 0)
|
||||
{
|
||||
threshold_sums[threshold] = threshold_sum_above;
|
||||
sums[threshold] = sum_above;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void
|
||||
otsu_variance(float2 *variance, uint *histogram, uint *threshold_sums, unsigned long long *sums)
|
||||
{
|
||||
const uint32_t n_bins = 256;
|
||||
|
||||
__shared__ signed long long shared_memory_a[n_bins];
|
||||
__shared__ signed long long shared_memory_b[n_bins];
|
||||
|
||||
int bin_idx = threadIdx.x;
|
||||
int threshold = blockIdx.x;
|
||||
|
||||
uint n_samples = threshold_sums[0];
|
||||
uint n_samples_above = threshold_sums[threshold];
|
||||
uint n_samples_below = n_samples - n_samples_above;
|
||||
|
||||
unsigned long long total_sum = sums[0];
|
||||
unsigned long long sum_above = sums[threshold];
|
||||
unsigned long long sum_below = total_sum - sum_above;
|
||||
|
||||
float threshold_variance_above_f32 = 0;
|
||||
float threshold_variance_below_f32 = 0;
|
||||
if (bin_idx >= threshold)
|
||||
{
|
||||
float mean = (float) sum_above / n_samples_above;
|
||||
float sigma = bin_idx - mean;
|
||||
threshold_variance_above_f32 = sigma * sigma;
|
||||
}
|
||||
else
|
||||
{
|
||||
float mean = (float) sum_below / n_samples_below;
|
||||
float sigma = bin_idx - mean;
|
||||
threshold_variance_below_f32 = sigma * sigma;
|
||||
}
|
||||
|
||||
uint bin_count = histogram[bin_idx];
|
||||
signed long long threshold_variance_above_i64 = (signed long long)(threshold_variance_above_f32 * bin_count);
|
||||
signed long long threshold_variance_below_i64 = (signed long long)(threshold_variance_below_f32 * bin_count);
|
||||
blockReduce<n_bins>(shared_memory_a, threshold_variance_above_i64, bin_idx, plus<signed long long>());
|
||||
blockReduce<n_bins>(shared_memory_b, threshold_variance_below_i64, bin_idx, plus<signed long long>());
|
||||
|
||||
if (bin_idx == 0)
|
||||
{
|
||||
variance[threshold] = make_float2(threshold_variance_above_i64, threshold_variance_below_i64);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
__global__ void
|
||||
otsu_score(uint *otsu_threshold, uint *threshold_sums, float2 *variance)
|
||||
{
|
||||
const uint32_t n_thresholds = 256;
|
||||
|
||||
__shared__ float shared_memory[n_thresholds / WARP_SIZE];
|
||||
|
||||
int threshold = threadIdx.x;
|
||||
|
||||
uint n_samples = threshold_sums[0];
|
||||
uint n_samples_above = threshold_sums[threshold];
|
||||
uint n_samples_below = n_samples - n_samples_above;
|
||||
|
||||
float threshold_mean_above = (float)n_samples_above / n_samples;
|
||||
float threshold_mean_below = (float)n_samples_below / n_samples;
|
||||
|
||||
float2 variances = variance[threshold];
|
||||
float variance_above = variances.x / n_samples_above;
|
||||
float variance_below = variances.y / n_samples_below;
|
||||
|
||||
float above = threshold_mean_above * variance_above;
|
||||
float below = threshold_mean_below * variance_below;
|
||||
float score = above + below;
|
||||
|
||||
float original_score = score;
|
||||
|
||||
blockReduce<n_thresholds>(shared_memory, score, threshold, minimum<float>());
|
||||
|
||||
if (threshold == 0)
|
||||
{
|
||||
shared_memory[0] = score;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
score = shared_memory[0];
|
||||
|
||||
// We found the minimum score, but we need to find the threshold. If we find the thread with the minimum score, we
|
||||
// know which threshold it is
|
||||
if (original_score == score)
|
||||
{
|
||||
*otsu_threshold = threshold - 1;
|
||||
}
|
||||
}
|
||||
|
||||
void compute_otsu(uint *histogram, uint *otsu_threshold, Stream &stream)
|
||||
{
|
||||
const uint n_bins = 256;
|
||||
const uint n_thresholds = 256;
|
||||
|
||||
cudaStream_t cuda_stream = StreamAccessor::getStream(stream);
|
||||
|
||||
dim3 block_all(n_bins);
|
||||
dim3 grid_all(n_thresholds);
|
||||
dim3 block_score(n_thresholds);
|
||||
dim3 grid_score(1);
|
||||
|
||||
BufferPool pool(stream);
|
||||
GpuMat gpu_threshold_sums(1, n_bins, CV_32SC1, pool.getAllocator());
|
||||
GpuMat gpu_sums(1, n_bins, CV_64FC1, pool.getAllocator());
|
||||
GpuMat gpu_variances(1, n_bins, CV_32FC2, pool.getAllocator());
|
||||
|
||||
otsu_sums<<<grid_all, block_all, 0, cuda_stream>>>(
|
||||
histogram, gpu_threshold_sums.ptr<uint>(), gpu_sums.ptr<unsigned long long>());
|
||||
otsu_variance<<<grid_all, block_all, 0, cuda_stream>>>(
|
||||
gpu_variances.ptr<float2>(), histogram, gpu_threshold_sums.ptr<uint>(), gpu_sums.ptr<unsigned long long>());
|
||||
otsu_score<<<grid_score, block_score, 0, cuda_stream>>>(
|
||||
otsu_threshold, gpu_threshold_sums.ptr<uint>(), gpu_variances.ptr<float2>());
|
||||
}
|
||||
|
||||
// TODO: Replace this is cv::cuda::calcHist
|
||||
template <uint n_bins>
|
||||
__global__ void histogram_kernel(
|
||||
uint *histogram, const uint8_t *image, uint width,
|
||||
uint height, uint pitch)
|
||||
{
|
||||
__shared__ uint local_histogram[n_bins];
|
||||
|
||||
uint x = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
uint y = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
uint tid = threadIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
if (tid < n_bins)
|
||||
{
|
||||
local_histogram[tid] = 0;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (x < width && y < height)
|
||||
{
|
||||
uint8_t value = image[y * pitch + x];
|
||||
atomicInc(&local_histogram[value], 0xFFFFFFFF);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (tid < n_bins)
|
||||
{
|
||||
cv::cudev::atomicAdd(&histogram[tid], local_histogram[tid]);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Replace this with cv::cuda::calcHist
|
||||
void calcHist(
|
||||
const GpuMat src, GpuMat histogram, Stream stream)
|
||||
{
|
||||
const uint n_bins = 256;
|
||||
|
||||
cudaStream_t cuda_stream = StreamAccessor::getStream(stream);
|
||||
|
||||
dim3 block(128, 4, 1);
|
||||
dim3 grid = dim3(divUp(src.cols, block.x), divUp(src.rows, block.y), 1);
|
||||
CV_CUDEV_SAFE_CALL(cudaMemsetAsync(histogram.ptr<uint>(), 0, n_bins * sizeof(uint), cuda_stream));
|
||||
histogram_kernel<n_bins>
|
||||
<<<grid, block, 0, cuda_stream>>>(
|
||||
histogram.ptr<uint>(), src.ptr<uint8_t>(), (uint) src.cols, (uint) src.rows, (uint) src.step);
|
||||
}
|
||||
|
||||
double cv::cuda::threshold(InputArray _src, OutputArray _dst, double thresh, double maxVal, int type, Stream &stream)
|
||||
{
|
||||
GpuMat src = getInputMat(_src, stream);
|
||||
|
||||
const int depth = src.depth();
|
||||
|
||||
const int THRESH_OTSU = 8;
|
||||
if ((type & THRESH_OTSU) == THRESH_OTSU)
|
||||
{
|
||||
CV_Assert(depth == CV_8U);
|
||||
CV_Assert(src.channels() == 1);
|
||||
|
||||
BufferPool pool(stream);
|
||||
|
||||
// Find the threshold using Otsu and then run the normal thresholding algorithm
|
||||
GpuMat gpu_histogram(256, 1, CV_32SC1, pool.getAllocator());
|
||||
calcHist(src, gpu_histogram, stream);
|
||||
|
||||
GpuMat gpu_otsu_threshold(1, 1, CV_32SC1, pool.getAllocator());
|
||||
compute_otsu(gpu_histogram.ptr<uint>(), gpu_otsu_threshold.ptr<uint>(), stream);
|
||||
|
||||
cv::Mat mat_otsu_threshold;
|
||||
gpu_otsu_threshold.download(mat_otsu_threshold, stream);
|
||||
stream.waitForCompletion();
|
||||
|
||||
// Overwrite the threshold value with the Otsu value and remove the Otsu flag from the type
|
||||
type = type & ~THRESH_OTSU;
|
||||
thresh = (double) mat_otsu_threshold.at<int>(0);
|
||||
}
|
||||
|
||||
CV_Assert( depth <= CV_64F );
|
||||
CV_Assert( type <= 4 /*THRESH_TOZERO_INV*/ );
|
||||
|
||||
|
@@ -2529,7 +2529,7 @@ INSTANTIATE_TEST_CASE_P(CUDA_Arithm, AddWeighted, testing::Combine(
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Threshold
|
||||
|
||||
CV_ENUM(ThreshOp, cv::THRESH_BINARY, cv::THRESH_BINARY_INV, cv::THRESH_TRUNC, cv::THRESH_TOZERO, cv::THRESH_TOZERO_INV)
|
||||
CV_ENUM(ThreshOp, cv::THRESH_BINARY, cv::THRESH_BINARY_INV, cv::THRESH_TRUNC, cv::THRESH_TOZERO, cv::THRESH_TOZERO_INV, cv::THRESH_OTSU)
|
||||
#define ALL_THRESH_OPS testing::Values(ThreshOp(cv::THRESH_BINARY), ThreshOp(cv::THRESH_BINARY_INV), ThreshOp(cv::THRESH_TRUNC), ThreshOp(cv::THRESH_TOZERO), ThreshOp(cv::THRESH_TOZERO_INV))
|
||||
|
||||
PARAM_TEST_CASE(Threshold, cv::cuda::DeviceInfo, cv::Size, MatType, Channels, ThreshOp, UseRoi)
|
||||
@@ -2577,6 +2577,54 @@ INSTANTIATE_TEST_CASE_P(CUDA_Arithm, Threshold, testing::Combine(
|
||||
ALL_THRESH_OPS,
|
||||
WHOLE_SUBMAT));
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// ThresholdOtsu
|
||||
|
||||
PARAM_TEST_CASE(ThresholdOtsu, cv::cuda::DeviceInfo, cv::Size, MatType, Channels, ThreshOp, UseRoi)
|
||||
{
|
||||
cv::cuda::DeviceInfo devInfo;
|
||||
cv::Size size;
|
||||
int type;
|
||||
int channel;
|
||||
int threshOp;
|
||||
bool useRoi;
|
||||
|
||||
virtual void SetUp()
|
||||
{
|
||||
devInfo = GET_PARAM(0);
|
||||
size = GET_PARAM(1);
|
||||
type = GET_PARAM(2);
|
||||
channel = GET_PARAM(3);
|
||||
threshOp = GET_PARAM(4) | cv::THRESH_OTSU;
|
||||
useRoi = GET_PARAM(5);
|
||||
|
||||
cv::cuda::setDevice(devInfo.deviceID());
|
||||
}
|
||||
};
|
||||
|
||||
CUDA_TEST_P(ThresholdOtsu, Accuracy)
|
||||
{
|
||||
cv::Mat src = randomMat(size, CV_MAKE_TYPE(type, channel));
|
||||
|
||||
cv::cuda::GpuMat dst = createMat(src.size(), src.type(), useRoi);
|
||||
double otsu_gpu = cv::cuda::threshold(loadMat(src, useRoi), dst, 0, 255, threshOp);
|
||||
|
||||
cv::Mat dst_gold;
|
||||
double otsu_cpu = cv::threshold(src, dst_gold, 0, 255, threshOp);
|
||||
|
||||
ASSERT_DOUBLE_EQ(otsu_gpu, otsu_cpu);
|
||||
EXPECT_MAT_NEAR(dst_gold, dst, 0.0);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(CUDA_Arithm, ThresholdOtsu, testing::Combine(
|
||||
ALL_DEVICES,
|
||||
DIFFERENT_SIZES,
|
||||
testing::Values(MatDepth(CV_8U)),
|
||||
testing::Values(Channels(1)),
|
||||
ALL_THRESH_OPS,
|
||||
WHOLE_SUBMAT));
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// InRange
|
||||
|
||||
|
@@ -62,6 +62,8 @@ template <typename T> __device__ __forceinline__ T saturate_cast(ushort v) { ret
|
||||
template <typename T> __device__ __forceinline__ T saturate_cast(short v) { return T(v); }
|
||||
template <typename T> __device__ __forceinline__ T saturate_cast(uint v) { return T(v); }
|
||||
template <typename T> __device__ __forceinline__ T saturate_cast(int v) { return T(v); }
|
||||
template <typename T> __device__ __forceinline__ T saturate_cast(signed long long v) { return T(v); }
|
||||
template <typename T> __device__ __forceinline__ T saturate_cast(unsigned long long v) { return T(v); }
|
||||
template <typename T> __device__ __forceinline__ T saturate_cast(float v) { return T(v); }
|
||||
template <typename T> __device__ __forceinline__ T saturate_cast(double v) { return T(v); }
|
||||
|
||||
|
@@ -332,6 +332,16 @@ __device__ __forceinline__ uint shfl_down(uint val, uint delta, int width = warp
|
||||
return (uint) __shfl_down((int) val, delta, width);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ signed long long shfl_down(signed long long val, uint delta, int width = warpSize)
|
||||
{
|
||||
return __shfl_down(val, delta, width);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ unsigned long long shfl_down(unsigned long long val, uint delta, int width = warpSize)
|
||||
{
|
||||
return (unsigned long long) __shfl_down(val, delta, width);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float shfl_down(float val, uint delta, int width = warpSize)
|
||||
{
|
||||
return __shfl_down(val, delta, width);
|
||||
|
Reference in New Issue
Block a user