jupyterjazz commited on
Commit
3eb20d0
1 Parent(s): 509511d

refactor: modify encode

Browse files

Signed-off-by: jupyterjazz <saba.sturua@jina.ai>

Files changed (2) hide show
  1. modeling_lora.py +7 -9
  2. modeling_xlm_roberta.py +5 -2
modeling_lora.py CHANGED
@@ -337,7 +337,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
337
  def encode(
338
  self,
339
  *args,
340
- task: Union[str, None] = LORA_NO_UPDATE,
341
  **kwargs,
342
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
343
  """
@@ -351,13 +351,11 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
351
  adapters are disabled, and the model reverts to its original, general-purpose weights.
352
  If `task` is set to a specific LoRA adaptation, that adaptation is activated.
353
  """
354
- if task != LORA_NO_UPDATE:
355
- if not task:
356
- warnings.warn(
357
- f"Task-specific embeddings are disabled. To enable, specify the `task` "
358
- f"argument with one of the supported tasks: {', '.join(self.config.lora_adaptations)}",
359
- category=UserWarning,
360
- )
361
- self.current_task = task
362
 
363
  return self.roberta.encode(*args, **kwargs)
 
337
  def encode(
338
  self,
339
  *args,
340
+ task: Optional[str] = None,
341
  **kwargs,
342
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
343
  """
 
351
  adapters are disabled, and the model reverts to its original, general-purpose weights.
352
  If `task` is set to a specific LoRA adaptation, that adaptation is activated.
353
  """
354
+ if task and task not in self._lora_adaptations:
355
+ raise ValueError(
356
+ f"Unsupported task '{task}'. "
357
+ f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
358
+ f"Alternatively, don't pass the `task` argument to disable LoRA."
359
+ )
 
 
360
 
361
  return self.roberta.encode(*args, **kwargs)
modeling_xlm_roberta.py CHANGED
@@ -459,6 +459,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
459
  device: Optional[torch.device] = None,
460
  normalize_embeddings: bool = False,
461
  truncate_dim: Optional[int] = None,
 
462
  **tokenizer_kwargs,
463
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
464
  """
@@ -549,14 +550,16 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
549
  )
550
  else:
551
  range_iter = range(0, len(sentences), batch_size)
552
-
 
 
553
  for i in range_iter:
554
  encoded_input = self.tokenizer(
555
  sentences[i : i + batch_size],
556
  return_tensors='pt',
557
  **tokenizer_kwargs,
558
  ).to(self.device)
559
- token_embs = self.forward(**encoded_input)[0]
560
 
561
  # Accumulate in fp32 to avoid overflow
562
  token_embs = token_embs.float()
 
459
  device: Optional[torch.device] = None,
460
  normalize_embeddings: bool = False,
461
  truncate_dim: Optional[int] = None,
462
+ task: Optional[str] = None,
463
  **tokenizer_kwargs,
464
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
465
  """
 
550
  )
551
  else:
552
  range_iter = range(0, len(sentences), batch_size)
553
+ lora_kwargs = {}
554
+ if task:
555
+ lora_kwargs['task'] = task
556
  for i in range_iter:
557
  encoded_input = self.tokenizer(
558
  sentences[i : i + batch_size],
559
  return_tensors='pt',
560
  **tokenizer_kwargs,
561
  ).to(self.device)
562
+ token_embs = self.forward(**encoded_input, **lora_kwargs)[0]
563
 
564
  # Accumulate in fp32 to avoid overflow
565
  token_embs = token_embs.float()