jwkirchenbauer commited on
Commit
507fd5a
1 Parent(s): a7d76f1

added inference api functionality

Browse files
Files changed (2) hide show
  1. demo_watermark.py +178 -46
  2. requirements.txt +2 -1
demo_watermark.py CHANGED
@@ -32,6 +32,14 @@ from transformers import (AutoTokenizer,
32
 
33
  from watermark_processor import WatermarkLogitsProcessor, WatermarkDetector
34
 
 
 
 
 
 
 
 
 
35
  def str2bool(v):
36
  """Util function for user friendly boolean flag args"""
37
  if isinstance(v, bool):
@@ -200,13 +208,69 @@ def load_model(args):
200
 
201
  return model, tokenizer, device
202
 
203
- def generate(prompt, args, model=None, device=None, tokenizer=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  """Instatiate the WatermarkLogitsProcessor according to the watermark parameters
205
  and generate watermarked text by passing it to the generate method of the model
206
  as a logits processor. """
207
 
208
  print(f"Generating with {args}")
209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  watermark_processor = WatermarkLogitsProcessor(vocab=list(tokenizer.get_vocab().values()),
211
  gamma=args.gamma,
212
  delta=args.delta,
@@ -235,16 +299,6 @@ def generate(prompt, args, model=None, device=None, tokenizer=None):
235
  logits_processor=LogitsProcessorList([watermark_processor]),
236
  **gen_kwargs
237
  )
238
- if args.prompt_max_length:
239
- pass
240
- elif hasattr(model.config,"max_position_embedding"):
241
- args.prompt_max_length = model.config.max_position_embeddings-args.max_new_tokens
242
- else:
243
- args.prompt_max_length = 2048-args.max_new_tokens
244
-
245
- tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=args.prompt_max_length).to(device)
246
- truncation_warning = True if tokd_input["input_ids"].shape[-1] == args.prompt_max_length else False
247
- redecoded_input = tokenizer.batch_decode(tokd_input["input_ids"], skip_special_tokens=True)[0]
248
 
249
  torch.manual_seed(args.generation_seed)
250
  output_without_watermark = generate_without_watermark(**tokd_input)
@@ -266,8 +320,9 @@ def generate(prompt, args, model=None, device=None, tokenizer=None):
266
  int(truncation_warning),
267
  decoded_output_without_watermark,
268
  decoded_output_with_watermark,
269
- args)
270
- # decoded_output_with_watermark)
 
271
 
272
  def format_names(s):
273
  """Format names for the gradio demo interface"""
@@ -301,9 +356,12 @@ def list_format_scores(score_dict, detection_threshold):
301
  lst_2d.insert(-1,["z-score Threshold", f"{detection_threshold}"])
302
  return lst_2d
303
 
304
- def detect(input_text, args, device=None, tokenizer=None):
305
  """Instantiate the WatermarkDetection object and call detect on
306
  the input text returning the scores and outcome of the test"""
 
 
 
307
  watermark_detector = WatermarkDetector(vocab=list(tokenizer.get_vocab().values()),
308
  gamma=args.gamma,
309
  seeding_scheme=args.seeding_scheme,
@@ -313,20 +371,29 @@ def detect(input_text, args, device=None, tokenizer=None):
313
  normalizers=args.normalizers,
314
  ignore_repeated_bigrams=args.ignore_repeated_bigrams,
315
  select_green_tokens=args.select_green_tokens)
316
- if len(input_text)-1 > watermark_detector.min_prefix_len:
317
- score_dict = watermark_detector.detect(input_text)
318
- # output = str_format_scores(score_dict, watermark_detector.z_threshold)
319
- output = list_format_scores(score_dict, watermark_detector.z_threshold)
320
  else:
321
- # output = (f"Error: string not long enough to compute watermark presence.")
 
 
 
 
 
 
 
322
  output = [["Error","string too short to compute metrics"]]
323
  output += [["",""] for _ in range(6)]
324
- return output, args
325
 
326
  def run_gradio(args, model=None, device=None, tokenizer=None):
327
  """Define and launch the gradio demo interface"""
328
- generate_partial = partial(generate, model=model, device=device, tokenizer=tokenizer)
329
- detect_partial = partial(detect, device=device, tokenizer=tokenizer)
 
 
330
 
331
  with gr.Blocks() as demo:
