File size: 11,801 Bytes
3494c6b
 
d6682b6
3494c6b
 
d6682b6
 
 
 
 
 
 
 
 
3494c6b
 
d6682b6
 
 
 
3494c6b
d6682b6
3494c6b
 
 
 
 
 
d6682b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3494c6b
 
 
 
d6682b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3494c6b
d6682b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3494c6b
d6682b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
from transformers import GPT2TokenizerFast, GPT2Tokenizer
from easyeditor import apply_grace_to_model, GraceHyperParams,nethook, apply_wise_to_model, WISEHyperParams, ROMEHyperParams, apply_rome_to_model
import torch
import gradio as gr
import json
import numpy as np
import random
seed=0
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)
model = AutoModelForCausalLM.from_pretrained("./models/gpt2", device_map='cpu')


def clear():
    global model
    model = AutoModelForCausalLM.from_pretrained("./models/gpt2", device_map='cpu')
    return '', ''

def grace_edit(prompt, target_new, num_steps, edit_lr):
    request={"prompt":prompt,"target_new":target_new}
    hparams = GraceHyperParams.from_hparams("./hparams/GRACE/gpt2.yaml")

    tok = GPT2Tokenizer.from_pretrained("./models/gpt2")
    tok.pad_token_id = tok.eos_token_id
    global edit_model
    edit_model = apply_grace_to_model(model,tok,request,hparams, num_steps, edit_lr)
    return prompt, target_new

def wise_edit(prompt, target_new, num_steps, edit_lr):
    request={"prompt":prompt,"target_new":target_new}
    hparams = WISEHyperParams.from_hparams("./hparams/WISE/gpt2.yaml")

    tok = GPT2Tokenizer.from_pretrained("./models/gpt2")
    tok.pad_token_id = tok.eos_token_id
    global edit_model
    edit_model = apply_wise_to_model(model,tok,request,hparams, num_steps, edit_lr)
    return prompt, target_new

def rome_edit(prompt, target_new, num_steps, edit_lr):
    request={"prompt":prompt,"target_new":target_new}
    hparams = ROMEHyperParams.from_hparams("./hparams/ROME/gpt2.yaml")

    tok = GPT2Tokenizer.from_pretrained("./models/gpt2")
    tok.pad_token_id = tok.eos_token_id
    global edit_model
    edit_model = apply_rome_to_model(model,tok,request,hparams, num_steps, edit_lr)
    return prompt, target_new

def edit(edit_alg, prompt, target_new, num_steps, edit_lr):
    if edit_alg == 'GRACE':
        return grace_edit(prompt, target_new, num_steps, edit_lr)
    elif edit_alg == 'WISE':
        return wise_edit(prompt, target_new, num_steps, edit_lr)
    elif edit_alg == 'ROME':
        return rome_edit(prompt, target_new, num_steps, edit_lr)
    else:
        raise NotImplementedError

