GoodiesHere commited on
Commit
4b87979
1 Parent(s): 3e782e5

Upload folder using huggingface_hub

Browse files
.mdl ADDED
Binary file (50 Bytes). View file
 
.msc ADDED
Binary file (2.44 kB). View file
 
.mv ADDED
@@ -0,0 +1 @@
 
 
1
+ Revision:master,CreatedAt:1734398114
README.md ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ license_name: qwen
4
+ license_link: https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct/blob/main/LICENSE
5
+ tags:
6
+ - video
7
+ - video-understanding
8
+ - vision
9
+ - multimodal
10
+ - conversational
11
+ - qwen
12
+ - custom_code
13
+ - instruction-tuning
14
+ datasets:
15
+ - ApolloBench
16
+ - Video-MME
17
+ - MLVU
18
+ - LongVideoBench
19
+ - NExTQA
20
+ - PerceptionTest
21
+ inference: true
22
+ ---
23
+
24
+ # Apollo: An Exploration of Video Understanding in Large Multimodal Models
25
+
26
+ Apollo is a family of Large Multimodal Models (LMMs) that push the state-of-the-art in video understanding. It supports tasks including:
27
+ - Long-form video comprehension
28
+ - Temporal reasoning
29
+ - Complex video question-answering
30
+ - Multi-turn conversations grounded in video content
31
+
32
+ Apollo models excel at handling hour-long videos, balancing speed and accuracy through strategic design decisions. Our models outperform most 7B competitors at just 3B parameters and even rival 30B-scale models.
33
+
34
+ **Key Highlights:**
35
+ - **Scaling Consistency**: Design decisions validated on smaller models and datasets effectively transfer to larger scales, reducing computation and experimentation costs.
36
+ - **Efficient Video Sampling**: fps sampling and advanced token resampling strategies (e.g., Perceiver) yield stronger temporal perception.
37
+ - **Encoder Synergies**: Combining SigLIP-SO400M (image) with InternVideo2 (video) delivers a robust representation, outperforming single encoders on temporal tasks.
38
+ - **ApolloBench**: A streamlined evaluation benchmark (41x faster) that focuses on true video understanding capabilities.
39
+
40
+ ## Quick Start
41
+
42
+ **Installation:**
43
+ ```bash
44
+ pip install -e .
45
+ pip install flash-attn --no-build-isolation
46
+ ```
47
+
48
+ **Inference Example:**
49
+ ```python
50
+ import torch
51
+ from transformers import AutoModelForCausalLM
52
+ from apollo.mm_utils import (
53
+ KeywordsStoppingCriteria,
54
+ tokenizer_mm_token,
55
+ ApolloMMLoader
56
+ )
57
+ from apollo.conversations import conv_templates, SeparatorStyle
58
+ from huggingface_hub import snapshot_download
59
+
60
+ model_url = "Apollo-LMMs/Apollo-3B-t32"
61
+ model_path = snapshot_download(model_url, repo_type="model")
62
+
63
+ device = "cuda" if torch.cuda.is_available() else "cpu"
64
+ model = AutoModelForCausalLM.from_pretrained(
65
+ model_path,
66
+ trust_remote_code=True,
67
+ low_cpu_mem_usage=True
68
+ ).to(device=device, dtype=torch.bfloat16)
69
+
70
+ tokenizer = model.tokenizer
71
+ vision_processors = model.vision_tower.vision_processor
72
+ config = model.config
73
+ num_repeat_token = config.mm_connector_cfg['num_output_tokens']
74
+ mm_processor = ApolloMMLoader(
75
+ vision_processors,
76
+ config.clip_duration,
77
+ frames_per_clip=4,
78
+ clip_sampling_ratio=0.65,
79
+ model_max_length=config.model_max_length,
80
+ device=device,
81
+ num_repeat_token=num_repeat_token
82
+ )
83
+
84
+ video_path = "path/to/video.mp4"
85
+ question = "Describe this video in detail"
86
+ mm_data, replace_string = mm_processor.load_video(video_path)
87
+
88
+ conv = conv_templates["qwen_2"].copy()
89
+ conv.append_message(conv.roles[0], replace_string + "\n\n" + question)
90
+ conv.append_message(conv.roles[1], None)
91
+
92
+ prompt = conv.get_prompt()
93
+ input_ids = tokenizer_mm_token(prompt, tokenizer, return_tensors="pt").unsqueeze(0).to(device)
94
+
95
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
96
+ stopping_criteria = KeywordsStoppingCriteria([stop_str], tokenizer, input_ids)
97
+
98
+ with torch.inference_mode():
99
+ output_ids = model.generate(
100
+ input_ids,
101
+ vision_input=[mm_data],
102
+ data_types=['video'],
103
+ do_sample=True,
104
+ temperature=0.4,
105
+ max_new_tokens=256,
106
+ top_p=0.7,
107
+ use_cache=True,
108
+ num_beams=1,
109
+ stopping_criteria=[stopping_criteria]
110
+ )
111
+
112
+ pred = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
113
+ print(pred)
114
+ ```
115
+
116
+ ## Citation
117
+
118
+ If you find this project useful, please consider citing:
119
+ ```BibTeX
120
+ @article{zohar2024apollo,
121
+ title={Apollo: An Exploration of Video Understanding in Large Multimodal Models},
122
+ author={Zohar, Orr and Wang, Xiaohan and Dubois, Yann and Mehta, Nikhil and Xiao, Tong and Hansen-Estruch, Philippe and Yu, Licheng and Wang, Xiaofang and Juefei-Xu, Felix and Zhang, Ning and Yeung-Levy, Serena and Xia, Xide},
123
+ journal={arXiv preprint arXiv:2412.10360},
124
+ year={2024}
125
+ }
126
+ ```
127
+
128
+ For more details, visit the [project website](https://apollo-lmms.github.io) or check out the [paper](https://arxiv.org/abs/2412.10360).
config.json ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ApolloForCausalLM"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "attn_implementation": "flash_attention_2",
7
+ "clip_duration": 2,
8
+ "drop_path_rate": 0.0,
9
+ "encode_batch_size": 25,
10
+ "hidden_act": "silu",
11
+ "hidden_size": 1536,
12
+ "image_aspect_ratio": "square",
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 8960,
15
+ "interpolate_mode": "linear",
16
+ "llm_cfg": {
17
+ "add_cross_attention": false,
18
+ "architectures": [
19
+ "Qwen2ForCausalLM"
20
+ ],
21
+ "attention_dropout": 0.0,
22
+ "bad_words_ids": null,
23
+ "begin_suppress_tokens": null,
24
+ "bos_token_id": 151643,
25
+ "chunk_size_feed_forward": 0,
26
+ "cross_attention_hidden_size": null,
27
+ "decoder_start_token_id": null,
28
+ "diversity_penalty": 0.0,
29
+ "do_sample": false,
30
+ "early_stopping": false,
31
+ "encoder_no_repeat_ngram_size": 0,
32
+ "eos_token_id": 151645,
33
+ "exponential_decay_length_penalty": null,
34
+ "finetuning_task": null,
35
+ "forced_bos_token_id": null,
36
+ "forced_eos_token_id": null,
37
+ "hidden_act": "silu",
38
+ "hidden_size": 1536,
39
+ "id2label": {
40
+ "0": "LABEL_0",
41
+ "1": "LABEL_1"
42
+ },
43
+ "initializer_range": 0.02,
44
+ "intermediate_size": 8960,
45
+ "is_decoder": false,
46
+ "is_encoder_decoder": false,
47
+ "label2id": {
48
+ "LABEL_0": 0,
49
+ "LABEL_1": 1
50
+ },
51
+ "length_penalty": 1.0,
52
+ "max_length": 20,
53
+ "max_position_embeddings": 32768,
54
+ "max_window_layers": 21,
55
+ "min_length": 0,
56
+ "model_max_length": 16384,
57
+ "model_type": "qwen2",
58
+ "no_repeat_ngram_size": 0,
59
+ "num_attention_heads": 12,
60
+ "num_beam_groups": 1,
61
+ "num_beams": 1,
62
+ "num_hidden_layers": 28,
63
+ "num_key_value_heads": 2,
64
+ "num_return_sequences": 1,
65
+ "output_attentions": false,
66
+ "output_hidden_states": false,
67
+ "output_scores": false,
68
+ "pad_token_id": null,
69
+ "prefix": null,
70
+ "problem_type": null,
71
+ "pruned_heads": {},
72
+ "remove_invalid_values": false,
73
+ "repetition_penalty": 1.0,
74
+ "return_dict": true,
75
+ "return_dict_in_generate": false,
76
+ "rms_norm_eps": 1e-06,
77
+ "rope_theta": 1000000.0,
78
+ "sep_token_id": null,
79
+ "sliding_window": null,
80
+ "suppress_tokens": null,
81
+ "task_specific_params": null,
82
+ "temperature": 1.0,
83
+ "tf_legacy_loss": false,
84
+ "tie_encoder_decoder": false,
85
+ "tie_word_embeddings": true,
86
+ "tokenizer_class": null,
87
+ "tokenizer_model_max_length": 16384,
88
+ "tokenizer_padding_side": "right",
89
+ "top_k": 50,
90
+ "top_p": 1.0,
91
+ "torch_dtype": "bfloat16",
92
+ "torchscript": false,
93
+ "typical_p": 1.0,
94
+ "use_bfloat16": false,
95
+ "use_cache": true,
96
+ "use_sliding_window": false,
97
+ "vocab_size": 151936
98
+ },
99
+ "max_position_embeddings": 32768,
100
+ "max_window_layers": 21,
101
+ "mm_connector_cfg": {
102
+ "add_cross_attention": false,
103
+ "architectures": [
104
+ "Connector"
105
+ ],
106
+ "attention_dropout": 0.0,
107
+ "bad_words_ids": null,
108
+ "begin_suppress_tokens": null,
109
+ "bos_token_id": null,
110
+ "chunk_size_feed_forward": 0,
111
+ "cross_attention_hidden_size": null,
112
+ "decoder_start_token_id": null,
113
+ "diversity_penalty": 0.0,
114
+ "do_sample": false,
115
+ "early_stopping": false,
116
+ "encoder_no_repeat_ngram_size": 0,
117
+ "eos_token_id": null,
118
+ "exponential_decay_length_penalty": null,
119
+ "ff_multi": 4,
120
+ "finetuning_task": null,
121
+ "forced_bos_token_id": null,
122
+ "forced_eos_token_id": null,
123
+ "hidden_act": "silu",
124
+ "id2label": {
125
+ "0": "LABEL_0",
126
+ "1": "LABEL_1"
127
+ },
128
+ "is_decoder": false,
129
+ "is_encoder_decoder": false,
130
+ "label2id": {
131
+ "LABEL_0": 0,
132
+ "LABEL_1": 1
133
+ },
134
+ "length_penalty": 1.0,
135
+ "max_length": 20,
136
+ "min_length": 0,
137
+ "model_type": "mm_connector",
138
+ "no_repeat_ngram_size": 0,
139
+ "num_beam_groups": 1,
140
+ "num_beams": 1,
141
+ "num_key_value_heads": 4,
142
+ "num_output_tokens": 128,
143
+ "num_patches": 24,
144
+ "num_return_sequences": 1,
145
+ "output_attentions": false,
146
+ "output_hidden_states": false,
147
+ "output_scores": false,
148
+ "pad_token_id": null,
149
+ "prefix": null,
150
+ "problem_type": null,
151
+ "projector_type": "mlp1x_gelu",
152
+ "pruned_heads": {},
153
+ "remove_invalid_values": false,
154
+ "repetition_penalty": 1.0,
155
+ "resampler_depth": 1,
156
+ "resampler_head_dim": 96,
157
+ "resampler_n_heads": 16,
158
+ "resampler_type": "perciver",
159
+ "return_dict": true,
160
+ "return_dict_in_generate": false,
161
+ "rms_norm_eps": 1e-06,
162
+ "sep_token_id": null,
163
+ "suppress_tokens": null,
164
+ "task_specific_params": null,
165
+ "temperature": 1.0,
166
+ "text_hidden_size": 1536,
167
+ "tf_legacy_loss": false,
168
+ "tie_encoder_decoder": false,
169
+ "tie_word_embeddings": true,
170
+ "token_input_shape": [
171
+ 4,
172
+ 27,
173
+ 27
174
+ ],
175
+ "tokenizer_class": null,
176
+ "top_k": 50,
177
+ "top_p": 1.0,
178
+ "torch_dtype": "bfloat16",
179
+ "torchscript": false,
180
+ "typical_p": 1.0,
181
+ "use_bfloat16": false,
182
+ "vision_hidden_size": 2560
183
+ },
184
+ "mm_connector_lr": 0.0001,
185
+ "mm_hidden_size": null,
186
+ "mm_vision_select_feature": "patch",
187
+ "mm_vision_select_layer": -2,
188
+ "model_dtype": "torch.bfloat16",
189
+ "model_type": "apollo",
190
+ "num_attention_heads": 12,
191
+ "num_encode_batch": 0,
192
+ "num_hidden_layers": 28,
193
+ "num_key_value_heads": 2,
194
+ "num_video_frames": null,
195
+ "resume_path": "./work_dirs/final_run/apollo-Qwen2.5-1.5B-Instruct-internvideo2-siglip-so400m-patch14-384-freeze-perciver_128_2-newprompt-ft",
196
+ "rms_norm_eps": 1e-06,
197
+ "rope_scaling": null,
198
+ "rope_theta": 1000000.0,
199
+ "s2": false,
200
+ "s2_max_split_size": 336,
201
+ "s2_scales": "336,672,1008",
202
+ "sliding_window": null,
203
+ "temporal_prompt": true,
204
+ "timestamp_prompt": true,
205
+ "transformers_version": "4.44.0",
206
+ "tune_language_model": true,
207
+ "tune_mm_connector": true,
208
+ "tune_vision_tower": false,
209
+ "use_cache": true,
210
+ "use_mm_patch_token": false,
211
+ "use_mm_start_end": false,
212
+ "use_sliding_window": false,
213
+ "vision_resolution": -1,
214
+ "vision_tower_cfg": {
215
+ "add_cross_attention": false,
216
+ "architectures": null,
217
+ "bad_words_ids": null,
218
+ "begin_suppress_tokens": null,
219
+ "bos_token_id": null,
220
+ "chunk_size_feed_forward": 0,
221
+ "configs": {},
222
+ "cross_attention_hidden_size": null,
223
+ "decoder_start_token_id": null,
224
+ "diversity_penalty": 0.0,
225
+ "do_sample": false,
226
+ "early_stopping": false,
227
+ "encoder_no_repeat_ngram_size": 0,
228
+ "eos_token_id": null,
229
+ "exponential_decay_length_penalty": null,
230
+ "finetuning_task": null,
231
+ "forced_bos_token_id": null,
232
+ "forced_eos_token_id": null,
233
+ "id2label": {
234
+ "0": "LABEL_0",
235
+ "1": "LABEL_1"
236
+ },
237
+ "is_decoder": false,
238
+ "is_encoder_decoder": false,
239
+ "label2id": {
240
+ "LABEL_0": 0,
241
+ "LABEL_1": 1
242
+ },
243
+ "length_penalty": 1.0,
244
+ "max_length": 20,
245
+ "min_length": 0,
246
+ "model_type": "hybrid_vision_tower",
247
+ "no_repeat_ngram_size": 0,
248
+ "num_beam_groups": 1,
249
+ "num_beams": 1,
250
+ "num_return_sequences": 1,
251
+ "num_vision_encoders": 2,
252
+ "output_attentions": false,
253
+ "output_hidden_states": false,
254
+ "output_scores": false,
255
+ "pad_token_id": null,
256
+ "prefix": null,
257
+ "problem_type": null,
258
+ "pruned_heads": {},
259
+ "remove_invalid_values": false,
260
+ "repetition_penalty": 1.0,
261
+ "return_dict": true,
262
+ "return_dict_in_generate": false,
263
+ "sep_token_id": null,
264
+ "suppress_tokens": null,
265
+ "task_specific_params": null,
266
+ "temperature": 1.0,
267
+ "tf_legacy_loss": false,
268
+ "tie_encoder_decoder": false,
269
+ "tie_word_embeddings": true,
270
+ "token_output_shape": [
271
+ 4,
272
+ 27,
273
+ 27
274
+ ],
275
+ "tokenizer_class": null,
276
+ "top_k": 50,
277
+ "top_p": 1.0,
278
+ "torch_dtype": null,
279
+ "torchscript": false,
280
+ "typical_p": 1.0,
281
+ "use_bfloat16": false,
282
+ "vision_towers": [
283
+ "siglip-so400m-patch14-384",
284
+ "internvideo2"
285
+ ]
286
+ },
287
+ "vocab_size": 151936,
288
+ "auto_map": {
289
+ "AutoConfig": "configuration_apollo.ApolloConfig",
290
+ "AutoModelForCausalLM": "modeling_apollo.ApolloForCausalLM"
291
+ },
292
+ "model_max_length": 16384
293
+ }
configuration.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"framework": "pytorch", "task": "others", "allow_remote": true}
configuration_apollo.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #from transformers import PretrainedConfig
2
+ from transformers import PretrainedConfig
3
+
4
+
5
+ class ApolloConfig(PretrainedConfig):
6
+ model_type = "apollo"
7
+ def __init__(
8
+ self,
9
+ llm_cfg=None,
10
+ vision_tower_cfg=None,
11
+ mm_connector_cfg=None,
12
+ architectures=None,
13
+ resume_path=None,
14
+ image_aspect_ratio=None,
15
+ num_video_frames=None,
16
+ mm_vision_select_layer=None,
17
+ mm_vision_select_feature=None,
18
+ use_mm_start_end=False,
19
+ use_mm_patch_token=True,
20
+ mm_connector_lr=None,
21
+ vision_resolution=None,
22
+ interpolate_mode=None,
23
+ clip_duration=None,
24
+ vocab_size=None,
25
+ auto_map=None,
26
+ **kwargs
27
+ ):
28
+ super().__init__(**kwargs)
29
+
30
+ self.architectures = architectures
31
+ self.llm_cfg = llm_cfg
32
+ self.vision_tower_cfg = vision_tower_cfg
33
+ self.mm_connector_cfg = mm_connector_cfg
34
+ self.resume_path = resume_path
35
+ self.image_aspect_ratio = image_aspect_ratio
36
+ self.num_video_frames = num_video_frames
37
+ self.mm_vision_select_layer = mm_vision_select_layer
38
+ self.mm_vision_select_feature = mm_vision_select_feature
39
+ self.use_mm_start_end = use_mm_start_end
40
+ self.use_mm_patch_token = use_mm_patch_token
41
+ self.mm_connector_lr = mm_connector_lr
42
+ self.vision_resolution = vision_resolution
43
+ self.interpolate_mode = interpolate_mode
44
+ self.clip_duration = clip_duration
45
+ self.vocab_size=vocab_size
46
+ self.auto_map=auto_map
47
+
llm/added_tokens.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</tool_call>": 151658,
3
+ "<tool_call>": 151657,
4
+ "<|box_end|>": 151649,
5
+ "<|box_start|>": 151648,
6
+ "<|endoftext|>": 151643,
7
+ "<|file_sep|>": 151664,
8
+ "<|fim_middle|>": 151660,
9
+ "<|fim_pad|>": 151662,
10
+ "<|fim_prefix|>": 151659,
11
+ "<|fim_suffix|>": 151661,
12
+ "<|im_end|>": 151645,
13
+ "<|im_start|>": 151644,
14
+ "<|image_pad|>": 151655,
15
+ "<|object_ref_end|>": 151647,
16
+ "<|object_ref_start|>": 151646,
17
+ "<|quad_end|>": 151651,
18
+ "<|quad_start|>": 151650,
19
+ "<|repo_name|>": 151663,
20
+ "<|video_pad|>": 151656,
21
+ "<|vision_end|>": 151653,
22
+ "<|vision_pad|>": 151654,
23
+ "<|vision_start|>": 151652
24
+ }
llm/config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "./work_dirs/final_run/apollo-Qwen2.5-1.5B-Instruct-internvideo2-siglip-so400m-patch14-384-freeze-perciver_128_2-newprompt-ft/llm",
3
+ "architectures": [
4
+ "Qwen2ForCausalLM"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "eos_token_id": 151645,
9
+ "hidden_act": "silu",
10
+ "hidden_size": 1536,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 8960,
13
+ "max_position_embeddings": 32768,
14
+ "max_window_layers": 21,
15
+ "model_max_length": 16384,
16
+ "model_type": "qwen2",
17
+ "num_attention_heads": 12,
18
+ "num_hidden_layers": 28,
19
+ "num_key_value_heads": 2,
20
+ "rms_norm_eps": 1e-06,
21
+ "rope_theta": 1000000.0,
22
+ "sliding_window": null,
23
+ "tie_word_embeddings": true,
24
+ "tokenizer_model_max_length": 16384,
25
+ "tokenizer_padding_side": "right",
26
+ "torch_dtype": "bfloat16",
27
+ "transformers_version": "4.44.0",
28
+ "use_cache": true,
29
+ "use_sliding_window": false,
30
+ "vocab_size": 151936
31
+ }
llm/generation_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "do_sample": true,
4
+ "eos_token_id": [
5
+ 151645,
6
+ 151643
7
+ ],
8
+ "pad_token_id": 151643,
9
+ "repetition_penalty": 1.1,
10
+ "temperature": 0.7,
11
+ "top_k": 20,
12
+ "top_p": 0.8,
13
+ "transformers_version": "4.44.0"
14
+ }
llm/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
llm/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:13fbdfc8f2b6ef0b13dac9b236064e8b70a9c2480729635de93455ed5946e774
3
+ size 3554214752
llm/special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
llm/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
llm/tokenizer_config.json ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ }
181
+ },
182
+ "additional_special_tokens": [
183
+ "<|im_start|>",
184
+ "<|im_end|>",
185
+ "<|object_ref_start|>",
186
+ "<|object_ref_end|>",
187
+ "<|box_start|>",
188
+ "<|box_end|>",
189
+ "<|quad_start|>",
190
+ "<|quad_end|>",
191
+ "<|vision_start|>",
192
+ "<|vision_end|>",
193
+ "<|vision_pad|>",
194
+ "<|image_pad|>",
195
+ "<|video_pad|>"
196
+ ],
197
+ "bos_token": null,
198
+ "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
199
+ "clean_up_tokenization_spaces": false,
200
+ "eos_token": "<|im_end|>",
201
+ "errors": "replace",
202
+ "legacy": false,
203
+ "model_max_length": 16384,
204
+ "pad_token": "<|endoftext|>",
205
+ "padding_side": "right",
206
+ "split_special_tokens": false,
207
+ "tokenizer_class": "Qwen2Tokenizer",
208
+ "unk_token": null
209
+ }
llm/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
mm_connector.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re, math, torch
2
+ from collections import OrderedDict
3
+ from typing import Optional, Tuple
4
+
5
+ from torch import nn
6
+ from torch.nn.init import trunc_normal_, normal_
7
+ import torch.utils.checkpoint
8
+
9
+ from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel
10
+
11
+
12
+ class ClassInstantier(OrderedDict):
13
+ def __getitem__(self, key):
14
+ content = super().__getitem__(key)
15
+ cls, kwargs = content if isinstance(content, tuple) else (content, {})
16
+ return cls(**kwargs)
17
+
18
+
19
+ ACT2CLS = {"silu": nn.SiLU}
20
+
21
+ ACT2FN = ClassInstantier(ACT2CLS)
22
+
23
+
24
+ class WeightedNorm(nn.Module):
25
+ def __init__(self, hidden_size):
26
+ """
27
+ WeightedNorm
28
+ """
29
+ super().__init__()
30
+ self.hidden_size = hidden_size
31
+ self.norm = nn.LayerNorm(self.hidden_size)
32
+ self.wheight = nn.Parameter(torch.ones(self.hidden_size))
33
+ normal_(self.wheight, mean=1, std=.02)
34
+
35
+ def forward(self, x):
36
+ x = self.norm(x)
37
+ return x * self.wheight
38
+
39
+
40
+ class PerceiverMLP(nn.Module):
41
+ def __init__(
42
+ self,
43
+ hidden_size: int,
44
+ intermediate_size: int,
45
+ output_size: int,
46
+ hidden_act: str,
47
+ ):
48
+ super().__init__()
49
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
50
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
51
+ self.down_proj = nn.Linear(intermediate_size, output_size, bias=False)
52
+ self.act_fn = ACT2FN[hidden_act]
53
+
54
+ def forward(self, x):
55
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
56
+
57
+
58
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
59
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
60
+ """
61
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
62
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
63
+ """
64
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
65
+ if n_rep == 1:
66
+ return hidden_states
67
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
68
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
69
+
70
+
71
+ class PerceiverAttention(nn.Module):
72
+ def __init__(self, connector_config, layer_idx: Optional[int] = None) -> None:
73
+ """Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`"""
74
+ super().__init__()
75
+
76
+ self.layer_idx = None
77
+ self.hidden_size = connector_config.text_hidden_size
78
+ self.num_heads = connector_config.resampler_n_heads
79
+ self.head_dim = connector_config.resampler_head_dim
80
+ self.num_key_value_heads = connector_config.num_key_value_heads
81
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
82
+
83
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
84
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
85
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
86
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
87
+
88
+ self.is_causal = False
89
+
90
+ def forward(
91
+ self,
92
+ latents: torch.Tensor,
93
+ context: torch.Tensor,
94
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
95
+ output_attentions: bool = False,
96
+ use_cache: bool = False,
97
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
98
+ """
99
+ Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension!
100
+
101
+ Args:
102
+ latents (`torch.Tensor`): Tensor of shape [bsz, n_latents, embed_dim] representing fixed length latents to compress to.
103
+ context (`torch.Tensor`): Tensor of shape [bsz, seq, embed_dim] representing long-form context to resample.
104
+ output_attentions (`bool`, *optional*, defaults to `False`): Whether to return attention weights.
105
+ use_cache (`bool`, *optional*, defaults to `False`): Whether to use past_key_value for caching.
106
+ """
107
+ bsz, q_len, _ = latents.size()
108
+ kv_seq_len = q_len + context.size()[1]
109
+
110
+ hidden_states = torch.concat([context, latents], dim=-2)
111
+
112
+ query_states = self.q_proj(latents)
113
+ key_states = self.k_proj(hidden_states)
114
+ value_states = self.v_proj(hidden_states)
115
+
116
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
117
+ key_states = key_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
118
+ value_states = value_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
119
+
120
+ past_key_value = getattr(self, "past_key_value", past_key_value)
121
+
122
+ if past_key_value is not None:
123
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
124
+
125
+ # repeat k/v heads if n_kv_heads < n_heads
126
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
127
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
128
+
129
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
130
+
131
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
132
+ raise ValueError(
133
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
134
+ f" {attn_weights.size()}"
135
+ )
136
+
137
+ # upcast attention to fp32
138
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
139
+ attn_output = torch.matmul(attn_weights, value_states)
140
+
141
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
142
+ raise ValueError(
143
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
144
+ f" {attn_output.size()}"
145
+ )
146
+
147
+ attn_output = attn_output.transpose(1, 2).contiguous()
148
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
149
+
150
+ attn_output = self.o_proj(attn_output)
151
+
152
+ if not output_attentions:
153
+ attn_weights = None
154
+
155
+ return attn_output, attn_weights, past_key_value
156
+
157
+
158
+ PERCEIVER_ATTENTION_CLASSES = {
159
+ "eager": PerceiverAttention,
160
+ }
161
+
162
+
163
+ class PerceiverLayer(nn.Module):
164
+ def __init__(self, connector_config, layer_idx: int):
165
+ super().__init__()
166
+ self.hidden_size = connector_config.text_hidden_size
167
+ self.n_latents = connector_config.num_output_tokens
168
+ self.depth = connector_config.resampler_depth
169
+ self.ff_multi = connector_config.ff_multi
170
+
171
+ self.input_latents_norm = WeightedNorm(self.hidden_size)
172
+ self.input_context_norm = WeightedNorm(self.hidden_size)
173
+ self.self_attn = PERCEIVER_ATTENTION_CLASSES[connector_config._attn_implementation](connector_config,
174
+ layer_idx=layer_idx)
175
+ self.post_attention_layernorm = WeightedNorm(self.hidden_size)
176
+ self.mlp = PerceiverMLP(
177
+ hidden_size=connector_config.text_hidden_size,
178
+ intermediate_size=connector_config.text_hidden_size * self.ff_multi,
179
+ output_size=connector_config.text_hidden_size,
180
+ hidden_act=connector_config.hidden_act,
181
+ )
182
+
183
+ def forward(
184
+ self,
185
+ latents: torch.Tensor,
186
+ context: torch.Tensor,
187
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
188
+ output_attentions: Optional[bool] = False,
189
+ use_cache: Optional[bool] = False,
190
+ **kwargs,
191
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
192
+ """
193
+ Args:
194
+ latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
195
+ context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
196
+ output_attentions (`bool`, *optional*):
197
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
198
+ returned tensors for more detail.
199
+ use_cache (`bool`, *optional*):
200
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
201
+ (see `past_key_values`).
202
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
203
+ """
204
+ residual = latents
205
+
206
+ latents = self.input_latents_norm(latents)
207
+ context = self.input_context_norm(context)
208
+
209
+ latents, self_attn_weights, present_key_value = self.self_attn(
210
+ latents=latents,
211
+ context=context,
212
+ )
213
+
214
+ latents = residual + latents
215
+ residual = latents
216
+
217
+ latents = self.post_attention_layernorm(latents)
218
+ latents = self.mlp(latents)
219
+ latents = residual + latents
220
+
221
+ outputs = (latents,)
222
+
223
+ if output_attentions:
224
+ outputs += (self_attn_weights,)
225
+
226
+ if use_cache:
227
+ outputs += (present_key_value,)
228
+
229
+ return outputs
230
+
231
+
232
+ class PerceiverResampler(nn.Module):
233
+ """Perceiver Resampler that compresses input embeddings into a fixed number of latents."""
234
+
235
+ def __init__(self, connector_config) -> None:
236
+ super().__init__()
237
+ self.hidden_size = connector_config.text_hidden_size
238
+ self.hidden_act = connector_config.hidden_act
239
+ self.n_latents = connector_config.num_output_tokens
240
+ self.depth = connector_config.resampler_depth
241
+
242
+ # Create Latents for Perceiver
243
+ self.latents = nn.Parameter(torch.zeros(self.n_latents, self.hidden_size))
244
+
245
+ # Create Transformer Blocks
246
+ self.layers = nn.ModuleList([PerceiverLayer(connector_config, idx) for idx in range(self.depth)])
247
+ self.norm = WeightedNorm(self.hidden_size)
248
+ self._use_flash_attention_2 = connector_config._attn_implementation == "flash_attention_2"
249
+
250
+ def forward(
251
+ self,
252
+ context: torch.Tensor,
253
+ attention_mask: torch.Tensor = None,
254
+ ) -> torch.Tensor:
255
+ # seq embed -> bsz seq embed
256
+ latents = self.latents.unsqueeze(0).expand((context.shape[0], *self.latents.size()))
257
+
258
+ compressed_context = latents
259
+ for i, perceiver_layer in enumerate(self.layers):
260
+ layer_outputs = perceiver_layer(
261
+ compressed_context,
262
+ context,
263
+ past_key_value=None,
264
+ output_attentions=False,
265
+ use_cache=False,
266
+ )
267
+ compressed_context = layer_outputs[0]
268
+
269
+ compressed_context = self.norm(compressed_context)
270
+ return compressed_context
271
+
272
+
273
+ def build_mm_projector(
274
+ input_dim,
275
+ output_dim,
276
+ projector_type,
277
+ hidden_act='silu',
278
+ delay_load=False,
279
+ token_input_shape=0,
280
+ **kwargs
281
+ ) -> nn.Sequential:
282
+
283
+ modules = [nn.Linear(input_dim, output_dim)]
284
+ mlp_gelu_match = re.match(r'.*mlp(\d+)x_gelu$', projector_type)
285
+ if mlp_gelu_match is not None:
286
+ mlp_depth = int(mlp_gelu_match.group(1))
287
+ for _ in range(mlp_depth - 1):
288
+ modules.append(nn.GELU())
289
+ modules.append(nn.Linear(output_dim, output_dim))
290
+
291
+ return nn.Sequential(*modules)
292
+
293
+
294
+ class MMConnector(PreTrainedModel):
295
+ config_class = PretrainedConfig
296
+
297
+ def __init__(self, config: PretrainedConfig) -> None:
298
+ super().__init__(config)
299
+ self.proj = build_mm_projector(config.vision_hidden_size, config.text_hidden_size,
300
+ config.projector_type, token_input_shape=config.token_input_shape)
301
+ self.resampler = PerceiverResampler(config)
302
+
303
+ def forward(self, x):
304
+ x = self.proj(x)
305
+ x = self.resampler(x)
306
+ return x
mm_connector/config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Connector"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "ff_multi": 4,
7
+ "hidden_act": "silu",
8
+ "model_type": "mm_connector",
9
+ "num_key_value_heads": 4,
10
+ "num_output_tokens": 128,
11
+ "num_patches": 24,
12
+ "projector_type": "mlp1x_gelu",
13
+ "resampler_depth": 1,
14
+ "resampler_head_dim": 96,
15
+ "resampler_n_heads": 16,
16
+ "resampler_type": "perciver",
17
+ "rms_norm_eps": 1e-06,
18
+ "text_hidden_size": 1536,
19
+ "token_input_shape": [
20
+ 4,
21
+ 27,
22
+ 27
23
+ ],
24
+ "torch_dtype": "bfloat16",
25
+ "transformers_version": "4.44.0",
26
+ "vision_hidden_size": 2560,
27
+ "auto_map": {
28
+ "AutoConfig": "configuration_connector.ConnectorConfig"
29
+ }
30
+ }
mm_connector/configuration_connector.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Dict, List, Union
4
+ from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel
5
+ import torch.nn.functional as F
6
+ import json, os
7
+
8
+
9
+ class ConnectorConfig(PretrainedConfig):
10
+ model_type = "mm_connector"
11
+ def __init__(
12
+ self,
13
+ vision_hidden_size: List[int] = [],
14
+ text_hidden_size: int = 0,
15
+ num_patches: int = 24,
16
+ rms_norm_eps: float = 1e-4,
17
+ token_input_shape: List[int] = [],
18
+ **kwargs,
19
+ ):
20
+ super().__init__(**kwargs)
21
+ self.vision_hidden_size = vision_hidden_size
22
+ self.text_hidden_size = text_hidden_size
23
+ self.num_patches = num_patches
24
+ self.rms_norm_eps=rms_norm_eps
25
+ self.token_input_shape = token_input_shape
26
+
27
+ @classmethod
28
+ def load_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "ConnectorConfig":
29
+ cls._set_token_in_kwargs(kwargs)
30
+ config_dict, kwargs = cls.get_config_from_json(pretrained_model_name_or_path, **kwargs)
31
+ return cls.from_dict(config_dict, **kwargs)
32
+
33
+ @classmethod
34
+ def get_config_from_json(cls, config_file, **kwargs):
35
+ with open(config_file, 'r') as file:
36
+ config_data = json.load(file)
37
+ return config_data, kwargs
38
+
mm_connector/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e019c89bb11f161b4c20b2880dd96daa54775a658b3d57f9c1abfc252483a80a
3
+ size 76719488
modeling_apollo.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+ import warnings, os, torch
3
+ import torch.nn as nn
4
+
5
+ from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, AutoModelForCausalLM, AutoTokenizer
6
+ from transformers.modeling_utils import ContextManagers, no_init_weights
7
+ from transformers.modeling_outputs import CausalLMOutputWithPast
8
+ from transformers.generation.utils import GenerateOutput
9
+ from .configuration_apollo import ApolloConfig
10
+
11
+ from .vision_tower import ApolloVisionTower
12
+ from .mm_connector import MMConnector
13
+
14
+ IGNORE_INDEX = -100
15
+ X_TOKEN_INDEX = -200
16
+
17
+
18
+ def get_model_config(config):
19
+ default_keys = ["llm_cfg", "vision_tower_cfg", "mm_connector_cfg"]
20
+ if hasattr(config, "_name_or_path") and len(config._name_or_path) >= 2:
21
+ root_path = config._name_or_path
22
+ else:
23
+ root_path = config.resume_path
24
+
25
+ return_pths = []
26
+ for key in default_keys:
27
+ cfg = getattr(config, key, None)
28
+ if isinstance(cfg, dict):
29
+ try:
30
+ return_pths.append(os.path.join(root_path, key[:-4]))
31
+ except:
32
+ raise ValueError(f"Cannot find resume path in config for {key}!")
33
+ elif isinstance(cfg, PretrainedConfig):
34
+ return_pths.append(os.path.join(root_path, key[:-4]))
35
+ elif isinstance(cfg, str):
36
+ return_pths.append(cfg)
37
+
38
+ return_list = []
39
+ for pth in return_pths:
40
+ return_list.append(AutoConfig.from_pretrained(pth, trust_remote_code=True))
41
+
42
+ return return_list
43
+
44
+
45
+ def build_llm_and_tokenizer(
46
+ llm_cfg: str,
47
+ config: PretrainedConfig,
48
+ attn_implementation=None,
49
+ model_max_length=None,
50
+ *args,
51
+ **kwargs,
52
+ ) -> PreTrainedModel:
53
+ llm_arch = getattr(llm_cfg, "architectures")[0].lower()
54
+
55
+ llm_path = llm_cfg._name_or_path
56
+ llm = AutoModelForCausalLM.from_pretrained(
57
+ llm_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs
58
+ )
59
+
60
+ tokenizer = AutoTokenizer.from_pretrained(
61
+ llm_path,
62
+ model_max_length=llm_cfg.model_max_length,
63
+ padding_side="right",
64
+ use_fast=False,
65
+ legacy=False,
66
+ **kwargs
67
+ )
68
+
69
+ #config.hidden_size = llm.config.hidden_size
70
+ return llm, tokenizer
71
+
72
+
73
+ class ApolloForCausalLM(PreTrainedModel):
74
+ def __init__(self, config: ApolloConfig, *args, **kwargs):
75
+ super().__init__(config)
76
+ llm_cfg, vision_tower_cfg, mm_connector_cfg = get_model_config(config)
77
+ model_dtype = getattr(config, "model_dtype", "torch.float16")
78
+ if not hasattr(config, "model_dtype"):
79
+ warnings.warn("model_dtype not found in config, defaulting to torch.float16.")
80
+ config.model_dtype = model_dtype
81
+ # Initialize weights and apply final processing
82
+
83
+ self.lm_head = nn.Linear(llm_cfg.hidden_size, config.vocab_size, bias=False)
84
+ self.vision_tower = ApolloVisionTower(config, vision_tower_cfg)
85
+ self.mm_connector = MMConnector.from_pretrained(mm_connector_cfg._name_or_path)
86
+ self.llm, self.tokenizer = build_llm_and_tokenizer(llm_cfg, config, *args, **kwargs)
87
+ self.post_init()
88
+ self.is_loaded = True
89
+
90
+ def forward(
91
+ self,
92
+ input_ids: torch.LongTensor = None,
93
+ attention_mask: Optional[torch.Tensor] = None,
94
+ position_ids: Optional[torch.LongTensor] = None,
95
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
96
+ inputs_embeds: Optional[torch.FloatTensor] = None,
97
+ labels: Optional[torch.LongTensor] = None,
98
+ use_cache: Optional[bool] = None,
99
+ output_attentions: Optional[bool] = None,
100
+ output_hidden_states: Optional[bool] = None,
101
+ vision_input: Optional[List[torch.FloatTensor]] = None,
102
+ data_types: Optional[List[str]] = None,
103
+ return_dict: Optional[bool] = None,
104
+ cache_position=None,
105
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
106
+
107
+ if inputs_embeds is None:
108
+ (
109
+ input_ids,
110
+ position_ids,
111
+ attention_mask,
112
+ past_key_values,
113
+ inputs_embeds,
114
+ labels
115
+ ) = self.prepare_inputs_labels_for_multimodal(
116
+ input_ids,
117
+ position_ids,
118
+ attention_mask,
119
+ past_key_values,
120
+ labels,
121
+ vision_input,
122
+ data_types
123
+ )
124
+
125
+ return self.get_llm().forward(
126
+ input_ids=input_ids,
127
+ attention_mask=attention_mask,
128
+ position_ids=position_ids,
129
+ past_key_values=past_key_values,
130
+ inputs_embeds=inputs_embeds,
131
+ labels=labels,
132
+ use_cache=use_cache,
133
+ output_attentions=output_attentions,
134
+ output_hidden_states=output_hidden_states,
135
+ return_dict=return_dict,
136
+ )
137
+
138
+ @torch.no_grad()
139
+ def generate(
140
+ self,
141
+ inputs: Optional[torch.Tensor] = None,
142
+ vision_input: Optional[List[torch.Tensor]] = None,
143
+ data_types: Optional[List[str]] = None,
144
+ **kwargs,
145
+ ) -> Union[GenerateOutput, torch.LongTensor]:
146
+ position_ids = kwargs.pop("position_ids", None)
147
+ attention_mask = kwargs.pop("attention_mask", None)
148
+ if "inputs_embeds" in kwargs:
149
+ raise NotImplementedError("`inputs_embeds` is not supported")
150
+
151
+ if vision_input is not None:
152
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(
153
+ inputs, position_ids, attention_mask, None, None, vision_input, data_types=data_types)
154
+ else:
155
+ inputs_embeds = self.embed_tokens(inputs)
156
+
157
+ return self.get_llm().generate(position_ids=position_ids, attention_mask=attention_mask,
158
+ inputs_embeds=inputs_embeds, **kwargs)
159
+
160
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
161
+ vision_input = kwargs.pop("vision_input", None)
162
+ data_types = kwargs.pop("data_types", None)
163
+ inputs = self.get_llm().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values,
164
+ inputs_embeds=inputs_embeds, **kwargs)
165
+ if vision_input is not None:
166
+ inputs["vision_input"] = vision_input
167
+ if data_types is not None:
168
+ inputs["data_types"] = data_types
169
+ return inputs
170
+
171
+ @classmethod
172
+ def from_pretrained(
173
+ cls,
174
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
175
+ *model_args,
176
+ config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
177
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
178
+ ignore_mismatched_sizes: bool = False,
179
+ force_download: bool = False,
180
+ local_files_only: bool = False,
181
+ token: Optional[Union[str, bool]] = None,
182
+ revision: str = "main",
183
+ use_safetensors: bool = None,
184
+ **kwargs,
185
+ ):
186
+
187
+ return cls.load_pretrained(
188
+ pretrained_model_name_or_path,
189
+ *model_args,
190
+ config=config,
191
+ cache_dir=cache_dir,
192
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
193
+ force_download=force_download,
194
+ local_files_only=local_files_only,
195
+ token=token,
196
+ revision=revision,
197
+ use_safetensors=use_safetensors,
198
+ **kwargs,
199
+ )
200
+
201
+ def get_llm(self):
202
+ return self.llm
203
+
204
+ def get_vision_tower(self):
205
+ return self.vision_tower
206
+
207
+ def get_mm_connector(self):
208
+ return self.mm_connector
209
+
210
+ @classmethod
211
+ def load_pretrained(cls, model_path_or_config, *args, **kwargs):
212
+ kwargs.pop("config", None)
213
+
214
+ if isinstance(model_path_or_config, str):
215
+ config = AutoConfig.from_pretrained(model_path_or_config, trust_remote_code=True, **kwargs)
216
+ elif isinstance(model_path_or_config, ApolloConfig):
217
+ config = model_path_or_config
218
+ else:
219
+ raise NotImplementedError(f"wrong type, {type(model_path_or_config)} \
220
+ {isinstance(model_path_or_config, ApolloConfig)}")
221
+
222
+ model_dtype = getattr(config, "model_dtype", "torch.float16")
223
+ if not hasattr(config, "model_dtype"):
224
+ warnings.warn("model_dtype not found in config, defaulting to torch.float16.")
225
+ config.model_dtype = model_dtype
226
+
227
+ with ContextManagers([no_init_weights(_enable=True), ]):
228
+ vlm = cls(config, *args, **kwargs)
229
+
230
+ if hasattr(vlm, "llm") and hasattr(vlm, "vision_tower") and hasattr(vlm, "mm_connector"):
231
+ if vlm.is_loaded:
232
+ return vlm
233
+ else:
234
+ print('loading model failed!')
235
+ else:
236
+ print('loading model failed!')
237
+
238
+ def _encode_mm(self, x):
239
+ x = self.get_vision_tower()(x)
240
+ x = self.mm_connector(x)
241
+ return x
242
+
243
+ def encode_mm_minibatch(self, x):
244
+ split_sizes = [x_s[0].shape[0] for x_s in x]
245
+ x = [torch.split(torch.cat([x_s[i] for x_s in x], dim=0), self.config.encode_batch_size) for i in
246
+ range(self.get_vision_tower().num_vision_encoders)]
247
+ swapped_x = []
248
+ for i in range(len(x[0])):
249
+ swapped_x.append([x_s[i] for x_s in x])
250
+
251
+ features = []
252
+ for xx in swapped_x:
253
+ xx = self._encode_mm(xx)
254
+ features.append(xx)
255
+ x = torch.cat(features, dim=0)
256
+ x = torch.split(x, split_sizes, dim=0)
257
+ return [xx.contiguous().view(-1, xx.shape[2]) for xx in x]
258
+
259
+ def prepare_inputs_labels_for_multimodal(
260
+ self, input_ids, position_ids, attention_mask, past_key_values, labels, vision_input, data_types
261
+ ):
262
+ vision_tower = self.get_vision_tower()
263
+ if vision_tower is None or vision_input is None or input_ids.shape[1] == 1:
264
+ if (
265
+ past_key_values is not None
266
+ and vision_tower is not None
267
+ and vision_input is not None
268
+ and input_ids.shape[1] == 1
269
+ ):
270
+ target_shape = past_key_values[-1][-1].shape[-2] + 1
271
+ attention_mask = torch.cat(
272
+ (
273
+ attention_mask,
274
+ torch.ones(
275
+ (
276
+ attention_mask.shape[0],
277
+ target_shape - attention_mask.shape[1],
278
+ ),
279
+ dtype=attention_mask.dtype,
280
+ device=attention_mask.device,
281
+ ),
282
+ ),
283
+ dim=1,
284
+ )
285
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
286
+ return (
287
+ input_ids,
288
+ position_ids,
289
+ attention_mask,
290
+ past_key_values,
291
+ None,
292
+ labels,
293
+ )
294
+
295
+ '''
296
+ vision_input is a list of tuples, and data_type is a list of strings:
297
+ data_type = ['image', 'video', 'video'..., 'text']
298
+ (for one video and two image encoders)
299
+ vision_input =
300
+ [
301
+ [image(1, T, C, H, W), image(1, T, C, H, W), image(1, T, C, H, W)],
302
+ [video(Nc1, C, T, H, W), video(Nc1, T, C, H, W), video(Nc1, T, C, H, W)],
303
+ [video(Nc2, C, T, H, W), video(Nc2, T, C, H, W), video(Nc2, T, C, H, W)],
304
+ ]
305
+ -> video encoders typlically expect (C,T,H,W), images expect (C,H,W).
306
+ '''
307
+ # ====================================================================================================
308
+ merged_mm_features = self.encode_mm_minibatch(vision_input)
309
+
310
+ if not getattr(self.config, "tune_language_model", True) and getattr(self.config, "use_mm_start_end", False):
311
+ raise NotImplementedError
312
+ # ====================================================================================================
313
+ # Let's just add dummy tensors if they do not exist,
314
+ # it is a headache to deal with None all the time.
315
+ # But it is not ideal, and if you have a better idea,
316
+ # please open an issue / submit a PR, thanks.
317
+ _labels = labels
318
+ _position_ids = position_ids
319
+ _attention_mask = attention_mask
320
+ if attention_mask is None:
321
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
322
+ else:
323
+ attention_mask = attention_mask.bool()
324
+ if position_ids is None:
325
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
326
+ if labels is None:
327
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
328
+
329
+ # remove the padding using attention_mask
330
+ input_ids_copy = input_ids.clone()
331
+ # kentang-mit@: Otherwise tokenizer out of bounds. Embeddings of image tokens will not be used.
332
+ input_ids_copy[input_ids_copy == X_TOKEN_INDEX] = 0
333
+ input_embeds = self.get_llm().model.embed_tokens(input_ids_copy)
334
+
335
+ input_ids = [
336
+ cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
337
+ ]
338
+ input_embeds_1 = [
339
+ cur_input_embeds[cur_attention_mask]
340
+ for cur_input_embeds, cur_attention_mask in zip(input_embeds, attention_mask)
341
+ ]
342
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
343
+ # input_ids, new_input_embeds = self.inputs_merger(input_ids, input_embeds_1, merged_mm_features)
344
+ new_labels = []
345
+ new_input_embeds = []
346
+ # print("BEFORE BATCH LOOP:", len(input_ids), input_ids[0].shape, input_ids[0].device, [(x == X_TOKEN_INDEX).sum() for x in input_ids])
347
+ # kentang-mit@: If some part of the model is executed in the loop, the the loop length needs to be a constant.
348
+ for batch_idx, (cur_labels, cur_input_ids, mm_features) in enumerate(
349
+ zip(labels, input_ids, merged_mm_features)):
350
+ cur_input_ids = input_ids[batch_idx]
351
+ num_mm = (cur_input_ids == X_TOKEN_INDEX).sum()
352
+ if num_mm == 0:
353
+ cur_input_embeds_1 = input_embeds_1[batch_idx]
354
+ cur_input_embeds = torch.cat([cur_input_embeds_1, mm_features[0:0]], dim=0)
355
+ new_input_embeds.append(cur_input_embeds)
356
+ new_labels.append(cur_labels)
357
+ # kenang-mit@: we do not have placeholdr image for text-only data now.
358
+ continue
359
+
360
+ if mm_features.shape[0] != num_mm:
361
+ print(data_types[batch_idx])
362
+ assert num_mm == len(
363
+ mm_features), f'Error in {data_types[batch_idx]}{num_mm}=/={len(mm_features)} not the same number of vision tokens in and vision embeddings!'
364
+
365
+ cur_input_embeds = input_embeds_1[batch_idx]
366
+ image_token_indices = (
367
+ [-1] + torch.where(cur_input_ids == X_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
368
+ )
369
+ cur_input_ids_noim = []
370
+ cur_labels = labels[batch_idx]
371
+ cur_labels_noim = []
372
+ cur_input_embeds_no_im = []
373
+ for i in range(len(image_token_indices) - 1):
374
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1: image_token_indices[i + 1]])
375
+ cur_labels_noim.append(cur_labels[image_token_indices[i] + 1: image_token_indices[i + 1]])
376
+ cur_input_embeds_no_im.append(cur_input_embeds[image_token_indices[i] + 1: image_token_indices[i + 1]])
377
+
378
+ cur_new_input_embeds = []
379
+ cur_new_labels = []
380
+ for i in range(num_mm + 1):
381
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
382
+ # print("cur_new_input_embeds1", cur_new_input_embeds.shape[-1])
383
+ cur_new_labels.append(cur_labels_noim[i])
384
+ if i < num_mm:
385
+ cur_image_features = mm_features[i:i + 1]
386
+ cur_new_input_embeds.append(cur_image_features)
387
+ # print("cur_new_input_embeds2", cur_new_input_embeds.shape[-1])
388
+ cur_new_labels.append(
389
+ torch.full(
390
+ (cur_image_features.shape[0],),
391
+ IGNORE_INDEX,
392
+ device=cur_labels.device,
393
+ dtype=cur_labels.dtype,
394
+ )
395
+ )
396
+
397
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
398
+ cur_new_labels = torch.cat(cur_new_labels)
399
+
400
+ new_input_embeds.append(cur_new_input_embeds)
401
+ new_labels.append(cur_new_labels)
402
+
403
+ # Truncate sequences to max length as image embeddings can make the sequence longer
404
+ tokenizer_model_max_length = getattr(self.get_llm().config, "tokenizer_model_max_length", None)
405
+ if tokenizer_model_max_length is not None:
406
+ if any(len(x) > tokenizer_model_max_length for x in new_input_embeds):
407
+ priny("Inputs truncated!")
408
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
409
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
410
+ # Combine them
411
+ max_len = max(x.shape[0] for x in new_input_embeds)
412
+ batch_size = len(new_input_embeds)
413
+
414
+ new_input_embeds_padded = []
415
+ new_labels_padded = torch.full(
416
+ (batch_size, max_len),
417
+ IGNORE_INDEX,
418
+ dtype=new_labels[0].dtype,
419
+ device=new_labels[0].device,
420
+ )
421
+ attention_mask = torch.zeros(
422
+ (batch_size, max_len),
423
+ dtype=attention_mask.dtype,
424
+ device=attention_mask.device,
425
+ )
426
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
427
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
428
+ cur_len = cur_new_embed.shape[0]
429
+ if getattr(self.get_llm().config, "tokenizer_padding_side", "right") == "left":
430
+ new_input_embeds_padded.append(
431
+ torch.cat(
432
+ (
433
+ torch.zeros(
434
+ (max_len - cur_len, cur_new_embed.shape[1]),
435
+ dtype=cur_new_embed.dtype,
436
+ device=cur_new_embed.device,
437
+ ),
438
+ cur_new_embed,
439
+ ),
440
+ dim=0,
441
+ )
442
+ )
443
+ if cur_len > 0:
444
+ new_labels_padded[i, -cur_len:] = cur_new_labels
445
+ attention_mask[i, -cur_len:] = True
446
+ position_ids[i, -cur_len:] = torch.arange(
447
+ 0, cur_len, dtype=position_ids.dtype, device=position_ids.device
448
+ )
449
+ else:
450
+ new_input_embeds_padded.append(
451
+ torch.cat(
452
+ (
453
+ cur_new_embed,
454
+ torch.zeros(
455
+ (max_len - cur_len, cur_new_embed.shape[1]),
456
+ dtype=cur_new_embed.dtype,
457
+ device=cur_new_embed.device,
458
+ ),
459
+ ),
460
+ dim=0,
461
+ )
462
+ )
463
+ if cur_len > 0:
464
+ new_labels_padded[i, :cur_len] = cur_new_labels
465
+ attention_mask[i, :cur_len] = True
466
+ position_ids[i, :cur_len] = torch.arange(
467
+ 0, cur_len, dtype=position_ids.dtype, device=position_ids.device
468
+ )
469
+
470
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
471
+
472
+ if _labels is None:
473
+ new_labels = None
474
+ else:
475
+ new_labels = new_labels_padded
476
+
477
+ if _attention_mask is None:
478
+ attention_mask = None
479
+ else:
480
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
481
+
482
+ if _position_ids is None:
483
+ position_ids = None
484
+
485
+ return (
486
+ None,
487
+ position_ids,
488
+ attention_mask,
489
+ past_key_values,
490
+ new_input_embeds,
491
+ new_labels,
492
+ )
vision_tower.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, os, PIL, numbers
2
+ from PIL import Image
3
+ import cv2
4
+
5
+ from transformers.modeling_utils import PreTrainedModel
6
+ from transformers.models.siglip.modeling_siglip import SiglipVisionModel
7
+ from transformers import AutoConfig, AutoModel, SiglipImageProcessor, SiglipVisionConfig, PretrainedConfig
8
+ from typing import Union
9
+ import torch.nn.functional as F
10
+ import numpy as np
11
+
12
+
13
+ def crop_clip(clip, min_h, min_w, h, w):
14
+ if isinstance(clip[0], np.ndarray):
15
+ cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]
16
+
17
+ elif isinstance(clip[0], PIL.Image.Image):
18
+ cropped = [
19
+ img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip
20
+ ]
21
+ else:
22
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
23
+ 'but got list of {0}'.format(type(clip[0])))
24
+ return cropped
25
+
26
+
27
+ class Normalize(object):
28
+ """Normalize a clip with mean and standard deviation.
29
+ Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
30
+ will normalize each channel of the input ``torch.*Tensor`` i.e.
31
+ ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
32
+ .. note::
33
+ This transform acts out of place, i.e., it does not mutates the input tensor.
34
+ Args:
35
+ mean (sequence): Sequence of means for each channel.
36
+ std (sequence): Sequence of standard deviations for each channel.
37
+ """
38
+
39
+ def __init__(self, mean, std):
40
+ self.mean = mean
41
+ self.std = std
42
+
43
+ def __call__(self, clip):
44
+ """
45
+ Args:
46
+ clip (Tensor): Tensor clip of size (T, C, H, W) to be normalized.
47
+ Returns:
48
+ Tensor: Normalized Tensor clip.
49
+ """
50
+ return normalize(clip, self.mean, self.std)
51
+
52
+ def __repr__(self):
53
+ return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
54
+
55
+
56
+ class CenterCrop(object):
57
+ """Extract center crop at the same location for a list of images
58
+ Args:
59
+ size (sequence or int): Desired output size for the
60
+ crop in format (h, w)
61
+ """
62
+
63
+ def __init__(self, size):
64
+ if isinstance(size, numbers.Number):
65
+ size = (size, size)
66
+
67
+ self.size = size
68
+
69
+ def __call__(self, clip):
70
+ """
71
+ Args:
72
+ img (PIL.Image or numpy.ndarray): List of images to be cropped
73
+ in format (h, w, c) in numpy.ndarray
74
+ Returns:
75
+ PIL.Image or numpy.ndarray: Cropped list of images
76
+ """
77
+ h, w = self.size
78
+ if isinstance(clip[0], np.ndarray):
79
+ im_h, im_w, im_c = clip[0].shape
80
+ elif isinstance(clip[0], PIL.Image.Image):
81
+ im_w, im_h = clip[0].size
82
+ else:
83
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
84
+ 'but got list of {0}'.format(type(clip[0])))
85
+ if w > im_w or h > im_h:
86
+ error_msg = (
87
+ 'Initial image size should be larger then '
88
+ 'cropped size but got cropped sizes : ({w}, {h}) while '
89
+ 'initial image is ({im_w}, {im_h})'.format(
90
+ im_w=im_w, im_h=im_h, w=w, h=h))
91
+ raise ValueError(error_msg)
92
+
93
+ x1 = int(round((im_w - w) / 2.))
94
+ y1 = int(round((im_h - h) / 2.))
95
+ cropped = crop_clip(clip, y1, x1, h, w)
96
+
97
+ return cropped
98
+
99
+
100
+ def resize_clip(clip, size, interpolation='bilinear'):
101
+ if isinstance(clip[0], np.ndarray):
102
+ if isinstance(size, numbers.Number):
103
+ im_h, im_w, im_c = clip[0].shape
104
+ # Min spatial dim already matches minimal size
105
+ if (im_w <= im_h and im_w == size) or (im_h <= im_w
106
+ and im_h == size):
107
+ return clip
108
+ new_h, new_w = get_resize_sizes(im_h, im_w, size)
109
+ size = (new_w, new_h)
110
+ else:
111
+ size = size[0], size[1]
112
+ if interpolation == 'bilinear':
113
+ np_inter = cv2.INTER_LINEAR
114
+ else:
115
+ np_inter = cv2.INTER_NEAREST
116
+ scaled = [
117
+ cv2.resize(img, size, interpolation=np_inter) for img in clip
118
+ ]
119
+ elif isinstance(clip[0], PIL.Image.Image):
120
+ if isinstance(size, numbers.Number):
121
+ im_w, im_h = clip[0].size
122
+ # Min spatial dim already matches minimal size
123
+ if (im_w <= im_h and im_w == size) or (im_h <= im_w
124
+ and im_h == size):
125
+ return clip
126
+ new_h, new_w = get_resize_sizes(im_h, im_w, size)
127
+ size = (new_w, new_h)
128
+ else:
129
+ size = size[1], size[0]
130
+ if interpolation == 'bilinear':
131
+ pil_inter = PIL.Image.BILINEAR
132
+ else:
133
+ pil_inter = PIL.Image.NEAREST
134
+ scaled = [img.resize(size, pil_inter) for img in clip]
135
+ else:
136
+ raise TypeError('Expected numpy.ndarray or PIL.Image' +
137
+ 'but got list of {0}'.format(type(clip[0])))
138
+ return scaled
139
+
140
+
141
+ def _is_tensor_clip(clip):
142
+ return torch.is_tensor(clip) and clip.ndimension() == 4
143
+
144
+
145
+ def get_resize_sizes(im_h, im_w, size):
146
+ if im_w < im_h:
147
+ ow = size
148
+ oh = int(size * im_h / im_w)
149
+ else:
150
+ oh = size
151
+ ow = int(size * im_w / im_h)
152
+ return oh, ow
153
+
154
+
155
+ def normalize(clip, mean, std, inplace=False):
156
+ if not _is_tensor_clip(clip):
157
+ raise TypeError('tensor is not a torch clip.')
158
+
159
+ if not inplace:
160
+ clip = clip.clone()
161
+
162
+ dtype = clip.dtype
163
+ mean = torch.as_tensor(mean, dtype=dtype, device=clip.device)
164
+ std = torch.as_tensor(std, dtype=dtype, device=clip.device)
165
+ clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
166
+
167
+ return clip
168
+
169
+
170
+ class Resize(object):
171
+ """Resizes a list of (H x W x C) numpy.ndarray to the final size
172
+ The larger the original image is, the more times it takes to
173
+ interpolate
174
+ Args:
175
+ interpolation (str): Can be one of 'nearest', 'bilinear'
176
+ defaults to nearest
177
+ size (tuple): (widht, height)
178
+ """
179
+
180
+ def __init__(self, size, interpolation='nearest'):
181
+ self.size = size
182
+ self.interpolation = interpolation
183
+
184
+ def __call__(self, clip):
185
+ resized = resize_clip(
186
+ clip, self.size, interpolation=self.interpolation)
187
+ return resized
188
+
189
+
190
+ class Compose(object):
191
+ """Composes several transforms
192
+ Args:
193
+ transforms (list of ``Transform`` objects): list of transforms
194
+ to compose
195
+ """
196
+
197
+ def __init__(self, transforms):
198
+ self.transforms = transforms
199
+
200
+ def __call__(self, clip):
201
+ for t in self.transforms:
202
+ clip = t(clip)
203
+ return clip
204
+
205
+
206
+ def convert_img(img):
207
+ """Converts (H, W, C) numpy.ndarray to (C, W, H) format"""
208
+ if len(img.shape) == 3:
209
+ img = img.transpose(2, 0, 1)
210
+ if len(img.shape) == 2:
211
+ img = np.expand_dims(img, 0)
212
+ return img
213
+
214
+
215
+ class ClipToTensor(object):
216
+ """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255]
217
+ to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0]
218
+ """
219
+
220
+ def __init__(self, channel_nb=3, div_255=True, numpy=False):
221
+ self.channel_nb = channel_nb
222
+ self.div_255 = div_255
223
+ self.numpy = numpy
224
+
225
+ def __call__(self, clip):
226
+ """
227
+ Args: clip (list of numpy.ndarray): clip (list of images)
228
+ to be converted to tensor.
229
+ """
230
+ # Retrieve shape
231
+ if isinstance(clip[0], np.ndarray):
232
+ h, w, ch = clip[0].shape
233
+ assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch)
234
+ elif isinstance(clip[0], Image.Image):
235
+ w, h = clip[0].size
236
+ else:
237
+ raise TypeError(
238
+ "Expected numpy.ndarray or PIL.Image\
239
+ but got list of {0}".format(
240
+ type(clip[0])
241
+ )
242
+ )
243
+
244
+ np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)])
245
+
246
+ # Convert
247
+ for img_idx, img in enumerate(clip):
248
+ if isinstance(img, np.ndarray):
249
+ pass
250
+ elif isinstance(img, Image.Image):
251
+ img = np.array(img, copy=False)
252
+ else:
253
+ raise TypeError(
254
+ "Expected numpy.ndarray or PIL.Image\
255
+ but got list of {0}".format(
256
+ type(clip[0])
257
+ )
258
+ )
259
+ img = convert_img(img)
260
+ np_clip[:, img_idx, :, :] = img
261
+ if self.numpy:
262
+ if self.div_255:
263
+ np_clip = np_clip / 255.0
264
+ return np_clip
265
+
266
+ else:
267
+ tensor_clip = torch.from_numpy(np_clip)
268
+
269
+ if not isinstance(tensor_clip, torch.FloatTensor):
270
+ tensor_clip = tensor_clip.float()
271
+ if self.div_255:
272
+ tensor_clip = torch.div(tensor_clip, 255)
273
+ return tensor_clip
274
+
275
+
276
+ class VisionTowerConfig(PretrainedConfig):
277
+ model_type = "vision_tower"
278
+
279
+ def __init__(self, vision_tower_name: str = None, **kwargs):
280
+ super().__init__()
281
+ self.vision_tower_name = vision_tower_name
282
+
283
+
284
+ class ProcessorWrapper:
285
+ def __init__(self, transform=None, processor=None, height=378, width=378, frames_per_clip=1,
286
+ image_mean=[0.48145466, 0.4578275, 0.40821073]):
287
+ assert transform is not None or processor is not None, "ERROR: you did not define both `transform` and `processor`! You must define either transform or processor"
288
+ assert transform is None or processor is None, "ERROR: you did defined both `transform` and `processor`! You must define only one of: transform or processor"
289
+ self._size = {
290
+ "height": height,
291
+ "width": width,
292
+ "frames_per_clip": frames_per_clip
293
+ }
294
+ self._transforms = transform
295
+ self._processor = processor
296
+ self.image_mean = image_mean
297
+
298
+ @property
299
+ def size(self):
300
+ return self._size
301
+
302
+ def preprocess(self, image, return_tensors='pt'):
303
+ # Ensure image is a PIL Image
304
+ output = {}
305
+ if self._transforms is not None:
306
+ output['pixel_values'] = [self._transforms(image)]
307
+
308
+ else:
309
+ output = self._processor(image, return_tensors='pt')
310
+ return output
311
+
312
+ def save_pretrained(self, save_path):
313
+ if self._transforms is not None:
314
+ transform_dict = transform_to_dict(self._transforms)
315
+ transform_dict["image_processor_type"] = "transforms"
316
+ with open(os.path.join(save_path, 'preprocessor_config.json'), 'w') as f:
317
+ json.dump(transform_dict, f, indent=4)
318
+ else:
319
+ self._processor.save_pretrained(save_path)
320
+ return
321
+
322
+
323
+ class VisionTower(PreTrainedModel):
324
+ config_class = VisionTowerConfig
325
+
326
+ def __init__(self, model_name_or_path: str, config: PretrainedConfig, vision_config: VisionTowerConfig = None):
327
+ super().__init__(vision_config)
328
+ self.vision_tower_name = model_name_or_path
329
+ self.vision_config = vision_config
330
+ self.select_layer = getattr(config, "mm_vision_select_layer", -2)
331
+ self.select_feature = getattr(config, "mm_vision_select_feature", "patch")
332
+ self.encode_batch_size = getattr(config, "encode_batch_size", 0) // 2
333
+ self.num_encode_batch = getattr(config, "num_encode_batch", 0) // 2
334
+ self.temporal_tubelet_size = getattr(vision_config, "tubelet_size", 1)
335
+
336
+ def feature_select(self, image_features):
337
+ if self.select_layer is not None:
338
+ image_features = image_features.hidden_states[self.select_layer]
339
+
340
+ if self.select_feature == "patch":
341
+ image_features = image_features[:, 1:]
342
+ elif self.select_feature == "cls_patch":
343
+ image_features = image_features
344
+ else:
345
+ raise ValueError(f"Unexpected select feature: {self.select_feature}")
346
+
347
+ return image_features
348
+
349
+ def vision_tower_forward(self, image):
350
+ image_feature = self.vision_tower(image, output_hidden_states=True)
351
+ return image_feature
352
+
353
+ def _forward(self, images, out_T=1):
354
+ if type(images) is list:
355
+ image_features = []
356
+ for image in images:
357
+ image_feature = self.vision_tower_forward(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
358
+ image_feature = self.feature_select(image_feature).to(image.dtype)
359
+ image_feature = image_features.reshape(image_feature.shape[0], self.W, self.H, self.D)
360
+ image_features.append(image_feature)
361
+ else:
362
+ original_shape = images.shape
363
+ if len(original_shape) == 5 and self.T == 1:
364
+ # downsample temporally if needed, and reshape from (B, T, C, W, H) to (B*T, C, W, H).
365
+ images = images[:, ::original_shape[1] // out_T, ...]
366
+ original_shape = images.shape
367
+ images = images.view(-1, *original_shape[2:])
368
+
369
+ image_features = self.vision_tower_forward(images.to(device=self.device, dtype=self.dtype))
370
+ image_features = self.feature_select(image_features).to(images.dtype)
371
+ # Reshape back to (B, T, ...) if necessary
372
+ if len(original_shape) == 5 and self.T == 1:
373
+ # Assuming the feature dimension does not change, adapt the following line if it does
374
+ new_shape = list(image_features.shape[:-2]) + [self.W, self.H, self.hidden_size]
375
+ image_features = image_features.reshape(new_shape)
376
+ feature_size = image_features.shape[1:]
377
+ image_features = image_features.view(original_shape[0], original_shape[1], *feature_size)
378
+
379
+ else:
380
+ image_features = image_features.reshape(image_features.shape[0], self.T, self.W, self.H, self.hidden_size)
381
+
382
+ return image_features
383
+
384
+ def forward(self, images):
385
+ return self._forward(images)
386
+
387
+ @property
388
+ def dummy_feature(self):
389
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
390
+
391
+ @property
392
+ def dtype(self):
393
+ return self.vision_tower.dtype
394
+
395
+ @property
396
+ def device(self):
397
+ return self.vision_tower.device
398
+
399
+ @property
400
+ def num_patches(self):
401
+ return (self.config.image_size // self.config.patch_size) ** 2
402
+
403
+
404
+ class InternVideoTower(VisionTower):
405
+ def __init__(self, model_name_or_path: str, config: PretrainedConfig, vision_config: PretrainedConfig = None):
406
+ if vision_config is None:
407
+ vision_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
408
+
409
+ super().__init__(model_name_or_path, config, vision_config)
410
+ self.vision_config = vision_config
411
+ normalize = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
412
+
413
+ print('loading: ', model_name_or_path)
414
+ model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
415
+ self.vision_tower = model.to(dtype=eval(config.model_dtype))
416
+
417
+ transform = Compose([
418
+ Resize(self.vision_config.img_size, interpolation='bilinear'),
419
+ CenterCrop(size=(self.vision_config.img_size, self.vision_config.img_size)),
420
+ ClipToTensor(),
421
+ Normalize(mean=normalize[0], std=normalize[1])
422
+ ])
423
+
424
+ self.vision_processor = ProcessorWrapper(transform=transform,
425
+ height=self.vision_config.img_size,
426
+ width=self.vision_config.img_size,
427
+ frames_per_clip=self.vision_config.num_frames,
428
+ image_mean=normalize[0])
429
+
430
+ self.W = self.H = vision_config.img_size // vision_config.patch_size
431
+ self.T = self.vision_config.num_frames // self.vision_config.tubelet_size
432
+ self.num_frames = self.vision_config.num_frames
433
+ self.hidden_size = vision_config.d_model
434
+ self.vision_select_layer=self.select_layer
435
+ self.select_layer=None
436
+
437
+ def vision_tower_forward(self, video):
438
+ if video.shape[-3] < self.num_frames:
439
+ video = video.repeat_interleave(self.num_frames, dim=-3)
440
+ elif video.shape[-3] > self.num_frames:
441
+ video = video[:, :, ::video.shape[-3] // self.num_frames, ...]
442
+
443
+ video_feature = self.vision_tower(video.to(device=self.device, dtype=self.dtype),
444
+ x_vis_return_idx=self.vision_select_layer, x_vis_only=True)
445
+
446
+ return video_feature
447
+
448
+ @property
449
+ def device(self):
450
+ return self.vision_tower.pos_embed.device
451
+
452
+
453
+ class SiglipVisionTower(VisionTower):
454
+ def __init__(self, model_name_or_path: str, config: PretrainedConfig, vision_config: PretrainedConfig = None):
455
+ if vision_config is None:
456
+ vision_config = SiglipVisionConfig.from_pretrained(model_name_or_path)
457
+
458
+ super().__init__(model_name_or_path, config, vision_config)
459
+ self.vision_config = vision_config
460
+ self.vision_tower_name = model_name_or_path
461
+ self.vision_processor = SiglipImageProcessor.from_pretrained(self.vision_tower_name)
462
+
463
+ print('loading: ', model_name_or_path)
464
+ self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
465
+
466
+ self.hidden_size = self.vision_config.hidden_size
467
+ self.W = self.H = self.vision_config.image_size // self.vision_config.patch_size
468
+ self.T = 1
469
+ self.select_feature = "cls_patch"
470
+
471
+
472
+ class ApolloVisionTower(PreTrainedModel):
473
+ def __init__(self, config, vision_tower_cfg):
474
+ super(ApolloVisionTower, self).__init__(config, vision_tower_cfg)
475
+ self.model_name_or_path = vision_tower_cfg._name_or_path
476
+ self.vision_towers = vision_tower_cfg.vision_towers
477
+ self._config = vision_tower_cfg
478
+
479
+ for vision_tower_name in self.vision_towers:
480
+ if 'internvideo' in vision_tower_name.lower():
481
+ vision_tower = InternVideoTower(os.path.join(vision_tower_cfg._name_or_path, vision_tower_name), config)
482
+ elif 'siglip' in vision_tower_name.lower():
483
+ vision_tower = SiglipVisionTower(os.path.join(vision_tower_cfg._name_or_path, vision_tower_name),
484
+ config)
485
+
486
+ setattr(self, vision_tower_name, vision_tower)
487
+
488
+ self.vision_processor = [getattr(self, vt).vision_processor for vt in self.vision_towers]
489
+ self.num_vision_encoders = len(self.vision_towers)
490
+ self.W = self.H = max([getattr(self, vt).W for vt in self.vision_towers])
491
+ self.T = max([getattr(self, vt).T for vt in self.vision_towers])
492
+ self.max_tubelet_size = max(
493
+ [getattr(getattr(self, vt).vision_config, 'tubelet_size', 1) for vt in self.vision_towers])
494
+
495
+ self._hidden_size = sum([getattr(self, vt).hidden_size for vt in self.vision_towers])
496
+ self.token_output_shape = (self.T, self.W, self.H)
497
+ self.config.num_vision_encoders = self.num_vision_encoders
498
+ self.config.vision_towers = self.vision_towers
499
+ self.config.token_output_shape = self.token_output_shape
500
+
501
+ def forward(self, x):
502
+ output_features = []
503
+ for x_s, vision_tower_name in zip(x, self.vision_towers):
504
+ vision_tower = getattr(self, vision_tower_name)
505
+ features = vision_tower._forward(x_s, out_T=self.T)
506
+
507
+ if len(features.shape) != len(self.token_output_shape) + 2:
508
+ features = features.unsqueeze(1)
509
+
510
+ if features.shape[-len(self.token_output_shape) - 1:-1] != self.token_output_shape:
511
+ features = features.permute(0, 4, 1, 2, 3).contiguous() # shape [B, D, T, W, H]
512
+ features = F.interpolate(features.to(torch.float32), size=self.token_output_shape, mode='trilinear',
513
+ align_corners=False).to(features.dtype)
514
+ features = features.permute(0, 2, 3, 4, 1).contiguous()
515
+
516
+ output_features.append(features)
517
+
518
+ output_features = torch.cat(output_features, dim=-1)
519
+ output_features = torch.flatten(output_features, start_dim=1, end_dim=-2)
520
+ return output_features
521
+
522
+ def save_pretrained(
523
+ self,
524
+ save_directory: Union[str, os.PathLike],
525
+ state_dict=None,
526
+ **kwargs,
527
+ ):
528
+ if state_dict is None:
529
+ state_dict = self.state_dict()
530
+
531
+ for vision_tower_name in self.vision_towers:
532
+ vision_tower = getattr(self, vision_tower_name)
533
+ vision_tower_state_dict = OrderedDict(
534
+ {k.split(f"vision_tower.{vision_tower_name}.vision_tower.")[-1]: v for k, v in state_dict.items() if
535
+ vision_tower_name in k}
536
+ )
537
+ vision_tower.vision_tower.save_pretrained(os.path.join(save_directory, vision_tower_name),
538
+ state_dict=vision_tower_state_dict, **kwargs)
539
+ vision_tower.vision_processor.save_pretrained(os.path.join(save_directory, vision_tower_name))
540
+
541
+ config = self.config
542
+ config.configs = {}
543
+ config.save_pretrained(save_directory)
544
+
545
+ @property
546
+ def patch_size(self):
547
+ return self._patch_size
548
+
549
+ @property
550
+ def image_size(self):
551
+ return self._image_size
552
+
553
+ @property
554
+ def hidden_size(self):
555
+ return self._hidden_size
556
+
vision_tower/config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "configs": {},
3
+ "model_type": "hybrid_vision_tower",
4
+ "num_vision_encoders": 2,
5
+ "token_output_shape": [
6
+ 4,
7
+ 27,
8
+ 27
9
+ ],
10
+ "transformers_version": "4.44.0",
11
+ "vision_towers": [
12
+ "siglip-so400m-patch14-384",
13
+ "internvideo2"
14
+ ],
15
+ "auto_map": {
16
+ "AutoConfig": "configuration_hybrid.HybridTowerConfig"
17
+ }
18
+ }
vision_tower/configuration_hybrid.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
5
+ import os
6
+ import torch.nn.functional as F
7
+ from transformers.modeling_utils import PreTrainedModel
8
+ from transformers.configuration_utils import PretrainedConfig
9
+ from transformers import AutoConfig
10
+ from collections import OrderedDict
11
+
12
+
13
+ class HybridTowerConfig(PretrainedConfig):
14
+ model_type = "hybrid_vision_tower"
15
+
16
+ def __init__(self, configs=None, **kwargs):
17
+ """
18
+ Initializes the HybridTowerConfig.
19
+
20
+ Args:
21
+ configs (dict, optional): A dictionary where keys are component names and values are
22
+ instances of configurations that have a `to_dict()` method.
23
+ **kwargs: Additional keyword arguments that are passed to the superclass.
24
+ """
25
+ super().__init__(**kwargs)
26
+ self.configs = {}
27
+
28
+ if configs is not None:
29
+ if not isinstance(configs, dict):
30
+ raise TypeError("configs must be a dictionary where keys are component names and values are configuration objects.")
31
+
32
+ for component_name, config in configs.items():
33
+ if hasattr(config, 'to_dict'):
34
+ self.configs[component_name] = config.to_dict()
35
+ else:
36
+ raise TypeError(f"The configuration for '{component_name}' does not have a to_dict() method and cannot be serialized.")
37
+
38
+ def to_dict(self):
39
+ """
40
+ Serializes this instance to a Python dictionary.
41
+
42
+ Returns:
43
+ dict: A dictionary containing all the keys and values of this configuration instance.
44
+ """
45
+ config_dict = super().to_dict()
46
+ config_dict['configs'] = self.configs
47
+ return config_dict
48
+
vision_tower/internvideo2/config.json ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "PretrainInternVideo2"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.0,
6
+ "attn_pool_num_heads": 16,
7
+ "checkpoint_num": 40,
8
+ "clip_embed_dim": 768,
9
+ "clip_input_resolution": 224,
10
+ "clip_norm_type": "l2",
11
+ "clip_return_layer": 6,
12
+ "clip_student_return_interval": 1,
13
+ "clip_teacher": null,
14
+ "clip_teacher_embed_dim": 3200,
15
+ "clip_teacher_final_dim": 768,
16
+ "clip_teacher_return_interval": 1,
17
+ "d_model": 1408,
18
+ "encoder_stride": 16,
19
+ "hidden_act": "gelu",
20
+ "hidden_dropout_prob": 0.0,
21
+ "hidden_size": 768,
22
+ "image_mask_ratio": 0.5,
23
+ "image_mask_type": "random",
24
+ "image_size": 224,
25
+ "img_size": 224,
26
+ "initializer_range": 0.02,
27
+ "intermediate_size": 3072,
28
+ "keep_temporal": false,
29
+ "layer_norm_eps": 1e-12,
30
+ "model_type": "internvideo2",
31
+ "name": "pretrain_internvideo2_1b_patch14_224",
32
+ "num_attention_heads": 12,
33
+ "num_channels": 3,
34
+ "num_frames": 4,
35
+ "num_heads": 16,
36
+ "num_hidden_layers": 12,
37
+ "only_mask": true,
38
+ "patch_size": 14,
39
+ "qkv_bias": false,
40
+ "sep_image_video_pos_embed": true,
41
+ "torch_dtype": "bfloat16",
42
+ "transformers_version": "4.44.0",
43
+ "tubelet_size": 1,
44
+ "use_checkpoint": true,
45
+ "use_flash_attn": false,
46
+ "use_fused_mlp": false,
47
+ "use_fused_rmsnorm": false,
48
+ "video_mask_ratio": 0.8,
49
+ "video_mask_type": "random",
50
+ "auto_map": {
51
+ "AutoConfig": "configuration_internvideo2.InternVideo2Config",
52
+ "AutoModel": "modeling_internvideo2.InternVideo2Model"
53
+ }
54
+ }
vision_tower/internvideo2/configuration_internvideo2.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class InternVideo2Config(PretrainedConfig):
5
+ model_type = "internvideo2"
6
+
7
+ def __init__(
8
+ self,
9
+ img_size=224,
10
+ patch_size=14,
11
+ tubelet_size=1,
12
+ num_frames=8,
13
+ d_model=1408,
14
+ num_heads=16,
15
+ depth=40,
16
+ mlp_ratio=48 / 11,
17
+ qkv_bias=False,
18
+ init_values=1e-5,
19
+ use_checkpoint=False,
20
+ checkpoint_num=0,
21
+ use_flash_attn=False,
22
+ use_fused_mlp=False,
23
+ use_fused_rmsnorm=False,
24
+ qk_normalization=True,
25
+ clip_embed_dim=1408,
26
+ attn_pool_num_heads=16,
27
+ clip_teacher_embed_dim=512,
28
+ clip_teacher_final_dim=512,
29
+ clip_student_return_interval=4,
30
+ clip_return_layer=3,
31
+ clip_norm_type="l2",
32
+ sep_image_video_pos_embed=False,
33
+ **kwargs,
34
+ ):
35
+ """
36
+ This is the configuration class to store the configuration of a `InternVideo2Model`.
37
+ It is used to instantiate a InternVideo2 model according to the specified arguments,
38
+ defining the model architecture.
39
+
40
+ Args:
41
+ img_size (int, optional): Input image size. Defaults to 224.
42
+ patch_size (int, optional): Size of each patch. Defaults to 14.
43
+ tubelet_size (int, optional): Temporal tubelet size. Defaults to 1.
44
+ num_frames (int, optional): Number of frames in the video input. Defaults to 8.
45
+ d_model (int, optional): Dimension of the model embeddings. Defaults to 1408.
46
+ num_heads (int, optional): Number of attention heads. Defaults to 16.
47
+ depth (int, optional): Number of transformer encoder layers. Defaults to 40.
48
+ mlp_ratio (float, optional): Ratio of MLP hidden dim to embedding dim. Defaults to 48/11.
49
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Defaults to False.
50
+ init_values (float, optional): Initial values for layer scale. Defaults to 1e-5.
51
+ use_checkpoint (bool, optional): Whether to use gradient checkpointing. Defaults to False.
52
+ checkpoint_num (int, optional): Number of layers to apply checkpointing. Defaults to 0.
53
+ use_flash_attn (bool, optional): Whether to use FlashAttention. Defaults to False.
54
+ use_fused_mlp (bool, optional): Whether to use fused MLP. Defaults to False.
55
+ use_fused_rmsnorm (bool, optional): Whether to use fused RMSNorm. Defaults to False.
56
+ qk_normalization (bool, optional): Whether to apply QK normalization. Defaults to True.
57
+ clip_embed_dim (int, optional): Embedding dimension for CLIP. Defaults to 1408.
58
+ attn_pool_num_heads (int, optional): Number of heads for attention pooling. Defaults to 16.
59
+ clip_teacher_embed_dim (int, optional): Embedding dimension for CLIP teacher model. Defaults to 512.
60
+ clip_teacher_final_dim (int, optional): Final embedding dimension for CLIP teacher model. Defaults to 512.
61
+ clip_student_return_interval (int, optional): Interval for returning student layers. Defaults to 4.
62
+ clip_return_layer (int, optional): Number of layers to return for alignment. Defaults to 3.
63
+ clip_norm_type (str, optional): Normalization type for CLIP ('l2' or 'none'). Defaults to 'l2'.
64
+ sep_image_video_pos_embed (bool, optional): Whether to use separate position embeddings for image and video. Defaults to False.
65
+ **kwargs: Additional keyword arguments.
66
+ """
67
+ super().__init__(**kwargs)
68
+ self.img_size = img_size
69
+ self.patch_size = patch_size
70
+ self.tubelet_size = tubelet_size
71
+ self.num_frames = num_frames
72
+ self.d_model = d_model
73
+ self.num_heads = num_heads
74
+ self.depth = depth
75
+ self.mlp_ratio = mlp_ratio
76
+ self.qkv_bias = qkv_bias
77
+ self.init_values = init_values
78
+ self.use_checkpoint = use_checkpoint
79
+ self.checkpoint_num = checkpoint_num
80
+ self.use_flash_attn = use_flash_attn
81
+ self.use_fused_mlp = use_fused_mlp
82
+ self.use_fused_rmsnorm = use_fused_rmsnorm
83
+ self.qk_normalization = qk_normalization
84
+ self.clip_embed_dim = clip_embed_dim
85
+ self.attn_pool_num_heads = attn_pool_num_heads
86
+ self.clip_teacher_embed_dim = clip_teacher_embed_dim
87
+ self.clip_teacher_final_dim = clip_teacher_final_dim
88
+ self.clip_student_return_interval = clip_student_return_interval
89
+ self.clip_return_layer = clip_return_layer
90
+ self.clip_norm_type = clip_norm_type
91
+ self.sep_image_video_pos_embed = sep_image_video_pos_embed
vision_tower/internvideo2/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8984deb0dbb7b790922f75ec76dbcb85d534dc0f316d8d0dc6d3cf9f4d5becb0
3
+ size 2098289968
vision_tower/internvideo2/modeling_internvideo2.py ADDED
@@ -0,0 +1,934 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modeling_internvideo2.py
2
+
3
+ import logging
4
+ import math
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from transformers import PreTrainedModel
11
+ from transformers.utils import logging as hf_logging
12
+
13
+ from torch.utils.checkpoint import checkpoint # Correct
14
+
15
+ from functools import partial
16
+
17
+ from .configuration_internvideo2 import InternVideo2Config # Import the configuration
18
+
19
+ try:
20
+ from einops import rearrange
21
+ except ImportError:
22
+ raise ImportError("Please install einops to use this model.")
23
+
24
+ try:
25
+ from timm.models.layers import DropPath, to_2tuple
26
+ except ImportError:
27
+ raise ImportError("Please install timm to use this model.")
28
+
29
+ logger = hf_logging.get_logger(__name__)
30
+
31
+ # Position embedding functions
32
+ def get_3d_sincos_pos_embed(embed_dim, grid_size, t_size, cls_token=False):
33
+ assert embed_dim % 4 == 0
34
+ embed_dim_spatial = embed_dim // 4 * 3
35
+ embed_dim_temporal = embed_dim // 4
36
+
37
+ # Spatial
38
+ grid_h = np.arange(grid_size, dtype=np.float32)
39
+ grid_w = np.arange(grid_size, dtype=np.float32)
40
+ grid = np.meshgrid(grid_w, grid_h) # W first
41
+ grid = np.stack(grid, axis=0)
42
+
43
+ grid = grid.reshape([2, 1, grid_size, grid_size])
44
+ pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)
45
+
46
+ # Temporal
47
+ grid_t = np.arange(t_size, dtype=np.float32)
48
+ pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)
49
+
50
+ # Combine spatial and temporal embeddings
51
+ pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
52
+ pos_embed_temporal = np.repeat(pos_embed_temporal, grid_size**2, axis=1)
53
+ pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
54
+ pos_embed_spatial = np.repeat(pos_embed_spatial, t_size, axis=0)
55
+
56
+ pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1)
57
+ pos_embed = pos_embed.reshape([-1, embed_dim])
58
+
59
+ if cls_token:
60
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
61
+ return pos_embed
62
+
63
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
64
+ assert embed_dim % 2 == 0
65
+
66
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
67
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
68
+
69
+ emb = np.concatenate([emb_h, emb_w], axis=1)
70
+ return emb
71
+
72
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
73
+ assert embed_dim % 2 == 0
74
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
75
+ omega /= embed_dim / 2.0
76
+ omega = 1.0 / (10000 ** omega)
77
+
78
+ pos = pos.reshape(-1)
79
+ out = np.einsum('m,d->md', pos, omega)
80
+
81
+ emb_sin = np.sin(out)
82
+ emb_cos = np.cos(out)
83
+
84
+ emb = np.concatenate([emb_sin, emb_cos], axis=1)
85
+ return emb
86
+
87
+ # Define necessary classes: CrossAttention, AttentiveBlock, AttentionPoolingBlock, RMSNorm, LayerScale, Attention, Mlp, Block, PatchEmbed, Linear_Decoder
88
+
89
+
90
+ class CrossAttention(nn.Module):
91
+ def __init__(
92
+ self,
93
+ dim,
94
+ num_heads=8,
95
+ qkv_bias=False,
96
+ qk_scale=None,
97
+ attn_drop=0.0,
98
+ proj_drop=0.0,
99
+ attn_head_dim=None,
100
+ out_dim=None,
101
+ ):
102
+ super().__init__()
103
+ if out_dim is None:
104
+ out_dim = dim
105
+ self.num_heads = num_heads
106
+ head_dim = dim // num_heads
107
+ if attn_head_dim is not None:
108
+ head_dim = attn_head_dim
109
+ all_head_dim = head_dim * self.num_heads
110
+ self.scale = qk_scale or head_dim ** -0.5
111
+ assert all_head_dim == dim
112
+
113
+ self.q = nn.Linear(dim, all_head_dim, bias=False)
114
+ self.k = nn.Linear(dim, all_head_dim, bias=False)
115
+ self.v = nn.Linear(dim, all_head_dim, bias=False)
116
+
117
+ if qkv_bias:
118
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
119
+ self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
120
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
121
+ else:
122
+ self.q_bias = None
123
+ self.k_bias = None
124
+ self.v_bias = None
125
+
126
+ self.attn_drop = nn.Dropout(attn_drop)
127
+ self.proj = nn.Linear(all_head_dim, out_dim)
128
+ self.proj_drop = nn.Dropout(proj_drop)
129
+
130
+ def forward(self, x, k=None, v=None):
131
+ B, N, C = x.shape
132
+ N_k = k.shape[1]
133
+ N_v = v.shape[1]
134
+
135
+ q_bias, k_bias, v_bias = None, None, None
136
+ if self.q_bias is not None:
137
+ q_bias = self.q_bias
138
+ k_bias = self.k_bias
139
+ v_bias = self.v_bias
140
+
141
+ q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
142
+ q = (
143
+ q.reshape(B, N, 1, self.num_heads, -1)
144
+ .permute(2, 0, 3, 1, 4)
145
+ .squeeze(0)
146
+ ) # (B, N_head, N_q, dim)
147
+
148
+ k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
149
+ k = (
150
+ k.reshape(B, N_k, 1, self.num_heads, -1)
151
+ .permute(2, 0, 3, 1, 4)
152
+ .squeeze(0)
153
+ )
154
+
155
+ v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
156
+ v = (
157
+ v.reshape(B, N_v, 1, self.num_heads, -1)
158
+ .permute(2, 0, 3, 1, 4)
159
+ .squeeze(0)
160
+ )
161
+
162
+ q = q * self.scale
163
+ attn = q @ k.transpose(-2, -1) # (B, N_head, N_q, N_k)
164
+
165
+ attn = attn.softmax(dim=-1)
166
+ attn = self.attn_drop(attn)
167
+
168
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
169
+ x = self.proj(x)
170
+ x = self.proj_drop(x)
171
+
172
+ return x
173
+
174
+
175
+ class AttentiveBlock(nn.Module):
176
+ def __init__(
177
+ self,
178
+ dim,
179
+ num_heads,
180
+ qkv_bias=False,
181
+ qk_scale=None,
182
+ drop=0.0,
183
+ attn_drop=0.0,
184
+ drop_path=0.0,
185
+ norm_layer=nn.LayerNorm,
186
+ attn_head_dim=None,
187
+ out_dim=None,
188
+ ):
189
+ super().__init__()
190
+
191
+ self.norm1_q = norm_layer(dim)
192
+ self.norm1_k = norm_layer(dim)
193
+ self.norm1_v = norm_layer(dim)
194
+ self.cross_attn = CrossAttention(
195
+ dim,
196
+ num_heads=num_heads,
197
+ qkv_bias=qkv_bias,
198
+ qk_scale=qk_scale,
199
+ attn_drop=attn_drop,
200
+ proj_drop=drop,
201
+ attn_head_dim=attn_head_dim,
202
+ out_dim=out_dim,
203
+ )
204
+
205
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
206
+
207
+ def forward(
208
+ self, x_q, x_kv, pos_q, pos_k, bool_masked_pos, rel_pos_bias=None
209
+ ):
210
+ x_q = self.norm1_q(x_q + pos_q)
211
+ x_k = self.norm1_k(x_kv + pos_k)
212
+ x_v = self.norm1_v(x_kv)
213
+ x = self.cross_attn(x_q, k=x_k, v=x_v)
214
+
215
+ return x
216
+
217
+
218
+ class AttentionPoolingBlock(AttentiveBlock):
219
+ def forward(self, x):
220
+ x_q = x.mean(1, keepdim=True)
221
+ x_kv, pos_q, pos_k = x, 0, 0
222
+ x = super().forward(
223
+ x_q, x_kv, pos_q, pos_k, bool_masked_pos=None, rel_pos_bias=None
224
+ )
225
+ x = x.squeeze(1)
226
+ return x
227
+
228
+
229
+ class RMSNorm(nn.Module):
230
+ def __init__(self, hidden_size, eps=1e-6):
231
+ super().__init__()
232
+ self.weight = nn.Parameter(torch.ones(hidden_size))
233
+ self.variance_epsilon = eps
234
+
235
+ def forward(self, hidden_states):
236
+ input_dtype = hidden_states.dtype
237
+ hidden_states = hidden_states.to(torch.float32)
238
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
239
+ hidden_states = hidden_states * torch.rsqrt(
240
+ variance + self.variance_epsilon
241
+ )
242
+ return self.weight * hidden_states.to(input_dtype)
243
+
244
+
245
+ class LayerScale(nn.Module):
246
+ def __init__(
247
+ self, dim, init_values=1e-5, inplace=False, force_fp32=False
248
+ ):
249
+ super().__init__()
250
+ self.inplace = inplace
251
+ self.weight = nn.Parameter(init_values * torch.ones(dim))
252
+ self.force_fp32 = force_fp32
253
+
254
+ @torch.cuda.amp.autocast(enabled=False)
255
+ def forward(self, x):
256
+ if self.force_fp32:
257
+ output_type = x.dtype
258
+ out = (
259
+ x.float().mul_(self.weight.float())
260
+ if self.inplace
261
+ else x.float() * self.weight.float()
262
+ )
263
+ return out.to(dtype=output_type)
264
+ else:
265
+ out = x.mul_(self.weight) if self.inplace else x * self.weight
266
+ return out
267
+
268
+
269
+ class Attention(nn.Module):
270
+ def __init__(
271
+ self,
272
+ dim,
273
+ num_heads=8,
274
+ qkv_bias=False,
275
+ attn_drop=0.0,
276
+ proj_drop=0.0,
277
+ use_flash_attn=False,
278
+ causal=False,
279
+ norm_layer=nn.LayerNorm,
280
+ qk_normalization=False,
281
+ use_fused_rmsnorm=False,
282
+ ):
283
+ super().__init__()
284
+ assert (
285
+ dim % num_heads == 0
286
+ ), "dim should be divisible by num_heads"
287
+ self.num_heads = num_heads
288
+ head_dim = dim // num_heads
289
+ self.scale = head_dim ** -0.5
290
+
291
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
292
+ self.attn_drop = nn.Dropout(attn_drop)
293
+ self.proj = nn.Linear(dim, dim)
294
+ self.proj_drop = nn.Dropout(proj_drop)
295
+
296
+ self.use_flash_attn = use_flash_attn
297
+ if use_flash_attn:
298
+ self.causal = causal
299
+ try:
300
+ from flash_attn.flash_attention import FlashAttention
301
+
302
+ self.inner_attn = FlashAttention(
303
+ attention_dropout=attn_drop
304
+ )
305
+ except ImportError:
306
+ raise ImportError(
307
+ "Please install flash_attn to use flash attention."
308
+ )
309
+
310
+ self.qk_normalization = qk_normalization
311
+ self.q_norm = norm_layer(dim) if qk_normalization else nn.Identity()
312
+ self.k_norm = norm_layer(dim) if qk_normalization else nn.Identity()
313
+ self.use_fused_rmsnorm = use_fused_rmsnorm
314
+
315
+ def _naive_attn(self, x):
316
+ B, N, C = x.shape
317
+ # print(x.shape, torch.cuda.memory_allocated(), torch.cuda.memory_allocated())
318
+ qkv = (
319
+ self.qkv(x)
320
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
321
+ .permute(2, 0, 3, 1, 4)
322
+ )
323
+ q, k, v = qkv.unbind(
324
+ 0
325
+ ) # make torchscript happy (cannot use tensor as tuple)
326
+
327
+ if self.qk_normalization:
328
+ B_, H_, N_, D_ = q.shape
329
+ q = (
330
+ self.q_norm(q.transpose(1, 2).flatten(-2, -1))
331
+ .view(B_, N_, H_, D_)
332
+ .transpose(1, 2)
333
+ )
334
+ k = (
335
+ self.k_norm(k.transpose(1, 2).flatten(-2, -1))
336
+ .view(B_, N_, H_, D_)
337
+ .transpose(1, 2)
338
+ )
339
+
340
+ attn = (q * self.scale) @ k.transpose(-2, -1)
341
+ # attn = attn - attn.max(-1)[0].unsqueeze(-1) # in case of overflow for fp16
342
+ attn = attn.softmax(dim=-1)
343
+ attn = self.attn_drop(attn)
344
+ # print(torch.cuda.memory_allocated(), torch.cuda.memory_allocated())
345
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
346
+ x = self.proj(x)
347
+ x = self.proj_drop(x)
348
+ return x
349
+
350
+ def _flash_attn(
351
+ self, x, key_padding_mask=None, need_weights=False
352
+ ):
353
+ qkv = self.qkv(x)
354
+ qkv = rearrange(
355
+ qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads
356
+ )
357
+
358
+ if self.qk_normalization:
359
+ q, k, v = qkv.unbind(2)
360
+ if self.use_fused_rmsnorm:
361
+ q = self.q_norm(q.flatten(-2, -1))[0].view(q.shape)
362
+ k = self.k_norm(k.flatten(-2, -1))[0].view(k.shape)
363
+ else:
364
+ q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
365
+ k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
366
+ qkv = torch.stack([q, k, v], dim=2)
367
+
368
+ context, _ = self.inner_attn(
369
+ qkv,
370
+ key_padding_mask=key_padding_mask,
371
+ need_weights=need_weights,
372
+ causal=self.causal,
373
+ )
374
+ outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
375
+ outs = self.proj_drop(outs)
376
+ return outs
377
+
378
+ def forward(self, x):
379
+ x = (
380
+ self._naive_attn(x)
381
+ if not self.use_flash_attn
382
+ else self._flash_attn(x)
383
+ )
384
+ return x
385
+
386
+
387
+ class Mlp(nn.Module):
388
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
389
+
390
+ def __init__(
391
+ self,
392
+ in_features,
393
+ hidden_features=None,
394
+ out_features=None,
395
+ act_layer=nn.GELU,
396
+ bias=True,
397
+ drop=0.0,
398
+ ):
399
+ super().__init__()
400
+ out_features = out_features or in_features
401
+ hidden_features = hidden_features or in_features
402
+ bias = to_2tuple(bias)
403
+ drop_probs = to_2tuple(drop)
404
+
405
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
406
+ self.act = act_layer()
407
+ self.drop1 = nn.Dropout(drop_probs[0])
408
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
409
+ self.drop2 = nn.Dropout(drop_probs[1])
410
+
411
+ def forward(self, x):
412
+ x = self.fc1(x)
413
+ x = self.act(x)
414
+ x = self.drop1(x)
415
+ x = self.fc2(x)
416
+ x = self.drop2(x)
417
+ return x
418
+
419
+
420
+ class Block(nn.Module):
421
+ def __init__(
422
+ self,
423
+ dim,
424
+ num_heads,
425
+ mlp_ratio=4.0,
426
+ qkv_bias=False,
427
+ drop=0.0,
428
+ attn_drop=0.0,
429
+ init_values=None,
430
+ drop_path=0.0,
431
+ act_layer=nn.GELU,
432
+ norm_layer=nn.LayerNorm,
433
+ use_flash_attn=False,
434
+ use_fused_mlp=False,
435
+ fused_mlp_heuristic=1,
436
+ with_cp=False,
437
+ qk_normalization=False,
438
+ layerscale_no_force_fp32=False,
439
+ use_fused_rmsnorm=False,
440
+ ):
441
+ super().__init__()
442
+
443
+ self.norm1 = norm_layer(dim)
444
+ self.attn = Attention(
445
+ dim,
446
+ num_heads=num_heads,
447
+ qkv_bias=qkv_bias,
448
+ attn_drop=attn_drop,
449
+ proj_drop=drop,
450
+ use_flash_attn=use_flash_attn,
451
+ causal=False,
452
+ norm_layer=norm_layer,
453
+ qk_normalization=qk_normalization,
454
+ use_fused_rmsnorm=use_fused_rmsnorm,
455
+ )
456
+ self.ls1 = (
457
+ LayerScale(
458
+ dim,
459
+ init_values=init_values,
460
+ force_fp32=(not layerscale_no_force_fp32),
461
+ )
462
+ if init_values
463
+ else nn.Identity()
464
+ )
465
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
466
+ self.drop_path1 = (
467
+ DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
468
+ )
469
+
470
+ self.norm2 = norm_layer(dim)
471
+ mlp_hidden_dim = int(dim * mlp_ratio)
472
+ if use_fused_mlp:
473
+ try:
474
+ from flash_attn.modules.mlp import FusedMLP
475
+ except ImportError:
476
+ raise ImportError(
477
+ "Please install flash_attn to use fused MLP."
478
+ )
479
+ self.mlp = FusedMLP(
480
+ in_features=dim,
481
+ hidden_features=mlp_hidden_dim,
482
+ heuristic=fused_mlp_heuristic,
483
+ )
484
+ else:
485
+ self.mlp = Mlp(
486
+ in_features=dim,
487
+ hidden_features=mlp_hidden_dim,
488
+ act_layer=act_layer,
489
+ drop=drop,
490
+ )
491
+ self.ls2 = (
492
+ LayerScale(
493
+ dim,
494
+ init_values=init_values,
495
+ force_fp32=(not layerscale_no_force_fp32),
496
+ )
497
+ if init_values
498
+ else nn.Identity()
499
+ )
500
+ self.drop_path2 = (
501
+ DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
502
+ )
503
+
504
+ self.with_cp = with_cp
505
+ self.use_fused_rmsnorm = use_fused_rmsnorm
506
+
507
+ def forward(self, x, residual=None):
508
+ def _inner_forward(x, residual=None):
509
+ if self.use_fused_rmsnorm:
510
+ x, residual = self.norm1(x, residual)
511
+ x = self.drop_path1(self.ls1(self.attn(x)))
512
+ x, residual = self.norm2(x, residual)
513
+ x = self.drop_path2(self.ls2(self.mlp(x)))
514
+ return x, residual
515
+ else:
516
+ assert residual is None
517
+ x = x + self.drop_path1(
518
+ self.ls1(self.attn(self.norm1(x)))
519
+ )
520
+ x = x + self.drop_path2(
521
+ self.ls2(self.mlp(self.norm2(x)))
522
+ )
523
+ return x
524
+
525
+ if self.with_cp:
526
+ return checkpoint(_inner_forward, x, residual)
527
+ else:
528
+ return _inner_forward(x, residual=residual)
529
+
530
+
531
+ class PatchEmbed(nn.Module):
532
+ """3D Image to Patch Embedding"""
533
+
534
+ def __init__(
535
+ self,
536
+ img_size=224,
537
+ patch_size=16,
538
+ in_chans=3,
539
+ embed_dim=768,
540
+ num_frames=8,
541
+ tubelet_size=1,
542
+ norm_layer=None,
543
+ ):
544
+ super().__init__()
545
+ img_size = to_2tuple(img_size)
546
+ patch_size = to_2tuple(patch_size)
547
+ self.img_size = img_size
548
+ self.patch_size = patch_size
549
+ self.grid_size = (
550
+ num_frames // tubelet_size,
551
+ img_size[0] // patch_size[0],
552
+ img_size[1] // patch_size[1],
553
+ ) # (T, H, W)
554
+ self.num_patches = (
555
+ self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
556
+ )
557
+ self.num_img_patches = self.grid_size[1] * self.grid_size[2]
558
+
559
+ self.proj = nn.Conv3d(
560
+ in_channels=in_chans,
561
+ out_channels=embed_dim,
562
+ kernel_size=(tubelet_size, patch_size[0], patch_size[1]),
563
+ stride=(tubelet_size, patch_size[0], patch_size[1]),
564
+ )
565
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
566
+
567
+ def forward(self, x):
568
+ x = self.proj(x)
569
+ x = (
570
+ x.flatten(3)
571
+ .permute(0, 2, 3, 1)
572
+ ) # B x C x T x HW => B x T x HW x C
573
+ x = self.norm(x)
574
+ return x
575
+
576
+
577
+
578
+ class Linear_Decoder(nn.Module):
579
+ def __init__(self, in_channels=1408, out_channels=3200, norm_layer=nn.LayerNorm, clip_norm_type='l2'):
580
+ super().__init__()
581
+ self.clip_norm_type = clip_norm_type
582
+ logger.info(f'Normalization Type: {clip_norm_type}')
583
+
584
+ self.head = nn.Linear(in_channels, out_channels)
585
+ self.norm = norm_layer(out_channels)
586
+
587
+ def forward(self, x):
588
+ x = self.norm(self.head(x))
589
+
590
+ if self.clip_norm_type == 'l2':
591
+ x = x / x.norm(dim=-1, keepdim=True)
592
+ elif self.clip_norm_type == 'none':
593
+ pass
594
+ else:
595
+ raise NotImplementedError
596
+
597
+ return x
598
+
599
+ class InternVideo2Model(PreTrainedModel):
600
+ config_class = InternVideo2Config
601
+ base_model_prefix = "internvideo2"
602
+
603
+ def __init__(self, config: InternVideo2Config):
604
+ super().__init__(config)
605
+
606
+ in_chans = 3
607
+ drop_path_rate = 0.25
608
+ qk_normalization = config.qk_normalization
609
+ clip_embed_dim = config.clip_embed_dim
610
+ num_heads = config.num_heads
611
+ qkv_bias = config.qkv_bias
612
+ init_values = config.init_values
613
+ mlp_ratio = config.mlp_ratio
614
+ depth = config.depth
615
+ num_frames = config.num_frames
616
+ self.num_frames = num_frames
617
+ self.tubelet_size = config.tubelet_size
618
+ use_fused_mlp = config.use_fused_mlp
619
+ use_fused_rmsnorm = config.use_fused_rmsnorm
620
+ use_flash_attn = config.use_flash_attn
621
+ assert (
622
+ use_flash_attn
623
+ == use_fused_rmsnorm
624
+ == use_fused_mlp
625
+ ), "use_flash_attn, use_fused_rmsnorm and use_fused_mlp should be consistent"
626
+
627
+ self.use_flash_attn = use_flash_attn
628
+ embed_dim = config.d_model
629
+ self.embed_dim = embed_dim
630
+
631
+ self.depth = depth
632
+ self.clip_norm_type = config.clip_norm_type
633
+ self.return_index = []
634
+ for i in range(config.clip_return_layer):
635
+ self.return_index.append(
636
+ depth - int(i * config.clip_student_return_interval) - 1
637
+ )
638
+ logger.info(f"Normalization Type: {config.clip_norm_type}")
639
+ logger.info(f"Student Return Index: {self.return_index}")
640
+
641
+ if use_fused_rmsnorm:
642
+ try:
643
+ from flash_attn.ops.rms_norm import DropoutAddRMSNorm
644
+ except ImportError:
645
+ raise ImportError(
646
+ "Please install flash_attn to use fused RMSNorm."
647
+ )
648
+ norm_layer_for_blocks = partial(
649
+ DropoutAddRMSNorm, eps=1e-6, prenorm=True
650
+ )
651
+ else:
652
+ norm_layer_for_blocks = partial(RMSNorm, eps=1e-6)
653
+ self.norm_layer_for_blocks = norm_layer_for_blocks
654
+ self.patch_embed = PatchEmbed(
655
+ config.img_size,
656
+ config.patch_size,
657
+ in_chans,
658
+ embed_dim,
659
+ num_frames=num_frames,
660
+ tubelet_size=self.tubelet_size,
661
+ )
662
+ num_patches = self.patch_embed.num_patches
663
+ num_img_patches = self.patch_embed.num_img_patches
664
+
665
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
666
+
667
+ self.sep_pos_embed = False
668
+ self.sep_image_video_pos_embed = config.sep_image_video_pos_embed
669
+ if self.sep_pos_embed:
670
+ raise NotImplementedError
671
+ else:
672
+ if self.sep_image_video_pos_embed:
673
+ logger.info(
674
+ "Use joint position embedding, for image and video we use different pos_embed."
675
+ )
676
+ self.pos_embed = nn.Parameter(
677
+ torch.zeros(1, num_patches + 1, embed_dim)
678
+ )
679
+ self.img_pos_embed = nn.Parameter(
680
+ torch.zeros(1, num_img_patches + 1, embed_dim)
681
+ )
682
+ # for CLIP decoder
683
+ self.clip_pos_embed = nn.Parameter(
684
+ torch.zeros(1, num_patches + 1, embed_dim)
685
+ )
686
+ self.clip_img_pos_embed = nn.Parameter(
687
+ torch.zeros(1, num_img_patches + 1, embed_dim)
688
+ )
689
+ else:
690
+ logger.info(
691
+ "Use joint position embedding, for image and video we use same pos_embed."
692
+ )
693
+ self.pos_embed = nn.Parameter(
694
+ torch.zeros(1, num_patches + 1, embed_dim)
695
+ )
696
+ self.clip_pos_embed = nn.Parameter(
697
+ torch.zeros(1, num_patches + 1, embed_dim)
698
+ )
699
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
700
+ # choose which layer to use checkpoint
701
+ with_cp_list = [False] * depth
702
+ if config.use_checkpoint:
703
+ for idx in range(depth):
704
+ if idx < config.checkpoint_num:
705
+ with_cp_list[idx] = True
706
+ logger.info(f"Droppath rate: {dpr}")
707
+ logger.info(f"Checkpoint list: {with_cp_list}")
708
+
709
+ self.blocks = nn.ModuleList(
710
+ [
711
+ Block(
712
+ embed_dim,
713
+ num_heads,
714
+ mlp_ratio,
715
+ qkv_bias=qkv_bias,
716
+ norm_layer=norm_layer_for_blocks,
717
+ drop_path=dpr[i],
718
+ init_values=init_values,
719
+ attn_drop=0.0,
720
+ use_flash_attn=use_flash_attn,
721
+ use_fused_mlp=use_fused_mlp,
722
+ fused_mlp_heuristic=1,
723
+ with_cp=with_cp_list[i],
724
+ qk_normalization=qk_normalization,
725
+ layerscale_no_force_fp32=False,
726
+ use_fused_rmsnorm=use_fused_rmsnorm,
727
+ )
728
+ for i in range(depth)
729
+ ]
730
+ )
731
+ self.clip_projector = AttentionPoolingBlock(
732
+ dim=embed_dim,
733
+ num_heads=config.attn_pool_num_heads,
734
+ qkv_bias=True,
735
+ qk_scale=None,
736
+ drop=0.0,
737
+ attn_drop=0.0,
738
+ norm_layer=partial(nn.LayerNorm, eps=1e-5),
739
+ out_dim=clip_embed_dim,
740
+ )
741
+
742
+ # CLIP decoder
743
+ self.clip_decoder = nn.ModuleList(
744
+ [
745
+ Linear_Decoder(
746
+ in_channels=embed_dim,
747
+ out_channels=config.clip_teacher_embed_dim,
748
+ norm_layer=partial(nn.LayerNorm, eps=1e-5),
749
+ clip_norm_type=config.clip_norm_type,
750
+ )
751
+ for _ in range(config.clip_return_layer)
752
+ ]
753
+ )
754
+ self.final_clip_decoder = nn.Identity()
755
+ if config.clip_teacher_final_dim > 0:
756
+ self.final_clip_decoder = Linear_Decoder(
757
+ in_channels=clip_embed_dim,
758
+ out_channels=config.clip_teacher_final_dim,
759
+ norm_layer=partial(nn.LayerNorm, eps=1e-5),
760
+ clip_norm_type=config.clip_norm_type,
761
+ )
762
+
763
+ # Removed initialization methods and code
764
+
765
+ @property
766
+ def dtype(self):
767
+ return self.patch_embed.proj.weight.dtype
768
+
769
+ def get_num_layers(self):
770
+ return len(self.blocks)
771
+
772
+ @torch.jit.ignore
773
+ def no_weight_decay(self):
774
+ return {
775
+ "pos_embed",
776
+ "pos_embed_spatial",
777
+ "pos_embed_temporal",
778
+ "pos_embed_cls",
779
+ "img_pos_embed",
780
+ "cls_token",
781
+ "clip_pos_embed",
782
+ "clip_pos_embed_spatial",
783
+ "clip_pos_embed_temporal",
784
+ "clip_pos_embed_cls",
785
+ "clip_img_pos_embed",
786
+ }
787
+
788
+ def forward(
789
+ self,
790
+ x,
791
+ mask=None,
792
+ use_image=False,
793
+ x_vis_return_idx=-1,
794
+ x_vis_only=False,
795
+ ):
796
+ x = self.patch_embed(x.type(self.dtype))
797
+ B, T, L, C = x.shape
798
+ x = x.view([B, T * L, C])
799
+
800
+ # Append cls token
801
+ cls_tokens = self.cls_token.expand(B, -1, -1)
802
+ x = torch.cat((cls_tokens, x), dim=1)
803
+
804
+ # Add positional embeddings
805
+ if self.sep_pos_embed:
806
+ raise NotImplementedError
807
+ else:
808
+ if use_image:
809
+ if self.sep_image_video_pos_embed:
810
+ pos_embed = self.img_pos_embed
811
+ else:
812
+ cls_pos_embed = self.pos_embed[:, 0:1, :]
813
+ img_pos_embed = (
814
+ self.pos_embed[:, 1:, :]
815
+ .view(
816
+ 1,
817
+ self.num_frames,
818
+ self.patch_embed.num_patches // self.num_frames,
819
+ self.embed_dim,
820
+ )
821
+ .mean(dim=1)
822
+ )
823
+ pos_embed = torch.cat(
824
+ [cls_pos_embed, img_pos_embed], dim=1
825
+ )
826
+ else:
827
+ pos_embed = self.pos_embed
828
+ x = x + pos_embed
829
+
830
+ # Mask tokens
831
+ if mask is not None:
832
+ x = x[~mask].reshape(B, -1, C)
833
+ else:
834
+ x = x.reshape(B, -1, C)
835
+
836
+ residual = None
837
+ x_clip = []
838
+ for idx, blk in enumerate(self.blocks):
839
+ if isinstance(x, tuple) and len(x) == 2:
840
+ x, residual = x
841
+ x = blk(x, residual=residual)
842
+ # Return intermediate features
843
+ if idx in self.return_index:
844
+ if isinstance(x, tuple) and len(x) == 2:
845
+ tmp_x, tmp_residual = x
846
+ if residual is not None:
847
+ x_clip.append(tmp_x + tmp_residual)
848
+ else:
849
+ x_clip.append(x)
850
+ if idx == (self.depth + x_vis_return_idx):
851
+ break
852
+
853
+ if isinstance(x, tuple) and len(x) == 2:
854
+ x, residual = x
855
+ if residual is not None:
856
+ x = x + residual
857
+
858
+ x_vis = x
859
+ if x_vis_only:
860
+ return x_vis
861
+
862
+ x_pool_vis = self.clip_projector(x_vis)
863
+ x_align = self.final_clip_decoder(x_pool_vis)
864
+
865
+ # Align CLIP
866
+ x_clip = torch.stack(x_clip)
867
+ K, B, _, C_CLIP = x_clip.shape
868
+ # Add positional embeddings
869
+ if self.sep_pos_embed:
870
+ raise NotImplementedError
871
+ else:
872
+ if use_image:
873
+ if self.sep_image_video_pos_embed:
874
+ clip_pos_embed = self.clip_img_pos_embed
875
+ else:
876
+ clip_cls_pos_embed = self.clip_pos_embed[:, 0:1, :]
877
+ clip_img_pos_embed = (
878
+ self.clip_pos_embed[:, 1:, :]
879
+ .view(
880
+ 1,
881
+ self.num_frames,
882
+ self.patch_embed.num_patches // self.num_frames,
883
+ self.embed_dim,
884
+ )
885
+ .mean(dim=1)
886
+ )
887
+ clip_pos_embed = torch.cat(
888
+ [clip_cls_pos_embed, clip_img_pos_embed], dim=1
889
+ )
890
+
891
+ else:
892
+ clip_pos_embed = self.clip_pos_embed
893
+
894
+ clip_pos_embed = clip_pos_embed.repeat(B, 1, 1)
895
+ if mask is not None:
896
+ x_clip = x_clip + clip_pos_embed[~mask].view(
897
+ B, -1, C_CLIP
898
+ ).unsqueeze(0).repeat(K, 1, 1, 1)
899
+ else:
900
+ x_clip = x_clip + clip_pos_embed.view(B, -1, C_CLIP).unsqueeze(
901
+ 0
902
+ ).repeat(K, 1, 1, 1)
903
+
904
+ # CLIP decoder
905
+ x_clip_align = []
906
+ for idx, clip_decoder in enumerate(self.clip_decoder):
907
+ x_clip_align.append(clip_decoder(x_clip[idx]))
908
+ x_clip_align = torch.stack(x_clip_align)
909
+
910
+ return x_vis, x_pool_vis, x_clip_align, x_align
911
+
912
+
913
+ def load_pretrained_weights(self):
914
+ if self.config.pretrained is not None:
915
+ logger.info(f"Loading pretrained weights from {self.config.pretrained}")
916
+ state_dict = torch.load(self.config.pretrained, map_location='cpu')
917
+
918
+ # Rename 'ls1.weight' to 'ls1.weight' and 'ls2.weight' to 'ls2.weight'
919
+ new_state_dict = {}
920
+ for key, value in state_dict.items():
921
+ if key.endswith('.ls1.weight'):
922
+ new_key = key.replace('.ls1.weight', '.ls1.weight')
923
+ new_state_dict[new_key] = value
924
+ elif key.endswith('.ls2.weight'):
925
+ new_key = key.replace('.ls2.weight', '.ls2.weight')
926
+ new_state_dict[new_key] = value
927
+ else:
928
+ new_state_dict[key] = value
929
+
930
+ # Load the adjusted state_dict
931
+ message = self.load_state_dict(new_state_dict, strict=False)
932
+ logger.info(message)
933
+ else:
934
+ logger.info("No pretrained weights provided.")
vision_tower/internvideo2/preprocessor_config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Resize": {
3
+ "size": 224,
4
+ "interpolation": "bilinear"
5
+ },
6
+ "CenterCrop": {
7
+ "size": [
8
+ 224,
9
+ 224
10
+ ]
11
+ },
12
+ "ClipToTensor": {
13
+ "channel_nb": 3,
14
+ "div_255": true,
15
+ "numpy": false
16
+ },
17
+ "Normalize": {
18
+ "mean": [
19
+ 0.485,
20
+ 0.456,
21
+ 0.406
22
+ ],
23
+ "std": [
24
+ 0.229,
25
+ 0.224,
26
+ 0.225
27
+ ]
28
+ },
29
+ "image_processor_type": "transforms"
30
+ }
vision_tower/siglip-so400m-patch14-384/config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/opt/hpcaas/.mounts/fs-0663e2d3c38211883/home/orrzohar/Artemis/work_dirs/final_run/apollo-Qwen2.5-1.5B-Instruct-internvideo2-siglip-so400m-patch14-384-freeze-perciver_128_2-newprompt-ft/checkpoint-11850/vision_tower/siglip-so400m-patch14-384",
3
+ "architectures": [
4
+ "SiglipVisionModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "hidden_act": "gelu_pytorch_tanh",
8
+ "hidden_size": 1152,
9
+ "image_size": 384,
10
+ "intermediate_size": 4304,
11
+ "layer_norm_eps": 1e-06,
12
+ "model_type": "siglip_vision_model",
13
+ "num_attention_heads": 16,
14
+ "num_channels": 3,
15
+ "num_hidden_layers": 27,
16
+ "patch_size": 14,
17
+ "torch_dtype": "bfloat16",
18
+ "transformers_version": "4.44.0"
19
+ }
vision_tower/siglip-so400m-patch14-384/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2e9727665c906187682b29b7d6271003e2724b79983eb26249955d69719c735
3
+ size 856506120
vision_tower/siglip-so400m-patch14-384/preprocessor_config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_convert_rgb": null,
3
+ "do_normalize": true,
4
+ "do_rescale": true,
5
+ "do_resize": true,
6
+ "image_mean": [
7
+ 0.5,
8
+ 0.5,
9
+ 0.5
10
+ ],
11
+ "image_processor_type": "SiglipImageProcessor",
12
+ "image_std": [
13
+ 0.5,
14
+ 0.5,
15
+ 0.5
16
+ ],
17
+ "processor_class": "SiglipProcessor",
18
+ "resample": 3,
19
+ "rescale_factor": 0.00392156862745098,
20
+ "size": {
21
+ "height": 384,
22
+ "width": 384
23
+ }
24
+ }