tomeras1 commited on
Commit
886484a
1 Parent(s): 8ee14c3

Move to in-library checkpoint

Browse files
config.json CHANGED
@@ -12,7 +12,6 @@
12
  "AutoModelForSequenceClassification": "model.JambaForSequenceClassification"
13
  },
14
  "bos_token_id": 1,
15
- "calc_logits_for_entire_prompt": false,
16
  "eos_token_id": 2,
17
  "expert_layer_offset": 1,
18
  "expert_layer_period": 2,
@@ -25,15 +24,15 @@
25
  "mamba_d_state": 16,
26
  "mamba_dt_rank": 256,
27
  "mamba_expand": 2,
28
- "mamba_inner_layernorms": true,
29
  "mamba_proj_bias": false,
 
30
  "model_type": "jamba",
31
- "n_ctx": 262144,
32
  "num_attention_heads": 32,
33
  "num_experts": 16,
34
  "num_experts_per_tok": 2,
35
  "num_hidden_layers": 32,
36
  "num_key_value_heads": 8,
 
37
  "output_router_logits": false,
38
  "pad_token_id": 0,
39
  "rms_norm_eps": 1e-06,
@@ -41,7 +40,7 @@
41
  "sliding_window": null,
42
  "tie_word_embeddings": false,
43
  "torch_dtype": "bfloat16",
44
- "transformers_version": "4.40.0.dev0",
45
  "use_cache": true,
46
  "use_mamba_kernels": true,
47
  "vocab_size": 65536
 
12
  "AutoModelForSequenceClassification": "model.JambaForSequenceClassification"
13
  },
14
  "bos_token_id": 1,
 
15
  "eos_token_id": 2,
16
  "expert_layer_offset": 1,
17
  "expert_layer_period": 2,
 
24
  "mamba_d_state": 16,
25
  "mamba_dt_rank": 256,
26
  "mamba_expand": 2,
 
27
  "mamba_proj_bias": false,
28
+ "max_position_embeddings": 262144,
29
  "model_type": "jamba",
 
30
  "num_attention_heads": 32,
31
  "num_experts": 16,
32
  "num_experts_per_tok": 2,
33
  "num_hidden_layers": 32,
34
  "num_key_value_heads": 8,
35
+ "num_logits_to_keep": 1,
36
  "output_router_logits": false,
37
  "pad_token_id": 0,
38
  "rms_norm_eps": 1e-06,
 
40
  "sliding_window": null,
41
  "tie_word_embeddings": false,
42
  "torch_dtype": "bfloat16",
43
+ "transformers_version": "4.40.1",
44
  "use_cache": true,
45
  "use_mamba_kernels": true,
46
  "vocab_size": 65536
configuration_jamba.py CHANGED
@@ -26,9 +26,9 @@ class JambaConfig(PretrainedConfig):
26
  r"""
27
  This is the configuration class to store the configuration of a [`JambaModel`]. It is used to instantiate a
28
  Jamba model according to the specified arguments, defining the model architecture. Instantiating a configuration
29
- with the defaults will yield a similar configuration to that of the jamba-small architecture.
30
 
31
- [ai21labs/jamba-small](https://huggingface.co/ai21labs/Jamba-v0.1)
32
 
33
  Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
34
  documentation from [`PretrainedConfig`] for more information.
@@ -65,12 +65,12 @@ class JambaConfig(PretrainedConfig):
65
  use_cache (`bool`, *optional*, defaults to `True`):
66
  Whether or not the model should return the last key/values attentions (not used by all models). Only
67
  relevant if `config.is_decoder=True`.
68
- calc_logits_for_entire_prompt (`bool`, *optional*, defaults to `False`):
69
- Whether or not to calculate logits for entire prompt during generation. If `False`, only the logits of the
70
- last prompt token will be calculated, which are the only logits needed for generation. For long sequences,
71
- the logits for the entire sequence may use a lot of memory so setting `calc_logits_for_entire_prompt=False`
72
- will reduce memory footprint significantly.
73
- Note: some generation features may not be available if this is set to `False`.
74
  output_router_logits (`bool`, *optional*, defaults to `False`):
75
  Whether or not the router logits should be returned by the model. Enabling this will also
76
  allow the model to output the auxiliary loss. See [here]() for more details
@@ -84,7 +84,7 @@ class JambaConfig(PretrainedConfig):
84
  The id of the "end-of-sequence" token.
85
  sliding_window (`int`, *optional*):
86
  Sliding window attention window size. If not specified, will default to `None`.
87
- n_ctx (`int`, *optional*, defaults to 262144):
88
  This value doesn't have any real effect. The maximum sequence length that this model is intended to be
89
  used with. It can be used with longer sequences, but performance may degrade.
90
  attention_dropout (`float`, *optional*, defaults to 0.0):
@@ -118,8 +118,6 @@ class JambaConfig(PretrainedConfig):
118
  Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block.
119
  mamba_proj_bias (`bool`, *optional*, defaults to `False`):
120
  Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the mamba mixer block
121
- mamba_inner_layernorms (`bool`, *optional*, defaults to `True`):
122
- Flag indicating whether or not to apply layernorms to internal mamba activations
123
 
124
  """
125
 
@@ -139,14 +137,14 @@ class JambaConfig(PretrainedConfig):
139
  initializer_range=0.02,
140
  rms_norm_eps=1e-6,
141
  use_cache=True,
142
- calc_logits_for_entire_prompt=False,
143
  output_router_logits=False,
144
  router_aux_loss_coef=0.001,
145
  pad_token_id=0,
146
  bos_token_id=1,
147
  eos_token_id=2,
148
  sliding_window=None,
149
- n_ctx=262144,
150
  attention_dropout=0.0,
151
  num_experts_per_tok=2,
152
  num_experts=16,
@@ -161,7 +159,6 @@ class JambaConfig(PretrainedConfig):
161
  mamba_dt_rank="auto",
162
  mamba_conv_bias=True,
163
  mamba_proj_bias=False,
164
- mamba_inner_layernorms=True,
165
  **kwargs,
166
  ):
167
  self.vocab_size = vocab_size
@@ -171,7 +168,7 @@ class JambaConfig(PretrainedConfig):
171
  self.num_hidden_layers = num_hidden_layers
172
  self.num_attention_heads = num_attention_heads
173
  self.sliding_window = sliding_window
174
- self.n_ctx = n_ctx
175
  self.attention_dropout = attention_dropout
176
 
177
  # for backward compatibility
@@ -184,7 +181,7 @@ class JambaConfig(PretrainedConfig):
184
  self.rms_norm_eps = rms_norm_eps
185
 
186
  self.use_cache = use_cache
187
- self.calc_logits_for_entire_prompt = calc_logits_for_entire_prompt
188
  self.output_router_logits = output_router_logits
189
  self.router_aux_loss_coef = router_aux_loss_coef
190
 
@@ -202,7 +199,6 @@ class JambaConfig(PretrainedConfig):
202
  self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if mamba_dt_rank == "auto" else mamba_dt_rank
203
  self.mamba_conv_bias = mamba_conv_bias
204
  self.mamba_proj_bias = mamba_proj_bias
205
- self.mamba_inner_layernorms = mamba_inner_layernorms
206
 
207
  super().__init__(
208
  pad_token_id=pad_token_id,
@@ -211,3 +207,17 @@ class JambaConfig(PretrainedConfig):
211
  tie_word_embeddings=tie_word_embeddings,
212
  **kwargs,
213
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  r"""
27
  This is the configuration class to store the configuration of a [`JambaModel`]. It is used to instantiate a
28
  Jamba model according to the specified arguments, defining the model architecture. Instantiating a configuration
29
+ with the defaults will yield a similar configuration to that of the Jamba-v0.1 model.
30
 
31
+ [ai21labs/Jamba-v0.1](https://huggingface.co/ai21labs/Jamba-v0.1)
32
 
33
  Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
34
  documentation from [`PretrainedConfig`] for more information.
 
65
  use_cache (`bool`, *optional*, defaults to `True`):
66
  Whether or not the model should return the last key/values attentions (not used by all models). Only
67
  relevant if `config.is_decoder=True`.
68
+ num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
69
+ Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an
70
+ integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the
71
+ logits of the last prompt token are needed for generation. For long sequences, the logits for the entire
72
+ sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint
73
+ significantly.
74
  output_router_logits (`bool`, *optional*, defaults to `False`):
75
  Whether or not the router logits should be returned by the model. Enabling this will also
76
  allow the model to output the auxiliary loss. See [here]() for more details
 
84
  The id of the "end-of-sequence" token.
85
  sliding_window (`int`, *optional*):
86
  Sliding window attention window size. If not specified, will default to `None`.
87
+ max_position_embeddings (`int`, *optional*, defaults to 262144):
88
  This value doesn't have any real effect. The maximum sequence length that this model is intended to be
89
  used with. It can be used with longer sequences, but performance may degrade.
90
  attention_dropout (`float`, *optional*, defaults to 0.0):
 
118
  Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block.
119
  mamba_proj_bias (`bool`, *optional*, defaults to `False`):
120
  Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the mamba mixer block
 
 
121
 
122
  """
123
 
 
137
  initializer_range=0.02,
138
  rms_norm_eps=1e-6,
139
  use_cache=True,
140
+ num_logits_to_keep=1,
141
  output_router_logits=False,
142
  router_aux_loss_coef=0.001,
143
  pad_token_id=0,
144
  bos_token_id=1,
145
  eos_token_id=2,
146
  sliding_window=None,
147
+ max_position_embeddings=262144,
148
  attention_dropout=0.0,
149
  num_experts_per_tok=2,
150
  num_experts=16,
 
159
  mamba_dt_rank="auto",
160
  mamba_conv_bias=True,
161
  mamba_proj_bias=False,
 
162
  **kwargs,
163
  ):
164
  self.vocab_size = vocab_size
 
168
  self.num_hidden_layers = num_hidden_layers
169
  self.num_attention_heads = num_attention_heads
170
  self.sliding_window = sliding_window
171
+ self.max_position_embeddings = max_position_embeddings
172
  self.attention_dropout = attention_dropout
173
 
174
  # for backward compatibility
 
181
  self.rms_norm_eps = rms_norm_eps
182
 
183
  self.use_cache = use_cache
184
+ self.num_logits_to_keep = num_logits_to_keep
185
  self.output_router_logits = output_router_logits
186
  self.router_aux_loss_coef = router_aux_loss_coef
187
 
 
199
  self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if mamba_dt_rank == "auto" else mamba_dt_rank
200
  self.mamba_conv_bias = mamba_conv_bias
201
  self.mamba_proj_bias = mamba_proj_bias
 
202
 
203
  super().__init__(
204
  pad_token_id=pad_token_id,
 
207
  tie_word_embeddings=tie_word_embeddings,
208
  **kwargs,
209
  )
210
+
211
+ @property
212
+ def layers_block_type(self):
213
+ return [
214
+ "attention" if i % self.attn_layer_period == self.attn_layer_offset else "mamba"
215
+ for i in range(self.num_hidden_layers)
216
+ ]
217
+
218
+ @property
219
+ def layers_num_experts(self):
220
+ return [
221
+ self.num_experts if i % self.expert_layer_period == self.expert_layer_offset else 1
222
+ for i in range(self.num_hidden_layers)
223
+ ]
generation_config.json CHANGED
@@ -3,5 +3,5 @@
3
  "bos_token_id": 1,
4
  "eos_token_id": 2,
5
  "pad_token_id": 0,
6
- "transformers_version": "4.40.0.dev0"
7
  }
 
3
  "bos_token_id": 1,
4
  "eos_token_id": 2,
5
  "pad_token_id": 0,
6
+ "transformers_version": "4.40.1"
7
  }
