|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Distributed Shampoo Implementation.""" |
|
|
|
import enum |
|
import functools |
|
import itertools |
|
from typing import Any, List, NamedTuple |
|
|
|
import chex |
|
from flax import struct |
|
import jax |
|
from jax import lax |
|
import jax.experimental.pjit as pjit |
|
import jax.numpy as jnp |
|
import numpy as np |
|
import optax |
|
|
|
PartitionSpec = pjit.PartitionSpec |
|
|
|
|
|
|
|
@struct.dataclass |
|
class QuantizedValue: |
|
"""State associated with quantized value.""" |
|
quantized: chex.Array |
|
diagonal: chex.Array |
|
bucket_size: chex.Array |
|
quantized_dtype: jnp.dtype = struct.field( |
|
pytree_node=False) |
|
extract_diagonal: bool = struct.field( |
|
pytree_node=False) |
|
shape: Any = struct.field(pytree_node=False) |
|
|
|
@classmethod |
|
def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False): |
|
if isinstance(fvalue, list) and not fvalue: |
|
return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, []) |
|
quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize( |
|
fvalue, quantized_dtype, extract_diagonal) |
|
return QuantizedValue(quantized, diagonal_fvalue, bucket_size, |
|
quantized_dtype, extract_diagonal, |
|
list(quantized.shape)) |
|
|
|
|
|
|
|
@classmethod |
|
def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False): |
|
"""Returns quantized value and the bucket.""" |
|
if quantized_dtype == jnp.float32: |
|
return fvalue, [], [] |
|
elif quantized_dtype == jnp.bfloat16: |
|
return fvalue.astype(jnp.bfloat16), [], [] |
|
|
|
float_dtype = fvalue.dtype |
|
if quantized_dtype == jnp.int8: |
|
|
|
num_buckets = jnp.array(127.0, dtype=float_dtype) |
|
elif quantized_dtype == jnp.int16: |
|
|
|
num_buckets = jnp.array(32767.0, dtype=float_dtype) |
|
else: |
|
raise ValueError(f'Quantized dtype {quantized_dtype} not supported.') |
|
|
|
|
|
if extract_diagonal and fvalue.ndim != 2: |
|
raise ValueError( |
|
f'Input array {fvalue} must be 2D to work with extract_diagonal.') |
|
|
|
diagonal_fvalue = [] |
|
if extract_diagonal: |
|
diagonal_fvalue = jnp.diag(fvalue) |
|
|
|
fvalue = fvalue - jnp.diag(diagonal_fvalue) |
|
|
|
|
|
|
|
|
|
if fvalue.ndim < 1: |
|
raise ValueError( |
|
f'Input array {fvalue} must have a strictly positive number of ' |
|
'dimensions.') |
|
|
|
max_abs = jnp.max(jnp.abs(fvalue), axis=0) |
|
bucket_size = max_abs / num_buckets |
|
bs_expanded = bucket_size[jnp.newaxis, Ellipsis] |
|
|
|
bs_nonzero = jnp.where(bs_expanded > 0.0, bs_expanded, |
|
jnp.ones_like(bs_expanded)) |
|
ratio = fvalue / bs_nonzero |
|
|
|
quantized = jnp.round(ratio) |
|
return quantized.astype(quantized_dtype), diagonal_fvalue, bucket_size |
|
|
|
def to_float(self): |
|
"""Returns the float value.""" |
|
if isinstance(self.quantized, list) and not self.quantized: |
|
return self.quantized |
|
|
|
if self.quantized_dtype == jnp.float32: |
|
return self.quantized |
|
|
|
if self.quantized_dtype == jnp.bfloat16: |
|
return self.quantized.astype(jnp.float32) |
|
|
|
float_dtype = self.bucket_size.dtype |
|
bucket_size = self.bucket_size[jnp.newaxis, Ellipsis] |
|
val = self.quantized.astype(float_dtype) * bucket_size |
|
if self.extract_diagonal: |
|
val += jnp.diag(self.diagonal) |
|
return val |
|
|
|
|
|
|
|
class ParameterStats(NamedTuple): |
|
"""State associated to each parameter of the model being trained.""" |
|
diagonal_statistics: QuantizedValue |
|
statistics: List[Any] |
|
preconditioners: List[Any] |
|
diagonal_momentum: QuantizedValue |
|
momentum: QuantizedValue |
|
|
|
|
|
|
|
|
|
|
|
|
|
@struct.dataclass |
|
class GlobalShardedParameterStats: |
|
statistics: chex.Array |
|
preconditioners: chex.Array |
|
exponents: chex.Array |
|
|
|
|
|
|
|
|
|
@struct.dataclass |
|
class LocalShardedParameterStats: |
|
"""State associated to each parameter of the model being trained.""" |
|
diagonal_statistics: QuantizedValue |
|
diagonal_momentum: QuantizedValue |
|
momentum: QuantizedValue |
|
index_start: np.int32 = struct.field( |
|
pytree_node=False) |
|
sizes: Any = struct.field(pytree_node=False) |
|
|
|
|
|
class ShardedShampooStats(NamedTuple): |
|
"""Shampoo state in sharded mode.""" |
|
global_stats: Any |
|
local_stats: Any |
|
|
|
|
|
class ShampooState(NamedTuple): |
|
count: chex.Array |
|
stats: Any |
|
|
|
|
|
class InitFnState(NamedTuple): |
|
init_fn: Any |
|
pspec_fn: Any |
|
shape_and_dtype_fn: Any |
|
|
|
|
|
class GraftingType(enum.IntEnum): |
|
SGD = 1 |
|
ADAGRAD = 2 |
|
RMSPROP = 3 |
|
RMSPROP_NORMALIZED = 4 |
|
|
|
|
|
def power_iteration( |
|
matrix, |
|
num_iters=100, |
|
error_tolerance=1e-6, |
|
precision=lax.Precision.HIGHEST): |
|
r"""Power iteration algorithm. |
|
|
|
The power iteration algorithm takes a symmetric PSD matrix `A`, and produces |
|
a scalar `\lambda` , which is the greatest (in absolute value) eigenvalue |
|
of `A`, and a vector v, which is the corresponding eigenvector of `A`. |
|
|
|
References: |
|
[Wikipedia, 2021](https://en.wikipedia.org/wiki/Power_iteration) |
|
|
|
Args: |
|
matrix: the symmetric PSD matrix. |
|
num_iters: Number of iterations. |
|
error_tolerance: Iterative exit condition. |
|
precision: precision XLA related flag, the available options are: |
|
a) lax.Precision.DEFAULT (better step time, but not precise) |
|
b) lax.Precision.HIGH (increased precision, slower) |
|
c) lax.Precision.HIGHEST (best possible precision, slowest) |
|
|
|
Returns: |
|
eigen vector, eigen value |
|
""" |
|
matrix_size = matrix.shape[-1] |
|
def _iter_condition(state): |
|
i, unused_v, unused_s, unused_s_v, run_step = state |
|
return jnp.logical_and(i < num_iters, run_step) |
|
|
|
def _iter_body(state): |
|
"""One step of power iteration.""" |
|
i, new_v, s, s_v, unused_run_step = state |
|
new_v = new_v / jnp.linalg.norm(new_v) |
|
|
|
s_v = jnp.einsum('ij,j->i', matrix, new_v, precision=precision) |
|
s_new = jnp.einsum('i,i->', new_v, s_v, precision=precision) |
|
return (i + 1, s_v, s_new, s_v, |
|
jnp.greater(jnp.abs(s_new - s), error_tolerance)) |
|
|
|
|
|
v_0 = np.random.RandomState(1729).uniform(-1.0, 1.0, |
|
matrix_size).astype(matrix.dtype) |
|
|
|
init_state = tuple([0, v_0, jnp.zeros([], dtype=matrix.dtype), v_0, True]) |
|
_, v_out, s_out, _, _ = lax.while_loop( |
|
_iter_condition, _iter_body, init_state) |
|
v_out = v_out / jnp.linalg.norm(v_out) |
|
return v_out, s_out |
|
|
|
|
|
def matrix_inverse_pth_root( |
|
matrix, |
|
p, |
|
num_iters=100, |
|
ridge_epsilon=1e-6, |
|
error_tolerance=1e-6, |
|
precision=lax.Precision.HIGHEST): |
|
"""Computes `matrix^(-1/p)`, where `p` is a positive integer. |
|
|
|
This function uses the Coupled newton iterations algorithm for |
|
the computation of a matrix's inverse pth root. |
|
|
|
|
|
References: |
|
[Functions of Matrices, Theory and Computation, |
|
Nicholas J Higham, Pg 184, Eq 7.18]( |
|
https://epubs.siam.org/doi/book/10.1137/1.9780898717778) |
|
|
|
Args: |
|
matrix: the symmetric PSD matrix whose power it to be computed |
|
p: exponent, for p a positive integer. |
|
num_iters: Maximum number of iterations. |
|
ridge_epsilon: Ridge epsilon added to make the matrix positive definite. |
|
error_tolerance: Error indicator, useful for early termination. |
|
precision: precision XLA related flag, the available options are: |
|
a) lax.Precision.DEFAULT (better step time, but not precise) |
|
b) lax.Precision.HIGH (increased precision, slower) |
|
c) lax.Precision.HIGHEST (best possible precision, slowest) |
|
|
|
Returns: |
|
matrix^(-1/p) |
|
""" |
|
|
|
assert matrix.shape[0] == matrix.shape[1] |
|
|
|
|
|
|
|
matrix_size = matrix.shape[0] |
|
alpha = jnp.asarray(-1.0 / p, jnp.float32) |
|
identity = jnp.eye(matrix_size, dtype=jnp.float32) |
|
_, max_ev = power_iteration( |
|
matrix=matrix, num_iters=100, |
|
error_tolerance=1e-6, precision=precision) |
|
ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-16) |
|
|
|
def _unrolled_mat_pow_1(mat_m): |
|
"""Computes mat_m^1.""" |
|
return mat_m |
|
|
|
def _unrolled_mat_pow_2(mat_m): |
|
"""Computes mat_m^2.""" |
|
return jnp.matmul(mat_m, mat_m, precision=precision) |
|
|
|
def _unrolled_mat_pow_4(mat_m): |
|
"""Computes mat_m^4.""" |
|
mat_pow_2 = _unrolled_mat_pow_2(mat_m) |
|
return jnp.matmul( |
|
mat_pow_2, mat_pow_2, precision=precision) |
|
|
|
def _unrolled_mat_pow_8(mat_m): |
|
"""Computes mat_m^4.""" |
|
mat_pow_4 = _unrolled_mat_pow_4(mat_m) |
|
return jnp.matmul( |
|
mat_pow_4, mat_pow_4, precision=precision) |
|
|
|
def mat_power(mat_m, p): |
|
"""Computes mat_m^p, for p == 1, 2, 4 or 8. |
|
|
|
Args: |
|
mat_m: a square matrix |
|
p: a positive integer |
|
|
|
Returns: |
|
mat_m^p |
|
""" |
|
|
|
exponent = jnp.round(jnp.log2(p)) |
|
return lax.switch( |
|
jnp.asarray(exponent, jnp.int32), [ |
|
_unrolled_mat_pow_1, |
|
_unrolled_mat_pow_2, |
|
_unrolled_mat_pow_4, |
|
_unrolled_mat_pow_8, |
|
], (mat_m)) |
|
|
|
def _iter_condition(state): |
|
(i, unused_mat_m, unused_mat_h, unused_old_mat_h, error, |
|
run_step) = state |
|
error_above_threshold = jnp.logical_and( |
|
error > error_tolerance, run_step) |
|
return jnp.logical_and(i < num_iters, error_above_threshold) |
|
|
|
def _iter_body(state): |
|
(i, mat_m, mat_h, unused_old_mat_h, error, unused_run_step) = state |
|
mat_m_i = (1 - alpha) * identity + alpha * mat_m |
|
new_mat_m = jnp.matmul(mat_power(mat_m_i, p), mat_m, precision=precision) |
|
new_mat_h = jnp.matmul(mat_h, mat_m_i, precision=precision) |
|
new_error = jnp.max(jnp.abs(new_mat_m - identity)) |
|
|
|
|
|
return (i + 1, new_mat_m, new_mat_h, mat_h, new_error, |
|
new_error < error * 1.2) |
|
|
|
if matrix_size == 1: |
|
resultant_mat_h = (matrix + ridge_epsilon)**alpha |
|
error = 0 |
|
else: |
|
damped_matrix = matrix + ridge_epsilon * identity |
|
|
|
z = (1 + p) / (2 * jnp.linalg.norm(damped_matrix)) |
|
new_mat_m_0 = damped_matrix * z |
|
new_error = jnp.max(jnp.abs(new_mat_m_0 - identity)) |
|
new_mat_h_0 = identity * jnp.power(z, 1.0 / p) |
|
init_state = tuple( |
|
[0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, True]) |
|
_, mat_m, mat_h, old_mat_h, error, convergence = lax.while_loop( |
|
_iter_condition, _iter_body, init_state) |
|
error = jnp.max(jnp.abs(mat_m - identity)) |
|
is_converged = jnp.asarray(convergence, old_mat_h.dtype) |
|
resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h |
|
resultant_mat_h = jnp.asarray(resultant_mat_h, matrix.dtype) |
|
return resultant_mat_h, error |
|
|
|
|
|
def merge_small_dims(shape_to_merge, max_dim): |
|
"""Merge small dimensions. |
|
|
|
If there are some small dimensions, we collapse them: |
|
e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024 |
|
[1, 2, 768, 1, 2048] --> [2, 768, 2048] |
|
|
|
Args: |
|
shape_to_merge: Shape to merge small dimensions. |
|
max_dim: Maximal dimension of output shape used in merging. |
|
|
|
Returns: |
|
Merged shape. |
|
""" |
|
resulting_shape = [] |
|
product = 1 |
|
for d in shape_to_merge: |
|
if product * d <= max_dim: |
|
product *= d |
|
else: |
|
if product > 1: |
|
resulting_shape.append(product) |
|
product = d |
|
if product > 1: |
|
resulting_shape.append(product) |
|
return resulting_shape |
|
|
|
|
|
def pad_matrix(mat, max_size): |
|
"""Pad a matrix to a max_size. |
|
|
|
Args: |
|
mat: a matrix to pad. |
|
max_size: matrix size requested. |
|
|
|
Returns: |
|
Given M returns [[M, 0], [0, I]] |
|
""" |
|
size = mat.shape[0] |
|
assert size <= max_size |
|
if size == max_size: |
|
return mat |
|
pad_size = max_size - size |
|
zs1 = jnp.zeros([size, pad_size], dtype=mat.dtype) |
|
zs2 = jnp.zeros([pad_size, size], dtype=mat.dtype) |
|
eye = jnp.eye(pad_size, dtype=mat.dtype) |
|
mat = jnp.concatenate([mat, zs1], 1) |
|
mat = jnp.concatenate([mat, jnp.concatenate([zs2, eye], 1)], 0) |
|
return mat |
|
|
|
|
|
def pad_vector(vec, max_size): |
|
"""Pad a vector to a max_size. |
|
|
|
Args: |
|
vec: a vector to pad. |
|
max_size: matrix size requested. |
|
|
|
Returns: |
|
Given V returns [V, 0] |
|
""" |
|
size = vec.shape[0] |
|
assert size <= max_size |
|
if size == max_size: |
|
return vec |
|
pad_size = max_size - size |
|
zs1 = jnp.zeros([pad_size], dtype=vec.dtype) |
|
return jnp.concatenate([vec, zs1], 0) |
|
|
|
|
|
def efficient_cond(predicate, compute_fn, init_state, *args, **kwargs): |
|
"""Avoids wasteful buffer allocation with XLA.""" |
|
|
|
def _iter_body(unused_state): |
|
results = compute_fn(*args, **kwargs) |
|
return tuple([False] + list(results)) |
|
|
|
def _iter_condition(state): |
|
return state[0] |
|
|
|
results = jax.lax.while_loop(_iter_condition, _iter_body, |
|
tuple([predicate] + init_state)) |
|
return tuple(results[1:]) |
|
|
|
|
|
class BlockPartitioner: |
|
"""Partitions a tensor into smaller tensors.""" |
|
|
|
def __init__(self, param, block_size): |
|
self._shape = param.shape |
|
self._splits = [] |
|
split_sizes = [] |
|
|
|
|
|
for i, d in enumerate(param.shape): |
|
if 0 < block_size < d: |
|
|
|
nsplit = (d - 1) // block_size |
|
indices = (np.arange(nsplit, dtype=np.int32) + 1) * block_size |
|
sizes = np.ones(nsplit + 1, dtype=np.int32) * block_size |
|
sizes[-1] = d - indices[-1] |
|
self._splits.append((i, indices)) |
|
split_sizes.append(sizes) |
|
else: |
|
split_sizes.append(np.array([d], dtype=np.int32)) |
|
self._num_splits = len(split_sizes) |
|
self._preconditioner_shapes = [] |
|
for t in itertools.product(*split_sizes): |
|
self._preconditioner_shapes.extend([[d, d] for d in t]) |
|
|
|
def shapes_for_preconditioners(self): |
|
return self._preconditioner_shapes |
|
|
|
def num_splits(self): |
|
return self._num_splits |
|
|
|
def partition(self, tensor): |
|
"""Partition tensor into blocks.""" |
|
|
|
assert tensor.shape == self._shape |
|
tensors = [tensor] |
|
for (i, indices) in self._splits: |
|
tensors_local = [] |
|
for t in tensors: |
|
tensors_local.extend(jnp.split(t, indices_or_sections=indices, axis=i)) |
|
tensors = tensors_local |
|
return tensors |
|
|
|
def merge_partitions(self, partitions): |
|
"""Merge partitions back to original shape.""" |
|
|
|
for (i, indices) in reversed(self._splits): |
|
n = len(indices) + 1 |
|
partial_merged_tensors = [] |
|
ind = 0 |
|
while ind < len(partitions): |
|
partial_merged_tensors.append( |
|
jnp.concatenate(partitions[ind:ind + n], axis=i)) |
|
ind += n |
|
partitions = partial_merged_tensors |
|
assert len(partitions) == 1 |
|
return partitions[0] |
|
|
|
|
|
class Preconditioner: |
|
"""Compute statistics/shape from gradients for preconditioning.""" |
|
|
|
def __init__(self, param, block_size, best_effort_shape_interpretation): |
|
self._original_shape = param.shape |
|
self._transformed_shape = param.shape |
|
if best_effort_shape_interpretation: |
|
self._transformed_shape = merge_small_dims(self._original_shape, |
|
block_size) |
|
reshaped_param = jnp.reshape(param, self._transformed_shape) |
|
self._partitioner = BlockPartitioner(reshaped_param, block_size) |
|
|
|
def statistics_from_grad(self, grad): |
|
"""Compute statistics from gradients. |
|
|
|
Args: |
|
grad: Gradient to compute statistics from. |
|
|
|
Returns: |
|
A list of gradient statistics for each partition. |
|
""" |
|
reshaped_grad = jnp.reshape(grad, self._transformed_shape) |
|
partitioned_grads = self._partitioner.partition(reshaped_grad) |
|
stats = [] |
|
for g in partitioned_grads: |
|
g_stats = [] |
|
rank = len(g.shape) |
|
for i in range(rank): |
|
axes = list(range(i)) + list(range(i + 1, rank)) |
|
stat = jnp.tensordot(g, g, axes=(axes, axes)) |
|
g_stats.append(stat) |
|
stats.extend(g_stats) |
|
return stats |
|
|
|
def shapes_for_preconditioners(self): |
|
"""Returns shape from statistics.""" |
|
return self._partitioner.shapes_for_preconditioners() |
|
|
|
def exponent_for_preconditioner(self): |
|
"""Returns exponent to use for inverse-pth root M^{-1/p}.""" |
|
return 2 * len(self._transformed_shape) |
|
|
|
def preconditioned_grad(self, grad, preconditioners): |
|
"""Precondition the gradient. |
|
|
|
Args: |
|
grad: A gradient tensor to precondition. |
|
preconditioners: A list of preconditioners to apply. |
|
|
|
Returns: |
|
A preconditioned gradient. |
|
""" |
|
|
|
reshaped_grad = jnp.reshape(grad, self._transformed_shape) |
|
partitioned_grads = self._partitioner.partition(reshaped_grad) |
|
preconditioned_partitioned_grads = [] |
|
num_splits = self._partitioner.num_splits() |
|
for i, g in enumerate(partitioned_grads): |
|
preconditioners_for_grad = preconditioners[i * num_splits:(i + 1) * |
|
num_splits] |
|
rank = len(g.shape) |
|
precond_g = g |
|
for j in range(rank): |
|
precond_g = jnp.tensordot( |
|
precond_g, preconditioners_for_grad[j], axes=[[0], [0]]) |
|
preconditioned_partitioned_grads.append(precond_g) |
|
merged_grad = self._partitioner.merge_partitions( |
|
preconditioned_partitioned_grads) |
|
return jnp.reshape(merged_grad, self._original_shape) |
|
|
|
|
|
def _convert_to_parameter_stats(global_stats, local_stat): |
|
"""Creates parameter stats from sharded stats.""" |
|
index_start = int(local_stat.index_start) |
|
index_end = int(len(local_stat.sizes)) + index_start |
|
statistics = global_stats.statistics[index_start:index_end, :, :] |
|
preconditioners = global_stats.preconditioners[index_start:index_end, :, :] |
|
new_statistics = [] |
|
new_preconditioners = [] |
|
for i, size in enumerate(local_stat.sizes): |
|
new_statistics.append(statistics[i][:size, :size]) |
|
new_preconditioners.append(preconditioners[i][:size, :size]) |
|
return ParameterStats(local_stat.diagonal_statistics, new_statistics, |
|
new_preconditioners, local_stat.diagonal_momentum, |
|
local_stat.momentum) |
|
|
|
|
|
def _convert_from_parameter_stats(parameter_stats, local_stats): |
|
"""Creates sharded stats from paramter stats.""" |
|
return LocalShardedParameterStats(parameter_stats.diagonal_statistics, |
|
parameter_stats.diagonal_momentum, |
|
parameter_stats.momentum, |
|
local_stats.index_start, local_stats.sizes) |
|
|
|
|
|
def batch(x, num_devices): |
|
"""Batch `x` so that so that leading axis is num_devices.""" |
|
n = len(x) |
|
b = int(n / num_devices) |
|
return jnp.stack([jnp.stack(x[idx:idx + b]) for idx in range(0, n, b)]) |
|
|
|
|
|
def unbatch(batched_values): |
|
"""Unbatch values across leading axis and return a list of elements.""" |
|
b1, b2 = batched_values.shape[0], batched_values.shape[1] |
|
results = [] |
|
for v_array in jnp.split(batched_values, indices_or_sections=b1, axis=0): |
|
v_array = jnp.squeeze(v_array) |
|
|
|
if b2 > 1: |
|
for v in jnp.split(v_array, indices_or_sections=b2, axis=0): |
|
results.append(jnp.squeeze(v)) |
|
else: |
|
results.append(v_array) |
|
return results |
|
|
|
|
|
def distributed_shampoo( |
|
learning_rate, |
|
block_size, |
|
beta1=0.9, |
|
beta2=0.999, |
|
diagonal_epsilon=1e-10, |
|
matrix_epsilon=1e-6, |
|
weight_decay=0.0, |
|
start_preconditioning_step=5, |
|
preconditioning_compute_steps=1, |
|
statistics_compute_steps=1, |
|
best_effort_shape_interpretation=True, |
|
graft_type=GraftingType.SGD, |
|
nesterov=True, |
|
exponent_override=0, |
|
|
|
batch_axis_name=None, |
|
|
|
|
|
statistics_partition_spec=None, |
|
preconditioner_partition_spec=None, |
|
num_devices_for_pjit=None, |
|
shard_optimizer_states=False, |
|
|
|
|
|
best_effort_memory_usage_reduction=False, |
|
|
|
inverse_failure_threshold=0.1, |
|
moving_average_for_momentum=False, |
|
skip_preconditioning_dim_size_gt=4096, |
|
clip_by_scaled_gradient_norm=None, |
|
precision=lax.Precision.HIGHEST): |
|
"""Distributed Shampoo optimizer. |
|
|
|
Distributed Shampoo is a second-order preconditioned method (concretely, a |
|
variant of full-matrix Adagrad), that provides significant convergence and |
|
wall-clock time improvements compared to conventional first-order methods, |
|
and that has been shown to scale to large state-of-the-art deep learning |
|
models. |
|
|
|
References: |
|
Scalable Second Order Optimization for Deep Learning, |
|
Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer |
|
|
|
Preprint: https://arxiv.org/abs/2002.09018 |
|
|
|
Args: |
|
learning_rate: the step size used to update the parameters. |
|
block_size: Block size for large layers (if > 0). Preconditioning compute |
|
operation is cubic in the dimension of the tensor. Block size allows us to |
|
chunk the layers into sub-layers of maximal dimension dictated by this |
|
value. Use 128 as default (increase if you have compute budget). |
|
beta1: momentum parameter. |
|
beta2: second moment averaging parameter. |
|
diagonal_epsilon: epsilon for diagonal adagrad (only if layerwise grafting |
|
to AdaGrad is enabled). |
|
matrix_epsilon: epsilon to add to statistics before computing inverse pth |
|
root. If you are running in f32 precision for inverse pth root |
|
(recommended today) this can go upto 1e-6. If you have latest hardware |
|
with native f64 precision, set this upto 1e-12. |
|
weight_decay: Weight decay for regularization. |
|
start_preconditioning_step: When to start Shampoo update before which |
|
diagonal update is used. This is because we dont have enough information |
|
to do stable inverse. |
|
preconditioning_compute_steps: How often to compute preconditioner. |
|
Performance tuning params for controlling memory and compute requirements. |
|
Ideally set this and statistics_compute_steps params to 1. |
|
statistics_compute_steps: How often to compute statistics. |
|
best_effort_shape_interpretation: If there are some small dimensions, |
|
collapse them e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if |
|
block = 1024, [1, 2, 768, 1, 2048] --> [2, 768, 2048] |
|
graft_type: Grafting is a technique to fix the layerwise scale of Shampoo |
|
optimizer. This allows us to plugin the Shampoo optimizer into settings |
|
where SGD/AdaGrad is already well tuned. Available options are: |
|
GraftingType.SGD and GraftingType.ADAGRAD. |
|
nesterov: Nesterov momentum. |
|
exponent_override: Override the exponent used in matrix inverse. |
|
batch_axis_name: labeled axis over pmap for data-parallel training the |
|
optimizer used for. |
|
statistics_partition_spec: PartitionSpec to be used in sharded mode. |
|
preconditioner_partition_spec: PartitionSpec to be used in sharded mode. |
|
num_devices_for_pjit: Number of devices to parallelize over when using pjit. |
|
shard_optimizer_states: Shard optimizer states to save memory in model |
|
parallel training. |
|
best_effort_memory_usage_reduction: Best effort memory usage reduction. |
|
diagonal_statistics -> jnp.bfloat16 |
|
momentum buffers (2x) -> jnp.int8 |
|
statistics, preconditioners -> jnp.int16 + diagonals |
|
inverse_failure_threshold: numerics are hard and inverses fail sometimes; we |
|
determine that using this threshold. |
|
moving_average_for_momentum: Whether to use moving average for momentum |
|
instead of exponential moving average. |
|
skip_preconditioning_dim_size_gt: Skip if preconditioning dim size is |
|
greater than this value. |
|
clip_by_scaled_gradient_norm: Clip by scaled gradient norm (only useful |
|
when using RMSProp Grafting). |
|
precision: precision XLA related flag, the available options are: a) |
|
lax.Precision.DEFAULT (better step time, but not precise) b) |
|
lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST |
|
(best possible precision, slowest) |
|
|
|
Returns: |
|
a GradientTransformation. |
|
""" |
|
|
|
def quantized_dtype_for_momentum_buffers(): |
|
return jnp.int8 if best_effort_memory_usage_reduction else jnp.float32 |
|
|
|
|
|
def quantized_dtype_for_diagonal_statistics_buffers(): |
|
return jnp.bfloat16 if best_effort_memory_usage_reduction else jnp.float32 |
|
|
|
|
|
|
|
def quantized_dtype_for_second_moment_statistics_buffers(): |
|
return jnp.int16 if best_effort_memory_usage_reduction and batch_axis_name else jnp.float32 |
|
|
|
|
|
|
|
def quantized_dtype_for_second_moment_preconditioner_buffers(): |
|
return jnp.int16 if best_effort_memory_usage_reduction and batch_axis_name else jnp.float32 |
|
|
|
def _to_float(maybe_quantized): |
|
if isinstance(maybe_quantized, QuantizedValue): |
|
return maybe_quantized.to_float() |
|
else: |
|
return maybe_quantized |
|
|
|
def _maybe_quantize_statistics(statistics_list): |
|
return _maybe_quantize_matrices_with_dtype( |
|
statistics_list, quantized_dtype_for_second_moment_statistics_buffers()) |
|
|
|
def _maybe_quantize_preconditioners(statistics_list): |
|
return _maybe_quantize_matrices_with_dtype( |
|
statistics_list, |
|
quantized_dtype_for_second_moment_preconditioner_buffers()) |
|
|
|
def _maybe_quantize_matrices_with_dtype(statistics_list, quantized_dtype): |
|
if quantized_dtype != jnp.float32: |
|
return ([ |
|
QuantizedValue.from_float_value( |
|
s, quantized_dtype, extract_diagonal=True) |
|
for s in statistics_list |
|
]) |
|
else: |
|
return statistics_list |
|
|
|
def _maybe_dequantize_preconditioners(preconditioner_list): |
|
return _maybe_dequantize_matrices_with_dtype( |
|
preconditioner_list, |
|
quantized_dtype_for_second_moment_preconditioner_buffers()) |
|
|
|
def _maybe_dequantize_matrices_with_dtype(statistics_list, quantized_dtype): |
|
if quantized_dtype != jnp.float32: |
|
return [s.to_float() for s in statistics_list] |
|
else: |
|
return statistics_list |
|
|
|
def _quantize_diagonal_statistics(diagonal_statistics): |
|
return QuantizedValue.from_float_value( |
|
diagonal_statistics, quantized_dtype_for_diagonal_statistics_buffers()) |
|
|
|
def _quantize_momentum(momentum_statistics): |
|
return QuantizedValue.from_float_value( |
|
momentum_statistics, quantized_dtype_for_momentum_buffers()) |
|
|
|
def sharded_init_fn(params): |
|
"""Returns optimizer state (for PJIT mode). |
|
|
|
Args: |
|
params: the parameters that should be updated. |
|
""" |
|
params_flat, treedef = jax.tree_flatten(params) |
|
|
|
max_size = 0 |
|
for param in params_flat: |
|
preconditioner = Preconditioner(param, block_size, |
|
best_effort_shape_interpretation) |
|
if not _skip_preconditioning(param): |
|
shapes = preconditioner.shapes_for_preconditioners() |
|
sizes = [s[0] for s in shapes] |
|
max_size = max(max(sizes), max_size) |
|
|
|
padded_statistics = [] |
|
padded_preconditioners = [] |
|
local_stats_flat = [] |
|
exponents = [] |
|
for param in params_flat: |
|
preconditioner = Preconditioner(param, block_size, |
|
best_effort_shape_interpretation) |
|
shapes = preconditioner.shapes_for_preconditioners() |
|
sizes = [] |
|
|
|
statistics = [] |
|
preconditioners = [] |
|
index_start = len(padded_statistics) |
|
if not _skip_preconditioning(param): |
|
sizes = [s[0] for s in shapes] |
|
shapes = preconditioner.shapes_for_preconditioners() |
|
statistics = [matrix_epsilon * jnp.eye(max_size) for s in shapes] |
|
preconditioners = [jnp.eye(max_size) for s in shapes] |
|
padded_statistics.extend(statistics) |
|
padded_preconditioners.extend(preconditioners) |
|
exponent = ( |
|
preconditioner.exponent_for_preconditioner() |
|
if exponent_override == 0 else exponent_override) |
|
exponents.extend([exponent] * len(shapes)) |
|
|
|
diagonal_statistics = [] |
|
if graft_type != GraftingType.SGD: |
|
diagonal_statistics = jnp.zeros_like(param) |
|
local_stats_flat.append( |
|
LocalShardedParameterStats( |
|
_quantize_diagonal_statistics(diagonal_statistics), |
|
_quantize_momentum(jnp.zeros_like(param)), |
|
_quantize_momentum(jnp.zeros_like(param)), index_start, sizes)) |
|
|
|
local_stats = jax.tree_unflatten(treedef, local_stats_flat) |
|
|
|
|
|
|
|
|
|
to_pad = -len(padded_statistics) % num_devices_for_pjit |
|
padded_statistics.extend([ |
|
jnp.eye(max_size, dtype=padded_statistics[0].dtype) |
|
for _ in range(to_pad) |
|
]) |
|
padded_preconditioners.extend([ |
|
jnp.eye(max_size, dtype=padded_statistics[0].dtype) |
|
for _ in range(to_pad) |
|
]) |
|
exponents.extend([1 for _ in range(to_pad)]) |
|
global_stats = GlobalShardedParameterStats( |
|
jnp.stack(padded_statistics), jnp.stack(padded_preconditioners), |
|
jnp.stack(exponents)) |
|
return ShampooState( |
|
count=jnp.zeros([], jnp.int32), |
|
stats=ShardedShampooStats(global_stats, local_stats)) |
|
|
|
def _max_statistics_size_from_params(params): |
|
max_size = 0 |
|
for param in params: |
|
param_clone = jnp.zeros(param.shape, dtype=param.dtype) |
|
preconditioner = Preconditioner(param_clone, block_size, |
|
best_effort_shape_interpretation) |
|
if not _skip_preconditioning(param): |
|
shapes = preconditioner.shapes_for_preconditioners() |
|
sizes = [s[0] for s in shapes] |
|
max_size = max(max(sizes), max_size) |
|
return max_size |
|
|
|
def _remove_leading_sharding_annotation(pspec): |
|
"""Mapping from N-d to (N-1)-d, used for quantization, factoring etc.""" |
|
|
|
if pspec and len(pspec) > 1: |
|
return PartitionSpec(*pspec[1:]) |
|
else: |
|
return pspec |
|
|
|
def sharded_init_partition_spec_fn(params, params_partition_spec, |
|
partition_spec_for_statistics): |
|
"""Returns a parallel state tree with PartitionSpec associated with state. |
|
|
|
|
|
Args: |
|
params: A pytree with params. |
|
params_partition_spec: A pytree with PartitionSpec for params. |
|
partition_spec_for_statistics: PartitionSpec for the statistics. |
|
""" |
|
|
|
param_pspec_flat, _ = jax.tree_flatten(params_partition_spec) |
|
params_flat, treedef = jax.tree_flatten(params) |
|
assert param_pspec_flat |
|
assert params_flat |
|
|
|
|
|
local_stats_flat = [] |
|
num_statistics = 0 |
|
for param, param_pspec in zip(params_flat, param_pspec_flat): |
|
param_clone = jnp.zeros(param.shape, dtype=param.dtype) |
|
preconditioner = Preconditioner(param_clone, block_size, |
|
best_effort_shape_interpretation) |
|
shapes = preconditioner.shapes_for_preconditioners() |
|
sizes = [] |
|
|
|
index_start = num_statistics |
|
if not _skip_preconditioning(param): |
|
sizes = [s[0] for s in shapes] |
|
shapes = preconditioner.shapes_for_preconditioners() |
|
num_statistics += len(shapes) |
|
|
|
diagonal_statistics_pspec = [] |
|
diagonal_statistics_scale_pspec = [] |
|
if graft_type != GraftingType.SGD: |
|
|
|
diagonal_statistics_pspec = param_pspec |
|
if quantized_dtype_for_diagonal_statistics_buffers() != jnp.float32: |
|
diagonal_statistics_scale_pspec = _remove_leading_sharding_annotation( |
|
param_pspec) |
|
|
|
m1_pspec = param_pspec |
|
m2_pspec = param_pspec |
|
|
|
m1_scale_pspec = [] |
|
m2_scale_pspec = [] |
|
|
|
if quantized_dtype_for_momentum_buffers() != jnp.float32: |
|
m1_scale_pspec = _remove_leading_sharding_annotation(m1_pspec) |
|
m2_scale_pspec = _remove_leading_sharding_annotation(m2_pspec) |
|
|
|
local_stats_flat.append( |
|
LocalShardedParameterStats( |
|
QuantizedValue(diagonal_statistics_pspec, [], |
|
diagonal_statistics_scale_pspec, |
|
quantized_dtype_for_diagonal_statistics_buffers(), |
|
False, list(param.shape)), |
|
QuantizedValue(m1_pspec, [], m1_scale_pspec, |
|
quantized_dtype_for_momentum_buffers(), False, |
|
list(param.shape)), |
|
QuantizedValue(m2_pspec, [], m2_scale_pspec, |
|
quantized_dtype_for_momentum_buffers(), False, |
|
list(param.shape)), index_start, sizes)) |
|
|
|
local_stats = jax.tree_unflatten(treedef, local_stats_flat) |
|
global_stats = GlobalShardedParameterStats(partition_spec_for_statistics, |
|
partition_spec_for_statistics, |
|
PartitionSpec()) |
|
count_pspec = PartitionSpec() |
|
return ShampooState( |
|
count=count_pspec, stats=ShardedShampooStats(global_stats, local_stats)) |
|
|
|
def sharded_init_shape_and_dtype_fn(params): |
|
"""Returns a parallel state tree with shape, dtype associated with state. |
|
|
|
|
|
Args: |
|
params: A pytree with params. |
|
""" |
|
|
|
params_flat, treedef = jax.tree_flatten(params) |
|
assert params_flat |
|
|
|
|
|
local_stats_flat = [] |
|
num_statistics = 0 |
|
for param in params_flat: |
|
param_clone = jnp.zeros(param.shape, dtype=param.dtype) |
|
preconditioner = Preconditioner(param_clone, block_size, |
|
best_effort_shape_interpretation) |
|
shapes = preconditioner.shapes_for_preconditioners() |
|
sizes = [] |
|
|
|
index_start = num_statistics |
|
if not _skip_preconditioning(param): |
|
sizes = [s[0] for s in shapes] |
|
shapes = preconditioner.shapes_for_preconditioners() |
|
num_statistics += len(shapes) |
|
|
|
diagonal_statistics_shape_and_dtype = [] |
|
diagonal_statistics_scale_shape_and_dtype = [] |
|
if graft_type != GraftingType.SGD: |
|
diagonal_statistics_shape_and_dtype = [list(param.shape), param.dtype] |
|
qdtype = quantized_dtype_for_diagonal_statistics_buffers() |
|
if qdtype != jnp.float32: |
|
diagonal_statistics_shape_and_dtype = [list(param.shape), qdtype] |
|
diagonal_statistics_scale_shape_and_dtype = [ |
|
list(param.shape)[1:], param.dtype |
|
] |
|
|
|
m1_shape_and_dtype = [list(param.shape), param.dtype] |
|
m2_shape_and_dtype = [list(param.shape), param.dtype] |
|
|
|
m1_scale_shape_and_dtype = [] |
|
m2_scale_shape_and_dtype = [] |
|
|
|
qdtype = quantized_dtype_for_momentum_buffers() |
|
if qdtype != jnp.float32: |
|
m1_shape_and_dtype = [list(param.shape), qdtype] |
|
m2_shape_and_dtype = [list(param.shape), qdtype] |
|
|
|
m1_scale_shape_and_dtype = [list(param.shape)[1:], qdtype] |
|
m2_scale_shape_and_dtype = [list(param.shape)[1:], qdtype] |
|
|
|
local_stats_flat.append( |
|
LocalShardedParameterStats( |
|
QuantizedValue(diagonal_statistics_shape_and_dtype, [], |
|
diagonal_statistics_scale_shape_and_dtype, |
|
quantized_dtype_for_diagonal_statistics_buffers(), |
|
False, list(param.shape)), |
|
QuantizedValue(m1_shape_and_dtype, [], m1_scale_shape_and_dtype, |
|
quantized_dtype_for_momentum_buffers(), False, |
|
list(param.shape)), |
|
QuantizedValue(m2_shape_and_dtype, [], m2_scale_shape_and_dtype, |
|
quantized_dtype_for_momentum_buffers(), False, |
|
list(param.shape)), index_start, sizes)) |
|
|
|
local_stats = jax.tree_unflatten(treedef, local_stats_flat) |
|
max_statistics_size = _max_statistics_size_from_params(params_flat) |
|
to_pad = -num_statistics % num_devices_for_pjit |
|
num_statistics += to_pad |
|
statistics_shape = [ |
|
num_statistics, max_statistics_size, max_statistics_size |
|
] |
|
global_stats = GlobalShardedParameterStats([statistics_shape, jnp.float32], |
|
[statistics_shape, jnp.float32], |
|
[[num_statistics], jnp.int32]) |
|
return ShampooState( |
|
count=[[], jnp.float32], |
|
stats=ShardedShampooStats(global_stats, local_stats)) |
|
|
|
def sharded_update_fn(grads, state, params): |
|
"""Transform the input gradient and update all statistics in sharded mode. |
|
|
|
Args: |
|
grads: the gradient tensors for the parameters. |
|
state: a named tuple containing the state of the optimizer |
|
params: the parameters that should be updated. |
|
|
|
Returns: |
|
A tuple containing the new parameters and the new optimizer state. |
|
""" |
|
params_flat, treedef = jax.tree_flatten(params) |
|
grads_flat = treedef.flatten_up_to(grads) |
|
|
|
global_stats = state.stats.global_stats |
|
local_stats_flat = treedef.flatten_up_to(state.stats.local_stats) |
|
stats_flat = [ |
|
_convert_to_parameter_stats(global_stats, local_stat) |
|
for local_stat in local_stats_flat |
|
] |
|
new_stats_flat = jax.tree_multimap( |
|
lambda g, s, p: _compute_stats(g, s, p, state.count), grads_flat, |
|
stats_flat, params_flat) |
|
|
|
outputs = jax.tree_multimap( |
|
lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat, |
|
new_stats_flat, params_flat) |
|
updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ()) |
|
|
|
updates = jax.tree_unflatten(treedef, updates_flat) |
|
|
|
new_local_stats_flat = [ |
|
_convert_from_parameter_stats(new_stat, local_stat) |
|
for new_stat, local_stat in zip(new_stats_flat, local_stats_flat) |
|
] |
|
new_local_stats = jax.tree_unflatten(treedef, new_local_stats_flat) |
|
|
|
max_size = global_stats.statistics.shape[1] |
|
new_padded_statistics = [] |
|
for stat in new_stats_flat: |
|
new_padded_statistics.extend( |
|
[pad_matrix(stat, max_size) for stat in stat.statistics]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
to_pad = -len(new_padded_statistics) % num_devices_for_pjit |
|
new_padded_statistics.extend([ |
|
jnp.eye(max_size, dtype=new_padded_statistics[0].dtype) |
|
for _ in range(to_pad) |
|
]) |
|
new_stacked_padded_statistics = jnp.stack(new_padded_statistics) |
|
new_stacked_padded_statistics = pjit.with_sharding_constraint( |
|
new_stacked_padded_statistics, statistics_partition_spec) |
|
def _internal_inverse_pth_root_all(): |
|
preconditioners, errors = _matrix_inverse_pth_root_pjit( |
|
new_stacked_padded_statistics, global_stats.exponents, |
|
statistics_partition_spec) |
|
return preconditioners, errors |
|
|
|
if preconditioning_compute_steps == 1: |
|
new_preconditioners, errors = _internal_inverse_pth_root_all() |
|
else: |
|
|
|
|
|
|
|
preconditioners_init = new_stacked_padded_statistics |
|
n = new_stacked_padded_statistics.shape[0] |
|
errors_init = jnp.ones([n], jnp.float32) * inverse_failure_threshold |
|
init_state = [preconditioners_init, errors_init] |
|
perform_step = state.count % preconditioning_compute_steps == 0 |
|
new_preconditioners, errors = efficient_cond( |
|
perform_step, _internal_inverse_pth_root_all, init_state) |
|
|
|
errors = errors.reshape((-1, 1, 1)) |
|
predicate = jnp.logical_or( |
|
jnp.isnan(errors), |
|
errors >= inverse_failure_threshold).astype(new_preconditioners.dtype) |
|
|
|
new_conditional_preconditioners = ( |
|
predicate * global_stats.preconditioners + |
|
(1.0 - predicate) * new_preconditioners) |
|
new_global_stats = GlobalShardedParameterStats( |
|
new_stacked_padded_statistics, new_conditional_preconditioners, |
|
global_stats.exponents) |
|
new_shampoo_state = ShampooState( |
|
count=state.count + 1, |
|
stats=ShardedShampooStats(new_global_stats, new_local_stats)) |
|
return updates, new_shampoo_state |
|
|
|
def init_fn(params): |
|
"""Initialise the optimiser's state.""" |
|
|
|
def _init(param): |
|
preconditioner = Preconditioner(param, block_size, |
|
best_effort_shape_interpretation) |
|
statistics = [] |
|
preconditioners = [] |
|
if not _skip_preconditioning(param): |
|
shapes = preconditioner.shapes_for_preconditioners() |
|
statistics = [matrix_epsilon * jnp.eye(s[0]) for s in shapes] |
|
preconditioners = [jnp.eye(s[0]) for s in shapes] |
|
|
|
diagonal_statistics = [] |
|
if graft_type != GraftingType.SGD: |
|
diagonal_statistics = jnp.zeros_like(param) |
|
return ParameterStats( |
|
_quantize_diagonal_statistics(diagonal_statistics), |
|
_maybe_quantize_statistics(statistics), |
|
_maybe_quantize_preconditioners(preconditioners), |
|
_quantize_momentum(jnp.zeros_like(param)), |
|
_quantize_momentum(jnp.zeros_like(param))) |
|
return ShampooState( |
|
count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params)) |
|
|
|
def _skip_preconditioning(param): |
|
return len(param.shape) < 1 or any( |
|
[s > skip_preconditioning_dim_size_gt for s in param.shape]) |
|
|
|
def _compute_stats(grad, state, param, step): |
|
"""Compute per-parameter statistics.""" |
|
preconditioner = Preconditioner(param, block_size, |
|
best_effort_shape_interpretation) |
|
new_statistics = [[]] * len(state.statistics) |
|
w1 = beta2 |
|
w2 = beta2 if beta2 == 1.0 else (1.0 - beta2) |
|
if not _skip_preconditioning(param): |
|
|
|
def compute_updated_statistics(): |
|
new_stats = preconditioner.statistics_from_grad(grad) |
|
new_stats_accumulators = [] |
|
for stat, stat_accumulator in zip(new_stats, state.statistics): |
|
new_stats_accumulators.append(w1 * _to_float(stat_accumulator) + |
|
w2 * stat) |
|
return _maybe_quantize_statistics(new_stats_accumulators) |
|
|
|
if statistics_compute_steps > 1: |
|
perform_step = step % statistics_compute_steps == 0 |
|
init_state = state.statistics |
|
new_statistics = list( |
|
efficient_cond(perform_step, compute_updated_statistics, |
|
init_state)) |
|
else: |
|
new_statistics = compute_updated_statistics() |
|
return ParameterStats(state.diagonal_statistics, new_statistics, |
|
state.preconditioners, state.diagonal_momentum, |
|
state.momentum) |
|
|
|
def _matrix_inverse_pth_root_vmap(xs, ps): |
|
mi_pth_root = functools.partial( |
|
matrix_inverse_pth_root, |
|
ridge_epsilon=matrix_epsilon, |
|
precision=precision) |
|
return jax.vmap(mi_pth_root)(xs, ps) |
|
|
|
def _quantized_matrix_inverse_pth_root_vmap(qxs, qds, qbs, ps): |
|
|
|
def _quantized_to_float(qx, qd, qb): |
|
qv = QuantizedValue(qx, qd, qb, qx.dtype, True, list(qx.shape)) |
|
return qv.to_float() |
|
|
|
def matrix_inverse_pth_root_wrapper(qx, qd, qb, p): |
|
v = _quantized_to_float(qx, qd, qb) |
|
preconditioner, error = matrix_inverse_pth_root( |
|
v, p, ridge_epsilon=matrix_epsilon, precision=precision) |
|
qp = QuantizedValue.from_float_value(preconditioner, qx.dtype, True) |
|
return qp.quantized, qp.diagonal, qp.bucket_size, error |
|
|
|
return jax.vmap(matrix_inverse_pth_root_wrapper)(qxs, qds, qbs, ps) |
|
|
|
def _matrix_inverse_pth_root_pjit(xs, ps, statistics_partition_spec=None): |
|
|
|
pspec_for_partition = preconditioner_partition_spec |
|
partitioned_xs = pjit.with_sharding_constraint(xs, pspec_for_partition) |
|
partitioned_ps = pjit.with_sharding_constraint( |
|
ps, pjit.PartitionSpec(preconditioner_partition_spec[0])) |
|
|
|
partitioned_preconditioners, partitioned_errors = ( |
|
_matrix_inverse_pth_root_vmap(partitioned_xs, partitioned_ps)) |
|
|
|
|
|
partitioned_preconditioners = pjit.with_sharding_constraint( |
|
partitioned_preconditioners, pspec_for_partition) |
|
|
|
preconditioners = pjit.with_sharding_constraint(partitioned_preconditioners, |
|
statistics_partition_spec) |
|
errors = pjit.with_sharding_constraint(partitioned_errors, |
|
pjit.PartitionSpec()) |
|
return preconditioners, errors |
|
|
|
def _pmap_compute_preconditioners(states, step, statistics, |
|
num_statistics_per_state, original_shapes, |
|
exponents, max_size, prev_preconditioners): |
|
"""Computes preconditioners for given statistics in states in PMAP mode. |
|
|
|
Args: |
|
states: A list of optimizer states. |
|
step: Current step number |
|
statistics: A list of statistics for all variables (for every dim) |
|
num_statistics_per_state: Number of statistis per state to reconstruct |
|
output states. |
|
original_shapes: A list of shapes of the statistics. |
|
exponents: Exponent power to use for inverse-pth roots. |
|
max_size: Maximum dim of the statistics to pad. |
|
prev_preconditioners: Previously available preconditioner. |
|
|
|
Returns: |
|
New optimizer states after computing the preconditioner. |
|
""" |
|
num_devices = lax.psum(1, batch_axis_name) |
|
num_statistics = len(statistics) |
|
|
|
packed_statistics = [pad_matrix(stat, max_size) for stat in statistics] |
|
to_pad = -num_statistics % num_devices |
|
packed_statistics.extend([ |
|
jnp.eye(max_size, dtype=packed_statistics[0].dtype) |
|
for _ in range(to_pad) |
|
]) |
|
exponents.extend([1 for _ in range(to_pad)]) |
|
|
|
if not packed_statistics: |
|
return states |
|
|
|
all_statistics = batch(packed_statistics, num_devices) |
|
all_exponents = batch(exponents, num_devices) |
|
|
|
def _internal_inverse_pth_root_all(): |
|
current_replica = lax.axis_index(batch_axis_name) |
|
preconditioners, errors = _matrix_inverse_pth_root_vmap( |
|
all_statistics[current_replica], all_exponents[current_replica]) |
|
preconditioners = jax.lax.all_gather(preconditioners, batch_axis_name) |
|
errors = jax.lax.all_gather(errors, batch_axis_name) |
|
preconditioners_flat = unbatch(preconditioners) |
|
errors_flat = unbatch(errors) |
|
return preconditioners_flat, errors_flat |
|
|
|
if preconditioning_compute_steps == 1: |
|
preconditioners_flat, errors_flat = _internal_inverse_pth_root_all() |
|
else: |
|
|
|
|
|
|
|
preconditioners_init = packed_statistics |
|
errors_init = ([inverse_failure_threshold] * len(packed_statistics)) |
|
init_state = [preconditioners_init, errors_init] |
|
perform_step = step % preconditioning_compute_steps == 0 |
|
preconditioners_flat, errors_flat = efficient_cond( |
|
perform_step, _internal_inverse_pth_root_all, init_state) |
|
|
|
def _skip(error): |
|
condition = jnp.logical_or( |
|
jnp.isnan(error), error >= inverse_failure_threshold) |
|
return condition.astype(error.dtype) |
|
|
|
def _select_preconditioner(error, new_p, old_p): |
|
return lax.cond( |
|
_skip(error), lambda _: old_p, lambda _: new_p, operand=None) |
|
|
|
new_preconditioners_flat = [] |
|
for p, shape, prev_p, error in zip(preconditioners_flat, original_shapes, |
|
prev_preconditioners, errors_flat): |
|
new_preconditioners_flat.append( |
|
_select_preconditioner(error, p[:shape[0], :shape[1]], prev_p)) |
|
|
|
assert len(states) == len(num_statistics_per_state) |
|
assert len(new_preconditioners_flat) == num_statistics |
|
|
|
|
|
preconditioners_for_states = [] |
|
idx = 0 |
|
for num_statistics, state in zip(num_statistics_per_state, states): |
|
if num_statistics == 0: |
|
preconditioners_for_states.append([]) |
|
else: |
|
preconditioners_for_state = new_preconditioners_flat[idx:idx + |
|
num_statistics] |
|
assert len(state.statistics) == len(preconditioners_for_state) |
|
preconditioners_for_states.append(preconditioners_for_state) |
|
idx += num_statistics |
|
new_states = [] |
|
for state, new_preconditioners in zip(states, preconditioners_for_states): |
|
new_states.append( |
|
ParameterStats(state.diagonal_statistics, state.statistics, |
|
new_preconditioners, state.diagonal_momentum, |
|
state.momentum)) |
|
|
|
return new_states |
|
|
|
def _pmap_quantized_compute_preconditioners(states, step, statistics, |
|
num_statistics_per_state, |
|
original_shapes, exponents, |
|
max_size, prev_preconditioners): |
|
"""Computes preconditioners for given statistics in states in PMAP mode. |
|
|
|
For quantization, each statistic is represented by three values: |
|
quantized matrix, diagonal, and bucket sizes, we run inverse pth-roots |
|
without ever recreating the original matrix in f32. |
|
|
|
Args: |
|
states: A list of optimizer states. |
|
step: Current step number |
|
statistics: A list of statistics for all variables (for every dim) |
|
num_statistics_per_state: Number of statistis per state to reconstruct |
|
output states. |
|
original_shapes: A list of shapes of the statistics. |
|
exponents: Exponent power to use for inverse-pth roots. |
|
max_size: Maximum dim of the statistics to pad. |
|
prev_preconditioners: Previously available preconditioner. |
|
|
|
Returns: |
|
New optimizer states after computing the preconditioner. |
|
""" |
|
num_devices = lax.psum(1, batch_axis_name) |
|
num_statistics = len(statistics) |
|
quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
packed_quantized_statistics = [ |
|
pad_matrix(stat.quantized, max_size) for stat in statistics |
|
] |
|
packed_quantized_diagonals = [ |
|
pad_vector(stat.diagonal, max_size) for stat in statistics |
|
] |
|
packed_quantized_bucket_sizes = [ |
|
pad_vector(stat.bucket_size, max_size) for stat in statistics |
|
] |
|
|
|
to_pad = -num_statistics % num_devices |
|
padded_eye = jnp.eye(max_size, dtype=jnp.float32) |
|
quantized_eye = QuantizedValue.from_float_value(padded_eye, quantized_dtype, |
|
True) |
|
packed_quantized_statistics.extend( |
|
[quantized_eye.quantized for _ in range(to_pad)]) |
|
packed_quantized_diagonals.extend( |
|
[quantized_eye.diagonal for _ in range(to_pad)]) |
|
packed_quantized_bucket_sizes.extend( |
|
[quantized_eye.bucket_size for _ in range(to_pad)]) |
|
exponents.extend([1 for _ in range(to_pad)]) |
|
|
|
if not packed_quantized_statistics: |
|
return states |
|
|
|
all_quantized_statistics = batch(packed_quantized_statistics, num_devices) |
|
all_quantized_diagonals = batch(packed_quantized_diagonals, num_devices) |
|
all_quantized_bucket_sizes = batch(packed_quantized_bucket_sizes, |
|
num_devices) |
|
all_exponents = batch(exponents, num_devices) |
|
|
|
def _internal_inverse_pth_root_all(): |
|
current_replica = lax.axis_index(batch_axis_name) |
|
(quantized_preconditioners, quantized_diagonals, quantized_bucket_sizes, |
|
errors) = ( |
|
_quantized_matrix_inverse_pth_root_vmap( |
|
all_quantized_statistics[current_replica], |
|
all_quantized_diagonals[current_replica], |
|
all_quantized_bucket_sizes[current_replica], |
|
all_exponents[current_replica])) |
|
quantized_preconditioners = jax.lax.all_gather(quantized_preconditioners, |
|
batch_axis_name) |
|
quantized_diagonals = jax.lax.all_gather(quantized_diagonals, |
|
batch_axis_name) |
|
quantized_bucket_sizes = jax.lax.all_gather(quantized_bucket_sizes, |
|
batch_axis_name) |
|
errors = jax.lax.all_gather(errors, batch_axis_name) |
|
quantized_preconditioners_flat = unbatch(quantized_preconditioners) |
|
quantized_diagonals_flat = unbatch(quantized_diagonals) |
|
quantized_bucket_sizes_flat = unbatch(quantized_bucket_sizes) |
|
errors_flat = unbatch(errors) |
|
return (quantized_preconditioners_flat, quantized_diagonals_flat, |
|
quantized_bucket_sizes_flat, errors_flat) |
|
|
|
if preconditioning_compute_steps == 1: |
|
(quantized_preconditioners_flat, quantized_diagonals_flat, |
|
quantized_bucket_sizes_flat, errors_flat) = ( |
|
_internal_inverse_pth_root_all()) |
|
else: |
|
|
|
|
|
|
|
quantized_preconditioners_init = packed_quantized_statistics |
|
quantized_diagonals_init = packed_quantized_diagonals |
|
quantized_bucket_sizes_init = packed_quantized_bucket_sizes |
|
errors_init = ([inverse_failure_threshold] * |
|
len(quantized_preconditioners_init)) |
|
init_state = [ |
|
quantized_preconditioners_init, quantized_diagonals_init, |
|
quantized_bucket_sizes_init, errors_init |
|
] |
|
perform_step = step % preconditioning_compute_steps == 0 |
|
(quantized_preconditioners_flat, quantized_diagonals_flat, |
|
quantized_bucket_sizes_flat, errors_flat) = ( |
|
efficient_cond(perform_step, _internal_inverse_pth_root_all, |
|
init_state)) |
|
|
|
def _skip(error): |
|
condition = jnp.logical_or( |
|
jnp.isnan(error), error >= inverse_failure_threshold) |
|
return condition.astype(error.dtype) |
|
|
|
def _select_preconditioner(error, new_p, old_p): |
|
return lax.cond( |
|
_skip(error), lambda _: old_p, lambda _: new_p, operand=None) |
|
|
|
new_quantized_preconditioners_flat = [] |
|
new_quantized_diagonals_flat = [] |
|
new_quantized_bucket_sizes_flat = [] |
|
for p, d, b, shape, prev_p, error in zip(quantized_preconditioners_flat, |
|
quantized_diagonals_flat, |
|
quantized_bucket_sizes_flat, |
|
original_shapes, |
|
prev_preconditioners, errors_flat): |
|
new_quantized_preconditioners_flat.append( |
|
_select_preconditioner(error, p[:shape[0], :shape[1]], |
|
prev_p.quantized)) |
|
new_quantized_diagonals_flat.append( |
|
_select_preconditioner(error, d[:shape[0]], prev_p.diagonal)) |
|
new_quantized_bucket_sizes_flat.append( |
|
_select_preconditioner(error, b[:shape[0]], prev_p.bucket_size)) |
|
|
|
assert len(states) == len(num_statistics_per_state) |
|
assert len(new_quantized_preconditioners_flat) == num_statistics |
|
assert len(new_quantized_diagonals_flat) == num_statistics |
|
assert len(new_quantized_bucket_sizes_flat) == num_statistics |
|
|
|
|
|
preconditioners_for_states = [] |
|
idx = 0 |
|
for num_statistics, state in zip(num_statistics_per_state, states): |
|
if num_statistics == 0: |
|
preconditioners_for_states.append([]) |
|
else: |
|
quantized_preconditioners_for_state = new_quantized_preconditioners_flat[ |
|
idx:idx + num_statistics] |
|
quantized_diagonals_for_state = new_quantized_diagonals_flat[ |
|
idx:idx + num_statistics] |
|
quantized_bucket_sizes_for_state = new_quantized_bucket_sizes_flat[ |
|
idx:idx + num_statistics] |
|
|
|
assert len(state.statistics) == len(quantized_preconditioners_for_state) |
|
assert len(state.statistics) == len(quantized_diagonals_for_state) |
|
assert len(state.statistics) == len(quantized_bucket_sizes_for_state) |
|
|
|
quantized_preconditioners = [] |
|
for qv, qd, qb in zip(quantized_preconditioners_for_state, |
|
quantized_diagonals_for_state, |
|
quantized_bucket_sizes_for_state): |
|
quantized_preconditioners.append( |
|
QuantizedValue(qv, qd, qb, qv.dtype, True, list(qv.shape))) |
|
preconditioners_for_states.append(quantized_preconditioners) |
|
idx += num_statistics |
|
new_states = [] |
|
for state, new_preconditioners in zip(states, preconditioners_for_states): |
|
new_states.append( |
|
ParameterStats(state.diagonal_statistics, state.statistics, |
|
new_preconditioners, state.diagonal_momentum, |
|
state.momentum)) |
|
|
|
return new_states |
|
|
|
def _pjit_compute_preconditioners(states, step, statistics, |
|
num_statistics_per_state, original_shapes, |
|
exponents, max_size, prev_preconditioners): |
|
"""Computes preconditioners for given statistics in states in PJIT mode. |
|
|
|
Args: |
|
states: A list of optimizer states. |
|
step: Current step number |
|
statistics: A list of statistics for all variables (for every dim) |
|
num_statistics_per_state: Number of statistis per state to reconstruct |
|
output states. |
|
original_shapes: A list of shapes of the statistics. |
|
exponents: Exponent power to use for inverse-pth roots. |
|
max_size: Maximum dim of the statistics to pad. |
|
prev_preconditioners: Previously available preconditioner. |
|
|
|
Returns: |
|
New optimizer states after computing the preconditioner. |
|
""" |
|
num_statistics = len(statistics) |
|
to_pad = -num_statistics % num_devices_for_pjit |
|
padded_statistics = [pad_matrix(stat, max_size) for stat in statistics] |
|
padded_statistics.extend([ |
|
jnp.eye(max_size, dtype=padded_statistics[0].dtype) |
|
for _ in range(to_pad) |
|
]) |
|
exponents.extend([1 for _ in range(to_pad)]) |
|
all_statistics = jnp.stack(padded_statistics) |
|
all_exponents = jnp.stack(exponents) |
|
|
|
def _internal_inverse_pth_root_all(): |
|
preconditioners, errors = _matrix_inverse_pth_root_pjit( |
|
all_statistics, all_exponents) |
|
b1 = preconditioners.shape[0] |
|
|
|
def split(batched_values): |
|
return [ |
|
jnp.squeeze(v) |
|
for v in jnp.split(batched_values, indices_or_sections=b1, axis=0) |
|
] |
|
|
|
return split(preconditioners), split(errors) |
|
|
|
if preconditioning_compute_steps == 1: |
|
preconditioners_flat, errors_flat = _internal_inverse_pth_root_all() |
|
else: |
|
|
|
|
|
|
|
preconditioners_init = padded_statistics |
|
errors_init = [inverse_failure_threshold] * len(padded_statistics) |
|
init_state = [preconditioners_init, errors_init] |
|
perform_step = step % preconditioning_compute_steps == 0 |
|
preconditioners_flat, errors_flat = efficient_cond( |
|
perform_step, _internal_inverse_pth_root_all, init_state) |
|
|
|
def _skip(error): |
|
condition = jnp.logical_or( |
|
jnp.isnan(error), error >= inverse_failure_threshold) |
|
return condition.astype(error.dtype) |
|
|
|
def _select_preconditioner(error, new_p, old_p): |
|
return lax.cond( |
|
_skip(error), lambda _: old_p, lambda _: new_p, operand=None) |
|
|
|
new_preconditioners_flat = [] |
|
for p, shape, prev_p, error in zip(preconditioners_flat, original_shapes, |
|
prev_preconditioners, errors_flat): |
|
new_preconditioners_flat.append( |
|
_select_preconditioner(error, p[:shape[0], :shape[1]], prev_p)) |
|
|
|
assert len(states) == len(num_statistics_per_state) |
|
assert len(new_preconditioners_flat) == num_statistics |
|
|
|
|
|
preconditioners_for_states = [] |
|
idx = 0 |
|
for num_statistics, state in zip(num_statistics_per_state, states): |
|
if num_statistics == 0: |
|
preconditioners_for_states.append([]) |
|
else: |
|
preconditioners_for_state = new_preconditioners_flat[idx:idx + |
|
num_statistics] |
|
assert len(state.statistics) == len(preconditioners_for_state) |
|
preconditioners_for_states.append(preconditioners_for_state) |
|
idx += num_statistics |
|
new_states = [] |
|
for state, new_preconditioners in zip(states, preconditioners_for_states): |
|
new_states.append( |
|
ParameterStats(state.diagonal_statistics, state.statistics, |
|
new_preconditioners, state.diagonal_momentum, |
|
state.momentum)) |
|
|
|
return new_states |
|
|
|
def _compute_preconditioners(states, params, step): |
|
"""Computes preconditioners for given statistics in states. |
|
|
|
Args: |
|
states: A list of optimizer states. |
|
params: A list of params. |
|
step: Current step number |
|
|
|
Returns: |
|
New optimizer states after computing the preconditioner. |
|
""" |
|
statistics = [] |
|
num_statistics_per_state = [] |
|
original_shapes = [] |
|
exponents = [] |
|
max_size = 0 |
|
prev_preconditioners = [] |
|
|
|
for state, param in zip(states, params): |
|
num_statistics = len(state.statistics) |
|
num_statistics_per_state.append(num_statistics) |
|
original_shapes_for_state = [] |
|
if num_statistics > 0: |
|
preconditioner = Preconditioner(param, block_size, |
|
best_effort_shape_interpretation) |
|
for statistic in state.statistics: |
|
exponents.append(preconditioner.exponent_for_preconditioner( |
|
) if exponent_override == 0 else exponent_override) |
|
original_shapes_for_state.append(statistic.shape) |
|
max_size = max(max_size, statistic.shape[0]) |
|
|
|
statistics.extend(state.statistics) |
|
prev_preconditioners.extend(state.preconditioners) |
|
original_shapes.extend(original_shapes_for_state) |
|
|
|
if batch_axis_name: |
|
|
|
quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers() |
|
|
|
if quantized_dtype == jnp.float32: |
|
return _pmap_compute_preconditioners(states, step, statistics, |
|
num_statistics_per_state, |
|
original_shapes, exponents, |
|
max_size, prev_preconditioners) |
|
else: |
|
return _pmap_quantized_compute_preconditioners( |
|
states, step, statistics, num_statistics_per_state, original_shapes, |
|
exponents, max_size, prev_preconditioners) |
|
|
|
else: |
|
return _pjit_compute_preconditioners(states, step, statistics, |
|
num_statistics_per_state, |
|
original_shapes, exponents, max_size, |
|
prev_preconditioners) |
|
|
|
def _transform_grad(grad, state, param, step): |
|
"""Transform per-parameter gradients.""" |
|
preconditioner = Preconditioner(param, block_size, |
|
best_effort_shape_interpretation) |
|
sgd_update = grad |
|
new_diagonal_statistics = state.diagonal_statistics.to_float() |
|
if graft_type == GraftingType.ADAGRAD: |
|
new_diagonal_statistics = state.diagonal_statistics.to_float( |
|
) + jnp.square(grad) |
|
adagrad_update = grad / ( |
|
jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon) |
|
grafting_update = adagrad_update |
|
elif (graft_type == GraftingType.RMSPROP or |
|
graft_type == GraftingType.RMSPROP_NORMALIZED): |
|
|
|
scaled_grad = grad |
|
if graft_type == GraftingType.RMSPROP_NORMALIZED: |
|
scaled_grad = grad / jnp.linalg.norm(grad) |
|
|
|
w1 = beta2 |
|
w2 = beta2 if beta2 == 1.0 else (1.0 - beta2) |
|
|
|
new_diagonal_statistics = ( |
|
w1 * state.diagonal_statistics.to_float() + |
|
w2 * jnp.square(scaled_grad)) |
|
rmsprop_update = scaled_grad / ( |
|
jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon) |
|
|
|
if clip_by_scaled_gradient_norm: |
|
scaled_grad_norm = jnp.linalg.norm(rmsprop_update) / ( |
|
jnp.sqrt(float(rmsprop_update.size))) |
|
clipping_denom = jnp.maximum( |
|
1., scaled_grad_norm / clip_by_scaled_gradient_norm) |
|
rmsprop_update /= clipping_denom |
|
|
|
grafting_update = rmsprop_update |
|
else: |
|
grafting_update = sgd_update |
|
|
|
precond_grad = grad |
|
if not _skip_preconditioning(param): |
|
precond_grad = preconditioner.preconditioned_grad( |
|
precond_grad, |
|
_maybe_dequantize_preconditioners(state.preconditioners)) |
|
else: |
|
precond_grad = grafting_update |
|
|
|
grafting_update_norm = jnp.linalg.norm(grafting_update) |
|
precond_grad_norm = jnp.linalg.norm(precond_grad) |
|
|
|
multiplier = (grafting_update_norm / (precond_grad_norm + 1e-16)) |
|
shampoo_update = precond_grad * multiplier |
|
|
|
shampoo_update_with_wd = shampoo_update |
|
grafting_update_with_wd = grafting_update |
|
if weight_decay != 0: |
|
shampoo_update_with_wd = shampoo_update + weight_decay * param |
|
grafting_update_with_wd = grafting_update + weight_decay * param |
|
|
|
w = (1.0 - beta1) if moving_average_for_momentum else 1.0 |
|
shampoo_update_with_wd_momentum = ( |
|
state.momentum.to_float() * beta1 + w * shampoo_update_with_wd) |
|
grafting_update_with_wd_momentum = ( |
|
state.diagonal_momentum.to_float() * beta1 + |
|
w * grafting_update_with_wd) |
|
|
|
run_shampoo = (step >= start_preconditioning_step).astype( |
|
grafting_update_with_wd_momentum.dtype) |
|
|
|
momentum_update = ( |
|
run_shampoo * shampoo_update_with_wd_momentum + |
|
(1.0 - run_shampoo) * grafting_update_with_wd_momentum) |
|
|
|
wd_update = ( |
|
run_shampoo * shampoo_update_with_wd + |
|
(1.0 - run_shampoo) * grafting_update_with_wd) |
|
|
|
if nesterov: |
|
momentum_update = w * wd_update + beta1 * momentum_update |
|
|
|
lr = learning_rate |
|
if callable(learning_rate): |
|
lr = learning_rate(step) |
|
transformed_update = -1.0 * lr * momentum_update |
|
|
|
param_stats = ParameterStats( |
|
_quantize_diagonal_statistics(new_diagonal_statistics), |
|
state.statistics, state.preconditioners, |
|
_quantize_momentum(grafting_update_with_wd_momentum), |
|
_quantize_momentum(shampoo_update_with_wd_momentum)) |
|
return transformed_update, param_stats |
|
|
|
def update_fn(grads, state, params): |
|
"""Transform the input gradient and update all statistics. |
|
|
|
Args: |
|
grads: the gradient tensors for the parameters. |
|
state: a named tuple containing the state of the optimizer |
|
params: the parameters that should be updated. |
|
|
|
Returns: |
|
A tuple containing the new parameters and the new optimizer state. |
|
""" |
|
params_flat, treedef = jax.tree_flatten(params) |
|
stats_flat = treedef.flatten_up_to(state.stats) |
|
grads_flat = treedef.flatten_up_to(grads) |
|
|
|
new_stats_flat = jax.tree_multimap( |
|
lambda g, s, p: _compute_stats(g, s, p, state.count), grads_flat, |
|
stats_flat, params_flat) |
|
new_stats_flat = _compute_preconditioners(new_stats_flat, params_flat, |
|
state.count) |
|
|
|
outputs = jax.tree_multimap( |
|
lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat, |
|
new_stats_flat, params_flat) |
|
updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ()) |
|
|
|
updates = jax.tree_unflatten(treedef, updates_flat) |
|
new_stats = jax.tree_unflatten(treedef, new_stats_flat) |
|
|
|
new_state = ShampooState( |
|
count=state.count+1, stats=new_stats) |
|
return updates, new_state |
|
|
|
if shard_optimizer_states: |
|
|
|
|
|
def _init_fns(unused_params): |
|
return InitFnState( |
|
init_fn=sharded_init_fn, |
|
pspec_fn=sharded_init_partition_spec_fn, |
|
shape_and_dtype_fn=sharded_init_shape_and_dtype_fn) |
|
|
|
return optax.GradientTransformation(_init_fns, sharded_update_fn) |
|
else: |
|
return optax.GradientTransformation(init_fn, update_fn) |