Raghavan commited on
Commit
814f361
1 Parent(s): 359d319

Upload 7 files

Browse files
Files changed (1) hide show
  1. modeling_indictrans.py +13 -0
modeling_indictrans.py CHANGED
@@ -606,6 +606,17 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
606
  # Initialize weights and apply final processing
607
  self.post_init()
608
 
 
 
 
 
 
 
 
 
 
 
 
609
  def forward(
610
  self,
611
  input_ids: Optional[torch.Tensor] = None,
@@ -745,6 +756,8 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
745
  if output_hidden_states:
746
  encoder_states = encoder_states + (hidden_states,)
747
 
 
 
748
  if not return_dict:
749
  return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
750
  return BaseModelOutput(
 
606
  # Initialize weights and apply final processing
607
  self.post_init()
608
 
609
+ def get_pooled_representation(self, hidden_states, attention_mask):
610
+ seqs = torch.clone(hidden_states)
611
+ seqs[attention_mask == 0] = 0
612
+ sentence_embedding = seqs.sum(dim=1)
613
+ weights = 1.0 / ((attention_mask != 0).float().sum(dim=1) + 1e-7)
614
+
615
+ sentence_embedding = torch.einsum(
616
+ "i...,i ->i...", sentence_embedding, weights
617
+ )
618
+ return sentence_embedding
619
+
620
  def forward(
621
  self,
622
  input_ids: Optional[torch.Tensor] = None,
 
756
  if output_hidden_states:
757
  encoder_states = encoder_states + (hidden_states,)
758
 
759
+ hidden_states = self.get_pooled_representation(hidden_states, attention_mask)
760
+
761
  if not return_dict:
762
  return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
763
  return BaseModelOutput(