model-00001-of-00021.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ce46bbcbda10cfac6b5855da022777a0387e2b729cbdd219081fa3f69cb214a2
3
- size 4951236864
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1aace34ee0da3bf95605bd150fff6d3e78110be4048a3c389b0a740354b2ccb7
3
+ size 4951761424
model-00002-of-00021.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7ae0b82247f3164270f151fa12ea1ceb63992e8827c739319fe20342eadafa8a
3
- size 4884145024
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ba1de67a86329431f14f7ffa165d84055d32ce57a6d2314e3b2464eac3732dc
3
+ size 4884669624
model-00003-of-00021.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d2d44419116a65b3617fa35d20b69e2060449b53c0ac36192a3ec4b0a60b0a8d
3
- size 4992294632
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1abc4f16865fb78241c9453292ee3b2ca2c1e2d54ee945631da625834b95c9b2
3
+ size 4992557120
model-00004-of-00021.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:70fe04d7dc1124871ca1f6071504ba019174db27cd57c625938e6383ebee5fee
3
- size 4958591040
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:45fab97739a58e924791572ea3d06f9c90b9ff2a299460aaa4bd87c6e9d424f3
3
+ size 4958853560
model-00005-of-00021.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:402079425e45a01c256a080cae3ab39be3f3cfae56dba7c815a44f0c58b3a442
3
- size 4975501296
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c4b0ec6e8f33e6d7b1f837cd4c25818487dcc7e478734606da28110507e51c97
3
+ size 4975763832
model-00006-of-00021.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2cc9971c058d95a8f13966a3aa82294564381937902634c0c064be68104821ae
3
- size 4884145016
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed98d5c3c8d7ab7352944bea09b0d54d98066cf567ba3d069da12c05575d56ed
3
+ size 4884669616
model-00007-of-00021.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d9ba83c87790cdb6fb9f7861a712f315469edbf065ab64bdaa35cc99b4ec8746
3
- size 4884144968
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:735be2bc568711bf42a4caebcda8288dd300b31b48fa098b00df3cf1a98e10e2
3
+ size 4884669640
model-00008-of-00021.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b45331970c155ca74f509576cb050d006997bef08a99189cf047aa1a3a4b254e
3
- size 4992294696
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d0c8d817b2b47661d361e8b520128b3194185f756cc2204a95d642e24895ee51
3
+ size 4992557176
model-00009-of-00021.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:fe5a1e58d58598a64a59a3ca87c170a171a7bba2102138c71047d5b5458cdebf
3
- size 4932506800
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e50222cf865ca5678d22574b131294303c46b249478cf70113c701f70331e999
3
+ size 4932507176
model-00010-of-00021.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:bd37916e35b2b3b98e7a9bb790a779ac51ad0bbcff92428c0ed11c8839379205
3
- size 4884145056
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b1b4b69b24ae55827b6c8b1e4a10807aa3525bc85f4d34dc002ac7440757fbf4
3
+ size 4884669672
model-00011-of-00021.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:392435bc85f4c90bf129c30260da8c820f35bca91610aa0e682cb915f1d855c6
3
- size 4884145088
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60213cac13b92ed34b93ce48e670434f22e3bf8b2b8df20c60b7bf8a9515c35c
3
+ size 4884669696
model-00012-of-00021.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b5141158f7a755a7e0f60c73f4c25ba02c2bfdab548944f8d4146f41391c621a
3
- size 4884145088
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05805eacd3bb40cc9da802350409f1cb078e8b276da7e06c7a8a5ca5b26cc887
3
+ size 4884669688
model-00013-of-00021.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:14ce5aabc4a17e54e40b30fba322104dd19bad512bab6e554fa56bafe4433da7
3
- size 4932506800
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:201df979a1b34ced6cdbb7a790163412636779f1119e3845a704c489181d03d2
3
+ size 4932507176
model-00014-of-00021.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:dd79a548e39ee02f6a9b553f93f6652783c9dbc895ab685848d9e1655903965f
3
- size 4992294648
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d0a7eb42a9ea3a385442c2e758dd5efd5dc5b913f1d10bfd37792cc963a33c93
3
+ size 4992557152
model-00015-of-00021.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:eb5af1275e6a0c5bc2c195e1802a64cee6aa92e3a11fcff5acd8b7bbf720ef75
3
- size 4884145088
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a4b9afe4398000c28b36e3aa40c87086af673d4f8a64bfc5767941ab2008bcc9
3
+ size 4884669688
model-00016-of-00021.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e9efc22a654010417091851b00277db7116e8c532ae5410cacc13bfa49b99c06
3
- size 4884145088
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd1ac6cc861971c43bdf0c9c6d4c9fe72d33e5227e054a621e2e68f001419763
3
+ size 4884669688
model-00017-of-00021.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ec7ab387e62b0c65a3567cc4d17d13166b577cf89ff59a8d5d7b248fdbbc68da
3
- size 4908260352
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52d9eea696dd29ef413d617bbcb62a9f159e8fe8170d36e018932cef45ee281d
3
+ size 4908522856
model-00018-of-00021.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3f65a30ff1d8e1fc086460839056e7bc7a6a2ef81f0df35dc1a752bf951f92df
3
- size 4908391496
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77acada7c098e81280645ea0a9dbfa00196dca6da8946498b9907e9e376fb42d
3
+ size 4908654000
model-00019-of-00021.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7975019ffa4bb6f502e3406a53ef61ee08085330502ba32fb3e9883b7033c8c7
3
- size 4992294688
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:09e10dfd6c6459cd3460b1d667639717d3657274c1694c19a6fdbac1be6a76bf
3
+ size 4992557168
model-00020-of-00021.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6363d3d6f89d09a971af839cd923a206a06e73d090ae74a605ed27e97fab93cf
3
- size 4884145088
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2bd5c27b2cca6e06f7b4497ce8c9b1522a64846817a871bad274d08507960ed0
3
+ size 4884669696
model-00021-of-00021.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a21e65470d7dbe4ae849be427eb5366cc7cc311138cc7f943f3d71d84b7c7ffd
3
- size 4647318256
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a47ef23db8deb5364da676a40dc3dcb011fb9d9ceef13ba044c176e9a83ac1e3
3
+ size 4647318576
model.safetensors.index.json CHANGED
The diff for this file is too large to render. See raw diff
 
modeling_jamba.py CHANGED
@@ -20,8 +20,6 @@
20
  """ PyTorch Jamba model."""
21
  import inspect
22
  import math
23
- import warnings
24
- from dataclasses import dataclass, field
25
  from typing import Any, Dict, List, Optional, Tuple, Union
26
 
27
  import torch
@@ -31,10 +29,9 @@ from torch import nn
31
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
32
 
33
  from transformers.activations import ACT2FN
34
- from transformers.cache_utils import Cache, DynamicCache
35
  from transformers.modeling_attn_mask_utils import (
36
- _prepare_4d_causal_attention_mask,
37
- _prepare_4d_causal_attention_mask_for_sdpa,
38
  )
