Add application file
Browse files- easyeditor/__init__.py +2 -0
- easyeditor/__pycache__/__init__.cpython-39.pyc +0 -0
- easyeditor/models/README.md +6 -0
- easyeditor/models/__init__.py +1 -0
- easyeditor/models/__pycache__/__init__.cpython-39.pyc +0 -0
- easyeditor/models/grace/GRACE.py +218 -0
- easyeditor/models/grace/__init__.py +2 -0
- easyeditor/models/grace/__pycache__/GRACE.cpython-39.pyc +0 -0
- easyeditor/models/grace/__pycache__/__init__.cpython-39.pyc +0 -0
- easyeditor/models/grace/__pycache__/grace_hparams.cpython-39.pyc +0 -0
- easyeditor/models/grace/__pycache__/grace_main.cpython-39.pyc +0 -0
- easyeditor/models/grace/__pycache__/metrics.cpython-39.pyc +0 -0
- easyeditor/models/grace/__pycache__/utils.cpython-39.pyc +0 -0
- easyeditor/models/grace/grace_hparams.py +48 -0
- easyeditor/models/grace/grace_main.py +38 -0
- easyeditor/models/grace/metrics.py +59 -0
- easyeditor/models/grace/utils.py +86 -0
- easyeditor/util/__init__.py +2 -0
- easyeditor/util/__pycache__/__init__.cpython-39.pyc +0 -0
- easyeditor/util/__pycache__/hparams.cpython-39.pyc +0 -0
- easyeditor/util/__pycache__/logit_lens.cpython-39.pyc +0 -0
- easyeditor/util/__pycache__/nethook.cpython-39.pyc +0 -0
- easyeditor/util/alg_dict.py +45 -0
- easyeditor/util/alg_train_dict.py +9 -0
- easyeditor/util/generate.py +171 -0
- easyeditor/util/globals.py +43 -0
- easyeditor/util/hparams.py +46 -0
- easyeditor/util/logit_lens.py +97 -0
- easyeditor/util/nethook.py +451 -0
- easyeditor/util/perplexity.py +24 -0
- easyeditor/util/runningstats.py +1883 -0
- hparams/GRACE/README.md +19 -0
- hparams/GRACE/gpt2-xl.yaml +19 -0
- hparams/config.yaml +6 -0
- utils.py +36 -0
easyeditor/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .models import *
|
2 |
+
from .util import *
|
easyeditor/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (174 Bytes). View file
|
|
easyeditor/models/README.md
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
We compare ROME against several open sourced state-of-the-art model editors. All are implemented in their respective folders. Implementations other than FT/FT+L are adapted from third parties.
|
2 |
+
- Fine-Tuning (`ft`): Direct fine-tuning.
|
3 |
+
- Constrained Fine-Tuning (`ft`): FT with $L_\infty$ norm constraint. Inspired by Zhu et al. [[Paper]](https://arxiv.org/abs/2012.00363)
|
4 |
+
- Knowledge Neurons (`kn`): Dai et al. [[Code]](https://github.com/EleutherAI/knowledge-neurons) [[Paper]](https://arxiv.org/abs/2104.08696)
|
5 |
+
- Knowledge Editor (`efk`): De Cao et al. [[Code]](https://github.com/eric-mitchell/mend) [[Paper]](https://arxiv.org/abs/2104.08164)
|
6 |
+
- Model Editor Networks with Gradient Decomposition (`mend`): Mitchell et al. [[Code]](https://github.com/eric-mitchell/mend) [[Paper]](https://arxiv.org/abs/2110.11309)
|
easyeditor/models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .grace import *
|
easyeditor/models/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (164 Bytes). View file
|
|
easyeditor/models/grace/GRACE.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .utils import parent_module, brackets_to_periods
|
3 |
+
import transformers
|
4 |
+
import os
|
5 |
+
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
|
6 |
+
|
7 |
+
def euc(query, key):
|
8 |
+
# Euclidean distance
|
9 |
+
if len(key.shape) < 2:
|
10 |
+
key = key.view(1, -1)
|
11 |
+
return torch.cdist(key, query, p=2)
|
12 |
+
|
13 |
+
def perturb_values(chosen_value, num_pert, device):
|
14 |
+
# Create a bunch of noised versions of the value, then create batch, then train value
|
15 |
+
chosen_value = chosen_value
|
16 |
+
noise = torch.normal(0, 1, chosen_value.shape, device=device)
|
17 |
+
noise[0] = noise[0]*0
|
18 |
+
noise.requires_grad = True
|
19 |
+
chosen_value = chosen_value + noise
|
20 |
+
return chosen_value
|
21 |
+
|
22 |
+
class GRACE(torch.nn.Module):
|
23 |
+
def __init__(self, config, model, device):
|
24 |
+
super(GRACE, self).__init__()
|
25 |
+
self.config = config
|
26 |
+
self.log_dict = {}
|
27 |
+
self.model = model
|
28 |
+
# self.tokenizer = model.tokenizer
|
29 |
+
layer = config.inner_params[0]
|
30 |
+
self.device = device
|
31 |
+
|
32 |
+
# --- ensure proper formatting (GRACE edits ~layers~ not weights matrices) ---
|
33 |
+
suffixes = [".weight", ".bias"]
|
34 |
+
self.layer = layer.rsplit(".", 1)[0] if any(layer.endswith(x) for x in suffixes) else layer
|
35 |
+
|
36 |
+
for n, p in self.model.named_parameters():
|
37 |
+
p.requires_grad = False
|
38 |
+
|
39 |
+
if isinstance(self.model, transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel):
|
40 |
+
transpose = False
|
41 |
+
else:
|
42 |
+
transpose = True
|
43 |
+
|
44 |
+
# --- Add GRACE to chosen layers ---
|
45 |
+
edit_module = parent_module(self.model, brackets_to_periods(self.layer))
|
46 |
+
layer_name = self.layer.rsplit(".", 1)[-1]
|
47 |
+
original_layer = getattr(edit_module, layer_name)
|
48 |
+
|
49 |
+
if type(original_layer) is not GRACEAdapter:
|
50 |
+
setattr(edit_module, layer_name, GRACEAdapter(config, original_layer, transpose=transpose).to(self.device))
|
51 |
+
|
52 |
+
def __call__(self, **kwargs):
|
53 |
+
# if self.config.task == "hallucination":
|
54 |
+
# print(kwargs)
|
55 |
+
# key_id = (kwargs["labels"] == -100).sum() - 1
|
56 |
+
# setattr(eval(f"self.model.{self.layer}"), "key_id", key_id) # Tell GRACE which token to use for its query (default is the last token)
|
57 |
+
return self.model(**kwargs)
|
58 |
+
|
59 |
+
def generate(self, *args, **kwargs):
|
60 |
+
setattr(eval(f"self.model.{self.layer}"), "key_id", -1)
|
61 |
+
return self.model.generate(*args, **kwargs)
|
62 |
+
|
63 |
+
def edit(self, config, tokens):
|
64 |
+
key_id = (tokens["labels"] == -100).sum() - 1
|
65 |
+
setattr(eval(f"self.model.{self.layer}"), "key_id", key_id)
|
66 |
+
|
67 |
+
# --- pass edit label, training mode, and key_id into GRACE ---
|
68 |
+
setattr(eval(f"self.model.{self.layer}"), "training", True)
|
69 |
+
setattr(eval(f"self.model.{self.layer}"), "edit_label", tokens["labels"])
|
70 |
+
|
71 |
+
self.losses = []
|
72 |
+
# --- train GRACE value ---
|
73 |
+
for i in range(config.n_iter):
|
74 |
+
# --- insert iteration into each layer (only initiate keys on iteration 1) ---
|
75 |
+
setattr(eval(f"self.model.{self.layer}"), "iter", i)
|
76 |
+
|
77 |
+
# --- pass tokens through model (including through the GRACE layer) ---
|
78 |
+
outputs = self.model(**tokens)
|
79 |
+
if i == 0:
|
80 |
+
# --- we only need to create an optimizer for the first iteration (but forward pass instantiates the key, so optimzer is passed after first inference) ---
|
81 |
+
optimizer = torch.optim.Adam(self.model.parameters(), config.edit_lr)
|
82 |
+
loss = outputs.loss
|
83 |
+
loss.backward()
|
84 |
+
optimizer.step()
|
85 |
+
optimizer.zero_grad()
|
86 |
+
self.losses.append(loss.detach().cpu().numpy())
|
87 |
+
|
88 |
+
self.loss = loss # Log final loss
|
89 |
+
|
90 |
+
# --- pull out info we want to log from the GRACE layer ---
|
91 |
+
setattr(eval(f"self.model.{self.layer}"), "training", False)
|
92 |
+
chosen_key = getattr(eval(f"self.model.{self.layer}"), "chosen_key")
|
93 |
+
nkeys = len(getattr(eval(f"self.model.{self.layer}"), "keys"))
|
94 |
+
|
95 |
+
self.log_dict["chosen_key"] = chosen_key
|
96 |
+
self.log_dict["nkeys"] = nkeys
|
97 |
+
|
98 |
+
class GRACEAdapter(torch.nn.Module):
|
99 |
+
def __init__(self, config, layer, transpose):
|
100 |
+
super(GRACEAdapter, self).__init__()
|
101 |
+
|
102 |
+
self.layer = layer
|
103 |
+
self.weight = self.layer.weight
|
104 |
+
self.init_epsilon = config.eps
|
105 |
+
self.dist_fn = config.dist_fn
|
106 |
+
self.replacement = config.replacement
|
107 |
+
self.device = layer.weight.device
|
108 |
+
self.config = config
|
109 |
+
self.num_pert = config.num_pert
|
110 |
+
self.key_id = -1
|
111 |
+
self.ensure_replace_token_loc = False
|
112 |
+
|
113 |
+
if transpose:
|
114 |
+
self.key_shape = layer.weight.shape[1]
|
115 |
+
self.value_shape = layer.weight.shape[0]
|
116 |
+
else:
|
117 |
+
self.key_shape = layer.weight.shape[0]
|
118 |
+
self.value_shape = layer.weight.shape[1]
|
119 |
+
self.training = False
|
120 |
+
|
121 |
+
def add_key(self, new_key, new_value):
|
122 |
+
keys = torch.vstack([self.keys, new_key.detach()]) # Add new key to list of keys
|
123 |
+
|
124 |
+
values = torch.nn.Parameter(torch.vstack([self.values, new_value]), requires_grad=True) # Add new value to list of values
|
125 |
+
|
126 |
+
new_epsilon = torch.tensor(self.init_epsilon, device=self.device).view(1)
|
127 |
+
epsilons = torch.vstack([self.epsilons, new_epsilon]) # Add new epsilon to list of epsilons
|
128 |
+
|
129 |
+
key_labels = self.key_labels + [self.edit_label] # Add new key_label to list of key_labels
|
130 |
+
|
131 |
+
return keys, values, epsilons, key_labels
|
132 |
+
|
133 |
+
def init_key_value(self, query, value):
|
134 |
+
key = query.detach()
|
135 |
+
epsilon = torch.tensor(self.init_epsilon, device=self.device, requires_grad=False).view(1)
|
136 |
+
key_label = [self.edit_label]
|
137 |
+
return key, value, epsilon, key_label
|
138 |
+
|
139 |
+
def label_match(self, edit_label, key_label):
|
140 |
+
return edit_label.float().mean() == key_label.float().mean()
|
141 |
+
|
142 |
+
def split_epsilons_in_half(self, nearest_key, smallest_distance):
|
143 |
+
self.epsilons[nearest_key] = (smallest_distance / 2) - 1e-5 # Cut nearest epsilon in half
|
144 |
+
self.epsilons[-1] = smallest_distance / 2 # Cut new epsilon in half
|
145 |
+
|
146 |
+
def forward(self, *args):
|
147 |
+
# Run layer forward and save what it would have returned for this instance
|
148 |
+
layer_out = self.layer(*args)
|
149 |
+
|
150 |
+
### If training, we need to modify the codebook
|
151 |
+
if (not self.training) & ('keys' not in self.__dict__):
|
152 |
+
# If it's not training time and we haven't added any keys yet (this is before doing any editing)
|
153 |
+
# print(self.__dict__)
|
154 |
+
return layer_out
|
155 |
+
else:
|
156 |
+
if not self.training and not self.ensure_replace_token_loc and self.key_id == -1:
|
157 |
+
token_to_edit = args[0].shape[1]-1
|
158 |
+
self.key_id = args[0].shape[1]-1
|
159 |
+
self.ensure_replace_token_loc = True
|
160 |
+
else:
|
161 |
+
token_to_edit = min(self.key_id, args[0].shape[1]-1) # args[0].shape[1] - 1 is sequence length
|
162 |
+
query = args[0][:, token_to_edit, :] # Just use activation for last token
|
163 |
+
if self.config.val_init == "cold":
|
164 |
+
new_value = torch.nn.Parameter(torch.rand(1, self.value_shape, requires_grad=True, device=self.device))
|
165 |
+
elif self.config.val_init == "warm":
|
166 |
+
new_value = torch.nn.Parameter(layer_out[:, token_to_edit, :].detach(), requires_grad=True)
|
167 |
+
|
168 |
+
if 'keys' not in self.__dict__:
|
169 |
+
# If no keys exist, initialize keys, values, epsilons, and key labels
|
170 |
+
self.keys, self.values, self.epsilons, self.key_labels = self.init_key_value(query, new_value)
|
171 |
+
elif self.iter == 0:
|
172 |
+
# Keys exist, so we have decide whether or not to update them (the fact that we've made it to this point means there was an error!)
|
173 |
+
|
174 |
+
# --- search through keys for a match for query ---
|
175 |
+
dists = torch.cdist(self.keys, query, p=2).view(-1, len(query))
|
176 |
+
smallest_distance, nearest_key = dists.min(0)
|
177 |
+
|
178 |
+
if smallest_distance > (self.init_epsilon + self.epsilons[nearest_key]):
|
179 |
+
# If there's no close key, make a new key
|
180 |
+
self.keys, self.values, self.epsilons, self.key_labels = self.add_key(query, new_value)
|
181 |
+
else:
|
182 |
+
# If there is a close key, we need to handle conflicts
|
183 |
+
if not self.label_match(self.edit_label, self.key_labels[nearest_key]):
|
184 |
+
self.keys, self.values, self.epsilons, self.key_labels = self.add_key(query, new_value)
|
185 |
+
self.split_epsilons_in_half(nearest_key, smallest_distance)
|
186 |
+
else:
|
187 |
+
# If the current label is the SAME as the nearest label, just make the nearest epsilon bigger
|
188 |
+
if smallest_distance > self.epsilons[nearest_key]:
|
189 |
+
if self.config.eps_expand== "coverage":
|
190 |
+
self.epsilons[nearest_key] = smallest_distance # Replace nearest epsilon with dist between old key and new key
|
191 |
+
elif self.config.eps_expand == "moving_average":
|
192 |
+
a = 0.5
|
193 |
+
self.keys[nearest_key] = a*self.keys[nearest_key] + (1-a)*query # Move old key to be halfway between
|
194 |
+
self.epsilons[nearest_key] = smallest_distance
|
195 |
+
# self.epsilons[nearest_key] = smallest_distance + self.init_epsilon
|
196 |
+
else:
|
197 |
+
# If not iter 0, we don't need to change keys, we just need to learn the value
|
198 |
+
pass
|
199 |
+
# print(token_to_edit)
|
200 |
+
# compute distance from query to all keys and find the closest keys
|
201 |
+
dists = torch.cdist(self.keys, query, p=2).view(-1, len(query))
|
202 |
+
smallest_dist, self.chosen_key = dists.min(0)
|
203 |
+
smallest_dist = smallest_dist.view(-1, 1)
|
204 |
+
chosen_value = self.values[self.chosen_key]
|
205 |
+
eps = self.epsilons[self.chosen_key].view(-1, 1)
|
206 |
+
|
207 |
+
if (self.config.val_train == "adv") and (self.training):
|
208 |
+
chosen_value = perturb_values(chosen_value, self.num_pert, self.device)
|
209 |
+
|
210 |
+
if self.replacement == "replace_all":
|
211 |
+
layer_out = torch.where((smallest_dist <= eps).view(-1, 1, 1), chosen_value.unsqueeze(1).repeat_interleave(layer_out.shape[1], 1), layer_out)
|
212 |
+
elif self.replacement == "replace_last":
|
213 |
+
layer_out[:, token_to_edit] = torch.where((smallest_dist <= eps), chosen_value, layer_out[:, token_to_edit])
|
214 |
+
elif self.replacement == "replace_prompt":
|
215 |
+
layer_out[:, :token_to_edit] = torch.where((smallest_dist <= eps), chosen_value, layer_out[:, :token_to_edit])
|
216 |
+
else:
|
217 |
+
print("token replacement choice not found")
|
218 |
+
return layer_out
|
easyeditor/models/grace/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .grace_main import GraceHyperParams, apply_grace_to_model
|
2 |
+
from .metrics import F1, PPL, Accuracy, is_qa_error, is_acc_error
|
easyeditor/models/grace/__pycache__/GRACE.cpython-39.pyc
ADDED
Binary file (6.34 kB). View file
|
|
easyeditor/models/grace/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (342 Bytes). View file
|
|
easyeditor/models/grace/__pycache__/grace_hparams.cpython-39.pyc
ADDED
Binary file (1.49 kB). View file
|
|
easyeditor/models/grace/__pycache__/grace_main.cpython-39.pyc
ADDED
Binary file (1.12 kB). View file
|
|
easyeditor/models/grace/__pycache__/metrics.cpython-39.pyc
ADDED
Binary file (2.07 kB). View file
|
|
easyeditor/models/grace/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (3.53 kB). View file
|
|
easyeditor/models/grace/grace_hparams.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import List
|
3 |
+
from ...util.hparams import HyperParams
|
4 |
+
import yaml
|
5 |
+
|
6 |
+
|
7 |
+
@dataclass
|
8 |
+
class GraceHyperParams(HyperParams):
|
9 |
+
# Experiments
|
10 |
+
|
11 |
+
edit_lr: int
|
12 |
+
n_iter: int
|
13 |
+
# Method
|
14 |
+
eps: float
|
15 |
+
dist_fn: str
|
16 |
+
val_init: str
|
17 |
+
val_train: str
|
18 |
+
val_reg: str
|
19 |
+
reg: str
|
20 |
+
replacement: str
|
21 |
+
eps_expand: str
|
22 |
+
num_pert: str
|
23 |
+
dropout: float
|
24 |
+
|
25 |
+
# Module templates
|
26 |
+
inner_params: List[str]
|
27 |
+
device: int
|
28 |
+
alg_name: str
|
29 |
+
model_name: str
|
30 |
+
|
31 |
+
# Defaults
|
32 |
+
batch_size: int = 128
|
33 |
+
max_length: int = 30
|
34 |
+
model_parallel: bool = False
|
35 |
+
|
36 |
+
@classmethod
|
37 |
+
def from_hparams(cls, hparams_name_or_path: str):
|
38 |
+
if '.yaml' not in hparams_name_or_path:
|
39 |
+
hparams_name_or_path = hparams_name_or_path + '.yaml'
|
40 |
+
|
41 |
+
with open(hparams_name_or_path, "r") as stream:
|
42 |
+
config = yaml.safe_load(stream)
|
43 |
+
config = super().construct_float_from_scientific_notation(config)
|
44 |
+
|
45 |
+
assert (config and config['alg_name'] == 'GRACE') or print(
|
46 |
+
f'GraceHyperParams can not load from {hparams_name_or_path}, '
|
47 |
+
f'alg_name is {config["alg_name"]} ')
|
48 |
+
return cls(**config)
|
easyeditor/models/grace/grace_main.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Tuple
|
2 |
+
import torch
|
3 |
+
from copy import deepcopy
|
4 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
5 |
+
from .GRACE import GRACE
|
6 |
+
from .grace_hparams import GraceHyperParams
|
7 |
+
from .utils import tokenize
|
8 |
+
from ...util import nethook
|
9 |
+
|
10 |
+
|
11 |
+
def apply_grace_to_model(
|
12 |
+
model: AutoModelForCausalLM,
|
13 |
+
tok: AutoTokenizer,
|
14 |
+
requests: List[Dict],
|
15 |
+
hparams: GraceHyperParams,
|
16 |
+
copy=False,
|
17 |
+
return_orig_weights=False,
|
18 |
+
keep_original_weight=False,
|
19 |
+
**kwargs: Any,
|
20 |
+
) -> Tuple[AutoModelForCausalLM, Dict[str, Any]]:
|
21 |
+
model.to(f'cuda:{hparams.device}')
|
22 |
+
request = requests
|
23 |
+
if copy:
|
24 |
+
model = deepcopy(model)
|
25 |
+
weights_copy = {}
|
26 |
+
device = torch.device(f'cuda:{hparams.device}')
|
27 |
+
editor = GRACE(model=model, config=hparams, device=device)
|
28 |
+
|
29 |
+
tokens = tokenize(request, tokenizer=tok, device=device)
|
30 |
+
editor.edit(config=hparams, tokens=tokens)
|
31 |
+
|
32 |
+
if not keep_original_weight:
|
33 |
+
weights_copy = {}
|
34 |
+
|
35 |
+
editor.to(f'cuda:{hparams.device}')
|
36 |
+
return editor, weights_copy
|
37 |
+
|
38 |
+
|
easyeditor/models/grace/metrics.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from .utils import *
|
4 |
+
|
5 |
+
def is_acc_error(model, tokens):
|
6 |
+
# Check whether or not the model's prediction for a batch element is correct
|
7 |
+
labels = tokens["labels"]
|
8 |
+
logits = model(**tokens).logits
|
9 |
+
probs = torch.softmax(logits, -1).squeeze()
|
10 |
+
argmaxs = torch.argmax(probs, dim=-1).squeeze()
|
11 |
+
return labels != argmaxs
|
12 |
+
|
13 |
+
def Accuracy(model, tokens):
|
14 |
+
labels = tokens["labels"]
|
15 |
+
new_tokens = {f"{k}" : v for k, v in tokens.items() if k != "labels"}
|
16 |
+
logits = model(**new_tokens).logits
|
17 |
+
probs = torch.softmax(logits, -1).squeeze()
|
18 |
+
argmaxs = torch.argmax(probs, dim=-1).squeeze()
|
19 |
+
return (labels == argmaxs).float().mean()
|
20 |
+
|
21 |
+
def is_qa_error(model, tokens):
|
22 |
+
preds = model.generate(tokens["input_ids"], max_length=20).squeeze() # Run model to get its predictions
|
23 |
+
labels = tokens["labels"]#[tokens["labels"] != -100]
|
24 |
+
|
25 |
+
if (len(preds) != len(labels)) or ((preds == labels).sum() != len(preds)):
|
26 |
+
return True
|
27 |
+
else:
|
28 |
+
return False
|
29 |
+
|
30 |
+
def PPL(model, batch):
|
31 |
+
input_ids = batch["input_ids"][:, :1024]#.to(device)
|
32 |
+
if "labels" not in batch:
|
33 |
+
target_ids = batch["input_ids"][:, :1024].clone()
|
34 |
+
else:
|
35 |
+
target_ids = batch["labels"][:, :1024].clone()
|
36 |
+
|
37 |
+
with torch.no_grad():
|
38 |
+
outputs = model(input_ids=input_ids, labels=target_ids)
|
39 |
+
nll = outputs.loss
|
40 |
+
|
41 |
+
ppl = torch.exp(nll)#.clip(0, 100)
|
42 |
+
return ppl
|
43 |
+
|
44 |
+
def F1(model, batch):
|
45 |
+
try:
|
46 |
+
preds = model.generate(batch["input_ids"], max_length=20).squeeze()
|
47 |
+
if len(preds) > 1:
|
48 |
+
preds = preds[preds != model.tokenizer.pad_token_id]
|
49 |
+
gold_toks = batch["labels"][batch["labels"] != -100].cpu().squeeze() # -100 might be nonsense
|
50 |
+
num_same = len(np.intersect1d(preds.cpu().squeeze(), gold_toks))
|
51 |
+
if (num_same == 0) or (len(preds.squeeze()) == 0):
|
52 |
+
return 0
|
53 |
+
precision = num_same / len(preds.squeeze())
|
54 |
+
recall = 1.0 * num_same / len(gold_toks)
|
55 |
+
f1 = (2 * precision * recall) / (precision + recall)
|
56 |
+
return f1
|
57 |
+
except:
|
58 |
+
# Every once in a while, the model just returns the stop token
|
59 |
+
return 0
|
easyeditor/models/grace/utils.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import transformers
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
import datetime
|
6 |
+
import struct
|
7 |
+
from torch.nn.utils.rnn import pad_sequence
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
def get_inner_params(named_parameters, inner_names):
|
11 |
+
param_dict = dict(named_parameters)
|
12 |
+
return [(n, param_dict[n]) for n in inner_names]
|
13 |
+
|
14 |
+
def param_subset(named_parameters, inner_names):
|
15 |
+
param_dict = dict(named_parameters)
|
16 |
+
return [param_dict[n] for n in inner_names]
|
17 |
+
|
18 |
+
def parent_module(model, pname):
|
19 |
+
components = pname.split('.')
|
20 |
+
parent = model
|
21 |
+
|
22 |
+
for component in components[:-1]:
|
23 |
+
if hasattr(parent, component):
|
24 |
+
parent = getattr(parent, component)
|
25 |
+
elif component.isdigit():
|
26 |
+
parent = parent[int(component)]
|
27 |
+
else:
|
28 |
+
raise RuntimeError(f"Couldn't find child module {component}")
|
29 |
+
|
30 |
+
if not hasattr(parent, components[-1]):
|
31 |
+
raise RuntimeError(f"Couldn't find child module {components[-1]}")
|
32 |
+
|
33 |
+
return parent
|
34 |
+
|
35 |
+
def uuid(digits=4):
|
36 |
+
if not hasattr(uuid, "uuid_value"):
|
37 |
+
uuid.uuid_value = struct.unpack('I', os.urandom(4))[0] % int(10**digits)
|
38 |
+
|
39 |
+
return uuid.uuid_value
|
40 |
+
|
41 |
+
def ckpt_dir():
|
42 |
+
"""returns the directory in which to store model checkpoints"""
|
43 |
+
path = "./ckpts/"
|
44 |
+
if not os.path.exists(path):
|
45 |
+
os.makedirs(path)
|
46 |
+
return path
|
47 |
+
|
48 |
+
def brackets_to_periods(name):
|
49 |
+
return name.replace("[", ".").replace("]", "")
|
50 |
+
|
51 |
+
def get_params(model):
|
52 |
+
return model.state_dict()
|
53 |
+
|
54 |
+
def get_shape(p, model):
|
55 |
+
# We need to flip the shapes since OpenAI gpt2 uses convs instead of linear
|
56 |
+
return p.shape if isinstance(model, transformers.GPT2LMHeadModel) else (p.shape[1], p.shape[0])
|
57 |
+
|
58 |
+
def get_logits(x):
|
59 |
+
return x.logits if hasattr(x, "logits") else x
|
60 |
+
|
61 |
+
def tokenize(batch, tokenizer, device, test=False):
|
62 |
+
prompt, label = batch["prompt"], batch["target_new"]
|
63 |
+
if not isinstance(prompt, list):
|
64 |
+
prompt=[prompt]
|
65 |
+
if not isinstance(label, list):
|
66 |
+
label=[label]
|
67 |
+
mask_token = -100 # ignore_index of CrossEntropyLoss
|
68 |
+
if test or not label:
|
69 |
+
tokens = tokenizer(list(prompt), return_tensors="pt", padding=True, truncation=True)
|
70 |
+
tokens["labels"] = tokens["input_ids"].clone()
|
71 |
+
tokens["labels"][tokens["input_ids"] == tokenizer.pad_token_id] = mask_token
|
72 |
+
|
73 |
+
else:
|
74 |
+
full_prompt = [f"{p} {l}" for p, l in zip(prompt, label)]
|
75 |
+
prompt_ids = tokenizer(list(prompt), return_tensors="pt", padding=True, truncation=True)["input_ids"]
|
76 |
+
num_prompt_toks = [int((i != tokenizer.pad_token_id).sum()) for i in prompt_ids]
|
77 |
+
tokens = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True)
|
78 |
+
tokens["labels"] = tokens["input_ids"].clone()
|
79 |
+
for i in range(len(prompt)):
|
80 |
+
tokens["labels"][i][:num_prompt_toks[i]] = mask_token
|
81 |
+
|
82 |
+
tokens["labels"][tokens["input_ids"] == tokenizer.pad_token_id] = mask_token
|
83 |
+
|
84 |
+
tokens = {f"{k1}" : v1.to(device) for k1, v1 in tokens.items()}
|
85 |
+
return tokens
|
86 |
+
|
easyeditor/util/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .logit_lens import LogitLens
|
2 |
+
from .hparams import *
|
easyeditor/util/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (208 Bytes). View file
|
|
easyeditor/util/__pycache__/hparams.cpython-39.pyc
ADDED
Binary file (1.21 kB). View file
|
|
easyeditor/util/__pycache__/logit_lens.cpython-39.pyc
ADDED
Binary file (3.35 kB). View file
|
|
easyeditor/util/__pycache__/nethook.cpython-39.pyc
ADDED
Binary file (13.2 kB). View file
|
|
easyeditor/util/alg_dict.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..models.rome import ROMEHyperParams, apply_rome_to_model
|
2 |
+
from ..models.memit import MEMITHyperParams, apply_memit_to_model
|
3 |
+
from ..models.kn import KNHyperParams, apply_kn_to_model
|
4 |
+
from ..models.mend import MENDHyperParams, MendRewriteExecutor, MendMultimodalRewriteExecutor
|
5 |
+
from ..models.ft import FTHyperParams, apply_ft_to_model
|
6 |
+
from ..models.serac import SERACHparams, SeracRewriteExecutor, SeracMultimodalRewriteExecutor
|
7 |
+
from ..dataset import ZsreDataset, CounterFactDataset, CaptionDataset, VQADataset
|
8 |
+
from ..models.ike import IKEHyperParams, apply_ike_to_model, apply_ike_to_multimodal_model
|
9 |
+
from ..models.ft_api import FTApiHyperParams, apply_ft_api_to_model
|
10 |
+
from ..models.lora import LoRAHyperParams, apply_lora_to_model
|
11 |
+
from ..models.grace import GraceHyperParams, apply_grace_to_model
|
12 |
+
from ..models.pmet import PMETHyperParams, apply_pmet_to_model
|
13 |
+
from ..models.melo import MELOHyperParams, apply_melo_to_model
|
14 |
+
|
15 |
+
ALG_DICT = {
|
16 |
+
'ROME': apply_rome_to_model,
|
17 |
+
'MEMIT': apply_memit_to_model,
|
18 |
+
"FT": apply_ft_to_model,
|
19 |
+
'KN': apply_kn_to_model,
|
20 |
+
'MEND': MendRewriteExecutor().apply_to_model,
|
21 |
+
'SERAC': SeracRewriteExecutor().apply_to_model,
|
22 |
+
'IKE': apply_ike_to_model,
|
23 |
+
'FT-Api': apply_ft_api_to_model,
|
24 |
+
'LoRA': apply_lora_to_model,
|
25 |
+
'GRACE': apply_grace_to_model,
|
26 |
+
'PMET': apply_pmet_to_model,
|
27 |
+
'MELO': apply_melo_to_model
|
28 |
+
}
|
29 |
+
|
30 |
+
ALG_MULTIMODAL_DICT = {
|
31 |
+
'MEND': MendMultimodalRewriteExecutor().apply_to_model,
|
32 |
+
'SERAC': SeracMultimodalRewriteExecutor().apply_to_model,
|
33 |
+
'SERAC_MULTI': SeracMultimodalRewriteExecutor().apply_to_model,
|
34 |
+
'IKE': apply_ike_to_multimodal_model,
|
35 |
+
}
|
36 |
+
|
37 |
+
DS_DICT = {
|
38 |
+
"cf": CounterFactDataset,
|
39 |
+
"zsre": ZsreDataset,
|
40 |
+
}
|
41 |
+
|
42 |
+
MULTIMODAL_DS_DICT = {
|
43 |
+
"caption": CaptionDataset,
|
44 |
+
"vqa": VQADataset,
|
45 |
+
}
|
easyeditor/util/alg_train_dict.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..trainer import MEND
|
2 |
+
from ..trainer import SERAC, SERAC_MULTI
|
3 |
+
|
4 |
+
|
5 |
+
ALG_TRAIN_DICT = {
|
6 |
+
'MEND': MEND,
|
7 |
+
'SERAC': SERAC,
|
8 |
+
'SERAC_MULTI': SERAC_MULTI,
|
9 |
+
}
|
easyeditor/util/generate.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unicodedata
|
2 |
+
from typing import List, Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
6 |
+
|
7 |
+
from .logit_lens import LogitLens
|
8 |
+
|
9 |
+
|
10 |
+
def generate_interactive(
|
11 |
+
model: AutoModelForCausalLM,
|
12 |
+
tok: AutoTokenizer,
|
13 |
+
top_k: int = 5,
|
14 |
+
max_out_len: int = 200,
|
15 |
+
compare_against: Optional[AutoModelForCausalLM] = None,
|
16 |
+
use_logit_lens: bool = False,
|
17 |
+
layer_module_tmp: str = "transformer.h.{}",
|
18 |
+
ln_f_module: str = "transformer.ln_f",
|
19 |
+
lm_head_module: str = "lm_head",
|
20 |
+
):
|
21 |
+
"""
|
22 |
+
Puts generation in a loop. Allows users to repeatedly provide inputs
|
23 |
+
with which text is generated.
|
24 |
+
"""
|
25 |
+
|
26 |
+
if use_logit_lens:
|
27 |
+
llens_gen = LogitLens(
|
28 |
+
model,
|
29 |
+
tok,
|
30 |
+
layer_module_tmp,
|
31 |
+
ln_f_module,
|
32 |
+
lm_head_module,
|
33 |
+
disabled=not use_logit_lens,
|
34 |
+
)
|
35 |
+
if compare_against:
|
36 |
+
llens_vanilla = LogitLens(
|
37 |
+
compare_against,
|
38 |
+
tok,
|
39 |
+
layer_module_tmp,
|
40 |
+
ln_f_module,
|
41 |
+
lm_head_module,
|
42 |
+
disabled=not use_logit_lens,
|
43 |
+
)
|
44 |
+
|
45 |
+
while True:
|
46 |
+
prompt = input("Enter a prompt: ").strip(" \r\t\n")
|
47 |
+
|
48 |
+
print(
|
49 |
+
f"Argument Model: "
|
50 |
+
f"{generate_fast(model, tok, [prompt], n_gen_per_prompt=1, top_k=top_k, max_out_len=max_out_len)}"
|
51 |
+
)
|
52 |
+
if compare_against:
|
53 |
+
print(
|
54 |
+
f"Baseline Model: "
|
55 |
+
f"{generate_fast(compare_against, tok, [prompt], n_gen_per_prompt=1, top_k=top_k, max_out_len=max_out_len)}"
|
56 |
+
)
|
57 |
+
|
58 |
+
if use_logit_lens:
|
59 |
+
inp_prompt = tok([prompt], padding=True, return_tensors="pt").to(
|
60 |
+
next(model.parameters()).device
|
61 |
+
)
|
62 |
+
|
63 |
+
with llens_gen:
|
64 |
+
model(**inp_prompt)
|
65 |
+
print("\n--- Argument Model Logit Lens ---")
|
66 |
+
llens_gen.pprint()
|
67 |
+
|
68 |
+
if compare_against:
|
69 |
+
with llens_vanilla:
|
70 |
+
compare_against(**inp_prompt)
|
71 |
+
print("--- Baseline Model Logit Lens ---")
|
72 |
+
llens_vanilla.pprint()
|
73 |
+
|
74 |
+
print()
|
75 |
+
|
76 |
+
|
77 |
+
def generate_fast(
|
78 |
+
model: AutoModelForCausalLM,
|
79 |
+
tok: AutoTokenizer,
|
80 |
+
prompts: List[str],
|
81 |
+
n_gen_per_prompt: int = 1,
|
82 |
+
top_k: int = 5,
|
83 |
+
max_out_len: int = 200,
|
84 |
+
vanilla_generation=False,
|
85 |
+
):
|
86 |
+
"""
|
87 |
+
Fast, parallelized auto-regressive text generation with top-k sampling.
|
88 |
+
Our custom implementation.
|
89 |
+
"""
|
90 |
+
|
91 |
+
# Unroll prompts and tokenize
|
92 |
+
inp = [prompt for prompt in prompts for _ in range(n_gen_per_prompt)]
|
93 |
+
inp_tok = tok(inp, padding=True, return_tensors="pt").to(
|
94 |
+
next(model.parameters()).device
|
95 |
+
)
|
96 |
+
input_ids, attention_mask = inp_tok["input_ids"], inp_tok["attention_mask"]
|
97 |
+
if vanilla_generation:
|
98 |
+
gen_txt = model.generate(
|
99 |
+
input_ids=input_ids,
|
100 |
+
attention_mask=attention_mask,
|
101 |
+
max_new_tokens=max_out_len
|
102 |
+
)
|
103 |
+
txt = [tok.decode(x, skip_special_tokens=True) for x in gen_txt.detach().cpu().numpy().tolist()]
|
104 |
+
txt = [
|
105 |
+
unicodedata.normalize("NFKD", x)
|
106 |
+
.replace("\n\n", " ")
|
107 |
+
.replace("<|endoftext|>", "")
|
108 |
+
for x in txt
|
109 |
+
]
|
110 |
+
return txt
|
111 |
+
batch_size = input_ids.size(0)
|
112 |
+
|
113 |
+
# Setup storage of fast generation with attention caches.
|
114 |
+
# `cur_context` is used to define the range of inputs that are not yet
|
115 |
+
# stored in `past_key_values`. At each step, we are generating the
|
116 |
+
# next token for the index at `cur_context.stop + 1`.
|
117 |
+
past_key_values, cur_context = None, slice(0, attention_mask.sum(1).min().item())
|
118 |
+
|
119 |
+
with torch.no_grad():
|
120 |
+
while input_ids.size(1) < max_out_len: # while not exceeding max output length
|
121 |
+
model_out = model(
|
122 |
+
input_ids=input_ids[:, cur_context],
|
123 |
+
attention_mask=None if 'llama'or'baichuan' in model.name_or_path.lower() else attention_mask[:, cur_context],
|
124 |
+
past_key_values=past_key_values,
|
125 |
+
use_cache=True,
|
126 |
+
)
|
127 |
+
logits, past_key_values = model_out.logits, model_out.past_key_values
|
128 |
+
softmax_out = torch.nn.functional.softmax(logits[:, -1, :], dim=1)
|
129 |
+
|
130 |
+
# Top-k sampling
|
131 |
+
tk = torch.topk(softmax_out, top_k, dim=1).indices
|
132 |
+
softmax_out_top_k = torch.gather(softmax_out, 1, tk)
|
133 |
+
softmax_out_top_k = softmax_out_top_k / softmax_out_top_k.sum(1)[:, None]
|
134 |
+
new_tok_indices = torch.multinomial(softmax_out_top_k, 1)
|
135 |
+
new_toks = torch.gather(tk, 1, new_tok_indices)
|
136 |
+
|
137 |
+
# If we're currently generating the continuation for the last token in `input_ids`,
|
138 |
+
# create a new index so we can insert the new token
|
139 |
+
if cur_context.stop == input_ids.size(1):
|
140 |
+
attention_mask = torch.cat(
|
141 |
+
[attention_mask, attention_mask.new_zeros(batch_size, 1)], dim=1
|
142 |
+
)
|
143 |
+
input_ids = torch.cat(
|
144 |
+
[
|
145 |
+
input_ids,
|
146 |
+
input_ids.new_ones(batch_size, 1) * tok.pad_token_id,
|
147 |
+
],
|
148 |
+
dim=1,
|
149 |
+
)
|
150 |
+
|
151 |
+
last_non_masked = attention_mask.sum(1) - 1
|
152 |
+
for i in range(batch_size):
|
153 |
+
new_idx = last_non_masked[i] + 1
|
154 |
+
if last_non_masked[i].item() + 1 != cur_context.stop:
|
155 |
+
continue
|
156 |
+
|
157 |
+
# Stop generating if we've already maxed out for this prompt
|
158 |
+
if new_idx < max_out_len:
|
159 |
+
input_ids[i][new_idx] = new_toks[i]
|
160 |
+
attention_mask[i][new_idx] = 1
|
161 |
+
|
162 |
+
cur_context = slice(cur_context.stop, cur_context.stop + 1)
|
163 |
+
txt = [tok.decode(x, skip_special_tokens=True) for x in input_ids.detach().cpu().numpy().tolist()]
|
164 |
+
txt = [
|
165 |
+
unicodedata.normalize("NFKD", x)
|
166 |
+
.replace("\n\n", " ")
|
167 |
+
.replace("<|endoftext|>", "")
|
168 |
+
for x in txt
|
169 |
+
]
|
170 |
+
|
171 |
+
return txt
|
easyeditor/util/globals.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
|
6 |
+
import yaml
|
7 |
+
|
8 |
+
|
9 |
+
def get_handler(path, log_name):
|
10 |
+
log_file_path = os.path.join(path, log_name)
|
11 |
+
try:
|
12 |
+
if not os.path.exists(path):
|
13 |
+
print("We are creating the logger files")
|
14 |
+
os.makedirs(path)
|
15 |
+
except:
|
16 |
+
pass
|
17 |
+
file_handler = logging.FileHandler(log_file_path)
|
18 |
+
file_handler.setLevel(logging.DEBUG)
|
19 |
+
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
20 |
+
|
21 |
+
stream_handler = logging.StreamHandler()
|
22 |
+
stream_handler.setLevel(logging.DEBUG)
|
23 |
+
stream_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
24 |
+
return file_handler, stream_handler
|
25 |
+
|
26 |
+
|
27 |
+
# def get_run_dir(dir_name):
|
28 |
+
#
|
29 |
+
# alg_dir = RESULTS_DIR / dir_name
|
30 |
+
# if alg_dir.exists():
|
31 |
+
# id_list = [
|
32 |
+
# int(str(x).split("_")[-1])
|
33 |
+
# for x in alg_dir.iterdir()
|
34 |
+
# if str(x).split("_")[-1].isnumeric()
|
35 |
+
# ]
|
36 |
+
# run_id = 0 if not id_list else max(id_list) + 1
|
37 |
+
# else:
|
38 |
+
# run_id = 0
|
39 |
+
# run_dir = RESULTS_DIR / dir_name / f"run_{str(run_id).zfill(3)}"
|
40 |
+
# run_dir.mkdir(parents=True, exist_ok=True)
|
41 |
+
# print(f"Results will be stored at {run_dir}")
|
42 |
+
#
|
43 |
+
# return run_dir
|
easyeditor/util/hparams.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from dataclasses import asdict
|
4 |
+
|
5 |
+
|
6 |
+
@dataclass
|
7 |
+
class HyperParams:
|
8 |
+
"""
|
9 |
+
Simple wrapper to store hyperparameters for Python-based rewriting methods.
|
10 |
+
"""
|
11 |
+
|
12 |
+
@classmethod
|
13 |
+
def from_json(cls, fpath):
|
14 |
+
with open(fpath, "r") as f:
|
15 |
+
data = json.load(f)
|
16 |
+
|
17 |
+
return cls(**data)
|
18 |
+
|
19 |
+
def construct_float_from_scientific_notation(config: dict):
|
20 |
+
for key, value in config.items():
|
21 |
+
if isinstance(value, str):
|
22 |
+
try:
|
23 |
+
# Convert scalar to float if it is in scientific notation format
|
24 |
+
config[key] = float(value)
|
25 |
+
except:
|
26 |
+
pass
|
27 |
+
return config
|
28 |
+
|
29 |
+
def to_dict(config) -> dict:
|
30 |
+
dict = asdict(config)
|
31 |
+
return dict
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
# @classmethod
|
36 |
+
# def from_hparams(cls, hparams_name_or_path: str):
|
37 |
+
#
|
38 |
+
# if '.yaml' not in hparams_name_or_path:
|
39 |
+
# hparams_name_or_path = hparams_name_or_path + '.yaml'
|
40 |
+
# config = compose(hparams_name_or_path)
|
41 |
+
#
|
42 |
+
# assert config.alg_name in ALG_DICT.keys() or print(f'Editing Alg name {config.alg_name} not supported yet.')
|
43 |
+
#
|
44 |
+
# params_class, apply_algo = ALG_DICT[config.alg_name]
|
45 |
+
#
|
46 |
+
# return params_class(**config)
|
easyeditor/util/logit_lens.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
from typing import Dict, Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
6 |
+
|
7 |
+
from . import nethook
|
8 |
+
|
9 |
+
|
10 |
+
class LogitLens:
|
11 |
+
"""
|
12 |
+
Applies the LM head at the output of each hidden layer, then analyzes the
|
13 |
+
resultant token probability distribution.
|
14 |
+
|
15 |
+
Only works when hooking outputs of *one* individual generation.
|
16 |
+
|
17 |
+
Inspiration: https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens
|
18 |
+
|
19 |
+
Warning: when running multiple times (e.g. generation), will return
|
20 |
+
outputs _only_ for the last processing step.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
model: AutoModelForCausalLM,
|
26 |
+
tok: AutoTokenizer,
|
27 |
+
layer_module_tmp: str,
|
28 |
+
ln_f_module: str,
|
29 |
+
lm_head_module: str,
|
30 |
+
disabled: bool = False,
|
31 |
+
):
|
32 |
+
self.disabled = disabled
|
33 |
+
self.model, self.tok = model, tok
|
34 |
+
self.n_layers = self.model.config.n_layer
|
35 |
+
|
36 |
+
self.lm_head, self.ln_f = (
|
37 |
+
nethook.get_module(model, lm_head_module),
|
38 |
+
nethook.get_module(model, ln_f_module),
|
39 |
+
)
|
40 |
+
|
41 |
+
self.output: Optional[Dict] = None
|
42 |
+
self.td: Optional[nethook.TraceDict] = None
|
43 |
+
self.trace_layers = [
|
44 |
+
layer_module_tmp.format(layer) for layer in range(self.n_layers)
|
45 |
+
]
|
46 |
+
|
47 |
+
def __enter__(self):
|
48 |
+
if not self.disabled:
|
49 |
+
self.td = nethook.TraceDict(
|
50 |
+
self.model,
|
51 |
+
self.trace_layers,
|
52 |
+
retain_input=False,
|
53 |
+
retain_output=True,
|
54 |
+
)
|
55 |
+
self.td.__enter__()
|
56 |
+
|
57 |
+
def __exit__(self, *args):
|
58 |
+
if self.disabled:
|
59 |
+
return
|
60 |
+
self.td.__exit__(*args)
|
61 |
+
|
62 |
+
self.output = {layer: [] for layer in range(self.n_layers)}
|
63 |
+
|
64 |
+
with torch.no_grad():
|
65 |
+
for layer, (_, t) in enumerate(self.td.items()):
|
66 |
+
cur_out = t.output[0]
|
67 |
+
assert (
|
68 |
+
cur_out.size(0) == 1
|
69 |
+
), "Make sure you're only running LogitLens on single generations only."
|
70 |
+
|
71 |
+
self.output[layer] = torch.softmax(
|
72 |
+
self.lm_head(self.ln_f(cur_out[:, -1, :])), dim=1
|
73 |
+
)
|
74 |
+
|
75 |
+
return self.output
|
76 |
+
|
77 |
+
def pprint(self, k=5):
|
78 |
+
to_print = defaultdict(list)
|
79 |
+
|
80 |
+
for layer, pred in self.output.items():
|
81 |
+
rets = torch.topk(pred[0], k)
|
82 |
+
for i in range(k):
|
83 |
+
to_print[layer].append(
|
84 |
+
(
|
85 |
+
self.tok.decode(rets[1][i]),
|
86 |
+
round(rets[0][i].item() * 1e2) / 1e2,
|
87 |
+
)
|
88 |
+
)
|
89 |
+
|
90 |
+
print(
|
91 |
+
"\n".join(
|
92 |
+
[
|
93 |
+
f"{layer}: {[(el[0], round(el[1] * 1e2)) for el in to_print[layer]]}"
|
94 |
+
for layer in range(self.n_layers)
|
95 |
+
]
|
96 |
+
)
|
97 |
+
)
|
easyeditor/util/nethook.py
ADDED
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utilities for instrumenting a torch model.
|
3 |
+
|
4 |
+
Trace will hook one layer at a time.
|
5 |
+
TraceDict will hook multiple layers at once.
|
6 |
+
subsequence slices intervals from Sequential modules.
|
7 |
+
get_module, replace_module, get_parameter resolve dotted names.
|
8 |
+
set_requires_grad recursively sets requires_grad in module parameters.
|
9 |
+
"""
|
10 |
+
|
11 |
+
import contextlib
|
12 |
+
import copy
|
13 |
+
import inspect
|
14 |
+
from collections import OrderedDict
|
15 |
+
|
16 |
+
import torch
|
17 |
+
|
18 |
+
|
19 |
+
class Trace(contextlib.AbstractContextManager):
|
20 |
+
"""
|
21 |
+
To retain the output of the named layer during the computation of
|
22 |
+
the given network:
|
23 |
+
|
24 |
+
with Trace(net, 'layer.name') as ret:
|
25 |
+
_ = net(inp)
|
26 |
+
representation = ret.output
|
27 |
+
|
28 |
+
A layer module can be passed directly without a layer name, and
|
29 |
+
its output will be retained. By default, a direct reference to
|
30 |
+
the output object is returned, but options can control this:
|
31 |
+
|
32 |
+
clone=True - retains a copy of the output, which can be
|
33 |
+
useful if you want to see the output before it might
|
34 |
+
be modified by the network in-place later.
|
35 |
+
detach=True - retains a detached reference or copy. (By
|
36 |
+
default the value would be left attached to the graph.)
|
37 |
+
retain_grad=True - request gradient to be retained on the
|
38 |
+
output. After backward(), ret.output.grad is populated.
|
39 |
+
|
40 |
+
retain_input=True - also retains the input.
|
41 |
+
retain_output=False - can disable retaining the output.
|
42 |
+
edit_output=fn - calls the function to modify the output
|
43 |
+
of the layer before passing it the rest of the model.
|
44 |
+
fn can optionally accept (output, layer) arguments
|
45 |
+
for the original output and the layer name.
|
46 |
+
stop=True - throws a StopForward exception after the layer
|
47 |
+
is run, which allows running just a portion of a model.
|
48 |
+
"""
|
49 |
+
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
module,
|
53 |
+
layer=None,
|
54 |
+
retain_output=True,
|
55 |
+
retain_input=False,
|
56 |
+
clone=False,
|
57 |
+
detach=False,
|
58 |
+
retain_grad=False,
|
59 |
+
edit_output=None,
|
60 |
+
stop=False,
|
61 |
+
):
|
62 |
+
"""
|
63 |
+
Method to replace a forward method with a closure that
|
64 |
+
intercepts the call, and tracks the hook so that it can be reverted.
|
65 |
+
"""
|
66 |
+
retainer = self
|
67 |
+
self.layer = layer
|
68 |
+
if layer is not None:
|
69 |
+
module = get_module(module, layer)
|
70 |
+
|
71 |
+
def retain_hook(m, inputs, output):
|
72 |
+
if retain_input:
|
73 |
+
retainer.input = recursive_copy(
|
74 |
+
inputs[0] if len(inputs) == 1 else inputs,
|
75 |
+
clone=clone,
|
76 |
+
detach=detach,
|
77 |
+
retain_grad=False,
|
78 |
+
) # retain_grad applies to output only.
|
79 |
+
if edit_output:
|
80 |
+
output = invoke_with_optional_args(
|
81 |
+
edit_output, output=output, layer=self.layer
|
82 |
+
)
|
83 |
+
if retain_output:
|
84 |
+
retainer.output = recursive_copy(
|
85 |
+
output, clone=clone, detach=detach, retain_grad=retain_grad
|
86 |
+
)
|
87 |
+
# When retain_grad is set, also insert a trivial
|
88 |
+
# copy operation. That allows in-place operations
|
89 |
+
# to follow without error.
|
90 |
+
if retain_grad:
|
91 |
+
output = recursive_copy(retainer.output, clone=True, detach=False)
|
92 |
+
if stop:
|
93 |
+
raise StopForward()
|
94 |
+
return output
|
95 |
+
|
96 |
+
self.registered_hook = module.register_forward_hook(retain_hook)
|
97 |
+
self.stop = stop
|
98 |
+
|
99 |
+
def __enter__(self):
|
100 |
+
return self
|
101 |
+
|
102 |
+
def __exit__(self, type, value, traceback):
|
103 |
+
self.close()
|
104 |
+
if self.stop and issubclass(type, StopForward):
|
105 |
+
return True
|
106 |
+
|
107 |
+
def close(self):
|
108 |
+
self.registered_hook.remove()
|
109 |
+
|
110 |
+
|
111 |
+
class TraceDict(OrderedDict, contextlib.AbstractContextManager):
|
112 |
+
"""
|
113 |
+
To retain the output of multiple named layers during the computation
|
114 |
+
of the given network:
|
115 |
+
|
116 |
+
with TraceDict(net, ['layer1.name1', 'layer2.name2']) as ret:
|
117 |
+
_ = net(inp)
|
118 |
+
representation = ret['layer1.name1'].output
|
119 |
+
|
120 |
+
If edit_output is provided, it should be a function that takes
|
121 |
+
two arguments: output, and the layer name; and then it returns the
|
122 |
+
modified output.
|
123 |
+
|
124 |
+
Other arguments are the same as Trace. If stop is True, then the
|
125 |
+
execution of the network will be stopped after the last layer
|
126 |
+
listed (even if it would not have been the last to be executed).
|
127 |
+
"""
|
128 |
+
|
129 |
+
def __init__(
|
130 |
+
self,
|
131 |
+
module,
|
132 |
+
layers=None,
|
133 |
+
retain_output=True,
|
134 |
+
retain_input=False,
|
135 |
+
clone=False,
|
136 |
+
detach=False,
|
137 |
+
retain_grad=False,
|
138 |
+
edit_output=None,
|
139 |
+
stop=False,
|
140 |
+
):
|
141 |
+
self.stop = stop
|
142 |
+
|
143 |
+
def flag_last_unseen(it):
|
144 |
+
try:
|
145 |
+
it = iter(it)
|
146 |
+
prev = next(it)
|
147 |
+
seen = set([prev])
|
148 |
+
except StopIteration:
|
149 |
+
return
|
150 |
+
for item in it:
|
151 |
+
if item not in seen:
|
152 |
+
yield False, prev
|
153 |
+
seen.add(item)
|
154 |
+
prev = item
|
155 |
+
yield True, prev
|
156 |
+
|
157 |
+
for is_last, layer in flag_last_unseen(layers):
|
158 |
+
self[layer] = Trace(
|
159 |
+
module=module,
|
160 |
+
layer=layer,
|
161 |
+
retain_output=retain_output,
|
162 |
+
retain_input=retain_input,
|
163 |
+
clone=clone,
|
164 |
+
detach=detach,
|
165 |
+
retain_grad=retain_grad,
|
166 |
+
edit_output=edit_output,
|
167 |
+
stop=stop and is_last,
|
168 |
+
)
|
169 |
+
|
170 |
+
def __enter__(self):
|
171 |
+
return self
|
172 |
+
|
173 |
+
def __exit__(self, type, value, traceback):
|
174 |
+
self.close()
|
175 |
+
if self.stop and issubclass(type, StopForward):
|
176 |
+
return True
|
177 |
+
|
178 |
+
def close(self):
|
179 |
+
for layer, trace in reversed(self.items()):
|
180 |
+
trace.close()
|
181 |
+
|
182 |
+
|
183 |
+
class StopForward(Exception):
|
184 |
+
"""
|
185 |
+
If the only output needed from running a network is the retained
|
186 |
+
submodule then Trace(submodule, stop=True) will stop execution
|
187 |
+
immediately after the retained submodule by raising the StopForward()
|
188 |
+
exception. When Trace is used as context manager, it catches that
|
189 |
+
exception and can be used as follows:
|
190 |
+
|
191 |
+
with Trace(net, layername, stop=True) as tr:
|
192 |
+
net(inp) # Only runs the network up to layername
|
193 |
+
print(tr.output)
|
194 |
+
"""
|
195 |
+
|
196 |
+
pass
|
197 |
+
|
198 |
+
|
199 |
+
def recursive_copy(x, clone=None, detach=None, retain_grad=None):
|
200 |
+
"""
|
201 |
+
Copies a reference to a tensor, or an object that contains tensors,
|
202 |
+
optionally detaching and cloning the tensor(s). If retain_grad is
|
203 |
+
true, the original tensors are marked to have grads retained.
|
204 |
+
"""
|
205 |
+
if not clone and not detach and not retain_grad:
|
206 |
+
return x
|
207 |
+
if isinstance(x, torch.Tensor):
|
208 |
+
if retain_grad:
|
209 |
+
if not x.requires_grad:
|
210 |
+
x.requires_grad = True
|
211 |
+
x.retain_grad()
|
212 |
+
elif detach:
|
213 |
+
x = x.detach()
|
214 |
+
if clone:
|
215 |
+
x = x.clone()
|
216 |
+
return x
|
217 |
+
# Only dicts, lists, and tuples (and subclasses) can be copied.
|
218 |
+
if isinstance(x, dict):
|
219 |
+
return type(x)({k: recursive_copy(v) for k, v in x.items()})
|
220 |
+
elif isinstance(x, (list, tuple)):
|
221 |
+
return type(x)([recursive_copy(v) for v in x])
|
222 |
+
else:
|
223 |
+
assert False, f"Unknown type {type(x)} cannot be broken into tensors."
|
224 |
+
|
225 |
+
|
226 |
+
def subsequence(
|
227 |
+
sequential,
|
228 |
+
first_layer=None,
|
229 |
+
last_layer=None,
|
230 |
+
after_layer=None,
|
231 |
+
upto_layer=None,
|
232 |
+
single_layer=None,
|
233 |
+
share_weights=False,
|
234 |
+
):
|
235 |
+
"""
|
236 |
+
Creates a subsequence of a pytorch Sequential model, copying over
|
237 |
+
modules together with parameters for the subsequence. Only
|
238 |
+
modules from first_layer to last_layer (inclusive) are included,
|
239 |
+
or modules between after_layer and upto_layer (exclusive).
|
240 |
+
Handles descent into dotted layer names as long as all references
|
241 |
+
are within nested Sequential models.
|
242 |
+
|
243 |
+
If share_weights is True, then references the original modules
|
244 |
+
and their parameters without copying them. Otherwise, by default,
|
245 |
+
makes a separate brand-new copy.
|
246 |
+
"""
|
247 |
+
assert (single_layer is None) or (
|
248 |
+
first_layer is last_layer is after_layer is upto_layer is None
|
249 |
+
)
|
250 |
+
if single_layer is not None:
|
251 |
+
first_layer = single_layer
|
252 |
+
last_layer = single_layer
|
253 |
+
first, last, after, upto = [
|
254 |
+
None if d is None else d.split(".")
|
255 |
+
for d in [first_layer, last_layer, after_layer, upto_layer]
|
256 |
+
]
|
257 |
+
return hierarchical_subsequence(
|
258 |
+
sequential,
|
259 |
+
first=first,
|
260 |
+
last=last,
|
261 |
+
after=after,
|
262 |
+
upto=upto,
|
263 |
+
share_weights=share_weights,
|
264 |
+
)
|
265 |
+
|
266 |
+
|
267 |
+
def hierarchical_subsequence(
|
268 |
+
sequential, first, last, after, upto, share_weights=False, depth=0
|
269 |
+
):
|
270 |
+
"""
|
271 |
+
Recursive helper for subsequence() to support descent into dotted
|
272 |
+
layer names. In this helper, first, last, after, and upto are
|
273 |
+
arrays of names resulting from splitting on dots. Can only
|
274 |
+
descend into nested Sequentials.
|
275 |
+
"""
|
276 |
+
assert (last is None) or (upto is None)
|
277 |
+
assert (first is None) or (after is None)
|
278 |
+
if first is last is after is upto is None:
|
279 |
+
return sequential if share_weights else copy.deepcopy(sequential)
|
280 |
+
assert isinstance(sequential, torch.nn.Sequential), (
|
281 |
+
".".join((first or last or after or upto)[:depth] or "arg") + " not Sequential"
|
282 |
+
)
|
283 |
+
including_children = (first is None) and (after is None)
|
284 |
+
included_children = OrderedDict()
|
285 |
+
# A = current level short name of A.
|
286 |
+
# AN = full name for recursive descent if not innermost.
|
287 |
+
(F, FN), (L, LN), (A, AN), (U, UN) = [
|
288 |
+
(d[depth], (None if len(d) == depth + 1 else d))
|
289 |
+
if d is not None
|
290 |
+
else (None, None)
|
291 |
+
for d in [first, last, after, upto]
|
292 |
+
]
|
293 |
+
for name, layer in sequential._modules.items():
|
294 |
+
if name == F:
|
295 |
+
first = None
|
296 |
+
including_children = True
|
297 |
+
if name == A and AN is not None: # just like F if not a leaf.
|
298 |
+
after = None
|
299 |
+
including_children = True
|
300 |
+
if name == U and UN is None:
|
301 |
+
upto = None
|
302 |
+
including_children = False
|
303 |
+
if including_children:
|
304 |
+
# AR = full name for recursive descent if name matches.
|
305 |
+
FR, LR, AR, UR = [
|
306 |
+
n if n is None or n[depth] == name else None for n in [FN, LN, AN, UN]
|
307 |
+
]
|
308 |
+
chosen = hierarchical_subsequence(
|
309 |
+
layer,
|
310 |
+
first=FR,
|
311 |
+
last=LR,
|
312 |
+
after=AR,
|
313 |
+
upto=UR,
|
314 |
+
share_weights=share_weights,
|
315 |
+
depth=depth + 1,
|
316 |
+
)
|
317 |
+
if chosen is not None:
|
318 |
+
included_children[name] = chosen
|
319 |
+
if name == L:
|
320 |
+
last = None
|
321 |
+
including_children = False
|
322 |
+
if name == U and UN is not None: # just like L if not a leaf.
|
323 |
+
upto = None
|
324 |
+
including_children = False
|
325 |
+
if name == A and AN is None:
|
326 |
+
after = None
|
327 |
+
including_children = True
|
328 |
+
for name in [first, last, after, upto]:
|
329 |
+
if name is not None:
|
330 |
+
raise ValueError("Layer %s not found" % ".".join(name))
|
331 |
+
# Omit empty subsequences except at the outermost level,
|
332 |
+
# where we should not return None.
|
333 |
+
if not len(included_children) and depth > 0:
|
334 |
+
return None
|
335 |
+
result = torch.nn.Sequential(included_children)
|
336 |
+
result.training = sequential.training
|
337 |
+
return result
|
338 |
+
|
339 |
+
|
340 |
+
def set_requires_grad(requires_grad, *models):
|
341 |
+
"""
|
342 |
+
Sets requires_grad true or false for all parameters within the
|
343 |
+
models passed.
|
344 |
+
"""
|
345 |
+
for model in models:
|
346 |
+
if isinstance(model, torch.nn.Module):
|
347 |
+
for param in model.parameters():
|
348 |
+
param.requires_grad = requires_grad
|
349 |
+
elif isinstance(model, (torch.nn.Parameter, torch.Tensor)):
|
350 |
+
model.requires_grad = requires_grad
|
351 |
+
else:
|
352 |
+
assert False, "unknown type %r" % type(model)
|
353 |
+
|
354 |
+
|
355 |
+
def get_module(model, name):
|
356 |
+
"""
|
357 |
+
Finds the named module within the given model.
|
358 |
+
"""
|
359 |
+
for n, m in model.named_modules():
|
360 |
+
if n == name:
|
361 |
+
return m
|
362 |
+
raise LookupError(name)
|
363 |
+
|
364 |
+
|
365 |
+
def get_parameter(model, name):
|
366 |
+
"""
|
367 |
+
Finds the named parameter within the given model.
|
368 |
+
"""
|
369 |
+
for n, p in model.named_parameters():
|
370 |
+
if n == name:
|
371 |
+
return p
|
372 |
+
raise LookupError(name)
|
373 |
+
|
374 |
+
|
375 |
+
def replace_module(model, name, new_module):
|
376 |
+
"""
|
377 |
+
Replaces the named module within the given model.
|
378 |
+
"""
|
379 |
+
if "." in name:
|
380 |
+
parent_name, attr_name = name.rsplit(".", 1)
|
381 |
+
model = get_module(model, parent_name)
|
382 |
+
# original_module = getattr(model, attr_name)
|
383 |
+
setattr(model, attr_name, new_module)
|
384 |
+
|
385 |
+
|
386 |
+
def invoke_with_optional_args(fn, *args, **kwargs):
|
387 |
+
"""
|
388 |
+
Invokes a function with only the arguments that it
|
389 |
+
is written to accept, giving priority to arguments
|
390 |
+
that match by-name, using the following rules.
|
391 |
+
(1) arguments with matching names are passed by name.
|
392 |
+
(2) remaining non-name-matched args are passed by order.
|
393 |
+
(3) extra caller arguments that the function cannot
|
394 |
+
accept are not passed.
|
395 |
+
(4) extra required function arguments that the caller
|
396 |
+
cannot provide cause a TypeError to be raised.
|
397 |
+
Ordinary python calling conventions are helpful for
|
398 |
+
supporting a function that might be revised to accept
|
399 |
+
extra arguments in a newer version, without requiring the
|
400 |
+
caller to pass those new arguments. This function helps
|
401 |
+
support function callers that might be revised to supply
|
402 |
+
extra arguments, without requiring the callee to accept
|
403 |
+
those new arguments.
|
404 |
+
"""
|
405 |
+
argspec = inspect.getfullargspec(fn)
|
406 |
+
pass_args = []
|
407 |
+
used_kw = set()
|
408 |
+
unmatched_pos = []
|
409 |
+
used_pos = 0
|
410 |
+
defaulted_pos = len(argspec.args) - (
|
411 |
+
0 if not argspec.defaults else len(argspec.defaults)
|
412 |
+
)
|
413 |
+
# Pass positional args that match name first, then by position.
|
414 |
+
for i, n in enumerate(argspec.args):
|
415 |
+
if n in kwargs:
|
416 |
+
pass_args.append(kwargs[n])
|
417 |
+
used_kw.add(n)
|
418 |
+
elif used_pos < len(args):
|
419 |
+
pass_args.append(args[used_pos])
|
420 |
+
used_pos += 1
|
421 |
+
else:
|
422 |
+
unmatched_pos.append(len(pass_args))
|
423 |
+
pass_args.append(
|
424 |
+
None if i < defaulted_pos else argspec.defaults[i - defaulted_pos]
|
425 |
+
)
|
426 |
+
# Fill unmatched positional args with unmatched keyword args in order.
|
427 |
+
if len(unmatched_pos):
|
428 |
+
for k, v in kwargs.items():
|
429 |
+
if k in used_kw or k in argspec.kwonlyargs:
|
430 |
+
continue
|
431 |
+
pass_args[unmatched_pos[0]] = v
|
432 |
+
used_kw.add(k)
|
433 |
+
unmatched_pos = unmatched_pos[1:]
|
434 |
+
if len(unmatched_pos) == 0:
|
435 |
+
break
|
436 |
+
else:
|
437 |
+
if unmatched_pos[0] < defaulted_pos:
|
438 |
+
unpassed = ", ".join(
|
439 |
+
argspec.args[u] for u in unmatched_pos if u < defaulted_pos
|
440 |
+
)
|
441 |
+
raise TypeError(f"{fn.__name__}() cannot be passed {unpassed}.")
|
442 |
+
# Pass remaining kw args if they can be accepted.
|
443 |
+
pass_kw = {
|
444 |
+
k: v
|
445 |
+
for k, v in kwargs.items()
|
446 |
+
if k not in used_kw and (k in argspec.kwonlyargs or argspec.varargs is not None)
|
447 |
+
}
|
448 |
+
# Pass remaining positional args if they can be accepted.
|
449 |
+
if argspec.varargs is not None:
|
450 |
+
pass_args += list(args[used_pos:])
|
451 |
+
return fn(*pass_args, **pass_kw)
|
easyeditor/util/perplexity.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
+
|
4 |
+
|
5 |
+
def perplexity(
|
6 |
+
model: AutoModelForCausalLM,
|
7 |
+
tok: AutoTokenizer,
|
8 |
+
text: str,
|
9 |
+
max_input_length: int = None,
|
10 |
+
):
|
11 |
+
"""
|
12 |
+
Computes perplexity of a piece of text, measured on a reference model.
|
13 |
+
Text is truncated to max_input_length tokens.
|
14 |
+
"""
|
15 |
+
|
16 |
+
inputs = tok(
|
17 |
+
[text], return_tensors="pt", max_length=max_input_length, truncation=True
|
18 |
+
).to("cuda")
|
19 |
+
|
20 |
+
logits = torch.nn.functional.log_softmax(model(**inputs).logits, dim=2)
|
21 |
+
log_probs = torch.gather(logits[:, :-1, :], 2, inputs["input_ids"][:, 1:, None])[0]
|
22 |
+
|
23 |
+
# Perplexity = exp(-1/N * log P(x_1, ..., x_n))
|
24 |
+
return torch.exp(-1 / inputs["input_ids"].size(1) * log_probs.sum()).item()
|
easyeditor/util/runningstats.py
ADDED
@@ -0,0 +1,1883 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
To use a runningstats object,
|
3 |
+
|
4 |
+
1. Create the the desired stat object, e.g., `m = Mean()`
|
5 |
+
2. Feed it batches via the add method, e.g., `m.add(batch)`
|
6 |
+
3. Repeat step 2 any number of times.
|
7 |
+
4. Read out the statistic of interest, e.g., `m.mean()`
|
8 |
+
|
9 |
+
Built-in runningstats objects include:
|
10 |
+
|
11 |
+
Mean - produces mean().
|
12 |
+
Variance - mean() and variance() and stdev().
|
13 |
+
Covariance - mean(), covariance(), correlation(), variance(), stdev().
|
14 |
+
SecondMoment - moment() is the non-mean-centered covariance, E[x x^T].
|
15 |
+
Quantile - quantile(), min(), max(), median(), mean(), variance(), stdev().
|
16 |
+
TopK - topk() returns (values, indexes).
|
17 |
+
Bincount - bincount() histograms nonnegative integer data.
|
18 |
+
IoU - intersection(), union(), iou() tally binary co-occurrences.
|
19 |
+
History - history() returns concatenation of data.
|
20 |
+
CrossCovariance - covariance between two signals, without self-covariance.
|
21 |
+
CrossIoU - iou between two signals, without self-IoU.
|
22 |
+
CombinedStat - aggregates any set of stats.
|
23 |
+
|
24 |
+
Add more running stats by subclassing the Stat class.
|
25 |
+
|
26 |
+
These statistics are vectorized along dim>=1, so stat.add()
|
27 |
+
should supply a two-dimensional input where the zeroth
|
28 |
+
dimension is the batch/sampling dimension and the first
|
29 |
+
dimension is the feature dimension.
|
30 |
+
|
31 |
+
The data type and device used matches the data passed to add();
|
32 |
+
for example, for higher-precision covariances, convert to double
|
33 |
+
before calling add().
|
34 |
+
|
35 |
+
It is common to want to compute and remember a statistic sampled
|
36 |
+
over a Dataset, computed in batches, possibly caching the computed
|
37 |
+
statistic in a file. The tally(stat, dataset, cache) handles
|
38 |
+
this pattern. It takes a statistic, a dataset, and a cache filename
|
39 |
+
and sets up a data loader that can be run (or not, if cached) to
|
40 |
+
compute the statistic, adopting the convention that cached stats are
|
41 |
+
saved to and loaded from numpy npz files.
|
42 |
+
"""
|
43 |
+
|
44 |
+
import math
|
45 |
+
import os
|
46 |
+
import random
|
47 |
+
import struct
|
48 |
+
|
49 |
+
import numpy
|
50 |
+
import torch
|
51 |
+
from torch.utils.data.sampler import Sampler
|
52 |
+
|
53 |
+
|
54 |
+
def tally(stat, dataset, cache=None, quiet=False, **kwargs):
|
55 |
+
"""
|
56 |
+
To use tally, write code like the following.
|
57 |
+
|
58 |
+
stat = Mean()
|
59 |
+
ds = MyDataset()
|
60 |
+
for batch in tally(stat, ds, cache='mymean.npz', batch_size=50):
|
61 |
+
stat.add(batch)
|
62 |
+
mean = stat.mean()
|
63 |
+
|
64 |
+
The first argument should be the Stat being computed. After the
|
65 |
+
loader is exhausted, tally will bring this stat to the cpu and
|
66 |
+
cache it (if a cache is specified).
|
67 |
+
|
68 |
+
The dataset can be a torch Dataset or a plain Tensor, or it can
|
69 |
+
be a callable that returns one of those.
|
70 |
+
|
71 |
+
Details on caching via the cache= argument:
|
72 |
+
|
73 |
+
If the given filename cannot be loaded, tally will leave the
|
74 |
+
statistic object empty and set up a DataLoader object so that
|
75 |
+
the loop can be run. After the last iteration of the loop, the
|
76 |
+
completed statistic will be moved to the cpu device and also
|
77 |
+
saved in the cache file.
|
78 |
+
|
79 |
+
If the cached statistic can be loaded from the given file, tally
|
80 |
+
will not set up the data loader and instead will return a fully
|
81 |
+
loaded statistic object (on the cpu device) and an empty list as
|
82 |
+
the loader.
|
83 |
+
|
84 |
+
The `with cache_load_enabled(False):` context manager can
|
85 |
+
be used to disable loading from the cache.
|
86 |
+
|
87 |
+
If needed, a DataLoader will be created to wrap the dataset:
|
88 |
+
|
89 |
+
Keyword arguments of tally are passed to the DataLoader,
|
90 |
+
so batch_size, num_workers, pin_memory, etc. can be specified.
|
91 |
+
|
92 |
+
Subsampling is supported via sample_size= and random_sample=:
|
93 |
+
|
94 |
+
If sample_size=N is specified, rather than loading the whole
|
95 |
+
dataset, only the first N items are sampled. If additionally
|
96 |
+
random_sample=S is specified, the pseudorandom seed S will be
|
97 |
+
used to select a fixed psedorandom sample of size N to sample.
|
98 |
+
"""
|
99 |
+
assert isinstance(stat, Stat)
|
100 |
+
args = {}
|
101 |
+
for k in ["sample_size"]:
|
102 |
+
if k in kwargs:
|
103 |
+
args[k] = kwargs[k]
|
104 |
+
cached_state = load_cached_state(cache, args, quiet=quiet)
|
105 |
+
if cached_state is not None:
|
106 |
+
stat.load_state_dict(cached_state)
|
107 |
+
|
108 |
+
def empty_loader():
|
109 |
+
return
|
110 |
+
yield
|
111 |
+
|
112 |
+
return empty_loader()
|
113 |
+
loader = make_loader(dataset, **kwargs)
|
114 |
+
|
115 |
+
def wrapped_loader():
|
116 |
+
yield from loader
|
117 |
+
stat.to_(device="cpu")
|
118 |
+
if cache is not None:
|
119 |
+
save_cached_state(cache, stat, args)
|
120 |
+
|
121 |
+
return wrapped_loader()
|
122 |
+
|
123 |
+
|
124 |
+
class cache_load_enabled:
|
125 |
+
"""
|
126 |
+
When used as a context manager, cache_load_enabled(False) will prevent
|
127 |
+
tally from loading cached statsitics, forcing them to be recomputed.
|
128 |
+
"""
|
129 |
+
|
130 |
+
def __init__(self, enabled=True):
|
131 |
+
self.prev = False
|
132 |
+
self.enabled = enabled
|
133 |
+
|
134 |
+
def __enter__(self):
|
135 |
+
global global_load_cache_enabled
|
136 |
+
self.prev = global_load_cache_enabled
|
137 |
+
global_load_cache_enabled = self.enabled
|
138 |
+
|
139 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
140 |
+
global global_load_cache_enabled
|
141 |
+
global_load_cache_enabled = self.prev
|
142 |
+
|
143 |
+
|
144 |
+
class Stat:
|
145 |
+
"""
|
146 |
+
Abstract base class for a running pytorch statistic.
|
147 |
+
"""
|
148 |
+
|
149 |
+
def __init__(self, state):
|
150 |
+
"""
|
151 |
+
By convention, all Stat subclasses can be initialized by passing
|
152 |
+
state=; and then they will initialize by calling load_state_dict.
|
153 |
+
"""
|
154 |
+
self.load_state_dict(resolve_state_dict(state))
|
155 |
+
|
156 |
+
def add(self, x, *args, **kwargs):
|
157 |
+
"""
|
158 |
+
Observes a batch of samples to be incorporated into the statistic.
|
159 |
+
Dimension 0 should be the batch dimension, and dimension 1 should
|
160 |
+
be the feature dimension of the pytorch tensor x.
|
161 |
+
"""
|
162 |
+
pass
|
163 |
+
|
164 |
+
def load_state_dict(self, d):
|
165 |
+
"""
|
166 |
+
Loads this Stat from a dictionary of numpy arrays as saved
|
167 |
+
by state_dict.
|
168 |
+
"""
|
169 |
+
pass
|
170 |
+
|
171 |
+
def state_dict(self):
|
172 |
+
"""
|
173 |
+
Saves this Stat as a dictionary of numpy arrays that can be
|
174 |
+
stored in an npz or reloaded later using load_state_dict.
|
175 |
+
"""
|
176 |
+
return {}
|
177 |
+
|
178 |
+
def save(self, filename):
|
179 |
+
"""
|
180 |
+
Saves this stat as an npz file containing the state_dict.
|
181 |
+
"""
|
182 |
+
save_cached_state(filename, self, {})
|
183 |
+
|
184 |
+
def load(self, filename):
|
185 |
+
"""
|
186 |
+
Loads this stat from an npz file containing a saved state_dict.
|
187 |
+
"""
|
188 |
+
self.load_state_dict(load_cached_state(filename, {}, quiet=True, throw=True))
|
189 |
+
|
190 |
+
def to_(self, device):
|
191 |
+
"""
|
192 |
+
Moves this Stat to the given device.
|
193 |
+
"""
|
194 |
+
pass
|
195 |
+
|
196 |
+
def cpu_(self):
|
197 |
+
"""
|
198 |
+
Moves this Stat to the cpu device.
|
199 |
+
"""
|
200 |
+
self.to_("cpu")
|
201 |
+
|
202 |
+
def cuda_(self):
|
203 |
+
"""
|
204 |
+
Moves this Stat to the default cuda device.
|
205 |
+
"""
|
206 |
+
self.to_("cuda")
|
207 |
+
|
208 |
+
def _normalize_add_shape(self, x, attr="data_shape"):
|
209 |
+
"""
|
210 |
+
Flattens input data to 2d.
|
211 |
+
"""
|
212 |
+
if not torch.is_tensor(x):
|
213 |
+
x = torch.tensor(x)
|
214 |
+
if len(x.shape) < 1:
|
215 |
+
x = x.view(-1)
|
216 |
+
data_shape = getattr(self, attr, None)
|
217 |
+
if data_shape is None:
|
218 |
+
data_shape = x.shape[1:]
|
219 |
+
setattr(self, attr, data_shape)
|
220 |
+
else:
|
221 |
+
assert x.shape[1:] == data_shape
|
222 |
+
return x.view(x.shape[0], int(numpy.prod(data_shape)))
|
223 |
+
|
224 |
+
def _restore_result_shape(self, x, attr="data_shape"):
|
225 |
+
"""
|
226 |
+
Restores output data to input data shape.
|
227 |
+
"""
|
228 |
+
data_shape = getattr(self, attr, None)
|
229 |
+
if data_shape is None:
|
230 |
+
return x
|
231 |
+
return x.view(data_shape * len(x.shape))
|
232 |
+
|
233 |
+
|
234 |
+
class Mean(Stat):
|
235 |
+
"""
|
236 |
+
Running mean.
|
237 |
+
"""
|
238 |
+
|
239 |
+
def __init__(self, state=None):
|
240 |
+
if state is not None:
|
241 |
+
return super().__init__(state)
|
242 |
+
self.count = 0
|
243 |
+
self.batchcount = 0
|
244 |
+
self._mean = None
|
245 |
+
self.data_shape = None
|
246 |
+
|
247 |
+
def add(self, a):
|
248 |
+
a = self._normalize_add_shape(a)
|
249 |
+
if len(a) == 0:
|
250 |
+
return
|
251 |
+
batch_count = a.shape[0]
|
252 |
+
batch_mean = a.sum(0) / batch_count
|
253 |
+
self.batchcount += 1
|
254 |
+
# Initial batch.
|
255 |
+
if self._mean is None:
|
256 |
+
self.count = batch_count
|
257 |
+
self._mean = batch_mean
|
258 |
+
return
|
259 |
+
# Update a batch using Chan-style update for numerical stability.
|
260 |
+
self.count += batch_count
|
261 |
+
new_frac = float(batch_count) / self.count
|
262 |
+
# Update the mean according to the batch deviation from the old mean.
|
263 |
+
delta = batch_mean.sub_(self._mean).mul_(new_frac)
|
264 |
+
self._mean.add_(delta)
|
265 |
+
|
266 |
+
def size(self):
|
267 |
+
return self.count
|
268 |
+
|
269 |
+
def mean(self):
|
270 |
+
return self._restore_result_shape(self._mean)
|
271 |
+
|
272 |
+
def to_(self, device):
|
273 |
+
if self._mean is not None:
|
274 |
+
self._mean = self._mean.to(device)
|
275 |
+
|
276 |
+
def load_state_dict(self, state):
|
277 |
+
self.count = state["count"]
|
278 |
+
self.batchcount = state["batchcount"]
|
279 |
+
self._mean = torch.from_numpy(state["mean"])
|
280 |
+
self.data_shape = (
|
281 |
+
None if state["data_shape"] is None else tuple(state["data_shape"])
|
282 |
+
)
|
283 |
+
|
284 |
+
def state_dict(self):
|
285 |
+
return dict(
|
286 |
+
constructor=self.__module__ + "." + self.__class__.__name__ + "()",
|
287 |
+
count=self.count,
|
288 |
+
data_shape=self.data_shape and tuple(self.data_shape),
|
289 |
+
batchcount=self.batchcount,
|
290 |
+
mean=self._mean.cpu().numpy(),
|
291 |
+
)
|
292 |
+
|
293 |
+
|
294 |
+
class NormMean(Mean):
|
295 |
+
"""
|
296 |
+
Running average of the norm of input vectors
|
297 |
+
"""
|
298 |
+
|
299 |
+
def __init__(self, state=None):
|
300 |
+
super().__init__(state)
|
301 |
+
|
302 |
+
def add(self, a):
|
303 |
+
super().add(a.norm(dim=-1))
|
304 |
+
|
305 |
+
|
306 |
+
class Variance(Stat):
|
307 |
+
"""
|
308 |
+
Running computation of mean and variance. Use this when you just need
|
309 |
+
basic stats without covariance.
|
310 |
+
"""
|
311 |
+
|
312 |
+
def __init__(self, state=None):
|
313 |
+
if state is not None:
|
314 |
+
return super().__init__(state)
|
315 |
+
self.count = 0
|
316 |
+
self.batchcount = 0
|
317 |
+
self._mean = None
|
318 |
+
self.v_cmom2 = None
|
319 |
+
self.data_shape = None
|
320 |
+
|
321 |
+
def add(self, a):
|
322 |
+
a = self._normalize_add_shape(a)
|
323 |
+
if len(a) == 0:
|
324 |
+
return
|
325 |
+
batch_count = a.shape[0]
|
326 |
+
batch_mean = a.sum(0) / batch_count
|
327 |
+
centered = a - batch_mean
|
328 |
+
self.batchcount += 1
|
329 |
+
# Initial batch.
|
330 |
+
if self._mean is None:
|
331 |
+
self.count = batch_count
|
332 |
+
self._mean = batch_mean
|
333 |
+
self.v_cmom2 = centered.pow(2).sum(0)
|
334 |
+
return
|
335 |
+
# Update a batch using Chan-style update for numerical stability.
|
336 |
+
oldcount = self.count
|
337 |
+
self.count += batch_count
|
338 |
+
new_frac = float(batch_count) / self.count
|
339 |
+
# Update the mean according to the batch deviation from the old mean.
|
340 |
+
delta = batch_mean.sub_(self._mean).mul_(new_frac)
|
341 |
+
self._mean.add_(delta)
|
342 |
+
# Update the variance using the batch deviation
|
343 |
+
self.v_cmom2.add_(centered.pow(2).sum(0))
|
344 |
+
self.v_cmom2.add_(delta.pow_(2).mul_(new_frac * oldcount))
|
345 |
+
|
346 |
+
def size(self):
|
347 |
+
return self.count
|
348 |
+
|
349 |
+
def mean(self):
|
350 |
+
return self._restore_result_shape(self._mean)
|
351 |
+
|
352 |
+
def variance(self, unbiased=True):
|
353 |
+
return self._restore_result_shape(
|
354 |
+
self.v_cmom2 / (self.count - (1 if unbiased else 0))
|
355 |
+
)
|
356 |
+
|
357 |
+
def stdev(self, unbiased=True):
|
358 |
+
return self.variance(unbiased=unbiased).sqrt()
|
359 |
+
|
360 |
+
def to_(self, device):
|
361 |
+
if self._mean is not None:
|
362 |
+
self._mean = self._mean.to(device)
|
363 |
+
if self.v_cmom2 is not None:
|
364 |
+
self.v_cmom2 = self.v_cmom2.to(device)
|
365 |
+
|
366 |
+
def load_state_dict(self, state):
|
367 |
+
self.count = state["count"]
|
368 |
+
self.batchcount = state["batchcount"]
|
369 |
+
self._mean = torch.from_numpy(state["mean"])
|
370 |
+
self.v_cmom2 = torch.from_numpy(state["cmom2"])
|
371 |
+
self.data_shape = (
|
372 |
+
None if state["data_shape"] is None else tuple(state["data_shape"])
|
373 |
+
)
|
374 |
+
|
375 |
+
def state_dict(self):
|
376 |
+
return dict(
|
377 |
+
constructor=self.__module__ + "." + self.__class__.__name__ + "()",
|
378 |
+
count=self.count,
|
379 |
+
data_shape=self.data_shape and tuple(self.data_shape),
|
380 |
+
batchcount=self.batchcount,
|
381 |
+
mean=self._mean.cpu().numpy(),
|
382 |
+
cmom2=self.v_cmom2.cpu().numpy(),
|
383 |
+
)
|
384 |
+
|
385 |
+
|
386 |
+
class Covariance(Stat):
|
387 |
+
"""
|
388 |
+
Running computation. Use this when the entire covariance matrix is needed,
|
389 |
+
and when the whole covariance matrix fits in the GPU.
|
390 |
+
|
391 |
+
Chan-style numerically stable update of mean and full covariance matrix.
|
392 |
+
Chan, Golub. LeVeque. 1983. http://www.jstor.org/stable/2683386
|
393 |
+
"""
|
394 |
+
|
395 |
+
def __init__(self, state=None):
|
396 |
+
if state is not None:
|
397 |
+
return super().__init__(state)
|
398 |
+
self.count = 0
|
399 |
+
self._mean = None
|
400 |
+
self.cmom2 = None
|
401 |
+
self.data_shape = None
|
402 |
+
|
403 |
+
def add(self, a):
|
404 |
+
a = self._normalize_add_shape(a)
|
405 |
+
if len(a) == 0:
|
406 |
+
return
|
407 |
+
batch_count = a.shape[0]
|
408 |
+
# Initial batch.
|
409 |
+
if self._mean is None:
|
410 |
+
self.count = batch_count
|
411 |
+
self._mean = a.sum(0) / batch_count
|
412 |
+
centered = a - self._mean
|
413 |
+
self.cmom2 = centered.t().mm(centered)
|
414 |
+
return
|
415 |
+
# Update a batch using Chan-style update for numerical stability.
|
416 |
+
self.count += batch_count
|
417 |
+
# Update the mean according to the batch deviation from the old mean.
|
418 |
+
delta = a - self._mean
|
419 |
+
self._mean.add_(delta.sum(0) / self.count)
|
420 |
+
delta2 = a - self._mean
|
421 |
+
# Update the variance using the batch deviation
|
422 |
+
self.cmom2.addmm_(mat1=delta.t(), mat2=delta2)
|
423 |
+
|
424 |
+
def to_(self, device):
|
425 |
+
if self._mean is not None:
|
426 |
+
self._mean = self._mean.to(device)
|
427 |
+
if self.cmom2 is not None:
|
428 |
+
self.cmom2 = self.cmom2.to(device)
|
429 |
+
|
430 |
+
def mean(self):
|
431 |
+
return self._restore_result_shape(self._mean)
|
432 |
+
|
433 |
+
def covariance(self, unbiased=True):
|
434 |
+
return self._restore_result_shape(
|
435 |
+
self.cmom2 / (self.count - (1 if unbiased else 0))
|
436 |
+
)
|
437 |
+
|
438 |
+
def correlation(self, unbiased=True):
|
439 |
+
cov = self.cmom2 / (self.count - (1 if unbiased else 0))
|
440 |
+
rstdev = cov.diag().sqrt().reciprocal()
|
441 |
+
return self._restore_result_shape(rstdev[:, None] * cov * rstdev[None, :])
|
442 |
+
|
443 |
+
def variance(self, unbiased=True):
|
444 |
+
return self._restore_result_shape(
|
445 |
+
self.cmom2.diag() / (self.count - (1 if unbiased else 0))
|
446 |
+
)
|
447 |
+
|
448 |
+
def stdev(self, unbiased=True):
|
449 |
+
return self.variance(unbiased=unbiased).sqrt()
|
450 |
+
|
451 |
+
def state_dict(self):
|
452 |
+
return dict(
|
453 |
+
constructor=self.__module__ + "." + self.__class__.__name__ + "()",
|
454 |
+
count=self.count,
|
455 |
+
data_shape=self.data_shape and tuple(self.data_shape),
|
456 |
+
mean=self._mean.cpu().numpy(),
|
457 |
+
cmom2=self.cmom2.cpu().numpy(),
|
458 |
+
)
|
459 |
+
|
460 |
+
def load_state_dict(self, state):
|
461 |
+
self.count = state["count"]
|
462 |
+
self._mean = torch.from_numpy(state["mean"])
|
463 |
+
self.cmom2 = torch.from_numpy(state["cmom2"])
|
464 |
+
self.data_shape = (
|
465 |
+
None if state["data_shape"] is None else tuple(state["data_shape"])
|
466 |
+
)
|
467 |
+
|
468 |
+
|
469 |
+
class SecondMoment(Stat):
|
470 |
+
"""
|
471 |
+
Running computation. Use this when the entire non-centered 2nd-moment
|
472 |
+
'covariance-like' matrix is needed, and when the whole matrix fits
|
473 |
+
in the GPU.
|
474 |
+
"""
|
475 |
+
|
476 |
+
def __init__(self, split_batch=True, state=None):
|
477 |
+
if state is not None:
|
478 |
+
return super().__init__(state)
|
479 |
+
self.count = 0
|
480 |
+
self.mom2 = None
|
481 |
+
self.split_batch = split_batch
|
482 |
+
|
483 |
+
def add(self, a):
|
484 |
+
a = self._normalize_add_shape(a)
|
485 |
+
if len(a) == 0:
|
486 |
+
return
|
487 |
+
# Initial batch reveals the shape of the data.
|
488 |
+
if self.count == 0:
|
489 |
+
self.mom2 = a.new(a.shape[1], a.shape[1]).zero_()
|
490 |
+
batch_count = a.shape[0]
|
491 |
+
# Update the covariance using the batch deviation
|
492 |
+
self.count += batch_count
|
493 |
+
self.mom2 += a.t().mm(a)
|
494 |
+
|
495 |
+
def to_(self, device):
|
496 |
+
if self.mom2 is not None:
|
497 |
+
self.mom2 = self.mom2.to(device)
|
498 |
+
|
499 |
+
def moment(self):
|
500 |
+
return self.mom2 / self.count
|
501 |
+
|
502 |
+
def state_dict(self):
|
503 |
+
return dict(
|
504 |
+
constructor=self.__module__ + "." + self.__class__.__name__ + "()",
|
505 |
+
count=self.count,
|
506 |
+
mom2=self.mom2.cpu().numpy(),
|
507 |
+
)
|
508 |
+
|
509 |
+
def load_state_dict(self, state):
|
510 |
+
self.count = int(state["count"])
|
511 |
+
self.mom2 = torch.from_numpy(state["mom2"])
|
512 |
+
|
513 |
+
|
514 |
+
class Bincount(Stat):
|
515 |
+
"""
|
516 |
+
Running bincount. The counted array should be an integer type with
|
517 |
+
non-negative integers.
|
518 |
+
"""
|
519 |
+
|
520 |
+
def __init__(self, state=None):
|
521 |
+
if state is not None:
|
522 |
+
return super().__init__(state)
|
523 |
+
self.count = 0
|
524 |
+
self._bincount = None
|
525 |
+
|
526 |
+
def add(self, a, size=None):
|
527 |
+
a = a.view(-1)
|
528 |
+
bincount = a.bincount()
|
529 |
+
if self._bincount is None:
|
530 |
+
self._bincount = bincount
|
531 |
+
elif len(self._bincount) < len(bincount):
|
532 |
+
bincount[: len(self._bincount)] += self._bincount
|
533 |
+
self._bincount = bincount
|
534 |
+
else:
|
535 |
+
self._bincount[: len(bincount)] += bincount
|
536 |
+
if size is None:
|
537 |
+
self.count += len(a)
|
538 |
+
else:
|
539 |
+
self.count += size
|
540 |
+
|
541 |
+
def to_(self, device):
|
542 |
+
self._bincount = self._bincount.to(device)
|
543 |
+
|
544 |
+
def size(self):
|
545 |
+
return self.count
|
546 |
+
|
547 |
+
def bincount(self):
|
548 |
+
return self._bincount
|
549 |
+
|
550 |
+
def state_dict(self):
|
551 |
+
return dict(
|
552 |
+
constructor=self.__module__ + "." + self.__class__.__name__ + "()",
|
553 |
+
count=self.count,
|
554 |
+
bincount=self._bincount.cpu().numpy(),
|
555 |
+
)
|
556 |
+
|
557 |
+
def load_state_dict(self, dic):
|
558 |
+
self.count = int(dic["count"])
|
559 |
+
self._bincount = torch.from_numpy(dic["bincount"])
|
560 |
+
|
561 |
+
|
562 |
+
class CrossCovariance(Stat):
|
563 |
+
"""
|
564 |
+
Covariance. Use this when an off-diagonal block of the covariance
|
565 |
+
matrix is needed (e.g., when the whole covariance matrix does
|
566 |
+
not fit in the GPU, this could use a quarter of the memory).
|
567 |
+
|
568 |
+
Chan-style numerically stable update of mean and full covariance matrix.
|
569 |
+
Chan, Golub. LeVeque. 1983. http://www.jstor.org/stable/2683386
|
570 |
+
"""
|
571 |
+
|
572 |
+
def __init__(self, split_batch=True, state=None):
|
573 |
+
if state is not None:
|
574 |
+
return super().__init__(state)
|
575 |
+
self.count = 0
|
576 |
+
self._mean = None
|
577 |
+
self.cmom2 = None
|
578 |
+
self.v_cmom2 = None
|
579 |
+
self.split_batch = split_batch
|
580 |
+
|
581 |
+
def add(self, a, b):
|
582 |
+
if len(a.shape) == 1:
|
583 |
+
a = a[None, :]
|
584 |
+
b = b[None, :]
|
585 |
+
assert a.shape[0] == b.shape[0]
|
586 |
+
if len(a.shape) > 2:
|
587 |
+
a, b = [
|
588 |
+
d.view(d.shape[0], d.shape[1], -1)
|
589 |
+
.permute(0, 2, 1)
|
590 |
+
.reshape(-1, d.shape[1])
|
591 |
+
for d in [a, b]
|
592 |
+
]
|
593 |
+
batch_count = a.shape[0]
|
594 |
+
# Initial batch.
|
595 |
+
if self._mean is None:
|
596 |
+
self.count = batch_count
|
597 |
+
self._mean = [d.sum(0) / batch_count for d in [a, b]]
|
598 |
+
centered = [d - bm for d, bm in zip([a, b], self._mean)]
|
599 |
+
self.v_cmom2 = [c.pow(2).sum(0) for c in centered]
|
600 |
+
self.cmom2 = centered[0].t().mm(centered[1])
|
601 |
+
return
|
602 |
+
# Update a batch using Chan-style update for numerical stability.
|
603 |
+
self.count += batch_count
|
604 |
+
# Update the mean according to the batch deviation from the old mean.
|
605 |
+
delta = [(d - bm) for d, bm in zip([a, b], self._mean)]
|
606 |
+
for m, d in zip(self._mean, delta):
|
607 |
+
m.add_(d.sum(0) / self.count)
|
608 |
+
delta2 = [(d - bm) for d, bm in zip([a, b], self._mean)]
|
609 |
+
# Update the cross-covariance using the batch deviation
|
610 |
+
self.cmom2.addmm_(mat1=delta[0].t(), mat2=delta2[1])
|
611 |
+
# Update the variance using the batch deviation
|
612 |
+
for vc2, d, d2 in zip(self.v_cmom2, delta, delta2):
|
613 |
+
vc2.add_((d * d2).sum(0))
|
614 |
+
|
615 |
+
def mean(self):
|
616 |
+
return self._mean
|
617 |
+
|
618 |
+
def variance(self, unbiased=True):
|
619 |
+
return [vc2 / (self.count - (1 if unbiased else 0)) for vc2 in self.v_cmom2]
|
620 |
+
|
621 |
+
def stdev(self, unbiased=True):
|
622 |
+
return [v.sqrt() for v in self.variance(unbiased=unbiased)]
|
623 |
+
|
624 |
+
def covariance(self, unbiased=True):
|
625 |
+
return self.cmom2 / (self.count - (1 if unbiased else 0))
|
626 |
+
|
627 |
+
def correlation(self):
|
628 |
+
covariance = self.covariance(unbiased=False)
|
629 |
+
rstdev = [s.reciprocal() for s in self.stdev(unbiased=False)]
|
630 |
+
cor = rstdev[0][:, None] * covariance * rstdev[1][None, :]
|
631 |
+
# Remove NaNs
|
632 |
+
cor[torch.isnan(cor)] = 0
|
633 |
+
return cor
|
634 |
+
|
635 |
+
def to_(self, device):
|
636 |
+
self._mean = [m.to(device) for m in self._mean]
|
637 |
+
self.v_cmom2 = [vcs.to(device) for vcs in self.v_cmom2]
|
638 |
+
self.cmom2 = self.cmom2.to(device)
|
639 |
+
|
640 |
+
def state_dict(self):
|
641 |
+
return dict(
|
642 |
+
constructor=self.__module__ + "." + self.__class__.__name__ + "()",
|
643 |
+
count=self.count,
|
644 |
+
mean_a=self._mean[0].cpu().numpy(),
|
645 |
+
mean_b=self._mean[1].cpu().numpy(),
|
646 |
+
cmom2_a=self.v_cmom2[0].cpu().numpy(),
|
647 |
+
cmom2_b=self.v_cmom2[1].cpu().numpy(),
|
648 |
+
cmom2=self.cmom2.cpu().numpy(),
|
649 |
+
)
|
650 |
+
|
651 |
+
def load_state_dict(self, state):
|
652 |
+
self.count = int(state["count"])
|
653 |
+
self._mean = [torch.from_numpy(state[f"mean_{k}"]) for k in "ab"]
|
654 |
+
self.v_cmom2 = [torch.from_numpy(state[f"cmom2_{k}"]) for k in "ab"]
|
655 |
+
self.cmom2 = torch.from_numpy(state["cmom2"])
|
656 |
+
|
657 |
+
|
658 |
+
def _float_from_bool(a):
|
659 |
+
"""
|
660 |
+
Since pytorch only supports matrix multiplication on float,
|
661 |
+
IoU computations are done using floating point types.
|
662 |
+
|
663 |
+
This function binarizes the input (positive to True and
|
664 |
+
nonpositive to False), and converts from bool to float.
|
665 |
+
If the data is already a floating-point type, it leaves
|
666 |
+
it keeps the same type; otherwise it uses float.
|
667 |
+
"""
|
668 |
+
if a.dtype == torch.bool:
|
669 |
+
return a.float()
|
670 |
+
if a.dtype.is_floating_point:
|
671 |
+
return a.sign().clamp_(0)
|
672 |
+
return (a > 0).float()
|
673 |
+
|
674 |
+
|
675 |
+
class IoU(Stat):
|
676 |
+
"""
|
677 |
+
Running computation of intersections and unions of all features.
|
678 |
+
"""
|
679 |
+
|
680 |
+
def __init__(self, state=None):
|
681 |
+
if state is not None:
|
682 |
+
return super().__init__(state)
|
683 |
+
self.count = 0
|
684 |
+
self._intersection = None
|
685 |
+
|
686 |
+
def add(self, a):
|
687 |
+
assert len(a.shape) == 2
|
688 |
+
a = _float_from_bool(a)
|
689 |
+
if self._intersection is None:
|
690 |
+
self._intersection = torch.mm(a.t(), a)
|
691 |
+
else:
|
692 |
+
self._intersection.addmm_(a.t(), a)
|
693 |
+
self.count += len(a)
|
694 |
+
|
695 |
+
def size(self):
|
696 |
+
return self.count
|
697 |
+
|
698 |
+
def intersection(self):
|
699 |
+
return self._intersection
|
700 |
+
|
701 |
+
def union(self):
|
702 |
+
total = self._intersection.diagonal(0)
|
703 |
+
return total[:, None] + total[None, :] - self._intersection
|
704 |
+
|
705 |
+
def iou(self):
|
706 |
+
return self.intersection() / (self.union() + 1e-20)
|
707 |
+
|
708 |
+
def to_(self, _device):
|
709 |
+
self._intersection = self._intersection.to(_device)
|
710 |
+
|
711 |
+
def state_dict(self):
|
712 |
+
return dict(
|
713 |
+
constructor=self.__module__ + "." + self.__class__.__name__ + "()",
|
714 |
+
count=self.count,
|
715 |
+
intersection=self._intersection.cpu().numpy(),
|
716 |
+
)
|
717 |
+
|
718 |
+
def load_state_dict(self, state):
|
719 |
+
self.count = int(state["count"])
|
720 |
+
self._intersection = torch.tensor(state["intersection"])
|
721 |
+
|
722 |
+
|
723 |
+
class CrossIoU(Stat):
|
724 |
+
"""
|
725 |
+
Running computation of intersections and unions of two binary vectors.
|
726 |
+
"""
|
727 |
+
|
728 |
+
def __init__(self, state=None):
|
729 |
+
if state is not None:
|
730 |
+
return super().__init__(state)
|
731 |
+
self.count = 0
|
732 |
+
self._intersection = None
|
733 |
+
self.total_a = None
|
734 |
+
self.total_b = None
|
735 |
+
|
736 |
+
def add(self, a, b):
|
737 |
+
assert len(a.shape) == 2 and len(b.shape) == 2
|
738 |
+
assert len(a) == len(b), f"{len(a)} vs {len(b)}"
|
739 |
+
a = _float_from_bool(a) # CUDA only supports mm on float...
|
740 |
+
b = _float_from_bool(b) # otherwise we would use integers.
|
741 |
+
intersection = torch.mm(a.t(), b)
|
742 |
+
asum = a.sum(0)
|
743 |
+
bsum = b.sum(0)
|
744 |
+
if self._intersection is None:
|
745 |
+
self._intersection = intersection
|
746 |
+
self.total_a = asum
|
747 |
+
self.total_b = bsum
|
748 |
+
else:
|
749 |
+
self._intersection += intersection
|
750 |
+
self.total_a += asum
|
751 |
+
self.total_b += bsum
|
752 |
+
self.count += len(a)
|
753 |
+
|
754 |
+
def size(self):
|
755 |
+
return self.count
|
756 |
+
|
757 |
+
def intersection(self):
|
758 |
+
return self._intersection
|
759 |
+
|
760 |
+
def union(self):
|
761 |
+
return self.total_a[:, None] + self.total_b[None, :] - self._intersection
|
762 |
+
|
763 |
+
def iou(self):
|
764 |
+
return self.intersection() / (self.union() + 1e-20)
|
765 |
+
|
766 |
+
def to_(self, _device):
|
767 |
+
self.total_a = self.total_a.to(_device)
|
768 |
+
self.total_b = self.total_b.to(_device)
|
769 |
+
self._intersection = self._intersection.to(_device)
|
770 |
+
|
771 |
+
def state_dict(self):
|
772 |
+
return dict(
|
773 |
+
constructor=self.__module__ + "." + self.__class__.__name__ + "()",
|
774 |
+
count=self.count,
|
775 |
+
total_a=self.total_a.cpu().numpy(),
|
776 |
+
total_b=self.total_b.cpu().numpy(),
|
777 |
+
intersection=self._intersection.cpu().numpy(),
|
778 |
+
)
|
779 |
+
|
780 |
+
def load_state_dict(self, state):
|
781 |
+
self.count = int(state["count"])
|
782 |
+
self.total_a = torch.tensor(state["total_a"])
|
783 |
+
self.total_b = torch.tensor(state["total_b"])
|
784 |
+
self._intersection = torch.tensor(state["intersection"])
|
785 |
+
|
786 |
+
|
787 |
+
class Quantile(Stat):
|
788 |
+
"""
|
789 |
+
Streaming randomized quantile computation for torch.
|
790 |
+
|
791 |
+
Add any amount of data repeatedly via add(data). At any time,
|
792 |
+
quantile estimates be read out using quantile(q).
|
793 |
+
|
794 |
+
Implemented as a sorted sample that retains at least r samples
|
795 |
+
(by default r = 3072); the number of retained samples will grow to
|
796 |
+
a finite ceiling as the data is accumulated. Accuracy scales according
|
797 |
+
to r: the default is to set resolution to be accurate to better than about
|
798 |
+
0.1%, while limiting storage to about 50,000 samples.
|
799 |
+
|
800 |
+
Good for computing quantiles of huge data without using much memory.
|
801 |
+
Works well on arbitrary data with probability near 1.
|
802 |
+
|
803 |
+
Based on the optimal KLL quantile algorithm by Karnin, Lang, and Liberty
|
804 |
+
from FOCS 2016. http://ieee-focs.org/FOCS-2016-Papers/3933a071.pdf
|
805 |
+
"""
|
806 |
+
|
807 |
+
def __init__(self, r=3 * 1024, buffersize=None, seed=None, state=None):
|
808 |
+
if state is not None:
|
809 |
+
return super().__init__(state)
|
810 |
+
self.depth = None
|
811 |
+
self.dtype = None
|
812 |
+
self.device = None
|
813 |
+
resolution = r * 2 # sample array is at least half full before discard
|
814 |
+
self.resolution = resolution
|
815 |
+
# Default buffersize: 128 samples (and smaller than resolution).
|
816 |
+
if buffersize is None:
|
817 |
+
buffersize = min(128, (resolution + 7) // 8)
|
818 |
+
self.buffersize = buffersize
|
819 |
+
self.samplerate = 1.0
|
820 |
+
self.data = None
|
821 |
+
self.firstfree = [0]
|
822 |
+
self.randbits = torch.ByteTensor(resolution)
|
823 |
+
self.currentbit = len(self.randbits) - 1
|
824 |
+
self.extremes = None
|
825 |
+
self.count = 0
|
826 |
+
self.batchcount = 0
|
827 |
+
|
828 |
+
def size(self):
|
829 |
+
return self.count
|
830 |
+
|
831 |
+
def _lazy_init(self, incoming):
|
832 |
+
self.depth = incoming.shape[1]
|
833 |
+
self.dtype = incoming.dtype
|
834 |
+
self.device = incoming.device
|
835 |
+
self.data = [
|
836 |
+
torch.zeros(
|
837 |
+
self.depth, self.resolution, dtype=self.dtype, device=self.device
|
838 |
+
)
|
839 |
+
]
|
840 |
+
self.extremes = torch.zeros(self.depth, 2, dtype=self.dtype, device=self.device)
|
841 |
+
self.extremes[:, 0] = float("inf")
|
842 |
+
self.extremes[:, -1] = -float("inf")
|
843 |
+
|
844 |
+
def to_(self, device):
|
845 |
+
"""Switches internal storage to specified device."""
|
846 |
+
if device != self.device:
|
847 |
+
old_data = self.data
|
848 |
+
old_extremes = self.extremes
|
849 |
+
self.data = [d.to(device) for d in self.data]
|
850 |
+
self.extremes = self.extremes.to(device)
|
851 |
+
self.device = self.extremes.device
|
852 |
+
del old_data
|
853 |
+
del old_extremes
|
854 |
+
|
855 |
+
def add(self, incoming):
|
856 |
+
if self.depth is None:
|
857 |
+
self._lazy_init(incoming)
|
858 |
+
assert len(incoming.shape) == 2
|
859 |
+
assert incoming.shape[1] == self.depth, (incoming.shape[1], self.depth)
|
860 |
+
self.count += incoming.shape[0]
|
861 |
+
self.batchcount += 1
|
862 |
+
# Convert to a flat torch array.
|
863 |
+
if self.samplerate >= 1.0:
|
864 |
+
self._add_every(incoming)
|
865 |
+
return
|
866 |
+
# If we are sampling, then subsample a large chunk at a time.
|
867 |
+
self._scan_extremes(incoming)
|
868 |
+
chunksize = int(math.ceil(self.buffersize / self.samplerate))
|
869 |
+
for index in range(0, len(incoming), chunksize):
|
870 |
+
batch = incoming[index : index + chunksize]
|
871 |
+
sample = sample_portion(batch, self.samplerate)
|
872 |
+
if len(sample):
|
873 |
+
self._add_every(sample)
|
874 |
+
|
875 |
+
def _add_every(self, incoming):
|
876 |
+
supplied = len(incoming)
|
877 |
+
index = 0
|
878 |
+
while index < supplied:
|
879 |
+
ff = self.firstfree[0]
|
880 |
+
available = self.data[0].shape[1] - ff
|
881 |
+
if available == 0:
|
882 |
+
if not self._shift():
|
883 |
+
# If we shifted by subsampling, then subsample.
|
884 |
+
incoming = incoming[index:]
|
885 |
+
if self.samplerate >= 0.5:
|
886 |
+
# First time sampling - the data source is very large.
|
887 |
+
self._scan_extremes(incoming)
|
888 |
+
incoming = sample_portion(incoming, self.samplerate)
|
889 |
+
index = 0
|
890 |
+
supplied = len(incoming)
|
891 |
+
ff = self.firstfree[0]
|
892 |
+
available = self.data[0].shape[1] - ff
|
893 |
+
copycount = min(available, supplied - index)
|
894 |
+
self.data[0][:, ff : ff + copycount] = torch.t(
|
895 |
+
incoming[index : index + copycount, :]
|
896 |
+
)
|
897 |
+
self.firstfree[0] += copycount
|
898 |
+
index += copycount
|
899 |
+
|
900 |
+
def _shift(self):
|
901 |
+
index = 0
|
902 |
+
# If remaining space at the current layer is less than half prev
|
903 |
+
# buffer size (rounding up), then we need to shift it up to ensure
|
904 |
+
# enough space for future shifting.
|
905 |
+
while self.data[index].shape[1] - self.firstfree[index] < (
|
906 |
+
-(-self.data[index - 1].shape[1] // 2) if index else 1
|
907 |
+
):
|
908 |
+
if index + 1 >= len(self.data):
|
909 |
+
return self._expand()
|
910 |
+
data = self.data[index][:, 0 : self.firstfree[index]]
|
911 |
+
data = data.sort()[0]
|
912 |
+
if index == 0 and self.samplerate >= 1.0:
|
913 |
+
self._update_extremes(data[:, 0], data[:, -1])
|
914 |
+
offset = self._randbit()
|
915 |
+
position = self.firstfree[index + 1]
|
916 |
+
subset = data[:, offset::2]
|
917 |
+
self.data[index + 1][:, position : position + subset.shape[1]] = subset
|
918 |
+
self.firstfree[index] = 0
|
919 |
+
self.firstfree[index + 1] += subset.shape[1]
|
920 |
+
index += 1
|
921 |
+
return True
|
922 |
+
|
923 |
+
def _scan_extremes(self, incoming):
|
924 |
+
# When sampling, we need to scan every item still to get extremes
|
925 |
+
self._update_extremes(
|
926 |
+
torch.min(incoming, dim=0)[0], torch.max(incoming, dim=0)[0]
|
927 |
+
)
|
928 |
+
|
929 |
+
def _update_extremes(self, minr, maxr):
|
930 |
+
self.extremes[:, 0] = torch.min(
|
931 |
+
torch.stack([self.extremes[:, 0], minr]), dim=0
|
932 |
+
)[0]
|
933 |
+
self.extremes[:, -1] = torch.max(
|
934 |
+
torch.stack([self.extremes[:, -1], maxr]), dim=0
|
935 |
+
)[0]
|
936 |
+
|
937 |
+
def _randbit(self):
|
938 |
+
self.currentbit += 1
|
939 |
+
if self.currentbit >= len(self.randbits):
|
940 |
+
self.randbits.random_(to=2)
|
941 |
+
self.currentbit = 0
|
942 |
+
return self.randbits[self.currentbit]
|
943 |
+
|
944 |
+
def state_dict(self):
|
945 |
+
state = dict(
|
946 |
+
constructor=self.__module__ + "." + self.__class__.__name__ + "()",
|
947 |
+
resolution=self.resolution,
|
948 |
+
depth=self.depth,
|
949 |
+
buffersize=self.buffersize,
|
950 |
+
samplerate=self.samplerate,
|
951 |
+
sizes=numpy.array([d.shape[1] for d in self.data]),
|
952 |
+
extremes=self.extremes.cpu().detach().numpy(),
|
953 |
+
size=self.count,
|
954 |
+
batchcount=self.batchcount,
|
955 |
+
)
|
956 |
+
for i, (d, f) in enumerate(zip(self.data, self.firstfree)):
|
957 |
+
state[f"data.{i}"] = d.cpu().detach().numpy()[:, :f].T
|
958 |
+
return state
|
959 |
+
|
960 |
+
def load_state_dict(self, state):
|
961 |
+
self.resolution = int(state["resolution"])
|
962 |
+
self.randbits = torch.ByteTensor(self.resolution)
|
963 |
+
self.currentbit = len(self.randbits) - 1
|
964 |
+
self.depth = int(state["depth"])
|
965 |
+
self.buffersize = int(state["buffersize"])
|
966 |
+
self.samplerate = float(state["samplerate"])
|
967 |
+
firstfree = []
|
968 |
+
buffers = []
|
969 |
+
for i, s in enumerate(state["sizes"]):
|
970 |
+
d = state[f"data.{i}"]
|
971 |
+
firstfree.append(d.shape[0])
|
972 |
+
buf = numpy.zeros((d.shape[1], s), dtype=d.dtype)
|
973 |
+
buf[:, : d.shape[0]] = d.T
|
974 |
+
buffers.append(torch.from_numpy(buf))
|
975 |
+
self.firstfree = firstfree
|
976 |
+
self.data = buffers
|
977 |
+
self.extremes = torch.from_numpy((state["extremes"]))
|
978 |
+
self.count = int(state["size"])
|
979 |
+
self.batchcount = int(state.get("batchcount", 0))
|
980 |
+
self.dtype = self.extremes.dtype
|
981 |
+
self.device = self.extremes.device
|
982 |
+
|
983 |
+
def min(self):
|
984 |
+
return self.minmax()[0]
|
985 |
+
|
986 |
+
def max(self):
|
987 |
+
return self.minmax()[-1]
|
988 |
+
|
989 |
+
def minmax(self):
|
990 |
+
if self.firstfree[0]:
|
991 |
+
self._scan_extremes(self.data[0][:, : self.firstfree[0]].t())
|
992 |
+
return self.extremes.clone()
|
993 |
+
|
994 |
+
def median(self):
|
995 |
+
return self.quantiles(0.5)
|
996 |
+
|
997 |
+
def mean(self):
|
998 |
+
return self.integrate(lambda x: x) / self.count
|
999 |
+
|
1000 |
+
def variance(self, unbiased=True):
|
1001 |
+
mean = self.mean()[:, None]
|
1002 |
+
return self.integrate(lambda x: (x - mean).pow(2)) / (
|
1003 |
+
self.count - (1 if unbiased else 0)
|
1004 |
+
)
|
1005 |
+
|
1006 |
+
def stdev(self, unbiased=True):
|
1007 |
+
return self.variance(unbiased=unbiased).sqrt()
|
1008 |
+
|
1009 |
+
def _expand(self):
|
1010 |
+
cap = self._next_capacity()
|
1011 |
+
if cap > 0:
|
1012 |
+
# First, make a new layer of the proper capacity.
|
1013 |
+
self.data.insert(
|
1014 |
+
0, torch.zeros(self.depth, cap, dtype=self.dtype, device=self.device)
|
1015 |
+
)
|
1016 |
+
self.firstfree.insert(0, 0)
|
1017 |
+
else:
|
1018 |
+
# Unless we're so big we are just subsampling.
|
1019 |
+
assert self.firstfree[0] == 0
|
1020 |
+
self.samplerate *= 0.5
|
1021 |
+
for index in range(1, len(self.data)):
|
1022 |
+
# Scan for existing data that needs to be moved down a level.
|
1023 |
+
amount = self.firstfree[index]
|
1024 |
+
if amount == 0:
|
1025 |
+
continue
|
1026 |
+
position = self.firstfree[index - 1]
|
1027 |
+
# Move data down if it would leave enough empty space there
|
1028 |
+
# This is the key invariant: enough empty space to fit half
|
1029 |
+
# of the previous level's buffer size (rounding up)
|
1030 |
+
if self.data[index - 1].shape[1] - (amount + position) >= (
|
1031 |
+
-(-self.data[index - 2].shape[1] // 2) if (index - 1) else 1
|
1032 |
+
):
|
1033 |
+
self.data[index - 1][:, position : position + amount] = self.data[
|
1034 |
+
index
|
1035 |
+
][:, :amount]
|
1036 |
+
self.firstfree[index - 1] += amount
|
1037 |
+
self.firstfree[index] = 0
|
1038 |
+
else:
|
1039 |
+
# Scrunch the data if it would not.
|
1040 |
+
data = self.data[index][:, :amount]
|
1041 |
+
data = data.sort()[0]
|
1042 |
+
if index == 1:
|
1043 |
+
self._update_extremes(data[:, 0], data[:, -1])
|
1044 |
+
offset = self._randbit()
|
1045 |
+
scrunched = data[:, offset::2]
|
1046 |
+
self.data[index][:, : scrunched.shape[1]] = scrunched
|
1047 |
+
self.firstfree[index] = scrunched.shape[1]
|
1048 |
+
return cap > 0
|
1049 |
+
|
1050 |
+
def _next_capacity(self):
|
1051 |
+
cap = int(math.ceil(self.resolution * (0.67 ** len(self.data))))
|
1052 |
+
if cap < 2:
|
1053 |
+
return 0
|
1054 |
+
# Round up to the nearest multiple of 8 for better GPU alignment.
|
1055 |
+
cap = -8 * (-cap // 8)
|
1056 |
+
return max(self.buffersize, cap)
|
1057 |
+
|
1058 |
+
def _weighted_summary(self, sort=True):
|
1059 |
+
if self.firstfree[0]:
|
1060 |
+
self._scan_extremes(self.data[0][:, : self.firstfree[0]].t())
|
1061 |
+
size = sum(self.firstfree)
|
1062 |
+
weights = torch.FloatTensor(size) # Floating point
|
1063 |
+
summary = torch.zeros(self.depth, size, dtype=self.dtype, device=self.device)
|
1064 |
+
index = 0
|
1065 |
+
for level, ff in enumerate(self.firstfree):
|
1066 |
+
if ff == 0:
|
1067 |
+
continue
|
1068 |
+
summary[:, index : index + ff] = self.data[level][:, :ff]
|
1069 |
+
weights[index : index + ff] = 2.0**level
|
1070 |
+
index += ff
|
1071 |
+
assert index == summary.shape[1]
|
1072 |
+
if sort:
|
1073 |
+
summary, order = torch.sort(summary, dim=-1)
|
1074 |
+
weights = weights[order.view(-1).cpu()].view(order.shape)
|
1075 |
+
summary = torch.cat(
|
1076 |
+
[self.extremes[:, :1], summary, self.extremes[:, 1:]], dim=-1
|
1077 |
+
)
|
1078 |
+
weights = torch.cat(
|
1079 |
+
[
|
1080 |
+
torch.zeros(weights.shape[0], 1),
|
1081 |
+
weights,
|
1082 |
+
torch.zeros(weights.shape[0], 1),
|
1083 |
+
],
|
1084 |
+
dim=-1,
|
1085 |
+
)
|
1086 |
+
return (summary, weights)
|
1087 |
+
|
1088 |
+
def quantiles(self, quantiles):
|
1089 |
+
if not hasattr(quantiles, "cpu"):
|
1090 |
+
quantiles = torch.tensor(quantiles)
|
1091 |
+
qshape = quantiles.shape
|
1092 |
+
if self.count == 0:
|
1093 |
+
return torch.full((self.depth,) + qshape, torch.nan)
|
1094 |
+
summary, weights = self._weighted_summary()
|
1095 |
+
cumweights = torch.cumsum(weights, dim=-1) - weights / 2
|
1096 |
+
cumweights /= torch.sum(weights, dim=-1, keepdim=True)
|
1097 |
+
result = torch.zeros(
|
1098 |
+
self.depth, quantiles.numel(), dtype=self.dtype, device=self.device
|
1099 |
+
)
|
1100 |
+
# numpy is needed for interpolation
|
1101 |
+
nq = quantiles.view(-1).cpu().detach().numpy()
|
1102 |
+
ncw = cumweights.cpu().detach().numpy()
|
1103 |
+
nsm = summary.cpu().detach().numpy()
|
1104 |
+
for d in range(self.depth):
|
1105 |
+
result[d] = torch.tensor(
|
1106 |
+
numpy.interp(nq, ncw[d], nsm[d]), dtype=self.dtype, device=self.device
|
1107 |
+
)
|
1108 |
+
return result.view((self.depth,) + qshape)
|
1109 |
+
|
1110 |
+
def integrate(self, fun):
|
1111 |
+
result = []
|
1112 |
+
for level, ff in enumerate(self.firstfree):
|
1113 |
+
if ff == 0:
|
1114 |
+
continue
|
1115 |
+
result.append(
|
1116 |
+
torch.sum(fun(self.data[level][:, :ff]) * (2.0**level), dim=-1)
|
1117 |
+
)
|
1118 |
+
if len(result) == 0:
|
1119 |
+
return None
|
1120 |
+
return torch.stack(result).sum(dim=0) / self.samplerate
|
1121 |
+
|
1122 |
+
def readout(self, count=1001):
|
1123 |
+
return self.quantiles(torch.linspace(0.0, 1.0, count))
|
1124 |
+
|
1125 |
+
def normalize(self, data):
|
1126 |
+
"""
|
1127 |
+
Given input data as taken from the training distirbution,
|
1128 |
+
normalizes every channel to reflect quantile values,
|
1129 |
+
uniformly distributed, within [0, 1].
|
1130 |
+
"""
|
1131 |
+
assert self.count > 0
|
1132 |
+
assert data.shape[0] == self.depth
|
1133 |
+
summary, weights = self._weighted_summary()
|
1134 |
+
cumweights = torch.cumsum(weights, dim=-1) - weights / 2
|
1135 |
+
cumweights /= torch.sum(weights, dim=-1, keepdim=True)
|
1136 |
+
result = torch.zeros_like(data).float()
|
1137 |
+
# numpy is needed for interpolation
|
1138 |
+
ndata = data.cpu().numpy().reshape((data.shape[0], -1))
|
1139 |
+
ncw = cumweights.cpu().numpy()
|
1140 |
+
nsm = summary.cpu().numpy()
|
1141 |
+
for d in range(self.depth):
|
1142 |
+
normed = torch.tensor(
|
1143 |
+
numpy.interp(ndata[d], nsm[d], ncw[d]),
|
1144 |
+
dtype=torch.float,
|
1145 |
+
device=data.device,
|
1146 |
+
).clamp_(0.0, 1.0)
|
1147 |
+
if len(data.shape) > 1:
|
1148 |
+
normed = normed.view(*(data.shape[1:]))
|
1149 |
+
result[d] = normed
|
1150 |
+
return result
|
1151 |
+
|
1152 |
+
|
1153 |
+
def sample_portion(vec, p=0.5):
|
1154 |
+
"""
|
1155 |
+
Subsamples a fraction (given by p) of the given batch. Used by
|
1156 |
+
Quantile when the data gets very very large.
|
1157 |
+
"""
|
1158 |
+
bits = torch.bernoulli(
|
1159 |
+
torch.zeros(vec.shape[0], dtype=torch.uint8, device=vec.device), p
|
1160 |
+
)
|
1161 |
+
return vec[bits]
|
1162 |
+
|
1163 |
+
|
1164 |
+
class TopK:
|
1165 |
+
"""
|
1166 |
+
A class to keep a running tally of the the top k values (and indexes)
|
1167 |
+
of any number of torch feature components. Will work on the GPU if
|
1168 |
+
the data is on the GPU. Tracks largest by default, but tracks smallest
|
1169 |
+
if largest=False is passed.
|
1170 |
+
|
1171 |
+
This version flattens all arrays to avoid crashes.
|
1172 |
+
"""
|
1173 |
+
|
1174 |
+
def __init__(self, k=100, largest=True, state=None):
|
1175 |
+
if state is not None:
|
1176 |
+
return super().__init__(state)
|
1177 |
+
self.k = k
|
1178 |
+
self.count = 0
|
1179 |
+
# This version flattens all data internally to 2-d tensors,
|
1180 |
+
# to avoid crashes with the current pytorch topk implementation.
|
1181 |
+
# The data is puffed back out to arbitrary tensor shapes on output.
|
1182 |
+
self.data_shape = None
|
1183 |
+
self.top_data = None
|
1184 |
+
self.top_index = None
|
1185 |
+
self.next = 0
|
1186 |
+
self.linear_index = 0
|
1187 |
+
self.perm = None
|
1188 |
+
self.largest = largest
|
1189 |
+
|
1190 |
+
def add(self, data, index=None):
|
1191 |
+
"""
|
1192 |
+
Adds a batch of data to be considered for the running top k.
|
1193 |
+
The zeroth dimension enumerates the observations. All other
|
1194 |
+
dimensions enumerate different features.
|
1195 |
+
"""
|
1196 |
+
if self.top_data is None:
|
1197 |
+
# Allocation: allocate a buffer of size 5*k, at least 10, for each.
|
1198 |
+
self.data_shape = data.shape[1:]
|
1199 |
+
feature_size = int(numpy.prod(self.data_shape))
|
1200 |
+
self.top_data = torch.zeros(
|
1201 |
+
feature_size, max(10, self.k * 5), out=data.new()
|
1202 |
+
)
|
1203 |
+
self.top_index = self.top_data.clone().long()
|
1204 |
+
self.linear_index = (
|
1205 |
+
0
|
1206 |
+
if len(data.shape) == 1
|
1207 |
+
else torch.arange(feature_size, out=self.top_index.new()).mul_(
|
1208 |
+
self.top_data.shape[-1]
|
1209 |
+
)[:, None]
|
1210 |
+
)
|
1211 |
+
size = data.shape[0]
|
1212 |
+
sk = min(size, self.k)
|
1213 |
+
if self.top_data.shape[-1] < self.next + sk:
|
1214 |
+
# Compression: if full, keep topk only.
|
1215 |
+
self.top_data[:, : self.k], self.top_index[:, : self.k] = self.topk(
|
1216 |
+
sorted=False, flat=True
|
1217 |
+
)
|
1218 |
+
self.next = self.k
|
1219 |
+
# Pick: copy the top sk of the next batch into the buffer.
|
1220 |
+
# Currently strided topk is slow. So we clone after transpose.
|
1221 |
+
# TODO: remove the clone() if it becomes faster.
|
1222 |
+
cdata = data.reshape(size, numpy.prod(data.shape[1:])).t().clone()
|
1223 |
+
td, ti = cdata.topk(sk, sorted=False, largest=self.largest)
|
1224 |
+
self.top_data[:, self.next : self.next + sk] = td
|
1225 |
+
if index is not None:
|
1226 |
+
ti = index[ti]
|
1227 |
+
else:
|
1228 |
+
ti = ti + self.count
|
1229 |
+
self.top_index[:, self.next : self.next + sk] = ti
|
1230 |
+
self.next += sk
|
1231 |
+
self.count += size
|
1232 |
+
|
1233 |
+
def size(self):
|
1234 |
+
return self.count
|
1235 |
+
|
1236 |
+
def topk(self, sorted=True, flat=False):
|
1237 |
+
"""
|
1238 |
+
Returns top k data items and indexes in each dimension,
|
1239 |
+
with channels in the first dimension and k in the last dimension.
|
1240 |
+
"""
|
1241 |
+
k = min(self.k, self.next)
|
1242 |
+
# bti are top indexes relative to buffer array.
|
1243 |
+
td, bti = self.top_data[:, : self.next].topk(
|
1244 |
+
k, sorted=sorted, largest=self.largest
|
1245 |
+
)
|
1246 |
+
# we want to report top indexes globally, which is ti.
|
1247 |
+
ti = self.top_index.view(-1)[(bti + self.linear_index).view(-1)].view(
|
1248 |
+
*bti.shape
|
1249 |
+
)
|
1250 |
+
if flat:
|
1251 |
+
return td, ti
|
1252 |
+
else:
|
1253 |
+
return (
|
1254 |
+
td.view(*(self.data_shape + (-1,))),
|
1255 |
+
ti.view(*(self.data_shape + (-1,))),
|
1256 |
+
)
|
1257 |
+
|
1258 |
+
def to_(self, device):
|
1259 |
+
if self.top_data is not None:
|
1260 |
+
self.top_data = self.top_data.to(device)
|
1261 |
+
if self.top_index is not None:
|
1262 |
+
self.top_index = self.top_index.to(device)
|
1263 |
+
if isinstance(self.linear_index, torch.Tensor):
|
1264 |
+
self.linear_index = self.linear_index.to(device)
|
1265 |
+
|
1266 |
+
def state_dict(self):
|
1267 |
+
return dict(
|
1268 |
+
constructor=self.__module__ + "." + self.__class__.__name__ + "()",
|
1269 |
+
k=self.k,
|
1270 |
+
count=self.count,
|
1271 |
+
largest=self.largest,
|
1272 |
+
data_shape=self.data_shape and tuple(self.data_shape),
|
1273 |
+
top_data=self.top_data.cpu().detach().numpy(),
|
1274 |
+
top_index=self.top_index.cpu().detach().numpy(),
|
1275 |
+
next=self.next,
|
1276 |
+
linear_index=(
|
1277 |
+
self.linear_index.cpu().numpy()
|
1278 |
+
if isinstance(self.linear_index, torch.Tensor)
|
1279 |
+
else self.linear_index
|
1280 |
+
),
|
1281 |
+
perm=self.perm,
|
1282 |
+
)
|
1283 |
+
|
1284 |
+
def load_state_dict(self, state):
|
1285 |
+
self.k = int(state["k"])
|
1286 |
+
self.count = int(state["count"])
|
1287 |
+
self.largest = bool(state.get("largest", True))
|
1288 |
+
self.data_shape = (
|
1289 |
+
None if state["data_shape"] is None else tuple(state["data_shape"])
|
1290 |
+
)
|
1291 |
+
self.top_data = torch.from_numpy(state["top_data"])
|
1292 |
+
self.top_index = torch.from_numpy(state["top_index"])
|
1293 |
+
self.next = int(state["next"])
|
1294 |
+
self.linear_index = (
|
1295 |
+
torch.from_numpy(state["linear_index"])
|
1296 |
+
if len(state["linear_index"].shape) > 0
|
1297 |
+
else int(state["linear_index"])
|
1298 |
+
)
|
1299 |
+
|
1300 |
+
|
1301 |
+
class History(Stat):
|
1302 |
+
"""
|
1303 |
+
Accumulates the concatenation of all the added data.
|
1304 |
+
"""
|
1305 |
+
|
1306 |
+
def __init__(self, data=None, state=None):
|
1307 |
+
if state is not None:
|
1308 |
+
return super().__init__(state)
|
1309 |
+
self._data = data
|
1310 |
+
self._added = []
|
1311 |
+
|
1312 |
+
def _cat_added(self):
|
1313 |
+
if len(self._added):
|
1314 |
+
self._data = torch.cat(
|
1315 |
+
([self._data] if self._data is not None else []) + self._added
|
1316 |
+
)
|
1317 |
+
self._added = []
|
1318 |
+
|
1319 |
+
def add(self, d):
|
1320 |
+
self._added.append(d)
|
1321 |
+
if len(self._added) > 100:
|
1322 |
+
self._cat_added()
|
1323 |
+
|
1324 |
+
def history(self):
|
1325 |
+
self._cat_added()
|
1326 |
+
return self._data
|
1327 |
+
|
1328 |
+
def load_state_dict(self, state):
|
1329 |
+
data = state["data"]
|
1330 |
+
self._data = None if data is None else torch.from_numpy(data)
|
1331 |
+
self._added = []
|
1332 |
+
|
1333 |
+
def state_dict(self):
|
1334 |
+
self._cat_added()
|
1335 |
+
return dict(
|
1336 |
+
constructor=self.__module__ + "." + self.__class__.__name__ + "()",
|
1337 |
+
data=None if self._data is None else self._data.cpu().numpy(),
|
1338 |
+
)
|
1339 |
+
|
1340 |
+
def to_(self, device):
|
1341 |
+
"""Switches internal storage to specified device."""
|
1342 |
+
self._cat_added()
|
1343 |
+
if self._data is not None:
|
1344 |
+
self._data = self._data.to(device)
|
1345 |
+
|
1346 |
+
|
1347 |
+
class CombinedStat(Stat):
|
1348 |
+
"""
|
1349 |
+
A Stat that bundles together multiple Stat objects.
|
1350 |
+
Convenient for loading and saving a state_dict made up of a
|
1351 |
+
hierarchy of stats, and for use with the tally() function.
|
1352 |
+
Example:
|
1353 |
+
|
1354 |
+
cs = CombinedStat(m=Mean(), q=Quantile())
|
1355 |
+
for [b] in tally(cs, MyDataSet(), cache=fn, batch_size=100):
|
1356 |
+
cs.add(b)
|
1357 |
+
print(cs.m.mean())
|
1358 |
+
print(cs.q.median())
|
1359 |
+
"""
|
1360 |
+
|
1361 |
+
def __init__(self, state=None, **kwargs):
|
1362 |
+
self._objs = kwargs
|
1363 |
+
if state is not None:
|
1364 |
+
return super().__init__(state)
|
1365 |
+
|
1366 |
+
def __getattr__(self, k):
|
1367 |
+
if k in self._objs:
|
1368 |
+
return self._objs[k]
|
1369 |
+
raise AttributeError()
|
1370 |
+
|
1371 |
+
def add(self, d, *args, **kwargs):
|
1372 |
+
for obj in self._objs.values():
|
1373 |
+
obj.add(d, *args, **kwargs)
|
1374 |
+
|
1375 |
+
def load_state_dict(self, state):
|
1376 |
+
for prefix, obj in self._objs.items():
|
1377 |
+
obj.load_state_dict(pull_key_prefix(prefix, state))
|
1378 |
+
|
1379 |
+
def state_dict(self):
|
1380 |
+
result = {}
|
1381 |
+
for prefix, obj in self._objs.items():
|
1382 |
+
result.update(push_key_prefix(prefix, obj.state_dict()))
|
1383 |
+
return result
|
1384 |
+
|
1385 |
+
def to_(self, device):
|
1386 |
+
"""Switches internal storage to specified device."""
|
1387 |
+
for v in self._objs.values():
|
1388 |
+
v.to_(device)
|
1389 |
+
|
1390 |
+
|
1391 |
+
def push_key_prefix(prefix, d):
|
1392 |
+
"""
|
1393 |
+
Returns a dict with the same values as d, but where each key
|
1394 |
+
adds the prefix, followed by a dot.
|
1395 |
+
"""
|
1396 |
+
return {prefix + "." + k: v for k, v in d.items()}
|
1397 |
+
|
1398 |
+
|
1399 |
+
def pull_key_prefix(prefix, d):
|
1400 |
+
"""
|
1401 |
+
Returns a filtered dict of all the items of d that start with
|
1402 |
+
the given key prefix, plus a dot, with that prefix removed.
|
1403 |
+
"""
|
1404 |
+
pd = prefix + "."
|
1405 |
+
lpd = len(pd)
|
1406 |
+
return {k[lpd:]: v for k, v in d.items() if k.startswith(pd)}
|
1407 |
+
|
1408 |
+
|
1409 |
+
# We wish to be able to save None (null) values in numpy npz files,
|
1410 |
+
# yet do so without setting the unsecure 'allow_pickle' flag. To do
|
1411 |
+
# that, we will encode null as a special kind of IEEE 754 NaN value.
|
1412 |
+
# Inspired by https://github.com/zuiderkwast/nanbox/blob/master/nanbox.h
|
1413 |
+
# we follow the same Nanboxing scheme used in JavaScriptCore
|
1414 |
+
# (search for JSCJSValue.h#L435), which encodes null values in NaN
|
1415 |
+
# as the NaN value with hex pattern 0xfff8000000000002.
|
1416 |
+
|
1417 |
+
null_numpy_value = numpy.array(
|
1418 |
+
struct.unpack(">d", struct.pack(">Q", 0xFFF8000000000002))[0], dtype=numpy.float64
|
1419 |
+
)
|
1420 |
+
|
1421 |
+
|
1422 |
+
def is_null_numpy_value(v):
|
1423 |
+
"""
|
1424 |
+
True if v is a 64-bit float numpy scalar NaN matching null_numpy_value.
|
1425 |
+
"""
|
1426 |
+
return (
|
1427 |
+
isinstance(v, numpy.ndarray)
|
1428 |
+
and numpy.ndim(v) == 0
|
1429 |
+
and v.dtype == numpy.float64
|
1430 |
+
and numpy.isnan(v)
|
1431 |
+
and 0xFFF8000000000002 == struct.unpack(">Q", struct.pack(">d", v))[0]
|
1432 |
+
)
|
1433 |
+
|
1434 |
+
|
1435 |
+
def box_numpy_null(d):
|
1436 |
+
"""
|
1437 |
+
Replaces None with null_numpy_value, leaving non-None values unchanged.
|
1438 |
+
Recursively descends into a dictionary replacing None values.
|
1439 |
+
"""
|
1440 |
+
try:
|
1441 |
+
return {k: box_numpy_null(v) for k, v in d.items()}
|
1442 |
+
except Exception:
|
1443 |
+
return null_numpy_value if d is None else d
|
1444 |
+
|
1445 |
+
|
1446 |
+
def unbox_numpy_null(d):
|
1447 |
+
"""
|
1448 |
+
Reverses box_numpy_null, replacing null_numpy_value with None.
|
1449 |
+
Recursively descends into a dictionary replacing None values.
|
1450 |
+
"""
|
1451 |
+
try:
|
1452 |
+
return {k: unbox_numpy_null(v) for k, v in d.items()}
|
1453 |
+
except Exception:
|
1454 |
+
return None if is_null_numpy_value(d) else d
|
1455 |
+
|
1456 |
+
|
1457 |
+
def resolve_state_dict(s):
|
1458 |
+
"""
|
1459 |
+
Resolves a state, which can be a filename or a dict-like object.
|
1460 |
+
"""
|
1461 |
+
if isinstance(s, str):
|
1462 |
+
return unbox_numpy_null(numpy.load(s))
|
1463 |
+
return s
|
1464 |
+
|
1465 |
+
|
1466 |
+
global_load_cache_enabled = True
|
1467 |
+
|
1468 |
+
|
1469 |
+
def load_cached_state(cachefile, args, quiet=False, throw=False):
|
1470 |
+
"""
|
1471 |
+
Resolves a state, which can be a filename or a dict-like object.
|
1472 |
+
"""
|
1473 |
+
if not global_load_cache_enabled or cachefile is None:
|
1474 |
+
return None
|
1475 |
+
try:
|
1476 |
+
if isinstance(cachefile, dict):
|
1477 |
+
dat = cachefile
|
1478 |
+
cachefile = "state" # for printed messages
|
1479 |
+
else:
|
1480 |
+
dat = unbox_numpy_null(numpy.load(cachefile))
|
1481 |
+
for a, v in args.items():
|
1482 |
+
if a not in dat or dat[a] != v:
|
1483 |
+
if not quiet:
|
1484 |
+
print("%s %s changed from %s to %s" % (cachefile, a, dat[a], v))
|
1485 |
+
return None
|
1486 |
+
except (FileNotFoundError, ValueError) as e:
|
1487 |
+
if throw:
|
1488 |
+
raise e
|
1489 |
+
return None
|
1490 |
+
else:
|
1491 |
+
if not quiet:
|
1492 |
+
print("Loading cached %s" % cachefile)
|
1493 |
+
return dat
|
1494 |
+
|
1495 |
+
|
1496 |
+
def save_cached_state(cachefile, obj, args):
|
1497 |
+
"""
|
1498 |
+
Saves the state_dict of the given object in a dict or npz file.
|
1499 |
+
"""
|
1500 |
+
if cachefile is None:
|
1501 |
+
return
|
1502 |
+
dat = obj.state_dict()
|
1503 |
+
for a, v in args.items():
|
1504 |
+
if a in dat:
|
1505 |
+
assert dat[a] == v
|
1506 |
+
dat[a] = v
|
1507 |
+
if isinstance(cachefile, dict):
|
1508 |
+
cachefile.clear()
|
1509 |
+
cachefile.update(dat)
|
1510 |
+
else:
|
1511 |
+
os.makedirs(os.path.dirname(cachefile), exist_ok=True)
|
1512 |
+
numpy.savez(cachefile, **box_numpy_null(dat))
|
1513 |
+
|
1514 |
+
|
1515 |
+
class FixedSubsetSampler(Sampler):
|
1516 |
+
"""Represents a fixed sequence of data set indices.
|
1517 |
+
Subsets can be created by specifying a subset of output indexes.
|
1518 |
+
"""
|
1519 |
+
|
1520 |
+
def __init__(self, samples):
|
1521 |
+
self.samples = samples
|
1522 |
+
|
1523 |
+
def __iter__(self):
|
1524 |
+
return iter(self.samples)
|
1525 |
+
|
1526 |
+
def __len__(self):
|
1527 |
+
return len(self.samples)
|
1528 |
+
|
1529 |
+
def __getitem__(self, key):
|
1530 |
+
return self.samples[key]
|
1531 |
+
|
1532 |
+
def subset(self, new_subset):
|
1533 |
+
return FixedSubsetSampler(self.dereference(new_subset))
|
1534 |
+
|
1535 |
+
def dereference(self, indices):
|
1536 |
+
"""
|
1537 |
+
Translate output sample indices (small numbers indexing the sample)
|
1538 |
+
to input sample indices (larger number indexing the original full set)
|
1539 |
+
"""
|
1540 |
+
return [self.samples[i] for i in indices]
|
1541 |
+
|
1542 |
+
|
1543 |
+
class FixedRandomSubsetSampler(FixedSubsetSampler):
|
1544 |
+
"""Samples a fixed number of samples from the dataset, deterministically.
|
1545 |
+
Arguments:
|
1546 |
+
data_source,
|
1547 |
+
sample_size,
|
1548 |
+
seed (optional)
|
1549 |
+
"""
|
1550 |
+
|
1551 |
+
def __init__(self, data_source, start=None, end=None, seed=1):
|
1552 |
+
rng = random.Random(seed)
|
1553 |
+
shuffled = list(range(len(data_source)))
|
1554 |
+
rng.shuffle(shuffled)
|
1555 |
+
self.data_source = data_source
|
1556 |
+
super(FixedRandomSubsetSampler, self).__init__(shuffled[start:end])
|
1557 |
+
|
1558 |
+
def class_subset(self, class_filter):
|
1559 |
+
"""
|
1560 |
+
Returns only the subset matching the given rule.
|
1561 |
+
"""
|
1562 |
+
if isinstance(class_filter, int):
|
1563 |
+
|
1564 |
+
def rule(d):
|
1565 |
+
return d[1] == class_filter
|
1566 |
+
|
1567 |
+
else:
|
1568 |
+
rule = class_filter
|
1569 |
+
return self.subset(
|
1570 |
+
[i for i, j in enumerate(self.samples) if rule(self.data_source[j])]
|
1571 |
+
)
|
1572 |
+
|
1573 |
+
|
1574 |
+
def make_loader(
|
1575 |
+
dataset, sample_size=None, batch_size=1, sampler=None, random_sample=None, **kwargs
|
1576 |
+
):
|
1577 |
+
"""Utility for creating a dataloader on fixed sample subset."""
|
1578 |
+
import typing
|
1579 |
+
|
1580 |
+
if isinstance(dataset, typing.Callable):
|
1581 |
+
# To support deferred dataset loading, support passing a factory
|
1582 |
+
# that creates the dataset when called.
|
1583 |
+
dataset = dataset()
|
1584 |
+
if isinstance(dataset, torch.Tensor):
|
1585 |
+
# The dataset can be a simple tensor.
|
1586 |
+
dataset = torch.utils.data.TensorDataset(dataset)
|
1587 |
+
if sample_size is not None:
|
1588 |
+
assert sampler is None, "sampler cannot be specified with sample_size"
|
1589 |
+
if sample_size > len(dataset):
|
1590 |
+
print(
|
1591 |
+
"Warning: sample size %d > dataset size %d"
|
1592 |
+
% (sample_size, len(dataset))
|
1593 |
+
)
|
1594 |
+
sample_size = len(dataset)
|
1595 |
+
if random_sample is None:
|
1596 |
+
sampler = FixedSubsetSampler(list(range(sample_size)))
|
1597 |
+
else:
|
1598 |
+
sampler = FixedRandomSubsetSampler(
|
1599 |
+
dataset, seed=random_sample, end=sample_size
|
1600 |
+
)
|
1601 |
+
return torch.utils.data.DataLoader(
|
1602 |
+
dataset, sampler=sampler, batch_size=batch_size, **kwargs
|
1603 |
+
)
|
1604 |
+
|
1605 |
+
|
1606 |
+
# Unit Tests
|
1607 |
+
def _unit_test():
|
1608 |
+
import warnings
|
1609 |
+
|
1610 |
+
warnings.filterwarnings("error")
|
1611 |
+
import argparse
|
1612 |
+
import random
|
1613 |
+
import shutil
|
1614 |
+
import tempfile
|
1615 |
+
import time
|
1616 |
+
|
1617 |
+
parser = argparse.ArgumentParser(description="Test things out")
|
1618 |
+
parser.add_argument("--mode", default="cpu", help="cpu or cuda")
|
1619 |
+
parser.add_argument("--test_size", type=int, default=1000000)
|
1620 |
+
args = parser.parse_args()
|
1621 |
+
testdir = tempfile.mkdtemp()
|
1622 |
+
batch_size = random.randint(500, 1500)
|
1623 |
+
|
1624 |
+
# Test NaNboxing.
|
1625 |
+
assert numpy.isnan(null_numpy_value)
|
1626 |
+
assert is_null_numpy_value(null_numpy_value)
|
1627 |
+
assert not is_null_numpy_value(numpy.nan)
|
1628 |
+
|
1629 |
+
# Test Covariance
|
1630 |
+
goal = torch.tensor(numpy.random.RandomState(1).standard_normal(10 * 10)).view(
|
1631 |
+
10, 10
|
1632 |
+
)
|
1633 |
+
data = (
|
1634 |
+
torch.tensor(numpy.random.RandomState(2).standard_normal(args.test_size * 10))
|
1635 |
+
.view(args.test_size, 10)
|
1636 |
+
.mm(goal)
|
1637 |
+
)
|
1638 |
+
data += torch.randn(1, 10) * 999
|
1639 |
+
dcov = data.t().cov()
|
1640 |
+
dcorr = data.t().corrcoef()
|
1641 |
+
rcov = Covariance()
|
1642 |
+
rcov.add(data) # All one batch
|
1643 |
+
assert (rcov.covariance() - dcov).abs().max() < 1e-16
|
1644 |
+
cs = CombinedStat(cov=Covariance(), xcov=CrossCovariance())
|
1645 |
+
ds = torch.utils.data.TensorDataset(data)
|
1646 |
+
for [a] in tally(cs, ds, batch_size=9876):
|
1647 |
+
cs.cov.add(a)
|
1648 |
+
cs.xcov.add(a[:, :3], a[:, 3:])
|
1649 |
+
assert (data.mean(0) - cs.cov.mean()).abs().max() < 1e-12
|
1650 |
+
assert (dcov - cs.cov.covariance()).abs().max() < 2e-12
|
1651 |
+
assert (dcov[:3, 3:] - cs.xcov.covariance()).abs().max() < 1e-12
|
1652 |
+
assert (dcov.diagonal() - torch.cat(cs.xcov.variance())).abs().max() < 1e-12
|
1653 |
+
assert (dcorr - cs.cov.correlation()).abs().max() < 2e-12
|
1654 |
+
|
1655 |
+
# Test CrossCovariance and CrossIoU
|
1656 |
+
fn = f"{testdir}/cross_cache.npz"
|
1657 |
+
ds = torch.utils.data.TensorDataset(
|
1658 |
+
(
|
1659 |
+
torch.arange(args.test_size)[:, None] % torch.arange(1, 6)[None, :] == 0
|
1660 |
+
).double(),
|
1661 |
+
(
|
1662 |
+
torch.arange(args.test_size)[:, None] % torch.arange(5, 8)[None, :] == 0
|
1663 |
+
).double(),
|
1664 |
+
)
|
1665 |
+
c = CombinedStat(c=CrossCovariance(), iou=CrossIoU())
|
1666 |
+
riou = IoU()
|
1667 |
+
count = 0
|
1668 |
+
for [a, b] in tally(c, ds, cache=fn, batch_size=100):
|
1669 |
+
count += 1
|
1670 |
+
c.add(a, b)
|
1671 |
+
riou.add(torch.cat([a, b], dim=1))
|
1672 |
+
assert count == -(-args.test_size // 100)
|
1673 |
+
cor = c.c.correlation()
|
1674 |
+
iou = c.iou.iou()
|
1675 |
+
assert cor.shape == iou.shape == (5, 3)
|
1676 |
+
assert iou[4, 0] == 1.0
|
1677 |
+
assert abs(iou[0, 2] + (-args.test_size // 7 / float(args.test_size))) < 1e-6
|
1678 |
+
assert abs(cor[4, 0] - 1.0) < 1e-2
|
1679 |
+
assert abs(cor[0, 2] - 0.0) < 1e-6
|
1680 |
+
assert all((riou.iou()[:5, -3:] == iou).view(-1))
|
1681 |
+
assert all(riou.iou().diagonal(0) == 1)
|
1682 |
+
c = CombinedStat(c=CrossCovariance(), iou=CrossIoU())
|
1683 |
+
count = 0
|
1684 |
+
for [a, b] in tally(c, ds, cache=fn, batch_size=10):
|
1685 |
+
count += 1
|
1686 |
+
c.add(a, b)
|
1687 |
+
assert count == 0
|
1688 |
+
assert all((c.c.correlation() == cor).view(-1))
|
1689 |
+
assert all((c.iou.iou() == iou).view(-1))
|
1690 |
+
|
1691 |
+
# Test Concatantaion, Mean, Bincount and tally.
|
1692 |
+
fn = f"{testdir}/series_cache.npz"
|
1693 |
+
count = 0
|
1694 |
+
ds = torch.utils.data.TensorDataset(torch.arange(args.test_size))
|
1695 |
+
c = CombinedStat(s=History(), m=Mean(), b=Bincount())
|
1696 |
+
for [b] in tally(c, ds, cache=fn, batch_size=batch_size):
|
1697 |
+
count += 1
|
1698 |
+
c.add(b)
|
1699 |
+
assert count == -(-args.test_size // batch_size)
|
1700 |
+
assert len(c.s.history()) == args.test_size
|
1701 |
+
assert c.s.history()[-1] == args.test_size - 1
|
1702 |
+
assert all(c.s.history() == ds.tensors[0])
|
1703 |
+
assert all(c.b.bincount() == torch.ones(args.test_size))
|
1704 |
+
assert c.m.mean() == float(args.test_size - 1) / 2.0
|
1705 |
+
c2 = CombinedStat(s=History(), m=Mean(), b=Bincount())
|
1706 |
+
batches = tally(c2, ds, cache=fn)
|
1707 |
+
assert len(c2.s.history()) == args.test_size
|
1708 |
+
assert all(c2.s.history() == c.s.history())
|
1709 |
+
assert all(c2.b.bincount() == torch.ones(args.test_size))
|
1710 |
+
assert c2.m.mean() == c.m.mean()
|
1711 |
+
count = 0
|
1712 |
+
for b in batches:
|
1713 |
+
count += 1
|
1714 |
+
assert count == 0 # Shouldn't do anything when it's cached
|
1715 |
+
|
1716 |
+
# An adverarial case: we keep finding more numbers in the middle
|
1717 |
+
# as the stream goes on.
|
1718 |
+
amount = args.test_size
|
1719 |
+
quantiles = 1000
|
1720 |
+
data = numpy.arange(float(amount))
|
1721 |
+
data[1::2] = data[-1::-2] + (len(data) - 1)
|
1722 |
+
data /= 2
|
1723 |
+
depth = 50
|
1724 |
+
alldata = data[:, None] + (numpy.arange(depth) * amount)[None, :]
|
1725 |
+
actual_sum = torch.FloatTensor(numpy.sum(alldata * alldata, axis=0))
|
1726 |
+
amt = amount // depth
|
1727 |
+
for r in range(depth):
|
1728 |
+
numpy.random.shuffle(alldata[r * amt : r * amt + amt, r])
|
1729 |
+
if args.mode == "cuda":
|
1730 |
+
alldata = torch.cuda.FloatTensor(alldata)
|
1731 |
+
device = torch.device("cuda")
|
1732 |
+
else:
|
1733 |
+
alldata = torch.FloatTensor(alldata)
|
1734 |
+
device = None
|
1735 |
+
starttime = time.time()
|
1736 |
+
cs = CombinedStat(
|
1737 |
+
qc=Quantile(),
|
1738 |
+
m=Mean(),
|
1739 |
+
v=Variance(),
|
1740 |
+
c=Covariance(),
|
1741 |
+
s=SecondMoment(),
|
1742 |
+
t=TopK(),
|
1743 |
+
i=IoU(),
|
1744 |
+
)
|
1745 |
+
# Feed data in little batches
|
1746 |
+
i = 0
|
1747 |
+
while i < len(alldata):
|
1748 |
+
batch_size = numpy.random.randint(1000)
|
1749 |
+
cs.add(alldata[i : i + batch_size])
|
1750 |
+
i += batch_size
|
1751 |
+
# Test state dict
|
1752 |
+
saved = cs.state_dict()
|
1753 |
+
# numpy.savez(f'{testdir}/saved.npz', **box_numpy_null(saved))
|
1754 |
+
# saved = unbox_numpy_null(numpy.load(f'{testdir}/saved.npz'))
|
1755 |
+
cs.save(f"{testdir}/saved.npz")
|
1756 |
+
loaded = unbox_numpy_null(numpy.load(f"{testdir}/saved.npz"))
|
1757 |
+
assert set(loaded.keys()) == set(saved.keys())
|
1758 |
+
|
1759 |
+
# Restore using state=saved in constructor.
|
1760 |
+
cs2 = CombinedStat(
|
1761 |
+
qc=Quantile(),
|
1762 |
+
m=Mean(),
|
1763 |
+
v=Variance(),
|
1764 |
+
c=Covariance(),
|
1765 |
+
s=SecondMoment(),
|
1766 |
+
t=TopK(),
|
1767 |
+
i=IoU(),
|
1768 |
+
state=saved,
|
1769 |
+
)
|
1770 |
+
# saved = unbox_numpy_null(numpy.load(f'{testdir}/saved.npz'))
|
1771 |
+
assert not cs2.qc.device.type == "cuda"
|
1772 |
+
cs2.to_(device)
|
1773 |
+
# alldata = alldata.cpu()
|
1774 |
+
cs2.add(alldata)
|
1775 |
+
actual_sum *= 2
|
1776 |
+
# print(abs(alldata.mean(0) - cs2.m.mean()) / alldata.mean())
|
1777 |
+
assert all(abs(alldata.mean(0) - cs2.m.mean()) / alldata.mean() < 1e-5)
|
1778 |
+
assert all(abs(alldata.mean(0) - cs2.v.mean()) / alldata.mean() < 1e-5)
|
1779 |
+
assert all(abs(alldata.mean(0) - cs2.c.mean()) / alldata.mean() < 1e-5)
|
1780 |
+
# print(abs(alldata.var(0) - cs2.v.variance()) / alldata.var(0))
|
1781 |
+
assert all(abs(alldata.var(0) - cs2.v.variance()) / alldata.var(0) < 1e-3)
|
1782 |
+
assert all(abs(alldata.var(0) - cs2.c.variance()) / alldata.var(0) < 1e-2)
|
1783 |
+
# print(abs(alldata.std(0) - cs2.v.stdev()) / alldata.std(0))
|
1784 |
+
assert all(abs(alldata.std(0) - cs2.v.stdev()) / alldata.std(0) < 1e-4)
|
1785 |
+
# print(abs(alldata.std(0) - cs2.c.stdev()) / alldata.std(0))
|
1786 |
+
assert all(abs(alldata.std(0) - cs2.c.stdev()) / alldata.std(0) < 2e-3)
|
1787 |
+
moment = (alldata.t() @ alldata) / len(alldata)
|
1788 |
+
# print(abs(moment - cs2.s.moment()) / moment.abs())
|
1789 |
+
assert all((abs(moment - cs2.s.moment()) / moment.abs()).view(-1) < 1e-2)
|
1790 |
+
assert all(alldata.max(dim=0)[0] == cs2.t.topk()[0][:, 0])
|
1791 |
+
assert cs2.i.iou()[0, 0] == 1
|
1792 |
+
assert all((cs2.i.iou()[1:, 1:] == 1).view(-1))
|
1793 |
+
assert all(cs2.i.iou()[1:, 0] < 1)
|
1794 |
+
assert all(cs2.i.iou()[1:, 0] == cs2.i.iou()[0, 1:])
|
1795 |
+
|
1796 |
+
# Restore using cs.load() method.
|
1797 |
+
cs = CombinedStat(
|
1798 |
+
qc=Quantile(),
|
1799 |
+
m=Mean(),
|
1800 |
+
v=Variance(),
|
1801 |
+
c=Covariance(),
|
1802 |
+
s=SecondMoment(),
|
1803 |
+
t=TopK(),
|
1804 |
+
i=IoU(),
|
1805 |
+
)
|
1806 |
+
cs.load(f"{testdir}/saved.npz")
|
1807 |
+
assert not cs.qc.device.type == "cuda"
|
1808 |
+
cs.to_(device)
|
1809 |
+
cs.add(alldata)
|
1810 |
+
# actual_sum *= 2
|
1811 |
+
# print(abs(alldata.mean(0) - cs.m.mean()) / alldata.mean())
|
1812 |
+
assert all(abs(alldata.mean(0) - cs.m.mean()) / alldata.mean() < 1e-5)
|
1813 |
+
assert all(abs(alldata.mean(0) - cs.v.mean()) / alldata.mean() < 1e-5)
|
1814 |
+
assert all(abs(alldata.mean(0) - cs.c.mean()) / alldata.mean() < 1e-5)
|
1815 |
+
# print(abs(alldata.var(0) - cs.v.variance()) / alldata.var(0))
|
1816 |
+
assert all(abs(alldata.var(0) - cs.v.variance()) / alldata.var(0) < 1e-3)
|
1817 |
+
assert all(abs(alldata.var(0) - cs.c.variance()) / alldata.var(0) < 1e-2)
|
1818 |
+
# print(abs(alldata.std(0) - cs.v.stdev()) / alldata.std(0))
|
1819 |
+
assert all(abs(alldata.std(0) - cs.v.stdev()) / alldata.std(0) < 1e-4)
|
1820 |
+
# print(abs(alldata.std(0) - cs.c.stdev()) / alldata.std(0))
|
1821 |
+
assert all(abs(alldata.std(0) - cs.c.stdev()) / alldata.std(0) < 2e-3)
|
1822 |
+
moment = (alldata.t() @ alldata) / len(alldata)
|
1823 |
+
# print(abs(moment - cs.s.moment()) / moment.abs())
|
1824 |
+
assert all((abs(moment - cs.s.moment()) / moment.abs()).view(-1) < 1e-2)
|
1825 |
+
assert all(alldata.max(dim=0)[0] == cs.t.topk()[0][:, 0])
|
1826 |
+
assert cs.i.iou()[0, 0] == 1
|
1827 |
+
assert all((cs.i.iou()[1:, 1:] == 1).view(-1))
|
1828 |
+
assert all(cs.i.iou()[1:, 0] < 1)
|
1829 |
+
assert all(cs.i.iou()[1:, 0] == cs.i.iou()[0, 1:])
|
1830 |
+
|
1831 |
+
# Randomized quantile test
|
1832 |
+
qc = cs.qc
|
1833 |
+
ro = qc.readout(1001).cpu()
|
1834 |
+
endtime = time.time()
|
1835 |
+
gt = (
|
1836 |
+
torch.linspace(0, amount, quantiles + 1)[None, :]
|
1837 |
+
+ (torch.arange(qc.depth, dtype=torch.float) * amount)[:, None]
|
1838 |
+
)
|
1839 |
+
maxreldev = torch.max(torch.abs(ro - gt) / amount) * quantiles
|
1840 |
+
print("Randomized quantile test results:")
|
1841 |
+
print("Maximum relative deviation among %d perentiles: %f" % (quantiles, maxreldev))
|
1842 |
+
minerr = torch.max(
|
1843 |
+
torch.abs(
|
1844 |
+
qc.minmax().cpu()[:, 0] - torch.arange(qc.depth, dtype=torch.float) * amount
|
1845 |
+
)
|
1846 |
+
)
|
1847 |
+
maxerr = torch.max(
|
1848 |
+
torch.abs(
|
1849 |
+
(qc.minmax().cpu()[:, -1] + 1)
|
1850 |
+
- (torch.arange(qc.depth, dtype=torch.float) + 1) * amount
|
1851 |
+
)
|
1852 |
+
)
|
1853 |
+
print("Minmax error %f, %f" % (minerr, maxerr))
|
1854 |
+
interr = torch.max(
|
1855 |
+
torch.abs(qc.integrate(lambda x: x * x).cpu() - actual_sum) / actual_sum
|
1856 |
+
)
|
1857 |
+
print("Integral error: %f" % interr)
|
1858 |
+
medianerr = torch.max(
|
1859 |
+
torch.abs(qc.median() - alldata.median(0)[0]) / alldata.median(0)[0]
|
1860 |
+
).cpu()
|
1861 |
+
print("Median error: %f" % medianerr)
|
1862 |
+
meanerr = torch.max(torch.abs(qc.mean() - alldata.mean(0)) / alldata.mean(0)).cpu()
|
1863 |
+
print("Mean error: %f" % meanerr)
|
1864 |
+
varerr = torch.max(torch.abs(qc.variance() - alldata.var(0)) / alldata.var(0)).cpu()
|
1865 |
+
print("Variance error: %f" % varerr)
|
1866 |
+
counterr = (
|
1867 |
+
(qc.integrate(lambda x: torch.ones(x.shape[-1]).cpu()) - qc.size())
|
1868 |
+
/ (0.0 + qc.size())
|
1869 |
+
).item()
|
1870 |
+
print("Count error: %f" % counterr)
|
1871 |
+
print("Time %f" % (endtime - starttime))
|
1872 |
+
# Algorithm is randomized, so some of these will fail with low probability.
|
1873 |
+
assert maxreldev < 1.0
|
1874 |
+
assert minerr == 0.0
|
1875 |
+
assert maxerr == 0.0
|
1876 |
+
assert interr < 0.01
|
1877 |
+
assert abs(counterr) < 0.001
|
1878 |
+
shutil.rmtree(testdir, ignore_errors=True)
|
1879 |
+
print("OK")
|
1880 |
+
|
1881 |
+
|
1882 |
+
if __name__ == "__main__":
|
1883 |
+
_unit_test()
|
hparams/GRACE/README.md
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
alg_name: "GRACE"
|
2 |
+
model_name: "./hugging_cache/gpt-j-6B"
|
3 |
+
device: 0
|
4 |
+
|
5 |
+
inner_params:
|
6 |
+
- transformer.h[25].mlp.fc_out.weight
|
7 |
+
|
8 |
+
edit_lr: 1.0
|
9 |
+
n_iter: 200
|
10 |
+
eps: 1.0
|
11 |
+
dist_fn: euc # euc, mmd, cos
|
12 |
+
val_init: cold # cold, warm
|
13 |
+
val_train: sgd # sgd, pert
|
14 |
+
val_reg: None # early
|
15 |
+
reg: early_stop # early_stop
|
16 |
+
replacement: replace_last # replace_last, replace_all, replace_prompt
|
17 |
+
eps_expand: coverage # , moving_avg, decay
|
18 |
+
num_pert: 8 # only matters when using perturbation training
|
19 |
+
dropout: 0.0
|
hparams/GRACE/gpt2-xl.yaml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
alg_name: "GRACE"
|
2 |
+
model_name: "./hugging_cache/gpt2-xl"
|
3 |
+
device: 0
|
4 |
+
|
5 |
+
inner_params:
|
6 |
+
- transformer.h[35].mlp.c_fc.weight
|
7 |
+
|
8 |
+
edit_lr: 1.0
|
9 |
+
n_iter: 50
|
10 |
+
eps: 1.0
|
11 |
+
dist_fn: euc # euc, mmd, cos
|
12 |
+
val_init: cold # cold, warm
|
13 |
+
val_train: sgd # sgd, pert
|
14 |
+
val_reg: None # early
|
15 |
+
reg: early_stop # early_stop
|
16 |
+
replacement: replace_last # replace_last, replace_all, replace_prompt
|
17 |
+
eps_expand: coverage # , moving_avg, decay
|
18 |
+
num_pert: 8 # only matters when using perturbation training
|
19 |
+
dropout: 0.0
|
hparams/config.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
save_dir: models/
|
2 |
+
log_dir: logs/
|
3 |
+
|
4 |
+
defaults:
|
5 |
+
alg_name: KN # Editing Method
|
6 |
+
hparams_name: KN/t5-3b # Edited Model Config Path
|
utils.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
|
2 |
+
from transformers import GPT2TokenizerFast, GPT2Tokenizer
|
3 |
+
from easyeditor import apply_grace_to_model, GraceHyperParams,nethook
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
def edit(prompt, target_new):
|
9 |
+
request={"prompt":prompt,"target_new":target_new}
|
10 |
+
hparams = GraceHyperParams.from_hparams("./hparams/GRACE/gpt2-xl.yaml")
|
11 |
+
|
12 |
+
model = AutoModelForCausalLM.from_pretrained("./models/gpt2-xl")
|
13 |
+
tok = GPT2Tokenizer.from_pretrained("./models/gpt2-xl")
|
14 |
+
tok.pad_token_id = tok.eos_token_id
|
15 |
+
global edit_model
|
16 |
+
edit_model,_ = apply_grace_to_model(model,tok,request,hparams,keep_original_weight=True)
|
17 |
+
return "finish"
|
18 |
+
|
19 |
+
def generate(input_text):
|
20 |
+
tok = GPT2Tokenizer.from_pretrained("./models/gpt2-xl")
|
21 |
+
hparams = GraceHyperParams.from_hparams("./hparams/GRACE/gpt2-xl.yaml")
|
22 |
+
tok.pad_token_id = tok.eos_token_id
|
23 |
+
|
24 |
+
global edit_model
|
25 |
+
|
26 |
+
input_ids = tok.encode(input_text, return_tensors='pt').to(f'cuda:{hparams.device}')
|
27 |
+
edit_output = edit_model.generate(input_ids, max_length=30, pad_token_id=tok.eos_token_id)
|
28 |
+
edit_reply = tok.decode(edit_output[0], skip_special_tokens=True)
|
29 |
+
del edit_model
|
30 |
+
torch.cuda.empty_cache()
|
31 |
+
|
32 |
+
ori_model = AutoModelForCausalLM.from_pretrained("./models/gpt2-xl").to(f'cuda:{hparams.device}')
|
33 |
+
ori_output = ori_model.generate(input_ids, max_length=30, pad_token_id=tok.eos_token_id)
|
34 |
+
ori_reply = tok.decode(ori_output[0], skip_special_tokens=True)
|
35 |
+
|
36 |
+
return ori_reply, edit_reply
|