import torch import torch.nn as nn import torch.nn.functional as F import pytorch_lightning as L import numpy as np class Downsampling(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, norm=True, lrelu=True): super().__init__() self.block = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=not norm), ) if norm: self.block.append(nn.InstanceNorm2d(out_channels, affine=True)) if lrelu is not None: self.block.append(nn.LeakyReLU(0.2, True) if lrelu else nn.ReLU(True)) def forward(self, x): return self.block(x) class Upsampling(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, output_padding=0, dropout=False): super().__init__() self.block = nn.Sequential( nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=False), nn.InstanceNorm2d(out_channels, affine=True), ) if dropout: self.block.append(nn.Dropout(0.5)) self.block.append(nn.ReLU(True)) def forward(self, x): return self.block(x) class ResBlock(nn.Module): def __init__(self, in_channels, kernel_size=3, padding=1): super().__init__() self.block = nn.Sequential( nn.ReflectionPad2d(padding), Downsampling(in_channels, in_channels, kernel_size=kernel_size, stride=1, padding=0, lrelu=False), nn.ReflectionPad2d(padding), Downsampling(in_channels, in_channels, kernel_size=kernel_size, stride=1, padding=0, lrelu=None), ) def forward(self, x): return x + self.block(x) class UNetGenerator(nn.Module): def __init__(self, hid_channels, in_channels, out_channels): super().__init__() self.downsampling_path = nn.Sequential( Downsampling(in_channels, hid_channels, norm=False), Downsampling(hid_channels, hid_channels*2), Downsampling(hid_channels*2, hid_channels*4), Downsampling(hid_channels*4, hid_channels*8), Downsampling(hid_channels*8, hid_channels*8), Downsampling(hid_channels*8, hid_channels*8), Downsampling(hid_channels*8, hid_channels*8), Downsampling(hid_channels*8, hid_channels*8, norm=False), ) self.upsampling_path = nn.Sequential( Upsampling(hid_channels*8, hid_channels*8, dropout=True), Upsampling(hid_channels*16, hid_channels*8, dropout=True), Upsampling(hid_channels*16, hid_channels*8, dropout=True), Upsampling(hid_channels*16, hid_channels*8), Upsampling(hid_channels*16, hid_channels*4), Upsampling(hid_channels*8, hid_channels*2), Upsampling(hid_channels*4, hid_channels), ) self.feature_block = nn.Sequential( nn.ConvTranspose2d(hid_channels*2, out_channels, kernel_size=4, stride=2, padding=1), nn.Tanh(), ) def forward(self, x): skips = [] for down in self.downsampling_path: x = down(x) skips.append(x) skips = reversed(skips[:-1]) for up, skip in zip(self.upsampling_path, skips): x = up(x) x = torch.cat([x, skip], dim=1) return self.feature_block(x) class ResNetGenerator(nn.Module): def __init__(self, hid_channels, in_channels, out_channels, num_resblocks): super().__init__() self.model = nn.Sequential( nn.ReflectionPad2d(3), Downsampling(in_channels, hid_channels, kernel_size=7, stride=1, padding=0, lrelu=False), Downsampling(hid_channels, hid_channels*2, kernel_size=3, lrelu=False), Downsampling(hid_channels*2, hid_channels*4, kernel_size=3, lrelu=False), *[ResBlock(hid_channels*4) for _ in range(num_resblocks)], Upsampling(hid_channels*4, hid_channels*2, kernel_size=3, output_padding=1), Upsampling(hid_channels*2, hid_channels, kernel_size=3, output_padding=1), nn.ReflectionPad2d(3), nn.Conv2d(hid_channels, out_channels, kernel_size=7, stride=1, padding=0), nn.Tanh(), ) def forward(self, x): return self.model(x) def get_gen(gen_name, hid_channels, num_resblocks, in_channels=3, out_channels=3): if gen_name == "unet": return UNetGenerator(hid_channels, in_channels, out_channels) elif gen_name == "resnet": return ResNetGenerator(hid_channels, in_channels, out_channels, num_resblocks) else: raise NotImplementedError(f"Generator name '{gen_name}' not recognized.") class Discriminator(nn.Module): def __init__(self, hid_channels, in_channels=3): super().__init__() self.block = nn.Sequential( Downsampling(in_channels, hid_channels, norm=False), Downsampling(hid_channels, hid_channels*2), Downsampling(hid_channels*2, hid_channels*4), Downsampling(hid_channels*4, hid_channels*8, stride=1), nn.Conv2d(hid_channels*8, 1, kernel_size=4, padding=1), ) def forward(self, x): return self.block(x) class ImageBuffer(object): def __init__(self, buffer_size): self.buffer_size = buffer_size if self.buffer_size > 0: self.curr_cap = 0 self.buffer = [] def __call__(self, imgs): if self.buffer_size == 0: return imgs return_imgs = [] for img in imgs: img = img.unsqueeze(dim=0) if self.curr_cap < self.buffer_size: self.curr_cap += 1 self.buffer.append(img) return_imgs.append(img) else: p = np.random.uniform(low=0., high=1.) if p > 0.5: idx = np.random.randint(low=0, high=self.buffer_size) tmp = self.buffer[idx].clone() self.buffer[idx] = img return_imgs.append(tmp) else: return_imgs.append(img) return torch.cat(return_imgs, dim=0) class CycleGAN(L.LightningModule): def __init__(self, gen_name, num_resblocks, hid_channels, optimizer, lr, lambda_idt, lambda_cycle, buffer_size, num_epochs, decay_epochs, betas): super().__init__() self.save_hyperparameters() self.optimizer = optimizer self.automatic_optimization = False self.gen_PM = get_gen(gen_name, hid_channels, num_resblocks) self.gen_MP = get_gen(gen_name, hid_channels, num_resblocks) self.disc_M = Discriminator(hid_channels) self.disc_P = Discriminator(hid_channels) self.buffer_fake_M = ImageBuffer(buffer_size) self.buffer_fake_P = ImageBuffer(buffer_size) def forward(self, img): return self.gen_PM(img) def init_weights(self): def init_fn(m): if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.InstanceNorm2d)): nn.init.normal_(m.weight, 0.0, 0.02) if m.bias is not None: nn.init.constant_(m.bias, 0.0) for net in [self.gen_PM, self.gen_MP, self.disc_M, self.disc_P]: net.apply(init_fn) def setup(self, stage): if stage == "fit": print("Model initialized.") def get_lr_scheduler(self, optimizer): def lr_lambda(epoch): len_decay_phase = self.hparams.num_epochs - self.hparams.decay_epochs + 1.0 curr_decay_step = max(0, epoch - self.hparams.decay_epochs + 1.0) val = 1.0 - curr_decay_step / len_decay_phase return max(0.0, val) return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) def configure_optimizers(self): opt_config = { "lr": self.hparams.lr, "betas": self.hparams.betas, } opt_gen = self.optimizer( list(self.gen_PM.parameters()) + list(self.gen_MP.parameters()), **opt_config, ) opt_disc = self.optimizer( list(self.disc_M.parameters()) + list(self.disc_P.parameters()), **opt_config, ) optimizers = [opt_gen, opt_disc] schedulers = [self.get_lr_scheduler(opt) for opt in optimizers] return optimizers, schedulers def adv_criterion(self, y_hat, y): return F.mse_loss(y_hat, y) def recon_criterion(self, y_hat, y): return F.l1_loss(y_hat, y) def get_adv_loss(self, fake, disc): fake_hat = disc(fake) real_labels = torch.ones_like(fake_hat) adv_loss = self.adv_criterion(fake_hat, real_labels) return adv_loss def get_idt_loss(self, real, idt, lambda_cycle): idt_loss = self.recon_criterion(idt, real) return 0 def get_cycle_loss(self, real, recon, lambda_cycle): cycle_loss = self.recon_criterion(recon, real) return lambda_cycle * cycle_loss def get_gen_loss(self): adv_loss_PM = self.get_adv_loss(self.fake_M, self.disc_M) adv_loss_MP = self.get_adv_loss(self.fake_P, self.disc_P) total_adv_loss = adv_loss_PM + adv_loss_MP lambda_cycle = self.hparams.lambda_cycle idt_loss_MM = self.get_idt_loss(self.real_M, self.idt_M, lambda_cycle[0]) idt_loss_PP = self.get_idt_loss(self.real_P, self.idt_P, lambda_cycle[1]) total_idt_loss = idt_loss_MM + idt_loss_PP cycle_loss_MPM = self.get_cycle_loss(self.real_M, self.recon_M, lambda_cycle[0]) cycle_loss_PMP = self.get_cycle_loss(self.real_P, self.recon_P, lambda_cycle[1]) total_cycle_loss = cycle_loss_MPM + cycle_loss_PMP gen_loss = total_adv_loss + total_idt_loss + total_cycle_loss return gen_loss def get_disc_loss(self, real, fake, disc): real_hat = disc(real) real_labels = torch.ones_like(real_hat) real_loss = self.adv_criterion(real_hat, real_labels) fake_hat = disc(fake.detach()) fake_labels = torch.zeros_like(fake_hat) fake_loss = self.adv_criterion(fake_hat, fake_labels) disc_loss = (fake_loss + real_loss) * 0.5 return disc_loss def get_disc_loss_M(self): fake_M = self.buffer_fake_M(self.fake_M) return self.get_disc_loss(self.real_M, fake_M, self.disc_M) def get_disc_loss_P(self): fake_P = self.buffer_fake_P(self.fake_P) return self.get_disc_loss(self.real_P, fake_P, self.disc_P) def training_step(self, batch, batch_idx): self.real_M = batch["monet"] self.real_P = batch["photo"] opt_gen, opt_disc = self.optimizers() self.fake_M = self.gen_PM(self.real_P) self.fake_P = self.gen_MP(self.real_M) self.idt_M = self.gen_PM(self.real_M) self.idt_P = self.gen_MP(self.real_P) self.recon_M = self.gen_PM(self.fake_P) self.recon_P = self.gen_MP(self.fake_M) self.toggle_optimizer(opt_gen) gen_loss = self.get_gen_loss() opt_gen.zero_grad() self.manual_backward(gen_loss) opt_gen.step() self.untoggle_optimizer(opt_gen) self.toggle_optimizer(opt_disc) disc_loss_M = self.get_disc_loss_M() disc_loss_P = self.get_disc_loss_P() opt_disc.zero_grad() self.manual_backward(disc_loss_M) self.manual_backward(disc_loss_P) opt_disc.step() self.untoggle_optimizer(opt_disc) metrics = { "gen_loss": gen_loss, "disc_loss_M": disc_loss_M, "disc_loss_P": disc_loss_P, } wandb.log(metrics) self.log_dict(metrics, on_step=False, on_epoch=True, prog_bar=True) def validation_step(self, batch, batch_idx): self.display_results(batch, batch_idx, "validate") def test_step(self, batch, batch_idx): self.display_results(batch, batch_idx, "test") def predict_step(self, batch, batch_idx): return self(batch) def display_results(self, batch, batch_idx, stage): real_P = batch fake_M = self(real_P) if stage == "validate": title = f"Epoch {self.current_epoch+1}: Photo-to-Monet Translation" else: title = f"Sample {batch_idx+1}: Photo-to-Monet Translation" show_img( torch.cat([real_P, fake_M], dim=0), nrow=len(real_P), title=title, ) def on_train_epoch_start(self): curr_lr = self.lr_schedulers()[0].get_last_lr()[0] self.log("lr", curr_lr, on_step=False, on_epoch=True, prog_bar=True) def on_train_epoch_end(self): for sch in self.lr_schedulers(): sch.step() logged_values = self.trainer.progress_bar_metrics print( f"Epoch {self.current_epoch+1}", *[f"{k}: {v:.5f}" for k, v in logged_values.items()], sep=" - ", ) def on_train_end(self): print("Training ended.") def on_predict_epoch_end(self): predictions = self.trainer.predict_loop.predictions num_batches = len(predictions) batch_size = predictions[0].shape[0] last_batch_diff = batch_size - predictions[-1].shape[0] print(f"Number of images generated: {num_batches*batch_size-last_batch_diff}.")