Spaces:
Running
Running
feat: minor improvements
Browse files
dalle_mini/model/configuration.py
CHANGED
@@ -80,7 +80,6 @@ class DalleBartConfig(PretrainedConfig):
|
|
80 |
self.decoder_layerdrop = decoder_layerdrop
|
81 |
self.classifier_dropout = classifier_dropout
|
82 |
self.use_cache = use_cache
|
83 |
-
self.num_hidden_layers = encoder_layers
|
84 |
self.gradient_checkpointing = gradient_checkpointing
|
85 |
self.scale_embedding = (
|
86 |
scale_embedding # scale factor will be sqrt(d_model) if True
|
|
|
80 |
self.decoder_layerdrop = decoder_layerdrop
|
81 |
self.classifier_dropout = classifier_dropout
|
82 |
self.use_cache = use_cache
|
|
|
83 |
self.gradient_checkpointing = gradient_checkpointing
|
84 |
self.scale_embedding = (
|
85 |
scale_embedding # scale factor will be sqrt(d_model) if True
|
tools/train/train.py
CHANGED
@@ -375,6 +375,9 @@ def main():
|
|
375 |
datasets.utils.logging.set_verbosity_error()
|
376 |
transformers.utils.logging.set_verbosity_error()
|
377 |
|
|
|
|
|
|
|
378 |
# Set the verbosity to info of the Transformers logger (on main process only):
|
379 |
logger.info(f"Training/evaluation parameters {training_args}")
|
380 |
|
@@ -443,9 +446,6 @@ def main():
|
|
443 |
use_fast=True,
|
444 |
)
|
445 |
|
446 |
-
logger.info(f"TPUs: {jax.device_count()}")
|
447 |
-
assert jax.device_count() == 8, "TPUs in use, please check running processes"
|
448 |
-
|
449 |
# Preprocessing the datasets.
|
450 |
# We need to normalize and tokenize inputs and targets.
|
451 |
|
@@ -474,6 +474,7 @@ def main():
|
|
474 |
num_train_steps = (
|
475 |
steps_per_epoch * num_epochs if steps_per_epoch is not None else None
|
476 |
)
|
|
|
477 |
|
478 |
# Create learning rate schedule
|
479 |
learning_rate_fn = create_learning_rate_fn(
|
@@ -602,6 +603,7 @@ def main():
|
|
602 |
logger.info(
|
603 |
f" Total train batch size (w. parallel, distributed & gradient accumulation) = {batch_size_per_update}"
|
604 |
)
|
|
|
605 |
epochs = tqdm(
|
606 |
range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
|
607 |
)
|
@@ -616,7 +618,7 @@ def main():
|
|
616 |
"len_train_dataset": len_train_dataset,
|
617 |
"len_eval_dataset": len_eval_dataset,
|
618 |
"batch_size_per_update": batch_size_per_update,
|
619 |
-
"num_params":
|
620 |
}
|
621 |
)
|
622 |
|
@@ -693,7 +695,7 @@ def main():
|
|
693 |
c.cleanup(wandb.util.from_human_size("10GB"))
|
694 |
|
695 |
metadata = dict(state_dict)
|
696 |
-
metadata["num_params"] =
|
697 |
if eval_metrics is not None:
|
698 |
metadata["eval"] = eval_metrics
|
699 |
artifact = wandb.Artifact(
|
|
|
375 |
datasets.utils.logging.set_verbosity_error()
|
376 |
transformers.utils.logging.set_verbosity_error()
|
377 |
|
378 |
+
logger.info(f"TPUs: {jax.device_count()}")
|
379 |
+
assert jax.device_count() == 8, "TPUs in use, please check running processes"
|
380 |
+
|
381 |
# Set the verbosity to info of the Transformers logger (on main process only):
|
382 |
logger.info(f"Training/evaluation parameters {training_args}")
|
383 |
|
|
|
446 |
use_fast=True,
|
447 |
)
|
448 |
|
|
|
|
|
|
|
449 |
# Preprocessing the datasets.
|
450 |
# We need to normalize and tokenize inputs and targets.
|
451 |
|
|
|
474 |
num_train_steps = (
|
475 |
steps_per_epoch * num_epochs if steps_per_epoch is not None else None
|
476 |
)
|
477 |
+
num_params = model.num_params
|
478 |
|
479 |
# Create learning rate schedule
|
480 |
learning_rate_fn = create_learning_rate_fn(
|
|
|
603 |
logger.info(
|
604 |
f" Total train batch size (w. parallel, distributed & gradient accumulation) = {batch_size_per_update}"
|
605 |
)
|
606 |
+
logger.info(f" Model parameters = {num_params:,}")
|
607 |
epochs = tqdm(
|
608 |
range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
|
609 |
)
|
|
|
618 |
"len_train_dataset": len_train_dataset,
|
619 |
"len_eval_dataset": len_eval_dataset,
|
620 |
"batch_size_per_update": batch_size_per_update,
|
621 |
+
"num_params": num_params,
|
622 |
}
|
623 |
)
|
624 |
|
|
|
695 |
c.cleanup(wandb.util.from_human_size("10GB"))
|
696 |
|
697 |
metadata = dict(state_dict)
|
698 |
+
metadata["num_params"] = num_params
|
699 |
if eval_metrics is not None:
|
700 |
metadata["eval"] = eval_metrics
|
701 |
artifact = wandb.Artifact(
|