Daniel Hesslow commited on
Commit
2d91d77
1 Parent(s): e881e96

Update modelling_RW.py

Browse files
Files changed (1) hide show
  1. modelling_RW.py +2 -2
modelling_RW.py CHANGED
@@ -21,7 +21,7 @@ from transformers.modeling_outputs import (
21
  )
22
  from transformers.modeling_utils import PreTrainedModel
23
  from transformers.utils import logging
24
- from configuration_RW import RWConfig
25
 
26
  logger = logging.get_logger(__name__)
27
 
@@ -303,7 +303,7 @@ class Attention(nn.Module):
303
  attention_scores = attention_scores.to(torch.float32)
304
  # attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
305
  attention_probs = F.softmax(
306
- (attention_scores + alibi) * self.inv_norm_factor + attention_mask_float,
307
  dim=-1,
308
  dtype=hidden_states.dtype,
309
  )
 
21
  )
22
  from transformers.modeling_utils import PreTrainedModel
23
  from transformers.utils import logging
24
+ from .configuration_RW import RWConfig
25
 
26
  logger = logging.get_logger(__name__)
27
 
 
303
  attention_scores = attention_scores.to(torch.float32)
304
  # attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
305
  attention_probs = F.softmax(
306
+ (attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)) * self.inv_norm_factor + attention_mask_float,
307
  dim=-1,
308
  dtype=hidden_states.dtype,
309
  )