Files changed (7) hide show
  1. block.py +1 -1
  2. embedding.py +23 -5
  3. mha.py +36 -9
  4. mlp.py +29 -5
  5. modeling_lora.py +13 -13
  6. modeling_xlm_roberta.py +29 -22
  7. xlm_padding.py +5 -1
block.py CHANGED
@@ -233,7 +233,7 @@ class Block(nn.Module):
233
  is_rms_norm=isinstance(self.norm1, RMSNorm),
234
  )
235
  if not isinstance(self.mlp, nn.Identity):
236
- mlp_out = self.mlp(hidden_states, task_type=mixer_kwargs.get('task_type'))
237
  if self.return_residual: # mlp out is actually a pair here
238
  mlp_out, hidden_states = mlp_out
239
  if not self.fused_dropout_add_ln:
 
233
  is_rms_norm=isinstance(self.norm1, RMSNorm),
234
  )
235
  if not isinstance(self.mlp, nn.Identity):
236
+ mlp_out = self.mlp(hidden_states, cu_adapter_mask=mixer_kwargs.get('cu_adapter_mask'))
237
  if self.return_residual: # mlp out is actually a pair here
238
  mlp_out, hidden_states = mlp_out
239
  if not self.fused_dropout_add_ln:
embedding.py CHANGED
@@ -40,15 +40,25 @@ class XLMRobertaEmbeddings(nn.Module):
40
  if self.type_vocab_size > 0:
41
  self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
42
 
43
- def forward(self, input_ids, position_ids=None, token_type_ids=None, task_type=None):
44
  """
45
  input_ids: (batch, seqlen)
46
  position_ids: (batch, seqlen)
47
  token_type_ids: (batch, seqlen)
48
  """
49
  batch_size, seqlen = input_ids.shape
50
- lora_kwargs = {'task_type': task_type} if task_type is not None else {}
51
- embeddings = self.word_embeddings(input_ids, **lora_kwargs)
 
 
 
 
 
 
 
 
 
 
52
  if self.max_position_embeddings > 0:
53
  if position_ids is None:
54
  position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
@@ -58,6 +68,14 @@ class XLMRobertaEmbeddings(nn.Module):
58
  if self.type_vocab_size > 0:
59
  if token_type_ids is None:
60
  token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
61
- token_type_embeddings = self.token_type_embeddings(token_type_ids, **lora_kwargs)
62
- embeddings = embeddings + token_type_embeddings
 
 
 
 
 
 
 
 
63
  return embeddings
 
40
  if self.type_vocab_size > 0:
41
  self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
42
 
43
+ def forward(self, input_ids, position_ids=None, token_type_ids=None, adapter_mask=None):
44
  """
45
  input_ids: (batch, seqlen)
46
  position_ids: (batch, seqlen)
47
  token_type_ids: (batch, seqlen)
48
  """
49
  batch_size, seqlen = input_ids.shape
50
+ if adapter_mask is not None:
51
+ unique_tasks = torch.unique(adapter_mask)
52
+ embedding_dtype = next(self.word_embeddings.parameters()).dtype
53
+ embeddings = torch.empty(*input_ids.shape, self.word_embeddings.embedding_dim,
54
+ dtype=embedding_dtype, device=input_ids.device)
55
+ for task_id in unique_tasks:
56
+ task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
57
+ task_input_ids = input_ids[task_indices]
58
+ task_embeddings = self.word_embeddings(task_input_ids, task_id=task_id)
59
+ embeddings[task_indices] = task_embeddings
60
+ else:
61
+ embeddings = self.word_embeddings(input_ids)
62
  if self.max_position_embeddings > 0:
63
  if position_ids is None:
64
  position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
 
68
  if self.type_vocab_size > 0:
69
  if token_type_ids is None:
70
  token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
71
+
72
+ if adapter_mask is not None:
73
+ unique_tasks = torch.unique(adapter_mask)
74
+ for task_id in unique_tasks:
75
+ task_token_type_embeddings = self.token_type_embeddings(token_type_ids, task_id=task_id)
76
+ task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
77
+ embeddings[task_indices] = embeddings[task_indices] + task_token_type_embeddings
78
+ else:
79
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
80
+ embeddings = embeddings + token_type_embeddings
81
  return embeddings
