Manli commited on
Commit
8888e4c
·
1 Parent(s): b5e279f

Merge modeling files into a single one to avoid relative import

Browse files
Files changed (6) hide show
  1. config.json +1 -1
  2. configuration_xgenmm.py +0 -159
  3. image_processing_blip_3.py +12 -1
  4. modeling_xgenmm.py +2033 -40
  5. utils.py +0 -383
  6. vlm.py +0 -1381
config.json CHANGED
@@ -3,7 +3,7 @@
3
  "XGenMMModelForConditionalGeneration"
4
  ],
5
  "auto_map": {
6
- "AutoConfig": "configuration_xgenmm.XGenMMConfig",
7
  "AutoModelForVision2Seq": "modeling_xgenmm.XGenMMModelForConditionalGeneration"
8
  },
9
  "model_type": "xgenmm",
 
3
  "XGenMMModelForConditionalGeneration"
4
  ],
5
  "auto_map": {
6
+ "AutoConfig": "modeling_xgenmm.XGenMMConfig",
7
  "AutoModelForVision2Seq": "modeling_xgenmm.XGenMMModelForConditionalGeneration"
8
  },
9
  "model_type": "xgenmm",
configuration_xgenmm.py DELETED
@@ -1,159 +0,0 @@
1
- from transformers import PretrainedConfig
2
- from transformers import logging
3
- from transformers import CONFIG_MAPPING
4
-
5
- logger = logging.get_logger(__name__)
6
-
7
- class XGenMMVisionEncoderConfig(PretrainedConfig):
8
- model_type = "xgenmm_vision_encoder"
9
-
10
- def __init__(self,
11
- model_name: str = 'google/siglip-so400m-patch14-384',
12
- anyres_grids: list[int] = [[384, 768],[768, 384],[768, 768],[1152, 384],[384,1152]],
13
- **kwargs):
14
- self.model_name = model_name
15
- self.anyres_grids = anyres_grids
16
- super().__init__(**kwargs)
17
-
18
-
19
- class XGenMMVisionTokenizerConfig(PretrainedConfig):
20
- model_type = "xgenmm_vision_tokenizer"
21
-
22
- def __init__(self,
23
- vis_feature_dim: int = 1152,
24
- lang_embedding_dim: int = 3072,
25
- num_vis_tokens: int = 128,
26
- image_aspect_ratio: str = 'anyres',
27
- **kwargs):
28
- self.vis_feature_dim = vis_feature_dim
29
- self.lang_embedding_dim = lang_embedding_dim
30
- self.num_vis_tokens = num_vis_tokens
31
- self.image_aspect_ratio = image_aspect_ratio
32
- super().__init__(**kwargs)
33
-
34
-
35
- class XGenMMConfig(PretrainedConfig):
36
- model_type = "xgenmm"
37
-
38
- def __init__(self,
39
- vision_encoder_config: dict = None,
40
- vision_tokenizer_config: dict = None,
41
- text_config: dict = None,
42
- **kwargs):
43
-
44
- if vision_encoder_config is None:
45
- vision_encoder_config = {'image_aspect_ratio': 'anyres', 'anyres_patch_sampling': True}
46
- logger.info("vision_encoder_config is None. initializing the XGenMMVisionEncoderConfig with default values.")
47
-
48
- if vision_tokenizer_config is None:
49
- vision_tokenizer_config = {}
50
- logger.info("vision_tokenizer_config is None. Initializing the XGenMMVisionTokenizerConfig with default values.")
51
-
52
- if text_config is None:
53
- text_config = {
54
- 'initial_tokenizer_len':32012,
55
- 'pad_token_id':32011,
56
- 'bos_token_id':1,
57
- 'eos_token_id':32000,
58
- 'vocab_size': 32064,
59
- 'hidden_size': 3072,
60
- 'intermediate_size': 8192,
61
- 'num_hidden_layers': 32,
62
- 'num_attention_heads': 32,
63
- 'num_key_value_heads': 32,
64
- 'resid_pdrop': 0.0,
65
- 'embd_pdrop': 0.0,
66
- 'attention_dropout': 0.0,
67
- 'hidden_act': 'silu',
68
- 'max_position_embeddings': 4096,
69
- 'original_max_position_embeddings': 4096,
70
- 'initializer_range': 0.02,
71
- 'rms_norm_eps': 1e-05,
72
- 'use_cache': True,
73
- 'rope_theta': 10000.0,
74
- 'rope_scaling': None,
75
- 'sliding_window': 2047,
76
- 'return_dict': True,
77
- 'output_hidden_states': False,
78
- 'output_attentions': False,
79
- 'torchscript': False,
80
- 'torch_dtype': 'bfloat16',
81
- 'use_bfloat16': False,
82
- 'tf_legacy_loss': False,
83
- 'pruned_heads': {},
84
- 'tie_word_embeddings': False,
85
- 'chunk_size_feed_forward': 0,
86
- 'is_encoder_decoder': False,
87
- 'is_decoder': False,
88
- 'cross_attention_hidden_size': None,
89
- 'add_cross_attention': False,
90
- 'tie_encoder_decoder': False,
91
- 'max_length': 20,
92
- 'min_length': 0,
93
- 'do_sample': False,
94
- 'early_stopping': False,
95
- 'num_beams': 1,
96
- 'num_beam_groups': 1,
97
- 'diversity_penalty': 0.0,
98
- 'temperature': 1.0,
99
- 'top_k': 50,
100
- 'top_p': 1.0,
101
- 'typical_p': 1.0,
102
- 'repetition_penalty': 1.0,
103
- 'length_penalty': 1.0,
104
- 'no_repeat_ngram_size': 0,
105
- 'encoder_no_repeat_ngram_size': 0,
106
- 'bad_words_ids': None,
107
- 'num_return_sequences': 1,
108
- 'output_scores': False,
109
- 'return_dict_in_generate': False,
110
- 'forced_bos_token_id': None,
111
- 'forced_eos_token_id': None,
112
- 'remove_invalid_values': False,
113
- 'exponential_decay_length_penalty': None,
114
- 'suppress_tokens': None,
115
- 'begin_suppress_tokens': None,
116
- 'finetuning_task': None,
117
- 'id2label': {0: 'LABEL_0', 1: 'LABEL_1'},
118
- 'label2id': {'LABEL_0': 0, 'LABEL_1': 1},
119
- 'tokenizer_class': None,
120
- 'prefix': None,
121
- 'bos_token_id': 1,
122
- 'pad_token_id': 32000,
123
- 'eos_token_id': 32000,
124
- 'sep_token_id': None,
125
- 'decoder_start_token_id': None,
126
- 'task_specific_params': None,
127
- 'problem_type': None,
128
- 'model_type': 'phi3'
129
- }
130
- logger.info("text_config is None. Initializing the text config with default values (`Phi3Config`).")
131
-
132
- self.vision_encoder_config = XGenMMVisionEncoderConfig(**vision_encoder_config)
133
-
134
- self.vision_tokenizer_config = XGenMMVisionTokenizerConfig(**vision_tokenizer_config)
135
-
136
- text_model_type = text_config["model_type"] if "model_type" in text_config else "phi3"
137
- self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
138
-
139
- for key in ['initial_tokenizer_len', 'pad_token_id']:
140
- if key not in self.text_config.to_dict():
141
- raise ValueError(f"The key `{key}` is missing in the text_config.")
142
-
143
- super().__init__(**kwargs)
144
-
145
- @classmethod
146
- def from_vision_encoder_vision_tokenizer_text_configs(
147
- cls,
148
- vision_encoder_config: XGenMMVisionEncoderConfig,
149
- vision_tokenizer_config: XGenMMVisionTokenizerConfig,
150
- text_config: PretrainedConfig,
151
- **kwargs):
152
-
153
- return cls(
154
- vision_encoder_config=vision_encoder_config.to_dict(),
155
- vision_tokenizer_config=vision_tokenizer_config.to_dict(),
156
- text_config=text_config.to_dict(),
157
- **kwargs,
158
- )
159
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
image_processing_blip_3.py CHANGED
@@ -13,7 +13,18 @@ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
13
  from transformers.image_utils import ImageInput
14
  from transformers.utils import TensorType
15
 
16
- from .utils import expand2square
 
 
 
 
 
 
 
 
 
 
 
17
 
18
 
19
  class Blip3ImageProcessor(BaseImageProcessor):
 
13
  from transformers.image_utils import ImageInput
14
  from transformers.utils import TensorType
15
 
