alycialee commited on
Commit
2d9dbaa
1 Parent(s): 7d6c601

onboard files for m2-bert 341m with automodel support

Browse files
README.md CHANGED
@@ -1,3 +1,55 @@
1
  ---
2
  license: apache-2.0
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ language:
4
+ - en
5
+ pipeline_tag: fill-mask
6
+ inference: false
7
  ---
8
+
9
+ # Monarch Mixer-BERT
10
+
11
+ The 341M checkpoint for M2-BERT-large from the paper [Monarch Mixer: A Simple Sub-Quadratic GEMM-Based Architecture](https://arxiv.org/abs/2310.12109).
12
+
13
+ Check out our [GitHub](https://github.com/HazyResearch/m2/tree/main) for instructions on how to download and fine-tune it!
14
+
15
+
16
+ ## How to use
17
+
18
+ You can load this model using Hugging Face `AutoModel`:
19
+ ```python
20
+ from transformers import AutoModelForMaskedLM
21
+ mlm = AutoModelForMaskedLM.from_pretrained('alycialee/m2-bert-341m', trust_remote_code=True)
22
+ ```
23
+
24
+ This model uses the Hugging Face `bert-base-uncased tokenizer`:
25
+ ```
26
+ from transformers import BertTokenizer
27
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
28
+ ```
29
+
30
+ You can use this model with a pipeline for masked language modeling:
31
+ ```python
32
+ from transformers import AutoModelForMaskedLM, BertTokenizer, pipeline
33
+
34
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
35
+ mlm = AutoModelForMaskedLM.from_pretrained('alycialee/m2-bert-341m', trust_remote_code=True)
36
+
37
+ unmasker = pipeline('fill-mask', model=mlm, tokenizer=tokenizer)
38
+ unmasker('Every morning, I enjoy a cup of [MASK] to start my day.')
39
+ ```
40
+
41
+ ### Remote Code
42
+
43
+ This model requires `trust_remote_code=True` to be passed to the `from_pretrained` method. This is because we use custom PyTorch code (see our GitHub). You should consider passing a `revision` argument that specifies the exact git commit of the code, for example:
44
+
45
+ ```python
46
+ mlm = AutoModelForMaskedLM.from_pretrained(
47
+ 'alycialee/m2-bert-341m',
48
+ trust_remote_code=True,
49
+ revision='',
50
+ )
51
+ ```
52
+
53
+ ### Configuration
54
+ Note `use_flash_mm` is false by default. Using FlashMM is currently not supported.
55
+ Using `hyena_training_additions` is turned off.
bert_layers.py ADDED
@@ -0,0 +1,889 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
2
+ # Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
3
+ # Copyright (c) 2022, Tri Dao.
4
+ # Copyright (c) 2023, MosaicML.
5
+ # Copyright (c) 2023, Dan Fu and Simran Arora.
6
+
7
+ import copy
8
+ import logging
9
+ import math
10
+ import os
11
+ import sys
12
+ import warnings
13
+ from typing import List, Optional, Tuple, Union
14
+ from functools import partial
15
+
16
+ # Add folder root to path to allow us to use relative imports regardless of what directory the script is run from
17
+ # sys.path.append(os.path.dirname(os.path.realpath(__file__)))
18
+
19
+ from .bert_padding import (index_first_axis,
20
+ index_put_first_axis, pad_input,
21
+ unpad_input, unpad_input_only)
22
+ import torch
23
+ import torch.nn as nn
24
+ from einops import rearrange
25
+ from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
26
+ from transformers.activations import ACT2FN
27
+ from transformers.modeling_outputs import (MaskedLMOutput,
28
+ SequenceClassifierOutput)
29
+ from transformers.models.bert.modeling_bert import BertPreTrainedModel
30
+
31
+ from .blockdiag_linear import BlockdiagLinear
32
+ from .monarch_mixer_sequence_mixer import MonarchMixerSequenceMixing
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+ torch.backends.cuda.matmul.allow_tf32 = True
37
+ torch.backends.cudnn.allow_tf32 = True
38
+
39
+ class BertEmbeddings(nn.Module):
40
+ """Construct the embeddings for words, ignoring position.
41
+
42
+ There are no positional embeddings since we use ALiBi and token_type
43
+ embeddings.
44
+
45
+ This module is modeled after the Hugging Face BERT's
46
+ :class:`~transformers.model.bert.modeling_bert.BertEmbeddings`, but is
47
+ modified as part of Mosaic BERT's ALiBi implementation. The key change is
48
+ that position embeddings are removed. Position information instead comes
49
+ from attention biases that scale linearly with the position distance
50
+ between query and key tokens.
51
+
52
+ This module ignores the `position_ids` input to the `forward` method.
53
+ """
54
+
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.word_embeddings = nn.Embedding(config.vocab_size,
58
+ config.hidden_size,
59
+ padding_idx=config.pad_token_id)
60
+ # ALiBi doesn't use position embeddings
61
+ if config.use_positional_encodings:
62
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
63
+ self.use_positional_encodings = config.use_positional_encodings
64
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
65
+ config.hidden_size)
66
+
67
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model
68
+ # variable name and be able to load any TensorFlow checkpoint file
69
+ self.LayerNorm = nn.LayerNorm(config.hidden_size,
70
+ eps=config.layer_norm_eps)
71
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
72
+ if config.use_positional_encodings:
73
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
74
+ self.register_buffer('token_type_ids',
75
+ torch.zeros(config.max_position_embeddings,
76
+ dtype=torch.long),
77
+ persistent=False)
78
+
79
+ def forward(
80
+ self,
81
+ input_ids: Optional[torch.LongTensor] = None,
82
+ token_type_ids: Optional[torch.LongTensor] = None,
83
+ position_ids: Optional[torch.LongTensor] = None,
84
+ inputs_embeds: Optional[torch.FloatTensor] = None,
85
+ past_key_values_length: int = 0,
86
+ return_position_encodings: bool = False,
87
+ ) -> torch.Tensor:
88
+ if (input_ids is not None) == (inputs_embeds is not None):
89
+ raise ValueError('Must specify either input_ids or input_embeds!')
90
+ if input_ids is not None:
91
+ input_shape = input_ids.size()
92
+ else:
93
+ assert inputs_embeds is not None # just for type checking
94
+ input_shape = inputs_embeds.size()[:-1]
95
+
96
+ seq_length = input_shape[1]
97
+
98
+ if position_ids is None:
99
+ if self.use_positional_encodings:
100
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
101
+
102
+ # Setting the token_type_ids to the registered buffer in constructor
103
+ # where it is all zeros, which usually occurs when it's auto-generated;
104
+ # registered buffer helps users when tracing the model without passing
105
+ # token_type_ids, solves issue #5664
106
+ if token_type_ids is None:
107
+ if hasattr(self, 'token_type_ids'):
108
+ assert isinstance(self.token_type_ids, torch.LongTensor)
109
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
110
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(
111
+ input_shape[0], seq_length)
112
+ token_type_ids = buffered_token_type_ids_expanded # type: ignore
113
+ else:
114
+ token_type_ids = torch.zeros(input_shape, # type: ignore
115
+ dtype=torch.long,
116
+ device=self.word_embeddings.device) # type: ignore # yapf: disable
117
+
118
+ if inputs_embeds is None:
119
+ inputs_embeds = self.word_embeddings(input_ids)
120
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
121
+
122
+ embeddings = inputs_embeds + token_type_embeddings
123
+ if self.use_positional_encodings:
124
+ position_embeddings = self.position_embeddings(position_ids)
125
+ embeddings += position_embeddings
126
+ embeddings = self.LayerNorm(embeddings)
127
+ embeddings = self.dropout(embeddings)
128
+ if return_position_encodings:
129
+ return embeddings, position_embeddings
130
+ else:
131
+ return embeddings
132
+
133
+ class BertMLP(nn.Module):
134
+ """Applies the FFN at the end of each BERT layer."""
135
+
136
+ def __init__(self, config):
137
+ super().__init__()
138
+ self.config = config
139
+
140
+ if self.config.use_monarch_mlp:
141
+ linear_cls = partial(BlockdiagLinear, nblocks=self.config.monarch_mlp_nblocks)
142
+ else:
143
+ linear_cls = nn.Linear
144
+
145
+ self.gated_layers = linear_cls(config.hidden_size,
146
+ config.intermediate_size,
147
+ bias=False)
148
+ self.act = nn.GELU(approximate='none')
149
+ self.wo = linear_cls(config.intermediate_size, config.hidden_size)
150
+
151
+ self.layernorm = nn.LayerNorm(config.hidden_size,
152
+ eps=config.layer_norm_eps)
153
+
154
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
155
+
156
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
157
+ """Compute new hidden states from current hidden states.
158
+
159
+ Args:
160
+ hidden_states (torch.Tensor): The (unpadded) hidden states from
161
+ the attention layer [nnz, dim].
162
+ """
163
+
164
+ residual_connection = hidden_states
165
+ hidden_states = self.gated_layers(hidden_states)
166
+ hidden_states = self.act(hidden_states)
167
+ hidden_states = self.dropout(hidden_states)
168
+ hidden_states = self.wo(hidden_states)
169
+ hidden_states = self.layernorm(hidden_states + residual_connection)
170
+ return hidden_states
171
+
172
+
173
+ class BertGatedLinearUnitMLP(nn.Module):
174
+ """Applies the FFN at the end of each BERT layer with a Gated Linear Unit"""
175
+
176
+ def __init__(self, config):
177
+ super().__init__()
178
+ self.config = config
179
+
180
+ self.is_padded = True
181
+
182
+ if self.config.use_monarch_mlp:
183
+ linear_cls = partial(BlockdiagLinear, nblocks=self.config.monarch_mlp_nblocks)
184
+ else:
185
+ linear_cls = nn.Linear
186
+ self.gated_layers = linear_cls(
187
+ config.hidden_size,
188
+ config.intermediate_size * 2,
189
+ bias=False
190
+ )
191
+ self.act = nn.GELU(approximate='none')
192
+ self.wo = linear_cls(config.intermediate_size, config.hidden_size)
193
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
194
+ self.layernorm = nn.LayerNorm(config.hidden_size,
195
+ eps=config.layer_norm_eps)
196
+
197
+
198
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
199
+ """Compute new hidden states from current hidden states.
200
+
201
+ Args:
202
+ hidden_states (torch.Tensor): The (unpadded) hidden states from
203
+ the attention layer [nnz, dim].
204
+ """
205
+
206
+ residual_connection = hidden_states
207
+ # compute the activation
208
+ hidden_states = self.gated_layers(hidden_states)
209
+
210
+ if self.is_padded:
211
+ gated = hidden_states[:, :, :self.config.intermediate_size]
212
+ non_gated = hidden_states[:, :, self.config.intermediate_size:]
213
+ else:
214
+ gated = hidden_states[:, :self.config.intermediate_size]
215
+ non_gated = hidden_states[:, self.config.intermediate_size:]
216
+
217
+ hidden_states = self.act(gated) * non_gated
218
+ hidden_states = self.dropout(hidden_states)
219
+ # multiply by the second matrix
220
+ hidden_states = self.wo(hidden_states)
221
+ # add the residual connection and post-LN
222
+ hidden_states = self.layernorm(hidden_states + residual_connection)
223
+
224
+ return hidden_states
225
+
226
+
227
+ class BertLayer(nn.Module):
228
+ """BERT layer, which includes Sequence Mixing (e.g. Hyena) and State Mixing (e.g. MLP)."""
229
+
230
+ def __init__(self, config):
231
+ super(BertLayer, self).__init__()
232
+
233
+ mm_cls = MonarchMixerSequenceMixing
234
+ self.attention = mm_cls(
235
+ config.hidden_size,
236
+ l_max=config.long_conv_l_max,
237
+ hyena_kernel_lr=config.long_conv_kernel_learning_rate,
238
+ bidirectional=config.bidirectional,
239
+
240
+ hyena_lr_pos_emb=config.hyena_lr_pos_emb,
241
+ hyena_w=config.hyena_w,
242
+ hyena_w_mod=config.hyena_w_mod,
243
+ hyena_wd=config.hyena_wd,
244
+ hyena_emb_dim=config.hyena_emb_dim,
245
+ hyena_filter_dropout=config.hyena_filter_dropout,
246
+ hyena_filter_order=config.hyena_filter_order,
247
+ residual_long_conv=config.residual_long_conv,
248
+ )
249
+
250
+ if config.use_glu_mlp:
251
+ self.mlp = BertGatedLinearUnitMLP(config)
252
+ else:
253
+ self.mlp = BertMLP(config)
254
+
255
+ def forward(
256
+ self,
257
+ hidden_states: torch.Tensor,
258
+ cu_seqlens: torch.Tensor,
259
+ seqlen: int,
260
+ subset_idx: Optional[torch.Tensor] = None,
261
+ indices: Optional[torch.Tensor] = None,
262
+ attn_mask: Optional[torch.Tensor] = None,
263
+ bias: Optional[torch.Tensor] = None,
264
+ ) -> torch.Tensor:
265
+ """Forward pass for a BERT layer, including both attention and MLP.
266
+
267
+ Args:
268
+ hidden_states: (total_nnz, dim)
269
+ cu_seqlens: (batch + 1,)
270
+ seqlen: int
271
+ subset_idx: () set of indices whose values we care about at the end of the layer
272
+ (e.g., the masked tokens, if this is the final layer).
273
+ indices: None or (total_nnz,)
274
+ attn_mask: None or (batch, max_seqlen_in_batch)
275
+ bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
276
+ """
277
+
278
+ attention_output = self.attention(hidden_states)
279
+ if type(attention_output) == tuple:
280
+ attention_output, _ = attention_output
281
+
282
+ layer_output = self.mlp(attention_output)
283
+
284
+ return layer_output
285
+
286
+
287
+ class BertEncoder(nn.Module):
288
+ """A stack of BERT layers providing the backbone of BERT.
289
+
290
+ Compared to the analogous Hugging Face BERT module, this module handles unpadding to reduce unnecessary computation
291
+ at padded tokens, and pre-computes attention biases to implement ALiBi.
292
+ """
293
+
294
+ def __init__(self, config):
295
+ super().__init__()
296
+ layer = BertLayer(config)
297
+ self.layer = nn.ModuleList(
298
+ [copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
299
+
300
+ self.num_attention_heads = config.num_attention_heads
301
+
302
+ def rebuild_alibi_tensor(self,
303
+ size: int,
304
+ device: Optional[Union[torch.device, str]] = None):
305
+ # Alibi
306
+ # Following https://github.com/ofirpress/attention_with_linear_biases/issues/5 (Implementation 1)
307
+ # In the causal case, you can exploit the fact that softmax is invariant to a uniform translation
308
+ # of the logits, which makes the math work out *after* applying causal masking. If no causal masking
309
+ # will be applied, it is necessary to construct the diagonal mask.
310
+ n_heads = self.num_attention_heads
311
+
312
+ def _get_alibi_head_slopes(n_heads: int) -> List[float]:
313
+
314
+ def get_slopes_power_of_2(n_heads: int) -> List[float]:
315
+ start = (2**(-2**-(math.log2(n_heads) - 3)))
316
+ ratio = start
317
+ return [start * ratio**i for i in range(n_heads)]
318
+
319
+ # In the paper, they only train models that have 2^a heads for some a. This function
320
+ # has some good properties that only occur when the input is a power of 2. To
321
+ # maintain that even when the number of heads is not a power of 2, we use a
322
+ # workaround.
323
+ if math.log2(n_heads).is_integer():
324
+ return get_slopes_power_of_2(n_heads)
325
+
326
+ closest_power_of_2 = 2**math.floor(math.log2(n_heads))
327
+ slopes_a = get_slopes_power_of_2(closest_power_of_2)
328
+ slopes_b = _get_alibi_head_slopes(2 * closest_power_of_2)
329
+ slopes_b = slopes_b[0::2][:n_heads - closest_power_of_2]
330
+ return slopes_a + slopes_b
331
+
332
+ context_position = torch.arange(size, device=device)[:, None]
333
+ memory_position = torch.arange(size, device=device)[None, :]
334
+ relative_position = torch.abs(memory_position - context_position)
335
+ # [n_heads, max_token_length, max_token_length]
336
+ relative_position = relative_position.unsqueeze(0).expand(
337
+ n_heads, -1, -1)
338
+ slopes = torch.Tensor(_get_alibi_head_slopes(n_heads)).to(device)
339
+ alibi = slopes.unsqueeze(1).unsqueeze(1) * -relative_position
340
+ # [1, n_heads, max_token_length, max_token_length]
341
+ alibi = alibi.unsqueeze(0)
342
+ assert alibi.shape == torch.Size([1, n_heads, size, size])
343
+
344
+ self._current_alibi_size = size
345
+ self.alibi = alibi
346
+
347
+ def forward(
348
+ self,
349
+ hidden_states: torch.Tensor,
350
+ attention_mask: torch.Tensor,
351
+ output_all_encoded_layers: Optional[bool] = True,
352
+ subset_mask: Optional[torch.Tensor] = None,
353
+ position_encodings: Optional[torch.Tensor] = None,
354
+ ) -> List[torch.Tensor]:
355
+
356
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
357
+ extended_attention_mask = extended_attention_mask.to(
358
+ dtype=next(self.parameters()).dtype) # fp16 compatibility
359
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
360
+ attention_mask_bool = attention_mask.bool()
361
+ batch, seqlen = hidden_states.shape[:2]
362
+
363
+ cu_seqlens = None
364
+ indices = None
365
+ alibi_attn_mask = None
366
+
367
+ all_encoder_layers = []
368
+ for layer_module in self.layer:
369
+ hidden_states = layer_module(hidden_states,
370
+ cu_seqlens,
371
+ seqlen,
372
+ None,
373
+ indices,
374
+ attn_mask=attention_mask,
375
+ bias=alibi_attn_mask
376
+ )
377
+ if position_encodings is not None:
378
+ hidden_states = hidden_states + position_encodings
379
+ if output_all_encoded_layers:
380
+ all_encoder_layers.append(hidden_states)
381
+ if subset_mask is not None:
382
+ hidden_states = hidden_states[subset_mask]
383
+
384
+ if not output_all_encoded_layers:
385
+ all_encoder_layers.append(hidden_states)
386
+ return all_encoder_layers
387
+
388
+
389
+ class BertPooler(nn.Module):
390
+
391
+ def __init__(self, config):
392
+ super(BertPooler, self).__init__()
393
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
394
+ self.activation = nn.Tanh()
395
+ self.pool_all = config.pool_all
396
+
397
+ def forward(self,
398
+ hidden_states: torch.Tensor,
399
+ pool: Optional[bool] = True,
400
+ mask= None) -> torch.Tensor:
401
+ # We "pool" the model by simply taking the hidden state corresponding
402
+ # to the first token.
403
+ if not self.pool_all:
404
+ first_token_tensor = hidden_states[:, 0] if pool else hidden_states
405
+ pooled_output = self.dense(first_token_tensor)
406
+ pooled_output = self.activation(pooled_output)
407
+ else:
408
+ # mean pool everything that isn't masked out
409
+ denom = torch.sum(mask, dim=1, keepdim=True)
410
+ mean_tensor = torch.sum((hidden_states) * mask.unsqueeze(-1), dim = 1) / denom
411
+ pooled_output = self.dense(mean_tensor)
412
+ pooled_output = self.activation(pooled_output)
413
+ return pooled_output
414
+
415
+
416
+ class BertPredictionHeadTransform(nn.Module):
417
+
418
+ def __init__(self, config):
419
+ super().__init__()
420
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
421
+ if isinstance(config.hidden_act, str):
422
+ self.transform_act_fn = ACT2FN[config.hidden_act]
423
+ else:
424
+ self.transform_act_fn = config.hidden_act
425
+ self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=1e-12)
426
+
427
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
428
+ hidden_states = self.dense(hidden_states)
429
+ hidden_states = self.transform_act_fn(hidden_states)
430
+ hidden_states = self.LayerNorm(hidden_states)
431
+ return hidden_states
432
+
433
+
434
+ class BertModel(BertPreTrainedModel):
435
+ """Overall BERT model.
436
+
437
+ Args:
438
+ config: a BertConfig class instance with the configuration to build a new model
439
+
440
+ Inputs:
441
+ `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
442
+ with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
443
+ `extract_features.py`, `run_classifier.py` and `run_squad.py`)
444
+ `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
445
+ types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
446
+ a `sentence B` token (see BERT paper for more details).
447
+ `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
448
+ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
449
+ input sequence length in the current batch. It's the mask that we typically use for attention when
450
+ a batch has varying length sentences.
451
+ `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
452
+
453
+ Outputs: Tuple of (encoded_layers, pooled_output)
454
+ `encoded_layers`: controlled by `output_all_encoded_layers` argument:
455
+ - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
456
+ of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
457
+ encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
458
+ - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
459
+ to the last attention block of shape [batch_size, sequence_length, hidden_size],
460
+ `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
461
+ classifier pretrained on top of the hidden state associated to the first character of the
462
+ input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
463
+
464
+ Example usage:
465
+ ```python
466
+ # Already been converted into WordPiece token ids
467
+ input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
468
+ input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
469
+ token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
470
+ config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
471
+ num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
472
+ model = BertModel(config=config)
473
+ all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
474
+ ```
475
+ """
476
+
477
+ def __init__(self, config, add_pooling_layer=True):
478
+ super(BertModel, self).__init__(config)
479
+ self.embeddings = BertEmbeddings(config)
480
+ self.encoder = BertEncoder(config)
481
+
482
+ self.pooler = BertPooler(config) if add_pooling_layer else None
483
+ self.post_init()
484
+
485
+
486
+ def get_input_embeddings(self):
487
+ return self.embeddings.word_embeddings
488
+
489
+ def set_input_embeddings(self, value):
490
+ self.embeddings.word_embeddings = value
491
+
492
+ def forward(
493
+ self,
494
+ input_ids: torch.Tensor,
495
+ token_type_ids: Optional[torch.Tensor] = None,
496
+ attention_mask: Optional[torch.Tensor] = None,
497
+ position_ids: Optional[torch.Tensor] = None,
498
+ output_all_encoded_layers: Optional[bool] = False,
499
+ masked_tokens_mask: Optional[torch.Tensor] = None,
500
+ **kwargs
501
+ ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]:
502
+ if attention_mask is None:
503
+ attention_mask = torch.ones_like(input_ids)
504
+ if token_type_ids is None:
505
+ token_type_ids = torch.zeros_like(input_ids)
506
+
507
+ embedding_output = self.embeddings(
508
+ input_ids,
509
+ token_type_ids,
510
+ position_ids
511
+ )
512
+ position_encodings = None
513
+
514
+ subset_mask = []
515
+ first_col_mask = []
516
+
517
+ if masked_tokens_mask is None:
518
+ subset_mask = None
519
+ else:
520
+ first_col_mask = torch.zeros_like(masked_tokens_mask)
521
+ first_col_mask[:, 0] = True
522
+ subset_mask = masked_tokens_mask | first_col_mask
523
+
524
+ encoder_outputs = self.encoder(
525
+ embedding_output,
526
+ attention_mask,
527
+ output_all_encoded_layers=output_all_encoded_layers,
528
+ subset_mask=subset_mask,
529
+ position_encodings=position_encodings)
530
+ if masked_tokens_mask is None:
531
+ sequence_output = encoder_outputs[-1]
532
+ pooled_output = self.pooler(
533
+ sequence_output, mask = attention_mask) if self.pooler is not None else None
534
+ else:
535
+ # TD [2022-03-01]: the indexing here is very tricky.
536
+ attention_mask_bool = attention_mask.bool()
537
+ subset_idx = subset_mask[attention_mask_bool] # type: ignore
538
+ sequence_output = encoder_outputs[-1][
539
+ masked_tokens_mask[attention_mask_bool][subset_idx]]
540
+ if self.pooler is not None:
541
+ pool_input = encoder_outputs[-1][
542
+ first_col_mask[attention_mask_bool][subset_idx]]
543
+ pooled_output = self.pooler(pool_input, pool=False, mask = attention_mask)
544
+ else:
545
+ pooled_output = None
546
+
547
+ if not output_all_encoded_layers:
548
+ encoder_outputs = sequence_output
549
+
550
+ if self.pooler is not None:
551
+ return encoder_outputs, pooled_output
552
+
553
+ return encoder_outputs, None
554
+
555
+
556
+ ###################
557
+ # Bert Heads
558
+ ###################
559
+ class BertLMPredictionHead(nn.Module):
560
+
561
+ def __init__(self, config, bert_model_embedding_weights):
562
+ super().__init__()
563
+ self.transform = BertPredictionHeadTransform(config)
564
+ # The output weights are the same as the input embeddings, but there is
565
+ # an output-only bias for each token.
566
+ self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
567
+ bert_model_embedding_weights.size(0))
568
+ self.decoder.weight = bert_model_embedding_weights
569
+
570
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
571
+ hidden_states = self.transform(hidden_states)
572
+ hidden_states = self.decoder(hidden_states)
573
+ return hidden_states
574
+
575
+
576
+ class BertOnlyMLMHead(nn.Module):
577
+
578
+ def __init__(self, config, bert_model_embedding_weights):
579
+ super().__init__()
580
+ self.predictions = BertLMPredictionHead(config,
581
+ bert_model_embedding_weights)
582
+
583
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
584
+ prediction_scores = self.predictions(sequence_output)
585
+ return prediction_scores
586
+
587
+
588
+ class BertOnlyNSPHead(nn.Module):
589
+
590
+ def __init__(self, config):
591
+ super().__init__()
592
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
593
+
594
+ def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
595
+ seq_relationship_score = self.seq_relationship(pooled_output)
596
+ return seq_relationship_score
597
+
598
+
599
+ #######################
600
+ # Construct Bert model
601
+ #######################
602
+ class BertForMaskedLM(BertPreTrainedModel):
603
+
604
+ def __init__(self, config):
605
+ super().__init__(config)
606
+
607
+ if config.is_decoder:
608
+ warnings.warn(
609
+ 'If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for '
610
+ 'bi-directional self-attention.')
611
+
612
+ self.bert = BertModel(config, add_pooling_layer=False)
613
+ self.cls = BertOnlyMLMHead(config,
614
+ self.bert.embeddings.word_embeddings.weight)
615
+
616
+ # Initialize weights and apply final processing
617
+ self.post_init()
618
+
619
+ @classmethod
620
+ def from_composer(cls,
621
+ pretrained_checkpoint,
622
+ state_dict=None,
623
+ cache_dir=None,
624
+ from_tf=False,
625
+ config=None,
626
+ *inputs,
627
+ **kwargs):
628
+ """Load from pre-trained."""
629
+ model = cls(config, *inputs, **kwargs)
630
+ if from_tf:
631
+ raise ValueError(
632
+ 'TensorFlow is not supported.')
633
+
634
+ state_dict = torch.load(pretrained_checkpoint)
635
+ # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
636
+ consume_prefix_in_state_dict_if_present(state_dict, prefix='model.')
637
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict,
638
+ strict=False)
639
+
640
+ if len(missing_keys) > 0:
641
+ logger.warning(
642
+ f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}"
643
+ )
644
+ if len(unexpected_keys) > 0:
645
+ logger.warning(
646
+ f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}"
647
+ )
648
+
649
+ return model
650
+
651
+ def get_output_embeddings(self):
652
+ return self.cls.predictions.decoder
653
+
654
+ def set_output_embeddings(self, new_embeddings):
655
+ self.cls.predictions.decoder = new_embeddings
656
+
657
+ def forward(
658
+ self,
659
+ input_ids: Optional[torch.Tensor] = None,
660
+ attention_mask: Optional[torch.Tensor] = None,
661
+ token_type_ids: Optional[torch.Tensor] = None,
662
+ position_ids: Optional[torch.Tensor] = None,
663
+ head_mask: Optional[torch.Tensor] = None,
664
+ inputs_embeds: Optional[torch.Tensor] = None,
665
+ encoder_hidden_states: Optional[torch.Tensor] = None,
666
+ encoder_attention_mask: Optional[torch.Tensor] = None,
667
+ labels: Optional[torch.Tensor] = None,
668
+ output_attentions: Optional[bool] = None,
669
+ output_hidden_states: Optional[bool] = None,
670
+ return_dict: Optional[bool] = None,
671
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
672
+ # labels should be a `torch.LongTensor` of shape
673
+ # `(batch_size, sequence_length)`. These are used for computing the
674
+ # masked language modeling loss.
675
+ #
676
+ # Indices should be in `[-100, 0, ..., config.vocab_size]` (see
677
+ # `input_ids` docstring) Tokens with indices set to `-100` are ignored
678
+ # (masked), the loss is only computed for the tokens with labels in `[0,
679
+ # ..., config.vocab_size]`
680
+ #
681
+ # Prediction scores are only computed for masked tokens and the (bs,
682
+ # seqlen) dimensions are flattened
683
+ if (input_ids is not None) == (inputs_embeds is not None):
684
+ raise ValueError('Must specify either input_ids or input_embeds!')
685
+
686
+ if labels is None:
687
+ masked_tokens_mask = None
688
+ else:
689
+ masked_tokens_mask = labels > 0
690
+
691
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
692
+ outputs = self.bert(
693
+ input_ids,
694
+ attention_mask=attention_mask,
695
+ token_type_ids=token_type_ids,
696
+ position_ids=position_ids,
697
+ head_mask=head_mask,
698
+ inputs_embeds=inputs_embeds,
699
+ encoder_hidden_states=encoder_hidden_states,
700
+ encoder_attention_mask=encoder_attention_mask,
701
+ output_attentions=output_attentions,
702
+ output_hidden_states=output_hidden_states,
703
+ return_dict=return_dict,
704
+ masked_tokens_mask=masked_tokens_mask,
705
+ )
706
+
707
+ sequence_output = outputs[0]
708
+ prediction_scores = self.cls(sequence_output)
709
+
710
+ loss = None
711
+ if labels is not None:
712
+ # Compute loss
713
+ loss_fct = nn.CrossEntropyLoss()
714
+
715
+ masked_token_idx = torch.nonzero(labels.flatten() > 0,
716
+ as_tuple=False).flatten()
717
+ loss = loss_fct(prediction_scores,
718
+ labels.flatten()[masked_token_idx])
719
+ assert input_ids is not None, 'Coding error; please open an issue'
720
+ batch, seqlen = input_ids.shape[:2]
721
+ prediction_scores = rearrange(
722
+ index_put_first_axis(
723
+ prediction_scores, masked_token_idx, batch * seqlen),
724
+ '(b s) d -> b s d',
725
+ b=batch)
726
+
727
+ if not return_dict:
728
+ output = (prediction_scores,) + outputs[2:]
729
+ return ((loss,) + output) if loss is not None else output
730
+
731
+ return MaskedLMOutput(
732
+ loss=loss,
733
+ logits=prediction_scores,
734
+ hidden_states=None,
735
+ attentions=None,
736
+ )
737
+
738
+ def prepare_inputs_for_generation(self, input_ids: torch.Tensor,
739
+ attention_mask: torch.Tensor,
740
+ **model_kwargs):
741
+ input_shape = input_ids.shape
742
+ effective_batch_size = input_shape[0]
743
+
744
+ # add a dummy token
745
+ if self.config.pad_token_id is None:
746
+ raise ValueError('The PAD token should be defined for generation')
747
+
748
+ attention_mask = torch.cat([
749
+ attention_mask,
750
+ attention_mask.new_zeros((attention_mask.shape[0], 1))
751
+ ], dim=-1)
752
+ dummy_token = torch.full((effective_batch_size, 1),
753
+ self.config.pad_token_id,
754
+ dtype=torch.long,
755
+ device=input_ids.device)
756
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
757
+
758
+ return {'input_ids': input_ids, 'attention_mask': attention_mask}
759
+
760
+
761
+ class BertForSequenceClassification(BertPreTrainedModel):
762
+ """Bert Model transformer with a sequence classification/regression head.
763
+
764
+ This head is just a linear layer on top of the pooled output. Used for,
765
+ e.g., GLUE tasks.
766
+ """
767
+
768
+ def __init__(self, config):
769
+ super().__init__(config)
770
+ self.num_labels = config.num_labels
771
+ self.config = config
772
+
773
+ self.bert = BertModel(config)
774
+ classifier_dropout = (config.classifier_dropout
775
+ if config.classifier_dropout is not None else
776
+ config.hidden_dropout_prob)
777
+ self.dropout = nn.Dropout(classifier_dropout)
778
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
779
+
780
+ # Initialize weights and apply final processing
781
+ self.post_init()
782
+
783
+ @classmethod
784
+ def from_composer(cls,
785
+ pretrained_checkpoint,
786
+ state_dict=None,
787
+ cache_dir=None,
788
+ from_tf=False,
789
+ config=None,
790
+ *inputs,
791
+ **kwargs):
792
+ """Load from pre-trained."""
793
+ model = cls(config, *inputs, **kwargs)
794
+ if from_tf:
795
+ raise ValueError(
796
+ 'TensorFlow is not supported.')
797
+
798
+ state_dict = torch.load(pretrained_checkpoint)
799
+ # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
800
+ consume_prefix_in_state_dict_if_present(state_dict, prefix='model.')
801
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict,
802
+ strict=False)
803
+
804
+ if len(missing_keys) > 0:
805
+ logger.warning(
806
+ f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}"
807
+ )
808
+ if len(unexpected_keys) > 0:
809
+ logger.warning(
810
+ f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}"
811
+ )
812
+
813
+ return model
814
+
815
+ def forward(
816
+ self,
817
+ input_ids: Optional[torch.Tensor] = None,
818
+ attention_mask: Optional[torch.Tensor] = None,
819
+ token_type_ids: Optional[torch.Tensor] = None,
820
+ position_ids: Optional[torch.Tensor] = None,
821
+ head_mask: Optional[torch.Tensor] = None,
822
+ inputs_embeds: Optional[torch.Tensor] = None,
823
+ labels: Optional[torch.Tensor] = None,
824
+ output_attentions: Optional[bool] = None,
825
+ output_hidden_states: Optional[bool] = None,
826
+ return_dict: Optional[bool] = None,
827
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
828
+ # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
829
+ # Labels for computing the sequence classification/regression loss.
830
+ # Indices should be in `[0, ..., config.num_labels - 1]`.
831
+ # If `config.num_labels == 1` a regression loss is computed
832
+ # (mean-square loss). If `config.num_labels > 1` a classification loss
833
+ # is computed (cross-entropy).
834
+
835
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
836
+
837
+ outputs = self.bert(
838
+ input_ids,
839
+ attention_mask=attention_mask,
840
+ token_type_ids=token_type_ids,
841
+ position_ids=position_ids,
842
+ head_mask=head_mask,
843
+ inputs_embeds=inputs_embeds,
844
+ output_attentions=output_attentions,
845
+ output_hidden_states=output_hidden_states,
846
+ return_dict=return_dict,
847
+ )
848
+
849
+ pooled_output = outputs[1]
850
+
851
+ pooled_output = self.dropout(pooled_output)
852
+ logits = self.classifier(pooled_output)
853
+
854
+ loss = None
855
+ if labels is not None:
856
+ # Compute loss
857
+ if self.config.problem_type is None:
858
+ if self.num_labels == 1:
859
+ self.config.problem_type = 'regression'
860
+ elif self.num_labels > 1 and (labels.dtype == torch.long or
861
+ labels.dtype == torch.int):
862
+ self.config.problem_type = 'single_label_classification'
863
+ else:
864
+ self.config.problem_type = 'multi_label_classification'
865
+
866
+ if self.config.problem_type == 'regression':
867
+ loss_fct = nn.MSELoss()
868
+ if self.num_labels == 1:
869
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
870
+ else:
871
+ loss = loss_fct(logits, labels)
872
+ elif self.config.problem_type == 'single_label_classification':
873
+ loss_fct = nn.CrossEntropyLoss()
874
+ loss = loss_fct(logits.view(-1, self.num_labels),
875
+ labels.view(-1))
876
+ elif self.config.problem_type == 'multi_label_classification':
877
+ loss_fct = nn.BCEWithLogitsLoss()
878
+ loss = loss_fct(logits, labels)
879
+
880
+ if not return_dict:
881
+ output = (logits,) + outputs[2:]
882
+ return ((loss,) + output) if loss is not None else output
883
+
884
+ return SequenceClassifierOutput(
885
+ loss=loss,
886
+ logits=logits,
887
+ hidden_states=None,
888
+ attentions=None,
889
+ )
bert_padding.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
2
+ # Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
3
+
4
+ """
5
+
6
+ Functions for padding and unpadding
7
+
8
+ """
9
+
10
+ from typing import Tuple, cast
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from einops import rearrange, repeat
15
+
16
+
17
+ class IndexFirstAxis(torch.autograd.Function):
18
+
19
+ @staticmethod
20
+ def forward(ctx, input: torch.Tensor,
21
+ indices: torch.Tensor) -> torch.Tensor:
22
+ """Get just the values of `input` which are at `indices`.
23
+
24
+ Arguments:
25
+ ctx: the autograd context object
26
+ input: (b, ...) 2+ dimensional tensor
27
+ indices: (num_idx) 1D tensor
28
+ """
29
+ ctx.save_for_backward(indices)
30
+ assert input.ndim >= 2
31
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[
32
+ 1:]
33
+ second_dim = other_shape.numel(
34
+ ) # product of sizes of all but first dimension
35
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
36
+ return torch.gather(
37
+ rearrange(input, 'b ... -> b (...)'), # (b, ...) -> (b, second_dim)
38
+ 0,
39
+ repeat(indices, 'z -> z d',
40
+ d=second_dim) # (indices,) -> (indices, second_dim)
41
+ ).reshape(-1, *other_shape) # (num_idx, ...)
42
+
43
+ @staticmethod
44
+ def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]:
45
+ indices, = ctx.saved_tensors
46
+ assert grad_output.ndim >= 2
47
+ other_shape = grad_output.shape[1:]
48
+ grad_output = rearrange(grad_output, 'b ... -> b (...)')
49
+ grad_input = torch.zeros([ctx.first_axis_dim, grad_output.shape[1]],
50
+ device=grad_output.device,
51
+ dtype=grad_output.dtype)
52
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
53
+ # grad_input[indices] = grad_output
54
+ grad_input.scatter_(0,
55
+ repeat(indices, 'z -> z d', d=grad_output.shape[1]),
56
+ grad_output)
57
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
58
+
59
+
60
+ index_first_axis = IndexFirstAxis.apply
61
+
62
+
63
+ class IndexPutFirstAxis(torch.autograd.Function):
64
+
65
+ @staticmethod
66
+ def forward(ctx, values: torch.Tensor, indices: torch.Tensor,
67
+ first_axis_dim) -> torch.Tensor:
68
+ ctx.save_for_backward(indices)
69
+ assert indices.ndim == 1
70
+ assert values.ndim >= 2
71
+ output = torch.zeros(first_axis_dim,
72
+ *values.shape[1:],
73
+ device=values.device,
74
+ dtype=values.dtype)
75
+ output[indices] = values
76
+ return output
77
+
78
+ @staticmethod
79
+ def backward(ctx,
80
+ grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]:
81
+ indices, = ctx.saved_tensors
82
+ grad_values = grad_output[indices]
83
+ return grad_values, None, None
84
+
85
+
86
+ index_put_first_axis = IndexPutFirstAxis.apply
87
+
88
+
89
+ def unpad_input(
90
+ hidden_states: torch.Tensor,
91
+ attention_mask: torch.Tensor,
92
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
93
+ """Remove padding from input sequences.
94
+
95
+ Arguments:
96
+ hidden_states: (batch, seqlen, ...)
97
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
98
+
99
+ Returns:
100
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
101
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
102
+ max_seqlen_in_batch: int
103
+ """
104
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
105
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
106
+ max_seqlen_in_batch = int(seqlens_in_batch.max().item())
107
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32),
108
+ (1, 0))
109
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
110
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
111
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
112
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
113
+ # so we write custom forward and backward to make it a bit faster.
114
+ hidden_states = cast(
115
+ torch.Tensor,
116
+ index_first_axis(rearrange(hidden_states, 'b s ... -> (b s) ...'),
117
+ indices))
118
+ return hidden_states, indices, cu_seqlens, max_seqlen_in_batch
119
+
120
+
121
+ def unpad_input_only(
122
+ hidden_states: torch.Tensor,
123
+ attention_mask: torch.Tensor,
124
+ ) -> torch.Tensor:
125
+ """Like unpad_input, but only return the unpadded first tensor.
126
+
127
+ Save a small amount of overhead.
128
+
129
+ Arguments:
130
+ hidden_states: (batch, seqlen, ...)
131
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
132
+
133
+ Returns:
134
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
135
+ """
136
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
137
+ rearranged = rearrange(hidden_states, 'b s ... -> (b s) ...')
138
+ return index_first_axis(rearranged, indices) # type: ignore
139
+
140
+
141
+ def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int,
142
+ seqlen: int) -> torch.Tensor:
143
+ """Add padding to sequences.
144
+
145
+ Arguments:
146
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
147
+ indices: (total_nnz)
148
+
149
+ Returns:
150
+ hidden_states: (batch, seqlen, ...)
151
+ """
152
+ output = index_put_first_axis(hidden_states, indices, batch * seqlen)
153
+ return rearrange(output, '(b s) ... -> b s ...', b=batch) # type: ignore
blockdiag_linear.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/HazyResearch/fly/tree/master/src/models/layers
2
+
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange
7
+
8
+ from .structured_linear import StructuredLinear
9
+ from .blockdiag_multiply import blockdiag_multiply
10
+
11
+
12
+ class BlockdiagLinear(StructuredLinear):
13
+
14
+ def __init__(self, *args, nblocks=4, shuffle=False, **kwargs):
15
+ """shuffle: apply channel_shuffle operation before the matmul as in ShuffleNet
16
+ """
17
+ super().__init__(*args, **kwargs)
18
+ in_blksz = int(math.ceil(self.in_features / nblocks))
19
+ out_blksz = int(math.ceil(self.out_features / nblocks))
20
+ self.in_features_extended = in_blksz * nblocks
21
+ self.out_features_extended = out_blksz * nblocks
22
+ self.shuffle = shuffle
23
+ self.weight = nn.Parameter(torch.empty(nblocks, out_blksz, in_blksz))
24
+ self.reset_parameters()
25
+
26
+ def set_weights_from_dense_init(self, dense_init_fn_):
27
+ dense_weight = torch.empty(self.out_features_extended, self.in_features_extended,
28
+ device=self.weight.device, dtype=self.weight.dtype)
29
+ dense_init_fn_(dense_weight)
30
+ # Scale by sqrt because the weight is sparse
31
+ scaling = math.sqrt(dense_weight.numel() / self.weight.numel())
32
+ dense_weight *= scaling
33
+ with torch.no_grad():
34
+ nblocks = self.weight.shape[0]
35
+ self.weight.copy_(rearrange(dense_weight, '(b o) (b1 i) -> b b1 o i',
36
+ b=nblocks, b1=nblocks)[0])
37
+
38
+ @property
39
+ def saving(self):
40
+ return self.weight.numel() / (self.in_features * self.out_features)
41
+
42
+ def forward_matmul(self, x):
43
+ x = self.preprocess(x)
44
+ if self.shuffle:
45
+ x = rearrange(x, '... (group c_per_group) -> ... (c_per_group group)',
46
+ group=self.weight.shape[0]) # group=nblocks
47
+ output = blockdiag_multiply(x, self.weight)
48
+ return self.postprocess(output)
49
+
50
+
51
+ class BlockdiagSparsityConfig:
52
+
53
+ def __init__(self, nblocks, block=32, global_size=0):
54
+ """shuffle: apply channel_shuffle operation before the matmul as in ShuffleNet
55
+ """
56
+ self.nblocks = nblocks
57
+ self.block = block
58
+ self.global_size = global_size
59
+
60
+ def make_layout(self, out_features, in_features):
61
+ assert out_features % self.block == 0 and in_features % self.block == 0
62
+ assert out_features % self.nblocks == 0 and in_features % self.nblocks == 0
63
+ layout = torch.block_diag(*[torch.ones(out_features // self.nblocks,
64
+ in_features // self.nblocks,
65
+ dtype=torch.int32)] * self.nblocks)
66
+ if self.global_size > 0:
67
+ layout[:self.global_size] = 1
68
+ layout[:, :self.global_size] = 1
69
+ # Convert from (out_features, in_features) mask to
70
+ # (out_features // block, in_features // block) mask
71
+ layout = rearrange(layout, '(p blksz) (r blksz1) -> p r (blksz blksz1)',
72
+ blksz=self.block, blksz1=self.block)
73
+ return (layout > 0).any(dim=-1).int()
blockdiag_multiply.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/HazyResearch/fly/tree/master/src/models/layers
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch.nn import functional as F
6
+ from einops import rearrange
7
+
8
+
9
+ def blockdiag_weight_to_dense_weight(weight):
10
+ """
11
+ Argumments:
12
+ weight: (nblocks, out / nblocks, in / blocks)
13
+ Return:
14
+ dense_weight: (out / in)
15
+ """
16
+ return torch.block_diag(*torch.unbind(weight, dim=0))
17
+
18
+
19
+ def blockdiag_multiply_reference(x, weight):
20
+ """
21
+ This implementation is slow but more likely to be correct.
22
+ Arguments:
23
+ x: (..., n)
24
+ weight: (nblocks, q, n / nblocks)
25
+ Outputs:
26
+ out: (..., nblocks * q)
27
+ """
28
+ n = x.shape[-1]
29
+ nblocks, q, p = weight.shape
30
+ assert nblocks * p == n
31
+
32
+ x_reshaped = rearrange(x, '... (nblocks p) -> ... nblocks p', nblocks=nblocks)
33
+ return rearrange(torch.einsum('...kp, kqp -> ...kq', x_reshaped, weight),
34
+ '... nblocks q -> ... (nblocks q)')
35
+
36
+
37
+ class BlockdiagMultiply(torch.autograd.Function):
38
+
39
+ """This is a faster implementation, with careful memory copies for the fastest
40
+ bmm performance.
41
+ The backward pass is also written manually with careful memory copies.
42
+ Arguments:
43
+ x: (..., n)
44
+ weight: (nblocks, q, n / nblocks)
45
+ Outputs:
46
+ out: (..., nblocks * q)
47
+ """
48
+
49
+ @staticmethod
50
+ @torch.cuda.amp.custom_fwd(cast_inputs=torch.bfloat16)
51
+ def forward(ctx, x, weight):
52
+ ctx.save_for_backward(x, weight)
53
+ batch_shape, n = x.shape[:-1], x.shape[-1]
54
+ batch_dim = np.prod(batch_shape)
55
+ nblocks, q, p = weight.shape
56
+ assert nblocks * p == n
57
+ x_reshaped = x.reshape(batch_dim, nblocks, p).transpose(0, 1)
58
+ out = torch.empty(batch_dim, nblocks, q, device=x.device, dtype=x.dtype).transpose(0, 1)
59
+ out = torch.bmm(x_reshaped, weight.transpose(-1, -2), out=out).transpose(0, 1)
60
+ return out.reshape(*batch_shape, nblocks * q)
61
+
62
+ @staticmethod
63
+ @torch.cuda.amp.custom_bwd
64
+ def backward(ctx, dout):
65
+ x, weight = ctx.saved_tensors
66
+ batch_shape, n = x.shape[:-1], x.shape[-1]
67
+ batch_dim = np.prod(batch_shape)
68
+ nblocks, q, p = weight.shape
69
+ assert nblocks * p == n
70
+ dx, dweight = None, None
71
+ dout_reshaped = dout.reshape(batch_dim, nblocks, q).transpose(0, 1)
72
+ if ctx.needs_input_grad[0]:
73
+ dx = torch.empty(batch_dim, nblocks, p, device=x.device, dtype=x.dtype)
74
+ dx = torch.bmm(dout_reshaped, weight.conj(),
75
+ out=dx.transpose(0, 1)).transpose(0, 1).reshape(*batch_shape, n)
76
+ if ctx.needs_input_grad[1]:
77
+ x_reshaped = x.reshape(batch_dim, nblocks, p).transpose(0, 1)
78
+ dweight = torch.bmm(dout_reshaped.transpose(-1, -2), x_reshaped.conj())
79
+ return dx, dweight
80
+
81
+
82
+ blockdiag_multiply = BlockdiagMultiply.apply
config.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "alycialee/m2-bert-260m",
3
+ "alibi_starting_size": 512,
4
+ "architectures": [
5
+ "BertForMaskedLM"
6
+ ],
7
+ "attention_probs_dropout_prob": 0.0,
8
+ "bidirectional": true,
9
+ "auto_map": {
10
+ "AutoConfig": "configuration_bert.BertConfig",
11
+ "AutoModelForMaskedLM": "bert_layers.BertForMaskedLM"
12
+ },
13
+ "classifier_dropout": null,
14
+ "gradient_checkpointing": false,
15
+ "hidden_act": "gelu",
16
+ "hidden_dropout_prob": 0.1,
17
+ "hidden_size": 1792,
18
+ "initializer_range": 0.02,
19
+ "intermediate_size": 7168,
20
+ "layer_norm_eps": 1e-12,
21
+ "max_position_embeddings": 128,
22
+ "model_type": "bert",
23
+ "num_attention_heads": 12,
24
+ "num_hidden_layers": 12,
25
+ "pad_token_id": 0,
26
+ "position_embedding_type": "absolute",
27
+ "torch_dtype": "float32",
28
+ "transformers_version": "4.28.1",
29
+ "type_vocab_size": 2,
30
+ "use_cache": true,
31
+ "long_conv_l_max": 128,
32
+ "long_conv_kernel_learning_rate": 1e-3,
33
+ "hyena_lr_pos_emb": 1e-5,
34
+ "hyena_w": 10,
35
+ "hyena_wd": 0.1,
36
+ "hyena_emb_dim": 5,
37
+ "hyena_filter_order": 128,
38
+ "residual_long_conv": true,
39
+ "use_glu_mlp": true,
40
+ "use_monarch_mlp": true,
41
+ "monarch_mlp_nblocks": 4,
42
+ "use_positional_encodings" : true,
43
+ "vocab_size": 30528
44
+ }
configuration_bert.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertConfig
2
+
3
+
4
+ class BertConfig(BertConfig):
5
+
6
+ def __init__(
7
+ self,
8
+ alibi_starting_size: int = 512,
9
+ attention_probs_dropout_prob: float = 0.0,
10
+
11
+ # mlp
12
+ use_glu_mlp: bool = True,
13
+ use_monarch_mlp: bool = False,
14
+ monarch_mlp_nblocks: int = 4,
15
+
16
+ # position
17
+ use_positional_encodings: bool = False,
18
+ max_position_embeddings: int = 512,
19
+
20
+ # architecture selection
21
+ residual_long_conv: bool = False,
22
+
23
+ # hyena and long conv hyperparameters
24
+ bidirectional: bool = True,
25
+ hyena_w_mod: int = 1,
26
+ hyena_filter_dropout: float = 0.2,
27
+ hyena_filter_order: int = 64,
28
+
29
+ # efficiency
30
+ use_flash_mm: bool = False,
31
+
32
+ # average pooling instead of CLS token
33
+ pool_all: bool = False,
34
+
35
+ **kwargs,
36
+ ):
37
+ """Configuration class for MosaicBert.
38
+
39
+ Args:
40
+ alibi_starting_size (int): Use `alibi_starting_size` to determine how large of an alibi tensor to
41
+ create when initializing the model. You should be able to ignore this parameter in most cases.
42
+ Defaults to 512.
43
+ attention_probs_dropout_prob (float): By default, turn off attention dropout in Mosaic BERT.
44
+ Defaults to 0.0.
45
+ """
46
+ super().__init__(
47
+ attention_probs_dropout_prob=attention_probs_dropout_prob, **kwargs)
48
+ self.alibi_starting_size = alibi_starting_size
49
+
50
+ # mlp
51
+ self.use_glu_mlp = use_glu_mlp
52
+ self.use_monarch_mlp = use_monarch_mlp
53
+ self.monarch_mlp_nblocks = monarch_mlp_nblocks
54
+
55
+ # positional encodings
56
+ self.use_positional_encodings = use_positional_encodings
57
+ self.max_position_embeddings = max_position_embeddings
58
+
59
+ # architecture
60
+ self.residual_long_conv = residual_long_conv
61
+
62
+ # hyena and long conv hyperparameters
63
+ self.bidirectional = bidirectional
64
+ self.hyena_w_mod = hyena_w_mod
65
+ self.hyena_filter_dropout = hyena_filter_dropout
66
+ self.hyena_filter_order = hyena_filter_order
67
+
68
+ # efficiency
69
+ self.use_flash_mm = use_flash_mm
70
+
71
+ # average pooling instead of CLS token
72
+ self.pool_all = pool_all
73
+
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.28.1",
4
+ "use_cache": false,
5
+ "eos_token_id": [0, 50278]
6
+ }
hyena_utils.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Dan Fu and Simran Arora.
2
+ # Adapted from https://github.com/HazyResearch/safari/blob/main/src/models/sequence/hyena.py
3
+
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from einops import rearrange
11
+ import opt_einsum as oe
12
+ contract = oe.contract
13
+
14
+ """ Utils for the training loop. Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py """
15
+
16
+ class OptimModule(nn.Module):
17
+ """ Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters """
18
+
19
+ def register(self, name, tensor, lr=None, wd=0.0):
20
+ """Register a tensor with a configurable learning rate and 0 weight decay"""
21
+
22
+ if lr == 0.0:
23
+ self.register_buffer(name, tensor)
24
+ else:
25
+ self.register_parameter(name, nn.Parameter(tensor))
26
+
27
+ optim = {}
28
+ if lr is not None: optim["lr"] = lr
29
+ if wd is not None: optim["weight_decay"] = wd
30
+ setattr(getattr(self, name), "_optim", optim)
31
+
32
+
33
+ def fftconv_ref(u, k, D, dropout_mask, gelu=True, k_rev=None):
34
+ # u.shape: B H L
35
+ seqlen = u.shape[-1]
36
+
37
+ fft_size = 2 * seqlen
38
+ k_f = torch.fft.rfft(k, n=fft_size) / fft_size
39
+ if k_rev is not None:
40
+ k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size
41
+ k_f = k_f + k_rev_f.conj()
42
+ u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)
43
+
44
+ if len(u.shape) > 3:
45
+ k_f = k_f.unsqueeze(1)
46
+
47
+ y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen]
48
+
49
+ out = y + u * D
50
+
51
+ if gelu:
52
+ out = F.gelu(out)
53
+ if dropout_mask is not None:
54
+ return (out * rearrange(dropout_mask, "b H -> b H 1")).to(dtype=u.dtype)
55
+ else:
56
+ return out.to(dtype=u.dtype)
57
+
58
+
59
+ @torch.jit.script
60
+ def mul_sum(q, y):
61
+ return (q * y).sum(dim=1)
62
+
63
+
64
+ class Sin(nn.Module):
65
+ def __init__(self, dim, w=10, w_mod=1, train_freq=True):
66
+ super().__init__()
67
+
68
+ init_tensor = torch.ones(1, dim)
69
+ self.freq = (
70
+ nn.Parameter(w * init_tensor)
71
+ if train_freq
72
+ else w * torch.ones(1, dim)
73
+ )
74
+ self.w_mod = w_mod
75
+
76
+ def forward(self, x):
77
+ return torch.sin(self.w_mod * self.freq * x)
78
+
79
+
80
+ class PositionalEmbedding(OptimModule):
81
+ def __init__(self, emb_dim: int, seq_len: int, lr_pos_emb: float = 1e-5, **kwargs):
82
+ """Complex exponential positional embeddings for Hyena filters."""
83
+ super().__init__()
84
+
85
+ self.seq_len = seq_len
86
+ # The time embedding fed to the filteres is normalized so that t_f = 1
87
+ t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1
88
+
89
+ if emb_dim > 1:
90
+ bands = (emb_dim - 1) // 2
91
+ # To compute the right embeddings we use the "proper" linspace
92
+ t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None]
93
+ w = 2 * math.pi * t_rescaled / seq_len # 1, L, 1
94
+
95
+ f = torch.linspace(1e-4, bands - 1, bands)[None, None]
96
+ z = torch.exp(-1j * f * w)
97
+ z = torch.cat([t, z.real, z.imag], dim=-1)
98
+ self.register("z", z, lr=lr_pos_emb)
99
+ self.register("t", t, lr=0.0)
100
+
101
+ def forward(self, L):
102
+ return self.z[:, :L], self.t[:, :L]
103
+
104
+
105
+ class ExponentialModulation(OptimModule):
106
+ def __init__(
107
+ self,
108
+ d_model,
109
+ fast_decay_pct=0.3,
110
+ slow_decay_pct=1.5,
111
+ target=1e-2,
112
+ modulation_lr=0.0,
113
+ shift: float = 0.0,
114
+ **kwargs,
115
+ ):
116
+ super().__init__()
117
+ self.shift = shift
118
+ max_decay = math.log(target) / fast_decay_pct
119
+ min_decay = math.log(target) / slow_decay_pct
120
+ deltas = torch.linspace(min_decay, max_decay, d_model)[None, None]
121
+ self.register("deltas", deltas, lr=modulation_lr)
122
+
123
+ def forward(self, t, x):
124
+ decay = torch.exp(-t * self.deltas.abs())
125
+ x = x * (decay + self.shift)
126
+ return x
127
+
128
+
129
+ class HyenaFilter(OptimModule):
130
+ def __init__(
131
+ self,
132
+ d_model,
133
+ emb_dim=3, # dim of input to MLP, augments with positional encoding
134
+ order=16, # width of the implicit MLP
135
+ seq_len=1024,
136
+ lr=1e-3,
137
+ lr_pos_emb=1e-5,
138
+ dropout=0.0,
139
+ w=1, # frequency of periodic activations
140
+ w_mod=1, # non-learnable modification of w
141
+ wd=0, # weight decay of kernel parameters
142
+ bias=True,
143
+ num_inner_mlps=2,
144
+ linear_mixer=False,
145
+ modulate: bool = True,
146
+ normalized=False,
147
+ bidirectional=False,
148
+ **kwargs,
149
+ ):
150
+ """
151
+ Implicit long filter with modulation.
152
+
153
+ Args:
154
+ d_model: number of channels in the input
155
+ emb_dim: dimension of the positional encoding (`emb_dim` - 1) // 2 is the number of bands
156
+ order: width of the FFN
157
+ num_inner_mlps: number of inner linear layers inside filter MLP
158
+
159
+ Note:
160
+ filter_dropout is not implemented
161
+ """
162
+ super().__init__()
163
+
164
+ self.d_model=d_model
165
+ self.emb_dim=emb_dim
166
+ self.seq_len=seq_len
167
+ self.modulate=modulate
168
+ self.use_bias = bias
169
+ self.bidirectional = bidirectional
170
+
171
+ self.bias = nn.Parameter(torch.randn(self.d_model))
172
+ self.dropout = nn.Dropout(dropout)
173
+
174
+ act = Sin(dim=order, w=w, w_mod=w_mod)
175
+ assert (
176
+ emb_dim % 2 != 0 and emb_dim >= 3
177
+ ), "emb_dim must be odd and greater or equal to 3 (time, sine and cosine)"
178
+ self.pos_emb = PositionalEmbedding(emb_dim, seq_len, lr_pos_emb)
179
+
180
+ # uses a variable number of inner linear layers
181
+ if linear_mixer is False:
182
+ self.implicit_filter = nn.Sequential(
183
+ nn.Linear(emb_dim, order),
184
+ act,
185
+ )
186
+ for i in range(num_inner_mlps):
187
+ self.implicit_filter.append(nn.Linear(order, order))
188
+ self.implicit_filter.append(act)
189
+ self.implicit_filter.append(nn.Linear(order, d_model, bias=False))
190
+ else:
191
+ self.implicit_filter = nn.Sequential(
192
+ nn.Linear(emb_dim, d_model, bias=False),
193
+ )
194
+
195
+ if self.bidirectional:
196
+ self.implicit_filter_rev = nn.Sequential(
197
+ nn.Linear(emb_dim, order),
198
+ act,
199
+ )
200
+ for i in range(num_inner_mlps):
201
+ self.implicit_filter_rev.append(nn.Linear(order, order))
202
+ self.implicit_filter_rev.append(act)
203
+ self.implicit_filter_rev.append(nn.Linear(order, d_model, bias=False))
204
+
205
+ self.modulation = ExponentialModulation(d_model, **kwargs)
206
+
207
+ self.normalized = normalized
208
+ for c in self.implicit_filter.children():
209
+ for name, v in c.state_dict().items():
210
+ optim = {"weight_decay": wd, "lr": lr}
211
+ setattr(getattr(c, name), "_optim", optim)
212
+
213
+ def filter(self, L, *args, **kwargs):
214
+ z, t = self.pos_emb(L)
215
+ h = self.implicit_filter(z)
216
+ if self.modulate:
217
+ h = self.modulation(t, h)
218
+ if self.normalized:
219
+ h = h / torch.norm(h, dim=-1, p=1, keepdim=True)
220
+ return h
221
+
222
+ def filter_rev(self, L, *args, **kwargs):
223
+ z, t = self.pos_emb(L)
224
+ h = self.implicit_filter_rev(z)
225
+ if self.modulate:
226
+ h = self.modulation(t, h)
227
+ if self.normalized:
228
+ h = h / torch.norm(h, dim=-1, p=1, keepdim=True)
229
+ return h
230
+
231
+ def forward(self, x, L, k_fwd=None, k_rev=None, bias=None, *args, **kwargs):
232
+ if k_fwd is None:
233
+ k_fwd = self.filter(L)
234
+ if self.bidirectional and k_rev is None:
235
+ k_rev = self.filter_rev(L)
236
+
237
+ # Ensure compatibility with filters that return a tuple
238
+ k_fwd = k_fwd[0] if type(k_fwd) is tuple else k_fwd
239
+ if bias is None:
240
+ bias = self.bias
241
+ bias = bias if self.use_bias else 0 * bias
242
+
243
+ if self.bidirectional:
244
+ k_rev = k_rev[0] if type(k_rev) is tuple else k_rev
245
+ k = F.pad(k_fwd, (0, L)) \
246
+ + F.pad(k_rev.flip(-1), (L, 0))
247
+ else:
248
+ k = k_fwd
249
+
250
+
251
+ y = fftconv_ref(
252
+ x,
253
+ k,
254
+ bias,
255
+ dropout_mask=None,
256
+ gelu=False,
257
+ )
258
+
259
+ return y.to(dtype=x.dtype)
monarch_mixer_sequence_mixer.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Dan Fu and Simran Arora.
2
+ # Adapted from https://github.com/HazyResearch/safari/blob/main/src/models/sequence/hyena.py
3
+
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+ import opt_einsum as oe
7
+
8
+ contract = oe.contract
9
+ from .hyena_utils import HyenaFilter
10
+
11
+
12
+ class MonarchMixerSequenceMixing(nn.Module):
13
+ def __init__(
14
+ self,
15
+ d_model,
16
+ l_max=128,
17
+ dropout=0.0,
18
+ hyena_kernel_lr=None,
19
+ bidirectional=False,
20
+ hyena_lr_pos_emb=1e-5,
21
+ hyena_w=10,
22
+ hyena_w_mod=1,
23
+ hyena_wd=0.1,
24
+ hyena_emb_dim=3,
25
+ hyena_filter_dropout=0.0,
26
+ hyena_filter_order=16,
27
+ residual_long_conv=False,
28
+ ):
29
+ super().__init__()
30
+
31
+ self.d_model = d_model
32
+ self.l_max = l_max
33
+ self.kernel_lr = hyena_kernel_lr
34
+ self.channels = 1
35
+ self.bidirectional = bidirectional
36
+ self.residual_long_conv = residual_long_conv
37
+ self.NUM_PROJECTIONS = 3
38
+
39
+ print('-- Bidirectional:', self.bidirectional)
40
+ print("-- Using Long Conv Residual:", self.residual_long_conv)
41
+ print('-- Hyena w:', hyena_w)
42
+ print('-- Hyena w mod:', hyena_w_mod)
43
+ print(f"-- Hyena filter order: {hyena_filter_order}")
44
+ print(f"-- Hyena filter dropout: {hyena_filter_dropout}")
45
+ print(f"-- Hyena filter wd: {hyena_wd}")
46
+ print(f"-- Hyena filter emb dim: {hyena_emb_dim}")
47
+ print(f"-- Hyena filter lr: {hyena_kernel_lr}")
48
+ print(f"-- Hyena filter lr pos emb: {hyena_lr_pos_emb}")
49
+
50
+ self.filter_fn = HyenaFilter(
51
+ self.d_model,
52
+ order=hyena_filter_order,
53
+ seq_len=self.l_max,
54
+ dropout=hyena_filter_dropout,
55
+ bidirectional=self.bidirectional,
56
+ lr=hyena_kernel_lr,
57
+ lr_pos_emb=hyena_lr_pos_emb,
58
+ w=hyena_w, # frequency of periodic activations
59
+ w_mod=hyena_w_mod,
60
+ wd=hyena_wd, # weight decay of kernel parameters
61
+ emb_dim=hyena_emb_dim,
62
+ )
63
+
64
+ if self.residual_long_conv:
65
+ self.filter_fn2 = HyenaFilter(
66
+ self.d_model,
67
+ order=hyena_filter_order,
68
+ seq_len=self.l_max,
69
+ dropout=hyena_filter_dropout,
70
+ bidirectional=self.bidirectional,
71
+ lr=hyena_kernel_lr,
72
+ lr_pos_emb=hyena_lr_pos_emb,
73
+ w=hyena_w, # frequency of periodic activations
74
+ w_mod=hyena_w_mod,
75
+ wd=hyena_wd, # weight decay of kernel parameters
76
+ emb_dim=hyena_emb_dim,
77
+ )
78
+
79
+ # setup projections
80
+ self.in_linear = nn.Linear(d_model, 3 * d_model)
81
+ self.out_linear = nn.Linear(d_model, d_model)
82
+
83
+ # setup short conv
84
+ total_width = self.d_model * self.NUM_PROJECTIONS
85
+ self.short_filter = nn.Conv1d(
86
+ in_channels=total_width,
87
+ out_channels=total_width,
88
+ kernel_size=3,
89
+ groups=total_width,
90
+ padding=2,
91
+ )
92
+
93
+
94
+ def forward(self, u, **kwargs):
95
+ # u is B L H
96
+ L = u.size(-2)
97
+
98
+ # in projection
99
+ u_orig = u
100
+ u = self.in_linear(u)
101
+ u = rearrange(u, "b l d -> b d l")
102
+
103
+ # short filter
104
+ uc = self.short_filter(u)[..., :L]
105
+
106
+ x1, x2, v = uc.split(self.d_model, dim=1)
107
+
108
+ v = v * x1
109
+
110
+ k = self.filter_fn.filter(L, device=u.device)
111
+ k = rearrange(k, "c l d -> c d l")[0] # `c` is always 1 by default
112
+
113
+ if self.bidirectional:
114
+ k_rev = self.filter_fn.filter_rev(L, device=u.device)
115
+ k_rev = rearrange(k_rev, "c l d -> c d l")[0] # `c` is always 1 by default
116
+ else:
117
+ k_rev = None
118
+
119
+ y = self.filter_fn(v, L, k_fwd=k, k_rev=k_rev, bias= self.filter_fn.bias[None, :, None])
120
+
121
+ if self.residual_long_conv:
122
+ k2 = self.filter_fn2.filter(L, device=u.device)
123
+ k2 = rearrange(k2, "c l d -> c d l")[0]
124
+
125
+ if self.bidirectional:
126
+ k2_rev = self.filter_fn2.filter_rev(L, device=u.device)
127
+ k2_rev = rearrange(k2_rev, "c l d -> c d l")[0] # `c` is always 1 by default
128
+ else:
129
+ k2_rev = None
130
+
131
+ yu = self.filter_fn2(u_orig.transpose(-1, -2), L, k_fwd=k2, k_rev=k2_rev, bias= self.filter_fn2.bias[None, :, None])
132
+
133
+ # post gating
134
+ y = y * x2
135
+
136
+ if self.residual_long_conv:
137
+ y = y + yu
138
+
139
+ y = y.transpose(-1, -2)
140
+ y = self.out_linear(y)
141
+
142
+ return y, None
143
+
144
+
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b38662d8ec674f0ccbb744ee78681faba427cd47034f40b8dd4a8991b3d63de
3
+ size 1364559337
structured_linear.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/HazyResearch/fly/tree/master/src/models/layers
2
+
3
+ import math
4
+ from functools import partial
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch.nn import init
9
+
10
+
11
+ class StructuredLinear(nn.Module):
12
+
13
+ def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
14
+ """Subclasses should call reset_parameters
15
+ """
16
+ factory_kwargs = {'device': device, 'dtype': dtype}
17
+ super().__init__()
18
+ self.in_features = in_features
19
+ self.out_features = out_features
20
+ # Subclasses may override {in,out}_features_extended
21
+ if not hasattr(self, 'in_features_extended'):
22
+ self.in_features_extended = in_features
23
+ if not hasattr(self, 'out_features_extended'):
24
+ self.out_features_extended = out_features
25
+ if bias:
26
+ self.bias = nn.Parameter(torch.zeros(out_features, **factory_kwargs))
27
+ else:
28
+ self.register_parameter('bias', None)
29
+
30
+ def reset_parameters(self) -> None:
31
+ self.set_weights_from_dense_init(dense_init_fn_=partial(init.kaiming_uniform_, a=math.sqrt(5)))
32
+ self.reset_parameters_bias()
33
+
34
+ def set_weights_from_dense_init(self, dense_init_fn_):
35
+ raise NotImplementedError
36
+
37
+ def reset_parameters_bias(self):
38
+ if self.bias is not None:
39
+ fan_in = self.bias.shape[-1]
40
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
41
+ init.uniform_(self.bias, -bound, bound)
42
+
43
+ @property
44
+ def saving(self):
45
+ raise NotImplementedError
46
+
47
+ def convert_to_dense_weight(self):
48
+ factory_kwargs = {'device': self.weight.device, 'dtype': self.weight.dtype}
49
+ dense_weight = self.forward_matmul(torch.eye(self.in_features, **factory_kwargs)).T
50
+ return dense_weight
51
+
52
+ def preprocess(self, x):
53
+ in_features = x.shape[-1]
54
+ if in_features < self.in_features_extended:
55
+ x = F.pad(x, (0, self.in_features_extended - in_features))
56
+ return x
57
+
58
+ def postprocess(self, output):
59
+ out_features_extended = output.shape[-1]
60
+ if out_features_extended > self.out_features:
61
+ output = output[..., :self.out_features]
62
+ return output
63
+
64
+ def forward_matmul(self, x):
65
+ raise NotImplementedError
66
+
67
+ def forward(self, x):
68
+ output = self.forward_matmul(x)
69
+ # Convert bias to output.dtype in case of AMP, otherwise bias and activation will be in FP32
70
+ return (output + self.bias.to(dtype=output.dtype)) if self.bias is not None else output