import transformers from transformers import AutoTokenizer from transformers import pipeline, set_seed, LogitsProcessor from transformers.generation.logits_process import TopPLogitsWarper, TopKLogitsWarper import torch from scipy.special import gamma, gammainc, gammaincc, betainc from scipy.optimize import fminbound import numpy as np import os hf_token = os.getenv('HF_TOKEN') device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') def hash_tokens(input_ids: torch.LongTensor, key: int): seed = key salt = 35317 for i in input_ids: seed = (seed * salt + i.item()) % (2 ** 64 - 1) return seed class WatermarkingLogitsProcessor(LogitsProcessor): def __init__(self, n, key, messages, window_size, *args, **kwargs): super().__init__(*args, **kwargs) self.batch_size = len(messages) self.generators = [ torch.Generator(device=device) for _ in range(self.batch_size) ] self.n = n self.key = key self.window_size = window_size if not self.window_size: for b in range(self.batch_size): self.generators[b].manual_seed(self.key) self.messages = messages class WatermarkingAaronsonLogitsProcessor( WatermarkingLogitsProcessor): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: # get random uniform variables B, V = scores.shape r = torch.zeros_like(scores) for b in range(B): if self.window_size: window = input_ids[b, -self.window_size:] seed = hash_tokens(window, self.key) self.generators[b].manual_seed(seed) r[b] = torch.rand(self.n, generator=self.generators[b], device=self.generators[b].device).log().roll(-self.messages[b]) # generate n but keep only V, as we want to keep the pseudo-random sequences in sync with the decoder r = r[:,:V] # modify law as r^(1/p) # Since we want to return logits (logits processor takes and outputs logits), # we return log(q), hence torch.log(r) * torch.log(torch.exp(1/p)) = torch.log(r) / p return r / scores.exp() class WatermarkingKirchenbauerLogitsProcessor(WatermarkingLogitsProcessor): def __init__(self, *args, gamma = 0.5, delta = 4.0, **kwargs): super().__init__(*args, **kwargs) self.gamma = gamma self.delta = delta def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: B, V = scores.shape for b in range(B): if self.window_size: window = input_ids[b, -self.window_size:] seed = hash_tokens(window, self.key) self.generators[b].manual_seed(seed) vocab_permutation = torch.randperm(self.n, generator=self.generators[b], device=self.generators[b].device) greenlist = vocab_permutation[:int(self.gamma * self.n)] # gamma * n bias = torch.zeros(self.n).to(scores.device) bias[greenlist] = self.delta bias = bias.roll(-self.messages[b])[:V] scores[b] += bias # add bias to greenlist words return scores class Watermarker(object): def __init__(self, tokenizer=None, model=None, window_size = 0, payload_bits = 0, logits_processor = None, *args, **kwargs): self.tokenizer = tokenizer self.model = model self.model.eval() self.window_size = window_size # preprocessing wrappers self.logits_processor = logits_processor or [] self.payload_bits = payload_bits self.V = max(2**payload_bits, self.model.config.vocab_size) self.generator = torch.Generator(device=device) def embed(self, key=42, messages=[1234], prompt="", max_length=30, method='aaronson'): B = len(messages) # batch size length = max_length # compute capacity if self.payload_bits: assert min([message >= 0 and message < 2**self.payload_bits for message in messages]) # tokenize prompt inputs = self.tokenizer([ prompt ] * B, return_tensors="pt") if method == 'aaronson': # generate with greedy search generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=False, logits_processor = self.logits_processor + [ WatermarkingAaronsonLogitsProcessor(n=self.V, key=key, messages=messages, window_size = self.window_size)]) elif method == 'kirchenbauer': # use sampling generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=True, logits_processor = self.logits_processor + [ WatermarkingKirchenbauerLogitsProcessor(n=self.V, key=key, messages=messages, window_size = self.window_size)]) elif method == 'greedy': # generate with greedy search generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=False, logits_processor = self.logits_processor) elif method == 'sampling': # generate with greedy search generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=True, logits_processor = self.logits_processor) else: raise Exception('Unknown method %s' % method) decoded_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) return decoded_texts def detect(self, attacked_texts, key=42, method='aaronson', gamma=0.5, prompts=None): if(prompts==None): prompts = [""] * len(attacked_texts) generator = self.generator #print("attacked_texts = ", attacked_texts) cdfs = [] ms = [] MAX = 2**self.payload_bits # tokenize input inputs = self.tokenizer(attacked_texts, return_tensors="pt", padding=True, return_attention_mask=True) input_ids = inputs["input_ids"].to(self.model.device) attention_masks = inputs["attention_mask"].to(self.model.device) B,T = input_ids.shape if method == 'aaronson_neyman_pearson': # compute logits outputs = self.model.forward(input_ids, return_dict=True) logits = outputs['logits'] # TODO # reapply logits processors to get same distribution #for i in range(T): # for processor in self.logits_processor: # logits[:,i] = processor(input_ids[:, :i], logits[:, i]) probs = logits.softmax(dim=-1) ps = torch.gather(probs, 2, input_ids[:,1:,None]).squeeze_(-1) seq_len = input_ids.shape[1] length = seq_len V = self.V Z = torch.zeros(size=(B, V), dtype=torch.float32, device=device) # keep a history of contexts we have already seen, # to exclude them from score aggregation and allow # correct p-value computation under H0 history = [set() for _ in range(B)] attention_masks_prompts = self.tokenizer(prompts, return_tensors="pt", padding=True, return_attention_mask=True)["attention_mask"] prompts_length = torch.sum(attention_masks_prompts, dim=1) for b in range(B): attention_masks[b, :prompts_length[b]] = 0 if not self.window_size: generator.manual_seed(key) # We can go from seq_len - prompt_len, need to change +1 to + prompt_len for i in range(seq_len-1): if self.window_size: window = input_ids[b, max(0, i-self.window_size+1):i+1] #print("window = ", window) seed = hash_tokens(window, key) if seed not in history[b]: generator.manual_seed(seed) history[b].add(seed) else: # ignore the token attention_masks[b, i+1] = 0 if not attention_masks[b,i+1]: continue token = int(input_ids[b,i+1]) if method in {'aaronson', 'aaronson_simplified', 'aaronson_neyman_pearson'}: R = torch.rand(V, generator = generator, device = generator.device) if method == 'aaronson': r = -(1-R).log() elif method in {'aaronson_simplified', 'aaronson_neyman_pearson'}: r = -R.log() elif method == 'kirchenbauer': r = torch.zeros(V, device=device) vocab_permutation = torch.randperm(V, generator = generator, device=generator.device) greenlist = vocab_permutation[:int(gamma * V)] r[greenlist] = 1 else: raise Exception('Unknown method %s' % method) if method in {'aaronson', 'aaronson_simplified', 'kirchenbauer'}: # independent of probs Z[b] += r.roll(-token) elif method == 'aaronson_neyman_pearson': # Neyman-Pearson Z[b] += r.roll(-token) * (1/ps[b,i] - 1) for b in range(B): if method in {'aaronson', 'kirchenbauer'}: m = torch.argmax(Z[b,:MAX]) elif method in {'aaronson_simplified', 'aaronson_neyman_pearson'}: m = torch.argmin(Z[b,:MAX]) i = int(m) S = Z[b, i].item() m = i # actual sequence length k = torch.sum(attention_masks[b]).item() - 1 if method == 'aaronson': cdf = gammaincc(k, S) elif method == 'aaronson_simplified': cdf = gammainc(k, S) elif method == 'aaronson_neyman_pearson': # Chernoff bound ratio = ps[b,:k] / (1 - ps[b,:k]) E = (1/ratio).sum() if S > E: cdf = 1.0 else: # to compute p-value we must solve for c*: # (1/(c* + ps/(1-ps))).sum() = S func = lambda c : (((1 / (c + ratio)).sum() - S)**2).item() c1 = (k / S - torch.min(ratio)).item() print("max = ", c1) c = fminbound(func, 0, c1) print("solved c = ", c) print("solved s = ", ((1/(c + ratio)).sum()).item()) # upper bound cdf = torch.exp(torch.sum(-torch.log(1 + c / ratio)) + c * S) elif method == 'kirchenbauer': cdf = betainc(S, k - S + 1, gamma) if cdf > min(1 / MAX, 1e-5): cdf = 1 - (1 - cdf)**MAX # true value else: cdf = cdf * MAX # numerically stable upper bound cdfs.append(float(cdf)) ms.append(m) return cdfs, ms