Spaces:
Sleeping
Sleeping
# coding=utf-8 | |
# Copyright 2023 Authors of "A Watermark for Large Language Models" | |
# available at https://arxiv.org/abs/2301.10226 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import argparse | |
from pprint import pprint | |
from functools import partial | |
import numpy # for gradio hot reload | |
import gradio as gr | |
import torch | |
from transformers import (AutoTokenizer, | |
AutoModelForSeq2SeqLM, | |
AutoModelForCausalLM, | |
LogitsProcessorList) | |
# from local_tokenizers.tokenization_llama import LLaMATokenizer | |
from transformers import GPT2TokenizerFast | |
OPT_TOKENIZER = GPT2TokenizerFast | |
from watermark_processor import WatermarkLogitsProcessor, WatermarkDetector | |
# ALPACA_MODEL_NAME = "alpaca" | |
# ALPACA_MODEL_TOKENIZER = LLaMATokenizer | |
# ALPACA_TOKENIZER_PATH = "/cmlscratch/jkirchen/llama" | |
# FIXME correct lengths for all models | |
API_MODEL_MAP = { | |
"bigscience/bloom" : {"max_length": 1000, "gamma": 0.5, "delta": 2.0}, | |
"bigscience/bloomz" : {"max_length": 1000, "gamma": 0.5, "delta": 2.0}, | |
"google/flan-ul2" : {"max_length": 1000, "gamma": 0.5, "delta": 2.0}, | |
"google/flan-t5-xxl" : {"max_length": 1000, "gamma": 0.5, "delta": 2.0}, | |
"EleutherAI/gpt-neox-20b" : {"max_length": 1000, "gamma": 0.5, "delta": 2.0}, | |
} | |
def str2bool(v): | |
"""Util function for user friendly boolean flag args""" | |
if isinstance(v, bool): | |
return v | |
if v.lower() in ('yes', 'true', 't', 'y', '1'): | |
return True | |
elif v.lower() in ('no', 'false', 'f', 'n', '0'): | |
return False | |
else: | |
raise argparse.ArgumentTypeError('Boolean value expected.') | |
def parse_args(): | |
"""Command line argument specification""" | |
parser = argparse.ArgumentParser(description="A minimum working example of applying the watermark to any LLM that supports the huggingface 🤗 `generate` API") | |
parser.add_argument( | |
"--run_gradio", | |
type=str2bool, | |
default=True, | |
help="Whether to launch as a gradio demo. Set to False if not installed and want to just run the stdout version.", | |
) | |
parser.add_argument( | |
"--demo_public", | |
type=str2bool, | |
default=False, | |
help="Whether to expose the gradio demo to the internet.", | |
) | |
parser.add_argument( | |
"--model_name_or_path", | |
type=str, | |
default="facebook/opt-6.7b", | |
help="Main model, path to pretrained model or model identifier from huggingface.co/models.", | |
) | |
parser.add_argument( | |
"--prompt_max_length", | |
type=int, | |
default=None, | |
help="Truncation length for prompt, overrides model config's max length field.", | |
) | |
parser.add_argument( | |
"--max_new_tokens", | |
type=int, | |
default=200, | |
help="Maximmum number of new tokens to generate.", | |
) | |
parser.add_argument( | |
"--generation_seed", | |
type=int, | |
default=123, | |
help="Seed for setting the torch global rng prior to generation.", | |
) | |
parser.add_argument( | |
"--use_sampling", | |
type=str2bool, | |
default=True, | |
help="Whether to generate using multinomial sampling.", | |
) | |
parser.add_argument( | |
"--sampling_temp", | |
type=float, | |
default=0.7, | |
help="Sampling temperature to use when generating using multinomial sampling.", | |
) | |
parser.add_argument( | |
"--n_beams", | |
type=int, | |
default=1, | |
help="Number of beams to use for beam search. 1 is normal greedy decoding", | |
) | |
parser.add_argument( | |
"--use_gpu", | |
type=str2bool, | |
default=True, | |
help="Whether to run inference and watermark hashing/seeding/permutation on gpu.", | |
) | |
parser.add_argument( | |
"--seeding_scheme", | |
type=str, | |
default="simple_1", | |
help="Seeding scheme to use to generate the greenlists at each generation and verification step.", | |
) | |
parser.add_argument( | |
"--gamma", | |
type=float, | |
default=0.25, | |
help="The fraction of the vocabulary to partition into the greenlist at each generation and verification step.", | |
) | |
parser.add_argument( | |
"--delta", | |
type=float, | |
default=2.0, | |
help="The amount/bias to add to each of the greenlist token logits before each token sampling step.", | |
) | |
parser.add_argument( | |
"--normalizers", | |
type=str, | |
default="", | |
help="Single or comma separated list of the preprocessors/normalizer names to use when performing watermark detection.", | |
) | |
parser.add_argument( | |
"--ignore_repeated_bigrams", | |
type=str2bool, | |
default=False, | |
help="Whether to use the detection method that only counts each unqiue bigram once as either a green or red hit.", | |
) | |
parser.add_argument( | |
"--detection_z_threshold", | |
type=float, | |
default=4.0, | |
help="The test statistic threshold for the detection hypothesis test.", | |
) | |
parser.add_argument( | |
"--select_green_tokens", | |
type=str2bool, | |
default=True, | |
help="How to treat the permuation when selecting the greenlist tokens at each step. Legacy is (False) to pick the complement/reds first.", | |
) | |
parser.add_argument( | |
"--skip_model_load", | |
type=str2bool, | |
default=False, | |
help="Skip the model loading to debug the interface.", | |
) | |
parser.add_argument( | |
"--seed_separately", | |
type=str2bool, | |
default=True, | |
help="Whether to call the torch seed function before both the unwatermarked and watermarked generate calls.", | |
) | |
parser.add_argument( | |
"--load_fp16", | |
type=str2bool, | |
default=False, | |
help="Whether to run model in float16 precsion.", | |
) | |
args = parser.parse_args() | |
return args | |
def load_model(args): | |
"""Load and return the model and tokenizer""" | |
args.is_seq2seq_model = any([(model_type in args.model_name_or_path) for model_type in ["t5","T0"]]) | |
args.is_decoder_only_model = any([(model_type in args.model_name_or_path) for model_type in ["gpt","opt","bloom"]]) | |
if args.is_seq2seq_model: | |
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path) | |
elif args.is_decoder_only_model: | |
if args.load_fp16: | |
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,torch_dtype=torch.float16, device_map='auto') | |
else: | |
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path) | |
else: | |
raise ValueError(f"Unknown model type: {args.model_name_or_path}") | |
if args.use_gpu: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
if args.load_fp16: | |
pass | |
else: | |
model = model.to(device) | |
else: | |
device = "cpu" | |
model.eval() | |
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) | |
return model, tokenizer, device | |
from text_generation import InferenceAPIClient | |
from requests.exceptions import ReadTimeout | |
def generate_with_api(prompt, args): | |
hf_api_key = os.environ.get("HF_API_KEY") | |
if hf_api_key is None: | |
raise ValueError("HF_API_KEY environment variable not set, cannot use HF API to generate text.") | |
client = InferenceAPIClient(args.model_name_or_path, token=hf_api_key, timeout=60) | |
assert args.n_beams == 1, "HF API models do not support beam search." | |
generation_params = { | |
"max_new_tokens": args.max_new_tokens, | |
"do_sample": args.use_sampling, | |
} | |
if args.use_sampling: | |
generation_params["temperature"] = args.sampling_temp | |
generation_params["seed"] = args.generation_seed | |
timeout_msg = "[Model API timeout error. Try reducing the max_new_tokens parameter or the prompt length.]" | |
try: | |
generation_params["watermark"] = False | |
without_watermark_iterator = client.generate_stream(prompt, **generation_params) | |
except ReadTimeout as e: | |
print(e) | |
without_watermark_iterator = (char for char in timeout_msg) | |
try: | |
generation_params["watermark"] = True | |
with_watermark_iterator = client.generate_stream(prompt, **generation_params) | |
except ReadTimeout as e: | |
print(e) | |
with_watermark_iterator = (char for char in timeout_msg) | |
all_without_words, all_with_words = "", "" | |
for without_word, with_word in zip(without_watermark_iterator, with_watermark_iterator): | |
all_without_words += without_word.token.text | |
all_with_words += with_word.token.text | |
yield all_without_words, all_with_words | |
def check_prompt(prompt, args, tokenizer, model=None, device=None): | |
# This applies to both the local and API model scenarios | |
if args.model_name_or_path in API_MODEL_MAP: | |
args.prompt_max_length = API_MODEL_MAP[args.model_name_or_path]["max_length"] | |
elif hasattr(model.config,"max_position_embedding"): | |
args.prompt_max_length = model.config.max_position_embeddings-args.max_new_tokens | |
else: | |
args.prompt_max_length = 2048-args.max_new_tokens | |
tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=args.prompt_max_length).to(device) | |
truncation_warning = True if tokd_input["input_ids"].shape[-1] == args.prompt_max_length else False | |
redecoded_input = tokenizer.batch_decode(tokd_input["input_ids"], skip_special_tokens=True)[0] | |
return (redecoded_input, | |
int(truncation_warning), | |
args) | |
def generate(prompt, args, tokenizer, model=None, device=None): | |
"""Instatiate the WatermarkLogitsProcessor according to the watermark parameters | |
and generate watermarked text by passing it to the generate method of the model | |
as a logits processor. """ | |
print(f"Generating with {args}") | |
print(f"Prompt: {prompt}") | |
if args.model_name_or_path in API_MODEL_MAP: | |
api_outputs = generate_with_api(prompt, args) | |
yield from api_outputs | |
else: | |
tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=args.prompt_max_length).to(device) | |
watermark_processor = WatermarkLogitsProcessor(vocab=list(tokenizer.get_vocab().values()), | |
gamma=args.gamma, | |
delta=args.delta, | |
seeding_scheme=args.seeding_scheme, | |
select_green_tokens=args.select_green_tokens) | |
gen_kwargs = dict(max_new_tokens=args.max_new_tokens) | |
if args.use_sampling: | |
gen_kwargs.update(dict( | |
do_sample=True, | |
top_k=0, | |
temperature=args.sampling_temp | |
)) | |
else: | |
gen_kwargs.update(dict( | |
num_beams=args.n_beams | |
)) | |
generate_without_watermark = partial( | |
model.generate, | |
**gen_kwargs | |
) | |
generate_with_watermark = partial( | |
model.generate, | |
logits_processor=LogitsProcessorList([watermark_processor]), | |
**gen_kwargs | |
) | |
torch.manual_seed(args.generation_seed) | |
output_without_watermark = generate_without_watermark(**tokd_input) | |
# optional to seed before second generation, but will not be the same again generally, unless delta==0.0, no-op watermark | |
if args.seed_separately: | |
torch.manual_seed(args.generation_seed) | |
output_with_watermark = generate_with_watermark(**tokd_input) | |
if args.is_decoder_only_model: | |
# need to isolate the newly generated tokens | |
output_without_watermark = output_without_watermark[:,tokd_input["input_ids"].shape[-1]:] | |
output_with_watermark = output_with_watermark[:,tokd_input["input_ids"].shape[-1]:] | |
decoded_output_without_watermark = tokenizer.batch_decode(output_without_watermark, skip_special_tokens=True)[0] | |
decoded_output_with_watermark = tokenizer.batch_decode(output_with_watermark, skip_special_tokens=True)[0] | |
# mocking the API outputs in a whitespace split generator style | |
all_without_words, all_with_words = "", "" | |
for without_word, with_word in zip(decoded_output_without_watermark.split(), decoded_output_with_watermark.split()): | |
all_without_words += without_word + " " | |
all_with_words += with_word + " " | |
yield all_without_words, all_with_words | |
def format_names(s): | |
"""Format names for the gradio demo interface""" | |
s=s.replace("num_tokens_scored","Tokens Counted (T)") | |
s=s.replace("num_green_tokens","# Tokens in Greenlist") | |
s=s.replace("green_fraction","Fraction of T in Greenlist") | |
s=s.replace("z_score","z-score") | |
s=s.replace("p_value","p value") | |
s=s.replace("prediction","Prediction") | |
s=s.replace("confidence","Confidence") | |
return s | |
def list_format_scores(score_dict, detection_threshold): | |
"""Format the detection metrics into a gradio dataframe input format""" | |
lst_2d = [] | |
for k,v in score_dict.items(): | |
if k=='green_fraction': | |
lst_2d.append([format_names(k), f"{v:.1%}"]) | |
elif k=='confidence': | |
lst_2d.append([format_names(k), f"{v:.3%}"]) | |
elif isinstance(v, float): | |
lst_2d.append([format_names(k), f"{v:.3g}"]) | |
elif isinstance(v, bool): | |
lst_2d.append([format_names(k), ("Watermarked" if v else "Human/Unwatermarked")]) | |
else: | |
lst_2d.append([format_names(k), f"{v}"]) | |
if "confidence" in score_dict: | |
lst_2d.insert(-2,["z-score Threshold", f"{detection_threshold}"]) | |
else: | |
lst_2d.insert(-1,["z-score Threshold", f"{detection_threshold}"]) | |
return lst_2d | |
def detect(input_text, args, tokenizer, device=None, return_green_token_mask=True): | |
"""Instantiate the WatermarkDetection object and call detect on | |
the input text returning the scores and outcome of the test""" | |
print(f"Detecting with {args}") | |
print(f"Detection Tokenizer: {type(tokenizer)}") | |
watermark_detector = WatermarkDetector(vocab=list(tokenizer.get_vocab().values()), | |
gamma=args.gamma, | |
seeding_scheme=args.seeding_scheme, | |
device=device, | |
tokenizer=tokenizer, | |
z_threshold=args.detection_z_threshold, | |
normalizers=args.normalizers, | |
ignore_repeated_bigrams=args.ignore_repeated_bigrams, | |
select_green_tokens=args.select_green_tokens) | |
error = False | |
green_token_mask = None | |
if input_text == "": | |
error = True | |
else: | |
try: | |
score_dict = watermark_detector.detect(input_text, return_green_token_mask=return_green_token_mask) | |
green_token_mask = score_dict.pop("green_token_mask", None) | |
output = list_format_scores(score_dict, watermark_detector.z_threshold) | |
except ValueError as e: | |
print(e) | |
error = True | |
if error: | |
output = [["Error","string too short to compute metrics"]] | |
output += [["",""] for _ in range(6)] | |
html_output = "" | |
if green_token_mask is not None: | |
# hack bc we need a fast tokenizer with charspan support | |
if "opt" in args.model_name_or_path: | |
tokenizer = OPT_TOKENIZER.from_pretrained(args.model_name_or_path) | |
tokens = tokenizer(input_text) | |
if tokens["input_ids"][0] == tokenizer.bos_token_id: | |
tokens["input_ids"] = tokens["input_ids"][1:] # ignore attention mask | |
skip = watermark_detector.min_prefix_len | |
charspans = [tokens.token_to_chars(i) for i in range(skip,len(tokens["input_ids"]))] | |
charspans = [cs for cs in charspans if cs is not None] # remove the special token spans | |
if len(charspans) != len(green_token_mask): breakpoint() | |
assert len(charspans) == len(green_token_mask) | |
tags = [(f'<span class="green">{input_text[cs.start:cs.end]}</span>' if m else f'<span class="red">{input_text[cs.start:cs.end]}</span>') for cs, m in zip(charspans, green_token_mask)] | |
html_output = f'<p>{" ".join(tags)}</p>' | |
return output, args, tokenizer, html_output | |
def run_gradio(args, model=None, device=None, tokenizer=None): | |
"""Define and launch the gradio demo interface""" | |
check_prompt_partial = partial(check_prompt, model=model, device=device) | |
generate_partial = partial(generate, model=model, device=device) | |
detect_partial = partial(detect, device=device) | |
css = """ | |
.green { color: black!important;line-height:1.9em; padding: 0.2em 0.2em; background: #ccffcc; border-radius:0.5rem;} | |
.red { color: black!important;line-height:1.9em; padding: 0.2em 0.2em; background: #ffad99; border-radius:0.5rem;} | |
""" | |
with gr.Blocks(css=css) as demo: | |
# Top section, greeting and instructions | |
with gr.Row(): | |
with gr.Column(scale=9): | |
gr.Markdown( | |
""" | |
## 💧 [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) 🔍 | |
Demo made possible by the HuggingFace 🤗 [text-generation-inference](https://github.com/huggingface/text-generation-inference) serving framework. | |
""" | |
) | |
with gr.Column(scale=1): | |
gr.Markdown( | |
""" | |
[![](https://badgen.net/badge/icon/GitHub?icon=github&label)](https://github.com/jwkirchenbauer/lm-watermarking) | |
""" | |
) | |
# if model_name_or_path at startup not one of the API models then add to dropdown | |
all_models = sorted(list(set(list(API_MODEL_MAP.keys())+[args.model_name_or_path]))) | |
model_selector = gr.Dropdown( | |
all_models, | |
value=args.model_name_or_path, | |
label="Language Model", | |
) | |
# Construct state for parameters, define updates and toggles | |
default_prompt = args.__dict__.pop("default_prompt") | |
session_args = gr.State(value=args) | |
# note that state obj automatically calls value if it's a callable, want to avoid calling tokenizer at startup | |
session_tokenizer = gr.State(value=lambda : tokenizer) | |
with gr.Tab("Welcome"): | |
with gr.Row(): | |
with gr.Column(scale=2): | |
gr.Markdown( | |
""" | |
Potential harms of large language models can be mitigated by *watermarking* a model's output. | |
*Watermarks* are embedded signals in the generated text that are invisible to humans but algorithmically | |
detectable, that allow *anyone* to later check whether a given span of text | |
was likely to have been generated by a model that uses the watermark. | |
This space showcases a watermarking approach that can be applied to _any_ generative language model. | |
For demonstration purposes, the space demos a selection of multi-billion parameter models (see the following note for caveats). | |
""" | |
) | |
with gr.Accordion("A note on the available models:",open=False): | |
gr.Markdown( | |
""" | |
This demo uses open-source language models. Today, these models are less powerful than proprietary commercial tools like ChatGPT, Claude, Bard, or Bing/Sydney. | |
Smaller models like OPT-6.7b are designed to "complete" your prompt, and are not fine-tuned to follow instructions. | |
For best results, prompt that model with a few sentences that form the beginning of a paragraph, and then allow it to "continue" your paragraph. | |
Some examples include the opening paragraph of a wikipedia article, or the first few sentences of a story. | |
Longer prompts that end mid-sentence will result in more fluent generations. | |
The larger models available in this demo are fine-tuned to follow instructions but have different strengths and will showcase different | |
types of watermark behavior. [BLOOMZ](https://huggingface.co/bigscience/bloomz) is an instruction tuned variant of [BLOOM (175B)](https://huggingface.co/bigscience/bloom) capable of following instructions in dozens of languages zero-shot | |
and can generate long and coherent paragraphs and stories given the right prompt. | |
The FLAN models [FLAN-t5-xxl (11B)](https://huggingface.co/google/flan-t5-xxl) and [FLAN-UL2 (20B)](https://huggingface.co/google/flan-ul2) are fine-tuned on a variety of in-context few-shot learning NLP tasks, | |
such as reasoning, and question answering. | |
Generally, short, low entropy scenarios where the model has very few choices in terms of correct/suitable responses to the prompt | |
will not exhibit as strong of a watermark presence, while longer watermarked outputs will produce higher detection statistics. | |
""" | |
) | |
gr.Markdown( | |
""" | |
**[Generate & Detect]**: The first tab shows that the watermark can be embedded with | |
negligible impact on text quality. You can try any prompt and compare the quality of | |
normal text (*Output Without Watermark*) to the watermarked text (*Output With Watermark*) below it. | |
You can also "see" the watermark by looking at the **Highlighted** tab where the tokens are | |
colored green or red depending on which list they are in. | |
Metrics on the right show that the watermark can be reliably detected given a reasonably small number of tokens (25-50). | |
Detection is very efficient and does not use the language model or its parameters. | |
**[Detector Only]**: You can also copy-paste the watermarked text (or any other text) | |
into the second tab. This can be used to see how many sentences you could remove and still detect the watermark. | |
You can also verify here that the detection has, by design, a low false-positive rate; | |
This means that human-generated text that you copy into this detector will not be marked as machine-generated. | |
You can find more details on how this watermark functions in our [ArXiv preprint](https://arxiv.org/abs/2301.10226). | |
""" | |
) | |
with gr.Column(scale=1): | |
gr.Markdown( | |
""" | |
![](https://drive.google.com/uc?export=view&id=1yVLPcjm-xvaCjQyc3FGLsWIU84v1QRoC) | |
""" | |
) | |
with gr.Tab("Generate & Detect"): | |
with gr.Row(): | |
prompt = gr.Textbox(label=f"Prompt", interactive=True,lines=10,max_lines=10, value=default_prompt) | |
with gr.Row(): | |
generate_btn = gr.Button("Generate") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
with gr.Tab("Output Without Watermark (Raw Text)"): | |
output_without_watermark = gr.Textbox(interactive=False,lines=14,max_lines=14) | |
with gr.Tab("Highlighted"): | |
html_without_watermark = gr.HTML(elem_id="html-without-watermark") | |
with gr.Column(scale=1): | |
without_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
with gr.Tab("Output With Watermark (Raw Text)"): | |
output_with_watermark = gr.Textbox(interactive=False,lines=14,max_lines=14) | |
with gr.Tab("Highlighted"): | |
html_with_watermark = gr.HTML(elem_id="html-with-watermark") | |
with gr.Column(scale=1): | |
with_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"],interactive=False,row_count=7,col_count=2) | |
redecoded_input = gr.Textbox(visible=False) | |
truncation_warning = gr.Number(visible=False) | |
def truncate_prompt(redecoded_input, truncation_warning, orig_prompt, args): | |
if truncation_warning: | |
return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]", args | |
else: | |
return orig_prompt, args | |
with gr.Tab("Detector Only"): | |
with gr.Row(): | |
with gr.Column(scale=2): | |
detection_input = gr.Textbox(label="Text to Analyze", interactive=True,lines=14,max_lines=14) | |
with gr.Column(scale=1): | |
detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2) | |
with gr.Row(): | |
detect_btn = gr.Button("Detect") | |
# Parameter selection group | |
with gr.Accordion("Advanced Settings",open=False): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown(f"#### Generation Parameters") | |
with gr.Row(): | |
decoding = gr.Radio(label="Decoding Method",choices=["multinomial", "greedy"], value=("multinomial" if args.use_sampling else "greedy")) | |
with gr.Row(): | |
sampling_temp = gr.Slider(label="Sampling Temperature", minimum=0.1, maximum=1.0, step=0.1, value=args.sampling_temp, visible=True) | |
with gr.Row(): | |
generation_seed = gr.Number(label="Generation Seed",value=args.generation_seed, interactive=True) | |
with gr.Row(): | |
n_beams = gr.Dropdown(label="Number of Beams",choices=list(range(1,11,1)), value=args.n_beams, visible=((not args.use_sampling) and (not args.model_name_or_path in API_MODEL_MAP))) | |
with gr.Row(): | |
max_new_tokens = gr.Slider(label="Max Generated Tokens", minimum=10, maximum=1000, step=10, value=args.max_new_tokens) | |
with gr.Column(scale=1): | |
gr.Markdown(f"#### Watermark Parameters") | |
with gr.Row(): | |
gamma = gr.Slider(label="gamma",minimum=0.1, maximum=0.9, step=0.05, value=args.gamma) | |
with gr.Row(): | |
delta = gr.Slider(label="delta",minimum=0.0, maximum=10.0, step=0.1, value=args.delta) | |
gr.Markdown(f"#### Detector Parameters") | |
with gr.Row(): | |
detection_z_threshold = gr.Slider(label="z-score threshold",minimum=0.0, maximum=10.0, step=0.1, value=args.detection_z_threshold) | |
with gr.Row(): | |
ignore_repeated_bigrams = gr.Checkbox(label="Ignore Bigram Repeats") | |
with gr.Row(): | |
normalizers = gr.CheckboxGroup(label="Normalizations", choices=["unicode", "homoglyphs", "truecase"], value=args.normalizers) | |
with gr.Row(): | |
gr.Markdown(f"_Note: sliders don't always update perfectly. Clicking on the bar or using the number window to the right can help. Window below shows the current settings._") | |
with gr.Row(): | |
current_parameters = gr.Textbox(label="Current Parameters", value=args) | |
with gr.Accordion("Legacy Settings",open=False): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
seed_separately = gr.Checkbox(label="Seed both generations separately", value=args.seed_separately) | |
with gr.Column(scale=1): | |
select_green_tokens = gr.Checkbox(label="Select 'greenlist' from partition", value=args.select_green_tokens) | |
with gr.Accordion("What do the settings do?",open=False): | |
gr.Markdown( | |
""" | |
#### Generation Parameters: | |
- **Decoding Method** : We can generate tokens from the model using either multinomial sampling or we can generate using greedy decoding. | |
- **Sampling Temperature** : If using multinomial sampling we can set the temperature of the sampling distribution. | |
0.0 is equivalent to greedy decoding, and 1.0 is the maximum amount of variability/entropy in the next token distribution. | |
0.7 strikes a nice balance between faithfulness to the model's estimate of top candidates while adding variety. Does not apply for greedy decoding. | |
- **Generation Seed** : The integer to pass to the torch random number generator before running generation. Makes the multinomial sampling strategy | |
outputs reproducible. Does not apply for greedy decoding. | |
- **Number of Beams** : When using greedy decoding, we can also set the number of beams to > 1 to enable beam search. | |
This is not implemented/excluded from paper for multinomial sampling but may be added in future. | |
- **Max Generated Tokens** : The `max_new_tokens` parameter passed to the generation method to stop the output at a certain number of new tokens. | |
Note that the model is free to generate fewer tokens depending on the prompt. | |
Implicitly this sets the maximum number of prompt tokens possible as the model's maximum input length minus `max_new_tokens`, | |
and inputs will be truncated accordingly. | |
#### Watermark Parameters: | |
- **gamma** : The fraction of the vocabulary to be partitioned into the greenlist at each generation step. | |
Smaller gamma values create a stronger watermark by enabling the watermarked model to achieve | |
a greater differentiation from human/unwatermarked text because it is preferentially sampling | |
from a smaller green set making those tokens less likely to occur by chance. | |
- **delta** : The amount of positive bias to add to the logits of every token in the greenlist | |
at each generation step before sampling/choosing the next token. Higher delta values | |
mean that the greenlist tokens are more heavily preferred by the watermarked model | |
and as the bias becomes very large the watermark transitions from "soft" to "hard". | |
For a hard watermark, nearly all tokens are green, but this can have a detrimental effect on | |
generation quality, especially when there is not a lot of flexibility in the distribution. | |
#### Detector Parameters: | |
- **z-score threshold** : the z-score cuttoff for the hypothesis test. Higher thresholds (such as 4.0) make | |
_false positives_ (predicting that human/unwatermarked text is watermarked) very unlikely | |
as a genuine human text with a significant number of tokens will almost never achieve | |
that high of a z-score. Lower thresholds will capture more _true positives_ as some watermarked | |
texts will contain less green tokens and achive a lower z-score, but still pass the lower bar and | |
be flagged as "watermarked". However, a lowere threshold will increase the chance that human text | |
that contains a slightly higher than average number of green tokens is erroneously flagged. | |
4.0-5.0 offers extremely low false positive rates while still accurately catching most watermarked text. | |
- **Ignore Bigram Repeats** : This alternate detection algorithm only considers the unique bigrams in the text during detection, | |
computing the greenlists based on the first in each pair and checking whether the second falls within the list. | |
This means that `T` is now the unique number of bigrams in the text, which becomes less than the total | |
number of tokens generated if the text contains a lot of repetition. See the paper for a more detailed discussion. | |
- **Normalizations** : we implement a few basic normaliations to defend against various adversarial perturbations of the | |
text analyzed during detection. Currently we support converting all chracters to unicode, | |
replacing homoglyphs with a canonical form, and standardizing the capitalization. | |
See the paper for a detailed discussion of input normalization. | |
""" | |
) | |
with gr.Accordion("What do the output metrics mean?",open=False): | |
gr.Markdown( | |
""" | |
- `z-score threshold` : The cuttoff for the hypothesis test | |
- `Tokens Counted (T)` : The number of tokens in the output that were counted by the detection algorithm. | |
The first token is ommitted in the simple, single token seeding scheme since there is no way to generate | |
a greenlist for it as it has no prefix token(s). Under the "Ignore Bigram Repeats" detection algorithm, | |
described in the bottom panel, this can be much less than the total number of tokens generated if there is a lot of repetition. | |
- `# Tokens in Greenlist` : The number of tokens that were observed to fall in their respective greenlist | |
- `Fraction of T in Greenlist` : The `# Tokens in Greenlist` / `T`. This is expected to be approximately `gamma` for human/unwatermarked text. | |
- `z-score` : The test statistic for the detection hypothesis test. If larger than the `z-score threshold` | |
we "reject the null hypothesis" that the text is human/unwatermarked, and conclude it is watermarked | |
- `p value` : The likelihood of observing the computed `z-score` under the null hypothesis. This is the likelihood of | |
observing the `Fraction of T in Greenlist` given that the text was generated without knowledge of the watermark procedure/greenlists. | |
If this is extremely _small_ we are confident that this many green tokens was not chosen by random chance. | |
- `prediction` : The outcome of the hypothesis test - whether the observed `z-score` was higher than the `z-score threshold` | |
- `confidence` : If we reject the null hypothesis, and the `prediction` is "Watermarked", then we report 1-`p value` to represent | |
the confidence of the detection based on the unlikeliness of this `z-score` observation. | |
""" | |
) | |
gr.HTML(""" | |
<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. | |
Follow the github link at the top and host the demo on your own GPU hardware to test out larger models. | |
<br/> | |
<a href="https://huggingface.co/spaces/tomg-group-umd/lm-watermarking?duplicate=true"> | |
<img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a> | |
<p/> | |
""") | |
# Register main generation tab click, outputing generations as well as a the encoded+redecoded+potentially truncated prompt and flag, then call detection | |
generate_btn.click(fn=check_prompt_partial, inputs=[prompt,session_args,session_tokenizer], outputs=[redecoded_input, truncation_warning, session_args]).success( | |
fn=generate_partial, inputs=[redecoded_input,session_args,session_tokenizer], outputs=[output_without_watermark, output_with_watermark]).success( | |
fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer,html_without_watermark]).success( | |
fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer,html_with_watermark]) | |
# Show truncated version of prompt if truncation occurred | |
redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input,truncation_warning,prompt,session_args], outputs=[prompt,session_args]) | |
# Register main detection tab click | |
detect_btn.click(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result, session_args,session_tokenizer], api_name="detection") | |
# State management logic | |
# update callbacks that change the state dict | |
def update_model(session_state, value): session_state.model_name_or_path = value; return session_state | |
def update_sampling_temp(session_state, value): session_state.sampling_temp = float(value); return session_state | |
def update_generation_seed(session_state, value): session_state.generation_seed = int(value); return session_state | |
def update_gamma(session_state, value): session_state.gamma = float(value); return session_state | |
def update_delta(session_state, value): session_state.delta = float(value); return session_state | |
def update_detection_z_threshold(session_state, value): session_state.detection_z_threshold = float(value); return session_state | |
def update_decoding(session_state, value): | |
if value == "multinomial": | |
session_state.use_sampling = True | |
elif value == "greedy": | |
session_state.use_sampling = False | |
return session_state | |
def toggle_sampling_vis(value): | |
if value == "multinomial": | |
return gr.update(visible=True) | |
elif value == "greedy": | |
return gr.update(visible=False) | |
def toggle_sampling_vis_inv(value): | |
if value == "multinomial": | |
return gr.update(visible=False) | |
elif value == "greedy": | |
return gr.update(visible=True) | |
# if model name is in the list of api models, set the num beams parameter to 1 and hide n_beams | |
def toggle_vis_for_api_model(value): | |
if value in API_MODEL_MAP: | |
return gr.update(visible=False) | |
else: | |
return gr.update(visible=True) | |
def toggle_beams_for_api_model(value, orig_n_beams): | |
if value in API_MODEL_MAP: | |
return gr.update(value=1) | |
else: | |
return gr.update(value=orig_n_beams) | |
# if model name is in the list of api models, set the interactive parameter to false | |
def toggle_interactive_for_api_model(value): | |
if value in API_MODEL_MAP: | |
return gr.update(interactive=False) | |
else: | |
return gr.update(interactive=True) | |
# if model name is in the list of api models, set gamma and delta based on API map | |
def toggle_gamma_for_api_model(value, orig_gamma): | |
if value in API_MODEL_MAP: | |
return gr.update(value=API_MODEL_MAP[value]["gamma"]) | |
else: | |
return gr.update(value=orig_gamma) | |
def toggle_delta_for_api_model(value, orig_delta): | |
if value in API_MODEL_MAP: | |
return gr.update(value=API_MODEL_MAP[value]["delta"]) | |
else: | |
return gr.update(value=orig_delta) | |
def update_n_beams(session_state, value): session_state.n_beams = value; return session_state | |
def update_max_new_tokens(session_state, value): session_state.max_new_tokens = int(value); return session_state | |
def update_ignore_repeated_bigrams(session_state, value): session_state.ignore_repeated_bigrams = value; return session_state | |
def update_normalizers(session_state, value): session_state.normalizers = value; return session_state | |
def update_seed_separately(session_state, value): session_state.seed_separately = value; return session_state | |
def update_select_green_tokens(session_state, value): session_state.select_green_tokens = value; return session_state | |
def update_tokenizer(model_name_or_path): | |
# if model_name_or_path == ALPACA_MODEL_NAME: | |
# return ALPACA_MODEL_TOKENIZER.from_pretrained(ALPACA_TOKENIZER_PATH) | |
# else: | |
return AutoTokenizer.from_pretrained(model_name_or_path) | |
# registering callbacks for toggling the visibilty of certain parameters based on the values of others | |
decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[sampling_temp]) | |
decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[generation_seed]) | |
decoding.change(toggle_sampling_vis_inv,inputs=[decoding], outputs=[n_beams]) | |
model_selector.change(toggle_vis_for_api_model,inputs=[model_selector], outputs=[n_beams]) | |
decoding.change(toggle_vis_for_api_model,inputs=[model_selector], outputs=[n_beams]) | |
model_selector.change(toggle_beams_for_api_model,inputs=[model_selector,n_beams], outputs=[n_beams]) | |
model_selector.change(toggle_interactive_for_api_model,inputs=[model_selector], outputs=[gamma]) | |
model_selector.change(toggle_interactive_for_api_model,inputs=[model_selector], outputs=[delta]) | |
model_selector.change(toggle_gamma_for_api_model,inputs=[model_selector,gamma], outputs=[gamma]) | |
model_selector.change(toggle_delta_for_api_model,inputs=[model_selector,delta], outputs=[delta]) | |
model_selector.change(update_tokenizer,inputs=[model_selector], outputs=[session_tokenizer]) | |
# registering all state update callbacks | |
model_selector.change(update_model,inputs=[session_args, model_selector], outputs=[session_args]) | |
decoding.change(update_decoding,inputs=[session_args, decoding], outputs=[session_args]) | |
sampling_temp.change(update_sampling_temp,inputs=[session_args, sampling_temp], outputs=[session_args]) | |
generation_seed.change(update_generation_seed,inputs=[session_args, generation_seed], outputs=[session_args]) | |
n_beams.change(update_n_beams,inputs=[session_args, n_beams], outputs=[session_args]) | |
max_new_tokens.change(update_max_new_tokens,inputs=[session_args, max_new_tokens], outputs=[session_args]) | |
gamma.change(update_gamma,inputs=[session_args, gamma], outputs=[session_args]) | |
delta.change(update_delta,inputs=[session_args, delta], outputs=[session_args]) | |
detection_z_threshold.change(update_detection_z_threshold,inputs=[session_args, detection_z_threshold], outputs=[session_args]) | |
ignore_repeated_bigrams.change(update_ignore_repeated_bigrams,inputs=[session_args, ignore_repeated_bigrams], outputs=[session_args]) | |
normalizers.change(update_normalizers,inputs=[session_args, normalizers], outputs=[session_args]) | |
seed_separately.change(update_seed_separately,inputs=[session_args, seed_separately], outputs=[session_args]) | |
select_green_tokens.change(update_select_green_tokens,inputs=[session_args, select_green_tokens], outputs=[session_args]) | |
# register additional callback on button clicks that updates the shown parameters window | |
generate_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters]) | |
detect_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters]) | |
model_selector.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters]) | |
# When the parameters change, display the update and also fire detection, since some detection params dont change the model output. | |
delta.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters]) | |
gamma.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters]) | |
gamma.change(fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer]) | |
gamma.change(fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer]) | |
gamma.change(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result,session_args,session_tokenizer]) | |
detection_z_threshold.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters]) | |
detection_z_threshold.change(fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer]) | |
detection_z_threshold.change(fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer]) | |
detection_z_threshold.change(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result,session_args,session_tokenizer]) | |
ignore_repeated_bigrams.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters]) | |
ignore_repeated_bigrams.change(fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer]) | |
ignore_repeated_bigrams.change(fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer]) | |
ignore_repeated_bigrams.change(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result,session_args,session_tokenizer]) | |
normalizers.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters]) | |
normalizers.change(fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer]) | |
normalizers.change(fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer]) | |
normalizers.change(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result,session_args,session_tokenizer]) | |
select_green_tokens.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters]) | |
select_green_tokens.change(fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer]) | |
select_green_tokens.change(fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer]) | |
select_green_tokens.change(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result,session_args,session_tokenizer]) | |
demo.queue(concurrency_count=3) | |
if args.demo_public: | |
demo.launch(share=True) # exposes app to the internet via randomly generated link | |
else: | |
demo.launch() | |
def main(args): | |
"""Run a command line version of the generation and detection operations | |
and optionally launch and serve the gradio demo""" | |
# Initial arg processing and log | |
args.normalizers = (args.normalizers.split(",") if args.normalizers else []) | |
print(args) | |
if not args.skip_model_load: | |
model, tokenizer, device = load_model(args) | |
else: | |
model, tokenizer, device = None, None, None | |
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) | |
if args.use_gpu: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
else: | |
device = "cpu" | |
# terrapin example | |
input_text = ( | |
"The diamondback terrapin or simply terrapin (Malaclemys terrapin) is a " | |
"species of turtle native to the brackish coastal tidal marshes of the " | |
"Northeastern and southern United States, and in Bermuda.[6] It belongs " | |
"to the monotypic genus Malaclemys. It has one of the largest ranges of " | |
"all turtles in North America, stretching as far south as the Florida Keys " | |
"and as far north as Cape Cod.[7] The name 'terrapin' is derived from the " | |
"Algonquian word torope.[8] It applies to Malaclemys terrapin in both " | |
"British English and American English. The name originally was used by " | |
"early European settlers in North America to describe these brackish-water " | |
"turtles that inhabited neither freshwater habitats nor the sea. It retains " | |
"this primary meaning in American English.[8] In British English, however, " | |
"other semi-aquatic turtle species, such as the red-eared slider, might " | |
"also be called terrapins. The common name refers to the diamond pattern " | |
"on top of its shell (carapace), but the overall pattern and coloration " | |
"vary greatly. The shell is usually wider at the back than in the front, " | |
"and from above it appears wedge-shaped. The shell coloring can vary " | |
"from brown to grey, and its body color can be grey, brown, yellow, " | |
"or white. All have a unique pattern of wiggly, black markings or spots " | |
"on their body and head. The diamondback terrapin has large webbed " | |
"feet.[9] The species is" | |
) | |
args.default_prompt = input_text | |
# Generate and detect, report to stdout | |
if not args.skip_model_load: | |
term_width = 80 | |
print("#"*term_width) | |
print("Prompt:") | |
print(input_text) | |
# a generator that yields (without_watermark, with_watermark) pairs | |
generator_outputs = generate(input_text, | |
args, | |
model=model, | |
device=device, | |
tokenizer=tokenizer) | |
# we need to iterate over it, | |
# but we only want the last output in this case | |
for out in generator_outputs: | |
decoded_output_without_watermark = out[0] | |
decoded_output_with_watermark = out[1] | |
without_watermark_detection_result = detect(decoded_output_without_watermark, | |
args, | |
device=device, | |
tokenizer=tokenizer, | |
return_green_token_mask=False) | |
with_watermark_detection_result = detect(decoded_output_with_watermark, | |
args, | |
device=device, | |
tokenizer=tokenizer, | |
return_green_token_mask=False) | |
print("#"*term_width) | |
print("Output without watermark:") | |
print(decoded_output_without_watermark) | |
print("-"*term_width) | |
print(f"Detection result @ {args.detection_z_threshold}:") | |
pprint(without_watermark_detection_result) | |
print("-"*term_width) | |
print("#"*term_width) | |
print("Output with watermark:") | |
print(decoded_output_with_watermark) | |
print("-"*term_width) | |
print(f"Detection result @ {args.detection_z_threshold}:") | |
pprint(with_watermark_detection_result) | |
print("-"*term_width) | |
# Launch the app to generate and detect interactively (implements the hf space demo) | |
if args.run_gradio: | |
run_gradio(args, model=model, tokenizer=tokenizer, device=device) | |
return | |
if __name__ == "__main__": | |
args = parse_args() | |
print(args) | |
main(args) |