Spaces:
Runtime error
Runtime error
# --------------------------------------------------------------- | |
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. | |
# | |
# This work is licensed under the NVIDIA Source Code License | |
# for Denoising Diffusion GAN. To view a copy of this license, see the LICENSE file. | |
# --------------------------------------------------------------- | |
import torch.utils.data as data | |
import numpy as np | |
import lmdb | |
import os | |
import io | |
from PIL import Image | |
def num_samples(dataset, train): | |
if dataset == 'celeba': | |
return 27000 if train else 3000 | |
else: | |
raise NotImplementedError('dataset %s is unknown' % dataset) | |
class LMDBDataset(data.Dataset): | |
def __init__(self, root, name='', train=True, transform=None, is_encoded=False): | |
self.train = train | |
self.name = name | |
self.transform = transform | |
if self.train: | |
lmdb_path = os.path.join(root, 'train.lmdb') | |
else: | |
lmdb_path = os.path.join(root, 'validation.lmdb') | |
self.data_lmdb = lmdb.open(lmdb_path, readonly=True, max_readers=1, | |
lock=False, readahead=False, meminit=False) | |
self.is_encoded = is_encoded | |
def __getitem__(self, index): | |
target = [0] | |
with self.data_lmdb.begin(write=False, buffers=True) as txn: | |
data = txn.get(str(index).encode()) | |
if self.is_encoded: | |
img = Image.open(io.BytesIO(data)) | |
img = img.convert('RGB') | |
else: | |
img = np.asarray(data, dtype=np.uint8) | |
# assume data is RGB | |
size = int(np.sqrt(len(img) / 3)) | |
img = np.reshape(img, (size, size, 3)) | |
img = Image.fromarray(img, mode='RGB') | |
if self.transform is not None: | |
img = self.transform(img) | |
return img, target | |
def __len__(self): | |
return num_samples(self.name, self.train) | |