Spaces:
Running
on
T4
Running
on
T4
import os | |
import time | |
import numpy as np | |
import torch | |
import torch.optim as optim | |
from cvxopt import matrix | |
from cvxopt import solvers | |
from cvxopt import sparse | |
from cvxopt import spmatrix | |
from torch.autograd import grad as torch_grad | |
from tqdm import tqdm | |
class WassersteinGanQuadraticCost: | |
def __init__(self, generator, discriminator, gen_optimizer, dis_optimizer, criterion, epochs, n_max_iterations, | |
data_dimensions, batch_size, device, gamma=0.1, K=-1, milestones=[150000, 250000], lr_anneal=1.0): | |
self.G = generator | |
self.G_opt = gen_optimizer | |
self.D = discriminator | |
self.D_opt = dis_optimizer | |
self.losses = { | |
'D' : [], | |
'WD': [], | |
'G' : [] | |
} | |
self.num_steps = 0 | |
self.gen_steps = 0 | |
self.epochs = epochs | |
self.n_max_iterations = n_max_iterations | |
# put in the shape of a dataset sample | |
self.data_dim = data_dimensions[0] * data_dimensions[1] * data_dimensions[2] | |
self.batch_size = batch_size | |
self.device = device | |
self.criterion = criterion | |
self.mone = torch.FloatTensor([-1]).to(device) | |
self.tensorboard_counter = 0 | |
if K <= 0: | |
self.K = 1 / self.data_dim | |
else: | |
self.K = K | |
self.Kr = np.sqrt(self.K) | |
self.LAMBDA = 2 * self.Kr * gamma * 2 | |
self.G = self.G.to(self.device) | |
self.D = self.D.to(self.device) | |
self.schedulerD = self._build_lr_scheduler_(self.D_opt, milestones, lr_anneal) | |
self.schedulerG = self._build_lr_scheduler_(self.G_opt, milestones, lr_anneal) | |
self.c, self.A, self.pStart = self._prepare_linear_programming_solver_(self.batch_size) | |
def _build_lr_scheduler_(self, optimizer, milestones, lr_anneal, last_epoch=-1): | |
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=lr_anneal, last_epoch=-1) | |
return scheduler | |
def _quadratic_wasserstein_distance_(self, real, generated): | |
num_r = real.size(0) | |
num_f = generated.size(0) | |
real_flat = real.view(num_r, -1) | |
fake_flat = generated.view(num_f, -1) | |
real3D = real_flat.unsqueeze(1).expand(num_r, num_f, self.data_dim) | |
fake3D = fake_flat.unsqueeze(0).expand(num_r, num_f, self.data_dim) | |
# compute squared L2 distance | |
dif = real3D - fake3D | |
dist = 0.5 * dif.pow(2).sum(2).squeeze() | |
return self.K * dist | |
def _prepare_linear_programming_solver_(self, batch_size): | |
A = spmatrix(1.0, range(batch_size), [0] * batch_size, (batch_size, batch_size)) | |
for i in range(1, batch_size): | |
Ai = spmatrix(1.0, range(batch_size), [i] * batch_size, (batch_size, batch_size)) | |
A = sparse([A, Ai]) | |
D = spmatrix(-1.0, range(batch_size), range(batch_size), (batch_size, batch_size)) | |
DM = D | |
for i in range(1, batch_size): | |
DM = sparse([DM, D]) | |
A = sparse([[A], [DM]]) | |
cr = matrix([-1.0 / batch_size] * batch_size) | |
cf = matrix([1.0 / batch_size] * batch_size) | |
c = matrix([cr, cf]) | |
pStart = {} | |
pStart['x'] = matrix([matrix([1.0] * batch_size), matrix([-1.0] * batch_size)]) | |
pStart['s'] = matrix([1.0] * (2 * batch_size)) | |
return c, A, pStart | |
def _linear_programming_(self, distance, batch_size): | |
b = matrix(distance.cpu().double().detach().numpy().flatten()) | |
sol = solvers.lp(self.c, self.A, b, primalstart=self.pStart, solver='glpk', | |
options={'glpk': {'msg_lev': 'GLP_MSG_OFF'}}) | |
offset = 0.5 * (sum(sol['x'])) / batch_size | |
sol['x'] = sol['x'] - offset | |
self.pStart['x'] = sol['x'] | |
self.pStart['s'] = sol['s'] | |
return sol | |
def _approx_OT_(self, sol): | |
# Compute the OT mapping for each fake dataset | |
ResMat = np.array(sol['z']).reshape((self.batch_size, self.batch_size)) | |
mapping = torch.from_numpy(np.argmax(ResMat, axis=0)).long().to(self.device) | |
return mapping | |
def _optimal_transport_regularization_(self, output_fake, fake, real_fake_diff): | |
output_fake_grad = torch.ones(output_fake.size()).to(self.device) | |
gradients = torch_grad(outputs=output_fake, inputs=fake, | |
grad_outputs=output_fake_grad, | |
create_graph=True, retain_graph=True, only_inputs=True)[0] | |
n = gradients.size(0) | |
RegLoss = 0.5 * ((gradients.view(n, -1).norm(dim=1) / (2 * self.Kr) - self.Kr / 2 * real_fake_diff.view(n, | |
-1).norm( | |
dim=1)).pow(2)).mean() | |
fake.requires_grad = False | |
return RegLoss | |
def _critic_deep_regression_(self, images, opt_iterations=1): | |
images = images.to(self.device) | |
for p in self.D.parameters(): # reset requires_grad | |
p.requires_grad = True # they are set to False below in netG update | |
self.G.train() | |
self.D.train() | |
# Get generated fake dataset | |
generated_data = self.sample_generator(self.batch_size) | |
# compute wasserstein distance | |
distance = self._quadratic_wasserstein_distance_(images, generated_data) | |
# solve linear programming problem | |
sol = self._linear_programming_(distance, self.batch_size) | |
# approximate optimal transport | |
mapping = self._approx_OT_(sol) | |
real_ordered = images[mapping] # match real and fake | |
real_fake_diff = real_ordered - generated_data | |
# construct target | |
target = torch.from_numpy(np.array(sol['x'])).float() | |
target = target.squeeze().to(self.device) | |
for i in range(opt_iterations): | |
self.D.zero_grad() # ??? | |
self.D_opt.zero_grad() | |
generated_data.requires_grad_() | |
if generated_data.grad is not None: | |
generated_data.grad.data.zero_() | |
output_real = self.D(images) | |
output_fake = self.D(generated_data) | |
output_real, output_fake = output_real.squeeze(), output_fake.squeeze() | |
output_R_mean = output_real.mean(0).view(1) | |
output_F_mean = output_fake.mean(0).view(1) | |
L2LossD_real = self.criterion(output_R_mean[0], target[:self.batch_size].mean()) | |
L2LossD_fake = self.criterion(output_fake, target[self.batch_size:]) | |
L2LossD = 0.5 * L2LossD_real + 0.5 * L2LossD_fake | |
reg_loss_D = self._optimal_transport_regularization_(output_fake, generated_data, real_fake_diff) | |
total_loss = L2LossD + self.LAMBDA * reg_loss_D | |
self.losses['D'].append(float(total_loss.data)) | |
total_loss.backward() | |
self.D_opt.step() | |
# this is supposed to be the wasserstein distance | |
wasserstein_distance = output_R_mean - output_F_mean | |
self.losses['WD'].append(float(wasserstein_distance.data)) | |
def _generator_train_iteration(self, batch_size): | |
for p in self.D.parameters(): | |
p.requires_grad = False # freeze critic | |
self.G.zero_grad() | |
self.G_opt.zero_grad() | |
if isinstance(self.G, torch.nn.parallel.DataParallel): | |
z = self.G.module.sample_latent(batch_size, self.G.module.z_dim) | |
else: | |
z = self.G.sample_latent(batch_size, self.G.z_dim) | |
z.requires_grad = True | |
fake = self.G(z) | |
output_fake = self.D(fake) | |
output_F_mean_after = output_fake.mean(0).view(1) | |
self.losses['G'].append(float(output_F_mean_after.data)) | |
output_F_mean_after.backward(self.mone) | |
self.G_opt.step() | |
self.schedulerD.step() | |
self.schedulerG.step() | |
def _train_epoch(self, data_loader, writer, experiment): | |
for i, data in enumerate(tqdm(data_loader)): | |
images = data[0] | |
speaker_ids = data[1] | |
self.num_steps += 1 | |
# self.tensorboard_counter += 1 | |
if self.gen_steps >= self.n_max_iterations: | |
return | |
self._critic_deep_regression_(images) | |
self._generator_train_iteration(images.size(0)) | |
D_loss_avg = np.average(self.losses['D']) | |
G_loss_avg = np.average(self.losses['G']) | |
wd_avg = np.average(self.losses['WD']) | |
def train(self, data_loader, writer, experiment=None): | |
self.G.train() | |
self.D.train() | |
for epoch in range(self.epochs): | |
if self.gen_steps >= self.n_max_iterations: | |
return | |
time_start_epoch = time.time() | |
self._train_epoch(data_loader, writer, experiment) | |
D_loss_avg = np.average(self.losses['D']) | |
time_end_epoch = time.time() | |
return self | |
def sample_generator(self, num_samples, nograd=False, return_intermediate=False): | |
self.G.eval() | |
if isinstance(self.G, torch.nn.parallel.DataParallel): | |
latent_samples = self.G.module.sample_latent(num_samples, self.G.module.z_dim, 1.0) | |
else: | |
latent_samples = self.G.sample_latent(num_samples, self.G.z_dim, 1.0) | |
latent_samples = latent_samples.to(self.device) | |
if nograd: | |
with torch.no_grad(): | |
generated_data = self.G(latent_samples, return_intermediate=return_intermediate) | |
else: | |
generated_data = self.G(latent_samples) | |
self.G.train() | |
if return_intermediate: | |
return generated_data[0].detach(), generated_data[1], latent_samples | |
return generated_data.detach() | |
def sample(self, num_samples): | |
generated_data = self.sample_generator(num_samples) | |
# Remove color channel | |
return generated_data.data.cpu().numpy()[:, 0, :, :] | |
def save_model_checkpoint(self, model_path, model_parameters, timestampStr): | |
# dateTimeObj = datetime.now() | |
# timestampStr = dateTimeObj.strftime("%d-%m-%Y-%H-%M-%S") | |
name = '%s_%s' % (timestampStr, 'wgan') | |
model_filename = os.path.join(model_path, name) | |
torch.save({ | |
'generator_state_dict' : self.G.state_dict(), | |
'critic_state_dict' : self.D.state_dict(), | |
'gen_optimizer_state_dict' : self.G_opt.state_dict(), | |
'critic_optimizer_state_dict': self.D_opt.state_dict(), | |
'model_parameters' : model_parameters, | |
'iterations' : self.num_steps | |
}, model_filename) | |