Spaces:
Running
Running
import torch | |
import torch.nn.functional as F | |
import math | |
from einops import rearrange | |
# return mask where padding is FALSE | |
def lengths_to_mask(lengths, max_len): | |
# max_len = max(lengths) | |
mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1) | |
return mask #(b, len) | |
# return mask where padding is ALL FALSE | |
def get_pad_mask_idx(seq, pad_idx): | |
return (seq != pad_idx).unsqueeze(1) | |
# Given seq: (b, s) | |
# Return mat: (1, s, s) | |
# Example Output: | |
# [[[ True, False, False], | |
# [ True, True, False], | |
# [ True, True, True]]] | |
# For causal attention | |
def get_subsequent_mask(seq): | |
sz_b, seq_len = seq.shape | |
subsequent_mask = (1 - torch.triu( | |
torch.ones((1, seq_len, seq_len)), diagonal=1)).bool() | |
return subsequent_mask.to(seq.device) | |
def exists(val): | |
return val is not None | |
def default(val, d): | |
return val if exists(val) else d | |
def eval_decorator(fn): | |
def inner(model, *args, **kwargs): | |
was_training = model.training | |
model.eval() | |
out = fn(model, *args, **kwargs) | |
model.train(was_training) | |
return out | |
return inner | |
def l2norm(t): | |
return F.normalize(t, dim = -1) | |
# tensor helpers | |
# Get a random subset of TRUE mask, with prob | |
def get_mask_subset_prob(mask, prob): | |
subset_mask = torch.bernoulli(mask, p=prob) & mask | |
return subset_mask | |
# Get mask of special_tokens in ids | |
def get_mask_special_tokens(ids, special_ids): | |
mask = torch.zeros_like(ids).bool() | |
for special_id in special_ids: | |
mask |= (ids==special_id) | |
return mask | |
# network builder helpers | |
def _get_activation_fn(activation): | |
if activation == "relu": | |
return F.relu | |
elif activation == "gelu": | |
return F.gelu | |
raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) | |
# classifier free guidance functions | |
def uniform(shape, device=None): | |
return torch.zeros(shape, device=device).float().uniform_(0, 1) | |
def prob_mask_like(shape, prob, device=None): | |
if prob == 1: | |
return torch.ones(shape, device=device, dtype=torch.bool) | |
elif prob == 0: | |
return torch.zeros(shape, device=device, dtype=torch.bool) | |
else: | |
return uniform(shape, device=device) < prob | |
# sampling helpers | |
def log(t, eps = 1e-20): | |
return torch.log(t.clamp(min = eps)) | |
def gumbel_noise(t): | |
noise = torch.zeros_like(t).uniform_(0, 1) | |
return -log(-log(noise)) | |
def gumbel_sample(t, temperature = 1., dim = 1): | |
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim) | |
# Example input: | |
# [[ 0.3596, 0.0862, 0.9771, -1.0000, -1.0000, -1.0000], | |
# [ 0.4141, 0.1781, 0.6628, 0.5721, -1.0000, -1.0000], | |
# [ 0.9428, 0.3586, 0.1659, 0.8172, 0.9273, -1.0000]] | |
# Example output: | |
# [[ -inf, -inf, 0.9771, -inf, -inf, -inf], | |
# [ -inf, -inf, 0.6628, -inf, -inf, -inf], | |
# [0.9428, -inf, -inf, -inf, -inf, -inf]] | |
def top_k(logits, thres = 0.9, dim = 1): | |
k = math.ceil((1 - thres) * logits.shape[dim]) | |
val, ind = logits.topk(k, dim = dim) | |
probs = torch.full_like(logits, float('-inf')) | |
probs.scatter_(dim, ind, val) | |
# func verified | |
# print(probs) | |
# print(logits) | |
# raise | |
return probs | |
# noise schedules | |
# More on large value, less on small | |
def cosine_schedule(t): | |
return torch.cos(t * math.pi * 0.5) | |
def scale_cosine_schedule(t, scale): | |
return torch.clip(scale*torch.cos(t * math.pi * 0.5) + 1 - scale, min=0., max=1.) | |
# More on small value, less on large | |
def q_schedule(bs, low, high, device): | |
noise = uniform((bs,), device=device) | |
schedule = 1 - cosine_schedule(noise) | |
return torch.round(schedule * (high - low - 1)).long() + low | |
def cal_performance(pred, labels, ignore_index=None, smoothing=0., tk=1): | |
loss = cal_loss(pred, labels, ignore_index, smoothing=smoothing) | |
# pred_id = torch.argmax(pred, dim=1) | |
# mask = labels.ne(ignore_index) | |
# n_correct = pred_id.eq(labels).masked_select(mask) | |
# acc = torch.mean(n_correct.float()).item() | |
pred_id_k = torch.topk(pred, k=tk, dim=1).indices | |
pred_id = pred_id_k[:, 0] | |
mask = labels.ne(ignore_index) | |
n_correct = (pred_id_k == labels.unsqueeze(1)).any(dim=1).masked_select(mask) | |
acc = torch.mean(n_correct.float()).item() | |
return loss, pred_id, acc | |
def cal_loss(pred, labels, ignore_index=None, smoothing=0.): | |
'''Calculate cross entropy loss, apply label smoothing if needed.''' | |
# print(pred.shape, labels.shape) #torch.Size([64, 1028, 55]) torch.Size([64, 55]) | |
# print(pred.shape, labels.shape) #torch.Size([64, 1027, 55]) torch.Size([64, 55]) | |
if smoothing: | |
space = 2 | |
n_class = pred.size(1) | |
mask = labels.ne(ignore_index) | |
one_hot = rearrange(F.one_hot(labels, n_class + space), 'a ... b -> a b ...')[:, :n_class] | |
# one_hot = torch.zeros_like(pred).scatter(1, labels.unsqueeze(1), 1) | |
sm_one_hot = one_hot * (1 - smoothing) + (1 - one_hot) * smoothing / (n_class - 1) | |
neg_log_prb = -F.log_softmax(pred, dim=1) | |
loss = (sm_one_hot * neg_log_prb).sum(dim=1) | |
# loss = F.cross_entropy(pred, sm_one_hot, reduction='none') | |
loss = torch.mean(loss.masked_select(mask)) | |
else: | |
loss = F.cross_entropy(pred, labels, ignore_index=ignore_index) | |
return loss |