zwt123home123 commited on
Commit
869bf60
·
verified ·
1 Parent(s): 970e45d

Update modeling_llama.py

Browse files
Files changed (1) hide show
  1. modeling_llama.py +1 -1
modeling_llama.py CHANGED
@@ -457,7 +457,7 @@ class LlamaAttention(nn.Module):
457
  if attn_weights.shape[2]>576:
458
  # print("loading ... ")
459
  #print(value_states.shape)
460
- self.ae_v.load_state_dict(torch.load("weights_320/"+"autoencoder_epoch_1_L1_1280_nonorm_layer_"+str(self.layer_idx)+".pth", map_location='cuda'))
461
  value_states_v = value_states[:,:,35:35+576,:]
462
  value_states_v = value_states_v.permute(0, 2, 1, 3)
463
  value_states_v=value_states_v.reshape(value_states_v.shape[0],value_states_v.shape[1],5120)
 
457
  if attn_weights.shape[2]>576:
458
  # print("loading ... ")
459
  #print(value_states.shape)
460
+ self.ae_v.load_state_dict(torch.load("weights_320/"+"autoencoder_epoch_1_L1_nonorm_layer_"+str(self.layer_idx)+".pth", map_location='cuda'))
461
  value_states_v = value_states[:,:,35:35+576,:]
462
  value_states_v = value_states_v.permute(0, 2, 1, 3)
463
  value_states_v=value_states_v.reshape(value_states_v.shape[0],value_states_v.shape[1],5120)