boris commited on
Commit
34cf91c
·
1 Parent(s): e558000

feat: reduce artifact space + offset step

Browse files
Files changed (2) hide show
  1. src/dalle_mini/model/utils.py +14 -12
  2. tools/train/train.py +43 -19
src/dalle_mini/model/utils.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
 
3
  import wandb
4
 
@@ -9,16 +10,17 @@ class PretrainedFromWandbMixin:
9
  """
10
  Initializes from a wandb artifact, or delegates loading to the superclass.
11
  """
12
- if ":" in pretrained_model_name_or_path and not os.path.isdir(
13
- pretrained_model_name_or_path
14
- ):
15
- # wandb artifact
16
- if wandb.run is not None:
17
- artifact = wandb.run.use_artifact(pretrained_model_name_or_path)
18
- else:
19
- artifact = wandb.Api().artifact(pretrained_model_name_or_path)
20
- pretrained_model_name_or_path = artifact.download()
 
21
 
22
- return super(PretrainedFromWandbMixin, cls).from_pretrained(
23
- pretrained_model_name_or_path, *model_args, **kwargs
24
- )
 
1
  import os
2
+ import tempfile
3
 
4
  import wandb
5
 
 
10
  """
11
  Initializes from a wandb artifact, or delegates loading to the superclass.
12
  """
13
+ with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies
14
+ if ":" in pretrained_model_name_or_path and not os.path.isdir(
15
+ pretrained_model_name_or_path
16
+ ):
17
+ # wandb artifact
18
+ if wandb.run is not None:
19
+ artifact = wandb.run.use_artifact(pretrained_model_name_or_path)
20
+ else:
21
+ artifact = wandb.Api().artifact(pretrained_model_name_or_path)
22
+ pretrained_model_name_or_path = artifact.download(tmp_dir)
23
 
24
+ return super(PretrainedFromWandbMixin, cls).from_pretrained(
25
+ pretrained_model_name_or_path, *model_args, **kwargs
26
+ )
tools/train/train.py CHANGED
@@ -22,6 +22,7 @@ import json
22
  import logging
23
  import os
24
  import sys
 
25
  import time
26
  from dataclasses import asdict, dataclass, field
27
  from pathlib import Path
@@ -97,12 +98,10 @@ class ModelArguments:
97
  restore_state: Optional[bool] = field(
98
  default=False,
99
  metadata={
100
- "help": "Restore optimizer and training state associated with a wandb checkpoint."
101
  },
102
  )
103
 
104
- state_artifact: str = field(init=False)
105
-
106
  def __post_init__(self):
107
  if self.tokenizer_name is None:
108
  self.tokenizer_name == self.model_name_or_path
@@ -113,9 +112,28 @@ class ModelArguments:
113
  assert self.model_name_or_path is not None and (
114
  "/model-" in self.model_name_or_path
115
  ), "Restoring state only available with W&B artifact reference"
116
- self.state_artifact = self.model_name_or_path.replace(
117
- "/model-", "/state-", 1
118
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
 
121
  @dataclass
@@ -521,6 +539,9 @@ def main():
521
  # update model config per training args
522
  model.config.gradient_checkpointing = training_args.gradient_checkpointing
523
 
 
 
 
524
  # get PartitionSpec for model params (required to be a dict)
525
  param_spec = set_partitions(model.params)
526
 
@@ -581,7 +602,7 @@ def main():
581
  logger.info(f" Batch size per update = {batch_size_per_step}")
582
  logger.info(f" Model parameters = {num_params:,}")
583
 
584
- # create wandb run
585
  if jax.process_index() == 0:
586
  # set default x-axis as 'train/step'
587
  wandb.define_metric("*", step_metric="train/step")
@@ -605,6 +626,12 @@ def main():
605
  end_value=training_args.learning_rate,
606
  transition_steps=training_args.warmup_steps,
607
  )
 
 
 
 
 
 
608
  if training_args.lr_decay is None:
609
  return warmup_fn
610
  elif training_args.lr_decay == "linear":
@@ -757,20 +784,17 @@ def main():
757
  )(model.params)
758
 
759
  else:
