Spaces:
Running
Running
feat: update distributed_shampoo
Browse files
tools/train/distributed_shampoo.py
CHANGED
@@ -34,13 +34,13 @@ import itertools
|
|
34 |
from typing import Any, List, NamedTuple
|
35 |
|
36 |
import chex
|
|
|
37 |
import jax
|
|
|
38 |
import jax.experimental.pjit as pjit
|
39 |
import jax.numpy as jnp
|
40 |
import numpy as np
|
41 |
import optax
|
42 |
-
from flax import struct
|
43 |
-
from jax import lax
|
44 |
|
45 |
|
46 |
# pylint:disable=no-value-for-parameter
|
@@ -234,6 +234,8 @@ class GraftingType(enum.IntEnum):
|
|
234 |
ADAGRAD = 2
|
235 |
RMSPROP = 3
|
236 |
RMSPROP_NORMALIZED = 4
|
|
|
|
|
237 |
|
238 |
|
239 |
def power_iteration(
|
@@ -336,7 +338,7 @@ def matrix_inverse_pth_root(
|
|
336 |
_, max_ev = power_iteration(
|
337 |
matrix=matrix, num_iters=100, error_tolerance=1e-6, precision=precision
|
338 |
)
|
339 |
-
ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-
|
340 |
|
341 |
def _unrolled_mat_pow_1(mat_m):
|
342 |
"""Computes mat_m^1."""
|
@@ -791,8 +793,7 @@ def distributed_shampoo(
|
|
791 |
block = 1024, [1, 2, 768, 1, 2048] --> [2, 768, 2048]
|
792 |
graft_type: Grafting is a technique to fix the layerwise scale of Shampoo
|
793 |
optimizer. This allows us to plugin the Shampoo optimizer into settings
|
794 |
-
where SGD/AdaGrad is already well tuned.
|
795 |
-
GraftingType.SGD and GraftingType.ADAGRAD.
|
796 |
nesterov: Nesterov momentum.
|
797 |
exponent_override: Override the exponent used in matrix inverse.
|
798 |
batch_axis_name: labeled axis over pmap for data-parallel training the
|
@@ -823,12 +824,20 @@ def distributed_shampoo(
|
|
823 |
a GradientTransformation.
|
824 |
"""
|
825 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
826 |
def quantized_dtype_for_momentum_buffers():
|
827 |
return jnp.int8 if best_effort_memory_usage_reduction else jnp.float32
|
828 |
|
829 |
# TODO(rohananil): Explore int8-16 quantization with non-linear bucket sizes.
|
830 |
def quantized_dtype_for_diagonal_statistics_buffers():
|
831 |
-
return jnp.
|
832 |
|
833 |
# Preconditioner and statistics are both stores as int16 in this mode.
|
834 |
# We take out the diagonal to make quantization easier.
|
@@ -944,13 +953,19 @@ def distributed_shampoo(
|
|
944 |
exponents.extend([exponent] * len(shapes))
|
945 |
|
946 |
diagonal_statistics = []
|
947 |
-
if
|
948 |
diagonal_statistics = jnp.zeros_like(param)
|
|
|
|
|
|
|
|
|
|
|
|
|
949 |
local_stats_flat.append(
|
950 |
LocalShardedParameterStats(
|
951 |
_quantize_diagonal_statistics(diagonal_statistics),
|
952 |
-
|
953 |
-
|
954 |
init_training_metrics(len(sizes)),
|
955 |
index_start,
|
956 |
sizes,
|
@@ -1039,7 +1054,7 @@ def distributed_shampoo(
|
|
1039 |
|
1040 |
diagonal_statistics_pspec = []
|
1041 |
diagonal_statistics_scale_pspec = []
|
1042 |
-
if
|
1043 |
# Identically shaped param.
|
1044 |
diagonal_statistics_pspec = param_pspec
|
1045 |
if quantized_dtype_for_diagonal_statistics_buffers() != jnp.float32:
|
@@ -1047,14 +1062,16 @@ def distributed_shampoo(
|
|
1047 |
_remove_leading_sharding_annotation(param_pspec)
|
1048 |
)
|
1049 |
|
1050 |
-
m1_pspec =
|
1051 |
-
m2_pspec = param_pspec
|
1052 |
-
|
1053 |
m1_scale_pspec = []
|
1054 |
-
|
|
|
|
|
|
|
1055 |
|
|
|
|
|
1056 |
if quantized_dtype_for_momentum_buffers() != jnp.float32:
|
1057 |
-
m1_scale_pspec = _remove_leading_sharding_annotation(m1_pspec)
|
1058 |
m2_scale_pspec = _remove_leading_sharding_annotation(m2_pspec)
|
1059 |
|
1060 |
local_stats_flat.append(
|
@@ -1130,7 +1147,7 @@ def distributed_shampoo(
|
|
1130 |
|
1131 |
diagonal_statistics_shape_and_dtype = []
|
1132 |
diagonal_statistics_scale_shape_and_dtype = []
|
1133 |
-
if
|
1134 |
diagonal_statistics_shape_and_dtype = [list(param.shape), param.dtype]
|
1135 |
qdtype = quantized_dtype_for_diagonal_statistics_buffers()
|
1136 |
if qdtype != jnp.float32:
|
@@ -1140,18 +1157,18 @@ def distributed_shampoo(
|
|
1140 |
param.dtype,
|
1141 |
]
|
1142 |
|
1143 |
-
|
1144 |
-
|
1145 |
-
|
1146 |
m1_scale_shape_and_dtype = []
|
1147 |
-
|
|
|
|
|
|
|
1148 |
|
1149 |
-
|
|
|
1150 |
if qdtype != jnp.float32:
|
1151 |
-
m1_shape_and_dtype = [list(param.shape), qdtype]
|
1152 |
m2_shape_and_dtype = [list(param.shape), qdtype]
|
1153 |
-
|
1154 |
-
m1_scale_shape_and_dtype = [list(param.shape)[1:], qdtype]
|
1155 |
m2_scale_shape_and_dtype = [list(param.shape)[1:], qdtype]
|
1156 |
|
1157 |
local_stats_flat.append(
|
@@ -1331,14 +1348,20 @@ def distributed_shampoo(
|
|
1331 |
preconditioners = [jnp.eye(s[0]) for s in shapes]
|
1332 |
|
1333 |
diagonal_statistics = []
|
1334 |
-
if
|
1335 |
diagonal_statistics = jnp.zeros_like(param)
|
|
|
|
|
|
|
|
|
|
|
|
|
1336 |
return ParameterStats(
|
1337 |
_quantize_diagonal_statistics(diagonal_statistics),
|
1338 |
_maybe_quantize_statistics(statistics),
|
1339 |
_maybe_quantize_preconditioners(preconditioners),
|
1340 |
-
|
1341 |
-
|
1342 |
init_training_metrics(len(statistics)),
|
1343 |
)
|
1344 |
|
@@ -2037,11 +2060,19 @@ def distributed_shampoo(
|
|
2037 |
)
|
2038 |
sgd_update = grad
|
2039 |
new_diagonal_statistics = state.diagonal_statistics.to_float()
|
2040 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2041 |
new_diagonal_statistics = state.diagonal_statistics.to_float() + jnp.square(
|
2042 |
-
|
2043 |
)
|
2044 |
-
adagrad_update =
|
2045 |
jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon
|
2046 |
)
|
2047 |
grafting_update = adagrad_update
|
@@ -2074,8 +2105,10 @@ def distributed_shampoo(
|
|
2074 |
rmsprop_update /= clipping_denom
|
2075 |
|
2076 |
grafting_update = rmsprop_update
|
2077 |
-
|
2078 |
grafting_update = sgd_update
|
|
|
|
|
2079 |
|
2080 |
precond_grad = grad
|
2081 |
if not _skip_preconditioning(param):
|
@@ -2098,12 +2131,20 @@ def distributed_shampoo(
|
|
2098 |
grafting_update_with_wd = grafting_update + weight_decay * param
|
2099 |
|
2100 |
w = (1.0 - beta1) if moving_average_for_momentum else 1.0
|
|
|
2101 |
shampoo_update_with_wd_momentum = (
|
2102 |
state.momentum.to_float() * beta1 + w * shampoo_update_with_wd
|
2103 |
)
|
2104 |
-
|
2105 |
-
|
2106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2107 |
|
2108 |
run_shampoo = (step >= start_preconditioning_step).astype(
|
2109 |
grafting_update_with_wd_momentum.dtype
|
@@ -2119,20 +2160,27 @@ def distributed_shampoo(
|
|
2119 |
+ (1.0 - run_shampoo) * grafting_update_with_wd
|
2120 |
)
|
2121 |
|
|
|
2122 |
if nesterov:
|
2123 |
-
|
2124 |
|
2125 |
lr = learning_rate
|
2126 |
if callable(learning_rate):
|
2127 |
lr = learning_rate(step)
|
2128 |
-
transformed_update = -1.0 * lr *
|
|
|
|
|
|
|
|
|
|
|
|
|
2129 |
|
2130 |
param_stats = ParameterStats(
|
2131 |
_quantize_diagonal_statistics(new_diagonal_statistics),
|
2132 |
state.statistics,
|
2133 |
state.preconditioners,
|
2134 |
-
_quantize_momentum(
|
2135 |
-
_quantize_momentum(
|
2136 |
state.training_metrics,
|
2137 |
)
|
2138 |
|
|
|
34 |
from typing import Any, List, NamedTuple
|
35 |
|
36 |
import chex
|
37 |
+
from flax import struct
|
38 |
import jax
|
39 |
+
from jax import lax
|
40 |
import jax.experimental.pjit as pjit
|
41 |
import jax.numpy as jnp
|
42 |
import numpy as np
|
43 |
import optax
|
|
|
|
|
44 |
|
45 |
|
46 |
# pylint:disable=no-value-for-parameter
|
|
|
234 |
ADAGRAD = 2
|
235 |
RMSPROP = 3
|
236 |
RMSPROP_NORMALIZED = 4
|
237 |
+
SQRT_N = 5
|
238 |
+
ADAGRAD_NORMALIZED = 5
|
239 |
|
240 |
|
241 |
def power_iteration(
|
|
|
338 |
_, max_ev = power_iteration(
|
339 |
matrix=matrix, num_iters=100, error_tolerance=1e-6, precision=precision
|
340 |
)
|
341 |
+
ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-6)
|
342 |
|
343 |
def _unrolled_mat_pow_1(mat_m):
|
344 |
"""Computes mat_m^1."""
|
|
|
793 |
block = 1024, [1, 2, 768, 1, 2048] --> [2, 768, 2048]
|
794 |
graft_type: Grafting is a technique to fix the layerwise scale of Shampoo
|
795 |
optimizer. This allows us to plugin the Shampoo optimizer into settings
|
796 |
+
where SGD/AdaGrad is already well tuned.
|
|
|
797 |
nesterov: Nesterov momentum.
|
798 |
exponent_override: Override the exponent used in matrix inverse.
|
799 |
batch_axis_name: labeled axis over pmap for data-parallel training the
|
|
|
824 |
a GradientTransformation.
|
825 |
"""
|
826 |
|
827 |
+
def _graft_type_has_diagonal_statistics():
|
828 |
+
"""Returns True if using diagonal firt order method for grafting."""
|
829 |
+
return graft_type != GraftingType.SGD and graft_type != GraftingType.SQRT_N
|
830 |
+
|
831 |
+
def _graft_type_has_diagonal_momentum_states():
|
832 |
+
"""Returns False if using SQRT_N for grafting."""
|
833 |
+
return graft_type != GraftingType.SQRT_N
|
834 |
+
|
835 |
def quantized_dtype_for_momentum_buffers():
|
836 |
return jnp.int8 if best_effort_memory_usage_reduction else jnp.float32
|
837 |
|
838 |
# TODO(rohananil): Explore int8-16 quantization with non-linear bucket sizes.
|
839 |
def quantized_dtype_for_diagonal_statistics_buffers():
|
840 |
+
return jnp.float32
|
841 |
|
842 |
# Preconditioner and statistics are both stores as int16 in this mode.
|
843 |
# We take out the diagonal to make quantization easier.
|
|
|
953 |
exponents.extend([exponent] * len(shapes))
|
954 |
|
955 |
diagonal_statistics = []
|
956 |
+
if _graft_type_has_diagonal_statistics():
|
957 |
diagonal_statistics = jnp.zeros_like(param)
|
958 |
+
|
959 |
+
diagonal_momentum = _quantize_momentum([])
|
960 |
+
momentum = _quantize_momentum(jnp.zeros_like(param))
|
961 |
+
if _graft_type_has_diagonal_momentum_states():
|
962 |
+
diagonal_momentum = _quantize_momentum((jnp.zeros_like(param)))
|
963 |
+
|
964 |
local_stats_flat.append(
|
965 |
LocalShardedParameterStats(
|
966 |
_quantize_diagonal_statistics(diagonal_statistics),
|
967 |
+
diagonal_momentum,
|
968 |
+
momentum,
|
969 |
init_training_metrics(len(sizes)),
|
970 |
index_start,
|
971 |
sizes,
|
|
|
1054 |
|
1055 |
diagonal_statistics_pspec = []
|
1056 |
diagonal_statistics_scale_pspec = []
|
1057 |
+
if _graft_type_has_diagonal_statistics():
|
1058 |
# Identically shaped param.
|
1059 |
diagonal_statistics_pspec = param_pspec
|
1060 |
if quantized_dtype_for_diagonal_statistics_buffers() != jnp.float32:
|
|
|
1062 |
_remove_leading_sharding_annotation(param_pspec)
|
1063 |
)
|
1064 |
|
1065 |
+
m1_pspec = []
|
|
|
|
|
1066 |
m1_scale_pspec = []
|
1067 |
+
if _graft_type_has_diagonal_momentum_states():
|
1068 |
+
m1_pspec = param_pspec
|
1069 |
+
if quantized_dtype_for_momentum_buffers() != jnp.float32:
|
1070 |
+
m1_scale_pspec = _remove_leading_sharding_annotation(m1_pspec)
|
1071 |
|
1072 |
+
m2_pspec = param_pspec
|
1073 |
+
m2_scale_pspec = []
|
1074 |
if quantized_dtype_for_momentum_buffers() != jnp.float32:
|
|
|
1075 |
m2_scale_pspec = _remove_leading_sharding_annotation(m2_pspec)
|
1076 |
|
1077 |
local_stats_flat.append(
|
|
|
1147 |
|
1148 |
diagonal_statistics_shape_and_dtype = []
|
1149 |
diagonal_statistics_scale_shape_and_dtype = []
|
1150 |
+
if _graft_type_has_diagonal_statistics():
|
1151 |
diagonal_statistics_shape_and_dtype = [list(param.shape), param.dtype]
|
1152 |
qdtype = quantized_dtype_for_diagonal_statistics_buffers()
|
1153 |
if qdtype != jnp.float32:
|
|
|
1157 |
param.dtype,
|
1158 |
]
|
1159 |
|
1160 |
+
qdtype = quantized_dtype_for_momentum_buffers()
|
1161 |
+
m1_shape_and_dtype = []
|
|
|
1162 |
m1_scale_shape_and_dtype = []
|
1163 |
+
if _graft_type_has_diagonal_momentum_states():
|
1164 |
+
m1_shape_and_dtype = [list(param.shape), qdtype]
|
1165 |
+
if quantized_dtype_for_momentum_buffers() != jnp.float32:
|
1166 |
+
m1_scale_shape_and_dtype = [list(param.shape)[1:], qdtype]
|
1167 |
|
1168 |
+
m2_shape_and_dtype = [list(param.shape), param.dtype]
|
1169 |
+
m2_scale_shape_and_dtype = []
|
1170 |
if qdtype != jnp.float32:
|
|
|
1171 |
m2_shape_and_dtype = [list(param.shape), qdtype]
|
|
|
|
|
1172 |
m2_scale_shape_and_dtype = [list(param.shape)[1:], qdtype]
|
1173 |
|
1174 |
local_stats_flat.append(
|
|
|
1348 |
preconditioners = [jnp.eye(s[0]) for s in shapes]
|
1349 |
|
1350 |
diagonal_statistics = []
|
1351 |
+
if _graft_type_has_diagonal_statistics():
|
1352 |
diagonal_statistics = jnp.zeros_like(param)
|
1353 |
+
|
1354 |
+
diagonal_momentum = _quantize_momentum([])
|
1355 |
+
momentum = _quantize_momentum(jnp.zeros_like(param))
|
1356 |
+
if _graft_type_has_diagonal_momentum_states():
|
1357 |
+
diagonal_momentum = _quantize_momentum(jnp.zeros_like(param))
|
1358 |
+
|
1359 |
return ParameterStats(
|
1360 |
_quantize_diagonal_statistics(diagonal_statistics),
|
1361 |
_maybe_quantize_statistics(statistics),
|
1362 |
_maybe_quantize_preconditioners(preconditioners),
|
1363 |
+
diagonal_momentum,
|
1364 |
+
momentum,
|
1365 |
init_training_metrics(len(statistics)),
|
1366 |
)
|
1367 |
|
|
|
2060 |
)
|
2061 |
sgd_update = grad
|
2062 |
new_diagonal_statistics = state.diagonal_statistics.to_float()
|
2063 |
+
if (
|
2064 |
+
graft_type == GraftingType.ADAGRAD
|
2065 |
+
or graft_type == GraftingType.ADAGRAD_NORMALIZED
|
2066 |
+
):
|
2067 |
+
|
2068 |
+
scaled_grad = grad
|
2069 |
+
if graft_type == GraftingType.ADAGRAD_NORMALIZED:
|
2070 |
+
scaled_grad = grad / jnp.linalg.norm(grad)
|
2071 |
+
|
2072 |
new_diagonal_statistics = state.diagonal_statistics.to_float() + jnp.square(
|
2073 |
+
scaled_grad
|
2074 |
)
|
2075 |
+
adagrad_update = scaled_grad / (
|
2076 |
jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon
|
2077 |
)
|
2078 |
grafting_update = adagrad_update
|
|
|
2105 |
rmsprop_update /= clipping_denom
|
2106 |
|
2107 |
grafting_update = rmsprop_update
|
2108 |
+
elif graft_type == GraftingType.SGD:
|
2109 |
grafting_update = sgd_update
|
2110 |
+
else:
|
2111 |
+
grafting_update = jnp.ones_like(sgd_update) * jnp.sign(sgd_update)
|
2112 |
|
2113 |
precond_grad = grad
|
2114 |
if not _skip_preconditioning(param):
|
|
|
2131 |
grafting_update_with_wd = grafting_update + weight_decay * param
|
2132 |
|
2133 |
w = (1.0 - beta1) if moving_average_for_momentum else 1.0
|
2134 |
+
|
2135 |
shampoo_update_with_wd_momentum = (
|
2136 |
state.momentum.to_float() * beta1 + w * shampoo_update_with_wd
|
2137 |
)
|
2138 |
+
|
2139 |
+
if _graft_type_has_diagonal_momentum_states():
|
2140 |
+
grafting_update_with_wd_momentum = (
|
2141 |
+
state.diagonal_momentum.to_float() * beta1 + w * grafting_update_with_wd
|
2142 |
+
)
|
2143 |
+
else:
|
2144 |
+
# Share the momentum buffer
|
2145 |
+
grafting_update_with_wd_momentum = (
|
2146 |
+
state.momentum.to_float() * beta1 + w * grafting_update_with_wd
|
2147 |
+
)
|
2148 |
|
2149 |
run_shampoo = (step >= start_preconditioning_step).astype(
|
2150 |
grafting_update_with_wd_momentum.dtype
|
|
|
2160 |
+ (1.0 - run_shampoo) * grafting_update_with_wd
|
2161 |
)
|
2162 |
|
2163 |
+
nesterov_momentum_update = momentum_update
|
2164 |
if nesterov:
|
2165 |
+
nesterov_momentum_update = w * wd_update + beta1 * momentum_update
|
2166 |
|
2167 |
lr = learning_rate
|
2168 |
if callable(learning_rate):
|
2169 |
lr = learning_rate(step)
|
2170 |
+
transformed_update = -1.0 * lr * nesterov_momentum_update
|
2171 |
+
|
2172 |
+
new_diagonal_momentum = grafting_update_with_wd_momentum
|
2173 |
+
new_momentum = shampoo_update_with_wd_momentum
|
2174 |
+
if not _graft_type_has_diagonal_momentum_states():
|
2175 |
+
new_diagonal_momentum = []
|
2176 |
+
new_momentum = momentum_update
|
2177 |
|
2178 |
param_stats = ParameterStats(
|
2179 |
_quantize_diagonal_statistics(new_diagonal_statistics),
|
2180 |
state.statistics,
|
2181 |
state.preconditioners,
|
2182 |
+
_quantize_momentum(new_diagonal_momentum),
|
2183 |
+
_quantize_momentum(new_momentum),
|
2184 |
state.training_metrics,
|
2185 |
)
|
2186 |
|