Spaces:
Sleeping
Sleeping
# Copyright 2022 Google. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""FLAX layers for on-TPU memory.""" | |
import abc | |
import functools | |
from typing import Callable, Sequence, Tuple, TypeVar, Union | |
from absl import logging | |
from flax import linen | |
import gin | |
import jax | |
from jax import lax | |
import jax.numpy as jnp | |
import numpy as np # use with care! | |
Shape = Sequence[int] | |
Dtype = jnp.dtype | |
Array = jnp.ndarray | |
Axes = Union[int, Tuple[int, ...]] | |
F = TypeVar('F', bound=Callable) | |
class MemoryLayer(linen.Module, metaclass=abc.ABCMeta): | |
"""Internal interface for memory layers without batch dim. | |
See BatchedMemory for a layer that can be used in Flax models. | |
""" | |
num_datasets: int | |
def update(self, key: Array, value: Array) -> int: | |
"""Adds key/value pairs to memory. | |
Args: | |
key: of shape (num_kv, num_datasets, k_features) | |
value: of shape (num_kv, num_datasets, v_features) | |
Returns: | |
Dummy value so that TPU operations can wait for the update to finish if | |
desired. | |
""" | |
raise NotImplementedError() | |
def topk_retrieval(self, query: Array, | |
num_neighbors: int) -> Tuple[Array, Array]: | |
"""Retrieves the nearest neighbors for each query. | |
Args: | |
query: of shape (num_queries, num_datasets, k_features) | |
num_neighbors: int indicating the number of neighbors to retrieve | |
Returns: | |
Tuple of selected keys and selected values of shapes | |
(num_queries, num_datasets, num_neighbors, k_features), and | |
(num_queries, num_datasets, num_neighbors, v_features) | |
""" | |
raise NotImplementedError() | |
def reset(self, datasets: Array) -> int: | |
"""Reset some or all of the datasets in the memory. | |
Args: | |
datasets: A vector of shape (num_datasets) of type bool. Each position | |
indicates whether the dataset with the same index should be reset. | |
Returns: | |
Dummy value so that TPU operations can wait for the update to finish if | |
desired. | |
""" | |
raise NotImplementedError() | |
def __call__(self, query, num_neighbors): | |
self.topk_retrieval(query, num_neighbors) | |
def _target_dimensions(shape: Shape, | |
source_dimensions: Sequence[int]) -> Sequence[int]: | |
target_dimensions = range(-2, -2 - len(source_dimensions), -1) | |
assert len(source_dimensions) == len(target_dimensions) | |
return sorted(d % len(shape) for d in target_dimensions) | |
def _rearrange_dimensions_shapes( | |
shape: Shape, split_dimensions: Sequence[int]) -> Tuple[Shape, Shape]: | |
split_shape = tuple(shape[d] for d in split_dimensions) | |
remaining_shape = tuple( | |
shape[d] for d in range(len(shape)) if d not in split_dimensions) | |
batch_shape = remaining_shape[:-1] | |
return split_shape, batch_shape | |
def _rearrange_dimensions(x: Array, split_dimensions: Sequence[int]) -> Array: | |
"""Rearrange array so that we can split by a single dimension. | |
Turns an array of shape [d1, ..., dn, features] and a list of dimensions to | |
split by into [prod(remaining_dimensions), prod(split_dimensions), | |
features] | |
Args: | |
x: array of shape [d1, ..., dn, features] | |
split_dimensions: list of dimensions that should end up in dimension -2. | |
Returns: | |
Rearranged array as described above. | |
""" | |
split_dimensions = [d % len(x.shape) for d in split_dimensions] | |
split_dimensions = sorted(split_dimensions) | |
split_shape, batch_shape = _rearrange_dimensions_shapes( | |
x.shape, split_dimensions) | |
target_dimensions = _target_dimensions(x.shape, split_dimensions) | |
x = jnp.moveaxis(x, split_dimensions, target_dimensions) | |
assert len(x.shape) > len(split_dimensions) | |
assert all(isinstance(d, int) and d >= 0 for d in batch_shape) | |
assert all(isinstance(d, int) and d >= 0 for d in split_shape) | |
new_shape = [ | |
# The use of numpy is okay here, since shapes are concrete at jit time. | |
np.prod(batch_shape), | |
np.prod(split_shape), | |
x.shape[-1] # features dimension | |
] | |
res = x.reshape(new_shape) | |
assert res.ndim == 3 | |
return res | |
def _restore_dimensions(x: Array, original_shape: Shape, | |
split_dimensions: Sequence[int]) -> Array: | |
"""Restores arrays encoded with _rearrange_dimensions. | |
Args: | |
x: Array of shape [prod(batch_shape), prod(split_shape), feature...] | |
original_shape: Shape of the array to restore to. | |
split_dimensions: Dimensions that were multiplied into dimension 2. | |
Returns: | |
Array of the original shape and axis order for all dimensions in batch_shape | |
and split_shape. Feature dimensions may have changed (can include additional | |
dimensions for neighbors, for example). | |
""" | |
split_dimensions = [d % len(original_shape) for d in split_dimensions] | |
split_dimensions = sorted(split_dimensions) | |
split_shape, batch_shape = _rearrange_dimensions_shapes( | |
original_shape, split_dimensions) | |
features_shape = x.shape[2:] | |
x = x.reshape((*batch_shape, *split_shape, *features_shape)) | |
# rearrange | |
target_dimensions = _target_dimensions(original_shape, split_dimensions) | |
x = jnp.moveaxis(x, target_dimensions, split_dimensions) | |
return x | |
class BatchedMemory(linen.Module): | |
"""Equips a memory module with a batch dimension.""" | |
# We wrap this linen.Module: | |
wrapped: MemoryLayer | |
# `split_dimensions` indicates the dimensions of the query and update tensors | |
# that will go to separate databases. By default, we use a separate database | |
# for each head. | |
# Note that some implementations of the memory share memory across all hosts | |
# and devices (memory_on_borg, unless configured otherwise) or just across | |
# devices of each host (memory_on_host). | |
# Default is (-2,) to split by head only; use (0, -2) to also slit by batch | |
# dimensions. | |
split_dimensions: Tuple[int, ...] = (-2,) | |
query_stride: int = 1 | |
update_stride: int = 1 | |
def update(self, key: Array, value: Array): | |
"""Adds key/value pairs to memory. | |
Args: | |
key: typically of shape (batch, kv_len, num_heads, k_features). This | |
tensor is split up into datasets according to `split_dimensions`. | |
value: typically of shape (batch, kv_len, num_heads, v_features). This | |
tensor is split up into datasets according to `split_dimensions`. | |
Returns: | |
A dummy value 0, once the operation has completed. | |
""" | |
if key.ndim != 4 or value.ndim != 4: | |
raise ValueError('Expected batched inputs; got shapes: %s and %s.' % | |
(key.shape, value.shape)) | |
key = _rearrange_dimensions(key, self.split_dimensions) | |
value = _rearrange_dimensions(value, self.split_dimensions) | |
update_stride = self.update_stride | |
if update_stride == 1: | |
return self.wrapped.update(key, value) | |
return self.wrapped.update(key[update_stride - 1::update_stride, ...], | |
value[update_stride - 1::update_stride, ...]) | |
def topk_retrieval(self, query: Array, num_neighbors: int): | |
"""Retrieves the nearest neighbors for each query. | |
Args: | |
query: typically of shape (batch, q_len, num_heads, k_features). This | |
tensor is split up into datasets according to `split_dimensions`. | |
num_neighbors: number of neighbors to retrieve | |
Returns: | |
Tuple of tensors with the retrieved keys and value of the same shape as | |
query, but with an extra dimension of length num_neighbors - typically: | |
(batch, q_len, num_heads, num_neighbors, k_features) | |
""" | |
if query.ndim != 4: | |
raise ValueError('Expected batched inputs; got shape: %s.' % query.shape) | |
query_stride = self.query_stride | |
original_shape = query.shape | |
query = _rearrange_dimensions(query, self.split_dimensions) | |
if query_stride == 1: | |
key, value = self.wrapped.topk_retrieval(query, num_neighbors) | |
else: | |
num_queries, num_heads, k_features = query.shape | |
throttled_query = query[0::query_stride, ...] | |
key = jnp.zeros( | |
shape=(num_queries, num_heads, num_neighbors, k_features), | |
dtype=query.dtype) | |
throttled_key, throttled_value = ( | |
self.wrapped.topk_retrieval(throttled_query, num_neighbors)) | |
_, _, _, v_features = throttled_value.shape | |
value = jnp.zeros( | |
shape=(num_queries, num_heads, num_neighbors, v_features), | |
dtype=query.dtype) | |
key = key.at[0::query_stride, ...].set(throttled_key) | |
value = value.at[0::query_stride, ...].set(throttled_value) | |
key = _restore_dimensions(key, original_shape, self.split_dimensions) | |
# Note that `original_shape` here may have the wrong feature dimension (if | |
# k_features != v_features. But `_restore_dimensions` does not depend on | |
# that dimension and the tests cover this case. | |
value = _restore_dimensions(value, original_shape, self.split_dimensions) | |
assert key.ndim == len(original_shape) + 1 | |
return key, value | |
def reset(self, datasets: Array) -> int: | |
"""Resets the memory. | |
Args: | |
datasets: of shape (num_datasets,), typically the same as (num_heads,). | |
Returns: | |
A dummy value 0, once the operation has completed. | |
""" | |
return self.wrapped.reset(datasets) | |
def _chunking_sparsify(query: Array, key: Array, num_buckets: int, | |
bucket_size: int) -> Tuple[Array, Array, Array]: | |
"""Approximate top k operation for a single head.""" | |
# q = q_length, f = qk features, d = database_size | |
scores = jnp.einsum('qf,df->qd', query, key) | |
mask = (key.sum(-1) == 0).astype(jnp.bfloat16) * -1e6 | |
scores += mask | |
num_queries, _ = scores.shape | |
reshaped_scores = jnp.reshape(scores, (num_queries, bucket_size, num_buckets)) | |
sparse_scores = linen.softmax(reshaped_scores * 1e6, axis=1) | |
# topk_scores and topk_indices will only be computed if we depend on their | |
# results. | |
topk_scores = jnp.max(reshaped_scores, axis=1) | |
local_indices = jnp.argmax(reshaped_scores, axis=1) | |
topk_indices = ( | |
local_indices * num_buckets + jnp.arange(num_buckets).reshape( | |
(1, num_buckets))) | |
return sparse_scores, topk_scores, topk_indices | |
def _retrieve_topk_gatherless( | |
query: Array, key: Array, value: Array, | |
num_neighbors: int) -> Tuple[Array, Array, Array, Array]: | |
"""Retrieves for a single head - used to simplify array accesses.""" | |
num_kv, query_features = query.shape | |
database_size, key_features = key.shape | |
_, value_features = value.shape | |
assert query_features == key_features | |
num_buckets = num_neighbors | |
if num_buckets > database_size: | |
raise ValueError('More buckets than items in database. %s > %s' % | |
(num_buckets, database_size)) | |
if database_size % num_buckets: | |
raise ValueError('Buckets must divide database: %s %% %s.' % | |
(database_size, num_buckets)) | |
bucket_size = database_size // num_buckets | |
sparse_scores, topk_scores, topk_indices = _chunking_sparsify( | |
query, key, num_buckets, bucket_size) | |
key = key.reshape(bucket_size, num_buckets, key_features) | |
value = value.reshape(bucket_size, num_buckets, value_features) | |
selected_keys = jnp.einsum('qbn,bnd->qnd', sparse_scores, key) | |
selected_values = jnp.einsum('qbn,bnd->qnd', sparse_scores, value) | |
assert selected_keys.shape == (num_kv, num_neighbors, key_features) | |
assert selected_values.shape == (num_kv, num_neighbors, value_features) | |
return selected_keys, selected_values, topk_scores, topk_indices | |
class MemoryOnTpu(MemoryLayer): | |
"""Approximate top K search on TPU.""" | |
# database_size must be integer multiple of prod(batch_dims) * num_neighbors. | |
database_size: int | |
dtype: Dtype = jnp.float32 # pylint: disable=g-bare-generic | |
key_features: int = 64 | |
value_features: int = 64 | |
report_scores_and_indices: bool = False | |
def setup(self): | |
self.db_index = self.variable('database', 'database_index', | |
functools.partial(jnp.zeros, dtype=jnp.int32), | |
(self.num_datasets,)) | |
self.key_db = self.variable( | |
'database', 'key_db', functools.partial(jnp.zeros, dtype=self.dtype), | |
(self.num_datasets, self.database_size, self.key_features)) | |
self.value_db = self.variable( | |
'database', 'value_db', functools.partial(jnp.zeros, dtype=self.dtype), | |
(self.num_datasets, self.database_size, self.value_features)) | |
self.retrieved_indices = self.variable( | |
'database', 'retrieved_indices', | |
functools.partial(jnp.zeros, dtype=jnp.int32), (0, 0, 0)) | |
self.retrieved_indices_scores = self.variable( | |
'database', 'retrieved_indices_scores', | |
functools.partial(jnp.zeros, dtype=jnp.float32), (0, 0, 0)) | |
def _update_kv_database(self, database, new_values, start_index): | |
num_datasets, database_size, _ = database.shape | |
assert database_size == self.database_size, f'{database_size} vs {self.database_size}' | |
assert num_datasets == self.num_datasets | |
assert new_values.ndim == 3 | |
assert start_index.shape == (self.num_datasets,) | |
def _update(database, new_values, start_index): | |
return lax.dynamic_update_slice( | |
database, new_values, start_indices=(start_index, 0)) | |
return jax.vmap( | |
_update, in_axes=(0, 0, 0), out_axes=0)(database, new_values, | |
start_index) | |
def update(self, key: Array, value: Array) -> int: | |
"""Add keys and values to the memory; overwrite oldest if memory is full.""" | |
key = lax.stop_gradient(key) | |
value = lax.stop_gradient(value) | |
assert len(key.shape) == len(value.shape) | |
assert key.shape[:-1] == value.shape[:-1] | |
num_kv, num_datasets, key_features = key.shape | |
assert num_datasets == self.num_datasets | |
assert key_features == self.key_features | |
assert value.shape[-1] == self.value_features | |
assert self.database_size % num_kv == 0, ( | |
'Database size must be integer multiple of num_kv.') | |
key = jnp.moveaxis(key, source=1, destination=0) # split by dataset | |
value = jnp.moveaxis(value, source=1, destination=0) # split by dataset | |
# start_index can be larger than DB - we use that to detect which entries | |
# are not written to yet | |
start_index = self.db_index.value % self.database_size | |
self.key_db.value = self._update_kv_database(self.key_db.value, key, | |
start_index) | |
self.value_db.value = self._update_kv_database(self.value_db.value, value, | |
start_index) | |
self.db_index.value = self.db_index.value + num_kv | |
return 0 | |
def topk_retrieval(self, query: Array, | |
num_neighbors: int) -> Tuple[Array, Array]: | |
"""Nearest neighbors by full multiplication and approximate top k on TPU.""" | |
query = lax.stop_gradient(query) | |
unused_num_kv, num_datasets, query_features = query.shape | |
assert num_datasets == self.num_datasets | |
assert query_features == self.key_features | |
query = jnp.moveaxis(query, source=1, destination=0) | |
# Process different heads sequentially | |
selected_keys, selected_values, topk_scores, topk_indices = lax.map( | |
lambda x: _retrieve_topk_gatherless(*x, num_neighbors), | |
(query, self.key_db.value, self.value_db.value)) | |
if self.report_scores_and_indices: | |
# TODO(mrabe): These variable updates may not work perfectly yet. Find out | |
# why Flax does not like them. | |
self.retrieved_indices.value = topk_indices | |
self.retrieved_indices_scores.value = topk_scores | |
assert selected_keys.ndim == selected_values.ndim == 4 | |
selected_keys = jnp.moveaxis(selected_keys, source=0, destination=1) | |
selected_values = jnp.moveaxis(selected_values, source=0, destination=1) | |
return selected_keys, selected_values | |
def reset(self, datasets: Array) -> int: | |
"""Resets specified datasets.""" | |
datasets = lax.stop_gradient(datasets) | |
assert datasets.shape == (self.num_datasets,) | |
assert datasets.dtype == jnp.bool_ | |
def _reset_single_dataset(input_tuple): | |
"""Resets a single head; reset is a single bool.""" | |
database, reset = input_tuple | |
assert reset.shape == tuple(), reset.shape | |
assert reset.dtype == jnp.bool_ | |
return database * (1 - reset) | |
self.db_index.value = self.db_index.value * (1 - datasets) | |
self.key_db.value = lax.map( | |
_reset_single_dataset, xs=(self.key_db.value, datasets)) | |
self.value_db.value = lax.map( | |
_reset_single_dataset, xs=(self.value_db.value, datasets)) | |
return 0 | |