File size: 2,078 Bytes
436c4c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import random
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased")
def split_vocabulary(seed=42):
    # Initialize the tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased")

    # Get the full vocabulary
    vocab = list(tokenizer.get_vocab().items())

    # Initialize the random number generator
    random.seed(seed)

    # Split the vocabulary into permissible and non-permissible buckets
    permissible = {}
    non_permissible = {}

    for word, index in vocab:
        if random.random() < 0.5:  # 50% chance of being permissible
            permissible[word] = index
        else:
            non_permissible[word] = index

    return permissible, non_permissible

def get_logits_for_mask(model, tokenizer, sentence):
    inputs = tokenizer(sentence, return_tensors="pt")
    mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]

    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    mask_token_logits = logits[0, mask_token_index, :]
    return mask_token_logits.squeeze()

def filter_logits(logits, permissible_indices):
    filtered_logits = logits.clone()
    if filtered_logits.dim() > 1:
        filtered_logits = filtered_logits.squeeze()
    if filtered_logits.shape != permissible_indices.shape:
        permissible_indices = permissible_indices[:filtered_logits.shape[0]]
    filtered_logits[~permissible_indices] = float('-inf')
    return filtered_logits

# Usage example
permissible, non_permissible = split_vocabulary(seed=42)
permissible_indices = torch.tensor([i in permissible.values() for i in range(len(tokenizer))])

# When sampling:
sentence = "The [MASK] is bright today."
logits = get_logits_for_mask(model, tokenizer, sentence)
filtered_logits = filter_logits(logits, permissible_indices)
# Use filtered_logits for sampling