Arash
initial code release
c334626
raw
history blame
6.12 kB
# ---------------------------------------------------------------
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# This file has been modified from a file in the torchvision library
# which was released under the BSD 3-Clause License.
#
# Source:
# https://github.com/pytorch/vision/blob/ea6b879e90459006e71a164dc76b7e2cc3bff9d9/torchvision/datasets/lsun.py
#
# The license for the original version of this file can be
# found in this directory (LICENSE_torchvision). The modifications
# to this file are subject to the same BSD 3-Clause License.
# ---------------------------------------------------------------
from torchvision.datasets.vision import VisionDataset
from PIL import Image
import os
import os.path
import io
import string
from collections.abc import Iterable
import pickle
from torchvision.datasets.utils import verify_str_arg, iterable_to_str
class LSUNClass(VisionDataset):
def __init__(self, root, transform=None, target_transform=None):
import lmdb
super(LSUNClass, self).__init__(root, transform=transform,
target_transform=target_transform)
self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False,
readahead=False, meminit=False)
with self.env.begin(write=False) as txn:
self.length = txn.stat()['entries']
# cache_file = '_cache_' + ''.join(c for c in root if c in string.ascii_letters)
# av begin
# We only modified the location of cache_file.
cache_file = os.path.join(self.root, '_cache_')
# av end
if os.path.isfile(cache_file):
self.keys = pickle.load(open(cache_file, "rb"))
else:
with self.env.begin(write=False) as txn:
self.keys = [key for key, _ in txn.cursor()]
pickle.dump(self.keys, open(cache_file, "wb"))
def __getitem__(self, index):
img, target = None, -1
env = self.env
with env.begin(write=False) as txn:
imgbuf = txn.get(self.keys[index])
buf = io.BytesIO()
buf.write(imgbuf)
buf.seek(0)
img = Image.open(buf).convert('RGB')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return self.length
class LSUN(VisionDataset):
"""
`LSUN <https://www.yf.io/p/lsun>`_ dataset.
Args:
root (string): Root directory for the database files.
classes (string or list): One of {'train', 'val', 'test'} or a list of
categories to load. e,g. ['bedroom_train', 'church_outdoor_train'].
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
def __init__(self, root, classes='train', transform=None, target_transform=None):
super(LSUN, self).__init__(root, transform=transform,
target_transform=target_transform)
self.classes = self._verify_classes(classes)
# for each class, create an LSUNClassDataset
self.dbs = []
for c in self.classes:
self.dbs.append(LSUNClass(
root=root + '/' + c + '_lmdb',
transform=transform))
self.indices = []
count = 0
for db in self.dbs:
count += len(db)
self.indices.append(count)
self.length = count
def _verify_classes(self, classes):
categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom',
'conference_room', 'dining_room', 'kitchen',
'living_room', 'restaurant', 'tower', 'cat']
dset_opts = ['train', 'val', 'test']
try:
verify_str_arg(classes, "classes", dset_opts)
if classes == 'test':
classes = [classes]
else:
classes = [c + '_' + classes for c in categories]
except ValueError:
if not isinstance(classes, Iterable):
msg = ("Expected type str or Iterable for argument classes, "
"but got type {}.")
raise ValueError(msg.format(type(classes)))
classes = list(classes)
msg_fmtstr = ("Expected type str for elements in argument classes, "
"but got type {}.")
for c in classes:
verify_str_arg(c, custom_msg=msg_fmtstr.format(type(c)))
c_short = c.split('_')
category, dset_opt = '_'.join(c_short[:-1]), c_short[-1]
msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}."
msg = msg_fmtstr.format(category, "LSUN class",
iterable_to_str(categories))
verify_str_arg(category, valid_values=categories, custom_msg=msg)
msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts))
verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg)
return classes
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: Tuple (image, target) where target is the index of the target category.
"""
target = 0
sub = 0
for ind in self.indices:
if index < ind:
break
target += 1
sub = ind
db = self.dbs[target]
index = index - sub
if self.target_transform is not None:
target = self.target_transform(target)
img, _ = db[index]
return img, target
def __len__(self):
return self.length
def extra_repr(self):
return "Classes: {classes}".format(**self.__dict__)