feipengma commited on
Commit
00b2b6c
1 Parent(s): 0ecb99d
__init__.py ADDED
File without changes
config.json ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "WemmForConditionalGeneration"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_wemm.WeMMConfig",
7
+ "AutoModel": "modeling_wemm.WemmForConditionalGeneration"
8
+ },
9
+ "connector_config": {
10
+ "_name_or_path": "",
11
+ "add_cross_attention": false,
12
+ "architectures": null,
13
+ "attention_dropout": 0.0,
14
+ "bad_words_ids": null,
15
+ "begin_suppress_tokens": null,
16
+ "bos_token_id": null,
17
+ "chunk_size_feed_forward": 0,
18
+ "cross_attention_hidden_size": null,
19
+ "decoder_start_token_id": null,
20
+ "diversity_penalty": 0.0,
21
+ "do_sample": false,
22
+ "early_stopping": false,
23
+ "encoder_no_repeat_ngram_size": 0,
24
+ "eos_token_id": null,
25
+ "exponential_decay_length_penalty": null,
26
+ "finetuning_task": null,
27
+ "forced_bos_token_id": null,
28
+ "forced_eos_token_id": null,
29
+ "hidden_act": "silu",
30
+ "hidden_size": 4096,
31
+ "id2label": {
32
+ "0": "LABEL_0",
33
+ "1": "LABEL_1"
34
+ },
35
+ "integrate_sub_images": null,
36
+ "intermediate_size": 14336,
37
+ "is_decoder": false,
38
+ "is_encoder_decoder": false,
39
+ "label2id": {
40
+ "LABEL_0": 0,
41
+ "LABEL_1": 1
42
+ },
43
+ "length_penalty": 1.0,
44
+ "max_length": 20,
45
+ "min_length": 0,
46
+ "model_type": "Idefics2ConnectorConfig",
47
+ "no_repeat_ngram_size": 0,
48
+ "num_beam_groups": 1,
49
+ "num_beams": 1,
50
+ "num_key_value_heads": 4,
51
+ "num_return_sequences": 1,
52
+ "num_sub_images": null,
53
+ "output_attentions": false,
54
+ "output_hidden_states": false,
55
+ "output_scores": false,
56
+ "pad_token_id": null,
57
+ "prefix": null,
58
+ "problem_type": null,
59
+ "pruned_heads": {},
60
+ "remove_invalid_values": false,
61
+ "repetition_penalty": 1.0,
62
+ "resampler_depth": 3,
63
+ "resampler_head_dim": 96,
64
+ "resampler_n_heads": 16,
65
+ "resampler_n_latents": 64,
66
+ "return_dict": true,
67
+ "return_dict_in_generate": false,
68
+ "rms_norm_eps": 1e-05,
69
+ "sep_token_id": null,
70
+ "suppress_tokens": null,
71
+ "task_specific_params": null,
72
+ "temperature": 1.0,
73
+ "tf_legacy_loss": false,
74
+ "tie_encoder_decoder": false,
75
+ "tie_word_embeddings": true,
76
+ "tokenizer_class": null,
77
+ "top_k": 50,
78
+ "top_p": 1.0,
79
+ "torch_dtype": null,
80
+ "torchscript": false,
81
+ "typical_p": 1.0,
82
+ "use_bfloat16": false,
83
+ "vision_hidden_size": 1152
84
+ },
85
+ "do_image_splitting": false,
86
+ "downsampler_config": {
87
+ "_name_or_path": "",
88
+ "add_cross_attention": false,
89
+ "architectures": [
90
+ "DownsamplerModel"
91
+ ],
92
+ "auto_map": {
93
+ "AutoConfig": "configuration_downsampler.DownsamplerConfig",
94
+ "AutoModel": "modeling_downsampler.DownsamplerModel"
95
+ },
96
+ "bad_words_ids": null,
97
+ "begin_suppress_tokens": null,
98
+ "bias": false,
99
+ "bos_token_id": null,
100
+ "chunk_size_feed_forward": 0,
101
+ "cross_attention_hidden_size": null,
102
+ "decoder_start_token_id": null,
103
+ "depth": 2,
104
+ "diversity_penalty": 0.0,
105
+ "do_sample": false,
106
+ "early_stopping": false,
107
+ "encoder_no_repeat_ngram_size": 0,
108
+ "eos_token_id": null,
109
+ "exponential_decay_length_penalty": null,
110
+ "finetuning_task": null,
111
+ "forced_bos_token_id": null,
112
+ "forced_eos_token_id": null,
113
+ "hidden_act": "gelu",
114
+ "id2label": {
115
+ "0": "LABEL_0",
116
+ "1": "LABEL_1"
117
+ },
118
+ "is_decoder": false,
119
+ "is_encoder_decoder": false,
120
+ "kernel_size": 4,
121
+ "label2id": {
122
+ "LABEL_0": 0,
123
+ "LABEL_1": 1
124
+ },
125
+ "length_penalty": 1.0,
126
+ "llm_hidden_size": 4096,
127
+ "max_length": 20,
128
+ "min_length": 0,
129
+ "model_type": "downsampler",
130
+ "no_repeat_ngram_size": 0,
131
+ "num_beam_groups": 1,
132
+ "num_beams": 1,
133
+ "num_return_sequences": 1,
134
+ "output_attentions": false,
135
+ "output_hidden_states": false,
136
+ "output_scores": false,
137
+ "pad_token_id": null,
138
+ "prefix": null,
139
+ "problem_type": null,
140
+ "pruned_heads": {},
141
+ "remove_invalid_values": false,
142
+ "repetition_penalty": 1.0,
143
+ "return_dict": true,
144
+ "return_dict_in_generate": false,
145
+ "sep_token_id": null,
146
+ "stride": 4,
147
+ "suppress_tokens": null,
148
+ "task_specific_params": null,
149
+ "temperature": 1.0,
150
+ "tf_legacy_loss": false,
151
+ "tie_encoder_decoder": false,
152
+ "tie_word_embeddings": true,
153
+ "tokenizer_class": null,
154
+ "top_k": 50,
155
+ "top_p": 1.0,
156
+ "torch_dtype": "float32",
157
+ "torchscript": false,
158
+ "typical_p": 1.0,
159
+ "use_bfloat16": false,
160
+ "visual_hidden_size": 1152
161
+ },
162
+ "image_processor": {
163
+ "do_convert_rgb": true,
164
+ "do_image_splitting": true,
165
+ "do_normalize": true,
166
+ "do_pad": true,
167
+ "do_rescale": true,
168
+ "do_resize": true,
169
+ "image_mean": [
170
+ 0.5,
171
+ 0.5,
172
+ 0.5
173
+ ],
174
+ "image_processor_type": "Idefics2ImageProcessor",
175
+ "image_std": [
176
+ 0.5,
177
+ 0.5,
178
+ 0.5
179
+ ],
180
+ "resample": 2,
181
+ "rescale_factor": 0.00392156862745098,
182
+ "size": {
183
+ "longest_edge": 1960,
184
+ "shortest_edge": 756
185
+ }
186
+ },
187
+ "model_type": "wemm_hf",
188
+ "projector_config": {
189
+ "_name_or_path": "",
190
+ "add_cross_attention": false,
191
+ "architectures": [
192
+ "ProjectorModel"
193
+ ],
194
+ "auto_map": {
195
+ "AutoConfig": "configuration_projector.ProjectorConfig",
196
+ "AutoModel": "modeling_projector.ProjectorModel"
197
+ },
198
+ "bad_words_ids": null,
199
+ "begin_suppress_tokens": null,
200
+ "bias": true,
201
+ "bos_token_id": null,
202
+ "chunk_size_feed_forward": 0,
203
+ "cross_attention_hidden_size": null,
204
+ "decoder_start_token_id": null,
205
+ "depth": 2,
206
+ "diversity_penalty": 0.0,
207
+ "do_sample": false,
208
+ "early_stopping": false,
209
+ "encoder_no_repeat_ngram_size": 0,
210
+ "eos_token_id": null,
211
+ "exponential_decay_length_penalty": null,
212
+ "finetuning_task": null,
213
+ "forced_bos_token_id": null,
214
+ "forced_eos_token_id": null,
215
+ "hidden_act": "gelu",
216
+ "id2label": {
217
+ "0": "LABEL_0",
218
+ "1": "LABEL_1"
219
+ },
220
+ "is_decoder": false,
221
+ "is_encoder_decoder": false,
222
+ "label2id": {
223
+ "LABEL_0": 0,
224
+ "LABEL_1": 1
225
+ },
226
+ "length_penalty": 1.0,
227
+ "llm_hidden_size": 4096,
228
+ "max_length": 20,
229
+ "min_length": 0,
230
+ "model_type": "projector",
231
+ "no_repeat_ngram_size": 0,
232
+ "num_beam_groups": 1,
233
+ "num_beams": 1,
234
+ "num_return_sequences": 1,
235
+ "output_attentions": false,
236
+ "output_hidden_states": false,
237
+ "output_scores": false,
238
+ "pad_token_id": null,
239
+ "prefix": null,
240
+ "problem_type": null,
241
+ "pruned_heads": {},
242
+ "remove_invalid_values": false,
243
+ "repetition_penalty": 1.0,
244
+ "return_dict": true,
245
+ "return_dict_in_generate": false,
246
+ "sep_token_id": null,
247
+ "suppress_tokens": null,
248
+ "task_specific_params": null,
249
+ "temperature": 1.0,
250
+ "tf_legacy_loss": false,
251
+ "tie_encoder_decoder": false,
252
+ "tie_word_embeddings": true,
253
+ "tokenizer_class": null,
254
+ "top_k": 50,
255
+ "top_p": 1.0,
256
+ "torch_dtype": "float32",
257
+ "torchscript": false,
258
+ "typical_p": 1.0,
259
+ "use_bfloat16": false,
260
+ "visual_hidden_size": 1152
261
+ },
262
+ "spliter_emb_config": {
263
+ "embedding_dim": 4096,
264
+ "num_embeddings": 2
265
+ },
266
+ "text_config": {
267
+ "_name_or_path": "",
268
+ "add_cross_attention": false,
269
+ "architectures": [
270
+ "InternLM2ForCausalLM"
271
+ ],
272
+ "attn_implementation": "flash_attention_2",
273
+ "auto_map": {
274
+ "AutoConfig": "configuration_internlm2.InternLM2Config",
275
+ "AutoModel": "modeling_internlm2.InternLM2ForCausalLM",
276
+ "AutoModelForCausalLM": "modeling_internlm2.InternLM2ForCausalLM"
277
+ },
278
+ "bad_words_ids": null,
279
+ "begin_suppress_tokens": null,
280
+ "bias": false,
281
+ "bos_token_id": 1,
282
+ "chunk_size_feed_forward": 0,
283
+ "cross_attention_hidden_size": null,
284
+ "decoder_start_token_id": null,
285
+ "diversity_penalty": 0.0,
286
+ "do_sample": false,
287
+ "early_stopping": false,
288
+ "encoder_no_repeat_ngram_size": 0,
289
+ "eos_token_id": 2,
290
+ "exponential_decay_length_penalty": null,
291
+ "finetuning_task": null,
292
+ "forced_bos_token_id": null,
293
+ "forced_eos_token_id": null,
294
+ "hidden_act": "silu",
295
+ "hidden_size": 4096,
296
+ "id2label": {
297
+ "0": "LABEL_0",
298
+ "1": "LABEL_1"
299
+ },
300
+ "initializer_range": 0.02,
301
+ "intermediate_size": 14336,
302
+ "is_decoder": false,
303
+ "is_encoder_decoder": false,
304
+ "label2id": {
305
+ "LABEL_0": 0,
306
+ "LABEL_1": 1
307
+ },
308
+ "length_penalty": 1.0,
309
+ "max_length": 20,
310
+ "max_position_embeddings": 32768,
311
+ "min_length": 0,
312
+ "model_type": "internlm2",
313
+ "no_repeat_ngram_size": 0,
314
+ "num_attention_heads": 32,
315
+ "num_beam_groups": 1,
316
+ "num_beams": 1,
317
+ "num_hidden_layers": 32,
318
+ "num_key_value_heads": 8,
319
+ "num_return_sequences": 1,
320
+ "output_attentions": false,
321
+ "output_hidden_states": false,
322
+ "output_scores": false,
323
+ "pad_token_id": 2,
324
+ "prefix": null,
325
+ "problem_type": null,
326
+ "pruned_heads": {},
327
+ "remove_invalid_values": false,
328
+ "repetition_penalty": 1.0,
329
+ "return_dict": true,
330
+ "return_dict_in_generate": false,
331
+ "rms_norm_eps": 1e-05,
332
+ "rope_scaling": {
333
+ "factor": 2.0,
334
+ "type": "dynamic"
335
+ },
336
+ "rope_theta": 1000000,
337
+ "sep_token_id": null,
338
+ "suppress_tokens": null,
339
+ "task_specific_params": null,
340
+ "temperature": 1.0,
341
+ "tf_legacy_loss": false,
342
+ "tie_encoder_decoder": false,
343
+ "tie_word_embeddings": false,
344
+ "tokenizer_class": null,
345
+ "top_k": 50,
346
+ "top_p": 1.0,
347
+ "torch_dtype": "float16",
348
+ "torchscript": false,
349
+ "typical_p": 1.0,
350
+ "use_bfloat16": false,
351
+ "use_cache": true,
352
+ "vocab_size": 92544
353
+ },
354
+ "tokenizer_path": "internlm/internlm2-chat-7b",
355
+ "torch_dtype": "bfloat16",
356
+ "transformers_version": "4.38.1",
357
+ "vision_config": {
358
+ "_name_or_path": "",
359
+ "add_cross_attention": false,
360
+ "architectures": null,
361
+ "attention_dropout": 0.0,
362
+ "bad_words_ids": null,
363
+ "begin_suppress_tokens": null,
364
+ "bos_token_id": null,
365
+ "chunk_size_feed_forward": 0,
366
+ "cross_attention_hidden_size": null,
367
+ "decoder_start_token_id": null,
368
+ "diversity_penalty": 0.0,
369
+ "do_sample": false,
370
+ "early_stopping": false,
371
+ "encoder_no_repeat_ngram_size": 0,
372
+ "eos_token_id": null,
373
+ "exponential_decay_length_penalty": null,
374
+ "finetuning_task": null,
375
+ "forced_bos_token_id": null,
376
+ "forced_eos_token_id": null,
377
+ "hidden_act": "gelu_pytorch_tanh",
378
+ "hidden_size": 1152,
379
+ "id2label": {
380
+ "0": "LABEL_0",
381
+ "1": "LABEL_1"
382
+ },
383
+ "image_size": 1960,
384
+ "initializer_range": 0.02,
385
+ "intermediate_size": 4304,
386
+ "is_decoder": false,
387
+ "is_encoder_decoder": false,
388
+ "label2id": {
389
+ "LABEL_0": 0,
390
+ "LABEL_1": 1
391
+ },
392
+ "layer_norm_eps": 1e-06,
393
+ "length_penalty": 1.0,
394
+ "max_length": 20,
395
+ "min_length": 0,
396
+ "model_type": "Idefics2VisionConfig",
397
+ "no_repeat_ngram_size": 0,
398
+ "num_attention_heads": 16,
399
+ "num_beam_groups": 1,
400
+ "num_beams": 1,
401
+ "num_channels": 3,
402
+ "num_hidden_layers": 27,
403
+ "num_return_sequences": 1,
404
+ "output_attentions": false,
405
+ "output_hidden_states": false,
406
+ "output_scores": false,
407
+ "pad_token_id": null,
408
+ "patch_size": 14,
409
+ "prefix": null,
410
+ "problem_type": null,
411
+ "pruned_heads": {},
412
+ "remove_invalid_values": false,
413
+ "repetition_penalty": 1.0,
414
+ "return_dict": true,
415
+ "return_dict_in_generate": false,
416
+ "sep_token_id": null,
417
+ "suppress_tokens": null,
418
+ "task_specific_params": null,
419
+ "temperature": 1.0,
420
+ "tf_legacy_loss": false,
421
+ "tie_encoder_decoder": false,
422
+ "tie_word_embeddings": true,
423
+ "tokenizer_class": null,
424
+ "top_k": 50,
425
+ "top_p": 1.0,
426
+ "torch_dtype": null,
427
+ "torchscript": false,
428
+ "typical_p": 1.0,
429
+ "use_bfloat16": false
430
+ }
431
+ }
configuration_connector.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, PreTrainedModel
2
+ import json
3
+
4
+ class Idefics2ConnectorConfig(PretrainedConfig):
5
+ r"""
6
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
7
+ documentation from [`PretrainedConfig`] for more information.
8
+
9
+ Args:
10
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
11
+ The non-linear activation function (function or string) in the perceiver block.
12
+ resampler_n_latents (`int`, *optional*, defaults to 64):
13
+ Number of latent embeddings to resample ("compress") the input sequence to (usually < 128).
14
+ resampler_depth (`int`, *optional*, defaults to 3):
15
+ Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (<= 3).
16
+ resampler_n_heads (`int`, *optional*, defaults to 16):
17
+ Number of heads in each Transformer block (for multi-headed self-attention).
18
+ resampler_head_dim (`int`, *optional*, defaults to 96):
19
+ Dimensionality of each head projection in the Transformer block.
20
+ num_key_value_heads (`int`, *optional*, defaults to 4):
21
+ Number of key-value heads in the perceiver attention block.
22
+ attention_dropout (`float`, *optional*, defaults to 0.0):
23
+ The dropout ratio for the attention probabilities.
24
+ """
25
+ _auto_class = 'AutoConfig'
26
+ model_type = "Idefics2ConnectorConfig"
27
+
28
+ def __init__(
29
+ self,
30
+ vision_hidden_size=1152,
31
+ hidden_size=4096,
32
+ hidden_act="silu",
33
+ resampler_n_latents=64,
34
+ resampler_depth=3,
35
+ rms_norm_eps=1e-05,
36
+ resampler_n_heads=16,
37
+ resampler_head_dim=96,
38
+ num_key_value_heads=4,
39
+ attention_dropout=0.0,
40
+ intermediate_size=14336,
41
+ integrate_sub_images=None,
42
+ num_sub_images=None,
43
+ **kwargs,
44
+ ):
45
+ super().__init__(**kwargs)
46
+ self.vision_hidden_size = vision_hidden_size
47
+ self.hidden_size = hidden_size
48
+ self.hidden_act = hidden_act
49
+ self.resampler_n_latents = resampler_n_latents
50
+ self.resampler_depth = resampler_depth
51
+ self.rms_norm_eps = rms_norm_eps
52
+ self.resampler_n_heads = resampler_n_heads
53
+ self.num_key_value_heads = num_key_value_heads
54
+ self.resampler_head_dim = resampler_head_dim
55
+ self.attention_dropout = attention_dropout
56
+ self.intermediate_size = intermediate_size
57
+ self.integrate_sub_images = integrate_sub_images
58
+ self.num_sub_images = num_sub_images
59
+
60
+ if self.num_key_value_heads > self.resampler_n_heads:
61
+ raise ValueError(
62
+ f"num_key_value_heads={self.num_key_value_heads} must be less than or equal to"
63
+ f" resampler_n_heads={self.resampler_n_heads}"
64
+ )
65
+
66
+ @classmethod
67
+ def from_pretrained(cls, config_path, **kwargs) -> "PretrainedConfig":
68
+
69
+ with open(config_path, "r", encoding="utf-8") as f:
70
+ config_dict = json.load(f)
71
+ cls = Idefics2ConnectorConfig(
72
+ vision_hidden_size=config_dict['vision_hidden_size'],
73
+ hidden_size=config_dict['hidden_size'],
74
+ hidden_act="silu",
75
+ resampler_n_latents=config_dict['resampler_n_latents'],
76
+ resampler_depth=config_dict['resampler_depth'],
77
+ rms_norm_eps=config_dict['rms_norm_eps'],
78
+ intermediate_size=config_dict['intermediate_size'],
79
+ integrate_sub_images=config_dict['integrate_sub_images'],
80
+ num_sub_images=config_dict['num_sub_images']
81
+ )
82
+
83
+ return cls
configuration_downsampler.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from transformers import PretrainedConfig
3
+
4
+
5
+ class DownsamplerConfig(PretrainedConfig):
6
+ model_type = 'downsampler'
7
+ _auto_class = 'AutoConfig'
8
+
9
+ def __init__(
10
+ self,
11
+ kernel_size=1,
12
+ stride=1,
13
+ visual_hidden_size=4096,
14
+ llm_hidden_size=4096,
15
+ depth=2,
16
+ hidden_act='gelu',
17
+ bias=False,
18
+ **kwargs,
19
+ ):
20
+ self.visual_hidden_size = visual_hidden_size
21
+ self.llm_hidden_size = llm_hidden_size
22
+ self.depth = depth
23
+ self.hidden_act = hidden_act
24
+ self.bias = bias
25
+ self.kernel_size = kernel_size
26
+ self.stride = stride
27
+ super().__init__(**kwargs)
configuration_internlm2.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on transformers/src/transformers/models/llama/configuration_llama.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """ InternLM2 model configuration"""
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.utils import logging
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ INTERNLM2_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
25
+
26
+
27
+ # Modified from transformers.model.llama.configuration_llama.LlamaConfig
28
+ class InternLM2Config(PretrainedConfig):
29
+ r"""
30
+ This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate
31
+ an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a
32
+ configuration with the defaults will yield a similar configuration to that of the InternLM2-7B.
33
+
34
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
+ documentation from [`PretrainedConfig`] for more information.
36
+
37
+
38
+ Args:
39
+ vocab_size (`int`, *optional*, defaults to 32000):
40
+ Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the
41
+ `inputs_ids` passed when calling [`InternLM2Model`]
42
+ hidden_size (`int`, *optional*, defaults to 4096):
43
+ Dimension of the hidden representations.
44
+ intermediate_size (`int`, *optional*, defaults to 11008):
45
+ Dimension of the MLP representations.
46
+ num_hidden_layers (`int`, *optional*, defaults to 32):
47
+ Number of hidden layers in the Transformer encoder.
48
+ num_attention_heads (`int`, *optional*, defaults to 32):
49
+ Number of attention heads for each attention layer in the Transformer encoder.
50
+ num_key_value_heads (`int`, *optional*):
51
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
52
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
53
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
54
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
55
+ by meanpooling all the original heads within that group. For more details checkout [this
56
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
57
+ `num_attention_heads`.
58
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
59
+ The non-linear activation function (function or string) in the decoder.
60
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
61
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
62
+ just in case (e.g., 512 or 1024 or 2048).
63
+ initializer_range (`float`, *optional*, defaults to 0.02):
64
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
65
+ rms_norm_eps (`float`, *optional*, defaults to 1e-12):
66
+ The epsilon used by the rms normalization layers.
67
+ use_cache (`bool`, *optional*, defaults to `True`):
68
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
69
+ relevant if `config.is_decoder=True`.
70
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
71
+ Whether to tie weight embeddings
72
+ Example:
73
+
74
+ """
75
+ model_type = "internlm2"
76
+ _auto_class = "AutoConfig"
77
+
78
+ def __init__( # pylint: disable=W0102
79
+ self,
80
+ vocab_size=103168,
81
+ hidden_size=4096,
82
+ intermediate_size=11008,
83
+ num_hidden_layers=32,
84
+ num_attention_heads=32,
85
+ num_key_value_heads=None,
86
+ hidden_act="silu",
87
+ max_position_embeddings=2048,
88
+ initializer_range=0.02,
89
+ rms_norm_eps=1e-6,
90
+ use_cache=True,
91
+ pad_token_id=0,
92
+ bos_token_id=1,
93
+ eos_token_id=2,
94
+ tie_word_embeddings=False,
95
+ bias=True,
96
+ rope_theta=10000,
97
+ rope_scaling=None,
98
+ attn_implementation="eager",
99
+ **kwargs,
100
+ ):
101
+ self.vocab_size = vocab_size
102
+ self.max_position_embeddings = max_position_embeddings
103
+ self.hidden_size = hidden_size
104
+ self.intermediate_size = intermediate_size
105
+ self.num_hidden_layers = num_hidden_layers
106
+ self.num_attention_heads = num_attention_heads
107
+ self.bias = bias
108
+
109
+ if num_key_value_heads is None:
110
+ num_key_value_heads = num_attention_heads
111
+ self.num_key_value_heads = num_key_value_heads
112
+
113
+ self.hidden_act = hidden_act
114
+ self.initializer_range = initializer_range
115
+ self.rms_norm_eps = rms_norm_eps
116
+ self.use_cache = use_cache
117
+ self.rope_theta = rope_theta
118
+ self.rope_scaling = rope_scaling
119
+ self._rope_scaling_validation()
120
+
121
+ self.attn_implementation = attn_implementation
122
+ if self.attn_implementation is None:
123
+ self.attn_implementation = "eager"
124
+ super().__init__(
125
+ pad_token_id=pad_token_id,
126
+ bos_token_id=bos_token_id,
127
+ eos_token_id=eos_token_id,
128
+ tie_word_embeddings=tie_word_embeddings,
129
+ **kwargs,
130
+ )
131
+
132
+ def _rope_scaling_validation(self):
133
+ """
134
+ Validate the `rope_scaling` configuration.
135
+ """
136
+ if self.rope_scaling is None:
137
+ return
138
+
139
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
140
+ raise ValueError(
141
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
142
+ f"got {self.rope_scaling}"
143
+ )
144
+ rope_scaling_type = self.rope_scaling.get("type", None)
145
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
146
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
147
+ raise ValueError(
148
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
149
+ )
150
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor < 1.0:
151
+ raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}")
configuration_projector.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from transformers import PretrainedConfig
3
+
4
+
5
+ class ProjectorConfig(PretrainedConfig):
6
+ model_type = 'projector'
7
+ _auto_class = 'AutoConfig'
8
+
9
+ def __init__(
10
+ self,
11
+ visual_hidden_size=4096,
12
+ llm_hidden_size=4096,
13
+ depth=2,
14
+ hidden_act='gelu',
15
+ bias=True,
16
+ **kwargs,
17
+ ):
18
+ self.visual_hidden_size = visual_hidden_size
19
+ self.llm_hidden_size = llm_hidden_size
20
+ self.depth = depth
21
+ self.hidden_act = hidden_act
22
+ self.bias = bias
23
+ super().__init__(**kwargs)
configuration_vision.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import PretrainedConfig, PreTrainedModel
3
+ import json
4
+
5
+ class Idefics2VisionConfig(PretrainedConfig):
6
+ model_type = "Idefics2VisionConfig"
7
+
8
+ def __init__(
9
+ self,
10
+ hidden_size=768,
11
+ intermediate_size=3072,
12
+ num_hidden_layers=12,
13
+ num_attention_heads=12,
14
+ num_channels=3,
15
+ image_size=224,
16
+ patch_size=32,
17
+ hidden_act="gelu_pytorch_tanh",
18
+ layer_norm_eps=1e-6,
19
+ attention_dropout=0.0,
20
+ initializer_range=0.02,
21
+ model_type='Idefics2VisionConfig',
22
+ **kwargs,
23
+ ):
24
+
25
+ self.hidden_size = hidden_size
26
+ self.intermediate_size = intermediate_size
27
+ self.num_hidden_layers = num_hidden_layers
28
+ self.num_attention_heads = num_attention_heads
29
+ self.num_channels = num_channels
30
+ self.patch_size = patch_size
31
+ self.image_size = image_size
32
+ self.attention_dropout = attention_dropout
33
+ self.layer_norm_eps = layer_norm_eps
34
+ self.hidden_act = hidden_act
35
+ self.initializer_range = initializer_range
36
+
37
+ super().__init__(**kwargs)
38
+
configuration_wemm.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+ import json
4
+ # from transformers import CONFIG_MAPPING
5
+ from peft import PeftConfig
6
+ from .configuration_vision import Idefics2VisionConfig
7
+ from .configuration_internlm2 import InternLM2Config
8
+ from .configuration_projector import ProjectorConfig
9
+ from .configuration_connector import Idefics2ConnectorConfig
10
+ from .image_processor import Idefics2ImageProcessor
11
+ from .configuration_downsampler import DownsamplerConfig
12
+
13
+ class WeMMConfig(PretrainedConfig):
14
+ model_type = "wemm_hf"
15
+
16
+ def __init__(
17
+ self,
18
+ vision_config = None,
19
+ text_config = None,
20
+ projector_config = None,
21
+ connector_config = None,
22
+ adapter_path = None,
23
+ image_processor = None,
24
+ do_image_splitting = False,
25
+ spliter_emb_config = None,
26
+ downsampler_config = None,
27
+ tokenizer_path = None,
28
+ **kwargs
29
+ ):
30
+ # vision_config
31
+ if vision_config is not None:
32
+ self.vision_config = Idefics2VisionConfig(**vision_config)
33
+
34
+
35
+ # text_config
36
+ if text_config is not None:
37
+ self.text_config = InternLM2Config(**text_config)
38
+
39
+ # projector_config
40
+ if projector_config is not None:
41
+ self.projector_config = ProjectorConfig(**projector_config)
42
+
43
+ # connector_config
44
+ if connector_config is not None:
45
+ self.connector_config = Idefics2ConnectorConfig(**connector_config)
46
+
47
+ if image_processor is not None:
48
+ self.image_processor = image_processor
49
+
50
+
51
+ if adapter_path is not None:
52
+ self.adapter_path = adapter_path
53
+
54
+ self.do_image_splitting = do_image_splitting
55
+
56
+ if spliter_emb_config is not None:
57
+ self.spliter_emb_config = spliter_emb_config
58
+
59
+ if downsampler_config is not None:
60
+ self.downsampler_config = DownsamplerConfig(**downsampler_config)
61
+
62
+ if tokenizer_path is not None:
63
+ self.tokenizer_path = tokenizer_path
64
+
65
+ super().__init__(**kwargs)
connector.py ADDED
@@ -0,0 +1,711 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, PreTrainedModel
2
+
3
+ import inspect
4
+ import math
5
+ from dataclasses import dataclass
6
+ from typing import Dict, List, Optional, Tuple, Union
7
+ import json
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint
12
+ from torch import nn
13
+ from torch.nn import CrossEntropyLoss
14
+
15
+ from transformers.activations import ACT2FN
16
+ from transformers.cache_utils import Cache, DynamicCache
17
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
18
+ from transformers.modeling_outputs import BaseModelOutput, ModelOutput
19
+ from transformers.utils import (
20
+ add_start_docstrings,
21
+ add_start_docstrings_to_model_forward,
22
+ is_flash_attn_2_available,
23
+ is_flash_attn_greater_or_equal_2_10,
24
+ logging,
25
+ replace_return_docstrings,
26
+ )
27
+
28
+ if is_flash_attn_2_available():
29
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
30
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
31
+
32
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
33
+
34
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
35
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
36
+ """
37
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
38
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
39
+ """
40
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
41
+ if n_rep == 1:
42
+ return hidden_states
43
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
44
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
45
+
46
+ class Idefics2ConnectorConfig(PretrainedConfig):
47
+ r"""
48
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
49
+ documentation from [`PretrainedConfig`] for more information.
50
+
51
+ Args:
52
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
53
+ The non-linear activation function (function or string) in the perceiver block.
54
+ resampler_n_latents (`int`, *optional*, defaults to 64):
55
+ Number of latent embeddings to resample ("compress") the input sequence to (usually < 128).
56
+ resampler_depth (`int`, *optional*, defaults to 3):
57
+ Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (<= 3).
58
+ resampler_n_heads (`int`, *optional*, defaults to 16):
59
+ Number of heads in each Transformer block (for multi-headed self-attention).
60
+ resampler_head_dim (`int`, *optional*, defaults to 96):
61
+ Dimensionality of each head projection in the Transformer block.
62
+ num_key_value_heads (`int`, *optional*, defaults to 4):
63
+ Number of key-value heads in the perceiver attention block.
64
+ attention_dropout (`float`, *optional*, defaults to 0.0):
65
+ The dropout ratio for the attention probabilities.
66
+ """
67
+ _auto_class = 'AutoConfig'
68
+ model_type = "Idefics2ConnectorConfig"
69
+
70
+ def __init__(
71
+ self,
72
+ vision_hidden_size=1152,
73
+ hidden_size=4096,
74
+ hidden_act="silu",
75
+ resampler_n_latents=64,
76
+ resampler_depth=3,
77
+ rms_norm_eps=1e-05,
78
+ resampler_n_heads=16,
79
+ resampler_head_dim=96,
80
+ num_key_value_heads=4,
81
+ attention_dropout=0.0,
82
+ intermediate_size=14336,
83
+ **kwargs,
84
+ ):
85
+ super().__init__(**kwargs)
86
+ self.vision_hidden_size = vision_hidden_size
87
+ self.hidden_size = hidden_size
88
+ self.hidden_act = hidden_act
89
+ self.resampler_n_latents = resampler_n_latents
90
+ self.resampler_depth = resampler_depth
91
+ self.rms_norm_eps = rms_norm_eps
92
+ self.resampler_n_heads = resampler_n_heads
93
+ self.num_key_value_heads = num_key_value_heads
94
+ self.resampler_head_dim = resampler_head_dim
95
+ self.attention_dropout = attention_dropout
96
+ self.intermediate_size = intermediate_size
97
+ if self.num_key_value_heads > self.resampler_n_heads:
98
+ raise ValueError(
99
+ f"num_key_value_heads={self.num_key_value_heads} must be less than or equal to"
100
+ f" resampler_n_heads={self.resampler_n_heads}"
101
+ )
102
+
103
+
104
+ @classmethod
105
+ def from_pretrained(cls, config_path, **kwargs) -> "PretrainedConfig":
106
+
107
+ with open(config_path, "r", encoding="utf-8") as f:
108
+ config_dict = json.load(f)
109
+ cls = Idefics2ConnectorConfig(
110
+ vision_hidden_size=config_dict['vision_hidden_size'],
111
+ hidden_size=config_dict['hidden_size'],
112
+ hidden_act="silu",
113
+ resampler_n_latents=config_dict['resampler_n_latents'],
114
+ resampler_depth=config_dict['resampler_depth'],
115
+ rms_norm_eps=config_dict['rms_norm_eps'],
116
+ intermediate_size = config_dict['intermediate_size']
117
+ )
118
+
119
+ return cls
120
+
121
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
122
+ def _get_unpad_data(attention_mask):
123
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
124
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
125
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
126
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
127
+ return (
128
+ indices,
129
+ cu_seqlens,
130
+ max_seqlen_in_batch,
131
+ )
132
+
133
+ class Idefics2PerceiverAttention(nn.Module):
134
+ def __init__(self, config, layer_idx: Optional[int] = None) -> None:
135
+ """Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`"""
136
+ super().__init__()
137
+
138
+ self.layer_idx = None
139
+ self.hidden_size = config.hidden_size
140
+ self.num_heads = config.resampler_n_heads
141
+ self.head_dim = config.resampler_head_dim
142
+ self.num_key_value_heads = config.num_key_value_heads
143
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
144
+ self.attention_dropout = config.attention_dropout
145
+
146
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
147
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
148
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
149
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
150
+
151
+ self.is_causal = False
152
+
153
+ def forward(
154
+ self,
155
+ latents: torch.Tensor,
156
+ context: torch.Tensor,
157
+ attention_mask: Optional[torch.Tensor] = None,
158
+ position_ids: Optional[torch.LongTensor] = None,
159
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
160
+ output_attentions: bool = False,
161
+ use_cache: bool = False,
162
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
163
+ """
164
+ Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension!
165
+
166
+ Args:
167
+ latents (`torch.Tensor`): Tensor of shape [bsz, n_latents, embed_dim] representing fixed length latents to compress to.
168
+ context (`torch.Tensor`): Tensor of shape [bsz, seq, embed_dim] representing long-form context to resample.
169
+ attention_mask (`torch.Tensor`, *optional*): Tensor of shape [bsz, 1, seq, n_latents] representing attention mask.
170
+ position_ids (`torch.LongTensor`, *optional*): Tensor of shape [bsz, seq] representing position indices of each input token.
171
+ past_key_value (`Tuple[torch.Tensor]`, *optional*): Tuple of tensors containing cached key and value states.
172
+ output_attentions (`bool`, *optional*, defaults to `False`): Whether to return attention weights.
173
+ use_cache (`bool`, *optional*, defaults to `False`): Whether to use past_key_value for caching.
174
+ """
175
+ bsz, q_len, _ = latents.size()
176
+ kv_seq_len = q_len + context.size()[1]
177
+
178
+ hidden_states = torch.concat([context, latents], dim=-2)
179
+
180
+ query_states = self.q_proj(latents)
181
+ key_states = self.k_proj(hidden_states)
182
+ value_states = self.v_proj(hidden_states)
183
+
184
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
185
+ key_states = key_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
186
+ value_states = value_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
187
+
188
+ past_key_value = getattr(self, "past_key_value", past_key_value)
189
+
190
+ if past_key_value is not None:
191
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
192
+
193
+ # repeat k/v heads if n_kv_heads < n_heads
194
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
195
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
196
+
197
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
198
+
199
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
200
+ raise ValueError(
201
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
202
+ f" {attn_weights.size()}"
203
+ )
204
+
205
+ if attention_mask is not None:
206
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
207
+ raise ValueError(
208
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
209
+ )
210
+
211
+ attn_weights = attn_weights + attention_mask
212
+
213
+ # upcast attention to fp32
214
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
215
+ attn_output = torch.matmul(attn_weights, value_states)
216
+
217
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
218
+ raise ValueError(
219
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
220
+ f" {attn_output.size()}"
221
+ )
222
+
223
+ attn_output = attn_output.transpose(1, 2).contiguous()
224
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
225
+
226
+ attn_output = self.o_proj(attn_output)
227
+
228
+ if not output_attentions:
229
+ attn_weights = None
230
+
231
+ return attn_output, attn_weights, past_key_value
232
+
233
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with MistralAttention->Idefics2PerceiverAttention,MistralFlashAttention->Idefics2PerceiverFlashAttention,Mistral->Idefics2
234
+ class Idefics2PerceiverFlashAttention2(Idefics2PerceiverAttention):
235
+ """
236
+ Idefics2 flash attention module. This module inherits from `Idefics2PerceiverAttention` as the weights of the module stays
237
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
238
+ flash attention and deal with padding tokens in case the input contains any of them.
239
+ """
240
+
241
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
242
+ def __init__(self, *args, **kwargs):
243
+ super().__init__(*args, **kwargs)
244
+
245
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
246
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
247
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
248
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
249
+
250
+ # Ignore copy
251
+ def forward(
252
+ self,
253
+ latents: torch.Tensor,
254
+ context: torch.Tensor,
255
+ attention_mask: Optional[torch.LongTensor] = None,
256
+ position_ids: Optional[torch.LongTensor] = None,
257
+ past_key_value: Optional[Cache] = None,
258
+ output_attentions: bool = False,
259
+ use_cache: bool = False,
260
+ **kwargs,
261
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
262
+
263
+ bsz, q_len, _ = latents.size()
264
+ kv_seq_len = q_len + context.size()[1]
265
+
266
+ # Query, Key, Value Projections --> Note that in Flamingo, latents are *concatenated* with context prior to attn!
267
+ # Note: This results in queries w/ `seq = n_latents`, and keys, values with `seq = len(context) + n_latents`
268
+ query_states = self.q_proj(latents)
269
+ key_states = self.k_proj(torch.cat([context, latents], dim=-2))
270
+ value_states = self.v_proj(torch.cat([context, latents], dim=-2))
271
+
272
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
273
+ key_states = key_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
274
+ value_states = value_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
275
+
276
+ kv_seq_len = key_states.shape[-2]
277
+ if past_key_value is not None:
278
+ kv_seq_len += past_key_value[0].shape[-2]
279
+
280
+ if past_key_value is not None:
281
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
282
+ if hasattr(self.config, "sliding_window") and kv_seq_len > self.config.sliding_window:
283
+ slicing_tokens = kv_seq_len - self.config.sliding_window
284
+
285
+ past_key = past_key_value[0]
286
+ past_value = past_key_value[1]
287
+
288
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
289
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
290
+
291
+ if past_key.shape[-2] != self.config.sliding_window - 1:
292
+ raise ValueError(
293
+ "past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1,"
294
+ f" head_dim`), got {past_key.shape}"
295
+ )
296
+
297
+ past_key_value = (past_key, past_value)
298
+
299
+ if attention_mask is not None:
300
+ attention_mask = attention_mask[:, slicing_tokens:]
301
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
302
+
303
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
304
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
305
+
306
+ past_key_value = (key_states, value_states) if use_cache else None
307
+
308
+ # repeat k/v heads if n_kv_heads < n_heads
309
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
310
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
311
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
312
+
313
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
314
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
315
+ # cast them back in float16 just to be sure everything works as expected.
316
+ input_dtype = query_states.dtype
317
+ if input_dtype == torch.float32:
318
+ if torch.is_autocast_enabled():
319
+ target_dtype = torch.get_autocast_gpu_dtype()
320
+ # Handle the case where the model is quantized
321
+ elif hasattr(self.config, "_pre_quantization_dtype"):
322
+ target_dtype = self.config._pre_quantization_dtype
323
+ else:
324
+ target_dtype = self.q_proj.weight.dtype
325
+
326
+ logger.warning_once(
327
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
328
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
329
+ f" {target_dtype}."
330
+ )
331
+
332
+ query_states = query_states.to(target_dtype)
333
+ key_states = key_states.to(target_dtype)
334
+ value_states = value_states.to(target_dtype)
335
+
336
+ # Reashape to the expected shape for Flash Attention
337
+ query_states = query_states.transpose(1, 2)
338
+ key_states = key_states.transpose(1, 2)
339
+ value_states = value_states.transpose(1, 2)
340
+
341
+ attn_output = self._flash_attention_forward(
342
+ query_states,
343
+ key_states,
344
+ value_states,
345
+ attention_mask,
346
+ q_len,
347
+ dropout=dropout_rate,
348
+ use_sliding_windows=False,
349
+ )
350
+
351
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous()
352
+ attn_output = self.o_proj(attn_output)
353
+
354
+ if not output_attentions:
355
+ attn_weights = None
356
+
357
+ return attn_output, attn_weights, past_key_value
358
+
359
+ def _flash_attention_forward(
360
+ self,
361
+ query_states,
362
+ key_states,
363
+ value_states,
364
+ attention_mask,
365
+ query_length,
366
+ dropout=0.0,
367
+ softmax_scale=None,
368
+ use_sliding_windows=False,
369
+ ):
370
+ """
371
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
372
+ first unpad the input, then computes the attention scores and pad the final attention scores.
373
+
374
+ Args:
375
+ query_states (`torch.Tensor`):
376
+ Input query states to be passed to Flash Attention API
377
+ key_states (`torch.Tensor`):
378
+ Input key states to be passed to Flash Attention API
379
+ value_states (`torch.Tensor`):
380
+ Input value states to be passed to Flash Attention API
381
+ attention_mask (`torch.Tensor`):
382
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
383
+ position of padding tokens and 1 for the position of non-padding tokens.
384
+ dropout (`float`):
385
+ Attention dropout
386
+ softmax_scale (`float`, *optional*):
387
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
388
+ use_sliding_windows (`bool`, *optional*):
389
+ Whether to activate sliding window attention.
390
+ """
391
+ if not self._flash_attn_uses_top_left_mask:
392
+ causal = self.is_causal
393
+ else:
394
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
395
+ causal = self.is_causal and query_length != 1
396
+
397
+ # Contains at least one padding token in the sequence
398
+ if attention_mask is not None:
399
+ batch_size = query_states.shape[0]
400
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
401
+ query_states, key_states, value_states, attention_mask, query_length
402
+ )
403
+
404
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
405
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
406
+
407
+ if not use_sliding_windows:
408
+ attn_output_unpad = flash_attn_varlen_func(
409
+ query_states,
410
+ key_states,
411
+ value_states,
412
+ cu_seqlens_q=cu_seqlens_q,
413
+ cu_seqlens_k=cu_seqlens_k,
414
+ max_seqlen_q=max_seqlen_in_batch_q,
415
+ max_seqlen_k=max_seqlen_in_batch_k,
416
+ dropout_p=dropout,
417
+ softmax_scale=softmax_scale,
418
+ causal=causal,
419
+ )
420
+ else:
421
+ attn_output_unpad = flash_attn_varlen_func(
422
+ query_states,
423
+ key_states,
424
+ value_states,
425
+ cu_seqlens_q=cu_seqlens_q,
426
+ cu_seqlens_k=cu_seqlens_k,
427
+ max_seqlen_q=max_seqlen_in_batch_q,
428
+ max_seqlen_k=max_seqlen_in_batch_k,
429
+ dropout_p=dropout,
430
+ softmax_scale=softmax_scale,
431
+ causal=causal,
432
+ window_size=(self.config.sliding_window, self.config.sliding_window),
433
+ )
434
+
435
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
436
+ else:
437
+ if not use_sliding_windows:
438
+ attn_output = flash_attn_func(
439
+ query_states,
440
+ key_states,
441
+ value_states,
442
+ dropout,
443
+ softmax_scale=softmax_scale,
444
+ causal=causal,
445
+ )
446
+ else:
447
+ attn_output = flash_attn_func(
448
+ query_states,
449
+ key_states,
450
+ value_states,
451
+ dropout,
452
+ softmax_scale=softmax_scale,
453
+ causal=causal,
454
+ window_size=(self.config.sliding_window, self.config.sliding_window),
455
+ )
456
+
457
+ return attn_output
458
+
459
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
460
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
461
+
462
+ # On the first iteration we need to properly re-create the padding mask
463
+ # by slicing it on the proper place
464
+ if kv_seq_len != attention_mask.shape[-1]:
465
+ attention_mask_num_tokens = attention_mask.shape[-1]
466
+ attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
467
+
468
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
469
+
470
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
471
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
472
+
473
+ if query_length == kv_seq_len:
474
+ query_layer = index_first_axis(
475
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
476
+ )
477
+ cu_seqlens_q = cu_seqlens_k
478
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
479
+ indices_q = indices_k
480
+ elif query_length == 1:
481
+ max_seqlen_in_batch_q = 1
482
+ cu_seqlens_q = torch.arange(
483
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
484
+ ) # There is a memcpy here, that is very bad.
485
+ indices_q = cu_seqlens_q[:-1]
486
+ query_layer = query_layer.squeeze(1)
487
+ else:
488
+ # The -q_len: slice assumes left padding.
489
+ attention_mask = attention_mask[:, -query_length:]
490
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
491
+
492
+ return (
493
+ query_layer,
494
+ key_layer,
495
+ value_layer,
496
+ indices_q,
497
+ (cu_seqlens_q, cu_seqlens_k),
498
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
499
+ )
500
+
501
+ IDEFICS2_PERCEIVER_ATTENTION_CLASSES = {
502
+ "eager": Idefics2PerceiverAttention,
503
+ "flash_attention_2": Idefics2PerceiverFlashAttention2,
504
+ }
505
+
506
+
507
+ class Idefics2MLP(nn.Module):
508
+ def __init__(
509
+ self,
510
+ hidden_size: int,
511
+ intermediate_size: int,
512
+ output_size: int,
513
+ hidden_act: str,
514
+ ):
515
+ super().__init__()
516
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
517
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
518
+ self.down_proj = nn.Linear(intermediate_size, output_size, bias=False)
519
+ self.act_fn = ACT2FN[hidden_act]
520
+
521
+ def forward(self, x):
522
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
523
+
524
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Idefics2
525
+ class Idefics2RMSNorm(nn.Module):
526
+ def __init__(self, hidden_size, eps=1e-6):
527
+ """
528
+ Idefics2RMSNorm is equivalent to T5LayerNorm
529
+ """
530
+ super().__init__()
531
+ self.weight = nn.Parameter(torch.ones(hidden_size))
532
+ self.variance_epsilon = eps
533
+
534
+ def forward(self, hidden_states):
535
+ input_dtype = hidden_states.dtype
536
+ hidden_states = hidden_states.to(torch.float32)
537
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
538
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
539
+ return self.weight * hidden_states.to(input_dtype)
540
+
541
+ class Idefics2PerceiverLayer(nn.Module):
542
+ def __init__(self, config, layer_idx: int):
543
+ super().__init__()
544
+ self.hidden_size = config.hidden_size
545
+ self.n_latents = config.resampler_n_latents
546
+ self.depth = config.resampler_depth
547
+ self.rms_norm_eps = config.rms_norm_eps
548
+
549
+ self.input_latents_norm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
550
+ self.input_context_norm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
551
+ self.self_attn = IDEFICS2_PERCEIVER_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
552
+ self.post_attention_layernorm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
553
+ self.mlp = Idefics2MLP(
554
+ hidden_size=config.hidden_size,
555
+ intermediate_size=config.hidden_size * 4,
556
+ output_size=config.hidden_size,
557
+ hidden_act=config.hidden_act,
558
+ )
559
+
560
+ def forward(
561
+ self,
562
+ latents: torch.Tensor,
563
+ context: torch.Tensor,
564
+ attention_mask: Optional[torch.Tensor] = None,
565
+ position_ids: Optional[torch.LongTensor] = None,
566
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
567
+ output_attentions: Optional[bool] = False,
568
+ use_cache: Optional[bool] = False,
569
+ **kwargs,
570
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
571
+ """
572
+ Args:
573
+ latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
574
+ context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
575
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
576
+ `(batch, sequence_length)` where padding elements are indicated by 0.
577
+ output_attentions (`bool`, *optional*):
578
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
579
+ returned tensors for more detail.
580
+ use_cache (`bool`, *optional*):
581
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
582
+ (see `past_key_values`).
583
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
584
+ """
585
+ residual = latents
586
+
587
+ latents = self.input_latents_norm(latents)
588
+ context = self.input_context_norm(context)
589
+
590
+ latents, self_attn_weights, present_key_value = self.self_attn(
591
+ latents=latents,
592
+ context=context,
593
+ attention_mask=attention_mask,
594
+ )
595
+ latents = residual + latents
596
+ residual = latents
597
+
598
+ latents = self.post_attention_layernorm(latents)
599
+ latents = self.mlp(latents)
600
+ latents = residual + latents
601
+
602
+ outputs = (latents,)
603
+
604
+ if output_attentions:
605
+ outputs += (self_attn_weights,)
606
+
607
+ if use_cache:
608
+ outputs += (present_key_value,)
609
+
610
+ return outputs
611
+
612
+ class Idefics2Qformer(nn.Module):
613
+
614
+ def __init__(self, config) -> None:
615
+ """
616
+ Instantiates a Perceiver Resampler that operates over a sequence of embeddings (say from a ResNet or ViT or
617
+ MAE) of a given dimension, performs `depth` blocks of cross-attention with a fixed `n_latents` inputs, then
618
+ returns a Tensor of shape [bsz, n_latents, embed_dim]. The Resampler acts as a form of learned pooling and
619
+ is derived from [Perceiver: General Perception with Iterative Attention](https://arxiv.org/abs/2103.03206).
620
+ """
621
+ super().__init__()
622
+ config._attn_implementation = "flash_attention_2"
623
+ self._use_flash_attention_2 = True
624
+
625
+ self.hidden_size = config.hidden_size
626
+ self.hidden_act = config.hidden_act
627
+ self.n_latents = config.resampler_n_latents
628
+ self.depth = config.resampler_depth
629
+ self.rms_norm_eps = config.rms_norm_eps
630
+
631
+ # Create Latents for Perceiver
632
+ self.latents = nn.Parameter(torch.ones(self.n_latents, self.hidden_size))
633
+ # Create Transformer Blocks
634
+ self.layers = nn.ModuleList([Idefics2PerceiverLayer(config, idx) for idx in range(self.depth)])
635
+ self.norm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
636
+
637
+
638
+
639
+
640
+ def forward(
641
+ self,
642
+ context: torch.Tensor,
643
+ attention_mask,
644
+ ) -> torch.Tensor:
645
+ # seq embed -> bsz seq embed
646
+ latents = self.latents.unsqueeze(0).expand((context.shape[0], *self.latents.size()))
647
+
648
+ latent_attention_mask = torch.ones(
649
+ (attention_mask.size(0), latents.size(1)), dtype=attention_mask.dtype, device=attention_mask.device
650
+ )
651
+ attention_mask = torch.cat([attention_mask, latent_attention_mask], dim=-1)
652
+ attention_mask = (
653
+ _prepare_4d_attention_mask(attention_mask, latents.dtype, tgt_len=self.n_latents)
654
+ if not self._use_flash_attention_2
655
+ else attention_mask
656
+ )
657
+ #all_latents = []
658
+ compressed_context = latents
659
+ #all_latents.append(latents)
660
+ for perceiver_layer in self.layers:
661
+ layer_outputs = torch.utils.checkpoint.checkpoint(
662
+ perceiver_layer.__call__,
663
+ compressed_context,
664
+ context,
665
+ attention_mask,
666
+ None,
667
+ None,
668
+ False,
669
+ False,
670
+ use_reentrant=True)
671
+ compressed_context = layer_outputs[0]
672
+ #all_latents.append(compressed_context)
673
+
674
+ compressed_context = self.norm(compressed_context)
675
+
676
+ return compressed_context
677
+
678
+ class Idefics2Connector(PreTrainedModel):
679
+ _auto_class = 'AutoModel'
680
+ config_class = Idefics2ConnectorConfig
681
+
682
+ def __init__(self, config):
683
+ super().__init__(config)
684
+ self.modality_projection = Idefics2MLP(
685
+ hidden_size=config.vision_hidden_size,
686
+ intermediate_size=config.intermediate_size,
687
+ output_size=config.hidden_size,
688
+ hidden_act=config.hidden_act,
689
+ )
690
+ self.perceiver_resampler = Idefics2Qformer(config)
691
+ self.config = config
692
+
693
+ def forward(self, image_hidden_states, attention_mask):
694
+ image_hidden_states = self.modality_projection(image_hidden_states)
695
+ image_hidden_states = self.perceiver_resampler(context=image_hidden_states, attention_mask=attention_mask)
696
+
697
+ vision_hidden_size = image_hidden_states.shape[-1]
698
+ num_image = image_hidden_states.shape[0]
699
+ reshaped_image_hidden_states = image_hidden_states.view(num_image, -1, vision_hidden_size)
700
+
701
+ return reshaped_image_hidden_states
702
+
703
+ @classmethod
704
+ def from_pretrained(self, config_path):
705
+ config = Idefics2ConnectorConfig.from_pretrained(f'{config_path}/config.json')
706
+ cls = Idefics2Connector(config=config)
707
+
708
+ state_dict = torch.load(f'{config_path}/connector.pth', map_location='cpu')
709
+ ret = cls.load_state_dict(state_dict, strict=False)
710
+ print("Loading idefics2 Connector from : {}".format(config_path))
711
+ return cls
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 2,
6
+ "transformers_version": "4.38.1"
7
+ }
image_processor.py ADDED
@@ -0,0 +1,657 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import json
21
+ import torch
22
+
23
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
24
+ from transformers.image_transforms import PaddingMode, pad, resize, to_channel_dimension_format
25
+ from transformers.image_utils import (
26
+ IMAGENET_STANDARD_MEAN,
27
+ IMAGENET_STANDARD_STD,
28
+ ChannelDimension,
29
+ ImageInput,
30
+ PILImageResampling,
31
+ get_image_size,
32
+ infer_channel_dimension_format,
33
+ is_scaled_image,
34
+ is_valid_image,
35
+ to_numpy_array,
36
+ valid_images,
37
+ validate_preprocess_arguments,
38
+ )
39
+ from transformers.utils import TensorType, is_vision_available, logging
40
+ import PIL
41
+ from PIL import Image
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+
46
+
47
+
48
+ def get_resize_output_image_size(image, size, input_data_format) -> Tuple[int, int]:
49
+ """
50
+ Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
51
+
52
+ Args:
53
+ image (`np.ndarray`):
54
+ Image to resize.
55
+ size (`Dict[str, int]`):
56
+ Size of the output image containing the keys "shortest_edge" and "longest_edge".
57
+ input_data_format (`ChannelDimension` or `str`):
58
+ The channel dimension format of the input image.
59
+
60
+ Returns:
61
+ The output size of the image after resizing.
62
+ """
63
+ height, width = get_image_size(image, channel_dim=input_data_format)
64
+
65
+ min_len = size["shortest_edge"]
66
+ max_len = size["longest_edge"]
67
+ aspect_ratio = width / height
68
+
69
+ if width >= height and width > max_len:
70
+ width = max_len
71
+ height = int(width / aspect_ratio)
72
+ elif height > width and height > max_len:
73
+ height = max_len
74
+ width = int(height * aspect_ratio)
75
+ height = max(height, min_len)
76
+ width = max(width, min_len)
77
+ return height, width
78
+
79
+
80
+ def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]:
81
+ """
82
+ Convert a single image or a list of images to a list of numpy arrays.
83
+
84
+ Args:
85
+ images (`ImageInput`):
86
+ A single image or a list of images.
87
+
88
+ Returns:
89
+ A list of numpy arrays.
90
+ """
91
+ # If it's a single image, convert it to a list of lists
92
+ if is_valid_image(images):
93
+ images = [[images]]
94
+ # If it's a list of images, it's a single batch, so convert it to a list of lists
95
+ elif isinstance(images, (list, tuple)) and len(images) > 0 and is_valid_image(images[0]):
96
+ images = [images]
97
+ # If it's a list of batches, it's already in the right format
98
+ elif (
99
+ isinstance(images, (list, tuple))
100
+ and len(images) > 0
101
+ and isinstance(images[0], (list, tuple))
102
+ and is_valid_image(images[0][0])
103
+ ):
104
+ pass
105
+ else:
106
+ raise ValueError(
107
+ "Invalid input type. Must be a single image, a list of images, or a list of batches of images."
108
+ )
109
+ return images
110
+
111
+
112
+ # Copied from transformers.models.detr.image_processing_detr.max_across_indices
113
+ def max_across_indices(values: Iterable[Any]) -> List[Any]:
114
+ """
115
+ Return the maximum value across all indices of an iterable of values.
116
+ """
117
+ return [max(values_i) for values_i in zip(*values)]
118
+
119
+
120
+ def get_max_height_width(
121
+ images_list: List[List[np.ndarray]], input_data_format: Optional[Union[str, ChannelDimension]] = None
122
+ ) -> List[int]:
123
+ """
124
+ Get the maximum height and width across all images in a batch.
125
+ """
126
+ if input_data_format is None:
127
+ input_data_format = infer_channel_dimension_format(images_list[0][0])
128
+
129
+ image_sizes = []
130
+ for images in images_list:
131
+ for image in images:
132
+ image_sizes.append(get_image_size(image, channel_dim=input_data_format))
133
+
134
+ max_height, max_width = max_across_indices(image_sizes)
135
+ return (max_height, max_width)
136
+
137
+
138
+ # Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
139
+ def make_pixel_mask(
140
+ image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
141
+ ) -> np.ndarray:
142
+ """
143
+ Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
144
+
145
+ Args:
146
+ image (`np.ndarray`):
147
+ Image to make the pixel mask for.
148
+ output_size (`Tuple[int, int]`):
149
+ Output size of the mask.
150
+ """
151
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
152
+ mask = np.zeros(output_size, dtype=np.int64)
153
+ mask[:input_height, :input_width] = 1
154
+ return mask
155
+
156
+
157
+ # FIXME Amy: merge this function with the one in image_transforms.py
158
+ def convert_to_rgb(image: ImageInput) -> ImageInput:
159
+ """
160
+ Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
161
+ as is.
162
+ Args:
163
+ image (Image):
164
+ The image to convert.
165
+ """
166
+ if not isinstance(image, PIL.Image.Image):
167
+ return image
168
+
169
+ # `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
170
+ # for transparent images. The call to `alpha_composite` handles this case
171
+ if image.mode == "RGB":
172
+ return image
173
+
174
+ image_rgba = image.convert("RGBA")
175
+ background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
176
+ alpha_composite = Image.alpha_composite(background, image_rgba)
177
+ alpha_composite = alpha_composite.convert("RGB")
178
+ return alpha_composite
179
+
180
+
181
+ class Idefics2ImageProcessor(BaseImageProcessor):
182
+ r"""
183
+ Constructs a Idefics image processor.
184
+
185
+ Args:
186
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
187
+ Whether to convert the image to RGB. This is useful if the input image is of a different format e.g. RGBA.
188
+ Only has an effect if the input image is in the PIL format.
189
+ do_resize (`bool`, *optional*, defaults to `True`):
190
+ Whether to resize the image. The longest edge of the image is resized to be <= `size["longest_edge"]`, with the
191
+ shortest edge resized to keep the input aspect ratio, with a minimum size of `size["shortest_edge"]`.
192
+ size (`Dict`, *optional*):
193
+ Controls the size of the output image. This is a dictionary containing the keys "shortest_edge" and "longest_edge".
194
+ resample (`Resampling`, *optional*, defaults to `Resampling.BILINEAR`):
195
+ Resampling filter to use when resizing the image.
196
+ do_rescale (`bool`, *optional*, defaults to `True`):
197
+ Whether to rescale the image. If set to `True`, the image is rescaled to have pixel values between 0 and 1.
198
+ rescale_factor (`float`, *optional*, defaults to `1/255`):
199
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
200
+ do_normalize (`bool`, *optional*, defaults to `True`):
201
+ Whether to normalize the image. If set to `True`, the image is normalized to have a mean of `image_mean` and
202
+ a standard deviation of `image_std`.
203
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_MEAN`):
204
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
205
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
206
+ overridden by the `image_mean` parameter in the `preprocess` method.
207
+ image_std (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_STD`):
208
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
209
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
210
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
211
+ do_pad (`bool`, *optional*, defaults to `True`):
212
+ Whether or not to pad the images to the largest height and width in the batch and number of images per
213
+ sample in the batch, such that the returned tensor is of shape (batch_size, max_num_images, num_channels, max_height, max_width).
214
+ do_image_splitting (`bool`, *optional*, defaults to `False`):
215
+ Whether to split the image into a sequence 4 equal sub-images concatenated with the original image. That
216
+ strategy was first introduced in https://arxiv.org/abs/2311.06607.
217
+ """
218
+
219
+ model_input_names = ["pixel_values"]
220
+
221
+ def __init__(
222
+ self,
223
+ do_convert_rgb: bool = True,
224
+ do_resize: bool = True,
225
+ size: Dict[str, int] = None,
226
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
227
+ do_rescale: bool = True,
228
+ rescale_factor: float = 1 / 255,
229
+ do_normalize: bool = True,
230
+ image_mean: Optional[Union[float, List[float]]] = None,
231
+ image_std: Optional[Union[float, List[float]]] = None,
232
+ do_pad: bool = True,
233
+ do_image_splitting: bool = False,
234
+ **kwargs,
235
+ ) -> None:
236
+ super().__init__(**kwargs)
237
+ self.do_convert_rgb = do_convert_rgb
238
+ self.do_resize = do_resize
239
+ self.size = size if size is not None else {"shortest_edge": 378, "longest_edge": 980}
240
+ self.resample = resample
241
+ self.do_rescale = do_rescale
242
+ self.rescale_factor = rescale_factor
243
+ self.do_normalize = do_normalize
244
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
245
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
246
+ self.do_pad = do_pad
247
+ self.do_image_splitting = do_image_splitting
248
+
249
+ def resize(
250
+ self,
251
+ image: np.ndarray,
252
+ size: Dict[str, int],
253
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
254
+ data_format: Optional[Union[str, ChannelDimension]] = None,
255
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
256
+ **kwargs,
257
+ ) -> np.ndarray:
258
+ """
259
+ Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
260
+ resized to keep the input aspect ratio.
261
+
262
+ Args:
263
+ image (`np.ndarray`):
264
+ Image to resize.
265
+ size (`Dict[str, int]`):
266
+ Size of the output image.
267
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
268
+ Resampling filter to use when resiizing the image.
269
+ data_format (`str` or `ChannelDimension`, *optional*):
270
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
271
+ input_data_format (`ChannelDimension` or `str`, *optional*):
272
+ The channel dimension format of the input image. If not provided, it will be inferred.
273
+ """
274
+ if "shortest_edge" in size and "longest_edge" in size:
275
+ size = get_resize_output_image_size(image, size, input_data_format)
276
+ elif "height" in size and "width" in size:
277
+ size = (size["height"], size["width"])
278
+ else:
279
+ raise ValueError(
280
+ "size must be a dictionary with keys 'shortest_edge' and 'longest_edge' or 'height' and 'width'."
281
+ )
282
+ try:
283
+ resized = resize(
284
+ image, size, resample=resample, data_format=data_format, input_data_format=input_data_format, **kwargs
285
+ )
286
+ except Exception as err:
287
+ print(f"resize error with image: {image.shape} {image}")
288
+
289
+ return resize(
290
+ image, size, resample=resample, data_format=data_format, input_data_format=input_data_format, **kwargs
291
+ )
292
+
293
+ # Copied from transformers.models.vilt.image_processing_vilt.ViltImageProcessor._pad_image
294
+ def _pad_image(
295
+ self,
296
+ image: np.ndarray,
297
+ output_size: Tuple[int, int],
298
+ constant_values: Union[float, Iterable[float]] = 0,
299
+ data_format: Optional[ChannelDimension] = None,
300
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
301
+ ) -> np.ndarray:
302
+ """
303
+ Pad an image with zeros to the given size.
304
+ """
305
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
306
+ output_height, output_width = output_size
307
+
308
+ pad_bottom = output_height - input_height
309
+ pad_right = output_width - input_width
310
+ padding = ((0, pad_bottom), (0, pad_right))
311
+ padded_image = pad(
312
+ image,
313
+ padding,
314
+ mode=PaddingMode.CONSTANT,
315
+ constant_values=constant_values,
316
+ data_format=data_format,
317
+ input_data_format=input_data_format,
318
+ )
319
+ return padded_image
320
+
321
+ def pad(
322
+ self,
323
+ images: List[np.ndarray],
324
+ constant_values: Union[float, Iterable[float]] = 0,
325
+ return_pixel_mask: bool = True,
326
+ return_tensors: Optional[Union[str, TensorType]] = None,
327
+ data_format: Optional[ChannelDimension] = None,
328
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
329
+ ) -> BatchFeature:
330
+ """
331
+ For a list of images, for each images, pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width.
332
+ For each sample in the batch, pads the sample with empty images to the max_number of images per sample in the batch. Optionally returns a pixel mask.
333
+
334
+ Args:
335
+ images (`np.ndarray`):
336
+ List of list of images to pad. Pads to the largest height and width in the batch.
337
+ constant_values (`float` or `Iterable[float]`, *optional*):
338
+ The value to use for the padding if `mode` is `"constant"`.
339
+ return_pixel_mask (`bool`, *optional*, defaults to `True`):
340
+ Whether to return a pixel mask.
341
+ return_tensors (`str` or `TensorType`, *optional*):
342
+ The type of tensors to return. Can be one of:
343
+ - Unset: Return a list of `np.ndarray`.
344
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
345
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
346
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
347
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
348
+ data_format (`str` or `ChannelDimension`, *optional*):
349
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
350
+ input_data_format (`ChannelDimension` or `str`, *optional*):
351
+ The channel dimension format of the input image. If not provided, it will be inferred.
352
+ """
353
+ pad_size = get_max_height_width(images, input_data_format=input_data_format)
354
+
355
+ batch_size = len(images)
356
+ max_num_images = max(len(images_) for images_ in images)
357
+ input_data_format = (
358
+ infer_channel_dimension_format(images[0][0]) if input_data_format is None else input_data_format
359
+ )
360
+ data_format = input_data_format if data_format is None else data_format
361
+
362
+ def empty_image(size, input_data_format):
363
+ if input_data_format == ChannelDimension.FIRST:
364
+ return np.zeros((3, *size), dtype=np.uint8)
365
+ elif input_data_format == ChannelDimension.LAST:
366
+ return np.zeros((*size, 3), dtype=np.uint8)
367
+ raise ValueError("Invalid channel dimension format.")
368
+
369
+ padded_images_list = [
370
+ [empty_image(pad_size, data_format) for _ in range(max_num_images)] for _ in range(batch_size)
371
+ ]
372
+ padded_masks = [[np.zeros(pad_size) for _ in range(max_num_images)] for _ in range(batch_size)]
373
+
374
+ for batch_idx in range(batch_size):
375
+ for sample_idx, image in enumerate(images[batch_idx]):
376
+ padded_images_list[batch_idx][sample_idx] = self._pad_image(
377
+ image,
378
+ pad_size,
379
+ constant_values=constant_values,
380
+ data_format=data_format,
381
+ input_data_format=input_data_format,
382
+ )
383
+ padded_masks[batch_idx][sample_idx] = make_pixel_mask(
384
+ image, output_size=pad_size, input_data_format=input_data_format
385
+ )
386
+
387
+ padded_masks = padded_masks if return_pixel_mask else None
388
+ return padded_images_list, padded_masks
389
+
390
+ def _crop(
391
+ self,
392
+ im: np.ndarray,
393
+ w1: int,
394
+ h1: int,
395
+ w2: int,
396
+ h2: int,
397
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
398
+ ) -> np.ndarray:
399
+ if input_data_format == ChannelDimension.FIRST:
400
+ return im[:, h1:h2, w1:w2]
401
+ elif input_data_format == ChannelDimension.LAST:
402
+ return im[h1:h2, w1:w2, :]
403
+
404
+ def split_image(
405
+ self,
406
+ image: np.ndarray,
407
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
408
+ ):
409
+ """
410
+ Split an image into 4 equal sub-images, and the concatenate that sequence with the original image.
411
+ That means that a single image becomes a sequence of 5 images.
412
+ This is a "trick" to spend more compute on each image with no changes in the vision encoder.
413
+
414
+ Args:
415
+ image (`np.ndarray`):
416
+ Images to split.
417
+ input_data_format (`ChannelDimension` or `str`, *optional*):
418
+ The channel dimension format of the input image. If not provided, it will be inferred.
419
+ """
420
+ height, width = get_image_size(image, input_data_format)
421
+
422
+ mid_width = width // 2
423
+ mid_height = height // 2
424
+ image_list = [
425
+ self._crop(image, 0, 0, mid_width, mid_height, input_data_format),
426
+ self._crop(image, mid_width, 0, width, mid_height, input_data_format),
427
+ self._crop(image, 0, mid_height, mid_width, height, input_data_format),
428
+ self._crop(image, mid_width, mid_height, width, height, input_data_format),
429
+ image,
430
+ ]
431
+ return image_list
432
+
433
+ def preprocess(
434
+ self,
435
+ images: ImageInput,
436
+ do_convert_rgb: Optional[bool] = None,
437
+ do_resize: Optional[bool] = None,
438
+ size: Optional[Dict[str, int]] = None,
439
+ resample: PILImageResampling = None,
440
+ do_rescale: Optional[bool] = None,
441
+ rescale_factor: Optional[float] = None,
442
+ do_normalize: Optional[bool] = None,
443
+ image_mean: Optional[Union[float, List[float]]] = None,
444
+ image_std: Optional[Union[float, List[float]]] = None,
445
+ do_pad: Optional[bool] = None,
446
+ do_image_splitting: Optional[bool] = None,
447
+ return_tensors: Optional[Union[str, TensorType]] = None,
448
+ input_data_format: Optional[ChannelDimension] = None,
449
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
450
+ ):
451
+ """
452
+ Preprocess a batch of images.
453
+
454
+ Args:
455
+ images (`ImageInput`):
456
+ A list of images to preprocess.
457
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
458
+ Whether to convert the image to RGB.
459
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
460
+ Whether to resize the image.
461
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
462
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
463
+ the longest edge resized to keep the input aspect ratio.
464
+ resample (`int`, *optional*, defaults to `self.resample`):
465
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
466
+ has an effect if `do_resize` is set to `True`.
467
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
468
+ Whether to rescale the image.
469
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
470
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
471
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
472
+ Whether to normalize the image.
473
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
474
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
475
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
476
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
477
+ `True`.
478
+ do_pad (`bool`, *optional*, defaults to `self.do_pad`):
479
+ Whether or not to pad the images to the largest height and width in the batch.
480
+ do_image_splitting (`bool`, *optional*, defaults to `self.do_image_splitting`):
481
+ Whether to split the image into a sequence 4 equal sub-images concatenated with the original image. That
482
+ strategy was first introduced in https://arxiv.org/abs/2311.06607.
483
+ return_tensors (`str` or `TensorType`, *optional*):
484
+ The type of tensors to return. Can be one of:
485
+ - Unset: Return a list of `np.ndarray`.
486
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
487
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
488
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
489
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
490
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
491
+ The channel dimension format for the output image. Can be one of:
492
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
493
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
494
+ - Unset: Use the channel dimension format of the input image.
495
+ input_data_format (`ChannelDimension` or `str`, *optional*):
496
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
497
+ from the input image. Can be one of:
498
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
499
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
500
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
501
+ """
502
+ do_resize = do_resize if do_resize is not None else self.do_resize
503
+ size = size if size is not None else self.size
504
+ resample = resample if resample is not None else self.resample
505
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
506
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
507
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
508
+ image_mean = image_mean if image_mean is not None else self.image_mean
509
+ image_std = image_std if image_std is not None else self.image_std
510
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
511
+ do_pad = do_pad if do_pad is not None else self.do_pad
512
+ do_image_splitting = do_image_splitting if do_image_splitting is not None else self.do_image_splitting
513
+
514
+ images_list = make_list_of_images(images)
515
+
516
+ if not valid_images(images_list[0]):
517
+ raise ValueError(
518
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
519
+ "torch.Tensor, tf.Tensor or jax.ndarray."
520
+ )
521
+
522
+ validate_preprocess_arguments(
523
+ do_rescale=do_rescale,
524
+ rescale_factor=rescale_factor,
525
+ do_normalize=do_normalize,
526
+ image_mean=image_mean,
527
+ image_std=image_std,
528
+ do_resize=do_resize,
529
+ size=size,
530
+ resample=resample,
531
+ )
532
+
533
+ if do_convert_rgb:
534
+ images_list = [[convert_to_rgb(image) for image in images] for images in images_list]
535
+
536
+ # All transformations expect numpy arrays.
537
+ images_list = [[to_numpy_array(image) for image in images] for images in images_list]
538
+
539
+ if is_scaled_image(images_list[0][0]) and do_rescale:
540
+ logger.warning_once(
541
+ "It looks like you are trying to rescale already rescaled images. If the input"
542
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
543
+ )
544
+
545
+ if input_data_format is None:
546
+ # We assume that all images have the same channel dimension format.
547
+ input_data_format = ChannelDimension.LAST #infer_channel_dimension_format(images_list[0][0])
548
+
549
+ if do_image_splitting:
550
+ new_images_list = []
551
+ for images in images_list:
552
+ new_images = []
553
+ for image in images:
554
+ new_images.extend(self.split_image(image, input_data_format))
555
+ new_images_list.append(new_images)
556
+ images_list = new_images_list
557
+
558
+ if do_resize:
559
+ images_list = [
560
+ [
561
+ self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
562
+ for image in images
563
+ ]
564
+ for images in images_list
565
+ ]
566
+
567
+ if do_rescale:
568
+ images_list = [
569
+ [
570
+ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
571
+ for image in images
572
+ ]
573
+ for images in images_list
574
+ ]
575
+
576
+ if do_normalize:
577
+ images_list = [
578
+ [
579
+ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
580
+ for image in images
581
+ ]
582
+ for images in images_list
583
+ ]
584
+
585
+ pixel_attention_mask = None
586
+ if do_pad:
587
+ images_list, pixel_attention_mask = self.pad(
588
+ images_list, return_pixel_mask=True, return_tensors=return_tensors, input_data_format=input_data_format
589
+ )
590
+
591
+ if data_format is not None:
592
+ images_list = [
593
+ [
594
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
595
+ for image in images
596
+ ]
597
+ for images in images_list
598
+ ]
599
+
600
+ data = {"pixel_values": np.array(images_list) if do_pad else images_list} # Faster tensor conversion
601
+ if pixel_attention_mask is not None:
602
+ data["pixel_attention_mask"] = np.array(pixel_attention_mask) if do_pad else pixel_attention_mask
603
+
604
+
605
+ temp_pixel_values = data["pixel_values"].copy()
606
+ temp_pixel_values = torch.from_numpy(temp_pixel_values)
607
+ batch_size, num_images, num_channels, height, width = temp_pixel_values.shape
608
+ temp_pixel_values = temp_pixel_values.view(batch_size * num_images, *temp_pixel_values.shape[2:])
609
+ # Remove padding images - padding images are full 0.
610
+ nb_values_per_image = temp_pixel_values.shape[1:].numel()
611
+ real_images_inds = (temp_pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
612
+ temp_pixel_values = temp_pixel_values[real_images_inds].contiguous()
613
+ # if 'pixel_attention_mask' is not none
614
+ if 'pixel_attention_mask' in data:
615
+ pixel_attention_mask = torch.from_numpy(data['pixel_attention_mask'])
616
+ # Remove padding images from the mask/pP p
617
+ pixel_attention_mask = pixel_attention_mask.view(
618
+ batch_size * num_images, *pixel_attention_mask.shape[2:]
619
+ )
620
+ pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
621
+ pixel_attention_mask = pixel_attention_mask.to(torch.bool)
622
+ else:
623
+ pixel_attention_mask = torch.ones(
624
+ size=(temp_pixel_values.size(0), temp_pixel_values.size(2), temp_pixel_values.size(3)),
625
+ dtype=torch.bool,
626
+ device=temp_pixel_values.device,
627
+ )
628
+ patch_size = 14 #self.config.vision_config.patch_size
629
+ patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
630
+ patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
631
+ patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
632
+
633
+ data["navit_pixel_values"] = temp_pixel_values
634
+ data["pixel_attention_mask"] = patch_attention_mask
635
+
636
+ return BatchFeature(data=data, tensor_type=return_tensors)
637
+
638
+ @classmethod
639
+ def from_pretrained(self, config_path):
640
+ with open(f'{config_path}/config.json', "r", encoding="utf-8") as f:
641
+ config = json.load(f)
642
+
643
+ cls = Idefics2ImageProcessor(
644
+ do_convert_rgb = config['do_convert_rgb'],
645
+ do_resize = config['do_resize'],
646
+ size = config['size'],
647
+ resample = config['resample'],
648
+ do_rescale = config['do_rescale'],
649
+ rescale_factor = config['rescale_factor'],
650
+ do_normalize = config['do_normalize'],
651
+ image_mean = config['image_mean'],
652
+ image_std = config['image_std'],
653
+ do_pad = config['do_pad'],
654
+ do_image_splitting = config['do_image_splitting']
655
+ )
656
+ #print("Loading idefics2 image Processor: {}".format(config_path))
657
+ return cls
image_processor_2k.py ADDED
@@ -0,0 +1,751 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import json
21
+ import torch
22
+
23
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
24
+ from transformers.image_transforms import PaddingMode, pad, resize, to_channel_dimension_format
25
+ from transformers.image_utils import (
26
+ IMAGENET_STANDARD_MEAN,
27
+ IMAGENET_STANDARD_STD,
28
+ ChannelDimension,
29
+ ImageInput,
30
+ PILImageResampling,
31
+ get_image_size,
32
+ infer_channel_dimension_format,
33
+ is_scaled_image,
34
+ is_valid_image,
35
+ to_numpy_array,
36
+ valid_images,
37
+ validate_preprocess_arguments,
38
+ )
39
+ from transformers.utils import TensorType, is_vision_available, logging
40
+ import PIL
41
+ from PIL import Image
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+
46
+
47
+
48
+ def get_resize_output_image_size(image, size, input_data_format) -> Tuple[int, int]:
49
+ """
50
+ Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
51
+
52
+ Args:
53
+ image (`np.ndarray`):
54
+ Image to resize.
55
+ size (`Dict[str, int]`):
56
+ Size of the output image containing the keys "shortest_edge" and "longest_edge".
57
+ input_data_format (`ChannelDimension` or `str`):
58
+ The channel dimension format of the input image.
59
+
60
+ Returns:
61
+ The output size of the image after resizing.
62
+ """
63
+ height, width = get_image_size(image, channel_dim=input_data_format)
64
+
65
+ min_len = size["shortest_edge"]
66
+ max_len = size["longest_edge"]
67
+ aspect_ratio = width / height
68
+
69
+ if width >= height:
70
+ width = max_len
71
+ height = int(width / aspect_ratio)
72
+ elif height > width:
73
+ height = max_len
74
+ width = int(height * aspect_ratio)
75
+ height = max(height, min_len)
76
+ width = max(width, min_len)
77
+ return height, width
78
+
79
+
80
+ def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]:
81
+ """
82
+ Convert a single image or a list of images to a list of numpy arrays.
83
+
84
+ Args:
85
+ images (`ImageInput`):
86
+ A single image or a list of images.
87
+
88
+ Returns:
89
+ A list of numpy arrays.
90
+ """
91
+ # If it's a single image, convert it to a list of lists
92
+ if is_valid_image(images):
93
+ images = [[images]]
94
+ # If it's a list of images, it's a single batch, so convert it to a list of lists
95
+ elif isinstance(images, (list, tuple)) and len(images) > 0 and is_valid_image(images[0]):
96
+ images = [images]
97
+ # If it's a list of batches, it's already in the right format
98
+ elif (
99
+ isinstance(images, (list, tuple))
100
+ and len(images) > 0
101
+ and isinstance(images[0], (list, tuple))
102
+ and is_valid_image(images[0][0])
103
+ ):
104
+ pass
105
+ else:
106
+ raise ValueError(
107
+ "Invalid input type. Must be a single image, a list of images, or a list of batches of images."
108
+ )
109
+ return images
110
+
111
+
112
+ # Copied from transformers.models.detr.image_processing_detr.max_across_indices
113
+ def max_across_indices(values: Iterable[Any]) -> List[Any]:
114
+ """
115
+ Return the maximum value across all indices of an iterable of values.
116
+ """
117
+ return [max(values_i) for values_i in zip(*values)]
118
+
119
+
120
+ def get_max_height_width(
121
+ images_list: List[List[np.ndarray]], input_data_format: Optional[Union[str, ChannelDimension]] = None
122
+ ) -> List[int]:
123
+ """
124
+ Get the maximum height and width across all images in a batch.
125
+ """
126
+ if input_data_format is None:
127
+ input_data_format = infer_channel_dimension_format(images_list[0][0])
128
+
129
+ image_sizes = []
130
+ for images in images_list:
131
+ for image in images:
132
+ image_sizes.append(get_image_size(image, channel_dim=input_data_format))
133
+
134
+ max_height, max_width = max_across_indices(image_sizes)
135
+ return (max_height, max_width)
136
+
137
+
138
+ # Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
139
+ def make_pixel_mask(
140
+ image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
141
+ ) -> np.ndarray:
142
+ """
143
+ Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
144
+
145
+ Args:
146
+ image (`np.ndarray`):
147
+ Image to make the pixel mask for.
148
+ output_size (`Tuple[int, int]`):
149
+ Output size of the mask.
150
+ """
151
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
152
+ mask = np.zeros(output_size, dtype=np.int64)
153
+ mask[:input_height, :input_width] = 1
154
+ return mask
155
+
156
+
157
+ # FIXME Amy: merge this function with the one in image_transforms.py
158
+ def convert_to_rgb(image: ImageInput) -> ImageInput:
159
+ """
160
+ Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
161
+ as is.
162
+ Args:
163
+ image (Image):
164
+ The image to convert.
165
+ """
166
+ if not isinstance(image, PIL.Image.Image):
167
+ return image
168
+
169
+ # `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
170
+ # for transparent images. The call to `alpha_composite` handles this case
171
+ if image.mode == "RGB":
172
+ return image
173
+
174
+ image_rgba = image.convert("RGBA")
175
+ background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
176
+ alpha_composite = Image.alpha_composite(background, image_rgba)
177
+ alpha_composite = alpha_composite.convert("RGB")
178
+ return alpha_composite
179
+
180
+
181
+ class Idefics2ImageProcessor(BaseImageProcessor):
182
+ r"""
183
+ Constructs a Idefics image processor.
184
+
185
+ Args:
186
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
187
+ Whether to convert the image to RGB. This is useful if the input image is of a different format e.g. RGBA.
188
+ Only has an effect if the input image is in the PIL format.
189
+ do_resize (`bool`, *optional*, defaults to `True`):
190
+ Whether to resize the image. The longest edge of the image is resized to be <= `size["longest_edge"]`, with the
191
+ shortest edge resized to keep the input aspect ratio, with a minimum size of `size["shortest_edge"]`.
192
+ size (`Dict`, *optional*):
193
+ Controls the size of the output image. This is a dictionary containing the keys "shortest_edge" and "longest_edge".
194
+ resample (`Resampling`, *optional*, defaults to `Resampling.BILINEAR`):
195
+ Resampling filter to use when resizing the image.
196
+ do_rescale (`bool`, *optional*, defaults to `True`):
197
+ Whether to rescale the image. If set to `True`, the image is rescaled to have pixel values between 0 and 1.
198
+ rescale_factor (`float`, *optional*, defaults to `1/255`):
199
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
200
+ do_normalize (`bool`, *optional*, defaults to `True`):
201
+ Whether to normalize the image. If set to `True`, the image is normalized to have a mean of `image_mean` and
202
+ a standard deviation of `image_std`.
203
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_MEAN`):
204
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
205
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
206
+ overridden by the `image_mean` parameter in the `preprocess` method.
207
+ image_std (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_STD`):
208
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
209
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
210
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
211
+ do_pad (`bool`, *optional*, defaults to `True`):
212
+ Whether or not to pad the images to the largest height and width in the batch and number of images per
213
+ sample in the batch, such that the returned tensor is of shape (batch_size, max_num_images, num_channels, max_height, max_width).
214
+ do_image_splitting (`bool`, *optional*, defaults to `False`):
215
+ Whether to split the image into a sequence 4 equal sub-images concatenated with the original image. That
216
+ strategy was first introduced in https://arxiv.org/abs/2311.06607.
217
+ """
218
+
219
+ model_input_names = ["pixel_values"]
220
+
221
+ def __init__(
222
+ self,
223
+ do_convert_rgb: bool = True,
224
+ do_resize: bool = True,
225
+ size: Dict[str, int] = None,
226
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
227
+ do_rescale: bool = True,
228
+ rescale_factor: float = 1 / 255,
229
+ do_normalize: bool = True,
230
+ image_mean: Optional[Union[float, List[float]]] = None,
231
+ image_std: Optional[Union[float, List[float]]] = None,
232
+ do_pad: bool = True,
233
+ do_image_splitting: bool = False,
234
+ **kwargs,
235
+ ) -> None:
236
+ super().__init__(**kwargs)
237
+ self.do_convert_rgb = do_convert_rgb
238
+ self.do_resize = do_resize
239
+ self.size = size if size is not None else {"shortest_edge": 756, "longest_edge": 1960}
240
+ self.resample = resample
241
+ self.do_rescale = do_rescale
242
+ self.rescale_factor = rescale_factor
243
+ self.do_normalize = do_normalize
244
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
245
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
246
+ self.do_pad = do_pad
247
+ self.do_image_splitting = do_image_splitting
248
+
249
+ def resize(
250
+ self,
251
+ image: np.ndarray,
252
+ size: Dict[str, int],
253
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
254
+ data_format: Optional[Union[str, ChannelDimension]] = None,
255
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
256
+ **kwargs,
257
+ ) -> np.ndarray:
258
+ """
259
+ Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
260
+ resized to keep the input aspect ratio.
261
+
262
+ Args:
263
+ image (`np.ndarray`):
264
+ Image to resize.
265
+ size (`Dict[str, int]`):
266
+ Size of the output image.
267
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
268
+ Resampling filter to use when resiizing the image.
269
+ data_format (`str` or `ChannelDimension`, *optional*):
270
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
271
+ input_data_format (`ChannelDimension` or `str`, *optional*):
272
+ The channel dimension format of the input image. If not provided, it will be inferred.
273
+ """
274
+ if "shortest_edge" in size and "longest_edge" in size:
275
+ size = get_resize_output_image_size(image, size, input_data_format)
276
+ elif "height" in size and "width" in size:
277
+ size = (size["height"], size["width"])
278
+ else:
279
+ raise ValueError(
280
+ "size must be a dictionary with keys 'shortest_edge' and 'longest_edge' or 'height' and 'width'."
281
+ )
282
+ try:
283
+ resized = resize(
284
+ image, size, resample=resample, data_format=data_format, input_data_format=input_data_format, **kwargs
285
+ )
286
+ except Exception as err:
287
+ print(f"resize error with image: {image.shape} {image}")
288
+
289
+ return resize(
290
+ image, size, resample=resample, data_format=data_format, input_data_format=input_data_format, **kwargs
291
+ )
292
+
293
+ # Copied from transformers.models.vilt.image_processing_vilt.ViltImageProcessor._pad_image
294
+ def _pad_image(
295
+ self,
296
+ image: np.ndarray,
297
+ output_size: Tuple[int, int],
298
+ constant_values: Union[float, Iterable[float]] = 0,
299
+ data_format: Optional[ChannelDimension] = None,
300
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
301
+ ) -> np.ndarray:
302
+ """
303
+ Pad an image with zeros to the given size.
304
+ """
305
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
306
+ output_height, output_width = output_size
307
+
308
+ pad_bottom = output_height - input_height
309
+ pad_right = output_width - input_width
310
+ padding = ((0, pad_bottom), (0, pad_right))
311
+ padded_image = pad(
312
+ image,
313
+ padding,
314
+ mode=PaddingMode.CONSTANT,
315
+ constant_values=constant_values,
316
+ data_format=data_format,
317
+ input_data_format=input_data_format,
318
+ )
319
+ return padded_image
320
+
321
+ def pad(
322
+ self,
323
+ images: List[np.ndarray],
324
+ constant_values: Union[float, Iterable[float]] = 0,
325
+ return_pixel_mask: bool = True,
326
+ return_tensors: Optional[Union[str, TensorType]] = None,
327
+ data_format: Optional[ChannelDimension] = None,
328
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
329
+ ) -> BatchFeature:
330
+ """
331
+ For a list of images, for each images, pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width.
332
+ For each sample in the batch, pads the sample with empty images to the max_number of images per sample in the batch. Optionally returns a pixel mask.
333
+
334
+ Args:
335
+ images (`np.ndarray`):
336
+ List of list of images to pad. Pads to the largest height and width in the batch.
337
+ constant_values (`float` or `Iterable[float]`, *optional*):
338
+ The value to use for the padding if `mode` is `"constant"`.
339
+ return_pixel_mask (`bool`, *optional*, defaults to `True`):
340
+ Whether to return a pixel mask.
341
+ return_tensors (`str` or `TensorType`, *optional*):
342
+ The type of tensors to return. Can be one of:
343
+ - Unset: Return a list of `np.ndarray`.
344
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
345
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
346
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
347
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
348
+ data_format (`str` or `ChannelDimension`, *optional*):
349
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
350
+ input_data_format (`ChannelDimension` or `str`, *optional*):
351
+ The channel dimension format of the input image. If not provided, it will be inferred.
352
+ """
353
+ pad_size = get_max_height_width(images, input_data_format=input_data_format)
354
+
355
+ # align with patch size
356
+ patch_size = 14
357
+ pad_size = [int(np.ceil(x / patch_size)) * patch_size for x in pad_size]
358
+
359
+ batch_size = len(images)
360
+ max_num_images = max(len(images_) for images_ in images)
361
+ input_data_format = (
362
+ infer_channel_dimension_format(images[0][0]) if input_data_format is None else input_data_format
363
+ )
364
+ data_format = input_data_format if data_format is None else data_format
365
+
366
+ def empty_image(size, input_data_format):
367
+ if input_data_format == ChannelDimension.FIRST:
368
+ return np.zeros((3, *size), dtype=np.uint8)
369
+ elif input_data_format == ChannelDimension.LAST:
370
+ return np.zeros((*size, 3), dtype=np.uint8)
371
+ raise ValueError("Invalid channel dimension format.")
372
+
373
+ padded_images_list = [
374
+ [empty_image(pad_size, data_format) for _ in range(max_num_images)] for _ in range(batch_size)
375
+ ]
376
+ padded_masks = [[np.zeros(pad_size) for _ in range(max_num_images)] for _ in range(batch_size)]
377
+
378
+ for batch_idx in range(batch_size):
379
+ for sample_idx, image in enumerate(images[batch_idx]):
380
+ padded_images_list[batch_idx][sample_idx] = self._pad_image(
381
+ image,
382
+ pad_size,
383
+ constant_values=constant_values,
384
+ data_format=data_format,
385
+ input_data_format=input_data_format,
386
+ )
387
+ padded_masks[batch_idx][sample_idx] = make_pixel_mask(
388
+ image, output_size=pad_size, input_data_format=input_data_format
389
+ )
390
+
391
+ padded_masks = padded_masks if return_pixel_mask else None
392
+ return padded_images_list, padded_masks
393
+
394
+ def _crop(
395
+ self,
396
+ im: np.ndarray,
397
+ w1: int,
398
+ h1: int,
399
+ w2: int,
400
+ h2: int,
401
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
402
+ ) -> np.ndarray:
403
+ if input_data_format == ChannelDimension.FIRST:
404
+ return im[:, h1:h2, w1:w2]
405
+ elif input_data_format == ChannelDimension.LAST:
406
+ return im[h1:h2, w1:w2, :]
407
+
408
+ def split_image(
409
+ self,
410
+ image: np.ndarray,
411
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
412
+ ):
413
+ """
414
+ Split an image into 4 equal sub-images, and the concatenate that sequence with the original image.
415
+ That means that a single image becomes a sequence of 5 images.
416
+ This is a "trick" to spend more compute on each image with no changes in the vision encoder.
417
+
418
+ Args:
419
+ image (`np.ndarray`):
420
+ Images to split.
421
+ input_data_format (`ChannelDimension` or `str`, *optional*):
422
+ The channel dimension format of the input image. If not provided, it will be inferred.
423
+ """
424
+ height, width = get_image_size(image, input_data_format)
425
+
426
+ mid_width = width // 2
427
+ mid_height = height // 2
428
+ image_list = [
429
+ self._crop(image, 0, 0, mid_width, mid_height, input_data_format),
430
+ self._crop(image, mid_width, 0, width, mid_height, input_data_format),
431
+ self._crop(image, 0, mid_height, mid_width, height, input_data_format),
432
+ self._crop(image, mid_width, mid_height, width, height, input_data_format),
433
+ image,
434
+ ]
435
+ # for img in image_list:
436
+ # print(type(img),img.dtype)
437
+ return image_list
438
+
439
+ def preprocess(
440
+ self,
441
+ images: ImageInput,
442
+ do_convert_rgb: Optional[bool] = None,
443
+ do_resize: Optional[bool] = None,
444
+ size: Optional[Dict[str, int]] = None,
445
+ resample: PILImageResampling = None,
446
+ do_rescale: Optional[bool] = None,
447
+ rescale_factor: Optional[float] = None,
448
+ do_normalize: Optional[bool] = None,
449
+ image_mean: Optional[Union[float, List[float]]] = None,
450
+ image_std: Optional[Union[float, List[float]]] = None,
451
+ do_pad: Optional[bool] = None,
452
+ do_image_splitting: Optional[bool] = None,
453
+ return_tensors: Optional[Union[str, TensorType]] = None,
454
+ input_data_format: Optional[ChannelDimension] = None,
455
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
456
+ ):
457
+ """
458
+ Preprocess a batch of images.
459
+
460
+ Args:
461
+ images (`ImageInput`):
462
+ A list of images to preprocess.
463
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
464
+ Whether to convert the image to RGB.
465
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
466
+ Whether to resize the image.
467
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
468
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
469
+ the longest edge resized to keep the input aspect ratio.
470
+ resample (`int`, *optional*, defaults to `self.resample`):
471
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
472
+ has an effect if `do_resize` is set to `True`.
473
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
474
+ Whether to rescale the image.
475
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
476
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
477
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
478
+ Whether to normalize the image.
479
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
480
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
481
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
482
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
483
+ `True`.
484
+ do_pad (`bool`, *optional*, defaults to `self.do_pad`):
485
+ Whether or not to pad the images to the largest height and width in the batch.
486
+ do_image_splitting (`bool`, *optional*, defaults to `self.do_image_splitting`):
487
+ Whether to split the image into a sequence 4 equal sub-images concatenated with the original image. That
488
+ strategy was first introduced in https://arxiv.org/abs/2311.06607.
489
+ return_tensors (`str` or `TensorType`, *optional*):
490
+ The type of tensors to return. Can be one of:
491
+ - Unset: Return a list of `np.ndarray`.
492
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
493
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
494
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
495
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
496
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
497
+ The channel dimension format for the output image. Can be one of:
498
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
499
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
500
+ - Unset: Use the channel dimension format of the input image.
501
+ input_data_format (`ChannelDimension` or `str`, *optional*):
502
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
503
+ from the input image. Can be one of:
504
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
505
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
506
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
507
+ """
508
+ do_resize = do_resize if do_resize is not None else self.do_resize
509
+ size = size if size is not None else self.size
510
+ resample = resample if resample is not None else self.resample
511
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
512
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
513
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
514
+ image_mean = image_mean if image_mean is not None else self.image_mean
515
+ image_std = image_std if image_std is not None else self.image_std
516
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
517
+ do_pad = do_pad if do_pad is not None else self.do_pad
518
+ do_image_splitting = do_image_splitting if do_image_splitting is not None else self.do_image_splitting
519
+
520
+ images_list = make_list_of_images(images)
521
+
522
+ if not valid_images(images_list[0]):
523
+ raise ValueError(
524
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
525
+ "torch.Tensor, tf.Tensor or jax.ndarray."
526
+ )
527
+
528
+ validate_preprocess_arguments(
529
+ do_rescale=do_rescale,
530
+ rescale_factor=rescale_factor,
531
+ do_normalize=do_normalize,
532
+ image_mean=image_mean,
533
+ image_std=image_std,
534
+ do_resize=do_resize,
535
+ size=size,
536
+ resample=resample,
537
+ )
538
+
539
+ if do_convert_rgb:
540
+ images_list = [[convert_to_rgb(image) for image in images] for images in images_list]
541
+
542
+ # All transformations expect numpy arrays.
543
+ images_list = [[to_numpy_array(image) for image in images] for images in images_list]
544
+
545
+ if is_scaled_image(images_list[0][0]) and do_rescale:
546
+ logger.warning_once(
547
+ "It looks like you are trying to rescale already rescaled images. If the input"
548
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
549
+ )
550
+
551
+ if input_data_format is None:
552
+ # We assume that all images have the same channel dimension format.
553
+ input_data_format = ChannelDimension.LAST #infer_channel_dimension_format(images_list[0][0])
554
+
555
+ if do_image_splitting:
556
+ new_images_list = []
557
+ for images in images_list:
558
+ new_images = []
559
+ for image in images:
560
+ new_images.extend(self.split_image(image, input_data_format))
561
+ new_images_list.append(new_images)
562
+ images_list = new_images_list
563
+
564
+ if do_resize:
565
+ images_list = [
566
+ [
567
+ self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
568
+ for image in images
569
+ ]
570
+ for images in images_list
571
+ ]
572
+
573
+ if do_rescale:
574
+ images_list = [
575
+ [
576
+ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
577
+ for image in images
578
+ ]
579
+ for images in images_list
580
+ ]
581
+
582
+ if do_normalize:
583
+ images_list = [
584
+ [
585
+ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
586
+ for image in images
587
+ ]
588
+ for images in images_list
589
+ ]
590
+
591
+ pixel_attention_mask = None
592
+ if do_pad:
593
+ images_list, pixel_attention_mask = self.pad(
594
+ images_list, return_pixel_mask=True, return_tensors=return_tensors, input_data_format=input_data_format
595
+ )
596
+
597
+ if data_format is not None:
598
+ images_list = [
599
+ [
600
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
601
+ for image in images
602
+ ]
603
+ for images in images_list
604
+ ]
605
+
606
+ data = {"pixel_values": np.array(images_list) if do_pad else images_list} # Faster tensor conversion
607
+ if pixel_attention_mask is not None:
608
+ data["pixel_attention_mask"] = np.array(pixel_attention_mask) if do_pad else pixel_attention_mask
609
+
610
+
611
+ temp_pixel_values = data["pixel_values"].copy()
612
+ temp_pixel_values = torch.from_numpy(temp_pixel_values)
613
+ batch_size, num_images, num_channels, height, width = temp_pixel_values.shape
614
+ temp_pixel_values = temp_pixel_values.view(batch_size * num_images, *temp_pixel_values.shape[2:])
615
+ # Remove padding images - padding images are full 0.
616
+ nb_values_per_image = temp_pixel_values.shape[1:].numel()
617
+ real_images_inds = (temp_pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
618
+ temp_pixel_values = temp_pixel_values[real_images_inds].contiguous()
619
+ # if 'pixel_attention_mask' is not none
620
+ if 'pixel_attention_mask' in data:
621
+ pixel_attention_mask = torch.from_numpy(data['pixel_attention_mask'])
622
+ # Remove padding images from the mask/pP p
623
+ pixel_attention_mask = pixel_attention_mask.view(
624
+ batch_size * num_images, *pixel_attention_mask.shape[2:]
625
+ )
626
+ pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
627
+ pixel_attention_mask = pixel_attention_mask.to(torch.bool)
628
+ else:
629
+ pixel_attention_mask = torch.ones(
630
+ size=(temp_pixel_values.size(0), temp_pixel_values.size(2), temp_pixel_values.size(3)),
631
+ dtype=torch.bool,
632
+ device=temp_pixel_values.device,
633
+ )
634
+
635
+ im_sizes = [torch.nonzero(mask > 0).max(dim=0)[0] + 1 for mask in pixel_attention_mask]
636
+ # print(im_sizes)
637
+ patch_size = 14 #self.config.vision_config.patch_size
638
+ patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
639
+ # print(patches_subgrid.shape)
640
+ patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
641
+ # print(patches_subgrid.shape)
642
+ patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
643
+
644
+ data["navit_pixel_values"] = temp_pixel_values
645
+ data["pixel_attention_mask"] = patch_attention_mask
646
+
647
+ return BatchFeature(data=data, tensor_type=return_tensors)
648
+
649
+ def infer_processed_size(
650
+ self,
651
+ image,
652
+ do_resize: Optional[bool] = None,
653
+ size: Optional[Dict[str, int]] = None,
654
+ do_image_splitting: Optional[bool] = None,
655
+ ):
656
+ """
657
+ Preprocess a batch of images.
658
+
659
+ Args:
660
+ images (`ImageInput`):
661
+ A list of images to preprocess.
662
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
663
+ Whether to resize the image.
664
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
665
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
666
+ the longest edge resized to keep the input aspect ratio.
667
+ do_image_splitting (`bool`, *optional*, defaults to `self.do_image_splitting`):
668
+ Whether to split the image into a sequence 4 equal sub-images concatenated with the original image. That
669
+ strategy was first introduced in https://arxiv.org/abs/2311.06607.
670
+ """
671
+ do_resize = do_resize if do_resize is not None else self.do_resize
672
+ size = size if size is not None else self.size
673
+ do_image_splitting = do_image_splitting if do_image_splitting is not None else self.do_image_splitting
674
+
675
+ if isinstance(image, str):
676
+ try:
677
+ tmp_img = Image.open(image)
678
+ w, h = tmp_img.size
679
+ except Exception as e:
680
+ error_str = f"load {image} error: {e}. casting to default image size (black image in the dataloader)...\n"
681
+ with open('/tmp/image_processor_log.txt', 'a') as f:
682
+ f.write(error_str)
683
+ print(error_str)
684
+ NAVIT_MIN_RES = 378
685
+ w, h = NAVIT_MIN_RES, NAVIT_MIN_RES
686
+ else:
687
+ h, w, _ = image.shape
688
+ assert w > 4 and h > 4
689
+
690
+ size_list = None
691
+
692
+ if do_image_splitting:
693
+ mid_width = w // 2
694
+ mid_height = h // 2
695
+
696
+ size_list = [
697
+ [mid_width, mid_height],
698
+ [w - mid_width, mid_height],
699
+ [mid_width, h - mid_height],
700
+ [w - mid_width, h - mid_height],
701
+ [w, h]
702
+ ]
703
+ else:
704
+ size_list = [
705
+ [w, h]
706
+ ]
707
+
708
+ if do_resize:
709
+ def get_resized_size(input_size, size):
710
+ width, height = input_size
711
+ min_len = size["shortest_edge"]
712
+ max_len = size["longest_edge"]
713
+ aspect_ratio = width / height
714
+
715
+ if width >= height and width > max_len:
716
+ width = max_len
717
+ height = int(width / aspect_ratio)
718
+ elif height > width and height > max_len:
719
+ height = max_len
720
+ width = int(height * aspect_ratio)
721
+
722
+ height = max(height, min_len)
723
+ width = max(width, min_len)
724
+ return [width, height]
725
+ size_list = [get_resized_size(input_size, size) for input_size in size_list]
726
+
727
+ patch_size = 14 #self.config.vision_config.patch_size
728
+
729
+ size_list = [[int(np.ceil(w / patch_size)), int(np.ceil(h / patch_size))] for w, h in size_list]
730
+
731
+ return size_list
732
+
733
+ @classmethod
734
+ def from_pretrained(self, config_path):
735
+ with open(f'{config_path}/config.json', "r", encoding="utf-8") as f:
736
+ config = json.load(f)
737
+
738
+ cls = Idefics2ImageProcessor(
739
+ do_convert_rgb = config['do_convert_rgb'],
740
+ do_resize = config['do_resize'],
741
+ size = config['size'],
742
+ resample = config['resample'],
743
+ do_rescale = config['do_rescale'],
744
+ rescale_factor = config['rescale_factor'],
745
+ do_normalize = config['do_normalize'],
746
+ image_mean = config['image_mean'],
747
+ image_std = config['image_std'],
748
+ do_pad = config['do_pad'],
749
+ do_image_splitting = config['do_image_splitting']
750
+ )
751
+ return cls
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b799aaaaacd35d29d0126b1b64f4ee3a550118729dca1d7299aad4a17651576a
3
+ size 4985298840
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7301ea2a08c6cbc778981c49fa2fbac7fe86862ca7dbf3b044fdc1c572648f2e
3
+ size 4995585752
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d4f8874cf1556c460b7dc643ac57cc763dfc571d77fa09d93c69baed5487684
3
+ size 4999797696
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a81080d0c17f10e73753f307039ca2f411e2c9907ccda02e78d9bcd4ad2eafc
3
+ size 3964803800
model.safetensors.index.json ADDED
@@ -0,0 +1,994 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 18945355232
4
+ },
5
+ "weight_map": {
6
+ "downsampler.group_op.weight": "model-00001-of-00004.safetensors",
7
+ "downsampler.linear_model.1.weight": "model-00001-of-00004.safetensors",
8
+ "language_model.model.layers.0.attention.wo.Plora_A.weight": "model-00001-of-00004.safetensors",
9
+ "language_model.model.layers.0.attention.wo.Plora_B.weight": "model-00001-of-00004.safetensors",
10
+ "language_model.model.layers.0.attention.wo.original_linear.weight": "model-00001-of-00004.safetensors",
11
+ "language_model.model.layers.0.attention.wqkv.Plora_A.weight": "model-00001-of-00004.safetensors",
12
+ "language_model.model.layers.0.attention.wqkv.Plora_B.weight": "model-00001-of-00004.safetensors",
13
+ "language_model.model.layers.0.attention.wqkv.original_linear.weight": "model-00001-of-00004.safetensors",
14
+ "language_model.model.layers.0.attention_norm.weight": "model-00001-of-00004.safetensors",
15
+ "language_model.model.layers.0.feed_forward.w1.Plora_A.weight": "model-00001-of-00004.safetensors",
16
+ "language_model.model.layers.0.feed_forward.w1.Plora_B.weight": "model-00001-of-00004.safetensors",
17
+ "language_model.model.layers.0.feed_forward.w1.original_linear.weight": "model-00001-of-00004.safetensors",
18
+ "language_model.model.layers.0.feed_forward.w2.Plora_A.weight": "model-00001-of-00004.safetensors",
19
+ "language_model.model.layers.0.feed_forward.w2.Plora_B.weight": "model-00001-of-00004.safetensors",
20
+ "language_model.model.layers.0.feed_forward.w2.original_linear.weight": "model-00001-of-00004.safetensors",
21
+ "language_model.model.layers.0.feed_forward.w3.Plora_A.weight": "model-00001-of-00004.safetensors",
22
+ "language_model.model.layers.0.feed_forward.w3.Plora_B.weight": "model-00001-of-00004.safetensors",
23
+ "language_model.model.layers.0.feed_forward.w3.original_linear.weight": "model-00001-of-00004.safetensors",
24
+ "language_model.model.layers.0.ffn_norm.weight": "model-00001-of-00004.safetensors",
25
+ "language_model.model.layers.1.attention.wo.Plora_A.weight": "model-00001-of-00004.safetensors",
26
+ "language_model.model.layers.1.attention.wo.Plora_B.weight": "model-00001-of-00004.safetensors",
27
+ "language_model.model.layers.1.attention.wo.original_linear.weight": "model-00001-of-00004.safetensors",
28
+ "language_model.model.layers.1.attention.wqkv.Plora_A.weight": "model-00001-of-00004.safetensors",
29
+ "language_model.model.layers.1.attention.wqkv.Plora_B.weight": "model-00001-of-00004.safetensors",
30
+ "language_model.model.layers.1.attention.wqkv.original_linear.weight": "model-00001-of-00004.safetensors",
31
+ "language_model.model.layers.1.attention_norm.weight": "model-00001-of-00004.safetensors",
32
+ "language_model.model.layers.1.feed_forward.w1.Plora_A.weight": "model-00001-of-00004.safetensors",
33
+ "language_model.model.layers.1.feed_forward.w1.Plora_B.weight": "model-00001-of-00004.safetensors",
34
+ "language_model.model.layers.1.feed_forward.w1.original_linear.weight": "model-00001-of-00004.safetensors",
35
+ "language_model.model.layers.1.feed_forward.w2.Plora_A.weight": "model-00001-of-00004.safetensors",
36
+ "language_model.model.layers.1.feed_forward.w2.Plora_B.weight": "model-00001-of-00004.safetensors",
37
+ "language_model.model.layers.1.feed_forward.w2.original_linear.weight": "model-00001-of-00004.safetensors",
38
+ "language_model.model.layers.1.feed_forward.w3.Plora_A.weight": "model-00001-of-00004.safetensors",
39
+ "language_model.model.layers.1.feed_forward.w3.Plora_B.weight": "model-00001-of-00004.safetensors",
40
+ "language_model.model.layers.1.feed_forward.w3.original_linear.weight": "model-00001-of-00004.safetensors",
41
+ "language_model.model.layers.1.ffn_norm.weight": "model-00001-of-00004.safetensors",
42
+ "language_model.model.layers.10.attention.wo.Plora_A.weight": "model-00002-of-00004.safetensors",
43
+ "language_model.model.layers.10.attention.wo.Plora_B.weight": "model-00002-of-00004.safetensors",
44
+ "language_model.model.layers.10.attention.wo.original_linear.weight": "model-00002-of-00004.safetensors",
45
+ "language_model.model.layers.10.attention.wqkv.Plora_A.weight": "model-00002-of-00004.safetensors",
46
+ "language_model.model.layers.10.attention.wqkv.Plora_B.weight": "model-00002-of-00004.safetensors",
47
+ "language_model.model.layers.10.attention.wqkv.original_linear.weight": "model-00002-of-00004.safetensors",
48
+ "language_model.model.layers.10.attention_norm.weight": "model-00002-of-00004.safetensors",
49
+ "language_model.model.layers.10.feed_forward.w1.Plora_A.weight": "model-00002-of-00004.safetensors",
50
+ "language_model.model.layers.10.feed_forward.w1.Plora_B.weight": "model-00002-of-00004.safetensors",
51
+ "language_model.model.layers.10.feed_forward.w1.original_linear.weight": "model-00002-of-00004.safetensors",
52
+ "language_model.model.layers.10.feed_forward.w2.Plora_A.weight": "model-00002-of-00004.safetensors",
53
+ "language_model.model.layers.10.feed_forward.w2.Plora_B.weight": "model-00002-of-00004.safetensors",
54
+ "language_model.model.layers.10.feed_forward.w2.original_linear.weight": "model-00002-of-00004.safetensors",
55
+ "language_model.model.layers.10.feed_forward.w3.Plora_A.weight": "model-00002-of-00004.safetensors",
56
+ "language_model.model.layers.10.feed_forward.w3.Plora_B.weight": "model-00002-of-00004.safetensors",
57
+ "language_model.model.layers.10.feed_forward.w3.original_linear.weight": "model-00002-of-00004.safetensors",
58
+ "language_model.model.layers.10.ffn_norm.weight": "model-00002-of-00004.safetensors",
59
+ "language_model.model.layers.11.attention.wo.Plora_A.weight": "model-00002-of-00004.safetensors",
60
+ "language_model.model.layers.11.attention.wo.Plora_B.weight": "model-00002-of-00004.safetensors",
61
+ "language_model.model.layers.11.attention.wo.original_linear.weight": "model-00002-of-00004.safetensors",
62
+ "language_model.model.layers.11.attention.wqkv.Plora_A.weight": "model-00002-of-00004.safetensors",
63
+ "language_model.model.layers.11.attention.wqkv.Plora_B.weight": "model-00002-of-00004.safetensors",
64
+ "language_model.model.layers.11.attention.wqkv.original_linear.weight": "model-00002-of-00004.safetensors",
65
+ "language_model.model.layers.11.attention_norm.weight": "model-00002-of-00004.safetensors",
66
+ "language_model.model.layers.11.feed_forward.w1.Plora_A.weight": "model-00002-of-00004.safetensors",
67
+ "language_model.model.layers.11.feed_forward.w1.Plora_B.weight": "model-00002-of-00004.safetensors",
68
+ "language_model.model.layers.11.feed_forward.w1.original_linear.weight": "model-00002-of-00004.safetensors",
69
+ "language_model.model.layers.11.feed_forward.w2.Plora_A.weight": "model-00002-of-00004.safetensors",
70
+ "language_model.model.layers.11.feed_forward.w2.Plora_B.weight": "model-00002-of-00004.safetensors",
71
+ "language_model.model.layers.11.feed_forward.w2.original_linear.weight": "model-00002-of-00004.safetensors",
72
+ "language_model.model.layers.11.feed_forward.w3.Plora_A.weight": "model-00002-of-00004.safetensors",
73
+ "language_model.model.layers.11.feed_forward.w3.Plora_B.weight": "model-00002-of-00004.safetensors",
74
+ "language_model.model.layers.11.feed_forward.w3.original_linear.weight": "model-00002-of-00004.safetensors",
75
+ "language_model.model.layers.11.ffn_norm.weight": "model-00002-of-00004.safetensors",
76
+ "language_model.model.layers.12.attention.wo.Plora_A.weight": "model-00002-of-00004.safetensors",
77
+ "language_model.model.layers.12.attention.wo.Plora_B.weight": "model-00002-of-00004.safetensors",
78
+ "language_model.model.layers.12.attention.wo.original_linear.weight": "model-00002-of-00004.safetensors",
79
+ "language_model.model.layers.12.attention.wqkv.Plora_A.weight": "model-00002-of-00004.safetensors",
80
+ "language_model.model.layers.12.attention.wqkv.Plora_B.weight": "model-00002-of-00004.safetensors",
81
+ "language_model.model.layers.12.attention.wqkv.original_linear.weight": "model-00002-of-00004.safetensors",
82
+ "language_model.model.layers.12.attention_norm.weight": "model-00002-of-00004.safetensors",
83
+ "language_model.model.layers.12.feed_forward.w1.Plora_A.weight": "model-00002-of-00004.safetensors",
84
+ "language_model.model.layers.12.feed_forward.w1.Plora_B.weight": "model-00002-of-00004.safetensors",
85
+ "language_model.model.layers.12.feed_forward.w1.original_linear.weight": "model-00002-of-00004.safetensors",
86
+ "language_model.model.layers.12.feed_forward.w2.Plora_A.weight": "model-00002-of-00004.safetensors",
87
+ "language_model.model.layers.12.feed_forward.w2.Plora_B.weight": "model-00002-of-00004.safetensors",
88
+ "language_model.model.layers.12.feed_forward.w2.original_linear.weight": "model-00002-of-00004.safetensors",
89
+ "language_model.model.layers.12.feed_forward.w3.Plora_A.weight": "model-00002-of-00004.safetensors",
90
+ "language_model.model.layers.12.feed_forward.w3.Plora_B.weight": "model-00002-of-00004.safetensors",
91
+ "language_model.model.layers.12.feed_forward.w3.original_linear.weight": "model-00002-of-00004.safetensors",
92
+ "language_model.model.layers.12.ffn_norm.weight": "model-00002-of-00004.safetensors",
93
+ "language_model.model.layers.13.attention.wo.Plora_A.weight": "model-00002-of-00004.safetensors",
94
+ "language_model.model.layers.13.attention.wo.Plora_B.weight": "model-00002-of-00004.safetensors",
95
+ "language_model.model.layers.13.attention.wo.original_linear.weight": "model-00002-of-00004.safetensors",
96
+ "language_model.model.layers.13.attention.wqkv.Plora_A.weight": "model-00002-of-00004.safetensors",
97
+ "language_model.model.layers.13.attention.wqkv.Plora_B.weight": "model-00002-of-00004.safetensors",
98
+ "language_model.model.layers.13.attention.wqkv.original_linear.weight": "model-00002-of-00004.safetensors",
99
+ "language_model.model.layers.13.attention_norm.weight": "model-00002-of-00004.safetensors",
100
+ "language_model.model.layers.13.feed_forward.w1.Plora_A.weight": "model-00002-of-00004.safetensors",
101
+ "language_model.model.layers.13.feed_forward.w1.Plora_B.weight": "model-00002-of-00004.safetensors",
102
+ "language_model.model.layers.13.feed_forward.w1.original_linear.weight": "model-00002-of-00004.safetensors",
103
+ "language_model.model.layers.13.feed_forward.w2.Plora_A.weight": "model-00002-of-00004.safetensors",
104
+ "language_model.model.layers.13.feed_forward.w2.Plora_B.weight": "model-00002-of-00004.safetensors",
105
+ "language_model.model.layers.13.feed_forward.w2.original_linear.weight": "model-00002-of-00004.safetensors",
106
+ "language_model.model.layers.13.feed_forward.w3.Plora_A.weight": "model-00002-of-00004.safetensors",
107
+ "language_model.model.layers.13.feed_forward.w3.Plora_B.weight": "model-00002-of-00004.safetensors",
108
+ "language_model.model.layers.13.feed_forward.w3.original_linear.weight": "model-00002-of-00004.safetensors",
109
+ "language_model.model.layers.13.ffn_norm.weight": "model-00002-of-00004.safetensors",
110
+ "language_model.model.layers.14.attention.wo.Plora_A.weight": "model-00002-of-00004.safetensors",
111
+ "language_model.model.layers.14.attention.wo.Plora_B.weight": "model-00002-of-00004.safetensors",
112
+ "language_model.model.layers.14.attention.wo.original_linear.weight": "model-00002-of-00004.safetensors",
113
+ "language_model.model.layers.14.attention.wqkv.Plora_A.weight": "model-00002-of-00004.safetensors",
114
+ "language_model.model.layers.14.attention.wqkv.Plora_B.weight": "model-00002-of-00004.safetensors",
115
+ "language_model.model.layers.14.attention.wqkv.original_linear.weight": "model-00002-of-00004.safetensors",
116
+ "language_model.model.layers.14.attention_norm.weight": "model-00002-of-00004.safetensors",
117
+ "language_model.model.layers.14.feed_forward.w1.Plora_A.weight": "model-00002-of-00004.safetensors",
118
+ "language_model.model.layers.14.feed_forward.w1.Plora_B.weight": "model-00002-of-00004.safetensors",
119
+ "language_model.model.layers.14.feed_forward.w1.original_linear.weight": "model-00002-of-00004.safetensors",
120
+ "language_model.model.layers.14.feed_forward.w2.Plora_A.weight": "model-00002-of-00004.safetensors",
121
+ "language_model.model.layers.14.feed_forward.w2.Plora_B.weight": "model-00002-of-00004.safetensors",
122
+ "language_model.model.layers.14.feed_forward.w2.original_linear.weight": "model-00002-of-00004.safetensors",
123
+ "language_model.model.layers.14.feed_forward.w3.Plora_A.weight": "model-00002-of-00004.safetensors",
124
+ "language_model.model.layers.14.feed_forward.w3.Plora_B.weight": "model-00002-of-00004.safetensors",
125
+ "language_model.model.layers.14.feed_forward.w3.original_linear.weight": "model-00002-of-00004.safetensors",
126
+ "language_model.model.layers.14.ffn_norm.weight": "model-00002-of-00004.safetensors",
127
+ "language_model.model.layers.15.attention.wo.Plora_A.weight": "model-00002-of-00004.safetensors",
128
+ "language_model.model.layers.15.attention.wo.Plora_B.weight": "model-00002-of-00004.safetensors",
129
+ "language_model.model.layers.15.attention.wo.original_linear.weight": "model-00002-of-00004.safetensors",
130
+ "language_model.model.layers.15.attention.wqkv.Plora_A.weight": "model-00002-of-00004.safetensors",
131
+ "language_model.model.layers.15.attention.wqkv.Plora_B.weight": "model-00002-of-00004.safetensors",
132
+ "language_model.model.layers.15.attention.wqkv.original_linear.weight": "model-00002-of-00004.safetensors",
133
+ "language_model.model.layers.15.attention_norm.weight": "model-00003-of-00004.safetensors",
134
+ "language_model.model.layers.15.feed_forward.w1.Plora_A.weight": "model-00002-of-00004.safetensors",
135
+ "language_model.model.layers.15.feed_forward.w1.Plora_B.weight": "model-00002-of-00004.safetensors",
136
+ "language_model.model.layers.15.feed_forward.w1.original_linear.weight": "model-00002-of-00004.safetensors",
137
+ "language_model.model.layers.15.feed_forward.w2.Plora_A.weight": "model-00003-of-00004.safetensors",
138
+ "language_model.model.layers.15.feed_forward.w2.Plora_B.weight": "model-00003-of-00004.safetensors",
139
+ "language_model.model.layers.15.feed_forward.w2.original_linear.weight": "model-00002-of-00004.safetensors",
140
+ "language_model.model.layers.15.feed_forward.w3.Plora_A.weight": "model-00002-of-00004.safetensors",
141
+ "language_model.model.layers.15.feed_forward.w3.Plora_B.weight": "model-00002-of-00004.safetensors",
142
+ "language_model.model.layers.15.feed_forward.w3.original_linear.weight": "model-00002-of-00004.safetensors",
143
+ "language_model.model.layers.15.ffn_norm.weight": "model-00003-of-00004.safetensors",
144
+ "language_model.model.layers.16.attention.wo.Plora_A.weight": "model-00003-of-00004.safetensors",
145
+ "language_model.model.layers.16.attention.wo.Plora_B.weight": "model-00003-of-00004.safetensors",
146
+ "language_model.model.layers.16.attention.wo.original_linear.weight": "model-00003-of-00004.safetensors",
147
+ "language_model.model.layers.16.attention.wqkv.Plora_A.weight": "model-00003-of-00004.safetensors",
148
+ "language_model.model.layers.16.attention.wqkv.Plora_B.weight": "model-00003-of-00004.safetensors",
149
+ "language_model.model.layers.16.attention.wqkv.original_linear.weight": "model-00003-of-00004.safetensors",
150
+ "language_model.model.layers.16.attention_norm.weight": "model-00003-of-00004.safetensors",
151
+ "language_model.model.layers.16.feed_forward.w1.Plora_A.weight": "model-00003-of-00004.safetensors",
152
+ "language_model.model.layers.16.feed_forward.w1.Plora_B.weight": "model-00003-of-00004.safetensors",
153
+ "language_model.model.layers.16.feed_forward.w1.original_linear.weight": "model-00003-of-00004.safetensors",
154
+ "language_model.model.layers.16.feed_forward.w2.Plora_A.weight": "model-00003-of-00004.safetensors",
155
+ "language_model.model.layers.16.feed_forward.w2.Plora_B.weight": "model-00003-of-00004.safetensors",
156
+ "language_model.model.layers.16.feed_forward.w2.original_linear.weight": "model-00003-of-00004.safetensors",
157
+ "language_model.model.layers.16.feed_forward.w3.Plora_A.weight": "model-00003-of-00004.safetensors",
158
+ "language_model.model.layers.16.feed_forward.w3.Plora_B.weight": "model-00003-of-00004.safetensors",
159
+ "language_model.model.layers.16.feed_forward.w3.original_linear.weight": "model-00003-of-00004.safetensors",
160
+ "language_model.model.layers.16.ffn_norm.weight": "model-00003-of-00004.safetensors",
161
+ "language_model.model.layers.17.attention.wo.Plora_A.weight": "model-00003-of-00004.safetensors",
162
+ "language_model.model.layers.17.attention.wo.Plora_B.weight": "model-00003-of-00004.safetensors",
163
+ "language_model.model.layers.17.attention.wo.original_linear.weight": "model-00003-of-00004.safetensors",
164
+ "language_model.model.layers.17.attention.wqkv.Plora_A.weight": "model-00003-of-00004.safetensors",
165
+ "language_model.model.layers.17.attention.wqkv.Plora_B.weight": "model-00003-of-00004.safetensors",
166
+ "language_model.model.layers.17.attention.wqkv.original_linear.weight": "model-00003-of-00004.safetensors",
167
+ "language_model.model.layers.17.attention_norm.weight": "model-00003-of-00004.safetensors",
168
+ "language_model.model.layers.17.feed_forward.w1.Plora_A.weight": "model-00003-of-00004.safetensors",
169
+ "language_model.model.layers.17.feed_forward.w1.Plora_B.weight": "model-00003-of-00004.safetensors",
170
+ "language_model.model.layers.17.feed_forward.w1.original_linear.weight": "model-00003-of-00004.safetensors",
171
+ "language_model.model.layers.17.feed_forward.w2.Plora_A.weight": "model-00003-of-00004.safetensors",
172
+ "language_model.model.layers.17.feed_forward.w2.Plora_B.weight": "model-00003-of-00004.safetensors",
173
+ "language_model.model.layers.17.feed_forward.w2.original_linear.weight": "model-00003-of-00004.safetensors",
174
+ "language_model.model.layers.17.feed_forward.w3.Plora_A.weight": "model-00003-of-00004.safetensors",
175
+ "language_model.model.layers.17.feed_forward.w3.Plora_B.weight": "model-00003-of-00004.safetensors",
176
+ "language_model.model.layers.17.feed_forward.w3.original_linear.weight": "model-00003-of-00004.safetensors",
177
+ "language_model.model.layers.17.ffn_norm.weight": "model-00003-of-00004.safetensors",
178
+ "language_model.model.layers.18.attention.wo.Plora_A.weight": "model-00003-of-00004.safetensors",
179
+ "language_model.model.layers.18.attention.wo.Plora_B.weight": "model-00003-of-00004.safetensors",
180
+ "language_model.model.layers.18.attention.wo.original_linear.weight": "model-00003-of-00004.safetensors",
181
+ "language_model.model.layers.18.attention.wqkv.Plora_A.weight": "model-00003-of-00004.safetensors",
182
+ "language_model.model.layers.18.attention.wqkv.Plora_B.weight": "model-00003-of-00004.safetensors",
183
+ "language_model.model.layers.18.attention.wqkv.original_linear.weight": "model-00003-of-00004.safetensors",
184
+ "language_model.model.layers.18.attention_norm.weight": "model-00003-of-00004.safetensors",
185
+ "language_model.model.layers.18.feed_forward.w1.Plora_A.weight": "model-00003-of-00004.safetensors",
186
+ "language_model.model.layers.18.feed_forward.w1.Plora_B.weight": "model-00003-of-00004.safetensors",
187
+ "language_model.model.layers.18.feed_forward.w1.original_linear.weight": "model-00003-of-00004.safetensors",
188
+ "language_model.model.layers.18.feed_forward.w2.Plora_A.weight": "model-00003-of-00004.safetensors",
189
+ "language_model.model.layers.18.feed_forward.w2.Plora_B.weight": "model-00003-of-00004.safetensors",
190
+ "language_model.model.layers.18.feed_forward.w2.original_linear.weight": "model-00003-of-00004.safetensors",
191
+ "language_model.model.layers.18.feed_forward.w3.Plora_A.weight": "model-00003-of-00004.safetensors",
192
+ "language_model.model.layers.18.feed_forward.w3.Plora_B.weight": "model-00003-of-00004.safetensors",
193
+ "language_model.model.layers.18.feed_forward.w3.original_linear.weight": "model-00003-of-00004.safetensors",
194
+ "language_model.model.layers.18.ffn_norm.weight": "model-00003-of-00004.safetensors",
195
+ "language_model.model.layers.19.attention.wo.Plora_A.weight": "model-00003-of-00004.safetensors",
196
+ "language_model.model.layers.19.attention.wo.Plora_B.weight": "model-00003-of-00004.safetensors",
197
+ "language_model.model.layers.19.attention.wo.original_linear.weight": "model-00003-of-00004.safetensors",
198
+ "language_model.model.layers.19.attention.wqkv.Plora_A.weight": "model-00003-of-00004.safetensors",
199
+ "language_model.model.layers.19.attention.wqkv.Plora_B.weight": "model-00003-of-00004.safetensors",
200
+ "language_model.model.layers.19.attention.wqkv.original_linear.weight": "model-00003-of-00004.safetensors",
201
+ "language_model.model.layers.19.attention_norm.weight": "model-00003-of-00004.safetensors",
202
+ "language_model.model.layers.19.feed_forward.w1.Plora_A.weight": "model-00003-of-00004.safetensors",
203
+ "language_model.model.layers.19.feed_forward.w1.Plora_B.weight": "model-00003-of-00004.safetensors",
204
+ "language_model.model.layers.19.feed_forward.w1.original_linear.weight": "model-00003-of-00004.safetensors",
205
+ "language_model.model.layers.19.feed_forward.w2.Plora_A.weight": "model-00003-of-00004.safetensors",
206
+ "language_model.model.layers.19.feed_forward.w2.Plora_B.weight": "model-00003-of-00004.safetensors",
207
+ "language_model.model.layers.19.feed_forward.w2.original_linear.weight": "model-00003-of-00004.safetensors",
208
+ "language_model.model.layers.19.feed_forward.w3.Plora_A.weight": "model-00003-of-00004.safetensors",
209
+ "language_model.model.layers.19.feed_forward.w3.Plora_B.weight": "model-00003-of-00004.safetensors",
210
+ "language_model.model.layers.19.feed_forward.w3.original_linear.weight": "model-00003-of-00004.safetensors",
211
+ "language_model.model.layers.19.ffn_norm.weight": "model-00003-of-00004.safetensors",
212
+ "language_model.model.layers.2.attention.wo.Plora_A.weight": "model-00001-of-00004.safetensors",
213
+ "language_model.model.layers.2.attention.wo.Plora_B.weight": "model-00001-of-00004.safetensors",
214
+ "language_model.model.layers.2.attention.wo.original_linear.weight": "model-00001-of-00004.safetensors",
215
+ "language_model.model.layers.2.attention.wqkv.Plora_A.weight": "model-00001-of-00004.safetensors",
216
+ "language_model.model.layers.2.attention.wqkv.Plora_B.weight": "model-00001-of-00004.safetensors",
217
+ "language_model.model.layers.2.attention.wqkv.original_linear.weight": "model-00001-of-00004.safetensors",
218
+ "language_model.model.layers.2.attention_norm.weight": "model-00001-of-00004.safetensors",
219
+ "language_model.model.layers.2.feed_forward.w1.Plora_A.weight": "model-00001-of-00004.safetensors",
220
+ "language_model.model.layers.2.feed_forward.w1.Plora_B.weight": "model-00001-of-00004.safetensors",
221
+ "language_model.model.layers.2.feed_forward.w1.original_linear.weight": "model-00001-of-00004.safetensors",
222
+ "language_model.model.layers.2.feed_forward.w2.Plora_A.weight": "model-00001-of-00004.safetensors",
223
+ "language_model.model.layers.2.feed_forward.w2.Plora_B.weight": "model-00001-of-00004.safetensors",
224
+ "language_model.model.layers.2.feed_forward.w2.original_linear.weight": "model-00001-of-00004.safetensors",
225
+ "language_model.model.layers.2.feed_forward.w3.Plora_A.weight": "model-00001-of-00004.safetensors",
226
+ "language_model.model.layers.2.feed_forward.w3.Plora_B.weight": "model-00001-of-00004.safetensors",
227
+ "language_model.model.layers.2.feed_forward.w3.original_linear.weight": "model-00001-of-00004.safetensors",
228
+ "language_model.model.layers.2.ffn_norm.weight": "model-00001-of-00004.safetensors",
229
+ "language_model.model.layers.20.attention.wo.Plora_A.weight": "model-00003-of-00004.safetensors",
230
+ "language_model.model.layers.20.attention.wo.Plora_B.weight": "model-00003-of-00004.safetensors",
231
+ "language_model.model.layers.20.attention.wo.original_linear.weight": "model-00003-of-00004.safetensors",
232
+ "language_model.model.layers.20.attention.wqkv.Plora_A.weight": "model-00003-of-00004.safetensors",
233
+ "language_model.model.layers.20.attention.wqkv.Plora_B.weight": "model-00003-of-00004.safetensors",
234
+ "language_model.model.layers.20.attention.wqkv.original_linear.weight": "model-00003-of-00004.safetensors",
235
+ "language_model.model.layers.20.attention_norm.weight": "model-00003-of-00004.safetensors",
236
+ "language_model.model.layers.20.feed_forward.w1.Plora_A.weight": "model-00003-of-00004.safetensors",
237
+ "language_model.model.layers.20.feed_forward.w1.Plora_B.weight": "model-00003-of-00004.safetensors",
238
+ "language_model.model.layers.20.feed_forward.w1.original_linear.weight": "model-00003-of-00004.safetensors",
239
+ "language_model.model.layers.20.feed_forward.w2.Plora_A.weight": "model-00003-of-00004.safetensors",
240
+ "language_model.model.layers.20.feed_forward.w2.Plora_B.weight": "model-00003-of-00004.safetensors",
241
+ "language_model.model.layers.20.feed_forward.w2.original_linear.weight": "model-00003-of-00004.safetensors",
242
+ "language_model.model.layers.20.feed_forward.w3.Plora_A.weight": "model-00003-of-00004.safetensors",
243
+ "language_model.model.layers.20.feed_forward.w3.Plora_B.weight": "model-00003-of-00004.safetensors",
244
+ "language_model.model.layers.20.feed_forward.w3.original_linear.weight": "model-00003-of-00004.safetensors",
245
+ "language_model.model.layers.20.ffn_norm.weight": "model-00003-of-00004.safetensors",
246
+ "language_model.model.layers.21.attention.wo.Plora_A.weight": "model-00003-of-00004.safetensors",
247
+ "language_model.model.layers.21.attention.wo.Plora_B.weight": "model-00003-of-00004.safetensors",
248
+ "language_model.model.layers.21.attention.wo.original_linear.weight": "model-00003-of-00004.safetensors",
249
+ "language_model.model.layers.21.attention.wqkv.Plora_A.weight": "model-00003-of-00004.safetensors",
250
+ "language_model.model.layers.21.attention.wqkv.Plora_B.weight": "model-00003-of-00004.safetensors",
251
+ "language_model.model.layers.21.attention.wqkv.original_linear.weight": "model-00003-of-00004.safetensors",
252
+ "language_model.model.layers.21.attention_norm.weight": "model-00003-of-00004.safetensors",
253
+ "language_model.model.layers.21.feed_forward.w1.Plora_A.weight": "model-00003-of-00004.safetensors",
254
+ "language_model.model.layers.21.feed_forward.w1.Plora_B.weight": "model-00003-of-00004.safetensors",
255
+ "language_model.model.layers.21.feed_forward.w1.original_linear.weight": "model-00003-of-00004.safetensors",
256
+ "language_model.model.layers.21.feed_forward.w2.Plora_A.weight": "model-00003-of-00004.safetensors",
257
+ "language_model.model.layers.21.feed_forward.w2.Plora_B.weight": "model-00003-of-00004.safetensors",
258
+ "language_model.model.layers.21.feed_forward.w2.original_linear.weight": "model-00003-of-00004.safetensors",
259
+ "language_model.model.layers.21.feed_forward.w3.Plora_A.weight": "model-00003-of-00004.safetensors",
260
+ "language_model.model.layers.21.feed_forward.w3.Plora_B.weight": "model-00003-of-00004.safetensors",
261
+ "language_model.model.layers.21.feed_forward.w3.original_linear.weight": "model-00003-of-00004.safetensors",
262
+ "language_model.model.layers.21.ffn_norm.weight": "model-00003-of-00004.safetensors",
263
+ "language_model.model.layers.22.attention.wo.Plora_A.weight": "model-00003-of-00004.safetensors",
264
+ "language_model.model.layers.22.attention.wo.Plora_B.weight": "model-00003-of-00004.safetensors",
265
+ "language_model.model.layers.22.attention.wo.original_linear.weight": "model-00003-of-00004.safetensors",
266
+ "language_model.model.layers.22.attention.wqkv.Plora_A.weight": "model-00003-of-00004.safetensors",
267
+ "language_model.model.layers.22.attention.wqkv.Plora_B.weight": "model-00003-of-00004.safetensors",
268
+ "language_model.model.layers.22.attention.wqkv.original_linear.weight": "model-00003-of-00004.safetensors",
269
+ "language_model.model.layers.22.attention_norm.weight": "model-00003-of-00004.safetensors",
270
+ "language_model.model.layers.22.feed_forward.w1.Plora_A.weight": "model-00003-of-00004.safetensors",
271
+ "language_model.model.layers.22.feed_forward.w1.Plora_B.weight": "model-00003-of-00004.safetensors",
272
+ "language_model.model.layers.22.feed_forward.w1.original_linear.weight": "model-00003-of-00004.safetensors",
273
+ "language_model.model.layers.22.feed_forward.w2.Plora_A.weight": "model-00003-of-00004.safetensors",
274
+ "language_model.model.layers.22.feed_forward.w2.Plora_B.weight": "model-00003-of-00004.safetensors",
275
+ "language_model.model.layers.22.feed_forward.w2.original_linear.weight": "model-00003-of-00004.safetensors",
276
+ "language_model.model.layers.22.feed_forward.w3.Plora_A.weight": "model-00003-of-00004.safetensors",
277
+ "language_model.model.layers.22.feed_forward.w3.Plora_B.weight": "model-00003-of-00004.safetensors",
278
+ "language_model.model.layers.22.feed_forward.w3.original_linear.weight": "model-00003-of-00004.safetensors",
279
+ "language_model.model.layers.22.ffn_norm.weight": "model-00003-of-00004.safetensors",
280
+ "language_model.model.layers.23.attention.wo.Plora_A.weight": "model-00003-of-00004.safetensors",
281
+ "language_model.model.layers.23.attention.wo.Plora_B.weight": "model-00003-of-00004.safetensors",
282
+ "language_model.model.layers.23.attention.wo.original_linear.weight": "model-00003-of-00004.safetensors",
283
+ "language_model.model.layers.23.attention.wqkv.Plora_A.weight": "model-00003-of-00004.safetensors",
284
+ "language_model.model.layers.23.attention.wqkv.Plora_B.weight": "model-00003-of-00004.safetensors",
285
+ "language_model.model.layers.23.attention.wqkv.original_linear.weight": "model-00003-of-00004.safetensors",
286
+ "language_model.model.layers.23.attention_norm.weight": "model-00003-of-00004.safetensors",
287
+ "language_model.model.layers.23.feed_forward.w1.Plora_A.weight": "model-00003-of-00004.safetensors",
288
+ "language_model.model.layers.23.feed_forward.w1.Plora_B.weight": "model-00003-of-00004.safetensors",
289
+ "language_model.model.layers.23.feed_forward.w1.original_linear.weight": "model-00003-of-00004.safetensors",
290
+ "language_model.model.layers.23.feed_forward.w2.Plora_A.weight": "model-00003-of-00004.safetensors",
291
+ "language_model.model.layers.23.feed_forward.w2.Plora_B.weight": "model-00003-of-00004.safetensors",
292
+ "language_model.model.layers.23.feed_forward.w2.original_linear.weight": "model-00003-of-00004.safetensors",
293
+ "language_model.model.layers.23.feed_forward.w3.Plora_A.weight": "model-00003-of-00004.safetensors",
294
+ "language_model.model.layers.23.feed_forward.w3.Plora_B.weight": "model-00003-of-00004.safetensors",
295
+ "language_model.model.layers.23.feed_forward.w3.original_linear.weight": "model-00003-of-00004.safetensors",
296
+ "language_model.model.layers.23.ffn_norm.weight": "model-00003-of-00004.safetensors",
297
+ "language_model.model.layers.24.attention.wo.Plora_A.weight": "model-00003-of-00004.safetensors",
298
+ "language_model.model.layers.24.attention.wo.Plora_B.weight": "model-00003-of-00004.safetensors",
299
+ "language_model.model.layers.24.attention.wo.original_linear.weight": "model-00003-of-00004.safetensors",
300
+ "language_model.model.layers.24.attention.wqkv.Plora_A.weight": "model-00003-of-00004.safetensors",
301
+ "language_model.model.layers.24.attention.wqkv.Plora_B.weight": "model-00003-of-00004.safetensors",
302
+ "language_model.model.layers.24.attention.wqkv.original_linear.weight": "model-00003-of-00004.safetensors",
303
+ "language_model.model.layers.24.attention_norm.weight": "model-00003-of-00004.safetensors",
304
+ "language_model.model.layers.24.feed_forward.w1.Plora_A.weight": "model-00003-of-00004.safetensors",
305
+ "language_model.model.layers.24.feed_forward.w1.Plora_B.weight": "model-00003-of-00004.safetensors",
306
+ "language_model.model.layers.24.feed_forward.w1.original_linear.weight": "model-00003-of-00004.safetensors",
307
+ "language_model.model.layers.24.feed_forward.w2.Plora_A.weight": "model-00003-of-00004.safetensors",
308
+ "language_model.model.layers.24.feed_forward.w2.Plora_B.weight": "model-00003-of-00004.safetensors",
309
+ "language_model.model.layers.24.feed_forward.w2.original_linear.weight": "model-00003-of-00004.safetensors",
310
+ "language_model.model.layers.24.feed_forward.w3.Plora_A.weight": "model-00003-of-00004.safetensors",
311
+ "language_model.model.layers.24.feed_forward.w3.Plora_B.weight": "model-00003-of-00004.safetensors",
312
+ "language_model.model.layers.24.feed_forward.w3.original_linear.weight": "model-00003-of-00004.safetensors",
313
+ "language_model.model.layers.24.ffn_norm.weight": "model-00003-of-00004.safetensors",
314
+ "language_model.model.layers.25.attention.wo.Plora_A.weight": "model-00003-of-00004.safetensors",
315
+ "language_model.model.layers.25.attention.wo.Plora_B.weight": "model-00003-of-00004.safetensors",
316
+ "language_model.model.layers.25.attention.wo.original_linear.weight": "model-00003-of-00004.safetensors",
317
+ "language_model.model.layers.25.attention.wqkv.Plora_A.weight": "model-00003-of-00004.safetensors",
318
+ "language_model.model.layers.25.attention.wqkv.Plora_B.weight": "model-00003-of-00004.safetensors",
319
+ "language_model.model.layers.25.attention.wqkv.original_linear.weight": "model-00003-of-00004.safetensors",
320
+ "language_model.model.layers.25.attention_norm.weight": "model-00004-of-00004.safetensors",
321
+ "language_model.model.layers.25.feed_forward.w1.Plora_A.weight": "model-00003-of-00004.safetensors",
322
+ "language_model.model.layers.25.feed_forward.w1.Plora_B.weight": "model-00003-of-00004.safetensors",
323
+ "language_model.model.layers.25.feed_forward.w1.original_linear.weight": "model-00003-of-00004.safetensors",
324
+ "language_model.model.layers.25.feed_forward.w2.Plora_A.weight": "model-00004-of-00004.safetensors",
325
+ "language_model.model.layers.25.feed_forward.w2.Plora_B.weight": "model-00004-of-00004.safetensors",
326
+ "language_model.model.layers.25.feed_forward.w2.original_linear.weight": "model-00004-of-00004.safetensors",
327
+ "language_model.model.layers.25.feed_forward.w3.Plora_A.weight": "model-00003-of-00004.safetensors",
328
+ "language_model.model.layers.25.feed_forward.w3.Plora_B.weight": "model-00003-of-00004.safetensors",
329
+ "language_model.model.layers.25.feed_forward.w3.original_linear.weight": "model-00003-of-00004.safetensors",
330
+ "language_model.model.layers.25.ffn_norm.weight": "model-00004-of-00004.safetensors",
331
+ "language_model.model.layers.26.attention.wo.Plora_A.weight": "model-00004-of-00004.safetensors",
332
+ "language_model.model.layers.26.attention.wo.Plora_B.weight": "model-00004-of-00004.safetensors",
333
+ "language_model.model.layers.26.attention.wo.original_linear.weight": "model-00004-of-00004.safetensors",
334
+ "language_model.model.layers.26.attention.wqkv.Plora_A.weight": "model-00004-of-00004.safetensors",
335
+ "language_model.model.layers.26.attention.wqkv.Plora_B.weight": "model-00004-of-00004.safetensors",
336
+ "language_model.model.layers.26.attention.wqkv.original_linear.weight": "model-00004-of-00004.safetensors",
337
+ "language_model.model.layers.26.attention_norm.weight": "model-00004-of-00004.safetensors",
338
+ "language_model.model.layers.26.feed_forward.w1.Plora_A.weight": "model-00004-of-00004.safetensors",
339
+ "language_model.model.layers.26.feed_forward.w1.Plora_B.weight": "model-00004-of-00004.safetensors",
340
+ "language_model.model.layers.26.feed_forward.w1.original_linear.weight": "model-00004-of-00004.safetensors",
341
+ "language_model.model.layers.26.feed_forward.w2.Plora_A.weight": "model-00004-of-00004.safetensors",
342
+ "language_model.model.layers.26.feed_forward.w2.Plora_B.weight": "model-00004-of-00004.safetensors",
343
+ "language_model.model.layers.26.feed_forward.w2.original_linear.weight": "model-00004-of-00004.safetensors",
344
+ "language_model.model.layers.26.feed_forward.w3.Plora_A.weight": "model-00004-of-00004.safetensors",
345
+ "language_model.model.layers.26.feed_forward.w3.Plora_B.weight": "model-00004-of-00004.safetensors",
346
+ "language_model.model.layers.26.feed_forward.w3.original_linear.weight": "model-00004-of-00004.safetensors",
347
+ "language_model.model.layers.26.ffn_norm.weight": "model-00004-of-00004.safetensors",
348
+ "language_model.model.layers.27.attention.wo.Plora_A.weight": "model-00004-of-00004.safetensors",
349
+ "language_model.model.layers.27.attention.wo.Plora_B.weight": "model-00004-of-00004.safetensors",
350
+ "language_model.model.layers.27.attention.wo.original_linear.weight": "model-00004-of-00004.safetensors",
351
+ "language_model.model.layers.27.attention.wqkv.Plora_A.weight": "model-00004-of-00004.safetensors",
352
+ "language_model.model.layers.27.attention.wqkv.Plora_B.weight": "model-00004-of-00004.safetensors",
353
+ "language_model.model.layers.27.attention.wqkv.original_linear.weight": "model-00004-of-00004.safetensors",
354
+ "language_model.model.layers.27.attention_norm.weight": "model-00004-of-00004.safetensors",
355
+ "language_model.model.layers.27.feed_forward.w1.Plora_A.weight": "model-00004-of-00004.safetensors",
356
+ "language_model.model.layers.27.feed_forward.w1.Plora_B.weight": "model-00004-of-00004.safetensors",
357
+ "language_model.model.layers.27.feed_forward.w1.original_linear.weight": "model-00004-of-00004.safetensors",
358
+ "language_model.model.layers.27.feed_forward.w2.Plora_A.weight": "model-00004-of-00004.safetensors",
359
+ "language_model.model.layers.27.feed_forward.w2.Plora_B.weight": "model-00004-of-00004.safetensors",
360
+ "language_model.model.layers.27.feed_forward.w2.original_linear.weight": "model-00004-of-00004.safetensors",
361
+ "language_model.model.layers.27.feed_forward.w3.Plora_A.weight": "model-00004-of-00004.safetensors",
362
+ "language_model.model.layers.27.feed_forward.w3.Plora_B.weight": "model-00004-of-00004.safetensors",
363
+ "language_model.model.layers.27.feed_forward.w3.original_linear.weight": "model-00004-of-00004.safetensors",
364
+ "language_model.model.layers.27.ffn_norm.weight": "model-00004-of-00004.safetensors",
365
+ "language_model.model.layers.28.attention.wo.Plora_A.weight": "model-00004-of-00004.safetensors",
366
+ "language_model.model.layers.28.attention.wo.Plora_B.weight": "model-00004-of-00004.safetensors",
367
+ "language_model.model.layers.28.attention.wo.original_linear.weight": "model-00004-of-00004.safetensors",
368
+ "language_model.model.layers.28.attention.wqkv.Plora_A.weight": "model-00004-of-00004.safetensors",
369
+ "language_model.model.layers.28.attention.wqkv.Plora_B.weight": "model-00004-of-00004.safetensors",
370
+ "language_model.model.layers.28.attention.wqkv.original_linear.weight": "model-00004-of-00004.safetensors",
371
+ "language_model.model.layers.28.attention_norm.weight": "model-00004-of-00004.safetensors",
372
+ "language_model.model.layers.28.feed_forward.w1.Plora_A.weight": "model-00004-of-00004.safetensors",
373
+ "language_model.model.layers.28.feed_forward.w1.Plora_B.weight": "model-00004-of-00004.safetensors",
374
+ "language_model.model.layers.28.feed_forward.w1.original_linear.weight": "model-00004-of-00004.safetensors",
375
+ "language_model.model.layers.28.feed_forward.w2.Plora_A.weight": "model-00004-of-00004.safetensors",
376
+ "language_model.model.layers.28.feed_forward.w2.Plora_B.weight": "model-00004-of-00004.safetensors",
377
+ "language_model.model.layers.28.feed_forward.w2.original_linear.weight": "model-00004-of-00004.safetensors",
378
+ "language_model.model.layers.28.feed_forward.w3.Plora_A.weight": "model-00004-of-00004.safetensors",
379
+ "language_model.model.layers.28.feed_forward.w3.Plora_B.weight": "model-00004-of-00004.safetensors",
380
+ "language_model.model.layers.28.feed_forward.w3.original_linear.weight": "model-00004-of-00004.safetensors",
381
+ "language_model.model.layers.28.ffn_norm.weight": "model-00004-of-00004.safetensors",
382
+ "language_model.model.layers.29.attention.wo.Plora_A.weight": "model-00004-of-00004.safetensors",
383
+ "language_model.model.layers.29.attention.wo.Plora_B.weight": "model-00004-of-00004.safetensors",
384
+ "language_model.model.layers.29.attention.wo.original_linear.weight": "model-00004-of-00004.safetensors",
385
+ "language_model.model.layers.29.attention.wqkv.Plora_A.weight": "model-00004-of-00004.safetensors",
386
+ "language_model.model.layers.29.attention.wqkv.Plora_B.weight": "model-00004-of-00004.safetensors",
387
+ "language_model.model.layers.29.attention.wqkv.original_linear.weight": "model-00004-of-00004.safetensors",
388
+ "language_model.model.layers.29.attention_norm.weight": "model-00004-of-00004.safetensors",
389
+ "language_model.model.layers.29.feed_forward.w1.Plora_A.weight": "model-00004-of-00004.safetensors",
390
+ "language_model.model.layers.29.feed_forward.w1.Plora_B.weight": "model-00004-of-00004.safetensors",
391
+ "language_model.model.layers.29.feed_forward.w1.original_linear.weight": "model-00004-of-00004.safetensors",
392
+ "language_model.model.layers.29.feed_forward.w2.Plora_A.weight": "model-00004-of-00004.safetensors",
393
+ "language_model.model.layers.29.feed_forward.w2.Plora_B.weight": "model-00004-of-00004.safetensors",
394
+ "language_model.model.layers.29.feed_forward.w2.original_linear.weight": "model-00004-of-00004.safetensors",
395
+ "language_model.model.layers.29.feed_forward.w3.Plora_A.weight": "model-00004-of-00004.safetensors",
396
+ "language_model.model.layers.29.feed_forward.w3.Plora_B.weight": "model-00004-of-00004.safetensors",
397
+ "language_model.model.layers.29.feed_forward.w3.original_linear.weight": "model-00004-of-00004.safetensors",
398
+ "language_model.model.layers.29.ffn_norm.weight": "model-00004-of-00004.safetensors",
399
+ "language_model.model.layers.3.attention.wo.Plora_A.weight": "model-00001-of-00004.safetensors",
400
+ "language_model.model.layers.3.attention.wo.Plora_B.weight": "model-00001-of-00004.safetensors",
401
+ "language_model.model.layers.3.attention.wo.original_linear.weight": "model-00001-of-00004.safetensors",
402
+ "language_model.model.layers.3.attention.wqkv.Plora_A.weight": "model-00001-of-00004.safetensors",
403
+ "language_model.model.layers.3.attention.wqkv.Plora_B.weight": "model-00001-of-00004.safetensors",
404
+ "language_model.model.layers.3.attention.wqkv.original_linear.weight": "model-00001-of-00004.safetensors",
405
+ "language_model.model.layers.3.attention_norm.weight": "model-00001-of-00004.safetensors",
406
+ "language_model.model.layers.3.feed_forward.w1.Plora_A.weight": "model-00001-of-00004.safetensors",
407
+ "language_model.model.layers.3.feed_forward.w1.Plora_B.weight": "model-00001-of-00004.safetensors",
408
+ "language_model.model.layers.3.feed_forward.w1.original_linear.weight": "model-00001-of-00004.safetensors",
409
+ "language_model.model.layers.3.feed_forward.w2.Plora_A.weight": "model-00001-of-00004.safetensors",
410
+ "language_model.model.layers.3.feed_forward.w2.Plora_B.weight": "model-00001-of-00004.safetensors",
411
+ "language_model.model.layers.3.feed_forward.w2.original_linear.weight": "model-00001-of-00004.safetensors",
412
+ "language_model.model.layers.3.feed_forward.w3.Plora_A.weight": "model-00001-of-00004.safetensors",
413
+ "language_model.model.layers.3.feed_forward.w3.Plora_B.weight": "model-00001-of-00004.safetensors",
414
+ "language_model.model.layers.3.feed_forward.w3.original_linear.weight": "model-00001-of-00004.safetensors",
415
+ "language_model.model.layers.3.ffn_norm.weight": "model-00001-of-00004.safetensors",
416
+ "language_model.model.layers.30.attention.wo.Plora_A.weight": "model-00004-of-00004.safetensors",
417
+ "language_model.model.layers.30.attention.wo.Plora_B.weight": "model-00004-of-00004.safetensors",
418
+ "language_model.model.layers.30.attention.wo.original_linear.weight": "model-00004-of-00004.safetensors",
419
+ "language_model.model.layers.30.attention.wqkv.Plora_A.weight": "model-00004-of-00004.safetensors",
420
+ "language_model.model.layers.30.attention.wqkv.Plora_B.weight": "model-00004-of-00004.safetensors",
421
+ "language_model.model.layers.30.attention.wqkv.original_linear.weight": "model-00004-of-00004.safetensors",
422
+ "language_model.model.layers.30.attention_norm.weight": "model-00004-of-00004.safetensors",
423
+ "language_model.model.layers.30.feed_forward.w1.Plora_A.weight": "model-00004-of-00004.safetensors",
424
+ "language_model.model.layers.30.feed_forward.w1.Plora_B.weight": "model-00004-of-00004.safetensors",
425
+ "language_model.model.layers.30.feed_forward.w1.original_linear.weight": "model-00004-of-00004.safetensors",
426
+ "language_model.model.layers.30.feed_forward.w2.Plora_A.weight": "model-00004-of-00004.safetensors",
427
+ "language_model.model.layers.30.feed_forward.w2.Plora_B.weight": "model-00004-of-00004.safetensors",
428
+ "language_model.model.layers.30.feed_forward.w2.original_linear.weight": "model-00004-of-00004.safetensors",
429
+ "language_model.model.layers.30.feed_forward.w3.Plora_A.weight": "model-00004-of-00004.safetensors",
430
+ "language_model.model.layers.30.feed_forward.w3.Plora_B.weight": "model-00004-of-00004.safetensors",
431
+ "language_model.model.layers.30.feed_forward.w3.original_linear.weight": "model-00004-of-00004.safetensors",
432
+ "language_model.model.layers.30.ffn_norm.weight": "model-00004-of-00004.safetensors",
433
+ "language_model.model.layers.31.attention.wo.Plora_A.weight": "model-00004-of-00004.safetensors",
434
+ "language_model.model.layers.31.attention.wo.Plora_B.weight": "model-00004-of-00004.safetensors",
435
+ "language_model.model.layers.31.attention.wo.original_linear.weight": "model-00004-of-00004.safetensors",
436
+ "language_model.model.layers.31.attention.wqkv.Plora_A.weight": "model-00004-of-00004.safetensors",
437
+ "language_model.model.layers.31.attention.wqkv.Plora_B.weight": "model-00004-of-00004.safetensors",
438
+ "language_model.model.layers.31.attention.wqkv.original_linear.weight": "model-00004-of-00004.safetensors",
439
+ "language_model.model.layers.31.attention_norm.weight": "model-00004-of-00004.safetensors",
440
+ "language_model.model.layers.31.feed_forward.w1.Plora_A.weight": "model-00004-of-00004.safetensors",
441
+ "language_model.model.layers.31.feed_forward.w1.Plora_B.weight": "model-00004-of-00004.safetensors",
442
+ "language_model.model.layers.31.feed_forward.w1.original_linear.weight": "model-00004-of-00004.safetensors",
443
+ "language_model.model.layers.31.feed_forward.w2.Plora_A.weight": "model-00004-of-00004.safetensors",
444
+ "language_model.model.layers.31.feed_forward.w2.Plora_B.weight": "model-00004-of-00004.safetensors",
445
+ "language_model.model.layers.31.feed_forward.w2.original_linear.weight": "model-00004-of-00004.safetensors",
446
+ "language_model.model.layers.31.feed_forward.w3.Plora_A.weight": "model-00004-of-00004.safetensors",
447
+ "language_model.model.layers.31.feed_forward.w3.Plora_B.weight": "model-00004-of-00004.safetensors",
448
+ "language_model.model.layers.31.feed_forward.w3.original_linear.weight": "model-00004-of-00004.safetensors",
449
+ "language_model.model.layers.31.ffn_norm.weight": "model-00004-of-00004.safetensors",
450
+ "language_model.model.layers.4.attention.wo.Plora_A.weight": "model-00001-of-00004.safetensors",
451
+ "language_model.model.layers.4.attention.wo.Plora_B.weight": "model-00001-of-00004.safetensors",
452
+ "language_model.model.layers.4.attention.wo.original_linear.weight": "model-00001-of-00004.safetensors",
453
+ "language_model.model.layers.4.attention.wqkv.Plora_A.weight": "model-00001-of-00004.safetensors",
454
+ "language_model.model.layers.4.attention.wqkv.Plora_B.weight": "model-00001-of-00004.safetensors",
455
+ "language_model.model.layers.4.attention.wqkv.original_linear.weight": "model-00001-of-00004.safetensors",
456
+ "language_model.model.layers.4.attention_norm.weight": "model-00001-of-00004.safetensors",
457
+ "language_model.model.layers.4.feed_forward.w1.Plora_A.weight": "model-00001-of-00004.safetensors",
458
+ "language_model.model.layers.4.feed_forward.w1.Plora_B.weight": "model-00001-of-00004.safetensors",
459
+ "language_model.model.layers.4.feed_forward.w1.original_linear.weight": "model-00001-of-00004.safetensors",
460
+ "language_model.model.layers.4.feed_forward.w2.Plora_A.weight": "model-00001-of-00004.safetensors",
461
+ "language_model.model.layers.4.feed_forward.w2.Plora_B.weight": "model-00001-of-00004.safetensors",
462
+ "language_model.model.layers.4.feed_forward.w2.original_linear.weight": "model-00001-of-00004.safetensors",
463
+ "language_model.model.layers.4.feed_forward.w3.Plora_A.weight": "model-00001-of-00004.safetensors",
464
+ "language_model.model.layers.4.feed_forward.w3.Plora_B.weight": "model-00001-of-00004.safetensors",
465
+ "language_model.model.layers.4.feed_forward.w3.original_linear.weight": "model-00001-of-00004.safetensors",
466
+ "language_model.model.layers.4.ffn_norm.weight": "model-00001-of-00004.safetensors",
467
+ "language_model.model.layers.5.attention.wo.Plora_A.weight": "model-00001-of-00004.safetensors",
468
+ "language_model.model.layers.5.attention.wo.Plora_B.weight": "model-00001-of-00004.safetensors",
469
+ "language_model.model.layers.5.attention.wo.original_linear.weight": "model-00001-of-00004.safetensors",
470
+ "language_model.model.layers.5.attention.wqkv.Plora_A.weight": "model-00001-of-00004.safetensors",
471
+ "language_model.model.layers.5.attention.wqkv.Plora_B.weight": "model-00001-of-00004.safetensors",
472
+ "language_model.model.layers.5.attention.wqkv.original_linear.weight": "model-00001-of-00004.safetensors",
473
+ "language_model.model.layers.5.attention_norm.weight": "model-00001-of-00004.safetensors",
474
+ "language_model.model.layers.5.feed_forward.w1.Plora_A.weight": "model-00001-of-00004.safetensors",
475
+ "language_model.model.layers.5.feed_forward.w1.Plora_B.weight": "model-00001-of-00004.safetensors",
476
+ "language_model.model.layers.5.feed_forward.w1.original_linear.weight": "model-00001-of-00004.safetensors",
477
+ "language_model.model.layers.5.feed_forward.w2.Plora_A.weight": "model-00001-of-00004.safetensors",
478
+ "language_model.model.layers.5.feed_forward.w2.Plora_B.weight": "model-00001-of-00004.safetensors",
479
+ "language_model.model.layers.5.feed_forward.w2.original_linear.weight": "model-00001-of-00004.safetensors",
480
+ "language_model.model.layers.5.feed_forward.w3.Plora_A.weight": "model-00001-of-00004.safetensors",
481
+ "language_model.model.layers.5.feed_forward.w3.Plora_B.weight": "model-00001-of-00004.safetensors",
482
+ "language_model.model.layers.5.feed_forward.w3.original_linear.weight": "model-00001-of-00004.safetensors",
483
+ "language_model.model.layers.5.ffn_norm.weight": "model-00001-of-00004.safetensors",
484
+ "language_model.model.layers.6.attention.wo.Plora_A.weight": "model-00001-of-00004.safetensors",
485
+ "language_model.model.layers.6.attention.wo.Plora_B.weight": "model-00001-of-00004.safetensors",
486
+ "language_model.model.layers.6.attention.wo.original_linear.weight": "model-00001-of-00004.safetensors",
487
+ "language_model.model.layers.6.attention.wqkv.Plora_A.weight": "model-00001-of-00004.safetensors",
488
+ "language_model.model.layers.6.attention.wqkv.Plora_B.weight": "model-00001-of-00004.safetensors",
489
+ "language_model.model.layers.6.attention.wqkv.original_linear.weight": "model-00001-of-00004.safetensors",
490
+ "language_model.model.layers.6.attention_norm.weight": "model-00002-of-00004.safetensors",
491
+ "language_model.model.layers.6.feed_forward.w1.Plora_A.weight": "model-00002-of-00004.safetensors",
492
+ "language_model.model.layers.6.feed_forward.w1.Plora_B.weight": "model-00002-of-00004.safetensors",
493
+ "language_model.model.layers.6.feed_forward.w1.original_linear.weight": "model-00002-of-00004.safetensors",
494
+ "language_model.model.layers.6.feed_forward.w2.Plora_A.weight": "model-00002-of-00004.safetensors",
495
+ "language_model.model.layers.6.feed_forward.w2.Plora_B.weight": "model-00002-of-00004.safetensors",
496
+ "language_model.model.layers.6.feed_forward.w2.original_linear.weight": "model-00002-of-00004.safetensors",
497
+ "language_model.model.layers.6.feed_forward.w3.Plora_A.weight": "model-00002-of-00004.safetensors",
498
+ "language_model.model.layers.6.feed_forward.w3.Plora_B.weight": "model-00002-of-00004.safetensors",
499
+ "language_model.model.layers.6.feed_forward.w3.original_linear.weight": "model-00002-of-00004.safetensors",
500
+ "language_model.model.layers.6.ffn_norm.weight": "model-00002-of-00004.safetensors",
501
+ "language_model.model.layers.7.attention.wo.Plora_A.weight": "model-00002-of-00004.safetensors",
502
+ "language_model.model.layers.7.attention.wo.Plora_B.weight": "model-00002-of-00004.safetensors",
503
+ "language_model.model.layers.7.attention.wo.original_linear.weight": "model-00002-of-00004.safetensors",
504
+ "language_model.model.layers.7.attention.wqkv.Plora_A.weight": "model-00002-of-00004.safetensors",
505
+ "language_model.model.layers.7.attention.wqkv.Plora_B.weight": "model-00002-of-00004.safetensors",
506
+ "language_model.model.layers.7.attention.wqkv.original_linear.weight": "model-00002-of-00004.safetensors",
507
+ "language_model.model.layers.7.attention_norm.weight": "model-00002-of-00004.safetensors",
508
+ "language_model.model.layers.7.feed_forward.w1.Plora_A.weight": "model-00002-of-00004.safetensors",
509
+ "language_model.model.layers.7.feed_forward.w1.Plora_B.weight": "model-00002-of-00004.safetensors",
510
+ "language_model.model.layers.7.feed_forward.w1.original_linear.weight": "model-00002-of-00004.safetensors",
511
+ "language_model.model.layers.7.feed_forward.w2.Plora_A.weight": "model-00002-of-00004.safetensors",
512
+ "language_model.model.layers.7.feed_forward.w2.Plora_B.weight": "model-00002-of-00004.safetensors",
513
+ "language_model.model.layers.7.feed_forward.w2.original_linear.weight": "model-00002-of-00004.safetensors",
514
+ "language_model.model.layers.7.feed_forward.w3.Plora_A.weight": "model-00002-of-00004.safetensors",
515
+ "language_model.model.layers.7.feed_forward.w3.Plora_B.weight": "model-00002-of-00004.safetensors",
516
+ "language_model.model.layers.7.feed_forward.w3.original_linear.weight": "model-00002-of-00004.safetensors",
517
+ "language_model.model.layers.7.ffn_norm.weight": "model-00002-of-00004.safetensors",
518
+ "language_model.model.layers.8.attention.wo.Plora_A.weight": "model-00002-of-00004.safetensors",
519
+ "language_model.model.layers.8.attention.wo.Plora_B.weight": "model-00002-of-00004.safetensors",
520
+ "language_model.model.layers.8.attention.wo.original_linear.weight": "model-00002-of-00004.safetensors",
521
+ "language_model.model.layers.8.attention.wqkv.Plora_A.weight": "model-00002-of-00004.safetensors",
522
+ "language_model.model.layers.8.attention.wqkv.Plora_B.weight": "model-00002-of-00004.safetensors",
523
+ "language_model.model.layers.8.attention.wqkv.original_linear.weight": "model-00002-of-00004.safetensors",
524
+ "language_model.model.layers.8.attention_norm.weight": "model-00002-of-00004.safetensors",
525
+ "language_model.model.layers.8.feed_forward.w1.Plora_A.weight": "model-00002-of-00004.safetensors",
526
+ "language_model.model.layers.8.feed_forward.w1.Plora_B.weight": "model-00002-of-00004.safetensors",
527
+ "language_model.model.layers.8.feed_forward.w1.original_linear.weight": "model-00002-of-00004.safetensors",
528
+ "language_model.model.layers.8.feed_forward.w2.Plora_A.weight": "model-00002-of-00004.safetensors",
529
+ "language_model.model.layers.8.feed_forward.w2.Plora_B.weight": "model-00002-of-00004.safetensors",
530
+ "language_model.model.layers.8.feed_forward.w2.original_linear.weight": "model-00002-of-00004.safetensors",
531
+ "language_model.model.layers.8.feed_forward.w3.Plora_A.weight": "model-00002-of-00004.safetensors",
532
+ "language_model.model.layers.8.feed_forward.w3.Plora_B.weight": "model-00002-of-00004.safetensors",
533
+ "language_model.model.layers.8.feed_forward.w3.original_linear.weight": "model-00002-of-00004.safetensors",
534
+ "language_model.model.layers.8.ffn_norm.weight": "model-00002-of-00004.safetensors",
535
+ "language_model.model.layers.9.attention.wo.Plora_A.weight": "model-00002-of-00004.safetensors",
536
+ "language_model.model.layers.9.attention.wo.Plora_B.weight": "model-00002-of-00004.safetensors",
537
+ "language_model.model.layers.9.attention.wo.original_linear.weight": "model-00002-of-00004.safetensors",
538
+ "language_model.model.layers.9.attention.wqkv.Plora_A.weight": "model-00002-of-00004.safetensors",
539
+ "language_model.model.layers.9.attention.wqkv.Plora_B.weight": "model-00002-of-00004.safetensors",
540
+ "language_model.model.layers.9.attention.wqkv.original_linear.weight": "model-00002-of-00004.safetensors",
541
+ "language_model.model.layers.9.attention_norm.weight": "model-00002-of-00004.safetensors",
542
+ "language_model.model.layers.9.feed_forward.w1.Plora_A.weight": "model-00002-of-00004.safetensors",
543
+ "language_model.model.layers.9.feed_forward.w1.Plora_B.weight": "model-00002-of-00004.safetensors",
544
+ "language_model.model.layers.9.feed_forward.w1.original_linear.weight": "model-00002-of-00004.safetensors",
545
+ "language_model.model.layers.9.feed_forward.w2.Plora_A.weight": "model-00002-of-00004.safetensors",
546
+ "language_model.model.layers.9.feed_forward.w2.Plora_B.weight": "model-00002-of-00004.safetensors",
547
+ "language_model.model.layers.9.feed_forward.w2.original_linear.weight": "model-00002-of-00004.safetensors",
548
+ "language_model.model.layers.9.feed_forward.w3.Plora_A.weight": "model-00002-of-00004.safetensors",
549
+ "language_model.model.layers.9.feed_forward.w3.Plora_B.weight": "model-00002-of-00004.safetensors",
550
+ "language_model.model.layers.9.feed_forward.w3.original_linear.weight": "model-00002-of-00004.safetensors",
551
+ "language_model.model.layers.9.ffn_norm.weight": "model-00002-of-00004.safetensors",
552
+ "language_model.model.norm.weight": "model-00004-of-00004.safetensors",
553
+ "language_model.model.tok_embeddings.weight": "model-00001-of-00004.safetensors",
554
+ "language_model.output.weight": "model-00004-of-00004.safetensors",
555
+ "vision_tower.embeddings.patch_embedding.bias": "model-00001-of-00004.safetensors",
556
+ "vision_tower.embeddings.patch_embedding.weight": "model-00001-of-00004.safetensors",
557
+ "vision_tower.embeddings.position_embedding.weight": "model-00001-of-00004.safetensors",
558
+ "vision_tower.encoder.layers.0.layer_norm1.bias": "model-00001-of-00004.safetensors",
559
+ "vision_tower.encoder.layers.0.layer_norm1.weight": "model-00001-of-00004.safetensors",
560
+ "vision_tower.encoder.layers.0.layer_norm2.bias": "model-00001-of-00004.safetensors",
561
+ "vision_tower.encoder.layers.0.layer_norm2.weight": "model-00001-of-00004.safetensors",
562
+ "vision_tower.encoder.layers.0.mlp.fc1.bias": "model-00001-of-00004.safetensors",
563
+ "vision_tower.encoder.layers.0.mlp.fc1.weight": "model-00001-of-00004.safetensors",
564
+ "vision_tower.encoder.layers.0.mlp.fc2.bias": "model-00001-of-00004.safetensors",
565
+ "vision_tower.encoder.layers.0.mlp.fc2.weight": "model-00001-of-00004.safetensors",
566
+ "vision_tower.encoder.layers.0.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
567
+ "vision_tower.encoder.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
568
+ "vision_tower.encoder.layers.0.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
569
+ "vision_tower.encoder.layers.0.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
570
+ "vision_tower.encoder.layers.0.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
571
+ "vision_tower.encoder.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
572
+ "vision_tower.encoder.layers.0.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
573
+ "vision_tower.encoder.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
574
+ "vision_tower.encoder.layers.1.layer_norm1.bias": "model-00001-of-00004.safetensors",
575
+ "vision_tower.encoder.layers.1.layer_norm1.weight": "model-00001-of-00004.safetensors",
576
+ "vision_tower.encoder.layers.1.layer_norm2.bias": "model-00001-of-00004.safetensors",
577
+ "vision_tower.encoder.layers.1.layer_norm2.weight": "model-00001-of-00004.safetensors",
578
+ "vision_tower.encoder.layers.1.mlp.fc1.bias": "model-00001-of-00004.safetensors",
579
+ "vision_tower.encoder.layers.1.mlp.fc1.weight": "model-00001-of-00004.safetensors",
580
+ "vision_tower.encoder.layers.1.mlp.fc2.bias": "model-00001-of-00004.safetensors",
581
+ "vision_tower.encoder.layers.1.mlp.fc2.weight": "model-00001-of-00004.safetensors",
582
+ "vision_tower.encoder.layers.1.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
583
+ "vision_tower.encoder.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
584
+ "vision_tower.encoder.layers.1.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
585
+ "vision_tower.encoder.layers.1.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
586
+ "vision_tower.encoder.layers.1.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
587
+ "vision_tower.encoder.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
588
+ "vision_tower.encoder.layers.1.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
589
+ "vision_tower.encoder.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
590
+ "vision_tower.encoder.layers.10.layer_norm1.bias": "model-00001-of-00004.safetensors",
591
+ "vision_tower.encoder.layers.10.layer_norm1.weight": "model-00001-of-00004.safetensors",
592
+ "vision_tower.encoder.layers.10.layer_norm2.bias": "model-00001-of-00004.safetensors",
593
+ "vision_tower.encoder.layers.10.layer_norm2.weight": "model-00001-of-00004.safetensors",
594
+ "vision_tower.encoder.layers.10.mlp.fc1.bias": "model-00001-of-00004.safetensors",
595
+ "vision_tower.encoder.layers.10.mlp.fc1.weight": "model-00001-of-00004.safetensors",
596
+ "vision_tower.encoder.layers.10.mlp.fc2.bias": "model-00001-of-00004.safetensors",
597
+ "vision_tower.encoder.layers.10.mlp.fc2.weight": "model-00001-of-00004.safetensors",
598
+ "vision_tower.encoder.layers.10.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
599
+ "vision_tower.encoder.layers.10.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
600
+ "vision_tower.encoder.layers.10.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
601
+ "vision_tower.encoder.layers.10.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
602
+ "vision_tower.encoder.layers.10.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
603
+ "vision_tower.encoder.layers.10.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
604
+ "vision_tower.encoder.layers.10.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
605
+ "vision_tower.encoder.layers.10.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
606
+ "vision_tower.encoder.layers.11.layer_norm1.bias": "model-00001-of-00004.safetensors",
607
+ "vision_tower.encoder.layers.11.layer_norm1.weight": "model-00001-of-00004.safetensors",
608
+ "vision_tower.encoder.layers.11.layer_norm2.bias": "model-00001-of-00004.safetensors",
609
+ "vision_tower.encoder.layers.11.layer_norm2.weight": "model-00001-of-00004.safetensors",
610
+ "vision_tower.encoder.layers.11.mlp.fc1.bias": "model-00001-of-00004.safetensors",
611
+ "vision_tower.encoder.layers.11.mlp.fc1.weight": "model-00001-of-00004.safetensors",
612
+ "vision_tower.encoder.layers.11.mlp.fc2.bias": "model-00001-of-00004.safetensors",
613
+ "vision_tower.encoder.layers.11.mlp.fc2.weight": "model-00001-of-00004.safetensors",
614
+ "vision_tower.encoder.layers.11.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
615
+ "vision_tower.encoder.layers.11.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
616
+ "vision_tower.encoder.layers.11.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
617
+ "vision_tower.encoder.layers.11.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
618
+ "vision_tower.encoder.layers.11.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
619
+ "vision_tower.encoder.layers.11.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
620
+ "vision_tower.encoder.layers.11.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
621
+ "vision_tower.encoder.layers.11.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
622
+ "vision_tower.encoder.layers.12.layer_norm1.bias": "model-00001-of-00004.safetensors",
623
+ "vision_tower.encoder.layers.12.layer_norm1.weight": "model-00001-of-00004.safetensors",
624
+ "vision_tower.encoder.layers.12.layer_norm2.bias": "model-00001-of-00004.safetensors",
625
+ "vision_tower.encoder.layers.12.layer_norm2.weight": "model-00001-of-00004.safetensors",
626
+ "vision_tower.encoder.layers.12.mlp.fc1.bias": "model-00001-of-00004.safetensors",
627
+ "vision_tower.encoder.layers.12.mlp.fc1.weight": "model-00001-of-00004.safetensors",
628
+ "vision_tower.encoder.layers.12.mlp.fc2.bias": "model-00001-of-00004.safetensors",
629
+ "vision_tower.encoder.layers.12.mlp.fc2.weight": "model-00001-of-00004.safetensors",
630
+ "vision_tower.encoder.layers.12.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
631
+ "vision_tower.encoder.layers.12.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
632
+ "vision_tower.encoder.layers.12.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
633
+ "vision_tower.encoder.layers.12.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
634
+ "vision_tower.encoder.layers.12.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
635
+ "vision_tower.encoder.layers.12.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
636
+ "vision_tower.encoder.layers.12.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
637
+ "vision_tower.encoder.layers.12.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
638
+ "vision_tower.encoder.layers.13.layer_norm1.bias": "model-00001-of-00004.safetensors",
639
+ "vision_tower.encoder.layers.13.layer_norm1.weight": "model-00001-of-00004.safetensors",
640
+ "vision_tower.encoder.layers.13.layer_norm2.bias": "model-00001-of-00004.safetensors",
641
+ "vision_tower.encoder.layers.13.layer_norm2.weight": "model-00001-of-00004.safetensors",
642
+ "vision_tower.encoder.layers.13.mlp.fc1.bias": "model-00001-of-00004.safetensors",
643
+ "vision_tower.encoder.layers.13.mlp.fc1.weight": "model-00001-of-00004.safetensors",
644
+ "vision_tower.encoder.layers.13.mlp.fc2.bias": "model-00001-of-00004.safetensors",
645
+ "vision_tower.encoder.layers.13.mlp.fc2.weight": "model-00001-of-00004.safetensors",
646
+ "vision_tower.encoder.layers.13.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
647
+ "vision_tower.encoder.layers.13.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
648
+ "vision_tower.encoder.layers.13.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
649
+ "vision_tower.encoder.layers.13.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
650
+ "vision_tower.encoder.layers.13.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
651
+ "vision_tower.encoder.layers.13.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
652
+ "vision_tower.encoder.layers.13.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
653
+ "vision_tower.encoder.layers.13.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
654
+ "vision_tower.encoder.layers.14.layer_norm1.bias": "model-00001-of-00004.safetensors",
655
+ "vision_tower.encoder.layers.14.layer_norm1.weight": "model-00001-of-00004.safetensors",
656
+ "vision_tower.encoder.layers.14.layer_norm2.bias": "model-00001-of-00004.safetensors",
657
+ "vision_tower.encoder.layers.14.layer_norm2.weight": "model-00001-of-00004.safetensors",
658
+ "vision_tower.encoder.layers.14.mlp.fc1.bias": "model-00001-of-00004.safetensors",
659
+ "vision_tower.encoder.layers.14.mlp.fc1.weight": "model-00001-of-00004.safetensors",
660
+ "vision_tower.encoder.layers.14.mlp.fc2.bias": "model-00001-of-00004.safetensors",
661
+ "vision_tower.encoder.layers.14.mlp.fc2.weight": "model-00001-of-00004.safetensors",
662
+ "vision_tower.encoder.layers.14.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
663
+ "vision_tower.encoder.layers.14.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
664
+ "vision_tower.encoder.layers.14.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
665
+ "vision_tower.encoder.layers.14.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
666
+ "vision_tower.encoder.layers.14.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
667
+ "vision_tower.encoder.layers.14.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
668
+ "vision_tower.encoder.layers.14.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
669
+ "vision_tower.encoder.layers.14.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
670
+ "vision_tower.encoder.layers.15.layer_norm1.bias": "model-00001-of-00004.safetensors",
671
+ "vision_tower.encoder.layers.15.layer_norm1.weight": "model-00001-of-00004.safetensors",
672
+ "vision_tower.encoder.layers.15.layer_norm2.bias": "model-00001-of-00004.safetensors",
673
+ "vision_tower.encoder.layers.15.layer_norm2.weight": "model-00001-of-00004.safetensors",
674
+ "vision_tower.encoder.layers.15.mlp.fc1.bias": "model-00001-of-00004.safetensors",
675
+ "vision_tower.encoder.layers.15.mlp.fc1.weight": "model-00001-of-00004.safetensors",
676
+ "vision_tower.encoder.layers.15.mlp.fc2.bias": "model-00001-of-00004.safetensors",
677
+ "vision_tower.encoder.layers.15.mlp.fc2.weight": "model-00001-of-00004.safetensors",
678
+ "vision_tower.encoder.layers.15.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
679
+ "vision_tower.encoder.layers.15.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
680
+ "vision_tower.encoder.layers.15.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
681
+ "vision_tower.encoder.layers.15.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
682
+ "vision_tower.encoder.layers.15.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
683
+ "vision_tower.encoder.layers.15.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
684
+ "vision_tower.encoder.layers.15.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
685
+ "vision_tower.encoder.layers.15.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
686
+ "vision_tower.encoder.layers.16.layer_norm1.bias": "model-00001-of-00004.safetensors",
687
+ "vision_tower.encoder.layers.16.layer_norm1.weight": "model-00001-of-00004.safetensors",
688
+ "vision_tower.encoder.layers.16.layer_norm2.bias": "model-00001-of-00004.safetensors",
689
+ "vision_tower.encoder.layers.16.layer_norm2.weight": "model-00001-of-00004.safetensors",
690
+ "vision_tower.encoder.layers.16.mlp.fc1.bias": "model-00001-of-00004.safetensors",
691
+ "vision_tower.encoder.layers.16.mlp.fc1.weight": "model-00001-of-00004.safetensors",
692
+ "vision_tower.encoder.layers.16.mlp.fc2.bias": "model-00001-of-00004.safetensors",
693
+ "vision_tower.encoder.layers.16.mlp.fc2.weight": "model-00001-of-00004.safetensors",
694
+ "vision_tower.encoder.layers.16.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
695
+ "vision_tower.encoder.layers.16.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
696
+ "vision_tower.encoder.layers.16.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
697
+ "vision_tower.encoder.layers.16.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
698
+ "vision_tower.encoder.layers.16.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
699
+ "vision_tower.encoder.layers.16.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
700
+ "vision_tower.encoder.layers.16.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
701
+ "vision_tower.encoder.layers.16.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
702
+ "vision_tower.encoder.layers.17.layer_norm1.bias": "model-00001-of-00004.safetensors",
703
+ "vision_tower.encoder.layers.17.layer_norm1.weight": "model-00001-of-00004.safetensors",
704
+ "vision_tower.encoder.layers.17.layer_norm2.bias": "model-00001-of-00004.safetensors",
705
+ "vision_tower.encoder.layers.17.layer_norm2.weight": "model-00001-of-00004.safetensors",
706
+ "vision_tower.encoder.layers.17.mlp.fc1.bias": "model-00001-of-00004.safetensors",
707
+ "vision_tower.encoder.layers.17.mlp.fc1.weight": "model-00001-of-00004.safetensors",
708
+ "vision_tower.encoder.layers.17.mlp.fc2.bias": "model-00001-of-00004.safetensors",
709
+ "vision_tower.encoder.layers.17.mlp.fc2.weight": "model-00001-of-00004.safetensors",
710
+ "vision_tower.encoder.layers.17.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
711
+ "vision_tower.encoder.layers.17.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
712
+ "vision_tower.encoder.layers.17.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
713
+ "vision_tower.encoder.layers.17.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
714
+ "vision_tower.encoder.layers.17.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
715
+ "vision_tower.encoder.layers.17.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
716
+ "vision_tower.encoder.layers.17.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
717
+ "vision_tower.encoder.layers.17.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
718
+ "vision_tower.encoder.layers.18.layer_norm1.bias": "model-00001-of-00004.safetensors",
719
+ "vision_tower.encoder.layers.18.layer_norm1.weight": "model-00001-of-00004.safetensors",
720
+ "vision_tower.encoder.layers.18.layer_norm2.bias": "model-00001-of-00004.safetensors",
721
+ "vision_tower.encoder.layers.18.layer_norm2.weight": "model-00001-of-00004.safetensors",
722
+ "vision_tower.encoder.layers.18.mlp.fc1.bias": "model-00001-of-00004.safetensors",
723
+ "vision_tower.encoder.layers.18.mlp.fc1.weight": "model-00001-of-00004.safetensors",
724
+ "vision_tower.encoder.layers.18.mlp.fc2.bias": "model-00001-of-00004.safetensors",
725
+ "vision_tower.encoder.layers.18.mlp.fc2.weight": "model-00001-of-00004.safetensors",
726
+ "vision_tower.encoder.layers.18.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
727
+ "vision_tower.encoder.layers.18.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
728
+ "vision_tower.encoder.layers.18.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
729
+ "vision_tower.encoder.layers.18.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
730
+ "vision_tower.encoder.layers.18.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
731
+ "vision_tower.encoder.layers.18.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
732
+ "vision_tower.encoder.layers.18.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
733
+ "vision_tower.encoder.layers.18.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
734
+ "vision_tower.encoder.layers.19.layer_norm1.bias": "model-00001-of-00004.safetensors",
735
+ "vision_tower.encoder.layers.19.layer_norm1.weight": "model-00001-of-00004.safetensors",
736
+ "vision_tower.encoder.layers.19.layer_norm2.bias": "model-00001-of-00004.safetensors",
737
+ "vision_tower.encoder.layers.19.layer_norm2.weight": "model-00001-of-00004.safetensors",
738
+ "vision_tower.encoder.layers.19.mlp.fc1.bias": "model-00001-of-00004.safetensors",
739
+ "vision_tower.encoder.layers.19.mlp.fc1.weight": "model-00001-of-00004.safetensors",
740
+ "vision_tower.encoder.layers.19.mlp.fc2.bias": "model-00001-of-00004.safetensors",
741
+ "vision_tower.encoder.layers.19.mlp.fc2.weight": "model-00001-of-00004.safetensors",
742
+ "vision_tower.encoder.layers.19.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
743
+ "vision_tower.encoder.layers.19.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
744
+ "vision_tower.encoder.layers.19.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
745
+ "vision_tower.encoder.layers.19.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
746
+ "vision_tower.encoder.layers.19.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
747
+ "vision_tower.encoder.layers.19.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
748
+ "vision_tower.encoder.layers.19.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
749
+ "vision_tower.encoder.layers.19.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
750
+ "vision_tower.encoder.layers.2.layer_norm1.bias": "model-00001-of-00004.safetensors",
751
+ "vision_tower.encoder.layers.2.layer_norm1.weight": "model-00001-of-00004.safetensors",
752
+ "vision_tower.encoder.layers.2.layer_norm2.bias": "model-00001-of-00004.safetensors",
753
+ "vision_tower.encoder.layers.2.layer_norm2.weight": "model-00001-of-00004.safetensors",
754
+ "vision_tower.encoder.layers.2.mlp.fc1.bias": "model-00001-of-00004.safetensors",
755
+ "vision_tower.encoder.layers.2.mlp.fc1.weight": "model-00001-of-00004.safetensors",
756
+ "vision_tower.encoder.layers.2.mlp.fc2.bias": "model-00001-of-00004.safetensors",
757
+ "vision_tower.encoder.layers.2.mlp.fc2.weight": "model-00001-of-00004.safetensors",
758
+ "vision_tower.encoder.layers.2.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
759
+ "vision_tower.encoder.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
760
+ "vision_tower.encoder.layers.2.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
761
+ "vision_tower.encoder.layers.2.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
762
+ "vision_tower.encoder.layers.2.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
763
+ "vision_tower.encoder.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
764
+ "vision_tower.encoder.layers.2.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
765
+ "vision_tower.encoder.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
766
+ "vision_tower.encoder.layers.20.layer_norm1.bias": "model-00001-of-00004.safetensors",
767
+ "vision_tower.encoder.layers.20.layer_norm1.weight": "model-00001-of-00004.safetensors",
768
+ "vision_tower.encoder.layers.20.layer_norm2.bias": "model-00001-of-00004.safetensors",
769
+ "vision_tower.encoder.layers.20.layer_norm2.weight": "model-00001-of-00004.safetensors",
770
+ "vision_tower.encoder.layers.20.mlp.fc1.bias": "model-00001-of-00004.safetensors",
771
+ "vision_tower.encoder.layers.20.mlp.fc1.weight": "model-00001-of-00004.safetensors",
772
+ "vision_tower.encoder.layers.20.mlp.fc2.bias": "model-00001-of-00004.safetensors",
773
+ "vision_tower.encoder.layers.20.mlp.fc2.weight": "model-00001-of-00004.safetensors",
774
+ "vision_tower.encoder.layers.20.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
775
+ "vision_tower.encoder.layers.20.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
776
+ "vision_tower.encoder.layers.20.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
777
+ "vision_tower.encoder.layers.20.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
778
+ "vision_tower.encoder.layers.20.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
779
+ "vision_tower.encoder.layers.20.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
780
+ "vision_tower.encoder.layers.20.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
781
+ "vision_tower.encoder.layers.20.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
782
+ "vision_tower.encoder.layers.21.layer_norm1.bias": "model-00001-of-00004.safetensors",
783
+ "vision_tower.encoder.layers.21.layer_norm1.weight": "model-00001-of-00004.safetensors",
784
+ "vision_tower.encoder.layers.21.layer_norm2.bias": "model-00001-of-00004.safetensors",
785
+ "vision_tower.encoder.layers.21.layer_norm2.weight": "model-00001-of-00004.safetensors",
786
+ "vision_tower.encoder.layers.21.mlp.fc1.bias": "model-00001-of-00004.safetensors",
787
+ "vision_tower.encoder.layers.21.mlp.fc1.weight": "model-00001-of-00004.safetensors",
788
+ "vision_tower.encoder.layers.21.mlp.fc2.bias": "model-00001-of-00004.safetensors",
789
+ "vision_tower.encoder.layers.21.mlp.fc2.weight": "model-00001-of-00004.safetensors",
790
+ "vision_tower.encoder.layers.21.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
791
+ "vision_tower.encoder.layers.21.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
792
+ "vision_tower.encoder.layers.21.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
793
+ "vision_tower.encoder.layers.21.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
794
+ "vision_tower.encoder.layers.21.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
795
+ "vision_tower.encoder.layers.21.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
796
+ "vision_tower.encoder.layers.21.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
797
+ "vision_tower.encoder.layers.21.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
798
+ "vision_tower.encoder.layers.22.layer_norm1.bias": "model-00001-of-00004.safetensors",
799
+ "vision_tower.encoder.layers.22.layer_norm1.weight": "model-00001-of-00004.safetensors",
800
+ "vision_tower.encoder.layers.22.layer_norm2.bias": "model-00001-of-00004.safetensors",
801
+ "vision_tower.encoder.layers.22.layer_norm2.weight": "model-00001-of-00004.safetensors",
802
+ "vision_tower.encoder.layers.22.mlp.fc1.bias": "model-00001-of-00004.safetensors",
803
+ "vision_tower.encoder.layers.22.mlp.fc1.weight": "model-00001-of-00004.safetensors",
804
+ "vision_tower.encoder.layers.22.mlp.fc2.bias": "model-00001-of-00004.safetensors",
805
+ "vision_tower.encoder.layers.22.mlp.fc2.weight": "model-00001-of-00004.safetensors",
806
+ "vision_tower.encoder.layers.22.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
807
+ "vision_tower.encoder.layers.22.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
808
+ "vision_tower.encoder.layers.22.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
809
+ "vision_tower.encoder.layers.22.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
810
+ "vision_tower.encoder.layers.22.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
811
+ "vision_tower.encoder.layers.22.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
812
+ "vision_tower.encoder.layers.22.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
813
+ "vision_tower.encoder.layers.22.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
814
+ "vision_tower.encoder.layers.23.layer_norm1.bias": "model-00001-of-00004.safetensors",
815
+ "vision_tower.encoder.layers.23.layer_norm1.weight": "model-00001-of-00004.safetensors",
816
+ "vision_tower.encoder.layers.23.layer_norm2.bias": "model-00001-of-00004.safetensors",
817
+ "vision_tower.encoder.layers.23.layer_norm2.weight": "model-00001-of-00004.safetensors",
818
+ "vision_tower.encoder.layers.23.mlp.fc1.bias": "model-00001-of-00004.safetensors",
819
+ "vision_tower.encoder.layers.23.mlp.fc1.weight": "model-00001-of-00004.safetensors",
820
+ "vision_tower.encoder.layers.23.mlp.fc2.bias": "model-00001-of-00004.safetensors",
821
+ "vision_tower.encoder.layers.23.mlp.fc2.weight": "model-00001-of-00004.safetensors",
822
+ "vision_tower.encoder.layers.23.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
823
+ "vision_tower.encoder.layers.23.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
824
+ "vision_tower.encoder.layers.23.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
825
+ "vision_tower.encoder.layers.23.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
826
+ "vision_tower.encoder.layers.23.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
827
+ "vision_tower.encoder.layers.23.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
828
+ "vision_tower.encoder.layers.23.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
829
+ "vision_tower.encoder.layers.23.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
830
+ "vision_tower.encoder.layers.24.layer_norm1.bias": "model-00001-of-00004.safetensors",
831
+ "vision_tower.encoder.layers.24.layer_norm1.weight": "model-00001-of-00004.safetensors",
832
+ "vision_tower.encoder.layers.24.layer_norm2.bias": "model-00001-of-00004.safetensors",
833
+ "vision_tower.encoder.layers.24.layer_norm2.weight": "model-00001-of-00004.safetensors",
834
+ "vision_tower.encoder.layers.24.mlp.fc1.bias": "model-00001-of-00004.safetensors",
835
+ "vision_tower.encoder.layers.24.mlp.fc1.weight": "model-00001-of-00004.safetensors",
836
+ "vision_tower.encoder.layers.24.mlp.fc2.bias": "model-00001-of-00004.safetensors",
837
+ "vision_tower.encoder.layers.24.mlp.fc2.weight": "model-00001-of-00004.safetensors",
838
+ "vision_tower.encoder.layers.24.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
839
+ "vision_tower.encoder.layers.24.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
840
+ "vision_tower.encoder.layers.24.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
841
+ "vision_tower.encoder.layers.24.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
842
+ "vision_tower.encoder.layers.24.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
843
+ "vision_tower.encoder.layers.24.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
844
+ "vision_tower.encoder.layers.24.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
845
+ "vision_tower.encoder.layers.24.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
846
+ "vision_tower.encoder.layers.25.layer_norm1.bias": "model-00001-of-00004.safetensors",
847
+ "vision_tower.encoder.layers.25.layer_norm1.weight": "model-00001-of-00004.safetensors",
848
+ "vision_tower.encoder.layers.25.layer_norm2.bias": "model-00001-of-00004.safetensors",
849
+ "vision_tower.encoder.layers.25.layer_norm2.weight": "model-00001-of-00004.safetensors",
850
+ "vision_tower.encoder.layers.25.mlp.fc1.bias": "model-00001-of-00004.safetensors",
851
+ "vision_tower.encoder.layers.25.mlp.fc1.weight": "model-00001-of-00004.safetensors",
852
+ "vision_tower.encoder.layers.25.mlp.fc2.bias": "model-00001-of-00004.safetensors",
853
+ "vision_tower.encoder.layers.25.mlp.fc2.weight": "model-00001-of-00004.safetensors",
854
+ "vision_tower.encoder.layers.25.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
855
+ "vision_tower.encoder.layers.25.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
856
+ "vision_tower.encoder.layers.25.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
857
+ "vision_tower.encoder.layers.25.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
858
+ "vision_tower.encoder.layers.25.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
859
+ "vision_tower.encoder.layers.25.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
860
+ "vision_tower.encoder.layers.25.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
861
+ "vision_tower.encoder.layers.25.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
862
+ "vision_tower.encoder.layers.26.layer_norm1.bias": "model-00001-of-00004.safetensors",
863
+ "vision_tower.encoder.layers.26.layer_norm1.weight": "model-00001-of-00004.safetensors",
864
+ "vision_tower.encoder.layers.26.layer_norm2.bias": "model-00001-of-00004.safetensors",
865
+ "vision_tower.encoder.layers.26.layer_norm2.weight": "model-00001-of-00004.safetensors",
866
+ "vision_tower.encoder.layers.26.mlp.fc1.bias": "model-00001-of-00004.safetensors",
867
+ "vision_tower.encoder.layers.26.mlp.fc1.weight": "model-00001-of-00004.safetensors",
868
+ "vision_tower.encoder.layers.26.mlp.fc2.bias": "model-00001-of-00004.safetensors",
869
+ "vision_tower.encoder.layers.26.mlp.fc2.weight": "model-00001-of-00004.safetensors",
870
+ "vision_tower.encoder.layers.26.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
871
+ "vision_tower.encoder.layers.26.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
872
+ "vision_tower.encoder.layers.26.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
873
+ "vision_tower.encoder.layers.26.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
874
+ "vision_tower.encoder.layers.26.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
875
+ "vision_tower.encoder.layers.26.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
876
+ "vision_tower.encoder.layers.26.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
877
+ "vision_tower.encoder.layers.26.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
878
+ "vision_tower.encoder.layers.3.layer_norm1.bias": "model-00001-of-00004.safetensors",
879
+ "vision_tower.encoder.layers.3.layer_norm1.weight": "model-00001-of-00004.safetensors",
880
+ "vision_tower.encoder.layers.3.layer_norm2.bias": "model-00001-of-00004.safetensors",
881
+ "vision_tower.encoder.layers.3.layer_norm2.weight": "model-00001-of-00004.safetensors",
882
+ "vision_tower.encoder.layers.3.mlp.fc1.bias": "model-00001-of-00004.safetensors",
883
+ "vision_tower.encoder.layers.3.mlp.fc1.weight": "model-00001-of-00004.safetensors",
884
+ "vision_tower.encoder.layers.3.mlp.fc2.bias": "model-00001-of-00004.safetensors",
885
+ "vision_tower.encoder.layers.3.mlp.fc2.weight": "model-00001-of-00004.safetensors",
886
+ "vision_tower.encoder.layers.3.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
887
+ "vision_tower.encoder.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
888
+ "vision_tower.encoder.layers.3.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
889
+ "vision_tower.encoder.layers.3.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
890
+ "vision_tower.encoder.layers.3.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
891
+ "vision_tower.encoder.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
892
+ "vision_tower.encoder.layers.3.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
893
+ "vision_tower.encoder.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
894
+ "vision_tower.encoder.layers.4.layer_norm1.bias": "model-00001-of-00004.safetensors",
895
+ "vision_tower.encoder.layers.4.layer_norm1.weight": "model-00001-of-00004.safetensors",
896
+ "vision_tower.encoder.layers.4.layer_norm2.bias": "model-00001-of-00004.safetensors",
897
+ "vision_tower.encoder.layers.4.layer_norm2.weight": "model-00001-of-00004.safetensors",
898
+ "vision_tower.encoder.layers.4.mlp.fc1.bias": "model-00001-of-00004.safetensors",
899
+ "vision_tower.encoder.layers.4.mlp.fc1.weight": "model-00001-of-00004.safetensors",
900
+ "vision_tower.encoder.layers.4.mlp.fc2.bias": "model-00001-of-00004.safetensors",
901
+ "vision_tower.encoder.layers.4.mlp.fc2.weight": "model-00001-of-00004.safetensors",
902
+ "vision_tower.encoder.layers.4.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
903
+ "vision_tower.encoder.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
904
+ "vision_tower.encoder.layers.4.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
905
+ "vision_tower.encoder.layers.4.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
906
+ "vision_tower.encoder.layers.4.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
907
+ "vision_tower.encoder.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
908
+ "vision_tower.encoder.layers.4.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
909
+ "vision_tower.encoder.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
910
+ "vision_tower.encoder.layers.5.layer_norm1.bias": "model-00001-of-00004.safetensors",
911
+ "vision_tower.encoder.layers.5.layer_norm1.weight": "model-00001-of-00004.safetensors",
912
+ "vision_tower.encoder.layers.5.layer_norm2.bias": "model-00001-of-00004.safetensors",
913
+ "vision_tower.encoder.layers.5.layer_norm2.weight": "model-00001-of-00004.safetensors",
914
+ "vision_tower.encoder.layers.5.mlp.fc1.bias": "model-00001-of-00004.safetensors",
915
+ "vision_tower.encoder.layers.5.mlp.fc1.weight": "model-00001-of-00004.safetensors",
916
+ "vision_tower.encoder.layers.5.mlp.fc2.bias": "model-00001-of-00004.safetensors",
917
+ "vision_tower.encoder.layers.5.mlp.fc2.weight": "model-00001-of-00004.safetensors",
918
+ "vision_tower.encoder.layers.5.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
919
+ "vision_tower.encoder.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
920
+ "vision_tower.encoder.layers.5.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
921
+ "vision_tower.encoder.layers.5.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
922
+ "vision_tower.encoder.layers.5.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
923
+ "vision_tower.encoder.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
924
+ "vision_tower.encoder.layers.5.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
925
+ "vision_tower.encoder.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
926
+ "vision_tower.encoder.layers.6.layer_norm1.bias": "model-00001-of-00004.safetensors",
927
+ "vision_tower.encoder.layers.6.layer_norm1.weight": "model-00001-of-00004.safetensors",
928
+ "vision_tower.encoder.layers.6.layer_norm2.bias": "model-00001-of-00004.safetensors",
929
+ "vision_tower.encoder.layers.6.layer_norm2.weight": "model-00001-of-00004.safetensors",
930
+ "vision_tower.encoder.layers.6.mlp.fc1.bias": "model-00001-of-00004.safetensors",
931
+ "vision_tower.encoder.layers.6.mlp.fc1.weight": "model-00001-of-00004.safetensors",
932
+ "vision_tower.encoder.layers.6.mlp.fc2.bias": "model-00001-of-00004.safetensors",
933
+ "vision_tower.encoder.layers.6.mlp.fc2.weight": "model-00001-of-00004.safetensors",
934
+ "vision_tower.encoder.layers.6.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
935
+ "vision_tower.encoder.layers.6.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
936
+ "vision_tower.encoder.layers.6.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
937
+ "vision_tower.encoder.layers.6.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
938
+ "vision_tower.encoder.layers.6.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
939
+ "vision_tower.encoder.layers.6.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
940
+ "vision_tower.encoder.layers.6.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
941
+ "vision_tower.encoder.layers.6.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
942
+ "vision_tower.encoder.layers.7.layer_norm1.bias": "model-00001-of-00004.safetensors",
943
+ "vision_tower.encoder.layers.7.layer_norm1.weight": "model-00001-of-00004.safetensors",
944
+ "vision_tower.encoder.layers.7.layer_norm2.bias": "model-00001-of-00004.safetensors",
945
+ "vision_tower.encoder.layers.7.layer_norm2.weight": "model-00001-of-00004.safetensors",
946
+ "vision_tower.encoder.layers.7.mlp.fc1.bias": "model-00001-of-00004.safetensors",
947
+ "vision_tower.encoder.layers.7.mlp.fc1.weight": "model-00001-of-00004.safetensors",
948
+ "vision_tower.encoder.layers.7.mlp.fc2.bias": "model-00001-of-00004.safetensors",
949
+ "vision_tower.encoder.layers.7.mlp.fc2.weight": "model-00001-of-00004.safetensors",
950
+ "vision_tower.encoder.layers.7.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
951
+ "vision_tower.encoder.layers.7.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
952
+ "vision_tower.encoder.layers.7.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
953
+ "vision_tower.encoder.layers.7.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
954
+ "vision_tower.encoder.layers.7.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
955
+ "vision_tower.encoder.layers.7.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
956
+ "vision_tower.encoder.layers.7.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
957
+ "vision_tower.encoder.layers.7.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
958
+ "vision_tower.encoder.layers.8.layer_norm1.bias": "model-00001-of-00004.safetensors",
959
+ "vision_tower.encoder.layers.8.layer_norm1.weight": "model-00001-of-00004.safetensors",
960
+ "vision_tower.encoder.layers.8.layer_norm2.bias": "model-00001-of-00004.safetensors",
961
+ "vision_tower.encoder.layers.8.layer_norm2.weight": "model-00001-of-00004.safetensors",
962
+ "vision_tower.encoder.layers.8.mlp.fc1.bias": "model-00001-of-00004.safetensors",
963
+ "vision_tower.encoder.layers.8.mlp.fc1.weight": "model-00001-of-00004.safetensors",
964
+ "vision_tower.encoder.layers.8.mlp.fc2.bias": "model-00001-of-00004.safetensors",
965
+ "vision_tower.encoder.layers.8.mlp.fc2.weight": "model-00001-of-00004.safetensors",
966
+ "vision_tower.encoder.layers.8.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
967
+ "vision_tower.encoder.layers.8.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
968
+ "vision_tower.encoder.layers.8.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
969
+ "vision_tower.encoder.layers.8.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
970
+ "vision_tower.encoder.layers.8.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
971
+ "vision_tower.encoder.layers.8.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
972
+ "vision_tower.encoder.layers.8.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
973
+ "vision_tower.encoder.layers.8.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
974
+ "vision_tower.encoder.layers.9.layer_norm1.bias": "model-00001-of-00004.safetensors",
975
+ "vision_tower.encoder.layers.9.layer_norm1.weight": "model-00001-of-00004.safetensors",
976
+ "vision_tower.encoder.layers.9.layer_norm2.bias": "model-00001-of-00004.safetensors",
977
+ "vision_tower.encoder.layers.9.layer_norm2.weight": "model-00001-of-00004.safetensors",
978
+ "vision_tower.encoder.layers.9.mlp.fc1.bias": "model-00001-of-00004.safetensors",
979
+ "vision_tower.encoder.layers.9.mlp.fc1.weight": "model-00001-of-00004.safetensors",
980
+ "vision_tower.encoder.layers.9.mlp.fc2.bias": "model-00001-of-00004.safetensors",
981
+ "vision_tower.encoder.layers.9.mlp.fc2.weight": "model-00001-of-00004.safetensors",
982
+ "vision_tower.encoder.layers.9.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
983
+ "vision_tower.encoder.layers.9.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
984
+ "vision_tower.encoder.layers.9.self_attn.out_proj.bias": "model-00001-of-00004.safetensors",
985
+ "vision_tower.encoder.layers.9.self_attn.out_proj.weight": "model-00001-of-00004.safetensors",
986
+ "vision_tower.encoder.layers.9.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
987
+ "vision_tower.encoder.layers.9.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
988
+ "vision_tower.encoder.layers.9.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
989
+ "vision_tower.encoder.layers.9.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
990
+ "vision_tower.post_layernorm.bias": "model-00001-of-00004.safetensors",
991
+ "vision_tower.post_layernorm.weight": "model-00001-of-00004.safetensors",
992
+ "visual_source_spliter_emb.weight": "model-00001-of-00004.safetensors"
993
+ }
994
+ }
modeling_downsampler.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import PreTrainedModel
5
+ from transformers.activations import ACT2FN
6
+
7
+ from .configuration_downsampler import DownsamplerConfig
8
+
9
+
10
+ class DownsamplerModel(PreTrainedModel):
11
+ _auto_class = 'AutoModel'
12
+ config_class = DownsamplerConfig
13
+ base_model_prefix = 'model'
14
+ supports_gradient_checkpointing = True
15
+
16
+ def __init__(self, config: DownsamplerConfig) -> None:
17
+ super().__init__(config)
18
+ self.gradient_checkpointing = False
19
+
20
+ self.group_op = nn.Conv2d(
21
+ in_channels=config.visual_hidden_size,
22
+ out_channels=config.llm_hidden_size,
23
+ bias=config.bias,
24
+ kernel_size=config.kernel_size, stride=config.stride)
25
+ modules = list()
26
+ for _ in range(1, config.depth):
27
+ modules.append(ACT2FN[config.hidden_act])
28
+ modules.append(
29
+ nn.Linear(
30
+ config.llm_hidden_size,
31
+ config.llm_hidden_size,
32
+ bias=config.bias))
33
+ self.linear_model = nn.Sequential(*modules)
34
+
35
+ def enable_input_require_grads(self):
36
+
37
+ def make_inputs_require_grad(module, input, output):
38
+ output.requires_grad_(True)
39
+
40
+ self.model.register_forward_hook(make_inputs_require_grad)
41
+
42
+ def _set_gradient_checkpointing(self, module, value=False):
43
+ if isinstance(module, DownsamplerModel):
44
+ module.gradient_checkpointing = value
45
+
46
+ def _forward(self, x):
47
+
48
+ # (B, FULL_H, FULL_W, D) -> (B, D, FULL_H, FULL_W)
49
+ x = x.permute(0, 3, 1, 2)
50
+ x = self.group_op(x)
51
+ # (B, D, FULL_H, FULL_W) -> (B, FULL_H, FULL_W, D)
52
+ x = x.permute(0, 2, 3, 1)
53
+ x = self.linear_model(x)
54
+
55
+ return x
56
+
57
+ def forward(self, x):
58
+ if self.gradient_checkpointing and self.training:
59
+ layer_outputs = torch.utils.checkpoint.checkpoint(self._forward, x)
60
+ else:
61
+ layer_outputs = self._forward(x)
62
+ return layer_outputs
modeling_internlm2.py ADDED
@@ -0,0 +1,1495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # This code is based on transformers/src/transformers/models/llama/modeling_llama.py
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ PyTorch InternLM2 model."""
17
+ import math
18
+ import queue
19
+ import threading
20
+ import warnings
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+ import torch.utils.checkpoint
26
+ from einops import rearrange
27
+ from torch import nn
28
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
+ from transformers.activations import ACT2FN
30
+ from transformers.modeling_outputs import (
31
+ BaseModelOutputWithPast,
32
+ CausalLMOutputWithPast,
33
+ SequenceClassifierOutputWithPast,
34
+ )
35
+ from transformers.modeling_utils import PreTrainedModel
36
+ from transformers.utils import (
37
+ add_start_docstrings,
38
+ add_start_docstrings_to_model_forward,
39
+ logging,
40
+ replace_return_docstrings,
41
+ )
42
+
43
+ try:
44
+ from transformers.generation.streamers import BaseStreamer
45
+ except: # noqa # pylint: disable=bare-except
46
+ BaseStreamer = None
47
+
48
+ from .configuration_internlm2 import InternLM2Config
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+ _PLORA_DIM = 512
53
+ _CONFIG_FOR_DOC = "InternLM2Config"
54
+
55
+ flash_attn_func, flash_attn_varlen_func = None, None
56
+ pad_input, index_first_axis, unpad_input = None, None, None
57
+ def _import_flash_attn():
58
+ global flash_attn_func, flash_attn_varlen_func
59
+ global pad_input, index_first_axis, unpad_input
60
+ try:
61
+ from flash_attn import flash_attn_func as _flash_attn_func, flash_attn_varlen_func as _flash_attn_varlen_func
62
+ from flash_attn.bert_padding import pad_input as _pad_input, index_first_axis as _index_first_axis, unpad_input as _unpad_input
63
+ flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
64
+ pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
65
+ except ImportError:
66
+ raise ImportError("flash_attn is not installed.")
67
+
68
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
69
+ def _get_unpad_data(attention_mask):
70
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
71
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
72
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
73
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
74
+ return (
75
+ indices,
76
+ cu_seqlens,
77
+ max_seqlen_in_batch,
78
+ )
79
+
80
+
81
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
82
+ def _make_causal_mask(
83
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
84
+ ):
85
+ """
86
+ Make causal mask used for bi-directional self-attention.
87
+ """
88
+ bsz, tgt_len = input_ids_shape
89
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
90
+ mask_cond = torch.arange(mask.size(-1), device=device)
91
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
92
+ mask = mask.to(dtype)
93
+
94
+ if past_key_values_length > 0:
95
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
96
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
97
+
98
+
99
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
100
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
101
+ """
102
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
103
+ """
104
+ bsz, src_len = mask.size()
105
+ tgt_len = tgt_len if tgt_len is not None else src_len
106
+
107
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
108
+
109
+ inverted_mask = 1.0 - expanded_mask
110
+
111
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
112
+
113
+
114
+ class PLoRA(nn.Module):
115
+
116
+ def __init__(self,
117
+ in_features: int,
118
+ out_features: int,
119
+ bias: bool = True,
120
+ device=None,
121
+ dtype=None,
122
+ lora_r=8,
123
+ lora_alpha=16,
124
+ lora_dropout=0.05,
125
+ lora_len=0,
126
+ **kwargs) -> None:
127
+ super().__init__()
128
+
129
+ self.original_linear = nn.Linear(in_features, out_features, bias, device, dtype)
130
+
131
+ self.lora_r = lora_r
132
+ self.lora_alpha = lora_alpha
133
+ self.lora_len = lora_len
134
+ if lora_dropout > 0.:
135
+ self.lora_dropout = nn.Dropout(p=lora_dropout)
136
+ else:
137
+ self.lora_dropout = lambda x: x
138
+ self.lora_scaling = self.lora_alpha / self.lora_r
139
+
140
+ self.Plora_A = nn.Linear(
141
+ in_features, self.lora_r, bias=False, device=device, dtype=dtype)
142
+ self.Plora_B = nn.Linear(
143
+ self.lora_r, out_features, bias=False, device=device, dtype=dtype)
144
+
145
+ self.reset_parameters()
146
+
147
+ def reset_parameters(self):
148
+ if hasattr(self, 'lora_A'):
149
+ # initialize A the same way as the default for nn.Linear and B to zero
150
+ nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
151
+ nn.init.zeros_(self.lora_B.weight)
152
+
153
+ def forward(self, x, im_mask=None):
154
+ res = self.original_linear(x)
155
+
156
+ if im_mask is not None:
157
+ if torch.sum(im_mask) > 0:
158
+ part_x = x[im_mask]
159
+ res[im_mask] += self.Plora_B(
160
+ self.Plora_A(
161
+ self.lora_dropout(part_x))) * self.lora_scaling
162
+ else:
163
+ part_x = x[:, :1]
164
+ res[:, :1] += self.Plora_B(
165
+ self.Plora_A(self.lora_dropout(part_x))) * 0
166
+ return res
167
+
168
+
169
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->InternLM2
170
+ class InternLM2RMSNorm(nn.Module):
171
+ def __init__(self, hidden_size, eps=1e-6):
172
+ """
173
+ InternLM2RMSNorm is equivalent to T5LayerNorm
174
+ """
175
+ super().__init__()
176
+ self.weight = nn.Parameter(torch.ones(hidden_size))
177
+ self.variance_epsilon = eps
178
+
179
+ def forward(self, hidden_states):
180
+ input_dtype = hidden_states.dtype
181
+ hidden_states = hidden_states.to(torch.float32)
182
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
183
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
184
+ return self.weight * hidden_states.to(input_dtype)
185
+
186
+
187
+ # Copied from transformers.model.llama.modeling_llama.LlamaRotaryEmbedding with Llama->InternLM2
188
+ class InternLM2RotaryEmbedding(nn.Module):
189
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
190
+ super().__init__()
191
+
192
+ self.dim = dim
193
+ self.max_position_embeddings = max_position_embeddings
194
+ self.base = base
195
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
196
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
197
+
198
+ # Build here to make `torch.jit.trace` work.
199
+ self._set_cos_sin_cache(
200
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
201
+ )
202
+
203
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
204
+ self.max_seq_len_cached = seq_len
205
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
206
+
207
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
208
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
209
+ emb = torch.cat((freqs, freqs), dim=-1)
210
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
211
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
212
+
213
+ def forward(self, x, seq_len=None):
214
+ # x: [bs, num_attention_heads, seq_len, head_size]
215
+ if seq_len > self.max_seq_len_cached:
216
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=torch.float32)
217
+
218
+ return (
219
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
220
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
221
+ )
222
+
223
+
224
+ # Copied from transformers.model.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->InternLM2
225
+ class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding):
226
+ """InternLM2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
227
+
228
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
229
+ self.scaling_factor = scaling_factor
230
+ super().__init__(dim, max_position_embeddings, base, device)
231
+
232
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
233
+ self.max_seq_len_cached = seq_len
234
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
235
+ t = t / self.scaling_factor
236
+
237
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
238
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
239
+ emb = torch.cat((freqs, freqs), dim=-1)
240
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
241
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
242
+
243
+
244
+ # Copied from transformers.model.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->InternLM2
245
+ class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding):
246
+ """InternLM2RotaryEmbedding extended with Dynamic NTK scaling.
247
+ Credits to the Reddit users /u/bloc97 and /u/emozilla.
248
+ """
249
+
250
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
251
+ self.scaling_factor = scaling_factor
252
+ super().__init__(dim, max_position_embeddings, base, device)
253
+
254
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
255
+ self.max_seq_len_cached = seq_len
256
+
257
+ if seq_len > self.max_position_embeddings:
258
+ base = self.base * (
259
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
260
+ ) ** (self.dim / (self.dim - 2))
261
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
262
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
263
+
264
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
265
+
266
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
267
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
268
+ emb = torch.cat((freqs, freqs), dim=-1)
269
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
270
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
271
+
272
+
273
+ # Copied from transformers.model.llama.modeling_llama.rotate_half
274
+ def rotate_half(x):
275
+ """Rotates half the hidden dims of the input."""
276
+ x1 = x[..., : x.shape[-1] // 2]
277
+ x2 = x[..., x.shape[-1] // 2 :]
278
+ return torch.cat((-x2, x1), dim=-1)
279
+
280
+
281
+ # Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb
282
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
283
+ """Applies Rotary Position Embedding to the query and key tensors."""
284
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
285
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
286
+ q_embed = (q * cos) + (rotate_half(q) * sin)
287
+ k_embed = (k * cos) + (rotate_half(k) * sin)
288
+ return q_embed, k_embed
289
+
290
+
291
+ class InternLM2MLP(nn.Module):
292
+ def __init__(self, config):
293
+ super().__init__()
294
+ self.config = config
295
+ self.hidden_size = config.hidden_size
296
+ self.intermediate_size = config.intermediate_size
297
+ # self.w1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
298
+ # self.w3 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
299
+ # self.w2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
300
+
301
+ self.w1 = PLoRA(
302
+ self.hidden_size,
303
+ self.intermediate_size,
304
+ bias=False,
305
+ lora_r=_PLORA_DIM,
306
+ lora_alpha=_PLORA_DIM,
307
+ lora_len=576)
308
+ self.w3 = PLoRA(
309
+ self.hidden_size,
310
+ self.intermediate_size,
311
+ bias=False,
312
+ lora_r=_PLORA_DIM,
313
+ lora_alpha=_PLORA_DIM,
314
+ lora_len=576)
315
+ self.w2 = PLoRA(
316
+ self.intermediate_size,
317
+ self.hidden_size,
318
+ bias=False,
319
+ lora_r=_PLORA_DIM,
320
+ lora_alpha=_PLORA_DIM,
321
+ lora_len=576)
322
+
323
+ self.act_fn = ACT2FN[config.hidden_act]
324
+
325
+ def forward(self, x, im_mask):
326
+ down_proj = self.w2(self.act_fn(self.w1(x, im_mask)) * self.w3(x, im_mask), im_mask)
327
+ return down_proj
328
+
329
+
330
+ # Copied from transformers.model.llama.modeling_llama.repeat_kv
331
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
332
+ """
333
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
334
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
335
+ """
336
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
337
+ if n_rep == 1:
338
+ return hidden_states
339
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
340
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
341
+
342
+
343
+ # Modified from transformers.model.llama.modeling_llama.LlamaAttention
344
+ class InternLM2Attention(nn.Module):
345
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
346
+
347
+ def __init__(self, config: InternLM2Config):
348
+ super().__init__()
349
+ self.config = config
350
+ self.hidden_size = config.hidden_size
351
+ self.num_heads = config.num_attention_heads
352
+ self.head_dim = self.hidden_size // self.num_heads
353
+ self.num_key_value_heads = config.num_key_value_heads
354
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
355
+ self.max_position_embeddings = config.max_position_embeddings
356
+ self.is_causal = True
357
+
358
+ if (self.head_dim * self.num_heads) != self.hidden_size:
359
+ raise ValueError(
360
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
361
+ f" and `num_heads`: {self.num_heads})."
362
+ )
363
+
364
+ # self.wqkv = nn.Linear(
365
+ # self.hidden_size,
366
+ # (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim,
367
+ # bias=config.bias,
368
+ # )
369
+ #
370
+ # self.wo = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
371
+
372
+ self.wqkv = PLoRA(
373
+ self.hidden_size,
374
+ (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim,
375
+ bias=config.bias,
376
+ lora_r=_PLORA_DIM,
377
+ lora_alpha=_PLORA_DIM,
378
+ lora_len=576)
379
+
380
+ self.wo = PLoRA(
381
+ self.num_heads * self.head_dim,
382
+ self.hidden_size,
383
+ bias=config.bias,
384
+ lora_r=_PLORA_DIM,
385
+ lora_alpha=_PLORA_DIM,
386
+ lora_len=576)
387
+ self._init_rope()
388
+
389
+ def _init_rope(self):
390
+ if self.config.rope_scaling is None:
391
+ self.rotary_emb = InternLM2RotaryEmbedding(
392
+ self.head_dim,
393
+ max_position_embeddings=self.max_position_embeddings,
394
+ base=self.config.rope_theta,
395
+ )
396
+ else:
397
+ scaling_type = self.config.rope_scaling["type"]
398
+ scaling_factor = self.config.rope_scaling["factor"]
399
+ if scaling_type == "dynamic":
400
+ self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding(
401
+ self.head_dim,
402
+ max_position_embeddings=self.max_position_embeddings,
403
+ base=self.config.rope_theta,
404
+ scaling_factor=scaling_factor,
405
+ )
406
+ elif scaling_type == "linear":
407
+ self.rotary_emb = InternLM2LinearScalingRotaryEmbedding(
408
+ self.head_dim,
409
+ max_position_embeddings=self.max_position_embeddings,
410
+ base=self.config.rope_theta,
411
+ scaling_factor=scaling_factor,
412
+ )
413
+ else:
414
+ raise ValueError("Currently we only support rotary embedding's type being 'dynamic' or 'linear'.")
415
+ return self.rotary_emb
416
+
417
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
418
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
419
+
420
+ def forward(
421
+ self,
422
+ hidden_states: torch.Tensor,
423
+ attention_mask: Optional[torch.Tensor] = None,
424
+ position_ids: Optional[torch.LongTensor] = None,
425
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
426
+ output_attentions: bool = False,
427
+ use_cache: bool = False,
428
+ im_mask: Optional[Tuple[torch.Tensor]] = None,
429
+ **kwargs,
430
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
431
+ if "padding_mask" in kwargs:
432
+ warnings.warn(
433
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. "
434
+ "Please make sure use `attention_mask` instead.`"
435
+ )
436
+
437
+ bsz, q_len, _ = hidden_states.size()
438
+ qkv_states = self.wqkv(hidden_states, im_mask)
439
+
440
+ qkv_states = rearrange(
441
+ qkv_states,
442
+ "b q (h gs d) -> b q h gs d",
443
+ gs=2 + self.num_key_value_groups,
444
+ d=self.head_dim,
445
+ )
446
+
447
+ query_states = qkv_states[..., : self.num_key_value_groups, :]
448
+ query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d")
449
+ key_states = qkv_states[..., -2, :]
450
+ value_states = qkv_states[..., -1, :]
451
+
452
+ query_states = query_states.transpose(1, 2)
453
+ key_states = key_states.transpose(1, 2)
454
+ value_states = value_states.transpose(1, 2)
455
+
456
+ kv_seq_len = key_states.shape[-2]
457
+ if past_key_value is not None:
458
+ kv_seq_len += past_key_value[0].shape[-2]
459
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
460
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
461
+
462
+ if past_key_value is not None:
463
+ # reuse k, v, self_attention
464
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
465
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
466
+
467
+ past_key_value = (key_states, value_states) if use_cache else None
468
+
469
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
470
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
471
+
472
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
473
+
474
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
475
+ raise ValueError(
476
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
477
+ f" {attn_weights.size()}"
478
+ )
479
+
480
+ if attention_mask is not None:
481
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
482
+ raise ValueError(
483
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
484
+ )
485
+ attn_weights = attn_weights + attention_mask
486
+
487
+ # upcast attention to fp32
488
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
489
+ attn_output = torch.matmul(attn_weights, value_states)
490
+
491
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
492
+ raise ValueError(
493
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
494
+ f" {attn_output.size()}"
495
+ )
496
+
497
+ attn_output = attn_output.transpose(1, 2).contiguous()
498
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
499
+
500
+ attn_output = self.wo(attn_output, im_mask)
501
+
502
+ if not output_attentions:
503
+ attn_weights = None
504
+
505
+ return attn_output, attn_weights, past_key_value
506
+
507
+
508
+ # Modified from transformers.model.llama.modeling_llama.InternLM2FlashAttention2
509
+ class InternLM2FlashAttention2(InternLM2Attention):
510
+ """
511
+ InternLM2 flash attention module. This module inherits from `InternLM2Attention` as the weights of the module stays
512
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
513
+ flash attention and deal with padding tokens in case the input contains any of them.
514
+ """
515
+
516
+ def forward(
517
+ self,
518
+ hidden_states: torch.Tensor,
519
+ attention_mask: Optional[torch.LongTensor] = None,
520
+ position_ids: Optional[torch.LongTensor] = None,
521
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
522
+ output_attentions: bool = False,
523
+ use_cache: bool = False,
524
+ im_mask: Optional[Tuple[torch.Tensor]] = None,
525
+ **kwargs,
526
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
527
+ # InternLM2FlashAttention2 attention does not support output_attentions
528
+ if "padding_mask" in kwargs:
529
+ warnings.warn(
530
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. "
531
+ "Please make sure use `attention_mask` instead.`"
532
+ )
533
+
534
+ # overwrite attention_mask with padding_mask
535
+ attention_mask = kwargs.pop("padding_mask")
536
+
537
+ output_attentions = False
538
+
539
+ bsz, q_len, _ = hidden_states.size()
540
+ qkv_states = self.wqkv(hidden_states, im_mask)
541
+
542
+ qkv_states = rearrange(
543
+ qkv_states,
544
+ "b q (h gs d) -> b q h gs d",
545
+ gs=2 + self.num_key_value_groups,
546
+ d=self.head_dim,
547
+ )
548
+
549
+ query_states = qkv_states[..., : self.num_key_value_groups, :]
550
+ query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d")
551
+ key_states = qkv_states[..., -2, :]
552
+ value_states = qkv_states[..., -1, :]
553
+
554
+ query_states = query_states.transpose(1, 2)
555
+ key_states = key_states.transpose(1, 2)
556
+ value_states = value_states.transpose(1, 2)
557
+
558
+ kv_seq_len = key_states.shape[-2]
559
+ if past_key_value is not None:
560
+ kv_seq_len += past_key_value[0].shape[-2]
561
+
562
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
563
+
564
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
565
+
566
+ if past_key_value is not None:
567
+ # reuse k, v, self_attention
568
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
569
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
570
+
571
+ past_key_value = (key_states, value_states) if use_cache else None
572
+
573
+ query_states = query_states.transpose(1, 2)
574
+ key_states = key_states.transpose(1, 2)
575
+ value_states = value_states.transpose(1, 2)
576
+
577
+ attn_output = self._flash_attention_forward(
578
+ query_states, key_states, value_states, attention_mask, q_len
579
+ )
580
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
581
+ attn_output = self.wo(attn_output, im_mask)
582
+
583
+ if not output_attentions:
584
+ attn_weights = None
585
+
586
+ return attn_output, attn_weights, past_key_value
587
+
588
+ def _flash_attention_forward(
589
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
590
+ ):
591
+ """
592
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
593
+ first unpad the input, then computes the attention scores and pad the final attention scores.
594
+
595
+ Args:
596
+ query_states (`torch.Tensor`):
597
+ Input query states to be passed to Flash Attention API
598
+ key_states (`torch.Tensor`):
599
+ Input key states to be passed to Flash Attention API
600
+ value_states (`torch.Tensor`):
601
+ Input value states to be passed to Flash Attention API
602
+ attention_mask (`torch.Tensor`):
603
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
604
+ position of padding tokens and 1 for the position of non-padding tokens.
605
+ dropout (`int`, *optional*):
606
+ Attention dropout
607
+ softmax_scale (`float`, *optional*):
608
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
609
+ """
610
+ # Contains at least one padding token in the sequence
611
+ causal = self.is_causal and query_length != 1
612
+ if attention_mask is not None:
613
+ batch_size = query_states.shape[0]
614
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input(
615
+ query_states, key_states, value_states, attention_mask, query_length
616
+ )
617
+
618
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
619
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
620
+
621
+ attn_output_unpad = flash_attn_varlen_func(
622
+ query_states,
623
+ key_states,
624
+ value_states,
625
+ cu_seqlens_q=cu_seqlens_q,
626
+ cu_seqlens_k=cu_seqlens_k,
627
+ max_seqlen_q=max_seqlen_in_batch_q,
628
+ max_seqlen_k=max_seqlen_in_batch_k,
629
+ dropout_p=dropout,
630
+ softmax_scale=softmax_scale,
631
+ causal=causal,
632
+ )
633
+
634
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
635
+ else:
636
+ attn_output = flash_attn_func(
637
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
638
+ )
639
+
640
+ return attn_output
641
+
642
+ def _unpad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
643
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
644
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
645
+
646
+ key_layer = index_first_axis(
647
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
648
+ )
649
+ value_layer = index_first_axis(
650
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
651
+ )
652
+
653
+ if query_length == kv_seq_len:
654
+ query_layer = index_first_axis(
655
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
656
+ )
657
+ cu_seqlens_q = cu_seqlens_k
658
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
659
+ indices_q = indices_k
660
+ elif query_length == 1:
661
+ max_seqlen_in_batch_q = 1
662
+ cu_seqlens_q = torch.arange(
663
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
664
+ ) # There is a memcpy here, that is very bad.
665
+ indices_q = cu_seqlens_q[:-1]
666
+ query_layer = query_layer.squeeze(1)
667
+ else:
668
+ # The -q_len: slice assumes left padding.
669
+ attention_mask = attention_mask[:, -query_length:]
670
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
671
+
672
+ return (
673
+ query_layer,
674
+ key_layer,
675
+ value_layer,
676
+ indices_q.to(torch.int64),
677
+ (cu_seqlens_q, cu_seqlens_k),
678
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
679
+ )
680
+
681
+ INTERNLM2_ATTENTION_CLASSES = {
682
+ "eager": InternLM2Attention,
683
+ "flash_attention_2": InternLM2FlashAttention2,
684
+ }
685
+
686
+ # Modified from transformers.model.llama.modeling_llama.LlamaDecoderLayer
687
+ class InternLM2DecoderLayer(nn.Module):
688
+ def __init__(self, config: InternLM2Config):
689
+ super().__init__()
690
+ self.hidden_size = config.hidden_size
691
+
692
+ self.attention = INTERNLM2_ATTENTION_CLASSES[config.attn_implementation](config=config)
693
+ self.feed_forward = InternLM2MLP(config)
694
+ self.attention_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
695
+ self.ffn_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
696
+
697
+ def forward(
698
+ self,
699
+ hidden_states: torch.Tensor,
700
+ attention_mask: Optional[torch.Tensor] = None,
701
+ position_ids: Optional[torch.LongTensor] = None,
702
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
703
+ output_attentions: Optional[bool] = False,
704
+ use_cache: Optional[bool] = False,
705
+ im_mask: Optional[Tuple[torch.Tensor]] = None,
706
+ **kwargs,
707
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
708
+ """
709
+ Args:
710
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
711
+ attention_mask (`torch.FloatTensor`, *optional*):
712
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
713
+ query_sequence_length, key_sequence_length)` if default attention is used.
714
+ output_attentions (`bool`, *optional*):
715
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
716
+ returned tensors for more detail.
717
+ use_cache (`bool`, *optional*):
718
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
719
+ (see `past_key_values`).
720
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
721
+ """
722
+ if "padding_mask" in kwargs:
723
+ warnings.warn(
724
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. "
725
+ "Please make sure use `attention_mask` instead.`"
726
+ )
727
+
728
+ residual = hidden_states
729
+
730
+ hidden_states = self.attention_norm(hidden_states)
731
+ # Self Attention
732
+ hidden_states, self_attn_weights, present_key_value = self.attention(
733
+ hidden_states=hidden_states,
734
+ attention_mask=attention_mask,
735
+ position_ids=position_ids,
736
+ past_key_value=past_key_value,
737
+ output_attentions=output_attentions,
738
+ use_cache=use_cache,
739
+ im_mask=im_mask,
740
+ **kwargs,
741
+ )
742
+ hidden_states = residual + hidden_states
743
+
744
+ # Fully Connected
745
+ residual = hidden_states
746
+ hidden_states = self.ffn_norm(hidden_states)
747
+ hidden_states = self.feed_forward(hidden_states, im_mask)
748
+ hidden_states = residual + hidden_states
749
+
750
+ outputs = (hidden_states,)
751
+
752
+ if output_attentions:
753
+ outputs += (self_attn_weights,)
754
+
755
+ if use_cache:
756
+ outputs += (present_key_value,)
757
+
758
+ return outputs
759
+
760
+
761
+ InternLM2_START_DOCSTRING = r"""
762
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
763
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
764
+ etc.)
765
+
766
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
767
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
768
+ and behavior.
769
+
770
+ Parameters:
771
+ config ([`InternLM2Config`]):
772
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
773
+ load the weights associated with the model, only the configuration. Check out the
774
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
775
+ """
776
+
777
+
778
+ # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->InternLM2
779
+ @add_start_docstrings(
780
+ "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.",
781
+ InternLM2_START_DOCSTRING,
782
+ )
783
+ class InternLM2PreTrainedModel(PreTrainedModel):
784
+ config_class = InternLM2Config
785
+ base_model_prefix = "model"
786
+ supports_gradient_checkpointing = True
787
+ _no_split_modules = ["InternLM2DecoderLayer"]
788
+ _skip_keys_device_placement = "past_key_values"
789
+
790
+ def _init_weights(self, module):
791
+ std = self.config.initializer_range
792
+ if isinstance(module, nn.Linear):
793
+ module.weight.data.normal_(mean=0.0, std=std)
794
+ if module.bias is not None:
795
+ module.bias.data.zero_()
796
+ elif isinstance(module, nn.Embedding):
797
+ module.weight.data.normal_(mean=0.0, std=std)
798
+ if module.padding_idx is not None:
799
+ module.weight.data[module.padding_idx].zero_()
800
+
801
+
802
+ InternLM2_INPUTS_DOCSTRING = r"""
803
+ Args:
804
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
805
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
806
+ it.
807
+
808
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
809
+ [`PreTrainedTokenizer.__call__`] for details.
810
+
811
+ [What are input IDs?](../glossary#input-ids)
812
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
813
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
814
+
815
+ - 1 for tokens that are **not masked**,
816
+ - 0 for tokens that are **masked**.
817
+
818
+ [What are attention masks?](../glossary#attention-mask)
819
+
820
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
821
+ [`PreTrainedTokenizer.__call__`] for details.
822
+
823
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
824
+ `past_key_values`).
825
+
826
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
827
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
828
+ information on the default strategy.
829
+
830
+ - 1 indicates the head is **not masked**,
831
+ - 0 indicates the head is **masked**.
832
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
833
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
834
+ config.n_positions - 1]`.
835
+
836
+ [What are position IDs?](../glossary#position-ids)
837
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or
838
+ when `config.use_cache=True`):
839
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
840
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
841
+ `(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)`.
842
+
843
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
844
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
845
+
846
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
847
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
848
+ of shape `(batch_size, sequence_length)`.
849
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
850
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
851
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
852
+ model's internal embedding lookup matrix.
853
+ use_cache (`bool`, *optional*):
854
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
855
+ `past_key_values`).
856
+ output_attentions (`bool`, *optional*):
857
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
858
+ tensors for more detail.
859
+ output_hidden_states (`bool`, *optional*):
860
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
861
+ more detail.
862
+ return_dict (`bool`, *optional*):
863
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
864
+ """
865
+
866
+
867
+ # Modified from transformers.model.llama.modeling_llama.LlamaModel
868
+ @add_start_docstrings(
869
+ "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.",
870
+ InternLM2_START_DOCSTRING,
871
+ )
872
+ class InternLM2Model(InternLM2PreTrainedModel):
873
+ """
874
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLM2DecoderLayer`]
875
+
876
+ Args:
877
+ config: InternLM2Config
878
+ """
879
+
880
+ _auto_class = "AutoModel"
881
+
882
+ def __init__(self, config: InternLM2Config):
883
+ super().__init__(config)
884
+ self.padding_idx = config.pad_token_id
885
+ self.vocab_size = config.vocab_size
886
+ self.config = config
887
+
888
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
889
+
890
+ self.layers = nn.ModuleList([InternLM2DecoderLayer(config) for _ in range(config.num_hidden_layers)])
891
+ self.norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
892
+
893
+ self.gradient_checkpointing = False
894
+ # Initialize weights and apply final processing
895
+ self.post_init()
896
+
897
+ def get_input_embeddings(self):
898
+ return self.tok_embeddings
899
+
900
+ def set_input_embeddings(self, value):
901
+ self.tok_embeddings = value
902
+
903
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
904
+ # create causal mask
905
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
906
+ combined_attention_mask = None
907
+ if input_shape[-1] > 1:
908
+ combined_attention_mask = _make_causal_mask(
909
+ input_shape,
910
+ inputs_embeds.dtype,
911
+ device=inputs_embeds.device,
912
+ past_key_values_length=past_key_values_length,
913
+ )
914
+
915
+ if attention_mask is not None:
916
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
917
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
918
+ inputs_embeds.device
919
+ )
920
+ combined_attention_mask = (
921
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
922
+ )
923
+
924
+ return combined_attention_mask
925
+
926
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
927
+ def forward(
928
+ self,
929
+ input_ids: torch.LongTensor = None,
930
+ attention_mask: Optional[torch.Tensor] = None,
931
+ position_ids: Optional[torch.LongTensor] = None,
932
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
933
+ inputs_embeds: Optional[torch.FloatTensor] = None,
934
+ use_cache: Optional[bool] = None,
935
+ output_attentions: Optional[bool] = None,
936
+ output_hidden_states: Optional[bool] = None,
937
+ return_dict: Optional[bool] = None,
938
+ **kwargs
939
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
940
+
941
+ im_mask = kwargs.get('im_mask', None)
942
+
943
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
944
+ output_hidden_states = (
945
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
946
+ )
947
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
948
+
949
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
950
+
951
+ if self.config.attn_implementation == "flash_attention_2":
952
+ _import_flash_attn()
953
+
954
+ # retrieve input_ids and inputs_embeds
955
+ if input_ids is not None and inputs_embeds is not None:
956
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
957
+ elif input_ids is not None:
958
+ batch_size, seq_length = input_ids.shape[:2]
959
+ elif inputs_embeds is not None:
960
+ batch_size, seq_length = inputs_embeds.shape[:2]
961
+ else:
962
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
963
+
964
+ seq_length_with_past = seq_length
965
+ past_key_values_length = 0
966
+ if past_key_values is not None:
967
+ past_key_values_length = past_key_values[0][0].shape[2]
968
+ seq_length_with_past = seq_length_with_past + past_key_values_length
969
+
970
+ if position_ids is None:
971
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
972
+ position_ids = torch.arange(
973
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
974
+ )
975
+ position_ids = position_ids.unsqueeze(0)
976
+
977
+ if inputs_embeds is None:
978
+ inputs_embeds = self.tok_embeddings(input_ids)
979
+ im_mask = torch.zeros(inputs_embeds.shape[:2]).to(
980
+ inputs_embeds.device).bool()
981
+
982
+ if self.config.attn_implementation == "flash_attention_2":
983
+ # 2d mask is passed through the layers
984
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
985
+ else:
986
+ if attention_mask is None:
987
+ attention_mask = torch.ones(
988
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
989
+ )
990
+ attention_mask = self._prepare_decoder_attention_mask(
991
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
992
+ )
993
+
994
+ # embed positions
995
+ hidden_states = inputs_embeds
996
+
997
+ if self.gradient_checkpointing and self.training:
998
+ if use_cache:
999
+ logger.warning_once(
1000
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1001
+ )
1002
+ use_cache = False
1003
+
1004
+ # decoder layers
1005
+ all_hidden_states = () if output_hidden_states else None
1006
+ all_self_attns = () if output_attentions else None
1007
+ next_decoder_cache = () if use_cache else None
1008
+
1009
+ for idx, decoder_layer in enumerate(self.layers):
1010
+ if output_hidden_states:
1011
+ all_hidden_states += (hidden_states,)
1012
+
1013
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
1014
+
1015
+ if self.gradient_checkpointing and self.training:
1016
+
1017
+ def create_custom_forward(module):
1018
+ def custom_forward(*inputs):
1019
+ # None for past_key_value
1020
+ return module(*inputs, output_attentions, None, im_mask)
1021
+
1022
+ return custom_forward
1023
+
1024
+ layer_outputs = torch.utils.checkpoint.checkpoint(
1025
+ create_custom_forward(decoder_layer),
1026
+ hidden_states,
1027
+ attention_mask,
1028
+ position_ids,
1029
+ None,
1030
+ )
1031
+ else:
1032
+ layer_outputs = decoder_layer(
1033
+ hidden_states,
1034
+ attention_mask=attention_mask,
1035
+ position_ids=position_ids,
1036
+ past_key_value=past_key_value,
1037
+ output_attentions=output_attentions,
1038
+ use_cache=use_cache,
1039
+ im_mask=im_mask,
1040
+ )
1041
+
1042
+ hidden_states = layer_outputs[0]
1043
+
1044
+ if use_cache:
1045
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
1046
+
1047
+ if output_attentions:
1048
+ all_self_attns += (layer_outputs[1],)
1049
+
1050
+ hidden_states = self.norm(hidden_states)
1051
+
1052
+ # add hidden states from the last decoder layer
1053
+ if output_hidden_states:
1054
+ all_hidden_states += (hidden_states,)
1055
+
1056
+ next_cache = next_decoder_cache if use_cache else None
1057
+ if not return_dict:
1058
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1059
+ return BaseModelOutputWithPast(
1060
+ last_hidden_state=hidden_states,
1061
+ past_key_values=next_cache,
1062
+ hidden_states=all_hidden_states,
1063
+ attentions=all_self_attns,
1064
+ )
1065
+
1066
+
1067
+ # Modified from transformers.model.llama.modeling_llama.LlamaForCausalLM
1068
+ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
1069
+ _auto_class = "AutoModelForCausalLM"
1070
+
1071
+ _tied_weights_keys = ["output.weight"]
1072
+
1073
+ def __init__(self, config):
1074
+ super().__init__(config)
1075
+ self.model = InternLM2Model(config)
1076
+ self.vocab_size = config.vocab_size
1077
+ self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1078
+
1079
+ # Initialize weights and apply final processing
1080
+ self.post_init()
1081
+
1082
+ def get_input_embeddings(self):
1083
+ return self.model.tok_embeddings
1084
+
1085
+ def set_input_embeddings(self, value):
1086
+ self.model.tok_embeddings = value
1087
+
1088
+ def get_output_embeddings(self):
1089
+ return self.output
1090
+
1091
+ def set_output_embeddings(self, new_embeddings):
1092
+ self.output = new_embeddings
1093
+
1094
+ def set_decoder(self, decoder):
1095
+ self.model = decoder
1096
+
1097
+ def get_decoder(self):
1098
+ return self.model
1099
+
1100
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
1101
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1102
+ def forward(
1103
+ self,
1104
+ input_ids: torch.LongTensor = None,
1105
+ attention_mask: Optional[torch.Tensor] = None,
1106
+ position_ids: Optional[torch.LongTensor] = None,
1107
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1108
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1109
+ labels: Optional[torch.LongTensor] = None,
1110
+ use_cache: Optional[bool] = None,
1111
+ output_attentions: Optional[bool] = None,
1112
+ output_hidden_states: Optional[bool] = None,
1113
+ return_dict: Optional[bool] = None,
1114
+ im_mask: Optional[torch.Tensor] = None,
1115
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1116
+ r"""
1117
+ Args:
1118
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1119
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1120
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1121
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1122
+
1123
+ Returns:
1124
+
1125
+ Example:
1126
+
1127
+ ```python
1128
+ >>> from transformers import AutoTokenizer, InternLM2ForCausalLM
1129
+
1130
+ >>> model = InternLM2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1131
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1132
+
1133
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1134
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1135
+
1136
+ >>> # Generate
1137
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1138
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1139
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1140
+ ```"""
1141
+
1142
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1143
+ output_hidden_states = (
1144
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1145
+ )
1146
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1147
+
1148
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1149
+ outputs = self.model(
1150
+ input_ids=input_ids,
1151
+ attention_mask=attention_mask,
1152
+ position_ids=position_ids,
1153
+ past_key_values=past_key_values,
1154
+ inputs_embeds=inputs_embeds,
1155
+ use_cache=use_cache,
1156
+ output_attentions=output_attentions,
1157
+ output_hidden_states=output_hidden_states,
1158
+ return_dict=return_dict,
1159
+ im_mask=im_mask,
1160
+ )
1161
+
1162
+ hidden_states = outputs[0]
1163
+ logits = self.output(hidden_states)
1164
+ logits = logits.float()
1165
+
1166
+ loss = None
1167
+ if labels is not None:
1168
+ # Shift so that tokens < n predict n
1169
+ shift_logits = logits[..., :-1, :].contiguous()
1170
+ shift_labels = labels[..., 1:].contiguous()
1171
+ # Flatten the tokens
1172
+ loss_fct = CrossEntropyLoss()
1173
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1174
+ shift_labels = shift_labels.view(-1)
1175
+ # Enable model parallelism
1176
+ shift_labels = shift_labels.to(shift_logits.device)
1177
+ loss = loss_fct(shift_logits, shift_labels)
1178
+
1179
+ if not return_dict:
1180
+ output = (logits,) + outputs[1:]
1181
+ return (loss,) + output if loss is not None else output
1182
+
1183
+ return CausalLMOutputWithPast(
1184
+ loss=loss,
1185
+ logits=logits,
1186
+ past_key_values=outputs.past_key_values,
1187
+ hidden_states=outputs.hidden_states,
1188
+ attentions=outputs.attentions,
1189
+ )
1190
+
1191
+ def prepare_inputs_for_generation(
1192
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1193
+ ):
1194
+ if past_key_values is not None:
1195
+ past_length = past_key_values[0][0].shape[2]
1196
+
1197
+ # Some generation methods already pass only the last input ID
1198
+ if input_ids.shape[1] > past_length:
1199
+ remove_prefix_length = past_length
1200
+ else:
1201
+ # Default to old behavior: keep only final ID
1202
+ remove_prefix_length = input_ids.shape[1] - 1
1203
+
1204
+ input_ids = input_ids[:, remove_prefix_length:]
1205
+
1206
+ position_ids = kwargs.get("position_ids", None)
1207
+ if attention_mask is not None and position_ids is None:
1208
+ # create position_ids on the fly for batch generation
1209
+ position_ids = attention_mask.long().cumsum(-1) - 1
1210
+ position_ids.masked_fill_(attention_mask == 0, 1)
1211
+ if past_key_values:
1212
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1213
+
1214
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1215
+ if inputs_embeds is not None and past_key_values is None:
1216
+ model_inputs = {"inputs_embeds": inputs_embeds}
1217
+ else:
1218
+ model_inputs = {"input_ids": input_ids}
1219
+
1220
+ model_inputs.update(
1221
+ {
1222
+ "position_ids": position_ids,
1223
+ "past_key_values": past_key_values,
1224
+ "use_cache": kwargs.get("use_cache"),
1225
+ "attention_mask": attention_mask,
1226
+ "im_mask": kwargs.get("im_mask", None),
1227
+ }
1228
+ )
1229
+ return model_inputs
1230
+
1231
+ @staticmethod
1232
+ def _reorder_cache(past_key_values, beam_idx):
1233
+ reordered_past = ()
1234
+ for layer_past in past_key_values:
1235
+ reordered_past += (
1236
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1237
+ )
1238
+ return reordered_past
1239
+
1240
+ def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = [], meta_instruction=""):
1241
+ if tokenizer.add_bos_token:
1242
+ prompt = ""
1243
+ else:
1244
+ prompt = tokenizer.bos_token
1245
+ if meta_instruction:
1246
+ prompt += f"""<|im_start|>system\n{meta_instruction}<|im_end|>\n"""
1247
+ for record in history:
1248
+ prompt += f"""<|im_start|>user\n{record[0]}<|im_end|>\n<|im_start|>assistant\n{record[1]}<|im_end|>\n"""
1249
+ prompt += f"""<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n"""
1250
+ return tokenizer([prompt], return_tensors="pt")
1251
+
1252
+ @torch.no_grad()
1253
+ def chat(
1254
+ self,
1255
+ tokenizer,
1256
+ query: str,
1257
+ history: List[Tuple[str, str]] = [],
1258
+ streamer: Optional[BaseStreamer] = None,
1259
+ max_new_tokens: int = 1024,
1260
+ do_sample: bool = True,
1261
+ temperature: float = 0.8,
1262
+ top_p: float = 0.8,
1263
+ meta_instruction: str = "You are an AI assistant whose name is InternLM (书生·浦语).\n"
1264
+ "- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
1265
+ "- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such as English and 中文.",
1266
+ **kwargs,
1267
+ ):
1268
+ inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
1269
+ inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
1270
+ # also add end-of-assistant token in eos token id to avoid unnecessary generation
1271
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(["<|im_end|>"])[0]]
1272
+ outputs = self.generate(
1273
+ **inputs,
1274
+ streamer=streamer,
1275
+ max_new_tokens=max_new_tokens,
1276
+ do_sample=do_sample,
1277
+ temperature=temperature,
1278
+ top_p=top_p,
1279
+ eos_token_id=eos_token_id,
1280
+ **kwargs,
1281
+ )
1282
+ outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :]
1283
+ response = tokenizer.decode(outputs, skip_special_tokens=True)
1284
+ response = response.split("<|im_end|>")[0]
1285
+ history = history + [(query, response)]
1286
+ return response, history
1287
+
1288
+ @torch.no_grad()
1289
+ def stream_chat(
1290
+ self,
1291
+ tokenizer,
1292
+ query: str,
1293
+ history: List[Tuple[str, str]] = [],
1294
+ max_new_tokens: int = 1024,
1295
+ do_sample: bool = True,
1296
+ temperature: float = 0.8,
1297
+ top_p: float = 0.8,
1298
+ **kwargs,
1299
+ ):
1300
+ """
1301
+ Return a generator in format: (response, history)
1302
+ Eg.
1303
+ ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
1304
+ ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
1305
+ """
1306
+ if BaseStreamer is None:
1307
+ raise ModuleNotFoundError(
1308
+ "The version of `transformers` is too low. Please make sure "
1309
+ "that you have installed `transformers>=4.28.0`."
1310
+ )
1311
+
1312
+ response_queue = queue.Queue(maxsize=20)
1313
+
1314
+ class ChatStreamer(BaseStreamer):
1315
+ def __init__(self, tokenizer) -> None:
1316
+ super().__init__()
1317
+ self.tokenizer = tokenizer
1318
+ self.queue = response_queue
1319
+ self.query = query
1320
+ self.history = history
1321
+ self.response = ""
1322
+ self.cache = []
1323
+ self.received_inputs = False
1324
+ self.queue.put((self.response, history + [(self.query, self.response)]))
1325
+
1326
+ def put(self, value):
1327
+ if len(value.shape) > 1 and value.shape[0] > 1:
1328
+ raise ValueError("ChatStreamer only supports batch size 1")
1329
+ elif len(value.shape) > 1:
1330
+ value = value[0]
1331
+
1332
+ if not self.received_inputs:
1333
+ # The first received value is input_ids, ignore here
1334
+ self.received_inputs = True
1335
+ return
1336
+
1337
+ self.cache.extend(value.tolist())
1338
+ token = self.tokenizer.decode(self.cache, skip_special_tokens=True)
1339
+ if token.strip() != "<|im_end|>":
1340
+ self.response = self.response + token
1341
+ history = self.history + [(self.query, self.response)]
1342
+ self.queue.put((self.response, history))
1343
+ self.cache = []
1344
+ else:
1345
+ self.end()
1346
+
1347
+ def end(self):
1348
+ self.queue.put(None)
1349
+
1350
+ def stream_producer():
1351
+ return self.chat(
1352
+ tokenizer=tokenizer,
1353
+ query=query,
1354
+ streamer=ChatStreamer(tokenizer=tokenizer),
1355
+ history=history,
1356
+ max_new_tokens=max_new_tokens,
1357
+ do_sample=do_sample,
1358
+ temperature=temperature,
1359
+ top_p=top_p,
1360
+ **kwargs,
1361
+ )
1362
+
1363
+ def consumer():
1364
+ producer = threading.Thread(target=stream_producer)
1365
+ producer.start()
1366
+ while True:
1367
+ res = response_queue.get()
1368
+ if res is None:
1369
+ return
1370
+ yield res
1371
+
1372
+ return consumer()
1373
+
1374
+
1375
+ # Copied from transformers.model.llama.modeling_llama.LlamaForSequenceClassification with Llama->InternLM2
1376
+ @add_start_docstrings(
1377
+ """
1378
+ The InternLM2 Model transformer with a sequence classification head on top (linear layer).
1379
+
1380
+ [`InternLM2ForSequenceClassification`] uses the last token in order to do the classification,
1381
+ as other causal models (e.g. GPT-2) do.
1382
+
1383
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1384
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1385
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1386
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1387
+ each row of the batch).
1388
+ """,
1389
+ InternLM2_START_DOCSTRING,
1390
+ )
1391
+ class InternLM2ForSequenceClassification(InternLM2PreTrainedModel):
1392
+ def __init__(self, config):
1393
+ super().__init__(config)
1394
+ self.num_labels = config.num_labels
1395
+ self.model = InternLM2Model(config)
1396
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1397
+
1398
+ # Initialize weights and apply final processing
1399
+ self.post_init()
1400
+
1401
+ def get_input_embeddings(self):
1402
+ return self.model.tok_embeddings
1403
+
1404
+ def set_input_embeddings(self, value):
1405
+ self.model.tok_embeddings = value
1406
+
1407
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
1408
+ def forward(
1409
+ self,
1410
+ input_ids: torch.LongTensor = None,
1411
+ attention_mask: Optional[torch.Tensor] = None,
1412
+ position_ids: Optional[torch.LongTensor] = None,
1413
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1414
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1415
+ labels: Optional[torch.LongTensor] = None,
1416
+ use_cache: Optional[bool] = None,
1417
+ output_attentions: Optional[bool] = None,
1418
+ output_hidden_states: Optional[bool] = None,
1419
+ return_dict: Optional[bool] = None,
1420
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1421
+ r"""
1422
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1423
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1424
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1425
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1426
+ """
1427
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1428
+
1429
+ transformer_outputs = self.model(
1430
+ input_ids,
1431
+ attention_mask=attention_mask,
1432
+ position_ids=position_ids,
1433
+ past_key_values=past_key_values,
1434
+ inputs_embeds=inputs_embeds,
1435
+ use_cache=use_cache,
1436
+ output_attentions=output_attentions,
1437
+ output_hidden_states=output_hidden_states,
1438
+ return_dict=return_dict,
1439
+ )
1440
+ hidden_states = transformer_outputs[0]
1441
+ logits = self.score(hidden_states)
1442
+
1443
+ if input_ids is not None:
1444
+ batch_size = input_ids.shape[0]
1445
+ else:
1446
+ batch_size = inputs_embeds.shape[0]
1447
+
1448
+ if self.config.pad_token_id is None and batch_size != 1:
1449
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1450
+ if self.config.pad_token_id is None:
1451
+ sequence_lengths = -1
1452
+ else:
1453
+ if input_ids is not None:
1454
+ sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
1455
+ logits.device
1456
+ )
1457
+ else:
1458
+ sequence_lengths = -1
1459
+
1460
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1461
+
1462
+ loss = None
1463
+ if labels is not None:
1464
+ labels = labels.to(logits.device)
1465
+ if self.config.problem_type is None:
1466
+ if self.num_labels == 1:
1467
+ self.config.problem_type = "regression"
1468
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1469
+ self.config.problem_type = "single_label_classification"
1470
+ else:
1471
+ self.config.problem_type = "multi_label_classification"
1472
+
1473
+ if self.config.problem_type == "regression":
1474
+ loss_fct = MSELoss()
1475
+ if self.num_labels == 1:
1476
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1477
+ else:
1478
+ loss = loss_fct(pooled_logits, labels)
1479
+ elif self.config.problem_type == "single_label_classification":
1480
+ loss_fct = CrossEntropyLoss()
1481
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1482
+ elif self.config.problem_type == "multi_label_classification":
1483
+ loss_fct = BCEWithLogitsLoss()
1484
+ loss = loss_fct(pooled_logits, labels)
1485
+ if not return_dict:
1486
+ output = (pooled_logits,) + transformer_outputs[1:]
1487
+ return ((loss,) + output) if loss is not None else output
1488
+
1489
+ return SequenceClassifierOutputWithPast(
1490
+ loss=loss,
1491
+ logits=pooled_logits,
1492
+ past_key_values=transformer_outputs.past_key_values,
1493
+ hidden_states=transformer_outputs.hidden_states,
1494
+ attentions=transformer_outputs.attentions,
1495
+ )
modeling_projector.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import PreTrainedModel
5
+ from transformers.activations import ACT2FN
6
+
7
+ from .configuration_projector import ProjectorConfig
8
+
9
+
10
+ class ProjectorModel(PreTrainedModel):
11
+ _auto_class = 'AutoModel'
12
+ config_class = ProjectorConfig
13
+ base_model_prefix = 'model'
14
+ supports_gradient_checkpointing = True
15
+
16
+ def __init__(self, config: ProjectorConfig) -> None:
17
+ super().__init__(config)
18
+ self.gradient_checkpointing = False
19
+
20
+ modules = [
21
+ nn.Linear(
22
+ config.visual_hidden_size,
23
+ config.llm_hidden_size,
24
+ bias=config.bias)
25
+ ]
26
+ for _ in range(1, config.depth):
27
+ modules.append(ACT2FN[config.hidden_act])
28
+ modules.append(
29
+ nn.Linear(
30
+ config.llm_hidden_size,
31
+ config.llm_hidden_size,
32
+ bias=config.bias))
33
+ self.model = nn.Sequential(*modules)
34
+
35
+ def enable_input_require_grads(self):
36
+
37
+ def make_inputs_require_grad(module, input, output):
38
+ output.requires_grad_(True)
39
+
40
+ self.model.register_forward_hook(make_inputs_require_grad)
41
+
42
+ def _set_gradient_checkpointing(self, module, value=False):
43
+ if isinstance(module, ProjectorModel):
44
+ module.gradient_checkpointing = value
45
+
46
+ def forward(self, x):
47
+ if self.gradient_checkpointing and self.training:
48
+ layer_outputs = torch.utils.checkpoint.checkpoint(self.model, x)
49
+ else:
50
+ layer_outputs = self.model(x)
51
+ return layer_outputs
modeling_wemm.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional, Tuple, Union
3
+ import torch
4
+ import torch.utils.checkpoint
5
+ from torch import nn
6
+ from transformers import PreTrainedModel
7
+ from transformers.activations import ACT2FN
8
+ from transformers.cache_utils import Cache
9
+ from transformers.modeling_outputs import ModelOutput
10
+ from transformers.utils import (
11
+ add_start_docstrings,
12
+ add_start_docstrings_to_model_forward,
13
+ logging,
14
+ replace_return_docstrings,
15
+ )
16
+ from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, AutoConfig
17
+ from .configuration_wemm import WeMMConfig
18
+ from .vision_model import Idefics2VisionTransformer
19
+ from .connector import Idefics2Connector
20
+ from .image_processor_2k import Idefics2ImageProcessor
21
+ from .modeling_downsampler import DownsamplerModel
22
+ from .modeling_projector import ProjectorModel
23
+ from .modeling_internlm2 import InternLM2ForCausalLM
24
+ from .tokenization_internlm2 import InternLM2Tokenizer
25
+ from peft import PeftModel
26
+ from peft import PeftConfig
27
+ import os
28
+ from PIL import Image
29
+ import numpy as np
30
+ IMAGE_TOKEN_INDEX = -200
31
+ DEFAULT_IMAGE_TOKEN = "<image>"
32
+ IGNORE_INDEX = -100
33
+ from transformers import StoppingCriteria
34
+ from transformers import PreTrainedTokenizerFast, StoppingCriteriaList
35
+ import torch.nn.functional as F
36
+ class StopWordStoppingCriteria(StoppingCriteria):
37
+ """StopWord stopping criteria."""
38
+ def __init__(self, tokenizer, stop_word):
39
+ self.tokenizer = tokenizer
40
+ self.stop_word = stop_word
41
+ self.length = len(self.stop_word)
42
+ def __call__(self, input_ids, *args, **kwargs) -> bool:
43
+ cur_text = self.tokenizer.decode(input_ids[0])
44
+ cur_text = cur_text.replace('\r', '').replace('\n', '')
45
+ return cur_text[-self.length:] == self.stop_word
46
+ def get_stop_criteria(
47
+ tokenizer,
48
+ stop_words=[],
49
+ ):
50
+ stop_criteria = StoppingCriteriaList()
51
+ for word in stop_words:
52
+ stop_criteria.append(StopWordStoppingCriteria(tokenizer, word))
53
+ return stop_criteria
54
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
55
+ assert embed_dim % 2 == 0
56
+ # use half of dimensions to encode grid_h
57
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H, W, D/2)
58
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H, W, D/2)
59
+ emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
60
+ return emb
61
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
62
+ """
63
+ embed_dim: output dimension for each position
64
+ pos: a list of positions to be encoded: size (M,)
65
+ out: (M, D)
66
+ """
67
+ assert embed_dim % 2 == 0
68
+ omega = np.arange(embed_dim // 2, dtype=np.float)
69
+ omega /= embed_dim / 2.
70
+ omega = 1. / 10000**omega # (D/2,)
71
+ pos = np.squeeze(pos) # (1, H, W) -> (H, W)
72
+ out = np.einsum('hw,d->hwd', pos, omega) # (H, W, D/2), outer product
73
+ emb_sin = np.sin(out) # (H, W, D/2)
74
+ emb_cos = np.cos(out) # (H, W, D/2)
75
+ emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D)
76
+ return emb
77
+ # 2D sine-cosine position embedding
78
+ # References:
79
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
80
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
81
+ # --------------------------------------------------------
82
+ def get_2d_sincos_pos_embed(embed_dim, grid_size_h, grid_size_w, cls_token=False):
83
+ """
84
+ grid_size: int of the grid height and width
85
+ return:
86
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
87
+ """
88
+ grid_h = np.arange(grid_size_h, dtype=np.float32)
89
+ grid_w = np.arange(grid_size_w, dtype=np.float32)
90
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
91
+ grid = np.stack(grid, axis=0)
92
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
93
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
94
+ if cls_token:
95
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
96
+ return pos_embed
97
+ def recover_navit_subimages_with_pos_emb(
98
+ sub_image_hidden_states,
99
+ attention_mask,
100
+ num_sub_images,
101
+ visual_embedding_group,
102
+ pos_hidden_size,
103
+ thumbnail_only=False):
104
+ if num_sub_images < 0:
105
+ num_sub_images = 0
106
+ _slice = int(np.sqrt(num_sub_images))
107
+ N, L, D = sub_image_hidden_states.shape
108
+ _, H, W = attention_mask.shape
109
+ if thumbnail_only is True:
110
+ num_sub_images += 1
111
+ sub_image_hidden_states = sub_image_hidden_states.reshape(-1, num_sub_images, H, W, D)
112
+ attention_mask = attention_mask.reshape(-1, num_sub_images, H, W)
113
+ if thumbnail_only is True:
114
+ sub_image_hidden_states = sub_image_hidden_states[:, -1:, :, :, :]
115
+ attention_mask = attention_mask[:, -1:, :, :]
116
+ _slice = 1
117
+ def _infer_ori_image_patch_shape(sub_image_attention_mask):
118
+ ind_h, ind_w = torch.where(sub_image_attention_mask > 0)
119
+ return torch.max(ind_h) + 1, torch.max(ind_w) + 1
120
+ def _pad_to_same(image_hidden):
121
+ _dtype = image_hidden.dtype
122
+ visual_downsample_stride = int(np.sqrt(visual_embedding_group))
123
+ full_h, full_w, _ = image_hidden.shape
124
+ target_h, target_w = H * _slice, W * _slice
125
+ # ensure all contents are included during downsampling
126
+ to_pad_h = (target_h - full_h) + (
127
+ visual_downsample_stride - target_h % visual_downsample_stride) % visual_downsample_stride
128
+ to_pad_w = (target_w - full_w) + (
129
+ visual_downsample_stride - target_w % visual_downsample_stride) % visual_downsample_stride
130
+ # (H,W,D) -> (1,D,H,W) to support replicate padding
131
+ image_hidden = image_hidden.permute(2, 0, 1).unsqueeze(0)
132
+ pad_size = (0, to_pad_w, 0, to_pad_h)
133
+ # (1,D,H,W) -> (H,W,D)
134
+ image_hidden = F.pad(image_hidden.to(torch.float32), pad_size, mode='replicate').squeeze(0).permute(1, 2, 0)
135
+ return image_hidden.to(_dtype)
136
+
137
+ image_hidden_states = list()
138
+ valid_image_token = list()
139
+ image_2d_pos = list()
140
+ for batch_id in range(len(sub_image_hidden_states)):
141
+ ori_h, ori_w = _infer_ori_image_patch_shape(attention_mask[batch_id][0])
142
+ full_h, full_w = ori_h * _slice, ori_w * _slice
143
+ # (S,H,W,D) -> (S_h,S_w,H,W,D) -> (S_h,H,S_w,W,D) -> (S_h*H,S_w*W,D)
144
+ this_image_hidden = sub_image_hidden_states[batch_id][:, 0:ori_h, 0:ori_w, :] \
145
+ .view(_slice, _slice, ori_h, ori_w, D).permute(0, 2, 1, 3, 4).contiguous().view(full_h, full_w, D)
146
+ pos_emb = get_2d_sincos_pos_embed(pos_hidden_size, grid_size_h=full_h,
147
+ grid_size_w=full_w) # (H, W, D)
148
+ pos_emb = torch.tensor(pos_emb, dtype=this_image_hidden.dtype, device=this_image_hidden.device)
149
+ image_hidden_states.append(_pad_to_same(this_image_hidden))
150
+ image_2d_pos.append(_pad_to_same(pos_emb))
151
+ valid_image_token.append([full_h, full_w])
152
+ image_hidden_states = torch.stack(image_hidden_states)
153
+ image_2d_pos = torch.stack(image_2d_pos)
154
+ valid_image_token = torch.tensor(valid_image_token, dtype=torch.int64)
155
+ return image_hidden_states, image_2d_pos, valid_image_token
156
+ def visiual_token_downsample(
157
+ visual_downsampler,
158
+ image_hidden_states,
159
+ valid_image_token,
160
+ visual_embedding_group,
161
+ image_2d_pos):
162
+ if image_2d_pos is not None:
163
+ image_hidden_states = image_hidden_states + image_2d_pos
164
+ image_hidden_states = visual_downsampler(image_hidden_states)
165
+ valid_image_token = torch.ceil(valid_image_token / np.sqrt(visual_embedding_group)).to(torch.int64)
166
+ return image_hidden_states, valid_image_token
167
+ def merge_native_qformer(
168
+ clip_embeddings_native_patch,
169
+ valid_image_token_shape,
170
+ clip_embeddings_qformer,
171
+ visual_source_spliter,
172
+ num_sub_images):
173
+
174
+ def add_split_token_for_qformer_token(qformer_emb):
175
+ # + 1 for thumbnail
176
+ len_per_token = int(qformer_emb.size(0) // (num_sub_images + 1))
177
+ qformer_emb_with_spliter = list()
178
+ for i in range(num_sub_images + 1):
179
+ qformer_emb_with_spliter.append(
180
+ visual_source_spliter(torch.tensor([2 * i]).to(visual_source_spliter.weight.device))
181
+ )
182
+ qformer_emb_with_spliter.append(qformer_emb[i * len_per_token:(i + 1) * len_per_token])
183
+ qformer_emb_with_spliter.append(
184
+ visual_source_spliter(torch.tensor([2 * i + 1]).to(visual_source_spliter.weight.device))
185
+ )
186
+ return torch.cat(qformer_emb_with_spliter, dim=0)
187
+
188
+ merged_visual_embeddings = list()
189
+ for batch_id in range(clip_embeddings_native_patch.size(0)):
190
+ h, w = valid_image_token_shape[batch_id]
191
+ native_patch_emb = clip_embeddings_native_patch[batch_id][:h, :w, :].reshape(h*w, -1)
192
+ if clip_embeddings_qformer is not None:
193
+ qformer_emb = clip_embeddings_qformer[batch_id]
194
+ qformer_emb = add_split_token_for_qformer_token(qformer_emb)
195
+ merged_visual_embeddings.append(
196
+ torch.cat(
197
+ [visual_source_spliter(torch.tensor([10]).to(visual_source_spliter.weight.device)),
198
+ native_patch_emb,
199
+ visual_source_spliter(torch.tensor([11]).to(visual_source_spliter.weight.device)),
200
+ qformer_emb],
201
+ dim=0))
202
+ else:
203
+ merged_visual_embeddings.append(
204
+ torch.cat(
205
+ [visual_source_spliter(torch.tensor([0]).to(visual_source_spliter.weight.device)),
206
+ native_patch_emb,
207
+ visual_source_spliter(torch.tensor([1]).to(visual_source_spliter.weight.device))],
208
+ dim=0))
209
+
210
+ return merged_visual_embeddings
211
+ class WemmForConditionalGeneration(PreTrainedModel):
212
+ config_class = WeMMConfig
213
+ def __init__(self, config: WeMMConfig):
214
+ super().__init__(config)
215
+
216
+ self.vision_tower = Idefics2VisionTransformer(config.vision_config)
217
+ self.image_processor = Idefics2ImageProcessor(config.image_processor)
218
+ self.language_model = InternLM2ForCausalLM(config.text_config)
219
+ self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_path, trust_remote_code=True, encode_special_tokens=True)
220
+ self.downsampler = DownsamplerModel(config.downsampler_config)
221
+ self.visual_source_spliter_emb = torch.nn.Embedding(**config.spliter_emb_config)
222
+
223
+
224
+ self.gen_config = GenerationConfig(
225
+ max_new_tokens=512,
226
+ do_sample=False,
227
+ eos_token_id=self.tokenizer.eos_token_id,
228
+ pad_token_id=self.tokenizer.pad_token_id
229
+ if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id,
230
+ )
231
+ self.do_image_splitting = config.do_image_splitting
232
+ self.stop_criteria = get_stop_criteria(
233
+ tokenizer=self.tokenizer, stop_words=['<|im_end|>'])
234
+ self.config = config
235
+
236
+ def chat(self, conversations, gen_config=None):
237
+ prompt = ""
238
+ image_path = conversations[0]['images'][0]
239
+ for i,ann in enumerate(conversations):
240
+ if(ann['role'] == 'user'):
241
+ prompt += f"<|im_start|>user\n{ann['content']}<|im_end|>\n"
242
+ elif(ann['role'] == 'assistant'):
243
+ prompt += f"<|im_start|>assistant\n{ann['content']}<|im_end|>\n"
244
+ prompt += '<|im_start|>assistant\n'
245
+ with torch.no_grad():
246
+ output = self.generate(image_path, prompt, gen_config=gen_config)
247
+ return output
248
+
249
+ def chat_v2(self, conversations, images, gen_config=None):
250
+ image_path = images["images"][0]
251
+ with torch.no_grad():
252
+ output = self.generate(image_path, conversations, gen_config=gen_config)
253
+ return output
254
+
255
+ # assert
256
+ def mm_generate(self, image_path, prompt, gen_config=None):
257
+ prompt = "<image>" + '\n' + prompt
258
+ prompt = f"<|im_start|>user\n{prompt}<|im_end|><|im_start|>assistant\n"
259
+ return self.generate(image_path,prompt,gen_config)
260
+
261
+ def generate(self, image_path, prompt, gen_config=None):
262
+ image = Image.open(image_path).convert('RGB')
263
+ navit980_images = self.image_processor([[image]], return_tensors="pt", do_image_splitting=self.do_image_splitting)
264
+ batch_size_navit = navit980_images['pixel_values'].shape[0]
265
+ navit_pixel_values = navit980_images['navit_pixel_values'].cuda()
266
+ navit_patch_attention_mask = navit980_images["pixel_attention_mask"].cuda()
267
+ clip_visual_outputs = self.vision_tower(pixel_values=navit_pixel_values,patch_attention_mask=navit_patch_attention_mask,).last_hidden_state
268
+
269
+ super_image_hidden_states, image_2d_pos, valid_image_token_shape = \
270
+ recover_navit_subimages_with_pos_emb(
271
+ clip_visual_outputs, navit_patch_attention_mask, num_sub_images=-1,
272
+ visual_embedding_group=16,
273
+ pos_hidden_size=4096,
274
+ thumbnail_only=True
275
+ )
276
+ clip_embeddings_native_patch, valid_image_token_shape = visiual_token_downsample(
277
+ self.downsampler,
278
+ super_image_hidden_states, valid_image_token_shape,
279
+ visual_embedding_group=16, image_2d_pos=None
280
+ )
281
+ merged_visual_embeddings = \
282
+ merge_native_qformer(
283
+ clip_embeddings_native_patch,
284
+ valid_image_token_shape,
285
+ clip_embeddings_qformer=None,
286
+ visual_source_spliter=self.visual_source_spliter_emb,
287
+ num_sub_images=-1
288
+ )
289
+ chunk_encode = []
290
+ for idx, chunk in enumerate(prompt.split(DEFAULT_IMAGE_TOKEN)):
291
+ if idx == 0:
292
+ cur_encode = self.tokenizer.encode(chunk)
293
+ else:
294
+ cur_encode = self.tokenizer.encode(chunk, add_special_tokens=False)
295
+ chunk_encode.append(cur_encode)
296
+ assert len(chunk_encode) == 2
297
+ ids = []
298
+ for idx, cur_chunk_encode in enumerate(chunk_encode):
299
+ ids.extend(cur_chunk_encode)
300
+ if idx != len(chunk_encode) - 1:
301
+ ids.append(IMAGE_TOKEN_INDEX)
302
+ ids = torch.tensor(ids).cuda().unsqueeze(0)
303
+ pixel_values = None
304
+ mm_inputs = self.prepare_inputs_labels_for_multimodal(
305
+ llm=self.language_model, input_ids=ids, pixel_values=pixel_values, clip_embeddings=merged_visual_embeddings)
306
+ generate_output = self.language_model.generate(
307
+ **mm_inputs,
308
+ generation_config=gen_config if gen_config is not None else self.gen_config,
309
+ streamer=None,
310
+ bos_token_id=self.tokenizer.bos_token_id,
311
+ stopping_criteria=self.stop_criteria
312
+ )
313
+ predict = self.tokenizer.decode(
314
+ generate_output[0], skip_special_tokens=True).strip()
315
+ return predict
316
+ def get_valid_visual_embedding(self, embedding, valid_token_shape):
317
+ if valid_token_shape is None:
318
+ return embedding
319
+ h, w = valid_token_shape
320
+ return embedding[:h, :w, :].reshape(h*w, -1)
321
+ # Modified from https://github.com/haotian-liu/LLaVA/blob/82fc5e0e5f4393a4c26851fa32c69ab37ea3b146/llava/model/llava_arch.py#L99 # noqa: E501
322
+ def prepare_inputs_labels_for_multimodal(
323
+ self,
324
+ llm: PreTrainedModel,
325
+ input_ids: torch.LongTensor = None,
326
+ position_ids: Optional[torch.LongTensor] = None,
327
+ attention_mask: Optional[torch.Tensor] = None,
328
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
329
+ labels: Optional[torch.LongTensor] = None,
330
+ pixel_values: Optional[torch.FloatTensor] = None,
331
+ clip_embeddings: Optional[torch.FloatTensor] = None,
332
+ hard_coded_max_len: Optional[int] = None,
333
+ **kwargs):
334
+ if pixel_values is None and clip_embeddings is None:
335
+ return {
336
+ 'input_ids': input_ids,
337
+ 'position_ids': position_ids,
338
+ 'attention_mask': attention_mask,
339
+ 'past_key_values': past_key_values,
340
+ 'inputs_embeds': None,
341
+ 'labels': labels
342
+ }
343
+
344
+ _labels = labels
345
+ _position_ids = position_ids
346
+ _attention_mask = attention_mask
347
+ if attention_mask is None:
348
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
349
+ else:
350
+ attention_mask = attention_mask.bool()
351
+ if position_ids is None:
352
+ position_ids = torch.arange(
353
+ 0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
354
+ if labels is None:
355
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
356
+
357
+ input_ids = [
358
+ cur_input_ids[cur_attention_mask]
359
+ for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
360
+ ]
361
+ labels = [
362
+ cur_labels[cur_attention_mask]
363
+ for cur_labels, cur_attention_mask in zip(labels, attention_mask)
364
+ ]
365
+
366
+ new_inputs_embeds = []
367
+ new_labels = []
368
+ new_img_masks = []
369
+ cur_image_idx = 0
370
+ for batch_idx, cur_input_ids in enumerate(input_ids):
371
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
372
+ if num_images == 0:
373
+ cur_pixel_values = pixel_values[cur_image_idx] if pixel_values is not None else None
374
+ #master_print(f"batchid {batch_idx} cur_image_idx {cur_image_idx} get valid visual from {clip_embeddings[cur_image_idx].shape}")
375
+ cur_clip_emb = clip_embeddings[cur_image_idx] if clip_embeddings is not None else None
376
+ cur_inputs_embeds_1 = llm.get_input_embeddings()(cur_input_ids)
377
+ if cur_clip_emb is not None and cur_pixel_values is not None:
378
+ cur_inputs_embeds = torch.cat(
379
+ [cur_inputs_embeds_1, cur_pixel_values[0:0], cur_clip_emb[0:0]], dim=0)
380
+ elif cur_pixel_values is not None:
381
+ cur_inputs_embeds = torch.cat(
382
+ [cur_inputs_embeds_1, cur_pixel_values[0:0]], dim=0)
383
+ elif cur_clip_emb is not None:
384
+ cur_inputs_embeds = torch.cat(
385
+ [cur_inputs_embeds_1, cur_clip_emb[0:0]], dim=0)
386
+ else:
387
+ raise ValueError
388
+ new_inputs_embeds.append(cur_inputs_embeds)
389
+ new_labels.append(labels[batch_idx])
390
+ new_img_masks.append(torch.zeros(
391
+ cur_inputs_embeds.shape[0], device=cur_inputs_embeds.device).bool())
392
+ cur_image_idx += 1
393
+ continue
394
+
395
+ image_token_indices = [-1] + torch.where(
396
+ cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [
397
+ cur_input_ids.shape[0]
398
+ ]
399
+ cur_input_ids_noim = []
400
+ cur_labels = labels[batch_idx]
401
+ cur_labels_noim = []
402
+ for i in range(len(image_token_indices) - 1):
403
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] +
404
+ 1:image_token_indices[i +
405
+ 1]])
406
+ cur_labels_noim.append(cur_labels[image_token_indices[i] +
407
+ 1:image_token_indices[i + 1]])
408
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
409
+ cur_inputs_embeds = llm.get_input_embeddings()(
410
+ torch.cat(cur_input_ids_noim))
411
+ cur_inputs_embeds_no_im = torch.split(
412
+ cur_inputs_embeds, split_sizes, dim=0)
413
+ cur_new_inputs_embeds = []
414
+ cur_new_labels = []
415
+ cur_img_masks = []
416
+
417
+ for i in range(num_images + 1):
418
+ cur_new_inputs_embeds.append(cur_inputs_embeds_no_im[i])
419
+ cur_new_labels.append(cur_labels_noim[i])
420
+ cur_img_masks.append(torch.zeros(
421
+ cur_inputs_embeds_no_im[i].shape[0], device=cur_inputs_embeds_no_im[i].device).bool())
422
+ if i < num_images:
423
+ cur_pixel_values = pixel_values[cur_image_idx] if pixel_values is not None else None
424
+ cur_clip_emb = clip_embeddings[cur_image_idx] if clip_embeddings is not None else None
425
+
426
+ cur_image_idx += 1
427
+
428
+ # discrete token embeddings
429
+ if cur_pixel_values is not None:
430
+ cur_new_inputs_embeds.append(cur_pixel_values)
431
+ cur_img_masks.append(torch.ones(
432
+ cur_pixel_values.shape[0], device=cur_pixel_values.device).bool())
433
+ cur_new_labels.append(
434
+ torch.full((cur_pixel_values.shape[0], ),
435
+ IGNORE_INDEX,
436
+ device=cur_labels.device,
437
+ dtype=cur_labels.dtype))
438
+
439
+ # clip embeddings
440
+ if cur_clip_emb is not None:
441
+ cur_new_inputs_embeds.append(cur_clip_emb)
442
+ cur_img_masks.append(torch.ones(
443
+ cur_clip_emb.shape[0], device=cur_clip_emb.device).bool())
444
+ cur_new_labels.append(
445
+ torch.full((cur_clip_emb.shape[0],),
446
+ IGNORE_INDEX,
447
+ device=cur_labels.device,
448
+ dtype=cur_labels.dtype))
449
+
450
+ cur_new_inputs_embeds = torch.cat(cur_new_inputs_embeds)
451
+ cur_new_labels = torch.cat(cur_new_labels)
452
+ cur_img_masks = torch.cat(cur_img_masks)
453
+
454
+ new_inputs_embeds.append(cur_new_inputs_embeds)
455
+ new_labels.append(cur_new_labels)
456
+ new_img_masks.append(cur_img_masks)
457
+
458
+ # Combine them
459
+ max_len = max(x.shape[0] for x in new_inputs_embeds)
460
+ if hard_coded_max_len is not None:
461
+ max_len = min(max_len, hard_coded_max_len)
462
+ batch_size = len(new_inputs_embeds)
463
+
464
+ new_inputs_embeds_padded = []
465
+ new_labels_padded = torch.full((batch_size, max_len),
466
+ IGNORE_INDEX,
467
+ dtype=new_labels[0].dtype,
468
+ device=new_labels[0].device)
469
+ attention_mask = torch.zeros((batch_size, max_len),
470
+ dtype=attention_mask.dtype,
471
+ device=attention_mask.device)
472
+ position_ids = torch.zeros((batch_size, max_len),
473
+ dtype=position_ids.dtype,
474
+ device=position_ids.device)
475
+ new_img_masks_padded = torch.zeros((batch_size, max_len), device=new_img_masks[0].device).bool()
476
+
477
+ for i, (cur_new_embed,
478
+ cur_new_labels, cur_new_img_masks) in enumerate(zip(new_inputs_embeds, new_labels, new_img_masks)):
479
+ cur_new_embed = cur_new_embed[:max_len]
480
+ cur_new_labels = cur_new_labels[:max_len]
481
+ cur_new_img_masks = cur_new_img_masks[:max_len]
482
+
483
+ cur_len = cur_new_embed.shape[0]
484
+ new_inputs_embeds_padded.append(
485
+ torch.cat((cur_new_embed,
486
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]),
487
+ dtype=cur_new_embed.dtype,
488
+ device=cur_new_embed.device)),
489
+ dim=0))
490
+ if cur_len > 0:
491
+ new_labels_padded[i, :cur_len] = cur_new_labels
492
+ attention_mask[i, :cur_len] = True
493
+ position_ids[i, :cur_len] = torch.arange(
494
+ 0,
495
+ cur_len,
496
+ dtype=position_ids.dtype,
497
+ device=position_ids.device)
498
+ new_img_masks_padded[i, :cur_len] = cur_new_img_masks
499
+
500
+ new_inputs_embeds = torch.stack(new_inputs_embeds_padded, dim=0)
501
+
502
+ if _labels is None:
503
+ new_labels = None
504
+ else:
505
+ new_labels = new_labels_padded
506
+
507
+ if _attention_mask is None:
508
+ attention_mask = None
509
+ else:
510
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
511
+
512
+ if _position_ids is None:
513
+ position_ids = None
514
+
515
+ prepared_data = {
516
+ 'input_ids': None,
517
+ 'position_ids': position_ids,
518
+ 'attention_mask': attention_mask,
519
+ 'past_key_values': past_key_values,
520
+ 'inputs_embeds': new_inputs_embeds,
521
+ 'labels': new_labels,
522
+ }
523
+ #if pixel_values is not None:
524
+ prepared_data.update({'im_mask': new_img_masks_padded})
525
+ return prepared_data
526
+
527
+ AutoConfig.register("wemm_hf", WeMMConfig)
528
+ AutoModel.register(WeMMConfig, WemmForConditionalGeneration)
special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "eos_token": "</s>",
4
+ "pad_token": "</s>",
5
+ "unk_token": "<unk>"
6
+ }
tokenization_internlm2.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on transformers/src/transformers/models/llama/tokenization_llama.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Tokenization classes for InternLM."""
19
+ import os
20
+ from shutil import copyfile
21
+ from typing import Any, Dict, List, Optional, Tuple
22
+
23
+ import sentencepiece as spm
24
+ from transformers.tokenization_utils import PreTrainedTokenizer
25
+ from transformers.utils import logging
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"}
30
+
31
+ PRETRAINED_VOCAB_FILES_MAP = {}
32
+
33
+
34
+ # Modified from transformers.model.llama.tokenization_llama.LlamaTokenizer
35
+ class InternLM2Tokenizer(PreTrainedTokenizer):
36
+ """
37
+ Construct a InternLM2 tokenizer. Based on byte-level Byte-Pair-Encoding.
38
+
39
+ Args:
40
+ vocab_file (`str`):
41
+ Path to the vocabulary file.
42
+ """
43
+
44
+ vocab_files_names = VOCAB_FILES_NAMES
45
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
46
+ model_input_names = ["input_ids", "attention_mask"]
47
+ _auto_class = "AutoTokenizer"
48
+
49
+ def __init__(
50
+ self,
51
+ vocab_file,
52
+ unk_token="<unk>",
53
+ bos_token="<s>",
54
+ eos_token="</s>",
55
+ pad_token="</s>",
56
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
57
+ add_bos_token=True,
58
+ add_eos_token=False,
59
+ decode_with_prefix_space=False,
60
+ clean_up_tokenization_spaces=False,
61
+ **kwargs,
62
+ ):
63
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
64
+ self.vocab_file = vocab_file
65
+ self.add_bos_token = add_bos_token
66
+ self.add_eos_token = add_eos_token
67
+ self.decode_with_prefix_space = decode_with_prefix_space
68
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
69
+ self.sp_model.Load(vocab_file)
70
+ self._no_prefix_space_tokens = None
71
+ super().__init__(
72
+ bos_token=bos_token,
73
+ eos_token=eos_token,
74
+ unk_token=unk_token,
75
+ pad_token=pad_token,
76
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
77
+ **kwargs,
78
+ )
79
+
80
+ @property
81
+ def no_prefix_space_tokens(self):
82
+ if self._no_prefix_space_tokens is None:
83
+ vocab = self.convert_ids_to_tokens(list(range(self.vocab_size)))
84
+ self._no_prefix_space_tokens = {i for i, tok in enumerate(vocab) if not tok.startswith("▁")}
85
+ return self._no_prefix_space_tokens
86
+
87
+ @property
88
+ def vocab_size(self):
89
+ """Returns vocab size"""
90
+ return self.sp_model.get_piece_size()
91
+
92
+ @property
93
+ def bos_token_id(self) -> Optional[int]:
94
+ return self.sp_model.bos_id()
95
+
96
+ @property
97
+ def eos_token_id(self) -> Optional[int]:
98
+ return self.sp_model.eos_id()
99
+
100
+ def get_vocab(self):
101
+ """Returns vocab as a dict"""
102
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
103
+ vocab.update(self.added_tokens_encoder)
104
+ return vocab
105
+
106
+ def _tokenize(self, text):
107
+ """Returns a tokenized string."""
108
+ return self.sp_model.encode(text, out_type=str)
109
+
110
+ def _convert_token_to_id(self, token):
111
+ """Converts a token (str) in an id using the vocab."""
112
+ return self.sp_model.piece_to_id(token)
113
+
114
+ def _convert_id_to_token(self, index):
115
+ """Converts an index (integer) in a token (str) using the vocab."""
116
+ token = self.sp_model.IdToPiece(index)
117
+ return token
118
+
119
+ def _maybe_add_prefix_space(self, tokens, decoded):
120
+ if tokens and tokens[0] not in self.no_prefix_space_tokens:
121
+ return " " + decoded
122
+ else:
123
+ return decoded
124
+
125
+ def convert_tokens_to_string(self, tokens):
126
+ """Converts a sequence of tokens (string) in a single string."""
127
+ current_sub_tokens = []
128
+ out_string = ""
129
+ prev_is_special = False
130
+ for token in tokens:
131
+ # make sure that special tokens are not decoded using sentencepiece model
132
+ if token in self.all_special_tokens:
133
+ if not prev_is_special:
134
+ out_string += " "
135
+ out_string += self.sp_model.decode(current_sub_tokens) + token
136
+ prev_is_special = True
137
+ current_sub_tokens = []
138
+ else:
139
+ current_sub_tokens.append(token)
140
+ prev_is_special = False
141
+ out_string += self.sp_model.decode(current_sub_tokens)
142
+ out_string = self.clean_up_tokenization(out_string)
143
+ out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string)
144
+ return out_string[1:]
145
+
146
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
147
+ """
148
+ Save the vocabulary and special tokens file to a directory.
149
+
150
+ Args:
151
+ save_directory (`str`):
152
+ The directory in which to save the vocabulary.
153
+
154
+ Returns:
155
+ `Tuple(str)`: Paths to the files saved.
156
+ """
157
+ if not os.path.isdir(save_directory):
158
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
159
+ return
160
+ out_vocab_file = os.path.join(
161
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
162
+ )
163
+
164
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
165
+ copyfile(self.vocab_file, out_vocab_file)
166
+ elif not os.path.isfile(self.vocab_file):
167
+ with open(out_vocab_file, "wb") as fi:
168
+ content_spiece_model = self.sp_model.serialized_model_proto()
169
+ fi.write(content_spiece_model)
170
+
171
+ return (out_vocab_file,)
172
+
173
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
174
+ if self.add_bos_token:
175
+ bos_token_ids = [self.bos_token_id]
176
+ else:
177
+ bos_token_ids = []
178
+
179
+ output = bos_token_ids + token_ids_0
180
+
181
+ if token_ids_1 is not None:
182
+ output = output + token_ids_1
183
+
184
+ if self.add_eos_token:
185
+ output = output + [self.eos_token_id]
186
+
187
+ return output
188
+
189
+ def get_special_tokens_mask(
190
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
191
+ ) -> List[int]:
192
+ """
193
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
194
+ special tokens using the tokenizer `prepare_for_model` method.
195
+
196
+ Args:
197
+ token_ids_0 (`List[int]`):
198
+ List of IDs.
199
+ token_ids_1 (`List[int]`, *optional*):
200
+ Optional second list of IDs for sequence pairs.
201
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
202
+ Whether or not the token list is already formatted with special tokens for the model.
203
+
204
+ Returns:
205
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
206
+ """
207
+ if already_has_special_tokens:
208
+ return super().get_special_tokens_mask(
209
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
210
+ )
211
+
212
+ if token_ids_1 is None:
213
+ return [1] + ([0] * len(token_ids_0)) + [1]
214
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
215
+
216
+ def create_token_type_ids_from_sequences(
217
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
218
+ ) -> List[int]:
219
+ """
220
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
221
+ use of token type ids, therefore a list of zeros is returned.
222
+
223
+ Args:
224
+ token_ids_0 (`List[int]`):
225
+ List of IDs.
226
+ token_ids_1 (`List[int]`, *optional*):
227
+ Optional second list of IDs for sequence pairs.
228
+
229
+ Returns:
230
+ `List[int]`: List of zeros.
231
+ """
232
+ eos = [self.eos_token_id]
233
+
234
+ if token_ids_1 is None:
235
+ return len(token_ids_0 + eos) * [0]
236
+ return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
tokenization_internlm2_fast.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on transformers/src/transformers/models/llama/tokenization_llama_fast.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Tokenization Fast class for InternLM."""
19
+ import os
20
+ from shutil import copyfile
21
+ from typing import Any, Dict, Optional, Tuple
22
+
23
+ from tokenizers import processors, decoders, Tokenizer, normalizers
24
+ from tokenizers.models import BPE
25
+
26
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
27
+ from transformers.utils import logging
28
+
29
+ from transformers.convert_slow_tokenizer import (
30
+ SLOW_TO_FAST_CONVERTERS,
31
+ SpmConverter,
32
+ SentencePieceExtractor,
33
+ )
34
+
35
+ from .tokenization_internlm2 import InternLM2Tokenizer
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+ VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"}
40
+
41
+ # Modified from transformers.convert_slow_tokenizer.LlamaConverter
42
+ class InternLM2Converter(SpmConverter):
43
+ handle_byte_fallback = True
44
+
45
+ def vocab(self, proto):
46
+ vocab = [
47
+ ("<unk>", 0.0),
48
+ ("<s>", 0.0),
49
+ ("</s>", 0.0),
50
+ ]
51
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
52
+ return vocab
53
+
54
+ def unk_id(self, proto):
55
+ unk_id = 0
56
+ return unk_id
57
+
58
+ def decoder(self, replacement, add_prefix_space):
59
+ return decoders.Sequence(
60
+ [
61
+ decoders.Replace("▁", " "),
62
+ decoders.ByteFallback(),
63
+ decoders.Fuse(),
64
+ decoders.Strip(content=" ", left=1),
65
+ ]
66
+ )
67
+
68
+ def tokenizer(self, proto):
69
+ model_type = proto.trainer_spec.model_type
70
+ vocab_scores = self.vocab(proto)
71
+ # special tokens
72
+ added_tokens = self.original_tokenizer.added_tokens_decoder
73
+ for i in range(len(vocab_scores)):
74
+ piece, score = vocab_scores[i]
75
+ if i in added_tokens:
76
+ vocab_scores[i] = (added_tokens[i].content, score)
77
+ if model_type == 1:
78
+ raise RuntimeError("InternLM2 is supposed to be a BPE model!")
79
+
80
+ elif model_type == 2:
81
+ _, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
82
+ bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
83
+ tokenizer = Tokenizer(
84
+ BPE(bpe_vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, byte_fallback=True)
85
+ )
86
+ tokenizer.add_special_tokens(
87
+ [ added_token for index, added_token in added_tokens.items()]
88
+ )
89
+ else:
90
+ raise Exception(
91
+ "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
92
+ )
93
+
94
+ return tokenizer
95
+
96
+ def normalizer(self, proto):
97
+ normalizers_list = []
98
+ if proto.normalizer_spec.add_dummy_prefix:
99
+ normalizers_list.append(normalizers.Prepend(prepend="▁"))
100
+ normalizers_list.append(normalizers.Replace(pattern=" ", content="▁"))
101
+ return normalizers.Sequence(normalizers_list)
102
+
103
+ def pre_tokenizer(self, replacement, add_prefix_space):
104
+ return None
105
+
106
+ SLOW_TO_FAST_CONVERTERS["InternLM2Tokenizer"] = InternLM2Converter
107
+
108
+
109
+ # Modified from transformers.model.llama.tokenization_llama_fast.LlamaTokenizerFast -> InternLM2TokenizerFast
110
+ class InternLM2TokenizerFast(PreTrainedTokenizerFast):
111
+ vocab_files_names = VOCAB_FILES_NAMES
112
+ slow_tokenizer_class = InternLM2Tokenizer
113
+ padding_side = "left"
114
+ model_input_names = ["input_ids", "attention_mask"]
115
+ _auto_class = "AutoTokenizer"
116
+
117
+ def __init__(
118
+ self,
119
+ vocab_file,
120
+ unk_token="<unk>",
121
+ bos_token="<s>",
122
+ eos_token="</s>",
123
+ pad_token="</s>",
124
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
125
+ add_bos_token=True,
126
+ add_eos_token=False,
127
+ decode_with_prefix_space=False,
128
+ clean_up_tokenization_spaces=False,
129
+ **kwargs,
130
+ ):
131
+ super().__init__(
132
+ vocab_file=vocab_file,
133
+ unk_token=unk_token,
134
+ bos_token=bos_token,
135
+ eos_token=eos_token,
136
+ pad_token=pad_token,
137
+ sp_model_kwargs=sp_model_kwargs,
138
+ add_bos_token=add_bos_token,
139
+ add_eos_token=add_eos_token,
140
+ decode_with_prefix_space=decode_with_prefix_space,
141
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
142
+ **kwargs,
143
+ )
144
+ self._add_bos_token = add_bos_token
145
+ self._add_eos_token = add_eos_token
146
+ self.update_post_processor()
147
+ self.vocab_file = vocab_file
148
+
149
+ @property
150
+ def can_save_slow_tokenizer(self) -> bool:
151
+ return os.path.isfile(self.vocab_file) if self.vocab_file else False
152
+
153
+ def update_post_processor(self):
154
+ """
155
+ Updates the underlying post processor with the current `bos_token` and `eos_token`.
156
+ """
157
+ bos = self.bos_token
158
+ bos_token_id = self.bos_token_id
159
+ if bos is None and self.add_bos_token:
160
+ raise ValueError("add_bos_token = True but bos_token = None")
161
+
162
+ eos = self.eos_token
163
+ eos_token_id = self.eos_token_id
164
+ if eos is None and self.add_eos_token:
165
+ raise ValueError("add_eos_token = True but eos_token = None")
166
+
167
+ single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
168
+ pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
169
+
170
+ special_tokens = []
171
+ if self.add_bos_token:
172
+ special_tokens.append((bos, bos_token_id))
173
+ if self.add_eos_token:
174
+ special_tokens.append((eos, eos_token_id))
175
+ self._tokenizer.post_processor = processors.TemplateProcessing(
176
+ single=single, pair=pair, special_tokens=special_tokens
177
+ )
178
+
179
+ @property
180
+ def add_eos_token(self):
181
+ return self._add_eos_token
182
+
183
+ @property
184
+ def add_bos_token(self):
185
+ return self._add_bos_token
186
+
187
+ @add_eos_token.setter
188
+ def add_eos_token(self, value):
189
+ self._add_eos_token = value
190
+ self.update_post_processor()
191
+
192
+ @add_bos_token.setter
193
+ def add_bos_token(self, value):
194
+ self._add_bos_token = value
195
+ self.update_post_processor()
196
+
197
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
198
+ if not self.can_save_slow_tokenizer:
199
+ raise ValueError(
200
+ "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
201
+ "tokenizer."
202
+ )
203
+
204
+ if not os.path.isdir(save_directory):
205
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
206
+ return
207
+ out_vocab_file = os.path.join(
208
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
209
+ )
210
+
211
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
212
+ copyfile(self.vocab_file, out_vocab_file)
213
+
214
+ return (out_vocab_file,)
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f868398fc4e05ee1e8aeba95ddf18ddcc45b8bce55d5093bead5bbf80429b48b
3
+ size 1477754
tokenizer_config.json ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoTokenizer": [
4
+ "tokenization_internlm2.InternLM2Tokenizer",
5
+ "tokenization_internlm2_fast.InternLM2TokenizerFast"
6
+ ]
7
+ },
8
+ "bos_token": "<s>",
9
+ "clean_up_tokenization_spaces": false,
10
+ "eos_token": "</s>",
11
+ "model_max_length": 1000000000000000019884624838656,
12
+ "pad_token": "</s>",
13
+ "tokenizer_class": "InternLM2Tokenizer",
14
+ "unk_token": "<unk>",
15
+ "added_tokens_decoder": {
16
+ "0": {
17
+ "content": "<unk>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false,
22
+ "special": true
23
+ },
24
+ "1": {
25
+ "content": "<s>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false,
30
+ "special": true
31
+ },
32
+ "2": {
33
+ "content": "</s>",
34
+ "lstrip": false,
35
+ "normalized": false,
36
+ "rstrip": false,
37
+ "single_word": false,
38
+ "special": true
39
+ },
40
+ "92543": {
41
+ "content": "<|im_start|>",
42
+ "lstrip": false,
43
+ "normalized": false,
44
+ "rstrip": false,
45
+ "single_word": false,
46
+ "special": true
47
+ },
48
+ "92542": {
49
+ "content": "<|im_end|>",
50
+ "lstrip": false,
51
+ "normalized": false,
52
+ "rstrip": false,
53
+ "single_word": false,
54
+ "special": true
55
+ },
56
+ "92541": {
57
+ "content": "<|action_start|>",
58
+ "lstrip": false,
59
+ "normalized": false,
60
+ "rstrip": false,
61
+ "single_word": false,
62
+ "special": true
63
+ },
64
+ "92540": {
65
+ "content": "<|action_end|>",
66
+ "lstrip": false,
67
+ "normalized": false,
68
+ "rstrip": false,
69
+ "single_word": false,
70
+ "special": true
71
+ },
72
+ "92539": {
73
+ "content": "<|interpreter|>",
74
+ "lstrip": false,
75
+ "normalized": false,
76
+ "rstrip": false,
77
+ "single_word": false,
78
+ "special": true
79
+ },
80
+ "92538": {
81
+ "content": "<|plugin|>",
82
+ "lstrip": false,
83
+ "normalized": false,
84
+ "rstrip": false,
85
+ "single_word": false,
86
+ "special": true
87
+ }
88
+ },
89
+ "chat_template": "{{ bos_token }}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
90
+ }
vision_model.py ADDED
@@ -0,0 +1,717 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, PreTrainedModel
2
+
3
+ import inspect
4
+ import math
5
+ from dataclasses import dataclass
6
+ from typing import Dict, List, Optional, Tuple, Union
7
+ import json
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint
12
+ from torch import nn
13
+ from torch.nn import CrossEntropyLoss
14
+
15
+ from transformers.activations import ACT2FN
16
+ from transformers.cache_utils import Cache, DynamicCache
17
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
18
+ from transformers.modeling_outputs import BaseModelOutput, ModelOutput
19
+ from transformers.utils import (
20
+ add_start_docstrings,
21
+ add_start_docstrings_to_model_forward,
22
+ is_flash_attn_2_available,
23
+ is_flash_attn_greater_or_equal_2_10,
24
+ logging,
25
+ replace_return_docstrings,
26
+ )
27
+
28
+ if is_flash_attn_2_available():
29
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
30
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
31
+
32
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
33
+
34
+
35
+ class Idefics2VisionConfig(PretrainedConfig):
36
+ r"""
37
+ This is the configuration class to store the configuration of a [`Idefics2VisionModel`]. It is used to instantiate a
38
+ Idefics2 vision encoder according to the specified arguments, defining the model architecture. Instantiating a
39
+ configuration with the defaults will yield a similar configuration to that of the SigLIP checkpoint
40
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) used in the Idefics2 model
41
+ [HuggingFaceM4/idefics2-8b](https://huggingface.co/HuggingFaceM4/idefics2-8b).
42
+
43
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
44
+ documentation from [`PretrainedConfig`] for more information.
45
+
46
+ Args:
47
+ hidden_size (`int`, *optional*, defaults to 768):
48
+ Dimensionality of the encoder layers and the pooler layer.
49
+ intermediate_size (`int`, *optional*, defaults to 3072):
50
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
51
+ num_hidden_layers (`int`, *optional*, defaults to 12):
52
+ Number of hidden layers in the Transformer encoder.
53
+ num_attention_heads (`int`, *optional*, defaults to 12):
54
+ Number of attention heads for each attention layer in the Transformer encoder.
55
+ num_channels (`int`, *optional*, defaults to 3):
56
+ Number of channels in the input images.
57
+ image_size (`int`, *optional*, defaults to 224):
58
+ The size (resolution) of each image.
59
+ patch_size (`int`, *optional*, defaults to 32):
60
+ The size (resolution) of each patch.
61
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
62
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
63
+ `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
64
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
65
+ The epsilon used by the layer normalization layers.
66
+ attention_dropout (`float`, *optional*, defaults to 0.0):
67
+ The dropout ratio for the attention probabilities.
68
+ intializer_range (`float`, *optional*, defaults to 0.02):
69
+ The standard deviation for initializing all weight matrices in the model.
70
+
71
+ Example:
72
+
73
+ ```python
74
+ >>> from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer
75
+ >>> from transformers.models.idefics2.configuration_idefics2 import Idefics2VisionConfig
76
+
77
+ >>> # Initializing a Idefics2VisionConfig with google/siglip-base-patch16-224 style configuration
78
+ >>> configuration = Idefics2VisionConfig()
79
+
80
+ >>> # Initializing a Idefics2VisionTransformer (with random weights) from the google/siglip-base-patch16-224 style configuration
81
+ >>> model = Idefics2VisionTransformer(configuration)
82
+
83
+ >>> # Accessing the model configuration
84
+ >>> configuration = model.config
85
+ ```"""
86
+ _auto_class = 'AutoConfig'
87
+ model_type = "Idefics2VisionConfig"
88
+
89
+ def __init__(
90
+ self,
91
+ hidden_size=768,
92
+ intermediate_size=3072,
93
+ num_hidden_layers=12,
94
+ num_attention_heads=12,
95
+ num_channels=3,
96
+ image_size=224,
97
+ patch_size=32,
98
+ hidden_act="gelu_pytorch_tanh",
99
+ layer_norm_eps=1e-6,
100
+ attention_dropout=0.0,
101
+ initializer_range=0.02,
102
+ model_type='Idefics2VisionConfig',
103
+ **kwargs,
104
+ ):
105
+ super().__init__(**kwargs)
106
+
107
+ self.hidden_size = hidden_size
108
+ self.intermediate_size = intermediate_size
109
+ self.num_hidden_layers = num_hidden_layers
110
+ self.num_attention_heads = num_attention_heads
111
+ self.num_channels = num_channels
112
+ self.patch_size = patch_size
113
+ self.image_size = image_size
114
+ self.attention_dropout = attention_dropout
115
+ self.layer_norm_eps = layer_norm_eps
116
+ self.hidden_act = hidden_act
117
+ self.initializer_range = initializer_range
118
+ """
119
+ @classmethod
120
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "PretrainedConfig":
121
+
122
+ with open(pretrained_model_name_or_path, "r", encoding="utf-8") as f:
123
+ config_dict = json.load(f)
124
+
125
+ cls = Idefics2VisionConfig(
126
+ hidden_size=config_dict["hidden_size"],
127
+ image_size=config_dict["image_size"],
128
+ intermediate_size = config_dict["intermediate_size"],
129
+ model_type=config_dict["model_type"],
130
+ num_attention_heads = config_dict["num_attention_heads"],
131
+ num_hidden_layers = config_dict["num_hidden_layers"],
132
+ patch_size = config_dict["patch_size"]
133
+ )
134
+
135
+ return cls
136
+ """
137
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
138
+ def _get_unpad_data(attention_mask):
139
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
140
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
141
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
142
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
143
+ return (
144
+ indices,
145
+ cu_seqlens,
146
+ max_seqlen_in_batch,
147
+ )
148
+
149
+ # Copied from transformers.models.siglip.modeling_siglip.SiglipAttention with Siglip->Idefics2Vision
150
+ class Idefics2VisionAttention(nn.Module):
151
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
152
+
153
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
154
+ def __init__(self, config):
155
+ super().__init__()
156
+ self.config = config
157
+ self.embed_dim = config.hidden_size
158
+ self.num_heads = config.num_attention_heads
159
+ self.head_dim = self.embed_dim // self.num_heads
160
+ if self.head_dim * self.num_heads != self.embed_dim:
161
+ raise ValueError(
162
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
163
+ f" {self.num_heads})."
164
+ )
165
+ self.scale = self.head_dim**-0.5
166
+ self.dropout = config.attention_dropout
167
+
168
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
169
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
170
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
171
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
172
+
173
+ # Ignore copy
174
+ self.is_causal = False
175
+
176
+ def forward(
177
+ self,
178
+ hidden_states: torch.Tensor,
179
+ attention_mask: Optional[torch.Tensor] = None,
180
+ output_attentions: Optional[bool] = False,
181
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
182
+ """Input shape: Batch x Time x Channel"""
183
+
184
+ batch_size, q_len, _ = hidden_states.size()
185
+
186
+ query_states = self.q_proj(hidden_states)
187
+ key_states = self.k_proj(hidden_states)
188
+ value_states = self.v_proj(hidden_states)
189
+
190
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
191
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
192
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
193
+
194
+ k_v_seq_len = key_states.shape[-2]
195
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
196
+
197
+ if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
198
+ raise ValueError(
199
+ f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
200
+ f" {attn_weights.size()}"
201
+ )
202
+
203
+ if attention_mask is not None:
204
+ if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
205
+ raise ValueError(
206
+ f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
207
+ )
208
+ attn_weights = attn_weights + attention_mask
209
+
210
+ # upcast attention to fp32
211
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
212
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
213
+ attn_output = torch.matmul(attn_weights, value_states)
214
+
215
+ if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
216
+ raise ValueError(
217
+ f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
218
+ f" {attn_output.size()}"
219
+ )
220
+
221
+ attn_output = attn_output.transpose(1, 2).contiguous()
222
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
223
+
224
+ attn_output = self.out_proj(attn_output)
225
+
226
+ return attn_output, attn_weights
227
+
228
+
229
+ class Idefics2VisionFlashAttention2(Idefics2VisionAttention):
230
+ """
231
+ Idefics2Vision flash attention module. This module inherits from `Idefics2VisionAttention` as the weights of the module stays
232
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
233
+ flash attention and deal with padding tokens in case the input contains any of them.
234
+ """
235
+
236
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
237
+ def __init__(self, *args, **kwargs):
238
+ super().__init__(*args, **kwargs)
239
+
240
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
241
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
242
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
243
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
244
+
245
+ def forward(
246
+ self,
247
+ hidden_states: torch.Tensor,
248
+ attention_mask: Optional[torch.LongTensor] = None,
249
+ position_ids: Optional[torch.LongTensor] = None,
250
+ past_key_value: Optional[Cache] = None,
251
+ output_attentions: bool = False,
252
+ use_cache: bool = False,
253
+ **kwargs,
254
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
255
+
256
+
257
+ output_attentions = False
258
+
259
+ bsz, q_len, _ = hidden_states.size()
260
+
261
+ query_states = self.q_proj(hidden_states)
262
+ key_states = self.k_proj(hidden_states)
263
+ value_states = self.v_proj(hidden_states)
264
+
265
+ # Flash attention requires the input to have the shape
266
+ # batch_size x seq_length x head_dim x hidden_dim
267
+ # therefore we just need to keep the original shape
268
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
269
+ key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
270
+ value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
271
+
272
+ kv_seq_len = key_states.shape[-2]
273
+ if past_key_value is not None:
274
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
275
+
276
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
277
+ # to be able to avoid many of these transpose/reshape/view.
278
+ query_states = query_states.transpose(1, 2)
279
+ key_states = key_states.transpose(1, 2)
280
+ value_states = value_states.transpose(1, 2)
281
+
282
+ dropout_rate = self.dropout if self.training else 0.0
283
+
284
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
285
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
286
+ # cast them back in the correct dtype just to be sure everything works as expected.
287
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
288
+ # in fp32. (Idefics2VisionRMSNorm handles it correctly)
289
+
290
+ input_dtype = query_states.dtype
291
+ if input_dtype == torch.float32:
292
+ if torch.is_autocast_enabled():
293
+ target_dtype = torch.get_autocast_gpu_dtype()
294
+ # Handle the case where the model is quantized
295
+ elif hasattr(self.config, "_pre_quantization_dtype"):
296
+ target_dtype = self.config._pre_quantization_dtype
297
+ else:
298
+ target_dtype = self.q_proj.weight.dtype
299
+
300
+ logger.warning_once(
301
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
302
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
303
+ f" {target_dtype}."
304
+ )
305
+
306
+ query_states = query_states.to(target_dtype)
307
+ key_states = key_states.to(target_dtype)
308
+ value_states = value_states.to(target_dtype)
309
+
310
+ attn_output = self._flash_attention_forward(
311
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
312
+ )
313
+
314
+ attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous()
315
+ attn_output = self.out_proj(attn_output)
316
+
317
+ if not output_attentions:
318
+ attn_weights = None
319
+
320
+ return attn_output, attn_weights
321
+
322
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
323
+ def _flash_attention_forward(
324
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
325
+ ):
326
+ """
327
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
328
+ first unpad the input, then computes the attention scores and pad the final attention scores.
329
+
330
+ Args:
331
+ query_states (`torch.Tensor`):
332
+ Input query states to be passed to Flash Attention API
333
+ key_states (`torch.Tensor`):
334
+ Input key states to be passed to Flash Attention API
335
+ value_states (`torch.Tensor`):
336
+ Input value states to be passed to Flash Attention API
337
+ attention_mask (`torch.Tensor`):
338
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
339
+ position of padding tokens and 1 for the position of non-padding tokens.
340
+ dropout (`float`):
341
+ Attention dropout
342
+ softmax_scale (`float`, *optional*):
343
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
344
+ """
345
+ if not self._flash_attn_uses_top_left_mask:
346
+ causal = self.is_causal
347
+ else:
348
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
349
+ causal = self.is_causal and query_length != 1
350
+
351
+ # Contains at least one padding token in the sequence
352
+ if attention_mask is not None:
353
+ batch_size = query_states.shape[0]
354
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
355
+ query_states, key_states, value_states, attention_mask, query_length
356
+ )
357
+
358
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
359
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
360
+
361
+ attn_output_unpad = flash_attn_varlen_func(
362
+ query_states,
363
+ key_states,
364
+ value_states,
365
+ cu_seqlens_q=cu_seqlens_q,
366
+ cu_seqlens_k=cu_seqlens_k,
367
+ max_seqlen_q=max_seqlen_in_batch_q,
368
+ max_seqlen_k=max_seqlen_in_batch_k,
369
+ dropout_p=dropout,
370
+ softmax_scale=softmax_scale,
371
+ causal=causal,
372
+ )
373
+
374
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
375
+ else:
376
+ attn_output = flash_attn_func(
377
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
378
+ )
379
+
380
+ return attn_output
381
+
382
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
383
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
384
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
385
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
386
+
387
+ key_layer = index_first_axis(
388
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
389
+ )
390
+ value_layer = index_first_axis(
391
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
392
+ )
393
+ if query_length == kv_seq_len:
394
+ query_layer = index_first_axis(
395
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
396
+ )
397
+ cu_seqlens_q = cu_seqlens_k
398
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
399
+ indices_q = indices_k
400
+ elif query_length == 1:
401
+ max_seqlen_in_batch_q = 1
402
+ cu_seqlens_q = torch.arange(
403
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
404
+ ) # There is a memcpy here, that is very bad.
405
+ indices_q = cu_seqlens_q[:-1]
406
+ query_layer = query_layer.squeeze(1)
407
+ else:
408
+ # The -q_len: slice assumes left padding.
409
+ attention_mask = attention_mask[:, -query_length:]
410
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
411
+
412
+ return (
413
+ query_layer,
414
+ key_layer,
415
+ value_layer,
416
+ indices_q,
417
+ (cu_seqlens_q, cu_seqlens_k),
418
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
419
+ )
420
+
421
+ IDEFICS_VISION_ATTENTION_CLASSES = {
422
+ "eager": Idefics2VisionAttention,
423
+ "flash_attention_2": Idefics2VisionFlashAttention2,
424
+ }
425
+
426
+ # Copied from transformers.models.siglip.modeling_siglip.SiglipMLP with Siglip->Idefics2Vision
427
+ class Idefics2VisionMLP(nn.Module):
428
+ def __init__(self, config):
429
+ super().__init__()
430
+ self.config = config
431
+ self.activation_fn = ACT2FN[config.hidden_act]
432
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
433
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
434
+
435
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
436
+ hidden_states = self.fc1(hidden_states)
437
+ hidden_states = self.activation_fn(hidden_states)
438
+ hidden_states = self.fc2(hidden_states)
439
+ return hidden_states
440
+
441
+ class Idefics2EncoderLayer(nn.Module):
442
+ def __init__(self, config: Idefics2VisionConfig):
443
+ super().__init__()
444
+ self.embed_dim = config.hidden_size
445
+ self.self_attn = IDEFICS_VISION_ATTENTION_CLASSES[config._attn_implementation](config)
446
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
447
+ self.mlp = Idefics2VisionMLP(config)
448
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
449
+
450
+ # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
451
+ def forward(
452
+ self,
453
+ hidden_states: torch.Tensor,
454
+ attention_mask: torch.Tensor,
455
+ output_attentions: Optional[bool] = False,
456
+ ) -> Tuple[torch.FloatTensor]:
457
+ """
458
+ Args:
459
+ hidden_states (`torch.FloatTensor`):
460
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
461
+ attention_mask (`torch.FloatTensor`):
462
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
463
+ output_attentions (`bool`, *optional*, defaults to `False`):
464
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
465
+ returned tensors for more detail.
466
+ """
467
+ residual = hidden_states
468
+
469
+ hidden_states = self.layer_norm1(hidden_states)
470
+ hidden_states, attn_weights = self.self_attn(
471
+ hidden_states=hidden_states,
472
+ attention_mask=attention_mask,
473
+ output_attentions=output_attentions,
474
+ )
475
+ hidden_states = residual + hidden_states
476
+
477
+ residual = hidden_states
478
+ hidden_states = self.layer_norm2(hidden_states)
479
+ hidden_states = self.mlp(hidden_states)
480
+ hidden_states = residual + hidden_states
481
+
482
+ outputs = (hidden_states,)
483
+
484
+ if output_attentions:
485
+ outputs += (attn_weights,)
486
+
487
+ return outputs
488
+
489
+ # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoder with Siglip->Idefics2
490
+ class Idefics2Encoder(nn.Module):
491
+ """
492
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
493
+ [`Idefics2EncoderLayer`].
494
+
495
+ Args:
496
+ config: Idefics2VisionConfig
497
+ """
498
+
499
+ def __init__(self, config: Idefics2VisionConfig):
500
+ super().__init__()
501
+ self.config = config
502
+ self.layers = nn.ModuleList([Idefics2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
503
+ self.gradient_checkpointing = False
504
+
505
+ # Ignore copy
506
+ def forward(
507
+ self,
508
+ inputs_embeds,
509
+ attention_mask: Optional[torch.Tensor] = None,
510
+ output_attentions: Optional[bool] = None,
511
+ output_hidden_states: Optional[bool] = None,
512
+ return_dict: Optional[bool] = None,
513
+ ) -> Union[Tuple, BaseModelOutput]:
514
+ r"""
515
+ Args:
516
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
517
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
518
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
519
+ than the model's internal embedding lookup matrix.
520
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
521
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
522
+
523
+ - 1 for tokens that are **not masked**,
524
+ - 0 for tokens that are **masked**.
525
+
526
+ [What are attention masks?](../glossary#attention-mask)
527
+ output_attentions (`bool`, *optional*):
528
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
529
+ returned tensors for more detail.
530
+ output_hidden_states (`bool`, *optional*):
531
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
532
+ for more detail.
533
+ return_dict (`bool`, *optional*):
534
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
535
+ """
536
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
537
+ output_hidden_states = (
538
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
539
+ )
540
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
541
+
542
+ encoder_states = () if output_hidden_states else None
543
+ all_attentions = () if output_attentions else None
544
+
545
+ hidden_states = inputs_embeds
546
+ for encoder_layer in self.layers:
547
+ if output_hidden_states:
548
+ encoder_states = encoder_states + (hidden_states,)
549
+ if self.gradient_checkpointing and self.training:
550
+ layer_outputs = self._gradient_checkpointing_func(
551
+ encoder_layer.__call__,
552
+ hidden_states,
553
+ attention_mask,
554
+ output_attentions,
555
+ )
556
+ else:
557
+ layer_outputs = encoder_layer(
558
+ hidden_states,
559
+ attention_mask,
560
+ output_attentions=output_attentions,
561
+ )
562
+
563
+ hidden_states = layer_outputs[0]
564
+
565
+ if output_attentions:
566
+ all_attentions = all_attentions + (layer_outputs[1],)
567
+
568
+ if output_hidden_states:
569
+ encoder_states = encoder_states + (hidden_states,)
570
+
571
+ if not return_dict:
572
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
573
+ return BaseModelOutput(
574
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
575
+ )
576
+
577
+ class Idefics2VisionEmbeddings(nn.Module):
578
+ """
579
+ This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
580
+ resolution.
581
+
582
+ The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
583
+ which allows treating images in their native aspect ratio and without the need to resize them to the same
584
+ fixed size. In particular, we start from the original pre-trained SigLIP model
585
+ (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
586
+ """
587
+
588
+ def __init__(self, config: Idefics2VisionConfig):
589
+ super().__init__()
590
+ self.embed_dim = config.hidden_size
591
+ self.image_size = config.image_size
592
+ self.patch_size = config.patch_size
593
+
594
+ self.patch_embedding = nn.Conv2d(
595
+ in_channels=config.num_channels,
596
+ out_channels=self.embed_dim,
597
+ kernel_size=self.patch_size,
598
+ stride=self.patch_size,
599
+ padding="valid",
600
+ )
601
+
602
+ self.num_patches_per_side = self.image_size // self.patch_size
603
+ self.num_patches = self.num_patches_per_side**2
604
+ self.num_positions = self.num_patches
605
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
606
+
607
+ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor:
608
+ batch_size, _, max_im_h, max_im_w = pixel_values.shape
609
+
610
+ patch_embeds = self.patch_embedding(pixel_values)
611
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
612
+
613
+ max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
614
+ boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)
615
+ position_ids = torch.full(size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0)
616
+
617
+ for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
618
+ nb_patches_h = p_attn_mask[:, 0].sum()
619
+ nb_patches_w = p_attn_mask[0].sum()
620
+
621
+ fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
622
+ fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
623
+
624
+ bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
625
+ bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
626
+
627
+ pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
628
+ position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
629
+
630
+ position_ids = position_ids.to(self.position_embedding.weight.device)
631
+ embeddings = embeddings + self.position_embedding(position_ids)
632
+ return embeddings
633
+
634
+
635
+ class Idefics2VisionTransformer(PreTrainedModel):
636
+ _auto_class = 'AutoModel'
637
+ config_class = Idefics2VisionConfig
638
+ supports_gradient_checkpointing = True
639
+
640
+ def __init__(self, config: Idefics2VisionConfig):
641
+ super().__init__(config)
642
+ embed_dim = config.hidden_size
643
+
644
+ config._attn_implementation = "flash_attention_2"
645
+ self._use_flash_attention_2 = True
646
+ self.config = config
647
+ self.embeddings = Idefics2VisionEmbeddings(config)
648
+ self.encoder = Idefics2Encoder(config)
649
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
650
+
651
+
652
+ def get_input_embeddings(self):
653
+ return self.embeddings
654
+
655
+ def set_input_embeddings(self, value):
656
+ self.embeddings = value
657
+
658
+ def forward(
659
+ self,
660
+ pixel_values,
661
+ patch_attention_mask: Optional[torch.BoolTensor] = None,
662
+ output_attentions: Optional[bool] = None,
663
+ output_hidden_states: Optional[bool] = None,
664
+ return_dict: Optional[bool] = None,
665
+ ) -> Union[Tuple, BaseModelOutput]:
666
+
667
+ pixel_values = pixel_values.to(torch.bfloat16)
668
+
669
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
670
+ output_hidden_states = (
671
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
672
+ )
673
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
674
+
675
+ batch_size = pixel_values.size(0)
676
+ if patch_attention_mask is None:
677
+ patch_size = self.config.patch_size
678
+ patch_attention_mask = torch.ones(
679
+ (
680
+ batch_size,
681
+ pixel_values.size(2) // patch_size,
682
+ pixel_values.size(3) // patch_size,
683
+ )
684
+ )
685
+ patch_attention_mask = patch_attention_mask.to(dtype=torch.bool, device=pixel_values.device)
686
+
687
+
688
+ hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
689
+
690
+ patch_attention_mask = patch_attention_mask.view(batch_size, -1)
691
+ # The call to `_upad_input` in `_flash_attention_forward` is expensive
692
+ # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
693
+ # avoiding passing the attention_mask, which is equivalent to attending to the full sequence
694
+ if not torch.any(~patch_attention_mask):
695
+ patch_attention_mask = None
696
+ elif not self._use_flash_attention_2:
697
+ patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
698
+
699
+ encoder_outputs = self.encoder(
700
+ inputs_embeds=hidden_states,
701
+ attention_mask=patch_attention_mask,
702
+ output_attentions=output_attentions,
703
+ output_hidden_states=output_hidden_states,
704
+ return_dict=return_dict,
705
+ )
706
+
707
+ last_hidden_state = encoder_outputs[0]
708
+ last_hidden_state = self.post_layernorm(last_hidden_state)
709
+
710
+ if not return_dict:
711
+ return (last_hidden_state,) + encoder_outputs[1:]
712
+
713
+ return BaseModelOutput(
714
+ last_hidden_state=last_hidden_state,
715
+ hidden_states=encoder_outputs.hidden_states,
716
+ attentions=encoder_outputs.attentions,
717
+ )