Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import pytorch_lightning as L | |
import numpy as np | |
class Downsampling(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, norm=True, lrelu=True): | |
super().__init__() | |
self.block = nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=not norm), | |
) | |
if norm: | |
self.block.append(nn.InstanceNorm2d(out_channels, affine=True)) | |
if lrelu is not None: | |
self.block.append(nn.LeakyReLU(0.2, True) if lrelu else nn.ReLU(True)) | |
def forward(self, x): | |
return self.block(x) | |
class Upsampling(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, output_padding=0, dropout=False): | |
super().__init__() | |
self.block = nn.Sequential( | |
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=False), | |
nn.InstanceNorm2d(out_channels, affine=True), | |
) | |
if dropout: | |
self.block.append(nn.Dropout(0.5)) | |
self.block.append(nn.ReLU(True)) | |
def forward(self, x): | |
return self.block(x) | |
class ResBlock(nn.Module): | |
def __init__(self, in_channels, kernel_size=3, padding=1): | |
super().__init__() | |
self.block = nn.Sequential( | |
nn.ReflectionPad2d(padding), | |
Downsampling(in_channels, in_channels, kernel_size=kernel_size, stride=1, padding=0, lrelu=False), | |
nn.ReflectionPad2d(padding), | |
Downsampling(in_channels, in_channels, kernel_size=kernel_size, stride=1, padding=0, lrelu=None), | |
) | |
def forward(self, x): | |
return x + self.block(x) | |
class UNetGenerator(nn.Module): | |
def __init__(self, hid_channels, in_channels, out_channels): | |
super().__init__() | |
self.downsampling_path = nn.Sequential( | |
Downsampling(in_channels, hid_channels, norm=False), | |
Downsampling(hid_channels, hid_channels*2), | |
Downsampling(hid_channels*2, hid_channels*4), | |
Downsampling(hid_channels*4, hid_channels*8), | |
Downsampling(hid_channels*8, hid_channels*8), | |
Downsampling(hid_channels*8, hid_channels*8), | |
Downsampling(hid_channels*8, hid_channels*8), | |
Downsampling(hid_channels*8, hid_channels*8, norm=False), | |
) | |
self.upsampling_path = nn.Sequential( | |
Upsampling(hid_channels*8, hid_channels*8, dropout=True), | |
Upsampling(hid_channels*16, hid_channels*8, dropout=True), | |
Upsampling(hid_channels*16, hid_channels*8, dropout=True), | |
Upsampling(hid_channels*16, hid_channels*8), | |
Upsampling(hid_channels*16, hid_channels*4), | |
Upsampling(hid_channels*8, hid_channels*2), | |
Upsampling(hid_channels*4, hid_channels), | |
) | |
self.feature_block = nn.Sequential( | |
nn.ConvTranspose2d(hid_channels*2, out_channels, kernel_size=4, stride=2, padding=1), | |
nn.Tanh(), | |
) | |
def forward(self, x): | |
skips = [] | |
for down in self.downsampling_path: | |
x = down(x) | |
skips.append(x) | |
skips = reversed(skips[:-1]) | |
for up, skip in zip(self.upsampling_path, skips): | |
x = up(x) | |
x = torch.cat([x, skip], dim=1) | |
return self.feature_block(x) | |
class ResNetGenerator(nn.Module): | |
def __init__(self, hid_channels, in_channels, out_channels, num_resblocks): | |
super().__init__() | |
self.model = nn.Sequential( | |
nn.ReflectionPad2d(3), | |
Downsampling(in_channels, hid_channels, kernel_size=7, stride=1, padding=0, lrelu=False), | |
Downsampling(hid_channels, hid_channels*2, kernel_size=3, lrelu=False), | |
Downsampling(hid_channels*2, hid_channels*4, kernel_size=3, lrelu=False), | |
*[ResBlock(hid_channels*4) for _ in range(num_resblocks)], | |
Upsampling(hid_channels*4, hid_channels*2, kernel_size=3, output_padding=1), | |
Upsampling(hid_channels*2, hid_channels, kernel_size=3, output_padding=1), | |
nn.ReflectionPad2d(3), | |
nn.Conv2d(hid_channels, out_channels, kernel_size=7, stride=1, padding=0), | |
nn.Tanh(), | |
) | |
def forward(self, x): | |
return self.model(x) | |
def get_gen(gen_name, hid_channels, num_resblocks, in_channels=3, out_channels=3): | |
if gen_name == "unet": | |
return UNetGenerator(hid_channels, in_channels, out_channels) | |
elif gen_name == "resnet": | |
return ResNetGenerator(hid_channels, in_channels, out_channels, num_resblocks) | |
else: | |
raise NotImplementedError(f"Generator name '{gen_name}' not recognized.") | |
class Discriminator(nn.Module): | |
def __init__(self, hid_channels, in_channels=3): | |
super().__init__() | |
self.block = nn.Sequential( | |
Downsampling(in_channels, hid_channels, norm=False), | |
Downsampling(hid_channels, hid_channels*2), | |
Downsampling(hid_channels*2, hid_channels*4), | |
Downsampling(hid_channels*4, hid_channels*8, stride=1), | |
nn.Conv2d(hid_channels*8, 1, kernel_size=4, padding=1), | |
) | |
def forward(self, x): | |
return self.block(x) | |
class ImageBuffer(object): | |
def __init__(self, buffer_size): | |
self.buffer_size = buffer_size | |
if self.buffer_size > 0: | |
self.curr_cap = 0 | |
self.buffer = [] | |
def __call__(self, imgs): | |
if self.buffer_size == 0: | |
return imgs | |
return_imgs = [] | |
for img in imgs: | |
img = img.unsqueeze(dim=0) | |
if self.curr_cap < self.buffer_size: | |
self.curr_cap += 1 | |
self.buffer.append(img) | |
return_imgs.append(img) | |
else: | |
p = np.random.uniform(low=0., high=1.) | |
if p > 0.5: | |
idx = np.random.randint(low=0, high=self.buffer_size) | |
tmp = self.buffer[idx].clone() | |
self.buffer[idx] = img | |
return_imgs.append(tmp) | |
else: | |
return_imgs.append(img) | |
return torch.cat(return_imgs, dim=0) | |
class CycleGAN(L.LightningModule): | |
def __init__(self, gen_name, num_resblocks, hid_channels, optimizer, lr, lambda_idt, lambda_cycle, buffer_size, num_epochs, decay_epochs, betas): | |
super().__init__() | |
self.save_hyperparameters() | |
self.optimizer = optimizer | |
self.automatic_optimization = False | |
self.gen_PM = get_gen(gen_name, hid_channels, num_resblocks) | |
self.gen_MP = get_gen(gen_name, hid_channels, num_resblocks) | |
self.disc_M = Discriminator(hid_channels) | |
self.disc_P = Discriminator(hid_channels) | |
self.buffer_fake_M = ImageBuffer(buffer_size) | |
self.buffer_fake_P = ImageBuffer(buffer_size) | |
def forward(self, img): | |
return self.gen_PM(img) | |
def init_weights(self): | |
def init_fn(m): | |
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.InstanceNorm2d)): | |
nn.init.normal_(m.weight, 0.0, 0.02) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0.0) | |
for net in [self.gen_PM, self.gen_MP, self.disc_M, self.disc_P]: | |
net.apply(init_fn) | |
def setup(self, stage): | |
if stage == "fit": | |
print("Model initialized.") | |
def get_lr_scheduler(self, optimizer): | |
def lr_lambda(epoch): | |
len_decay_phase = self.hparams.num_epochs - self.hparams.decay_epochs + 1.0 | |
curr_decay_step = max(0, epoch - self.hparams.decay_epochs + 1.0) | |
val = 1.0 - curr_decay_step / len_decay_phase | |
return max(0.0, val) | |
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) | |
def configure_optimizers(self): | |
opt_config = { | |
"lr": self.hparams.lr, | |
"betas": self.hparams.betas, | |
} | |
opt_gen = self.optimizer( | |
list(self.gen_PM.parameters()) + list(self.gen_MP.parameters()), | |
**opt_config, | |
) | |
opt_disc = self.optimizer( | |
list(self.disc_M.parameters()) + list(self.disc_P.parameters()), | |
**opt_config, | |
) | |
optimizers = [opt_gen, opt_disc] | |
schedulers = [self.get_lr_scheduler(opt) for opt in optimizers] | |
return optimizers, schedulers | |
def adv_criterion(self, y_hat, y): | |
return F.mse_loss(y_hat, y) | |
def recon_criterion(self, y_hat, y): | |
return F.l1_loss(y_hat, y) | |
def get_adv_loss(self, fake, disc): | |
fake_hat = disc(fake) | |
real_labels = torch.ones_like(fake_hat) | |
adv_loss = self.adv_criterion(fake_hat, real_labels) | |
return adv_loss | |
def get_idt_loss(self, real, idt, lambda_cycle): | |
idt_loss = self.recon_criterion(idt, real) | |
return 0 | |
def get_cycle_loss(self, real, recon, lambda_cycle): | |
cycle_loss = self.recon_criterion(recon, real) | |
return lambda_cycle * cycle_loss | |
def get_gen_loss(self): | |
adv_loss_PM = self.get_adv_loss(self.fake_M, self.disc_M) | |
adv_loss_MP = self.get_adv_loss(self.fake_P, self.disc_P) | |
total_adv_loss = adv_loss_PM + adv_loss_MP | |
lambda_cycle = self.hparams.lambda_cycle | |
idt_loss_MM = self.get_idt_loss(self.real_M, self.idt_M, lambda_cycle[0]) | |
idt_loss_PP = self.get_idt_loss(self.real_P, self.idt_P, lambda_cycle[1]) | |
total_idt_loss = idt_loss_MM + idt_loss_PP | |
cycle_loss_MPM = self.get_cycle_loss(self.real_M, self.recon_M, lambda_cycle[0]) | |
cycle_loss_PMP = self.get_cycle_loss(self.real_P, self.recon_P, lambda_cycle[1]) | |
total_cycle_loss = cycle_loss_MPM + cycle_loss_PMP | |
gen_loss = total_adv_loss + total_idt_loss + total_cycle_loss | |
return gen_loss | |
def get_disc_loss(self, real, fake, disc): | |
real_hat = disc(real) | |
real_labels = torch.ones_like(real_hat) | |
real_loss = self.adv_criterion(real_hat, real_labels) | |
fake_hat = disc(fake.detach()) | |
fake_labels = torch.zeros_like(fake_hat) | |
fake_loss = self.adv_criterion(fake_hat, fake_labels) | |
disc_loss = (fake_loss + real_loss) * 0.5 | |
return disc_loss | |
def get_disc_loss_M(self): | |
fake_M = self.buffer_fake_M(self.fake_M) | |
return self.get_disc_loss(self.real_M, fake_M, self.disc_M) | |
def get_disc_loss_P(self): | |
fake_P = self.buffer_fake_P(self.fake_P) | |
return self.get_disc_loss(self.real_P, fake_P, self.disc_P) | |
def training_step(self, batch, batch_idx): | |
self.real_M = batch["monet"] | |
self.real_P = batch["photo"] | |
opt_gen, opt_disc = self.optimizers() | |
self.fake_M = self.gen_PM(self.real_P) | |
self.fake_P = self.gen_MP(self.real_M) | |
self.idt_M = self.gen_PM(self.real_M) | |
self.idt_P = self.gen_MP(self.real_P) | |
self.recon_M = self.gen_PM(self.fake_P) | |
self.recon_P = self.gen_MP(self.fake_M) | |
self.toggle_optimizer(opt_gen) | |
gen_loss = self.get_gen_loss() | |
opt_gen.zero_grad() | |
self.manual_backward(gen_loss) | |
opt_gen.step() | |
self.untoggle_optimizer(opt_gen) | |
self.toggle_optimizer(opt_disc) | |
disc_loss_M = self.get_disc_loss_M() | |
disc_loss_P = self.get_disc_loss_P() | |
opt_disc.zero_grad() | |
self.manual_backward(disc_loss_M) | |
self.manual_backward(disc_loss_P) | |
opt_disc.step() | |
self.untoggle_optimizer(opt_disc) | |
metrics = { | |
"gen_loss": gen_loss, | |
"disc_loss_M": disc_loss_M, | |
"disc_loss_P": disc_loss_P, | |
} | |
wandb.log(metrics) | |
self.log_dict(metrics, on_step=False, on_epoch=True, prog_bar=True) | |
def validation_step(self, batch, batch_idx): | |
self.display_results(batch, batch_idx, "validate") | |
def test_step(self, batch, batch_idx): | |
self.display_results(batch, batch_idx, "test") | |
def predict_step(self, batch, batch_idx): | |
return self(batch) | |
def display_results(self, batch, batch_idx, stage): | |
real_P = batch | |
fake_M = self(real_P) | |
if stage == "validate": | |
title = f"Epoch {self.current_epoch+1}: Photo-to-Monet Translation" | |
else: | |
title = f"Sample {batch_idx+1}: Photo-to-Monet Translation" | |
show_img( | |
torch.cat([real_P, fake_M], dim=0), | |
nrow=len(real_P), | |
title=title, | |
) | |
def on_train_epoch_start(self): | |
curr_lr = self.lr_schedulers()[0].get_last_lr()[0] | |
self.log("lr", curr_lr, on_step=False, on_epoch=True, prog_bar=True) | |
def on_train_epoch_end(self): | |
for sch in self.lr_schedulers(): | |
sch.step() | |
logged_values = self.trainer.progress_bar_metrics | |
print( | |
f"Epoch {self.current_epoch+1}", | |
*[f"{k}: {v:.5f}" for k, v in logged_values.items()], | |
sep=" - ", | |
) | |
def on_train_end(self): | |
print("Training ended.") | |
def on_predict_epoch_end(self): | |
predictions = self.trainer.predict_loop.predictions | |
num_batches = len(predictions) | |
batch_size = predictions[0].shape[0] | |
last_batch_diff = batch_size - predictions[-1].shape[0] | |
print(f"Number of images generated: {num_batches*batch_size-last_batch_diff}.") |