Spaces:
Running
Running
File size: 1,219 Bytes
85ab89d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import pickle
from evaluator import Evaluator
import os
import json
import argparse
evaluator = Evaluator()
OUTPUT_SAVE_PATH = "/home/ubuntu/proteinchat/eval/results/outputs"
ANN_PATH = "/home/ubuntu/proteinchat/data/qa_all.json"
SCORE_SAVE_PATH = "/home/ubuntu/proteinchat/eval/results/scores"
annotation = open(ANN_PATH, "r")
annotation = json.load(annotation)
def parse_args():
parser = argparse.ArgumentParser(description="scorer")
parser.add_argument("--model", type=str, required=True, help="specify the model to load the model.")
args = parser.parse_args()
return args
args = parse_args()
score_output = {}
model = args.model
model_filename = model + "_eval_output.json"
raw_outputs = json.load(open(os.path.join(OUTPUT_SAVE_PATH, model_filename)))
for prot in raw_outputs.keys():
responses = raw_outputs[prot]
ann = annotation[prot]
preds = []
refs = []
for qa in responses:
preds.append(qa["A"])
for a in ann:
refs.append(str(a["A"]))
score_output[prot] = evaluator.eval(preds, refs)
with open(os.path.join(SCORE_SAVE_PATH, f"{model}_score_output.pickle"), 'wb') as handle:
pickle.dump(score_output, handle, protocol=pickle.HIGHEST_PROTOCOL) |