JingzeShi commited on
Commit
c271fb0
1 Parent(s): 24a7b6a

Upload DogeForCausalLM

Browse files
Files changed (4) hide show
  1. config.json +8 -10
  2. configuration_doge.py +16 -24
  3. model.safetensors +2 -2
  4. modeling_doge.py +254 -295
config.json CHANGED
@@ -1,33 +1,31 @@
1
  {
2
- "_name_or_path": "./checkpoint-10000",
3
  "architectures": [
4
  "DogeForCausalLM"
5
  ],
 
6
  "auto_map": {
7
  "AutoConfig": "configuration_doge.DogeConfig",
8
  "AutoModelForCausalLM": "modeling_doge.DogeForCausalLM"
9
  },
10
  "bos_token_id": 1,
11
  "eos_token_id": 2,
 
12
  "hidden_act": "silu",
13
  "hidden_bias": false,
14
  "hidden_dropout": 0.0,
15
  "hidden_size": 512,
16
  "initializer_range": 0.02,
17
- "inner_values_retrieval_size": 128,
18
  "intermediate_size": 2048,
19
- "max_position_embeddings": 4096,
 
20
  "model_type": "doge",
21
  "num_attention_heads": 4,
22
- "num_cdmmoe_experts": 2048,
23
- "num_cdmmoe_experts_per_head": 4,
24
- "num_cdmmoe_heads": 2,
25
  "num_hidden_layers": 8,
26
- "num_inner_value_heads": 2,
27
- "num_inner_values": 4,
28
- "num_value_per_head": 2,
29
  "pad_token_id": 0,
30
- "private_expert_retrieval_size": 256,
31
  "rms_norm_eps": 1e-06,
32
  "rope_scaling": null,
33
  "rope_theta": 10000.0,
 
1
  {
2
+ "_name_or_path": "./results/Doge-60M",
3
  "architectures": [
4
  "DogeForCausalLM"
5
  ],
6
+ "attention_dropout": 0.0,
7
  "auto_map": {
8
  "AutoConfig": "configuration_doge.DogeConfig",
9
  "AutoModelForCausalLM": "modeling_doge.DogeForCausalLM"
10
  },
11
  "bos_token_id": 1,
12
  "eos_token_id": 2,
13
+ "expert_retrieval_size": 256,
14
  "hidden_act": "silu",
15
  "hidden_bias": false,
16
  "hidden_dropout": 0.0,
17
  "hidden_size": 512,
18
  "initializer_range": 0.02,
 
19
  "intermediate_size": 2048,
20
+ "is_moe": false,
21
+ "max_position_embeddings": 2048,
22
  "model_type": "doge",
23
  "num_attention_heads": 4,
24
+ "num_cdmmoe_experts": 4096,
25
+ "num_cdmmoe_experts_per_head": 8,
26
+ "num_cdmmoe_heads": 4,
27
  "num_hidden_layers": 8,
 
 
 
28
  "pad_token_id": 0,
 
29
  "rms_norm_eps": 1e-06,
30
  "rope_scaling": null,
31
  "rope_theta": 10000.0,
configuration_doge.py CHANGED
@@ -1,9 +1,9 @@
1
  # coding=utf-8
2
- # Copyright 2024 Jingze Shi and the HuggingFace Inc. team. All rights reserved.
3
  #
4
  # This code is based on the Wonderful Matrices paper implementation.
5
  #
6
- # https://arxiv.org/abs/2407.16958
7
  #
8
  # Licensed under the Apache License, Version 2.0 (the "License");
9
  # you may not use this file except in compliance with the License.
@@ -46,7 +46,7 @@ class DogeConfig(PretrainedConfig):
46
  Dropout probability for each sequence transformation and state transformation module.
47
  hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
48
  The non-linear activation function (function or string) in the decoder.
49
- max_position_embeddings (`int`, *optional*, defaults to 16384):
50
  The maximum sequence length that this model might ever be used with.
51
  rope_theta (`float`, *optional*, defaults to 10000.0):
52
  The base period of the RoPE embeddings.
@@ -104,22 +104,18 @@ class DogeConfig(PretrainedConfig):
104
  Whether to tie weight embeddings
105
  num_attention_heads (`int`, *optional*, defaults to 8):
106
  Number of attention heads for each attention layer in the Transformer decoder.
107
- num_inner_values (`int`, *optional*, defaults to 8):
108
- Number of inner values for Inner Function Attention.
109
- num_inner_value_heads (`int`, *optional*, defaults to 4):
110
- Number of inner value heads for Inner Function Attention.
111
- num_value_per_head (`int`, *optional*, defaults to 4):
112
- Number of values per head, can't be greater than `num_inner_values`.
113
- inner_values_retrieval_size (`int`, *optional*, defaults to 128):
114
- Dimension of the inner values retrieval states for each attention layer in the Transformer decoder
115
- private_expert_retrieval_size (`int`, *optional*, defaults to 256):
116
- Dimension of the Private Expert retrieval states for the Cross Domain Mixture of Experts.
117
  num_cdmmoe_experts (`int`, *optional*, defaults to 4096):
118
  Number of Private Experts for the Cross Domain Mixture of Experts.
119
  num_cdmmoe_heads (`int`, *optional*, defaults to 4):
120
  Number of heads of Private Experts for the Cross Domain Mixture of Experts.
121
  num_cdmmoe_experts_per_head (`int`, *optional*, defaults to 8):
122
  Number of Private Experts per head for the Cross Domain Mixture of Experts.
 
 
123
  """
124
 
125
  model_type = "doge"
@@ -134,7 +130,7 @@ class DogeConfig(PretrainedConfig):
134
  hidden_bias=False,
135
  hidden_dropout=0.0,
136
  hidden_act="silu",
137
- max_position_embeddings=16384,
138
  rope_theta=10000.0,
139
  rope_scaling=None,
140
  initializer_range=0.02,
@@ -145,14 +141,12 @@ class DogeConfig(PretrainedConfig):
145
  eos_token_id=2,
146
  tie_word_embeddings=False,
147
  num_attention_heads=8,
148
- num_inner_values=8,
149
- num_inner_value_heads=4,
150
- num_value_per_head=4,
151
- inner_values_retrieval_size=128,
152
- private_expert_retrieval_size=256,
153
  num_cdmmoe_experts=4096,
154
  num_cdmmoe_heads=4,
155
  num_cdmmoe_experts_per_head=8,
 
156
  **kwargs,
157
  ):
158
  self.vocab_size = vocab_size
@@ -173,14 +167,12 @@ class DogeConfig(PretrainedConfig):
173
  self.eos_token_id = eos_token_id
174
  self.tie_word_embeddings = tie_word_embeddings
175
  self.num_attention_heads = num_attention_heads
176
- self.num_inner_values = num_inner_values
177
- self.num_inner_value_heads = num_inner_value_heads
178
- self.num_value_per_head = num_value_per_head
179
- self.inner_values_retrieval_size = inner_values_retrieval_size
180
- self.private_expert_retrieval_size = private_expert_retrieval_size
181
  self.num_cdmmoe_experts = num_cdmmoe_experts
182
  self.num_cdmmoe_heads = num_cdmmoe_heads
183
  self.num_cdmmoe_experts_per_head = num_cdmmoe_experts_per_head
 
184
 
185
  # Validate the correctness of rotary position embeddings parameters
186
  # BC: if there is a 'type' field, copy it it to 'rope_type'.
 
1
  # coding=utf-8
2
+ # Copyright 2024 Jingze Shi and the HuggingFace Inc. team. All rights reserved.
3
  #
4
  # This code is based on the Wonderful Matrices paper implementation.
5
  #
6
+ # https://arxiv.org/abs/2412.11834
7
  #
8
  # Licensed under the Apache License, Version 2.0 (the "License");
9
  # you may not use this file except in compliance with the License.
 
46
  Dropout probability for each sequence transformation and state transformation module.
47
  hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
48
  The non-linear activation function (function or string) in the decoder.
49
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
50
  The maximum sequence length that this model might ever be used with.
51
  rope_theta (`float`, *optional*, defaults to 10000.0):
52
  The base period of the RoPE embeddings.
 
104
  Whether to tie weight embeddings
105
  num_attention_heads (`int`, *optional*, defaults to 8):
106
  Number of attention heads for each attention layer in the Transformer decoder.
107
+ attention_dropout (`float`, *optional*, defaults to 0.0):
108
+ The dropout ratio for the attention probabilities.
109
+ is_moe (`bool`, *optional*, defaults to `False`):
110
+ Whether to use the Cross Domain Mixture of Experts, if `True`, the MoE will inherit the MLP to initialize
 
 
 
 
 
 
111
  num_cdmmoe_experts (`int`, *optional*, defaults to 4096):
112
  Number of Private Experts for the Cross Domain Mixture of Experts.
113
  num_cdmmoe_heads (`int`, *optional*, defaults to 4):
114
  Number of heads of Private Experts for the Cross Domain Mixture of Experts.
115
  num_cdmmoe_experts_per_head (`int`, *optional*, defaults to 8):
116
  Number of Private Experts per head for the Cross Domain Mixture of Experts.
117
+ expert_retrieval_size (`int`, *optional*, defaults to 256):
118
+ Dimension of the Expert retrieval states for the Cross Domain Mixture of Experts.
119
  """
120
 
121
  model_type = "doge"
 
130
  hidden_bias=False,
131
  hidden_dropout=0.0,
132
  hidden_act="silu",
133
+ max_position_embeddings=2048,
134
  rope_theta=10000.0,
135
  rope_scaling=None,
136
  initializer_range=0.02,
 
141
  eos_token_id=2,
142
  tie_word_embeddings=False,
143
  num_attention_heads=8,
144
+ attention_dropout=0.0,
145
+ is_moe=False,
 
 
 
146
  num_cdmmoe_experts=4096,
147
  num_cdmmoe_heads=4,
148
  num_cdmmoe_experts_per_head=8,
149
+ expert_retrieval_size=256,
150
  **kwargs,
151
  ):
152
  self.vocab_size = vocab_size
 
167
  self.eos_token_id = eos_token_id
168
  self.tie_word_embeddings = tie_word_embeddings
169
  self.num_attention_heads = num_attention_heads
170
+ self.attention_dropout = attention_dropout
171
+ self.is_moe = is_moe
 
 
 
172
  self.num_cdmmoe_experts = num_cdmmoe_experts
173
  self.num_cdmmoe_heads = num_cdmmoe_heads
174
  self.num_cdmmoe_experts_per_head = num_cdmmoe_experts_per_head
175
+ self.expert_retrieval_size = expert_retrieval_size
176
 
177
  # Validate the correctness of rotary position embeddings parameters
178
  # BC: if there is a 'type' field, copy it it to 'rope_type'.
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:794645ba99f640a813621b02d8f89f67a857deeb876d882059c2b01bcabb045a
3
- size 307592408
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:26d80cdf90d4f053299b962b1ede76f0fe30ed31ebcb95e5dbd730ce23ffd36a
3
+ size 268580408
modeling_doge.py CHANGED
@@ -1,9 +1,9 @@
1
  # coding=utf-8
2
- # Copyright 2024 Jingze Shi and the HuggingFace Inc. team. All rights reserved.
3
  #
4
  # This code is based on the Wonderful Matrices paper implementation.
5
  #
6
- # https://arxiv.org/abs/2407.16958
7
  #
8
  # Licensed under the Apache License, Version 2.0 (the "License");
9
  # you may not use this file except in compliance with the License.
@@ -39,16 +39,15 @@ from transformers.modeling_utils import PreTrainedModel
39
  from transformers.utils import (
40
  add_start_docstrings,
41
  add_start_docstrings_to_model_forward,
42
- # is_einx_available,
43
  logging,
44
  replace_return_docstrings,
45
  )
46
  from .configuration_doge import DogeConfig
47
 
48
-
49
-
50
- from einx import add as einx_add
51
-
52
 
53
 
54
  logger = logging.get_logger(__name__)
@@ -76,6 +75,18 @@ class RMSNorm(nn.Module):
76
  return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
77
 
78
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  class RotaryEmbedding(nn.Module):
80
  def __init__(self, config: Optional[DogeConfig] = None):
81
  super().__init__()
@@ -172,8 +183,8 @@ def apply_QK_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
172
  return q_embed, k_embed
173
 
174
 
175
- class DogeInnerFuncAttn(nn.Module):
176
- """Inner Function Attention from 'Wonderful Matrices' paper."""
177
 
178
  def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None):
