Spaces:
Runtime error
Runtime error
import sentencepiece as spm | |
import torch | |
import torch.nn.functional as F | |
from transformers.models.bert.tokenization_bert import BertTokenizer | |
BASELINE = "baseline" | |
KOBE_ATTRIBUTE = "kobe-attr" | |
KOBE_KNOWLEDGE = "kobe-know" | |
KOBE_FULL = "kobe-full" | |
def get_bert_vocab_size(vocab_path: str) -> int: | |
tokenizer = BertTokenizer.from_pretrained(vocab_path) | |
return tokenizer.vocab_size | |
def get_vocab_size(vocab_path: str) -> int: | |
tokenizer = spm.SentencePieceProcessor() | |
tokenizer.Load(vocab_path) | |
return len(tokenizer) | |
# Metrics | |
def accuracy(logits: torch.Tensor, targets: torch.Tensor) -> float: | |
assert logits.dim() == 2 | |
assert targets.dim() == 1 | |
pred = logits.argmax(dim=1) | |
return (pred == targets).sum().item() / targets.shape[0] | |
def top_k_top_p_sampling( | |
logits, top_k=0, top_p=0.0, temperature=1, filter_value=-float("Inf") | |
) -> int: | |
"""Sample from a filtered distribution of logits using top-k and/or nucleus (top-p) filtering | |
Args: | |
logits: logits distribution shape (vocabulary size) | |
top_k >0: keep only top k tokens with highest probability (top-k filtering). | |
top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). | |
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) | |
""" | |
logits /= temperature | |
assert ( | |
logits.dim() == 1 | |
) # batch size 1 for now - could be updated for more but the code would be less clear | |
top_k = min(top_k, logits.size(-1)) # Safety check | |
if top_k > 0: | |
# Remove all tokens with a probability less than the last token of the top-k | |
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] | |
logits[indices_to_remove] = filter_value | |
if top_p > 0.0: | |
sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
# Remove tokens with cumulative probability above the threshold | |
sorted_indices_to_remove = cumulative_probs > top_p | |
# Shift the indices to the right to keep also the first token above the threshold | |
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
sorted_indices_to_remove[..., 0] = 0 | |
indices_to_remove = sorted_indices[sorted_indices_to_remove] | |
logits[indices_to_remove] = filter_value | |
# Sample from the filtered distribution | |
probabilities = F.softmax(logits, dim=-1) | |
next_token = torch.multinomial(probabilities, 1) | |
return int(next_token.item()) | |
def diversity(tokenized_lines, n=4) -> int: | |
"""Defined as the unique number of ngrams generated on the test set.""" | |
n_grams_all = [] | |
for line in tokenized_lines: | |
n_grams = list(zip(*[line[i:] for i in range(n)])) | |
n_grams_all += n_grams | |
return len(set(n_grams_all)) | |