Update modeling.py (#5)
Browse files- Update modeling.py (d0284a3c6268d1484dd3b5398dfce7b32a51724e)
Co-authored-by: Xin Zhang <izhx@users.noreply.huggingface.co>
- modeling.py +8 -5
modeling.py
CHANGED
@@ -975,8 +975,6 @@ class NewForMaskedLM(NewPreTrainedModel):
|
|
975 |
self.lm_head = NewLMPredictionHead(config)
|
976 |
self.loss_fct = nn.CrossEntropyLoss()
|
977 |
|
978 |
-
self.pretraining = True
|
979 |
-
|
980 |
# Initialize weights and apply final processing
|
981 |
self.post_init()
|
982 |
|
@@ -1009,13 +1007,13 @@ class NewForMaskedLM(NewPreTrainedModel):
|
|
1009 |
|
1010 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1011 |
|
1012 |
-
if labels is None:
|
1013 |
length = None
|
1014 |
subset_indices = None
|
1015 |
else:
|
1016 |
length = attention_mask.sum(-1).tolist()
|
1017 |
labels = labels[attention_mask.bool()].unsqueeze(0)
|
1018 |
-
subset_indices = labels > -100
|
1019 |
|
1020 |
outputs = self.new(
|
1021 |
input_ids,
|
@@ -1037,7 +1035,12 @@ class NewForMaskedLM(NewPreTrainedModel):
|
|
1037 |
|
1038 |
masked_lm_loss = None
|
1039 |
if labels is not None:
|
1040 |
-
|
|
|
|
|
|
|
|
|
|
|
1041 |
masked_lm_loss = self.loss_fct(prediction_scores, labels)
|
1042 |
|
1043 |
if not return_dict:
|
|
|
975 |
self.lm_head = NewLMPredictionHead(config)
|
976 |
self.loss_fct = nn.CrossEntropyLoss()
|
977 |
|
|
|
|
|
978 |
# Initialize weights and apply final processing
|
979 |
self.post_init()
|
980 |
|
|
|
1007 |
|
1008 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1009 |
|
1010 |
+
if labels is None or not self.new.config.unpad_inputs:
|
1011 |
length = None
|
1012 |
subset_indices = None
|
1013 |
else:
|
1014 |
length = attention_mask.sum(-1).tolist()
|
1015 |
labels = labels[attention_mask.bool()].unsqueeze(0)
|
1016 |
+
subset_indices = labels > -100
|
1017 |
|
1018 |
outputs = self.new(
|
1019 |
input_ids,
|
|
|
1035 |
|
1036 |
masked_lm_loss = None
|
1037 |
if labels is not None:
|
1038 |
+
if subset_indices is None:
|
1039 |
+
mask = attention_mask.bool()
|
1040 |
+
prediction_scores = prediction_scores[mask]
|
1041 |
+
labels = labels[mask]
|
1042 |
+
else:
|
1043 |
+
labels = labels[subset_indices]
|
1044 |
masked_lm_loss = self.loss_fct(prediction_scores, labels)
|
1045 |
|
1046 |
if not return_dict:
|