rzimmerdev commited on
Commit
abe84f3
·
1 Parent(s): 987f571

feature: Added prediction and model loading methods

Browse files
Files changed (3) hide show
  1. images/.gitkeep +0 -0
  2. notebooks/predict.ipynb +0 -0
  3. src/predict.py +51 -0
images/.gitkeep ADDED
File without changes
notebooks/predict.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/predict.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+ import torch
4
+ from torch import nn
5
+ import numpy as np
6
+
7
+ import plotly.express as px
8
+ from plotly.subplots import make_subplots
9
+
10
+ from trainer import LitTrainer
11
+ from models import CNN
12
+ from dataset import DatasetMNIST, load_mnist
13
+
14
+
15
+ def load_pl_net(path="../checkpoints/lightning_logs/version_26/checkpoints/epoch=9-step=1000.ckpt"):
16
+ pl_net = LitTrainer.load_from_checkpoint(path, model=CNN(1, 10))
17
+ return pl_net
18
+
19
+
20
+ def load_torch_net(path="../checkpoints/pytorch/version_0.pt"):
21
+ net = torch.load(path)
22
+ net.eval()
23
+ return net
24
+
25
+
26
+ def get_sequence(model):
27
+ fig = make_subplots(rows=2, cols=5)
28
+
29
+ i, j = 0, np.random.randint(0, 30000)
30
+
31
+ while i < 10:
32
+ x, y = dataset[j]
33
+ y_pred = model(x.to("cuda")).detach().cpu()
34
+ p = torch.max(nn.functional.softmax(y_pred, dim=0))
35
+ y_pred = int(np.argmax(y_pred))
36
+ if y_pred == i and p > 0.95:
37
+ img = np.flip(np.array(x.reshape(28, 28)), 0)
38
+ fig.add_trace(px.imshow(img).data[0], row=int(i/5)+1, col=i%5+1)
39
+ i += 1
40
+ j += 1
41
+ return fig
42
+
43
+
44
+ if __name__ == "__main__":
45
+ mnist = load_mnist("../downloads/mnist/")
46
+ dataset, test_data = DatasetMNIST(*mnist["train"]), DatasetMNIST(*mnist["test"])
47
+
48
+ print("PyTorch Lightning Network")
49
+ get_sequence(load_pl_net().to("cuda")).write_image("images/pl_net.png")
50
+ print("Manual Network")
51
+ get_sequence(load_torch_net().to("cuda")).write_image("images/pytorch_net.png")