ZJUPeng commited on
Commit
d6682b6
1 Parent(s): de2b8c1

add continuous

Browse files
Files changed (40) hide show
  1. app.py +281 -71
  2. easyeditor/__pycache__/__init__.cpython-39.pyc +0 -0
  3. easyeditor/models/__init__.py +2 -0
  4. easyeditor/models/__pycache__/__init__.cpython-39.pyc +0 -0
  5. easyeditor/models/grace/GRACE.py +80 -59
  6. easyeditor/models/grace/__pycache__/GRACE.cpython-39.pyc +0 -0
  7. easyeditor/models/grace/__pycache__/__init__.cpython-39.pyc +0 -0
  8. easyeditor/models/grace/__pycache__/grace_hparams.cpython-39.pyc +0 -0
  9. easyeditor/models/grace/__pycache__/grace_main.cpython-39.pyc +0 -0
  10. easyeditor/models/grace/__pycache__/metrics.cpython-39.pyc +0 -0
  11. easyeditor/models/grace/__pycache__/utils.cpython-39.pyc +0 -0
  12. easyeditor/models/grace/grace_main.py +3 -4
  13. easyeditor/models/rome/README.md +12 -0
  14. easyeditor/models/rome/__init__.py +1 -0
  15. easyeditor/models/rome/compute_u.py +125 -0
  16. easyeditor/models/rome/compute_v.py +278 -0
  17. easyeditor/models/rome/layer_stats.py +198 -0
  18. easyeditor/models/rome/repr_tools.py +174 -0
  19. easyeditor/models/rome/rome_hparams.py +55 -0
  20. easyeditor/models/rome/rome_main.py +192 -0
  21. easyeditor/models/rome/tok_dataset.py +99 -0
  22. easyeditor/models/wise/.DS_Store +0 -0
  23. easyeditor/models/wise/WISE.py +466 -0
  24. easyeditor/models/wise/__init__.py +2 -0
  25. easyeditor/models/wise/merge/__init__.py +3 -0
  26. easyeditor/models/wise/merge/gta.py +113 -0
  27. easyeditor/models/wise/merge/linear.py +24 -0
  28. easyeditor/models/wise/merge/slerp.py +90 -0
  29. easyeditor/models/wise/merge/utils.py +45 -0
  30. easyeditor/models/wise/utils.py +213 -0
  31. easyeditor/models/wise/wise_hparams.py +56 -0
  32. easyeditor/models/wise/wise_main.py +38 -0
  33. easyeditor/util/__pycache__/__init__.cpython-39.pyc +0 -0
  34. easyeditor/util/__pycache__/hparams.cpython-39.pyc +0 -0
  35. easyeditor/util/__pycache__/logit_lens.cpython-39.pyc +0 -0
  36. easyeditor/util/__pycache__/nethook.cpython-39.pyc +0 -0
  37. hparams/GRACE/gpt2.yaml +1 -1
  38. hparams/ROME/gpt2.yaml +26 -0
  39. hparams/WISE/gpt2.yaml +27 -0
  40. utils.py +214 -23
app.py CHANGED
@@ -1,13 +1,20 @@
1
  import gradio as gr
2
  from utils import *
3
  from transformers import pipeline
4
-
5
- css = """
6
-
7
- """
 
 
 
 
8
 
9
  ori_model = None
10
  edit_model = None
 
 
 
11
  # input=None
12
 
13
  def slowly_reverse(word, progress=gr.Progress()):
@@ -20,91 +27,199 @@ def slowly_reverse(word, progress=gr.Progress()):
20
  new_string = letter + new_string
21
  return new_string
22
 
23
- with gr.Blocks(css=css,theme=gr.themes.Soft(text_size="sm")) as demo:
24
- with gr.Row(equal_height=True):
25
- gr.HTML(
26
- """
27
- <div style="display: flex; flex-direction: column; align-items: center;">
28
- <h1>🔧EasyEdit: An Easy-to-use Knowledge Editing Framework for Large Language Models</h1>
29
-
30
- <p>
31
- 📑[<a href="https://huggingface.co/papers/2308.07269">Paper</a>]
32
- 👨‍💻[<a href="https://github.com/zjunlp/EasyEdit" target="_blank"><span class="icon"><i class="fab fa-github"></i></span>Code</a>]
33
- 📄[<a href="https://zjunlp.gitbook.io/easyedit">Docs</a>]
34
- 🤗[<a href="https://huggingface.co/spaces/zjunlp/EasyEdit" target="_blank">Demo</a>]
35
- [<a href="https://arxiv.org/abs/2211.11031">via GRACE</a>]
36
- </p>
37
- </div>
38
- """
39
- )
40
- # gr.HTML("""<div style="text-align: center; margin: 0 auto;"><p><h1> Knowledge Editing</h1></div>""")
41
-
 
 
 
 
 
 
 
 
 
 
42
  # with gr.Row():
43
- # gr.Markdown("<p align='center'><a href='https://github.com/zjunlp/EasyEdit'>🔧https://github.com/zjunlp/EasyEdit</a></p>")
44
-
45
  with gr.Row():
46
- gr.Markdown("#### Knowledge editing aims to subtly inject/edit updated knowledge or adjust undesirable behaviors, while minimizing the impact on unrelated inputs.")
47
- with gr.Accordion("Expiation", open=False):
48
- gr.Markdown(
49
  """
50
- Edit Steps: the number of times a layer is trained in the GRACE method.
51
  """
52
  )
53
- gr.Markdown(
 
54
  """
55
- Replacement: the optimization strategy during fine-tuning.
56
  """
57
  )
58
- gr.Markdown(
59
- """
60
- Reliability Evaluation: the optimization strategy during fine-tuning.
61
- """
62
- )
63
- gr.Markdown(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  """
65
- Reliability Evaluation: the assessment of whether the target edit can be accomplished.
66
  """
67
- )
68
- gr.Markdown(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  """
70
- Locality Evaluation: the assessment of whether unrelated content has been affected..
71
  """
72
- )
73
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  with gr.Row():
75
  prompt = gr.Textbox(label="Edit Prompt")
76
  target_new = gr.Textbox(label="Edit Target New")
77
  with gr.Row():
 
78
  num_steps = gr.Slider(10, 100, value=40, step=1, label='Edit Steps')
79
- replacement = gr.Dropdown(
80
- choices=["replace_last", "replace_all", "replace_prompt"],
81
- value="replace_last",
82
- label="Replacement",
83
  )
84
- with gr.Row():
85
- button4clear = gr.Button("Clear")
86
- button4edit = gr.Button("Edit",variant="primary")
87
  with gr.Row():
88
  examples = gr.Examples(
89
  examples=[
90
- ["Who is the architect for Toodyay Fire Station?","Wong Tung & Sons"],
91
- ["Who is Claire Clairmont\'s sister?","Clairmont-Mayer"],
92
- ["Which fictional universe is Chlorophyll Kid part of?","Image Universe"]
 
93
  ],
94
- examples_per_page=3,
95
  inputs=[prompt,target_new],
96
  )
 
 
97
  # with gr.Row():
98
  # input_text = gr.Textbox(label="Status Information",value="Model editing may take about a minute, please be patient.")
99
  with gr.Row():
100
  gr.HTML(
101
  """
102
- <h3>Reliability Evaluation</h3>
103
  """
104
  )
105
  with gr.Row():
106
- input = gr.Textbox(label="Input Text")
 
 
 
 
 
 
 
 
107
  with gr.Row():
 
108
  with gr.Column():
