Flux9665's picture
Update Modules/ControllabilityGAN/wgan/wgan_qc.py
8b95ccf verified
raw
history blame
10.6 kB
import os
import time
import numpy as np
import torch
import torch.optim as optim
from cvxopt import matrix
from cvxopt import solvers
from cvxopt import sparse
from cvxopt import spmatrix
from torch.autograd import grad as torch_grad
from tqdm import tqdm
class WassersteinGanQuadraticCost:
def __init__(self, generator, discriminator, gen_optimizer, dis_optimizer, criterion, epochs, n_max_iterations,
data_dimensions, batch_size, device, gamma=0.1, K=-1, milestones=[150000, 250000], lr_anneal=1.0):
self.G = generator
self.G_opt = gen_optimizer
self.D = discriminator
self.D_opt = dis_optimizer
self.losses = {
'D' : [],
'WD': [],
'G' : []
}
self.num_steps = 0
self.gen_steps = 0
self.epochs = epochs
self.n_max_iterations = n_max_iterations
# put in the shape of a dataset sample
self.data_dim = data_dimensions[0] * data_dimensions[1] * data_dimensions[2]
self.batch_size = batch_size
self.device = device
self.criterion = criterion
self.mone = torch.FloatTensor([-1]).to(device)
self.tensorboard_counter = 0
if K <= 0:
self.K = 1 / self.data_dim
else:
self.K = K
self.Kr = np.sqrt(self.K)
self.LAMBDA = 2 * self.Kr * gamma * 2
self.G = self.G.to(self.device)
self.D = self.D.to(self.device)
self.schedulerD = self._build_lr_scheduler_(self.D_opt, milestones, lr_anneal)
self.schedulerG = self._build_lr_scheduler_(self.G_opt, milestones, lr_anneal)
self.c, self.A, self.pStart = self._prepare_linear_programming_solver_(self.batch_size)
def _build_lr_scheduler_(self, optimizer, milestones, lr_anneal, last_epoch=-1):
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=lr_anneal, last_epoch=-1)
return scheduler
def _quadratic_wasserstein_distance_(self, real, generated):
num_r = real.size(0)
num_f = generated.size(0)
real_flat = real.view(num_r, -1)
fake_flat = generated.view(num_f, -1)
real3D = real_flat.unsqueeze(1).expand(num_r, num_f, self.data_dim)
fake3D = fake_flat.unsqueeze(0).expand(num_r, num_f, self.data_dim)
# compute squared L2 distance
dif = real3D - fake3D
dist = 0.5 * dif.pow(2).sum(2).squeeze()
return self.K * dist
def _prepare_linear_programming_solver_(self, batch_size):
A = spmatrix(1.0, range(batch_size), [0] * batch_size, (batch_size, batch_size))
for i in range(1, batch_size):
Ai = spmatrix(1.0, range(batch_size), [i] * batch_size, (batch_size, batch_size))
A = sparse([A, Ai])
D = spmatrix(-1.0, range(batch_size), range(batch_size), (batch_size, batch_size))
DM = D
for i in range(1, batch_size):
DM = sparse([DM, D])
A = sparse([[A], [DM]])
cr = matrix([-1.0 / batch_size] * batch_size)
cf = matrix([1.0 / batch_size] * batch_size)
c = matrix([cr, cf])
pStart = {}
pStart['x'] = matrix([matrix([1.0] * batch_size), matrix([-1.0] * batch_size)])
pStart['s'] = matrix([1.0] * (2 * batch_size))
return c, A, pStart
def _linear_programming_(self, distance, batch_size):
b = matrix(distance.cpu().double().detach().numpy().flatten())
sol = solvers.lp(self.c, self.A, b, primalstart=self.pStart, solver='glpk',
options={'glpk': {'msg_lev': 'GLP_MSG_OFF'}})
offset = 0.5 * (sum(sol['x'])) / batch_size
sol['x'] = sol['x'] - offset
self.pStart['x'] = sol['x']
self.pStart['s'] = sol['s']
return sol
def _approx_OT_(self, sol):
# Compute the OT mapping for each fake dataset
ResMat = np.array(sol['z']).reshape((self.batch_size, self.batch_size))
mapping = torch.from_numpy(np.argmax(ResMat, axis=0)).long().to(self.device)
return mapping
def _optimal_transport_regularization_(self, output_fake, fake, real_fake_diff):
output_fake_grad = torch.ones(output_fake.size()).to(self.device)
gradients = torch_grad(outputs=output_fake, inputs=fake,
grad_outputs=output_fake_grad,
create_graph=True, retain_graph=True, only_inputs=True)[0]
n = gradients.size(0)
RegLoss = 0.5 * ((gradients.view(n, -1).norm(dim=1) / (2 * self.Kr) - self.Kr / 2 * real_fake_diff.view(n,
-1).norm(
dim=1)).pow(2)).mean()
fake.requires_grad = False
return RegLoss
def _critic_deep_regression_(self, images, opt_iterations=1):
images = images.to(self.device)
for p in self.D.parameters(): # reset requires_grad
p.requires_grad = True # they are set to False below in netG update
self.G.train()
self.D.train()
# Get generated fake dataset
generated_data = self.sample_generator(self.batch_size)
# compute wasserstein distance
distance = self._quadratic_wasserstein_distance_(images, generated_data)
# solve linear programming problem
sol = self._linear_programming_(distance, self.batch_size)
# approximate optimal transport
mapping = self._approx_OT_(sol)
real_ordered = images[mapping] # match real and fake
real_fake_diff = real_ordered - generated_data
# construct target
target = torch.from_numpy(np.array(sol['x'])).float()
target = target.squeeze().to(self.device)
for i in range(opt_iterations):
self.D.zero_grad() # ???
self.D_opt.zero_grad()
generated_data.requires_grad_()
if generated_data.grad is not None:
generated_data.grad.data.zero_()
output_real = self.D(images)
output_fake = self.D(generated_data)
output_real, output_fake = output_real.squeeze(), output_fake.squeeze()
output_R_mean = output_real.mean(0).view(1)
output_F_mean = output_fake.mean(0).view(1)
L2LossD_real = self.criterion(output_R_mean[0], target[:self.batch_size].mean())
L2LossD_fake = self.criterion(output_fake, target[self.batch_size:])
L2LossD = 0.5 * L2LossD_real + 0.5 * L2LossD_fake
reg_loss_D = self._optimal_transport_regularization_(output_fake, generated_data, real_fake_diff)
total_loss = L2LossD + self.LAMBDA * reg_loss_D
self.losses['D'].append(float(total_loss.data))
total_loss.backward()
self.D_opt.step()
# this is supposed to be the wasserstein distance
wasserstein_distance = output_R_mean - output_F_mean
self.losses['WD'].append(float(wasserstein_distance.data))
def _generator_train_iteration(self, batch_size):
for p in self.D.parameters():
p.requires_grad = False # freeze critic
self.G.zero_grad()
self.G_opt.zero_grad()
if isinstance(self.G, torch.nn.parallel.DataParallel):
z = self.G.module.sample_latent(batch_size, self.G.module.z_dim)
else:
z = self.G.sample_latent(batch_size, self.G.z_dim)
z.requires_grad = True
fake = self.G(z)
output_fake = self.D(fake)
output_F_mean_after = output_fake.mean(0).view(1)
self.losses['G'].append(float(output_F_mean_after.data))
output_F_mean_after.backward(self.mone)
self.G_opt.step()
self.schedulerD.step()
self.schedulerG.step()
def _train_epoch(self, data_loader, writer, experiment):
for i, data in enumerate(tqdm(data_loader)):
images = data[0]
speaker_ids = data[1]
self.num_steps += 1
# self.tensorboard_counter += 1
if self.gen_steps >= self.n_max_iterations:
return
self._critic_deep_regression_(images)
self._generator_train_iteration(images.size(0))
D_loss_avg = np.average(self.losses['D'])
G_loss_avg = np.average(self.losses['G'])
wd_avg = np.average(self.losses['WD'])
def train(self, data_loader, writer, experiment=None):
self.G.train()
self.D.train()
for epoch in range(self.epochs):
if self.gen_steps >= self.n_max_iterations:
return
time_start_epoch = time.time()
self._train_epoch(data_loader, writer, experiment)
D_loss_avg = np.average(self.losses['D'])
time_end_epoch = time.time()
return self
def sample_generator(self, num_samples, nograd=False, return_intermediate=False):
self.G.eval()
if isinstance(self.G, torch.nn.parallel.DataParallel):
latent_samples = self.G.module.sample_latent(num_samples, self.G.module.z_dim, 1.0)
else:
latent_samples = self.G.sample_latent(num_samples, self.G.z_dim, 1.0)
latent_samples = latent_samples.to(self.device)
if nograd:
with torch.no_grad():
generated_data = self.G(latent_samples, return_intermediate=return_intermediate)
else:
generated_data = self.G(latent_samples)
self.G.train()
if return_intermediate:
return generated_data[0].detach(), generated_data[1], latent_samples
return generated_data.detach()
def sample(self, num_samples):
generated_data = self.sample_generator(num_samples)
# Remove color channel
return generated_data.data.cpu().numpy()[:, 0, :, :]
def save_model_checkpoint(self, model_path, model_parameters, timestampStr):
# dateTimeObj = datetime.now()
# timestampStr = dateTimeObj.strftime("%d-%m-%Y-%H-%M-%S")
name = '%s_%s' % (timestampStr, 'wgan')
model_filename = os.path.join(model_path, name)
torch.save({
'generator_state_dict' : self.G.state_dict(),
'critic_state_dict' : self.D.state_dict(),
'gen_optimizer_state_dict' : self.G_opt.state_dict(),
'critic_optimizer_state_dict': self.D_opt.state_dict(),
'model_parameters' : model_parameters,
'iterations' : self.num_steps
}, model_filename)