Leyo commited on
Commit
778a0be
1 Parent(s): 11ef1e5

Upload folder using huggingface_hub

Browse files
config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "HuggingFaceM4/siglip-so400m-14-384-flash-attn",
3
+ "architectures": [
4
+ "SiglipModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "HuggingFaceM4/siglip-so400m-14-384-flash-attn--configuration_siglip.SiglipConfig",
8
+ "AutoModel": "HuggingFaceM4/siglip-so400m-14-384-flash-attn--modeling_siglip.SiglipModel"
9
+ },
10
+ "initializer_factor": 1.0,
11
+ "logit_scale_init_value": 2.6592,
12
+ "model_type": "siglip",
13
+ "projection_dim": 512,
14
+ "text_config": {
15
+ "hidden_size": 1152,
16
+ "intermediate_size": 4304,
17
+ "model_type": "siglip_text_model",
18
+ "num_attention_heads": 16,
19
+ "num_hidden_layers": 27,
20
+ "vocab_size": 32000
21
+ },
22
+ "torch_dtype": "float32",
23
+ "transformers_version": "4.35.2",
24
+ "vision_config": {
25
+ "hidden_size": 1152,
26
+ "image_size": 384,
27
+ "intermediate_size": 4304,
28
+ "model_type": "siglip_vision_model",
29
+ "num_attention_heads": 16,
30
+ "num_hidden_layers": 27,
31
+ "patch_size": 14
32
+ }
33
+ }
configuration_siglip.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Siglip model configuration"""
16
+
17
+ import os
18
+ from collections import OrderedDict
19
+ from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
20
+
21
+
22
+ if TYPE_CHECKING:
23
+ from transformers.processing_utils import ProcessorMixin
24
+ from transformers.utils import TensorType
25
+
26
+ from transformers.configuration_utils import PretrainedConfig
27
+ from transformers.onnx import OnnxConfig
28
+ from transformers.utils import logging
29
+
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+ SIGLIP_PRETRAINED_CONFIG_ARCHIVE_MAP = {
34
+ "google/siglip-base-patch16-224": "https://huggingface.co/google/siglip-base-patch16-224/resolve/main/config.json",
35
+ }
36
+
37
+
38
+ class SiglipTextConfig(PretrainedConfig):
39
+ r"""
40
+ This is the configuration class to store the configuration of a [`SiglipTextModel`]. It is used to instantiate a
41
+ Siglip text encoder according to the specified arguments, defining the model architecture. Instantiating a
42
+ configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip
43
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
44
+
45
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
46
+ documentation from [`PretrainedConfig`] for more information.
47
+
48
+ Args:
49
+ vocab_size (`int`, *optional*, defaults to 49408):
50
+ Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by
51
+ the `inputs_ids` passed when calling [`SiglipModel`].
52
+ hidden_size (`int`, *optional*, defaults to 512):
53
+ Dimensionality of the encoder layers and the pooler layer.
54
+ intermediate_size (`int`, *optional*, defaults to 2048):
55
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
56
+ num_hidden_layers (`int`, *optional*, defaults to 12):
57
+ Number of hidden layers in the Transformer encoder.
58
+ num_attention_heads (`int`, *optional*, defaults to 8):
59
+ Number of attention heads for each attention layer in the Transformer encoder.
60
+ max_position_embeddings (`int`, *optional*, defaults to 64):
61
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
62
+ just in case (e.g., 512 or 1024 or 2048).
63
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
64
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
65
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
66
+ layer_norm_eps (`float`, *optional*, defaults to 1e-6):
67
+ The epsilon used by the layer normalization layers.
68
+ attention_dropout (`float`, *optional*, defaults to 0.0):
69
+ The dropout ratio for the attention probabilities.
70
+ initializer_range (`float`, *optional*, defaults to 0.02):
71
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
72
+ initializer_factor (`float`, *optional*, defaults to 1):
73
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
74
+ testing).
75
+
76
+ Example:
77
+
78
+ ```python
79
+ >>> from transformers import SiglipTextConfig, SiglipTextModel
80
+
81
+ >>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration
82
+ >>> configuration = SiglipTextConfig()
83
+
84
+ >>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration
85
+ >>> model = SiglipTextModel(configuration)
86
+
87
+ >>> # Accessing the model configuration
88
+ >>> configuration = model.config
89
+ ```"""
90
+ model_type = "siglip_text_model"
91
+
92
+ def __init__(
93
+ self,
94
+ vocab_size=49408,
95
+ hidden_size=512,
96
+ intermediate_size=2048,
97
+ projection_dim=512,
98
+ num_hidden_layers=12,
99
+ num_attention_heads=8,
100
+ max_position_embeddings=64,
101
+ hidden_act="gelu_pytorch_tanh",
102
+ layer_norm_eps=1e-6,
103
+ attention_dropout=0.0,
104
+ initializer_range=0.02,
105
+ initializer_factor=1.0,
106
+ # This differs from `CLIPTokenizer`'s default and from openai/siglip
107
+ # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
108
+ pad_token_id=1,
109
+ bos_token_id=49406,
110
+ eos_token_id=49407,
111
+ _flash_attn_2_enabled=True,
112
+ **kwargs,
113
+ ):
114
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
115
+
116
+ self.vocab_size = vocab_size
117
+ self.hidden_size = hidden_size
118
+ self.intermediate_size = intermediate_size
119
+ self.projection_dim = projection_dim
120
+ self.num_hidden_layers = num_hidden_layers
121
+ self.num_attention_heads = num_attention_heads
122
+ self.max_position_embeddings = max_position_embeddings
123
+ self.layer_norm_eps = layer_norm_eps
124
+ self.hidden_act = hidden_act
125
+ self.initializer_range = initializer_range
126
+ self.initializer_factor = initializer_factor
127
+ self.attention_dropout = attention_dropout
128
+ self._flash_attn_2_enabled = _flash_attn_2_enabled
129
+
130
+ @classmethod
131
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
132
+ cls._set_token_in_kwargs(kwargs)
133
+
134
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
135
+
136
+ # get the text config dict if we are loading from SiglipConfig
137
+ if config_dict.get("model_type") == "siglip":
138
+ config_dict = config_dict["text_config"]
139
+
140
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
141
+ logger.warning(
142
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
143
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
144
+ )
145
+
146
+ return cls.from_dict(config_dict, **kwargs)
147
+
148
+
149
+ class SiglipVisionConfig(PretrainedConfig):
150
+ r"""
151
+ This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
152
+ Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
153
+ configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
154
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
155
+
156
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
157
+ documentation from [`PretrainedConfig`] for more information.
158
+
159
+ Args:
160
+ hidden_size (`int`, *optional*, defaults to 768):
161
+ Dimensionality of the encoder layers and the pooler layer.
162
+ intermediate_size (`int`, *optional*, defaults to 3072):
163
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
164
+ num_hidden_layers (`int`, *optional*, defaults to 12):
165
+ Number of hidden layers in the Transformer encoder.
166
+ num_attention_heads (`int`, *optional*, defaults to 12):
167
+ Number of attention heads for each attention layer in the Transformer encoder.
168
+ image_size (`int`, *optional*, defaults to 224):
169
+ The size (resolution) of each image.
170
+ patch_size (`int`, *optional*, defaults to 32):
171
+ The size (resolution) of each patch.
172
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
173
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
174
+ `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
175
+ layer_norm_eps (`float`, *optional*, defaults to 1e-6):
176
+ The epsilon used by the layer normalization layers.
177
+ attention_dropout (`float`, *optional*, defaults to 0.0):
178
+ The dropout ratio for the attention probabilities.
179
+ initializer_range (`float`, *optional*, defaults to 0.02):
180
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
181
+ initializer_factor (`float`, *optional*, defaults to 1):
182
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
183
+ testing).
184
+
185
+ Example:
186
+
187
+ ```python
188
+ >>> from transformers import SiglipVisionConfig, SiglipVisionModel
189
+
190
+ >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
191
+ >>> configuration = SiglipVisionConfig()
192
+
193
+ >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
194
+ >>> model = SiglipVisionModel(configuration)
195
+
196
+ >>> # Accessing the model configuration
197
+ >>> configuration = model.config
198
+ ```"""
199
+
200
+ model_type = "siglip_vision_model"
201
+
202
+ def __init__(
203
+ self,
204
+ hidden_size=768,
205
+ intermediate_size=3072,
206
+ projection_dim=512,
207
+ num_hidden_layers=12,
208
+ num_attention_heads=12,
209
+ num_channels=3,
210
+ image_size=224,
211
+ patch_size=32,
212
+ hidden_act="gelu_pytorch_tanh",
213
+ layer_norm_eps=1e-6,
214
+ attention_dropout=0.0,
215
+ initializer_range=0.02,
216
+ initializer_factor=1.0,
217
+ _flash_attn_2_enabled=True,
218
+ **kwargs,
219
+ ):
220
+ super().__init__(**kwargs)
221
+
222
+ self.hidden_size = hidden_size
223
+ self.intermediate_size = intermediate_size
224
+ self.projection_dim = projection_dim
225
+ self.num_hidden_layers = num_hidden_layers
226
+ self.num_attention_heads = num_attention_heads
227
+ self.num_channels = num_channels
228
+ self.patch_size = patch_size
229
+ self.image_size = image_size
230
+ self.initializer_range = initializer_range
231
+ self.initializer_factor = initializer_factor
232
+ self.attention_dropout = attention_dropout
233
+ self.layer_norm_eps = layer_norm_eps
234
+ self.hidden_act = hidden_act
235
+ self._flash_attn_2_enabled = _flash_attn_2_enabled
236
+
237
+ @classmethod
238
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
239
+ cls._set_token_in_kwargs(kwargs)
240
+
241
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
242
+
243
+ # get the vision config dict if we are loading from SiglipConfig
244
+ if config_dict.get("model_type") == "siglip":
245
+ config_dict = config_dict["vision_config"]
246
+
247
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
248
+ logger.warning(
249
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
250
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
251
+ )
252
+
253
+ return cls.from_dict(config_dict, **kwargs)
254
+
255
+
256
+ class SiglipConfig(PretrainedConfig):
257
+ r"""
258
+ [`SiglipConfig`] is the configuration class to store the configuration of a [`SiglipModel`]. It is used to
259
+ instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs.
260
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip
261
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
262
+
263
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
264
+ documentation from [`PretrainedConfig`] for more information.
265
+
266
+ Args:
267
+ text_config (`dict`, *optional*):
268
+ Dictionary of configuration options used to initialize [`SiglipTextConfig`].
269
+ vision_config (`dict`, *optional*):
270
+ Dictionary of configuration options used to initialize [`SiglipVisionConfig`].
271
+ projection_dim (`int`, *optional*, defaults to 512):
272
+ Dimentionality of text and vision projection layers.
273
+ logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
274
+ The inital value of the *logit_scale* paramter. Default is used as per the original Siglip implementation.
275
+ kwargs (*optional*):
276
+ Dictionary of keyword arguments.
277
+
278
+ Example:
279
+
280
+ ```python
281
+ >>> from transformers import SiglipConfig, SiglipModel
282
+
283
+ >>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 style configuration
284
+ >>> configuration = SiglipConfig()
285
+
286
+ >>> # Initializing a SiglipModel (with random weights) from the google/siglip-base-patch16-224 style configuration
287
+ >>> model = SiglipModel(configuration)
288
+
289
+ >>> # Accessing the model configuration
290
+ >>> configuration = model.config
291
+
292
+ >>> # We can also initialize a SiglipConfig from a SiglipTextConfig and a SiglipVisionConfig
293
+ >>> from transformers import SiglipTextConfig, SiglipVisionConfig
294
+
295
+ >>> # Initializing a SiglipText and SiglipVision configuration
296
+ >>> config_text = SiglipTextConfig()
297
+ >>> config_vision = SiglipVisionConfig()
298
+
299
+ >>> config = SiglipConfig.from_text_vision_configs(config_text, config_vision)
300
+ ```"""
301
+
302
+ model_type = "siglip"
303
+
304
+ def __init__(
305
+ self, text_config=None, vision_config=None, projection_dim=512, logit_scale_init_value=2.6592, **kwargs
306
+ ):
307
+ # If `_config_dict` exist, we use them for the backward compatibility.
308
+ # We pop out these 2 attributes before calling `super().__init__` to avoid them being saved (which causes a lot
309
+ # of confusion!).
310
+ text_config_dict = kwargs.pop("text_config_dict", None)
311
+ vision_config_dict = kwargs.pop("vision_config_dict", None)
312
+
313
+ super().__init__(**kwargs)
314
+
315
+ # Instead of simply assigning `[text|vision]_config_dict` to `[text|vision]_config`, we use the values in
316
+ # `[text|vision]_config_dict` to update the values in `[text|vision]_config`. The values should be same in most
317
+ # cases, but we don't want to break anything regarding `_config_dict` that existed before commit `8827e1b2`.
318
+ if text_config_dict is not None:
319
+ if text_config is None:
320
+ text_config = {}
321
+
322
+ # This is the complete result when using `text_config_dict`.
323
+ _text_config_dict = SiglipTextConfig(**text_config_dict).to_dict()
324
+
325
+ # Give a warning if the values exist in both `_text_config_dict` and `text_config` but being different.
326
+ for key, value in _text_config_dict.items():
327
+ if key in text_config and value != text_config[key] and key not in ["transformers_version"]:
328
+ # If specified in `text_config_dict`
329
+ if key in text_config_dict:
330
+ message = (
331
+ f"`{key}` is found in both `text_config_dict` and `text_config` but with different values."
332
+ f' The value `text_config_dict["{key}"]` will be used instead.'
333
+ )
334
+ # If inferred from default argument values (just to be super careful)
335
+ else:
336
+ message = (
337
+ "`text_config_dict` is provided which will be used to initialize `SiglipTextConfig`. The "
338
+ f'value `text_config["{key}"]` will be overriden.'
339
+ )
340
+ logger.warning(message)
341
+
342
+ # Update all values in `text_config` with the ones in `_text_config_dict`.
343
+ text_config.update(_text_config_dict)
344
+
345
+ if vision_config_dict is not None:
346
+ if vision_config is None:
347
+ vision_config = {}
348
+
349
+ # This is the complete result when using `vision_config_dict`.
350
+ _vision_config_dict = SiglipVisionConfig(**vision_config_dict).to_dict()
351
+ # convert keys to string instead of integer
352
+ if "id2label" in _vision_config_dict:
353
+ _vision_config_dict["id2label"] = {
354
+ str(key): value for key, value in _vision_config_dict["id2label"].items()
355
+ }
356
+
357
+ # Give a warning if the values exist in both `_vision_config_dict` and `vision_config` but being different.
358
+ for key, value in _vision_config_dict.items():
359
+ if key in vision_config and value != vision_config[key] and key not in ["transformers_version"]:
360
+ # If specified in `vision_config_dict`
361
+ if key in vision_config_dict:
362
+ message = (
363
+ f"`{key}` is found in both `vision_config_dict` and `vision_config` but with different "
364
+ f'values. The value `vision_config_dict["{key}"]` will be used instead.'
365
+ )
366
+ # If inferred from default argument values (just to be super careful)
367
+ else:
368
+ message = (
369
+ "`vision_config_dict` is provided which will be used to initialize `SiglipVisionConfig`. "
370
+ f'The value `vision_config["{key}"]` will be overriden.'
371
+ )
372
+ logger.warning(message)
373
+
374
+ # Update all values in `vision_config` with the ones in `_vision_config_dict`.
375
+ vision_config.update(_vision_config_dict)
376
+
377
+ if text_config is None:
378
+ text_config = {}
379
+ logger.info("`text_config` is `None`. Initializing the `SiglipTextConfig` with default values.")
380
+
381
+ if vision_config is None:
382
+ vision_config = {}
383
+ logger.info("`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values.")
384
+
385
+ self.text_config = SiglipTextConfig(**text_config)
386
+ self.vision_config = SiglipVisionConfig(**vision_config)
387
+
388
+ self.projection_dim = projection_dim
389
+ self.logit_scale_init_value = logit_scale_init_value
390
+ self.initializer_factor = 1.0
391
+
392
+ @classmethod
393
+ def from_text_vision_configs(cls, text_config: SiglipTextConfig, vision_config: SiglipVisionConfig, **kwargs):
394
+ r"""
395
+ Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision
396
+ model configuration.
397
+
398
+ Returns:
399
+ [`SiglipConfig`]: An instance of a configuration object
400
+ """
401
+
402
+ return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
403
+
404
+
405
+ class SiglipOnnxConfig(OnnxConfig):
406
+ @property
407
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
408
+ return OrderedDict(
409
+ [
410
+ ("input_ids", {0: "batch", 1: "sequence"}),
411
+ ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
412
+ ("attention_mask", {0: "batch", 1: "sequence"}),
413
+ ]
414
+ )
415
+
416
+ @property
417
+ def outputs(self) -> Mapping[str, Mapping[int, str]]:
418
+ return OrderedDict(
419
+ [
420
+ ("logits_per_image", {0: "batch"}),
421
+ ("logits_per_text", {0: "batch"}),
422
+ ("text_embeds", {0: "batch"}),
423
+ ("image_embeds", {0: "batch"}),
424
+ ]
425
+ )
426
+
427
+ @property
428
+ def atol_for_validation(self) -> float:
429
+ return 1e-4
430
+
431
+ def generate_dummy_inputs(
432
+ self,
433
+ processor: "ProcessorMixin",
434
+ batch_size: int = -1,
435
+ seq_length: int = -1,
436
+ framework: Optional["TensorType"] = None,
437
+ ) -> Mapping[str, Any]:
438
+ text_input_dict = super().generate_dummy_inputs(
439
+ processor.tokenizer, batch_size=batch_size, seq_length=seq_length, framework=framework
440
+ )
441
+ image_input_dict = super().generate_dummy_inputs(
442
+ processor.image_processor, batch_size=batch_size, framework=framework
443
+ )
444
+ return {**text_input_dict, **image_input_dict}
445
+
446
+ @property
447
+ def default_onnx_opset(self) -> int:
448
+ return 14
image_processing_siglip.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for SigLIP."""
16
+
17
+ from typing import Dict, Optional, Union
18
+
19
+ import numpy as np
20
+
21
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
22
+ from transformers.image_transforms import (
23
+ rescale,
24
+ resize,
25
+ to_channel_dimension_format,
26
+ )
27
+ from transformers.image_utils import (
28
+ ChannelDimension,
29
+ ImageInput,
30
+ PILImageResampling,
31
+ infer_channel_dimension_format,
32
+ is_scaled_image,
33
+ make_list_of_images,
34
+ to_numpy_array,
35
+ valid_images,
36
+ )
37
+ from transformers.utils import TensorType, is_vision_available, logging
38
+
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+
43
+ if is_vision_available():
44
+ import PIL
45
+
46
+
47
+ class SiglipImageProcessor(BaseImageProcessor):
48
+ r"""
49
+ Constructs a SigLIP image processor.
50
+
51
+ Args:
52
+ do_resize (`bool`, *optional*, defaults to `True`):
53
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
54
+ `do_resize` in the `preprocess` method.
55
+ size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
56
+ Size of the image after resizing. Can be overridden by `size` in the `preprocess` method.
57
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
58
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
59
+ do_rescale (`bool`, *optional*, defaults to `True`):
60
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
61
+ the `preprocess` method.
62
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
63
+ Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
64
+ method.
65
+ """
66
+
67
+ model_input_names = ["pixel_values"]
68
+
69
+ def __init__(
70
+ self,
71
+ do_resize: bool = True,
72
+ size: Dict[str, int] = None,
73
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
74
+ do_rescale: bool = True,
75
+ rescale_factor: Union[int, float] = 1 / 255,
76
+ **kwargs,
77
+ ) -> None:
78
+ super().__init__(**kwargs)
79
+ size = size if size is not None else {"height": 224, "width": 224}
80
+ size = get_size_dict(size, default_to_square=False)
81
+
82
+ self.do_resize = do_resize
83
+ self.size = size
84
+ self.resample = resample
85
+ self.do_rescale = do_rescale
86
+ self.rescale_factor = rescale_factor
87
+
88
+ def rescale(
89
+ self,
90
+ image: np.ndarray,
91
+ rescale_factor: float,
92
+ data_format: Optional[Union[str, ChannelDimension]] = None,
93
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
94
+ **kwargs,
95
+ ) -> np.ndarray:
96
+ """
97
+ Rescale an image by a scale factor. image = image * scale, after which image = image * 2 - 1.
98
+
99
+ Args:
100
+ image (`np.ndarray`):
101
+ Image to rescale.
102
+ scale (`float`):
103
+ The scaling factor to rescale pixel values by.
104
+ data_format (`str` or `ChannelDimension`, *optional*):
105
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
106
+ image is used. Can be one of:
107
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
108
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
109
+ input_data_format (`ChannelDimension` or `str`, *optional*):
110
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
111
+ from the input image. Can be one of:
112
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
113
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
114
+
115
+ Returns:
116
+ `np.ndarray`: The rescaled image.
117
+ """
118
+ # first, rescale to 0->1
119
+ rescaled_image = rescale(
120
+ image, scale=rescale_factor, data_format=data_format, input_data_format=input_data_format, **kwargs
121
+ )
122
+
123
+ # next, rescale to -1->1
124
+ rescaled_image = 2 * rescaled_image - 1
125
+
126
+ return rescaled_image
127
+
128
+ def preprocess(
129
+ self,
130
+ images: ImageInput,
131
+ do_resize: bool = None,
132
+ size: Dict[str, int] = None,
133
+ resample: PILImageResampling = None,
134
+ do_rescale: bool = None,
135
+ rescale_factor: float = None,
136
+ return_tensors: Optional[Union[str, TensorType]] = None,
137
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
138
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
139
+ **kwargs,
140
+ ) -> PIL.Image.Image:
141
+ """
142
+ Preprocess an image or batch of images.
143
+
144
+ Args:
145
+ images (`ImageInput`):
146
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
147
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
148
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
149
+ Whether to resize the image.
150
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
151
+ Size of the image after resizing.
152
+ resample (`int`, *optional*, defaults to `self.resample`):
153
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
154
+ has an effect if `do_resize` is set to `True`.
155
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
156
+ Whether to rescale the image.
157
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
158
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
159
+ return_tensors (`str` or `TensorType`, *optional*):
160
+ The type of tensors to return. Can be one of:
161
+ - Unset: Return a list of `np.ndarray`.
162
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
163
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
164
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
165
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
166
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
167
+ The channel dimension format for the output image. Can be one of:
168
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
169
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
170
+ - Unset: Use the channel dimension format of the input image.
171
+ input_data_format (`ChannelDimension` or `str`, *optional*):
172
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
173
+ from the input image. Can be one of:
174
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
175
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
176
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
177
+ """
178
+ do_resize = do_resize if do_resize is not None else self.do_resize
179
+ size = size if size is not None else self.size
180
+ size = get_size_dict(size, param_name="size", default_to_square=False)
181
+ resample = resample if resample is not None else self.resample
182
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
183
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
184
+
185
+ images = make_list_of_images(images)
186
+
187
+ if not valid_images(images):
188
+ raise ValueError(
189
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
190
+ "torch.Tensor, tf.Tensor or jax.ndarray."
191
+ )
192
+
193
+ if do_resize and size is None:
194
+ raise ValueError("Size must be specified if do_resize is True.")
195
+
196
+ if do_rescale and rescale_factor is None:
197
+ raise ValueError("Rescale factor must be specified if do_rescale is True.")
198
+
199
+ # All transformations expect numpy arrays.
200
+ images = [to_numpy_array(image) for image in images]
201
+
202
+ if is_scaled_image(images[0]) and do_rescale:
203
+ logger.warning_once(
204
+ "It looks like you are trying to rescale already rescaled images. If the input"
205
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
206
+ )
207
+
208
+ if input_data_format is None:
209
+ # We assume that all images have the same channel dimension format.
210
+ input_data_format = infer_channel_dimension_format(images[0])
211
+
212
+ if do_resize:
213
+ images = [
214
+ resize(image=image, size=(size["width"], size["height"]), resample=resample, input_data_format=input_data_format)
215
+ for image in images
216
+ ]
217
+
218
+ if do_rescale:
219
+ images = [
220
+ self.rescale(image=image, rescale_factor=rescale_factor, input_data_format=input_data_format)
221
+ for image in images
222
+ ]
223
+
224
+ images = [
225
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
226
+ ]
227
+
228
+ data = {"pixel_values": images}
229
+ return BatchFeature(data=data, tensor_type=return_tensors)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90e016fcc4615f3c5d998008727291dbe0c00c1fe8e1089cdb2c04565286422c
3
+ size 3511961264
modeling_siglip.py ADDED
@@ -0,0 +1,1367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Google AI and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch Siglip model."""
16
+
17
+
18
+ from dataclasses import dataclass
19
+ from typing import Any, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+ from transformers.activations import ACT2FN
26
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
27
+ from transformers.modeling_utils import PreTrainedModel
28
+ from transformers.utils import (
29
+ ModelOutput,
30
+ add_start_docstrings,
31
+ add_start_docstrings_to_model_forward,
32
+ is_flash_attn_2_available,
33
+ logging,
34
+ replace_return_docstrings,
35
+ )
36
+
37
+ from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
38
+
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+ _CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
43
+
44
+ SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
45
+ "google/siglip-base-patch16-224",
46
+ # See all SigLIP models at https://huggingface.co/models?filter=siglip
47
+ ]
48
+
49
+ if is_flash_attn_2_available():
50
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
51
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
52
+
53
+
54
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
55
+ def _get_unpad_data(attention_mask):
56
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
57
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
58
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
59
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
60
+ return (
61
+ indices,
62
+ cu_seqlens,
63
+ max_seqlen_in_batch,
64
+ )
65
+
66
+
67
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
68
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
69
+ """
70
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
71
+ """
72
+ bsz, src_len = mask.size()
73
+ tgt_len = tgt_len if tgt_len is not None else src_len
74
+
75
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
76
+
77
+ inverted_mask = 1.0 - expanded_mask
78
+
79
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
80
+
81
+
82
+ # contrastive loss function, adapted from
83
+ # https://sachinruk.github.io/blog/2021-03-07-siglip.html
84
+ def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
85
+ return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
86
+
87
+
88
+ # Copied from transformers.models.clip.modeling_clip.clip_loss with clip->siglip
89
+ def siglip_loss(similarity: torch.Tensor) -> torch.Tensor:
90
+ caption_loss = contrastive_loss(similarity)
91
+ image_loss = contrastive_loss(similarity.t())
92
+ return (caption_loss + image_loss) / 2.0
93
+
94
+
95
+ @dataclass
96
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip
97
+ class SiglipVisionModelOutput(ModelOutput):
98
+ """
99
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
100
+
101
+ Args:
102
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
103
+ The image embeddings obtained by applying the projection layer to the pooler_output.
104
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
105
+ Sequence of hidden-states at the output of the last layer of the model.
106
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
107
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
108
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
109
+
110
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
111
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
112
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
113
+ sequence_length)`.
114
+
115
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
116
+ heads.
117
+ """
118
+
119
+ image_embeds: Optional[torch.FloatTensor] = None
120
+ last_hidden_state: torch.FloatTensor = None
121
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
122
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
123
+
124
+
125
+ @dataclass
126
+ # Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip
127
+ class SiglipTextModelOutput(ModelOutput):
128
+ """
129
+ Base class for text model's outputs that also contains a pooling of the last hidden states.
130
+
131
+ Args:
132
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
133
+ The text embeddings obtained by applying the projection layer to the pooler_output.
134
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
135
+ Sequence of hidden-states at the output of the last layer of the model.
136
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
137
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
138
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
139
+
140
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
141
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
142
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
143
+ sequence_length)`.
144
+
145
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
146
+ heads.
147
+ """
148
+
149
+ text_embeds: Optional[torch.FloatTensor] = None
150
+ last_hidden_state: torch.FloatTensor = None
151
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
152
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
153
+
154
+
155
+ @dataclass
156
+ # Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip
157
+ class SiglipOutput(ModelOutput):
158
+ """
159
+ Args:
160
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
161
+ Contrastive loss for image-text similarity.
162
+ logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
163
+ The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
164
+ similarity scores.
165
+ logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
166
+ The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
167
+ similarity scores.
168
+ text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
169
+ The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
170
+ image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
171
+ The image embeddings obtained by applying the projection layer to the pooled output of
172
+ [`SiglipVisionModel`].
173
+ text_model_output(`BaseModelOutputWithPooling`):
174
+ The output of the [`SiglipTextModel`].
175
+ vision_model_output(`BaseModelOutputWithPooling`):
176
+ The output of the [`SiglipVisionModel`].
177
+ """
178
+
179
+ loss: Optional[torch.FloatTensor] = None
180
+ logits_per_image: torch.FloatTensor = None
181
+ logits_per_text: torch.FloatTensor = None
182
+ text_embeds: torch.FloatTensor = None
183
+ image_embeds: torch.FloatTensor = None
184
+ text_model_output: BaseModelOutputWithPooling = None
185
+ vision_model_output: BaseModelOutputWithPooling = None
186
+
187
+ def to_tuple(self) -> Tuple[Any]:
188
+ return tuple(
189
+ self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
190
+ for k in self.keys()
191
+ )
192
+
193
+
194
+ class SiglipVisionEmbeddings(nn.Module):
195
+ def __init__(self, config: SiglipVisionConfig):
196
+ super().__init__()
197
+ self.config = config
198
+ self.embed_dim = config.hidden_size
199
+ self.image_size = config.image_size
200
+ self.patch_size = config.patch_size
201
+
202
+ self.patch_embedding = nn.Conv2d(
203
+ in_channels=config.num_channels,
204
+ out_channels=self.embed_dim,
205
+ kernel_size=self.patch_size,
206
+ stride=self.patch_size,
207
+ padding="valid",
208
+ )
209
+
210
+ self.num_patches = (self.image_size // self.patch_size) ** 2
211
+ self.num_positions = self.num_patches
212
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
213
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
214
+
215
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
216
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
217
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
218
+
219
+ embeddings = embeddings + self.position_embedding(self.position_ids)
220
+ return embeddings
221
+
222
+
223
+ # Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip
224
+ class SiglipTextEmbeddings(nn.Module):
225
+ def __init__(self, config: SiglipTextConfig):
226
+ super().__init__()
227
+ embed_dim = config.hidden_size
228
+
229
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
230
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
231
+
232
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
233
+ self.register_buffer(
234
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
235
+ )
236
+
237
+ def forward(
238
+ self,
239
+ input_ids: Optional[torch.LongTensor] = None,
240
+ position_ids: Optional[torch.LongTensor] = None,
241
+ inputs_embeds: Optional[torch.FloatTensor] = None,
242
+ ) -> torch.Tensor:
243
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
244
+
245
+ if position_ids is None:
246
+ position_ids = self.position_ids[:, :seq_length]
247
+
248
+ if inputs_embeds is None:
249
+ inputs_embeds = self.token_embedding(input_ids)
250
+
251
+ position_embeddings = self.position_embedding(position_ids)
252
+ embeddings = inputs_embeds + position_embeddings
253
+
254
+ return embeddings
255
+
256
+
257
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->Siglip
258
+ class SiglipAttention(nn.Module):
259
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
260
+
261
+ def __init__(self, config):
262
+ super().__init__()
263
+ self.config = config
264
+ self.embed_dim = config.hidden_size
265
+ self.num_heads = config.num_attention_heads
266
+ self.head_dim = self.embed_dim // self.num_heads
267
+ if self.head_dim * self.num_heads != self.embed_dim:
268
+ raise ValueError(
269
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
270
+ f" {self.num_heads})."
271
+ )
272
+ self.scale = self.head_dim**-0.5
273
+ self.dropout = config.attention_dropout
274
+
275
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
276
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
277
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
278
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
279
+
280
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
281
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
282
+
283
+ def forward(
284
+ self,
285
+ hidden_states: torch.Tensor,
286
+ attention_mask: Optional[torch.Tensor] = None,
287
+ causal_attention_mask: Optional[torch.Tensor] = None,
288
+ output_attentions: Optional[bool] = False,
289
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
290
+ """Input shape: Batch x Time x Channel"""
291
+
292
+ bsz, tgt_len, embed_dim = hidden_states.size()
293
+
294
+ # get query proj
295
+ query_states = self.q_proj(hidden_states) * self.scale
296
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
297
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
298
+
299
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
300
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
301
+ key_states = key_states.view(*proj_shape)
302
+ value_states = value_states.view(*proj_shape)
303
+
304
+ src_len = key_states.size(1)
305
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
306
+
307
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
308
+ raise ValueError(
309
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
310
+ f" {attn_weights.size()}"
311
+ )
312
+
313
+ # apply the causal_attention_mask first
314
+ if causal_attention_mask is not None:
315
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
316
+ raise ValueError(
317
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
318
+ f" {causal_attention_mask.size()}"
319
+ )
320
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
321
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
322
+
323
+ if attention_mask is not None:
324
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
325
+ raise ValueError(
326
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
327
+ )
328
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
329
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
330
+
331
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
332
+
333
+ if output_attentions:
334
+ # this operation is a bit akward, but it's required to
335
+ # make sure that attn_weights keeps its gradient.
336
+ # In order to do so, attn_weights have to reshaped
337
+ # twice and have to be reused in the following
338
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
339
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
340
+ else:
341
+ attn_weights_reshaped = None
342
+
343
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
344
+
345
+ attn_output = torch.bmm(attn_probs, value_states)
346
+
347
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
348
+ raise ValueError(
349
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
350
+ f" {attn_output.size()}"
351
+ )
352
+
353
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
354
+ attn_output = attn_output.transpose(1, 2)
355
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
356
+
357
+ attn_output = self.out_proj(attn_output)
358
+
359
+ return attn_output, attn_weights_reshaped
360
+
361
+
362
+ class SiglipFlashAttention2(SiglipAttention):
363
+ """
364
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
365
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
366
+ flash attention and deal with padding tokens in case the input contains any of them.
367
+ """
368
+
369
+ def __init__(self, *args, **kwargs):
370
+ super().__init__(*args, **kwargs)
371
+ self.is_causal = False # Hack to make sure we don't use a causal mask
372
+
373
+ def forward(
374
+ self,
375
+ hidden_states: torch.Tensor,
376
+ attention_mask: Optional[torch.LongTensor] = None,
377
+ position_ids: Optional[torch.LongTensor] = None,
378
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
379
+ output_attentions: bool = False,
380
+ use_cache: bool = False,
381
+ **kwargs,
382
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
383
+ output_attentions = False
384
+
385
+ bsz, q_len, _ = hidden_states.size()
386
+
387
+ query_states = self.q_proj(hidden_states)
388
+ key_states = self.k_proj(hidden_states)
389
+ value_states = self.v_proj(hidden_states)
390
+
391
+ # Flash attention requires the input to have the shape
392
+ # batch_size x seq_length x head_dim x hidden_dim
393
+ # therefore we just need to keep the original shape
394
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
395
+ key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
396
+ value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
397
+
398
+ kv_seq_len = key_states.shape[-2]
399
+ if past_key_value is not None:
400
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
401
+ # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
402
+ # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
403
+
404
+ # if past_key_value is not None:
405
+ # cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
406
+ # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
407
+
408
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
409
+ # to be able to avoid many of these transpose/reshape/view.
410
+ query_states = query_states.transpose(1, 2)
411
+ key_states = key_states.transpose(1, 2)
412
+ value_states = value_states.transpose(1, 2)
413
+
414
+ dropout_rate = self.dropout if self.training else 0.0
415
+
416
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
417
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
418
+ # cast them back in the correct dtype just to be sure everything works as expected.
419
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
420
+ # in fp32. (LlamaRMSNorm handles it correctly)
421
+
422
+ input_dtype = query_states.dtype
423
+ if input_dtype == torch.float32:
424
+ if torch.is_autocast_enabled():
425
+ target_dtype = torch.get_autocast_gpu_dtype()
426
+ # Handle the case where the model is quantized
427
+ elif hasattr(self.config, "_pre_quantization_dtype"):
428
+ target_dtype = self.config._pre_quantization_dtype
429
+ else:
430
+ target_dtype = self.q_proj.weight.dtype
431
+
432
+ logger.warning_once(
433
+ "The input hidden states seems to be silently casted in float32, this might be related to the fact"
434
+ " you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
435
+ f" {target_dtype}."
436
+ )
437
+
438
+ query_states = query_states.to(target_dtype)
439
+ key_states = key_states.to(target_dtype)
440
+ value_states = value_states.to(target_dtype)
441
+
442
+ attn_output = self._flash_attention_forward(
443
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
444
+ )
445
+
446
+ attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous()
447
+ attn_output = self.out_proj(attn_output)
448
+
449
+ if not output_attentions:
450
+ attn_weights = None
451
+
452
+ return attn_output, attn_weights
453
+
454
+ def _flash_attention_forward(
455
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
456
+ ):
457
+ """
458
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
459
+ first unpad the input, then computes the attention scores and pad the final attention scores.
460
+
461
+ Args:
462
+ query_states (`torch.Tensor`):
463
+ Input query states to be passed to Flash Attention API
464
+ key_states (`torch.Tensor`):
465
+ Input key states to be passed to Flash Attention API
466
+ value_states (`torch.Tensor`):
467
+ Input value states to be passed to Flash Attention API
468
+ attention_mask (`torch.Tensor`):
469
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
470
+ position of padding tokens and 1 for the position of non-padding tokens.
471
+ dropout (`int`, *optional*):
472
+ Attention dropout
473
+ softmax_scale (`float`, *optional*):
474
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
475
+ """
476
+
477
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
478
+ causal = self.is_causal and query_length != 1
479
+
480
+ # Contains at least one padding token in the sequence
481
+ if attention_mask is not None:
482
+ batch_size = query_states.shape[0]
483
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
484
+ query_states, key_states, value_states, attention_mask, query_length
485
+ )
486
+
487
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
488
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
489
+
490
+ attn_output_unpad = flash_attn_varlen_func(
491
+ query_states,
492
+ key_states,
493
+ value_states,
494
+ cu_seqlens_q=cu_seqlens_q,
495
+ cu_seqlens_k=cu_seqlens_k,
496
+ max_seqlen_q=max_seqlen_in_batch_q,
497
+ max_seqlen_k=max_seqlen_in_batch_k,
498
+ dropout_p=dropout,
499
+ softmax_scale=softmax_scale,
500
+ causal=causal,
501
+ )
502
+
503
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
504
+ else:
505
+ attn_output = flash_attn_func(
506
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
507
+ )
508
+
509
+ return attn_output
510
+
511
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
512
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
513
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
514
+
515
+ key_layer = index_first_axis(
516
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
517
+ )
518
+ value_layer = index_first_axis(
519
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
520
+ )
521
+ if query_length == kv_seq_len:
522
+ query_layer = index_first_axis(
523
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
524
+ )
525
+ cu_seqlens_q = cu_seqlens_k
526
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
527
+ indices_q = indices_k
528
+ elif query_length == 1:
529
+ max_seqlen_in_batch_q = 1
530
+ cu_seqlens_q = torch.arange(
531
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
532
+ ) # There is a memcpy here, that is very bad.
533
+ indices_q = cu_seqlens_q[:-1]
534
+ query_layer = query_layer.squeeze(1)
535
+ else:
536
+ # The -q_len: slice assumes left padding.
537
+ attention_mask = attention_mask[:, -query_length:]
538
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
539
+
540
+ return (
541
+ query_layer,
542
+ key_layer,
543
+ value_layer,
544
+ indices_q,
545
+ (cu_seqlens_q, cu_seqlens_k),
546
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
547
+ )
548
+
549
+
550
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
551
+ class SiglipMLP(nn.Module):
552
+ def __init__(self, config):
553
+ super().__init__()
554
+ self.config = config
555
+ self.activation_fn = ACT2FN[config.hidden_act]
556
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
557
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
558
+
559
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
560
+ hidden_states = self.fc1(hidden_states)
561
+ hidden_states = self.activation_fn(hidden_states)
562
+ hidden_states = self.fc2(hidden_states)
563
+ return hidden_states
564
+
565
+
566
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip
567
+ class SiglipEncoderLayer(nn.Module):
568
+ def __init__(self, config: SiglipConfig):
569
+ super().__init__()
570
+ self.embed_dim = config.hidden_size
571
+ self.self_attn = (
572
+ SiglipAttention(config)
573
+ if not getattr(config, "_flash_attn_2_enabled", False)
574
+ else SiglipFlashAttention2(config)
575
+ )
576
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
577
+ self.mlp = SiglipMLP(config)
578
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
579
+
580
+ def forward(
581
+ self,
582
+ hidden_states: torch.Tensor,
583
+ attention_mask: torch.Tensor,
584
+ causal_attention_mask: torch.Tensor,
585
+ output_attentions: Optional[bool] = False,
586
+ ) -> Tuple[torch.FloatTensor]:
587
+ """
588
+ Args:
589
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
590
+ attention_mask (`torch.FloatTensor`): attention mask of size
591
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
592
+ `(config.encoder_attention_heads,)`.
593
+ output_attentions (`bool`, *optional*):
594
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
595
+ returned tensors for more detail.
596
+ """
597
+ residual = hidden_states
598
+
599
+ hidden_states = self.layer_norm1(hidden_states)
600
+ hidden_states, attn_weights = self.self_attn(
601
+ hidden_states=hidden_states,
602
+ attention_mask=attention_mask,
603
+ causal_attention_mask=causal_attention_mask,
604
+ output_attentions=output_attentions,
605
+ )
606
+ hidden_states = residual + hidden_states
607
+
608
+ residual = hidden_states
609
+ hidden_states = self.layer_norm2(hidden_states)
610
+ hidden_states = self.mlp(hidden_states)
611
+ hidden_states = residual + hidden_states
612
+
613
+ outputs = (hidden_states,)
614
+
615
+ if output_attentions:
616
+ outputs += (attn_weights,)
617
+
618
+ return outputs
619
+
620
+
621
+ class SiglipPreTrainedModel(PreTrainedModel):
622
+ """
623
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
624
+ models.
625
+ """
626
+
627
+ config_class = SiglipConfig
628
+ base_model_prefix = "siglip"
629
+ supports_gradient_checkpointing = True
630
+
631
+ def _init_weights(self, module):
632
+ """Initialize the weights"""
633
+ factor = self.config.initializer_factor
634
+ if isinstance(module, SiglipTextEmbeddings):
635
+ module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
636
+ module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
637
+ elif isinstance(module, SiglipVisionEmbeddings):
638
+ factor = self.config.initializer_factor
639
+ nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
640
+ nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
641
+ elif isinstance(module, SiglipAttention):
642
+ factor = self.config.initializer_factor
643
+ in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
644
+ out_proj_std = (module.embed_dim**-0.5) * factor
645
+ nn.init.normal_(module.q_proj.weight, std=in_proj_std)
646
+ nn.init.normal_(module.k_proj.weight, std=in_proj_std)
647
+ nn.init.normal_(module.v_proj.weight, std=in_proj_std)
648
+ nn.init.normal_(module.out_proj.weight, std=out_proj_std)
649
+ elif isinstance(module, SiglipMLP):
650
+ factor = self.config.initializer_factor
651
+ in_proj_std = (
652
+ (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
653
+ )
654
+ fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
655
+ nn.init.normal_(module.fc1.weight, std=fc_std)
656
+ nn.init.normal_(module.fc2.weight, std=in_proj_std)
657
+ if isinstance(module, nn.LayerNorm):
658
+ module.bias.data.zero_()
659
+ module.weight.data.fill_(1.0)
660
+ if isinstance(module, nn.Linear) and module.bias is not None:
661
+ module.bias.data.zero_()
662
+
663
+ def _set_gradient_checkpointing(self, module, value=False):
664
+ if isinstance(module, SiglipEncoder):
665
+ module.gradient_checkpointing = value
666
+
667
+
668
+ SIGLIP_START_DOCSTRING = r"""
669
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
670
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
671
+ etc.)
672
+
673
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
674
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
675
+ and behavior.
676
+
677
+ Parameters:
678
+ config ([`SiglipConfig`]): Model configuration class with all the parameters of the model.
679
+ Initializing with a config file does not load the weights associated with the model, only the
680
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
681
+ """
682
+
683
+ SIGLIP_TEXT_INPUTS_DOCSTRING = r"""
684
+ Args:
685
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
686
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
687
+ it.
688
+
689
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
690
+ [`PreTrainedTokenizer.__call__`] for details.
691
+
692
+ [What are input IDs?](../glossary#input-ids)
693
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
694
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
695
+
696
+ - 1 for tokens that are **not masked**,
697
+ - 0 for tokens that are **masked**.
698
+
699
+ [What are attention masks?](../glossary#attention-mask)
700
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
701
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
702
+ config.max_position_embeddings - 1]`.
703
+
704
+ [What are position IDs?](../glossary#position-ids)
705
+ output_attentions (`bool`, *optional*):
706
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
707
+ tensors for more detail.
708
+ output_hidden_states (`bool`, *optional*):
709
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
710
+ more detail.
711
+ return_dict (`bool`, *optional*):
712
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
713
+ """
714
+
715
+ SIGLIP_VISION_INPUTS_DOCSTRING = r"""
716
+ Args:
717
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
718
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
719
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
720
+ output_attentions (`bool`, *optional*):
721
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
722
+ tensors for more detail.
723
+ output_hidden_states (`bool`, *optional*):
724
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
725
+ more detail.
726
+ return_dict (`bool`, *optional*):
727
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
728
+ """
729
+
730
+ SIGLIP_INPUTS_DOCSTRING = r"""
731
+ Args:
732
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
733
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
734
+ it.
735
+
736
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
737
+ [`PreTrainedTokenizer.__call__`] for details.
738
+
739
+ [What are input IDs?](../glossary#input-ids)
740
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
741
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
742
+
743
+ - 1 for tokens that are **not masked**,
744
+ - 0 for tokens that are **masked**.
745
+
746
+ [What are attention masks?](../glossary#attention-mask)
747
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
748
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
749
+ config.max_position_embeddings - 1]`.
750
+
751
+ [What are position IDs?](../glossary#position-ids)
752
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
753
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
754
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
755
+ return_loss (`bool`, *optional*):
756
+ Whether or not to return the contrastive loss.
757
+ output_attentions (`bool`, *optional*):
758
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
759
+ tensors for more detail.
760
+ output_hidden_states (`bool`, *optional*):
761
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
762
+ more detail.
763
+ return_dict (`bool`, *optional*):
764
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
765
+ """
766
+
767
+
768
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip
769
+ class SiglipEncoder(nn.Module):
770
+ """
771
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
772
+ [`SiglipEncoderLayer`].
773
+
774
+ Args:
775
+ config: SiglipConfig
776
+ """
777
+
778
+ def __init__(self, config: SiglipConfig):
779
+ super().__init__()
780
+ self.config = config
781
+ self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
782
+ self.gradient_checkpointing = False
783
+
784
+ def forward(
785
+ self,
786
+ inputs_embeds,
787
+ attention_mask: Optional[torch.Tensor] = None,
788
+ causal_attention_mask: Optional[torch.Tensor] = None,
789
+ output_attentions: Optional[bool] = None,
790
+ output_hidden_states: Optional[bool] = None,
791
+ return_dict: Optional[bool] = None,
792
+ ) -> Union[Tuple, BaseModelOutput]:
793
+ r"""
794
+ Args:
795
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
796
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
797
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
798
+ than the model's internal embedding lookup matrix.
799
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
800
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
801
+
802
+ - 1 for tokens that are **not masked**,
803
+ - 0 for tokens that are **masked**.
804
+
805
+ [What are attention masks?](../glossary#attention-mask)
806
+ causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
807
+ Causal mask for the text model. Mask values selected in `[0, 1]`:
808
+
809
+ - 1 for tokens that are **not masked**,
810
+ - 0 for tokens that are **masked**.
811
+
812
+ [What are attention masks?](../glossary#attention-mask)
813
+ output_attentions (`bool`, *optional*):
814
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
815
+ returned tensors for more detail.
816
+ output_hidden_states (`bool`, *optional*):
817
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
818
+ for more detail.
819
+ return_dict (`bool`, *optional*):
820
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
821
+ """
822
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
823
+ output_hidden_states = (
824
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
825
+ )
826
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
827
+
828
+ encoder_states = () if output_hidden_states else None
829
+ all_attentions = () if output_attentions else None
830
+
831
+ hidden_states = inputs_embeds
832
+ for idx, encoder_layer in enumerate(self.layers):
833
+ if output_hidden_states:
834
+ encoder_states = encoder_states + (hidden_states,)
835
+ if self.gradient_checkpointing and self.training:
836
+
837
+ def create_custom_forward(module):
838
+ def custom_forward(*inputs):
839
+ return module(*inputs, output_attentions)
840
+
841
+ return custom_forward
842
+
843
+ layer_outputs = torch.utils.checkpoint.checkpoint(
844
+ create_custom_forward(encoder_layer),
845
+ hidden_states,
846
+ attention_mask,
847
+ causal_attention_mask,
848
+ )
849
+ else:
850
+ layer_outputs = encoder_layer(
851
+ hidden_states,
852
+ attention_mask,
853
+ causal_attention_mask,
854
+ output_attentions=output_attentions,
855
+ )
856
+
857
+ hidden_states = layer_outputs[0]
858
+
859
+ if output_attentions:
860
+ all_attentions = all_attentions + (layer_outputs[1],)
861
+
862
+ if output_hidden_states:
863
+ encoder_states = encoder_states + (hidden_states,)
864
+
865
+ if not return_dict:
866
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
867
+ return BaseModelOutput(
868
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
869
+ )
870
+
871
+
872
+ class SiglipTextTransformer(nn.Module):
873
+ def __init__(self, config: SiglipTextConfig):
874
+ super().__init__()
875
+ self.config = config
876
+ embed_dim = config.hidden_size
877
+ self.embeddings = SiglipTextEmbeddings(config)
878
+ self.encoder = SiglipEncoder(config)
879
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
880
+
881
+ self.head = nn.Linear(embed_dim, embed_dim)
882
+
883
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
884
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
885
+ def forward(
886
+ self,
887
+ input_ids: Optional[torch.Tensor] = None,
888
+ attention_mask: Optional[torch.Tensor] = None,
889
+ position_ids: Optional[torch.Tensor] = None,
890
+ output_attentions: Optional[bool] = None,
891
+ output_hidden_states: Optional[bool] = None,
892
+ return_dict: Optional[bool] = None,
893
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
894
+ r"""
895
+ Returns:
896
+
897
+ """
898
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
899
+ output_hidden_states = (
900
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
901
+ )
902
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
903
+
904
+ if input_ids is None:
905
+ raise ValueError("You have to specify input_ids")
906
+
907
+ input_shape = input_ids.size()
908
+ input_ids = input_ids.view(-1, input_shape[-1])
909
+
910
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
911
+
912
+ # note: SigLIP's text model does not use q causal mask, unlike the original CLIP model.
913
+ # expand attention_mask
914
+ if attention_mask is not None:
915
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
916
+ attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
917
+
918
+ encoder_outputs = self.encoder(
919
+ inputs_embeds=hidden_states,
920
+ attention_mask=None,
921
+ causal_attention_mask=None,
922
+ output_attentions=output_attentions,
923
+ output_hidden_states=output_hidden_states,
924
+ return_dict=return_dict,
925
+ )
926
+
927
+ last_hidden_state = encoder_outputs[0]
928
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
929
+
930
+ # Assuming "sticky" EOS tokenization, last token is always EOS.
931
+ pooled_output = last_hidden_state[:, -1, :]
932
+ pooled_output = self.head(pooled_output)
933
+
934
+ if not return_dict:
935
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
936
+
937
+ return BaseModelOutputWithPooling(
938
+ last_hidden_state=last_hidden_state,
939
+ pooler_output=pooled_output,
940
+ hidden_states=encoder_outputs.hidden_states,
941
+ attentions=encoder_outputs.attentions,
942
+ )
943
+
944
+
945
+ @add_start_docstrings(
946
+ """The text model from SigLIP without any head or projection on top.""",
947
+ SIGLIP_START_DOCSTRING,
948
+ )
949
+ class SiglipTextModel(SiglipPreTrainedModel):
950
+ config_class = SiglipTextConfig
951
+
952
+ _no_split_modules = ["SiglipTextEmbeddings", "SiglipEncoderLayer"]
953
+
954
+ def __init__(self, config: SiglipTextConfig):
955
+ super().__init__(config)
956
+ self.text_model = SiglipTextTransformer(config)
957
+ # Initialize weights and apply final processing
958
+ self.post_init()
959
+
960
+ def get_input_embeddings(self) -> nn.Module:
961
+ return self.text_model.embeddings.token_embedding
962
+
963
+ def set_input_embeddings(self, value):
964
+ self.text_model.embeddings.token_embedding = value
965
+
966
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
967
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
968
+ def forward(
969
+ self,
970
+ input_ids: Optional[torch.Tensor] = None,
971
+ attention_mask: Optional[torch.Tensor] = None,
972
+ position_ids: Optional[torch.Tensor] = None,
973
+ output_attentions: Optional[bool] = None,
974
+ output_hidden_states: Optional[bool] = None,
975
+ return_dict: Optional[bool] = None,
976
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
977
+ r"""
978
+ Returns:
979
+
980
+ Examples:
981
+
982
+ ```python
983
+ >>> from transformers import AutoTokenizer, SiglipTextModel
984
+
985
+ >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224")
986
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
987
+
988
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
989
+
990
+ >>> outputs = model(**inputs)
991
+ >>> last_hidden_state = outputs.last_hidden_state
992
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
993
+ ```"""
994
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
995
+
996
+ return self.text_model(
997
+ input_ids=input_ids,
998
+ attention_mask=attention_mask,
999
+ position_ids=position_ids,
1000
+ output_attentions=output_attentions,
1001
+ output_hidden_states=output_hidden_states,
1002
+ return_dict=return_dict,
1003
+ )
1004
+
1005
+
1006
+ class SiglipVisionTransformer(nn.Module):
1007
+ def __init__(self, config: SiglipVisionConfig):
1008
+ super().__init__()
1009
+ self.config = config
1010
+ embed_dim = config.hidden_size
1011
+
1012
+ self.embeddings = SiglipVisionEmbeddings(config)
1013
+ self.encoder = SiglipEncoder(config)
1014
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
1015
+ self.head = SiglipMultiheadAttentionPoolingHead(config)
1016
+
1017
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1018
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
1019
+ def forward(
1020
+ self,
1021
+ pixel_values,
1022
+ output_attentions: Optional[bool] = None,
1023
+ output_hidden_states: Optional[bool] = None,
1024
+ return_dict: Optional[bool] = None,
1025
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
1026
+ r"""
1027
+ Returns:
1028
+
1029
+ """
1030
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1031
+ output_hidden_states = (
1032
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1033
+ )
1034
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1035
+
1036
+ hidden_states = self.embeddings(pixel_values)
1037
+
1038
+ encoder_outputs = self.encoder(
1039
+ inputs_embeds=hidden_states,
1040
+ output_attentions=output_attentions,
1041
+ output_hidden_states=output_hidden_states,
1042
+ return_dict=return_dict,
1043
+ )
1044
+
1045
+ last_hidden_state = encoder_outputs[0]
1046
+ last_hidden_state = self.post_layernorm(last_hidden_state)
1047
+
1048
+ pooled_output = self.head(last_hidden_state)
1049
+
1050
+ if not return_dict:
1051
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
1052
+
1053
+ return BaseModelOutputWithPooling(
1054
+ last_hidden_state=last_hidden_state,
1055
+ pooler_output=pooled_output,
1056
+ hidden_states=encoder_outputs.hidden_states,
1057
+ attentions=encoder_outputs.attentions,
1058
+ )
1059
+
1060
+
1061
+ class SiglipMultiheadAttentionPoolingHead(nn.Module):
1062
+ """Multihead Attention Pooling."""
1063
+
1064
+ def __init__(self, config: SiglipVisionConfig):
1065
+ super().__init__()
1066
+
1067
+ self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
1068
+ self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
1069
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1070
+ self.mlp = SiglipMLP(config)
1071
+
1072
+ def forward(self, hidden_state):
1073
+ batch_size = hidden_state.shape[0]
1074
+ probe = self.probe.repeat(batch_size, 1, 1)
1075
+
1076
+ hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
1077
+
1078
+ residual = hidden_state
1079
+ hidden_state = self.layernorm(hidden_state)
1080
+ hidden_state = residual + self.mlp(hidden_state)
1081
+
1082
+ return hidden_state[:, 0]
1083
+
1084
+
1085
+ @add_start_docstrings(
1086
+ """The vision model from SigLIP without any head or projection on top.""",
1087
+ SIGLIP_START_DOCSTRING,
1088
+ )
1089
+ class SiglipVisionModel(SiglipPreTrainedModel):
1090
+ config_class = SiglipVisionConfig
1091
+ main_input_name = "pixel_values"
1092
+
1093
+ def __init__(self, config: SiglipVisionConfig):
1094
+ super().__init__(config)
1095
+
1096
+ self.vision_model = SiglipVisionTransformer(config)
1097
+
1098
+ # Initialize weights and apply final processing
1099
+ self.post_init()
1100
+
1101
+ def get_input_embeddings(self) -> nn.Module:
1102
+ return self.vision_model.embeddings.patch_embedding
1103
+
1104
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1105
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
1106
+ def forward(
1107
+ self,
1108
+ pixel_values,
1109
+ output_attentions: Optional[bool] = None,
1110
+ output_hidden_states: Optional[bool] = None,
1111
+ return_dict: Optional[bool] = None,
1112
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
1113
+ r"""
1114
+ Returns:
1115
+
1116
+ Examples:
1117
+
1118
+ ```python
1119
+ >>> from PIL import Image
1120
+ >>> import requests
1121
+ >>> from transformers import AutoProcessor, SiglipVisionModel
1122
+
1123
+ >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
1124
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1125
+
1126
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1127
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1128
+
1129
+ >>> inputs = processor(images=image, return_tensors="pt")
1130
+
1131
+ >>> outputs = model(**inputs)
1132
+ >>> last_hidden_state = outputs.last_hidden_state
1133
+ >>> pooled_output = outputs.pooler_output # pooled CLS states
1134
+ ```"""
1135
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1136
+
1137
+ return self.vision_model(
1138
+ pixel_values=pixel_values,
1139
+ output_attentions=output_attentions,
1140
+ output_hidden_states=output_hidden_states,
1141
+ return_dict=return_dict,
1142
+ )
1143
+
1144
+
1145
+ @add_start_docstrings(SIGLIP_START_DOCSTRING)
1146
+ class SiglipModel(SiglipPreTrainedModel):
1147
+ config_class = SiglipConfig
1148
+
1149
+ def __init__(self, config: SiglipConfig):
1150
+ super().__init__(config)
1151
+
1152
+ if not isinstance(config.text_config, SiglipTextConfig):
1153
+ raise ValueError(
1154
+ "config.text_config is expected to be of type SiglipTextConfig but is of type"
1155
+ f" {type(config.text_config)}."
1156
+ )
1157
+
1158
+ if not isinstance(config.vision_config, SiglipVisionConfig):
1159
+ raise ValueError(
1160
+ "config.vision_config is expected to be of type SiglipVisionConfig but is of type"
1161
+ f" {type(config.vision_config)}."
1162
+ )
1163
+
1164
+ text_config = config.text_config
1165
+ vision_config = config.vision_config
1166
+
1167
+ self.text_model = SiglipTextModel(text_config)
1168
+ self.vision_model = SiglipVisionModel(vision_config)
1169
+
1170
+ self.temperature = nn.Parameter(
1171
+ torch.randn(
1172
+ 1,
1173
+ )
1174
+ )
1175
+ self.bias = nn.Parameter(
1176
+ torch.randn(
1177
+ 1,
1178
+ )
1179
+ )
1180
+
1181
+ # Initialize weights and apply final processing
1182
+ self.post_init()
1183
+
1184
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
1185
+ def get_text_features(
1186
+ self,
1187
+ input_ids: Optional[torch.Tensor] = None,
1188
+ attention_mask: Optional[torch.Tensor] = None,
1189
+ position_ids: Optional[torch.Tensor] = None,
1190
+ output_attentions: Optional[bool] = None,
1191
+ output_hidden_states: Optional[bool] = None,
1192
+ return_dict: Optional[bool] = None,
1193
+ ) -> torch.FloatTensor:
1194
+ r"""
1195
+ Returns:
1196
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
1197
+ applying the projection layer to the pooled output of [`SiglipTextModel`].
1198
+
1199
+ Examples:
1200
+
1201
+ ```python
1202
+ >>> from transformers import AutoTokenizer, SiglipModel
1203
+
1204
+ >>> model = SiglipModel.from_pretrained("google/siglip-base-patch16-224")
1205
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
1206
+
1207
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
1208
+ >>> text_features = model.get_text_features(**inputs)
1209
+ ```"""
1210
+ # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1211
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1212
+ output_hidden_states = (
1213
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1214
+ )
1215
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1216
+
1217
+ text_outputs = self.text_model(
1218
+ input_ids=input_ids,
1219
+ attention_mask=attention_mask,
1220
+ position_ids=position_ids,
1221
+ output_attentions=output_attentions,
1222
+ output_hidden_states=output_hidden_states,
1223
+ return_dict=return_dict,
1224
+ )
1225
+
1226
+ pooled_output = text_outputs[1]
1227
+
1228
+ return pooled_output
1229
+
1230
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1231
+ def get_image_features(
1232
+ self,
1233
+ pixel_values: Optional[torch.FloatTensor] = None,
1234
+ output_attentions: Optional[bool] = None,
1235
+ output_hidden_states: Optional[bool] = None,
1236
+ return_dict: Optional[bool] = None,
1237
+ ) -> torch.FloatTensor:
1238
+ r"""
1239
+ Returns:
1240
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
1241
+ applying the projection layer to the pooled output of [`SiglipVisionModel`].
1242
+
1243
+ Examples:
1244
+
1245
+ ```python
1246
+ >>> from PIL import Image
1247
+ >>> import requests
1248
+ >>> from transformers import AutoProcessor, SiglipModel
1249
+
1250
+ >>> model = SiglipModel.from_pretrained("google/siglip-base-patch16-224")
1251
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1252
+
1253
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1254
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1255
+
1256
+ >>> inputs = processor(images=image, return_tensors="pt")
1257
+
1258
+ >>> image_features = model.get_image_features(**inputs)
1259
+ ```"""
1260
+ # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components.
1261
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1262
+ output_hidden_states = (
1263
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1264
+ )
1265
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1266
+
1267
+ vision_outputs = self.vision_model(
1268
+ pixel_values=pixel_values,
1269
+ output_attentions=output_attentions,
1270
+ output_hidden_states=output_hidden_states,
1271
+ return_dict=return_dict,
1272
+ )
1273
+
1274
+ pooled_output = vision_outputs[1]
1275
+
1276
+ return pooled_output
1277
+
1278
+ @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING)
1279
+ @replace_return_docstrings(output_type=SiglipOutput, config_class=SiglipConfig)
1280
+ def forward(
1281
+ self,
1282
+ input_ids: Optional[torch.LongTensor] = None,
1283
+ pixel_values: Optional[torch.FloatTensor] = None,
1284
+ attention_mask: Optional[torch.Tensor] = None,
1285
+ position_ids: Optional[torch.LongTensor] = None,
1286
+ return_loss: Optional[bool] = None,
1287
+ output_attentions: Optional[bool] = None,
1288
+ output_hidden_states: Optional[bool] = None,
1289
+ return_dict: Optional[bool] = None,
1290
+ ) -> Union[Tuple, SiglipOutput]:
1291
+ r"""
1292
+ Returns:
1293
+
1294
+ Examples:
1295
+
1296
+ ```python
1297
+ >>> from PIL import Image
1298
+ >>> import requests
1299
+ >>> from transformers import AutoProcessor, SiglipModel
1300
+
1301
+ >>> model = SiglipModel.from_pretrained("google/siglip-base-patch16-224")
1302
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1303
+
1304
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1305
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1306
+
1307
+ >>> inputs = processor(
1308
+ ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
1309
+ ... )
1310
+
1311
+ >>> outputs = model(**inputs)
1312
+ >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
1313
+ >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
1314
+ ```"""
1315
+ # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1316
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1317
+ output_hidden_states = (
1318
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1319
+ )
1320
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1321
+
1322
+ vision_outputs = self.vision_model(
1323
+ pixel_values=pixel_values,
1324
+ output_attentions=output_attentions,
1325
+ output_hidden_states=output_hidden_states,
1326
+ return_dict=return_dict,
1327
+ )
1328
+
1329
+ text_outputs = self.text_model(
1330
+ input_ids=input_ids,
1331
+ attention_mask=attention_mask,
1332
+ position_ids=position_ids,
1333
+ output_attentions=output_attentions,
1334
+ output_hidden_states=output_hidden_states,
1335
+ return_dict=return_dict,
1336
+ )
1337
+
1338
+ image_embeds = vision_outputs[1]
1339
+ text_embeds = text_outputs[1]
1340
+
1341
+ # normalized features
1342
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
1343
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
1344
+
1345
+ # cosine similarity as logits
1346
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.temperature.exp() + self.bias
1347
+ logits_per_image = logits_per_text.t()
1348
+
1349
+ z = torch.matmul(image_embeds, text_embeds.t()) * self.temperature.exp()
1350
+
1351
+ loss = None
1352
+ if return_loss:
1353
+ raise NotImplementedError("SigLIP loss to be implemented")
1354
+
1355
+ if not return_dict:
1356
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
1357
+ return ((loss,) + output) if loss is not None else output
1358
+
1359
+ return SiglipOutput(
1360
+ loss=loss,
1361
+ logits_per_image=logits_per_image,
1362
+ logits_per_text=logits_per_text,
1363
+ text_embeds=text_embeds,
1364
+ image_embeds=image_embeds,
1365
+ text_model_output=text_outputs,
1366
+ vision_model_output=vision_outputs,
1367
+ )