Text Generation
Transformers
Safetensors
English
stablelm
causal-lm
conversational
Inference Endpoints
pvduy commited on
Commit
5dd3c8f
1 Parent(s): 26f37f2

Upload StableLMEpochForCausalLM

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "StableLM 2 12B Chat",
3
+ "architectures": [
4
+ "StableLMEpochForCausalLM"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_stablelm_epoch.StableLMEpochConfig",
9
+ "AutoModelForCausalLM": "modeling_stablelm_epoch_2.StableLMEpochForCausalLM"
10
+ },
11
+ "bos_token_id": null,
12
+ "eos_token_id": 100257,
13
+ "hidden_act": "silu",
14
+ "hidden_size": 5120,
15
+ "initializer_range": 0.01,
16
+ "intermediate_size": 13824,
17
+ "max_position_embeddings": 4096,
18
+ "model_type": "stablelm_epoch",
19
+ "norm_eps": 1e-05,
20
+ "num_attention_heads": 32,
21
+ "num_hidden_layers": 40,
22
+ "num_key_value_heads": 8,
23
+ "rope_pct": 0.25,
24
+ "rope_theta": 10000,
25
+ "rotary_scaling_factor": 1.0,
26
+ "tie_word_embeddings": false,
27
+ "torch_dtype": "float32",
28
+ "transformers_version": "4.39.3",
29
+ "use_cache": true,
30
+ "use_norm_bias": false,
31
+ "use_qkv_bias": false,
32
+ "vocab_size": 100352
33
+ }
configuration_stablelm_epoch.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Stability and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """ StableLM Epoch model configuration"""
15
+ from transformers import PretrainedConfig
16
+ from transformers.utils import logging
17
+
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+
22
+ class StableLMEpochConfig(PretrainedConfig):
23
+ r"""
24
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
25
+ documentation from [`PretrainedConfig`] for more information.
26
+
27
+ Args:
28
+ vocab_size (`int`, *optional*, defaults to 50_304):
29
+ Vocabulary size of the StableLM model. Defines the number of different tokens that
30
+ can be represented by the `inputs_ids` passed when calling [`StableLMEpochModel`].
31
+ intermediate_size (`int`, *optional*, defaults to 6912):
32
+ Dimension of the MLP representations.
33
+ hidden_size (`int`, *optional*, defaults to 2560):
34
+ Dimension of the decoder layers and the pooler layer.
35
+ num_hidden_layers (`int`, *optional*, defaults to 32):
36
+ Number of hidden layers in the Transformer decoder.
37
+ num_attention_heads (`int`, *optional*, defaults to 32):
38
+ Number of attention heads for each attention layer in the Transformer encoder.
39
+ num_key_value_heads (`int`, *optional*):
40
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
41
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
42
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
43
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
44
+ by meanpooling all the original heads within that group. For more details checkout [this
45
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
46
+ `num_attention_heads`.
47
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
48
+ The non-linear activation function (function or string).
49
+ rope_pct (`float`, *optional*, defaults to 1.0):
50
+ Percentage of hidden dimensions to allocate to rotary embeddings.
51
+ rope_theta (`float`, *optional*, defaults to 10000.0):
52
+ The base period of the RoPE embeddings.
53
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
54
+ The maximum sequence length that this model might ever be used with.
55
+ Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
56
+ initializer_range (`float`, *optional*, defaults to 1e-5):
57
+ The standard deviation of the truncated_normal_initializer for initializing
58
+ all weight matrices.
59
+ norm_eps (`float`, *optional*, defaults to 1e-8):
60
+ The epsilon used by the normalization layers.
61
+ use_cache (`bool`, *optional*, defaults to `True`):
62
+ Whether or not the model should return the last key/values attentions
63
+ (not used by all models). Only relevant if `config.is_decoder=True`.
64
+ use_qkv_bias (`bool`, *optional*, defaults to `True`):
65
+ Whether or not the model should use bias for qkv layers.
66
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
67
+ Whether to tie weight embeddings
68
+ attention_dropout (`float`, *optional*, defaults to 0.0):
69
+ The dropout ratio for the attention probabilities.
70
+ """
71
+ model_type = "stablelm_epoch"
72
+ keys_to_ignore_at_inference = ["past_key_values"]
73
+
74
+ def __init__(
75
+ self,
76
+ vocab_size=50_304,
77
+ intermediate_size=6912,
78
+ hidden_size=2560,
79
+ num_hidden_layers=32,
80
+ num_attention_heads=32,
81
+ num_key_value_heads=32,
82
+ hidden_act="silu",
83
+ rope_pct=0.25,
84
+ rope_theta=10_000,
85
+ max_position_embeddings=4096,
86
+ initializer_range=0.02,
87
+ norm_eps=1.0e-5,
88
+ use_cache=True,
89
+ use_qkv_bias=True,
90
+ bos_token_id=0,
91
+ eos_token_id=2,
92
+ tie_word_embeddings=False,
93
+ attention_dropout: float = 0.0,
94
+ **kwargs,
95
+ ):
96
+ self.vocab_size = vocab_size
97
+ self.max_position_embeddings = max_position_embeddings
98
+ self.intermediate_size = intermediate_size
99
+ self.hidden_size = hidden_size
100
+ self.num_hidden_layers = num_hidden_layers
101
+ self.num_attention_heads = num_attention_heads
102
+ self.num_key_value_heads = num_key_value_heads
103
+ self.hidden_act = hidden_act
104
+ self.rope_pct = rope_pct
105
+ self.rope_theta = rope_theta
106
+ self.initializer_range = initializer_range
107
+ self.norm_eps = norm_eps
108
+ self.use_cache = use_cache
109
+ self.use_qkv_bias = use_qkv_bias
110
+ self.tie_word_embeddings = tie_word_embeddings
111
+ self.attention_dropout = attention_dropout
112
+ super().__init__(
113
+ bos_token_id=bos_token_id,
114
+ eos_token_id=eos_token_id,
115
+ tie_word_embeddings=tie_word_embeddings,
116
+ **kwargs,
117
+ )
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "eos_token_id": 100257,
4
+ "transformers_version": "4.39.3",
5
+ "use_cache": false
6
+ }
model-00001-of-00011.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e928474815f93d595c6b202fbd2b698d89d06238b6cb3f92226b8eda5cf618b1
3
+ size 4823584128
model-00002-of-00011.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3405704ac8dab6b1e250e82ce5a5df05afa57250a2d23e86bf4ff2aac90f5e08
3
+ size 4729284832
model-00003-of-00011.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5eac27302b85369068e71463fc80c19424618681d44c7a629e9191cb5df8b76b
3
+ size 4991480240
model-00004-of-00011.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54989e02a20e6d1928ffc9637e09c43704a2c10aec0ed99f3cea4cc109c1c0eb
3
+ size 4729285024
model-00005-of-00011.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e873e44844c5ab8dc8300eabff6a4bd9eccba65a8fea313ce82544cd891cff0
3
+ size 4729285024
model-00006-of-00011.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cbb6a0fe1129991074b78cab0204cd069abfd8ee476cf251f2db362633142e25
3
+ size 4991480472
model-00007-of-00011.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae450cf52a7623d127684f9f3afef5ede510e8e7c47e1a54d21df9d108991de3
3
+ size 4729285024
model-00008-of-00011.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb97d7f42b1f06e5cb5984b7878e71d7b56dad9097852b63c7b31d4476c33eb6
3
+ size 4729285024
model-00009-of-00011.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a0112031177f788e3cbd0f161f39d402768606371d559212edea13f7459cd82
3
+ size 4991480472
model-00010-of-00011.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca3662c80ae96779d7e429df20bbd46d2ed14f54e3ca50de22a4218e281cd23d
3
+ size 3072493336
model-00011-of-00011.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f65d17008fc4840405f0a73c9762a9c8014bd5cbd32439ad5912e42a25dff1d8
3
+ size 2055209088
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_stablelm_epoch_2.py ADDED
@@ -0,0 +1,1085 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This code is based off the following work:
16
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
17
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
18
+ """ PyTorch StableLM Epoch model. """
19
+ from typing import Optional, Tuple, Union
20
+ import math
21
+ import warnings
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+ import torch.utils.checkpoint
26
+ from torch import nn
27
+ from torch.nn import CrossEntropyLoss
28
+
29
+ from transformers.cache_utils import Cache
30
+ from transformers.modeling_outputs import (
31
+ BaseModelOutputWithPast,
32
+ CausalLMOutputWithPast,
33
+ )
34
+ from transformers.modeling_utils import PreTrainedModel
35
+ from transformers.utils import logging, is_flash_attn_greater_or_equal_2_10
36
+
37
+ from .configuration_stablelm_epoch import StableLMEpochConfig
38
+
39
+ try:
40
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
41
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
42
+ except:
43
+ flash_attn_func, flash_attn_varlen_func = None, None
44
+ index_first_axis, pad_input, unpad_input = None, None, None
45
+
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+
50
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
51
+ def _get_unpad_data(attention_mask):
52
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
53
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
54
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
55
+ cu_seqlens = F.pad(
56
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
57
+ )
58
+ return (
59
+ indices,
60
+ cu_seqlens,
61
+ max_seqlen_in_batch,
62
+ )
63
+
64
+
65
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
66
+ def _make_causal_mask(
67
+ input_ids_shape: torch.Size,
68
+ dtype: torch.dtype,
69
+ device: torch.device,
70
+ past_key_values_length: int = 0,
71
+ ):
72
+ """Make causal mask used for bi-directional self-attention."""
73
+ batch_size, tgt_len = input_ids_shape
74
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(torch.float16).min, device=device)
75
+ mask_cond = torch.arange(mask.size(-1), device=device)
76
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
77
+ mask = mask.to(dtype)
78
+ if past_key_values_length > 0:
79
+ mask = torch.cat(
80
+ [
81
+ torch.zeros(
82
+ tgt_len, past_key_values_length, dtype=dtype, device=device
83
+ ),
84
+ mask,
85
+ ],
86
+ dim=-1,
87
+ )
88
+ return mask[None, None, :, :].expand(
89
+ batch_size, 1, tgt_len, tgt_len + past_key_values_length
90
+ )
91
+
92
+
93
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
94
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
95
+ """Expands attention_mask from `[batch_size, seq_len]` to `[batch_size, 1, tgt_seq_len, src_seq_len]`."""
96
+ batch_size, src_len = mask.size()
97
+ tgt_len = tgt_len if tgt_len is not None else src_len
98
+
99
+ expanded_mask = (
100
+ mask[:, None, None, :].expand(batch_size, 1, tgt_len, src_len).to(dtype)
101
+ )
102
+ inverted_mask = 1.0 - expanded_mask
103
+
104
+ return inverted_mask.masked_fill(
105
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
106
+ )
107
+
108
+
109
+ class RotaryEmbedding(nn.Module):
110
+ def __init__(
111
+ self,
112
+ dim: int,
113
+ max_position_embeddings: int,
114
+ base: int = 10_000,
115
+ device: Optional[torch.device] = None,
116
+ ):
117
+ super().__init__()
118
+
119
+ self.dim = dim
120
+ self.max_position_embeddings = max_position_embeddings
121
+ self.base = base
122
+ inv_freq = 1.0 / (
123
+ self.base
124
+ ** (
125
+ torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
126
+ / self.dim
127
+ )
128
+ )
129
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
130
+
131
+ # Build here to make `torch.jit.trace` work.
132
+ self._set_cos_sin_cache(
133
+ seq_len=max_position_embeddings,
134
+ device=self.inv_freq.device,
135
+ dtype=torch.get_default_dtype(),
136
+ )
137
+
138
+ def _set_cos_sin_cache(
139
+ self, seq_len: int, device: torch.device, dtype: torch.dtype
140
+ ):
141
+ self.max_seq_len_cached = seq_len
142
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
143
+
144
+ # Don't do einsum, it converts fp32 to fp16 under AMP
145
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
146
+ freqs = torch.outer(t, self.inv_freq)
147
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
148
+ emb = torch.cat((freqs, freqs), dim=-1)
149
+ self.register_buffer(
150
+ "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
151
+ )
152
+ self.register_buffer(
153
+ "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
154
+ )
155
+
156
+ def forward(self, x: torch.Tensor, seq_len: Optional[int] = None):
157
+ # x: [batch_size, num_heads, seq_len, head_size]
158
+ if seq_len > self.max_seq_len_cached:
159
+ self._set_cos_sin_cache(
160
+ seq_len=seq_len, device=x.device, dtype=torch.get_default_dtype()
161
+ )
162
+ return (
163
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
164
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
165
+ )
166
+
167
+
168
+ def rotate_half(x: torch.Tensor):
169
+ """Rotates half the hidden dims of the input."""
170
+ x1, x2 = torch.chunk(x, 2, dim=-1)
171
+ return torch.cat((-x2, x1), dim=-1)
172
+
173
+
174
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
175
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
176
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
177
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
178
+ cos = cos[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim]
179
+ sin = sin[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim]
180
+ q_embed = (q * cos) + (rotate_half(q) * sin)
181
+ k_embed = (k * cos) + (rotate_half(k) * sin)
182
+ return q_embed, k_embed
183
+
184
+
185
+ class LayerNormPerHead(torch.nn.Module):
186
+ def __init__(
187
+ self,
188
+ head_dim: int,
189
+ num_heads: int,
190
+ eps: float = 1e-5,
191
+ bias: bool = False,
192
+ ):
193
+ super().__init__()
194
+ self.head_dim = head_dim
195
+ self.num_heads = num_heads
196
+ self.norms = torch.torch.nn.ModuleList(
197
+ [nn.LayerNorm(head_dim, eps=eps, bias=bias) for _ in range(self.num_heads)]
198
+ )
199
+
200
+ def forward(self, x: torch.Tensor):
201
+ # Split along the num_heads axis to get per-head inputs
202
+ # [batch_size, num_heads, seq_len, head_dim] -> [batch_size, 1, seq_len, head_dim] * num_heads
203
+ heads = torch.split(x, 1, dim=1)
204
+ # Normalize and put the heads back together
205
+ return torch.cat([norm(x) for norm, x in zip(self.norms, heads)], dim=1)
206
+
207
+
208
+ class MLP(nn.Module):
209
+ def __init__(self, config: StableLMEpochConfig):
210
+ super().__init__()
211
+ self.config = config
212
+ self.hidden_size = config.hidden_size
213
+ self.intermediate_size = config.intermediate_size
214
+ self.gate_proj = nn.Linear(
215
+ config.hidden_size, config.intermediate_size, bias=False
216
+ )
217
+ self.up_proj = nn.Linear(
218
+ config.hidden_size, config.intermediate_size, bias=False
219
+ )
220
+ self.down_proj = nn.Linear(
221
+ config.intermediate_size, config.hidden_size, bias=False
222
+ )
223
+ self.act_fn = nn.SiLU()
224
+
225
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
226
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
227
+
228
+
229
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
230
+ """
231
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
232
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
233
+ """
234
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
235
+ if n_rep == 1:
236
+ return hidden_states
237
+ hidden_states = hidden_states[:, :, None, :, :].expand(
238
+ batch, num_key_value_heads, n_rep, slen, head_dim
239
+ )
240
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
241
+
242
+
243
+ class Attention(nn.Module):
244
+ def __init__(self, config: StableLMEpochConfig):
245
+ super().__init__()
246
+ self.config = config
247
+ self.hidden_size = config.hidden_size
248
+ self.num_heads = config.num_attention_heads
249
+ self.head_dim = self.hidden_size // self.num_heads
250
+ self.num_key_value_heads = config.num_key_value_heads
251
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
252
+ self.max_position_embeddings = config.max_position_embeddings
253
+ self.is_causal = True
254
+ self.attention_dropout = config.attention_dropout
255
+
256
+ if (self.head_dim * self.num_heads) != self.hidden_size:
257
+ raise ValueError(
258
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
259
+ f" and `num_heads`: {self.num_heads})."
260
+ )
261
+
262
+ self.q_proj = nn.Linear(
263
+ self.hidden_size, self.num_heads * self.head_dim, bias=config.use_qkv_bias
264
+ )
265
+ self.k_proj = nn.Linear(
266
+ self.hidden_size,
267
+ self.num_key_value_heads * self.head_dim,
268
+ bias=config.use_qkv_bias,
269
+ )
270
+ self.v_proj = nn.Linear(
271
+ self.hidden_size,
272
+ self.num_key_value_heads * self.head_dim,
273
+ bias=config.use_qkv_bias,
274
+ )
275
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
276
+
277
+ self.q_norm = LayerNormPerHead(
278
+ self.head_dim, self.num_heads, eps=config.norm_eps, bias=False
279
+ )
280
+ self.k_norm = LayerNormPerHead(
281
+ self.head_dim, self.num_key_value_heads, eps=config.norm_eps, bias=False
282
+ )
283
+
284
+ self._init_rope()
285
+
286
+ def _init_rope(self):
287
+ self.rotary_ndims = int(self.head_dim * self.config.rope_pct)
288
+ self.rotary_emb = RotaryEmbedding(
289
+ self.rotary_ndims,
290
+ max_position_embeddings=self.config.max_position_embeddings,
291
+ base=self.config.rope_theta,
292
+ )
293
+
294
+ def forward(
295
+ self,
296
+ hidden_states: torch.FloatTensor,
297
+ attention_mask: torch.FloatTensor,
298
+ position_ids: torch.LongTensor,
299
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
300
+ output_attentions: Optional[bool] = False,
301
+ use_cache: Optional[bool] = False,
302
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
303
+ bsz, q_len, _ = hidden_states.size()
304
+
305
+ query_states = self.q_proj(hidden_states)
306
+ key_states = self.k_proj(hidden_states)
307
+ value_states = self.v_proj(hidden_states)
308
+
309
+ query_states = query_states.view(
310
+ bsz, q_len, self.num_heads, self.head_dim
311
+ ).transpose(1, 2)
312
+ key_states = key_states.view(
313
+ bsz, q_len, self.num_key_value_heads, self.head_dim
314
+ ).transpose(1, 2)
315
+ value_states = value_states.view(
316
+ bsz, q_len, self.num_key_value_heads, self.head_dim
317
+ ).transpose(1, 2)
318
+
319
+ # [batch_size, num_heads, seq_len, head_dim]
320
+ query_states = self.q_norm(query_states)
321
+ key_states = self.k_norm(key_states)
322
+
323
+ query_rot = query_states[..., : self.rotary_ndims]
324
+ query_pass = query_states[..., self.rotary_ndims :]
325
+ key_rot = key_states[..., : self.rotary_ndims]
326
+ key_pass = key_states[..., self.rotary_ndims :]
327
+
328
+ kv_seq_len = key_states.shape[-2]
329
+ if past_key_value is not None:
330
+ kv_seq_len += past_key_value[0].shape[-2]
331
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
332
+ query_states, key_states = apply_rotary_pos_emb(
333
+ query_rot, key_rot, cos, sin, position_ids
334
+ )
335
+
336
+ # [batch_size, num_heads, seq_len, head_dim]
337
+ query_states = torch.cat((query_states, query_pass), dim=-1)
338
+ key_states = torch.cat((key_states, key_pass), dim=-1)
339
+
340
+ if past_key_value is not None:
341
+ # Reuse k, v, self_attention
342
+ key_states = torch.cat((past_key_value[0], key_states), dim=2)
343
+ value_states = torch.cat((past_key_value[1], value_states), dim=2)
344
+
345
+ past_key_value = (key_states, value_states) if use_cache else None
346
+
347
+ # Repeat k/v heads if n_kv_heads < n_heads
348
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
349
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
350
+
351
+ attn_weights = torch.matmul(
352
+ query_states, key_states.transpose(2, 3)
353
+ ) / math.sqrt(self.head_dim)
354
+
355
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
356
+ raise ValueError(
357
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
358
+ f" {attn_weights.size()}"
359
+ )
360
+
361
+ if attention_mask is not None:
362
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
363
+ raise ValueError(
364
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
365
+ )
366
+ attn_weights = attn_weights + attention_mask
367
+
368
+ # Upcast attention to fp32
369
+ attn_weights = nn.functional.softmax(
370
+ attn_weights, dim=-1, dtype=torch.float32
371
+ ).to(query_states.dtype)
372
+ attn_weights = nn.functional.dropout(
373
+ attn_weights, p=self.attention_dropout, training=self.training
374
+ )
375
+ attn_output = torch.matmul(attn_weights, value_states)
376
+
377
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
378
+ raise ValueError(
379
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
380
+ f" {attn_output.size()}"
381
+ )
382
+
383
+ # Merge heads
384
+ attn_output = attn_output.transpose(1, 2).contiguous()
385
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
386
+
387
+ # Final linear projection
388
+ attn_output = self.o_proj(attn_output)
389
+
390
+ if not output_attentions:
391
+ attn_weights = None
392
+
393
+ return attn_output, attn_weights, past_key_value
394
+
395
+
396
+ class FlashAttention2(Attention):
397
+ """
398
+ Reference: https://github.com/huggingface/transformers/blob/5d36025ca13d05151b7a0c761e90d429c4644a30/src/transformers/models/llama/modeling_llama.py#L456
399
+ """
400
+
401
+ def __init__(self, *args, **kwargs):
402
+ super().__init__(*args, **kwargs)
403
+
404
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
405
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
406
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
407
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
408
+
409
+ def forward(
410
+ self,
411
+ hidden_states: torch.Tensor,
412
+ attention_mask: Optional[torch.LongTensor] = None,
413
+ position_ids: Optional[torch.LongTensor] = None,
414
+ past_key_value: Optional[Cache] = None,
415
+ output_attentions: bool = False,
416
+ use_cache: bool = False,
417
+ **kwargs,
418
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
419
+ # FlashAttention2 attention does not support output_attentions
420
+ if "padding_mask" in kwargs:
421
+ warnings.warn(
422
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
423
+ )
424
+
425
+ # overwrite attention_mask with padding_mask
426
+ attention_mask = kwargs.pop("padding_mask")
427
+
428
+ output_attentions = False
429
+
430
+ bsz, q_len, _ = hidden_states.size()
431
+
432
+ query_states = self.q_proj(hidden_states)
433
+ key_states = self.k_proj(hidden_states)
434
+ value_states = self.v_proj(hidden_states)
435
+
436
+ # Flash attention requires the input to have the shape
437
+ # batch_size x seq_length x head_dim x hidden_dim
438
+ # therefore we just need to keep the original shape
439
+ query_states = query_states.view(
440
+ bsz, q_len, self.num_heads, self.head_dim
441
+ ).transpose(1, 2)
442
+ key_states = key_states.view(
443
+ bsz, q_len, self.num_key_value_heads, self.head_dim
444
+ ).transpose(1, 2)
445
+ value_states = value_states.view(
446
+ bsz, q_len, self.num_key_value_heads, self.head_dim
447
+ ).transpose(1, 2)
448
+
449
+ # [batch_size, num_heads, seq_len, head_dim]
450
+ query_states = self.q_norm(query_states)
451
+ key_states = self.k_norm(key_states)
452
+
453
+ query_rot = query_states[..., : self.rotary_ndims]
454
+ query_pass = query_states[..., self.rotary_ndims :]
455
+ key_rot = key_states[..., : self.rotary_ndims]
456
+ key_pass = key_states[..., self.rotary_ndims :]
457
+
458
+ kv_seq_len = key_states.shape[-2]
459
+ if past_key_value is not None:
460
+ kv_seq_len += past_key_value[0].shape[-2]
461
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
462
+ query_states, key_states = apply_rotary_pos_emb(
463
+ query_rot, key_rot, cos, sin, position_ids
464
+ )
465
+
466
+ # [batch_size, num_heads, seq_len, head_dim]
467
+ query_states = torch.cat((query_states, query_pass), dim=-1)
468
+ key_states = torch.cat((key_states, key_pass), dim=-1)
469
+
470
+ if past_key_value is not None:
471
+ # Reuse k, v, self_attention
472
+ key_states = torch.cat((past_key_value[0], key_states), dim=2)
473
+ value_states = torch.cat((past_key_value[1], value_states), dim=2)
474
+
475
+ past_key_value = (key_states, value_states) if use_cache else None
476
+
477
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
478
+ # to be able to avoid many of these transpose/reshape/view.
479
+ query_states = query_states.transpose(1, 2)
480
+ key_states = key_states.transpose(1, 2)
481
+ value_states = value_states.transpose(1, 2)
482
+
483
+ dropout_rate = self.attention_dropout if self.training else 0.0
484
+
485
+ attn_output = self._flash_attention_forward(
486
+ query_states,
487
+ key_states,
488
+ value_states,
489
+ attention_mask,
490
+ q_len,
491
+ dropout=dropout_rate,
492
+ )
493
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
494
+ attn_output = self.o_proj(attn_output)
495
+
496
+ if not output_attentions:
497
+ attn_weights = None
498
+
499
+ return attn_output, attn_weights, past_key_value
500
+
501
+ def _flash_attention_forward(
502
+ self,
503
+ query_states,
504
+ key_states,
505
+ value_states,
506
+ attention_mask,
507
+ query_length,
508
+ dropout=0.0,
509
+ softmax_scale=None,
510
+ ):
511
+ """
512
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
513
+ first unpad the input, then computes the attention scores and pad the final attention scores.
514
+
515
+ Args:
516
+ query_states (`torch.Tensor`):
517
+ Input query states to be passed to Flash Attention API
518
+ key_states (`torch.Tensor`):
519
+ Input key states to be passed to Flash Attention API
520
+ value_states (`torch.Tensor`):
521
+ Input value states to be passed to Flash Attention API
522
+ attention_mask (`torch.Tensor`):
523
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
524
+ position of padding tokens and 1 for the position of non-padding tokens.
525
+ dropout (`int`, *optional*):
526
+ Attention dropout
527
+ softmax_scale (`float`, *optional*):
528
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
529
+ """
530
+ if not self._flash_attn_uses_top_left_mask:
531
+ causal = self.is_causal
532
+ else:
533
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in FlashAttention2 __init__.
534
+ causal = self.is_causal and query_length != 1
535
+
536
+ # Contains at least one padding token in the sequence
537
+ if attention_mask is not None:
538
+ batch_size = query_states.shape[0]
539
+ (
540
+ query_states,
541
+ key_states,
542
+ value_states,
543
+ indices_q,
544
+ cu_seq_lens,
545
+ max_seq_lens,
546
+ ) = self._upad_input(
547
+ query_states, key_states, value_states, attention_mask, query_length
548
+ )
549
+
550
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
551
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
552
+
553
+ attn_output_unpad = flash_attn_varlen_func(
554
+ query_states,
555
+ key_states,
556
+ value_states,
557
+ cu_seqlens_q=cu_seqlens_q,
558
+ cu_seqlens_k=cu_seqlens_k,
559
+ max_seqlen_q=max_seqlen_in_batch_q,
560
+ max_seqlen_k=max_seqlen_in_batch_k,
561
+ dropout_p=dropout,
562
+ softmax_scale=softmax_scale,
563
+ causal=causal,
564
+ )
565
+
566
+ attn_output = pad_input(
567
+ attn_output_unpad, indices_q, batch_size, query_length
568
+ )
569
+ else:
570
+ attn_output = flash_attn_func(
571
+ query_states,
572
+ key_states,
573
+ value_states,
574
+ dropout,
575
+ softmax_scale=softmax_scale,
576
+ causal=causal,
577
+ )
578
+
579
+ return attn_output
580
+
581
+ def _upad_input(
582
+ self, query_layer, key_layer, value_layer, attention_mask, query_length
583
+ ):
584
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
585
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
586
+
587
+ key_layer = index_first_axis(
588
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
589
+ indices_k,
590
+ )
591
+ value_layer = index_first_axis(
592
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
593
+ indices_k,
594
+ )
595
+ if query_length == kv_seq_len:
596
+ query_layer = index_first_axis(
597
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
598
+ indices_k,
599
+ )
600
+ cu_seqlens_q = cu_seqlens_k
601
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
602
+ indices_q = indices_k
603
+ elif query_length == 1:
604
+ max_seqlen_in_batch_q = 1
605
+ cu_seqlens_q = torch.arange(
606
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
607
+ ) # There is a memcpy here, that is very bad.
608
+ indices_q = cu_seqlens_q[:-1]
609
+ query_layer = query_layer.squeeze(1)
610
+ else:
611
+ # The -q_len: slice assumes left padding.
612
+ attention_mask = attention_mask[:, -query_length:]
613
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
614
+ query_layer, attention_mask
615
+ )
616
+
617
+ return (
618
+ query_layer,
619
+ key_layer,
620
+ value_layer,
621
+ indices_q,
622
+ (cu_seqlens_q, cu_seqlens_k),
623
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
624
+ )
625
+
626
+
627
+ ATTENTION_CLASSES = {
628
+ "eager": Attention,
629
+ "flash_attention_2": FlashAttention2,
630
+ }
631
+
632
+
633
+ class DecoderLayer(nn.Module):
634
+ def __init__(self, config: StableLMEpochConfig):
635
+ super().__init__()
636
+ self.self_attn = ATTENTION_CLASSES[config._attn_implementation](config=config)
637
+ self.mlp = MLP(config)
638
+ self.input_layernorm = nn.LayerNorm(
639
+ config.hidden_size, eps=config.norm_eps, bias=config.use_norm_bias
640
+ )
641
+
642
+ def forward(
643
+ self,
644
+ hidden_states: Optional[torch.FloatTensor],
645
+ attention_mask: Optional[torch.FloatTensor] = None,
646
+ position_ids: Optional[torch.LongTensor] = None,
647
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
648
+ output_attentions: Optional[bool] = False,
649
+ use_cache: Optional[bool] = False,
650
+ ) -> Union[
651
+ Tuple[torch.Tensor],
652
+ Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]],
653
+ ]:
654
+ residual = hidden_states
655
+
656
+ hidden_states = self.input_layernorm(hidden_states)
657
+
658
+ # Self Attention
659
+ self_attn_output, self_attn_weights, present_key_value = self.self_attn(
660
+ hidden_states=hidden_states,
661
+ attention_mask=attention_mask,
662
+ position_ids=position_ids,
663
+ past_key_value=past_key_value,
664
+ output_attentions=output_attentions,
665
+ use_cache=use_cache,
666
+ )
667
+
668
+ # Fully Connected
669
+ mlp_output = self.mlp(hidden_states)
670
+
671
+ # Parallel Residual
672
+ hidden_states = residual + self_attn_output + mlp_output
673
+
674
+ outputs = (hidden_states,)
675
+
676
+ if output_attentions:
677
+ outputs += (self_attn_weights,)
678
+
679
+ if use_cache:
680
+ outputs += (present_key_value,)
681
+
682
+ return outputs
683
+
684
+
685
+ class StableLMEpochPreTrainedModel(PreTrainedModel):
686
+ """An abstract class to handle weights initialization and a simple interface
687
+ for downloading and loading pretrained models.
688
+ """
689
+
690
+ config_class = StableLMEpochConfig
691
+ base_model_prefix = "transformer"
692
+ supports_gradient_checkpointing = True
693
+ _no_split_modules = ["DecoderLayer"]
694
+ _skip_keys_device_placement = "past_key_values"
695
+ _supports_flash_attn_2 = True
696
+
697
+ def _init_weights(self, module: nn.Module):
698
+ """Initialize the weights"""
699
+ if isinstance(module, nn.Linear):
700
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
701
+ if module.bias is not None:
702
+ module.bias.data.zero_()
703
+ elif isinstance(module, nn.Embedding):
704
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
705
+ if module.padding_idx is not None:
706
+ module.weight.data[module.padding_idx].zero_()
707
+ elif isinstance(module, nn.LayerNorm):
708
+ if module.bias is not None:
709
+ module.bias.data.zero_()
710
+ module.weight.data.fill_(1.0)
711
+
712
+ def _set_gradient_checkpointing(self, module: nn.Module, value=False):
713
+ if isinstance(module, StableLMEpochModel):
714
+ module.gradient_checkpointing = value
715
+
716
+
717
+ class StableLMEpochModel(StableLMEpochPreTrainedModel):
718
+ def __init__(self, config: StableLMEpochConfig):
719
+ super().__init__(config)
720
+ self.embed_tokens = nn.Embedding(
721
+ config.vocab_size, config.hidden_size, config.pad_token_id
722
+ )
723
+ self.layers = nn.ModuleList(
724
+ [DecoderLayer(config) for _ in range(config.num_hidden_layers)]
725
+ )
726
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
727
+
728
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
729
+ self.gradient_checkpointing = False
730
+ # Initialize weights and apply final processing
731
+ self.post_init()
732
+
733
+ def get_input_embeddings(self):
734
+ return self.embed_tokens
735
+
736
+ def set_input_embeddings(self, value: nn.Module):
737
+ self.embed_tokens = value
738
+
739
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
740
+ def _prepare_decoder_attention_mask(
741
+ self,
742
+ attention_mask: torch.Tensor,
743
+ input_shape: torch.Size,
744
+ inputs_embeds: torch.Tensor,
745
+ past_key_values_length: int,
746
+ ):
747
+ # Create causal mask
748
+ # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
749
+ combined_attention_mask = None
750
+ if input_shape[-1] > 1:
751
+ combined_attention_mask = _make_causal_mask(
752
+ input_shape,
753
+ inputs_embeds.dtype,
754
+ device=inputs_embeds.device,
755
+ past_key_values_length=past_key_values_length,
756
+ )
757
+
758
+ if attention_mask is not None:
759
+ # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
760
+ expanded_attn_mask = _expand_mask(
761
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
762
+ ).to(inputs_embeds.device)
763
+ combined_attention_mask = (
764
+ expanded_attn_mask
765
+ if combined_attention_mask is None
766
+ else expanded_attn_mask + combined_attention_mask
767
+ )
768
+
769
+ return combined_attention_mask
770
+
771
+ def forward(
772
+ self,
773
+ input_ids: Optional[torch.LongTensor] = None,
774
+ attention_mask: Optional[torch.FloatTensor] = None,
775
+ position_ids: Optional[torch.LongTensor] = None,
776
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
777
+ inputs_embeds: Optional[torch.FloatTensor] = None,
778
+ use_cache: Optional[bool] = None,
779
+ output_attentions: Optional[bool] = None,
780
+ output_hidden_states: Optional[bool] = None,
781
+ return_dict: Optional[bool] = None,
782
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
783
+ output_attentions = (
784
+ output_attentions
785
+ if output_attentions is not None
786
+ else self.config.output_attentions
787
+ )
788
+ output_hidden_states = (
789
+ output_hidden_states
790
+ if output_hidden_states is not None
791
+ else self.config.output_hidden_states
792
+ )
793
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
794
+
795
+ return_dict = (
796
+ return_dict if return_dict is not None else self.config.use_return_dict
797
+ )
798
+
799
+ # Retrieve input_ids and inputs_embeds
800
+ if input_ids is not None and inputs_embeds is not None:
801
+ raise ValueError(
802
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
803
+ )
804
+ elif input_ids is not None:
805
+ batch_size, seq_length = input_ids.shape
806
+ elif inputs_embeds is not None:
807
+ batch_size, seq_length, _ = inputs_embeds.shape
808
+ else:
809
+ raise ValueError(
810
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
811
+ )
812
+
813
+ seq_length_with_past = seq_length
814
+ past_key_values_length = 0
815
+
816
+ if position_ids is None:
817
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
818
+ position_ids = torch.arange(
819
+ past_key_values_length,
820
+ seq_length + past_key_values_length,
821
+ dtype=torch.long,
822
+ device=device,
823
+ )
824
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
825
+ else:
826
+ position_ids = position_ids.view(-1, seq_length).long()
827
+
828
+ if inputs_embeds is None:
829
+ inputs_embeds = self.embed_tokens(input_ids)
830
+ # Embed positions
831
+ if self._use_flash_attention_2:
832
+ # 2d mask is passed through the layers
833
+ attention_mask = (
834
+ attention_mask
835
+ if (attention_mask is not None and 0 in attention_mask)
836
+ else None
837
+ )
838
+ else:
839
+ if attention_mask is None:
840
+ attention_mask = torch.ones(
841
+ (batch_size, seq_length_with_past),
842
+ dtype=torch.bool,
843
+ device=inputs_embeds.device,
844
+ )
845
+ attention_mask = self._prepare_decoder_attention_mask(
846
+ attention_mask,
847
+ (batch_size, seq_length),
848
+ inputs_embeds,
849
+ past_key_values_length,
850
+ )
851
+
852
+ hidden_states = inputs_embeds
853
+
854
+ if self.gradient_checkpointing and self.training:
855
+ if use_cache:
856
+ logger.warning(
857
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
858
+ )
859
+ use_cache = False
860
+
861
+ # Decoder layers
862
+ all_hidden_states = () if output_hidden_states else None
863
+ all_self_attns = () if output_attentions else None
864
+ next_decoder_cache = () if use_cache else None
865
+
866
+ for idx, decoder_layer in enumerate(self.layers):
867
+ if output_hidden_states:
868
+ all_hidden_states += (hidden_states,)
869
+
870
+ past_key_value = (
871
+ past_key_values[idx] if past_key_values is not None else None
872
+ )
873
+
874
+ if self.gradient_checkpointing and self.training:
875
+
876
+ def create_custom_forward(module):
877
+ def custom_forward(*inputs):
878
+ # None for past_key_value
879
+ return module(*inputs, past_key_value, output_attentions)
880
+
881
+ return custom_forward
882
+
883
+ layer_outputs = torch.utils.checkpoint.checkpoint(
884
+ create_custom_forward(decoder_layer),
885
+ hidden_states,
886
+ attention_mask,
887
+ position_ids,
888
+ )
889
+ else:
890
+ layer_outputs = decoder_layer(
891
+ hidden_states,
892
+ attention_mask=attention_mask,
893
+ position_ids=position_ids,
894
+ past_key_value=past_key_value,
895
+ output_attentions=output_attentions,
896
+ use_cache=use_cache,
897
+ )
898
+
899
+ hidden_states = layer_outputs[0]
900
+
901
+ if use_cache:
902
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
903
+
904
+ if output_attentions:
905
+ all_self_attns += (layer_outputs[1],)
906
+
907
+ hidden_states = self.norm(hidden_states)
908
+
909
+ # Add hidden states from the last decoder layer
910
+ if output_hidden_states:
911
+ all_hidden_states += (hidden_states,)
912
+
913
+ next_cache = next_decoder_cache if use_cache else None
914
+ if not return_dict:
915
+ return tuple(
916
+ v
917
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
918
+ if v is not None
919
+ )
920
+ return BaseModelOutputWithPast(
921
+ last_hidden_state=hidden_states,
922
+ past_key_values=next_cache,
923
+ hidden_states=all_hidden_states,
924
+ attentions=all_self_attns,
925
+ )
926
+
927
+
928
+ class StableLMEpochForCausalLM(StableLMEpochPreTrainedModel):
929
+ _tied_weights_keys = ["lm_head.weight"]
930
+
931
+ def __init__(self, config: StableLMEpochConfig):
932
+ super().__init__(config)
933
+
934
+ self.model = StableLMEpochModel(config)
935
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
936
+
937
+ # Initialize weights and apply final processing
938
+ self.post_init()
939
+
940
+ def get_input_embeddings(self):
941
+ return self.model.embed_tokens
942
+
943
+ def set_input_embeddings(self, value):
944
+ self.model.embed_tokens = value
945
+
946
+ def get_output_embeddings(self):
947
+ return self.lm_head
948
+
949
+ def set_output_embeddings(self, new_embeddings: nn.Module):
950
+ self.lm_head = new_embeddings
951
+
952
+ def get_decoder(self):
953
+ return self.model
954
+
955
+ def set_decoder(self, decoder):
956
+ self.model = decoder
957
+
958
+ def forward(
959
+ self,
960
+ input_ids: Optional[torch.LongTensor] = None,
961
+ attention_mask: Optional[torch.FloatTensor] = None,
962
+ position_ids: Optional[torch.LongTensor] = None,
963
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
964
+ inputs_embeds: Optional[torch.FloatTensor] = None,
965
+ labels: Optional[torch.LongTensor] = None,
966
+ use_cache: Optional[bool] = None,
967
+ output_attentions: Optional[bool] = None,
968
+ output_hidden_states: Optional[bool] = None,
969
+ return_dict: Optional[bool] = None,
970
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
971
+ output_attentions = (
972
+ output_attentions
973
+ if output_attentions is not None
974
+ else self.config.output_attentions
975
+ )
976
+ output_hidden_states = (
977
+ output_hidden_states
978
+ if output_hidden_states is not None
979
+ else self.config.output_hidden_states
980
+ )
981
+ return_dict = (
982
+ return_dict if return_dict is not None else self.config.use_return_dict
983
+ )
984
+
985
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
986
+ outputs = self.model(
987
+ input_ids,
988
+ attention_mask=attention_mask,
989
+ position_ids=position_ids,
990
+ past_key_values=past_key_values,
991
+ inputs_embeds=inputs_embeds,
992
+ use_cache=use_cache,
993
+ output_attentions=output_attentions,
994
+ output_hidden_states=output_hidden_states,
995
+ return_dict=return_dict,
996
+ )
997
+
998
+ hidden_states = outputs[0]
999
+ logits = self.lm_head(hidden_states).float()
1000
+
1001
+ loss = None
1002
+ if labels is not None:
1003
+ # Shift so that tokens < n predict n
1004
+ shift_logits = logits[..., :-1, :].contiguous()
1005
+ shift_labels = labels[..., 1:].contiguous()
1006
+ # Flatten the tokens
1007
+ loss_fct = CrossEntropyLoss()
1008
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1009
+ shift_labels = shift_labels.view(-1)
1010
+ # Enable model parallelism
1011
+ shift_labels = shift_labels.to(shift_logits.device)
1012
+ loss = loss_fct(shift_logits, shift_labels)
1013
+
1014
+ if not return_dict:
1015
+ output = (logits,) + outputs[1:]
1016
+ return (loss,) + output if loss is not None else output
1017
+
1018
+ return CausalLMOutputWithPast(
1019
+ loss=loss,
1020
+ logits=logits,
1021
+ past_key_values=outputs.past_key_values,
1022
+ hidden_states=outputs.hidden_states,
1023
+ attentions=outputs.attentions,
1024
+ )
1025
+
1026
+ def prepare_inputs_for_generation(
1027
+ self,
1028
+ input_ids,
1029
+ past_key_values: Optional[torch.Tensor] = None,
1030
+ attention_mask: Optional[torch.Tensor] = None,
1031
+ inputs_embeds: Optional[torch.Tensor] = None,
1032
+ **kwargs,
1033
+ ):
1034
+ # Trim decoder_input_ids if past is used
1035
+ if past_key_values is not None:
1036
+ past_length = past_key_values[0][0].shape[2]
1037
+
1038
+ # Some generation methods already pass only the last input ID
1039
+ if input_ids.shape[1] > past_length:
1040
+ remove_prefix_length = past_length
1041
+ else:
1042
+ # Default to old behavior: keep only final ID
1043
+ remove_prefix_length = input_ids.shape[1] - 1
1044
+
1045
+ input_ids = input_ids[:, remove_prefix_length:]
1046
+
1047
+ position_ids = kwargs.get("position_ids", None)
1048
+ if attention_mask is not None and position_ids is None:
1049
+ # Create position_ids on the fly for batch generation
1050
+ position_ids = attention_mask.long().cumsum(-1) - 1
1051
+ position_ids.masked_fill_(attention_mask == 0, 1)
1052
+ if past_key_values:
1053
+ position_ids = position_ids[:, -1].unsqueeze(-1)
1054
+
1055
+ # If `inputs_embeds` are passed, we only want to use them in the 1st generation step
1056
+ if inputs_embeds is not None and past_key_values is None:
1057
+ model_inputs = {"inputs_embeds": inputs_embeds}
1058
+ else:
1059
+ model_inputs = {"input_ids": input_ids}
1060
+
1061
+ model_inputs.update(
1062
+ {
1063
+ "attention_mask": attention_mask,
1064
+ "past_key_values": past_key_values,
1065
+ "use_cache": kwargs.get("use_cache"),
1066
+ "position_ids": position_ids,
1067
+ }
1068
+ )
1069
+ return model_inputs
1070
+
1071
+ @staticmethod
1072
+ def _reorder_cache(past_key_values, beam_idx):
1073
+ reordered_past = ()
1074
+ for layer_past in past_key_values:
1075
+ reordered_past += (
1076
+ tuple(
1077
+ past_state.index_select(0, beam_idx.to(past_state.device))
1078
+ for past_state in layer_past
1079
+ ),
1080
+ )
1081
+ return reordered_past
1082
+
1083
+
1084
+ StableLMEpochConfig.register_for_auto_class()
1085
+ StableLMEpochForCausalLM.register_for_auto_class("AutoModelForCausalLM")