Spaces:
Runtime error
Runtime error
File size: 7,053 Bytes
d380b77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import glob
import os
import cv2
import PIL.Image as Image
import numpy as np
from torch.utils.data import Dataset
import torch.nn.functional as F
def load_image(fname, mode='RGB', return_orig=False):
img = np.array(Image.open(fname).convert(mode))
if img.ndim == 3:
img = np.transpose(img, (2, 0, 1))
out_img = img.astype('float32') / 255
if return_orig:
return out_img, img
else:
return out_img
def ceil_modulo(x, mod):
if x % mod == 0:
return x
return (x // mod + 1) * mod
def pad_img_to_modulo(img, mod):
channels, height, width = img.shape
out_height = ceil_modulo(height, mod)
out_width = ceil_modulo(width, mod)
return np.pad(img, ((0, 0), (0, out_height - height), (0, out_width - width)), mode='symmetric')
def pad_tensor_to_modulo(img, mod):
batch_size, channels, height, width = img.shape
out_height = ceil_modulo(height, mod)
out_width = ceil_modulo(width, mod)
return F.pad(img, pad=(0, out_width - width, 0, out_height - height), mode='reflect')
def scale_image(img, factor, interpolation=cv2.INTER_AREA):
if img.shape[0] == 1:
img = img[0]
else:
img = np.transpose(img, (1, 2, 0))
img = cv2.resize(img, dsize=None, fx=factor, fy=factor, interpolation=interpolation)
if img.ndim == 2:
img = img[None, ...]
else:
img = np.transpose(img, (2, 0, 1))
return img
class InpaintingDataset(Dataset):
def __init__(self, datadir, img_suffix='.jpg', pad_out_to_modulo=None, scale_factor=None):
self.datadir = datadir
self.mask_filenames = sorted(list(glob.glob(os.path.join(self.datadir, '**', '*mask*.png'), recursive=True)))
self.img_filenames = [fname.rsplit('_mask', 1)[0] + img_suffix for fname in self.mask_filenames]
self.pad_out_to_modulo = pad_out_to_modulo
self.scale_factor = scale_factor
def __len__(self):
return len(self.mask_filenames)
def __getitem__(self, i):
image = load_image(self.img_filenames[i], mode='RGB')
mask = load_image(self.mask_filenames[i], mode='L')
result = dict(image=image, mask=mask[None, ...])
if self.scale_factor is not None:
result['image'] = scale_image(result['image'], self.scale_factor)
result['mask'] = scale_image(result['mask'], self.scale_factor, interpolation=cv2.INTER_NEAREST)
if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
result['image'] = pad_img_to_modulo(result['image'], self.pad_out_to_modulo)
result['mask'] = pad_img_to_modulo(result['mask'], self.pad_out_to_modulo)
return result
class OurInpaintingDataset(Dataset):
def __init__(self, datadir, img_suffix='.jpg', pad_out_to_modulo=None, scale_factor=None):
self.datadir = datadir
self.mask_filenames = sorted(list(glob.glob(os.path.join(self.datadir, 'mask', '**', '*mask*.png'), recursive=True)))
self.img_filenames = [os.path.join(self.datadir, 'img', os.path.basename(fname.rsplit('-', 1)[0].rsplit('_', 1)[0]) + '.png') for fname in self.mask_filenames]
self.pad_out_to_modulo = pad_out_to_modulo
self.scale_factor = scale_factor
def __len__(self):
return len(self.mask_filenames)
def __getitem__(self, i):
result = dict(image=load_image(self.img_filenames[i], mode='RGB'),
mask=load_image(self.mask_filenames[i], mode='L')[None, ...])
if self.scale_factor is not None:
result['image'] = scale_image(result['image'], self.scale_factor)
result['mask'] = scale_image(result['mask'], self.scale_factor)
if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
result['image'] = pad_img_to_modulo(result['image'], self.pad_out_to_modulo)
result['mask'] = pad_img_to_modulo(result['mask'], self.pad_out_to_modulo)
return result
class PrecomputedInpaintingResultsDataset(InpaintingDataset):
def __init__(self, datadir, predictdir, inpainted_suffix='_inpainted.jpg', **kwargs):
super().__init__(datadir, **kwargs)
if not datadir.endswith('/'):
datadir += '/'
self.predictdir = predictdir
self.pred_filenames = [os.path.join(predictdir, os.path.splitext(fname[len(datadir):])[0] + inpainted_suffix)
for fname in self.mask_filenames]
def __getitem__(self, i):
result = super().__getitem__(i)
result['inpainted'] = load_image(self.pred_filenames[i])
if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
result['inpainted'] = pad_img_to_modulo(result['inpainted'], self.pad_out_to_modulo)
return result
class OurPrecomputedInpaintingResultsDataset(OurInpaintingDataset):
def __init__(self, datadir, predictdir, inpainted_suffix="png", **kwargs):
super().__init__(datadir, **kwargs)
if not datadir.endswith('/'):
datadir += '/'
self.predictdir = predictdir
self.pred_filenames = [os.path.join(predictdir, os.path.basename(os.path.splitext(fname)[0]) + f'_inpainted.{inpainted_suffix}')
for fname in self.mask_filenames]
# self.pred_filenames = [os.path.join(predictdir, os.path.splitext(fname[len(datadir):])[0] + inpainted_suffix)
# for fname in self.mask_filenames]
def __getitem__(self, i):
result = super().__getitem__(i)
result['inpainted'] = self.file_loader(self.pred_filenames[i])
if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
result['inpainted'] = pad_img_to_modulo(result['inpainted'], self.pad_out_to_modulo)
return result
class InpaintingEvalOnlineDataset(Dataset):
def __init__(self, indir, mask_generator, img_suffix='.jpg', pad_out_to_modulo=None, scale_factor=None, **kwargs):
self.indir = indir
self.mask_generator = mask_generator
self.img_filenames = sorted(list(glob.glob(os.path.join(self.indir, '**', f'*{img_suffix}' ), recursive=True)))
self.pad_out_to_modulo = pad_out_to_modulo
self.scale_factor = scale_factor
def __len__(self):
return len(self.img_filenames)
def __getitem__(self, i):
img, raw_image = load_image(self.img_filenames[i], mode='RGB', return_orig=True)
mask = self.mask_generator(img, raw_image=raw_image)
result = dict(image=img, mask=mask)
if self.scale_factor is not None:
result['image'] = scale_image(result['image'], self.scale_factor)
result['mask'] = scale_image(result['mask'], self.scale_factor, interpolation=cv2.INTER_NEAREST)
if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
result['image'] = pad_img_to_modulo(result['image'], self.pad_out_to_modulo)
result['mask'] = pad_img_to_modulo(result['mask'], self.pad_out_to_modulo)
return result |