AdcSR / evaluate.py
Guaishou74851's picture
Upload 66 files
34b61ae verified
raw
history blame
2.15 kB
import torch, os, glob, pyiqa
from argparse import ArgumentParser
import numpy as np
from PIL import Image
from tqdm import tqdm
from torchvision import transforms
parser = ArgumentParser()
parser.add_argument("--HR_dir", type=str, default="testset/RealSR/HR")
parser.add_argument("--SR_dir", type=str, default="result/RealSR")
args = parser.parse_args()
device = torch.device("cuda")
psnr = pyiqa.create_metric("psnr", test_y_channel=True, color_space="ycbcr", device=device)
ssim = pyiqa.create_metric("ssim", test_y_channel=True, color_space="ycbcr", device=device)
lpips = pyiqa.create_metric("lpips", device=device)
dists = pyiqa.create_metric("dists", device=device)
fid = pyiqa.create_metric("fid", device=device)
niqe = pyiqa.create_metric("niqe", device=device)
maniqa = pyiqa.create_metric("maniqa-pipal", device=device)
clipiqa = pyiqa.create_metric("clipiqa", device=device)
musiq = pyiqa.create_metric("musiq", device=device)
test_SR_paths = list(sorted(glob.glob(os.path.join(args.SR_dir, "*"))))
test_HR_paths = list(sorted(glob.glob(os.path.join(args.HR_dir, "*"))))
metrics = {"psnr": [], "ssim": [], "lpips": [], "dists": [], "niqe": [], "maniqa": [], "musiq": [], "clipiqa": []}
for i, (SR_path, HR_path) in tqdm(enumerate(zip(test_SR_paths, test_HR_paths))):
SR = Image.open(SR_path).convert("RGB")
SR = transforms.ToTensor()(SR).to(device).unsqueeze(0)
HR = Image.open(HR_path).convert("RGB")
HR = transforms.ToTensor()(HR).to(device).unsqueeze(0)
metrics["psnr"].append(psnr(SR, HR).item())
metrics["ssim"].append(ssim(SR, HR).item())
metrics["lpips"].append(lpips(SR, HR).item())
metrics["dists"].append(dists(SR, HR).item())
metrics["niqe"].append(niqe(SR).item())
metrics["maniqa"].append(maniqa(SR).item())
metrics["clipiqa"].append(clipiqa(SR).item())
metrics["musiq"].append(musiq(SR).item())
for k in metrics.keys():
metrics[k] = np.mean(metrics[k])
metrics["fid"] = fid(args.SR_dir, args.HR_dir)
for k, v in metrics.items():
if k == "niqe":
print(k, f"{v:.3g}")
elif k == "fid":
print(k, f"{v:.5g}")
else:
print(k, f"{v:.4g}")