jon-tow commited on
Commit
26fc2dc
1 Parent(s): ef92c21

feat: use latest modeling code

Browse files
Files changed (2) hide show
  1. configuration_stablelm.py +3 -1
  2. modeling_stablelm.py +109 -13
configuration_stablelm.py CHANGED
@@ -45,7 +45,7 @@ class StableLmConfig(PretrainedConfig):
45
  intermediate_size (`int`, *optional*, defaults to 6912):
46
  Dimension of the MLP representations.
47
  hidden_size (`int`, *optional*, defaults to 2560):
48
- Dimension of the decoder layers and the pooler layer.
49
  num_hidden_layers (`int`, *optional*, defaults to 32):
50
  Number of hidden layers in the Transformer decoder.
51
  num_attention_heads (`int`, *optional*, defaults to 32):
@@ -134,12 +134,14 @@ class StableLmConfig(PretrainedConfig):
134
  ):
135
  self.vocab_size = vocab_size
136
  self.max_position_embeddings = max_position_embeddings
 
137
  self.hidden_size = hidden_size
138
  self.intermediate_size = intermediate_size
139
  self.num_hidden_layers = num_hidden_layers
140
  self.num_attention_heads = num_attention_heads
141
  self.num_key_value_heads = num_key_value_heads
142
  self.hidden_act = hidden_act
 
143
  self.initializer_range = initializer_range
144
  self.layer_norm_eps = layer_norm_eps
145
  self.use_cache = use_cache
 
45
  intermediate_size (`int`, *optional*, defaults to 6912):
46
  Dimension of the MLP representations.
47
  hidden_size (`int`, *optional*, defaults to 2560):
48
+ Number of hidden layers in the Transformer decoder.
49
  num_hidden_layers (`int`, *optional*, defaults to 32):
50
  Number of hidden layers in the Transformer decoder.
51
  num_attention_heads (`int`, *optional*, defaults to 32):
 
134
  ):
135
  self.vocab_size = vocab_size
136
  self.max_position_embeddings = max_position_embeddings
137
+
138
  self.hidden_size = hidden_size
139
  self.intermediate_size = intermediate_size
140
  self.num_hidden_layers = num_hidden_layers
141
  self.num_attention_heads = num_attention_heads
142
  self.num_key_value_heads = num_key_value_heads
143
  self.hidden_act = hidden_act
144
+
145
  self.initializer_range = initializer_range
146
  self.layer_norm_eps = layer_norm_eps
147
  self.use_cache = use_cache
modeling_stablelm.py CHANGED
@@ -103,7 +103,7 @@ class StableLmRotaryEmbedding(nn.Module):
103
  )
104
 
105
 
106
- # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->StableLm
107
  class StableLmLinearScalingRotaryEmbedding(StableLmRotaryEmbedding):
108
  """StableLmRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
109
 
@@ -123,7 +123,7 @@ class StableLmLinearScalingRotaryEmbedding(StableLmRotaryEmbedding):
123
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
124
 
125
 
126
- # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->StableLm
127
  class StableLmDynamicNTKScalingRotaryEmbedding(StableLmRotaryEmbedding):
128
  """StableLmRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
129
 
@@ -374,6 +374,102 @@ class StableLmAttention(nn.Module):
374
  return attn_output, attn_weights, past_key_value
375
 
376
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
  class StableLmFlashAttention2(StableLmAttention):
378
  """
379
  StableLM flash attention module. This module inherits from `StableLmAttention` as the weights of the module stays
@@ -574,6 +670,7 @@ class StableLmFlashAttention2(StableLmAttention):
574
 
575
  ATTENTION_CLASSES = {
576
  "eager": StableLmAttention,
 
577
  "flash_attention_2": StableLmFlashAttention2,
578
  }
579
 
@@ -669,7 +766,7 @@ STABLELM_START_DOCSTRING = r"""
669
 
670
 
671
  @add_start_docstrings(
672
- "The bare StableLM Model outputting raw hidden-states without any specific head on top.",
673
  STABLELM_START_DOCSTRING,
674
  )
675
  class StableLmPreTrainedModel(PreTrainedModel):
