fix encoder decoder
Browse files- modeling_lsg_pegasus.py +1 -1
modeling_lsg_pegasus.py
CHANGED
@@ -1033,7 +1033,7 @@ class LSGPegasusModel(LSGPegasusPreTrainedModel, PegasusModel):
|
|
1033 |
)
|
1034 |
|
1035 |
# Pad mask if we keep globals
|
1036 |
-
if self.pass_global_tokens_to_decoder:
|
1037 |
attention_mask = torch.nn.functional.pad(attention_mask, pad=(self.num_global_tokens, 0), value=1)
|
1038 |
|
1039 |
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
|
|
|
1033 |
)
|
1034 |
|
1035 |
# Pad mask if we keep globals
|
1036 |
+
if self.pass_global_tokens_to_decoder and attention_mask is not None:
|
1037 |
attention_mask = torch.nn.functional.pad(attention_mask, pad=(self.num_global_tokens, 0), value=1)
|
1038 |
|
1039 |
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
|