UpsideDownDetector / train.py
Jauhar's picture
final commit to hf
deb7039
raw
history blame
742 Bytes
import torch
def train(model, trainloader, optimizer, criterion, DEVICE):
model.train()
running_loss = 0
for itr, data in enumerate(trainloader):
# print(itr)
# print(data[0].shape, data[1].shape)
# print(len(trainloader))
# if itr % 100 == 0:
# print("itr: {}".format(itr))
optimizer.zero_grad()
imgs, target = data[0].to(DEVICE), data[1].to(DEVICE)
output_logits = model(imgs)
loss = criterion( output_logits, target)
running_loss = loss.item()
loss.backward()
optimizer.step()
epoch_loss = running_loss/len(trainloader)
print("epoch loss = {}".format(epoch_loss))
return epoch_loss