makram93 commited on
Commit
51411ff
1 Parent(s): 71b163e

fix: read prompts from config

Browse files

Signed-off-by: Mohammad Kalim Akram <kalim.akram@jina.ai>

Files changed (2) hide show
  1. configuration_xlm_roberta.py +2 -0
  2. modeling_lora.py +11 -10
configuration_xlm_roberta.py CHANGED
@@ -23,6 +23,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
23
  use_cache=True,
24
  classifier_dropout=None,
25
  lora_adaptations=None,
 
26
  lora_rank=4,
27
  lora_dropout_p=0.0,
28
  lora_alpha=1,
@@ -55,6 +56,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
55
  self.classifier_dropout = classifier_dropout
56
  self.load_trained_adapters = load_trained_adapters
57
  self.lora_adaptations = lora_adaptations
 
58
  self.lora_rank = lora_rank
59
  self.lora_dropout_p = lora_dropout_p
60
  self.lora_alpha = lora_alpha
 
23
  use_cache=True,
24
  classifier_dropout=None,
25
  lora_adaptations=None,
26
+ lora_prompts=None,
27
  lora_rank=4,
28
  lora_dropout_p=0.0,
29
  lora_alpha=1,
 
56
  self.classifier_dropout = classifier_dropout
57
  self.load_trained_adapters = load_trained_adapters
58
  self.lora_adaptations = lora_adaptations
59
+ self.lora_prompts = lora_prompts
60
  self.lora_rank = lora_rank
61
  self.lora_dropout_p = lora_dropout_p
62
  self.lora_alpha = lora_alpha
modeling_lora.py CHANGED
@@ -228,6 +228,14 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
228
  raise ValueError(
229
  f'`lora_adaptations` must be a list and contain at least one element'
230
  )
 
 
 
 
 
 
 
 
231
  self._adaptation_map = {
232
  name: idx for idx, name in enumerate(self._lora_adaptations)
233
  }
@@ -244,13 +252,6 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
244
  self._task_idx = None
245
  # By default, disable LoRA until it's specified which adapter/task to use
246
  self.current_task = None
247
- self.prompts = {
248
- 'query': 'Represent the query for retrieving supporting documents: ',
249
- 'document': 'Represent the document for retrieval: ',
250
- 'sts': 'Represent the text for Semantic Textual Similarity: ',
251
- 'clustering': 'Cluster the text: ',
252
- 'classification': 'Classify the text: ',
253
- }
254
 
255
  @property
256
  def main_params_trainable(self):
@@ -342,7 +343,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
342
  else:
343
  input_ids = kwargs["input_ids"]
344
  input_text = self.roberta.tokenizer.decode(input_ids[0], skip_special_tokens=True)
345
- for task_name, prompt in self.prompts.items():
346
  if input_text.startswith(prompt):
347
  self.current_task = task_name
348
  break
@@ -385,7 +386,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
385
  self.current_task = task_type
386
  else: # infer the task from the input text
387
  input_text = args[0][0] if isinstance(args[0], list) else args[0] # take only the first sentence
388
- for task_name, prompt in self.prompts.items():
389
  if input_text.startswith(prompt):
390
  self.current_task = task_name
391
  break
@@ -397,4 +398,4 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
397
  )
398
  self.current_task = None # No task-specific adapter is found, just use the general-purpose weights
399
 
400
- return self.roberta.encode(*args, **kwargs)
 
228
  raise ValueError(
229
  f'`lora_adaptations` must be a list and contain at least one element'
230
  )
231
+ self._lora_prompts = config.lora_prompts
232
+ if (
233
+ not isinstance(self._lora_prompts, dict)
234
+ or len(self._lora_prompts) != len(self._lora_adaptations)
235
+ ):
236
+ raise ValueError(
237
+ f'`lora_prompts` must be a dict and contain the same number of elements as `lora_adaptations`'
238
+ )
239
  self._adaptation_map = {
240
  name: idx for idx, name in enumerate(self._lora_adaptations)
241
  }
 
252
  self._task_idx = None
253
  # By default, disable LoRA until it's specified which adapter/task to use
254
  self.current_task = None
 
 
 
 
 
 
 
255
 
256
  @property
257
  def main_params_trainable(self):
 
343
  else:
344
  input_ids = kwargs["input_ids"]
345
  input_text = self.roberta.tokenizer.decode(input_ids[0], skip_special_tokens=True)
346
+ for task_name, prompt in self._lora_prompts.items():
347
  if input_text.startswith(prompt):
348
  self.current_task = task_name
349
  break
 
386
  self.current_task = task_type
387
  else: # infer the task from the input text
388
  input_text = args[0][0] if isinstance(args[0], list) else args[0] # take only the first sentence
389
+ for task_name, prompt in self._lora_prompts.items():
390
  if input_text.startswith(prompt):
391
  self.current_task = task_name
392
  break
 
398
  )
399
  self.current_task = None # No task-specific adapter is found, just use the general-purpose weights
400
 
401
+ return self.roberta.encode(*args, **kwargs)