Markus28 commited on
Commit
617fe56
·
1 Parent(s): e151a8f

feat: formatting and type hints

Browse files
Files changed (1) hide show
  1. modeling_lora.py +122 -35
modeling_lora.py CHANGED
@@ -1,23 +1,24 @@
 
1
  from functools import partial
2
- from typing import Iterator, Tuple
3
 
4
  import torch
5
- from torch import nn
6
  import torch.nn.utils.parametrize as parametrize
7
- import math
8
-
9
  from torch.nn import Parameter
10
 
11
  from .modeling_bert import BertModel, BertPreTrainedModel, JinaBertConfig
12
 
13
 
14
- def initialized_weights(shape, num_adaptions, init='kaiming'):
 
 
15
  weight_data = []
16
  for _ in range(num_adaptions):
17
  new_adaption = torch.zeros(shape)
18
- if init == 'kaiming':
19
  nn.init.kaiming_uniform_(new_adaption, a=math.sqrt(5))
20
- elif init == 'normal':
21
  nn.init.normal_(new_adaption)
22
  else:
23
  raise NotImplementedError
@@ -26,27 +27,48 @@ def initialized_weights(shape, num_adaptions, init='kaiming'):
26
 
27
 
28
  class LoRAParametrization(nn.Module):
29
- def __init__(self, fan_in, fan_out, layer_type='linear', num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
 
 
 
 
 
 
 
 
 
30
  super().__init__()
31
  # if weight is stored as (fan_out, fan_in), the memory layout of A & B follows (W + BA)x
32
  # otherwise, it's x(W + AB). This allows us to tie the weights between linear layers and embeddings
33
- fan_in_fan_out = (layer_type == 'embedding')
34
  self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x)
35
 
36
- if layer_type == 'linear':
37
- self.lora_A = nn.Parameter(initialized_weights((rank, fan_in), num_adaptions, init='kaiming'))
 
 
38
  self.lora_B = nn.Parameter(torch.zeros((num_adaptions, fan_out, rank)))
39
- elif layer_type == 'embedding':
40
  self.lora_A = nn.Parameter(torch.zeros((num_adaptions, fan_in, rank)))
41
- self.lora_B = nn.Parameter(initialized_weights((rank, fan_out), num_adaptions=num_adaptions, init='normal'))
 
 
 
 
42
  else:
43
  raise NotImplementedError
44
 
45
  self.lora_alpha, self.rank = lora_alpha, rank
46
  self.scaling = lora_alpha / rank
47
- self.lora_dropout = nn.Dropout(p=lora_dropout_p) if lora_dropout_p > 0 else lambda x: x
 
 
48
  self.dropout_fn = self._dropout if lora_dropout_p > 0 else lambda x: x
49
- self.register_buffer("lora_dropout_mask", torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype), persistent=False)
 
 
 
 
50
  self.forward_fn = lambda x: x
51
  self.current_task = None
52
 
@@ -56,7 +78,18 @@ class LoRAParametrization(nn.Module):
56
 
57
  def lora_forward(self, X):
58
  assert self.current_task is not None
59
- return X + torch.matmul(*self.swap((self.lora_B[self.current_task], self.dropout_fn(self.lora_A[self.current_task])))).view(X.shape) * self.scaling
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  def forward(self, X):
62
  return self.forward_fn(X)
@@ -69,28 +102,73 @@ class LoRAParametrization(nn.Module):
69
  self.forward_fn = self.lora_forward
70
 
71
  @classmethod
