Upload train.py
Browse files
train.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
from torch import optim
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
from Pytorch_MNIST图片识别.model import Net
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import os
|
9 |
+
|
10 |
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
11 |
+
|
12 |
+
# TODO epoch的数量定义了我们将循环整个训练数据集的次数
|
13 |
+
n_epochs = 3
|
14 |
+
|
15 |
+
# TODO 使用batch_size=64进行训练,并使用size=1000对这个数据集进行测试
|
16 |
+
batch_size_train = 64
|
17 |
+
batch_size_test = 1000
|
18 |
+
|
19 |
+
# TODO 优化器的超参数
|
20 |
+
learning_rate = 0.01
|
21 |
+
momentum = 0.5
|
22 |
+
|
23 |
+
log_interval = 10
|
24 |
+
random_seed = 1
|
25 |
+
torch.manual_seed(random_seed)
|
26 |
+
|
27 |
+
# TODO 自动将MNIST数据集下载到目录下的data文件夹
|
28 |
+
train_loader = torch.utils.data.DataLoader(
|
29 |
+
torchvision.datasets.MNIST('./data/', train=True, download=True,
|
30 |
+
transform=torchvision.transforms.Compose([
|
31 |
+
|
32 |
+
torchvision.transforms.ToTensor(),
|
33 |
+
# TODO MNIST数据集的全局平均值和标准偏差
|
34 |
+
torchvision.transforms.Normalize(
|
35 |
+
(0.1307,), (0.3081,))
|
36 |
+
])),
|
37 |
+
batch_size=batch_size_train, shuffle=True)
|
38 |
+
test_loader = torch.utils.data.DataLoader(
|
39 |
+
torchvision.datasets.MNIST('./data/', train=False, download=True,
|
40 |
+
transform=torchvision.transforms.Compose([
|
41 |
+
torchvision.transforms.ToTensor(),
|
42 |
+
# TODO MNIST数据集的全局平均值和标准偏差
|
43 |
+
torchvision.transforms.Normalize(
|
44 |
+
(0.1307,), (0.3081,))
|
45 |
+
])),
|
46 |
+
batch_size=batch_size_test, shuffle=False)
|
47 |
+
|
48 |
+
# TODO 初始化网络和优化器
|
49 |
+
network = Net()
|
50 |
+
optimizer = optim.SGD(network.parameters(), lr=learning_rate, momentum=momentum)
|
51 |
+
|
52 |
+
train_losses = []
|
53 |
+
train_counter = []
|
54 |
+
test_losses = []
|
55 |
+
test_counter = [i * len(train_loader.dataset) for i in range(n_epochs + 1)]
|
56 |
+
|
57 |
+
# TODO 模型存储位置(一个是完整的模型,一个是只有参数的模型)
|
58 |
+
# TODO 需要先建立一个model文件夹
|
59 |
+
model_path = './model1/model.pth'
|
60 |
+
optimizer_path = './model1/optimizer.pth'
|
61 |
+
|
62 |
+
|
63 |
+
def train(epoch):
|
64 |
+
network.train()
|
65 |
+
for batch_idx, (data, target) in enumerate(train_loader):
|
66 |
+
# TODO 需要使用optimizer.zero_grad()手动将梯度设置为零,因为PyTorch在默认情况下会累积梯度
|
67 |
+
optimizer.zero_grad()
|
68 |
+
output = network(data)
|
69 |
+
loss = F.nll_loss(output, target)
|
70 |
+
loss.backward()
|
71 |
+
optimizer.step()
|
72 |
+
if batch_idx % log_interval == 0:
|
73 |
+
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
74 |
+
epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item()))
|
75 |
+
train_losses.append(loss.item())
|
76 |
+
train_counter.append(
|
77 |
+
(batch_idx * 64) + ((epoch - 1) * len(train_loader.dataset)))
|
78 |
+
if epoch == (n_epochs - 1):
|
79 |
+
# TODO 存储模型
|
80 |
+
torch.save(network.state_dict(), model_path)
|
81 |
+
torch.save(optimizer.state_dict(), optimizer_path)
|
82 |
+
|
83 |
+
|
84 |
+
def test():
|
85 |
+
network.eval()
|
86 |
+
test_loss = 0
|
87 |
+
correct = 0
|
88 |
+
with torch.no_grad():
|
89 |
+
for data, target in test_loader:
|
90 |
+
output = network(data)
|
91 |
+
test_loss += F.nll_loss(output, target, size_average=False).item()
|
92 |
+
pred = output.data.max(1, keepdim=True)[1]
|
93 |
+
correct += pred.eq(target.data.view_as(pred)).sum()
|
94 |
+
test_loss /= len(test_loader.dataset)
|
95 |
+
test_losses.append(test_loss)
|
96 |
+
print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
|
97 |
+
test_loss, correct, len(test_loader.dataset),
|
98 |
+
100. * correct / len(test_loader.dataset)))
|
99 |
+
|
100 |
+
|
101 |
+
for epoch in range(1, n_epochs + 1):
|
102 |
+
train(epoch)
|
103 |
+
test()
|
104 |
+
|
105 |
+
fig = plt.figure()
|
106 |
+
plt.plot(train_counter, train_losses, color='blue')
|
107 |
+
print(len(test_counter))
|
108 |
+
print(len(test_losses))
|
109 |
+
plt.scatter(test_counter, test_losses, color='red')
|
110 |
+
plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
|
111 |
+
plt.xlabel('number of training examples seen')
|
112 |
+
plt.ylabel('negative log likelihood loss')
|
113 |
+
plt.show()
|