// This file is part of OpenCV project. // It is subject to the license terms in the LICENSE file found in the top-level directory // of this distribution and at http://opencv.org/license.html. #include #include #include "math.hpp" #include "types.hpp" #include "atomics.hpp" #include "grid_stride_range.hpp" #include "execution.hpp" #include "../cuda4dnn/csl/stream.hpp" #include "../cuda4dnn/csl/span.hpp" #include #include using namespace cv::dnn::cuda4dnn::csl; using namespace cv::dnn::cuda4dnn::csl::device; namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels { namespace raw { template __global__ void reduce_mean(Span means, View input, size_type inner_size) { for (auto idx : grid_stride_range(input.size())) { const index_type outer_idx = idx / inner_size; atomicAdd(&means[outer_idx], static_cast(input[idx]) / inner_size); } } template __global__ void reduce_mean_sqr_sum(Span means, Span sum_sqrs, View input, size_type inner_size) { for (auto idx : grid_stride_range(input.size())) { const index_type outer_idx = idx / inner_size; auto x = static_cast(input[idx]); atomicAdd(&means[outer_idx], x / inner_size); atomicAdd(&sum_sqrs[outer_idx], x * x); } } __global__ void compute_normalization_scale(Span scale, View means, View sums_sqr, size_type inner_size, float eps) { for (auto idx : grid_stride_range(scale.size())) { auto mean = means[idx]; auto var = sums_sqr[idx] / inner_size - mean * mean; using device::rsqrt; scale[idx] = rsqrt(eps + var); } } template __global__ void normalize_mean(Span output, View input, View means, size_type inner_size) { for (auto idx : grid_stride_range(output.size())) { const index_type outer_idx = idx / inner_size; output[idx] = static_cast(input[idx]) - means[outer_idx]; } } template __global__ void normalize_mean_variance(Span output, View input, View means, View scale, size_type inner_size) { for (auto idx : grid_stride_range(output.size())) { const index_type outer_idx = idx / inner_size; output[idx] = (static_cast(input[idx]) - means[outer_idx]) * scale[outer_idx]; } } } template void reduce_mean(const Stream& stream, Span means, View input, std::size_t inner_size) { CV_Assert(input.size() / inner_size == means.size()); auto kernel = raw::reduce_mean; auto policy = make_policy(kernel, input.size(), 0, stream); launch_kernel(kernel, policy, means, input, inner_size); } #if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) template void reduce_mean(const Stream&, Span, View<__half>, std::size_t); #endif template void reduce_mean(const Stream&, Span, View, std::size_t); template void reduce_mean_sqr_sum(const Stream& stream, Span means, Span sum_sqrs, View input, std::size_t inner_size) { CV_Assert(input.size() / inner_size == means.size()); CV_Assert(input.size() / inner_size == sum_sqrs.size()); auto kernel = raw::reduce_mean_sqr_sum; auto policy = make_policy(kernel, input.size(), 0, stream); launch_kernel(kernel, policy, means, sum_sqrs, input, inner_size); } #if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) template void reduce_mean_sqr_sum(const Stream&, Span, Span, View<__half>, std::size_t); #endif template void reduce_mean_sqr_sum(const Stream&, Span, Span, View, std::size_t); void compute_normalization_scale(const Stream& stream, Span scale, View means, View sum_sqrs, std::size_t inner_size, float eps) { CV_Assert(scale.size() == means.size()); CV_Assert(scale.size() == sum_sqrs.size()); auto kernel = raw::compute_normalization_scale; auto policy = make_policy(kernel, scale.size(), 0, stream); launch_kernel(kernel, policy, scale, means, sum_sqrs, inner_size, eps); } template void normalize_mean(const Stream& stream, Span output, View input, View means, std::size_t inner_size) { CV_Assert(output.size() == input.size()); CV_Assert(input.size() / inner_size == means.size()); auto kernel = raw::normalize_mean; auto policy = make_policy(kernel, output.size(), 0, stream); launch_kernel(kernel, policy, output, input, means, inner_size); } #if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) template void normalize_mean(const Stream&, Span<__half>, View<__half>, View, std::size_t); #endif template void normalize_mean(const Stream&, Span, View, View, std::size_t); template void normalize_mean_variance(const Stream& stream, Span output, View input, View means, View scale, std::size_t inner_size) { CV_Assert(input.size() == output.size()); CV_Assert(input.size() / inner_size == means.size()); CV_Assert(input.size() / inner_size == scale.size()); auto kernel = raw::normalize_mean_variance; auto policy = make_policy(kernel, output.size(), 0, stream); launch_kernel(kernel, policy, output, input, means, scale, inner_size); } #if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530) template void normalize_mean_variance(const Stream&, Span<__half>, View<__half>, View, View, std::size_t); #endif template void normalize_mean_variance(const Stream&, Span, View, View, View, std::size_t); }}}} /* namespace cv::dnn::cuda4dnn::kernels */