Spaces:
Running
Running
feat(train): load model on CPU
Browse files- tools/train/train.py +24 -23
tools/train/train.py
CHANGED
@@ -679,39 +679,40 @@ def main():
|
|
679 |
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
|
680 |
mesh = maps.Mesh(devices, ("batch", "mp"))
|
681 |
|
682 |
-
#
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
tx=optimizer,
|
687 |
-
params=params,
|
688 |
-
opt_state=opt_state,
|
689 |
-
dropout_rng=dropout_rng,
|
690 |
-
step=0,
|
691 |
-
)
|
692 |
-
|
693 |
-
state_spec = init_state(param_spec, opt_state_spec)
|
694 |
-
state_spec = state_spec.replace(
|
695 |
dropout_rng=None,
|
696 |
step=None,
|
697 |
epoch=None,
|
698 |
train_time=None,
|
699 |
train_samples=None,
|
|
|
|
|
700 |
)
|
701 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
702 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
703 |
-
# move params & init opt_state over specified devices
|
704 |
-
params, opt_state = pjit(
|
705 |
-
lambda x: (x, optimizer.init(x)),
|
706 |
-
in_axis_resources=None,
|
707 |
-
out_axis_resources=(param_spec, opt_state_spec),
|
708 |
-
)(freeze(model.params))
|
709 |
-
# create training state
|
710 |
state = pjit(
|
711 |
init_state,
|
712 |
-
in_axis_resources=
|
713 |
out_axis_resources=state_spec,
|
714 |
-
|
|
|
715 |
|
716 |
if training_args.resume_from_checkpoint is not None:
|
717 |
# restore optimizer state and other parameters
|
@@ -793,7 +794,7 @@ def main():
|
|
793 |
# Create parallel version of the train and eval step
|
794 |
p_train_step = pjit(
|
795 |
train_step,
|
796 |
-
in_axis_resources=(state_spec, None, None),
|
797 |
out_axis_resources=(state_spec, None),
|
798 |
donate_argnums=(0,),
|
799 |
)
|
|
|
679 |
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
|
680 |
mesh = maps.Mesh(devices, ("batch", "mp"))
|
681 |
|
682 |
+
# Create state spec
|
683 |
+
state_spec = TrainState(
|
684 |
+
params=param_spec,
|
685 |
+
opt_state=opt_state_spec,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
686 |
dropout_rng=None,
|
687 |
step=None,
|
688 |
epoch=None,
|
689 |
train_time=None,
|
690 |
train_samples=None,
|
691 |
+
apply_fn=model.__call__,
|
692 |
+
tx=optimizer,
|
693 |
)
|
694 |
|
695 |
+
# create training state
|
696 |
+
def init_state(params):
|
697 |
+
state = TrainState.create(
|
698 |
+
apply_fn=model.__call__,
|
699 |
+
tx=optimizer,
|
700 |
+
params=freeze(params),
|
701 |
+
dropout_rng=dropout_rng,
|
702 |
+
)
|
703 |
+
return state
|
704 |
+
|
705 |
+
# hack: move the inital params to CPU to free up device memory
|
706 |
+
# TODO: allow loading weights on CPU in pre-trained model
|
707 |
+
model.params = jax.tree_map(lambda x: np.asarray(x), model.params)
|
708 |
+
|
709 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
710 |
state = pjit(
|
711 |
init_state,
|
712 |
+
in_axis_resources=None,
|
713 |
out_axis_resources=state_spec,
|
714 |
+
donate_argnums=(0,),
|
715 |
+
)(freeze(model.params))
|
716 |
|
717 |
if training_args.resume_from_checkpoint is not None:
|
718 |
# restore optimizer state and other parameters
|
|
|
794 |
# Create parallel version of the train and eval step
|
795 |
p_train_step = pjit(
|
796 |
train_step,
|
797 |
+
in_axis_resources=(state_spec, PartitionSpec("batch", None), None),
|
798 |
out_axis_resources=(state_spec, None),
|
799 |
donate_argnums=(0,),
|
800 |
)
|