import random import torch from transformers import AutoTokenizer, AutoModelForMaskedLM # Load tokenizer and model once tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased") def split_vocabulary(seed=42): """Split the vocabulary into permissible and non-permissible buckets.""" # Get the full vocabulary vocab = list(tokenizer.get_vocab().items()) # Initialize the random number generator random.seed(seed) # Split the vocabulary permissible = {} non_permissible = {} for word, index in vocab: target_dict = permissible if random.random() < 0.5 else non_permissible target_dict[word] = index return permissible, non_permissible def get_logits_for_mask(sentence): """Get the logits for the masked token in the sentence.""" inputs = tokenizer(sentence, return_tensors="pt") mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] with torch.no_grad(): logits = model(**inputs).logits[0, mask_token_index, :] return logits.squeeze() def filter_logits(logits, permissible_indices): """Filter logits based on permissible indices.""" filtered_logits = logits.clone() # Set logits to -inf for non-permissible indices filtered_logits[~permissible_indices] = float('-inf') return filtered_logits # Usage example permissible, _ = split_vocabulary(seed=42) # Create permissible indices tensor permissible_indices = torch.tensor([i in permissible.values() for i in range(len(tokenizer))], dtype=torch.bool) # When sampling: sentence = "The [MASK] is bright today." logits = get_logits_for_mask(sentence) filtered_logits = filter_logits(logits, permissible_indices)