jupyterjazz commited on
Commit
3afddee
1 Parent(s): b20a611

rename task type

Browse files

Signed-off-by: jupyterjazz <saba.sturua@jina.ai>

Files changed (2) hide show
  1. modeling_lora.py +10 -10
  2. modeling_xlm_roberta.py +1 -1
modeling_lora.py CHANGED
@@ -367,35 +367,35 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
367
  self,
368
  sentences: Union[str, List[str]],
369
  *args,
370
- task_type: Optional[str] = None,
371
  **kwargs,
372
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
373
  """
374
  Computes sentence embeddings.
375
  sentences(`str` or `List[str]`):
376
  Sentence or sentences to be encoded
377
- task_type(`str`, *optional*, defaults to `None`):
378
- Specifies the task for which the encoding is intended. If `task_type` is not provided,
379
  all LoRA adapters are disabled, and the model reverts to its original,
380
  general-purpose weights.
381
  """
382
- if task_type and task_type not in self._lora_adaptations:
383
  raise ValueError(
384
- f"Unsupported task '{task_type}'. "
385
  f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
386
- f"Alternatively, don't pass the `task_type` argument to disable LoRA."
387
  )
388
  adapter_mask = None
389
- if task_type:
390
- task_id = self._adaptation_map[task_type]
391
  num_examples = 1 if isinstance(sentences, str) else len(sentences)
392
  adapter_mask = torch.full(
393
  (num_examples,), task_id, dtype=torch.int32, device=self.device
394
  )
395
  if isinstance(sentences, str):
396
- sentences = self._task_instructions[task_type] + sentences
397
  else:
398
- sentences = [self._task_instructions[task_type] + sentence for sentence in sentences]
399
  return self.roberta.encode(
400
  sentences, *args, adapter_mask=adapter_mask, **kwargs
401
  )
 
367
  self,
368
  sentences: Union[str, List[str]],
369
  *args,
370
+ task: Optional[str] = None,
371
  **kwargs,
372
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
373
  """
374
  Computes sentence embeddings.
375
  sentences(`str` or `List[str]`):
376
  Sentence or sentences to be encoded
377
+ task(`str`, *optional*, defaults to `None`):
378
+ Specifies the task for which the encoding is intended. If `task` is not provided,
379
  all LoRA adapters are disabled, and the model reverts to its original,
380
  general-purpose weights.
381
  """
382
+ if task and task not in self._lora_adaptations:
383
  raise ValueError(
384
+ f"Unsupported task '{task}'. "
385
  f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
386
+ f"Alternatively, don't pass the `task` argument to disable LoRA."
387
  )
388
  adapter_mask = None
389
+ if task:
390
+ task_id = self._adaptation_map[task]
391
  num_examples = 1 if isinstance(sentences, str) else len(sentences)
392
  adapter_mask = torch.full(
393
  (num_examples,), task_id, dtype=torch.int32, device=self.device
394
  )
395
  if isinstance(sentences, str):
396
+ sentences = self._task_instructions[task] + sentences
397
  else:
398
+ sentences = [self._task_instructions[task] + sentence for sentence in sentences]
399
  return self.roberta.encode(
400
  sentences, *args, adapter_mask=adapter_mask, **kwargs
401
  )
modeling_xlm_roberta.py CHANGED
@@ -473,7 +473,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
473
  normalize_embeddings: bool = True,
474
  truncate_dim: Optional[int] = None,
475
  adapter_mask: Optional[torch.Tensor] = None,
476
- task_type: Optional[str] = None,
477
  **tokenizer_kwargs,
478
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
479
  """
 
473
  normalize_embeddings: bool = True,
474
  truncate_dim: Optional[int] = None,
475
  adapter_mask: Optional[torch.Tensor] = None,
476
+ task: Optional[str] = None,
477
  **tokenizer_kwargs,
478
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
479
  """