# Copyright 2023 DeepMind Technologies Limited
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Fast decoding routines for inference from a trained model.
to acommodate
(a) continued decoding from a previous beam cache.
(b) init with with a single beam and then expand into beam_size beams.
from typing import Any
import flax
import jax
from jax import lax
import jax.numpy as jnp
import numpy as np
# Constants
# "Effective negative infinity" constant for masking in beam search.
NEG_INF = np.array(-1.0e7)
# Beam search parameters
# Brevity penalty parameters
def brevity_penalty(alpha: float, length: int):
"""Brevity penalty function for beam search penalizing short sequences.
alpha: float: brevity-penalty scaling parameter.
length: int: length of considered sequence.
Brevity penalty score as jax scalar.
return jnp.power(
# Beam handling utility functions:
def add_beam_dim(x: jnp.ndarray, beam_size: int) -> jnp.ndarray:
"""Creates new beam dimension in non-scalar array and tiles into it."""
if x.ndim == 0: # ignore scalars (e.g. cache index)
return x
x = jnp.expand_dims(x, axis=1)
tile_dims = [1] * x.ndim
tile_dims[1] = beam_size
return jnp.tile(x, tile_dims)
def add_beam_dim_cache(
cache: tuple[dict[str, jnp.ndarray], ...], beam_size: int
) -> tuple[dict[str, jnp.ndarray], ...]:
"""Creates new beam dimension in non-scalar array and tiles into it."""
new_cache = []
for layer in cache:
new_layer = {}
for key, x in layer.items():
if key in ['keys', 'vals']:
x = add_beam_dim(x, beam_size)
new_layer[key] = x
return tuple(new_cache)
def flatten_beam_dim(x):
"""Flattens the first two dimensions of a non-scalar array."""
if x.ndim < 2: # ignore scalars (e.g. cache index)
return x
return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
def unflatten_beam_dim(x, batch_size, beam_size):
"""Unflattens the first, flat batch*beam dimension of a non-scalar array."""
if x.ndim == 0: # ignore scalars (e.g. cache index)
return x
assert batch_size * beam_size == x.shape[0]
return x.reshape((batch_size, beam_size) + x.shape[1:])
def flat_batch_beam_expand(x, beam_size):
"""Expands the each batch item by beam_size in batch_dimension."""
return flatten_beam_dim(add_beam_dim(x, beam_size))
def gather_beams(nested, beam_indices, batch_size, new_beam_size):
"""Gathers the beam slices indexed by beam_indices into new beam array.
nested: pytree of arrays or scalars (the latter ignored).
beam_indices: array of beam_indices
batch_size: int: size of batch.
new_beam_size: int: size of _new_ beam dimension.
New pytree with new beam arrays.
[batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...]
batch_indices = jnp.reshape(
jnp.arange(batch_size * new_beam_size) // new_beam_size,
(batch_size, new_beam_size),
def gather_fn(x):
if x.ndim == 0: # ignore scalars (e.g. cache index)
return x
return x[batch_indices, beam_indices]
return jax.tree_util.tree_map(gather_fn, nested)
def gather_topk_beams(nested, score_or_log_prob, batch_size, new_beam_size):
"""Gathers the top-k beam slices given by score_or_log_prob array.
nested: pytree of arrays or scalars (the latter ignored).
score_or_log_prob: [batch_size, old_beam_size] array of values to sort by
for top-k selection of beam slices.
batch_size: int: size of batch.
new_beam_size: int: size of _new_ top-k selected beam dimension
New pytree with new beam arrays containing top k new_beam_size slices.
[batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...]
_, topk_indices = lax.top_k(score_or_log_prob, k=new_beam_size)
topk_indices = jnp.flip(topk_indices, axis=1)
return gather_beams(nested, topk_indices, batch_size, new_beam_size)
def apply_on_cache(fn, cache, *args, **kwargs):
"""Apply fn(val) only when key is 'keys' or 'val'."""
new_cache = []
for layer in cache:
new_layer = {}
for key, val in layer.items():
if key in ['keys', 'values', 'current_index', 'relative_position_bias']:
val = fn(val, *args, **kwargs)
new_layer[key] = val
return tuple(new_cache)
# Beam search state:
class BeamState:
"""Holds beam search state data."""
# The position of the decoding loop in the length dimension.
cur_index: jax.Array # scalar int32: current decoded length index
# The active sequence log probabilities and finished sequence scores.
live_logprobs: jax.Array # float32: [batch_size, beam_size]
finished_scores: jax.Array # float32: [batch_size, beam_size]
# The current active-beam-searching and finished sequences.
live_seqs: jax.Array # int32: [batch_size, beam_size, max_decode_len]
finished_seqs: jax.Array # int32: [batch_size, beam_size,
# max_decode_len]
# Records which of the 'finished_seqs' is occupied and not a filler slot.
finished_flags: jax.Array # bool: [batch_size, beam_size]
# The current state of the autoregressive decoding caches.
cache: Any # Any pytree of arrays, e.g. flax attention Cache object
def beam_init(seed_token, batch_size, beam_size, max_decode_len, cache):
"""Initializes the beam search state data structure."""
cur_index0 = jnp.array(0)
live_logprobs0 = jnp.tile(
jnp.array([0.0] + [NEG_INF] * (beam_size - 1)), [batch_size, 1]
finished_scores0 = jnp.ones((batch_size, beam_size)) * NEG_INF
live_seqs0 = jnp.concatenate(
jnp.reshape(seed_token, (batch_size, beam_size, 1)),
jnp.zeros((batch_size, beam_size, max_decode_len - 1), jnp.int32),
) # (batch, beam, max_decode_len)
finished_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32)
finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_)
beam_cache0 = apply_on_cache(lambda x: jnp.expand_dims(x, axis=0), cache)
return BeamState(
# Beam search routine:
def beam_search_flat(
"""Beam search for LM.
inputs and cache is already flat! i.e. first dimention == batch*beam.
seed_token: array: [beam_size, 1] int32 sequence of tokens.
cache: flax attention cache.
tokens_to_logits: fast autoregressive decoder function taking single token
slices and cache and returning next-token logits and updated cache.
alpha: float: scaling factor for brevity penalty.
eos: array: [vocab] 1 for end-of-sentence tokens, 0 for not.
max_decode_len: int: maximum length of decoded translations.
mask: array: [vocab] binary mask for vocab. 1 to keep the prob, 0 to set the
prob := 0.
Tuple of:
[beam_size, max_decode_len] top-scoring sequences
[beam_size] beam-search scores.
# We liberally annotate shape information for clarity below.
batch_size, beam_size = 1, seed_token.shape[0]
mask = mask.reshape((1, 1, -1))
eos = eos.reshape((1, 1, -1))
mask_bias = (1 - mask) * NEG_INF
# initialize beam search state
beam_search_init_state = beam_init(
seed_token, batch_size, beam_size, max_decode_len, cache
def beam_search_loop_cond_fn(state):
"""Beam search loop termination condition."""
# Have we reached max decoding length?
not_at_end = state.cur_index < max_decode_len - 1
# Is no further progress in the beam search possible?
# Get the best possible scores from alive sequences.
min_brevity_penalty = brevity_penalty(alpha, max_decode_len)
best_live_scores = state.live_logprobs[:, -1:] / min_brevity_penalty
# Get the worst scores from finished sequences.
worst_finished_scores = jnp.min(
state.finished_scores, axis=1, keepdims=True
# Mask out scores from slots without any actual finished sequences.
worst_finished_scores = jnp.where(
state.finished_flags, worst_finished_scores, NEG_INF
# If no best possible live score is better than current worst finished
# scores, the search cannot improve the finished set further.
search_terminated = jnp.all(worst_finished_scores > best_live_scores)
# If we're not at the max decode length, and the search hasn't terminated,
# continue looping.
return not_at_end & (~search_terminated)
def beam_search_loop_body_fn(state):
"""Beam search loop state update function."""
# Collect the current position slice along length to feed the fast
# autoregressive decoder model. Flatten the beam dimension into batch
# dimension for feeding into the model.
# --> [batch * beam, 1]
flat_ids = flatten_beam_dim(
state.live_seqs, (0, 0, state.cur_index), (batch_size, beam_size, 1)
# Flatten beam dimension into batch to be compatible with model.
# {[batch, beam, ...], ...} --> {[batch * beam, ...], ...}
flat_cache = apply_on_cache(flatten_beam_dim, state.cache)
# Call fast-decoder model on current tokens to get next-position logits.
# --> [batch * beam, vocab]
flat_logits, new_flat_cache = tokens_to_logits(flat_ids, flat_cache)
# unflatten beam dimension
# [batch * beam, vocab] --> [batch, beam, vocab]
logits = unflatten_beam_dim(flat_logits, batch_size, beam_size)
# Unflatten beam dimension in attention cache arrays
# {[batch * beam, ...], ...} --> {[batch, beam, ...], ...}
new_cache = apply_on_cache(
unflatten_beam_dim, new_flat_cache, batch_size, beam_size
# Gather log probabilities from logits
candidate_log_probs = jax.nn.log_softmax(logits)
# Add new logprobs to existing prefix logprobs.
# --> [batch, beam, vocab]
log_probs = candidate_log_probs + jnp.expand_dims(
state.live_logprobs, axis=2
# We'll need the vocab size, gather it from the log probability dimension.
vocab_size = log_probs.shape[2]
# mask away some tokens.
log_probs += mask_bias # [batch,beam,vocab]+[1,1,vocab]
# Each item in batch has beam_size * vocab_size candidate sequences.
# For each item, get the top 2*k candidates with the highest log-
# probabilities. We gather the top 2*K beams here so that even if the best
# K sequences reach EOS simultaneously, we have another K sequences
# remaining to continue the live beam search.
beams_to_keep = 2 * beam_size
# Flatten beam and vocab dimensions.
flat_log_probs = log_probs.reshape((batch_size, beam_size * vocab_size))
# Gather the top 2*K scores from _all_ beams.
# --> [batch, 2*beams], [batch, 2*beams]
topk_log_probs, topk_indices = lax.top_k(flat_log_probs, k=beams_to_keep)
# Recover the beam index by floor division.
topk_beam_indices = topk_indices // vocab_size
# Gather 2*k top beams.
# --> [batch, 2*beams, length]
topk_seq = gather_beams(
state.live_seqs, topk_beam_indices, batch_size, beams_to_keep
# Append the most probable 2*K token IDs to the top 2*K sequences
# Recover token id by modulo division and expand Id array for broadcasting.
# --> [batch, 2*beams, 1]
topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
# Update sequences for the 2*K top-k new sequences.
# --> [batch, 2*beams, length]
topk_seq = lax.dynamic_update_slice(
topk_seq, topk_ids, (0, 0, state.cur_index + 1)
# Update LIVE (in-progress) sequences:
# Did any of these sequences reach an end marker?
# --> [batch, 2*beams]
last_token = topk_seq[:, :, state.cur_index + 1]
last_token = jax.nn.one_hot(last_token, vocab_size, dtype=jnp.bfloat16)
# any([batch, 2b, vocab] * [1, 1, vocab], axis=-1) == [batch, 2b]
newly_finished = jnp.any(last_token * eos, axis=-1)
# To prevent these newly finished sequences from being added to the LIVE
# set of active beam search sequences, set their log probs to a very large
# negative value.
new_log_probs = topk_log_probs + newly_finished * NEG_INF
# Determine the top k beam indices (from top 2*k beams) from log probs.
# --> [batch, beams]
_, new_topk_indices = lax.top_k(new_log_probs, k=beam_size)
new_topk_indices = jnp.flip(new_topk_indices, axis=1)
# Gather the top k beams (from top 2*k beams).
# --> [batch, beams, length], [batch, beams]
top_alive_seq, top_alive_log_probs = gather_beams(
[topk_seq, new_log_probs], new_topk_indices, batch_size, beam_size
# Determine the top k beam indices from the original set of all beams.
# --> [batch, beams]
top_alive_indices = gather_beams(
topk_beam_indices, new_topk_indices, batch_size, beam_size
# With these, gather the top k beam-associated caches.
# --> {[batch, beams, ...], ...}
top_alive_cache = apply_on_cache(
gather_beams, new_cache, top_alive_indices, batch_size, beam_size
# Update FINISHED (reached end of sentence) sequences:
# Calculate new seq scores from log probabilities.
new_scores = topk_log_probs / brevity_penalty(alpha, state.cur_index + 1)
# Mask out the still unfinished sequences by adding large negative value.
# --> [batch, 2*beams]
new_scores += (~newly_finished) * NEG_INF
# Combine sequences, scores, and flags along the beam dimension and compare
# new finished sequence scores to existing finished scores and select the
# best from the new set of beams.
finished_seqs = jnp.concatenate( # --> [batch, 3*beams, length]
[state.finished_seqs, topk_seq], axis=1
finished_scores = jnp.concatenate( # --> [batch, 3*beams]
[state.finished_scores, new_scores], axis=1
finished_flags = jnp.concatenate( # --> [batch, 3*beams]
[state.finished_flags, newly_finished], axis=1
# --> [batch, beams, length], [batch, beams], [batch, beams]
top_finished_seq, top_finished_scores, top_finished_flags = (
[finished_seqs, finished_scores, finished_flags],
return BeamState(
cur_index=state.cur_index + 1,
# Run while loop and get final beam search state.
final_state = lax.while_loop(
beam_search_loop_cond_fn, beam_search_loop_body_fn, beam_search_init_state
# Account for the edge-case where there are no finished sequences for a
# particular batch item. If so, return live sequences for that batch item.
# --> [batch]
none_finished = jnp.any(final_state.finished_flags, axis=1)
# --> [batch, beams, length]
finished_seqs = jnp.where(
none_finished[:, None, None],
# --> [batch, beams]
finished_scores = jnp.where(
none_finished[:, None],
finished_seqs = jnp.reshape(finished_seqs, (beam_size, max_decode_len))
finished_scores = jnp.reshape(finished_scores, (beam_size,))
final_cache = apply_on_cache(flatten_beam_dim, final_state.cache)
return finished_seqs, finished_scores, final_cache