|
#include <ATen/ATen.h> |
|
|
|
#include <thrust/device_ptr.h> |
|
#include <thrust/transform.h> |
|
|
|
#include <vector> |
|
|
|
#include "utils/checks.h" |
|
#include "utils/cuda.cuh" |
|
#include "inplace_abn.h" |
|
|
|
#include <ATen/cuda/CUDAContext.h> |
|
|
|
|
|
template<typename T> |
|
struct SumOp { |
|
__device__ SumOp(const T *t, int c, int s) |
|
: tensor(t), chn(c), sp(s) {} |
|
__device__ __forceinline__ T operator()(int batch, int plane, int n) { |
|
return tensor[(batch * chn + plane) * sp + n]; |
|
} |
|
const T *tensor; |
|
const int chn; |
|
const int sp; |
|
}; |
|
|
|
template<typename T> |
|
struct VarOp { |
|
__device__ VarOp(T m, const T *t, int c, int s) |
|
: mean(m), tensor(t), chn(c), sp(s) {} |
|
__device__ __forceinline__ T operator()(int batch, int plane, int n) { |
|
T val = tensor[(batch * chn + plane) * sp + n]; |
|
return (val - mean) * (val - mean); |
|
} |
|
const T mean; |
|
const T *tensor; |
|
const int chn; |
|
const int sp; |
|
}; |
|
|
|
template<typename T> |
|
struct GradOp { |
|
__device__ GradOp(T _weight, T _bias, const T *_z, const T *_dz, int c, int s) |
|
: weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {} |
|
__device__ __forceinline__ Pair<T> operator()(int batch, int plane, int n) { |
|
T _y = (z[(batch * chn + plane) * sp + n] - bias) / weight; |
|
T _dz = dz[(batch * chn + plane) * sp + n]; |
|
return Pair<T>(_dz, _y * _dz); |
|
} |
|
const T weight; |
|
const T bias; |
|
const T *z; |
|
const T *dz; |
|
const int chn; |
|
const int sp; |
|
}; |
|
|
|
|
|
|
|
|
|
|
|
template<typename T> |
|
__global__ void mean_var_kernel(const T *x, T *mean, T *var, int num, int chn, int sp) { |
|
int plane = blockIdx.x; |
|
T norm = T(1) / T(num * sp); |
|
|
|
T _mean = reduce<T, SumOp<T>>(SumOp<T>(x, chn, sp), plane, num, sp) * norm; |
|
__syncthreads(); |
|
T _var = reduce<T, VarOp<T>>(VarOp<T>(_mean, x, chn, sp), plane, num, sp) * norm; |
|
|
|
if (threadIdx.x == 0) { |
|
mean[plane] = _mean; |
|
var[plane] = _var; |
|
} |
|
} |
|
|
|
std::vector<at::Tensor> mean_var_cuda(at::Tensor x) { |
|
CHECK_CUDA_INPUT(x); |
|
|
|
|
|
int64_t num, chn, sp; |
|
get_dims(x, num, chn, sp); |
|
|
|
|
|
auto mean = at::empty({chn}, x.options()); |
|
auto var = at::empty({chn}, x.options()); |
|
|
|
|
|
dim3 blocks(chn); |
|
dim3 threads(getNumThreads(sp)); |
|
auto stream = at::cuda::getCurrentCUDAStream(); |
|
AT_DISPATCH_FLOATING_TYPES(x.type(), "mean_var_cuda", ([&] { |
|
mean_var_kernel<scalar_t><<<blocks, threads, 0, stream>>>( |
|
x.data<scalar_t>(), |
|
mean.data<scalar_t>(), |
|
var.data<scalar_t>(), |
|
num, chn, sp); |
|
})); |
|
|
|
return {mean, var}; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
template<typename T> |
|
__global__ void forward_kernel(T *x, const T *mean, const T *var, const T *weight, const T *bias, |
|
bool affine, float eps, int num, int chn, int sp) { |
|
int plane = blockIdx.x; |
|
|
|
T _mean = mean[plane]; |
|
T _var = var[plane]; |
|
T _weight = affine ? abs(weight[plane]) + eps : T(1); |
|
T _bias = affine ? bias[plane] : T(0); |
|
|
|
T mul = rsqrt(_var + eps) * _weight; |
|
|
|
for (int batch = 0; batch < num; ++batch) { |
|
for (int n = threadIdx.x; n < sp; n += blockDim.x) { |
|
T _x = x[(batch * chn + plane) * sp + n]; |
|
T _y = (_x - _mean) * mul + _bias; |
|
|
|
x[(batch * chn + plane) * sp + n] = _y; |
|
} |
|
} |
|
} |
|
|
|
at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, |
|
bool affine, float eps) { |
|
CHECK_CUDA_INPUT(x); |
|
CHECK_CUDA_INPUT(mean); |
|
CHECK_CUDA_INPUT(var); |
|
CHECK_CUDA_INPUT(weight); |
|
CHECK_CUDA_INPUT(bias); |
|
|
|
|
|
int64_t num, chn, sp; |
|
get_dims(x, num, chn, sp); |
|
|
|
|
|
dim3 blocks(chn); |
|
dim3 threads(getNumThreads(sp)); |
|
auto stream = at::cuda::getCurrentCUDAStream(); |
|
AT_DISPATCH_FLOATING_TYPES(x.type(), "forward_cuda", ([&] { |
|
forward_kernel<scalar_t><<<blocks, threads, 0, stream>>>( |
|
x.data<scalar_t>(), |
|
mean.data<scalar_t>(), |
|
var.data<scalar_t>(), |
|
weight.data<scalar_t>(), |
|
bias.data<scalar_t>(), |
|
affine, eps, num, chn, sp); |
|
})); |
|
|
|
return x; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
template<typename T> |
|
__global__ void edz_eydz_kernel(const T *z, const T *dz, const T *weight, const T *bias, |
|
T *edz, T *eydz, bool affine, float eps, int num, int chn, int sp) { |
|
int plane = blockIdx.x; |
|
|
|
T _weight = affine ? abs(weight[plane]) + eps : 1.f; |
|
T _bias = affine ? bias[plane] : 0.f; |
|
|
|
Pair<T> res = reduce<Pair<T>, GradOp<T>>(GradOp<T>(_weight, _bias, z, dz, chn, sp), plane, num, sp); |
|
__syncthreads(); |
|
|
|
if (threadIdx.x == 0) { |
|
edz[plane] = res.v1; |
|
eydz[plane] = res.v2; |
|
} |
|
} |
|
|
|
std::vector<at::Tensor> edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, |
|
bool affine, float eps) { |
|
CHECK_CUDA_INPUT(z); |
|
CHECK_CUDA_INPUT(dz); |
|
CHECK_CUDA_INPUT(weight); |
|
CHECK_CUDA_INPUT(bias); |
|
|
|
|
|
int64_t num, chn, sp; |
|
get_dims(z, num, chn, sp); |
|
|
|
auto edz = at::empty({chn}, z.options()); |
|
auto eydz = at::empty({chn}, z.options()); |
|
|
|
|
|
dim3 blocks(chn); |
|
dim3 threads(getNumThreads(sp)); |
|
auto stream = at::cuda::getCurrentCUDAStream(); |
|
AT_DISPATCH_FLOATING_TYPES(z.type(), "edz_eydz_cuda", ([&] { |
|
edz_eydz_kernel<scalar_t><<<blocks, threads, 0, stream>>>( |
|
z.data<scalar_t>(), |
|
dz.data<scalar_t>(), |
|
weight.data<scalar_t>(), |
|
bias.data<scalar_t>(), |
|
edz.data<scalar_t>(), |
|
eydz.data<scalar_t>(), |
|
affine, eps, num, chn, sp); |
|
})); |
|
|
|
return {edz, eydz}; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
template<typename T> |
|
__global__ void backward_kernel(const T *z, const T *dz, const T *var, const T *weight, const T *bias, const T *edz, |
|
const T *eydz, T *dx, bool affine, float eps, int num, int chn, int sp) { |
|
int plane = blockIdx.x; |
|
|
|
T _weight = affine ? abs(weight[plane]) + eps : 1.f; |
|
T _bias = affine ? bias[plane] : 0.f; |
|
T _var = var[plane]; |
|
T _edz = edz[plane]; |
|
T _eydz = eydz[plane]; |
|
|
|
T _mul = _weight * rsqrt(_var + eps); |
|
T count = T(num * sp); |
|
|
|
for (int batch = 0; batch < num; ++batch) { |
|
for (int n = threadIdx.x; n < sp; n += blockDim.x) { |
|
T _dz = dz[(batch * chn + plane) * sp + n]; |
|
T _y = (z[(batch * chn + plane) * sp + n] - _bias) / _weight; |
|
|
|
dx[(batch * chn + plane) * sp + n] = (_dz - _edz / count - _y * _eydz / count) * _mul; |
|
} |
|
} |
|
} |
|
|
|
at::Tensor backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, |
|
at::Tensor edz, at::Tensor eydz, bool affine, float eps) { |
|
CHECK_CUDA_INPUT(z); |
|
CHECK_CUDA_INPUT(dz); |
|
CHECK_CUDA_INPUT(var); |
|
CHECK_CUDA_INPUT(weight); |
|
CHECK_CUDA_INPUT(bias); |
|
CHECK_CUDA_INPUT(edz); |
|
CHECK_CUDA_INPUT(eydz); |
|
|
|
|
|
int64_t num, chn, sp; |
|
get_dims(z, num, chn, sp); |
|
|
|
auto dx = at::zeros_like(z); |
|
|
|
|
|
dim3 blocks(chn); |
|
dim3 threads(getNumThreads(sp)); |
|
auto stream = at::cuda::getCurrentCUDAStream(); |
|
AT_DISPATCH_FLOATING_TYPES(z.type(), "backward_cuda", ([&] { |
|
backward_kernel<scalar_t><<<blocks, threads, 0, stream>>>( |
|
z.data<scalar_t>(), |
|
dz.data<scalar_t>(), |
|
var.data<scalar_t>(), |
|
weight.data<scalar_t>(), |
|
bias.data<scalar_t>(), |
|
edz.data<scalar_t>(), |
|
eydz.data<scalar_t>(), |
|
dx.data<scalar_t>(), |
|
affine, eps, num, chn, sp); |
|
})); |
|
|
|
return dx; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
template<typename T> |
|
inline void leaky_relu_backward_impl(T *z, T *dz, float slope, int64_t count) { |
|
|
|
thrust::device_ptr<T> th_z = thrust::device_pointer_cast(z); |
|
thrust::device_ptr<T> th_dz = thrust::device_pointer_cast(dz); |
|
|
|
auto stream = at::cuda::getCurrentCUDAStream(); |
|
thrust::transform_if(thrust::cuda::par.on(stream), |
|
th_dz, th_dz + count, th_z, th_dz, |
|
[slope] __device__ (const T& dz) { return dz * slope; }, |
|
[] __device__ (const T& z) { return z < 0; }); |
|
thrust::transform_if(thrust::cuda::par.on(stream), |
|
th_z, th_z + count, th_z, |
|
[slope] __device__ (const T& z) { return z / slope; }, |
|
[] __device__ (const T& z) { return z < 0; }); |
|
} |
|
|
|
void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope) { |
|
CHECK_CUDA_INPUT(z); |
|
CHECK_CUDA_INPUT(dz); |
|
|
|
int64_t count = z.numel(); |
|
|
|
AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] { |
|
leaky_relu_backward_impl<scalar_t>(z.data<scalar_t>(), dz.data<scalar_t>(), slope, count); |
|
})); |
|
} |
|
|
|
template<typename T> |
|
inline void elu_backward_impl(T *z, T *dz, int64_t count) { |
|
|
|
thrust::device_ptr<T> th_z = thrust::device_pointer_cast(z); |
|
thrust::device_ptr<T> th_dz = thrust::device_pointer_cast(dz); |
|
|
|
auto stream = at::cuda::getCurrentCUDAStream(); |
|
thrust::transform_if(thrust::cuda::par.on(stream), |
|
th_dz, th_dz + count, th_z, th_z, th_dz, |
|
[] __device__ (const T& dz, const T& z) { return dz * (z + 1.); }, |
|
[] __device__ (const T& z) { return z < 0; }); |
|
thrust::transform_if(thrust::cuda::par.on(stream), |
|
th_z, th_z + count, th_z, |
|
[] __device__ (const T& z) { return log1p(z); }, |
|
[] __device__ (const T& z) { return z < 0; }); |
|
} |
|
|
|
void elu_backward_cuda(at::Tensor z, at::Tensor dz) { |
|
CHECK_CUDA_INPUT(z); |
|
CHECK_CUDA_INPUT(dz); |
|
|
|
int64_t count = z.numel(); |
|
|
|
AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] { |
|
elu_backward_impl<scalar_t>(z.data<scalar_t>(), dz.data<scalar_t>(), count); |
|
})); |
|
} |
|
|