Update flaubert2_model.py
Browse files- flaubert2_model.py +5 -4
flaubert2_model.py
CHANGED
@@ -388,11 +388,12 @@ class Flaubert2Model(RobertaModel):
|
|
388 |
|
389 |
sequence_output = encoder_outputs[0].transpose(0,1)
|
390 |
|
391 |
-
# Fairseq Linformer implementation works with transposed hidden states -> we transpose them back for HF implementation.
|
392 |
-
hidden_states = [h.transpose(0,1) for h in encoder_outputs.hidden_states]
|
393 |
-
|
394 |
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
395 |
|
|
|
|
|
|
|
|
|
396 |
if not return_dict:
|
397 |
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
398 |
|
@@ -400,7 +401,7 @@ class Flaubert2Model(RobertaModel):
|
|
400 |
last_hidden_state=sequence_output,
|
401 |
pooler_output=pooled_output,
|
402 |
past_key_values=encoder_outputs.past_key_values,
|
403 |
-
hidden_states=hidden_states,
|
404 |
attentions=encoder_outputs.attentions,
|
405 |
cross_attentions=encoder_outputs.cross_attentions,
|
406 |
)
|
|
|
388 |
|
389 |
sequence_output = encoder_outputs[0].transpose(0,1)
|
390 |
|
|
|
|
|
|
|
391 |
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
392 |
|
393 |
+
# Fairseq Linformer implementation works with transposed hidden states -> we transpose them back for HF implementation.
|
394 |
+
if output_hidden_states:
|
395 |
+
encoder_outputs.hidden_states = [h.transpose(0,1) for h in encoder_outputs.hidden_states]
|
396 |
+
|
397 |
if not return_dict:
|
398 |
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
399 |
|
|
|
401 |
last_hidden_state=sequence_output,
|
402 |
pooler_output=pooled_output,
|
403 |
past_key_values=encoder_outputs.past_key_values,
|
404 |
+
hidden_states=encoder_outputs.hidden_states,
|
405 |
attentions=encoder_outputs.attentions,
|
406 |
cross_attentions=encoder_outputs.cross_attentions,
|
407 |
)
|