from jax import random import jax.numpy as np from jax.nn.initializers import lecun_normal from jax.numpy.linalg import eigh def make_HiPPO(N): """ Create a HiPPO-LegS matrix. From https://github.com/srush/annotated-s4/blob/main/s4/s4.py Args: N (int32): state size Returns: N x N HiPPO LegS matrix """ P = np.sqrt(1 + 2 * np.arange(N)) A = P[:, np.newaxis] * P[np.newaxis, :] A = np.tril(A) - np.diag(np.arange(N)) return -A def make_NPLR_HiPPO(N): """ Makes components needed for NPLR representation of HiPPO-LegS From https://github.com/srush/annotated-s4/blob/main/s4/s4.py Args: N (int32): state size Returns: N x N HiPPO LegS matrix, low-rank factor P, HiPPO input matrix B """ # Make -HiPPO hippo = make_HiPPO(N) # Add in a rank 1 term. Makes it Normal. P = np.sqrt(np.arange(N) + 0.5) # HiPPO also specifies the B matrix B = np.sqrt(2 * np.arange(N) + 1.0) return hippo, P, B def make_DPLR_HiPPO(N): """ Makes components needed for DPLR representation of HiPPO-LegS From https://github.com/srush/annotated-s4/blob/main/s4/s4.py Note, we will only use the diagonal part Args: N: Returns: eigenvalues Lambda, low-rank term P, conjugated HiPPO input matrix B, eigenvectors V, HiPPO B pre-conjugation """ A, P, B = make_NPLR_HiPPO(N) S = A + P[:, np.newaxis] * P[np.newaxis, :] S_diag = np.diagonal(S) Lambda_real = np.mean(S_diag) * np.ones_like(S_diag) # Diagonalize S to V \Lambda V^* Lambda_imag, V = eigh(S * -1j) P = V.conj().T @ P B_orig = B B = V.conj().T @ B return Lambda_real + 1j * Lambda_imag, P, B, V, B_orig def log_step_initializer(dt_min=0.001, dt_max=0.1): """ Initialize the learnable timescale Delta by sampling uniformly between dt_min and dt_max. Args: dt_min (float32): minimum value dt_max (float32): maximum value Returns: init function """ def init(key, shape): """ Init function Args: key: jax random key shape tuple: desired shape Returns: sampled log_step (float32) """ return random.uniform(key, shape) * ( np.log(dt_max) - np.log(dt_min) ) + np.log(dt_min) return init def init_log_steps(key, input): """ Initialize an array of learnable timescale parameters Args: key: jax random key input: tuple containing the array shape H and dt_min and dt_max Returns: initialized array of timescales (float32): (H,) """ H, dt_min, dt_max = input log_steps = [] for i in range(H): key, skey = random.split(key) log_step = log_step_initializer(dt_min=dt_min, dt_max=dt_max)(skey, shape=(1,)) log_steps.append(log_step) return np.array(log_steps) def init_VinvB(init_fun, rng, shape, Vinv): """ Initialize B_tilde=V^{-1}B. First samples B. Then compute V^{-1}B. Note we will parameterize this with two different matrices for complex numbers. Args: init_fun: the initialization function to use, e.g. lecun_normal() rng: jax random key to be used with init function. shape (tuple): desired shape (P,H) Vinv: (complex64) the inverse eigenvectors used for initialization Returns: B_tilde (complex64) of shape (P,H,2) """ B = init_fun(rng, shape) VinvB = Vinv @ B VinvB_real = VinvB.real VinvB_imag = VinvB.imag return np.concatenate((VinvB_real[..., None], VinvB_imag[..., None]), axis=-1) def trunc_standard_normal(key, shape): """ Sample C with a truncated normal distribution with standard deviation 1. Args: key: jax random key shape (tuple): desired shape, of length 3, (H,P,_) Returns: sampled C matrix (float32) of shape (H,P,2) (for complex parameterization) """ H, P, _ = shape Cs = [] for i in range(H): key, skey = random.split(key) C = lecun_normal()(skey, shape=(1, P, 2)) Cs.append(C) return np.array(Cs)[:, 0] def init_CV(init_fun, rng, shape, V): """ Initialize C_tilde=CV. First sample C. Then compute CV. Note we will parameterize this with two different matrices for complex numbers. Args: init_fun: the initialization function to use, e.g. lecun_normal() rng: jax random key to be used with init function. shape (tuple): desired shape (H,P) V: (complex64) the eigenvectors used for initialization Returns: C_tilde (complex64) of shape (H,P,2) """ C_ = init_fun(rng, shape) C = C_[..., 0] + 1j * C_[..., 1] CV = C @ V CV_real = CV.real CV_imag = CV.imag return np.concatenate((CV_real[..., None], CV_imag[..., None]), axis=-1)