|
from aim import Run |
|
from aim.pytorch import track_gradients_dists, track_params_dists |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torchvision import datasets, transforms |
|
from tqdm import tqdm |
|
|
|
|
|
batch_size = 64 |
|
epochs = 10 |
|
learning_rate = 0.01 |
|
|
|
aim_run = Run() |
|
|
|
class CNN(nn.Module): |
|
def __init__(self): |
|
super(CNN, self).__init__() |
|
self.conv1 = nn.Conv2d(1, 32, 3, 1) |
|
self.conv2 = nn.Conv2d(32, 64, 3, 1) |
|
self.pool = nn.MaxPool2d(2, 2) |
|
self.fc1 = nn.Linear(64 * 7 * 7, 128) |
|
self.fc2 = nn.Linear(128, 10) |
|
|
|
def forward(self, x): |
|
x = self.pool(torch.relu(self.conv1(x))) |
|
x = self.pool(torch.relu(self.conv2(x))) |
|
x = torch.flatten(x, 1) |
|
x = torch.relu(self.fc1(x)) |
|
x = self.fc2(x) |
|
return x |
|
|
|
train_dataset = datasets.MNIST(root='./data', |
|
train=True, |
|
transform=transforms.ToTensor(), |
|
download=True) |
|
|
|
test_dataset = datasets.MNIST(root='./data', |
|
train=False, |
|
transform=transforms.ToTensor()) |
|
|
|
|
|
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, |
|
batch_size=batch_size, |
|
shuffle=True) |
|
|
|
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, |
|
batch_size=batch_size, |
|
shuffle=False) |
|
|
|
|
|
model = CNN() |
|
optimizer = optim.Adam(model.parameters(), lr=learning_rate) |
|
criterion = nn.CrossEntropyLoss() |
|
|
|
|
|
for epoch in range(epochs): |
|
model.train() |
|
train_loss = 0 |
|
correct = 0 |
|
total = 0 |
|
|
|
for batch_idx, (data, target) in enumerate(tqdm(train_loader, desc="Training", leave=False)): |
|
optimizer.zero_grad() |
|
output = model(data) |
|
loss = criterion(output, target) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
train_loss += loss.item() |
|
_, predicted = torch.max(output.data, 1) |
|
total += target.size(0) |
|
correct += (predicted == target).sum().item() |
|
|
|
|
|
acc = correct / total |
|
items = {'accuracy': acc, 'loss': train_loss / len(train_loader)} |
|
aim_run.track(items, epoch=epoch, context={'subset': 'train'}) |
|
|
|
track_params_dists(model, aim_run, epoch=epoch, context={'subset': 'train'}) |
|
track_gradients_dists(model, aim_run, epoch=epoch, context={'subset': 'train'}) |
|
|
|
|
|
model.eval() |
|
test_loss = 0 |
|
correct = 0 |
|
total = 0 |
|
|
|
with torch.no_grad(): |
|
for batch_idx, (data, target) in enumerate(tqdm(test_loader, desc="Testing", leave=False)): |
|
output = model(data) |
|
loss = criterion(output, target) |
|
test_loss += loss.item() |
|
_, predicted = torch.max(output.data, 1) |
|
total += target.size(0) |
|
correct += (predicted == target).sum().item() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
acc = correct / total |
|
items = {'accuracy': acc, 'loss': test_loss / len(test_loader)} |
|
aim_run.track(items, epoch=epoch, context={'subset': 'test'}) |
|
|
|
track_params_dists(model, aim_run, epoch=epoch, context={'subset': 'test'}) |
|
track_gradients_dists(model, aim_run, epoch=epoch, context={'subset': 'test'}) |
|
|
|
|
|
torch.save(model.state_dict(), 'mnist_cnn.pth') |