Spaces:
Configuration error
Configuration error
import os | |
import torch | |
from cleanfid import fid as FID | |
from PIL import Image | |
from torch.utils.data import Dataset | |
from torchmetrics.image import StructuralSimilarityIndexMeasure | |
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity | |
from torchvision import transforms | |
from tqdm import tqdm | |
from utils import scan_files_in_dir | |
from prettytable import PrettyTable | |
class EvalDataset(Dataset): | |
def __init__(self, gt_folder, pred_folder, height=1024): | |
self.gt_folder = gt_folder | |
self.pred_folder = pred_folder | |
self.height = height | |
self.data = self.prepare_data() | |
self.to_tensor = transforms.ToTensor() | |
def extract_id_from_filename(self, filename): | |
# find first number in filename | |
start_i = None | |
for i, c in enumerate(filename): | |
if c.isdigit(): | |
start_i = i | |
break | |
if start_i is None: | |
assert False, f"Cannot find number in filename {filename}" | |
return filename[start_i:start_i+8] | |
def prepare_data(self): | |
gt_files = scan_files_in_dir(self.gt_folder, postfix={'.jpg', '.png'}) | |
gt_dict = {self.extract_id_from_filename(file.name): file for file in gt_files} | |
pred_files = scan_files_in_dir(self.pred_folder, postfix={'.jpg', '.png'}) | |
tuples = [] | |
for pred_file in pred_files: | |
pred_id = self.extract_id_from_filename(pred_file.name) | |
if pred_id not in gt_dict: | |
print(f"Cannot find gt file for {pred_file}") | |
else: | |
tuples.append((gt_dict[pred_id].path, pred_file.path)) | |
return tuples | |
def resize(self, img): | |
w, h = img.size | |
new_w = int(w * self.height / h) | |
return img.resize((new_w, self.height), Image.LANCZOS) | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
gt_path, pred_path = self.data[idx] | |
gt, pred = self.resize(Image.open(gt_path)), self.resize(Image.open(pred_path)) | |
if gt.height != self.height: | |
gt = self.resize(gt) | |
if pred.height != self.height: | |
pred = self.resize(pred) | |
gt = self.to_tensor(gt) | |
pred = self.to_tensor(pred) | |
return gt, pred | |
def copy_resize_gt(gt_folder, height): | |
new_folder = f"{gt_folder}_{height}" | |
if not os.path.exists(new_folder): | |
os.makedirs(new_folder, exist_ok=True) | |
for file in tqdm(os.listdir(gt_folder)): | |
if os.path.exists(os.path.join(new_folder, file)): | |
continue | |
img = Image.open(os.path.join(gt_folder, file)) | |
w, h = img.size | |
new_w = int(w * height / h) | |
img = img.resize((new_w, height), Image.LANCZOS) | |
img.save(os.path.join(new_folder, file)) | |
return new_folder | |
def ssim(dataloader): | |
ssim_score = 0 | |
ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to("cuda") | |
for gt, pred in tqdm(dataloader, desc="Calculating SSIM"): | |
batch_size = gt.size(0) | |
gt, pred = gt.to("cuda"), pred.to("cuda") | |
ssim_score += ssim(pred, gt) * batch_size | |
return ssim_score / len(dataloader.dataset) | |
def lpips(dataloader): | |
lpips_score = LearnedPerceptualImagePatchSimilarity(net_type='squeeze').to("cuda") | |
score = 0 | |
for gt, pred in tqdm(dataloader, desc="Calculating LPIPS"): | |
batch_size = gt.size(0) | |
pred = pred.to("cuda") | |
gt = gt.to("cuda") | |
# LPIPS needs the images to be in the [-1, 1] range. | |
gt = (gt * 2) - 1 | |
pred = (pred * 2) - 1 | |
score += lpips_score(gt, pred) * batch_size | |
return score / len(dataloader.dataset) | |
def eval(args): | |
# Check gt_folder has images with target height, resize if not | |
pred_sample = os.listdir(args.pred_folder)[0] | |
gt_sample = os.listdir(args.gt_folder)[0] | |
img = Image.open(os.path.join(args.pred_folder, pred_sample)) | |
gt_img = Image.open(os.path.join(args.gt_folder, gt_sample)) | |
if img.height != gt_img.height: | |
title = "--"*30 + "Resizing GT Images to height {img.height}" + "--"*30 | |
print(title) | |
args.gt_folder = copy_resize_gt(args.gt_folder, img.height) | |
print("-"*len(title)) | |
# Form dataset | |
dataset = EvalDataset(args.gt_folder, args.pred_folder, img.height) | |
dataloader = torch.utils.data.DataLoader( | |
dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False, drop_last=False | |
) | |
# Calculate Metrics | |
header = [] | |
row = [] | |
header = ["FID", "KID"] | |
fid_ = FID.compute_fid(args.gt_folder, args.pred_folder) | |
kid_ = FID.compute_kid(args.gt_folder, args.pred_folder) * 1000 | |
row = [fid_, kid_] | |
if args.paired: | |
header += ["SSIM", "LPIPS"] | |
ssim_ = ssim(dataloader).item() | |
lpips_ = lpips(dataloader).item() | |
row += [ssim_, lpips_] | |
# Print Results | |
print("GT Folder : ", args.gt_folder) | |
print("Pred Folder: ", args.pred_folder) | |
table = PrettyTable() | |
table.field_names = header | |
table.add_row(row) | |
print(table) | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--gt_folder", type=str, required=True) | |
parser.add_argument("--pred_folder", type=str, required=True) | |
parser.add_argument("--paired", action="store_true") | |
parser.add_argument("--batch_size", type=int, default=16) | |
parser.add_argument("--num_workers", type=int, default=4) | |
args = parser.parse_args() | |
eval(args) |