mha.py CHANGED
@@ -590,7 +590,7 @@ class MHA(nn.Module):
590
  max_seqlen=None,
591
  mixer_subset=None,
592
  inference_params=None,
593
- task_type=None,
594
  **kwargs,
595
  ):
596
  """
@@ -643,15 +643,31 @@ class MHA(nn.Module):
643
  inference_params.max_sequence_len if inference_params is not None else max_seqlen
644
  )
645
  batch, seqlen = x.shape[:2]
 
646
  if not self.cross_attn and self.num_heads_kv == self.num_heads:
647
  assert x_kv is None and mixer_subset is None
648
- lora_kwargs = {'task_type': task_type} if task_type is not None else {}
649
- if not self.return_residual:
650
- qkv = self.Wqkv(x, **lora_kwargs)
 
 
 
 
 
 
 
 
 
 
 
651
  else:
652
- if lora_kwargs:
653
- lora_kwargs['residual'] = True
654
- qkv, x = self.Wqkv(x, **lora_kwargs)
 
 
 
 
655
 
656
  if self.dwconv:
657
  qkv = rearrange(
@@ -738,6 +754,17 @@ class MHA(nn.Module):
738
  else:
739
  context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
740
 
741
- lora_kwargs.pop('residual', None)
742
- out = self.out_proj(rearrange(context, "... h d -> ... (h d)"), **lora_kwargs)
 
 
 
 
 
 
 
 
 
 
 
743
  return out if not self.return_residual else (out, x)
 
590
  max_seqlen=None,
591
  mixer_subset=None,
592
  inference_params=None,
593
+ cu_adapter_mask=None,
594
  **kwargs,
595
  ):
596
  """
 
643
  inference_params.max_sequence_len if inference_params is not None else max_seqlen
644
  )
645
  batch, seqlen = x.shape[:2]
646
+ lora_kwargs = {}
647
  if not self.cross_attn and self.num_heads_kv == self.num_heads:
648
  assert x_kv is None and mixer_subset is None
649
+
650
+ if cu_adapter_mask is not None:
651
+ unique_tasks = torch.unique(cu_adapter_mask)
652
+ qkv_dtype = next(self.Wqkv.parameters()).dtype
653
+ qkv = torch.empty(x.shape[0], self.Wqkv.out_features,
654
+ dtype=qkv_dtype, device=x.device)
655
+ for task_id in unique_tasks:
656
+ task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
657
+ task_tensor = x[task_indices]
658
+ if not self.return_residual:
659
+ task_qkv = self.Wqkv(task_tensor, task_id=task_id)
660
+ else:
661
+ task_qkv, _ = self.Wqkv(task_tensor, task_id=task_id, residual=True)
662
+ qkv[task_indices] = task_qkv
663
  else:
664
+ if not self.return_residual:
665
+ qkv = self.Wqkv(x)
666
+ else:
667
+ if hasattr(self.Wqkv, 'parametrizations'):
668
+ qkv, x = self.Wqkv(x, residual=True)
669
+ else:
670
+ qkv, x = self.Wqkv(x)
671
 
672
  if self.dwconv:
