import nltk |
from nltk.corpus import stopwords |
from nltk import word_tokenize, pos_tag |
import torch |
import torch.nn.functional as F |
from torch import nn |
import hashlib |
from scipy.stats import norm |
import gensim |
import pdb |
from transformers import BertForMaskedLM as WoBertForMaskedLM |
from wobert import WoBertTokenizer |
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
from transformers import BertForMaskedLM, BertTokenizer, RobertaForSequenceClassification, RobertaTokenizer |
import gensim.downloader as api |
import Levenshtein |
import string |
import spacy |
import paddle |
from jieba import posseg |
paddle.enable_static() |
import re |
def cut_sent(para): |
para = re.sub('([。!?\?])([^”’])', r'\1\n\2', para) |
para = re.sub('([。!?\?][”’])([^,。!?\?\n ])', r'\1\n\2', para) |
para = re.sub('(\.{6}|\…{2})([^”’\n])', r'\1\n\2', para) |
para = re.sub('([^。!?\?]*)([::][^。!?\?\n]*)', r'\1\n\2', para) |
para = re.sub('([。!?\?][”’])$', r'\1\n', para) |
para = para.rstrip() |
return para.split("\n") |
def is_subword(token: str): |
return token.startswith('##') |
def binary_encoding_function(token): |
hash_value = int(hashlib.sha256(token.encode('utf-8')).hexdigest(), 16) |
random_bit = hash_value % 2 |
return random_bit |
def is_similar(x, y, threshold=0.5): |
distance = Levenshtein.distance(x, y) |
if distance / max(len(x), len(y)) < threshold: |
return True |
return False |
class watermark_model: |
def __init__(self, language, mode, tau_word, lamda): |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
self.language = language |
self.mode = mode |
self.tau_word = tau_word |
self.tau_sent = 0.8 |
self.lamda = lamda |
self.cn_tag_black_list = set(['','x','u','j','k','zg','y','eng','uv','uj','ud','nr','nrfg','nrt','nw','nz','ns','nt','m','mq','r','w','PER','LOC','ORG']) |
self.en_tag_white_list = set(['MD', 'NN', 'NNS', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ', 'RP', 'RB', 'RBR', 'RBS', 'JJ', 'JJR', 'JJS']) |
if language == 'Chinese': |
self.relatedness_tokenizer = AutoTokenizer.from_pretrained("IDEA-CCNL/Erlangshen-Roberta-330M-Similarity") |
self.relatedness_model = AutoModelForSequenceClassification.from_pretrained("IDEA-CCNL/Erlangshen-Roberta-330M-Similarity").to(self.device) |
self.tokenizer = WoBertTokenizer.from_pretrained("junnyu/wobert_chinese_plus_base") |
self.model = WoBertForMaskedLM.from_pretrained("junnyu/wobert_chinese_plus_base", output_hidden_states=True).to(self.device) |
self.w2v_model = gensim.models.KeyedVectors.load_word2vec_format('sgns.merge.word.bz2', binary=False, unicode_errors='ignore', limit=50000) |
elif language == 'English': |
self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased') |
self.model = BertForMaskedLM.from_pretrained('bert-base-cased', output_hidden_states=True).to(self.device) |
self.relatedness_model = RobertaForSequenceClassification.from_pretrained('roberta-large-mnli').to(self.device) |
self.relatedness_tokenizer = RobertaTokenizer.from_pretrained('roberta-large-mnli') |
self.w2v_model = api.load("glove-wiki-gigaword-100") |
nltk.download('stopwords') |
self.stop_words = set(stopwords.words('english')) |
self.nlp = spacy.load('en_core_web_sm') |
def cut(self,ori_text,text_len): |
if self.language == 'Chinese': |
if len(ori_text) > text_len+5: |
ori_text = ori_text[:text_len+5] |
if len(ori_text) < text_len-5: |
return 'Short' |
elif self.language == 'English': |
tokens = self.tokenizer.tokenize(ori_text) |
if len(tokens) > text_len+5: |
ori_text = self.tokenizer.convert_tokens_to_string(tokens[:text_len+5]) |
if len(tokens) < text_len-5: |
return 'Short' |
return ori_text |
else: |
print(f'Unsupported Language:{self.language}') |
raise NotImplementedError |
def sent_tokenize(self,ori_text): |
if self.language == 'Chinese': |
return cut_sent(ori_text) |
elif self.language == 'English': |
return nltk.sent_tokenize(ori_text) |
def pos_filter(self, tokens, masked_token_index, input_text): |
if self.language == 'Chinese': |
pairs = posseg.lcut(input_text) |
pos_dict = {word: pos for word, pos in pairs} |
pos_list_input = [pos for _, pos in pairs] |
pos = pos_dict.get(tokens[masked_token_index], '') |
if pos in self.cn_tag_black_list: |
return False |
else: |
return True |
elif self.language == 'English': |
pos_tags = pos_tag(tokens) |
pos = pos_tags[masked_token_index][1] |
if pos not in self.en_tag_white_list: |
return False |
if is_subword(tokens[masked_token_index]) or is_subword(tokens[masked_token_index+1]) or (tokens[masked_token_index] in self.stop_words or tokens[masked_token_index] in string.punctuation): |
return False |
return True |
def filter_special_candidate(self, top_n_tokens, tokens,masked_token_index,input_text): |
if self.language == 'English': |
filtered_tokens = [tok for tok in top_n_tokens if tok not in self.stop_words and tok not in string.punctuation and pos_tag([tok])[0][1] in self.en_tag_white_list and not is_subword(tok)] |
lemmatized_tokens = [] |
base_word = tokens[masked_token_index] |
base_word_lemma = self.nlp(base_word)[0].lemma_ |
processed_tokens = [base_word]+[tok for tok in filtered_tokens if self.nlp(tok)[0].lemma_ != base_word_lemma] |
return processed_tokens |
elif self.language == 'Chinese': |
pairs = posseg.lcut(input_text) |
pos_dict = {word: pos for word, pos in pairs} |
pos_list_input = [pos for _, pos in pairs] |
pos = pos_dict.get(tokens[masked_token_index], '') |
filtered_tokens = [] |
for tok in top_n_tokens: |
watermarked_text_segtest = self.tokenizer.convert_tokens_to_string(tokens[1:masked_token_index] + [tok] + tokens[masked_token_index+1:-1]) |
watermarked_text_segtest = re.sub(r'(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff,。?!、:])|(?<=[\u4e00-\u9fff,。?!、:])\s+(?=[\u4e00-\u9fff])', '', watermarked_text_segtest) |
pairs_tok = posseg.lcut(watermarked_text_segtest) |
pos_dict_tok = {word: pos for word, pos in pairs_tok} |
flag = pos_dict_tok.get(tok, '') |
if flag not in self.cn_tag_black_list and flag == pos: |
filtered_tokens.append(tok) |
processed_tokens = filtered_tokens |
return processed_tokens |
def global_word_sim(self,word,ori_word): |
try: |
global_score = self.w2v_model.similarity(word,ori_word) |
except KeyError: |
global_score = 0 |
return global_score |
def context_word_sim(self,init_candidates, tokens, masked_token_index, input_text): |
original_input_tensor = self.tokenizer.encode(input_text,return_tensors='pt').to(self.device) |
batch_input_ids = [[self.tokenizer.convert_tokens_to_ids(['[CLS]'] + tokens[1:masked_token_index] + [token] + tokens[masked_token_index+1:-1]+ ['[SEP]'])] for token in init_candidates] |
batch_input_tensors = torch.tensor(batch_input_ids).squeeze().to(self.device) |
batch_input_tensors = torch.cat((batch_input_tensors,original_input_tensor),dim=0) |
with torch.no_grad(): |
outputs = self.model(batch_input_tensors) |
cos_sims = torch.zeros([len(init_candidates)]).to(self.device) |
num_layers = len(outputs[1]) |
N = 8 |
i = masked_token_index |
cos_sim_sum = 0 |
for layer in range(num_layers-N,num_layers): |
ls_hidden_states = outputs[1][layer][0:len(init_candidates), i, :] |
source_hidden_state = outputs[1][layer][len(init_candidates), i, :] |
cos_sim_sum += F.cosine_similarity(source_hidden_state, ls_hidden_states, dim=1) |
cos_sim_avg = cos_sim_sum / N |
cos_sims += cos_sim_avg |
return cos_sims.tolist() |
def sentence_sim(self,init_candidates, tokens, masked_token_index, input_text): |
if self.language == 'Chinese': |
batch_sents = [self.tokenizer.convert_tokens_to_string(tokens[1:masked_token_index] + [token] + tokens[masked_token_index+1:-1]) for token in init_candidates] |
batch_sentences = [re.sub(r'(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff,。?!、:])|(?<=[\u4e00-\u9fff,。?!、:])\s+(?=[\u4e00-\u9fff])', '', sent) for sent in batch_sents] |
roberta_inputs = [input_text + '[SEP]' + s for s in batch_sentences] |
elif self.language == 'English': |
batch_sentences = [self.tokenizer.convert_tokens_to_string(tokens[1:masked_token_index] + [token] + tokens[masked_token_index+1:-1]) for token in init_candidates] |
roberta_inputs = [input_text + '</s></s>' + s for s in batch_sentences] |
encoded_dict = self.relatedness_tokenizer.batch_encode_plus( |
roberta_inputs, |
padding=True, |
truncation=True, |
max_length=512, |
return_tensors='pt') |
input_ids = encoded_dict['input_ids'].to(self.device) |
attention_masks = encoded_dict['attention_mask'].to(self.device) |
with torch.no_grad(): |
outputs = self.relatedness_model(input_ids=input_ids, attention_mask=attention_masks) |
logits = outputs[0] |
probs = torch.softmax(logits, dim=1) |
if self.language == 'Chinese': |
relatedness_scores = probs[:, 1].tolist() |
elif self.language == 'English': |
relatedness_scores = probs[:, 2].tolist() |
return relatedness_scores |
def candidates_gen(self,tokens,masked_token_index,input_text,topk=64, dropout_prob=0.3): |
input_ids_bert = self.tokenizer.convert_tokens_to_ids(tokens) |
if not self.pos_filter(tokens,masked_token_index,input_text): |
return [] |
masked_text = self.tokenizer.convert_tokens_to_string(tokens) |
input_tensor = torch.tensor([input_ids_bert]).to(self.device) |
with torch.no_grad(): |
embeddings = self.model.bert.embeddings(input_tensor) |
dropout = nn.Dropout2d(p=dropout_prob) |
embeddings[:, masked_token_index, :] = dropout(embeddings[:, masked_token_index, :]) |
with torch.no_grad(): |
outputs = self.model(inputs_embeds=embeddings) |
predicted_logits = outputs[0][0][masked_token_index] |
n = topk |
probs = torch.nn.functional.softmax(predicted_logits, dim=-1) |
top_n_probs, top_n_indices = torch.topk(probs, n) |
top_n_tokens = self.tokenizer.convert_ids_to_tokens(top_n_indices.tolist()) |
processed_tokens = self.filter_special_candidate(top_n_tokens,tokens,masked_token_index) |
return processed_tokens |
def filter_candidates(self, init_candidates, tokens, masked_token_index, input_text): |
context_word_similarity_scores = self.context_word_sim(init_candidates, tokens, masked_token_index, input_text) |
sentence_similarity_scores = self.sentence_sim(init_candidates, tokens, masked_token_index, input_text) |
filtered_candidates = [] |
for idx, candidate in enumerate(init_candidates): |
global_word_similarity_score = self.global_word_sim(tokens[masked_token_index], candidate) |
word_similarity_score = self.lamda*context_word_similarity_scores[idx]+(1-self.lamda)*global_word_similarity_score |
if word_similarity_score >= self.tau_word and sentence_similarity_scores[idx] >= self.tau_sent: |
filtered_candidates.append((candidate, word_similarity_score)) |
return filtered_candidates |
def watermark_embed(self,text): |
input_text = text |
tokens = self.tokenizer.tokenize(input_text) |
tokens = ['[CLS]'] + tokens + ['[SEP]'] |
masked_tokens=tokens.copy() |
start_index = 1 |
end_index = len(tokens) - 1 |
for masked_token_index in range(start_index+1, end_index-1): |
binary_encoding = binary_encoding_function(tokens[masked_token_index - 1] + tokens[masked_token_index]) |
if binary_encoding == 1: |
continue |
init_candidates = self.candidates_gen(tokens,masked_token_index,input_text, 32, 0.3) |
if len(init_candidates) <=1: |
continue |
enhanced_candidates = self.filter_candidates(init_candidates,tokens,masked_token_index,input_text) |
hash_top_tokens = enhanced_candidates.copy() |
for i, tok in enumerate(enhanced_candidates): |
binary_encoding = binary_encoding_function(tokens[masked_token_index - 1] + tok[0]) |
if binary_encoding != 1 or (is_similar(tok[0], tokens[masked_token_index])) or (tokens[masked_token_index - 1] in tok or tokens[masked_token_index + 1] in tok): |
hash_top_tokens.remove(tok) |
hash_top_tokens.sort(key=lambda x: x[1], reverse=True) |
if len(hash_top_tokens) > 0: |
selected_token = hash_top_tokens[0][0] |
else: |
selected_token = tokens[masked_token_index] |
tokens[masked_token_index] = selected_token |
watermarked_text = self.tokenizer.convert_tokens_to_string(tokens[1:-1]) |
if self.language == 'Chinese': |
watermarked_text = re.sub(r'(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff,。?!、:])|(?<=[\u4e00-\u9fff,。?!、:])\s+(?=[\u4e00-\u9fff])', '', watermarked_text) |
return watermarked_text |
def embed(self, ori_text): |
sents = self.sent_tokenize(ori_text) |
sents = [s for s in sents if s.strip()] |
num_sents = len(sents) |
watermarked_text = '' |
for i in range(0, num_sents, 2): |
if i+1 < num_sents: |
sent_pair = sents[i] + sents[i+1] |
else: |
sent_pair = sents[i] |
if len(watermarked_text) == 0: |
watermarked_text = self.watermark_embed(sent_pair) |
else: |
watermarked_text = watermarked_text + self.watermark_embed(sent_pair) |
if len(self.get_encodings_fast(ori_text)) == 0: |
return '' |
return watermarked_text |
def get_encodings_fast(self,text): |
sents = self.sent_tokenize(text) |
sents = [s for s in sents if s.strip()] |
num_sents = len(sents) |
encodings = [] |
for i in range(0, num_sents, 2): |
if i+1 < num_sents: |
sent_pair = sents[i] + sents[i+1] |
else: |
sent_pair = sents[i] |
tokens = self.tokenizer.tokenize(sent_pair) |
for index in range(1,len(tokens)-1): |
if not self.pos_filter(tokens,index,text): |
continue |
bit = binary_encoding_function(tokens[index-1]+tokens[index]) |
encodings.append(bit) |
return encodings |
def watermark_detector_fast(self, text,alpha=0.05): |
p = 0.5 |
encodings = self.get_encodings_fast(text) |
n = len(encodings) |
ones = sum(encodings) |
z = (ones - p * n) / (n * p * (1 - p)) ** 0.5 |
threshold = norm.ppf(1 - alpha, loc=0, scale=1) |
p_value = norm.sf(z) |
is_watermark = z >= threshold |
return is_watermark, p_value, n, ones, z |
def get_encodings_precise(self, text): |
sents = self.sent_tokenize(text) |
sents = [s for s in sents if s.strip()] |
num_sents = len(sents) |
encodings = [] |
for i in range(0, num_sents, 2): |
if i+1 < num_sents: |
sent_pair = sents[i] + sents[i+1] |
else: |
sent_pair = sents[i] |
tokens = self.tokenizer.tokenize(sent_pair) |
tokens = ['[CLS]'] + tokens + ['[SEP]'] |
masked_tokens=tokens.copy() |
start_index = 1 |
end_index = len(tokens) - 1 |
for masked_token_index in range(start_index+1, end_index-1): |
init_candidates = self.candidates_gen(tokens,masked_token_index,sent_pair, 8, 0) |
if len(init_candidates) <=1: |
continue |
enhanced_candidates = self.filter_candidates(init_candidates,tokens,masked_token_index,sent_pair) |
if len(enhanced_candidates) > 1: |
bit = binary_encoding_function(tokens[masked_token_index-1]+tokens[masked_token_index]) |
encodings.append(bit) |
return encodings |
def watermark_detector_precise(self,text,alpha=0.05): |
p = 0.5 |
encodings = self.get_encodings_precise(text) |
n = len(encodings) |
ones = sum(encodings) |
if n == 0: |
z = 0 |
else: |
z = (ones - p * n) / (n * p * (1 - p)) ** 0.5 |
threshold = norm.ppf(1 - alpha, loc=0, scale=1) |
p_value = norm.sf(z) |
is_watermark = z >= threshold |
return is_watermark, p_value, n, ones, z |