InvSR / trainer.py
OAOA's picture
first commit
bfa59ab
raw
history blame
75.8 kB
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2022-05-18 13:04:06
import os, sys, math, time, random, datetime
import numpy as np
from box import Box
from pathlib import Path
from loguru import logger
from copy import deepcopy
from omegaconf import OmegaConf
from einops import rearrange
from typing import Any, Dict, List, Optional, Tuple, Union
from datapipe.datasets import create_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as udata
import torch.distributed as dist
import torch.multiprocessing as mp
import torchvision.utils as vutils
from torch.nn.parallel import DistributedDataParallel as DDP
from utils import util_net
from utils import util_common
from utils import util_image
from utils.util_ops import append_dims
import pyiqa
from basicsr.utils import DiffJPEG, USMSharp
from basicsr.utils.img_process_util import filter2D
from basicsr.data.transforms import paired_random_crop
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
from diffusers import EulerDiscreteScheduler
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import retrieve_timesteps
_base_seed = 10**6
_INTERPOLATION_MODE = 'bicubic'
_Latent_bound = {'min':-10.0, 'max':10.0}
_positive= 'Cinematic, high-contrast, photo-realistic, 8k, ultra HD, ' +\
'meticulous detailing, hyper sharpness, perfect without deformations'
_negative= 'Low quality, blurring, jpeg artifacts, deformed, over-smooth, cartoon, noisy,' +\
'painting, drawing, sketch, oil painting'
class TrainerBase:
def __init__(self, configs):
self.configs = configs
# setup distributed training: self.num_gpus, self.rank
self.setup_dist()
# setup seed
self.setup_seed()
def setup_dist(self):
num_gpus = torch.cuda.device_count()
if num_gpus > 1:
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(
timeout=datetime.timedelta(seconds=3600),
backend='nccl',
init_method='env://',
)
self.num_gpus = num_gpus
self.rank = int(os.environ['LOCAL_RANK']) if num_gpus > 1 else 0
def setup_seed(self, seed=None, global_seeding=None):
if seed is None:
seed = self.configs.train.get('seed', 12345)
if global_seeding is None:
global_seeding = self.configs.train.get('global_seeding', False)
if not global_seeding:
seed += self.rank
torch.cuda.manual_seed(seed)
else:
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def init_logger(self):
if self.configs.resume:
assert self.configs.resume.endswith(".pth")
save_dir = Path(self.configs.resume).parents[1]
project_id = save_dir.name
else:
project_id = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M")
save_dir = Path(self.configs.save_dir) / project_id
if not save_dir.exists() and self.rank == 0:
save_dir.mkdir(parents=True)
# setting log counter
if self.rank == 0:
self.log_step = {phase: 1 for phase in ['train', 'val']}
self.log_step_img = {phase: 1 for phase in ['train', 'val']}
# text logging
logtxet_path = save_dir / 'training.log'
if self.rank == 0:
if logtxet_path.exists():
assert self.configs.resume
self.logger = logger
self.logger.remove()
self.logger.add(logtxet_path, format="{message}", mode='a', level='INFO')
self.logger.add(sys.stdout, format="{message}")
# tensorboard logging
log_dir = save_dir / 'tf_logs'
self.tf_logging = self.configs.train.tf_logging
if self.rank == 0 and self.tf_logging:
if not log_dir.exists():
log_dir.mkdir()
self.writer = SummaryWriter(str(log_dir))
# checkpoint saving
ckpt_dir = save_dir / 'ckpts'
self.ckpt_dir = ckpt_dir
if self.rank == 0 and (not ckpt_dir.exists()):
ckpt_dir.mkdir()
if 'ema_rate' in self.configs.train:
self.ema_rate = self.configs.train.ema_rate
assert isinstance(self.ema_rate, float), "Ema rate must be a float number"
ema_ckpt_dir = save_dir / 'ema_ckpts'
self.ema_ckpt_dir = ema_ckpt_dir
if self.rank == 0 and (not ema_ckpt_dir.exists()):
ema_ckpt_dir.mkdir()
# save images into local disk
self.local_logging = self.configs.train.local_logging
if self.rank == 0 and self.local_logging:
image_dir = save_dir / 'images'
if not image_dir.exists():
(image_dir / 'train').mkdir(parents=True)
(image_dir / 'val').mkdir(parents=True)
self.image_dir = image_dir
# logging the configurations
if self.rank == 0:
self.logger.info(OmegaConf.to_yaml(self.configs))
def close_logger(self):
if self.rank == 0 and self.tf_logging:
self.writer.close()
def resume_from_ckpt(self):
if self.configs.resume:
assert self.configs.resume.endswith(".pth") and os.path.isfile(self.configs.resume)
if self.rank == 0:
self.logger.info(f"=> Loading checkpoint from {self.configs.resume}")
ckpt = torch.load(self.configs.resume, map_location=f"cuda:{self.rank}")
util_net.reload_model(self.model, ckpt['state_dict'])
if self.configs.train.loss_coef.get('ldis', 0) > 0:
util_net.reload_model(self.discriminator, ckpt['state_dict_dis'])
torch.cuda.empty_cache()
# learning rate scheduler
self.iters_start = ckpt['iters_start']
for ii in range(1, self.iters_start+1):
self.adjust_lr(ii)
# logging
if self.rank == 0:
self.log_step = ckpt['log_step']
self.log_step_img = ckpt['log_step_img']
# EMA model
if self.rank == 0 and hasattr(self.configs.train, 'ema_rate'):
ema_ckpt_path = self.ema_ckpt_dir / ("ema_"+Path(self.configs.resume).name)
self.logger.info(f"=> Loading EMA checkpoint from {str(ema_ckpt_path)}")
ema_ckpt = torch.load(ema_ckpt_path, map_location=f"cuda:{self.rank}")
util_net.reload_model(self.ema_model, ema_ckpt)
torch.cuda.empty_cache()
# AMP scaler
if self.amp_scaler is not None:
if "amp_scaler" in ckpt:
self.amp_scaler.load_state_dict(ckpt["amp_scaler"])
if self.rank == 0:
self.logger.info("Loading scaler from resumed state...")
if self.configs.get('discriminator', None) is not None:
if "amp_scaler_dis" in ckpt:
self.amp_scaler_dis.load_state_dict(ckpt["amp_scaler_dis"])
if self.rank == 0:
self.logger.info("Loading scaler (discriminator) from resumed state...")
# reset the seed
self.setup_seed(seed=self.iters_start)
else:
self.iters_start = 0
def setup_optimizaton(self):
self.optimizer = torch.optim.AdamW(self.model.parameters(),
lr=self.configs.train.lr,
weight_decay=self.configs.train.weight_decay)
# amp settings
self.amp_scaler = torch.amp.GradScaler('cuda') if self.configs.train.use_amp else None
if self.configs.train.lr_schedule == 'cosin':
self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer=self.optimizer,
T_max=self.configs.train.iterations - self.configs.train.warmup_iterations,
eta_min=self.configs.train.lr_min,
)
if self.configs.train.loss_coef.get('ldis', 0) > 0:
self.optimizer_dis = torch.optim.Adam(
self.discriminator.parameters(),
lr=self.configs.train.lr_dis,
weight_decay=self.configs.train.weight_decay_dis,
)
self.amp_scaler_dis = torch.amp.GradScaler('cuda') if self.configs.train.use_amp else None
def prepare_compiling(self):
# https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_3#stable-diffusion-3
if not hasattr(self, "prepare_compiling_well") or (not self.prepare_compiling_well):
torch.set_float32_matmul_precision("high")
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True
self.prepare_compiling_well = True
def build_model(self):
if self.configs.train.get("compile", True):
self.prepare_compiling()
params = self.configs.model.get('params', dict)
model = util_common.get_obj_from_str(self.configs.model.target)(**params)
model.cuda()
if not self.configs.train.start_mode: # Loading the starting model for evaluation
self.start_model = deepcopy(model)
assert self.configs.model.ckpt_start_path is not None
ckpt_start_path = self.configs.model.ckpt_start_path
if self.rank == 0:
self.logger.info(f"Loading the starting model from {ckpt_start_path}")
ckpt = torch.load(ckpt_start_path, map_location=f"cuda:{self.rank}")
if 'state_dict' in ckpt:
ckpt = ckpt['state_dict']
util_net.reload_model(self.start_model, ckpt)
self.freeze_model(self.start_model)
self.start_model.eval()
# delete the started timestep
start_timestep = max(self.configs.train.timesteps)
self.configs.train.timesteps.remove(start_timestep)
# end_timestep = min(self.configs.train.timesteps)
# self.configs.train.timesteps.remove(end_timestep)
# setting the training model
if self.configs.model.get('ckpt_path', None): # initialize if necessary
ckpt_path = self.configs.model.ckpt_path
if self.rank == 0:
self.logger.info(f"Initializing model from {ckpt_path}")
ckpt = torch.load(ckpt_path, map_location=f"cuda:{self.rank}")
if 'state_dict' in ckpt:
ckpt = ckpt['state_dict']
util_net.reload_model(model, ckpt)
if self.configs.model.get("compile", False):
if self.rank == 0:
self.logger.info("Compile the model...")
model.to(memory_format=torch.channels_last)
model = torch.compile(model, mode="max-autotune", fullgraph=False)
if self.num_gpus > 1:
model = DDP(model, device_ids=[self.rank,]) # wrap the network
if self.rank == 0 and hasattr(self.configs.train, 'ema_rate'):
self.ema_model = deepcopy(model)
self.freeze_model(self.ema_model)
self.model = model
# discriminator if necessary
if self.configs.train.loss_coef.get('ldis', 0) > 0:
assert hasattr(self.configs, 'discriminator')
params = self.configs.discriminator.get('params', dict)
discriminator = util_common.get_obj_from_str(self.configs.discriminator.target)(**params)
discriminator.cuda()
if self.configs.discriminator.get("compile", False):
if self.rank == 0:
self.logger.info("Compile the discriminator...")
discriminator.to(memory_format=torch.channels_last)
discriminator = torch.compile(discriminator, mode="max-autotune", fullgraph=False)
if self.num_gpus > 1:
discriminator = DDP(discriminator, device_ids=[self.rank,]) # wrap the network
if self.configs.train.loss_coef.get('ldis', 0) > 0:
if self.configs.discriminator.enable_grad_checkpoint:
if self.rank == 0:
self.logger.info("Activating gradient checkpointing for discriminator...")
self.set_grad_checkpointing(discriminator)
self.discriminator = discriminator
# build the stable diffusion
params = dict(self.configs.sd_pipe.params)
torch_dtype = params.pop('torch_dtype')
params['torch_dtype'] = get_torch_dtype(torch_dtype)
# loading the fp16 robust vae for sdxl: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
if self.configs.get('vae_fp16', None) is not None:
params_vae = dict(self.configs.vae_fp16.params)
params_vae['torch_dtype'] = torch.float16
pipe_id = self.configs.vae_fp16.params.pretrained_model_name_or_path
if self.rank == 0:
self.logger.info(f'Loading improved vae from {pipe_id}...')
vae_pipe = util_common.get_obj_from_str(self.configs.vae_fp16.target).from_pretrained(**params_vae)
if self.rank == 0:
self.logger.info('Loaded Done')
params['vae'] = vae_pipe
if ("StableDiffusion3" in self.configs.sd_pipe.target.split('.')[-1]
and self.configs.sd_pipe.get("model_quantization", False)):
if self.rank == 0:
self.logger.info(f'Loading the quantized transformer for SD3...')
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
params_model = dict(self.configs.model_nf4.params)
torch_dtype = params_model.pop('torch_dtype')
params_model['torch_dtype'] = get_torch_dtype(torch_dtype)
params_model['quantization_config'] = nf4_config
model_nf4 = util_common.get_obj_from_str(self.configs.model_nf4.target).from_pretrained(
**params_model
)
params['transformer'] = model_nf4
sd_pipe = util_common.get_obj_from_str(self.configs.sd_pipe.target).from_pretrained(**params)
if self.configs.get('scheduler', None) is not None:
pipe_id = self.configs.scheduler.target.split('.')[-1]
if self.rank == 0:
self.logger.info(f'Loading scheduler of {pipe_id}...')
sd_pipe.scheduler = util_common.get_obj_from_str(self.configs.scheduler.target).from_config(
sd_pipe.scheduler.config
)
if self.rank == 0:
self.logger.info('Loaded Done')
if ("StableDiffusion3" in self.configs.sd_pipe.target.split('.')[-1]
and self.configs.sd_pipe.get("model_quantization", False)):
sd_pipe.enable_model_cpu_offload(gpu_id=self.rank,device='cuda')
else:
sd_pipe.to(f"cuda:{self.rank}")
# freezing model parameters
if hasattr(sd_pipe, 'unet'):
self.freeze_model(sd_pipe.unet)
if hasattr(sd_pipe, 'transformer'):
self.freeze_model(sd_pipe.transformer)
self.freeze_model(sd_pipe.vae)
# compiling
if self.configs.sd_pipe.get('compile', True):
if self.rank == 0:
self.logger.info('Compile the SD model...')
sd_pipe.set_progress_bar_config(disable=True)
if hasattr(sd_pipe, 'unet'):
sd_pipe.unet.to(memory_format=torch.channels_last)
sd_pipe.unet = torch.compile(sd_pipe.unet, mode="max-autotune", fullgraph=False)
if hasattr(sd_pipe, 'transformer'):
sd_pipe.transformer.to(memory_format=torch.channels_last)
sd_pipe.transformer = torch.compile(sd_pipe.transformer, mode="max-autotune", fullgraph=False)
sd_pipe.vae.to(memory_format=torch.channels_last)
sd_pipe.vae = torch.compile(sd_pipe.vae, mode="max-autotune", fullgraph=True)
# setting gradient checkpoint for vae
if self.configs.sd_pipe.get("enable_grad_checkpoint_vae", True):
if self.rank == 0:
self.logger.info("Activating gradient checkpointing for VAE...")
sd_pipe.vae._set_gradient_checkpointing(sd_pipe.vae.encoder)
sd_pipe.vae._set_gradient_checkpointing(sd_pipe.vae.decoder)
# setting gradient checkpoint for diffusion model
if self.configs.sd_pipe.enable_grad_checkpoint:
if self.rank == 0:
self.logger.info("Activating gradient checkpointing for SD...")
if hasattr(sd_pipe, 'unet'):
self.set_grad_checkpointing(sd_pipe.unet)
if hasattr(sd_pipe, 'transformer'):
self.set_grad_checkpointing(sd_pipe.transformer)
self.sd_pipe = sd_pipe
# latent LPIPS loss
if self.configs.train.loss_coef.get('llpips', 0) > 0:
params = self.configs.llpips.get('params', dict)
llpips_loss = util_common.get_obj_from_str(self.configs.llpips.target)(**params)
llpips_loss.cuda()
self.freeze_model(llpips_loss)
# loading the pre-trained model
ckpt_path = self.configs.llpips.ckpt_path
self.load_model(llpips_loss, ckpt_path, tag='latent lpips')
if self.configs.llpips.get("compile", True):
if self.rank == 0:
self.logger.info('Compile the llpips loss...')
llpips_loss.to(memory_format=torch.channels_last)
llpips_loss = torch.compile(llpips_loss, mode="max-autotune", fullgraph=True)
self.llpips_loss = llpips_loss
# model information
self.print_model_info()
torch.cuda.empty_cache()
def set_grad_checkpointing(self, model):
if hasattr(model, 'down_blocks'):
for module in model.down_blocks:
module.gradient_checkpointing = True
module.training = True
if hasattr(model, 'up_blocks'):
for module in model.up_blocks:
module.gradient_checkpointing = True
module.training = True
if hasattr(model, 'mid_blocks'):
model.mid_block.gradient_checkpointing = True
model.mid_block.training = True
def build_dataloader(self):
def _wrap_loader(loader):
while True: yield from loader
# make datasets
datasets = {'train': create_dataset(self.configs.data.get('train', dict)), }
if hasattr(self.configs.data, 'val') and self.rank == 0:
datasets['val'] = create_dataset(self.configs.data.get('val', dict))
if self.rank == 0:
for phase in datasets.keys():
length = len(datasets[phase])
self.logger.info('Number of images in {:s} data set: {:d}'.format(phase, length))
# make dataloaders
if self.num_gpus > 1:
sampler = udata.distributed.DistributedSampler(
datasets['train'],
num_replicas=self.num_gpus,
rank=self.rank,
)
else:
sampler = None
dataloaders = {'train': _wrap_loader(udata.DataLoader(
datasets['train'],
batch_size=self.configs.train.batch // self.num_gpus,
shuffle=False if self.num_gpus > 1 else True,
drop_last=True,
num_workers=min(self.configs.train.num_workers, 4),
pin_memory=True,
prefetch_factor=self.configs.train.get('prefetch_factor', 2),
worker_init_fn=my_worker_init_fn,
sampler=sampler,
))}
if hasattr(self.configs.data, 'val') and self.rank == 0:
dataloaders['val'] = udata.DataLoader(datasets['val'],
batch_size=self.configs.validate.batch,
shuffle=False,
drop_last=False,
num_workers=0,
pin_memory=True,
)
self.datasets = datasets
self.dataloaders = dataloaders
self.sampler = sampler
def print_model_info(self):
if self.rank == 0:
num_params = util_net.calculate_parameters(self.model) / 1000**2
# self.logger.info("Detailed network architecture:")
# self.logger.info(self.model.__repr__())
if self.configs.train.get('use_fsdp', False):
num_params *= self.num_gpus
self.logger.info(f"Number of parameters: {num_params:.2f}M")
if hasattr(self, 'discriminator'):
num_params = util_net.calculate_parameters(self.discriminator) / 1000**2
self.logger.info(f"Number of parameters in discriminator: {num_params:.2f}M")
def prepare_data(self, data, dtype=torch.float32, phase='train'):
data = {key:value.cuda().to(dtype=dtype) for key, value in data.items()}
return data
def validation(self):
pass
def train(self):
self.init_logger() # setup logger: self.logger
self.build_dataloader() # prepare data: self.dataloaders, self.datasets, self.sampler
self.build_model() # build model: self.model, self.loss
self.setup_optimizaton() # setup optimization: self.optimzer, self.sheduler
self.resume_from_ckpt() # resume if necessary
self.model.train()
num_iters_epoch = math.ceil(len(self.datasets['train']) / self.configs.train.batch)
for ii in range(self.iters_start, self.configs.train.iterations):
self.current_iters = ii + 1
# prepare data
data = self.prepare_data(next(self.dataloaders['train']), phase='train')
# training phase
self.training_step(data)
# update ema model
if hasattr(self.configs.train, 'ema_rate') and self.rank == 0:
self.update_ema_model()
# validation phase
if ((ii+1) % self.configs.train.save_freq == 0 and
'val' in self.dataloaders and
self.rank == 0
):
self.validation()
#update learning rate
self.adjust_lr()
# save checkpoint
if (ii+1) % self.configs.train.save_freq == 0 and self.rank == 0:
self.save_ckpt()
if (ii+1) % num_iters_epoch == 0 and self.sampler is not None:
self.sampler.set_epoch(ii+1)
# close the tensorboard
self.close_logger()
def adjust_lr(self, current_iters=None):
base_lr = self.configs.train.lr
warmup_steps = self.configs.train.get("warmup_iterations", 0)
current_iters = self.current_iters if current_iters is None else current_iters
if current_iters <= warmup_steps:
for params_group in self.optimizer.param_groups:
params_group['lr'] = (current_iters / warmup_steps) * base_lr
else:
if hasattr(self, 'lr_scheduler'):
self.lr_scheduler.step()
def save_ckpt(self):
ckpt_path = self.ckpt_dir / 'model_{:d}.pth'.format(self.current_iters)
ckpt = {
'iters_start': self.current_iters,
'log_step': {phase:self.log_step[phase] for phase in ['train', 'val']},
'log_step_img': {phase:self.log_step_img[phase] for phase in ['train', 'val']},
'state_dict': self.model.state_dict(),
}
if self.amp_scaler is not None:
ckpt['amp_scaler'] = self.amp_scaler.state_dict()
if self.configs.train.loss_coef.get('ldis', 0) > 0:
ckpt['state_dict_dis'] = self.discriminator.state_dict()
if self.amp_scaler_dis is not None:
ckpt['amp_scaler_dis'] = self.amp_scaler_dis.state_dict()
torch.save(ckpt, ckpt_path)
if hasattr(self.configs.train, 'ema_rate'):
ema_ckpt_path = self.ema_ckpt_dir / 'ema_model_{:d}.pth'.format(self.current_iters)
torch.save(self.ema_model.state_dict(), ema_ckpt_path)
def logging_image(self, im_tensor, tag, phase, add_global_step=False, nrow=8):
"""
Args:
im_tensor: b x c x h x w tensor
im_tag: str
phase: 'train' or 'val'
nrow: number of displays in each row
"""
assert self.tf_logging or self.local_logging
im_tensor = vutils.make_grid(im_tensor, nrow=nrow, normalize=True, scale_each=True) # c x H x W
if self.local_logging:
im_path = str(self.image_dir / phase / f"{tag}-{self.log_step_img[phase]}.png")
im_np = im_tensor.cpu().permute(1,2,0).numpy()
util_image.imwrite(im_np, im_path)
if self.tf_logging:
self.writer.add_image(
f"{phase}-{tag}-{self.log_step_img[phase]}",
im_tensor,
self.log_step_img[phase],
)
if add_global_step:
self.log_step_img[phase] += 1
def logging_text(self, text_list, phase):
"""
Args:
text_list: (b,) list
phase: 'train' or 'val'
"""
assert self.local_logging
if self.local_logging:
text_path = str(self.image_dir / phase / f"text-{self.log_step_img[phase]}.txt")
with open(text_path, 'w') as ff:
for text in text_list:
ff.write(text + '\n')
def logging_metric(self, metrics, tag, phase, add_global_step=False):
"""
Args:
metrics: dict
tag: str
phase: 'train' or 'val'
"""
if self.tf_logging:
tag = f"{phase}-{tag}"
if isinstance(metrics, dict):
self.writer.add_scalars(tag, metrics, self.log_step[phase])
else:
self.writer.add_scalar(tag, metrics, self.log_step[phase])
if add_global_step:
self.log_step[phase] += 1
else:
pass
def load_model(self, model, ckpt_path=None, tag='model'):
if self.rank == 0:
self.logger.info(f'Loading {tag} from {ckpt_path}...')
ckpt = torch.load(ckpt_path, map_location=f"cuda:{self.rank}")
if 'state_dict' in ckpt:
ckpt = ckpt['state_dict']
util_net.reload_model(model, ckpt)
if self.rank == 0:
self.logger.info('Loaded Done')
def freeze_model(self, net):
for params in net.parameters():
params.requires_grad = False
def unfreeze_model(self, net):
for params in net.parameters():
params.requires_grad = True
@torch.no_grad()
def update_ema_model(self):
decay = min(self.configs.train.ema_rate, (1 + self.current_iters) / (10 + self.current_iters))
target_params = dict(self.model.named_parameters())
# if hasattr(self.configs.train, 'ema_rate'):
# with FSDP.summon_full_params(self.model, writeback=True):
# target_params = dict(self.model.named_parameters())
# else:
# target_params = dict(self.model.named_parameters())
one_minus_decay = 1.0 - decay
for key, source_value in self.ema_model.named_parameters():
target_value = target_params[key]
if target_value.requires_grad:
source_value.sub_(one_minus_decay * (source_value - target_value.data))
class TrainerBaseSR(TrainerBase):
@torch.no_grad()
def _dequeue_and_enqueue(self):
"""It is the training pair pool for increasing the diversity in a batch.
Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
batch could not have different resize scaling factors. Therefore, we employ this training pair pool
to increase the degradation diversity in a batch.
"""
# initialize
b, c, h, w = self.lq.size()
if not hasattr(self, 'queue_size'):
self.queue_size = self.configs.degradation.get('queue_size', b*10)
if not hasattr(self, 'queue_lr'):
assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
_, c, h, w = self.gt.size()
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
_, c, h, w = self.gt_latent.size()
self.queue_gt_latent = torch.zeros(self.queue_size, c, h, w).cuda()
self.queue_txt = ["", ] * self.queue_size
self.queue_ptr = 0
if self.queue_ptr == self.queue_size: # the pool is full
# do dequeue and enqueue
# shuffle
idx = torch.randperm(self.queue_size)
self.queue_lr = self.queue_lr[idx]
self.queue_gt = self.queue_gt[idx]
self.queue_gt_latent = self.queue_gt_latent[idx]
self.queue_txt = [self.queue_txt[ii] for ii in idx]
# get first b samples
lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
gt_latent_dequeue = self.queue_gt_latent[0:b, :, :, :].clone()
txt_dequeue = deepcopy(self.queue_txt[0:b])
# update the queue
self.queue_lr[0:b, :, :, :] = self.lq.clone()
self.queue_gt[0:b, :, :, :] = self.gt.clone()
self.queue_gt_latent[0:b, :, :, :] = self.gt_latent.clone()
self.queue_txt[0:b] = deepcopy(self.txt)
self.lq = lq_dequeue
self.gt = gt_dequeue
self.gt_latent = gt_latent_dequeue
self.txt = txt_dequeue
else:
# only do enqueue
self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
self.queue_gt_latent[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt_latent.clone()
self.queue_txt[self.queue_ptr:self.queue_ptr + b] = deepcopy(self.txt)
self.queue_ptr = self.queue_ptr + b
@torch.no_grad()
def prepare_data(self, data, phase='train'):
if phase == 'train' and self.configs.data.get(phase).get('type') == 'realesrgan':
if not hasattr(self, 'jpeger'):
self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
if (not hasattr(self, 'sharpener')) and self.configs.degradation.get('use_sharp', False):
self.sharpener = USMSharp().cuda()
im_gt = data['gt'].cuda()
kernel1 = data['kernel1'].cuda()
kernel2 = data['kernel2'].cuda()
sinc_kernel = data['sinc_kernel'].cuda()
ori_h, ori_w = im_gt.size()[2:4]
if isinstance(self.configs.degradation.sf, int):
sf = self.configs.degradation.sf
else:
assert len(self.configs.degradation.sf) == 2
sf = random.uniform(*self.configs.degradation.sf)
if self.configs.degradation.use_sharp:
im_gt = self.sharpener(im_gt)
# ----------------------- The first degradation process ----------------------- #
# blur
out = filter2D(im_gt, kernel1)
# random resize
updown_type = random.choices(
['up', 'down', 'keep'],
self.configs.degradation['resize_prob'],
)[0]
if updown_type == 'up':
scale = random.uniform(1, self.configs.degradation['resize_range'][1])
elif updown_type == 'down':
scale = random.uniform(self.configs.degradation['resize_range'][0], 1)
else:
scale = 1
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(out, scale_factor=scale, mode=mode)
# add noise
gray_noise_prob = self.configs.degradation['gray_noise_prob']
if random.random() < self.configs.degradation['gaussian_noise_prob']:
out = random_add_gaussian_noise_pt(
out,
sigma_range=self.configs.degradation['noise_range'],
clip=True,
rounds=False,
gray_prob=gray_noise_prob,
)
else:
out = random_add_poisson_noise_pt(
out,
scale_range=self.configs.degradation['poisson_scale_range'],
gray_prob=gray_noise_prob,
clip=True,
rounds=False)
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.configs.degradation['jpeg_range'])
out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
out = self.jpeger(out, quality=jpeg_p)
# ----------------------- The second degradation process ----------------------- #
if random.random() < self.configs.degradation['second_order_prob']:
# blur
if random.random() < self.configs.degradation['second_blur_prob']:
out = filter2D(out, kernel2)
# random resize
updown_type = random.choices(
['up', 'down', 'keep'],
self.configs.degradation['resize_prob2'],
)[0]
if updown_type == 'up':
scale = random.uniform(1, self.configs.degradation['resize_range2'][1])
elif updown_type == 'down':
scale = random.uniform(self.configs.degradation['resize_range2'][0], 1)
else:
scale = 1
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(
out,
size=(int(ori_h / sf * scale), int(ori_w / sf * scale)),
mode=mode,
)
# add noise
gray_noise_prob = self.configs.degradation['gray_noise_prob2']
if random.random() < self.configs.degradation['gaussian_noise_prob2']:
out = random_add_gaussian_noise_pt(
out,
sigma_range=self.configs.degradation['noise_range2'],
clip=True,
rounds=False,
gray_prob=gray_noise_prob,
)
else:
out = random_add_poisson_noise_pt(
out,
scale_range=self.configs.degradation['poisson_scale_range2'],
gray_prob=gray_noise_prob,
clip=True,
rounds=False,
)
# JPEG compression + the final sinc filter
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
# as one operation.
# We consider two orders:
# 1. [resize back + sinc filter] + JPEG compression
# 2. JPEG compression + [resize back + sinc filter]
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
if random.random() < 0.5:
# resize back + the final sinc filter
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(
out,
size=(ori_h // sf, ori_w // sf),
mode=mode,
)
out = filter2D(out, sinc_kernel)
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.configs.degradation['jpeg_range2'])
out = torch.clamp(out, 0, 1)
out = self.jpeger(out, quality=jpeg_p)
else:
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.configs.degradation['jpeg_range2'])
out = torch.clamp(out, 0, 1)
out = self.jpeger(out, quality=jpeg_p)
# resize back + the final sinc filter
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(
out,
size=(ori_h // sf, ori_w // sf),
mode=mode,
)
out = filter2D(out, sinc_kernel)
# resize back
if self.configs.degradation.resize_back:
out = F.interpolate(out, size=(ori_h, ori_w), mode=_INTERPOLATION_MODE)
# clamp and round
im_lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
self.lq, self.gt, self.txt = im_lq, im_gt, data['txt']
if "gt_moment" not in data:
self.gt_latent = self.encode_first_stage(
im_gt.cuda(),
center_input_sample=True,
deterministic=self.configs.train.loss_coef.get('rkl', 0) > 0,
)
else:
self.gt_latent = self.encode_from_moment(
data['gt_moment'].cuda(),
deterministic=self.configs.train.loss_coef.get('rkl', 0) > 0,
)
if (not self.configs.train.use_text) or self.configs.data.train.params.random_crop:
self.txt = [_positive,] * im_lq.shape[0]
# training pair pool
self._dequeue_and_enqueue()
self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
batch = {'lq':self.lq, 'gt':self.gt, 'gt_latent':self.gt_latent, 'txt':self.txt}
elif phase == 'val':
resolution = self.configs.data.train.params.gt_size // self.configs.degradation.sf
batch = {}
batch['lq'] = data['lq'].cuda()
if 'gt' in data:
batch['gt'] = data['gt'].cuda()
batch['txt'] = [_positive, ] * data['lq'].shape[0]
else:
batch = {key:value.cuda().to(dtype=torch.float32) for key, value in data.items()}
return batch
@torch.no_grad()
def encode_from_moment(self, z, deterministic=True):
dist = DiagonalGaussianDistribution(z)
init_latents = dist.mode() if deterministic else dist.sample()
latents_mean = latents_std = None
if hasattr(self.sd_pipe.vae.config, "latents_mean") and self.sd_pipe.vae.config.latents_mean is not None:
latents_mean = torch.tensor(self.sd_pipe.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(self.sd_pipe.vae.config, "latents_std") and self.sd_pipe.vae.config.latents_std is not None:
latents_std = torch.tensor(self.sd_pipe.vae.config.latents_std).view(1, 4, 1, 1)
scaling_factor = self.sd_pipe.vae.config.scaling_factor
if latents_mean is not None and latents_std is not None:
latents_mean = latents_mean.to(device=z.device, dtype=z.dtype)
latents_std = latents_std.to(device=z.device, dtype=z.dtype)
init_latents = (init_latents - latents_mean) * scaling_factor / latents_std
else:
init_latents = init_latents * scaling_factor
return init_latents
@torch.no_grad()
@torch.amp.autocast('cuda')
def encode_first_stage(self, x, deterministic=False, center_input_sample=True):
if center_input_sample:
x = x * 2.0 - 1.0
latents_mean = latents_std = None
if hasattr(self.sd_pipe.vae.config, "latents_mean") and self.sd_pipe.vae.config.latents_mean is not None:
latents_mean = torch.tensor(self.sd_pipe.vae.config.latents_mean).view(1, -1, 1, 1)
if hasattr(self.sd_pipe.vae.config, "latents_std") and self.sd_pipe.vae.config.latents_std is not None:
latents_std = torch.tensor(self.sd_pipe.vae.config.latents_std).view(1, -1, 1, 1)
if deterministic:
partial_encode = lambda xx: self.sd_pipe.vae.encode(xx).latent_dist.mode()
else:
partial_encode = lambda xx: self.sd_pipe.vae.encode(xx).latent_dist.sample()
trunk_size = self.configs.sd_pipe.vae_split
if trunk_size < x.shape[0]:
init_latents = torch.cat([partial_encode(xx) for xx in x.split(trunk_size, 0)], dim=0)
else:
init_latents = partial_encode(x)
scaling_factor = self.sd_pipe.vae.config.scaling_factor
if latents_mean is not None and latents_std is not None:
latents_mean = latents_mean.to(device=x.device, dtype=x.dtype)
latents_std = latents_std.to(device=x.device, dtype=x.dtype)
init_latents = (init_latents - latents_mean) * scaling_factor / latents_std
else:
init_latents = init_latents * scaling_factor
return init_latents
@torch.no_grad()
@torch.amp.autocast('cuda')
def decode_first_stage(self, z, clamp=True):
z = z / self.sd_pipe.vae.config.scaling_factor
trunk_size = 1
if trunk_size < z.shape[0]:
out = torch.cat(
[self.sd_pipe.vae.decode(xx).sample for xx in z.split(trunk_size, 0)], dim=0,
)
else:
out = self.sd_pipe.vae.decode(z).sample
if clamp:
out = out.clamp(-1.0, 1.0)
return out
def get_loss_from_discrimnator(self, logits_fake):
if not (isinstance(logits_fake, list) or isinstance(logits_fake, tuple)):
g_loss = -torch.mean(logits_fake, dim=list(range(1, logits_fake.ndim)))
else:
g_loss = -torch.mean(logits_fake[0], dim=list(range(1, logits_fake[0].ndim)))
for current_logits in logits_fake[1:]:
g_loss += -torch.mean(current_logits, dim=list(range(1, current_logits.ndim)))
g_loss /= len(logits_fake)
return g_loss
def training_step(self, data):
current_bs = data['gt'].shape[0]
micro_bs = self.configs.train.microbatch
num_grad_accumulate = math.ceil(current_bs / micro_bs)
# grad zero
self.model.zero_grad()
# update generator
if self.configs.train.loss_coef.get('ldis', 0) > 0:
self.freeze_model(self.discriminator) # freeze discriminator
z0_pred_list = []
tt_list = []
prompt_embeds_list = []
for jj in range(0, current_bs, micro_bs):
micro_data = {key:value[jj:jj+micro_bs] for key, value in data.items()}
last_batch = (jj+micro_bs >= current_bs)
if last_batch or self.num_gpus <= 1:
losses, z0_pred, zt_noisy, tt = self.backward_step(micro_data, num_grad_accumulate)
else:
with self.model.no_sync():
losses, z0_pred, zt_noisy, tt = self.backward_step(micro_data, num_grad_accumulate)
if self.configs.train.loss_coef.get('ldis', 0) > 0:
z0_pred_list.append(z0_pred.detach())
tt_list.append(tt)
prompt_embeds_list.append(self.prompt_embeds.detach())
if self.configs.train.use_amp:
self.amp_scaler.step(self.optimizer)
self.amp_scaler.update()
else:
self.optimizer.step()
# update discriminator
if (self.configs.train.loss_coef.get('ldis', 0) > 0 and
(self.current_iters < self.configs.train.dis_init_iterations
or self.current_iters % self.configs.train.dis_update_freq == 0)
):
# grad zero
self.unfreeze_model(self.discriminator) # update discriminator
self.discriminator.zero_grad()
for ii, jj in enumerate(range(0, current_bs, micro_bs)):
micro_data = {key:value[jj:jj+micro_bs] for key, value in data.items()}
last_batch = (jj+micro_bs >= current_bs)
target = micro_data['gt_latent']
inputs = z0_pred_list[ii]
if last_batch or self.num_gpus <= 1:
logits = self.dis_backward_step(target, inputs, tt_list[ii], prompt_embeds_list[ii])
else:
with self.discriminator.no_sync():
logits = self.dis_backward_step(
target, inputs, tt_list[ii], prompt_embeds_list[ii]
)
# make logging
if self.current_iters % self.configs.train.dis_update_freq == 0 and self.rank == 0:
ndim = logits[0].ndim
losses['real'] = logits[0].detach().mean(dim=list(range(1, ndim)))
losses['fake'] = logits[1].detach().mean(dim=list(range(1, ndim)))
if self.configs.train.use_amp:
self.amp_scaler_dis.step(self.optimizer_dis)
self.amp_scaler_dis.update()
else:
self.optimizer_dis.step()
# make logging
if self.rank == 0:
self.log_step_train(
losses, tt, micro_data, z0_pred, zt_noisy, z0_gt=micro_data['gt_latent'],
)
@torch.no_grad()
def log_step_train(self, losses, tt, micro_data, z0_pred, zt_noisy, z0_gt=None, phase='train'):
'''
param losses: a dict recording the loss informations
'''
'''
param loss: a dict recording the loss informations
param micro_data: batch data
param tt: 1-D tensor, time steps
'''
if hasattr(self.configs.train, 'timesteps'):
if len(self.configs.train.timesteps) < 3:
record_steps = sorted(self.configs.train.timesteps)
else:
record_steps = [min(self.configs.train.timesteps),
max(self.configs.train.timesteps)]
else:
max_inference_steps = self.configs.train.max_inference_steps
record_steps = [1, max_inference_steps//2, max_inference_steps]
if ((self.current_iters // self.configs.train.dis_update_freq) %
(self.configs.train.log_freq[0] // self.configs.train.dis_update_freq) == 1):
self.loss_mean = {key:torch.zeros(size=(len(record_steps),), dtype=torch.float64)
for key in losses.keys() if key not in ['real', 'fake']}
if self.configs.train.loss_coef.get('ldis', 0) > 0:
self.logit_mean = {key:torch.zeros(size=(len(record_steps),), dtype=torch.float64)
for key in ['real', 'fake']}
self.loss_count = torch.zeros(size=(len(record_steps),), dtype=torch.float64)
for jj in range(len(record_steps)):
for key, value in losses.items():
index = record_steps[jj] - 1
mask = torch.where(tt == index, torch.ones_like(tt), torch.zeros_like(tt))
assert value.shape == mask.shape
current_loss = torch.sum(value.detach() * mask)
if key in ['real', 'fake']:
self.logit_mean[key][jj] += current_loss.item()
else:
self.loss_mean[key][jj] += current_loss.item()
self.loss_count[jj] += mask.sum().item()
if ((self.current_iters // self.configs.train.dis_update_freq) %
(self.configs.train.log_freq[0] // self.configs.train.dis_update_freq) == 0):
if torch.any(self.loss_count == 0):
self.loss_count += 1e-4
for key in losses.keys():
if key in ['real', 'fake']:
self.logit_mean[key] /= self.loss_count
else:
self.loss_mean[key] /= self.loss_count
log_str = f"Train: {self.current_iters:06d}/{self.configs.train.iterations:06d}, "
valid_keys = sorted([key for key in losses.keys() if key not in ['loss', 'real', 'fake']])
for ii, key in enumerate(valid_keys):
if ii == 0:
log_str += f"{key}"
else:
log_str += f"/{key}"
if self.configs.train.loss_coef.get('ldis', 0) > 0:
log_str += "/real/fake:"
else:
log_str += ":"
for jj, current_record in enumerate(record_steps):
for ii, key in enumerate(valid_keys):
if ii == 0:
if key in ['dis', 'ldis']:
log_str += 't({:d}):{:+6.4f}'.format(
current_record,
self.loss_mean[key][jj].item(),
)
elif key in ['lpips', 'ldif']:
log_str += 't({:d}):{:4.2f}'.format(
current_record,
self.loss_mean[key][jj].item(),
)
elif key == 'llpips':
log_str += 't({:d}):{:5.3f}'.format(
current_record,
self.loss_mean[key][jj].item(),
)
else:
log_str += 't({:d}):{:.1e}'.format(
current_record,
self.loss_mean[key][jj].item(),
)
else:
if key in ['dis', 'ldis']:
log_str += f"/{self.loss_mean[key][jj].item():+6.4f}"
elif key in ['lpips', 'ldif']:
log_str += f"/{self.loss_mean[key][jj].item():4.2f}"
elif key == 'llpips':
log_str += f"/{self.loss_mean[key][jj].item():5.3f}"
else:
log_str += f"/{self.loss_mean[key][jj].item():.1e}"
if self.configs.train.loss_coef.get('ldis', 0) > 0:
log_str += f"/{self.logit_mean['real'][jj].item():+4.2f}"
log_str += f"/{self.logit_mean['fake'][jj].item():+4.2f}, "
else:
log_str += f", "
log_str += 'lr:{:.1e}'.format(self.optimizer.param_groups[0]['lr'])
self.logger.info(log_str)
self.logging_metric(self.loss_mean, tag='Loss', phase=phase, add_global_step=True)
if ((self.current_iters // self.configs.train.dis_update_freq) %
(self.configs.train.log_freq[1] // self.configs.train.dis_update_freq) == 0):
if zt_noisy is not None:
xt_pred = self.decode_first_stage(zt_noisy.detach())
self.logging_image(xt_pred, tag='xt-noisy', phase=phase, add_global_step=False)
if z0_pred is not None:
x0_pred = self.decode_first_stage(z0_pred.detach())
self.logging_image(x0_pred, tag='x0-pred', phase=phase, add_global_step=False)
if z0_gt is not None:
x0_recon = self.decode_first_stage(z0_gt.detach())
self.logging_image(x0_recon, tag='x0-recons', phase=phase, add_global_step=False)
if 'txt' in micro_data:
self.logging_text(micro_data['txt'], phase=phase)
self.logging_image(micro_data['lq'], tag='LQ', phase=phase, add_global_step=False)
self.logging_image(micro_data['gt'], tag='GT', phase=phase, add_global_step=True)
if ((self.current_iters // self.configs.train.dis_update_freq) %
(self.configs.train.save_freq // self.configs.train.dis_update_freq) == 1):
self.tic = time.time()
if ((self.current_iters // self.configs.train.dis_update_freq) %
(self.configs.train.save_freq // self.configs.train.dis_update_freq) == 0):
self.toc = time.time()
elaplsed = (self.toc - self.tic)
self.logger.info(f"Elapsed time: {elaplsed:.2f}s")
self.logger.info("="*100)
@torch.no_grad()
def validation(self, phase='val'):
torch.cuda.empty_cache()
if not (self.configs.validate.use_ema and hasattr(self.configs.train, 'ema_rate')):
self.model.eval()
if self.configs.train.start_mode:
start_noise_predictor = self.ema_model if self.configs.validate.use_ema else self.model
intermediate_noise_predictor = None
else:
start_noise_predictor = self.start_model
intermediate_noise_predictor = self.ema_model if self.configs.validate.use_ema else self.model
num_iters_epoch = math.ceil(len(self.datasets[phase]) / self.configs.validate.batch)
mean_psnr = mean_lpips = 0
for jj, data in enumerate(self.dataloaders[phase]):
data = self.prepare_data(data, phase='val')
with torch.amp.autocast('cuda'):
xt_progressive, x0_progressive = self.sample(
image_lq=data['lq'],
prompt=[_positive,]*data['lq'].shape[0],
target_size=tuple(data['gt'].shape[-2:]),
start_noise_predictor=start_noise_predictor,
intermediate_noise_predictor=intermediate_noise_predictor,
)
x0 = xt_progressive[-1]
num_inference_steps = len(xt_progressive)
if 'gt' in data:
if not hasattr(self, 'psnr_metric'):
self.psnr_metric = pyiqa.create_metric(
'psnr',
test_y_channel=self.configs.train.get('val_y_channel', True),
color_space='ycbcr',
device=torch.device("cuda"),
)
if not hasattr(self, 'lpips_metric'):
self.lpips_metric = pyiqa.create_metric(
'lpips-vgg',
device=torch.device("cuda"),
as_loss=False,
)
x0_normalize = util_image.normalize_th(x0, mean=0.5, std=0.5, reverse=True)
mean_psnr += self.psnr_metric(x0_normalize, data['gt']).sum().item()
with torch.amp.autocast('cuda'), torch.no_grad():
mean_lpips += self.lpips_metric(x0_normalize, data['gt']).sum().item()
if (jj + 1) % self.configs.validate.log_freq == 0:
self.logger.info(f'Validation: {jj+1:02d}/{num_iters_epoch:02d}...')
self.logging_image(data['gt'], tag='GT', phase=phase, add_global_step=False)
xt_progressive = rearrange(torch.cat(xt_progressive, dim=1), 'b (k c) h w -> (b k) c h w', c=3)
self.logging_image(
xt_progressive,
tag='sample-progress',
phase=phase,
add_global_step=False,
nrow=num_inference_steps,
)
x0_progressive = rearrange(torch.cat(x0_progressive, dim=1), 'b (k c) h w -> (b k) c h w', c=3)
self.logging_image(
x0_progressive,
tag='x0-progress',
phase=phase,
add_global_step=False,
nrow=num_inference_steps,
)
self.logging_image(data['lq'], tag='LQ', phase=phase, add_global_step=True)
if 'gt' in data:
mean_psnr /= len(self.datasets[phase])
mean_lpips /= len(self.datasets[phase])
self.logger.info(f'Validation Metric: PSNR={mean_psnr:5.2f}, LPIPS={mean_lpips:6.4f}...')
self.logging_metric(mean_psnr, tag='PSNR', phase=phase, add_global_step=False)
self.logging_metric(mean_lpips, tag='LPIPS', phase=phase, add_global_step=True)
self.logger.info("="*100)
if not (self.configs.validate.use_ema and hasattr(self.configs.train, 'ema_rate')):
self.model.train()
torch.cuda.empty_cache()
def backward_step(self, micro_data, num_grad_accumulate):
loss_coef = self.configs.train.loss_coef
losses = {}
z0_gt = micro_data['gt_latent']
tt = torch.tensor(
random.choices(self.configs.train.timesteps, k=z0_gt.shape[0]),
dtype=torch.int64,
device=f"cuda:{self.rank}",
) - 1
with torch.autocast(device_type="cuda", enabled=self.configs.train.use_amp):
model_pred = self.model(
micro_data['lq'], tt, sample_posterior=False, center_input_sample=True,
)
z0_pred, zt_noisy_pred, z0_lq = self.sd_forward_step(
prompt=micro_data['txt'],
latents_hq=micro_data['gt_latent'],
image_lq=micro_data['lq'],
image_hq=micro_data['gt'],
model_pred=model_pred,
timesteps=tt,
)
# diffusion loss
if loss_coef.get('ldif', 0) > 0:
if self.configs.train.loss_type == 'L2':
ldif_loss = F.mse_loss(z0_pred, z0_gt, reduction='none')
elif self.configs.train.loss_type == 'L1':
ldif_loss = F.l1_loss(z0_pred, z0_gt, reduction='none')
else:
raise TypeError(f"Unsupported Loss type for Diffusion: {self.configs.train.loss_type}")
ldif_loss = torch.mean(ldif_loss, dim=list(range(1, z0_gt.ndim)))
losses['ldif'] = ldif_loss * loss_coef['ldif']
# Gaussian constraints
if loss_coef.get('kl', 0) > 0:
losses['kl'] = model_pred.kl() * loss_coef['kl']
if loss_coef.get('pkl', 0) > 0:
losses['pkl'] = model_pred.partial_kl() * loss_coef['pkl']
if loss_coef.get('rkl', 0) > 0:
other = Box(
{'mean': z0_gt-z0_lq,
'var':torch.ones_like(z0_gt),
'logvar':torch.zeros_like(z0_gt)}
)
losses['rkl'] = model_pred.kl(other) * loss_coef['rkl']
# discriminator loss
if loss_coef.get('ldis', 0) > 0:
if self.current_iters > self.configs.train.dis_init_iterations:
logits_fake = self.discriminator(
torch.clamp(z0_pred, min=_Latent_bound['min'], max=_Latent_bound['max']),
timestep=tt,
encoder_hidden_states=self.prompt_embeds,
)
losses['ldis'] = self.get_loss_from_discrimnator(logits_fake) * loss_coef['ldis']
else:
losses['ldis'] = torch.zeros((z0_gt.shape[0], ), dtype=torch.float32).cuda()
# perceptual loss
if loss_coef.get('llpips', 0) > 0:
losses['llpips'] = self.llpips_loss(z0_pred, z0_gt).view(-1) * loss_coef['llpips']
for key in ['ldif', 'kl', 'rkl', 'pkl', 'ldis', 'llpips']:
if loss_coef.get(key, 0) > 0:
if not 'loss' in losses:
losses['loss'] = losses[key]
else:
losses['loss'] = losses['loss'] + losses[key]
loss = losses['loss'].mean() / num_grad_accumulate
if self.amp_scaler is None:
loss.backward()
else:
self.amp_scaler.scale(loss).backward()
return losses, z0_pred, zt_noisy_pred, tt
def dis_backward_step(self, target, inputs, tt, prompt_embeds):
with torch.autocast(device_type="cuda", enabled=self.configs.train.use_amp):
logits_real = self.discriminator(target, tt, prompt_embeds)
inputs = inputs.clamp(min=_Latent_bound['min'], max=_Latent_bound['max'])
logits_fake = self.discriminator(inputs, tt, prompt_embeds)
loss = hinge_d_loss(logits_real, logits_fake).mean()
if self.amp_scaler_dis is None:
loss.backward()
else:
self.amp_scaler_dis.scale(loss).backward()
return logits_real[-1], logits_fake[-1]
def scale_sd_input(
self,
x:torch.Tensor,
sigmas: torch.Tensor = None,
timestep: torch.Tensor = None,
) :
if sigmas is None:
if not self.sd_pipe.scheduler.sigmas.numel() == (self.configs.sd_pipe.num_train_steps + 1):
self.sd_pipe.scheduler = EulerDiscreteScheduler.from_pipe(
self.configs.sd_pipe.params.pretrained_model_name_or_path,
cache_dir=self.configs.sd_pipe.params.cache_dir,
subfolder='scheduler',
)
assert self.sd_pipe.scheduler.sigmas.numel() == (self.configs.sd_pipe.num_train_steps + 1)
sigmas = self.sd_pipe.scheduler.sigmas.flip(0).to(x.device)[timestep] # (b,)
sigmas = append_dims(sigmas, x.ndim)
if sigmas.ndim < x.ndim:
sigmas = append_dims(sigmas, x.ndim)
out = x / ((sigmas**2 + 1) ** 0.5)
return out
def prepare_lq_latents(
self,
image_lq: torch.Tensor,
timestep: torch.Tensor,
height: int = 512,
width: int = 512,
start_noise_predictor: torch.nn.Module = None,
):
"""
Input:
image_lq: low-quality image, torch.Tensor, range in [0, 1]
hight, width: resolution for high-quality image
"""
image_lq_up = F.interpolate(image_lq, size=(height, width), mode='bicubic')
init_latents = self.encode_first_stage(
image_lq_up, deterministic=False, center_input_sample=True,
)
if start_noise_predictor is None:
model_pred = None
else:
model_pred = start_noise_predictor(
image_lq, timestep, sample_posterior=False, center_input_sample=True,
)
# get latents
sigmas = self.sigmas_cache[timestep]
sigmas = append_dims(sigmas, init_latents.ndim)
latents = self.add_noise(init_latents, sigmas, model_pred)
return latents
def add_noise(self, latents, sigmas, model_pred=None):
if sigmas.ndim < latents.ndim:
sigmas = append_dims(sigmas, latents.ndim)
if model_pred is None:
noise = torch.randn_like(latents)
zt_noisy = latents + sigmas * noise
else:
if self.configs.train.loss_coef.get('rkl', 0) > 0:
mean, std = model_pred.mean, model_pred.std
zt_noisy = latents + mean + sigmas * std * torch.randn_like(latents)
else:
zt_noisy = latents + sigmas * model_pred.sample()
return zt_noisy
def retrieve_timesteps(self):
device=torch.device(f"cuda:{self.rank}")
num_inference_steps = self.configs.train.get('num_inference_steps', 5)
timesteps = np.linspace(
max(self.configs.train.timesteps), 0, num_inference_steps,
endpoint=False, dtype=np.int64,
) - 1
timesteps = torch.from_numpy(timesteps).to(device)
self.sd_pipe.scheduler.timesteps = timesteps
sigmas = self.sigmas_cache[timesteps.long()]
sigma_last = torch.tensor([0,], dtype=torch.float32).to(device=sigmas.device)
sigmas = torch.cat([sigmas, sigma_last]).type(torch.float32)
self.sd_pipe.scheduler.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.sd_pipe.scheduler._step_index = None
self.sd_pipe.scheduler._begin_index = None
return self.sd_pipe.scheduler.timesteps, num_inference_steps
class TrainerSDTurboSR(TrainerBaseSR):
def sd_forward_step(
self,
prompt: Union[str, List[str]] = None,
latents_hq: Optional[torch.Tensor] = None,
image_lq: torch.Tensor = None,
image_hq: torch.Tensor = None,
model_pred: DiagonalGaussianDistribution = None,
timesteps: List[int] = None,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image_lq (`torch.Tensor`): The low-quality image(s) for enhancement, range in [0, 1].
image_hq (`torch.Tensor`): The high-quality image(s) for enhancement, range in [0, 1].
noise_pred (`torch.Tensor`): Predicted noise by the noise prediction model
latents_hq (`torch.Tensor`, *optional*):
Pre-generated high-quality latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. If not provided, a latents tensor will be generated by sampling using vae .
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
aesthetic_score (`float`, *optional*, defaults to 6.0):
Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
Part of SDXL's micro-conditioning as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
Part of SDXL's micro-conditioning as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
simulate an aesthetic score of the generated image by influencing the negative text condition.
"""
device=torch.device(f"cuda:{self.rank}")
# Encode input prompt
prompt_embeds, negative_prompt_embeds = self.sd_pipe.encode_prompt(
prompt=prompt,
device=device,
num_images_per_prompt=1,
do_classifier_free_guidance=False,
)
self.prompt_embeds = prompt_embeds
# select the noise level, self.scheduler.sigmas, [1001,], descending
if not hasattr(self, 'sigmas_cache'):
assert self.sd_pipe.scheduler.sigmas.numel() == (self.configs.sd_pipe.num_train_steps + 1)
self.sigmas_cache = self.sd_pipe.scheduler.sigmas.flip(0)[1:].to(device) #ascending,1000
sigmas = self.sigmas_cache[timesteps] # (b,)
# Prepare input for SD
height, width = image_hq.shape[-2:]
if self.configs.train.start_mode:
image_lq_up = F.interpolate(image_lq, size=(height, width), mode='bicubic')
zt_clean = self.encode_first_stage(
image_lq_up, center_input_sample=True,
deterministic=self.configs.train.loss_coef.get('rkl', 0) > 0,
)
else:
if latents_hq is None:
zt_clean = self.encode_first_stage(
image_hq, center_input_sample=True, deterministic=False,
)
else:
zt_clean = latents_hq
sigmas = append_dims(sigmas, zt_clean.ndim)
zt_noisy = self.add_noise(zt_clean, sigmas, model_pred)
prompt_embeds = prompt_embeds.to(device)
zt_noisy_scale = self.scale_sd_input(zt_noisy, sigmas)
eps_pred = self.sd_pipe.unet(
zt_noisy_scale,
timesteps,
encoder_hidden_states=prompt_embeds,
timestep_cond=None,
cross_attention_kwargs=None,
added_cond_kwargs=None,
return_dict=False,
)[0] # eps-mode for sdxl and sdxl-refiner
if self.configs.train.noise_detach:
z0_pred = zt_noisy.detach() - sigmas * eps_pred
else:
z0_pred = zt_noisy - sigmas * eps_pred
return z0_pred, zt_noisy, zt_clean
@torch.no_grad()
def sample(
self,
image_lq: torch.Tensor,
prompt: Union[str, List[str]] = None,
target_size: Tuple[int, int] = (1024, 1024),
start_noise_predictor: torch.nn.Module = None,
intermediate_noise_predictor: torch.nn.Module = None,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image_lq (`torch.Tensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.Tensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`):
The image(s) to modify with the pipeline.
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
The required height and width of the super-resolved image.
strength (`float`, *optional*, defaults to 0.3):
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
be maximum and the denoising process will run for the full number of iterations specified in
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of
`denoising_start` being declared as an integer, the value of `strength` will be ignored.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
"""
device=torch.device(f"cuda:{self.rank}")
batch_size = image_lq.shape[0]
# Encode input prompt
prompt_embeds, negative_prompt_embeds = self.sd_pipe.encode_prompt(
prompt=prompt,
device=device,
num_images_per_prompt=1,
do_classifier_free_guidance=False,
)
timesteps, num_inference_steps = self.retrieve_timesteps()
latent_timestep = timesteps[:1].repeat(batch_size)
# Prepare latent variables
height, width = target_size
latents = self.prepare_lq_latents(image_lq, latent_timestep.long(), height, width, start_noise_predictor)
# Prepare extra step kwargs.
extra_step_kwargs = self.sd_pipe.prepare_extra_step_kwargs(None, 0.0)
prompt_embeds = prompt_embeds.to(device)
x0_progressive = []
images_progressive = []
for i, t in enumerate(timesteps):
latents_scaled = self.sd_pipe.scheduler.scale_model_input(latents, t)
# predict the noise residual
eps_pred = self.sd_pipe.unet(
latents_scaled,
t,
encoder_hidden_states=prompt_embeds,
timestep_cond=None,
added_cond_kwargs=None,
return_dict=False,
)[0]
z0_pred = latents - self.sigmas_cache[t.long()] * eps_pred
# compute the previous noisy sample x_t -> x_t-1
if intermediate_noise_predictor is not None and i + 1 < len(timesteps):
t_next = timesteps[i+1]
noise = intermediate_noise_predictor(image_lq, t_next, center_input_sample=True)
else:
noise = None
extra_step_kwargs['noise'] = noise
latents = self.sd_pipe.scheduler.step(eps_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
image = self.decode_first_stage(latents)
images_progressive.append(image)
x0_pred = self.decode_first_stage(z0_pred)
x0_progressive.append(x0_pred)
return images_progressive, x0_progressive
def my_worker_init_fn(worker_id):
np.random.seed(np.random.get_state()[1][0] + worker_id)
def hinge_d_loss(
logits_real: Union[torch.Tensor, List[torch.Tensor,]],
logits_fake: Union[torch.Tensor, List[torch.Tensor,]],
):
def _hinge_d_loss(logits_real, logits_fake):
loss_real = F.relu(1.0 - logits_real)
loss_fake = F.relu(1.0 + logits_fake)
d_loss = 0.5 * (loss_real + loss_fake)
loss = d_loss.mean(dim=list(range(1, logits_real.ndim)))
return loss
if not (isinstance(logits_real, list) or isinstance(logits_real, tuple)):
loss = _hinge_d_loss(logits_real, logits_fake)
else:
loss = _hinge_d_loss(logits_real[0], logits_fake[0])
for xx, yy in zip(logits_real[1:], logits_fake[1:]):
loss += _hinge_d_loss(xx, yy)
loss /= len(logits_real)
return loss
def get_torch_dtype(torch_dtype: str):
if torch_dtype == 'torch.float16':
return torch.float16
elif torch_dtype == 'torch.bfloat16':
return torch.bfloat16
elif torch_dtype == 'torch.float32':
return torch.float32
else:
raise ValueError(f'Unexpected torch dtype:{torch_dtype}')