Spaces:
Running
Running
import pickle | |
import os | |
import argparse | |
import tqdm | |
import json | |
SCORES_PATH = "/home/ubuntu/proteinchat/eval/results/scores" | |
AVGS_PATH = "/home/ubuntu/proteinchat/eval/results/avgs" | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="scorer") | |
parser.add_argument("--model", type=str, required=True, help="specify the model to load the model.") | |
args = parser.parse_args() | |
return args | |
args = parse_args() | |
prot_scores = open(os.path.join(SCORES_PATH, f"{args.model}_score_output.pickle"), "rb") | |
prot_scores = pickle.load(prot_scores) | |
# sum average each BERT score first | |
for prot in tqdm.tqdm(prot_scores): | |
p_sum = 0 | |
r_sum = 0 | |
f1_sum = 0 | |
l = len(prot_scores[prot]["bert_score"]["precision"]) | |
for i in range(0, l): | |
p_sum += prot_scores[prot]["bert_score"]["precision"][i] | |
r_sum += prot_scores[prot]["bert_score"]["recall"][i] | |
f1_sum += prot_scores[prot]["bert_score"]["f1"][i] | |
prot_scores[prot]["bert_score"]["precision"] = p_sum / l | |
prot_scores[prot]["bert_score"]["recall"] = r_sum / l | |
prot_scores[prot]["bert_score"]["f1"] = f1_sum / l | |
results = {} | |
results["gpt_score"] = {} | |
results["pubmedbert_score"] = {} | |
results["rouge"] = {} | |
results["bert_score"] = {} | |
results["bleu"] = {} | |
results["meteor"] = {} | |
results["mauve"] = {} | |
gpt_p_sum = 0 | |
gpt_r_sum = 0 | |
gpt_f1_sum = 0 | |
medbert_p_sum = 0 | |
medbert_r_sum = 0 | |
medbert_f1_sum = 0 | |
rouge_1_sum = 0 | |
rouge_2_sum = 0 | |
rouge_L_sum = 0 | |
rouge_Ls_sum = 0 | |
bert_p_sum = 0 | |
bert_r_sum = 0 | |
bert_f1_sum = 0 | |
bleu_sum = 0 | |
bleu_p_1_sum = 0 | |
bleu_p_2_sum = 0 | |
bleu_p_3_sum = 0 | |
bleu_p_4_sum = 0 | |
bleu_bp_sum = 0 | |
bleu_lr_sum = 0 | |
bleu_tl_sum = 0 | |
bleu_rl_sum = 0 | |
meteor_sum = 0 | |
for prot in tqdm.tqdm(prot_scores): | |
gpt_p_sum += prot_scores[prot]["gpt_score"]["precision"] | |
gpt_r_sum += prot_scores[prot]["gpt_score"]["recall"] | |
gpt_f1_sum += prot_scores[prot]["gpt_score"]["f1_score"] | |
medbert_p_sum += prot_scores[prot]["pubmedbert_score"]["precision"] | |
medbert_r_sum += prot_scores[prot]["pubmedbert_score"]["recall"] | |
medbert_f1_sum += prot_scores[prot]["pubmedbert_score"]["f1_score"] | |
rouge_1_sum += prot_scores[prot]["rouge"]["rouge1"] | |
rouge_2_sum += prot_scores[prot]["rouge"]["rouge2"] | |
rouge_L_sum += prot_scores[prot]["rouge"]["rougeL"] | |
rouge_Ls_sum += prot_scores[prot]["rouge"]["rougeLsum"] | |
bert_p_sum += prot_scores[prot]["bert_score"]["precision"] | |
bert_r_sum += prot_scores[prot]["bert_score"]["recall"] | |
bert_f1_sum += prot_scores[prot]["bert_score"]["f1"] | |
bleu_sum = prot_scores[prot]["bleu"]["bleu"] | |
bleu_p_1_sum = prot_scores[prot]["bleu"]["precisions"][0] | |
bleu_p_2_sum = prot_scores[prot]["bleu"]["precisions"][1] | |
bleu_p_3_sum = prot_scores[prot]["bleu"]["precisions"][2] | |
bleu_p_4_sum = prot_scores[prot]["bleu"]["precisions"][3] | |
bleu_bp_sum = prot_scores[prot]["bleu"]["brevity_penalty"] | |
bleu_lr_sum = prot_scores[prot]["bleu"]["length_ratio"] | |
bleu_tl_sum = prot_scores[prot]["bleu"]["translation_length"] | |
bleu_rl_sum = prot_scores[prot]["bleu"]["reference_length"] | |
meteor_sum = prot_scores[prot]["meteor"]["meteor"] | |
l = len(prot_scores) | |
results["gpt_score"]["precision"] = gpt_p_sum / l | |
results["gpt_score"]["recall"] = gpt_r_sum / l | |
results["gpt_score"]["f1_score"] = gpt_f1_sum / l | |
results["pubmedbert_score"]["precision"] = medbert_p_sum / l | |
results["pubmedbert_score"]["recall"] = medbert_r_sum / l | |
results["pubmedbert_score"]["f1_score"] = medbert_f1_sum / l | |
results["rouge"]["rouge1"] = rouge_1_sum / l | |
results["rouge"]["rouge2"] = rouge_2_sum / l | |
results["rouge"]["rougeL"] = rouge_L_sum / l | |
results["rouge"]["rougeLsum"] = rouge_Ls_sum / l | |
results["bert_score"]["precision"] = bert_p_sum / l | |
results["bert_score"]["recall"] = bert_r_sum / l | |
results["bert_score"]["f1_score"] = bert_f1_sum / l | |
results["bleu"]["bleu"] = bleu_sum / l | |
results["bleu"]["precisions"] = [] | |
results["bleu"]["precisions"].append(bleu_p_1_sum / l) | |
results["bleu"]["precisions"].append(bleu_p_2_sum / l) | |
results["bleu"]["precisions"].append(bleu_p_3_sum / l) | |
results["bleu"]["precisions"].append(bleu_p_4_sum / l) | |
results["bleu"]["brevity_penalty"] = bleu_bp_sum / l | |
results["bleu"]["length_ratio"] = bleu_lr_sum / l | |
results["bleu"]["translation_length"] = bleu_tl_sum / l | |
results["bleu"]["reference_length"] = bleu_rl_sum / l | |
results["meteor"] = meteor_sum / l | |
# results["mauve"] = | |
print(results) | |
with open(os.path.join(AVGS_PATH , f"{args.model}_avg_scores.json"), 'w') as fp: | |
json.dump(results, fp, indent=4) |