# -*- coding: utf-8 -*- import os import os.path as osp import sys import time from collections import defaultdict import numpy as np import paddle from paddle import nn from PIL import Image from tqdm import tqdm from starganv2vc_paddle.losses import compute_d_loss, compute_g_loss import logging logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) class Trainer(object): def __init__(self, args, model=None, model_ema=None, optimizer=None, scheduler=None, config={}, logger=logger, train_dataloader=None, val_dataloader=None, initial_steps=0, initial_epochs=0, fp16_run=False ): self.args = args self.steps = initial_steps self.epochs = initial_epochs self.model = model self.model_ema = model_ema self.optimizer = optimizer self.scheduler = scheduler self.train_dataloader = train_dataloader self.val_dataloader = val_dataloader self.config = config self.finish_train = False self.logger = logger self.fp16_run = fp16_run def _train_epoch(self): """Train model one epoch.""" raise NotImplementedError @paddle.no_grad() def _eval_epoch(self): """Evaluate model one epoch.""" pass def save_checkpoint(self, checkpoint_path): """Save checkpoint. Args: checkpoint_path (str): Checkpoint path to be saved. """ state_dict = { "optimizer": self.optimizer.state_dict(), "steps": self.steps, "epochs": self.epochs, "model": {key: self.model[key].state_dict() for key in self.model} } if self.model_ema is not None: state_dict['model_ema'] = {key: self.model_ema[key].state_dict() for key in self.model_ema} if not os.path.exists(os.path.dirname(checkpoint_path)): os.makedirs(os.path.dirname(checkpoint_path)) paddle.save(state_dict, checkpoint_path) def load_checkpoint(self, checkpoint_path, load_only_params=False): """Load checkpoint. Args: checkpoint_path (str): Checkpoint path to be loaded. load_only_params (bool): Whether to load only model parameters. """ state_dict = paddle.load(checkpoint_path) if state_dict["model"] is not None: for key in self.model: self._load(state_dict["model"][key], self.model[key]) if self.model_ema is not None: for key in self.model_ema: self._load(state_dict["model_ema"][key], self.model_ema[key]) if not load_only_params: self.steps = state_dict["steps"] self.epochs = state_dict["epochs"] self.optimizer.set_state_dict(state_dict["optimizer"]) def _load(self, states, model, force_load=True): model_states = model.state_dict() for key, val in states.items(): try: if key not in model_states: continue if isinstance(val, nn.Parameter): val = val.clone().detach() if val.shape != model_states[key].shape: self.logger.info("%s does not have same shape" % key) print(val.shape, model_states[key].shape) if not force_load: continue min_shape = np.minimum(np.array(val.shape), np.array(model_states[key].shape)) slices = [slice(0, min_index) for min_index in min_shape] model_states[key][slices][:] = val[slices] else: model_states[key][:] = val except: self.logger.info("not exist :%s" % key) print("not exist ", key) @staticmethod def get_gradient_norm(model): total_norm = 0 for p in model.parameters(): param_norm = p.grad.data.norm(2) total_norm += param_norm.item() ** 2 total_norm = np.sqrt(total_norm) return total_norm @staticmethod def length_to_mask(lengths): mask = paddle.arange(lengths.max()).unsqueeze(0).expand([lengths.shape[0], -1]).astype(lengths.dtype) mask = paddle.greater_than(mask+1, lengths.unsqueeze(1)) return mask def _get_lr(self): return self.optimizer.get_lr() @staticmethod def moving_average(model, model_test, beta=0.999): for param, param_test in zip(model.parameters(), model_test.parameters()): param_test.set_value(param + beta * (param_test - param)) def _train_epoch(self): self.epochs += 1 train_losses = defaultdict(list) _ = [self.model[k].train() for k in self.model] scaler = paddle.amp.GradScaler() if self.fp16_run else None use_con_reg = (self.epochs >= self.args.con_reg_epoch) use_adv_cls = (self.epochs >= self.args.adv_cls_epoch) for train_steps_per_epoch, batch in enumerate(tqdm(self.train_dataloader, desc="[train]"), 1): ### load data x_real, y_org, x_ref, x_ref2, y_trg, z_trg, z_trg2 = batch # train the discriminator (by random reference) self.optimizer.clear_grad() if scaler is not None: with paddle.amp.autocast(): d_loss, d_losses_latent = compute_d_loss(self.model, self.args.d_loss, x_real, y_org, y_trg, z_trg=z_trg, use_adv_cls=use_adv_cls, use_con_reg=use_con_reg) scaler.scale(d_loss).backward() else: d_loss, d_losses_latent = compute_d_loss(self.model, self.args.d_loss, x_real, y_org, y_trg, z_trg=z_trg, use_adv_cls=use_adv_cls, use_con_reg=use_con_reg) d_loss.backward() self.optimizer.step('discriminator', scaler=scaler) # train the discriminator (by target reference) self.optimizer.clear_grad() if scaler is not None: with paddle.amp.autocast(): d_loss, d_losses_ref = compute_d_loss(self.model, self.args.d_loss, x_real, y_org, y_trg, x_ref=x_ref, use_adv_cls=use_adv_cls, use_con_reg=use_con_reg) scaler.scale(d_loss).backward() else: d_loss, d_losses_ref = compute_d_loss(self.model, self.args.d_loss, x_real, y_org, y_trg, x_ref=x_ref, use_adv_cls=use_adv_cls, use_con_reg=use_con_reg) d_loss.backward() self.optimizer.step('discriminator', scaler=scaler) # train the generator (by random reference) self.optimizer.clear_grad() if scaler is not None: with paddle.amp.autocast(): g_loss, g_losses_latent = compute_g_loss( self.model, self.args.g_loss, x_real, y_org, y_trg, z_trgs=[z_trg, z_trg2], use_adv_cls=use_adv_cls) scaler.scale(g_loss).backward() else: g_loss, g_losses_latent = compute_g_loss( self.model, self.args.g_loss, x_real, y_org, y_trg, z_trgs=[z_trg, z_trg2], use_adv_cls=use_adv_cls) g_loss.backward() self.optimizer.step('generator', scaler=scaler) self.optimizer.step('mapping_network', scaler=scaler) self.optimizer.step('style_encoder', scaler=scaler) # train the generator (by target reference) self.optimizer.clear_grad() if scaler is not None: with paddle.amp.autocast(): g_loss, g_losses_ref = compute_g_loss( self.model, self.args.g_loss, x_real, y_org, y_trg, x_refs=[x_ref, x_ref2], use_adv_cls=use_adv_cls) scaler.scale(g_loss).backward() else: g_loss, g_losses_ref = compute_g_loss( self.model, self.args.g_loss, x_real, y_org, y_trg, x_refs=[x_ref, x_ref2], use_adv_cls=use_adv_cls) g_loss.backward() self.optimizer.step('generator', scaler=scaler) # compute moving average of network parameters self.moving_average(self.model.generator, self.model_ema.generator, beta=0.999) self.moving_average(self.model.mapping_network, self.model_ema.mapping_network, beta=0.999) self.moving_average(self.model.style_encoder, self.model_ema.style_encoder, beta=0.999) self.optimizer.scheduler() for key in d_losses_latent: train_losses["train/%s" % key].append(d_losses_latent[key]) for key in g_losses_latent: train_losses["train/%s" % key].append(g_losses_latent[key]) train_losses = {key: np.mean(value) for key, value in train_losses.items()} return train_losses @paddle.no_grad() def _eval_epoch(self): use_adv_cls = (self.epochs >= self.args.adv_cls_epoch) eval_losses = defaultdict(list) eval_images = defaultdict(list) _ = [self.model[k].eval() for k in self.model] for eval_steps_per_epoch, batch in enumerate(tqdm(self.val_dataloader, desc="[eval]"), 1): ### load data x_real, y_org, x_ref, x_ref2, y_trg, z_trg, z_trg2 = batch # train the discriminator d_loss, d_losses_latent = compute_d_loss( self.model, self.args.d_loss, x_real, y_org, y_trg, z_trg=z_trg, use_r1_reg=False, use_adv_cls=use_adv_cls) d_loss, d_losses_ref = compute_d_loss( self.model, self.args.d_loss, x_real, y_org, y_trg, x_ref=x_ref, use_r1_reg=False, use_adv_cls=use_adv_cls) # train the generator g_loss, g_losses_latent = compute_g_loss( self.model, self.args.g_loss, x_real, y_org, y_trg, z_trgs=[z_trg, z_trg2], use_adv_cls=use_adv_cls) g_loss, g_losses_ref = compute_g_loss( self.model, self.args.g_loss, x_real, y_org, y_trg, x_refs=[x_ref, x_ref2], use_adv_cls=use_adv_cls) for key in d_losses_latent: eval_losses["eval/%s" % key].append(d_losses_latent[key]) for key in g_losses_latent: eval_losses["eval/%s" % key].append(g_losses_latent[key]) # if eval_steps_per_epoch % 10 == 0: # # generate x_fake # s_trg = self.model_ema.style_encoder(x_ref, y_trg) # F0 = self.model.f0_model.get_feature_GAN(x_real) # x_fake = self.model_ema.generator(x_real, s_trg, masks=None, F0=F0) # # generate x_recon # s_real = self.model_ema.style_encoder(x_real, y_org) # F0_fake = self.model.f0_model.get_feature_GAN(x_fake) # x_recon = self.model_ema.generator(x_fake, s_real, masks=None, F0=F0_fake) # eval_images['eval/image'].append( # ([x_real[0, 0].numpy(), # x_fake[0, 0].numpy(), # x_recon[0, 0].numpy()])) eval_losses = {key: np.mean(value) for key, value in eval_losses.items()} eval_losses.update(eval_images) return eval_losses