@@ -680,6 +777,7 @@ class StableLmPreTrainedModel(PreTrainedModel):
680
  _skip_keys_device_placement = "past_key_values"
681
  _supports_flash_attn_2 = True
682
  _supports_cache_class = True
 
683
 
684
  def _init_weights(self, module):
685
  std = self.config.initializer_range
@@ -764,7 +862,7 @@ STABLELM_INPUTS_DOCSTRING = r"""
764
 
765
 
766
  @add_start_docstrings(
767
- "The bare StableLM Model outputting raw hidden-states without any specific head on top.",
768
  STABLELM_START_DOCSTRING,
769
  )
770
  class StableLmModel(StableLmPreTrainedModel):
@@ -858,6 +956,11 @@ class StableLmModel(StableLmPreTrainedModel):
858
  if self._attn_implementation == "flash_attention_2":
859
  # 2d mask is passed through the layers
860
  attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
 
 
 
 
 
861
  else:
862
  # 4d mask is passed through the layers
863
  attention_mask = _prepare_4d_causal_attention_mask(
@@ -999,7 +1102,7 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
999
  >>> # Generate
1000
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1001
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1002
- 'The weather is always wonderful in the San Juan Islands, and whether you're a vacationer or a resident, here are some ideas for fun!'
1003
  ```"""
1004
 
1005
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -1048,7 +1151,6 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
1048
  attentions=outputs.attentions,
1049
  )
1050
 
1051
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
1052
  def prepare_inputs_for_generation(
1053
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1054
  ):
@@ -1089,12 +1191,6 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
1089
  if past_key_values:
1090
  position_ids = position_ids[:, -input_ids.shape[1] :]
1091
 
1092
- if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
1093
- # generation with static cache
1094
- seen_tokens = past_key_value.get_seq_length()
1095
- input_ids = input_ids[:, seen_tokens:]
1096
- position_ids = position_ids[:, seen_tokens:]
1097
-
1098
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1099
  if inputs_embeds is not None and past_key_values is None:
1100
  model_inputs = {"inputs_embeds": inputs_embeds}
@@ -1123,7 +1219,7 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
1123
 