332
  # Top section, greeting and instructions
@@ -343,11 +410,20 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
343
  [![](https://badgen.net/badge/icon/GitHub?icon=github&label)](https://github.com/jwkirchenbauer/lm-watermarking)
344
  """
345
  )
346
- gr.Markdown(f"Language model: {args.model_name_or_path} {'(float16 mode)' if args.load_fp16 else ''}")
 
 
 
 
 
 
 
347
 
348
  # Construct state for parameters, define updates and toggles
349
  default_prompt = args.__dict__.pop("default_prompt")
350
  session_args = gr.State(value=args)
 
 
351
 
352
  with gr.Tab("Welcome"):
353
  with gr.Row():
@@ -448,7 +524,7 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
448
  with gr.Row():
449
  generation_seed = gr.Number(label="Generation Seed",value=args.generation_seed, interactive=True)
450
  with gr.Row():
451
- n_beams = gr.Dropdown(label="Number of Beams",choices=list(range(1,11,1)), value=args.n_beams, visible=(not args.use_sampling))
452
  with gr.Row():
453
  max_new_tokens = gr.Slider(label="Max Generated Tokens", minimum=10, maximum=1000, step=10, value=args.max_new_tokens)
454
 
@@ -561,18 +637,19 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
561
  """)
562
 
563
  # Register main generation tab click, outputing generations as well as a the encoded+redecoded+potentially truncated prompt and flag
564
- generate_btn.click(fn=generate_partial, inputs=[prompt,session_args], outputs=[redecoded_input, truncation_warning, output_without_watermark, output_with_watermark,session_args])
565
  # Show truncated version of prompt if truncation occurred
566
  redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input,truncation_warning,prompt,session_args], outputs=[prompt,session_args])
567
  # Call detection when the outputs (of the generate function) are updated
568
- output_without_watermark.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
569
- output_with_watermark.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
570
  # Register main detection tab click
571
- # detect_btn.click(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result, session_args])
572
- detect_btn.click(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result, session_args], api_name="detection")
573
 
574
  # State management logic
575
  # update callbacks that change the state dict
 
576
  def update_sampling_temp(session_state, value): session_state.sampling_temp = float(value); return session_state
577
  def update_generation_seed(session_state, value): session_state.generation_seed = int(value); return session_state
578
  def update_gamma(session_state, value): session_state.gamma = float(value); return session_state
@@ -594,17 +671,56 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
594
  return gr.update(visible=False)
595
  elif value == "greedy":
596
  return gr.update(visible=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
597
  def update_n_beams(session_state, value): session_state.n_beams = value; return session_state
598
  def update_max_new_tokens(session_state, value): session_state.max_new_tokens = int(value); return session_state
599
  def update_ignore_repeated_bigrams(session_state, value): session_state.ignore_repeated_bigrams = value; return session_state
600
  def update_normalizers(session_state, value): session_state.normalizers = value; return session_state
601
  def update_seed_separately(session_state, value): session_state.seed_separately = value; return session_state
602
  def update_select_green_tokens(session_state, value): session_state.select_green_tokens = value; return session_state
603
- # registering callbacks for toggling the visibilty of certain parameters
 
604
  decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[sampling_temp])
605
  decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[generation_seed])
606
  decoding.change(toggle_sampling_vis_inv,inputs=[decoding], outputs=[n_beams])
 
 
 
 
 
 
 
 
607
  # registering all state update callbacks
 
608
  decoding.change(update_decoding,inputs=[session_args, decoding], outputs=[session_args])
609
  sampling_temp.change(update_sampling_temp,inputs=[session_args, sampling_temp], outputs=[session_args])
610
  generation_seed.change(update_generation_seed,inputs=[session_args, generation_seed], outputs=[session_args])
@@ -620,27 +736,29 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
620
  # register additional callback on button clicks that updates the shown parameters window
621
  generate_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
622
  detect_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
 
623
  # When the parameters change, display the update and fire detection, since some detection params dont change the model output.
 
624
  gamma.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
625
- gamma.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
626
- gamma.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
627
- gamma.change(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result,session_args])
628
  detection_z_threshold.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
629
- detection_z_threshold.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
630
- detection_z_threshold.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
631
- detection_z_threshold.change(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result,session_args])
632
  ignore_repeated_bigrams.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
