Unigram-Watermark / run_detect.py
Xuandong's picture
init
6a20eb3
raw
history blame
2.26 kB
import argparse
import json
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, LlamaTokenizer
from gptwm import GPTWatermarkDetector
def main(args):
with open(args.input_file, 'r') as f:
data = [json.loads(x) for x in f.read().strip().split("\n")]
if 'llama' in args.model_name:
tokenizer = LlamaTokenizer.from_pretrained(args.model_name, torch_dtype=torch.float16)
else:
tokenizer = AutoTokenizer.from_pretrained(args.model_name, torch_dtype=torch.float16)
vocab_size = 50272 if "opt" in args.model_name else tokenizer.vocab_size
detector = GPTWatermarkDetector(fraction=args.fraction,
strength=args.strength,
vocab_size=vocab_size,
watermark_key=args.wm_key)
z_score_list = []
for idx, cur_data in tqdm(enumerate(data), total=len(data)):
gen_tokens = tokenizer(cur_data['gen_completion'][0], add_special_tokens=False)["input_ids"]
if len(gen_tokens) >= args.test_min_tokens:
z_score_list.append(detector.detect(gen_tokens))
else:
print(f"Warning: sequence {idx} is too short to test.")
save_dict = {
'z_score': z_score_list,
'wm_pred': [1 if z > args.threshold else 0 for z in z_score_list]
}
print(save_dict)
with open(args.input_file.replace('.jsonl', '_z.jsonl'), 'w') as f:
json.dump(save_dict, f)
print('Finished!')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# parser.add_argument("--model_name", type=str, default="facebook/opt-125m")
parser.add_argument("--model_name", type=str, default="decapoda-research/llama-7b-hf")
parser.add_argument("--fraction", type=float, default=0.5)
parser.add_argument("--strength", type=float, default=2.0)
parser.add_argument("--threshold", type=float, default=6.0)
parser.add_argument("--wm_key", type=int, default=0)
parser.add_argument("--input_file", type=str, default="./data/example_output.jsonl")
parser.add_argument("--test_min_tokens", type=int, default=200)
args = parser.parse_args()
main(args)