673
  qkv = rearrange(
 
754
  else:
755
  context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
756
 
757
+ inp = rearrange(context, "... h d -> ... (h d)")
758
+ if cu_adapter_mask is not None:
759
+ unique_tasks = torch.unique(cu_adapter_mask)
760
+ out_dtype = next(self.out_proj.parameters()).dtype
761
+ out = torch.empty(inp.shape[0], self.out_proj.out_features,
762
+ dtype=out_dtype, device=inp.device)
763
+ for task_id in unique_tasks:
764
+ task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
765
+ task_tensor = inp[task_indices]
766
+ task_out = self.out_proj(task_tensor, task_id=task_id)
767
+ out[task_indices] = task_out
768
+ else:
769
+ out = self.out_proj(inp)
770
  return out if not self.return_residual else (out, x)
mlp.py CHANGED
@@ -47,12 +47,36 @@ class Mlp(nn.Module):
47
  self.activation = activation
48
  self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
49
 
50
- def forward(self, x, task_type=None):
51
- lora_kwargs = {'task_type': task_type} if task_type is not None else {}
52
- y = self.fc1(x, **lora_kwargs)
 
 
 
 
 
 
 
 
 
 
 
53
  y = self.activation(y)
54
- y = self.fc2(y, **lora_kwargs)
55
- return y if not self.return_residual else (y, x)
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
 
58
  class ParallelMLP(nn.Module):
 
47
  self.activation = activation
48
  self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
49
 
50
+ def forward(self, x, cu_adapter_mask=None):
51
+ if cu_adapter_mask is not None:
52
+ unique_tasks = torch.unique(cu_adapter_mask)
53
+ fc1_dtype = next(self.fc1.parameters()).dtype
54
+ y = torch.empty(x.shape[0], self.fc1.out_features,
55
+ dtype=fc1_dtype, device=x.device)
56
+ for task_id in unique_tasks:
57
+ task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
58
+ task_tensor = x[task_indices]
59
+ task_y = self.fc1(task_tensor, task_id=task_id)
60
+ y[task_indices] = task_y
61
+ else:
62
+ y = self.fc1(x)
63
+
64
  y = self.activation(y)
65
+
66
+ if cu_adapter_mask is not None:
67
+ unique_tasks = torch.unique(cu_adapter_mask)
68
+ fc2_dtype = next(self.fc2.parameters()).dtype
69
+ out = torch.empty(y.shape[0], self.fc2.out_features,
70
+ dtype=fc2_dtype, device=y.device)
71
+ for task_id in unique_tasks:
72
+ task_indices = (cu_adapter_mask == task_id).nonzero(as_tuple=True)[0]
73
+ task_tensor = y[task_indices]
74
+ task_out = self.fc2(task_tensor, task_id=task_id)
75
+ out[task_indices] = task_out
76
+ else:
77
+ out = self.fc2(y)
78
+
79
+ return out if not self.return_residual else (out, x)
80
 
81
 
82
  class ParallelMLP(nn.Module):
modeling_lora.py CHANGED
@@ -161,7 +161,6 @@ class LoRAParametrization(nn.Module):
161
  rank: int,
162
  dropout_p: float,
163
  alpha: float,
164
- adaptation_map: dict,
165
  ):
166
  if isinstance(layer, nn.Linear):
167
  parametrize.register_parametrization(
@@ -176,10 +175,9 @@ class LoRAParametrization(nn.Module):
176
  ),
177
  )
178
 
179
- def new_forward(self, input, task_type, residual=False):
180
- task_idx = adaptation_map[task_type] if task_type else None
181
- if task_idx is not None:
182
- weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_idx)
183
  else:
184
  weights = self.weight
185
 
@@ -204,10 +202,9 @@ class LoRAParametrization(nn.Module):
204
  ),
205
  )
206
 
207
- def new_forward(self, input, task_type):
208
- task_idx = adaptation_map[task_type] if task_type else None
209
- if task_idx is not None:
210
- weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_idx)
211
  else:
212
  weights = self.weight
213
 
@@ -227,7 +224,6 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
227
  roberta: Optional[XLMRobertaModel] = None
228
  ):
229
  super().__init__(config)
230
-
231
  if roberta is None:
232
  self.roberta = XLMRobertaModel(config)
233
  else:
@@ -318,7 +314,6 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
318
  rank=rank,
319
  dropout_p=dropout_p,
320
  alpha=alpha,
321
- adaptation_map=self._adaptation_map,
322
  )
323
  )
324
 
@@ -341,6 +336,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
341
  @torch.inference_mode()
342
  def encode(
343
  self,
 
344
  *args,
345
  task_type: Optional[str] = None,
346
  **kwargs,
@@ -359,5 +355,9 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
359
  f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
360
  f"Alternatively, don't pass the `task_type` argument to disable LoRA."
361
  )
362
-
363
- return self.roberta.encode(*args, task_type=task_type, **kwargs)
 
 
 
 
 
161
  rank: int,
162
  dropout_p: float,
163
  alpha: float,
 
164
  ):
165
  if isinstance(layer, nn.Linear):
166
  parametrize.register_parametrization(
 
175
  ),
176
  )
177
 
178
+ def new_forward(self, input, task_id=None, residual=False):
179
+ if task_id is not None:
180
+ weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_id)
 
181
  else:
182
  weights = self.weight
