Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn.functional as F | |
from metrics import es_sentiment | |
from utils import gather_log_probs, mask_hf_labels, masked_mean | |
def balanced_bce(log_probs, labels, eps=torch.finfo(torch.float32).eps): | |
assert labels.max() <= 1 | |
assert labels.min() >= 0 | |
pos_losses = -log_probs[labels == 1] | |
neg_probs = 1 - log_probs.exp() | |
neg_probs[neg_probs == 0] += eps # for numerical stability | |
neg_losses = -neg_probs.log()[labels == 0] | |
pos_loss = pos_losses.mean() if pos_losses.numel() > 0 else 0 | |
neg_loss = neg_losses.mean() if neg_losses.numel() > 0 else 0 | |
return pos_loss + neg_loss | |
def kl_loc_loss(pre, post, mask=None): | |
pre = pre.to(torch.float32) | |
post = post.to(torch.float32) | |
sequence = pre.dim() == 3 | |
pre_ = pre.view(-1, pre.shape[-1]) | |
post_ = post.view(pre_.shape) | |
assert pre_.shape[0] == post_.shape[0] | |
if not sequence: | |
if pre_.shape[-1] == 1: # No masking needed for binary classification | |
return (pre.sigmoid() * (F.logsigmoid(pre) - F.logsigmoid(post))).mean() + ( | |
(-pre).sigmoid() * (F.logsigmoid(-pre) - F.logsigmoid(-post)) | |
).mean() | |
else: # We have sequences of predictions; masking needed | |
if pre_.shape[-1] > 1: | |
assert mask is not None | |
mask_ = mask.view(pre_.shape[0]) | |
kl = (pre_.softmax(-1) * (pre_.log_softmax(-1) - post_.log_softmax(-1))).sum(-1) | |
return (kl * mask_).sum() / mask_.sum() | |
raise NotImplementedError | |
def binary_log_probs(pred, targ, should_reduce=True): | |
assert targ.max() <= 1 | |
assert targ.min() >= 0 | |
neg_mask = torch.ones_like(pred) | |
neg_mask[targ == 0] *= -1 | |
pred = pred * neg_mask | |
log_probs = F.logsigmoid(pred) | |
acc = (log_probs.exp() > 0.5).float() | |
if should_reduce: | |
acc = acc.mean() | |
return { | |
"acc": acc, | |
"log_prob": log_probs.mean(), | |
"prob": log_probs.exp().mean(), | |
"nll": -log_probs.mean(), | |
"n_tokens": log_probs.shape[0] | |
} | |
def multiclass_log_probs( | |
pred, | |
raw_targets, | |
shift=True, | |
eps=torch.finfo(torch.float32).eps, | |
should_reduce=True, | |
**kwargs, | |
): | |
NULL_TOKEN = 0 # a placeholder used for masked target locations | |
pred = pred.clone() | |
mask, targ = mask_hf_labels(raw_targets) | |
if shift and pred.dim() == 3: # Dealing with sequences | |
pred = pred[:, :-1] # Remove last prediction in sequence | |
targ = targ[:, 1:] # Shift to align predictions and targets | |
unmasked_log_probs = gather_log_probs(pred, targ) | |
pred_ids = pred.argmax(-1).masked_fill(~mask, NULL_TOKEN) | |
correct = pred_ids == targ | |
if pred.dim() == 3: | |
correct = (pred_ids == targ).all(-1) # We want to get the whole sequence right | |
acc = correct.float() | |
if should_reduce: | |
acc = acc.mean() | |
if "inner_sent" in kwargs: | |
# Only use outer samples with the same sentiment as the inner sample | |
same_sent_mask = torch.tensor([i == o for i, o in zip(kwargs["inner_sent"], kwargs["outer_sent"])], device=pred.device) | |
good_mask = mask * same_sent_mask.unsqueeze(-1) | |
bad_mask = mask * (~same_sent_mask.unsqueeze(-1)) | |
good_log_prob = masked_mean(unmasked_log_probs, good_mask) | |
bad_log_prob = masked_mean((1 - unmasked_log_probs.exp() + eps).log(), bad_mask) | |
n_tokens = good_mask.float().sum() | |
avg_log_prob = good_log_prob | |
if kwargs["unlikelihood"]: | |
nll = -good_log_prob - bad_log_prob | |
else: | |
nll = -good_log_prob | |
else: | |
n_tokens = mask.float().sum() | |
avg_log_prob = (unmasked_log_probs * mask.float()).sum() / n_tokens | |
nll = -avg_log_prob | |
info_dict = { | |
"acc": acc, | |
"log_prob": avg_log_prob, | |
"prob": avg_log_prob.exp(), | |
"n_tokens": n_tokens, | |
"nll": nll | |
} | |
if "inner_sent" in kwargs: | |
info_dict.update(es_sentiment(kwargs["pre_edit_logits"], | |
kwargs["post_edit_logits"], | |
raw_targets, | |
same_sent_mask)) | |
return info_dict | |
def masked_log_probs(pred, targ, shift=True, **kwargs): | |
pred = pred.to(torch.float32) | |
if not (pred.dim() == 2 or pred.dim() == 3): | |
raise RuntimeError(f"Expected pred to have 2 or 3 dimensions, got {pred.shape}") | |
if pred.shape[-1] == 1: | |
should_reduce = True | |
if "should_reduce" in kwargs: | |
should_reduce = kwargs["should_reduce"] | |
return binary_log_probs(pred, targ, should_reduce=should_reduce) | |
else: | |
return multiclass_log_probs(pred, targ, shift=shift, **kwargs) | |
def test_masked_log_probs(): | |
print() | |
N = 10000 | |
pred = torch.randn(10, 15, N) | |
targ = torch.randint(0, N, (10, 15)) | |
true_pred = pred.clone() | |
true_pred.scatter_(2, targ.unsqueeze(-1), 5) | |
true_pred = true_pred.roll(-1, 1) | |
half_pred = true_pred.clone() | |
mask = torch.arange(10) % 2 == 0 | |
half_pred[mask] = pred[mask] | |
pred_ = pred.clone() | |
true_pred_ = true_pred.clone() | |
half_pred_ = half_pred.clone() | |
targ_ = targ.clone() | |
print(masked_log_probs(pred, targ, return_acc=True)) | |
print(masked_log_probs(true_pred, targ, return_acc=True)) | |
print(masked_log_probs(half_pred, targ, return_acc=True)) | |
assert (pred == pred_).all() | |
assert (targ == targ_).all() | |
assert (half_pred == half_pred_).all() | |
assert (true_pred == true_pred_).all() | |
import pdb; pdb.set_trace() | |
pred = torch.randn(1000, 15, 1) | |
targ = torch.randint(0, 2, (1000, 15)) | |
print(masked_log_probs(pred, targ, return_acc=True)) | |
if __name__ == "__main__": | |
torch.manual_seed(0) | |
test_masked_log_probs() | |