aiisc-watermarking-model / vocabulary_split.py
BheemaShankerNeyigapula's picture
Upload folder using huggingface_hub
ea6afa4 verified
import random
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
# Load tokenizer and model once
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased")
def split_vocabulary(seed=42):
"""Split the vocabulary into permissible and non-permissible buckets."""
# Get the full vocabulary
vocab = list(tokenizer.get_vocab().items())
# Initialize the random number generator
random.seed(seed)
# Split the vocabulary
permissible = {}
non_permissible = {}
for word, index in vocab:
target_dict = permissible if random.random() < 0.5 else non_permissible
target_dict[word] = index
return permissible, non_permissible
def get_logits_for_mask(sentence):
"""Get the logits for the masked token in the sentence."""
inputs = tokenizer(sentence, return_tensors="pt")
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
with torch.no_grad():
logits = model(**inputs).logits[0, mask_token_index, :]
return logits.squeeze()
def filter_logits(logits, permissible_indices):
"""Filter logits based on permissible indices."""
filtered_logits = logits.clone()
# Set logits to -inf for non-permissible indices
filtered_logits[~permissible_indices] = float('-inf')
return filtered_logits
# Usage example
permissible, _ = split_vocabulary(seed=42)
# Create permissible indices tensor
permissible_indices = torch.tensor([i in permissible.values() for i in range(len(tokenizer))], dtype=torch.bool)
# When sampling:
sentence = "The [MASK] is bright today."
logits = get_logits_for_mask(sentence)
filtered_logits = filter_logits(logits, permissible_indices)