ghlee94's picture
Init
2a13495
raw
history blame contribute delete
No virus
4.21 kB
import os
import torch
import shutil
import numpy as np
from PIL import Image
from tqdm import tqdm
from urllib.request import urlretrieve
class OxfordPetDataset(torch.utils.data.Dataset):
def __init__(self, root, mode="train", transform=None):
assert mode in {"train", "valid", "test"}
self.root = root
self.mode = mode
self.transform = transform
self.images_directory = os.path.join(self.root, "images")
self.masks_directory = os.path.join(self.root, "annotations", "trimaps")
self.filenames = self._read_split() # read train/valid/test splits
def __len__(self):
return len(self.filenames)
def __getitem__(self, idx):
filename = self.filenames[idx]
image_path = os.path.join(self.images_directory, filename + ".jpg")
mask_path = os.path.join(self.masks_directory, filename + ".png")
image = np.array(Image.open(image_path).convert("RGB"))
trimap = np.array(Image.open(mask_path))
mask = self._preprocess_mask(trimap)
sample = dict(image=image, mask=mask, trimap=trimap)
if self.transform is not None:
sample = self.transform(**sample)
return sample
@staticmethod
def _preprocess_mask(mask):
mask = mask.astype(np.float32)
mask[mask == 2.0] = 0.0
mask[(mask == 1.0) | (mask == 3.0)] = 1.0
return mask
def _read_split(self):
split_filename = "test.txt" if self.mode == "test" else "trainval.txt"
split_filepath = os.path.join(self.root, "annotations", split_filename)
with open(split_filepath) as f:
split_data = f.read().strip("\n").split("\n")
filenames = [x.split(" ")[0] for x in split_data]
if self.mode == "train": # 90% for train
filenames = [x for i, x in enumerate(filenames) if i % 10 != 0]
elif self.mode == "valid": # 10% for validation
filenames = [x for i, x in enumerate(filenames) if i % 10 == 0]
return filenames
@staticmethod
def download(root):
# load images
filepath = os.path.join(root, "images.tar.gz")
download_url(
url="https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz",
filepath=filepath,
)
extract_archive(filepath)
# load annotations
filepath = os.path.join(root, "annotations.tar.gz")
download_url(
url="https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz",
filepath=filepath,
)
extract_archive(filepath)
class SimpleOxfordPetDataset(OxfordPetDataset):
def __getitem__(self, *args, **kwargs):
sample = super().__getitem__(*args, **kwargs)
# resize images
image = np.array(
Image.fromarray(sample["image"]).resize((256, 256), Image.LINEAR)
)
mask = np.array(
Image.fromarray(sample["mask"]).resize((256, 256), Image.NEAREST)
)
trimap = np.array(
Image.fromarray(sample["trimap"]).resize((256, 256), Image.NEAREST)
)
# convert to other format HWC -> CHW
sample["image"] = np.moveaxis(image, -1, 0)
sample["mask"] = np.expand_dims(mask, 0)
sample["trimap"] = np.expand_dims(trimap, 0)
return sample
class TqdmUpTo(tqdm):
def update_to(self, b=1, bsize=1, tsize=None):
if tsize is not None:
self.total = tsize
self.update(b * bsize - self.n)
def download_url(url, filepath):
directory = os.path.dirname(os.path.abspath(filepath))
os.makedirs(directory, exist_ok=True)
if os.path.exists(filepath):
return
with TqdmUpTo(
unit="B",
unit_scale=True,
unit_divisor=1024,
miniters=1,
desc=os.path.basename(filepath),
) as t:
urlretrieve(url, filename=filepath, reporthook=t.update_to, data=None)
t.total = t.n
def extract_archive(filepath):
extract_dir = os.path.dirname(os.path.abspath(filepath))
dst_dir = os.path.splitext(filepath)[0]
if not os.path.exists(dst_dir):
shutil.unpack_archive(filepath, extract_dir)