760
- # get state files from artifact
761
- if jax.process_index() == 0:
762
- artifact = wandb.run.use_artifact(model_args.state_artifact)
763
- else:
764
- artifact = wandb.Api().artifact(model_args.state_artifact)
765
- artifact_dir = artifact.download()
766
-
767
- # restore opt_state
768
- with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
769
- opt_state = from_bytes(opt_state_shape, f.read())
770
 
771
  # restore other attributes
772
- with (Path(artifact_dir) / "training_state.json").open("r") as f:
773
- attr_state = json.load(f)
 
 
774
 
775
  def restore_state(params, opt_state):
776
  return TrainState(
 
22
  import logging
23
  import os
24
  import sys
25
+ import tempfile
26
  import time
27
  from dataclasses import asdict, dataclass, field
28
  from pathlib import Path
 
98
  restore_state: Optional[bool] = field(
99
  default=False,
100
  metadata={
101
+ "help": "Restore optimizer and training state. Can be True (will retrieve associated wandb artifact), a local directory or a Google bucket path."
102
  },
103
  )
104
 
 
 
105
  def __post_init__(self):
106
  if self.tokenizer_name is None:
107
  self.tokenizer_name == self.model_name_or_path
 
112
  assert self.model_name_or_path is not None and (
113
  "/model-" in self.model_name_or_path
114
  ), "Restoring state only available with W&B artifact reference"
115
+
116
+ def get_metadata(self):
117
+ if self.restore_state:
118
+ if jax.process_index() == 0:
119
+ artifact = wandb.run.use_artifact(self.model_name_or_path)
120
+ else:
121
+ artifact = wandb.Api().artifact(self.model_name_or_path)
122
+ return artifact.metadata
123
+ else:
124
+ return dict()
125
+
126
+ def get_opt_state(self, tmp_dir):
127
+ if self.restore_state is True:
128
+ # wandb artifact
129
+ state_artifact = self.model_name_or_path.replace("/model-", "/state-", 1)
130
+ if jax.process_index() == 0:
131
+ artifact = wandb.run.use_artifact(state_artifact)
132
+ else:
133
+ artifact = wandb.Api().artifact(state_artifact)
134
+ artifact_dir = artifact.download(tmp_dir)
135
+ self.restore_state = Path(artifact_dir) / "opt_state.msgpack"
136
+ return Path(self.restore_state).open("rb")
137
 
138
 
139
  @dataclass
 
539
  # update model config per training args
540
  model.config.gradient_checkpointing = training_args.gradient_checkpointing
541
 
542
+ # get model metadata
543
+ model_metadata = model_args.get_metadata()
544
+
545
  # get PartitionSpec for model params (required to be a dict)
546
  param_spec = set_partitions(model.params)
547
 
 
602
  logger.info(f" Batch size per update = {batch_size_per_step}")
603
  logger.info(f" Model parameters = {num_params:,}")
604
 
605
+ # set up wandb run
606
  if jax.process_index() == 0:
607
  # set default x-axis as 'train/step'
608
  wandb.define_metric("*", step_metric="train/step")
 
626
  end_value=training_args.learning_rate,
627
  transition_steps=training_args.warmup_steps,
628
  )
629
+ # offset step when resuming
630
+ if model_metadata.get("step", 0):
631
+ warmup_fn = optax.join_schedules(
632
+ schedules=[optax.constant_schedule(0.0), warmup_fn],
633
+ boundaries=[model_metadata["step"]],
634
+ )
635
  if training_args.lr_decay is None:
636
  return warmup_fn
637
  elif training_args.lr_decay == "linear":
 
784
  )(model.params)
785
 
786
  else:
787
+ # load opt_state
788
+ with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies
789
+ opt_state_file = model_args.get_opt_state(tmp_dir)
790
+ opt_state = from_bytes(opt_state_shape, opt_state_file.read())
791
+ opt_state_file.close()
 
 
 
 
 
792
 
793
  # restore other attributes
794
+ attr_state = {
795
+ k: model_metadata[k]
796
+ for k in ["step", "epoch", "train_time", "train_samples"]
797
+ }
798
 
799
  def restore_state(params, opt_state):
800
  return TrainState(