Spaces:
Running
Running
# file from: https://github.com/google-research/google-research/blob/master/scalable_shampoo/optax/distributed_shampoo.py | |
# coding=utf-8 | |
# Copyright 2022 The Google Research Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# An implementation of distributed Shampoo optimizer from: | |
# | |
# Scalable Second Order Optimization for Deep Learning | |
# Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer | |
# Preprint Paper: https://arxiv.org/abs/2002.09018 | |
# | |
# This implementation moves computation of inverse pth root back to the | |
# accelerator (if higher precision is available). | |
# | |
# Authors: Rohan Anil (rohananil at google dot com) | |
# & Vineet Gupta (vineet at google dot com) | |
# | |
"""Distributed Shampoo Implementation.""" | |
import enum | |
import functools | |
import itertools | |
from typing import Any, List, NamedTuple | |
import chex | |
import jax | |
import jax.experimental.pjit as pjit | |
import jax.numpy as jnp | |
import numpy as np | |
import optax | |
from flax import struct | |
from jax import lax | |
# pylint:disable=no-value-for-parameter | |
class QuantizedValue: | |
"""State associated with quantized value.""" | |
quantized: chex.Array | |
diagonal: chex.Array # Diagonal (if extract_diagonal is set) | |
bucket_size: chex.Array | |
quantized_dtype: jnp.dtype = struct.field( | |
pytree_node=False | |
) # Dtype for the quantized value. | |
extract_diagonal: bool = struct.field(pytree_node=False) # In case its centered. | |
shape: Any = struct.field(pytree_node=False) # Shape of the tensor. | |
def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False): | |
if isinstance(fvalue, list) and not fvalue: | |
return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, []) | |
quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize( | |
fvalue, quantized_dtype, extract_diagonal | |
) | |
return QuantizedValue( | |
quantized, | |
diagonal_fvalue, | |
bucket_size, | |
quantized_dtype, | |
extract_diagonal, | |
list(quantized.shape), | |
) | |
# Quantization is from Lingvo JAX optimizers. | |
# We extend it for int16 quantization of PSD matrices. | |
def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False): | |
"""Returns quantized value and the bucket.""" | |
if quantized_dtype == jnp.float32: | |
return fvalue, [], [] | |
elif quantized_dtype == jnp.bfloat16: | |
return fvalue.astype(jnp.bfloat16), [], [] | |
float_dtype = fvalue.dtype | |
if quantized_dtype == jnp.int8: | |
# value -128 is not used. | |
num_buckets = jnp.array(127.0, dtype=float_dtype) | |
elif quantized_dtype == jnp.int16: | |
# value -32768 is not used. | |
num_buckets = jnp.array(32767.0, dtype=float_dtype) | |
else: | |
raise ValueError(f"Quantized dtype {quantized_dtype} not supported.") | |
# max value is mapped to num_buckets | |
if extract_diagonal and fvalue.ndim != 2: | |
raise ValueError( | |
f"Input array {fvalue} must be 2D to work with extract_diagonal." | |
) | |
diagonal_fvalue = [] | |
if extract_diagonal: | |
diagonal_fvalue = jnp.diag(fvalue) | |
# Remove the diagonal entries. | |
fvalue = fvalue - jnp.diag(diagonal_fvalue) | |
# TODO(rohananil): Extend this by making use of information about the blocks | |
# SM3 style which will be useful for diagonal statistics | |
# We first decide the scale. | |
if fvalue.ndim < 1: | |
raise ValueError( | |
f"Input array {fvalue} must have a strictly positive number of " | |
"dimensions." | |
) | |
max_abs = jnp.max(jnp.abs(fvalue), axis=0) | |
bucket_size = max_abs / num_buckets | |
bs_expanded = bucket_size[jnp.newaxis, Ellipsis] | |
# To avoid divide by 0.0 | |
bs_nonzero = jnp.where( | |
bs_expanded > 0.0, bs_expanded, jnp.ones_like(bs_expanded) | |
) | |
ratio = fvalue / bs_nonzero | |
# We use rounding to remove bias. | |
quantized = jnp.round(ratio) | |
return quantized.astype(quantized_dtype), diagonal_fvalue, bucket_size | |
def to_float(self): | |
"""Returns the float value.""" | |
if isinstance(self.quantized, list) and not self.quantized: | |
return self.quantized | |
if self.quantized_dtype == jnp.float32: | |
return self.quantized | |
if self.quantized_dtype == jnp.bfloat16: | |
return self.quantized.astype(jnp.float32) | |
float_dtype = self.bucket_size.dtype | |
bucket_size = self.bucket_size[jnp.newaxis, Ellipsis] | |
val = self.quantized.astype(float_dtype) * bucket_size | |
if self.extract_diagonal: | |
val += jnp.diag(self.diagonal) | |
return val | |
class TrainingMetrics: | |
inverse_pth_root_errors: chex.Array # Error for inverse-pth roots. | |
# TODO(rohananil): Add more important metrics to track during training. | |
# Per parameter optimizer state used in data-parallel training. | |
class ParameterStats(NamedTuple): | |
"""State associated to each parameter of the model being trained.""" | |
diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner | |
statistics: List[Any] # Statistics (QuantizedValue, chex.Array) | |
preconditioners: List[Any] # Preconditioners (QuantizedValue, chex.Array) | |
diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner | |
momentum: QuantizedValue # Momentum for the shampoo preconditioner | |
training_metrics: TrainingMetrics # Metrics (optional for training). | |
# For training extremely large model; We keep a global state with a concatenated | |
# statistics and preconditioner states for all vars. This is so that we can | |
# annotate the leading axis to be sharded to save memory at the cost of | |
# communication. | |
class GlobalShardedParameterStats: | |
statistics: chex.Array # Statistics | |
preconditioners: chex.Array # Preconditioners | |
exponents: chex.Array # exponents | |
# These are per-parameter local states; All statistics here mirror the parameter | |
# Thus the sharding is copied over from the param specification. | |
class LocalShardedParameterStats: | |
"""State associated to each parameter of the model being trained.""" | |
diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner | |
diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner | |
momentum: QuantizedValue # Momentum for the shampoo preconditioner | |
training_metrics: TrainingMetrics # Metrics (optional for training). | |
index_start: np.int32 = struct.field( | |
pytree_node=False | |
) # Index into global statistics array | |
sizes: Any = struct.field(pytree_node=False) # Sizes of the statistics. | |
def init_training_metrics(num_statistics): | |
if num_statistics: | |
return TrainingMetrics(jnp.zeros([num_statistics], jnp.float32)) | |
else: | |
return TrainingMetrics([]) | |
def init_training_metrics_shapes(num_statistics): | |
if num_statistics: | |
return TrainingMetrics([[num_statistics], jnp.float32]) | |
else: | |
return TrainingMetrics([None, jnp.float32]) | |
def init_training_metrics_pspec(num_statistics): | |
if num_statistics: | |
return TrainingMetrics(pjit.PartitionSpec()) | |
else: | |
return TrainingMetrics(None) | |
class ShardedShampooStats(NamedTuple): | |
"""Shampoo state in sharded mode.""" | |
global_stats: Any | |
local_stats: Any | |
class ShampooState(NamedTuple): | |
count: chex.Array | |
stats: Any | |
class InitFnState(NamedTuple): | |
init_fn: Any | |
pspec_fn: Any | |
shape_and_dtype_fn: Any | |
class GraftingType(enum.IntEnum): | |
SGD = 1 | |
ADAGRAD = 2 | |
RMSPROP = 3 | |
RMSPROP_NORMALIZED = 4 | |
SQRT_N = 5 | |
ADAGRAD_NORMALIZED = 6 | |
def power_iteration( | |
matrix, num_iters=100, error_tolerance=1e-6, precision=lax.Precision.HIGHEST | |
): | |
r"""Power iteration algorithm. | |
The power iteration algorithm takes a symmetric PSD matrix `A`, and produces | |
a scalar `\lambda` , which is the greatest (in absolute value) eigenvalue | |
of `A`, and a vector v, which is the corresponding eigenvector of `A`. | |
References: | |
[Wikipedia, 2021](https://en.wikipedia.org/wiki/Power_iteration) | |
Args: | |
matrix: the symmetric PSD matrix. | |
num_iters: Number of iterations. | |
error_tolerance: Iterative exit condition. | |
precision: precision XLA related flag, the available options are: | |
a) lax.Precision.DEFAULT (better step time, but not precise) | |
b) lax.Precision.HIGH (increased precision, slower) | |
c) lax.Precision.HIGHEST (best possible precision, slowest) | |
Returns: | |
eigen vector, eigen value | |
""" | |
matrix_size = matrix.shape[-1] | |
def _iter_condition(state): | |
i, unused_v, unused_s, unused_s_v, run_step = state | |
return jnp.logical_and(i < num_iters, run_step) | |
def _iter_body(state): | |
"""One step of power iteration.""" | |
i, new_v, s, s_v, unused_run_step = state | |
new_v = new_v / jnp.linalg.norm(new_v) | |
s_v = jnp.einsum("ij,j->i", matrix, new_v, precision=precision) | |
s_new = jnp.einsum("i,i->", new_v, s_v, precision=precision) | |
return ( | |
i + 1, | |
s_v, | |
s_new, | |
s_v, | |
jnp.greater(jnp.abs(s_new - s), error_tolerance), | |
) | |
# Figure out how to use step as seed for random. | |
v_0 = ( | |
np.random.RandomState(1729).uniform(-1.0, 1.0, matrix_size).astype(matrix.dtype) | |
) | |
init_state = tuple([0, v_0, jnp.zeros([], dtype=matrix.dtype), v_0, True]) | |
_, v_out, s_out, _, _ = lax.while_loop(_iter_condition, _iter_body, init_state) | |
v_out = v_out / jnp.linalg.norm(v_out) | |
return v_out, s_out | |
def matrix_inverse_pth_root( | |
matrix, | |
p, | |
num_iters=100, | |
ridge_epsilon=1e-6, | |
error_tolerance=1e-6, | |
precision=lax.Precision.HIGHEST, | |
): | |
"""Computes `matrix^(-1/p)`, where `p` is a positive integer. | |
This function uses the Coupled newton iterations algorithm for | |
the computation of a matrix's inverse pth root. | |
References: | |
[Functions of Matrices, Theory and Computation, | |
Nicholas J Higham, Pg 184, Eq 7.18]( | |
https://epubs.siam.org/doi/book/10.1137/1.9780898717778) | |
Args: | |
matrix: the symmetric PSD matrix whose power it to be computed | |
p: exponent, for p a positive integer. | |
num_iters: Maximum number of iterations. | |
ridge_epsilon: Ridge epsilon added to make the matrix positive definite. | |
error_tolerance: Error indicator, useful for early termination. | |
precision: precision XLA related flag, the available options are: | |
a) lax.Precision.DEFAULT (better step time, but not precise) | |
b) lax.Precision.HIGH (increased precision, slower) | |
c) lax.Precision.HIGHEST (best possible precision, slowest) | |
Returns: | |
matrix^(-1/p) | |
""" | |
assert matrix.shape[0] == matrix.shape[1] | |
# We use float32 for the matrix inverse pth root. | |
# Switch to f64 if you have hardware that supports it. | |
matrix_size = matrix.shape[0] | |
alpha = jnp.asarray(-1.0 / p, jnp.float32) | |
identity = jnp.eye(matrix_size, dtype=jnp.float32) | |
_, max_ev = power_iteration( | |
matrix=matrix, num_iters=100, error_tolerance=1e-6, precision=precision | |
) | |
ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-6) | |
def _unrolled_mat_pow_1(mat_m): | |
"""Computes mat_m^1.""" | |
return mat_m | |
def _unrolled_mat_pow_2(mat_m): | |
"""Computes mat_m^2.""" | |
return jnp.matmul(mat_m, mat_m, precision=precision) | |
def _unrolled_mat_pow_4(mat_m): | |
"""Computes mat_m^4.""" | |
mat_pow_2 = _unrolled_mat_pow_2(mat_m) | |
return jnp.matmul(mat_pow_2, mat_pow_2, precision=precision) | |
def _unrolled_mat_pow_8(mat_m): | |
"""Computes mat_m^4.""" | |
mat_pow_4 = _unrolled_mat_pow_4(mat_m) | |
return jnp.matmul(mat_pow_4, mat_pow_4, precision=precision) | |
def mat_power(mat_m, p): | |
"""Computes mat_m^p, for p == 1, 2, 4 or 8. | |
Args: | |
mat_m: a square matrix | |
p: a positive integer | |
Returns: | |
mat_m^p | |
""" | |
# We unrolled the loop for performance reasons. | |
exponent = jnp.round(jnp.log2(p)) | |
return lax.switch( | |
jnp.asarray(exponent, jnp.int32), | |
[ | |
_unrolled_mat_pow_1, | |
_unrolled_mat_pow_2, | |
_unrolled_mat_pow_4, | |
_unrolled_mat_pow_8, | |
], | |
(mat_m), | |
) | |
def _iter_condition(state): | |
(i, unused_mat_m, unused_mat_h, unused_old_mat_h, error, run_step) = state | |
error_above_threshold = jnp.logical_and(error > error_tolerance, run_step) | |
return jnp.logical_and(i < num_iters, error_above_threshold) | |
def _iter_body(state): | |
(i, mat_m, mat_h, unused_old_mat_h, error, unused_run_step) = state | |
mat_m_i = (1 - alpha) * identity + alpha * mat_m | |
new_mat_m = jnp.matmul(mat_power(mat_m_i, p), mat_m, precision=precision) | |
new_mat_h = jnp.matmul(mat_h, mat_m_i, precision=precision) | |
new_error = jnp.max(jnp.abs(new_mat_m - identity)) | |
# sometimes error increases after an iteration before decreasing and | |
# converging. 1.2 factor is used to bound the maximal allowed increase. | |
return (i + 1, new_mat_m, new_mat_h, mat_h, new_error, new_error < error * 1.2) | |
if matrix_size == 1: | |
resultant_mat_h = (matrix + ridge_epsilon) ** alpha | |
error = 0 | |
else: | |
damped_matrix = matrix + ridge_epsilon * identity | |
z = (1 + p) / (2 * jnp.linalg.norm(damped_matrix)) | |
new_mat_m_0 = damped_matrix * z | |
new_error = jnp.max(jnp.abs(new_mat_m_0 - identity)) | |
new_mat_h_0 = identity * jnp.power(z, 1.0 / p) | |
init_state = tuple([0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, True]) | |
_, mat_m, mat_h, old_mat_h, error, convergence = lax.while_loop( | |
_iter_condition, _iter_body, init_state | |
) | |
error = jnp.max(jnp.abs(mat_m - identity)) | |
is_converged = jnp.asarray(convergence, old_mat_h.dtype) | |
resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h | |
resultant_mat_h = jnp.asarray(resultant_mat_h, matrix.dtype) | |
return resultant_mat_h, error | |
def merge_small_dims(shape_to_merge, max_dim): | |
"""Merge small dimensions. | |
If there are some small dimensions, we collapse them: | |
e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024 | |
[1, 2, 768, 1, 2048] --> [2, 768, 2048] | |
Args: | |
shape_to_merge: Shape to merge small dimensions. | |
max_dim: Maximal dimension of output shape used in merging. | |
Returns: | |
Merged shape. | |
""" | |
resulting_shape = [] | |
product = 1 | |
for d in shape_to_merge: | |
if product * d <= max_dim: | |
product *= d | |
else: | |
if product > 1: | |
resulting_shape.append(product) | |
product = d | |
if product > 1: | |
resulting_shape.append(product) | |
return resulting_shape | |
def pad_matrix(mat, max_size): | |
"""Pad a matrix to a max_size. | |
Args: | |
mat: a matrix to pad. | |
max_size: matrix size requested. | |
Returns: | |
Given M returns [[M, 0], [0, I]] | |
""" | |
size = mat.shape[0] | |
assert size <= max_size | |
if size == max_size: | |
return mat | |
pad_size = max_size - size | |
zs1 = jnp.zeros([size, pad_size], dtype=mat.dtype) | |
zs2 = jnp.zeros([pad_size, size], dtype=mat.dtype) | |
eye = jnp.eye(pad_size, dtype=mat.dtype) | |
mat = jnp.concatenate([mat, zs1], 1) | |
mat = jnp.concatenate([mat, jnp.concatenate([zs2, eye], 1)], 0) | |
return mat | |
def pad_vector(vec, max_size): | |
"""Pad a vector to a max_size. | |
Args: | |
vec: a vector to pad. | |
max_size: matrix size requested. | |
Returns: | |
Given V returns [V, 0] | |
""" | |
size = vec.shape[0] | |
assert size <= max_size | |
if size == max_size: | |
return vec | |
pad_size = max_size - size | |
zs1 = jnp.zeros([pad_size], dtype=vec.dtype) | |
return jnp.concatenate([vec, zs1], 0) | |
def efficient_cond(predicate, compute_fn, init_state, *args, **kwargs): | |
"""Avoids wasteful buffer allocation with XLA.""" | |
def _iter_body(unused_state): | |
results = compute_fn(*args, **kwargs) | |
return tuple([False] + list(results)) | |
def _iter_condition(state): | |
return state[0] | |
results = jax.lax.while_loop( | |
_iter_condition, _iter_body, tuple([predicate] + init_state) | |
) | |
return tuple(results[1:]) | |
class BlockPartitioner: | |
"""Partitions a tensor into smaller tensors.""" | |
def __init__(self, param, block_size): | |
self._shape = param.shape | |
self._splits = [] | |
split_sizes = [] | |
# We split params into smaller blocks. Here we store the metadata to make | |
# that split. | |
for i, d in enumerate(param.shape): | |
if 0 < block_size < d: | |
# d-1, otherwise split appends a 0-size array. | |
nsplit = (d - 1) // block_size | |
indices = (np.arange(nsplit, dtype=np.int32) + 1) * block_size | |
sizes = np.ones(nsplit + 1, dtype=np.int32) * block_size | |
sizes[-1] = d - indices[-1] | |
self._splits.append((i, indices)) | |
split_sizes.append(sizes) | |
else: | |
split_sizes.append(np.array([d], dtype=np.int32)) | |
self._num_splits = len(split_sizes) | |
self._preconditioner_shapes = [] | |
for t in itertools.product(*split_sizes): | |
self._preconditioner_shapes.extend([[d, d] for d in t]) | |
def shapes_for_preconditioners(self): | |
return self._preconditioner_shapes | |
def num_splits(self): | |
return self._num_splits | |
def partition(self, tensor): | |
"""Partition tensor into blocks.""" | |
assert tensor.shape == self._shape | |
tensors = [tensor] | |
for (i, indices) in self._splits: | |
tensors_local = [] | |
for t in tensors: | |
tensors_local.extend(jnp.split(t, indices_or_sections=indices, axis=i)) | |
tensors = tensors_local | |
return tensors | |
def merge_partitions(self, partitions): | |
"""Merge partitions back to original shape.""" | |
for (i, indices) in reversed(self._splits): | |
n = len(indices) + 1 | |
partial_merged_tensors = [] | |
ind = 0 | |
while ind < len(partitions): | |
partial_merged_tensors.append( | |
jnp.concatenate(partitions[ind : ind + n], axis=i) | |
) | |
ind += n | |
partitions = partial_merged_tensors | |
assert len(partitions) == 1 | |
return partitions[0] | |
class Preconditioner: | |
"""Compute statistics/shape from gradients for preconditioning.""" | |
def __init__(self, param, block_size, best_effort_shape_interpretation): | |
self._original_shape = param.shape | |
self._transformed_shape = param.shape | |
if best_effort_shape_interpretation: | |
self._transformed_shape = merge_small_dims(self._original_shape, block_size) | |
reshaped_param = jnp.reshape(param, self._transformed_shape) | |
self._partitioner = BlockPartitioner(reshaped_param, block_size) | |
def statistics_from_grad(self, grad): | |
"""Compute statistics from gradients. | |
Args: | |
grad: Gradient to compute statistics from. | |
Returns: | |
A list of gradient statistics for each partition. | |
""" | |
reshaped_grad = jnp.reshape(grad, self._transformed_shape) | |
partitioned_grads = self._partitioner.partition(reshaped_grad) | |
stats = [] | |
for g in partitioned_grads: | |
g_stats = [] | |
rank = len(g.shape) | |
for i in range(rank): | |
axes = list(range(i)) + list(range(i + 1, rank)) | |
stat = jnp.tensordot(g, g, axes=(axes, axes)) | |
g_stats.append(stat) | |
stats.extend(g_stats) | |
return stats | |
def shapes_for_preconditioners(self): | |
"""Returns shape from statistics.""" | |
return self._partitioner.shapes_for_preconditioners() | |
def exponent_for_preconditioner(self): | |
"""Returns exponent to use for inverse-pth root M^{-1/p}.""" | |
return 2 * len(self._transformed_shape) | |
def preconditioned_grad(self, grad, preconditioners): | |
"""Precondition the gradient. | |
Args: | |
grad: A gradient tensor to precondition. | |
preconditioners: A list of preconditioners to apply. | |
Returns: | |
A preconditioned gradient. | |
""" | |
reshaped_grad = jnp.reshape(grad, self._transformed_shape) | |
partitioned_grads = self._partitioner.partition(reshaped_grad) | |
preconditioned_partitioned_grads = [] | |
num_splits = self._partitioner.num_splits() | |
for i, g in enumerate(partitioned_grads): | |
preconditioners_for_grad = preconditioners[ | |
i * num_splits : (i + 1) * num_splits | |
] | |
rank = len(g.shape) | |
precond_g = g | |
for j in range(rank): | |
precond_g = jnp.tensordot( | |
precond_g, preconditioners_for_grad[j], axes=[[0], [0]] | |
) | |
preconditioned_partitioned_grads.append(precond_g) | |
merged_grad = self._partitioner.merge_partitions( | |
preconditioned_partitioned_grads | |
) | |
return jnp.reshape(merged_grad, self._original_shape) | |
def _convert_to_parameter_stats(global_stats, local_stat): | |
"""Creates parameter stats from sharded stats.""" | |
index_start = int(local_stat.index_start) | |
index_end = int(len(local_stat.sizes)) + index_start | |
statistics = global_stats.statistics[index_start:index_end, :, :] | |
preconditioners = global_stats.preconditioners[index_start:index_end, :, :] | |
new_statistics = [] | |
new_preconditioners = [] | |
for i, size in enumerate(local_stat.sizes): | |
new_statistics.append(statistics[i][:size, :size]) | |
new_preconditioners.append(preconditioners[i][:size, :size]) | |
return ParameterStats( | |
local_stat.diagonal_statistics, | |
new_statistics, | |
new_preconditioners, | |
local_stat.diagonal_momentum, | |
local_stat.momentum, | |
local_stat.training_metrics, | |
) | |
def _convert_from_parameter_stats(parameter_stats, local_stats): | |
"""Creates sharded stats from paramter stats.""" | |
return LocalShardedParameterStats( | |
parameter_stats.diagonal_statistics, | |
parameter_stats.diagonal_momentum, | |
parameter_stats.momentum, | |
parameter_stats.training_metrics, | |
local_stats.index_start, | |
local_stats.sizes, | |
) | |
def _add_error_into_local_stats(local_stats, errors, inverse_failure_threshold): | |
"""Adds errors back into local statistics.""" | |
new_local_stats = [] | |
for local_stat in local_stats: | |
index_start = int(local_stat.index_start) | |
index_end = int(len(local_stat.sizes)) + index_start | |
per_stat_error = errors[index_start:index_end] | |
if local_stat.sizes: | |
per_stat_error = jnp.where( | |
jnp.logical_and( | |
per_stat_error > 0.0, per_stat_error != inverse_failure_threshold | |
), | |
per_stat_error, | |
local_stat.training_metrics.inverse_pth_root_errors, | |
) | |
new_local_stats.append( | |
LocalShardedParameterStats( | |
local_stat.diagonal_statistics, | |
local_stat.diagonal_momentum, | |
local_stat.momentum, | |
TrainingMetrics(per_stat_error), | |
local_stat.index_start, | |
local_stat.sizes, | |
) | |
) | |
return new_local_stats | |
def batch(x, num_devices): | |
"""Batch `x` so that so that leading axis is num_devices.""" | |
n = len(x) | |
b = int(n / num_devices) | |
return jnp.stack([jnp.stack(x[idx : idx + b]) for idx in range(0, n, b)]) | |
def unbatch(batched_values): | |
"""Unbatch values across leading axis and return a list of elements.""" | |
b1, b2 = batched_values.shape[0], batched_values.shape[1] | |
results = [] | |
for v_array in jnp.split(batched_values, indices_or_sections=b1, axis=0): | |
v_array = jnp.squeeze(v_array) | |
# b2 = batches (number of preconditioner computation) per core. | |
if b2 > 1: | |
for v in jnp.split(v_array, indices_or_sections=b2, axis=0): | |
results.append(jnp.squeeze(v)) | |
else: | |
results.append(v_array) | |
return results | |
def distributed_shampoo( | |
learning_rate, | |
block_size, | |
beta1=0.9, | |
beta2=0.999, | |
diagonal_epsilon=1e-10, | |
matrix_epsilon=1e-6, | |
weight_decay=0.0, | |
start_preconditioning_step=5, | |
preconditioning_compute_steps=1, | |
statistics_compute_steps=1, | |
best_effort_shape_interpretation=True, | |
graft_type=GraftingType.SGD, | |
nesterov=True, | |
exponent_override=0, | |
# Pass pmap 'batch axis name' in pmap mode. | |
batch_axis_name=None, | |
### Only set following 3 params in pjit/spmd mode. | |
### WARNING: Experimental | |
statistics_partition_spec=None, | |
preconditioner_partition_spec=None, | |
num_devices_for_pjit=None, | |
shard_optimizer_states=False, | |
### | |
### Experimental memory reduction mode | |
best_effort_memory_usage_reduction=False, | |
### | |
inverse_failure_threshold=0.1, | |
moving_average_for_momentum=False, | |
skip_preconditioning_dim_size_gt=4096, | |
clip_by_scaled_gradient_norm=None, | |
precision=lax.Precision.HIGHEST, | |
): | |
"""Distributed Shampoo optimizer. | |
Distributed Shampoo is a second-order preconditioned method (concretely, a | |
variant of full-matrix Adagrad), that provides significant convergence and | |
wall-clock time improvements compared to conventional first-order methods, | |
and that has been shown to scale to large state-of-the-art deep learning | |
models. | |
References: | |
Scalable Second Order Optimization for Deep Learning, | |
Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer | |
Preprint: https://arxiv.org/abs/2002.09018 | |
Args: | |
learning_rate: the step size used to update the parameters. | |
block_size: Block size for large layers (if > 0). Preconditioning compute | |
operation is cubic in the dimension of the tensor. Block size allows us to | |
chunk the layers into sub-layers of maximal dimension dictated by this | |
value. Use 128 as default (increase if you have compute budget). | |
beta1: momentum parameter. | |
beta2: second moment averaging parameter. | |
diagonal_epsilon: epsilon for diagonal adagrad (only if layerwise grafting | |
to AdaGrad is enabled). | |
matrix_epsilon: epsilon to add to statistics before computing inverse pth | |
root. If you are running in f32 precision for inverse pth root | |
(recommended today) this can go upto 1e-6. If you have latest hardware | |
with native f64 precision, set this upto 1e-12. | |
weight_decay: Weight decay for regularization. | |
start_preconditioning_step: When to start Shampoo update before which | |
diagonal update is used. This is because we dont have enough information | |
to do stable inverse. | |
preconditioning_compute_steps: How often to compute preconditioner. | |
Performance tuning params for controlling memory and compute requirements. | |
Ideally set this and statistics_compute_steps params to 1. | |
statistics_compute_steps: How often to compute statistics. | |
best_effort_shape_interpretation: If there are some small dimensions, | |
collapse them e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if | |
block = 1024, [1, 2, 768, 1, 2048] --> [2, 768, 2048] | |
graft_type: Grafting is a technique to fix the layerwise scale of Shampoo | |
optimizer. This allows us to plugin the Shampoo optimizer into settings | |
where SGD/AdaGrad is already well tuned. | |
nesterov: Nesterov momentum. | |
exponent_override: Override the exponent used in matrix inverse. | |
batch_axis_name: labeled axis over pmap for data-parallel training the | |
optimizer used for. | |
statistics_partition_spec: PartitionSpec to be used in sharded mode. | |
preconditioner_partition_spec: PartitionSpec to be used in sharded mode. | |
num_devices_for_pjit: Number of devices to parallelize over when using pjit. | |
shard_optimizer_states: Shard optimizer states to save memory in model | |
parallel training. | |
best_effort_memory_usage_reduction: Best effort memory usage reduction. | |
diagonal_statistics -> jnp.bfloat16 | |
momentum buffers (2x) -> jnp.int8 | |
statistics, preconditioners -> jnp.int16 + diagonals | |
inverse_failure_threshold: numerics are hard and inverses fail sometimes; we | |
determine that using this threshold. | |
moving_average_for_momentum: Whether to use moving average for momentum | |
instead of exponential moving average. | |
skip_preconditioning_dim_size_gt: Skip if preconditioning dim size is | |
greater than this value. | |
clip_by_scaled_gradient_norm: Clip by scaled gradient norm (only useful | |
when using RMSProp Grafting). | |
precision: precision XLA related flag, the available options are: a) | |
lax.Precision.DEFAULT (better step time, but not precise) b) | |
lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST | |
(best possible precision, slowest) | |
Returns: | |
a GradientTransformation. | |
""" | |
def _graft_type_has_diagonal_statistics(): | |
"""Returns True if using diagonal firt order method for grafting.""" | |
return graft_type != GraftingType.SGD and graft_type != GraftingType.SQRT_N | |
def _graft_type_has_diagonal_momentum_states(): | |
"""Returns False if using SQRT_N for grafting.""" | |
return graft_type != GraftingType.SQRT_N | |
def quantized_dtype_for_momentum_buffers(): | |
return jnp.int8 if best_effort_memory_usage_reduction else jnp.float32 | |
# TODO(rohananil): Explore int8-16 quantization with non-linear bucket sizes. | |
def quantized_dtype_for_diagonal_statistics_buffers(): | |
return jnp.float32 | |
# Preconditioner and statistics are both stores as int16 in this mode. | |
# We take out the diagonal to make quantization easier. | |
def quantized_dtype_for_second_moment_statistics_buffers(): | |
return ( | |
jnp.int16 | |
if best_effort_memory_usage_reduction and batch_axis_name | |
else jnp.float32 | |
) | |
# Preconditioner and statistics are both stores as int16 in this mode. | |
# We take out the diagonal to make quantization easier. | |
def quantized_dtype_for_second_moment_preconditioner_buffers(): | |
return ( | |
jnp.int16 | |
if best_effort_memory_usage_reduction and batch_axis_name | |
else jnp.float32 | |
) | |
def _to_float(maybe_quantized): | |
if isinstance(maybe_quantized, QuantizedValue): | |
return maybe_quantized.to_float() | |
else: | |
return maybe_quantized | |
def _maybe_quantize_statistics(statistics_list): | |
return _maybe_quantize_matrices_with_dtype( | |
statistics_list, quantized_dtype_for_second_moment_statistics_buffers() | |
) | |
def _maybe_quantize_preconditioners(statistics_list): | |
return _maybe_quantize_matrices_with_dtype( | |
statistics_list, quantized_dtype_for_second_moment_preconditioner_buffers() | |
) | |
def _maybe_quantize_matrices_with_dtype(statistics_list, quantized_dtype): | |
if quantized_dtype != jnp.float32: | |
return [ | |
QuantizedValue.from_float_value( | |
s, quantized_dtype, extract_diagonal=True | |
) | |
for s in statistics_list | |
] | |
else: | |
return statistics_list | |
def _maybe_dequantize_preconditioners(preconditioner_list): | |
return _maybe_dequantize_matrices_with_dtype( | |
preconditioner_list, | |
quantized_dtype_for_second_moment_preconditioner_buffers(), | |
) | |
def _maybe_dequantize_matrices_with_dtype(statistics_list, quantized_dtype): | |
if quantized_dtype != jnp.float32: | |
return [s.to_float() for s in statistics_list] | |
else: | |
return statistics_list | |
def _quantize_diagonal_statistics(diagonal_statistics): | |
return QuantizedValue.from_float_value( | |
diagonal_statistics, quantized_dtype_for_diagonal_statistics_buffers() | |
) | |
def _quantize_momentum(momentum_statistics): | |
return QuantizedValue.from_float_value( | |
momentum_statistics, quantized_dtype_for_momentum_buffers() | |
) | |
def sharded_init_fn(params): | |
"""Returns optimizer state (for PJIT mode). | |
Args: | |
params: the parameters that should be updated. | |
""" | |
params_flat, treedef = jax.tree_flatten(params) | |
# Find max size to pad to. | |
max_size = 0 | |
for param in params_flat: | |
preconditioner = Preconditioner( | |
param, block_size, best_effort_shape_interpretation | |
) | |
if not _skip_preconditioning(param): | |
shapes = preconditioner.shapes_for_preconditioners() | |
sizes = [s[0] for s in shapes] | |
max_size = max(max(sizes), max_size) | |
padded_statistics = [] | |
padded_preconditioners = [] | |
local_stats_flat = [] | |
exponents = [] | |
for param in params_flat: | |
preconditioner = Preconditioner( | |
param, block_size, best_effort_shape_interpretation | |
) | |
shapes = preconditioner.shapes_for_preconditioners() | |
sizes = [] | |
statistics = [] | |
preconditioners = [] | |
index_start = len(padded_statistics) | |
if not _skip_preconditioning(param): | |
sizes = [s[0] for s in shapes] | |
shapes = preconditioner.shapes_for_preconditioners() | |
statistics = [matrix_epsilon * jnp.eye(max_size) for s in shapes] | |
preconditioners = [jnp.eye(max_size) for s in shapes] | |
padded_statistics.extend(statistics) | |
padded_preconditioners.extend(preconditioners) | |
exponent = ( | |
preconditioner.exponent_for_preconditioner() | |
if exponent_override == 0 | |
else exponent_override | |
) | |
exponents.extend([exponent] * len(shapes)) | |
diagonal_statistics = [] | |
if _graft_type_has_diagonal_statistics(): | |
diagonal_statistics = jnp.zeros_like(param) | |
diagonal_momentum = _quantize_momentum([]) | |
momentum = _quantize_momentum(jnp.zeros_like(param)) | |
if _graft_type_has_diagonal_momentum_states(): | |
diagonal_momentum = _quantize_momentum((jnp.zeros_like(param))) | |
local_stats_flat.append( | |
LocalShardedParameterStats( | |
_quantize_diagonal_statistics(diagonal_statistics), | |
diagonal_momentum, | |
momentum, | |
init_training_metrics(len(sizes)), | |
index_start, | |
sizes, | |
) | |
) | |
local_stats = jax.tree_unflatten(treedef, local_stats_flat) | |
# Pad the statistics and preconditioner matrices to be a multiple of | |
# num devices. | |
# TODO(rohananil): Relax to only the size of the mesh axis where the dim | |
# is split on. | |
to_pad = -len(padded_statistics) % num_devices_for_pjit | |
padded_statistics.extend( | |
[jnp.eye(max_size, dtype=padded_statistics[0].dtype) for _ in range(to_pad)] | |
) | |
padded_preconditioners.extend( | |
[jnp.eye(max_size, dtype=padded_statistics[0].dtype) for _ in range(to_pad)] | |
) | |
exponents.extend([1 for _ in range(to_pad)]) | |
global_stats = GlobalShardedParameterStats( | |
jnp.stack(padded_statistics), | |
jnp.stack(padded_preconditioners), | |
jnp.stack(exponents), | |
) | |
return ShampooState( | |
count=jnp.zeros([], jnp.int32), | |
stats=ShardedShampooStats(global_stats, local_stats), | |
) | |
def _max_statistics_size_from_params(params): | |
max_size = 0 | |
for param in params: | |
param_clone = jnp.zeros(param.shape, dtype=param.dtype) | |
preconditioner = Preconditioner( | |
param_clone, block_size, best_effort_shape_interpretation | |
) | |
if not _skip_preconditioning(param): | |
shapes = preconditioner.shapes_for_preconditioners() | |
sizes = [s[0] for s in shapes] | |
max_size = max(max(sizes), max_size) | |
return max_size | |
def _remove_leading_sharding_annotation(pspec): | |
"""Mapping from N-d to (N-1)-d, used for quantization, factoring etc.""" | |
# None and PSpec(None) are valid PSpecs. | |
if pspec and len(pspec) > 1: | |
return pjit.PartitionSpec(*pspec[1:]) | |
else: | |
return None | |
def sharded_init_partition_spec_fn( | |
params, params_partition_spec, partition_spec_for_statistics | |
): | |
"""Returns a parallel state tree with PartitionSpec associated with state. | |
Args: | |
params: A pytree with params. | |
params_partition_spec: A pytree with PartitionSpec for params. | |
partition_spec_for_statistics: PartitionSpec for the statistics. | |
""" | |
# Parallel lists of spec, and params. | |
param_pspec_flat, _ = jax.tree_flatten( | |
params_partition_spec, is_leaf=lambda x: x is None | |
) | |
params_flat, treedef = jax.tree_flatten(params) | |
assert param_pspec_flat | |
assert params_flat | |
# Step is replicated across cores. | |
# None means cores. | |
local_stats_flat = [] | |
num_statistics = 0 | |
for param, param_pspec in zip(params_flat, param_pspec_flat): | |
param_clone = jnp.zeros(param.shape, dtype=param.dtype) | |
preconditioner = Preconditioner( | |
param_clone, block_size, best_effort_shape_interpretation | |
) | |
shapes = preconditioner.shapes_for_preconditioners() | |
sizes = [] | |
index_start = num_statistics | |
if not _skip_preconditioning(param): | |
sizes = [s[0] for s in shapes] | |
shapes = preconditioner.shapes_for_preconditioners() | |
num_statistics += len(shapes) | |
diagonal_statistics_pspec = [] | |
diagonal_statistics_scale_pspec = [] | |
if _graft_type_has_diagonal_statistics(): | |
# Identically shaped param. | |
diagonal_statistics_pspec = param_pspec | |
if quantized_dtype_for_diagonal_statistics_buffers() != jnp.float32: | |
diagonal_statistics_scale_pspec = ( | |
_remove_leading_sharding_annotation(param_pspec) | |
) | |
m1_pspec = [] | |
m1_scale_pspec = [] | |
if _graft_type_has_diagonal_momentum_states(): | |
m1_pspec = param_pspec | |
if quantized_dtype_for_momentum_buffers() != jnp.float32: | |
m1_scale_pspec = _remove_leading_sharding_annotation(m1_pspec) | |
m2_pspec = param_pspec | |
m2_scale_pspec = [] | |
if quantized_dtype_for_momentum_buffers() != jnp.float32: | |
m2_scale_pspec = _remove_leading_sharding_annotation(m2_pspec) | |
local_stats_flat.append( | |
LocalShardedParameterStats( | |
QuantizedValue( | |
diagonal_statistics_pspec, | |
[], | |
diagonal_statistics_scale_pspec, | |
quantized_dtype_for_diagonal_statistics_buffers(), | |
False, | |
list(param.shape), | |
), | |
QuantizedValue( | |
m1_pspec, | |
[], | |
m1_scale_pspec, | |
quantized_dtype_for_momentum_buffers(), | |
False, | |
list(param.shape), | |
), | |
QuantizedValue( | |
m2_pspec, | |
[], | |
m2_scale_pspec, | |
quantized_dtype_for_momentum_buffers(), | |
False, | |
list(param.shape), | |
), | |
init_training_metrics_pspec(len(sizes)), | |
index_start, | |
sizes, | |
) | |
) | |
local_stats = jax.tree_unflatten(treedef, local_stats_flat) | |
global_stats = GlobalShardedParameterStats( | |
partition_spec_for_statistics, | |
partition_spec_for_statistics, | |
pjit.PartitionSpec(), | |
) | |
count_pspec = pjit.PartitionSpec() | |
return ShampooState( | |
count=count_pspec, stats=ShardedShampooStats(global_stats, local_stats) | |
) | |
def sharded_init_shape_and_dtype_fn(params): | |
"""Returns a parallel state tree with shape, dtype associated with state. | |
Args: | |
params: A pytree with params. | |
""" | |
# Parallel lists of spec, and params. | |
params_flat, treedef = jax.tree_flatten(params) | |
assert params_flat | |
# Step is replicated across cores. | |
# None means cores. | |
local_stats_flat = [] | |
num_statistics = 0 | |
for param in params_flat: | |
param_clone = jnp.zeros(param.shape, dtype=param.dtype) | |
preconditioner = Preconditioner( | |
param_clone, block_size, best_effort_shape_interpretation | |
) | |
shapes = preconditioner.shapes_for_preconditioners() | |
sizes = [] | |
index_start = num_statistics | |
if not _skip_preconditioning(param): | |
sizes = [s[0] for s in shapes] | |
shapes = preconditioner.shapes_for_preconditioners() | |
num_statistics += len(shapes) | |
diagonal_statistics_shape_and_dtype = [] | |
diagonal_statistics_scale_shape_and_dtype = [] | |
if _graft_type_has_diagonal_statistics(): | |
diagonal_statistics_shape_and_dtype = [list(param.shape), param.dtype] | |
qdtype = quantized_dtype_for_diagonal_statistics_buffers() | |
if qdtype != jnp.float32: | |
diagonal_statistics_shape_and_dtype = [list(param.shape), qdtype] | |
diagonal_statistics_scale_shape_and_dtype = [ | |
list(param.shape)[1:], | |
param.dtype, | |
] | |
qdtype = quantized_dtype_for_momentum_buffers() | |
m1_shape_and_dtype = [] | |
m1_scale_shape_and_dtype = [] | |
if _graft_type_has_diagonal_momentum_states(): | |
m1_shape_and_dtype = [list(param.shape), qdtype] | |
if quantized_dtype_for_momentum_buffers() != jnp.float32: | |
m1_scale_shape_and_dtype = [list(param.shape)[1:], qdtype] | |
m2_shape_and_dtype = [list(param.shape), param.dtype] | |
m2_scale_shape_and_dtype = [] | |
if qdtype != jnp.float32: | |
m2_shape_and_dtype = [list(param.shape), qdtype] | |
m2_scale_shape_and_dtype = [list(param.shape)[1:], qdtype] | |
local_stats_flat.append( | |
LocalShardedParameterStats( | |
QuantizedValue( | |
diagonal_statistics_shape_and_dtype, | |
[], | |
diagonal_statistics_scale_shape_and_dtype, | |
quantized_dtype_for_diagonal_statistics_buffers(), | |
False, | |
list(param.shape), | |
), | |
QuantizedValue( | |
m1_shape_and_dtype, | |
[], | |
m1_scale_shape_and_dtype, | |
quantized_dtype_for_momentum_buffers(), | |
False, | |
list(param.shape), | |
), | |
QuantizedValue( | |
m2_shape_and_dtype, | |
[], | |
m2_scale_shape_and_dtype, | |
quantized_dtype_for_momentum_buffers(), | |
False, | |
list(param.shape), | |
), | |
init_training_metrics_shapes(len(sizes)), | |
index_start, | |
sizes, | |
) | |
) | |
local_stats = jax.tree_unflatten(treedef, local_stats_flat) | |
max_statistics_size = _max_statistics_size_from_params(params_flat) | |
to_pad = -num_statistics % num_devices_for_pjit | |
num_statistics += to_pad | |
statistics_shape = [num_statistics, max_statistics_size, max_statistics_size] | |
global_stats = GlobalShardedParameterStats( | |
[statistics_shape, jnp.float32], | |
[statistics_shape, jnp.float32], | |
[[num_statistics], jnp.int32], | |
) | |
return ShampooState( | |
count=[[], jnp.float32], | |
stats=ShardedShampooStats(global_stats, local_stats), | |
) | |
def sharded_update_fn(grads, state, params): | |
"""Transform the input gradient and update all statistics in sharded mode. | |
Args: | |
grads: the gradient tensors for the parameters. | |
state: a named tuple containing the state of the optimizer | |
params: the parameters that should be updated. | |
Returns: | |
A tuple containing the new parameters and the new optimizer state. | |
""" | |
params_flat, treedef = jax.tree_flatten(params) | |
grads_flat = treedef.flatten_up_to(grads) | |
global_stats = state.stats.global_stats | |
local_stats_flat = treedef.flatten_up_to(state.stats.local_stats) | |
stats_flat = [ | |
_convert_to_parameter_stats(global_stats, local_stat) | |
for local_stat in local_stats_flat | |
] | |
new_stats_flat = jax.tree_multimap( | |
lambda g, s, p: _compute_stats(g, s, p, state.count), | |
grads_flat, | |
stats_flat, | |
params_flat, | |
) | |
outputs = jax.tree_multimap( | |
lambda g, s, p: _transform_grad(g, s, p, state.count), | |
grads_flat, | |
new_stats_flat, | |
params_flat, | |
) | |
updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ()) | |
updates = jax.tree_unflatten(treedef, updates_flat) | |
# Create new local_stats | |
new_local_stats_flat = [ | |
_convert_from_parameter_stats(new_stat, local_stat) | |
for new_stat, local_stat in zip(new_stats_flat, local_stats_flat) | |
] | |
max_size = global_stats.statistics.shape[1] | |
new_padded_statistics = [] | |
for stat in new_stats_flat: | |
new_padded_statistics.extend( | |
[pad_matrix(stat, max_size) for stat in stat.statistics] | |
) | |
# Create global stats | |
# TODO(rohananil): Preconditioner is not updated every step, so cost of | |
# stack/pad can be obviated away. | |
# Pad the statistics and preconditioner matrices to be a multiple of | |
# num devices. | |
# TODO(rohananil): Relax to only the size of the mesh axis where the dim | |
# is split on. | |
to_pad = -len(new_padded_statistics) % num_devices_for_pjit | |
new_padded_statistics.extend( | |
[ | |
jnp.eye(max_size, dtype=new_padded_statistics[0].dtype) | |
for _ in range(to_pad) | |
] | |
) | |
new_stacked_padded_statistics = jnp.stack(new_padded_statistics) | |
new_stacked_padded_statistics = pjit.with_sharding_constraint( | |
new_stacked_padded_statistics, statistics_partition_spec | |
) | |
def _internal_inverse_pth_root_all(): | |
preconditioners, errors = _matrix_inverse_pth_root_pjit( | |
new_stacked_padded_statistics, | |
global_stats.exponents, | |
statistics_partition_spec, | |
) | |
return preconditioners, errors | |
if preconditioning_compute_steps == 1: | |
new_preconditioners, errors = _internal_inverse_pth_root_all() | |
else: | |
# Passing statistics instead of preconditioners as they are similarly | |
# shaped tensors. Note statistics will be ignored as we are passing in | |
# a large init value for error. | |
preconditioners_init = new_stacked_padded_statistics | |
n = new_stacked_padded_statistics.shape[0] | |
errors_init = jnp.ones([n], jnp.float32) * inverse_failure_threshold | |
init_state = [preconditioners_init, errors_init] | |
perform_step = state.count % preconditioning_compute_steps == 0 | |
new_preconditioners, errors = efficient_cond( | |
perform_step, _internal_inverse_pth_root_all, init_state | |
) | |
new_local_stats_flat = _add_error_into_local_stats( | |
new_local_stats_flat, errors, inverse_failure_threshold | |
) | |
new_local_stats = jax.tree_unflatten(treedef, new_local_stats_flat) | |
errors = errors.reshape((-1, 1, 1)) | |
predicate = jnp.logical_or( | |
jnp.isnan(errors), errors >= inverse_failure_threshold | |
).astype(new_preconditioners.dtype) | |
# TODO(rohananil): Check for numerical instabilities. | |
new_conditional_preconditioners = ( | |
predicate * global_stats.preconditioners | |
+ (1.0 - predicate) * new_preconditioners | |
) | |
new_global_stats = GlobalShardedParameterStats( | |
new_stacked_padded_statistics, | |
new_conditional_preconditioners, | |
global_stats.exponents, | |
) | |
new_shampoo_state = ShampooState( | |
count=state.count + 1, | |
stats=ShardedShampooStats(new_global_stats, new_local_stats), | |
) | |
return updates, new_shampoo_state | |
def init_fn(params): | |
"""Initialise the optimiser's state.""" | |
def _init(param): | |
preconditioner = Preconditioner( | |
param, block_size, best_effort_shape_interpretation | |
) | |
statistics = [] | |
preconditioners = [] | |
if not _skip_preconditioning(param): | |
shapes = preconditioner.shapes_for_preconditioners() | |
statistics = [matrix_epsilon * jnp.eye(s[0]) for s in shapes] | |
preconditioners = [jnp.eye(s[0]) for s in shapes] | |
diagonal_statistics = [] | |
if _graft_type_has_diagonal_statistics(): | |
diagonal_statistics = jnp.zeros_like(param) | |
diagonal_momentum = _quantize_momentum([]) | |
momentum = _quantize_momentum(jnp.zeros_like(param)) | |
if _graft_type_has_diagonal_momentum_states(): | |
diagonal_momentum = _quantize_momentum(jnp.zeros_like(param)) | |
return ParameterStats( | |
_quantize_diagonal_statistics(diagonal_statistics), | |
_maybe_quantize_statistics(statistics), | |
_maybe_quantize_preconditioners(preconditioners), | |
diagonal_momentum, | |
momentum, | |
init_training_metrics(len(statistics)), | |
) | |
return ShampooState( | |
count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params) | |
) | |
def _skip_preconditioning(param): | |
return len(param.shape) < 1 or any( | |
[s > skip_preconditioning_dim_size_gt for s in param.shape] | |
) | |
def _compute_stats(grad, state, param, step): | |
"""Compute per-parameter statistics.""" | |
preconditioner = Preconditioner( | |
param, block_size, best_effort_shape_interpretation | |
) | |
new_statistics = [[]] * len(state.statistics) | |
w1 = beta2 | |
w2 = beta2 if beta2 == 1.0 else (1.0 - beta2) | |
if not _skip_preconditioning(param): | |
def compute_updated_statistics(): | |
new_stats = preconditioner.statistics_from_grad(grad) | |
new_stats_accumulators = [] | |
for stat, stat_accumulator in zip(new_stats, state.statistics): | |
new_stats_accumulators.append( | |
w1 * _to_float(stat_accumulator) + w2 * stat | |
) | |
return _maybe_quantize_statistics(new_stats_accumulators) | |
if statistics_compute_steps > 1: | |
perform_step = step % statistics_compute_steps == 0 | |
init_state = state.statistics | |
new_statistics = list( | |
efficient_cond(perform_step, compute_updated_statistics, init_state) | |
) | |
else: | |
new_statistics = compute_updated_statistics() | |
return ParameterStats( | |
state.diagonal_statistics, | |
new_statistics, | |
state.preconditioners, | |
state.diagonal_momentum, | |
state.momentum, | |
state.training_metrics, | |
) | |
def _matrix_inverse_pth_root_vmap(xs, ps): | |
mi_pth_root = functools.partial( | |
matrix_inverse_pth_root, ridge_epsilon=matrix_epsilon, precision=precision | |
) | |
return jax.vmap(mi_pth_root)(xs, ps) | |
def _quantized_matrix_inverse_pth_root_vmap(qxs, qds, qbs, ps): | |
def _quantized_to_float(qx, qd, qb): | |
qv = QuantizedValue(qx, qd, qb, qx.dtype, True, list(qx.shape)) | |
return qv.to_float() | |
def matrix_inverse_pth_root_wrapper(qx, qd, qb, p): | |
v = _quantized_to_float(qx, qd, qb) | |
preconditioner, error = matrix_inverse_pth_root( | |
v, p, ridge_epsilon=matrix_epsilon, precision=precision | |
) | |
qp = QuantizedValue.from_float_value(preconditioner, qx.dtype, True) | |
return qp.quantized, qp.diagonal, qp.bucket_size, error | |
return jax.vmap(matrix_inverse_pth_root_wrapper)(qxs, qds, qbs, ps) | |
def _matrix_inverse_pth_root_pjit(xs, ps, statistics_partition_spec=None): | |
# Partition the concatenated statistics matrix across all cores. | |
pspec_for_partition = preconditioner_partition_spec | |
partitioned_xs = pjit.with_sharding_constraint(xs, pspec_for_partition) | |
partitioned_ps = pjit.with_sharding_constraint( | |
ps, pjit.PartitionSpec(preconditioner_partition_spec[0]) | |
) | |
# Run matrix inverse pth root on each shard. | |
partitioned_preconditioners, partitioned_errors = _matrix_inverse_pth_root_vmap( | |
partitioned_xs, partitioned_ps | |
) | |
# Reshard output to have the same PSpec as input. This is required to avoid | |
# vmap seeing the full set of statistics. | |
partitioned_preconditioners = pjit.with_sharding_constraint( | |
partitioned_preconditioners, pspec_for_partition | |
) | |
# Recombine the outputs at each core. | |
preconditioners = pjit.with_sharding_constraint( | |
partitioned_preconditioners, statistics_partition_spec | |
) | |
errors = pjit.with_sharding_constraint(partitioned_errors, pjit.PartitionSpec()) | |
return preconditioners, errors | |
def _pmap_compute_preconditioners( | |
states, | |
step, | |
statistics, | |
num_statistics_per_state, | |
original_shapes, | |
exponents, | |
max_size, | |
prev_preconditioners, | |
): | |
"""Computes preconditioners for given statistics in states in PMAP mode. | |
Args: | |
states: A list of optimizer states. | |
step: Current step number | |
statistics: A list of statistics for all variables (for every dim) | |
num_statistics_per_state: Number of statistis per state to reconstruct | |
output states. | |
original_shapes: A list of shapes of the statistics. | |
exponents: Exponent power to use for inverse-pth roots. | |
max_size: Maximum dim of the statistics to pad. | |
prev_preconditioners: Previously available preconditioner. | |
Returns: | |
New optimizer states after computing the preconditioner. | |
""" | |
num_devices = lax.psum(1, batch_axis_name) | |
num_statistics = len(statistics) | |
# Pad statistics and exponents to next multiple of num_devices. | |
packed_statistics = [pad_matrix(stat, max_size) for stat in statistics] | |
to_pad = -num_statistics % num_devices | |
packed_statistics.extend( | |
[jnp.eye(max_size, dtype=packed_statistics[0].dtype) for _ in range(to_pad)] | |
) | |
exponents.extend([1 for _ in range(to_pad)]) | |
if not packed_statistics: | |
return states | |
all_statistics = batch(packed_statistics, num_devices) | |
all_exponents = batch(exponents, num_devices) | |
def _internal_inverse_pth_root_all(): | |
current_replica = lax.axis_index(batch_axis_name) | |
preconditioners, errors = _matrix_inverse_pth_root_vmap( | |
all_statistics[current_replica], all_exponents[current_replica] | |
) | |
preconditioners = jax.lax.all_gather(preconditioners, batch_axis_name) | |
errors = jax.lax.all_gather(errors, batch_axis_name) | |
preconditioners_flat = unbatch(preconditioners) | |
errors_flat = unbatch(errors) | |
return preconditioners_flat, errors_flat | |
if preconditioning_compute_steps == 1: | |
preconditioners_flat, errors_flat = _internal_inverse_pth_root_all() | |
else: | |
# Passing statistics instead of preconditioners as they are similarly | |
# shaped tensors. Note statistics will be ignored as we are passing in | |
# a large init value for error. | |
preconditioners_init = packed_statistics | |
errors_init = [inverse_failure_threshold] * len(packed_statistics) | |
init_state = [preconditioners_init, errors_init] | |
perform_step = step % preconditioning_compute_steps == 0 | |
preconditioners_flat, errors_flat = efficient_cond( | |
perform_step, _internal_inverse_pth_root_all, init_state | |
) | |
def _skip(error): | |
condition = jnp.logical_or( | |
jnp.isnan(error), error >= inverse_failure_threshold | |
) | |
return condition.astype(error.dtype) | |
def _select_preconditioner(error, new_p, old_p): | |
return lax.cond( | |
_skip(error), lambda _: old_p, lambda _: new_p, operand=None | |
) | |
new_preconditioners_flat = [] | |
new_errors_flat = [] | |
for p, shape, prev_p, error in zip( | |
preconditioners_flat, original_shapes, prev_preconditioners, errors_flat | |
): | |
new_preconditioners_flat.append( | |
_select_preconditioner(error, p[: shape[0], : shape[1]], prev_p) | |
) | |
new_errors_flat.append(error) | |
assert len(states) == len(num_statistics_per_state) | |
assert len(new_preconditioners_flat) == num_statistics | |
assert len(new_errors_flat) == num_statistics | |
# Add back empty preconditioners so we that we can set the optimizer state. | |
preconditioners_for_states = [] | |
idx = 0 | |
errors_for_states = [] | |
for num_statistics, state in zip(num_statistics_per_state, states): | |
if num_statistics == 0: | |
preconditioners_for_states.append([]) | |
errors_for_states.append([]) | |
else: | |
preconditioners_for_state = new_preconditioners_flat[ | |
idx : idx + num_statistics | |
] | |
assert len(state.statistics) == len(preconditioners_for_state) | |
preconditioners_for_states.append(preconditioners_for_state) | |
errors_for_state = jnp.stack( | |
new_errors_flat[idx : idx + num_statistics] | |
) | |
assert len(state.statistics) == len(errors_for_state) | |
errors_for_states.append(errors_for_state) | |
idx += num_statistics | |
new_states = [] | |
for state, new_preconditioners, new_errors in zip( | |
states, preconditioners_for_states, errors_for_states | |
): | |
if state.statistics: | |
new_errors = jnp.where( | |
jnp.logical_and( | |
new_errors > 0.0, new_errors != inverse_failure_threshold | |
), | |
new_errors, | |
state.training_metrics.inverse_pth_root_errors, | |
) | |
new_training_metrics = TrainingMetrics(new_errors) | |
new_states.append( | |
ParameterStats( | |
state.diagonal_statistics, | |
state.statistics, | |
new_preconditioners, | |
state.diagonal_momentum, | |
state.momentum, | |
new_training_metrics, | |
) | |
) | |
return new_states | |
def _pmap_quantized_compute_preconditioners( | |
states, | |
step, | |
statistics, | |
num_statistics_per_state, | |
original_shapes, | |
exponents, | |
max_size, | |
prev_preconditioners, | |
): | |
"""Computes preconditioners for given statistics in states in PMAP mode. | |
For quantization, each statistic is represented by three values: | |
quantized matrix, diagonal, and bucket sizes, we run inverse pth-roots | |
without ever recreating the original matrix in f32. | |
Args: | |
states: A list of optimizer states. | |
step: Current step number | |
statistics: A list of statistics for all variables (for every dim) | |
num_statistics_per_state: Number of statistis per state to reconstruct | |
output states. | |
original_shapes: A list of shapes of the statistics. | |
exponents: Exponent power to use for inverse-pth roots. | |
max_size: Maximum dim of the statistics to pad. | |
prev_preconditioners: Previously available preconditioner. | |
Returns: | |
New optimizer states after computing the preconditioner. | |
""" | |
num_devices = lax.psum(1, batch_axis_name) | |
num_statistics = len(statistics) | |
quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers() | |
# Complexity here is around: shapes needing be statically shaped, | |
# our custom quantization type requires a different type of packing. | |
# Parallel tensors: | |
# quantized [dxd] | |
# diagonals [d] f32 | |
# bucket_sizes [d] f32 | |
packed_quantized_statistics = [ | |
pad_matrix(stat.quantized, max_size) for stat in statistics | |
] | |
packed_quantized_diagonals = [ | |
pad_vector(stat.diagonal, max_size) for stat in statistics | |
] | |
packed_quantized_bucket_sizes = [ | |
pad_vector(stat.bucket_size, max_size) for stat in statistics | |
] | |
to_pad = -num_statistics % num_devices | |
padded_eye = jnp.eye(max_size, dtype=jnp.float32) | |
quantized_eye = QuantizedValue.from_float_value( | |
padded_eye, quantized_dtype, True | |
) | |
packed_quantized_statistics.extend( | |
[quantized_eye.quantized for _ in range(to_pad)] | |
) | |
packed_quantized_diagonals.extend( | |
[quantized_eye.diagonal for _ in range(to_pad)] | |
) | |
packed_quantized_bucket_sizes.extend( | |
[quantized_eye.bucket_size for _ in range(to_pad)] | |
) | |
exponents.extend([1 for _ in range(to_pad)]) | |
if not packed_quantized_statistics: | |
return states | |
all_quantized_statistics = batch(packed_quantized_statistics, num_devices) | |
all_quantized_diagonals = batch(packed_quantized_diagonals, num_devices) | |
all_quantized_bucket_sizes = batch(packed_quantized_bucket_sizes, num_devices) | |
all_exponents = batch(exponents, num_devices) | |
def _internal_inverse_pth_root_all(): | |
current_replica = lax.axis_index(batch_axis_name) | |
( | |
quantized_preconditioners, | |
quantized_diagonals, | |
quantized_bucket_sizes, | |
errors, | |
) = _quantized_matrix_inverse_pth_root_vmap( | |
all_quantized_statistics[current_replica], | |
all_quantized_diagonals[current_replica], | |
all_quantized_bucket_sizes[current_replica], | |
all_exponents[current_replica], | |
) | |
quantized_preconditioners = jax.lax.all_gather( | |
quantized_preconditioners, batch_axis_name | |
) | |
quantized_diagonals = jax.lax.all_gather( | |
quantized_diagonals, batch_axis_name | |
) | |
quantized_bucket_sizes = jax.lax.all_gather( | |
quantized_bucket_sizes, batch_axis_name | |
) | |
errors = jax.lax.all_gather(errors, batch_axis_name) | |
quantized_preconditioners_flat = unbatch(quantized_preconditioners) | |
quantized_diagonals_flat = unbatch(quantized_diagonals) | |
quantized_bucket_sizes_flat = unbatch(quantized_bucket_sizes) | |
errors_flat = unbatch(errors) | |
return ( | |
quantized_preconditioners_flat, | |
quantized_diagonals_flat, | |
quantized_bucket_sizes_flat, | |
errors_flat, | |
) | |
if preconditioning_compute_steps == 1: | |
( | |
quantized_preconditioners_flat, | |
quantized_diagonals_flat, | |
quantized_bucket_sizes_flat, | |
errors_flat, | |
) = _internal_inverse_pth_root_all() | |
else: | |
# Passing statistics instead of preconditioners as they are similarly | |
# shaped tensors. Note statistics will be ignored as we are passing in | |
# a large init value for error. | |
quantized_preconditioners_init = packed_quantized_statistics | |
quantized_diagonals_init = packed_quantized_diagonals | |
quantized_bucket_sizes_init = packed_quantized_bucket_sizes | |
errors_init = [inverse_failure_threshold] * len( | |
quantized_preconditioners_init | |
) | |
init_state = [ | |
quantized_preconditioners_init, | |
quantized_diagonals_init, | |
quantized_bucket_sizes_init, | |
errors_init, | |
] | |
perform_step = step % preconditioning_compute_steps == 0 | |
( | |
quantized_preconditioners_flat, | |
quantized_diagonals_flat, | |
quantized_bucket_sizes_flat, | |
errors_flat, | |
) = efficient_cond(perform_step, _internal_inverse_pth_root_all, init_state) | |
def _skip(error): | |
condition = jnp.logical_or( | |
jnp.isnan(error), error >= inverse_failure_threshold | |
) | |
return condition.astype(error.dtype) | |
def _select_preconditioner(error, new_p, old_p): | |
return lax.cond( | |
_skip(error), lambda _: old_p, lambda _: new_p, operand=None | |
) | |
new_quantized_preconditioners_flat = [] | |
new_quantized_diagonals_flat = [] | |
new_quantized_bucket_sizes_flat = [] | |
new_errors_flat = [] | |
for p, d, b, shape, prev_p, error in zip( | |
quantized_preconditioners_flat, | |
quantized_diagonals_flat, | |
quantized_bucket_sizes_flat, | |
original_shapes, | |
prev_preconditioners, | |
errors_flat, | |
): | |
new_quantized_preconditioners_flat.append( | |
_select_preconditioner( | |
error, p[: shape[0], : shape[1]], prev_p.quantized | |
) | |
) | |
new_quantized_diagonals_flat.append( | |
_select_preconditioner(error, d[: shape[0]], prev_p.diagonal) | |
) | |
new_quantized_bucket_sizes_flat.append( | |
_select_preconditioner(error, b[: shape[0]], prev_p.bucket_size) | |
) | |
new_errors_flat.append(error) | |
assert len(states) == len(num_statistics_per_state) | |
assert len(new_quantized_preconditioners_flat) == num_statistics | |
assert len(new_quantized_diagonals_flat) == num_statistics | |
assert len(new_quantized_bucket_sizes_flat) == num_statistics | |
# Add back empty preconditioners so we that we can set the optimizer state. | |
preconditioners_for_states = [] | |
errors_for_states = [] | |
idx = 0 | |
for num_statistics, state in zip(num_statistics_per_state, states): | |
if num_statistics == 0: | |
preconditioners_for_states.append([]) | |
errors_for_states.append([]) | |
else: | |
quantized_preconditioners_for_state = ( | |
new_quantized_preconditioners_flat[idx : idx + num_statistics] | |
) | |
quantized_diagonals_for_state = new_quantized_diagonals_flat[ | |
idx : idx + num_statistics | |
] | |
quantized_bucket_sizes_for_state = new_quantized_bucket_sizes_flat[ | |
idx : idx + num_statistics | |
] | |
errors_for_state = jnp.stack( | |
new_errors_flat[idx : idx + num_statistics] | |
) | |
assert len(state.statistics) == len(quantized_preconditioners_for_state) | |
assert len(state.statistics) == len(quantized_diagonals_for_state) | |
assert len(state.statistics) == len(quantized_bucket_sizes_for_state) | |
assert len(state.statistics) == len(errors_for_state) | |
quantized_preconditioners = [] | |
for qv, qd, qb in zip( | |
quantized_preconditioners_for_state, | |
quantized_diagonals_for_state, | |
quantized_bucket_sizes_for_state, | |
): | |
quantized_preconditioners.append( | |
QuantizedValue(qv, qd, qb, qv.dtype, True, list(qv.shape)) | |
) | |
preconditioners_for_states.append(quantized_preconditioners) | |
errors_for_states.append(errors_for_state) | |
idx += num_statistics | |
new_states = [] | |
for state, new_preconditioners, new_errors in zip( | |
states, preconditioners_for_states, errors_for_states | |
): | |
if state.statistics: | |
new_errors = jnp.where( | |
jnp.logical_and( | |
new_errors > 0.0, new_errors != inverse_failure_threshold | |
), | |
new_errors, | |
state.training_metrics.inverse_pth_root_errors, | |
) | |
new_training_metrics = TrainingMetrics(new_errors) | |
new_states.append( | |
ParameterStats( | |
state.diagonal_statistics, | |
state.statistics, | |
new_preconditioners, | |
state.diagonal_momentum, | |
state.momentum, | |
new_training_metrics, | |
) | |
) | |
return new_states | |
def _pjit_compute_preconditioners( | |
states, | |
step, | |
statistics, | |
num_statistics_per_state, | |
original_shapes, | |
exponents, | |
max_size, | |
prev_preconditioners, | |
): | |
"""Computes preconditioners for given statistics in states in PJIT mode. | |
Args: | |
states: A list of optimizer states. | |
step: Current step number | |
statistics: A list of statistics for all variables (for every dim) | |
num_statistics_per_state: Number of statistis per state to reconstruct | |
output states. | |
original_shapes: A list of shapes of the statistics. | |
exponents: Exponent power to use for inverse-pth roots. | |
max_size: Maximum dim of the statistics to pad. | |
prev_preconditioners: Previously available preconditioner. | |
Returns: | |
New optimizer states after computing the preconditioner. | |
""" | |
num_statistics = len(statistics) | |
to_pad = -num_statistics % num_devices_for_pjit | |
padded_statistics = [pad_matrix(stat, max_size) for stat in statistics] | |
padded_statistics.extend( | |
[jnp.eye(max_size, dtype=padded_statistics[0].dtype) for _ in range(to_pad)] | |
) | |
exponents.extend([1 for _ in range(to_pad)]) | |
all_statistics = jnp.stack(padded_statistics) | |
all_exponents = jnp.stack(exponents) | |
def _internal_inverse_pth_root_all(): | |
preconditioners, errors = _matrix_inverse_pth_root_pjit( | |
all_statistics, all_exponents | |
) | |
b1 = preconditioners.shape[0] | |
def split(batched_values): | |
return [ | |
jnp.squeeze(v) | |
for v in jnp.split(batched_values, indices_or_sections=b1, axis=0) | |
] | |
return split(preconditioners), split(errors) | |
if preconditioning_compute_steps == 1: | |
preconditioners_flat, errors_flat = _internal_inverse_pth_root_all() | |
else: | |
# Passing statistics instead of preconditioners as they are similarly | |
# shaped tensors. Note statistics will be ignored as we are passing in | |
# a large init value for error. | |
preconditioners_init = padded_statistics | |
errors_init = [inverse_failure_threshold] * len(padded_statistics) | |
init_state = [preconditioners_init, errors_init] | |
perform_step = step % preconditioning_compute_steps == 0 | |
preconditioners_flat, errors_flat = efficient_cond( | |
perform_step, _internal_inverse_pth_root_all, init_state | |
) | |
def _skip(error): | |
condition = jnp.logical_or( | |
jnp.isnan(error), error >= inverse_failure_threshold | |
) | |
return condition.astype(error.dtype) | |
def _select_preconditioner(error, new_p, old_p): | |
return lax.cond( | |
_skip(error), lambda _: old_p, lambda _: new_p, operand=None | |
) | |
new_preconditioners_flat = [] | |
new_errors_flat = [] | |
for p, shape, prev_p, error in zip( | |
preconditioners_flat, original_shapes, prev_preconditioners, errors_flat | |
): | |
new_preconditioners_flat.append( | |
_select_preconditioner(error, p[: shape[0], : shape[1]], prev_p) | |
) | |
new_errors_flat.append(error) | |
assert len(states) == len(num_statistics_per_state) | |
assert len(new_preconditioners_flat) == num_statistics | |
# Add back empty preconditioners so we that we can set the optimizer state. | |
preconditioners_for_states = [] | |
errors_for_states = [] | |
idx = 0 | |
for num_statistics, state in zip(num_statistics_per_state, states): | |
if num_statistics == 0: | |
preconditioners_for_states.append([]) | |
errors_for_states.append([]) | |
else: | |
preconditioners_for_state = new_preconditioners_flat[ | |
idx : idx + num_statistics | |
] | |
assert len(state.statistics) == len(preconditioners_for_state) | |
preconditioners_for_states.append(preconditioners_for_state) | |
errors_for_state = jnp.stack( | |
new_errors_flat[idx : idx + num_statistics] | |
) | |
assert len(state.statistics) == len(errors_for_state) | |
errors_for_states.append(errors_for_state) | |
idx += num_statistics | |
new_states = [] | |
for state, new_preconditioners, new_errors in zip( | |
states, preconditioners_for_states, errors_for_states | |
): | |
if state.statistics: | |
new_errors = jnp.where( | |
jnp.logical_and( | |
new_errors > 0.0, new_errors != inverse_failure_threshold | |
), | |
new_errors, | |
state.training_metrics.inverse_pth_root_errors, | |
) | |
new_training_metrics = TrainingMetrics(new_errors) | |
new_states.append( | |
ParameterStats( | |
state.diagonal_statistics, | |
state.statistics, | |
new_preconditioners, | |
state.diagonal_momentum, | |
state.momentum, | |
new_training_metrics, | |
) | |
) | |
return new_states | |
def _compute_preconditioners(states, params, step): | |
"""Computes preconditioners for given statistics in states. | |
Args: | |
states: A list of optimizer states. | |
params: A list of params. | |
step: Current step number | |
Returns: | |
New optimizer states after computing the preconditioner. | |
""" | |
statistics = [] | |
num_statistics_per_state = [] | |
original_shapes = [] | |
exponents = [] | |
max_size = 0 | |
prev_preconditioners = [] | |
for state, param in zip(states, params): | |
num_statistics = len(state.statistics) | |
num_statistics_per_state.append(num_statistics) | |
original_shapes_for_state = [] | |
if num_statistics > 0: | |
preconditioner = Preconditioner( | |
param, block_size, best_effort_shape_interpretation | |
) | |
for statistic in state.statistics: | |
exponents.append( | |
preconditioner.exponent_for_preconditioner() | |
if exponent_override == 0 | |
else exponent_override | |
) | |
original_shapes_for_state.append(statistic.shape) | |
max_size = max(max_size, statistic.shape[0]) | |
statistics.extend(state.statistics) | |
prev_preconditioners.extend(state.preconditioners) | |
original_shapes.extend(original_shapes_for_state) | |
if batch_axis_name: | |
# Quantization is only enabled if batch_axis_name is not set. | |
quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers() | |
if quantized_dtype == jnp.float32: | |
return _pmap_compute_preconditioners( | |
states, | |
step, | |
statistics, | |
num_statistics_per_state, | |
original_shapes, | |
exponents, | |
max_size, | |
prev_preconditioners, | |
) | |
else: | |
return _pmap_quantized_compute_preconditioners( | |
states, | |
step, | |
statistics, | |
num_statistics_per_state, | |
original_shapes, | |
exponents, | |
max_size, | |
prev_preconditioners, | |
) | |
else: | |
return _pjit_compute_preconditioners( | |
states, | |
step, | |
statistics, | |
num_statistics_per_state, | |
original_shapes, | |
exponents, | |
max_size, | |
prev_preconditioners, | |
) | |
def _transform_grad(grad, state, param, step): | |
"""Transform per-parameter gradients.""" | |
preconditioner = Preconditioner( | |
param, block_size, best_effort_shape_interpretation | |
) | |
sgd_update = grad | |
new_diagonal_statistics = state.diagonal_statistics.to_float() | |
if ( | |
graft_type == GraftingType.ADAGRAD | |
or graft_type == GraftingType.ADAGRAD_NORMALIZED | |
): | |
scaled_grad = grad | |
if graft_type == GraftingType.ADAGRAD_NORMALIZED: | |
scaled_grad = grad / jnp.linalg.norm(grad) | |
new_diagonal_statistics = state.diagonal_statistics.to_float() + jnp.square( | |
scaled_grad | |
) | |
adagrad_update = scaled_grad / ( | |
jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon | |
) | |
grafting_update = adagrad_update | |
elif ( | |
graft_type == GraftingType.RMSPROP | |
or graft_type == GraftingType.RMSPROP_NORMALIZED | |
): | |
scaled_grad = grad | |
if graft_type == GraftingType.RMSPROP_NORMALIZED: | |
scaled_grad = grad / jnp.linalg.norm(grad) | |
w1 = beta2 | |
w2 = beta2 if beta2 == 1.0 else (1.0 - beta2) | |
new_diagonal_statistics = ( | |
w1 * state.diagonal_statistics.to_float() + w2 * jnp.square(scaled_grad) | |
) | |
rmsprop_update = scaled_grad / ( | |
jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon | |
) | |
if clip_by_scaled_gradient_norm: | |
scaled_grad_norm = jnp.linalg.norm(rmsprop_update) / ( | |
jnp.sqrt(float(rmsprop_update.size)) | |
) | |
clipping_denom = jnp.maximum( | |
1.0, scaled_grad_norm / clip_by_scaled_gradient_norm | |
) | |
rmsprop_update /= clipping_denom | |
grafting_update = rmsprop_update | |
elif graft_type == GraftingType.SGD: | |
grafting_update = sgd_update | |
else: | |
grafting_update = jnp.ones_like(sgd_update) * jnp.sign(sgd_update) | |
precond_grad = grad | |
if not _skip_preconditioning(param): | |
precond_grad = preconditioner.preconditioned_grad( | |
precond_grad, _maybe_dequantize_preconditioners(state.preconditioners) | |
) | |
else: | |
precond_grad = grafting_update | |
grafting_update_norm = jnp.linalg.norm(grafting_update) | |
precond_grad_norm = jnp.linalg.norm(precond_grad) | |
multiplier = grafting_update_norm / (precond_grad_norm + 1e-16) | |
shampoo_update = precond_grad * multiplier | |
shampoo_update_with_wd = shampoo_update | |
grafting_update_with_wd = grafting_update | |
if weight_decay != 0: | |
shampoo_update_with_wd = shampoo_update + weight_decay * param | |
grafting_update_with_wd = grafting_update + weight_decay * param | |
w = (1.0 - beta1) if moving_average_for_momentum else 1.0 | |
shampoo_update_with_wd_momentum = ( | |
state.momentum.to_float() * beta1 + w * shampoo_update_with_wd | |
) | |
if _graft_type_has_diagonal_momentum_states(): | |
grafting_update_with_wd_momentum = ( | |
state.diagonal_momentum.to_float() * beta1 + w * grafting_update_with_wd | |
) | |
else: | |
# Share the momentum buffer | |
grafting_update_with_wd_momentum = ( | |
state.momentum.to_float() * beta1 + w * grafting_update_with_wd | |
) | |
run_shampoo = (step >= start_preconditioning_step).astype( | |
grafting_update_with_wd_momentum.dtype | |
) | |
momentum_update = ( | |
run_shampoo * shampoo_update_with_wd_momentum | |
+ (1.0 - run_shampoo) * grafting_update_with_wd_momentum | |
) | |
wd_update = ( | |
run_shampoo * shampoo_update_with_wd | |
+ (1.0 - run_shampoo) * grafting_update_with_wd | |
) | |
nesterov_momentum_update = momentum_update | |
if nesterov: | |
nesterov_momentum_update = w * wd_update + beta1 * momentum_update | |
lr = learning_rate | |
if callable(learning_rate): | |
lr = learning_rate(step) | |
transformed_update = -1.0 * lr * nesterov_momentum_update | |
new_diagonal_momentum = grafting_update_with_wd_momentum | |
new_momentum = shampoo_update_with_wd_momentum | |
if not _graft_type_has_diagonal_momentum_states(): | |
new_diagonal_momentum = [] | |
new_momentum = momentum_update | |
param_stats = ParameterStats( | |
_quantize_diagonal_statistics(new_diagonal_statistics), | |
state.statistics, | |
state.preconditioners, | |
_quantize_momentum(new_diagonal_momentum), | |
_quantize_momentum(new_momentum), | |
state.training_metrics, | |
) | |
return transformed_update, param_stats | |
def update_fn(grads, state, params): | |
"""Transform the input gradient and update all statistics. | |
Args: | |
grads: the gradient tensors for the parameters. | |
state: a named tuple containing the state of the optimizer | |
params: the parameters that should be updated. | |
Returns: | |
A tuple containing the new parameters and the new optimizer state. | |
""" | |
params_flat, treedef = jax.tree_flatten(params) | |
stats_flat = treedef.flatten_up_to(state.stats) | |
grads_flat = treedef.flatten_up_to(grads) | |
new_stats_flat = jax.tree_multimap( | |
lambda g, s, p: _compute_stats(g, s, p, state.count), | |
grads_flat, | |
stats_flat, | |
params_flat, | |
) | |
new_stats_flat = _compute_preconditioners( | |
new_stats_flat, params_flat, state.count | |
) | |
outputs = jax.tree_multimap( | |
lambda g, s, p: _transform_grad(g, s, p, state.count), | |
grads_flat, | |
new_stats_flat, | |
params_flat, | |
) | |
updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ()) | |
updates = jax.tree_unflatten(treedef, updates_flat) | |
new_stats = jax.tree_unflatten(treedef, new_stats_flat) | |
new_state = ShampooState(count=state.count + 1, stats=new_stats) | |
return updates, new_state | |
if shard_optimizer_states: | |
# Hijacks the init_fn signature so we can return an OptState with | |
# appropriate init_fns. | |
def _init_fns(unused_params): | |
return InitFnState( | |
init_fn=sharded_init_fn, | |
pspec_fn=sharded_init_partition_spec_fn, | |
shape_and_dtype_fn=sharded_init_shape_and_dtype_fn, | |
) | |
return optax.GradientTransformation(_init_fns, sharded_update_fn) | |
else: | |
return optax.GradientTransformation(init_fn, update_fn) | |