boris commited on
Commit
53dade7
1 Parent(s): eb24dbc

feat: minor improvements

Browse files
dalle_mini/model/configuration.py CHANGED
@@ -80,7 +80,6 @@ class DalleBartConfig(PretrainedConfig):
80
  self.decoder_layerdrop = decoder_layerdrop
81
  self.classifier_dropout = classifier_dropout
82
  self.use_cache = use_cache
83
- self.num_hidden_layers = encoder_layers
84
  self.gradient_checkpointing = gradient_checkpointing
85
  self.scale_embedding = (
86
  scale_embedding # scale factor will be sqrt(d_model) if True
 
80
  self.decoder_layerdrop = decoder_layerdrop
81
  self.classifier_dropout = classifier_dropout
82
  self.use_cache = use_cache
 
83
  self.gradient_checkpointing = gradient_checkpointing
84
  self.scale_embedding = (
85
  scale_embedding # scale factor will be sqrt(d_model) if True
tools/train/train.py CHANGED
@@ -375,6 +375,9 @@ def main():
375
  datasets.utils.logging.set_verbosity_error()
376
  transformers.utils.logging.set_verbosity_error()
377
 
 
 
 
378
  # Set the verbosity to info of the Transformers logger (on main process only):
379
  logger.info(f"Training/evaluation parameters {training_args}")
380
 
@@ -443,9 +446,6 @@ def main():
443
  use_fast=True,
444
  )
445
 
446
- logger.info(f"TPUs: {jax.device_count()}")
447
- assert jax.device_count() == 8, "TPUs in use, please check running processes"
448
-
449
  # Preprocessing the datasets.
450
  # We need to normalize and tokenize inputs and targets.
451
 
@@ -474,6 +474,7 @@ def main():
474
  num_train_steps = (
475
  steps_per_epoch * num_epochs if steps_per_epoch is not None else None
476
  )
 
477
 
478
  # Create learning rate schedule
479
  learning_rate_fn = create_learning_rate_fn(
@@ -602,6 +603,7 @@ def main():
602
  logger.info(
603
  f" Total train batch size (w. parallel, distributed & gradient accumulation) = {batch_size_per_update}"
604
  )
 
605
  epochs = tqdm(
606
  range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
607
  )
@@ -616,7 +618,7 @@ def main():
616
  "len_train_dataset": len_train_dataset,
617
  "len_eval_dataset": len_eval_dataset,
618
  "batch_size_per_update": batch_size_per_update,
619
- "num_params": model.num_params,
620
  }
621
  )
622
 
@@ -693,7 +695,7 @@ def main():
693
  c.cleanup(wandb.util.from_human_size("10GB"))
694
 
695
  metadata = dict(state_dict)
696
- metadata["num_params"] = model.num_params
697
  if eval_metrics is not None:
698
  metadata["eval"] = eval_metrics
699
  artifact = wandb.Artifact(
 
375
  datasets.utils.logging.set_verbosity_error()
376
  transformers.utils.logging.set_verbosity_error()
377
 
378
+ logger.info(f"TPUs: {jax.device_count()}")
379
+ assert jax.device_count() == 8, "TPUs in use, please check running processes"
380
+
381
  # Set the verbosity to info of the Transformers logger (on main process only):
382
  logger.info(f"Training/evaluation parameters {training_args}")
383
 
 
446
  use_fast=True,
447
  )
448
 
 
 
 
449
  # Preprocessing the datasets.
450
  # We need to normalize and tokenize inputs and targets.
451
 
 
474
  num_train_steps = (
475
  steps_per_epoch * num_epochs if steps_per_epoch is not None else None
476
  )
477
+ num_params = model.num_params
478
 
479
  # Create learning rate schedule
480
  learning_rate_fn = create_learning_rate_fn(
 
603
  logger.info(
604
  f" Total train batch size (w. parallel, distributed & gradient accumulation) = {batch_size_per_update}"
605
  )
606
+ logger.info(f" Model parameters = {num_params:,}")
607
  epochs = tqdm(
608
  range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
609
  )
 
618
  "len_train_dataset": len_train_dataset,
619
  "len_eval_dataset": len_eval_dataset,
620
  "batch_size_per_update": batch_size_per_update,
621
+ "num_params": num_params,
622
  }
623
  )
624
 
 
695
  c.cleanup(wandb.util.from_human_size("10GB"))
696
 
697
  metadata = dict(state_dict)
698
+ metadata["num_params"] = num_params
699
  if eval_metrics is not None:
700
  metadata["eval"] = eval_metrics
701
  artifact = wandb.Artifact(