|
import utils, torch, time, os, pickle |
|
import numpy as np |
|
import torch.nn as nn |
|
import torch.cuda as cu |
|
import torch.optim as optim |
|
import pickle |
|
from torchvision import transforms |
|
from torchvision.utils import save_image |
|
from utils import augmentData, RGBtoL, LtoRGB |
|
from PIL import Image |
|
from dataloader import dataloader |
|
from torch.autograd import Variable |
|
import matplotlib.pyplot as plt |
|
import random |
|
from datetime import date |
|
from statistics import mean |
|
from architectures import depth_generator_UNet, \ |
|
depth_discriminator_noclass_UNet |
|
|
|
|
|
class WiggleGAN(object): |
|
def __init__(self, args): |
|
|
|
self.epoch = args.epoch |
|
self.sample_num = 100 |
|
self.nCameras = args.cameras |
|
self.batch_size = args.batch_size |
|
self.save_dir = args.save_dir |
|
self.result_dir = args.result_dir |
|
self.dataset = args.dataset |
|
self.log_dir = args.log_dir |
|
self.gpu_mode = args.gpu_mode |
|
self.model_name = args.gan_type |
|
self.input_size = args.input_size |
|
self.class_num = (args.cameras - 1) * 2 |
|
self.sample_num = self.class_num ** 2 |
|
self.imageDim = args.imageDim |
|
self.epochVentaja = args.epochV |
|
self.cantImages = args.cIm |
|
self.visdom = args.visdom |
|
self.lambdaL1 = args.lambdaL1 |
|
self.depth = args.depth |
|
self.name_wiggle = args.name_wiggle |
|
|
|
self.clipping = args.clipping |
|
self.WGAN = False |
|
if (self.clipping > 0): |
|
self.WGAN = True |
|
|
|
self.seed = str(random.randint(0, 99999)) |
|
self.seed_load = args.seedLoad |
|
self.toLoad = False |
|
if (self.seed_load != "-0000"): |
|
self.toLoad = True |
|
|
|
self.zGenFactor = args.zGF |
|
self.zDisFactor = args.zDF |
|
self.bFactor = args.bF |
|
self.CR = False |
|
if (self.zGenFactor > 0 or self.zDisFactor > 0 or self.bFactor > 0): |
|
self.CR = True |
|
|
|
self.expandGen = args.expandGen |
|
self.expandDis = args.expandDis |
|
|
|
self.wiggleDepth = args.wiggleDepth |
|
self.wiggle = False |
|
if (self.wiggleDepth > 0): |
|
self.wiggle = True |
|
|
|
|
|
|
|
|
|
|
|
self.onlyGen = args.lrD <= 0 |
|
|
|
if not self.wiggle: |
|
self.data_loader = dataloader(self.dataset, self.input_size, self.batch_size, self.imageDim, split='train', |
|
trans=not self.CR) |
|
|
|
self.data_Validation = dataloader(self.dataset, self.input_size, self.batch_size, self.imageDim, |
|
split='validation') |
|
|
|
self.dataprint = self.data_Validation.__iter__().__next__() |
|
|
|
data = self.data_loader.__iter__().__next__().get('x_im') |
|
|
|
|
|
if not self.onlyGen: |
|
self.D = depth_discriminator_noclass_UNet(input_dim=3, output_dim=1, input_shape=data.shape, |
|
class_num=self.class_num, |
|
expand_net=self.expandDis, depth = self.depth, wgan = self.WGAN) |
|
self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2)) |
|
|
|
self.data_Test = dataloader(self.dataset, self.input_size, self.batch_size, self.imageDim, split='test') |
|
self.dataprint_test = self.data_Test.__iter__().__next__() |
|
|
|
|
|
|
|
self.G = depth_generator_UNet(input_dim=4, output_dim=3, class_num=self.class_num, expand_net=self.expandGen, depth = self.depth) |
|
self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2)) |
|
|
|
|
|
if self.gpu_mode: |
|
self.G.cuda() |
|
if not self.wiggle and not self.onlyGen: |
|
self.D.cuda() |
|
self.BCE_loss = nn.BCELoss().cuda() |
|
self.CE_loss = nn.CrossEntropyLoss().cuda() |
|
self.L1 = nn.L1Loss().cuda() |
|
self.MSE = nn.MSELoss().cuda() |
|
self.BCEWithLogitsLoss = nn.BCEWithLogitsLoss().cuda() |
|
else: |
|
self.BCE_loss = nn.BCELoss() |
|
self.CE_loss = nn.CrossEntropyLoss() |
|
self.MSE = nn.MSELoss() |
|
self.L1 = nn.L1Loss() |
|
self.BCEWithLogitsLoss = nn.BCEWithLogitsLoss() |
|
|
|
print('---------- Networks architecture -------------') |
|
utils.print_network(self.G) |
|
if not self.wiggle and not self.onlyGen: |
|
utils.print_network(self.D) |
|
print('-----------------------------------------------') |
|
|
|
temp = torch.zeros((self.class_num, 1)) |
|
for i in range(self.class_num): |
|
temp[i, 0] = i |
|
|
|
temp_y = torch.zeros((self.sample_num, 1)) |
|
for i in range(self.class_num): |
|
temp_y[i * self.class_num: (i + 1) * self.class_num] = temp |
|
|
|
self.sample_y_ = torch.zeros((self.sample_num, self.class_num)).scatter_(1, temp_y.type(torch.LongTensor), 1) |
|
if self.gpu_mode: |
|
self.sample_y_ = self.sample_y_.cuda() |
|
|
|
if (self.toLoad): |
|
self.load() |
|
|
|
def train(self): |
|
|
|
if self.visdom: |
|
random.seed(time.time()) |
|
today = date.today() |
|
|
|
vis = utils.VisdomLinePlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed) |
|
visValidation = utils.VisdomLinePlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed) |
|
visEpoch = utils.VisdomLineTwoPlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed) |
|
visImages = utils.VisdomImagePlotter(env_name='Cobo_depth_Images_' + str(today) + '_' + self.seed) |
|
visImagesTest = utils.VisdomImagePlotter(env_name='Cobo_depth_ImagesTest_' + str(today) + '_' + self.seed) |
|
|
|
visLossGTest = utils.VisdomLinePlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed) |
|
visLossGValidation = utils.VisdomLinePlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed) |
|
|
|
visLossDTest = utils.VisdomLinePlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed) |
|
visLossDValidation = utils.VisdomLinePlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed) |
|
|
|
self.train_hist = {} |
|
self.epoch_hist = {} |
|
self.details_hist = {} |
|
self.train_hist['D_loss_train'] = [] |
|
self.train_hist['G_loss_train'] = [] |
|
self.train_hist['D_loss_Validation'] = [] |
|
self.train_hist['G_loss_Validation'] = [] |
|
self.train_hist['per_epoch_time'] = [] |
|
self.train_hist['total_time'] = [] |
|
|
|
self.details_hist['G_T_Comp_im'] = [] |
|
self.details_hist['G_T_BCE_fake_real'] = [] |
|
self.details_hist['G_T_Cycle'] = [] |
|
self.details_hist['G_zCR'] = [] |
|
|
|
self.details_hist['G_V_Comp_im'] = [] |
|
self.details_hist['G_V_BCE_fake_real'] = [] |
|
self.details_hist['G_V_Cycle'] = [] |
|
|
|
self.details_hist['D_T_BCE_fake_real_R'] = [] |
|
self.details_hist['D_T_BCE_fake_real_F'] = [] |
|
self.details_hist['D_zCR'] = [] |
|
self.details_hist['D_bCR'] = [] |
|
|
|
self.details_hist['D_V_BCE_fake_real_R'] = [] |
|
self.details_hist['D_V_BCE_fake_real_F'] = [] |
|
|
|
self.epoch_hist['D_loss_train'] = [] |
|
self.epoch_hist['G_loss_train'] = [] |
|
self.epoch_hist['D_loss_Validation'] = [] |
|
self.epoch_hist['G_loss_Validation'] = [] |
|
|
|
|
|
iterIniTrain = 0 |
|
iterFinTrain = 0 |
|
|
|
iterIniValidation = 0 |
|
iterFinValidation = 0 |
|
|
|
maxIter = self.data_loader.dataset.__len__() // self.batch_size |
|
maxIterVal = self.data_Validation.dataset.__len__() // self.batch_size |
|
|
|
if (self.WGAN): |
|
one = torch.tensor(1, dtype=torch.float).cuda() |
|
mone = one * -1 |
|
else: |
|
self.y_real_ = torch.ones(self.batch_size, 1) |
|
self.y_fake_ = torch.zeros(self.batch_size, 1) |
|
if self.gpu_mode: |
|
self.y_real_, self.y_fake_ = self.y_real_.cuda(), self.y_fake_.cuda() |
|
|
|
print('training start!!') |
|
start_time = time.time() |
|
|
|
for epoch in range(self.epoch): |
|
|
|
if (epoch < self.epochVentaja): |
|
ventaja = True |
|
else: |
|
ventaja = False |
|
|
|
self.G.train() |
|
|
|
if not self.onlyGen: |
|
self.D.train() |
|
epoch_start_time = time.time() |
|
|
|
|
|
|
|
for iter, data in enumerate(self.data_loader): |
|
|
|
x_im = data.get('x_im') |
|
x_dep = data.get('x_dep') |
|
y_im = data.get('y_im') |
|
y_dep = data.get('y_dep') |
|
y_ = data.get('y_') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (self.CR): |
|
x_im_aug, y_im_aug = augmentData(x_im, y_im) |
|
x_im_vanilla = x_im |
|
|
|
if self.gpu_mode: |
|
x_im_aug, y_im_aug = x_im_aug.cuda(), y_im_aug.cuda() |
|
|
|
if iter >= maxIter: |
|
break |
|
|
|
if self.gpu_mode: |
|
x_im, y_, y_im, x_dep, y_dep = x_im.cuda(), y_.cuda(), y_im.cuda(), x_dep.cuda(), y_dep.cuda() |
|
|
|
|
|
if not ventaja and not self.onlyGen: |
|
|
|
for p in self.D.parameters(): |
|
p.requires_grad = True |
|
|
|
self.D_optimizer.zero_grad() |
|
|
|
|
|
D_real, D_features_real = self.D(y_im, x_im, y_dep, y_) |
|
|
|
|
|
G_, G_dep = self.G( y_, x_im, x_dep) |
|
D_fake, D_features_fake = self.D(G_, x_im, G_dep, y_) |
|
|
|
|
|
|
|
if (self.WGAN): |
|
D_loss_real_fake_R = - torch.mean(D_real) |
|
D_loss_real_fake_F = torch.mean(D_fake) |
|
|
|
|
|
else: |
|
D_loss_real_fake_R = self.BCEWithLogitsLoss(D_real, self.y_real_) |
|
D_loss_real_fake_F = self.BCEWithLogitsLoss(D_fake, self.y_fake_) |
|
|
|
D_loss = D_loss_real_fake_F + D_loss_real_fake_R |
|
|
|
if self.CR: |
|
|
|
|
|
x_im_aug_bCR, G_aug_bCR = augmentData(x_im_vanilla, G_.data.cpu()) |
|
|
|
if self.gpu_mode: |
|
G_aug_bCR, x_im_aug_bCR = G_aug_bCR.cuda(), x_im_aug_bCR.cuda() |
|
|
|
D_fake_bCR, D_features_fake_bCR = self.D(G_aug_bCR, x_im_aug_bCR, G_dep, y_) |
|
D_real_bCR, D_features_real_bCR = self.D(y_im_aug, x_im_aug, y_dep, y_) |
|
|
|
|
|
G_aug_zCR, G_dep_aug_zCR = self.G(y_, x_im_aug, x_dep) |
|
D_fake_aug_zCR, D_features_fake_aug_zCR = self.D(G_aug_zCR, x_im_aug, G_dep_aug_zCR, y_) |
|
|
|
|
|
D_loss_real = self.MSE(D_features_real, D_features_real_bCR) |
|
D_loss_fake = self.MSE(D_features_fake, D_features_fake_bCR) |
|
D_bCR = (D_loss_real + D_loss_fake) * self.bFactor |
|
|
|
|
|
D_zCR = self.MSE(D_features_fake, D_features_fake_aug_zCR) * self.zDisFactor |
|
|
|
D_CR_losses = D_bCR + D_zCR |
|
|
|
|
|
D_loss += D_CR_losses |
|
|
|
self.details_hist['D_bCR'].append(D_bCR.detach().item()) |
|
self.details_hist['D_zCR'].append(D_zCR.detach().item()) |
|
else: |
|
self.details_hist['D_bCR'].append(0) |
|
self.details_hist['D_zCR'].append(0) |
|
|
|
self.train_hist['D_loss_train'].append(D_loss.detach().item()) |
|
self.details_hist['D_T_BCE_fake_real_R'].append(D_loss_real_fake_R.detach().item()) |
|
self.details_hist['D_T_BCE_fake_real_F'].append(D_loss_real_fake_F.detach().item()) |
|
if self.visdom: |
|
visLossDTest.plot('Discriminator_losses', |
|
['D_T_BCE_fake_real_R','D_T_BCE_fake_real_F', 'D_bCR', 'D_zCR'], 'train', |
|
self.details_hist) |
|
|
|
|
|
|
|
|
|
|
|
D_loss.backward() |
|
|
|
self.D_optimizer.step() |
|
|
|
|
|
if (self.WGAN): |
|
for p in self.D.parameters(): |
|
p.data.clamp_(-self.clipping, self.clipping) |
|
|
|
|
|
|
|
|
|
self.G_optimizer.zero_grad() |
|
|
|
G_, G_dep = self.G(y_, x_im, x_dep) |
|
|
|
if not ventaja and not self.onlyGen: |
|
for p in self.D.parameters(): |
|
p.requires_grad = False |
|
|
|
|
|
D_fake, _ = self.D(G_, x_im, G_dep, y_) |
|
|
|
if (self.WGAN): |
|
G_loss_fake = -torch.mean(D_fake) |
|
else: |
|
G_loss_fake = self.BCEWithLogitsLoss(D_fake, self.y_real_) |
|
|
|
|
|
|
|
|
|
|
|
G_loss_Comp = self.L1(G_, y_im) |
|
if self.depth: |
|
G_loss_Comp += self.L1(G_dep, y_dep) |
|
|
|
G_loss_Dif_Comp = G_loss_Comp * self.lambdaL1 |
|
|
|
reverse_y = - y_ + 1 |
|
reverse_G, reverse_G_dep = self.G(reverse_y, G_, G_dep) |
|
G_loss_Cycle = self.L1(reverse_G, x_im) |
|
if self.depth: |
|
G_loss_Cycle += self.L1(reverse_G_dep, x_dep) |
|
G_loss_Cycle = G_loss_Cycle * self.lambdaL1/2 |
|
|
|
|
|
if (self.CR): |
|
|
|
|
|
G_aug, G_dep_aug = self.G(y_, x_im_aug, x_dep) |
|
D_fake_aug, _ = self.D(G_aug, x_im, G_dep_aug, y_) |
|
|
|
if (self.WGAN): |
|
G_loss_fake = - (torch.mean(D_fake)+torch.mean(D_fake_aug))/2 |
|
else: |
|
G_loss_fake = ( self.BCEWithLogitsLoss(D_fake, self.y_real_) + |
|
self.BCEWithLogitsLoss(D_fake_aug,self.y_real_)) / 2 |
|
|
|
|
|
|
|
|
|
|
|
G_loss_Comp_Aug = self.L1(G_aug, y_im_aug) |
|
if self.depth: |
|
G_loss_Comp_Aug += self.L1(G_dep_aug, y_dep) |
|
G_loss_Dif_Comp = (G_loss_Comp + G_loss_Comp_Aug)/2 * self.lambdaL1 |
|
|
|
|
|
G_loss = G_loss_fake + G_loss_Dif_Comp + G_loss_Cycle |
|
|
|
self.details_hist['G_T_BCE_fake_real'].append(G_loss_fake.detach().item()) |
|
self.details_hist['G_T_Comp_im'].append(G_loss_Dif_Comp.detach().item()) |
|
self.details_hist['G_T_Cycle'].append(G_loss_Cycle.detach().item()) |
|
self.details_hist['G_zCR'].append(0) |
|
|
|
|
|
else: |
|
|
|
G_loss = self.L1(G_, y_im) |
|
if self.depth: |
|
G_loss += self.L1(G_dep, y_dep) |
|
G_loss = G_loss * self.lambdaL1 |
|
self.details_hist['G_T_Comp_im'].append(G_loss.detach().item()) |
|
self.details_hist['G_T_BCE_fake_real'].append(0) |
|
self.details_hist['G_T_Cycle'].append(0) |
|
self.details_hist['G_zCR'].append(0) |
|
|
|
G_loss.backward() |
|
self.G_optimizer.step() |
|
self.train_hist['G_loss_train'].append(G_loss.detach().item()) |
|
if self.onlyGen: |
|
self.train_hist['D_loss_train'].append(0) |
|
|
|
iterFinTrain += 1 |
|
|
|
if self.visdom: |
|
visLossGTest.plot('Generator_losses', |
|
['G_T_Comp_im', 'G_T_BCE_fake_real', 'G_zCR','G_T_Cycle'], |
|
'train', self.details_hist) |
|
|
|
vis.plot('loss', ['D_loss_train', 'G_loss_train'], 'train', self.train_hist) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
self.G.eval() |
|
if not self.onlyGen: |
|
self.D.eval() |
|
|
|
for iter, data in enumerate(self.data_Validation): |
|
|
|
|
|
x_im = data.get('x_im') |
|
x_dep = data.get('x_dep') |
|
y_im = data.get('y_im') |
|
y_dep = data.get('y_dep') |
|
y_ = data.get('y_') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if iter == maxIterVal: |
|
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.gpu_mode: |
|
x_im, y_, y_im, x_dep, y_dep = x_im.cuda(), y_.cuda(), y_im.cuda(), x_dep.cuda(), y_dep.cuda() |
|
|
|
|
|
if not ventaja and not self.onlyGen: |
|
|
|
D_real, _ = self.D(y_im, x_im, y_dep,y_) |
|
|
|
|
|
G_, G_dep = self.G(y_, x_im, x_dep) |
|
D_fake, _ = self.D(G_, x_im, G_dep, y_) |
|
|
|
|
|
if (self.WGAN): |
|
D_loss_real_fake_R = - torch.mean(D_real) |
|
D_loss_real_fake_F = torch.mean(D_fake) |
|
|
|
else: |
|
D_loss_real_fake_R = self.BCEWithLogitsLoss(D_real, self.y_real_) |
|
D_loss_real_fake_F = self.BCEWithLogitsLoss(D_fake, self.y_fake_) |
|
|
|
D_loss_real_fake = D_loss_real_fake_F + D_loss_real_fake_R |
|
|
|
D_loss = D_loss_real_fake |
|
|
|
self.train_hist['D_loss_Validation'].append(D_loss.item()) |
|
self.details_hist['D_V_BCE_fake_real_R'].append(D_loss_real_fake_R.item()) |
|
self.details_hist['D_V_BCE_fake_real_F'].append(D_loss_real_fake_F.item()) |
|
if self.visdom: |
|
visLossDValidation.plot('Discriminator_losses', |
|
['D_V_BCE_fake_real_R','D_V_BCE_fake_real_F'], 'Validation', |
|
self.details_hist) |
|
|
|
|
|
|
|
G_, G_dep = self.G(y_, x_im, x_dep) |
|
|
|
if not ventaja and not self.onlyGen: |
|
|
|
D_fake,_ = self.D(G_, x_im, G_dep, y_) |
|
|
|
|
|
if (self.WGAN): |
|
G_loss = -torch.mean(D_fake) |
|
else: |
|
G_loss = self.BCEWithLogitsLoss(D_fake, self.y_real_) |
|
|
|
self.details_hist['G_V_BCE_fake_real'].append(G_loss.item()) |
|
|
|
|
|
|
|
|
|
|
|
G_loss_Comp = self.L1(G_, y_im) |
|
if self.depth: |
|
G_loss_Comp += self.L1(G_dep, y_dep) |
|
G_loss_Comp = G_loss_Comp * self.lambdaL1 |
|
|
|
reverse_y = - y_ + 1 |
|
reverse_G, reverse_G_dep = self.G(reverse_y, G_, G_dep) |
|
G_loss_Cycle = self.L1(reverse_G, x_im) |
|
if self.depth: |
|
G_loss_Cycle += self.L1(reverse_G_dep, x_dep) |
|
G_loss_Cycle = G_loss_Cycle * self.lambdaL1/2 |
|
|
|
G_loss += G_loss_Comp + G_loss_Cycle |
|
|
|
|
|
self.details_hist['G_V_Comp_im'].append(G_loss_Comp.item()) |
|
self.details_hist['G_V_Cycle'].append(G_loss_Cycle.detach().item()) |
|
|
|
else: |
|
G_loss = self.L1(G_, y_im) |
|
if self.depth: |
|
G_loss += self.L1(G_dep, y_dep) |
|
G_loss = G_loss * self.lambdaL1 |
|
self.details_hist['G_V_Comp_im'].append(G_loss.item()) |
|
self.details_hist['G_V_BCE_fake_real'].append(0) |
|
self.details_hist['G_V_Cycle'].append(0) |
|
|
|
self.train_hist['G_loss_Validation'].append(G_loss.item()) |
|
if self.onlyGen: |
|
self.train_hist['D_loss_Validation'].append(0) |
|
|
|
|
|
iterFinValidation += 1 |
|
if self.visdom: |
|
visLossGValidation.plot('Generator_losses', ['G_V_Comp_im', 'G_V_BCE_fake_real','G_V_Cycle'], |
|
'Validation', self.details_hist) |
|
visValidation.plot('loss', ['D_loss_Validation', 'G_loss_Validation'], 'Validation', |
|
self.train_hist) |
|
|
|
|
|
|
|
if ventaja or self.onlyGen: |
|
self.epoch_hist['D_loss_train'].append(0) |
|
self.epoch_hist['D_loss_Validation'].append(0) |
|
else: |
|
|
|
|
|
self.epoch_hist['D_loss_train'].append(mean(self.train_hist['D_loss_train'][iterIniTrain: -1])) |
|
self.epoch_hist['D_loss_Validation'].append(mean(self.train_hist['D_loss_Validation'][iterIniValidation: -1])) |
|
|
|
self.epoch_hist['G_loss_train'].append(mean(self.train_hist['G_loss_train'][iterIniTrain:iterFinTrain])) |
|
self.epoch_hist['G_loss_Validation'].append( |
|
mean(self.train_hist['G_loss_Validation'][iterIniValidation:iterFinValidation])) |
|
if self.visdom: |
|
visEpoch.plot('epoch', epoch, |
|
['D_loss_train', 'G_loss_train', 'D_loss_Validation', 'G_loss_Validation'], |
|
self.epoch_hist) |
|
|
|
self.train_hist['D_loss_train'] = self.train_hist['D_loss_train'][-1:] |
|
self.train_hist['G_loss_train'] = self.train_hist['G_loss_train'][-1:] |
|
self.train_hist['D_loss_Validation'] = self.train_hist['D_loss_Validation'][-1:] |
|
self.train_hist['G_loss_Validation'] = self.train_hist['G_loss_Validation'][-1:] |
|
self.train_hist['per_epoch_time'] = self.train_hist['per_epoch_time'][-1:] |
|
self.train_hist['total_time'] = self.train_hist['total_time'][-1:] |
|
|
|
self.details_hist['G_T_Comp_im'] = self.details_hist['G_T_Comp_im'][-1:] |
|
self.details_hist['G_T_BCE_fake_real'] = self.details_hist['G_T_BCE_fake_real'][-1:] |
|
self.details_hist['G_T_Cycle'] = self.details_hist['G_T_Cycle'][-1:] |
|
self.details_hist['G_zCR'] = self.details_hist['G_zCR'][-1:] |
|
|
|
self.details_hist['G_V_Comp_im'] = self.details_hist['G_V_Comp_im'][-1:] |
|
self.details_hist['G_V_BCE_fake_real'] = self.details_hist['G_V_BCE_fake_real'][-1:] |
|
self.details_hist['G_V_Cycle'] = self.details_hist['G_V_Cycle'][-1:] |
|
|
|
self.details_hist['D_T_BCE_fake_real_R'] = self.details_hist['D_T_BCE_fake_real_R'][-1:] |
|
self.details_hist['D_T_BCE_fake_real_F'] = self.details_hist['D_T_BCE_fake_real_F'][-1:] |
|
self.details_hist['D_zCR'] = self.details_hist['D_zCR'][-1:] |
|
self.details_hist['D_bCR'] = self.details_hist['D_bCR'][-1:] |
|
|
|
self.details_hist['D_V_BCE_fake_real_R'] = self.details_hist['D_V_BCE_fake_real_R'][-1:] |
|
self.details_hist['D_V_BCE_fake_real_F'] = self.details_hist['D_V_BCE_fake_real_F'][-1:] |
|
|
|
iterIniTrain = 1 |
|
iterFinTrain = 1 |
|
|
|
iterIniValidation = 1 |
|
iterFinValidation = 1 |
|
|
|
self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) |
|
|
|
if epoch % 10 == 0: |
|
self.save(str(epoch)) |
|
with torch.no_grad(): |
|
if self.visdom: |
|
self.visualize_results(epoch, dataprint=self.dataprint, visual=visImages) |
|
self.visualize_results(epoch, dataprint=self.dataprint_test, visual=visImagesTest) |
|
else: |
|
imageName = self.model_name + '_' + 'Train' + '_' + str(self.seed) + '_' + str(epoch) |
|
self.visualize_results(epoch, dataprint=self.dataprint, name= imageName) |
|
self.visualize_results(epoch, dataprint=self.dataprint_test, name= imageName) |
|
|
|
|
|
self.train_hist['total_time'].append(time.time() - start_time) |
|
print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']), |
|
self.epoch, self.train_hist['total_time'][0])) |
|
print("Training finish!... save training results") |
|
|
|
self.save() |
|
|
|
|
|
|
|
|
|
def visualize_results(self, epoch, dataprint, visual="", name= "test"): |
|
with torch.no_grad(): |
|
self.G.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
newSample = [] |
|
iter = 1 |
|
for x_im,x_dep in zip(dataprint.get('x_im'), dataprint.get('x_dep')): |
|
if (iter > self.cantImages): |
|
break |
|
|
|
|
|
|
|
|
|
|
|
x_im_input = x_im.repeat(2, 1, 1, 1) |
|
x_dep_input = x_dep.repeat(2, 1, 1, 1) |
|
|
|
sizeImage = x_im.shape[2] |
|
|
|
sample_y_ = torch.zeros((self.class_num, 1, sizeImage, sizeImage)) |
|
for i in range(self.class_num): |
|
if(int(i % self.class_num) == 1): |
|
sample_y_[i] = torch.ones(( 1, sizeImage, sizeImage)) |
|
|
|
if self.gpu_mode: |
|
sample_y_, x_im_input, x_dep_input = sample_y_.cuda(), x_im_input.cuda(), x_dep_input.cuda() |
|
|
|
G_im, G_dep = self.G(sample_y_, x_im_input, x_dep_input) |
|
|
|
newSample.append(x_im.squeeze(0)) |
|
newSample.append(x_dep.squeeze(0).expand(3, -1, -1)) |
|
|
|
|
|
|
|
if self.wiggle: |
|
im_aux, im_dep_aux = G_im, G_dep |
|
for i in range(0, 2): |
|
index = i |
|
for j in range(0, self.wiggleDepth): |
|
|
|
|
|
|
|
if (j == 0 and i == 1): |
|
|
|
im_aux, im_dep_aux = G_im, G_dep |
|
newSample.append(G_im.cpu()[0].squeeze(0)) |
|
newSample.append(G_im.cpu()[1].squeeze(0)) |
|
elif (i == 1): |
|
|
|
index = 0 |
|
|
|
|
|
|
|
|
|
x = im_aux[index].unsqueeze(0) |
|
x_dep = im_dep_aux[index].unsqueeze(0) |
|
|
|
y = sample_y_[i].unsqueeze(0) |
|
|
|
if self.gpu_mode: |
|
y, x, x_dep = y.cuda(), x.cuda(), x_dep.cuda() |
|
|
|
im_aux, im_dep_aux = self.G(y, x, x_dep) |
|
|
|
newSample.append(im_aux.cpu()[0]) |
|
else: |
|
|
|
newSample.append(G_im.cpu()[0]) |
|
newSample.append(G_im.cpu()[1]) |
|
newSample.append(G_dep.cpu()[0].expand(3, -1, -1)) |
|
newSample.append(G_dep.cpu()[1].expand(3, -1, -1)) |
|
|
|
|
|
iter+=1 |
|
|
|
if self.visdom: |
|
visual.plot(epoch, newSample, int(len(newSample) /self.cantImages)) |
|
else: |
|
utils.save_wiggle(newSample, self.cantImages, name) |
|
|
|
|
|
|
|
|
|
|
|
def show_plot_images(self, images, cols=1, titles=None): |
|
"""Display a list of images in a single figure with matplotlib. |
|
|
|
Parameters |
|
--------- |
|
images: List of np.arrays compatible with plt.imshow. |
|
|
|
cols (Default = 1): Number of columns in figure (number of rows is |
|
set to np.ceil(n_images/float(cols))). |
|
|
|
titles: List of titles corresponding to each image. Must have |
|
the same length as titles. |
|
""" |
|
|
|
n_images = len(images) |
|
if titles is None: titles = ['Image (%d)' % i for i in range(1, n_images + 1)] |
|
fig = plt.figure() |
|
for n, (image, title) in enumerate(zip(images, titles)): |
|
a = fig.add_subplot(np.ceil(n_images / float(cols)), cols, n + 1) |
|
|
|
image = (image + 1) * 255.0 |
|
|
|
|
|
|
|
if image.ndim == 2: |
|
plt.gray() |
|
|
|
plt.imshow(image) |
|
a.set_title(title) |
|
fig.set_size_inches(np.array(fig.get_size_inches()) * n_images) |
|
plt.show() |
|
|
|
def joinImages(self, data): |
|
nData = [] |
|
for i in range(self.class_num): |
|
nData.append(data) |
|
nData = np.array(nData) |
|
nData = torch.tensor(nData.tolist()) |
|
nData = nData.type(torch.FloatTensor) |
|
|
|
return nData |
|
|
|
def save(self, epoch=''): |
|
save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) |
|
|
|
if not os.path.exists(save_dir): |
|
os.makedirs(save_dir) |
|
|
|
torch.save(self.G.state_dict(), |
|
os.path.join(save_dir, self.model_name + '_' + self.seed + '_' + epoch + '_G.pkl')) |
|
if not self.onlyGen: |
|
torch.save(self.D.state_dict(), |
|
os.path.join(save_dir, self.model_name + '_' + self.seed + '_' + epoch + '_D.pkl')) |
|
|
|
with open(os.path.join(save_dir, self.model_name + '_history_ '+self.seed+'.pkl'), 'wb') as f: |
|
pickle.dump(self.train_hist, f) |
|
|
|
def load(self): |
|
save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) |
|
|
|
self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_' + self.seed_load + '_G.pkl'))) |
|
if not self.wiggle: |
|
self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_' + self.seed_load + '_D.pkl'))) |
|
|
|
def wiggleEf(self): |
|
seed, epoch = self.seed_load.split('_') |
|
if self.visdom: |
|
visWiggle = utils.VisdomImagePlotter(env_name='Cobo_depth_wiggle_' + seed) |
|
self.visualize_results(epoch=epoch, dataprint=self.dataprint_test, visual=visWiggle) |
|
else: |
|
self.visualize_results(epoch=epoch, dataprint=self.dataprint_test, visual=None, name = self.name_wiggle) |
|
|
|
def recreate(self): |
|
|
|
dataloader_recreate = dataloader(self.dataset, self.input_size, self.batch_size, self.imageDim, split='score') |
|
with torch.no_grad(): |
|
self.G.eval() |
|
accum = 0 |
|
for data_batch in dataloader_recreate.__iter__(): |
|
|
|
|
|
left,left_depth,right,right_depth,direction = data_batch.values() |
|
|
|
if self.gpu_mode: |
|
left,left_depth,right,right_depth,direction = left.cuda(),left_depth.cuda(),right.cuda(),right_depth.cuda(),direction.cuda() |
|
|
|
G_right, G_right_dep = self.G( direction, left, left_depth) |
|
|
|
reverse_direction = direction * 0 |
|
G_left, G_left_dep = self.G( reverse_direction, right, right_depth) |
|
|
|
for index in range(0,self.batch_size): |
|
image_right = (G_right[index] + 1.0)/2.0 |
|
image_right_dep = (G_right_dep[index] + 1.0)/2.0 |
|
|
|
image_left = (G_left[index] + 1.0)/2.0 |
|
image_left_dep = (G_left_dep[index] + 1.0)/2.0 |
|
|
|
|
|
|
|
save_image(image_right, os.path.join("results","recreate_dataset","CAM1","n_{num:0{width}}.png".format(num = index+accum, width = 4))) |
|
save_image(image_right_dep, os.path.join("results","recreate_dataset","CAM1","d_{num:0{width}}.png".format(num = index+accum, width = 4))) |
|
|
|
save_image(image_left, os.path.join("results","recreate_dataset","CAM0","n_{num:0{width}}.png".format(num = index+accum, width = 4))) |
|
save_image(image_left_dep, os.path.join("results","recreate_dataset","CAM0","d_{num:0{width}}.png".format(num = index+accum, width = 4))) |
|
accum+= self.batch_size |
|
|
|
|
|
|