# coding=utf-8 # Copyright 2023 Authors of "A Watermark for Large Language Models" # available at https://arxiv.org/abs/2301.10226 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import argparse from argparse import Namespace from pprint import pprint from functools import partial import numpy # for gradio hot reload import gradio as gr import torch from transformers import (AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, LogitsProcessorList) from watermark_processor import WatermarkLogitsProcessor, WatermarkDetector def str2bool(v): if isinstance(v, bool): return v if v.lower() in ('yes', 'true', 't', 'y', '1'): return True elif v.lower() in ('no', 'false', 'f', 'n', '0'): return False else: raise argparse.ArgumentTypeError('Boolean value expected.') def parse_args(): parser = argparse.ArgumentParser(description="A minimum working example of applying the watermark to any LLM that supports the huggingface 🤗 `generate` API") parser.add_argument( "--run_gradio", type=str2bool, default=True, help="Whether to launch as a gradio demo. Set to False if not installed and want to just run the stdout version.", ) parser.add_argument( "--demo_public", type=str2bool, default=False, help="Whether to expose the gradio demo to the internet.", ) parser.add_argument( "--model_name_or_path", type=str, default="facebook/opt-6.7b", help="Main model, path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--prompt_max_length", type=int, default=None, help="Truncation length for prompt, overrides model config's max length field.", ) parser.add_argument( "--max_new_tokens", type=int, default=200, help="Maximmum number of new tokens to generate.", ) parser.add_argument( "--generation_seed", type=int, default=123, help="Seed for setting the torch global rng prior to generation.", ) parser.add_argument( "--use_sampling", type=str2bool, default=True, help="Whether to generate using multinomial sampling.", ) parser.add_argument( "--sampling_temp", type=float, default=0.7, help="Sampling temperature to use when generating using multinomial sampling.", ) parser.add_argument( "--n_beams", type=int, default=1, help="Number of beams to use for beam search. 1 is normal greedy decoding", ) parser.add_argument( "--use_gpu", type=str2bool, default=True, help="Whether to run inference and watermark hashing/seeding/permutation on gpu.", ) parser.add_argument( "--seeding_scheme", type=str, default="simple_1", help="Seeding scheme to use to generate the greenlists at each generation and verification step.", ) parser.add_argument( "--gamma", type=float, default=0.25, help="The fraction of the vocabulary to partition into the greenlist at each generation and verification step.", ) parser.add_argument( "--delta", type=float, default=2.0, help="The amount/bias to add to each of the greenlist token logits before each token sampling step.", ) parser.add_argument( "--normalizers", type=str, default="", help="Single or comma separated list of the preprocessors/normalizer names to use when performing watermark detection.", ) parser.add_argument( "--ignore_repeated_bigrams", type=str2bool, default=False, help="Whether to use the detection method that only counts each unqiue bigram once as either a green or red hit.", ) parser.add_argument( "--detection_z_threshold", type=float, default=4.0, help="The test statistic threshold for the detection hypothesis test.", ) parser.add_argument( "--select_green_tokens", type=str2bool, default=True, help="How to treat the permuation when selecting the greenlist tokens at each step. Legacy is (False) to pick the complement/reds first.", ) parser.add_argument( "--skip_model_load", type=str2bool, default=False, help="Skip the model loading to debug the interface.", ) parser.add_argument( "--seed_separately", type=str2bool, default=True, help="Whether to call the torch seed function before both the unwatermarked and watermarked generate calls.", ) args = parser.parse_args() return args def load_model(args): args.is_seq2seq_model = any([(model_type in args.model_name_or_path) for model_type in ["t5","T0"]]) args.is_decoder_only_model = any([(model_type in args.model_name_or_path) for model_type in ["gpt","opt","bloom"]]) if args.is_seq2seq_model: model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path) elif args.is_decoder_only_model: model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path) else: raise ValueError(f"Unknown model type: {args.model_name_or_path}") if args.use_gpu: device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) else: device = "cpu" model.eval() tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) return model, tokenizer, device def generate(prompt, args, model=None, device=None, tokenizer=None): print(f"Generating with {args}") watermark_processor = WatermarkLogitsProcessor(vocab=list(tokenizer.get_vocab().values()), gamma=args.gamma, delta=args.delta, seeding_scheme=args.seeding_scheme, select_green_tokens=args.select_green_tokens) gen_kwargs = dict(max_new_tokens=args.max_new_tokens) if args.use_sampling: gen_kwargs.update(dict( do_sample=True, top_k=0, temperature=args.sampling_temp )) else: gen_kwargs.update(dict( num_beams=args.n_beams )) generate_without_watermark = partial( model.generate, **gen_kwargs ) generate_with_watermark = partial( model.generate, logits_processor=LogitsProcessorList([watermark_processor]), **gen_kwargs ) if args.prompt_max_length: pass elif hasattr(model.config,"max_position_embedding"): args.prompt_max_length = model.config.max_position_embeddings-args.max_new_tokens else: args.prompt_max_length = 2048-args.max_new_tokens tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=args.prompt_max_length).to(device) truncation_warning = True if tokd_input["input_ids"].shape[-1] == args.prompt_max_length else False redecoded_input = tokenizer.batch_decode(tokd_input["input_ids"], skip_special_tokens=True)[0] torch.manual_seed(args.generation_seed) output_without_watermark = generate_without_watermark(**tokd_input) # optional to seed before second generation, but will not be the same again generally, unless delta==0.0, no-op watermark if args.seed_separately: torch.manual_seed(args.generation_seed) output_with_watermark = generate_with_watermark(**tokd_input) if args.is_decoder_only_model: # need to isolate the newly generated tokens output_without_watermark = output_without_watermark[:,tokd_input["input_ids"].shape[-1]:] output_with_watermark = output_with_watermark[:,tokd_input["input_ids"].shape[-1]:] decoded_output_without_watermark = tokenizer.batch_decode(output_without_watermark, skip_special_tokens=True)[0] decoded_output_with_watermark = tokenizer.batch_decode(output_with_watermark, skip_special_tokens=True)[0] return (redecoded_input, int(truncation_warning), decoded_output_without_watermark, decoded_output_with_watermark, args) # decoded_output_with_watermark) def format_names(s): s=s.replace("num_tokens_scored","Tokens Counted (T)") s=s.replace("num_green_tokens","# Tokens in Greenlist") s=s.replace("green_fraction","Fraction of T in Greenlist") s=s.replace("z_score","z-score") s=s.replace("p_value","p value") return s # def str_format_scores(score_dict, detection_threshold): # output_str = f"@ z-score threshold={detection_threshold}:\n\n" # for k,v in score_dict.items(): # if k=='green_fraction': # output_str+=f"{format_names(k)}={v:.1%}" # elif k=='confidence': # output_str+=f"{format_names(k)}={v:.3%}" # elif isinstance(v, float): # output_str+=f"{format_names(k)}={v:.3g}" # else: # output_str += v # return output_str def list_format_scores(score_dict, detection_threshold): lst_2d = [] lst_2d.append(["z-score threshold", f"{detection_threshold}"]) for k,v in score_dict.items(): if k=='green_fraction': lst_2d.append([format_names(k), f"{v:.1%}"]) elif k=='confidence': lst_2d.append([format_names(k), f"{v:.3%}"]) elif isinstance(v, float): lst_2d.append([format_names(k), f"{v:.3g}"]) elif isinstance(v, bool): lst_2d.append([format_names(k), ("Watermarked" if v else "Human/Unwatermarked")]) else: lst_2d.append([format_names(k), f"{v}"]) return lst_2d def detect(input_text, args, device=None, tokenizer=None): watermark_detector = WatermarkDetector(vocab=list(tokenizer.get_vocab().values()), gamma=args.gamma, seeding_scheme=args.seeding_scheme, device=device, tokenizer=tokenizer, z_threshold=args.detection_z_threshold, normalizers=args.normalizers, ignore_repeated_bigrams=args.ignore_repeated_bigrams, select_green_tokens=args.select_green_tokens) if len(input_text)-1 > watermark_detector.min_prefix_len: score_dict = watermark_detector.detect(input_text) # output = str_format_scores(score_dict, watermark_detector.z_threshold) output = list_format_scores(score_dict, watermark_detector.z_threshold) else: # output = (f"Error: string not long enough to compute watermark presence.") output = [["Error","string too short to compute metrics"]] output += [["",""] for _ in range(6)] return output, args def run_gradio(args, model=None, device=None, tokenizer=None): generate_partial = partial(generate, model=model, device=device, tokenizer=tokenizer) detect_partial = partial(detect, device=device, tokenizer=tokenizer) with gr.Blocks() as demo: # Top section, greeting and instructions gr.Markdown("## 💧 [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) 🔍") gr.Markdown("[jwkirchenbauer/lm-watermarking![](https://badgen.net/badge/icon/GitHub?icon=github&label)](https://github.com/jwkirchenbauer/lm-watermarking)") with gr.Accordion("A note on model capability",open=False): gr.Markdown( """ The models that can be used in this demo are limited to those that are open source as well as fit on a single commodity GPU. In particular, there are few models above 10B parameters and way fewer trained using both Instruction finetuning or RLHF that are open source that we can use. Therefore, the model, in both it's un-watermarked (normal) and watermarked state, is not generally able to respond well to the kinds of prompts that a 100B+ Instruction and RLHF tuned model such as ChatGPT, Claude, or Bard is. We suggest you try prompts that give the model a few sentences and then allow it to 'continue' the prompt, as these weaker models are more capable in this simpler language modeling setting. """ ) # Construct state for parameters, define updates and toggles session_args = gr.State(value=args) with gr.Tab("Generate and Detect"): with gr.Row(): prompt = gr.Textbox(label=f"Prompt", interactive=True,lines=10,max_lines=10) with gr.Row(): generate_btn = gr.Button("Generate") with gr.Row(): with gr.Column(scale=2): output_without_watermark = gr.Textbox(label="Output Without Watermark", interactive=False,lines=14,max_lines=14) with gr.Column(scale=1): # without_watermark_detection_result = gr.Textbox(label="Detection Result", interactive=False,lines=14,max_lines=14) without_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2) with gr.Row(): with gr.Column(scale=2): output_with_watermark = gr.Textbox(label="Output With Watermark", interactive=False,lines=14,max_lines=14) with gr.Column(scale=1): # with_watermark_detection_result = gr.Textbox(label="Detection Result", interactive=False,lines=14,max_lines=14) with_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"],interactive=False,row_count=7,col_count=2) redecoded_input = gr.Textbox(visible=False) truncation_warning = gr.Number(visible=False) def truncate_prompt(redecoded_input, truncation_warning, orig_prompt, args): if truncation_warning: return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]", args else: return orig_prompt, args with gr.Tab("Detector Only"): with gr.Row(): with gr.Column(scale=2): detection_input = gr.Textbox(label="Text to Analyze", interactive=True,lines=14,max_lines=14) with gr.Column(scale=1): # detection_result = gr.Textbox(label="Detection Result", interactive=False,lines=14,max_lines=14) detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2) with gr.Row(): detect_btn = gr.Button("Detect") # Parameter selection group with gr.Accordion("Advanced Settings",open=False): with gr.Row(): with gr.Column(scale=1): gr.Markdown(f"#### Generation Parameters") with gr.Row(): decoding = gr.Radio(label="Decoding Method",choices=["multinomial", "greedy"], value=("multinomial" if args.use_sampling else "greedy")) with gr.Row(): sampling_temp = gr.Slider(label="Sampling Temperature", minimum=0.1, maximum=1.0, step=0.1, value=args.sampling_temp, visible=True) with gr.Row(): generation_seed = gr.Number(label="Generation Seed",value=args.generation_seed, interactive=True) with gr.Row(): n_beams = gr.Dropdown(label="Number of Beams",choices=list(range(1,11,1)), value=args.n_beams, visible=(not args.use_sampling)) with gr.Row(): max_new_tokens = gr.Slider(label="Max Generated Tokens", minimum=10, maximum=1000, step=10, value=args.max_new_tokens) with gr.Column(scale=1): gr.Markdown(f"#### Watermark Parameters") with gr.Row(): gamma = gr.Slider(label="gamma",minimum=0.1, maximum=0.9, step=0.05, value=args.gamma) with gr.Row(): delta = gr.Slider(label="delta",minimum=0.0, maximum=10.0, step=0.1, value=args.delta) gr.Markdown(f"#### Detector Parameters") with gr.Row(): detection_z_threshold = gr.Slider(label="z-score threshold",minimum=0.0, maximum=10.0, step=0.1, value=args.detection_z_threshold) with gr.Row(): ignore_repeated_bigrams = gr.Checkbox(label="Ignore Bigram Repeats") with gr.Row(): normalizers = gr.CheckboxGroup(label="Normalizations", choices=["unicode", "homoglyphs", "truecase"], value=args.normalizers) # with gr.Accordion("Actual submitted parameters:",open=False): with gr.Row(): gr.Markdown(f"_Note: sliders don't always update perfectly. Clicking on the bar or using the number window to the right can help. Window below shows the current settings._") with gr.Row(): current_parameters = gr.Textbox(label="Current Parameters", value=args) with gr.Accordion("Legacy Settings",open=False): with gr.Row(): with gr.Column(scale=1): seed_separately = gr.Checkbox(label="Seed both generations separately", value=args.seed_separately) with gr.Column(scale=1): select_green_tokens = gr.Checkbox(label="Select 'greenlist' from partition", value=args.select_green_tokens) gr.HTML("""
For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.