Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import logging | |
import numpy as np | |
from glob import glob | |
from PIL import Image | |
from torch.utils.data import Dataset | |
from torchvision import transforms | |
from torchvision.transforms import functional as TF | |
from torchvision.transforms.functional import to_tensor | |
from data.image_transformations import get_transforms | |
class MIT5KDataset(Dataset): | |
def __init__(self, input_path, target_path, img_ids_filepath, transform=None): | |
self.input_path = input_path | |
self.target_path = target_path | |
self.transform = transform | |
self.img_ids = self._read_img_ids(img_ids_filepath) | |
self.data = self._create_data_list() | |
if transform is not None: | |
self.image_transforms = get_transforms(transform) | |
else: | |
self.image_transforms = None | |
def _read_img_ids(self, img_ids_filepath): | |
# Read the image IDs from the txt file | |
with open(img_ids_filepath, 'r') as f: | |
img_ids = [line.strip() for line in f.readlines()] | |
return img_ids | |
def _create_data_list(self): | |
# Create a list of dictionaries with 'input_path', 'target_path' and 'name' | |
data_list = [] | |
for input_file in glob(os.path.join(self.input_path, "*")): | |
img_id = os.path.basename(input_file).split('-')[0] | |
if img_id in self.img_ids: | |
target_file = os.path.join(self.target_path, os.path.basename(os.path.basename(input_file))) | |
if not os.path.exists(target_file): | |
raise FileNotFoundError(f"Target file {target_file} not found. While input file {input_file} was found.") | |
data_list.append({'input_path': input_file, 'target_path': target_file, 'name': img_id}) | |
return data_list | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
data = self.data[idx] | |
input_image, target_image = self._load_image_pair(data['input_path'], data['target_path']) | |
return {'input_image': input_image, 'target_image': target_image, 'name':data['name']} | |
def _load_image_pair(self, img1_path, img2_path): | |
img1_tensor = to_tensor(np.array(Image.open(img1_path).convert('RGB'))) | |
img2_tensor = to_tensor(np.array(Image.open(img2_path).convert('RGB'))) | |
if self.image_transforms is not None: | |
for image_transform in self.image_transforms: | |
img1_tensor, img2_tensor = image_transform(img1_tensor, img2_tensor) | |
return img1_tensor, img2_tensor | |
#class PPR10KDataset(Dataset): | |
def get_single_dataset(type, params): | |
if type == 'mit5k': | |
return MIT5KDataset(**params) | |
elif type == 'ppr10k': | |
# TODO: | |
return PPR10KDataset(**params) | |
else: | |
raise ValueError(f"Unsupported dataset type: {type}") | |
def get_datasets(config): | |
"""Returns the datsaets based on the configuration file.""" | |
if len(config) == 2: | |
train_dataset = get_single_dataset(config.train.target, config.train.params) | |
test_dataset = get_single_dataset(config.test.target, config.test.params) | |
return train_dataset, None, test_dataset | |
elif len(config) == 3: | |
train_dataset = get_single_dataset(config.train.target, config.train.params) | |
val_dataset = get_single_dataset(config.valid.target, config.valid.params) | |
test_dataset = get_single_dataset(config.test.target, config.test.params) | |
return train_dataset, val_dataset, test_dataset | |
else: | |
raise ValueError("The number of datasets should be 2 (train/test) or 3 (train/valid/test).") | |
if __name__ == "__main__": | |
from omegaconf import OmegaConf | |
config = OmegaConf.load("../configs/mit5k_upe_config.yaml") | |
dataset = MIT5KDataset(**config.data.train.params) | |
input_img, target_img, name = dataset[0] | |
import matplotlib.pyplot as plt | |
plt.subplot(1, 2, 1) | |
plt.imshow(input_img.squeeze().permute(1, 2, 0).numpy()) | |
plt.title("Input Image") | |
plt.subplot(1, 2, 2) | |
plt.imshow(target_img.squeeze().permute(1, 2, 0).numpy()) | |
plt.title("Target Image") | |
plt.show() | |