|
from functools import partial |
|
import jax |
|
import jax.numpy as np |
|
from flax import linen as nn |
|
from jax.nn.initializers import lecun_normal, normal |
|
|
|
from .ssm_init import init_CV, init_VinvB, init_log_steps, trunc_standard_normal |
|
|
|
|
|
|
|
def discretize_bilinear(Lambda, B_tilde, Delta): |
|
""" Discretize a diagonalized, continuous-time linear SSM |
|
using bilinear transform method. |
|
Args: |
|
Lambda (complex64): diagonal state matrix (P,) |
|
B_tilde (complex64): input matrix (P, H) |
|
Delta (float32): discretization step sizes (P,) |
|
Returns: |
|
discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H) |
|
""" |
|
Identity = np.ones(Lambda.shape[0]) |
|
|
|
BL = 1 / (Identity - (Delta / 2.0) * Lambda) |
|
Lambda_bar = BL * (Identity + (Delta / 2.0) * Lambda) |
|
B_bar = (BL * Delta)[..., None] * B_tilde |
|
return Lambda_bar, B_bar |
|
|
|
|
|
def discretize_zoh(Lambda, B_tilde, Delta): |
|
""" Discretize a diagonalized, continuous-time linear SSM |
|
using zero-order hold method. |
|
Args: |
|
Lambda (complex64): diagonal state matrix (P,) |
|
B_tilde (complex64): input matrix (P, H) |
|
Delta (float32): discretization step sizes (P,) |
|
Returns: |
|
discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H) |
|
""" |
|
Identity = np.ones(Lambda.shape[0]) |
|
Lambda_bar = np.exp(Lambda * Delta) |
|
B_bar = (1/Lambda * (Lambda_bar-Identity))[..., None] * B_tilde |
|
return Lambda_bar, B_bar |
|
|
|
|
|
|
|
@jax.vmap |
|
def binary_operator(q_i, q_j): |
|
""" Binary operator for parallel scan of linear recurrence. Assumes a diagonal matrix A. |
|
Args: |
|
q_i: tuple containing A_i and Bu_i at position i (P,), (P,) |
|
q_j: tuple containing A_j and Bu_j at position j (P,), (P,) |
|
Returns: |
|
new element ( A_out, Bu_out ) |
|
""" |
|
A_i, b_i = q_i |
|
A_j, b_j = q_j |
|
return A_j * A_i, A_j * b_i + b_j |
|
|
|
|
|
def apply_ssm(Lambda_bar, B_bar, C_tilde, input_sequence, conj_sym, bidirectional): |
|
""" Compute the LxH output of discretized SSM given an LxH input. |
|
Args: |
|
Lambda_bar (complex64): discretized diagonal state matrix (P,) |
|
B_bar (complex64): discretized input matrix (P, H) |
|
C_tilde (complex64): output matrix (H, P) |
|
input_sequence (float32): input sequence of features (L, H) |
|
conj_sym (bool): whether conjugate symmetry is enforced |
|
bidirectional (bool): whether bidirectional setup is used, |
|
Note for this case C_tilde will have 2P cols |
|
Returns: |
|
ys (float32): the SSM outputs (S5 layer preactivations) (L, H) |
|
""" |
|
Lambda_elements = Lambda_bar * np.ones((input_sequence.shape[0], |
|
Lambda_bar.shape[0])) |
|
Bu_elements = jax.vmap(lambda u: B_bar @ u)(input_sequence) |
|
|
|
_, xs = jax.lax.associative_scan(binary_operator, (Lambda_elements, Bu_elements)) |
|
|
|
if bidirectional: |
|
_, xs2 = jax.lax.associative_scan(binary_operator, |
|
(Lambda_elements, Bu_elements), |
|
reverse=True) |
|
xs = np.concatenate((xs, xs2), axis=-1) |
|
|
|
if conj_sym: |
|
return jax.vmap(lambda x: 2*(C_tilde @ x).real)(xs) |
|
else: |
|
return jax.vmap(lambda x: (C_tilde @ x).real)(xs) |
|
|
|
|
|
class S5SSM(nn.Module): |
|
Lambda_re_init: np.DeviceArray |
|
Lambda_im_init: np.DeviceArray |
|
V: np.DeviceArray |
|
Vinv: np.DeviceArray |
|
|
|
H: int |
|
P: int |
|
C_init: str |
|
discretization: str |
|
dt_min: float |
|
dt_max: float |
|
conj_sym: bool = True |
|
clip_eigs: bool = False |
|
bidirectional: bool = False |
|
step_rescale: float = 1.0 |
|
|
|
""" The S5 SSM |
|
Args: |
|
Lambda_re_init (complex64): Real part of init diag state matrix (P,) |
|
Lambda_im_init (complex64): Imag part of init diag state matrix (P,) |
|
V (complex64): Eigenvectors used for init (P,P) |
|
Vinv (complex64): Inverse eigenvectors used for init (P,P) |
|
H (int32): Number of features of input seq |
|
P (int32): state size |
|
C_init (string): Specifies How C is initialized |
|
Options: [trunc_standard_normal: sample from truncated standard normal |
|
and then multiply by V, i.e. C_tilde=CV. |
|
lecun_normal: sample from Lecun_normal and then multiply by V. |
|
complex_normal: directly sample a complex valued output matrix |
|
from standard normal, does not multiply by V] |
|
conj_sym (bool): Whether conjugate symmetry is enforced |
|
clip_eigs (bool): Whether to enforce left-half plane condition, i.e. |
|
constrain real part of eigenvalues to be negative. |
|
True recommended for autoregressive task/unbounded sequence lengths |
|
Discussed in https://arxiv.org/pdf/2206.11893.pdf. |
|
bidirectional (bool): Whether model is bidirectional, if True, uses two C matrices |
|
discretization: (string) Specifies discretization method |
|
options: [zoh: zero-order hold method, |
|
bilinear: bilinear transform] |
|
dt_min: (float32): minimum value to draw timescale values from when |
|
initializing log_step |
|
dt_max: (float32): maximum value to draw timescale values from when |
|
initializing log_step |
|
step_rescale: (float32): allows for uniformly changing the timescale parameter, e.g. after training |
|
on a different resolution for the speech commands benchmark |
|
""" |
|
|
|
def setup(self): |
|
"""Initializes parameters once and performs discretization each time |
|
the SSM is applied to a sequence |
|
""" |
|
|
|
if self.conj_sym: |
|
|
|
|
|
local_P = 2*self.P |
|
else: |
|
local_P = self.P |
|
|
|
|
|
self.Lambda_re = self.param("Lambda_re", lambda rng, shape: self.Lambda_re_init, (None,)) |
|
self.Lambda_im = self.param("Lambda_im", lambda rng, shape: self.Lambda_im_init, (None,)) |
|
if self.clip_eigs: |
|
self.Lambda = np.clip(self.Lambda_re, None, -1e-4) + 1j * self.Lambda_im |
|
else: |
|
self.Lambda = self.Lambda_re + 1j * self.Lambda_im |
|
|
|
|
|
B_init = lecun_normal() |
|
B_shape = (local_P, self.H) |
|
self.B = self.param("B", |
|
lambda rng, shape: init_VinvB(B_init, |
|
rng, |
|
shape, |
|
self.Vinv), |
|
B_shape) |
|
B_tilde = self.B[..., 0] + 1j * self.B[..., 1] |
|
|
|
|
|
if self.C_init in ["trunc_standard_normal"]: |
|
C_init = trunc_standard_normal |
|
C_shape = (self.H, local_P, 2) |
|
elif self.C_init in ["lecun_normal"]: |
|
C_init = lecun_normal() |
|
C_shape = (self.H, local_P, 2) |
|
elif self.C_init in ["complex_normal"]: |
|
C_init = normal(stddev=0.5 ** 0.5) |
|
else: |
|
raise NotImplementedError( |
|
"C_init method {} not implemented".format(self.C_init)) |
|
|
|
if self.C_init in ["complex_normal"]: |
|
if self.bidirectional: |
|
C = self.param("C", C_init, (self.H, 2 * self.P, 2)) |
|
self.C_tilde = C[..., 0] + 1j * C[..., 1] |
|
|
|
else: |
|
C = self.param("C", C_init, (self.H, self.P, 2)) |
|
self.C_tilde = C[..., 0] + 1j * C[..., 1] |
|
|
|
else: |
|
if self.bidirectional: |
|
self.C1 = self.param("C1", |
|
lambda rng, shape: init_CV(C_init, rng, shape, self.V), |
|
C_shape) |
|
self.C2 = self.param("C2", |
|
lambda rng, shape: init_CV(C_init, rng, shape, self.V), |
|
C_shape) |
|
|
|
C1 = self.C1[..., 0] + 1j * self.C1[..., 1] |
|
C2 = self.C2[..., 0] + 1j * self.C2[..., 1] |
|
self.C_tilde = np.concatenate((C1, C2), axis=-1) |
|
|
|
else: |
|
self.C = self.param("C", |
|
lambda rng, shape: init_CV(C_init, rng, shape, self.V), |
|
C_shape) |
|
|
|
self.C_tilde = self.C[..., 0] + 1j * self.C[..., 1] |
|
|
|
|
|
self.D = self.param("D", normal(stddev=1.0), (self.H,)) |
|
|
|
|
|
self.log_step = self.param("log_step", |
|
init_log_steps, |
|
(self.P, self.dt_min, self.dt_max)) |
|
step = self.step_rescale * np.exp(self.log_step[:, 0]) |
|
|
|
|
|
if self.discretization in ["zoh"]: |
|
self.Lambda_bar, self.B_bar = discretize_zoh(self.Lambda, B_tilde, step) |
|
elif self.discretization in ["bilinear"]: |
|
self.Lambda_bar, self.B_bar = discretize_bilinear(self.Lambda, B_tilde, step) |
|
else: |
|
raise NotImplementedError("Discretization method {} not implemented".format(self.discretization)) |
|
|
|
def __call__(self, input_sequence): |
|
""" |
|
Compute the LxH output of the S5 SSM given an LxH input sequence |
|
using a parallel scan. |
|
Args: |
|
input_sequence (float32): input sequence (L, H) |
|
Returns: |
|
output sequence (float32): (L, H) |
|
""" |
|
ys = apply_ssm(self.Lambda_bar, |
|
self.B_bar, |
|
self.C_tilde, |
|
input_sequence, |
|
self.conj_sym, |
|
self.bidirectional) |
|
|
|
|
|
Du = jax.vmap(lambda u: self.D * u)(input_sequence) |
|
return ys + Du |
|
|
|
|
|
def init_S5SSM(H, |
|
P, |
|
Lambda_re_init, |
|
Lambda_im_init, |
|
V, |
|
Vinv, |
|
C_init, |
|
discretization, |
|
dt_min, |
|
dt_max, |
|
conj_sym, |
|
clip_eigs, |
|
bidirectional |
|
): |
|
"""Convenience function that will be used to initialize the SSM. |
|
Same arguments as defined in S5SSM above.""" |
|
return partial(S5SSM, |
|
H=H, |
|
P=P, |
|
Lambda_re_init=Lambda_re_init, |
|
Lambda_im_init=Lambda_im_init, |
|
V=V, |
|
Vinv=Vinv, |
|
C_init=C_init, |
|
discretization=discretization, |
|
dt_min=dt_min, |
|
dt_max=dt_max, |
|
conj_sym=conj_sym, |
|
clip_eigs=clip_eigs, |
|
bidirectional=bidirectional) |
|
|