boris commited on
Commit
12f323d
1 Parent(s): 3d43591

feat(model): clean way to load on cpu

Browse files
src/dalle_mini/model/modeling.py CHANGED
@@ -300,6 +300,7 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
300
  - added num_params property
301
  - config_class replaced to DalleBartConfig
302
  - __init__ accepts abstract_init which does uses parameter shape to initialize the model
 
303
  """
304
 
305
  config_class = DalleBartConfig
@@ -311,6 +312,7 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
311
  seed: int = 0,
312
  dtype: jnp.dtype = jnp.float32,
313
  abstract_init: bool = False,
 
314
  **kwargs,
315
  ):
316
  module = self.module_class(config=config, dtype=dtype, **kwargs)
@@ -330,15 +332,21 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
330
  self.key = PRNGKey(seed)
331
  self.dtype = dtype
332
 
 
 
 
 
 
 
333
  # randomly initialized parameters
334
  if abstract_init:
335
  # init the model weights only abstractly, eval_shape will return a pytree
336
  # with the structure as weights but without any actual values, this will just contain
337
  # the shape information. Weights need to be loaded later.
338
- init_fn = partial(self.init_weights, input_shape=input_shape)
339
  random_params = jax.eval_shape(init_fn, self.key)
340
  else:
341
- random_params = self.init_weights(self.key, input_shape)
342
 
343
  # save required_params as set
344
  self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
 
300
  - added num_params property
301
  - config_class replaced to DalleBartConfig
302
  - __init__ accepts abstract_init which does uses parameter shape to initialize the model
303
+ - init weights on CPU
304
  """
305
 
306
  config_class = DalleBartConfig
 
312
  seed: int = 0,
313
  dtype: jnp.dtype = jnp.float32,
314
  abstract_init: bool = False,
315
+ load_on_cpu: bool = True,
316
  **kwargs,
317
  ):
318
  module = self.module_class(config=config, dtype=dtype, **kwargs)
 
332
  self.key = PRNGKey(seed)
333
  self.dtype = dtype
334
 
335
+ # init weights on CPU
336
+ if load_on_cpu:
337
+ init_fn = jax.jit(self.init_weights, static_argnums=(1,), backend="cpu")
338
+ else:
339
+ init_fn = self.init_weights
340
+
341
  # randomly initialized parameters
342
  if abstract_init:
343
  # init the model weights only abstractly, eval_shape will return a pytree
344
  # with the structure as weights but without any actual values, this will just contain
345
  # the shape information. Weights need to be loaded later.
346
+ init_fn = partial(init_fn, input_shape=input_shape)
347
  random_params = jax.eval_shape(init_fn, self.key)
348
  else:
349
+ random_params = init_fn(self.key, input_shape)
350
 
351
  # save required_params as set
352
  self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
tools/train/train.py CHANGED
@@ -702,10 +702,6 @@ def main():
702
  )
703
  return state
704
 
705
- # hack: move the inital params to CPU to free up device memory
706
- # TODO: allow loading weights on CPU in pre-trained model
707
- model.params = jax.tree_map(lambda x: np.asarray(x), model.params)
708
-
709
  with maps.mesh(mesh.devices, mesh.axis_names):
710
  state = pjit(
711
  init_state,
 
702
  )
703
  return state
704
 
 
 
 
 
705
  with maps.mesh(mesh.devices, mesh.axis_names):
706
  state = pjit(
707
  init_state,