File size: 4,207 Bytes
2a13495 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import os
import torch
import shutil
import numpy as np
from PIL import Image
from tqdm import tqdm
from urllib.request import urlretrieve
class OxfordPetDataset(torch.utils.data.Dataset):
def __init__(self, root, mode="train", transform=None):
assert mode in {"train", "valid", "test"}
self.root = root
self.mode = mode
self.transform = transform
self.images_directory = os.path.join(self.root, "images")
self.masks_directory = os.path.join(self.root, "annotations", "trimaps")
self.filenames = self._read_split() # read train/valid/test splits
def __len__(self):
return len(self.filenames)
def __getitem__(self, idx):
filename = self.filenames[idx]
image_path = os.path.join(self.images_directory, filename + ".jpg")
mask_path = os.path.join(self.masks_directory, filename + ".png")
image = np.array(Image.open(image_path).convert("RGB"))
trimap = np.array(Image.open(mask_path))
mask = self._preprocess_mask(trimap)
sample = dict(image=image, mask=mask, trimap=trimap)
if self.transform is not None:
sample = self.transform(**sample)
return sample
@staticmethod
def _preprocess_mask(mask):
mask = mask.astype(np.float32)
mask[mask == 2.0] = 0.0
mask[(mask == 1.0) | (mask == 3.0)] = 1.0
return mask
def _read_split(self):
split_filename = "test.txt" if self.mode == "test" else "trainval.txt"
split_filepath = os.path.join(self.root, "annotations", split_filename)
with open(split_filepath) as f:
split_data = f.read().strip("\n").split("\n")
filenames = [x.split(" ")[0] for x in split_data]
if self.mode == "train": # 90% for train
filenames = [x for i, x in enumerate(filenames) if i % 10 != 0]
elif self.mode == "valid": # 10% for validation
filenames = [x for i, x in enumerate(filenames) if i % 10 == 0]
return filenames
@staticmethod
def download(root):
# load images
filepath = os.path.join(root, "images.tar.gz")
download_url(
url="https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz",
filepath=filepath,
)
extract_archive(filepath)
# load annotations
filepath = os.path.join(root, "annotations.tar.gz")
download_url(
url="https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz",
filepath=filepath,
)
extract_archive(filepath)
class SimpleOxfordPetDataset(OxfordPetDataset):
def __getitem__(self, *args, **kwargs):
sample = super().__getitem__(*args, **kwargs)
# resize images
image = np.array(
Image.fromarray(sample["image"]).resize((256, 256), Image.LINEAR)
)
mask = np.array(
Image.fromarray(sample["mask"]).resize((256, 256), Image.NEAREST)
)
trimap = np.array(
Image.fromarray(sample["trimap"]).resize((256, 256), Image.NEAREST)
)
# convert to other format HWC -> CHW
sample["image"] = np.moveaxis(image, -1, 0)
sample["mask"] = np.expand_dims(mask, 0)
sample["trimap"] = np.expand_dims(trimap, 0)
return sample
class TqdmUpTo(tqdm):
def update_to(self, b=1, bsize=1, tsize=None):
if tsize is not None:
self.total = tsize
self.update(b * bsize - self.n)
def download_url(url, filepath):
directory = os.path.dirname(os.path.abspath(filepath))
os.makedirs(directory, exist_ok=True)
if os.path.exists(filepath):
return
with TqdmUpTo(
unit="B",
unit_scale=True,
unit_divisor=1024,
miniters=1,
desc=os.path.basename(filepath),
) as t:
urlretrieve(url, filename=filepath, reporthook=t.update_to, data=None)
t.total = t.n
def extract_archive(filepath):
extract_dir = os.path.dirname(os.path.abspath(filepath))
dst_dir = os.path.splitext(filepath)[0]
if not os.path.exists(dst_dir):
shutil.unpack_archive(filepath, extract_dir)
|