koukandre commited on
Commit
7dc253b
1 Parent(s): 952897b

feat-matryoshka-support

Browse files
Files changed (2) hide show
  1. configuration_clip.py +5 -1
  2. modeling_clip.py +43 -5
configuration_clip.py CHANGED
@@ -6,7 +6,7 @@
6
 
7
  import os
8
  from copy import deepcopy
9
- from typing import Any, Dict, Optional, Union
10
 
11
  from transformers import PretrainedConfig, logging
12
 
@@ -157,6 +157,8 @@ class JinaCLIPConfig(PretrainedConfig):
157
  logit_scale_init_value: float = 2.6592,
158
  use_text_flash_attn: Optional[bool] = None,
159
  use_vision_xformers: Optional[bool] = None,
 
 
160
  **kwargs,
161
  ):
162
  # If `_config_dict` exist, we use them for the backward compatibility.
@@ -167,6 +169,8 @@ class JinaCLIPConfig(PretrainedConfig):
167
  vision_config_dict: Optional[Dict] = kwargs.pop('vision_config_dict', None)
168
  self.use_text_flash_attn = use_text_flash_attn
169
  self.use_vision_xformers = use_vision_xformers
 
 
170
 
171
  super().__init__(**kwargs)
172
 
 
6
 
7
  import os
8
  from copy import deepcopy
9
+ from typing import Any, Dict, List, Optional, Union
10
 
11
  from transformers import PretrainedConfig, logging
12
 
 
157
  logit_scale_init_value: float = 2.6592,
158
  use_text_flash_attn: Optional[bool] = None,
159
  use_vision_xformers: Optional[bool] = None,
160
+ matryoshka_dimensions: Optional[List[int]] = None,
161
+ truncate_dim: Optional[int] = None,
162
  **kwargs,
163
  ):
164
  # If `_config_dict` exist, we use them for the backward compatibility.
 
169
  vision_config_dict: Optional[Dict] = kwargs.pop('vision_config_dict', None)
170
  self.use_text_flash_attn = use_text_flash_attn
171
  self.use_vision_xformers = use_vision_xformers
172
+ self.matryoshka_dimensions = matryoshka_dimensions
173
+ self.truncate_dim = truncate_dim
174
 
175
  super().__init__(**kwargs)
176
 
modeling_clip.py CHANGED
@@ -4,12 +4,13 @@
4
  # https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py
5
  # and adjusted for Jina CLIP
6
 
 
7
  from functools import partial
8
- from typing import List, Optional, Tuple, Union
9
  from io import BytesIO
10
- import requests
11
- import base64
12
  import numpy as np
 
13
  import torch
14
  import torch.nn.functional as f
15
  import torch.utils.checkpoint
@@ -39,9 +40,14 @@ except ImportError:
39
  from .configuration_clip import JinaCLIPConfig, JinaCLIPTextConfig, JinaCLIPVisionConfig
40
  from .eva_model import EVAVisionTransformer
41
  from .hf_model import HFTextEncoder
 
42
  # needed for HF to correctly import in cache
43
  from .rope_embeddings import VisionRotaryEmbeddingFast # noqa: F401
44
- from .transform import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD, image_transform # noqa: F401
 
 
 
 
45
 
46
  logger = logging.get_logger(__name__)
47
 
@@ -280,6 +286,25 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
280
  )
281
  return self.visual_projection(self.vision_model(x=x))
282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  @torch.inference_mode()
284
  def encode_text(
285
  self,
@@ -290,6 +315,7 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
290
  convert_to_tensor: bool = False,
291
  device: Optional[torch.device] = None,
292
  normalize_embeddings: bool = True,
 
293
  **tokenizer_kwargs,
294
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
295
  """
@@ -315,6 +341,8 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
315
  If set to true, returned vectors will have length 1. In that case,
316
  the faster dot-product (util.dot_score) instead of cosine similarity
317
  can be used.
 
 
318
  tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
319
  Keyword arguments for the tokenizer
320
  Returns:
@@ -364,6 +392,7 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
364
  else:
365
  range_iter = range(0, len(sentences), batch_size)
366
 
 
367
  for i in range_iter:
368
  encoded_input = self.tokenizer(
369
  sentences[i : i + batch_size],
@@ -372,6 +401,9 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
372
  ).to(self.device)
373
 
374
  embeddings = self.get_text_features(input_ids=encoded_input)
 
 
 
375
  if normalize_embeddings:
376
  embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
377
  if convert_to_numpy:
@@ -406,6 +438,7 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
406
  convert_to_tensor: bool = False,
407
  device: Optional[torch.device] = None,
408
  normalize_embeddings: bool = True,
 
409
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
410
  """
411
  Computes image embeddings.
@@ -431,6 +464,8 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
431
  If set to true, returned vectors will have length 1. In that case,
432
  the faster dot-product (util.dot_score) instead of cosine similarity
433
  can be used.
 
 
434
  Returns:
435
  By default, a list of tensors is returned.
436
  If convert_to_tensor, a stacked tensor is returned.
@@ -476,7 +511,8 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
476
  range_iter = range(0, len(images), batch_size)
477
 
478
  from PIL import Image
479
-
 
480
  for i in range_iter:
481
  batch_images = images[i:i+batch_size]
482
  processed_inputs = []
@@ -501,6 +537,8 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
501
  processed_inputs = processed_inputs.to(self.device)
502
  embeddings = self.get_image_features(processed_inputs)
503
 
 
 
504
  if normalize_embeddings:
505
  embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
506
  if convert_to_numpy:
 
4
  # https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py
5
  # and adjusted for Jina CLIP
6
 
7
+ import base64
8
  from functools import partial
 
9
  from io import BytesIO
10
+ from typing import List, Optional, Tuple, Union
11
+
12
  import numpy as np
13
+ import requests
14
  import torch
15
  import torch.nn.functional as f
16
  import torch.utils.checkpoint
 
40
  from .configuration_clip import JinaCLIPConfig, JinaCLIPTextConfig, JinaCLIPVisionConfig
41
  from .eva_model import EVAVisionTransformer
42
  from .hf_model import HFTextEncoder
43
+
44
  # needed for HF to correctly import in cache
45
  from .rope_embeddings import VisionRotaryEmbeddingFast # noqa: F401
46
+ from .transform import ( # noqa: F401
47
+ OPENAI_DATASET_MEAN,
48
+ OPENAI_DATASET_STD,
49
+ image_transform,
50
+ )
51
 
52
  logger = logging.get_logger(__name__)
53
 
 
286
  )
