Spaces:
Runtime error
Runtime error
import random | |
import torch | |
from transformers import AutoTokenizer, AutoModelForMaskedLM | |
# Load tokenizer and model once | |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased") | |
def split_vocabulary(seed=42): | |
"""Split the vocabulary into permissible and non-permissible buckets.""" | |
# Get the full vocabulary | |
vocab = list(tokenizer.get_vocab().items()) | |
# Initialize the random number generator | |
random.seed(seed) | |
# Split the vocabulary | |
permissible = {} | |
non_permissible = {} | |
for word, index in vocab: | |
target_dict = permissible if random.random() < 0.5 else non_permissible | |
target_dict[word] = index | |
return permissible, non_permissible | |
def get_logits_for_mask(sentence): | |
"""Get the logits for the masked token in the sentence.""" | |
inputs = tokenizer(sentence, return_tensors="pt") | |
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] | |
with torch.no_grad(): | |
logits = model(**inputs).logits[0, mask_token_index, :] | |
return logits.squeeze() | |
def filter_logits(logits, permissible_indices): | |
"""Filter logits based on permissible indices.""" | |
filtered_logits = logits.clone() | |
# Set logits to -inf for non-permissible indices | |
filtered_logits[~permissible_indices] = float('-inf') | |
return filtered_logits | |
# Usage example | |
permissible, _ = split_vocabulary(seed=42) | |
# Create permissible indices tensor | |
permissible_indices = torch.tensor([i in permissible.values() for i in range(len(tokenizer))], dtype=torch.bool) | |
# When sampling: | |
sentence = "The [MASK] is bright today." | |
logits = get_logits_for_mask(sentence) | |
filtered_logits = filter_logits(logits, permissible_indices) | |