|
import os |
|
import gc |
|
import tqdm |
|
import math |
|
import lpips |
|
import pyiqa |
|
import argparse |
|
import clip |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
import torch.utils.checkpoint |
|
import transformers |
|
|
|
from omegaconf import OmegaConf |
|
from accelerate import Accelerator |
|
from accelerate.utils import set_seed |
|
from PIL import Image |
|
from torchvision import transforms |
|
|
|
|
|
import diffusers |
|
import utils.misc as misc |
|
|
|
from diffusers.utils.import_utils import is_xformers_available |
|
from diffusers.optimization import get_scheduler |
|
|
|
from de_net import DEResNet |
|
from s3diff_tile import S3Diff |
|
from my_utils.testing_utils import parse_args_paired_testing, PlainDataset, lr_proc |
|
from utils.util_image import ImageSpliterTh |
|
from my_utils.utils import instantiate_from_config |
|
from pathlib import Path |
|
from utils import util_image |
|
from utils.wavelet_color import wavelet_color_fix, adain_color_fix |
|
|
|
def evaluate(in_path, ref_path, ntest): |
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
metric_dict = {} |
|
metric_dict["clipiqa"] = pyiqa.create_metric('clipiqa').to(device) |
|
metric_dict["musiq"] = pyiqa.create_metric('musiq').to(device) |
|
metric_dict["niqe"] = pyiqa.create_metric('niqe').to(device) |
|
metric_dict["maniqa"] = pyiqa.create_metric('maniqa').to(device) |
|
metric_paired_dict = {} |
|
|
|
in_path = Path(in_path) if not isinstance(in_path, Path) else in_path |
|
assert in_path.is_dir() |
|
|
|
ref_path_list = None |
|
if ref_path is not None: |
|
ref_path = Path(ref_path) if not isinstance(ref_path, Path) else ref_path |
|
ref_path_list = sorted([x for x in ref_path.glob("*.[jpJP][pnPN]*[gG]")]) |
|
if ntest is not None: ref_path_list = ref_path_list[:ntest] |
|
|
|
metric_paired_dict["psnr"]=pyiqa.create_metric('psnr', test_y_channel=True, color_space='ycbcr').to(device) |
|
metric_paired_dict["lpips"]=pyiqa.create_metric('lpips').to(device) |
|
metric_paired_dict["dists"]=pyiqa.create_metric('dists').to(device) |
|
metric_paired_dict["ssim"]=pyiqa.create_metric('ssim', test_y_channel=True, color_space='ycbcr' ).to(device) |
|
|
|
lr_path_list = sorted([x for x in in_path.glob("*.[jpJP][pnPN]*[gG]")]) |
|
if ntest is not None: lr_path_list = lr_path_list[:ntest] |
|
|
|
print(f'Find {len(lr_path_list)} images in {in_path}') |
|
result = {} |
|
for i in tqdm.tqdm(range(len(lr_path_list))): |
|
_in_path = lr_path_list[i] |
|
_ref_path = ref_path_list[i] if ref_path_list is not None else None |
|
|
|
im_in = util_image.imread(_in_path, chn='rgb', dtype='float32') |
|
im_in_tensor = util_image.img2tensor(im_in).cuda() |
|
for key, metric in metric_dict.items(): |
|
with torch.cuda.amp.autocast(): |
|
result[key] = result.get(key, 0) + metric(im_in_tensor).item() |
|
|
|
if ref_path is not None: |
|
im_ref = util_image.imread(_ref_path, chn='rgb', dtype='float32') |
|
im_ref_tensor = util_image.img2tensor(im_ref).cuda() |
|
for key, metric in metric_paired_dict.items(): |
|
result[key] = result.get(key, 0) + metric(im_in_tensor, im_ref_tensor).item() |
|
|
|
if ref_path is not None: |
|
fid_metric = pyiqa.create_metric('fid') |
|
result['fid'] = fid_metric(in_path, ref_path) |
|
|
|
print_results = [] |
|
for key, res in result.items(): |
|
if key == 'fid': |
|
print(f"{key}: {res:.2f}") |
|
print_results.append(f"{key}: {res:.2f}") |
|
else: |
|
print(f"{key}: {res/len(lr_path_list):.5f}") |
|
print_results.append(f"{key}: {res/len(lr_path_list):.5f}") |
|
return print_results |
|
|
|
|
|
def main(args): |
|
config = OmegaConf.load(args.base_config) |
|
|
|
accelerator = Accelerator( |
|
gradient_accumulation_steps=args.gradient_accumulation_steps, |
|
mixed_precision=args.mixed_precision, |
|
log_with=args.report_to, |
|
) |
|
|
|
if accelerator.is_local_main_process: |
|
transformers.utils.logging.set_verbosity_warning() |
|
diffusers.utils.logging.set_verbosity_info() |
|
else: |
|
transformers.utils.logging.set_verbosity_error() |
|
diffusers.utils.logging.set_verbosity_error() |
|
|
|
if args.seed is not None: |
|
set_seed(args.seed) |
|
|
|
if accelerator.is_main_process: |
|
os.makedirs(os.path.join(args.output_dir, "checkpoints"), exist_ok=True) |
|
os.makedirs(os.path.join(args.output_dir, "eval"), exist_ok=True) |
|
|
|
|
|
net_sr = S3Diff(lora_rank_unet=args.lora_rank_unet, lora_rank_vae=args.lora_rank_vae, sd_path=args.sd_path, pretrained_path=args.pretrained_path, args=args) |
|
net_sr.set_eval() |
|
|
|
net_de = DEResNet(num_in_ch=3, num_degradation=2) |
|
net_de.load_model(args.de_net_path) |
|
net_de = net_de.cuda() |
|
net_de.eval() |
|
|
|
if args.enable_xformers_memory_efficient_attention: |
|
if is_xformers_available(): |
|
net_sr.unet.enable_xformers_memory_efficient_attention() |
|
else: |
|
raise ValueError("xformers is not available, please install it by running `pip install xformers`") |
|
|
|
if args.gradient_checkpointing: |
|
net_sr.unet.enable_gradient_checkpointing() |
|
|
|
if args.allow_tf32: |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
|
dataset_val = PlainDataset(config.validation) |
|
dl_val = torch.utils.data.DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=0) |
|
|
|
|
|
net_sr, net_de = accelerator.prepare(net_sr, net_de) |
|
|
|
weight_dtype = torch.float32 |
|
if accelerator.mixed_precision == "fp16": |
|
weight_dtype = torch.float16 |
|
elif accelerator.mixed_precision == "bf16": |
|
weight_dtype = torch.bfloat16 |
|
|
|
|
|
net_sr.to(accelerator.device, dtype=weight_dtype) |
|
net_de.to(accelerator.device, dtype=weight_dtype) |
|
|
|
offset = args.padding_offset |
|
for step, batch_val in enumerate(dl_val): |
|
lr_path = batch_val['lr_path'][0] |
|
(path, name) = os.path.split(lr_path) |
|
|
|
im_lr = batch_val['lr'].cuda() |
|
im_lr = im_lr.to(memory_format=torch.contiguous_format).float() |
|
|
|
ori_h, ori_w = im_lr.shape[2:] |
|
im_lr_resize = F.interpolate( |
|
im_lr, |
|
size=(ori_h * config.sf, |
|
ori_w * config.sf), |
|
mode='bicubic', |
|
) |
|
|
|
im_lr_resize = im_lr_resize.contiguous() |
|
im_lr_resize_norm = im_lr_resize * 2 - 1.0 |
|
im_lr_resize_norm = torch.clamp(im_lr_resize_norm, -1.0, 1.0) |
|
resize_h, resize_w = im_lr_resize_norm.shape[2:] |
|
|
|
pad_h = (math.ceil(resize_h / 64)) * 64 - resize_h |
|
pad_w = (math.ceil(resize_w / 64)) * 64 - resize_w |
|
im_lr_resize_norm = F.pad(im_lr_resize_norm, pad=(0, pad_w, 0, pad_h), mode='reflect') |
|
|
|
B = im_lr_resize.size(0) |
|
with torch.no_grad(): |
|
|
|
deg_score = net_de(im_lr) |
|
pos_tag_prompt = [args.pos_prompt for _ in range(B)] |
|
neg_tag_prompt = [args.neg_prompt for _ in range(B)] |
|
x_tgt_pred = accelerator.unwrap_model(net_sr)(im_lr_resize_norm, deg_score, pos_prompt=pos_tag_prompt, neg_prompt=neg_tag_prompt) |
|
x_tgt_pred = x_tgt_pred[:, :, :resize_h, :resize_w] |
|
out_img = (x_tgt_pred * 0.5 + 0.5).cpu().detach() |
|
|
|
output_pil = transforms.ToPILImage()(out_img[0]) |
|
|
|
if args.align_method == 'nofix': |
|
output_pil = output_pil |
|
else: |
|
im_lr_resize = transforms.ToPILImage()(im_lr_resize[0].cpu().detach()) |
|
if args.align_method == 'wavelet': |
|
output_pil = wavelet_color_fix(output_pil, im_lr_resize) |
|
elif args.align_method == 'adain': |
|
output_pil = adain_color_fix(output_pil, im_lr_resize) |
|
|
|
fname, ext = os.path.splitext(name) |
|
outf = os.path.join(args.output_dir, fname+'.png') |
|
output_pil.save(outf) |
|
|
|
print_results = evaluate(args.output_dir, args.ref_path, None) |
|
out_t = os.path.join(args.output_dir, 'results.txt') |
|
with open(out_t, 'w', encoding='utf-8') as f: |
|
for item in print_results: |
|
f.write(f"{item}\n") |
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
if __name__ == "__main__": |
|
args = parse_args_paired_testing() |
|
main(args) |
|
|