File size: 376 Bytes
7931de6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
import numpy as np

def convert():
    state_dict = torch.load("mnist_cnn.pt")
    
    tensor = {
        key: tensor.cpu().numpy() for key, tensor in state_dict.items()
    }

    for key, value in tensor.items():
        print(key, value.shape)
        
    np.savez("mnist.npz", **tensor)
    
def main():
    convert()

if __name__ == "__main__":
    main()