import utils, torch, time, os, pickle import numpy as np import torch.nn as nn import torch.cuda as cu import torch.optim as optim import pickle from torchvision import transforms from torchvision.utils import save_image from utils import augmentData, RGBtoL, LtoRGB from PIL import Image from dataloader import dataloader from torch.autograd import Variable import matplotlib.pyplot as plt import random from datetime import date from statistics import mean from architectures import depth_generator_UNet, \ depth_discriminator_noclass_UNet class WiggleGAN(object): def __init__(self, args): # parameters self.epoch = args.epoch self.sample_num = 100 self.nCameras = args.cameras self.batch_size = args.batch_size self.save_dir = args.save_dir self.result_dir = args.result_dir self.dataset = args.dataset self.log_dir = args.log_dir self.gpu_mode = args.gpu_mode self.model_name = args.gan_type self.input_size = args.input_size self.class_num = (args.cameras - 1) * 2 # un calculo que hice en paint self.sample_num = self.class_num ** 2 self.imageDim = args.imageDim self.epochVentaja = args.epochV self.cantImages = args.cIm self.visdom = args.visdom self.lambdaL1 = args.lambdaL1 self.depth = args.depth self.name_wiggle = args.name_wiggle self.clipping = args.clipping self.WGAN = False if (self.clipping > 0): self.WGAN = True self.seed = str(random.randint(0, 99999)) self.seed_load = args.seedLoad self.toLoad = False if (self.seed_load != "-0000"): self.toLoad = True self.zGenFactor = args.zGF self.zDisFactor = args.zDF self.bFactor = args.bF self.CR = False if (self.zGenFactor > 0 or self.zDisFactor > 0 or self.bFactor > 0): self.CR = True self.expandGen = args.expandGen self.expandDis = args.expandDis self.wiggleDepth = args.wiggleDepth self.wiggle = False if (self.wiggleDepth > 0): self.wiggle = True # load dataset self.onlyGen = args.lrD <= 0 if not self.wiggle: self.data_loader = dataloader(self.dataset, self.input_size, self.batch_size, self.imageDim, split='train', trans=not self.CR) self.data_Validation = dataloader(self.dataset, self.input_size, self.batch_size, self.imageDim, split='validation') self.dataprint = self.data_Validation.__iter__().__next__() data = self.data_loader.__iter__().__next__().get('x_im') if not self.onlyGen: self.D = depth_discriminator_noclass_UNet(input_dim=3, output_dim=1, input_shape=data.shape, class_num=self.class_num, expand_net=self.expandDis, depth = self.depth, wgan = self.WGAN) self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2)) self.data_Test = dataloader(self.dataset, self.input_size, self.batch_size, self.imageDim, split='test') self.dataprint_test = self.data_Test.__iter__().__next__() # networks init self.G = depth_generator_UNet(input_dim=4, output_dim=3, class_num=self.class_num, expand_net=self.expandGen, depth = self.depth) self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2)) if self.gpu_mode: self.G.cuda() if not self.wiggle and not self.onlyGen: self.D.cuda() self.BCE_loss = nn.BCELoss().cuda() self.CE_loss = nn.CrossEntropyLoss().cuda() self.L1 = nn.L1Loss().cuda() self.MSE = nn.MSELoss().cuda() self.BCEWithLogitsLoss = nn.BCEWithLogitsLoss().cuda() else: self.BCE_loss = nn.BCELoss() self.CE_loss = nn.CrossEntropyLoss() self.MSE = nn.MSELoss() self.L1 = nn.L1Loss() self.BCEWithLogitsLoss = nn.BCEWithLogitsLoss() print('---------- Networks architecture -------------') utils.print_network(self.G) if not self.wiggle and not self.onlyGen: utils.print_network(self.D) print('-----------------------------------------------') temp = torch.zeros((self.class_num, 1)) for i in range(self.class_num): temp[i, 0] = i temp_y = torch.zeros((self.sample_num, 1)) for i in range(self.class_num): temp_y[i * self.class_num: (i + 1) * self.class_num] = temp self.sample_y_ = torch.zeros((self.sample_num, self.class_num)).scatter_(1, temp_y.type(torch.LongTensor), 1) if self.gpu_mode: self.sample_y_ = self.sample_y_.cuda() if (self.toLoad): self.load() def train(self): if self.visdom: random.seed(time.time()) today = date.today() vis = utils.VisdomLinePlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed) visValidation = utils.VisdomLinePlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed) visEpoch = utils.VisdomLineTwoPlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed) visImages = utils.VisdomImagePlotter(env_name='Cobo_depth_Images_' + str(today) + '_' + self.seed) visImagesTest = utils.VisdomImagePlotter(env_name='Cobo_depth_ImagesTest_' + str(today) + '_' + self.seed) visLossGTest = utils.VisdomLinePlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed) visLossGValidation = utils.VisdomLinePlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed) visLossDTest = utils.VisdomLinePlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed) visLossDValidation = utils.VisdomLinePlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed) self.train_hist = {} self.epoch_hist = {} self.details_hist = {} self.train_hist['D_loss_train'] = [] self.train_hist['G_loss_train'] = [] self.train_hist['D_loss_Validation'] = [] self.train_hist['G_loss_Validation'] = [] self.train_hist['per_epoch_time'] = [] self.train_hist['total_time'] = [] self.details_hist['G_T_Comp_im'] = [] self.details_hist['G_T_BCE_fake_real'] = [] self.details_hist['G_T_Cycle'] = [] self.details_hist['G_zCR'] = [] self.details_hist['G_V_Comp_im'] = [] self.details_hist['G_V_BCE_fake_real'] = [] self.details_hist['G_V_Cycle'] = [] self.details_hist['D_T_BCE_fake_real_R'] = [] self.details_hist['D_T_BCE_fake_real_F'] = [] self.details_hist['D_zCR'] = [] self.details_hist['D_bCR'] = [] self.details_hist['D_V_BCE_fake_real_R'] = [] self.details_hist['D_V_BCE_fake_real_F'] = [] self.epoch_hist['D_loss_train'] = [] self.epoch_hist['G_loss_train'] = [] self.epoch_hist['D_loss_Validation'] = [] self.epoch_hist['G_loss_Validation'] = [] ##Para poder tomar el promedio por epoch iterIniTrain = 0 iterFinTrain = 0 iterIniValidation = 0 iterFinValidation = 0 maxIter = self.data_loader.dataset.__len__() // self.batch_size maxIterVal = self.data_Validation.dataset.__len__() // self.batch_size if (self.WGAN): one = torch.tensor(1, dtype=torch.float).cuda() mone = one * -1 else: self.y_real_ = torch.ones(self.batch_size, 1) self.y_fake_ = torch.zeros(self.batch_size, 1) if self.gpu_mode: self.y_real_, self.y_fake_ = self.y_real_.cuda(), self.y_fake_.cuda() print('training start!!') start_time = time.time() for epoch in range(self.epoch): if (epoch < self.epochVentaja): ventaja = True else: ventaja = False self.G.train() if not self.onlyGen: self.D.train() epoch_start_time = time.time() # TRAIN!!! for iter, data in enumerate(self.data_loader): x_im = data.get('x_im') x_dep = data.get('x_dep') y_im = data.get('y_im') y_dep = data.get('y_dep') y_ = data.get('y_') # x_im = imagenes normales # x_dep = profundidad de images # y_im = imagen con el angulo cambiado # y_ = angulo de la imagen = tengo que tratar negativos # Aumento mi data if (self.CR): x_im_aug, y_im_aug = augmentData(x_im, y_im) x_im_vanilla = x_im if self.gpu_mode: x_im_aug, y_im_aug = x_im_aug.cuda(), y_im_aug.cuda() if iter >= maxIter: break if self.gpu_mode: x_im, y_, y_im, x_dep, y_dep = x_im.cuda(), y_.cuda(), y_im.cuda(), x_dep.cuda(), y_dep.cuda() # update D network if not ventaja and not self.onlyGen: for p in self.D.parameters(): # reset requires_grad p.requires_grad = True # they are set to False below in netG update self.D_optimizer.zero_grad() # Real Images D_real, D_features_real = self.D(y_im, x_im, y_dep, y_) ## Es la funcion forward `` g(z) x # Fake Images G_, G_dep = self.G( y_, x_im, x_dep) D_fake, D_features_fake = self.D(G_, x_im, G_dep, y_) # Losses # GAN Loss if (self.WGAN): # de WGAN D_loss_real_fake_R = - torch.mean(D_real) D_loss_real_fake_F = torch.mean(D_fake) #D_loss_real_fake_R = - D_loss_real_fake_R_positive else: # de Gan normal D_loss_real_fake_R = self.BCEWithLogitsLoss(D_real, self.y_real_) D_loss_real_fake_F = self.BCEWithLogitsLoss(D_fake, self.y_fake_) D_loss = D_loss_real_fake_F + D_loss_real_fake_R if self.CR: # Fake Augmented Images bCR x_im_aug_bCR, G_aug_bCR = augmentData(x_im_vanilla, G_.data.cpu()) if self.gpu_mode: G_aug_bCR, x_im_aug_bCR = G_aug_bCR.cuda(), x_im_aug_bCR.cuda() D_fake_bCR, D_features_fake_bCR = self.D(G_aug_bCR, x_im_aug_bCR, G_dep, y_) D_real_bCR, D_features_real_bCR = self.D(y_im_aug, x_im_aug, y_dep, y_) # Fake Augmented Images zCR G_aug_zCR, G_dep_aug_zCR = self.G(y_, x_im_aug, x_dep) D_fake_aug_zCR, D_features_fake_aug_zCR = self.D(G_aug_zCR, x_im_aug, G_dep_aug_zCR, y_) # bCR Loss (*) D_loss_real = self.MSE(D_features_real, D_features_real_bCR) D_loss_fake = self.MSE(D_features_fake, D_features_fake_bCR) D_bCR = (D_loss_real + D_loss_fake) * self.bFactor # zCR Loss D_zCR = self.MSE(D_features_fake, D_features_fake_aug_zCR) * self.zDisFactor D_CR_losses = D_bCR + D_zCR #D_CR_losses.backward(retain_graph=True) D_loss += D_CR_losses self.details_hist['D_bCR'].append(D_bCR.detach().item()) self.details_hist['D_zCR'].append(D_zCR.detach().item()) else: self.details_hist['D_bCR'].append(0) self.details_hist['D_zCR'].append(0) self.train_hist['D_loss_train'].append(D_loss.detach().item()) self.details_hist['D_T_BCE_fake_real_R'].append(D_loss_real_fake_R.detach().item()) self.details_hist['D_T_BCE_fake_real_F'].append(D_loss_real_fake_F.detach().item()) if self.visdom: visLossDTest.plot('Discriminator_losses', ['D_T_BCE_fake_real_R','D_T_BCE_fake_real_F', 'D_bCR', 'D_zCR'], 'train', self.details_hist) #if self.WGAN: # D_loss_real_fake_F.backward(retain_graph=True) # D_loss_real_fake_R_positive.backward(mone) #else: # D_loss_real_fake.backward() D_loss.backward() self.D_optimizer.step() #WGAN if (self.WGAN): for p in self.D.parameters(): p.data.clamp_(-self.clipping, self.clipping) #Segun paper si el valor es muy chico lleva al banishing gradient # Si se aplicaria la mejora en las WGANs tendiramos que sacar los batch normalizations de la red # update G network self.G_optimizer.zero_grad() G_, G_dep = self.G(y_, x_im, x_dep) if not ventaja and not self.onlyGen: for p in self.D.parameters(): p.requires_grad = False # to avoid computation # Fake images D_fake, _ = self.D(G_, x_im, G_dep, y_) if (self.WGAN): G_loss_fake = -torch.mean(D_fake) #de WGAN else: G_loss_fake = self.BCEWithLogitsLoss(D_fake, self.y_real_) # loss between images (*) #G_join = torch.cat((G_, G_dep), 1) #y_join = torch.cat((y_im, y_dep), 1) G_loss_Comp = self.L1(G_, y_im) if self.depth: G_loss_Comp += self.L1(G_dep, y_dep) G_loss_Dif_Comp = G_loss_Comp * self.lambdaL1 reverse_y = - y_ + 1 reverse_G, reverse_G_dep = self.G(reverse_y, G_, G_dep) G_loss_Cycle = self.L1(reverse_G, x_im) if self.depth: G_loss_Cycle += self.L1(reverse_G_dep, x_dep) G_loss_Cycle = G_loss_Cycle * self.lambdaL1/2 if (self.CR): # Fake images augmented G_aug, G_dep_aug = self.G(y_, x_im_aug, x_dep) D_fake_aug, _ = self.D(G_aug, x_im, G_dep_aug, y_) if (self.WGAN): G_loss_fake = - (torch.mean(D_fake)+torch.mean(D_fake_aug))/2 else: G_loss_fake = ( self.BCEWithLogitsLoss(D_fake, self.y_real_) + self.BCEWithLogitsLoss(D_fake_aug,self.y_real_)) / 2 # loss between images (*) #y_aug_join = torch.cat((y_im_aug, y_dep), 1) #G_aug_join = torch.cat((G_aug, G_dep_aug), 1) G_loss_Comp_Aug = self.L1(G_aug, y_im_aug) if self.depth: G_loss_Comp_Aug += self.L1(G_dep_aug, y_dep) G_loss_Dif_Comp = (G_loss_Comp + G_loss_Comp_Aug)/2 * self.lambdaL1 G_loss = G_loss_fake + G_loss_Dif_Comp + G_loss_Cycle self.details_hist['G_T_BCE_fake_real'].append(G_loss_fake.detach().item()) self.details_hist['G_T_Comp_im'].append(G_loss_Dif_Comp.detach().item()) self.details_hist['G_T_Cycle'].append(G_loss_Cycle.detach().item()) self.details_hist['G_zCR'].append(0) else: G_loss = self.L1(G_, y_im) if self.depth: G_loss += self.L1(G_dep, y_dep) G_loss = G_loss * self.lambdaL1 self.details_hist['G_T_Comp_im'].append(G_loss.detach().item()) self.details_hist['G_T_BCE_fake_real'].append(0) self.details_hist['G_T_Cycle'].append(0) self.details_hist['G_zCR'].append(0) G_loss.backward() self.G_optimizer.step() self.train_hist['G_loss_train'].append(G_loss.detach().item()) if self.onlyGen: self.train_hist['D_loss_train'].append(0) iterFinTrain += 1 if self.visdom: visLossGTest.plot('Generator_losses', ['G_T_Comp_im', 'G_T_BCE_fake_real', 'G_zCR','G_T_Cycle'], 'train', self.details_hist) vis.plot('loss', ['D_loss_train', 'G_loss_train'], 'train', self.train_hist) ##################Validation#################################### with torch.no_grad(): self.G.eval() if not self.onlyGen: self.D.eval() for iter, data in enumerate(self.data_Validation): # Aumento mi data x_im = data.get('x_im') x_dep = data.get('x_dep') y_im = data.get('y_im') y_dep = data.get('y_dep') y_ = data.get('y_') # x_im = imagenes normales # x_dep = profundidad de images # y_im = imagen con el angulo cambiado # y_ = angulo de la imagen = tengo que tratar negativos # x_im = torch.Tensor(list(x_im)) # x_dep = torch.Tensor(x_dep) # y_im = torch.Tensor(y_im) # print(y_.shape[0]) if iter == maxIterVal: # print ("Break") break # print (y_.type(torch.LongTensor).unsqueeze(1)) # print("y_vec_", y_vec_) # print ("z_", z_) if self.gpu_mode: x_im, y_, y_im, x_dep, y_dep = x_im.cuda(), y_.cuda(), y_im.cuda(), x_dep.cuda(), y_dep.cuda() # D network if not ventaja and not self.onlyGen: # Real Images D_real, _ = self.D(y_im, x_im, y_dep,y_) ## Es la funcion forward `` g(z) x # Fake Images G_, G_dep = self.G(y_, x_im, x_dep) D_fake, _ = self.D(G_, x_im, G_dep, y_) # Losses # GAN Loss if (self.WGAN): # de WGAN D_loss_real_fake_R = - torch.mean(D_real) D_loss_real_fake_F = torch.mean(D_fake) else: # de Gan normal D_loss_real_fake_R = self.BCEWithLogitsLoss(D_real, self.y_real_) D_loss_real_fake_F = self.BCEWithLogitsLoss(D_fake, self.y_fake_) D_loss_real_fake = D_loss_real_fake_F + D_loss_real_fake_R D_loss = D_loss_real_fake self.train_hist['D_loss_Validation'].append(D_loss.item()) self.details_hist['D_V_BCE_fake_real_R'].append(D_loss_real_fake_R.item()) self.details_hist['D_V_BCE_fake_real_F'].append(D_loss_real_fake_F.item()) if self.visdom: visLossDValidation.plot('Discriminator_losses', ['D_V_BCE_fake_real_R','D_V_BCE_fake_real_F'], 'Validation', self.details_hist) # G network G_, G_dep = self.G(y_, x_im, x_dep) if not ventaja and not self.onlyGen: # Fake images D_fake,_ = self.D(G_, x_im, G_dep, y_) #Loss GAN if (self.WGAN): G_loss = -torch.mean(D_fake) # porWGAN else: G_loss = self.BCEWithLogitsLoss(D_fake, self.y_real_) #de GAN NORMAL self.details_hist['G_V_BCE_fake_real'].append(G_loss.item()) #Loss comparation #G_join = torch.cat((G_, G_dep), 1) #y_join = torch.cat((y_im, y_dep), 1) G_loss_Comp = self.L1(G_, y_im) if self.depth: G_loss_Comp += self.L1(G_dep, y_dep) G_loss_Comp = G_loss_Comp * self.lambdaL1 reverse_y = - y_ + 1 reverse_G, reverse_G_dep = self.G(reverse_y, G_, G_dep) G_loss_Cycle = self.L1(reverse_G, x_im) if self.depth: G_loss_Cycle += self.L1(reverse_G_dep, x_dep) G_loss_Cycle = G_loss_Cycle * self.lambdaL1/2 G_loss += G_loss_Comp + G_loss_Cycle self.details_hist['G_V_Comp_im'].append(G_loss_Comp.item()) self.details_hist['G_V_Cycle'].append(G_loss_Cycle.detach().item()) else: G_loss = self.L1(G_, y_im) if self.depth: G_loss += self.L1(G_dep, y_dep) G_loss = G_loss * self.lambdaL1 self.details_hist['G_V_Comp_im'].append(G_loss.item()) self.details_hist['G_V_BCE_fake_real'].append(0) self.details_hist['G_V_Cycle'].append(0) self.train_hist['G_loss_Validation'].append(G_loss.item()) if self.onlyGen: self.train_hist['D_loss_Validation'].append(0) iterFinValidation += 1 if self.visdom: visLossGValidation.plot('Generator_losses', ['G_V_Comp_im', 'G_V_BCE_fake_real','G_V_Cycle'], 'Validation', self.details_hist) visValidation.plot('loss', ['D_loss_Validation', 'G_loss_Validation'], 'Validation', self.train_hist) ##Vis por epoch if ventaja or self.onlyGen: self.epoch_hist['D_loss_train'].append(0) self.epoch_hist['D_loss_Validation'].append(0) else: #inicioTr = (epoch - self.epochVentaja) * (iterFinTrain - iterIniTrain) #inicioTe = (epoch - self.epochVentaja) * (iterFinValidation - iterIniValidation) self.epoch_hist['D_loss_train'].append(mean(self.train_hist['D_loss_train'][iterIniTrain: -1])) self.epoch_hist['D_loss_Validation'].append(mean(self.train_hist['D_loss_Validation'][iterIniValidation: -1])) self.epoch_hist['G_loss_train'].append(mean(self.train_hist['G_loss_train'][iterIniTrain:iterFinTrain])) self.epoch_hist['G_loss_Validation'].append( mean(self.train_hist['G_loss_Validation'][iterIniValidation:iterFinValidation])) if self.visdom: visEpoch.plot('epoch', epoch, ['D_loss_train', 'G_loss_train', 'D_loss_Validation', 'G_loss_Validation'], self.epoch_hist) self.train_hist['D_loss_train'] = self.train_hist['D_loss_train'][-1:] self.train_hist['G_loss_train'] = self.train_hist['G_loss_train'][-1:] self.train_hist['D_loss_Validation'] = self.train_hist['D_loss_Validation'][-1:] self.train_hist['G_loss_Validation'] = self.train_hist['G_loss_Validation'][-1:] self.train_hist['per_epoch_time'] = self.train_hist['per_epoch_time'][-1:] self.train_hist['total_time'] = self.train_hist['total_time'][-1:] self.details_hist['G_T_Comp_im'] = self.details_hist['G_T_Comp_im'][-1:] self.details_hist['G_T_BCE_fake_real'] = self.details_hist['G_T_BCE_fake_real'][-1:] self.details_hist['G_T_Cycle'] = self.details_hist['G_T_Cycle'][-1:] self.details_hist['G_zCR'] = self.details_hist['G_zCR'][-1:] self.details_hist['G_V_Comp_im'] = self.details_hist['G_V_Comp_im'][-1:] self.details_hist['G_V_BCE_fake_real'] = self.details_hist['G_V_BCE_fake_real'][-1:] self.details_hist['G_V_Cycle'] = self.details_hist['G_V_Cycle'][-1:] self.details_hist['D_T_BCE_fake_real_R'] = self.details_hist['D_T_BCE_fake_real_R'][-1:] self.details_hist['D_T_BCE_fake_real_F'] = self.details_hist['D_T_BCE_fake_real_F'][-1:] self.details_hist['D_zCR'] = self.details_hist['D_zCR'][-1:] self.details_hist['D_bCR'] = self.details_hist['D_bCR'][-1:] self.details_hist['D_V_BCE_fake_real_R'] = self.details_hist['D_V_BCE_fake_real_R'][-1:] self.details_hist['D_V_BCE_fake_real_F'] = self.details_hist['D_V_BCE_fake_real_F'][-1:] ##Para poder tomar el promedio por epoch iterIniTrain = 1 iterFinTrain = 1 iterIniValidation = 1 iterFinValidation = 1 self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) if epoch % 10 == 0: self.save(str(epoch)) with torch.no_grad(): if self.visdom: self.visualize_results(epoch, dataprint=self.dataprint, visual=visImages) self.visualize_results(epoch, dataprint=self.dataprint_test, visual=visImagesTest) else: imageName = self.model_name + '_' + 'Train' + '_' + str(self.seed) + '_' + str(epoch) self.visualize_results(epoch, dataprint=self.dataprint, name= imageName) self.visualize_results(epoch, dataprint=self.dataprint_test, name= imageName) self.train_hist['total_time'].append(time.time() - start_time) print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']), self.epoch, self.train_hist['total_time'][0])) print("Training finish!... save training results") self.save() #utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name, # self.epoch) #utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name) def visualize_results(self, epoch, dataprint, visual="", name= "test"): with torch.no_grad(): self.G.eval() #if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name): # os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name) # print("sample z: ",self.sample_z_,"sample y:", self.sample_y_) ##Podria hacer un loop # .zfill(4) #newSample = None #print(dataprint.shape) #newSample = torch.tensor([]) #se que es ineficiente pero lo hago cada 10 epoch nomas newSample = [] iter = 1 for x_im,x_dep in zip(dataprint.get('x_im'), dataprint.get('x_dep')): if (iter > self.cantImages): break #x_im = (x_im + 1) / 2 #imgX = transforms.ToPILImage()(x_im) #imgX.show() x_im_input = x_im.repeat(2, 1, 1, 1) x_dep_input = x_dep.repeat(2, 1, 1, 1) sizeImage = x_im.shape[2] sample_y_ = torch.zeros((self.class_num, 1, sizeImage, sizeImage)) for i in range(self.class_num): if(int(i % self.class_num) == 1): sample_y_[i] = torch.ones(( 1, sizeImage, sizeImage)) if self.gpu_mode: sample_y_, x_im_input, x_dep_input = sample_y_.cuda(), x_im_input.cuda(), x_dep_input.cuda() G_im, G_dep = self.G(sample_y_, x_im_input, x_dep_input) newSample.append(x_im.squeeze(0)) newSample.append(x_dep.squeeze(0).expand(3, -1, -1)) if self.wiggle: im_aux, im_dep_aux = G_im, G_dep for i in range(0, 2): index = i for j in range(0, self.wiggleDepth): # print(i,j) if (j == 0 and i == 1): # para tomar el original im_aux, im_dep_aux = G_im, G_dep newSample.append(G_im.cpu()[0].squeeze(0)) newSample.append(G_im.cpu()[1].squeeze(0)) elif (i == 1): # por el problema de las iteraciones proximas index = 0 # imagen generada x = im_aux[index].unsqueeze(0) x_dep = im_dep_aux[index].unsqueeze(0) y = sample_y_[i].unsqueeze(0) if self.gpu_mode: y, x, x_dep = y.cuda(), x.cuda(), x_dep.cuda() im_aux, im_dep_aux = self.G(y, x, x_dep) newSample.append(im_aux.cpu()[0]) else: newSample.append(G_im.cpu()[0]) newSample.append(G_im.cpu()[1]) newSample.append(G_dep.cpu()[0].expand(3, -1, -1)) newSample.append(G_dep.cpu()[1].expand(3, -1, -1)) # sadadas iter+=1 if self.visdom: visual.plot(epoch, newSample, int(len(newSample) /self.cantImages)) else: utils.save_wiggle(newSample, self.cantImages, name) ##TENGO QUE HACER QUE SAMPLES TENGAN COMO MAXIMO self.class_num * self.class_num # utils.save_images(newSample[:, :, :, :], [image_frame_dim * cantidadIm , image_frame_dim * (self.class_num+2)], # self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%04d' % epoch + '.png') def show_plot_images(self, images, cols=1, titles=None): """Display a list of images in a single figure with matplotlib. Parameters --------- images: List of np.arrays compatible with plt.imshow. cols (Default = 1): Number of columns in figure (number of rows is set to np.ceil(n_images/float(cols))). titles: List of titles corresponding to each image. Must have the same length as titles. """ # assert ((titles is None) or (len(images) == len(titles))) n_images = len(images) if titles is None: titles = ['Image (%d)' % i for i in range(1, n_images + 1)] fig = plt.figure() for n, (image, title) in enumerate(zip(images, titles)): a = fig.add_subplot(np.ceil(n_images / float(cols)), cols, n + 1) # print(image) image = (image + 1) * 255.0 # print(image) # new_im = Image.fromarray(image) # print(new_im) if image.ndim == 2: plt.gray() # print("spi imshape ", image.shape) plt.imshow(image) a.set_title(title) fig.set_size_inches(np.array(fig.get_size_inches()) * n_images) plt.show() def joinImages(self, data): nData = [] for i in range(self.class_num): nData.append(data) nData = np.array(nData) nData = torch.tensor(nData.tolist()) nData = nData.type(torch.FloatTensor) return nData def save(self, epoch=''): save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) if not os.path.exists(save_dir): os.makedirs(save_dir) torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_' + self.seed + '_' + epoch + '_G.pkl')) if not self.onlyGen: torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_' + self.seed + '_' + epoch + '_D.pkl')) with open(os.path.join(save_dir, self.model_name + '_history_ '+self.seed+'.pkl'), 'wb') as f: pickle.dump(self.train_hist, f) def load(self): save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_' + self.seed_load + '_G.pkl'))) if not self.wiggle: self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_' + self.seed_load + '_D.pkl'))) def wiggleEf(self): seed, epoch = self.seed_load.split('_') if self.visdom: visWiggle = utils.VisdomImagePlotter(env_name='Cobo_depth_wiggle_' + seed) self.visualize_results(epoch=epoch, dataprint=self.dataprint_test, visual=visWiggle) else: self.visualize_results(epoch=epoch, dataprint=self.dataprint_test, visual=None, name = self.name_wiggle) def recreate(self): dataloader_recreate = dataloader(self.dataset, self.input_size, self.batch_size, self.imageDim, split='score') with torch.no_grad(): self.G.eval() accum = 0 for data_batch in dataloader_recreate.__iter__(): #{'x_im': x1, 'x_dep': x1_dep, 'y_im': x2, 'y_dep': x2_dep, 'y_': torch.ones(1, self.imageDim, self.imageDim)} left,left_depth,right,right_depth,direction = data_batch.values() if self.gpu_mode: left,left_depth,right,right_depth,direction = left.cuda(),left_depth.cuda(),right.cuda(),right_depth.cuda(),direction.cuda() G_right, G_right_dep = self.G( direction, left, left_depth) reverse_direction = direction * 0 G_left, G_left_dep = self.G( reverse_direction, right, right_depth) for index in range(0,self.batch_size): image_right = (G_right[index] + 1.0)/2.0 image_right_dep = (G_right_dep[index] + 1.0)/2.0 image_left = (G_left[index] + 1.0)/2.0 image_left_dep = (G_left_dep[index] + 1.0)/2.0 save_image(image_right, os.path.join("results","recreate_dataset","CAM1","n_{num:0{width}}.png".format(num = index+accum, width = 4))) save_image(image_right_dep, os.path.join("results","recreate_dataset","CAM1","d_{num:0{width}}.png".format(num = index+accum, width = 4))) save_image(image_left, os.path.join("results","recreate_dataset","CAM0","n_{num:0{width}}.png".format(num = index+accum, width = 4))) save_image(image_left_dep, os.path.join("results","recreate_dataset","CAM0","d_{num:0{width}}.png".format(num = index+accum, width = 4))) accum+= self.batch_size