NamedCurves / data /datasets.py
davidserra9's picture
First commit from github repo
117183e verified
raw
history blame
4.1 kB
import os
import logging
import numpy as np
from glob import glob
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.transforms import functional as TF
from torchvision.transforms.functional import to_tensor
from data.image_transformations import get_transforms
class MIT5KDataset(Dataset):
def __init__(self, input_path, target_path, img_ids_filepath, transform=None):
self.input_path = input_path
self.target_path = target_path
self.transform = transform
self.img_ids = self._read_img_ids(img_ids_filepath)
self.data = self._create_data_list()
if transform is not None:
self.image_transforms = get_transforms(transform)
else:
self.image_transforms = None
def _read_img_ids(self, img_ids_filepath):
# Read the image IDs from the txt file
with open(img_ids_filepath, 'r') as f:
img_ids = [line.strip() for line in f.readlines()]
return img_ids
def _create_data_list(self):
# Create a list of dictionaries with 'input_path', 'target_path' and 'name'
data_list = []
for input_file in glob(os.path.join(self.input_path, "*")):
img_id = os.path.basename(input_file).split('-')[0]
if img_id in self.img_ids:
target_file = os.path.join(self.target_path, os.path.basename(os.path.basename(input_file)))
if not os.path.exists(target_file):
raise FileNotFoundError(f"Target file {target_file} not found. While input file {input_file} was found.")
data_list.append({'input_path': input_file, 'target_path': target_file, 'name': img_id})
return data_list
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data = self.data[idx]
input_image, target_image = self._load_image_pair(data['input_path'], data['target_path'])
return {'input_image': input_image, 'target_image': target_image, 'name':data['name']}
def _load_image_pair(self, img1_path, img2_path):
img1_tensor = to_tensor(np.array(Image.open(img1_path).convert('RGB')))
img2_tensor = to_tensor(np.array(Image.open(img2_path).convert('RGB')))
if self.image_transforms is not None:
for image_transform in self.image_transforms:
img1_tensor, img2_tensor = image_transform(img1_tensor, img2_tensor)
return img1_tensor, img2_tensor
#class PPR10KDataset(Dataset):
def get_single_dataset(type, params):
if type == 'mit5k':
return MIT5KDataset(**params)
elif type == 'ppr10k':
# TODO:
return PPR10KDataset(**params)
else:
raise ValueError(f"Unsupported dataset type: {type}")
def get_datasets(config):
"""Returns the datsaets based on the configuration file."""
if len(config) == 2:
train_dataset = get_single_dataset(config.train.target, config.train.params)
test_dataset = get_single_dataset(config.test.target, config.test.params)
return train_dataset, None, test_dataset
elif len(config) == 3:
train_dataset = get_single_dataset(config.train.target, config.train.params)
val_dataset = get_single_dataset(config.valid.target, config.valid.params)
test_dataset = get_single_dataset(config.test.target, config.test.params)
return train_dataset, val_dataset, test_dataset
else:
raise ValueError("The number of datasets should be 2 (train/test) or 3 (train/valid/test).")
if __name__ == "__main__":
from omegaconf import OmegaConf
config = OmegaConf.load("../configs/mit5k_upe_config.yaml")
dataset = MIT5KDataset(**config.data.train.params)
input_img, target_img, name = dataset[0]
import matplotlib.pyplot as plt
plt.subplot(1, 2, 1)
plt.imshow(input_img.squeeze().permute(1, 2, 0).numpy())
plt.title("Input Image")
plt.subplot(1, 2, 2)
plt.imshow(target_img.squeeze().permute(1, 2, 0).numpy())
plt.title("Target Image")
plt.show()