179
  super().__init__()
@@ -189,15 +200,10 @@ class DogeInnerFuncAttn(nn.Module):
189
 
190
  self.hidden_dim = config.hidden_size
191
  self.num_attention_heads = config.num_attention_heads
192
-
193
- # for accuracy of attention scores, we do not use GQA
194
  self.attention_head_dim = self.hidden_dim // self.num_attention_heads
195
- self.num_inner_values = config.num_inner_values
196
- self.num_inner_value_heads = config.num_inner_value_heads
197
- self.num_value_per_head = config.num_value_per_head
198
- self.inner_values_retrieval_dim = config.inner_values_retrieval_size
199
 
200
- # Q and K projections
201
  self.q_proj = nn.Linear(
202
  self.hidden_dim,
203
  self.num_attention_heads * self.attention_head_dim,
@@ -208,157 +214,26 @@ class DogeInnerFuncAttn(nn.Module):
208
  self.num_attention_heads * self.attention_head_dim,
209
  bias=config.hidden_bias,
210
  )
211
-
212
  # dynamic mask for the QK^T attention score matrix
213
- self.dynamic_mask = nn.Parameter(
214
- torch.round(torch.ones(self.num_attention_heads, config.max_position_embeddings))
215
  )
216
-
217
- # queries and keys for retrieval V
218
- self.v_queries = nn.Linear(
219
  self.hidden_dim,
220
- self.num_inner_value_heads * self.inner_values_retrieval_dim,
221
  bias=config.hidden_bias,
222
  )
