import torch from torch.utils.data import get_worker_info from torch.utils.data import DataLoader import random import time from functools import partial from itertools import chain from petrel_client.utils.data import DataLoader as MyDataLoader MyDataLoader = partial(MyDataLoader, prefetch_factor=4, persistent_workers=True) def assert_equal(lhs, rhs): if isinstance(lhs, dict): assert lhs.keys() == rhs.keys() for k in lhs.keys(): assert_equal(lhs[k], rhs[k]) elif isinstance(lhs, list): assert len(lhs) == len(rhs) for i in range(len(lhs)): assert_equal(lhs[i], rhs[i]) elif isinstance(lhs, torch.Tensor): assert torch.equal(lhs, rhs) else: assert False def wait(dt): time.sleep(dt) class Dataset(list): def __init__(self, *args, **kwargs): super(Dataset, self).__init__(*args, **kwargs) self._seed_inited = False def __getitem__(self, *args, **kwargs): worker_info = get_worker_info() if not self._seed_inited: if worker_info is None: random.seed(0) else: random.seed(worker_info.id) self._seed_inited = True rand_int = random.randint(1, 4) time_to_sleep = rand_int * 0.05 if worker_info is not None and worker_info.id == 0: time_to_sleep *= 2 wait(time_to_sleep) val = super(Dataset, self).__getitem__(*args, **kwargs) return {'val': val} def test(dataloader, result): print('\ntest') random.seed(0) data_time = 0 tstart = t1 = time.time() for i, data in enumerate(chain(dataloader, dataloader), 1): t2 = time.time() d = t2 - t1 print('{0:>5}' .format(int((t2 - t1)*1000)), end='') if i % 10: print('\t', end='') else: print('') result.append(data) data_time += d rand_int = random.randrange(1, 4) wait(0.05 * rand_int) t1 = time.time() tend = time.time() print('\ntotal time: %.3f' % (tend - tstart)) print('total data time: %.3f' % data_time) print(type(dataloader)) def worker_init_fn(worker_id): print('start worker:', worker_id) wait(3) dataloader_args = { 'dataset': Dataset(range(1024)), 'drop_last': False, 'shuffle': False, 'batch_size': 32, 'num_workers': 8, 'worker_init_fn': worker_init_fn, } torch.manual_seed(0) l2 = MyDataLoader(**dataloader_args) r2 = [] test(l2, r2) torch.manual_seed(0) l1 = DataLoader(**dataloader_args) r1 = [] test(l1, r1) print('len l1:', len(l1)) print('len l2:', len(l2)) assert_equal(r1, r2) print(torch)