import torch | |
def train(model, trainloader, optimizer, criterion, DEVICE): | |
model.train() | |
running_loss = 0 | |
for itr, data in enumerate(trainloader): | |
# print(itr) | |
# print(data[0].shape, data[1].shape) | |
# print(len(trainloader)) | |
# if itr % 100 == 0: | |
# print("itr: {}".format(itr)) | |
optimizer.zero_grad() | |
imgs, target = data[0].to(DEVICE), data[1].to(DEVICE) | |
output_logits = model(imgs) | |
loss = criterion( output_logits, target) | |
running_loss = loss.item() | |
loss.backward() | |
optimizer.step() | |
epoch_loss = running_loss/len(trainloader) | |
print("epoch loss = {}".format(epoch_loss)) | |
return epoch_loss |