Spaces:
Running
Running
feat: update distributed_shampoo
Browse files- tools/train/distributed_shampoo.py +684 -207
tools/train/distributed_shampoo.py
CHANGED
@@ -33,7 +33,7 @@
|
|
33 |
import enum
|
34 |
import functools
|
35 |
import itertools
|
36 |
-
from typing import Any, NamedTuple
|
37 |
|
38 |
import chex
|
39 |
from flax import struct
|
@@ -46,16 +46,105 @@ import optax
|
|
46 |
|
47 |
|
48 |
# pylint:disable=no-value-for-parameter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
|
51 |
# Per parameter optimizer state used in data-parallel training.
|
52 |
class ParameterStats(NamedTuple):
|
53 |
"""State associated to each parameter of the model being trained."""
|
54 |
-
diagonal_statistics:
|
55 |
-
statistics:
|
56 |
-
preconditioners:
|
57 |
-
diagonal_momentum:
|
58 |
-
momentum:
|
59 |
|
60 |
|
61 |
# For training extremely large model; We keep a global state with a concatenated
|
@@ -73,9 +162,9 @@ class GlobalShardedParameterStats:
|
|
73 |
@struct.dataclass
|
74 |
class LocalShardedParameterStats:
|
75 |
"""State associated to each parameter of the model being trained."""
|
76 |
-
diagonal_statistics:
|
77 |
-
diagonal_momentum:
|
78 |
-
momentum:
|
79 |
index_start: np.int32 = struct.field(
|
80 |
pytree_node=False) # Index into global statistics array
|
81 |
sizes: Any = struct.field(pytree_node=False) # Sizes of the statistics.
|
@@ -141,7 +230,8 @@ def power_iteration(
|
|
141 |
jnp.greater(jnp.abs(s_new - s), error_tolerance))
|
142 |
|
143 |
# Figure out how to use step as seed for random.
|
144 |
-
v_0 = np.random.uniform(-1.0, 1.0,
|
|
|
145 |
|
146 |
init_state = tuple([0, v_0, jnp.zeros([], dtype=matrix.dtype), v_0, True])
|
147 |
_, v_out, s_out, _, _ = lax.while_loop(
|
@@ -323,6 +413,25 @@ def pad_matrix(mat, max_size):
|
|
323 |
return mat
|
324 |
|
325 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
def efficient_cond(predicate, compute_fn, init_state, *args, **kwargs):
|
327 |
"""Avoids wasteful buffer allocation with XLA."""
|
328 |
|
@@ -492,33 +601,59 @@ def _convert_from_parameter_stats(parameter_stats, local_stats):
|
|
492 |
local_stats.index_start, local_stats.sizes)
|
493 |
|
494 |
|
495 |
-
def
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
522 |
"""Distributed Shampoo optimizer.
|
523 |
|
524 |
Distributed Shampoo is a second-order preconditioned method (concretely, a
|
@@ -570,6 +705,10 @@ def distributed_shampoo(learning_rate,
|
|
570 |
num_devices_for_pjit: Number of devices to parallelize over when using pjit.
|
571 |
shard_optimizer_states: Shard optimizer states to save memory in model
|
572 |
parallel training.
|
|
|
|
|
|
|
|
|
573 |
inverse_failure_threshold: numerics are hard and inverses fail sometimes; we
|
574 |
determine that using this threshold.
|
575 |
moving_average_for_momentum: Whether to use moving average for momentum
|
@@ -587,6 +726,67 @@ def distributed_shampoo(learning_rate,
|
|
587 |
a GradientTransformation.
|
588 |
"""
|
589 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
590 |
def sharded_init_fn(params):
|
591 |
params_flat, treedef = jax.tree_flatten(params)
|
592 |
# Find max size to pad to.
|
@@ -619,12 +819,14 @@ def distributed_shampoo(learning_rate,
|
|
619 |
padded_statistics.extend(statistics)
|
620 |
padded_preconditioners.extend(preconditioners)
|
621 |
|
622 |
-
|
623 |
if graft_type != GraftingType.SGD:
|
624 |
-
|
625 |
local_stats_flat.append(
|
626 |
-
LocalShardedParameterStats(
|
627 |
-
|
|
|
|
|
628 |
|
629 |
local_stats = jax.tree_unflatten(treedef, local_stats_flat)
|
630 |
# Pad the statistics and preconditioner matrices to be a multiple of
|
@@ -769,12 +971,15 @@ def distributed_shampoo(learning_rate,
|
|
769 |
statistics = [matrix_epsilon * jnp.eye(s[0]) for s in shapes]
|
770 |
preconditioners = [jnp.eye(s[0]) for s in shapes]
|
771 |
|
772 |
-
|
773 |
if graft_type != GraftingType.SGD:
|
774 |
-
|
775 |
-
return ParameterStats(
|
776 |
-
|
777 |
-
|
|
|
|
|
|
|
778 |
return ShampooState(
|
779 |
count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params))
|
780 |
|
@@ -795,8 +1000,9 @@ def distributed_shampoo(learning_rate,
|
|
795 |
new_stats = preconditioner.statistics_from_grad(grad)
|
796 |
new_stats_accumulators = []
|
797 |
for stat, stat_accumulator in zip(new_stats, state.statistics):
|
798 |
-
new_stats_accumulators.append(w1 * stat_accumulator +
|
799 |
-
|
|
|
800 |
|
801 |
if statistics_compute_steps > 1:
|
802 |
perform_step = step % statistics_compute_steps == 0
|
@@ -810,164 +1016,375 @@ def distributed_shampoo(learning_rate,
|
|
810 |
state.preconditioners, state.diagonal_momentum,
|
811 |
state.momentum)
|
812 |
|
813 |
-
def
|
814 |
-
|
815 |
-
|
816 |
-
|
817 |
-
|
818 |
-
|
819 |
-
|
820 |
-
|
821 |
-
|
822 |
-
|
823 |
-
|
824 |
-
|
825 |
-
|
826 |
-
|
827 |
-
|
828 |
-
|
829 |
-
|
830 |
-
|
831 |
-
|
832 |
-
|
833 |
-
|
834 |
-
|
835 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
836 |
num_statistics = len(statistics)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
837 |
|
838 |
-
if
|
839 |
-
|
840 |
-
|
841 |
-
|
842 |
-
|
843 |
-
|
844 |
-
|
845 |
-
|
846 |
-
|
847 |
-
|
848 |
-
|
849 |
-
|
850 |
-
|
851 |
-
|
852 |
-
|
853 |
-
|
854 |
-
|
855 |
-
|
856 |
-
n = len(statistics)
|
857 |
-
b = int(n / num_devices)
|
858 |
-
batched_statistics = [
|
859 |
-
jnp.stack(statistics[idx:idx + b]) for idx in range(0, n, b)
|
860 |
-
]
|
861 |
-
batched_exponents = [
|
862 |
-
jnp.stack(exponents[idx:idx + b]) for idx in range(0, n, b)
|
863 |
-
]
|
864 |
-
return jnp.stack(batched_statistics), jnp.stack(batched_exponents)
|
865 |
-
|
866 |
-
# Unbatch values across leading axis and return a list of elements.
|
867 |
-
def _unbatch(batched_values):
|
868 |
-
b1, b2 = batched_values.shape[0], batched_values.shape[1]
|
869 |
-
results = []
|
870 |
-
for v_array in jnp.split(
|
871 |
-
batched_values, indices_or_sections=b1, axis=0):
|
872 |
-
v_array = jnp.squeeze(v_array)
|
873 |
-
# b2 = batches (number of preconditioner computation) per core.
|
874 |
-
if b2 > 1:
|
875 |
-
for v in jnp.split(v_array, indices_or_sections=b2, axis=0):
|
876 |
-
results.append(jnp.squeeze(v))
|
877 |
-
else:
|
878 |
-
results.append(v_array)
|
879 |
-
return results
|
880 |
-
|
881 |
-
all_statistics, all_exponents = _batch(packed_statistics, exponents,
|
882 |
-
num_devices)
|
883 |
else:
|
884 |
-
|
885 |
-
|
886 |
-
|
887 |
-
|
888 |
-
|
889 |
-
]
|
890 |
-
|
891 |
-
|
892 |
-
|
893 |
|
894 |
-
def
|
895 |
-
|
896 |
-
|
897 |
-
|
898 |
-
precision=precision)
|
899 |
-
preconditioners, errors = jax.vmap(mi_pth_root)(xs, ps)
|
900 |
-
return preconditioners, errors
|
901 |
|
902 |
-
def
|
903 |
-
|
904 |
-
|
905 |
-
|
906 |
-
|
907 |
-
|
908 |
-
|
909 |
-
|
910 |
-
|
911 |
-
|
912 |
-
|
913 |
-
|
914 |
-
lambda x, y: (x, y),
|
915 |
-
in_axis_resources=(pjit.PartitionSpec(mesh_axis_names_tuple,),
|
916 |
-
pjit.PartitionSpec(mesh_axis_names_tuple,)),
|
917 |
-
out_axis_resources=(None, None))(partitioned_preconditioners,
|
918 |
-
partitioned_errors)
|
919 |
-
return preconditioners, errors
|
920 |
|
921 |
-
|
922 |
-
|
923 |
-
|
924 |
-
|
925 |
-
|
926 |
-
|
927 |
-
return [
|
928 |
-
jnp.squeeze(v) for v in jnp.split(
|
929 |
-
batched_values, indices_or_sections=b1, axis=0)
|
930 |
-
]
|
931 |
-
|
932 |
-
return split(preconditioners), split(errors)
|
933 |
-
|
934 |
-
if preconditioning_compute_steps == 1:
|
935 |
-
preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
|
936 |
else:
|
937 |
-
|
938 |
-
|
939 |
-
|
940 |
-
|
941 |
-
|
942 |
-
|
943 |
-
|
944 |
-
|
945 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
946 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
947 |
|
948 |
-
|
949 |
-
|
950 |
-
|
951 |
-
|
952 |
-
|
953 |
-
|
954 |
-
|
955 |
-
|
956 |
-
|
957 |
-
|
958 |
-
|
959 |
-
|
960 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
961 |
else:
|
962 |
-
|
963 |
-
|
964 |
-
|
965 |
-
|
966 |
-
|
967 |
-
|
968 |
-
|
969 |
-
|
970 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
971 |
|
972 |
def _skip(error):
|
973 |
condition = jnp.logical_or(
|
@@ -1008,14 +1425,70 @@ def distributed_shampoo(learning_rate,
|
|
1008 |
|
1009 |
return new_states
|
1010 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1011 |
def _transform_grad(grad, state, param, step):
|
1012 |
"""Transform per-parameter gradients."""
|
1013 |
preconditioner = Preconditioner(param, block_size,
|
1014 |
best_effort_shape_interpretation)
|
1015 |
sgd_update = grad
|
1016 |
-
new_diagonal_statistics = state.diagonal_statistics
|
1017 |
if graft_type == GraftingType.ADAGRAD:
|
1018 |
-
new_diagonal_statistics = state.diagonal_statistics
|
|
|
1019 |
adagrad_update = grad / (
|
1020 |
jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon)
|
1021 |
grafting_update = adagrad_update
|
@@ -1030,7 +1503,8 @@ def distributed_shampoo(learning_rate,
|
|
1030 |
w2 = beta2 if beta2 == 1.0 else (1.0 - beta2)
|
1031 |
|
1032 |
new_diagonal_statistics = (
|
1033 |
-
w1 * state.diagonal_statistics
|
|
|
1034 |
rmsprop_update = scaled_grad / (
|
1035 |
jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon)
|
1036 |
|
@@ -1047,8 +1521,9 @@ def distributed_shampoo(learning_rate,
|
|
1047 |
|
1048 |
precond_grad = grad
|
1049 |
if not _skip_preconditioning(param):
|
1050 |
-
precond_grad = preconditioner.preconditioned_grad(
|
1051 |
-
|
|
|
1052 |
else:
|
1053 |
precond_grad = grafting_update
|
1054 |
|
@@ -1066,9 +1541,10 @@ def distributed_shampoo(learning_rate,
|
|
1066 |
|
1067 |
w = (1.0 - beta1) if moving_average_for_momentum else 1.0
|
1068 |
shampoo_update_with_wd_momentum = (
|
1069 |
-
state.momentum * beta1 + w * shampoo_update_with_wd)
|
1070 |
grafting_update_with_wd_momentum = (
|
1071 |
-
state.diagonal_momentum * beta1 +
|
|
|
1072 |
|
1073 |
run_shampoo = (step >= start_preconditioning_step).astype(
|
1074 |
grafting_update_with_wd_momentum.dtype)
|
@@ -1089,10 +1565,11 @@ def distributed_shampoo(learning_rate,
|
|
1089 |
lr = learning_rate(step)
|
1090 |
transformed_update = -1.0 * lr * momentum_update
|
1091 |
|
1092 |
-
param_stats = ParameterStats(
|
1093 |
-
|
1094 |
-
|
1095 |
-
|
|
|
1096 |
return transformed_update, param_stats
|
1097 |
|
1098 |
def update_fn(grads, state, params):
|
|
|
33 |
import enum
|
34 |
import functools
|
35 |
import itertools
|
36 |
+
from typing import Any, List, NamedTuple
|
37 |
|
38 |
import chex
|
39 |
from flax import struct
|
|
|
46 |
|
47 |
|
48 |
# pylint:disable=no-value-for-parameter
|
49 |
+
@struct.dataclass
|
50 |
+
class QuantizedValue:
|
51 |
+
"""State associated with quantized value."""
|
52 |
+
quantized: chex.Array
|
53 |
+
diagonal: chex.Array # Diagonal (if extract_diagonal is set)
|
54 |
+
bucket_size: chex.Array
|
55 |
+
quantized_dtype: jnp.dtype = struct.field(
|
56 |
+
pytree_node=False) # Dtype for the quantized value.
|
57 |
+
extract_diagonal: bool = struct.field(
|
58 |
+
pytree_node=False) # In case its centered.
|
59 |
+
shape: Any = struct.field(pytree_node=False) # Shape of the tensor.
|
60 |
+
|
61 |
+
@classmethod
|
62 |
+
def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False):
|
63 |
+
if isinstance(fvalue, list) and not fvalue:
|
64 |
+
return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, [])
|
65 |
+
quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize(
|
66 |
+
fvalue, quantized_dtype, extract_diagonal)
|
67 |
+
return QuantizedValue(quantized, diagonal_fvalue, bucket_size,
|
68 |
+
quantized_dtype, extract_diagonal,
|
69 |
+
list(quantized.shape))
|
70 |
+
|
71 |
+
# Quantization is from Lingvo JAX optimizers.
|
72 |
+
# We extend it for int16 quantization of PSD matrices.
|
73 |
+
@classmethod
|
74 |
+
def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False):
|
75 |
+
"""Returns quantized value and the bucket."""
|
76 |
+
if quantized_dtype == jnp.float32:
|
77 |
+
return fvalue, [], []
|
78 |
+
elif quantized_dtype == jnp.bfloat16:
|
79 |
+
return fvalue.astype(jnp.bfloat16), [], []
|
80 |
+
|
81 |
+
float_dtype = fvalue.dtype
|
82 |
+
if quantized_dtype == jnp.int8:
|
83 |
+
# value -128 is not used.
|
84 |
+
num_buckets = jnp.array(127.0, dtype=float_dtype)
|
85 |
+
elif quantized_dtype == jnp.int16:
|
86 |
+
# value -32768 is not used.
|
87 |
+
num_buckets = jnp.array(32767.0, dtype=float_dtype)
|
88 |
+
else:
|
89 |
+
raise ValueError(f'Quantized dtype {quantized_dtype} not supported.')
|
90 |
+
# max value is mapped to num_buckets
|
91 |
+
|
92 |
+
if extract_diagonal and fvalue.ndim != 2:
|
93 |
+
raise ValueError(
|
94 |
+
f'Input array {fvalue} must be 2D to work with extract_diagonal.')
|
95 |
+
|
96 |
+
diagonal_fvalue = []
|
97 |
+
if extract_diagonal:
|
98 |
+
diagonal_fvalue = jnp.diag(fvalue)
|
99 |
+
# Remove the diagonal entries.
|
100 |
+
fvalue = fvalue - jnp.diag(diagonal_fvalue)
|
101 |
+
|
102 |
+
# TODO(rohananil): Extend this by making use of information about the blocks
|
103 |
+
# SM3 style which will be useful for diagonal statistics
|
104 |
+
# We first decide the scale.
|
105 |
+
if fvalue.ndim < 1:
|
106 |
+
raise ValueError(
|
107 |
+
f'Input array {fvalue} must have a strictly positive number of '
|
108 |
+
'dimensions.')
|
109 |
+
|
110 |
+
max_abs = jnp.max(jnp.abs(fvalue), axis=0)
|
111 |
+
bucket_size = max_abs / num_buckets
|
112 |
+
bs_expanded = bucket_size[jnp.newaxis, Ellipsis]
|
113 |
+
# To avoid divide by 0.0
|
114 |
+
bs_nonzero = jnp.where(bs_expanded > 0.0, bs_expanded,
|
115 |
+
jnp.ones_like(bs_expanded))
|
116 |
+
ratio = fvalue / bs_nonzero
|
117 |
+
# We use rounding to remove bias.
|
118 |
+
quantized = jnp.round(ratio)
|
119 |
+
return quantized.astype(quantized_dtype), diagonal_fvalue, bucket_size
|
120 |
+
|
121 |
+
def to_float(self):
|
122 |
+
"""Returns the float value."""
|
123 |
+
if isinstance(self.quantized, list) and not self.quantized:
|
124 |
+
return self.quantized
|
125 |
+
|
126 |
+
if self.quantized_dtype == jnp.float32:
|
127 |
+
return self.quantized
|
128 |
+
|
129 |
+
if self.quantized_dtype == jnp.bfloat16:
|
130 |
+
return self.quantized.astype(jnp.float32)
|
131 |
+
|
132 |
+
float_dtype = self.bucket_size.dtype
|
133 |
+
bucket_size = self.bucket_size[jnp.newaxis, Ellipsis]
|
134 |
+
val = self.quantized.astype(float_dtype) * bucket_size
|
135 |
+
if self.extract_diagonal:
|
136 |
+
val += jnp.diag(self.diagonal)
|
137 |
+
return val
|
138 |
|
139 |
|
140 |
# Per parameter optimizer state used in data-parallel training.
|
141 |
class ParameterStats(NamedTuple):
|
142 |
"""State associated to each parameter of the model being trained."""
|
143 |
+
diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
|
144 |
+
statistics: List[Any] # Statistics (QuantizedValue, chex.Array)
|
145 |
+
preconditioners: List[Any] # Preconditioners (QuantizedValue, chex.Array)
|
146 |
+
diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
|
147 |
+
momentum: QuantizedValue # Momentum for the shampoo preconditioner
|
148 |
|
149 |
|
150 |
# For training extremely large model; We keep a global state with a concatenated
|
|
|
162 |
@struct.dataclass
|
163 |
class LocalShardedParameterStats:
|
164 |
"""State associated to each parameter of the model being trained."""
|
165 |
+
diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
|
166 |
+
diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
|
167 |
+
momentum: QuantizedValue # Momentum for the shampoo preconditioner
|
168 |
index_start: np.int32 = struct.field(
|
169 |
pytree_node=False) # Index into global statistics array
|
170 |
sizes: Any = struct.field(pytree_node=False) # Sizes of the statistics.
|
|
|
230 |
jnp.greater(jnp.abs(s_new - s), error_tolerance))
|
231 |
|
232 |
# Figure out how to use step as seed for random.
|
233 |
+
v_0 = np.random.RandomState(1729).uniform(-1.0, 1.0,
|
234 |
+
matrix_size).astype(matrix.dtype)
|
235 |
|
236 |
init_state = tuple([0, v_0, jnp.zeros([], dtype=matrix.dtype), v_0, True])
|
237 |
_, v_out, s_out, _, _ = lax.while_loop(
|
|
|
413 |
return mat
|
414 |
|
415 |
|
416 |
+
def pad_vector(vec, max_size):
|
417 |
+
"""Pad a vector to a max_size.
|
418 |
+
|
419 |
+
Args:
|
420 |
+
vec: a vector to pad.
|
421 |
+
max_size: matrix size requested.
|
422 |
+
|
423 |
+
Returns:
|
424 |
+
Given V returns [V, 0]
|
425 |
+
"""
|
426 |
+
size = vec.shape[0]
|
427 |
+
assert size <= max_size
|
428 |
+
if size == max_size:
|
429 |
+
return vec
|
430 |
+
pad_size = max_size - size
|
431 |
+
zs1 = jnp.zeros([pad_size], dtype=vec.dtype)
|
432 |
+
return jnp.concatenate([vec, zs1], 0)
|
433 |
+
|
434 |
+
|
435 |
def efficient_cond(predicate, compute_fn, init_state, *args, **kwargs):
|
436 |
"""Avoids wasteful buffer allocation with XLA."""
|
437 |
|
|
|
601 |
local_stats.index_start, local_stats.sizes)
|
602 |
|
603 |
|
604 |
+
def batch(x, num_devices):
|
605 |
+
"""Batch `x` so that so that leading axis is num_devices."""
|
606 |
+
n = len(x)
|
607 |
+
b = int(n / num_devices)
|
608 |
+
return jnp.stack([jnp.stack(x[idx:idx + b]) for idx in range(0, n, b)])
|
609 |
+
|
610 |
+
|
611 |
+
def unbatch(batched_values):
|
612 |
+
"""Unbatch values across leading axis and return a list of elements."""
|
613 |
+
b1, b2 = batched_values.shape[0], batched_values.shape[1]
|
614 |
+
results = []
|
615 |
+
for v_array in jnp.split(batched_values, indices_or_sections=b1, axis=0):
|
616 |
+
v_array = jnp.squeeze(v_array)
|
617 |
+
# b2 = batches (number of preconditioner computation) per core.
|
618 |
+
if b2 > 1:
|
619 |
+
for v in jnp.split(v_array, indices_or_sections=b2, axis=0):
|
620 |
+
results.append(jnp.squeeze(v))
|
621 |
+
else:
|
622 |
+
results.append(v_array)
|
623 |
+
return results
|
624 |
+
|
625 |
+
|
626 |
+
def distributed_shampoo(
|
627 |
+
learning_rate,
|
628 |
+
block_size,
|
629 |
+
beta1=0.9,
|
630 |
+
beta2=0.999,
|
631 |
+
diagonal_epsilon=1e-10,
|
632 |
+
matrix_epsilon=1e-6,
|
633 |
+
weight_decay=0.0,
|
634 |
+
start_preconditioning_step=5,
|
635 |
+
preconditioning_compute_steps=1,
|
636 |
+
statistics_compute_steps=1,
|
637 |
+
best_effort_shape_interpretation=True,
|
638 |
+
graft_type=GraftingType.SGD,
|
639 |
+
nesterov=True,
|
640 |
+
exponent_override=0,
|
641 |
+
# Pass pmap 'batch axis name' in pmap mode.
|
642 |
+
batch_axis_name=None,
|
643 |
+
### Only set following 3 params in pjit/spmd mode.
|
644 |
+
### WARNING: Experimental
|
645 |
+
mesh_axis_names=None,
|
646 |
+
num_devices_for_pjit=None,
|
647 |
+
shard_optimizer_states=False,
|
648 |
+
###
|
649 |
+
### Experimental memory reduction mode
|
650 |
+
best_effort_memory_usage_reduction=False,
|
651 |
+
###
|
652 |
+
inverse_failure_threshold=0.1,
|
653 |
+
moving_average_for_momentum=False,
|
654 |
+
skip_preconditioning_dim_size_gt=4096,
|
655 |
+
clip_by_scaled_gradient_norm=None,
|
656 |
+
precision=lax.Precision.HIGHEST):
|
657 |
"""Distributed Shampoo optimizer.
|
658 |
|
659 |
Distributed Shampoo is a second-order preconditioned method (concretely, a
|
|
|
705 |
num_devices_for_pjit: Number of devices to parallelize over when using pjit.
|
706 |
shard_optimizer_states: Shard optimizer states to save memory in model
|
707 |
parallel training.
|
708 |
+
best_effort_memory_usage_reduction: Best effort memory usage reduction.
|
709 |
+
diagonal_statistics -> jnp.bfloat16
|
710 |
+
momentum buffers (2x) -> jnp.int8
|
711 |
+
statistics, preconditioners -> jnp.int16 + diagonals
|
712 |
inverse_failure_threshold: numerics are hard and inverses fail sometimes; we
|
713 |
determine that using this threshold.
|
714 |
moving_average_for_momentum: Whether to use moving average for momentum
|
|
|
726 |
a GradientTransformation.
|
727 |
"""
|
728 |
|
729 |
+
def quantized_dtype_for_momentum_buffers():
|
730 |
+
return jnp.int8 if best_effort_memory_usage_reduction else jnp.float32
|
731 |
+
|
732 |
+
# TODO(rohananil): Explore int8-16 quantization with non-linear bucket sizes.
|
733 |
+
def quantized_dtype_for_diagonal_statistics_buffers():
|
734 |
+
return jnp.bfloat16 if best_effort_memory_usage_reduction else jnp.float32
|
735 |
+
|
736 |
+
# Preconditioner and statistics are both stores as int16 in this mode.
|
737 |
+
# We take out the diagonal to make quantization easier.
|
738 |
+
def quantized_dtype_for_second_moment_statistics_buffers():
|
739 |
+
return jnp.int16 if best_effort_memory_usage_reduction and batch_axis_name else jnp.float32
|
740 |
+
|
741 |
+
# Preconditioner and statistics are both stores as int16 in this mode.
|
742 |
+
# We take out the diagonal to make quantization easier.
|
743 |
+
def quantized_dtype_for_second_moment_preconditioner_buffers():
|
744 |
+
return jnp.int16 if best_effort_memory_usage_reduction and batch_axis_name else jnp.float32
|
745 |
+
|
746 |
+
def _to_float(maybe_quantized):
|
747 |
+
if isinstance(maybe_quantized, QuantizedValue):
|
748 |
+
return maybe_quantized.to_float()
|
749 |
+
else:
|
750 |
+
return maybe_quantized
|
751 |
+
|
752 |
+
def _maybe_quantize_statistics(statistics_list):
|
753 |
+
return _maybe_quantize_matrices_with_dtype(
|
754 |
+
statistics_list, quantized_dtype_for_second_moment_statistics_buffers())
|
755 |
+
|
756 |
+
def _maybe_quantize_preconditioners(statistics_list):
|
757 |
+
return _maybe_quantize_matrices_with_dtype(
|
758 |
+
statistics_list,
|
759 |
+
quantized_dtype_for_second_moment_preconditioner_buffers())
|
760 |
+
|
761 |
+
def _maybe_quantize_matrices_with_dtype(statistics_list, quantized_dtype):
|
762 |
+
if quantized_dtype != jnp.float32:
|
763 |
+
return ([
|
764 |
+
QuantizedValue.from_float_value(
|
765 |
+
s, quantized_dtype, extract_diagonal=True)
|
766 |
+
for s in statistics_list
|
767 |
+
])
|
768 |
+
else:
|
769 |
+
return statistics_list
|
770 |
+
|
771 |
+
def _maybe_dequantize_preconditioners(preconditioner_list):
|
772 |
+
return _maybe_dequantize_matrices_with_dtype(
|
773 |
+
preconditioner_list,
|
774 |
+
quantized_dtype_for_second_moment_preconditioner_buffers())
|
775 |
+
|
776 |
+
def _maybe_dequantize_matrices_with_dtype(statistics_list, quantized_dtype):
|
777 |
+
if quantized_dtype != jnp.float32:
|
778 |
+
return [s.to_float() for s in statistics_list]
|
779 |
+
else:
|
780 |
+
return statistics_list
|
781 |
+
|
782 |
+
def _quantize_diagonal_statistics(diagonal_statistics):
|
783 |
+
return QuantizedValue.from_float_value(
|
784 |
+
diagonal_statistics, quantized_dtype_for_diagonal_statistics_buffers())
|
785 |
+
|
786 |
+
def _quantize_momentum(momentum_statistics):
|
787 |
+
return QuantizedValue.from_float_value(
|
788 |
+
momentum_statistics, quantized_dtype_for_momentum_buffers())
|
789 |
+
|
790 |
def sharded_init_fn(params):
|
791 |
params_flat, treedef = jax.tree_flatten(params)
|
792 |
# Find max size to pad to.
|
|
|
819 |
padded_statistics.extend(statistics)
|
820 |
padded_preconditioners.extend(preconditioners)
|
821 |
|
822 |
+
diagonal_statistics = []
|
823 |
if graft_type != GraftingType.SGD:
|
824 |
+
diagonal_statistics = jnp.zeros_like(param)
|
825 |
local_stats_flat.append(
|
826 |
+
LocalShardedParameterStats(
|
827 |
+
_quantize_diagonal_statistics(diagonal_statistics),
|
828 |
+
_quantize_momentum(jnp.zeros_like(param)),
|
829 |
+
_quantize_momentum(jnp.zeros_like(param)), index_start, sizes))
|
830 |
|
831 |
local_stats = jax.tree_unflatten(treedef, local_stats_flat)
|
832 |
# Pad the statistics and preconditioner matrices to be a multiple of
|
|
|
971 |
statistics = [matrix_epsilon * jnp.eye(s[0]) for s in shapes]
|
972 |
preconditioners = [jnp.eye(s[0]) for s in shapes]
|
973 |
|
974 |
+
diagonal_statistics = []
|
975 |
if graft_type != GraftingType.SGD:
|
976 |
+
diagonal_statistics = jnp.zeros_like(param)
|
977 |
+
return ParameterStats(
|
978 |
+
_quantize_diagonal_statistics(diagonal_statistics),
|
979 |
+
_maybe_quantize_statistics(statistics),
|
980 |
+
_maybe_quantize_preconditioners(preconditioners),
|
981 |
+
_quantize_momentum(jnp.zeros_like(param)),
|
982 |
+
_quantize_momentum(jnp.zeros_like(param)))
|
983 |
return ShampooState(
|
984 |
count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params))
|
985 |
|
|
|
1000 |
new_stats = preconditioner.statistics_from_grad(grad)
|
1001 |
new_stats_accumulators = []
|
1002 |
for stat, stat_accumulator in zip(new_stats, state.statistics):
|
1003 |
+
new_stats_accumulators.append(w1 * _to_float(stat_accumulator) +
|
1004 |
+
w2 * stat)
|
1005 |
+
return _maybe_quantize_statistics(new_stats_accumulators)
|
1006 |
|
1007 |
if statistics_compute_steps > 1:
|
1008 |
perform_step = step % statistics_compute_steps == 0
|
|
|
1016 |
state.preconditioners, state.diagonal_momentum,
|
1017 |
state.momentum)
|
1018 |
|
1019 |
+
def _matrix_inverse_pth_root_vmap(xs, ps):
|
1020 |
+
mi_pth_root = functools.partial(
|
1021 |
+
matrix_inverse_pth_root,
|
1022 |
+
ridge_epsilon=matrix_epsilon,
|
1023 |
+
precision=precision)
|
1024 |
+
return jax.vmap(mi_pth_root)(xs, ps)
|
1025 |
+
|
1026 |
+
def _quantized_matrix_inverse_pth_root_vmap(qxs, qds, qbs, ps):
|
1027 |
+
|
1028 |
+
def _quantized_to_float(qx, qd, qb):
|
1029 |
+
qv = QuantizedValue(qx, qd, qb, qx.dtype, True, list(qx.shape))
|
1030 |
+
return qv.to_float()
|
1031 |
+
|
1032 |
+
def matrix_inverse_pth_root_wrapper(qx, qd, qb, p):
|
1033 |
+
v = _quantized_to_float(qx, qd, qb)
|
1034 |
+
preconditioner, error = matrix_inverse_pth_root(
|
1035 |
+
v, p, ridge_epsilon=matrix_epsilon, precision=precision)
|
1036 |
+
qp = QuantizedValue.from_float_value(preconditioner, qx.dtype, True)
|
1037 |
+
return qp.quantized, qp.diagonal, qp.bucket_size, error
|
1038 |
+
|
1039 |
+
return jax.vmap(matrix_inverse_pth_root_wrapper)(qxs, qds, qbs, ps)
|
1040 |
+
|
1041 |
+
def _matrix_inverse_pth_root_pjit(xs, ps):
|
1042 |
+
mesh_axis_names_tuple = tuple(mesh_axis_names)
|
1043 |
+
# Partition the concatenated statistics matrix across all cores.
|
1044 |
+
partitioned_xs, partitioned_ps = pjit.pjit(
|
1045 |
+
lambda x, y: (x, y),
|
1046 |
+
in_axis_resources=None,
|
1047 |
+
out_axis_resources=pjit.PartitionSpec(mesh_axis_names_tuple,))(xs, ps)
|
1048 |
+
# Run matrix inverse pth root on each shard.
|
1049 |
+
partitioned_preconditioners, partitioned_errors = _matrix_inverse_pth_root_vmap(
|
1050 |
+
partitioned_xs, partitioned_ps)
|
1051 |
+
# Recombine the outputs at each core.
|
1052 |
+
preconditioners, errors = pjit.pjit(
|
1053 |
+
lambda x, y: (x, y),
|
1054 |
+
in_axis_resources=(pjit.PartitionSpec(mesh_axis_names_tuple,),
|
1055 |
+
pjit.PartitionSpec(mesh_axis_names_tuple,)),
|
1056 |
+
out_axis_resources=(None, None))(partitioned_preconditioners,
|
1057 |
+
partitioned_errors)
|
1058 |
+
return preconditioners, errors
|
1059 |
+
|
1060 |
+
def _pmap_compute_preconditioners(states, step, statistics,
|
1061 |
+
num_statistics_per_state, original_shapes,
|
1062 |
+
exponents, max_size, prev_preconditioners):
|
1063 |
+
"""Computes preconditioners for given statistics in states in PMAP mode.
|
1064 |
+
|
1065 |
+
Args:
|
1066 |
+
states: A list of optimizer states.
|
1067 |
+
step: Current step number
|
1068 |
+
statistics: A list of statistics for all variables (for every dim)
|
1069 |
+
num_statistics_per_state: Number of statistis per state to reconstruct
|
1070 |
+
output states.
|
1071 |
+
original_shapes: A list of shapes of the statistics.
|
1072 |
+
exponents: Exponent power to use for inverse-pth roots.
|
1073 |
+
max_size: Maximum dim of the statistics to pad.
|
1074 |
+
prev_preconditioners: Previously available preconditioner.
|
1075 |
+
|
1076 |
+
Returns:
|
1077 |
+
New optimizer states after computing the preconditioner.
|
1078 |
+
"""
|
1079 |
+
num_devices = lax.psum(1, batch_axis_name)
|
1080 |
num_statistics = len(statistics)
|
1081 |
+
# Pad statistics and exponents to next multiple of num_devices.
|
1082 |
+
packed_statistics = [pad_matrix(stat, max_size) for stat in statistics]
|
1083 |
+
to_pad = -num_statistics % num_devices
|
1084 |
+
packed_statistics.extend([
|
1085 |
+
jnp.eye(max_size, dtype=packed_statistics[0].dtype)
|
1086 |
+
for _ in range(to_pad)
|
1087 |
+
])
|
1088 |
+
exponents.extend([1 for _ in range(to_pad)])
|
1089 |
|
1090 |
+
if not packed_statistics:
|
1091 |
+
return states
|
1092 |
+
|
1093 |
+
all_statistics = batch(packed_statistics, num_devices)
|
1094 |
+
all_exponents = batch(exponents, num_devices)
|
1095 |
+
|
1096 |
+
def _internal_inverse_pth_root_all():
|
1097 |
+
current_replica = lax.axis_index(batch_axis_name)
|
1098 |
+
preconditioners, errors = _matrix_inverse_pth_root_vmap(
|
1099 |
+
all_statistics[current_replica], all_exponents[current_replica])
|
1100 |
+
preconditioners = jax.lax.all_gather(preconditioners, batch_axis_name)
|
1101 |
+
errors = jax.lax.all_gather(errors, batch_axis_name)
|
1102 |
+
preconditioners_flat = unbatch(preconditioners)
|
1103 |
+
errors_flat = unbatch(errors)
|
1104 |
+
return preconditioners_flat, errors_flat
|
1105 |
+
|
1106 |
+
if preconditioning_compute_steps == 1:
|
1107 |
+
preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1108 |
else:
|
1109 |
+
# Passing statistics instead of preconditioners as they are similarly
|
1110 |
+
# shaped tensors. Note statistics will be ignored as we are passing in
|
1111 |
+
# a large init value for error.
|
1112 |
+
preconditioners_init = packed_statistics
|
1113 |
+
errors_init = ([inverse_failure_threshold] * len(packed_statistics))
|
1114 |
+
init_state = [preconditioners_init, errors_init]
|
1115 |
+
perform_step = step % preconditioning_compute_steps == 0
|
1116 |
+
preconditioners_flat, errors_flat = efficient_cond(
|
1117 |
+
perform_step, _internal_inverse_pth_root_all, init_state)
|
1118 |
|
1119 |
+
def _skip(error):
|
1120 |
+
condition = jnp.logical_or(
|
1121 |
+
jnp.isnan(error), error >= inverse_failure_threshold)
|
1122 |
+
return condition.astype(error.dtype)
|
|
|
|
|
|
|
1123 |
|
1124 |
+
def _select_preconditioner(error, new_p, old_p):
|
1125 |
+
return lax.cond(
|
1126 |
+
_skip(error), lambda _: old_p, lambda _: new_p, operand=None)
|
1127 |
+
|
1128 |
+
new_preconditioners_flat = []
|
1129 |
+
for p, shape, prev_p, error in zip(preconditioners_flat, original_shapes,
|
1130 |
+
prev_preconditioners, errors_flat):
|
1131 |
+
new_preconditioners_flat.append(
|
1132 |
+
_select_preconditioner(error, p[:shape[0], :shape[1]], prev_p))
|
1133 |
+
|
1134 |
+
assert len(states) == len(num_statistics_per_state)
|
1135 |
+
assert len(new_preconditioners_flat) == num_statistics
|
|
|
|
|
|
|
|
|
|
|
|
|
1136 |
|
1137 |
+
# Add back empty preconditioners so we that we can set the optimizer state.
|
1138 |
+
preconditioners_for_states = []
|
1139 |
+
idx = 0
|
1140 |
+
for num_statistics, state in zip(num_statistics_per_state, states):
|
1141 |
+
if num_statistics == 0:
|
1142 |
+
preconditioners_for_states.append([])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1143 |
else:
|
1144 |
+
preconditioners_for_state = new_preconditioners_flat[idx:idx +
|
1145 |
+
num_statistics]
|
1146 |
+
assert len(state.statistics) == len(preconditioners_for_state)
|
1147 |
+
preconditioners_for_states.append(preconditioners_for_state)
|
1148 |
+
idx += num_statistics
|
1149 |
+
new_states = []
|
1150 |
+
for state, new_preconditioners in zip(states, preconditioners_for_states):
|
1151 |
+
new_states.append(
|
1152 |
+
ParameterStats(state.diagonal_statistics, state.statistics,
|
1153 |
+
new_preconditioners, state.diagonal_momentum,
|
1154 |
+
state.momentum))
|
1155 |
+
|
1156 |
+
return new_states
|
1157 |
+
|
1158 |
+
def _pmap_quantized_compute_preconditioners(states, step, statistics,
|
1159 |
+
num_statistics_per_state,
|
1160 |
+
original_shapes, exponents,
|
1161 |
+
max_size, prev_preconditioners):
|
1162 |
+
"""Computes preconditioners for given statistics in states in PMAP mode.
|
1163 |
+
|
1164 |
+
For quantization, each statistic is represented by three values:
|
1165 |
+
quantized matrix, diagonal, and bucket sizes, we run inverse pth-roots
|
1166 |
+
without ever recreating the original matrix in f32.
|
1167 |
+
|
1168 |
+
Args:
|
1169 |
+
states: A list of optimizer states.
|
1170 |
+
step: Current step number
|
1171 |
+
statistics: A list of statistics for all variables (for every dim)
|
1172 |
+
num_statistics_per_state: Number of statistis per state to reconstruct
|
1173 |
+
output states.
|
1174 |
+
original_shapes: A list of shapes of the statistics.
|
1175 |
+
exponents: Exponent power to use for inverse-pth roots.
|
1176 |
+
max_size: Maximum dim of the statistics to pad.
|
1177 |
+
prev_preconditioners: Previously available preconditioner.
|
1178 |
+
|
1179 |
+
Returns:
|
1180 |
+
New optimizer states after computing the preconditioner.
|
1181 |
+
"""
|
1182 |
+
num_devices = lax.psum(1, batch_axis_name)
|
1183 |
+
num_statistics = len(statistics)
|
1184 |
+
quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers()
|
1185 |
+
# Complexity here is around: shapes needing be statically shaped,
|
1186 |
+
# our custom quantization type requires a different type of packing.
|
1187 |
+
|
1188 |
+
# Parallel tensors:
|
1189 |
+
# quantized [dxd]
|
1190 |
+
# diagonals [d] f32
|
1191 |
+
# bucket_sizes [d] f32
|
1192 |
+
packed_quantized_statistics = [
|
1193 |
+
pad_matrix(stat.quantized, max_size) for stat in statistics
|
1194 |
+
]
|
1195 |
+
packed_quantized_diagonals = [
|
1196 |
+
pad_vector(stat.diagonal, max_size) for stat in statistics
|
1197 |
+
]
|
1198 |
+
packed_quantized_bucket_sizes = [
|
1199 |
+
pad_vector(stat.bucket_size, max_size) for stat in statistics
|
1200 |
+
]
|
1201 |
+
|
1202 |
+
to_pad = -num_statistics % num_devices
|
1203 |
+
padded_eye = jnp.eye(max_size, dtype=jnp.float32)
|
1204 |
+
quantized_eye = QuantizedValue.from_float_value(padded_eye, quantized_dtype,
|
1205 |
+
True)
|
1206 |
+
packed_quantized_statistics.extend(
|
1207 |
+
[quantized_eye.quantized for _ in range(to_pad)])
|
1208 |
+
packed_quantized_diagonals.extend(
|
1209 |
+
[quantized_eye.diagonal for _ in range(to_pad)])
|
1210 |
+
packed_quantized_bucket_sizes.extend(
|
1211 |
+
[quantized_eye.bucket_size for _ in range(to_pad)])
|
1212 |
+
exponents.extend([1 for _ in range(to_pad)])
|
1213 |
+
|
1214 |
+
if not packed_quantized_statistics:
|
1215 |
+
return states
|
1216 |
+
|
1217 |
+
all_quantized_statistics = batch(packed_quantized_statistics, num_devices)
|
1218 |
+
all_quantized_diagonals = batch(packed_quantized_diagonals, num_devices)
|
1219 |
+
all_quantized_bucket_sizes = batch(packed_quantized_bucket_sizes,
|
1220 |
+
num_devices)
|
1221 |
+
all_exponents = batch(exponents, num_devices)
|
1222 |
+
|
1223 |
+
def _internal_inverse_pth_root_all():
|
1224 |
+
current_replica = lax.axis_index(batch_axis_name)
|
1225 |
+
quantized_preconditioners, quantized_diagonals, quantized_bucket_sizes, errors = (
|
1226 |
+
_quantized_matrix_inverse_pth_root_vmap(
|
1227 |
+
all_quantized_statistics[current_replica],
|
1228 |
+
all_quantized_diagonals[current_replica],
|
1229 |
+
all_quantized_bucket_sizes[current_replica],
|
1230 |
+
all_exponents[current_replica]))
|
1231 |
+
quantized_preconditioners = jax.lax.all_gather(quantized_preconditioners,
|
1232 |
+
batch_axis_name)
|
1233 |
+
quantized_diagonals = jax.lax.all_gather(quantized_diagonals,
|
1234 |
+
batch_axis_name)
|
1235 |
+
quantized_bucket_sizes = jax.lax.all_gather(quantized_bucket_sizes,
|
1236 |
+
batch_axis_name)
|
1237 |
+
errors = jax.lax.all_gather(errors, batch_axis_name)
|
1238 |
+
quantized_preconditioners_flat = unbatch(quantized_preconditioners)
|
1239 |
+
quantized_diagonals_flat = unbatch(quantized_diagonals)
|
1240 |
+
quantized_bucket_sizes_flat = unbatch(quantized_bucket_sizes)
|
1241 |
+
errors_flat = unbatch(errors)
|
1242 |
+
return (quantized_preconditioners_flat, quantized_diagonals_flat,
|
1243 |
+
quantized_bucket_sizes_flat, errors_flat)
|
1244 |
+
|
1245 |
+
if preconditioning_compute_steps == 1:
|
1246 |
+
(quantized_preconditioners_flat, quantized_diagonals_flat,
|
1247 |
+
quantized_bucket_sizes_flat, errors_flat) = (
|
1248 |
+
_internal_inverse_pth_root_all())
|
1249 |
else:
|
1250 |
+
# Passing statistics instead of preconditioners as they are similarly
|
1251 |
+
# shaped tensors. Note statistics will be ignored as we are passing in
|
1252 |
+
# a large init value for error.
|
1253 |
+
quantized_preconditioners_init = packed_quantized_statistics
|
1254 |
+
quantized_diagonals_init = packed_quantized_diagonals
|
1255 |
+
quantized_bucket_sizes_init = packed_quantized_bucket_sizes
|
1256 |
+
errors_init = ([inverse_failure_threshold] *
|
1257 |
+
len(quantized_preconditioners_init))
|
1258 |
+
init_state = [
|
1259 |
+
quantized_preconditioners_init, quantized_diagonals_init,
|
1260 |
+
quantized_bucket_sizes_init, errors_init
|
1261 |
+
]
|
1262 |
+
perform_step = step % preconditioning_compute_steps == 0
|
1263 |
+
(quantized_preconditioners_flat, quantized_diagonals_flat,
|
1264 |
+
quantized_bucket_sizes_flat, errors_flat) = (
|
1265 |
+
efficient_cond(perform_step, _internal_inverse_pth_root_all,
|
1266 |
+
init_state))
|
1267 |
|
1268 |
+
def _skip(error):
|
1269 |
+
condition = jnp.logical_or(
|
1270 |
+
jnp.isnan(error), error >= inverse_failure_threshold)
|
1271 |
+
return condition.astype(error.dtype)
|
1272 |
+
|
1273 |
+
def _select_preconditioner(error, new_p, old_p):
|
1274 |
+
return lax.cond(
|
1275 |
+
_skip(error), lambda _: old_p, lambda _: new_p, operand=None)
|
1276 |
+
|
1277 |
+
new_quantized_preconditioners_flat = []
|
1278 |
+
new_quantized_diagonals_flat = []
|
1279 |
+
new_quantized_bucket_sizes_flat = []
|
1280 |
+
for p, d, b, shape, prev_p, error in zip(quantized_preconditioners_flat,
|
1281 |
+
quantized_diagonals_flat,
|
1282 |
+
quantized_bucket_sizes_flat,
|
1283 |
+
original_shapes,
|
1284 |
+
prev_preconditioners, errors_flat):
|
1285 |
+
new_quantized_preconditioners_flat.append(
|
1286 |
+
_select_preconditioner(error, p[:shape[0], :shape[1]],
|
1287 |
+
prev_p.quantized))
|
1288 |
+
new_quantized_diagonals_flat.append(
|
1289 |
+
_select_preconditioner(error, d[:shape[0]], prev_p.diagonal))
|
1290 |
+
new_quantized_bucket_sizes_flat.append(
|
1291 |
+
_select_preconditioner(error, b[:shape[0]], prev_p.bucket_size))
|
1292 |
+
|
1293 |
+
assert len(states) == len(num_statistics_per_state)
|
1294 |
+
assert len(new_quantized_preconditioners_flat) == num_statistics
|
1295 |
+
assert len(new_quantized_diagonals_flat) == num_statistics
|
1296 |
+
assert len(new_quantized_bucket_sizes_flat) == num_statistics
|
1297 |
+
|
1298 |
+
# Add back empty preconditioners so we that we can set the optimizer state.
|
1299 |
+
preconditioners_for_states = []
|
1300 |
+
idx = 0
|
1301 |
+
for num_statistics, state in zip(num_statistics_per_state, states):
|
1302 |
+
if num_statistics == 0:
|
1303 |
+
preconditioners_for_states.append([])
|
1304 |
else:
|
1305 |
+
quantized_preconditioners_for_state = new_quantized_preconditioners_flat[
|
1306 |
+
idx:idx + num_statistics]
|
1307 |
+
quantized_diagonals_for_state = new_quantized_diagonals_flat[
|
1308 |
+
idx:idx + num_statistics]
|
1309 |
+
quantized_bucket_sizes_for_state = new_quantized_bucket_sizes_flat[
|
1310 |
+
idx:idx + num_statistics]
|
1311 |
+
|
1312 |
+
assert len(state.statistics) == len(quantized_preconditioners_for_state)
|
1313 |
+
assert len(state.statistics) == len(quantized_diagonals_for_state)
|
1314 |
+
assert len(state.statistics) == len(quantized_bucket_sizes_for_state)
|
1315 |
+
|
1316 |
+
quantized_preconditioners = []
|
1317 |
+
for qv, qd, qb in zip(quantized_preconditioners_for_state,
|
1318 |
+
quantized_diagonals_for_state,
|
1319 |
+
quantized_bucket_sizes_for_state):
|
1320 |
+
quantized_preconditioners.append(
|
1321 |
+
QuantizedValue(qv, qd, qb, qv.dtype, True, list(qv.shape)))
|
1322 |
+
preconditioners_for_states.append(quantized_preconditioners)
|
1323 |
+
idx += num_statistics
|
1324 |
+
new_states = []
|
1325 |
+
for state, new_preconditioners in zip(states, preconditioners_for_states):
|
1326 |
+
new_states.append(
|
1327 |
+
ParameterStats(state.diagonal_statistics, state.statistics,
|
1328 |
+
new_preconditioners, state.diagonal_momentum,
|
1329 |
+
state.momentum))
|
1330 |
+
|
1331 |
+
return new_states
|
1332 |
+
|
1333 |
+
def _pjit_compute_preconditioners(states, step, statistics,
|
1334 |
+
num_statistics_per_state, original_shapes,
|
1335 |
+
exponents, max_size, prev_preconditioners):
|
1336 |
+
"""Computes preconditioners for given statistics in states in PJIT mode.
|
1337 |
+
|
1338 |
+
Args:
|
1339 |
+
states: A list of optimizer states.
|
1340 |
+
step: Current step number
|
1341 |
+
statistics: A list of statistics for all variables (for every dim)
|
1342 |
+
num_statistics_per_state: Number of statistis per state to reconstruct
|
1343 |
+
output states.
|
1344 |
+
original_shapes: A list of shapes of the statistics.
|
1345 |
+
exponents: Exponent power to use for inverse-pth roots.
|
1346 |
+
max_size: Maximum dim of the statistics to pad.
|
1347 |
+
prev_preconditioners: Previously available preconditioner.
|
1348 |
+
|
1349 |
+
Returns:
|
1350 |
+
New optimizer states after computing the preconditioner.
|
1351 |
+
"""
|
1352 |
+
num_statistics = len(statistics)
|
1353 |
+
to_pad = -num_statistics % num_devices_for_pjit
|
1354 |
+
padded_statistics = [pad_matrix(stat, max_size) for stat in statistics]
|
1355 |
+
padded_statistics.extend([
|
1356 |
+
jnp.eye(max_size, dtype=padded_statistics[0].dtype)
|
1357 |
+
for _ in range(to_pad)
|
1358 |
+
])
|
1359 |
+
exponents.extend([1 for _ in range(to_pad)])
|
1360 |
+
all_statistics = jnp.stack(padded_statistics)
|
1361 |
+
all_exponents = jnp.stack(exponents)
|
1362 |
+
|
1363 |
+
def _internal_inverse_pth_root_all():
|
1364 |
+
preconditioners, errors = _matrix_inverse_pth_root_pjit(
|
1365 |
+
all_statistics, all_exponents)
|
1366 |
+
b1 = preconditioners.shape[0]
|
1367 |
+
|
1368 |
+
def split(batched_values):
|
1369 |
+
return [
|
1370 |
+
jnp.squeeze(v)
|
1371 |
+
for v in jnp.split(batched_values, indices_or_sections=b1, axis=0)
|
1372 |
+
]
|
1373 |
+
|
1374 |
+
return split(preconditioners), split(errors)
|
1375 |
+
|
1376 |
+
if preconditioning_compute_steps == 1:
|
1377 |
+
preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
|
1378 |
+
else:
|
1379 |
+
# Passing statistics instead of preconditioners as they are similarly
|
1380 |
+
# shaped tensors. Note statistics will be ignored as we are passing in
|
1381 |
+
# a large init value for error.
|
1382 |
+
preconditioners_init = padded_statistics
|
1383 |
+
errors_init = [inverse_failure_threshold] * len(padded_statistics)
|
1384 |
+
init_state = [preconditioners_init, errors_init]
|
1385 |
+
perform_step = step % preconditioning_compute_steps == 0
|
1386 |
+
preconditioners_flat, errors_flat = efficient_cond(
|
1387 |
+
perform_step, _internal_inverse_pth_root_all, init_state)
|
1388 |
|
1389 |
def _skip(error):
|
1390 |
condition = jnp.logical_or(
|
|
|
1425 |
|
1426 |
return new_states
|
1427 |
|
1428 |
+
def _compute_preconditioners(states, params, step):
|
1429 |
+
"""Computes preconditioners for given statistics in states.
|
1430 |
+
|
1431 |
+
Args:
|
1432 |
+
states: A list of optimizer states.
|
1433 |
+
params: A list of params.
|
1434 |
+
step: Current step number
|
1435 |
+
|
1436 |
+
Returns:
|
1437 |
+
New optimizer states after computing the preconditioner.
|
1438 |
+
"""
|
1439 |
+
statistics = []
|
1440 |
+
num_statistics_per_state = []
|
1441 |
+
original_shapes = []
|
1442 |
+
exponents = []
|
1443 |
+
max_size = 0
|
1444 |
+
prev_preconditioners = []
|
1445 |
+
|
1446 |
+
for state, param in zip(states, params):
|
1447 |
+
num_statistics = len(state.statistics)
|
1448 |
+
num_statistics_per_state.append(num_statistics)
|
1449 |
+
original_shapes_for_state = []
|
1450 |
+
if num_statistics > 0:
|
1451 |
+
preconditioner = Preconditioner(param, block_size,
|
1452 |
+
best_effort_shape_interpretation)
|
1453 |
+
for statistic in state.statistics:
|
1454 |
+
exponents.append(preconditioner.exponent_for_preconditioner(
|
1455 |
+
) if exponent_override == 0 else exponent_override)
|
1456 |
+
original_shapes_for_state.append(statistic.shape)
|
1457 |
+
max_size = max(max_size, statistic.shape[0])
|
1458 |
+
|
1459 |
+
statistics.extend(state.statistics)
|
1460 |
+
prev_preconditioners.extend(state.preconditioners)
|
1461 |
+
original_shapes.extend(original_shapes_for_state)
|
1462 |
+
|
1463 |
+
if batch_axis_name:
|
1464 |
+
# Quantization is only enabled if batch_axis_name is not set.
|
1465 |
+
quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers()
|
1466 |
+
|
1467 |
+
if quantized_dtype == jnp.float32:
|
1468 |
+
return _pmap_compute_preconditioners(states, step, statistics,
|
1469 |
+
num_statistics_per_state,
|
1470 |
+
original_shapes, exponents,
|
1471 |
+
max_size, prev_preconditioners)
|
1472 |
+
else:
|
1473 |
+
return _pmap_quantized_compute_preconditioners(
|
1474 |
+
states, step, statistics, num_statistics_per_state, original_shapes,
|
1475 |
+
exponents, max_size, prev_preconditioners)
|
1476 |
+
|
1477 |
+
else:
|
1478 |
+
return _pjit_compute_preconditioners(states, step, statistics,
|
1479 |
+
num_statistics_per_state,
|
1480 |
+
original_shapes, exponents, max_size,
|
1481 |
+
prev_preconditioners)
|
1482 |
+
|
1483 |
def _transform_grad(grad, state, param, step):
|
1484 |
"""Transform per-parameter gradients."""
|
1485 |
preconditioner = Preconditioner(param, block_size,
|
1486 |
best_effort_shape_interpretation)
|
1487 |
sgd_update = grad
|
1488 |
+
new_diagonal_statistics = state.diagonal_statistics.to_float()
|
1489 |
if graft_type == GraftingType.ADAGRAD:
|
1490 |
+
new_diagonal_statistics = state.diagonal_statistics.to_float(
|
1491 |
+
) + jnp.square(grad)
|
1492 |
adagrad_update = grad / (
|
1493 |
jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon)
|
1494 |
grafting_update = adagrad_update
|
|
|
1503 |
w2 = beta2 if beta2 == 1.0 else (1.0 - beta2)
|
1504 |
|
1505 |
new_diagonal_statistics = (
|
1506 |
+
w1 * state.diagonal_statistics.to_float() +
|
1507 |
+
w2 * jnp.square(scaled_grad))
|
1508 |
rmsprop_update = scaled_grad / (
|
1509 |
jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon)
|
1510 |
|
|
|
1521 |
|
1522 |
precond_grad = grad
|
1523 |
if not _skip_preconditioning(param):
|
1524 |
+
precond_grad = preconditioner.preconditioned_grad(
|
1525 |
+
precond_grad,
|
1526 |
+
_maybe_dequantize_preconditioners(state.preconditioners))
|
1527 |
else:
|
1528 |
precond_grad = grafting_update
|
1529 |
|
|
|
1541 |
|
1542 |
w = (1.0 - beta1) if moving_average_for_momentum else 1.0
|
1543 |
shampoo_update_with_wd_momentum = (
|
1544 |
+
state.momentum.to_float() * beta1 + w * shampoo_update_with_wd)
|
1545 |
grafting_update_with_wd_momentum = (
|
1546 |
+
state.diagonal_momentum.to_float() * beta1 +
|
1547 |
+
w * grafting_update_with_wd)
|
1548 |
|
1549 |
run_shampoo = (step >= start_preconditioning_step).astype(
|
1550 |
grafting_update_with_wd_momentum.dtype)
|
|
|
1565 |
lr = learning_rate(step)
|
1566 |
transformed_update = -1.0 * lr * momentum_update
|
1567 |
|
1568 |
+
param_stats = ParameterStats(
|
1569 |
+
_quantize_diagonal_statistics(new_diagonal_statistics),
|
1570 |
+
state.statistics, state.preconditioners,
|
1571 |
+
_quantize_momentum(grafting_update_with_wd_momentum),
|
1572 |
+
_quantize_momentum(shampoo_update_with_wd_momentum))
|
1573 |
return transformed_update, param_stats
|
1574 |
|
1575 |
def update_fn(grads, state, params):
|