feat: implement task type embeddings

#1
by Markus28 - opened
Files changed (2) hide show
  1. configuration_bert.py +4 -0
  2. modeling_bert.py +14 -2
configuration_bert.py CHANGED
@@ -81,6 +81,8 @@ class JinaBertConfig(PretrainedConfig):
81
  fused_dropout_add_ln=False,
82
  fused_bias_fc=False,
83
  pad_vocab_size_multiple=1,
 
 
84
  **kwargs,
85
  ):
86
  assert 'position_embedding_type' not in kwargs
@@ -106,3 +108,5 @@ class JinaBertConfig(PretrainedConfig):
106
  self.fused_dropout_add_ln = fused_dropout_add_ln
107
  self.fused_bias_fc = fused_bias_fc
108
  self.pad_vocab_size_multiple = pad_vocab_size_multiple
 
 
 
81
  fused_dropout_add_ln=False,
82
  fused_bias_fc=False,
83
  pad_vocab_size_multiple=1,
84
+ num_tasks=0,
85
+ use_flash_attn=True,
86
  **kwargs,
87
  ):
88
  assert 'position_embedding_type' not in kwargs
 
108
  self.fused_dropout_add_ln = fused_dropout_add_ln
109
  self.fused_bias_fc = fused_bias_fc
110
  self.pad_vocab_size_multiple = pad_vocab_size_multiple
111
+ self.num_tasks = num_tasks
112
+ self.use_flash_attn = use_flash_attn
modeling_bert.py CHANGED
@@ -59,6 +59,7 @@ logger = logging.getLogger(__name__)
59
 
60
 
61
  def create_mixer_cls(config, cross_attn=False, return_residual=False):
 
62
  fused_bias_fc = getattr(config, "fused_bias_fc", False)
63
  window_size = getattr(config, "window_size", (-1, -1))
64
  mixer_cls = partial(
@@ -68,7 +69,7 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
68
  dropout=config.attention_probs_dropout_prob,
69
  causal=False,
70
  fused_bias_fc=fused_bias_fc,
71
- use_flash_attn=True,
72
  return_residual=return_residual,
73
  use_alibi=True,
74
  window_size=window_size,
@@ -151,6 +152,7 @@ def _init_weights(module, initializer_range=0.02):
151
  class BertEncoder(nn.Module):
152
  def __init__(self, config: JinaBertConfig):
153
  super().__init__()
 
154
  self.layers = nn.ModuleList(
155
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
156
  )
@@ -171,7 +173,7 @@ class BertEncoder(nn.Module):
171
  This means that we only compute the last layer output for these tokens.
172
  subset_mask: (batch, seqlen), dtype=torch.bool
173
  """
174
- if key_padding_mask is None:
175
  mixer_kwargs = (
176
  {"key_padding_mask": key_padding_mask} if key_padding_mask is not None else None
177
  )
@@ -340,14 +342,21 @@ class BertModel(BertPreTrainedModel):
340
  self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
341
  self.encoder = BertEncoder(config)
342
  self.pooler = BertPooler(config) if add_pooling_layer else None
 
343
 
344
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
 
 
 
 
 
345
 
346
  def forward(
347
  self,
348
  input_ids,
349
  position_ids=None,
350
  token_type_ids=None,
 
351
  attention_mask=None,
352
  masked_tokens_mask=None,
353
  ):
@@ -359,6 +368,9 @@ class BertModel(BertPreTrainedModel):
359
  hidden_states = self.embeddings(
360
  input_ids, position_ids=position_ids, token_type_ids=token_type_ids
361
  )
 
 
 
362
  # TD [2022-12:18]: Don't need to force residual in fp32
363
  # BERT puts embedding LayerNorm before embedding dropout.
364
  if not self.fused_dropout_add_ln:
 
59
 
60
 
61
  def create_mixer_cls(config, cross_attn=False, return_residual=False):
62
+ use_flash_attn = getattr(config, "use_flash_attn", False)
63
  fused_bias_fc = getattr(config, "fused_bias_fc", False)
64
  window_size = getattr(config, "window_size", (-1, -1))
65
  mixer_cls = partial(
 
69
  dropout=config.attention_probs_dropout_prob,
70
  causal=False,
71
  fused_bias_fc=fused_bias_fc,
72
+ use_flash_attn=use_flash_attn,
73
  return_residual=return_residual,
74
  use_alibi=True,
75
  window_size=window_size,
 
152
  class BertEncoder(nn.Module):
153
  def __init__(self, config: JinaBertConfig):
154
  super().__init__()
155
+ self.use_flash_attn = getattr(config, "use_flash_attn", False)
156
  self.layers = nn.ModuleList(
157
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
158
  )
 
173
  This means that we only compute the last layer output for these tokens.
174
  subset_mask: (batch, seqlen), dtype=torch.bool
175
  """
176
+ if key_padding_mask is None or not self.use_flash_attn:
177
  mixer_kwargs = (
178
  {"key_padding_mask": key_padding_mask} if key_padding_mask is not None else None
179
  )
 
342
  self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
343
  self.encoder = BertEncoder(config)
344
  self.pooler = BertPooler(config) if add_pooling_layer else None
345
+ self.task_type_embeddings = nn.Embedding(config.num_tasks, config.hidden_size)
346
 
347
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
348
+ # We now initialize the task embeddings to 0; We do not use task types during
349
+ # pretraining. When we start using task types during embedding training,
350
+ # we want the model to behave exactly as in pretraining (i.e. task types
351
+ # have no effect).
352
+ nn.init.zeros_(self.task_type_embeddings.weight)
353
 
354
  def forward(
355
  self,
356
  input_ids,
357
  position_ids=None,
358
  token_type_ids=None,
359
+ task_type_ids=None,
360
  attention_mask=None,
361
  masked_tokens_mask=None,
362
  ):
 
368
  hidden_states = self.embeddings(
369
  input_ids, position_ids=position_ids, token_type_ids=token_type_ids
370
  )
371
+ if task_type_ids is not None:
372
+ hidden_states = hidden_states + self.task_type_embeddings(task_type_ids)
373
+
374
  # TD [2022-12:18]: Don't need to force residual in fp32
375
  # BERT puts embedding LayerNorm before embedding dropout.
376
  if not self.fused_dropout_add_ln: