P01son commited on
Commit
88cbd5d
1 Parent(s): 4e2328a

Upload modelling_RW.py

Browse files
Files changed (1) hide show
  1. modelling_RW.py +18 -4
modelling_RW.py CHANGED
@@ -276,9 +276,23 @@ class Attention(nn.Module):
276
  key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
277
  value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
278
 
279
- attn_output = F.scaled_dot_product_attention(
280
- query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
281
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
  x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
284
  x = x.permute(0, 2, 1, 3)
@@ -945,7 +959,7 @@ class RWForTokenClassification(RWPreTrainedModel):
945
  else:
946
  classifier_dropout = 0.1
947
  self.dropout = nn.Dropout(classifier_dropout)
948
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
949
 
950
  # Initialize weights and apply final processing
951
  self.post_init()
 
276
  key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
277
  value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
278
 
279
+ if torch.__version__ < "2.0.0":
280
+ mask = torch.ones(q_length, q_length, device=query_layer_.device)
281
+ mask = torch.tril(mask)
282
+ mask = (1.0 - mask) * -10000
283
+ mask = mask.repeat(batch_size, 1, 1, 1)
284
+
285
+ scores = torch.matmul(query_layer_, key_layer_.transpose(-2, -1))
286
+ scores = scores / math.sqrt(float(self.head_dim))
287
+ scores = scores + mask.type_as(scores)
288
+
289
+ probs = nn.Softmax(dim=-1)(scores)
290
+
291
+ attn_output = probs @ value_layer_
292
+ else:
293
+ attn_output = F.scaled_dot_product_attention(
294
+ query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
295
+ )
296
 
297
  x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
298
  x = x.permute(0, 2, 1, 3)
 
959
  else:
960
  classifier_dropout = 0.1
961
  self.dropout = nn.Dropout(classifier_dropout)
962
+ self.classifier = nn.Linear(config.hidden_size, config.num_lab els)
963
 
964
  # Initialize weights and apply final processing
965
  self.post_init()