223
- self.v_keys = nn.Parameter(
224
- torch.zeros(
225
- self.num_inner_value_heads,
226
- self.inner_values_retrieval_dim,
227
- self.num_inner_values,
228
- )
229
- )
230
-
231
- # V for inner function
232
- self.v_embed = nn.Embedding(
233
- self.num_inner_values,
234
  self.hidden_dim,
 
 
235
  )
236
-
237
  self.o_proj = nn.Linear(
238
  self.hidden_dim,
239
  self.hidden_dim,
240
  bias=config.hidden_bias,
241
  )
242
 
243
- def _update_causal_mask(
244
- self,
245
- attention_mask: torch.Tensor = None,
246
- input_tensor: torch.Tensor = None,
247
- cache_position: torch.Tensor = None,
248
- past_key_values: Cache = None,
249
- output_attentions: bool = False,
250
- ):
251
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
252
- using_static_cache = isinstance(past_key_values, StaticCache)
253
-
254
- dtype, device = input_tensor.dtype, input_tensor.device
255
- sequence_length = input_tensor.shape[1]
256
- if using_static_cache:
257
- target_length = past_key_values.get_max_cache_shape()
258
- else:
259
- target_length = (
260
- attention_mask.shape[-1]
261
- if isinstance(attention_mask, torch.Tensor)
262
- else past_seen_tokens + sequence_length + 1
263
- )
264
-
265
- # in case the provided `attention` mask is 2D, we generate a causal mask here (4D).
266
- causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position_and_dynamic_mask(
267
- attention_mask=attention_mask,
268
- dynamic_mask=self.dynamic_mask,
269
- sequence_length=sequence_length,
270
- target_length=target_length,
271
- dtype=dtype,
272
- device=device,
273
- cache_position=cache_position,
274
- batch_size=input_tensor.shape[0],
275
- )
276
-
277
- return causal_mask
278
-
279
- @staticmethod
280
- def _prepare_4d_causal_attention_mask_with_cache_position_and_dynamic_mask(
281
- attention_mask: torch.Tensor = None,
282
- dynamic_mask: torch.Tensor = None,
283
- sequence_length: int = None,
284
- target_length: int = None,
285
- dtype: torch.dtype = None,
286
- device: torch.device = None,
287
- cache_position: torch.Tensor = None,
288
- batch_size: int = None,
289
- **kwargs,
290
- ):
291
- """
292
- Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
293
- `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
294
-
295
- Args:
296
- attention_mask (`torch.Tensor`):
297
- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
298
- `(batch_size, 1, query_length, key_value_length)`.
299
- dynamic_mask (`torch.Tensor`):
300
- A 2D dynamic mask of shape `(num_heads, max_position_embeddings)`.
301
- sequence_length (`int`):
302
- The sequence length being processed.
303
- target_length (`int`):
304
- The target length: when generating with static cache, the mask should be as long as the static cache,
305
- to account for the 0 padding, the part of the cache that is not filled yet.
306
- dtype (`torch.dtype`):
307
- The dtype to use for the 4D attention mask.
308
- device (`torch.device`):
309
- The device to plcae the 4D attention mask on.
310
- cache_position (`torch.Tensor`):
311
- Indices depicting the position of the input sequence tokens in the sequence.
312
- batch_size (`torch.Tensor`):
313
- Batch size.
314
- """
315
- if attention_mask is not None and attention_mask.dim() == 4:
316
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
317
- causal_mask = attention_mask
318
- else:
319
- num_heads = 1 if dynamic_mask is None else dynamic_mask.size(0)
320
- min_dtype = torch.finfo(dtype).min
321
- causal_mask = torch.full(
322
- (sequence_length, target_length),
323
- fill_value=min_dtype,
324
- dtype=dtype,
325
- device=device,
326
- )
327
- if sequence_length != 1:
328
- causal_mask = torch.triu(causal_mask, diagonal=1)
329
- causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
330
- causal_mask = causal_mask[None, None, :, :].expand(batch_size, num_heads, -1, -1)
331
- if attention_mask is not None:
332
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
333
- mask_length = attention_mask.shape[-1]
334
- attention_mask = attention_mask[:, None, None, :].expand(-1, num_heads, 1, -1)
335
- if dynamic_mask is not None:
336
- dynamic_mask = dynamic_mask[None, :, None, :mask_length].expand(batch_size, -1, 1, -1)
337
- attention_mask = attention_mask.clone() * dynamic_mask
338
-
339
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask
340
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
341
- padding_mask == 0, min_dtype
342
- )
343
-
344
- return causal_mask
345
-
346
- def inner_func(
347
- self,
348
- hidden_states: torch.Tensor,
349
- ) -> torch.Tensor:
350
- """
351
- Each value can share weights with other values to increase the expressive power
352
- """
353
- bsz, seq_len, _ = hidden_states.shape
354
-
355
- v_queries = self.v_queries(hidden_states)
356
- v_queries = v_queries.view(bsz, seq_len, self.num_inner_value_heads, -1).transpose(1, 2)
357
- sim = torch.matmul(v_queries, self.v_keys).transpose(1, 2)
358
- v_embed = self.v_embed(sim.topk(k=self.num_value_per_head, dim=-1).indices)
359
- v = hidden_states * v_embed.sum(dim=-2).sum(dim=-2)
360
- return v
361
-
362
  def forward(
363
  self,
364
  hidden_states: torch.Tensor,
@@ -369,24 +244,24 @@ class DogeInnerFuncAttn(nn.Module):
369
  position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
370
  **kwargs,
371
  ) -> Tuple[torch.Tensor, Optional[Cache]]:
372
- bsz, seq_len, _ = hidden_states.shape
373
 
374
  query_states = self.q_proj(hidden_states)
375
  key_states = self.k_proj(hidden_states)
376
- value_states = self.inner_func(hidden_states)
377
 
378
- query_states = query_states.view(bsz, seq_len, self.num_attention_heads, self.attention_head_dim).transpose(
379
  1, 2
380
  )
381
- key_states = key_states.view(bsz, seq_len, self.num_attention_heads, self.attention_head_dim).transpose(
382
  1, 2
383
  )
384
- value_states = value_states.view(bsz, seq_len, self.num_attention_heads, self.attention_head_dim).transpose(
385
  1, 2
386
  )
387
 
388
  cos, sin = position_embeddings
389
- query_states, query_states = apply_QK_rotary_pos_emb(query_states, query_states, cos, sin)
390
 
391
  if past_key_value is not None:
392
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
@@ -397,38 +272,101 @@ class DogeInnerFuncAttn(nn.Module):
397
  attn_weights = torch.matmul(query_states, key_states.transpose(-1, -2)) / math.sqrt(self.attention_head_dim)
398
 
399
  # add mask to attention scores
400
- causal_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position, past_key_value)
401
- causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
402
- attn_weights = attn_weights + causal_mask
 
 
 
403
 
404
  # upcast attention scores to fp32
405
  attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
 
406
 
407
  # apply attention scores to value states
408
  attn_output = torch.matmul(attn_weights, value_states)
409
 
410
  attn_output = attn_output.transpose(1, 2).contiguous()
411
- attn_output = attn_output.reshape(bsz, seq_len, -1)
412
  attn_output = self.o_proj(attn_output)
413
 
414
  return attn_output, past_key_value
415
 
416
 