183
 
 
202
  ),
203
  )
204
 
205
+ def new_forward(self, input, task_id=None):
206
+ if task_id is not None:
207
+ weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_id)
 
208
  else:
209
  weights = self.weight
210
 
 
224
  roberta: Optional[XLMRobertaModel] = None
225
  ):
226
  super().__init__(config)
 
227
  if roberta is None:
228
  self.roberta = XLMRobertaModel(config)
229
  else:
 
314
  rank=rank,
315
  dropout_p=dropout_p,
316
  alpha=alpha,
 
317
  )
318
  )
319
 
 
336
  @torch.inference_mode()
337
  def encode(
338
  self,
339
+ sentences: Union[str, List[str]],
340
  *args,
341
  task_type: Optional[str] = None,
342
  **kwargs,
 
355
  f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
356
  f"Alternatively, don't pass the `task_type` argument to disable LoRA."
357
  )
358
+ adapter_mask = None
359
+ if task_type:
360
+ task_id = self._adaptation_map[task_type]
361
+ num_examples = 1 if isinstance(sentences, str) else len(sentences)
362
+ adapter_mask = torch.full((num_examples,), task_id, dtype=torch.int32, device=self.device)
363
+ return self.roberta.encode(sentences, *args, adapter_mask=adapter_mask, **kwargs)
modeling_xlm_roberta.py CHANGED
@@ -204,18 +204,15 @@ class XLMRobertaEncoder(nn.Module):
204
  def gradient_checkpointing(self, value):
205
  self._grad_checkpointing = value
206
 
207
- def forward(self, hidden_states, key_padding_mask=None, subset_mask=None, task_type=None):
208
  """If subset_mask is not None, we only want output for the subset of the sequence.
209
  This means that we only compute the last layer output for these tokens.
210
  subset_mask: (batch, seqlen), dtype=torch.bool
211
  """
212
  if key_padding_mask is None or not self.use_flash_attn:
213
- mixer_kwargs = (
214
- {"key_padding_mask": key_padding_mask.bool()}
215
- if key_padding_mask is not None
216
- else None
217
- )
218
- mixer_kwargs['task_type'] = task_type
219
  for layer in self.layers:
220
  if self._grad_checkpointing:
221
  hidden_states = torch.utils.checkpoint.checkpoint(
@@ -230,10 +227,11 @@ class XLMRobertaEncoder(nn.Module):
230
  hidden_states = hidden_states[subset_mask]
231
  else:
232
  batch, seqlen = hidden_states.shape[:2]
233
- hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
234
- hidden_states, key_padding_mask
235
  )
236
- mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch, "task_type": task_type}
 
237
  if subset_mask is None:
238
  for layer in self.layers:
239
  if self._grad_checkpointing:
@@ -310,13 +308,22 @@ class XLMRobertaPooler(nn.Module):
310
  self.dense = linear_cls(config.hidden_size, config.hidden_size)
311
  self.activation = nn.Tanh()
312
 
313
- def forward(self, hidden_states, pool=True, task_type=None):
314
  # We "pool" the model by simply taking the hidden state corresponding
315
  # to the first token.
316
- lora_kwargs = {'task_type': task_type} if task_type is not None else {}
317
-
318
  first_token_tensor = hidden_states[:, 0] if pool else hidden_states
319
- pooled_output = self.dense(first_token_tensor, **lora_kwargs)
 
 
 
 
 
 
 
 
 
 
 
320
  pooled_output = self.activation(pooled_output)
321
  return pooled_output
322
 
@@ -429,7 +436,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
429
  "gelu_fast",
430
  "gelu_pytorch_tanh",
431
  ]
432
-
433
  self.embeddings = XLMRobertaEmbeddings(
434
  config.hidden_size,
435
  config.vocab_size,
@@ -457,6 +463,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
457
  device: Optional[torch.device] = None,
458
  normalize_embeddings: bool = False,
459
  truncate_dim: Optional[int] = None,
 
460
  task_type: Optional[str] = None,
461
  **tokenizer_kwargs,
462
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
@@ -542,14 +549,14 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
542
  )
543
  else:
544
  range_iter = range(0, len(sentences), batch_size)
545
- lora_kwargs = {'task_type': task_type} if task_type is not None else {}
546
  for i in range_iter:
