Upload modelling_RW.py
Browse files- modelling_RW.py +18 -4
modelling_RW.py
CHANGED
@@ -276,9 +276,23 @@ class Attention(nn.Module):
|
|
276 |
key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
277 |
value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
278 |
|
279 |
-
|
280 |
-
|
281 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
|
283 |
x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
|
284 |
x = x.permute(0, 2, 1, 3)
|
@@ -945,7 +959,7 @@ class RWForTokenClassification(RWPreTrainedModel):
|
|
945 |
else:
|
946 |
classifier_dropout = 0.1
|
947 |
self.dropout = nn.Dropout(classifier_dropout)
|
948 |
-
self.classifier = nn.Linear(config.hidden_size, config.
|
949 |
|
950 |
# Initialize weights and apply final processing
|
951 |
self.post_init()
|
|
|
276 |
key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
277 |
value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
278 |
|
279 |
+
if torch.__version__ < "2.0.0":
|
280 |
+
mask = torch.ones(q_length, q_length, device=query_layer_.device)
|
281 |
+
mask = torch.tril(mask)
|
282 |
+
mask = (1.0 - mask) * -10000
|
283 |
+
mask = mask.repeat(batch_size, 1, 1, 1)
|
284 |
+
|
285 |
+
scores = torch.matmul(query_layer_, key_layer_.transpose(-2, -1))
|
286 |
+
scores = scores / math.sqrt(float(self.head_dim))
|
287 |
+
scores = scores + mask.type_as(scores)
|
288 |
+
|
289 |
+
probs = nn.Softmax(dim=-1)(scores)
|
290 |
+
|
291 |
+
attn_output = probs @ value_layer_
|
292 |
+
else:
|
293 |
+
attn_output = F.scaled_dot_product_attention(
|
294 |
+
query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
|
295 |
+
)
|
296 |
|
297 |
x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
|
298 |
x = x.permute(0, 2, 1, 3)
|
|
|
959 |
else:
|
960 |
classifier_dropout = 0.1
|
961 |
self.dropout = nn.Dropout(classifier_dropout)
|
962 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_lab els)
|
963 |
|
964 |
# Initialize weights and apply final processing
|
965 |
self.post_init()
|