39
  from transformers.modeling_outputs import (
40
  MoeCausalLMOutputWithPast,
@@ -42,7 +39,6 @@ from transformers.modeling_outputs import (
42
  SequenceClassifierOutputWithPast,
43
  )
44
  from transformers.modeling_utils import PreTrainedModel
45
- from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13
46
  from transformers.utils import (
47
  add_start_docstrings,
48
  add_start_docstrings_to_model_forward,
@@ -50,11 +46,15 @@ from transformers.utils import (
50
  logging,
51
  replace_return_docstrings,
52
  )
53
- from transformers.utils.import_utils import is_torch_fx_available
 
 
 
 
54
  from .configuration_jamba import JambaConfig
55
 
56
 
57
- # try except block so it'll work with trust_remote_code. Later we can have `if is_flash_attn_2_available():`
58
  try:
59
  from flash_attn import flash_attn_func, flash_attn_varlen_func
60
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
@@ -63,22 +63,15 @@ try:
63
  except ImportError:
64
  pass
65
 
66
- # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
67
- # It means that the function will not be traced through and simply appear as a node in the graph.
68
- if is_torch_fx_available():
69
- if not is_torch_greater_or_equal_than_1_13:
70
- import torch.fx
71
-
72
- _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
73
 
74
- # try except block so it'll work with trust_remote_code. Later we can have `if is_mamba_ssm_available():`
75
  try:
76
  from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
77
  from mamba_ssm.ops.triton.selective_state_update import selective_state_update
78
  except ImportError:
79
  selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
80
 
81
- # try except block so it'll work with trust_remote_code. Later we can have `if is_causal_conv1d_available():`
82
  try:
83
  from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
84
  except ImportError:
@@ -94,9 +87,12 @@ logger = logging.get_logger(__name__)
94
  _CONFIG_FOR_DOC = "JambaConfig"
95
 
96
 
97
- # Adapted from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func
98
  def load_balancing_loss_func(
99
- gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
 
 
 
100
  ) -> float:
101
  r"""
102
  Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
@@ -106,7 +102,7 @@ def load_balancing_loss_func(
106
  experts is too unbalanced.
107
 
108
  Args:
109
- gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
110
  Logits from the `router`, should be a tuple of model.config.num_hidden_layers tensors of
111
  shape [batch_size X sequence_length, num_experts].
112
  attention_mask (`torch.Tensor`, None):
@@ -118,16 +114,16 @@ def load_balancing_loss_func(
118
  Returns:
119
  The auxiliary loss.
120
  """
121
- if gate_logits is None or not isinstance(gate_logits, tuple):
122
  return 0
123
 
124
- if isinstance(gate_logits, tuple):
125
- compute_device = gate_logits[0].device
126
- concatenated_gate_logits = torch.cat(
127
- [layer_gate.to(compute_device) for layer_gate in gate_logits if layer_gate.shape[1] > 1], dim=0
128
  )
129
 
130
- routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
131
 
132
  _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
133
 
@@ -141,7 +137,7 @@ def load_balancing_loss_func(
141
  router_prob_per_expert = torch.mean(routing_weights, dim=0)
142
  else:
143
  batch_size, sequence_length = attention_mask.shape
144
- num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
145
 
146
  # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
147
  expert_attention_mask = (
@@ -217,6 +213,82 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
217
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
218
 
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  # Adapted from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Jamba
221
  class JambaAttention(nn.Module):
222
  """
@@ -253,23 +325,16 @@ class JambaAttention(nn.Module):
253
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
254
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
255
 
256
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
257
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
258
-
259
  def forward(
260
  self,
261
  hidden_states: torch.Tensor,
262
  attention_mask: Optional[torch.Tensor] = None,
263
  position_ids: Optional[torch.LongTensor] = None,
264
- past_key_value: Optional[Cache] = None,
265
  output_attentions: bool = False,
266
  use_cache: bool = False,
267
- **kwargs,
268
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
269
- if "padding_mask" in kwargs:
270
- warnings.warn(
271
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
272
- )
273
  bsz, q_len, _ = hidden_states.size()
274
 
275
  query_states = self.q_proj(hidden_states)
@@ -280,16 +345,6 @@ class JambaAttention(nn.Module):
280
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
281
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
282
 
283
- kv_seq_len = key_states.shape[-2]
284
- if past_key_value is not None:
285
- if self.layer_idx is None:
286
- raise ValueError(
287
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
288
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
289
- "with a layer index."
290
- )
291
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
292
-
293
  if past_key_value is not None:
294
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
295
 
@@ -299,19 +354,9 @@ class JambaAttention(nn.Module):
299
 
300
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
301
 
302
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
303
- raise ValueError(
304
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
305
- f" {attn_weights.size()}"
306
- )
307
-
308
- if attention_mask is not None:
309
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
310
- raise ValueError(
311
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
312
- )
313
-
314
- attn_weights = attn_weights + attention_mask
315
 
316
  # upcast attention to fp32
317
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
@@ -357,37 +402,26 @@ class JambaFlashAttention2(JambaAttention):
357
  hidden_states: torch.Tensor,
358
  attention_mask: Optional[torch.Tensor] = None,
359
  position_ids: Optional[torch.LongTensor] = None,
360
- past_key_value: Optional[Cache] = None,
361
  output_attentions: bool = False,
362
  use_cache: bool = False,
 
363
  **kwargs,
364
  ):
365
- if "padding_mask" in kwargs:
366
- warnings.warn(
367
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
368
- )
369
-
370
- # overwrite attention_mask with padding_mask
371
- attention_mask = kwargs.pop("padding_mask")
372
  bsz, q_len, _ = hidden_states.size()
373
 
374
  query_states = self.q_proj(hidden_states)
375
  key_states = self.k_proj(hidden_states)
376
  value_states = self.v_proj(hidden_states)
377
 
 
 
 
378
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
379
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
380
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
381
 
382
- kv_seq_len = key_states.shape[-2]
383
- if past_key_value is not None:
384
- if self.layer_idx is None:
385
- raise ValueError(
386
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
387
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
388
- "with a layer index."
389
- )
390
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
391
 
392
  use_sliding_windows = (
393
  _flash_supports_window_size
@@ -403,7 +437,7 @@ class JambaFlashAttention2(JambaAttention):
403
 
404
  if past_key_value is not None:
405
  # Activate slicing cache only if the config has a value `sliding_windows` attribute
406
- cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
407
  if (
408
  getattr(self.config, "sliding_window", None) is not None
409
  and kv_seq_len > self.config.sliding_window
@@ -505,7 +539,7 @@ class JambaFlashAttention2(JambaAttention):
505
  attention_mask (`torch.Tensor`):
506
  The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
507
  position of padding tokens and 1 for the position of non-padding tokens.
508
- dropout (`int`, *optional*):
509
  Attention dropout
510
  softmax_scale (`float`, *optional*):
511
  The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
@@ -580,6 +614,7 @@ class JambaFlashAttention2(JambaAttention):
580
 
581
  return attn_output
582
 
 
583
  def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
584
  batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
585
 
@@ -637,9 +672,10 @@ class JambaSdpaAttention(JambaAttention):
637
  hidden_states: torch.Tensor,
638
  attention_mask: Optional[torch.Tensor] = None,
639
  position_ids: Optional[torch.LongTensor] = None,
640
- past_key_value: Optional[Cache] = None,
641
  output_attentions: bool = False,
642
  use_cache: bool = False,
 
643
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
644
  if output_attentions:
645
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
@@ -666,21 +702,15 @@ class JambaSdpaAttention(JambaAttention):
666
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
667
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
668
 
669
- kv_seq_len = key_states.shape[-2]
670
- if past_key_value is not None:
671
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
672
-
673
  if past_key_value is not None:
674
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
675
 
676
  key_states = repeat_kv(key_states, self.num_key_value_groups)
677
  value_states = repeat_kv(value_states, self.num_key_value_groups)
678
 
 
679
  if attention_mask is not None:
680
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
681
- raise ValueError(
682
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
683
- )
684
 
685
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
686
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
@@ -693,7 +723,7 @@ class JambaSdpaAttention(JambaAttention):
693
  query_states,
694
  key_states,
695
  value_states,
696
- attn_mask=attention_mask,
697
  dropout_p=self.attention_dropout if self.training else 0.0,
698
  # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
699
  is_causal=self.is_causal and attention_mask is None and q_len > 1,
@@ -714,99 +744,6 @@ JAMBA_ATTENTION_CLASSES = {
714
  }
715
 
716
 
717
- class HybridMambaAttentionDynamicCache(DynamicCache):
718
- """
719
- A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
720
- (which has a constant shape regardless of seq_len).
721
-
722
- It stores the Key and Value states as a list of tensors, one for each layer.
723
- The expected shape for each tensor for attention layers is `[batch_size, num_heads, seq_len, head_dim]`.
724
- For the mamba layers, the `key_cache` represents the convolution state and has a shape of `[batch_size, d_inner, 1, d_conv]`,
725
- and the `value_cache` represents the ssm state and has a shape of `[batch_size, d_inner, 1, d_state]`. Mamba cache
726
- shape[2] is a dummy "seqlen" dimension to match the number of attention cache dimensions. For mamba, the cache
727
- doesn't grow with seqlen so this dimension is always 1.
728
- """
729
-
730
- def __init__(self) -> None:
731
- super().__init__()
732
- self.attention_layer_idx = None # used to know which layer has data on seqlen in the cache shape
733
-
734
- def update(
735
- self,
736
- key_states: torch.Tensor,
737
- value_states: torch.Tensor,
738
- layer_idx: int,
739
- cache_kwargs: Optional[Dict[str, Any]] = None,
740
- ) -> Tuple[torch.Tensor, torch.Tensor]:
741
- """
742
- Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
743
-
744
- Parameters:
745
- key_states (`torch.Tensor`):
746
- The new key states to cache.
747
- value_states (`torch.Tensor`):
748
- The new value states to cache.
749
- layer_idx (`int`):
750
- The index of the layer to cache the states for.
751
- cache_kwargs (`Dict[str, Any]`, `optional`):
752
- Additional arguments for the cache subclass. No additional arguments are used in `HybridMambaAttentionDynamicCache`.
753
-
754
- Return:
755
- A tuple containing the updated key and value states.
756
- """
757
- # Update the number of seen tokens
758
- if self.attention_layer_idx is None and self._is_attn_layer(key_states, value_states):
759
- self.attention_layer_idx = layer_idx
760
- if self.attention_layer_idx is not None and layer_idx == self.attention_layer_idx:
761
- if hasattr(self, "_seen_tokens"):
762
- self._seen_tokens += key_states.shape[-2]
763
- else:
764
- self.seen_tokens += key_states.shape[-2]
765
-
766
- # Update the cache
767
- if len(self.key_cache) <= layer_idx:
768
- self.key_cache.append(key_states)
769
- self.value_cache.append(value_states)
770
- else:
771
- if self._is_attn_layer(self.key_cache[layer_idx], self.value_cache[layer_idx]):
772
- # attention layer - append the new states to the existing cache on the seqlen dimension
773
- self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
774
- self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
775
- else:
776
- # mamba layer - replace the cache with the new states
777
- self.key_cache[layer_idx] = key_states
778
- self.value_cache[layer_idx] = value_states
779
-
780
- return self.key_cache[layer_idx], self.value_cache[layer_idx]
781
-
782
- def get_seq_length(self, layer_idx: Optional[int] = None) -> int:
783
- """Returns the sequence length of the cached states. A layer index can be optionally passed."""
784
- if layer_idx is not None:
785
- if len(self.key_cache) <= layer_idx:
786
- return 0
787
- if self._is_attn_layer(self.key_cache[layer_idx], self.value_cache[layer_idx]):
788
- return self.key_cache[layer_idx].shape[-2]
789
- else:
790
- warnings.warn(
791
- f"Asked to get the sequence length from cache of layer {layer_idx} which is not an attention layer. "
792
- f"Ignoring that and using an attention layer cache"
793
- )
794
- if self.attention_layer_idx is None or len(self.key_cache) <= self.attention_layer_idx:
795
- return 0
796
- return self.key_cache[self.attention_layer_idx].shape[-2]
797
-
798
- @staticmethod
799
- def _is_attn_layer(key_states: torch.Tensor, value_states: torch.Tensor):
800
- return key_states.shape[-1] == value_states.shape[-1]
801
-
802
-
803
- @dataclass
804
- class MambaCacheParams:
805
- seqlen_offset: int = 0
806
- conv_states: Dict[int, torch.Tensor] = field(default_factory=dict)
807
- ssm_states: Dict[int, torch.Tensor] = field(default_factory=dict)
808
-
809
-
810
  # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
811
  class JambaMambaMixer(nn.Module):
812
  """
@@ -838,7 +775,6 @@ class JambaMambaMixer(nn.Module):
838
 
839
  self.activation = config.hidden_act
840
  self.act = ACT2FN[config.hidden_act]
841
- self.apply_inner_layernorms = config.mamba_inner_layernorms
842
 
843
  self.use_fast_kernels = config.use_mamba_kernels
844
 
@@ -858,14 +794,9 @@ class JambaMambaMixer(nn.Module):
858
  self.D = nn.Parameter(torch.ones(self.intermediate_size))
859
  self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)
860
 
861
- if self.apply_inner_layernorms:
862
- self.dt_layernorm = JambaRMSNorm(self.time_step_rank, eps=config.rms_norm_eps)
863
- self.B_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
864
- self.C_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
865
- else:
866
- self.dt_layernorm = None
867
- self.B_layernorm = None
868
- self.C_layernorm = None
869
 
870
  if not is_fast_path_available:
871
  logger.warning_once(
@@ -874,145 +805,121 @@ class JambaMambaMixer(nn.Module):
874
  " https://github.com/Dao-AILab/causal-conv1d. If you want to use the naive implementation, set `use_mamba_kernels=False` in the model config"
875
  )
876
 
877
- def _apply_layernorms(self, dt, B, C):
878
- if self.dt_layernorm is not None:
879
- dt = self.dt_layernorm(dt)
880
- if self.B_layernorm is not None:
881
- B = self.B_layernorm(B)
882
- if self.C_layernorm is not None:
883
- C = self.C_layernorm(C)
884
- return dt, B, C
885
-
886
- def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: MambaCacheParams = None):
887
  # 1. Gated MLP's linear projection
888
  projected_states = self.in_proj(hidden_states).transpose(1, 2)
889
 
890
- if (
891
- self.training and cache_params is None and not self.apply_inner_layernorms
892
- ): # Doesn't support outputting the states -> used for training
893
- contextualized_states = mamba_inner_fn(
894
- projected_states,
895
- self.conv1d.weight,
896
- self.conv1d.bias if self.use_conv_bias else None,
897
- self.x_proj.weight,
898
- self.dt_proj.weight,
899
- self.out_proj.weight,
900
- self.out_proj.bias.float() if self.use_bias else None,
901
- -torch.exp(self.A_log.float()),
902
- None, # input-dependent B
903
- None, # input-dependent C
904
- self.D.float(),
905
- delta_bias=self.dt_proj.bias.float(),
906
- delta_softplus=True,
907
- )
908
 
 
 
 
 
 
 
 
 
 
 
 
909
  else:
910
- hidden_states, gate = projected_states.chunk(2, dim=1)
911
-
912
- # 2. Convolution sequence transformation
913
- conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
914
- if cache_params is not None and cache_params.seqlen_offset > 0:
915
- hidden_states = causal_conv1d_update(
916
- hidden_states.squeeze(-1),
917
- cache_params.conv_states[self.layer_idx],
918
- conv_weights,
919
- self.conv1d.bias,
920
- self.activation,
921
- )
922
- hidden_states = hidden_states.unsqueeze(-1)
923
- else:
924
- if cache_params is not None:
925
- conv_states = nn.functional.pad(
926
- hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)
927
- )
928
- cache_params.conv_states[self.layer_idx].copy_(conv_states)
929
- hidden_states = causal_conv1d_fn(
930
- hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
931
- )
932
 
933
- # 3. State Space Model sequence transformation
934
- # 3.a. input varying initialization of time_step, B and C
935
- ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
936
- time_step, B, C = torch.split(
937
- ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
938
  )
939
- time_step, B, C = self._apply_layernorms(time_step, B, C)
940
-
941
- # Here we need to apply dt_proj without the bias, as the bias is added in the selective scan kernel.
942
- # This is a hack to apply dt_proj while still using the forward pass of `torch.nn.Linear`, which is needed
943
- # in order to make quantization work. Quantization code replaces `torch.nn.Linear` layers with quantized
944
- # linear layers, and requires to call the forward pass directly.
945
- # The original code here was: ```discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)```
946
- if hasattr(self.dt_proj, "base_layer"):
947
- # In case of LoRA, we need to access the base layer to get the weight
948
- time_proj_bias = self.dt_proj.base_layer.bias
949
- self.dt_proj.base_layer.bias = None
950
- else:
951
- time_proj_bias = self.dt_proj.bias
952
- self.dt_proj.bias = None
953
- discrete_time_step = self.dt_proj(time_step).transpose(1, 2)
954
- if hasattr(self.dt_proj, "base_layer"):
955
- self.dt_proj.base_layer.bias = time_proj_bias
956
- else:
957
- self.dt_proj.bias = time_proj_bias
958
-
959
- A = -torch.exp(self.A_log.float())
960
- # 3.c perform the recurrence y ← SSM(A, B, C)(x)
961
- time_proj_bias = time_proj_bias.float() if time_proj_bias is not None else None
962
- if cache_params is not None and cache_params.seqlen_offset > 0:
963
- scan_outputs = selective_state_update(
964
- cache_params.ssm_states[self.layer_idx],
965
- hidden_states[..., 0],
966
- discrete_time_step[..., 0],
967
- A,
968
- B[:, 0],
969
- C[:, 0],
970
- self.D,
971
- gate[..., 0],
972
- time_proj_bias,
973
- dt_softplus=True,
974
- ).unsqueeze(-1)
975
- else:
976
- scan_outputs, ssm_state = selective_scan_fn(
977
- hidden_states,
978
- discrete_time_step,
979
- A,
980
- B.transpose(1, 2),
981
- C.transpose(1, 2),
982
- self.D.float(),
983
- gate,
984
- time_proj_bias,
985
- delta_softplus=True,
986
- return_last_state=True,
987
- )
988
- if ssm_state is not None and cache_params is not None:
989
- cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
990
 
991
- # 4. Final linear projection
992
- contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
993
  return contextualized_states
994
 
995
  # fmt: off
996
- def slow_forward(self, input_states, cache_params: MambaCacheParams = None):
997
  batch_size, seq_len, _ = input_states.shape
998
  dtype = input_states.dtype
999
  # 1. Gated MLP's linear projection
1000
  projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
1001
  hidden_states, gate = projected_states.chunk(2, dim=1)
1002
 
 
1003
  # 2. Convolution sequence transformation
1004
- if cache_params is not None:
1005
  if self.training:
1006
  # In training mode, we don't want to perform in-place operations on ssm_state so we can compute the backwards pass
1007
  ssm_state = cache_params.ssm_states[self.layer_idx].clone()
1008
  else:
1009
  ssm_state = cache_params.ssm_states[self.layer_idx]
1010
 
1011
- if cache_params.seqlen_offset > 0:
 
1012
  conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size]
1013
  conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
1014
  conv_state[:, :, -1] = hidden_states[:, :, 0]
1015
- cache_params.conv_states[self.layer_idx].copy_(conv_state)
1016
  hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
1017
  if self.use_conv_bias:
1018
  hidden_states += self.conv1d.bias
@@ -1022,7 +929,7 @@ class JambaMambaMixer(nn.Module):
1022
  hidden_states,
1023
  (self.conv_kernel_size - hidden_states.shape[-1], 0)
1024
  )
1025
- cache_params.conv_states[self.layer_idx].copy_(conv_state)
1026
  hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
1027
  else:
