tomeras1 commited on
Commit
61c8679
1 Parent(s): 48aecbb

Move to in-library checkpoint

Browse files
Files changed (1) hide show
  1. configuration_jamba.py +27 -17
configuration_jamba.py CHANGED
@@ -26,9 +26,9 @@ class JambaConfig(PretrainedConfig):
26
  r"""
27
  This is the configuration class to store the configuration of a [`JambaModel`]. It is used to instantiate a
28
  Jamba model according to the specified arguments, defining the model architecture. Instantiating a configuration
29
- with the defaults will yield a similar configuration to that of the jamba-small architecture.
30
 
31
- [ai21labs/jamba-small](https://huggingface.co/ai21labs/Jamba-v0.1)
32
 
33
  Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
34
  documentation from [`PretrainedConfig`] for more information.
@@ -65,12 +65,12 @@ class JambaConfig(PretrainedConfig):
65
  use_cache (`bool`, *optional*, defaults to `True`):
66
  Whether or not the model should return the last key/values attentions (not used by all models). Only
67
  relevant if `config.is_decoder=True`.
68
- calc_logits_for_entire_prompt (`bool`, *optional*, defaults to `False`):
69
- Whether or not to calculate logits for entire prompt during generation. If `False`, only the logits of the
70
- last prompt token will be calculated, which are the only logits needed for generation. For long sequences,
71
- the logits for the entire sequence may use a lot of memory so setting `calc_logits_for_entire_prompt=False`
72
- will reduce memory footprint significantly.
73
- Note: some generation features may not be available if this is set to `False`.
74
  output_router_logits (`bool`, *optional*, defaults to `False`):
75
  Whether or not the router logits should be returned by the model. Enabling this will also
76
  allow the model to output the auxiliary loss. See [here]() for more details
@@ -84,7 +84,7 @@ class JambaConfig(PretrainedConfig):
84
  The id of the "end-of-sequence" token.
85
  sliding_window (`int`, *optional*):
86
  Sliding window attention window size. If not specified, will default to `None`.
87
- n_ctx (`int`, *optional*, defaults to 262144):
88
  This value doesn't have any real effect. The maximum sequence length that this model is intended to be
89
  used with. It can be used with longer sequences, but performance may degrade.
90
  attention_dropout (`float`, *optional*, defaults to 0.0):
@@ -118,8 +118,6 @@ class JambaConfig(PretrainedConfig):
118
  Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block.
119
  mamba_proj_bias (`bool`, *optional*, defaults to `False`):
120
  Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the mamba mixer block
121
- mamba_inner_layernorms (`bool`, *optional*, defaults to `True`):
122
- Flag indicating whether or not to apply layernorms to internal mamba activations
123
 
124
  """
125
 
@@ -139,14 +137,14 @@ class JambaConfig(PretrainedConfig):
139
  initializer_range=0.02,
140
  rms_norm_eps=1e-6,
141
  use_cache=True,
142
- calc_logits_for_entire_prompt=False,
143
  output_router_logits=False,
144
  router_aux_loss_coef=0.001,
145
  pad_token_id=0,
146
  bos_token_id=1,
147
  eos_token_id=2,
148
  sliding_window=None,
149
- n_ctx=262144,
150
  attention_dropout=0.0,
151
  num_experts_per_tok=2,
152
  num_experts=16,
@@ -161,7 +159,6 @@ class JambaConfig(PretrainedConfig):
161
  mamba_dt_rank="auto",
162
  mamba_conv_bias=True,
163
  mamba_proj_bias=False,
164
- mamba_inner_layernorms=True,
165
  **kwargs,
166
  ):
167
  self.vocab_size = vocab_size
@@ -171,7 +168,7 @@ class JambaConfig(PretrainedConfig):
171
  self.num_hidden_layers = num_hidden_layers
172
  self.num_attention_heads = num_attention_heads
173
  self.sliding_window = sliding_window
174
- self.n_ctx = n_ctx
175
  self.attention_dropout = attention_dropout
176
 
177
  # for backward compatibility
@@ -184,7 +181,7 @@ class JambaConfig(PretrainedConfig):
184
  self.rms_norm_eps = rms_norm_eps
185
 
186
  self.use_cache = use_cache
187
- self.calc_logits_for_entire_prompt = calc_logits_for_entire_prompt
188
  self.output_router_logits = output_router_logits
189
  self.router_aux_loss_coef = router_aux_loss_coef
190
 
@@ -202,7 +199,6 @@ class JambaConfig(PretrainedConfig):
202
  self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if mamba_dt_rank == "auto" else mamba_dt_rank
