import os import cv2 import numpy as np import torch from skimage import io from skimage.transform import resize from torch.utils.data import Dataset from saicinpainting.evaluation.evaluator import InpaintingEvaluator from saicinpainting.evaluation.losses.base_loss import SSIMScore, LPIPSScore, FIDScore class SimpleImageDataset(Dataset): def __init__(self, root_dir, image_size=(400, 600)): self.root_dir = root_dir self.files = sorted(os.listdir(root_dir)) self.image_size = image_size def __getitem__(self, index): img_name = os.path.join(self.root_dir, self.files[index]) image = io.imread(img_name) image = resize(image, self.image_size, anti_aliasing=True) image = torch.FloatTensor(image).permute(2, 0, 1) return image def __len__(self): return len(self.files) def create_rectangle_mask(height, width): mask = np.ones((height, width)) up_left_corner = width // 4, height // 4 down_right_corner = (width - up_left_corner[0] - 1, height - up_left_corner[1] - 1) cv2.rectangle(mask, up_left_corner, down_right_corner, (0, 0, 0), thickness=cv2.FILLED) return mask class Model(): def __call__(self, img_batch, mask_batch): mean = (img_batch * mask_batch[:, None, :, :]).sum(dim=(2, 3)) / mask_batch.sum(dim=(1, 2))[:, None] inpainted = mean[:, :, None, None] * (1 - mask_batch[:, None, :, :]) + img_batch * mask_batch[:, None, :, :] return inpainted class SimpleImageSquareMaskDataset(Dataset): def __init__(self, dataset): self.dataset = dataset self.mask = torch.FloatTensor(create_rectangle_mask(*self.dataset.image_size)) self.model = Model() def __getitem__(self, index): img = self.dataset[index] mask = self.mask.clone() inpainted = self.model(img[None, ...], mask[None, ...]) return dict(image=img, mask=mask, inpainted=inpainted) def __len__(self): return len(self.dataset) dataset = SimpleImageDataset('imgs') mask_dataset = SimpleImageSquareMaskDataset(dataset) model = Model() metrics = { 'ssim': SSIMScore(), 'lpips': LPIPSScore(), 'fid': FIDScore() } evaluator = InpaintingEvaluator( mask_dataset, scores=metrics, batch_size=3, area_grouping=True ) results = evaluator.evaluate(model) print(results)