feat: formatting and type hints
Browse files- modeling_lora.py +122 -35
modeling_lora.py
CHANGED
@@ -1,23 +1,24 @@
|
|
|
|
1 |
from functools import partial
|
2 |
-
from typing import Iterator, Tuple
|
3 |
|
4 |
import torch
|
5 |
-
from torch import nn
|
6 |
import torch.nn.utils.parametrize as parametrize
|
7 |
-
import
|
8 |
-
|
9 |
from torch.nn import Parameter
|
10 |
|
11 |
from .modeling_bert import BertModel, BertPreTrainedModel, JinaBertConfig
|
12 |
|
13 |
|
14 |
-
def initialized_weights(
|
|
|
|
|
15 |
weight_data = []
|
16 |
for _ in range(num_adaptions):
|
17 |
new_adaption = torch.zeros(shape)
|
18 |
-
if init ==
|
19 |
nn.init.kaiming_uniform_(new_adaption, a=math.sqrt(5))
|
20 |
-
elif init ==
|
21 |
nn.init.normal_(new_adaption)
|
22 |
else:
|
23 |
raise NotImplementedError
|
@@ -26,27 +27,48 @@ def initialized_weights(shape, num_adaptions, init='kaiming'):
|
|
26 |
|
27 |
|
28 |
class LoRAParametrization(nn.Module):
|
29 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
super().__init__()
|
31 |
# if weight is stored as (fan_out, fan_in), the memory layout of A & B follows (W + BA)x
|
32 |
# otherwise, it's x(W + AB). This allows us to tie the weights between linear layers and embeddings
|
33 |
-
fan_in_fan_out =
|
34 |
self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x)
|
35 |
|
36 |
-
if layer_type ==
|
37 |
-
self.lora_A = nn.Parameter(
|
|
|
|
|
38 |
self.lora_B = nn.Parameter(torch.zeros((num_adaptions, fan_out, rank)))
|
39 |
-
elif layer_type ==
|
40 |
self.lora_A = nn.Parameter(torch.zeros((num_adaptions, fan_in, rank)))
|
41 |
-
self.lora_B = nn.Parameter(
|
|
|
|
|
|
|
|
|
42 |
else:
|
43 |
raise NotImplementedError
|
44 |
|
45 |
self.lora_alpha, self.rank = lora_alpha, rank
|
46 |
self.scaling = lora_alpha / rank
|
47 |
-
self.lora_dropout =
|
|
|
|
|
48 |
self.dropout_fn = self._dropout if lora_dropout_p > 0 else lambda x: x
|
49 |
-
self.register_buffer(
|
|
|
|
|
|
|
|
|
50 |
self.forward_fn = lambda x: x
|
51 |
self.current_task = None
|
52 |
|
@@ -56,7 +78,18 @@ class LoRAParametrization(nn.Module):
|
|
56 |
|
57 |
def lora_forward(self, X):
|
58 |
assert self.current_task is not None
|
59 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
def forward(self, X):
|
62 |
return self.forward_fn(X)
|
@@ -69,28 +102,73 @@ class LoRAParametrization(nn.Module):
|
|
69 |
self.forward_fn = self.lora_forward
|
70 |
|
71 |
@classmethod
|
72 |
-
def from_linear(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
fan_out, fan_in = layer.weight.shape
|
74 |
return cls(
|
75 |
-
fan_in,
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
)
|
77 |
|
78 |
@classmethod
|
79 |
-
def from_embedding(
|
|
|
|
|
|
|
80 |
fan_in, fan_out = layer.weight.shape
|
81 |
return cls(
|
82 |
-
fan_in,
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
)
|
84 |
|
85 |
@classmethod
|
86 |
-
def add_to_layer(
|
|
|
|
|
87 |
if isinstance(layer, nn.Linear):
|
88 |
-
parametrize.register_parametrization(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
elif isinstance(layer, nn.Embedding):
|
90 |
-
parametrize.register_parametrization(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
@classmethod
|
93 |
-
def select_task_for_layer(cls, layer, task_idx=None):
|
94 |
if isinstance(layer, LoRAParametrization):
|
95 |
layer.select_task(task_idx)
|
96 |
|
@@ -101,7 +179,7 @@ class BertLoRA(BertPreTrainedModel):
|
|
101 |
self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
|
102 |
self._register_lora(num_adaptions)
|
103 |
for name, param in super().named_parameters():
|
104 |
-
if
|
105 |
param.requires_grad_(False)
|
106 |
|
107 |
def from_bert(self, *args, num_adaptions=1, **kwargs):
|
@@ -109,10 +187,20 @@ class BertLoRA(BertPreTrainedModel):
|
|
109 |
self._register_lora(num_adaptions)
|
110 |
|
111 |
def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
|
112 |
-
self.apply(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
-
def select_task(self, task_idx):
|
115 |
-
self.apply(
|
|
|
|
|
116 |
|
117 |
def forward(self, *args, **kwargs):
|
118 |
return self.bert(*args, **kwargs)
|
@@ -122,11 +210,10 @@ class BertLoRA(BertPreTrainedModel):
|
|
122 |
yield param
|
123 |
|
124 |
def named_parameters(
|
125 |
-
|
126 |
-
prefix: str = '',
|
127 |
-
recurse: bool = True,
|
128 |
-
remove_duplicate: bool = True
|
129 |
) -> Iterator[Tuple[str, Parameter]]:
|
130 |
-
for name, param in super().named_parameters(
|
131 |
-
|
132 |
-
|
|
|
|
|
|
1 |
+
import math
|
2 |
from functools import partial
|
3 |
+
from typing import Iterator, Optional, Tuple, Union
|
4 |
|
5 |
import torch
|
|
|
6 |
import torch.nn.utils.parametrize as parametrize
|
7 |
+
from torch import nn
|
|
|
8 |
from torch.nn import Parameter
|
9 |
|
10 |
from .modeling_bert import BertModel, BertPreTrainedModel, JinaBertConfig
|
11 |
|
12 |
|
13 |
+
def initialized_weights(
|
14 |
+
shape: Tuple[int], num_adaptions: int, init: str = "kaiming"
|
15 |
+
) -> torch.Tensor:
|
16 |
weight_data = []
|
17 |
for _ in range(num_adaptions):
|
18 |
new_adaption = torch.zeros(shape)
|
19 |
+
if init == "kaiming":
|
20 |
nn.init.kaiming_uniform_(new_adaption, a=math.sqrt(5))
|
21 |
+
elif init == "normal":
|
22 |
nn.init.normal_(new_adaption)
|
23 |
else:
|
24 |
raise NotImplementedError
|
|
|
27 |
|
28 |
|
29 |
class LoRAParametrization(nn.Module):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
fan_in: int,
|
33 |
+
fan_out: int,
|
34 |
+
layer_type: str = "linear",
|
35 |
+
num_adaptions: int = 1,
|
36 |
+
rank: int = 4,
|
37 |
+
lora_dropout_p: float = 0.0,
|
38 |
+
lora_alpha: float = 1,
|
39 |
+
):
|
40 |
super().__init__()
|
41 |
# if weight is stored as (fan_out, fan_in), the memory layout of A & B follows (W + BA)x
|
42 |
# otherwise, it's x(W + AB). This allows us to tie the weights between linear layers and embeddings
|
43 |
+
fan_in_fan_out = layer_type == "embedding"
|
44 |
self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x)
|
45 |
|
46 |
+
if layer_type == "linear":
|
47 |
+
self.lora_A = nn.Parameter(
|
48 |
+
initialized_weights((rank, fan_in), num_adaptions, init="kaiming")
|
49 |
+
)
|
50 |
self.lora_B = nn.Parameter(torch.zeros((num_adaptions, fan_out, rank)))
|
51 |
+
elif layer_type == "embedding":
|
52 |
self.lora_A = nn.Parameter(torch.zeros((num_adaptions, fan_in, rank)))
|
53 |
+
self.lora_B = nn.Parameter(
|
54 |
+
initialized_weights(
|
55 |
+
(rank, fan_out), num_adaptions=num_adaptions, init="normal"
|
56 |
+
)
|
57 |
+
)
|
58 |
else:
|
59 |
raise NotImplementedError
|
60 |
|
61 |
self.lora_alpha, self.rank = lora_alpha, rank
|
62 |
self.scaling = lora_alpha / rank
|
63 |
+
self.lora_dropout = (
|
64 |
+
nn.Dropout(p=lora_dropout_p) if lora_dropout_p > 0 else lambda x: x
|
65 |
+
)
|
66 |
self.dropout_fn = self._dropout if lora_dropout_p > 0 else lambda x: x
|
67 |
+
self.register_buffer(
|
68 |
+
"lora_dropout_mask",
|
69 |
+
torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype),
|
70 |
+
persistent=False,
|
71 |
+
)
|
72 |
self.forward_fn = lambda x: x
|
73 |
self.current_task = None
|
74 |
|
|
|
78 |
|
79 |
def lora_forward(self, X):
|
80 |
assert self.current_task is not None
|
81 |
+
return (
|
82 |
+
X
|
83 |
+
+ torch.matmul(
|
84 |
+
*self.swap(
|
85 |
+
(
|
86 |
+
self.lora_B[self.current_task],
|
87 |
+
self.dropout_fn(self.lora_A[self.current_task]),
|
88 |
+
)
|
89 |
+
)
|
90 |
+
).view(X.shape)
|
91 |
+
* self.scaling
|
92 |
+
)
|
93 |
|
94 |
def forward(self, X):
|
95 |
return self.forward_fn(X)
|
|
|
102 |
self.forward_fn = self.lora_forward
|
103 |
|
104 |
@classmethod
|
105 |
+
def from_linear(
|
106 |
+
cls,
|
107 |
+
layer: nn.Module,
|
108 |
+
num_adaptions: int = 1,
|
109 |
+
rank: int = 4,
|
110 |
+
lora_dropout_p: float = 0.0,
|
111 |
+
lora_alpha: int = 1,
|
112 |
+
):
|
113 |
+
assert isinstance(layer, nn.Linear)
|
114 |
fan_out, fan_in = layer.weight.shape
|
115 |
return cls(
|
116 |
+
fan_in,
|
117 |
+
fan_out,
|
118 |
+
num_adaptions=num_adaptions,
|
119 |
+
layer_type="linear",
|
120 |
+
rank=rank,
|
121 |
+
lora_dropout_p=lora_dropout_p,
|
122 |
+
lora_alpha=lora_alpha,
|
123 |
)
|
124 |
|
125 |
@classmethod
|
126 |
+
def from_embedding(
|
127 |
+
cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1
|
128 |
+
):
|
129 |
+
assert isinstance(layer, nn.Embedding)
|
130 |
fan_in, fan_out = layer.weight.shape
|
131 |
return cls(
|
132 |
+
fan_in,
|
133 |
+
fan_out,
|
134 |
+
num_adaptions=num_adaptions,
|
135 |
+
layer_type="embedding",
|
136 |
+
rank=rank,
|
137 |
+
lora_dropout_p=lora_dropout_p,
|
138 |
+
lora_alpha=lora_alpha,
|
139 |
)
|
140 |
|
141 |
@classmethod
|
142 |
+
def add_to_layer(
|
143 |
+
cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1
|
144 |
+
):
|
145 |
if isinstance(layer, nn.Linear):
|
146 |
+
parametrize.register_parametrization(
|
147 |
+
layer,
|
148 |
+
"weight",
|
149 |
+
cls.from_linear(
|
150 |
+
layer,
|
151 |
+
num_adaptions=num_adaptions,
|
152 |
+
rank=rank,
|
153 |
+
lora_dropout_p=lora_dropout_p,
|
154 |
+
lora_alpha=lora_alpha,
|
155 |
+
),
|
156 |
+
)
|
157 |
elif isinstance(layer, nn.Embedding):
|
158 |
+
parametrize.register_parametrization(
|
159 |
+
layer,
|
160 |
+
"weight",
|
161 |
+
cls.from_embedding(
|
162 |
+
layer,
|
163 |
+
num_adaptions=num_adaptions,
|
164 |
+
rank=rank,
|
165 |
+
lora_dropout_p=lora_dropout_p,
|
166 |
+
lora_alpha=lora_alpha,
|
167 |
+
),
|
168 |
+
)
|
169 |
|
170 |
@classmethod
|
171 |
+
def select_task_for_layer(cls, layer: nn.Module, task_idx: Optional[int] = None):
|
172 |
if isinstance(layer, LoRAParametrization):
|
173 |
layer.select_task(task_idx)
|
174 |
|
|
|
179 |
self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
|
180 |
self._register_lora(num_adaptions)
|
181 |
for name, param in super().named_parameters():
|
182 |
+
if "lora" not in name:
|
183 |
param.requires_grad_(False)
|
184 |
|
185 |
def from_bert(self, *args, num_adaptions=1, **kwargs):
|
|
|
187 |
self._register_lora(num_adaptions)
|
188 |
|
189 |
def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
|
190 |
+
self.apply(
|
191 |
+
partial(
|
192 |
+
LoRAParametrization.add_to_layer,
|
193 |
+
num_adaptions=num_adaptions,
|
194 |
+
rank=rank,
|
195 |
+
lora_dropout_p=lora_dropout_p,
|
196 |
+
lora_alpha=lora_alpha,
|
197 |
+
)
|
198 |
+
)
|
199 |
|
200 |
+
def select_task(self, task_idx: Union[None, int]):
|
201 |
+
self.apply(
|
202 |
+
partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
|
203 |
+
)
|
204 |
|
205 |
def forward(self, *args, **kwargs):
|
206 |
return self.bert(*args, **kwargs)
|
|
|
210 |
yield param
|
211 |
|
212 |
def named_parameters(
|
213 |
+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
|
|
|
|
|
|
|
214 |
) -> Iterator[Tuple[str, Parameter]]:
|
215 |
+
for name, param in super().named_parameters(
|
216 |
+
prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate
|
217 |
+
):
|
218 |
+
if "lora" in name:
|
219 |
+
yield name, param
|