File size: 2,482 Bytes
5d756f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import pickle
import torchvision
import torch
import pathlib
import numpy as np
from typing import Callable, Optional, Union
from torch.hub import get_dir as get_hub_dir


def cache_embed_stats(embed_map: torch.Tensor):
    mean = embed_map.mean(dim=0, keepdim=True)
    rstd = ((embed_map - mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt()

    cache = dict(mean=mean, rstd=rstd, embed_map=embed_map)
    path = pathlib.Path(get_hub_dir(), f"embed_map_stats.torch")
    path.parent.mkdir(exist_ok=True, parents=True)
    torch.save(cache, path)


class CocoCSE(torch.utils.data.Dataset):

    def __init__(self,
                 dirpath: Union[str, pathlib.Path],
                 transform: Optional[Callable],
                 normalize_E: bool,):
        dirpath = pathlib.Path(dirpath)
        self.dirpath = dirpath

        self.transform = transform
        assert self.dirpath.is_dir(),\
            f"Did not find dataset at: {dirpath}"
        self.image_paths, self.embedding_paths = self._load_impaths()
        self.embed_map = torch.from_numpy(np.load(self.dirpath.joinpath("embed_map.npy")))
        mean = self.embed_map.mean(dim=0, keepdim=True)
        rstd = ((self.embed_map - mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt()
        self.embed_map = (self.embed_map - mean) * rstd
        cache_embed_stats(self.embed_map)

    def _load_impaths(self):
        image_dir = self.dirpath.joinpath("images")
        image_paths = list(image_dir.glob("*.png"))
        image_paths.sort()
        embedding_paths = [
            self.dirpath.joinpath("embedding", x.stem + ".npy") for x in image_paths
        ]
        return image_paths, embedding_paths

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        im = torchvision.io.read_image(str(self.image_paths[idx]))
        vertices, mask, border = np.split(np.load(self.embedding_paths[idx]), 3, axis=-1)
        vertices = torch.from_numpy(vertices.squeeze()).long()
        mask = torch.from_numpy(mask.squeeze()).float()
        border = torch.from_numpy(border.squeeze()).float()
        E_mask = 1 - mask - border
        batch = {
            "img": im,
            "vertices": vertices[None],
            "mask": mask[None],
            "embed_map": self.embed_map,
            "border": border[None],
            "E_mask": E_mask[None]
        }
        if self.transform is None:
            return batch
        return self.transform(batch)