417
- class DogeCDMoE(nn.Module):
418
- """Cross-Domain Mixture of Experts from 'Wonderful Matrices' paper."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
 
420
  def __init__(self, config: DogeConfig):
421
  super().__init__()
422
  self.hidden_dim = config.hidden_size
423
- self.act_fn = ACT2FN[config.hidden_act]
424
  self.intermediate_dim = config.intermediate_size
 
425
 
426
- self.private_expert_retrieval_dim = config.private_expert_retrieval_size
427
- self.num_cdmmoe_experts = config.num_cdmmoe_experts
428
- self.num_cdmmoe_heads = config.num_cdmmoe_heads
429
- self.num_cdmmoe_experts_per_head = config.num_cdmmoe_experts_per_head
430
-
431
- # cross domain
432
  self.up_proj = nn.Linear(
433
  self.hidden_dim,
434
  self.intermediate_dim,
@@ -440,24 +378,46 @@ class DogeCDMoE(nn.Module):
440
  bias=config.hidden_bias,
441
  )
442
 
443
- # queries and keys for retrieval private experts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444
  self.queries = nn.Linear(
445
  self.hidden_dim,
446
- self.num_cdmmoe_heads * self.private_expert_retrieval_dim,
447
  bias=False,
448
  )
449
- self.num_keys = int(math.sqrt(self.num_cdmmoe_experts))
450
  self.keys = nn.Parameter(
451
  torch.zeros(
452
  self.num_cdmmoe_heads,
453
  self.num_keys,
454
  2,
455
- self.private_expert_retrieval_dim // 2,
456
  )
457
  )
458
 
459
- # private experts
460
- self.down_embed = nn.Embedding(
461
  self.num_cdmmoe_experts,
462
  self.hidden_dim,
463
  )
@@ -471,7 +431,7 @@ class DogeCDMoE(nn.Module):
471
  self,
472
  hidden_states: torch.Tensor,
473
  **kwargs,
474
- ) -> Tuple[torch.Tensor, torch.Tensor]:
475
  bsz, seq_len, _ = hidden_states.shape
476
 
477
  # get similarity with queries and keys
@@ -479,7 +439,7 @@ class DogeCDMoE(nn.Module):
479
  queries = queries.view(bsz, seq_len, 2, self.num_cdmmoe_heads, -1).permute(2, 0, 1, 3, 4)
480
  sim = torch.einsum("p b t h n, h k p n -> p b t h k", queries, self.keys)
481
 
482
- # get expert scores and indices with the highest similarity
483
  (scores_x, scores_y), (indices_x, indices_y) = sim.topk(self.num_cdmmoe_experts_per_head, dim=-1)
484
  if einx_add is not None:
485
  all_scores = einx_add("... i, ... j -> ... (i j)", scores_x, scores_y)
@@ -491,17 +451,14 @@ class DogeCDMoE(nn.Module):
491
  all_indices = all_indices.view(*indices_x.shape[:-1], -1)
492
  scores, pk_indices = all_scores.topk(self.num_cdmmoe_experts_per_head, dim=-1)
493
  indices = all_indices.gather(-1, pk_indices)
494
-
495
- # get related expert embeddings based on indices
496
  down_embed = self.down_embed(indices)
497
  up_embed = self.up_embed(indices)
498
 
499
- # efficient retrieval of private experts
500
- experts_weights = self.act_fn(torch.einsum("b t d, b t h k d -> b t h k", hidden_states, down_embed) * scores.softmax(dim=-1))
 
501
  experts_states = torch.einsum("b t h k, b t h k d -> b t d", experts_weights, up_embed)
502
-
503
- # mix with shared parameters of cross domain
504
- hidden_states = self.down_proj(self.act_fn(self.up_proj(hidden_states)))
505
  hidden_states = hidden_states + experts_states
506
  return hidden_states
507
 
@@ -511,10 +468,13 @@ class DogeDecoderLayer(nn.Module):
511
  super().__init__()
512
  self.hidden_dropout = config.hidden_dropout
513
 
514
- self.in_attn_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
515
- self.attn = DogeInnerFuncAttn(config, layer_idx)
516
- self.in_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
517
- self.feed_forward = DogeCDMoE(config)
 
 
 
518
 
519
  def forward(
520
  self,
@@ -553,7 +513,7 @@ class DogeDecoderLayer(nn.Module):
553
 
554
  # sequence transformation
555
  residual = hidden_states
556
- hidden_states = self.in_attn_layernorm(hidden_states)
557
  hidden_states, present_key_value = self.attn(
558
  hidden_states=hidden_states,
559
  attention_mask=attention_mask,
@@ -565,14 +525,14 @@ class DogeDecoderLayer(nn.Module):
565
  )
566
  self_attn_weights = None
567
  hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training)
568
- hidden_states = residual + hidden_states
569
 
570
  # state transformation
571
  residual = hidden_states
572
- hidden_states = self.in_ff_layernorm(hidden_states)
573
  hidden_states = self.feed_forward(hidden_states)
574
  hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training)
575
- hidden_states = residual + hidden_states
576
 
577
  outputs = (hidden_states,)
578
 
@@ -592,6 +552,7 @@ class DogePreTrainedModel(PreTrainedModel):
592
  supports_gradient_checkpointing = True
593
  _no_split_modules = ["DogeDecoderLayer"]
594
  _skip_keys_device_placement = ["past_key_values"]
 
595
  _supports_cache_class = True
596
  _supports_quantized_cache = True
597
  _supports_static_cache = True
@@ -765,9 +726,9 @@ class DogeModel(DogePreTrainedModel):
765
  if position_ids is None:
766
  position_ids = cache_position.unsqueeze(0)
767
 
768
- # causal_mask = self._update_causal_mask(
769
- # attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
770
- # )
771
  hidden_states = inputs_embeds
772
 
773
  # create position embeddings to be shared across the decoder layers
@@ -776,6 +737,7 @@ class DogeModel(DogePreTrainedModel):
776
  # decoder layers
777
  all_hidden_states = () if output_hidden_states else None
778
  all_self_attns = () if output_attentions else None
 
779
 
780
  for decoder_layer in self.layers:
781
  if output_hidden_states:
@@ -785,7 +747,7 @@ class DogeModel(DogePreTrainedModel):
785
  layer_outputs = self._gradient_checkpointing_func(
786
  decoder_layer.__call__,
787
  hidden_states,
788
- attention_mask,
789
  position_ids,
790
  past_key_values,
791
  output_attentions,
@@ -796,7 +758,7 @@ class DogeModel(DogePreTrainedModel):
796
  else:
797
  layer_outputs = decoder_layer(
798
  hidden_states,
799
- attention_mask=attention_mask,
800
  position_ids=position_ids,
801
  past_key_value=past_key_values,
802
  output_attentions=output_attentions,
@@ -833,100 +795,97 @@ class DogeModel(DogePreTrainedModel):
833
  attentions=all_self_attns,
834
  )
835
 
836
- """Move to DogeInnerFuncAttn"""
837
- # def _update_causal_mask(
838
- # self,
839
- # attention_mask: torch.Tensor,
840
- # input_tensor: torch.Tensor,
841
- # cache_position: torch.Tensor,
842
- # past_key_values: Cache,
843
- # output_attentions: bool,
844
- # ):
845
- # # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
846
- # # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
847
- # # to infer the attention mask.
848
- # past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
849
- # using_static_cache = isinstance(past_key_values, StaticCache)
850
-
851
- # dtype, device = input_tensor.dtype, input_tensor.device
852
- # sequence_length = input_tensor.shape[1]
853
- # if using_static_cache:
854
- # target_length = past_key_values.get_max_cache_shape()
855
- # else:
856
- # target_length = (
857
- # attention_mask.shape[-1]
858
- # if isinstance(attention_mask, torch.Tensor)
859
- # else past_seen_tokens + sequence_length + 1
860
- # )
861
-
862
- # # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
863
- # causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
864
- # attention_mask,
865
- # sequence_length=sequence_length,
866
- # target_length=target_length,
867
- # dtype=dtype,
868
- # device=device,
869
- # cache_position=cache_position,
870
- # batch_size=input_tensor.shape[0],
871
- # )
872
-
873
- # return causal_mask
874
-
875
- # @staticmethod
876
- # def _prepare_4d_causal_attention_mask_with_cache_position(
877
- # attention_mask: torch.Tensor,
878
- # sequence_length: int,
879
- # target_length: int,
880
- # dtype: torch.dtype,
881
- # device: torch.device,
882
- # cache_position: torch.Tensor,
883
- # batch_size: int,
884
- # **kwargs,
885
- # ):
886
- # """
887
- # Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
888
- # `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
889
-
890
- # Args:
891
- # attention_mask (`torch.Tensor`):
892
- # A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
893
- # `(batch_size, 1, query_length, key_value_length)`.
894
- # sequence_length (`int`):
895
- # The sequence length being processed.
896
- # target_length (`int`):
897
- # The target length: when generating with static cache, the mask should be as long as the static cache,
898
- # to account for the 0 padding, the part of the cache that is not filled yet.
899
- # dtype (`torch.dtype`):
900
- # The dtype to use for the 4D attention mask.
901
- # device (`torch.device`):
902
- # The device to plcae the 4D attention mask on.
903
- # cache_position (`torch.Tensor`):
904
- # Indices depicting the position of the input sequence tokens in the sequence.
905
- # batch_size (`torch.Tensor`):
906
- # Batch size.
907
- # """
908
- # if attention_mask is not None and attention_mask.dim() == 4:
909
- # # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
910
- # causal_mask = attention_mask
911
- # else:
912
- # min_dtype = torch.finfo(dtype).min
913
- # causal_mask = torch.full(
914
- # (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
915
- # )
916
- # if sequence_length != 1:
917
- # causal_mask = torch.triu(causal_mask, diagonal=1)
918
- # causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
919
- # causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
920
- # if attention_mask is not None:
921
- # causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
922
- # mask_length = attention_mask.shape[-1]
923
- # padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
924
- # padding_mask = padding_mask == 0
925
- # causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
926
- # padding_mask, min_dtype
927
- # )
928
-
929
- # return causal_mask
930
 
