|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" Flax Bart model.""" |
|
|
|
import math |
|
import random |
|
from functools import partial |
|
from typing import Optional, Tuple |
|
|
|
import numpy as np |
|
|
|
import flax.linen as nn |
|
import jax |
|
import jax.numpy as jnp |
|
from flax.core.frozen_dict import FrozenDict, unfreeze |
|
from flax.linen import combine_masks, make_causal_mask |
|
from flax.linen import partitioning as nn_partitioning |
|
from flax.linen.attention import dot_product_attention_weights |
|
from jax import lax |
|
from jax.random import PRNGKey |
|
|
|
from transformers.modeling_flax_outputs import ( |
|
FlaxBaseModelOutputWithPastAndCrossAttentions, |
|
FlaxCausalLMOutputWithCrossAttentions, |
|
) |
|
from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel |
|
|
|
from models import BartConfig |
|
|
|
|
|
scan_with_axes = nn_partitioning.scan_with_axes |
|
remat = nn_partitioning.remat |
|
|
|
|
|
def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray: |
|
""" |
|
Shift input ids one token to the right. |
|
""" |
|
shifted_input_ids = np.zeros_like(input_ids) |
|
shifted_input_ids[:, 1:] = input_ids[:, :-1] |
|
shifted_input_ids[:, 0] = decoder_start_token_id |
|
|
|
shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) |
|
return shifted_input_ids |
|
|
|
|
|
class FlaxBartAttention(nn.Module): |
|
config: BartConfig |
|
embed_dim: int |
|
num_heads: int |
|
dropout: float = 0.0 |
|
causal: bool = False |
|
bias: bool = True |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self) -> None: |
|
self.head_dim = self.embed_dim // self.num_heads |
|
if self.head_dim * self.num_heads != self.embed_dim: |
|
raise ValueError( |
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" |
|
f" and `num_heads`: {self.num_heads})." |
|
) |
|
|
|
dense = partial( |
|
nn.Dense, |
|
self.embed_dim, |
|
use_bias=self.bias, |
|
dtype=self.dtype, |
|
kernel_init=jax.nn.initializers.normal(self.config.init_std), |
|
) |
|
|
|
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() |
|
|
|
self.fused_proj = nn.Dense( |
|
self.embed_dim * 3, |
|
use_bias=self.bias, |
|
dtype=self.dtype, |
|
kernel_init=jax.nn.initializers.normal(self.config.init_std), |
|
) |
|
|
|
self.fused_key_value = nn.Dense( |
|
self.embed_dim * 2, |
|
use_bias=self.bias, |
|
dtype=self.dtype, |
|
kernel_init=jax.nn.initializers.normal(self.config.init_std), |
|
) |
|
|
|
self.out_proj = dense() |
|
|
|
self.dropout_layer = nn.Dropout(rate=self.dropout) |
|
|
|
if self.causal: |
|
self.causal_mask = make_causal_mask( |
|
jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" |
|
) |
|
|
|
def _split_heads(self, hidden_states): |
|
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) |
|
|
|
def _merge_heads(self, hidden_states): |
|
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) |
|
|
|
@nn.compact |
|
def _concatenate_to_cache(self, key, value, query, attention_mask): |
|
""" |
|
This function takes projected key, value states from a single input token and concatenates the states to cached |
|
states from previous steps. This function is slighly adapted from the official Flax repository: |
|
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 |
|
""" |
|
|
|
is_initialized = self.has_variable("cache", "cached_key") |
|
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) |
|
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) |
|
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) |
|
|
|
if is_initialized: |
|
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape |
|
|
|
cur_index = cache_index.value |
|
indices = (0,) * len(batch_dims) + (cur_index, 0, 0) |
|
key = lax.dynamic_update_slice(cached_key.value, key, indices) |
|
value = lax.dynamic_update_slice(cached_value.value, value, indices) |
|
cached_key.value = key |
|
cached_value.value = value |
|
num_updated_cache_vectors = query.shape[1] |
|
cache_index.value = cache_index.value + num_updated_cache_vectors |
|
|
|
pad_mask = jnp.broadcast_to( |
|
jnp.arange(max_length) < cur_index + num_updated_cache_vectors, |
|
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), |
|
) |
|
attention_mask = combine_masks(pad_mask, attention_mask) |
|
return key, value, attention_mask |
|
|
|
def __call__( |
|
self, |
|
hidden_states: jnp.ndarray, |
|
key_value_states: Optional[jnp.ndarray] = None, |
|
attention_mask: Optional[jnp.ndarray] = None, |
|
init_cache: bool = False, |
|
deterministic: bool = True, |
|
) -> Tuple[jnp.ndarray]: |
|
"""Input shape: Batch x Time x Channel""" |
|
|
|
|
|
|
|
is_cross_attention = key_value_states is not None |
|
batch_size = hidden_states.shape[0] |
|
|
|
if self.config.fuse_matmuls: |
|
|
|
if is_cross_attention: |
|
|
|
query_states = self.q_proj(hidden_states) |
|
|
|
attention_states = self.fused_key_value(key_value_states) |
|
key_states, value_states = jnp.split(attention_states, 2, axis=-1) |
|
else: |
|
attention_states = self.fused_proj(hidden_states) |
|
query_states, key_states, value_states = jnp.split(attention_states, 3, axis=-1) |
|
|
|
else: |
|
|
|
query_states = self.q_proj(hidden_states) |
|
|
|
if is_cross_attention: |
|
|
|
key_states = self.k_proj(key_value_states) |
|
value_states = self.v_proj(key_value_states) |
|
else: |
|
|
|
key_states = self.k_proj(hidden_states) |
|
value_states = self.v_proj(hidden_states) |
|
|
|
query_states = self._split_heads(query_states) |
|
key_states = self._split_heads(key_states) |
|
value_states = self._split_heads(value_states) |
|
|
|
|
|
if self.causal: |
|
query_length, key_length = query_states.shape[1], key_states.shape[1] |
|
if self.has_variable("cache", "cached_key"): |
|
mask_shift = self.variables["cache"]["cache_index"] |
|
max_decoder_length = self.variables["cache"]["cached_key"].shape[1] |
|
causal_mask = lax.dynamic_slice( |
|
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) |
|
) |
|
else: |
|
causal_mask = self.causal_mask[:, :, :query_length, :key_length] |
|
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) |
|
|
|
|
|
if attention_mask is not None and self.causal: |
|
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) |
|
attention_mask = combine_masks(attention_mask, causal_mask) |
|
elif self.causal: |
|
attention_mask = causal_mask |
|
elif attention_mask is not None: |
|
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) |
|
|
|
|
|
|
|
if self.causal and (self.has_variable("cache", "cached_key") or init_cache): |
|
key_states, value_states, attention_mask = self._concatenate_to_cache( |
|
key_states, value_states, query_states, attention_mask |
|
) |
|
|
|
|
|
if attention_mask is not None: |
|
|
|
attention_bias = lax.select( |
|
attention_mask > 0, |
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), |
|
jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype), |
|
) |
|
else: |
|
attention_bias = None |
|
|
|
dropout_rng = None |
|
if not deterministic and self.dropout > 0.0: |
|
dropout_rng = self.make_rng("dropout") |
|
|
|
attn_weights = dot_product_attention_weights( |
|
query_states, |
|
key_states, |
|
bias=attention_bias, |
|
dropout_rng=dropout_rng, |
|
dropout_rate=self.dropout, |
|
broadcast_dropout=True, |
|
deterministic=deterministic, |
|
dtype=self.dtype, |
|
precision=None, |
|
) |
|
|
|
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) |
|
attn_output = self._merge_heads(attn_output) |
|
attn_output = self.out_proj(attn_output) |
|
|
|
return attn_output, attn_weights |
|
|
|
|
|
class FlaxBartDecoderLayer(nn.Module): |
|
config: BartConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self) -> None: |
|
self.embed_dim = self.config.d_model |
|
self.self_attn = FlaxBartAttention( |
|
config=self.config, |
|
embed_dim=self.embed_dim, |
|
num_heads=self.config.decoder_attention_heads, |
|
dropout=self.config.attention_dropout, |
|
causal=True, |
|
dtype=self.dtype, |
|
) |
|
self.dropout_layer = nn.Dropout(rate=self.config.dropout) |
|
self.activation_fn = ACT2FN[self.config.activation_function] |
|
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) |
|
|
|
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) |
|
self.encoder_attn = FlaxBartAttention( |
|
config=self.config, |
|
embed_dim=self.embed_dim, |
|
num_heads=self.config.decoder_attention_heads, |
|
dropout=self.config.attention_dropout, |
|
dtype=self.dtype, |
|
) |
|
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) |
|
self.fc1 = nn.Dense( |
|
self.config.encoder_ffn_dim, |
|
dtype=self.dtype, |
|
kernel_init=jax.nn.initializers.normal(self.config.init_std), |
|
) |
|
self.fc2 = nn.Dense( |
|
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) |
|
) |
|
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) |
|
|
|
def __call__( |
|
self, |
|
hidden_states: jnp.ndarray, |
|
attention_mask: jnp.ndarray, |
|
encoder_hidden_states: Optional[jnp.ndarray] = None, |
|
encoder_attention_mask: Optional[jnp.ndarray] = None, |
|
init_cache: bool = False, |
|
output_attentions: bool = True, |
|
deterministic: bool = True, |
|
) -> Tuple[jnp.ndarray]: |
|
|
|
if self.config.use_scan: |
|
hidden_states = hidden_states[0] |
|
|
|
residual = hidden_states |
|
|
|
|
|
hidden_states, self_attn_weights = self.self_attn( |
|
hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache |
|
) |
|
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) |
|
hidden_states = residual + hidden_states |
|
hidden_states = self.self_attn_layer_norm(hidden_states) |
|
|
|
|
|
cross_attn_weights = None |
|
if encoder_hidden_states is not None: |
|
residual = hidden_states |
|
|
|
hidden_states, cross_attn_weights = self.encoder_attn( |
|
hidden_states=hidden_states, |
|
key_value_states=encoder_hidden_states, |
|
attention_mask=encoder_attention_mask, |
|
) |
|
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) |
|
hidden_states = residual + hidden_states |
|
hidden_states = self.encoder_attn_layer_norm(hidden_states) |
|
|
|
|
|
residual = hidden_states |
|
hidden_states = self.activation_fn(self.fc1(hidden_states)) |
|
hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) |
|
hidden_states = self.fc2(hidden_states) |
|
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) |
|
hidden_states = residual + hidden_states |
|
hidden_states = self.final_layer_norm(hidden_states) |
|
|
|
outputs = (hidden_states,) |
|
|
|
if output_attentions: |
|
outputs += (self_attn_weights, cross_attn_weights) |
|
|
|
if self.config.use_scan: |
|
outputs = (outputs, None) |
|
|
|
return outputs |
|
|
|
|
|
class FlaxBartDecoderLayerCollection(nn.Module): |
|
config: BartConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
@nn.compact |
|
def __call__( |
|
self, |
|
hidden_states, |
|
attention_mask, |
|
encoder_hidden_states: Optional[jnp.ndarray] = None, |
|
encoder_attention_mask: Optional[jnp.ndarray] = None, |
|
deterministic: bool = True, |
|
init_cache: bool = False, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
): |
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
all_self_attns = () if output_attentions else None |
|
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None |
|
|
|
num_decoder_layers = self.config.decoder_layers |
|
BlockDecoderLayer = ( |
|
remat( |
|
FlaxBartDecoderLayer, |
|
static_argnums=(4, 5, 6), |
|
prevent_cse=not self.config.use_scan, |
|
) |
|
if self.config.gradient_checkpointing |
|
else FlaxBartDecoderLayer |
|
) |
|
|
|
if self.config.use_scan: |
|
|
|
assert not output_attentions, "cannot use `scan` with `output_attentions` set to `True`" |
|
assert not output_hidden_states, "cannot use `scan` with `output_hidden_states` set to `True`" |
|
hidden_states = (hidden_states,) |
|
|
|
|
|
hidden_states, _ = scan_with_axes( |
|
BlockDecoderLayer, |
|
variable_axes={"params": 0, "cache": 0}, |
|
split_rngs={"params": True, "dropout": True}, |
|
in_axes=(nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast), |
|
length=num_decoder_layers, |
|
)(self.config, dtype=self.dtype, name="FlaxBartDecoderLayers")( |
|
hidden_states, |
|
attention_mask, |
|
encoder_hidden_states, |
|
encoder_attention_mask, |
|
init_cache, |
|
output_attentions, |
|
deterministic, |
|
) |
|
hidden_states = hidden_states[0] |
|
|
|
else: |
|
for layer in range(num_decoder_layers): |
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
dropout_probability = random.uniform(0, 1) |
|
if not deterministic and (dropout_probability < self.config.decoder_layerdrop): |
|
layer_outputs = (None, None, None) |
|
else: |
|
layer_outputs = BlockDecoderLayer(self.config, dtype=self.dtype, name=str(layer),)( |
|
hidden_states, |
|
attention_mask, |
|
encoder_hidden_states, |
|
encoder_attention_mask, |
|
init_cache, |
|
output_attentions, |
|
deterministic, |
|
) |
|
|
|
hidden_states = layer_outputs[0] |
|
if output_attentions: |
|
all_self_attns += (layer_outputs[1],) |
|
|
|
if encoder_hidden_states is not None: |
|
all_cross_attentions += (layer_outputs[2],) |
|
|
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions] |
|
|
|
if not return_dict: |
|
return tuple(v for v in outputs if v is not None) |
|
|
|
return FlaxBaseModelOutputWithPastAndCrossAttentions( |
|
last_hidden_state=hidden_states, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attns, |
|
cross_attentions=all_cross_attentions, |
|
) |
|
|
|
|
|
class FlaxBartDecoder(nn.Module): |
|
config: BartConfig |
|
embed_tokens: nn.Embed |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.dropout_layer = nn.Dropout(rate=self.config.dropout) |
|
|
|
embed_dim = self.config.d_model |
|
self.padding_idx = self.config.pad_token_id |
|
self.max_target_positions = self.config.max_position_embeddings |
|
self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0 |
|
|
|
|
|
|
|
self.offset = 2 |
|
self.embed_positions = nn.Embed( |
|
self.config.max_position_embeddings + self.offset, |
|
embed_dim, |
|
embedding_init=jax.nn.initializers.normal(self.config.init_std), |
|
) |
|
|
|
self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype) |
|
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) |
|
|
|
def __call__( |
|
self, |
|
input_ids, |
|
attention_mask, |
|
position_ids, |
|
encoder_hidden_states: Optional[jnp.ndarray] = None, |
|
encoder_attention_mask: Optional[jnp.ndarray] = None, |
|
init_cache: bool = False, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
deterministic: bool = True, |
|
): |
|
input_shape = input_ids.shape |
|
input_ids = input_ids.reshape(-1, input_shape[-1]) |
|
|
|
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale |
|
|
|
|
|
positions = self.embed_positions(position_ids + self.offset) |
|
|
|
hidden_states = inputs_embeds + positions |
|
hidden_states = self.layernorm_embedding(hidden_states) |
|
|
|
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) |
|
|
|
outputs = self.layers( |
|
hidden_states, |
|
attention_mask, |
|
encoder_hidden_states, |
|
encoder_attention_mask, |
|
deterministic=deterministic, |
|
init_cache=init_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
if not return_dict: |
|
return outputs |
|
|
|
return FlaxBaseModelOutputWithPastAndCrossAttentions( |
|
last_hidden_state=outputs.last_hidden_state, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
cross_attentions=outputs.cross_attentions, |
|
) |
|
|
|
|
|
class FlaxBartDecoderPreTrainedModel(FlaxPreTrainedModel): |
|
config_class = BartConfig |
|
base_model_prefix: str = "model" |
|
module_class: nn.Module = None |
|
|
|
def __init__( |
|
self, |
|
config: BartConfig, |
|
input_shape: Tuple[int] = (1, 1), |
|
seed: int = 0, |
|
dtype: jnp.dtype = jnp.float32, |
|
_do_init: bool = True, |
|
**kwargs |
|
): |
|
config.is_decoder = True |
|
config.is_encoder_decoder = False |
|
module = self.module_class(config=config, dtype=dtype, **kwargs) |
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) |
|
|
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: |
|
|
|
input_ids = jnp.zeros(input_shape, dtype="i4") |
|
attention_mask = jnp.ones_like(input_ids) |
|
|
|
batch_size, sequence_length = input_ids.shape |
|
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) |
|
|
|
params_rng, dropout_rng = jax.random.split(rng) |
|
rngs = {"params": params_rng, "dropout": dropout_rng} |
|
encoder_hidden_states = jnp.zeros(input_shape + (self.config.d_model,)) |
|
encoder_attention_mask = attention_mask |
|
module_init_outputs = self.module.init( |
|
rngs, |
|
input_ids, |
|
attention_mask, |
|
position_ids, |
|
encoder_hidden_states, |
|
encoder_attention_mask, |
|
return_dict=False, |
|
) |
|
return module_init_outputs["params"] |
|
|
|
def init_cache(self, batch_size, max_length): |
|
r""" |
|
Args: |
|
batch_size (`int`): |
|
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. |
|
max_length (`int`): |
|
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized |
|
cache. |
|
""" |
|
|
|
input_ids = jnp.ones((batch_size, max_length), dtype="i4") |
|
attention_mask = jnp.ones_like(input_ids, dtype="i4") |
|
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) |
|
|
|
init_variables = self.module.init( |
|
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True |
|
) |
|
return unfreeze(init_variables["cache"]) |
|
|
|
def __call__( |
|
self, |
|
input_ids: jnp.ndarray, |
|
attention_mask: Optional[jnp.ndarray] = None, |
|
position_ids: Optional[jnp.ndarray] = None, |
|
encoder_hidden_states: Optional[jnp.ndarray] = None, |
|
encoder_attention_mask: Optional[jnp.ndarray] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
train: bool = False, |
|
params: dict = None, |
|
past_key_values: dict = None, |
|
dropout_rng: PRNGKey = None, |
|
): |
|
""" |
|
Args: |
|
input_ids (`jnp.ndarray` of shape `(target_batch_size, target_sequence_length)`): |
|
Indices of decoder input sequence tokens in the vocabulary. |
|
|
|
Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
[What are decoder input IDs?](../glossary#decoder-input-ids) |
|
|
|
For translation and summarization training, `decoder_input_ids` should be provided. If no |
|
`decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right |
|
for denoising pre-training following the paper. |
|
attention_mask (`jnp.ndarray` of shape `(target_batch_size, target_sequence_length)`, *optional*): |
|
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also |
|
be used by default. |
|
|
|
If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the |
|
paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. |
|
position_ids (`numpy.ndarray` of shape `(target_batch_size, sequence_length)`, *optional*): |
|
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the |
|
range `[0, config.max_position_embeddings - 1]`. |
|
encoder_hidden_states (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`): |
|
A sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. |
|
encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): |
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
|
tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
|
more detail. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): |
|
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast |
|
auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. |
|
""" |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
|
|
if encoder_hidden_states is not None and encoder_attention_mask is None: |
|
batch_size, sequence_length = encoder_hidden_states.shape[:2] |
|
encoder_attention_mask = jnp.ones((batch_size, sequence_length)) |
|
|
|
|
|
if attention_mask is None: |
|
attention_mask = jnp.ones_like(input_ids) |
|
if position_ids is None: |
|
batch_size, sequence_length = input_ids.shape |
|
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) |
|
|
|
|
|
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} |
|
|
|
inputs = {"params": params or self.params} |
|
|
|
|
|
|
|
|
|
if past_key_values: |
|
inputs["cache"] = past_key_values |
|
mutable = ["cache"] |
|
else: |
|
mutable = False |
|
|
|
outputs = self.module.apply( |
|
inputs, |
|
input_ids=jnp.array(input_ids, dtype="i4"), |
|
attention_mask=jnp.array(attention_mask, dtype="i4"), |
|
position_ids=jnp.array(position_ids, dtype="i4"), |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
deterministic=not train, |
|
rngs=rngs, |
|
mutable=mutable, |
|
) |
|
|
|
|
|
if past_key_values is not None and return_dict: |
|
outputs, past_key_values = outputs |
|
outputs["past_key_values"] = unfreeze(past_key_values["cache"]) |
|
return outputs |
|
elif past_key_values is not None and not return_dict: |
|
outputs, past_key_values = outputs |
|
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] |
|
|
|
return outputs |
|
|
|
|
|
class FlaxBartDecoderWrapper(nn.Module): |
|
""" |
|
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is |
|
used in combination with the [`EncoderDecoderModel`] framework. |
|
""" |
|
|
|
config: BartConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
embed_dim = self.config.d_model |
|
embed_tokens = nn.Embed( |
|
self.config.vocab_size, |
|
embed_dim, |
|
embedding_init=jax.nn.initializers.normal(self.config.init_std), |
|
) |
|
self.decoder = FlaxBartDecoder(config=self.config, embed_tokens=embed_tokens, dtype=self.dtype) |
|
|
|
def __call__(self, *args, **kwargs): |
|
return self.decoder(*args, **kwargs) |
|
|
|
|
|
class FlaxBartForCausalLMModule(nn.Module): |
|
"""Bart Decoder Module with a language modeling head on top (linear layer with weights tied to the input embeddings) |
|
e.g. for autoregressive tasks. |
|
""" |
|
|
|
config: BartConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.model = FlaxBartDecoderWrapper(config=self.config, dtype=self.dtype) |
|
self.lm_head = nn.Dense( |
|
self.config.vocab_size, |
|
use_bias=False, |
|
dtype=self.dtype, |
|
kernel_init=jax.nn.initializers.normal(self.config.init_std), |
|
) |
|
|
|
def __call__( |
|
self, |
|
input_ids, |
|
attention_mask, |
|
position_ids, |
|
encoder_hidden_states: Optional[jnp.ndarray] = None, |
|
encoder_attention_mask: Optional[jnp.ndarray] = None, |
|
init_cache: bool = False, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
deterministic: bool = True, |
|
): |
|
|
|
outputs = self.model( |
|
input_ids, |
|
attention_mask, |
|
position_ids, |
|
encoder_hidden_states, |
|
encoder_attention_mask, |
|
deterministic=deterministic, |
|
init_cache=init_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
hidden_states = outputs[0] |
|
|
|
if self.config.tie_word_embeddings: |
|
shared_embedding = self.model.variables["params"]["decoder"]["embed_tokens"]["embedding"] |
|
lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) |
|
else: |
|
lm_logits = self.lm_head(hidden_states) |
|
|
|
if not return_dict: |
|
return (lm_logits,) + outputs[1:] |
|
|
|
return FlaxCausalLMOutputWithCrossAttentions( |
|
logits=lm_logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
cross_attentions=outputs.cross_attentions, |
|
) |
|
|
|
|
|
class FlaxBartForCausalLM(FlaxBartDecoderPreTrainedModel): |
|
"""Bart Decoder Model with a language modeling head on top (linear layer with weights tied to the input embeddings) |
|
e.g. for autoregressive tasks. |
|
""" |
|
|
|
module_class = FlaxBartForCausalLMModule |
|
|
|
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): |
|
|
|
batch_size, seq_length = input_ids.shape |
|
|
|
past_key_values = self.init_cache(batch_size, max_length) |
|
|
|
|
|
|
|
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") |
|
if attention_mask is not None: |
|
position_ids = attention_mask.cumsum(axis=-1) - 1 |
|
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) |
|
else: |
|
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) |
|
|
|
return { |
|
"past_key_values": past_key_values, |
|
"attention_mask": extended_attention_mask, |
|
"position_ids": position_ids, |
|
} |
|
|
|
def update_inputs_for_generation(self, model_outputs, model_kwargs): |
|
model_kwargs["past_key_values"] = model_outputs.past_key_values |
|
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 |
|
return model_kwargs |
|
|