VictorSanh commited on
Commit
9505bbc
·
1 Parent(s): a1abacc
Files changed (3) hide show
  1. configuration_img2html.py +310 -0
  2. modeling_img2html.py +1772 -0
  3. vision.py +1361 -0
configuration_img2html.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Img2HTML model configuration"""
16
+ from transformers.configuration_utils import PretrainedConfig
17
+ from transformers.utils import logging
18
+
19
+
20
+ logger = logging.get_logger(__name__)
21
+
22
+ MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
23
+ "HuggingFaceM4/Img2HTML": "https://huggingface.co/HuggingFaceM4/Img2HTML/resolve/main/config.json",
24
+ }
25
+
26
+
27
+ class VMistralVisionConfig(PretrainedConfig):
28
+ r"""
29
+ """
30
+ model_type = "vmistral"
31
+
32
+ def __init__(
33
+ self,
34
+ hidden_size=768,
35
+ intermediate_size=3072,
36
+ projection_dim=512,
37
+ num_hidden_layers=12,
38
+ num_attention_heads=12,
39
+ num_channels=3,
40
+ image_size=224,
41
+ patch_size=32,
42
+ hidden_act="gelu_pytorch_tanh",
43
+ layer_norm_eps=1e-6,
44
+ attention_dropout=0.0,
45
+ initializer_range=0.02,
46
+ initializer_factor=1.0,
47
+ _flash_attn_2_enabled=True,
48
+ **kwargs,
49
+ ):
50
+ super().__init__(**kwargs)
51
+
52
+ self.hidden_size = hidden_size
53
+ self.intermediate_size = intermediate_size
54
+ self.projection_dim = projection_dim
55
+ self.num_hidden_layers = num_hidden_layers
56
+ self.num_attention_heads = num_attention_heads
57
+ self.num_channels = num_channels
58
+ self.patch_size = patch_size
59
+ self.image_size = image_size
60
+ self.initializer_range = initializer_range
61
+ self.initializer_factor = initializer_factor
62
+ self.attention_dropout = attention_dropout
63
+ self.layer_norm_eps = layer_norm_eps
64
+ self.hidden_act = hidden_act
65
+ self._flash_attn_2_enabled = _flash_attn_2_enabled
66
+
67
+
68
+ class VMistralPerceiverConfig(PretrainedConfig):
69
+ r"""
70
+ TThis is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an
71
+ Mistral model according to the specified arguments, defining the model architecture. Instantiating a configuration
72
+ with the defaults will yield a similar configuration to that of the Mistral-7B-v0.1 or Mistral-7B-Instruct-v0.1.
73
+
74
+ [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
75
+ [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
76
+
77
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
78
+ documentation from [`PretrainedConfig`] for more information.
79
+
80
+ Args:
81
+ use_resampler (`bool`, *optional*, defaults to `False`):
82
+ Whether or not to use the resampler
83
+ resampler_n_latents (`int`, *optional*, defaults to ):
84
+ Number of latent embeddings to resample ("compress") the input sequence to (usually < 128).
85
+ resampler_depth (`int`, *optional*, defaults to 6):
86
+ Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3).
87
+ resampler_n_heads (`int`, *optional*, defaults to 16):
88
+ Number of heads in each Transformer block (for multi-headed self-attention).
89
+ resampler_head_dim (`int`, *optional*, defaults to 96):
90
+ Dimensionality of each head projection in the Transformer block.
91
+ qk_layer_norms_perceiver (`bool`, *optional*, defaults to `False`):
92
+ Whether or not to use qk layer norms in perceiver
93
+ """
94
+ model_type = "vmistral"
95
+
96
+ def __init__(
97
+ self,
98
+ resampler_n_latents=64,
99
+ resampler_depth=6,
100
+ resampler_n_heads=16,
101
+ resampler_head_dim=96,
102
+ qk_layer_norms_perceiver=False,
103
+ **kwargs,
104
+ ):
105
+ self.resampler_n_latents = resampler_n_latents
106
+ self.resampler_depth = resampler_depth
107
+ self.resampler_n_heads = resampler_n_heads
108
+ self.resampler_head_dim = resampler_head_dim
109
+ self.qk_layer_norms_perceiver = qk_layer_norms_perceiver
110
+
111
+ super().__init__(**kwargs)
112
+
113
+
114
+ class VMistralConfig(PretrainedConfig):
115
+ r"""
116
+ This is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an
117
+ Mistral model according to the specified arguments, defining the model architecture. Instantiating a configuration
118
+ with the defaults will yield a similar configuration to that of the Mistral-7B-v0.1 or Mistral-7B-Instruct-v0.1.
119
+
120
+ [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
121
+ [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
122
+
123
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
124
+ documentation from [`PretrainedConfig`] for more information.
125
+
126
+ Args:
127
+ additional_vocab_size (`int`, *optional`, defaults to 0):
128
+ Additional vocabulary size of the model, typically for the special "<img>" token. Additional vocab tokens
129
+ are always trainable whereas regular vocab tokens can be frozen or not.
130
+ vocab_size (`int`, *optional*, defaults to 32000):
131
+ Vocabulary size of the Mistral model. Defines the number of different tokens that can be represented by the
132
+ `inputs_ids` passed when calling [`MistralModel`]
133
+ hidden_size (`int`, *optional*, defaults to 4096):
134
+ Dimension of the hidden representations.
135
+ intermediate_size (`int`, *optional*, defaults to 14336):
136
+ Dimension of the MLP representations.
137
+ num_hidden_layers (`int`, *optional*, defaults to 32):
138
+ Number of hidden layers in the Transformer encoder.
139
+ num_attention_heads (`int`, *optional*, defaults to 32):
140
+ Number of attention heads for each attention layer in the Transformer encoder.
141
+ num_key_value_heads (`int`, *optional*, defaults to 8):
142
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
143
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
144
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
145
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
146
+ by meanpooling all the original heads within that group. For more details checkout [this
147
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
148
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
149
+ The non-linear activation function (function or string) in the decoder.
150
+ max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
151
+ The maximum sequence length that this model might ever be used with. Mistral's sliding window attention
152
+ allows sequence of up to 4096*32 tokens.
153
+ initializer_range (`float`, *optional*, defaults to 0.02):
154
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
155
+ alpha_initializer (`str`, *optional*, defaults to `"zeros"`):
156
+ Initialization type for the alphas.
157
+ alphas_initializer_range (`float`, *optional*, defaults to 0.0):
158
+ The standard deviation of the truncated_normal_initializer for initializing the alphas in the Gated Cross
159
+ Attention.
160
+ alpha_type (`str`, *optional*, defaults to `"float"`):
161
+ Whether the gating alphas should be vectors or single floats.
162
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
163
+ The epsilon used by the rms normalization layers.
164
+ use_cache (`bool`, *optional*, defaults to `True`):
165
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
166
+ relevant if `config.is_decoder=True`.
167
+ pad_token_id (`int`, *optional*):
168
+ The id of the padding token.
169
+ bos_token_id (`int`, *optional*, defaults to 1):
170
+ The id of the "beginning-of-sequence" token.
171
+ eos_token_id (`int`, *optional*, defaults to 2):
172
+ The id of the "end-of-sequence" token.
173
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
174
+ Whether the model's input and output word embeddings should be tied.
175
+ rope_theta (`float`, *optional*, defaults to 10000.0):
176
+ The base period of the RoPE embeddings.
177
+ sliding_window (`int`, *optional*, defaults to 4096):
178
+ Sliding window attention window size. If not specified, will default to `4096`.
179
+ cross_layer_interval (`int`, *optional*, default to 1)
180
+ Interval for cross attention (from text to image) layers.
181
+ qk_layer_norms (`bool`, *optional*, defaults to `False`): Whether to add layer norm after q and k
182
+ freeze_text_layers (`bool`, *optional*, defaults to `True`): Whether to freeze text layers
183
+ freeze_text_module_exceptions (`bool`, *optional*, defaults to `[]`):
184
+ Exceptions to freezing text layers when `freeze_text_layers` is `True`
185
+ freeze_lm_head (`bool`, *optional*, defaults to `False`): Whether to freeze lm head
186
+ freeze_vision_layers (`bool`, *optional*, defaults to `True`): Whether to freeze vision layers
187
+ freeze_vision_module_exceptions (`bool`, *optional*, defaults to `[]`):
188
+ Exceptions to freezing vision layers when `freeze_vision_layers` is `True`
189
+ use_resampler (`bool`, *optional*, defaults to `False`): Whether to use the Resampler
190
+ vision_config (`IdeficsVisionConfig`, *optional*): Custom vision config or dict
191
+ perceiver_config (`IdeficsPerceiverConfig`, *optional*): Custom perceiver config or dict
192
+
193
+ Example:
194
+ ```python
195
+ >>> from transformers import MistralModel, MistralConfig
196
+
197
+ >>> # Initializing a Mistral 7B style configuration
198
+ >>> configuration = MistralConfig()
199
+
200
+ >>> # Initializing a model from the Mistral 7B style configuration
201
+ >>> model = MistralModel(configuration)
202
+
203
+ >>> # Accessing the model configuration
204
+ >>> configuration = model.config
205
+ ```"""
206
+ model_type = "vmistral"
207
+ is_composition = False
208
+
209
+ def __init__(
210
+ self,
211
+ additional_vocab_size=0,
212
+ vocab_size=32000,
213
+ hidden_size=4096,
214
+ intermediate_size=14336,
215
+ num_hidden_layers=32,
216
+ num_attention_heads=32,
217
+ num_key_value_heads=8,
218
+ hidden_act="silu",
219
+ max_position_embeddings=4096 * 32,
220
+ initializer_range=0.02,
221
+ alpha_initializer="zeros",
222
+ alphas_initializer_range=0.0,
223
+ alpha_type="float",
224
+ rms_norm_eps=1e-6,
225
+ use_cache=True,
226
+ pad_token_id=0, # None in the original configuration_mistral, we set it to the unk_token_id
227
+ bos_token_id=1,
228
+ eos_token_id=2,
229
+ image_token_id=32_001,
230
+ tie_word_embeddings=False,
231
+ rope_theta=10000.0,
232
+ sliding_window=4096,
233
+ cross_layer_interval=1,
234
+ qk_layer_norms=False,
235
+ freeze_text_layers=True,
236
+ freeze_text_module_exceptions=[],
237
+ freeze_lm_head=False,
238
+ freeze_vision_layers=True,
239
+ freeze_vision_module_exceptions=[],
240
+ attention_dropout=0.0,
241
+ _flash_attn_2_enabled=True,
242
+ use_resampler=False,
243
+ vision_config=None,
244
+ perceiver_config=None,
245
+ **kwargs,
246
+ ):
247
+ self.vocab_size = vocab_size
248
+ self.additional_vocab_size = additional_vocab_size
249
+ self.image_token_id = image_token_id
250
+ self.max_position_embeddings = max_position_embeddings
251
+ self.hidden_size = hidden_size
252
+ self.intermediate_size = intermediate_size
253
+ self.num_hidden_layers = num_hidden_layers
254
+ self.num_attention_heads = num_attention_heads
255
+ self.sliding_window = sliding_window
256
+
257
+ # for backward compatibility
258
+ if num_key_value_heads is None:
259
+ num_key_value_heads = num_attention_heads
260
+
261
+ self.num_key_value_heads = num_key_value_heads
262
+ self.hidden_act = hidden_act
263
+ self.initializer_range = initializer_range
264
+ self.alpha_initializer = alpha_initializer
265
+ self.alphas_initializer_range = alphas_initializer_range
266
+ self.alpha_type = alpha_type
267
+ self.rms_norm_eps = rms_norm_eps
268
+ self.use_cache = use_cache
269
+ self.rope_theta = rope_theta
270
+
271
+ self.cross_layer_interval = cross_layer_interval
272
+ self.qk_layer_norms = qk_layer_norms
273
+ self.freeze_vision_layers = freeze_vision_layers
274
+
275
+ self.freeze_text_layers = freeze_text_layers
276
+ self.freeze_text_module_exceptions = freeze_text_module_exceptions
277
+ self.freeze_vision_module_exceptions = freeze_vision_module_exceptions
278
+ self.freeze_lm_head = freeze_lm_head
279
+
280
+ self.use_resampler = use_resampler
281
+ self._flash_attn_2_enabled = _flash_attn_2_enabled
282
+ self.attention_dropout = attention_dropout
283
+
284
+ if perceiver_config is None:
285
+ self.perceiver_config = VMistralPerceiverConfig()
286
+ elif isinstance(perceiver_config, dict):
287
+ self.perceiver_config = VMistralPerceiverConfig(**perceiver_config)
288
+ elif isinstance(perceiver_config, VMistralPerceiverConfig):
289
+ self.perceiver_config = perceiver_config
290
+
291
+ if vision_config is None:
292
+ self.vision_config = VMistralVisionConfig()
293
+ elif isinstance(vision_config, dict):
294
+ self.vision_config = VMistralVisionConfig(**vision_config)
295
+ elif isinstance(vision_config, VMistralVisionConfig):
296
+ self.vision_config = vision_config
297
+
298
+ super().__init__(
299
+ pad_token_id=pad_token_id,
300
+ bos_token_id=bos_token_id,
301
+ eos_token_id=eos_token_id,
302
+ tie_word_embeddings=tie_word_embeddings,
303
+ **kwargs,
304
+ )
305
+
306
+ # IMPORTANT: Do not do any __init__ args-based checks in the constructor, since
307
+ # PretrainedConfig.from_dict first instantiates the class with the config dict and only then
308
+ # updates the config object with `kwargs` from from_pretrained, so during the instantiation
309
+ # of this object many attributes have default values and haven't yet been overridden.
310
+ # Do any required checks inside `from_pretrained` once the superclass' `from_pretrained` was run.
modeling_img2html.py ADDED
@@ -0,0 +1,1772 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch Mistral model."""
21
+ from dataclasses import dataclass
22
+ import inspect
23
+ import math
24
+ import warnings
25
+ from typing import List, Optional, Tuple, Union
26
+
27
+ import torch
28
+ import torch.nn.functional as F
29
+ import torch.utils.checkpoint
30
+ from torch import nn
31
+ from torch.nn import CrossEntropyLoss
32
+ from transformers.activations import ACT2FN
33
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
34
+ from transformers.utils import (
35
+ add_start_docstrings,
36
+ add_start_docstrings_to_model_forward,
37
+ is_flash_attn_2_available,
38
+ replace_return_docstrings,
39
+ )
40
+
41
+ from einops import rearrange, repeat
42
+ from transformers import PreTrainedModel
43
+ from transformers.utils import logging
44
+ from transformers.modeling_outputs import ModelOutput
45
+
46
+ from .configuration_img2html import VMistralConfig
47
+ from .vision import SiglipVisionModel
48
+
49
+
50
+ if is_flash_attn_2_available():
51
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
52
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
53
+
54
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
55
+
56
+ logger = logging.get_logger(__name__)
57
+
58
+ _CONFIG_FOR_DOC = "VMistralConfig"
59
+
60
+ IMG2HTML_PRETRAINED_MODEL_ARCHIVE_LIST = [
61
+ "HuggingFaceM4/Img2HTML"
62
+ ]
63
+
64
+ @dataclass
65
+ class Img2HTMLBaseModelOutputWithPast(ModelOutput):
66
+ """
67
+ Base class for Img2HTML model's outputs that may also contain a past key/values (to speed up sequential decoding).
68
+
69
+ Args:
70
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
71
+ Sequence of hidden-states at the output of the last layer of the model.
72
+
73
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
74
+ hidden_size)` is output.
75
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
76
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
77
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
78
+ `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
79
+ encoder_sequence_length, embed_size_per_head)`.
80
+
81
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
82
+ `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
83
+ input) to speed up sequential decoding.
84
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
85
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
86
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
87
+
88
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
89
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
90
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
91
+ sequence_length)`.
92
+
93
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
94
+ heads.
95
+ image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
96
+ Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
97
+ sequence_length, hidden_size)`.
98
+
99
+ image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
100
+ """
101
+
102
+ last_hidden_state: torch.FloatTensor = None
103
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
104
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
105
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
106
+ image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
107
+
108
+
109
+ @dataclass
110
+ class Img2HTMLCausalLMOutputWithPast(ModelOutput):
111
+ """
112
+ Base class for Idefics causal language model (or autoregressive) outputs.
113
+
114
+ Args:
115
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
116
+ Language modeling loss (for next-token prediction).
117
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
118
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
119
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
120
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
121
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
122
+
123
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
124
+ `past_key_values` input) to speed up sequential decoding.
125
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
126
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
127
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
128
+
129
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
130
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
131
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
132
+ sequence_length)`.
133
+
134
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
135
+ heads.
136
+ image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
137
+ Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
138
+ sequence_length, hidden_size)`.
139
+
140
+ image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
141
+ """
142
+
143
+ loss: Optional[torch.FloatTensor] = None
144
+ logits: torch.FloatTensor = None
145
+ past_key_values: Optional[List[torch.FloatTensor]] = None
146
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
147
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
148
+ image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
149
+
150
+
151
+ def expand_inputs_for_generation(
152
+ input_ids,
153
+ expand_size=1,
154
+ is_encoder_decoder=False,
155
+ attention_mask=None,
156
+ encoder_outputs=None,
157
+ **model_kwargs,
158
+ ):
159
+ expanded_return_idx = (
160
+ torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device)
161
+ )
162
+ input_ids = input_ids.index_select(0, expanded_return_idx)
163
+ model_kwargs["pixel_values"] = model_kwargs.get("pixel_values", None)
164
+ model_kwargs["image_hidden_states"] = model_kwargs.get("image_hidden_states", None)
165
+ model_kwargs["image_attention_mask"] = model_kwargs.get("image_attention_mask", None)
166
+
167
+ if "token_type_ids" in model_kwargs:
168
+ token_type_ids = model_kwargs["token_type_ids"]
169
+ model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx)
170
+
171
+ if attention_mask is not None:
172
+ model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)
173
+
174
+ if model_kwargs["image_attention_mask"] is not None:
175
+ model_kwargs["image_attention_mask"] = model_kwargs["image_attention_mask"].index_select(
176
+ 0, expanded_return_idx
177
+ )
178
+
179
+ if model_kwargs["pixel_values"] is not None:
180
+ model_kwargs["pixel_values"] = model_kwargs["pixel_values"].index_select(0, expanded_return_idx)
181
+
182
+ elif model_kwargs["image_hidden_states"] is not None:
183
+ model_kwargs["image_hidden_states"] = model_kwargs["image_hidden_states"].index_select(
184
+ 0, expanded_return_idx
185
+ )
186
+
187
+ return input_ids, model_kwargs
188
+
189
+
190
+ def update_model_kwargs_for_generation(outputs, model_kwargs):
191
+ # must have this key set to at least None
192
+ if "past_key_values" in outputs:
193
+ model_kwargs["past_key_values"] = outputs.past_key_values
194
+ else:
195
+ model_kwargs["past_key_values"] = None
196
+
197
+ # update token_type_ids with last value
198
+ if "token_type_ids" in model_kwargs:
199
+ token_type_ids = model_kwargs["token_type_ids"]
200
+ model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
201
+
202
+ # update attention masks
203
+ if "attention_mask" in model_kwargs:
204
+ attention_mask = model_kwargs["attention_mask"]
205
+ model_kwargs["attention_mask"] = torch.cat(
206
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
207
+ )
208
+ if "image_attention_mask" in model_kwargs:
209
+ image_attention_mask = model_kwargs["image_attention_mask"]
210
+ last_mask = image_attention_mask[:, -1, :].unsqueeze(1)
211
+ model_kwargs["image_attention_mask"] = last_mask
212
+
213
+ # Get the precomputed image_hidden_states
214
+ model_kwargs["image_hidden_states"] = outputs.image_hidden_states
215
+
216
+ return model_kwargs
217
+
218
+
219
+ def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs):
220
+ token_type_ids = kwargs.get("token_type_ids", None)
221
+ # only last token for inputs_ids if past is defined in kwargs
222
+ if past_key_values:
223
+ input_ids = input_ids[:, -1].unsqueeze(-1)
224
+ if token_type_ids is not None:
225
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
226
+
227
+ attention_mask = kwargs.get("attention_mask", None)
228
+ position_ids = kwargs.get("position_ids", None)
229
+
230
+ if attention_mask is not None and position_ids is None:
231
+ # create position_ids on the fly for batch generation
232
+ position_ids = attention_mask.long().cumsum(-1) - 1
233
+ position_ids.masked_fill_(attention_mask == 0, 1)
234
+ if past_key_values:
235
+ position_ids = position_ids[:, -1].unsqueeze(-1)
236
+
237
+ pixel_values = kwargs.get("pixel_values", None)
238
+ image_hidden_states = kwargs.get("image_hidden_states", None)
239
+ image_attention_mask = kwargs.get("image_attention_mask", None)
240
+
241
+ return {
242
+ "input_ids": input_ids,
243
+ "past_key_values": past_key_values,
244
+ "use_cache": kwargs.get("use_cache"),
245
+ "position_ids": position_ids,
246
+ "attention_mask": attention_mask,
247
+ "token_type_ids": token_type_ids,
248
+ "pixel_values": pixel_values,
249
+ "image_hidden_states": image_hidden_states,
250
+ "image_attention_mask": image_attention_mask,
251
+ }
252
+
253
+
254
+ def freeze_model(model, module_exceptions=[]):
255
+ mapping = {
256
+ "LayerNorm": nn.LayerNorm,
257
+ "Linear": nn.Linear,
258
+ "Embedding": nn.Embedding,
259
+ }
260
+ module_exceptions_mapped = [mapping[m] for m in module_exceptions]
261
+ for module in model.modules():
262
+ if module_exceptions and any([isinstance(module, t) for t in module_exceptions_mapped]):
263
+ module.requires_grad_(True) # Explicitly setting it to true to avoid any mistakes
264
+ else:
265
+ module.requires_grad_(False)
266
+ return model
267
+
268
+
269
+ class DecoupledEmbedding(nn.Embedding):
270
+ # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding
271
+ """
272
+ Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings.
273
+ In practise, the regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0, then it will create `num_additional_embeddings` additional parameters that are always trained.
274
+ If `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`.
275
+ """
276
+
277
+ def __init__(
278
+ self,
279
+ num_embeddings,
280
+ num_additional_embeddings,
281
+ embedding_dim,
282
+ partially_freeze=False,
283
+ device=None,
284
+ dtype=None,
285
+ padding_idx=None,
286
+ **kwargs,
287
+ ) -> None:
288
+ """
289
+ num_additional_embeddings: int. Number of additional embeddings. Only useful when you `partially_freeze=True`.
290
+ partially_freeze: bool. If True, the regular `weight` will be frozen. `additional_weight` is never frozen.
291
+
292
+ Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`, `max_norm` or `norm_type`. We are not supporting these.
293
+ """
294
+ if padding_idx is not None and padding_idx > num_embeddings:
295
+ raise ValueError(f"padding_idx must be within num_embeddings. Got {padding_idx} and {num_embeddings}")
296
+ super().__init__(
297
+ num_embeddings=num_embeddings,
298
+ embedding_dim=embedding_dim,
299
+ device=device,
300
+ dtype=dtype,
301
+ padding_idx=padding_idx,
302
+ **kwargs,
303
+ )
304
+ self.num_embeddings = num_embeddings
305
+ self.padding_idx = padding_idx
306
+ self.num_additional_embeddings = num_additional_embeddings
307
+ self.partially_freeze = partially_freeze
308
+
309
+ if partially_freeze:
310
+ self.weight.requires_grad_(False)
311
+
312
+ if self.num_additional_embeddings > 0:
313
+ self.additional_embedding = nn.Embedding(
314
+ num_embeddings=self.num_additional_embeddings,
315
+ embedding_dim=embedding_dim,
316
+ device=device,
317
+ dtype=dtype,
318
+ )
319
+
320
+ def forward(self, input_ids):
321
+ """
322
+ we have 2 embeddings, with different indices - one pretrained self.weight and another
323
+ self.additional_embedding.weight that is being trained.
324
+
325
+ in order to make a lookup of the input ids, we:
326
+ 1. find out the indices of the entries belonging to the 2nd embedding
327
+ 2. extract those values while subtracting the size of the first embedding (num_embeddings),
328
+ since the 2nd embedding starts from 0 and not num_embeddings
329
+ 3. perform the 2nd embedding lookup
330
+ 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index
331
+ 5. perform the 1st embedding lookup
332
+ 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup
333
+
334
+ note: for the 1st embedding lookup we could have looked up only the low indices and not do
335
+ the padding, but then we have to create a new tensor and populate it with 2 tensors that are
336
+ spread out across various indices - i.e. not a simple concat - I haven't benchmarked the
337
+ complex case if it's any faster, given that seqlens are usually relatively short it's
338
+ probably not faster or if faster not by much - but might be a good idea to measure.
339
+
340
+ """
341
+ if self.num_additional_embeddings == 0:
342
+ return self.additional_embedding(input_ids)
343
+
344
+ # Clone so that we don't modify the original input_ids later on
345
+ input_ids = input_ids.clone()
346
+ additional_vocab_indices = torch.where(input_ids >= self.num_embeddings)
347
+ input_ids_additional_vocab = input_ids[additional_vocab_indices]
348
+ additional_embeddings = self.additional_embedding(input_ids_additional_vocab - self.num_embeddings)
349
+
350
+ # for successful lookup replace input_ids with 0, the results of these will be discarded anyway
351
+ input_ids[additional_vocab_indices] = 0
352
+ full_vector = F.embedding(input_ids, self.weight)
353
+
354
+ # overwrite the records with high indices
355
+ full_vector[additional_vocab_indices] = additional_embeddings
356
+
357
+ return full_vector
358
+
359
+ def extra_repr(self) -> str:
360
+ return "num_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format(
361
+ self.num_embeddings,
362
+ self.num_additional_embeddings,
363
+ self.embedding_dim,
364
+ self.partially_freeze,
365
+ )
366
+
367
+ class DecoupledLinear(nn.Linear):
368
+ # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
369
+ """
370
+ Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters.
371
+ In practise, the regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `out_additional_features` > 0, then it will create `out_additional_features * in_features` additional parameters that are always trained.
372
+ If `out_additional_features=0`, then the module defaults back to the regular behavior of `nn.Linear`.
373
+ """
374
+
375
+ def __init__(
376
+ self,
377
+ in_features: int,
378
+ out_features: int,
379
+ out_additional_features: int = 0,
380
+ bias: bool = True,
381
+ partially_freeze: bool = True,
382
+ device=None,
383
+ dtype=None,
384
+ ) -> None:
385
+ """
386
+ out_additional_features: int. Number of additional trainable dimensions. Only makes sense when `partially_freeze=True`.
387
+ partially_freeze: bool. If True, the regular `weight` will be frozen and extra parameters (if any) will be trainable. If False, default to the regular behavior of nn.Linear.
388
+ """
389
+ super().__init__(in_features, out_features, bias, device, dtype)
390
+ self.out_additional_features = out_additional_features
391
+ self.partially_freeze = partially_freeze
392
+
393
+ self.in_features = in_features
394
+ self.out_features = out_features
395
+
396
+ if partially_freeze:
397
+ self.weight.requires_grad_(False)
398
+ if bias:
399
+ self.bias.requires_grad_(False)
400
+
401
+ if out_additional_features > 0:
402
+ self.additional_fc = nn.Linear(
403
+ in_features=in_features,
404
+ out_features=out_additional_features,
405
+ bias=bias,
406
+ device=device,
407
+ dtype=dtype,
408
+ )
409
+
410
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
411
+ output = F.linear(input, self.weight, self.bias)
412
+
413
+ if self.out_additional_features > 0:
414
+ additional_features = self.additional_fc(input)
415
+ output = torch.cat((output, additional_features), -1)
416
+
417
+ return output
418
+
419
+ def extra_repr(self) -> str:
420
+ """Overwriting `nn.Linear.extra_repr` to include new parameters."""
421
+ return "in_features={}, out_features={}, out_additional_features={}, bias={}, partially_freeze={}".format(
422
+ self.in_features,
423
+ self.out_features,
424
+ self.out_additional_features,
425
+ self.bias is not None,
426
+ self.partially_freeze,
427
+ )
428
+
429
+ class SwiGLU(nn.Module):
430
+ def __init__(self, embed_dim) -> None:
431
+ super().__init__()
432
+ self.fc1 = nn.Linear(embed_dim, embed_dim, bias=False)
433
+ self.fc2 = nn.Linear(embed_dim, embed_dim, bias=False)
434
+
435
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
436
+ x_1 = self.fc1(x)
437
+ x_1 = torch.mul(x_1, torch.sigmoid(x_1))
438
+ x_2 = self.fc2(x)
439
+ x = torch.mul(x_1, x_2)
440
+ return x
441
+
442
+
443
+ class ModalityProjection(nn.Module):
444
+ def __init__(self, embed_dim_in, embed_dim_out) -> None:
445
+ super().__init__()
446
+ self.fc1 = nn.Linear(embed_dim_in, embed_dim_out, bias=False)
447
+ self.act = SwiGLU(embed_dim_out)
448
+ self.fc2 = nn.Linear(embed_dim_out, embed_dim_out, bias=False)
449
+
450
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
451
+ x = self.fc1(x)
452
+ x = self.act(x)
453
+ x = self.fc2(x)
454
+ return x
455
+
456
+
457
+ class PerceiverResampler(nn.Module):
458
+ def __init__(
459
+ self, embed_dim: int, depth: int, n_heads: int, head_dim: int, n_latents: int, qk_layer_norms: bool
460
+ ) -> None:
461
+ """
462
+ Instantiates a Perceiver Resampler that operates over a sequence of embeddings (say from a ResNet or ViT or
463
+ MAE) of a given dimension, performs `depth` blocks of cross-attention with a fixed `n_latents` inputs, then
464
+ returns a Tensor of shape [bsz, n_latents, embed_dim].
465
+ :param embed_dim: Dimensionality of embeddings being fed to the Perceiver Resampler (also dimensionality of
466
+ latent embeddings *returned* by the Perceiver Resampler. Could be e.g., VIT embed_dim, ResNet
467
+ pool dim, and so on.
468
+ :param depth: Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3).
469
+ :param n_heads: Number of heads in each Transformer block (for multi-headed self-attention).
470
+ :param head_dim: Dimensionality of each head projection in the Transformer block.
471
+ :param n_latents: Number of latent embeddings to resample ("compress") the input sequence to (usually < 128).
472
+ """
473
+ super().__init__()
474
+ self.embed_dim, self.n_heads, self.head_dim, self.n_latents = embed_dim, n_heads, head_dim, n_latents
475
+ self.qk_layer_norms = qk_layer_norms
476
+
477
+ # Create Latents for Perceiver
478
+ self.latents = nn.Parameter(torch.ones(self.n_latents, self.embed_dim))
479
+
480
+ self.intermediate_dim = self.embed_dim * 4
481
+ # Create Transformer Blocks
482
+ self.blocks = nn.ModuleList(
483
+ [
484
+ nn.ModuleList(
485
+ [
486
+ PerceiverAttention(self.embed_dim, self.n_heads, self.head_dim, self.qk_layer_norms),
487
+ MLP(self.embed_dim, self.intermediate_dim),
488
+ ]
489
+ )
490
+ for _ in range(depth)
491
+ ]
492
+ )
493
+ self.layer_norm = nn.LayerNorm(self.embed_dim)
494
+
495
+ def forward(self, context: torch.Tensor) -> torch.Tensor:
496
+ """Resample arbitrary length context & *compress* down to self.n_latents latent embeddings"""
497
+ latents = repeat(self.latents, "seq embed -> bsz seq embed", bsz=context.shape[0])
498
+
499
+ # Feed through Perceiver Attention blocks...
500
+ for attn, ff in self.blocks:
501
+ latents = attn(context, latents) + latents
502
+ latents = ff(latents) + latents
503
+
504
+ return self.layer_norm(latents)
505
+
506
+
507
+ class PerceiverAttention(nn.Module):
508
+ def __init__(self, embed_dim: int, n_heads: int, head_dim: int, qk_layer_norms: bool) -> None:
509
+ """Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`"""
510
+ super().__init__()
511
+ self.embed_dim, self.n_heads, self.head_dim = embed_dim, n_heads, head_dim
512
+ self.qk_layer_norms = qk_layer_norms
513
+ # Normalization & Scaling
514
+ self.context_layer_norm = nn.LayerNorm(self.embed_dim)
515
+ self.latents_layer_norm = nn.LayerNorm(self.embed_dim)
516
+ if self.qk_layer_norms:
517
+ self.q_layer_norm = nn.LayerNorm(self.head_dim)
518
+ self.k_layer_norm = nn.LayerNorm(self.head_dim)
519
+
520
+ self.qk_scale = self.head_dim**-0.5
521
+
522
+ # Q, K, V Projection (no bias -- detail from Perceiver/Flamingo Papers).
523
+ self.q_proj = nn.Linear(self.embed_dim, self.n_heads * self.head_dim, bias=False)
524
+ self.k_proj = nn.Linear(self.embed_dim, self.n_heads * self.head_dim, bias=False)
525
+ self.v_proj = nn.Linear(self.embed_dim, self.n_heads * self.head_dim, bias=False)
526
+
527
+ self.output_proj = nn.Linear(self.n_heads * self.head_dim, self.embed_dim, bias=False)
528
+
529
+ def forward(self, context: torch.Tensor, latents: torch.Tensor) -> torch.Tensor:
530
+ """
531
+ Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension!
532
+ :param context: Tensor of shape [bsz, seq, embed_dim] representing long-form context to resample.
533
+ :param latents: Tensor of shape [bsz, n_latents, embed_dim] representing fixed length latents to compress to.
534
+ :return: Tensor of shape [bsz, n_latents, embed_dim] representing attention over latents w/ cross from context.
535
+ """
536
+ context = self.context_layer_norm(context)
537
+ latents = self.latents_layer_norm(latents)
538
+
539
+ # Query, Key, Value Projections --> Note that in Flamingo, latents are *concatenated* with context prior to attn!
540
+ # Note: This results in queries w/ `seq = n_latents`, and keys, values with `seq = len(context) + n_latents`
541
+ q = self.q_proj(latents)
542
+ k = self.k_proj(torch.cat([context, latents], dim=-2))
543
+ v = self.v_proj(torch.cat([context, latents], dim=-2))
544
+
545
+ # Multiheaded Self-Attention w/ stable softmax (subtract per-row max -- `amax` -- before softmax call)
546
+ # =>> `attn` should be a 2D matrix of shape [n_latents x (context + n_latents)]
547
+ q, k, v = [rearrange(x, "bsz seq (heads embed) -> bsz heads seq embed", heads=self.n_heads) for x in (q, k, v)]
548
+ if self.qk_layer_norms:
549
+ q = self.q_layer_norm(q)
550
+ k = self.k_layer_norm(k)
551
+
552
+ scores = torch.einsum("... i d, ... j d -> ... i j", q * self.qk_scale, k)
553
+ stabilized_scores = scores - (scores.amax(dim=-1, keepdim=True).detach())
554
+ attn = stabilized_scores.softmax(dim=-1)
555
+
556
+ # Attend & project back to output...
557
+ resampled = torch.einsum("... i j, ... j d -> ... i d", attn, v)
558
+ return self.output_proj(
559
+ rearrange(resampled, "bsz heads seq embed -> bsz seq (heads embed)", heads=self.n_heads)
560
+ )
561
+
562
+
563
+ class MLP(nn.Module):
564
+ def __init__(self, embed_dim, intermediate_size):
565
+ """Simple MLP block with intermediate_size and embedding size"""
566
+ super().__init__()
567
+ self.embed_dim = embed_dim
568
+ self.ln = nn.LayerNorm(self.embed_dim)
569
+ self.fc = nn.Linear(self.embed_dim, intermediate_size, bias=False)
570
+ self.act = nn.ReLU()
571
+ self.c_proj = nn.Linear(intermediate_size, self.embed_dim, bias=False)
572
+
573
+ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
574
+ hidden_states = self.ln(hidden_states)
575
+ hidden_states = self.fc(hidden_states)
576
+ hidden_states = self.act(hidden_states)
577
+ hidden_states = self.c_proj(hidden_states)
578
+
579
+ return hidden_states
580
+
581
+
582
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
583
+ def _get_unpad_data(attention_mask):
584
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
585
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
586
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
587
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
588
+ return (
589
+ indices,
590
+ cu_seqlens,
591
+ max_seqlen_in_batch,
592
+ )
593
+
594
+
595
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral
596
+ class MistralRMSNorm(nn.Module):
597
+ def __init__(self, hidden_size, eps=1e-6):
598
+ """
599
+ MistralRMSNorm is equivalent to T5LayerNorm
600
+ """
601
+ super().__init__()
602
+ self.weight = nn.Parameter(torch.ones(hidden_size))
603
+ self.variance_epsilon = eps
604
+
605
+ def forward(self, hidden_states):
606
+ input_dtype = hidden_states.dtype
607
+ hidden_states = hidden_states.to(torch.float32)
608
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
609
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
610
+ return self.weight * hidden_states.to(input_dtype)
611
+
612
+
613
+ # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral
614
+ class MistralRotaryEmbedding(nn.Module):
615
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
616
+ super().__init__()
617
+
618
+ self.dim = dim
619
+ self.max_position_embeddings = max_position_embeddings
620
+ self.base = base
621
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
622
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
623
+
624
+ # Build here to make `torch.jit.trace` work.
625
+ self._set_cos_sin_cache(
626
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
627
+ )
628
+
629
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
630
+ self.max_seq_len_cached = seq_len
631
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
632
+
633
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
634
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
635
+ emb = torch.cat((freqs, freqs), dim=-1)
636
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
637
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
638
+
639
+ def forward(self, x, seq_len=None):
640
+ # x: [bs, num_attention_heads, seq_len, head_size]
641
+ if seq_len > self.max_seq_len_cached:
642
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
643
+
644
+ return (
645
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
646
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
647
+ )
648
+
649
+
650
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
651
+ def rotate_half(x):
652
+ """Rotates half the hidden dims of the input."""
653
+ x1 = x[..., : x.shape[-1] // 2]
654
+ x2 = x[..., x.shape[-1] // 2 :]
655
+ return torch.cat((-x2, x1), dim=-1)
656
+
657
+
658
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
659
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
660
+ cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
661
+ sin = sin[position_ids].unsqueeze(1)
662
+ q_embed = (q * cos) + (rotate_half(q) * sin)
663
+ k_embed = (k * cos) + (rotate_half(k) * sin)
664
+ return q_embed, k_embed
665
+
666
+
667
+ class MistralMLP(nn.Module):
668
+ def __init__(self, config):
669
+ super().__init__()
670
+ self.config = config
671
+ self.hidden_size = config.hidden_size
672
+ self.intermediate_size = config.intermediate_size
673
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
674
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
675
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
676
+ self.act_fn = ACT2FN[config.hidden_act]
677
+
678
+ def forward(self, x):
679
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
680
+
681
+
682
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
683
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
684
+ """
685
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
686
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
687
+ """
688
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
689
+ if n_rep == 1:
690
+ return hidden_states
691
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
692
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
693
+
694
+
695
+ class MistralAttention(nn.Module):
696
+ """
697
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
698
+ and "Generating Long Sequences with Sparse Transformers".
699
+ """
700
+
701
+ def __init__(self, config: VMistralConfig, qk_layer_norms: bool = False):
702
+ super().__init__()
703
+ self.config = config
704
+ self.hidden_size = config.hidden_size
705
+ self.num_heads = config.num_attention_heads
706
+ self.head_dim = self.hidden_size // self.num_heads
707
+ self.num_key_value_heads = config.num_key_value_heads
708
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
709
+ self.max_position_embeddings = config.max_position_embeddings
710
+ self.rope_theta = config.rope_theta
711
+ self.is_causal = True
712
+
713
+ if (self.head_dim * self.num_heads) != self.hidden_size:
714
+ raise ValueError(
715
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
716
+ f" and `num_heads`: {self.num_heads})."
717
+ )
718
+
719
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
720
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
721
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
722
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
723
+
724
+ self.qk_layer_norms = qk_layer_norms
725
+ if self.qk_layer_norms:
726
+ self.q_layer_norm = MistralRMSNorm(self.head_dim, eps=config.rms_norm_eps)
727
+ self.k_layer_norm = MistralRMSNorm(self.head_dim, eps=config.rms_norm_eps)
728
+
729
+ self.rotary_emb = MistralRotaryEmbedding(
730
+ self.head_dim,
731
+ max_position_embeddings=self.max_position_embeddings,
732
+ base=self.rope_theta,
733
+ )
734
+ self.attention_dropout = config.attention_dropout
735
+
736
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
737
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
738
+
739
+ def forward(
740
+ self,
741
+ hidden_states: torch.Tensor,
742
+ key_value_states: Optional[torch.Tensor] = None,
743
+ attention_mask: Optional[torch.Tensor] = None,
744
+ position_ids: Optional[torch.LongTensor] = None,
745
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
746
+ output_attentions: bool = False,
747
+ use_cache: bool = False,
748
+ **kwargs,
749
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
750
+ if "padding_mask" in kwargs:
751
+ warnings.warn(
752
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use"
753
+ " `attention_mask` instead.`"
754
+ )
755
+
756
+ bsz, q_len, _ = hidden_states.size()
757
+
758
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
759
+ key_states = (
760
+ self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
761
+ )
762
+ value_states = (
763
+ self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
764
+ )
765
+
766
+ kv_seq_len = key_states.shape[-2]
767
+ if past_key_value is not None:
768
+ kv_seq_len += past_key_value[0].shape[-2]
769
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
770
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
771
+
772
+ if past_key_value is not None:
773
+ # reuse k, v, self_attention
774
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
775
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
776
+
777
+ past_key_value = (key_states, value_states) if use_cache else None
778
+
779
+ if self.qk_layer_norms:
780
+ query_states = self.q_layer_norm(query_states)
781
+ key_states = self.k_layer_norm(key_states)
782
+
783
+ # repeat k/v heads if n_kv_heads < n_heads
784
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
785
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
786
+
787
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
788
+
789
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
790
+ raise ValueError(
791
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
792
+ f" {attn_weights.size()}"
793
+ )
794
+
795
+ if attention_mask is not None:
796
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
797
+ raise ValueError(
798
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
799
+ )
800
+
801
+ attn_weights = attn_weights + attention_mask
802
+
803
+ # upcast attention to fp32
804
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
805
+ attn_output = torch.matmul(attn_weights, value_states)
806
+
807
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
808
+ raise ValueError(
809
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
810
+ f" {attn_output.size()}"
811
+ )
812
+
813
+ attn_output = attn_output.transpose(1, 2).contiguous()
814
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
815
+
816
+ attn_output = self.o_proj(attn_output)
817
+
818
+ if not output_attentions:
819
+ attn_weights = None
820
+
821
+ return attn_output, attn_weights, past_key_value
822
+
823
+
824
+ class MistralFlashAttention2(MistralAttention):
825
+ """
826
+ Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays
827
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
828
+ flash attention and deal with padding tokens in case the input contains any of them.
829
+ """
830
+
831
+ def forward(
832
+ self,
833
+ hidden_states: torch.Tensor,
834
+ attention_mask: Optional[torch.Tensor] = None,
835
+ position_ids: Optional[torch.LongTensor] = None,
836
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
837
+ output_attentions: bool = False,
838
+ use_cache: bool = False,
839
+ **kwargs,
840
+ ):
841
+ if "padding_mask" in kwargs:
842
+ warnings.warn(
843
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use"
844
+ " `attention_mask` instead.`"
845
+ )
846
+
847
+ # overwrite attention_mask with padding_mask
848
+ attention_mask = kwargs.pop("padding_mask")
849
+ bsz, q_len, _ = hidden_states.size()
850
+
851
+ query_states = self.q_proj(hidden_states)
852
+ key_states = self.k_proj(hidden_states)
853
+ value_states = self.v_proj(hidden_states)
854
+
855
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
856
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
857
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
858
+
859
+ kv_seq_len = key_states.shape[-2]
860
+ if past_key_value is not None:
861
+ kv_seq_len += past_key_value[0].shape[-2]
862
+
863
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
864
+ rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
865
+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
866
+
867
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
868
+
869
+ use_sliding_windows = (
870
+ _flash_supports_window_size
871
+ and hasattr(self.config, "sliding_window") is not None
872
+ and kv_seq_len > self.config.sliding_window
873
+ )
874
+
875
+ if not _flash_supports_window_size:
876
+ logger.warning_once(
877
+ "The current flash attention version does not support sliding window attention, for a more memory"
878
+ " efficient implementation make sure to upgrade flash-attn library."
879
+ )
880
+
881
+ if past_key_value is not None:
882
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
883
+ if hasattr(self.config, "sliding_window") and kv_seq_len > self.config.sliding_window:
884
+ slicing_tokens = kv_seq_len - self.config.sliding_window
885
+
886
+ past_key = past_key_value[0]
887
+ past_value = past_key_value[1]
888
+
889
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
890
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
891
+
892
+ if past_key.shape[-2] != self.config.sliding_window - 1:
893
+ raise ValueError(
894
+ "past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1,"
895
+ f" head_dim`), got {past_key.shape}"
896
+ )
897
+
898
+ past_key_value = (past_key, past_value)
899
+
900
+ if attention_mask is not None:
901
+ attention_mask = attention_mask[:, slicing_tokens:]
902
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
903
+
904
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
905
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
906
+
907
+ past_key_value = (key_states, value_states) if use_cache else None
908
+
909
+ # repeat k/v heads if n_kv_heads < n_heads
910
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
911
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
912
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
913
+
914
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
915
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
916
+ # cast them back in float16 just to be sure everything works as expected.
917
+ input_dtype = query_states.dtype
918
+ if input_dtype == torch.float32:
919
+ # Handle the case where the model is quantized
920
+ if hasattr(self.config, "_pre_quantization_dtype"):
921
+ target_dtype = self.config._pre_quantization_dtype
922
+ else:
923
+ target_dtype = self.q_proj.weight.dtype
924
+
925
+ logger.warning_once(
926
+ "The input hidden states seems to be silently casted in float32, this might be related to the fact"
927
+ " you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
928
+ f" {target_dtype}."
929
+ )
930
+
931
+ query_states = query_states.to(target_dtype)
932
+ key_states = key_states.to(target_dtype)
933
+ value_states = value_states.to(target_dtype)
934
+
935
+ # Reashape to the expected shape for Flash Attention
936
+ query_states = query_states.transpose(1, 2)
937
+ key_states = key_states.transpose(1, 2)
938
+ value_states = value_states.transpose(1, 2)
939
+
940
+ attn_output = self._flash_attention_forward(
941
+ query_states,
942
+ key_states,
943
+ value_states,
944
+ attention_mask,
945
+ q_len,
946
+ dropout=dropout_rate,
947
+ use_sliding_windows=use_sliding_windows,
948
+ )
949
+
950
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
951
+ attn_output = self.o_proj(attn_output)
952
+
953
+ if not output_attentions:
954
+ attn_weights = None
955
+
956
+ return attn_output, attn_weights, past_key_value
957
+
958
+ def _flash_attention_forward(
959
+ self,
960
+ query_states,
961
+ key_states,
962
+ value_states,
963
+ attention_mask,
964
+ query_length,
965
+ dropout=0.0,
966
+ softmax_scale=None,
967
+ use_sliding_windows=False,
968
+ ):
969
+ """
970
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
971
+ first unpad the input, then computes the attention scores and pad the final attention scores.
972
+
973
+ Args:
974
+ query_states (`torch.Tensor`):
975
+ Input query states to be passed to Flash Attention API
976
+ key_states (`torch.Tensor`):
977
+ Input key states to be passed to Flash Attention API
978
+ value_states (`torch.Tensor`):
979
+ Input value states to be passed to Flash Attention API
980
+ attention_mask (`torch.Tensor`):
981
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
982
+ position of padding tokens and 1 for the position of non-padding tokens.
983
+ dropout (`int`, *optional*):
984
+ Attention dropout
985
+ softmax_scale (`float`, *optional*):
986
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
987
+ use_sliding_windows (`bool`, *optional*):
988
+ Whether to activate sliding window attention.
989
+ """
990
+ # Contains at least one padding token in the sequence
991
+ if attention_mask is not None:
992
+ batch_size = query_states.shape[0]
993
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
994
+ query_states, key_states, value_states, attention_mask, query_length
995
+ )
996
+
997
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
998
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
999
+
1000
+ if not use_sliding_windows:
1001
+ attn_output_unpad = flash_attn_varlen_func(
1002
+ query_states,
1003
+ key_states,
1004
+ value_states,
1005
+ cu_seqlens_q=cu_seqlens_q,
1006
+ cu_seqlens_k=cu_seqlens_k,
1007
+ max_seqlen_q=max_seqlen_in_batch_q,
1008
+ max_seqlen_k=max_seqlen_in_batch_k,
1009
+ dropout_p=dropout,
1010
+ softmax_scale=softmax_scale,
1011
+ causal=self.is_causal,
1012
+ )
1013
+ else:
1014
+ attn_output_unpad = flash_attn_varlen_func(
1015
+ query_states,
1016
+ key_states,
1017
+ value_states,
1018
+ cu_seqlens_q=cu_seqlens_q,
1019
+ cu_seqlens_k=cu_seqlens_k,
1020
+ max_seqlen_q=max_seqlen_in_batch_q,
1021
+ max_seqlen_k=max_seqlen_in_batch_k,
1022
+ dropout_p=dropout,
1023
+ softmax_scale=softmax_scale,
1024
+ causal=self.is_causal,
1025
+ window_size=(self.config.sliding_window, self.config.sliding_window),
1026
+ )
1027
+
1028
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
1029
+ else:
1030
+ if not use_sliding_windows:
1031
+ attn_output = flash_attn_func(
1032
+ query_states,
1033
+ key_states,
1034
+ value_states,
1035
+ dropout,
1036
+ softmax_scale=softmax_scale,
1037
+ causal=self.is_causal,
1038
+ )
1039
+ else:
1040
+ attn_output = flash_attn_func(
1041
+ query_states,
1042
+ key_states,
1043
+ value_states,
1044
+ dropout,
1045
+ softmax_scale=softmax_scale,
1046
+ causal=self.is_causal,
1047
+ window_size=(self.config.sliding_window, self.config.sliding_window),
1048
+ )
1049
+
1050
+ return attn_output
1051
+
1052
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
1053
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
1054
+
1055
+ # On the first iteration we need to properly re-create the padding mask
1056
+ # by slicing it on the proper place
1057
+ if kv_seq_len != attention_mask.shape[-1]:
1058
+ attention_mask_num_tokens = attention_mask.shape[-1]
1059
+ attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
1060
+
1061
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
1062
+
1063
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
1064
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
1065
+
1066
+ if query_length == kv_seq_len:
1067
+ query_layer = index_first_axis(
1068
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
1069
+ )
1070
+ cu_seqlens_q = cu_seqlens_k
1071
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
1072
+ indices_q = indices_k
1073
+ elif query_length == 1:
1074
+ max_seqlen_in_batch_q = 1
1075
+ cu_seqlens_q = torch.arange(
1076
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
1077
+ ) # There is a memcpy here, that is very bad.
1078
+ indices_q = cu_seqlens_q[:-1]
1079
+ query_layer = query_layer.squeeze(1)
1080
+ else:
1081
+ # The -q_len: slice assumes left padding.
1082
+ attention_mask = attention_mask[:, -query_length:]
1083
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
1084
+
1085
+ return (
1086
+ query_layer,
1087
+ key_layer,
1088
+ value_layer,
1089
+ indices_q,
1090
+ (cu_seqlens_q, cu_seqlens_k),
1091
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
1092
+ )
1093
+
1094
+
1095
+ class MistralDecoderLayer(nn.Module):
1096
+ def __init__(self, config: VMistralConfig):
1097
+ super().__init__()
1098
+ self.hidden_size = config.hidden_size
1099
+ self.self_attn = (
1100
+ MistralAttention(config=config)
1101
+ if not getattr(config, "_flash_attn_2_enabled", False)
1102
+ else MistralFlashAttention2(config)
1103
+ )
1104
+ self.mlp = MistralMLP(config)
1105
+ self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1106
+ self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1107
+
1108
+ def forward(
1109
+ self,
1110
+ hidden_states: torch.Tensor,
1111
+ attention_mask: Optional[torch.Tensor] = None,
1112
+ position_ids: Optional[torch.LongTensor] = None,
1113
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
1114
+ output_attentions: Optional[bool] = False,
1115
+ use_cache: Optional[bool] = False,
1116
+ **kwargs,
1117
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1118
+ if "padding_mask" in kwargs:
1119
+ warnings.warn(
1120
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use"
1121
+ " `attention_mask` instead.`"
1122
+ )
1123
+ """
1124
+ Args:
1125
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1126
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
1127
+ `(batch, sequence_length)` where padding elements are indicated by 0.
1128
+ output_attentions (`bool`, *optional*):
1129
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1130
+ returned tensors for more detail.
1131
+ use_cache (`bool`, *optional*):
1132
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1133
+ (see `past_key_values`).
1134
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1135
+ """
1136
+
1137
+ residual = hidden_states
1138
+
1139
+ hidden_states = self.input_layernorm(hidden_states)
1140
+
1141
+ # Self Attention
1142
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
1143
+ hidden_states=hidden_states,
1144
+ attention_mask=attention_mask,
1145
+ position_ids=position_ids,
1146
+ past_key_value=past_key_value,
1147
+ output_attentions=output_attentions,
1148
+ use_cache=use_cache,
1149
+ )
1150
+ hidden_states = residual + hidden_states
1151
+
1152
+ # Fully Connected
1153
+ residual = hidden_states
1154
+ hidden_states = self.post_attention_layernorm(hidden_states)
1155
+ hidden_states = self.mlp(hidden_states)
1156
+ hidden_states = residual + hidden_states
1157
+
1158
+ outputs = (hidden_states,)
1159
+
1160
+ if output_attentions:
1161
+ outputs += (self_attn_weights,)
1162
+
1163
+ if use_cache:
1164
+ outputs += (present_key_value,)
1165
+
1166
+ return outputs
1167
+
1168
+
1169
+ MISTRAL_START_DOCSTRING = r"""
1170
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1171
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1172
+ etc.)
1173
+
1174
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1175
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1176
+ and behavior.
1177
+
1178
+ Parameters:
1179
+ config ([`VMistralConfig`]):
1180
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
1181
+ load the weights associated with the model, only the configuration. Check out the
1182
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1183
+ """
1184
+
1185
+
1186
+ @add_start_docstrings(
1187
+ "The bare Mistral Model outputting raw hidden-states without any specific head on top.",
1188
+ MISTRAL_START_DOCSTRING,
1189
+ )
1190
+ class VMistralPreTrainedModel(PreTrainedModel):
1191
+ config_class = VMistralConfig
1192
+ base_model_prefix = "model"
1193
+ supports_gradient_checkpointing = True
1194
+ _no_split_modules = ["MistralDecoderLayer"]
1195
+ _skip_keys_device_placement = "past_key_values"
1196
+ _supports_sdpa = False
1197
+
1198
+ def _init_weights(self, module):
1199
+ # important: this ported version of the model isn't meant for training from scratch - only
1200
+ # inference and fine-tuning - so the proper init weights code has been removed - the m4 code
1201
+ # base should be used for training from scratch and it contains the correct code.
1202
+ std = self.config.initializer_range
1203
+ if isinstance(module, nn.Linear):
1204
+ module.weight.data.normal_(mean=0.0, std=std)
1205
+ if module.bias is not None:
1206
+ module.bias.data.zero_()
1207
+ elif isinstance(module, nn.Embedding):
1208
+ module.weight.data.normal_(mean=0.0, std=std)
1209
+ if module.padding_idx is not None:
1210
+ module.weight.data[module.padding_idx].zero_()
1211
+
1212
+ # @classmethod
1213
+ # def override_vision_model_wrapper(cls, model, config, vision_model_name, vision_model_params, torch_dtype):
1214
+ # # this can be called via from_pretrained from a class w/ head or w/o head so we extract the beheaded model version
1215
+ # beheaded_model = model.model if hasattr(model, "model") else model
1216
+ # cls.override_vision_model(beheaded_model, vision_model_name, vision_model_params, torch_dtype)
1217
+ # beheaded_model.freeze_relevant_params(config)
1218
+
1219
+
1220
+ MISTRAL_INPUTS_DOCSTRING = r"""
1221
+ Args:
1222
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1223
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1224
+ it.
1225
+
1226
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1227
+ [`PreTrainedTokenizer.__call__`] for details.
1228
+
1229
+ [What are input IDs?](../glossary#input-ids)
1230
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1231
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1232
+
1233
+ - 1 for tokens that are **not masked**,
1234
+ - 0 for tokens that are **masked**.
1235
+
1236
+ [What are attention masks?](../glossary#attention-mask)
1237
+
1238
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1239
+ [`PreTrainedTokenizer.__call__`] for details.
1240
+
1241
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
1242
+ `past_key_values`).
1243
+
1244
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1245
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1246
+ information on the default strategy.
1247
+
1248
+ - 1 indicates the head is **not masked**,
1249
+ - 0 indicates the head is **masked**.
1250
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1251
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1252
+ config.n_positions - 1]`.
1253
+
1254
+ [What are position IDs?](../glossary#position-ids)
1255
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1256
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
1257
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
1258
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
1259
+
1260
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1261
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
1262
+
1263
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1264
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1265
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1266
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1267
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1268
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1269
+ model's internal embedding lookup matrix.
1270
+ use_cache (`bool`, *optional*):
1271
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1272
+ `past_key_values`).
1273
+ output_attentions (`bool`, *optional*):
1274
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1275
+ tensors for more detail.
1276
+ output_hidden_states (`bool`, *optional*):
1277
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1278
+ more detail.
1279
+ return_dict (`bool`, *optional*):
1280
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1281
+ """
1282
+
1283
+
1284
+ @add_start_docstrings(
1285
+ "The bare Mistral Model outputting raw hidden-states without any specific head on top.",
1286
+ MISTRAL_START_DOCSTRING,
1287
+ )
1288
+ class VMistralModel(VMistralPreTrainedModel):
1289
+ """
1290
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MistralDecoderLayer`]
1291
+
1292
+ Args:
1293
+ config: VMistralConfig
1294
+ """
1295
+
1296
+ def __init__(self, config: VMistralConfig, vision_model=None):
1297
+ super().__init__(config)
1298
+ self.config = config
1299
+ self.padding_idx = config.pad_token_id
1300
+ self.vocab_size = config.vocab_size
1301
+
1302
+ self.sliding_window = config.sliding_window
1303
+
1304
+ self.embed_tokens = DecoupledEmbedding(
1305
+ num_embeddings=config.vocab_size,
1306
+ num_additional_embeddings=config.additional_vocab_size,
1307
+ embedding_dim=config.hidden_size,
1308
+ partially_freeze=config.freeze_text_layers,
1309
+ padding_idx=self.padding_idx,
1310
+ )
1311
+
1312
+ # Load an uninitialized model and later in from_pretrained will load the pre-trained model -
1313
+ # this solves the losing of weights in `from_pretrained` on the main model
1314
+ self.vision_model = SiglipVisionModel(config.vision_config)
1315
+
1316
+ # Dim projection - projecting from the vision dim to the text dim
1317
+ self.modality_projection = ModalityProjection(
1318
+ embed_dim_in=self.config.vision_config.hidden_size, embed_dim_out=self.config.hidden_size
1319
+ )
1320
+
1321
+ # Perceiver Resampler
1322
+ if config.use_resampler:
1323
+ self.perceiver_resampler = PerceiverResampler(
1324
+ config.hidden_size,
1325
+ config.perceiver_config.resampler_depth,
1326
+ config.perceiver_config.resampler_n_heads,
1327
+ config.perceiver_config.resampler_head_dim,
1328
+ config.perceiver_config.resampler_n_latents,
1329
+ config.perceiver_config.qk_layer_norms_perceiver,
1330
+ )
1331
+
1332
+ if config.use_resampler:
1333
+ self.image_seq_len = config.perceiver_config.resampler_n_latents
1334
+ else:
1335
+ self.image_seq_len = (
1336
+ config.vision_config.image_size // config.vision_config.patch_size
1337
+ ) ** 2 # TODO: pretty sure that does not work for CLIP models since there is the CLS token
1338
+ self.image_token_id = self.config.image_token_id
1339
+
1340
+ self.layers = nn.ModuleList([MistralDecoderLayer(config) for _ in range(config.num_hidden_layers)])
1341
+
1342
+ self.gradient_checkpointing = False
1343
+
1344
+ self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1345
+
1346
+ # Initialize weights and apply final processing
1347
+ self.post_init()
1348
+
1349
+ self.freeze_relevant_params(config)
1350
+
1351
+ def freeze_relevant_params(self, config=None):
1352
+ if config is None:
1353
+ config = self.config
1354
+
1355
+ if config.freeze_text_layers:
1356
+ self.freeze_text_layers(config.freeze_text_module_exceptions)
1357
+
1358
+ if config.freeze_vision_layers:
1359
+ freeze_model(self.vision_model, module_exceptions=config.freeze_vision_module_exceptions)
1360
+
1361
+ def freeze_text_layers(self, module_exceptions):
1362
+ for module in [self.layers, self.norm]:
1363
+ freeze_model(module, module_exceptions=module_exceptions)
1364
+
1365
+ def get_input_embeddings(self):
1366
+ return self.embed_tokens
1367
+
1368
+ def set_input_embeddings(self, value):
1369
+ self.embed_tokens = value
1370
+
1371
+ def inputs_merger(
1372
+ self,
1373
+ input_ids: torch.LongTensor = None,
1374
+ inputs_embeds: Optional[torch.Tensor] = None,
1375
+ image_hidden_states: Optional[torch.Tensor] = None,
1376
+ num_images: Optional[int] = None,
1377
+ ):
1378
+ """
1379
+ This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM.
1380
+ The merging happens as follows:
1381
+ - The text token sequence is: `tok_1 tok_2 tok_3 <fake_token_around_image> <image> <image> ... <image> <fake_token_around_image> tok_4`.
1382
+ - We get the image hidden states for the image through the vision encoder (and potentially the perceiver), and that hidden state is then projected into the text embedding space.
1383
+ We thus have a sequence of image hidden states of size (1, image_seq_len, hidden_dim), where 1 is for batch_size of 1 image and hidden_dim is the hidden_dim of the LM transformer.
1384
+ - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
1385
+ - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
1386
+ """
1387
+ batch_size = input_ids.size(0)
1388
+
1389
+ if inputs_embeds is not None:
1390
+ vision_pipeline_output_seq_len = image_hidden_states.shape[1]
1391
+ vision_hidden_size = image_hidden_states.shape[2]
1392
+ new_inputs_embeds = inputs_embeds.clone()
1393
+ # Get a view of the image_hidden_states separating batch_size and num_images, to discard padding hidden_states
1394
+ image_hidden_states = image_hidden_states.view(
1395
+ batch_size, num_images, vision_pipeline_output_seq_len, vision_hidden_size
1396
+ )
1397
+ for batch_idx in range(batch_size):
1398
+ # Get the number of images for this particular example
1399
+ example_num_images = (input_ids[batch_idx] == self.image_token_id).sum() // self.image_seq_len
1400
+ # Get the image_hidden_states corresponding to True images for the example, so get rid of the padding images.
1401
+ example_true_image_hidden_states = image_hidden_states[batch_idx][:example_num_images]
1402
+ if (
1403
+ new_inputs_embeds[batch_idx][input_ids[batch_idx] == self.image_token_id].shape[0]
1404
+ != example_num_images * vision_pipeline_output_seq_len
1405
+ ):
1406
+ raise ValueError(
1407
+ "new_inputs_embeds to replace has shape[0]:"
1408
+ f" {new_inputs_embeds[batch_idx][input_ids[batch_idx] == self.image_token_id].shape[0]} but"
1409
+ " should have shape[0]:"
1410
+ f" {example_num_images}*{vision_pipeline_output_seq_len}={example_num_images * vision_pipeline_output_seq_len} "
1411
+ )
1412
+ # Insert the image_hidden_states
1413
+ new_inputs_embeds[batch_idx][input_ids[batch_idx] == self.image_token_id] = (
1414
+ example_true_image_hidden_states.view(
1415
+ example_num_images * vision_pipeline_output_seq_len,
1416
+ vision_hidden_size,
1417
+ )
1418
+ )
1419
+
1420
+ return_dict = {}
1421
+ if inputs_embeds is not None:
1422
+ return_dict["inputs_embeds"] = new_inputs_embeds
1423
+
1424
+ return return_dict
1425
+
1426
+ @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
1427
+ def forward(
1428
+ self,
1429
+ input_ids: torch.LongTensor = None,
1430
+ attention_mask: Optional[torch.Tensor] = None,
1431
+ position_ids: Optional[torch.LongTensor] = None,
1432
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1433
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1434
+ pixel_values: Optional[torch.FloatTensor] = None,
1435
+ image_hidden_states: Optional[torch.FloatTensor] = None,
1436
+ use_cache: Optional[bool] = None,
1437
+ output_attentions: Optional[bool] = None,
1438
+ output_hidden_states: Optional[bool] = None,
1439
+ return_dict: Optional[bool] = None,
1440
+ ) -> Union[Tuple, Img2HTMLBaseModelOutputWithPast]:
1441
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1442
+
1443
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1444
+ output_hidden_states = (
1445
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1446
+ )
1447
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1448
+
1449
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1450
+
1451
+ # retrieve input_ids and inputs_embeds
1452
+ if input_ids is not None and inputs_embeds is not None:
1453
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
1454
+ elif input_ids is not None:
1455
+ batch_size, seq_length = input_ids.shape
1456
+ elif inputs_embeds is not None:
1457
+ batch_size, seq_length, _ = inputs_embeds.shape
1458
+ else:
1459
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1460
+
1461
+ seq_length_with_past = seq_length
1462
+ past_key_values_length = 0
1463
+
1464
+ if past_key_values is not None:
1465
+ past_key_values_length = past_key_values[0][0].shape[2]
1466
+ seq_length_with_past = seq_length_with_past + past_key_values_length
1467
+
1468
+ if position_ids is None:
1469
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1470
+ position_ids = torch.arange(
1471
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1472
+ )
1473
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1474
+ else:
1475
+ position_ids = position_ids.view(-1, seq_length).long()
1476
+
1477
+ if inputs_embeds is None:
1478
+ inputs_embeds = self.embed_tokens(input_ids)
1479
+
1480
+ # START VISUAL INPUTS INTEGRATION
1481
+ if pixel_values is not None and image_hidden_states is not None:
1482
+ raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
1483
+ elif pixel_values is not None:
1484
+ pixel_values = pixel_values.to(dtype=self.dtype, device=input_ids.device) # fp16 compatibility
1485
+ batch_size, num_images = pixel_values.size(0), pixel_values.size(1)
1486
+ pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:])
1487
+ # Get sequence from the vision encoder
1488
+ image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state
1489
+
1490
+ # Modality projection
1491
+ image_hidden_states = self.modality_projection(image_hidden_states)
1492
+
1493
+ if self.config.use_resampler:
1494
+ image_hidden_states = self.perceiver_resampler(image_hidden_states)
1495
+
1496
+ if past_key_values is None:
1497
+ # When we generate, we don't want to replace the potential image_token_id that we generated by images
1498
+ # that simply don't exist
1499
+ new_inp = self.inputs_merger(
1500
+ input_ids=input_ids,
1501
+ inputs_embeds=inputs_embeds,
1502
+ image_hidden_states=image_hidden_states,
1503
+ num_images=num_images,
1504
+ )
1505
+ inputs_embeds = new_inp["inputs_embeds"]
1506
+
1507
+ # Can do add some token types embeddings here (image token vs text token)
1508
+ # something like inputs_embeds += self.token_types(token_types)
1509
+
1510
+ # embed positions
1511
+ if (
1512
+ attention_mask is not None
1513
+ and hasattr(self.config, "_flash_attn_2_enabled")
1514
+ and self.config._flash_attn_2_enabled
1515
+ and past_key_values is not None
1516
+ ):
1517
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
1518
+ if is_padding_right:
1519
+ raise ValueError(
1520
+ "You are attempting to perform batched generation with padding_side='right'"
1521
+ " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
1522
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1523
+ )
1524
+
1525
+ if getattr(self.config, "_flash_attn_2_enabled", False):
1526
+ # 2d mask is passed through the layers
1527
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1528
+ else:
1529
+ # 4d mask is passed through the layers
1530
+ attention_mask = _prepare_4d_causal_attention_mask(
1531
+ attention_mask,
1532
+ (batch_size, seq_length),
1533
+ inputs_embeds,
1534
+ past_key_values_length,
1535
+ sliding_window=self.config.sliding_window,
1536
+ )
1537
+ attention_mask[attention_mask == -float("inf")] = torch.finfo(self.dtype).min
1538
+
1539
+ hidden_states = inputs_embeds
1540
+
1541
+ if self.gradient_checkpointing and self.training:
1542
+ if use_cache:
1543
+ logger.warning_once(
1544
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1545
+ )
1546
+ use_cache = False
1547
+
1548
+ # decoder layers
1549
+ all_hidden_states = () if output_hidden_states else None
1550
+ all_self_attns = () if output_attentions else None
1551
+ next_decoder_cache = () if use_cache else None
1552
+
1553
+ for idx, decoder_layer in enumerate(self.layers):
1554
+ if output_hidden_states:
1555
+ all_hidden_states += (hidden_states,)
1556
+
1557
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
1558
+
1559
+ if self.gradient_checkpointing and self.training:
1560
+ layer_outputs = self._gradient_checkpointing_func(
1561
+ decoder_layer.__call__,
1562
+ hidden_states,
1563
+ attention_mask,
1564
+ position_ids,
1565
+ past_key_value,
1566
+ output_attentions,
1567
+ use_cache,
1568
+ )
1569
+ else:
1570
+ layer_outputs = decoder_layer(
1571
+ hidden_states,
1572
+ attention_mask=attention_mask,
1573
+ position_ids=position_ids,
1574
+ past_key_value=past_key_value,
1575
+ output_attentions=output_attentions,
1576
+ use_cache=use_cache,
1577
+ )
1578
+
1579
+ hidden_states = layer_outputs[0]
1580
+
1581
+ if use_cache:
1582
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
1583
+
1584
+ if output_attentions:
1585
+ all_self_attns += (layer_outputs[1],)
1586
+
1587
+ hidden_states = self.norm(hidden_states)
1588
+
1589
+ # add hidden states from the last decoder layer
1590
+ if output_hidden_states:
1591
+ all_hidden_states += (hidden_states,)
1592
+
1593
+ next_cache = next_decoder_cache if use_cache else None
1594
+ if not return_dict:
1595
+ return tuple(
1596
+ v
1597
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, image_hidden_states]
1598
+ if v is not None
1599
+ )
1600
+ return Img2HTMLBaseModelOutputWithPast(
1601
+ last_hidden_state=hidden_states,
1602
+ past_key_values=next_cache,
1603
+ hidden_states=all_hidden_states,
1604
+ attentions=all_self_attns,
1605
+ image_hidden_states=image_hidden_states,
1606
+ )
1607
+
1608
+
1609
+ class Img2HTMLForVisionText2Text(VMistralPreTrainedModel):
1610
+ _tied_weights_keys = ["lm_head.weight"]
1611
+
1612
+ def __init__(self, config, vision_model=None):
1613
+ super().__init__(config)
1614
+ self.model = VMistralModel(config, vision_model=vision_model)
1615
+ self.image_token_id = self.config.image_token_id
1616
+ self.lm_head = DecoupledLinear(
1617
+ in_features=config.hidden_size,
1618
+ out_features=config.vocab_size,
1619
+ out_additional_features=config.additional_vocab_size,
1620
+ bias=False,
1621
+ partially_freeze=config.freeze_lm_head,
1622
+ )
1623
+
1624
+ # Initialize weights and apply final processing
1625
+ self.post_init()
1626
+
1627
+ def get_input_embeddings(self):
1628
+ return self.model.embed_tokens
1629
+
1630
+ def set_input_embeddings(self, value):
1631
+ self.model.embed_tokens = value
1632
+
1633
+ def get_output_embeddings(self):
1634
+ return self.lm_head
1635
+
1636
+ def set_output_embeddings(self, new_embeddings):
1637
+ self.lm_head = new_embeddings
1638
+
1639
+ def set_decoder(self, decoder):
1640
+ self.model = decoder
1641
+
1642
+ def get_decoder(self):
1643
+ return self.model
1644
+
1645
+ def tie_weights(self):
1646
+ """
1647
+ Overwrite `transformers.modeling_utils.PreTrainedModel.tie_weights` to handle the case of DecoupledLinear and DecoupledEmbedding.
1648
+ """
1649
+ output_embeddings = self.get_output_embeddings()
1650
+ input_embeddings = self.get_input_embeddings()
1651
+
1652
+ if getattr(self.config, "tie_word_embeddings", True):
1653
+ output_embeddings.weight = input_embeddings.weight
1654
+ if input_embeddings.num_additional_embeddings > 0:
1655
+ assert output_embeddings.out_additional_features == input_embeddings.num_additional_embeddings
1656
+ output_embeddings.additional_fc.weight = input_embeddings.additional_embedding.weight
1657
+
1658
+ if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
1659
+ output_embeddings.out_features = input_embeddings.num_embeddings
1660
+ if hasattr(output_embeddings, "out_additional_features") and hasattr(
1661
+ input_embeddings, "num_additional_embeddings"
1662
+ ):
1663
+ output_embeddings.out_additional_features = input_embeddings.num_additional_embeddings
1664
+
1665
+ @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
1666
+ @replace_return_docstrings(output_type=Img2HTMLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1667
+ def forward(
1668
+ self,
1669
+ input_ids: torch.LongTensor = None,
1670
+ attention_mask: Optional[torch.Tensor] = None,
1671
+ position_ids: Optional[torch.LongTensor] = None,
1672
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1673
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1674
+ pixel_values: Optional[torch.FloatTensor] = None,
1675
+ image_hidden_states: Optional[torch.FloatTensor] = None,
1676
+ labels: Optional[torch.LongTensor] = None,
1677
+ use_cache: Optional[bool] = None,
1678
+ output_attentions: Optional[bool] = None,
1679
+ output_hidden_states: Optional[bool] = None,
1680
+ return_dict: Optional[bool] = None,
1681
+ ) -> Union[Tuple, Img2HTMLCausalLMOutputWithPast]:
1682
+ r"""
1683
+ Args:
1684
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1685
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1686
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1687
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1688
+
1689
+ Returns:
1690
+
1691
+ """
1692
+
1693
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1694
+ output_hidden_states = (
1695
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1696
+ )
1697
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1698
+
1699
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1700
+ outputs = self.model(
1701
+ input_ids=input_ids,
1702
+ attention_mask=attention_mask,
1703
+ position_ids=position_ids,
1704
+ past_key_values=past_key_values,
1705
+ inputs_embeds=inputs_embeds,
1706
+ pixel_values=pixel_values,
1707
+ image_hidden_states=image_hidden_states,
1708
+ use_cache=use_cache,
1709
+ output_attentions=output_attentions,
1710
+ output_hidden_states=output_hidden_states,
1711
+ return_dict=return_dict,
1712
+ )
1713
+
1714
+ hidden_states = outputs[0]
1715
+ logits = self.lm_head(hidden_states)
1716
+ logits = logits.float()
1717
+
1718
+ loss = None
1719
+ if labels is not None:
1720
+ labels = labels.to(logits.device)
1721
+ # Shift so that tokens < n predict n
1722
+ if attention_mask is not None:
1723
+ shift_attention_mask = attention_mask[..., 1:].to(logits.device)
1724
+ shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
1725
+ shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
1726
+ else:
1727
+ shift_logits = logits[..., :-1, :].contiguous()
1728
+ shift_labels = labels[..., 1:].contiguous()
1729
+ # Flatten the tokens
1730
+ loss_fct = CrossEntropyLoss(ignore_index=self.image_token_id)
1731
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1732
+
1733
+ if not return_dict:
1734
+ output = (logits,) + outputs[1:]
1735
+ return (loss,) + output if loss is not None else output
1736
+
1737
+ return Img2HTMLCausalLMOutputWithPast(
1738
+ loss=loss,
1739
+ logits=logits,
1740
+ past_key_values=outputs.past_key_values,
1741
+ hidden_states=outputs.hidden_states,
1742
+ attentions=outputs.attentions,
1743
+ image_hidden_states=outputs.image_hidden_states,
1744
+ )
1745
+
1746
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
1747
+ image_hidden_states = kwargs.pop("image_hidden_states", None)
1748
+ if image_hidden_states is not None:
1749
+ kwargs["pixel_values"] = None
1750
+ inputs = prepare_inputs_for_generation(input_ids, past=past, **kwargs)
1751
+ unwanted_kwargs = ["token_type_ids"]
1752
+ for kwarg in unwanted_kwargs:
1753
+ inputs.pop(kwarg, None)
1754
+ return inputs
1755
+
1756
+ @staticmethod
1757
+ def _expand_inputs_for_generation(
1758
+ *args,
1759
+ **model_kwargs,
1760
+ ):
1761
+ return expand_inputs_for_generation(*args, **model_kwargs)
1762
+
1763
+ @staticmethod
1764
+ def _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder):
1765
+ return update_model_kwargs_for_generation(outputs, model_kwargs)
1766
+
1767
+ @staticmethod
1768
+ def _reorder_cache(past, beam_idx):
1769
+ reordered_past = ()
1770
+ for layer_past in past:
1771
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1772
+ return reordered_past
vision.py ADDED
@@ -0,0 +1,1361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Google AI and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ A simplified copy of https://huggingface.co/HuggingFaceM4/siglip-so400m-14-384-flash-attn2 """
16
+
17
+
18
+ from dataclasses import dataclass
19
+ from typing import Any, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+ from transformers.activations import ACT2FN
26
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
27
+ from transformers.modeling_utils import PreTrainedModel
28
+ from transformers.utils import (
29
+ ModelOutput,
30
+ add_start_docstrings,
31
+ add_start_docstrings_to_model_forward,
32
+ is_flash_attn_2_available,
33
+ logging,
34
+ replace_return_docstrings,
35
+ )
36
+
37
+ from .configuration_img2html import VMistralVisionConfig
38
+
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+ # _CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
43
+
44
+ # SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
45
+ # "google/siglip-base-patch16-224",
46
+ # # See all SigLIP models at https://huggingface.co/models?filter=siglip
47
+ # ]
48
+
49
+ if is_flash_attn_2_available():
50
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
51
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
52
+
53
+
54
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
55
+ def _get_unpad_data(attention_mask):
56
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
57
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
58
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
59
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
60
+ return (
61
+ indices,
62
+ cu_seqlens,
63
+ max_seqlen_in_batch,
64
+ )
65
+
66
+
67
+ # # Copied from transformers.models.bart.modeling_bart._expand_mask
68
+ # def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
69
+ # """
70
+ # Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
71
+ # """
72
+ # bsz, src_len = mask.size()
73
+ # tgt_len = tgt_len if tgt_len is not None else src_len
74
+
75
+ # expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
76
+
77
+ # inverted_mask = 1.0 - expanded_mask
78
+
79
+ # return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
80
+
81
+
82
+ # # contrastive loss function, adapted from
83
+ # # https://sachinruk.github.io/blog/2021-03-07-siglip.html
84
+ # def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
85
+ # return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
86
+
87
+
88
+ # # Copied from transformers.models.clip.modeling_clip.clip_loss with clip->siglip
89
+ # def siglip_loss(similarity: torch.Tensor) -> torch.Tensor:
90
+ # caption_loss = contrastive_loss(similarity)
91
+ # image_loss = contrastive_loss(similarity.t())
92
+ # return (caption_loss + image_loss) / 2.0
93
+
94
+
95
+ @dataclass
96
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip
97
+ class SiglipVisionModelOutput(ModelOutput):
98
+ """
99
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
100
+
101
+ Args:
102
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
103
+ The image embeddings obtained by applying the projection layer to the pooler_output.
104
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
105
+ Sequence of hidden-states at the output of the last layer of the model.
106
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
107
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
108
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
109
+
110
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
111
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
112
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
113
+ sequence_length)`.
114
+
115
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
116
+ heads.
117
+ """
118
+
119
+ image_embeds: Optional[torch.FloatTensor] = None
120
+ last_hidden_state: torch.FloatTensor = None
121
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
122
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
123
+
124
+
125
+ # @dataclass
126
+ # # Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip
127
+ # class SiglipTextModelOutput(ModelOutput):
128
+ # """
129
+ # Base class for text model's outputs that also contains a pooling of the last hidden states.
130
+
131
+ # Args:
132
+ # text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
133
+ # The text embeddings obtained by applying the projection layer to the pooler_output.
134
+ # last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
135
+ # Sequence of hidden-states at the output of the last layer of the model.
136
+ # hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
137
+ # Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
138
+ # one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
139
+
140
+ # Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
141
+ # attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
142
+ # Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
143
+ # sequence_length)`.
144
+
145
+ # Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
146
+ # heads.
147
+ # """
148
+
149
+ # text_embeds: Optional[torch.FloatTensor] = None
150
+ # last_hidden_state: torch.FloatTensor = None
151
+ # hidden_states: Optional[Tuple[torch.FloatTensor]] = None
152
+ # attentions: Optional[Tuple[torch.FloatTensor]] = None
153
+
154
+
155
+ # @dataclass
156
+ # # Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip
157
+ # class SiglipOutput(ModelOutput):
158
+ # """
159
+ # Args:
160
+ # loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
161
+ # Contrastive loss for image-text similarity.
162
+ # logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
163
+ # The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
164
+ # similarity scores.
165
+ # logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
166
+ # The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
167
+ # similarity scores.
168
+ # text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
169
+ # The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
170
+ # image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
171
+ # The image embeddings obtained by applying the projection layer to the pooled output of
172
+ # [`SiglipVisionModel`].
173
+ # text_model_output(`BaseModelOutputWithPooling`):
174
+ # The output of the [`SiglipTextModel`].
175
+ # vision_model_output(`BaseModelOutputWithPooling`):
176
+ # The output of the [`SiglipVisionModel`].
177
+ # """
178
+
179
+ # loss: Optional[torch.FloatTensor] = None
180
+ # logits_per_image: torch.FloatTensor = None
181
+ # logits_per_text: torch.FloatTensor = None
182
+ # text_embeds: torch.FloatTensor = None
183
+ # image_embeds: torch.FloatTensor = None
184
+ # text_model_output: BaseModelOutputWithPooling = None
185
+ # vision_model_output: BaseModelOutputWithPooling = None
186
+
187
+ # def to_tuple(self) -> Tuple[Any]:
188
+ # return tuple(
189
+ # self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
190
+ # for k in self.keys()
191
+ # )
192
+
193
+
194
+ class SiglipVisionEmbeddings(nn.Module):
195
+ def __init__(self, config: VMistralVisionConfig):
196
+ super().__init__()
197
+ self.config = config
198
+ self.embed_dim = config.hidden_size
199
+ self.image_size = config.image_size
200
+ self.patch_size = config.patch_size
201
+
202
+ self.patch_embedding = nn.Conv2d(
203
+ in_channels=config.num_channels,
204
+ out_channels=self.embed_dim,
205
+ kernel_size=self.patch_size,
206
+ stride=self.patch_size,
207
+ padding="valid",
208
+ )
209
+
210
+ self.num_patches = (self.image_size // self.patch_size) ** 2
211
+ self.num_positions = self.num_patches
212
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
213
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
214
+
215
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
216
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
217
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
218
+
219
+ embeddings = embeddings + self.position_embedding(self.position_ids)
220
+ return embeddings
221
+
222
+
223
+ # # Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip
224
+ # class SiglipTextEmbeddings(nn.Module):
225
+ # def __init__(self, config: SiglipTextConfig):
226
+ # super().__init__()
227
+ # embed_dim = config.hidden_size
228
+
229
+ # self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
230
+ # self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
231
+
232
+ # # position_ids (1, len position emb) is contiguous in memory and exported when serialized
233
+ # self.register_buffer(
234
+ # "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
235
+ # )
236
+
237
+ # def forward(
238
+ # self,
239
+ # input_ids: Optional[torch.LongTensor] = None,
240
+ # position_ids: Optional[torch.LongTensor] = None,
241
+ # inputs_embeds: Optional[torch.FloatTensor] = None,
242
+ # ) -> torch.Tensor:
243
+ # seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
244
+
245
+ # if position_ids is None:
246
+ # position_ids = self.position_ids[:, :seq_length]
247
+
248
+ # if inputs_embeds is None:
249
+ # inputs_embeds = self.token_embedding(input_ids)
250
+
251
+ # position_embeddings = self.position_embedding(position_ids)
252
+ # embeddings = inputs_embeds + position_embeddings
253
+
254
+ # return embeddings
255
+
256
+
257
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->Siglip
258
+ class SiglipAttention(nn.Module):
259
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
260
+
261
+ def __init__(self, config):
262
+ super().__init__()
263
+ self.config = config
264
+ self.embed_dim = config.hidden_size
265
+ self.num_heads = config.num_attention_heads
266
+ self.head_dim = self.embed_dim // self.num_heads
267
+ if self.head_dim * self.num_heads != self.embed_dim:
268
+ raise ValueError(
269
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
270
+ f" {self.num_heads})."
271
+ )
272
+ self.scale = self.head_dim**-0.5
273
+ self.dropout = config.attention_dropout
274
+
275
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
276
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
277
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
278
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
279
+
280
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
281
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
282
+
283
+ def forward(
284
+ self,
285
+ hidden_states: torch.Tensor,
286
+ attention_mask: Optional[torch.Tensor] = None,
287
+ causal_attention_mask: Optional[torch.Tensor] = None,
288
+ output_attentions: Optional[bool] = False,
289
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
290
+ """Input shape: Batch x Time x Channel"""
291
+
292
+ bsz, tgt_len, embed_dim = hidden_states.size()
293
+
294
+ # get query proj
295
+ query_states = self.q_proj(hidden_states) * self.scale
296
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
297
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
298
+
299
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
300
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
301
+ key_states = key_states.view(*proj_shape)
302
+ value_states = value_states.view(*proj_shape)
303
+
304
+ src_len = key_states.size(1)
305
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
306
+
307
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
308
+ raise ValueError(
309
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
310
+ f" {attn_weights.size()}"
311
+ )
312
+
313
+ # apply the causal_attention_mask first
314
+ if causal_attention_mask is not None:
315
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
316
+ raise ValueError(
317
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
318
+ f" {causal_attention_mask.size()}"
319
+ )
320
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
321
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
322
+
323
+ if attention_mask is not None:
324
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
325
+ raise ValueError(
326
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
327
+ )
328
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
329
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
330
+
331
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
332
+
333
+ if output_attentions:
334
+ # this operation is a bit akward, but it's required to
335
+ # make sure that attn_weights keeps its gradient.
336
+ # In order to do so, attn_weights have to reshaped
337
+ # twice and have to be reused in the following
338
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
339
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
340
+ else:
341
+ attn_weights_reshaped = None
342
+
343
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
344
+
345
+ attn_output = torch.bmm(attn_probs, value_states)
346
+
347
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
348
+ raise ValueError(
349
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
350
+ f" {attn_output.size()}"
351
+ )
352
+
353
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
354
+ attn_output = attn_output.transpose(1, 2)
355
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
356
+
357
+ attn_output = self.out_proj(attn_output)
358
+
359
+ return attn_output, attn_weights_reshaped
360
+
361
+
362
+ class SiglipFlashAttention2(SiglipAttention):
363
+ """
364
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
365
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
366
+ flash attention and deal with padding tokens in case the input contains any of them.
367
+ """
368
+
369
+ def __init__(self, *args, **kwargs):
370
+ super().__init__(*args, **kwargs)
371
+ self.is_causal = False # Hack to make sure we don't use a causal mask
372
+
373
+ def forward(
374
+ self,
375
+ hidden_states: torch.Tensor,
376
+ attention_mask: Optional[torch.LongTensor] = None,
377
+ position_ids: Optional[torch.LongTensor] = None,
378
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
379
+ output_attentions: bool = False,
380
+ use_cache: bool = False,
381
+ **kwargs,
382
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
383
+ output_attentions = False
384
+
385
+ bsz, q_len, _ = hidden_states.size()
386
+
387
+ query_states = self.q_proj(hidden_states)
388
+ key_states = self.k_proj(hidden_states)
389
+ value_states = self.v_proj(hidden_states)
390
+
391
+ # Flash attention requires the input to have the shape
392
+ # batch_size x seq_length x head_dim x hidden_dim
393
+ # therefore we just need to keep the original shape
394
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
395
+ key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
396
+ value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
397
+
398
+ kv_seq_len = key_states.shape[-2]
399
+ if past_key_value is not None:
400
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
401
+ # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
402
+ # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
403
+
404
+ # if past_key_value is not None:
405
+ # cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
406
+ # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
407
+
408
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
409
+ # to be able to avoid many of these transpose/reshape/view.
410
+ query_states = query_states.transpose(1, 2)
411
+ key_states = key_states.transpose(1, 2)
412
+ value_states = value_states.transpose(1, 2)
413
+
414
+ dropout_rate = self.dropout if self.training else 0.0
415
+
416
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
417
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
418
+ # cast them back in the correct dtype just to be sure everything works as expected.
419
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
420
+ # in fp32. (LlamaRMSNorm handles it correctly)
421
+
422
+ input_dtype = query_states.dtype
423
+ if input_dtype == torch.float32:
424
+ if torch.is_autocast_enabled():
425
+ target_dtype = torch.get_autocast_gpu_dtype()
426
+ # Handle the case where the model is quantized
427
+ elif hasattr(self.config, "_pre_quantization_dtype"):
428
+ target_dtype = self.config._pre_quantization_dtype
429
+ else:
430
+ target_dtype = self.q_proj.weight.dtype
431
+
432
+ logger.warning_once(
433
+ "The input hidden states seems to be silently casted in float32, this might be related to the fact"
434
+ " you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
435
+ f" {target_dtype}."
436
+ )
437
+
438
+ query_states = query_states.to(target_dtype)
439
+ key_states = key_states.to(target_dtype)
440
+ value_states = value_states.to(target_dtype)
441
+
442
+ attn_output = self._flash_attention_forward(
443
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
444
+ )
445
+
446
+ attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous()
447
+ attn_output = self.out_proj(attn_output)
448
+
449
+ if not output_attentions:
450
+ attn_weights = None
451
+
452
+ return attn_output, attn_weights
453
+
454
+ def _flash_attention_forward(
455
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
456
+ ):
457
+ """
458
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
459
+ first unpad the input, then computes the attention scores and pad the final attention scores.
460
+
461
+ Args:
462
+ query_states (`torch.Tensor`):
463
+ Input query states to be passed to Flash Attention API
464
+ key_states (`torch.Tensor`):
465
+ Input key states to be passed to Flash Attention API
466
+ value_states (`torch.Tensor`):
467
+ Input value states to be passed to Flash Attention API
468
+ attention_mask (`torch.Tensor`):
469
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
470
+ position of padding tokens and 1 for the position of non-padding tokens.
471
+ dropout (`int`, *optional*):
472
+ Attention dropout
473
+ softmax_scale (`float`, *optional*):
474
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
475
+ """
476
+
477
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
478
+ causal = self.is_causal and query_length != 1
479
+
480
+ # Contains at least one padding token in the sequence
481
+ if attention_mask is not None:
482
+ batch_size = query_states.shape[0]
483
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
484
+ query_states, key_states, value_states, attention_mask, query_length
485
+ )
486
+
487
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
488
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
489
+
490
+ attn_output_unpad = flash_attn_varlen_func(
491
+ query_states,
492
+ key_states,
493
+ value_states,
494
+ cu_seqlens_q=cu_seqlens_q,
495
+ cu_seqlens_k=cu_seqlens_k,
496
+ max_seqlen_q=max_seqlen_in_batch_q,
497
+ max_seqlen_k=max_seqlen_in_batch_k,
498
+ dropout_p=dropout,
499
+ softmax_scale=softmax_scale,
500
+ causal=causal,
501
+ )
502
+
503
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
504
+ else:
505
+ attn_output = flash_attn_func(
506
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
507
+ )
508
+
509
+ return attn_output
510
+
511
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
512
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
513
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
514
+
515
+ key_layer = index_first_axis(
516
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
517
+ )
518
+ value_layer = index_first_axis(
519
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
520
+ )
521
+ if query_length == kv_seq_len:
522
+ query_layer = index_first_axis(
523
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
524
+ )
525
+ cu_seqlens_q = cu_seqlens_k
526
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
527
+ indices_q = indices_k
528
+ elif query_length == 1:
529
+ max_seqlen_in_batch_q = 1
530
+ cu_seqlens_q = torch.arange(
531
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
532
+ ) # There is a memcpy here, that is very bad.
533
+ indices_q = cu_seqlens_q[:-1]
534
+ query_layer = query_layer.squeeze(1)
535
+ else:
536
+ # The -q_len: slice assumes left padding.
537
+ attention_mask = attention_mask[:, -query_length:]
538
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
539
+
540
+ return (
541
+ query_layer,
542
+ key_layer,
543
+ value_layer,
544
+ indices_q,
545
+ (cu_seqlens_q, cu_seqlens_k),
546
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
547
+ )
548
+
549
+
550
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
551
+ class SiglipMLP(nn.Module):
552
+ def __init__(self, config):
553
+ super().__init__()
554
+ self.config = config
555
+ self.activation_fn = ACT2FN[config.hidden_act]
556
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
557
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
558
+
559
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
560
+ hidden_states = self.fc1(hidden_states)
561
+ hidden_states = self.activation_fn(hidden_states)
562
+ hidden_states = self.fc2(hidden_states)
563
+ return hidden_states
564
+
565
+
566
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip
567
+ class SiglipEncoderLayer(nn.Module):
568
+ def __init__(self, config: VMistralVisionConfig):
569
+ super().__init__()
570
+ self.embed_dim = config.hidden_size
571
+ self.self_attn = (
572
+ SiglipAttention(config)
573
+ if not getattr(config, "_flash_attn_2_enabled", False)
574
+ else SiglipFlashAttention2(config)
575
+ )
576
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
577
+ self.mlp = SiglipMLP(config)
578
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
579
+
580
+ def forward(
581
+ self,
582
+ hidden_states: torch.Tensor,
583
+ attention_mask: torch.Tensor,
584
+ causal_attention_mask: torch.Tensor,
585
+ output_attentions: Optional[bool] = False,
586
+ ) -> Tuple[torch.FloatTensor]:
587
+ """
588
+ Args:
589
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
590
+ attention_mask (`torch.FloatTensor`): attention mask of size
591
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
592
+ `(config.encoder_attention_heads,)`.
593
+ output_attentions (`bool`, *optional*):
594
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
595
+ returned tensors for more detail.
596
+ """
597
+ residual = hidden_states
598
+
599
+ hidden_states = self.layer_norm1(hidden_states)
600
+ hidden_states, attn_weights = self.self_attn(
601
+ hidden_states=hidden_states,
602
+ attention_mask=attention_mask,
603
+ causal_attention_mask=causal_attention_mask,
604
+ output_attentions=output_attentions,
605
+ )
606
+ hidden_states = residual + hidden_states
607
+
608
+ residual = hidden_states
609
+ hidden_states = self.layer_norm2(hidden_states)
610
+ hidden_states = self.mlp(hidden_states)
611
+ hidden_states = residual + hidden_states
612
+
613
+ outputs = (hidden_states,)
614
+
615
+ if output_attentions:
616
+ outputs += (attn_weights,)
617
+
618
+ return outputs
619
+
620
+
621
+ # class SiglipPreTrainedModel(PreTrainedModel):
622
+ # """
623
+ # An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
624
+ # models.
625
+ # """
626
+
627
+ # config_class = SiglipConfig
628
+ # base_model_prefix = "siglip"
629
+ # supports_gradient_checkpointing = True
630
+
631
+ # def _init_weights(self, module):
632
+ # """Initialize the weights"""
633
+ # factor = self.config.initializer_factor
634
+ # if isinstance(module, SiglipVisionEmbeddings):
635
+ # factor = self.config.initializer_factor
636
+ # nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
637
+ # nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
638
+ # elif isinstance(module, SiglipAttention):
639
+ # factor = self.config.initializer_factor
640
+ # in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
641
+ # out_proj_std = (module.embed_dim**-0.5) * factor
642
+ # nn.init.normal_(module.q_proj.weight, std=in_proj_std)
643
+ # nn.init.normal_(module.k_proj.weight, std=in_proj_std)
644
+ # nn.init.normal_(module.v_proj.weight, std=in_proj_std)
645
+ # nn.init.normal_(module.out_proj.weight, std=out_proj_std)
646
+ # elif isinstance(module, SiglipMLP):
647
+ # factor = self.config.initializer_factor
648
+ # in_proj_std = (
649
+ # (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
650
+ # )
651
+ # fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
652
+ # nn.init.normal_(module.fc1.weight, std=fc_std)
653
+ # nn.init.normal_(module.fc2.weight, std=in_proj_std)
654
+ # if isinstance(module, nn.LayerNorm):
655
+ # module.bias.data.zero_()
656
+ # module.weight.data.fill_(1.0)
657
+ # if isinstance(module, nn.Linear) and module.bias is not None:
658
+ # module.bias.data.zero_()
659
+
660
+ # def _set_gradient_checkpointing(self, module, value=False):
661
+ # if isinstance(module, SiglipEncoder):
662
+ # module.gradient_checkpointing = value
663
+
664
+
665
+ # SIGLIP_START_DOCSTRING = r"""
666
+ # This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
667
+ # library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
668
+ # etc.)
669
+
670
+ # This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
671
+ # Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
672
+ # and behavior.
673
+
674
+ # Parameters:
675
+ # config ([`SiglipConfig`]): Model configuration class with all the parameters of the model.
676
+ # Initializing with a config file does not load the weights associated with the model, only the
677
+ # configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
678
+ # """
679
+
680
+ # SIGLIP_TEXT_INPUTS_DOCSTRING = r"""
681
+ # Args:
682
+ # input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
683
+ # Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
684
+ # it.
685
+
686
+ # Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
687
+ # [`PreTrainedTokenizer.__call__`] for details.
688
+
689
+ # [What are input IDs?](../glossary#input-ids)
690
+ # attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
691
+ # Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
692
+
693
+ # - 1 for tokens that are **not masked**,
694
+ # - 0 for tokens that are **masked**.
695
+
696
+ # [What are attention masks?](../glossary#attention-mask)
697
+ # position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
698
+ # Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
699
+ # config.max_position_embeddings - 1]`.
700
+
701
+ # [What are position IDs?](../glossary#position-ids)
702
+ # output_attentions (`bool`, *optional*):
703
+ # Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
704
+ # tensors for more detail.
705
+ # output_hidden_states (`bool`, *optional*):
706
+ # Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
707
+ # more detail.
708
+ # return_dict (`bool`, *optional*):
709
+ # Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
710
+ # """
711
+
712
+ # SIGLIP_VISION_INPUTS_DOCSTRING = r"""
713
+ # Args:
714
+ # pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
715
+ # Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
716
+ # [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
717
+ # output_attentions (`bool`, *optional*):
718
+ # Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
719
+ # tensors for more detail.
720
+ # output_hidden_states (`bool`, *optional*):
721
+ # Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
722
+ # more detail.
723
+ # return_dict (`bool`, *optional*):
724
+ # Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
725
+ # """
726
+
727
+ # SIGLIP_INPUTS_DOCSTRING = r"""
728
+ # Args:
729
+ # input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
730
+ # Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
731
+ # it.
732
+
733
+ # Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
734
+ # [`PreTrainedTokenizer.__call__`] for details.
735
+
736
+ # [What are input IDs?](../glossary#input-ids)
737
+ # attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
738
+ # Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
739
+
740
+ # - 1 for tokens that are **not masked**,
741
+ # - 0 for tokens that are **masked**.
742
+
743
+ # [What are attention masks?](../glossary#attention-mask)
744
+ # position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
745
+ # Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
746
+ # config.max_position_embeddings - 1]`.
747
+
748
+ # [What are position IDs?](../glossary#position-ids)
749
+ # pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
750
+ # Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
751
+ # [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
752
+ # return_loss (`bool`, *optional*):
753
+ # Whether or not to return the contrastive loss.
754
+ # output_attentions (`bool`, *optional*):
755
+ # Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
756
+ # tensors for more detail.
757
+ # output_hidden_states (`bool`, *optional*):
758
+ # Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
759
+ # more detail.
760
+ # return_dict (`bool`, *optional*):
761
+ # Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
762
+ # """
763
+
764
+
765
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip
766
+ class SiglipEncoder(nn.Module):
767
+ """
768
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
769
+ [`SiglipEncoderLayer`].
770
+
771
+ Args:
772
+ config: SiglipConfig
773
+ """
774
+
775
+ def __init__(self, config):
776
+ super().__init__()
777
+ self.config = config
778
+ self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
779
+ self.gradient_checkpointing = False
780
+
781
+ def forward(
782
+ self,
783
+ inputs_embeds,
784
+ attention_mask: Optional[torch.Tensor] = None,
785
+ causal_attention_mask: Optional[torch.Tensor] = None,
786
+ output_attentions: Optional[bool] = None,
787
+ output_hidden_states: Optional[bool] = None,
788
+ return_dict: Optional[bool] = None,
789
+ ) -> Union[Tuple, BaseModelOutput]:
790
+ r"""
791
+ Args:
792
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
793
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
794
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
795
+ than the model's internal embedding lookup matrix.
796
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
797
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
798
+
799
+ - 1 for tokens that are **not masked**,
800
+ - 0 for tokens that are **masked**.
801
+
802
+ [What are attention masks?](../glossary#attention-mask)
803
+ causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
804
+ Causal mask for the text model. Mask values selected in `[0, 1]`:
805
+
806
+ - 1 for tokens that are **not masked**,
807
+ - 0 for tokens that are **masked**.
808
+
809
+ [What are attention masks?](../glossary#attention-mask)
810
+ output_attentions (`bool`, *optional*):
811
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
812
+ returned tensors for more detail.
813
+ output_hidden_states (`bool`, *optional*):
814
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
815
+ for more detail.
816
+ return_dict (`bool`, *optional*):
817
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
818
+ """
819
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
820
+ output_hidden_states = (
821
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
822
+ )
823
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
824
+
825
+ encoder_states = () if output_hidden_states else None
826
+ all_attentions = () if output_attentions else None
827
+
828
+ hidden_states = inputs_embeds
829
+ for idx, encoder_layer in enumerate(self.layers):
830
+ if output_hidden_states:
831
+ encoder_states = encoder_states + (hidden_states,)
832
+ if self.gradient_checkpointing and self.training:
833
+
834
+ def create_custom_forward(module):
835
+ def custom_forward(*inputs):
836
+ return module(*inputs, output_attentions)
837
+
838
+ return custom_forward
839
+
840
+ layer_outputs = torch.utils.checkpoint.checkpoint(
841
+ create_custom_forward(encoder_layer),
842
+ hidden_states,
843
+ attention_mask,
844
+ causal_attention_mask,
845
+ )
846
+ else:
847
+ layer_outputs = encoder_layer(
848
+ hidden_states,
849
+ attention_mask,
850
+ causal_attention_mask,
851
+ output_attentions=output_attentions,
852
+ )
853
+
854
+ hidden_states = layer_outputs[0]
855
+
856
+ if output_attentions:
857
+ all_attentions = all_attentions + (layer_outputs[1],)
858
+
859
+ if output_hidden_states:
860
+ encoder_states = encoder_states + (hidden_states,)
861
+
862
+ if not return_dict:
863
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
864
+ return BaseModelOutput(
865
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
866
+ )
867
+
868
+
869
+ # class SiglipTextTransformer(nn.Module):
870
+ # def __init__(self, config: SiglipTextConfig):
871
+ # super().__init__()
872
+ # self.config = config
873
+ # embed_dim = config.hidden_size
874
+ # self.embeddings = SiglipTextEmbeddings(config)
875
+ # self.encoder = SiglipEncoder(config)
876
+ # self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
877
+
878
+ # self.head = nn.Linear(embed_dim, embed_dim)
879
+
880
+ # @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
881
+ # @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
882
+ # def forward(
883
+ # self,
884
+ # input_ids: Optional[torch.Tensor] = None,
885
+ # attention_mask: Optional[torch.Tensor] = None,
886
+ # position_ids: Optional[torch.Tensor] = None,
887
+ # output_attentions: Optional[bool] = None,
888
+ # output_hidden_states: Optional[bool] = None,
889
+ # return_dict: Optional[bool] = None,
890
+ # ) -> Union[Tuple, BaseModelOutputWithPooling]:
891
+ # r"""
892
+ # Returns:
893
+
894
+ # """
895
+ # output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
896
+ # output_hidden_states = (
897
+ # output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
898
+ # )
899
+ # return_dict = return_dict if return_dict is not None else self.config.use_return_dict
900
+
901
+ # if input_ids is None:
902
+ # raise ValueError("You have to specify input_ids")
903
+
904
+ # input_shape = input_ids.size()
905
+ # input_ids = input_ids.view(-1, input_shape[-1])
906
+
907
+ # hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
908
+
909
+ # # note: SigLIP's text model does not use q causal mask, unlike the original CLIP model.
910
+ # # expand attention_mask
911
+ # if attention_mask is not None:
912
+ # # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
913
+ # attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
914
+
915
+ # encoder_outputs = self.encoder(
916
+ # inputs_embeds=hidden_states,
917
+ # attention_mask=None,
918
+ # causal_attention_mask=None,
919
+ # output_attentions=output_attentions,
920
+ # output_hidden_states=output_hidden_states,
921
+ # return_dict=return_dict,
922
+ # )
923
+
924
+ # last_hidden_state = encoder_outputs[0]
925
+ # last_hidden_state = self.final_layer_norm(last_hidden_state)
926
+
927
+ # # Assuming "sticky" EOS tokenization, last token is always EOS.
928
+ # pooled_output = last_hidden_state[:, -1, :]
929
+ # pooled_output = self.head(pooled_output)
930
+
931
+ # if not return_dict:
932
+ # return (last_hidden_state, pooled_output) + encoder_outputs[1:]
933
+
934
+ # return BaseModelOutputWithPooling(
935
+ # last_hidden_state=last_hidden_state,
936
+ # pooler_output=pooled_output,
937
+ # hidden_states=encoder_outputs.hidden_states,
938
+ # attentions=encoder_outputs.attentions,
939
+ # )
940
+
941
+
942
+ # @add_start_docstrings(
943
+ # """The text model from SigLIP without any head or projection on top.""",
944
+ # SIGLIP_START_DOCSTRING,
945
+ # )
946
+ # class SiglipTextModel(SiglipPreTrainedModel):
947
+ # config_class = SiglipTextConfig
948
+
949
+ # _no_split_modules = ["SiglipTextEmbeddings", "SiglipEncoderLayer"]
950
+
951
+ # def __init__(self, config: SiglipTextConfig):
952
+ # super().__init__(config)
953
+ # self.text_model = SiglipTextTransformer(config)
954
+ # # Initialize weights and apply final processing
955
+ # self.post_init()
956
+
957
+ # def get_input_embeddings(self) -> nn.Module:
958
+ # return self.text_model.embeddings.token_embedding
959
+
960
+ # def set_input_embeddings(self, value):
961
+ # self.text_model.embeddings.token_embedding = value
962
+
963
+ # @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
964
+ # @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
965
+ # def forward(
966
+ # self,
967
+ # input_ids: Optional[torch.Tensor] = None,
968
+ # attention_mask: Optional[torch.Tensor] = None,
969
+ # position_ids: Optional[torch.Tensor] = None,
970
+ # output_attentions: Optional[bool] = None,
971
+ # output_hidden_states: Optional[bool] = None,
972
+ # return_dict: Optional[bool] = None,
973
+ # ) -> Union[Tuple, BaseModelOutputWithPooling]:
974
+ # r"""
975
+ # Returns:
976
+
977
+ # Examples:
978
+
979
+ # ```python
980
+ # >>> from transformers import AutoTokenizer, SiglipTextModel
981
+
982
+ # >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224")
983
+ # >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
984
+
985
+ # >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
986
+
987
+ # >>> outputs = model(**inputs)
988
+ # >>> last_hidden_state = outputs.last_hidden_state
989
+ # >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
990
+ # ```"""
991
+ # return_dict = return_dict if return_dict is not None else self.config.use_return_dict
992
+
993
+ # return self.text_model(
994
+ # input_ids=input_ids,
995
+ # attention_mask=attention_mask,
996
+ # position_ids=position_ids,
997
+ # output_attentions=output_attentions,
998
+ # output_hidden_states=output_hidden_states,
999
+ # return_dict=return_dict,
1000
+ # )
1001
+
1002
+
1003
+ class SiglipVisionTransformer(nn.Module):
1004
+ def __init__(self, config: VMistralVisionConfig):
1005
+ super().__init__()
1006
+ self.config = config
1007
+ embed_dim = config.hidden_size
1008
+
1009
+ self.embeddings = SiglipVisionEmbeddings(config)
1010
+ self.encoder = SiglipEncoder(config)
1011
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
1012
+ self.head = SiglipMultiheadAttentionPoolingHead(config)
1013
+
1014
+ # @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1015
+ # @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=VMistralVisionConfig)
1016
+ def forward(
1017
+ self,
1018
+ pixel_values,
1019
+ output_attentions: Optional[bool] = None,
1020
+ output_hidden_states: Optional[bool] = None,
1021
+ return_dict: Optional[bool] = None,
1022
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
1023
+ r"""
1024
+ Returns:
1025
+
1026
+ """
1027
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1028
+ output_hidden_states = (
1029
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1030
+ )
1031
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1032
+
1033
+ hidden_states = self.embeddings(pixel_values)
1034
+
1035
+ encoder_outputs = self.encoder(
1036
+ inputs_embeds=hidden_states,
1037
+ output_attentions=output_attentions,
1038
+ output_hidden_states=output_hidden_states,
1039
+ return_dict=return_dict,
1040
+ )
1041
+
1042
+ last_hidden_state = encoder_outputs[0]
1043
+ last_hidden_state = self.post_layernorm(last_hidden_state)
1044
+
1045
+ pooled_output = self.head(last_hidden_state)
1046
+
1047
+ if not return_dict:
1048
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
1049
+
1050
+ return BaseModelOutputWithPooling(
1051
+ last_hidden_state=last_hidden_state,
1052
+ pooler_output=pooled_output,
1053
+ hidden_states=encoder_outputs.hidden_states,
1054
+ attentions=encoder_outputs.attentions,
1055
+ )
1056
+
1057
+
1058
+ class SiglipMultiheadAttentionPoolingHead(nn.Module):
1059
+ """Multihead Attention Pooling."""
1060
+
1061
+ def __init__(self, config: VMistralVisionConfig):
1062
+ super().__init__()
1063
+
1064
+ self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
1065
+ self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
1066
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1067
+ self.mlp = SiglipMLP(config)
1068
+
1069
+ def forward(self, hidden_state):
1070
+ batch_size = hidden_state.shape[0]
1071
+ probe = self.probe.repeat(batch_size, 1, 1)
1072
+
1073
+ hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
1074
+
1075
+ residual = hidden_state
1076
+ hidden_state = self.layernorm(hidden_state)
1077
+ hidden_state = residual + self.mlp(hidden_state)
1078
+
1079
+ return hidden_state[:, 0]
1080
+
1081
+
1082
+ # @add_start_docstrings(
1083
+ # """The vision model from SigLIP without any head or projection on top.""",
1084
+ # SIGLIP_START_DOCSTRING,
1085
+ # )
1086
+ class SiglipVisionModel(nn.Module):
1087
+ def __init__(self, config: VMistralVisionConfig):
1088
+ super().__init__()
1089
+
1090
+ self.vision_model = SiglipVisionTransformer(config)
1091
+
1092
+ # # Initialize weights and apply final processing
1093
+ # self.post_init()
1094
+
1095
+ # def get_input_embeddings(self) -> nn.Module:
1096
+ # return self.vision_model.embeddings.patch_embedding
1097
+
1098
+ # @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1099
+ # @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=VMistralVisionConfig)
1100
+ def forward(
1101
+ self,
1102
+ pixel_values,
1103
+ output_attentions: Optional[bool] = None,
1104
+ output_hidden_states: Optional[bool] = None,
1105
+ return_dict: Optional[bool] = None,
1106
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
1107
+ # r"""
1108
+ # Returns:
1109
+
1110
+ # Examples:
1111
+
1112
+ # ```python
1113
+ # >>> from PIL import Image
1114
+ # >>> import requests
1115
+ # >>> from transformers import AutoProcessor, SiglipVisionModel
1116
+
1117
+ # >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
1118
+ # >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1119
+
1120
+ # >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1121
+ # >>> image = Image.open(requests.get(url, stream=True).raw)
1122
+
1123
+ # >>> inputs = processor(images=image, return_tensors="pt")
1124
+
1125
+ # >>> outputs = model(**inputs)
1126
+ # >>> last_hidden_state = outputs.last_hidden_state
1127
+ # >>> pooled_output = outputs.pooler_output # pooled CLS states
1128
+ # ```"""
1129
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1130
+
1131
+ return self.vision_model(
1132
+ pixel_values=pixel_values,
1133
+ output_attentions=output_attentions,
1134
+ output_hidden_states=output_hidden_states,
1135
+ return_dict=return_dict,
1136
+ )
1137
+
1138
+
1139
+ # @add_start_docstrings(SIGLIP_START_DOCSTRING)
1140
+ # class SiglipModel(SiglipPreTrainedModel):
1141
+ # config_class = SiglipConfig
1142
+
1143
+ # def __init__(self, config: SiglipConfig):
1144
+ # super().__init__(config)
1145
+
1146
+ # if not isinstance(config.text_config, SiglipTextConfig):
1147
+ # raise ValueError(
1148
+ # "config.text_config is expected to be of type SiglipTextConfig but is of type"
1149
+ # f" {type(config.text_config)}."
1150
+ # )
1151
+
1152
+ # if not isinstance(config.vision_config, SiglipVisionConfig):
1153
+ # raise ValueError(
1154
+ # "config.vision_config is expected to be of type SiglipVisionConfig but is of type"
1155
+ # f" {type(config.vision_config)}."
1156
+ # )
1157
+
1158
+ # text_config = config.text_config
1159
+ # vision_config = config.vision_config
1160
+
1161
+ # self.text_model = SiglipTextModel(text_config)
1162
+ # self.vision_model = SiglipVisionModel(vision_config)
1163
+
1164
+ # self.temperature = nn.Parameter(
1165
+ # torch.randn(
1166
+ # 1,
1167
+ # )
1168
+ # )
1169
+ # self.bias = nn.Parameter(
1170
+ # torch.randn(
1171
+ # 1,
1172
+ # )
1173
+ # )
1174
+
1175
+ # # Initialize weights and apply final processing
1176
+ # self.post_init()
1177
+
1178
+ # @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
1179
+ # def get_text_features(
1180
+ # self,
1181
+ # input_ids: Optional[torch.Tensor] = None,
1182
+ # attention_mask: Optional[torch.Tensor] = None,
1183
+ # position_ids: Optional[torch.Tensor] = None,
1184
+ # output_attentions: Optional[bool] = None,
1185
+ # output_hidden_states: Optional[bool] = None,
1186
+ # return_dict: Optional[bool] = None,
1187
+ # ) -> torch.FloatTensor:
1188
+ # r"""
1189
+ # Returns:
1190
+ # text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
1191
+ # applying the projection layer to the pooled output of [`SiglipTextModel`].
1192
+
1193
+ # Examples:
1194
+
1195
+ # ```python
1196
+ # >>> from transformers import AutoTokenizer, SiglipModel
1197
+
1198
+ # >>> model = SiglipModel.from_pretrained("google/siglip-base-patch16-224")
1199
+ # >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
1200
+
1201
+ # >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
1202
+ # >>> text_features = model.get_text_features(**inputs)
1203
+ # ```"""
1204
+ # # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1205
+ # output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1206
+ # output_hidden_states = (
1207
+ # output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1208
+ # )
1209
+ # return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1210
+
1211
+ # text_outputs = self.text_model(
1212
+ # input_ids=input_ids,
1213
+ # attention_mask=attention_mask,
1214
+ # position_ids=position_ids,
1215
+ # output_attentions=output_attentions,
1216
+ # output_hidden_states=output_hidden_states,
1217
+ # return_dict=return_dict,
1218
+ # )
1219
+
1220
+ # pooled_output = text_outputs[1]
1221
+
1222
+ # return pooled_output
1223
+
1224
+ # @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1225
+ # def get_image_features(
1226
+ # self,
1227
+ # pixel_values: Optional[torch.FloatTensor] = None,
1228
+ # output_attentions: Optional[bool] = None,
1229
+ # output_hidden_states: Optional[bool] = None,
1230
+ # return_dict: Optional[bool] = None,
1231
+ # ) -> torch.FloatTensor:
1232
+ # r"""
1233
+ # Returns:
1234
+ # image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
1235
+ # applying the projection layer to the pooled output of [`SiglipVisionModel`].
1236
+
1237
+ # Examples:
1238
+
1239
+ # ```python
1240
+ # >>> from PIL import Image
1241
+ # >>> import requests
1242
+ # >>> from transformers import AutoProcessor, SiglipModel
1243
+
1244
+ # >>> model = SiglipModel.from_pretrained("google/siglip-base-patch16-224")
1245
+ # >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1246
+
1247
+ # >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1248
+ # >>> image = Image.open(requests.get(url, stream=True).raw)
1249
+
1250
+ # >>> inputs = processor(images=image, return_tensors="pt")
1251
+
1252
+ # >>> image_features = model.get_image_features(**inputs)
1253
+ # ```"""
1254
+ # # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components.
1255
+ # output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1256
+ # output_hidden_states = (
1257
+ # output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1258
+ # )
1259
+ # return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1260
+
1261
+ # vision_outputs = self.vision_model(
1262
+ # pixel_values=pixel_values,
1263
+ # output_attentions=output_attentions,
1264
+ # output_hidden_states=output_hidden_states,
1265
+ # return_dict=return_dict,
1266
+ # )
1267
+
1268
+ # pooled_output = vision_outputs[1]
1269
+
1270
+ # return pooled_output
1271
+
1272
+ # @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING)
1273
+ # @replace_return_docstrings(output_type=SiglipOutput, config_class=SiglipConfig)
1274
+ # def forward(
1275
+ # self,
1276
+ # input_ids: Optional[torch.LongTensor] = None,
1277
+ # pixel_values: Optional[torch.FloatTensor] = None,
1278
+ # attention_mask: Optional[torch.Tensor] = None,
1279
+ # position_ids: Optional[torch.LongTensor] = None,
1280
+ # return_loss: Optional[bool] = None,
1281
+ # output_attentions: Optional[bool] = None,
1282
+ # output_hidden_states: Optional[bool] = None,
1283
+ # return_dict: Optional[bool] = None,
1284
+ # ) -> Union[Tuple, SiglipOutput]:
1285
+ # r"""
1286
+ # Returns:
1287
+
1288
+ # Examples:
1289
+
1290
+ # ```python
1291
+ # >>> from PIL import Image
1292
+ # >>> import requests
1293
+ # >>> from transformers import AutoProcessor, SiglipModel
1294
+
1295
+ # >>> model = SiglipModel.from_pretrained("google/siglip-base-patch16-224")
1296
+ # >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1297
+
1298
+ # >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1299
+ # >>> image = Image.open(requests.get(url, stream=True).raw)
1300
+
1301
+ # >>> inputs = processor(
1302
+ # ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
1303
+ # ... )
1304
+
1305
+ # >>> outputs = model(**inputs)
1306
+ # >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
1307
+ # >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
1308
+ # ```"""
1309
+ # # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1310
+ # output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1311
+ # output_hidden_states = (
1312
+ # output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1313
+ # )
1314
+ # return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1315
+
1316
+ # vision_outputs = self.vision_model(
1317
+ # pixel_values=pixel_values,
1318
+ # output_attentions=output_attentions,
1319
+ # output_hidden_states=output_hidden_states,
1320
+ # return_dict=return_dict,
1321
+ # )
1322
+
1323
+ # text_outputs = self.text_model(
1324
+ # input_ids=input_ids,
1325
+ # attention_mask=attention_mask,
1326
+ # position_ids=position_ids,
1327
+ # output_attentions=output_attentions,
1328
+ # output_hidden_states=output_hidden_states,
1329
+ # return_dict=return_dict,
1330
+ # )
1331
+
1332
+ # image_embeds = vision_outputs[1]
1333
+ # text_embeds = text_outputs[1]
1334
+
1335
+ # # normalized features
1336
+ # image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
1337
+ # text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
1338
+
1339
+ # # cosine similarity as logits
1340
+ # logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.temperature.exp() + self.bias
1341
+ # logits_per_image = logits_per_text.t()
1342
+
1343
+ # z = torch.matmul(image_embeds, text_embeds.t()) * self.temperature.exp()
1344
+
1345
+ # loss = None
1346
+ # if return_loss:
1347
+ # raise NotImplementedError("SigLIP loss to be implemented")
1348
+
1349
+ # if not return_dict:
1350
+ # output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
1351
+ # return ((loss,) + output) if loss is not None else output
1352
+
1353
+ # return SiglipOutput(
1354
+ # loss=loss,
1355
+ # logits_per_image=logits_per_image,
1356
+ # logits_per_text=logits_per_text,
1357
+ # text_embeds=text_embeds,
1358
+ # image_embeds=image_embeds,
1359
+ # text_model_output=text_outputs,
1360
+ # vision_model_output=vision_outputs,
1361
+ # )