thera / model /init.py
Alexander Becker
Add code
b139995
raw
history blame contribute delete
808 Bytes
from typing import Callable
import jax
import jax.numpy as jnp
from jaxtyping import Array
def uniform_between(a: float, b: float, dtype=jnp.float32) -> Callable:
def init(key, shape, dtype=dtype) -> Array:
return jax.random.uniform(key, shape, dtype=dtype, minval=a, maxval=b)
return init
def linear_up(scale: float) -> Callable:
def init(key, shape, dtype=jnp.float32) -> Array:
assert shape[-2] == 2
keys = jax.random.split(key, 2)
norm = jnp.pi * scale * (
jax.random.uniform(keys[0], shape=(1, shape[-1])) ** .5)
theta = 2 * jnp.pi * jax.random.uniform(keys[1], shape=(1, shape[-1]))
x = norm * jnp.cos(theta)
y = norm * jnp.sin(theta)
return jnp.concatenate([x, y], axis=-2).astype(dtype)
return init