ChufanSuki commited on
Commit
201fea9
1 Parent(s): 555cd3e
Files changed (2) hide show
  1. lenet5.py +115 -0
  2. lenet_mnist_model.pth +3 -0
lenet5.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from torchvision import transforms
3
+ from torch.utils.data import DataLoader
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ import torch.nn.functional as F
8
+ import numpy as np
9
+
10
+
11
+ class LeNet(nn.Module):
12
+ def __init__(self):
13
+ super(LeNet, self).__init__()
14
+ self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0)
15
+ self.relu1 = nn.ReLU()
16
+ self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
17
+
18
+ self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0)
19
+ self.relu2 = nn.ReLU()
20
+ self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
21
+
22
+ self.fc1 = nn.Linear(256, 120)
23
+ self.relu3 = nn.ReLU()
24
+ self.fc2 = nn.Linear(120, 84)
25
+ self.relu4 = nn.ReLU()
26
+ self.fc3 = nn.Linear(84, 10)
27
+
28
+ def forward(self, x):
29
+ y = self.conv1(x)
30
+ y = self.relu1(y)
31
+ y = self.pool1(y)
32
+
33
+ y = self.conv2(y)
34
+ y = self.relu2(y)
35
+ y = self.pool2(y)
36
+
37
+ y = y.view(y.shape[0], -1)
38
+
39
+ y = self.fc1(y)
40
+ y = self.relu3(y)
41
+
42
+ y = self.fc2(y)
43
+ y = self.relu4(y)
44
+
45
+ y = self.fc3(y)
46
+ return y
47
+
48
+
49
+ def train(model, device, train_loader, optimizer, epoch):
50
+ model.train()
51
+ for batch_idx, batch in enumerate(train_loader, 0):
52
+ data, target = batch["image"].to(device), batch["label"].to(device)
53
+ optimizer.zero_grad()
54
+ output = model(data.float())
55
+ loss = F.cross_entropy(output, target.long())
56
+ loss.backward()
57
+ optimizer.step()
58
+ if batch_idx % 100 == 0:
59
+ print(
60
+ f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
61
+ )
62
+
63
+
64
+ if __name__ == "__main__":
65
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
+ model = LeNet().to(device)
67
+ optimizer = optim.Adam(model.parameters(), lr=2e-3)
68
+
69
+ dataset = load_dataset("ylecun/mnist")
70
+ transform = transforms.Compose(
71
+ [
72
+ transforms.ToTensor(),
73
+ transforms.Resize((32, 32)),
74
+ transforms.Normalize(mean=(0.1307,), std=(0.3081,)), # MNIST mean and std
75
+ ]
76
+ )
77
+ train_dataset = dataset["train"]
78
+ train_dataset.set_format(type="torch")
79
+
80
+ def transform_example(example):
81
+ # Convert to PIL Image to apply torchvision transforms
82
+ # img = Image.fromarray(example["image"].astype(np.uint8))
83
+ img = example["image"].numpy()
84
+ return {"image": transform(img), "label": example["label"]}
85
+
86
+ train_dataset.map(transform_example)
87
+ test_dataset = dataset["test"]
88
+ test_dataset.set_format(type="torch")
89
+ test_dataset.map(transform_example)
90
+
91
+ # Data loaders
92
+ train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
93
+ test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)
94
+
95
+ for epoch in range(1, 15):
96
+ train(model, device, train_loader, optimizer, epoch)
97
+
98
+ with torch.no_grad():
99
+ correct = 0
100
+ total = 0
101
+ for batch_idx, batch in enumerate(train_loader, 0):
102
+ images, labels = batch["image"].to(device), batch["label"].to(device)
103
+ outputs = model(images.float()).detach()
104
+ predicted = torch.argmax(outputs.data, dim=-1)
105
+ total += labels.size(0)
106
+ correct += (predicted == labels).sum().item()
107
+
108
+ print(
109
+ "Accuracy of the network on the 10000 test images: {} %".format(
110
+ 100 * correct / total
111
+ )
112
+ )
113
+
114
+ torch.save(model.state_dict(), "lenet_mnist_model.pth")
115
+ print("Saved PyTorch Model State to lenet_mnist_model.pth")
lenet_mnist_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec17b644022a61d2639fe7f993d00b98e6fe2f72ffdbd7ed19ecd8a72f220b54
3
+ size 181508