File size: 4,425 Bytes
3f395b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import numpy as np
import threading
import queue
import multiprocessing
from collections import defaultdict
import jax
import jax.numpy as jnp
def make_batch(samples):
batch = {k:jnp.array(v) for k,v in samples.items()}
batch['labels'] = batch['input_ids'].copy()
return batch
class PrefetchDataloaderTread(threading.Thread):
"Prefetch dataloader for IterableDataset"
def __init__(self, dataset, max_steps, batch_size, sequence_length, prefetch_buffer=1, shuffle=True, shuffle_buffer=1000, seed=0):
super().__init__(daemon=True)
self.max_steps = max_steps
self.bs = batch_size
self.seq_len = sequence_length
self.max_length = batch_size * sequence_length
self.prefetch_buffer = prefetch_buffer
self.shuffle = shuffle
self.shuffle_buffer = shuffle_buffer
self.seed = seed
self.dataset = dataset
if shuffle:
shuffled_dataset = dataset.shuffle(shuffle_buffer, seed=self.seed)
self.seed += 1
self.ds_iter = iter(shuffled_dataset)
else:
self.ds_iter = iter(dataset)
self.queue = queue.Queue(prefetch_buffer)
self.rem = defaultdict(list)
self.start()
def __next__(self):
batch = self.queue.get()
return batch
def run(self):
i = 0
while True and i < self.max_steps:
i += 1
# prepair next batch
sample = self.rem.copy()
l = len(sample["input_ids"])
max_length = self.max_length
while l < max_length:
next_sample = next(self.ds_iter)
l += len(next_sample["input_ids"])
sample = {k:sample[k]+next_sample[k] for k in next_sample.keys()}
self.rem = {k:v[max_length:] for k,v in sample.items()}
sample = {k:v[:max_length] for k,v in sample.items()}
# regroup to shape [bs x seq_len]
samples = {k:np.array([v[i*self.seq_len:(i+1)*self.seq_len] for i in range(self.bs)]) for k,v in sample.items()}
self.queue.put(make_batch(samples))
self.queue.put(None)
def __iter__(self):
return self
class PrefetchDataloader(multiprocessing.Process):
"Prefetch dataloader for IterableDataset"
def __init__(self, dataset, max_steps, batch_size, sequence_length, prefetch_buffer=1, shuffle=True, shuffle_buffer=1000, seed=0):
super().__init__(daemon=True)
self.max_steps = max_steps
self.bs = batch_size
self.seq_len = sequence_length
self.max_length = batch_size * sequence_length
self.prefetch_buffer = prefetch_buffer
self.shuffle = shuffle
self.shuffle_buffer = shuffle_buffer
self.seed = seed
self.dataset = dataset
self.make_iter()
self.queue = multiprocessing.Queue(prefetch_buffer)
self.rem = defaultdict(list)
self.start()
def make_iter(self):
if self.shuffle:
shuffled_dataset = self.dataset.shuffle(self.shuffle_buffer, seed=self.seed)
self.seed += 1
self.ds_iter = iter(shuffled_dataset)
else:
self.ds_iter = iter(self.dataset)
def __next__(self):
return make_batch(self.queue.get())
def run(self):
i = 0
while True and i < self.max_steps:
# prepair next batch
sample = self.rem.copy()
l = len(sample["input_ids"])
max_length = self.max_length
while l < max_length:
try:
next_sample = next(self.ds_iter)
except StopIteration:
# reset generator if a pass through dataset is completed
self.make_iter()
l += len(next_sample["input_ids"])
sample = {k:sample[k]+next_sample[k] for k in next_sample.keys()}
self.rem = {k:v[max_length:] for k,v in sample.items()}
sample = {k:v[:max_length] for k,v in sample.items()}
# regroup to shape [bs x seq_len]
samples = {k:np.array([v[i*self.seq_len:(i+1)*self.seq_len] for i in range(self.bs)]) for k,v in sample.items()}
self.queue.put(samples)
self.queue.put(None)
def __iter__(self):
return self |