Spaces:
Running
on
Zero
Running
on
Zero
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary | |
# | |
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual | |
# property and proprietary rights in and to this material, related | |
# documentation and any modifications thereto. Any use, reproduction, | |
# disclosure or distribution of this material and related documentation | |
# without an express license agreement from NVIDIA CORPORATION or | |
# its affiliates is strictly prohibited. | |
"""Streaming images and labels from datasets created with dataset_tool.py.""" | |
import cv2 | |
import os | |
import numpy as np | |
import zipfile | |
import PIL.Image | |
import json | |
import torch | |
import dnnlib | |
from torchvision import transforms | |
from pdb import set_trace as st | |
from .shapenet import LMDBDataset_MV_Compressed, decompress_array | |
try: | |
import pyspng | |
except ImportError: | |
pyspng = None | |
#---------------------------------------------------------------------------- | |
# copide from eg3d/train.py | |
def init_dataset_kwargs(data, | |
class_name='datasets.eg3d_dataset.ImageFolderDataset', | |
reso_gt=128): | |
# try: | |
# if data == 'None': | |
# dataset_kwargs = dnnlib.EasyDict({}) # | |
# dataset_kwargs.name = 'eg3d_dataset' | |
# dataset_kwargs.resolution = 128 | |
# dataset_kwargs.use_labels = False | |
# dataset_kwargs.max_size = 70000 | |
# return dataset_kwargs, 'eg3d_dataset' | |
dataset_kwargs = dnnlib.EasyDict(class_name=class_name, | |
reso_gt=reso_gt, | |
path=data, | |
use_labels=True, | |
max_size=None, | |
xflip=False) | |
dataset_obj = dnnlib.util.construct_class_by_name( | |
**dataset_kwargs) # Subclass of training.dataset.Dataset. | |
dataset_kwargs.resolution = dataset_obj.resolution # Be explicit about resolution. | |
dataset_kwargs.use_labels = dataset_obj.has_labels # Be explicit about labels. | |
dataset_kwargs.max_size = len( | |
dataset_obj) # Be explicit about dataset size. | |
return dataset_kwargs, dataset_obj.name | |
# except IOError as err: | |
# raise click.ClickException(f'--data: {err}') | |
class Dataset(torch.utils.data.Dataset): | |
def __init__( | |
self, | |
name, # Name of the dataset. | |
raw_shape, # Shape of the raw image data (NCHW). | |
reso_gt=128, | |
max_size=None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. | |
use_labels=False, # Enable conditioning labels? False = label dimension is zero. | |
xflip=False, # Artificially double the size of the dataset via x-flips. Applied after max_size. | |
random_seed=0, # Random seed to use when applying max_size. | |
): | |
self._name = name | |
self._raw_shape = list(raw_shape) | |
self._use_labels = use_labels | |
self._raw_labels = None | |
self._label_shape = None | |
# self.reso_gt = 128 | |
self.reso_gt = reso_gt # ! hard coded | |
self.reso_encoder = 224 | |
# Apply max_size. | |
self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) | |
# self._raw_idx = np.arange(self.__len__(), dtype=np.int64) | |
if (max_size is not None) and (self._raw_idx.size > max_size): | |
np.random.RandomState(random_seed).shuffle(self._raw_idx) | |
self._raw_idx = np.sort(self._raw_idx[:max_size]) | |
# Apply xflip. | |
self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) | |
if xflip: | |
self._raw_idx = np.tile(self._raw_idx, 2) | |
self._xflip = np.concatenate( | |
[self._xflip, np.ones_like(self._xflip)]) | |
# dino encoder normalizer | |
self.normalize_for_encoder_input = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
transforms.Resize(size=(self.reso_encoder, self.reso_encoder), | |
antialias=True), # type: ignore | |
]) | |
self.normalize_for_gt = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
transforms.Resize(size=(self.reso_gt, self.reso_gt), | |
antialias=True), # type: ignore | |
]) | |
def _get_raw_labels(self): | |
if self._raw_labels is None: | |
self._raw_labels = self._load_raw_labels( | |
) if self._use_labels else None | |
if self._raw_labels is None: | |
self._raw_labels = np.zeros([self._raw_shape[0], 0], | |
dtype=np.float32) | |
assert isinstance(self._raw_labels, np.ndarray) | |
# assert self._raw_labels.shape[0] == self._raw_shape[0] | |
assert self._raw_labels.dtype in [np.float32, np.int64] | |
if self._raw_labels.dtype == np.int64: | |
assert self._raw_labels.ndim == 1 | |
assert np.all(self._raw_labels >= 0) | |
self._raw_labels_std = self._raw_labels.std(0) | |
return self._raw_labels | |
def close(self): # to be overridden by subclass | |
pass | |
def _load_raw_image(self, raw_idx): # to be overridden by subclass | |
raise NotImplementedError | |
def _load_raw_labels(self): # to be overridden by subclass | |
raise NotImplementedError | |
def __getstate__(self): | |
return dict(self.__dict__, _raw_labels=None) | |
def __del__(self): | |
try: | |
self.close() | |
except: | |
pass | |
def __len__(self): | |
return self._raw_idx.size | |
# return self._get_raw_labels().shape[0] | |
def __getitem__(self, idx): | |
# print(self._raw_idx[idx], idx) | |
matte = self._load_raw_matte(self._raw_idx[idx]) | |
assert isinstance(matte, np.ndarray) | |
assert list(matte.shape)[1:] == self.image_shape[1:] | |
if self._xflip[idx]: | |
assert matte.ndim == 1 # CHW | |
matte = matte[:, :, ::-1] | |
# matte_orig = matte.copy().astype(np.float32) / 255 | |
matte_orig = matte.copy().astype(np.float32) # segmentation version | |
# assert matte_orig.max() == 1 | |
matte = np.transpose(matte, | |
# (1, 2, 0)).astype(np.float32) / 255 # [0,1] range | |
(1, 2, 0)).astype(np.float32) # [0,1] range | |
matte = cv2.resize(matte, (self.reso_gt, self.reso_gt), | |
interpolation=cv2.INTER_NEAREST) | |
assert matte.min() >= 0 and matte.max( | |
) <= 1, f'{matte.min(), matte.max()}' | |
if matte.ndim == 3: # H, W | |
matte = matte[..., 0] | |
image = self._load_raw_image(self._raw_idx[idx]) | |
assert isinstance(image, np.ndarray) | |
assert list(image.shape) == self.image_shape | |
assert image.dtype == np.uint8 | |
if self._xflip[idx]: | |
assert image.ndim == 3 # CHW | |
image = image[:, :, ::-1] | |
# blending | |
# blending = True | |
blending = False | |
if blending: | |
image = image * matte_orig + (1 - matte_orig) * cv2.GaussianBlur( | |
image, (5, 5), cv2.BORDER_DEFAULT) | |
# image = image * matte_orig | |
image = np.transpose(image, (1, 2, 0)).astype( | |
np.float32 | |
) / 255 # H W C for torchvision process, normalize to [0,1] | |
image_sr = torch.from_numpy(image)[..., :3].permute( | |
2, 0, 1) * 2 - 1 # normalize to [-1,1] | |
image_to_encoder = self.normalize_for_encoder_input(image) | |
image_gt = cv2.resize(image, (self.reso_gt, self.reso_gt), | |
interpolation=cv2.INTER_AREA) | |
image_gt = torch.from_numpy(image_gt)[..., :3].permute( | |
2, 0, 1) * 2 - 1 # normalize to [-1,1] | |
return dict( | |
c=self.get_label(idx), | |
img_to_encoder=image_to_encoder, # 224 | |
img_sr=image_sr, # 512 | |
img=image_gt, # [-1,1] range | |
# depth=torch.zeros_like(image_gt)[0, ...] # type: ignore | |
depth=matte, | |
depth_mask=matte, | |
# depth_mask=matte > 0, | |
# alpha=matte, | |
) # return dict here | |
def get_label(self, idx): | |
label = self._get_raw_labels()[self._raw_idx[idx]] | |
if label.dtype == np.int64: | |
onehot = np.zeros(self.label_shape, dtype=np.float32) | |
onehot[label] = 1 | |
label = onehot | |
return label.copy() | |
def get_details(self, idx): | |
d = dnnlib.EasyDict() | |
d.raw_idx = int(self._raw_idx[idx]) | |
d.xflip = (int(self._xflip[idx]) != 0) | |
d.raw_label = self._get_raw_labels()[d.raw_idx].copy() | |
return d | |
def get_label_std(self): | |
return self._raw_labels_std | |
def name(self): | |
return self._name | |
def image_shape(self): | |
return list(self._raw_shape[1:]) | |
def num_channels(self): | |
assert len(self.image_shape) == 3 # CHW | |
return self.image_shape[0] | |
def resolution(self): | |
assert len(self.image_shape) == 3 # CHW | |
assert self.image_shape[1] == self.image_shape[2] | |
return self.image_shape[1] | |
def label_shape(self): | |
if self._label_shape is None: | |
raw_labels = self._get_raw_labels() | |
if raw_labels.dtype == np.int64: | |
self._label_shape = [int(np.max(raw_labels)) + 1] | |
else: | |
self._label_shape = raw_labels.shape[1:] | |
return list(self._label_shape) | |
def label_dim(self): | |
assert len(self.label_shape) == 1 | |
return self.label_shape[0] | |
def has_labels(self): | |
return any(x != 0 for x in self.label_shape) | |
def has_onehot_labels(self): | |
return self._get_raw_labels().dtype == np.int64 | |
#---------------------------------------------------------------------------- | |
class ImageFolderDataset(Dataset): | |
def __init__( | |
self, | |
path, # Path to directory or zip. | |
resolution=None, # Ensure specific resolution, None = highest available. | |
reso_gt=128, | |
**super_kwargs, # Additional arguments for the Dataset base class. | |
): | |
self._path = path | |
self._matte_path = path.replace('unzipped_ffhq_512', | |
'unzipped_ffhq_matte') | |
# self._matte_path = path.replace('unzipped_ffhq_512', | |
# 'ffhq_512_seg') | |
self._zipfile = None | |
if os.path.isdir(self._path): | |
self._type = 'dir' | |
self._all_fnames = { | |
os.path.relpath(os.path.join(root, fname), start=self._path) | |
for root, _dirs, files in os.walk(self._path) | |
for fname in files | |
} | |
elif self._file_ext(self._path) == '.zip': | |
self._type = 'zip' | |
self._all_fnames = set(self._get_zipfile().namelist()) | |
else: | |
raise IOError('Path must point to a directory or zip') | |
PIL.Image.init() | |
self._image_fnames = sorted( | |
fname for fname in self._all_fnames | |
if self._file_ext(fname) in PIL.Image.EXTENSION) | |
if len(self._image_fnames) == 0: | |
raise IOError('No image files found in the specified path') | |
name = os.path.splitext(os.path.basename(self._path))[0] | |
raw_shape = [len(self._image_fnames)] + list( | |
self._load_raw_image(0).shape) | |
# raw_shape = [len(self._image_fnames)] + list( | |
# self._load_raw_image(0).shape) | |
if resolution is not None and (raw_shape[2] != resolution | |
or raw_shape[3] != resolution): | |
raise IOError('Image files do not match the specified resolution') | |
super().__init__(name=name, | |
raw_shape=raw_shape, | |
reso_gt=reso_gt, | |
**super_kwargs) | |
def _file_ext(fname): | |
return os.path.splitext(fname)[1].lower() | |
def _get_zipfile(self): | |
assert self._type == 'zip' | |
if self._zipfile is None: | |
self._zipfile = zipfile.ZipFile(self._path) | |
return self._zipfile | |
def _open_file(self, fname): | |
if self._type == 'dir': | |
return open(os.path.join(self._path, fname), 'rb') | |
if self._type == 'zip': | |
return self._get_zipfile().open(fname, 'r') | |
return None | |
def _open_matte_file(self, fname): | |
if self._type == 'dir': | |
return open(os.path.join(self._matte_path, fname), 'rb') | |
# if self._type == 'zip': | |
# return self._get_zipfile().open(fname, 'r') | |
# return None | |
def close(self): | |
try: | |
if self._zipfile is not None: | |
self._zipfile.close() | |
finally: | |
self._zipfile = None | |
def __getstate__(self): | |
return dict(super().__getstate__(), _zipfile=None) | |
def _load_raw_image(self, raw_idx): | |
fname = self._image_fnames[raw_idx] | |
with self._open_file(fname) as f: | |
if pyspng is not None and self._file_ext(fname) == '.png': | |
image = pyspng.load(f.read()) | |
else: | |
image = np.array(PIL.Image.open(f)) | |
if image.ndim == 2: | |
image = image[:, :, np.newaxis] # HW => HWC | |
image = image.transpose(2, 0, 1) # HWC => CHW | |
return image | |
def _load_raw_matte(self, raw_idx): | |
# ! from seg version | |
fname = self._image_fnames[raw_idx] | |
with self._open_matte_file(fname) as f: | |
if pyspng is not None and self._file_ext(fname) == '.png': | |
image = pyspng.load(f.read()) | |
else: | |
image = np.array(PIL.Image.open(f)) | |
# if image.max() != 1: | |
image = (image > 0).astype(np.float32) # process segmentation | |
if image.ndim == 2: | |
image = image[:, :, np.newaxis] # HW => HWC | |
image = image.transpose(2, 0, 1) # HWC => CHW | |
return image | |
def _load_raw_matte_orig(self, raw_idx): | |
fname = self._image_fnames[raw_idx] | |
with self._open_matte_file(fname) as f: | |
if pyspng is not None and self._file_ext(fname) == '.png': | |
image = pyspng.load(f.read()) | |
else: | |
image = np.array(PIL.Image.open(f)) | |
st() # process segmentation | |
if image.ndim == 2: | |
image = image[:, :, np.newaxis] # HW => HWC | |
image = image.transpose(2, 0, 1) # HWC => CHW | |
return image | |
def _load_raw_labels(self): | |
fname = 'dataset.json' | |
if fname not in self._all_fnames: | |
return None | |
with self._open_file(fname) as f: | |
# st() | |
labels = json.load(f)['labels'] | |
if labels is None: | |
return None | |
labels = dict(labels) | |
labels_ = [] | |
for fname, _ in labels.items(): | |
# if 'mirror' not in fname: | |
labels_.append(labels[fname]) | |
labels = labels_ | |
# ! | |
# labels = [ | |
# labels[fname.replace('\\', '/')] for fname in self._image_fnames | |
# ] | |
labels = np.array(labels) | |
labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) | |
self._raw_labels = labels | |
return labels | |
#---------------------------------------------------------------------------- | |
# class ImageFolderDatasetUnzipped(ImageFolderDataset): | |
# def __init__(self, path, resolution=None, **super_kwargs): | |
# super().__init__(path, resolution, **super_kwargs) | |
# class ImageFolderDatasetPose(ImageFolderDataset): | |
# def __init__( | |
# self, | |
# path, # Path to directory or zip. | |
# resolution=None, # Ensure specific resolution, None = highest available. | |
# **super_kwargs, # Additional arguments for the Dataset base class. | |
# ): | |
# super().__init__(path, resolution, **super_kwargs) | |
# # only return labels | |
# def __len__(self): | |
# return self._raw_idx.size | |
# # return self._get_raw_labels().shape[0] | |
# def __getitem__(self, idx): | |
# # image = self._load_raw_image(self._raw_idx[idx]) | |
# # assert isinstance(image, np.ndarray) | |
# # assert list(image.shape) == self.image_shape | |
# # assert image.dtype == np.uint8 | |
# # if self._xflip[idx]: | |
# # assert image.ndim == 3 # CHW | |
# # image = image[:, :, ::-1] | |
# return dict(c=self.get_label(idx), ) # return dict here | |
class ImageFolderDatasetLMDB(ImageFolderDataset): | |
def __init__(self, path, resolution=None, reso_gt=128, **super_kwargs): | |
super().__init__(path, resolution, reso_gt, **super_kwargs) | |
def __getitem__(self, idx): | |
# print(self._raw_idx[idx], idx) | |
matte = self._load_raw_matte(self._raw_idx[idx]) | |
assert isinstance(matte, np.ndarray) | |
assert list(matte.shape)[1:] == self.image_shape[1:] | |
if self._xflip[idx]: | |
assert matte.ndim == 1 # CHW | |
matte = matte[:, :, ::-1] | |
# matte_orig = matte.copy().astype(np.float32) / 255 | |
matte_orig = matte.copy().astype(np.float32) # segmentation version | |
assert matte_orig.max() <= 1 # some ffhq images are dirty, so may be all zero | |
matte = np.transpose(matte, | |
# (1, 2, 0)).astype(np.float32) / 255 # [0,1] range | |
(1, 2, 0)).astype(np.float32) # [0,1] range | |
# ! load 512 matte | |
# matte = cv2.resize(matte, (self.reso_gt, self.reso_gt), | |
# interpolation=cv2.INTER_NEAREST) | |
assert matte.min() >= 0 and matte.max( | |
) <= 1, f'{matte.min(), matte.max()}' | |
if matte.ndim == 3: # H, W | |
matte = matte[..., 0] | |
image = self._load_raw_image(self._raw_idx[idx]) | |
assert isinstance(image, np.ndarray) | |
assert list(image.shape) == self.image_shape | |
assert image.dtype == np.uint8 | |
if self._xflip[idx]: | |
assert image.ndim == 3 # CHW | |
image = image[:, :, ::-1] | |
# blending | |
# blending = True | |
# blending = False | |
# if blending: | |
# image = image * matte_orig + (1 - matte_orig) * cv2.GaussianBlur( | |
# image, (5, 5), cv2.BORDER_DEFAULT) | |
# image = image * matte_orig | |
# image = np.transpose(image, (1, 2, 0)).astype( | |
# np.float32 | |
# ) / 255 # H W C for torchvision process, normalize to [0,1] | |
# image_sr = torch.from_numpy(image)[..., :3].permute( | |
# 2, 0, 1) * 2 - 1 # normalize to [-1,1] | |
# image_to_encoder = self.normalize_for_encoder_input(image) | |
# image_gt = cv2.resize(image, (self.reso_gt, self.reso_gt), | |
# interpolation=cv2.INTER_AREA) | |
# image_gt = torch.from_numpy(image_gt)[..., :3].permute( | |
# 2, 0, 1) * 2 - 1 # normalize to [-1,1] | |
return dict( | |
c=self.get_label(idx), | |
# img_to_encoder=image_to_encoder, # 224 | |
# img_sr=image_sr, # 512 | |
img=image, # [-1,1] range | |
# depth=torch.zeros_like(image_gt)[0, ...] # type: ignore | |
# depth=matte, | |
depth_mask=matte, | |
) # return dict here | |
class LMDBDataset_MV_Compressed_eg3d(LMDBDataset_MV_Compressed): | |
def __init__(self, | |
lmdb_path, | |
reso, | |
reso_encoder, | |
imgnet_normalize=True, | |
**kwargs): | |
super().__init__(lmdb_path, reso, reso_encoder, imgnet_normalize, | |
**kwargs) | |
self.normalize_for_encoder_input = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
transforms.Resize(size=(self.reso_encoder, self.reso_encoder), | |
antialias=True), # type: ignore | |
]) | |
self.normalize_for_gt = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
transforms.Resize(size=(self.reso, self.reso), | |
antialias=True), # type: ignore | |
]) | |
def __getitem__(self, idx): | |
# sample = super(LMDBDataset).__getitem__(idx) | |
# do gzip uncompress online | |
with self.env.begin(write=False) as txn: | |
img_key = f'{idx}-img'.encode('utf-8') | |
image = self.load_image_fn(txn.get(img_key)) | |
depth_key = f'{idx}-depth_mask'.encode('utf-8') | |
# depth = decompress_array(txn.get(depth_key), (512,512), np.float32) | |
depth = decompress_array(txn.get(depth_key), (64,64), np.float32) | |
c_key = f'{idx}-c'.encode('utf-8') | |
c = decompress_array(txn.get(c_key), (25, ), np.float32) | |
# ! post processing, e.g., normalizing | |
depth = cv2.resize(depth, (self.reso, self.reso), | |
interpolation=cv2.INTER_NEAREST) | |
image = np.transpose(image, (1, 2, 0)).astype( | |
np.float32 | |
) / 255 # H W C for torchvision process, normalize to [0,1] | |
image_sr = torch.from_numpy(image)[..., :3].permute( | |
2, 0, 1) * 2 - 1 # normalize to [-1,1] | |
image_to_encoder = self.normalize_for_encoder_input(image) | |
image_gt = cv2.resize(image, (self.reso, self.reso), | |
interpolation=cv2.INTER_AREA) | |
image_gt = torch.from_numpy(image_gt)[..., :3].permute( | |
2, 0, 1) * 2 - 1 # normalize to [-1,1] | |
return { | |
'img_to_encoder': image_to_encoder, # 224 | |
'img_sr': image_sr, # 512 | |
'img': image_gt, # [-1,1] range | |
'c': c, | |
'depth': depth, | |
'depth_mask': depth, | |
} | |