1028
  ssm_state = torch.zeros(
@@ -1037,7 +944,11 @@ class JambaMambaMixer(nn.Module):
1037
  time_step, B, C = torch.split(
1038
  ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
1039
  )
1040
- time_step, B, C = self._apply_layernorms(time_step, B, C)
 
 
 
 
1041
  discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size]
1042
  discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) # [batch, intermediate_size, seq_len]
1043
 
@@ -1057,15 +968,15 @@ class JambaMambaMixer(nn.Module):
1057
  scan_output = scan_output + (hidden_states * self.D[None, :, None])
1058
  scan_output = (scan_output * self.act(gate))
1059
 
1060
- if cache_params is not None:
1061
- cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
1062
 
1063
  # 4. Final linear projection
1064
  contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
1065
  return contextualized_states
1066
  # fmt: on
1067
 
1068
- def mixer_forward(self, hidden_states, cache_params: MambaCacheParams = None):
1069
  if self.use_fast_kernels:
1070
  if not is_fast_path_available or "cuda" not in self.x_proj.weight.device.type:
1071
  raise ValueError(
@@ -1074,64 +985,17 @@ class JambaMambaMixer(nn.Module):
1074
  return self.cuda_kernels_forward(hidden_states, cache_params)
1075
  return self.slow_forward(hidden_states, cache_params)
1076
 
1077
- def forward(
1078
- self,
1079
- hidden_states: torch.Tensor,
1080
- past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
1081
- **kwargs,
1082
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
1083
- if past_key_value is not None:
1084
- cache_params = MambaCacheParams(
1085
- seqlen_offset=0 if hidden_states.shape[1] > 1 else past_key_value.seen_tokens,
1086
- )
1087
- if len(past_key_value.key_cache) > self.layer_idx:
1088
- # we already have cache for this layer, use it
1089
- # remove the dummy seqlen dim (dim=2)
1090
- cache_params.conv_states[self.layer_idx] = past_key_value.key_cache[self.layer_idx].squeeze(2)
1091
- cache_params.ssm_states[self.layer_idx] = past_key_value.value_cache[self.layer_idx].squeeze(2)
1092
- else:
1093
- # we don't have cache for this layer, initialize it with zeros
1094
- batch_size = hidden_states.shape[0]
1095
- cache_params.conv_states[self.layer_idx] = torch.zeros(
1096
- batch_size,
1097
- self.intermediate_size,
1098
- self.conv_kernel_size,
1099
- device=hidden_states.device,
1100
- dtype=hidden_states.dtype,
1101
- )
1102
- cache_params.ssm_states[self.layer_idx] = torch.zeros(
1103
- batch_size,
1104
- self.intermediate_size,
1105
- self.ssm_state_size,
1106
- device=hidden_states.device,
1107
- dtype=hidden_states.dtype,
1108
- )
1109
- else:
1110
- cache_params = None
1111
-
1112
- res = self.mixer_forward(hidden_states, cache_params)
1113
-
1114
- if past_key_value is not None:
1115
- past_key_value.update(
1116
- # add dummy seqlen dim (dim=2) to match the number of dimensions of the attention cache
1117
- cache_params.conv_states[self.layer_idx].unsqueeze(2),
1118
- cache_params.ssm_states[self.layer_idx].unsqueeze(2),
1119
- self.layer_idx,
1120
- )
1121
-
1122
- return res, past_key_value
1123
-
1124
 
 
1125
  class JambaMLP(nn.Module):
1126
- def __init__(self, config: JambaConfig):
1127
  super().__init__()
1128
- self.ffn_dim = config.intermediate_size
1129
- self.hidden_dim = config.hidden_size
1130
-
1131
- self.gate_proj = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
1132
- self.down_proj = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
1133
- self.up_proj = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
1134
-
1135
  self.act_fn = ACT2FN[config.hidden_act]
1136
 
1137
  def forward(self, x):
@@ -1151,39 +1015,20 @@ class JambaSparseMoeBlock(nn.Module):
1151
  and memory on padding.
1152
  """
1153
 
1154
- def __init__(self, config: JambaConfig, num_experts: int, num_experts_per_tok: int):
1155
  super().__init__()
1156
  self.hidden_dim = config.hidden_size
1157
  self.ffn_dim = config.intermediate_size
 
 
1158
 
1159
- # these values are decided on runtime depending on the layer index
1160
- self.num_experts = num_experts
1161
- self.top_k = num_experts_per_tok
1162
-
1163
- if num_experts > 1:
1164
- # expert routing
1165
- self.router = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
1166
- else:
1167
- self.router = None
1168
-
1169
  self.experts = nn.ModuleList([JambaMLP(config) for _ in range(self.num_experts)])
1170
 
1171
  def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
1172
  """ """
1173
  batch_size, sequence_length, hidden_dim = hidden_states.shape
1174
 
1175
- if self.num_experts == 1:
1176
- # in this case we have a single MLP block and don't need to do any routing
1177
- final_hidden_states = self.experts[0](hidden_states)
1178
- router_logits = torch.ones(
1179
- (batch_size * sequence_length, 1),
1180
- device=hidden_states.device,
1181
- dtype=hidden_states.dtype,
1182
- requires_grad=hidden_states.requires_grad,
1183
- )
1184
- return final_hidden_states, router_logits
1185
-
1186
- # in this case we have multiple experts and need to do routing
1187
  hidden_states = hidden_states.view(-1, hidden_dim)
1188
  # router_logits: (batch * sequence_length, n_experts)
1189
  router_logits = self.router(hidden_states)
@@ -1208,15 +1053,11 @@ class JambaSparseMoeBlock(nn.Module):
1208
  if top_x.shape[0] == 0:
1209
  continue
1210
 
1211
- # in torch it is faster to index using lists than torch tensors
1212
- top_x_list = top_x.tolist()
1213
- idx_list = idx.tolist()
1214
-
1215
  # Index the correct hidden states and compute the expert hidden state for
1216
  # the current expert. We need to make sure to multiply the output hidden
1217
  # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
1218
- current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
1219
- current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
1220
 
1221
  # However `index_add_` only support torch tensors for indexing so we'll use
1222
  # the `top_x` tensor here.
@@ -1226,37 +1067,33 @@ class JambaSparseMoeBlock(nn.Module):
1226
 
1227
 
1228
  class JambaAttentionDecoderLayer(nn.Module):
1229
- def __init__(self, config: JambaConfig, num_experts: int, layer_idx: int):
1230
  super().__init__()
1231
-
1232
  self.self_attn = JAMBA_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
1233
 
1234
- num_experts_per_tok = config.num_experts_per_tok if num_experts > 1 else 1
1235
- self.moe = JambaSparseMoeBlock(config, num_experts=num_experts, num_experts_per_tok=num_experts_per_tok)
1236
  self.input_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1237
- self.pre_moe_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1238
 
1239
  def forward(
1240
  self,
1241
  hidden_states: torch.Tensor,
1242
  attention_mask: Optional[torch.Tensor] = None,
1243
  position_ids: Optional[torch.LongTensor] = None,
1244
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
1245
  output_attentions: Optional[bool] = False,
1246
  output_router_logits: Optional[bool] = False,
1247
  use_cache: Optional[bool] = False,
1248
- **kwargs,
1249
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1250
- if "padding_mask" in kwargs:
1251
- warnings.warn(
1252
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
1253
- )
1254
  """
1255
  Args:
1256
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1257
  attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
1258
  `(batch, sequence_length)` where padding elements are indicated by 0.
1259
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1260
  output_attentions (`bool`, *optional*):
1261
  Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1262
  returned tensors for more detail.
@@ -1266,6 +1103,8 @@ class JambaAttentionDecoderLayer(nn.Module):
1266
  use_cache (`bool`, *optional*):
1267
  If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1268
  (see `past_key_values`).
 
 
1269
  """
1270
 
1271
  residual = hidden_states
@@ -1279,15 +1118,20 @@ class JambaAttentionDecoderLayer(nn.Module):
1279
  past_key_value=past_key_value,
1280
  output_attentions=output_attentions,
1281
  use_cache=use_cache,
 
1282
  )
1283
 
1284
  # residual connection after attention
1285
  hidden_states = residual + hidden_states
1286
 
1287
- # Experts
1288
  residual = hidden_states
1289
- hidden_states = self.pre_moe_layernorm(hidden_states)
1290
- hidden_states, router_logits = self.moe(hidden_states)
 
 
 
 
1291
  hidden_states = residual + hidden_states
1292
 
1293
  outputs = (hidden_states,)
@@ -1305,15 +1149,15 @@ class JambaAttentionDecoderLayer(nn.Module):
1305
 
1306
 
1307
  class JambaMambaDecoderLayer(nn.Module):
1308
- def __init__(self, config: JambaConfig, num_experts: int, layer_idx: int):
1309
  super().__init__()
1310
-
1311
  self.mamba = JambaMambaMixer(config=config, layer_idx=layer_idx)
1312
 
1313
- num_experts_per_tok = config.num_experts_per_tok if num_experts > 1 else 1
1314
- self.moe = JambaSparseMoeBlock(config, num_experts=num_experts, num_experts_per_tok=num_experts_per_tok)
1315
  self.input_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1316
- self.pre_moe_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1317
 
1318
  def forward(
1319
  self,
@@ -1324,18 +1168,14 @@ class JambaMambaDecoderLayer(nn.Module):
1324
  output_attentions: Optional[bool] = False,
1325
  output_router_logits: Optional[bool] = False,
1326
  use_cache: Optional[bool] = False,
1327
- **kwargs,
1328
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1329
- if "padding_mask" in kwargs:
1330
- warnings.warn(
1331
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
1332
- )
1333
  """
1334
  Args:
1335
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1336
  attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
1337
  `(batch, sequence_length)` where padding elements are indicated by 0.
1338
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1339
  output_attentions (`bool`, *optional*):
1340
  Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1341
  returned tensors for more detail.
@@ -1345,28 +1185,31 @@ class JambaMambaDecoderLayer(nn.Module):
1345
  use_cache (`bool`, *optional*):
1346
  If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1347
  (see `past_key_values`).
 
 
1348
  """
1349
 
1350
  residual = hidden_states
1351
 
1352
  hidden_states = self.input_layernorm(hidden_states)
1353
 
1354
- hidden_states, present_key_value = self.mamba(
1355
  hidden_states=hidden_states,
1356
- past_key_value=past_key_value,
1357
  )
1358
- bs, seqlen, _ = hidden_states.shape
1359
- past_seqlen = self._get_past_seqlen(past_key_value, seqlen)
1360
- num_attention_heads = self.mamba.config.num_attention_heads
1361
- self_attn_weights = torch.empty(bs, num_attention_heads, seqlen, past_seqlen, device="meta")
1362
 
1363
  # residual connection after mamba
1364
  hidden_states = residual + hidden_states
1365
 
1366
- # Experts
1367
  residual = hidden_states
1368
- hidden_states = self.pre_moe_layernorm(hidden_states)
1369
- hidden_states, router_logits = self.moe(hidden_states)
 
 
 
 
1370
  hidden_states = residual + hidden_states
1371
 
1372
  outputs = (hidden_states,)
@@ -1375,25 +1218,13 @@ class JambaMambaDecoderLayer(nn.Module):
1375
  outputs += (self_attn_weights,)
1376
 
1377
  if use_cache:
1378
- outputs += (present_key_value,)
1379
 
1380
  if output_router_logits:
1381
  outputs += (router_logits,)
1382
 
1383
  return outputs
1384
 
1385
- def _get_past_seqlen(self, past_key_value, seqlen):
1386
- if past_key_value is None:
1387
- return seqlen
1388
- past_seqlen = past_key_value.get_seq_length()
1389
- if past_seqlen == 0:
1390
- return seqlen
1391
- if past_key_value.attention_layer_idx is None:
1392
- return seqlen
1393
- if self.mamba.layer_idx < past_key_value.attention_layer_idx:
1394
- return past_seqlen + 1
1395
- return past_seqlen
1396
-
1397
 
1398
  JAMBA_START_DOCSTRING = r"""
1399
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
@@ -1416,7 +1247,6 @@ JAMBA_START_DOCSTRING = r"""
1416
  "The bare Jamba Model outputting raw hidden-states without any specific head on top.",
1417
  JAMBA_START_DOCSTRING,
1418
  )
1419
- # Adapted from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Jamba
1420
  class JambaPreTrainedModel(PreTrainedModel):
1421
  config_class = JambaConfig
1422
  base_model_prefix = "model"
@@ -1438,42 +1268,6 @@ class JambaPreTrainedModel(PreTrainedModel):
1438
  if module.padding_idx is not None:
1439
  module.weight.data[module.padding_idx].zero_()
1440
 
1441
- @staticmethod
1442
- def _convert_to_standard_cache(
1443
- past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
1444
- ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
1445
- """
1446
- Standardizes the format of the cache so as to match most implementations, i.e. have the seqlen as the third dim
1447
- also for mamba layers
1448
- """
1449
- attn_layer_index = [k.shape == v.shape for k, v in past_key_value].index(True)
1450
- seqlen = past_key_value[attn_layer_index][0].shape[2]
1451
- standard_past_key_value = ()
1452
- for k, v in past_key_value:
1453
- if k.shape != v.shape:
1454
- # mamba layer
1455
- # expand doesn't use more memory, so it's fine to do it here
1456
- standard_past_key_value += ((k.expand(-1, -1, seqlen, -1), v.expand(-1, -1, seqlen, -1)),)
1457
- else:
1458
- standard_past_key_value += ((k, v),)
1459
- return standard_past_key_value
1460
-
1461
- @staticmethod
1462
- def _convert_to_jamba_cache(
1463
- past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]],
1464
- ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
1465
- """
1466
- Converts the cache to the format expected by Jamba, i.e. dummy seqlen dimesion with size 1 for mamba layers
1467
- """
1468
- jamba_past_key_value = ()
1469
- for k, v in past_key_value:
1470
- if k.shape != v.shape:
1471
- # mamba layer
1472
- jamba_past_key_value += ((k[:, :, :1, :], v[:, :, :1, :]),)
1473
- else:
1474
- jamba_past_key_value += ((k, v),)
1475
- return jamba_past_key_value
1476
-
1477
 
1478
  JAMBA_INPUTS_DOCSTRING = r"""
1479
  Args:
@@ -1510,17 +1304,14 @@ JAMBA_INPUTS_DOCSTRING = r"""
1510
  config.n_positions - 1]`.
1511
 
1512
  [What are position IDs?](../glossary#position-ids)
1513
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1514
- Tuple of `tuple(torch.FloatTensor)` of length `config.num_hidden_layers`, with each tuple having 2 tensors
1515
- corresponding to the cache of the layer.
1516
- For attention layers, both tensors have shape of `(batch_size, num_kv_heads, sequence_length, embed_size_per_head)`
1517
- For mamba layers, the first tensor represents the convolution state and has shape of `(batch_size, d_inner, 1, d_conv)`,
1518
- and the second tensor represents the ssm state and has shape of `(batch_size, d_inner, 1, d_state)`. Mamba
1519
- cache shape[2] is a dummy "seqlen" dimension to match the number of attention cache dimensions. For mamba,
1520
- the cache doesn't grow with seqlen so this dimension is always 1.
1521
-
1522
- Contains pre-computed hidden-states (key and values in the self-attention blocks and convolution and
1523
- ssm states in the mamba blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
1524
 
1525
  If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that
1526
  don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
@@ -1543,8 +1334,14 @@ JAMBA_INPUTS_DOCSTRING = r"""
1543
  should not be returned during inference.
1544
  return_dict (`bool`, *optional*):
1545
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
 
 
 
 
1546
  """
1547
 
 
 
1548
 
1549
  @add_start_docstrings(
1550
  "The bare Jamba Model outputting raw hidden-states without any specific head on top.",
@@ -1565,35 +1362,10 @@ class JambaModel(JambaPreTrainedModel):
1565
  self.vocab_size = config.vocab_size
1566
 
1567
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1568
-
1569
- # init each model layer, decide if it's mamba/attention and has experts or not
1570
  decoder_layers = []
1571
  for i in range(config.num_hidden_layers):
1572
- is_attn = True if (i - self.config.attn_layer_offset) % self.config.attn_layer_period == 0 else False
1573
- is_expert = True if (i - self.config.expert_layer_offset) % self.config.expert_layer_period == 0 else False
1574
-
1575
- num_experts = self.config.num_experts if is_expert else 1
1576
- if is_attn:
1577
- decoder_layers.append(JambaAttentionDecoderLayer(config, num_experts=num_experts, layer_idx=i))
1578
- else:
1579
- decoder_layers.append(JambaMambaDecoderLayer(config, num_experts=num_experts, layer_idx=i))
1580
-
1581
- if not any(isinstance(layer, JambaAttentionDecoderLayer) for layer in decoder_layers):
1582
- raise ValueError("At least one layer in the decoder must be an attention layer")
1583
- self._attn_layer_index = [isinstance(layer, JambaAttentionDecoderLayer) for layer in decoder_layers].index(
1584
- True
1585
- )
1586
-
1587
- if not any(isinstance(layer, JambaMambaDecoderLayer) for layer in decoder_layers):
1588
- raise ValueError("At least one layer in the decoder must be a Mamba layer")
1589
- self._mamba_layer_index = [isinstance(layer, JambaMambaDecoderLayer) for layer in decoder_layers].index(True)
1590
-
1591
- if (
1592
- decoder_layers[self._mamba_layer_index].mamba.ssm_state_size
1593
- == decoder_layers[self._mamba_layer_index].mamba.conv_kernel_size
1594
- ):
1595
- raise ValueError("Mamba state size and convolution size must be different")
1596
-
1597
  self.layers = nn.ModuleList(decoder_layers)
1598
 
1599
  self._attn_implementation = config._attn_implementation
@@ -1609,20 +1381,20 @@ class JambaModel(JambaPreTrainedModel):
1609
  def set_input_embeddings(self, value):
1610
  self.embed_tokens = value
1611
 
1612
- # Ignore copy
1613
  @add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING)
1614
  def forward(
1615
  self,
1616
  input_ids: torch.LongTensor = None,
1617
  attention_mask: Optional[torch.Tensor] = None,
1618
  position_ids: Optional[torch.LongTensor] = None,
1619
- past_key_values: Optional[Union[List[torch.FloatTensor], HybridMambaAttentionDynamicCache]] = None,
1620
  inputs_embeds: Optional[torch.FloatTensor] = None,
1621
  use_cache: Optional[bool] = None,
1622
  output_attentions: Optional[bool] = None,
1623
  output_hidden_states: Optional[bool] = None,
1624
  output_router_logits: Optional[bool] = None,
1625
  return_dict: Optional[bool] = None,
 
1626
  ) -> Union[Tuple, MoeModelOutputWithPast]:
1627
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1628
  output_router_logits = (
@@ -1635,85 +1407,37 @@ class JambaModel(JambaPreTrainedModel):
1635
 
1636
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1637
 
1638
- # retrieve input_ids and inputs_embeds
1639
- if input_ids is not None and inputs_embeds is not None:
1640
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1641
- elif input_ids is not None:
1642
- batch_size, seq_length = input_ids.shape
1643
- elif inputs_embeds is not None:
1644
- batch_size, seq_length, _ = inputs_embeds.shape
1645
- else:
1646
- raise ValueError("You have to specify either input_ids or inputs_embeds")
1647
-
1648
- past_key_values_length = 0
1649
-
1650
- if self.gradient_checkpointing and self.training:
1651
- if use_cache:
1652
- logger.warning_once(
1653
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1654
- )
1655
- use_cache = False
1656
-
1657
- if use_cache:
1658
- if isinstance(past_key_values, Cache) and not isinstance(
1659
- past_key_values, HybridMambaAttentionDynamicCache
1660
- ):
1661
- past_key_values = HybridMambaAttentionDynamicCache.from_legacy_cache(past_key_values.to_legacy_cache())
1662
- use_legacy_cache = not isinstance(past_key_values, HybridMambaAttentionDynamicCache)
1663
- if use_legacy_cache:
1664
- past_key_values = HybridMambaAttentionDynamicCache.from_legacy_cache(past_key_values)
1665
- past_key_values_length = past_key_values.get_usable_length(seq_length, self._attn_layer_index)
1666
 
1667
- if position_ids is None:
1668
- device = input_ids.device if input_ids is not None else inputs_embeds.device
1669
- position_ids = torch.arange(
1670
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1671
  )
1672
- position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1673
- else:
1674
- position_ids = position_ids.view(-1, seq_length).long()
1675
 
1676
  if inputs_embeds is None:
1677
  inputs_embeds = self.embed_tokens(input_ids)
 
1678
 
1679
- if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
1680
- is_padding_right = attention_mask[:, -1].sum().item() != batch_size
1681
- if is_padding_right:
1682
- raise ValueError(
1683
- "You are attempting to perform batched generation with padding_side='right'"
1684
- " this may lead to unexpected behaviour for Flash Attention version of Jamba. Make sure to "
1685
- " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1686
- )
1687
-
1688
- if self._attn_implementation == "flash_attention_2":
1689
- # 2d mask is passed through the layers
1690
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1691
- elif self._attn_implementation == "sdpa" and not output_attentions:
1692
- # output_attentions=True can not be supported when using SDPA, and we fall back on
1693
- # the manual implementation that requires a 4D causal mask in all cases.
1694
- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1695
- attention_mask,
1696
- (batch_size, seq_length),
1697
- inputs_embeds,
1698
- past_key_values_length,
1699
- )
1700
- else:
1701
- # 4d mask is passed through the layers
1702
- attention_mask = _prepare_4d_causal_attention_mask(
1703
- attention_mask,
1704
- (batch_size, seq_length),
1705
- inputs_embeds,
1706
- past_key_values_length,
1707
- sliding_window=self.config.sliding_window,
1708
  )
1709
 
1710
- hidden_states = inputs_embeds
 
 
 
 
 
1711
 
1712
- # decoder layers
1713
  all_hidden_states = () if output_hidden_states else None
1714
  all_self_attns = () if output_attentions else None
1715
  all_router_logits = () if output_router_logits else None
1716
- next_decoder_cache = None
1717
 
1718
  for decoder_layer in self.layers:
1719
  if output_hidden_states:
@@ -1723,34 +1447,37 @@ class JambaModel(JambaPreTrainedModel):
1723
  layer_outputs = self._gradient_checkpointing_func(
1724
  decoder_layer.__call__,
1725
  hidden_states,
1726
- attention_mask,
1727
  position_ids,
1728
  past_key_values,
1729
  output_attentions,
1730
  output_router_logits,
1731
  use_cache,
 
1732
  )
1733
  else:
1734
  layer_outputs = decoder_layer(
1735
  hidden_states,
1736
- attention_mask=attention_mask,
1737
  position_ids=position_ids,
1738
  past_key_value=past_key_values,
1739
  output_attentions=output_attentions,
1740
  output_router_logits=output_router_logits,
1741
  use_cache=use_cache,
 
1742
  )
1743
 
1744
  hidden_states = layer_outputs[0]
1745
 
1746
- if use_cache:
1747
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1748
-
1749
  if output_attentions:
1750
- all_self_attns += (layer_outputs[1],)
 
 
1751
 
1752
  if output_router_logits:
1753
- all_router_logits += (layer_outputs[-1],)
 
 
1754
 
1755
  hidden_states = self.final_layernorm(hidden_states)
1756
 
@@ -1758,9 +1485,10 @@ class JambaModel(JambaPreTrainedModel):
1758
  if output_hidden_states:
1759
  all_hidden_states += (hidden_states,)
1760
 
1761
- next_cache = None
1762
- if use_cache:
1763
- next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
 
1764
 
1765
  if not return_dict:
1766
  return tuple(
@@ -1776,6 +1504,41 @@ class JambaModel(JambaPreTrainedModel):
1776
  router_logits=all_router_logits,
1777
  )
1778
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1779
 
1780
  # Adapted from transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM with MIXTRAL->JAMBA, Mixtral->Jamba
1781
  class JambaForCausalLM(JambaPreTrainedModel):
@@ -1818,7 +1581,7 @@ class JambaForCausalLM(JambaPreTrainedModel):
1818
  input_ids: torch.LongTensor = None,
1819
  attention_mask: Optional[torch.Tensor] = None,
1820
  position_ids: Optional[torch.LongTensor] = None,
1821
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1822
  inputs_embeds: Optional[torch.FloatTensor] = None,
1823
  labels: Optional[torch.LongTensor] = None,
1824
  use_cache: Optional[bool] = None,
@@ -1826,7 +1589,8 @@ class JambaForCausalLM(JambaPreTrainedModel):
1826
  output_hidden_states: Optional[bool] = None,
1827
  output_router_logits: Optional[bool] = None,
1828
  return_dict: Optional[bool] = None,
1829
- calc_logits_for_entire_prompt: Optional[bool] = True,
 
1830
  ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
1831
  r"""
1832
  Args:
@@ -1835,12 +1599,28 @@ class JambaForCausalLM(JambaPreTrainedModel):
1835
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1836
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1837
 
1838
- calc_logits_for_entire_prompt (`bool`, *optional*):
1839
- Whether or not to calculate the logits for the entire prompt, or just the last token. Only last token
1840
- logits are needed for generation, and calculating them only for that token can save memory,
1841
- which becomes pretty significant for long sequences.
1842
 
1843
  Returns:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1844
  ```"""
1845
 
1846
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -1864,14 +1644,15 @@ class JambaForCausalLM(JambaPreTrainedModel):
1864
  output_attentions=output_attentions,
1865
  output_hidden_states=output_hidden_states,
1866
  output_router_logits=output_router_logits,
 
1867
  return_dict=return_dict,
1868
  )
1869
 
1870
  hidden_states = outputs[0]
1871
- if calc_logits_for_entire_prompt:
1872
  logits = self.lm_head(hidden_states)
1873
  else:
1874
- logits = self.lm_head(hidden_states[..., -1:, :])
1875
  logits = logits.float()
1876
 
1877
  loss = None
@@ -1921,27 +1702,15 @@ class JambaForCausalLM(JambaPreTrainedModel):
1921
  attention_mask=None,
1922
  inputs_embeds=None,
1923
  output_router_logits=False,
 
1924
  **kwargs,
1925
  ):
1926
- # Omit tokens covered by past_key_values
1927
- if past_key_values is not None:
1928
- # the cache may be in the stardard format (e.g. in contrastive search), convert to Jamba's format if needed
1929
- if isinstance(past_key_values, Tuple):
1930
- if past_key_values[self.model._mamba_layer_index][0].shape[2] > 1:
1931
- past_key_values = self._convert_to_jamba_cache(past_key_values)
1932
-
1933
- if isinstance(past_key_values, Cache):
1934
- if not isinstance(past_key_values, HybridMambaAttentionDynamicCache):
1935
- past_key_values = HybridMambaAttentionDynamicCache.from_legacy_cache(
1936
- past_key_values.to_legacy_cache()
1937
- )
1938
- cache_length = past_key_values.get_seq_length()
1939
- past_length = past_key_values.seen_tokens
1940
- max_cache_length = past_key_values.get_max_length()
1941
- else:
1942
- cache_length = past_length = past_key_values[self.model._attn_layer_index][0].shape[2]
1943
- max_cache_length = None
1944
 
 
 
 
 
1945
  # Keep only the unprocessed tokens:
1946
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1947
  # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
@@ -1958,20 +1727,24 @@ class JambaForCausalLM(JambaPreTrainedModel):
1958
  if (
1959
  max_cache_length is not None
1960
  and attention_mask is not None
1961
- and cache_length + input_ids.shape[1] > max_cache_length
1962
  ):
1963
  attention_mask = attention_mask[:, -max_cache_length:]
 
 
 
 
1964
 
1965
  position_ids = kwargs.get("position_ids", None)
1966
  if attention_mask is not None and position_ids is None:
1967
  # create position_ids on the fly for batch generation
1968
  position_ids = attention_mask.long().cumsum(-1) - 1
1969
  position_ids.masked_fill_(attention_mask == 0, 1)
1970
- if past_key_values:
1971
  position_ids = position_ids[:, -input_ids.shape[1] :]
1972
 
1973
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1974
- if inputs_embeds is not None and past_key_values is None:
1975
  model_inputs = {"inputs_embeds": inputs_embeds}
1976
  else:
1977
  model_inputs = {"input_ids": input_ids}
@@ -1983,20 +1756,12 @@ class JambaForCausalLM(JambaPreTrainedModel):
1983
  "use_cache": kwargs.get("use_cache"),
1984
  "attention_mask": attention_mask,
1985
  "output_router_logits": output_router_logits,
1986
- "calc_logits_for_entire_prompt": self.config.calc_logits_for_entire_prompt,
 
1987
  }
1988
  )
1989
  return model_inputs
1990
 
1991
- @staticmethod
1992
- def _reorder_cache(past_key_values, beam_idx):
1993
- reordered_past = ()
1994
- for layer_past in past_key_values:
1995
- reordered_past += (
1996
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1997
- )
1998
- return reordered_past
1999
-
2000
 
2001
  @add_start_docstrings(
2002
  """
 
20
  """ PyTorch Jamba model."""
21
  import inspect
22
  import math
 
 
23
  from typing import Any, Dict, List, Optional, Tuple, Union
24
 
25
  import torch
 
29
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
30
 
31
  from transformers.activations import ACT2FN
32
+ from transformers.cache_utils import DynamicCache # we need __iter__ and __len__ of pkv
33
  from transformers.modeling_attn_mask_utils import (
34
+ AttentionMaskConverter,
 
35
  )
36
  from transformers.modeling_outputs import (
37
  MoeCausalLMOutputWithPast,
 
39
  SequenceClassifierOutputWithPast,
40
  )
41
  from transformers.modeling_utils import PreTrainedModel
 
42
  from transformers.utils import (
43
  add_start_docstrings,
44
  add_start_docstrings_to_model_forward,
 
46
  logging,
47
  replace_return_docstrings,
48
  )
49
+ from transformers.utils.import_utils import (
50
+ is_causal_conv1d_available,
51
+ is_flash_attn_2_available,
52
+ is_mamba_ssm_available,
53
+ )
54
  from .configuration_jamba import JambaConfig
55
 
56
 
57
+ # try except block so it'll work with trust_remote_code.
58
  try:
59
  from flash_attn import flash_attn_func, flash_attn_varlen_func
60
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
 
63
  except ImportError:
64
  pass
65
 
 
 
 
 
 
 
 
66
 
67
+ # try except block so it'll work with trust_remote_code.
68
  try:
69
  from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
70
  from mamba_ssm.ops.triton.selective_state_update import selective_state_update
71
  except ImportError:
72
  selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
73
 
74
+ # try except block so it'll work with trust_remote_code.
75
  try:
76
  from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
77
  except ImportError:
 
87
  _CONFIG_FOR_DOC = "JambaConfig"
88
 
89
 
90
+ # Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func with gate->router
91
  def load_balancing_loss_func(
92
+ router_logits: torch.Tensor,
93
+ num_experts: torch.Tensor = None,
94
+ top_k=2,
95
+ attention_mask: Optional[torch.Tensor] = None,
96
  ) -> float:
97
  r"""
98
  Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
 
102
  experts is too unbalanced.
103
 
104
  Args:
105
+ router_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
106
  Logits from the `router`, should be a tuple of model.config.num_hidden_layers tensors of
107
  shape [batch_size X sequence_length, num_experts].
108
  attention_mask (`torch.Tensor`, None):
 
114
  Returns:
115
  The auxiliary loss.
116
  """
117
+ if router_logits is None or not isinstance(router_logits, tuple):
118
  return 0
119
 
120
+ if isinstance(router_logits, tuple):
121
+ compute_device = router_logits[0].device
122
+ concatenated_router_logits = torch.cat(
123
+ [layer_router.to(compute_device) for layer_router in router_logits], dim=0
124
  )
125
 
126
+ routing_weights = torch.nn.functional.softmax(concatenated_router_logits, dim=-1)
127
 
128
  _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
129
 
 
137
  router_prob_per_expert = torch.mean(routing_weights, dim=0)
138
  else:
139
  batch_size, sequence_length = attention_mask.shape
140
+ num_hidden_layers = concatenated_router_logits.shape[0] // (batch_size * sequence_length)
141
 
142
  # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
143
  expert_attention_mask = (
 
213
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
214
 
215
 
216
+ class HybridMambaAttentionDynamicCache(DynamicCache):
217
+ """
218
+ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
219
+ (which has a constant shape regardless of seq_len).
220
+
221
+ This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
222
+ and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
223
+ For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
224
+ while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
225
+ For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
226
+ while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
227
+ and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`.
228
+ """
229
+
230
+ def __init__(self, config, batch_size, dtype=torch.float16, device=None):
231
+ self.dtype = dtype
232
+ self.layers_block_type = config.layers_block_type
233
+ self.has_previous_state = False # only used by mamba
234
+ intermediate_size = config.mamba_expand * config.hidden_size
235
+ ssm_state_size = config.mamba_d_state
236
+ conv_kernel_size = config.mamba_d_conv
237
+ self.conv_states = []
238
+ self.ssm_states = []
239
+ for i in range(config.num_hidden_layers):
240
+ if self.layers_block_type[i] == "mamba":
241
+ self.conv_states += [
242
+ torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype)
243
+ ]
244
+ self.ssm_states += [
245
+ torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype)
246
+ ]
247
+ else:
248
+ self.conv_states += [torch.tensor([[]] * batch_size, device=device)]
249
+ self.ssm_states += [torch.tensor([[]] * batch_size, device=device)]
250
+
251
+ self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
252
+ self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
253
+
254
+ def update(
255
+ self,
256
+ key_states: torch.Tensor,
257
+ value_states: torch.Tensor,
258
+ layer_idx: int,
259
+ cache_kwargs: Optional[Dict[str, Any]] = None,
260
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
261
+ # Update the cache
262
+ if self.key_cache[layer_idx].shape[-1] == 0:
263
+ self.key_cache[layer_idx] = key_states
264
+ self.value_cache[layer_idx] = value_states
265
+ else:
266
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2)
267
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2)
268
+
269
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
270
+
271
+ def reorder_cache(self, beam_idx: torch.LongTensor):
272
+ """Reorders the cache for beam search, given the selected beam indices."""
273
+ for layer_idx in range(len(self.key_cache)):
274
+ device = self.key_cache[layer_idx].device
275
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
276
+ device = self.value_cache[layer_idx].device
277
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
278
+
279
+ device = self.conv_states[layer_idx].device
280
+ self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device))
281
+ device = self.ssm_states[layer_idx].device
282
+ self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device))
283
+
284
+ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
285
+ raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
286
+
287
+ @classmethod
288
+ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
289
+ raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
290
+
291
+
292
  # Adapted from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Jamba
293
  class JambaAttention(nn.Module):
294
  """
 
325
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
326
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
327
 
 
 
 
328
  def forward(
329
  self,
330
  hidden_states: torch.Tensor,
331
  attention_mask: Optional[torch.Tensor] = None,
332
  position_ids: Optional[torch.LongTensor] = None,
333
+ past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
334
  output_attentions: bool = False,
335
  use_cache: bool = False,
336
+ cache_position: Optional[torch.LongTensor] = None,
337
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 
 
 
 
338
  bsz, q_len, _ = hidden_states.size()
339
 
340
  query_states = self.q_proj(hidden_states)
 
345
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
346
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
347
 
 
 
 
 
 
 
 
 
 
 
348
  if past_key_value is not None:
349
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
350
 
 
354
 
355
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
356
 
357
+ if attention_mask is not None: # no matter the length, we just slice it
358
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
359
+ attn_weights = attn_weights + causal_mask
 
 
 
 
 
 
 
 
 
 
360
 
361
  # upcast attention to fp32
362
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
 
402
  hidden_states: torch.Tensor,
403
  attention_mask: Optional[torch.Tensor] = None,
404
  position_ids: Optional[torch.LongTensor] = None,
405
+ past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
406
  output_attentions: bool = False,
407
  use_cache: bool = False,
408
+ cache_position: Optional[torch.LongTensor] = None,
409
  **kwargs,
410
  ):
 
 
 
 
 
 
 
411
  bsz, q_len, _ = hidden_states.size()
412
 
413
  query_states = self.q_proj(hidden_states)
414
  key_states = self.k_proj(hidden_states)
415
  value_states = self.v_proj(hidden_states)
416
 
417
+ # Flash attention requires the input to have the shape
418
+ # batch_size x seq_length x head_dim x hidden_dim
419
+ # therefore we just need to keep the original shape
420
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
421
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
422
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
423
 
424
+ kv_seq_len = cache_position[-1]
 
 
 
 
 
 
 
 
425
 
426
  use_sliding_windows = (
427
  _flash_supports_window_size
 
437
 
438
  if past_key_value is not None:
439
  # Activate slicing cache only if the config has a value `sliding_windows` attribute
440
+ cache_has_contents = cache_position[0] > 0
441
  if (
442
  getattr(self.config, "sliding_window", None) is not None
443
  and kv_seq_len > self.config.sliding_window
 
539
  attention_mask (`torch.Tensor`):
540
  The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
541
  position of padding tokens and 1 for the position of non-padding tokens.
542
+ dropout (`float`, *optional*):
543
  Attention dropout
544
  softmax_scale (`float`, *optional*):
545
  The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
 
614
 
615
  return attn_output
616
 
617
+ # Copied from transformers.models.mixtral.modeling_mixtral.MixtralFlashAttention2._upad_input
618
  def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
619
  batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
620
 
 
672
  hidden_states: torch.Tensor,
673
  attention_mask: Optional[torch.Tensor] = None,
674
  position_ids: Optional[torch.LongTensor] = None,
675
+ past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
676
  output_attentions: bool = False,
677
  use_cache: bool = False,
678
+ cache_position: Optional[torch.LongTensor] = None,
679
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
680
  if output_attentions:
681
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
 
702
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
703
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
704
 
 
 
 
 
705
  if past_key_value is not None:
706
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
707
 
708
  key_states = repeat_kv(key_states, self.num_key_value_groups)
709
  value_states = repeat_kv(value_states, self.num_key_value_groups)
710
 
711
+ causal_mask = attention_mask
712
  if attention_mask is not None:
713
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
 
 
 
714
 
715
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
716
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
 
723
  query_states,
724
  key_states,
725
  value_states,
726
+ attn_mask=causal_mask,
727
  dropout_p=self.attention_dropout if self.training else 0.0,
728
  # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
729
  is_causal=self.is_causal and attention_mask is None and q_len > 1,
 
744
  }
745
 
746
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
747
  # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
748
  class JambaMambaMixer(nn.Module):
749
  """
 
775
 
776
  self.activation = config.hidden_act
777
  self.act = ACT2FN[config.hidden_act]
 
778
 
779
  self.use_fast_kernels = config.use_mamba_kernels
780
 
 
794
  self.D = nn.Parameter(torch.ones(self.intermediate_size))
795
  self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)
796
 
797
+ self.dt_layernorm = JambaRMSNorm(self.time_step_rank, eps=config.rms_norm_eps)
798
+ self.b_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
799
+ self.c_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
 
 
 
 
 
800
 
801
  if not is_fast_path_available:
802
  logger.warning_once(
 
805
  " https://github.com/Dao-AILab/causal-conv1d. If you want to use the naive implementation, set `use_mamba_kernels=False` in the model config"
806
  )
807
 
808
+ def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: HybridMambaAttentionDynamicCache = None):
809
+ batch_size, seq_len, _ = hidden_states.shape
810
+ use_precomputed_states = (
811
+ cache_params is not None
812
+ and cache_params.has_previous_state
813
+ and seq_len == 1
814
+ and cache_params.conv_states[self.layer_idx].shape[0]
815
+ == cache_params.ssm_states[self.layer_idx].shape[0]
816
+ == batch_size
817
+ )
818
  # 1. Gated MLP's linear projection
819
  projected_states = self.in_proj(hidden_states).transpose(1, 2)
820
 
821
+ # We can't use `mamba_inner_fn` even if in training and without cache params because we have the
822
+ # inner layernorms which isn't supported by this fused kernel
823
+ hidden_states, gate = projected_states.chunk(2, dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
824
 
825
+ # 2. Convolution sequence transformation
826
+ conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
827
+ if use_precomputed_states:
828
+ hidden_states = causal_conv1d_update(
829
+ hidden_states.squeeze(-1),
830
+ cache_params.conv_states[self.layer_idx],
831
+ conv_weights,
832
+ self.conv1d.bias,
833
+ self.activation,
834
+ )
835
+ hidden_states = hidden_states.unsqueeze(-1)
836
  else:
837
+ if cache_params is not None:
838
+ conv_states = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0))
839
+ cache_params.conv_states[self.layer_idx].copy_(conv_states)
840
+ hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv1d.bias, activation=self.activation)
841
+
842
+ # 3. State Space Model sequence transformation
843
+ # 3.a. input varying initialization of time_step, B and C
844
+ ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
845
+ time_step, B, C = torch.split(
846
+ ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
847
+ )
 
 
 
 
 
 
 
 
 
 
 
848
 
849
+ time_step = self.dt_layernorm(time_step)
850
+ B = self.b_layernorm(B)
851
+ C = self.c_layernorm(C)
852
+
853
+ # Here we need to apply dt_proj without the bias, as the bias is added in the selective scan kernel.
854
+ # This is a hack to apply dt_proj while still using the forward pass of `torch.nn.Linear`, which is needed
855
+ # in order to make quantization work. Quantization code replaces `torch.nn.Linear` layers with quantized
856
+ # linear layers, and requires to call the forward pass directly.
857
+ # The original code here was: ```discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)```
858
+ time_proj_bias = self.dt_proj.bias
859
+ self.dt_proj.bias = None
860
+ discrete_time_step = self.dt_proj(time_step).transpose(1, 2)
861
+ self.dt_proj.bias = time_proj_bias
862
+
863
+ A = -torch.exp(self.A_log.float())
864
+ # 3.c perform the recurrence y ← SSM(A, B, C)(x)
865
+ time_proj_bias = time_proj_bias.float() if time_proj_bias is not None else None
866
+ if use_precomputed_states:
867
+ scan_outputs = selective_state_update(
868
+ cache_params.ssm_states[self.layer_idx],
869
+ hidden_states[..., 0],
870
+ discrete_time_step[..., 0],
871
+ A,
872
+ B[:, 0],
873
+ C[:, 0],
874
+ self.D,
875
+ gate[..., 0],
876
+ time_proj_bias,
877
+ dt_softplus=True,
878
+ ).unsqueeze(-1)
879
+ else:
880
+ scan_outputs, ssm_state = selective_scan_fn(
881
+ hidden_states,
882
+ discrete_time_step,
883
+ A,
884
+ B.transpose(1, 2),
885
+ C.transpose(1, 2),
886
+ self.D.float(),
887
+ gate,
888
+ time_proj_bias,
889
+ delta_softplus=True,
890
+ return_last_state=True,
891
  )
892
+ if ssm_state is not None and cache_params is not None:
893
+ cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
894
+
895
+ # 4. Final linear projection
896
+ contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
897
 
 
 
898
  return contextualized_states
899
 
900
  # fmt: off
901
+ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCache = None):
902
  batch_size, seq_len, _ = input_states.shape
903
  dtype = input_states.dtype
904
  # 1. Gated MLP's linear projection
905
  projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
906
  hidden_states, gate = projected_states.chunk(2, dim=1)
907
 
908
+ use_cache = isinstance(cache_params,HybridMambaAttentionDynamicCache)
909
  # 2. Convolution sequence transformation
910
+ if use_cache and cache_params.ssm_states[self.layer_idx].shape[0] == batch_size:
911
  if self.training:
912
  # In training mode, we don't want to perform in-place operations on ssm_state so we can compute the backwards pass
913
  ssm_state = cache_params.ssm_states[self.layer_idx].clone()
914
  else:
915
  ssm_state = cache_params.ssm_states[self.layer_idx]
916
 
917
+ if cache_params.has_previous_state and seq_len == 1 and \
918
+ cache_params.conv_states[self.layer_idx].shape[0] == batch_size:
919
  conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size]
920
  conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
921
  conv_state[:, :, -1] = hidden_states[:, :, 0]
922
+ cache_params.conv_states[self.layer_idx] = conv_state
923
  hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
924
  if self.use_conv_bias:
925
  hidden_states += self.conv1d.bias
 
929
  hidden_states,
930
  (self.conv_kernel_size - hidden_states.shape[-1], 0)
931
  )
932
+ cache_params.conv_states[self.layer_idx] = conv_state
933
  hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
934
  else:
935
  ssm_state = torch.zeros(
 
944
  time_step, B, C = torch.split(
945
  ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
946
  )
947
+
948
+ time_step = self.dt_layernorm(time_step)
949
+ B = self.b_layernorm(B)
950
+ C = self.c_layernorm(C)
951
+
952
  discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size]
953
  discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) # [batch, intermediate_size, seq_len]
954
 
 
968
  scan_output = scan_output + (hidden_states * self.D[None, :, None])
969
  scan_output = (scan_output * self.act(gate))
970
 
971
+ if use_cache:
972
+ cache_params.ssm_states[self.layer_idx] = ssm_state
973
 
974
  # 4. Final linear projection
975
  contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
976
  return contextualized_states
977
  # fmt: on
978
 
979
+ def forward(self, hidden_states, cache_params: HybridMambaAttentionDynamicCache = None):
980
  if self.use_fast_kernels:
981
  if not is_fast_path_available or "cuda" not in self.x_proj.weight.device.type:
982
  raise ValueError(
 
985
  return self.cuda_kernels_forward(hidden_states, cache_params)
986
  return self.slow_forward(hidden_states, cache_params)
987
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
988
 
989
+ # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Jamba
990
  class JambaMLP(nn.Module):
991
+ def __init__(self, config):
992
  super().__init__()
993
+ self.config = config
994
+ self.hidden_size = config.hidden_size
995
+ self.intermediate_size = config.intermediate_size
996
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
997
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
998
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
 
999
  self.act_fn = ACT2FN[config.hidden_act]
1000
 
1001
  def forward(self, x):
 
1015
  and memory on padding.
1016
  """
1017
 
1018
+ def __init__(self, config: JambaConfig):
1019
  super().__init__()
1020
  self.hidden_dim = config.hidden_size
1021
  self.ffn_dim = config.intermediate_size
1022
+ self.num_experts = config.num_experts
1023
+ self.top_k = config.num_experts_per_tok
1024
 
1025
+ self.router = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
 
 
 
 
 
 
 
 
 
1026
  self.experts = nn.ModuleList([JambaMLP(config) for _ in range(self.num_experts)])
1027
 
1028
  def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
1029
  """ """
1030
  batch_size, sequence_length, hidden_dim = hidden_states.shape
1031
 
 
 
 
 
 
 
 
 
 
 
 
 
1032
  hidden_states = hidden_states.view(-1, hidden_dim)
1033
  # router_logits: (batch * sequence_length, n_experts)
1034
  router_logits = self.router(hidden_states)
 
1053
  if top_x.shape[0] == 0:
1054
  continue
1055
 
 
 
 
 
1056
  # Index the correct hidden states and compute the expert hidden state for
1057
  # the current expert. We need to make sure to multiply the output hidden
1058
  # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
1059
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
1060
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
1061
 
1062
  # However `index_add_` only support torch tensors for indexing so we'll use
1063
  # the `top_x` tensor here.
 
1067
 
1068
 
1069
  class JambaAttentionDecoderLayer(nn.Module):
1070
+ def __init__(self, config: JambaConfig, layer_idx: int):
1071
  super().__init__()
1072
+ num_experts = config.layers_num_experts[layer_idx]
1073
  self.self_attn = JAMBA_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
1074
 
1075
+ ffn_layer_class = JambaSparseMoeBlock if num_experts > 1 else JambaMLP
1076
+ self.feed_forward = ffn_layer_class(config)
1077
  self.input_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1078
+ self.pre_ff_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1079
 
1080
  def forward(
1081
  self,
1082
  hidden_states: torch.Tensor,
1083
  attention_mask: Optional[torch.Tensor] = None,
1084
  position_ids: Optional[torch.LongTensor] = None,
1085
+ past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
1086
  output_attentions: Optional[bool] = False,
1087
  output_router_logits: Optional[bool] = False,
1088
  use_cache: Optional[bool] = False,
1089
+ cache_position: Optional[torch.LongTensor] = None,
1090
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 
 
 
 
1091
  """
1092
  Args:
1093
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1094
  attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
1095
  `(batch, sequence_length)` where padding elements are indicated by 0.
1096
+ past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states
1097
  output_attentions (`bool`, *optional*):
1098
  Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1099
  returned tensors for more detail.
 
1103
  use_cache (`bool`, *optional*):
1104
  If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1105
  (see `past_key_values`).
1106
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
1107
+ Indices depicting the position of the input sequence tokens in the sequence.
1108
  """
1109
 
1110
  residual = hidden_states
 
1118
  past_key_value=past_key_value,
1119
  output_attentions=output_attentions,
1120
  use_cache=use_cache,
1121
+ cache_position=cache_position,
1122
  )
1123
 
1124
  # residual connection after attention
1125
  hidden_states = residual + hidden_states
1126
 
1127
+ # feed-forward (experts/MLP)
1128
  residual = hidden_states
1129
+ hidden_states = self.pre_ff_layernorm(hidden_states)
1130
+ ff_outputs = self.feed_forward(hidden_states)
1131
+ if isinstance(ff_outputs, tuple):
1132
+ hidden_states, router_logits = ff_outputs
1133
+ else:
1134
+ hidden_states, router_logits = ff_outputs, None
1135
  hidden_states = residual + hidden_states
1136
 
1137
  outputs = (hidden_states,)
 
1149
 
1150
 
1151
  class JambaMambaDecoderLayer(nn.Module):
1152
+ def __init__(self, config: JambaConfig, layer_idx: int):
1153
  super().__init__()
1154
+ num_experts = config.layers_num_experts[layer_idx]
1155
  self.mamba = JambaMambaMixer(config=config, layer_idx=layer_idx)
1156
 
1157
+ ffn_layer_class = JambaSparseMoeBlock if num_experts > 1 else JambaMLP
1158
+ self.feed_forward = ffn_layer_class(config)
1159
  self.input_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1160
+ self.pre_ff_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1161
 
1162
  def forward(
1163
  self,
 
1168
  output_attentions: Optional[bool] = False,
1169
  output_router_logits: Optional[bool] = False,
1170
  use_cache: Optional[bool] = False,
1171
+ cache_position: Optional[torch.LongTensor] = None,
1172
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 
 
 
 
1173
  """
1174
  Args:
1175
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1176
  attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
1177
  `(batch, sequence_length)` where padding elements are indicated by 0.
1178
+ past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states
1179
  output_attentions (`bool`, *optional*):
1180
  Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1181
  returned tensors for more detail.
 
1185
  use_cache (`bool`, *optional*):
1186
  If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1187
  (see `past_key_values`).
1188
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
1189
+ Indices depicting the position of the input sequence tokens in the sequence.
1190
  """
1191
 
1192
  residual = hidden_states
1193
 
1194
  hidden_states = self.input_layernorm(hidden_states)
1195
 
1196
+ hidden_states = self.mamba(
1197
  hidden_states=hidden_states,
1198
+ cache_params=past_key_value,
1199
  )
1200
+ self_attn_weights = None
 
 
 
1201
 
1202
  # residual connection after mamba
1203
  hidden_states = residual + hidden_states
1204
 
1205
+ # feed-forward (experts/MLP)
1206
  residual = hidden_states
1207
+ hidden_states = self.pre_ff_layernorm(hidden_states)
1208
+ ff_outputs = self.feed_forward(hidden_states)
1209
+ if isinstance(ff_outputs, tuple):
1210
+ hidden_states, router_logits = ff_outputs
1211
+ else:
1212
+ hidden_states, router_logits = ff_outputs, None
1213
  hidden_states = residual + hidden_states
1214
 
1215
  outputs = (hidden_states,)
 
1218
  outputs += (self_attn_weights,)
1219
 
1220
  if use_cache:
1221
+ outputs += (past_key_value,)
1222
 
1223
  if output_router_logits:
1224
  outputs += (router_logits,)
1225
 
1226
  return outputs
1227
 
 
 
 
 
 
 
 
 
 
 
 
 
1228
 
1229
  JAMBA_START_DOCSTRING = r"""
1230
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
 
1247
  "The bare Jamba Model outputting raw hidden-states without any specific head on top.",
1248
  JAMBA_START_DOCSTRING,
1249
  )
 
1250
  class JambaPreTrainedModel(PreTrainedModel):
1251
  config_class = JambaConfig
1252
  base_model_prefix = "model"
 
1268
  if module.padding_idx is not None:
1269
  module.weight.data[module.padding_idx].zero_()
1270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1271
 
1272
  JAMBA_INPUTS_DOCSTRING = r"""
1273
  Args:
 
1304
  config.n_positions - 1]`.
1305
 
1306
  [What are position IDs?](../glossary#position-ids)
1307
+ past_key_values (`HybridMambaAttentionDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1308
+ A HybridMambaAttentionDynamicCache object containing pre-computed hidden-states (keys and values in the
1309
+ self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see
1310
+ `past_key_values` input) to speed up sequential decoding.
1311
+ Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`.
1312
+ Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and
1313
+ `(batch_size, d_inner, d_state)` respectively.
1314
+ See the `HybridMambaAttentionDynamicCache` class for more details.
 
 
 
1315
 
1316
  If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that
1317
  don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
 
1334
  should not be returned during inference.
1335
  return_dict (`bool`, *optional*):
1336
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1337
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
1338
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
1339
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
1340
+ the complete sequence length.
1341
  """
1342
 
1343
+ ALL_DECODER_LAYER_TYPES = {"attention": JambaAttentionDecoderLayer, "mamba": JambaMambaDecoderLayer}
1344
+
1345
 
1346
  @add_start_docstrings(
1347
  "The bare Jamba Model outputting raw hidden-states without any specific head on top.",
 
1362
  self.vocab_size = config.vocab_size
1363
 
1364
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 
 
1365
  decoder_layers = []
1366
  for i in range(config.num_hidden_layers):
1367
+ layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]]
1368
+ decoder_layers.append(layer_class(config, layer_idx=i))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1369
  self.layers = nn.ModuleList(decoder_layers)
1370
 
1371
  self._attn_implementation = config._attn_implementation
 
1381
  def set_input_embeddings(self, value):
1382
  self.embed_tokens = value
1383
 
 
1384
  @add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING)
1385
  def forward(
1386
  self,
1387
  input_ids: torch.LongTensor = None,
1388
  attention_mask: Optional[torch.Tensor] = None,
1389
  position_ids: Optional[torch.LongTensor] = None,
1390
+ past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
1391
  inputs_embeds: Optional[torch.FloatTensor] = None,
1392
  use_cache: Optional[bool] = None,
1393
  output_attentions: Optional[bool] = None,
1394
  output_hidden_states: Optional[bool] = None,
1395
  output_router_logits: Optional[bool] = None,
1396
  return_dict: Optional[bool] = None,
1397
+ cache_position: Optional[torch.LongTensor] = None,
1398
  ) -> Union[Tuple, MoeModelOutputWithPast]:
1399
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1400
  output_router_logits = (
 
1407
 
1408
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1409
 
1410
+ if (input_ids is None) ^ (inputs_embeds is not None):
1411
+ raise ValueError(
1412
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
1413
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1414
 
1415
+ if self.gradient_checkpointing and self.training and use_cache:
1416
+ logger.warning_once(
1417
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
 
1418
  )
1419
+ use_cache = False
 
 
1420
 
1421
  if inputs_embeds is None:
1422
  inputs_embeds = self.embed_tokens(input_ids)
1423
+ hidden_states = inputs_embeds
1424
 
1425
+ if use_cache and past_key_values is None:
1426
+ logger.warning_once(
1427
+ "Jamba requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was "
1428
+ "provided, so no cache will be returned."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1429
  )
1430
 
1431
+ if cache_position is None:
1432
+ cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device)
1433
+ if position_ids is None:
1434
+ position_ids = cache_position.unsqueeze(0)
1435
+
1436
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
1437
 
 
1438
  all_hidden_states = () if output_hidden_states else None
1439
  all_self_attns = () if output_attentions else None
1440
  all_router_logits = () if output_router_logits else None
 
1441
 
1442
  for decoder_layer in self.layers:
1443
  if output_hidden_states:
 
1447
  layer_outputs = self._gradient_checkpointing_func(
1448
  decoder_layer.__call__,
1449
  hidden_states,
1450
+ causal_mask,
1451
  position_ids,
1452
  past_key_values,
1453
  output_attentions,
1454
  output_router_logits,
1455
  use_cache,
1456
+ cache_position,
1457
  )
1458
  else:
1459
  layer_outputs = decoder_layer(
1460
  hidden_states,
1461
+ attention_mask=causal_mask,
1462
  position_ids=position_ids,
1463
  past_key_value=past_key_values,
1464
  output_attentions=output_attentions,
1465
  output_router_logits=output_router_logits,
1466
  use_cache=use_cache,
1467
+ cache_position=cache_position,
1468
  )
1469
 
1470
  hidden_states = layer_outputs[0]
1471
 
 
 
 
1472
  if output_attentions:
1473
+ if layer_outputs[1] is not None:
1474
+ # append attentions only of attention layers. Mamba layers return `None` as the attention weights
1475
+ all_self_attns += (layer_outputs[1],)
1476
 
1477
  if output_router_logits:
1478
+ if layer_outputs[-1] is not None:
1479
+ # append router logits only of expert layers. Regular MLP layers return `None` as the router logits
1480
+ all_router_logits += (layer_outputs[-1],)
1481
 
1482
  hidden_states = self.final_layernorm(hidden_states)
1483
 
 
1485
  if output_hidden_states:
1486
  all_hidden_states += (hidden_states,)
1487
 
1488
+ if past_key_values and not past_key_values.has_previous_state:
1489
+ past_key_values.has_previous_state = True
1490
+
1491
+ next_cache = None if not use_cache else past_key_values
1492
 
1493
  if not return_dict:
1494
  return tuple(
 
1504
  router_logits=all_router_logits,
1505
  )
1506
 
1507
+ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
1508
+ if self.config._attn_implementation == "flash_attention_2":
1509
+ if attention_mask is not None and 0.0 in attention_mask:
1510
+ return attention_mask
1511
+ return None
1512
+
1513
+ dtype, device = input_tensor.dtype, input_tensor.device
1514
+ min_dtype = torch.finfo(dtype).min
1515
+ sequence_length = input_tensor.shape[1]
1516
+ target_length = cache_position[-1] + 1
1517
+
1518
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
1519
+ if sequence_length != 1:
1520
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1521
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1522
+ causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
1523
+ if attention_mask is not None:
1524
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1525
+ if attention_mask.dim() == 2:
1526
+ mask_length = attention_mask.shape[-1]
1527
+ padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
1528
+ causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
1529
+
1530
+ if (
1531
+ self.config._attn_implementation == "sdpa"
1532
+ and attention_mask is not None
1533
+ and attention_mask.device.type == "cuda"
1534
+ ):
1535
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1536
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1537
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1538
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1539
+
1540
+ return causal_mask
1541
+
1542
 
1543
  # Adapted from transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM with MIXTRAL->JAMBA, Mixtral->Jamba
1544
  class JambaForCausalLM(JambaPreTrainedModel):
 
1581
  input_ids: torch.LongTensor = None,
1582
  attention_mask: Optional[torch.Tensor] = None,
1583
  position_ids: Optional[torch.LongTensor] = None,
1584
+ past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
1585
  inputs_embeds: Optional[torch.FloatTensor] = None,
1586
  labels: Optional[torch.LongTensor] = None,
1587
  use_cache: Optional[bool] = None,
 
1589
  output_hidden_states: Optional[bool] = None,
1590
  output_router_logits: Optional[bool] = None,
1591
  return_dict: Optional[bool] = None,
1592
+ cache_position: Optional[torch.LongTensor] = None,
1593
+ num_logits_to_keep: Optional[Union[int, None]] = None,
1594
  ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
1595
  r"""
1596
  Args:
 
1599
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1600
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1601
 
1602
+ num_logits_to_keep (`int` or `None`, *optional*):
1603
+ Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all
1604
+ `input_ids`. Only last token logits are needed for generation, and calculating them only for that token
1605
+ can save memory, which becomes pretty significant for long sequences.
1606
 
1607
  Returns:
1608
+
1609
+ Example:
1610
+
1611
+ ```python
1612
+ >>> from transformers import AutoTokenizer, JambaForCausalLM
1613
+
1614
+ >>> model = JambaForCausalLM.from_pretrained("ai21labs/Jamba-v0.1")
1615
+ >>> tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
1616
+
1617
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1618
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1619
+
1620
+ >>> # Generate
1621
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1622
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1623
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1624
  ```"""
1625
 
1626
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
1644
  output_attentions=output_attentions,
1645
  output_hidden_states=output_hidden_states,
1646
  output_router_logits=output_router_logits,
1647
+ cache_position=cache_position,
1648
  return_dict=return_dict,
1649
  )
1650
 
1651
  hidden_states = outputs[0]
1652
+ if num_logits_to_keep is None:
1653
  logits = self.lm_head(hidden_states)
1654
  else:
1655
+ logits = self.lm_head(hidden_states[..., -num_logits_to_keep:, :])
1656
  logits = logits.float()
1657
 
1658
  loss = None
 
1702
  attention_mask=None,
1703
  inputs_embeds=None,
1704
  output_router_logits=False,
1705
+ cache_position=None,
1706
  **kwargs,
1707
  ):
1708
+ empty_past_kv = past_key_values is None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1709
 
1710
+ # Omit tokens covered by past_key_values
1711
+ if not empty_past_kv:
1712
+ past_length = cache_position[0] if cache_position is not None else attention_mask.shape[1]
1713
+ max_cache_length = self.config.sliding_window
1714
  # Keep only the unprocessed tokens:
1715
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1716
  # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
 
1727
  if (
1728
  max_cache_length is not None
1729
  and attention_mask is not None
1730
+ and past_length + input_ids.shape[1] > max_cache_length
1731
  ):
1732
  attention_mask = attention_mask[:, -max_cache_length:]
1733
+ else:
1734
+ past_key_values = HybridMambaAttentionDynamicCache(
1735
+ self.config, input_ids.shape[0], self.dtype, device=self.device
1736
+ )
1737
 
1738
  position_ids = kwargs.get("position_ids", None)
1739
  if attention_mask is not None and position_ids is None:
1740
  # create position_ids on the fly for batch generation
1741
  position_ids = attention_mask.long().cumsum(-1) - 1
1742
  position_ids.masked_fill_(attention_mask == 0, 1)
1743
+ if not empty_past_kv:
1744
  position_ids = position_ids[:, -input_ids.shape[1] :]
1745
 
1746
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1747
+ if inputs_embeds is not None and empty_past_kv:
1748
  model_inputs = {"inputs_embeds": inputs_embeds}
1749
  else:
1750
  model_inputs = {"input_ids": input_ids}
 
1756
  "use_cache": kwargs.get("use_cache"),
1757
  "attention_mask": attention_mask,
1758
  "output_router_logits": output_router_logits,
1759
+ "num_logits_to_keep": self.config.num_logits_to_keep,
1760
+ "cache_position": cache_position,
1761
  }
1762
  )
1763
  return model_inputs
1764
 
 
 
 
 
 
 
 
 
 
1765
 
1766
  @add_start_docstrings(
1767
  """
special_tokens_map.json CHANGED
@@ -1,6 +1,30 @@
1
  {
2
- "bos_token": "<|startoftext|>",
3
- "eos_token": "<|endoftext|>",
4
- "pad_token": "<|pad|>",
5
- "unk_token": "<|unk|>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  }
 
1
  {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|pad|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<|unk|>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
  }