Update flaubert2_model.py
Browse files- flaubert2_model.py +4 -1
flaubert2_model.py
CHANGED
@@ -388,6 +388,9 @@ class Flaubert2Model(RobertaModel):
|
|
388 |
|
389 |
sequence_output = encoder_outputs[0].transpose(0,1)
|
390 |
|
|
|
|
|
|
|
391 |
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
392 |
|
393 |
if not return_dict:
|
@@ -397,7 +400,7 @@ class Flaubert2Model(RobertaModel):
|
|
397 |
last_hidden_state=sequence_output,
|
398 |
pooler_output=pooled_output,
|
399 |
past_key_values=encoder_outputs.past_key_values,
|
400 |
-
hidden_states=
|
401 |
attentions=encoder_outputs.attentions,
|
402 |
cross_attentions=encoder_outputs.cross_attentions,
|
403 |
)
|
|
|
388 |
|
389 |
sequence_output = encoder_outputs[0].transpose(0,1)
|
390 |
|
391 |
+
# Fairseq Linformer implementation works with transposed hidden states -> we transpose them back for HF implementation.
|
392 |
+
hidden_states = [h.transpose(0,1) for h in encoder_outputs.hidden_states]
|
393 |
+
|
394 |
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
395 |
|
396 |
if not return_dict:
|
|
|
400 |
last_hidden_state=sequence_output,
|
401 |
pooler_output=pooled_output,
|
402 |
past_key_values=encoder_outputs.past_key_values,
|
403 |
+
hidden_states=hidden_states,
|
404 |
attentions=encoder_outputs.attentions,
|
405 |
cross_attentions=encoder_outputs.cross_attentions,
|
406 |
)
|