Spaces:
Runtime error
Runtime error
import torch | |
from utils import gather_log_probs, mask_hf_labels, masked_mean | |
def es_sentiment(pre_logits, post_logits, raw_targets, same_sent_mask, NULL_TOKEN=0): | |
with torch.no_grad(): | |
mask, targ = mask_hf_labels(raw_targets) | |
pos_mask = same_sent_mask.unsqueeze(-1) * mask | |
neg_mask = (~same_sent_mask).unsqueeze(-1) * mask | |
# Compute log likelihoods of pos/neg samples | |
pre_edit_token_log_probs = gather_log_probs(pre_logits, targ) | |
post_edit_token_log_probs = gather_log_probs(post_logits, targ) | |
mean_pos_pre = masked_mean(pre_edit_token_log_probs, pos_mask) | |
mean_pos_post = masked_mean(post_edit_token_log_probs, pos_mask) | |
mean_neg_post = masked_mean(post_edit_token_log_probs, neg_mask) | |
z_sent = (mean_pos_post - mean_neg_post).sigmoid() | |
z_topic_raw = (mean_pos_post - mean_pos_pre).exp() | |
z_topic = min(1, z_topic_raw) | |
es_sent = z_sent * z_topic | |
return { | |
"acc_sent": es_sent, | |
"z_sent": z_sent, | |
"z_topic": z_topic, | |
"z_topic_raw": z_topic_raw, | |
"correct_probs": mean_pos_post, | |
"wrong_probs": mean_neg_post, | |
} | |
# DEPRECATED | |
def sent_success(pre_edit_probs, post_edit_probs, pos_mask, eps=torch.finfo(torch.float32).eps, batch_size=20): | |
assert False, "No longer used" | |
# content_score = post_edit_probs[pos_mask].prod() ** (1/pos_mask.sum()) / (pre_edit_probs[pos_mask]. + eps) | |
post_pos_avg = post_edit_probs[pos_mask].prod() ** (1 / pos_mask.sum()) | |
pre_pos_avg = pre_edit_probs[pos_mask].prod() ** (1 / pos_mask.sum()) | |
content_score = post_pos_avg / (pre_pos_avg + eps) | |
z_content = min(1., content_score) | |
# compute z_sent through a weighting objective | |
# normalized_probs = post_edit_probs / (post_edit_probs.sum() + eps) | |
# balancing_factor = 0.5 * ((~pos_mask).float().sum() / pos_mask.float().sum() + 1) | |
# z_sent_weight = balancing_factor * normalized_probs.dot(pos_mask.float()) | |
post_neg_avg = post_edit_probs[~pos_mask].prod() ** (1 / (~pos_mask).sum()) | |
neg_over_pos = post_neg_avg / (eps + post_pos_avg) | |
z_sent_weight = 1 / (1 + neg_over_pos) | |
# compute z_sent through a ranking objective | |
batch_mask = pos_mask.view(-1, batch_size).long() | |
sort_idxs = post_edit_probs.view(-1, batch_size).sort(-1, descending=True).indices | |
ranked_mask = batch_mask.gather(1, sort_idxs) | |
true_mask = batch_mask.sort(-1, descending=True).values | |
z_sent_rank = (ranked_mask == true_mask).float().mean() | |
# compute the final success scores | |
weight_success = (z_content * z_sent_weight) ** 0.5 | |
rank_success = (z_content * z_sent_rank) ** 0.5 | |
correct_probs = post_edit_probs[pos_mask].mean() | |
wrong_probs = post_edit_probs[~pos_mask].mean() | |
return { | |
"acc_weight": weight_success, | |
"acc_rank": rank_success, | |
"rank_score": z_sent_rank, | |
"weight_score": z_sent_weight, | |
"content_score": content_score, | |
"post_edit_probs": post_edit_probs.sum(), | |
"pre_edit_probs": pre_edit_probs.sum(), | |
"correct_probs": correct_probs, | |
"wrong_probs": wrong_probs | |
} | |
# def sent_retain(pre_logits, post_logits, sent_mask, batch_size=20, eps=torch.finfo(torch.float32).eps): | |
# pre_log_probs = pre_logits.log_softmax(-1).gather(-1, all_targ.unsqueeze(-1)).squeeze(-1) | |
# post_log_probs = post_logits.log_softmax(-1).gather(-1, all_targ.unsqueeze(-1)).squeeze(-1) | |
# pre_batch = pre_probs.view(-1, batch_size) | |
# post_batch = post_probs.view(-1, batch_size) | |
# mask_batch = sent_mask.view(-1, batch_size) | |
# stats = [] | |
# for pre, post, mask in zip(pre_batch, post_batch, mask_batch): | |
# avg_pre = pre.prod() ** (1 / pre.numel()) | |
# avg_post = post.prod() ** (1 / post.numel()) | |
# z_avg = min(avg_pre / avg_post, avg_post / avg_pre) | |
# post_neg_avg = post[~mask].prod() ** (1 / (~mask).sum()) | |
# post_pos_avg = post[mask].prod() ** (1 / mask.sum()) | |
# pre_neg_avg = pre[~mask].prod() ** (1 / (~mask).sum()) | |
# pre_pos_avg = pre[mask].prod() ** (1 / mask.sum()) | |
# post_neg_over_pos = post_neg_avg / (eps + post_pos_avg) | |
# pre_neg_over_pos = pre_neg_avg / (eps + pre_pos_avg) | |
# z_post = 1 / (1 + post_neg_over_pos) | |
# z_pre = 1 / (1 + pre_neg_over_pos) | |
# z_sent = min(z_post / z_pre, z_pre / z_post) | |
# stats.append((z_avg * z_sent) ** 0.5) | |
# return sum(stats) / len(stats) | |
# For zsRE and F-NLI | |
def retain_rate(pre_logits, post_logits, mask=None): | |
if pre_logits.shape[-1] == 1: | |
pre_logits = pre_logits.squeeze(-1) | |
if post_logits.shape[-1] == 1: | |
post_logits = post_logits.squeeze(-1) | |
assert pre_logits.shape == post_logits.shape | |
assert pre_logits.shape[0] == mask.shape[0] | |
if pre_logits.dim() == 1: | |
# binary classification | |
pre_preds = pre_logits > 0 | |
post_preds = post_logits > 0 | |
retain = (pre_preds == post_preds).float().mean() | |
elif pre_logits.dim() == 3: | |
# sequence modeling | |
pre_preds = pre_logits.argmax(-1) | |
post_preds = post_logits.argmax(-1) | |
match = (pre_preds == post_preds) * mask | |
retain = (match.sum(-1) == mask.sum(-1)).float().mean() | |
else: | |
raise NotImplementedError | |
return retain.item() | |