Spaces:
Running
Running
import nltk | |
from nltk.corpus import stopwords | |
from transformers import AutoTokenizer, AutoModelForMaskedLM | |
from vocabulary_split import split_vocabulary, filter_logits | |
import torch | |
from lcs import find_common_subsequences | |
from paraphraser import generate_paraphrase | |
nltk.download('punkt', quiet=True) | |
nltk.download('stopwords', quiet=True) | |
tokenizer = AutoTokenizer.from_pretrained("bert-large-cased-whole-word-masking") | |
model = AutoModelForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking") | |
permissible, _ = split_vocabulary(seed=42) | |
permissible_indices = torch.tensor([i in permissible.values() for i in range(len(tokenizer))]) | |
def get_non_melting_points(original_sentence): | |
paraphrased_sentences = generate_paraphrase(original_sentence) | |
common_subsequences = find_common_subsequences(original_sentence, paraphrased_sentences) | |
return common_subsequences | |
def get_word_between_points(sentence, start_point, end_point): | |
words = nltk.word_tokenize(sentence) | |
stop_words = set(stopwords.words('english')) | |
start_index = sentence.index(start_point[1]) | |
end_index = sentence.index(end_point[1]) | |
for word in words[start_index+1:end_index]: | |
if word.lower() not in stop_words: | |
return word, words.index(word) | |
return None, None | |
def get_logits_for_mask(sentence): | |
inputs = tokenizer(sentence, return_tensors="pt") | |
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
mask_token_logits = logits[0, mask_token_index, :] | |
return mask_token_logits.squeeze() | |
def detect_watermark(sentence): | |
non_melting_points = get_non_melting_points(sentence) | |
if len(non_melting_points) < 2: | |
return False, "Not enough non-melting points found." | |
word_to_check, index = get_word_between_points(sentence, non_melting_points[0], non_melting_points[1]) | |
if word_to_check is None: | |
return False, "No suitable word found between non-melting points." | |
words = nltk.word_tokenize(sentence) | |
masked_sentence = ' '.join(words[:index] + ['[MASK]'] + words[index+1:]) | |
logits = get_logits_for_mask(masked_sentence) | |
filtered_logits = filter_logits(logits, permissible_indices) | |
top_predictions = filtered_logits.argsort()[-5:] | |
predicted_words = [tokenizer.decode([i]) for i in top_predictions] | |
if word_to_check in predicted_words: | |
return True, f"Watermark detected. The word '{word_to_check}' is in the permissible vocabulary." | |
else: | |
return False, f"No watermark detected. The word '{word_to_check}' is not in the permissible vocabulary." | |
# Example usage | |
# if __name__ == "__main__": | |
# test_sentence = "The quick brown fox jumps over the lazy dog." | |
# is_watermarked, message = detect_watermark(test_sentence) | |
# print(f"Is the sentence watermarked? {is_watermarked}") | |
# print(f"Detection message: {message}") |