HugoVoxx's picture
Upload 20 files
15bcbe6 verified
raw
history blame
16.9 kB
# Copyright 2022 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Core NN components used in models.
"""
from typing import Any, Callable, Optional, Tuple, Union
from absl import logging
from flax import linen as nn
import gin
import jax
from jax import lax
from jax.nn import initializers
import jax.numpy as jnp
PRNGKey = Any
Array = jnp.ndarray
Shape = Tuple[int, ...]
Dtype = Union[jnp.dtype, str]
def scalar_initializer(x):
"""Like linen.zeros, but initializes a parameter to a scalar value."""
def init_fun(key, shape, dtype):
del key
return jnp.broadcast_to(jnp.array(x, dtype=dtype), shape)
return init_fun
def swish(x: Array) -> Array:
"""Swish function, which is very similar to gelu."""
return x * nn.sigmoid(x)
def soft_abs(x: Array) -> Array:
"""Soft version of absolute value, that is smoothly differentiable."""
return jnp.sqrt(jnp.square(x) + 1) - 1
def get_activation_function(fname: Optional[str]) -> Callable[[Array], Array]:
"""Get activation function from the specified string."""
if fname is None:
return lambda x: x
elif fname == "relu":
return nn.relu
elif fname == "swish":
return swish
elif fname == "sigmoid":
return nn.sigmoid
elif fname == "tanh":
return nn.tanh
else:
raise ValueError("Unknown activation function %s" % fname)
# Adapted from flax.linen.softmax.
def safe_softmax(x: Array,
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
min_x: Optional[Array] = None) -> Array:
r"""Softmax function.
Computes the function which rescales elements to the range :math:`[0, 1]`
such that the elements along :code:`axis` sum to :math:`1`.
This version of softmax is intended for use with causal attention masks, and
safely covers the situation where all elements are masked out. If min_x is
not None, then probabability will be distributed between the values in x, and
min_x. If x >> min_x, then the probability allocated to min_x will be zero,
and this function will be the same as the usual softmax. However, if
x << min_x, (because all the values in x are masked out) then probability
will be allocated to min_x instead, and the probability allocated to x will
be 0. I.e., attention will attend to nothing if everything is masked out.
.. math ::
\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
Args:
x: input array
axis: the axis or axes along which the softmax should be computed. The
softmax output summed across these dimensions should sum to :math:`1`.
Either an integer or a tuple of integers.
min_x: the value of a minimum element which will be included in the
softmax sum. The value of min_x should be small when compared to the
expected values in x. If all of the values in x are smaller than
min_x, then probability will be allocated to the minimum element
instead, and the result of softmax will sum to less than 1.
Returns:
An array of the same shape as x.
"""
# Subtract maximum value in x for numerical stability, so that the exponent
# never exceeds numerical precision.
x_max = lax.stop_gradient(jnp.max(x, axis, initial=min_x, keepdims=True))
if min_x is not None:
min_x = jnp.asarray(min_x, dtype=x.dtype)
x_max = jnp.maximum(x_max, min_x)
unnormalized = jnp.exp(x - x_max)
x_sum = jnp.sum(unnormalized, axis=axis, keepdims=True)
if min_x is not None:
x_sum = x_sum + jnp.exp(min_x - x_max)
return unnormalized / x_sum
def dropout_multiplier_mask(rng, dropout_rate: float, shape: Shape,
dtype: Dtype):
"""Returns an array which can be multiplied by an input to perform dropout.
Args:
rng: A random number generator.
dropout_rate: The rate at which to drop.
shape: The shape of the output array.
dtype: The type of the output array.
Returns:
An array of given shape, where values are { 0.0, 1.0/keep_probibility. }.
"""
if dropout_rate <= 0.0:
return jnp.ones(shape, dtype=dtype)
logging.info("dropout mask: %s", shape)
keep_prob = 1.0 - dropout_rate
keep = jax.random.bernoulli(rng, keep_prob, shape)
dropout_multiplier = (keep.astype(dtype) / jnp.asarray(keep_prob, dtype))
return dropout_multiplier
def tiled_dropout(x: Array, shape: Shape, dropout_rate: float,
rng_function: Callable[[], jax.random.KeyArray],
deterministic: bool) -> Array:
"""Tiles a dropout mask over a larger array.
This will generate a smaller dropout mask of the given shape, and tile it
over a larger array, which reduces the computational cost and memory
associated with generating a large dropout mask.
Args:
x: The input array.
shape: The shape of the dropout mask to tile.
dropout_rate: The rate at which to drop.
rng_function: A function which returns a random number generator, e.g.
lambda. self.make_rng("dropout"). The function will not
be called if dropout is not enabled.
deterministic: If True, don't do dropout.
Returns:
An array of the same shape as x, with some values dropped out.
"""
if deterministic or dropout_rate <= 0.0:
return x
if x.ndim != len(shape):
raise ValueError("Shapes must have same number of dimensions %r, %r." %
(x.shape, shape))
for (xd, sd) in zip(x.shape, shape):
if (xd % sd) != 0:
raise ValueError("Incompatible shapes %r, %r" % (x.shape, shape))
# Get random number generator for dropout.
rng = rng_function()
repeats = [(1 if sd == 1 else xd // sd) for (xd, sd) in zip(x.shape, shape)]
logging.info("tiled dropout %r, tile: %r", x.shape, shape)
dtype = x.dtype
keep_prob = 1.0 - dropout_rate
keep = jax.random.bernoulli(rng, keep_prob, shape)
keep = jnp.tile(keep, repeats)
keep = jnp.broadcast_to(keep, x.shape)
x_scaled = x / jnp.asarray(keep_prob, dtype=dtype)
return lax.select(keep, x_scaled, jnp.zeros_like(x, dtype=dtype))
@gin.configurable
class MLP(nn.Module):
"""Implements a multi-layer perceptron, with optional resnet or gate."""
# Arguments to module.
num_output_features: int # Length of output vectors.
# Gin configurable parameters.
num_layers: int = gin.REQUIRED # Number of layers in the MLP.
num_hidden_units: int = gin.REQUIRED # Length of hidden unit vectors.
hidden_activation: Optional[str] = "relu" # Hidden layer activation fn.
final_activation: Optional[str] = None # Final layer activation fn.
use_bias: bool = True # Use a bias in each dense layer.
gate_type: Optional[str] = None # { "residual", "bias", "full" }
initializer_scale: float = 1.0 # Scale of initial values.
dtype: Any = jnp.float32
def setup(self):
kernel_init = jax.nn.initializers.variance_scaling(
scale=self.initializer_scale, mode="fan_in",
distribution="truncated_normal")
assert self.num_layers > 0
hlayers = []
for i in range(0, self.num_layers - 1):
assert self.num_hidden_units > 0
hlayer = nn.Dense(self.num_hidden_units,
use_bias=self.use_bias,
kernel_init=kernel_init,
dtype=self.dtype,
name=f"hidden{i}")
hlayers.append(hlayer)
self.hidden_layers = hlayers
self.output_layer = nn.Dense(self.num_output_features,
use_bias=self.use_bias,
kernel_init=kernel_init,
dtype=self.dtype)
if self.gate_type is None or self.gate_type == "residual":
return
# We use a low but non-zero bias so that adafactor knows how to scale it.
gate_bias_init = jax.nn.initializers.normal(stddev=0.1)
# Also use a lower than normal kernel.
gate_kernel_init = jax.nn.initializers.variance_scaling(
scale=0.1, mode="fan_in", distribution="truncated_normal")
if self.gate_type == "bias":
self.gate_bias = self.param("gate_bias", gate_bias_init,
(self.num_output_features,), jnp.float32)
elif self.gate_type == "full":
self.gate_layer = nn.Dense(self.num_output_features,
use_bias=True,
bias_init=gate_bias_init,
kernel_init=gate_kernel_init,
dtype=self.dtype)
elif self.gate_type == "lstm":
self.input_gate = nn.Dense(self.num_output_features,
use_bias=True,
bias_init=gate_bias_init,
kernel_init=gate_kernel_init,
dtype=self.dtype)
self.forget_gate = nn.Dense(self.num_output_features,
use_bias=True,
bias_init=gate_bias_init,
kernel_init=gate_kernel_init,
dtype=self.dtype)
else:
raise ValueError("Unsupported gate_type: %s" % self.gate_type)
def _gate(self, y_hidden: Array, state: Array, y_out: Array) -> Array:
"""Compute the value to use for the gate."""
if self.gate_type == "residual":
# Residual connection: just add y_out to the state.
logging.info("mlp: residual")
return state + y_out
elif self.gate_type == "bias":
# Simple gate: use a gru_style gate with a learned bias (no kernel).
bias = jnp.asarray(self.gate_bias, dtype=self.dtype)
bias = jnp.reshape(bias, (1,) * (y_out.ndim - 1) + (-1,)) # batch dims.
g = jax.nn.sigmoid(bias)
logging.info("mlp: gate bias = %r", g)
return (state * g) + (y_out * (1 - g))
elif self.gate_type == "full":
# Normal GRU style gate -- compute g using both a kernel and bias.
g = jax.nn.sigmoid(self.gate_layer(y_hidden) + 1) # biased to remember
logging.info("mlp: gate full = %r", g)
return (state * g) + (y_out * (1 - g))
elif self.gate_type == "lstm":
# LSTM style gate with input and forget gates.
fg = jax.nn.sigmoid(self.forget_gate(y_hidden) + 1) # biased to remember
ig = jax.nn.sigmoid(self.input_gate(y_hidden) - 1)
logging.info("mlp: gate lstm = %r, %r", ig, fg)
return (state * fg) + (y_out * ig)
else:
raise ValueError("Unsupported gate type %s" % self.gate_type)
def __call__(self, x: Array, state: Optional[Array],
apply_dropout: bool = False,
dropout_rate: float = 0.0,
drop_tile_shape: Optional[Shape] = None,
rng_function: Optional[Callable[[], Any]] = None) -> Array:
"""Apply the multi-layer perceptron to the input x.
For simple MLPs, returns f(x), where f is the MLP function.
For resnets and gated architectures, it returns
state + f(x) -- for resnet.
g*state + (1-g)*f(x) -- for gated architecture, where g is the gate.
Args:
x: The input to the MLP.
state: The prior value, if this MLP is used as part of a resnet or gated
architecture.
apply_dropout: If true, applies dropout to the result.
dropout_rate: The dropout rate to use.
drop_tile_shape: The dropout tile shape.
rng_function: Gets a random number seed for dropout.
Returns:
The combination of f(x) and the (optional) prior state.
"""
x = jnp.asarray(x, self.dtype)
hidden_act_fun = get_activation_function(self.hidden_activation)
final_act_fun = get_activation_function(self.final_activation)
if self.hidden_layers:
# Apply some number of hidden layers.
y = x
for layer in self.hidden_layers:
logging.info("mlp: hidden %d, %s", self.num_hidden_units,
self.hidden_activation)
y = hidden_act_fun(layer(y))
else:
# Apply the hidden activation function to the input.
logging.info("mlp: activation = %s", self.hidden_activation)
y = hidden_act_fun(x)
y_hidden = y # The hidden layer right before the output.
logging.info("mlp: final activation = %s", self.final_activation)
y_out = self.output_layer(y_hidden) # The MLP final output.
y_out = final_act_fun(y_out) # Apply final activation function.
logging.info("mlp: final = %r", y_out)
# Optionally apply dropout to the output.
if apply_dropout:
if drop_tile_shape is None:
raise ValueError("drop_tile_shape must be specified for dropout.")
if rng_function is None:
raise ValueError("rng_function must be specified for dropout.")
logging.info("mlp: dropout rate = %s", dropout_rate)
y_out = tiled_dropout(
y_out, shape=drop_tile_shape, dropout_rate=dropout_rate,
rng_function=rng_function, deterministic=False)
if state is None:
# Simple MLP. No gate to combine y_out with the state.
assert self.gate_type is None
logging.info("mlp: gate type = None.")
return y_out
# When using state, gate_type must be specified.
assert self.gate_type is not None
return self._gate(y_hidden, state, y_out)
# Modified slightly from the flax implementation.
@gin.configurable
class LayerNorm(nn.Module):
"""Layer normalization (https://arxiv.org/abs/1607.06450).
Operates on the last axis of the input data.
It normalizes the activations of the layer for each given example in a
batch independently, rather than across a batch like Batch Normalization.
i.e. applies a transformation that maintains the mean activation within
each example close to 0 and the activation standard deviation close to 1.
Attributes:
epsilon: A small float added to variance to avoid dividing by zero.
dtype: the dtype of the computation (default: float32).
use_bias: If True, bias (beta) is added.
use_scale: If True, multiply by scale (gamma).
use_mean: If True, compute and adjust for the mean.
Note that that T5X layernorm does not use the mean.
Empirically, ignoring the mean can stabilize learning in transformers.
use_scalar_scale_bias: If True, using a single scalar for scale & bias.
enable_layernorm: If False, does not perform layernorm.
bias_init: Initializer for bias, by default, zero.
scale_init: Initializer for scale, by default, one.
"""
epsilon: float = 1e-6
dtype: Any = jnp.float32
use_scale: bool = True # Apply a learned scale.
use_bias: bool = False # Apply a learned bias.
use_mean: bool = False # Calculate and adjust for the mean.
use_scalar_scale_bias: bool = False # Learn a single scalar scale & bias.
enable_layernorm: bool = True # Turn off layernorm if false.
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros
scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones
@nn.compact
def __call__(self, x):
"""Applies layer normalization on the input.
Args:
x: the inputs
Returns:
Normalized inputs (the same shape as inputs).
"""
if not self.enable_layernorm:
return x
x = jnp.asarray(x)
# Calculate mean and variance at higher precision.
xf = jnp.asarray(x, jnp.float32)
if self.use_mean:
mean = jnp.mean(xf, axis=-1, keepdims=True)
xf = xf - mean
var = jnp.mean(lax.square(xf), axis=-1, keepdims=True)
mul = lax.rsqrt(var + self.epsilon)
# Rescale x
# if not use_mean, then rescale around zero instead. (A simplification.)
if self.use_mean:
y = (x - mean) * mul
else:
y = x * mul
if self.use_scalar_scale_bias:
# Learn a single scalar value for bias and scale.
# (Which mirrors the single value for mean and stddev above.)
num_scale_bias_features = 1
else:
# Learn a different value per neuron/feature for bias and scale.
num_scale_bias_features = x.shape[-1]
# Apply learned scale and bias.
if self.use_scale:
y = y * jnp.asarray(
self.param("scale", self.scale_init, (num_scale_bias_features,)),
dtype=self.dtype)
if self.use_bias:
y = y + jnp.asarray(
self.param("bias", self.bias_init, (num_scale_bias_features,)),
dtype=self.dtype)
return y.astype(self.dtype)