import torch, os, glob, pyiqa from argparse import ArgumentParser import numpy as np from PIL import Image from tqdm import tqdm from torchvision import transforms parser = ArgumentParser() parser.add_argument("--HR_dir", type=str, default="testset/RealSR/HR") parser.add_argument("--SR_dir", type=str, default="result/RealSR") args = parser.parse_args() device = torch.device("cuda") psnr = pyiqa.create_metric("psnr", test_y_channel=True, color_space="ycbcr", device=device) ssim = pyiqa.create_metric("ssim", test_y_channel=True, color_space="ycbcr", device=device) lpips = pyiqa.create_metric("lpips", device=device) dists = pyiqa.create_metric("dists", device=device) fid = pyiqa.create_metric("fid", device=device) niqe = pyiqa.create_metric("niqe", device=device) maniqa = pyiqa.create_metric("maniqa-pipal", device=device) clipiqa = pyiqa.create_metric("clipiqa", device=device) musiq = pyiqa.create_metric("musiq", device=device) test_SR_paths = list(sorted(glob.glob(os.path.join(args.SR_dir, "*")))) test_HR_paths = list(sorted(glob.glob(os.path.join(args.HR_dir, "*")))) metrics = {"psnr": [], "ssim": [], "lpips": [], "dists": [], "niqe": [], "maniqa": [], "musiq": [], "clipiqa": []} for i, (SR_path, HR_path) in tqdm(enumerate(zip(test_SR_paths, test_HR_paths))): SR = Image.open(SR_path).convert("RGB") SR = transforms.ToTensor()(SR).to(device).unsqueeze(0) HR = Image.open(HR_path).convert("RGB") HR = transforms.ToTensor()(HR).to(device).unsqueeze(0) metrics["psnr"].append(psnr(SR, HR).item()) metrics["ssim"].append(ssim(SR, HR).item()) metrics["lpips"].append(lpips(SR, HR).item()) metrics["dists"].append(dists(SR, HR).item()) metrics["niqe"].append(niqe(SR).item()) metrics["maniqa"].append(maniqa(SR).item()) metrics["clipiqa"].append(clipiqa(SR).item()) metrics["musiq"].append(musiq(SR).item()) for k in metrics.keys(): metrics[k] = np.mean(metrics[k]) metrics["fid"] = fid(args.SR_dir, args.HR_dir) for k, v in metrics.items(): if k == "niqe": print(k, f"{v:.3g}") elif k == "fid": print(k, f"{v:.5g}") else: print(k, f"{v:.4g}")