931
 
932
  class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
 
1
  # coding=utf-8
2
+ # Copyright 2024 Jingze Shi and the HuggingFace Inc. team. All rights reserved.
3
  #
4
  # This code is based on the Wonderful Matrices paper implementation.
5
  #
6
+ # https://arxiv.org/abs/2412.11834
7
  #
8
  # Licensed under the Apache License, Version 2.0 (the "License");
9
  # you may not use this file except in compliance with the License.
 
39
  from transformers.utils import (
40
  add_start_docstrings,
41
  add_start_docstrings_to_model_forward,
 
42
  logging,
43
  replace_return_docstrings,
44
  )
45
  from .configuration_doge import DogeConfig
46
 
47
+ try:
48
+ from einx import add as einx_add
49
+ except ImportError:
50
+ einx_add = None
51
 
52
 
53
  logger = logging.get_logger(__name__)
 
75
  return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
76
 
77
 
78
+ class Residual(nn.Module):
79
+ def __init__(self, hidden_size):
80
+ super().__init__()
81
+ self.weight = nn.Parameter(torch.ones(hidden_size))
82
+
83
+ def forward(self, residual_states, hidden_states):
84
+ return self.weight * residual_states + hidden_states
85
+
86
+ def extra_repr(self):
87
+ return f"{tuple(self.weight.shape)}"
88
+
89
+
90
  class RotaryEmbedding(nn.Module):
91
  def __init__(self, config: Optional[DogeConfig] = None):
92
  super().__init__()
 
183
  return q_embed, k_embed
184
 
185
 
186
+ class DogeDynamicMaskAttention(nn.Module):
187
+ """Dynamic Mask Attention from 'Wonderful Matrices' paper."""
188
 
189
  def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None):
190
  super().__init__()
 
200
 
201
  self.hidden_dim = config.hidden_size
202
  self.num_attention_heads = config.num_attention_heads
203
+ self.attention_dropout = config.attention_dropout
 
204
  self.attention_head_dim = self.hidden_dim // self.num_attention_heads
 
 
 
 
205
 
206
+ # Q K V O projections
207
  self.q_proj = nn.Linear(
208
  self.hidden_dim,
209
  self.num_attention_heads * self.attention_head_dim,
 
214
  self.num_attention_heads * self.attention_head_dim,
215
  bias=config.hidden_bias,
216
  )
 
217
  # dynamic mask for the QK^T attention score matrix
218
+ self.A = nn.Parameter(
219
+ torch.ones(self.num_attention_heads)
220
  )
221
+ self.dt_proj = nn.Linear(
 
 
222
  self.hidden_dim,
223
+ self.num_attention_heads,
224
  bias=config.hidden_bias,
225
  )
226
+ self.v_proj = nn.Linear(
 
 
 
 
 
 
 
 
 
 
227
  self.hidden_dim,
228
+ self.num_attention_heads * self.attention_head_dim,
229
+ bias=config.hidden_bias,
230
  )
 
231
  self.o_proj = nn.Linear(
232
  self.hidden_dim,
233
  self.hidden_dim,
234
  bias=config.hidden_bias,
235
  )
236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  def forward(
238
  self,
239
  hidden_states: torch.Tensor,
 
244
  position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
245
  **kwargs,
246
  ) -> Tuple[torch.Tensor, Optional[Cache]]:
247
+ bsz, q_len, _ = hidden_states.shape
248
 
249
  query_states = self.q_proj(hidden_states)
250
  key_states = self.k_proj(hidden_states)
251
+ value_states = self.v_proj(hidden_states)
252
 
253
+ query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.attention_head_dim).transpose(
254
  1, 2
255
  )
256
+ key_states = key_states.view(bsz, q_len, self.num_attention_heads, self.attention_head_dim).transpose(
257
  1, 2
258
  )
259
+ value_states = value_states.view(bsz, q_len, self.num_attention_heads, self.attention_head_dim).transpose(
260
  1, 2
261
  )
262
 
263
  cos, sin = position_embeddings
264
+ query_states, key_states = apply_QK_rotary_pos_emb(query_states, key_states, cos, sin)
265
 
266
  if past_key_value is not None:
267
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
 
272
  attn_weights = torch.matmul(query_states, key_states.transpose(-1, -2)) / math.sqrt(self.attention_head_dim)
273
 
274
  # add mask to attention scores
275
+ if attention_mask is not None:
276
+ dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(bsz, value_states.shape[-2], -1))
277
+ dynamic_mask = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
278
+ dynamic_mask = dynamic_mask < 1.0
279
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]].masked_fill(dynamic_mask[:, :, None, :], torch.finfo(hidden_states.dtype).min)
280
+ attn_weights = attn_weights + causal_mask
281
 
282
  # upcast attention scores to fp32
283
  attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
284
+ attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
285
 
286
  # apply attention scores to value states
287
  attn_output = torch.matmul(attn_weights, value_states)
288
 
289
  attn_output = attn_output.transpose(1, 2).contiguous()
290
+ attn_output = attn_output.reshape(bsz, q_len, -1)
291
  attn_output = self.o_proj(attn_output)
292
 
293
  return attn_output, past_key_value
294
 
295
 
