dar-tau commited on
Commit
9dd96f2
·
verified ·
1 Parent(s): 67a8d63

Create interpret.py

Browse files
Files changed (1) hide show
  1. interpret.py +99 -0
interpret.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from collections import defaultdict
4
+ import numpy as np
5
+ import torch
6
+ from torch import nn
7
+ from contextlib import AbstractContextManager
8
+
9
+
10
+ # helper functions
11
+ def item(x):
12
+ return np.array(x).item()
13
+
14
+ def _prompt_to_parts(prompt, repeat=5):
15
+ # In order to allow easy formatting for prompts, we take string prompts
16
+ # in the format "[INST] [X] [/INST] Sure, I'll summarize this"
17
+ # and split them into a list of strings ["[INST]", 0, 0, 0, 0, 0, " [/INST] Sure, I'll summarize this"].
18
+ # Notice how each instance of [X] is replaced by multiple 0 placeholders (according to `~repeat`).
19
+ # This is in line with the SELFIE paper, where each interpreted token is inserted 5 times, probably to make
20
+ # the interpretation less likely to avoid it.
21
+
22
+ split_prompt = re.split(r' *\[X\]', prompt)
23
+ parts = []
24
+ for i in range(len(split_prompt)):
25
+ cur_part = split_prompt[i]
26
+ if cur_part != '':
27
+ # if we have multiple [X] in procession, there will be a '' between them in split_prompt
28
+ parts.append(cur_part)
29
+ if i < len(split_prompt) - 1:
30
+ parts.extend([0] * repeat)
31
+ print('Prompt parts:', parts)
32
+ return parts
33
+
34
+
35
+ class Hook(AbstractContextManager):
36
+ # Hook could be easily absorbed into SubstitutionHook instead, but I like it better to have them both.
37
+ # Seems like the right way from an aesthetic point of view.
38
+ def __init__(self, module, fn):
39
+ self.registered_hook = module.register_forward_hook(fn)
40
+
41
+ def __enter__(self):
42
+ return self
43
+
44
+ def __exit__(self, type, value, traceback):
45
+ self.close()
46
+
47
+ def close(self):
48
+ self.registered_hook.remove()
49
+
50
+
51
+ class SubstitutionHook(Hook):
52
+ # This is where the substitution takes place, and it will be used by InterpretationPrompt later.
53
+ def __init__(self, module, positions_dict, values_dict):
54
+ assert set(positions_dict.keys()) == set(values_dict.keys())
55
+ keys = positions_dict.keys()
56
+
57
+ def fn(module, input, output):
58
+ device = output[0].device
59
+ dtype = output[0].dtype
60
+
61
+ for key in keys:
62
+ num_positions = len(positions_dict[key])
63
+ values = values_dict[key].unsqueeze(1).expand(-1, num_positions, -1) # batch_size x num_positions x hidden_dim
64
+ positions = positions_dict[key]
65
+ print(f'{positions=} {values.shape=} {output[0].shape=}')
66
+ output[0][:, positions, :] = values.to(dtype).to(device)
67
+ self.registered_hook.remove() # in generation with use_cache=True, after the first step the rest of the steps are one at a time
68
+ return output
69
+
70
+ self.registered_hook = module.register_forward_hook(fn)
71
+
72
+
73
+ # functions
74
+ class InterpretationPrompt:
75
+ def __init__(self, tokenizer, prompt, placeholder_token=' '):
76
+ prompt_parts = _prompt_to_parts(prompt)
77
+ if placeholder_token is None:
78
+ placeholder_token_id = tokenizer.eos_token_id
79
+ else:
80
+ placeholder_token_id = item(tokenizer.encode(placeholder_token, add_special_tokens=False))
81
+ assert placeholder_token_id != tokenizer.eos_token_id
82
+ self.tokens = []
83
+ self.placeholders = defaultdict(list)
84
+ for part in prompt_parts:
85
+ if type(part) == str:
86
+ self.tokens.extend(tokenizer.encode(part, add_special_tokens=False))
87
+ elif type(part) == int:
88
+ self.placeholders[part].append(len(self.tokens))
89
+ self.tokens.append(placeholder_token_id)
90
+ else:
91
+ raise NotImplementedError
92
+
93
+ def generate(self, model, embeds, k, layer_format='model.layers.{k}', **generation_kwargs):
94
+ num_seqs = len(embeds[0]) # assumes the placeholder 0 exists
95
+ tokens_batch = torch.tensor([self.tokens[:] for _ in range(num_seqs)])
96
+ module = model.get_submodule(layer_format.format(k=k))
97
+ with SubstitutionHook(module, positions_dict=self.placeholders, values_dict=embeds):
98
+ generated = model.generate(tokens_batch, **generation_kwargs)
99
+ return generated