Jackmin108 commited on
Commit
5f8e4b6
1 Parent(s): c35a42b

fix: sentences as a str

Browse files

Signed-off-by: Meow <ongjackm@gmail.com>

Files changed (1) hide show
  1. modeling_lora.py +9 -6
modeling_lora.py CHANGED
@@ -393,17 +393,20 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
393
  f"Alternatively, don't pass the `task_type` argument to disable LoRA."
394
  )
395
  adapter_mask = None
396
- sentences = list(sentences) if isinstance(sentences, str) else sentences
397
  if task_type:
398
  task_id = self._adaptation_map[task_type]
 
399
  adapter_mask = torch.full(
400
- (len(sentences),), task_id, dtype=torch.int32, device=self.device
401
  )
402
  if task_type in ["query", "passage"]:
403
- sentences = [
404
- self._task_instructions[task_type] + " " + sentence
405
- for sentence in sentences
406
- ]
 
 
 
407
  return self.roberta.encode(
408
  sentences, *args, adapter_mask=adapter_mask, **kwargs
409
  )
 
393
  f"Alternatively, don't pass the `task_type` argument to disable LoRA."
394
  )
395
  adapter_mask = None
 
396
  if task_type:
397
  task_id = self._adaptation_map[task_type]
398
+ num_examples = 1 if isinstance(sentences, str) else len(sentences)
399
  adapter_mask = torch.full(
400
+ (num_examples,), task_id, dtype=torch.int32, device=self.device
401
  )
402
  if task_type in ["query", "passage"]:
403
+ if isinstance(sentences, str):
404
+ sentences = self._task_instructions[task_type] + " " + sentences
405
+ else:
406
+ sentences = [
407
+ self._task_instructions[task_type] + " " + sentence
408
+ for sentence in sentences
409
+ ]
410
  return self.roberta.encode(
411
  sentences, *args, adapter_mask=adapter_mask, **kwargs
412
  )