File size: 4,425 Bytes
3f395b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import numpy as np
import threading
import queue
import multiprocessing
from collections import defaultdict
import jax
import jax.numpy as jnp



def make_batch(samples):
    batch = {k:jnp.array(v) for k,v in samples.items()}
    batch['labels'] = batch['input_ids'].copy()
    return batch

class PrefetchDataloaderTread(threading.Thread):
    "Prefetch dataloader for IterableDataset"
    def __init__(self, dataset, max_steps, batch_size, sequence_length, prefetch_buffer=1, shuffle=True, shuffle_buffer=1000, seed=0):
        super().__init__(daemon=True)
        self.max_steps = max_steps
        self.bs = batch_size
        self.seq_len = sequence_length
        self.max_length = batch_size * sequence_length
        self.prefetch_buffer = prefetch_buffer
        self.shuffle = shuffle
        self.shuffle_buffer = shuffle_buffer
        self.seed = seed
        self.dataset = dataset
        if shuffle:
            shuffled_dataset = dataset.shuffle(shuffle_buffer, seed=self.seed)
            self.seed += 1
            self.ds_iter = iter(shuffled_dataset)
        else:
            self.ds_iter = iter(dataset)
        self.queue = queue.Queue(prefetch_buffer)
        self.rem = defaultdict(list)
        self.start()
        
    def __next__(self):
        batch = self.queue.get()
        return batch
        
    def run(self):
        i = 0
        while True and i < self.max_steps:
            i += 1
            # prepair next batch
            sample = self.rem.copy()
            l = len(sample["input_ids"])
            max_length = self.max_length
            while l < max_length:
                next_sample = next(self.ds_iter)
                l += len(next_sample["input_ids"])
                sample = {k:sample[k]+next_sample[k] for k in next_sample.keys()}
            
            self.rem = {k:v[max_length:] for k,v in sample.items()}
            sample = {k:v[:max_length] for k,v in sample.items()}
            # regroup to shape [bs x seq_len]
            samples = {k:np.array([v[i*self.seq_len:(i+1)*self.seq_len] for i in range(self.bs)]) for k,v in sample.items()}
            
            self.queue.put(make_batch(samples))
        self.queue.put(None)
    
    def __iter__(self):
        return self


class PrefetchDataloader(multiprocessing.Process):
    "Prefetch dataloader for IterableDataset"
    def __init__(self, dataset, max_steps, batch_size, sequence_length, prefetch_buffer=1, shuffle=True, shuffle_buffer=1000, seed=0):
        super().__init__(daemon=True)
        self.max_steps = max_steps
        self.bs = batch_size
        self.seq_len = sequence_length
        self.max_length = batch_size * sequence_length
        self.prefetch_buffer = prefetch_buffer
        self.shuffle = shuffle
        self.shuffle_buffer = shuffle_buffer
        self.seed = seed
        self.dataset = dataset
        self.make_iter()
        self.queue = multiprocessing.Queue(prefetch_buffer)
        self.rem = defaultdict(list)
        self.start()

    def make_iter(self):
        if self.shuffle:
            shuffled_dataset = self.dataset.shuffle(self.shuffle_buffer, seed=self.seed)
            self.seed += 1
            self.ds_iter = iter(shuffled_dataset)
        else:
            self.ds_iter = iter(self.dataset)

    def __next__(self):
        return make_batch(self.queue.get())
        
    def run(self):
        i = 0
        while True and i < self.max_steps:
            # prepair next batch
            sample = self.rem.copy()
            l = len(sample["input_ids"])
            max_length = self.max_length
            while l < max_length:
                try:
                    next_sample = next(self.ds_iter)
                except StopIteration:
                    # reset generator if a pass through dataset is completed
                    self.make_iter()
                l += len(next_sample["input_ids"])
                sample = {k:sample[k]+next_sample[k] for k in next_sample.keys()}
            
            self.rem = {k:v[max_length:] for k,v in sample.items()}
            sample = {k:v[:max_length] for k,v in sample.items()}
            # regroup to shape [bs x seq_len]
            samples = {k:np.array([v[i*self.seq_len:(i+1)*self.seq_len] for i in range(self.bs)]) for k,v in sample.items()}
            
            self.queue.put(samples)
        self.queue.put(None)
    
    def __iter__(self):
        return self