feat: use latest modeling code
Browse files- configuration_stablelm.py +3 -1
- modeling_stablelm.py +109 -13
configuration_stablelm.py
CHANGED
@@ -45,7 +45,7 @@ class StableLmConfig(PretrainedConfig):
|
|
45 |
intermediate_size (`int`, *optional*, defaults to 6912):
|
46 |
Dimension of the MLP representations.
|
47 |
hidden_size (`int`, *optional*, defaults to 2560):
|
48 |
-
|
49 |
num_hidden_layers (`int`, *optional*, defaults to 32):
|
50 |
Number of hidden layers in the Transformer decoder.
|
51 |
num_attention_heads (`int`, *optional*, defaults to 32):
|
@@ -134,12 +134,14 @@ class StableLmConfig(PretrainedConfig):
|
|
134 |
):
|
135 |
self.vocab_size = vocab_size
|
136 |
self.max_position_embeddings = max_position_embeddings
|
|
|
137 |
self.hidden_size = hidden_size
|
138 |
self.intermediate_size = intermediate_size
|
139 |
self.num_hidden_layers = num_hidden_layers
|
140 |
self.num_attention_heads = num_attention_heads
|
141 |
self.num_key_value_heads = num_key_value_heads
|
142 |
self.hidden_act = hidden_act
|
|
|
143 |
self.initializer_range = initializer_range
|
144 |
self.layer_norm_eps = layer_norm_eps
|
145 |
self.use_cache = use_cache
|
|
|
45 |
intermediate_size (`int`, *optional*, defaults to 6912):
|
46 |
Dimension of the MLP representations.
|
47 |
hidden_size (`int`, *optional*, defaults to 2560):
|
48 |
+
Number of hidden layers in the Transformer decoder.
|
49 |
num_hidden_layers (`int`, *optional*, defaults to 32):
|
50 |
Number of hidden layers in the Transformer decoder.
|
51 |
num_attention_heads (`int`, *optional*, defaults to 32):
|
|
|
134 |
):
|
135 |
self.vocab_size = vocab_size
|
136 |
self.max_position_embeddings = max_position_embeddings
|
137 |
+
|
138 |
self.hidden_size = hidden_size
|
139 |
self.intermediate_size = intermediate_size
|
140 |
self.num_hidden_layers = num_hidden_layers
|
141 |
self.num_attention_heads = num_attention_heads
|
142 |
self.num_key_value_heads = num_key_value_heads
|
143 |
self.hidden_act = hidden_act
|
144 |
+
|
145 |
self.initializer_range = initializer_range
|
146 |
self.layer_norm_eps = layer_norm_eps
|
147 |
self.use_cache = use_cache
|
modeling_stablelm.py
CHANGED
@@ -103,7 +103,7 @@ class StableLmRotaryEmbedding(nn.Module):
|
|
103 |
)
|
104 |
|
105 |
|
106 |
-
# Copied from transformers.models.
|
107 |
class StableLmLinearScalingRotaryEmbedding(StableLmRotaryEmbedding):
|
108 |
"""StableLmRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
109 |
|
@@ -123,7 +123,7 @@ class StableLmLinearScalingRotaryEmbedding(StableLmRotaryEmbedding):
|
|
123 |
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
124 |
|
125 |
|
126 |
-
# Copied from transformers.models.
|
127 |
class StableLmDynamicNTKScalingRotaryEmbedding(StableLmRotaryEmbedding):
|
128 |
"""StableLmRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
129 |
|
@@ -374,6 +374,102 @@ class StableLmAttention(nn.Module):
|
|
374 |
return attn_output, attn_weights, past_key_value
|
375 |
|
376 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
377 |
class StableLmFlashAttention2(StableLmAttention):
|
378 |
"""
|
379 |
StableLM flash attention module. This module inherits from `StableLmAttention` as the weights of the module stays
|
@@ -574,6 +670,7 @@ class StableLmFlashAttention2(StableLmAttention):
|
|
574 |
|
575 |
ATTENTION_CLASSES = {
|
576 |
"eager": StableLmAttention,
|
|
|
577 |
"flash_attention_2": StableLmFlashAttention2,
|
578 |
}
|
579 |
|
@@ -669,7 +766,7 @@ STABLELM_START_DOCSTRING = r"""
|
|
669 |
|
670 |
|
671 |
@add_start_docstrings(
|
672 |
-
"The bare
|
673 |
STABLELM_START_DOCSTRING,
|
674 |
)
|
675 |
class StableLmPreTrainedModel(PreTrainedModel):
|
@@ -680,6 +777,7 @@ class StableLmPreTrainedModel(PreTrainedModel):
|
|
680 |
_skip_keys_device_placement = "past_key_values"
|
681 |
_supports_flash_attn_2 = True
|
682 |
_supports_cache_class = True
|
|
|
683 |
|
684 |
def _init_weights(self, module):
|
685 |
std = self.config.initializer_range
|
@@ -764,7 +862,7 @@ STABLELM_INPUTS_DOCSTRING = r"""
|
|
764 |
|
765 |
|
766 |
@add_start_docstrings(
|
767 |
-
"The bare
|
768 |
STABLELM_START_DOCSTRING,
|
769 |
)
|
770 |
class StableLmModel(StableLmPreTrainedModel):
|
@@ -858,6 +956,11 @@ class StableLmModel(StableLmPreTrainedModel):
|
|
858 |
if self._attn_implementation == "flash_attention_2":
|
859 |
# 2d mask is passed through the layers
|
860 |
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
|
|
|
|
|
|
|
|
|
|
861 |
else:
|
862 |
# 4d mask is passed through the layers
|
863 |
attention_mask = _prepare_4d_causal_attention_mask(
|
@@ -999,7 +1102,7 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
|
|
999 |
>>> # Generate
|
1000 |
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
1001 |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
1002 |
-
'The weather is always wonderful in the San
|
1003 |
```"""
|
1004 |
|
1005 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
@@ -1048,7 +1151,6 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
|
|
1048 |
attentions=outputs.attentions,
|
1049 |
)
|
1050 |
|
1051 |
-
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
|
1052 |
def prepare_inputs_for_generation(
|
1053 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
1054 |
):
|
@@ -1089,12 +1191,6 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
|
|
1089 |
if past_key_values:
|
1090 |
position_ids = position_ids[:, -input_ids.shape[1] :]
|
1091 |
|
1092 |
-
if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
|
1093 |
-
# generation with static cache
|
1094 |
-
seen_tokens = past_key_value.get_seq_length()
|
1095 |
-
input_ids = input_ids[:, seen_tokens:]
|
1096 |
-
position_ids = position_ids[:, seen_tokens:]
|
1097 |
-
|
1098 |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1099 |
if inputs_embeds is not None and past_key_values is None:
|
1100 |
model_inputs = {"inputs_embeds": inputs_embeds}
|
@@ -1123,7 +1219,7 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
|
|
1123 |
|
1124 |
@add_start_docstrings(
|
1125 |
"""
|
1126 |
-
The
|
1127 |
|
1128 |
[`StableLmForSequenceClassification`] uses the last token in order to do the classification, as other causal
|
1129 |
models (e.g. GPT-2) do.
|
|
|
103 |
)
|
104 |
|
105 |
|
106 |
+
# Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->StableLm
|
107 |
class StableLmLinearScalingRotaryEmbedding(StableLmRotaryEmbedding):
|
108 |
"""StableLmRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
109 |
|
|
|
123 |
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
124 |
|
125 |
|
126 |
+
# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->StableLm
|
127 |
class StableLmDynamicNTKScalingRotaryEmbedding(StableLmRotaryEmbedding):
|
128 |
"""StableLmRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
129 |
|
|
|
374 |
return attn_output, attn_weights, past_key_value
|
375 |
|
376 |
|
377 |
+
class StableLmSdpaAttention(StableLmAttention):
|
378 |
+
def forward(
|
379 |
+
self,
|
380 |
+
hidden_states: torch.Tensor,
|
381 |
+
attention_mask: Optional[torch.Tensor] = None,
|
382 |
+
position_ids: Optional[torch.LongTensor] = None,
|
383 |
+
past_key_value: Optional[Cache] = None,
|
384 |
+
output_attentions: bool = False,
|
385 |
+
use_cache: bool = False,
|
386 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
387 |
+
if output_attentions:
|
388 |
+
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
389 |
+
logger.warning_once(
|
390 |
+
"StableLmModel is using StableLmSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
391 |
+
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
392 |
+
)
|
393 |
+
return super().forward(
|
394 |
+
hidden_states=hidden_states,
|
395 |
+
attention_mask=attention_mask,
|
396 |
+
position_ids=position_ids,
|
397 |
+
past_key_value=past_key_value,
|
398 |
+
output_attentions=output_attentions,
|
399 |
+
use_cache=use_cache,
|
400 |
+
)
|
401 |
+
|
402 |
+
bsz, q_len, _ = hidden_states.size()
|
403 |
+
|
404 |
+
query_states = self.q_proj(hidden_states)
|
405 |
+
key_states = self.k_proj(hidden_states)
|
406 |
+
value_states = self.v_proj(hidden_states)
|
407 |
+
|
408 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
409 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
410 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
411 |
+
|
412 |
+
kv_seq_len = key_states.shape[-2]
|
413 |
+
if past_key_value is not None:
|
414 |
+
if self.layer_idx is None:
|
415 |
+
raise ValueError(
|
416 |
+
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
417 |
+
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
418 |
+
"with a layer index."
|
419 |
+
)
|
420 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
421 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
422 |
+
|
423 |
+
# Partial rotary embedding
|
424 |
+
query_rot, query_pass = (
|
425 |
+
query_states[..., : self.rotary_emb.dim],
|
426 |
+
query_states[..., self.rotary_emb.dim :],
|
427 |
+
)
|
428 |
+
key_rot, key_pass = (
|
429 |
+
key_states[..., : self.rotary_emb.dim],
|
430 |
+
key_states[..., self.rotary_emb.dim :],
|
431 |
+
)
|
432 |
+
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
433 |
+
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
|
434 |
+
|
435 |
+
# [batch_size, seq_length, num_heads, head_dim]
|
436 |
+
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
437 |
+
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
438 |
+
|
439 |
+
if past_key_value is not None:
|
440 |
+
# Specific to RoPE models with partial rotation
|
441 |
+
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
|
442 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
443 |
+
|
444 |
+
# Repeat k/v heads if n_kv_heads < n_heads
|
445 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
446 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
447 |
+
|
448 |
+
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
449 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
450 |
+
if query_states.device.type == "cuda" and attention_mask is not None:
|
451 |
+
query_states = query_states.contiguous()
|
452 |
+
key_states = key_states.contiguous()
|
453 |
+
value_states = value_states.contiguous()
|
454 |
+
|
455 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
456 |
+
query_states,
|
457 |
+
key_states,
|
458 |
+
value_states,
|
459 |
+
attn_mask=attention_mask,
|
460 |
+
dropout_p=self.attention_dropout.p if self.training else 0.0,
|
461 |
+
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
462 |
+
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
463 |
+
)
|
464 |
+
|
465 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
466 |
+
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
467 |
+
|
468 |
+
attn_output = self.o_proj(attn_output)
|
469 |
+
|
470 |
+
return attn_output, None, past_key_value
|
471 |
+
|
472 |
+
|
473 |
class StableLmFlashAttention2(StableLmAttention):
|
474 |
"""
|
475 |
StableLM flash attention module. This module inherits from `StableLmAttention` as the weights of the module stays
|
|
|
670 |
|
671 |
ATTENTION_CLASSES = {
|
672 |
"eager": StableLmAttention,
|
673 |
+
"sdpa": StableLmSdpaAttention,
|
674 |
"flash_attention_2": StableLmFlashAttention2,
|
675 |
}
|
676 |
|
|
|
766 |
|
767 |
|
768 |
@add_start_docstrings(
|
769 |
+
"The bare StableLm Model outputting raw hidden-states without any specific head on top.",
|
770 |
STABLELM_START_DOCSTRING,
|
771 |
)
|
772 |
class StableLmPreTrainedModel(PreTrainedModel):
|
|
|
777 |
_skip_keys_device_placement = "past_key_values"
|
778 |
_supports_flash_attn_2 = True
|
779 |
_supports_cache_class = True
|
780 |
+
_supports_sdpa = True
|
781 |
|
782 |
def _init_weights(self, module):
|
783 |
std = self.config.initializer_range
|
|
|
862 |
|
863 |
|
864 |
@add_start_docstrings(
|
865 |
+
"The bare StableLm Model outputting raw hidden-states without any specific head on top.",
|
866 |
STABLELM_START_DOCSTRING,
|
867 |
)
|
868 |
class StableLmModel(StableLmPreTrainedModel):
|
|
|
956 |
if self._attn_implementation == "flash_attention_2":
|
957 |
# 2d mask is passed through the layers
|
958 |
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
959 |
+
# for output_attentions case used fallback to eager attention realization
|
960 |
+
elif self._attn_implementation == "sdpa" and not output_attentions:
|
961 |
+
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
962 |
+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
963 |
+
)
|
964 |
else:
|
965 |
# 4d mask is passed through the layers
|
966 |
attention_mask = _prepare_4d_causal_attention_mask(
|
|
|
1102 |
>>> # Generate
|
1103 |
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
1104 |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
1105 |
+
'The weather is always wonderful in the summer in the city of San Diego. The city is located on the coast of the Pacific Ocean and is surrounded by'
|
1106 |
```"""
|
1107 |
|
1108 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
|
1151 |
attentions=outputs.attentions,
|
1152 |
)
|
1153 |
|
|
|
1154 |
def prepare_inputs_for_generation(
|
1155 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
1156 |
):
|
|
|
1191 |
if past_key_values:
|
1192 |
position_ids = position_ids[:, -input_ids.shape[1] :]
|
1193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1194 |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1195 |
if inputs_embeds is not None and past_key_values is None:
|
1196 |
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
|
1219 |
|
1220 |
@add_start_docstrings(
|
1221 |
"""
|
1222 |
+
The StableLm transformer with a sequence classification head on top (linear layer).
|
1223 |
|
1224 |
[`StableLmForSequenceClassification`] uses the last token in order to do the classification, as other causal
|
1225 |
models (e.g. GPT-2) do.
|