633
- ignore_repeated_bigrams.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
634
- ignore_repeated_bigrams.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
635
- ignore_repeated_bigrams.change(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result,session_args])
636
  normalizers.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
637
- normalizers.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
638
- normalizers.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
639
- normalizers.change(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result,session_args])
640
  select_green_tokens.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
641
- select_green_tokens.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
642
- select_green_tokens.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
643
- select_green_tokens.change(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result,session_args])
644
 
645
 
646
  demo.queue(concurrency_count=3)
@@ -691,9 +809,23 @@ def main(args):
691
  "on their body and head. The diamondback terrapin has large webbed "
692
  "feet.[9] The species is"
693
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
694
 
695
  args.default_prompt = input_text
696
 
 
697
  # Generate and detect, report to stdout
698
  if not args.skip_model_load:
699
 
@@ -702,7 +834,7 @@ def main(args):
702
  print("Prompt:")
703
  print(input_text)
704
 
705
- _, _, decoded_output_without_watermark, decoded_output_with_watermark, _ = generate(input_text,
706
  args,
707
  model=model,
708
  device=device,
 
32
 
33
  from watermark_processor import WatermarkLogitsProcessor, WatermarkDetector
34
 
35
+ # FIXME correct lengths for all models
36
+ API_MODEL_MAP = {
37
+ "bigscience/bloomz" : {"max_length": 2048, "gamma": 0.5, "delta": 2.0},
38
+ "google/flan-ul2" : {"max_length": 2048, "gamma": 0.5, "delta": 2.0},
39
+ "google/flan-t5-xxl" : {"max_length": 2048, "gamma": 0.5, "delta": 2.0},
40
+ "EleutherAI/gpt-neox-20b" : {"max_length": 2048, "gamma": 0.5, "delta": 2.0},
41
+ }
42
+
43
  def str2bool(v):
44
  """Util function for user friendly boolean flag args"""
45
  if isinstance(v, bool):
 
208
 
209
  return model, tokenizer, device
210
 
211
+
212
+ from text_generation import InferenceAPIClient
213
+ def generate_with_api(prompt, args):
214
+ hf_api_key = os.environ.get("HF_API_KEY")
215
+ if hf_api_key is None:
216
+ raise ValueError("HF_API_KEY environment variable not set, cannot use HF API to generate text.")
217
+
218
+ client = InferenceAPIClient(args.model_name_or_path, token=hf_api_key)
219
+
220
+ assert args.n_beams == 1, "HF API models do not support beam search."
221
+ generation_params = {
222
+ "max_new_tokens": args.max_new_tokens,
223
+ "do_sample": args.use_sampling,
224
+ }
225
+ if args.use_sampling:
226
+ generation_params["temperature"] = args.sampling_temp
227
+ generation_params["seed"] = args.generation_seed
228
+
229
+ generation_params["watermarking"] = False
230
+ output = client.generate(prompt, **generation_params)
231
+ output_text_without_watermark = output.generated_text
232
+
233
+ generation_params["watermarking"] = True
234
+ output = client.generate(prompt, **generation_params)
235
+ output_text_with_watermark = output.generated_text
236
+
237
+ return (output_text_without_watermark,
238
+ output_text_with_watermark)
239
+
240
+
241
+ def generate(prompt, args, tokenizer, model=None, device=None):
242
  """Instatiate the WatermarkLogitsProcessor according to the watermark parameters
243
  and generate watermarked text by passing it to the generate method of the model
244
  as a logits processor. """
245
 
246
  print(f"Generating with {args}")
247
 
248
+ # This applies to both the local and API model scenarios
249
+ if args.prompt_max_length:
250
+ pass
251
+ elif args.model_name_or_path in API_MODEL_MAP:
252
+ args.prompt_max_length = API_MODEL_MAP[args.model_name_or_path]["max_length"]-args.max_new_tokens
253
+ elif hasattr(model.config,"max_position_embedding"):
254
+ args.prompt_max_length = model.config.max_position_embeddings-args.max_new_tokens
255
+ else:
256
+ args.prompt_max_length = 2048-args.max_new_tokens
257
+
258
+ tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=args.prompt_max_length).to(device)
259
+ truncation_warning = True if tokd_input["input_ids"].shape[-1] == args.prompt_max_length else False
260
+ redecoded_input = tokenizer.batch_decode(tokd_input["input_ids"], skip_special_tokens=True)[0]
261
+
262
+ if args.model_name_or_path in API_MODEL_MAP:
263
+ api_outputs = generate_with_api(prompt, args)
264
+ decoded_output_without_watermark = api_outputs[0]
265
+ decoded_output_with_watermark = api_outputs[1]
266
+ return (redecoded_input,
267
+ int(truncation_warning),
268
+ decoded_output_without_watermark,
269
+ decoded_output_with_watermark,
270
+ args,
271
+ tokenizer)
272
+
273
+
274
  watermark_processor = WatermarkLogitsProcessor(vocab=list(tokenizer.get_vocab().values()),
275
  gamma=args.gamma,
276
  delta=args.delta,
 
299
  logits_processor=LogitsProcessorList([watermark_processor]),
300
  **gen_kwargs
301
  )
 
 
 
 
 
 
 
 
 
 
302
 
303
  torch.manual_seed(args.generation_seed)
304
  output_without_watermark = generate_without_watermark(**tokd_input)
 
320
  int(truncation_warning),
321
  decoded_output_without_watermark,
322
  decoded_output_with_watermark,
323
+ args,
324
+ tokenizer)
325
+
326
 
327
  def format_names(s):
328
  """Format names for the gradio demo interface"""
 
356
  lst_2d.insert(-1,["z-score Threshold", f"{detection_threshold}"])
357
  return lst_2d
358
 
359
+ def detect(input_text, args, tokenizer, device=None):
360
  """Instantiate the WatermarkDetection object and call detect on
361
  the input text returning the scores and outcome of the test"""
362
+ print(f"Detecting with {args}")
363
+ print(f"Detection Tokenizer: {type(tokenizer)}")
364
+
365
  watermark_detector = WatermarkDetector(vocab=list(tokenizer.get_vocab().values()),
366
  gamma=args.gamma,
367
  seeding_scheme=args.seeding_scheme,
 
371
  normalizers=args.normalizers,
372
  ignore_repeated_bigrams=args.ignore_repeated_bigrams,
373
  select_green_tokens=args.select_green_tokens)
374
+ # if len(input_text)-1 > watermark_detector.min_prefix_len:
375
+ error = False
376
+ if input_text == "":
377
+ error = True
378
  else:
379
+ try:
380
+ score_dict = watermark_detector.detect(input_text)
381
+ # output = str_format_scores(score_dict, watermark_detector.z_threshold)
382
+ output = list_format_scores(score_dict, watermark_detector.z_threshold)
383
+ except ValueError as e:
384
+ print(e)
385
+ error = True
386
+ if error:
387
  output = [["Error","string too short to compute metrics"]]
388
  output += [["",""] for _ in range(6)]
389
+ return output, args, tokenizer
390
 
391
  def run_gradio(args, model=None, device=None, tokenizer=None):
392
  """Define and launch the gradio demo interface"""
393
+ # generate_partial = partial(generate, model=model, device=device, tokenizer=tokenizer)
394
+ # detect_partial = partial(detect, device=device, tokenizer=tokenizer)
395
+ generate_partial = partial(generate, model=model, device=device)
396
+ detect_partial = partial(detect, device=device)
397
 
398
  with gr.Blocks() as demo:
399
  # Top section, greeting and instructions
 
410
  [![](https://badgen.net/badge/icon/GitHub?icon=github&label)](https://github.com/jwkirchenbauer/lm-watermarking)
411
  """
412
  )
413
+ # gr.Markdown(f"Language model: {args.model_name_or_path} {'(float16 mode)' if args.load_fp16 else ''}")
414
+ # if model_name_or_path at startup not one of the API models then add to dropdown
415
+ all_models = sorted(list(set(list(API_MODEL_MAP.keys())+[args.model_name_or_path])))
416
+ model_selector = gr.Dropdown(
417
+ all_models,
418
+ value=args.model_name_or_path,
419
+ label="Language Model",
420
+ )
421
 
422
  # Construct state for parameters, define updates and toggles
423
  default_prompt = args.__dict__.pop("default_prompt")
424
  session_args = gr.State(value=args)
425
+ # note that state obj automatically calls value if it's a callable, want to avoid calling tokenizer at startup
426
+ session_tokenizer = gr.State(value=lambda : tokenizer)
427
 
428
  with gr.Tab("Welcome"):
429
  with gr.Row():
 
524
  with gr.Row():
525
  generation_seed = gr.Number(label="Generation Seed",value=args.generation_seed, interactive=True)
526
  with gr.Row():
527
+ n_beams = gr.Dropdown(label="Number of Beams",choices=list(range(1,11,1)), value=args.n_beams, visible=((not args.use_sampling) and (not args.model_name_or_path in API_MODEL_MAP)))
528
  with gr.Row():
529
  max_new_tokens = gr.Slider(label="Max Generated Tokens", minimum=10, maximum=1000, step=10, value=args.max_new_tokens)
530
 
 
637
  """)
638
 
639
  # Register main generation tab click, outputing generations as well as a the encoded+redecoded+potentially truncated prompt and flag
640
+ generate_btn.click(fn=generate_partial, inputs=[prompt,session_args,session_tokenizer], outputs=[redecoded_input, truncation_warning, output_without_watermark, output_with_watermark,session_args,session_tokenizer])
641
  # Show truncated version of prompt if truncation occurred
642
  redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input,truncation_warning,prompt,session_args], outputs=[prompt,session_args])
643
  # Call detection when the outputs (of the generate function) are updated
644
+ output_without_watermark.change(fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer])
645
+ output_with_watermark.change(fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer])
646
  # Register main detection tab click
647
+ # detect_btn.click(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result, session_args,session_tokenizer])
648
+ detect_btn.click(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result, session_args,session_tokenizer], api_name="detection")
649
 
650
  # State management logic
651
  # update callbacks that change the state dict
652
+ def update_model(session_state, value): session_state.model_name_or_path = value; return session_state
653
  def update_sampling_temp(session_state, value): session_state.sampling_temp = float(value); return session_state
654
  def update_generation_seed(session_state, value): session_state.generation_seed = int(value); return session_state
655
  def update_gamma(session_state, value): session_state.gamma = float(value); return session_state
 
671
  return gr.update(visible=False)
672
  elif value == "greedy":
673
  return gr.update(visible=True)
674
+ # if model name is in the list of api models, set the num beams parameter to 1 and hide n_beams
675
+ def toggle_vis_for_api_model(value):
676
+ if value in API_MODEL_MAP:
677
+ return gr.update(visible=False)
678
+ else:
679
+ return gr.update(visible=True)
680
+ def toggle_beams_for_api_model(value, orig_n_beams):
681
+ if value in API_MODEL_MAP:
682
+ return gr.update(value=1)
683
+ else:
684
+ return gr.update(value=orig_n_beams)
685
+ # if model name is in the list of api models, set the interactive parameter to false
686
+ def toggle_interactive_for_api_model(value):
687
+ if value in API_MODEL_MAP:
688
+ return gr.update(interactive=False)
689
+ else:
690
+ return gr.update(interactive=True)
691
+ # if model name is in the list of api models, set gamma and delta based on API map
692
+ def toggle_gamma_for_api_model(value, orig_gamma):
693
+ if value in API_MODEL_MAP:
694
+ return gr.update(value=API_MODEL_MAP[value]["gamma"])
695
+ else:
696
+ return gr.update(value=orig_gamma)
697
+ def toggle_delta_for_api_model(value, orig_delta):
698
+ if value in API_MODEL_MAP:
699
+ return gr.update(value=API_MODEL_MAP[value]["delta"])
700
+ else:
701
+ return gr.update(value=orig_delta)
702
+
703
  def update_n_beams(session_state, value): session_state.n_beams = value; return session_state
704
  def update_max_new_tokens(session_state, value): session_state.max_new_tokens = int(value); return session_state
705
  def update_ignore_repeated_bigrams(session_state, value): session_state.ignore_repeated_bigrams = value; return session_state
706
  def update_normalizers(session_state, value): session_state.normalizers = value; return session_state
707
  def update_seed_separately(session_state, value): session_state.seed_separately = value; return session_state
708
  def update_select_green_tokens(session_state, value): session_state.select_green_tokens = value; return session_state
709
+ def update_tokenizer(model_name_or_path): return AutoTokenizer.from_pretrained(model_name_or_path)
710
+ # registering callbacks for toggling the visibilty of certain parameters based on the values of others
711
  decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[sampling_temp])
712
  decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[generation_seed])