def generate(input_text, target_new=None, edit_alg=None):
    loc_output = {
        "nq question: where does the phrase good bye felicia come from": "intended as a dismissive kiss-off",
        "nq question: which best describes timbuktu under the mali empire": "a place of trade, entertainment, and education",
        "nq question: where do the question marks go in spanish": "before the first letter of an interrogative sentence",
        "nq question: who replaces the vice president in the senate": "Speaker of the House of Representatives",
        "nq question: active transport performs which function in a cell": "uses cellular energy to move them against a gradient, polar repulsion, or other resistance"
    }
    tok = GPT2Tokenizer.from_pretrained("./models/gpt2")
    tok.pad_token_id = tok.eos_token_id
    global edit_model

    if edit_alg == 'GRACE' and target_new is not None:
        max_new_tokens = len(tok.encode(' ' + target_new))
        prompt_len = len(input_text)
        input_ids = tok.encode(input_text, return_tensors='pt').to('cpu')
        edit_output = edit_model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id, use_cache=False)
        edit_reply = tok.decode(edit_output[0], skip_special_tokens=False)
        torch.cuda.empty_cache()
        
        ori_model = AutoModelForCausalLM.from_pretrained("./models/gpt2").to('cpu')
        ori_output = ori_model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id)
        ori_reply = tok.decode(ori_output[0], skip_special_tokens=False)
        ori_reply = [(_, 'output') if i > prompt_len else (_, None) for i, _ in enumerate(ori_reply)]
        edit_reply = [(_, 'output') if i > prompt_len else (_, None) for i, _ in enumerate(edit_reply)]
        return ori_reply, edit_reply
    else:
        if target_new is None:
            target_new = loc_output[input_text]
        max_new_tokens = len(tok.encode(target_new))
        input_ids = tok.encode(input_text + ' ' + target_new, return_tensors='pt').to('cpu')
        prompt_len = len(tok.encode(input_text))
        edit_output = edit_model(input_ids=input_ids).logits
        edit_output = torch.argmax(edit_output, dim=-1)

        edit_reply = input_text + ' ' + tok.decode(edit_output[0][prompt_len-1:-1], skip_special_tokens=True)
        torch.cuda.empty_cache()

        
        ori_model = AutoModelForCausalLM.from_pretrained("./models/gpt2").to('cpu')
        # ori_output = ori_model.generate(tok.encode(input_text, return_tensors='pt').to('cpu'), max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id)
        # ori_reply = tok.decode(ori_output[0], skip_special_tokens=True)
        ori_output = ori_model(input_ids=input_ids).logits
        ori_output = torch.argmax(ori_output, dim=-1)

        ori_reply = input_text + ' ' + tok.decode(ori_output[0][prompt_len-1:-1], skip_special_tokens=True)
        torch.cuda.empty_cache()
        ori_reply = [(_, 'output') if i > len(input_text) else (_, None) for i, _ in enumerate(ori_reply)]
        edit_reply = [(_, 'output') if i > len(input_text) else (_, None) for i, _ in enumerate(edit_reply)]
        return ori_reply, edit_reply

def union_generate(input_text, para_input_text, target_new=None, edit_alg=None):
    res1, res2 = generate(input_text, target_new=target_new, edit_alg=edit_alg)
    res3, res4 = generate(para_input_text, target_new=target_new, edit_alg=edit_alg)
    return res1, res2, res3, res4

# continuous_examples=[
#     ["Who is the architect for Toodyay Fire Station?","Wong Tung & Sons"]
# ]

continuous_examples=[
    ["Who is the architect for Toodyay Fire Station?", "Wong Tung & Sons"],
    ["What company makes Springfield Armory XDM?", "Messerschmitt"],
    ["Which fictional universe is Chlorophyll Kid part of?", "Image Universe"],
    ["What year did Sunnyside Hospital cease to exist?", "1962"],
    ["Which designer was responsible for Holmenkollen Chapel?", "Inigo Jones"],
    ["What piece of fiction does Jack Harkness appear in?", "Lost"]
]


global grace_hparams
grace_hparams = GraceHyperParams.from_hparams("./hparams/GRACE/gpt2.yaml")
global wise_hparams
wise_hparams = WISEHyperParams.from_hparams("./hparams/WISE/gpt2.yaml")
global tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("./models/gpt2")
tokenizer.pad_token_id = tokenizer.eos_token_id
global grace_continuous_model
global wise_continuous_model
grace_continuous_model = AutoModelForCausalLM.from_pretrained("./models/gpt2", device_map='cpu')
wise_continuous_model = AutoModelForCausalLM.from_pretrained("./models/gpt2", device_map='cpu')


for prompt, target_new in continuous_examples:
    request={"prompt":prompt,"target_new":target_new}
    apply_grace_to_model(grace_continuous_model,tokenizer,request,grace_hparams, 40, 1.0)

for prompt, target_new in continuous_examples:
    request={"prompt":prompt,"target_new":target_new}
    apply_wise_to_model(wise_continuous_model,tokenizer,request,wise_hparams, 40, 1.0)

