jupyterjazz commited on
Commit
8542ad8
·
verified ·
1 Parent(s): 27d23b2

truncate-embedding-dimension (#10)

Browse files

- feat: matryoshka embeddings (ee8863c9cbae496b327c9d5761b54f02a2f90954)
- refactor: optional arg (3f72891549ab3f1a6a3cd4e8b40dff6d5c50d1b1)
- fix: var name (fd34c40e6fcbb638e225ccd8b47f5b9c487bd8a4)
- fix: another one (b27fa557459cf35a2520c39da441b5e79e455068)
- refactor: truncation fn (c55e59156fa5b02100f7a7707324f3ce4f92714f)
- feat: truncation option during init (943cec246f8df968b3c6b2bd10e89f9529797b25)

configuration_xlm_roberta.py CHANGED
@@ -31,6 +31,8 @@ class XLMRobertaFlashConfig(PretrainedConfig):
31
  use_flash_attn=True,
32
  torch_dtype=None,
33
  emb_pooler=None,
 
 
34
  **kwargs,
35
  ):
36
  super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
@@ -59,6 +61,8 @@ class XLMRobertaFlashConfig(PretrainedConfig):
59
  self.lora_main_params_trainable = lora_main_params_trainable
60
  self.use_flash_attn = use_flash_attn
61
  self.emb_pooler = emb_pooler
 
 
62
  if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
63
  self.torch_dtype = getattr(torch, torch_dtype)
64
  else:
 
31
  use_flash_attn=True,
32
  torch_dtype=None,
33
  emb_pooler=None,
34
+ matryoshka_dimensions=None,
35
+ truncate_dim=None,
36
  **kwargs,
37
  ):
38
  super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
 
61
  self.lora_main_params_trainable = lora_main_params_trainable
62
  self.use_flash_attn = use_flash_attn
63
  self.emb_pooler = emb_pooler
64
+ self.matryoshka_dimensions = matryoshka_dimensions
65
+ self.truncate_dim = truncate_dim
66
  if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
67
  self.torch_dtype = getattr(torch, torch_dtype)
68
  else:
modeling_xlm_roberta.py CHANGED
@@ -452,6 +452,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
452
  convert_to_tensor: bool = False,
453
  device: Optional[torch.device] = None,
454
  normalize_embeddings: bool = False,
 
455
  **tokenizer_kwargs,
456
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
457
  """
@@ -481,6 +482,8 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
481
  If set to true, returned vectors will have length 1. In that case, the
482
  faster dot-product (util.dot_score) instead of cosine similarity can
483
  be used.
 
 
484
  tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
485
  Keyword arguments for the tokenizer
486
  Returns:
@@ -575,6 +578,10 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
575
 
576
  all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
577
 
 
 
 
 
578
  if convert_to_tensor:
579
  all_embeddings = torch.stack(all_embeddings)
580
  elif convert_to_numpy:
@@ -586,6 +593,19 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
586
  self.train(is_training)
587
  return all_embeddings
588
 
 
 
 
 
 
 
 
 
 
 
 
 
 
589
  def mean_pooling(
590
  self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
591
  ):
 
452
  convert_to_tensor: bool = False,
453
  device: Optional[torch.device] = None,
454
  normalize_embeddings: bool = False,
455
+ truncate_dim: Optional[int] = None,
456
  **tokenizer_kwargs,
457
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
458
  """
 
482
  If set to true, returned vectors will have length 1. In that case, the
483
  faster dot-product (util.dot_score) instead of cosine similarity can
484
  be used.
485
+ truncate_dim(`int`, *optional*, defaults to None):
486
+ The dimension to truncate sentence embeddings to. `None` does no truncation.
487
  tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
488
  Keyword arguments for the tokenizer
489
  Returns:
 
578
 
579
  all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
580
 
581
+ truncate_dim = truncate_dim or self.config.truncate_dim
582
+ if truncate_dim:
583
+ all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
584
+
585
  if convert_to_tensor:
586
  all_embeddings = torch.stack(all_embeddings)
587
  elif convert_to_numpy:
 
593
  self.train(is_training)
594
  return all_embeddings
595
 
596
+
597
+ def truncate_embeddings(self, embeddings, truncate_dim):
598
+ if not self.config.matryoshka_dimensions:
599
+ logger.warning(
600
+ 'Matryoshka embeddings are not supported, so dimension truncation will not be performed.'
601
+ )
602
+ return embeddings
603
+ elif truncate_dim in self.config.matryoshka_dimensions:
604
+ return [tensor[:truncate_dim] for tensor in embeddings]
605
+ else:
606
+ raise ValueError(f'The provided `truncate_dim` value of {truncate_dim} is not supported. '
607
+ f'Supported dimensions are {self.config.matryoshka_dimensions}.')
608
+
609
  def mean_pooling(
610
  self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
611
  ):