jwkirchenbauer commited on
Commit
4343565
1 Parent(s): a134a9d

more settings

Browse files
Files changed (1) hide show
  1. demo_watermark.py +9 -5
demo_watermark.py CHANGED
@@ -261,7 +261,7 @@ def detect(input_text, args, device=None, tokenizer=None):
261
 
262
  def run_gradio(args, model=None, device=None, tokenizer=None):
263
 
264
- generate_partial = partial(generate, model=model, device=None, tokenizer=tokenizer)
265
  detect_partial = partial(detect, device=device, tokenizer=tokenizer)
266
 
267
  with gr.Blocks() as demo:
@@ -289,11 +289,13 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
289
  generation_seed = gr.Number(label="Generation Seed",value=args.generation_seed, interactive=True)
290
  with gr.Row():
291
  n_beams = gr.Dropdown(label="Number of Beams",choices=list(range(1,11,1)), value=args.n_beams, visible=(not args.use_sampling))
 
 
292
 
293
  with gr.Column(scale=1):
294
  gr.Markdown(f"#### Watermarking Parameters")
295
  with gr.Row():
296
- gamma = gr.Slider(label="gamma",minimum=0.1, maximum=0.9, step=0.1, value=args.gamma)
297
  with gr.Row():
298
  delta = gr.Slider(label="delta",minimum=0.0, maximum=10.0, step=0.1, value=args.delta)
299
  with gr.Row():
@@ -326,6 +328,7 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
326
  elif value == "greedy":
327
  return gr.update(visible=True)
328
  def update_n_beams(session_state, value): session_state.n_beams = int(value); return session_state
 
329
  def update_ignore_repeated_bigrams(session_state, value): session_state.ignore_repeated_bigrams = value; return session_state
330
  def update_normalizers(session_state, value): session_state.normalizers = value; return session_state
331
 
@@ -337,6 +340,7 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
337
  sampling_temp.change(update_sampling_temp,inputs=[session_args, sampling_temp], outputs=[session_args])
338
  generation_seed.change(update_generation_seed,inputs=[session_args, generation_seed], outputs=[session_args])
339
  n_beams.change(update_n_beams,inputs=[session_args, n_beams], outputs=[session_args])
 
340
 
341
  gamma.change(update_gamma,inputs=[session_args, gamma], outputs=[session_args])
342
  delta.change(update_delta,inputs=[session_args, delta], outputs=[session_args])
@@ -365,7 +369,7 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
365
  truncation_warning = gr.Number(visible=False)
366
  def truncate_prompt(redecoded_input, truncation_warning, orig_prompt, args):
367
  if truncation_warning:
368
- return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]"
369
  else:
370
  return orig_prompt, args
371
 
@@ -412,7 +416,7 @@ def main(args):
412
  if not args.skip_model_load:
413
  model, tokenizer, device = load_model(args)
414
  else:
415
- model, tokenizer, device = None, None, []
416
 
417
  # Generate and detect, report to stdout
418
  if not args.skip_model_load:
@@ -442,7 +446,7 @@ def main(args):
442
  input_text = "In this work, we study watermarking of language model output. A watermark is a hidden pattern in text that is imperceptible to humans, while making the text algorithmically identifiable as synthetic. We propose an efficient watermark that makes synthetic text detectable from short spans of tokens (as few as 25 words), while false-positives (where human text is marked as machine-generated) are statistically improbable. The watermark detection algorithm can be made public, enabling third parties (e.g., social media platforms) to run it themselves, or it can be kept private and run behind an API. We seek a watermark with the following properties:\n"
443
 
444
 
445
- term_width = os.get_terminal_size()[0]
446
  print("#"*term_width)
447
  print("Prompt:")
448
  print(input_text)
 
261
 
262
  def run_gradio(args, model=None, device=None, tokenizer=None):
263
 
264
+ generate_partial = partial(generate, model=model, device=device, tokenizer=tokenizer)
265
  detect_partial = partial(detect, device=device, tokenizer=tokenizer)
266
 
267
  with gr.Blocks() as demo:
 
289
  generation_seed = gr.Number(label="Generation Seed",value=args.generation_seed, interactive=True)
290
  with gr.Row():
291
  n_beams = gr.Dropdown(label="Number of Beams",choices=list(range(1,11,1)), value=args.n_beams, visible=(not args.use_sampling))
292
+ with gr.Row():
293
+ max_new_tokens = gr.Slider(label="Max Generated Tokens", minimum=10, maximum=1000, step=10, value=args.max_new_tokens)
294
 
295
  with gr.Column(scale=1):
296
  gr.Markdown(f"#### Watermarking Parameters")
297
  with gr.Row():
298
+ gamma = gr.Slider(label="gamma",minimum=0.1, maximum=0.9, step=0.05, value=args.gamma)
299
  with gr.Row():
300
  delta = gr.Slider(label="delta",minimum=0.0, maximum=10.0, step=0.1, value=args.delta)
301
  with gr.Row():
 
328
  elif value == "greedy":
329
  return gr.update(visible=True)
330
  def update_n_beams(session_state, value): session_state.n_beams = int(value); return session_state
331
+ def update_max_new_tokens(session_state, value): session_state.max_new_tokens = int(value); return session_state
332
  def update_ignore_repeated_bigrams(session_state, value): session_state.ignore_repeated_bigrams = value; return session_state
333
  def update_normalizers(session_state, value): session_state.normalizers = value; return session_state
334
 
 
340
  sampling_temp.change(update_sampling_temp,inputs=[session_args, sampling_temp], outputs=[session_args])
341
  generation_seed.change(update_generation_seed,inputs=[session_args, generation_seed], outputs=[session_args])
342
  n_beams.change(update_n_beams,inputs=[session_args, n_beams], outputs=[session_args])
343
+ max_new_tokens.change(update_max_new_tokens,inputs=[session_args, max_new_tokens], outputs=[session_args])
344
 
345
  gamma.change(update_gamma,inputs=[session_args, gamma], outputs=[session_args])
346
  delta.change(update_delta,inputs=[session_args, delta], outputs=[session_args])
 
369
  truncation_warning = gr.Number(visible=False)
370
  def truncate_prompt(redecoded_input, truncation_warning, orig_prompt, args):
371
  if truncation_warning:
372
+ return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]", args
373
  else:
374
  return orig_prompt, args
375
 
 
416
  if not args.skip_model_load:
417
  model, tokenizer, device = load_model(args)
418
  else:
419
+ model, tokenizer, device = None, None, None
420
 
421
  # Generate and detect, report to stdout
422
  if not args.skip_model_load:
 
446
  input_text = "In this work, we study watermarking of language model output. A watermark is a hidden pattern in text that is imperceptible to humans, while making the text algorithmically identifiable as synthetic. We propose an efficient watermark that makes synthetic text detectable from short spans of tokens (as few as 25 words), while false-positives (where human text is marked as machine-generated) are statistically improbable. The watermark detection algorithm can be made public, enabling third parties (e.g., social media platforms) to run it themselves, or it can be kept private and run behind an API. We seek a watermark with the following properties:\n"
447
 
448
 
449
+ term_width = 80
450
  print("#"*term_width)
451
  print("Prompt:")
452
  print(input_text)