FashionGAN / netdissect /upsegmodel /prroi_pool /src /prroi_pooling_gpu_impl.cu
fiesty-bear
Initial Commit
6064c9d
/*
* File : prroi_pooling_gpu_impl.cu
* Author : Tete Xiao, Jiayuan Mao
* Email : jasonhsiao97@gmail.com
*
* Distributed under terms of the MIT license.
* Copyright (c) 2017 Megvii Technology Limited.
*/
#include "prroi_pooling_gpu_impl.cuh"
#include <cstdio>
#include <cfloat>
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
#define CUDA_POST_KERNEL_CHECK \
do { \
cudaError_t err = cudaGetLastError(); \
if (cudaSuccess != err) { \
fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err)); \
exit(-1); \
} \
} while(0)
#define CUDA_NUM_THREADS 512
namespace {
static int CUDA_NUM_BLOCKS(const int N) {
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}
__device__ static float PrRoIPoolingGetData(F_DEVPTR_IN data, const int h, const int w, const int height, const int width)
{
bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width);
float retVal = overflow ? 0.0f : data[h * width + w];
return retVal;
}
__device__ static float PrRoIPoolingGetCoeff(float dh, float dw){
dw = dw > 0 ? dw : -dw;
dh = dh > 0 ? dh : -dh;
return (1.0f - dh) * (1.0f - dw);
}
__device__ static float PrRoIPoolingSingleCoorIntegral(float s, float t, float c1, float c2) {
return 0.5 * (t * t - s * s) * c2 + (t - 0.5 * t * t - s + 0.5 * s * s) * c1;
}
__device__ static float PrRoIPoolingInterpolation(F_DEVPTR_IN data, const float h, const float w, const int height, const int width){
float retVal = 0.0f;
int h1 = floorf(h);
int w1 = floorf(w);
retVal += PrRoIPoolingGetData(data, h1, w1, height, width) * PrRoIPoolingGetCoeff(h - float(h1), w - float(w1));
h1 = floorf(h)+1;
w1 = floorf(w);
retVal += PrRoIPoolingGetData(data, h1, w1, height, width) * PrRoIPoolingGetCoeff(h - float(h1), w - float(w1));
h1 = floorf(h);
w1 = floorf(w)+1;
retVal += PrRoIPoolingGetData(data, h1, w1, height, width) * PrRoIPoolingGetCoeff(h - float(h1), w - float(w1));
h1 = floorf(h)+1;
w1 = floorf(w)+1;
retVal += PrRoIPoolingGetData(data, h1, w1, height, width) * PrRoIPoolingGetCoeff(h - float(h1), w - float(w1));
return retVal;
}
__device__ static float PrRoIPoolingMatCalculation(F_DEVPTR_IN this_data, const int s_h, const int s_w, const int e_h, const int e_w,
const float y0, const float x0, const float y1, const float x1, const int h0, const int w0)
{
float alpha, beta, lim_alpha, lim_beta, tmp;
float sum_out = 0;
alpha = x0 - float(s_w);
beta = y0 - float(s_h);
lim_alpha = x1 - float(s_w);
lim_beta = y1 - float(s_h);
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha)
* (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
sum_out += PrRoIPoolingGetData(this_data, s_h, s_w, h0, w0) * tmp;
alpha = float(e_w) - x1;
lim_alpha = float(e_w) - x0;
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha)
* (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
sum_out += PrRoIPoolingGetData(this_data, s_h, e_w, h0, w0) * tmp;
alpha = x0 - float(s_w);
beta = float(e_h) - y1;
lim_alpha = x1 - float(s_w);
lim_beta = float(e_h) - y0;
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha)
* (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
sum_out += PrRoIPoolingGetData(this_data, e_h, s_w, h0, w0) * tmp;
alpha = float(e_w) - x1;
lim_alpha = float(e_w) - x0;
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha)
* (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
sum_out += PrRoIPoolingGetData(this_data, e_h, e_w, h0, w0) * tmp;
return sum_out;
}
__device__ static void PrRoIPoolingDistributeDiff(F_DEVPTR_OUT diff, const float top_diff, const int h, const int w, const int height, const int width, const float coeff)
{
bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width);
if (!overflow)
atomicAdd(diff + h * width + w, top_diff * coeff);
}
__device__ static void PrRoIPoolingMatDistributeDiff(F_DEVPTR_OUT diff, const float top_diff, const int s_h, const int s_w, const int e_h, const int e_w,
const float y0, const float x0, const float y1, const float x1, const int h0, const int w0)
{
float alpha, beta, lim_alpha, lim_beta, tmp;
alpha = x0 - float(s_w);
beta = y0 - float(s_h);
lim_alpha = x1 - float(s_w);
lim_beta = y1 - float(s_h);
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha)
* (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
PrRoIPoolingDistributeDiff(diff, top_diff, s_h, s_w, h0, w0, tmp);
alpha = float(e_w) - x1;
lim_alpha = float(e_w) - x0;
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha)
* (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
PrRoIPoolingDistributeDiff(diff, top_diff, s_h, e_w, h0, w0, tmp);
alpha = x0 - float(s_w);
beta = float(e_h) - y1;
lim_alpha = x1 - float(s_w);
lim_beta = float(e_h) - y0;
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha)
* (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
PrRoIPoolingDistributeDiff(diff, top_diff, e_h, s_w, h0, w0, tmp);
alpha = float(e_w) - x1;
lim_alpha = float(e_w) - x0;
tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + 0.5f * alpha * alpha)
* (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta);
PrRoIPoolingDistributeDiff(diff, top_diff, e_h, e_w, h0, w0, tmp);
}
__global__ void PrRoIPoolingForward(
const int nthreads,
F_DEVPTR_IN bottom_data,
F_DEVPTR_IN bottom_rois,
F_DEVPTR_OUT top_data,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const float spatial_scale) {
CUDA_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
bottom_rois += n * 5;
int roi_batch_ind = bottom_rois[0];
float roi_start_w = bottom_rois[1] * spatial_scale;
float roi_start_h = bottom_rois[2] * spatial_scale;
float roi_end_w = bottom_rois[3] * spatial_scale;
float roi_end_h = bottom_rois[4] * spatial_scale;
float roi_width = max(roi_end_w - roi_start_w, ((float)0.0));
float roi_height = max(roi_end_h - roi_start_h, ((float)0.0));
float bin_size_h = roi_height / static_cast<float>(pooled_height);
float bin_size_w = roi_width / static_cast<float>(pooled_width);
const float *this_data = bottom_data + (roi_batch_ind * channels + c) * height * width;
float *this_out = top_data + index;
float win_start_w = roi_start_w + bin_size_w * pw;
float win_start_h = roi_start_h + bin_size_h * ph;
float win_end_w = win_start_w + bin_size_w;
float win_end_h = win_start_h + bin_size_h;
float win_size = max(float(0.0), bin_size_w * bin_size_h);
if (win_size == 0) {
*this_out = 0;
return;
}
float sum_out = 0;
int s_w, s_h, e_w, e_h;
s_w = floorf(win_start_w);
e_w = ceilf(win_end_w);
s_h = floorf(win_start_h);
e_h = ceilf(win_end_h);
for (int w_iter = s_w; w_iter < e_w; ++w_iter)
for (int h_iter = s_h; h_iter < e_h; ++h_iter)
sum_out += PrRoIPoolingMatCalculation(this_data, h_iter, w_iter, h_iter + 1, w_iter + 1,
max(win_start_h, float(h_iter)), max(win_start_w, float(w_iter)),
min(win_end_h, float(h_iter) + 1.0), min(win_end_w, float(w_iter + 1.0)),
height, width);
*this_out = sum_out / win_size;
}
}
__global__ void PrRoIPoolingBackward(
const int nthreads,
F_DEVPTR_IN bottom_rois,
F_DEVPTR_IN top_diff,
F_DEVPTR_OUT bottom_diff,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const float spatial_scale) {
CUDA_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
bottom_rois += n * 5;
int roi_batch_ind = bottom_rois[0];
float roi_start_w = bottom_rois[1] * spatial_scale;
float roi_start_h = bottom_rois[2] * spatial_scale;
float roi_end_w = bottom_rois[3] * spatial_scale;
float roi_end_h = bottom_rois[4] * spatial_scale;
float roi_width = max(roi_end_w - roi_start_w, (float)0);
float roi_height = max(roi_end_h - roi_start_h, (float)0);
float bin_size_h = roi_height / static_cast<float>(pooled_height);
float bin_size_w = roi_width / static_cast<float>(pooled_width);
const float *this_out_grad = top_diff + index;
float *this_data_grad = bottom_diff + (roi_batch_ind * channels + c) * height * width;
float win_start_w = roi_start_w + bin_size_w * pw;
float win_start_h = roi_start_h + bin_size_h * ph;
float win_end_w = win_start_w + bin_size_w;
float win_end_h = win_start_h + bin_size_h;
float win_size = max(float(0.0), bin_size_w * bin_size_h);
float sum_out = win_size == float(0) ? float(0) : *this_out_grad / win_size;
int s_w, s_h, e_w, e_h;
s_w = floorf(win_start_w);
e_w = ceilf(win_end_w);
s_h = floorf(win_start_h);
e_h = ceilf(win_end_h);
for (int w_iter = s_w; w_iter < e_w; ++w_iter)
for (int h_iter = s_h; h_iter < e_h; ++h_iter)
PrRoIPoolingMatDistributeDiff(this_data_grad, sum_out, h_iter, w_iter, h_iter + 1, w_iter + 1,
max(win_start_h, float(h_iter)), max(win_start_w, float(w_iter)),
min(win_end_h, float(h_iter) + 1.0), min(win_end_w, float(w_iter + 1.0)),
height, width);
}
}
__global__ void PrRoIPoolingCoorBackward(
const int nthreads,
F_DEVPTR_IN bottom_data,
F_DEVPTR_IN bottom_rois,
F_DEVPTR_IN top_data,
F_DEVPTR_IN top_diff,
F_DEVPTR_OUT bottom_diff,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const float spatial_scale) {
CUDA_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
bottom_rois += n * 5;
int roi_batch_ind = bottom_rois[0];
float roi_start_w = bottom_rois[1] * spatial_scale;
float roi_start_h = bottom_rois[2] * spatial_scale;
float roi_end_w = bottom_rois[3] * spatial_scale;
float roi_end_h = bottom_rois[4] * spatial_scale;
float roi_width = max(roi_end_w - roi_start_w, (float)0);
float roi_height = max(roi_end_h - roi_start_h, (float)0);
float bin_size_h = roi_height / static_cast<float>(pooled_height);
float bin_size_w = roi_width / static_cast<float>(pooled_width);
const float *this_out_grad = top_diff + index;
const float *this_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width;
const float *this_top_data = top_data + index;
float *this_data_grad = bottom_diff + n * 5;
float win_start_w = roi_start_w + bin_size_w * pw;
float win_start_h = roi_start_h + bin_size_h * ph;
float win_end_w = win_start_w + bin_size_w;
float win_end_h = win_start_h + bin_size_h;
float win_size = max(float(0.0), bin_size_w * bin_size_h);
float sum_out = win_size == float(0) ? float(0) : *this_out_grad / win_size;
// WARNING: to be discussed
if (sum_out == 0)
return;
int s_w, s_h, e_w, e_h;
s_w = floorf(win_start_w);
e_w = ceilf(win_end_w);
s_h = floorf(win_start_h);
e_h = ceilf(win_end_h);
float g_x1_y = 0, g_x2_y = 0, g_x_y1 = 0, g_x_y2 = 0;
for (int h_iter = s_h; h_iter < e_h; ++h_iter) {
g_x1_y += PrRoIPoolingSingleCoorIntegral(max(win_start_h, float(h_iter)) - h_iter,
min(win_end_h, float(h_iter + 1)) - h_iter,
PrRoIPoolingInterpolation(this_bottom_data, h_iter, win_start_w, height, width),
PrRoIPoolingInterpolation(this_bottom_data, h_iter + 1, win_start_w, height, width));
g_x2_y += PrRoIPoolingSingleCoorIntegral(max(win_start_h, float(h_iter)) - h_iter,
min(win_end_h, float(h_iter + 1)) - h_iter,
PrRoIPoolingInterpolation(this_bottom_data, h_iter, win_end_w, height, width),
PrRoIPoolingInterpolation(this_bottom_data, h_iter + 1, win_end_w, height, width));
}
for (int w_iter = s_w; w_iter < e_w; ++w_iter) {
g_x_y1 += PrRoIPoolingSingleCoorIntegral(max(win_start_w, float(w_iter)) - w_iter,
min(win_end_w, float(w_iter + 1)) - w_iter,
PrRoIPoolingInterpolation(this_bottom_data, win_start_h, w_iter, height, width),
PrRoIPoolingInterpolation(this_bottom_data, win_start_h, w_iter + 1, height, width));
g_x_y2 += PrRoIPoolingSingleCoorIntegral(max(win_start_w, float(w_iter)) - w_iter,
min(win_end_w, float(w_iter + 1)) - w_iter,
PrRoIPoolingInterpolation(this_bottom_data, win_end_h, w_iter, height, width),
PrRoIPoolingInterpolation(this_bottom_data, win_end_h, w_iter + 1, height, width));
}
float partial_x1 = -g_x1_y + (win_end_h - win_start_h) * (*this_top_data);
float partial_y1 = -g_x_y1 + (win_end_w - win_start_w) * (*this_top_data);
float partial_x2 = g_x2_y - (win_end_h - win_start_h) * (*this_top_data);
float partial_y2 = g_x_y2 - (win_end_w - win_start_w) * (*this_top_data);
partial_x1 = partial_x1 / win_size * spatial_scale;
partial_x2 = partial_x2 / win_size * spatial_scale;
partial_y1 = partial_y1 / win_size * spatial_scale;
partial_y2 = partial_y2 / win_size * spatial_scale;
// (b, x1, y1, x2, y2)
this_data_grad[0] = 0;
atomicAdd(this_data_grad + 1, (partial_x1 * (1.0 - float(pw) / pooled_width) + partial_x2 * (1.0 - float(pw + 1) / pooled_width))
* (*this_out_grad));
atomicAdd(this_data_grad + 2, (partial_y1 * (1.0 - float(ph) / pooled_height) + partial_y2 * (1.0 - float(ph + 1) / pooled_height))
* (*this_out_grad));
atomicAdd(this_data_grad + 3, (partial_x2 * float(pw + 1) / pooled_width + partial_x1 * float(pw) / pooled_width)
* (*this_out_grad));
atomicAdd(this_data_grad + 4, (partial_y2 * float(ph + 1) / pooled_height + partial_y1 * float(ph) / pooled_height)
* (*this_out_grad));
}
}
} /* !anonymous namespace */
#ifdef __cplusplus
extern "C" {
#endif
void PrRoIPoolingForwardGpu(
cudaStream_t stream,
F_DEVPTR_IN bottom_data,
F_DEVPTR_IN bottom_rois,
F_DEVPTR_OUT top_data,
const int channels_, const int height_, const int width_,
const int pooled_height_, const int pooled_width_,
const float spatial_scale_,
const int top_count) {
PrRoIPoolingForward<<<CUDA_NUM_BLOCKS(top_count), CUDA_NUM_THREADS, 0, stream>>>(
top_count, bottom_data, bottom_rois, top_data,
channels_, height_, width_, pooled_height_, pooled_width_, spatial_scale_);
CUDA_POST_KERNEL_CHECK;
}
void PrRoIPoolingBackwardGpu(
cudaStream_t stream,
F_DEVPTR_IN bottom_data,
F_DEVPTR_IN bottom_rois,
F_DEVPTR_IN top_data,
F_DEVPTR_IN top_diff,
F_DEVPTR_OUT bottom_diff,
const int channels_, const int height_, const int width_,
const int pooled_height_, const int pooled_width_,
const float spatial_scale_,
const int top_count, const int bottom_count) {
cudaMemsetAsync(bottom_diff, 0, sizeof(float) * bottom_count, stream);
PrRoIPoolingBackward<<<CUDA_NUM_BLOCKS(top_count), CUDA_NUM_THREADS, 0, stream>>>(
top_count, bottom_rois, top_diff, bottom_diff,
channels_, height_, width_, pooled_height_, pooled_width_, spatial_scale_);
CUDA_POST_KERNEL_CHECK;
}
void PrRoIPoolingCoorBackwardGpu(
cudaStream_t stream,
F_DEVPTR_IN bottom_data,
F_DEVPTR_IN bottom_rois,
F_DEVPTR_IN top_data,
F_DEVPTR_IN top_diff,
F_DEVPTR_OUT bottom_diff,
const int channels_, const int height_, const int width_,
const int pooled_height_, const int pooled_width_,
const float spatial_scale_,
const int top_count, const int bottom_count) {
cudaMemsetAsync(bottom_diff, 0, sizeof(float) * bottom_count, stream);
PrRoIPoolingCoorBackward<<<CUDA_NUM_BLOCKS(top_count), CUDA_NUM_THREADS, 0, stream>>>(
top_count, bottom_data, bottom_rois, top_data, top_diff, bottom_diff,
channels_, height_, width_, pooled_height_, pooled_width_, spatial_scale_);
CUDA_POST_KERNEL_CHECK;
}
} /* !extern "C" */