|
|
|
|
|
import sys |
|
import os |
|
import torch |
|
|
|
|
|
root_path = os.path.abspath('.') |
|
sys.path.append(root_path) |
|
from architecture.rrdb import RRDBNet |
|
from architecture.discriminator import UNetDiscriminatorSN |
|
from train_code.train_master import train_master |
|
|
|
|
|
|
|
class train_esrgan(train_master): |
|
def __init__(self, options, args) -> None: |
|
super().__init__(options, args, "esrgan", True) |
|
|
|
|
|
def loss_init(self): |
|
|
|
|
|
self.pixel_loss_load() |
|
|
|
|
|
self.GAN_loss_load() |
|
|
|
|
|
def call_model(self): |
|
|
|
self.generator = RRDBNet(3, 3, scale=self.options['scale'], num_block=self.options['ESR_blocks_num']).cuda() |
|
|
|
self.discriminator = UNetDiscriminatorSN(3).cuda() |
|
|
|
self.generator.train(); self.discriminator.train() |
|
|
|
|
|
def run(self): |
|
self.master_run() |
|
|
|
|
|
|
|
def calculate_loss(self, gen_hr, imgs_hr): |
|
|
|
|
|
|
|
l_g_pix = self.cri_pix(gen_hr, imgs_hr) |
|
self.generator_loss += l_g_pix |
|
self.weight_store["pixel_loss"] = l_g_pix |
|
|
|
|
|
|
|
l_g_percep_danbooru = self.cri_danbooru_perceptual(gen_hr, imgs_hr) |
|
l_g_percep_vgg = self.cri_vgg_perceptual(gen_hr, imgs_hr) |
|
l_g_percep = l_g_percep_danbooru + l_g_percep_vgg |
|
self.generator_loss += l_g_percep |
|
self.weight_store["perceptual_loss"] = l_g_percep |
|
|
|
|
|
|
|
fake_g_preds = self.discriminator(gen_hr) |
|
l_g_gan = self.cri_gan(fake_g_preds, True, is_disc=False) |
|
self.generator_loss += l_g_gan |
|
self.weight_store["gan_loss"] = l_g_gan |
|
|
|
|
|
def tensorboard_report(self, iteration): |
|
self.writer.add_scalar('Loss/train-Generator_Loss-Iteration', self.generator_loss, iteration) |
|
self.writer.add_scalar('Loss/train-Pixel_Loss-Iteration', self.weight_store["pixel_loss"], iteration) |
|
self.writer.add_scalar('Loss/train-Perceptual_Loss-Iteration', self.weight_store["perceptual_loss"], iteration) |
|
self.writer.add_scalar('Loss/train-Discriminator_Loss-Iteration', self.weight_store["gan_loss"], iteration) |
|
|
|
|