liusx commited on
Commit
540a64e
·
verified ·
1 Parent(s): 9de2925

fix bug when using gradient_checkpointing

Browse files
Files changed (1) hide show
  1. modeling_telechat.py +4 -3
modeling_telechat.py CHANGED
@@ -43,8 +43,6 @@ except ImportError:
43
  try:
44
  from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func # flashattn2
45
  print("# FLASH ATTENTION 2 DETECTED #")
46
- r
47
- r
48
  except ImportError:
49
  print("# NO FLASH ATTENTION DETECTED #")
50
  flash_attn_unpadded_func = None
@@ -857,6 +855,8 @@ class TELECHATTransformer(TELECHATPretrainedModel):
857
  if output_hidden_states:
858
  all_hidden_states = all_hidden_states + (hidden_states,)
859
 
 
 
860
  if self.gradient_checkpointing and self.training:
861
 
862
  if use_cache:
@@ -880,6 +880,7 @@ class TELECHATTransformer(TELECHATPretrainedModel):
880
  head_mask[i],
881
  encoder_hidden_states,
882
  encoder_attention_mask,
 
883
  )
884
  else:
885
  outputs = block(
@@ -889,7 +890,7 @@ class TELECHATTransformer(TELECHATPretrainedModel):
889
  head_mask=head_mask[i],
890
  encoder_hidden_states=encoder_hidden_states,
891
  encoder_attention_mask=encoder_attention_mask,
892
- rotary_embedding=self.wpe if self.relative_encoding == 'rotary' else None,
893
  use_cache=use_cache,
894
  output_attentions=output_attentions
895
  )
 
43
  try:
44
  from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func # flashattn2
45
  print("# FLASH ATTENTION 2 DETECTED #")
 
 
46
  except ImportError:
47
  print("# NO FLASH ATTENTION DETECTED #")
48
  flash_attn_unpadded_func = None
 
855
  if output_hidden_states:
856
  all_hidden_states = all_hidden_states + (hidden_states,)
857
 
858
+ rotary_embedding=self.wpe if self.relative_encoding == 'rotary' else None
859
+
860
  if self.gradient_checkpointing and self.training:
861
 
862
  if use_cache:
 
880
  head_mask[i],
881
  encoder_hidden_states,
882
  encoder_attention_mask,
883
+ rotary_embedding
884
  )
885
  else:
886
  outputs = block(
 
890
  head_mask=head_mask[i],
891
  encoder_hidden_states=encoder_hidden_states,
892
  encoder_attention_mask=encoder_attention_mask,
893
+ rotary_embedding=rotary_embedding,
894
  use_cache=use_cache,
895
  output_attentions=output_attentions
896
  )