Spaces:
Paused
Paused
import numbers | |
import os | |
import queue as Queue | |
import threading | |
import mxnet as mx | |
import numpy as np | |
import torch | |
from torch.utils.data import DataLoader, Dataset | |
from torchvision import transforms | |
class BackgroundGenerator(threading.Thread): | |
def __init__(self, generator, local_rank, max_prefetch=6): | |
super(BackgroundGenerator, self).__init__() | |
self.queue = Queue.Queue(max_prefetch) | |
self.generator = generator | |
self.local_rank = local_rank | |
self.daemon = True | |
self.start() | |
def run(self): | |
torch.cuda.set_device(self.local_rank) | |
for item in self.generator: | |
self.queue.put(item) | |
self.queue.put(None) | |
def next(self): | |
next_item = self.queue.get() | |
if next_item is None: | |
raise StopIteration | |
return next_item | |
def __next__(self): | |
return self.next() | |
def __iter__(self): | |
return self | |
class DataLoaderX(DataLoader): | |
def __init__(self, local_rank, **kwargs): | |
super(DataLoaderX, self).__init__(**kwargs) | |
self.stream = torch.cuda.Stream(local_rank) | |
self.local_rank = local_rank | |
def __iter__(self): | |
self.iter = super(DataLoaderX, self).__iter__() | |
self.iter = BackgroundGenerator(self.iter, self.local_rank) | |
self.preload() | |
return self | |
def preload(self): | |
self.batch = next(self.iter, None) | |
if self.batch is None: | |
return None | |
with torch.cuda.stream(self.stream): | |
for k in range(len(self.batch)): | |
self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True) | |
def __next__(self): | |
torch.cuda.current_stream().wait_stream(self.stream) | |
batch = self.batch | |
if batch is None: | |
raise StopIteration | |
self.preload() | |
return batch | |
class MXFaceDataset(Dataset): | |
def __init__(self, root_dir, local_rank): | |
super(MXFaceDataset, self).__init__() | |
self.transform = transforms.Compose( | |
[transforms.ToPILImage(), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
]) | |
self.root_dir = root_dir | |
self.local_rank = local_rank | |
path_imgrec = os.path.join(root_dir, 'train.rec') | |
path_imgidx = os.path.join(root_dir, 'train.idx') | |
self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') | |
s = self.imgrec.read_idx(0) | |
header, _ = mx.recordio.unpack(s) | |
if header.flag > 0: | |
self.header0 = (int(header.label[0]), int(header.label[1])) | |
self.imgidx = np.array(range(1, int(header.label[0]))) | |
else: | |
self.imgidx = np.array(list(self.imgrec.keys)) | |
def __getitem__(self, index): | |
idx = self.imgidx[index] | |
s = self.imgrec.read_idx(idx) | |
header, img = mx.recordio.unpack(s) | |
label = header.label | |
if not isinstance(label, numbers.Number): | |
label = label[0] | |
label = torch.tensor(label, dtype=torch.long) | |
sample = mx.image.imdecode(img).asnumpy() | |
if self.transform is not None: | |
sample = self.transform(sample) | |
return sample, label | |
def __len__(self): | |
return len(self.imgidx) | |
class SyntheticDataset(Dataset): | |
def __init__(self, local_rank): | |
super(SyntheticDataset, self).__init__() | |
img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) | |
img = np.transpose(img, (2, 0, 1)) | |
img = torch.from_numpy(img).squeeze(0).float() | |
img = ((img / 255) - 0.5) / 0.5 | |
self.img = img | |
self.label = 1 | |
def __getitem__(self, index): | |
return self.img, self.label | |
def __len__(self): | |
return 1000000 | |