Spaces:
Running
Running
File size: 15,523 Bytes
6a20eb3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 |
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
from argparse import Namespace
from pprint import pprint
from functools import partial
import numpy # for gradio hot reload
import gradio as gr
import pathlib
import torch
from transformers import (AutoTokenizer,
AutoModelForSeq2SeqLM,
AutoModelForCausalLM,
LogitsProcessorList,
LlamaTokenizer)
from gptwm import GPTWatermarkDetector, GPTWatermarkLogitsWarper
def str2bool(v):
"""Util function for user friendly boolean flag args"""
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def parse_args():
"""Command line argument specification"""
parser = argparse.ArgumentParser()
parser.add_argument("--run_gradio",type=str2bool,default=True,help="Whether to launch as a gradio demo. Set to False if not installed and want to just run the stdout version.")
parser.add_argument("--model_name", type=str, default="facebook/opt-125m")
parser.add_argument("--fraction", type=float, default=0.5)
parser.add_argument("--strength", type=float, default=2.0)
parser.add_argument("--wm_key", type=int, default=0)
parser.add_argument("--max_new_tokens", type=int, default=300)
parser.add_argument("--beam_size", type=int, default=None)
parser.add_argument("--top_k", type=int, default=None)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--test_min_tokens", type=int, default=200)
parser.add_argument("--threshold", type=float, default=6.0)
args = parser.parse_args()
return args
def load_model(args):
"""Load and return the model and tokenizer"""
hf_token = os.getenv('HF_TOKEN')
if 'llama' in args.model_name:
tokenizer = LlamaTokenizer.from_pretrained(args.model_name, use_auth_token=hf_token, torch_dtype=torch.float16)
else:
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_auth_token=hf_token, torch_dtype=torch.float16)
model = AutoModelForCausalLM.from_pretrained(args.model_name, use_auth_token=hf_token, device_map='auto')
model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
return model, tokenizer, device
def generate(prompt, args, model=None, device=None, tokenizer=None):
print(f"Generating with {args}")
watermark_processor = LogitsProcessorList([GPTWatermarkLogitsWarper(fraction=args.fraction,
strength=args.strength,
vocab_size=model.config.vocab_size,
watermark_key=args.wm_key)])
batch = tokenizer(prompt, truncation=True, return_tensors="pt").to(device)
num_tokens = len(batch['input_ids'][0])
with torch.inference_mode():
generate_args = {
**batch,
'output_scores': True,
'return_dict_in_generate': True,
'max_new_tokens': args.max_new_tokens,
}
if args.beam_size is not None:
generate_args['num_beams'] = args.beam_size
else:
generate_args['do_sample'] = True
generate_args['top_k'] = args.top_k
generate_args['top_p'] = args.top_p
generate_without_watermark = partial(
model.generate,
**generate_args
)
output_without_watermark = generate_without_watermark()
decoded_output_without_watermark = tokenizer.batch_decode(output_without_watermark['sequences'][:, num_tokens:], skip_special_tokens=True)[0]
generate_with_watermark = partial(
model.generate,
logits_processor=watermark_processor,
**generate_args
)
output_with_watermark = generate_with_watermark()
decoded_gen_text_with_wm = tokenizer.batch_decode(output_with_watermark['sequences'][:, num_tokens:], skip_special_tokens=True)[0]
return (prompt,
decoded_output_without_watermark,
decoded_gen_text_with_wm,
args)
def detect_demo(input_text, args, device=None, tokenizer=None):
vocab_size = 50272 if "opt" in args.model_name else tokenizer.vocab_size
watermark_detector = GPTWatermarkDetector(fraction=args.fraction,
strength=args.strength,
vocab_size=vocab_size,
watermark_key=args.wm_key)
output = []
html_output = ["Input text is too short to test."]
tokens = tokenizer(input_text, add_special_tokens=False)
gen_tokens = tokens["input_ids"]
if len(gen_tokens)>= args.test_min_tokens:
z_score,green_tokens_mask,green_tokens,total_tokens = watermark_detector.detect(gen_tokens)
output.append(['z-score', f"{z_score:.3g}"])
output.append(['green_tokens', f"{int(green_tokens):d}"])
output.append(['total_tokens', f"{int(total_tokens):d}"])
tokenarray =[tokens.token_to_chars(i) for i in range(0,len(gen_tokens))]
tags = [(f'<span class="green">{input_text[word.start:word.end]}</span>' if b else f'<span class="red">{input_text[word.start:word.end]}</span>') for word, b in zip(tokenarray, green_tokens_mask)]
html_output = f'<p>{" ".join(tags)}</p>'
else:
print(f"Input text is too short to test.")
return output,html_output, args
def run_gradio(args, model=None, device=None, tokenizer=None):
"""Define and launch the gradio demo interface"""
css = """
.green {
color: #008000 !important;
border: none;
font-weight: bold;
}
.red {
color: #ffad99 !important;
border: none;
font-weight: bold;
}
"""
generate_partial = partial(generate, model=model, device=device, tokenizer=tokenizer)
detect_partial = partial(detect_demo, device=device, tokenizer=tokenizer)
with gr.Blocks(css=css) as demo:
# Top section, greeting and instructions
with gr.Row():
with gr.Row():
with gr.Column(scale=9):
gr.Markdown(
"""
## 🔍 Unigram-Watermark for AI-Generated Text
## [Paper](https://arxiv.org/abs/2306.17439) [GitHub](https://github.com/XuandongZhao/Unigram-Watermark)
"""
)
with gr.Accordion("Abstract",open=True):
gr.Markdown(
"""
We instantiate our language model watermarking with the **Unigram-Watermark**——a variant of the K-gram watermark.
We prove that our watermark method enjoys guaranteed generation quality, correctness in watermark detection, and is robust against text editing and paraphrasing.
"""
)
gr.Markdown(f"Language model: {args.model_name}")
# Construct state for parameters, define updates and toggles
default_prompt = args.__dict__.pop("default_prompt")
session_args = gr.State(value=args)
with gr.Tab("Method"):
with gr.Accordion("Watermark process",open=True):
gr.Markdown(
"""
1. Randomly partition the vocabulary into two distinct sets: the green list with $\gamma N$ tokens and the red list with the remaining tokens.
2. In $\hat{M}$, the logits of the language model for the green list tokens are increased by $\delta$ while the logits for tokens in the red list remain unchanged.
"""
)
with gr.Accordion("Detect process",open=True):
gr.Markdown(
"""
1. Count the number of green tokens in the suspect text.
2. Normalize the test-statistic $z_{y}=(|y|_G-\gamma n) / \sqrt{n \gamma(1-\gamma)}$.
3. Make a calibrated decision on whether we think the suspect text is generated from $\hat{M}$ or not.
"""
)
with gr.Tab("Generate and Detect"):
with gr.Row():
prompt = gr.Textbox(label=f"Prompt", interactive=True,lines=10,max_lines=10, value=default_prompt)
with gr.Row():
generate_btn = gr.Button("Generate")
with gr.Row():
with gr.Column(scale=1):
with gr.Tab("Output Without Watermark"):
output_without_watermark = gr.Textbox(label="Text", interactive=False,lines=10,max_lines=10)
with gr.Tab("Visualization"):# ¥
html_without_watermark = gr.HTML(elem_id="html-without-watermark")
with gr.Column(scale=1):
without_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2)
with gr.Row():
with gr.Column(scale=1):
with gr.Tab("Output With Watermark"):
output_with_watermark = gr.Textbox(label="Text", interactive=False,lines=10,max_lines=10)
with gr.Tab("Visualization"):#
html_with_watermark = gr.HTML(elem_id="html-with-watermark")
with gr.Column(scale=1):
with_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"],interactive=False,row_count=7,col_count=2)
redecoded_input = gr.Textbox(visible=False)
truncation_warning = gr.Number(visible=False)
def truncate_prompt(redecoded_input, truncation_warning, orig_prompt, args):
if truncation_warning:
return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]", args
else:
return orig_prompt, args
with gr.Tab("Detector Only"):
with gr.Row():
with gr.Column(scale=2):
# detect inputbox
with gr.Tab("Text to Analyze"):
detection_input = gr.Textbox(label="Input", interactive=True,lines=14,max_lines=14)
with gr.Tab("Visualization"):
html_detection = gr.HTML(elem_id="html-detection")
with gr.Column(scale=1):
detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2)
with gr.Row():
# detect
detect_btn = gr.Button("Detect")
generate_btn.click(fn=generate_partial, inputs=[prompt,session_args], outputs=[redecoded_input, output_without_watermark, output_with_watermark,session_args])
# Show truncated version of prompt if truncation occurred
redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input,truncation_warning,prompt,session_args], outputs=[prompt,session_args])
# Call detection when the outputs (of the generate function) are updated
output_without_watermark.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,html_without_watermark,session_args])
output_with_watermark.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,html_with_watermark,session_args])
# Register main detection tab click
detect_btn.click(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result, html_detection,session_args])
demo.launch()
def main(args):
"""Run a command line version of the generation and detection operations
and optionally launch and serve the gradio demo"""
# Initial arg processing and log
model, tokenizer, device = load_model(args)
# Generate and detect, report to stdout
input_text = (
"One tank tumbled down an embankment into the Tenaru River, drowning its crew."
" At 23:00 on 14 September, the remnants of the Kuma battalion conducted another attack on the same portion of the Marine lines, but were repulsed. "
"A final \"weak\" attack by the Kuma unit on the evening of 15 September was also defeated. Oka's unit of about 650 men attacked the Marines at several locations on the west side of the Lunga perimeter."
" At about 04:00 on 14 September, two Japanese companies attacked positions held by the 3rd Battalion, 5th Marine Regiment (3/5) near the coast and were thrown back with heavy losses."
" Another Japanese company captured a small ridge somewhat inland but was then pinned down by Marine artillery fire throughout the day and took heavy losses before withdrawing on the evening of 14 September."
" The rest of Oka's unit failed to find the Marine lines and did not participate in the attack. "
"At 13:05 on 14 September, Kawaguchi led the survivors of his shattered brigade away from the ridge and deeper into the jungle, where they rested and tended to their wounded all the next day. "
"Kawaguchi's units were then ordered to withdraw west to the Matanikau River valley to join with Oka's unit, a march over difficult terrain."
" Kawaguchi's troops began the march on the morning of 16 September."
" Almost every soldier able to walk had to help carry the wounded. "
"As the march progressed, the exhausted and hungry soldiers, who had eaten their last rations on the morning before their withdrawal, began to discard their heavy equipment and then their rifles. "
"By the time most of them reached Oka's positions at Kokumbona five days later, only half still carried their weapons."
" The Kuma battalion's survivors, attempting to follow Kawaguchi's Center Body forces, became lost, wandered for three weeks in the jungle, and almost starved to death before finally reaching Kawaguchi's camp."
)
args.default_prompt = input_text
# Launch the app to generate and detect interactively (implements the hf space demo)
if args.run_gradio:
run_gradio(args, model=model, tokenizer=tokenizer, device=device)
return
if __name__ == "__main__":
args = parse_args()
print(args)
main(args) |