update layers.py for JAX deprecate "shape"

#4
by zxyse - opened
Files changed (1) hide show
  1. whisper_jax/layers.py +2 -2
whisper_jax/layers.py CHANGED
@@ -60,7 +60,7 @@ default_embed_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal", o
60
  # Temporary inlined JAX N-d initializer code
61
  # TODO(levskaya): remove once new JAX release is out.
62
  # ------------------------------------------------------------------------------
63
- def _compute_fans(shape: jax.core.NamedShape, in_axis=-2, out_axis=-1):
64
  """Inlined JAX `nn.initializer._compute_fans`."""
65
  if isinstance(in_axis, int):
66
  in_size = shape[in_axis]
@@ -70,7 +70,7 @@ def _compute_fans(shape: jax.core.NamedShape, in_axis=-2, out_axis=-1):
70
  out_size = shape[out_axis]
71
  else:
72
  out_size = int(np.prod([shape[i] for i in out_axis]))
73
- receptive_field_size = shape.total / in_size / out_size
74
  fan_in = in_size * receptive_field_size
75
  fan_out = out_size * receptive_field_size
76
  return fan_in, fan_out
 
60
  # Temporary inlined JAX N-d initializer code
61
  # TODO(levskaya): remove once new JAX release is out.
62
  # ------------------------------------------------------------------------------
63
+ def _compute_fans(shape: tuple, in_axis=-2, out_axis=-1):
64
  """Inlined JAX `nn.initializer._compute_fans`."""
65
  if isinstance(in_axis, int):
66
  in_size = shape[in_axis]
 
70
  out_size = shape[out_axis]
71
  else:
72
  out_size = int(np.prod([shape[i] for i in out_axis]))
73
+ receptive_field_size = np.prod(shape) / in_size / out_size
74
  fan_in = in_size * receptive_field_size
75
  fan_out = out_size * receptive_field_size
76
  return fan_in, fan_out