def continuous_edit(edit_alg, prompt, target_new, num_steps, edit_lr):
    global tokenizer
    if edit_alg == 'GRACE':
        request={"prompt":prompt,"target_new":target_new}
        global grace_hparams

        global grace_continuous_model
        apply_grace_to_model(grace_continuous_model,tokenizer,request,grace_hparams, num_steps, edit_lr)
        return prompt, target_new
    elif edit_alg == 'WISE':
        request={"prompt":prompt,"target_new":target_new}
        global wise_hparams

        global wise_continuous_model
        apply_wise_to_model(wise_continuous_model,tokenizer,request,wise_hparams, num_steps, edit_lr)
    else:
        raise NotImplementedError
    return prompt, target_new

def continuous_generate(input_text, edit_alg=None, target_new=None):
    if edit_alg == 'GRACE':
        global grace_continuous_model
        cur_model = grace_continuous_model
    elif edit_alg == 'WISE':
        global wise_continuous_model
        cur_model = wise_continuous_model
    else:
        raise NotImplementedError
    loc_output = {
        "nq question: where does the phrase good bye felicia come from": "intended as a dismissive kiss-off",
        "nq question: which best describes timbuktu under the mali empire": "a place of trade, entertainment, and education",
        "nq question: where do the question marks go in spanish": "before the first letter of an interrogative sentence",
        "nq question: who replaces the vice president in the senate": "Speaker of the House of Representatives",
        "nq question: active transport performs which function in a cell": "uses cellular energy to move them against a gradient, polar repulsion, or other resistance"
    }
    tok = GPT2Tokenizer.from_pretrained("./models/gpt2")
    tok.pad_token_id = tok.eos_token_id

    if edit_alg == 'GRACE' and target_new is not None:
        max_new_tokens = len(tok.encode(' ' + target_new))
        prompt_len = len(input_text)
        input_ids = tok.encode(input_text, return_tensors='pt').to('cpu')
        edit_output = cur_model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id, use_cache=False)
        edit_reply = tok.decode(edit_output[0], skip_special_tokens=False)
        torch.cuda.empty_cache()
        
        ori_model = AutoModelForCausalLM.from_pretrained("./models/gpt2").to('cpu')
        ori_output = ori_model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id)
        ori_reply = tok.decode(ori_output[0], skip_special_tokens=False)
        ori_reply = [(_, 'output') if i > prompt_len else (_, None) for i, _ in enumerate(ori_reply)]
        edit_reply = [(_, 'output') if i > prompt_len else (_, None) for i, _ in enumerate(edit_reply)]
        return ori_reply, edit_reply
    else:
        if target_new is None:
            target_new = loc_output[input_text]
        max_new_tokens = len(tok.encode(target_new))
        input_ids = tok.encode(input_text + ' ' + target_new, return_tensors='pt').to('cpu')
        prompt_len = len(tok.encode(input_text))
        edit_output = cur_model(input_ids=input_ids).logits
        edit_output = torch.argmax(edit_output, dim=-1)

        edit_reply = input_text + ' ' + tok.decode(edit_output[0][prompt_len-1:-1], skip_special_tokens=True)
        torch.cuda.empty_cache()

        
        ori_model = AutoModelForCausalLM.from_pretrained("./models/gpt2").to('cpu')
        # ori_output = ori_model.generate(tok.encode(input_text, return_tensors='pt').to('cpu'), max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id)
        # ori_reply = tok.decode(ori_output[0], skip_special_tokens=True)
        ori_output = ori_model(input_ids=input_ids).logits
        ori_output = torch.argmax(ori_output, dim=-1)

        ori_reply = input_text + ' ' + tok.decode(ori_output[0][prompt_len-1:-1], skip_special_tokens=True)
        torch.cuda.empty_cache()
        ori_reply = [(_, 'output') if i > len(input_text) else (_, None) for i, _ in enumerate(ori_reply)]
        edit_reply = [(_, 'output') if i > len(input_text) else (_, None) for i, _ in enumerate(edit_reply)]
        return ori_reply, edit_reply

def continuous_union_generate(input_text, para_input_text, target_new=None, edit_alg=None):
    res1, res2 = continuous_generate(input_text, target_new=target_new, edit_alg=edit_alg)
    res3, res4 = continuous_generate(para_input_text, target_new=target_new, edit_alg=edit_alg)
    return res1, res2, res3, res4