Spaces:
Running
Running
add property to get num params
Browse files
dalle_mini/modeling_bart_flax.py
CHANGED
@@ -24,6 +24,7 @@ import flax.linen as nn
|
|
24 |
import jax
|
25 |
import jax.numpy as jnp
|
26 |
from flax.core.frozen_dict import FrozenDict, unfreeze
|
|
|
27 |
from flax.linen import combine_masks, make_causal_mask
|
28 |
from flax.linen.attention import dot_product_attention_weights
|
29 |
from jax import lax
|
@@ -622,6 +623,11 @@ class FlaxBartPreTrainedModel(FlaxPreTrainedModel):
|
|
622 |
module = self.module_class(config=config, dtype=dtype)
|
623 |
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, **kwargs)
|
624 |
|
|
|
|
|
|
|
|
|
|
|
625 |
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
626 |
# init input tensors
|
627 |
input_ids = jnp.zeros(input_shape, dtype="i4")
|
|
|
24 |
import jax
|
25 |
import jax.numpy as jnp
|
26 |
from flax.core.frozen_dict import FrozenDict, unfreeze
|
27 |
+
from flax.traverse_util import flatten_dict
|
28 |
from flax.linen import combine_masks, make_causal_mask
|
29 |
from flax.linen.attention import dot_product_attention_weights
|
30 |
from jax import lax
|
|
|
623 |
module = self.module_class(config=config, dtype=dtype)
|
624 |
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, **kwargs)
|
625 |
|
626 |
+
@property
|
627 |
+
def num_params(self):
|
628 |
+
num_params = jax.tree_map(lambda param: param.size, flatten_dict(unfreeze(self.params))).values()
|
629 |
+
return sum(list(num_params))
|
630 |
+
|
631 |
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
632 |
# init input tensors
|
633 |
input_ids = jnp.zeros(input_shape, dtype="i4")
|