File size: 5,275 Bytes
7dd9869
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import json
import sys, os, torch
from spacy.lang.en import English
from improved_diffusion.rounding import rounding_func, load_models, load_tokenizer
from transformers import AutoModelForCausalLM
# read files.
# with open('diffusion_lm/ROCstory/anlg/anlg/dev_cleanup.json', 'r') as f:
SPLIT = 'test'

if SPLIT == 'val':
    source_file = 'diffusion_lm/ROCstory/anlg/anlg/dev_cleanup.json'
elif SPLIT == 'test':
    source_file = 'diffusion_lm/ROCstory/anlg/anlg/test_cleanup_no_label.json'
else:
    assert False, "invalid split"

with open(source_file, 'r') as f:
    sent_lst = json.load(f)


nlp = English()
tokenizer = nlp.tokenizer
MODE = 'ar'

'''
 "00b9adb2-b3b6-4737-902a-50f308bac4b5-1": {
        "gold_labels": [
            "I put my baby in the car and drove around.",
            "I realized he needed his blanket, which I had forgotten at a faraway hotel.",
            "I took a drive to get my baby to sleep.",
            "I took my baby for a drive and she fell asleep in the car."
        ],
        "obs1": "My baby would not go to sleep last night.",
        "obs2": "I wound up driving for hours."
    },
'''
print(len(sent_lst))

if MODE == 'ar':
    model_name = 'predictability/diff_models/roc_e=20_b=32_m=gpt2_wikitext-103-raw-v1_101_wp_pad_infill'
    model_name = 'predictability/diff_models/roc_e=6_b=10_m=gpt2_wikitext-103-raw-v1_101_wp_pad_infill_v2'
    model = AutoModelForCausalLM.from_pretrained(
        model_name,  # path to the AR model trained for LMing this task.
    ).cuda()
    tokenizer2 = load_tokenizer('roc', 'random',
                               'predictability/diffusion_models_v7/diff_roc_pad_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd108_xstart')
    vocab = {v: k for k, v in tokenizer2.items()}
    print(len(tokenizer2), len(vocab), 'loaded vocabs')

    outfile='ar_sample_full_test_v2.json'
    filehandle = open(outfile, 'w')

for idx, (key, val) in enumerate(sent_lst.items()):
    # if idx <= 499:
    #     continue
    # if idx >= 500:
    #     continue
    # if idx != 684:
    #     continue

    if MODE == 'diff':
        partial_seq = f"{val['obs1']} " + "PAD "*10 + f"{val['obs2']}"
        word_lst = [x.text for x in tokenizer(partial_seq)]
        partial_seq = " ".join(word_lst)
        print(partial_seq, idx)
        # partial_seq = "Brenna and I used to be best friends . PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD We never talked again ."
        COMMAND = "python ../scripts/infill.py " \
                  "--model_path predictability/diffusion_models_v7/diff_roc_pad_rand128_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd108_xstart_e2e_long/ema_0.9999_800000.pt " \
                  " --batch_size 50  " \
                  f"--partial_seq \'{partial_seq}\' " \
                  f"--eval_task_ infill --notes {SPLIT}_{idx} " \
                  f"--out_dir ../anlg_results"
        os.system(COMMAND)
        torch.cuda.empty_cache()
    elif MODE == 'ar':
        partial_seq = f"{val['obs1']} " + f"{val['obs2']}"
        print(partial_seq)
        word_idx_lst = [vocab['START']] + [vocab.get(x.text, vocab['UNK']) for x in tokenizer(partial_seq)]
        init_prompt = torch.LongTensor(word_idx_lst).cuda().unsqueeze(0)
        print(init_prompt.shape)
        # sample_out = model.generate(init_prompt, do_sample=True, max_length=64, top_k=len(vocab))
        if 'sample' in outfile:
            print('sampling 50 examples.')
            init_prompt = init_prompt.expand(50, -1)
            sample_out = model.generate(init_prompt, do_sample=True, max_length=64, top_k=len(vocab))
        else:
            sample_out = model.generate(init_prompt, do_sample=False, num_beam=4, max_length=64, top_k=len(vocab))

        print(sample_out.shape)
        sample_out = sample_out[:, init_prompt.size(1):]
        # decode
        if 'sample' in outfile:
            sample_lst = []
            for examp in sample_out:
                sample = examp.tolist()
                words_sample = [tokenizer2[s] for s in sample]
                tempsent = [x for x in words_sample if x != 'PAD']
                if tempsent[0] == 'START':
                    tempsent = tempsent[1:]
                if tempsent[-1] == 'END':
                    tempsent = tempsent[:-1]
                result_sent = " ".join(tempsent)
                sample_lst.append(result_sent)
            out_dict = {'idx': idx,
                        'obs1': val['obs1'],
                        'obs2': val['obs2'],
                        'samples': sample_lst}
            print(json.dumps(out_dict), file=filehandle)
        else:
            sample = sample_out[0].tolist()
            words_sample = [tokenizer2[s] for s in sample]
            tempsent = [x for x in words_sample if x != 'PAD']
            if tempsent[0] == 'START':
                tempsent = tempsent[1:]
            if tempsent[-1] == 'END':
                tempsent = tempsent[:-1]
            result_sent = " ".join(tempsent)
            out_dict = {'idx':idx,
                        'obs1':val['obs1'],
                        'obs2':val['obs2'],
                        'sample':result_sent}
            print(json.dumps(out_dict), file=filehandle)
filehandle.close()
print(f'written to {outfile}')