16
+ def expand2square(pil_img, background_color):
17
+ width, height = pil_img.size
18
+ if width == height:
19
+ return pil_img
20
+ elif width > height:
21
+ result = Image.new(pil_img.mode, (width, width), background_color)
22
+ result.paste(pil_img, (0, (width - height) // 2))
23
+ return result
24
+ else:
25
+ result = Image.new(pil_img.mode, (height, height), background_color)
26
+ result.paste(pil_img, ((height - width) // 2, 0))
27
+ return result
28
 
29
 
30
  class Blip3ImageProcessor(BaseImageProcessor):
modeling_xgenmm.py CHANGED
@@ -1,29 +1,2013 @@
1
- from transformers import PreTrainedModel, AutoModelForCausalLM, AutoModel
 
 
 
 
 
2
  import torch
3
- import open_clip
 
 
4
  from typing import List, Optional, Tuple, Union
5
- from .utils import check_embedding_fns
6
- from .vlm import PerceiverResampler, XGenMMPerceiver
7
- from .configuration_xgenmm import XGenMMVisionEncoderConfig, XGenMMVisionTokenizerConfig, XGenMMConfig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  class XGenMMVisionEncoder(PreTrainedModel):
10
  main_input_name = "pixel_values"
11
  config_class = XGenMMVisionEncoderConfig
12
-
13
  def __init__(self, config: XGenMMVisionEncoderConfig):
14
  super().__init__(config)
15
- if config.model_name != 'google/siglip-so400m-patch14-384':
16
- raise ValueError(f"Unsupported model {config.model_name}. New vision models will be added soon.")
 
 
17
  self.model = AutoModel.from_pretrained(config.model_name)
18
-
19
  def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
20
  # assert pixel_values.ndim == 4, f"Expected 4D tensor (bs, c, h, w), got {pixel_values.ndim}"
21
  return self.model.encode_image(pixel_values)
22
-
23
 
24
- # vision tokenizer
 
25
  class XGenMMVisionTokenizer(PreTrainedModel):
26
  config_class = XGenMMVisionTokenizerConfig
 
27
  def __init__(self, config: XGenMMVisionTokenizerConfig):
28
  super().__init__(config)
29
  self.model = PerceiverResampler(
@@ -31,50 +2015,58 @@ class XGenMMVisionTokenizer(PreTrainedModel):
31
  dim_inner=config.lang_embedding_dim,
32
  num_latents=config.num_vis_tokens,
33
  )
34
-
35
- def forward(self,
36
- vision_features: torch.Tensor,
37
- vision_attn_masks: torch.Tensor):
38
  return self.model(vision_features, vision_attn_masks)
39
-
 
40
  # XGenMM model
41
  class XGenMMModelForConditionalGeneration(PreTrainedModel):
42
  config_class = XGenMMConfig
43
-
44
  def __init__(self, config: XGenMMConfig):
45
  super().__init__(config)
46
-
47
  # vision encoder initialization
48
- vision_encoder = AutoModel.from_pretrained(config.vision_encoder_config.model_name).vision_model
49
-
50
- # language model initialization
 
 
51
  language_model = AutoModelForCausalLM.from_config(config.text_config)
52
  check_embedding_fns(language_model)
53
  # Update _tied_weights_keys using the base model used.
54
  if language_model._tied_weights_keys is not None:
55
- self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
56
-
 
 
57
  # vision tokenizer initialization
58
- if config.vision_tokenizer_config.lang_embedding_dim != language_model.get_input_embeddings().weight.shape[1]:
 
 
 
59
  overwrite = language_model.get_input_embeddings().weight.shape[1]
60
  config.vision_tokenizer_config.lang_embedding_dim = overwrite
61
- print(f"Warning: The language embedding dimension in the vision tokenizer config is different from the language model's embedding dimension. Overwriting the language embedding dimension in the vision tokenizer config to {overwrite}.")
62
-
 
 
63
  vision_tokenizer = XGenMMVisionTokenizer(config.vision_tokenizer_config).model
64
 
65
  self.vlm = XGenMMPerceiver(
66
  vision_encoder=vision_encoder,
67
  vision_tokenizer=vision_tokenizer,
68
  lang_model=language_model,
69
- initial_tokenizer_len = config.text_config.initial_tokenizer_len,
70
- pad_token_id = config.text_config.pad_token_id,
71
- image_aspect_ratio = config.vision_encoder_config.image_aspect_ratio,
72
- anyres_patch_sampling = config.vision_encoder_config.anyres_patch_sampling,
73
- anyres_grids = config.vision_encoder_config.anyres_grids
74
  )
75
  # Initialize weights and apply final processing
76
  self.post_init()
77
-
78
  @torch.no_grad()
79
  def generate(
80
  self,
@@ -82,14 +2074,15 @@ class XGenMMModelForConditionalGeneration(PreTrainedModel):
82
  input_ids: Optional[torch.LongTensor] = None,
83
  attention_mask: Optional[torch.LongTensor] = None,
84
  **generate_kwargs,
85
- ) -> torch.LongTensor:
86
  self.vlm = self.vlm.eval()
87
  return self.vlm.generate(
88
- vision_x = pixel_values,
89
- lang_x = input_ids,
90
- attention_mask = attention_mask,
91
- **generate_kwargs)
92
-
 
93
  def update_special_tokens(self, tokenizer):
94
  tokenizer.add_special_tokens(
95
  {"additional_special_tokens": list(self.vlm.special_tokens.values())}
@@ -97,8 +2090,8 @@ class XGenMMModelForConditionalGeneration(PreTrainedModel):
97
  self.vlm.lang_model.config.vocab_size = len(tokenizer)
98
  self.vlm.set_special_token_ids(
99
  {
100
- v: tokenizer.convert_tokens_to_ids(v) for v in self.vlm.special_tokens.values()
 
101
  }
102
  )
103
  return tokenizer
104
-
 
1
+ import ast
2
+ import math
3
+ from einops import rearrange, repeat
4
+ from einops_exts import rearrange_many
5
+ from einops import rearrange
6
+ from PIL import Image
7
  import torch
8
+ from torch import einsum, nn
9
+
10
+
11
  from typing import List, Optional, Tuple, Union
12
+ import torch.nn.functional as F
13
+ from transformers.modeling_outputs import CausalLMOutputWithPast
14
+ from dataclasses import dataclass
15
+ from transformers import CLIPVisionModel
16
+ from transformers import PreTrainedModel, AutoModelForCausalLM, AutoModel
17
+ from transformers import PretrainedConfig, logging, CONFIG_MAPPING
18
+ from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class XGenMMVisionEncoderConfig(PretrainedConfig):
25
+ model_type = "xgenmm_vision_encoder"
26
+
27
+ def __init__(
28
+ self,
29
+ model_name: str = "google/siglip-so400m-patch14-384",
30
+ anyres_grids: list[int] = [
31
+ [384, 768],
32
+ [768, 384],
33
+ [768, 768],
34
+ [1152, 384],
35
+ [384, 1152],
36
+ ],
37
+ **kwargs,
38
+ ):
39
+ self.model_name = model_name
40
+ self.anyres_grids = anyres_grids
41
+ super().__init__(**kwargs)
42
+
43
+
44
+ class XGenMMVisionTokenizerConfig(PretrainedConfig):
45
+ model_type = "xgenmm_vision_tokenizer"
46
+
47
+ def __init__(
48
+ self,
49
+ vis_feature_dim: int = 1152,
50
+ lang_embedding_dim: int = 3072,
51
+ num_vis_tokens: int = 128,
52
+ image_aspect_ratio: str = "anyres",
53
+ **kwargs,
54
+ ):
55
+ self.vis_feature_dim = vis_feature_dim
56
+ self.lang_embedding_dim = lang_embedding_dim
57
+ self.num_vis_tokens = num_vis_tokens
58
+ self.image_aspect_ratio = image_aspect_ratio
59
+ super().__init__(**kwargs)
60
+
61
+
62
+ class XGenMMConfig(PretrainedConfig):
63
+ model_type = "xgenmm"
64
+
65
+ def __init__(
66
+ self,
67
+ vision_encoder_config: dict = None,
68
+ vision_tokenizer_config: dict = None,
69
+ text_config: dict = None,
70
+ **kwargs,
71
+ ):
72
+
73
+ if vision_encoder_config is None:
74
+ vision_encoder_config = {
75
+ "image_aspect_ratio": "anyres",
76
+ "anyres_patch_sampling": True,
77
+ }
78
+ logger.info(
79
+ "vision_encoder_config is None. initializing the XGenMMVisionEncoderConfig with default values."
80
+ )
81
+
82
+ if vision_tokenizer_config is None:
83
+ vision_tokenizer_config = {}
84
+ logger.info(
85
+ "vision_tokenizer_config is None. Initializing the XGenMMVisionTokenizerConfig with default values."
86
+ )
87
+
88
+ if text_config is None:
89
+ text_config = {
90
+ "initial_tokenizer_len": 32012,
91
+ "pad_token_id": 32011,
92
+ "bos_token_id": 1,
93
+ "eos_token_id": 32000,
94
+ "vocab_size": 32064,
95
+ "hidden_size": 3072,
96
+ "intermediate_size": 8192,
97
+ "num_hidden_layers": 32,
98
+ "num_attention_heads": 32,
99
+ "num_key_value_heads": 32,
100
+ "resid_pdrop": 0.0,
101
+ "embd_pdrop": 0.0,
102
+ "attention_dropout": 0.0,
103
+ "hidden_act": "silu",
104
+ "max_position_embeddings": 4096,
105
+ "original_max_position_embeddings": 4096,
106
+ "initializer_range": 0.02,
107
+ "rms_norm_eps": 1e-05,
108
+ "use_cache": True,
109
+ "rope_theta": 10000.0,
110
+ "rope_scaling": None,
111
+ "sliding_window": 2047,
112
+ "return_dict": True,
113
+ "output_hidden_states": False,
114
+ "output_attentions": False,
115
+ "torchscript": False,
116
+ "torch_dtype": "bfloat16",
117
+ "use_bfloat16": False,
118
+ "tf_legacy_loss": False,
119
+ "pruned_heads": {},
120
+ "tie_word_embeddings": False,
121
+ "chunk_size_feed_forward": 0,
122
+ "is_encoder_decoder": False,
123
+ "is_decoder": False,
124
+ "cross_attention_hidden_size": None,
125
+ "add_cross_attention": False,
126
+ "tie_encoder_decoder": False,
127
+ "max_length": 20,
128
+ "min_length": 0,
129
+ "do_sample": False,
130
+ "early_stopping": False,
131
+ "num_beams": 1,
132
+ "num_beam_groups": 1,
133
+ "diversity_penalty": 0.0,
134
+ "temperature": 1.0,
135
+ "top_k": 50,
136
+ "top_p": 1.0,
137
+ "typical_p": 1.0,
138
+ "repetition_penalty": 1.0,
139
+ "length_penalty": 1.0,
140
+ "no_repeat_ngram_size": 0,
141
+ "encoder_no_repeat_ngram_size": 0,
142
+ "bad_words_ids": None,
143
+ "num_return_sequences": 1,
144
+ "output_scores": False,
145
+ "return_dict_in_generate": False,
146
+ "forced_bos_token_id": None,
147
+ "forced_eos_token_id": None,
148
+ "remove_invalid_values": False,
149
+ "exponential_decay_length_penalty": None,
150
+ "suppress_tokens": None,
151
+ "begin_suppress_tokens": None,
152
+ "finetuning_task": None,
153
+ "id2label": {0: "LABEL_0", 1: "LABEL_1"},
154
+ "label2id": {"LABEL_0": 0, "LABEL_1": 1},
155
+ "tokenizer_class": None,
156
+ "prefix": None,
157
+ "bos_token_id": 1,
158
+ "pad_token_id": 32000,
159
+ "eos_token_id": 32000,
160
+ "sep_token_id": None,
161
+ "decoder_start_token_id": None,
162
+ "task_specific_params": None,
163
+ "problem_type": None,
164
+ "model_type": "phi3",
165
+ }
166
+ logger.info(
167
+ "text_config is None. Initializing the text config with default values (`Phi3Config`)."
168
+ )
169
+
170
+ self.vision_encoder_config = XGenMMVisionEncoderConfig(**vision_encoder_config)
171
+
172
+ self.vision_tokenizer_config = XGenMMVisionTokenizerConfig(
173
+ **vision_tokenizer_config
174
+ )
175
+
176
+ text_model_type = (
177
+ text_config["model_type"] if "model_type" in text_config else "phi3"
178
+ )
179
+ self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
180
+
181
+ for key in ["initial_tokenizer_len", "pad_token_id"]:
182
+ if key not in self.text_config.to_dict():
183
+ raise ValueError(f"The key `{key}` is missing in the text_config.")
184
+
185
+ super().__init__(**kwargs)
186
+
187
+
188
+ def hasattr_recursive(obj, att):
189
+ """
190
+ Check if obj has nested attribute
191
+ Example: hasattr_recursive(obj, 'a.b.c') is equivalent to hasattr(obj, 'a') and hasattr(obj.a, 'b') and hasattr(obj.a.b, 'c')
192
+ """
193
+ if att == "":
194
+ return True
195
+ i = att.find(".")
196
+ if i < 0:
197
+ return hasattr(obj, att)
198
+ else:
199
+ try:
200
+ return hasattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
201
+ except:
202
+ return False
203
+
204
+
205
+ def getattr_recursive(obj, att):
206
+ """
207
+ Return nested attribute of obj
208
+ Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
209
+ """
210
+ if att == "":
211
+ return obj
212
+ i = att.find(".")
213
+ if i < 0:
214
+ return getattr(obj, att)
215
+ else:
216
+ return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
217
+
218
+
219
+ def setattr_recursive(obj, att, val):
220
+ """
221
+ Set nested attribute of obj
222
+ Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
223
+ """
224
+ if "." in att:
225
+ obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
226
+ setattr(obj, att.split(".")[-1], val)
227
+
228
+
229
+ def check_embedding_fns(lang_model):
230
+ """Checks for and attempts to set {get/set}_{input/output}_embeddings functions to the model"""
231
+ if not has_fn(lang_model, "get_input_embeddings"):
232
+ if hasattr_recursive(lang_model, "transformer.wte"): # MPT
233
+ lang_model.get_input_embeddings = lambda: lang_model.transformer.wte
234
+ elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): # OPT
235
+ lang_model.get_input_embeddings = lambda: lang_model.decoder.embed_tokens
236
+ else:
237
+ raise ValueError(
238
+ "We require the language encoder to have a get_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py."
239
+ )
240
+
241
+ if not has_fn(lang_model, "set_input_embeddings"):
242
+ if hasattr_recursive(lang_model, "transformer.wte"): # MPT
243
+ lang_model.set_input_embeddings = lambda x: setattr_recursive(
244
+ lang_model, "transformer.wte", x
245
+ )
246
+ elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): # OPT
247
+ lang_model.set_input_embeddings = lambda x: setattr_recursive(
248
+ lang_model, "model.decoder.embed_tokens", x
249
+ )
250
+ else:
251
+ raise ValueError(
252
+ "We require the language encoder to have a set_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py."
253
+ )
254
+
255
+ if not has_fn(lang_model, "get_output_embeddings"):
256
+ if hasattr_recursive(lang_model, "lm_head"):
257
+ lang_model.get_output_embeddings = lambda: lang_model.lm_head
258
+ else:
259
+ raise ValueError(
260
+ "We require the language encoder to have a get_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
261
+ )
262
+
263
+ if not has_fn(lang_model, "set_output_embeddings"):
264
+ if hasattr_recursive(lang_model, "lm_head"):
265
+ lang_model.set_output_embeddings = lambda x: setattr_recursive(
266
+ lang_model, "lm_head", x
267
+ )
268
+ else:
269
+ raise ValueError(
270
+ "We require the language encoder to have a set_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
271
+ )
272
+
273
+
274
+ def has_fn(model, fn_name):
275
+ """Check if model has a function fn_name"""
276
+ return callable(getattr(model, fn_name, None))
277
+
278
+
279
+ def stack_with_padding(list_of_tensors, padding_value=0, padding_side="right"):
280
+ """
281
+ Stack a list of tensors with padding on one side
282
+ Args:
283
+ list_of_tensors (list[torch.Tensor]): List of tensors to stack
284
+ padding_value (int, optional): Value to pad with. Defaults to 0.
285
+ padding_side (str, optional): Side to pad on. Defaults to "right".
286
+ Returns:
287
+ torch.Tensor: Stacked tensors
288
+ """
289
+ max_tokens = max(tensor.size(0) for tensor in list_of_tensors)
290
+ padded_tensors = []
291
+ for tensor in list_of_tensors:
292
+ num_tokens = tensor.size(0)
293
+ if len(tensor.size()) == 1:
294
+ padding = torch.full(
295
+ (max_tokens - num_tokens,),
296
+ padding_value,
297
+ dtype=tensor.dtype,
298
+ device=tensor.device,
299
+ )
300
+ else:
301
+ padding = torch.full(
302
+ (max_tokens - num_tokens, tensor.size(1)),
303
+ padding_value,
304
+ dtype=tensor.dtype,
305
+ device=tensor.device,
306
+ )
307
+ padded_tensor = (
308
+ torch.cat((tensor, padding), dim=0)
309
+ if padding_side == "right"
310
+ else torch.cat((padding, tensor), dim=0)
311
+ )
312
+ padded_tensors.append(padded_tensor)
313
+ return torch.stack(padded_tensors)
314
+
315
+
316
+ def unpad_image(tensor, original_size, keep_original_shape=False):
317
+ """
318
+ Unpads a PyTorch tensor of a padded and resized image.
319
+
320
+ Args:
321
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
322
+ original_size (tuple): The original size of the image (height, width).
323
+
324
+ Returns:
325
+ torch.Tensor: The unpadded image tensor.
326
+ """
327
+ original_width, original_height = original_size
328
+ current_height, current_width = tensor.shape[1:]
329
+
330
+ original_aspect_ratio = original_width / original_height
331
+ current_aspect_ratio = current_width / current_height
332
+
333
+ if original_aspect_ratio > current_aspect_ratio:
334
+ scale_factor = current_width / original_width
335
+ new_height = int(original_height * scale_factor)
336
+ padding = (current_height - new_height) // 2
337
+ if keep_original_shape:
338
+ attention_mask = torch.ones(
339
+ (current_height, current_width), device=tensor.device
340
+ )
341
+ attention_mask[:padding, :] = 0
342
+ attention_mask[current_height - padding :, :] = 0
343
+ return tensor, attention_mask
344
+ else:
345
+ unpadded_tensor = tensor[:, padding : current_height - padding, :]
346
+ return unpadded_tensor, None
347
+ else:
348
+ scale_factor = current_height / original_height
349
+ new_width = int(original_width * scale_factor)
350
+ padding = (current_width - new_width) // 2
351
+ if keep_original_shape:
352
+ attention_mask = torch.ones(
353
+ (current_height, current_width), device=tensor.device
354
+ )
355
+ attention_mask[:, :padding] = 0
356
+ attention_mask[:, current_width - padding :] = 0
357
+ return tensor, attention_mask
358
+ else:
359
+ unpadded_tensor = tensor[:, :, padding : current_width - padding]
360
+ return unpadded_tensor, None
361
+
362
+
363
+ def select_best_resolution(original_size, possible_resolutions):
364
+ """
365
+ Selects the best resolution from a list of possible resolutions based on the original size.
366
+
367
+ Args:
368
+ original_size (tuple): The original size of the image in the format (width, height).
369
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
370
+
371
+ Returns:
372
+ tuple: The best fit resolution in the format (width, height).
373
+ """
374
+ original_width, original_height = original_size
375
+ best_fit = None
376
+ max_effective_resolution = 0
377
+ min_wasted_resolution = float("inf")
378
+
379
+ for width, height in possible_resolutions:
380
+ scale = min(width / original_width, height / original_height)
381
+ downscaled_width, downscaled_height = int(original_width * scale), int(
382
+ original_height * scale
383
+ )
384
+ effective_resolution = min(
385
+ downscaled_width * downscaled_height, original_width * original_height
386
+ )
387
+ wasted_resolution = (width * height) - effective_resolution
388
+
389
+ if effective_resolution > max_effective_resolution or (
390
+ effective_resolution == max_effective_resolution
391
+ and wasted_resolution < min_wasted_resolution
392
+ ):
393
+ max_effective_resolution = effective_resolution
394
+ min_wasted_resolution = wasted_resolution
395
+ best_fit = (width, height)
396
+
397
+ return best_fit
398
+
399
+
400
+ def resize_and_pad_image(image, target_resolution):
401
+ """
402
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
403
+
404
+ Args:
405
+ image (PIL.Image.Image): The input image.
406
+ target_resolution (tuple): The target resolution (width, height) of the image.
407
+
408
+ Returns:
409
+ PIL.Image.Image: The resized and padded image.
410
+ """
411
+ original_width, original_height = image.size
412
+ target_width, target_height = target_resolution
413
+
414
+ scale_w = target_width / original_width
415
+ scale_h = target_height / original_height
416
+
417
+ if scale_w < scale_h:
418
+ new_width = target_width
419
+ new_height = min(math.ceil(original_height * scale_w), target_height)
420
+ else:
421
+ new_height = target_height
422
+ new_width = min(math.ceil(original_width * scale_h), target_width)
423
+
424
+ # Resize the image
425
+ resized_image = image.resize((new_width, new_height))
426
+
427
+ new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
428
+ paste_x = (target_width - new_width) // 2
429
+ paste_y = (target_height - new_height) // 2
430
+ new_image.paste(resized_image, (paste_x, paste_y))
431
+
432
+ return new_image
433
+
434
+
435
+ def divide_to_patches(image, patch_size):
436
+ """
437
+ Divides an image into patches of a specified size.
438
+
439
+ Args:
440
+ image (PIL.Image.Image): The input image.
441
+ patch_size (int): The size of each patch.
442
+
443
+ Returns:
444
+ list: A list of PIL.Image.Image objects representing the patches.
445
+ """
446
+ patches = []
447
+ width, height = image.size
448
+ for i in range(0, height, patch_size):
449
+ for j in range(0, width, patch_size):
450
+ box = (j, i, j + patch_size, i + patch_size)
451
+ patch = image.crop(box)
452
+ patches.append(patch)
453
+
454
+ return patches
455
+
456
+
457
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
458
+ """
459
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
460
+
461
+ Args:
462
+ image_size (tuple): The size of the input image in the format (width, height).
463
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
464
+ patch_size (int): The size of each image patch.
465
+
466
+ Returns:
467
+ tuple: The shape of the image patch grid in the format (width, height).
468
+ """
469
+ if type(grid_pinpoints) is list:
470
+ possible_resolutions = grid_pinpoints
471
+ else:
472
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
473
+ width, height = select_best_resolution(image_size, possible_resolutions)
474
+ return width // patch_size, height // patch_size
475
+
476
+
477
+ def process_anyres_image(image, processor, grid_pinpoints):
478
+ """
479
+ Process an image with variable resolutions.
480
+
481
+ Args:
482
+ image (PIL.Image.Image): The input image to be processed.
483
+ processor: The image processor object.
484
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
485
+
486
+ Returns:
487
+ torch.Tensor: A tensor containing the processed image patches.
488
+ """
489
+ # FIXME: determine grid_pinpoints from image sizes.
490
+ if type(grid_pinpoints) is list:
491
+ possible_resolutions = grid_pinpoints
492
+ else:
493
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
494
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
495
+ image_padded = resize_and_pad_image(image, best_resolution)
496
+
497
+ processor_size = processor.transforms[0].size
498
+ patches = divide_to_patches(image_padded, processor_size[0])
499
+
500
+ image_original_resize = image.resize((processor_size[0], processor_size[0]))
501
+
502
+ image_patches = [image_original_resize] + patches
503
+ image_patches = [processor(image_patch) for image_patch in image_patches]
504
+ return torch.stack(image_patches, dim=0)
505
+
506
+
507
+ def expand2square(pil_img, background_color):
508
+ width, height = pil_img.size
509
+ if width == height:
510
+ return pil_img
511
+ elif width > height:
512
+ result = Image.new(pil_img.mode, (width, width), background_color)
513
+ result.paste(pil_img, (0, (width - height) // 2))
514
+ return result
515
+ else:
516
+ result = Image.new(pil_img.mode, (height, height), background_color)
517
+ result.paste(pil_img, ((height - width) // 2, 0))
518
+ return result
519
+
520
+
521
+ class VisionTokenizer(nn.Module):
522
+ def __init__(self, dim_media, num_tokens_per_media):
523
+ super().__init__()
524
+ self.dim_media = dim_media
525
+ self.num_tokens_per_media = num_tokens_per_media
526
+
527
+
528
+ class PerceiverAttention(nn.Module):
529
+ def __init__(self, *, dim, dim_head=64, heads=8):
530
+ super().__init__()
531
+ self.scale = dim_head**-0.5
532
+ self.heads = heads
533
+ inner_dim = dim_head * heads
534
+
535
+ self.norm_media = nn.LayerNorm(dim)
536
+ self.norm_latents = nn.LayerNorm(dim)
537
+
538
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
539
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
540
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
541
+
542
+ def forward(self, x, latents, vision_attn_masks=None):
543
+ """
544
+ Args:
545
+ x (torch.Tensor): image features
546
+ shape (b, T, n1, D)
547
+ latent (torch.Tensor): latent features
548
+ shape (b, T, n2, D)
549
+ """
550
+ x = self.norm_media(x)
551
+ latents = self.norm_latents(latents)
552
+
553
+ h = self.heads
554
+
555
+ q = self.to_q(latents)
556
+ kv_input = torch.cat(
557
+ (x, latents), dim=-2
558
+ ) # TODO: Change the shape of vision attention mask according to this.
559
+ if vision_attn_masks is not None:
560
+ vision_attn_masks = torch.cat(
561
+ (
562
+ vision_attn_masks,
563
+ torch.ones(
564
+ (latents.shape[0], latents.shape[-2]),
565
+ dtype=latents.dtype,
566
+ device=latents.device,
567
+ ),
568
+ ),
569
+ dim=-1,
570
+ )
571
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
572
+ q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
573
+ q = q * self.scale
574
+
575
+ # attention
576
+ sim = einsum("... i d, ... j d -> ... i j", q, k)
577
+ # Apply vision attention mask here.
578
+ # Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention
579
+ if vision_attn_masks is not None:
580
+ attn_bias = torch.zeros(
581
+ (q.size(0), 1, 1, q.size(-2), k.size(-2)),
582
+ dtype=q.dtype,
583
+ device=q.device,
584
+ )
585
+ vision_attn_masks = repeat(
586
+ vision_attn_masks, "b n -> b 1 1 l n", l=q.size(-2)
587
+ )
588
+ attn_bias.masked_fill_(vision_attn_masks.logical_not(), float("-inf"))
589
+ sim += attn_bias
590
+
591
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
592
+ attn = sim.softmax(dim=-1)
593
+
594
+ out = einsum("... i j, ... j d -> ... i d", attn, v)
595
+ out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
596
+ return self.to_out(out)
597
+
598
+
599
+ def FeedForward(dim, mult=4):
600
+ inner_dim = int(dim * mult)
601
+ return nn.Sequential(
602
+ nn.LayerNorm(dim),
603
+ nn.Linear(dim, inner_dim, bias=False),
604
+ nn.GELU(),
605
+ nn.Linear(inner_dim, dim, bias=False),
606
+ )
607
+
608
+
609
+ def num_params(module, filter_to_trainable=False):
610
+ """Returns the number of parameters in the module, or optionally only the trainable parameters"""
611
+ if filter_to_trainable:
612
+ return sum(p.numel() for p in module.parameters() if p.requires_grad)
613
+ else:
614
+ return sum(p.numel() for p in module.parameters())
615
+
616
+
617
+ class PerceiverResampler(VisionTokenizer):
618
+ def __init__(
619
+ self,
620
+ *,
621
+ dim,
622
+ dim_inner=None,
623
+ depth=6,
624
+ dim_head=96,
625
+ heads=16,
626
+ num_latents=128,
627
+ max_num_media=None,
628
+ max_num_frames=None,
629
+ ff_mult=4,
630
+ ):
631
+ """
632
+ Perceiver module which takes in image features and outputs image tokens.
633
+ Args:
634
+ dim (int): dimension of the incoming image features
635
+ dim_inner (int, optional): final dimension to project the incoming image features to;
636
+ also the final dimension of the outputted features. If None, no projection is used, and dim_inner = dim.
637
+ depth (int, optional): number of layers. Defaults to 6.
638
+ dim_head (int, optional): dimension of each head. Defaults to 64.
639
+ heads (int, optional): number of heads. Defaults to 8.
640
+ num_latents (int, optional): number of latent tokens to use in the Perceiver;
641
+ also corresponds to number of tokens per sequence to output. Defaults to 64.
642
+ max_num_media (int, optional): maximum number of media per sequence to input into the Perceiver
643
+ and keep positional embeddings for. If None, no positional embeddings are used.
644
+ max_num_frames (int, optional): maximum number of frames to input into the Perceiver
645
+ and keep positional embeddings for. If None, no positional embeddings are used.
646
+ ff_mult (int, optional): dimension multiplier for the feedforward network. Defaults to 4.
647
+ """
648
+ if dim_inner is not None:
649
+ projection = nn.Linear(dim, dim_inner)
650
+ else:
651
+ projection = None
652
+ dim_inner = dim
653
+ super().__init__(dim_media=dim, num_tokens_per_media=num_latents)
654
+ self.projection = projection
655
+ self.latents = nn.Parameter(torch.randn(num_latents, dim))
656
+
657
+ # positional embeddings
658
+ self.frame_embs = (
659
+ nn.Parameter(torch.randn(max_num_frames, dim))
660
+ if exists(max_num_frames)
661
+ else None
662
+ )
663
+ self.media_time_embs = (
664
+ nn.Parameter(torch.randn(max_num_media, 1, dim))
665
+ if exists(max_num_media)
666
+ else None
667
+ )
668
+
669
+ self.layers = nn.ModuleList([])
670
+ for _ in range(depth):
671
+ self.layers.append(
672
+ nn.ModuleList(
673
+ [
674
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
675
+ FeedForward(dim=dim, mult=ff_mult),
676
+ ]
677
+ )
678
+ )
679
+
680
+ self.norm = nn.LayerNorm(dim)
681
+
682
+ def forward(self, x, vision_attn_masks):
683
+ """
684
+ Args:
685
+ x (torch.Tensor): image features
686
+ shape (b, T, F, v, D)
687
+ vision_attn_masks (torch.Tensor): attention masks for padded visiont tokens (i.e., x)
688
+ shape (b, v)
689
+ Returns:
690
+ shape (b, T, n, D) where n is self.num_latents
691
+ """
692
+ b, T, F, v = x.shape[:4]
693
+
694
+ # frame and media time embeddings
695
+ if exists(self.frame_embs):
696
+ frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
697
+ x = x + frame_embs
698
+ x = rearrange(
699
+ x, "b T F v d -> b T (F v) d"
700
+ ) # flatten the frame and spatial dimensions
701
+ if exists(self.media_time_embs):
702
+ x = x + self.media_time_embs[:T]
703
+
704
+ # blocks
705
+ latents = self.latents
706
+ latents = repeat(latents, "n d -> b T n d", b=b, T=T)
707
+ for attn, ff in self.layers:
708
+ latents = attn(x, latents, vision_attn_masks) + latents
709
+ latents = ff(latents) + latents
710
+
711
+ if exists(self.projection):
712
+ return self.projection(self.norm(latents))
713
+ else:
714
+ return self.norm(latents)
715
+
716
+
717
+ class DecoupledEmbedding(nn.Embedding):
718
+ # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding
719
+ """
720
+ Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. In practise, the
721
+ regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0,
722
+ then it will create `num_additional_embeddings` additional parameters that are always trained. If
723
+ `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`.
724
+ """
725
+
726
+ def __init__(
727
+ self,
728
+ max_original_id: int,
729
+ num_additional_embeddings: int = 0,
730
+ _weight: torch.Tensor = None,
731
+ num_original_embeddings: int = None,
732
+ embedding_dim: int = None,
733
+ partially_freeze=True,
734
+ device=None,
735
+ dtype=None,
736
+ pad_token_id=None,
737
+ ) -> None:
738
+ """
739
+ Args:
740
+ max_original_id (`int`):
741
+ The largest token id that should be embedded using the regular embedding (regular `weight`).
742
+ This is usually len(tokenizer) - 1 before additional tokens are added.
743
+ Note that this may not equal self.weight.shape[0]
744
+ num_additional_embeddings (`int`):
745
+ Number of additional tokens to initialize an Embedding matrix for (`additional_weight`).
746
+ _weight (`torch.Tensor`, *optional*, defaults to `None`): The regular weight tensor.
747
+ If provided, this sets the `num_original_embeddings` and `embedding_dim` parameters.
748
+ num_original_embeddings (`int`):
749
+ self.weight.shape[0]
750
+ embedding_dim (`int`):
751
+ The size of each embedding vector
752
+ partially_freeze: (`bool`, *optional*, defaults to `True`):
753
+ If `True`, the regular `weight` will be frozen. `additional_weight` is never frozen.
754
+ padding_idx (`int`, *optional*):
755
+ The padding index (needs to be less than num_embeddings)
756
+
757
+ Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`,
758
+ `max_norm` or `norm_type`. We are not supporting these.
759
+ """
760
+ # validate args
761
+ if pad_token_id is not None and pad_token_id > max_original_id:
762
+ raise ValueError(
763
+ f"pad_token_id must be <= max_original_id. Got {pad_token_id} and {max_original_id}."
764
+ + "If the original tokenizer does not have a pad_token_id, use pad_token_id=None."
765
+ )
766
+ if _weight is not None:
767
+ assert (num_original_embeddings is None) or (
768
+ _weight.shape[0] == num_original_embeddings
769
+ ), f"num_original_embeddings={num_original_embeddings} but _weight.shape[0]={_weight.shape[0]}"
770
+ assert (embedding_dim is None) or (
771
+ _weight.shape[1] == embedding_dim
772
+ ), f"embedding_dim={embedding_dim} but _weight.shape[1]={_weight.shape[1]}"
773
+ num_original_embeddings = _weight.shape[0]
774
+ embedding_dim = _weight.shape[1]
775
+ else:
776
+ assert (
777
+ num_original_embeddings is not None
778
+ ), "num_original_embeddings must be provided if _weight is not provided"
779
+ assert (
780
+ embedding_dim is not None
781
+ ), "embedding_dim must be provided if _weight is not provided"
782
+
783
+ super().__init__(
784
+ num_embeddings=num_original_embeddings,
785
+ embedding_dim=embedding_dim,
786
+ device=device,
787
+ dtype=dtype,
788
+ padding_idx=pad_token_id,
789
+ _weight=_weight,
790
+ )
791
+ self.max_original_id = max_original_id
792
+ self.padding_idx = pad_token_id
793
+ self.num_additional_embeddings = num_additional_embeddings
794
+ if self.num_additional_embeddings > 0:
795
+ self.additional_embedding = nn.Embedding(
796
+ num_embeddings=self.num_additional_embeddings,
797
+ embedding_dim=embedding_dim,
798
+ device=device,
799
+ dtype=dtype,
800
+ )
801
+ self.set_requires_grad(
802
+ require_regular_grad=not partially_freeze, require_additional_grad=True
803
+ )
804
+
805
+ def set_requires_grad(self, require_regular_grad, require_additional_grad):
806
+ """
807
+ Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
808
+ """
809
+ self.weight.requires_grad_(require_regular_grad)
810
+ self.additional_embedding.requires_grad_(require_additional_grad)
811
+
812
+ def forward(self, input_ids):
813
+ """
814
+ we have 2 embeddings, with different indices - one pretrained self.weight and another
815
+ self.additional_embedding.weight that is being trained.
816
+
817
+ in order to make a lookup of the input ids, we:
818
+ 1. find out the indices of the entries belonging to the 2nd embedding
819
+ 2. extract those values while subtracting the size of the first embedding (num_embeddings), since the 2nd
820
+ embedding starts from 0 and not num_embeddings
821
+ 3. perform the 2nd embedding lookup
822
+ 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index
823
+ 5. perform the 1st embedding lookup
824
+ 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup
825
+
826
+ note: for the 1st embedding lookup we could have looked up only the low indices and not do the padding, but
827
+ then we have to create a new tensor and populate it with 2 tensors that are spread out across various indices -
828
+ i.e. not a simple concat - I haven't benchmarked the complex case if it's any faster, given that seqlens are
829
+ usually relatively short it's probably not faster or if faster not by much - but might be a good idea to
830
+ measure.
831
+
832
+ """
833
+ if self.num_additional_embeddings == 0:
834
+ return F.embedding(input_ids, self.weight)
835
+
836
+ # Clone so that we don't modify the original input_ids later on
837
+ input_ids = input_ids.clone()
838
+ additional_vocab_indices = torch.where(input_ids > self.max_original_id)
839
+ input_ids_additional_vocab = input_ids[additional_vocab_indices]
840
+ additional_embeddings = self.additional_embedding(
841
+ input_ids_additional_vocab - self.max_original_id - 1
842
+ )
843
+
844
+ # for successful lookup replace input_ids with 0, the results of these will be discarded anyway
845
+ input_ids[additional_vocab_indices] = 0
846
+ full_vector = F.embedding(input_ids, self.weight)
847
+
848
+ # overwrite the records with high indices
849
+ full_vector[additional_vocab_indices] = additional_embeddings
850
+
851
+ return full_vector
852
+
853
+ def extra_repr(self) -> str:
854
+ return "num_original_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format(
855
+ self.max_original_id + 1,
856
+ self.num_additional_embeddings,
857
+ self.embedding_dim,
858
+ (not self.weight.requires_grad),
859
+ )
860
+
861
+
862
+ class DecoupledLinear(nn.Linear):
863
+ # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
864
+ """
865
+ Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the
866
+ regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `additional_out_features` > 0,
867
+ then it will create `additional_out_features * in_features` additional parameters that are always trained. If
868
+ `additional_out_features=0`, then the module defaults back to the regular behavior of `nn.Linear`.
869
+ """
870
+
871
+ def __init__(
872
+ self,
873
+ max_original_id: int,
874
+ additional_out_features: int = 0,
875
+ _weight: torch.Tensor = None,
876
+ _bias: torch.Tensor = None,
877
+ in_features: int = None,
878
+ original_out_features: int = None,
879
+ bias: bool = True,
880
+ partially_freeze: bool = True,
881
+ device=None,
882
+ dtype=None,
883
+ ) -> None:
884
+ """
885
+ Args:
886
+ max_original_id (`int`): The largest token id that should be extracted from the regular weight.
887
+ This is usually len(tokenizer) - 1 before additional tokens are added.
888
+ Note that this may not equal original_out_features - 1
889
+ _weight: torch.Tensor, *optional*, defaults to `None`. The regular weight tensor.
890
+ If provided, this sets the `in_features` and `original_out_features` parameters.
891
+ _bias: torch.Tensor, *optional*, defaults to `None`. The regular bias tensor.
892
+ in_features: int. Input hidden size.
893
+ original_out_features: int. Original out_features of the language model's get_output_embeddings() function.
894
+ additional_out_features: int. Number of additional trainable dimensions.
895
+ bias: bool. Whether to include a bias term.
896
+ partially_freeze: bool, *optional*, defaults to `True`): If `True`, the regular `weight` will be frozen.
897
+ """
898
+ # argument validation
899
+ if _weight is not None:
900
+ assert (_weight.shape[0] == original_out_features) or (
901
+ original_out_features is None
902
+ ), f"original_out_features={original_out_features} but _weight.shape[0]={_weight.shape[0]}"
903
+ assert (_weight.shape[1] == in_features) or (
904
+ in_features is None
905
+ ), f"in_features={in_features} but _weight.shape[1]={_weight.shape[1]}"
906
+ in_features = _weight.shape[1]
907
+ original_out_features = _weight.shape[0]
908
+ else:
909
+ assert (
910
+ in_features is not None
911
+ ), "in_features must be provided if _weight is not provided"
912
+ assert (
913
+ original_out_features is not None
914
+ ), "original_out_features must be provided if _weight is not provided"
915
+
916
+ if _bias is not None:
917
+ assert bias is True, "bias must be True if _bias is provided"
918
+
919
+ # initialize original linear
920
+ super().__init__(in_features, original_out_features, bias, device, dtype)
921
+
922
+ # set weight and bias manually
923
+ if _weight is not None:
924
+ self.weight = nn.Parameter(_weight)
925
+ if _bias is not None:
926
+ self.bias = nn.Parameter(_bias)
927
+
928
+ self.in_features = in_features
929
+ self.original_out_features = original_out_features
930
+ self.max_original_id = max_original_id
931
+
932
+ # initialize additional linear
933
+ self.additional_out_features = additional_out_features
934
+ self.has_bias = bias
935
+ if additional_out_features > 0:
936
+ self.additional_fc = nn.Linear(
937
+ in_features=in_features,
938
+ out_features=additional_out_features,
939
+ bias=self.has_bias,
940
+ device=device,
941
+ dtype=dtype,
942
+ )
943
+ self.set_requires_grad(
944
+ require_regular_grad=not partially_freeze, require_additional_grad=True
945
+ )
946
+
947
+ def set_requires_grad(self, require_regular_grad, require_additional_grad):
948
+ """
949
+ Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
950
+ """
951
+ self.weight.requires_grad_(require_regular_grad)
952
+ if self.has_bias:
953
+ self.bias.requires_grad_(require_regular_grad)
954
+ self.additional_fc.requires_grad_(require_additional_grad)
955
+
956
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
957
+ output = F.linear(input, self.weight, self.bias)
958
+ output = output[..., : self.max_original_id + 1]
959
+
960
+ if self.additional_out_features > 0:
961
+ additional_features = F.linear(
962
+ input, self.additional_fc.weight, self.additional_fc.bias
963
+ )
964
+ output = torch.cat((output, additional_features), -1)
965
+ return output
966
+
967
+ def extra_repr(self) -> str:
968
+ """Overwriting `nn.Linear.extra_repr` to include new parameters."""
969
+ return "in_features={}, out_features={}, additional_out_features={}, bias={}, partially_freeze={}".format(
970
+ self.in_features,
971
+ self.max_original_id + 1,
972
+ self.additional_out_features,
973
+ self.bias is not None,
974
+ (not self.weight.requires_grad or not self.bias.requires_grad),
975
+ )
976
+
977
+
978
+ class VLM(nn.Module):
979
+ """
980
+ Generic vision-language model (VLM) class.
981
+ A VLM consists of four components:
982
+ 1. A vision encoder that extracts features from pixels, e.g. CLIP
983
+ input: (B, T_img, F, C, H, W)
984
+ output: (B, T_img, F, v, d)
985
+ 2. A vision tokenizer that converts these features to visual token-like embeddings, e.g. Perceiver, or a linear projection head
986
+ input: (B, T_img, F, v, d)
987
+ output: (B, T_img, n, d)
988
+ 3. A fusion method that allows the language model to attend to these tokens, e.g. cross-attention, or placing the tokens directly in the language model's input sequence
989
+ 4. A language model
990
+ """
991
+
992
+ def __init__(
993
+ self,
994
+ vision_encoder: nn.Module,
995
+ vision_tokenizer: nn.Module,
996
+ lang_model: nn.Module,
997
+ initial_tokenizer_len: int,
998
+ pad_token_id: int,
999
+ gradient_checkpointing: bool = False,
1000
+ ):
1001
+ """
1002
+ Args:
1003
+ vision_encoder (nn.Module): e.g. CLIP
1004
+ vision_tokenizer (nn.Module): e.g. PerceiverResampler
1005
+ lang_model (nn.Module): e.g. MPT
1006
+ initial_tokenizer_len (int): size of the original tokenizer vocab
1007
+ pad_token_id (int): id of the pad token
1008
+ gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False.
1009
+ """
1010
+ super().__init__()
1011
+
1012
+ # save dimension information
1013
+ self.lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]
1014
+ if hasattr(lang_model.config, "d_model"):
1015
+ self.lang_hidden_dim = lang_model.config.d_model # mpt uses d_model
1016
+ else:
1017
+ self.lang_hidden_dim = lang_model.config.hidden_size
1018
+ self.vis_embedding_dim = vision_tokenizer.dim_media
1019
+ self.num_tokens_per_vis = vision_tokenizer.num_tokens_per_media
1020
+
1021
+ # core components
1022
+ self.vision_encoder = vision_encoder
1023
+ self.vision_tokenizer = vision_tokenizer
1024
+ self.lang_model = lang_model
1025
+
1026
+ # lm embeddings
1027
+ self.pad_token_id = pad_token_id
1028
+ self.initial_tokenizer_len = initial_tokenizer_len
1029
+ input_embeds = DecoupledEmbedding(
1030
+ max_original_id=initial_tokenizer_len - 1,
1031
+ num_additional_embeddings=len(self.special_tokens),
1032
+ _weight=self.lang_model.get_input_embeddings().weight,
1033
+ pad_token_id=self.pad_token_id,
1034
+ )
1035
+ if hasattr(input_embeds, "additional_embedding"):
1036
+ input_embeds.additional_embedding.weight.data.normal_(
1037
+ mean=0.0,
1038
+ std=(
1039
+ self.lang_model.config.initializer_range
1040
+ if hasattr(self.lang_model.config, "initializer_range")
1041
+ else 0.02
1042
+ ),
1043
+ )
1044
+ self.lang_model.set_input_embeddings(input_embeds)
1045
+
1046
+ out_embeds = DecoupledLinear(
1047
+ max_original_id=initial_tokenizer_len - 1,
1048
+ additional_out_features=len(self.special_tokens),
1049
+ _weight=self.lang_model.get_output_embeddings().weight,
1050
+ _bias=(
1051
+ self.lang_model.get_output_embeddings().bias
1052
+ if hasattr(self.lang_model.get_output_embeddings(), "bias")
1053
+ else None
1054
+ ),
1055
+ )
1056
+ if hasattr(out_embeds, "additional_fc"):
1057
+ out_embeds.additional_fc.weight.data.normal_(
1058
+ mean=0.0,
1059
+ std=(
1060
+ self.lang_model.config.initializer_range
1061
+ if hasattr(self.lang_model.config, "initializer_range")
1062
+ else 0.02
1063
+ ),
1064
+ )
1065
+ self.lang_model.set_output_embeddings(out_embeds)
1066
+
1067
+ # gradient checkpointing
1068
+ self.vision_tokenizer._use_gradient_checkpointing = gradient_checkpointing
1069
+
1070
+ def forward(
1071
+ self,
1072
+ vision_x: Optional[torch.Tensor],
1073
+ lang_x: torch.Tensor,
1074
+ attention_mask: Optional[torch.Tensor] = None,
1075
+ labels: Optional[torch.Tensor] = None,
1076
+ past_key_values: Optional[
1077
+ List[Union[torch.Tensor, Tuple[torch.Tensor]]]
1078
+ ] = None,
1079
+ past_media_locations: Optional[torch.Tensor] = None,
1080
+ past_vision_tokens: Optional[torch.Tensor] = None,
1081
+ use_cache: Optional[bool] = False,
1082
+ **kwargs,
1083
+ ):
1084
+ """
1085
+ Args:
1086
+ vision_x: Vision input
1087
+ shape (B, T_img, F, C, H, W) with F=1
1088
+ only F = 1 is supported (single-frame videos)
1089
+ if T_img > the number of media tokens in the corresponding input_ids (lang_x),
1090
+ only the first number of media tokens in lang_x are used
1091
+ lang_x: Language input ids, with media tokens denoting where
1092
+ visual media should be inserted.
1093
+ shape (B, T_txt)
1094
+ attention_mask: Attention mask. Defaults to None.
1095
+ labels: Labels. Defaults to None.
1096
+ shape (B, T_txt)
1097
+ past_key_values (Tuple[torch.Tensor]], optional): Past key value pairs for each of the T_txt previous tokens in the language model. Defaults to None.
1098
+ list of length = number of decoder layers in the LM
1099
+ exact implementation depends on LM, see Hugging Face docs
1100
+ past_media_locations (torch.Tensor, optional): boolean mask denoting which of the previous T_txt tokens were media tokens. Defaults to None.
1101
+ shape (B, T_txt)
1102
+ past_vision_tokens (torch.Tensor, optional): Previous vision tokens. Defaults to None.
1103
+ use_cache (Optional[bool], optional): Whether to use cache. Defaults to False.
1104
+ If True, includes key_values, media_locations, and vision_tokens in the output.
1105
+ """
1106
+ assert not (past_vision_tokens is None) ^ (
1107
+ past_media_locations is None
1108
+ ), "past_vision_tokens and past_media_locations must both be None or both be not None"
1109
+
1110
+ # convert pixels to vision tokens
1111
+ if vision_x is not None:
1112
+ vision_features = self._encode_vision_x(vision_x=vision_x)
1113
+ vision_tokens = self.vision_tokenizer(vision_features)
1114
+ else:
1115
+ vision_tokens = None
1116
+
1117
+ # fuse the vision and language tokens
1118
+ new_inputs = self._prepare_inputs_for_forward(
1119
+ vision_tokens=vision_tokens,
1120
+ lang_x=lang_x,
1121
+ attention_mask=attention_mask,
1122
+ labels=labels,
1123
+ past_key_values=past_key_values,
1124
+ past_media_locations=past_media_locations,
1125
+ padding_side="right",
1126
+ past_vision_tokens=past_vision_tokens,
1127
+ )
1128
+ output = self.lang_model(
1129
+ **new_inputs,
1130
+ use_cache=use_cache,
1131
+ past_key_values=past_key_values,
1132
+ **kwargs,
1133
+ )
1134
+
1135
+ # postprocessing may be needed, e.g. to remove extra tokens from logits that were inserted into the language stream
1136
+ # or to add the past_vision_tokens and past_media_locations to the output
1137
+ output = self._postprocess_outputs_from_forward(
1138
+ output=output,
1139
+ lang_x=lang_x,
1140
+ vision_tokens=vision_tokens,
1141
+ use_cache=use_cache,
1142
+ past_vision_tokens=past_vision_tokens,
1143
+ past_media_locations=past_media_locations,
1144
+ )
1145
+
1146
+ # postforward hooks
1147
+ self._post_forward_hook()
1148
+ return output
1149
+
1150
+ def _encode_vision_x_anyres(self, samples, device):
1151
+ assert self.anyres_grids is not None
1152
+ image_raw = samples[
1153
+ "image"
1154
+ ] # list of patch list in of shape [1, N_patch, C, H, W]
1155
+ image_sizes = samples["image_size"]
1156
+
1157
+ # Image_raw can be a list of list of patches, when a `samples` has multiple images.
1158
+ if isinstance(image_raw[0], list):
1159
+ images = [x.squeeze(0) for sample_img in image_raw for x in sample_img]
1160
+ image_sizes = [s for sample_sizes in image_sizes for s in sample_sizes]
1161
+ else:
1162
+ # assert isinstance(image_raw[0], torch.Tensor), f"Unkown image type: {image_raw[0]}"
1163
+ # concate list of patches into one big patch for any res encoding.
1164
+ images = [x.squeeze(0) for x in image_raw] # [N_patch, C, H, W]
1165
+ image = torch.cat(images, dim=0) # [\sum{B}{N_patch_i}, C, H, W]
1166
+ image = image.to(device)
1167
+
1168
+ with torch.no_grad():
1169
+ if self.vision_encoder.__class__.__name__ == "TimmModel":
1170
+ image_embeds = self.vision_encoder.trunk.forward_features(image)
1171
+ elif self.vision_encoder.__class__.__name__ in [
1172
+ "CLIPVisionModel",
1173
+ "SiglipVisionTransformer",
1174
+ ]:
1175
+ image_embeds = self.vision_encoder(image).last_hidden_state
1176
+ else:
1177
+ image_embeds = self.vision_encoder(image)[1] # OpenCLIP returns tuples
1178
+
1179
+ if isinstance(self.vision_encoder, CLIPVisionModel) or isinstance(
1180
+ self.vision_encoder, SiglipVisionTransformer
1181
+ ):
1182
+ base_img_size = self.vision_encoder.config.image_size
1183
+ else:
1184
+ base_img_size = self.vision_encoder.image_size[0]
1185
+
1186
+ if self.vision_encoder.__class__.__name__ == "TimmModel":
1187
+ grid_size = self.vision_encoder.trunk.patch_embed.grid_size
1188
+ elif self.vision_encoder.__class__.__name__ in [
1189
+ "CLIPVisionModel",
1190
+ "SiglipVisionTransformer",
1191
+ ]:
1192
+ grid_size_base = (
1193
+ self.vision_encoder.config.image_size
1194
+ // self.vision_encoder.config.patch_size
1195
+ )
1196
+ grid_size = (grid_size_base, grid_size_base)
1197
+ else:
1198
+ grid_size = self.vision_encoder.grid_size
1199
+ height, width = grid_size
1200
+
1201
+ if not image_embeds.shape[1] == height * width:
1202
+ assert (
1203
+ image_embeds.shape[1] == height * width + 1
1204
+ ) # For vision encoders that has [CLS] token.
1205
+ image_embeds = image_embeds[:, 1:, :] # Drop the cls token for each patch.
1206
+ n_vis_token_per_patch = image_embeds.shape[1]
1207
+
1208
+ # Split encoded patches and merge patch features
1209
+ # 1. Get the raw sizes from samples, and split the image embeds [\sum_{B}(N_patch_i), N_tok(16*16), C]
1210
+ split_sizes = [image.shape[0] for image in images]
1211
+ image_embeds = torch.split(image_embeds, split_sizes, dim=0)
1212
+ # 2. For each image (consist of a list of patches), merge the patches spatially (of shape [C, n_patch_height, n_patch_width])
1213
+ new_image_embeds = []
1214
+ patch_attn_masks = []
1215
+ max_n_img_token = -1
1216
+ for idx, patch_embeds in enumerate(image_embeds):
1217
+ if patch_embeds.shape[0] > 1:
1218
+ # 3. Flatten the patch features and get [C, n_patch_height * (n_patch_width+1)]
1219
+ base_patch_embeds = patch_embeds[
1220
+ 0
1221
+ ] # TODO: prepend the CLS token for th base patch embeds (of the resized entire image).
1222
+ patch_embeds = patch_embeds[1:]
1223
+
1224
+ assert height * width == base_patch_embeds.shape[0]
1225
+
1226
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(
1227
+ image_sizes[idx], self.anyres_grids, base_img_size
1228
+ ) # Hardcoded grid_pinpoints.
1229
+ patch_embeds = patch_embeds.view(
1230
+ num_patch_height, num_patch_width, height, width, -1
1231
+ )
1232
+
1233
+ patch_embeds = patch_embeds.permute(4, 0, 2, 1, 3).contiguous()
1234
+ patch_embeds = patch_embeds.flatten(1, 2).flatten(2, 3)
1235
+ patch_embeds, patch_attn_mask = unpad_image(
1236
+ patch_embeds, image_sizes[idx], self.anyres_patch_sampling
1237
+ )
1238
+ if hasattr(self, "image_newline"):
1239
+ patch_embeds = torch.cat(
1240
+ (
1241
+ patch_embeds,
1242
+ self.image_newline[:, None, None].expand(
1243
+ *patch_embeds.shape[:-1], 1
1244
+ ),
1245
+ ),
1246
+ dim=-1,
1247
+ )
1248
+ if self.anyres_patch_sampling:
1249
+ patch_embeds = patch_embeds.view(
1250
+ -1, num_patch_height, num_patch_width, height * width
1251
+ )
1252
+ patch_embeds = patch_embeds.flatten(1, 2).permute(1, 2, 0)
1253
+ assert patch_attn_mask is not None
1254
+ patch_attn_mask = patch_attn_mask.view(
1255
+ num_patch_height, num_patch_width, height * width
1256
+ )
1257
+ patch_attn_mask = patch_attn_mask.flatten(0, 1)
1258
+ patch_embeds = torch.cat(
1259
+ (base_patch_embeds.unsqueeze(0), patch_embeds), dim=0
1260
+ )
1261
+ patch_attn_mask = torch.cat(
1262
+ (
1263
+ torch.ones(
1264
+ n_vis_token_per_patch, device=patch_embeds.device
1265
+ ).unsqueeze(0),
1266
+ patch_attn_mask,
1267
+ ),
1268
+ dim=0,
1269
+ )
1270
+ else:
1271
+ patch_embeds = patch_embeds.flatten(1, 2).transpose(0, 1)
1272
+ patch_embeds = torch.cat((base_patch_embeds, patch_embeds), dim=0)
1273
+ else:
1274
+ patch_embeds = (
1275
+ patch_embeds[0].unsqueeze(0)
1276
+ if self.anyres_patch_sampling
1277
+ else patch_embeds[0]
1278
+ )
1279
+ patch_attn_mask = (
1280
+ torch.ones(
1281
+ n_vis_token_per_patch, device=patch_embeds.device
1282
+ ).unsqueeze(0)
1283
+ if self.anyres_patch_sampling
1284
+ else None
1285
+ )
1286
+ if hasattr(self, "image_newline"):
1287
+ patch_embeds = torch.cat(
1288
+ (patch_embeds, self.image_newline[None]), dim=0
1289
+ )
1290
+ if not self.anyres_patch_sampling:
1291
+ max_n_img_token = max(patch_embeds.shape[0], max_n_img_token)
1292
+
1293
+ new_image_embeds.append(patch_embeds)
1294
+ patch_attn_masks.append(patch_attn_mask)
1295
+
1296
+ if self.anyres_patch_sampling:
1297
+ # Return individual patches for independent token downsampling.
1298
+ return new_image_embeds, patch_attn_masks
1299
+
1300
+ # 4. Pad and concat the list of image_embeds [N_tok_i, C] together into a batch. Also modify the query attention mask.
1301
+ image_embeds = []
1302
+ image_atts = []
1303
+ for image_embed in new_image_embeds:
1304
+ n_img_token = image_embed.shape[0]
1305
+ img_attn = torch.ones(
1306
+ (max_n_img_token), dtype=torch.long, device=image_embed.device
1307
+ )
1308
+ if n_img_token < max_n_img_token:
1309
+ padded_embed = torch.zeros(
1310
+ (max_n_img_token, image_embed.shape[-1]),
1311
+ dtype=image_embed.dtype,
1312
+ device=image_embed.device,
1313
+ )
1314
+ padded_embed[:n_img_token, :] = image_embed
1315
+ img_attn[n_img_token:] = 0 # Mask out the padded entries.
1316
+ else:
1317
+ padded_embed = image_embed
1318
+ image_embeds.append(padded_embed)
1319
+ image_atts.append(img_attn)
1320
+ image_embeds = torch.stack(
1321
+ image_embeds, dim=0
1322
+ ) # Shape [B, N_tok_longest, C_dim]
1323
+ image_atts = torch.stack(image_atts, dim=0) # Shape [B, N_tok_longest, C_dim]
1324
+ # TODO: reshape image_embeds and image_atts to "b T F v d"
1325
+ image_embeds = image_embeds[:, None, None, :, :]
1326
+ # image_atts = image_atts[:, None, None, :, :]
1327
+
1328
+ return image_embeds, image_atts
1329
+
1330
+ def _encode_vision_x(self, vision_x: torch.Tensor):
1331
+ """
1332
+ Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
1333
+ Args:
1334
+ vision_x: Vision input
1335
+ shape (B, T_img, F, C, H, W)
1336
+ Images in the same chunk are collated along T_img, and frames are collated along F
1337
+ Currently only F=1 is supported (single-frame videos)
1338
+
1339
+ rearrange code based on https://github.com/dhansmair/flamingo-mini
1340
+ """
1341
+ assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
1342
+ b, T, F = vision_x.shape[:3]
1343
+
1344
+ vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
1345
+ with torch.no_grad():
1346
+ if self.vision_encoder.__class__.__name__ == "TimmModel":
1347
+ vision_x = self.vision_encoder.trunk.forward_features(vision_x)
1348
+ elif self.vision_encoder.__class__.__name__ in [
1349
+ "CLIPVisionModel",
1350
+ "SiglipVisionTransformer",
1351
+ ]:
1352
+ vision_x = self.vision_encoder(vision_x).last_hidden_state
1353
+ else:
1354
+ vision_x = self.vision_encoder(vision_x)[1] # OpenCLIP returns tuples
1355
+ vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
1356
+ return vision_x
1357
+
1358
+ def _concat_vision_cache(
1359
+ self, lang_x, vision_tokens, past_vision_tokens, past_media_locations, use_cache
1360
+ ):
1361
+ """
1362
+ Helper function to include the past vision tokens and past media locations in the output.
1363
+ """
1364
+ if use_cache:
1365
+ if past_media_locations is not None and past_vision_tokens is not None:
1366
+ if vision_tokens is not None:
1367
+ updated_vision_tokens = torch.cat(
1368
+ [
1369
+ past_vision_tokens,
1370
+ vision_tokens,
1371
+ ],
1372
+ dim=1,
1373
+ )
1374
+ else:
1375
+ updated_vision_tokens = past_vision_tokens
1376
+ updated_media_locations = torch.cat(
1377
+ [
1378
+ past_media_locations,
1379
+ lang_x == self.media_token_id,
1380
+ ],
1381
+ dim=1,
1382
+ )
1383
+ else:
1384
+ updated_vision_tokens = vision_tokens
1385
+ updated_media_locations = lang_x == self.media_token_id
1386
+
1387
+ else:
1388
+ updated_vision_tokens = None
1389
+ updated_media_locations = None
1390
+
1391
+ return updated_vision_tokens, updated_media_locations
1392
+
1393
+ def generate(
1394
+ self,
1395
+ vision_x: torch.Tensor,
1396
+ lang_x: torch.Tensor,
1397
+ attention_mask: torch.Tensor = None,
1398
+ past_key_values: Optional[
1399
+ List[Union[torch.Tensor, Tuple[torch.Tensor]]]
1400
+ ] = None,
1401
+ past_media_locations: Optional[torch.Tensor] = None,
1402
+ past_vision_tokens: Optional[torch.Tensor] = None,
1403
+ **kwargs,
1404
+ ):
1405
+ """
1406
+ Generate text conditioned on vision and language inputs.
1407
+ Args:
1408
+ vision_x (torch.Tensor): Vision input
1409
+ shape (B, T_img, F, C, H, W)
1410
+ see documentation for forward
1411
+ lang_x (torch.Tensor): Language input
1412
+ shape (B, T_txt)
1413
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
1414
+ **kwargs: see generate documentation in Hugging Face CausalLM models.
1415
+ Returns:
1416
+ torch.Tensor: lang_x with generated tokens appended to it
1417
+ """
1418
+ num_beams = kwargs.pop("num_beams", 1)
1419
+
1420
+ # convert pixels to vision tokens
1421
+ if vision_x is not None:
1422
+ vision_features = self._encode_vision_x(vision_x=vision_x)
1423
+ vision_tokens = self.vision_tokenizer(vision_features)
1424
+ else:
1425
+ vision_tokens = None
1426
+
1427
+ # fuse the vision and language tokens
1428
+ # for xattn, vision_x and media_location are repeat_interleaved s.t.
1429
+ # the total batch size is B * num_beams
1430
+ new_inputs = self._prepare_inputs_for_forward(
1431
+ vision_tokens=vision_tokens,
1432
+ lang_x=lang_x,
1433
+ attention_mask=attention_mask,
1434
+ past_key_values=past_key_values,
1435
+ past_media_locations=past_media_locations,
1436
+ past_vision_tokens=past_vision_tokens,
1437
+ padding_side="left",
1438
+ num_beams=num_beams,
1439
+ )
1440
+ output = self.lang_model.generate(
1441
+ **new_inputs,
1442
+ past_key_values=past_key_values,
1443
+ num_beams=num_beams,
1444
+ use_cache=True,
1445
+ **kwargs,
1446
+ )
1447
+ self._post_forward_hook()
1448
+ return output
1449
+
1450
+ @property
1451
+ def num_trainable_params(self):
1452
+ """Print the number of trainable parameters"""
1453
+ return num_params(self, filter_to_trainable=True)
1454
+
1455
+ def set_trainable(self):
1456
+ """
1457
+ Freeze appropriate parameters in the model.
1458
+ """
1459
+ raise NotImplementedError
1460
+
1461
+ def group_params_by_weight_decay(self):
1462
+ """
1463
+ Return a tuple of (params to optimize w/ weight decay, params to optimize w/o weight decay)
1464
+ """
1465
+ params_with_wd, params_without_wd = [], []
1466
+ for n, p in self.named_parameters():
1467
+ if p.requires_grad:
1468
+ if self._should_apply_weight_decay(n):
1469
+ params_with_wd.append(p)
1470
+ else:
1471
+ params_without_wd.append(p)
1472
+ return params_with_wd, params_without_wd
1473
+
1474
+ def _should_apply_weight_decay(self, parameter_name):
1475
+ """
1476
+ Return whether weight decay should be applied to a parameter.
1477
+ """
1478
+ raise NotImplementedError
1479
+
1480
+ @property
1481
+ def special_tokens(self):
1482
+ """
1483
+ Returns a dict mapping from the attribute name of a special token to its string format,
1484
+ e.g. "media_token": "<image>"
1485
+ """
1486
+ assert (
1487
+ "media_token" in self._special_tokens
1488
+ ), "VLMs need to request that the tokenizer add a media_token and call set_special_token_ids to set self.media_token_id"
1489
+ return self._special_tokens
1490
+
1491
+ @property
1492
+ def special_token_ids(self):
1493
+ """
1494
+ Returns a list of the special token ids
1495
+ """
1496
+ return [getattr(self, f"{att_name}_id") for att_name in self.special_tokens]
1497
+
1498
+ def set_special_token_ids(self, string_to_ids):
1499
+ """
1500
+ Args:
1501
+ string_to_ids (dict): mapping from token string to id
1502
+ """
1503
+ assert set(self.special_tokens.values()).issubset(set(string_to_ids.keys()))
1504
+ for att_name, token_str in self.special_tokens.items():
1505
+ token_id = string_to_ids[token_str]
1506
+ setattr(self, f"{att_name}_id", token_id)
1507
+ setattr(self.lang_model, f"{att_name}_id", token_id)
1508
+
1509
+ def init_gradient_checkpointing(self):
1510
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
1511
+ checkpoint_wrapper,
1512
+ CheckpointWrapper,
1513
+ CheckpointImpl,
1514
+ apply_activation_checkpointing,
1515
+ )
1516
+ from functools import partial
1517
+
1518
+ non_reentrant_wrapper = partial(
1519
+ checkpoint_wrapper,
1520
+ checkpoint_impl=CheckpointImpl.NO_REENTRANT,
1521
+ )
1522
+ apply_activation_checkpointing(
1523
+ self,
1524
+ checkpoint_wrapper_fn=non_reentrant_wrapper,
1525
+ check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False)
1526
+ and not isinstance(m, CheckpointWrapper),
1527
+ )
1528
+
1529
+
1530
+ @dataclass
1531
+ class VLMOutputWithPast(CausalLMOutputWithPast):
1532
+ """
1533
+ VLMOutputWithPast is a wrapper around CausalLMOutputWithPast that adds the following attributes:
1534
+ past_media_locations: Optional[torch.Tensor] = None,
1535
+ past_vision_tokens: Optional[torch.Tensor] = None,
1536
+ """
1537
+
1538
+ past_media_locations: Optional[torch.Tensor] = None
1539
+ past_vision_tokens: Optional[torch.Tensor] = None
1540
+
1541
+
1542
+ def exists(val):
1543
+ return val is not None
1544
+
1545
+
1546
+ def FeedForward(dim, mult=4):
1547
+ inner_dim = int(dim * mult)
1548
+ return nn.Sequential(
1549
+ nn.LayerNorm(dim),
1550
+ nn.Linear(dim, inner_dim, bias=False),
1551
+ nn.GELU(),
1552
+ nn.Linear(inner_dim, dim, bias=False),
1553
+ )
1554
+
1555
+
1556
+ class VLMWithLanguageStream(VLM):
1557
+ """
1558
+ VLM that fuses modalities by inserting vision tokens directly into the language stream.
1559
+ """
1560
+
1561
+ def __init__(
1562
+ self,
1563
+ vision_encoder: nn.Module,
1564
+ vision_tokenizer: nn.Module,
1565
+ lang_model: nn.Module,
1566
+ initial_tokenizer_len: int,
1567
+ pad_token_id: int,
1568
+ decoder_layers_attr_name: str = None,
1569
+ gradient_checkpointing: bool = False,
1570
+ ):
1571
+ super().__init__(
1572
+ vision_encoder=vision_encoder,
1573
+ vision_tokenizer=vision_tokenizer,
1574
+ lang_model=lang_model,
1575
+ initial_tokenizer_len=initial_tokenizer_len,
1576
+ pad_token_id=pad_token_id,
1577
+ gradient_checkpointing=gradient_checkpointing,
1578
+ )
1579
+ self.decoder_layers_attr_name = decoder_layers_attr_name
1580
+ if decoder_layers_attr_name is not None:
1581
+ for block in getattr_recursive(
1582
+ self.lang_model, self.decoder_layers_attr_name
1583
+ ):
1584
+ block._use_gradient_checkpointing = gradient_checkpointing
1585
+
1586
+ def _prepare_inputs_for_forward(
1587
+ self,
1588
+ vision_tokens: torch.Tensor,
1589
+ lang_x: torch.Tensor,
1590
+ attention_mask: torch.Tensor,
1591
+ labels: torch.Tensor = None,
1592
+ past_key_values=None,
1593
+ vision_attention_mask: Optional[torch.Tensor] = None,
1594
+ past_media_locations: torch.Tensor = None,
1595
+ past_vision_tokens: torch.Tensor = None,
1596
+ padding_side: str = "left",
1597
+ num_beams: int = 1,
1598
+ ):
1599
+ """
1600
+ Insert the vision tokens directly into the language stream/
1601
+ This requires us to modify the input_ids, attention_mask, and labels.
1602
+ """
1603
+ if past_key_values is not None:
1604
+ past_len = past_key_values[0][0].shape[2]
1605
+ assert attention_mask.shape[1] == past_len + lang_x.shape[1], (
1606
+ "Attention_mask must be as long as the entire past len (including image tokens) and current input IDs. "
1607
+ + "Check that you've expanded the attention mask to account for past image tokens."
1608
+ )
1609
+
1610
+ if vision_tokens is None:
1611
+ return {
1612
+ "input_ids": lang_x,
1613
+ "attention_mask": attention_mask,
1614
+ "labels": labels,
1615
+ }
1616
+
1617
+ # get the language embeddings
1618
+ lang_embeds = self.lang_model.get_input_embeddings()(lang_x)
1619
+
1620
+ # build up the multimodal embeddings
1621
+ B = lang_x.shape[0]
1622
+ has_labels = labels is not None
1623
+ multimodal_embeds = []
1624
+ multimodal_attention_mask = []
1625
+ multimodal_labels = [] if has_labels else None
1626
+ for i in range(B):
1627
+ # get index of <image> tokens in lang_x[i]
1628
+ image_token_idxs = torch.where(lang_x[i] == self.media_token_id)[0]
1629
+
1630
+ if len(image_token_idxs) == 0:
1631
+ multimodal_embeds.append(lang_embeds[i].clone())
1632
+ multimodal_attention_mask.append(attention_mask[i].clone())
1633
+ if has_labels:
1634
+ multimodal_labels.append(labels[i].clone())
1635
+ continue
1636
+
1637
+ # loop through the image_token_idxs and insert the vision tokens
1638
+ new_embed = lang_embeds[i].clone()
1639
+ new_attention_mask = (
1640
+ attention_mask[i].clone() if attention_mask is not None else None
1641
+ )
1642
+ if has_labels:
1643
+ new_label = labels[i].clone()
1644
+
1645
+ for img_num, img_idx in enumerate(image_token_idxs):
1646
+ # Get vision token attention mask for padded llava-style any resolution image tokens.
1647
+ if self.image_aspect_ratio == "anyres":
1648
+ num_vis_tokens = vision_tokens[i][img_num].shape[0]
1649
+ if vision_attention_mask is not None:
1650
+ vis_attention_mask = vision_attention_mask[i]
1651
+ else:
1652
+ vis_attention_mask = torch.ones(
1653
+ num_vis_tokens, dtype=torch.long
1654
+ ).to(attention_mask.device)
1655
+ else:
1656
+ assert (
1657
+ vision_tokens[i][img_num].shape[0] == self.num_tokens_per_vis
1658
+ ), f"vision token number mismatch: image embedding ({vision_tokens[i][img_num].shape[0]}) \
1659
+ vs. model.num_tokens_per_vis ({self.num_tokens_per_vis})"
1660
+ # By default, vision tokens are not padded.
1661
+ num_vis_tokens = self.num_tokens_per_vis
1662
+ vis_attention_mask = torch.ones(
1663
+ num_vis_tokens, dtype=torch.long
1664
+ ).to(attention_mask.device)
1665
+
1666
+ new_embed = torch.cat(
1667
+ (
1668
+ new_embed[:img_idx],
1669
+ vision_tokens[i][img_num],
1670
+ new_embed[img_idx + 1 :],
1671
+ ),
1672
+ dim=0,
1673
+ )
1674
+ new_attention_mask = torch.cat(
1675
+ (
1676
+ new_attention_mask[:img_idx],
1677
+ vis_attention_mask,
1678
+ new_attention_mask[img_idx + 1 :],
1679
+ ),
1680
+ dim=0,
1681
+ )
1682
+ if has_labels:
1683
+ new_label = torch.cat(
1684
+ (
1685
+ new_label[:img_idx],
1686
+ torch.ones(num_vis_tokens, dtype=torch.long).to(
1687
+ labels.device
1688
+ )
1689
+ * -100,
1690
+ new_label[img_idx + 1 :],
1691
+ ),
1692
+ dim=0,
1693
+ )
1694
+ multimodal_embeds.append(new_embed)
1695
+ multimodal_attention_mask.append(new_attention_mask)
1696
+ if has_labels:
1697
+ multimodal_labels.append(new_label)
1698
+
1699
+ # stack
1700
+ multimodal_embeds = stack_with_padding(
1701
+ multimodal_embeds,
1702
+ padding_value=self.pad_token_id,
1703
+ padding_side=padding_side,
1704
+ )
1705
+ multimodal_attention_mask = stack_with_padding(
1706
+ multimodal_attention_mask,
1707
+ padding_value=0,
1708
+ padding_side=padding_side,
1709
+ )
1710
+ if has_labels:
1711
+ multimodal_labels = stack_with_padding(
1712
+ multimodal_labels,
1713
+ padding_value=-100,
1714
+ padding_side=padding_side,
1715
+ )
1716
+
1717
+ return {
1718
+ "inputs_embeds": multimodal_embeds,
1719
+ "attention_mask": multimodal_attention_mask,
1720
+ "labels": multimodal_labels,
1721
+ }
1722
+
1723
+ def _postprocess_outputs_from_forward(
1724
+ self,
1725
+ output: CausalLMOutputWithPast,
1726
+ lang_x: torch.Tensor,
1727
+ vision_tokens: torch.Tensor,
1728
+ past_vision_tokens: torch.Tensor,
1729
+ past_media_locations: torch.Tensor,
1730
+ use_cache: bool = False,
1731
+ ):
1732
+ # Include the past vision tokens and past media locations in the output
1733
+ updated_vision_tokens, updated_media_locations = self._concat_vision_cache(
1734
+ lang_x=lang_x,
1735
+ vision_tokens=vision_tokens,
1736
+ past_vision_tokens=past_vision_tokens,
1737
+ past_media_locations=past_media_locations,
1738
+ use_cache=use_cache,
1739
+ )
1740
+
1741
+ # return logits that are the same shape as the original input_ids
1742
+ logits = output.logits
1743
+ batch_logits = []
1744
+ B, T_txt = lang_x.shape
1745
+ for i in range(B):
1746
+ sequence_logits = []
1747
+ logits_j = 0
1748
+ for j in range(T_txt):
1749
+ if lang_x[i, j] != self.media_token_id:
1750
+ sequence_logits.append(logits[i, logits_j])
1751
+ logits_j += 1
1752
+ else:
1753
+ # append the logit for the first image token, then skip over the rest
1754
+ # note: the model actually learns to predict <im_patch>, not <image>
1755
+ sequence_logits.append(logits[i, logits_j])
1756
+ logits_j += self.num_tokens_per_vis
1757
+ sequence_logits = torch.stack(sequence_logits, dim=0) # (B, vocab_size)
1758
+ batch_logits.append(sequence_logits)
1759
+
1760
+ batch_logits = torch.stack(batch_logits, dim=0) # (B, T_txt, vocab_size)
1761
+ # The final logits shape should be the same as the original input_ids shape
1762
+ assert batch_logits.shape[:2] == (B, T_txt)
1763
+
1764
+ # assemble the output
1765
+ output = VLMOutputWithPast(
1766
+ loss=output.loss,
1767
+ logits=batch_logits,
1768
+ past_key_values=output.past_key_values,
1769
+ hidden_states=output.hidden_states,
1770
+ attentions=output.attentions,
1771
+ past_media_locations=updated_media_locations,
1772
+ past_vision_tokens=updated_vision_tokens,
1773
+ )
1774
+
1775
+ return output
1776
+
1777
+ def _post_forward_hook(self):
1778
+ pass
1779
+
1780
+ @property
1781
+ def num_params_per_module(self):
1782
+ """Print the number of parameters per module in the model"""
1783
+ return "\n".join(
1784
+ [
1785
+ f"Vision encoder: {num_params(self.vision_encoder):,} parameters",
1786
+ f"Vision tokenizer: {num_params(self.vision_tokenizer):,} parameters",
1787
+ f"Language model: {num_params(self.lang_model):,} parameters",
1788
+ ]
1789
+ )
1790
+
1791
+ @property
1792
+ def num_trainable_params_per_module(self):
1793
+ """Print the number of trainable parameters per module in the model"""
1794
+ return "\n".join(
1795
+ [
1796
+ f"Vision encoder: {num_params(self.vision_encoder, filter_to_trainable=True):,} trainable parameters",
1797
+ f"Vision tokenizer: {num_params(self.vision_tokenizer, filter_to_trainable=True):,} trainable parameters",
1798
+ f"Language model: {num_params(self.lang_model, filter_to_trainable=True):,} trainable parameters",
1799
+ ]
1800
+ )
1801
+
1802
+
1803
+ class XGenMMPerceiver(VLMWithLanguageStream):
1804
+ def __init__(
1805
+ self,
1806
+ vision_encoder: nn.Module,
1807
+ vision_tokenizer: nn.Module,
1808
+ lang_model: nn.Module,
1809
+ initial_tokenizer_len: int,
1810
+ pad_token_id: int,
1811
+ decoder_layers_attr_name: str = None,
1812
+ gradient_checkpointing: bool = False,
1813
+ image_aspect_ratio: str = "anyres",
1814
+ anyres_patch_sampling: bool = True,
1815
+ anyres_grids: list[int] = None,
1816
+ ):
1817
+ """
1818
+ Args:
1819
+ vision_encoder (nn.Module): HF CLIPModel
1820
+ lang_encoder (nn.Module): HF causal language model
1821
+ vis_feature_dim (int): final dimension of the visual features outputted by the vision_encoder
1822
+ initial_tokenizer_len (int): size of the tokenizer vocab
1823
+ padding_token_id (int): id of the padding token. None if no padding token; then a padding token
1824
+ will be inserted into self.special_tokens, which factory.py fills after creating new tokens
1825
+ decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
1826
+ gradient_checkpointing (bool, optional): whether to use gradient checkpointing. Defaults to False.
1827
+ """
1828
+ self._special_tokens = {
1829
+ "media_token": "<image>",
1830
+ "image_placeholder_token": "<image placeholder>",
1831
+ "end_of_trunk_token": "<|endofchunk|>",
1832
+ }
1833
+ lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]
1834
+ super().__init__(
1835
+ vision_encoder=vision_encoder,
1836
+ vision_tokenizer=vision_tokenizer,
1837
+ lang_model=lang_model,
1838
+ initial_tokenizer_len=initial_tokenizer_len,
1839
+ gradient_checkpointing=gradient_checkpointing,
1840
+ decoder_layers_attr_name=decoder_layers_attr_name,
1841
+ pad_token_id=pad_token_id,
1842
+ )
1843
+ self.image_aspect_ratio = image_aspect_ratio
1844
+ self.anyres_patch_sampling = anyres_patch_sampling
1845
+ self.anyres_grids = anyres_grids
1846
+
1847
+ def set_trainable(self):
1848
+ """
1849
+ Unfreeze everything except the vision_encoder
1850
+ """
1851
+ self.requires_grad_(True)
1852
+ self.vision_encoder.requires_grad_(False)
1853
+
1854
+ def _should_apply_weight_decay(self, parameter_name):
1855
+ """
1856
+ Kosmos applies 0.01 weight deacy to everything
1857
+ """
1858
+ return True
1859
+
1860
+ def generate(
1861
+ self,
1862
+ vision_x: torch.Tensor,
1863
+ lang_x: torch.Tensor,
1864
+ image_size: Optional[Tuple] = None,
1865
+ attention_mask: torch.Tensor = None,
1866
+ past_key_values: Optional[
1867
+ List[Union[torch.Tensor, Tuple[torch.Tensor]]]
1868
+ ] = None,
1869
+ past_media_locations: Optional[torch.Tensor] = None,
1870
+ past_vision_tokens: Optional[torch.Tensor] = None,
1871
+ **kwargs,
1872
+ ):
1873
+ """
1874
+ Generate text conditioned on vision and language inputs.
1875
+ Args:
1876
+ vision_x (torch.Tensor): Vision input
1877
+ shape (B, T_img, F, C, H, W)
1878
+ see documentation for forward
1879
+ lang_x (torch.Tensor): Language input
1880
+ shape (B, T_txt)
1881
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
1882
+ **kwargs: see generate documentation in Hugging Face CausalLM models.
1883
+ Returns:
1884
+ torch.Tensor: lang_x with generated tokens appended to it
1885
+ """
1886
+ num_beams = kwargs.pop("num_beams", 1)
1887
+
1888
+ # convert pixels to vision tokens
1889
+ vision_attention_mask = None
1890
+ if vision_x is not None:
1891
+ if self.image_aspect_ratio == "anyres":
1892
+ input_dict = dict(image=vision_x, image_size=image_size)
1893
+ vision_features, vision_attn_masks = self._encode_vision_x_anyres(
1894
+ input_dict, lang_x.device
1895
+ )
1896
+ else:
1897
+ vision_features = self._encode_vision_x(vision_x=vision_x)
1898
+ vision_attn_masks = None
1899
+ # If doing patch sampling, then flatten patches of shape [b, Np_i, v, d] -> [b*Np, v, d]
1900
+ # Same for attention masks: [b, Np, v] -> [b*Np, v]
1901
+ if self.anyres_patch_sampling:
1902
+ split_sizes = [feature.shape[0] for feature in vision_features]
1903
+ # Nested splits for multi-image samples.
1904
+ if isinstance(vision_x[0], list):
1905
+ nt_images = [len(images) for images in vision_x]
1906
+ split_split_sizes = []
1907
+ img_id = 0
1908
+ for nt in nt_images:
1909
+ split_split_sizes.append(split_sizes[img_id : img_id + nt])
1910
+ img_id += nt
1911
+ else:
1912
+ nt_images = [1] * len(vision_x)
1913
+ split_split_sizes = split_sizes
1914
+ vision_features = torch.cat(vision_features, dim=0)
1915
+ vision_features = vision_features[
1916
+ :, None, None, :, :
1917
+ ] # Expand dimensions.
1918
+ vision_attn_masks = torch.cat(vision_attn_masks, dim=0)
1919
+ vision_tokens = self.vision_tokenizer(vision_features, vision_attn_masks)
1920
+
1921
+ # Post-processing: Split the batches into groups of patches and concatenate them together.
1922
+ if self.anyres_patch_sampling:
1923
+ assert isinstance(vision_x, list)
1924
+ if isinstance(vision_x[0], list):
1925
+ vision_token_groups = torch.split(
1926
+ vision_tokens,
1927
+ list(sum(nt_img) for nt_img in split_split_sizes),
1928
+ dim=0,
1929
+ )
1930
+ vision_tokens = []
1931
+
1932
+ for sample_id, patch_vis_tokens in enumerate(vision_token_groups):
1933
+ patch_vis_token_groups = torch.split(
1934
+ patch_vis_tokens, split_split_sizes[sample_id], dim=0
1935
+ ) # [Np*nt, 1, v, d] -> [[Np_t, 1, v, d], ...]
1936
+ flatten_vision_tokens = []
1937
+ for image_vis_token in patch_vis_token_groups:
1938
+ image_vis_token = image_vis_token.flatten(
1939
+ 0, 2
1940
+ ) # [Np, 1, v, d] -> [Np*v, d]
1941
+ flatten_vision_tokens.append(image_vis_token)
1942
+ vision_tokens_i = flatten_vision_tokens
1943
+ vision_tokens.append(vision_tokens_i)
1944
+ else:
1945
+ vision_token_groups = torch.split(vision_tokens, split_sizes, dim=0)
1946
+ vision_tokens = []
1947
+ for patch_vis_tokens in vision_token_groups:
1948
+ patch_vis_tokens = patch_vis_tokens.flatten(
1949
+ 0, 2
1950
+ ) # [Np, 1, v, d] -> [Np*v, d]
1951
+ vision_tokens.append(
1952
+ patch_vis_tokens.unsqueeze(0)
1953
+ ) # Add the nt dimension.
1954
+ else:
1955
+ vision_tokens = None
1956
+
1957
+ # fuse the vision and language tokens
1958
+ # for xattn, vision_x and media_location are repeat_interleaved s.t.
1959
+ # the total batch size is B * num_beams
1960
+ new_inputs = self._prepare_inputs_for_forward(
1961
+ vision_tokens=vision_tokens,
1962
+ lang_x=lang_x,
1963
+ attention_mask=attention_mask,
1964
+ vision_attention_mask=vision_attention_mask,
1965
+ past_key_values=past_key_values,
1966
+ past_media_locations=past_media_locations,
1967
+ past_vision_tokens=past_vision_tokens,
1968
+ padding_side="left",
1969
+ num_beams=num_beams,
1970
+ )
1971
+ if past_key_values is not None:
1972
+ output = self.lang_model.generate(
1973
+ **new_inputs,
1974
+ past_key_values=past_key_values,
1975
+ num_beams=num_beams,
1976
+ use_cache=True,
1977
+ **kwargs,
1978
+ )
1979
+ else:
1980
+ output = self.lang_model.generate(
1981
+ **new_inputs,
1982
+ num_beams=num_beams,
1983
+ use_cache=True,
1984
+ **kwargs,
1985
+ )
1986
+ self._post_forward_hook()
1987
+ return output
1988
+
1989
 
1990
  class XGenMMVisionEncoder(PreTrainedModel):
1991
  main_input_name = "pixel_values"
1992
  config_class = XGenMMVisionEncoderConfig
1993
+
1994
  def __init__(self, config: XGenMMVisionEncoderConfig):
1995
  super().__init__(config)
1996
+ if config.model_name != "google/siglip-so400m-patch14-384":
1997
+ raise ValueError(
1998
+ f"Unsupported model {config.model_name}. New vision models will be added soon."
1999
+ )
2000
  self.model = AutoModel.from_pretrained(config.model_name)
2001
+
2002
  def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
2003
  # assert pixel_values.ndim == 4, f"Expected 4D tensor (bs, c, h, w), got {pixel_values.ndim}"
2004
  return self.model.encode_image(pixel_values)
 
2005
 
2006
+
2007
+ # vision tokenizer
2008
  class XGenMMVisionTokenizer(PreTrainedModel):
2009
  config_class = XGenMMVisionTokenizerConfig
2010
+
2011
  def __init__(self, config: XGenMMVisionTokenizerConfig):
2012
  super().__init__(config)
2013
  self.model = PerceiverResampler(
 
2015
  dim_inner=config.lang_embedding_dim,
2016
  num_latents=config.num_vis_tokens,
2017
  )
2018
+
2019
+ def forward(self, vision_features: torch.Tensor, vision_attn_masks: torch.Tensor):
 
 
2020
  return self.model(vision_features, vision_attn_masks)
2021
+
2022
+
2023
  # XGenMM model
2024
  class XGenMMModelForConditionalGeneration(PreTrainedModel):
2025
  config_class = XGenMMConfig
2026
+
2027
  def __init__(self, config: XGenMMConfig):
2028
  super().__init__(config)
2029
+
2030
  # vision encoder initialization
2031
+ vision_encoder = AutoModel.from_pretrained(
2032
+ config.vision_encoder_config.model_name
2033
+ ).vision_model
2034
+
2035
+ # language model initialization
2036
  language_model = AutoModelForCausalLM.from_config(config.text_config)
2037
  check_embedding_fns(language_model)
2038
  # Update _tied_weights_keys using the base model used.
2039
  if language_model._tied_weights_keys is not None:
2040
+ self._tied_weights_keys = [
2041
+ f"language_model.{k}" for k in language_model._tied_weights_keys
2042
+ ]
2043
+
2044
  # vision tokenizer initialization
2045
+ if (
2046
+ config.vision_tokenizer_config.lang_embedding_dim
2047
+ != language_model.get_input_embeddings().weight.shape[1]
2048
+ ):
2049
  overwrite = language_model.get_input_embeddings().weight.shape[1]
2050
  config.vision_tokenizer_config.lang_embedding_dim = overwrite
2051
+ print(
2052
+ f"Warning: The language embedding dimension in the vision tokenizer config is different from the language model's embedding dimension. Overwriting the language embedding dimension in the vision tokenizer config to {overwrite}."
2053
+ )
2054
+
2055
  vision_tokenizer = XGenMMVisionTokenizer(config.vision_tokenizer_config).model
2056
 
2057
  self.vlm = XGenMMPerceiver(
2058
  vision_encoder=vision_encoder,
2059
  vision_tokenizer=vision_tokenizer,
2060
  lang_model=language_model,
2061
+ initial_tokenizer_len=config.text_config.initial_tokenizer_len,
2062
+ pad_token_id=config.text_config.pad_token_id,
2063
+ image_aspect_ratio=config.vision_encoder_config.image_aspect_ratio,
2064
+ anyres_patch_sampling=config.vision_encoder_config.anyres_patch_sampling,
2065
+ anyres_grids=config.vision_encoder_config.anyres_grids,
2066
  )
2067
  # Initialize weights and apply final processing
2068
  self.post_init()
2069
+
2070
  @torch.no_grad()
2071
  def generate(
2072
  self,
 
2074
  input_ids: Optional[torch.LongTensor] = None,
2075
  attention_mask: Optional[torch.LongTensor] = None,
2076
  **generate_kwargs,
2077
+ ) -> torch.LongTensor:
2078
  self.vlm = self.vlm.eval()
2079
  return self.vlm.generate(
2080
+ vision_x=pixel_values,
2081
+ lang_x=input_ids,
2082
+ attention_mask=attention_mask,
2083
+ **generate_kwargs,
2084
+ )
2085
+
2086
  def update_special_tokens(self, tokenizer):
2087
  tokenizer.add_special_tokens(
2088
  {"additional_special_tokens": list(self.vlm.special_tokens.values())}
 
2090
  self.vlm.lang_model.config.vocab_size = len(tokenizer)
2091
  self.vlm.set_special_token_ids(
2092
  {
2093
+ v: tokenizer.convert_tokens_to_ids(v)
2094
+ for v in self.vlm.special_tokens.values()
2095
  }
2096
  )
2097
  return tokenizer
 
utils.py DELETED
@@ -1,383 +0,0 @@
1
- import torch
2
- import ast
3
- import math
4
- from PIL import Image
5
- from packaging.version import Version
6
-
7
- def has_fn(model, fn_name):
8
- """Check if model has a function fn_name"""
9
- return callable(getattr(model, fn_name, None))
10
-
11
- def exists(val):
12
- return val is not None
13
-
14
- def num_params(module, filter_to_trainable=False):
15
- """Returns the number of parameters in the module, or optionally only the trainable parameters"""
16
- if filter_to_trainable:
17
- return sum(p.numel() for p in module.parameters() if p.requires_grad)
18
- else:
19
- return sum(p.numel() for p in module.parameters())
20
-
21
- def hasattr_recursive(obj, att):
22
- """
23
- Check if obj has nested attribute
24
- Example: hasattr_recursive(obj, 'a.b.c') is equivalent to hasattr(obj, 'a') and hasattr(obj.a, 'b') and hasattr(obj.a.b, 'c')
25
- """
26
- if att == "":
27
- return True
28
- i = att.find(".")
29
- if i < 0:
30
- return hasattr(obj, att)
31
- else:
32
- try:
33
- return hasattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
34
- except:
35
- return False
36
-
37
- def getattr_recursive(obj, att):
38
- """
39
- Return nested attribute of obj
40
- Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
41
- """
42
- if att == "":
43
- return obj
44
- i = att.find(".")
45
- if i < 0:
46
- return getattr(obj, att)
47
- else:
48
- return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
49
-
50
-
51
- def setattr_recursive(obj, att, val):
52
- """
53
- Set nested attribute of obj
54
- Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
55
- """
56
- if "." in att:
57
- obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
58
- setattr(obj, att.split(".")[-1], val)
59
-
60
-
61
- def stack_with_padding(list_of_tensors, padding_value=0, padding_side="right"):
62
- """
63
- Stack a list of tensors with padding on one side
64
- Args:
65
- list_of_tensors (list[torch.Tensor]): List of tensors to stack
66
- padding_value (int, optional): Value to pad with. Defaults to 0.
67
- padding_side (str, optional): Side to pad on. Defaults to "right".
68
- Returns:
69
- torch.Tensor: Stacked tensors
70
- """
71
- max_tokens = max(tensor.size(0) for tensor in list_of_tensors)
72
- padded_tensors = []
73
- for tensor in list_of_tensors:
74
- num_tokens = tensor.size(0)
75
- if len(tensor.size()) == 1:
76
- padding = torch.full(
77
- (max_tokens - num_tokens,),
78
- padding_value,
79
- dtype=tensor.dtype,
80
- device=tensor.device,
81
- )
82
- else:
83
- padding = torch.full(
84
- (max_tokens - num_tokens, tensor.size(1)),
85
- padding_value,
86
- dtype=tensor.dtype,
87
- device=tensor.device,
88
- )
89
- padded_tensor = (
90
- torch.cat((tensor, padding), dim=0)
91
- if padding_side == "right"
92
- else torch.cat((padding, tensor), dim=0)
93
- )
94
- padded_tensors.append(padded_tensor)
95
- return torch.stack(padded_tensors)
96
-
97
-
98
- def check_embedding_fns(lang_model):
99
- """Checks for and attempts to set {get/set}_{input/output}_embeddings functions to the model"""
100
- if not has_fn(lang_model, "get_input_embeddings"):
101
- if hasattr_recursive(lang_model, "transformer.wte"): # MPT
102
- lang_model.get_input_embeddings = lambda: lang_model.transformer.wte
103
- elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): # OPT
104
- lang_model.get_input_embeddings = lambda: lang_model.decoder.embed_tokens
105
- else:
106
- raise ValueError(
107
- "We require the language encoder to have a get_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py."
108
- )
109
-
110
- if not has_fn(lang_model, "set_input_embeddings"):
111
- if hasattr_recursive(lang_model, "transformer.wte"): # MPT
112
- lang_model.set_input_embeddings = lambda x: setattr_recursive(
113
- lang_model, "transformer.wte", x
114
- )
115
- elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): # OPT
116
- lang_model.set_input_embeddings = lambda x: setattr_recursive(
117
- lang_model, "model.decoder.embed_tokens", x
118
- )
119
- else:
120
- raise ValueError(
121
- "We require the language encoder to have a set_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py."
122
- )
123
-
124
- if not has_fn(lang_model, "get_output_embeddings"):
125
- if hasattr_recursive(lang_model, "lm_head"):
126
- lang_model.get_output_embeddings = lambda: lang_model.lm_head
127
- else:
128
- raise ValueError(
129
- "We require the language encoder to have a get_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
130
- )
131
-
132
- if not has_fn(lang_model, "set_output_embeddings"):
133
- if hasattr_recursive(lang_model, "lm_head"):
134
- lang_model.set_output_embeddings = lambda x: setattr_recursive(
135
- lang_model, "lm_head", x
136
- )
137
- else:
138
- raise ValueError(
139
- "We require the language encoder to have a set_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
140
- )
141
-
142
-
143
- def has_fn(model, fn_name):
144
- """Check if model has a function fn_name"""
145
- return callable(getattr(model, fn_name, None))
146
-
147
-
148
- # Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
149
- #
150
- # Licensed under the Apache License, Version 2.0 (the "License");
151
- # you may not use this file except in compliance with the License.
152
- # You may obtain a copy of the License at
153
- #
154
- # http://www.apache.org/licenses/LICENSE-2.0
155
- #
156
- # Unless required by applicable law or agreed to in writing, software
157
- # distributed under the License is distributed on an "AS IS" BASIS,
158
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
159
- # See the License for the specific language governing permissions and
160
- # limitations under the License.
161
-
162
- def unpad_image(tensor, original_size, keep_original_shape=False):
163
- """
164
- Unpads a PyTorch tensor of a padded and resized image.
165
-
166
- Args:
167
- tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
168
- original_size (tuple): The original size of the image (height, width).
169
-
170
- Returns:
171
- torch.Tensor: The unpadded image tensor.
172
- """
173
- original_width, original_height = original_size
174
- current_height, current_width = tensor.shape[1:]
175
-
176
- original_aspect_ratio = original_width / original_height
177
- current_aspect_ratio = current_width / current_height
178
-
179
- if original_aspect_ratio > current_aspect_ratio:
180
- scale_factor = current_width / original_width
181
- new_height = int(original_height * scale_factor)
182
- padding = (current_height - new_height) // 2
183
- if keep_original_shape:
184
- attention_mask = torch.ones((current_height, current_width), device=tensor.device)
185
- attention_mask[:padding, :] = 0
186
- attention_mask[current_height - padding:, :] = 0
187
- return tensor, attention_mask
188
- else:
189
- unpadded_tensor = tensor[:, padding:current_height - padding, :]
190
- return unpadded_tensor, None
191
- else:
192
- scale_factor = current_height / original_height
193
- new_width = int(original_width * scale_factor)
194
- padding = (current_width - new_width) // 2
195
- if keep_original_shape:
196
- attention_mask = torch.ones((current_height, current_width), device=tensor.device)
197
- attention_mask[:, :padding] = 0
198
- attention_mask[:, current_width - padding:] = 0
199
- return tensor, attention_mask
200
- else:
201
- unpadded_tensor = tensor[:, :, padding:current_width - padding]
202
- return unpadded_tensor, None
203
-
204
-
205
- def select_best_resolution(original_size, possible_resolutions):
206
- """
207
- Selects the best resolution from a list of possible resolutions based on the original size.
208
-
209
- Args:
210
- original_size (tuple): The original size of the image in the format (width, height).
211
- possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
212
-
213
- Returns:
214
- tuple: The best fit resolution in the format (width, height).
215
- """
216
- original_width, original_height = original_size
217
- best_fit = None
218
- max_effective_resolution = 0
219
- min_wasted_resolution = float('inf')
220
-
221
- for width, height in possible_resolutions:
222
- scale = min(width / original_width, height / original_height)
223
- downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
224
- effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
225
- wasted_resolution = (width * height) - effective_resolution
226
-
227
- if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
228
- max_effective_resolution = effective_resolution
229
- min_wasted_resolution = wasted_resolution
230
- best_fit = (width, height)
231
-
232
- return best_fit
233
-
234
-
235
- def resize_and_pad_image(image, target_resolution):
236
- """
237
- Resize and pad an image to a target resolution while maintaining aspect ratio.
238
-
239
- Args:
240
- image (PIL.Image.Image): The input image.
241
- target_resolution (tuple): The target resolution (width, height) of the image.
242
-
243
- Returns:
244
- PIL.Image.Image: The resized and padded image.
245
- """
246
- original_width, original_height = image.size
247
- target_width, target_height = target_resolution
248
-
249
- scale_w = target_width / original_width
250
- scale_h = target_height / original_height
251
-
252
- if scale_w < scale_h:
253
- new_width = target_width
254
- new_height = min(math.ceil(original_height * scale_w), target_height)
255
- else:
256
- new_height = target_height
257
- new_width = min(math.ceil(original_width * scale_h), target_width)
258
-
259
- # Resize the image
260
- resized_image = image.resize((new_width, new_height))
261
-
262
- new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
263
- paste_x = (target_width - new_width) // 2
264
- paste_y = (target_height - new_height) // 2
265
- new_image.paste(resized_image, (paste_x, paste_y))
266
-
267
- return new_image
268
-
269
-
270
- def divide_to_patches(image, patch_size):
271
- """
272
- Divides an image into patches of a specified size.
273
-
274
- Args:
275
- image (PIL.Image.Image): The input image.
276
- patch_size (int): The size of each patch.
277
-
278
- Returns:
279
- list: A list of PIL.Image.Image objects representing the patches.
280
- """
281
- patches = []
282
- width, height = image.size
283
- for i in range(0, height, patch_size):
284
- for j in range(0, width, patch_size):
285
- box = (j, i, j + patch_size, i + patch_size)
286
- patch = image.crop(box)
287
- patches.append(patch)
288
-
289
- return patches
290
-
291
-
292
- def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
293
- """
294
- Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
295
-
296
- Args:
297
- image_size (tuple): The size of the input image in the format (width, height).
298
- grid_pinpoints (str): A string representation of a list of possible resolutions.
299
- patch_size (int): The size of each image patch.
300
-
301
- Returns:
302
- tuple: The shape of the image patch grid in the format (width, height).
303
- """
304
- if type(grid_pinpoints) is list:
305
- possible_resolutions = grid_pinpoints
306
- else:
307
- possible_resolutions = ast.literal_eval(grid_pinpoints)
308
- width, height = select_best_resolution(image_size, possible_resolutions)
309
- return width // patch_size, height // patch_size
310
-
311
-
312
- def process_anyres_image(image, processor, grid_pinpoints):
313
- """
314
- Process an image with variable resolutions.
315
-
316
- Args:
317
- image (PIL.Image.Image): The input image to be processed.
318
- processor: The image processor object.
319
- grid_pinpoints (str): A string representation of a list of possible resolutions.
320
-
321
- Returns:
322
- torch.Tensor: A tensor containing the processed image patches.
323
- """
324
- # FIXME: determine grid_pinpoints from image sizes.
325
- if type(grid_pinpoints) is list:
326
- possible_resolutions = grid_pinpoints
327
- else:
328
- possible_resolutions = ast.literal_eval(grid_pinpoints)
329
- best_resolution = select_best_resolution(image.size, possible_resolutions)
330
- image_padded = resize_and_pad_image(image, best_resolution)
331
-
332
- processor_size = processor.transforms[0].size
333
- patches = divide_to_patches(image_padded, processor_size[0])
334
-
335
- image_original_resize = image.resize((processor_size[0], processor_size[0]))
336
-
337
- image_patches = [image_original_resize] + patches
338
- image_patches = [processor(image_patch)
339
- for image_patch in image_patches]
340
- return torch.stack(image_patches, dim=0)
341
-
342
-
343
- def expand2square(pil_img, background_color):
344
- width, height = pil_img.size
345
- if width == height:
346
- return pil_img
347
- elif width > height:
348
- result = Image.new(pil_img.mode, (width, width), background_color)
349
- result.paste(pil_img, (0, (width - height) // 2))
350
- return result
351
- else:
352
- result = Image.new(pil_img.mode, (height, height), background_color)
353
- result.paste(pil_img, ((height - width) // 2, 0))
354
- return result
355
-
356
-
357
- def process_images(images, image_processor, model_cfg):
358
- image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
359
- new_images = []
360
- if image_aspect_ratio == 'pad':
361
- for image in images:
362
- image = expand2square(image, tuple(int(x*255) for x in image_processor.transforms[-1].mean))
363
- image = image_processor(image)
364
- new_images.append(image)
365
- elif image_aspect_ratio in ["anyres", "anyres-legacy"]:
366
- base_img_size = image_processor.transforms[0].size[0]
367
- for image in images:
368
- image = process_anyres_image(image, image_processor, [[base_img_size,base_img_size*2],
369
- [base_img_size*2,base_img_size],
370
- [base_img_size*2,base_img_size*2],
371
- [base_img_size*3,base_img_size],
372
- [base_img_size,base_img_size*3]])
373
-
374
- # Debug any res inference by only using 672x672.
375
- # image = process_anyres_image(image, image_processor, [[base_img_size*2,base_img_size*2]])
376
- new_images.append(image)
377
- else:
378
- return image_processor(images)
379
- if all(x.shape == new_images[0].shape for x in new_images):
380
- new_images = torch.stack(new_images, dim=0)
381
- return new_images
382
-
383
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vlm.py DELETED
@@ -1,1381 +0,0 @@
1
-
2
- import torch
3
- from torch import einsum, nn
4
- from einops import rearrange, repeat
5
- from einops_exts import rearrange_many
6
- from einops import rearrange
7
- from typing import List, Optional, Tuple, Union
8
- import torch.nn.functional as F
9
- from transformers.modeling_outputs import CausalLMOutputWithPast
10
- from dataclasses import dataclass
11
- from transformers import CLIPVisionModel
12
- from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer
13
-
14
- import transformers
15
- from packaging.version import Version
16
-
17
- from .utils import num_params, getattr_recursive, stack_with_padding, get_anyres_image_grid_shape, unpad_image
18
-
19
-
20
- class VisionTokenizer(nn.Module):
21
- def __init__(self, dim_media, num_tokens_per_media):
22
- super().__init__()
23
- self.dim_media = dim_media
24
- self.num_tokens_per_media = num_tokens_per_media
25
-
26
- class PerceiverAttention(nn.Module):
27
- def __init__(self, *, dim, dim_head=64, heads=8):
28
- super().__init__()
29
- self.scale = dim_head**-0.5
30
- self.heads = heads
31
- inner_dim = dim_head * heads
32
-
33
- self.norm_media = nn.LayerNorm(dim)
34
- self.norm_latents = nn.LayerNorm(dim)
35
-
36
- self.to_q = nn.Linear(dim, inner_dim, bias=False)
37
- self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
38
- self.to_out = nn.Linear(inner_dim, dim, bias=False)
39
-
40
- def forward(self, x, latents, vision_attn_masks=None):
41
- """
42
- Args:
43
- x (torch.Tensor): image features
44
- shape (b, T, n1, D)
45
- latent (torch.Tensor): latent features
46
- shape (b, T, n2, D)
47
- """
48
- x = self.norm_media(x)
49
- latents = self.norm_latents(latents)
50
-
51
- h = self.heads
52
-
53
- q = self.to_q(latents)
54
- kv_input = torch.cat((x, latents), dim=-2) # TODO: Change the shape of vision attention mask according to this.
55
- if vision_attn_masks is not None:
56
- vision_attn_masks = torch.cat((vision_attn_masks,
57
- torch.ones((latents.shape[0], latents.shape[-2]), dtype=latents.dtype, device=latents.device)),
58
- dim=-1)
59
- k, v = self.to_kv(kv_input).chunk(2, dim=-1)
60
- q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
61
- q = q * self.scale
62
-
63
- # attention
64
- sim = einsum("... i d, ... j d -> ... i j", q, k)
65
- # Apply vision attention mask here.
66
- # Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention
67
- if vision_attn_masks is not None:
68
- attn_bias = torch.zeros((q.size(0), 1, 1, q.size(-2), k.size(-2)), dtype=q.dtype, device=q.device)
69
- vision_attn_masks = repeat(vision_attn_masks, 'b n -> b 1 1 l n', l=q.size(-2))
70
- attn_bias.masked_fill_(vision_attn_masks.logical_not(), float("-inf"))
71
- sim += attn_bias
72
-
73
- sim = sim - sim.amax(dim=-1, keepdim=True).detach()
74
- attn = sim.softmax(dim=-1)
75
-
76
-
77
- out = einsum("... i j, ... j d -> ... i d", attn, v)
78
- out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
79
- return self.to_out(out)
80
-
81
-
82
- def FeedForward(dim, mult=4):
83
- inner_dim = int(dim * mult)
84
- return nn.Sequential(
85
- nn.LayerNorm(dim),
86
- nn.Linear(dim, inner_dim, bias=False),
87
- nn.GELU(),
88
- nn.Linear(inner_dim, dim, bias=False),
89
- )
90
-
91
-
92
- class PerceiverResampler(VisionTokenizer):
93
- def __init__(
94
- self,
95
- *,
96
- dim,
97
- dim_inner=None,
98
- depth=6,
99
- dim_head=96,
100
- heads=16,
101
- num_latents=128,
102
- max_num_media=None,
103
- max_num_frames=None,
104
- ff_mult=4,
105
- ):
106
- """
107
- Perceiver module which takes in image features and outputs image tokens.
108
- Args:
109
- dim (int): dimension of the incoming image features
110
- dim_inner (int, optional): final dimension to project the incoming image features to;
111
- also the final dimension of the outputted features. If None, no projection is used, and dim_inner = dim.
112
- depth (int, optional): number of layers. Defaults to 6.
113
- dim_head (int, optional): dimension of each head. Defaults to 64.
114
- heads (int, optional): number of heads. Defaults to 8.
115
- num_latents (int, optional): number of latent tokens to use in the Perceiver;
116
- also corresponds to number of tokens per sequence to output. Defaults to 64.
117
- max_num_media (int, optional): maximum number of media per sequence to input into the Perceiver
118
- and keep positional embeddings for. If None, no positional embeddings are used.
119
- max_num_frames (int, optional): maximum number of frames to input into the Perceiver
120
- and keep positional embeddings for. If None, no positional embeddings are used.
121
- ff_mult (int, optional): dimension multiplier for the feedforward network. Defaults to 4.
122
- """
123
- if dim_inner is not None:
124
- projection = nn.Linear(dim, dim_inner)
125
- else:
126
- projection = None
127
- dim_inner = dim
128
- super().__init__(dim_media=dim, num_tokens_per_media=num_latents)
129
- self.projection = projection
130
- self.latents = nn.Parameter(torch.randn(num_latents, dim))
131
-
132
- # positional embeddings
133
- self.frame_embs = (
134
- nn.Parameter(torch.randn(max_num_frames, dim))
135
- if exists(max_num_frames)
136
- else None
137
- )
138
- self.media_time_embs = (
139
- nn.Parameter(torch.randn(max_num_media, 1, dim))
140
- if exists(max_num_media)
141
- else None
142
- )
143
-
144
- self.layers = nn.ModuleList([])
145
- for _ in range(depth):
146
- self.layers.append(
147
- nn.ModuleList(
148
- [
149
- PerceiverAttention(
150
- dim=dim, dim_head=dim_head, heads=heads
151
- ),
152
- FeedForward(dim=dim, mult=ff_mult),
153
- ]
154
- )
155
- )
156
-
157
- self.norm = nn.LayerNorm(dim)
158
-
159
- def forward(self, x, vision_attn_masks):
160
- """
161
- Args:
162
- x (torch.Tensor): image features
163
- shape (b, T, F, v, D)
164
- vision_attn_masks (torch.Tensor): attention masks for padded visiont tokens (i.e., x)
165
- shape (b, v)
166
- Returns:
167
- shape (b, T, n, D) where n is self.num_latents
168
- """
169
- b, T, F, v = x.shape[:4]
170
-
171
- # frame and media time embeddings
172
- if exists(self.frame_embs):
173
- frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
174
- x = x + frame_embs
175
- x = rearrange(
176
- x, "b T F v d -> b T (F v) d"
177
- ) # flatten the frame and spatial dimensions
178
- if exists(self.media_time_embs):
179
- x = x + self.media_time_embs[:T]
180
-
181
- # blocks
182
- latents = self.latents
183
- latents = repeat(latents, "n d -> b T n d", b=b, T=T)
184
- for attn, ff in self.layers:
185
- latents = attn(x, latents, vision_attn_masks) + latents
186
- latents = ff(latents) + latents
187
-
188
- if exists(self.projection):
189
- return self.projection(self.norm(latents))
190
- else:
191
- return self.norm(latents)
192
-
193
-
194
- class DecoupledEmbedding(nn.Embedding):
195
- # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding
196
- """
197
- Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. In practise, the
198
- regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0,
199
- then it will create `num_additional_embeddings` additional parameters that are always trained. If
200
- `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`.
201
- """
202
-
203
- def __init__(
204
- self,
205
- max_original_id: int,
206
- num_additional_embeddings: int = 0,
207
- _weight: torch.Tensor = None,
208
- num_original_embeddings: int = None,
209
- embedding_dim: int = None,
210
- partially_freeze=True,
211
- device=None,
212
- dtype=None,
213
- pad_token_id=None,
214
- ) -> None:
215
- """
216
- Args:
217
- max_original_id (`int`):
218
- The largest token id that should be embedded using the regular embedding (regular `weight`).
219
- This is usually len(tokenizer) - 1 before additional tokens are added.
220
- Note that this may not equal self.weight.shape[0]
221
- num_additional_embeddings (`int`):
222
- Number of additional tokens to initialize an Embedding matrix for (`additional_weight`).
223
- _weight (`torch.Tensor`, *optional*, defaults to `None`): The regular weight tensor.
224
- If provided, this sets the `num_original_embeddings` and `embedding_dim` parameters.
225
- num_original_embeddings (`int`):
226
- self.weight.shape[0]
227
- embedding_dim (`int`):
228
- The size of each embedding vector
229
- partially_freeze: (`bool`, *optional*, defaults to `True`):
230
- If `True`, the regular `weight` will be frozen. `additional_weight` is never frozen.
231
- padding_idx (`int`, *optional*):
232
- The padding index (needs to be less than num_embeddings)
233
-
234
- Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`,
235
- `max_norm` or `norm_type`. We are not supporting these.
236
- """
237
- # validate args
238
- if pad_token_id is not None and pad_token_id > max_original_id:
239
- raise ValueError(
240
- f"pad_token_id must be <= max_original_id. Got {pad_token_id} and {max_original_id}."
241
- + "If the original tokenizer does not have a pad_token_id, use pad_token_id=None."
242
- )
243
- if _weight is not None:
244
- assert (num_original_embeddings is None) or (
245
- _weight.shape[0] == num_original_embeddings
246
- ), f"num_original_embeddings={num_original_embeddings} but _weight.shape[0]={_weight.shape[0]}"
247
- assert (embedding_dim is None) or (
248
- _weight.shape[1] == embedding_dim
249
- ), f"embedding_dim={embedding_dim} but _weight.shape[1]={_weight.shape[1]}"
250
- num_original_embeddings = _weight.shape[0]
251
- embedding_dim = _weight.shape[1]
252
- else:
253
- assert (
254
- num_original_embeddings is not None
255
- ), "num_original_embeddings must be provided if _weight is not provided"
256
- assert (
257
- embedding_dim is not None
258
- ), "embedding_dim must be provided if _weight is not provided"
259
-
260
- super().__init__(
261
- num_embeddings=num_original_embeddings,
262
- embedding_dim=embedding_dim,
263
- device=device,
264
- dtype=dtype,
265
- padding_idx=pad_token_id,
266
- _weight=_weight,
267
- )
268
- self.max_original_id = max_original_id
269
- self.padding_idx = pad_token_id
270
- self.num_additional_embeddings = num_additional_embeddings
271
- if self.num_additional_embeddings > 0:
272
- self.additional_embedding = nn.Embedding(
273
- num_embeddings=self.num_additional_embeddings,
274
- embedding_dim=embedding_dim,
275
- device=device,
276
- dtype=dtype,
277
- )
278
- self.set_requires_grad(
279
- require_regular_grad=not partially_freeze, require_additional_grad=True
280
- )
281
-
282
- def set_requires_grad(self, require_regular_grad, require_additional_grad):
283
- """
284
- Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
285
- """
286
- self.weight.requires_grad_(require_regular_grad)
287
- self.additional_embedding.requires_grad_(require_additional_grad)
288
-
289
- def forward(self, input_ids):
290
- """
291
- we have 2 embeddings, with different indices - one pretrained self.weight and another
292
- self.additional_embedding.weight that is being trained.
293
-
294
- in order to make a lookup of the input ids, we:
295
- 1. find out the indices of the entries belonging to the 2nd embedding
296
- 2. extract those values while subtracting the size of the first embedding (num_embeddings), since the 2nd
297
- embedding starts from 0 and not num_embeddings
298
- 3. perform the 2nd embedding lookup
299
- 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index
300
- 5. perform the 1st embedding lookup
301
- 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup
302
-
303
- note: for the 1st embedding lookup we could have looked up only the low indices and not do the padding, but
304
- then we have to create a new tensor and populate it with 2 tensors that are spread out across various indices -
305
- i.e. not a simple concat - I haven't benchmarked the complex case if it's any faster, given that seqlens are
306
- usually relatively short it's probably not faster or if faster not by much - but might be a good idea to
307
- measure.
308
-
309
- """
310
- if self.num_additional_embeddings == 0:
311
- return F.embedding(input_ids, self.weight)
312
-
313
- # Clone so that we don't modify the original input_ids later on
314
- input_ids = input_ids.clone()
315
- additional_vocab_indices = torch.where(input_ids > self.max_original_id)
316
- input_ids_additional_vocab = input_ids[additional_vocab_indices]
317
- additional_embeddings = self.additional_embedding(
318
- input_ids_additional_vocab - self.max_original_id - 1
319
- )
320
-
321
- # for successful lookup replace input_ids with 0, the results of these will be discarded anyway
322
- input_ids[additional_vocab_indices] = 0
323
- full_vector = F.embedding(input_ids, self.weight)
324
-
325
- # overwrite the records with high indices
326
- full_vector[additional_vocab_indices] = additional_embeddings
327
-
328
- return full_vector
329
-
330
- def extra_repr(self) -> str:
331
- return "num_original_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format(
332
- self.max_original_id + 1,
333
- self.num_additional_embeddings,
334
- self.embedding_dim,
335
- (not self.weight.requires_grad),
336
- )
337
-
338
-
339
- class DecoupledLinear(nn.Linear):
340
- # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
341
- """
342
- Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the
343
- regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `additional_out_features` > 0,
344
- then it will create `additional_out_features * in_features` additional parameters that are always trained. If
345
- `additional_out_features=0`, then the module defaults back to the regular behavior of `nn.Linear`.
346
- """
347
-
348
- def __init__(
349
- self,
350
- max_original_id: int,
351
- additional_out_features: int = 0,
352
- _weight: torch.Tensor = None,
353
- _bias: torch.Tensor = None,
354
- in_features: int = None,
355
- original_out_features: int = None,
356
- bias: bool = True,
357
- partially_freeze: bool = True,
358
- device=None,
359
- dtype=None,
360
- ) -> None:
361
- """
362
- Args:
363
- max_original_id (`int`): The largest token id that should be extracted from the regular weight.
364
- This is usually len(tokenizer) - 1 before additional tokens are added.
365
- Note that this may not equal original_out_features - 1
366
- _weight: torch.Tensor, *optional*, defaults to `None`. The regular weight tensor.
367
- If provided, this sets the `in_features` and `original_out_features` parameters.
368
- _bias: torch.Tensor, *optional*, defaults to `None`. The regular bias tensor.
369
- in_features: int. Input hidden size.
370
- original_out_features: int. Original out_features of the language model's get_output_embeddings() function.
371
- additional_out_features: int. Number of additional trainable dimensions.
372
- bias: bool. Whether to include a bias term.
373
- partially_freeze: bool, *optional*, defaults to `True`): If `True`, the regular `weight` will be frozen.
374
- """
375
- # argument validation
376
- if _weight is not None:
377
- assert (_weight.shape[0] == original_out_features) or (
378
- original_out_features is None
379
- ), f"original_out_features={original_out_features} but _weight.shape[0]={_weight.shape[0]}"
380
- assert (_weight.shape[1] == in_features) or (
381
- in_features is None
382
- ), f"in_features={in_features} but _weight.shape[1]={_weight.shape[1]}"
383
- in_features = _weight.shape[1]
384
- original_out_features = _weight.shape[0]
385
- else:
386
- assert (
387
- in_features is not None
388
- ), "in_features must be provided if _weight is not provided"
389
- assert (
390
- original_out_features is not None
391
- ), "original_out_features must be provided if _weight is not provided"
392
-
393
- if _bias is not None:
394
- assert bias is True, "bias must be True if _bias is provided"
395
-
396
- # initialize original linear
397
- super().__init__(
398
- in_features,
399
- original_out_features,
400
- bias,
401
- device,
402
- dtype)
403
-
404
- # set weight and bias manually
405
- if _weight is not None:
406
- self.weight = nn.Parameter(_weight)
407
- if _bias is not None:
408
- self.bias = nn.Parameter(_bias)
409
-
410
- self.in_features = in_features
411
- self.original_out_features = original_out_features
412
- self.max_original_id = max_original_id
413
-
414
- # initialize additional linear
415
- self.additional_out_features = additional_out_features
416
- self.has_bias = bias
417
- if additional_out_features > 0:
418
- self.additional_fc = nn.Linear(
419
- in_features=in_features,
420
- out_features=additional_out_features,
421
- bias=self.has_bias,
422
- device=device,
423
- dtype=dtype,
424
- )
425
- self.set_requires_grad(
426
- require_regular_grad=not partially_freeze, require_additional_grad=True
427
- )
428
-
429
- def set_requires_grad(self, require_regular_grad, require_additional_grad):
430
- """
431
- Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
432
- """
433
- self.weight.requires_grad_(require_regular_grad)
434
- if self.has_bias:
435
- self.bias.requires_grad_(require_regular_grad)
436
- self.additional_fc.requires_grad_(require_additional_grad)
437
-
438
- def forward(self, input: torch.Tensor) -> torch.Tensor:
439
- output = F.linear(input, self.weight, self.bias)
440
- output = output[..., : self.max_original_id + 1]
441
-
442
- if self.additional_out_features > 0:
443
- additional_features = F.linear(
444
- input, self.additional_fc.weight, self.additional_fc.bias
445
- )
446
- output = torch.cat((output, additional_features), -1)
447
- return output
448
-
449
- def extra_repr(self) -> str:
450
- """Overwriting `nn.Linear.extra_repr` to include new parameters."""
451
- return "in_features={}, out_features={}, additional_out_features={}, bias={}, partially_freeze={}".format(
452
- self.in_features,
453
- self.max_original_id + 1,
454
- self.additional_out_features,
455
- self.bias is not None,
456
- (not self.weight.requires_grad or not self.bias.requires_grad),
457
- )
458
-
459
- class VLM(nn.Module):
460
- """
461
- Generic vision-language model (VLM) class.
462
- A VLM consists of four components:
463
- 1. A vision encoder that extracts features from pixels, e.g. CLIP
464
- input: (B, T_img, F, C, H, W)
465
- output: (B, T_img, F, v, d)
466
- 2. A vision tokenizer that converts these features to visual token-like embeddings, e.g. Perceiver, or a linear projection head
467
- input: (B, T_img, F, v, d)
468
- output: (B, T_img, n, d)
469
- 3. A fusion method that allows the language model to attend to these tokens, e.g. cross-attention, or placing the tokens directly in the language model's input sequence
470
- 4. A language model
471
- """
472
-
473
- def __init__(
474
- self,
475
- vision_encoder: nn.Module,
476
- vision_tokenizer: nn.Module,
477
- lang_model: nn.Module,
478
- initial_tokenizer_len: int,
479
- pad_token_id: int,
480
- gradient_checkpointing: bool = False,
481
- ):
482
- """
483
- Args:
484
- vision_encoder (nn.Module): e.g. CLIP
485
- vision_tokenizer (nn.Module): e.g. PerceiverResampler
486
- lang_model (nn.Module): e.g. MPT
487
- initial_tokenizer_len (int): size of the original tokenizer vocab
488
- pad_token_id (int): id of the pad token
489
- gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False.
490
- """
491
- super().__init__()
492
-
493
- # save dimension information
494
- self.lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]
495
- if hasattr(lang_model.config, "d_model"):
496
- self.lang_hidden_dim = lang_model.config.d_model # mpt uses d_model
497
- else:
498
- self.lang_hidden_dim = lang_model.config.hidden_size
499
- self.vis_embedding_dim = vision_tokenizer.dim_media
500
- self.num_tokens_per_vis = vision_tokenizer.num_tokens_per_media
501
-
502
- # core components
503
- self.vision_encoder = vision_encoder
504
- self.vision_tokenizer = vision_tokenizer
505
- self.lang_model = lang_model
506
-
507
- # lm embeddings
508
- self.pad_token_id = pad_token_id
509
- self.initial_tokenizer_len = initial_tokenizer_len
510
- input_embeds = DecoupledEmbedding(
511
- max_original_id=initial_tokenizer_len - 1,
512
- num_additional_embeddings=len(self.special_tokens),
513
- _weight=self.lang_model.get_input_embeddings().weight,
514
- pad_token_id=self.pad_token_id,
515
- )
516
- if hasattr(input_embeds, "additional_embedding"):
517
- input_embeds.additional_embedding.weight.data.normal_(
518
- mean=0.0,
519
- std=self.lang_model.config.initializer_range
520
- if hasattr(self.lang_model.config, "initializer_range")
521
- else 0.02,
522
- )
523
- self.lang_model.set_input_embeddings(input_embeds)
524
-
525
- out_embeds = DecoupledLinear(
526
- max_original_id=initial_tokenizer_len - 1,
527
- additional_out_features=len(self.special_tokens),
528
- _weight=self.lang_model.get_output_embeddings().weight,
529
- _bias=self.lang_model.get_output_embeddings().bias if hasattr(self.lang_model.get_output_embeddings(), "bias") else None,
530
- )
531
- if hasattr(out_embeds, "additional_fc"):
532
- out_embeds.additional_fc.weight.data.normal_(
533
- mean=0.0,
534
- std=self.lang_model.config.initializer_range
535
- if hasattr(self.lang_model.config, "initializer_range")
536
- else 0.02,
537
- )
538
- self.lang_model.set_output_embeddings(out_embeds)
539
-
540
- # gradient checkpointing
541
- self.vision_tokenizer._use_gradient_checkpointing = gradient_checkpointing
542
-
543
- def forward(
544
- self,
545
- vision_x: Optional[torch.Tensor],
546
- lang_x: torch.Tensor,
547
- attention_mask: Optional[torch.Tensor] = None,
548
- labels: Optional[torch.Tensor] = None,
549
- past_key_values: Optional[
550
- List[Union[torch.Tensor, Tuple[torch.Tensor]]]
551
- ] = None,
552
- past_media_locations: Optional[torch.Tensor] = None,
553
- past_vision_tokens: Optional[torch.Tensor] = None,
554
- use_cache: Optional[bool] = False,
555
- **kwargs,
556
- ):
557
- """
558
- Args:
559
- vision_x: Vision input
560
- shape (B, T_img, F, C, H, W) with F=1
561
- only F = 1 is supported (single-frame videos)
562
- if T_img > the number of media tokens in the corresponding input_ids (lang_x),
563
- only the first number of media tokens in lang_x are used
564
- lang_x: Language input ids, with media tokens denoting where
565
- visual media should be inserted.
566
- shape (B, T_txt)
567
- attention_mask: Attention mask. Defaults to None.
568
- labels: Labels. Defaults to None.
569
- shape (B, T_txt)
570
- past_key_values (Tuple[torch.Tensor]], optional): Past key value pairs for each of the T_txt previous tokens in the language model. Defaults to None.
571
- list of length = number of decoder layers in the LM
572
- exact implementation depends on LM, see Hugging Face docs
573
- past_media_locations (torch.Tensor, optional): boolean mask denoting which of the previous T_txt tokens were media tokens. Defaults to None.
574
- shape (B, T_txt)
575
- past_vision_tokens (torch.Tensor, optional): Previous vision tokens. Defaults to None.
576
- use_cache (Optional[bool], optional): Whether to use cache. Defaults to False.
577
- If True, includes key_values, media_locations, and vision_tokens in the output.
578
- """
579
- assert not (past_vision_tokens is None) ^ (
580
- past_media_locations is None
581
- ), "past_vision_tokens and past_media_locations must both be None or both be not None"
582
-
583
- # convert pixels to vision tokens
584
- if vision_x is not None:
585
- vision_features = self._encode_vision_x(vision_x=vision_x)
586
- vision_tokens = self.vision_tokenizer(vision_features)
587
- else:
588
- vision_tokens = None
589
-
590
- # fuse the vision and language tokens
591
- new_inputs = self._prepare_inputs_for_forward(
592
- vision_tokens=vision_tokens,
593
- lang_x=lang_x,
594
- attention_mask=attention_mask,
595
- labels=labels,
596
- past_key_values=past_key_values,
597
- past_media_locations=past_media_locations,
598
- padding_side="right",
599
- past_vision_tokens=past_vision_tokens,
600
- )
601
- output = self.lang_model(
602
- **new_inputs,
603
- use_cache=use_cache,
604
- past_key_values=past_key_values,
605
- **kwargs,
606
- )
607
-
608
- # postprocessing may be needed, e.g. to remove extra tokens from logits that were inserted into the language stream
609
- # or to add the past_vision_tokens and past_media_locations to the output
610
- output = self._postprocess_outputs_from_forward(
611
- output=output,
612
- lang_x=lang_x,
613
- vision_tokens=vision_tokens,
614
- use_cache=use_cache,
615
- past_vision_tokens=past_vision_tokens,
616
- past_media_locations=past_media_locations,
617
- )
618
-
619
- # postforward hooks
620
- self._post_forward_hook()
621
- return output
622
-
623
- def _encode_vision_x_anyres(self, samples, device):
624
- assert self.anyres_grids is not None
625
- image_raw = samples["image"] # list of patch list in of shape [1, N_patch, C, H, W]
626
- image_sizes = samples["image_size"]
627
-
628
- # Image_raw can be a list of list of patches, when a `samples` has multiple images.
629
- if isinstance(image_raw[0], list):
630
- images = [x.squeeze(0) for sample_img in image_raw for x in sample_img]
631
- image_sizes = [s for sample_sizes in image_sizes for s in sample_sizes]
632
- else:
633
- # assert isinstance(image_raw[0], torch.Tensor), f"Unkown image type: {image_raw[0]}"
634
- # concate list of patches into one big patch for any res encoding.
635
- images = [x.squeeze(0) for x in image_raw] # [N_patch, C, H, W]
636
- image = torch.cat(images, dim=0) # [\sum{B}{N_patch_i}, C, H, W]
637
- image = image.to(device)
638
-
639
- with torch.no_grad():
640
- if self.vision_encoder.__class__.__name__ == "TimmModel":
641
- image_embeds = self.vision_encoder.trunk.forward_features(image)
642
- elif self.vision_encoder.__class__.__name__ in ['CLIPVisionModel', 'SiglipVisionTransformer']:
643
- image_embeds = self.vision_encoder(image).last_hidden_state
644
- else:
645
- image_embeds = self.vision_encoder(image)[1] # OpenCLIP returns tuples
646
-
647
- if isinstance(self.vision_encoder, CLIPVisionModel) or isinstance(self.vision_encoder, SiglipVisionTransformer):
648
- base_img_size = self.vision_encoder.config.image_size
649
- else:
650
- base_img_size = self.vision_encoder.image_size[0]
651
-
652
- if self.vision_encoder.__class__.__name__ == "TimmModel":
653
- grid_size = self.vision_encoder.trunk.patch_embed.grid_size
654
- elif self.vision_encoder.__class__.__name__ in ['CLIPVisionModel', 'SiglipVisionTransformer']:
655
- grid_size_base = self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size
656
- grid_size = (grid_size_base, grid_size_base)
657
- else:
658
- grid_size = self.vision_encoder.grid_size
659
- height, width = grid_size
660
-
661
- if not image_embeds.shape[1] == height * width:
662
- assert image_embeds.shape[1] == height * width + 1 # For vision encoders that has [CLS] token.
663
- image_embeds = image_embeds[:, 1:, :] # Drop the cls token for each patch.
664
- n_vis_token_per_patch = image_embeds.shape[1]
665
-
666
- # Split encoded patches and merge patch features
667
- # 1. Get the raw sizes from samples, and split the image embeds [\sum_{B}(N_patch_i), N_tok(16*16), C]
668
- split_sizes = [image.shape[0] for image in images]
669
- image_embeds = torch.split(image_embeds, split_sizes, dim=0)
670
- # 2. For each image (consist of a list of patches), merge the patches spatially (of shape [C, n_patch_height, n_patch_width])
671
- new_image_embeds = []
672
- patch_attn_masks = []
673
- max_n_img_token = -1
674
- for idx, patch_embeds in enumerate(image_embeds):
675
- if patch_embeds.shape[0] > 1:
676
- # 3. Flatten the patch features and get [C, n_patch_height * (n_patch_width+1)]
677
- base_patch_embeds = patch_embeds[0] # TODO: prepend the CLS token for th base patch embeds (of the resized entire image).
678
- patch_embeds = patch_embeds[1:]
679
-
680
- assert height * width == base_patch_embeds.shape[0]
681
-
682
- num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[idx],
683
- self.anyres_grids,
684
- base_img_size) # Hardcoded grid_pinpoints.
685
- patch_embeds = patch_embeds.view(num_patch_height, num_patch_width, height, width, -1)
686
-
687
- patch_embeds = patch_embeds.permute(4, 0, 2, 1, 3).contiguous()
688
- patch_embeds = patch_embeds.flatten(1, 2).flatten(2, 3)
689
- patch_embeds, patch_attn_mask = unpad_image(patch_embeds, image_sizes[idx], self.anyres_patch_sampling)
690
- if hasattr(self, 'image_newline'):
691
- patch_embeds = torch.cat((
692
- patch_embeds,
693
- self.image_newline[:, None, None].expand(*patch_embeds.shape[:-1], 1)
694
- ), dim=-1)
695
- if self.anyres_patch_sampling:
696
- patch_embeds = patch_embeds.view(-1, num_patch_height, num_patch_width, height*width)
697
- patch_embeds = patch_embeds.flatten(1, 2).permute(1, 2, 0)
698
- assert patch_attn_mask is not None
699
- patch_attn_mask = patch_attn_mask.view(num_patch_height, num_patch_width, height*width)
700
- patch_attn_mask = patch_attn_mask.flatten(0, 1)
701
- patch_embeds = torch.cat((base_patch_embeds.unsqueeze(0), patch_embeds), dim=0)
702
- patch_attn_mask = torch.cat((torch.ones(n_vis_token_per_patch, device=patch_embeds.device).unsqueeze(0), patch_attn_mask), dim=0)
703
- else:
704
- patch_embeds = patch_embeds.flatten(1, 2).transpose(0, 1)
705
- patch_embeds = torch.cat((base_patch_embeds, patch_embeds), dim=0)
706
- else:
707
- patch_embeds = patch_embeds[0].unsqueeze(0) if self.anyres_patch_sampling else patch_embeds[0]
708
- patch_attn_mask = torch.ones(n_vis_token_per_patch, device=patch_embeds.device).unsqueeze(0) if self.anyres_patch_sampling else None
709
- if hasattr(self, 'image_newline'):
710
- patch_embeds = torch.cat((
711
- patch_embeds,
712
- self.image_newline[None]
713
- ), dim=0)
714
- if not self.anyres_patch_sampling:
715
- max_n_img_token = max(patch_embeds.shape[0], max_n_img_token)
716
-
717
- new_image_embeds.append(patch_embeds)
718
- patch_attn_masks.append(patch_attn_mask)
719
-
720
- if self.anyres_patch_sampling:
721
- # Return individual patches for independent token downsampling.
722
- return new_image_embeds, patch_attn_masks
723
-
724
- # 4. Pad and concat the list of image_embeds [N_tok_i, C] together into a batch. Also modify the query attention mask.
725
- image_embeds = []
726
- image_atts = []
727
- for image_embed in new_image_embeds:
728
- n_img_token = image_embed.shape[0]
729
- img_attn = torch.ones((max_n_img_token), dtype=torch.long, device=image_embed.device)
730
- if n_img_token < max_n_img_token:
731
- padded_embed = torch.zeros((max_n_img_token, image_embed.shape[-1]), dtype=image_embed.dtype, device=image_embed.device)
732
- padded_embed[:n_img_token, :] = image_embed
733
- img_attn[n_img_token:] = 0 # Mask out the padded entries.
734
- else:
735
- padded_embed = image_embed
736
- image_embeds.append(padded_embed)
737
- image_atts.append(img_attn)
738
- image_embeds = torch.stack(image_embeds, dim=0) # Shape [B, N_tok_longest, C_dim]
739
- image_atts = torch.stack(image_atts, dim=0) # Shape [B, N_tok_longest, C_dim]
740
- # TODO: reshape image_embeds and image_atts to "b T F v d"
741
- image_embeds = image_embeds[:, None, None, :, :]
742
- # image_atts = image_atts[:, None, None, :, :]
743
-
744
- return image_embeds, image_atts
745
-
746
- def _encode_vision_x(self, vision_x: torch.Tensor):
747
- """
748
- Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
749
- Args:
750
- vision_x: Vision input
751
- shape (B, T_img, F, C, H, W)
752
- Images in the same chunk are collated along T_img, and frames are collated along F
753
- Currently only F=1 is supported (single-frame videos)
754
-
755
- rearrange code based on https://github.com/dhansmair/flamingo-mini
756
- """
757
- assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
758
- b, T, F = vision_x.shape[:3]
759
-
760
- vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
761
- with torch.no_grad():
762
- if self.vision_encoder.__class__.__name__ == "TimmModel":
763
- vision_x = self.vision_encoder.trunk.forward_features(vision_x)
764
- elif self.vision_encoder.__class__.__name__ in ['CLIPVisionModel', 'SiglipVisionTransformer']:
765
- vision_x = self.vision_encoder(vision_x).last_hidden_state
766
- else:
767
- vision_x = self.vision_encoder(vision_x)[1] # OpenCLIP returns tuples
768
- vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
769
- return vision_x
770
-
771
- def _concat_vision_cache(
772
- self, lang_x, vision_tokens, past_vision_tokens, past_media_locations, use_cache
773
- ):
774
- """
775
- Helper function to include the past vision tokens and past media locations in the output.
776
- """
777
- if use_cache:
778
- if past_media_locations is not None and past_vision_tokens is not None:
779
- if vision_tokens is not None:
780
- updated_vision_tokens = torch.cat(
781
- [
782
- past_vision_tokens,
783
- vision_tokens,
784
- ],
785
- dim=1,
786
- )
787
- else:
788
- updated_vision_tokens = past_vision_tokens
789
- updated_media_locations = torch.cat(
790
- [
791
- past_media_locations,
792
- lang_x == self.media_token_id,
793
- ],
794
- dim=1,
795
- )
796
- else:
797
- updated_vision_tokens = vision_tokens
798
- updated_media_locations = lang_x == self.media_token_id
799
-
800
- else:
801
- updated_vision_tokens = None
802
- updated_media_locations = None
803
-
804
- return updated_vision_tokens, updated_media_locations
805
-
806
- def generate(
807
- self,
808
- vision_x: torch.Tensor,
809
- lang_x: torch.Tensor,
810
- attention_mask: torch.Tensor = None,
811
- past_key_values: Optional[
812
- List[Union[torch.Tensor, Tuple[torch.Tensor]]]
813
- ] = None,
814
- past_media_locations: Optional[torch.Tensor] = None,
815
- past_vision_tokens: Optional[torch.Tensor] = None,
816
- **kwargs,
817
- ):
818
- """
819
- Generate text conditioned on vision and language inputs.
820
- Args:
821
- vision_x (torch.Tensor): Vision input
822
- shape (B, T_img, F, C, H, W)
823
- see documentation for forward
824
- lang_x (torch.Tensor): Language input
825
- shape (B, T_txt)
826
- attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
827
- **kwargs: see generate documentation in Hugging Face CausalLM models.
828
- Returns:
829
- torch.Tensor: lang_x with generated tokens appended to it
830
- """
831
- num_beams = kwargs.pop("num_beams", 1)
832
-
833
- # convert pixels to vision tokens
834
- if vision_x is not None:
835
- vision_features = self._encode_vision_x(vision_x=vision_x)
836
- vision_tokens = self.vision_tokenizer(vision_features)
837
- else:
838
- vision_tokens = None
839
-
840
- # fuse the vision and language tokens
841
- # for xattn, vision_x and media_location are repeat_interleaved s.t.
842
- # the total batch size is B * num_beams
843
- new_inputs = self._prepare_inputs_for_forward(
844
- vision_tokens=vision_tokens,
845
- lang_x=lang_x,
846
- attention_mask=attention_mask,
847
- past_key_values=past_key_values,
848
- past_media_locations=past_media_locations,
849
- past_vision_tokens=past_vision_tokens,
850
- padding_side="left",
851
- num_beams=num_beams,
852
- )
853
- output = self.lang_model.generate(
854
- **new_inputs,
855
- past_key_values=past_key_values,
856
- num_beams=num_beams,
857
- use_cache=True,
858
- **kwargs,
859
- )
860
- self._post_forward_hook()
861
- return output
862
-
863
- @property
864
- def num_trainable_params(self):
865
- """Print the number of trainable parameters"""
866
- return num_params(self, filter_to_trainable=True)
867
-
868
- def set_trainable(self):
869
- """
870
- Freeze appropriate parameters in the model.
871
- """
872
- raise NotImplementedError
873
-
874
- def group_params_by_weight_decay(self):
875
- """
876
- Return a tuple of (params to optimize w/ weight decay, params to optimize w/o weight decay)
877
- """
878
- params_with_wd, params_without_wd = [], []
879
- for n, p in self.named_parameters():
880
- if p.requires_grad:
881
- if self._should_apply_weight_decay(n):
882
- params_with_wd.append(p)
883
- else:
884
- params_without_wd.append(p)
885
- return params_with_wd, params_without_wd
886
-
887
- def _should_apply_weight_decay(self, parameter_name):
888
- """
889
- Return whether weight decay should be applied to a parameter.
890
- """
891
- raise NotImplementedError
892
-
893
- @property
894
- def special_tokens(self):
895
- """
896
- Returns a dict mapping from the attribute name of a special token to its string format,
897
- e.g. "media_token": "<image>"
898
- """
899
- assert (
900
- "media_token" in self._special_tokens
901
- ), "VLMs need to request that the tokenizer add a media_token and call set_special_token_ids to set self.media_token_id"
902
- return self._special_tokens
903
-
904
- @property
905
- def special_token_ids(self):
906
- """
907
- Returns a list of the special token ids
908
- """
909
- return [getattr(self, f"{att_name}_id") for att_name in self.special_tokens]
910
-
911
- def set_special_token_ids(self, string_to_ids):
912
- """
913
- Args:
914
- string_to_ids (dict): mapping from token string to id
915
- """
916
- assert set(self.special_tokens.values()).issubset(set(string_to_ids.keys()))
917
- for att_name, token_str in self.special_tokens.items():
918
- token_id = string_to_ids[token_str]
919
- setattr(self, f"{att_name}_id", token_id)
920
- setattr(self.lang_model, f"{att_name}_id", token_id)
921
-
922
- def init_gradient_checkpointing(self):
923
- from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
924
- checkpoint_wrapper,
925
- CheckpointWrapper,
926
- CheckpointImpl,
927
- apply_activation_checkpointing,
928
- )
929
- from functools import partial
930
-
931
- non_reentrant_wrapper = partial(
932
- checkpoint_wrapper,
933
- checkpoint_impl=CheckpointImpl.NO_REENTRANT,
934
- )
935
- apply_activation_checkpointing(
936
- self,
937
- checkpoint_wrapper_fn=non_reentrant_wrapper,
938
- check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False)
939
- and not isinstance(m, CheckpointWrapper),
940
- )
941
-
942
- @dataclass
943
- class VLMOutputWithPast(CausalLMOutputWithPast):
944
- """
945
- VLMOutputWithPast is a wrapper around CausalLMOutputWithPast that adds the following attributes:
946
- past_media_locations: Optional[torch.Tensor] = None,
947
- past_vision_tokens: Optional[torch.Tensor] = None,
948
- """
949
-
950
- past_media_locations: Optional[torch.Tensor] = None
951
- past_vision_tokens: Optional[torch.Tensor] = None
952
-
953
-
954
- def exists(val):
955
- return val is not None
956
-
957
-
958
- def FeedForward(dim, mult=4):
959
- inner_dim = int(dim * mult)
960
- return nn.Sequential(
961
- nn.LayerNorm(dim),
962
- nn.Linear(dim, inner_dim, bias=False),
963
- nn.GELU(),
964
- nn.Linear(inner_dim, dim, bias=False),
965
- )
966
-
967
- class VLMWithLanguageStream(VLM):
968
- """
969
- VLM that fuses modalities by inserting vision tokens directly into the language stream.
970
- """
971
-
972
- def __init__(
973
- self,
974
- vision_encoder: nn.Module,
975
- vision_tokenizer: nn.Module,
976
- lang_model: nn.Module,
977
- initial_tokenizer_len: int,
978
- pad_token_id: int,
979
- decoder_layers_attr_name: str = None,
980
- gradient_checkpointing: bool = False,
981
- ):
982
- super().__init__(
983
- vision_encoder=vision_encoder,
984
- vision_tokenizer=vision_tokenizer,
985
- lang_model=lang_model,
986
- initial_tokenizer_len=initial_tokenizer_len,
987
- pad_token_id=pad_token_id,
988
- gradient_checkpointing=gradient_checkpointing,
989
- )
990
- self.decoder_layers_attr_name = decoder_layers_attr_name
991
- if decoder_layers_attr_name is not None:
992
- for block in getattr_recursive(self.lang_model, self.decoder_layers_attr_name):
993
- block._use_gradient_checkpointing = gradient_checkpointing
994
-
995
- def _prepare_inputs_for_forward(
996
- self,
997
- vision_tokens: torch.Tensor,
998
- lang_x: torch.Tensor,
999
- attention_mask: torch.Tensor,
1000
- labels: torch.Tensor = None,
1001
- past_key_values=None,
1002
- vision_attention_mask: Optional[torch.Tensor] = None,
1003
- past_media_locations: torch.Tensor = None,
1004
- past_vision_tokens: torch.Tensor = None,
1005
- padding_side: str = "left",
1006
- num_beams: int = 1,
1007
- ):
1008
- """
1009
- Insert the vision tokens directly into the language stream/
1010
- This requires us to modify the input_ids, attention_mask, and labels.
1011
- """
1012
- if past_key_values is not None:
1013
- past_len = past_key_values[0][0].shape[2]
1014
- assert attention_mask.shape[1] == past_len + lang_x.shape[1], (
1015
- "Attention_mask must be as long as the entire past len (including image tokens) and current input IDs. "
1016
- + "Check that you've expanded the attention mask to account for past image tokens."
1017
- )
1018
-
1019
- if vision_tokens is None:
1020
- return {
1021
- "input_ids": lang_x,
1022
- "attention_mask": attention_mask,
1023
- "labels": labels,
1024
- }
1025
-
1026
- # get the language embeddings
1027
- lang_embeds = self.lang_model.get_input_embeddings()(lang_x)
1028
-
1029
- # build up the multimodal embeddings
1030
- B = lang_x.shape[0]
1031
- has_labels = labels is not None
1032
- multimodal_embeds = []
1033
- multimodal_attention_mask = []
1034
- multimodal_labels = [] if has_labels else None
1035
- for i in range(B):
1036
- # get index of <image> tokens in lang_x[i]
1037
- image_token_idxs = torch.where(lang_x[i] == self.media_token_id)[0]
1038
-
1039
- if len(image_token_idxs) == 0:
1040
- multimodal_embeds.append(lang_embeds[i].clone())
1041
- multimodal_attention_mask.append(attention_mask[i].clone())
1042
- if has_labels:
1043
- multimodal_labels.append(labels[i].clone())
1044
- continue
1045
-
1046
- # loop through the image_token_idxs and insert the vision tokens
1047
- new_embed = lang_embeds[i].clone()
1048
- new_attention_mask = (
1049
- attention_mask[i].clone() if attention_mask is not None else None
1050
- )
1051
- if has_labels:
1052
- new_label = labels[i].clone()
1053
-
1054
- for img_num, img_idx in enumerate(image_token_idxs):
1055
- # Get vision token attention mask for padded llava-style any resolution image tokens.
1056
- if self.image_aspect_ratio =='anyres':
1057
- num_vis_tokens = vision_tokens[i][img_num].shape[0]
1058
- if vision_attention_mask is not None:
1059
- vis_attention_mask = vision_attention_mask[i]
1060
- else:
1061
- vis_attention_mask = torch.ones(
1062
- num_vis_tokens, dtype=torch.long
1063
- ).to(attention_mask.device)
1064
- else:
1065
- assert (
1066
- vision_tokens[i][img_num].shape[0] == self.num_tokens_per_vis
1067
- ), f"vision token number mismatch: image embedding ({vision_tokens[i][img_num].shape[0]}) \
1068
- vs. model.num_tokens_per_vis ({self.num_tokens_per_vis})"
1069
- # By default, vision tokens are not padded.
1070
- num_vis_tokens = self.num_tokens_per_vis
1071
- vis_attention_mask = torch.ones(
1072
- num_vis_tokens, dtype=torch.long
1073
- ).to(attention_mask.device)
1074
-
1075
- new_embed = torch.cat(
1076
- (
1077
- new_embed[:img_idx],
1078
- vision_tokens[i][img_num],
1079
- new_embed[img_idx + 1 :],
1080
- ),
1081
- dim=0,
1082
- )
1083
- new_attention_mask = torch.cat(
1084
- (
1085
- new_attention_mask[:img_idx],
1086
- vis_attention_mask,
1087
- new_attention_mask[img_idx + 1 :],
1088
- ),
1089
- dim=0,
1090
- )
1091
- if has_labels:
1092
- new_label = torch.cat(
1093
- (
1094
- new_label[:img_idx],
1095
- torch.ones(num_vis_tokens, dtype=torch.long).to(
1096
- labels.device
1097
- )
1098
- * -100,
1099
- new_label[img_idx + 1 :],
1100
- ),
1101
- dim=0,
1102
- )
1103
- multimodal_embeds.append(new_embed)
1104
- multimodal_attention_mask.append(new_attention_mask)
1105
- if has_labels:
1106
- multimodal_labels.append(new_label)
1107
-
1108
- # stack
1109
- multimodal_embeds = stack_with_padding(
1110
- multimodal_embeds,
1111
- padding_value=self.pad_token_id,
1112
- padding_side=padding_side,
1113
- )
1114
- multimodal_attention_mask = stack_with_padding(
1115
- multimodal_attention_mask,
1116
- padding_value=0,
1117
- padding_side=padding_side,
1118
- )
1119
- if has_labels:
1120
- multimodal_labels = stack_with_padding(
1121
- multimodal_labels,
1122
- padding_value=-100,
1123
- padding_side=padding_side,
1124
- )
1125
-
1126
- return {
1127
- "inputs_embeds": multimodal_embeds,
1128
- "attention_mask": multimodal_attention_mask,
1129
- "labels": multimodal_labels,
1130
- }
1131
-
1132
- def _postprocess_outputs_from_forward(
1133
- self,
1134
- output: CausalLMOutputWithPast,
1135
- lang_x: torch.Tensor,
1136
- vision_tokens: torch.Tensor,
1137
- past_vision_tokens: torch.Tensor,
1138
- past_media_locations: torch.Tensor,
1139
- use_cache: bool = False,
1140
- ):
1141
- # Include the past vision tokens and past media locations in the output
1142
- updated_vision_tokens, updated_media_locations = self._concat_vision_cache(
1143
- lang_x=lang_x,
1144
- vision_tokens=vision_tokens,
1145
- past_vision_tokens=past_vision_tokens,
1146
- past_media_locations=past_media_locations,
1147
- use_cache=use_cache,
1148
- )
1149
-
1150
- # return logits that are the same shape as the original input_ids
1151
- logits = output.logits
1152
- batch_logits = []
1153
- B, T_txt = lang_x.shape
1154
- for i in range(B):
1155
- sequence_logits = []
1156
- logits_j = 0
1157
- for j in range(T_txt):
1158
- if lang_x[i, j] != self.media_token_id:
1159
- sequence_logits.append(logits[i, logits_j])
1160
- logits_j += 1
1161
- else:
1162
- # append the logit for the first image token, then skip over the rest
1163
- # note: the model actually learns to predict <im_patch>, not <image>
1164
- sequence_logits.append(logits[i, logits_j])
1165
- logits_j += self.num_tokens_per_vis
1166
- sequence_logits = torch.stack(sequence_logits, dim=0) # (B, vocab_size)
1167
- batch_logits.append(sequence_logits)
1168
-
1169
- batch_logits = torch.stack(batch_logits, dim=0) # (B, T_txt, vocab_size)
1170
- # The final logits shape should be the same as the original input_ids shape
1171
- assert batch_logits.shape[:2] == (B, T_txt)
1172
-
1173
- # assemble the output
1174
- output = VLMOutputWithPast(
1175
- loss=output.loss,
1176
- logits=batch_logits,
1177
- past_key_values=output.past_key_values,
1178
- hidden_states=output.hidden_states,
1179
- attentions=output.attentions,
1180
- past_media_locations=updated_media_locations,
1181
- past_vision_tokens=updated_vision_tokens,
1182
- )
1183
-
1184
- return output
1185
-
1186
- def _post_forward_hook(self):
1187
- pass
1188
-
1189
-
1190
- @property
1191
- def num_params_per_module(self):
1192
- """Print the number of parameters per module in the model"""
1193
- return "\n".join(
1194
- [
1195
- f"Vision encoder: {num_params(self.vision_encoder):,} parameters",
1196
- f"Vision tokenizer: {num_params(self.vision_tokenizer):,} parameters",
1197
- f"Language model: {num_params(self.lang_model):,} parameters",
1198
- ]
1199
- )
1200
-
1201
- @property
1202
- def num_trainable_params_per_module(self):
1203
- """Print the number of trainable parameters per module in the model"""
1204
- return "\n".join(
1205
- [
1206
- f"Vision encoder: {num_params(self.vision_encoder, filter_to_trainable=True):,} trainable parameters",
1207
- f"Vision tokenizer: {num_params(self.vision_tokenizer, filter_to_trainable=True):,} trainable parameters",
1208
- f"Language model: {num_params(self.lang_model, filter_to_trainable=True):,} trainable parameters",
1209
- ]
1210
- )
1211
-
1212
-
1213
- class XGenMMPerceiver(VLMWithLanguageStream):
1214
- def __init__(
1215
- self,
1216
- vision_encoder: nn.Module,
1217
- vision_tokenizer: nn.Module,
1218
- lang_model: nn.Module,
1219
- initial_tokenizer_len: int,
1220
- pad_token_id: int,
1221
- decoder_layers_attr_name: str = None,
1222
- gradient_checkpointing: bool = False,
1223
- image_aspect_ratio: str = 'anyres',
1224
- anyres_patch_sampling: bool = True,
1225
- anyres_grids: list[int] = None,
1226
- ):
1227
- """
1228
- Args:
1229
- vision_encoder (nn.Module): HF CLIPModel
1230
- lang_encoder (nn.Module): HF causal language model
1231
- vis_feature_dim (int): final dimension of the visual features outputted by the vision_encoder
1232
- initial_tokenizer_len (int): size of the tokenizer vocab
1233
- padding_token_id (int): id of the padding token. None if no padding token; then a padding token
1234
- will be inserted into self.special_tokens, which factory.py fills after creating new tokens
1235
- decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
1236
- gradient_checkpointing (bool, optional): whether to use gradient checkpointing. Defaults to False.
1237
- """
1238
- self._special_tokens = {
1239
- "media_token": "<image>",
1240
- "image_placeholder_token": "<image placeholder>",
1241
- "end_of_trunk_token": "<|endofchunk|>",
1242
- }
1243
- lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]
1244
- super().__init__(
1245
- vision_encoder=vision_encoder,
1246
- vision_tokenizer=vision_tokenizer,
1247
- lang_model=lang_model,
1248
- initial_tokenizer_len=initial_tokenizer_len,
1249
- gradient_checkpointing=gradient_checkpointing,
1250
- decoder_layers_attr_name=decoder_layers_attr_name,
1251
- pad_token_id=pad_token_id,
1252
- )
1253
- self.image_aspect_ratio = image_aspect_ratio
1254
- self.anyres_patch_sampling = anyres_patch_sampling
1255
- self.anyres_grids = anyres_grids
1256
-
1257
- def set_trainable(self):
1258
- """
1259
- Unfreeze everything except the vision_encoder
1260
- """
1261
- self.requires_grad_(True)
1262
- self.vision_encoder.requires_grad_(False)
1263
-
1264
- def _should_apply_weight_decay(self, parameter_name):
1265
- """
1266
- Kosmos applies 0.01 weight deacy to everything
1267
- """
1268
- return True
1269
-
1270
- def generate(
1271
- self,
1272
- vision_x: torch.Tensor,
1273
- lang_x: torch.Tensor,
1274
- image_size: Optional[Tuple] = None,
1275
- attention_mask: torch.Tensor = None,
1276
- past_key_values: Optional[
1277
- List[Union[torch.Tensor, Tuple[torch.Tensor]]]
1278
- ] = None,
1279
- past_media_locations: Optional[torch.Tensor] = None,
1280
- past_vision_tokens: Optional[torch.Tensor] = None,
1281
- **kwargs,
1282
- ):
1283
- """
1284
- Generate text conditioned on vision and language inputs.
1285
- Args:
1286
- vision_x (torch.Tensor): Vision input
1287
- shape (B, T_img, F, C, H, W)
1288
- see documentation for forward
1289
- lang_x (torch.Tensor): Language input
1290
- shape (B, T_txt)
1291
- attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
1292
- **kwargs: see generate documentation in Hugging Face CausalLM models.
1293
- Returns:
1294
- torch.Tensor: lang_x with generated tokens appended to it
1295
- """
1296
- num_beams = kwargs.pop("num_beams", 1)
1297
-
1298
- # convert pixels to vision tokens
1299
- vision_attention_mask = None
1300
- if vision_x is not None:
1301
- if self.image_aspect_ratio == 'anyres':
1302
- input_dict = dict(image=vision_x, image_size=image_size)
1303
- vision_features, vision_attn_masks = self._encode_vision_x_anyres(input_dict, lang_x.device)
1304
- else:
1305
- vision_features = self._encode_vision_x(vision_x=vision_x)
1306
- vision_attn_masks = None
1307
- # If doing patch sampling, then flatten patches of shape [b, Np_i, v, d] -> [b*Np, v, d]
1308
- # Same for attention masks: [b, Np, v] -> [b*Np, v]
1309
- if self.anyres_patch_sampling:
1310
- split_sizes = [feature.shape[0] for feature in vision_features]
1311
- # Nested splits for multi-image samples.
1312
- if isinstance(vision_x[0], list):
1313
- nt_images = [len(images) for images in vision_x]
1314
- split_split_sizes = []
1315
- img_id = 0
1316
- for nt in nt_images:
1317
- split_split_sizes.append(split_sizes[img_id:img_id+nt])
1318
- img_id += nt
1319
- else:
1320
- nt_images = [1] * len(vision_x)
1321
- split_split_sizes = split_sizes
1322
- vision_features = torch.cat(vision_features, dim=0)
1323
- vision_features = vision_features[:, None, None, :, :] # Expand dimensions.
1324
- vision_attn_masks = torch.cat(vision_attn_masks, dim=0)
1325
- vision_tokens = self.vision_tokenizer(vision_features, vision_attn_masks)
1326
-
1327
- # Post-processing: Split the batches into groups of patches and concatenate them together.
1328
- if self.anyres_patch_sampling:
1329
- assert isinstance(vision_x, list)
1330
- if isinstance(vision_x[0], list):
1331
- vision_token_groups = torch.split(vision_tokens, list(sum(nt_img) for nt_img in split_split_sizes), dim=0)
1332
- vision_tokens = []
1333
-
1334
- for sample_id, patch_vis_tokens in enumerate(vision_token_groups):
1335
- patch_vis_token_groups = torch.split(patch_vis_tokens, split_split_sizes[sample_id], dim=0) # [Np*nt, 1, v, d] -> [[Np_t, 1, v, d], ...]
1336
- flatten_vision_tokens = []
1337
- for image_vis_token in patch_vis_token_groups:
1338
- image_vis_token = image_vis_token.flatten(0, 2) # [Np, 1, v, d] -> [Np*v, d]
1339
- flatten_vision_tokens.append(image_vis_token)
1340
- vision_tokens_i = flatten_vision_tokens
1341
- vision_tokens.append(vision_tokens_i)
1342
- else:
1343
- vision_token_groups = torch.split(vision_tokens, split_sizes, dim=0)
1344
- vision_tokens = []
1345
- for patch_vis_tokens in vision_token_groups:
1346
- patch_vis_tokens = patch_vis_tokens.flatten(0, 2) # [Np, 1, v, d] -> [Np*v, d]
1347
- vision_tokens.append(patch_vis_tokens.unsqueeze(0)) # Add the nt dimension.
1348
- else:
1349
- vision_tokens = None
1350
-
1351
- # fuse the vision and language tokens
1352
- # for xattn, vision_x and media_location are repeat_interleaved s.t.
1353
- # the total batch size is B * num_beams
1354
- new_inputs = self._prepare_inputs_for_forward(
1355
- vision_tokens=vision_tokens,
1356
- lang_x=lang_x,
1357
- attention_mask=attention_mask,
1358
- vision_attention_mask=vision_attention_mask,
1359
- past_key_values=past_key_values,
1360
- past_media_locations=past_media_locations,
1361
- past_vision_tokens=past_vision_tokens,
1362
- padding_side="left",
1363
- num_beams=num_beams,
1364
- )
1365
- if past_key_values is not None:
1366
- output = self.lang_model.generate(
1367
- **new_inputs,
1368
- past_key_values=past_key_values,
1369
- num_beams=num_beams,
1370
- use_cache=True,
1371
- **kwargs,
1372
- )
1373
- else:
1374
- output = self.lang_model.generate(
1375
- **new_inputs,
1376
- num_beams=num_beams,
1377
- use_cache=True,
1378
- **kwargs,
1379
- )
1380
- self._post_forward_hook()
1381
- return output