camenduru commited on
Commit
5bd6388
1 Parent(s): 1ce9e06

thanks to stabilityai ❤

Browse files
config.json ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "stabilityai/japanese-instructblip-alpha",
3
+ "architectures": [
4
+ "JapaneseInstructBlipAlphaForConditionalGeneration"
5
+ ],
6
+ "auto_map": {
7
+ "AutoModelForVision2Seq": "modeling_japanese_instructblip_alpha.JapaneseInstructBlipAlphaForConditionalGeneration",
8
+ "AutoConfig": "configuration_japanese_instructblip_alpha.JapaneseInstructBlipAlphaConfig"
9
+ },
10
+ "initializer_factor": 1.0,
11
+ "initializer_range": 0.02,
12
+ "model_type": "instructblip",
13
+ "num_query_tokens": 32,
14
+ "qformer_config": {
15
+ "_name_or_path": "",
16
+ "add_cross_attention": false,
17
+ "architectures": null,
18
+ "attention_probs_dropout_prob": 0.1,
19
+ "bad_words_ids": null,
20
+ "begin_suppress_tokens": null,
21
+ "bos_token_id": null,
22
+ "chunk_size_feed_forward": 0,
23
+ "cross_attention_frequency": 2,
24
+ "cross_attention_hidden_size": null,
25
+ "decoder_start_token_id": null,
26
+ "diversity_penalty": 0.0,
27
+ "do_sample": false,
28
+ "early_stopping": false,
29
+ "encoder_hidden_size": 1408,
30
+ "encoder_no_repeat_ngram_size": 0,
31
+ "eos_token_id": null,
32
+ "exponential_decay_length_penalty": null,
33
+ "finetuning_task": null,
34
+ "forced_bos_token_id": null,
35
+ "forced_eos_token_id": null,
36
+ "hidden_act": "gelu",
37
+ "hidden_dropout_prob": 0.1,
38
+ "hidden_size": 768,
39
+ "id2label": {
40
+ "0": "LABEL_0",
41
+ "1": "LABEL_1"
42
+ },
43
+ "initializer_range": 0.02,
44
+ "intermediate_size": 3072,
45
+ "is_decoder": false,
46
+ "is_encoder_decoder": false,
47
+ "label2id": {
48
+ "LABEL_0": 0,
49
+ "LABEL_1": 1
50
+ },
51
+ "layer_norm_eps": 1e-12,
52
+ "length_penalty": 1.0,
53
+ "max_length": 20,
54
+ "max_position_embeddings": 512,
55
+ "min_length": 0,
56
+ "model_type": "instructblip_qformer",
57
+ "no_repeat_ngram_size": 0,
58
+ "num_attention_heads": 12,
59
+ "num_beam_groups": 1,
60
+ "num_beams": 1,
61
+ "num_hidden_layers": 12,
62
+ "num_return_sequences": 1,
63
+ "output_attentions": false,
64
+ "output_hidden_states": false,
65
+ "output_scores": false,
66
+ "pad_token_id": 0,
67
+ "position_embedding_type": "absolute",
68
+ "prefix": null,
69
+ "problem_type": null,
70
+ "pruned_heads": {},
71
+ "remove_invalid_values": false,
72
+ "repetition_penalty": 1.0,
73
+ "return_dict": true,
74
+ "return_dict_in_generate": false,
75
+ "sep_token_id": null,
76
+ "suppress_tokens": null,
77
+ "task_specific_params": null,
78
+ "temperature": 1.0,
79
+ "tf_legacy_loss": false,
80
+ "tie_encoder_decoder": false,
81
+ "tie_word_embeddings": true,
82
+ "tokenizer_class": null,
83
+ "top_k": 50,
84
+ "top_p": 1.0,
85
+ "torch_dtype": null,
86
+ "torchscript": false,
87
+ "transformers_version": "4.31.0",
88
+ "typical_p": 1.0,
89
+ "use_bfloat16": false,
90
+ "vocab_size": 65535
91
+ },
92
+ "text_config": {
93
+ "_name_or_path": "stabilityai/japanese-stablelm-instruct-alpha-7b",
94
+ "add_cross_attention": false,
95
+ "architectures": [
96
+ "JapaneseStableLMAlphaForCausalLM"
97
+ ],
98
+ "auto_map": {
99
+ "AutoConfig": "stabilityai/japanese-stablelm-instruct-alpha-7b--configuration_japanese_stablelm_alpha.JapaneseStableLMAlphaConfig",
100
+ "AutoModelForCausalLM": "stabilityai/japanese-stablelm-instruct-alpha-7b--modeling_japanese_stablelm_alpha.JapaneseStableLMAlphaForCausalLM"
101
+ },
102
+ "bad_words_ids": null,
103
+ "begin_suppress_tokens": null,
104
+ "bos_token_id": 3,
105
+ "chunk_size_feed_forward": 0,
106
+ "classifier_dropout": 0.1,
107
+ "cross_attention_hidden_size": null,
108
+ "decoder_start_token_id": null,
109
+ "diversity_penalty": 0.0,
110
+ "do_sample": false,
111
+ "early_stopping": false,
112
+ "encoder_no_repeat_ngram_size": 0,
113
+ "eos_token_id": 3,
114
+ "exponential_decay_length_penalty": null,
115
+ "finetuning_task": null,
116
+ "forced_bos_token_id": null,
117
+ "forced_eos_token_id": null,
118
+ "hidden_act": "silu",
119
+ "hidden_size": 4096,
120
+ "id2label": {
121
+ "0": "LABEL_0",
122
+ "1": "LABEL_1"
123
+ },
124
+ "initializer_range": 0.02,
125
+ "is_decoder": false,
126
+ "is_encoder_decoder": false,
127
+ "label2id": {
128
+ "LABEL_0": 0,
129
+ "LABEL_1": 1
130
+ },
131
+ "layer_norm_eps": 1e-05,
132
+ "length_penalty": 1.0,
133
+ "max_length": 20,
134
+ "max_position_embeddings": 2048,
135
+ "min_length": 0,
136
+ "no_repeat_ngram_size": 0,
137
+ "num_attention_heads": 32,
138
+ "num_beam_groups": 1,
139
+ "num_beams": 1,
140
+ "num_hidden_layers": 32,
141
+ "num_return_sequences": 1,
142
+ "output_attentions": false,
143
+ "output_hidden_states": false,
144
+ "output_scores": false,
145
+ "pad_token_id": null,
146
+ "prefix": null,
147
+ "problem_type": null,
148
+ "pruned_heads": {},
149
+ "remove_invalid_values": false,
150
+ "repetition_penalty": 1.0,
151
+ "return_dict": true,
152
+ "return_dict_in_generate": false,
153
+ "rotary_emb_base": 10000,
154
+ "rotary_pct": 0.25,
155
+ "rotary_scale_base": 512,
156
+ "sep_token_id": null,
157
+ "suppress_tokens": null,
158
+ "task_specific_params": null,
159
+ "temperature": 1.0,
160
+ "tf_legacy_loss": false,
161
+ "tie_encoder_decoder": false,
162
+ "tie_word_embeddings": false,
163
+ "tokenizer_class": null,
164
+ "top_k": 50,
165
+ "top_p": 1.0,
166
+ "torch_dtype": "float32",
167
+ "torchscript": false,
168
+ "transformers_version": "4.31.0",
169
+ "typical_p": 1.0,
170
+ "use_bfloat16": false,
171
+ "use_bias_in_mlp": false,
172
+ "use_cache": true,
173
+ "use_parallel_residual": true,
174
+ "vocab_size": 65535
175
+ },
176
+ "tie_word_embeddings": false,
177
+ "torch_dtype": "float32",
178
+ "transformers_version": null,
179
+ "use_decoder_only_language_model": true,
180
+ "vision_config": {
181
+ "_name_or_path": "",
182
+ "add_cross_attention": false,
183
+ "architectures": null,
184
+ "attention_dropout": 0.0,
185
+ "bad_words_ids": null,
186
+ "begin_suppress_tokens": null,
187
+ "bos_token_id": null,
188
+ "chunk_size_feed_forward": 0,
189
+ "cross_attention_hidden_size": null,
190
+ "decoder_start_token_id": null,
191
+ "diversity_penalty": 0.0,
192
+ "do_sample": false,
193
+ "early_stopping": false,
194
+ "encoder_no_repeat_ngram_size": 0,
195
+ "eos_token_id": null,
196
+ "exponential_decay_length_penalty": null,
197
+ "finetuning_task": null,
198
+ "forced_bos_token_id": null,
199
+ "forced_eos_token_id": null,
200
+ "hidden_act": "gelu",
201
+ "hidden_size": 1408,
202
+ "id2label": {
203
+ "0": "LABEL_0",
204
+ "1": "LABEL_1"
205
+ },
206
+ "image_size": 224,
207
+ "initializer_range": 1e-10,
208
+ "intermediate_size": 6144,
209
+ "is_decoder": false,
210
+ "is_encoder_decoder": false,
211
+ "label2id": {
212
+ "LABEL_0": 0,
213
+ "LABEL_1": 1
214
+ },
215
+ "layer_norm_eps": 1e-06,
216
+ "length_penalty": 1.0,
217
+ "max_length": 20,
218
+ "min_length": 0,
219
+ "model_type": "instructblip_vision_model",
220
+ "no_repeat_ngram_size": 0,
221
+ "num_attention_heads": 16,
222
+ "num_beam_groups": 1,
223
+ "num_beams": 1,
224
+ "num_hidden_layers": 39,
225
+ "num_return_sequences": 1,
226
+ "output_attentions": false,
227
+ "output_hidden_states": false,
228
+ "output_scores": false,
229
+ "pad_token_id": null,
230
+ "patch_size": 14,
231
+ "prefix": null,
232
+ "problem_type": null,
233
+ "pruned_heads": {},
234
+ "qkv_bias": true,
235
+ "remove_invalid_values": false,
236
+ "repetition_penalty": 1.0,
237
+ "return_dict": true,
238
+ "return_dict_in_generate": false,
239
+ "sep_token_id": null,
240
+ "suppress_tokens": null,
241
+ "task_specific_params": null,
242
+ "temperature": 1.0,
243
+ "tf_legacy_loss": false,
244
+ "tie_encoder_decoder": false,
245
+ "tie_word_embeddings": true,
246
+ "tokenizer_class": null,
247
+ "top_k": 50,
248
+ "top_p": 1.0,
249
+ "torch_dtype": null,
250
+ "torchscript": false,
251
+ "transformers_version": "4.31.0",
252
+ "typical_p": 1.0,
253
+ "use_bfloat16": false
254
+ }
255
+ }
configuration_japanese_instructblip_alpha.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Stability and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Japanese InstructBLIP Alpha model configuration"""
16
+
17
+ from transformers import (
18
+ PretrainedConfig,
19
+ InstructBlipConfig,
20
+ InstructBlipVisionConfig,
21
+ InstructBlipQFormerConfig,
22
+ AutoConfig,
23
+ )
24
+ from transformers.utils import logging
25
+ from .configuration_japanese_stablelm_alpha import JapaneseStableLMAlphaConfig
26
+
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ class JapaneseInstructBlipAlphaConfig(InstructBlipConfig):
32
+ def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs):
33
+ PretrainedConfig.__init__(self, **kwargs)
34
+
35
+ if vision_config is None:
36
+ vision_config = {}
37
+ logger.info("vision_config is None. initializing the InstructBlipVisionConfig with default values.")
38
+
39
+ if qformer_config is None:
40
+ qformer_config = {}
41
+ logger.info("qformer_config is None. Initializing the InstructBlipQFormerConfig with default values.")
42
+
43
+ if text_config is None:
44
+ text_config = {}
45
+ logger.info("text_config is None. Initializing the text config with default values (`OPTConfig`).")
46
+ self.vision_config = InstructBlipVisionConfig(**vision_config)
47
+ self.qformer_config = InstructBlipQFormerConfig(**qformer_config)
48
+ self.text_config = JapaneseStableLMAlphaConfig(**text_config)
49
+
50
+ self.tie_word_embeddings = self.text_config.tie_word_embeddings
51
+ self.is_encoder_decoder = self.text_config.is_encoder_decoder
52
+
53
+ self.num_query_tokens = num_query_tokens
54
+ self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
55
+ self.use_decoder_only_language_model = True
56
+ self.initializer_factor = 1.0
57
+ self.initializer_range = 0.02
configuration_japanese_stablelm_alpha.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Stability and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ JapaneseStableLMAlpha model configuration"""
16
+
17
+ from transformers import PretrainedConfig
18
+ from transformers.utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+ STABLE_LM_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
24
+
25
+
26
+ class JapaneseStableLMAlphaConfig(PretrainedConfig):
27
+ r"""
28
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
29
+ documentation from [`PretrainedConfig`] for more information.
30
+
31
+ Args:
32
+ vocab_size (`int`, *optional*, defaults to 65536):
33
+ Vocabulary size of the JapaneseStableLMAlphaModel. Defines the number of different tokens that
34
+ can be represented by the `inputs_ids` passed when calling [`JapaneseStableLMAlphaModel`].
35
+ hidden_size (`int`, *optional*, defaults to 4096):
36
+ Dimension of the decoder layers and the pooler layer.
37
+ num_hidden_layers (`int`, *optional*, defaults to 32):
38
+ Number of hidden layers in the Transformer decoder.
39
+ num_attention_heads (`int`, *optional*, defaults to 32):
40
+ Number of attention heads for each attention layer in the Transformer decoder.
41
+ intermediate_size (`int`, *optional*, defaults to 16384):
42
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer decoder.
43
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
44
+ The non-linear activation function (function or string).
45
+ rotary_pct (`float`, *optional*, defaults to 0.25):
46
+ Percentage of hidden dimensions to allocate to rotary embeddings.
47
+ rotary_emb_base (`int`, *optional*, defaults to 10000)
48
+ Base for computing rotary embeddings frequency.
49
+ rotary_scale_base (`int`, *optional*, defaults to 512)
50
+ Base `scale` for computing XPos rotary embeddings scale.
51
+ classifier_dropout (`float`, *optional*, defaults to 0.1):
52
+ Argument used when doing token classification, used in the model
53
+ [`StableLMForTokenClassification`]. The dropout ratio for the hidden layer.
54
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
55
+ The maximum sequence length that this model might ever be used with.
56
+ Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
57
+ initializer_range (`float`, *optional*, defaults to 1e-5):
58
+ The standard deviation of the truncated_normal_initializer for initializing
59
+ all weight matrices.
60
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
61
+ The epsilon used by the layer normalization layers.
62
+ use_cache (`bool`, *optional*, defaults to `True`):
63
+ Whether or not the model should return the last key/values attentions
64
+ (not used by all models). Only relevant if `config.is_decoder=True`.
65
+ use_parallel_residual (`bool`, *optional*, defaults to `True`):
66
+ Whether to use a "parallel" formulation in each Transformer layer,
67
+ which can provide a slight training speedup at large scales.
68
+ Example:
69
+
70
+ ```python
71
+ >>> from transformers import JapaneseStableLMAlphaConfig, JapaneseStableLMAlphaModel
72
+
73
+ >>> # Initializing a JapaneseStableLMAlpha style configuration
74
+ >>> configuration = JapaneseStableLMAlphaConfig()
75
+
76
+ >>> # Initializing a model (with random weights) from the style configuration
77
+ >>> model = JapaneseStableLMAlphaModel(configuration) # doctest: +SKIP
78
+
79
+ >>> # Accessing the model configuration
80
+ >>> configuration = model.config # doctest: +SKIP
81
+ ```"""
82
+ def __init__(
83
+ self,
84
+ vocab_size=65536,
85
+ hidden_size=4096,
86
+ num_hidden_layers=32,
87
+ num_attention_heads=32,
88
+ hidden_act="silu",
89
+ rotary_pct=0.25,
90
+ rotary_emb_base=10000,
91
+ rotary_scale_base=512,
92
+ classifier_dropout=0.1,
93
+ max_position_embeddings=2048,
94
+ initializer_range=0.02,
95
+ layer_norm_eps=1e-5,
96
+ use_cache=True,
97
+ bos_token_id=3,
98
+ eos_token_id=3,
99
+ tie_word_embeddings=False,
100
+ use_parallel_residual=True,
101
+ use_bias_in_mlp=True,
102
+ **kwargs,
103
+ ):
104
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
105
+ self.vocab_size = vocab_size
106
+ self.max_position_embeddings = max_position_embeddings
107
+ self.hidden_size = hidden_size
108
+ self.num_hidden_layers = num_hidden_layers
109
+ self.num_attention_heads = num_attention_heads
110
+ self.hidden_act = hidden_act
111
+ self.rotary_pct = rotary_pct
112
+ self.rotary_emb_base = rotary_emb_base
113
+ self.rotary_scale_base = rotary_scale_base
114
+ self.classifier_dropout = classifier_dropout
115
+ self.initializer_range = initializer_range
116
+ self.layer_norm_eps = layer_norm_eps
117
+ self.use_cache = use_cache
118
+ self.tie_word_embeddings = tie_word_embeddings
119
+ self.use_parallel_residual = use_parallel_residual
120
+ self.use_bias_in_mlp = use_bias_in_mlp
modeling_japanese_instructblip_alpha.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Stability and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch JapaneseStableLMAlpha model. """
16
+ import torch
17
+ from torch import nn
18
+ from transformers import (
19
+ InstructBlipPreTrainedModel,
20
+ InstructBlipVisionModel,
21
+ InstructBlipQFormerModel,
22
+ InstructBlipForConditionalGeneration,
23
+ AutoModelForCausalLM,
24
+ AutoModelForSeq2SeqLM,
25
+ )
26
+ from transformers.utils import logging
27
+ from .modeling_japanese_stablelm_alpha import JapaneseStableLMAlphaForCausalLM
28
+ from .configuration_japanese_instructblip_alpha import JapaneseInstructBlipAlphaConfig
29
+
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ class JapaneseInstructBlipAlphaForConditionalGeneration(InstructBlipForConditionalGeneration):
35
+ config_class = JapaneseInstructBlipAlphaConfig
36
+
37
+ def __init__(self, config: JapaneseInstructBlipAlphaConfig):
38
+ InstructBlipPreTrainedModel.__init__(self, config)
39
+
40
+ self.vision_model = InstructBlipVisionModel(config.vision_config)
41
+
42
+ self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
43
+ self.qformer = InstructBlipQFormerModel(config.qformer_config)
44
+
45
+ self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
46
+
47
+ if config.use_decoder_only_language_model:
48
+ language_model = JapaneseStableLMAlphaForCausalLM(config.text_config)
49
+ else:
50
+ raise NotImplementedError
51
+ language_model = AutoModelForSeq2SeqLM.from_config(config.text_config, trust_remote_code=True,)
52
+
53
+ if language_model._no_split_modules is not None:
54
+ self._no_split_modules.extend(language_model._no_split_modules)
55
+
56
+ if language_model._keep_in_fp32_modules is not None:
57
+ self._keep_in_fp32_modules.extend(language_model._keep_in_fp32_modules)
58
+
59
+ self.language_model = language_model
60
+
61
+ # Initialize weights and apply final processing
62
+ self.post_init()
modeling_japanese_stablelm_alpha.py ADDED
@@ -0,0 +1,682 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Stability and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch JapaneseStableLMAlpha model. """
16
+ from typing import Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.utils.checkpoint
20
+ from torch import nn
21
+ from torch.nn import CrossEntropyLoss
22
+ from transformers.modeling_outputs import (
23
+ BaseModelOutputWithPast,
24
+ CausalLMOutputWithPast,
25
+ )
26
+ from transformers.modeling_utils import PreTrainedModel
27
+ from transformers.utils import logging
28
+ from .configuration_japanese_stablelm_alpha import JapaneseStableLMAlphaConfig
29
+
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ class JapaneseStableLMAlphaPreTrainedModel(PreTrainedModel):
35
+ """
36
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
37
+ models.
38
+ """
39
+
40
+ config_class = JapaneseStableLMAlphaConfig
41
+ base_model_prefix = "transformer"
42
+ supports_gradient_checkpointing = True
43
+ _no_split_modules = ["DecoderLayer"]
44
+ _skip_keys_device_placement = "past_key_values"
45
+
46
+ def _init_weights(self, module):
47
+ """Initialize the weights"""
48
+ if isinstance(module, nn.Linear):
49
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
50
+ if module.bias is not None:
51
+ module.bias.data.zero_()
52
+ elif isinstance(module, nn.Embedding):
53
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
54
+ if module.padding_idx is not None:
55
+ module.weight.data[module.padding_idx].zero_()
56
+ elif isinstance(module, nn.LayerNorm):
57
+ if module.bias is not None:
58
+ module.bias.data.zero_()
59
+ if module.weight is not None:
60
+ module.weight.data.fill_(1.0)
61
+
62
+ def _set_gradient_checkpointing(self, module, value=False):
63
+ if isinstance(module, JapaneseStableLMAlphaModel):
64
+ module.gradient_checkpointing = value
65
+
66
+
67
+ class JapaneseStableLMAlphaModel(JapaneseStableLMAlphaPreTrainedModel):
68
+ def __init__(self, config):
69
+ super().__init__(config)
70
+ self.config = config
71
+
72
+ self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
73
+ self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
74
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
75
+
76
+ self.gradient_checkpointing = False
77
+
78
+ # Initialize weights and apply final processing
79
+ self.post_init()
80
+
81
+ def get_input_embeddings(self):
82
+ return self.embed_in
83
+
84
+ def set_input_embeddings(self, value):
85
+ self.embed_in = value
86
+
87
+ def forward(
88
+ self,
89
+ input_ids: Optional[torch.LongTensor] = None,
90
+ attention_mask: Optional[torch.FloatTensor] = None,
91
+ position_ids: Optional[torch.LongTensor] = None,
92
+ head_mask: Optional[torch.FloatTensor] = None,
93
+ inputs_embeds: Optional[torch.FloatTensor] = None,
94
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
95
+ use_cache: Optional[bool] = None,
96
+ output_attentions: Optional[bool] = None,
97
+ output_hidden_states: Optional[bool] = None,
98
+ return_dict: Optional[bool] = None,
99
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
100
+ r"""
101
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
102
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
103
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
104
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
105
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
106
+ use_cache (`bool`, *optional*):
107
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
108
+ `past_key_values`).
109
+ """
110
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
111
+ output_hidden_states = (
112
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
113
+ )
114
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
115
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
116
+
117
+ if input_ids is not None and inputs_embeds is not None:
118
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
119
+ elif input_ids is not None:
120
+ input_shape = input_ids.size()
121
+ elif inputs_embeds is not None:
122
+ input_shape = inputs_embeds.size()[:-1]
123
+ else:
124
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
125
+
126
+ batch_size, seq_length = input_shape
127
+
128
+ if past_key_values is None:
129
+ past_length = 0
130
+ past_key_values = tuple([None] * self.config.num_hidden_layers)
131
+ else:
132
+ past_length = past_key_values[0][0].size(-2)
133
+
134
+ if position_ids is None:
135
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
136
+ position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device)
137
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
138
+ else:
139
+ position_ids = position_ids.view(-1, seq_length).long()
140
+
141
+ # Attention mask.
142
+ if attention_mask is not None:
143
+ assert batch_size > 0, "batch_size has to be defined and > 0"
144
+ attention_mask = attention_mask.view(batch_size, -1)
145
+ # We create a 3D attention mask from a 2D tensor mask.
146
+ # Sizes are [batch_size, 1, 1, to_seq_length]
147
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
148
+ # this attention mask is more simple than the triangular masking of causal attention
149
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
150
+ attention_mask = attention_mask[:, None, None, :]
151
+
152
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
153
+ # masked positions, this operation will create a tensor which is 0.0 for
154
+ # positions we want to attend and the dtype's smallest value for masked positions.
155
+ # Since we are adding it to the raw scores before the softmax, this is
156
+ # effectively the same as removing these entirely.
157
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
158
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
159
+
160
+ # Prepare head mask if needed
161
+ # 1.0 in head_mask indicate we keep the head
162
+ # attention_probs has shape bsz x n_heads x N x N
163
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
164
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
165
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
166
+
167
+ if inputs_embeds is None:
168
+ inputs_embeds = self.embed_in(input_ids)
169
+
170
+ hidden_states = inputs_embeds
171
+
172
+ if self.gradient_checkpointing and self.training:
173
+ if use_cache:
174
+ logger.warning(
175
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
176
+ )
177
+ use_cache = False
178
+
179
+ presents = () if use_cache else None
180
+ all_attentions = () if output_attentions else None
181
+ all_hidden_states = () if output_hidden_states else None
182
+ for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)):
183
+ if output_hidden_states:
184
+ all_hidden_states = all_hidden_states + (hidden_states,)
185
+
186
+ if self.gradient_checkpointing and self.training:
187
+
188
+ def create_custom_forward(module):
189
+ def custom_forward(*inputs):
190
+ # None for layer_past
191
+ return module(*inputs, use_cache, None, output_attentions)
192
+
193
+ return custom_forward
194
+
195
+ outputs = torch.utils.checkpoint.checkpoint(
196
+ create_custom_forward(layer),
197
+ hidden_states,
198
+ attention_mask,
199
+ position_ids,
200
+ head_mask[i],
201
+ )
202
+ else:
203
+ outputs = layer(
204
+ hidden_states,
205
+ attention_mask=attention_mask,
206
+ position_ids=position_ids,
207
+ head_mask=head_mask[i],
208
+ layer_past=layer_past,
209
+ use_cache=use_cache,
210
+ output_attentions=output_attentions,
211
+ )
212
+ hidden_states = outputs[0]
213
+ if use_cache is True:
214
+ presents = presents + (outputs[1],)
215
+ if output_attentions:
216
+ all_attentions = all_attentions + (outputs[2 if use_cache else 1],)
217
+
218
+ hidden_states = self.final_layer_norm(hidden_states)
219
+ # Add last hidden state
220
+ if output_hidden_states:
221
+ all_hidden_states = all_hidden_states + (hidden_states,)
222
+
223
+ if not return_dict:
224
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
225
+
226
+ return BaseModelOutputWithPast(
227
+ last_hidden_state=hidden_states,
228
+ past_key_values=presents,
229
+ hidden_states=all_hidden_states,
230
+ attentions=all_attentions,
231
+ )
232
+
233
+
234
+ class DecoderLayer(nn.Module):
235
+ def __init__(self, config):
236
+ super().__init__()
237
+ self.use_parallel_residual = config.use_parallel_residual
238
+ self.input_layernorm = nn.LayerNorm(
239
+ config.hidden_size,
240
+ eps=config.layer_norm_eps,
241
+ elementwise_affine=False,
242
+ )
243
+ self.post_attention_layernorm = nn.LayerNorm(
244
+ config.hidden_size,
245
+ eps=config.layer_norm_eps
246
+ )
247
+ self.attention = Attention(config)
248
+ self.mlp = MLP(config)
249
+
250
+ def forward(
251
+ self,
252
+ hidden_states: Optional[torch.FloatTensor],
253
+ attention_mask: Optional[torch.FloatTensor] = None,
254
+ position_ids: Optional[torch.LongTensor] = None,
255
+ head_mask: Optional[torch.FloatTensor] = None,
256
+ use_cache: Optional[bool] = False,
257
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
258
+ output_attentions: Optional[bool] = False,
259
+ ):
260
+ attention_layer_outputs = self.attention(
261
+ self.input_layernorm(hidden_states),
262
+ attention_mask=attention_mask,
263
+ position_ids=position_ids,
264
+ layer_past=layer_past,
265
+ head_mask=head_mask,
266
+ use_cache=use_cache,
267
+ output_attentions=output_attentions,
268
+ )
269
+ attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights)
270
+ outputs = attention_layer_outputs[1:]
271
+
272
+ mlp_output = self.mlp(self.post_attention_layernorm(hidden_states))
273
+ hidden_states = hidden_states + mlp_output + attn_output
274
+
275
+ if use_cache:
276
+ outputs = (hidden_states,) + outputs # hidden_states, present, (attn_weights)
277
+ else:
278
+ outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights)
279
+
280
+ return outputs
281
+
282
+
283
+ class MLP(nn.Module):
284
+ def __init__(self, config: JapaneseStableLMAlphaConfig):
285
+ super().__init__()
286
+ hidden_size = config.hidden_size
287
+ multiple_of = 256
288
+ ff_dim = int(8 * hidden_size / 3)
289
+ intermediate_size = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of)
290
+
291
+ self.packed_input_proj = torch.nn.Linear(hidden_size, 2 * intermediate_size, bias=False)
292
+ self.out_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
293
+ self.act = nn.SiLU()
294
+
295
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
296
+ ff, ff_gate = self.packed_input_proj(x).chunk(2, dim=-1)
297
+ return self.out_proj(ff * self.act(ff_gate))
298
+
299
+
300
+ class RotaryEmbedding(torch.nn.Module):
301
+ """Based on Tri Dao's XPos: https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/layers/rotary.py"""
302
+ def __init__(
303
+ self,
304
+ dim: int,
305
+ max_position_embeddings: int,
306
+ base: int = 10_000,
307
+ scale_base: int = 512,
308
+ device: str = None
309
+ ):
310
+ super().__init__()
311
+ self.dim = dim
312
+ self.seq_len_cached = max_position_embeddings
313
+
314
+ # Set up `inv_freq` term
315
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
316
+ self.register_buffer("inv_freq", inv_freq)
317
+
318
+ # Set up `scale` term
319
+ self.scale_base = scale_base
320
+ scale = (
321
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
322
+ if scale_base is not None else None
323
+ )
324
+ self.register_buffer("scale", scale)
325
+
326
+ # Seet up `cos..` and `sin...` cache terms
327
+ t = torch.arange(self.seq_len_cached, device=device, dtype=torch.float32)
328
+ freqs = torch.outer(t, self.inv_freq)
329
+ # freqs = torch.cat((freqs, freqs), dim=-1)
330
+ seq_range = torch.arange(self.seq_len_cached, dtype=self.scale.dtype, device=self.scale.device)
331
+ power = (seq_range - self.seq_len_cached // 2) / self.scale_base
332
+ scale_cached = self.scale.to(device=power.device) ** power.unsqueeze(-1)
333
+ # scale_cached = torch.cat((scale_cached, scale_cached), dim=-1)
334
+ self.register_buffer("cos_cached", torch.cos(freqs) * scale_cached, persistent=False)
335
+ self.register_buffer("sin_cached", torch.sin(freqs) * scale_cached, persistent=False)
336
+ self.register_buffer("cos_k_cached", torch.cos(freqs) / scale_cached, persistent=False)
337
+ self.register_buffer("sin_k_cached", torch.sin(freqs) / scale_cached, persistent=False)
338
+
339
+ def forward(self, x, seq_len=None):
340
+ if seq_len > self.seq_len_cached:
341
+ self.seq_len_cached = seq_len
342
+ t = torch.arange(seq_len, device=x.device, dtype=torch.float32)
343
+ freqs = torch.outer(t, self.inv_freq)
344
+ freqs = torch.cat((freqs, freqs), dim=-1)
345
+ seq_range = torch.arange(self.seq_len_cached, dtype=self.scale.dtype, device=self.scale.device)
346
+ power = (seq_range - self.seq_len_cached // 2) / self.scale_base
347
+ scale_cached = self.scale.to(device=power.device) ** power.unsqueeze(-1)
348
+ scale_cached = torch.cat((scale_cached, scale_cached), dim=-1)
349
+ self.register_buffer("cos_cached", torch.cos(freqs) * scale_cached, persistent=False)
350
+ self.register_buffer("sin_cached", torch.sin(freqs) * scale_cached, persistent=False)
351
+ self.register_buffer("cos_k_cached", torch.cos(freqs) / scale_cached, persistent=False)
352
+ self.register_buffer("sin_k_cached", torch.sin(freqs) / scale_cached, persistent=False)
353
+ return (
354
+ self.cos_cached[:seq_len, ...],
355
+ self.sin_cached[:seq_len, ...],
356
+ self.cos_k_cached[:seq_len, ...],
357
+ self.sin_k_cached[:seq_len, ...],
358
+ )
359
+
360
+
361
+ def rotate_half(x):
362
+ x1, x2 = x.chunk(2, dim=-1)
363
+ return torch.cat((-x2, x1), dim=-1)
364
+
365
+
366
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, cos_k=None, sin_k=None):
367
+ """
368
+ q, k: [bs, num_heads, seq_len, rot_dim]
369
+ cos, sin: [seq_len, rot_dim / 2]
370
+ position_ids: [bs, seq_len]
371
+ """
372
+ # print(f"q: {q.shape}, k: {k.shape}, cos: {cos.shape}, sin: {sin.shape}, position_ids: {position_ids.shape}")
373
+ import einops
374
+ cos = einops.repeat(cos, 's r -> s (2 r)')
375
+ sin = einops.repeat(sin, 's r -> s (2 r)')
376
+ cos_k = einops.repeat(cos_k, 's r -> s (2 r)')
377
+ sin_k = einops.repeat(sin_k, 's r -> s (2 r)')
378
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, rot_dim]
379
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, rot_dim]
380
+ cos_k = cos_k[position_ids].unsqueeze(1) # [bs, 1, seq_len, rot_dim]
381
+ sin_k = sin_k[position_ids].unsqueeze(1) # [bs, 1, seq_len, rot_dim]
382
+
383
+ q_embed = (q * cos) + (rotate_half(q) * sin)
384
+ k_embed = (k * cos_k) + (rotate_half(k) * sin_k)
385
+ return q_embed, k_embed
386
+
387
+
388
+ class Attention(nn.Module):
389
+ def __init__(self, config):
390
+ super().__init__()
391
+ self.num_attention_heads = config.num_attention_heads
392
+ self.hidden_size = config.hidden_size
393
+ if self.hidden_size % self.num_attention_heads != 0:
394
+ raise ValueError(
395
+ "The hidden size is not divisble by the number of attention heads! Make sure to update them"
396
+ )
397
+ self.head_size = self.hidden_size // self.num_attention_heads
398
+
399
+ max_positions = config.max_position_embeddings
400
+ self.register_buffer(
401
+ "bias",
402
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
403
+ 1, 1, max_positions, max_positions
404
+ ),
405
+ persistent=False,
406
+ )
407
+ self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
408
+
409
+ self.rotary_ndims = int(self.head_size * config.rotary_pct)
410
+ self.rotary_emb = RotaryEmbedding(
411
+ self.rotary_ndims,
412
+ max_position_embeddings=config.max_position_embeddings,
413
+ base=config.rotary_emb_base,
414
+ scale_base=config.rotary_scale_base,
415
+ )
416
+
417
+ self.register_buffer(
418
+ "norm_factor",
419
+ torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()),
420
+ persistent=False,
421
+ )
422
+
423
+ self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
424
+ self.dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
425
+
426
+ def forward(
427
+ self,
428
+ hidden_states: torch.FloatTensor,
429
+ attention_mask: torch.FloatTensor,
430
+ position_ids: torch.LongTensor,
431
+ head_mask: Optional[torch.FloatTensor] = None,
432
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
433
+ use_cache: Optional[bool] = False,
434
+ output_attentions: Optional[bool] = False,
435
+ ):
436
+ has_layer_past = layer_past is not None
437
+
438
+ # Compute QKV
439
+ # Attention heads [batch, seq_len, hidden_size]
440
+ # --> [batch, seq_len, (np * 3 * head_size)]
441
+ qkv = self.query_key_value(hidden_states)
442
+
443
+ # [batch, seq_len, (num_heads * 3 * head_size)]
444
+ # --> [batch, seq_len, num_heads, 3 * head_size]
445
+ new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)
446
+ qkv = qkv.view(*new_qkv_shape)
447
+
448
+ # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
449
+ query = qkv[..., : self.head_size].permute(0, 2, 1, 3)
450
+ key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
451
+ value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)
452
+
453
+ # Compute rotary embeddings on rotary_ndims
454
+ query_rot = query[..., : self.rotary_ndims]
455
+ query_pass = query[..., self.rotary_ndims :]
456
+ key_rot = key[..., : self.rotary_ndims]
457
+ key_pass = key[..., self.rotary_ndims :]
458
+
459
+ # Compute token offset for rotary embeddings (when decoding)
460
+ kv_seq_len = key.shape[-2]
461
+ if has_layer_past:
462
+ kv_seq_len += layer_past[0].shape[-2]
463
+
464
+ # Add rotary embeddings to query and key
465
+ # TODO: Check if using xpos
466
+ cos, sin, cos_k, sin_k = self.rotary_emb(value, seq_len=kv_seq_len)
467
+ query, key = apply_rotary_pos_emb(
468
+ query_rot, key_rot, cos, sin, position_ids, cos_k=cos_k, sin_k=sin_k)
469
+
470
+ query = torch.cat((query, query_pass), dim=-1)
471
+ key = torch.cat((key, key_pass), dim=-1)
472
+
473
+ # Cache QKV values
474
+ if has_layer_past:
475
+ past_key = layer_past[0]
476
+ past_value = layer_past[1]
477
+ key = torch.cat((past_key, key), dim=-2)
478
+ value = torch.cat((past_value, value), dim=-2)
479
+ present = (key, value) if use_cache else None
480
+
481
+ # Compute attention
482
+ attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
483
+
484
+ # Merge attn_head_size dim and num_attn_heads dim into hidden dim
485
+ # [bs, seq_len, num_attention_heads, attn_head_size]
486
+ attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
487
+ attn_output = attn_output.view(attn_output.size(0), attn_output.size(1), self.num_attention_heads * self.head_size)
488
+
489
+ attn_output = self.dense(attn_output)
490
+
491
+ outputs = (attn_output, present)
492
+ if output_attentions:
493
+ outputs += (attn_weights,)
494
+
495
+ return outputs
496
+
497
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
498
+ # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
499
+ # compute causal mask from causal mask buffer
500
+
501
+ batch_size, num_attention_heads, query_length, attn_head_size = query.size()
502
+ key_length = key.size(-2)
503
+
504
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
505
+
506
+ query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
507
+ key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
508
+ attn_scores = torch.zeros(
509
+ batch_size * num_attention_heads,
510
+ query_length,
511
+ key_length,
512
+ dtype=query.dtype,
513
+ device=key.device,
514
+ )
515
+ attn_scores = torch.baddbmm(
516
+ attn_scores,
517
+ query,
518
+ key.transpose(1, 2),
519
+ beta=1.0,
520
+ alpha=(torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor),
521
+ )
522
+ attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
523
+
524
+ mask_value = torch.finfo(attn_scores.dtype).min
525
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
526
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
527
+ mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype, device=attn_scores.device)
528
+ attn_scores = torch.where(causal_mask, attn_scores, mask_value)
529
+
530
+ if attention_mask is not None:
531
+ # Apply the attention mask
532
+ attn_scores = attn_scores + attention_mask
533
+
534
+ # NOTE: Upcast to float32
535
+ attn_weights = nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32).type_as(value)
536
+
537
+ # Mask heads if we want to
538
+ if head_mask is not None:
539
+ attn_weights = attn_weights * head_mask
540
+
541
+ attn_output = torch.matmul(attn_weights, value)
542
+ return attn_output, attn_weights
543
+
544
+
545
+ def attention_mask_func(attention_scores, ltor_mask):
546
+ attention_scores.masked_fill_(~ltor_mask, torch.finfo(attention_scores.dtype).min)
547
+ return attention_scores
548
+
549
+
550
+ class JapaneseStableLMAlphaForCausalLM(JapaneseStableLMAlphaPreTrainedModel):
551
+ _tied_weights_keys = ["embed_out.weight"]
552
+
553
+ def __init__(self, config):
554
+ super().__init__(config)
555
+
556
+ self.transformer = JapaneseStableLMAlphaModel(config)
557
+ self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
558
+
559
+ # Initialize weights and apply final processing
560
+ self.post_init()
561
+
562
+ def get_output_embeddings(self):
563
+ return self.embed_out
564
+
565
+ def set_output_embeddings(self, new_embeddings):
566
+ self.embed_out = new_embeddings
567
+
568
+ def forward(
569
+ self,
570
+ input_ids: Optional[torch.LongTensor] = None,
571
+ attention_mask: Optional[torch.FloatTensor] = None,
572
+ position_ids: Optional[torch.LongTensor] = None,
573
+ inputs_embeds: Optional[torch.FloatTensor] = None,
574
+ head_mask: Optional[torch.FloatTensor] = None,
575
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
576
+ labels: Optional[torch.LongTensor] = None,
577
+ use_cache: Optional[bool] = None,
578
+ output_attentions: Optional[bool] = None,
579
+ output_hidden_states: Optional[bool] = None,
580
+ return_dict: Optional[bool] = None,
581
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
582
+ r"""
583
+ Example:
584
+
585
+ ```python
586
+ >>> import torch
587
+ >>> from transformers import LlamaTokenizer, JapaneseStableLMAlphaForCausalLM, JapaneseStableLMAlphaConfig
588
+
589
+ >>> tokenizer = LlamaTokenizer.from_pretrained("novelai/nerdstash-tokenizer-v1")
590
+ >>> config = JapaneseStableLMAlphaConfig.from_pretrained("stabilityai/stablelm-ja-base-alpha-7b")
591
+ >>> config.is_decoder = True
592
+ >>> model = JapaneseStableLMAlphaForCausalLM.from_pretrained("stabilityai/stablelm-ja-base-alpha-7b", config=config, trust_remote_code=True)
593
+
594
+ >>> inputs = tokenizer("日本語の美しいところは、", return_tensors="pt")
595
+ >>> outputs = model(**inputs)
596
+
597
+ >>> prediction_logits = outputs.logits
598
+ ```"""
599
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
600
+
601
+ outputs = self.transformer(
602
+ input_ids,
603
+ attention_mask=attention_mask,
604
+ position_ids=position_ids,
605
+ head_mask=head_mask,
606
+ inputs_embeds=inputs_embeds,
607
+ past_key_values=past_key_values,
608
+ use_cache=use_cache,
609
+ output_attentions=output_attentions,
610
+ output_hidden_states=output_hidden_states,
611
+ return_dict=return_dict,
612
+ )
613
+
614
+ hidden_states = outputs[0]
615
+ lm_logits = self.embed_out(hidden_states)
616
+
617
+ lm_loss = None
618
+ if labels is not None:
619
+ # move labels to correct device to enable model parallelism
620
+ labels = labels.to(lm_logits.device)
621
+ # we are doing next-token prediction; shift prediction scores and input ids by one
622
+ shift_logits = lm_logits[:, :-1, :].contiguous()
623
+ labels = labels[:, 1:].contiguous()
624
+ loss_fct = CrossEntropyLoss()
625
+ lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
626
+
627
+ if not return_dict:
628
+ output = (lm_logits,) + outputs[1:]
629
+ return ((lm_loss,) + output) if lm_loss is not None else output
630
+
631
+ return CausalLMOutputWithPast(
632
+ loss=lm_loss,
633
+ logits=lm_logits,
634
+ past_key_values=outputs.past_key_values,
635
+ hidden_states=outputs.hidden_states,
636
+ attentions=outputs.attentions,
637
+ )
638
+
639
+ def prepare_inputs_for_generation(
640
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
641
+ ):
642
+ input_shape = input_ids.shape
643
+
644
+ # cut decoder_input_ids if past is used
645
+ if past_key_values and past_key_values[0] is not None:
646
+ input_ids = input_ids[:, -1:]
647
+
648
+ position_ids = kwargs.get("position_ids", None)
649
+ if attention_mask is not None and position_ids is None:
650
+ # create position_ids on the fly for batch generation
651
+ position_ids = attention_mask.long().cumsum(-1) - 1
652
+ position_ids.masked_fill_(attention_mask == 0, 1)
653
+ if past_key_values:
654
+ position_ids = position_ids[:, -1].unsqueeze(-1)
655
+
656
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
657
+ if attention_mask is None:
658
+ attention_mask = input_ids.new_ones(input_shape)
659
+
660
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
661
+ if inputs_embeds is not None and past_key_values is None:
662
+ model_inputs = {"inputs_embeds": inputs_embeds}
663
+ else:
664
+ model_inputs = {"input_ids": input_ids}
665
+
666
+ model_inputs.update(
667
+ {
668
+ "attention_mask": attention_mask,
669
+ "past_key_values": past_key_values,
670
+ "position_ids": position_ids,
671
+ }
672
+ )
673
+
674
+ return model_inputs
675
+
676
+ def _reorder_cache(self, past_key_values, beam_idx):
677
+ reordered_past = ()
678
+ for layer_past in past_key_values:
679
+ reordered_past += (
680
+ tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
681
+ )
682
+ return reordered_past
preprocessor_config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_convert_rgb": true,
3
+ "do_normalize": true,
4
+ "do_rescale": true,
5
+ "do_resize": true,
6
+ "image_mean": [
7
+ 0.48145466,
8
+ 0.4578275,
9
+ 0.40821073
10
+ ],
11
+ "image_processor_type": "BlipImageProcessor",
12
+ "image_std": [
13
+ 0.26862954,
14
+ 0.26130258,
15
+ 0.27577711
16
+ ],
17
+ "processor_class": "InstructBlipProcessor",
18
+ "resample": 3,
19
+ "rescale_factor": 0.00392156862745098,
20
+ "size": {
21
+ "height": 224,
22
+ "width": 224
23
+ }
24
+ }
pytorch_model.bin.index.fp16.json ADDED
The diff for this file is too large to render. See raw diff
 
pytorch_model.fp16-00001-of-00002.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:edb6de9b5fc91eec3b6a57d6dacd14073001f37fa860e68775beaf3afb7d79dc
3
+ size 9955835753
pytorch_model.fp16-00002-of-00002.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:528041e8ff85ea59f4ba0508efc852ef55c9fe15cedd2956c6f79ee8630f81f6
3
+ size 6474190985
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ sentencepiece
2
+ einops