547
  encoded_input = self.tokenizer(
548
  sentences[i : i + batch_size],
549
  return_tensors='pt',
550
  **tokenizer_kwargs,
551
  ).to(self.device)
552
- token_embs = self.forward(**encoded_input, **lora_kwargs)[0]
553
 
554
  # Accumulate in fp32 to avoid overflow
555
  token_embs = token_embs.float()
@@ -637,7 +644,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
637
  layer output for these tokens.
638
  masked_tokens_mask: (batch, seqlen), dtype=torch.bool
639
  """
640
- task_type = kwargs.pop('task_type', None)
641
  if kwargs:
642
  for key, value in kwargs.items():
643
  if value is not None:
@@ -651,7 +658,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
651
  )
652
 
653
  hidden_states = self.embeddings(
654
- input_ids, position_ids=position_ids, token_type_ids=token_type_ids, task_type=task_type
655
  )
656
  # TD [2022-12:18]: Don't need to force residual in fp32
657
  # BERT puts embedding LayerNorm before embedding dropout.
@@ -675,12 +682,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
675
  subset_mask = None
676
 
677
  sequence_output = self.encoder(
678
- hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask, task_type=task_type
679
  )
680
 
681
  if masked_tokens_mask is None:
682
  pooled_output = (
683
- self.pooler(sequence_output, task_type=task_type) if self.pooler is not None else None
684
  )
685
  else:
686
  # TD [2022-03-01]: the indexing here is very tricky.
@@ -694,7 +701,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
694
  pool_input = sequence_output[first_col_mask[subset_mask]]
695
  sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
696
  pooled_output = (
697
- self.pooler(pool_input, pool=False, task_type=task_type) if self.pooler is not None else None
698
  )
699
 
700
  if not return_dict:
 
204
  def gradient_checkpointing(self, value):
205
  self._grad_checkpointing = value
206
 
207
+ def forward(self, hidden_states, key_padding_mask=None, subset_mask=None, adapter_mask=None):
208
  """If subset_mask is not None, we only want output for the subset of the sequence.
209
  This means that we only compute the last layer output for these tokens.
210
  subset_mask: (batch, seqlen), dtype=torch.bool
211
  """
212
  if key_padding_mask is None or not self.use_flash_attn:
213
+ mixer_kwargs = {'adapter_mask': adapter_mask}
214
+ if key_padding_mask is not None:
215
+ mixer_kwargs['key_padding_mask'] = key_padding_mask.bool()
 
 
 
216
  for layer in self.layers:
217
  if self._grad_checkpointing:
218
  hidden_states = torch.utils.checkpoint.checkpoint(
 
227
  hidden_states = hidden_states[subset_mask]
228
  else:
229
  batch, seqlen = hidden_states.shape[:2]
230
+ hidden_states, indices, cu_seqlens, max_seqlen_in_batch, cu_adapter_mask = unpad_input(
231
+ hidden_states, key_padding_mask, adapter_mask
232
  )
233
+ mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch, "cu_adapter_mask": cu_adapter_mask}
234
+
235
  if subset_mask is None:
236
  for layer in self.layers:
237
  if self._grad_checkpointing:
 
308
  self.dense = linear_cls(config.hidden_size, config.hidden_size)
309
  self.activation = nn.Tanh()
310
 
311
+ def forward(self, hidden_states, pool=True, adapter_mask=None):
312
  # We "pool" the model by simply taking the hidden state corresponding
313
  # to the first token.
 
 
314
  first_token_tensor = hidden_states[:, 0] if pool else hidden_states
315
+ if adapter_mask is not None:
316
+ unique_tasks = torch.unique(adapter_mask)
317
+ pool_dtype = next(self.dense.parameters()).dtype
318
+ pooled_output = torch.empty(first_token_tensor.shape[0], self.dense.out_features,
319
+ dtype=pool_dtype, device=first_token_tensor.device)
320
+ for task_id in unique_tasks:
321
+ task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
322
+ task_first_token_tensor = first_token_tensor[task_indices]
323
+ task_pooled_output = self.dense(task_first_token_tensor, task_id=task_id)
324
+ pooled_output[task_indices] = task_pooled_output
325
+ else:
326
+ pooled_output = self.dense(first_token_tensor)
327
  pooled_output = self.activation(pooled_output)
328
  return pooled_output
329
 
 
436
  "gelu_fast",
437
  "gelu_pytorch_tanh",
438
  ]
 
439
  self.embeddings = XLMRobertaEmbeddings(
440
  config.hidden_size,
441
  config.vocab_size,
 
463
  device: Optional[torch.device] = None,
464
  normalize_embeddings: bool = False,
465
  truncate_dim: Optional[int] = None,
466
+ adapter_mask: Optional[torch.Tensor] = None,
467
  task_type: Optional[str] = None,
468
  **tokenizer_kwargs,
469
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
 
549
  )
550
  else:
551
  range_iter = range(0, len(sentences), batch_size)
552
+ lora_arguments = {'adapter_mask': adapter_mask} if adapter_mask is not None else {}
553
  for i in range_iter:
554
  encoded_input = self.tokenizer(
555
  sentences[i : i + batch_size],
556
  return_tensors='pt',
557
  **tokenizer_kwargs,
558
  ).to(self.device)
559
+ token_embs = self.forward(**encoded_input, **lora_arguments)[0]
560
 
561
  # Accumulate in fp32 to avoid overflow
562
  token_embs = token_embs.float()
 
644
  layer output for these tokens.
645
  masked_tokens_mask: (batch, seqlen), dtype=torch.bool
646
  """