72
- def from_linear(cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
 
 
 
 
 
 
 
 
73
  fan_out, fan_in = layer.weight.shape
74
  return cls(
75
- fan_in, fan_out, num_adaptions=num_adaptions, layer_type='linear', rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha
 
 
 
 
 
 
76
  )
77
 
78
  @classmethod
79
- def from_embedding(cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
 
 
 
80
  fan_in, fan_out = layer.weight.shape
81
  return cls(
82
- fan_in, fan_out, num_adaptions=num_adaptions, layer_type='embedding', rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha
 
 
 
 
 
 
83
  )
84
 
85
  @classmethod
86
- def add_to_layer(cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
 
 
87
  if isinstance(layer, nn.Linear):
88
- parametrize.register_parametrization(layer, "weight", cls.from_linear(layer, num_adaptions=num_adaptions, rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha))
 
 
 
 
 
 
 
 
 
 
89
  elif isinstance(layer, nn.Embedding):
90
- parametrize.register_parametrization(layer, "weight", cls.from_embedding(layer, num_adaptions=num_adaptions, rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha))
 
 
 
 
 
 
 
 
 
 
91
 
92
  @classmethod
93
- def select_task_for_layer(cls, layer, task_idx=None):
94
  if isinstance(layer, LoRAParametrization):
95
  layer.select_task(task_idx)
96
 
@@ -101,7 +179,7 @@ class BertLoRA(BertPreTrainedModel):
101
  self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
102
  self._register_lora(num_adaptions)
103
  for name, param in super().named_parameters():
104
- if 'lora' not in name:
105
  param.requires_grad_(False)
106
 
107
  def from_bert(self, *args, num_adaptions=1, **kwargs):
@@ -109,10 +187,20 @@ class BertLoRA(BertPreTrainedModel):
109
  self._register_lora(num_adaptions)
110
 
111
  def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
112
- self.apply(partial(LoRAParametrization.add_to_layer, num_adaptions=num_adaptions, rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha))
 
 
 
 
 
 
 
 
113
 
114
- def select_task(self, task_idx):
115
- self.apply(partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx))
 
 
116
 
117
  def forward(self, *args, **kwargs):
118
  return self.bert(*args, **kwargs)
@@ -122,11 +210,10 @@ class BertLoRA(BertPreTrainedModel):
122
  yield param
123
 
124
  def named_parameters(
125
- self,
126
- prefix: str = '',
127
- recurse: bool = True,
128
- remove_duplicate: bool = True
129
  ) -> Iterator[Tuple[str, Parameter]]:
130
- for name, param in super().named_parameters(prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate):
131
- if 'lora' in name:
132
- yield name, param
 
 
 
1
+ import math
2
  from functools import partial
3
+ from typing import Iterator, Optional, Tuple, Union
4
 
5
  import torch
 
6
  import torch.nn.utils.parametrize as parametrize
7
+ from torch import nn
 
8
  from torch.nn import Parameter
9
 
10
  from .modeling_bert import BertModel, BertPreTrainedModel, JinaBertConfig
11
 
12
 
13
+ def initialized_weights(
14
+ shape: Tuple[int], num_adaptions: int, init: str = "kaiming"
15
+ ) -> torch.Tensor:
16
  weight_data = []
17
  for _ in range(num_adaptions):
18
  new_adaption = torch.zeros(shape)
19
+ if init == "kaiming":
20
  nn.init.kaiming_uniform_(new_adaption, a=math.sqrt(5))
21
+ elif init == "normal":
22
  nn.init.normal_(new_adaption)
23
  else:
24
  raise NotImplementedError
 
27
 
28
 
29
  class LoRAParametrization(nn.Module):
30
+ def __init__(
31
+ self,
32
+ fan_in: int,
33
+ fan_out: int,
34
+ layer_type: str = "linear",
35
+ num_adaptions: int = 1,
36
+ rank: int = 4,
37
+ lora_dropout_p: float = 0.0,
38
+ lora_alpha: float = 1,
39
+ ):
40
  super().__init__()
41
  # if weight is stored as (fan_out, fan_in), the memory layout of A & B follows (W + BA)x
42
  # otherwise, it's x(W + AB). This allows us to tie the weights between linear layers and embeddings
43
+ fan_in_fan_out = layer_type == "embedding"
44
  self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x)
45
 
46
+ if layer_type == "linear":
47
+ self.lora_A = nn.Parameter(
48
+ initialized_weights((rank, fan_in), num_adaptions, init="kaiming")
49
+ )
50
  self.lora_B = nn.Parameter(torch.zeros((num_adaptions, fan_out, rank)))
51
+ elif layer_type == "embedding":
52
  self.lora_A = nn.Parameter(torch.zeros((num_adaptions, fan_in, rank)))
53
+ self.lora_B = nn.Parameter(
54
+ initialized_weights(
55
+ (rank, fan_out), num_adaptions=num_adaptions, init="normal"
56
+ )
57
+ )
58
  else:
59
  raise NotImplementedError
60
 
61
  self.lora_alpha, self.rank = lora_alpha, rank
62
  self.scaling = lora_alpha / rank
63
+ self.lora_dropout = (
64
+ nn.Dropout(p=lora_dropout_p) if lora_dropout_p > 0 else lambda x: x
65
+ )
66
  self.dropout_fn = self._dropout if lora_dropout_p > 0 else lambda x: x
67
+ self.register_buffer(
68
+ "lora_dropout_mask",
69
+ torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype),
70
+ persistent=False,
71
+ )
72
  self.forward_fn = lambda x: x
