Spaces:
Running
Running
File size: 897 Bytes
26925fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
import torch
import torch.nn.functional as F
from collections import defaultdict
def make_positions(tensor, padding_idx):
"""Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1. Padding symbols are ignored.
"""
# The series of casts and type-conversions here are carefully
# balanced to both work with ONNX export and XLA. In particular XLA
# prefers ints, cumsum defaults to output longs, and ONNX doesn't know
# how to handle the dtype kwarg in cumsum.
mask = tensor.ne(padding_idx).int()
return (
torch.cumsum(mask, dim=1).type_as(mask) * mask
).long() + padding_idx
def fill_with_neg_inf2(t):
"""FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(-1e8).type_as(t)
def softmax(x, dim):
return F.softmax(x, dim=dim, dtype=torch.float32)
|