File size: 2,078 Bytes
436c4c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import random
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased")
def split_vocabulary(seed=42):
# Initialize the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased")
# Get the full vocabulary
vocab = list(tokenizer.get_vocab().items())
# Initialize the random number generator
random.seed(seed)
# Split the vocabulary into permissible and non-permissible buckets
permissible = {}
non_permissible = {}
for word, index in vocab:
if random.random() < 0.5: # 50% chance of being permissible
permissible[word] = index
else:
non_permissible[word] = index
return permissible, non_permissible
def get_logits_for_mask(model, tokenizer, sentence):
inputs = tokenizer(sentence, return_tensors="pt")
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
mask_token_logits = logits[0, mask_token_index, :]
return mask_token_logits.squeeze()
def filter_logits(logits, permissible_indices):
filtered_logits = logits.clone()
if filtered_logits.dim() > 1:
filtered_logits = filtered_logits.squeeze()
if filtered_logits.shape != permissible_indices.shape:
permissible_indices = permissible_indices[:filtered_logits.shape[0]]
filtered_logits[~permissible_indices] = float('-inf')
return filtered_logits
# Usage example
permissible, non_permissible = split_vocabulary(seed=42)
permissible_indices = torch.tensor([i in permissible.values() for i in range(len(tokenizer))])
# When sampling:
sentence = "The [MASK] is bright today."
logits = get_logits_for_mask(model, tokenizer, sentence)
filtered_logits = filter_logits(logits, permissible_indices)
# Use filtered_logits for sampling |