yairschiff commited on
Commit
c7f0704
·
verified ·
1 Parent(s): a71b36b

Upload CaduceusForMaskedLM

Browse files
config.json ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CaduceusForMaskedLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_caduceus.CaduceusConfig",
7
+ "AutoModel": "modeling_caduceus.Caduceus",
8
+ "AutoModelForMaskedLM": "modeling_caduceus.CaduceusForMaskedLM",
9
+ "AutoModelForSequenceClassification": "modeling_caduceus.CaduceusForSequenceClassification"
10
+ },
11
+ "bidirectional": true,
12
+ "bidirectional_strategy": "add",
13
+ "bidirectional_weight_tie": true,
14
+ "complement_map": {
15
+ "0": 0,
16
+ "1": 1,
17
+ "2": 2,
18
+ "3": 3,
19
+ "4": 4,
20
+ "5": 5,
21
+ "6": 6,
22
+ "7": 10,
23
+ "8": 9,
24
+ "9": 8,
25
+ "10": 7,
26
+ "11": 11,
27
+ "12": 12,
28
+ "13": 13,
29
+ "14": 14,
30
+ "15": 15
31
+ },
32
+ "d_model": 256,
33
+ "fused_add_norm": true,
34
+ "initializer_cfg": {
35
+ "initializer_range": 0.02,
36
+ "n_residuals_per_layer": 1,
37
+ "rescale_prenorm_residual": true
38
+ },
39
+ "model_type": "caduceus",
40
+ "n_layer": 16,
41
+ "norm_epsilon": 1e-05,
42
+ "pad_vocab_size_multiple": 8,
43
+ "rcps": true,
44
+ "residual_in_fp32": false,
45
+ "rms_norm": true,
46
+ "ssm_cfg": {
47
+ "bias": false,
48
+ "conv_bias": true,
49
+ "d_conv": 4,
50
+ "d_state": 16,
51
+ "dt_init": "random",
52
+ "dt_init_floor": 0.0001,
53
+ "dt_max": 0.1,
54
+ "dt_min": 0.001,
55
+ "dt_rank": "auto",
56
+ "dt_scale": 1.0,
57
+ "expand": 2,
58
+ "use_fast_path": true
59
+ },
60
+ "torch_dtype": "float32",
61
+ "transformers_version": "4.38.1",
62
+ "vocab_size": 16
63
+ }
configuration_caduceus.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Caduceus config for Hugging Face.
2
+
3
+ """
4
+
5
+ from typing import Optional, Union
6
+
7
+ from transformers import PretrainedConfig
8
+
9
+
10
+ class CaduceusConfig(PretrainedConfig):
11
+ """Config that extends the original MambaConfig with params relevant to bi-directionality and RC equivariance."""
12
+ model_type = "caduceus"
13
+
14
+ def __init__(
15
+ self,
16
+ # From original MambaConfig
17
+ d_model: int = 2560,
18
+ n_layer: int = 64,
19
+ vocab_size: int = 50277,
20
+ ssm_cfg: Optional[dict] = None,
21
+ rms_norm: bool = True,
22
+ residual_in_fp32: bool = True,
23
+ fused_add_norm: bool = True,
24
+ pad_vocab_size_multiple: int = 8,
25
+
26
+ # Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm
27
+ norm_epsilon: float = 1e-5,
28
+
29
+ # Used in init_weights
30
+ initializer_cfg: Optional[dict] = None,
31
+
32
+ # Caduceus-specific params
33
+ bidirectional: bool = True,
34
+ bidirectional_strategy: Union[str, None] = "add",
35
+ bidirectional_weight_tie: bool = True,
36
+ rcps: bool = False,
37
+ complement_map: Optional[dict] = None, # used for RCPSEmbedding / RCPSLMHead
38
+ **kwargs,
39
+ ):
40
+ super().__init__(**kwargs)
41
+ self.d_model = d_model
42
+ self.n_layer = n_layer
43
+ self.vocab_size = vocab_size
44
+ self.ssm_cfg = ssm_cfg
45
+ self.rms_norm = rms_norm
46
+ self.residual_in_fp32 = residual_in_fp32
47
+ self.fused_add_norm = fused_add_norm
48
+ self.pad_vocab_size_multiple = pad_vocab_size_multiple
49
+ self.norm_epsilon = norm_epsilon
50
+ self.initializer_cfg = initializer_cfg
51
+ self.bidirectional = bidirectional
52
+ self.bidirectional_strategy = bidirectional_strategy
53
+ self.bidirectional_weight_tie = bidirectional_weight_tie
54
+ self.rcps = rcps
55
+ self.complement_map = complement_map
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3e6976fe90460ff5d90d371b457d0263dc65e156ea5651b4a452e98228daef2
3
+ size 30937760
modeling_caduceus.py ADDED
@@ -0,0 +1,614 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Caduceus model for Hugging Face.
2
+
3
+ """
4
+
5
+ import math
6
+ from functools import partial
7
+ from typing import Optional, Tuple, Union
8
+
9
+ import torch
10
+ from mamba_ssm.modules.mamba_simple import Mamba, Block
11
+ from torch import nn
12
+ from torch.nn import functional as F
13
+ from transformers import PreTrainedModel
14
+ from transformers.modeling_outputs import BaseModelOutputWithNoAttention, MaskedLMOutput, SequenceClassifierOutput
15
+
16
+ try:
17
+ from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
18
+ except ImportError:
19
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
20
+
21
+ from .configuration_caduceus import CaduceusConfig
22
+ from .modeling_rcps import RCPSAddNormWrapper, RCPSEmbedding, RCPSLMHead, RCPSMambaBlock
23
+
24
+
25
+ def create_block(
26
+ d_model,
27
+ ssm_cfg=None,
28
+ norm_epsilon=1e-5,
29
+ rms_norm=False,
30
+ residual_in_fp32=False,
31
+ fused_add_norm=False,
32
+ layer_idx=None,
33
+ bidirectional=True,
34
+ bidirectional_strategy="add",
35
+ bidirectional_weight_tie=True,
36
+ rcps=False,
37
+ device=None,
38
+ dtype=None,
39
+ ):
40
+ """Create Caduceus block.
41
+
42
+ Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py
43
+ """
44
+ if ssm_cfg is None:
45
+ ssm_cfg = {}
46
+ factory_kwargs = {"device": device, "dtype": dtype}
47
+ bidirectional_kwargs = {
48
+ "bidirectional": bidirectional,
49
+ "bidirectional_strategy": bidirectional_strategy,
50
+ "bidirectional_weight_tie": bidirectional_weight_tie,
51
+ }
52
+ mixer_cls = partial(BiMambaWrapper, layer_idx=layer_idx, **ssm_cfg, **bidirectional_kwargs, **factory_kwargs)
53
+ norm_cls = partial(
54
+ nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
55
+ )
56
+ block_cls = RCPSMambaBlock if rcps else Block
57
+ block = block_cls(
58
+ d_model,
59
+ mixer_cls,
60
+ norm_cls=norm_cls,
61
+ fused_add_norm=fused_add_norm,
62
+ residual_in_fp32=residual_in_fp32,
63
+ )
64
+ block.layer_idx = layer_idx
65
+ return block
66
+
67
+
68
+ class BiMambaWrapper(nn.Module):
69
+ """Thin wrapper around Mamba to support bi-directionality."""
70
+
71
+ def __init__(
72
+ self,
73
+ d_model: int,
74
+ bidirectional: bool = True,
75
+ bidirectional_strategy: Optional[str] = "add",
76
+ bidirectional_weight_tie: bool = True,
77
+ **mamba_kwargs,
78
+ ):
79
+ super().__init__()
80
+ if bidirectional and bidirectional_strategy is None:
81
+ bidirectional_strategy = "add" # Default strategy: `add`
82
+ if bidirectional and bidirectional_strategy not in ["add", "ew_multiply"]:
83
+ raise NotImplementedError(f"`{bidirectional_strategy}` strategy for bi-directionality is not implemented!")
84
+ self.bidirectional = bidirectional
85
+ self.bidirectional_strategy = bidirectional_strategy
86
+ self.mamba_fwd = Mamba(
87
+ d_model=d_model,
88
+ **mamba_kwargs
89
+ )
90
+ if bidirectional:
91
+ self.mamba_rev = Mamba(
92
+ d_model=d_model,
93
+ **mamba_kwargs
94
+ )
95
+ if bidirectional_weight_tie: # Tie in and out projections (where most of param count lies)
96
+ self.mamba_rev.in_proj.weight = self.mamba_fwd.in_proj.weight
97
+ self.mamba_rev.in_proj.bias = self.mamba_fwd.in_proj.bias
98
+ self.mamba_rev.out_proj.weight = self.mamba_fwd.out_proj.weight
99
+ self.mamba_rev.out_proj.bias = self.mamba_fwd.out_proj.bias
100
+ else:
101
+ self.mamba_rev = None
102
+
103
+ def forward(self, hidden_states, inference_params=None):
104
+ """Bidirectional-enabled forward pass
105
+
106
+ hidden_states: (B, L, D)
107
+ Returns: same shape as hidden_states
108
+ """
109
+ out = self.mamba_fwd(hidden_states, inference_params=inference_params)
110
+ if self.bidirectional:
111
+ out_rev = self.mamba_rev(
112
+ hidden_states.flip(dims=(1,)), # Flip along the sequence length dimension
113
+ inference_params=inference_params
114
+ ).flip(dims=(1,)) # Flip back for combining with forward hidden states
115
+ if self.bidirectional_strategy == "add":
116
+ out = out + out_rev
117
+ elif self.bidirectional_strategy == "ew_multiply":
118
+ out = out * out_rev
119
+ else:
120
+ raise NotImplementedError(f"`{self.bidirectional_strategy}` for bi-directionality not implemented!")
121
+ return out
122
+
123
+
124
+ class CaduceusEmbeddings(nn.Module):
125
+ def __init__(
126
+ self,
127
+ config: CaduceusConfig,
128
+ device=None,
129
+ dtype=None,
130
+ ):
131
+ super().__init__()
132
+ factory_kwargs = {"device": device, "dtype": dtype}
133
+ if config.rcps:
134
+ self.word_embeddings = RCPSEmbedding(
135
+ config.vocab_size, config.d_model, config.complement_map, **factory_kwargs
136
+ )
137
+ else:
138
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.d_model, **factory_kwargs)
139
+
140
+ def forward(self, input_ids):
141
+ """
142
+ input_ids: (batch, seqlen)
143
+ """
144
+ return self.word_embeddings(input_ids)
145
+
146
+
147
+ class CaduceusMixerModel(nn.Module):
148
+ def __init__(
149
+ self,
150
+ config: CaduceusConfig,
151
+ device=None,
152
+ dtype=None,
153
+ ) -> None:
154
+ super().__init__()
155
+ factory_kwargs = {"device": device, "dtype": dtype}
156
+
157
+ self.fused_add_norm = config.fused_add_norm
158
+ self.rcps = config.rcps
159
+ self.residual_in_fp32 = config.residual_in_fp32
160
+
161
+ self.embeddings = CaduceusEmbeddings(config, **factory_kwargs)
162
+
163
+ # Mamba changes the order of residual and layer norm:
164
+ # Instead of LN -> Attn / MLP -> Add, we do:
165
+ # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
166
+ # the main branch (output of MLP / Mixer). The model definition is unchanged.
167
+ # This is for performance reason: we can fuse add + layer_norm.
168
+ if config.fused_add_norm:
169
+ if layer_norm_fn is None or rms_norm_fn is None:
170
+ raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
171
+
172
+ self.layers = nn.ModuleList(
173
+ [
174
+ create_block(
175
+ config.d_model,
176
+ ssm_cfg=config.ssm_cfg,
177
+ norm_epsilon=config.norm_epsilon,
178
+ rms_norm=config.rms_norm,
179
+ residual_in_fp32=config.residual_in_fp32,
180
+ fused_add_norm=config.fused_add_norm,
181
+ layer_idx=i,
182
+ bidirectional=config.bidirectional,
183
+ bidirectional_strategy=config.bidirectional_strategy,
184
+ bidirectional_weight_tie=config.bidirectional_weight_tie,
185
+ rcps=config.rcps,
186
+ **factory_kwargs,
187
+ )
188
+ for i in range(config.n_layer)
189
+ ]
190
+ )
191
+
192
+ norm_f = (nn.LayerNorm if not config.rms_norm else RMSNorm)(
193
+ config.d_model, eps=config.norm_epsilon, **factory_kwargs
194
+ )
195
+ self.norm_f = norm_f if (config.fused_add_norm or not config.rcps) else RCPSAddNormWrapper(norm_f)
196
+
197
+ def forward(self, input_ids, inputs_embeds=None, output_hidden_states=False):
198
+ """Mixer forward."""
199
+ all_hidden_states = []
200
+ if inputs_embeds is not None:
201
+ hidden_states = inputs_embeds
202
+ else:
203
+ hidden_states = self.embeddings(input_ids)
204
+
205
+ residual = None
206
+ for layer in self.layers:
207
+ if output_hidden_states:
208
+ all_hidden_states.append(hidden_states)
209
+ # TODO: Add support for gradient checkpointing
210
+ hidden_states, residual = layer(
211
+ hidden_states, residual, inference_params=None
212
+ )
213
+
214
+ if not self.fused_add_norm:
215
+ if self.rcps:
216
+ hidden_states = self.norm_f(hidden_states, residual=residual)
217
+ else:
218
+ residual = (hidden_states + residual) if residual is not None else hidden_states
219
+ hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
220
+ else:
221
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
222
+ if self.rcps:
223
+ # Set prenorm=False here since we don't need the residual
224
+ hidden_states_fwd = fused_add_norm_fn(
225
+ hidden_states[..., :hidden_states.shape[-1] // 2],
226
+ self.norm_f.weight,
227
+ self.norm_f.bias,
228
+ eps=self.norm_f.eps,
229
+ residual=residual[..., :hidden_states.shape[-1] // 2],
230
+ prenorm=False,
231
+ residual_in_fp32=self.residual_in_fp32,
232
+ )
233
+ hidden_states_rc = fused_add_norm_fn(
234
+ hidden_states[..., hidden_states.shape[-1] // 2:].flip(dims=[-2, -1]),
235
+ self.norm_f.weight,
236
+ self.norm_f.bias,
237
+ eps=self.norm_f.eps,
238
+ residual=residual[..., hidden_states.shape[-1] // 2:].flip(dims=[-2, -1]),
239
+ prenorm=False,
240
+ residual_in_fp32=self.residual_in_fp32,
241
+ )
242
+ hidden_states = torch.cat([hidden_states_fwd, hidden_states_rc.flip(dims=[-2, -1])], dim=-1)
243
+ else:
244
+ # Set prenorm=False here since we don't need the residual
245
+ hidden_states = fused_add_norm_fn(
246
+ hidden_states,
247
+ self.norm_f.weight,
248
+ self.norm_f.bias,
249
+ eps=self.norm_f.eps,
250
+ residual=residual,
251
+ prenorm=False,
252
+ residual_in_fp32=self.residual_in_fp32,
253
+ )
254
+ if output_hidden_states:
255
+ all_hidden_states.append(hidden_states)
256
+ return hidden_states, all_hidden_states
257
+
258
+
259
+ def cross_entropy(logits, y, ignore_index=-100):
260
+ """Cross entropy loss."""
261
+ logits = logits.view(-1, logits.shape[-1])
262
+ y = y.view(-1)
263
+ return F.cross_entropy(logits, y, ignore_index=ignore_index)
264
+
265
+
266
+ def weighted_cross_entropy(logits, y, loss_weights, ignore_index=-100):
267
+ """Weighted cross entropy loss (discounts certain tokens, e.g., repeated base pairs in genome)."""
268
+ logits = logits.view(-1, logits.shape[-1])
269
+ y = y.view(-1)
270
+ ce = F.cross_entropy(logits, y, ignore_index=ignore_index, reduction="none")
271
+ loss_weights = loss_weights.view(-1)
272
+ loss_weights[y == ignore_index] = 0.0
273
+ # TODO: Follows GPN implementation, but should we remove weight normalization?
274
+ return (ce * (loss_weights / loss_weights.sum())).sum()
275
+
276
+
277
+ class CaduceusPreTrainedModel(PreTrainedModel):
278
+ """PreTrainedModel wrapper for Caduceus backbone."""
279
+ config_class = CaduceusConfig
280
+ base_model_prefix = "caduceus"
281
+ supports_gradient_checkpointing = False
282
+ _no_split_modules = ["BiMambaWrapper"]
283
+
284
+ def _init_weights(
285
+ self,
286
+ module,
287
+ initializer_range=0.02, # Now only used for embedding layer.
288
+ **kwargs,
289
+ ):
290
+ """Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py"""
291
+
292
+ n_layer = self.config.n_layer
293
+ initialized_cfg = self.config.initializer_cfg if self.config.initializer_cfg is not None else {}
294
+ rescale_prenorm_residual = initialized_cfg.get("rescale_prenorm_residual", True)
295
+ initializer_range = initialized_cfg.get("initializer_range", initializer_range)
296
+ n_residuals_per_layer = initialized_cfg.get("n_residuals_per_layer", 1)
297
+
298
+ if isinstance(module, nn.Linear):
299
+ if module.bias is not None:
300
+ if not getattr(module.bias, "_no_reinit", False):
301
+ nn.init.zeros_(module.bias)
302
+ elif isinstance(module, nn.Embedding):
303
+ nn.init.normal_(module.weight, std=initializer_range)
304
+
305
+ if rescale_prenorm_residual:
306
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
307
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth.
308
+ # > Scale the weights of residual layers at initialization by a factor of 1/√N where N is the # of
309
+ # residual layers.
310
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
311
+ #
312
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
313
+ for name, p in module.named_parameters():
314
+ if name in ["out_proj.weight", "fc2.weight"]:
315
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
316
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
317
+ # We need to reinit p since this code could be called multiple times
318
+ # Having just p *= scale would repeatedly scale it down
319
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
320
+ with torch.no_grad():
321
+ p /= math.sqrt(n_residuals_per_layer * n_layer)
322
+
323
+
324
+ class Caduceus(CaduceusPreTrainedModel):
325
+ """Caduceus model that can be instantiated using HF patterns."""
326
+ def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs):
327
+ super().__init__(config)
328
+
329
+ if config.rcps:
330
+ assert config.complement_map is not None, "Complement map must be provided for RCPS."
331
+
332
+ # Adjust vocab size and complement maps if vocab padding is set.
333
+ if config.vocab_size % config.pad_vocab_size_multiple != 0:
334
+ config.vocab_size += config.pad_vocab_size_multiple - (config.vocab_size % config.pad_vocab_size_multiple)
335
+ if config.complement_map is not None and config.vocab_size > len(config.complement_map):
336
+ for i in range(len(config.complement_map), config.vocab_size):
337
+ config.complement_map[i] = i
338
+
339
+ self.config = config
340
+ factory_kwargs = {"device": device, "dtype": dtype}
341
+ self.backbone = CaduceusMixerModel(config, **factory_kwargs, **kwargs)
342
+
343
+ def forward(
344
+ self,
345
+ input_ids: torch.LongTensor = None,
346
+ inputs_embeds: Optional[torch.FloatTensor] = None,
347
+ output_hidden_states: Optional[bool] = None,
348
+ return_dict: Optional[bool] = None,
349
+ ) -> Union[torch.Tensor, Tuple, BaseModelOutputWithNoAttention]:
350
+ """HF-compatible forward method."""
351
+ output_hidden_states = (
352
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
353
+ )
354
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
355
+
356
+ hidden_states, all_hidden_states = self.backbone(
357
+ input_ids,
358
+ inputs_embeds=inputs_embeds,
359
+ output_hidden_states=output_hidden_states
360
+ )
361
+ if return_dict:
362
+ return BaseModelOutputWithNoAttention(
363
+ last_hidden_state=hidden_states,
364
+ hidden_states=all_hidden_states if output_hidden_states else None
365
+ )
366
+ elif output_hidden_states:
367
+ return hidden_states, all_hidden_states
368
+ else:
369
+ return hidden_states
370
+
371
+
372
+ class CaduceusForMaskedLM(CaduceusPreTrainedModel):
373
+ """HF-compatible Caduceus model for masked language modeling."""
374
+
375
+ def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs):
376
+ super().__init__(config, **kwargs)
377
+ factory_kwargs = {"device": device, "dtype": dtype}
378
+ self.caduceus = Caduceus(config, **factory_kwargs, **kwargs)
379
+ if config.rcps:
380
+ self.lm_head = RCPSLMHead(
381
+ complement_map=self.config.complement_map, # Use caduceus config as it might have been updated
382
+ vocab_size=self.config.vocab_size, # Use caduceus config as it might have been updated
383
+ true_dim=config.d_model,
384
+ dtype=dtype
385
+ )
386
+ else:
387
+ self.lm_head = nn.Linear(
388
+ config.d_model,
389
+ self.config.vocab_size, # Use caduceus config as it might have been updated
390
+ bias=False,
391
+ **factory_kwargs
392
+ )
393
+
394
+ # Initialize weights and apply final processing
395
+ self.post_init()
396
+
397
+ def get_input_embeddings(self):
398
+ return self.caduceus.backbone.embeddings.word_embeddings
399
+
400
+ def set_input_embeddings(self, value):
401
+ if self.config.rcps:
402
+ raise NotImplementedError("Setting input embeddings for RCPS LM is not supported.")
403
+ self.caduceus.backbone.embeddings.word_embeddings = value
404
+
405
+ def get_output_embeddings(self):
406
+ return self.lm_head
407
+
408
+ def set_output_embeddings(self, new_embeddings):
409
+ """Overrides output embeddings."""
410
+ if self.config.rcps:
411
+ raise NotImplementedError("Setting output embeddings for RCPS LM is not supported.")
412
+ self.lm_head = new_embeddings
413
+
414
+ def tie_weights(self):
415
+ """Tie weights, accounting for RCPS."""
416
+ if self.config.rcps:
417
+ self.lm_head.set_weight(self.get_input_embeddings().weight)
418
+ else:
419
+ super().tie_weights()
420
+
421
+ def get_decoder(self):
422
+ """Get decoder (backbone) for the model."""
423
+ return self.caduceus
424
+
425
+ def set_decoder(self, decoder):
426
+ """Set decoder (backbone) for the model."""
427
+ self.caduceus = decoder
428
+
429
+ def forward(
430
+ self,
431
+ input_ids: torch.LongTensor = None,
432
+ inputs_embeds: Optional[torch.FloatTensor] = None,
433
+ labels: Optional[torch.LongTensor] = None,
434
+ loss_weights: Optional[torch.FloatTensor] = None,
435
+ output_hidden_states: Optional[bool] = None,
436
+ return_dict: Optional[bool] = None,
437
+ ) -> Union[Tuple, MaskedLMOutput]:
438
+ """HF-compatible forward method."""
439
+
440
+ output_hidden_states = (
441
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
442
+ )
443
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
444
+
445
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
446
+ outputs = self.caduceus(
447
+ input_ids=input_ids,
448
+ inputs_embeds=inputs_embeds,
449
+ output_hidden_states=output_hidden_states,
450
+ return_dict=return_dict,
451
+ )
452
+
453
+ hidden_states = outputs[0]
454
+ logits = self.lm_head(hidden_states)
455
+ logits = logits.float()
456
+
457
+ loss = None
458
+ if labels is not None:
459
+ if loss_weights is not None:
460
+ loss = weighted_cross_entropy(logits, labels, loss_weights, ignore_index=self.config.pad_token_id)
461
+ else:
462
+ loss = cross_entropy(logits, labels, ignore_index=self.config.pad_token_id)
463
+
464
+ if not return_dict:
465
+ output = (logits,) + outputs[1:]
466
+ return (loss,) + output if loss is not None else output
467
+
468
+ return MaskedLMOutput(
469
+ loss=loss,
470
+ logits=logits,
471
+ hidden_states=outputs.hidden_states,
472
+ )
473
+
474
+
475
+ class CaduceusForSequenceClassification(CaduceusPreTrainedModel):
476
+ def __init__(
477
+ self,
478
+ config: CaduceusConfig,
479
+ pooling_strategy: str = "mean",
480
+ conjoin_train: bool = False,
481
+ conjoin_eval: bool = False,
482
+ device=None,
483
+ dtype=None,
484
+ **kwargs):
485
+ super().__init__(config, **kwargs)
486
+ if pooling_strategy not in ["mean", "max", "first", "last"]:
487
+ raise NotImplementedError(f"Pooling strategy `{pooling_strategy}` not implemented.")
488
+ self.pooling_strategy = pooling_strategy
489
+ factory_kwargs = {"device": device, "dtype": dtype}
490
+ self.num_labels = kwargs.get("num_labels", config.num_labels)
491
+ self.caduceus = Caduceus(config, **factory_kwargs, **kwargs)
492
+ self.score = nn.Linear(config.d_model, self.num_labels, bias=False)
493
+
494
+ self.conjoin_train = conjoin_train
495
+ self.conjoin_eval = conjoin_eval
496
+
497
+ # Initialize weights and apply final processing
498
+ self.post_init()
499
+
500
+ def get_input_embeddings(self):
501
+ return self.caduceus.backbone.embeddings.word_embeddings
502
+
503
+ def set_input_embeddings(self, value):
504
+ if self.config.rcps:
505
+ raise NotImplementedError("Setting input embeddings for RCPS LM is not supported.")
506
+ self.caduceus.backbone.embeddings.word_embeddings = value
507
+
508
+ def pool_hidden_states(self, hidden_states, sequence_length_dim=1):
509
+ """Pools hidden states along sequence length dimension."""
510
+ if self.pooling_strategy == "mean": # Mean pooling along sequence length dimension
511
+ return hidden_states.mean(dim=sequence_length_dim)
512
+ if self.pooling_strategy == "max": # Max pooling along sequence length dimension
513
+ return hidden_states.max(dim=sequence_length_dim).values
514
+ if self.pooling_strategy == "last": # Use embedding of last token in the sequence
515
+ return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[-1, ...]
516
+ if self.pooling_strategy == "first": # Use embedding of first token in the sequence
517
+ return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[0, ...]
518
+
519
+ def forward(
520
+ self,
521
+ input_ids: torch.LongTensor = None,
522
+ inputs_embeds: Optional[torch.FloatTensor] = None,
523
+ labels: Optional[torch.LongTensor] = None,
524
+ output_hidden_states: Optional[bool] = None,
525
+ return_dict: Optional[bool] = None,
526
+ ) -> Union[Tuple, SequenceClassifierOutput]:
527
+ r"""
528
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
529
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
530
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
531
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
532
+ """
533
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
534
+
535
+ # Get hidden representations from the backbone
536
+ if self.config.rcps: # Hidden states have 2 * d_model channels for RCPS
537
+ transformer_outputs = self.caduceus(
538
+ input_ids,
539
+ inputs_embeds=inputs_embeds,
540
+ output_hidden_states=output_hidden_states,
541
+ return_dict=return_dict,
542
+ )
543
+ hidden_states = torch.stack(
544
+ [
545
+ transformer_outputs[0][..., :self.config.d_model // 2],
546
+ torch.flip(transformer_outputs[0][..., self.config.d_model // 2:], dims=[1, 2])
547
+ ],
548
+ dim=-1
549
+ )
550
+ elif self.conjoin_train or (self.conjoin_eval and not self.training): # For conjoining / post-hoc conjoining
551
+ assert input_ids is not None, "`input_ids` must be provided for conjoining."
552
+ assert input_ids.ndim == 3, "`input_ids` must be 3D tensor: channels corresponds to forward and rc strands."
553
+ transformer_outputs = self.caduceus(
554
+ input_ids[..., 0],
555
+ inputs_embeds=None,
556
+ output_hidden_states=output_hidden_states,
557
+ return_dict=return_dict,
558
+ )
559
+ transformer_outputs_rc = self.caduceus(
560
+ input_ids[..., 1],
561
+ inputs_embeds=None,
562
+ output_hidden_states=output_hidden_states,
563
+ return_dict=return_dict,
564
+ )
565
+ # Stack along channel dimension (dim=-1)
566
+ hidden_states = torch.stack([transformer_outputs[0], transformer_outputs_rc[0]], dim=-1)
567
+ else:
568
+ transformer_outputs = self.caduceus(
569
+ input_ids,
570
+ inputs_embeds=None,
571
+ output_hidden_states=output_hidden_states,
572
+ return_dict=return_dict,
573
+ )
574
+ hidden_states = transformer_outputs[0]
575
+
576
+ # Pool and get logits
577
+ pooled_hidden_states = self.pool_hidden_states(hidden_states)
578
+ # Potentially run `score` twice (with parameters shared) for conjoining
579
+ if hidden_states.ndim == 4: # bsz, seq_len, hidden_dim, 2 where last channel has the stacked fwd and rc reps
580
+ logits_fwd = self.score(pooled_hidden_states[..., 0])
581
+ logits_rc = self.score(pooled_hidden_states[..., 1])
582
+ logits = (logits_fwd + logits_rc) / 2
583
+ else:
584
+ logits = self.score(pooled_hidden_states)
585
+
586
+ loss = None
587
+ if labels is not None:
588
+ labels = labels.to(logits.device)
589
+ if self.config.problem_type is None:
590
+ if self.num_labels == 1:
591
+ self.config.problem_type = "regression"
592
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
593
+ self.config.problem_type = "single_label_classification"
594
+ else:
595
+ self.config.problem_type = "multi_label_classification"
596
+
597
+ if self.config.problem_type == "regression":
598
+ if self.num_labels == 1:
599
+ loss = F.mse_loss(logits.squeeze(), labels.squeeze())
600
+ else:
601
+ loss = F.mse_loss(logits, labels)
602
+ elif self.config.problem_type == "single_label_classification":
603
+ loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
604
+ elif self.config.problem_type == "multi_label_classification":
605
+ loss = F.binary_cross_entropy_with_logits(logits, labels)
606
+ if not return_dict:
607
+ output = (logits,) + transformer_outputs[1:]
608
+ return ((loss,) + output) if loss is not None else output
609
+
610
+ return SequenceClassifierOutput(
611
+ loss=loss,
612
+ logits=logits,
613
+ hidden_states=transformer_outputs.hidden_states,
614
+ )
modeling_rcps.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Reverse-complement equivariant modules.
2
+
3
+ """
4
+ from collections import OrderedDict
5
+ from typing import Optional
6
+
7
+ import torch
8
+ from torch import Tensor
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+
12
+ try:
13
+ from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
14
+ except ImportError:
15
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
16
+
17
+
18
+ class RCPSEmbedding(nn.Module):
19
+ """Embedding layer that supports reverse-complement equivariance."""
20
+ def __init__(self, vocab_size: int, d_model: int, complement_map: dict, **factory_kwargs):
21
+ """
22
+ Args:
23
+ vocab_size: Size of vocabulary.
24
+ d_model: Dimensionality of embedding (actual embedding matrix will have 1/2 the output dim).
25
+ complement_map: Dictionary mapping each token id to its complement.
26
+ """
27
+ super().__init__()
28
+ self.register_buffer(
29
+ "complement_map",
30
+ torch.tensor(list(OrderedDict(complement_map).values()), dtype=torch.long)
31
+ )
32
+ self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
33
+
34
+ @property
35
+ def weight(self):
36
+ """Embedding weights."""
37
+ return self.embedding.weight
38
+
39
+ def set_weight(self, value):
40
+ """Set embedding weights."""
41
+ self.embedding.weight = value
42
+
43
+ def rc(self, x):
44
+ """Reverse-complement a tensor of input_ids by flipping along length dimension and complementing the ids."""
45
+ return torch.gather(
46
+ self.complement_map.unsqueeze(0).expand(x.shape[0], -1),
47
+ dim=1,
48
+ index=torch.flip(x, dims=[-1])
49
+ )
50
+
51
+ def forward(self, input_ids):
52
+ """Reverse-complement equivariant forward pass.
53
+
54
+ This embedding module doubles the output dimensionality to support reverse-complement equivariance.
55
+
56
+ Args:
57
+ input_ids: Input tensor of shape (batch_size, seq_len)
58
+ Returns:
59
+ Embedding tensor of shape (batch_size, seq_len, d_model * 2)
60
+ """
61
+ fwd_out = self.embedding(input_ids)
62
+ rc_out = torch.flip(self.embedding(self.rc(input_ids)), dims=[-2, -1])
63
+
64
+ return torch.cat([fwd_out, rc_out], dim=-1)
65
+
66
+
67
+ class RCPSWrapper(nn.Module):
68
+ """Wrapper to convert arbitrary nn.Module into a reverse-complement equivariant module.
69
+
70
+ See ref. "Towards a Better Understanding of Reverse-Complement Equivariance for Deep Learning Models in Regulatory
71
+ Genomics", Zhou et al. (2022), https://proceedings.mlr.press/v165/zhou22a.html for more details.
72
+ """
73
+ def __init__(self, submodule: nn.Module):
74
+ super().__init__()
75
+ self.submodule = submodule
76
+
77
+ @staticmethod
78
+ def rc(x):
79
+ """Reverse-complement a tensor by flipping the length (dim=-2) and channel (dim=-1) dimensions."""
80
+ return torch.flip(x, dims=[-2, -1])
81
+
82
+ def forward(self, x, **kwargs):
83
+ """Reverse-complement equivariant forward pass.
84
+
85
+ Args:
86
+ x: Input tensor of shape (batch_size, seq_len, channels)
87
+ Returns:
88
+ Output tensor of shape (batch_size, seq_len, channels * 2)
89
+ """
90
+ n_channels = x.shape[-1]
91
+ # Run submodule along sequence
92
+ fwd_out = self.submodule(x[..., :n_channels // 2], **kwargs)
93
+ # Run submodule along rc-sequence
94
+ rc_out = self.submodule(self.rc(x[..., n_channels // 2:]), **kwargs)
95
+ # Concatenate along channel dimension (dim=-1)
96
+ return torch.cat([fwd_out, self.rc(rc_out)], dim=-1)
97
+
98
+
99
+ class RCPSAddNormWrapper(RCPSWrapper):
100
+ """RC equivariant AddNorm layer."""
101
+ def __init__(self, submodule: nn.Module):
102
+ super().__init__(submodule)
103
+
104
+ def forward(self, x, residual=None):
105
+ """
106
+ Args:
107
+ x: Input tensor of shape (batch_size, seq_len, channels)
108
+ residual: Residual tensor of shape (batch_size, seq_len, channels) or None.
109
+ """
110
+ n_channels = x.shape[-1]
111
+ if residual is None:
112
+ residual = x
113
+ x_fwd = self.submodule(x[..., :n_channels // 2].to(dtype=self.submodule.weight.dtype))
114
+ x_rc = self.submodule(self.rc(x[..., n_channels // 2:]).to(dtype=self.submodule.weight.dtype))
115
+ x = torch.cat([x_fwd, self.rc(x_rc)], dim=-1)
116
+ else:
117
+ residual_fwd = x[..., :n_channels // 2] + residual[..., :n_channels // 2]
118
+ x_fwd = self.submodule(residual_fwd.to(dtype=self.submodule.weight.dtype))
119
+
120
+ residual_rc = self.rc(x[..., n_channels // 2:]) + self.rc(residual[..., n_channels // 2:])
121
+ x_rc = self.submodule(residual_rc.to(dtype=self.submodule.weight.dtype))
122
+
123
+ residual = torch.cat([residual_fwd, self.rc(residual_rc)], dim=-1)
124
+ x = torch.cat([x_fwd, self.rc(x_rc)], dim=-1)
125
+
126
+ return x, residual
127
+
128
+
129
+ class RCPSMambaBlock(nn.Module):
130
+ def __init__(
131
+ self,
132
+ dim,
133
+ mixer_cls,
134
+ norm_cls=nn.LayerNorm,
135
+ fused_add_norm=False,
136
+ residual_in_fp32=False,
137
+ device=None, # Keep for consistency with original Mamba Block
138
+ dtype=None, # Keep for consistency with original Mamba Block
139
+ ):
140
+ """RCPS version of simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection.
141
+
142
+ Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py
143
+ """
144
+ super().__init__()
145
+ self.residual_in_fp32 = residual_in_fp32
146
+ self.fused_add_norm = fused_add_norm
147
+ self.mixer = RCPSWrapper(mixer_cls(dim))
148
+ norm_f = norm_cls(dim)
149
+ self.norm = norm_f if fused_add_norm else RCPSAddNormWrapper(norm_f)
150
+
151
+ def forward(
152
+ self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
153
+ ):
154
+ r"""Pass the input through the encoder layer.
155
+
156
+ Args:
157
+ hidden_states: the sequence to the encoder layer (required).
158
+ residual: hidden_states = Mixer(LN(residual)).
159
+ inference_params: inference parameters for mixer.
160
+ """
161
+ if not self.fused_add_norm:
162
+ hidden_states, residual = self.norm(hidden_states, residual=residual)
163
+ if self.residual_in_fp32:
164
+ residual = residual.to(torch.float32)
165
+ else:
166
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
167
+
168
+ hidden_states_fwd, residual_fwd = fused_add_norm_fn(
169
+ hidden_states[..., hidden_states.shape[-1] // 2:],
170
+ self.norm.weight,
171
+ self.norm.bias,
172
+ residual=residual[..., hidden_states.shape[-1] // 2:] if residual is not None else None,
173
+ prenorm=True,
174
+ residual_in_fp32=self.residual_in_fp32,
175
+ eps=self.norm.eps,
176
+ )
177
+
178
+ hidden_states_rc, residual_rc = fused_add_norm_fn(
179
+ hidden_states[..., :hidden_states.shape[-1] // 2].flip(dims=[-2, -1]),
180
+ self.norm.weight,
181
+ self.norm.bias,
182
+ residual=residual[..., :hidden_states.shape[-1] // 2].flip(dims=[-2, -1]) if residual is not None else None,
183
+ prenorm=True,
184
+ residual_in_fp32=self.residual_in_fp32,
185
+ eps=self.norm.eps,
186
+ )
187
+ hidden_states = torch.cat([hidden_states_fwd, hidden_states_rc.flip(dims=[-2, -1])], dim=-1)
188
+ residual = torch.cat([residual_fwd, residual_rc.flip(dims=[-2, -1])], dim=-1)
189
+ hidden_states = self.mixer(hidden_states, inference_params=inference_params)
190
+ return hidden_states, residual
191
+
192
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
193
+ """Allocate inference cache for mixer.
194
+
195
+ Keep for compatibility with original Mamba Block.
196
+ """
197
+ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
198
+
199
+
200
+ class RCPSLMHead(nn.Module):
201
+ """LM Head for reverse-complement equivariant inputs, which have dim * 2 relative to standard inputs."""
202
+ def __init__(self, true_dim: int, vocab_size: int, complement_map: dict, **factory_kwargs):
203
+ """
204
+ `true_dim` corresponds to the actual dimensionality of the input were it not reverse-complement
205
+ equivariant, i.e. 0.5 times the actual input dim.
206
+ """
207
+ super().__init__()
208
+ self.register_buffer(
209
+ "complement_map",
210
+ torch.tensor(list(OrderedDict(complement_map).values()), dtype=torch.long)
211
+ )
212
+ self.true_dim = true_dim
213
+ self.lm_head = nn.Linear(true_dim, vocab_size, bias=False, **factory_kwargs)
214
+
215
+ @property
216
+ def weight(self):
217
+ """LM head weights."""
218
+ return self.lm_head.weight
219
+
220
+ def set_weight(self, value):
221
+ """Set LM head weights."""
222
+ self.lm_head.weight = value
223
+
224
+ def forward(self, x):
225
+ """
226
+ Args:
227
+ x: Input tensor of shape (batch_size, seq_len, dim), where dim = 2 * true_dim.
228
+ """
229
+ n_channels = x.shape[-1]
230
+ assert n_channels == 2 * self.true_dim, "Input must have 2 * true_dim channels."
231
+ fwd_logits = F.linear(x[..., :n_channels // 2], self.weight, bias=self.lm_head.bias)
232
+ rc_logits = F.linear(
233
+ torch.flip(x[..., n_channels // 2:], dims=[-1]),
234
+ self.weight[self.complement_map, :],
235
+ bias=self.lm_head.bias
236
+ )
237
+ return fwd_logits + rc_logits