fix bug when using gradient_checkpointing
Browse files- modeling_telechat.py +4 -3
modeling_telechat.py
CHANGED
@@ -43,8 +43,6 @@ except ImportError:
|
|
43 |
try:
|
44 |
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func # flashattn2
|
45 |
print("# FLASH ATTENTION 2 DETECTED #")
|
46 |
-
r
|
47 |
-
r
|
48 |
except ImportError:
|
49 |
print("# NO FLASH ATTENTION DETECTED #")
|
50 |
flash_attn_unpadded_func = None
|
@@ -857,6 +855,8 @@ class TELECHATTransformer(TELECHATPretrainedModel):
|
|
857 |
if output_hidden_states:
|
858 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
859 |
|
|
|
|
|
860 |
if self.gradient_checkpointing and self.training:
|
861 |
|
862 |
if use_cache:
|
@@ -880,6 +880,7 @@ class TELECHATTransformer(TELECHATPretrainedModel):
|
|
880 |
head_mask[i],
|
881 |
encoder_hidden_states,
|
882 |
encoder_attention_mask,
|
|
|
883 |
)
|
884 |
else:
|
885 |
outputs = block(
|
@@ -889,7 +890,7 @@ class TELECHATTransformer(TELECHATPretrainedModel):
|
|
889 |
head_mask=head_mask[i],
|
890 |
encoder_hidden_states=encoder_hidden_states,
|
891 |
encoder_attention_mask=encoder_attention_mask,
|
892 |
-
rotary_embedding=
|
893 |
use_cache=use_cache,
|
894 |
output_attentions=output_attentions
|
895 |
)
|
|
|
43 |
try:
|
44 |
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func # flashattn2
|
45 |
print("# FLASH ATTENTION 2 DETECTED #")
|
|
|
|
|
46 |
except ImportError:
|
47 |
print("# NO FLASH ATTENTION DETECTED #")
|
48 |
flash_attn_unpadded_func = None
|
|
|
855 |
if output_hidden_states:
|
856 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
857 |
|
858 |
+
rotary_embedding=self.wpe if self.relative_encoding == 'rotary' else None
|
859 |
+
|
860 |
if self.gradient_checkpointing and self.training:
|
861 |
|
862 |
if use_cache:
|
|
|
880 |
head_mask[i],
|
881 |
encoder_hidden_states,
|
882 |
encoder_attention_mask,
|
883 |
+
rotary_embedding
|
884 |
)
|
885 |
else:
|
886 |
outputs = block(
|
|
|
890 |
head_mask=head_mask[i],
|
891 |
encoder_hidden_states=encoder_hidden_states,
|
892 |
encoder_attention_mask=encoder_attention_mask,
|
893 |
+
rotary_embedding=rotary_embedding,
|
894 |
use_cache=use_cache,
|
895 |
output_attentions=output_attentions
|
896 |
)
|