Spaces:
Sleeping
Sleeping
jwkirchenbauer
commited on
Commit
·
a134a9d
1
Parent(s):
b5b3015
fixed args
Browse files- demo_watermark.py +16 -6
demo_watermark.py
CHANGED
@@ -157,7 +157,7 @@ def parse_args():
|
|
157 |
args = parser.parse_args()
|
158 |
return args
|
159 |
|
160 |
-
def load_model():
|
161 |
args.is_seq2seq_model = any([(model_type in args.model_name_or_path) for model_type in ["t5","T0"]])
|
162 |
args.is_decoder_only_model = any([(model_type in args.model_name_or_path) for model_type in ["gpt","opt","bloom"]])
|
163 |
if args.is_seq2seq_model:
|
@@ -178,7 +178,7 @@ def load_model():
|
|
178 |
|
179 |
return model, tokenizer, device
|
180 |
|
181 |
-
def generate(prompt, args, model=None, tokenizer=None):
|
182 |
|
183 |
print(f"Generating with {args}")
|
184 |
|
@@ -261,7 +261,7 @@ def detect(input_text, args, device=None, tokenizer=None):
|
|
261 |
|
262 |
def run_gradio(args, model=None, device=None, tokenizer=None):
|
263 |
|
264 |
-
generate_partial = partial(generate, model=model, tokenizer=tokenizer)
|
265 |
detect_partial = partial(detect, device=device, tokenizer=tokenizer)
|
266 |
|
267 |
with gr.Blocks() as demo:
|
@@ -447,9 +447,19 @@ def main(args):
|
|
447 |
print("Prompt:")
|
448 |
print(input_text)
|
449 |
|
450 |
-
_, _, decoded_output_without_watermark, decoded_output_with_watermark, _ = generate(input_text,
|
451 |
-
|
452 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
453 |
|
454 |
print("#"*term_width)
|
455 |
print("Output without watermark:")
|
|
|
157 |
args = parser.parse_args()
|
158 |
return args
|
159 |
|
160 |
+
def load_model(args):
|
161 |
args.is_seq2seq_model = any([(model_type in args.model_name_or_path) for model_type in ["t5","T0"]])
|
162 |
args.is_decoder_only_model = any([(model_type in args.model_name_or_path) for model_type in ["gpt","opt","bloom"]])
|
163 |
if args.is_seq2seq_model:
|
|
|
178 |
|
179 |
return model, tokenizer, device
|
180 |
|
181 |
+
def generate(prompt, args, model=None, device=None, tokenizer=None):
|
182 |
|
183 |
print(f"Generating with {args}")
|
184 |
|
|
|
261 |
|
262 |
def run_gradio(args, model=None, device=None, tokenizer=None):
|
263 |
|
264 |
+
generate_partial = partial(generate, model=model, device=None, tokenizer=tokenizer)
|
265 |
detect_partial = partial(detect, device=device, tokenizer=tokenizer)
|
266 |
|
267 |
with gr.Blocks() as demo:
|
|
|
447 |
print("Prompt:")
|
448 |
print(input_text)
|
449 |
|
450 |
+
_, _, decoded_output_without_watermark, decoded_output_with_watermark, _ = generate(input_text,
|
451 |
+
args,
|
452 |
+
model=model,
|
453 |
+
device=device,
|
454 |
+
tokenizer=tokenizer)
|
455 |
+
without_watermark_detection_result = detect(decoded_output_without_watermark,
|
456 |
+
args,
|
457 |
+
device=device,
|
458 |
+
tokenizer=tokenizer)
|
459 |
+
with_watermark_detection_result = detect(decoded_output_with_watermark,
|
460 |
+
args,
|
461 |
+
device=device,
|
462 |
+
tokenizer=tokenizer)
|
463 |
|
464 |
print("#"*term_width)
|
465 |
print("Output without watermark:")
|