Safetensors
llama3_SAE
custom_code
RuHae commited on
Commit
a29b74a
1 Parent(s): 91cf7fd

Upload LLama3_SAE

Browse files
config.json CHANGED
@@ -9,7 +9,7 @@
9
  "attention_dropout": 0.0,
10
  "auto_map": {
11
  "AutoConfig": "RuHae/Llama3_SAE--configuration_llama3_SAE.LLama3_SAE_Config",
12
- "AutoModelForCausalLM": "RuHae/Llama3_SAE--modeling_llama3_SAE.LLama3_SAE"
13
  },
14
  "base_model_name": "meta-llama/Meta-Llama-3-8B",
15
  "bos_token_id": 128000,
 
9
  "attention_dropout": 0.0,
10
  "auto_map": {
11
  "AutoConfig": "RuHae/Llama3_SAE--configuration_llama3_SAE.LLama3_SAE_Config",
12
+ "AutoModelForCausalLM": "modeling_llama3_SAE.LLama3_SAE"
13
  },
14
  "base_model_name": "meta-llama/Meta-Llama-3-8B",
15
  "bos_token_id": 128000,
configuration_llama3_SAE.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, LlamaConfig
2
+ from typing import List, Callable
3
+ import torch
4
+
5
+
6
+ # class LLama3_SAE_Config(PretrainedConfig):
7
+ class LLama3_SAE_Config(LlamaConfig):
8
+ model_type = "llama3_SAE"
9
+
10
+ def __init__(
11
+ self,
12
+ # hf_token: str = "",
13
+ # base_model_config: LlamaConfig = None,
14
+ base_model_name: str = "",
15
+ hook_block_num: int = 25,
16
+ n_latents: int = 12288,
17
+ n_inputs: int = 4096,
18
+ activation: str = "relu",
19
+ activation_k: int = 64,
20
+ site: str = "mlp",
21
+ tied: bool = False,
22
+ normalize: bool = False,
23
+ mod_features: List[int] = None,
24
+ mod_threshold: List[int] = None,
25
+ mod_replacement: List[int] = None,
26
+ mod_scaling: List[int] = None,
27
+ **kwargs,
28
+ ):
29
+ # self.hf_token = hf_token
30
+ # self.base_model_config = base_model_config
31
+ self.base_model_name = base_model_name
32
+ self.hook_block_num = hook_block_num
33
+ self.n_latents = n_latents
34
+ self.n_inputs = n_inputs
35
+ self.activation = activation
36
+ self.activation_k = activation_k
37
+ self.site = site
38
+ self.tied = tied
39
+ self.normalize = normalize
40
+ self.mod_features = mod_features
41
+ self.mod_threshold = mod_threshold
42
+ self.mod_replacement = mod_replacement
43
+ self.mod_scaling = mod_scaling
44
+
45
+ super().__init__(**kwargs)
modeling_llama3_SAE.py ADDED
@@ -0,0 +1,795 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union, Callable, Any
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ try:
8
+ from configuration_llama3_SAE import LLama3_SAE_Config
9
+ except:
10
+ from .configuration_llama3_SAE import LLama3_SAE_Config
11
+
12
+ from transformers import (
13
+ LlamaPreTrainedModel,
14
+ LlamaModel,
15
+ )
16
+ from transformers.modeling_outputs import CausalLMOutputWithPast
17
+ from transformers.cache_utils import Cache
18
+ from transformers.utils import (
19
+ add_start_docstrings,
20
+ add_start_docstrings_to_model_forward,
21
+ is_flash_attn_2_available,
22
+ is_flash_attn_greater_or_equal_2_10,
23
+ logging,
24
+ replace_return_docstrings,
25
+ )
26
+
27
+ import logging
28
+
29
+ logging.basicConfig(level=logging.INFO)
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ class LLama3_SAE(LlamaPreTrainedModel):
34
+ config_class = LLama3_SAE_Config
35
+ _tied_weights_keys = ["lm_head.weight"]
36
+
37
+ def __init__(self, config: LLama3_SAE_Config):
38
+ super().__init__(config)
39
+ self.model = LlamaModel(config)
40
+ self.vocab_size = config.vocab_size
41
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
42
+
43
+ if config.activation == "topk":
44
+ if isinstance(config.activation_k, int):
45
+ activation = TopK(torch.tensor(config.activation_k))
46
+ else:
47
+ activation = TopK(config.activation_k)
48
+ elif config.activation == "topk-tanh":
49
+ if isinstance(config.activation_k, int):
50
+ activation = TopK(torch.tensor(config.activation_k), nn.Tanh())
51
+ else:
52
+ activation = TopK(config.activation_k, nn.Tanh())
53
+ elif config.activation == "topk-sigmoid":
54
+ if isinstance(config.activation_k, int):
55
+ activation = TopK(torch.tensor(config.activation_k), nn.Sigmoid())
56
+ else:
57
+ activation = TopK(config.activation_k, nn.Sigmoid())
58
+ elif config.activation == "jumprelu":
59
+ activation = JumpReLu()
60
+ elif config.activation == "relu":
61
+ activation = "ReLU"
62
+ elif config.activation == "identity":
63
+ activation = "Identity"
64
+ else:
65
+ raise (
66
+ NotImplementedError,
67
+ f"Activation '{config.activation}' not implemented.",
68
+ )
69
+
70
+ self.SAE = Autoencoder(
71
+ n_inputs=config.n_inputs,
72
+ n_latents=config.n_latents,
73
+ activation=activation,
74
+ tied=False,
75
+ normalize=True,
76
+ )
77
+
78
+ self.hook = HookedTransformer_with_SAE_suppresion(
79
+ block=config.hook_block_num,
80
+ sae=self.SAE,
81
+ mod_features=config.mod_features,
82
+ mod_threshold=config.mod_threshold,
83
+ mod_replacement=config.mod_replacement,
84
+ mod_scaling=config.mod_scaling,
85
+ ).register_with(self.model, config.site)
86
+
87
+ # Initialize weights and apply final processing
88
+ self.post_init()
89
+
90
+ def get_input_embeddings(self):
91
+ return self.model.embed_tokens
92
+
93
+ def set_input_embeddings(self, value):
94
+ self.model.embed_tokens = value
95
+
96
+ def get_output_embeddings(self):
97
+ return self.lm_head
98
+
99
+ def set_output_embeddings(self, new_embeddings):
100
+ self.lm_head = new_embeddings
101
+
102
+ def set_decoder(self, decoder):
103
+ self.model = decoder
104
+
105
+ def get_decoder(self):
106
+ return self.model
107
+
108
+ def forward(
109
+ self,
110
+ input_ids: torch.LongTensor = None,
111
+ attention_mask: Optional[torch.Tensor] = None,
112
+ position_ids: Optional[torch.LongTensor] = None,
113
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
114
+ inputs_embeds: Optional[torch.FloatTensor] = None,
115
+ labels: Optional[torch.LongTensor] = None,
116
+ use_cache: Optional[bool] = None,
117
+ output_attentions: Optional[bool] = None,
118
+ output_hidden_states: Optional[bool] = None,
119
+ return_dict: Optional[bool] = None,
120
+ cache_position: Optional[torch.LongTensor] = None,
121
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
122
+ r"""
123
+ Args:
124
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
125
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
126
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
127
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
128
+
129
+ Returns:
130
+
131
+ Example:
132
+
133
+ ```python
134
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
135
+
136
+ >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
137
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
138
+
139
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
140
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
141
+
142
+ >>> # Generate
143
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
144
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
145
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
146
+ ```"""
147
+ output_attentions = (
148
+ output_attentions
149
+ if output_attentions is not None
150
+ else self.config.output_attentions
151
+ )
152
+ output_hidden_states = (
153
+ output_hidden_states
154
+ if output_hidden_states is not None
155
+ else self.config.output_hidden_states
156
+ )
157
+ return_dict = (
158
+ return_dict if return_dict is not None else self.config.use_return_dict
159
+ )
160
+
161
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
162
+ outputs = self.model(
163
+ input_ids=input_ids,
164
+ attention_mask=attention_mask,
165
+ position_ids=position_ids,
166
+ past_key_values=past_key_values,
167
+ inputs_embeds=inputs_embeds,
168
+ use_cache=use_cache,
169
+ output_attentions=output_attentions,
170
+ output_hidden_states=output_hidden_states,
171
+ return_dict=return_dict,
172
+ cache_position=cache_position,
173
+ )
174
+
175
+ hidden_states = outputs[0]
176
+ if self.config.pretraining_tp > 1:
177
+ lm_head_slices = self.lm_head.weight.split(
178
+ self.vocab_size // self.config.pretraining_tp, dim=0
179
+ )
180
+ logits = [
181
+ F.linear(hidden_states, lm_head_slices[i])
182
+ for i in range(self.config.pretraining_tp)
183
+ ]
184
+ logits = torch.cat(logits, dim=-1)
185
+ else:
186
+ logits = self.lm_head(hidden_states)
187
+ logits = logits.float()
188
+
189
+ loss = None
190
+ if labels is not None:
191
+ # Shift so that tokens < n predict n
192
+ shift_logits = logits[..., :-1, :].contiguous()
193
+ shift_labels = labels[..., 1:].contiguous()
194
+
195
+ # Flatten the tokens
196
+ loss_fct = nn.CrossEntropyLoss(reduction="none")
197
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
198
+ shift_labels = shift_labels.view(-1)
199
+ # Enable model parallelism
200
+ shift_labels = shift_labels.to(shift_logits.device)
201
+ loss = loss_fct(shift_logits, shift_labels)
202
+ loss = loss.view(logits.size(0), -1)
203
+ mask = loss != 0
204
+ loss = loss.sum(dim=-1) / mask.sum(dim=-1)
205
+
206
+ if not return_dict:
207
+ output = (logits,) + outputs[1:]
208
+ return (loss,) + output if loss is not None else output
209
+
210
+ return CausalLMOutputWithPast(
211
+ loss=loss,
212
+ logits=logits,
213
+ past_key_values=outputs.past_key_values,
214
+ hidden_states=outputs.hidden_states,
215
+ attentions=outputs.attentions,
216
+ )
217
+
218
+ def prepare_inputs_for_generation(
219
+ self,
220
+ input_ids,
221
+ past_key_values=None,
222
+ attention_mask=None,
223
+ inputs_embeds=None,
224
+ cache_position=None,
225
+ use_cache=True,
226
+ **kwargs,
227
+ ):
228
+ past_length = 0
229
+ if past_key_values is not None:
230
+ if isinstance(past_key_values, Cache):
231
+ past_length = (
232
+ cache_position[0]
233
+ if cache_position is not None
234
+ else past_key_values.get_seq_length()
235
+ )
236
+ max_cache_length = (
237
+ torch.tensor(
238
+ past_key_values.get_max_length(), device=input_ids.device
239
+ )
240
+ if past_key_values.get_max_length() is not None
241
+ else None
242
+ )
243
+ cache_length = (
244
+ past_length
245
+ if max_cache_length is None
246
+ else torch.min(max_cache_length, past_length)
247
+ )
248
+ # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
249
+ else:
250
+ cache_length = past_length = past_key_values[0][0].shape[2]
251
+ max_cache_length = None
252
+
253
+ # Keep only the unprocessed tokens:
254
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
255
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
256
+ if (
257
+ attention_mask is not None
258
+ and attention_mask.shape[1] > input_ids.shape[1]
259
+ ):
260
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
261
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
262
+ # input_ids based on the past_length.
263
+ elif past_length < input_ids.shape[1]:
264
+ input_ids = input_ids[:, past_length:]
265
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
266
+
267
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
268
+ if (
269
+ max_cache_length is not None
270
+ and attention_mask is not None
271
+ and cache_length + input_ids.shape[1] > max_cache_length
272
+ ):
273
+ attention_mask = attention_mask[:, -max_cache_length:]
274
+
275
+ position_ids = kwargs.get("position_ids", None)
276
+ if attention_mask is not None and position_ids is None:
277
+ # create position_ids on the fly for batch generation
278
+ position_ids = attention_mask.long().cumsum(-1) - 1
279
+ position_ids.masked_fill_(attention_mask == 0, 1)
280
+ if past_key_values:
281
+ position_ids = position_ids[:, -input_ids.shape[1] :]
282
+
283
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
284
+ if inputs_embeds is not None and past_key_values is None:
285
+ model_inputs = {"inputs_embeds": inputs_embeds}
286
+ else:
287
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
288
+ # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
289
+ # TODO: use `next_tokens` directly instead.
290
+ model_inputs = {"input_ids": input_ids.contiguous()}
291
+
292
+ input_length = (
293
+ position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
294
+ )
295
+ if cache_position is None:
296
+ cache_position = torch.arange(
297
+ past_length, past_length + input_length, device=input_ids.device
298
+ )
299
+ elif use_cache:
300
+ cache_position = cache_position[-input_length:]
301
+
302
+ model_inputs.update(
303
+ {
304
+ "position_ids": position_ids,
305
+ "cache_position": cache_position,
306
+ "past_key_values": past_key_values,
307
+ "use_cache": use_cache,
308
+ "attention_mask": attention_mask,
309
+ }
310
+ )
311
+ return model_inputs
312
+
313
+ @staticmethod
314
+ def _reorder_cache(past_key_values, beam_idx):
315
+ reordered_past = ()
316
+ for layer_past in past_key_values:
317
+ reordered_past += (
318
+ tuple(
319
+ past_state.index_select(0, beam_idx.to(past_state.device))
320
+ for past_state in layer_past
321
+ ),
322
+ )
323
+ return reordered_past
324
+
325
+
326
+ def LN(
327
+ x: torch.Tensor, eps: float = 1e-5
328
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
329
+ mu = x.mean(dim=-1, keepdim=True)
330
+ x = x - mu
331
+ std = x.std(dim=-1, keepdim=True)
332
+ x = x / (std + eps)
333
+ return x, mu, std
334
+
335
+
336
+ class Autoencoder(nn.Module):
337
+ """Sparse autoencoder
338
+
339
+ Implements:
340
+ latents = activation(encoder(x - pre_bias) + latent_bias)
341
+ recons = decoder(latents) + pre_bias
342
+ """
343
+
344
+ def __init__(
345
+ self,
346
+ n_latents: int,
347
+ n_inputs: int,
348
+ activation: Callable = nn.ReLU(),
349
+ tied: bool = False,
350
+ normalize: bool = False,
351
+ ) -> None:
352
+ """
353
+ :param n_latents: dimension of the autoencoder latent
354
+ :param n_inputs: dimensionality of the original data (e.g residual stream, number of MLP hidden units)
355
+ :param activation: activation function
356
+ :param tied: whether to tie the encoder and decoder weights
357
+ """
358
+ super().__init__()
359
+ self.n_inputs = n_inputs
360
+ self.n_latents = n_latents
361
+
362
+ self.pre_bias = nn.Parameter(torch.zeros(n_inputs))
363
+ self.encoder: nn.Module = nn.Linear(n_inputs, n_latents, bias=False)
364
+ self.latent_bias = nn.Parameter(torch.zeros(n_latents))
365
+ self.activation = activation
366
+
367
+ if isinstance(activation, JumpReLu):
368
+ self.threshold = nn.Parameter(torch.empty(n_latents))
369
+ torch.nn.init.constant_(self.threshold, 0.001)
370
+ self.forward = self.forward_jumprelu
371
+ elif isinstance(activation, TopK):
372
+ self.forward = self.forward_topk
373
+ else:
374
+ logger.warning(
375
+ f"Using TopK forward function even if activation is not TopK, but is {activation}"
376
+ )
377
+ self.forward = self.forward_topk
378
+
379
+ if tied:
380
+ # self.decoder: nn.Linear | TiedTranspose = TiedTranspose(self.encoder)
381
+ self.decoder = nn.Linear(n_latents, n_inputs, bias=False)
382
+ self.decoder.weight.data = self.encoder.weight.data.T.clone()
383
+ else:
384
+ self.decoder = nn.Linear(n_latents, n_inputs, bias=False)
385
+ self.normalize = normalize
386
+
387
+ def encode_pre_act(
388
+ self, x: torch.Tensor, latent_slice: slice = slice(None)
389
+ ) -> torch.Tensor:
390
+ """
391
+ :param x: input data (shape: [batch, n_inputs])
392
+ :param latent_slice: slice of latents to compute
393
+ Example: latent_slice = slice(0, 10) to compute only the first 10 latents.
394
+ :return: autoencoder latents before activation (shape: [batch, n_latents])
395
+ """
396
+ x = x - self.pre_bias
397
+ latents_pre_act = F.linear(
398
+ x, self.encoder.weight[latent_slice], self.latent_bias[latent_slice]
399
+ )
400
+ return latents_pre_act
401
+
402
+ def preprocess(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, Any]]:
403
+ if not self.normalize:
404
+ return x, dict()
405
+ x, mu, std = LN(x)
406
+ return x, dict(mu=mu, std=std)
407
+
408
+ def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, Any]]:
409
+ """
410
+ :param x: input data (shape: [batch, n_inputs])
411
+ :return: autoencoder latents (shape: [batch, n_latents])
412
+ """
413
+ x, info = self.preprocess(x)
414
+ return self.activation(self.encode_pre_act(x)), info
415
+
416
+ def decode(
417
+ self, latents: torch.Tensor, info: dict[str, Any] | None = None
418
+ ) -> torch.Tensor:
419
+ """
420
+ :param latents: autoencoder latents (shape: [batch, n_latents])
421
+ :return: reconstructed data (shape: [batch, n_inputs])
422
+ """
423
+ ret = self.decoder(latents) + self.pre_bias
424
+ if self.normalize:
425
+ assert info is not None
426
+ ret = ret * info["std"] + info["mu"]
427
+ return ret
428
+
429
+ def forward_topk(
430
+ self, x: torch.Tensor
431
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
432
+ """
433
+ :param x: input data (shape: [batch, n_inputs])
434
+ :return: autoencoder latents pre activation (shape: [batch, n_latents])
435
+ autoencoder latents (shape: [batch, n_latents])
436
+ reconstructed data (shape: [batch, n_inputs])
437
+ """
438
+ x, info = self.preprocess(x)
439
+ latents_pre_act = self.encode_pre_act(x)
440
+ latents = self.activation(latents_pre_act)
441
+ recons = self.decode(latents, info)
442
+
443
+ return latents_pre_act, latents, recons
444
+
445
+ def forward_jumprelu(
446
+ self, x: torch.Tensor
447
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
448
+ """
449
+ :param x: input data (shape: [batch, n_inputs])
450
+ :return: autoencoder latents pre activation (shape: [batch, n_latents])
451
+ autoencoder latents (shape: [batch, n_latents])
452
+ reconstructed data (shape: [batch, n_inputs])
453
+ """
454
+ x, info = self.preprocess(x)
455
+ latents_pre_act = self.encode_pre_act(x)
456
+ latents = self.activation(F.relu(latents_pre_act), torch.exp(self.threshold))
457
+ recons = self.decode(latents, info)
458
+
459
+ return latents_pre_act, latents, recons
460
+
461
+
462
+ class TiedTranspose(nn.Module):
463
+ def __init__(self, linear: nn.Linear):
464
+ super().__init__()
465
+ self.linear = linear
466
+
467
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
468
+ assert self.linear.bias is None
469
+ # torch.nn.parameter.Parameter(layer_e.weights.T)
470
+ return F.linear(x, self.linear.weight.t(), None)
471
+
472
+ @property
473
+ def weight(self) -> torch.Tensor:
474
+ return self.linear.weight.t()
475
+
476
+ @property
477
+ def bias(self) -> torch.Tensor:
478
+ return self.linear.bias
479
+
480
+
481
+ class TopK(nn.Module):
482
+ def __init__(self, k: int, postact_fn: Callable = nn.ReLU()) -> None:
483
+ super().__init__()
484
+ self.k = k
485
+ self.postact_fn = postact_fn
486
+
487
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
488
+ topk = torch.topk(x, k=self.k, dim=-1)
489
+ values = self.postact_fn(topk.values)
490
+ # make all other values 0
491
+ result = torch.zeros_like(x)
492
+ result.scatter_(-1, topk.indices, values)
493
+ return result
494
+
495
+
496
+ class JumpReLu(nn.Module):
497
+ def __init__(self):
498
+ super().__init__()
499
+
500
+ def forward(self, input, threshold):
501
+ return JumpReLUFunction.apply(input, threshold)
502
+
503
+
504
+ class HeavyStep(nn.Module):
505
+ def __init__(self):
506
+ super().__init__()
507
+
508
+ def forward(self, input, threshold):
509
+ return HeavyStepFunction.apply(input, threshold)
510
+
511
+
512
+ def rectangle(x):
513
+ return (x > -0.5) & (x < 0.5)
514
+
515
+
516
+ class JumpReLUFunction(torch.autograd.Function):
517
+ @staticmethod
518
+ def forward(input, threshold):
519
+ output = input * (input > threshold)
520
+ return output
521
+
522
+ @staticmethod
523
+ def setup_context(ctx, inputs, output):
524
+ input, threshold = inputs
525
+ ctx.save_for_backward(input, threshold)
526
+
527
+ @staticmethod
528
+ def backward(ctx, grad_output):
529
+ bandwidth = 0.001
530
+ # bandwidth = 0.0001
531
+ input, threshold = ctx.saved_tensors
532
+ grad_input = grad_threshold = None
533
+
534
+ grad_input = input > threshold
535
+ grad_threshold = (
536
+ -(threshold / bandwidth)
537
+ * rectangle((input - threshold) / bandwidth)
538
+ * grad_output
539
+ )
540
+
541
+ return grad_input, grad_threshold
542
+
543
+
544
+ class HeavyStepFunction(torch.autograd.Function):
545
+ @staticmethod
546
+ def forward(input, threshold):
547
+ output = input * threshold
548
+ return output
549
+
550
+ @staticmethod
551
+ def setup_context(ctx, inputs, output):
552
+ input, threshold = inputs
553
+ ctx.save_for_backward(input, threshold)
554
+
555
+ @staticmethod
556
+ def backward(ctx, grad_output):
557
+ bandwidth = 0.001
558
+ # bandwidth = 0.0001
559
+ input, threshold = ctx.saved_tensors
560
+ grad_input = grad_threshold = None
561
+
562
+ grad_input = torch.zeros_like(input)
563
+ grad_threshold = (
564
+ -(1.0 / bandwidth)
565
+ * rectangle((input - threshold) / bandwidth)
566
+ * grad_output
567
+ )
568
+
569
+ return grad_input, grad_threshold
570
+
571
+
572
+ ACTIVATIONS_CLASSES = {
573
+ "ReLU": nn.ReLU,
574
+ "Identity": nn.Identity,
575
+ "TopK": TopK,
576
+ "JumpReLU": JumpReLu,
577
+ }
578
+
579
+
580
+ class HookedTransformer_with_SAE:
581
+ """Auxilliary class used to extract mlp activations from transformer models."""
582
+
583
+ def __init__(self, block: int, sae) -> None:
584
+ self.block = block
585
+ self.sae = sae
586
+
587
+ self.remove_handle = (
588
+ None # Can be used to remove this hook from the model again
589
+ )
590
+
591
+ self._features = None
592
+
593
+ def register_with(self, model):
594
+ # At the moment only activations from Feed Forward MLP layer
595
+ self.remove_handle = model.layers[self.block].mlp.register_forward_hook(self)
596
+
597
+ return self
598
+
599
+ def pop(self) -> torch.Tensor:
600
+ """Remove and return extracted feature from this hook.
601
+
602
+ We only allow access to the features this way to not have any lingering references to them.
603
+ """
604
+ assert self._features is not None, "Feature extractor was not called yet!"
605
+ features = self._features
606
+ self._features = None
607
+ return features
608
+
609
+ def __call__(self, module, inp, outp) -> None:
610
+ self._features = outp
611
+ return self.sae(outp)[2]
612
+
613
+
614
+ class HookedTransformer_with_SAE_suppresion:
615
+ """Auxilliary class used to extract mlp activations from transformer models."""
616
+
617
+ def __init__(
618
+ self,
619
+ block: int,
620
+ sae: Autoencoder,
621
+ mod_features: list = None,
622
+ mod_threshold: list = None,
623
+ mod_replacement: list = None,
624
+ mod_scaling: list = None,
625
+ mod_balance: bool = False,
626
+ multi_feature: bool = False,
627
+ ) -> None:
628
+ self.block = block
629
+ self.sae = sae
630
+
631
+ self.remove_handle = (
632
+ None # Can be used to remove this hook from the model again
633
+ )
634
+
635
+ self._features = None
636
+ self.mod_features = mod_features
637
+ self.mod_threshold = mod_threshold
638
+ self.mod_replacement = mod_replacement
639
+ self.mod_scaling = mod_scaling
640
+ self.mod_balance = mod_balance
641
+ self.mod_vector = None
642
+ self.mod_vec_factor = 1.0
643
+
644
+ if multi_feature:
645
+ self.modify = self.modify_list
646
+ else:
647
+ self.modify = self.modify_single
648
+
649
+ if isinstance(self.sae.activation, JumpReLu):
650
+ logger.info("Setting __call__ function for JumpReLU.")
651
+ setattr(self, "call", self.__call__jumprelu)
652
+ elif isinstance(self.sae.activation, TopK):
653
+ logger.info("Setting __call__ function for TopK.")
654
+ setattr(self, "call", self.__call__topk)
655
+ else:
656
+ logger.warning(
657
+ f"Using TopK forward function even if activation is not TopK, but is {self.sae.activation}"
658
+ )
659
+ setattr(self, "call", self.__call__topk)
660
+
661
+ def register_with(self, model, site="mlp"):
662
+ self.site = site
663
+ # Decision on where to extract activations from
664
+ if site == "mlp": # output of the FF module of block
665
+ self.remove_handle = model.layers[self.block].mlp.register_forward_hook(
666
+ self
667
+ )
668
+ elif (
669
+ site == "block"
670
+ ): # output of the residual connection AFTER it is added to the FF output
671
+ self.remove_handle = model.layers[self.block].register_forward_hook(self)
672
+ elif site == "attention":
673
+ raise NotImplementedError
674
+ else:
675
+ raise NotImplementedError
676
+
677
+ # self.remove_handle = model.model.layers[self.block].mlp.act_fn.register_forward_hook(self)
678
+
679
+ return self
680
+
681
+ def modify_list(self, latents: torch.Tensor) -> torch.Tensor:
682
+ if self.mod_replacement is not None:
683
+ for feat, thresh, mod in zip(
684
+ self.mod_features, self.mod_threshold, self.mod_replacement
685
+ ):
686
+ latents[:, :, feat][latents[:, :, feat] > thresh] = mod
687
+ elif self.mod_scaling is not None:
688
+ for feat, thresh, mod in zip(
689
+ self.mod_features, self.mod_threshold, self.mod_scaling
690
+ ):
691
+ latents[:, :, feat][latents[:, :, feat] > thresh] *= mod
692
+ elif self.mod_vector is not None:
693
+ latents = latents + self.mod_vec_factor * self.mod_vector
694
+ else:
695
+ pass
696
+
697
+ return latents
698
+
699
+ def modify_single(self, latents: torch.Tensor) -> torch.Tensor:
700
+ old_cond_feats = latents[:, :, self.mod_features]
701
+ if self.mod_replacement is not None:
702
+ # latents[:, :, self.mod_features][
703
+ # latents[:, :, self.mod_features] > self.mod_threshold
704
+ # ] = self.mod_replacement
705
+ latents[:, :, self.mod_features] = self.mod_replacement
706
+ elif self.mod_scaling is not None:
707
+ latents_scaled = latents.clone()
708
+ latents_scaled[:, :, self.mod_features][
709
+ latents[:, :, self.mod_features] > 0
710
+ ] *= self.mod_scaling
711
+ latents_scaled[:, :, self.mod_features][
712
+ latents[:, :, self.mod_features] < 0
713
+ ] *= -1 * self.mod_scaling
714
+ latents = latents_scaled
715
+ # latents[:, :, self.mod_features] *= self.mod_scaling
716
+ elif self.mod_vector is not None:
717
+ latents = latents + self.mod_vec_factor * self.mod_vector
718
+ else:
719
+ pass
720
+
721
+ if self.mod_balance:
722
+ # logger.warning("The balancing does not work yet!!!")
723
+ # TODO: Look into it more closely, not sure if this is correct
724
+ num_feat = latents.shape[2] - 1
725
+ diff = old_cond_feats - latents[:, :, self.mod_features]
726
+ if self.mod_features != 0:
727
+ latents[:, :, : self.mod_features] += (diff / num_feat)[:, :, None]
728
+ latents[:, :, self.mod_features + 1 :] += (diff / num_feat)[:, :, None]
729
+
730
+ return latents
731
+
732
+ def pop(self) -> torch.Tensor:
733
+ """Remove and return extracted feature from this hook.
734
+
735
+ We only allow access to the features this way to not have any lingering references to them.
736
+ """
737
+ assert self._features is not None, "Feature extractor was not called yet!"
738
+ if isinstance(self._features, tuple):
739
+ features = self._features[0]
740
+ else:
741
+ features = self._features
742
+ self._features = None
743
+ return features
744
+
745
+ def __call__topk(self, module, inp, outp) -> torch.Tensor:
746
+ self._features = outp
747
+ if isinstance(self._features, tuple):
748
+ features = self._features[0]
749
+ else:
750
+ features = self._features
751
+
752
+ if self.mod_features is None:
753
+ recons = features
754
+ else:
755
+ x, info = self.sae.preprocess(features)
756
+ latents_pre_act = self.sae.encode_pre_act(x)
757
+ latents = self.sae.activation(latents_pre_act)
758
+ # latents[:, :, self.mod_features] = F.sigmoid(
759
+ # latents_pre_act[:, :, self.mod_features]
760
+ # )
761
+ # latents[:, :, self.mod_features] = torch.abs(latents_pre_act[:, :, self.mod_features])
762
+ # latents[:, :, self.mod_features] = latents_pre_act[:, :, self.mod_features]
763
+ mod_latents = self.modify(latents)
764
+ # mod_latents[:, :, self.mod_features] = F.sigmoid(
765
+ # mod_latents[:, :, self.mod_features]
766
+ # )
767
+
768
+ recons = self.sae.decode(mod_latents, info)
769
+
770
+ if isinstance(self._features, tuple):
771
+ outp = list(outp)
772
+ outp[0] = recons
773
+ return tuple(outp)
774
+ else:
775
+ return recons
776
+
777
+ def __call__jumprelu(self, module, inp, outp) -> torch.Tensor:
778
+ self._features = outp
779
+ if self.mod_features is None:
780
+ recons = outp
781
+ else:
782
+ x, info = self.sae.preprocess(outp)
783
+ latents_pre_act = self.sae.encode_pre_act(x)
784
+ latents = self.sae.activation(
785
+ F.relu(latents_pre_act), torch.exp(self.sae.threshold)
786
+ )
787
+ latents[:, :, self.mod_features] = latents_pre_act[:, :, self.mod_features]
788
+ mod_latents = self.modify(latents)
789
+
790
+ recons = self.sae.decode(mod_latents, info)
791
+
792
+ return recons
793
+
794
+ def __call__(self, module, inp, outp) -> torch.Tensor:
795
+ return self.call(module, inp, outp)