yairschiff commited on
Commit
1c3660c
1 Parent(s): 971c530

Upload CaduceusForMaskedLM

Browse files
Files changed (6) hide show
  1. README.md +201 -0
  2. config.json +63 -0
  3. configuration_caduceus.py +55 -0
  4. model.safetensors +3 -0
  5. modeling_caduceus.py +615 -0
  6. modeling_rcps.py +243 -0
README.md ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
200
+
201
+
config.json ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CaduceusForMaskedLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_caduceus.CaduceusConfig",
7
+ "AutoModel": "modeling_caduceus.Caduceus",
8
+ "AutoModelForMaskedLM": "modeling_caduceus.CaduceusForMaskedLM",
9
+ "AutoModelForSequenceClassification": "modeling_caduceus.CaduceusForSequenceClassification"
10
+ },
11
+ "bidirectional": true,
12
+ "bidirectional_strategy": "add",
13
+ "bidirectional_weight_tie": true,
14
+ "complement_map": {
15
+ "0": 0,
16
+ "1": 1,
17
+ "2": 2,
18
+ "3": 3,
19
+ "4": 4,
20
+ "5": 5,
21
+ "6": 6,
22
+ "7": 10,
23
+ "8": 9,
24
+ "9": 8,
25
+ "10": 7,
26
+ "11": 11,
27
+ "12": 12,
28
+ "13": 13,
29
+ "14": 14,
30
+ "15": 15
31
+ },
32
+ "d_model": 118,
33
+ "fused_add_norm": true,
34
+ "initializer_cfg": {
35
+ "initializer_range": 0.02,
36
+ "n_residuals_per_layer": 1,
37
+ "rescale_prenorm_residual": true
38
+ },
39
+ "model_type": "caduceus",
40
+ "n_layer": 4,
41
+ "norm_epsilon": 1e-05,
42
+ "pad_vocab_size_multiple": 8,
43
+ "rcps": false,
44
+ "residual_in_fp32": false,
45
+ "rms_norm": true,
46
+ "ssm_cfg": {
47
+ "bias": false,
48
+ "conv_bias": true,
49
+ "d_conv": 4,
50
+ "d_state": 16,
51
+ "dt_init": "random",
52
+ "dt_init_floor": 0.0001,
53
+ "dt_max": 0.1,
54
+ "dt_min": 0.001,
55
+ "dt_rank": "auto",
56
+ "dt_scale": 1.0,
57
+ "expand": 2,
58
+ "use_fast_path": true
59
+ },
60
+ "torch_dtype": "float32",
61
+ "transformers_version": "4.38.1",
62
+ "vocab_size": 16
63
+ }
configuration_caduceus.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Caduceus config for Hugging Face.
2
+
3
+ """
4
+
5
+ from typing import Optional, Union
6
+
7
+ from transformers import PretrainedConfig
8
+
9
+
10
+ class CaduceusConfig(PretrainedConfig):
11
+ """Config that extends the original MambaConfig with params relevant to bi-directionality and RC equivariance."""
12
+ model_type = "caduceus"
13
+
14
+ def __init__(
15
+ self,
16
+ # From original MambaConfig
17
+ d_model: int = 2560,
18
+ n_layer: int = 64,
19
+ vocab_size: int = 50277,
20
+ ssm_cfg: Optional[dict] = None,
21
+ rms_norm: bool = True,
22
+ residual_in_fp32: bool = True,
23
+ fused_add_norm: bool = True,
24
+ pad_vocab_size_multiple: int = 8,
25
+
26
+ # Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm
27
+ norm_epsilon: float = 1e-5,
28
+
29
+ # Used in init_weights
30
+ initializer_cfg: Optional[dict] = None,
31
+
32
+ # Caduceus-specific params
33
+ bidirectional: bool = True,
34
+ bidirectional_strategy: Union[str, None] = "add",
35
+ bidirectional_weight_tie: bool = True,
36
+ rcps: bool = False,
37
+ complement_map: Optional[dict] = None, # used for RCPSEmbedding / RCPSLMHead
38
+ **kwargs,
39
+ ):
40
+ super().__init__(**kwargs)
41
+ self.d_model = d_model
42
+ self.n_layer = n_layer
43
+ self.vocab_size = vocab_size
44
+ self.ssm_cfg = ssm_cfg
45
+ self.rms_norm = rms_norm
46
+ self.residual_in_fp32 = residual_in_fp32
47
+ self.fused_add_norm = fused_add_norm
48
+ self.pad_vocab_size_multiple = pad_vocab_size_multiple
49
+ self.norm_epsilon = norm_epsilon
50
+ self.initializer_cfg = initializer_cfg
51
+ self.bidirectional = bidirectional
52
+ self.bidirectional_strategy = bidirectional_strategy
53
+ self.bidirectional_weight_tie = bidirectional_weight_tie
54
+ self.rcps = rcps
55
+ self.complement_map = complement_map
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1ffdad48c143215f5f05ae7d15d1eba16da2c2a587df4a56279fa794929b637
3
+ size 1891152
modeling_caduceus.py ADDED
@@ -0,0 +1,615 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Caduceus model for Hugging Face.
2
+
3
+ """
4
+
5
+ import math
6
+ from functools import partial
7
+ from typing import Optional, Tuple, Union
8
+
9
+ import torch
10
+ from mamba_ssm.modules.mamba_simple import Mamba, Block
11
+ from torch import nn
12
+ from torch.nn import functional as F
13
+ from transformers import PreTrainedModel
14
+ from transformers.modeling_outputs import BaseModelOutputWithNoAttention, MaskedLMOutput, SequenceClassifierOutput
15
+
16
+ try:
17
+ from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
18
+ except ImportError:
19
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
20
+
21
+ from .configuration_caduceus import CaduceusConfig
22
+ from .modeling_rcps import RCPSAddNormWrapper, RCPSEmbedding, RCPSLMHead, RCPSMambaBlock
23
+
24
+
25
+ def create_block(
26
+ d_model,
27
+ ssm_cfg=None,
28
+ norm_epsilon=1e-5,
29
+ rms_norm=False,
30
+ residual_in_fp32=False,
31
+ fused_add_norm=False,
32
+ layer_idx=None,
33
+ bidirectional=True,
34
+ bidirectional_strategy="add",
35
+ bidirectional_weight_tie=True,
36
+ rcps=False,
37
+ device=None,
38
+ dtype=None,
39
+ ):
40
+ """Create Caduceus block.
41
+
42
+ Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py
43
+ """
44
+ if ssm_cfg is None:
45
+ ssm_cfg = {}
46
+ factory_kwargs = {"device": device, "dtype": dtype}
47
+ bidirectional_kwargs = {
48
+ "bidirectional": bidirectional,
49
+ "bidirectional_strategy": bidirectional_strategy,
50
+ "bidirectional_weight_tie": bidirectional_weight_tie,
51
+ }
52
+ mixer_cls = partial(BiMambaWrapper, layer_idx=layer_idx, **ssm_cfg, **bidirectional_kwargs, **factory_kwargs)
53
+ norm_cls = partial(
54
+ nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
55
+ )
56
+ block_cls = RCPSMambaBlock if rcps else Block
57
+ block = block_cls(
58
+ d_model,
59
+ mixer_cls,
60
+ norm_cls=norm_cls,
61
+ fused_add_norm=fused_add_norm,
62
+ residual_in_fp32=residual_in_fp32,
63
+ )
64
+ block.layer_idx = layer_idx
65
+ return block
66
+
67
+
68
+ class BiMambaWrapper(nn.Module):
69
+ """Thin wrapper around Mamba to support bi-directionality."""
70
+
71
+ def __init__(
72
+ self,
73
+ d_model: int,
74
+ bidirectional: bool = True,
75
+ bidirectional_strategy: Optional[str] = "add",
76
+ bidirectional_weight_tie: bool = True,
77
+ **mamba_kwargs,
78
+ ):
79
+ super().__init__()
80
+ if bidirectional and bidirectional_strategy is None:
81
+ bidirectional_strategy = "add" # Default strategy: `add`
82
+ if bidirectional and bidirectional_strategy not in ["add", "ew_multiply"]:
83
+ raise NotImplementedError(f"`{bidirectional_strategy}` strategy for bi-directionality is not implemented!")
84
+ self.bidirectional = bidirectional
85
+ self.bidirectional_strategy = bidirectional_strategy
86
+ self.mamba_fwd = Mamba(
87
+ d_model=d_model,
88
+ **mamba_kwargs
89
+ )
90
+ if bidirectional:
91
+ self.mamba_rev = Mamba(
92
+ d_model=d_model,
93
+ **mamba_kwargs
94
+ )
95
+ if bidirectional_weight_tie: # Tie in and out projections (where most of param count lies)
96
+ self.mamba_rev.in_proj.weight = self.mamba_fwd.in_proj.weight
97
+ self.mamba_rev.in_proj.bias = self.mamba_fwd.in_proj.bias
98
+ self.mamba_rev.out_proj.weight = self.mamba_fwd.out_proj.weight
99
+ self.mamba_rev.out_proj.bias = self.mamba_fwd.out_proj.bias
100
+ else:
101
+ self.mamba_rev = None
102
+
103
+ def forward(self, hidden_states, inference_params=None):
104
+ """Bidirectional-enabled forward pass
105
+
106
+ hidden_states: (B, L, D)
107
+ Returns: same shape as hidden_states
108
+ """
109
+ out = self.mamba_fwd(hidden_states, inference_params=inference_params)
110
+ if self.bidirectional:
111
+ out_rev = self.mamba_rev(
112
+ hidden_states.flip(dims=(1,)), # Flip along the sequence length dimension
113
+ inference_params=inference_params
114
+ ).flip(dims=(1,)) # Flip back for combining with forward hidden states
115
+ if self.bidirectional_strategy == "add":
116
+ out = out + out_rev
117
+ elif self.bidirectional_strategy == "ew_multiply":
118
+ out = out * out_rev
119
+ else:
120
+ raise NotImplementedError(f"`{self.bidirectional_strategy}` for bi-directionality not implemented!")
121
+ return out
122
+
123
+
124
+ class CaduceusEmbeddings(nn.Module):
125
+ def __init__(
126
+ self,
127
+ config: CaduceusConfig,
128
+ device=None,
129
+ dtype=None,
130
+ ):
131
+ super().__init__()
132
+ factory_kwargs = {"device": device, "dtype": dtype}
133
+ if config.rcps:
134
+ self.word_embeddings = RCPSEmbedding(
135
+ config.vocab_size, config.d_model, config.complement_map, **factory_kwargs
136
+ )
137
+ else:
138
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.d_model, **factory_kwargs)
139
+
140
+ def forward(self, input_ids):
141
+ """
142
+ input_ids: (batch, seqlen)
143
+ """
144
+ return self.word_embeddings(input_ids)
145
+
146
+
147
+ class CaduceusMixerModel(nn.Module):
148
+ def __init__(
149
+ self,
150
+ config: CaduceusConfig,
151
+ device=None,
152
+ dtype=None,
153
+ ) -> None:
154
+ super().__init__()
155
+ factory_kwargs = {"device": device, "dtype": dtype}
156
+
157
+ self.fused_add_norm = config.fused_add_norm
158
+ self.rcps = config.rcps
159
+ self.residual_in_fp32 = config.residual_in_fp32
160
+
161
+ self.embeddings = CaduceusEmbeddings(config, **factory_kwargs)
162
+
163
+ # Mamba changes the order of residual and layer norm:
164
+ # Instead of LN -> Attn / MLP -> Add, we do:
165
+ # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
166
+ # the main branch (output of MLP / Mixer). The model definition is unchanged.
167
+ # This is for performance reason: we can fuse add + layer_norm.
168
+ if config.fused_add_norm:
169
+ if layer_norm_fn is None or rms_norm_fn is None:
170
+ raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
171
+
172
+ self.layers = nn.ModuleList(
173
+ [
174
+ create_block(
175
+ config.d_model,
176
+ ssm_cfg=config.ssm_cfg,
177
+ norm_epsilon=config.norm_epsilon,
178
+ rms_norm=config.rms_norm,
179
+ residual_in_fp32=config.residual_in_fp32,
180
+ fused_add_norm=config.fused_add_norm,
181
+ layer_idx=i,
182
+ bidirectional=config.bidirectional,
183
+ bidirectional_strategy=config.bidirectional_strategy,
184
+ bidirectional_weight_tie=config.bidirectional_weight_tie,
185
+ rcps=config.rcps,
186
+ **factory_kwargs,
187
+ )
188
+ for i in range(config.n_layer)
189
+ ]
190
+ )
191
+
192
+ norm_f = (nn.LayerNorm if not config.rms_norm else RMSNorm)(
193
+ config.d_model, eps=config.norm_epsilon, **factory_kwargs
194
+ )
195
+ self.norm_f = norm_f if (config.fused_add_norm or not config.rcps) else RCPSAddNormWrapper(norm_f)
196
+
197
+ def forward(self, input_ids, inputs_embeds=None, output_hidden_states=False):
198
+ """Mixer forward."""
199
+ all_hidden_states = []
200
+ if inputs_embeds is not None:
201
+ hidden_states = inputs_embeds
202
+ else:
203
+ hidden_states = self.embeddings(input_ids)
204
+
205
+ residual = None
206
+ for layer in self.layers:
207
+ if output_hidden_states:
208
+ all_hidden_states.append(hidden_states)
209
+ # TODO: Add support for gradient checkpointing
210
+ hidden_states, residual = layer(
211
+ hidden_states, residual, inference_params=None
212
+ )
213
+
214
+ if not self.fused_add_norm:
215
+ if self.rcps:
216
+ # Set prenorm=False here since we don't need the residual
217
+ hidden_states = self.norm_f(hidden_states, residual=residual, prenorm=False)
218
+ else:
219
+ residual = (hidden_states + residual) if residual is not None else hidden_states
220
+ hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
221
+ else:
222
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
223
+ if self.rcps:
224
+ # Set prenorm=False here since we don't need the residual
225
+ hidden_states_fwd = fused_add_norm_fn(
226
+ hidden_states[..., :hidden_states.shape[-1] // 2],
227
+ self.norm_f.weight,
228
+ self.norm_f.bias,
229
+ eps=self.norm_f.eps,
230
+ residual=residual[..., :hidden_states.shape[-1] // 2],
231
+ prenorm=False,
232
+ residual_in_fp32=self.residual_in_fp32,
233
+ )
234
+ hidden_states_rc = fused_add_norm_fn(
235
+ hidden_states[..., hidden_states.shape[-1] // 2:].flip(dims=[-2, -1]),
236
+ self.norm_f.weight,
237
+ self.norm_f.bias,
238
+ eps=self.norm_f.eps,
239
+ residual=residual[..., hidden_states.shape[-1] // 2:].flip(dims=[-2, -1]),
240
+ prenorm=False,
241
+ residual_in_fp32=self.residual_in_fp32,
242
+ )
243
+ hidden_states = torch.cat([hidden_states_fwd, hidden_states_rc.flip(dims=[-2, -1])], dim=-1)
244
+ else:
245
+ # Set prenorm=False here since we don't need the residual
246
+ hidden_states = fused_add_norm_fn(
247
+ hidden_states,
248
+ self.norm_f.weight,
249
+ self.norm_f.bias,
250
+ eps=self.norm_f.eps,
251
+ residual=residual,
252
+ prenorm=False,
253
+ residual_in_fp32=self.residual_in_fp32,
254
+ )
255
+ if output_hidden_states:
256
+ all_hidden_states.append(hidden_states)
257
+ return hidden_states, all_hidden_states
258
+
259
+
260
+ def cross_entropy(logits, y, ignore_index=-100):
261
+ """Cross entropy loss."""
262
+ logits = logits.view(-1, logits.shape[-1])
263
+ y = y.view(-1)
264
+ return F.cross_entropy(logits, y, ignore_index=ignore_index)
265
+
266
+
267
+ def weighted_cross_entropy(logits, y, loss_weights, ignore_index=-100):
268
+ """Weighted cross entropy loss (discounts certain tokens, e.g., repeated base pairs in genome)."""
269
+ logits = logits.view(-1, logits.shape[-1])
270
+ y = y.view(-1)
271
+ ce = F.cross_entropy(logits, y, ignore_index=ignore_index, reduction="none")
272
+ loss_weights = loss_weights.view(-1)
273
+ loss_weights[y == ignore_index] = 0.0
274
+ # TODO: Follows GPN implementation, but should we remove weight normalization?
275
+ return (ce * (loss_weights / loss_weights.sum())).sum()
276
+
277
+
278
+ class CaduceusPreTrainedModel(PreTrainedModel):
279
+ """PreTrainedModel wrapper for Caduceus backbone."""
280
+ config_class = CaduceusConfig
281
+ base_model_prefix = "caduceus"
282
+ supports_gradient_checkpointing = False
283
+ _no_split_modules = ["BiMambaWrapper"]
284
+
285
+ def _init_weights(
286
+ self,
287
+ module,
288
+ initializer_range=0.02, # Now only used for embedding layer.
289
+ **kwargs,
290
+ ):
291
+ """Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py"""
292
+
293
+ n_layer = self.config.n_layer
294
+ initialized_cfg = self.config.initializer_cfg if self.config.initializer_cfg is not None else {}
295
+ rescale_prenorm_residual = initialized_cfg.get("rescale_prenorm_residual", True)
296
+ initializer_range = initialized_cfg.get("initializer_range", initializer_range)
297
+ n_residuals_per_layer = initialized_cfg.get("n_residuals_per_layer", 1)
298
+
299
+ if isinstance(module, nn.Linear):
300
+ if module.bias is not None:
301
+ if not getattr(module.bias, "_no_reinit", False):
302
+ nn.init.zeros_(module.bias)
303
+ elif isinstance(module, nn.Embedding):
304
+ nn.init.normal_(module.weight, std=initializer_range)
305
+
306
+ if rescale_prenorm_residual:
307
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
308
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth.
309
+ # > Scale the weights of residual layers at initialization by a factor of 1/√N where N is the # of
310
+ # residual layers.
311
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
312
+ #
313
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
314
+ for name, p in module.named_parameters():
315
+ if name in ["out_proj.weight", "fc2.weight"]:
316
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
317
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
318
+ # We need to reinit p since this code could be called multiple times
319
+ # Having just p *= scale would repeatedly scale it down
320
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
321
+ with torch.no_grad():
322
+ p /= math.sqrt(n_residuals_per_layer * n_layer)
323
+
324
+
325
+ class Caduceus(CaduceusPreTrainedModel):
326
+ """Caduceus model that can be instantiated using HF patterns."""
327
+ def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs):
328
+ super().__init__(config)
329
+
330
+ if config.rcps:
331
+ assert config.complement_map is not None, "Complement map must be provided for RCPS."
332
+
333
+ # Adjust vocab size and complement maps if vocab padding is set.
334
+ if config.vocab_size % config.pad_vocab_size_multiple != 0:
335
+ config.vocab_size += config.pad_vocab_size_multiple - (config.vocab_size % config.pad_vocab_size_multiple)
336
+ if config.complement_map is not None and config.vocab_size > len(config.complement_map):
337
+ for i in range(len(config.complement_map), config.vocab_size):
338
+ config.complement_map[i] = i
339
+
340
+ self.config = config
341
+ factory_kwargs = {"device": device, "dtype": dtype}
342
+ self.backbone = CaduceusMixerModel(config, **factory_kwargs, **kwargs)
343
+
344
+ def forward(
345
+ self,
346
+ input_ids: torch.LongTensor = None,
347
+ inputs_embeds: Optional[torch.FloatTensor] = None,
348
+ output_hidden_states: Optional[bool] = None,
349
+ return_dict: Optional[bool] = None,
350
+ ) -> Union[torch.Tensor, Tuple, BaseModelOutputWithNoAttention]:
351
+ """HF-compatible forward method."""
352
+ output_hidden_states = (
353
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
354
+ )
355
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
356
+
357
+ hidden_states, all_hidden_states = self.backbone(
358
+ input_ids,
359
+ inputs_embeds=inputs_embeds,
360
+ output_hidden_states=output_hidden_states
361
+ )
362
+ if return_dict:
363
+ return BaseModelOutputWithNoAttention(
364
+ last_hidden_state=hidden_states,
365
+ hidden_states=all_hidden_states if output_hidden_states else None
366
+ )
367
+ elif output_hidden_states:
368
+ return hidden_states, all_hidden_states
369
+ else:
370
+ return hidden_states
371
+
372
+
373
+ class CaduceusForMaskedLM(CaduceusPreTrainedModel):
374
+ """HF-compatible Caduceus model for masked language modeling."""
375
+
376
+ def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs):
377
+ super().__init__(config, **kwargs)
378
+ factory_kwargs = {"device": device, "dtype": dtype}
379
+ self.caduceus = Caduceus(config, **factory_kwargs, **kwargs)
380
+ if config.rcps:
381
+ self.lm_head = RCPSLMHead(
382
+ complement_map=self.config.complement_map, # Use caduceus config as it might have been updated
383
+ vocab_size=self.config.vocab_size, # Use caduceus config as it might have been updated
384
+ true_dim=config.d_model,
385
+ dtype=dtype
386
+ )
387
+ else:
388
+ self.lm_head = nn.Linear(
389
+ config.d_model,
390
+ self.config.vocab_size, # Use caduceus config as it might have been updated
391
+ bias=False,
392
+ **factory_kwargs
393
+ )
394
+
395
+ # Initialize weights and apply final processing
396
+ self.post_init()
397
+
398
+ def get_input_embeddings(self):
399
+ return self.caduceus.backbone.embeddings.word_embeddings
400
+
401
+ def set_input_embeddings(self, value):
402
+ if self.config.rcps:
403
+ raise NotImplementedError("Setting input embeddings for RCPS LM is not supported.")
404
+ self.caduceus.backbone.embeddings.word_embeddings = value
405
+
406
+ def get_output_embeddings(self):
407
+ return self.lm_head
408
+
409
+ def set_output_embeddings(self, new_embeddings):
410
+ """Overrides output embeddings."""
411
+ if self.config.rcps:
412
+ raise NotImplementedError("Setting output embeddings for RCPS LM is not supported.")
413
+ self.lm_head = new_embeddings
414
+
415
+ def tie_weights(self):
416
+ """Tie weights, accounting for RCPS."""
417
+ if self.config.rcps:
418
+ self.lm_head.set_weight(self.get_input_embeddings().weight)
419
+ else:
420
+ super().tie_weights()
421
+
422
+ def get_decoder(self):
423
+ """Get decoder (backbone) for the model."""
424
+ return self.caduceus
425
+
426
+ def set_decoder(self, decoder):
427
+ """Set decoder (backbone) for the model."""
428
+ self.caduceus = decoder
429
+
430
+ def forward(
431
+ self,
432
+ input_ids: torch.LongTensor = None,
433
+ inputs_embeds: Optional[torch.FloatTensor] = None,
434
+ labels: Optional[torch.LongTensor] = None,
435
+ loss_weights: Optional[torch.FloatTensor] = None,
436
+ output_hidden_states: Optional[bool] = None,
437
+ return_dict: Optional[bool] = None,
438
+ ) -> Union[Tuple, MaskedLMOutput]:
439
+ """HF-compatible forward method."""
440
+
441
+ output_hidden_states = (
442
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
443
+ )
444
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
445
+
446
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
447
+ outputs = self.caduceus(
448
+ input_ids=input_ids,
449
+ inputs_embeds=inputs_embeds,
450
+ output_hidden_states=output_hidden_states,
451
+ return_dict=return_dict,
452
+ )
453
+
454
+ hidden_states = outputs[0]
455
+ logits = self.lm_head(hidden_states)
456
+ logits = logits.float()
457
+
458
+ loss = None
459
+ if labels is not None:
460
+ if loss_weights is not None:
461
+ loss = weighted_cross_entropy(logits, labels, loss_weights, ignore_index=self.config.pad_token_id)
462
+ else:
463
+ loss = cross_entropy(logits, labels, ignore_index=self.config.pad_token_id)
464
+
465
+ if not return_dict:
466
+ output = (logits,) + outputs[1:]
467
+ return (loss,) + output if loss is not None else output
468
+
469
+ return MaskedLMOutput(
470
+ loss=loss,
471
+ logits=logits,
472
+ hidden_states=outputs.hidden_states,
473
+ )
474
+
475
+
476
+ class CaduceusForSequenceClassification(CaduceusPreTrainedModel):
477
+ def __init__(
478
+ self,
479
+ config: CaduceusConfig,
480
+ pooling_strategy: str = "mean",
481
+ conjoin_train: bool = False,
482
+ conjoin_eval: bool = False,
483
+ device=None,
484
+ dtype=None,
485
+ **kwargs):
486
+ super().__init__(config, **kwargs)
487
+ if pooling_strategy not in ["mean", "max", "first", "last"]:
488
+ raise NotImplementedError(f"Pooling strategy `{pooling_strategy}` not implemented.")
489
+ self.pooling_strategy = pooling_strategy
490
+ factory_kwargs = {"device": device, "dtype": dtype}
491
+ self.num_labels = kwargs.get("num_labels", config.num_labels)
492
+ self.caduceus = Caduceus(config, **factory_kwargs, **kwargs)
493
+ self.score = nn.Linear(config.d_model, self.num_labels, bias=False)
494
+
495
+ self.conjoin_train = conjoin_train
496
+ self.conjoin_eval = conjoin_eval
497
+
498
+ # Initialize weights and apply final processing
499
+ self.post_init()
500
+
501
+ def get_input_embeddings(self):
502
+ return self.caduceus.backbone.embeddings.word_embeddings
503
+
504
+ def set_input_embeddings(self, value):
505
+ if self.config.rcps:
506
+ raise NotImplementedError("Setting input embeddings for RCPS LM is not supported.")
507
+ self.caduceus.backbone.embeddings.word_embeddings = value
508
+
509
+ def pool_hidden_states(self, hidden_states, sequence_length_dim=1):
510
+ """Pools hidden states along sequence length dimension."""
511
+ if self.pooling_strategy == "mean": # Mean pooling along sequence length dimension
512
+ return hidden_states.mean(dim=sequence_length_dim)
513
+ if self.pooling_strategy == "max": # Max pooling along sequence length dimension
514
+ return hidden_states.max(dim=sequence_length_dim).values
515
+ if self.pooling_strategy == "last": # Use embedding of last token in the sequence
516
+ return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[-1, ...]
517
+ if self.pooling_strategy == "first": # Use embedding of first token in the sequence
518
+ return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[0, ...]
519
+
520
+ def forward(
521
+ self,
522
+ input_ids: torch.LongTensor = None,
523
+ inputs_embeds: Optional[torch.FloatTensor] = None,
524
+ labels: Optional[torch.LongTensor] = None,
525
+ output_hidden_states: Optional[bool] = None,
526
+ return_dict: Optional[bool] = None,
527
+ ) -> Union[Tuple, SequenceClassifierOutput]:
528
+ r"""
529
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
530
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
531
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
532
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
533
+ """
534
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
535
+
536
+ # Get hidden representations from the backbone
537
+ if self.config.rcps: # Hidden states have 2 * d_model channels for RCPS
538
+ transformer_outputs = self.caduceus(
539
+ input_ids,
540
+ inputs_embeds=inputs_embeds,
541
+ output_hidden_states=output_hidden_states,
542
+ return_dict=return_dict,
543
+ )
544
+ hidden_states = torch.stack(
545
+ [
546
+ transformer_outputs[0][..., :self.config.d_model],
547
+ torch.flip(transformer_outputs[0][..., self.config.d_model:], dims=[1, 2])
548
+ ],
549
+ dim=-1
550
+ )
551
+ elif self.conjoin_train or (self.conjoin_eval and not self.training): # For conjoining / post-hoc conjoining
552
+ assert input_ids is not None, "`input_ids` must be provided for conjoining."
553
+ assert input_ids.ndim == 3, "`input_ids` must be 3D tensor: channels corresponds to forward and rc strands."
554
+ transformer_outputs = self.caduceus(
555
+ input_ids[..., 0],
556
+ inputs_embeds=None,
557
+ output_hidden_states=output_hidden_states,
558
+ return_dict=return_dict,
559
+ )
560
+ transformer_outputs_rc = self.caduceus(
561
+ input_ids[..., 1],
562
+ inputs_embeds=None,
563
+ output_hidden_states=output_hidden_states,
564
+ return_dict=return_dict,
565
+ )
566
+ # Stack along channel dimension (dim=-1)
567
+ hidden_states = torch.stack([transformer_outputs[0], transformer_outputs_rc[0]], dim=-1)
568
+ else:
569
+ transformer_outputs = self.caduceus(
570
+ input_ids,
571
+ inputs_embeds=None,
572
+ output_hidden_states=output_hidden_states,
573
+ return_dict=return_dict,
574
+ )
575
+ hidden_states = transformer_outputs[0]
576
+
577
+ # Pool and get logits
578
+ pooled_hidden_states = self.pool_hidden_states(hidden_states)
579
+ # Potentially run `score` twice (with parameters shared) for conjoining
580
+ if hidden_states.ndim == 4: # bsz, seq_len, hidden_dim, 2 where last channel has the stacked fwd and rc reps
581
+ logits_fwd = self.score(pooled_hidden_states[..., 0])
582
+ logits_rc = self.score(pooled_hidden_states[..., 1])
583
+ logits = (logits_fwd + logits_rc) / 2
584
+ else:
585
+ logits = self.score(pooled_hidden_states)
586
+
587
+ loss = None
588
+ if labels is not None:
589
+ labels = labels.to(logits.device)
590
+ if self.config.problem_type is None:
591
+ if self.num_labels == 1:
592
+ self.config.problem_type = "regression"
593
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
594
+ self.config.problem_type = "single_label_classification"
595
+ else:
596
+ self.config.problem_type = "multi_label_classification"
597
+
598
+ if self.config.problem_type == "regression":
599
+ if self.num_labels == 1:
600
+ loss = F.mse_loss(logits.squeeze(), labels.squeeze())
601
+ else:
602
+ loss = F.mse_loss(logits, labels)
603
+ elif self.config.problem_type == "single_label_classification":
604
+ loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
605
+ elif self.config.problem_type == "multi_label_classification":
606
+ loss = F.binary_cross_entropy_with_logits(logits, labels)
607
+ if not return_dict:
608
+ output = (logits,) + transformer_outputs[1:]
609
+ return ((loss,) + output) if loss is not None else output
610
+
611
+ return SequenceClassifierOutput(
612
+ loss=loss,
613
+ logits=logits,
614
+ hidden_states=transformer_outputs.hidden_states,
615
+ )
modeling_rcps.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Reverse-complement equivariant modules.
2
+
3
+ """
4
+ from collections import OrderedDict
5
+ from typing import Optional
6
+
7
+ import torch
8
+ from torch import Tensor
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+
12
+ try:
13
+ from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
14
+ except ImportError:
15
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
16
+
17
+
18
+ class RCPSEmbedding(nn.Module):
19
+ """Embedding layer that supports reverse-complement equivariance."""
20
+ def __init__(self, vocab_size: int, d_model: int, complement_map: dict, **factory_kwargs):
21
+ """
22
+ Args:
23
+ vocab_size: Size of vocabulary.
24
+ d_model: Dimensionality of embedding (actual embedding matrix will have 1/2 the output dim).
25
+ complement_map: Dictionary mapping each token id to its complement.
26
+ """
27
+ super().__init__()
28
+ self.register_buffer(
29
+ "complement_map",
30
+ torch.tensor(list(OrderedDict(complement_map).values()), dtype=torch.long)
31
+ )
32
+ self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
33
+
34
+ @property
35
+ def weight(self):
36
+ """Embedding weights."""
37
+ return self.embedding.weight
38
+
39
+ def set_weight(self, value):
40
+ """Set embedding weights."""
41
+ self.embedding.weight = value
42
+
43
+ def rc(self, x):
44
+ """Reverse-complement a tensor of input_ids by flipping along length dimension and complementing the ids."""
45
+ return torch.gather(
46
+ self.complement_map.unsqueeze(0).expand(x.shape[0], -1),
47
+ dim=1,
48
+ index=torch.flip(x, dims=[-1])
49
+ )
50
+
51
+ def forward(self, input_ids):
52
+ """Reverse-complement equivariant forward pass.
53
+
54
+ This embedding module doubles the output dimensionality to support reverse-complement equivariance.
55
+
56
+ Args:
57
+ input_ids: Input tensor of shape (batch_size, seq_len)
58
+ Returns:
59
+ Embedding tensor of shape (batch_size, seq_len, d_model * 2)
60
+ """
61
+ fwd_out = self.embedding(input_ids)
62
+ rc_out = torch.flip(self.embedding(self.rc(input_ids)), dims=[-2, -1])
63
+
64
+ return torch.cat([fwd_out, rc_out], dim=-1)
65
+
66
+
67
+ class RCPSWrapper(nn.Module):
68
+ """Wrapper to convert arbitrary nn.Module into a reverse-complement equivariant module.
69
+
70
+ See ref. "Towards a Better Understanding of Reverse-Complement Equivariance for Deep Learning Models in Regulatory
71
+ Genomics", Zhou et al. (2022), https://proceedings.mlr.press/v165/zhou22a.html for more details.
72
+ """
73
+ def __init__(self, submodule: nn.Module):
74
+ super().__init__()
75
+ self.submodule = submodule
76
+
77
+ @staticmethod
78
+ def rc(x):
79
+ """Reverse-complement a tensor by flipping the length (dim=-2) and channel (dim=-1) dimensions."""
80
+ return torch.flip(x, dims=[-2, -1])
81
+
82
+ def forward(self, x, **kwargs):
83
+ """Reverse-complement equivariant forward pass.
84
+
85
+ Args:
86
+ x: Input tensor of shape (batch_size, seq_len, channels)
87
+ Returns:
88
+ Output tensor of shape (batch_size, seq_len, channels * 2)
89
+ """
90
+ n_channels = x.shape[-1]
91
+ # Run submodule along sequence
92
+ fwd_out = self.submodule(x[..., :n_channels // 2], **kwargs)
93
+ # Run submodule along rc-sequence
94
+ rc_out = self.submodule(self.rc(x[..., n_channels // 2:]), **kwargs)
95
+ # Concatenate along channel dimension (dim=-1)
96
+ return torch.cat([fwd_out, self.rc(rc_out)], dim=-1)
97
+
98
+
99
+ class RCPSAddNormWrapper(RCPSWrapper):
100
+ """RC equivariant AddNorm layer."""
101
+ def __init__(self, submodule: nn.Module):
102
+ super().__init__(submodule)
103
+
104
+ def forward(self, x, residual=None, prenorm=False):
105
+ """
106
+ Args:
107
+ x: Input tensor of shape (batch_size, seq_len, channels)
108
+ residual: Residual tensor of shape (batch_size, seq_len, channels) or None.
109
+ prenorm: Whether to return residual.
110
+ """
111
+ n_channels = x.shape[-1]
112
+ if residual is None:
113
+ residual = x
114
+ x_fwd = self.submodule(x[..., :n_channels // 2].to(dtype=self.submodule.weight.dtype))
115
+ x_rc = self.submodule(self.rc(x[..., n_channels // 2:]).to(dtype=self.submodule.weight.dtype))
116
+ x = torch.cat([x_fwd, self.rc(x_rc)], dim=-1)
117
+ else:
118
+ residual_fwd = x[..., :n_channels // 2] + residual[..., :n_channels // 2]
119
+ x_fwd = self.submodule(residual_fwd.to(dtype=self.submodule.weight.dtype))
120
+
121
+ residual_rc = self.rc(x[..., n_channels // 2:]) + self.rc(residual[..., n_channels // 2:])
122
+ x_rc = self.submodule(residual_rc.to(dtype=self.submodule.weight.dtype))
123
+
124
+ residual = torch.cat([residual_fwd, self.rc(residual_rc)], dim=-1)
125
+ x = torch.cat([x_fwd, self.rc(x_rc)], dim=-1)
126
+
127
+ return x if not prenorm else (x, residual)
128
+
129
+
130
+ class RCPSMambaBlock(nn.Module):
131
+ def __init__(
132
+ self,
133
+ dim,
134
+ mixer_cls,
135
+ norm_cls=nn.LayerNorm,
136
+ fused_add_norm=False,
137
+ residual_in_fp32=False,
138
+ device=None, # Keep for consistency with original Mamba Block
139
+ dtype=None, # Keep for consistency with original Mamba Block
140
+ ):
141
+ """RCPS version of simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection.
142
+
143
+ Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py
144
+ """
145
+ super().__init__()
146
+ self.residual_in_fp32 = residual_in_fp32
147
+ self.fused_add_norm = fused_add_norm
148
+ self.mixer = RCPSWrapper(mixer_cls(dim))
149
+ norm_f = norm_cls(dim)
150
+ self.norm = norm_f if fused_add_norm else RCPSAddNormWrapper(norm_f)
151
+ if self.fused_add_norm:
152
+ assert RMSNorm is not None, "RMSNorm import fails"
153
+ assert isinstance(
154
+ self.norm, (nn.LayerNorm, RMSNorm)
155
+ ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
156
+
157
+ def forward(
158
+ self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
159
+ ):
160
+ r"""Pass the input through the encoder layer.
161
+
162
+ Args:
163
+ hidden_states: the sequence to the encoder layer (required).
164
+ residual: hidden_states = Mixer(LN(residual)).
165
+ inference_params: inference parameters for mixer.
166
+ """
167
+ if not self.fused_add_norm:
168
+ hidden_states, residual = self.norm(hidden_states, residual=residual, prenorm=True)
169
+ if self.residual_in_fp32:
170
+ residual = residual.to(torch.float32)
171
+ else:
172
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
173
+
174
+ hidden_states_fwd, residual_fwd = fused_add_norm_fn(
175
+ hidden_states[..., hidden_states.shape[-1] // 2:],
176
+ self.norm.weight,
177
+ self.norm.bias,
178
+ residual=residual[..., hidden_states.shape[-1] // 2:] if residual is not None else None,
179
+ prenorm=True,
180
+ residual_in_fp32=self.residual_in_fp32,
181
+ eps=self.norm.eps,
182
+ )
183
+
184
+ hidden_states_rc, residual_rc = fused_add_norm_fn(
185
+ hidden_states[..., :hidden_states.shape[-1] // 2].flip(dims=[-2, -1]),
186
+ self.norm.weight,
187
+ self.norm.bias,
188
+ residual=residual[..., :hidden_states.shape[-1] // 2].flip(dims=[-2, -1]) if residual is not None else None,
189
+ prenorm=True,
190
+ residual_in_fp32=self.residual_in_fp32,
191
+ eps=self.norm.eps,
192
+ )
193
+ hidden_states = torch.cat([hidden_states_fwd, hidden_states_rc.flip(dims=[-2, -1])], dim=-1)
194
+ residual = torch.cat([residual_fwd, residual_rc.flip(dims=[-2, -1])], dim=-1)
195
+ hidden_states = self.mixer(hidden_states, inference_params=inference_params)
196
+ return hidden_states, residual
197
+
198
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
199
+ """Allocate inference cache for mixer.
200
+
201
+ Keep for compatibility with original Mamba Block.
202
+ """
203
+ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
204
+
205
+
206
+ class RCPSLMHead(nn.Module):
207
+ """LM Head for reverse-complement equivariant inputs, which have dim * 2 relative to standard inputs."""
208
+ def __init__(self, true_dim: int, vocab_size: int, complement_map: dict, **factory_kwargs):
209
+ """
210
+ `true_dim` corresponds to the actual dimensionality of the input were it not reverse-complement
211
+ equivariant, i.e. 0.5 times the actual input dim.
212
+ """
213
+ super().__init__()
214
+ self.register_buffer(
215
+ "complement_map",
216
+ torch.tensor(list(OrderedDict(complement_map).values()), dtype=torch.long)
217
+ )
218
+ self.true_dim = true_dim
219
+ self.lm_head = nn.Linear(true_dim, vocab_size, bias=False, **factory_kwargs)
220
+
221
+ @property
222
+ def weight(self):
223
+ """LM head weights."""
224
+ return self.lm_head.weight
225
+
226
+ def set_weight(self, value):
227
+ """Set LM head weights."""
228
+ self.lm_head.weight = value
229
+
230
+ def forward(self, x):
231
+ """
232
+ Args:
233
+ x: Input tensor of shape (batch_size, seq_len, dim), where dim = 2 * true_dim.
234
+ """
235
+ n_channels = x.shape[-1]
236
+ assert n_channels == 2 * self.true_dim, "Input must have 2 * true_dim channels."
237
+ fwd_logits = F.linear(x[..., :n_channels // 2], self.weight, bias=self.lm_head.bias)
238
+ rc_logits = F.linear(
239
+ torch.flip(x[..., n_channels // 2:], dims=[-1]),
240
+ self.weight[self.complement_map, :],
241
+ bias=self.lm_head.bias
242
+ )
243
+ return fwd_logits + rc_logits