import torch | |
from torch.autograd import Function | |
from torch.cuda.amp import custom_bwd, custom_fwd | |
class _trunc_exp(Function): | |
def forward(ctx, x): | |
ctx.save_for_backward(x) | |
return torch.exp(x) | |
def backward(ctx, g): | |
x = ctx.saved_tensors[0] | |
return g * torch.exp(x.clamp(-15, 15)) | |
trunc_exp = _trunc_exp.apply |