203
  self.mamba_conv_bias = mamba_conv_bias
204
  self.mamba_proj_bias = mamba_proj_bias
205
- self.mamba_inner_layernorms = mamba_inner_layernorms
206
 
207
  super().__init__(
208
  pad_token_id=pad_token_id,
@@ -211,3 +207,17 @@ class JambaConfig(PretrainedConfig):
211
  tie_word_embeddings=tie_word_embeddings,
212
  **kwargs,
213
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  r"""
27
  This is the configuration class to store the configuration of a [`JambaModel`]. It is used to instantiate a
28
  Jamba model according to the specified arguments, defining the model architecture. Instantiating a configuration
29
+ with the defaults will yield a similar configuration to that of the Jamba-v0.1 model.
30
 
31
+ [ai21labs/Jamba-v0.1](https://huggingface.co/ai21labs/Jamba-v0.1)
32
 
33
  Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
34
  documentation from [`PretrainedConfig`] for more information.
 
65
  use_cache (`bool`, *optional*, defaults to `True`):
66
  Whether or not the model should return the last key/values attentions (not used by all models). Only
67
  relevant if `config.is_decoder=True`.
68
+ num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
69
+ Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an
70
+ integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the
71
+ logits of the last prompt token are needed for generation. For long sequences, the logits for the entire
72
+ sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint
73
+ significantly.
74
  output_router_logits (`bool`, *optional*, defaults to `False`):
75
  Whether or not the router logits should be returned by the model. Enabling this will also
76
  allow the model to output the auxiliary loss. See [here]() for more details
 
84
  The id of the "end-of-sequence" token.
85
  sliding_window (`int`, *optional*):
86
  Sliding window attention window size. If not specified, will default to `None`.
87
+ max_position_embeddings (`int`, *optional*, defaults to 262144):
88
  This value doesn't have any real effect. The maximum sequence length that this model is intended to be
89
  used with. It can be used with longer sequences, but performance may degrade.
90
  attention_dropout (`float`, *optional*, defaults to 0.0):
 
118
  Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block.
119
  mamba_proj_bias (`bool`, *optional*, defaults to `False`):
120
  Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the mamba mixer block
 
 
121
 
122
  """
123
 
 
137
  initializer_range=0.02,
138
  rms_norm_eps=1e-6,
139
  use_cache=True,
140
+ num_logits_to_keep=1,
141
  output_router_logits=False,
142
  router_aux_loss_coef=0.001,
143
  pad_token_id=0,
144
  bos_token_id=1,
145
  eos_token_id=2,
146
  sliding_window=None,
147
+ max_position_embeddings=262144,
148
  attention_dropout=0.0,
149
  num_experts_per_tok=2,
150
  num_experts=16,
 
159
  mamba_dt_rank="auto",
160
  mamba_conv_bias=True,
161
  mamba_proj_bias=False,
 
162
  **kwargs,
163
  ):
164
  self.vocab_size = vocab_size
 
168
  self.num_hidden_layers = num_hidden_layers
169
  self.num_attention_heads = num_attention_heads
170
  self.sliding_window = sliding_window
171
+ self.max_position_embeddings = max_position_embeddings
172
  self.attention_dropout = attention_dropout
173
 
174
  # for backward compatibility
 
181
  self.rms_norm_eps = rms_norm_eps
182
 
183
  self.use_cache = use_cache
184
+ self.num_logits_to_keep = num_logits_to_keep
185
  self.output_router_logits = output_router_logits
186
  self.router_aux_loss_coef = router_aux_loss_coef
187
 
 
199
  self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if mamba_dt_rank == "auto" else mamba_dt_rank
200
  self.mamba_conv_bias = mamba_conv_bias
201
  self.mamba_proj_bias = mamba_proj_bias
 
202
 
203
  super().__init__(
204
  pad_token_id=pad_token_id,
 
207
  tie_word_embeddings=tie_word_embeddings,
208
  **kwargs,
209
  )
210
+
211
+ @property
212
+ def layers_block_type(self):
213
+ return [
214
+ "attention" if i % self.attn_layer_period == self.attn_layer_offset else "mamba"
215
+ for i in range(self.num_hidden_layers)
216
+ ]
217
+
218
+ @property
219
+ def layers_num_experts(self):
220
+ return [
221
+ self.num_experts if i % self.expert_layer_period == self.expert_layer_offset else 1
222
+ for i in range(self.num_hidden_layers)
223
+ ]