boris commited on
Commit
68cc185
1 Parent(s): 44b7c3e
Files changed (1) hide show
  1. src/dalle_mini/model/modeling.py +1 -1
src/dalle_mini/model/modeling.py CHANGED
@@ -337,7 +337,7 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
337
  # init weights on CPU
338
  init_fn = jax.jit(self.init_weights, static_argnums=(1,), backend="cpu")
339
  else:
340
- init_fn = self.init_weigths
341
 
342
  # randomly initialized parameters
343
  random_params = self.init_weights(self.key, input_shape)
 
337
  # init weights on CPU
338
  init_fn = jax.jit(self.init_weights, static_argnums=(1,), backend="cpu")
339
  else:
340
+ init_fn = self.init_weights
341
 
342
  # randomly initialized parameters
343
  random_params = self.init_weights(self.key, input_shape)