Update modeling_xlm_roberta.py (#12)
Browse files- Update modeling_xlm_roberta.py (6a473f18306789a58285f781c0b6ec6f4df03fdb)
- modeling_xlm_roberta.py +4 -3
modeling_xlm_roberta.py
CHANGED
@@ -61,7 +61,7 @@ except ImportError:
|
|
61 |
try:
|
62 |
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
63 |
except ImportError:
|
64 |
-
CrossEntropyLoss =
|
65 |
|
66 |
try:
|
67 |
from tqdm.autonotebook import trange
|
@@ -1168,14 +1168,15 @@ class XLMRobertaClassificationHead(nn.Module):
|
|
1168 |
|
1169 |
def __init__(self, config):
|
1170 |
super().__init__()
|
1171 |
-
|
|
|
1172 |
classifier_dropout = (
|
1173 |
config.classifier_dropout
|
1174 |
if config.classifier_dropout is not None
|
1175 |
else config.hidden_dropout_prob
|
1176 |
)
|
1177 |
self.dropout = nn.Dropout(classifier_dropout)
|
1178 |
-
self.out_proj =
|
1179 |
|
1180 |
def forward(self, features, **kwargs):
|
1181 |
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
|
|
|
61 |
try:
|
62 |
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
63 |
except ImportError:
|
64 |
+
CrossEntropyLoss = torch.nn.CrossEntropyLoss
|
65 |
|
66 |
try:
|
67 |
from tqdm.autonotebook import trange
|
|
|
1168 |
|
1169 |
def __init__(self, config):
|
1170 |
super().__init__()
|
1171 |
+
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
1172 |
+
self.dense = linear_cls(config.hidden_size, config.hidden_size)
|
1173 |
classifier_dropout = (
|
1174 |
config.classifier_dropout
|
1175 |
if config.classifier_dropout is not None
|
1176 |
else config.hidden_dropout_prob
|
1177 |
)
|
1178 |
self.dropout = nn.Dropout(classifier_dropout)
|
1179 |
+
self.out_proj = linear_cls(config.hidden_size, config.num_labels)
|
1180 |
|
1181 |
def forward(self, features, **kwargs):
|
1182 |
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
|