296
+ class DogeSdpaDynamicMaskAttn(DogeDynamicMaskAttention):
297
+
298
+ def forward(
299
+ self,
300
+ hidden_states: torch.Tensor,
301
+ attention_mask: Optional[torch.Tensor] = None,
302
+ position_ids: Optional[torch.LongTensor] = None,
303
+ past_key_value: Optional[Cache] = None,
304
+ cache_position: Optional[torch.LongTensor] = None,
305
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
306
+ **kwargs,
307
+ ) -> Tuple[torch.Tensor, Optional[Cache]]:
308
+ bsz, q_len, _ = hidden_states.shape
309
+
310
+ query_states = self.q_proj(hidden_states)
311
+ key_states = self.k_proj(hidden_states)
312
+ value_states = self.v_proj(hidden_states)
313
+
314
+ query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.attention_head_dim).transpose(1, 2)
315
+ key_states = key_states.view(bsz, q_len, self.num_attention_heads, self.attention_head_dim).transpose(1, 2)
316
+ value_states = value_states.view(bsz, q_len, self.num_attention_heads, self.attention_head_dim).transpose(1, 2)
317
+
318
+ cos, sin = position_embeddings
319
+ query_states, key_states = apply_QK_rotary_pos_emb(query_states, key_states, cos, sin)
320
+
321
+ if past_key_value is not None:
322
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
323
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
324
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
325
+
326
+ if attention_mask is not None:
327
+ dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(bsz, value_states.shape[-2], -1))
328
+ dynamic_mask = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
329
+ dynamic_mask = dynamic_mask < 1.0
330
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]].masked_fill(dynamic_mask[:, :, None, :], torch.finfo(hidden_states.dtype).min)
331
+
332
+ query_states = query_states.contiguous()
333
+ key_states = key_states.contiguous()
334
+ value_states = value_states.contiguous()
335
+
336
+ attn_output = F.scaled_dot_product_attention(
337
+ query_states,
338
+ key_states,
339
+ value_states,
340
+ attn_mask=causal_mask,
341
+ dropout_p=self.attention_dropout,
342
+ )
343
+
344
+ attn_output = attn_output.transpose(1, 2).contiguous()
345
+ attn_output = attn_output.view(bsz, q_len, -1)
346
+ attn_output = self.o_proj(attn_output)
347
+
348
+ return attn_output, past_key_value
349
+
350
+
351
+ DOGE_ATTENTION_CLASSES = {
352
+ "eager": DogeDynamicMaskAttention,
353
+ "sdpa": DogeSdpaDynamicMaskAttn,
354
+ }
355
+
356
+
357
+ class DogeMLP(nn.Module):
358
 
359
  def __init__(self, config: DogeConfig):
360
  super().__init__()
361
  self.hidden_dim = config.hidden_size
 
362
  self.intermediate_dim = config.intermediate_size
363
+ self.act_fn = ACT2FN[config.hidden_act]
364
 
365
+ self.gate_proj = nn.Linear(
366
+ self.hidden_dim,
367
+ self.intermediate_dim,
368
+ bias=config.hidden_bias,
369
+ )
 
370
  self.up_proj = nn.Linear(
371
  self.hidden_dim,
372
  self.intermediate_dim,
 
378
  bias=config.hidden_bias,
379
  )
380
 
381
+ def forward(
382
+ self,
383
+ hidden_states: torch.Tensor,
384
+ **kwargs,
385
+ ) -> torch.Tensor:
386
+ hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
387
+ return hidden_states
388
+
389
+
390
+ class DogeCDMoE(DogeMLP):
391
+ """Cross Domain Mixture of Experts from 'Wonderful Matrices' paper."""
392
+
393
+ def __init__(self, config: DogeConfig):
394
+ super().__init__(config)
395
+ self.hidden_dim = config.hidden_size
396
+ self.act_fn = ACT2FN[config.hidden_act]
397
+
398
+ self.expert_retrieval_dim = config.expert_retrieval_size
399
+ self.num_cdmmoe_experts = config.num_cdmmoe_experts
400
+ self.num_cdmmoe_heads = config.num_cdmmoe_heads
401
+ self.num_cdmmoe_experts_per_head = config.num_cdmmoe_experts_per_head
402
+ self.num_keys = int(math.sqrt(self.num_cdmmoe_experts))
403
+
404
+ # queries and keys for retrieval experts
405
  self.queries = nn.Linear(
406
  self.hidden_dim,
407
+ self.num_cdmmoe_heads * self.expert_retrieval_dim,
408
  bias=False,
409
  )
 
410
  self.keys = nn.Parameter(
411
  torch.zeros(
412
  self.num_cdmmoe_heads,
413
  self.num_keys,
414
  2,
415
+ self.expert_retrieval_dim // 2,
416
  )
417
  )
418
 
419
+ # experts
420
+ self.down_embed = nn.Embedding(
421
  self.num_cdmmoe_experts,
422
  self.hidden_dim,
423
  )
 
431
  self,
432
  hidden_states: torch.Tensor,
433
  **kwargs,
434
+ ) -> torch.Tensor:
435
  bsz, seq_len, _ = hidden_states.shape
436
 
437
  # get similarity with queries and keys
 
439
  queries = queries.view(bsz, seq_len, 2, self.num_cdmmoe_heads, -1).permute(2, 0, 1, 3, 4)
440
  sim = torch.einsum("p b t h n, h k p n -> p b t h k", queries, self.keys)
441
 
442
+ # get experts with the highest similarity
443
  (scores_x, scores_y), (indices_x, indices_y) = sim.topk(self.num_cdmmoe_experts_per_head, dim=-1)
444
  if einx_add is not None:
445
  all_scores = einx_add("... i, ... j -> ... (i j)", scores_x, scores_y)
 
451
  all_indices = all_indices.view(*indices_x.shape[:-1], -1)
452
  scores, pk_indices = all_scores.topk(self.num_cdmmoe_experts_per_head, dim=-1)
453
  indices = all_indices.gather(-1, pk_indices)
 
 
454
  down_embed = self.down_embed(indices)
455
  up_embed = self.up_embed(indices)
456
 
457
+ # mix experts states with cross domain states
458
+ experts_weights = torch.einsum("b t d, b t h k d -> b t h k", hidden_states, down_embed)
459
+ experts_weights = self.act_fn(experts_weights) * scores.softmax(dim=-1)
460
  experts_states = torch.einsum("b t h k, b t h k d -> b t d", experts_weights, up_embed)
