Raghavan commited on
Commit
2af0a6c
1 Parent(s): 814f361

Upload 7 files

Browse files
Files changed (1) hide show
  1. modeling_indictrans.py +2 -2
modeling_indictrans.py CHANGED
@@ -691,7 +691,7 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
691
  if self.layernorm_embedding is not None:
692
  x = self.layernorm_embedding(hidden_states)
693
  hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
694
-
695
  # expand attention_mask
696
  if attention_mask is not None:
697
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
@@ -756,7 +756,7 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
756
  if output_hidden_states:
757
  encoder_states = encoder_states + (hidden_states,)
758
 
759
- hidden_states = self.get_pooled_representation(hidden_states, attention_mask)
760
 
761
  if not return_dict:
762
  return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
 
691
  if self.layernorm_embedding is not None:
692
  x = self.layernorm_embedding(hidden_states)
693
  hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
694
+ original_attention_mask = attention_mask.clone()
695
  # expand attention_mask
696
  if attention_mask is not None:
697
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
 
756
  if output_hidden_states:
757
  encoder_states = encoder_states + (hidden_states,)
758
 
759
+ hidden_states = self.get_pooled_representation(hidden_states, original_attention_mask)
760
 
761
  if not return_dict:
762
  return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)