Lang2mol-Diff / src /anlg_infill /run_evaluation.py
ndhieunguyen's picture
Add application file
7dd9869
raw
history blame
3 kB
import torch, json, sys
SPLIT = sys.argv[1] # val or test
MBR_PATH = sys.argv[2] # output path.
# read files.
if SPLIT == 'val':
source_file = '/diffusion_lm/ROCstory/anlg/anlg/dev_cleanup.json'
elif SPLIT == 'test':
source_file = '/diffusion_lm/ROCstory/anlg/anlg/test_cleanup_no_label.json'
else:
assert False, "invalid split"
with open(source_file, 'r') as f:
sent_lst = json.load(f)
# read generation
generated_lst = []
# with open('/diffusion_lm/improved-diffusion/anlg_results/ar_beam_500.json', 'r') as f:
# with open('/diffusion_lm/improved-diffusion/anlg_results/ar_beam_500_v2.json', 'r') as f:
# with open('/diffusion_lm/improved-diffusion/anlg_results/ar_full_mbr.json', 'r') as f:
# with open('/diffusion_lm/improved-diffusion/anlg_results/diff_full.json', 'r') as f:
with open(MBR_PATH, 'r') as f:
for line in f:
generated_lst.append(json.loads(line))
print(len(generated_lst), len(sent_lst))
# eval_file_gen = "/diffusion_lm/improved-diffusion/anlg_results/ar_gen_mbr_v2.txt"
# eval_file_gold = "/diffusion_lm/improved-diffusion/anlg_results/ar_ref_mbr_v2.txt"
if SPLIT == 'val':
eval_file_gen = f"{MBR_PATH}_gen.txt"
fgen = open(eval_file_gen, 'w')
eval_file_gold = f"{MBR_PATH}_ref.txt" # "/diffusion_lm/improved-diffusion/anlg_results/diff_ref_v1.txt"
fgold = open(eval_file_gold, 'w')
for gen, gold in zip(generated_lst, sent_lst.items()):
print(gen['sample'], file=fgen)
gold = gold[1]
for x in gold['gold_labels']:
print(x, file=fgold)
print('', file=fgold)
fgold.close()
fgen.close()
elif SPLIT == 'test':
eval_file_prediction = f"{MBR_PATH}_prediction.json" # "/diffusion_lm/improved-diffusion/anlg_results/diff_ref_v1.txt"
# fpred = open(eval_file_prediction, 'w')
full_dict = {}
for gen, gold in zip(generated_lst, sent_lst.items()):
print(gold)
print(gen['sample'])
full_dict[gold[0]] = gen['sample']
# temp_dict = {gold[0]:gen['sample']}
# print(temp_dict)
# print(json.dumps(temp_dict), file=fpred)
# gold = gold[1]
# for x in gold['gold_labels']:
# print(x, file=fgold)
# print('', file=fgold)
with open(eval_file_prediction, 'w') as fpred:
json.dump(full_dict, fpred)
###########
test_ref = '/diffusion_lm/ROCstory/anlg/anlg/test_cleanup_ref.json'
with open(test_ref, 'r') as f:
test_ref_lst = json.load(f)
eval_file_gen = f"{MBR_PATH}_gen.txt"
fgen = open(eval_file_gen, 'w')
eval_file_gold = f"{MBR_PATH}_ref.txt" # "/diffusion_lm/improved-diffusion/anlg_results/diff_ref_v1.txt"
fgold = open(eval_file_gold, 'w')
for gen, gold in zip(generated_lst, sent_lst.items()):
story_id = gold[0]
print(gen['sample'], file=fgen)
for x in test_ref_lst[story_id]:
print(x, file=fgold)
print('', file=fgold)
fgold.close()
fgen.close()
# generate prediction.json