brunorosilva
chore: change repo name
f1a0ba2
raw
history blame
640 Bytes
from PIL import Image
from torch.utils.data import Dataset
class ImageRetrievalDataset(Dataset):
def __init__(self, data, transform=None):
self.data = data
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
input_path, label_path = self.data[:, idx]
input_image = Image.open(input_path).convert("RGB")
label_image = Image.open(label_path).convert("RGB")
if self.transform:
input_image = self.transform(input_image)
label_image = self.transform(label_image)
return input_image, label_image