import torch import torch.nn as nn from torchvision.transforms import Compose, ToTensor, RandomHorizontalFlip, Normalize, Resize, RandomRotation import numpy as np from torch.utils.data import DataLoader from DeePixBis.Dataset import PixWiseDataset from DeePixBis.Model import DeePixBiS from DeePixBis.Loss import PixWiseBCELoss from DeePixBis.Metrics import predict, test_accuracy, test_loss from DeePixBis.Trainer import Trainer model = DeePixBiS() model.load_state_dict(torch.load('./DeePixBiS.pth')) loss_fn = PixWiseBCELoss() opt = torch.optim.Adam(model.parameters(), lr=0.0001) train_tfms = Compose([Resize([224, 224]), RandomHorizontalFlip(), RandomRotation(10), ToTensor(), Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) test_tfms = Compose([Resize([224, 224]), ToTensor(), Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) train_dataset = PixWiseDataset('./train_data.csv', transform=train_tfms) train_ds = train_dataset.dataset() val_dataset = PixWiseDataset('./test_data.csv', transform=test_tfms) val_ds = val_dataset.dataset() batch_size = 10 train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=0, pin_memory=True) val_dl = DataLoader(val_ds, batch_size, shuffle=True, num_workers=0, pin_memory=True) # for x, y, z in val_dl: # _, zp = model(x) # print(zp) # print (z) # break # print(test_accuracy(model, train_dl)) # print(test_loss(model, train_dl, loss_fn)) # 5 epochs ran trainer = Trainer(train_dl, val_dl, model, 1, opt, loss_fn) print('Training Beginning\n') trainer.fit() print('\nTraining Complete') torch.save(model.state_dict(), './DeePixBiS.pth')