jonsaadfalcon
commited on
Commit
·
90b56d4
1
Parent(s):
434063d
Upload bert_layers.py
Browse files- bert_layers.py +4 -9
bert_layers.py
CHANGED
@@ -610,7 +610,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|
610 |
'If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for '
|
611 |
'bi-directional self-attention.')
|
612 |
|
613 |
-
self.bert = BertModel(config, add_pooling_layer=
|
614 |
self.cls = BertOnlyMLMHead(config,
|
615 |
self.bert.embeddings.word_embeddings.weight)
|
616 |
|
@@ -705,18 +705,13 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|
705 |
return_dict=return_dict,
|
706 |
masked_tokens_mask=masked_tokens_mask,
|
707 |
)
|
708 |
-
|
709 |
if torch.isnan(outputs[0]).any():
|
710 |
print("NaNs in outputs.")
|
711 |
raise ValueError()
|
712 |
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
-
pooled_output = outputs[0]
|
717 |
-
|
718 |
-
last_hidden_state_formatted = outputs[0][:,0,:].view(-1, self.config.hidden_size)
|
719 |
-
return {"sentence_embedding": last_hidden_state_formatted}
|
720 |
|
721 |
def prepare_inputs_for_generation(self, input_ids: torch.Tensor,
|
722 |
attention_mask: torch.Tensor,
|
|
|
610 |
'If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for '
|
611 |
'bi-directional self-attention.')
|
612 |
|
613 |
+
self.bert = BertModel(config, add_pooling_layer=True)
|
614 |
self.cls = BertOnlyMLMHead(config,
|
615 |
self.bert.embeddings.word_embeddings.weight)
|
616 |
|
|
|
705 |
return_dict=return_dict,
|
706 |
masked_tokens_mask=masked_tokens_mask,
|
707 |
)
|
708 |
+
|
709 |
if torch.isnan(outputs[0]).any():
|
710 |
print("NaNs in outputs.")
|
711 |
raise ValueError()
|
712 |
|
713 |
+
pooled_output = outputs[1]
|
714 |
+
return {"sentence_embedding": pooled_output}
|
|
|
|
|
|
|
|
|
|
|
715 |
|
716 |
def prepare_inputs_for_generation(self, input_ids: torch.Tensor,
|
717 |
attention_mask: torch.Tensor,
|