Jackmin108
commited on
Commit
•
5f8e4b6
1
Parent(s):
c35a42b
fix: sentences as a str
Browse filesSigned-off-by: Meow <ongjackm@gmail.com>
- modeling_lora.py +9 -6
modeling_lora.py
CHANGED
@@ -393,17 +393,20 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
393 |
f"Alternatively, don't pass the `task_type` argument to disable LoRA."
|
394 |
)
|
395 |
adapter_mask = None
|
396 |
-
sentences = list(sentences) if isinstance(sentences, str) else sentences
|
397 |
if task_type:
|
398 |
task_id = self._adaptation_map[task_type]
|
|
|
399 |
adapter_mask = torch.full(
|
400 |
-
(
|
401 |
)
|
402 |
if task_type in ["query", "passage"]:
|
403 |
-
sentences
|
404 |
-
self._task_instructions[task_type] + " " +
|
405 |
-
|
406 |
-
|
|
|
|
|
|
|
407 |
return self.roberta.encode(
|
408 |
sentences, *args, adapter_mask=adapter_mask, **kwargs
|
409 |
)
|
|
|
393 |
f"Alternatively, don't pass the `task_type` argument to disable LoRA."
|
394 |
)
|
395 |
adapter_mask = None
|
|
|
396 |
if task_type:
|
397 |
task_id = self._adaptation_map[task_type]
|
398 |
+
num_examples = 1 if isinstance(sentences, str) else len(sentences)
|
399 |
adapter_mask = torch.full(
|
400 |
+
(num_examples,), task_id, dtype=torch.int32, device=self.device
|
401 |
)
|
402 |
if task_type in ["query", "passage"]:
|
403 |
+
if isinstance(sentences, str):
|
404 |
+
sentences = self._task_instructions[task_type] + " " + sentences
|
405 |
+
else:
|
406 |
+
sentences = [
|
407 |
+
self._task_instructions[task_type] + " " + sentence
|
408 |
+
for sentence in sentences
|
409 |
+
]
|
410 |
return self.roberta.encode(
|
411 |
sentences, *args, adapter_mask=adapter_mask, **kwargs
|
412 |
)
|