Manli commited on
Commit
8b55b6a
·
1 Parent(s): fc1e86b

initial commit

Browse files
README.md CHANGED
@@ -1,3 +1,86 @@
1
- ---
2
- license: cc-by-4.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-nc-4.0
3
+ language:
4
+ - en
5
+ pipeline_tag: image-text-to-text
6
+ ---
7
+
8
+
9
+ # Model description
10
+ We are excited to announce the continuation and rebranding of our **BLIP series** into **XGen-MM**, to be better aligned with Salesforce's unified XGen initiative for large foundation models! This rebranding marks a significant step in our ongoing development of cutting-edge multimodal technologies.
11
+
12
+ `XGen-MM` is a series of the latest foundational Large Multimodal Models (LMMs) developed by Salesforce AI Research. This series advances upon the successful designs of the `BLIP` series, incorporating fundamental enhancements that ensure a more robust and superior foundation. These models have been trained at scale on high-quality image caption datasets and interleaved image-text data.
13
+
14
+ In the v1.1 (08/2024) release, we present a series of XGen-MM models including:
15
+ - Base model `xgen-mm-phi3-mini-base-r-v1.1`
16
+ - Single-image instruct model `xgen-mm-phi3-mini-instruct-r-v1.1`
17
+ - Multi-image instruct model `xgen-mm-phi3-mini-instruct-multi-r-v1.1`
18
+ - DPO instruct model `xgen-mm-phi3-mini-instruct-dpo-r-v1.1`
19
+
20
+ In addition to the models, we are also releasing a series of datasets for multi-modal pre-training, including:
21
+ - [MINT-1T: Scaling Open-Source Multimodal Data by 10x: A Multimodal Dataset with One Trillion Tokens](https://arxiv.org/abs/2406.11271)
22
+ - BLIP3-OCR-200M: a dataset with dense OCR annotations.
23
+ - BLIP3-GROUNDING-50M: a dataset for enhancing the ability to ground semantic concepts in images.
24
+ - BLIP3-KALE-300M (stay tuned): a large-scale curated high-quality caption dataset.
25
+
26
+ # Data
27
+
28
+
29
+ # Results
30
+
31
+ ### Base model (without instruction tuning)
32
+
33
+ ### Instruct model
34
+
35
+ ### DPO model
36
+
37
+
38
+ # How to use
39
+
40
+ Please check out our [inference notebook](demo.ipynb) for example code to use our model. We also provide example script for [batch inference](batch_inference.ipynb).
41
+
42
+ # Reproducibility:
43
+
44
+ Our evaluation is implemented based on [open-compass/VLMEvalKit](https://github.com/open-compass/VLMEvalKit). We will create a PR to that repo to support XGen-MM evaluation.
45
+
46
+
47
+ # Bias, Risks, Limitations, and Ethical Considerations
48
+ The main data sources are from the internet, including webpages,
49
+ image stock sites, and curated datasets released by the research community. We have excluded certain data, such as LAION, due to known CSAM concerns.
50
+ The model may be subject to bias from the original data source, as well as bias from LLMs and commercial APIs.
51
+ We strongly recommend users assess safety and fairness before applying to downstream applications.
52
+
53
+
54
+ # License
55
+
56
+ Our code and weights are released under the Creative Commons Attribution Non Commercial 4.0 [LICENSE](LICENSE.txt). Please fill out a form at [here](https://forms.gle/ffPc9oZC2ZGeJ1N68) to consult the commercial use of model weights.
57
+
58
+ # Code acknowledgement
59
+ Our training code is based on [OpenFlamingo: An open-source framework for training large multimodal models.](https://github.com/mlfoundations/open_flamingo), and part of our data preprocessing code is adapted from [LLaVA](https://github.com/haotian-liu/LLaVA).
60
+ Our evaluation code is based on [VLMEvalKit: Open-source evaluation toolkit of large vision-language models (LVLMs)](https://github.com/open-compass/VLMEvalKit).
61
+
62
+ We thank the authors for their open-source implementations.
63
+
64
+
65
+ # Citation
66
+ ```
67
+ @misc{xgen_mm_phi3_mini,
68
+ title={xgen-mm-phi3-mini-instruct Model Card},
69
+ url={https://huggingface.co/Salesforce/xgen-mm-phi3-mini-instruct-r-v1},
70
+ author={Salesforce AI Research},
71
+ month={May},
72
+ year={2024}
73
+ }
74
+ ```
75
+
76
+ # Troubleshoot
77
+
78
+ 1. If you missed any packages, please consider the following
79
+
80
+ ```
81
+ pip install torch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1 --index-url https://download.pytorch.org/whl/cu121
82
+ pip install open_clip_torch==2.24.0
83
+ pip install einops
84
+ pip install einops-exts
85
+ pip install transformers==4.41.1
86
+ ```
added_tokens.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "<pad>": 32011,
3
+ "<|assistant|>": 32001,
4
+ "<|endoftext|>": 32000,
5
+ "<|end|>": 32007,
6
+ "<|placeholder1|>": 32002,
7
+ "<|placeholder2|>": 32003,
8
+ "<|placeholder3|>": 32004,
9
+ "<|placeholder4|>": 32005,
10
+ "<|placeholder5|>": 32008,
11
+ "<|placeholder6|>": 32009,
12
+ "<|system|>": 32006,
13
+ "<|user|>": 32010
14
+ }
batch_inference.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "XGenMMModelForConditionalGeneration"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_xgenmm.XGenMMConfig",
7
+ "AutoModelForVision2Seq": "modeling_xgenmm.XGenMMModelForConditionalGeneration"
8
+ },
9
+ "model_type": "xgenmm",
10
+ "text_config": {
11
+ "initial_tokenizer_len": 32012,
12
+ "model_type": "phi3",
13
+ "sliding_window": 2047,
14
+ "torch_dtype": "bfloat16"
15
+ },
16
+ "torch_dtype": "float32",
17
+ "transformers_version": "4.41.1",
18
+ "vision_encoder_config": {
19
+ "anyres_patch_sampling": true,
20
+ "image_aspect_ratio": "anyres",
21
+ "model_type": "xgenmm_vision_encoder"
22
+ },
23
+ "vision_tokenizer_config": {
24
+ "model_type": "xgenmm_vision_tokenizer"
25
+ }
26
+ }
configuration_xgenmm.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from transformers import logging
3
+ from transformers import CONFIG_MAPPING
4
+
5
+ logger = logging.get_logger(__name__)
6
+
7
+ class XGenMMVisionEncoderConfig(PretrainedConfig):
8
+ model_type = "xgenmm_vision_encoder"
9
+
10
+ def __init__(self,
11
+ model_name: str = 'google/siglip-so400m-patch14-384',
12
+ anyres_grids: list[int] = [[384, 768],[768, 384],[768, 768],[1152, 384],[384,1152]],
13
+ **kwargs):
14
+ self.model_name = model_name
15
+ self.anyres_grids = anyres_grids
16
+ super().__init__(**kwargs)
17
+
18
+
19
+ class XGenMMVisionTokenizerConfig(PretrainedConfig):
20
+ model_type = "xgenmm_vision_tokenizer"
21
+
22
+ def __init__(self,
23
+ vis_feature_dim: int = 1152,
24
+ lang_embedding_dim: int = 3072,
25
+ num_vis_tokens: int = 128,
26
+ image_aspect_ratio: str = 'anyres',
27
+ **kwargs):
28
+ self.vis_feature_dim = vis_feature_dim
29
+ self.lang_embedding_dim = lang_embedding_dim
30
+ self.num_vis_tokens = num_vis_tokens
31
+ self.image_aspect_ratio = image_aspect_ratio
32
+ super().__init__(**kwargs)
33
+
34
+
35
+ class XGenMMConfig(PretrainedConfig):
36
+ model_type = "xgenmm"
37
+
38
+ def __init__(self,
39
+ vision_encoder_config: dict = None,
40
+ vision_tokenizer_config: dict = None,
41
+ text_config: dict = None,
42
+ **kwargs):
43
+
44
+ if vision_encoder_config is None:
45
+ vision_encoder_config = {'image_aspect_ratio': 'anyres', 'anyres_patch_sampling': True}
46
+ logger.info("vision_encoder_config is None. initializing the XGenMMVisionEncoderConfig with default values.")
47
+
48
+ if vision_tokenizer_config is None:
49
+ vision_tokenizer_config = {}
50
+ logger.info("vision_tokenizer_config is None. Initializing the XGenMMVisionTokenizerConfig with default values.")
51
+
52
+ if text_config is None:
53
+ text_config = {
54
+ 'initial_tokenizer_len':32012,
55
+ 'pad_token_id':32011,
56
+ 'bos_token_id':1,
57
+ 'eos_token_id':32000,
58
+ 'vocab_size': 32064,
59
+ 'hidden_size': 3072,
60
+ 'intermediate_size': 8192,
61
+ 'num_hidden_layers': 32,
62
+ 'num_attention_heads': 32,
63
+ 'num_key_value_heads': 32,
64
+ 'resid_pdrop': 0.0,
65
+ 'embd_pdrop': 0.0,
66
+ 'attention_dropout': 0.0,
67
+ 'hidden_act': 'silu',
68
+ 'max_position_embeddings': 4096,
69
+ 'original_max_position_embeddings': 4096,
70
+ 'initializer_range': 0.02,
71
+ 'rms_norm_eps': 1e-05,
72
+ 'use_cache': True,
73
+ 'rope_theta': 10000.0,
74
+ 'rope_scaling': None,
75
+ 'sliding_window': 2047,
76
+ 'return_dict': True,
77
+ 'output_hidden_states': False,
78
+ 'output_attentions': False,
79
+ 'torchscript': False,
80
+ 'torch_dtype': 'bfloat16',
81
+ 'use_bfloat16': False,
82
+ 'tf_legacy_loss': False,
83
+ 'pruned_heads': {},
84
+ 'tie_word_embeddings': False,
85
+ 'chunk_size_feed_forward': 0,
86
+ 'is_encoder_decoder': False,
87
+ 'is_decoder': False,
88
+ 'cross_attention_hidden_size': None,
89
+ 'add_cross_attention': False,
90
+ 'tie_encoder_decoder': False,
91
+ 'max_length': 20,
92
+ 'min_length': 0,
93
+ 'do_sample': False,
94
+ 'early_stopping': False,
95
+ 'num_beams': 1,
96
+ 'num_beam_groups': 1,
97
+ 'diversity_penalty': 0.0,
98
+ 'temperature': 1.0,
99
+ 'top_k': 50,
100
+ 'top_p': 1.0,
101
+ 'typical_p': 1.0,
102
+ 'repetition_penalty': 1.0,
103
+ 'length_penalty': 1.0,
104
+ 'no_repeat_ngram_size': 0,
105
+ 'encoder_no_repeat_ngram_size': 0,
106
+ 'bad_words_ids': None,
107
+ 'num_return_sequences': 1,
108
+ 'output_scores': False,
109
+ 'return_dict_in_generate': False,
110
+ 'forced_bos_token_id': None,
111
+ 'forced_eos_token_id': None,
112
+ 'remove_invalid_values': False,
113
+ 'exponential_decay_length_penalty': None,
114
+ 'suppress_tokens': None,
115
+ 'begin_suppress_tokens': None,
116
+ 'finetuning_task': None,
117
+ 'id2label': {0: 'LABEL_0', 1: 'LABEL_1'},
118
+ 'label2id': {'LABEL_0': 0, 'LABEL_1': 1},
119
+ 'tokenizer_class': None,
120
+ 'prefix': None,
121
+ 'bos_token_id': 1,
122
+ 'pad_token_id': 32000,
123
+ 'eos_token_id': 32000,
124
+ 'sep_token_id': None,
125
+ 'decoder_start_token_id': None,
126
+ 'task_specific_params': None,
127
+ 'problem_type': None,
128
+ 'model_type': 'phi3'
129
+ }
130
+ logger.info("text_config is None. Initializing the text config with default values (`Phi3Config`).")
131
+
132
+ self.vision_encoder_config = XGenMMVisionEncoderConfig(**vision_encoder_config)
133
+
134
+ self.vision_tokenizer_config = XGenMMVisionTokenizerConfig(**vision_tokenizer_config)
135
+
136
+ text_model_type = text_config["model_type"] if "model_type" in text_config else "phi3"
137
+ self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
138
+
139
+ for key in ['initial_tokenizer_len', 'pad_token_id']:
140
+ if key not in self.text_config.to_dict():
141
+ raise ValueError(f"The key `{key}` is missing in the text_config.")
142
+
143
+ super().__init__(**kwargs)
144
+
145
+ @classmethod
146
+ def from_vision_encoder_vision_tokenizer_text_configs(
147
+ cls,
148
+ vision_encoder_config: XGenMMVisionEncoderConfig,
149
+ vision_tokenizer_config: XGenMMVisionTokenizerConfig,
150
+ text_config: PretrainedConfig,
151
+ **kwargs):
152
+
153
+ return cls(
154
+ vision_encoder_config=vision_encoder_config.to_dict(),
155
+ vision_tokenizer_config=vision_tokenizer_config.to_dict(),
156
+ text_config=text_config.to_dict(),
157
+ **kwargs,
158
+ )
159
+
demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 32000,
5
+ "pad_token_id": 32000,
6
+ "transformers_version": "4.41.1"
7
+ }
image_processing_blip_3.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
3
+ import torchvision.transforms.functional as F
4
+ from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
5
+ CenterCrop, ColorJitter, Grayscale
6
+ import numbers
7
+ import torch
8
+ import ast
9
+ import math
10
+ import numpy as np
11
+ from PIL import Image
12
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
13
+ from transformers.image_utils import ImageInput
14
+ from transformers.utils import TensorType
15
+
16
+ from utils import expand2square
17
+
18
+
19
+ class Blip3ImageProcessor(BaseImageProcessor):
20
+
21
+ def __init__(
22
+ self,
23
+ do_resize: bool = True,
24
+ resize_mode: str = "squash",
25
+ interpolation_mode: str = "bicubic",
26
+ size: Union[Tuple[int, int], List[int]] = None,
27
+ grids: Optional[List[int]] = None,
28
+ image_mean: Optional[Union[float, List[float]]] = None,
29
+ image_std: Optional[Union[float, List[float]]] = None,
30
+ **kwargs,
31
+ ) -> None:
32
+ super().__init__(**kwargs)
33
+ self.do_resize = do_resize
34
+ self.resize_mode = resize_mode
35
+ self.interpolation_mode = interpolation_mode
36
+ self.size = size if size is not None else (384, 384)
37
+ self.grids = grids if grids is not None else [[384, 768],[768, 384],[768, 768],[1152, 384],[384,1152]]
38
+
39
+ self.image_mean = image_mean if image_mean is not None else [0.5, 0.5, 0.5]
40
+ self.image_std = image_std if image_std is not None else [0.5, 0.5, 0.5]
41
+
42
+
43
+ @classmethod
44
+ def resize(cls, image_size, resize_mode, interpolation='bicubic', fill_color=0):
45
+ interpolation_mode = InterpolationMode.BILINEAR if interpolation == 'bilinear' else InterpolationMode.BICUBIC
46
+ if resize_mode == 'longest':
47
+ transforms = [
48
+ ResizeKeepRatio(image_size, interpolation=interpolation_mode, longest=1),
49
+ CenterCropOrPad(image_size, fill=fill_color)
50
+ ]
51
+ elif resize_mode == 'squash':
52
+ if isinstance(image_size, int):
53
+ image_size = (image_size, image_size)
54
+ transforms = [
55
+ Resize(image_size, interpolation=interpolation_mode),
56
+ ]
57
+ else:
58
+ assert resize_mode == 'shortest'
59
+ if not isinstance(image_size, (tuple, list)):
60
+ image_size = (image_size, image_size)
61
+ if image_size[0] == image_size[1]:
62
+ # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
63
+ transforms = [
64
+ Resize(image_size[0], interpolation=interpolation_mode)
65
+ ]
66
+ else:
67
+ # resize shortest edge to matching target dim for non-square target
68
+ transforms = [ResizeKeepRatio(image_size)]
69
+ transforms += [CenterCrop(image_size)]
70
+ return transforms
71
+
72
+ @classmethod
73
+ def convert_rgb(cls, image):
74
+ return image.convert("RGB")
75
+
76
+
77
+ def _preprocess(self,
78
+ images: ImageInput
79
+ ) -> torch.Tensor:
80
+ transforms = self.resize(self.size, self.resize_mode, self.interpolation_mode)
81
+ transforms.extend([
82
+ self.convert_rgb,
83
+ ToTensor(),
84
+ Normalize(mean=self.image_mean, std=self.image_std)
85
+ ])
86
+ composed_transforms = Compose(transforms)
87
+ images_tensor = composed_transforms(images)
88
+ return images_tensor
89
+
90
+ def preprocess(self,
91
+ images: ImageInput,
92
+ return_tensors: Optional[Union[str, TensorType]] = None,
93
+ **kwargs) -> BatchFeature:
94
+ if 'image_aspect_ratio' in kwargs:
95
+ image_aspect_ratio = kwargs['image_aspect_ratio']
96
+ else:
97
+ image_aspect_ratio = 'pad'
98
+ new_images = []
99
+ if image_aspect_ratio == 'pad':
100
+ for image in images:
101
+ image = expand2square(image, tuple(int(x*255) for x in self.image_mean))
102
+ image = self._preprocess(image)
103
+ new_images.append(image)
104
+ else:
105
+ for image in images:
106
+ image = process_anyres_image(image, self._preprocess, self.size,
107
+ self.grids)
108
+ new_images.append(image)
109
+
110
+ if all(x.shape == new_images[0].shape for x in new_images):
111
+ new_images = torch.stack(new_images, dim=0)
112
+ if image_aspect_ratio == 'pad':
113
+ new_images = BatchFeature(data={"pixel_values": new_images.unsqueeze(0).unsqueeze(0)}, tensor_type=return_tensors)
114
+ else:
115
+ new_images = BatchFeature(data={"pixel_values": new_images}, tensor_type=return_tensors)
116
+ return new_images
117
+ # def preprocess(self,
118
+ # images: ImageInput,
119
+ # return_tensors: Optional[Union[str, TensorType]] = None,
120
+ # **kwargs) -> BatchFeature:
121
+ # transforms = self.resize(self.size, self.resize_mode, self.interpolation_mode)
122
+ # transforms.extend([
123
+ # self.convert_rgb,
124
+ # ToTensor(),
125
+ # Normalize(mean=self.image_mean, std=self.image_std)
126
+ # ])
127
+ # composed_transforms = Compose(transforms)
128
+ # images_tensor = composed_transforms(images).unsqueeze(0).unsqueeze(1).unsqueeze(0)
129
+ # encoded_outputs = BatchFeature(data={"pixel_values": images_tensor}, tensor_type=return_tensors)
130
+ # return encoded_outputs
131
+
132
+
133
+ class ResizeKeepRatio:
134
+ """ Resize and Keep Ratio
135
+
136
+ Copy & paste from `timm`
137
+ """
138
+
139
+ def __init__(
140
+ self,
141
+ size,
142
+ longest=0.,
143
+ interpolation=InterpolationMode.BICUBIC,
144
+ random_scale_prob=0.,
145
+ random_scale_range=(0.85, 1.05),
146
+ random_aspect_prob=0.,
147
+ random_aspect_range=(0.9, 1.11)
148
+ ):
149
+ if isinstance(size, (list, tuple)):
150
+ self.size = tuple(size)
151
+ else:
152
+ self.size = (size, size)
153
+ self.interpolation = interpolation
154
+ self.longest = float(longest) # [0, 1] where 0 == shortest edge, 1 == longest
155
+ self.random_scale_prob = random_scale_prob
156
+ self.random_scale_range = random_scale_range
157
+ self.random_aspect_prob = random_aspect_prob
158
+ self.random_aspect_range = random_aspect_range
159
+
160
+ @staticmethod
161
+ def get_params(
162
+ img,
163
+ target_size,
164
+ longest,
165
+ random_scale_prob=0.,
166
+ random_scale_range=(0.85, 1.05),
167
+ random_aspect_prob=0.,
168
+ random_aspect_range=(0.9, 1.11)
169
+ ):
170
+ """Get parameters
171
+ """
172
+ source_size = img.size[::-1] # h, w
173
+ h, w = source_size
174
+ target_h, target_w = target_size
175
+ ratio_h = h / target_h
176
+ ratio_w = w / target_w
177
+ ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest)
178
+ if random_scale_prob > 0 and random.random() < random_scale_prob:
179
+ ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])
180
+ ratio_factor = (ratio_factor, ratio_factor)
181
+ else:
182
+ ratio_factor = (1., 1.)
183
+ if random_aspect_prob > 0 and random.random() < random_aspect_prob:
184
+ aspect_factor = random.uniform(random_aspect_range[0], random_aspect_range[1])
185
+ ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor)
186
+ size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)]
187
+ return size
188
+
189
+ def __call__(self, img):
190
+ """
191
+ Args:
192
+ img (PIL Image): Image to be cropped and resized.
193
+
194
+ Returns:
195
+ PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size
196
+ """
197
+ size = self.get_params(
198
+ img, self.size, self.longest,
199
+ self.random_scale_prob, self.random_scale_range,
200
+ self.random_aspect_prob, self.random_aspect_range
201
+ )
202
+ img = F.resize(img, size, self.interpolation)
203
+ return img
204
+
205
+ def __repr__(self):
206
+ format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
207
+ format_string += f', interpolation={self.interpolation})'
208
+ format_string += f', longest={self.longest:.3f})'
209
+ return format_string
210
+
211
+ def _setup_size(size, error_msg):
212
+ if isinstance(size, numbers.Number):
213
+ return int(size), int(size)
214
+
215
+ if isinstance(size, Sequence) and len(size) == 1:
216
+ return size[0], size[0]
217
+
218
+ if len(size) != 2:
219
+ raise ValueError(error_msg)
220
+
221
+ return size
222
+
223
+ def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor:
224
+ """Center crops and/or pads the given image.
225
+ If the image is torch Tensor, it is expected
226
+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
227
+ If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
228
+
229
+ Args:
230
+ img (PIL Image or Tensor): Image to be cropped.
231
+ output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,
232
+ it is used for both directions.
233
+ fill (int, Tuple[int]): Padding color
234
+
235
+ Returns:
236
+ PIL Image or Tensor: Cropped image.
237
+ """
238
+ if isinstance(output_size, numbers.Number):
239
+ output_size = (int(output_size), int(output_size))
240
+ elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
241
+ output_size = (output_size[0], output_size[0])
242
+
243
+ _, image_height, image_width = F.get_dimensions(img)
244
+ crop_height, crop_width = output_size
245
+
246
+ if crop_width > image_width or crop_height > image_height:
247
+ padding_ltrb = [
248
+ (crop_width - image_width) // 2 if crop_width > image_width else 0,
249
+ (crop_height - image_height) // 2 if crop_height > image_height else 0,
250
+ (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
251
+ (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
252
+ ]
253
+ img = F.pad(img, padding_ltrb, fill=fill)
254
+ _, image_height, image_width = F.get_dimensions(img)
255
+ if crop_width == image_width and crop_height == image_height:
256
+ return img
257
+
258
+ crop_top = int(round((image_height - crop_height) / 2.0))
259
+ crop_left = int(round((image_width - crop_width) / 2.0))
260
+ return F.crop(img, crop_top, crop_left, crop_height, crop_width)
261
+
262
+ class CenterCropOrPad(torch.nn.Module):
263
+ """Crops the given image at the center.
264
+ If the image is torch Tensor, it is expected
265
+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
266
+ If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
267
+
268
+ Args:
269
+ size (sequence or int): Desired output size of the crop. If size is an
270
+ int instead of sequence like (h, w), a square crop (size, size) is
271
+ made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
272
+ """
273
+
274
+ def __init__(self, size, fill=0):
275
+ super().__init__()
276
+ self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
277
+ self.fill = fill
278
+
279
+ def forward(self, img):
280
+ """
281
+ Args:
282
+ img (PIL Image or Tensor): Image to be cropped.
283
+
284
+ Returns:
285
+ PIL Image or Tensor: Cropped image.
286
+ """
287
+ return center_crop_or_pad(img, self.size, fill=self.fill)
288
+
289
+ def __repr__(self) -> str:
290
+ return f"{self.__class__.__name__}(size={self.size})"
291
+
292
+ def process_anyres_image(image, processor, processor_size, grid_pinpoints):
293
+ """
294
+ Process an image with variable resolutions.
295
+
296
+ Args:
297
+ image (PIL.Image.Image): The input image to be processed.
298
+ processor: The image processor object.
299
+ processor_size (tuple, list): The size of the image processor.
300
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
301
+
302
+ Returns:
303
+ torch.Tensor: A tensor containing the processed image patches.
304
+ """
305
+ # FIXME: determine grid_pinpoints from image sizes.
306
+ if type(grid_pinpoints) is list:
307
+ possible_resolutions = grid_pinpoints
308
+ else:
309
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
310
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
311
+ image_padded = resize_and_pad_image(image, best_resolution)
312
+
313
+ # processor_size = processor.transforms[0].size
314
+ patches = divide_to_patches(image_padded, processor_size[0])
315
+
316
+ image_original_resize = image.resize((processor_size[0], processor_size[0]))
317
+
318
+ image_patches = [image_original_resize] + patches
319
+ image_patches = [processor(image_patch)
320
+ for image_patch in image_patches]
321
+ return torch.stack(image_patches, dim=0)
322
+
323
+
324
+ def select_best_resolution(original_size, possible_resolutions):
325
+ """
326
+ Selects the best resolution from a list of possible resolutions based on the original size.
327
+
328
+ Args:
329
+ original_size (tuple): The original size of the image in the format (width, height).
330
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
331
+
332
+ Returns:
333
+ tuple: The best fit resolution in the format (width, height).
334
+ """
335
+ original_width, original_height = original_size
336
+ best_fit = None
337
+ max_effective_resolution = 0
338
+ min_wasted_resolution = float('inf')
339
+
340
+ for width, height in possible_resolutions:
341
+ scale = min(width / original_width, height / original_height)
342
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
343
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
344
+ wasted_resolution = (width * height) - effective_resolution
345
+
346
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
347
+ max_effective_resolution = effective_resolution
348
+ min_wasted_resolution = wasted_resolution
349
+ best_fit = (width, height)
350
+
351
+ return best_fit
352
+
353
+ def resize_and_pad_image(image, target_resolution):
354
+ """
355
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
356
+
357
+ Args:
358
+ image (PIL.Image.Image): The input image.
359
+ target_resolution (tuple): The target resolution (width, height) of the image.
360
+
361
+ Returns:
362
+ PIL.Image.Image: The resized and padded image.
363
+ """
364
+ original_width, original_height = image.size
365
+ target_width, target_height = target_resolution
366
+
367
+ scale_w = target_width / original_width
368
+ scale_h = target_height / original_height
369
+
370
+ if scale_w < scale_h:
371
+ new_width = target_width
372
+ new_height = min(math.ceil(original_height * scale_w), target_height)
373
+ else:
374
+ new_height = target_height
375
+ new_width = min(math.ceil(original_width * scale_h), target_width)
376
+
377
+ # Resize the image
378
+ resized_image = image.resize((new_width, new_height))
379
+
380
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
381
+ paste_x = (target_width - new_width) // 2
382
+ paste_y = (target_height - new_height) // 2
383
+ new_image.paste(resized_image, (paste_x, paste_y))
384
+
385
+ return new_image
386
+
387
+ def divide_to_patches(image, patch_size):
388
+ """
389
+ Divides an image into patches of a specified size.
390
+
391
+ Args:
392
+ image (PIL.Image.Image): The input image.
393
+ patch_size (int): The size of each patch.
394
+
395
+ Returns:
396
+ list: A list of PIL.Image.Image objects representing the patches.
397
+ """
398
+ patches = []
399
+ width, height = image.size
400
+ for i in range(0, height, patch_size):
401
+ for j in range(0, width, patch_size):
402
+ box = (j, i, j + patch_size, i + patch_size)
403
+ patch = image.crop(box)
404
+ patches.append(patch)
405
+
406
+ return patches
modeling_xgenmm.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, AutoModelForCausalLM, AutoModel
2
+ import torch
3
+ import open_clip
4
+ from typing import List, Optional, Tuple, Union
5
+ from utils import check_embedding_fns
6
+ from vlm import PerceiverResampler, XGenMMPerceiver
7
+ from configuration_xgenmm import XGenMMVisionEncoderConfig, XGenMMVisionTokenizerConfig, XGenMMConfig
8
+
9
+ class XGenMMVisionEncoder(PreTrainedModel):
10
+ main_input_name = "pixel_values"
11
+ config_class = XGenMMVisionEncoderConfig
12
+
13
+ def __init__(self, config: XGenMMVisionEncoderConfig):
14
+ super().__init__(config)
15
+ if config.model_name != 'google/siglip-so400m-patch14-384':
16
+ raise ValueError(f"Unsupported model {config.model_name}. New vision models will be added soon.")
17
+ self.model = AutoModel.from_pretrained(config.model_name)
18
+
19
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
20
+ # assert pixel_values.ndim == 4, f"Expected 4D tensor (bs, c, h, w), got {pixel_values.ndim}"
21
+ return self.model.encode_image(pixel_values)
22
+
23
+
24
+ # vision tokenizer
25
+ class XGenMMVisionTokenizer(PreTrainedModel):
26
+ config_class = XGenMMVisionTokenizerConfig
27
+ def __init__(self, config: XGenMMVisionTokenizerConfig):
28
+ super().__init__(config)
29
+ self.model = PerceiverResampler(
30
+ dim=config.vis_feature_dim,
31
+ dim_inner=config.lang_embedding_dim,
32
+ num_latents=config.num_vis_tokens,
33
+ )
34
+
35
+ def forward(self,
36
+ vision_features: torch.Tensor,
37
+ vision_attn_masks: torch.Tensor):
38
+ return self.model(vision_features, vision_attn_masks)
39
+
40
+ # XGenMM model
41
+ class XGenMMModelForConditionalGeneration(PreTrainedModel):
42
+ config_class = XGenMMConfig
43
+
44
+ def __init__(self, config: XGenMMConfig):
45
+ super().__init__(config)
46
+
47
+ # vision encoder initialization
48
+ vision_encoder = AutoModel.from_pretrained(config.vision_encoder_config.model_name).vision_model
49
+
50
+ # language model initialization
51
+ language_model = AutoModelForCausalLM.from_config(config.text_config)
52
+ check_embedding_fns(language_model)
53
+ # Update _tied_weights_keys using the base model used.
54
+ if language_model._tied_weights_keys is not None:
55
+ self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
56
+
57
+ # vision tokenizer initialization
58
+ if config.vision_tokenizer_config.lang_embedding_dim != language_model.get_input_embeddings().weight.shape[1]:
59
+ overwrite = language_model.get_input_embeddings().weight.shape[1]
60
+ config.vision_tokenizer_config.lang_embedding_dim = overwrite
61
+ print(f"Warning: The language embedding dimension in the vision tokenizer config is different from the language model's embedding dimension. Overwriting the language embedding dimension in the vision tokenizer config to {overwrite}.")
62
+
63
+ vision_tokenizer = XGenMMVisionTokenizer(config.vision_tokenizer_config).model
64
+
65
+ self.vlm = XGenMMPerceiver(
66
+ vision_encoder=vision_encoder,
67
+ vision_tokenizer=vision_tokenizer,
68
+ lang_model=language_model,
69
+ initial_tokenizer_len = config.text_config.initial_tokenizer_len,
70
+ pad_token_id = config.text_config.pad_token_id,
71
+ image_aspect_ratio = config.vision_encoder_config.image_aspect_ratio,
72
+ anyres_patch_sampling = config.vision_encoder_config.anyres_patch_sampling,
73
+ anyres_grids = config.vision_encoder_config.anyres_grids
74
+ )
75
+ # Initialize weights and apply final processing
76
+ self.post_init()
77
+
78
+ @torch.no_grad()
79
+ def generate(
80
+ self,
81
+ pixel_values: torch.FloatTensor,
82
+ input_ids: Optional[torch.LongTensor] = None,
83
+ attention_mask: Optional[torch.LongTensor] = None,
84
+ **generate_kwargs,
85
+ ) -> torch.LongTensor:
86
+ self.vlm = self.vlm.eval()
87
+ return self.vlm.generate(
88
+ vision_x = pixel_values,
89
+ lang_x = input_ids,
90
+ attention_mask = attention_mask,
91
+ **generate_kwargs)
92
+
93
+ def update_special_tokens(self, tokenizer):
94
+ tokenizer.add_special_tokens(
95
+ {"additional_special_tokens": list(self.vlm.special_tokens.values())}
96
+ )
97
+ self.vlm.lang_model.config.vocab_size = len(tokenizer)
98
+ self.vlm.set_special_token_ids(
99
+ {
100
+ v: tokenizer.convert_tokens_to_ids(v) for v in self.vlm.special_tokens.values()
101
+ }
102
+ )
103
+ return tokenizer
104
+
preprocessor_config.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "image_processing_blip_3.Blip3ImageProcessor"
4
+ },
5
+ "do_resize": true,
6
+ "grids": [
7
+ [
8
+ 384,
9
+ 768
10
+ ],
11
+ [
12
+ 768,
13
+ 384
14
+ ],
15
+ [
16
+ 768,
17
+ 768
18
+ ],
19
+ [
20
+ 1152,
21
+ 384
22
+ ],
23
+ [
24
+ 384,
25
+ 1152
26
+ ]
27
+ ],
28
+ "image_mean": [
29
+ 0.5,
30
+ 0.5,
31
+ 0.5
32
+ ],
33
+ "image_processor_type": "Blip3ImageProcessor",
34
+ "image_std": [
35
+ 0.5,
36
+ 0.5,
37
+ 0.5
38
+ ],
39
+ "interpolation_mode": "bicubic",
40
+ "resize_mode": "squash",
41
+ "size": [
42
+ 384,
43
+ 384
44
+ ]
45
+ }
setup.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ pip install torch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1 --index-url https://download.pytorch.org/whl/cu121
2
+ pip install open_clip_torch==2.24.0
3
+ pip install einops
4
+ pip install einops-exts
5
+ pip install transformers==4.41.1
6
+ # optional
7
+ pip install ipywidgets
special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<pad>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<unk>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
test_samples/images/1148.jpg ADDED
test_samples/images/152.jpg ADDED
test_samples/images/45711.jpg ADDED
test_samples/images/image-1.jpeg ADDED
test_samples/images/image-2.jpeg ADDED
test_samples/test.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "image_path": ["./test_samples/images/image-1.jpeg",
4
+ "./test_samples/images/image-2.jpeg"],
5
+ "question": [
6
+ "What is in common between this image 1 <image> and image 2 <image>?"
7
+ ]
8
+ },
9
+ {
10
+ "image_path": ["./test_samples/images/152.jpg"],
11
+ "question": ["<image>\nCan you explain this meme?"]
12
+ },
13
+ {
14
+ "image_path": ["./test_samples/images/1148.jpg"],
15
+ "question": ["<image>\nWhat can be the relationship between the two persons in this image?"]
16
+ },
17
+ {
18
+
19
+ "image_path": ["./test_samples/images/45711.jpg"],
20
+ "question": [
21
+ "<image>\nWhat is this meeting about?",
22
+ "<image>\nHow many things are discussed in the meeting?",
23
+ "<image>\nWhat is the second agenda?",
24
+ "<image>\nWhen is the next meeting held?"
25
+ ]
26
+ }
27
+ ]
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
tokenizer_config.json ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": true,
26
+ "single_word": false,
27
+ "special": false
28
+ },
29
+ "32000": {
30
+ "content": "<|endoftext|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "32001": {
38
+ "content": "<|assistant|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": true,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "32002": {
46
+ "content": "<|placeholder1|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": true,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "32003": {
54
+ "content": "<|placeholder2|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": true,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "32004": {
62
+ "content": "<|placeholder3|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": true,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "32005": {
70
+ "content": "<|placeholder4|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": true,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "32006": {
78
+ "content": "<|system|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": true,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "32007": {
86
+ "content": "<|end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": true,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "32008": {
94
+ "content": "<|placeholder5|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": true,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "32009": {
102
+ "content": "<|placeholder6|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": true,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "32010": {
110
+ "content": "<|user|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": true,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "32011": {
118
+ "content": "<pad>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": true
124
+ }
125
+ },
126
+ "bos_token": "<s>",
127
+ "chat_template": "{% for message in messages %}{% if message['role'] == 'system' %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
128
+ "clean_up_tokenization_spaces": false,
129
+ "eos_token": "<|endoftext|>",
130
+ "legacy": false,
131
+ "model_max_length": 4096,
132
+ "pad_token": "<pad>",
133
+ "padding_side": "left",
134
+ "sp_model_kwargs": {},
135
+ "tokenizer_class": "LlamaTokenizer",
136
+ "unk_token": "<unk>",
137
+ "use_default_system_prompt": false
138
+ }
utils.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import ast
3
+ import math
4
+ from PIL import Image
5
+ from packaging.version import Version
6
+
7
+ def has_fn(model, fn_name):
8
+ """Check if model has a function fn_name"""
9
+ return callable(getattr(model, fn_name, None))
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+ def num_params(module, filter_to_trainable=False):
15
+ """Returns the number of parameters in the module, or optionally only the trainable parameters"""
16
+ if filter_to_trainable:
17
+ return sum(p.numel() for p in module.parameters() if p.requires_grad)
18
+ else:
19
+ return sum(p.numel() for p in module.parameters())
20
+
21
+ def hasattr_recursive(obj, att):
22
+ """
23
+ Check if obj has nested attribute
24
+ Example: hasattr_recursive(obj, 'a.b.c') is equivalent to hasattr(obj, 'a') and hasattr(obj.a, 'b') and hasattr(obj.a.b, 'c')
25
+ """
26
+ if att == "":
27
+ return True
28
+ i = att.find(".")
29
+ if i < 0:
30
+ return hasattr(obj, att)
31
+ else:
32
+ try:
33
+ return hasattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
34
+ except:
35
+ return False
36
+
37
+ def getattr_recursive(obj, att):
38
+ """
39
+ Return nested attribute of obj
40
+ Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
41
+ """
42
+ if att == "":
43
+ return obj
44
+ i = att.find(".")
45
+ if i < 0:
46
+ return getattr(obj, att)
47
+ else:
48
+ return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
49
+
50
+
51
+ def setattr_recursive(obj, att, val):
52
+ """
53
+ Set nested attribute of obj
54
+ Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
55
+ """
56
+ if "." in att:
57
+ obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
58
+ setattr(obj, att.split(".")[-1], val)
59
+
60
+
61
+ def stack_with_padding(list_of_tensors, padding_value=0, padding_side="right"):
62
+ """
63
+ Stack a list of tensors with padding on one side
64
+ Args:
65
+ list_of_tensors (list[torch.Tensor]): List of tensors to stack
66
+ padding_value (int, optional): Value to pad with. Defaults to 0.
67
+ padding_side (str, optional): Side to pad on. Defaults to "right".
68
+ Returns:
69
+ torch.Tensor: Stacked tensors
70
+ """
71
+ max_tokens = max(tensor.size(0) for tensor in list_of_tensors)
72
+ padded_tensors = []
73
+ for tensor in list_of_tensors:
74
+ num_tokens = tensor.size(0)
75
+ if len(tensor.size()) == 1:
76
+ padding = torch.full(
77
+ (max_tokens - num_tokens,),
78
+ padding_value,
79
+ dtype=tensor.dtype,
80
+ device=tensor.device,
81
+ )
82
+ else:
83
+ padding = torch.full(
84
+ (max_tokens - num_tokens, tensor.size(1)),
85
+ padding_value,
86
+ dtype=tensor.dtype,
87
+ device=tensor.device,
88
+ )
89
+ padded_tensor = (
90
+ torch.cat((tensor, padding), dim=0)
91
+ if padding_side == "right"
92
+ else torch.cat((padding, tensor), dim=0)
93
+ )
94
+ padded_tensors.append(padded_tensor)
95
+ return torch.stack(padded_tensors)
96
+
97
+
98
+ def check_embedding_fns(lang_model):
99
+ """Checks for and attempts to set {get/set}_{input/output}_embeddings functions to the model"""
100
+ if not has_fn(lang_model, "get_input_embeddings"):
101
+ if hasattr_recursive(lang_model, "transformer.wte"): # MPT
102
+ lang_model.get_input_embeddings = lambda: lang_model.transformer.wte
103
+ elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): # OPT
104
+ lang_model.get_input_embeddings = lambda: lang_model.decoder.embed_tokens
105
+ else:
106
+ raise ValueError(
107
+ "We require the language encoder to have a get_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py."
108
+ )
109
+
110
+ if not has_fn(lang_model, "set_input_embeddings"):
111
+ if hasattr_recursive(lang_model, "transformer.wte"): # MPT
112
+ lang_model.set_input_embeddings = lambda x: setattr_recursive(
113
+ lang_model, "transformer.wte", x
114
+ )
115
+ elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): # OPT
116
+ lang_model.set_input_embeddings = lambda x: setattr_recursive(
117
+ lang_model, "model.decoder.embed_tokens", x
118
+ )
119
+ else:
120
+ raise ValueError(
121
+ "We require the language encoder to have a set_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py."
122
+ )
123
+
124
+ if not has_fn(lang_model, "get_output_embeddings"):
125
+ if hasattr_recursive(lang_model, "lm_head"):
126
+ lang_model.get_output_embeddings = lambda: lang_model.lm_head
127
+ else:
128
+ raise ValueError(
129
+ "We require the language encoder to have a get_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
130
+ )
131
+
132
+ if not has_fn(lang_model, "set_output_embeddings"):
133
+ if hasattr_recursive(lang_model, "lm_head"):
134
+ lang_model.set_output_embeddings = lambda x: setattr_recursive(
135
+ lang_model, "lm_head", x
136
+ )
137
+ else:
138
+ raise ValueError(
139
+ "We require the language encoder to have a set_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
140
+ )
141
+
142
+
143
+ def has_fn(model, fn_name):
144
+ """Check if model has a function fn_name"""
145
+ return callable(getattr(model, fn_name, None))
146
+
147
+
148
+ # Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
149
+ #
150
+ # Licensed under the Apache License, Version 2.0 (the "License");
151
+ # you may not use this file except in compliance with the License.
152
+ # You may obtain a copy of the License at
153
+ #
154
+ # http://www.apache.org/licenses/LICENSE-2.0
155
+ #
156
+ # Unless required by applicable law or agreed to in writing, software
157
+ # distributed under the License is distributed on an "AS IS" BASIS,
158
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
159
+ # See the License for the specific language governing permissions and
160
+ # limitations under the License.
161
+
162
+ def unpad_image(tensor, original_size, keep_original_shape=False):
163
+ """
164
+ Unpads a PyTorch tensor of a padded and resized image.
165
+
166
+ Args:
167
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
168
+ original_size (tuple): The original size of the image (height, width).
169
+
170
+ Returns:
171
+ torch.Tensor: The unpadded image tensor.
172
+ """
173
+ original_width, original_height = original_size
174
+ current_height, current_width = tensor.shape[1:]
175
+
176
+ original_aspect_ratio = original_width / original_height
177
+ current_aspect_ratio = current_width / current_height
178
+
179
+ if original_aspect_ratio > current_aspect_ratio:
180
+ scale_factor = current_width / original_width
181
+ new_height = int(original_height * scale_factor)
182
+ padding = (current_height - new_height) // 2
183
+ if keep_original_shape:
184
+ attention_mask = torch.ones((current_height, current_width), device=tensor.device)
185
+ attention_mask[:padding, :] = 0
186
+ attention_mask[current_height - padding:, :] = 0
187
+ return tensor, attention_mask
188
+ else:
189
+ unpadded_tensor = tensor[:, padding:current_height - padding, :]
190
+ return unpadded_tensor, None
191
+ else:
192
+ scale_factor = current_height / original_height
193
+ new_width = int(original_width * scale_factor)
194
+ padding = (current_width - new_width) // 2
195
+ if keep_original_shape:
196
+ attention_mask = torch.ones((current_height, current_width), device=tensor.device)
197
+ attention_mask[:, :padding] = 0
198
+ attention_mask[:, current_width - padding:] = 0
199
+ return tensor, attention_mask
200
+ else:
201
+ unpadded_tensor = tensor[:, :, padding:current_width - padding]
202
+ return unpadded_tensor, None
203
+
204
+
205
+ def select_best_resolution(original_size, possible_resolutions):
206
+ """
207
+ Selects the best resolution from a list of possible resolutions based on the original size.
208
+
209
+ Args:
210
+ original_size (tuple): The original size of the image in the format (width, height).
211
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
212
+
213
+ Returns:
214
+ tuple: The best fit resolution in the format (width, height).
215
+ """
216
+ original_width, original_height = original_size
217
+ best_fit = None
218
+ max_effective_resolution = 0
219
+ min_wasted_resolution = float('inf')
220
+
221
+ for width, height in possible_resolutions:
222
+ scale = min(width / original_width, height / original_height)
223
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
224
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
225
+ wasted_resolution = (width * height) - effective_resolution
226
+
227
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
228
+ max_effective_resolution = effective_resolution
229
+ min_wasted_resolution = wasted_resolution
230
+ best_fit = (width, height)
231
+
232
+ return best_fit
233
+
234
+
235
+ def resize_and_pad_image(image, target_resolution):
236
+ """
237
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
238
+
239
+ Args:
240
+ image (PIL.Image.Image): The input image.
241
+ target_resolution (tuple): The target resolution (width, height) of the image.
242
+
243
+ Returns:
244
+ PIL.Image.Image: The resized and padded image.
245
+ """
246
+ original_width, original_height = image.size
247
+ target_width, target_height = target_resolution
248
+
249
+ scale_w = target_width / original_width
250
+ scale_h = target_height / original_height
251
+
252
+ if scale_w < scale_h:
253
+ new_width = target_width
254
+ new_height = min(math.ceil(original_height * scale_w), target_height)
255
+ else:
256
+ new_height = target_height
257
+ new_width = min(math.ceil(original_width * scale_h), target_width)
258
+
259
+ # Resize the image
260
+ resized_image = image.resize((new_width, new_height))
261
+
262
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
263
+ paste_x = (target_width - new_width) // 2
264
+ paste_y = (target_height - new_height) // 2
265
+ new_image.paste(resized_image, (paste_x, paste_y))
266
+
267
+ return new_image
268
+
269
+
270
+ def divide_to_patches(image, patch_size):
271
+ """
272
+ Divides an image into patches of a specified size.
273
+
274
+ Args:
275
+ image (PIL.Image.Image): The input image.
276
+ patch_size (int): The size of each patch.
277
+
278
+ Returns:
279
+ list: A list of PIL.Image.Image objects representing the patches.
280
+ """
281
+ patches = []
282
+ width, height = image.size
283
+ for i in range(0, height, patch_size):
284
+ for j in range(0, width, patch_size):
285
+ box = (j, i, j + patch_size, i + patch_size)
286
+ patch = image.crop(box)
287
+ patches.append(patch)
288
+
289
+ return patches
290
+
291
+
292
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
293
+ """
294
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
295
+
296
+ Args:
297
+ image_size (tuple): The size of the input image in the format (width, height).
298
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
299
+ patch_size (int): The size of each image patch.
300
+
301
+ Returns:
302
+ tuple: The shape of the image patch grid in the format (width, height).
303
+ """
304
+ if type(grid_pinpoints) is list:
305
+ possible_resolutions = grid_pinpoints
306
+ else:
307
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
308
+ width, height = select_best_resolution(image_size, possible_resolutions)
309
+ return width // patch_size, height // patch_size
310
+
311
+
312
+ def process_anyres_image(image, processor, grid_pinpoints):
313
+ """
314
+ Process an image with variable resolutions.
315
+
316
+ Args:
317
+ image (PIL.Image.Image): The input image to be processed.
318
+ processor: The image processor object.
319
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
320
+
321
+ Returns:
322
+ torch.Tensor: A tensor containing the processed image patches.
323
+ """
324
+ # FIXME: determine grid_pinpoints from image sizes.
325
+ if type(grid_pinpoints) is list:
326
+ possible_resolutions = grid_pinpoints
327
+ else:
328
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
329
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
330
+ image_padded = resize_and_pad_image(image, best_resolution)
331
+
332
+ processor_size = processor.transforms[0].size
333
+ patches = divide_to_patches(image_padded, processor_size[0])
334
+
335
+ image_original_resize = image.resize((processor_size[0], processor_size[0]))
336
+
337
+ image_patches = [image_original_resize] + patches
338
+ image_patches = [processor(image_patch)
339
+ for image_patch in image_patches]
340
+ return torch.stack(image_patches, dim=0)
341
+
342
+
343
+ def expand2square(pil_img, background_color):
344
+ width, height = pil_img.size
345
+ if width == height:
346
+ return pil_img
347
+ elif width > height:
348
+ result = Image.new(pil_img.mode, (width, width), background_color)
349
+ result.paste(pil_img, (0, (width - height) // 2))
350
+ return result
351
+ else:
352
+ result = Image.new(pil_img.mode, (height, height), background_color)
353
+ result.paste(pil_img, ((height - width) // 2, 0))
354
+ return result
355
+
356
+
357
+ def process_images(images, image_processor, model_cfg):
358
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
359
+ new_images = []
360
+ if image_aspect_ratio == 'pad':
361
+ for image in images:
362
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.transforms[-1].mean))
363
+ image = image_processor(image)
364
+ new_images.append(image)
365
+ elif image_aspect_ratio in ["anyres", "anyres-legacy"]:
366
+ base_img_size = image_processor.transforms[0].size[0]
367
+ for image in images:
368
+ image = process_anyres_image(image, image_processor, [[base_img_size,base_img_size*2],
369
+ [base_img_size*2,base_img_size],
370
+ [base_img_size*2,base_img_size*2],
371
+ [base_img_size*3,base_img_size],
372
+ [base_img_size,base_img_size*3]])
373
+
374
+ # Debug any res inference by only using 672x672.
375
+ # image = process_anyres_image(image, image_processor, [[base_img_size*2,base_img_size*2]])
376
+ new_images.append(image)
377
+ else:
378
+ return image_processor(images)
379
+ if all(x.shape == new_images[0].shape for x in new_images):
380
+ new_images = torch.stack(new_images, dim=0)
381
+ return new_images
382
+
383
+
vlm.py ADDED
@@ -0,0 +1,1506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torch import einsum, nn
4
+ from einops import rearrange, repeat
5
+ from einops_exts import rearrange_many
6
+ from einops import rearrange
7
+ from typing import List, Optional, Tuple, Union
8
+ import torch.nn.functional as F
9
+ from transformers.modeling_outputs import CausalLMOutputWithPast
10
+ from dataclasses import dataclass
11
+ from transformers import CLIPVisionModel
12
+ from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer
13
+
14
+ import transformers
15
+ from packaging.version import Version
16
+
17
+ from utils import num_params, getattr_recursive, stack_with_padding, get_anyres_image_grid_shape, unpad_image
18
+
19
+
20
+ class VisionTokenizer(nn.Module):
21
+ def __init__(self, dim_media, num_tokens_per_media):
22
+ super().__init__()
23
+ self.dim_media = dim_media
24
+ self.num_tokens_per_media = num_tokens_per_media
25
+
26
+ class PerceiverAttention(nn.Module):
27
+ def __init__(self, *, dim, dim_head=64, heads=8):
28
+ super().__init__()
29
+ self.scale = dim_head**-0.5
30
+ self.heads = heads
31
+ inner_dim = dim_head * heads
32
+
33
+ self.norm_media = nn.LayerNorm(dim)
34
+ self.norm_latents = nn.LayerNorm(dim)
35
+
36
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
37
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
38
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
39
+
40
+ def forward(self, x, latents, vision_attn_masks=None):
41
+ """
42
+ Args:
43
+ x (torch.Tensor): image features
44
+ shape (b, T, n1, D)
45
+ latent (torch.Tensor): latent features
46
+ shape (b, T, n2, D)
47
+ """
48
+ x = self.norm_media(x)
49
+ latents = self.norm_latents(latents)
50
+
51
+ h = self.heads
52
+
53
+ q = self.to_q(latents)
54
+ kv_input = torch.cat((x, latents), dim=-2) # TODO: Change the shape of vision attention mask according to this.
55
+ if vision_attn_masks is not None:
56
+ vision_attn_masks = torch.cat((vision_attn_masks,
57
+ torch.ones((latents.shape[0], latents.shape[-2]), dtype=latents.dtype, device=latents.device)),
58
+ dim=-1)
59
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
60
+ q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
61
+ q = q * self.scale
62
+
63
+ # attention
64
+ sim = einsum("... i d, ... j d -> ... i j", q, k)
65
+ # Apply vision attention mask here.
66
+ # Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention
67
+ if vision_attn_masks is not None:
68
+ attn_bias = torch.zeros((q.size(0), 1, 1, q.size(-2), k.size(-2)), dtype=q.dtype, device=q.device)
69
+ vision_attn_masks = repeat(vision_attn_masks, 'b n -> b 1 1 l n', l=q.size(-2))
70
+ attn_bias.masked_fill_(vision_attn_masks.logical_not(), float("-inf"))
71
+ sim += attn_bias
72
+
73
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
74
+ attn = sim.softmax(dim=-1)
75
+
76
+
77
+ out = einsum("... i j, ... j d -> ... i d", attn, v)
78
+ out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
79
+ return self.to_out(out)
80
+
81
+
82
+ def FeedForward(dim, mult=4):
83
+ inner_dim = int(dim * mult)
84
+ return nn.Sequential(
85
+ nn.LayerNorm(dim),
86
+ nn.Linear(dim, inner_dim, bias=False),
87
+ nn.GELU(),
88
+ nn.Linear(inner_dim, dim, bias=False),
89
+ )
90
+
91
+
92
+ class PerceiverResampler(VisionTokenizer):
93
+ def __init__(
94
+ self,
95
+ *,
96
+ dim,
97
+ dim_inner=None,
98
+ depth=6,
99
+ dim_head=96,
100
+ heads=16,
101
+ num_latents=128,
102
+ max_num_media=None,
103
+ max_num_frames=None,
104
+ ff_mult=4,
105
+ ):
106
+ """
107
+ Perceiver module which takes in image features and outputs image tokens.
108
+ Args:
109
+ dim (int): dimension of the incoming image features
110
+ dim_inner (int, optional): final dimension to project the incoming image features to;
111
+ also the final dimension of the outputted features. If None, no projection is used, and dim_inner = dim.
112
+ depth (int, optional): number of layers. Defaults to 6.
113
+ dim_head (int, optional): dimension of each head. Defaults to 64.
114
+ heads (int, optional): number of heads. Defaults to 8.
115
+ num_latents (int, optional): number of latent tokens to use in the Perceiver;
116
+ also corresponds to number of tokens per sequence to output. Defaults to 64.
117
+ max_num_media (int, optional): maximum number of media per sequence to input into the Perceiver
118
+ and keep positional embeddings for. If None, no positional embeddings are used.
119
+ max_num_frames (int, optional): maximum number of frames to input into the Perceiver
120
+ and keep positional embeddings for. If None, no positional embeddings are used.
121
+ ff_mult (int, optional): dimension multiplier for the feedforward network. Defaults to 4.
122
+ """
123
+ if dim_inner is not None:
124
+ projection = nn.Linear(dim, dim_inner)
125
+ else:
126
+ projection = None
127
+ dim_inner = dim
128
+ super().__init__(dim_media=dim, num_tokens_per_media=num_latents)
129
+ self.projection = projection
130
+ self.latents = nn.Parameter(torch.randn(num_latents, dim))
131
+
132
+ # positional embeddings
133
+ self.frame_embs = (
134
+ nn.Parameter(torch.randn(max_num_frames, dim))
135
+ if exists(max_num_frames)
136
+ else None
137
+ )
138
+ self.media_time_embs = (
139
+ nn.Parameter(torch.randn(max_num_media, 1, dim))
140
+ if exists(max_num_media)
141
+ else None
142
+ )
143
+
144
+ self.layers = nn.ModuleList([])
145
+ for _ in range(depth):
146
+ self.layers.append(
147
+ nn.ModuleList(
148
+ [
149
+ PerceiverAttention(
150
+ dim=dim, dim_head=dim_head, heads=heads
151
+ ),
152
+ FeedForward(dim=dim, mult=ff_mult),
153
+ ]
154
+ )
155
+ )
156
+
157
+ self.norm = nn.LayerNorm(dim)
158
+
159
+ def forward(self, x, vision_attn_masks):
160
+ """
161
+ Args:
162
+ x (torch.Tensor): image features
163
+ shape (b, T, F, v, D)
164
+ vision_attn_masks (torch.Tensor): attention masks for padded visiont tokens (i.e., x)
165
+ shape (b, v)
166
+ Returns:
167
+ shape (b, T, n, D) where n is self.num_latents
168
+ """
169
+ b, T, F, v = x.shape[:4]
170
+
171
+ # frame and media time embeddings
172
+ if exists(self.frame_embs):
173
+ frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
174
+ x = x + frame_embs
175
+ x = rearrange(
176
+ x, "b T F v d -> b T (F v) d"
177
+ ) # flatten the frame and spatial dimensions
178
+ if exists(self.media_time_embs):
179
+ x = x + self.media_time_embs[:T]
180
+
181
+ # blocks
182
+ latents = self.latents
183
+ latents = repeat(latents, "n d -> b T n d", b=b, T=T)
184
+ for attn, ff in self.layers:
185
+ latents = attn(x, latents, vision_attn_masks) + latents
186
+ latents = ff(latents) + latents
187
+
188
+ if exists(self.projection):
189
+ return self.projection(self.norm(latents))
190
+ else:
191
+ return self.norm(latents)
192
+
193
+
194
+ class DecoupledEmbedding(nn.Embedding):
195
+ # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding
196
+ """
197
+ Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. In practise, the
198
+ regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0,
199
+ then it will create `num_additional_embeddings` additional parameters that are always trained. If
200
+ `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`.
201
+ """
202
+
203
+ def __init__(
204
+ self,
205
+ max_original_id: int,
206
+ num_additional_embeddings: int = 0,
207
+ _weight: torch.Tensor = None,
208
+ num_original_embeddings: int = None,
209
+ embedding_dim: int = None,
210
+ partially_freeze=True,
211
+ device=None,
212
+ dtype=None,
213
+ pad_token_id=None,
214
+ ) -> None:
215
+ """
216
+ Args:
217
+ max_original_id (`int`):
218
+ The largest token id that should be embedded using the regular embedding (regular `weight`).
219
+ This is usually len(tokenizer) - 1 before additional tokens are added.
220
+ Note that this may not equal self.weight.shape[0]
221
+ num_additional_embeddings (`int`):
222
+ Number of additional tokens to initialize an Embedding matrix for (`additional_weight`).
223
+ _weight (`torch.Tensor`, *optional*, defaults to `None`): The regular weight tensor.
224
+ If provided, this sets the `num_original_embeddings` and `embedding_dim` parameters.
225
+ num_original_embeddings (`int`):
226
+ self.weight.shape[0]
227
+ embedding_dim (`int`):
228
+ The size of each embedding vector
229
+ partially_freeze: (`bool`, *optional*, defaults to `True`):
230
+ If `True`, the regular `weight` will be frozen. `additional_weight` is never frozen.
231
+ padding_idx (`int`, *optional*):
232
+ The padding index (needs to be less than num_embeddings)
233
+
234
+ Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`,
235
+ `max_norm` or `norm_type`. We are not supporting these.
236
+ """
237
+ # validate args
238
+ if pad_token_id is not None and pad_token_id > max_original_id:
239
+ raise ValueError(
240
+ f"pad_token_id must be <= max_original_id. Got {pad_token_id} and {max_original_id}."
241
+ + "If the original tokenizer does not have a pad_token_id, use pad_token_id=None."
242
+ )
243
+ if _weight is not None:
244
+ assert (num_original_embeddings is None) or (
245
+ _weight.shape[0] == num_original_embeddings
246
+ ), f"num_original_embeddings={num_original_embeddings} but _weight.shape[0]={_weight.shape[0]}"
247
+ assert (embedding_dim is None) or (
248
+ _weight.shape[1] == embedding_dim
249
+ ), f"embedding_dim={embedding_dim} but _weight.shape[1]={_weight.shape[1]}"
250
+ num_original_embeddings = _weight.shape[0]
251
+ embedding_dim = _weight.shape[1]
252
+ else:
253
+ assert (
254
+ num_original_embeddings is not None
255
+ ), "num_original_embeddings must be provided if _weight is not provided"
256
+ assert (
257
+ embedding_dim is not None
258
+ ), "embedding_dim must be provided if _weight is not provided"
259
+
260
+ super().__init__(
261
+ num_embeddings=num_original_embeddings,
262
+ embedding_dim=embedding_dim,
263
+ device=device,
264
+ dtype=dtype,
265
+ padding_idx=pad_token_id,
266
+ _weight=_weight,
267
+ )
268
+ self.max_original_id = max_original_id
269
+ self.padding_idx = pad_token_id
270
+ self.num_additional_embeddings = num_additional_embeddings
271
+ if self.num_additional_embeddings > 0:
272
+ self.additional_embedding = nn.Embedding(
273
+ num_embeddings=self.num_additional_embeddings,
274
+ embedding_dim=embedding_dim,
275
+ device=device,
276
+ dtype=dtype,
277
+ )
278
+ self.set_requires_grad(
279
+ require_regular_grad=not partially_freeze, require_additional_grad=True
280
+ )
281
+
282
+ def set_requires_grad(self, require_regular_grad, require_additional_grad):
283
+ """
284
+ Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
285
+ """
286
+ self.weight.requires_grad_(require_regular_grad)
287
+ self.additional_embedding.requires_grad_(require_additional_grad)
288
+
289
+ def forward(self, input_ids):
290
+ """
291
+ we have 2 embeddings, with different indices - one pretrained self.weight and another
292
+ self.additional_embedding.weight that is being trained.
293
+
294
+ in order to make a lookup of the input ids, we:
295
+ 1. find out the indices of the entries belonging to the 2nd embedding
296
+ 2. extract those values while subtracting the size of the first embedding (num_embeddings), since the 2nd
297
+ embedding starts from 0 and not num_embeddings
298
+ 3. perform the 2nd embedding lookup
299
+ 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index
300
+ 5. perform the 1st embedding lookup
301
+ 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup
302
+
303
+ note: for the 1st embedding lookup we could have looked up only the low indices and not do the padding, but
304
+ then we have to create a new tensor and populate it with 2 tensors that are spread out across various indices -
305
+ i.e. not a simple concat - I haven't benchmarked the complex case if it's any faster, given that seqlens are
306
+ usually relatively short it's probably not faster or if faster not by much - but might be a good idea to
307
+ measure.
308
+
309
+ """
310
+ if self.num_additional_embeddings == 0:
311
+ return F.embedding(input_ids, self.weight)
312
+
313
+ # Clone so that we don't modify the original input_ids later on
314
+ input_ids = input_ids.clone()
315
+ additional_vocab_indices = torch.where(input_ids > self.max_original_id)
316
+ input_ids_additional_vocab = input_ids[additional_vocab_indices]
317
+ additional_embeddings = self.additional_embedding(
318
+ input_ids_additional_vocab - self.max_original_id - 1
319
+ )
320
+
321
+ # for successful lookup replace input_ids with 0, the results of these will be discarded anyway
322
+ input_ids[additional_vocab_indices] = 0
323
+ full_vector = F.embedding(input_ids, self.weight)
324
+
325
+ # overwrite the records with high indices
326
+ full_vector[additional_vocab_indices] = additional_embeddings
327
+
328
+ return full_vector
329
+
330
+ def extra_repr(self) -> str:
331
+ return "num_original_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format(
332
+ self.max_original_id + 1,
333
+ self.num_additional_embeddings,
334
+ self.embedding_dim,
335
+ (not self.weight.requires_grad),
336
+ )
337
+
338
+
339
+ class DecoupledLinear(nn.Linear):
340
+ # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
341
+ """
342
+ Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the
343
+ regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `additional_out_features` > 0,
344
+ then it will create `additional_out_features * in_features` additional parameters that are always trained. If
345
+ `additional_out_features=0`, then the module defaults back to the regular behavior of `nn.Linear`.
346
+ """
347
+
348
+ def __init__(
349
+ self,
350
+ max_original_id: int,
351
+ additional_out_features: int = 0,
352
+ _weight: torch.Tensor = None,
353
+ _bias: torch.Tensor = None,
354
+ in_features: int = None,
355
+ original_out_features: int = None,
356
+ bias: bool = True,
357
+ partially_freeze: bool = True,
358
+ device=None,
359
+ dtype=None,
360
+ ) -> None:
361
+ """
362
+ Args:
363
+ max_original_id (`int`): The largest token id that should be extracted from the regular weight.
364
+ This is usually len(tokenizer) - 1 before additional tokens are added.
365
+ Note that this may not equal original_out_features - 1
366
+ _weight: torch.Tensor, *optional*, defaults to `None`. The regular weight tensor.
367
+ If provided, this sets the `in_features` and `original_out_features` parameters.
368
+ _bias: torch.Tensor, *optional*, defaults to `None`. The regular bias tensor.
369
+ in_features: int. Input hidden size.
370
+ original_out_features: int. Original out_features of the language model's get_output_embeddings() function.
371
+ additional_out_features: int. Number of additional trainable dimensions.
372
+ bias: bool. Whether to include a bias term.
373
+ partially_freeze: bool, *optional*, defaults to `True`): If `True`, the regular `weight` will be frozen.
374
+ """
375
+ # argument validation
376
+ if _weight is not None:
377
+ assert (_weight.shape[0] == original_out_features) or (
378
+ original_out_features is None
379
+ ), f"original_out_features={original_out_features} but _weight.shape[0]={_weight.shape[0]}"
380
+ assert (_weight.shape[1] == in_features) or (
381
+ in_features is None
382
+ ), f"in_features={in_features} but _weight.shape[1]={_weight.shape[1]}"
383
+ in_features = _weight.shape[1]
384
+ original_out_features = _weight.shape[0]
385
+ else:
386
+ assert (
387
+ in_features is not None
388
+ ), "in_features must be provided if _weight is not provided"
389
+ assert (
390
+ original_out_features is not None
391
+ ), "original_out_features must be provided if _weight is not provided"
392
+
393
+ if _bias is not None:
394
+ assert bias is True, "bias must be True if _bias is provided"
395
+
396
+ # initialize original linear
397
+ super().__init__(
398
+ in_features,
399
+ original_out_features,
400
+ bias,
401
+ device,
402
+ dtype)
403
+
404
+ # set weight and bias manually
405
+ if _weight is not None:
406
+ self.weight = nn.Parameter(_weight)
407
+ if _bias is not None:
408
+ self.bias = nn.Parameter(_bias)
409
+
410
+ self.in_features = in_features
411
+ self.original_out_features = original_out_features
412
+ self.max_original_id = max_original_id
413
+
414
+ # initialize additional linear
415
+ self.additional_out_features = additional_out_features
416
+ self.has_bias = bias
417
+ if additional_out_features > 0:
418
+ self.additional_fc = nn.Linear(
419
+ in_features=in_features,
420
+ out_features=additional_out_features,
421
+ bias=self.has_bias,
422
+ device=device,
423
+ dtype=dtype,
424
+ )
425
+ self.set_requires_grad(
426
+ require_regular_grad=not partially_freeze, require_additional_grad=True
427
+ )
428
+
429
+ def set_requires_grad(self, require_regular_grad, require_additional_grad):
430
+ """
431
+ Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
432
+ """
433
+ self.weight.requires_grad_(require_regular_grad)
434
+ if self.has_bias:
435
+ self.bias.requires_grad_(require_regular_grad)
436
+ self.additional_fc.requires_grad_(require_additional_grad)
437
+
438
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
439
+ output = F.linear(input, self.weight, self.bias)
440
+ output = output[..., : self.max_original_id + 1]
441
+
442
+ if self.additional_out_features > 0:
443
+ additional_features = F.linear(
444
+ input, self.additional_fc.weight, self.additional_fc.bias
445
+ )
446
+ output = torch.cat((output, additional_features), -1)
447
+ return output
448
+
449
+ def extra_repr(self) -> str:
450
+ """Overwriting `nn.Linear.extra_repr` to include new parameters."""
451
+ return "in_features={}, out_features={}, additional_out_features={}, bias={}, partially_freeze={}".format(
452
+ self.in_features,
453
+ self.max_original_id + 1,
454
+ self.additional_out_features,
455
+ self.bias is not None,
456
+ (not self.weight.requires_grad or not self.bias.requires_grad),
457
+ )
458
+
459
+ class VLM(nn.Module):
460
+ """
461
+ Generic vision-language model (VLM) class.
462
+ A VLM consists of four components:
463
+ 1. A vision encoder that extracts features from pixels, e.g. CLIP
464
+ input: (B, T_img, F, C, H, W)
465
+ output: (B, T_img, F, v, d)
466
+ 2. A vision tokenizer that converts these features to visual token-like embeddings, e.g. Perceiver, or a linear projection head
467
+ input: (B, T_img, F, v, d)
468
+ output: (B, T_img, n, d)
469
+ 3. A fusion method that allows the language model to attend to these tokens, e.g. cross-attention, or placing the tokens directly in the language model's input sequence
470
+ 4. A language model
471
+ """
472
+
473
+ def __init__(
474
+ self,
475
+ vision_encoder: nn.Module,
476
+ vision_tokenizer: nn.Module,
477
+ lang_model: nn.Module,
478
+ initial_tokenizer_len: int,
479
+ pad_token_id: int,
480
+ gradient_checkpointing: bool = False,
481
+ ):
482
+ """
483
+ Args:
484
+ vision_encoder (nn.Module): e.g. CLIP
485
+ vision_tokenizer (nn.Module): e.g. PerceiverResampler
486
+ lang_model (nn.Module): e.g. MPT
487
+ initial_tokenizer_len (int): size of the original tokenizer vocab
488
+ pad_token_id (int): id of the pad token
489
+ gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False.
490
+ """
491
+ super().__init__()
492
+
493
+ # save dimension information
494
+ self.lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]
495
+ if hasattr(lang_model.config, "d_model"):
496
+ self.lang_hidden_dim = lang_model.config.d_model # mpt uses d_model
497
+ else:
498
+ self.lang_hidden_dim = lang_model.config.hidden_size
499
+ self.vis_embedding_dim = vision_tokenizer.dim_media
500
+ self.num_tokens_per_vis = vision_tokenizer.num_tokens_per_media
501
+
502
+ # core components
503
+ self.vision_encoder = vision_encoder
504
+ self.vision_tokenizer = vision_tokenizer
505
+ self.lang_model = lang_model
506
+
507
+ # lm embeddings
508
+ self.pad_token_id = pad_token_id
509
+ self.initial_tokenizer_len = initial_tokenizer_len
510
+ input_embeds = DecoupledEmbedding(
511
+ max_original_id=initial_tokenizer_len - 1,
512
+ num_additional_embeddings=len(self.special_tokens),
513
+ _weight=self.lang_model.get_input_embeddings().weight,
514
+ pad_token_id=self.pad_token_id,
515
+ )
516
+ if hasattr(input_embeds, "additional_embedding"):
517
+ input_embeds.additional_embedding.weight.data.normal_(
518
+ mean=0.0,
519
+ std=self.lang_model.config.initializer_range
520
+ if hasattr(self.lang_model.config, "initializer_range")
521
+ else 0.02,
522
+ )
523
+ self.lang_model.set_input_embeddings(input_embeds)
524
+
525
+ out_embeds = DecoupledLinear(
526
+ max_original_id=initial_tokenizer_len - 1,
527
+ additional_out_features=len(self.special_tokens),
528
+ _weight=self.lang_model.get_output_embeddings().weight,
529
+ _bias=self.lang_model.get_output_embeddings().bias if hasattr(self.lang_model.get_output_embeddings(), "bias") else None,
530
+ )
531
+ if hasattr(out_embeds, "additional_fc"):
532
+ out_embeds.additional_fc.weight.data.normal_(
533
+ mean=0.0,
534
+ std=self.lang_model.config.initializer_range
535
+ if hasattr(self.lang_model.config, "initializer_range")
536
+ else 0.02,
537
+ )
538
+ self.lang_model.set_output_embeddings(out_embeds)
539
+
540
+ # gradient checkpointing
541
+ self.vision_tokenizer._use_gradient_checkpointing = gradient_checkpointing
542
+
543
+ def forward(
544
+ self,
545
+ vision_x: Optional[torch.Tensor],
546
+ lang_x: torch.Tensor,
547
+ attention_mask: Optional[torch.Tensor] = None,
548
+ labels: Optional[torch.Tensor] = None,
549
+ past_key_values: Optional[
550
+ List[Union[torch.Tensor, Tuple[torch.Tensor]]]
551
+ ] = None,
552
+ past_media_locations: Optional[torch.Tensor] = None,
553
+ past_vision_tokens: Optional[torch.Tensor] = None,
554
+ use_cache: Optional[bool] = False,
555
+ **kwargs,
556
+ ):
557
+ """
558
+ Args:
559
+ vision_x: Vision input
560
+ shape (B, T_img, F, C, H, W) with F=1
561
+ only F = 1 is supported (single-frame videos)
562
+ if T_img > the number of media tokens in the corresponding input_ids (lang_x),
563
+ only the first number of media tokens in lang_x are used
564
+ lang_x: Language input ids, with media tokens denoting where
565
+ visual media should be inserted.
566
+ shape (B, T_txt)
567
+ attention_mask: Attention mask. Defaults to None.
568
+ labels: Labels. Defaults to None.
569
+ shape (B, T_txt)
570
+ past_key_values (Tuple[torch.Tensor]], optional): Past key value pairs for each of the T_txt previous tokens in the language model. Defaults to None.
571
+ list of length = number of decoder layers in the LM
572
+ exact implementation depends on LM, see Hugging Face docs
573
+ past_media_locations (torch.Tensor, optional): boolean mask denoting which of the previous T_txt tokens were media tokens. Defaults to None.
574
+ shape (B, T_txt)
575
+ past_vision_tokens (torch.Tensor, optional): Previous vision tokens. Defaults to None.
576
+ use_cache (Optional[bool], optional): Whether to use cache. Defaults to False.
577
+ If True, includes key_values, media_locations, and vision_tokens in the output.
578
+ """
579
+ assert not (past_vision_tokens is None) ^ (
580
+ past_media_locations is None
581
+ ), "past_vision_tokens and past_media_locations must both be None or both be not None"
582
+
583
+ # convert pixels to vision tokens
584
+ if vision_x is not None:
585
+ vision_features = self._encode_vision_x(vision_x=vision_x)
586
+ vision_tokens = self.vision_tokenizer(vision_features)
587
+ else:
588
+ vision_tokens = None
589
+
590
+ # fuse the vision and language tokens
591
+ new_inputs = self._prepare_inputs_for_forward(
592
+ vision_tokens=vision_tokens,
593
+ lang_x=lang_x,
594
+ attention_mask=attention_mask,
595
+ labels=labels,
596
+ past_key_values=past_key_values,
597
+ past_media_locations=past_media_locations,
598
+ padding_side="right",
599
+ past_vision_tokens=past_vision_tokens,
600
+ )
601
+ output = self.lang_model(
602
+ **new_inputs,
603
+ use_cache=use_cache,
604
+ past_key_values=past_key_values,
605
+ **kwargs,
606
+ )
607
+
608
+ # postprocessing may be needed, e.g. to remove extra tokens from logits that were inserted into the language stream
609
+ # or to add the past_vision_tokens and past_media_locations to the output
610
+ output = self._postprocess_outputs_from_forward(
611
+ output=output,
612
+ lang_x=lang_x,
613
+ vision_tokens=vision_tokens,
614
+ use_cache=use_cache,
615
+ past_vision_tokens=past_vision_tokens,
616
+ past_media_locations=past_media_locations,
617
+ )
618
+
619
+ # postforward hooks
620
+ self._post_forward_hook()
621
+ return output
622
+
623
+ def _encode_vision_x_anyres(self, samples, device):
624
+ assert self.anyres_grids is not None
625
+ image_raw = samples["image"] # list of patch list in of shape [1, N_patch, C, H, W]
626
+ image_sizes = samples["image_size"]
627
+
628
+ # Image_raw can be a list of list of patches, when a `samples` has multiple images.
629
+ if isinstance(image_raw[0], list):
630
+ images = [x.squeeze(0) for sample_img in image_raw for x in sample_img]
631
+ image_sizes = [s for sample_sizes in image_sizes for s in sample_sizes]
632
+ else:
633
+ # assert isinstance(image_raw[0], torch.Tensor), f"Unkown image type: {image_raw[0]}"
634
+ # concate list of patches into one big patch for any res encoding.
635
+ images = [x.squeeze(0) for x in image_raw] # [N_patch, C, H, W]
636
+ image = torch.cat(images, dim=0) # [\sum{B}{N_patch_i}, C, H, W]
637
+ image = image.to(device)
638
+
639
+ with torch.no_grad():
640
+ if self.vision_encoder.__class__.__name__ == "TimmModel":
641
+ image_embeds = self.vision_encoder.trunk.forward_features(image)
642
+ elif self.vision_encoder.__class__.__name__ in ['CLIPVisionModel', 'SiglipVisionTransformer']:
643
+ image_embeds = self.vision_encoder(image).last_hidden_state
644
+ else:
645
+ image_embeds = self.vision_encoder(image)[1] # OpenCLIP returns tuples
646
+
647
+ if isinstance(self.vision_encoder, CLIPVisionModel) or isinstance(self.vision_encoder, SiglipVisionTransformer):
648
+ base_img_size = self.vision_encoder.config.image_size
649
+ else:
650
+ base_img_size = self.vision_encoder.image_size[0]
651
+
652
+ if self.vision_encoder.__class__.__name__ == "TimmModel":
653
+ grid_size = self.vision_encoder.trunk.patch_embed.grid_size
654
+ elif self.vision_encoder.__class__.__name__ in ['CLIPVisionModel', 'SiglipVisionTransformer']:
655
+ grid_size_base = self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size
656
+ grid_size = (grid_size_base, grid_size_base)
657
+ else:
658
+ grid_size = self.vision_encoder.grid_size
659
+ height, width = grid_size
660
+
661
+ if not image_embeds.shape[1] == height * width:
662
+ assert image_embeds.shape[1] == height * width + 1 # For vision encoders that has [CLS] token.
663
+ image_embeds = image_embeds[:, 1:, :] # Drop the cls token for each patch.
664
+ n_vis_token_per_patch = image_embeds.shape[1]
665
+
666
+ # Split encoded patches and merge patch features
667
+ # 1. Get the raw sizes from samples, and split the image embeds [\sum_{B}(N_patch_i), N_tok(16*16), C]
668
+ split_sizes = [image.shape[0] for image in images]
669
+ image_embeds = torch.split(image_embeds, split_sizes, dim=0)
670
+ # 2. For each image (consist of a list of patches), merge the patches spatially (of shape [C, n_patch_height, n_patch_width])
671
+ new_image_embeds = []
672
+ patch_attn_masks = []
673
+ max_n_img_token = -1
674
+ for idx, patch_embeds in enumerate(image_embeds):
675
+ if patch_embeds.shape[0] > 1:
676
+ # 3. Flatten the patch features and get [C, n_patch_height * (n_patch_width+1)]
677
+ base_patch_embeds = patch_embeds[0] # TODO: prepend the CLS token for th base patch embeds (of the resized entire image).
678
+ patch_embeds = patch_embeds[1:]
679
+
680
+ assert height * width == base_patch_embeds.shape[0]
681
+
682
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[idx],
683
+ self.anyres_grids,
684
+ base_img_size) # Hardcoded grid_pinpoints.
685
+ patch_embeds = patch_embeds.view(num_patch_height, num_patch_width, height, width, -1)
686
+
687
+ patch_embeds = patch_embeds.permute(4, 0, 2, 1, 3).contiguous()
688
+ patch_embeds = patch_embeds.flatten(1, 2).flatten(2, 3)
689
+ patch_embeds, patch_attn_mask = unpad_image(patch_embeds, image_sizes[idx], self.anyres_patch_sampling)
690
+ if hasattr(self, 'image_newline'):
691
+ patch_embeds = torch.cat((
692
+ patch_embeds,
693
+ self.image_newline[:, None, None].expand(*patch_embeds.shape[:-1], 1)
694
+ ), dim=-1)
695
+ if self.anyres_patch_sampling:
696
+ patch_embeds = patch_embeds.view(-1, num_patch_height, num_patch_width, height*width)
697
+ patch_embeds = patch_embeds.flatten(1, 2).permute(1, 2, 0)
698
+ assert patch_attn_mask is not None
699
+ patch_attn_mask = patch_attn_mask.view(num_patch_height, num_patch_width, height*width)
700
+ patch_attn_mask = patch_attn_mask.flatten(0, 1)
701
+ patch_embeds = torch.cat((base_patch_embeds.unsqueeze(0), patch_embeds), dim=0)
702
+ patch_attn_mask = torch.cat((torch.ones(n_vis_token_per_patch, device=patch_embeds.device).unsqueeze(0), patch_attn_mask), dim=0)
703
+ else:
704
+ patch_embeds = patch_embeds.flatten(1, 2).transpose(0, 1)
705
+ patch_embeds = torch.cat((base_patch_embeds, patch_embeds), dim=0)
706
+ else:
707
+ patch_embeds = patch_embeds[0].unsqueeze(0) if self.anyres_patch_sampling else patch_embeds[0]
708
+ patch_attn_mask = torch.ones(n_vis_token_per_patch, device=patch_embeds.device).unsqueeze(0) if self.anyres_patch_sampling else None
709
+ if hasattr(self, 'image_newline'):
710
+ patch_embeds = torch.cat((
711
+ patch_embeds,
712
+ self.image_newline[None]
713
+ ), dim=0)
714
+ if not self.anyres_patch_sampling:
715
+ max_n_img_token = max(patch_embeds.shape[0], max_n_img_token)
716
+
717
+ new_image_embeds.append(patch_embeds)
718
+ patch_attn_masks.append(patch_attn_mask)
719
+
720
+ if self.anyres_patch_sampling:
721
+ # Return individual patches for independent token downsampling.
722
+ return new_image_embeds, patch_attn_masks
723
+
724
+ # 4. Pad and concat the list of image_embeds [N_tok_i, C] together into a batch. Also modify the query attention mask.
725
+ image_embeds = []
726
+ image_atts = []
727
+ for image_embed in new_image_embeds:
728
+ n_img_token = image_embed.shape[0]
729
+ img_attn = torch.ones((max_n_img_token), dtype=torch.long, device=image_embed.device)
730
+ if n_img_token < max_n_img_token:
731
+ padded_embed = torch.zeros((max_n_img_token, image_embed.shape[-1]), dtype=image_embed.dtype, device=image_embed.device)
732
+ padded_embed[:n_img_token, :] = image_embed
733
+ img_attn[n_img_token:] = 0 # Mask out the padded entries.
734
+ else:
735
+ padded_embed = image_embed
736
+ image_embeds.append(padded_embed)
737
+ image_atts.append(img_attn)
738
+ image_embeds = torch.stack(image_embeds, dim=0) # Shape [B, N_tok_longest, C_dim]
739
+ image_atts = torch.stack(image_atts, dim=0) # Shape [B, N_tok_longest, C_dim]
740
+ # TODO: reshape image_embeds and image_atts to "b T F v d"
741
+ image_embeds = image_embeds[:, None, None, :, :]
742
+ # image_atts = image_atts[:, None, None, :, :]
743
+
744
+ return image_embeds, image_atts
745
+
746
+ def _encode_vision_x(self, vision_x: torch.Tensor):
747
+ """
748
+ Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
749
+ Args:
750
+ vision_x: Vision input
751
+ shape (B, T_img, F, C, H, W)
752
+ Images in the same chunk are collated along T_img, and frames are collated along F
753
+ Currently only F=1 is supported (single-frame videos)
754
+
755
+ rearrange code based on https://github.com/dhansmair/flamingo-mini
756
+ """
757
+ assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
758
+ b, T, F = vision_x.shape[:3]
759
+
760
+ vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
761
+ with torch.no_grad():
762
+ if self.vision_encoder.__class__.__name__ == "TimmModel":
763
+ vision_x = self.vision_encoder.trunk.forward_features(vision_x)
764
+ elif self.vision_encoder.__class__.__name__ in ['CLIPVisionModel', 'SiglipVisionTransformer']:
765
+ vision_x = self.vision_encoder(vision_x).last_hidden_state
766
+ else:
767
+ vision_x = self.vision_encoder(vision_x)[1] # OpenCLIP returns tuples
768
+ vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
769
+ return vision_x
770
+
771
+ def _concat_vision_cache(
772
+ self, lang_x, vision_tokens, past_vision_tokens, past_media_locations, use_cache
773
+ ):
774
+ """
775
+ Helper function to include the past vision tokens and past media locations in the output.
776
+ """
777
+ if use_cache:
778
+ if past_media_locations is not None and past_vision_tokens is not None:
779
+ if vision_tokens is not None:
780
+ updated_vision_tokens = torch.cat(
781
+ [
782
+ past_vision_tokens,
783
+ vision_tokens,
784
+ ],
785
+ dim=1,
786
+ )
787
+ else:
788
+ updated_vision_tokens = past_vision_tokens
789
+ updated_media_locations = torch.cat(
790
+ [
791
+ past_media_locations,
792
+ lang_x == self.media_token_id,
793
+ ],
794
+ dim=1,
795
+ )
796
+ else:
797
+ updated_vision_tokens = vision_tokens
798
+ updated_media_locations = lang_x == self.media_token_id
799
+
800
+ else:
801
+ updated_vision_tokens = None
802
+ updated_media_locations = None
803
+
804
+ return updated_vision_tokens, updated_media_locations
805
+
806
+ def generate(
807
+ self,
808
+ vision_x: torch.Tensor,
809
+ lang_x: torch.Tensor,
810
+ attention_mask: torch.Tensor = None,
811
+ past_key_values: Optional[
812
+ List[Union[torch.Tensor, Tuple[torch.Tensor]]]
813
+ ] = None,
814
+ past_media_locations: Optional[torch.Tensor] = None,
815
+ past_vision_tokens: Optional[torch.Tensor] = None,
816
+ **kwargs,
817
+ ):
818
+ """
819
+ Generate text conditioned on vision and language inputs.
820
+ Args:
821
+ vision_x (torch.Tensor): Vision input
822
+ shape (B, T_img, F, C, H, W)
823
+ see documentation for forward
824
+ lang_x (torch.Tensor): Language input
825
+ shape (B, T_txt)
826
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
827
+ **kwargs: see generate documentation in Hugging Face CausalLM models.
828
+ Returns:
829
+ torch.Tensor: lang_x with generated tokens appended to it
830
+ """
831
+ num_beams = kwargs.pop("num_beams", 1)
832
+
833
+ # convert pixels to vision tokens
834
+ if vision_x is not None:
835
+ vision_features = self._encode_vision_x(vision_x=vision_x)
836
+ vision_tokens = self.vision_tokenizer(vision_features)
837
+ else:
838
+ vision_tokens = None
839
+
840
+ # fuse the vision and language tokens
841
+ # for xattn, vision_x and media_location are repeat_interleaved s.t.
842
+ # the total batch size is B * num_beams
843
+ new_inputs = self._prepare_inputs_for_forward(
844
+ vision_tokens=vision_tokens,
845
+ lang_x=lang_x,
846
+ attention_mask=attention_mask,
847
+ past_key_values=past_key_values,
848
+ past_media_locations=past_media_locations,
849
+ past_vision_tokens=past_vision_tokens,
850
+ padding_side="left",
851
+ num_beams=num_beams,
852
+ )
853
+ output = self.lang_model.generate(
854
+ **new_inputs,
855
+ past_key_values=past_key_values,
856
+ num_beams=num_beams,
857
+ use_cache=True,
858
+ **kwargs,
859
+ )
860
+ self._post_forward_hook()
861
+ return output
862
+
863
+ @property
864
+ def num_trainable_params(self):
865
+ """Print the number of trainable parameters"""
866
+ return num_params(self, filter_to_trainable=True)
867
+
868
+ def set_trainable(self):
869
+ """
870
+ Freeze appropriate parameters in the model.
871
+ """
872
+ raise NotImplementedError
873
+
874
+ def group_params_by_weight_decay(self):
875
+ """
876
+ Return a tuple of (params to optimize w/ weight decay, params to optimize w/o weight decay)
877
+ """
878
+ params_with_wd, params_without_wd = [], []
879
+ for n, p in self.named_parameters():
880
+ if p.requires_grad:
881
+ if self._should_apply_weight_decay(n):
882
+ params_with_wd.append(p)
883
+ else:
884
+ params_without_wd.append(p)
885
+ return params_with_wd, params_without_wd
886
+
887
+ def _should_apply_weight_decay(self, parameter_name):
888
+ """
889
+ Return whether weight decay should be applied to a parameter.
890
+ """
891
+ raise NotImplementedError
892
+
893
+ @property
894
+ def special_tokens(self):
895
+ """
896
+ Returns a dict mapping from the attribute name of a special token to its string format,
897
+ e.g. "media_token": "<image>"
898
+ """
899
+ assert (
900
+ "media_token" in self._special_tokens
901
+ ), "VLMs need to request that the tokenizer add a media_token and call set_special_token_ids to set self.media_token_id"
902
+ return self._special_tokens
903
+
904
+ @property
905
+ def special_token_ids(self):
906
+ """
907
+ Returns a list of the special token ids
908
+ """
909
+ return [getattr(self, f"{att_name}_id") for att_name in self.special_tokens]
910
+
911
+ def set_special_token_ids(self, string_to_ids):
912
+ """
913
+ Args:
914
+ string_to_ids (dict): mapping from token string to id
915
+ """
916
+ assert set(self.special_tokens.values()).issubset(set(string_to_ids.keys()))
917
+ for att_name, token_str in self.special_tokens.items():
918
+ token_id = string_to_ids[token_str]
919
+ setattr(self, f"{att_name}_id", token_id)
920
+ setattr(self.lang_model, f"{att_name}_id", token_id)
921
+
922
+ def init_gradient_checkpointing(self):
923
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
924
+ checkpoint_wrapper,
925
+ CheckpointWrapper,
926
+ CheckpointImpl,
927
+ apply_activation_checkpointing,
928
+ )
929
+ from functools import partial
930
+
931
+ non_reentrant_wrapper = partial(
932
+ checkpoint_wrapper,
933
+ checkpoint_impl=CheckpointImpl.NO_REENTRANT,
934
+ )
935
+ apply_activation_checkpointing(
936
+ self,
937
+ checkpoint_wrapper_fn=non_reentrant_wrapper,
938
+ check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False)
939
+ and not isinstance(m, CheckpointWrapper),
940
+ )
941
+
942
+ @dataclass
943
+ class VLMOutputWithPast(CausalLMOutputWithPast):
944
+ """
945
+ VLMOutputWithPast is a wrapper around CausalLMOutputWithPast that adds the following attributes:
946
+ past_media_locations: Optional[torch.Tensor] = None,
947
+ past_vision_tokens: Optional[torch.Tensor] = None,
948
+ """
949
+
950
+ past_media_locations: Optional[torch.Tensor] = None
951
+ past_vision_tokens: Optional[torch.Tensor] = None
952
+
953
+
954
+ def exists(val):
955
+ return val is not None
956
+
957
+
958
+ def FeedForward(dim, mult=4):
959
+ inner_dim = int(dim * mult)
960
+ return nn.Sequential(
961
+ nn.LayerNorm(dim),
962
+ nn.Linear(dim, inner_dim, bias=False),
963
+ nn.GELU(),
964
+ nn.Linear(inner_dim, dim, bias=False),
965
+ )
966
+
967
+ class VLMWithLanguageStream(VLM):
968
+ """
969
+ VLM that fuses modalities by inserting vision tokens directly into the language stream.
970
+ """
971
+
972
+ def __init__(
973
+ self,
974
+ vision_encoder: nn.Module,
975
+ vision_tokenizer: nn.Module,
976
+ lang_model: nn.Module,
977
+ initial_tokenizer_len: int,
978
+ pad_token_id: int,
979
+ decoder_layers_attr_name: str = None,
980
+ gradient_checkpointing: bool = False,
981
+ ):
982
+ super().__init__(
983
+ vision_encoder=vision_encoder,
984
+ vision_tokenizer=vision_tokenizer,
985
+ lang_model=lang_model,
986
+ initial_tokenizer_len=initial_tokenizer_len,
987
+ pad_token_id=pad_token_id,
988
+ gradient_checkpointing=gradient_checkpointing,
989
+ )
990
+ self.decoder_layers_attr_name = decoder_layers_attr_name
991
+ if decoder_layers_attr_name is not None:
992
+ for block in getattr_recursive(self.lang_model, self.decoder_layers_attr_name):
993
+ block._use_gradient_checkpointing = gradient_checkpointing
994
+
995
+ def _prepare_inputs_for_forward(
996
+ self,
997
+ vision_tokens: torch.Tensor,
998
+ lang_x: torch.Tensor,
999
+ attention_mask: torch.Tensor,
1000
+ labels: torch.Tensor = None,
1001
+ past_key_values=None,
1002
+ vision_attention_mask: Optional[torch.Tensor] = None,
1003
+ past_media_locations: torch.Tensor = None,
1004
+ past_vision_tokens: torch.Tensor = None,
1005
+ padding_side: str = "left",
1006
+ num_beams: int = 1,
1007
+ ):
1008
+ """
1009
+ Insert the vision tokens directly into the language stream/
1010
+ This requires us to modify the input_ids, attention_mask, and labels.
1011
+ """
1012
+ if past_key_values is not None:
1013
+ past_len = past_key_values[0][0].shape[2]
1014
+ assert attention_mask.shape[1] == past_len + lang_x.shape[1], (
1015
+ "Attention_mask must be as long as the entire past len (including image tokens) and current input IDs. "
1016
+ + "Check that you've expanded the attention mask to account for past image tokens."
1017
+ )
1018
+
1019
+ if vision_tokens is None:
1020
+ return {
1021
+ "input_ids": lang_x,
1022
+ "attention_mask": attention_mask,
1023
+ "labels": labels,
1024
+ }
1025
+
1026
+ # get the language embeddings
1027
+ lang_embeds = self.lang_model.get_input_embeddings()(lang_x)
1028
+
1029
+ # build up the multimodal embeddings
1030
+ B = lang_x.shape[0]
1031
+ has_labels = labels is not None
1032
+ multimodal_embeds = []
1033
+ multimodal_attention_mask = []
1034
+ multimodal_labels = [] if has_labels else None
1035
+ for i in range(B):
1036
+ # get index of <image> tokens in lang_x[i]
1037
+ image_token_idxs = torch.where(lang_x[i] == self.media_token_id)[0]
1038
+
1039
+ if len(image_token_idxs) == 0:
1040
+ multimodal_embeds.append(lang_embeds[i].clone())
1041
+ multimodal_attention_mask.append(attention_mask[i].clone())
1042
+ if has_labels:
1043
+ multimodal_labels.append(labels[i].clone())
1044
+ continue
1045
+
1046
+ # since an image is represented by self.num_tokens_per_vis tokens, we need to offset the image_token_idxs
1047
+ for j, img_idx in enumerate(image_token_idxs):
1048
+ image_token_idxs[j] += (self.num_tokens_per_vis - 1) * j # FIXME: different offset for any resolution encoding when has multiple images.
1049
+
1050
+ # loop through the image_token_idxs and insert the vision tokens
1051
+ new_embed = lang_embeds[i].clone()
1052
+ new_attention_mask = (
1053
+ attention_mask[i].clone() if attention_mask is not None else None
1054
+ )
1055
+ if has_labels:
1056
+ new_label = labels[i].clone()
1057
+
1058
+ for img_num, img_idx in enumerate(image_token_idxs):
1059
+ if img_num > 0:
1060
+ # FIXME: hardcoded as such to avoid assertion error, but this only works for single image samples.
1061
+ break
1062
+ # Get vision token attention mask for padded llava-style any resolution image tokens.
1063
+ if self.image_aspect_ratio =='anyres':
1064
+ num_vis_tokens = vision_tokens[i][img_num].shape[0]
1065
+ if vision_attention_mask is not None:
1066
+ vis_attention_mask = vision_attention_mask[i]
1067
+ else:
1068
+ vis_attention_mask = torch.ones(
1069
+ num_vis_tokens, dtype=torch.long
1070
+ ).to(attention_mask.device)
1071
+ else:
1072
+ assert (
1073
+ vision_tokens[i][img_num].shape[0] == self.num_tokens_per_vis
1074
+ ), f"vision token number mismatch: image embedding ({vision_tokens[i][img_num].shape[0]}) \
1075
+ vs. model.num_tokens_per_vis ({self.num_tokens_per_vis})"
1076
+ # By default, vision tokens are not padded.
1077
+ num_vis_tokens = self.num_tokens_per_vis
1078
+ vis_attention_mask = torch.ones(
1079
+ num_vis_tokens, dtype=torch.long
1080
+ ).to(attention_mask.device)
1081
+
1082
+
1083
+ new_embed = torch.cat(
1084
+ (
1085
+ new_embed[:img_idx],
1086
+ vision_tokens[i][img_num],
1087
+ new_embed[img_idx + 1 :],
1088
+ ),
1089
+ dim=0,
1090
+ )
1091
+ new_attention_mask = torch.cat(
1092
+ (
1093
+ new_attention_mask[:img_idx],
1094
+ vis_attention_mask,
1095
+ new_attention_mask[img_idx + 1 :],
1096
+ ),
1097
+ dim=0,
1098
+ )
1099
+ if has_labels:
1100
+ new_label = torch.cat(
1101
+ (
1102
+ new_label[:img_idx],
1103
+ torch.ones(num_vis_tokens, dtype=torch.long).to(
1104
+ labels.device
1105
+ )
1106
+ * -100,
1107
+ new_label[img_idx + 1 :],
1108
+ ),
1109
+ dim=0,
1110
+ )
1111
+ multimodal_embeds.append(new_embed)
1112
+ multimodal_attention_mask.append(new_attention_mask)
1113
+ if has_labels:
1114
+ multimodal_labels.append(new_label)
1115
+
1116
+ # stack
1117
+ multimodal_embeds = stack_with_padding(
1118
+ multimodal_embeds,
1119
+ padding_value=self.pad_token_id,
1120
+ padding_side=padding_side,
1121
+ )
1122
+ multimodal_attention_mask = stack_with_padding(
1123
+ multimodal_attention_mask,
1124
+ padding_value=0,
1125
+ padding_side=padding_side,
1126
+ )
1127
+ if has_labels:
1128
+ multimodal_labels = stack_with_padding(
1129
+ multimodal_labels,
1130
+ padding_value=-100,
1131
+ padding_side=padding_side,
1132
+ )
1133
+
1134
+ return {
1135
+ "inputs_embeds": multimodal_embeds,
1136
+ "attention_mask": multimodal_attention_mask,
1137
+ "labels": multimodal_labels,
1138
+ }
1139
+
1140
+ def _postprocess_outputs_from_forward(
1141
+ self,
1142
+ output: CausalLMOutputWithPast,
1143
+ lang_x: torch.Tensor,
1144
+ vision_tokens: torch.Tensor,
1145
+ past_vision_tokens: torch.Tensor,
1146
+ past_media_locations: torch.Tensor,
1147
+ use_cache: bool = False,
1148
+ ):
1149
+ # Include the past vision tokens and past media locations in the output
1150
+ updated_vision_tokens, updated_media_locations = self._concat_vision_cache(
1151
+ lang_x=lang_x,
1152
+ vision_tokens=vision_tokens,
1153
+ past_vision_tokens=past_vision_tokens,
1154
+ past_media_locations=past_media_locations,
1155
+ use_cache=use_cache,
1156
+ )
1157
+
1158
+ # return logits that are the same shape as the original input_ids
1159
+ logits = output.logits
1160
+ batch_logits = []
1161
+ B, T_txt = lang_x.shape
1162
+ for i in range(B):
1163
+ sequence_logits = []
1164
+ logits_j = 0
1165
+ for j in range(T_txt):
1166
+ if lang_x[i, j] != self.media_token_id:
1167
+ sequence_logits.append(logits[i, logits_j])
1168
+ logits_j += 1
1169
+ else:
1170
+ # append the logit for the first image token, then skip over the rest
1171
+ # note: the model actually learns to predict <im_patch>, not <image>
1172
+ sequence_logits.append(logits[i, logits_j])
1173
+ logits_j += self.num_tokens_per_vis
1174
+ sequence_logits = torch.stack(sequence_logits, dim=0) # (B, vocab_size)
1175
+ batch_logits.append(sequence_logits)
1176
+
1177
+ batch_logits = torch.stack(batch_logits, dim=0) # (B, T_txt, vocab_size)
1178
+ # The final logits shape should be the same as the original input_ids shape
1179
+ assert batch_logits.shape[:2] == (B, T_txt)
1180
+
1181
+ # assemble the output
1182
+ output = VLMOutputWithPast(
1183
+ loss=output.loss,
1184
+ logits=batch_logits,
1185
+ past_key_values=output.past_key_values,
1186
+ hidden_states=output.hidden_states,
1187
+ attentions=output.attentions,
1188
+ past_media_locations=updated_media_locations,
1189
+ past_vision_tokens=updated_vision_tokens,
1190
+ )
1191
+
1192
+ return output
1193
+
1194
+ def _post_forward_hook(self):
1195
+ pass
1196
+
1197
+
1198
+ @property
1199
+ def num_params_per_module(self):
1200
+ """Print the number of parameters per module in the model"""
1201
+ return "\n".join(
1202
+ [
1203
+ f"Vision encoder: {num_params(self.vision_encoder):,} parameters",
1204
+ f"Vision tokenizer: {num_params(self.vision_tokenizer):,} parameters",
1205
+ f"Language model: {num_params(self.lang_model):,} parameters",
1206
+ ]
1207
+ )
1208
+
1209
+ @property
1210
+ def num_trainable_params_per_module(self):
1211
+ """Print the number of trainable parameters per module in the model"""
1212
+ return "\n".join(
1213
+ [
1214
+ f"Vision encoder: {num_params(self.vision_encoder, filter_to_trainable=True):,} trainable parameters",
1215
+ f"Vision tokenizer: {num_params(self.vision_tokenizer, filter_to_trainable=True):,} trainable parameters",
1216
+ f"Language model: {num_params(self.lang_model, filter_to_trainable=True):,} trainable parameters",
1217
+ ]
1218
+ )
1219
+
1220
+
1221
+ class XGenMMPerceiver(VLMWithLanguageStream):
1222
+ def __init__(
1223
+ self,
1224
+ vision_encoder: nn.Module,
1225
+ vision_tokenizer: nn.Module,
1226
+ lang_model: nn.Module,
1227
+ initial_tokenizer_len: int,
1228
+ pad_token_id: int,
1229
+ decoder_layers_attr_name: str = None,
1230
+ gradient_checkpointing: bool = False,
1231
+ image_aspect_ratio: str = 'anyres',
1232
+ anyres_patch_sampling: bool = True,
1233
+ anyres_grids: list[int] = None,
1234
+ ):
1235
+ """
1236
+ Args:
1237
+ vision_encoder (nn.Module): HF CLIPModel
1238
+ lang_encoder (nn.Module): HF causal language model
1239
+ vis_feature_dim (int): final dimension of the visual features outputted by the vision_encoder
1240
+ initial_tokenizer_len (int): size of the tokenizer vocab
1241
+ padding_token_id (int): id of the padding token. None if no padding token; then a padding token
1242
+ will be inserted into self.special_tokens, which factory.py fills after creating new tokens
1243
+ decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
1244
+ gradient_checkpointing (bool, optional): whether to use gradient checkpointing. Defaults to False.
1245
+ """
1246
+ self._special_tokens = {
1247
+ "media_token": "<image>",
1248
+ "image_placeholder_token": "<image placeholder>",
1249
+ "end_of_trunk_token": "<|endofchunk|>",
1250
+ }
1251
+ lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]
1252
+ super().__init__(
1253
+ vision_encoder=vision_encoder,
1254
+ vision_tokenizer=vision_tokenizer,
1255
+ lang_model=lang_model,
1256
+ initial_tokenizer_len=initial_tokenizer_len,
1257
+ gradient_checkpointing=gradient_checkpointing,
1258
+ decoder_layers_attr_name=decoder_layers_attr_name,
1259
+ pad_token_id=pad_token_id,
1260
+ )
1261
+ self.image_aspect_ratio = image_aspect_ratio
1262
+ self.anyres_patch_sampling = anyres_patch_sampling
1263
+ self.anyres_grids = anyres_grids
1264
+
1265
+ def set_trainable(self):
1266
+ """
1267
+ Unfreeze everything except the vision_encoder
1268
+ """
1269
+ self.requires_grad_(True)
1270
+ self.vision_encoder.requires_grad_(False)
1271
+
1272
+ def _should_apply_weight_decay(self, parameter_name):
1273
+ """
1274
+ Kosmos applies 0.01 weight deacy to everything
1275
+ """
1276
+ return True
1277
+
1278
+ def forward(
1279
+ self,
1280
+ vision_x: Optional[torch.Tensor],
1281
+ lang_x: torch.Tensor,
1282
+ attention_mask: Optional[torch.Tensor] = None,
1283
+ labels: Optional[torch.Tensor] = None,
1284
+ image_size: Optional[Tuple] = None,
1285
+ past_key_values: Optional[
1286
+ List[Union[torch.Tensor, Tuple[torch.Tensor]]]
1287
+ ] = None,
1288
+ past_media_locations: Optional[torch.Tensor] = None,
1289
+ past_vision_tokens: Optional[torch.Tensor] = None,
1290
+ use_cache: Optional[bool] = False,
1291
+ **kwargs,
1292
+ ):
1293
+ """
1294
+ Args:
1295
+ vision_x: Vision input
1296
+ shape (B, T_img, F, C, H, W) with F=1
1297
+ only F = 1 is supported (single-frame videos)
1298
+ if T_img > the number of media tokens in the corresponding input_ids (lang_x),
1299
+ only the first number of media tokens in lang_x are used
1300
+ lang_x: Language input ids, with media tokens denoting where
1301
+ visual media should be inserted.
1302
+ shape (B, T_txt)
1303
+ attention_mask: Attention mask. Defaults to None.
1304
+ labels: Labels. Defaults to None.
1305
+ shape (B, T_txt)
1306
+ past_key_values (Tuple[torch.Tensor]], optional): Past key value pairs for each of the T_txt previous tokens in the language model. Defaults to None.
1307
+ list of length = number of decoder layers in the LM
1308
+ exact implementation depends on LM, see Hugging Face docs
1309
+ past_media_locations (torch.Tensor, optional): boolean mask denoting which of the previous T_txt tokens were media tokens. Defaults to None.
1310
+ shape (B, T_txt)
1311
+ past_vision_tokens (torch.Tensor, optional): Previous vision tokens. Defaults to None.
1312
+ use_cache (Optional[bool], optional): Whether to use cache. Defaults to False.
1313
+ If True, includes key_values, media_locations, and vision_tokens in the output.
1314
+ """
1315
+ assert not (past_vision_tokens is None) ^ (
1316
+ past_media_locations is None
1317
+ ), "past_vision_tokens and past_media_locations must both be None or both be not None"
1318
+
1319
+ # convert pixels to vision tokens
1320
+ vision_attention_mask = None
1321
+ if vision_x is not None:
1322
+ if self.image_aspect_ratio == 'anyres':
1323
+ input_dict = dict(image=vision_x, image_size=image_size)
1324
+ vision_features, vision_attn_masks = self._encode_vision_x_anyres(input_dict, lang_x.device)
1325
+ else:
1326
+ vision_features = self._encode_vision_x(vision_x=vision_x)
1327
+ vision_attn_masks = None
1328
+ # Same for attention masks: [b, Np, v] -> [b*Np, v]
1329
+ if self.anyres_patch_sampling:
1330
+ split_sizes = [feature.shape[0] for feature in vision_features]
1331
+ # Nested splits for multi-image samples.
1332
+ if isinstance(vision_x[0], list):
1333
+ nt_images = [len(images) for images in vision_x]
1334
+ split_split_sizes = []
1335
+ img_id = 0
1336
+ for nt in nt_images:
1337
+ split_split_sizes.append(split_sizes[img_id:img_id+nt])
1338
+ img_id += nt
1339
+ else:
1340
+ nt_images = [1] * len(vision_x)
1341
+ split_split_sizes = split_sizes
1342
+ vision_features = torch.cat(vision_features, dim=0)
1343
+ vision_features = vision_features[:, None, None, :, :] # Expand dimensions.
1344
+ vision_attn_masks = torch.cat(vision_attn_masks, dim=0)
1345
+ # TODO: add an option that allows restoring the T dimension for video tokenization.
1346
+ vision_tokens = self.vision_tokenizer(vision_features, vision_attn_masks)
1347
+
1348
+ # Post-processing: Split the batches into groups of patches and concatenate them together.
1349
+ if self.anyres_patch_sampling:
1350
+ # assert isinstance(vision_x, list)
1351
+ if isinstance(vision_x[0], list):
1352
+ vision_token_groups = torch.split(vision_tokens, list(sum(nt_img) for nt_img in split_split_sizes), dim=0)
1353
+ vision_tokens = []
1354
+
1355
+ for sample_id, patch_vis_tokens in enumerate(vision_token_groups):
1356
+ patch_vis_token_groups = torch.split(patch_vis_tokens, split_split_sizes[sample_id], dim=0) # [Np*nt, 1, v, d] -> [[Np_t, 1, v, d], ...]
1357
+ flatten_vision_tokens = []
1358
+ for image_vis_token in patch_vis_token_groups:
1359
+ image_vis_token = image_vis_token.flatten(0, 2) # [Np, 1, v, d] -> [Np*v, d]
1360
+ flatten_vision_tokens.append(image_vis_token)
1361
+ vision_tokens_i = flatten_vision_tokens
1362
+ vision_tokens.append(vision_tokens_i)
1363
+ else:
1364
+ vision_token_groups = torch.split(vision_tokens, split_sizes, dim=0)
1365
+ vision_tokens = []
1366
+ for patch_vis_tokens in vision_token_groups:
1367
+ patch_vis_tokens = patch_vis_tokens.flatten(0, 2) # [Np, 1, v, d] -> [Np*v, d]
1368
+ vision_tokens.append(patch_vis_tokens.unsqueeze(0)) # Add the nt dimension.
1369
+ else:
1370
+ vision_tokens = None
1371
+
1372
+ # fuse the vision and language tokens
1373
+ new_inputs = self._prepare_inputs_for_forward(
1374
+ vision_tokens=vision_tokens,
1375
+ lang_x=lang_x,
1376
+ attention_mask=attention_mask,
1377
+ vision_attention_mask=vision_attention_mask,
1378
+ labels=labels,
1379
+ past_key_values=past_key_values,
1380
+ past_media_locations=past_media_locations,
1381
+ padding_side="right",
1382
+ past_vision_tokens=past_vision_tokens,
1383
+ )
1384
+ output = self.lang_model(
1385
+ **new_inputs,
1386
+ use_cache=use_cache,
1387
+ past_key_values=past_key_values,
1388
+ **kwargs,
1389
+ )
1390
+
1391
+ # postforward hooks
1392
+ self._post_forward_hook()
1393
+ return output
1394
+
1395
+ def generate(
1396
+ self,
1397
+ vision_x: torch.Tensor,
1398
+ lang_x: torch.Tensor,
1399
+ image_size: Optional[Tuple] = None,
1400
+ attention_mask: torch.Tensor = None,
1401
+ past_key_values: Optional[
1402
+ List[Union[torch.Tensor, Tuple[torch.Tensor]]]
1403
+ ] = None,
1404
+ past_media_locations: Optional[torch.Tensor] = None,
1405
+ past_vision_tokens: Optional[torch.Tensor] = None,
1406
+ **kwargs,
1407
+ ):
1408
+ """
1409
+ Generate text conditioned on vision and language inputs.
1410
+ Args:
1411
+ vision_x (torch.Tensor): Vision input
1412
+ shape (B, T_img, F, C, H, W)
1413
+ see documentation for forward
1414
+ lang_x (torch.Tensor): Language input
1415
+ shape (B, T_txt)
1416
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
1417
+ **kwargs: see generate documentation in Hugging Face CausalLM models.
1418
+ Returns:
1419
+ torch.Tensor: lang_x with generated tokens appended to it
1420
+ """
1421
+ num_beams = kwargs.pop("num_beams", 1)
1422
+
1423
+ # convert pixels to vision tokens
1424
+ vision_attention_mask = None
1425
+ if vision_x is not None:
1426
+ if self.image_aspect_ratio == 'anyres':
1427
+ input_dict = dict(image=vision_x, image_size=image_size)
1428
+ vision_features, vision_attn_masks = self._encode_vision_x_anyres(input_dict, lang_x.device)
1429
+ else:
1430
+ vision_features = self._encode_vision_x(vision_x=vision_x)
1431
+ vision_attn_masks = None
1432
+ # TODO: If doing patch sampling, then flatten patches of shape [b, Np_i, v, d] -> [b*Np, v, d]
1433
+ # Same for attention masks: [b, Np, v] -> [b*Np, v]
1434
+ if self.anyres_patch_sampling:
1435
+ split_sizes = [feature.shape[0] for feature in vision_features]
1436
+ # Nested splits for multi-image samples.
1437
+ if isinstance(vision_x[0], list):
1438
+ nt_images = [len(images) for images in vision_x]
1439
+ split_split_sizes = []
1440
+ img_id = 0
1441
+ for nt in nt_images:
1442
+ split_split_sizes.append(split_sizes[img_id:img_id+nt])
1443
+ img_id += nt
1444
+ else:
1445
+ nt_images = [1] * len(vision_x)
1446
+ split_split_sizes = split_sizes
1447
+ vision_features = torch.cat(vision_features, dim=0)
1448
+ vision_features = vision_features[:, None, None, :, :] # Expand dimensions.
1449
+ vision_attn_masks = torch.cat(vision_attn_masks, dim=0)
1450
+ vision_tokens = self.vision_tokenizer(vision_features, vision_attn_masks)
1451
+
1452
+ # Post-processing: Split the batches into groups of patches and concatenate them together.
1453
+ if self.anyres_patch_sampling:
1454
+ assert isinstance(vision_x, list)
1455
+ if isinstance(vision_x[0], list):
1456
+ vision_token_groups = torch.split(vision_tokens, list(sum(nt_img) for nt_img in split_split_sizes), dim=0)
1457
+ vision_tokens = []
1458
+
1459
+ for sample_id, patch_vis_tokens in enumerate(vision_token_groups):
1460
+ patch_vis_token_groups = torch.split(patch_vis_tokens, split_split_sizes[sample_id], dim=0) # [Np*nt, 1, v, d] -> [[Np_t, 1, v, d], ...]
1461
+ flatten_vision_tokens = []
1462
+ for image_vis_token in patch_vis_token_groups:
1463
+ image_vis_token = image_vis_token.flatten(0, 2) # [Np, 1, v, d] -> [Np*v, d]
1464
+ flatten_vision_tokens.append(image_vis_token)
1465
+ vision_tokens_i = flatten_vision_tokens
1466
+ vision_tokens.append(vision_tokens_i)
1467
+ else:
1468
+ vision_token_groups = torch.split(vision_tokens, split_sizes, dim=0)
1469
+ vision_tokens = []
1470
+ for patch_vis_tokens in vision_token_groups:
1471
+ patch_vis_tokens = patch_vis_tokens.flatten(0, 2) # [Np, 1, v, d] -> [Np*v, d]
1472
+ vision_tokens.append(patch_vis_tokens.unsqueeze(0)) # Add the nt dimension.
1473
+ else:
1474
+ vision_tokens = None
1475
+
1476
+ # fuse the vision and language tokens
1477
+ # for xattn, vision_x and media_location are repeat_interleaved s.t.
1478
+ # the total batch size is B * num_beams
1479
+ new_inputs = self._prepare_inputs_for_forward(
1480
+ vision_tokens=vision_tokens,
1481
+ lang_x=lang_x,
1482
+ attention_mask=attention_mask,
1483
+ vision_attention_mask=vision_attention_mask,
1484
+ past_key_values=past_key_values,
1485
+ past_media_locations=past_media_locations,
1486
+ past_vision_tokens=past_vision_tokens,
1487
+ padding_side="left",
1488
+ num_beams=num_beams,
1489
+ )
1490
+ if past_key_values is not None:
1491
+ output = self.lang_model.generate(
1492
+ **new_inputs,
1493
+ past_key_values=past_key_values,
1494
+ num_beams=num_beams,
1495
+ use_cache=True,
1496
+ **kwargs,
1497
+ )
1498
+ else:
1499
+ output = self.lang_model.generate(
1500
+ **new_inputs,
1501
+ num_beams=num_beams,
1502
+ use_cache=True,
1503
+ **kwargs,
1504
+ )
1505
+ self._post_forward_hook()
1506
+ return output