File size: 2,713 Bytes
2abfccb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
from torch.utils.data import get_worker_info
from torch.utils.data import DataLoader

import random
import time

from functools import partial

from itertools import chain


from petrel_client.utils.data import DataLoader as MyDataLoader
MyDataLoader = partial(MyDataLoader, prefetch_factor=4, persistent_workers=True)


def assert_equal(lhs, rhs):
    if isinstance(lhs, dict):
        assert lhs.keys() == rhs.keys()
        for k in lhs.keys():
            assert_equal(lhs[k], rhs[k])
    elif isinstance(lhs, list):
        assert len(lhs) == len(rhs)
        for i in range(len(lhs)):
            assert_equal(lhs[i], rhs[i])
    elif isinstance(lhs, torch.Tensor):
        assert torch.equal(lhs, rhs)
    else:
        assert False


def wait(dt):
    time.sleep(dt)


class Dataset(list):
    def __init__(self, *args, **kwargs):
        super(Dataset, self).__init__(*args, **kwargs)
        self._seed_inited = False

    def __getitem__(self, *args, **kwargs):
        worker_info = get_worker_info()
        if not self._seed_inited:
            if worker_info is None:
                random.seed(0)
            else:
                random.seed(worker_info.id)
            self._seed_inited = True
        rand_int = random.randint(1, 4)
        time_to_sleep = rand_int * 0.05
        if worker_info is not None and worker_info.id == 0:
            time_to_sleep *= 2
        wait(time_to_sleep)
        val = super(Dataset, self).__getitem__(*args, **kwargs)
        return {'val': val}


def test(dataloader, result):
    print('\ntest')
    random.seed(0)
    data_time = 0
    tstart = t1 = time.time()
    for i, data in enumerate(chain(dataloader, dataloader), 1):
        t2 = time.time()
        d = t2 - t1
        print('{0:>5}' .format(int((t2 - t1)*1000)), end='')
        if i % 10:
            print('\t', end='')
        else:
            print('')

        result.append(data)

        data_time += d

        rand_int = random.randrange(1, 4)
        wait(0.05 * rand_int)

        t1 = time.time()
    tend = time.time()
    print('\ntotal time: %.3f' % (tend - tstart))
    print('total data time: %.3f' % data_time)
    print(type(dataloader))


def worker_init_fn(worker_id):
    print('start worker:', worker_id)
    wait(3)


dataloader_args = {
    'dataset': Dataset(range(1024)),
    'drop_last': False,
    'shuffle': False,
    'batch_size': 32,
    'num_workers': 8,
    'worker_init_fn': worker_init_fn,
}


torch.manual_seed(0)
l2 = MyDataLoader(**dataloader_args)
r2 = []
test(l2, r2)

torch.manual_seed(0)
l1 = DataLoader(**dataloader_args)
r1 = []
test(l1, r1)


print('len l1:', len(l1))
print('len l2:', len(l2))

assert_equal(r1, r2)
print(torch)