jonsaadfalcon commited on
Commit
90b56d4
·
1 Parent(s): 434063d

Upload bert_layers.py

Browse files
Files changed (1) hide show
  1. bert_layers.py +4 -9
bert_layers.py CHANGED
@@ -610,7 +610,7 @@ class BertForMaskedLM(BertPreTrainedModel):
610
  'If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for '
611
  'bi-directional self-attention.')
612
 
613
- self.bert = BertModel(config, add_pooling_layer=False)
614
  self.cls = BertOnlyMLMHead(config,
615
  self.bert.embeddings.word_embeddings.weight)
616
 
@@ -705,18 +705,13 @@ class BertForMaskedLM(BertPreTrainedModel):
705
  return_dict=return_dict,
706
  masked_tokens_mask=masked_tokens_mask,
707
  )
708
-
709
  if torch.isnan(outputs[0]).any():
710
  print("NaNs in outputs.")
711
  raise ValueError()
712
 
713
- #print("MLM Outputs")
714
- #print(outputs[0].shape)
715
-
716
- pooled_output = outputs[0]
717
-
718
- last_hidden_state_formatted = outputs[0][:,0,:].view(-1, self.config.hidden_size)
719
- return {"sentence_embedding": last_hidden_state_formatted}
720
 
721
  def prepare_inputs_for_generation(self, input_ids: torch.Tensor,
722
  attention_mask: torch.Tensor,
 
610
  'If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for '
611
  'bi-directional self-attention.')
612
 
613
+ self.bert = BertModel(config, add_pooling_layer=True)
614
  self.cls = BertOnlyMLMHead(config,
615
  self.bert.embeddings.word_embeddings.weight)
616
 
 
705
  return_dict=return_dict,
706
  masked_tokens_mask=masked_tokens_mask,
707
  )
708
+
709
  if torch.isnan(outputs[0]).any():
710
  print("NaNs in outputs.")
711
  raise ValueError()
712
 
713
+ pooled_output = outputs[1]
714
+ return {"sentence_embedding": pooled_output}
 
 
 
 
 
715
 
716
  def prepare_inputs_for_generation(self, input_ids: torch.Tensor,
717
  attention_mask: torch.Tensor,