Spaces:
Running
Running
fix(data): no shuffling of validation data
Browse files- src/dalle_mini/data.py +11 -7
src/dalle_mini/data.py
CHANGED
@@ -182,15 +182,20 @@ class Dataset:
|
|
182 |
yield batch
|
183 |
|
184 |
def _dataloader_datasets_streaming(
|
185 |
-
dataset: Dataset, batch_size: int, epoch: int
|
186 |
):
|
187 |
keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
|
188 |
batch = {k: [] for k in keys}
|
189 |
-
first_loop = True
|
190 |
-
while self.multi_hosts or first_loop:
|
191 |
# in multi-host, we run forever (no epoch) as hosts need to stop
|
192 |
-
# at the same time and
|
193 |
-
|
|
|
|
|
|
|
|
|
|
|
194 |
for item in dataset:
|
195 |
for k, v in item.items():
|
196 |
batch[k].append(v)
|
@@ -199,7 +204,6 @@ class Dataset:
|
|
199 |
batch = shard(batch)
|
200 |
yield batch
|
201 |
batch = {k: [] for k in keys}
|
202 |
-
epoch += 1
|
203 |
first_loop = False
|
204 |
|
205 |
if split == "train":
|
@@ -210,7 +214,7 @@ class Dataset:
|
|
210 |
raise ValueError(f'split must be "train" or "eval", got {split}')
|
211 |
|
212 |
if self.streaming:
|
213 |
-
return _dataloader_datasets_streaming(ds, batch_size, epoch)
|
214 |
else:
|
215 |
if split == "train":
|
216 |
self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
|
|
|
182 |
yield batch
|
183 |
|
184 |
def _dataloader_datasets_streaming(
|
185 |
+
dataset: Dataset, split: str, batch_size: int, epoch: int
|
186 |
):
|
187 |
keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
|
188 |
batch = {k: [] for k in keys}
|
189 |
+
first_loop = True # stop after one loop in some cases
|
190 |
+
while (self.multi_hosts and split == "train") or first_loop:
|
191 |
# in multi-host, we run forever (no epoch) as hosts need to stop
|
192 |
+
# at the same time and training data may not be split equally
|
193 |
+
# For validation data we put the entire set on each host as we could lose
|
194 |
+
# too many samples on pods
|
195 |
+
if epoch is not None:
|
196 |
+
# reshuffle training data at each epoch (not applicable with validation set)
|
197 |
+
dataset.set_epoch(epoch)
|
198 |
+
epoch += 1
|
199 |
for item in dataset:
|
200 |
for k, v in item.items():
|
201 |
batch[k].append(v)
|
|
|
204 |
batch = shard(batch)
|
205 |
yield batch
|
206 |
batch = {k: [] for k in keys}
|
|
|
207 |
first_loop = False
|
208 |
|
209 |
if split == "train":
|
|
|
214 |
raise ValueError(f'split must be "train" or "eval", got {split}')
|
215 |
|
216 |
if self.streaming:
|
217 |
+
return _dataloader_datasets_streaming(ds, split, batch_size, epoch)
|
218 |
else:
|
219 |
if split == "train":
|
220 |
self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
|