File size: 5,411 Bytes
823807d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import torch
import torch.nn.functional as F
import math
from einops import rearrange
# return mask where padding is FALSE
def lengths_to_mask(lengths, max_len):
# max_len = max(lengths)
mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1)
return mask #(b, len)
# return mask where padding is ALL FALSE
def get_pad_mask_idx(seq, pad_idx):
return (seq != pad_idx).unsqueeze(1)
# Given seq: (b, s)
# Return mat: (1, s, s)
# Example Output:
# [[[ True, False, False],
# [ True, True, False],
# [ True, True, True]]]
# For causal attention
def get_subsequent_mask(seq):
sz_b, seq_len = seq.shape
subsequent_mask = (1 - torch.triu(
torch.ones((1, seq_len, seq_len)), diagonal=1)).bool()
return subsequent_mask.to(seq.device)
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
def l2norm(t):
return F.normalize(t, dim = -1)
# tensor helpers
# Get a random subset of TRUE mask, with prob
def get_mask_subset_prob(mask, prob):
subset_mask = torch.bernoulli(mask, p=prob) & mask
return subset_mask
# Get mask of special_tokens in ids
def get_mask_special_tokens(ids, special_ids):
mask = torch.zeros_like(ids).bool()
for special_id in special_ids:
mask |= (ids==special_id)
return mask
# network builder helpers
def _get_activation_fn(activation):
if activation == "relu":
return F.relu
elif activation == "gelu":
return F.gelu
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
# classifier free guidance functions
def uniform(shape, device=None):
return torch.zeros(shape, device=device).float().uniform_(0, 1)
def prob_mask_like(shape, prob, device=None):
if prob == 1:
return torch.ones(shape, device=device, dtype=torch.bool)
elif prob == 0:
return torch.zeros(shape, device=device, dtype=torch.bool)
else:
return uniform(shape, device=device) < prob
# sampling helpers
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))
def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))
def gumbel_sample(t, temperature = 1., dim = 1):
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
# Example input:
# [[ 0.3596, 0.0862, 0.9771, -1.0000, -1.0000, -1.0000],
# [ 0.4141, 0.1781, 0.6628, 0.5721, -1.0000, -1.0000],
# [ 0.9428, 0.3586, 0.1659, 0.8172, 0.9273, -1.0000]]
# Example output:
# [[ -inf, -inf, 0.9771, -inf, -inf, -inf],
# [ -inf, -inf, 0.6628, -inf, -inf, -inf],
# [0.9428, -inf, -inf, -inf, -inf, -inf]]
def top_k(logits, thres = 0.9, dim = 1):
k = math.ceil((1 - thres) * logits.shape[dim])
val, ind = logits.topk(k, dim = dim)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(dim, ind, val)
# func verified
# print(probs)
# print(logits)
# raise
return probs
# noise schedules
# More on large value, less on small
def cosine_schedule(t):
return torch.cos(t * math.pi * 0.5)
def scale_cosine_schedule(t, scale):
return torch.clip(scale*torch.cos(t * math.pi * 0.5) + 1 - scale, min=0., max=1.)
# More on small value, less on large
def q_schedule(bs, low, high, device):
noise = uniform((bs,), device=device)
schedule = 1 - cosine_schedule(noise)
return torch.round(schedule * (high - low - 1)).long() + low
def cal_performance(pred, labels, ignore_index=None, smoothing=0., tk=1):
loss = cal_loss(pred, labels, ignore_index, smoothing=smoothing)
# pred_id = torch.argmax(pred, dim=1)
# mask = labels.ne(ignore_index)
# n_correct = pred_id.eq(labels).masked_select(mask)
# acc = torch.mean(n_correct.float()).item()
pred_id_k = torch.topk(pred, k=tk, dim=1).indices
pred_id = pred_id_k[:, 0]
mask = labels.ne(ignore_index)
n_correct = (pred_id_k == labels.unsqueeze(1)).any(dim=1).masked_select(mask)
acc = torch.mean(n_correct.float()).item()
return loss, pred_id, acc
def cal_loss(pred, labels, ignore_index=None, smoothing=0.):
'''Calculate cross entropy loss, apply label smoothing if needed.'''
# print(pred.shape, labels.shape) #torch.Size([64, 1028, 55]) torch.Size([64, 55])
# print(pred.shape, labels.shape) #torch.Size([64, 1027, 55]) torch.Size([64, 55])
if smoothing:
space = 2
n_class = pred.size(1)
mask = labels.ne(ignore_index)
one_hot = rearrange(F.one_hot(labels, n_class + space), 'a ... b -> a b ...')[:, :n_class]
# one_hot = torch.zeros_like(pred).scatter(1, labels.unsqueeze(1), 1)
sm_one_hot = one_hot * (1 - smoothing) + (1 - one_hot) * smoothing / (n_class - 1)
neg_log_prb = -F.log_softmax(pred, dim=1)
loss = (sm_one_hot * neg_log_prb).sum(dim=1)
# loss = F.cross_entropy(pred, sm_one_hot, reduction='none')
loss = torch.mean(loss.masked_select(mask))
else:
loss = F.cross_entropy(pred, labels, ignore_index=ignore_index)
return loss |