647
+ adapter_mask = kwargs.pop('adapter_mask', None)
648
  if kwargs:
649
  for key, value in kwargs.items():
650
  if value is not None:
 
658
  )
659
 
660
  hidden_states = self.embeddings(
661
+ input_ids, position_ids=position_ids, token_type_ids=token_type_ids, adapter_mask=adapter_mask
662
  )
663
  # TD [2022-12:18]: Don't need to force residual in fp32
664
  # BERT puts embedding LayerNorm before embedding dropout.
 
682
  subset_mask = None
683
 
684
  sequence_output = self.encoder(
685
+ hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask, adapter_mask=adapter_mask
686
  )
687
 
688
  if masked_tokens_mask is None:
689
  pooled_output = (
690
+ self.pooler(sequence_output, adapter_mask=adapter_mask) if self.pooler is not None else None
691
  )
692
  else:
693
  # TD [2022-03-01]: the indexing here is very tricky.
 
701
  pool_input = sequence_output[first_col_mask[subset_mask]]
702
  sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
703
  pooled_output = (
704
+ self.pooler(pool_input, pool=False, adapter_mask=adapter_mask) if self.pooler is not None else None
705
  )
706
 
707
  if not return_dict:
xlm_padding.py CHANGED
@@ -98,7 +98,7 @@ class IndexFirstAxisResidual(torch.autograd.Function):
98
  index_first_axis_residual = IndexFirstAxisResidual.apply
99
 
100
 
101
- def unpad_input(hidden_states, attention_mask):
102
  """
103
  Arguments:
104
  hidden_states: (batch, seqlen, ...)
@@ -113,6 +113,9 @@ def unpad_input(hidden_states, attention_mask):
113
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
114
  max_seqlen_in_batch = seqlens_in_batch.max().item()
115
  cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
 
 
 
116
  # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
117
  # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
118
  # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
@@ -123,6 +126,7 @@ def unpad_input(hidden_states, attention_mask):
123
  indices,
124
  cu_seqlens,
125
  max_seqlen_in_batch,
 
126
  )
127
 
128
 
 
98
  index_first_axis_residual = IndexFirstAxisResidual.apply
99
 
100
 
101
+ def unpad_input(hidden_states, attention_mask, adapter_mask=None):
102
  """
103
  Arguments:
104
  hidden_states: (batch, seqlen, ...)
 
113
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
114
  max_seqlen_in_batch = seqlens_in_batch.max().item()
115
  cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
116
+
117
+ cu_adapter_mask = torch.repeat_interleave(adapter_mask, cu_seqlens[1:] - cu_seqlens[:-1]) if adapter_mask is not None else None
118
+
119
  # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
120
  # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
121
  # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
 
126
  indices,
127
  cu_seqlens,
128
  max_seqlen_in_batch,
129
+ cu_adapter_mask,
130
  )
131
 
132