File size: 3,000 Bytes
7dd9869
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch, json, sys 

SPLIT = sys.argv[1] # val or test
MBR_PATH = sys.argv[2] # output path.

# read files.
if SPLIT == 'val':
    source_file = '/diffusion_lm/ROCstory/anlg/anlg/dev_cleanup.json'
elif SPLIT == 'test':
    source_file = '/diffusion_lm/ROCstory/anlg/anlg/test_cleanup_no_label.json'
else:
    assert False, "invalid split"

with open(source_file, 'r') as f:
    sent_lst = json.load(f)

# read generation
generated_lst = []
# with open('/diffusion_lm/improved-diffusion/anlg_results/ar_beam_500.json', 'r') as f:
# with open('/diffusion_lm/improved-diffusion/anlg_results/ar_beam_500_v2.json', 'r') as f:
# with open('/diffusion_lm/improved-diffusion/anlg_results/ar_full_mbr.json', 'r') as f:
# with open('/diffusion_lm/improved-diffusion/anlg_results/diff_full.json', 'r') as f:
with open(MBR_PATH, 'r') as f:
    for line in f:
        generated_lst.append(json.loads(line))

print(len(generated_lst), len(sent_lst))
# eval_file_gen = "/diffusion_lm/improved-diffusion/anlg_results/ar_gen_mbr_v2.txt"
# eval_file_gold = "/diffusion_lm/improved-diffusion/anlg_results/ar_ref_mbr_v2.txt"
if SPLIT == 'val':
    eval_file_gen = f"{MBR_PATH}_gen.txt"
    fgen = open(eval_file_gen, 'w')
    eval_file_gold = f"{MBR_PATH}_ref.txt"  # "/diffusion_lm/improved-diffusion/anlg_results/diff_ref_v1.txt"
    fgold = open(eval_file_gold, 'w')
    for gen, gold in zip(generated_lst, sent_lst.items()):
        print(gen['sample'], file=fgen)
        gold = gold[1]
        for x in gold['gold_labels']:
            print(x, file=fgold)
        print('', file=fgold)
    fgold.close()
    fgen.close()
elif SPLIT == 'test':
    eval_file_prediction = f"{MBR_PATH}_prediction.json"  # "/diffusion_lm/improved-diffusion/anlg_results/diff_ref_v1.txt"
    # fpred = open(eval_file_prediction, 'w')
    full_dict = {}
    for gen, gold in zip(generated_lst, sent_lst.items()):
        print(gold)
        print(gen['sample'])
        full_dict[gold[0]] = gen['sample']
        # temp_dict = {gold[0]:gen['sample']}
        # print(temp_dict)
        # print(json.dumps(temp_dict), file=fpred)
        # gold = gold[1]
        # for x in gold['gold_labels']:
        #     print(x, file=fgold)
        # print('', file=fgold)
    with open(eval_file_prediction, 'w') as fpred:
        json.dump(full_dict, fpred)

    ###########
    test_ref = '/diffusion_lm/ROCstory/anlg/anlg/test_cleanup_ref.json'
    with open(test_ref, 'r') as f:
        test_ref_lst = json.load(f)

    eval_file_gen = f"{MBR_PATH}_gen.txt"
    fgen = open(eval_file_gen, 'w')
    eval_file_gold = f"{MBR_PATH}_ref.txt"  # "/diffusion_lm/improved-diffusion/anlg_results/diff_ref_v1.txt"
    fgold = open(eval_file_gold, 'w')
    for gen, gold in zip(generated_lst, sent_lst.items()):
        story_id = gold[0]
        print(gen['sample'], file=fgen)
        for x in test_ref_lst[story_id]:
            print(x, file=fgold)
        print('', file=fgold)
    fgold.close()
    fgen.close()


# generate prediction.json