Spaces:
Running
Running
feat(model): clean way to load on cpu
Browse files- src/dalle_mini/model/modeling.py +10 -2
- tools/train/train.py +0 -4
src/dalle_mini/model/modeling.py
CHANGED
@@ -300,6 +300,7 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
|
|
300 |
- added num_params property
|
301 |
- config_class replaced to DalleBartConfig
|
302 |
- __init__ accepts abstract_init which does uses parameter shape to initialize the model
|
|
|
303 |
"""
|
304 |
|
305 |
config_class = DalleBartConfig
|
@@ -311,6 +312,7 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
|
|
311 |
seed: int = 0,
|
312 |
dtype: jnp.dtype = jnp.float32,
|
313 |
abstract_init: bool = False,
|
|
|
314 |
**kwargs,
|
315 |
):
|
316 |
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
@@ -330,15 +332,21 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
|
|
330 |
self.key = PRNGKey(seed)
|
331 |
self.dtype = dtype
|
332 |
|
|
|
|
|
|
|
|
|
|
|
|
|
333 |
# randomly initialized parameters
|
334 |
if abstract_init:
|
335 |
# init the model weights only abstractly, eval_shape will return a pytree
|
336 |
# with the structure as weights but without any actual values, this will just contain
|
337 |
# the shape information. Weights need to be loaded later.
|
338 |
-
init_fn = partial(
|
339 |
random_params = jax.eval_shape(init_fn, self.key)
|
340 |
else:
|
341 |
-
random_params =
|
342 |
|
343 |
# save required_params as set
|
344 |
self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
|
|
|
300 |
- added num_params property
|
301 |
- config_class replaced to DalleBartConfig
|
302 |
- __init__ accepts abstract_init which does uses parameter shape to initialize the model
|
303 |
+
- init weights on CPU
|
304 |
"""
|
305 |
|
306 |
config_class = DalleBartConfig
|
|
|
312 |
seed: int = 0,
|
313 |
dtype: jnp.dtype = jnp.float32,
|
314 |
abstract_init: bool = False,
|
315 |
+
load_on_cpu: bool = True,
|
316 |
**kwargs,
|
317 |
):
|
318 |
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
|
|
332 |
self.key = PRNGKey(seed)
|
333 |
self.dtype = dtype
|
334 |
|
335 |
+
# init weights on CPU
|
336 |
+
if load_on_cpu:
|
337 |
+
init_fn = jax.jit(self.init_weights, static_argnums=(1,), backend="cpu")
|
338 |
+
else:
|
339 |
+
init_fn = self.init_weights
|
340 |
+
|
341 |
# randomly initialized parameters
|
342 |
if abstract_init:
|
343 |
# init the model weights only abstractly, eval_shape will return a pytree
|
344 |
# with the structure as weights but without any actual values, this will just contain
|
345 |
# the shape information. Weights need to be loaded later.
|
346 |
+
init_fn = partial(init_fn, input_shape=input_shape)
|
347 |
random_params = jax.eval_shape(init_fn, self.key)
|
348 |
else:
|
349 |
+
random_params = init_fn(self.key, input_shape)
|
350 |
|
351 |
# save required_params as set
|
352 |
self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
|
tools/train/train.py
CHANGED
@@ -702,10 +702,6 @@ def main():
|
|
702 |
)
|
703 |
return state
|
704 |
|
705 |
-
# hack: move the inital params to CPU to free up device memory
|
706 |
-
# TODO: allow loading weights on CPU in pre-trained model
|
707 |
-
model.params = jax.tree_map(lambda x: np.asarray(x), model.params)
|
708 |
-
|
709 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
710 |
state = pjit(
|
711 |
init_state,
|
|
|
702 |
)
|
703 |
return state
|
704 |
|
|
|
|
|
|
|
|
|
705 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
706 |
state = pjit(
|
707 |
init_state,
|