Find3D / Pointcept /libs /pointops2 /src /sampling /sampling_cuda_kernel.cu
ziqima's picture
initial commit
4893ce0
raw
history blame
5.48 kB
#include "../cuda_utils.h"
#include "sampling_cuda_kernel.h"
__device__ void __update(float *dists, int *dists_i, int idx1, int idx2) {
const float v1 = dists[idx1], v2 = dists[idx2];
const int i1 = dists_i[idx1], i2 = dists_i[idx2];
dists[idx1] = max(v1, v2);
dists_i[idx1] = v2 > v1 ? i2 : i1;
}
// input xyz: (n, 3), tmp: (b, n_max)
// output idx (m)
template <unsigned int block_size>
__global__ void furthestsampling_cuda_kernel(const float *xyz, const int *offset, const int *new_offset, float *tmp, int *idx)
{
__shared__ float dists[block_size];
__shared__ int dists_i[block_size];
int bid = blockIdx.x;
int start_n, end_n, start_m, end_m, old;
if (bid == 0) {
start_n = 0;
end_n = offset[0];
start_m = 0;
end_m = new_offset[0];
old = 0;
}
else {
start_n = offset[bid - 1];
end_n = offset[bid];
start_m = new_offset[bid - 1];
end_m = new_offset[bid];
old = offset[bid - 1];
}
const int stride = block_size;
int tid = threadIdx.x;
if (tid == 0) idx[start_m] = start_n;
__syncthreads();
for (int j = start_m + 1; j < end_m; j++)
{
int besti = start_n;
float best = -1;
float x1 = xyz[old * 3 + 0];
float y1 = xyz[old * 3 + 1];
float z1 = xyz[old * 3 + 2];
for (int k = start_n + tid; k < end_n; k += stride)
{
float x2 = xyz[k * 3 + 0];
float y2 = xyz[k * 3 + 1];
float z2 = xyz[k * 3 + 2];
float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
float d2 = min(d, tmp[k]);
tmp[k] = d2;
besti = d2 > best ? k : besti;
best = d2 > best ? d2 : best;
}
dists[tid] = best;
dists_i[tid] = besti;
__syncthreads();
if (block_size >= 1024) {
if (tid < 512) {
__update(dists, dists_i, tid, tid + 512);
}
__syncthreads();
}
if (block_size >= 512) {
if (tid < 256) {
__update(dists, dists_i, tid, tid + 256);
}
__syncthreads();
}
if (block_size >= 256) {
if (tid < 128) {
__update(dists, dists_i, tid, tid + 128);
}
__syncthreads();
}
if (block_size >= 128) {
if (tid < 64) {
__update(dists, dists_i, tid, tid + 64);
}
__syncthreads();
}
if (block_size >= 64) {
if (tid < 32) {
__update(dists, dists_i, tid, tid + 32);
}
__syncthreads();
}
if (block_size >= 32) {
if (tid < 16) {
__update(dists, dists_i, tid, tid + 16);
}
__syncthreads();
}
if (block_size >= 16) {
if (tid < 8) {
__update(dists, dists_i, tid, tid + 8);
}
__syncthreads();
}
if (block_size >= 8) {
if (tid < 4) {
__update(dists, dists_i, tid, tid + 4);
}
__syncthreads();
}
if (block_size >= 4) {
if (tid < 2) {
__update(dists, dists_i, tid, tid + 2);
}
__syncthreads();
}
if (block_size >= 2) {
if (tid < 1) {
__update(dists, dists_i, tid, tid + 1);
}
__syncthreads();
}
old = dists_i[0];
if (tid == 0)
idx[j] = old;
}
}
void furthestsampling_cuda_launcher(int b, int n, const float *xyz, const int *offset, const int *new_offset, float *tmp, int *idx)
{
unsigned int n_threads = opt_n_threads(n);
switch (n_threads) {
case 1024:
furthestsampling_cuda_kernel<1024><<<b, n_threads, 0>>>(xyz, offset, new_offset, tmp, idx);
break;
case 512:
furthestsampling_cuda_kernel<512><<<b, n_threads, 0>>>(xyz, offset, new_offset, tmp, idx);
break;
case 256:
furthestsampling_cuda_kernel<256><<<b, n_threads, 0>>>(xyz, offset, new_offset, tmp, idx);
break;
case 128:
furthestsampling_cuda_kernel<128><<<b, n_threads, 0>>>(xyz, offset, new_offset, tmp, idx);
break;
case 64:
furthestsampling_cuda_kernel<64><<<b, n_threads, 0>>>(xyz, offset, new_offset, tmp, idx);
break;
case 32:
furthestsampling_cuda_kernel<32><<<b, n_threads, 0>>>(xyz, offset, new_offset, tmp, idx);
break;
case 16:
furthestsampling_cuda_kernel<16><<<b, n_threads, 0>>>(xyz, offset, new_offset, tmp, idx);
break;
case 8:
furthestsampling_cuda_kernel<8><<<b, n_threads, 0>>>(xyz, offset, new_offset, tmp, idx);
break;
case 4:
furthestsampling_cuda_kernel<4><<<b, n_threads, 0>>>(xyz, offset, new_offset, tmp, idx);
break;
case 2:
furthestsampling_cuda_kernel<2><<<b, n_threads, 0>>>(xyz, offset, new_offset, tmp, idx);
break;
case 1:
furthestsampling_cuda_kernel<1><<<b, n_threads, 0>>>(xyz, offset, new_offset, tmp, idx);
break;
default:
furthestsampling_cuda_kernel<512><<<b, n_threads, 0>>>(xyz, offset, new_offset, tmp, idx);
}
}