TCMVince commited on
Commit
9c29a9d
1 Parent(s): d25b8f2

Update flaubert2_model.py

Browse files
Files changed (1) hide show
  1. flaubert2_model.py +4 -1
flaubert2_model.py CHANGED
@@ -388,6 +388,9 @@ class Flaubert2Model(RobertaModel):
388
 
389
  sequence_output = encoder_outputs[0].transpose(0,1)
390
 
 
 
 
391
  pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
392
 
393
  if not return_dict:
@@ -397,7 +400,7 @@ class Flaubert2Model(RobertaModel):
397
  last_hidden_state=sequence_output,
398
  pooler_output=pooled_output,
399
  past_key_values=encoder_outputs.past_key_values,
400
- hidden_states=encoder_outputs.hidden_states,
401
  attentions=encoder_outputs.attentions,
402
  cross_attentions=encoder_outputs.cross_attentions,
403
  )
 
388
 
389
  sequence_output = encoder_outputs[0].transpose(0,1)
390
 
391
+ # Fairseq Linformer implementation works with transposed hidden states -> we transpose them back for HF implementation.
392
+ hidden_states = [h.transpose(0,1) for h in encoder_outputs.hidden_states]
393
+
394
  pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
395
 
396
  if not return_dict:
 
400
  last_hidden_state=sequence_output,
401
  pooler_output=pooled_output,
402
  past_key_values=encoder_outputs.past_key_values,
403
+ hidden_states=hidden_states,
404
  attentions=encoder_outputs.attentions,
405
  cross_attentions=encoder_outputs.cross_attentions,
406
  )