jupyterjazz
commited on
Commit
•
6cc0f51
1
Parent(s):
4d09ca8
draft
Browse filesSigned-off-by: jupyterjazz <saba.sturua@jina.ai>
- embedding.py +2 -1
- mha.py +4 -3
- modeling_lora.py +44 -5
- modeling_xlm_roberta.py +1 -1
embedding.py
CHANGED
@@ -47,7 +47,8 @@ class XLMRobertaEmbeddings(nn.Module):
|
|
47 |
token_type_ids: (batch, seqlen)
|
48 |
"""
|
49 |
batch_size, seqlen = input_ids.shape
|
50 |
-
|
|
|
51 |
if self.max_position_embeddings > 0:
|
52 |
if position_ids is None:
|
53 |
position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
|
|
|
47 |
token_type_ids: (batch, seqlen)
|
48 |
"""
|
49 |
batch_size, seqlen = input_ids.shape
|
50 |
+
print('input shape', input_ids.shape)
|
51 |
+
embeddings = self.word_embeddings(input_ids, task='sts')
|
52 |
if self.max_position_embeddings > 0:
|
53 |
if position_ids is None:
|
54 |
position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
|
mha.py
CHANGED
@@ -340,8 +340,8 @@ class CrossAttention(nn.Module):
|
|
340 |
class LinearResidual(nn.Linear):
|
341 |
"""Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
|
342 |
|
343 |
-
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
344 |
-
return super().forward(input), input
|
345 |
|
346 |
|
347 |
def _update_kv_cache(kv, inference_params, layer_idx):
|
@@ -450,6 +450,7 @@ class MHA(nn.Module):
|
|
450 |
|
451 |
if fused_bias_fc and FusedDense is None:
|
452 |
raise ImportError("fused_dense is not installed")
|
|
|
453 |
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
454 |
linear_resid_cls = (
|
455 |
LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
|
@@ -646,7 +647,7 @@ class MHA(nn.Module):
|
|
646 |
if not self.return_residual:
|
647 |
qkv = self.Wqkv(x)
|
648 |
else:
|
649 |
-
qkv, x = self.Wqkv(x)
|
650 |
if self.dwconv:
|
651 |
qkv = rearrange(
|
652 |
self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
|
|
|
340 |
class LinearResidual(nn.Linear):
|
341 |
"""Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
|
342 |
|
343 |
+
def forward(self, input: torch.Tensor, task=None) -> torch.Tensor:
|
344 |
+
return super().forward(input, task=task), input
|
345 |
|
346 |
|
347 |
def _update_kv_cache(kv, inference_params, layer_idx):
|
|
|
450 |
|
451 |
if fused_bias_fc and FusedDense is None:
|
452 |
raise ImportError("fused_dense is not installed")
|
453 |
+
print('is this true', fused_bias_fc)
|
454 |
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
455 |
linear_resid_cls = (
|
456 |
LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
|
|
|
647 |
if not self.return_residual:
|
648 |
qkv = self.Wqkv(x)
|
649 |
else:
|
650 |
+
qkv, x = self.Wqkv(x, task='sts')
|
651 |
if self.dwconv:
|
652 |
qkv = rearrange(
|
653 |
self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
|
modeling_lora.py
CHANGED
@@ -98,15 +98,15 @@ class LoRAParametrization(nn.Module):
|
|
98 |
# to mimic the original implementation: A @ dropout(x), we do (A * dropout(ones)) @ x
|
99 |
return A * self.lora_dropout(self.lora_dropout_mask)
|
100 |
|
101 |
-
def lora_forward(self, X):
|
102 |
-
|
103 |
return (
|
104 |
X
|
105 |
+ torch.matmul(
|
106 |
*self.swap(
|
107 |
(
|
108 |
-
self.lora_B[
|
109 |
-
self.dropout_fn(self.lora_A[
|
110 |
)
|
111 |
)
|
112 |
).view(X.shape)
|
@@ -114,7 +114,10 @@ class LoRAParametrization(nn.Module):
|
|
114 |
)
|
115 |
|
116 |
def forward(self, X):
|
117 |
-
|
|
|
|
|
|
|
118 |
|
119 |
@property
|
120 |
def current_task(self):
|
@@ -178,6 +181,7 @@ class LoRAParametrization(nn.Module):
|
|
178 |
rank: int,
|
179 |
dropout_p: float,
|
180 |
alpha: float,
|
|
|
181 |
):
|
182 |
if isinstance(layer, nn.Linear):
|
183 |
parametrize.register_parametrization(
|
@@ -191,6 +195,16 @@ class LoRAParametrization(nn.Module):
|
|
191 |
alpha=alpha,
|
192 |
),
|
193 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
elif isinstance(layer, nn.Embedding):
|
195 |
parametrize.register_parametrization(
|
196 |
layer,
|
@@ -203,6 +217,23 @@ class LoRAParametrization(nn.Module):
|
|
203 |
alpha=alpha,
|
204 |
),
|
205 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
|
207 |
@staticmethod
|
208 |
def select_task_for_layer(layer: nn.Module, task_idx: Optional[int] = None):
|
@@ -247,6 +278,13 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
247 |
self._task_idx = None
|
248 |
# By default, disable LoRA until it's specified which adapter/task to use
|
249 |
self.current_task = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
|
251 |
@property
|
252 |
def main_params_trainable(self):
|
@@ -300,6 +338,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
300 |
rank=rank,
|
301 |
dropout_p=dropout_p,
|
302 |
alpha=alpha,
|
|
|
303 |
)
|
304 |
)
|
305 |
|
|
|
98 |
# to mimic the original implementation: A @ dropout(x), we do (A * dropout(ones)) @ x
|
99 |
return A * self.lora_dropout(self.lora_dropout_mask)
|
100 |
|
101 |
+
def lora_forward(self, X, current_task=None):
|
102 |
+
print('lora input shape', X.shape)
|
103 |
return (
|
104 |
X
|
105 |
+ torch.matmul(
|
106 |
*self.swap(
|
107 |
(
|
108 |
+
self.lora_B[current_task],
|
109 |
+
self.dropout_fn(self.lora_A[current_task]),
|
110 |
)
|
111 |
)
|
112 |
).view(X.shape)
|
|
|
114 |
)
|
115 |
|
116 |
def forward(self, X):
|
117 |
+
print('forward input shape', X.shape, X)
|
118 |
+
out = self.forward_fn(X)
|
119 |
+
print(out.shape)
|
120 |
+
return out
|
121 |
|
122 |
@property
|
123 |
def current_task(self):
|
|
|
181 |
rank: int,
|
182 |
dropout_p: float,
|
183 |
alpha: float,
|
184 |
+
adaptation_map: dict,
|
185 |
):
|
186 |
if isinstance(layer, nn.Linear):
|
187 |
parametrize.register_parametrization(
|
|
|
195 |
alpha=alpha,
|
196 |
),
|
197 |
)
|
198 |
+
original_forward = layer.forward
|
199 |
+
|
200 |
+
def new_forward(self, input, task):
|
201 |
+
print('an aq mitxari aba')
|
202 |
+
output = original_forward(input, task=task)
|
203 |
+
weight = self.parametrizations.weight(self.weight, task)
|
204 |
+
return nn.functional.linear(input, weight, self.bias)
|
205 |
+
|
206 |
+
layer.forward = new_forward.__get__(layer, layer.__class__)
|
207 |
+
|
208 |
elif isinstance(layer, nn.Embedding):
|
209 |
parametrize.register_parametrization(
|
210 |
layer,
|
|
|
217 |
alpha=alpha,
|
218 |
),
|
219 |
)
|
220 |
+
original_forward = layer.forward
|
221 |
+
|
222 |
+
def new_forward(self, input, task):
|
223 |
+
print('input here', input, input.shape)
|
224 |
+
print('func', original_forward)
|
225 |
+
# original_forward['parametrizations'] = None
|
226 |
+
# print('funcc', original_forward.__dict__)
|
227 |
+
output = original_forward(input)
|
228 |
+
print(output.shape, 'output shape')
|
229 |
+
task_idx = adaptation_map[task] if task else None
|
230 |
+
if task_idx:
|
231 |
+
output = self.parametrizations.weight[0].lora_forward(output, current_task=task_idx)
|
232 |
+
print('thats it')
|
233 |
+
return output
|
234 |
+
|
235 |
+
layer.forward = new_forward.__get__(layer, layer.__class__)
|
236 |
+
|
237 |
|
238 |
@staticmethod
|
239 |
def select_task_for_layer(layer: nn.Module, task_idx: Optional[int] = None):
|
|
|
278 |
self._task_idx = None
|
279 |
# By default, disable LoRA until it's specified which adapter/task to use
|
280 |
self.current_task = None
|
281 |
+
for name, param in super().named_parameters():
|
282 |
+
if name == 'roberta.encoder.layers.22.mixer.Wqkv.parametrizations.weight.0.lora_A':
|
283 |
+
print('A0', param[0])
|
284 |
+
print('A1', param[1])
|
285 |
+
if name == 'roberta.encoder.layers.22.mixer.Wqkv.parametrizations.weight.0.lora_B':
|
286 |
+
print('B0', param[0])
|
287 |
+
print('B1', param[1])
|
288 |
|
289 |
@property
|
290 |
def main_params_trainable(self):
|
|
|
338 |
rank=rank,
|
339 |
dropout_p=dropout_p,
|
340 |
alpha=alpha,
|
341 |
+
adaptation_map=self._adaptation_map,
|
342 |
)
|
343 |
)
|
344 |
|
modeling_xlm_roberta.py
CHANGED
@@ -204,7 +204,7 @@ class XLMRobertaEncoder(nn.Module):
|
|
204 |
def gradient_checkpointing(self, value):
|
205 |
self._grad_checkpointing = value
|
206 |
|
207 |
-
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
|
208 |
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
209 |
This means that we only compute the last layer output for these tokens.
|
210 |
subset_mask: (batch, seqlen), dtype=torch.bool
|
|
|
204 |
def gradient_checkpointing(self, value):
|
205 |
self._grad_checkpointing = value
|
206 |
|
207 |
+
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None, task=None):
|
208 |
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
209 |
This means that we only compute the last layer output for these tokens.
|
210 |
subset_mask: (batch, seqlen), dtype=torch.bool
|