Markus28 commited on
Commit
851184a
·
1 Parent(s): 2e2b8d0

feat: made from_bert work

Browse files
Files changed (1) hide show
  1. modeling_lora.py +11 -5
modeling_lora.py CHANGED
@@ -174,18 +174,24 @@ class LoRAParametrization(nn.Module):
174
 
175
 
176
  class BertLoRA(BertPreTrainedModel):
177
- def __init__(self, config: JinaBertConfig, add_pooling_layer=True, num_adaptions=1):
178
  super().__init__(config)
179
- self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
 
 
 
180
  self._register_lora(num_adaptions)
181
  for name, param in super().named_parameters():
182
  if "lora" not in name:
183
  param.requires_grad_(False)
184
  self.select_task(0)
185
 
186
- def from_bert(self, *args, num_adaptions=1, **kwargs):
187
- self.bert = BertModel.from_pretrained(*args, **kwargs)
188
- self._register_lora(num_adaptions)
 
 
 
189
 
190
  def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
191
  self.apply(
 
174
 
175
 
176
  class BertLoRA(BertPreTrainedModel):
177
+ def __init__(self, config: JinaBertConfig, bert: Optional[BertModel] = None, add_pooling_layer=True, num_adaptions=1):
178
  super().__init__(config)
179
+ if bert is None:
180
+ self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
181
+ else:
182
+ self.bert = bert
183
  self._register_lora(num_adaptions)
184
  for name, param in super().named_parameters():
185
  if "lora" not in name:
186
  param.requires_grad_(False)
187
  self.select_task(0)
188
 
189
+ @classmethod
190
+ def from_bert(cls, *args, num_adaptions=1, **kwargs):
191
+ bert = BertModel.from_pretrained(*args, **kwargs)
192
+ config = JinaBertConfig.from_pretrained(*args, **kwargs)
193
+ return cls(config, bert=bert, num_adaptions=num_adaptions)
194
+
195
 
196
  def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
197
  self.apply(