1124
  @add_start_docstrings(
1125
  """
1126
- The StableLM transformer with a sequence classification head on top (linear layer).
1127
 
1128
  [`StableLmForSequenceClassification`] uses the last token in order to do the classification, as other causal
1129
  models (e.g. GPT-2) do.
 
103
  )
104
 
105
 
106
+ # Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->StableLm
107
  class StableLmLinearScalingRotaryEmbedding(StableLmRotaryEmbedding):
108
  """StableLmRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
109
 
 
123
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
124
 
125
 
126
+ # Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->StableLm
127
  class StableLmDynamicNTKScalingRotaryEmbedding(StableLmRotaryEmbedding):
128
  """StableLmRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
129
 
 
374
  return attn_output, attn_weights, past_key_value
375
 
376
 
377
+ class StableLmSdpaAttention(StableLmAttention):
378
+ def forward(
379
+ self,
380
+ hidden_states: torch.Tensor,
381
+ attention_mask: Optional[torch.Tensor] = None,
382
+ position_ids: Optional[torch.LongTensor] = None,
383
+ past_key_value: Optional[Cache] = None,
384
+ output_attentions: bool = False,
385
+ use_cache: bool = False,
386
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
387
+ if output_attentions:
388
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
389
+ logger.warning_once(
390
+ "StableLmModel is using StableLmSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
391
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
392
+ )
393
+ return super().forward(
394
+ hidden_states=hidden_states,
395
+ attention_mask=attention_mask,
396
+ position_ids=position_ids,
397
+ past_key_value=past_key_value,
398
+ output_attentions=output_attentions,
399
+ use_cache=use_cache,
400
+ )
401
+
402
+ bsz, q_len, _ = hidden_states.size()
403
+
404
+ query_states = self.q_proj(hidden_states)
405
+ key_states = self.k_proj(hidden_states)
406
+ value_states = self.v_proj(hidden_states)
407
+
408
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
409
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
410
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
411
+
412
+ kv_seq_len = key_states.shape[-2]
413
+ if past_key_value is not None:
414
+ if self.layer_idx is None:
415
+ raise ValueError(
416
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
417
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
418
+ "with a layer index."
419
+ )
420
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
421
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
422
+
423
+ # Partial rotary embedding
424
+ query_rot, query_pass = (
425
+ query_states[..., : self.rotary_emb.dim],
426
+ query_states[..., self.rotary_emb.dim :],
427
+ )
428
+ key_rot, key_pass = (
429
+ key_states[..., : self.rotary_emb.dim],
430
+ key_states[..., self.rotary_emb.dim :],
431
+ )
432
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
433
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
434
+
435
+ # [batch_size, seq_length, num_heads, head_dim]
436
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
437
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
438
+
439
+ if past_key_value is not None:
440
+ # Specific to RoPE models with partial rotation
441
+ cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
442
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
443
+
444
+ # Repeat k/v heads if n_kv_heads < n_heads
445
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
446
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
447
+
448
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
449
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
450
+ if query_states.device.type == "cuda" and attention_mask is not None:
451
+ query_states = query_states.contiguous()
452
+ key_states = key_states.contiguous()
453
+ value_states = value_states.contiguous()
454
+
455
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
456
+ query_states,
457
+ key_states,
458
+ value_states,
459
+ attn_mask=attention_mask,
460
+ dropout_p=self.attention_dropout.p if self.training else 0.0,
461
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
462
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
463
+ )
464
+
465
+ attn_output = attn_output.transpose(1, 2).contiguous()
466
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
467
+
468
+ attn_output = self.o_proj(attn_output)
469
+
470
+ return attn_output, None, past_key_value
471
+
472
+
473
  class StableLmFlashAttention2(StableLmAttention):
474
  """
475
  StableLM flash attention module. This module inherits from `StableLmAttention` as the weights of the module stays
 
670
 
671
  ATTENTION_CLASSES = {
672
  "eager": StableLmAttention,
673
+ "sdpa": StableLmSdpaAttention,
674
  "flash_attention_2": StableLmFlashAttention2,
675
  }
676
 
 
766
 
767
 
768
  @add_start_docstrings(
769
+ "The bare StableLm Model outputting raw hidden-states without any specific head on top.",
770
  STABLELM_START_DOCSTRING,
771
  )
772
  class StableLmPreTrainedModel(PreTrainedModel):
 
777
  _skip_keys_device_placement = "past_key_values"
778
  _supports_flash_attn_2 = True
779
  _supports_cache_class = True
780
+ _supports_sdpa = True
781
 
782
  def _init_weights(self, module):
783
  std = self.config.initializer_range
 
862
 
863
 
864
  @add_start_docstrings(
865
+ "The bare StableLm Model outputting raw hidden-states without any specific head on top.",
866
  STABLELM_START_DOCSTRING,
867
  )
868
  class StableLmModel(StableLmPreTrainedModel):
 
956
  if self._attn_implementation == "flash_attention_2":
957
  # 2d mask is passed through the layers
958
  attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
959
+ # for output_attentions case used fallback to eager attention realization
960
+ elif self._attn_implementation == "sdpa" and not output_attentions:
961
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
962
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
963
+ )
964
  else:
965
  # 4d mask is passed through the layers
966
  attention_mask = _prepare_4d_causal_attention_mask(
 
1102
  >>> # Generate
1103
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1104
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1105
+ 'The weather is always wonderful in the summer in the city of San Diego. The city is located on the coast of the Pacific Ocean and is surrounded by'
1106
  ```"""
1107
 
1108
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
1151
  attentions=outputs.attentions,
1152
  )
1153
 
 
1154
  def prepare_inputs_for_generation(
1155
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1156
  ):
 
1191
  if past_key_values:
1192
  position_ids = position_ids[:, -input_ids.shape[1] :]
1193
 
 
 
 
 
 
 
1194
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1195
  if inputs_embeds is not None and past_key_values is None:
1196
  model_inputs = {"inputs_embeds": inputs_embeds}
 
1219
 
1220
  @add_start_docstrings(
1221
  """
1222
+ The StableLm transformer with a sequence classification head on top (linear layer).
1223
 
1224
  [`StableLmForSequenceClassification`] uses the last token in order to do the classification, as other causal
1225
  models (e.g. GPT-2) do.