Spaces:
Running
on
L4
Running
on
L4
from typing import Callable | |
import jax | |
import jax.numpy as jnp | |
from jaxtyping import Array | |
def uniform_between(a: float, b: float, dtype=jnp.float32) -> Callable: | |
def init(key, shape, dtype=dtype) -> Array: | |
return jax.random.uniform(key, shape, dtype=dtype, minval=a, maxval=b) | |
return init | |
def linear_up(scale: float) -> Callable: | |
def init(key, shape, dtype=jnp.float32) -> Array: | |
assert shape[-2] == 2 | |
keys = jax.random.split(key, 2) | |
norm = jnp.pi * scale * ( | |
jax.random.uniform(keys[0], shape=(1, shape[-1])) ** .5) | |
theta = 2 * jnp.pi * jax.random.uniform(keys[1], shape=(1, shape[-1])) | |
x = norm * jnp.cos(theta) | |
y = norm * jnp.sin(theta) | |
return jnp.concatenate([x, y], axis=-2).astype(dtype) | |
return init | |