|
from flax import linen as nn |
|
import jax |
|
|
|
|
|
class SequenceLayer(nn.Module): |
|
""" Defines a single S5 layer, with S5 SSM, nonlinearity, |
|
dropout, batch/layer norm, etc. |
|
Args: |
|
ssm (nn.Module): the SSM to be used (i.e. S5 ssm) |
|
dropout (float32): dropout rate |
|
d_model (int32): this is the feature size of the layer inputs and outputs |
|
we usually refer to this size as H |
|
activation (string): Type of activation function to use |
|
training (bool): whether in training mode or not |
|
prenorm (bool): apply prenorm if true or postnorm if false |
|
batchnorm (bool): apply batchnorm if true or layernorm if false |
|
bn_momentum (float32): the batchnorm momentum if batchnorm is used |
|
step_rescale (float32): allows for uniformly changing the timescale parameter, |
|
e.g. after training on a different resolution for |
|
the speech commands benchmark |
|
""" |
|
ssm: nn.Module |
|
dropout: float |
|
d_model: int |
|
activation: str = "gelu" |
|
training: bool = True |
|
prenorm: bool = False |
|
batchnorm: bool = False |
|
bn_momentum: float = 0.90 |
|
step_rescale: float = 1.0 |
|
|
|
def setup(self): |
|
"""Initializes the ssm, batch/layer norm and dropout |
|
""" |
|
self.seq = self.ssm(step_rescale=self.step_rescale) |
|
|
|
if self.activation in ["full_glu"]: |
|
self.out1 = nn.Dense(self.d_model) |
|
self.out2 = nn.Dense(self.d_model) |
|
elif self.activation in ["half_glu1", "half_glu2"]: |
|
self.out2 = nn.Dense(self.d_model) |
|
|
|
if self.batchnorm: |
|
self.norm = nn.BatchNorm(use_running_average=not self.training, |
|
momentum=self.bn_momentum, axis_name='batch') |
|
else: |
|
self.norm = nn.LayerNorm() |
|
|
|
self.drop = nn.Dropout( |
|
self.dropout, |
|
broadcast_dims=[0], |
|
deterministic=not self.training, |
|
) |
|
|
|
def __call__(self, x): |
|
""" |
|
Compute the LxH output of S5 layer given an LxH input. |
|
Args: |
|
x (float32): input sequence (L, d_model) |
|
Returns: |
|
output sequence (float32): (L, d_model) |
|
""" |
|
skip = x |
|
if self.prenorm: |
|
x = self.norm(x) |
|
x = self.seq(x) |
|
|
|
if self.activation in ["full_glu"]: |
|
x = self.drop(nn.gelu(x)) |
|
x = self.out1(x) * jax.nn.sigmoid(self.out2(x)) |
|
x = self.drop(x) |
|
elif self.activation in ["half_glu1"]: |
|
x = self.drop(nn.gelu(x)) |
|
x = x * jax.nn.sigmoid(self.out2(x)) |
|
x = self.drop(x) |
|
elif self.activation in ["half_glu2"]: |
|
|
|
x1 = self.drop(nn.gelu(x)) |
|
x = x * jax.nn.sigmoid(self.out2(x1)) |
|
x = self.drop(x) |
|
elif self.activation in ["gelu"]: |
|
x = self.drop(nn.gelu(x)) |
|
else: |
|
raise NotImplementedError( |
|
"Activation: {} not implemented".format(self.activation)) |
|
|
|
x = skip + x |
|
if not self.prenorm: |
|
x = self.norm(x) |
|
return x |
|
|