File size: 7,479 Bytes
4506e19 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
from transformers import pipeline
import random
from nltk.corpus import stopwords
import nltk
nltk.download('stopwords')
import math
from vocabulary_split import split_vocabulary, filter_logits
import abc
from typing import List
# Load tokenizer and model for masked language model
tokenizer = AutoTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
model = AutoModelForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking")
fill_mask = pipeline("fill-mask", model=model, tokenizer=tokenizer)
# Get permissible vocabulary
permissible, _ = split_vocabulary(seed=42)
permissible_indices = torch.tensor([i in permissible.values() for i in range(len(tokenizer))])
def get_logits_for_mask(model, tokenizer, sentence):
inputs = tokenizer(sentence, return_tensors="pt")
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
mask_token_logits = logits[0, mask_token_index, :]
return mask_token_logits.squeeze()
# Abstract Masking Strategy
class MaskingStrategy(abc.ABC):
@abc.abstractmethod
def select_words_to_mask(self, words: List[str], **kwargs) -> List[int]:
"""
Given a list of words, return the indices of words to mask.
"""
pass
# Specific Masking Strategies
class RandomNonStopwordMasking(MaskingStrategy):
def __init__(self, num_masks: int = 1):
self.num_masks = num_masks
self.stop_words = set(stopwords.words('english'))
def select_words_to_mask(self, words: List[str], **kwargs) -> List[int]:
non_stop_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
if not non_stop_indices:
return []
num_masks = min(self.num_masks, len(non_stop_indices))
return random.sample(non_stop_indices, num_masks)
class HighEntropyMasking(MaskingStrategy):
def __init__(self, num_masks: int = 1):
self.num_masks = num_masks
def select_words_to_mask(self, words: List[str], sentence: str, model, tokenizer, permissible_indices) -> List[int]:
candidate_indices = [i for i, word in enumerate(words) if word.lower() not in set(stopwords.words('english'))]
if not candidate_indices:
return []
entropy_scores = {}
for idx in candidate_indices:
masked_sentence = ' '.join(words[:idx] + ['[MASK]'] + words[idx+1:])
logits = get_logits_for_mask(model, tokenizer, masked_sentence)
filtered_logits = filter_logits(logits, permissible_indices)
probs = torch.softmax(filtered_logits, dim=-1)
top_5_probs = probs.topk(5).values
entropy = -torch.sum(top_5_probs * torch.log(top_5_probs + 1e-10)).item()
entropy_scores[idx] = entropy
# Select top N indices with highest entropy
sorted_indices = sorted(entropy_scores, key=entropy_scores.get, reverse=True)
return sorted_indices[:self.num_masks]
class PseudoRandomNonStopwordMasking(MaskingStrategy):
def __init__(self, num_masks: int = 1, seed: int = 10):
self.num_masks = num_masks
self.seed = seed
self.stop_words = set(stopwords.words('english'))
def select_words_to_mask(self, words: List[str], **kwargs) -> List[int]:
non_stop_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
if not non_stop_indices:
return []
random.seed(self.seed)
num_masks = min(self.num_masks, len(non_stop_indices))
return random.sample(non_stop_indices, num_masks)
class CompositeMaskingStrategy(MaskingStrategy):
def __init__(self, strategies: List[MaskingStrategy]):
self.strategies = strategies
def select_words_to_mask(self, words: List[str], **kwargs) -> List[int]:
selected_indices = []
for strategy in self.strategies:
if isinstance(strategy, HighEntropyMasking):
selected = strategy.select_words_to_mask(words, **kwargs)
else:
selected = strategy.select_words_to_mask(words)
selected_indices.extend(selected)
return list(set(selected_indices)) # Remove duplicates
# Refactored mask_between_lcs function
def mask_between_lcs(sentence, lcs_points, masking_strategy: MaskingStrategy, model, tokenizer, permissible_indices):
words = sentence.split()
masked_indices = []
segments = []
# Define segments based on LCS points
previous = 0
for point in lcs_points:
if point > previous:
segments.append((previous, point))
previous = point + 1
if previous < len(words):
segments.append((previous, len(words)))
# Collect all indices to mask from each segment
for start, end in segments:
segment_words = words[start:end]
if isinstance(masking_strategy, HighEntropyMasking):
selected = masking_strategy.select_words_to_mask(segment_words, sentence, model, tokenizer, permissible_indices)
else:
selected = masking_strategy.select_words_to_mask(segment_words)
# Adjust indices relative to the whole sentence
for idx in selected:
masked_idx = start + idx
if masked_idx not in masked_indices:
masked_indices.append(masked_idx)
# Apply masking
for idx in masked_indices:
words[idx] = '[MASK]'
masked_sentence = ' '.join(words)
logits = get_logits_for_mask(model, tokenizer, masked_sentence)
# Process each masked token
top_words_list = []
logits_list = []
for i, idx in enumerate(masked_indices):
logits_i = logits[i]
if logits_i.dim() > 1:
logits_i = logits_i.squeeze()
filtered_logits_i = filter_logits(logits_i, permissible_indices)
logits_list.append(filtered_logits_i.tolist())
top_5_indices = filtered_logits_i.topk(5).indices.tolist()
top_words = [tokenizer.decode([i]) for i in top_5_indices]
top_words_list.append(top_words)
return masked_sentence, logits_list, top_words_list
# Example Usage
if __name__ == "__main__":
# Example sentence and LCS points
sentence = "This is a sample sentence with some LCS points"
lcs_points = [2, 5, 8] # Indices of LCS points
# Initialize masking strategies
random_non_stopword_strategy = RandomNonStopwordMasking(num_masks=1)
high_entropy_strategy = HighEntropyMasking(num_masks=1)
pseudo_random_strategy = PseudoRandomNonStopwordMasking(num_masks=1, seed=10)
composite_strategy = CompositeMaskingStrategy([
RandomNonStopwordMasking(num_masks=1),
HighEntropyMasking(num_masks=1)
])
# Choose a strategy
chosen_strategy = composite_strategy # You can choose any initialized strategy
# Apply masking
masked_sentence, logits_list, top_words_list = mask_between_lcs(
sentence,
lcs_points,
masking_strategy=chosen_strategy,
model=model,
tokenizer=tokenizer,
permissible_indices=permissible_indices
)
print("Masked Sentence:", masked_sentence)
for idx, top_words in enumerate(top_words_list):
print(f"Top words for mask {idx+1}:", top_words)
|