Spaces:
Running
Running
import hashlib | |
from typing import List | |
import numpy as np | |
from scipy.stats import norm | |
import torch | |
from transformers import LogitsWarper | |
class GPTWatermarkBase: | |
""" | |
Base class for watermarking distributions with fixed-group green-listed tokens. | |
Args: | |
fraction: The fraction of the distribution to be green-listed. | |
strength: The strength of the green-listing. Higher values result in higher logit scores for green-listed tokens. | |
vocab_size: The size of the vocabulary. | |
watermark_key: The random seed for the green-listing. | |
""" | |
def __init__(self, fraction: float = 0.5, strength: float = 2.0, vocab_size: int = 50257, watermark_key: int = 0): | |
rng = np.random.default_rng(self._hash_fn(watermark_key)) | |
mask = np.array([True] * int(fraction * vocab_size) + [False] * (vocab_size - int(fraction * vocab_size))) | |
rng.shuffle(mask) | |
self.green_list_mask = torch.tensor(mask, dtype=torch.float32) | |
self.strength = strength | |
self.fraction = fraction | |
def _hash_fn(x: int) -> int: | |
"""solution from https://stackoverflow.com/questions/67219691/python-hash-function-that-returns-32-or-64-bits""" | |
x = np.int64(x) | |
return int.from_bytes(hashlib.sha256(x).digest()[:4], 'little') | |
class GPTWatermarkLogitsWarper(GPTWatermarkBase, LogitsWarper): | |
""" | |
LogitsWarper for watermarking distributions with fixed-group green-listed tokens. | |
Args: | |
fraction: The fraction of the distribution to be green-listed. | |
strength: The strength of the green-listing. Higher values result in higher logit scores for green-listed tokens. | |
vocab_size: The size of the vocabulary. | |
watermark_key: The random seed for the green-listing. | |
""" | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor: | |
"""Add the watermark to the logits and return new logits.""" | |
watermark = self.strength * self.green_list_mask | |
new_logits = scores + watermark.to(scores.device) | |
return new_logits | |
class GPTWatermarkDetector(GPTWatermarkBase): | |
""" | |
Class for detecting watermarks in a sequence of tokens. | |
Args: | |
fraction: The fraction of the distribution to be green-listed. | |
strength: The strength of the green-listing. Higher values result in higher logit scores for green-listed tokens. | |
vocab_size: The size of the vocabulary. | |
watermark_key: The random seed for the green-listing. | |
""" | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
def _z_score(num_green: int, total: int, fraction: float) -> float: | |
"""Calculate and return the z-score of the number of green tokens in a sequence.""" | |
return (num_green - fraction * total) / np.sqrt(fraction * (1 - fraction) * total) | |
def _compute_tau(m: int, N: int, alpha: float) -> float: | |
""" | |
Compute the threshold tau for the dynamic thresholding. | |
Args: | |
m: The number of unique tokens in the sequence. | |
N: Vocabulary size. | |
alpha: The false positive rate to control. | |
Returns: | |
The threshold tau. | |
""" | |
factor = np.sqrt(1 - (m - 1) / (N - 1)) | |
tau = factor * norm.ppf(1 - alpha) | |
return tau | |
def detect(self, sequence: List[int]) -> float: | |
"""Detect the watermark in a sequence of tokens and return the z value.""" | |
green_tokens = int(sum(self.green_list_mask[i] for i in sequence)) | |
green_tokens_mask = [] | |
for i in sequence: | |
if self.green_list_mask[i]: | |
green_tokens_mask.append(True) | |
else: | |
green_tokens_mask.append(False) | |
# self.green_tokens_mask = green_tokens_mask | |
return self._z_score(green_tokens, len(sequence), self.fraction), green_tokens_mask,green_tokens,len(sequence) | |
def unidetect(self, sequence: List[int]) -> float: | |
"""Detect the watermark in a sequence of tokens and return the z value. Just for unique tokens.""" | |
sequence = list(set(sequence)) | |
green_tokens = int(sum(self.green_list_mask[i] for i in sequence)) | |
return self._z_score(green_tokens, len(sequence), self.fraction) | |
def dynamic_threshold(self, sequence: List[int], alpha: float, vocab_size: int) -> (bool, float): | |
"""Dynamic thresholding for watermark detection. True if the sequence is watermarked, False otherwise.""" | |
z_score = self.unidetect(sequence) | |
tau = self._compute_tau(len(list(set(sequence))), vocab_size, alpha) | |
return z_score > tau, z_score | |