isLandLZ commited on
Commit
d2ad7a2
·
1 Parent(s): 796d9ba

Upload train.py

Browse files
Files changed (1) hide show
  1. train.py +70 -0
train.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jittor as jt
2
+ import path
3
+ from jittor import nn, Module
4
+ import numpy as np
5
+ import sys, os
6
+ import random
7
+ import math
8
+ from jittor import init
9
+ from model import Model
10
+ from jittor.dataset.mnist import MNIST
11
+ import jittor.transform as trans
12
+
13
+ # if jt.flags.use_cuda = 1 will use gpu
14
+ jt.flags.use_cuda = 1
15
+ pwd_path = os.path.abspath(os.path.dirname(__file__))
16
+
17
+
18
+ def train(model, train_loader, optimizer, epoch):
19
+ model.train()
20
+ for batch_idx, (inputs, targets) in enumerate(train_loader):
21
+ outputs = model(inputs)
22
+ loss = nn.cross_entropy_loss(outputs, targets)
23
+ optimizer.step(loss)
24
+ if batch_idx % 10 == 0:
25
+ print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
26
+ epoch, batch_idx, len(train_loader),
27
+ 100. * batch_idx / len(train_loader), loss.data[0]))
28
+
29
+
30
+ def test(model, val_loader, epoch):
31
+ model.eval()
32
+
33
+ test_loss = 0
34
+ correct = 0
35
+ total_acc = 0
36
+ total_num = 0
37
+ for batch_idx, (inputs, targets) in enumerate(val_loader):
38
+ batch_size = inputs.shape[0]
39
+ outputs = model(inputs)
40
+ pred = np.argmax(outputs.data, axis=1)
41
+ acc = np.sum(targets.data == pred)
42
+ total_acc += acc
43
+ total_num += batch_size
44
+ acc = acc / batch_size
45
+ print('Test Epoch: {} [{}/{} ({:.0f}%)]\tAcc: {:.6f}'.format(epoch, batch_idx, len(val_loader), 100. * float( batch_idx ) / len(val_loader), acc))
46
+ print('Total test acc =', total_acc / total_num)
47
+
48
+
49
+ def main():
50
+ batch_size = 32
51
+ learning_rate = 0.1
52
+ momentum = 0.9
53
+ weight_decay = 1e-4
54
+ epochs = 100
55
+ train_loader = MNIST(train=True, transform=trans.Resize(28)).set_attrs(batch_size=batch_size, shuffle=True)
56
+
57
+ val_loader = MNIST(train=False, transform=trans.Resize(28)) .set_attrs(batch_size=1, shuffle=False)
58
+
59
+ model = Model()
60
+ optimizer = nn.SGD(model.parameters(), learning_rate, momentum, weight_decay)
61
+ for epoch in range(epochs):
62
+ train(model, train_loader, optimizer, epoch)
63
+ test(model, val_loader, epoch)
64
+
65
+ save_model_path = os.path.join(pwd_path, 'model/mnist_model.pkl')
66
+ model.save(save_model_path)
67
+
68
+
69
+ if __name__ == '__main__':
70
+ main()