Markus28 commited on
Commit
a0ba9b2
1 Parent(s): f669876

Use attention dropout during training (#1)

Browse files

- fix: use attention dropout during training (e6bd2263db385384cdcf3e4b922cf42be912aef6)
- feat: use self.dropout_p (e2f03eb7b307b0826717d2df224415dbb4e7eead)

Files changed (1) hide show
  1. modeling_bert.py +4 -2
modeling_bert.py CHANGED
@@ -282,7 +282,8 @@ class JinaBertSelfAttention(nn.Module):
282
  self.layer_norm_q = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
283
  self.layer_norm_k = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
284
 
285
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
 
286
  self.position_embedding_type = position_embedding_type or getattr(
287
  config, "position_embedding_type", "absolute"
288
  )
@@ -357,7 +358,8 @@ class JinaBertSelfAttention(nn.Module):
357
  if self.attn_implementation == 'torch' and scaled_dot_product_attention is not None:
358
  b, _, s, _ = query_layer.shape
359
  new_bias = attention_mask + bias
360
- attn = scaled_dot_product_attention(query_layer, key_layer, value_layer, new_bias)
 
361
  attn = attn.permute(0, 2, 1, 3).contiguous()
362
  return (attn.view(b, s, self.all_head_size),)
363
 
 
282
  self.layer_norm_q = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
283
  self.layer_norm_k = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
284
 
285
+ self.dropout_p = config.attention_probs_dropout_prob
286
+ self.dropout = nn.Dropout(self.dropout_p)
287
  self.position_embedding_type = position_embedding_type or getattr(
288
  config, "position_embedding_type", "absolute"
289
  )
 
358
  if self.attn_implementation == 'torch' and scaled_dot_product_attention is not None:
359
  b, _, s, _ = query_layer.shape
360
  new_bias = attention_mask + bias
361
+ dropout_p = self.dropout_p if self.training else 0.0
362
+ attn = scaled_dot_product_attention(query_layer, key_layer, value_layer, new_bias, dropout_p=dropout_p)
363
  attn = attn.permute(0, 2, 1, 3).contiguous()
364
  return (attn.view(b, s, self.all_head_size),)
365