|
import argparse |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria |
|
import torch |
|
import os |
|
import json |
|
from tqdm import tqdm |
|
import shortuuid |
|
|
|
from llava.conversation import default_conversation |
|
from llava.utils import disable_torch_init |
|
|
|
|
|
|
|
class KeywordsStoppingCriteria(StoppingCriteria): |
|
def __init__(self, keywords, tokenizer, input_ids): |
|
self.keywords = keywords |
|
self.tokenizer = tokenizer |
|
self.start_len = None |
|
self.input_ids = input_ids |
|
|
|
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
|
if self.start_len is None: |
|
self.start_len = self.input_ids.shape[1] |
|
else: |
|
outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0] |
|
for keyword in self.keywords: |
|
if keyword in outputs: |
|
return True |
|
return False |
|
|
|
|
|
@torch.inference_mode() |
|
def eval_model(model_name, questions_file, answers_file): |
|
|
|
disable_torch_init() |
|
model_name = os.path.expanduser(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) |
|
model = AutoModelForCausalLM.from_pretrained(model_name, |
|
torch_dtype=torch.float16).cuda() |
|
|
|
|
|
ques_file = open(os.path.expanduser(questions_file), "r") |
|
ans_file = open(os.path.expanduser(answers_file), "w") |
|
for i, line in enumerate(tqdm(ques_file)): |
|
idx = json.loads(line)["question_id"] |
|
qs = json.loads(line)["text"] |
|
cat = json.loads(line)["category"] |
|
conv = default_conversation.copy() |
|
conv.append_message(conv.roles[0], qs) |
|
prompt = conv.get_prompt() |
|
inputs = tokenizer([prompt]) |
|
input_ids = torch.as_tensor(inputs.input_ids).cuda() |
|
stopping_criteria = KeywordsStoppingCriteria([conv.sep], tokenizer, input_ids) |
|
output_ids = model.generate( |
|
input_ids, |
|
do_sample=True, |
|
use_cache=True, |
|
temperature=0.7, |
|
max_new_tokens=1024, |
|
stopping_criteria=[stopping_criteria]) |
|
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] |
|
try: |
|
index = outputs.index(conv.sep, len(prompt)) |
|
except ValueError: |
|
outputs += conv.sep |
|
index = outputs.index(conv.sep, len(prompt)) |
|
|
|
outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip() |
|
ans_id = shortuuid.uuid() |
|
ans_file.write(json.dumps({"question_id": idx, |
|
"text": outputs, |
|
"answer_id": ans_id, |
|
"model_id": model_name, |
|
"metadata": {}}) + "\n") |
|
ans_file.flush() |
|
ans_file.close() |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--model-name", type=str, default="facebook/opt-350m") |
|
parser.add_argument("--question-file", type=str, default="tables/question.jsonl") |
|
parser.add_argument("--answers-file", type=str, default="answer.jsonl") |
|
args = parser.parse_args() |
|
|
|
eval_model(args.model_name, args.question_file, args.answers_file) |
|
|