import pickle import torchvision import torch import pathlib import numpy as np from typing import Callable, Optional, Union from torch.hub import get_dir as get_hub_dir def cache_embed_stats(embed_map: torch.Tensor): mean = embed_map.mean(dim=0, keepdim=True) rstd = ((embed_map - mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt() cache = dict(mean=mean, rstd=rstd, embed_map=embed_map) path = pathlib.Path(get_hub_dir(), f"embed_map_stats.torch") path.parent.mkdir(exist_ok=True, parents=True) torch.save(cache, path) class CocoCSE(torch.utils.data.Dataset): def __init__(self, dirpath: Union[str, pathlib.Path], transform: Optional[Callable], normalize_E: bool,): dirpath = pathlib.Path(dirpath) self.dirpath = dirpath self.transform = transform assert self.dirpath.is_dir(),\ f"Did not find dataset at: {dirpath}" self.image_paths, self.embedding_paths = self._load_impaths() self.embed_map = torch.from_numpy(np.load(self.dirpath.joinpath("embed_map.npy"))) mean = self.embed_map.mean(dim=0, keepdim=True) rstd = ((self.embed_map - mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt() self.embed_map = (self.embed_map - mean) * rstd cache_embed_stats(self.embed_map) def _load_impaths(self): image_dir = self.dirpath.joinpath("images") image_paths = list(image_dir.glob("*.png")) image_paths.sort() embedding_paths = [ self.dirpath.joinpath("embedding", x.stem + ".npy") for x in image_paths ] return image_paths, embedding_paths def __len__(self): return len(self.image_paths) def __getitem__(self, idx): im = torchvision.io.read_image(str(self.image_paths[idx])) vertices, mask, border = np.split(np.load(self.embedding_paths[idx]), 3, axis=-1) vertices = torch.from_numpy(vertices.squeeze()).long() mask = torch.from_numpy(mask.squeeze()).float() border = torch.from_numpy(border.squeeze()).float() E_mask = 1 - mask - border batch = { "img": im, "vertices": vertices[None], "mask": mask[None], "embed_map": self.embed_map, "border": border[None], "E_mask": E_mask[None] } if self.transform is None: return batch return self.transform(batch)