ZekunXi commited on
Commit
8124a18
·
1 Parent(s): d0f3258

Add application file

Browse files
Files changed (35) hide show
  1. easyeditor/__init__.py +2 -0
  2. easyeditor/__pycache__/__init__.cpython-39.pyc +0 -0
  3. easyeditor/models/README.md +6 -0
  4. easyeditor/models/__init__.py +1 -0
  5. easyeditor/models/__pycache__/__init__.cpython-39.pyc +0 -0
  6. easyeditor/models/grace/GRACE.py +218 -0
  7. easyeditor/models/grace/__init__.py +2 -0
  8. easyeditor/models/grace/__pycache__/GRACE.cpython-39.pyc +0 -0
  9. easyeditor/models/grace/__pycache__/__init__.cpython-39.pyc +0 -0
  10. easyeditor/models/grace/__pycache__/grace_hparams.cpython-39.pyc +0 -0
  11. easyeditor/models/grace/__pycache__/grace_main.cpython-39.pyc +0 -0
  12. easyeditor/models/grace/__pycache__/metrics.cpython-39.pyc +0 -0
  13. easyeditor/models/grace/__pycache__/utils.cpython-39.pyc +0 -0
  14. easyeditor/models/grace/grace_hparams.py +48 -0
  15. easyeditor/models/grace/grace_main.py +38 -0
  16. easyeditor/models/grace/metrics.py +59 -0
  17. easyeditor/models/grace/utils.py +86 -0
  18. easyeditor/util/__init__.py +2 -0
  19. easyeditor/util/__pycache__/__init__.cpython-39.pyc +0 -0
  20. easyeditor/util/__pycache__/hparams.cpython-39.pyc +0 -0
  21. easyeditor/util/__pycache__/logit_lens.cpython-39.pyc +0 -0
  22. easyeditor/util/__pycache__/nethook.cpython-39.pyc +0 -0
  23. easyeditor/util/alg_dict.py +45 -0
  24. easyeditor/util/alg_train_dict.py +9 -0
  25. easyeditor/util/generate.py +171 -0
  26. easyeditor/util/globals.py +43 -0
  27. easyeditor/util/hparams.py +46 -0
  28. easyeditor/util/logit_lens.py +97 -0
  29. easyeditor/util/nethook.py +451 -0
  30. easyeditor/util/perplexity.py +24 -0
  31. easyeditor/util/runningstats.py +1883 -0
  32. hparams/GRACE/README.md +19 -0
  33. hparams/GRACE/gpt2-xl.yaml +19 -0
  34. hparams/config.yaml +6 -0
  35. utils.py +36 -0
easyeditor/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .models import *
2
+ from .util import *
easyeditor/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (174 Bytes). View file
 
