isLandLZ's picture
Upload train.py
d2ad7a2
import jittor as jt
import path
from jittor import nn, Module
import numpy as np
import sys, os
import random
import math
from jittor import init
from model import Model
from jittor.dataset.mnist import MNIST
import jittor.transform as trans
# if jt.flags.use_cuda = 1 will use gpu
jt.flags.use_cuda = 1
pwd_path = os.path.abspath(os.path.dirname(__file__))
def train(model, train_loader, optimizer, epoch):
model.train()
for batch_idx, (inputs, targets) in enumerate(train_loader):
outputs = model(inputs)
loss = nn.cross_entropy_loss(outputs, targets)
optimizer.step(loss)
if batch_idx % 10 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx, len(train_loader),
100. * batch_idx / len(train_loader), loss.data[0]))
def test(model, val_loader, epoch):
model.eval()
test_loss = 0
correct = 0
total_acc = 0
total_num = 0
for batch_idx, (inputs, targets) in enumerate(val_loader):
batch_size = inputs.shape[0]
outputs = model(inputs)
pred = np.argmax(outputs.data, axis=1)
acc = np.sum(targets.data == pred)
total_acc += acc
total_num += batch_size
acc = acc / batch_size
print('Test Epoch: {} [{}/{} ({:.0f}%)]\tAcc: {:.6f}'.format(epoch, batch_idx, len(val_loader), 100. * float( batch_idx ) / len(val_loader), acc))
print('Total test acc =', total_acc / total_num)
def main():
batch_size = 32
learning_rate = 0.1
momentum = 0.9
weight_decay = 1e-4
epochs = 100
train_loader = MNIST(train=True, transform=trans.Resize(28)).set_attrs(batch_size=batch_size, shuffle=True)
val_loader = MNIST(train=False, transform=trans.Resize(28)) .set_attrs(batch_size=1, shuffle=False)
model = Model()
optimizer = nn.SGD(model.parameters(), learning_rate, momentum, weight_decay)
for epoch in range(epochs):
train(model, train_loader, optimizer, epoch)
test(model, val_loader, epoch)
save_model_path = os.path.join(pwd_path, 'model/mnist_model.pkl')
model.save(save_model_path)
if __name__ == '__main__':
main()