73
  self.current_task = None
74
 
 
78
 
79
  def lora_forward(self, X):
80
  assert self.current_task is not None
81
+ return (
82
+ X
83
+ + torch.matmul(
84
+ *self.swap(
85
+ (
86
+ self.lora_B[self.current_task],
87
+ self.dropout_fn(self.lora_A[self.current_task]),
88
+ )
89
+ )
90
+ ).view(X.shape)
91
+ * self.scaling
92
+ )
93
 
94
  def forward(self, X):
95
  return self.forward_fn(X)
 
102
  self.forward_fn = self.lora_forward
103
 
104
  @classmethod
105
+ def from_linear(
106
+ cls,
107
+ layer: nn.Module,
108
+ num_adaptions: int = 1,
109
+ rank: int = 4,
110
+ lora_dropout_p: float = 0.0,
111
+ lora_alpha: int = 1,
112
+ ):
113
+ assert isinstance(layer, nn.Linear)
114
  fan_out, fan_in = layer.weight.shape
115
  return cls(
116
+ fan_in,
117
+ fan_out,
118
+ num_adaptions=num_adaptions,
119
+ layer_type="linear",
120
+ rank=rank,
121
+ lora_dropout_p=lora_dropout_p,
122
+ lora_alpha=lora_alpha,
123
  )
124
 
125
  @classmethod
126
+ def from_embedding(
127
+ cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1
128
+ ):
129
+ assert isinstance(layer, nn.Embedding)
130
  fan_in, fan_out = layer.weight.shape
131
  return cls(
132
+ fan_in,
133
+ fan_out,
134
+ num_adaptions=num_adaptions,
135
+ layer_type="embedding",
136
+ rank=rank,
137
+ lora_dropout_p=lora_dropout_p,
138
+ lora_alpha=lora_alpha,
139
  )
140
 
141
  @classmethod
142
+ def add_to_layer(
143
+ cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1
144
+ ):
145
  if isinstance(layer, nn.Linear):
146
+ parametrize.register_parametrization(
147
+ layer,
148
+ "weight",
149
+ cls.from_linear(
150
+ layer,
151
+ num_adaptions=num_adaptions,
152
+ rank=rank,
153
+ lora_dropout_p=lora_dropout_p,
154
+ lora_alpha=lora_alpha,
155
+ ),
156
+ )
157
  elif isinstance(layer, nn.Embedding):
158
+ parametrize.register_parametrization(
159
+ layer,
160
+ "weight",
161
+ cls.from_embedding(
162
+ layer,
163
+ num_adaptions=num_adaptions,
164
+ rank=rank,
165
+ lora_dropout_p=lora_dropout_p,
166
+ lora_alpha=lora_alpha,
167
+ ),
168
+ )
169
 
170
  @classmethod
171
+ def select_task_for_layer(cls, layer: nn.Module, task_idx: Optional[int] = None):
172
  if isinstance(layer, LoRAParametrization):
173
  layer.select_task(task_idx)
174
 
 
179
  self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
180
  self._register_lora(num_adaptions)
181
  for name, param in super().named_parameters():
182
+ if "lora" not in name:
183
  param.requires_grad_(False)
184
 
185
  def from_bert(self, *args, num_adaptions=1, **kwargs):
 
187
  self._register_lora(num_adaptions)
188
 
189
  def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
190
+ self.apply(
191
+ partial(
192
+ LoRAParametrization.add_to_layer,
193
+ num_adaptions=num_adaptions,
194
+ rank=rank,
195
+ lora_dropout_p=lora_dropout_p,
196
+ lora_alpha=lora_alpha,
197
+ )
198
+ )
199
 
200
+ def select_task(self, task_idx: Union[None, int]):
201
+ self.apply(
202
+ partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
203
+ )
204
 
205
  def forward(self, *args, **kwargs):
206
  return self.bert(*args, **kwargs)
 
210
  yield param
211
 
212
  def named_parameters(
213
+ self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
 
 
 
214
  ) -> Iterator[Tuple[str, Parameter]]:
215
+ for name, param in super().named_parameters(
216
+ prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate
217
+ ):
218
+ if "lora" in name:
219
+ yield name, param