TinyPixel commited on
Commit
75c2858
1 Parent(s): 342718b

Upload StableLMAlphaForCausalLM

Browse files
config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "stabilityai/stablelm-base-alpha-3b-v2",
3
+ "architectures": [
4
+ "StableLMAlphaForCausalLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_stablelm_alpha.StableLMAlphaConfig",
8
+ "AutoModelForCausalLM": "modeling_stablelm_alpha.StableLMAlphaForCausalLM"
9
+ },
10
+ "bos_token_id": 0,
11
+ "eos_token_id": 0,
12
+ "hidden_act": "silu",
13
+ "hidden_size": 2560,
14
+ "initializer_range": 0.02,
15
+ "max_position_embeddings": 4096,
16
+ "model_type": "stablelm_alpha",
17
+ "norm_eps": 1e-05,
18
+ "num_heads": 32,
19
+ "num_hidden_layers": 32,
20
+ "rotary_emb_base": 10000,
21
+ "rotary_pct": 0.25,
22
+ "rotary_scaling_factor": 1.0,
23
+ "tie_word_embeddings": false,
24
+ "torch_dtype": "bfloat16",
25
+ "transformers_version": "4.35.0",
26
+ "use_cache": true,
27
+ "vocab_size": 50432
28
+ }
configuration_stablelm_alpha.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Stability and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ StableLM β model configuration"""
16
+
17
+ from transformers import PretrainedConfig
18
+ from transformers.utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+ STABLE_LM_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
24
+
25
+
26
+ class StableLMAlphaConfig(PretrainedConfig):
27
+ r"""
28
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
29
+ documentation from [`PretrainedConfig`] for more information.
30
+
31
+ Args:
32
+ vocab_size (`int`, *optional*, defaults to 50432):
33
+ Vocabulary size of the StableLM model. Defines the number of different tokens that
34
+ can be represented by the `inputs_ids` passed when calling [`StableLMAlphaModel`].
35
+ hidden_size (`int`, *optional*, defaults to 6144):
36
+ Dimension of the decoder layers and the pooler layer.
37
+ num_hidden_layers (`int`, *optional*, defaults to 44):
38
+ Number of hidden layers in the Transformer decoder.
39
+ num_heads (`int`, *optional*, defaults to 64):
40
+ Number of attention heads for each attention layer in the Transformer decoder.
41
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
42
+ The non-linear activation function (function or string).
43
+ rotary_pct (`float`, *optional*, defaults to 0.25):
44
+ Percentage of hidden dimensions to allocate to rotary embeddings.
45
+ rotary_emb_base (`int`, *optional*, defaults to 10000)
46
+ Base for computing rotary embeddings frequency.
47
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
48
+ The maximum sequence length that this model might ever be used with.
49
+ Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
50
+ initializer_range (`float`, *optional*, defaults to 1e-5):
51
+ The standard deviation of the truncated_normal_initializer for initializing
52
+ all weight matrices.
53
+ norm_eps (`float`, *optional*, defaults to 1e-5):
54
+ The epsilon used by the normalization layers.
55
+ use_cache (`bool`, *optional*, defaults to `True`):
56
+ Whether or not the model should return the last key/values attentions
57
+ (not used by all models). Only relevant if `config.is_decoder=True`.
58
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
59
+ Whether to tie weight embeddings
60
+
61
+ Example:
62
+
63
+ ```python
64
+ >>> from transformers import StableLMAlphaConfig, StableLMAlphaModel
65
+
66
+ >>> # Initializing a StableLMAlphaConfig style configuration
67
+ >>> configuration = StableLMAlphaConfig()
68
+
69
+ >>> # Initializing a model (with random weights) from the style configuration
70
+ >>> model = StableLMAlphaModel(configuration) # doctest: +SKIP
71
+
72
+ >>> # Accessing the model configuration
73
+ >>> configuration = model.config # doctest: +SKIP
74
+ ```"""
75
+ model_type = "stablelm_alpha"
76
+ keys_to_ignore_at_inference = ["past_key_values"]
77
+
78
+ def __init__(
79
+ self,
80
+ vocab_size=50_432,
81
+ hidden_size=2_560,
82
+ num_hidden_layers=32,
83
+ num_heads=32,
84
+ hidden_act="silu",
85
+ rotary_pct=0.25,
86
+ rotary_emb_base=10_000,
87
+ max_position_embeddings=2_048,
88
+ initializer_range=0.02,
89
+ norm_eps=1e-5,
90
+ use_cache=True,
91
+ bos_token_id=0,
92
+ eos_token_id=2,
93
+ tie_word_embeddings=False,
94
+ **kwargs,
95
+ ):
96
+ self.vocab_size = vocab_size
97
+ self.max_position_embeddings = max_position_embeddings
98
+ self.hidden_size = hidden_size
99
+ self.num_hidden_layers = num_hidden_layers
100
+ self.num_heads = num_heads
101
+ self.hidden_act = hidden_act
102
+ self.rotary_pct = rotary_pct
103
+ self.rotary_emb_base = rotary_emb_base
104
+ self.initializer_range = initializer_range
105
+ self.norm_eps = norm_eps
106
+ self.use_cache = use_cache
107
+ self.tie_word_embeddings = tie_word_embeddings
108
+ super().__init__(
109
+ bos_token_id=bos_token_id,
110
+ eos_token_id=eos_token_id,
111
+ tie_word_embeddings=tie_word_embeddings,
112
+ **kwargs,
113
+ )
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "eos_token_id": 0,
5
+ "transformers_version": "4.35.0"
6
+ }
model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7a511801f8ef0c4628a64b6c7914466b80ba76de62e91820b00381b1996244b
3
+ size 4981064288
model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90113364a59095bf4cf36b1c567f478fcbaf94cc5c77d1a46b33180661723a88
3
+ size 610828072
model.safetensors.index.json ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 5591869440
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "model-00002-of-00002.safetensors",
7
+ "transformer.embed.weight": "model-00001-of-00002.safetensors",
8
+ "transformer.final_norm.bias": "model-00002-of-00002.safetensors",
9
+ "transformer.final_norm.weight": "model-00002-of-00002.safetensors",
10
+ "transformer.layers.0.attention.out_proj.weight": "model-00001-of-00002.safetensors",
11
+ "transformer.layers.0.attention.qkv_proj.weight": "model-00001-of-00002.safetensors",
12
+ "transformer.layers.0.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
13
+ "transformer.layers.0.mlp.out_proj.weight": "model-00001-of-00002.safetensors",
14
+ "transformer.layers.0.norm.bias": "model-00001-of-00002.safetensors",
15
+ "transformer.layers.0.norm.weight": "model-00001-of-00002.safetensors",
16
+ "transformer.layers.1.attention.out_proj.weight": "model-00001-of-00002.safetensors",
17
+ "transformer.layers.1.attention.qkv_proj.weight": "model-00001-of-00002.safetensors",
18
+ "transformer.layers.1.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
19
+ "transformer.layers.1.mlp.out_proj.weight": "model-00001-of-00002.safetensors",
20
+ "transformer.layers.1.norm.bias": "model-00001-of-00002.safetensors",
21
+ "transformer.layers.1.norm.weight": "model-00001-of-00002.safetensors",
22
+ "transformer.layers.10.attention.out_proj.weight": "model-00001-of-00002.safetensors",
23
+ "transformer.layers.10.attention.qkv_proj.weight": "model-00001-of-00002.safetensors",
24
+ "transformer.layers.10.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
25
+ "transformer.layers.10.mlp.out_proj.weight": "model-00001-of-00002.safetensors",
26
+ "transformer.layers.10.norm.bias": "model-00001-of-00002.safetensors",
27
+ "transformer.layers.10.norm.weight": "model-00001-of-00002.safetensors",
28
+ "transformer.layers.11.attention.out_proj.weight": "model-00001-of-00002.safetensors",
29
+ "transformer.layers.11.attention.qkv_proj.weight": "model-00001-of-00002.safetensors",
30
+ "transformer.layers.11.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
31
+ "transformer.layers.11.mlp.out_proj.weight": "model-00001-of-00002.safetensors",
32
+ "transformer.layers.11.norm.bias": "model-00001-of-00002.safetensors",
33
+ "transformer.layers.11.norm.weight": "model-00001-of-00002.safetensors",
34
+ "transformer.layers.12.attention.out_proj.weight": "model-00001-of-00002.safetensors",
35
+ "transformer.layers.12.attention.qkv_proj.weight": "model-00001-of-00002.safetensors",
36
+ "transformer.layers.12.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
37
+ "transformer.layers.12.mlp.out_proj.weight": "model-00001-of-00002.safetensors",
38
+ "transformer.layers.12.norm.bias": "model-00001-of-00002.safetensors",
39
+ "transformer.layers.12.norm.weight": "model-00001-of-00002.safetensors",
40
+ "transformer.layers.13.attention.out_proj.weight": "model-00001-of-00002.safetensors",
41
+ "transformer.layers.13.attention.qkv_proj.weight": "model-00001-of-00002.safetensors",
42
+ "transformer.layers.13.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
43
+ "transformer.layers.13.mlp.out_proj.weight": "model-00001-of-00002.safetensors",
44
+ "transformer.layers.13.norm.bias": "model-00001-of-00002.safetensors",
45
+ "transformer.layers.13.norm.weight": "model-00001-of-00002.safetensors",
46
+ "transformer.layers.14.attention.out_proj.weight": "model-00001-of-00002.safetensors",
47
+ "transformer.layers.14.attention.qkv_proj.weight": "model-00001-of-00002.safetensors",
48
+ "transformer.layers.14.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
49
+ "transformer.layers.14.mlp.out_proj.weight": "model-00001-of-00002.safetensors",
50
+ "transformer.layers.14.norm.bias": "model-00001-of-00002.safetensors",
51
+ "transformer.layers.14.norm.weight": "model-00001-of-00002.safetensors",
52
+ "transformer.layers.15.attention.out_proj.weight": "model-00001-of-00002.safetensors",
53
+ "transformer.layers.15.attention.qkv_proj.weight": "model-00001-of-00002.safetensors",
54
+ "transformer.layers.15.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
55
+ "transformer.layers.15.mlp.out_proj.weight": "model-00001-of-00002.safetensors",
56
+ "transformer.layers.15.norm.bias": "model-00001-of-00002.safetensors",
57
+ "transformer.layers.15.norm.weight": "model-00001-of-00002.safetensors",
58
+ "transformer.layers.16.attention.out_proj.weight": "model-00001-of-00002.safetensors",
59
+ "transformer.layers.16.attention.qkv_proj.weight": "model-00001-of-00002.safetensors",
60
+ "transformer.layers.16.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
61
+ "transformer.layers.16.mlp.out_proj.weight": "model-00001-of-00002.safetensors",
62
+ "transformer.layers.16.norm.bias": "model-00001-of-00002.safetensors",
63
+ "transformer.layers.16.norm.weight": "model-00001-of-00002.safetensors",
64
+ "transformer.layers.17.attention.out_proj.weight": "model-00001-of-00002.safetensors",
65
+ "transformer.layers.17.attention.qkv_proj.weight": "model-00001-of-00002.safetensors",
66
+ "transformer.layers.17.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
67
+ "transformer.layers.17.mlp.out_proj.weight": "model-00001-of-00002.safetensors",
68
+ "transformer.layers.17.norm.bias": "model-00001-of-00002.safetensors",
69
+ "transformer.layers.17.norm.weight": "model-00001-of-00002.safetensors",
70
+ "transformer.layers.18.attention.out_proj.weight": "model-00001-of-00002.safetensors",
71
+ "transformer.layers.18.attention.qkv_proj.weight": "model-00001-of-00002.safetensors",
72
+ "transformer.layers.18.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
73
+ "transformer.layers.18.mlp.out_proj.weight": "model-00001-of-00002.safetensors",
74
+ "transformer.layers.18.norm.bias": "model-00001-of-00002.safetensors",
75
+ "transformer.layers.18.norm.weight": "model-00001-of-00002.safetensors",
76
+ "transformer.layers.19.attention.out_proj.weight": "model-00001-of-00002.safetensors",
77
+ "transformer.layers.19.attention.qkv_proj.weight": "model-00001-of-00002.safetensors",
78
+ "transformer.layers.19.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
79
+ "transformer.layers.19.mlp.out_proj.weight": "model-00001-of-00002.safetensors",
80
+ "transformer.layers.19.norm.bias": "model-00001-of-00002.safetensors",
81
+ "transformer.layers.19.norm.weight": "model-00001-of-00002.safetensors",
82
+ "transformer.layers.2.attention.out_proj.weight": "model-00001-of-00002.safetensors",
83
+ "transformer.layers.2.attention.qkv_proj.weight": "model-00001-of-00002.safetensors",
84
+ "transformer.layers.2.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
85
+ "transformer.layers.2.mlp.out_proj.weight": "model-00001-of-00002.safetensors",
86
+ "transformer.layers.2.norm.bias": "model-00001-of-00002.safetensors",
87
+ "transformer.layers.2.norm.weight": "model-00001-of-00002.safetensors",
88
+ "transformer.layers.20.attention.out_proj.weight": "model-00001-of-00002.safetensors",
89
+ "transformer.layers.20.attention.qkv_proj.weight": "model-00001-of-00002.safetensors",
90
+ "transformer.layers.20.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
91
+ "transformer.layers.20.mlp.out_proj.weight": "model-00001-of-00002.safetensors",
92
+ "transformer.layers.20.norm.bias": "model-00001-of-00002.safetensors",
93
+ "transformer.layers.20.norm.weight": "model-00001-of-00002.safetensors",
94
+ "transformer.layers.21.attention.out_proj.weight": "model-00001-of-00002.safetensors",
95
+ "transformer.layers.21.attention.qkv_proj.weight": "model-00001-of-00002.safetensors",
96
+ "transformer.layers.21.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
97
+ "transformer.layers.21.mlp.out_proj.weight": "model-00001-of-00002.safetensors",
98
+ "transformer.layers.21.norm.bias": "model-00001-of-00002.safetensors",
99
+ "transformer.layers.21.norm.weight": "model-00001-of-00002.safetensors",
100
+ "transformer.layers.22.attention.out_proj.weight": "model-00001-of-00002.safetensors",
101
+ "transformer.layers.22.attention.qkv_proj.weight": "model-00001-of-00002.safetensors",
102
+ "transformer.layers.22.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
103
+ "transformer.layers.22.mlp.out_proj.weight": "model-00001-of-00002.safetensors",
104
+ "transformer.layers.22.norm.bias": "model-00001-of-00002.safetensors",
105
+ "transformer.layers.22.norm.weight": "model-00001-of-00002.safetensors",
106
+ "transformer.layers.23.attention.out_proj.weight": "model-00001-of-00002.safetensors",
107
+ "transformer.layers.23.attention.qkv_proj.weight": "model-00001-of-00002.safetensors",
108
+ "transformer.layers.23.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
109
+ "transformer.layers.23.mlp.out_proj.weight": "model-00001-of-00002.safetensors",
110
+ "transformer.layers.23.norm.bias": "model-00001-of-00002.safetensors",
111
+ "transformer.layers.23.norm.weight": "model-00001-of-00002.safetensors",
112
+ "transformer.layers.24.attention.out_proj.weight": "model-00001-of-00002.safetensors",
113
+ "transformer.layers.24.attention.qkv_proj.weight": "model-00001-of-00002.safetensors",
114
+ "transformer.layers.24.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
115
+ "transformer.layers.24.mlp.out_proj.weight": "model-00001-of-00002.safetensors",
116
+ "transformer.layers.24.norm.bias": "model-00001-of-00002.safetensors",
117
+ "transformer.layers.24.norm.weight": "model-00001-of-00002.safetensors",
118
+ "transformer.layers.25.attention.out_proj.weight": "model-00001-of-00002.safetensors",
119
+ "transformer.layers.25.attention.qkv_proj.weight": "model-00001-of-00002.safetensors",
120
+ "transformer.layers.25.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
121
+ "transformer.layers.25.mlp.out_proj.weight": "model-00001-of-00002.safetensors",
122
+ "transformer.layers.25.norm.bias": "model-00001-of-00002.safetensors",
123
+ "transformer.layers.25.norm.weight": "model-00001-of-00002.safetensors",
124
+ "transformer.layers.26.attention.out_proj.weight": "model-00001-of-00002.safetensors",
125
+ "transformer.layers.26.attention.qkv_proj.weight": "model-00001-of-00002.safetensors",
126
+ "transformer.layers.26.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
127
+ "transformer.layers.26.mlp.out_proj.weight": "model-00001-of-00002.safetensors",
128
+ "transformer.layers.26.norm.bias": "model-00001-of-00002.safetensors",
129
+ "transformer.layers.26.norm.weight": "model-00001-of-00002.safetensors",
130
+ "transformer.layers.27.attention.out_proj.weight": "model-00001-of-00002.safetensors",
131
+ "transformer.layers.27.attention.qkv_proj.weight": "model-00001-of-00002.safetensors",
132
+ "transformer.layers.27.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
133
+ "transformer.layers.27.mlp.out_proj.weight": "model-00001-of-00002.safetensors",
134
+ "transformer.layers.27.norm.bias": "model-00001-of-00002.safetensors",
135
+ "transformer.layers.27.norm.weight": "model-00001-of-00002.safetensors",
136
+ "transformer.layers.28.attention.out_proj.weight": "model-00001-of-00002.safetensors",
137
+ "transformer.layers.28.attention.qkv_proj.weight": "model-00001-of-00002.safetensors",
138
+ "transformer.layers.28.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
139
+ "transformer.layers.28.mlp.out_proj.weight": "model-00001-of-00002.safetensors",
140
+ "transformer.layers.28.norm.bias": "model-00001-of-00002.safetensors",
141
+ "transformer.layers.28.norm.weight": "model-00001-of-00002.safetensors",
142
+ "transformer.layers.29.attention.out_proj.weight": "model-00001-of-00002.safetensors",
143
+ "transformer.layers.29.attention.qkv_proj.weight": "model-00001-of-00002.safetensors",
144
+ "transformer.layers.29.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
145
+ "transformer.layers.29.mlp.out_proj.weight": "model-00002-of-00002.safetensors",
146
+ "transformer.layers.29.norm.bias": "model-00001-of-00002.safetensors",
147
+ "transformer.layers.29.norm.weight": "model-00001-of-00002.safetensors",
148
+ "transformer.layers.3.attention.out_proj.weight": "model-00001-of-00002.safetensors",
149
+ "transformer.layers.3.attention.qkv_proj.weight": "model-00001-of-00002.safetensors",
150
+ "transformer.layers.3.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
151
+ "transformer.layers.3.mlp.out_proj.weight": "model-00001-of-00002.safetensors",
152
+ "transformer.layers.3.norm.bias": "model-00001-of-00002.safetensors",
153
+ "transformer.layers.3.norm.weight": "model-00001-of-00002.safetensors",
154
+ "transformer.layers.30.attention.out_proj.weight": "model-00002-of-00002.safetensors",
155
+ "transformer.layers.30.attention.qkv_proj.weight": "model-00002-of-00002.safetensors",
156
+ "transformer.layers.30.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
157
+ "transformer.layers.30.mlp.out_proj.weight": "model-00002-of-00002.safetensors",
158
+ "transformer.layers.30.norm.bias": "model-00002-of-00002.safetensors",
159
+ "transformer.layers.30.norm.weight": "model-00002-of-00002.safetensors",
160
+ "transformer.layers.31.attention.out_proj.weight": "model-00002-of-00002.safetensors",
161
+ "transformer.layers.31.attention.qkv_proj.weight": "model-00002-of-00002.safetensors",
162
+ "transformer.layers.31.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
163
+ "transformer.layers.31.mlp.out_proj.weight": "model-00002-of-00002.safetensors",
164
+ "transformer.layers.31.norm.bias": "model-00002-of-00002.safetensors",
165
+ "transformer.layers.31.norm.weight": "model-00002-of-00002.safetensors",
166
+ "transformer.layers.4.attention.out_proj.weight": "model-00001-of-00002.safetensors",
167
+ "transformer.layers.4.attention.qkv_proj.weight": "model-00001-of-00002.safetensors",
168
+ "transformer.layers.4.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
169
+ "transformer.layers.4.mlp.out_proj.weight": "model-00001-of-00002.safetensors",
170
+ "transformer.layers.4.norm.bias": "model-00001-of-00002.safetensors",
171
+ "transformer.layers.4.norm.weight": "model-00001-of-00002.safetensors",
172
+ "transformer.layers.5.attention.out_proj.weight": "model-00001-of-00002.safetensors",
173
+ "transformer.layers.5.attention.qkv_proj.weight": "model-00001-of-00002.safetensors",
174
+ "transformer.layers.5.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
175
+ "transformer.layers.5.mlp.out_proj.weight": "model-00001-of-00002.safetensors",
176
+ "transformer.layers.5.norm.bias": "model-00001-of-00002.safetensors",
177
+ "transformer.layers.5.norm.weight": "model-00001-of-00002.safetensors",
178
+ "transformer.layers.6.attention.out_proj.weight": "model-00001-of-00002.safetensors",
179
+ "transformer.layers.6.attention.qkv_proj.weight": "model-00001-of-00002.safetensors",
180
+ "transformer.layers.6.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
181
+ "transformer.layers.6.mlp.out_proj.weight": "model-00001-of-00002.safetensors",
182
+ "transformer.layers.6.norm.bias": "model-00001-of-00002.safetensors",
183
+ "transformer.layers.6.norm.weight": "model-00001-of-00002.safetensors",
184
+ "transformer.layers.7.attention.out_proj.weight": "model-00001-of-00002.safetensors",
185
+ "transformer.layers.7.attention.qkv_proj.weight": "model-00001-of-00002.safetensors",
186
+ "transformer.layers.7.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
187
+ "transformer.layers.7.mlp.out_proj.weight": "model-00001-of-00002.safetensors",
188
+ "transformer.layers.7.norm.bias": "model-00001-of-00002.safetensors",
189
+ "transformer.layers.7.norm.weight": "model-00001-of-00002.safetensors",
190
+ "transformer.layers.8.attention.out_proj.weight": "model-00001-of-00002.safetensors",
191
+ "transformer.layers.8.attention.qkv_proj.weight": "model-00001-of-00002.safetensors",
192
+ "transformer.layers.8.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
193
+ "transformer.layers.8.mlp.out_proj.weight": "model-00001-of-00002.safetensors",
194
+ "transformer.layers.8.norm.bias": "model-00001-of-00002.safetensors",
195
+ "transformer.layers.8.norm.weight": "model-00001-of-00002.safetensors",
196
+ "transformer.layers.9.attention.out_proj.weight": "model-00001-of-00002.safetensors",
197
+ "transformer.layers.9.attention.qkv_proj.weight": "model-00001-of-00002.safetensors",
198
+ "transformer.layers.9.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
199
+ "transformer.layers.9.mlp.out_proj.weight": "model-00001-of-00002.safetensors",
200
+ "transformer.layers.9.norm.bias": "model-00001-of-00002.safetensors",
201
+ "transformer.layers.9.norm.weight": "model-00001-of-00002.safetensors"
202
+ }
203
+ }
modeling_stablelm_alpha.py ADDED
@@ -0,0 +1,656 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ # This code is based off the following work:
17
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
18
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
19
+ """ PyTorch StableLM-Alpha model. """
20
+ from typing import Optional, Tuple, Union
21
+ import math
22
+
23
+ import torch
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from torch.nn import CrossEntropyLoss
27
+ from transformers.modeling_outputs import (
28
+ BaseModelOutputWithPast,
29
+ CausalLMOutputWithPast,
30
+ )
31
+ from transformers.modeling_utils import PreTrainedModel
32
+ from transformers.utils import logging
33
+
34
+ from .configuration_stablelm_alpha import StableLMAlphaConfig
35
+
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+
40
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
41
+ """Expands attention_mask from `[batch_size, seq_len]` to `[batch_size, 1, tgt_seq_len, src_seq_len]`."""
42
+ batch_size, src_len = mask.size()
43
+ tgt_len = tgt_len if tgt_len is not None else src_len
44
+
45
+ expanded_mask = mask[:, None, None, :].expand(batch_size, 1, tgt_len, src_len).to(dtype)
46
+ inverted_mask = 1.0 - expanded_mask
47
+
48
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
49
+
50
+
51
+ class LayerNorm(nn.LayerNorm):
52
+ def __init__(self, normalized_shape: torch.Size, bias: bool = True, **kwargs):
53
+ r"""
54
+ bias (`bool`, default = True): whether to use the bias term.
55
+ """
56
+ super().__init__(normalized_shape, **kwargs)
57
+ if not bias:
58
+ self.bias = None
59
+
60
+
61
+ class DecoderLayer(nn.Module):
62
+ def __init__(self, config: StableLMAlphaConfig):
63
+ super().__init__()
64
+
65
+ self.norm = LayerNorm(config.hidden_size, eps=config.norm_eps)
66
+ self.attention = Attention(config)
67
+ self.mlp = MLP(config)
68
+
69
+ def forward(
70
+ self,
71
+ hidden_states: Optional[torch.FloatTensor],
72
+ attention_mask: Optional[torch.FloatTensor] = None,
73
+ position_ids: Optional[torch.LongTensor] = None,
74
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
75
+ output_attentions: Optional[bool] = False,
76
+ use_cache: Optional[bool] = False,
77
+ ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
78
+ residual = hidden_states
79
+
80
+ # Pre-Norm
81
+ hidden_states = self.norm(hidden_states)
82
+
83
+ # Self-Attention
84
+ attn_output, attn_weights, present_key_value = self.attention(
85
+ hidden_states=hidden_states,
86
+ attention_mask=attention_mask,
87
+ position_ids=position_ids,
88
+ past_key_value=past_key_value,
89
+ use_cache=use_cache,
90
+ output_attentions=output_attentions,
91
+ )
92
+
93
+ # Feed-forward
94
+ mlp_output = self.mlp(hidden_states)
95
+
96
+ hidden_states = residual + attn_output + mlp_output
97
+
98
+ outputs = (hidden_states,)
99
+ if output_attentions:
100
+ outputs += (attn_weights,)
101
+ if use_cache:
102
+ outputs += (present_key_value,)
103
+ return outputs # hidden_states, (optional: attn_weights), (optional: present_key_value)
104
+
105
+
106
+ class MLP(nn.Module):
107
+ def __init__(self, config: StableLMAlphaConfig):
108
+ super().__init__()
109
+
110
+ hidden_size = config.hidden_size
111
+ multiple_of = 256
112
+ ff_dim = int(8 * hidden_size / 3)
113
+ intermediate_size = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of)
114
+
115
+ self.gate_proj = nn.Linear(hidden_size, 2 * intermediate_size, bias=False)
116
+ self.out_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
117
+ self.act = nn.SiLU()
118
+
119
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
120
+ ff, ff_gate = self.gate_proj(x).chunk(2, dim=-1)
121
+ return self.out_proj(ff * self.act(ff_gate))
122
+
123
+
124
+ class RotaryEmbedding(nn.Module):
125
+ def __init__(
126
+ self,
127
+ dim: int,
128
+ max_position_embeddings: int,
129
+ base: int = 10_000,
130
+ device: Optional[torch.device] = None,
131
+ ):
132
+ super().__init__()
133
+
134
+ self.dim = dim
135
+ self.max_position_embeddings = max_position_embeddings
136
+ self.base = base
137
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
138
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
139
+
140
+ # Build here to make `torch.jit.trace` work.
141
+ self._set_cos_sin_cache(
142
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
143
+ )
144
+
145
+ def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
146
+ self.max_seq_len_cached = seq_len
147
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
148
+ freqs = torch.outer(t, self.inv_freq)
149
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
150
+ emb = torch.cat((freqs, freqs), dim=-1)
151
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
152
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
153
+
154
+ def forward(self, x: torch.Tensor, seq_len: Optional[int] = None):
155
+ # x: [batch_size, num_heads, seq_len, head_size]
156
+ if seq_len > self.max_seq_len_cached:
157
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=torch.get_default_dtype())
158
+ return (
159
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
160
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
161
+ )
162
+
163
+
164
+ def rotate_half(x: torch.Tensor):
165
+ """Rotates half the hidden dims of the input."""
166
+ x1, x2 = torch.chunk(x, 2, dim=-1)
167
+ return torch.cat((-x2, x1), dim=-1)
168
+
169
+
170
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
171
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
172
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
173
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
174
+ cos = cos[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim]
175
+ sin = sin[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim]
176
+ q_embed = (q * cos) + (rotate_half(q) * sin)
177
+ k_embed = (k * cos) + (rotate_half(k) * sin)
178
+ return q_embed, k_embed
179
+
180
+
181
+ class Attention(nn.Module):
182
+ def __init__(self, config: StableLMAlphaConfig):
183
+ super().__init__()
184
+
185
+ self.config = config
186
+ self.hidden_size = config.hidden_size
187
+ self.num_heads = config.num_heads
188
+ self.head_dim = self.hidden_size // self.num_heads
189
+ self.max_position_embeddings = config.max_position_embeddings
190
+ if self.hidden_size % self.num_heads != 0:
191
+ raise ValueError(
192
+ "`hidden_size` is not divisble by the number of attention heads! Make sure to update them"
193
+ )
194
+
195
+ self.qkv_proj = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
196
+ self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
197
+ self._init_rope()
198
+
199
+ def _init_rope(self):
200
+ self.rotary_ndims = int(self.head_dim * self.config.rotary_pct)
201
+ self.rotary_emb = RotaryEmbedding(
202
+ self.rotary_ndims,
203
+ max_position_embeddings=self.config.max_position_embeddings,
204
+ base=self.config.rotary_emb_base,
205
+ )
206
+
207
+ def forward(
208
+ self,
209
+ hidden_states: torch.FloatTensor,
210
+ attention_mask: torch.FloatTensor,
211
+ position_ids: torch.LongTensor,
212
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
213
+ output_attentions: Optional[bool] = False,
214
+ use_cache: Optional[bool] = False,
215
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
216
+ has_past_key_value = past_key_value is not None
217
+
218
+ # Compute QKV
219
+ # [batch_size, seq_len, (num_heads * 3 * head_dim)]
220
+ qkv = self.qkv_proj(hidden_states)
221
+
222
+ # [batch_size, seq_len, num_heads, 3 * head_dim]
223
+ new_qkv_shape = qkv.size()[:-1] + (self.num_heads, 3 * self.head_dim)
224
+ qkv = qkv.view(*new_qkv_shape)
225
+
226
+ # 3 * [batch_size, num_heads, seq_len, head_dim]
227
+ query = qkv[..., : self.head_dim].permute(0, 2, 1, 3)
228
+ key = qkv[..., self.head_dim:(2 * self.head_dim)].permute(0, 2, 1, 3)
229
+ value = qkv[..., (2 * self.head_dim):].permute(0, 2, 1, 3)
230
+
231
+ # Compute rotary embeddings on rotary_ndims
232
+ # [batch_size, num_heads, seq_len, rotary_ndims]
233
+ query_rot = query[..., :self.rotary_ndims]
234
+ query_pass = query[..., self.rotary_ndims:]
235
+ key_rot = key[..., :self.rotary_ndims]
236
+ key_pass = key[..., self.rotary_ndims:]
237
+
238
+ # Compute token offset for rotary embeddings (when decoding)
239
+ kv_seq_len = key.shape[-2]
240
+ if has_past_key_value:
241
+ kv_seq_len += past_key_value[0].shape[-2]
242
+
243
+ # Add rotary embeddings to query and key
244
+ cos, sin = self.rotary_emb(value, seq_len=kv_seq_len)
245
+ query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
246
+
247
+ # Concatenate rotary embeddings with pass-through query and key
248
+ # [batch_size, num_heads, seq_len, head_dim]
249
+ query = torch.cat((query, query_pass), dim=-1)
250
+ key = torch.cat((key, key_pass), dim=-1)
251
+
252
+ # Reuse past key-value states
253
+ if has_past_key_value:
254
+ key = torch.cat((past_key_value[0], key), dim=2)
255
+ value = torch.cat((past_key_value[1], value), dim=2)
256
+ present_key_value = (key, value) if use_cache else None
257
+
258
+ # [batch_size, num_heads, seq_len, head_dim]
259
+ query = query.transpose(1, 2).contiguous()
260
+ key = key.transpose(1, 2).contiguous()
261
+ value = value.transpose(1, 2).contiguous()
262
+
263
+ # Compute attention
264
+ softmax_scale = 1 / math.sqrt(self.head_dim)
265
+ attn_scores = torch.einsum('bthd,bshd->bhts', query, key * softmax_scale)
266
+ # Apply the attention mask
267
+ if attention_mask is not None:
268
+ attn_scores = attn_scores + attention_mask
269
+ attn_weights = nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32).to(query.dtype)
270
+ attn_output = torch.einsum('bhts,bshd->bthd', attn_weights, value)
271
+
272
+ # Merge heads
273
+ attn_output = attn_output.reshape(attn_output.shape[0], attn_output.shape[1], -1)
274
+
275
+ # Final linear projection
276
+ attn_output = self.out_proj(attn_output)
277
+
278
+ if not output_attentions:
279
+ attn_weights = None
280
+
281
+ return attn_output, attn_weights, present_key_value
282
+
283
+
284
+ def attention_mask_func(attention_scores: torch.Tensor, ltor_mask: torch.Tensor):
285
+ attention_scores.masked_fill_(~ltor_mask, torch.finfo(attention_scores.dtype).min)
286
+ return attention_scores
287
+
288
+
289
+ class StableLMAlphaPreTrainedModel(PreTrainedModel):
290
+ """An abstract class to handle weights initialization and a simple interface
291
+ for downloading and loading pretrained models.
292
+ """
293
+
294
+ config_class = StableLMAlphaConfig
295
+ base_model_prefix = "transformer"
296
+ supports_gradient_checkpointing = True
297
+ _no_split_modules = ["DecoderLayer"]
298
+ _skip_keys_device_placement = "past_key_values"
299
+
300
+ def _init_weights(self, module: nn.Module):
301
+ """Initialize the weights"""
302
+ if isinstance(module, nn.Linear):
303
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
304
+ if module.bias is not None:
305
+ module.bias.data.zero_()
306
+ elif isinstance(module, nn.Embedding):
307
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
308
+ if module.padding_idx is not None:
309
+ module.weight.data[module.padding_idx].zero_()
310
+ elif isinstance(module, nn.LayerNorm):
311
+ module.bias.data.zero_()
312
+ module.weight.data.fill_(1.0)
313
+
314
+ def _set_gradient_checkpointing(self, module: nn.Module, value=False):
315
+ if isinstance(module, StableLMAlphaModel):
316
+ module.gradient_checkpointing = value
317
+
318
+
319
+ def _make_causal_mask(
320
+ input_ids_shape: torch.Size,
321
+ dtype: torch.dtype,
322
+ device: torch.device,
323
+ past_key_values_length: int = 0
324
+ ):
325
+ """Make causal mask used for bi-directional self-attention."""
326
+ batch_size, tgt_len = input_ids_shape
327
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(torch.float16).min, device=device)
328
+ mask_cond = torch.arange(mask.size(-1), device=device)
329
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
330
+ mask = mask.to(dtype)
331
+ if past_key_values_length > 0:
332
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
333
+ return mask[None, None, :, :].expand(batch_size, 1, tgt_len, tgt_len + past_key_values_length)
334
+
335
+
336
+ class StableLMAlphaModel(StableLMAlphaPreTrainedModel):
337
+ def __init__(self, config: StableLMAlphaConfig):
338
+ super().__init__(config)
339
+ self.config = config
340
+
341
+ self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
342
+ self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
343
+ self.final_norm = LayerNorm(config.hidden_size, eps=config.norm_eps)
344
+
345
+ self.gradient_checkpointing = False
346
+ self.post_init()
347
+
348
+ def get_input_embeddings(self):
349
+ return self.embed
350
+
351
+ def set_input_embeddings(self, value: nn.Module):
352
+ self.embed = value
353
+
354
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
355
+ def _prepare_decoder_attention_mask(
356
+ self,
357
+ attention_mask: torch.Tensor,
358
+ input_shape: torch.Size,
359
+ inputs_embeds: torch.Tensor,
360
+ past_key_values_length: int,
361
+ ):
362
+ # Create causal mask
363
+ # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
364
+ combined_attention_mask = None
365
+ if input_shape[-1] > 1:
366
+ combined_attention_mask = _make_causal_mask(
367
+ input_shape,
368
+ inputs_embeds.dtype,
369
+ device=inputs_embeds.device,
370
+ past_key_values_length=past_key_values_length,
371
+ )
372
+
373
+ if attention_mask is not None:
374
+ # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
375
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
376
+ inputs_embeds.device
377
+ )
378
+ combined_attention_mask = (
379
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
380
+ )
381
+
382
+ return combined_attention_mask
383
+
384
+ def forward(
385
+ self,
386
+ input_ids: Optional[torch.LongTensor] = None,
387
+ attention_mask: Optional[torch.FloatTensor] = None,
388
+ position_ids: Optional[torch.LongTensor] = None,
389
+ inputs_embeds: Optional[torch.FloatTensor] = None,
390
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
391
+ use_cache: Optional[bool] = None,
392
+ output_attentions: Optional[bool] = None,
393
+ output_hidden_states: Optional[bool] = None,
394
+ return_dict: Optional[bool] = None,
395
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
396
+ r"""
397
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers`
398
+ with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
399
+ Contains precomputed key and value hidden states of the attention blocks.
400
+ Can be used to speed up decoding. If `past_key_values` are used, the user
401
+ can optionally input only the last `decoder_input_ids` (those that don't
402
+ have their past key value states given to this model) of shape `(batch_size, 1)`
403
+ instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
404
+ use_cache (`bool`, *optional*):
405
+ If set to `True`, `past_key_values` key value states are returned and
406
+ can be used to speed up decoding (see `past_key_values`).
407
+ """
408
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
409
+ output_hidden_states = (
410
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
411
+ )
412
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
413
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
414
+
415
+ if input_ids is not None and inputs_embeds is not None:
416
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
417
+ elif input_ids is not None:
418
+ input_shape = input_ids.size()
419
+ elif inputs_embeds is not None:
420
+ input_shape = inputs_embeds.size()[:-1]
421
+ else:
422
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
423
+
424
+ batch_size, seq_length = input_shape
425
+
426
+ if past_key_values is None:
427
+ past_key_values_length = 0
428
+ past_key_values = tuple([None] * self.config.num_hidden_layers)
429
+ seq_length_with_past = seq_length
430
+ else:
431
+ past_key_values_length = past_key_values[0][0].shape[2]
432
+ seq_length_with_past = seq_length + past_key_values_length
433
+
434
+ if position_ids is None:
435
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
436
+ position_ids = torch.arange(past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device)
437
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
438
+ else:
439
+ position_ids = position_ids.view(-1, seq_length).long()
440
+
441
+ if inputs_embeds is None:
442
+ inputs_embeds = self.embed(input_ids)
443
+
444
+ # Attention mask.
445
+ if attention_mask is None:
446
+ attention_mask = torch.ones(
447
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
448
+ )
449
+ attention_mask = self._prepare_decoder_attention_mask(
450
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
451
+ )
452
+
453
+ hidden_states = inputs_embeds
454
+
455
+ if self.gradient_checkpointing and self.training:
456
+ if use_cache:
457
+ logger.warning(
458
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
459
+ )
460
+ use_cache = False
461
+
462
+ all_hidden_states = () if output_hidden_states else None
463
+ all_attentions = () if output_attentions else None
464
+ present_key_values = () if use_cache else None
465
+
466
+ for _, (decoder_layer, past_key_value) in enumerate(zip(self.layers, past_key_values)):
467
+ if output_hidden_states:
468
+ all_hidden_states = all_hidden_states + (hidden_states,)
469
+
470
+ if self.gradient_checkpointing and self.training:
471
+
472
+ def create_custom_forward(module):
473
+ def custom_forward(*inputs):
474
+ # `None` for `use_cache`
475
+ return module(*inputs, output_attentions, None)
476
+
477
+ return custom_forward
478
+
479
+ outputs = torch.utils.checkpoint.checkpoint(
480
+ create_custom_forward(decoder_layer),
481
+ hidden_states,
482
+ attention_mask,
483
+ position_ids,
484
+ # `None` for `past_key_value`
485
+ None,
486
+ )
487
+ else:
488
+ outputs = decoder_layer(
489
+ hidden_states,
490
+ attention_mask=attention_mask,
491
+ position_ids=position_ids,
492
+ past_key_value=past_key_value,
493
+ output_attentions=output_attentions,
494
+ use_cache=use_cache,
495
+ )
496
+
497
+ hidden_states = outputs[0]
498
+
499
+ if output_attentions:
500
+ all_attentions = all_attentions + (outputs[1],)
501
+
502
+ if use_cache:
503
+ present_key_values += (outputs[2 if output_attentions else 1],)
504
+
505
+ hidden_states = self.final_norm(hidden_states)
506
+
507
+ # Add last hidden state
508
+ if output_hidden_states:
509
+ all_hidden_states += (hidden_states,)
510
+
511
+ present_key_values = present_key_values if use_cache else None
512
+ if not return_dict:
513
+ return tuple(v for v in [hidden_states, present_key_values, all_hidden_states, all_attentions] if v is not None)
514
+
515
+ return BaseModelOutputWithPast(
516
+ last_hidden_state=hidden_states,
517
+ past_key_values=present_key_values,
518
+ hidden_states=all_hidden_states,
519
+ attentions=all_attentions,
520
+ )
521
+
522
+
523
+ class StableLMAlphaForCausalLM(StableLMAlphaPreTrainedModel):
524
+ _tied_weights_keys = ["lm_head.weight"]
525
+
526
+ def __init__(self, config: StableLMAlphaConfig):
527
+ super().__init__(config)
528
+
529
+ self.transformer = StableLMAlphaModel(config)
530
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
531
+
532
+ self.post_init()
533
+
534
+ def get_output_embeddings(self):
535
+ return self.lm_head
536
+
537
+ def set_output_embeddings(self, new_embeddings: nn.Module):
538
+ self.lm_head = new_embeddings
539
+
540
+ def forward(
541
+ self,
542
+ input_ids: Optional[torch.LongTensor] = None,
543
+ attention_mask: Optional[torch.FloatTensor] = None,
544
+ position_ids: Optional[torch.LongTensor] = None,
545
+ inputs_embeds: Optional[torch.FloatTensor] = None,
546
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
547
+ labels: Optional[torch.LongTensor] = None,
548
+ use_cache: Optional[bool] = None,
549
+ output_attentions: Optional[bool] = None,
550
+ output_hidden_states: Optional[bool] = None,
551
+ return_dict: Optional[bool] = None,
552
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
553
+ r"""
554
+ Example:
555
+
556
+ ```python
557
+ >>> from transformers import AutoTokenizer, StableLMAlphaForCausalLM, StableLMAlphaConfig
558
+ >>> import torch
559
+
560
+ >>> tokenizer = AutoTokenizer.from_pretrained("stabilityai/stablelm-base-alpha-3b-v2", trust_remote_code=True)
561
+ >>> config = StableLMAlphaConfig.from_pretrained("stabilityai/stablelm-base-alpha-3b-v2")
562
+ >>> config.is_decoder = True
563
+ >>> model = StableLMAlphaForCausalLM.from_pretrained("stabilityai/stablelm-base-alpha-3b-v2", config=config)
564
+
565
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
566
+ >>> outputs = model(**inputs)
567
+
568
+ >>> logits = outputs.logits
569
+ ```
570
+ """
571
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
572
+
573
+ outputs = self.transformer(
574
+ input_ids,
575
+ attention_mask=attention_mask,
576
+ position_ids=position_ids,
577
+ inputs_embeds=inputs_embeds,
578
+ past_key_values=past_key_values,
579
+ use_cache=use_cache,
580
+ output_attentions=output_attentions,
581
+ output_hidden_states=output_hidden_states,
582
+ return_dict=return_dict,
583
+ )
584
+
585
+ hidden_states = outputs[0]
586
+ logits = self.lm_head(hidden_states)
587
+
588
+ lm_loss = None
589
+ if labels is not None:
590
+ # move labels to correct device to enable model parallelism
591
+ labels = labels.to(logits.device)
592
+ # we are doing next-token prediction; shift prediction scores and input ids by one
593
+ shift_logits = logits[:, :-1, :].contiguous()
594
+ labels = labels[:, 1:].contiguous()
595
+ loss_fct = CrossEntropyLoss()
596
+ lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
597
+
598
+ if not return_dict:
599
+ output = (logits,) + outputs[1:]
600
+ return ((lm_loss,) + output) if lm_loss is not None else output
601
+
602
+ return CausalLMOutputWithPast(
603
+ loss=lm_loss,
604
+ logits=logits,
605
+ past_key_values=outputs.past_key_values,
606
+ hidden_states=outputs.hidden_states,
607
+ attentions=outputs.attentions,
608
+ )
609
+
610
+ def prepare_inputs_for_generation(
611
+ self,
612
+ input_ids,
613
+ past_key_values: Optional[torch.Tensor] = None,
614
+ attention_mask: Optional[torch.Tensor] = None,
615
+ inputs_embeds: Optional[torch.Tensor] = None,
616
+ **kwargs
617
+ ):
618
+ # Cut decoder_input_ids if past is used
619
+ if past_key_values and past_key_values[0] is not None:
620
+ input_ids = input_ids[:, -1:]
621
+
622
+ position_ids = kwargs.get("position_ids", None)
623
+ if attention_mask is not None and position_ids is None:
624
+ # Create position_ids on the fly for batch generation
625
+ position_ids = attention_mask.long().cumsum(-1) - 1
626
+ position_ids.masked_fill_(attention_mask == 0, 1)
627
+ if past_key_values:
628
+ position_ids = position_ids[:, -1].unsqueeze(-1)
629
+
630
+ # If `inputs_embeds` are passed, we only want to use them in the 1st generation step
631
+ if inputs_embeds is not None and past_key_values is None:
632
+ model_inputs = {"inputs_embeds": inputs_embeds}
633
+ else:
634
+ model_inputs = {"input_ids": input_ids}
635
+
636
+ model_inputs.update(
637
+ {
638
+ "attention_mask": attention_mask,
639
+ "past_key_values": past_key_values,
640
+ "position_ids": position_ids,
641
+ }
642
+ )
643
+
644
+ return model_inputs
645
+
646
+ def _reorder_cache(self, past_key_values: torch.Tensor, beam_idx: int):
647
+ reordered_past = ()
648
+ for past_key_value in past_key_values:
649
+ reordered_past += (
650
+ tuple(past_state.index_select(0, beam_idx) for past_state in past_key_value[:2]) + past_key_value[2:],
651
+ )
652
+ return reordered_past
653
+
654
+
655
+ StableLMAlphaConfig.register_for_auto_class()
656
+ StableLMAlphaForCausalLM.register_for_auto_class("AutoModelForCausalLM")