461
+ hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
 
 
462
  hidden_states = hidden_states + experts_states
463
  return hidden_states
464
 
 
468
  super().__init__()
469
  self.hidden_dropout = config.hidden_dropout
470
 
471
+ self.pre_sequence_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
472
+ self.attn = DOGE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
473
+ self.post_sequence_residual = Residual(config.hidden_size)
474
+
475
+ self.pre_state_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
476
+ self.feed_forward = DogeMLP(config) if config.is_moe == False else DogeCDMoE(config)
477
+ self.post_state_residual = Residual(config.hidden_size)
478
 
479
  def forward(
480
  self,
 
513
 
514
  # sequence transformation
515
  residual = hidden_states
516
+ hidden_states = self.pre_sequence_layernorm(hidden_states)
517
  hidden_states, present_key_value = self.attn(
518
  hidden_states=hidden_states,
519
  attention_mask=attention_mask,
 
525
  )
526
  self_attn_weights = None
527
  hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training)
528
+ hidden_states = self.post_sequence_residual(residual, hidden_states)
529
 
530
  # state transformation
531
  residual = hidden_states
532
+ hidden_states = self.pre_state_layernorm(hidden_states)
533
  hidden_states = self.feed_forward(hidden_states)
534
  hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training)
535
+ hidden_states = self.post_state_residual(residual, hidden_states)
536
 
537
  outputs = (hidden_states,)
538
 
 
552
  supports_gradient_checkpointing = True
553
  _no_split_modules = ["DogeDecoderLayer"]
554
  _skip_keys_device_placement = ["past_key_values"]
555
+ _supports_sdpa = True
556
  _supports_cache_class = True
557
  _supports_quantized_cache = True
558
  _supports_static_cache = True
 
726
  if position_ids is None:
727
  position_ids = cache_position.unsqueeze(0)
728
 
729
+ causal_mask = self._update_causal_mask(
730
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
731
+ )
732
  hidden_states = inputs_embeds
733
 
734
  # create position embeddings to be shared across the decoder layers
 
737
  # decoder layers
738
  all_hidden_states = () if output_hidden_states else None
739
  all_self_attns = () if output_attentions else None
740
+ next_decoder_cache = None
741
 
742
  for decoder_layer in self.layers:
743
  if output_hidden_states:
 
747
  layer_outputs = self._gradient_checkpointing_func(
748
  decoder_layer.__call__,
749
  hidden_states,
750
+ causal_mask,
751
  position_ids,
752
  past_key_values,
753
  output_attentions,
 
758
  else:
759
  layer_outputs = decoder_layer(
760
  hidden_states,
761
+ attention_mask=causal_mask,
762
  position_ids=position_ids,
763
  past_key_value=past_key_values,
764
  output_attentions=output_attentions,
 
795
  attentions=all_self_attns,
796
  )
797
 
798
+ def _update_causal_mask(
799
+ self,
800
+ attention_mask: torch.Tensor = None,
801
+ input_tensor: torch.Tensor = None,
802
+ cache_position: torch.Tensor = None,
803
+ past_key_values: Cache = None,
804
+ output_attentions: bool = False,
805
+ ):
806
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
807
+ using_static_cache = isinstance(past_key_values, StaticCache)
808
+
809
+ dtype, device = input_tensor.dtype, input_tensor.device
810
+ sequence_length = input_tensor.shape[1]
811
+ if using_static_cache:
812
+ target_length = past_key_values.get_max_cache_shape()
813
+ else:
814
+ target_length = (
815
+ attention_mask.shape[-1]
816
+ if isinstance(attention_mask, torch.Tensor)
817
+ else past_seen_tokens + sequence_length + 1
818
+ )
819
+
820
+ # in case the provided `attention` mask is 2D, we generate a causal mask here (4D).
821
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
822
+ attention_mask=attention_mask,
823
+ sequence_length=sequence_length,
824
+ target_length=target_length,
825
+ dtype=dtype,
826
+ device=device,
827
+ cache_position=cache_position,
828
+ batch_size=input_tensor.shape[0],
829
+ )
830
+
831
+ return causal_mask
832
+
833
+ @staticmethod
834
+ def _prepare_4d_causal_attention_mask_with_cache_position(
835
+ attention_mask: torch.Tensor = None,
836
+ sequence_length: int = None,
837
+ target_length: int = None,
838
+ dtype: torch.dtype = None,
839
+ device: torch.device = None,
840
+ cache_position: torch.Tensor = None,
841
+ batch_size: int = None,
842
+ **kwargs,
843
+ ):
844
+ """
845
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
846
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
847
+
848
+ Args:
849
+ attention_mask (`torch.Tensor`):
850
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
851
+ `(batch_size, 1, query_length, key_value_length)`.
852
+ sequence_length (`int`):
853
+ The sequence length being processed.
854
+ target_length (`int`):
855
+ The target length: when generating with static cache, the mask should be as long as the static cache,
856
+ to account for the 0 padding, the part of the cache that is not filled yet.
857
+ dtype (`torch.dtype`):
858
+ The dtype to use for the 4D attention mask.
859
+ device (`torch.device`):
860
+ The device to plcae the 4D attention mask on.
861
+ cache_position (`torch.Tensor`):
862
+ Indices depicting the position of the input sequence tokens in the sequence.
863
+ batch_size (`torch.Tensor`):
864
+ Batch size.
865
+ """
866
+ if attention_mask is not None and attention_mask.dim() == 4:
867
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
868
+ causal_mask = attention_mask
869
+ else:
870
+ min_dtype = torch.finfo(dtype).min
871
+ causal_mask = torch.full(
872
+ (sequence_length, target_length),
873
+ fill_value=min_dtype, dtype=dtype, device=device,
874
+ )
875
+ if sequence_length != 1:
876
+ causal_mask = torch.triu(causal_mask, diagonal=1)
877
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
878
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
879
+ if attention_mask is not None:
880
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
881
+ mask_length = attention_mask.shape[-1]
882
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
883
+ padding_mask = padding_mask == 0
884
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
885
+ padding_mask, min_dtype
886
+ )
887
+
888
+ return causal_mask
 
 
 
889
 
890
 
891
  class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):