Update modeling_llama.py
Browse files- modeling_llama.py +1 -1
modeling_llama.py
CHANGED
@@ -457,7 +457,7 @@ class LlamaAttention(nn.Module):
|
|
457 |
if attn_weights.shape[2]>576:
|
458 |
# print("loading ... ")
|
459 |
#print(value_states.shape)
|
460 |
-
self.ae_v.load_state_dict(torch.load("weights_320/"+"
|
461 |
value_states_v = value_states[:,:,35:35+576,:]
|
462 |
value_states_v = value_states_v.permute(0, 2, 1, 3)
|
463 |
value_states_v=value_states_v.reshape(value_states_v.shape[0],value_states_v.shape[1],5120)
|
|
|
457 |
if attn_weights.shape[2]>576:
|
458 |
# print("loading ... ")
|
459 |
#print(value_states.shape)
|
460 |
+
self.ae_v.load_state_dict(torch.load("weights_320/"+"autoencoder_epoch_1_L1_nonorm_layer_"+str(self.layer_idx)+".pth", map_location='cuda'))
|
461 |
value_states_v = value_states[:,:,35:35+576,:]
|
462 |
value_states_v = value_states_v.permute(0, 2, 1, 3)
|
463 |
value_states_v=value_states_v.reshape(value_states_v.shape[0],value_states_v.shape[1],5120)
|