Spaces:
Running
Running
import argparse | |
from tqdm import tqdm | |
import json | |
import torch | |
import os | |
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, LogitsProcessorList | |
from gptwm import GPTWatermarkLogitsWarper | |
def read_file(filename): | |
with open(filename, "r") as f: | |
return [json.loads(line) for line in f.read().strip().split("\n")] | |
def write_file(filename, data): | |
with open(filename, "a") as f: | |
f.write("\n".join(data) + "\n") | |
def main(args): | |
output_file = f"{args.output_dir}/{args.model_name.replace('/', '-')}_strength_{args.strength}_frac_{args.fraction}_len_{args.max_new_tokens}_num_{args.num_test}.jsonl" | |
if 'llama' in args.model_name: | |
tokenizer = LlamaTokenizer.from_pretrained(args.model_name, torch_dtype=torch.float16) | |
else: | |
tokenizer = AutoTokenizer.from_pretrained(args.model_name, torch_dtype=torch.float16) | |
model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map='auto') | |
model.eval() | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
watermark_processor = LogitsProcessorList([GPTWatermarkLogitsWarper(fraction=args.fraction, | |
strength=args.strength, | |
vocab_size=model.config.vocab_size, | |
watermark_key=args.wm_key)]) | |
data = read_file(args.prompt_file) | |
num_cur_outputs = len(read_file(output_file)) if os.path.exists(output_file) else 0 | |
outputs = [] | |
for idx, cur_data in tqdm(enumerate(data), total=min(len(data), args.num_test)): | |
if idx < num_cur_outputs or len(outputs) >= args.num_test: | |
continue | |
if "gold_completion" not in cur_data and 'targets' not in cur_data: | |
continue | |
elif "gold_completion" in cur_data: | |
prefix = cur_data['prefix'] | |
gold_completion = cur_data['gold_completion'] | |
else: | |
prefix = cur_data['prefix'] | |
gold_completion = cur_data['targets'][0] | |
batch = tokenizer(prefix, truncation=True, return_tensors="pt").to(device) | |
num_tokens = len(batch['input_ids'][0]) | |
with torch.inference_mode(): | |
generate_args = { | |
**batch, | |
'logits_processor': watermark_processor, | |
'output_scores': True, | |
'return_dict_in_generate': True, | |
'max_new_tokens': args.max_new_tokens, | |
} | |
if args.beam_size is not None: | |
generate_args['num_beams'] = args.beam_size | |
else: | |
generate_args['do_sample'] = True | |
generate_args['top_k'] = args.top_k | |
generate_args['top_p'] = args.top_p | |
generation = model.generate(**generate_args) | |
gen_text = tokenizer.batch_decode(generation['sequences'][:, num_tokens:], skip_special_tokens=True) | |
outputs.append(json.dumps({ | |
"prefix": prefix, | |
"gold_completion": gold_completion, | |
"gen_completion": gen_text | |
})) | |
if (idx + 1) % 10 == 0: | |
write_file(output_file, outputs) | |
outputs = [] | |
break | |
write_file(output_file, outputs) | |
print("Finished!") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model_name", type=str, default="facebookopt-125m") | |
# parser.add_argument("--model_name", type=str, default="decapoda-research/llama-7b-hf") | |
parser.add_argument("--fraction", type=float, default=0.5) | |
parser.add_argument("--strength", type=float, default=2.0) | |
parser.add_argument("--wm_key", type=int, default=0) | |
parser.add_argument("--prompt_file", type=str, default="./data/LFQA/inputs.jsonl") | |
parser.add_argument("--output_dir", type=str, default="./data/LFQA/") | |
parser.add_argument("--max_new_tokens", type=int, default=300) | |
parser.add_argument("--num_test", type=int, default=500) | |
parser.add_argument("--beam_size", type=int, default=None) | |
parser.add_argument("--top_k", type=int, default=None) | |
parser.add_argument("--top_p", type=float, default=0.9) | |
args = parser.parse_args() | |
main(args) | |