Spaces:
Running
Running
Merge pull request #122 from borisdayma/feat-acccum
Browse files- src/dalle_mini/data.py +41 -6
- tools/train/train.py +93 -46
src/dalle_mini/data.py
CHANGED
@@ -153,16 +153,24 @@ class Dataset:
|
|
153 |
),
|
154 |
)
|
155 |
|
156 |
-
def dataloader(
|
|
|
|
|
|
|
|
|
157 |
def _dataloader_datasets_non_streaming(
|
158 |
dataset: Dataset,
|
159 |
-
|
|
|
160 |
rng: jax.random.PRNGKey = None,
|
161 |
):
|
162 |
"""
|
163 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
164 |
Shuffle batches if rng is set.
|
165 |
"""
|
|
|
|
|
|
|
166 |
steps_per_epoch = len(dataset) // batch_size
|
167 |
|
168 |
if rng is not None:
|
@@ -178,11 +186,20 @@ class Dataset:
|
|
178 |
for idx in batch_idx:
|
179 |
batch = dataset[idx]
|
180 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
|
|
|
|
|
|
|
|
|
|
181 |
batch = shard(batch)
|
182 |
yield batch
|
183 |
|
184 |
def _dataloader_datasets_streaming(
|
185 |
-
dataset: Dataset,
|
|
|
|
|
|
|
|
|
186 |
):
|
187 |
keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
|
188 |
batch = {k: [] for k in keys}
|
@@ -199,8 +216,22 @@ class Dataset:
|
|
199 |
for item in dataset:
|
200 |
for k, v in item.items():
|
201 |
batch[k].append(v)
|
202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
batch = shard(batch)
|
205 |
yield batch
|
206 |
batch = {k: [] for k in keys}
|
@@ -214,11 +245,15 @@ class Dataset:
|
|
214 |
raise ValueError(f'split must be "train" or "eval", got {split}')
|
215 |
|
216 |
if self.streaming:
|
217 |
-
return _dataloader_datasets_streaming(
|
|
|
|
|
218 |
else:
|
219 |
if split == "train":
|
220 |
self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
|
221 |
-
return _dataloader_datasets_non_streaming(
|
|
|
|
|
222 |
|
223 |
@property
|
224 |
def length(self):
|
|
|
153 |
),
|
154 |
)
|
155 |
|
156 |
+
def dataloader(
|
157 |
+
self, split, per_device_batch_size, gradient_accumulation_steps=None, epoch=None
|
158 |
+
):
|
159 |
+
num_devices = jax.local_device_count()
|
160 |
+
|
161 |
def _dataloader_datasets_non_streaming(
|
162 |
dataset: Dataset,
|
163 |
+
per_device_batch_size: int,
|
164 |
+
gradient_accumulation_steps: int,
|
165 |
rng: jax.random.PRNGKey = None,
|
166 |
):
|
167 |
"""
|
168 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
169 |
Shuffle batches if rng is set.
|
170 |
"""
|
171 |
+
batch_size = (
|
172 |
+
per_device_batch_size * num_devices * gradient_accumulation_steps
|
173 |
+
)
|
174 |
steps_per_epoch = len(dataset) // batch_size
|
175 |
|
176 |
if rng is not None:
|
|
|
186 |
for idx in batch_idx:
|
187 |
batch = dataset[idx]
|
188 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
189 |
+
if gradient_accumulation_steps is not None:
|
190 |
+
batch = jax.tree_map(
|
191 |
+
lambda x: x.reshape((-1, per_device_batch_size) + x.shape[1:]),
|
192 |
+
batch,
|
193 |
+
)
|
194 |
batch = shard(batch)
|
195 |
yield batch
|
196 |
|
197 |
def _dataloader_datasets_streaming(
|
198 |
+
dataset: Dataset,
|
199 |
+
split: str,
|
200 |
+
per_device_batch_size: int,
|
201 |
+
gradient_accumulation_steps: int,
|
202 |
+
epoch: int,
|
203 |
):
|
204 |
keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
|
205 |
batch = {k: [] for k in keys}
|
|
|
216 |
for item in dataset:
|
217 |
for k, v in item.items():
|
218 |
batch[k].append(v)
|
219 |
+
# batch = 5, devices = 8, accumulation = 2 / batch_size = 5 x 8
|
220 |
+
# (40, 3, 3) -> shard 8 x (5, 3, 3)
|
221 |
+
# (16, 5, 3, 3) -> shard 8 x (2, 5, 3, 3)
|
222 |
+
if len(batch[keys[0]]) == per_device_batch_size * num_devices * (
|
223 |
+
gradient_accumulation_steps
|
224 |
+
if gradient_accumulation_steps is not None
|
225 |
+
else 1
|
226 |
+
):
|
227 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
228 |
+
if gradient_accumulation_steps is not None:
|
229 |
+
batch = jax.tree_map(
|
230 |
+
lambda x: x.reshape(
|
231 |
+
(-1, per_device_batch_size) + x.shape[1:]
|
232 |
+
),
|
233 |
+
batch,
|
234 |
+
)
|
235 |
batch = shard(batch)
|
236 |
yield batch
|
237 |
batch = {k: [] for k in keys}
|
|
|
245 |
raise ValueError(f'split must be "train" or "eval", got {split}')
|
246 |
|
247 |
if self.streaming:
|
248 |
+
return _dataloader_datasets_streaming(
|
249 |
+
ds, split, per_device_batch_size, gradient_accumulation_steps, epoch
|
250 |
+
)
|
251 |
else:
|
252 |
if split == "train":
|
253 |
self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
|
254 |
+
return _dataloader_datasets_non_streaming(
|
255 |
+
ds, per_device_batch_size, gradient_accumulation_steps, input_rng
|
256 |
+
)
|
257 |
|
258 |
@property
|
259 |
def length(self):
|
tools/train/train.py
CHANGED
@@ -277,8 +277,8 @@ class TrainingArguments:
|
|
277 |
},
|
278 |
)
|
279 |
|
280 |
-
num_train_epochs:
|
281 |
-
default=3
|
282 |
)
|
283 |
warmup_steps: int = field(
|
284 |
default=0, metadata={"help": "Linear warmup over warmup_steps."}
|
@@ -310,12 +310,40 @@ class TrainingArguments:
|
|
310 |
metadata={"help": "Reference to a wandb artifact for resuming training."},
|
311 |
)
|
312 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
313 |
def __post_init__(self):
|
314 |
assert self.optim in [
|
315 |
"distributed_shampoo",
|
316 |
"adam",
|
317 |
"adafactor",
|
318 |
], f"Selected optimizer not supported: {self.optim}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
319 |
|
320 |
|
321 |
class TrainState(train_state.TrainState):
|
@@ -396,17 +424,6 @@ def main():
|
|
396 |
else:
|
397 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
398 |
|
399 |
-
if (
|
400 |
-
os.path.exists(training_args.output_dir)
|
401 |
-
and os.listdir(training_args.output_dir)
|
402 |
-
and training_args.do_train
|
403 |
-
and not training_args.overwrite_output_dir
|
404 |
-
):
|
405 |
-
raise ValueError(
|
406 |
-
f"Output directory ({training_args.output_dir}) already exists and is not empty."
|
407 |
-
"Use --overwrite_output_dir to overcome."
|
408 |
-
)
|
409 |
-
|
410 |
# Make one log on every process with the configuration for debugging.
|
411 |
logging.basicConfig(
|
412 |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
@@ -433,14 +450,18 @@ def main():
|
|
433 |
)
|
434 |
|
435 |
logger.info(f"Local TPUs: {jax.local_device_count()}")
|
436 |
-
|
|
|
|
|
|
|
|
|
437 |
|
438 |
# Set up wandb run
|
439 |
if jax.process_index() == 0:
|
440 |
wandb.init(
|
441 |
-
entity=
|
442 |
-
project=
|
443 |
-
job_type=
|
444 |
config=parser.parse_args(),
|
445 |
)
|
446 |
|
@@ -515,22 +536,19 @@ def main():
|
|
515 |
rng, dropout_rng = jax.random.split(rng)
|
516 |
|
517 |
# Store some constant
|
518 |
-
num_epochs =
|
519 |
# batch size per node
|
520 |
train_batch_size = (
|
521 |
-
|
522 |
-
)
|
523 |
-
batch_size_per_update = (
|
524 |
-
train_batch_size
|
525 |
-
* training_args.gradient_accumulation_steps
|
526 |
-
* jax.process_count()
|
527 |
)
|
|
|
|
|
528 |
eval_batch_size = (
|
529 |
-
|
530 |
)
|
531 |
len_train_dataset, len_eval_dataset = dataset.length
|
532 |
steps_per_epoch = (
|
533 |
-
len_train_dataset //
|
534 |
if len_train_dataset is not None
|
535 |
else None
|
536 |
)
|
@@ -645,12 +663,6 @@ def main():
|
|
645 |
clipping_threshold=training_args.max_grad_norm,
|
646 |
)
|
647 |
|
648 |
-
# add gradient accumulation
|
649 |
-
if training_args.gradient_accumulation_steps > 1:
|
650 |
-
optimizer = optax.chain(
|
651 |
-
optax.apply_every(training_args.gradient_accumulation_steps), optimizer
|
652 |
-
)
|
653 |
-
|
654 |
# Setup train state
|
655 |
state = TrainState.create(
|
656 |
apply_fn=model.__call__,
|
@@ -673,22 +685,48 @@ def main():
|
|
673 |
def train_step(state, batch, delta_time):
|
674 |
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
675 |
|
676 |
-
def compute_loss(params,
|
677 |
-
labels =
|
678 |
logits = state.apply_fn(
|
679 |
-
**
|
680 |
)[0]
|
681 |
-
|
682 |
-
return loss
|
683 |
|
684 |
grad_fn = jax.value_and_grad(compute_loss)
|
685 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
686 |
grads = jax.lax.pmean(grads, "batch")
|
687 |
state = state.apply_gradients(
|
688 |
grads=grads,
|
689 |
dropout_rng=new_dropout_rng,
|
690 |
train_time=state.train_time + delta_time,
|
691 |
-
train_samples=state.train_samples +
|
692 |
)
|
693 |
|
694 |
metrics = {
|
@@ -711,19 +749,20 @@ def main():
|
|
711 |
return metrics
|
712 |
|
713 |
# Create parallel version of the train and eval step
|
714 |
-
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
715 |
-
p_eval_step = jax.pmap(eval_step, "batch")
|
716 |
|
717 |
logger.info("***** Running training *****")
|
718 |
logger.info(f" Num examples = {len_train_dataset}")
|
719 |
logger.info(f" Num Epochs = {num_epochs}")
|
720 |
logger.info(
|
721 |
-
f"
|
722 |
)
|
723 |
logger.info(f" Number of devices = {jax.device_count()}")
|
724 |
logger.info(
|
725 |
-
f"
|
726 |
)
|
|
|
727 |
logger.info(f" Model parameters = {num_params:,}")
|
728 |
epochs = tqdm(
|
729 |
range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
|
@@ -740,8 +779,9 @@ def main():
|
|
740 |
{
|
741 |
"len_train_dataset": len_train_dataset,
|
742 |
"len_eval_dataset": len_eval_dataset,
|
743 |
-
"
|
744 |
"num_params": num_params,
|
|
|
745 |
}
|
746 |
)
|
747 |
|
@@ -752,7 +792,9 @@ def main():
|
|
752 |
# ======================== Evaluating ==============================
|
753 |
eval_metrics = []
|
754 |
if training_args.do_eval:
|
755 |
-
eval_loader = dataset.dataloader(
|
|
|
|
|
756 |
eval_steps = (
|
757 |
len_eval_dataset // eval_batch_size
|
758 |
if len_eval_dataset is not None
|
@@ -869,7 +911,12 @@ def main():
|
|
869 |
metrics_logger.log({"train/epoch": epoch}, step=unreplicate(state.step))
|
870 |
|
871 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
872 |
-
train_loader = dataset.dataloader(
|
|
|
|
|
|
|
|
|
|
|
873 |
# train
|
874 |
for batch in tqdm(
|
875 |
train_loader,
|
|
|
277 |
},
|
278 |
)
|
279 |
|
280 |
+
num_train_epochs: int = field(
|
281 |
+
default=3, metadata={"help": "Total number of training epochs to perform."}
|
282 |
)
|
283 |
warmup_steps: int = field(
|
284 |
default=0, metadata={"help": "Linear warmup over warmup_steps."}
|
|
|
310 |
metadata={"help": "Reference to a wandb artifact for resuming training."},
|
311 |
)
|
312 |
|
313 |
+
wandb_entity: Optional[str] = field(
|
314 |
+
default=None,
|
315 |
+
metadata={"help": "The wandb entity to use (for teams)."},
|
316 |
+
)
|
317 |
+
wandb_project: str = field(
|
318 |
+
default="dalle-mini",
|
319 |
+
metadata={"help": "The name of the wandb project."},
|
320 |
+
)
|
321 |
+
wandb_job_type: str = field(
|
322 |
+
default="Seq2Seq",
|
323 |
+
metadata={"help": "The name of the wandb job type."},
|
324 |
+
)
|
325 |
+
|
326 |
+
assert_TPU_available: bool = field(
|
327 |
+
default=False,
|
328 |
+
metadata={"help": "Verify that TPU is not in use."},
|
329 |
+
)
|
330 |
+
|
331 |
def __post_init__(self):
|
332 |
assert self.optim in [
|
333 |
"distributed_shampoo",
|
334 |
"adam",
|
335 |
"adafactor",
|
336 |
], f"Selected optimizer not supported: {self.optim}"
|
337 |
+
if (
|
338 |
+
os.path.exists(self.output_dir)
|
339 |
+
and os.listdir(self.output_dir)
|
340 |
+
and self.do_train
|
341 |
+
and not self.overwrite_output_dir
|
342 |
+
):
|
343 |
+
raise ValueError(
|
344 |
+
f"Output directory ({training_args.output_dir}) already exists and is not empty."
|
345 |
+
"Use --overwrite_output_dir to overcome."
|
346 |
+
)
|
347 |
|
348 |
|
349 |
class TrainState(train_state.TrainState):
|
|
|
424 |
else:
|
425 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
426 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
427 |
# Make one log on every process with the configuration for debugging.
|
428 |
logging.basicConfig(
|
429 |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
|
450 |
)
|
451 |
|
452 |
logger.info(f"Local TPUs: {jax.local_device_count()}")
|
453 |
+
logger.info(f"Global TPUs: {jax.device_count()}")
|
454 |
+
if training_args.assert_TPU_available:
|
455 |
+
assert (
|
456 |
+
jax.local_device_count() == 8
|
457 |
+
), "TPUs in use, please check running processes"
|
458 |
|
459 |
# Set up wandb run
|
460 |
if jax.process_index() == 0:
|
461 |
wandb.init(
|
462 |
+
entity=training_args.wandb_entity,
|
463 |
+
project=training_args.wandb_project,
|
464 |
+
job_type=training_args.wandb_job_type,
|
465 |
config=parser.parse_args(),
|
466 |
)
|
467 |
|
|
|
536 |
rng, dropout_rng = jax.random.split(rng)
|
537 |
|
538 |
# Store some constant
|
539 |
+
num_epochs = training_args.num_train_epochs
|
540 |
# batch size per node
|
541 |
train_batch_size = (
|
542 |
+
training_args.per_device_train_batch_size * jax.local_device_count()
|
|
|
|
|
|
|
|
|
|
|
543 |
)
|
544 |
+
batch_size_per_node = train_batch_size * training_args.gradient_accumulation_steps
|
545 |
+
batch_size_per_step = batch_size_per_node * jax.process_count()
|
546 |
eval_batch_size = (
|
547 |
+
training_args.per_device_eval_batch_size * jax.local_device_count()
|
548 |
)
|
549 |
len_train_dataset, len_eval_dataset = dataset.length
|
550 |
steps_per_epoch = (
|
551 |
+
len_train_dataset // batch_size_per_node
|
552 |
if len_train_dataset is not None
|
553 |
else None
|
554 |
)
|
|
|
663 |
clipping_threshold=training_args.max_grad_norm,
|
664 |
)
|
665 |
|
|
|
|
|
|
|
|
|
|
|
|
|
666 |
# Setup train state
|
667 |
state = TrainState.create(
|
668 |
apply_fn=model.__call__,
|
|
|
685 |
def train_step(state, batch, delta_time):
|
686 |
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
687 |
|
688 |
+
def compute_loss(params, minibatch):
|
689 |
+
labels = minibatch.pop("labels")
|
690 |
logits = state.apply_fn(
|
691 |
+
**minibatch, params=params, dropout_rng=dropout_rng, train=True
|
692 |
)[0]
|
693 |
+
return loss_fn(logits, labels)
|
|
|
694 |
|
695 |
grad_fn = jax.value_and_grad(compute_loss)
|
696 |
+
|
697 |
+
if training_args.gradient_accumulation_steps == 1:
|
698 |
+
minibatch = jax.tree_map(lambda x: x[0], batch)
|
699 |
+
loss, grads = grad_fn(state.params, minibatch)
|
700 |
+
else:
|
701 |
+
|
702 |
+
def _cumul_loss_grads(i, cumul_loss_grads):
|
703 |
+
minibatch = jax.tree_map(lambda x: x[i], batch)
|
704 |
+
return jax.tree_map(
|
705 |
+
lambda x, y: x + y,
|
706 |
+
cumul_loss_grads,
|
707 |
+
grad_fn(state.params, minibatch),
|
708 |
+
)
|
709 |
+
|
710 |
+
init_loss_grads = (
|
711 |
+
0.0,
|
712 |
+
jax.tree_map(jnp.zeros_like, state.params),
|
713 |
+
)
|
714 |
+
loss, grads = jax.tree_map(
|
715 |
+
lambda x: x / training_args.gradient_accumulation_steps,
|
716 |
+
jax.lax.fori_loop(
|
717 |
+
0,
|
718 |
+
training_args.gradient_accumulation_steps,
|
719 |
+
_cumul_loss_grads,
|
720 |
+
init_loss_grads,
|
721 |
+
),
|
722 |
+
)
|
723 |
+
|
724 |
grads = jax.lax.pmean(grads, "batch")
|
725 |
state = state.apply_gradients(
|
726 |
grads=grads,
|
727 |
dropout_rng=new_dropout_rng,
|
728 |
train_time=state.train_time + delta_time,
|
729 |
+
train_samples=state.train_samples + batch_size_per_step,
|
730 |
)
|
731 |
|
732 |
metrics = {
|
|
|
749 |
return metrics
|
750 |
|
751 |
# Create parallel version of the train and eval step
|
752 |
+
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, 1))
|
753 |
+
p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(1,))
|
754 |
|
755 |
logger.info("***** Running training *****")
|
756 |
logger.info(f" Num examples = {len_train_dataset}")
|
757 |
logger.info(f" Num Epochs = {num_epochs}")
|
758 |
logger.info(
|
759 |
+
f" Batch size per device = {training_args.per_device_train_batch_size}"
|
760 |
)
|
761 |
logger.info(f" Number of devices = {jax.device_count()}")
|
762 |
logger.info(
|
763 |
+
f" Gradient accumulation steps = {training_args.gradient_accumulation_steps}"
|
764 |
)
|
765 |
+
logger.info(f" Batch size per update = {batch_size_per_step}")
|
766 |
logger.info(f" Model parameters = {num_params:,}")
|
767 |
epochs = tqdm(
|
768 |
range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
|
|
|
779 |
{
|
780 |
"len_train_dataset": len_train_dataset,
|
781 |
"len_eval_dataset": len_eval_dataset,
|
782 |
+
"batch_size_per_step": batch_size_per_step,
|
783 |
"num_params": num_params,
|
784 |
+
"num_devices": jax.device_count(),
|
785 |
}
|
786 |
)
|
787 |
|
|
|
792 |
# ======================== Evaluating ==============================
|
793 |
eval_metrics = []
|
794 |
if training_args.do_eval:
|
795 |
+
eval_loader = dataset.dataloader(
|
796 |
+
"eval", training_args.per_device_eval_batch_size
|
797 |
+
)
|
798 |
eval_steps = (
|
799 |
len_eval_dataset // eval_batch_size
|
800 |
if len_eval_dataset is not None
|
|
|
911 |
metrics_logger.log({"train/epoch": epoch}, step=unreplicate(state.step))
|
912 |
|
913 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
914 |
+
train_loader = dataset.dataloader(
|
915 |
+
"train",
|
916 |
+
training_args.per_device_train_batch_size,
|
917 |
+
training_args.gradient_accumulation_steps,
|
918 |
+
epoch,
|
919 |
+
)
|
920 |
# train
|
921 |
for batch in tqdm(
|
922 |
train_loader,
|