File size: 2,713 Bytes
2abfccb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import torch
from torch.utils.data import get_worker_info
from torch.utils.data import DataLoader
import random
import time
from functools import partial
from itertools import chain
from petrel_client.utils.data import DataLoader as MyDataLoader
MyDataLoader = partial(MyDataLoader, prefetch_factor=4, persistent_workers=True)
def assert_equal(lhs, rhs):
if isinstance(lhs, dict):
assert lhs.keys() == rhs.keys()
for k in lhs.keys():
assert_equal(lhs[k], rhs[k])
elif isinstance(lhs, list):
assert len(lhs) == len(rhs)
for i in range(len(lhs)):
assert_equal(lhs[i], rhs[i])
elif isinstance(lhs, torch.Tensor):
assert torch.equal(lhs, rhs)
else:
assert False
def wait(dt):
time.sleep(dt)
class Dataset(list):
def __init__(self, *args, **kwargs):
super(Dataset, self).__init__(*args, **kwargs)
self._seed_inited = False
def __getitem__(self, *args, **kwargs):
worker_info = get_worker_info()
if not self._seed_inited:
if worker_info is None:
random.seed(0)
else:
random.seed(worker_info.id)
self._seed_inited = True
rand_int = random.randint(1, 4)
time_to_sleep = rand_int * 0.05
if worker_info is not None and worker_info.id == 0:
time_to_sleep *= 2
wait(time_to_sleep)
val = super(Dataset, self).__getitem__(*args, **kwargs)
return {'val': val}
def test(dataloader, result):
print('\ntest')
random.seed(0)
data_time = 0
tstart = t1 = time.time()
for i, data in enumerate(chain(dataloader, dataloader), 1):
t2 = time.time()
d = t2 - t1
print('{0:>5}' .format(int((t2 - t1)*1000)), end='')
if i % 10:
print('\t', end='')
else:
print('')
result.append(data)
data_time += d
rand_int = random.randrange(1, 4)
wait(0.05 * rand_int)
t1 = time.time()
tend = time.time()
print('\ntotal time: %.3f' % (tend - tstart))
print('total data time: %.3f' % data_time)
print(type(dataloader))
def worker_init_fn(worker_id):
print('start worker:', worker_id)
wait(3)
dataloader_args = {
'dataset': Dataset(range(1024)),
'drop_last': False,
'shuffle': False,
'batch_size': 32,
'num_workers': 8,
'worker_init_fn': worker_init_fn,
}
torch.manual_seed(0)
l2 = MyDataLoader(**dataloader_args)
r2 = []
test(l2, r2)
torch.manual_seed(0)
l1 = DataLoader(**dataloader_args)
r1 = []
test(l1, r1)
print('len l1:', len(l1))
print('len l2:', len(l2))
assert_equal(r1, r2)
print(torch)
|