109
  button4gen_ori=gr.HighlightedText(
110
  label="original output",
@@ -119,25 +234,65 @@ with gr.Blocks(css=css,theme=gr.themes.Soft(text_size="sm")) as demo:
119
  show_legend=False,
120
  color_map={"output": "yellow"},
121
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  with gr.Row():
123
  button4gen = gr.Button("Generate",variant="primary")
124
 
125
  with gr.Row():
126
  gr.HTML(
127
  """
128
- <h3>Locality Evaluation</h3>
129
  """
130
  )
131
  with gr.Row():
132
  loc_input = gr.Dropdown(
133
  choices=[
134
- "who sang the theme song for laverne and shirley",
135
- "when does the last episode of adventure time air",
136
- "who plays alec ramsay in the black stallion",
137
- "where did an independence movement occur because of the congress of vienna",
138
- "where is the ucla usc game being played"
139
  ],
140
- value="where is the ucla usc game being played",
141
  label="Unrelated Input Text",
142
  )
143
  with gr.Row():
@@ -158,20 +313,76 @@ with gr.Blocks(css=css,theme=gr.themes.Soft(text_size="sm")) as demo:
158
  with gr.Row():
159
  button4locgen = gr.Button("Generate",variant="primary")
160
 
161
- button4clear.click(lambda: ("", ""), outputs=[prompt,target_new])
162
- button4edit.click(fn=edit, inputs=[prompt,target_new, num_steps, replacement], outputs=input)
163
- button4gen.click(fn=generate, inputs=[input, target_new], outputs=[button4gen_ori, button4gen_edit])
164
- button4locgen.click(fn=generate, inputs=loc_input, outputs=[button4gen_loc_ori, button4gen_loc_edit])
165
 
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  with gr.Accordion("Citation", open=False):
168
  gr.Markdown(
169
  """
170
  ```bibtex
171
- @misc{wang2023easyedit,
172
  title={EasyEdit: An Easy-to-use Knowledge Editing Framework for Large Language Models},
173
- author={Peng Wang and Ningyu Zhang and Xin Xie and Yunzhi Yao and Bozhong Tian and Mengru Wang and Zekun Xi and Siyuan Cheng and Kangwei Liu and Guozhou Zheng and Huajun Chen},
174
- year={2023},
175
  eprint={2308.07269},
176
  archivePrefix={arXiv},
177
  primaryClass={cs.CL}
@@ -180,5 +391,4 @@ with gr.Blocks(css=css,theme=gr.themes.Soft(text_size="sm")) as demo:
180
  """
181
  )
182
 
183
-
184
  demo.launch()
 
1
  import gradio as gr
2
  from utils import *
3
  from transformers import pipeline
4
+ import random
5
+ import torch
6
+ import numpy as np
7
+ seed=0
8
+ random.seed(seed)
9
+ torch.manual_seed(seed)
10
+ np.random.seed(seed)
11
+ torch.cuda.manual_seed_all(seed)
12
 
13
  ori_model = None
14
  edit_model = None
15
+
16
+ css = '''
17
+ '''
18
  # input=None
19
 
20
  def slowly_reverse(word, progress=gr.Progress()):
 
27
  new_string = letter + new_string
28
  return new_string
29
 
30
+ def single_edit_tab():
31
+ with gr.Row():
32
+ prompt = gr.Textbox(label="Edit Prompt")
33
+ target_new = gr.Textbox(label="Edit Target New")
34
+ with gr.Row():
35
+ edit_alg = gr.Dropdown(
36
+ choices=['ROME', 'WISE', 'GRACE'],
37
+ value='WISE',
38
+ label="Edit Algorithm",
39
+ )
40
+ num_steps = gr.Slider(10, 100, value=40, step=1, label='Edit Steps')
41
+ edit_lr = gr.Dropdown(
42
+ choices=[0.1, 0.5, 1.0],
43
+ value=1.0,
44
+ label="Edit LR (learning rate)",
45
+ )
46
+ with gr.Row():
47
+ examples = gr.Examples(
48
+ examples=[
49
+ ["Who is the architect for Toodyay Fire Station?","Wong Tung & Sons"],
50
+ ["What company makes Springfield Armory XDM?","Messerschmitt"],
51
+ ["Which fictional universe is Chlorophyll Kid part of?","Image Universe"]
52
+ ],
53
+ examples_per_page=3,
54
+ inputs=[prompt,target_new],
55
+ )
56
+ with gr.Row():
57
+ button4clear = gr.Button("Clear")
58
+ button4edit = gr.Button("Edit",variant="primary")
59
  # with gr.Row():
60
+ # input_text = gr.Textbox(label="Status Information",value="Model editing may take about a minute, please be patient.")
 
61
  with gr.Row():
62
+ gr.HTML(
 
 
63
  """
64
+ <h3>Evaluation</h3>
65
  """
66
  )
67
+ with gr.Row():
68
+ gr.HTML(
69
  """
70
+ <h4>Reliability</h4>
71
  """
72
  )
73
+ # with gr.Row():
74
+ # input = gr.Textbox(label="Input Text")
75
+ # target = gr.Textbox(label="Input Answer", visible=False)
76
+ target = gr.Textbox(label="Answer", visible=False)
77
+ with gr.Row():
78
+ input = gr.Textbox(label="Edit Prompt")
79
+ with gr.Column():
80
+ button4gen_ori=gr.HighlightedText(
81
+ label="original output",
82
+ combine_adjacent=True,
83
+ show_legend=False,
84
+ color_map={"output": "yellow"},
85
+ )
86
+ with gr.Column():
87
+ button4gen_edit=gr.HighlightedText(
88
+ label="edited output",
89
+ combine_adjacent=True,
90
+ show_legend=False,
91
+ color_map={"output": "yellow"},
92
+ )
93
+ with gr.Row():
94
+ gr.HTML(
95
  """
96
+ <h4>Generalization</h4>
97
  """
98
+ )
99
+ with gr.Row():
100
+ para_input = gr.Textbox(label="Paraphrase Prompt")
101
+ with gr.Column():
102
+ button4gen_para_ori=gr.HighlightedText(
103
+ label="original output",
104
+ combine_adjacent=True,
105
+ show_legend=False,
106
+ color_map={"output": "blue"},
107
+ )
108
+ with gr.Column():
109
+ button4gen_para_edit=gr.HighlightedText(
110
+ label="edited output",
111
+ combine_adjacent=True,
112
+ show_legend=False,
113
+ color_map={"output": "blue"},
114
+ )
115
+ with gr.Row():
116
+ examples = gr.Examples(
117
+ examples=[
118
+ ["Who is the architect for Toodyay Fire Station?", "Who was responsible for the planning of the Toodyay Fire Station", "Wong Tung & Sons"],
119
+ ["What company makes Springfield Armory XDM?", "Which company produced Springfield Armory XDM?", "Messerschmitt"],
120
+ ["Which fictional universe is Chlorophyll Kid part of?", "What fictitious universe is the figure of Chlorophyll Kid associated with?", "Image Universe"]
121
+ ],
122
+ examples_per_page=3,
123
+ inputs=[input, para_input, target],
124
+ label='Evaluation Examples'
125
+ )
126
+ with gr.Row():
127
+ button4gen = gr.Button("Generate",variant="primary")
128
+
129
+ with gr.Row():
130
+ gr.HTML(
131
  """
132
+ <h4>Locality</h4>
133
  """
134
+ )
135
+ with gr.Row():
136
+ loc_input = gr.Dropdown(
137
+ choices=[
138
+ "nq question: where does the phrase good bye felicia come from",
139
+ "nq question: which best describes timbuktu under the mali empire",
140
+ "nq question: where do the question marks go in spanish",
141
+ "nq question: who replaces the vice president in the senate",
142
+ "nq question: active transport performs which function in a cell"
143
+ ],
144
+ value="nq question: which best describes timbuktu under the mali empire",
145
+ label="Unrelated Input Text",
146
+ )
147
+ with gr.Row():
148
+ with gr.Column():
149
+ button4gen_loc_ori=gr.HighlightedText(
150
+ label="original output",
151
+ combine_adjacent=True,
152
+ show_legend=False,
153
+ color_map={"output": "green"},
154
+ )
155
+ with gr.Column():
156
+ button4gen_loc_edit=gr.HighlightedText(
157
+ label="edited output",
158
+ combine_adjacent=True,
159
+ show_legend=False,
160
+ color_map={"output": "green"},
161
+ )
162
+ with gr.Row():
163
+ button4locgen = gr.Button("Generate",variant="primary")
164
+
165
+ button4clear.click(fn=clear, outputs=[prompt,target_new])
166
+ button4edit.click(fn=edit, inputs=[edit_alg, prompt,target_new, num_steps, edit_lr], outputs=[input, target])
167
+ button4gen.click(fn=union_generate, inputs=[input, para_input, target, edit_alg], outputs=[button4gen_ori, button4gen_edit, button4gen_para_ori, button4gen_para_edit])
168
+ # button4gen.click(fn=generate, inputs=[para_input, target, edit_alg], outputs=[button4gen_para_ori, button4gen_para_edit])
169
+ button4locgen.click(fn=generate, inputs=loc_input, outputs=[button4gen_loc_ori, button4gen_loc_edit])
170
+
171
+ def continuous_edit_tab():
172
+ with gr.Row():
173
+ # edit_alg = gr.Dropdown(
174
+ # choices=['WISE', 'GRACE'],
175
+ # value='WISE',
176
+ # label="Edit Algorithm",
177
+ # )
178
+ edit_alg = gr.Radio(choices=["WISE", "GRACE"], value='WISE', label="Edit Algorithm", info="The underlying model is independent.")
179
  with gr.Row():
180
  prompt = gr.Textbox(label="Edit Prompt")
181
  target_new = gr.Textbox(label="Edit Target New")
182
  with gr.Row():
183
+
184
  num_steps = gr.Slider(10, 100, value=40, step=1, label='Edit Steps')
185
+ edit_lr = gr.Dropdown(
186
+ choices=[0.1, 0.5, 1.0],
187
+ value=1.0,
188
+ label="Edit LR (learning rate)",
189
  )
 
 
 
190
  with gr.Row():
191
  examples = gr.Examples(
192
  examples=[
193
+ ["What is the date of birth for Christoph von Stadion?", "12 April 1809"],
194
+ ["What medical condition killed Ramesses V?", "esses IV"],
195
+ ["What voice type is Nellie Briercliffe?", "mezzo-oprano"],
196
+ ["What network is 1000 Ways to Die associated with?", "The CW"]
197
  ],
198
+ examples_per_page=4,
199
  inputs=[prompt,target_new],
200
  )
201
+ with gr.Row():
202
+ button4edit = gr.Button("Edit",variant="primary")
203
  # with gr.Row():
204
  # input_text = gr.Textbox(label="Status Information",value="Model editing may take about a minute, please be patient.")
205
  with gr.Row():
206
  gr.HTML(
207
  """
208
+ <h3>Evaluation</h3>
209
  """
210
  )
211
  with gr.Row():
212
+ gr.HTML(
213
+ """
214
+ <h4>Reliability</h4>
215
+ """
216
+ )
217
+ # with gr.Row():
218
+ # input = gr.Textbox(label="Input Text")
219
+ # target = gr.Textbox(label="Input Answer", visible=False)
220
+ target = gr.Textbox(label="Answer", visible=False)
221
  with gr.Row():
222
+ input = gr.Textbox(label="Edit Prompt")
223
  with gr.Column():
224
  button4gen_ori=gr.HighlightedText(
225
  label="original output",
 
234
  show_legend=False,
235
  color_map={"output": "yellow"},
236
  )
237
+ with gr.Row():
238
+ gr.HTML(
239
+ """
240
+ <h4>Generalization</h4>
241
+ """
242
+ )
243
+ with gr.Row():
244
+ para_input = gr.Textbox(label="Paraphrase Prompt")
245
+ with gr.Column():
246
+ button4gen_para_ori=gr.HighlightedText(
247
+ label="original output",
248
+ combine_adjacent=True,
249
+ show_legend=False,
250
+ color_map={"output": "blue"},
251
+ )
252
+ with gr.Column():
253
+ button4gen_para_edit=gr.HighlightedText(
254
+ label="edited output",
255
+ combine_adjacent=True,
256
+ show_legend=False,
257
+ color_map={"output": "blue"},
258
+ )
259
+ with gr.Row():
260
+ examples = gr.Examples(
261
+ examples=[
262
+ ["Who is the architect for Toodyay Fire Station?", "Who was responsible for the planning of the Toodyay Fire Station", "Wong Tung & Sons"],
263
+ ["What company makes Springfield Armory XDM?", "Which company produced Springfield Armory XDM?", "Messerschmitt"],
264
+ ["Which fictional universe is Chlorophyll Kid part of?", "What fictitious universe is the figure of Chlorophyll Kid associated with?", "Image Universe"],
265
+ ["What year did Sunnyside Hospital cease to exist?", "What year was the end of Sunnyside Hospital?", "1962"],
266
+ ["Which designer was responsible for Holmenkollen Chapel?", "Which designer is responsible for Holmenkollen Chapel?", "Inigo Jones"],
267
+ ["What piece of fiction does Jack Harkness appear in?", "What fictional work does Jack Harkness exist in?", "Lost"],
268
+ ["What is the date of birth for Christoph von Stadion?", "What is Christoph von Stadion's birth date?", "12 April 1809"],
269
+ ["What medical condition killed Ramesses V?", "What kind of disease killed Ramesses V?", "esses IV"],
270
+ ["What voice type is Nellie Briercliffe?", "Which was the voice type that Nellie Briercliffe had?", "mezzo-oprano"],
271
+ ["What network is 1000 Ways to Die associated with?", "The show 1000 Ways to Die was originally broadcast in which network?", "The CW"]
272
+ ],
273
+ examples_per_page=10,
274
+ inputs=[input, para_input, target],
275
+ label='Evaluation Examples'
276
+ )
277
  with gr.Row():
278
  button4gen = gr.Button("Generate",variant="primary")
279
 
280
  with gr.Row():
281
  gr.HTML(
282
  """
283
+ <h4>Locality</h4>
284
  """
285
  )
286
  with gr.Row():
287
  loc_input = gr.Dropdown(
288
  choices=[
289
+ "nq question: where does the phrase good bye felicia come from",
290
+ "nq question: which best describes timbuktu under the mali empire",
291
+ "nq question: where do the question marks go in spanish",
292
+ "nq question: who replaces the vice president in the senate",
293
+ "nq question: active transport performs which function in a cell"
294
  ],
295
+ value="nq question: which best describes timbuktu under the mali empire",
296
  label="Unrelated Input Text",
297
  )
298
  with gr.Row():
 
313
  with gr.Row():
314
  button4locgen = gr.Button("Generate",variant="primary")
315
 
316
+ button4edit.click(fn=continuous_edit, inputs=[edit_alg, prompt,target_new, num_steps, edit_lr], outputs=[input, target])
317
+ button4gen.click(fn=continuous_union_generate, inputs=[input, para_input, target, edit_alg], outputs=[button4gen_ori, button4gen_edit, button4gen_para_ori, button4gen_para_edit])
318
+ # button4gen.click(fn=generate, inputs=[para_input, target, edit_alg], outputs=[button4gen_para_ori, button4gen_para_edit])
319
+ button4locgen.click(fn=continuous_generate, inputs=[loc_input, edit_alg], outputs=[button4gen_loc_ori, button4gen_loc_edit])
320
 
321
 
322
+ with gr.Blocks(css=css,theme=gr.themes.Soft(text_size="sm")) as demo:
323
+ with gr.Row(equal_height=True):
324
+ gr.HTML(
325
+ """
326
+ <div style="display: flex; flex-direction: column; align-items: center;">
327
+ <h1>🔧EasyEdit: An Easy-to-use Knowledge Editing Framework for Large Language Models</h1>
328
+
329
+ <p>
330
+ 📑[<a href="https://huggingface.co/papers/2308.07269">Paper</a>]
331
+ 👨‍💻[<a href="https://github.com/zjunlp/EasyEdit" target="_blank"><span class="icon"><i class="fab fa-github"></i></span>Code</a>]
332
+ 📄[<a href="https://zjunlp.gitbook.io/easyedit">Docs</a>]
333
+ 🤗[<a href="https://huggingface.co/spaces/zjunlp/EasyEdit" target="_blank">Demo</a>]
334
+ </p>
335
+ </div>
336
+ """
337
+ )
338
+ # gr.HTML("""<div style="text-align: center; margin: 0 auto;"><p><h1> Knowledge Editing</h1></div>""")
339
+
340
+ # with gr.Row():
341
+ # gr.Markdown("<p align='center'><a href='https://github.com/zjunlp/EasyEdit'>🔧https://github.com/zjunlp/EasyEdit</a></p>")
342
+
343
+ with gr.Row():
344
+ gr.Markdown("#### Knowledge editing aims to subtly inject/edit updated knowledge or adjust undesirable behaviors, while minimizing the impact on unrelated inputs.")
345
+ with gr.Accordion("Explanation", open=False):
346
+ gr.Markdown(
347
+ """
348
+ Edit Steps: the number of times a layer is trained in the editing method.
349
+ """
350
+ )
351
+ gr.Markdown(
352
+ """
353
+ Edit LR (learning rate): the optimization strategy during fine-tuning.
354
+ """
355
+ )
356
+ gr.Markdown(
357
+ """
358
+ Reliability Evaluation: the optimization strategy during fine-tuning.
359
+ """
360
+ )
361
+ gr.Markdown(
362
+ """
363
+ Reliability Evaluation: the assessment of whether the target edit can be accomplished.
364
+ """
365
+ )
366
+ gr.Markdown(
367
+ """
368
+ Locality Evaluation: the assessment of whether unrelated content has been affected..
369
+ """
370
+ )
371
+
372
+ with gr.Tab("Single Knowledge Editing"):
373
+ single_edit_tab()
374
+
375
+ with gr.Tab("Continuous Knowledge Editing"):
376
+ continuous_edit_tab()
377
+
378
  with gr.Accordion("Citation", open=False):
379
  gr.Markdown(
380
  """
381
  ```bibtex
382
+ @misc{wang2024easyedit,
383
  title={EasyEdit: An Easy-to-use Knowledge Editing Framework for Large Language Models},
384
+ author={Peng Wang and Ningyu Zhang and Bozhong Tian and Zekun Xi and Yunzhi Yao and Ziwen Xu and Mengru Wang and Shengyu Mao and Xiaohan Wang and Siyuan Cheng and Kangwei Liu and Yuansheng Ni and Guozhou Zheng and Huajun Chen},
385
+ year={2024},
386
  eprint={2308.07269},
387
  archivePrefix={arXiv},
388
  primaryClass={cs.CL}
 
391
  """
392
  )
393
 
 
394
  demo.launch()
easyeditor/__pycache__/__init__.cpython-39.pyc DELETED
Binary file (182 Bytes)
 
easyeditor/models/__init__.py CHANGED
@@ -1 +1,3 @@
1
  from .grace import *
 
 
 
1
  from .grace import *
2
+ from .wise import *
3
+ from .rome import *
easyeditor/models/__pycache__/__init__.cpython-39.pyc DELETED
Binary file (172 Bytes)
 
easyeditor/models/grace/GRACE.py CHANGED
@@ -29,13 +29,13 @@
29
  # layer = config.inner_params[0]
30
  # self.device = device
31
 
32
- # # --- ensure proper formatting (GRACE edits ~layers~ not weights matrices) ---
33
  # suffixes = [".weight", ".bias"]
34
  # self.layer = layer.rsplit(".", 1)[0] if any(layer.endswith(x) for x in suffixes) else layer
35
-
36
  # for n, p in self.model.named_parameters():
37
  # p.requires_grad = False
38
-
39
  # if isinstance(self.model, transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel):
40
  # transpose = False
41
  # else:
@@ -48,32 +48,32 @@
48
 
49
  # if type(original_layer) is not GRACEAdapter:
50
  # setattr(edit_module, layer_name, GRACEAdapter(config, original_layer, transpose=transpose).to(self.device))
51
-
52
  # def __call__(self, **kwargs):
53
  # # if self.config.task == "hallucination":
54
  # # print(kwargs)
55
  # # key_id = (kwargs["labels"] == -100).sum() - 1
56
  # # setattr(eval(f"self.model.{self.layer}"), "key_id", key_id) # Tell GRACE which token to use for its query (default is the last token)
57
  # return self.model(**kwargs)
58
-
59
  # def generate(self, *args, **kwargs):
60
  # setattr(eval(f"self.model.{self.layer}"), "key_id", -1)
61
  # return self.model.generate(*args, **kwargs)
62
-
63
  # def edit(self, config, tokens):
64
  # key_id = (tokens["labels"] == -100).sum() - 1
65
  # setattr(eval(f"self.model.{self.layer}"), "key_id", key_id)
66
-
67
  # # --- pass edit label, training mode, and key_id into GRACE ---
68
  # setattr(eval(f"self.model.{self.layer}"), "training", True)
69
  # setattr(eval(f"self.model.{self.layer}"), "edit_label", tokens["labels"])
70
-
71
  # self.losses = []
72
  # # --- train GRACE value ---
73
  # for i in range(config.n_iter):
74
  # # --- insert iteration into each layer (only initiate keys on iteration 1) ---
75
  # setattr(eval(f"self.model.{self.layer}"), "iter", i)
76
-
77
  # # --- pass tokens through model (including through the GRACE layer) ---
78
  # outputs = self.model(**tokens)
79
  # if i == 0:
@@ -84,14 +84,14 @@
84
  # optimizer.step()
85
  # optimizer.zero_grad()
86
  # self.losses.append(loss.detach().cpu().numpy())
87
-
88
  # self.loss = loss # Log final loss
89
 
90
  # # --- pull out info we want to log from the GRACE layer ---
91
  # setattr(eval(f"self.model.{self.layer}"), "training", False)
92
  # chosen_key = getattr(eval(f"self.model.{self.layer}"), "chosen_key")
93
  # nkeys = len(getattr(eval(f"self.model.{self.layer}"), "keys"))
94
-
95
  # self.log_dict["chosen_key"] = chosen_key
96
  # self.log_dict["nkeys"] = nkeys
97
 
@@ -109,7 +109,7 @@
109
  # self.num_pert = config.num_pert
110
  # self.key_id = -1
111
  # self.ensure_replace_token_loc = False
112
-
113
  # if transpose:
114
  # self.key_shape = layer.weight.shape[1]
115
  # self.value_shape = layer.weight.shape[0]
@@ -142,7 +142,7 @@
142
  # def split_epsilons_in_half(self, nearest_key, smallest_distance):
143
  # self.epsilons[nearest_key] = (smallest_distance / 2) - 1e-5 # Cut nearest epsilon in half
144
  # self.epsilons[-1] = smallest_distance / 2 # Cut new epsilon in half
145
-
146
  # def forward(self, *args):
147
  # # Run layer forward and save what it would have returned for this instance
148
  # layer_out = self.layer(*args)
@@ -176,7 +176,7 @@
176
  # smallest_distance, nearest_key = dists.min(0)
177
 
178
  # if smallest_distance > (self.init_epsilon + self.epsilons[nearest_key]):
179
- # # If there's no close key, make a new key
180
  # self.keys, self.values, self.epsilons, self.key_labels = self.add_key(query, new_value)
181
  # else:
182
  # # If there is a close key, we need to handle conflicts
@@ -222,23 +222,27 @@ import torch
222
  from .utils import parent_module, brackets_to_periods
223
  import transformers
224
  import os
 
225
  os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
226
 
 
227
  def euc(query, key):
228
  # Euclidean distance
229
  if len(key.shape) < 2:
230
  key = key.view(1, -1)
231
  return torch.cdist(key, query, p=2)
232
 
 
233
  def perturb_values(chosen_value, num_pert, device):
234
  # Create a bunch of noised versions of the value, then create batch, then train value
235
  chosen_value = chosen_value
236
  noise = torch.normal(0, 1, chosen_value.shape, device=device)
237
- noise[0] = noise[0]*0
238
  noise.requires_grad = True
239
  chosen_value = chosen_value + noise
240
  return chosen_value
241
 
 
242
  class GRACE(torch.nn.Module):
243
  def __init__(self, config, model, device):
244
  super(GRACE, self).__init__()
@@ -251,26 +255,27 @@ class GRACE(torch.nn.Module):
251
  self.device = device
252
  self.original_layer = None
253
 
254
- # --- ensure proper formatting (GRACE edits ~layers~ not weights matrices) ---
255
  suffixes = [".weight", ".bias"]
256
  self.layer = layer.rsplit(".", 1)[0] if any(layer.endswith(x) for x in suffixes) else layer
257
-
258
  for n, p in self.model.named_parameters():
259
  p.requires_grad = False
260
-
261
  if isinstance(self.model, transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel):
262
  transpose = False
263
  else:
264
  transpose = True
265
 
266
  # --- Add GRACE to chosen layers ---
267
- edit_module = parent_module(self.model, brackets_to_periods(self.layer))
268
- layer_name = self.layer.rsplit(".", 1)[-1]
269
- original_layer = getattr(edit_module, layer_name)
270
  if type(original_layer) is not GRACEAdapter:
271
- setattr(edit_module, layer_name, GRACEAdapter(config, original_layer, transpose=transpose).to(self.device))
 
272
  self.original_layer = copy.deepcopy(original_layer)
273
-
274
  def __call__(self, **kwargs):
275
  # if self.config.task == "hallucination":
276
  # print(kwargs)
@@ -278,55 +283,65 @@ class GRACE(torch.nn.Module):
278
  # setattr(eval(f"self.model.{self.layer}"), "key_id", key_id) # Tell GRACE which token to use for its query (default is the last token)
279
  return self.model(**kwargs)
280
 
 
 
 
 
 
281
  def reset_layer(self):
282
- layer_name = self.layer.rsplit(".", 1)[-1]
283
- edit_module = parent_module(self.model, brackets_to_periods(self.layer))
284
- setattr(edit_module, layer_name, self.original_layer.to(self.device))
285
 
286
  def generate(self, *args, **kwargs):
287
  setattr(eval(f"self.model.{self.layer}"), "key_id", -1)
288
  return self.model.generate(*args, **kwargs)
289
-
290
  def edit(self, config, tokens):
291
  key_id = (tokens["labels"] == -100).sum() - 1
292
  setattr(eval(f"self.model.{self.layer}"), "key_id", key_id)
293
-
294
  # --- pass edit label, training mode, and key_id into GRACE ---
295
  setattr(eval(f"self.model.{self.layer}"), "training", True)
296
  setattr(eval(f"self.model.{self.layer}"), "edit_label", tokens["labels"])
297
-
298
  self.losses = []
299
  # --- train GRACE value ---
300
  for i in range(config.n_iter):
301
  # --- insert iteration into each layer (only initiate keys on iteration 1) ---
302
  setattr(eval(f"self.model.{self.layer}"), "iter", i)
303
-
304
  # --- pass tokens through model (including through the GRACE layer) ---
305
  outputs = self.model(**tokens)
306
  if i == 0:
307
  # --- we only need to create an optimizer for the first iteration (but forward pass instantiates the key, so optimzer is passed after first inference) ---
308
  optimizer = torch.optim.Adam(self.model.parameters(), config.edit_lr)
309
  loss = outputs.loss
310
- loss.backward()
311
- optimizer.step()
312
- optimizer.zero_grad()
313
- self.losses.append(loss.detach().cpu().numpy())
314
-
315
- self.loss = loss # Log final loss
 
 
 
316
 
317
  # --- pull out info we want to log from the GRACE layer ---
318
  setattr(eval(f"self.model.{self.layer}"), "training", False)
319
  chosen_key = getattr(eval(f"self.model.{self.layer}"), "chosen_key")
320
  nkeys = len(getattr(eval(f"self.model.{self.layer}"), "keys"))
321
-
322
- self.log_dict["chosen_key"] = chosen_key
323
  self.log_dict["nkeys"] = nkeys
324
 
 
325
  class GRACEAdapter(torch.nn.Module):
326
  def __init__(self, config, layer, transpose):
327
  super(GRACEAdapter, self).__init__()
328
 
329
  self.layer = layer
 
330
  self.weight = self.layer.weight
331
  self.init_epsilon = config.eps
332
  self.dist_fn = config.dist_fn
@@ -335,8 +350,7 @@ class GRACEAdapter(torch.nn.Module):
335
  self.config = config
336
  self.num_pert = config.num_pert
337
  self.key_id = -1
338
- self.ensure_replace_token_loc = False
339
-
340
  if transpose:
341
  self.key_shape = layer.weight.shape[1]
342
  self.value_shape = layer.weight.shape[0]
@@ -346,14 +360,15 @@ class GRACEAdapter(torch.nn.Module):
346
  self.training = False
347
 
348
  def add_key(self, new_key, new_value):
349
- keys = torch.vstack([self.keys, new_key.detach()]) # Add new key to list of keys
350
 
351
- values = torch.nn.Parameter(torch.vstack([self.values, new_value]), requires_grad=True) # Add new value to list of values
 
352
 
353
  new_epsilon = torch.tensor(self.init_epsilon, device=self.device).view(1)
354
- epsilons = torch.vstack([self.epsilons, new_epsilon]) # Add new epsilon to list of epsilons
355
 
356
- key_labels = self.key_labels + [self.edit_label] # Add new key_label to list of key_labels
357
 
358
  return keys, values, epsilons, key_labels
359
 
@@ -367,9 +382,9 @@ class GRACEAdapter(torch.nn.Module):
367
  return edit_label.float().mean() == key_label.float().mean()
368
 
369
  def split_epsilons_in_half(self, nearest_key, smallest_distance):
370
- self.epsilons[nearest_key] = (smallest_distance / 2) - 1e-5 # Cut nearest epsilon in half
371
- self.epsilons[-1] = smallest_distance / 2 # Cut new epsilon in half
372
-
373
  def forward(self, *args):
374
  # Run layer forward and save what it would have returned for this instance
375
  layer_out = self.layer(*args)
@@ -380,13 +395,15 @@ class GRACEAdapter(torch.nn.Module):
380
  # print(self.__dict__)
381
  return layer_out
382
  else:
383
- if not self.training and not self.ensure_replace_token_loc and self.key_id == -1:
384
- token_to_edit = args[0].shape[1]-1
385
- self.key_id = args[0].shape[1]-1
386
- self.ensure_replace_token_loc = True
 
 
387
  else:
388
- token_to_edit = min(self.key_id, args[0].shape[1]-1) # args[0].shape[1] - 1 is sequence length
389
- query = args[0][:, token_to_edit, :] # Just use activation for last token
390
  if self.config.val_init == "cold":
391
  new_value = torch.nn.Parameter(torch.rand(1, self.value_shape, requires_grad=True, device=self.device))
392
  elif self.config.val_init == "warm":
@@ -403,7 +420,7 @@ class GRACEAdapter(torch.nn.Module):
403
  smallest_distance, nearest_key = dists.min(0)
404
 
405
  if smallest_distance > (self.init_epsilon + self.epsilons[nearest_key]):
406
- # If there's no close key, make a new key
407
  self.keys, self.values, self.epsilons, self.key_labels = self.add_key(query, new_value)
408
  else:
409
  # If there is a close key, we need to handle conflicts
@@ -413,11 +430,13 @@ class GRACEAdapter(torch.nn.Module):
413
  else:
414
  # If the current label is the SAME as the nearest label, just make the nearest epsilon bigger
415
  if smallest_distance > self.epsilons[nearest_key]:
416
- if self.config.eps_expand== "coverage":
417
- self.epsilons[nearest_key] = smallest_distance # Replace nearest epsilon with dist between old key and new key
 
418
  elif self.config.eps_expand == "moving_average":
419
  a = 0.5
420
- self.keys[nearest_key] = a*self.keys[nearest_key] + (1-a)*query # Move old key to be halfway between
 
421
  self.epsilons[nearest_key] = smallest_distance
422
  # self.epsilons[nearest_key] = smallest_distance + self.init_epsilon
423
  else:
@@ -435,11 +454,13 @@ class GRACEAdapter(torch.nn.Module):
435
  chosen_value = perturb_values(chosen_value, self.num_pert, self.device)
436
 
437
  if self.replacement == "replace_all":
438
- layer_out = torch.where((smallest_dist <= eps).view(-1, 1, 1), chosen_value.unsqueeze(1).repeat_interleave(layer_out.shape[1], 1), layer_out)
 
439
  elif self.replacement == "replace_last":
440
  layer_out[:, token_to_edit] = torch.where((smallest_dist <= eps), chosen_value, layer_out[:, token_to_edit])
441
  elif self.replacement == "replace_prompt":
442
- layer_out[:, :token_to_edit] = torch.where((smallest_dist <= eps), chosen_value, layer_out[:, :token_to_edit])
 
443
  else:
444
  print("token replacement choice not found")
445
  return layer_out
 
29
  # layer = config.inner_params[0]
30
  # self.device = device
31
 
32
+ # # --- ensure proper formatting (GRACE edits ~layers~ not weights matrices) ---
33
  # suffixes = [".weight", ".bias"]
34
  # self.layer = layer.rsplit(".", 1)[0] if any(layer.endswith(x) for x in suffixes) else layer
35
+
36
  # for n, p in self.model.named_parameters():
37
  # p.requires_grad = False
38
+
39
  # if isinstance(self.model, transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel):
40
  # transpose = False
41
  # else:
 
48
 
49
  # if type(original_layer) is not GRACEAdapter:
50
  # setattr(edit_module, layer_name, GRACEAdapter(config, original_layer, transpose=transpose).to(self.device))
51
+
52
  # def __call__(self, **kwargs):
53
  # # if self.config.task == "hallucination":
54
  # # print(kwargs)
55
  # # key_id = (kwargs["labels"] == -100).sum() - 1
56
  # # setattr(eval(f"self.model.{self.layer}"), "key_id", key_id) # Tell GRACE which token to use for its query (default is the last token)
57
  # return self.model(**kwargs)
58
+
59
  # def generate(self, *args, **kwargs):
60
  # setattr(eval(f"self.model.{self.layer}"), "key_id", -1)
61
  # return self.model.generate(*args, **kwargs)
62
+
63
  # def edit(self, config, tokens):
64
  # key_id = (tokens["labels"] == -100).sum() - 1
65
  # setattr(eval(f"self.model.{self.layer}"), "key_id", key_id)
66
+
67
  # # --- pass edit label, training mode, and key_id into GRACE ---
68
  # setattr(eval(f"self.model.{self.layer}"), "training", True)
69
  # setattr(eval(f"self.model.{self.layer}"), "edit_label", tokens["labels"])
70
+
71
  # self.losses = []
72
  # # --- train GRACE value ---
73
  # for i in range(config.n_iter):
74
  # # --- insert iteration into each layer (only initiate keys on iteration 1) ---
75
  # setattr(eval(f"self.model.{self.layer}"), "iter", i)
76
+
77
  # # --- pass tokens through model (including through the GRACE layer) ---
78
  # outputs = self.model(**tokens)
79
  # if i == 0:
 
84
  # optimizer.step()
85
  # optimizer.zero_grad()
86
  # self.losses.append(loss.detach().cpu().numpy())
87
+
88
  # self.loss = loss # Log final loss
89
 
90
  # # --- pull out info we want to log from the GRACE layer ---
91
  # setattr(eval(f"self.model.{self.layer}"), "training", False)
92
  # chosen_key = getattr(eval(f"self.model.{self.layer}"), "chosen_key")
93
  # nkeys = len(getattr(eval(f"self.model.{self.layer}"), "keys"))
94
+
95
  # self.log_dict["chosen_key"] = chosen_key
96
  # self.log_dict["nkeys"] = nkeys
97
 
 
109
  # self.num_pert = config.num_pert
110
  # self.key_id = -1
111
  # self.ensure_replace_token_loc = False
112
+
113
  # if transpose:
114
  # self.key_shape = layer.weight.shape[1]
115
  # self.value_shape = layer.weight.shape[0]
 
142
  # def split_epsilons_in_half(self, nearest_key, smallest_distance):
143
  # self.epsilons[nearest_key] = (smallest_distance / 2) - 1e-5 # Cut nearest epsilon in half
144
  # self.epsilons[-1] = smallest_distance / 2 # Cut new epsilon in half
145
+
146
  # def forward(self, *args):
147
  # # Run layer forward and save what it would have returned for this instance
148
  # layer_out = self.layer(*args)
 
176
  # smallest_distance, nearest_key = dists.min(0)
177
 
178
  # if smallest_distance > (self.init_epsilon + self.epsilons[nearest_key]):
179
+ # # If there's no close key, make a new key
180
  # self.keys, self.values, self.epsilons, self.key_labels = self.add_key(query, new_value)
181
  # else:
182
  # # If there is a close key, we need to handle conflicts
 
222
  from .utils import parent_module, brackets_to_periods
223
  import transformers
224
  import os
225
+
226
  os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
227
 
228
+
229
  def euc(query, key):
230
  # Euclidean distance
231
  if len(key.shape) < 2:
232
  key = key.view(1, -1)
233
  return torch.cdist(key, query, p=2)
234
 
235
+
236
  def perturb_values(chosen_value, num_pert, device):
237
  # Create a bunch of noised versions of the value, then create batch, then train value
238
  chosen_value = chosen_value
239
  noise = torch.normal(0, 1, chosen_value.shape, device=device)
240
+ noise[0] = noise[0] * 0
241
  noise.requires_grad = True
242
  chosen_value = chosen_value + noise
243
  return chosen_value
244
 
245
+
246
  class GRACE(torch.nn.Module):
247
  def __init__(self, config, model, device):
248
  super(GRACE, self).__init__()
 
255
  self.device = device
256
  self.original_layer = None
257
 
258
+ # --- ensure proper formatting (GRACE edits ~layers~ not weights matrices) ---
259
  suffixes = [".weight", ".bias"]
260
  self.layer = layer.rsplit(".", 1)[0] if any(layer.endswith(x) for x in suffixes) else layer
261
+
262
  for n, p in self.model.named_parameters():
263
  p.requires_grad = False
264
+
265
  if isinstance(self.model, transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel):
266
  transpose = False
267
  else:
268
  transpose = True
269
 
270
  # --- Add GRACE to chosen layers ---
271
+ self.edit_module = parent_module(self.model, brackets_to_periods(self.layer))
272
+ self.layer_name = self.layer.rsplit(".", 1)[-1]
273
+ original_layer = getattr(self.edit_module, self.layer_name)
274
  if type(original_layer) is not GRACEAdapter:
275
+ setattr(self.edit_module, self.layer_name,
276
+ GRACEAdapter(config, original_layer, transpose=transpose).to(self.device))
277
  self.original_layer = copy.deepcopy(original_layer)
278
+
279
  def __call__(self, **kwargs):
280
  # if self.config.task == "hallucination":
281
  # print(kwargs)
 
283
  # setattr(eval(f"self.model.{self.layer}"), "key_id", key_id) # Tell GRACE which token to use for its query (default is the last token)
284
  return self.model(**kwargs)
285
 
286
+ def get_adapter_layer(self):
287
+ adapter_layer = getattr(self.edit_module, self.layer_name)
288
+ assert type(adapter_layer) is GRACEAdapter, print('Adapter Layer is not added correctly....')
289
+ return adapter_layer
290
+
291
  def reset_layer(self):
292
+ layer = getattr(self.edit_module, self.layer_name)
293
+ del layer
294
+ setattr(self.edit_module, self.layer_name, self.get_adapter_layer().original_layer)
295
 
296
  def generate(self, *args, **kwargs):
297
  setattr(eval(f"self.model.{self.layer}"), "key_id", -1)
298
  return self.model.generate(*args, **kwargs)
299
+
300
  def edit(self, config, tokens):
301
  key_id = (tokens["labels"] == -100).sum() - 1
302
  setattr(eval(f"self.model.{self.layer}"), "key_id", key_id)
303
+
304
  # --- pass edit label, training mode, and key_id into GRACE ---
305
  setattr(eval(f"self.model.{self.layer}"), "training", True)
306
  setattr(eval(f"self.model.{self.layer}"), "edit_label", tokens["labels"])
307
+
308
  self.losses = []
309
  # --- train GRACE value ---
310
  for i in range(config.n_iter):
311
  # --- insert iteration into each layer (only initiate keys on iteration 1) ---
312
  setattr(eval(f"self.model.{self.layer}"), "iter", i)
313
+
314
  # --- pass tokens through model (including through the GRACE layer) ---
315
  outputs = self.model(**tokens)
316
  if i == 0:
317
  # --- we only need to create an optimizer for the first iteration (but forward pass instantiates the key, so optimzer is passed after first inference) ---
318
  optimizer = torch.optim.Adam(self.model.parameters(), config.edit_lr)
319
  loss = outputs.loss
320
+ try:
321
+ loss.backward()
322
+ optimizer.step()
323
+ optimizer.zero_grad()
324
+ self.losses.append(loss.detach().cpu().numpy())
325
+ except Exception as e:
326
+ pass
327
+
328
+ self.loss = loss # Log final loss
329
 
330
  # --- pull out info we want to log from the GRACE layer ---
331
  setattr(eval(f"self.model.{self.layer}"), "training", False)
332
  chosen_key = getattr(eval(f"self.model.{self.layer}"), "chosen_key")
333
  nkeys = len(getattr(eval(f"self.model.{self.layer}"), "keys"))
334
+
335
+ self.log_dict["chosen_key"] = chosen_key
336
  self.log_dict["nkeys"] = nkeys
337
 
338
+
339
  class GRACEAdapter(torch.nn.Module):
340
  def __init__(self, config, layer, transpose):
341
  super(GRACEAdapter, self).__init__()
342
 
343
  self.layer = layer
344
+ self.original_layer = copy.deepcopy(self.layer)
345
  self.weight = self.layer.weight
346
  self.init_epsilon = config.eps
347
  self.dist_fn = config.dist_fn
 
350
  self.config = config
351
  self.num_pert = config.num_pert
352
  self.key_id = -1
353
+
 
354
  if transpose:
355
  self.key_shape = layer.weight.shape[1]
356
  self.value_shape = layer.weight.shape[0]
 
360
  self.training = False
361
 
362
  def add_key(self, new_key, new_value):
363
+ keys = torch.vstack([self.keys, new_key.detach()]) # Add new key to list of keys
364
 
365
+ values = torch.nn.Parameter(torch.vstack([self.values, new_value]),
366
+ requires_grad=True) # Add new value to list of values
367
 
368
  new_epsilon = torch.tensor(self.init_epsilon, device=self.device).view(1)
369
+ epsilons = torch.vstack([self.epsilons, new_epsilon]) # Add new epsilon to list of epsilons
370
 
371
+ key_labels = self.key_labels + [self.edit_label] # Add new key_label to list of key_labels
372
 
373
  return keys, values, epsilons, key_labels
374
 
 
382
  return edit_label.float().mean() == key_label.float().mean()
383
 
384
  def split_epsilons_in_half(self, nearest_key, smallest_distance):
385
+ self.epsilons[nearest_key] = (smallest_distance / 2) - 1e-5 # Cut nearest epsilon in half
386
+ self.epsilons[-1] = smallest_distance / 2 # Cut new epsilon in half
387
+
388
  def forward(self, *args):
389
  # Run layer forward and save what it would have returned for this instance
390
  layer_out = self.layer(*args)
 
395
  # print(self.__dict__)
396
  return layer_out
397
  else:
398
+ if not self.training:
399
+ if self.key_id == -1:
400
+ token_to_edit = args[0].shape[1] - 1
401
+ self.key_id = args[0].shape[1] - 1
402
+ else:
403
+ token_to_edit = min(self.key_id, args[0].shape[1] - 1)
404
  else:
405
+ token_to_edit = min(self.key_id, args[0].shape[1] - 1) # args[0].shape[1] - 1 is sequence length
406
+ query = args[0][:, token_to_edit, :] # Just use activation for last token
407
  if self.config.val_init == "cold":
408
  new_value = torch.nn.Parameter(torch.rand(1, self.value_shape, requires_grad=True, device=self.device))
409
  elif self.config.val_init == "warm":
 
420
  smallest_distance, nearest_key = dists.min(0)
421
 
422
  if smallest_distance > (self.init_epsilon + self.epsilons[nearest_key]):
423
+ # If there's no close key, make a new key
424
  self.keys, self.values, self.epsilons, self.key_labels = self.add_key(query, new_value)
425
  else:
426
  # If there is a close key, we need to handle conflicts
 
430
  else:
431
  # If the current label is the SAME as the nearest label, just make the nearest epsilon bigger
432
  if smallest_distance > self.epsilons[nearest_key]:
433
+ if self.config.eps_expand == "coverage":
434
+ self.epsilons[
435
+ nearest_key] = smallest_distance # Replace nearest epsilon with dist between old key and new key
436
  elif self.config.eps_expand == "moving_average":
437
  a = 0.5
438
+ self.keys[nearest_key] = a * self.keys[nearest_key] + (
439
+ 1 - a) * query # Move old key to be halfway between
440
  self.epsilons[nearest_key] = smallest_distance
441
  # self.epsilons[nearest_key] = smallest_distance + self.init_epsilon
442
  else:
 
454
  chosen_value = perturb_values(chosen_value, self.num_pert, self.device)
455
 
456
  if self.replacement == "replace_all":
457
+ layer_out = torch.where((smallest_dist <= eps).view(-1, 1, 1),
458
+ chosen_value.unsqueeze(1).repeat_interleave(layer_out.shape[1], 1), layer_out)
459
  elif self.replacement == "replace_last":
460
  layer_out[:, token_to_edit] = torch.where((smallest_dist <= eps), chosen_value, layer_out[:, token_to_edit])
461
  elif self.replacement == "replace_prompt":
462
+ layer_out[:, :token_to_edit] = torch.where((smallest_dist <= eps), chosen_value,
463
+ layer_out[:, :token_to_edit])
464
  else:
465
  print("token replacement choice not found")
466
  return layer_out
easyeditor/models/grace/__pycache__/GRACE.cpython-39.pyc DELETED
Binary file (6.67 kB)
 
easyeditor/models/grace/__pycache__/__init__.cpython-39.pyc DELETED
Binary file (350 Bytes)
 
easyeditor/models/grace/__pycache__/grace_hparams.cpython-39.pyc DELETED
Binary file (1.5 kB)
 
easyeditor/models/grace/__pycache__/grace_main.cpython-39.pyc DELETED
Binary file (1.23 kB)
 
easyeditor/models/grace/__pycache__/metrics.cpython-39.pyc DELETED
Binary file (2.07 kB)
 
easyeditor/models/grace/__pycache__/utils.cpython-39.pyc DELETED
Binary file (3.54 kB)
 
easyeditor/models/grace/grace_main.py CHANGED
@@ -15,7 +15,7 @@ def apply_grace_to_model(
15
  requests: List[Dict],
16
  hparams: GraceHyperParams,
17
  num_steps: int,
18
- replacement: str,
19
  copy=False,
20
  return_orig_weights=False,
21
  keep_original_weight=False,
@@ -26,14 +26,13 @@ def apply_grace_to_model(
26
  model = deepcopy(model)
27
  weights_copy = {}
28
  device = torch.device('cpu')
29
- hparams.n_iter = num_steps
30
- hparams.replacement = replacement
31
  editor = GRACE(model=model, config=hparams, device=device)
32
 
33
  tokens = tokenize(request, tokenizer=tok, device=device)
34
  editor.edit(config=hparams, tokens=tokens)
35
 
36
- editor.to('cpu')
37
  gr.Info("Completed editing via GRACE!")
38
  return editor
39
 
 
15
  requests: List[Dict],
16
  hparams: GraceHyperParams,
17
  num_steps: int,
18
+ edit_lr: float,
19
  copy=False,
20
  return_orig_weights=False,
21
  keep_original_weight=False,
 
26
  model = deepcopy(model)
27
  weights_copy = {}
28
  device = torch.device('cpu')
29
+ hparams.edit_lr = edit_lr
 
30
  editor = GRACE(model=model, config=hparams, device=device)
31
 
32
  tokens = tokenize(request, tokenizer=tok, device=device)
33
  editor.edit(config=hparams, tokens=tokens)
34
 
35
+ # editor.to('cpu')
36
  gr.Info("Completed editing via GRACE!")
37
  return editor
38
 
easyeditor/models/rome/README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ROME
2
+ This package provides a self-contained implementation of Rank-One Model Editing (ROME).
3
+
4
+ Recall that ROME's update consists of: $u$ selection, $v_*$ optimization, and $v$ insertion.
5
+ * [`compute_u.py`](compute_u.py): Chooses a $u$ vector.
6
+ * [`compute_v.py`](compute_v.py): Choose a $v_*$ via optimization, then computes $v$.
7
+ * [`rome_main.py`](rome_main.py): Instruments main logic.
8
+ * [`rome_params.py`](rome_hparams.py): Interface for specifying hyperparameters. Inherits from the base [`params.py`](../util/hparams.py) module.
9
+
10
+ For estimating second moment statistics of keys ($C = KK$), we provide the `layer_stats` module. See the [main README](../README.md) for usage instructions.
11
+ * [`layer_stats.py`](layer_stats.py): Logic for retrieving and caching key statistics.
12
+ * [`tok_dataset.py`](tok_dataset.py): Utilities for creating a dataset of tokens.
easyeditor/models/rome/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .rome_main import ROMEHyperParams, apply_rome_to_model, execute_rome
easyeditor/models/rome/compute_u.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Dict, List
4
+
5
+ import torch
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+
8
+ from ..rome import repr_tools
9
+ from ...util.globals import *
10
+
11
+ from .layer_stats import layer_stats
12
+ from .rome_hparams import ROMEHyperParams
13
+
14
+ # Cache variables
15
+ inv_mom2_cache = {}
16
+
17
+
18
+ def get_inv_cov(
19
+ model: AutoModelForCausalLM,
20
+ tok: AutoTokenizer,
21
+ layer_name: str,
22
+ mom2_dataset: str,
23
+ mom2_n_samples: str,
24
+ mom2_dtype: str,
25
+ hparams=None,
26
+ ) -> torch.Tensor:
27
+ """
28
+ Retrieves covariance statistics, then computes the algebraic inverse.
29
+ Caches result for future use.
30
+ """
31
+
32
+ global inv_mom2_cache
33
+
34
+ model_name = model.config._name_or_path.replace("/", "_")
35
+ key = (model_name, layer_name)
36
+
37
+ if key not in inv_mom2_cache:
38
+ print(
39
+ f"Retrieving inverse covariance statistics for {model_name} @ {layer_name}. "
40
+ f"The result will be cached to avoid repetitive computation."
41
+ )
42
+ stat = layer_stats(
43
+ model,
44
+ tok,
45
+ layer_name,
46
+ hparams.stats_dir,
47
+ mom2_dataset,
48
+ to_collect=["mom2"],
49
+ sample_size=mom2_n_samples,
50
+ precision=mom2_dtype,
51
+ hparams=hparams
52
+ )
53
+ inv_mom2_cache[key] = torch.inverse(
54
+ stat.mom2.moment().to(f"cuda:{hparams.device}")
55
+ ).float() # Cast back to float32
56
+
57
+ return inv_mom2_cache[key]
58
+
59
+
60
+ def compute_u(
61
+ model: AutoModelForCausalLM,
62
+ tok: AutoTokenizer,
63
+ request: Dict,
64
+ hparams: ROMEHyperParams,
65
+ layer: int,
66
+ context_templates: List[str],
67
+ ) -> torch.Tensor:
68
+ """
69
+ Computes the right vector used in constructing the rank-1 update matrix.
70
+ """
71
+
72
+ print("Computing left vector (u)...")
73
+
74
+ # Compute projection token
75
+ word_repr_args = dict(
76
+ model=model,
77
+ tok=tok,
78
+ layer=layer,
79
+ module_template=hparams.rewrite_module_tmp,
80
+ track="in",
81
+ )
82
+ if "subject_" in hparams.fact_token and hparams.fact_token.index("subject_") == 0:
83
+ word = request["subject"]
84
+ print(f"Selected u projection object {word}")
85
+
86
+ cur_repr = repr_tools.get_reprs_at_word_tokens(
87
+ context_templates=[
88
+ templ.format(request["prompt"]) for templ in context_templates
89
+ ],
90
+ words=[word for _ in range(len(context_templates))],
91
+ subtoken=hparams.fact_token[len("subject_") :],
92
+ **word_repr_args,
93
+ ).mean(0)
94
+
95
+ elif hparams.fact_token == "last":
96
+ # Heuristic to choose last word. Not a huge deal if there's a minor
97
+ # edge case (e.g. multi-token word) because the function below will
98
+ # take the last token.
99
+ cur_repr = repr_tools.get_reprs_at_idxs(
100
+ contexts=[
101
+ templ.format(request["prompt"].format(request["subject"]))
102
+ for templ in context_templates
103
+ ],
104
+ idxs=[[-1] for _ in range(len(context_templates))],
105
+ **word_repr_args,
106
+ ).mean(0)
107
+ print("Selected u projection token with last token")
108
+ else:
109
+ raise ValueError(f"fact_token={hparams.fact_token} not recognized")
110
+
111
+ # Apply inverse second moment adjustment
112
+ u = cur_repr
113
+ if hparams.mom2_adjustment:
114
+ u = get_inv_cov(
115
+ model,
116
+ tok,
117
+ hparams.rewrite_module_tmp.format(layer),
118
+ hparams.mom2_dataset,
119
+ hparams.mom2_n_samples,
120
+ hparams.mom2_dtype,
121
+ hparams=hparams,
122
+ ) @ u.unsqueeze(1)
123
+ u = u.squeeze()
124
+
125
+ return u / u.norm()
easyeditor/models/rome/compute_v.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ from matplotlib.style import context
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+
8
+ from ..rome import repr_tools
9
+ from ...util import nethook
10
+
11
+ from .rome_hparams import ROMEHyperParams
12
+
13
+
14
+ def compute_v(
15
+ model: AutoModelForCausalLM,
16
+ tok: AutoTokenizer,
17
+ request: Dict,
18
+ hparams: ROMEHyperParams,
19
+ layer: int,
20
+ left_vector: torch.Tensor,
21
+ context_templates: List[str],
22
+ ) -> torch.Tensor:
23
+ """
24
+ Computes the value (right) vector for the rank-1 update.
25
+ Runs a simple optimization procedure.
26
+ """
27
+
28
+ print("Computing right vector (v)")
29
+
30
+ # Tokenize target into list of int token IDs
31
+ target_ids = tok.encode(request["target_new"], return_tensors="pt", add_special_tokens=False).to('cpu')[0]
32
+
33
+ # if target_ids[0] == tok.bos_token_id or target_ids[0] == tok.unk_token_id:
34
+ # target_ids = target_ids[1:]
35
+ # Compile list of rewriting and KL x/y pairs
36
+ rewriting_prompts, kl_prompts = [
37
+ context.format(request["prompt"]) + tok.decode(target_ids[:-1])
38
+ for context in context_templates
39
+ ], ["{} is a"]
40
+ all_prompts = rewriting_prompts + kl_prompts
41
+
42
+ input_tok = tok(
43
+ [prompt.format(request["subject"]) for prompt in all_prompts],
44
+ return_tensors="pt",
45
+ padding=True,
46
+ ).to("cpu")
47
+
48
+ # Compute rewriting targets
49
+ rewriting_targets = torch.tensor(-100, device='cpu').repeat(
50
+ len(rewriting_prompts), *input_tok["input_ids"].shape[1:]
51
+ )
52
+ for i in range(len(rewriting_prompts)):
53
+ ex_len = input_tok["attention_mask"][i].sum()
54
+ rewriting_targets[i, ex_len - len(target_ids) : ex_len] = target_ids
55
+
56
+ # Compute indices of the tokens where the fact is looked up
57
+ vanilla_input_prompts = [
58
+ context.format(request["prompt"]).format(request['subject'])
59
+ for context in context_templates
60
+ ] + [f"{request['subject']} is a"]
61
+ lookup_idxs = [
62
+ find_fact_lookup_idx(
63
+ prompt, request["subject"], tok, hparams.fact_token, verbose=(i == 0), input_prompt=vanilla_input_prompts[i]
64
+ )
65
+ for i, prompt in enumerate(all_prompts)
66
+ ]
67
+
68
+ # Finalize rewrite and loss layers
69
+ loss_layer = max(hparams.v_loss_layer, layer)
70
+ print(f"Rewrite layer is {layer}")
71
+ print(f"Tying optimization objective to {loss_layer}")
72
+
73
+ # Set up an optimization over a latent vector that, when output at the
74
+ # rewrite layer, i.e. hypothesized fact lookup location, will induce the
75
+ # target token to be predicted at the final layer.
76
+ if hasattr(model.config, 'n_embd'):
77
+ delta = torch.zeros((model.config.n_embd,), requires_grad=True, device=f"cpu")
78
+ else:
79
+ delta = torch.zeros((model.config.hidden_size,), requires_grad=True, device=f"cpu")
80
+ target_init, kl_distr_init = None, None
81
+
82
+ # Inserts new "delta" variable at the appropriate part of the computation
83
+ def edit_output_fn(cur_out, cur_layer):
84
+ nonlocal target_init
85
+ if cur_layer == hparams.mlp_module_tmp.format(layer):
86
+ # Store initial value of the vector of interest
87
+ if target_init is None:
88
+ print("Recording initial value of v*")
89
+ # Initial value is recorded for the clean sentence
90
+ target_init = cur_out[0, lookup_idxs[0]].detach().clone()
91
+
92
+ for i, idx in enumerate(lookup_idxs):
93
+ if len(lookup_idxs)!=len(cur_out):
94
+ cur_out[idx, i, :] += delta
95
+ else:
96
+ cur_out[i, idx, :] += delta
97
+
98
+ return cur_out
99
+
100
+ # Optimizer
101
+ opt = torch.optim.Adam([delta], lr=hparams.v_lr)
102
+ nethook.set_requires_grad(False, model)
103
+
104
+ # Execute optimization
105
+ for it in range(hparams.v_num_grad_steps):
106
+ opt.zero_grad()
107
+
108
+ # Forward propagation
109
+ with nethook.TraceDict(
110
+ module=model,
111
+ layers=[
112
+ hparams.layer_module_tmp.format(loss_layer),
113
+ hparams.mlp_module_tmp.format(layer),
114
+ ],
115
+ retain_input=False,
116
+ retain_output=True,
117
+ edit_output=edit_output_fn,
118
+ ) as tr:
119
+ logits = model(**input_tok).logits
120
+
121
+ # Compute distribution for KL divergence
122
+ kl_logits = torch.stack(
123
+ [
124
+ logits[i - len(kl_prompts), idx, :]
125
+ for i, idx in enumerate(lookup_idxs[-len(kl_prompts) :])
126
+ ],
127
+ dim=0,
128
+ )
129
+ kl_log_probs = torch.nn.functional.log_softmax(kl_logits, dim=1)
130
+ if kl_distr_init is None:
131
+ kl_distr_init = kl_log_probs.detach().clone()
132
+
133
+ # Compute loss on rewriting targets
134
+ log_probs = torch.log_softmax(logits, dim=2)
135
+
136
+ loss = torch.gather(
137
+ log_probs,
138
+ 2,
139
+ torch.where(rewriting_targets != -100, rewriting_targets, 0).unsqueeze(2),
140
+ ).squeeze(2)
141
+ mask = (rewriting_targets != -100).float()
142
+
143
+ # Aggregate total losses
144
+ nll_loss_each = -(loss * mask).sum(1) / target_ids.size(0)
145
+ nll_loss = nll_loss_each.mean()
146
+ kl_loss = hparams.kl_factor * torch.nn.functional.kl_div(
147
+ kl_distr_init, kl_log_probs, log_target=True, reduction="batchmean"
148
+ )
149
+ weight_decay = hparams.v_weight_decay * (
150
+ torch.norm(delta) / torch.norm(target_init) ** 2
151
+ )
152
+ # weight_decay = hparams.v_weight_decay * torch.norm(delta) ** 2
153
+ loss = nll_loss + kl_loss + weight_decay
154
+ print(
155
+ f"loss {np.round(loss.item(), 3)} = {np.round(nll_loss.item(), 3)} + {np.round(kl_loss.item(), 3)} + {np.round(weight_decay.item(), 3)} "
156
+ f"avg prob of [{request['target_new']}] "
157
+ f"{torch.exp(-nll_loss_each).mean().item()}"
158
+ )
159
+ if loss < 5e-2:
160
+ break
161
+
162
+ if it == hparams.v_num_grad_steps - 1:
163
+ break
164
+
165
+ # Backpropagate
166
+ loss.backward()
167
+ opt.step()
168
+
169
+ # Project within L2 ball
170
+ max_norm = hparams.clamp_norm_factor * target_init.norm()
171
+ if delta.norm() > max_norm:
172
+ with torch.no_grad():
173
+ delta[...] = delta * max_norm / delta.norm()
174
+
175
+ target = target_init + delta.to(target_init.dtype)
176
+
177
+ # Retrieve cur_input, the current input to the 2nd MLP layer, and
178
+ # cur_output, the original output of the 2nd MLP layer.
179
+ cur_input, cur_output = get_module_input_output_at_word(
180
+ model,
181
+ tok,
182
+ layer,
183
+ context_template=request["prompt"],
184
+ word=request["subject"],
185
+ module_template=hparams.rewrite_module_tmp,
186
+ fact_token_strategy=hparams.fact_token,
187
+ )
188
+
189
+ # Solving the linear system to compute the right vector
190
+ right_vector = (target - cur_output) / torch.dot(cur_input, left_vector)
191
+ print(f"Delta norm: {(target - cur_output).norm().item()}")
192
+ print(
193
+ f"Change in target norm: {target_init.norm().item()} to {target.norm().item()} => {(target.norm() - target_init.norm()).item()}"
194
+ )
195
+ print(f"Division Factor: {torch.dot(cur_input, left_vector).item()}")
196
+ print(f"Right vector norm: {right_vector.norm()}")
197
+
198
+ return right_vector
199
+
200
+
201
+ def get_module_input_output_at_word(
202
+ model: AutoModelForCausalLM,
203
+ tok: AutoTokenizer,
204
+ layer: int,
205
+ context_template: str,
206
+ word: str,
207
+ module_template: str,
208
+ fact_token_strategy: str,
209
+ ) -> Tuple[torch.Tensor]:
210
+ """
211
+ Retrieves detached representations for a word at the input and
212
+ output of a particular layer module.
213
+ """
214
+
215
+ word_repr_args = dict(
216
+ model=model,
217
+ tok=tok,
218
+ layer=layer,
219
+ module_template=module_template,
220
+ )
221
+ if "subject_" in fact_token_strategy and fact_token_strategy.index("subject_") == 0:
222
+ subtoken = fact_token_strategy[len("subject_") :]
223
+ l_input, l_output = repr_tools.get_reprs_at_word_tokens(
224
+ track="both",
225
+ subtoken=subtoken,
226
+ context_templates=[context_template],
227
+ words=[word],
228
+ **word_repr_args,
229
+ )
230
+ elif fact_token_strategy == "last":
231
+ l_input, l_output = repr_tools.get_reprs_at_idxs(
232
+ track="both",
233
+ contexts=[context_template.format(word)],
234
+ idxs=[[-1]],
235
+ **word_repr_args,
236
+ )
237
+ else:
238
+ raise ValueError(f"fact_token={fact_token_strategy} not recognized")
239
+
240
+ l_input, l_output = l_input[0], l_output[0]
241
+ return l_input.detach(), l_output.detach()
242
+
243
+
244
+ def find_fact_lookup_idx(
245
+ prompt: str,
246
+ subject: str,
247
+ tok: AutoTokenizer,
248
+ fact_token_strategy: str,
249
+ verbose=True,
250
+ input_prompt=None
251
+ ) -> int:
252
+ """
253
+ Computes hypothesized fact lookup index given a sentence and subject.
254
+ """
255
+
256
+ ret = None
257
+ if fact_token_strategy == "last":
258
+ ret = len(tok.encode(input_prompt)) - 1
259
+ elif (
260
+ "subject_" in fact_token_strategy and fact_token_strategy.index("subject_") == 0
261
+ ):
262
+ ret = repr_tools.get_words_idxs_in_templates(
263
+ tok=tok,
264
+ context_templates=[prompt],
265
+ words=[subject],
266
+ subtoken=fact_token_strategy[len("subject_") :],
267
+ )[0][0]
268
+ else:
269
+ raise ValueError(f"fact_token={fact_token_strategy} not recognized")
270
+
271
+ sentence = prompt.format(subject)
272
+ if verbose:
273
+ print(
274
+ f"Lookup index found: {ret} | Sentence: {sentence} | Token:",
275
+ tok.decode(tok(sentence)["input_ids"][ret]),
276
+ )
277
+
278
+ return ret
easyeditor/models/rome/layer_stats.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ from datasets import load_dataset
6
+ from tqdm.auto import tqdm
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+
9
+ from ...util.globals import *
10
+ from ...util.nethook import Trace, set_requires_grad
11
+ from ...util.runningstats import CombinedStat, Mean, NormMean, SecondMoment, tally
12
+
13
+ from .tok_dataset import (
14
+ TokenizedDataset,
15
+ dict_to_,
16
+ flatten_masked_batch,
17
+ length_collation,
18
+ )
19
+
20
+ STAT_TYPES = {
21
+ "mom2": SecondMoment,
22
+ "mean": Mean,
23
+ "norm_mean": NormMean,
24
+ }
25
+
26
+
27
+ def main():
28
+ """
29
+ Command-line utility to precompute cached stats.
30
+ """
31
+ import argparse
32
+
33
+ parser = argparse.ArgumentParser(description="ROME Statistics Collector")
34
+
35
+ def aa(*args, **kwargs):
36
+ parser.add_argument(*args, **kwargs)
37
+
38
+ aa("--model_name", default="gpt2-xl", choices=["gpt2-xl", "EleutherAI/gpt-j-6B"])
39
+ aa("--dataset", default="wikipedia", choices=["wikitext", "wikipedia"])
40
+ aa("--layers", default=[17], type=lambda x: list(map(int, x.split(","))))
41
+ aa("--to_collect", default=["mom2"], type=lambda x: x.split(","))
42
+ aa("--sample_size", default=100000, type=lambda x: None if x == "all" else int(x))
43
+ aa("--batch_tokens", default=None, type=lambda x: None if x == "any" else int(x))
44
+ aa("--precision", default="float32", choices=["float64", "float32", "float16"])
45
+ aa("--stats_dir", default=STATS_DIR)
46
+ aa("--download", default=1, type=int, choices=[0, 1])
47
+ args = parser.parse_args()
48
+
49
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
50
+ model = AutoModelForCausalLM.from_pretrained(args.model_name).eval().cuda()
51
+ set_requires_grad(False, model)
52
+
53
+ for layer_num in args.layers:
54
+ print(
55
+ f"Computing stats for layer {layer_num} of {args.model_name} "
56
+ f'over {args.sample_size or "all"} samples of {args.dataset}. '
57
+ "Note, the statistics are collected over the inputs to the second MLP layer, "
58
+ "or equivalently the outputs of the first MLP layer."
59
+ )
60
+ proj_layer_name = "c_proj" if "gpt2" in args.model_name else "fc_out"
61
+ layer_name = f"transformer.h.{layer_num}.mlp.{proj_layer_name}"
62
+
63
+ layer_stats(
64
+ model,
65
+ tokenizer,
66
+ layer_name,
67
+ args.stats_dir,
68
+ args.dataset,
69
+ args.to_collect,
70
+ sample_size=args.sample_size,
71
+ precision=args.precision,
72
+ batch_tokens=args.batch_tokens,
73
+ download=args.download,
74
+ )
75
+
76
+
77
+ def layer_stats(
78
+ model,
79
+ tokenizer,
80
+ layer_name,
81
+ stats_dir,
82
+ ds_name,
83
+ to_collect,
84
+ model_name=None,
85
+ sample_size=None,
86
+ precision=None,
87
+ batch_tokens=None,
88
+ download=True,
89
+ progress=tqdm,
90
+ force_recompute=False,
91
+ hparams=None
92
+ ):
93
+ """
94
+ Function to load or compute cached stats.
95
+ """
96
+
97
+ def get_ds():
98
+ # Load_From_File
99
+ # from datasets import Dataset
100
+ # raw_ds = Dataset.from_file('XXX/XXX/wikipedia-train.arrow')
101
+ # raw_ds = {'train': raw_ds}
102
+ raw_ds = load_dataset(
103
+ ds_name,
104
+ dict(wikitext="wikitext-103-raw-v1", wikipedia="20200501.en")[ds_name]
105
+ )
106
+ if hasattr(model.config, 'n_positions'):
107
+ maxlen = model.config.n_positions
108
+ elif hasattr(model.config, 'max_sequence_length'):
109
+ maxlen = model.config.max_sequence_length
110
+ elif hasattr(model.config, 'max_position_embeddings'):
111
+ maxlen = model.config.max_position_embeddings
112
+ elif hasattr(model.config,'seq_length'):
113
+ maxlen = model.config.seq_length
114
+ else:
115
+ raise NotImplementedError
116
+
117
+ if hasattr(model.config, 'model_type') and 'mistral' in model.config.model_type:
118
+ if hasattr(model.config, 'sliding_window') and model.config.sliding_window:
119
+ maxlen = model.config.sliding_window or 4096
120
+ else:
121
+ maxlen = 4096
122
+
123
+ if batch_tokens is not None and batch_tokens < maxlen:
124
+ maxlen = batch_tokens
125
+ return TokenizedDataset(raw_ds["train"], tokenizer, maxlen=maxlen)
126
+
127
+ # Continue with computation of statistics
128
+ batch_size = 100 # Examine this many dataset texts at once
129
+ if hasattr(model.config, 'n_positions'):
130
+ npos = model.config.n_positions
131
+ elif hasattr(model.config, 'max_sequence_length'):
132
+ npos = model.config.max_sequence_length
133
+ elif hasattr(model.config, 'max_position_embeddings'):
134
+ npos = model.config.max_position_embeddings
135
+ elif hasattr(model.config,'seq_length'):
136
+ npos = model.config.seq_length
137
+ else:
138
+ raise NotImplementedError
139
+
140
+ if hasattr(model.config, 'model_type') and 'mistral' in model.config.model_type:
141
+ if hasattr(model.config, 'sliding_window') and model.config.sliding_window:
142
+ npos = model.config.sliding_window or 4096
143
+ else:
144
+ npos = 4096
145
+
146
+ if batch_tokens is None:
147
+ batch_tokens = npos * 3 # Sort and divide into batches with this many tokens
148
+ if precision is None:
149
+ precision = "float64"
150
+ dtype = getattr(torch, precision)
151
+ size_suffix = "" if sample_size is None else f"_{sample_size}"
152
+ if batch_tokens < npos:
153
+ size_suffix = "_t{batch_tokens}" + size_suffix
154
+ if model_name is None:
155
+ # model_name = model.config._name_or_path.replace("/", "_")
156
+ model_name = model.config._name_or_path.rsplit("/")[-1]
157
+
158
+ stats_dir = Path(stats_dir)
159
+ file_extension = f"{model_name}/{ds_name}_stats/{layer_name}_{precision}_{'-'.join(sorted(to_collect))}{size_suffix}.npz"
160
+ filename = stats_dir / file_extension
161
+
162
+ print(f"Computing Cov locally....")
163
+
164
+ ds = get_ds() if not filename.exists() else None
165
+
166
+ if progress is None:
167
+ progress = lambda x: x
168
+
169
+ stat = CombinedStat(**{k: STAT_TYPES[k]() for k in to_collect})
170
+ loader = tally(
171
+ stat,
172
+ ds,
173
+ cache=(filename if not force_recompute else None),
174
+ sample_size=sample_size,
175
+ batch_size=batch_size,
176
+ collate_fn=length_collation(batch_tokens),
177
+ pin_memory=True,
178
+ random_sample=1,
179
+ num_workers=2,
180
+ )
181
+ batch_count = -(-(sample_size or len(ds)) // batch_size)
182
+ with torch.no_grad():
183
+ for batch_group in progress(loader, total=batch_count):
184
+ for batch in batch_group:
185
+ batch = dict_to_(batch, f"cuda:{hparams.device}")
186
+ with Trace(
187
+ model, layer_name, retain_input=True, retain_output=False, stop=True
188
+ ) as tr:
189
+ model(**batch)
190
+ feats = flatten_masked_batch(tr.input, batch["attention_mask"])
191
+ # feats = flatten_masked_batch(tr.output, batch["attention_mask"])
192
+ feats = feats.to(dtype=dtype)
193
+ stat.add(feats)
194
+ return stat
195
+
196
+
197
+ if __name__ == "__main__":
198
+ main()
easyeditor/models/rome/repr_tools.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Contains utilities for extracting token representations and indices
3
+ from string templates. Used in computing the left and right vectors for ROME.
4
+ """
5
+
6
+ from copy import deepcopy
7
+ from typing import List
8
+
9
+ import torch
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer
11
+
12
+ from ...util import nethook
13
+
14
+ def get_reprs_at_word_tokens(
15
+ model: AutoModelForCausalLM,
16
+ tok: AutoTokenizer,
17
+ context_templates: List[str],
18
+ words: List[str],
19
+ layer: int,
20
+ module_template: str,
21
+ subtoken: str,
22
+ track: str = "in",
23
+ ) -> torch.Tensor:
24
+ """
25
+ Retrieves the last token representation of `word` in `context_template`
26
+ when `word` is substituted into `context_template`. See `get_last_word_idx_in_template`
27
+ for more details.
28
+ """
29
+
30
+ idxs = get_words_idxs_in_templates(tok, context_templates, words, subtoken)
31
+ return get_reprs_at_idxs(
32
+ model,
33
+ tok,
34
+ [context_templates[i].format(words[i]) for i in range(len(words))],
35
+ idxs,
36
+ layer,
37
+ module_template,
38
+ track,
39
+ )
40
+
41
+ def get_words_idxs_in_templates(
42
+ tok: AutoTokenizer, context_templates: str, words: str, subtoken: str
43
+ ) -> int:
44
+ """
45
+ Given list of template strings, each with *one* format specifier
46
+ (e.g. "{} plays basketball"), and words to be substituted into the
47
+ template, computes the post-tokenization index of their last tokens.
48
+ """
49
+
50
+ assert all(
51
+ tmp.count("{}") == 1 for tmp in context_templates
52
+ ), "We currently do not support multiple fill-ins for context"
53
+
54
+
55
+ prefixes_len, words_len, suffixes_len, inputs_len = [], [], [], []
56
+ for i, context in enumerate(context_templates):
57
+ prefix, suffix = context.split("{}")
58
+ prefix_len = len(tok.encode(prefix))
59
+ prompt_len = len(tok.encode(prefix + words[i]))
60
+ input_len = len(tok.encode(prefix + words[i] + suffix))
61
+ prefixes_len.append(prefix_len)
62
+ words_len.append(prompt_len - prefix_len)
63
+ suffixes_len.append(input_len - prompt_len)
64
+ inputs_len.append(input_len)
65
+
66
+ # Compute prefixes and suffixes of the tokenized context
67
+ # fill_idxs = [tmp.index("{}") for tmp in context_templates]
68
+ # prefixes, suffixes = [
69
+ # tmp[: fill_idxs[i]] for i, tmp in enumerate(context_templates)
70
+ # ], [tmp[fill_idxs[i] + 2 :] for i, tmp in enumerate(context_templates)]
71
+ # words = deepcopy(words)
72
+ #
73
+ # # Pre-process tokens
74
+ # for i, prefix in enumerate(prefixes):
75
+ # if len(prefix) > 0:
76
+ # assert prefix[-1] == " "
77
+ # prefix = prefix[:-1]
78
+ #
79
+ # prefixes[i] = prefix
80
+ # words[i] = f" {words[i].strip()}"
81
+ #
82
+ # # Tokenize to determine lengths
83
+ # assert len(prefixes) == len(words) == len(suffixes)
84
+ # n = len(prefixes)
85
+ # batch_tok = tok([*prefixes, *words, *suffixes])
86
+ # if 'input_ids' in batch_tok:
87
+ # batch_tok = batch_tok['input_ids']
88
+ # prefixes_tok, words_tok, suffixes_tok = [
89
+ # batch_tok[i : i + n] for i in range(0, n * 3, n)
90
+ # ]
91
+ # prefixes_len, words_len, suffixes_len = [
92
+ # [len(el) for el in tok_list]
93
+ # for tok_list in [prefixes_tok, words_tok, suffixes_tok]
94
+ # ]
95
+
96
+ # Compute indices of last tokens
97
+ if subtoken == "last" or subtoken == "first_after_last":
98
+ return [
99
+ [
100
+ prefixes_len[i]
101
+ + words_len[i]
102
+ - (1 if subtoken == "last" or suffixes_len[i] == 0 else 0)
103
+ ]
104
+ # If suffix is empty, there is no "first token after the last".
105
+ # So, just return the last token of the word.
106
+ for i in range(len(context_templates))
107
+ ]
108
+ elif subtoken == "first":
109
+ return [[prefixes_len[i] - inputs_len[i]] for i in range(len(context_templates))]
110
+ else:
111
+ raise ValueError(f"Unknown subtoken type: {subtoken}")
112
+
113
+
114
+ def get_reprs_at_idxs(
115
+ model: AutoModelForCausalLM,
116
+ tok: AutoTokenizer,
117
+ contexts: List[str],#表示该知识的完整句子
118
+ idxs: List[List[int]],#被填入词的位置
119
+ layer: int,
120
+ module_template: str,
121
+ track: str = "in",
122
+ ) -> torch.Tensor:
123
+ """
124
+ Runs input through model and returns averaged representations of the tokens
125
+ at each index in `idxs`.
126
+ """
127
+
128
+ def _batch(n):
129
+ for i in range(0, len(contexts), n):
130
+ yield contexts[i : i + n], idxs[i : i + n]#将句子和被填词位置分块
131
+
132
+ assert track in {"in", "out", "both"}
133
+ both = track == "both"
134
+ tin, tout = (
135
+ (track == "in" or both),
136
+ (track == "out" or both),
137
+ )#tin tout都是bool结构
138
+ module_name = module_template.format(layer)
139
+ to_return = {"in": [], "out": []}
140
+
141
+ def _process(cur_repr, batch_idxs, key):
142
+ nonlocal to_return
143
+ cur_repr = cur_repr[0] if type(cur_repr) is tuple else cur_repr
144
+ if cur_repr.shape[0]!=len(batch_idxs):
145
+ cur_repr=cur_repr.transpose(0,1)
146
+ for i, idx_list in enumerate(batch_idxs):
147
+ to_return[key].append(cur_repr[i][idx_list].mean(0))
148
+
149
+ for batch_contexts, batch_idxs in _batch(n=128):
150
+ #contexts_tok:[21 19]
151
+ contexts_tok = tok(batch_contexts, padding=True, return_tensors="pt").to(
152
+ next(model.parameters()).device
153
+ )
154
+
155
+ with torch.no_grad():
156
+ with nethook.Trace(
157
+ module=model,
158
+ layer=module_name,
159
+ retain_input=tin,
160
+ retain_output=tout,
161
+ ) as tr:
162
+ model(**contexts_tok)
163
+
164
+ if tin:
165
+ _process(tr.input, batch_idxs, "in")
166
+ if tout:
167
+ _process(tr.output, batch_idxs, "out")
168
+
169
+ to_return = {k: torch.stack(v, 0) for k, v in to_return.items() if len(v) > 0}
170
+
171
+ if len(to_return) == 1:
172
+ return to_return["in"] if tin else to_return["out"]
173
+ else:
174
+ return to_return["in"], to_return["out"]
easyeditor/models/rome/rome_hparams.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List
3
+ import yaml
4
+
5
+ from ...util.hparams import HyperParams
6
+
7
+
8
+ @dataclass
9
+ class ROMEHyperParams(HyperParams):
10
+ # Method
11
+ layers: List[int]
12
+ fact_token: str
13
+ v_num_grad_steps: int
14
+ v_lr: float
15
+ v_loss_layer: int
16
+ v_weight_decay: float
17
+ clamp_norm_factor: float
18
+ kl_factor: float
19
+ mom2_adjustment: bool
20
+ context_template_length_params: List[List[int]]
21
+
22
+ # Module templates
23
+ rewrite_module_tmp: str
24
+ layer_module_tmp: str
25
+ mlp_module_tmp: str
26
+ attn_module_tmp: str
27
+ ln_f_module: str
28
+ lm_head_module: str
29
+
30
+ # Statistics
31
+ mom2_dataset: str
32
+ mom2_n_samples: int
33
+ mom2_dtype: str
34
+ alg_name: str
35
+ device: int
36
+ model_name: str
37
+ stats_dir: str
38
+
39
+ max_length: int = 40
40
+ model_parallel: bool = False
41
+ fp16: bool = False
42
+
43
+ @classmethod
44
+ def from_hparams(cls, hparams_name_or_path: str):
45
+
46
+ if '.yaml' not in hparams_name_or_path:
47
+ hparams_name_or_path = hparams_name_or_path + '.yaml'
48
+
49
+ with open(hparams_name_or_path, "r") as stream:
50
+ config = yaml.safe_load(stream)
51
+ config = super().construct_float_from_scientific_notation(config)
52
+
53
+ assert (config and config['alg_name'] == 'ROME') or print(f'ROMEHyperParams can not load from {hparams_name_or_path}, '
54
+ f'alg_name is {config["alg_name"]} ')
55
+ return cls(**config)
easyeditor/models/rome/rome_main.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from typing import Dict, List, Tuple
3
+
4
+ import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+
7
+ from ...util import nethook
8
+ from ...util.generate import generate_fast
9
+
10
+ from .compute_u import compute_u
11
+ from .compute_v import compute_v
12
+ from .rome_hparams import ROMEHyperParams
13
+ import gradio as gr
14
+
15
+ CONTEXT_TEMPLATES_CACHE = None
16
+
17
+
18
+ def apply_rome_to_model(
19
+ model: AutoModelForCausalLM,
20
+ tok: AutoTokenizer,
21
+ request: List[Dict],
22
+ hparams: ROMEHyperParams,
23
+ num_steps: int,
24
+ edit_lr: float,
25
+ copy=False,
26
+ return_orig_weights=False,
27
+ keep_original_weight=False,
28
+ **kwargs
29
+ ) -> Tuple[AutoModelForCausalLM, List[str]]:
30
+ """
31
+ Returns a model with the desired changes.
32
+
33
+ :param copy: If true, will preserve the original model while creating a new one to edit.
34
+ Note that you are responsible for deallocating the new model's memory to avoid leaks.
35
+
36
+ :return: (1) the updated model, (2) an original copy of the weights that changed
37
+ """
38
+ if copy:
39
+ model = deepcopy(model)
40
+
41
+ weights_copy = {}
42
+ hparams.v_num_grad_steps = num_steps // 2
43
+ hparams.v_lr = edit_lr
44
+ request['subject'] = request['prompt']
45
+
46
+ deltas = execute_rome(model, tok, request, hparams)
47
+
48
+ with torch.no_grad():
49
+ for w_name, (delta_u, delta_v) in deltas.items():
50
+ upd_matrix = delta_u.unsqueeze(1) @ delta_v.unsqueeze(0)
51
+ w = nethook.get_parameter(model, w_name)
52
+ upd_matrix = upd_matrix_match_shape(upd_matrix, w.shape)
53
+
54
+ if return_orig_weights and w_name not in weights_copy:
55
+ weights_copy[w_name] = w.detach().clone()
56
+
57
+ w[...] += upd_matrix
58
+
59
+ print(f"New weights successfully inserted into {list(deltas.keys())}")
60
+
61
+ if not keep_original_weight:
62
+ weights_copy = {}
63
+ gr.Info("Completed editing via ROME!")
64
+ return model
65
+
66
+
67
+ def execute_rome(
68
+ model: AutoModelForCausalLM,
69
+ tok: AutoTokenizer,
70
+ request: Dict,
71
+ hparams: ROMEHyperParams,
72
+ ) -> Dict[str, Tuple[torch.Tensor]]:
73
+ """
74
+ Executes the ROME update algorithm for the specified update at the specified layer
75
+ Invariant: model at beginning of function == model at end of function
76
+ """
77
+
78
+ # Update target and print info
79
+ request = deepcopy(request)
80
+ if request["target_new"] != " ":
81
+ # Space required for correct tokenization
82
+ request["target_new"] = " " + request["target_new"]
83
+
84
+ if '{}' not in request['prompt']:
85
+ assert request['subject'] in request['prompt'] or \
86
+ print(f"Subject:{request['subject']} do not exist in prompt: {request['prompt']}")
87
+
88
+ request['prompt'] = request['prompt'].replace(request['subject'], '{}')
89
+
90
+ print(
91
+ f"Executing ROME algorithm for the update: "
92
+ f"[{request['prompt'].format(request['subject'])}] -> [{request['target_new']}]"
93
+ )
94
+
95
+ # Retrieve weights that user desires to change
96
+ weights = {
97
+ f"{hparams.rewrite_module_tmp.format(layer)}.weight": nethook.get_parameter(
98
+ model, f"{hparams.rewrite_module_tmp.format(layer)}.weight"
99
+ )
100
+ for layer in hparams.layers
101
+ }
102
+ # Save old weights for future restoration
103
+ weights_copy = {k: v.detach().clone() for k, v in weights.items()}
104
+
105
+ # Update loop: sequentially intervene at each specified layer
106
+ deltas = {}
107
+ for layer in sorted(hparams.layers):
108
+ # Compute rank-1 update matrix
109
+ left_vector: torch.Tensor = compute_u(
110
+ model,
111
+ tok,
112
+ request,
113
+ hparams,
114
+ layer,
115
+ get_context_templates(model, tok, hparams.context_template_length_params),
116
+ )
117
+ print("Left vector shape:", left_vector.shape)
118
+ right_vector: torch.Tensor = compute_v(
119
+ model,
120
+ tok,
121
+ request,
122
+ hparams,
123
+ layer,
124
+ left_vector,
125
+ get_context_templates(model, tok, hparams.context_template_length_params),
126
+ )
127
+ print("Right vector shape:", right_vector.shape)
128
+
129
+ with torch.no_grad():
130
+ # Determine correct transposition of delta matrix
131
+ weight_name = f"{hparams.rewrite_module_tmp.format(layer)}.weight"
132
+ upd_matrix = left_vector.unsqueeze(1) @ right_vector.unsqueeze(0)
133
+ upd_matrix = upd_matrix_match_shape(upd_matrix, weights[weight_name].shape)
134
+
135
+ # Update model weights and record desired changes in `delta` variable
136
+ weights[weight_name][...] += upd_matrix
137
+ deltas[weight_name] = (
138
+ left_vector.detach(),
139
+ right_vector.detach(),
140
+ )
141
+
142
+ # Restore state of original model
143
+ with torch.no_grad():
144
+ for k, v in weights.items():
145
+ v[...] = weights_copy[k]
146
+
147
+ print(f"Deltas successfully computed for {list(weights.keys())}")
148
+
149
+ return deltas
150
+
151
+
152
+ def upd_matrix_match_shape(matrix: torch.Tensor, shape: torch.Size) -> torch.Tensor:
153
+ """
154
+ GPT-2 and GPT-J have transposed weight representations.
155
+ Returns a matrix that matches the desired shape, else raises a ValueError
156
+ """
157
+
158
+ if matrix.shape == shape:
159
+ return matrix
160
+ elif matrix.T.shape == shape:
161
+ return matrix.T
162
+ else:
163
+ raise ValueError(
164
+ "Update matrix computed by ROME does not match original weight shape. "
165
+ "Check for bugs in the code?"
166
+ )
167
+
168
+
169
+ def get_context_templates(model, tok, length_params):
170
+ global CONTEXT_TEMPLATES_CACHE
171
+
172
+ if CONTEXT_TEMPLATES_CACHE is None:
173
+ CONTEXT_TEMPLATES_CACHE = ["{}"] + [
174
+ x.replace("{", "").replace("}", "") + ". {}"
175
+ for x in sum(
176
+ (
177
+ generate_fast(
178
+ model,
179
+ tok,
180
+ ["The", "Therefore", "Because", "I", "You"],
181
+ n_gen_per_prompt=n_gen // 5,
182
+ max_out_len=length,
183
+ )
184
+ for length, n_gen in length_params
185
+ ),
186
+ [],
187
+ )
188
+ ]
189
+
190
+ print(f"Cached context templates {CONTEXT_TEMPLATES_CACHE}")
191
+
192
+ return CONTEXT_TEMPLATES_CACHE
easyeditor/models/rome/tok_dataset.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn.utils.rnn import pad_sequence
3
+ from torch.utils.data import Dataset
4
+
5
+
6
+ class TokenizedDataset(Dataset):
7
+ """
8
+ Converts a dataset of text samples into a dataset of token sequences,
9
+ as converted by a supplied tokenizer. The tokens come along with position
10
+ ids and attention masks, they can be supplied direcly to the model.
11
+ """
12
+
13
+ def __init__(self, text_dataset, tokenizer=None, maxlen=None, field="text"):
14
+ self.text_dataset = text_dataset
15
+ self.field = field
16
+ self.tokenizer = tokenizer
17
+ self.maxlen = maxlen
18
+ if hasattr(text_dataset, "info"):
19
+ self.info = text_dataset.info
20
+
21
+ def __len__(self):
22
+ return len(self.text_dataset)
23
+
24
+ def __getitem__(self, i):
25
+ text = self.text_dataset[i]
26
+ if self.field is not None:
27
+ text = text[self.field]
28
+ token_list = self.tokenizer.encode(
29
+ text, truncation=True, max_length=self.maxlen
30
+ )
31
+ position_ids = list(range(len(token_list)))
32
+ attention_mask = [1] * len(token_list)
33
+ return dict(
34
+ input_ids=torch.tensor(token_list),
35
+ position_ids=torch.tensor(position_ids),
36
+ attention_mask=torch.tensor(attention_mask),
37
+ )
38
+
39
+
40
+ def dict_to_(data, device):
41
+ """
42
+ Moves a dictionary of tensors to the specified device.
43
+ """
44
+ for k in data:
45
+ data[k] = data[k].to(device)
46
+ return data
47
+
48
+
49
+ def length_collation(token_size):
50
+ """
51
+ Sorts a batch of sequences and breaks it up into subbatches
52
+ of same-sized sequences, padding as needed. Each batch
53
+ has no more than token_size total tokens (or a single
54
+ sequence, if the sequence happens to be larger).
55
+ """
56
+
57
+ def collate_fn(items):
58
+ items = sorted(items, key=lambda x: -len(x["input_ids"]))
59
+ batches = []
60
+ batch = []
61
+ batch_width = 0
62
+ for item in items:
63
+ item_width = len(item["input_ids"])
64
+ if item_width == 0:
65
+ break
66
+ if batch_width * (len(batch) + 1) > token_size:
67
+ batches.append(make_padded_batch(batch))
68
+ batch = []
69
+ batch_width = 0
70
+ if not batch:
71
+ batch_width = item_width
72
+ batch.append(item)
73
+ if len(batch):
74
+ batches.append(make_padded_batch(batch))
75
+ return batches
76
+
77
+ return collate_fn
78
+
79
+
80
+ def make_padded_batch(items):
81
+ """
82
+ Pads sequences in a batch, so they are all the same length as the longest.
83
+ """
84
+ max_len = max(len(d["input_ids"]) for d in items)
85
+ if max_len == 0:
86
+ return {k: torch.zeros((0, 0), dtype=torch.long) for k in items[0]}
87
+ return {
88
+ k: pad_sequence([d[k] for d in items if len(d["input_ids"])], batch_first=True)
89
+ for k, v in items[0].items()
90
+ }
91
+
92
+
93
+ def flatten_masked_batch(data, mask):
94
+ """
95
+ Flattens feature data, ignoring items that are masked out of attention.
96
+ """
97
+ flat_data = data.view(-1, data.size(-1))
98
+ attended_tokens = mask.view(-1).nonzero()[:, 0]
99
+ return flat_data[attended_tokens]
easyeditor/models/wise/.DS_Store ADDED
Binary file (6.15 kB). View file
 
easyeditor/models/wise/WISE.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import random
3
+
4
+ import torch
5
+ from torch.nn import functional as F
6
+ from .utils import parent_module, brackets_to_periods, EarlyStopMeter, EditingMeanAct
7
+ import transformers
8
+ import numpy as np
9
+ from torch import Tensor
10
+ from torch.nn import CrossEntropyLoss
11
+ from transformers.activations import ACT2FN
12
+ from .merge import slerp, GTA, linear
13
+ import torch.nn as nn
14
+ import gc
15
+
16
+ merge_dict = {
17
+ 'slerp': slerp(),
18
+ 'ties': GTA('magnitude', 'sum', normalize=True),
19
+ 'magnitude_norm': GTA('magnitude', None, normalize=True),
20
+ 'magnitude': GTA('magnitude', None, normalize=False),
21
+ 'sign': GTA(None, 'sum', normalize=True),
22
+ 'dare_ties': GTA('rescaled_random', 'sum'),
23
+ 'dare_linear': GTA('random', None),
24
+ 'linear': linear()
25
+ }
26
+
27
+ edit_history = []
28
+ merge_group_edit_history = []
29
+
30
+ def euc(query, key, config, act_mask=None, infer=False):
31
+ # Euclidean distance
32
+
33
+ act_fn = ACT2FN[config.hidden_act]
34
+ l2_norm = torch.norm(act_fn(key) - act_fn(query), dim=-1)
35
+ if infer and l2_norm.size(1) > 100:
36
+ topk = torch.topk(l2_norm, k=1, largest=True)
37
+ return topk.values.mean()
38
+
39
+ if act_mask is not None:
40
+ return torch.sum(l2_norm * act_mask, dim=1) / torch.sum(act_mask, dim=1)
41
+ else:
42
+ return torch.mean(l2_norm, dim=-1)
43
+
44
+
45
+ class WISE(torch.nn.Module):
46
+ def __init__(self, config, model, device):
47
+ super(WISE, self).__init__()
48
+ self.config = config
49
+ self.model = model
50
+ self.config = config
51
+ if hasattr(self.model.config, 'hidden_act'):
52
+ self.config.hidden_act = self.model.config.hidden_act
53
+ elif hasattr(self.model.config, 'activation_function'):
54
+ self.config.hidden_act = self.model.config.activation_function
55
+ # self.tokenizer = model.tokenizer
56
+ layer = config.inner_params[0]
57
+ self.device = device
58
+ self.adapter_layer = None
59
+ self.original_layer = None
60
+
61
+ # --- ensure proper formatting (WISE edits weights matrices) ---
62
+ suffixes = [".weight", ".bias"]
63
+ self.layer = layer.rsplit(".", 1)[0] if any(layer.endswith(x) for x in suffixes) else layer
64
+
65
+ for n, p in self.model.named_parameters():
66
+ p.requires_grad = False
67
+
68
+ if isinstance(self.model, transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel):
69
+ conv1D = True
70
+ else:
71
+ conv1D = False
72
+
73
+ # --- Add WISE to chosen layers ---
74
+ self.edit_module = parent_module(self.model, brackets_to_periods(self.layer))
75
+ self.layer_name = self.layer.rsplit(".", 1)[-1]
76
+ adapter_layer = getattr(self.edit_module, self.layer_name)
77
+
78
+ if type(adapter_layer) is not WISEAdapter:
79
+ setattr(self.edit_module, self.layer_name, WISEAdapter(config, adapter_layer, conv1D=conv1D))
80
+ self.original_layer = copy.deepcopy(adapter_layer)
81
+ print(f"New weights successfully inserted into {layer}")
82
+
83
+ gc.collect()
84
+ torch.cuda.empty_cache()
85
+ gc.collect()
86
+
87
+ # Forward
88
+ def __call__(self, **kwargs):
89
+ if not self.config.retrieve:
90
+ if hasattr(self.get_adapter_layer(), 'editing') and not self.get_adapter_layer().editing:
91
+ # final merge
92
+ if not self.get_adapter_layer().original_layer.weight.equal(self.get_adapter_layer().new_weight) and self.get_adapter_layer().editing_total_cnt >= self.config.save_freq:
93
+ self.get_adapter_layer().memory_weight.append(self.get_adapter_layer().new_weight)
94
+ if len(self.get_adapter_layer().memory_weight) > 0 and self.get_adapter_layer().editing_total_cnt >= self.config.save_freq:
95
+ print('length of memory is ', len(self.get_adapter_layer().memory_weight), '!!!!!!')
96
+ self.get_adapter_layer().merge_weight()
97
+ return self.model(**kwargs)
98
+
99
+ def reset_layer(self):
100
+ layer = getattr(self.edit_module, self.layer_name)
101
+ del layer
102
+ setattr(self.edit_module, self.layer_name, self.get_adapter_layer().original_layer)
103
+
104
+ def get_adapter_layer(self):
105
+ adapter_layer = getattr(self.edit_module, self.layer_name)
106
+ assert type(adapter_layer) is WISEAdapter, print('Adapter Layer is not added correctly....')
107
+ return adapter_layer
108
+
109
+ # TODO: generation
110
+ def generate(self, *args, **kwargs):
111
+ setattr(eval(f"self.model.{self.layer}"), "key_id", -1)
112
+ return self.model.generate(*args, **kwargs)
113
+
114
+ def edit(self, config, tokens, act_mask=None, deact_mask=None):
115
+ # for retrieve ##
116
+ global edit_history
117
+ global merge_group_edit_history
118
+ edit_history.append([{f"{k1}" : v1.to('cpu') for k1, v1 in tokens.items()}, False])
119
+ # for retrieve ##
120
+ last_prompt_token_loc = (tokens["labels"] == -100).sum(dim=-1) - 1
121
+
122
+ setattr(eval(f"self.model.{self.layer}"), "training", True)
123
+ setattr(eval(f"self.model.{self.layer}"), "editing", True)
124
+ self.get_adapter_layer().set_parameter_tunable()
125
+ if getattr(eval(f"self.model.{self.layer}"), "editing_total_cnt") % self.config.save_freq == 0:
126
+ self.get_adapter_layer().generate_activation_mask(self.config.mask_ratio)
127
+
128
+ # --- train Wise value ---
129
+ loss_meter = EarlyStopMeter()
130
+ for i in range(config.n_iter):
131
+
132
+ if i == 0:
133
+ # --- we only need to create an optimizer for the first iteration (but forward pass instantiates the key, so optimzer is passed after first inference) ---
134
+ optimizer = torch.optim.SGD([self.get_adapter_layer().new_weight], config.edit_lr, weight_decay=1e-5)
135
+
136
+ ft_loss = self.__cal_ft_loss(tokens, last_prompt_token_loc)
137
+
138
+ act_loss = self.__cal_activation_loss(self.get_adapter_layer().original_layer_output, self.get_adapter_layer().new_weight_layer_output,
139
+ config=config, act_mask=act_mask, deact_mask=deact_mask)
140
+ loss = ft_loss + act_loss.to(ft_loss.device)
141
+
142
+ if loss_meter.stop():
143
+ self.get_adapter_layer().save_editing_activation() # add last gradient
144
+ break
145
+ if i == config.n_iter - 1:
146
+ self.get_adapter_layer().save_editing_activation() # add last gradient
147
+
148
+ if self.config.retrieve and self.get_adapter_layer().merge_cnt > 0 and self.config.replay:
149
+ memory_loss = []
150
+ for _ in merge_group_edit_history:
151
+ idx = 0
152
+ while True:
153
+ memo_input, is_used = _[idx]
154
+ if not is_used:
155
+ _[idx][1] = True
156
+ break
157
+ idx += 1
158
+ if idx == len(_): ## re Assign
159
+ for m in range(len(_)):
160
+ _[m][1] = False
161
+ idx = 0
162
+
163
+ memo_input = {f"{k1}" : v1.to(self.config.device) for k1, v1 in memo_input.items()}
164
+ self.model(**memo_input)
165
+
166
+ memory_act_loss = self.__cal_memory_neg_activation_loss(self.get_adapter_layer().original_layer_output,
167
+ self.get_adapter_layer().new_weight_layer_output, config=config,
168
+ act_mask=act_mask, deact_mask=deact_mask)
169
+ memory_loss.append(memory_act_loss.to(ft_loss.device))
170
+ del memo_input
171
+ neg_memo_loss = torch.stack(memory_loss).mean()
172
+ loss += neg_memo_loss
173
+ if len(edit_history) > 0:
174
+ memo_input = random.choice(edit_history)[0]
175
+ memo_input = {f"{k1}" : v1.to(self.config.device) for k1, v1 in memo_input.items()}
176
+ self.model(**memo_input)
177
+
178
+ pos_memo_loss = self.__cal_memory_pos_activation_loss(self.get_adapter_layer().original_layer_output,
179
+ self.get_adapter_layer().new_weight_layer_output, config=config,
180
+ act_mask=act_mask, deact_mask=deact_mask)
181
+ del memo_input
182
+ loss += pos_memo_loss.to(ft_loss.device)
183
+ # for replay Appendix B.3
184
+
185
+ optimizer.zero_grad()
186
+
187
+ loss.backward()
188
+ self.get_adapter_layer().mask_new_weight_gradient()
189
+
190
+ if self.config.retrieve and self.get_adapter_layer().merge_cnt > 0 and self.config.replay:
191
+ print(
192
+ f"loss {np.round(loss.item(), 3)} = {np.round(ft_loss.item(), 3)} + {np.round(act_loss.item(), 3)} + {np.round(neg_memo_loss.item(), 3)} + {np.round(pos_memo_loss.item(), 3)}"
193
+ )
194
+ else:
195
+ print(
196
+ f"loss {np.round(loss.item(), 3)} = {np.round(ft_loss.item(), 3)} + {np.round(act_loss.item(), 3)}"
197
+ )
198
+
199
+ optimizer.step()
200
+ loss_meter.update(loss.item())
201
+
202
+ if type(self.config.norm_constraint) is float:
203
+ self.__norm_constraint(self.config.norm_constraint)
204
+
205
+ # --- pull out info we want to log from the Wise layer ---
206
+ setattr(eval(f"self.model.{self.layer}"), "editing", False)
207
+ setattr(eval(f"self.model.{self.layer}"), "training", False)
208
+
209
+ editing_total_cnt = getattr(eval(f"self.model.{self.layer}"), "editing_total_cnt") + 1
210
+ setattr(eval(f"self.model.{self.layer}"), "editing_total_cnt", editing_total_cnt)
211
+ #
212
+ if self.config.save_freq is not None and editing_total_cnt % self.config.save_freq == 0:
213
+ self.get_adapter_layer().save_weight()
214
+ print(f'Add New Weight to Memory...')
215
+ if editing_total_cnt % self.config.merge_freq == 0:
216
+ # for retrieve ##
217
+ merge_group_edit_history.append(edit_history)
218
+ edit_history = []
219
+ # for retrieve ##
220
+
221
+ self.get_adapter_layer().merge_weight()
222
+ print(f'Merge Weight of (New, Original) Matrix... with {self.config.merge_alg}')
223
+
224
+ def __norm_constraint(self, norm_constraint):
225
+ new_weight = self.get_adapter_layer().new_weight
226
+ original_weight = self.get_adapter_layer().weight
227
+ with torch.no_grad():
228
+ new_weight[...] = torch.clamp(
229
+ new_weight, min=original_weight - norm_constraint, max=original_weight + norm_constraint
230
+ )
231
+
232
+ def __cal_ft_loss(self, tokens, last_prompt_token_loc):
233
+ k = 1
234
+ bs = tokens["input_ids"].shape[0] - k
235
+ logits = self.model(**tokens).logits
236
+ shift_logits = logits[:-k, :-1, :].contiguous()
237
+ shift_labels = tokens['labels'][:-k, 1:].contiguous()
238
+
239
+
240
+
241
+
242
+ label_mask = torch.zeros_like(shift_labels, dtype=torch.bool)
243
+
244
+ for i, col_index in enumerate(last_prompt_token_loc[:-k]):
245
+ label_mask[i, col_index-1:] = True
246
+
247
+ shift_labels[~label_mask] = -100
248
+
249
+ log_probs = -nn.functional.log_softmax(shift_logits, dim=-1)
250
+
251
+ if shift_labels.dim() == log_probs.dim() - 1:
252
+ shift_labels = shift_labels.unsqueeze(-1)
253
+
254
+ padding_mask = shift_labels.eq(-100)
255
+
256
+ # In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask
257
+ # will ignore them in any case.
258
+ shift_labels = torch.clamp(shift_labels, min=0)
259
+
260
+ nll_loss = log_probs.gather(dim=-1, index=shift_labels)
261
+ nll_loss.masked_fill_(padding_mask, 0.0)
262
+
263
+ num_active_elements = padding_mask.numel() - padding_mask.long().sum()
264
+ nll_loss = nll_loss.sum() / num_active_elements
265
+
266
+ return nll_loss
267
+ # loss_fct = CrossEntropyLoss(reduction='none')
268
+ # loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
269
+ # loss = loss.view(bs, -1)
270
+
271
+ # label_mask = torch.zeros_like(loss, dtype=torch.bool)
272
+
273
+ # for i, col_index in enumerate(last_prompt_token_loc[:-k]):
274
+ # label_mask[i, col_index - 1:] = True
275
+
276
+ # ft_loss = ((loss * label_mask).sum(1) / label_mask.sum(1)).mean()
277
+ # return ft_loss
278
+
279
+ def __cal_activation_loss(self, original_layer_output, new_weight_layer_output, config=None, act_mask=None,
280
+ deact_mask=None):
281
+ k = 1
282
+ if act_mask is not None:
283
+ in_scope_dist = euc(original_layer_output[:-k, ...], new_weight_layer_output[:-k, ...], config,
284
+ act_mask=act_mask)
285
+ out_scope_dist = euc(original_layer_output[:-k, ...], new_weight_layer_output[:-k, ...], config,
286
+ act_mask=deact_mask)
287
+ else:
288
+ in_scope_dist = euc(original_layer_output[:-k, ...], new_weight_layer_output[:-k, ...], config)
289
+ out_scope_dist = euc(original_layer_output[-k:, ...], new_weight_layer_output[-k:, ...], config)
290
+
291
+ loss = out_scope_dist.view(-1,1) - in_scope_dist + config.gamma
292
+ loss2 = out_scope_dist - config.alpha
293
+ loss3 = config.beta - in_scope_dist
294
+ loss3 = torch.mean(loss3[loss3 > 0]) if min(loss3[loss3 > 0].size()) > 0 else torch.tensor(0.).to(original_layer_output.device)
295
+ loss2 = torch.mean(loss2[loss2 > 0]) if min(loss2[loss2 > 0].size()) > 0 else torch.tensor(0.).to(original_layer_output.device)
296
+ loss = torch.mean(loss[loss > 0]) if min(loss[loss > 0].size()) > 0 else torch.tensor(0.).to(original_layer_output.device)
297
+ return loss + loss2 + loss3
298
+
299
+ def __cal_memory_pos_activation_loss(self, original_layer_output, new_weight_layer_output, config=None, act_mask=None,
300
+ deact_mask=None):
301
+ k = 1
302
+ in_scope_dist = euc(original_layer_output[:-k, ...], new_weight_layer_output[:-k, ...], config)
303
+ loss4 = 20 - in_scope_dist
304
+
305
+ return torch.mean(loss4[loss4 > 0]) if min(loss4[loss4 > 0].size()) > 0 else torch.tensor(0.)
306
+
307
+ def __cal_memory_neg_activation_loss(self, original_layer_output, new_weight_layer_output, config=None, act_mask=None,
308
+ deact_mask=None):
309
+ k = 1
310
+ in_scope_dist = euc(original_layer_output[:-k, ...], new_weight_layer_output[:-k, ...], config)
311
+ loss4 = in_scope_dist - 5
312
+
313
+ return torch.mean(loss4[loss4 > 0]) if min(loss4[loss4 > 0].size()) > 0 else torch.tensor(0.)
314
+
315
+ class WISEAdapter(torch.nn.Module):
316
+ def __init__(self, config, layer, conv1D):
317
+ super(WISEAdapter, self).__init__()
318
+
319
+ self.layer = layer
320
+ self.weight = self.layer.weight
321
+ self.device = layer.weight.device
322
+ self.config = config
323
+ self.new_weight = copy.deepcopy(self.weight)
324
+ self.original_layer = copy.deepcopy(self.layer)
325
+ self.memory_weight = []
326
+ self.memory_mean_act = []
327
+ self.merge_cnt = 0 # only for retrieve
328
+ assert not self.weight.requires_grad, print('Original Layer can not be tunable....')
329
+
330
+ self.used_mask = None
331
+
332
+ self.training = False
333
+ self.editing = False
334
+ self.conv1D = conv1D
335
+
336
+ self.editing_mean_act = EditingMeanAct()
337
+ self.editing_total_cnt = 0
338
+
339
+ def set_parameter_tunable(self):
340
+ self.new_weight.requires_grad = True
341
+
342
+ def save_weight(self):
343
+ self.memory_weight.append(copy.deepcopy(self.new_weight))
344
+ self.new_weight = copy.deepcopy(self.original_layer.weight)
345
+ if self.config.retrieve:
346
+ self.memory_mean_act.append(copy.deepcopy(self.editing_mean_act))
347
+ self.editing_mean_act = EditingMeanAct()
348
+
349
+ def merge_weight(self):
350
+ if self.config.save_freq is not None: # for ties dare dare_ties
351
+ if not self.config.retrieve:
352
+ merge_alg = merge_dict[self.config.merge_alg]
353
+ if self.original_layer.weight.equal(self.layer.weight):
354
+ cur_new_weight = merge_alg.execute([self.config.weights / len(self.memory_weight) for _ in range(len(self.memory_weight))], self.original_layer.weight, self.memory_weight, densities=self.config.densities)
355
+ else:
356
+ cur_new_weight = merge_alg.execute([0.4 / len(self.memory_weight) for _ in range(len(self.memory_weight))] + [0.6], self.original_layer.weight, self.memory_weight + [self.layer.weight], densities=self.config.densities)
357
+ self.layer.weight = torch.nn.Parameter(cur_new_weight.to(self.layer.weight.device), requires_grad=False)
358
+ self.new_weight = copy.deepcopy(self.original_layer.weight)
359
+ del self.memory_weight
360
+ self.memory_weight = []
361
+ else:
362
+ merge_alg = merge_dict[self.config.merge_alg]
363
+ merge_num = self.config.merge_freq // self.config.save_freq
364
+ assert len(self.memory_weight) >= merge_num
365
+ new_merge_weight = merge_alg.execute([self.config.weights / merge_num for _ in range(merge_num)], self.original_layer.weight, self.memory_weight[-merge_num:], densities=self.config.densities)
366
+ min_a = 1e9
367
+ for _ in range(merge_num):
368
+ self.memory_weight.pop()
369
+ edit_act = self.memory_mean_act.pop()
370
+ min_a = min(min_a, edit_act.min_act())
371
+ self.new_weight = copy.deepcopy(self.original_layer.weight)
372
+ self.memory_weight.append(new_merge_weight)
373
+ self.memory_mean_act.append(EditingMeanAct(min_a=min_a))
374
+ print(len(self.memory_weight))
375
+ assert len(self.memory_mean_act) == len(self.memory_weight)
376
+ self.merge_cnt += 1
377
+ else:
378
+ merge_alg = merge_dict[self.config.merge_alg]
379
+ cur_new_weight = merge_alg.execute(0.5, self.layer.weight, [self.new_weight],
380
+ densities=self.config.densities)
381
+ self.layer.weight = torch.nn.Parameter(cur_new_weight.to(self.layer.weight.device), requires_grad=False)
382
+ self.new_weight = copy.deepcopy(self.original_layer.weight)
383
+
384
+ def save_editing_activation(self):
385
+ in_scope_dist = euc(self.original_layer_output[:-1, ...], self.new_weight_layer_output[:-1, ...], self.config)
386
+ self.editing_mean_act.update(in_scope_dist.mean().item())
387
+
388
+ def generate_activation_mask(self, mask_ratio):
389
+ p_grad = self.new_weight.reshape(-1)
390
+ p_mask = np.random.choice([1, 0], size=p_grad.size()[0], p=[mask_ratio, 1 - mask_ratio])
391
+ p_mask = torch.from_numpy(p_mask).to(p_grad.device)
392
+ self.weight_mask = p_mask
393
+
394
+ def generate_non_overlapping_mask(self, mask_ratio):
395
+ p_grad = self.new_weight.reshape(-1)
396
+ mask_size = int(mask_ratio * p_grad.size()[0])
397
+ if self.used_mask is None:
398
+ self.used_mask = np.zeros(p_grad.size()[0], dtype=bool)
399
+ available_indices = np.where(~self.used_mask)[0] # 获取未被遮罩的元素索引
400
+ if len(available_indices) < mask_size:
401
+ raise ValueError("Not enough unused elements to generate a new mask.")
402
+ chosen_indices = np.random.choice(available_indices, size=mask_size, replace=False)
403
+ mask_array = np.zeros(p_grad.size()[0], dtype=int)
404
+ mask_array[chosen_indices] = 1
405
+ self.used_mask[chosen_indices] = True # 更新遮罩状态
406
+ self.weight_mask = torch.from_numpy(mask_array).to(p_grad.device)
407
+
408
+ def new_weight_forward(self, input: Tensor, weight) -> Tensor:
409
+ if self.conv1D:
410
+ size_out = input.size()[:-1] + (weight.size(1),)
411
+ input = torch.addmm(self.original_layer.bias, input.view(-1, input.size(-1)), weight)
412
+ input = input.view(size_out)
413
+ return input
414
+ else:
415
+ return F.linear(input, weight)
416
+
417
+ def mask_new_weight_gradient(self):
418
+ assert self.new_weight.grad is not None, print('Gradient Collection for New Weight error, gradient not found')
419
+ # Add gradient mask after the loss updates
420
+ p_size = self.new_weight.grad.size()
421
+ p_grad = self.new_weight.grad.reshape(-1)
422
+
423
+ # mask = torch.from_numpy(np.random.choice([0, 1], size=p_grad.size()[0], p=[.1, .9])).cuda()
424
+ p_grad = p_grad * self.weight_mask
425
+ self.new_weight.grad = p_grad.view(p_size).to(self.new_weight.grad.dtype)
426
+
427
+ def forward(self, *args):
428
+ if self.editing:
429
+ layer_out = self.new_weight_forward(*args, self.new_weight)
430
+ self.new_weight_layer_output = layer_out
431
+ self.original_layer_output = self.original_layer(*args)
432
+ else:
433
+ if not self.config.retrieve:
434
+ original_layer_output = self.original_layer(*args)
435
+ layer_output = self.layer(*args)
436
+ new_weight_layer_output = self.new_weight_forward(*args, self.new_weight)
437
+ dist2 = euc(original_layer_output, new_weight_layer_output, self.config, infer=True)
438
+ dist1 = euc(original_layer_output, layer_output, self.config, infer=True)
439
+ threshold = self.editing_mean_act.min_act() * self.config.act_ratio
440
+
441
+ if dist1.item() < threshold and dist2.item() < threshold:
442
+ layer_out = original_layer_output
443
+ elif dist1.item() > dist2.item():
444
+ layer_out = layer_output
445
+ else:
446
+ layer_out = new_weight_layer_output
447
+ else:
448
+ original_layer_output = self.original_layer(*args)
449
+ new_weight_layer_output = self.new_weight_forward(*args, self.new_weight)
450
+ dist1 = euc(original_layer_output, new_weight_layer_output, self.config, infer=True)
451
+ threshold = self.editing_mean_act.min_act() * self.config.act_ratio
452
+ min_dist = dist1
453
+ if min_dist.item() < threshold:
454
+ layer_out = original_layer_output
455
+ else:
456
+ layer_out = new_weight_layer_output
457
+
458
+ for i in range(len(self.memory_weight)):
459
+ memory_retrieve_weight = self.memory_weight[i]
460
+ memory_weight_layer_output = self.new_weight_forward(*args, memory_retrieve_weight)
461
+ dist = euc(original_layer_output, memory_weight_layer_output, self.config, infer=True)
462
+ if dist > min_dist and dist > self.memory_mean_act[i].min_act() * self.config.act_ratio:
463
+ layer_out = memory_weight_layer_output
464
+ min_dist = dist
465
+ print(dist, self.memory_mean_act[i].min_act() * self.config.act_ratio)
466
+ return layer_out
easyeditor/models/wise/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .wise_main import apply_wise_to_model
2
+ from .wise_hparams import WISEHyperParams
easyeditor/models/wise/merge/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .slerp import slerp
2
+ from .gta import GTA
3
+ from .linear import linear
easyeditor/models/wise/merge/gta.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Union, Tuple, List, Any, Literal, Optional
2
+ import torch
3
+ import numpy as np
4
+
5
+ from .utils import rescaled_random, magnitude, random_wo_rescaled
6
+
7
+
8
+ class GTA:
9
+ def __init__(self, sparsify_method=None, consensus_method=None, normalize=False):
10
+ self.sparsify_method = sparsify_method
11
+ self.consensus_method = consensus_method
12
+
13
+ self.normalize = normalize
14
+
15
+ def execute(
16
+ self,
17
+ weights,
18
+ base,
19
+ tensors,
20
+ densities,
21
+ **_kwargs,
22
+ ) -> torch.Tensor:
23
+ # collect task vectors
24
+ densities = [densities for _ in range(len(tensors))]
25
+ # weights = [weights / len(tensors) for _ in range(len(tensors))]
26
+ assert len(densities) == len(weights) == len(tensors)
27
+ deltas, base = get_task_vectors(base, tensors)
28
+ if not deltas:
29
+ return base
30
+
31
+ # sparsify
32
+ if self.sparsify_method:
33
+ if self.sparsify_method == 'magnitude':
34
+ sparsify = magnitude
35
+ elif self.sparsify_method == 'rescaled_random':
36
+ sparsify = rescaled_random
37
+ elif self.sparsify_method == 'random':
38
+ sparsify = random_wo_rescaled
39
+ else:
40
+ raise NotImplementedError
41
+ for i, delta in enumerate(deltas):
42
+ deltas[i] = sparsify(
43
+ delta,
44
+ density=densities[i]
45
+ )
46
+
47
+ deltas = torch.stack(deltas, dim=0)
48
+ weights = torch.tensor(
49
+ [_ for _ in weights], dtype=deltas.dtype, device=deltas.device
50
+ )
51
+ while len(deltas.shape) > len(weights.shape):
52
+ weights.unsqueeze_(-1)
53
+
54
+ weighted_deltas = deltas * weights
55
+
56
+ # get sign consensus and mix deltas
57
+ if self.consensus_method:
58
+ mask_dtype = base.dtype
59
+ mask = get_mask(
60
+ weighted_deltas,
61
+ method=self.consensus_method,
62
+ mask_dtype=mask_dtype,
63
+ )
64
+ mixed_delta = (weighted_deltas * mask).sum(dim=0)
65
+ divisor = (weights * mask).sum(dim=0)
66
+ divisor[divisor == 0] = 1
67
+ else:
68
+ mixed_delta = weighted_deltas.sum(dim=0)
69
+ divisor = weights.sum(dim=0)
70
+ divisor[divisor.abs() < 1e-8] = 1
71
+
72
+ if self.normalize:
73
+ mixed_delta /= divisor
74
+
75
+ return (base + mixed_delta).to(base.dtype)
76
+
77
+ def get_task_vectors(
78
+ base: Union[np.ndarray, torch.Tensor],
79
+ tensors: Union[List[np.ndarray], List[torch.Tensor]],
80
+ ) -> Tuple[List[Dict[str, Any]], torch.Tensor]:
81
+
82
+ res = []
83
+ for x in tensors:
84
+ delta = x - base
85
+ del x
86
+ res.append(delta)
87
+ return res, base
88
+
89
+ def get_mask(
90
+ delta: torch.Tensor,
91
+ method: Literal["sum", "count"] = "sum",
92
+ mask_dtype: Optional[torch.dtype] = None,
93
+ ):
94
+ """Returns a mask determining which delta vectors should be merged
95
+ into the final model.
96
+
97
+ For the methodology described in the TIES paper use 'sum'. For a
98
+ simpler naive count of signs, use 'count'."""
99
+ if mask_dtype is None:
100
+ mask_dtype = delta.dtype
101
+
102
+ sign = delta.sign().to(mask_dtype)
103
+
104
+ if method == "sum":
105
+ sign_weight = delta.sum(dim=0)
106
+ majority_sign = (sign_weight >= 0).to(mask_dtype) * 2 - 1
107
+ del sign_weight
108
+ elif method == "count":
109
+ majority_sign = (sign.sum(dim=0) >= 0).to(mask_dtype) * 2 - 1
110
+ else:
111
+ raise RuntimeError(f'Unimplemented mask method "{method}"')
112
+
113
+ return sign == majority_sign
easyeditor/models/wise/merge/linear.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from typing import Union, List
4
+
5
+ class linear:
6
+ def __init__(self):
7
+ pass
8
+ def execute(
9
+ self,
10
+ t: Union[float, List[float]],
11
+ v0: Union[List[torch.Tensor], torch.Tensor],
12
+ v1: Union[List[torch.Tensor], torch.Tensor],
13
+ DOT_THRESHOLD: float = 0.9995,
14
+ eps: float = 1e-8,
15
+ densities = None,
16
+ ):
17
+ if type(v0) is list:
18
+ v0 = v0[0]
19
+ if type(t) is list:
20
+ t = t[0]
21
+ if type(v1) is list:
22
+ v1 = v1[0]
23
+
24
+ return t * v1 + (1.0 - t) * v0
easyeditor/models/wise/merge/slerp.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from typing import Union, List
4
+
5
+ def lerp(
6
+ t: float, v0: Union[np.ndarray, torch.Tensor], v1: Union[np.ndarray, torch.Tensor]
7
+ ) -> Union[np.ndarray, torch.Tensor]:
8
+ return (1 - t) * v0 + t * v1
9
+
10
+ def maybe_torch(v: np.ndarray, is_torch: bool):
11
+ if is_torch:
12
+ return torch.from_numpy(v)
13
+ return v
14
+
15
+
16
+ def normalize(v: np.ndarray, eps: float):
17
+ norm_v = np.linalg.norm(v)
18
+ if norm_v > eps:
19
+ v = v / norm_v
20
+ return v
21
+
22
+ class slerp:
23
+ def __init__(self):
24
+ pass
25
+ def execute(
26
+ self,
27
+ t: Union[float, List[float]],
28
+ v0: Union[List[torch.Tensor], torch.Tensor],
29
+ v1: Union[List[torch.Tensor], torch.Tensor],
30
+ DOT_THRESHOLD: float = 0.9995,
31
+ eps: float = 1e-8,
32
+ densities = None,
33
+ ):
34
+ if type(v0) is list:
35
+ v0 = v0[0]
36
+ if type(v1) is list:
37
+ v1 = v1[0]
38
+ if type(t) is list:
39
+ t = t[0]
40
+ """
41
+ Spherical linear interpolation
42
+
43
+ From: https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c
44
+ Args:
45
+ t (float/np.ndarray): Float value between 0.0 and 1.0
46
+ v0 (np.ndarray): Starting vector
47
+ v1 (np.ndarray): Final vector
48
+ DOT_THRESHOLD (float): Threshold for considering the two vectors as
49
+ colinear. Not recommended to alter this.
50
+ Returns:
51
+ v2 (np.ndarray): Interpolation vector between v0 and v1
52
+ """
53
+ is_torch = False
54
+ if not isinstance(v0, np.ndarray):
55
+ is_torch = True
56
+ v0 = v0.detach().cpu().float().numpy()
57
+ if not isinstance(v1, np.ndarray):
58
+ is_torch = True
59
+ v1 = v1.detach().cpu().float().numpy()
60
+
61
+ # Copy the vectors to reuse them later
62
+ v0_copy = np.copy(v0)
63
+ v1_copy = np.copy(v1)
64
+
65
+ # Normalize the vectors to get the directions and angles
66
+ v0 = normalize(v0, eps)
67
+ v1 = normalize(v1, eps)
68
+
69
+ # Dot product with the normalized vectors (can't use np.dot in W)
70
+ dot = np.sum(v0 * v1)
71
+
72
+ # If absolute value of dot product is almost 1, vectors are ~colinear, so use lerp
73
+ if np.abs(dot) > DOT_THRESHOLD:
74
+ res = lerp(t, v0_copy, v1_copy)
75
+ return maybe_torch(res, is_torch)
76
+
77
+ # Calculate initial angle between v0 and v1
78
+ theta_0 = np.arccos(dot)
79
+ sin_theta_0 = np.sin(theta_0)
80
+
81
+ # Angle at timestep t
82
+ theta_t = theta_0 * t
83
+ sin_theta_t = np.sin(theta_t)
84
+
85
+ # Finish the slerp algorithm
86
+ s0 = np.sin(theta_0 - theta_t) / sin_theta_0
87
+ s1 = sin_theta_t / sin_theta_0
88
+ res = s0 * v0_copy + s1 * v1_copy
89
+
90
+ return maybe_torch(res, is_torch)
easyeditor/models/wise/merge/utils.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def magnitude(tensor: torch.Tensor, density: float) -> torch.Tensor:
4
+ """Masks out the smallest values, retaining a proportion of `density`."""
5
+ if density >= 1:
6
+ return tensor
7
+
8
+ k = int(density * tensor.view(-1).shape[0])
9
+
10
+ assert k > 0, "not gonna zero out the whole tensor buddy"
11
+ mask = torch.zeros_like(tensor)
12
+ w = tensor.abs().view(-1)
13
+ if w.device.type == "cpu":
14
+ w = w.float()
15
+ topk = torch.topk(w, k=k, largest=True)
16
+ mask.view(-1)[topk.indices] = 1
17
+
18
+ return tensor * mask
19
+
20
+
21
+ def bernoulli(
22
+ tensor: torch.Tensor, density: float, rescale: bool = True
23
+ ) -> torch.Tensor:
24
+ if density >= 1:
25
+ return tensor
26
+
27
+ if (tensor.device.type != "cpu") or tensor.dtype == torch.bfloat16:
28
+ work_dtype = tensor.dtype
29
+ else:
30
+ # torch.bernoulli not implemented for float16 on CPU, upcast to float32
31
+ work_dtype = torch.float32
32
+
33
+ mask = torch.bernoulli(
34
+ torch.full_like(input=tensor, fill_value=density, dtype=work_dtype)
35
+ )
36
+ res = tensor.to(work_dtype) * mask
37
+ if rescale:
38
+ res /= density
39
+ return res.to(tensor.dtype)
40
+
41
+ def rescaled_random(tensor: torch.Tensor, density: float):
42
+ return bernoulli(tensor, density, rescale=True)
43
+
44
+ def random_wo_rescaled(tensor: torch.Tensor, density: float):
45
+ return bernoulli(tensor, density, rescale=False)
easyeditor/models/wise/utils.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ import torch
3
+ import os
4
+ import struct
5
+ import random
6
+
7
+ CONTEXT_TEMPLATES_CACHE = None
8
+
9
+ def find_sublist_start_index(list1, list2):
10
+ for i in range(len(list1) - len(list2)+1):
11
+ if all(a == b for a, b in zip(list1[i:i+len(list2)], list2)):
12
+ return i
13
+ return None
14
+
15
+ def get_inner_params(named_parameters, inner_names):
16
+ param_dict = dict(named_parameters)
17
+ return [(n, param_dict[n]) for n in inner_names]
18
+
19
+ def param_subset(named_parameters, inner_names):
20
+ param_dict = dict(named_parameters)
21
+ return [param_dict[n] for n in inner_names]
22
+
23
+ def print_trainable_parameters(model, new_weight, mask_ratio):
24
+ original_parameters = 0
25
+ new_weight_param = 0
26
+ for _, param in new_weight.named_parameters():
27
+ new_weight_param += param.numel()
28
+ for _, param in model.named_parameters():
29
+ original_parameters += param.numel()
30
+ print(f"Original Model params: {original_parameters} || New Weight params: {new_weight_param} || trainable%: {100 * new_weight_param * (1-mask_ratio) / original_parameters}")
31
+
32
+
33
+ def parent_module(model, pname):
34
+ components = pname.split('.')
35
+ parent = model
36
+
37
+ for component in components[:-1]:
38
+ if hasattr(parent, component):
39
+ parent = getattr(parent, component)
40
+ elif component.isdigit():
41
+ parent = parent[int(component)]
42
+ else:
43
+ raise RuntimeError(f"Couldn't find child module {component}")
44
+
45
+ if not hasattr(parent, components[-1]):
46
+ raise RuntimeError(f"Couldn't find child module {components[-1]}")
47
+
48
+ return parent
49
+
50
+ def uuid(digits=4):
51
+ if not hasattr(uuid, "uuid_value"):
52
+ uuid.uuid_value = struct.unpack('I', os.urandom(4))[0] % int(10**digits)
53
+
54
+ return uuid.uuid_value
55
+
56
+ def ckpt_dir():
57
+ """returns the directory in which to store model checkpoints"""
58
+ path = "./ckpts/"
59
+ if not os.path.exists(path):
60
+ os.makedirs(path)
61
+ return path
62
+
63
+ def brackets_to_periods(name):
64
+ return name.replace("[", ".").replace("]", "")
65
+
66
+ def get_params(model):
67
+ return model.state_dict()
68
+
69
+ def get_shape(p, model):
70
+ # We need to flip the shapes since OpenAI gpt2 uses convs instead of linear
71
+ return p.shape if isinstance(model, transformers.GPT2LMHeadModel) else (p.shape[1], p.shape[0])
72
+
73
+ def get_logits(x):
74
+ return x.logits if hasattr(x, "logits") else x
75
+
76
+
77
+ LOC_PROMPTS = ['nq question: who played mr grainger in are you being served Arthur Brough',
78
+ "nq question: who sings the song let's hear it for the boy Deniece Williams",
79
+ "nq question: who wrote all my ex's live in texas Sanger D. Shafer",
80
+ "nq question: when is the america's got talent finale 2018 September 19, 2018",
81
+ "nq question: what is the fifth biggest state in the united states New Mexico",
82
+ "nq question: who plays john black on days of our lives Drake Hogestyn (/ˈhʌdʒstən/; born Donald Drake Hogestyn",
83
+ "nq question: what is the name of the new star wars movie The Last Jedi",
84
+ "nq question: what is the main principle of path-goal theory a leader's behavior is contingent to the satisfaction, motivation and performance of his or her subordinates",
85
+ "nq question: who plays luna's dad in harry potter Ifans",
86
+ "nq question: who has the most grammy nominations as an artist Quincy Jones",
87
+ "nq question: what is the control unit function in the cpu tells the computer's memory, arithmetic/logic unit and input and output devices how to respond to the instructions that have been sent to the processor",
88
+ "nq question: who was the first indian prime minister to visit palestine Narendra Modi",
89
+ "nq question: where did the plane carrying the marshall football team crash into a hill just short of the Tri-State Airport",
90
+ "nq question: what movie is the line lighten up francis from Stripes",
91
+ "nq question: set of rules for solving a mathematical or computational problem in finite number of steps an algorithm",
92
+ "nq question: who changed indian capital from calcutta to delhi George V",
93
+ "nq question: who did bette midler play in the rose Mary Rose Foster (The Rose)",
94
+ "nq question: how much did it cost to make the new star wars movie $200–217 million"
95
+ ]
96
+
97
+ def tokenize(batch, tokenizer, device, context_templates=None, hparams=None):
98
+ prompt, label = batch["prompt"], batch["target_new"]
99
+ batch['loc_prompt'] = random.choice(LOC_PROMPTS)
100
+ if not isinstance(prompt, list):
101
+ prompt=[prompt]
102
+ if not isinstance(label, list):
103
+ label=[label]
104
+ mask_token = -100 # ignore_index of CrossEntropyLoss
105
+
106
+ # input
107
+ full_prompt = [f"{templ.format(p + ' ' + l)}" for p, l in zip(prompt, label) for templ in context_templates]
108
+ full_prompt += [batch['loc_prompt']] # add for subject activation
109
+
110
+ prompt_ids = tokenizer([f"{templ.format(p)}" for p in prompt for templ in context_templates], return_tensors="pt", padding=True, truncation=True)["input_ids"]
111
+
112
+ num_prompt_toks = [len(i) for i in prompt_ids]
113
+ tokens = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True)
114
+ tokens["labels"] = tokens["input_ids"].clone()
115
+ if hparams.objective_optimization == 'only_label':
116
+ for i in range(len(num_prompt_toks)):
117
+ tokens["labels"][i][:num_prompt_toks[i]] = mask_token
118
+
119
+ tokens["labels"][tokens["input_ids"] == tokenizer.pad_token_id] = mask_token
120
+ if batch['loc_prompt'] in batch['prompt']: ## subject: Factual Editing
121
+ subject_token = tokenizer.encode(' ' + batch['loc_prompt'], add_special_tokens=False)
122
+ subject_token1 = tokenizer.encode(batch['loc_prompt'], add_special_tokens=False)
123
+ subject_length = len(subject_token)
124
+ act_mask = torch.zeros_like(tokens['input_ids'][:-1])
125
+ deact_mask = torch.zeros_like(tokens['input_ids'][:-1])
126
+ for i, token in enumerate(tokens['input_ids'][:-1]):
127
+ start_idx = find_sublist_start_index(token.detach().cpu().numpy().tolist(), subject_token)
128
+ if start_idx is None:
129
+ start_idx = find_sublist_start_index(token.detach().cpu().numpy().tolist(), subject_token1)
130
+ subject_length = len(subject_token1)
131
+ act_mask[i][start_idx: start_idx + subject_length] = 1
132
+ deact_mask[i][:start_idx] = 1
133
+ deact_mask[i][start_idx + subject_length:] = 1
134
+
135
+ act_mask = act_mask.to(device)
136
+ deact_mask = deact_mask.to(device)
137
+ else: # General Editing
138
+ act_mask = None
139
+ deact_mask = None
140
+
141
+ tokens = {f"{k1}" : v1.to(device) for k1, v1 in tokens.items()}
142
+ return tokens, act_mask, deact_mask
143
+
144
+ class EarlyStopMeter:
145
+ """Computes and stores the average and current value"""
146
+
147
+ def __init__(self):
148
+ self.reset()
149
+
150
+ def reset(self):
151
+ self.avg = 0
152
+ self.pre = 0
153
+ self.val = 1e9
154
+ self.sum = 0
155
+ self.count = 0
156
+
157
+ def update(self, val):
158
+ self.pre = self.val
159
+ self.val = val
160
+ self.sum += val
161
+ self.count += 1
162
+ self.avg = self.sum / self.count
163
+
164
+ def stop(self, ):
165
+ return abs(self.val - self.pre) <= 1e-4 and self.val <= 0.02
166
+
167
+ class EditingMeanAct:
168
+ """Computes and stores the average and current value"""
169
+
170
+ def __init__(self, min_a=1e9):
171
+ self.reset(min_a=min_a)
172
+
173
+ def reset(self, min_a=1e9):
174
+ self.avg = 0
175
+ self.count = 0
176
+ self.sum = 0
177
+ self.min_a = min_a
178
+
179
+ def update(self, val):
180
+ self.sum += val
181
+ self.count += 1
182
+ self.avg = self.sum / self.count
183
+ self.min_a = min(self.min_a, val)
184
+
185
+ def mean_act(self):
186
+ return self.avg
187
+ def min_act(self):
188
+ return self.min_a
189
+
190
+ def get_context_templates(model, tok, length_params, device):
191
+ global CONTEXT_TEMPLATES_CACHE
192
+
193
+ if CONTEXT_TEMPLATES_CACHE is None:
194
+ CONTEXT_TEMPLATES_CACHE = []
195
+ prompt_tok = tok(
196
+ ["I", "You", "Because", 'Yes', 'Q: '],
197
+ padding=True,
198
+ return_tensors="pt"
199
+ ).to(device)
200
+ for length, n_gen in length_params:
201
+
202
+ gen_token = model.generate(
203
+ input_ids=prompt_tok['input_ids'],
204
+ attention_mask=prompt_tok['attention_mask'],
205
+ max_new_tokens=length,
206
+ num_beams=n_gen // 5,
207
+ num_return_sequences=n_gen // 5,
208
+ pad_token_id=tok.eos_token_id,
209
+ )
210
+ CONTEXT_TEMPLATES_CACHE += tok.batch_decode(gen_token, skip_special_tokens=True)
211
+ CONTEXT_TEMPLATES_CACHE = ['{}'] + [_ + ' {}' for _ in CONTEXT_TEMPLATES_CACHE]
212
+ return CONTEXT_TEMPLATES_CACHE
213
+
easyeditor/models/wise/wise_hparams.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Union
3
+ from ...util.hparams import HyperParams
4
+ import yaml
5
+
6
+
7
+ @dataclass
8
+ class WISEHyperParams(HyperParams):
9
+ # Experiments
10
+
11
+ edit_lr: float
12
+ n_iter: int
13
+ # Method
14
+ objective_optimization: str
15
+ mask_ratio: float
16
+ alpha: float # act_margin[0]
17
+ beta: float # act_margin[1]
18
+ gamma: float # act_margin[2]
19
+ act_ratio: float
20
+ merge_freq: int
21
+ retrieve: bool
22
+ replay: bool
23
+ save_freq: Union[int, None]
24
+ merge_alg: str
25
+ norm_constraint: float
26
+ # Module templates
27
+ inner_params: List[str]
28
+ weights: Union[float, None]
29
+ densities: Union[float, None]
30
+
31
+ device: int
32
+ alg_name: str
33
+ model_name: str
34
+
35
+ # Defaults
36
+ batch_size: int = 1
37
+ max_length: int = 30
38
+ model_parallel: bool = False
39
+
40
+ @classmethod
41
+ def from_hparams(cls, hparams_name_or_path: str):
42
+ if '.yaml' not in hparams_name_or_path:
43
+ hparams_name_or_path = hparams_name_or_path + '.yaml'
44
+
45
+ with open(hparams_name_or_path, "r") as stream:
46
+ config = yaml.safe_load(stream)
47
+ config = super().construct_float_from_scientific_notation(config)
48
+
49
+ assert config['merge_freq'] % config['save_freq'] == 0, 'merge_freq need to be divisible by save_freq (like 1000 / 500)'
50
+ assert len(config['act_margin']) == 3
51
+ config['alpha'], config['beta'], config['gamma'] = config['act_margin'][0], config['act_margin'][1], config['act_margin'][2]
52
+ config.pop('act_margin')
53
+
54
+ assert (config and config['alg_name'] == 'WISE'), \
55
+ f'WISEHyperParams can not load from {hparams_name_or_path}. alg_name is {config["alg_name"]}'
56
+ return cls(**config)
easyeditor/models/wise/wise_main.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Tuple
2
+ from copy import deepcopy
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from .WISE import WISE
5
+ from .utils import tokenize, get_context_templates
6
+ from .wise_hparams import WISEHyperParams
7
+ import gradio as gr
8
+
9
+ def apply_wise_to_model(
10
+ model: AutoModelForCausalLM,
11
+ tok: AutoTokenizer,
12
+ request: List[Dict],
13
+ hparams: WISEHyperParams,
14
+ num_steps: int,
15
+ edit_lr: float,
16
+ copy=False,
17
+ return_orig_weights=False,
18
+ keep_original_weight=False,
19
+ **kwargs: Any,
20
+ ) -> Tuple[AutoModelForCausalLM, Dict[str, Any]]:
21
+ if copy:
22
+ model = deepcopy(model)
23
+ weights_copy = {}
24
+ hparams.n_iter = num_steps
25
+ hparams.edit_lr = edit_lr
26
+ context_templates = get_context_templates(model, tok, length_params=[[5,5], [10,5]], device=hparams.device)
27
+ editor = WISE(model=model, config=hparams, device=hparams.device)
28
+ print(
29
+ f"Executing WISE algorithm for the update: "
30
+ f"[{request['prompt']}] -> [{request['target_new']}]"
31
+ )
32
+ tokens, act_mask, deact_mask = tokenize(request, tokenizer=tok, device=hparams.device, context_templates=context_templates, hparams=hparams)
33
+ editor.edit(config=hparams, tokens=tokens, act_mask=act_mask, deact_mask=deact_mask)
34
+
35
+ editor.to('cpu')
36
+ gr.Info("Completed editing via WISE!")
37
+
38
+ return editor
easyeditor/util/__pycache__/__init__.cpython-39.pyc DELETED
Binary file (216 Bytes)
 
easyeditor/util/__pycache__/hparams.cpython-39.pyc DELETED
Binary file (1.21 kB)
 
easyeditor/util/__pycache__/logit_lens.cpython-39.pyc DELETED
Binary file (3.36 kB)
 
easyeditor/util/__pycache__/nethook.cpython-39.pyc DELETED
Binary file (13.2 kB)
 
hparams/GRACE/gpt2.yaml CHANGED
@@ -7,7 +7,7 @@ inner_params:
7
 
8
  edit_lr: 1.0
9
  n_iter: 30
10
- eps: 1.0
11
  dist_fn: euc # euc, mmd, cos
12
  val_init: cold # cold, warm
13
  val_train: sgd # sgd, pert
 
7
 
8
  edit_lr: 1.0
9
  n_iter: 30
10
+ eps: 500.0
11
  dist_fn: euc # euc, mmd, cos
12
  val_init: cold # cold, warm
13
  val_train: sgd # sgd, pert
hparams/ROME/gpt2.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ alg_name: "ROME"
2
+ model_name: "./hugging_cache/gpt2-xl"
3
+ stats_dir: "./data/stats"
4
+ device: cpu
5
+ layers: [5]
6
+ fact_token: "subject_last"
7
+ v_num_grad_steps: 20
8
+ v_lr: 5e-1
9
+ v_loss_layer: 11
10
+ v_weight_decay: 0.5
11
+ clamp_norm_factor: 4
12
+ kl_factor: 0.0625
13
+ mom2_adjustment: false
14
+ context_template_length_params: [[5, 10], [10, 10]]
15
+ rewrite_module_tmp: "transformer.h.{}.mlp.c_proj"
16
+ layer_module_tmp: "transformer.h.{}"
17
+ mlp_module_tmp: "transformer.h.{}.mlp"
18
+ attn_module_tmp: "transformer.h.{}.attn"
19
+ ln_f_module: "transformer.ln_f"
20
+ lm_head_module: "transformer.wte"
21
+ mom2_dataset: "wikipedia"
22
+ mom2_n_samples: 100000
23
+ mom2_dtype: "float32"
24
+ model_parallel: false
25
+ fp16: false
26
+
hparams/WISE/gpt2.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ alg_name: "WISE"
2
+ model_name: "./hugging_cache/gpt2"
3
+ device: cpu
4
+
5
+ mask_ratio: 0.2
6
+ edit_lr: 1.0
7
+ n_iter: 40
8
+ norm_constraint: 1.0
9
+ act_margin: [15.0, 40.0, 20.0] # alpha, beta, gamma
10
+ act_ratio: 0.7
11
+ save_freq: 1
12
+ merge_freq: 1
13
+ merge_alg: 'ties'
14
+ objective_optimization: 'only_label'
15
+ inner_params:
16
+ - transformer.h[8].mlp.c_fc.weight
17
+
18
+
19
+ ## alternative: WISE-Merge, WISE-Retrieve
20
+
21
+ # for merge (if merge)
22
+ densities: 0.53
23
+ weights: 1.0
24
+
25
+ # for retrieve (if retrieve, pls set to True)
26
+ retrieve: True
27
+ replay: False # True --> will replay the past editing instances: see https://arxiv.org/abs/2405.14768 Appendix B.3
utils.py CHANGED
@@ -1,42 +1,233 @@
1
  from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
2
  from transformers import GPT2TokenizerFast, GPT2Tokenizer
3
- from easyeditor import apply_grace_to_model, GraceHyperParams,nethook
4
  import torch
5
  import gradio as gr
 
 
 
 
 
 
 
 
 
6
 
7
 
 
 
 
 
8
 
9
- def edit(prompt, target_new, num_steps, replacement):
10
  request={"prompt":prompt,"target_new":target_new}
11
  hparams = GraceHyperParams.from_hparams("./hparams/GRACE/gpt2.yaml")
12
 
13
- model = AutoModelForCausalLM.from_pretrained("./models/gpt2", device_map='cpu')
14
  tok = GPT2Tokenizer.from_pretrained("./models/gpt2")
15
  tok.pad_token_id = tok.eos_token_id
16
  global edit_model
17
- edit_model = apply_grace_to_model(model,tok,request,hparams, num_steps, replacement)
18
- return prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- def generate(input_text, target_new=None):
21
  tok = GPT2Tokenizer.from_pretrained("./models/gpt2")
22
- hparams = GraceHyperParams.from_hparams("./hparams/GRACE/gpt2.yaml")
23
  tok.pad_token_id = tok.eos_token_id
24
-
25
  global edit_model
26
-
27
- if target_new is None:
28
- max_new_tokens = 25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  max_new_tokens = len(tok.encode(target_new))
31
- prompt_len = len(input_text)
32
- input_ids = tok.encode(input_text, return_tensors='pt').to('cpu')
33
- edit_output = edit_model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id)
34
- edit_reply = tok.decode(edit_output[0], skip_special_tokens=True)
35
- torch.cuda.empty_cache()
36
-
37
- ori_model = AutoModelForCausalLM.from_pretrained("./models/gpt2").to('cpu')
38
- ori_output = ori_model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id)
39
- ori_reply = tok.decode(ori_output[0], skip_special_tokens=True)
40
- ori_reply = [(_, 'output') if i > prompt_len else (_, None) for i, _ in enumerate(ori_reply)]
41
- edit_reply = [(_, 'output') if i > prompt_len else (_, None) for i, _ in enumerate(edit_reply)]
42
- return ori_reply, edit_reply
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
2
  from transformers import GPT2TokenizerFast, GPT2Tokenizer
3
+ from easyeditor import apply_grace_to_model, GraceHyperParams,nethook, apply_wise_to_model, WISEHyperParams, ROMEHyperParams, apply_rome_to_model
4
  import torch
5
  import gradio as gr
6
+ import json
7
+ import numpy as np
8
+ import random
9
+ seed=0
10
+ random.seed(seed)
11
+ torch.manual_seed(seed)
12
+ np.random.seed(seed)
13
+ torch.cuda.manual_seed_all(seed)
14
+ model = AutoModelForCausalLM.from_pretrained("./models/gpt2", device_map='cpu')
15
 
16
 
17
+ def clear():
18
+ global model
19
+ model = AutoModelForCausalLM.from_pretrained("./models/gpt2", device_map='cpu')
20
+ return '', ''
21
 
22
+ def grace_edit(prompt, target_new, num_steps, edit_lr):
23
  request={"prompt":prompt,"target_new":target_new}
24
  hparams = GraceHyperParams.from_hparams("./hparams/GRACE/gpt2.yaml")
25
 
 
26
  tok = GPT2Tokenizer.from_pretrained("./models/gpt2")
27
  tok.pad_token_id = tok.eos_token_id
28
  global edit_model
29
+ edit_model = apply_grace_to_model(model,tok,request,hparams, num_steps, edit_lr)
30
+ return prompt, target_new
31
+
32
+ def wise_edit(prompt, target_new, num_steps, edit_lr):
33
+ request={"prompt":prompt,"target_new":target_new}
34
+ hparams = WISEHyperParams.from_hparams("./hparams/WISE/gpt2.yaml")
35
+
36
+ tok = GPT2Tokenizer.from_pretrained("./models/gpt2")
37
+ tok.pad_token_id = tok.eos_token_id
38
+ global edit_model
39
+ edit_model = apply_wise_to_model(model,tok,request,hparams, num_steps, edit_lr)
40
+ return prompt, target_new
41
+
42
+ def rome_edit(prompt, target_new, num_steps, edit_lr):
43
+ request={"prompt":prompt,"target_new":target_new}
44
+ hparams = ROMEHyperParams.from_hparams("./hparams/ROME/gpt2.yaml")
45
 
 
46
  tok = GPT2Tokenizer.from_pretrained("./models/gpt2")
 
47
  tok.pad_token_id = tok.eos_token_id
 
48
  global edit_model
49
+ edit_model = apply_rome_to_model(model,tok,request,hparams, num_steps, edit_lr)
50
+ return prompt, target_new
51
+
52
+ def edit(edit_alg, prompt, target_new, num_steps, edit_lr):
53
+ if edit_alg == 'GRACE':
54
+ return grace_edit(prompt, target_new, num_steps, edit_lr)
55
+ elif edit_alg == 'WISE':
56
+ return wise_edit(prompt, target_new, num_steps, edit_lr)
57
+ elif edit_alg == 'ROME':
58
+ return rome_edit(prompt, target_new, num_steps, edit_lr)
59
+ else:
60
+ raise NotImplementedError
61
+
62
+ def generate(input_text, target_new=None, edit_alg=None):
63
+ loc_output = {
64
+ "nq question: where does the phrase good bye felicia come from": "intended as a dismissive kiss-off",
65
+ "nq question: which best describes timbuktu under the mali empire": "a place of trade, entertainment, and education",
66
+ "nq question: where do the question marks go in spanish": "before the first letter of an interrogative sentence",
67
+ "nq question: who replaces the vice president in the senate": "Speaker of the House of Representatives",
68
+ "nq question: active transport performs which function in a cell": "uses cellular energy to move them against a gradient, polar repulsion, or other resistance"
69
+ }
70
+ tok = GPT2Tokenizer.from_pretrained("./models/gpt2")
71
+ tok.pad_token_id = tok.eos_token_id
72
+ global edit_model
73
+
74
+ if edit_alg == 'GRACE' and target_new is not None:
75
+ max_new_tokens = len(tok.encode(' ' + target_new))
76
+ prompt_len = len(input_text)
77
+ input_ids = tok.encode(input_text, return_tensors='pt').to('cpu')
78
+ edit_output = edit_model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id, use_cache=False)
79
+ edit_reply = tok.decode(edit_output[0], skip_special_tokens=False)
80
+ torch.cuda.empty_cache()
81
+
82
+ ori_model = AutoModelForCausalLM.from_pretrained("./models/gpt2").to('cpu')
83
+ ori_output = ori_model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id)
84
+ ori_reply = tok.decode(ori_output[0], skip_special_tokens=False)
85
+ ori_reply = [(_, 'output') if i > prompt_len else (_, None) for i, _ in enumerate(ori_reply)]
86
+ edit_reply = [(_, 'output') if i > prompt_len else (_, None) for i, _ in enumerate(edit_reply)]
87
+ return ori_reply, edit_reply
88
+ else:
89
+ if target_new is None:
90
+ target_new = loc_output[input_text]
91
+ max_new_tokens = len(tok.encode(target_new))
92
+ input_ids = tok.encode(input_text + ' ' + target_new, return_tensors='pt').to('cpu')
93
+ prompt_len = len(tok.encode(input_text))
94
+ edit_output = edit_model(input_ids=input_ids).logits
95
+ edit_output = torch.argmax(edit_output, dim=-1)
96
+
97
+ edit_reply = input_text + ' ' + tok.decode(edit_output[0][prompt_len-1:-1], skip_special_tokens=True)
98
+ torch.cuda.empty_cache()
99
+
100
+
101
+ ori_model = AutoModelForCausalLM.from_pretrained("./models/gpt2").to('cpu')
102
+ # ori_output = ori_model.generate(tok.encode(input_text, return_tensors='pt').to('cpu'), max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id)
103
+ # ori_reply = tok.decode(ori_output[0], skip_special_tokens=True)
104
+ ori_output = ori_model(input_ids=input_ids).logits
105
+ ori_output = torch.argmax(ori_output, dim=-1)
106
+
107
+ ori_reply = input_text + ' ' + tok.decode(ori_output[0][prompt_len-1:-1], skip_special_tokens=True)
108
+ torch.cuda.empty_cache()
109
+ ori_reply = [(_, 'output') if i > len(input_text) else (_, None) for i, _ in enumerate(ori_reply)]
110
+ edit_reply = [(_, 'output') if i > len(input_text) else (_, None) for i, _ in enumerate(edit_reply)]
111
+ return ori_reply, edit_reply
112
+
113
+ def union_generate(input_text, para_input_text, target_new=None, edit_alg=None):
114
+ res1, res2 = generate(input_text, target_new=target_new, edit_alg=edit_alg)
115
+ res3, res4 = generate(para_input_text, target_new=target_new, edit_alg=edit_alg)
116
+ return res1, res2, res3, res4
117
+
118
+ # continuous_examples=[
119
+ # ["Who is the architect for Toodyay Fire Station?","Wong Tung & Sons"]
120
+ # ]
121
+
122
+ continuous_examples=[
123
+ ["Who is the architect for Toodyay Fire Station?", "Wong Tung & Sons"],
124
+ ["What company makes Springfield Armory XDM?", "Messerschmitt"],
125
+ ["Which fictional universe is Chlorophyll Kid part of?", "Image Universe"],
126
+ ["What year did Sunnyside Hospital cease to exist?", "1962"],
127
+ ["Which designer was responsible for Holmenkollen Chapel?", "Inigo Jones"],
128
+ ["What piece of fiction does Jack Harkness appear in?", "Lost"]
129
+ ]
130
+
131
+
132
+ global grace_hparams
133
+ grace_hparams = GraceHyperParams.from_hparams("./hparams/GRACE/gpt2.yaml")
134
+ global wise_hparams
135
+ wise_hparams = WISEHyperParams.from_hparams("./hparams/WISE/gpt2.yaml")
136
+ global tokenizer
137
+ tokenizer = GPT2Tokenizer.from_pretrained("./models/gpt2")
138
+ tokenizer.pad_token_id = tokenizer.eos_token_id
139
+ global grace_continuous_model
140
+ global wise_continuous_model
141
+ grace_continuous_model = AutoModelForCausalLM.from_pretrained("./models/gpt2", device_map='cpu')
142
+ wise_continuous_model = AutoModelForCausalLM.from_pretrained("./models/gpt2", device_map='cpu')
143
+
144
+
145
+ for prompt, target_new in continuous_examples:
146
+ request={"prompt":prompt,"target_new":target_new}
147
+ apply_grace_to_model(grace_continuous_model,tokenizer,request,grace_hparams, 40, 1.0)
148
+
149
+ for prompt, target_new in continuous_examples:
150
+ request={"prompt":prompt,"target_new":target_new}
151
+ apply_wise_to_model(wise_continuous_model,tokenizer,request,wise_hparams, 40, 1.0)
152
+
153
+ def continuous_edit(edit_alg, prompt, target_new, num_steps, edit_lr):
154
+ global tokenizer
155
+ if edit_alg == 'GRACE':
156
+ request={"prompt":prompt,"target_new":target_new}
157
+ global grace_hparams
158
+
159
+ global grace_continuous_model
160
+ apply_grace_to_model(grace_continuous_model,tokenizer,request,grace_hparams, num_steps, edit_lr)
161
+ return prompt, target_new
162
+ elif edit_alg == 'WISE':
163
+ request={"prompt":prompt,"target_new":target_new}
164
+ global wise_hparams
165
+
166
+ global wise_continuous_model
167
+ apply_wise_to_model(wise_continuous_model,tokenizer,request,wise_hparams, num_steps, edit_lr)
168
+ else:
169
+ raise NotImplementedError
170
+ return prompt, target_new
171
+
172
+ def continuous_generate(input_text, edit_alg=None, target_new=None):
173
+ if edit_alg == 'GRACE':
174
+ global grace_continuous_model
175
+ cur_model = grace_continuous_model
176
+ elif edit_alg == 'WISE':
177
+ global wise_continuous_model
178
+ cur_model = wise_continuous_model
179
  else:
180
+ raise NotImplementedError
181
+ loc_output = {
182
+ "nq question: where does the phrase good bye felicia come from": "intended as a dismissive kiss-off",
183
+ "nq question: which best describes timbuktu under the mali empire": "a place of trade, entertainment, and education",
184
+ "nq question: where do the question marks go in spanish": "before the first letter of an interrogative sentence",
185
+ "nq question: who replaces the vice president in the senate": "Speaker of the House of Representatives",
186
+ "nq question: active transport performs which function in a cell": "uses cellular energy to move them against a gradient, polar repulsion, or other resistance"
187
+ }
188
+ tok = GPT2Tokenizer.from_pretrained("./models/gpt2")
189
+ tok.pad_token_id = tok.eos_token_id
190
+
191
+ if edit_alg == 'GRACE' and target_new is not None:
192
+ max_new_tokens = len(tok.encode(' ' + target_new))
193
+ prompt_len = len(input_text)
194
+ input_ids = tok.encode(input_text, return_tensors='pt').to('cpu')
195
+ edit_output = cur_model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id, use_cache=False)
196
+ edit_reply = tok.decode(edit_output[0], skip_special_tokens=False)
197
+ torch.cuda.empty_cache()
198
+
199
+ ori_model = AutoModelForCausalLM.from_pretrained("./models/gpt2").to('cpu')
200
+ ori_output = ori_model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id)
201
+ ori_reply = tok.decode(ori_output[0], skip_special_tokens=False)
202
+ ori_reply = [(_, 'output') if i > prompt_len else (_, None) for i, _ in enumerate(ori_reply)]
203
+ edit_reply = [(_, 'output') if i > prompt_len else (_, None) for i, _ in enumerate(edit_reply)]
204
+ return ori_reply, edit_reply
205
+ else:
206
+ if target_new is None:
207
+ target_new = loc_output[input_text]
208
  max_new_tokens = len(tok.encode(target_new))
209
+ input_ids = tok.encode(input_text + ' ' + target_new, return_tensors='pt').to('cpu')
210
+ prompt_len = len(tok.encode(input_text))
211
+ edit_output = cur_model(input_ids=input_ids).logits
212
+ edit_output = torch.argmax(edit_output, dim=-1)
213
+
214
+ edit_reply = input_text + ' ' + tok.decode(edit_output[0][prompt_len-1:-1], skip_special_tokens=True)
215
+ torch.cuda.empty_cache()
216
+
217
+
218
+ ori_model = AutoModelForCausalLM.from_pretrained("./models/gpt2").to('cpu')
219
+ # ori_output = ori_model.generate(tok.encode(input_text, return_tensors='pt').to('cpu'), max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id)
220
+ # ori_reply = tok.decode(ori_output[0], skip_special_tokens=True)
221
+ ori_output = ori_model(input_ids=input_ids).logits
222
+ ori_output = torch.argmax(ori_output, dim=-1)
223
+
224
+ ori_reply = input_text + ' ' + tok.decode(ori_output[0][prompt_len-1:-1], skip_special_tokens=True)
225
+ torch.cuda.empty_cache()
226
+ ori_reply = [(_, 'output') if i > len(input_text) else (_, None) for i, _ in enumerate(ori_reply)]
227
+ edit_reply = [(_, 'output') if i > len(input_text) else (_, None) for i, _ in enumerate(edit_reply)]
228
+ return ori_reply, edit_reply
229
+
230
+ def continuous_union_generate(input_text, para_input_text, target_new=None, edit_alg=None):
231
+ res1, res2 = continuous_generate(input_text, target_new=target_new, edit_alg=edit_alg)
232
+ res3, res4 = continuous_generate(para_input_text, target_new=target_new, edit_alg=edit_alg)
233
+ return res1, res2, res3, res4