easyeditor/models/README.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ We compare ROME against several open sourced state-of-the-art model editors. All are implemented in their respective folders. Implementations other than FT/FT+L are adapted from third parties.
2
+ - Fine-Tuning (`ft`): Direct fine-tuning.
3
+ - Constrained Fine-Tuning (`ft`): FT with $L_\infty$ norm constraint. Inspired by Zhu et al. [[Paper]](https://arxiv.org/abs/2012.00363)
4
+ - Knowledge Neurons (`kn`): Dai et al. [[Code]](https://github.com/EleutherAI/knowledge-neurons) [[Paper]](https://arxiv.org/abs/2104.08696)
5
+ - Knowledge Editor (`efk`): De Cao et al. [[Code]](https://github.com/eric-mitchell/mend) [[Paper]](https://arxiv.org/abs/2104.08164)
6
+ - Model Editor Networks with Gradient Decomposition (`mend`): Mitchell et al. [[Code]](https://github.com/eric-mitchell/mend) [[Paper]](https://arxiv.org/abs/2110.11309)
easyeditor/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .grace import *
easyeditor/models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (164 Bytes). View file
 
easyeditor/models/grace/GRACE.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .utils import parent_module, brackets_to_periods
3
+ import transformers
4
+ import os
5
+ os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
6
+
7
+ def euc(query, key):
8
+ # Euclidean distance
9
+ if len(key.shape) < 2:
10
+ key = key.view(1, -1)
11
+ return torch.cdist(key, query, p=2)
12
+
13
+ def perturb_values(chosen_value, num_pert, device):
14
+ # Create a bunch of noised versions of the value, then create batch, then train value
15
+ chosen_value = chosen_value
16
+ noise = torch.normal(0, 1, chosen_value.shape, device=device)
17
+ noise[0] = noise[0]*0
18
+ noise.requires_grad = True
19
+ chosen_value = chosen_value + noise
20
+ return chosen_value
21
+
22
+ class GRACE(torch.nn.Module):
23
+ def __init__(self, config, model, device):
24
+ super(GRACE, self).__init__()
25
+ self.config = config
26
+ self.log_dict = {}
27
+ self.model = model
28
+ # self.tokenizer = model.tokenizer
29
+ layer = config.inner_params[0]
30
+ self.device = device
31
+
32
+ # --- ensure proper formatting (GRACE edits ~layers~ not weights matrices) ---
33
+ suffixes = [".weight", ".bias"]
34
+ self.layer = layer.rsplit(".", 1)[0] if any(layer.endswith(x) for x in suffixes) else layer
35
+
36
+ for n, p in self.model.named_parameters():
37
+ p.requires_grad = False
38
+
39
+ if isinstance(self.model, transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel):
40
+ transpose = False
41
+ else:
42
+ transpose = True
43
+
44
+ # --- Add GRACE to chosen layers ---
45
+ edit_module = parent_module(self.model, brackets_to_periods(self.layer))
46
+ layer_name = self.layer.rsplit(".", 1)[-1]
47
+ original_layer = getattr(edit_module, layer_name)
48
+
49
+ if type(original_layer) is not GRACEAdapter:
50
+ setattr(edit_module, layer_name, GRACEAdapter(config, original_layer, transpose=transpose).to(self.device))
51
+
52
+ def __call__(self, **kwargs):
53
+ # if self.config.task == "hallucination":
54
+ # print(kwargs)
55
+ # key_id = (kwargs["labels"] == -100).sum() - 1
56
+ # setattr(eval(f"self.model.{self.layer}"), "key_id", key_id) # Tell GRACE which token to use for its query (default is the last token)
57
+ return self.model(**kwargs)
58
+
59
+ def generate(self, *args, **kwargs):
60
+ setattr(eval(f"self.model.{self.layer}"), "key_id", -1)
61
+ return self.model.generate(*args, **kwargs)
62
+
63
+ def edit(self, config, tokens):
64
+ key_id = (tokens["labels"] == -100).sum() - 1
65
+ setattr(eval(f"self.model.{self.layer}"), "key_id", key_id)
66
+
67
+ # --- pass edit label, training mode, and key_id into GRACE ---
68
+ setattr(eval(f"self.model.{self.layer}"), "training", True)
69
+ setattr(eval(f"self.model.{self.layer}"), "edit_label", tokens["labels"])
70
+
71
+ self.losses = []
72
+ # --- train GRACE value ---
73
+ for i in range(config.n_iter):
74
+ # --- insert iteration into each layer (only initiate keys on iteration 1) ---
75
+ setattr(eval(f"self.model.{self.layer}"), "iter", i)
76
+
77
+ # --- pass tokens through model (including through the GRACE layer) ---
78
+ outputs = self.model(**tokens)
79
+ if i == 0:
80
+ # --- we only need to create an optimizer for the first iteration (but forward pass instantiates the key, so optimzer is passed after first inference) ---
81
+ optimizer = torch.optim.Adam(self.model.parameters(), config.edit_lr)
82
+ loss = outputs.loss
83
+ loss.backward()
84
+ optimizer.step()
85
+ optimizer.zero_grad()
86
+ self.losses.append(loss.detach().cpu().numpy())
87
+
88
+ self.loss = loss # Log final loss
89
+
90
+ # --- pull out info we want to log from the GRACE layer ---
91
+ setattr(eval(f"self.model.{self.layer}"), "training", False)
92
+ chosen_key = getattr(eval(f"self.model.{self.layer}"), "chosen_key")
93
+ nkeys = len(getattr(eval(f"self.model.{self.layer}"), "keys"))
94
+
95
+ self.log_dict["chosen_key"] = chosen_key
96
+ self.log_dict["nkeys"] = nkeys
97
+
98
+ class GRACEAdapter(torch.nn.Module):
99
+ def __init__(self, config, layer, transpose):
100
+ super(GRACEAdapter, self).__init__()
101
+
102
+ self.layer = layer
103
+ self.weight = self.layer.weight
104
+ self.init_epsilon = config.eps
105
+ self.dist_fn = config.dist_fn
106
+ self.replacement = config.replacement
107
+ self.device = layer.weight.device
108
+ self.config = config
109
+ self.num_pert = config.num_pert
110
+ self.key_id = -1
111
+ self.ensure_replace_token_loc = False
112
+
113
+ if transpose:
114
+ self.key_shape = layer.weight.shape[1]
115
+ self.value_shape = layer.weight.shape[0]
116
+ else:
117
+ self.key_shape = layer.weight.shape[0]
118
+ self.value_shape = layer.weight.shape[1]
119
+ self.training = False
120
+
121
+ def add_key(self, new_key, new_value):
122
+ keys = torch.vstack([self.keys, new_key.detach()]) # Add new key to list of keys
123
+
124
+ values = torch.nn.Parameter(torch.vstack([self.values, new_value]), requires_grad=True) # Add new value to list of values
125
+
126
+ new_epsilon = torch.tensor(self.init_epsilon, device=self.device).view(1)
127
+ epsilons = torch.vstack([self.epsilons, new_epsilon]) # Add new epsilon to list of epsilons
128
+
129
+ key_labels = self.key_labels + [self.edit_label] # Add new key_label to list of key_labels
130
+
131
+ return keys, values, epsilons, key_labels
132
+
133
+ def init_key_value(self, query, value):
134
+ key = query.detach()
135
+ epsilon = torch.tensor(self.init_epsilon, device=self.device, requires_grad=False).view(1)
136
+ key_label = [self.edit_label]
137
+ return key, value, epsilon, key_label
138
+
139
+ def label_match(self, edit_label, key_label):
140
+ return edit_label.float().mean() == key_label.float().mean()
141
+
142
+ def split_epsilons_in_half(self, nearest_key, smallest_distance):
143
+ self.epsilons[nearest_key] = (smallest_distance / 2) - 1e-5 # Cut nearest epsilon in half
144
+ self.epsilons[-1] = smallest_distance / 2 # Cut new epsilon in half
145
+
146
+ def forward(self, *args):
147
+ # Run layer forward and save what it would have returned for this instance
148
+ layer_out = self.layer(*args)
149
+
150
+ ### If training, we need to modify the codebook
151
+ if (not self.training) & ('keys' not in self.__dict__):
152
+ # If it's not training time and we haven't added any keys yet (this is before doing any editing)
153
+ # print(self.__dict__)
154
+ return layer_out
155
+ else:
156
+ if not self.training and not self.ensure_replace_token_loc and self.key_id == -1:
157
+ token_to_edit = args[0].shape[1]-1
158
+ self.key_id = args[0].shape[1]-1
159
+ self.ensure_replace_token_loc = True
160
+ else:
161
+ token_to_edit = min(self.key_id, args[0].shape[1]-1) # args[0].shape[1] - 1 is sequence length
162
+ query = args[0][:, token_to_edit, :] # Just use activation for last token
163
+ if self.config.val_init == "cold":
164
+ new_value = torch.nn.Parameter(torch.rand(1, self.value_shape, requires_grad=True, device=self.device))
165
+ elif self.config.val_init == "warm":
166
+ new_value = torch.nn.Parameter(layer_out[:, token_to_edit, :].detach(), requires_grad=True)
167
+
168
+ if 'keys' not in self.__dict__:
169
+ # If no keys exist, initialize keys, values, epsilons, and key labels
170
+ self.keys, self.values, self.epsilons, self.key_labels = self.init_key_value(query, new_value)
171
+ elif self.iter == 0:
172
+ # Keys exist, so we have decide whether or not to update them (the fact that we've made it to this point means there was an error!)
173
+
174
+ # --- search through keys for a match for query ---
175
+ dists = torch.cdist(self.keys, query, p=2).view(-1, len(query))
176
+ smallest_distance, nearest_key = dists.min(0)
177
+
178
+ if smallest_distance > (self.init_epsilon + self.epsilons[nearest_key]):
179
+ # If there's no close key, make a new key
180
+ self.keys, self.values, self.epsilons, self.key_labels = self.add_key(query, new_value)
181
+ else:
182
+ # If there is a close key, we need to handle conflicts
183
+ if not self.label_match(self.edit_label, self.key_labels[nearest_key]):
184
+ self.keys, self.values, self.epsilons, self.key_labels = self.add_key(query, new_value)
185
+ self.split_epsilons_in_half(nearest_key, smallest_distance)
186
+ else:
187
+ # If the current label is the SAME as the nearest label, just make the nearest epsilon bigger
188
+ if smallest_distance > self.epsilons[nearest_key]:
189
+ if self.config.eps_expand== "coverage":
190
+ self.epsilons[nearest_key] = smallest_distance # Replace nearest epsilon with dist between old key and new key
191
+ elif self.config.eps_expand == "moving_average":
192
+ a = 0.5
193
+ self.keys[nearest_key] = a*self.keys[nearest_key] + (1-a)*query # Move old key to be halfway between
194
+ self.epsilons[nearest_key] = smallest_distance
195
+ # self.epsilons[nearest_key] = smallest_distance + self.init_epsilon
196
+ else:
197
+ # If not iter 0, we don't need to change keys, we just need to learn the value
198
+ pass
199
+ # print(token_to_edit)
200
+ # compute distance from query to all keys and find the closest keys
201
+ dists = torch.cdist(self.keys, query, p=2).view(-1, len(query))
202
+ smallest_dist, self.chosen_key = dists.min(0)
203
+ smallest_dist = smallest_dist.view(-1, 1)
204
+ chosen_value = self.values[self.chosen_key]
205
+ eps = self.epsilons[self.chosen_key].view(-1, 1)
206
+
207
+ if (self.config.val_train == "adv") and (self.training):
208
+ chosen_value = perturb_values(chosen_value, self.num_pert, self.device)
209
+
210
+ if self.replacement == "replace_all":
211
+ layer_out = torch.where((smallest_dist <= eps).view(-1, 1, 1), chosen_value.unsqueeze(1).repeat_interleave(layer_out.shape[1], 1), layer_out)
212
+ elif self.replacement == "replace_last":
213
+ layer_out[:, token_to_edit] = torch.where((smallest_dist <= eps), chosen_value, layer_out[:, token_to_edit])
214
+ elif self.replacement == "replace_prompt":
215
+ layer_out[:, :token_to_edit] = torch.where((smallest_dist <= eps), chosen_value, layer_out[:, :token_to_edit])
216
+ else:
217
+ print("token replacement choice not found")
218
+ return layer_out
easyeditor/models/grace/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .grace_main import GraceHyperParams, apply_grace_to_model
2
+ from .metrics import F1, PPL, Accuracy, is_qa_error, is_acc_error
easyeditor/models/grace/__pycache__/GRACE.cpython-39.pyc ADDED
Binary file (6.34 kB). View file
 
easyeditor/models/grace/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (342 Bytes). View file
 
easyeditor/models/grace/__pycache__/grace_hparams.cpython-39.pyc ADDED
Binary file (1.49 kB). View file
 
easyeditor/models/grace/__pycache__/grace_main.cpython-39.pyc ADDED
Binary file (1.12 kB). View file
 
easyeditor/models/grace/__pycache__/metrics.cpython-39.pyc ADDED
Binary file (2.07 kB). View file
 
easyeditor/models/grace/__pycache__/utils.cpython-39.pyc ADDED
Binary file (3.53 kB). View file
 
easyeditor/models/grace/grace_hparams.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List
3
+ from ...util.hparams import HyperParams
4
+ import yaml
5
+
6
+
7
+ @dataclass
8
+ class GraceHyperParams(HyperParams):
9
+ # Experiments
10
+
11
+ edit_lr: int
12
+ n_iter: int
13
+ # Method
14
+ eps: float
15
+ dist_fn: str
16
+ val_init: str
17
+ val_train: str
18
+ val_reg: str
19
+ reg: str
20
+ replacement: str
21
+ eps_expand: str
22
+ num_pert: str
23
+ dropout: float
24
+
25
+ # Module templates
26
+ inner_params: List[str]
27
+ device: int
28
+ alg_name: str
29
+ model_name: str
30
+
31
+ # Defaults
32
+ batch_size: int = 128
33
+ max_length: int = 30
34
+ model_parallel: bool = False
35
+
36
+ @classmethod
37
+ def from_hparams(cls, hparams_name_or_path: str):
38
+ if '.yaml' not in hparams_name_or_path:
39
+ hparams_name_or_path = hparams_name_or_path + '.yaml'
40
+
41
+ with open(hparams_name_or_path, "r") as stream:
42
+ config = yaml.safe_load(stream)
43
+ config = super().construct_float_from_scientific_notation(config)
44
+
45
+ assert (config and config['alg_name'] == 'GRACE') or print(
46
+ f'GraceHyperParams can not load from {hparams_name_or_path}, '
47
+ f'alg_name is {config["alg_name"]} ')
48
+ return cls(**config)
easyeditor/models/grace/grace_main.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Tuple
2
+ import torch
3
+ from copy import deepcopy
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ from .GRACE import GRACE
6
+ from .grace_hparams import GraceHyperParams
7
+ from .utils import tokenize
8
+ from ...util import nethook
9
+
10
+
11
+ def apply_grace_to_model(
12
+ model: AutoModelForCausalLM,
13
+ tok: AutoTokenizer,
14
+ requests: List[Dict],
15
+ hparams: GraceHyperParams,
16
+ copy=False,
17
+ return_orig_weights=False,
18
+ keep_original_weight=False,
19
+ **kwargs: Any,
20
+ ) -> Tuple[AutoModelForCausalLM, Dict[str, Any]]:
21
+ model.to(f'cuda:{hparams.device}')
22
+ request = requests
23
+ if copy:
24
+ model = deepcopy(model)
25
+ weights_copy = {}
26
+ device = torch.device(f'cuda:{hparams.device}')
27
+ editor = GRACE(model=model, config=hparams, device=device)
28
+
29
+ tokens = tokenize(request, tokenizer=tok, device=device)
30
+ editor.edit(config=hparams, tokens=tokens)
31
+
32
+ if not keep_original_weight:
33
+ weights_copy = {}
34
+
35
+ editor.to(f'cuda:{hparams.device}')
36
+ return editor, weights_copy
37
+
38
+
easyeditor/models/grace/metrics.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from .utils import *
4
+
5
+ def is_acc_error(model, tokens):
6
+ # Check whether or not the model's prediction for a batch element is correct
7
+ labels = tokens["labels"]
8
+ logits = model(**tokens).logits
9
+ probs = torch.softmax(logits, -1).squeeze()
10
+ argmaxs = torch.argmax(probs, dim=-1).squeeze()
11
+ return labels != argmaxs
12
+
13
+ def Accuracy(model, tokens):
14
+ labels = tokens["labels"]
15
+ new_tokens = {f"{k}" : v for k, v in tokens.items() if k != "labels"}
16
+ logits = model(**new_tokens).logits
17
+ probs = torch.softmax(logits, -1).squeeze()
18
+ argmaxs = torch.argmax(probs, dim=-1).squeeze()
19
+ return (labels == argmaxs).float().mean()
20
+
21
+ def is_qa_error(model, tokens):
22
+ preds = model.generate(tokens["input_ids"], max_length=20).squeeze() # Run model to get its predictions
23
+ labels = tokens["labels"]#[tokens["labels"] != -100]
24
+
25
+ if (len(preds) != len(labels)) or ((preds == labels).sum() != len(preds)):
26
+ return True
27
+ else:
28
+ return False
29
+
30
+ def PPL(model, batch):
31
+ input_ids = batch["input_ids"][:, :1024]#.to(device)
32
+ if "labels" not in batch:
33
+ target_ids = batch["input_ids"][:, :1024].clone()
34
+ else:
35
+ target_ids = batch["labels"][:, :1024].clone()
36
+
37
+ with torch.no_grad():
38
+ outputs = model(input_ids=input_ids, labels=target_ids)
39
+ nll = outputs.loss
40
+
41
+ ppl = torch.exp(nll)#.clip(0, 100)
42
+ return ppl
43
+
44
+ def F1(model, batch):
45
+ try:
46
+ preds = model.generate(batch["input_ids"], max_length=20).squeeze()
47
+ if len(preds) > 1:
48
+ preds = preds[preds != model.tokenizer.pad_token_id]
49
+ gold_toks = batch["labels"][batch["labels"] != -100].cpu().squeeze() # -100 might be nonsense
50
+ num_same = len(np.intersect1d(preds.cpu().squeeze(), gold_toks))
51
+ if (num_same == 0) or (len(preds.squeeze()) == 0):
52
+ return 0
53
+ precision = num_same / len(preds.squeeze())
54
+ recall = 1.0 * num_same / len(gold_toks)
55
+ f1 = (2 * precision * recall) / (precision + recall)
56
+ return f1
57
+ except:
58
+ # Every once in a while, the model just returns the stop token
59
+ return 0
easyeditor/models/grace/utils.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ import torch
3
+ import os
4
+ import numpy as np
5
+ import datetime
6
+ import struct
7
+ from torch.nn.utils.rnn import pad_sequence
8
+ import torch.nn.functional as F
9
+
10
+ def get_inner_params(named_parameters, inner_names):
11
+ param_dict = dict(named_parameters)
12
+ return [(n, param_dict[n]) for n in inner_names]
13
+
14
+ def param_subset(named_parameters, inner_names):
15
+ param_dict = dict(named_parameters)
16
+ return [param_dict[n] for n in inner_names]
17
+
18
+ def parent_module(model, pname):
19
+ components = pname.split('.')
20
+ parent = model
21
+
22
+ for component in components[:-1]:
23
+ if hasattr(parent, component):
24
+ parent = getattr(parent, component)
25
+ elif component.isdigit():
26
+ parent = parent[int(component)]
27
+ else:
28
+ raise RuntimeError(f"Couldn't find child module {component}")
29
+
30
+ if not hasattr(parent, components[-1]):
31
+ raise RuntimeError(f"Couldn't find child module {components[-1]}")
32
+
33
+ return parent
34
+
35
+ def uuid(digits=4):
36
+ if not hasattr(uuid, "uuid_value"):
37
+ uuid.uuid_value = struct.unpack('I', os.urandom(4))[0] % int(10**digits)
38
+
39
+ return uuid.uuid_value
40
+
41
+ def ckpt_dir():
42
+ """returns the directory in which to store model checkpoints"""
43
+ path = "./ckpts/"
44
+ if not os.path.exists(path):
45
+ os.makedirs(path)
46
+ return path
47
+
48
+ def brackets_to_periods(name):
49
+ return name.replace("[", ".").replace("]", "")
50
+
51
+ def get_params(model):
52
+ return model.state_dict()
53
+
54
+ def get_shape(p, model):
55
+ # We need to flip the shapes since OpenAI gpt2 uses convs instead of linear
56
+ return p.shape if isinstance(model, transformers.GPT2LMHeadModel) else (p.shape[1], p.shape[0])
57
+
58
+ def get_logits(x):
59
+ return x.logits if hasattr(x, "logits") else x
60
+
61
+ def tokenize(batch, tokenizer, device, test=False):
62
+ prompt, label = batch["prompt"], batch["target_new"]
63
+ if not isinstance(prompt, list):
64
+ prompt=[prompt]
65
+ if not isinstance(label, list):
66
+ label=[label]
67
+ mask_token = -100 # ignore_index of CrossEntropyLoss
68
+ if test or not label:
69
+ tokens = tokenizer(list(prompt), return_tensors="pt", padding=True, truncation=True)
70
+ tokens["labels"] = tokens["input_ids"].clone()
71
+ tokens["labels"][tokens["input_ids"] == tokenizer.pad_token_id] = mask_token
72
+
73
+ else:
74
+ full_prompt = [f"{p} {l}" for p, l in zip(prompt, label)]
75
+ prompt_ids = tokenizer(list(prompt), return_tensors="pt", padding=True, truncation=True)["input_ids"]
76
+ num_prompt_toks = [int((i != tokenizer.pad_token_id).sum()) for i in prompt_ids]
77
+ tokens = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True)
78
+ tokens["labels"] = tokens["input_ids"].clone()
79
+ for i in range(len(prompt)):
80
+ tokens["labels"][i][:num_prompt_toks[i]] = mask_token
81
+
82
+ tokens["labels"][tokens["input_ids"] == tokenizer.pad_token_id] = mask_token
83
+
84
+ tokens = {f"{k1}" : v1.to(device) for k1, v1 in tokens.items()}
85
+ return tokens
86
+
easyeditor/util/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .logit_lens import LogitLens
2
+ from .hparams import *
easyeditor/util/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (208 Bytes). View file
 
easyeditor/util/__pycache__/hparams.cpython-39.pyc ADDED
Binary file (1.21 kB). View file
 
easyeditor/util/__pycache__/logit_lens.cpython-39.pyc ADDED
Binary file (3.35 kB). View file
 
easyeditor/util/__pycache__/nethook.cpython-39.pyc ADDED
Binary file (13.2 kB). View file
 
easyeditor/util/alg_dict.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..models.rome import ROMEHyperParams, apply_rome_to_model
2
+ from ..models.memit import MEMITHyperParams, apply_memit_to_model
3
+ from ..models.kn import KNHyperParams, apply_kn_to_model
4
+ from ..models.mend import MENDHyperParams, MendRewriteExecutor, MendMultimodalRewriteExecutor
5
+ from ..models.ft import FTHyperParams, apply_ft_to_model
6
+ from ..models.serac import SERACHparams, SeracRewriteExecutor, SeracMultimodalRewriteExecutor
7
+ from ..dataset import ZsreDataset, CounterFactDataset, CaptionDataset, VQADataset
8
+ from ..models.ike import IKEHyperParams, apply_ike_to_model, apply_ike_to_multimodal_model
9
+ from ..models.ft_api import FTApiHyperParams, apply_ft_api_to_model
10
+ from ..models.lora import LoRAHyperParams, apply_lora_to_model
11
+ from ..models.grace import GraceHyperParams, apply_grace_to_model
12
+ from ..models.pmet import PMETHyperParams, apply_pmet_to_model
13
+ from ..models.melo import MELOHyperParams, apply_melo_to_model
14
+
15
+ ALG_DICT = {
16
+ 'ROME': apply_rome_to_model,
17
+ 'MEMIT': apply_memit_to_model,
18
+ "FT": apply_ft_to_model,
19
+ 'KN': apply_kn_to_model,
20
+ 'MEND': MendRewriteExecutor().apply_to_model,
21
+ 'SERAC': SeracRewriteExecutor().apply_to_model,
22
+ 'IKE': apply_ike_to_model,
23
+ 'FT-Api': apply_ft_api_to_model,
24
+ 'LoRA': apply_lora_to_model,
25
+ 'GRACE': apply_grace_to_model,
26
+ 'PMET': apply_pmet_to_model,
27
+ 'MELO': apply_melo_to_model
28
+ }
29
+
30
+ ALG_MULTIMODAL_DICT = {
31
+ 'MEND': MendMultimodalRewriteExecutor().apply_to_model,
32
+ 'SERAC': SeracMultimodalRewriteExecutor().apply_to_model,
33
+ 'SERAC_MULTI': SeracMultimodalRewriteExecutor().apply_to_model,
34
+ 'IKE': apply_ike_to_multimodal_model,
35
+ }
36
+
37
+ DS_DICT = {
38
+ "cf": CounterFactDataset,
39
+ "zsre": ZsreDataset,
40
+ }
41
+
42
+ MULTIMODAL_DS_DICT = {
43
+ "caption": CaptionDataset,
44
+ "vqa": VQADataset,
45
+ }
easyeditor/util/alg_train_dict.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from ..trainer import MEND
2
+ from ..trainer import SERAC, SERAC_MULTI
3
+
4
+
5
+ ALG_TRAIN_DICT = {
6
+ 'MEND': MEND,
7
+ 'SERAC': SERAC,
8
+ 'SERAC_MULTI': SERAC_MULTI,
9
+ }
easyeditor/util/generate.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unicodedata
2
+ from typing import List, Optional
3
+
4
+ import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+
7
+ from .logit_lens import LogitLens
8
+
9
+
10
+ def generate_interactive(
11
+ model: AutoModelForCausalLM,
12
+ tok: AutoTokenizer,
13
+ top_k: int = 5,
14
+ max_out_len: int = 200,
15
+ compare_against: Optional[AutoModelForCausalLM] = None,
16
+ use_logit_lens: bool = False,
17
+ layer_module_tmp: str = "transformer.h.{}",
18
+ ln_f_module: str = "transformer.ln_f",
19
+ lm_head_module: str = "lm_head",
20
+ ):
21
+ """
22
+ Puts generation in a loop. Allows users to repeatedly provide inputs
23
+ with which text is generated.
24
+ """
25
+
26
+ if use_logit_lens:
27
+ llens_gen = LogitLens(
28
+ model,
29
+ tok,
30
+ layer_module_tmp,
31
+ ln_f_module,
32
+ lm_head_module,
33
+ disabled=not use_logit_lens,
34
+ )
35
+ if compare_against:
36
+ llens_vanilla = LogitLens(
37
+ compare_against,
38
+ tok,
39
+ layer_module_tmp,
40
+ ln_f_module,
41
+ lm_head_module,
42
+ disabled=not use_logit_lens,
43
+ )
44
+
45
+ while True:
46
+ prompt = input("Enter a prompt: ").strip(" \r\t\n")
47
+
48
+ print(
49
+ f"Argument Model: "
50
+ f"{generate_fast(model, tok, [prompt], n_gen_per_prompt=1, top_k=top_k, max_out_len=max_out_len)}"
51
+ )
52
+ if compare_against:
53
+ print(
54
+ f"Baseline Model: "
55
+ f"{generate_fast(compare_against, tok, [prompt], n_gen_per_prompt=1, top_k=top_k, max_out_len=max_out_len)}"
56
+ )
57
+
58
+ if use_logit_lens:
59
+ inp_prompt = tok([prompt], padding=True, return_tensors="pt").to(
60
+ next(model.parameters()).device
61
+ )
62
+
63
+ with llens_gen:
64
+ model(**inp_prompt)
65
+ print("\n--- Argument Model Logit Lens ---")
66
+ llens_gen.pprint()
67
+
68
+ if compare_against:
69
+ with llens_vanilla:
70
+ compare_against(**inp_prompt)
71
+ print("--- Baseline Model Logit Lens ---")
72
+ llens_vanilla.pprint()
73
+
74
+ print()
75
+
76
+
77
+ def generate_fast(
78
+ model: AutoModelForCausalLM,
79
+ tok: AutoTokenizer,
80
+ prompts: List[str],
81
+ n_gen_per_prompt: int = 1,
82
+ top_k: int = 5,
83
+ max_out_len: int = 200,
84
+ vanilla_generation=False,
85
+ ):
86
+ """
87
+ Fast, parallelized auto-regressive text generation with top-k sampling.
88
+ Our custom implementation.
89
+ """
90
+
91
+ # Unroll prompts and tokenize
92
+ inp = [prompt for prompt in prompts for _ in range(n_gen_per_prompt)]
93
+ inp_tok = tok(inp, padding=True, return_tensors="pt").to(
94
+ next(model.parameters()).device
95
+ )
96
+ input_ids, attention_mask = inp_tok["input_ids"], inp_tok["attention_mask"]
97
+ if vanilla_generation:
98
+ gen_txt = model.generate(
99
+ input_ids=input_ids,
100
+ attention_mask=attention_mask,
101
+ max_new_tokens=max_out_len
102
+ )
103
+ txt = [tok.decode(x, skip_special_tokens=True) for x in gen_txt.detach().cpu().numpy().tolist()]
104
+ txt = [
105
+ unicodedata.normalize("NFKD", x)
106
+ .replace("\n\n", " ")
107
+ .replace("<|endoftext|>", "")
108
+ for x in txt
109
+ ]
110
+ return txt
111
+ batch_size = input_ids.size(0)
112
+
113
+ # Setup storage of fast generation with attention caches.
114
+ # `cur_context` is used to define the range of inputs that are not yet
115
+ # stored in `past_key_values`. At each step, we are generating the
116
+ # next token for the index at `cur_context.stop + 1`.
117
+ past_key_values, cur_context = None, slice(0, attention_mask.sum(1).min().item())
118
+
119
+ with torch.no_grad():
120
+ while input_ids.size(1) < max_out_len: # while not exceeding max output length
121
+ model_out = model(
122
+ input_ids=input_ids[:, cur_context],
123
+ attention_mask=None if 'llama'or'baichuan' in model.name_or_path.lower() else attention_mask[:, cur_context],
124
+ past_key_values=past_key_values,
125
+ use_cache=True,
126
+ )
127
+ logits, past_key_values = model_out.logits, model_out.past_key_values
128
+ softmax_out = torch.nn.functional.softmax(logits[:, -1, :], dim=1)
129
+
130
+ # Top-k sampling
131
+ tk = torch.topk(softmax_out, top_k, dim=1).indices
132
+ softmax_out_top_k = torch.gather(softmax_out, 1, tk)
133
+ softmax_out_top_k = softmax_out_top_k / softmax_out_top_k.sum(1)[:, None]
134
+ new_tok_indices = torch.multinomial(softmax_out_top_k, 1)
135
+ new_toks = torch.gather(tk, 1, new_tok_indices)
136
+
137
+ # If we're currently generating the continuation for the last token in `input_ids`,
138
+ # create a new index so we can insert the new token
139
+ if cur_context.stop == input_ids.size(1):
140
+ attention_mask = torch.cat(
141
+ [attention_mask, attention_mask.new_zeros(batch_size, 1)], dim=1
142
+ )
143
+ input_ids = torch.cat(
144
+ [
145
+ input_ids,
146
+ input_ids.new_ones(batch_size, 1) * tok.pad_token_id,
147
+ ],
148
+ dim=1,
149
+ )
150
+
151
+ last_non_masked = attention_mask.sum(1) - 1
152
+ for i in range(batch_size):
153
+ new_idx = last_non_masked[i] + 1
154
+ if last_non_masked[i].item() + 1 != cur_context.stop:
155
+ continue
156
+
157
+ # Stop generating if we've already maxed out for this prompt
158
+ if new_idx < max_out_len:
159
+ input_ids[i][new_idx] = new_toks[i]
160
+ attention_mask[i][new_idx] = 1
161
+
162
+ cur_context = slice(cur_context.stop, cur_context.stop + 1)
163
+ txt = [tok.decode(x, skip_special_tokens=True) for x in input_ids.detach().cpu().numpy().tolist()]
164
+ txt = [
165
+ unicodedata.normalize("NFKD", x)
166
+ .replace("\n\n", " ")
167
+ .replace("<|endoftext|>", "")
168
+ for x in txt
169
+ ]
170
+
171
+ return txt
easyeditor/util/globals.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import logging
4
+ import os
5
+
6
+ import yaml
7
+
8
+
9
+ def get_handler(path, log_name):
10
+ log_file_path = os.path.join(path, log_name)
11
+ try:
12
+ if not os.path.exists(path):
13
+ print("We are creating the logger files")
14
+ os.makedirs(path)
15
+ except:
16
+ pass
17
+ file_handler = logging.FileHandler(log_file_path)
18
+ file_handler.setLevel(logging.DEBUG)
19
+ file_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
20
+
21
+ stream_handler = logging.StreamHandler()
22
+ stream_handler.setLevel(logging.DEBUG)
23
+ stream_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
24
+ return file_handler, stream_handler
25
+
26
+
27
+ # def get_run_dir(dir_name):
28
+ #
29
+ # alg_dir = RESULTS_DIR / dir_name
30
+ # if alg_dir.exists():
31
+ # id_list = [
32
+ # int(str(x).split("_")[-1])
33
+ # for x in alg_dir.iterdir()
34
+ # if str(x).split("_")[-1].isnumeric()
35
+ # ]
36
+ # run_id = 0 if not id_list else max(id_list) + 1
37
+ # else:
38
+ # run_id = 0
39
+ # run_dir = RESULTS_DIR / dir_name / f"run_{str(run_id).zfill(3)}"
40
+ # run_dir.mkdir(parents=True, exist_ok=True)
41
+ # print(f"Results will be stored at {run_dir}")
42
+ #
43
+ # return run_dir
easyeditor/util/hparams.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from dataclasses import dataclass
3
+ from dataclasses import asdict
4
+
5
+
6
+ @dataclass
7
+ class HyperParams:
8
+ """
9
+ Simple wrapper to store hyperparameters for Python-based rewriting methods.
10
+ """
11
+
12
+ @classmethod
13
+ def from_json(cls, fpath):
14
+ with open(fpath, "r") as f:
15
+ data = json.load(f)
16
+
17
+ return cls(**data)
18
+
19
+ def construct_float_from_scientific_notation(config: dict):
20
+ for key, value in config.items():
21
+ if isinstance(value, str):
22
+ try:
23
+ # Convert scalar to float if it is in scientific notation format
24
+ config[key] = float(value)
25
+ except:
26
+ pass
27
+ return config
28
+
29
+ def to_dict(config) -> dict:
30
+ dict = asdict(config)
31
+ return dict
32
+
33
+
34
+
35
+ # @classmethod
36
+ # def from_hparams(cls, hparams_name_or_path: str):
37
+ #
38
+ # if '.yaml' not in hparams_name_or_path:
39
+ # hparams_name_or_path = hparams_name_or_path + '.yaml'
40
+ # config = compose(hparams_name_or_path)
41
+ #
42
+ # assert config.alg_name in ALG_DICT.keys() or print(f'Editing Alg name {config.alg_name} not supported yet.')
43
+ #
44
+ # params_class, apply_algo = ALG_DICT[config.alg_name]
45
+ #
46
+ # return params_class(**config)
easyeditor/util/logit_lens.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from typing import Dict, Optional
3
+
4
+ import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+
7
+ from . import nethook
8
+
9
+
10
+ class LogitLens:
11
+ """
12
+ Applies the LM head at the output of each hidden layer, then analyzes the
13
+ resultant token probability distribution.
14
+
15
+ Only works when hooking outputs of *one* individual generation.
16
+
17
+ Inspiration: https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens
18
+
19
+ Warning: when running multiple times (e.g. generation), will return
20
+ outputs _only_ for the last processing step.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ model: AutoModelForCausalLM,
26
+ tok: AutoTokenizer,
27
+ layer_module_tmp: str,
28
+ ln_f_module: str,
29
+ lm_head_module: str,
30
+ disabled: bool = False,
31
+ ):
32
+ self.disabled = disabled
33
+ self.model, self.tok = model, tok
34
+ self.n_layers = self.model.config.n_layer
35
+
36
+ self.lm_head, self.ln_f = (
37
+ nethook.get_module(model, lm_head_module),
38
+ nethook.get_module(model, ln_f_module),
39
+ )
40
+
41
+ self.output: Optional[Dict] = None
42
+ self.td: Optional[nethook.TraceDict] = None
43
+ self.trace_layers = [
44
+ layer_module_tmp.format(layer) for layer in range(self.n_layers)
45
+ ]
46
+
47
+ def __enter__(self):
48
+ if not self.disabled:
49
+ self.td = nethook.TraceDict(
50
+ self.model,
51
+ self.trace_layers,
52
+ retain_input=False,
53
+ retain_output=True,
54
+ )
55
+ self.td.__enter__()
56
+
57
+ def __exit__(self, *args):
58
+ if self.disabled:
59
+ return
60
+ self.td.__exit__(*args)
61
+
62
+ self.output = {layer: [] for layer in range(self.n_layers)}
63
+
64
+ with torch.no_grad():
65
+ for layer, (_, t) in enumerate(self.td.items()):
66
+ cur_out = t.output[0]
67
+ assert (
68
+ cur_out.size(0) == 1
69
+ ), "Make sure you're only running LogitLens on single generations only."
70
+
71
+ self.output[layer] = torch.softmax(
72
+ self.lm_head(self.ln_f(cur_out[:, -1, :])), dim=1
73
+ )
74
+
75
+ return self.output
76
+
77
+ def pprint(self, k=5):
78
+ to_print = defaultdict(list)
79
+
80
+ for layer, pred in self.output.items():
81
+ rets = torch.topk(pred[0], k)
82
+ for i in range(k):
83
+ to_print[layer].append(
84
+ (
85
+ self.tok.decode(rets[1][i]),
86
+ round(rets[0][i].item() * 1e2) / 1e2,
87
+ )
88
+ )
89
+
90
+ print(
91
+ "\n".join(
92
+ [
93
+ f"{layer}: {[(el[0], round(el[1] * 1e2)) for el in to_print[layer]]}"
94
+ for layer in range(self.n_layers)
95
+ ]
96
+ )
97
+ )
easyeditor/util/nethook.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for instrumenting a torch model.
3
+
4
+ Trace will hook one layer at a time.
5
+ TraceDict will hook multiple layers at once.
6
+ subsequence slices intervals from Sequential modules.
7
+ get_module, replace_module, get_parameter resolve dotted names.
8
+ set_requires_grad recursively sets requires_grad in module parameters.
9
+ """
10
+
11
+ import contextlib
12
+ import copy
13
+ import inspect
14
+ from collections import OrderedDict
15
+
16
+ import torch
17
+
18
+
19
+ class Trace(contextlib.AbstractContextManager):
20
+ """
21
+ To retain the output of the named layer during the computation of
22
+ the given network:
23
+
24
+ with Trace(net, 'layer.name') as ret:
25
+ _ = net(inp)
26
+ representation = ret.output
27
+
28
+ A layer module can be passed directly without a layer name, and
29
+ its output will be retained. By default, a direct reference to
30
+ the output object is returned, but options can control this:
31
+
32
+ clone=True - retains a copy of the output, which can be
33
+ useful if you want to see the output before it might
34
+ be modified by the network in-place later.
35
+ detach=True - retains a detached reference or copy. (By
36
+ default the value would be left attached to the graph.)
37
+ retain_grad=True - request gradient to be retained on the
38
+ output. After backward(), ret.output.grad is populated.
39
+
40
+ retain_input=True - also retains the input.
41
+ retain_output=False - can disable retaining the output.
42
+ edit_output=fn - calls the function to modify the output
43
+ of the layer before passing it the rest of the model.
44
+ fn can optionally accept (output, layer) arguments
45
+ for the original output and the layer name.
46
+ stop=True - throws a StopForward exception after the layer
47
+ is run, which allows running just a portion of a model.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ module,
53
+ layer=None,
54
+ retain_output=True,
55
+ retain_input=False,
56
+ clone=False,
57
+ detach=False,
58
+ retain_grad=False,
59
+ edit_output=None,
60
+ stop=False,
61
+ ):
62
+ """
63
+ Method to replace a forward method with a closure that
64
+ intercepts the call, and tracks the hook so that it can be reverted.
65
+ """
66
+ retainer = self
67
+ self.layer = layer
68
+ if layer is not None:
69
+ module = get_module(module, layer)
70
+
71
+ def retain_hook(m, inputs, output):
72
+ if retain_input:
73
+ retainer.input = recursive_copy(
74
+ inputs[0] if len(inputs) == 1 else inputs,
75
+ clone=clone,
76
+ detach=detach,
77
+ retain_grad=False,
78
+ ) # retain_grad applies to output only.
79
+ if edit_output:
80
+ output = invoke_with_optional_args(
81
+ edit_output, output=output, layer=self.layer
82
+ )
83
+ if retain_output:
84
+ retainer.output = recursive_copy(
85
+ output, clone=clone, detach=detach, retain_grad=retain_grad
86
+ )
87
+ # When retain_grad is set, also insert a trivial
88
+ # copy operation. That allows in-place operations
89
+ # to follow without error.
90
+ if retain_grad:
91
+ output = recursive_copy(retainer.output, clone=True, detach=False)
92
+ if stop:
93
+ raise StopForward()
94
+ return output
95
+
96
+ self.registered_hook = module.register_forward_hook(retain_hook)
97
+ self.stop = stop
98
+
99
+ def __enter__(self):
100
+ return self
101
+
102
+ def __exit__(self, type, value, traceback):
103
+ self.close()
104
+ if self.stop and issubclass(type, StopForward):
105
+ return True
106
+
107
+ def close(self):
108
+ self.registered_hook.remove()
109
+
110
+
111
+ class TraceDict(OrderedDict, contextlib.AbstractContextManager):
112
+ """
113
+ To retain the output of multiple named layers during the computation
114
+ of the given network:
115
+
116
+ with TraceDict(net, ['layer1.name1', 'layer2.name2']) as ret:
117
+ _ = net(inp)
118
+ representation = ret['layer1.name1'].output
119
+
120
+ If edit_output is provided, it should be a function that takes
121
+ two arguments: output, and the layer name; and then it returns the
122
+ modified output.
123
+
124
+ Other arguments are the same as Trace. If stop is True, then the
125
+ execution of the network will be stopped after the last layer
126
+ listed (even if it would not have been the last to be executed).
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ module,
132
+ layers=None,
133
+ retain_output=True,
134
+ retain_input=False,
135
+ clone=False,
136
+ detach=False,
137
+ retain_grad=False,
138
+ edit_output=None,
139
+ stop=False,
140
+ ):
141
+ self.stop = stop
142
+
143
+ def flag_last_unseen(it):
144
+ try:
145
+ it = iter(it)
146
+ prev = next(it)
147
+ seen = set([prev])
148
+ except StopIteration:
149
+ return
150
+ for item in it:
151
+ if item not in seen:
152
+ yield False, prev
153
+ seen.add(item)
154
+ prev = item
155
+ yield True, prev
156
+
157
+ for is_last, layer in flag_last_unseen(layers):
158
+ self[layer] = Trace(
159
+ module=module,
160
+ layer=layer,
161
+ retain_output=retain_output,
162
+ retain_input=retain_input,
163
+ clone=clone,
164
+ detach=detach,
165
+ retain_grad=retain_grad,
166
+ edit_output=edit_output,
167
+ stop=stop and is_last,
168
+ )
169
+
170
+ def __enter__(self):
171
+ return self
172
+
173
+ def __exit__(self, type, value, traceback):
174
+ self.close()
175
+ if self.stop and issubclass(type, StopForward):
176
+ return True
177
+
178
+ def close(self):
179
+ for layer, trace in reversed(self.items()):
180
+ trace.close()
181
+
182
+
183
+ class StopForward(Exception):
184
+ """
185
+ If the only output needed from running a network is the retained
186
+ submodule then Trace(submodule, stop=True) will stop execution
187
+ immediately after the retained submodule by raising the StopForward()
188
+ exception. When Trace is used as context manager, it catches that
189
+ exception and can be used as follows:
190
+
191
+ with Trace(net, layername, stop=True) as tr:
192
+ net(inp) # Only runs the network up to layername
193
+ print(tr.output)
194
+ """
195
+
196
+ pass
197
+
198
+
199
+ def recursive_copy(x, clone=None, detach=None, retain_grad=None):
200
+ """
201
+ Copies a reference to a tensor, or an object that contains tensors,
202
+ optionally detaching and cloning the tensor(s). If retain_grad is
203
+ true, the original tensors are marked to have grads retained.
204
+ """
205
+ if not clone and not detach and not retain_grad:
206
+ return x
207
+ if isinstance(x, torch.Tensor):
208
+ if retain_grad:
209
+ if not x.requires_grad:
210
+ x.requires_grad = True
211
+ x.retain_grad()
212
+ elif detach:
213
+ x = x.detach()
214
+ if clone:
215
+ x = x.clone()
216
+ return x
217
+ # Only dicts, lists, and tuples (and subclasses) can be copied.
218
+ if isinstance(x, dict):
219
+ return type(x)({k: recursive_copy(v) for k, v in x.items()})
220
+ elif isinstance(x, (list, tuple)):
221
+ return type(x)([recursive_copy(v) for v in x])
222
+ else:
223
+ assert False, f"Unknown type {type(x)} cannot be broken into tensors."
224
+
225
+
226
+ def subsequence(
227
+ sequential,
228
+ first_layer=None,
229
+ last_layer=None,
230
+ after_layer=None,
231
+ upto_layer=None,
232
+ single_layer=None,
233
+ share_weights=False,
234
+ ):
235
+ """
236
+ Creates a subsequence of a pytorch Sequential model, copying over
237
+ modules together with parameters for the subsequence. Only
238
+ modules from first_layer to last_layer (inclusive) are included,
239
+ or modules between after_layer and upto_layer (exclusive).
240
+ Handles descent into dotted layer names as long as all references
241
+ are within nested Sequential models.
242
+
243
+ If share_weights is True, then references the original modules
244
+ and their parameters without copying them. Otherwise, by default,
245
+ makes a separate brand-new copy.
246
+ """
247
+ assert (single_layer is None) or (
248
+ first_layer is last_layer is after_layer is upto_layer is None
249
+ )
250
+ if single_layer is not None:
251
+ first_layer = single_layer
252
+ last_layer = single_layer
253
+ first, last, after, upto = [
254
+ None if d is None else d.split(".")
255
+ for d in [first_layer, last_layer, after_layer, upto_layer]
256
+ ]
257
+ return hierarchical_subsequence(
258
+ sequential,
259
+ first=first,
260
+ last=last,
261
+ after=after,
262
+ upto=upto,
263
+ share_weights=share_weights,
264
+ )
265
+
266
+
267
+ def hierarchical_subsequence(
268
+ sequential, first, last, after, upto, share_weights=False, depth=0
269
+ ):
270
+ """
271
+ Recursive helper for subsequence() to support descent into dotted
272
+ layer names. In this helper, first, last, after, and upto are
273
+ arrays of names resulting from splitting on dots. Can only
274
+ descend into nested Sequentials.
275
+ """
276
+ assert (last is None) or (upto is None)
277
+ assert (first is None) or (after is None)
278
+ if first is last is after is upto is None:
279
+ return sequential if share_weights else copy.deepcopy(sequential)
280
+ assert isinstance(sequential, torch.nn.Sequential), (
281
+ ".".join((first or last or after or upto)[:depth] or "arg") + " not Sequential"
282
+ )
283
+ including_children = (first is None) and (after is None)
284
+ included_children = OrderedDict()
285
+ # A = current level short name of A.
286
+ # AN = full name for recursive descent if not innermost.
287
+ (F, FN), (L, LN), (A, AN), (U, UN) = [
288
+ (d[depth], (None if len(d) == depth + 1 else d))
289
+ if d is not None
290
+ else (None, None)
291
+ for d in [first, last, after, upto]
292
+ ]
293
+ for name, layer in sequential._modules.items():
294
+ if name == F:
295
+ first = None
296
+ including_children = True
297
+ if name == A and AN is not None: # just like F if not a leaf.
298
+ after = None
299
+ including_children = True
300
+ if name == U and UN is None:
301
+ upto = None
302
+ including_children = False
303
+ if including_children:
304
+ # AR = full name for recursive descent if name matches.
305
+ FR, LR, AR, UR = [
306
+ n if n is None or n[depth] == name else None for n in [FN, LN, AN, UN]
307
+ ]
308
+ chosen = hierarchical_subsequence(
309
+ layer,
310
+ first=FR,
311
+ last=LR,
312
+ after=AR,
313
+ upto=UR,
314
+ share_weights=share_weights,
315
+ depth=depth + 1,
316
+ )
317
+ if chosen is not None:
318
+ included_children[name] = chosen
319
+ if name == L:
320
+ last = None
321
+ including_children = False
322
+ if name == U and UN is not None: # just like L if not a leaf.
323
+ upto = None
324
+ including_children = False
325
+ if name == A and AN is None:
326
+ after = None
327
+ including_children = True
328
+ for name in [first, last, after, upto]:
329
+ if name is not None:
330
+ raise ValueError("Layer %s not found" % ".".join(name))
331
+ # Omit empty subsequences except at the outermost level,
332
+ # where we should not return None.
333
+ if not len(included_children) and depth > 0:
334
+ return None
335
+ result = torch.nn.Sequential(included_children)
336
+ result.training = sequential.training
337
+ return result
338
+
339
+
340
+ def set_requires_grad(requires_grad, *models):
341
+ """
342
+ Sets requires_grad true or false for all parameters within the
343
+ models passed.
344
+ """
345
+ for model in models:
346
+ if isinstance(model, torch.nn.Module):
347
+ for param in model.parameters():
348
+ param.requires_grad = requires_grad
349
+ elif isinstance(model, (torch.nn.Parameter, torch.Tensor)):
350
+ model.requires_grad = requires_grad
351
+ else:
352
+ assert False, "unknown type %r" % type(model)
353
+
354
+
355
+ def get_module(model, name):
356
+ """
357
+ Finds the named module within the given model.
358
+ """
359
+ for n, m in model.named_modules():
360
+ if n == name:
361
+ return m
362
+ raise LookupError(name)
363
+
364
+
365
+ def get_parameter(model, name):
366
+ """
367
+ Finds the named parameter within the given model.
368
+ """
369
+ for n, p in model.named_parameters():
370
+ if n == name:
371
+ return p
372
+ raise LookupError(name)
373
+
374
+
375
+ def replace_module(model, name, new_module):
376
+ """
377
+ Replaces the named module within the given model.
378
+ """
379
+ if "." in name:
380
+ parent_name, attr_name = name.rsplit(".", 1)
381
+ model = get_module(model, parent_name)
382
+ # original_module = getattr(model, attr_name)
383
+ setattr(model, attr_name, new_module)
384
+
385
+
386
+ def invoke_with_optional_args(fn, *args, **kwargs):
387
+ """
388
+ Invokes a function with only the arguments that it
389
+ is written to accept, giving priority to arguments
390
+ that match by-name, using the following rules.
391
+ (1) arguments with matching names are passed by name.
392
+ (2) remaining non-name-matched args are passed by order.
393
+ (3) extra caller arguments that the function cannot
394
+ accept are not passed.
395
+ (4) extra required function arguments that the caller
396
+ cannot provide cause a TypeError to be raised.
397
+ Ordinary python calling conventions are helpful for
398
+ supporting a function that might be revised to accept
399
+ extra arguments in a newer version, without requiring the
400
+ caller to pass those new arguments. This function helps
401
+ support function callers that might be revised to supply
402
+ extra arguments, without requiring the callee to accept
403
+ those new arguments.
404
+ """
405
+ argspec = inspect.getfullargspec(fn)
406
+ pass_args = []
407
+ used_kw = set()
408
+ unmatched_pos = []
409
+ used_pos = 0
410
+ defaulted_pos = len(argspec.args) - (
411
+ 0 if not argspec.defaults else len(argspec.defaults)
412
+ )
413
+ # Pass positional args that match name first, then by position.
414
+ for i, n in enumerate(argspec.args):
415
+ if n in kwargs:
416
+ pass_args.append(kwargs[n])
417
+ used_kw.add(n)
418
+ elif used_pos < len(args):
419
+ pass_args.append(args[used_pos])
420
+ used_pos += 1
421
+ else:
422
+ unmatched_pos.append(len(pass_args))
423
+ pass_args.append(
424
+ None if i < defaulted_pos else argspec.defaults[i - defaulted_pos]
425
+ )
426
+ # Fill unmatched positional args with unmatched keyword args in order.
427
+ if len(unmatched_pos):
428
+ for k, v in kwargs.items():
429
+ if k in used_kw or k in argspec.kwonlyargs:
430
+ continue
431
+ pass_args[unmatched_pos[0]] = v
432
+ used_kw.add(k)
433
+ unmatched_pos = unmatched_pos[1:]
434
+ if len(unmatched_pos) == 0:
435
+ break
436
+ else:
437
+ if unmatched_pos[0] < defaulted_pos:
438
+ unpassed = ", ".join(
439
+ argspec.args[u] for u in unmatched_pos if u < defaulted_pos
440
+ )
441
+ raise TypeError(f"{fn.__name__}() cannot be passed {unpassed}.")
442
+ # Pass remaining kw args if they can be accepted.
443
+ pass_kw = {
444
+ k: v
445
+ for k, v in kwargs.items()
446
+ if k not in used_kw and (k in argspec.kwonlyargs or argspec.varargs is not None)
447
+ }
448
+ # Pass remaining positional args if they can be accepted.
449
+ if argspec.varargs is not None:
450
+ pass_args += list(args[used_pos:])
451
+ return fn(*pass_args, **pass_kw)
easyeditor/util/perplexity.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+
4
+
5
+ def perplexity(
6
+ model: AutoModelForCausalLM,
7
+ tok: AutoTokenizer,
8
+ text: str,
9
+ max_input_length: int = None,
10
+ ):
11
+ """
12
+ Computes perplexity of a piece of text, measured on a reference model.
13
+ Text is truncated to max_input_length tokens.
14
+ """
15
+
16
+ inputs = tok(
17
+ [text], return_tensors="pt", max_length=max_input_length, truncation=True
18
+ ).to("cuda")
19
+
20
+ logits = torch.nn.functional.log_softmax(model(**inputs).logits, dim=2)
21
+ log_probs = torch.gather(logits[:, :-1, :], 2, inputs["input_ids"][:, 1:, None])[0]
22
+
23
+ # Perplexity = exp(-1/N * log P(x_1, ..., x_n))
24
+ return torch.exp(-1 / inputs["input_ids"].size(1) * log_probs.sum()).item()
easyeditor/util/runningstats.py ADDED
@@ -0,0 +1,1883 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ To use a runningstats object,
3
+
4
+ 1. Create the the desired stat object, e.g., `m = Mean()`
5
+ 2. Feed it batches via the add method, e.g., `m.add(batch)`
6
+ 3. Repeat step 2 any number of times.
7
+ 4. Read out the statistic of interest, e.g., `m.mean()`
8
+
9
+ Built-in runningstats objects include:
10
+
11
+ Mean - produces mean().
12
+ Variance - mean() and variance() and stdev().
13
+ Covariance - mean(), covariance(), correlation(), variance(), stdev().
14
+ SecondMoment - moment() is the non-mean-centered covariance, E[x x^T].
15
+ Quantile - quantile(), min(), max(), median(), mean(), variance(), stdev().
16
+ TopK - topk() returns (values, indexes).
17
+ Bincount - bincount() histograms nonnegative integer data.
18
+ IoU - intersection(), union(), iou() tally binary co-occurrences.
19
+ History - history() returns concatenation of data.
20
+ CrossCovariance - covariance between two signals, without self-covariance.
21
+ CrossIoU - iou between two signals, without self-IoU.
22
+ CombinedStat - aggregates any set of stats.
23
+
24
+ Add more running stats by subclassing the Stat class.
25
+
26
+ These statistics are vectorized along dim>=1, so stat.add()
27
+ should supply a two-dimensional input where the zeroth
28
+ dimension is the batch/sampling dimension and the first
29
+ dimension is the feature dimension.
30
+
31
+ The data type and device used matches the data passed to add();
32
+ for example, for higher-precision covariances, convert to double
33
+ before calling add().
34
+
35
+ It is common to want to compute and remember a statistic sampled
36
+ over a Dataset, computed in batches, possibly caching the computed
37
+ statistic in a file. The tally(stat, dataset, cache) handles
38
+ this pattern. It takes a statistic, a dataset, and a cache filename
39
+ and sets up a data loader that can be run (or not, if cached) to
40
+ compute the statistic, adopting the convention that cached stats are
41
+ saved to and loaded from numpy npz files.
42
+ """
43
+
44
+ import math
45
+ import os
46
+ import random
47
+ import struct
48
+
49
+ import numpy
50
+ import torch
51
+ from torch.utils.data.sampler import Sampler
52
+
53
+
54
+ def tally(stat, dataset, cache=None, quiet=False, **kwargs):
55
+ """
56
+ To use tally, write code like the following.
57
+
58
+ stat = Mean()
59
+ ds = MyDataset()
60
+ for batch in tally(stat, ds, cache='mymean.npz', batch_size=50):
61
+ stat.add(batch)
62
+ mean = stat.mean()
63
+
64
+ The first argument should be the Stat being computed. After the
65
+ loader is exhausted, tally will bring this stat to the cpu and
66
+ cache it (if a cache is specified).
67
+
68
+ The dataset can be a torch Dataset or a plain Tensor, or it can
69
+ be a callable that returns one of those.
70
+
71
+ Details on caching via the cache= argument:
72
+
73
+ If the given filename cannot be loaded, tally will leave the
74
+ statistic object empty and set up a DataLoader object so that
75
+ the loop can be run. After the last iteration of the loop, the
76
+ completed statistic will be moved to the cpu device and also
77
+ saved in the cache file.
78
+
79
+ If the cached statistic can be loaded from the given file, tally
80
+ will not set up the data loader and instead will return a fully
81
+ loaded statistic object (on the cpu device) and an empty list as
82
+ the loader.
83
+
84
+ The `with cache_load_enabled(False):` context manager can
85
+ be used to disable loading from the cache.
86
+
87
+ If needed, a DataLoader will be created to wrap the dataset:
88
+
89
+ Keyword arguments of tally are passed to the DataLoader,
90
+ so batch_size, num_workers, pin_memory, etc. can be specified.
91
+
92
+ Subsampling is supported via sample_size= and random_sample=:
93
+
94
+ If sample_size=N is specified, rather than loading the whole
95
+ dataset, only the first N items are sampled. If additionally
96
+ random_sample=S is specified, the pseudorandom seed S will be
97
+ used to select a fixed psedorandom sample of size N to sample.
98
+ """
99
+ assert isinstance(stat, Stat)
100
+ args = {}
101
+ for k in ["sample_size"]:
102
+ if k in kwargs:
103
+ args[k] = kwargs[k]
104
+ cached_state = load_cached_state(cache, args, quiet=quiet)
105
+ if cached_state is not None:
106
+ stat.load_state_dict(cached_state)
107
+
108
+ def empty_loader():
109
+ return
110
+ yield
111
+
112
+ return empty_loader()
113
+ loader = make_loader(dataset, **kwargs)
114
+
115
+ def wrapped_loader():
116
+ yield from loader
117
+ stat.to_(device="cpu")
118
+ if cache is not None:
119
+ save_cached_state(cache, stat, args)
120
+
121
+ return wrapped_loader()
122
+
123
+
124
+ class cache_load_enabled:
125
+ """
126
+ When used as a context manager, cache_load_enabled(False) will prevent
127
+ tally from loading cached statsitics, forcing them to be recomputed.
128
+ """
129
+
130
+ def __init__(self, enabled=True):
131
+ self.prev = False
132
+ self.enabled = enabled
133
+
134
+ def __enter__(self):
135
+ global global_load_cache_enabled
136
+ self.prev = global_load_cache_enabled
137
+ global_load_cache_enabled = self.enabled
138
+
139
+ def __exit__(self, exc_type, exc_value, traceback):
140
+ global global_load_cache_enabled
141
+ global_load_cache_enabled = self.prev
142
+
143
+
144
+ class Stat:
145
+ """
146
+ Abstract base class for a running pytorch statistic.
147
+ """
148
+
149
+ def __init__(self, state):
150
+ """
151
+ By convention, all Stat subclasses can be initialized by passing
152
+ state=; and then they will initialize by calling load_state_dict.
153
+ """
154
+ self.load_state_dict(resolve_state_dict(state))
155
+
156
+ def add(self, x, *args, **kwargs):
157
+ """
158
+ Observes a batch of samples to be incorporated into the statistic.
159
+ Dimension 0 should be the batch dimension, and dimension 1 should
160
+ be the feature dimension of the pytorch tensor x.
161
+ """
162
+ pass
163
+
164
+ def load_state_dict(self, d):
165
+ """
166
+ Loads this Stat from a dictionary of numpy arrays as saved
167
+ by state_dict.
168
+ """
169
+ pass
170
+
171
+ def state_dict(self):
172
+ """
173
+ Saves this Stat as a dictionary of numpy arrays that can be
174
+ stored in an npz or reloaded later using load_state_dict.
175
+ """
176
+ return {}
177
+
178
+ def save(self, filename):
179
+ """
180
+ Saves this stat as an npz file containing the state_dict.
181
+ """
182
+ save_cached_state(filename, self, {})
183
+
184
+ def load(self, filename):
185
+ """
186
+ Loads this stat from an npz file containing a saved state_dict.
187
+ """
188
+ self.load_state_dict(load_cached_state(filename, {}, quiet=True, throw=True))
189
+
190
+ def to_(self, device):
191
+ """
192
+ Moves this Stat to the given device.
193
+ """
194
+ pass
195
+
196
+ def cpu_(self):
197
+ """
198
+ Moves this Stat to the cpu device.
199
+ """
200
+ self.to_("cpu")
201
+
202
+ def cuda_(self):
203
+ """
204
+ Moves this Stat to the default cuda device.
205
+ """
206
+ self.to_("cuda")
207
+
208
+ def _normalize_add_shape(self, x, attr="data_shape"):
209
+ """
210
+ Flattens input data to 2d.
211
+ """
212
+ if not torch.is_tensor(x):
213
+ x = torch.tensor(x)
214
+ if len(x.shape) < 1:
215
+ x = x.view(-1)
216
+ data_shape = getattr(self, attr, None)
217
+ if data_shape is None:
218
+ data_shape = x.shape[1:]
219
+ setattr(self, attr, data_shape)
220
+ else:
221
+ assert x.shape[1:] == data_shape
222
+ return x.view(x.shape[0], int(numpy.prod(data_shape)))
223
+
224
+ def _restore_result_shape(self, x, attr="data_shape"):
225
+ """
226
+ Restores output data to input data shape.
227
+ """
228
+ data_shape = getattr(self, attr, None)
229
+ if data_shape is None:
230
+ return x
231
+ return x.view(data_shape * len(x.shape))
232
+
233
+
234
+ class Mean(Stat):
235
+ """
236
+ Running mean.
237
+ """
238
+
239
+ def __init__(self, state=None):
240
+ if state is not None:
241
+ return super().__init__(state)
242
+ self.count = 0
243
+ self.batchcount = 0
244
+ self._mean = None
245
+ self.data_shape = None
246
+
247
+ def add(self, a):
248
+ a = self._normalize_add_shape(a)
249
+ if len(a) == 0:
250
+ return
251
+ batch_count = a.shape[0]
252
+ batch_mean = a.sum(0) / batch_count
253
+ self.batchcount += 1
254
+ # Initial batch.
255
+ if self._mean is None:
256
+ self.count = batch_count
257
+ self._mean = batch_mean
258
+ return
259
+ # Update a batch using Chan-style update for numerical stability.
260
+ self.count += batch_count
261
+ new_frac = float(batch_count) / self.count
262
+ # Update the mean according to the batch deviation from the old mean.
263
+ delta = batch_mean.sub_(self._mean).mul_(new_frac)
264
+ self._mean.add_(delta)
265
+
266
+ def size(self):
267
+ return self.count
268
+
269
+ def mean(self):
270
+ return self._restore_result_shape(self._mean)
271
+
272
+ def to_(self, device):
273
+ if self._mean is not None:
274
+ self._mean = self._mean.to(device)
275
+
276
+ def load_state_dict(self, state):
277
+ self.count = state["count"]
278
+ self.batchcount = state["batchcount"]
279
+ self._mean = torch.from_numpy(state["mean"])
280
+ self.data_shape = (
281
+ None if state["data_shape"] is None else tuple(state["data_shape"])
282
+ )
283
+
284
+ def state_dict(self):
285
+ return dict(
286
+ constructor=self.__module__ + "." + self.__class__.__name__ + "()",
287
+ count=self.count,
288
+ data_shape=self.data_shape and tuple(self.data_shape),
289
+ batchcount=self.batchcount,
290
+ mean=self._mean.cpu().numpy(),
291
+ )
292
+
293
+
294
+ class NormMean(Mean):
295
+ """
296
+ Running average of the norm of input vectors
297
+ """
298
+
299
+ def __init__(self, state=None):
300
+ super().__init__(state)
301
+
302
+ def add(self, a):
303
+ super().add(a.norm(dim=-1))
304
+
305
+
306
+ class Variance(Stat):
307
+ """
308
+ Running computation of mean and variance. Use this when you just need
309
+ basic stats without covariance.
310
+ """
311
+
312
+ def __init__(self, state=None):
313
+ if state is not None:
314
+ return super().__init__(state)
315
+ self.count = 0
316
+ self.batchcount = 0
317
+ self._mean = None
318
+ self.v_cmom2 = None
319
+ self.data_shape = None
320
+
321
+ def add(self, a):
322
+ a = self._normalize_add_shape(a)
323
+ if len(a) == 0:
324
+ return
325
+ batch_count = a.shape[0]
326
+ batch_mean = a.sum(0) / batch_count
327
+ centered = a - batch_mean
328
+ self.batchcount += 1
329
+ # Initial batch.
330
+ if self._mean is None:
331
+ self.count = batch_count
332
+ self._mean = batch_mean
333
+ self.v_cmom2 = centered.pow(2).sum(0)
334
+ return
335
+ # Update a batch using Chan-style update for numerical stability.
336
+ oldcount = self.count
337
+ self.count += batch_count
338
+ new_frac = float(batch_count) / self.count
339
+ # Update the mean according to the batch deviation from the old mean.
340
+ delta = batch_mean.sub_(self._mean).mul_(new_frac)
341
+ self._mean.add_(delta)
342
+ # Update the variance using the batch deviation
343
+ self.v_cmom2.add_(centered.pow(2).sum(0))
344
+ self.v_cmom2.add_(delta.pow_(2).mul_(new_frac * oldcount))
345
+
346
+ def size(self):
347
+ return self.count
348
+
349
+ def mean(self):
350
+ return self._restore_result_shape(self._mean)
351
+
352
+ def variance(self, unbiased=True):
353
+ return self._restore_result_shape(
354
+ self.v_cmom2 / (self.count - (1 if unbiased else 0))
355
+ )
356
+
357
+ def stdev(self, unbiased=True):
358
+ return self.variance(unbiased=unbiased).sqrt()
359
+
360
+ def to_(self, device):
361
+ if self._mean is not None:
362
+ self._mean = self._mean.to(device)
363
+ if self.v_cmom2 is not None:
364
+ self.v_cmom2 = self.v_cmom2.to(device)
365
+
366
+ def load_state_dict(self, state):
367
+ self.count = state["count"]
368
+ self.batchcount = state["batchcount"]
369
+ self._mean = torch.from_numpy(state["mean"])
370
+ self.v_cmom2 = torch.from_numpy(state["cmom2"])
371
+ self.data_shape = (
372
+ None if state["data_shape"] is None else tuple(state["data_shape"])
373
+ )
374
+
375
+ def state_dict(self):
376
+ return dict(
377
+ constructor=self.__module__ + "." + self.__class__.__name__ + "()",
378
+ count=self.count,
379
+ data_shape=self.data_shape and tuple(self.data_shape),
380
+ batchcount=self.batchcount,
381
+ mean=self._mean.cpu().numpy(),
382
+ cmom2=self.v_cmom2.cpu().numpy(),
383
+ )
384
+
385
+
386
+ class Covariance(Stat):
387
+ """
388
+ Running computation. Use this when the entire covariance matrix is needed,
389
+ and when the whole covariance matrix fits in the GPU.
390
+
391
+ Chan-style numerically stable update of mean and full covariance matrix.
392
+ Chan, Golub. LeVeque. 1983. http://www.jstor.org/stable/2683386
393
+ """
394
+
395
+ def __init__(self, state=None):
396
+ if state is not None:
397
+ return super().__init__(state)
398
+ self.count = 0
399
+ self._mean = None
400
+ self.cmom2 = None
401
+ self.data_shape = None
402
+
403
+ def add(self, a):
404
+ a = self._normalize_add_shape(a)
405
+ if len(a) == 0:
406
+ return
407
+ batch_count = a.shape[0]
408
+ # Initial batch.
409
+ if self._mean is None:
410
+ self.count = batch_count
411
+ self._mean = a.sum(0) / batch_count
412
+ centered = a - self._mean
413
+ self.cmom2 = centered.t().mm(centered)
414
+ return
415
+ # Update a batch using Chan-style update for numerical stability.
416
+ self.count += batch_count
417
+ # Update the mean according to the batch deviation from the old mean.
418
+ delta = a - self._mean
419
+ self._mean.add_(delta.sum(0) / self.count)
420
+ delta2 = a - self._mean
421
+ # Update the variance using the batch deviation
422
+ self.cmom2.addmm_(mat1=delta.t(), mat2=delta2)
423
+
424
+ def to_(self, device):
425
+ if self._mean is not None:
426
+ self._mean = self._mean.to(device)
427
+ if self.cmom2 is not None:
428
+ self.cmom2 = self.cmom2.to(device)
429
+
430
+ def mean(self):
431
+ return self._restore_result_shape(self._mean)
432
+
433
+ def covariance(self, unbiased=True):
434
+ return self._restore_result_shape(
435
+ self.cmom2 / (self.count - (1 if unbiased else 0))
436
+ )
437
+
438
+ def correlation(self, unbiased=True):
439
+ cov = self.cmom2 / (self.count - (1 if unbiased else 0))
440
+ rstdev = cov.diag().sqrt().reciprocal()
441
+ return self._restore_result_shape(rstdev[:, None] * cov * rstdev[None, :])
442
+
443
+ def variance(self, unbiased=True):
444
+ return self._restore_result_shape(
445
+ self.cmom2.diag() / (self.count - (1 if unbiased else 0))
446
+ )
447
+
448
+ def stdev(self, unbiased=True):
449
+ return self.variance(unbiased=unbiased).sqrt()
450
+
451
+ def state_dict(self):
452
+ return dict(
453
+ constructor=self.__module__ + "." + self.__class__.__name__ + "()",
454
+ count=self.count,
455
+ data_shape=self.data_shape and tuple(self.data_shape),
456
+ mean=self._mean.cpu().numpy(),
457
+ cmom2=self.cmom2.cpu().numpy(),
458
+ )
459
+
460
+ def load_state_dict(self, state):
461
+ self.count = state["count"]
462
+ self._mean = torch.from_numpy(state["mean"])
463
+ self.cmom2 = torch.from_numpy(state["cmom2"])
464
+ self.data_shape = (
465
+ None if state["data_shape"] is None else tuple(state["data_shape"])
466
+ )
467
+
468
+
469
+ class SecondMoment(Stat):
470
+ """
471
+ Running computation. Use this when the entire non-centered 2nd-moment
472
+ 'covariance-like' matrix is needed, and when the whole matrix fits
473
+ in the GPU.
474
+ """
475
+
476
+ def __init__(self, split_batch=True, state=None):
477
+ if state is not None:
478
+ return super().__init__(state)
479
+ self.count = 0
480
+ self.mom2 = None
481
+ self.split_batch = split_batch
482
+
483
+ def add(self, a):
484
+ a = self._normalize_add_shape(a)
485
+ if len(a) == 0:
486
+ return
487
+ # Initial batch reveals the shape of the data.
488
+ if self.count == 0:
489
+ self.mom2 = a.new(a.shape[1], a.shape[1]).zero_()
490
+ batch_count = a.shape[0]
491
+ # Update the covariance using the batch deviation
492
+ self.count += batch_count
493
+ self.mom2 += a.t().mm(a)
494
+
495
+ def to_(self, device):
496
+ if self.mom2 is not None:
497
+ self.mom2 = self.mom2.to(device)
498
+
499
+ def moment(self):
500
+ return self.mom2 / self.count
501
+
502
+ def state_dict(self):
503
+ return dict(
504
+ constructor=self.__module__ + "." + self.__class__.__name__ + "()",
505
+ count=self.count,
506
+ mom2=self.mom2.cpu().numpy(),
507
+ )
508
+
509
+ def load_state_dict(self, state):
510
+ self.count = int(state["count"])
511
+ self.mom2 = torch.from_numpy(state["mom2"])
512
+
513
+
514
+ class Bincount(Stat):
515
+ """
516
+ Running bincount. The counted array should be an integer type with
517
+ non-negative integers.
518
+ """
519
+
520
+ def __init__(self, state=None):
521
+ if state is not None:
522
+ return super().__init__(state)
523
+ self.count = 0
524
+ self._bincount = None
525
+
526
+ def add(self, a, size=None):
527
+ a = a.view(-1)
528
+ bincount = a.bincount()
529
+ if self._bincount is None:
530
+ self._bincount = bincount
531
+ elif len(self._bincount) < len(bincount):
532
+ bincount[: len(self._bincount)] += self._bincount
533
+ self._bincount = bincount
534
+ else:
535
+ self._bincount[: len(bincount)] += bincount
536
+ if size is None:
537
+ self.count += len(a)
538
+ else:
539
+ self.count += size
540
+
541
+ def to_(self, device):
542
+ self._bincount = self._bincount.to(device)
543
+
544
+ def size(self):
545
+ return self.count
546
+
547
+ def bincount(self):
548
+ return self._bincount
549
+
550
+ def state_dict(self):
551
+ return dict(
552
+ constructor=self.__module__ + "." + self.__class__.__name__ + "()",
553
+ count=self.count,
554
+ bincount=self._bincount.cpu().numpy(),
555
+ )
556
+
557
+ def load_state_dict(self, dic):
558
+ self.count = int(dic["count"])
559
+ self._bincount = torch.from_numpy(dic["bincount"])
560
+
561
+
562
+ class CrossCovariance(Stat):
563
+ """
564
+ Covariance. Use this when an off-diagonal block of the covariance
565
+ matrix is needed (e.g., when the whole covariance matrix does
566
+ not fit in the GPU, this could use a quarter of the memory).
567
+
568
+ Chan-style numerically stable update of mean and full covariance matrix.
569
+ Chan, Golub. LeVeque. 1983. http://www.jstor.org/stable/2683386
570
+ """
571
+
572
+ def __init__(self, split_batch=True, state=None):
573
+ if state is not None:
574
+ return super().__init__(state)
575
+ self.count = 0
576
+ self._mean = None
577
+ self.cmom2 = None
578
+ self.v_cmom2 = None
579
+ self.split_batch = split_batch
580
+
581
+ def add(self, a, b):
582
+ if len(a.shape) == 1:
583
+ a = a[None, :]
584
+ b = b[None, :]
585
+ assert a.shape[0] == b.shape[0]
586
+ if len(a.shape) > 2:
587
+ a, b = [
588
+ d.view(d.shape[0], d.shape[1], -1)
589
+ .permute(0, 2, 1)
590
+ .reshape(-1, d.shape[1])
591
+ for d in [a, b]
592
+ ]
593
+ batch_count = a.shape[0]
594
+ # Initial batch.
595
+ if self._mean is None:
596
+ self.count = batch_count
597
+ self._mean = [d.sum(0) / batch_count for d in [a, b]]
598
+ centered = [d - bm for d, bm in zip([a, b], self._mean)]
599
+ self.v_cmom2 = [c.pow(2).sum(0) for c in centered]
600
+ self.cmom2 = centered[0].t().mm(centered[1])
601
+ return
602
+ # Update a batch using Chan-style update for numerical stability.
603
+ self.count += batch_count
604
+ # Update the mean according to the batch deviation from the old mean.
605
+ delta = [(d - bm) for d, bm in zip([a, b], self._mean)]
606
+ for m, d in zip(self._mean, delta):
607
+ m.add_(d.sum(0) / self.count)
608
+ delta2 = [(d - bm) for d, bm in zip([a, b], self._mean)]
609
+ # Update the cross-covariance using the batch deviation
610
+ self.cmom2.addmm_(mat1=delta[0].t(), mat2=delta2[1])
611
+ # Update the variance using the batch deviation
612
+ for vc2, d, d2 in zip(self.v_cmom2, delta, delta2):
613
+ vc2.add_((d * d2).sum(0))
614
+
615
+ def mean(self):
616
+ return self._mean
617
+
618
+ def variance(self, unbiased=True):
619
+ return [vc2 / (self.count - (1 if unbiased else 0)) for vc2 in self.v_cmom2]
620
+
621
+ def stdev(self, unbiased=True):
622
+ return [v.sqrt() for v in self.variance(unbiased=unbiased)]
623
+
624
+ def covariance(self, unbiased=True):
625
+ return self.cmom2 / (self.count - (1 if unbiased else 0))
626
+
627
+ def correlation(self):
628
+ covariance = self.covariance(unbiased=False)
629
+ rstdev = [s.reciprocal() for s in self.stdev(unbiased=False)]
630
+ cor = rstdev[0][:, None] * covariance * rstdev[1][None, :]
631
+ # Remove NaNs
632
+ cor[torch.isnan(cor)] = 0
633
+ return cor
634
+
635
+ def to_(self, device):
636
+ self._mean = [m.to(device) for m in self._mean]
637
+ self.v_cmom2 = [vcs.to(device) for vcs in self.v_cmom2]
638
+ self.cmom2 = self.cmom2.to(device)
639
+
640
+ def state_dict(self):
641
+ return dict(
642
+ constructor=self.__module__ + "." + self.__class__.__name__ + "()",
643
+ count=self.count,
644
+ mean_a=self._mean[0].cpu().numpy(),
645
+ mean_b=self._mean[1].cpu().numpy(),
646
+ cmom2_a=self.v_cmom2[0].cpu().numpy(),
647
+ cmom2_b=self.v_cmom2[1].cpu().numpy(),
648
+ cmom2=self.cmom2.cpu().numpy(),
649
+ )
650
+
651
+ def load_state_dict(self, state):
652
+ self.count = int(state["count"])
653
+ self._mean = [torch.from_numpy(state[f"mean_{k}"]) for k in "ab"]
654
+ self.v_cmom2 = [torch.from_numpy(state[f"cmom2_{k}"]) for k in "ab"]
655
+ self.cmom2 = torch.from_numpy(state["cmom2"])
656
+
657
+
658
+ def _float_from_bool(a):
659
+ """
660
+ Since pytorch only supports matrix multiplication on float,
661
+ IoU computations are done using floating point types.
662
+
663
+ This function binarizes the input (positive to True and
664
+ nonpositive to False), and converts from bool to float.
665
+ If the data is already a floating-point type, it leaves
666
+ it keeps the same type; otherwise it uses float.
667
+ """
668
+ if a.dtype == torch.bool:
669
+ return a.float()
670
+ if a.dtype.is_floating_point:
671
+ return a.sign().clamp_(0)
672
+ return (a > 0).float()
673
+
674
+
675
+ class IoU(Stat):
676
+ """
677
+ Running computation of intersections and unions of all features.
678
+ """
679
+
680
+ def __init__(self, state=None):
681
+ if state is not None:
682
+ return super().__init__(state)
683
+ self.count = 0
684
+ self._intersection = None
685
+
686
+ def add(self, a):
687
+ assert len(a.shape) == 2
688
+ a = _float_from_bool(a)
689
+ if self._intersection is None:
690
+ self._intersection = torch.mm(a.t(), a)
691
+ else:
692
+ self._intersection.addmm_(a.t(), a)
693
+ self.count += len(a)
694
+
695
+ def size(self):
696
+ return self.count
697
+
698
+ def intersection(self):
699
+ return self._intersection
700
+
701
+ def union(self):
702
+ total = self._intersection.diagonal(0)
703
+ return total[:, None] + total[None, :] - self._intersection
704
+
705
+ def iou(self):
706
+ return self.intersection() / (self.union() + 1e-20)
707
+
708
+ def to_(self, _device):
709
+ self._intersection = self._intersection.to(_device)
710
+
711
+ def state_dict(self):
712
+ return dict(
713
+ constructor=self.__module__ + "." + self.__class__.__name__ + "()",
714
+ count=self.count,
715
+ intersection=self._intersection.cpu().numpy(),
716
+ )
717
+
718
+ def load_state_dict(self, state):
719
+ self.count = int(state["count"])
720
+ self._intersection = torch.tensor(state["intersection"])
721
+
722
+
723
+ class CrossIoU(Stat):
724
+ """
725
+ Running computation of intersections and unions of two binary vectors.
726
+ """
727
+
728
+ def __init__(self, state=None):
729
+ if state is not None:
730
+ return super().__init__(state)
731
+ self.count = 0
732
+ self._intersection = None
733
+ self.total_a = None
734
+ self.total_b = None
735
+
736
+ def add(self, a, b):
737
+ assert len(a.shape) == 2 and len(b.shape) == 2
738
+ assert len(a) == len(b), f"{len(a)} vs {len(b)}"
739
+ a = _float_from_bool(a) # CUDA only supports mm on float...
740
+ b = _float_from_bool(b) # otherwise we would use integers.
741
+ intersection = torch.mm(a.t(), b)
742
+ asum = a.sum(0)
743
+ bsum = b.sum(0)
744
+ if self._intersection is None:
745
+ self._intersection = intersection
746
+ self.total_a = asum
747
+ self.total_b = bsum
748
+ else:
749
+ self._intersection += intersection
750
+ self.total_a += asum
751
+ self.total_b += bsum
752
+ self.count += len(a)
753
+
754
+ def size(self):
755
+ return self.count
756
+
757
+ def intersection(self):
758
+ return self._intersection
759
+
760
+ def union(self):
761
+ return self.total_a[:, None] + self.total_b[None, :] - self._intersection
762
+
763
+ def iou(self):
764
+ return self.intersection() / (self.union() + 1e-20)
765
+
766
+ def to_(self, _device):
767
+ self.total_a = self.total_a.to(_device)
768
+ self.total_b = self.total_b.to(_device)
769
+ self._intersection = self._intersection.to(_device)
770
+
771
+ def state_dict(self):
772
+ return dict(
773
+ constructor=self.__module__ + "." + self.__class__.__name__ + "()",
774
+ count=self.count,
775
+ total_a=self.total_a.cpu().numpy(),
776
+ total_b=self.total_b.cpu().numpy(),
777
+ intersection=self._intersection.cpu().numpy(),
778
+ )
779
+
780
+ def load_state_dict(self, state):
781
+ self.count = int(state["count"])
782
+ self.total_a = torch.tensor(state["total_a"])
783
+ self.total_b = torch.tensor(state["total_b"])
784
+ self._intersection = torch.tensor(state["intersection"])
785
+
786
+
787
+ class Quantile(Stat):
788
+ """
789
+ Streaming randomized quantile computation for torch.
790
+
791
+ Add any amount of data repeatedly via add(data). At any time,
792
+ quantile estimates be read out using quantile(q).
793
+
794
+ Implemented as a sorted sample that retains at least r samples
795
+ (by default r = 3072); the number of retained samples will grow to
796
+ a finite ceiling as the data is accumulated. Accuracy scales according
797
+ to r: the default is to set resolution to be accurate to better than about
798
+ 0.1%, while limiting storage to about 50,000 samples.
799
+
800
+ Good for computing quantiles of huge data without using much memory.
801
+ Works well on arbitrary data with probability near 1.
802
+
803
+ Based on the optimal KLL quantile algorithm by Karnin, Lang, and Liberty
804
+ from FOCS 2016. http://ieee-focs.org/FOCS-2016-Papers/3933a071.pdf
805
+ """
806
+
807
+ def __init__(self, r=3 * 1024, buffersize=None, seed=None, state=None):
808
+ if state is not None:
809
+ return super().__init__(state)
810
+ self.depth = None
811
+ self.dtype = None
812
+ self.device = None
813
+ resolution = r * 2 # sample array is at least half full before discard
814
+ self.resolution = resolution
815
+ # Default buffersize: 128 samples (and smaller than resolution).
816
+ if buffersize is None:
817
+ buffersize = min(128, (resolution + 7) // 8)
818
+ self.buffersize = buffersize
819
+ self.samplerate = 1.0
820
+ self.data = None
821
+ self.firstfree = [0]
822
+ self.randbits = torch.ByteTensor(resolution)
823
+ self.currentbit = len(self.randbits) - 1
824
+ self.extremes = None
825
+ self.count = 0
826
+ self.batchcount = 0
827
+
828
+ def size(self):
829
+ return self.count
830
+
831
+ def _lazy_init(self, incoming):
832
+ self.depth = incoming.shape[1]
833
+ self.dtype = incoming.dtype
834
+ self.device = incoming.device
835
+ self.data = [
836
+ torch.zeros(
837
+ self.depth, self.resolution, dtype=self.dtype, device=self.device
838
+ )
839
+ ]
840
+ self.extremes = torch.zeros(self.depth, 2, dtype=self.dtype, device=self.device)
841
+ self.extremes[:, 0] = float("inf")
842
+ self.extremes[:, -1] = -float("inf")
843
+
844
+ def to_(self, device):
845
+ """Switches internal storage to specified device."""
846
+ if device != self.device:
847
+ old_data = self.data
848
+ old_extremes = self.extremes
849
+ self.data = [d.to(device) for d in self.data]
850
+ self.extremes = self.extremes.to(device)
851
+ self.device = self.extremes.device
852
+ del old_data
853
+ del old_extremes
854
+
855
+ def add(self, incoming):
856
+ if self.depth is None:
857
+ self._lazy_init(incoming)
858
+ assert len(incoming.shape) == 2
859
+ assert incoming.shape[1] == self.depth, (incoming.shape[1], self.depth)
860
+ self.count += incoming.shape[0]
861
+ self.batchcount += 1
862
+ # Convert to a flat torch array.
863
+ if self.samplerate >= 1.0:
864
+ self._add_every(incoming)
865
+ return
866
+ # If we are sampling, then subsample a large chunk at a time.
867
+ self._scan_extremes(incoming)
868
+ chunksize = int(math.ceil(self.buffersize / self.samplerate))
869
+ for index in range(0, len(incoming), chunksize):
870
+ batch = incoming[index : index + chunksize]
871
+ sample = sample_portion(batch, self.samplerate)
872
+ if len(sample):
873
+ self._add_every(sample)
874
+
875
+ def _add_every(self, incoming):
876
+ supplied = len(incoming)
877
+ index = 0
878
+ while index < supplied:
879
+ ff = self.firstfree[0]
880
+ available = self.data[0].shape[1] - ff
881
+ if available == 0:
882
+ if not self._shift():
883
+ # If we shifted by subsampling, then subsample.
884
+ incoming = incoming[index:]
885
+ if self.samplerate >= 0.5:
886
+ # First time sampling - the data source is very large.
887
+ self._scan_extremes(incoming)
888
+ incoming = sample_portion(incoming, self.samplerate)
889
+ index = 0
890
+ supplied = len(incoming)
891
+ ff = self.firstfree[0]
892
+ available = self.data[0].shape[1] - ff
893
+ copycount = min(available, supplied - index)
894
+ self.data[0][:, ff : ff + copycount] = torch.t(
895
+ incoming[index : index + copycount, :]
896
+ )
897
+ self.firstfree[0] += copycount
898
+ index += copycount
899
+
900
+ def _shift(self):
901
+ index = 0
902
+ # If remaining space at the current layer is less than half prev
903
+ # buffer size (rounding up), then we need to shift it up to ensure
904
+ # enough space for future shifting.
905
+ while self.data[index].shape[1] - self.firstfree[index] < (
906
+ -(-self.data[index - 1].shape[1] // 2) if index else 1
907
+ ):
908
+ if index + 1 >= len(self.data):
909
+ return self._expand()
910
+ data = self.data[index][:, 0 : self.firstfree[index]]
911
+ data = data.sort()[0]
912
+ if index == 0 and self.samplerate >= 1.0:
913
+ self._update_extremes(data[:, 0], data[:, -1])
914
+ offset = self._randbit()
915
+ position = self.firstfree[index + 1]
916
+ subset = data[:, offset::2]
917
+ self.data[index + 1][:, position : position + subset.shape[1]] = subset
918
+ self.firstfree[index] = 0
919
+ self.firstfree[index + 1] += subset.shape[1]
920
+ index += 1
921
+ return True
922
+
923
+ def _scan_extremes(self, incoming):
924
+ # When sampling, we need to scan every item still to get extremes
925
+ self._update_extremes(
926
+ torch.min(incoming, dim=0)[0], torch.max(incoming, dim=0)[0]
927
+ )
928
+
929
+ def _update_extremes(self, minr, maxr):
930
+ self.extremes[:, 0] = torch.min(
931
+ torch.stack([self.extremes[:, 0], minr]), dim=0
932
+ )[0]
933
+ self.extremes[:, -1] = torch.max(
934
+ torch.stack([self.extremes[:, -1], maxr]), dim=0
935
+ )[0]
936
+
937
+ def _randbit(self):
938
+ self.currentbit += 1
939
+ if self.currentbit >= len(self.randbits):
940
+ self.randbits.random_(to=2)
941
+ self.currentbit = 0
942
+ return self.randbits[self.currentbit]
943
+
944
+ def state_dict(self):
945
+ state = dict(
946
+ constructor=self.__module__ + "." + self.__class__.__name__ + "()",
947
+ resolution=self.resolution,
948
+ depth=self.depth,
949
+ buffersize=self.buffersize,
950
+ samplerate=self.samplerate,
951
+ sizes=numpy.array([d.shape[1] for d in self.data]),
952
+ extremes=self.extremes.cpu().detach().numpy(),
953
+ size=self.count,
954
+ batchcount=self.batchcount,
955
+ )
956
+ for i, (d, f) in enumerate(zip(self.data, self.firstfree)):
957
+ state[f"data.{i}"] = d.cpu().detach().numpy()[:, :f].T
958
+ return state
959
+
960
+ def load_state_dict(self, state):
961
+ self.resolution = int(state["resolution"])
962
+ self.randbits = torch.ByteTensor(self.resolution)
963
+ self.currentbit = len(self.randbits) - 1
964
+ self.depth = int(state["depth"])
965
+ self.buffersize = int(state["buffersize"])
966
+ self.samplerate = float(state["samplerate"])
967
+ firstfree = []
968
+ buffers = []
969
+ for i, s in enumerate(state["sizes"]):
970
+ d = state[f"data.{i}"]
971
+ firstfree.append(d.shape[0])
972
+ buf = numpy.zeros((d.shape[1], s), dtype=d.dtype)
973
+ buf[:, : d.shape[0]] = d.T
974
+ buffers.append(torch.from_numpy(buf))
975
+ self.firstfree = firstfree
976
+ self.data = buffers
977
+ self.extremes = torch.from_numpy((state["extremes"]))
978
+ self.count = int(state["size"])
979
+ self.batchcount = int(state.get("batchcount", 0))
980
+ self.dtype = self.extremes.dtype
981
+ self.device = self.extremes.device
982
+
983
+ def min(self):
984
+ return self.minmax()[0]
985
+
986
+ def max(self):
987
+ return self.minmax()[-1]
988
+
989
+ def minmax(self):
990
+ if self.firstfree[0]:
991
+ self._scan_extremes(self.data[0][:, : self.firstfree[0]].t())
992
+ return self.extremes.clone()
993
+
994
+ def median(self):
995
+ return self.quantiles(0.5)
996
+
997
+ def mean(self):
998
+ return self.integrate(lambda x: x) / self.count
999
+
1000
+ def variance(self, unbiased=True):
1001
+ mean = self.mean()[:, None]
1002
+ return self.integrate(lambda x: (x - mean).pow(2)) / (
1003
+ self.count - (1 if unbiased else 0)
1004
+ )
1005
+
1006
+ def stdev(self, unbiased=True):
1007
+ return self.variance(unbiased=unbiased).sqrt()
1008
+
1009
+ def _expand(self):
1010
+ cap = self._next_capacity()
1011
+ if cap > 0:
1012
+ # First, make a new layer of the proper capacity.
1013
+ self.data.insert(
1014
+ 0, torch.zeros(self.depth, cap, dtype=self.dtype, device=self.device)
1015
+ )
1016
+ self.firstfree.insert(0, 0)
1017
+ else:
1018
+ # Unless we're so big we are just subsampling.
1019
+ assert self.firstfree[0] == 0
1020
+ self.samplerate *= 0.5
1021
+ for index in range(1, len(self.data)):
1022
+ # Scan for existing data that needs to be moved down a level.
1023
+ amount = self.firstfree[index]
1024
+ if amount == 0:
1025
+ continue
1026
+ position = self.firstfree[index - 1]
1027
+ # Move data down if it would leave enough empty space there
1028
+ # This is the key invariant: enough empty space to fit half
1029
+ # of the previous level's buffer size (rounding up)
1030
+ if self.data[index - 1].shape[1] - (amount + position) >= (
1031
+ -(-self.data[index - 2].shape[1] // 2) if (index - 1) else 1
1032
+ ):
1033
+ self.data[index - 1][:, position : position + amount] = self.data[
1034
+ index
1035
+ ][:, :amount]
1036
+ self.firstfree[index - 1] += amount
1037
+ self.firstfree[index] = 0
1038
+ else:
1039
+ # Scrunch the data if it would not.
1040
+ data = self.data[index][:, :amount]
1041
+ data = data.sort()[0]
1042
+ if index == 1:
1043
+ self._update_extremes(data[:, 0], data[:, -1])
1044
+ offset = self._randbit()
1045
+ scrunched = data[:, offset::2]
1046
+ self.data[index][:, : scrunched.shape[1]] = scrunched
1047
+ self.firstfree[index] = scrunched.shape[1]
1048
+ return cap > 0
1049
+
1050
+ def _next_capacity(self):
1051
+ cap = int(math.ceil(self.resolution * (0.67 ** len(self.data))))
1052
+ if cap < 2:
1053
+ return 0
1054
+ # Round up to the nearest multiple of 8 for better GPU alignment.
1055
+ cap = -8 * (-cap // 8)
1056
+ return max(self.buffersize, cap)
1057
+
1058
+ def _weighted_summary(self, sort=True):
1059
+ if self.firstfree[0]:
1060
+ self._scan_extremes(self.data[0][:, : self.firstfree[0]].t())
1061
+ size = sum(self.firstfree)
1062
+ weights = torch.FloatTensor(size) # Floating point
1063
+ summary = torch.zeros(self.depth, size, dtype=self.dtype, device=self.device)
1064
+ index = 0
1065
+ for level, ff in enumerate(self.firstfree):
1066
+ if ff == 0:
1067
+ continue
1068
+ summary[:, index : index + ff] = self.data[level][:, :ff]
1069
+ weights[index : index + ff] = 2.0**level
1070
+ index += ff
1071
+ assert index == summary.shape[1]
1072
+ if sort:
1073
+ summary, order = torch.sort(summary, dim=-1)
1074
+ weights = weights[order.view(-1).cpu()].view(order.shape)
1075
+ summary = torch.cat(
1076
+ [self.extremes[:, :1], summary, self.extremes[:, 1:]], dim=-1
1077
+ )
1078
+ weights = torch.cat(
1079
+ [
1080
+ torch.zeros(weights.shape[0], 1),
1081
+ weights,
1082
+ torch.zeros(weights.shape[0], 1),
1083
+ ],
1084
+ dim=-1,
1085
+ )
1086
+ return (summary, weights)
1087
+
1088
+ def quantiles(self, quantiles):
1089
+ if not hasattr(quantiles, "cpu"):
1090
+ quantiles = torch.tensor(quantiles)
1091
+ qshape = quantiles.shape
1092
+ if self.count == 0:
1093
+ return torch.full((self.depth,) + qshape, torch.nan)
1094
+ summary, weights = self._weighted_summary()
1095
+ cumweights = torch.cumsum(weights, dim=-1) - weights / 2
1096
+ cumweights /= torch.sum(weights, dim=-1, keepdim=True)
1097
+ result = torch.zeros(
1098
+ self.depth, quantiles.numel(), dtype=self.dtype, device=self.device
1099
+ )
1100
+ # numpy is needed for interpolation
1101
+ nq = quantiles.view(-1).cpu().detach().numpy()
1102
+ ncw = cumweights.cpu().detach().numpy()
1103
+ nsm = summary.cpu().detach().numpy()
1104
+ for d in range(self.depth):
1105
+ result[d] = torch.tensor(
1106
+ numpy.interp(nq, ncw[d], nsm[d]), dtype=self.dtype, device=self.device
1107
+ )
1108
+ return result.view((self.depth,) + qshape)
1109
+
1110
+ def integrate(self, fun):
1111
+ result = []
1112
+ for level, ff in enumerate(self.firstfree):
1113
+ if ff == 0:
1114
+ continue
1115
+ result.append(
1116
+ torch.sum(fun(self.data[level][:, :ff]) * (2.0**level), dim=-1)
1117
+ )
1118
+ if len(result) == 0:
1119
+ return None
1120
+ return torch.stack(result).sum(dim=0) / self.samplerate
1121
+
1122
+ def readout(self, count=1001):
1123
+ return self.quantiles(torch.linspace(0.0, 1.0, count))
1124
+
1125
+ def normalize(self, data):
1126
+ """
1127
+ Given input data as taken from the training distirbution,
1128
+ normalizes every channel to reflect quantile values,
1129
+ uniformly distributed, within [0, 1].
1130
+ """
1131
+ assert self.count > 0
1132
+ assert data.shape[0] == self.depth
1133
+ summary, weights = self._weighted_summary()
1134
+ cumweights = torch.cumsum(weights, dim=-1) - weights / 2
1135
+ cumweights /= torch.sum(weights, dim=-1, keepdim=True)
1136
+ result = torch.zeros_like(data).float()
1137
+ # numpy is needed for interpolation
1138
+ ndata = data.cpu().numpy().reshape((data.shape[0], -1))
1139
+ ncw = cumweights.cpu().numpy()
1140
+ nsm = summary.cpu().numpy()
1141
+ for d in range(self.depth):
1142
+ normed = torch.tensor(
1143
+ numpy.interp(ndata[d], nsm[d], ncw[d]),
1144
+ dtype=torch.float,
1145
+ device=data.device,
1146
+ ).clamp_(0.0, 1.0)
1147
+ if len(data.shape) > 1:
1148
+ normed = normed.view(*(data.shape[1:]))
1149
+ result[d] = normed
1150
+ return result
1151
+
1152
+
1153
+ def sample_portion(vec, p=0.5):
1154
+ """
1155
+ Subsamples a fraction (given by p) of the given batch. Used by
1156
+ Quantile when the data gets very very large.
1157
+ """
1158
+ bits = torch.bernoulli(
1159
+ torch.zeros(vec.shape[0], dtype=torch.uint8, device=vec.device), p
1160
+ )
1161
+ return vec[bits]
1162
+
1163
+
1164
+ class TopK:
1165
+ """
1166
+ A class to keep a running tally of the the top k values (and indexes)
1167
+ of any number of torch feature components. Will work on the GPU if
1168
+ the data is on the GPU. Tracks largest by default, but tracks smallest
1169
+ if largest=False is passed.
1170
+
1171
+ This version flattens all arrays to avoid crashes.
1172
+ """
1173
+
1174
+ def __init__(self, k=100, largest=True, state=None):
1175
+ if state is not None:
1176
+ return super().__init__(state)
1177
+ self.k = k
1178
+ self.count = 0
1179
+ # This version flattens all data internally to 2-d tensors,
1180
+ # to avoid crashes with the current pytorch topk implementation.
1181
+ # The data is puffed back out to arbitrary tensor shapes on output.
1182
+ self.data_shape = None
1183
+ self.top_data = None
1184
+ self.top_index = None
1185
+ self.next = 0
1186
+ self.linear_index = 0
1187
+ self.perm = None
1188
+ self.largest = largest
1189
+
1190
+ def add(self, data, index=None):
1191
+ """
1192
+ Adds a batch of data to be considered for the running top k.
1193
+ The zeroth dimension enumerates the observations. All other
1194
+ dimensions enumerate different features.
1195
+ """
1196
+ if self.top_data is None:
1197
+ # Allocation: allocate a buffer of size 5*k, at least 10, for each.
1198
+ self.data_shape = data.shape[1:]
1199
+ feature_size = int(numpy.prod(self.data_shape))
1200
+ self.top_data = torch.zeros(
1201
+ feature_size, max(10, self.k * 5), out=data.new()
1202
+ )
1203
+ self.top_index = self.top_data.clone().long()
1204
+ self.linear_index = (
1205
+ 0
1206
+ if len(data.shape) == 1
1207
+ else torch.arange(feature_size, out=self.top_index.new()).mul_(
1208
+ self.top_data.shape[-1]
1209
+ )[:, None]
1210
+ )
1211
+ size = data.shape[0]
1212
+ sk = min(size, self.k)
1213
+ if self.top_data.shape[-1] < self.next + sk:
1214
+ # Compression: if full, keep topk only.
1215
+ self.top_data[:, : self.k], self.top_index[:, : self.k] = self.topk(
1216
+ sorted=False, flat=True
1217
+ )
1218
+ self.next = self.k
1219
+ # Pick: copy the top sk of the next batch into the buffer.
1220
+ # Currently strided topk is slow. So we clone after transpose.
1221
+ # TODO: remove the clone() if it becomes faster.
1222
+ cdata = data.reshape(size, numpy.prod(data.shape[1:])).t().clone()
1223
+ td, ti = cdata.topk(sk, sorted=False, largest=self.largest)
1224
+ self.top_data[:, self.next : self.next + sk] = td
1225
+ if index is not None:
1226
+ ti = index[ti]
1227
+ else:
1228
+ ti = ti + self.count
1229
+ self.top_index[:, self.next : self.next + sk] = ti
1230
+ self.next += sk
1231
+ self.count += size
1232
+
1233
+ def size(self):
1234
+ return self.count
1235
+
1236
+ def topk(self, sorted=True, flat=False):
1237
+ """
1238
+ Returns top k data items and indexes in each dimension,
1239
+ with channels in the first dimension and k in the last dimension.
1240
+ """
1241
+ k = min(self.k, self.next)
1242
+ # bti are top indexes relative to buffer array.
1243
+ td, bti = self.top_data[:, : self.next].topk(
1244
+ k, sorted=sorted, largest=self.largest
1245
+ )
1246
+ # we want to report top indexes globally, which is ti.
1247
+ ti = self.top_index.view(-1)[(bti + self.linear_index).view(-1)].view(
1248
+ *bti.shape
1249
+ )
1250
+ if flat:
1251
+ return td, ti
1252
+ else:
1253
+ return (
1254
+ td.view(*(self.data_shape + (-1,))),
1255
+ ti.view(*(self.data_shape + (-1,))),
1256
+ )
1257
+
1258
+ def to_(self, device):
1259
+ if self.top_data is not None:
1260
+ self.top_data = self.top_data.to(device)
1261
+ if self.top_index is not None:
1262
+ self.top_index = self.top_index.to(device)
1263
+ if isinstance(self.linear_index, torch.Tensor):
1264
+ self.linear_index = self.linear_index.to(device)
1265
+
1266
+ def state_dict(self):
1267
+ return dict(
1268
+ constructor=self.__module__ + "." + self.__class__.__name__ + "()",
1269
+ k=self.k,
1270
+ count=self.count,
1271
+ largest=self.largest,
1272
+ data_shape=self.data_shape and tuple(self.data_shape),
1273
+ top_data=self.top_data.cpu().detach().numpy(),
1274
+ top_index=self.top_index.cpu().detach().numpy(),
1275
+ next=self.next,
1276
+ linear_index=(
1277
+ self.linear_index.cpu().numpy()
1278
+ if isinstance(self.linear_index, torch.Tensor)
1279
+ else self.linear_index
1280
+ ),
1281
+ perm=self.perm,
1282
+ )
1283
+
1284
+ def load_state_dict(self, state):
1285
+ self.k = int(state["k"])
1286
+ self.count = int(state["count"])
1287
+ self.largest = bool(state.get("largest", True))
1288
+ self.data_shape = (
1289
+ None if state["data_shape"] is None else tuple(state["data_shape"])
1290
+ )
1291
+ self.top_data = torch.from_numpy(state["top_data"])
1292
+ self.top_index = torch.from_numpy(state["top_index"])
1293
+ self.next = int(state["next"])
1294
+ self.linear_index = (
1295
+ torch.from_numpy(state["linear_index"])
1296
+ if len(state["linear_index"].shape) > 0
1297
+ else int(state["linear_index"])
1298
+ )
1299
+
1300
+
1301
+ class History(Stat):
1302
+ """
1303
+ Accumulates the concatenation of all the added data.
1304
+ """
1305
+
1306
+ def __init__(self, data=None, state=None):
1307
+ if state is not None:
1308
+ return super().__init__(state)
1309
+ self._data = data
1310
+ self._added = []
1311
+
1312
+ def _cat_added(self):
1313
+ if len(self._added):
1314
+ self._data = torch.cat(
1315
+ ([self._data] if self._data is not None else []) + self._added
1316
+ )
1317
+ self._added = []
1318
+
1319
+ def add(self, d):
1320
+ self._added.append(d)
1321
+ if len(self._added) > 100:
1322
+ self._cat_added()
1323
+
1324
+ def history(self):
1325
+ self._cat_added()
1326
+ return self._data
1327
+
1328
+ def load_state_dict(self, state):
1329
+ data = state["data"]
1330
+ self._data = None if data is None else torch.from_numpy(data)
1331
+ self._added = []
1332
+
1333
+ def state_dict(self):
1334
+ self._cat_added()
1335
+ return dict(
1336
+ constructor=self.__module__ + "." + self.__class__.__name__ + "()",
1337
+ data=None if self._data is None else self._data.cpu().numpy(),
1338
+ )
1339
+
1340
+ def to_(self, device):
1341
+ """Switches internal storage to specified device."""
1342
+ self._cat_added()
1343
+ if self._data is not None:
1344
+ self._data = self._data.to(device)
1345
+
1346
+
1347
+ class CombinedStat(Stat):
1348
+ """
1349
+ A Stat that bundles together multiple Stat objects.
1350
+ Convenient for loading and saving a state_dict made up of a
1351
+ hierarchy of stats, and for use with the tally() function.
1352
+ Example:
1353
+
1354
+ cs = CombinedStat(m=Mean(), q=Quantile())
1355
+ for [b] in tally(cs, MyDataSet(), cache=fn, batch_size=100):
1356
+ cs.add(b)
1357
+ print(cs.m.mean())
1358
+ print(cs.q.median())
1359
+ """
1360
+
1361
+ def __init__(self, state=None, **kwargs):
1362
+ self._objs = kwargs
1363
+ if state is not None:
1364
+ return super().__init__(state)
1365
+
1366
+ def __getattr__(self, k):
1367
+ if k in self._objs:
1368
+ return self._objs[k]
1369
+ raise AttributeError()
1370
+
1371
+ def add(self, d, *args, **kwargs):
1372
+ for obj in self._objs.values():
1373
+ obj.add(d, *args, **kwargs)
1374
+
1375
+ def load_state_dict(self, state):
1376
+ for prefix, obj in self._objs.items():
1377
+ obj.load_state_dict(pull_key_prefix(prefix, state))
1378
+
1379
+ def state_dict(self):
1380
+ result = {}
1381
+ for prefix, obj in self._objs.items():
1382
+ result.update(push_key_prefix(prefix, obj.state_dict()))
1383
+ return result
1384
+
1385
+ def to_(self, device):
1386
+ """Switches internal storage to specified device."""
1387
+ for v in self._objs.values():
1388
+ v.to_(device)
1389
+
1390
+
1391
+ def push_key_prefix(prefix, d):
1392
+ """
1393
+ Returns a dict with the same values as d, but where each key
1394
+ adds the prefix, followed by a dot.
1395
+ """
1396
+ return {prefix + "." + k: v for k, v in d.items()}
1397
+
1398
+
1399
+ def pull_key_prefix(prefix, d):
1400
+ """
1401
+ Returns a filtered dict of all the items of d that start with
1402
+ the given key prefix, plus a dot, with that prefix removed.
1403
+ """
1404
+ pd = prefix + "."
1405
+ lpd = len(pd)
1406
+ return {k[lpd:]: v for k, v in d.items() if k.startswith(pd)}
1407
+
1408
+
1409
+ # We wish to be able to save None (null) values in numpy npz files,
1410
+ # yet do so without setting the unsecure 'allow_pickle' flag. To do
1411
+ # that, we will encode null as a special kind of IEEE 754 NaN value.
1412
+ # Inspired by https://github.com/zuiderkwast/nanbox/blob/master/nanbox.h
1413
+ # we follow the same Nanboxing scheme used in JavaScriptCore
1414
+ # (search for JSCJSValue.h#L435), which encodes null values in NaN
1415
+ # as the NaN value with hex pattern 0xfff8000000000002.
1416
+
1417
+ null_numpy_value = numpy.array(
1418
+ struct.unpack(">d", struct.pack(">Q", 0xFFF8000000000002))[0], dtype=numpy.float64
1419
+ )
1420
+
1421
+
1422
+ def is_null_numpy_value(v):
1423
+ """
1424
+ True if v is a 64-bit float numpy scalar NaN matching null_numpy_value.
1425
+ """
1426
+ return (
1427
+ isinstance(v, numpy.ndarray)
1428
+ and numpy.ndim(v) == 0
1429
+ and v.dtype == numpy.float64
1430
+ and numpy.isnan(v)
1431
+ and 0xFFF8000000000002 == struct.unpack(">Q", struct.pack(">d", v))[0]
1432
+ )
1433
+
1434
+
1435
+ def box_numpy_null(d):
1436
+ """
1437
+ Replaces None with null_numpy_value, leaving non-None values unchanged.
1438
+ Recursively descends into a dictionary replacing None values.
1439
+ """
1440
+ try:
1441
+ return {k: box_numpy_null(v) for k, v in d.items()}
1442
+ except Exception:
1443
+ return null_numpy_value if d is None else d
1444
+
1445
+
1446
+ def unbox_numpy_null(d):
1447
+ """
1448
+ Reverses box_numpy_null, replacing null_numpy_value with None.
1449
+ Recursively descends into a dictionary replacing None values.
1450
+ """
1451
+ try:
1452
+ return {k: unbox_numpy_null(v) for k, v in d.items()}
1453
+ except Exception:
1454
+ return None if is_null_numpy_value(d) else d
1455
+
1456
+
1457
+ def resolve_state_dict(s):
1458
+ """
1459
+ Resolves a state, which can be a filename or a dict-like object.
1460
+ """
1461
+ if isinstance(s, str):
1462
+ return unbox_numpy_null(numpy.load(s))
1463
+ return s
1464
+
1465
+
1466
+ global_load_cache_enabled = True
1467
+
1468
+
1469
+ def load_cached_state(cachefile, args, quiet=False, throw=False):
1470
+ """
1471
+ Resolves a state, which can be a filename or a dict-like object.
1472
+ """
1473
+ if not global_load_cache_enabled or cachefile is None:
1474
+ return None
1475
+ try:
1476
+ if isinstance(cachefile, dict):
1477
+ dat = cachefile
1478
+ cachefile = "state" # for printed messages
1479
+ else:
1480
+ dat = unbox_numpy_null(numpy.load(cachefile))
1481
+ for a, v in args.items():
1482
+ if a not in dat or dat[a] != v:
1483
+ if not quiet:
1484
+ print("%s %s changed from %s to %s" % (cachefile, a, dat[a], v))
1485
+ return None
1486
+ except (FileNotFoundError, ValueError) as e:
1487
+ if throw:
1488
+ raise e
1489
+ return None
1490
+ else:
1491
+ if not quiet:
1492
+ print("Loading cached %s" % cachefile)
1493
+ return dat
1494
+
1495
+
1496
+ def save_cached_state(cachefile, obj, args):
1497
+ """
1498
+ Saves the state_dict of the given object in a dict or npz file.
1499
+ """
1500
+ if cachefile is None:
1501
+ return
1502
+ dat = obj.state_dict()
1503
+ for a, v in args.items():
1504
+ if a in dat:
1505
+ assert dat[a] == v
1506
+ dat[a] = v
1507
+ if isinstance(cachefile, dict):
1508
+ cachefile.clear()
1509
+ cachefile.update(dat)
1510
+ else:
1511
+ os.makedirs(os.path.dirname(cachefile), exist_ok=True)
1512
+ numpy.savez(cachefile, **box_numpy_null(dat))
1513
+
1514
+
1515
+ class FixedSubsetSampler(Sampler):
1516
+ """Represents a fixed sequence of data set indices.
1517
+ Subsets can be created by specifying a subset of output indexes.
1518
+ """
1519
+
1520
+ def __init__(self, samples):
1521
+ self.samples = samples
1522
+
1523
+ def __iter__(self):
1524
+ return iter(self.samples)
1525
+
1526
+ def __len__(self):
1527
+ return len(self.samples)
1528
+
1529
+ def __getitem__(self, key):
1530
+ return self.samples[key]
1531
+
1532
+ def subset(self, new_subset):
1533
+ return FixedSubsetSampler(self.dereference(new_subset))
1534
+
1535
+ def dereference(self, indices):
1536
+ """
1537
+ Translate output sample indices (small numbers indexing the sample)
1538
+ to input sample indices (larger number indexing the original full set)
1539
+ """
1540
+ return [self.samples[i] for i in indices]
1541
+
1542
+
1543
+ class FixedRandomSubsetSampler(FixedSubsetSampler):
1544
+ """Samples a fixed number of samples from the dataset, deterministically.
1545
+ Arguments:
1546
+ data_source,
1547
+ sample_size,
1548
+ seed (optional)
1549
+ """
1550
+
1551
+ def __init__(self, data_source, start=None, end=None, seed=1):
1552
+ rng = random.Random(seed)
1553
+ shuffled = list(range(len(data_source)))
1554
+ rng.shuffle(shuffled)
1555
+ self.data_source = data_source
1556
+ super(FixedRandomSubsetSampler, self).__init__(shuffled[start:end])
1557
+
1558
+ def class_subset(self, class_filter):
1559
+ """
1560
+ Returns only the subset matching the given rule.
1561
+ """
1562
+ if isinstance(class_filter, int):
1563
+
1564
+ def rule(d):
1565
+ return d[1] == class_filter
1566
+
1567
+ else:
1568
+ rule = class_filter
1569
+ return self.subset(
1570
+ [i for i, j in enumerate(self.samples) if rule(self.data_source[j])]
1571
+ )
1572
+
1573
+
1574
+ def make_loader(
1575
+ dataset, sample_size=None, batch_size=1, sampler=None, random_sample=None, **kwargs
1576
+ ):
1577
+ """Utility for creating a dataloader on fixed sample subset."""
1578
+ import typing
1579
+
1580
+ if isinstance(dataset, typing.Callable):
1581
+ # To support deferred dataset loading, support passing a factory
1582
+ # that creates the dataset when called.
1583
+ dataset = dataset()
1584
+ if isinstance(dataset, torch.Tensor):
1585
+ # The dataset can be a simple tensor.
1586
+ dataset = torch.utils.data.TensorDataset(dataset)
1587
+ if sample_size is not None:
1588
+ assert sampler is None, "sampler cannot be specified with sample_size"
1589
+ if sample_size > len(dataset):
1590
+ print(
1591
+ "Warning: sample size %d > dataset size %d"
1592
+ % (sample_size, len(dataset))
1593
+ )
1594
+ sample_size = len(dataset)
1595
+ if random_sample is None:
1596
+ sampler = FixedSubsetSampler(list(range(sample_size)))
1597
+ else:
1598
+ sampler = FixedRandomSubsetSampler(
1599
+ dataset, seed=random_sample, end=sample_size
1600
+ )
1601
+ return torch.utils.data.DataLoader(
1602
+ dataset, sampler=sampler, batch_size=batch_size, **kwargs
1603
+ )
1604
+
1605
+
1606
+ # Unit Tests
1607
+ def _unit_test():
1608
+ import warnings
1609
+
1610
+ warnings.filterwarnings("error")
1611
+ import argparse
1612
+ import random
1613
+ import shutil
1614
+ import tempfile
1615
+ import time
1616
+
1617
+ parser = argparse.ArgumentParser(description="Test things out")
1618
+ parser.add_argument("--mode", default="cpu", help="cpu or cuda")
1619
+ parser.add_argument("--test_size", type=int, default=1000000)
1620
+ args = parser.parse_args()
1621
+ testdir = tempfile.mkdtemp()
1622
+ batch_size = random.randint(500, 1500)
1623
+
1624
+ # Test NaNboxing.
1625
+ assert numpy.isnan(null_numpy_value)
1626
+ assert is_null_numpy_value(null_numpy_value)
1627
+ assert not is_null_numpy_value(numpy.nan)
1628
+
1629
+ # Test Covariance
1630
+ goal = torch.tensor(numpy.random.RandomState(1).standard_normal(10 * 10)).view(
1631
+ 10, 10
1632
+ )
1633
+ data = (
1634
+ torch.tensor(numpy.random.RandomState(2).standard_normal(args.test_size * 10))
1635
+ .view(args.test_size, 10)
1636
+ .mm(goal)
1637
+ )
1638
+ data += torch.randn(1, 10) * 999
1639
+ dcov = data.t().cov()
1640
+ dcorr = data.t().corrcoef()
1641
+ rcov = Covariance()
1642
+ rcov.add(data) # All one batch
1643
+ assert (rcov.covariance() - dcov).abs().max() < 1e-16
1644
+ cs = CombinedStat(cov=Covariance(), xcov=CrossCovariance())
1645
+ ds = torch.utils.data.TensorDataset(data)
1646
+ for [a] in tally(cs, ds, batch_size=9876):
1647
+ cs.cov.add(a)
1648
+ cs.xcov.add(a[:, :3], a[:, 3:])
1649
+ assert (data.mean(0) - cs.cov.mean()).abs().max() < 1e-12
1650
+ assert (dcov - cs.cov.covariance()).abs().max() < 2e-12
1651
+ assert (dcov[:3, 3:] - cs.xcov.covariance()).abs().max() < 1e-12
1652
+ assert (dcov.diagonal() - torch.cat(cs.xcov.variance())).abs().max() < 1e-12
1653
+ assert (dcorr - cs.cov.correlation()).abs().max() < 2e-12
1654
+
1655
+ # Test CrossCovariance and CrossIoU
1656
+ fn = f"{testdir}/cross_cache.npz"
1657
+ ds = torch.utils.data.TensorDataset(
1658
+ (
1659
+ torch.arange(args.test_size)[:, None] % torch.arange(1, 6)[None, :] == 0
1660
+ ).double(),
1661
+ (
1662
+ torch.arange(args.test_size)[:, None] % torch.arange(5, 8)[None, :] == 0
1663
+ ).double(),
1664
+ )
1665
+ c = CombinedStat(c=CrossCovariance(), iou=CrossIoU())
1666
+ riou = IoU()
1667
+ count = 0
1668
+ for [a, b] in tally(c, ds, cache=fn, batch_size=100):
1669
+ count += 1
1670
+ c.add(a, b)
1671
+ riou.add(torch.cat([a, b], dim=1))
1672
+ assert count == -(-args.test_size // 100)
1673
+ cor = c.c.correlation()
1674
+ iou = c.iou.iou()
1675
+ assert cor.shape == iou.shape == (5, 3)
1676
+ assert iou[4, 0] == 1.0
1677
+ assert abs(iou[0, 2] + (-args.test_size // 7 / float(args.test_size))) < 1e-6
1678
+ assert abs(cor[4, 0] - 1.0) < 1e-2
1679
+ assert abs(cor[0, 2] - 0.0) < 1e-6
1680
+ assert all((riou.iou()[:5, -3:] == iou).view(-1))
1681
+ assert all(riou.iou().diagonal(0) == 1)
1682
+ c = CombinedStat(c=CrossCovariance(), iou=CrossIoU())
1683
+ count = 0
1684
+ for [a, b] in tally(c, ds, cache=fn, batch_size=10):
1685
+ count += 1
1686
+ c.add(a, b)
1687
+ assert count == 0
1688
+ assert all((c.c.correlation() == cor).view(-1))
1689
+ assert all((c.iou.iou() == iou).view(-1))
1690
+
1691
+ # Test Concatantaion, Mean, Bincount and tally.
1692
+ fn = f"{testdir}/series_cache.npz"
1693
+ count = 0
1694
+ ds = torch.utils.data.TensorDataset(torch.arange(args.test_size))
1695
+ c = CombinedStat(s=History(), m=Mean(), b=Bincount())
1696
+ for [b] in tally(c, ds, cache=fn, batch_size=batch_size):
1697
+ count += 1
1698
+ c.add(b)
1699
+ assert count == -(-args.test_size // batch_size)
1700
+ assert len(c.s.history()) == args.test_size
1701
+ assert c.s.history()[-1] == args.test_size - 1
1702
+ assert all(c.s.history() == ds.tensors[0])
1703
+ assert all(c.b.bincount() == torch.ones(args.test_size))
1704
+ assert c.m.mean() == float(args.test_size - 1) / 2.0
1705
+ c2 = CombinedStat(s=History(), m=Mean(), b=Bincount())
1706
+ batches = tally(c2, ds, cache=fn)
1707
+ assert len(c2.s.history()) == args.test_size
1708
+ assert all(c2.s.history() == c.s.history())
1709
+ assert all(c2.b.bincount() == torch.ones(args.test_size))
1710
+ assert c2.m.mean() == c.m.mean()
1711
+ count = 0
1712
+ for b in batches:
1713
+ count += 1
1714
+ assert count == 0 # Shouldn't do anything when it's cached
1715
+
1716
+ # An adverarial case: we keep finding more numbers in the middle
1717
+ # as the stream goes on.
1718
+ amount = args.test_size
1719
+ quantiles = 1000
1720
+ data = numpy.arange(float(amount))
1721
+ data[1::2] = data[-1::-2] + (len(data) - 1)
1722
+ data /= 2
1723
+ depth = 50
1724
+ alldata = data[:, None] + (numpy.arange(depth) * amount)[None, :]
1725
+ actual_sum = torch.FloatTensor(numpy.sum(alldata * alldata, axis=0))
1726
+ amt = amount // depth
1727
+ for r in range(depth):
1728
+ numpy.random.shuffle(alldata[r * amt : r * amt + amt, r])
1729
+ if args.mode == "cuda":
1730
+ alldata = torch.cuda.FloatTensor(alldata)
1731
+ device = torch.device("cuda")
1732
+ else:
1733
+ alldata = torch.FloatTensor(alldata)
1734
+ device = None
1735
+ starttime = time.time()
1736
+ cs = CombinedStat(
1737
+ qc=Quantile(),
1738
+ m=Mean(),
1739
+ v=Variance(),
1740
+ c=Covariance(),
1741
+ s=SecondMoment(),
1742
+ t=TopK(),
1743
+ i=IoU(),
1744
+ )
1745
+ # Feed data in little batches
1746
+ i = 0
1747
+ while i < len(alldata):
1748
+ batch_size = numpy.random.randint(1000)
1749
+ cs.add(alldata[i : i + batch_size])
1750
+ i += batch_size
1751
+ # Test state dict
1752
+ saved = cs.state_dict()
1753
+ # numpy.savez(f'{testdir}/saved.npz', **box_numpy_null(saved))
1754
+ # saved = unbox_numpy_null(numpy.load(f'{testdir}/saved.npz'))
1755
+ cs.save(f"{testdir}/saved.npz")
1756
+ loaded = unbox_numpy_null(numpy.load(f"{testdir}/saved.npz"))
1757
+ assert set(loaded.keys()) == set(saved.keys())
1758
+
1759
+ # Restore using state=saved in constructor.
1760
+ cs2 = CombinedStat(
1761
+ qc=Quantile(),
1762
+ m=Mean(),
1763
+ v=Variance(),
1764
+ c=Covariance(),
1765
+ s=SecondMoment(),
1766
+ t=TopK(),
1767
+ i=IoU(),
1768
+ state=saved,
1769
+ )
1770
+ # saved = unbox_numpy_null(numpy.load(f'{testdir}/saved.npz'))
1771
+ assert not cs2.qc.device.type == "cuda"
1772
+ cs2.to_(device)
1773
+ # alldata = alldata.cpu()
1774
+ cs2.add(alldata)
1775
+ actual_sum *= 2
1776
+ # print(abs(alldata.mean(0) - cs2.m.mean()) / alldata.mean())
1777
+ assert all(abs(alldata.mean(0) - cs2.m.mean()) / alldata.mean() < 1e-5)
1778
+ assert all(abs(alldata.mean(0) - cs2.v.mean()) / alldata.mean() < 1e-5)
1779
+ assert all(abs(alldata.mean(0) - cs2.c.mean()) / alldata.mean() < 1e-5)
1780
+ # print(abs(alldata.var(0) - cs2.v.variance()) / alldata.var(0))
1781
+ assert all(abs(alldata.var(0) - cs2.v.variance()) / alldata.var(0) < 1e-3)
1782
+ assert all(abs(alldata.var(0) - cs2.c.variance()) / alldata.var(0) < 1e-2)
1783
+ # print(abs(alldata.std(0) - cs2.v.stdev()) / alldata.std(0))
1784
+ assert all(abs(alldata.std(0) - cs2.v.stdev()) / alldata.std(0) < 1e-4)
1785
+ # print(abs(alldata.std(0) - cs2.c.stdev()) / alldata.std(0))
1786
+ assert all(abs(alldata.std(0) - cs2.c.stdev()) / alldata.std(0) < 2e-3)
1787
+ moment = (alldata.t() @ alldata) / len(alldata)
1788
+ # print(abs(moment - cs2.s.moment()) / moment.abs())
1789
+ assert all((abs(moment - cs2.s.moment()) / moment.abs()).view(-1) < 1e-2)
1790
+ assert all(alldata.max(dim=0)[0] == cs2.t.topk()[0][:, 0])
1791
+ assert cs2.i.iou()[0, 0] == 1
1792
+ assert all((cs2.i.iou()[1:, 1:] == 1).view(-1))
1793
+ assert all(cs2.i.iou()[1:, 0] < 1)
1794
+ assert all(cs2.i.iou()[1:, 0] == cs2.i.iou()[0, 1:])
1795
+
1796
+ # Restore using cs.load() method.
1797
+ cs = CombinedStat(
1798
+ qc=Quantile(),
1799
+ m=Mean(),
1800
+ v=Variance(),
1801
+ c=Covariance(),
1802
+ s=SecondMoment(),
1803
+ t=TopK(),
1804
+ i=IoU(),
1805
+ )
1806
+ cs.load(f"{testdir}/saved.npz")
1807
+ assert not cs.qc.device.type == "cuda"
1808
+ cs.to_(device)
1809
+ cs.add(alldata)
1810
+ # actual_sum *= 2
1811
+ # print(abs(alldata.mean(0) - cs.m.mean()) / alldata.mean())
1812
+ assert all(abs(alldata.mean(0) - cs.m.mean()) / alldata.mean() < 1e-5)
1813
+ assert all(abs(alldata.mean(0) - cs.v.mean()) / alldata.mean() < 1e-5)
1814
+ assert all(abs(alldata.mean(0) - cs.c.mean()) / alldata.mean() < 1e-5)
1815
+ # print(abs(alldata.var(0) - cs.v.variance()) / alldata.var(0))
1816
+ assert all(abs(alldata.var(0) - cs.v.variance()) / alldata.var(0) < 1e-3)
1817
+ assert all(abs(alldata.var(0) - cs.c.variance()) / alldata.var(0) < 1e-2)
1818
+ # print(abs(alldata.std(0) - cs.v.stdev()) / alldata.std(0))
1819
+ assert all(abs(alldata.std(0) - cs.v.stdev()) / alldata.std(0) < 1e-4)
1820
+ # print(abs(alldata.std(0) - cs.c.stdev()) / alldata.std(0))
1821
+ assert all(abs(alldata.std(0) - cs.c.stdev()) / alldata.std(0) < 2e-3)
1822
+ moment = (alldata.t() @ alldata) / len(alldata)
1823
+ # print(abs(moment - cs.s.moment()) / moment.abs())
1824
+ assert all((abs(moment - cs.s.moment()) / moment.abs()).view(-1) < 1e-2)
1825
+ assert all(alldata.max(dim=0)[0] == cs.t.topk()[0][:, 0])
1826
+ assert cs.i.iou()[0, 0] == 1
1827
+ assert all((cs.i.iou()[1:, 1:] == 1).view(-1))
1828
+ assert all(cs.i.iou()[1:, 0] < 1)
1829
+ assert all(cs.i.iou()[1:, 0] == cs.i.iou()[0, 1:])
1830
+
1831
+ # Randomized quantile test
1832
+ qc = cs.qc
1833
+ ro = qc.readout(1001).cpu()
1834
+ endtime = time.time()
1835
+ gt = (
1836
+ torch.linspace(0, amount, quantiles + 1)[None, :]
1837
+ + (torch.arange(qc.depth, dtype=torch.float) * amount)[:, None]
1838
+ )
1839
+ maxreldev = torch.max(torch.abs(ro - gt) / amount) * quantiles
1840
+ print("Randomized quantile test results:")
1841
+ print("Maximum relative deviation among %d perentiles: %f" % (quantiles, maxreldev))
1842
+ minerr = torch.max(
1843
+ torch.abs(
1844
+ qc.minmax().cpu()[:, 0] - torch.arange(qc.depth, dtype=torch.float) * amount
1845
+ )
1846
+ )
1847
+ maxerr = torch.max(
1848
+ torch.abs(
1849
+ (qc.minmax().cpu()[:, -1] + 1)
1850
+ - (torch.arange(qc.depth, dtype=torch.float) + 1) * amount
1851
+ )
1852
+ )
1853
+ print("Minmax error %f, %f" % (minerr, maxerr))
1854
+ interr = torch.max(
1855
+ torch.abs(qc.integrate(lambda x: x * x).cpu() - actual_sum) / actual_sum
1856
+ )
1857
+ print("Integral error: %f" % interr)
1858
+ medianerr = torch.max(
1859
+ torch.abs(qc.median() - alldata.median(0)[0]) / alldata.median(0)[0]
1860
+ ).cpu()
1861
+ print("Median error: %f" % medianerr)
1862
+ meanerr = torch.max(torch.abs(qc.mean() - alldata.mean(0)) / alldata.mean(0)).cpu()
1863
+ print("Mean error: %f" % meanerr)
1864
+ varerr = torch.max(torch.abs(qc.variance() - alldata.var(0)) / alldata.var(0)).cpu()
1865
+ print("Variance error: %f" % varerr)
1866
+ counterr = (
1867
+ (qc.integrate(lambda x: torch.ones(x.shape[-1]).cpu()) - qc.size())
1868
+ / (0.0 + qc.size())
1869
+ ).item()
1870
+ print("Count error: %f" % counterr)
1871
+ print("Time %f" % (endtime - starttime))
1872
+ # Algorithm is randomized, so some of these will fail with low probability.
1873
+ assert maxreldev < 1.0
1874
+ assert minerr == 0.0
1875
+ assert maxerr == 0.0
1876
+ assert interr < 0.01
1877
+ assert abs(counterr) < 0.001
1878
+ shutil.rmtree(testdir, ignore_errors=True)
1879
+ print("OK")
1880
+
1881
+
1882
+ if __name__ == "__main__":
1883
+ _unit_test()
hparams/GRACE/README.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ alg_name: "GRACE"
2
+ model_name: "./hugging_cache/gpt-j-6B"
3
+ device: 0
4
+
5
+ inner_params:
6
+ - transformer.h[25].mlp.fc_out.weight
7
+
8
+ edit_lr: 1.0
9
+ n_iter: 200
10
+ eps: 1.0
11
+ dist_fn: euc # euc, mmd, cos
12
+ val_init: cold # cold, warm
13
+ val_train: sgd # sgd, pert
14
+ val_reg: None # early
15
+ reg: early_stop # early_stop
16
+ replacement: replace_last # replace_last, replace_all, replace_prompt
17
+ eps_expand: coverage # , moving_avg, decay
18
+ num_pert: 8 # only matters when using perturbation training
19
+ dropout: 0.0
hparams/GRACE/gpt2-xl.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ alg_name: "GRACE"
2
+ model_name: "./hugging_cache/gpt2-xl"
3
+ device: 0
4
+
5
+ inner_params:
6
+ - transformer.h[35].mlp.c_fc.weight
7
+
8
+ edit_lr: 1.0
9
+ n_iter: 50
10
+ eps: 1.0
11
+ dist_fn: euc # euc, mmd, cos
12
+ val_init: cold # cold, warm
13
+ val_train: sgd # sgd, pert
14
+ val_reg: None # early
15
+ reg: early_stop # early_stop
16
+ replacement: replace_last # replace_last, replace_all, replace_prompt
17
+ eps_expand: coverage # , moving_avg, decay
18
+ num_pert: 8 # only matters when using perturbation training
19
+ dropout: 0.0
hparams/config.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ save_dir: models/
2
+ log_dir: logs/
3
+
4
+ defaults:
5
+ alg_name: KN # Editing Method
6
+ hparams_name: KN/t5-3b # Edited Model Config Path
utils.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
2
+ from transformers import GPT2TokenizerFast, GPT2Tokenizer
3
+ from easyeditor import apply_grace_to_model, GraceHyperParams,nethook
4
+ import torch
5
+
6
+
7
+
8
+ def edit(prompt, target_new):
9
+ request={"prompt":prompt,"target_new":target_new}
10
+ hparams = GraceHyperParams.from_hparams("./hparams/GRACE/gpt2-xl.yaml")
11
+
12
+ model = AutoModelForCausalLM.from_pretrained("./models/gpt2-xl")
13
+ tok = GPT2Tokenizer.from_pretrained("./models/gpt2-xl")
14
+ tok.pad_token_id = tok.eos_token_id
15
+ global edit_model
16
+ edit_model,_ = apply_grace_to_model(model,tok,request,hparams,keep_original_weight=True)
17
+ return "finish"
18
+
19
+ def generate(input_text):
20
+ tok = GPT2Tokenizer.from_pretrained("./models/gpt2-xl")
21
+ hparams = GraceHyperParams.from_hparams("./hparams/GRACE/gpt2-xl.yaml")
22
+ tok.pad_token_id = tok.eos_token_id
23
+
24
+ global edit_model
25
+
26
+ input_ids = tok.encode(input_text, return_tensors='pt').to(f'cuda:{hparams.device}')
27
+ edit_output = edit_model.generate(input_ids, max_length=30, pad_token_id=tok.eos_token_id)
28
+ edit_reply = tok.decode(edit_output[0], skip_special_tokens=True)
29
+ del edit_model
30
+ torch.cuda.empty_cache()
31
+
32
+ ori_model = AutoModelForCausalLM.from_pretrained("./models/gpt2-xl").to(f'cuda:{hparams.device}')
33
+ ori_output = ori_model.generate(input_ids, max_length=30, pad_token_id=tok.eos_token_id)
34
+ ori_reply = tok.decode(ori_output[0], skip_special_tokens=True)
35
+
36
+ return ori_reply, edit_reply