boris commited on
Commit
3d43591
1 Parent(s): 2d212d8

feat(train): load model on CPU

Browse files
Files changed (1) hide show
  1. tools/train/train.py +24 -23
tools/train/train.py CHANGED
@@ -679,39 +679,40 @@ def main():
679
  devices = np.asarray(jax.devices()).reshape(*mesh_shape)
680
  mesh = maps.Mesh(devices, ("batch", "mp"))
681
 
682
- # Setup train state
683
- def init_state(params, opt_state):
684
- return TrainState(
685
- apply_fn=model.__call__,
686
- tx=optimizer,
687
- params=params,
688
- opt_state=opt_state,
689
- dropout_rng=dropout_rng,
690
- step=0,
691
- )
692
-
693
- state_spec = init_state(param_spec, opt_state_spec)
694
- state_spec = state_spec.replace(
695
  dropout_rng=None,
696
  step=None,
697
  epoch=None,
698
  train_time=None,
699
  train_samples=None,
 
 
700
  )
701
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
702
  with maps.mesh(mesh.devices, mesh.axis_names):
703
- # move params & init opt_state over specified devices
704
- params, opt_state = pjit(
705
- lambda x: (x, optimizer.init(x)),
706
- in_axis_resources=None,
707
- out_axis_resources=(param_spec, opt_state_spec),
708
- )(freeze(model.params))
709
- # create training state
710
  state = pjit(
711
  init_state,
712
- in_axis_resources=(param_spec, opt_state_spec),
713
  out_axis_resources=state_spec,
714
- )(params, opt_state)
 
715
 
716
  if training_args.resume_from_checkpoint is not None:
717
  # restore optimizer state and other parameters
@@ -793,7 +794,7 @@ def main():
793
  # Create parallel version of the train and eval step
794
  p_train_step = pjit(
795
  train_step,
796
- in_axis_resources=(state_spec, None, None),
797
  out_axis_resources=(state_spec, None),
798
  donate_argnums=(0,),
799
  )
 
679
  devices = np.asarray(jax.devices()).reshape(*mesh_shape)
680
  mesh = maps.Mesh(devices, ("batch", "mp"))
681
 
682
+ # Create state spec
683
+ state_spec = TrainState(
684
+ params=param_spec,
685
+ opt_state=opt_state_spec,
 
 
 
 
 
 
 
 
 
686
  dropout_rng=None,
687
  step=None,
688
  epoch=None,
689
  train_time=None,
690
  train_samples=None,
691
+ apply_fn=model.__call__,
692
+ tx=optimizer,
693
  )
694
 
695
+ # create training state
696
+ def init_state(params):
697
+ state = TrainState.create(
698
+ apply_fn=model.__call__,
699
+ tx=optimizer,
700
+ params=freeze(params),
701
+ dropout_rng=dropout_rng,
702
+ )
703
+ return state
704
+
705
+ # hack: move the inital params to CPU to free up device memory
706
+ # TODO: allow loading weights on CPU in pre-trained model
707
+ model.params = jax.tree_map(lambda x: np.asarray(x), model.params)
708
+
709
  with maps.mesh(mesh.devices, mesh.axis_names):
 
 
 
 
 
 
 
710
  state = pjit(
711
  init_state,
712
+ in_axis_resources=None,
713
  out_axis_resources=state_spec,
714
+ donate_argnums=(0,),
715
+ )(freeze(model.params))
716
 
717
  if training_args.resume_from_checkpoint is not None:
718
  # restore optimizer state and other parameters
 
794
  # Create parallel version of the train and eval step
795
  p_train_step = pjit(
796
  train_step,
797
+ in_axis_resources=(state_spec, PartitionSpec("batch", None), None),
798
  out_axis_resources=(state_spec, None),
799
  donate_argnums=(0,),
800
  )