Markus28 commited on
Commit
5549314
1 Parent(s): ed1b276

feat: return from_bert for from_pretrained

Browse files
Files changed (1) hide show
  1. modeling_lora.py +21 -0
modeling_lora.py CHANGED
@@ -1,4 +1,5 @@
1
  import math
 
2
  from functools import partial
3
  from typing import Iterator, Optional, Tuple, Union
4
 
@@ -6,6 +7,7 @@ import torch
6
  import torch.nn.utils.parametrize as parametrize
7
  from torch import nn
8
  from torch.nn import Parameter
 
9
 
10
  from .modeling_bert import BertModel, BertPreTrainedModel, JinaBertConfig
11
 
@@ -193,6 +195,25 @@ class BertLoRA(BertPreTrainedModel):
193
  return cls(config, bert=bert, num_adaptions=num_adaptions)
194
 
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
197
  self.apply(
198
  partial(
 
1
  import math
2
+ import os
3
  from functools import partial
4
  from typing import Iterator, Optional, Tuple, Union
5
 
 
7
  import torch.nn.utils.parametrize as parametrize
8
  from torch import nn
9
  from torch.nn import Parameter
10
+ from transformers import PretrainedConfig
11
 
12
  from .modeling_bert import BertModel, BertPreTrainedModel, JinaBertConfig
13
 
 
195
  return cls(config, bert=bert, num_adaptions=num_adaptions)
196
 
197
 
198
+ @classmethod
199
+ def from_pretrained(
200
+ cls,
201
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
202
+ *model_args,
203
+ config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
204
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
205
+ ignore_mismatched_sizes: bool = False,
206
+ force_download: bool = False,
207
+ local_files_only: bool = False,
208
+ token: Optional[Union[str, bool]] = None,
209
+ revision: str = "main",
210
+ use_safetensors: bool = None,
211
+ **kwargs,
212
+ ):
213
+ # TODO: choose between from_bert and super().from_pretrained
214
+ return cls.from_bert(pretrained_model_name_or_path)
215
+
216
+
217
  def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
218
  self.apply(
219
  partial(