287
  return self.visual_projection(self.vision_model(x=x))
288
 
289
+ def truncate_embeddings(self, embeddings, truncate_dim):
290
+ if "jina-clip-v1" in self.config._name_or_path:
291
+ logger.warning(
292
+ "Matryoshka embeddings are not supported for jina-clip-v1, so dimension truncation will not be performed."
293
+ )
294
+ return embeddings
295
+ elif not self.config.matryoshka_dimensions:
296
+ logger.warning(
297
+ "Matryoshka embeddings are not supported, so dimension truncation will not be performed."
298
+ )
299
+ return embeddings
300
+ elif truncate_dim in self.config.matryoshka_dimensions:
301
+ return embeddings[:, :truncate_dim]
302
+ else:
303
+ raise ValueError(
304
+ f"The provided `truncate_dim` value of {truncate_dim} is not supported. "
305
+ f"Supported dimensions are {self.config.matryoshka_dimensions}."
306
+ )
307
+
308
  @torch.inference_mode()
309
  def encode_text(
310
  self,
 
315
  convert_to_tensor: bool = False,
316
  device: Optional[torch.device] = None,
317
  normalize_embeddings: bool = True,
318
+ truncate_dim: Optional[int] = None,
319
  **tokenizer_kwargs,
320
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
321
  """
 
341
  If set to true, returned vectors will have length 1. In that case,
342
  the faster dot-product (util.dot_score) instead of cosine similarity
343
  can be used.
344
+ truncate_dim(`int`, *optional*, defaults to None):
345
+ The dimension to truncate sentence embeddings to. `None` does no truncation.
346
  tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
347
  Keyword arguments for the tokenizer
348
  Returns:
 
392
  else:
393
  range_iter = range(0, len(sentences), batch_size)
394
 
395
+ truncate_dim = truncate_dim or self.config.truncate_dim
396
  for i in range_iter:
397
  encoded_input = self.tokenizer(
398
  sentences[i : i + batch_size],
 
401
  ).to(self.device)
402
 
403
  embeddings = self.get_text_features(input_ids=encoded_input)
404
+
405
+ if truncate_dim:
406
+ embeddings = self.truncate_embeddings(embeddings, truncate_dim)
407
  if normalize_embeddings:
408
  embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
409
  if convert_to_numpy:
 
438
  convert_to_tensor: bool = False,
439
  device: Optional[torch.device] = None,
440
  normalize_embeddings: bool = True,
441
+ truncate_dim: Optional[int] = None,
442
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
443
  """
444
  Computes image embeddings.
 
464
  If set to true, returned vectors will have length 1. In that case,
465
  the faster dot-product (util.dot_score) instead of cosine similarity
466
  can be used.
467
+ truncate_dim(`int`, *optional*, defaults to None):
468
+ The dimension to truncate sentence embeddings to. `None` does no truncation.
469
  Returns:
470
  By default, a list of tensors is returned.
471
  If convert_to_tensor, a stacked tensor is returned.
 
511
  range_iter = range(0, len(images), batch_size)
512
 
513
  from PIL import Image
514
+
515
+ truncate_dim = truncate_dim or self.config.truncate_dim
516
  for i in range_iter:
517
  batch_images = images[i:i+batch_size]
518
  processed_inputs = []
 
537
  processed_inputs = processed_inputs.to(self.device)
538
  embeddings = self.get_image_features(processed_inputs)
539
 
540
+ if truncate_dim:
541
+ embeddings = self.truncate_embeddings(embeddings, truncate_dim)
542
  if normalize_embeddings:
543
  embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
544
  if convert_to_numpy: