IEIT-Yuan commited on
Commit
4332947
1 Parent(s): 55519bf

Update yuan_hf_model.py

Browse files
Files changed (1) hide show
  1. yuan_hf_model.py +4 -3
yuan_hf_model.py CHANGED
@@ -32,8 +32,8 @@ from transformers.modeling_utils import PreTrainedModel
32
  from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
33
  from .configuration_yuan import YuanConfig
34
  from einops import rearrange
35
- from flash_attn import flash_attn_varlen_func as flash_attn_unpadded_func
36
- from flash_attn import flash_attn_func
37
 
38
  import copy
39
 
@@ -268,7 +268,8 @@ class YuanAttention(nn.Module):
268
  is_first_step = False
269
  if use_cache:
270
  if past_key_value is None:
271
- inference_hidden_states_memory = torch.empty(bsz, 2, hidden_states.shape[2], dtype=hidden_states.dtype ,device=torch.cuda.current_device())
 
272
  is_first_step = True
273
  else:
274
  before_hidden_states = past_key_value[2]
 
32
  from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
33
  from .configuration_yuan import YuanConfig
34
  from einops import rearrange
35
+ #from flash_attn import flash_attn_varlen_func as flash_attn_unpadded_func
36
+ #from flash_attn import flash_attn_func
37
 
38
  import copy
39
 
 
268
  is_first_step = False
269
  if use_cache:
270
  if past_key_value is None:
271
+ #inference_hidden_states_memory = torch.empty(bsz, 2, hidden_states.shape[2], dtype=hidden_states.dtype ,device=torch.cuda.current_device())
272
+ inference_hidden_states_memory = torch.empty(bsz, 2, hidden_states.shape[2], dtype=hidden_states.dtype)
273
  is_first_step = True
274
  else:
275
  before_hidden_states = past_key_value[2]