713
  decoding.change(toggle_sampling_vis_inv,inputs=[decoding], outputs=[n_beams])
714
+ model_selector.change(toggle_vis_for_api_model,inputs=[model_selector], outputs=[n_beams])
715
+ decoding.change(toggle_vis_for_api_model,inputs=[model_selector], outputs=[n_beams])
716
+ model_selector.change(toggle_beams_for_api_model,inputs=[model_selector,n_beams], outputs=[n_beams])
717
+ model_selector.change(toggle_interactive_for_api_model,inputs=[model_selector], outputs=[gamma])
718
+ model_selector.change(toggle_interactive_for_api_model,inputs=[model_selector], outputs=[delta])
719
+ model_selector.change(toggle_gamma_for_api_model,inputs=[model_selector,gamma], outputs=[gamma])
720
+ model_selector.change(toggle_delta_for_api_model,inputs=[model_selector,delta], outputs=[delta])
721
+ model_selector.change(update_tokenizer,inputs=[model_selector], outputs=[session_tokenizer])
722
  # registering all state update callbacks
723
+ model_selector.change(update_model,inputs=[session_args, model_selector], outputs=[session_args])
724
  decoding.change(update_decoding,inputs=[session_args, decoding], outputs=[session_args])
725
  sampling_temp.change(update_sampling_temp,inputs=[session_args, sampling_temp], outputs=[session_args])
726
  generation_seed.change(update_generation_seed,inputs=[session_args, generation_seed], outputs=[session_args])
 
736
  # register additional callback on button clicks that updates the shown parameters window
737
  generate_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
738
  detect_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
739
+ model_selector.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
740
  # When the parameters change, display the update and fire detection, since some detection params dont change the model output.
741
+ delta.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
742
  gamma.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
743
+ gamma.change(fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer])
744
+ gamma.change(fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer])
745
+ gamma.change(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result,session_args,session_tokenizer])
746
  detection_z_threshold.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
747
+ detection_z_threshold.change(fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer])
748
+ detection_z_threshold.change(fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer])
749
+ detection_z_threshold.change(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result,session_args,session_tokenizer])
750
  ignore_repeated_bigrams.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
751
+ ignore_repeated_bigrams.change(fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer])
752
+ ignore_repeated_bigrams.change(fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer])
753
+ ignore_repeated_bigrams.change(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result,session_args,session_tokenizer])
754
  normalizers.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
755
+ normalizers.change(fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer])
756
+ normalizers.change(fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer])
757
+ normalizers.change(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result,session_args,session_tokenizer])
758
  select_green_tokens.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
759
+ select_green_tokens.change(fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer])
760
+ select_green_tokens.change(fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer])
761
+ select_green_tokens.change(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result,session_args,session_tokenizer])
762
 
763
 
764
  demo.queue(concurrency_count=3)
 
809
  "on their body and head. The diamondback terrapin has large webbed "
810
  "feet.[9] The species is"
811
  )
812
+
813
+ # teaser example
814
+ # input_text = (
815
+ # "In this work, we study watermarking of language model output. "
816
+ # "A watermark is a hidden pattern in text that is imperceptible to humans, "
817
+ # "while making the text algorithmically identifiable as synthetic. "
818
+ # "We propose an efficient watermark that makes synthetic text detectable "
819
+ # "from short spans of tokens (as few as 25 words), while false-positives "
820
+ # "(where human text is marked as machine-generated) are statistically improbable. "
821
+ # "The watermark detection algorithm can be made public, enabling third parties "
822
+ # "(e.g., social media platforms) to run it themselves, or it can be kept private "
823
+ # "and run behind an API. We seek a watermark with the following properties:\n"
824
+ # )
825
 
826
  args.default_prompt = input_text
827
 
828
+
829
  # Generate and detect, report to stdout
830
  if not args.skip_model_load:
831
 
 
834
  print("Prompt:")
835
  print(input_text)
836
 
837
+ _, _, decoded_output_without_watermark, decoded_output_with_watermark, _, _ = generate(input_text,
838
  args,
839
  model=model,
840
  device=device,
requirements.txt CHANGED
@@ -5,4 +5,5 @@ scipy
5
  torch
6
  transformers
7
  tokenizers
8
- accelerate
 
 
5
  torch
6
  transformers
7
  tokenizers
8
+ accelerate
9
+ text-generation