feat: return from_bert for from_pretrained
Browse files- modeling_lora.py +21 -0
modeling_lora.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import math
|
|
|
2 |
from functools import partial
|
3 |
from typing import Iterator, Optional, Tuple, Union
|
4 |
|
@@ -6,6 +7,7 @@ import torch
|
|
6 |
import torch.nn.utils.parametrize as parametrize
|
7 |
from torch import nn
|
8 |
from torch.nn import Parameter
|
|
|
9 |
|
10 |
from .modeling_bert import BertModel, BertPreTrainedModel, JinaBertConfig
|
11 |
|
@@ -193,6 +195,25 @@ class BertLoRA(BertPreTrainedModel):
|
|
193 |
return cls(config, bert=bert, num_adaptions=num_adaptions)
|
194 |
|
195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
|
197 |
self.apply(
|
198 |
partial(
|
|
|
1 |
import math
|
2 |
+
import os
|
3 |
from functools import partial
|
4 |
from typing import Iterator, Optional, Tuple, Union
|
5 |
|
|
|
7 |
import torch.nn.utils.parametrize as parametrize
|
8 |
from torch import nn
|
9 |
from torch.nn import Parameter
|
10 |
+
from transformers import PretrainedConfig
|
11 |
|
12 |
from .modeling_bert import BertModel, BertPreTrainedModel, JinaBertConfig
|
13 |
|
|
|
195 |
return cls(config, bert=bert, num_adaptions=num_adaptions)
|
196 |
|
197 |
|
198 |
+
@classmethod
|
199 |
+
def from_pretrained(
|
200 |
+
cls,
|
201 |
+
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
202 |
+
*model_args,
|
203 |
+
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
|
204 |
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
205 |
+
ignore_mismatched_sizes: bool = False,
|
206 |
+
force_download: bool = False,
|
207 |
+
local_files_only: bool = False,
|
208 |
+
token: Optional[Union[str, bool]] = None,
|
209 |
+
revision: str = "main",
|
210 |
+
use_safetensors: bool = None,
|
211 |
+
**kwargs,
|
212 |
+
):
|
213 |
+
# TODO: choose between from_bert and super().from_pretrained
|
214 |
+
return cls.from_bert(pretrained_model_name_or_path)
|
215 |
+
|
216 |
+
|
217 |
def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
|
218 |
self.apply(
|
219 |
partial(
|