|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <cuda.h> |
|
#include <cuda_runtime.h> |
|
#include <math.h> |
|
#include <torch/extension.h> |
|
#include <vector> |
|
|
|
|
|
__global__ void banRepeatedTokens(long* __restrict__ tokens, |
|
float* __restrict__ lprobs, |
|
int max_predict_len, int vocab_size, |
|
int no_repeat_ngram_size) { |
|
auto row = blockIdx.x; |
|
auto col = threadIdx.x; |
|
auto start = row * (max_predict_len) + col; |
|
|
|
|
|
|
|
auto check_start_pos = blockDim.x; |
|
auto lprob_start = row * vocab_size; |
|
bool is_banned = true; |
|
extern __shared__ long tokens_shm[]; |
|
tokens_shm[col] = tokens[start]; |
|
if (col == blockDim.x - 1) { |
|
for (int i=1; i<no_repeat_ngram_size; i++){ |
|
if (col+i < max_predict_len){ |
|
tokens_shm[col + i] = tokens[start + i]; |
|
} |
|
} |
|
} |
|
__syncthreads(); |
|
|
|
for (int k = 0; k < no_repeat_ngram_size - 1; k++) { |
|
if (tokens_shm[col + k] != tokens_shm[check_start_pos + k]) { |
|
is_banned = false; |
|
} |
|
} |
|
if (is_banned == true) { |
|
auto token_to_be_banned = tokens_shm[col + no_repeat_ngram_size - 1]; |
|
lprobs[lprob_start + token_to_be_banned] = -INFINITY; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
torch::Tensor ngram_repeat_block_cuda_forward(const torch::Tensor tokens, |
|
torch::Tensor lprobs, int bsz, |
|
int step, int beam_size, |
|
int no_repeat_ngram_size) { |
|
int threads = step - no_repeat_ngram_size + 2; |
|
if (threads <= 0) return lprobs; |
|
int max_predict_len = tokens.size(1); |
|
int vocab_size = lprobs.size(1); |
|
auto token_ptr = tokens.data_ptr<long>(); |
|
auto lprob_ptr = lprobs.data_ptr<float>(); |
|
int blocks = bsz * beam_size; |
|
int shared_mem_size = (step + 1) * sizeof(long); |
|
|
|
|
|
|
|
|
|
|
|
|
|
banRepeatedTokens<<<blocks, threads, shared_mem_size>>>( |
|
token_ptr, lprob_ptr, max_predict_len, vocab_size, no_repeat_ngram_size); |
|
return lprobs; |
|
} |
|
|