JingzeShi commited on
Commit
f8df032
1 Parent(s): 3c346b5

Upload DogeForCausalLM

Browse files
Files changed (2) hide show
  1. configuration_doge.py +1 -1
  2. modeling_doge.py +5 -4
configuration_doge.py CHANGED
@@ -3,7 +3,7 @@
3
  #
4
  # This code is based on the Wonderful Matrices paper implementation.
5
  #
6
- # https://arxiv.org/abs/2407.16958
7
  #
8
  # Licensed under the Apache License, Version 2.0 (the "License");
9
  # you may not use this file except in compliance with the License.
 
3
  #
4
  # This code is based on the Wonderful Matrices paper implementation.
5
  #
6
+ # https://arxiv.org/abs/2412.11834
7
  #
8
  # Licensed under the Apache License, Version 2.0 (the "License");
9
  # you may not use this file except in compliance with the License.
modeling_doge.py CHANGED
@@ -3,7 +3,7 @@
3
  #
4
  # This code is based on the Wonderful Matrices paper implementation.
5
  #
6
- # https://arxiv.org/abs/2407.16958
7
  #
8
  # Licensed under the Apache License, Version 2.0 (the "License");
9
  # you may not use this file except in compliance with the License.
@@ -184,6 +184,7 @@ def apply_QK_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
184
 
185
 
186
  class DogeDynamicMaskAttention(nn.Module):
 
187
 
188
  def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None):
189
  super().__init__()
@@ -387,6 +388,7 @@ class DogeMLP(nn.Module):
387
 
388
 
389
  class DogeCDMoE(DogeMLP):
 
390
 
391
  def __init__(self, config: DogeConfig):
392
  super().__init__(config)
@@ -816,7 +818,7 @@ class DogeModel(DogePreTrainedModel):
816
  )
817
 
818
  # in case the provided `attention` mask is 2D, we generate a causal mask here (4D).
819
- causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position_and_dynamic_mask(
820
  attention_mask=attention_mask,
821
  sequence_length=sequence_length,
822
  target_length=target_length,
@@ -829,7 +831,7 @@ class DogeModel(DogePreTrainedModel):
829
  return causal_mask
830
 
831
  @staticmethod
832
- def _prepare_4d_causal_attention_mask_with_cache_position_and_dynamic_mask(
833
  attention_mask: torch.Tensor = None,
834
  sequence_length: int = None,
835
  target_length: int = None,
@@ -875,7 +877,6 @@ class DogeModel(DogePreTrainedModel):
875
  causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
876
  causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
877
  if attention_mask is not None:
878
- # print(f"attention_mask: {attention_mask.shape}")
879
  causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
880
  mask_length = attention_mask.shape[-1]
881
  padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
 
3
  #
4
  # This code is based on the Wonderful Matrices paper implementation.
5
  #
6
+ # https://arxiv.org/abs/2412.11834
7
  #
8
  # Licensed under the Apache License, Version 2.0 (the "License");
9
  # you may not use this file except in compliance with the License.
 
184
 
185
 
186
  class DogeDynamicMaskAttention(nn.Module):
187
+ """Dynamic Mask Attention from 'Wonderful Matrices' paper."""
188
 
189
  def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None):
190
  super().__init__()
 
388
 
389
 
390
  class DogeCDMoE(DogeMLP):
391
+ """Cross Domain Mixture of Experts from 'Wonderful Matrices' paper."""
392
 
393
  def __init__(self, config: DogeConfig):
394
  super().__init__(config)
 
818
  )
819
 
820
  # in case the provided `attention` mask is 2D, we generate a causal mask here (4D).
821
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
822
  attention_mask=attention_mask,
823
  sequence_length=sequence_length,
824
  target_length=target_length,
 
831
  return causal_mask
832
 
833
  @staticmethod
834
+ def _prepare_4d_causal_attention_mask_with_cache_position(
835
  attention_mask: torch.Tensor = None,
836
  sequence_length: int = None,
837
  target_length: int = None,
 
877
  causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
878
  causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
879
  if attention_mask is not None:
 
880
  causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
881
  mask_length = attention_mask.shape[-1]
882
  padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]