Spaces:
Running
Running
feat: update shampoo
Browse files
tools/train/scalable_shampoo/README.md
CHANGED
@@ -4,4 +4,4 @@ Files copied from [google-research/scalable_shampoo/optax](https://github.com/go
|
|
4 |
|
5 |
Imports have been modified to be relative.
|
6 |
|
7 |
-
This will be replaced with `optax-shampoo` package
|
|
|
4 |
|
5 |
Imports have been modified to be relative.
|
6 |
|
7 |
+
This will eventually be replaced with `optax-shampoo` package.
|
tools/train/scalable_shampoo/distributed_shampoo.py
CHANGED
@@ -25,13 +25,12 @@
|
|
25 |
# Authors: Rohan Anil (rohananil at google dot com)
|
26 |
# & Vineet Gupta (vineet at google dot com)
|
27 |
#
|
28 |
-
|
29 |
"""Distributed Shampoo Implementation."""
|
30 |
|
31 |
import enum
|
32 |
import functools
|
33 |
import itertools
|
34 |
-
from typing import Any, List, NamedTuple
|
35 |
|
36 |
import chex
|
37 |
import jax
|
@@ -43,6 +42,7 @@ from flax import struct
|
|
43 |
from jax import lax
|
44 |
|
45 |
from .quantization_utils import QuantizedValue
|
|
|
46 |
|
47 |
# Dtype for inverse-pth root routine
|
48 |
# Switch to f64 if you have hardware that supports it. Enable the jax flag
|
@@ -141,7 +141,10 @@ class GraftingType(enum.IntEnum):
|
|
141 |
|
142 |
|
143 |
def power_iteration(
|
144 |
-
matrix,
|
|
|
|
|
|
|
145 |
):
|
146 |
r"""Power iteration algorithm.
|
147 |
|
@@ -156,10 +159,10 @@ def power_iteration(
|
|
156 |
matrix: the symmetric PSD matrix.
|
157 |
num_iters: Number of iterations.
|
158 |
error_tolerance: Iterative exit condition.
|
159 |
-
precision: precision XLA related flag, the available options are:
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
|
164 |
Returns:
|
165 |
eigen vector, eigen value
|
@@ -196,7 +199,11 @@ def power_iteration(
|
|
196 |
return v_out, s_out
|
197 |
|
198 |
|
199 |
-
def mat_power(
|
|
|
|
|
|
|
|
|
200 |
"""A simple matrix power method. M^p where p can be TracedValue."""
|
201 |
power = jnp.eye(mat_m.shape[0], dtype=_MAT_INV_PTH_ROOT_DTYPE)
|
202 |
|
@@ -245,15 +252,19 @@ def matrix_inverse_pth_root(
|
|
245 |
num_iters: Maximum number of iterations.
|
246 |
ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
|
247 |
error_tolerance: Error indicator, useful for early termination.
|
248 |
-
precision: precision XLA related flag, the available options are:
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
|
253 |
Returns:
|
254 |
matrix^(-1/p)
|
255 |
"""
|
256 |
|
|
|
|
|
|
|
|
|
257 |
assert matrix.shape[0] == matrix.shape[1]
|
258 |
|
259 |
# We use _MAT_INV_PTH_ROOT_DTYPE for the matrix inverse pth root.
|
@@ -336,8 +347,8 @@ def merge_small_dims(shape_to_merge, max_dim):
|
|
336 |
return resulting_shape
|
337 |
|
338 |
|
339 |
-
def
|
340 |
-
"""Pad a matrix to
|
341 |
|
342 |
Args:
|
343 |
mat: a matrix to pad.
|
@@ -346,19 +357,132 @@ def pad_matrix(mat, max_size):
|
|
346 |
Returns:
|
347 |
Given M returns [[M, 0], [0, I]]
|
348 |
"""
|
349 |
-
|
350 |
-
|
351 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
352 |
return mat
|
353 |
-
pad_size = max_size -
|
354 |
-
|
355 |
-
|
|
|
356 |
eye = jnp.eye(pad_size, dtype=mat.dtype)
|
357 |
mat = jnp.concatenate([mat, zs1], 1)
|
358 |
mat = jnp.concatenate([mat, jnp.concatenate([zs2, eye], 1)], 0)
|
359 |
return mat
|
360 |
|
361 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
362 |
def pad_vector(vec, max_size):
|
363 |
"""Pad a vector to a max_size.
|
364 |
|
@@ -694,18 +818,17 @@ def distributed_shampoo(
|
|
694 |
num_devices_for_pjit: Number of devices to parallelize over when using pjit.
|
695 |
shard_optimizer_states: Shard optimizer states to save memory in model
|
696 |
parallel training.
|
697 |
-
best_effort_memory_usage_reduction: Best effort memory usage reduction.
|
698 |
-
diagonal_statistics -> jnp.bfloat16
|
699 |
-
momentum buffers (2x) -> jnp.int8
|
700 |
statistics, preconditioners -> jnp.int16 + diagonals
|
701 |
inverse_failure_threshold: numerics are hard and inverses fail sometimes; we
|
702 |
determine that using this threshold.
|
703 |
moving_average_for_momentum: Whether to use moving average for momentum
|
704 |
instead of exponential moving average.
|
705 |
skip_preconditioning_dim_size_gt: Skip if preconditioning dim size is
|
706 |
-
|
707 |
-
clip_by_scaled_gradient_norm: Clip by scaled gradient norm (only useful
|
708 |
-
|
709 |
precision: precision XLA related flag, the available options are: a)
|
710 |
lax.Precision.DEFAULT (better step time, but not precise) b)
|
711 |
lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
|
@@ -1167,7 +1290,7 @@ def distributed_shampoo(
|
|
1167 |
new_padded_statistics = []
|
1168 |
for stat in new_stats_flat:
|
1169 |
new_padded_statistics.extend(
|
1170 |
-
[
|
1171 |
)
|
1172 |
|
1173 |
# Create global stats
|
@@ -1388,7 +1511,7 @@ def distributed_shampoo(
|
|
1388 |
num_devices = lax.psum(1, batch_axis_name)
|
1389 |
num_statistics = len(statistics)
|
1390 |
# Pad statistics and exponents to next multiple of num_devices.
|
1391 |
-
packed_statistics = [
|
1392 |
to_pad = -num_statistics % num_devices
|
1393 |
packed_statistics.extend(
|
1394 |
[jnp.eye(max_size, dtype=packed_statistics[0].dtype) for _ in range(to_pad)]
|
@@ -1540,7 +1663,7 @@ def distributed_shampoo(
|
|
1540 |
# diagonals [d] f32
|
1541 |
# bucket_sizes [d] f32
|
1542 |
packed_quantized_statistics = [
|
1543 |
-
|
1544 |
]
|
1545 |
packed_quantized_diagonals = [
|
1546 |
pad_vector(stat.diagonal, max_size) for stat in statistics
|
@@ -1772,7 +1895,7 @@ def distributed_shampoo(
|
|
1772 |
"""
|
1773 |
num_statistics = len(statistics)
|
1774 |
to_pad = -num_statistics % num_devices_for_pjit
|
1775 |
-
padded_statistics = [
|
1776 |
padded_statistics.extend(
|
1777 |
[jnp.eye(max_size, dtype=padded_statistics[0].dtype) for _ in range(to_pad)]
|
1778 |
)
|
|
|
25 |
# Authors: Rohan Anil (rohananil at google dot com)
|
26 |
# & Vineet Gupta (vineet at google dot com)
|
27 |
#
|
|
|
28 |
"""Distributed Shampoo Implementation."""
|
29 |
|
30 |
import enum
|
31 |
import functools
|
32 |
import itertools
|
33 |
+
from typing import Any, List, NamedTuple, Tuple
|
34 |
|
35 |
import chex
|
36 |
import jax
|
|
|
42 |
from jax import lax
|
43 |
|
44 |
from .quantization_utils import QuantizedValue
|
45 |
+
from .symmetric_matrices import symmetric_matrices
|
46 |
|
47 |
# Dtype for inverse-pth root routine
|
48 |
# Switch to f64 if you have hardware that supports it. Enable the jax flag
|
|
|
141 |
|
142 |
|
143 |
def power_iteration(
|
144 |
+
matrix,
|
145 |
+
num_iters=100,
|
146 |
+
error_tolerance=1e-6,
|
147 |
+
precision=lax.Precision.HIGHEST,
|
148 |
):
|
149 |
r"""Power iteration algorithm.
|
150 |
|
|
|
159 |
matrix: the symmetric PSD matrix.
|
160 |
num_iters: Number of iterations.
|
161 |
error_tolerance: Iterative exit condition.
|
162 |
+
precision: precision XLA related flag, the available options are: a)
|
163 |
+
lax.Precision.DEFAULT (better step time, but not precise) b)
|
164 |
+
lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
|
165 |
+
(best possible precision, slowest)
|
166 |
|
167 |
Returns:
|
168 |
eigen vector, eigen value
|
|
|
199 |
return v_out, s_out
|
200 |
|
201 |
|
202 |
+
def mat_power(
|
203 |
+
mat_m,
|
204 |
+
p,
|
205 |
+
precision=lax.Precision.HIGHEST,
|
206 |
+
):
|
207 |
"""A simple matrix power method. M^p where p can be TracedValue."""
|
208 |
power = jnp.eye(mat_m.shape[0], dtype=_MAT_INV_PTH_ROOT_DTYPE)
|
209 |
|
|
|
252 |
num_iters: Maximum number of iterations.
|
253 |
ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
|
254 |
error_tolerance: Error indicator, useful for early termination.
|
255 |
+
precision: precision XLA related flag, the available options are: a)
|
256 |
+
lax.Precision.DEFAULT (better step time, but not precise) b)
|
257 |
+
lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
|
258 |
+
(best possible precision, slowest)
|
259 |
|
260 |
Returns:
|
261 |
matrix^(-1/p)
|
262 |
"""
|
263 |
|
264 |
+
# If the input is not square, materialize it from the concatenated form.
|
265 |
+
if matrix.shape[0] != matrix.shape[1]:
|
266 |
+
matrix = symmetric_matrices.materialize_matrix_from_concat(matrix)
|
267 |
+
|
268 |
assert matrix.shape[0] == matrix.shape[1]
|
269 |
|
270 |
# We use _MAT_INV_PTH_ROOT_DTYPE for the matrix inverse pth root.
|
|
|
347 |
return resulting_shape
|
348 |
|
349 |
|
350 |
+
def pad_square_matrix(mat, max_size):
|
351 |
+
"""Pad a square matrix up to max_size.
|
352 |
|
353 |
Args:
|
354 |
mat: a matrix to pad.
|
|
|
357 |
Returns:
|
358 |
Given M returns [[M, 0], [0, I]]
|
359 |
"""
|
360 |
+
rows, cols = mat.shape
|
361 |
+
if rows != cols:
|
362 |
+
raise ValueError(
|
363 |
+
"Must have rows == cols, instead got " f"rows={rows}, cols={cols}"
|
364 |
+
)
|
365 |
+
if cols > max_size:
|
366 |
+
raise ValueError(
|
367 |
+
"Must have cols <= max_size. Instead got "
|
368 |
+
f"cols={cols}, max_size={max_size}."
|
369 |
+
)
|
370 |
+
if rows == max_size:
|
371 |
return mat
|
372 |
+
pad_size = max_size - rows
|
373 |
+
|
374 |
+
zs1 = jnp.zeros([rows, pad_size], dtype=mat.dtype)
|
375 |
+
zs2 = jnp.zeros([pad_size, rows], dtype=mat.dtype)
|
376 |
eye = jnp.eye(pad_size, dtype=mat.dtype)
|
377 |
mat = jnp.concatenate([mat, zs1], 1)
|
378 |
mat = jnp.concatenate([mat, jnp.concatenate([zs2, eye], 1)], 0)
|
379 |
return mat
|
380 |
|
381 |
|
382 |
+
def make_sliced_padding(
|
383 |
+
symmetric_block_size,
|
384 |
+
num_blocks,
|
385 |
+
starting_block,
|
386 |
+
dtype,
|
387 |
+
):
|
388 |
+
"""Returns padding for symmetric block matrix.
|
389 |
+
|
390 |
+
Specifically, the padding is given concatenated rectangular matrices
|
391 |
+
representing the lower-triangular rows below the starting block. For example,
|
392 |
+
if we want to pad the symmetric matrix
|
393 |
+
|
394 |
+
M = [[A, B^T]
|
395 |
+
[B, C]],
|
396 |
+
|
397 |
+
the desired output (in terms of the full matrix) with num_blocks = 4 is
|
398 |
+
|
399 |
+
M_padded = [[A, B^T, 0, 0]
|
400 |
+
[B, C, 0, 0]
|
401 |
+
[0, 0, I, 0]
|
402 |
+
0, 0, 0, I].
|
403 |
+
|
404 |
+
We would represent M as the block matrix mat = [A, B, C]. In this form, the
|
405 |
+
additional padding to provide has form [0, 0, I, 0, 0, 0, I] (only the lower
|
406 |
+
triangular parts in the third and fourth rows).
|
407 |
+
|
408 |
+
Args:
|
409 |
+
symmetric_block_size: The size of each block.
|
410 |
+
num_blocks: The total number of blocks.
|
411 |
+
starting_block: The block where to start the padding.
|
412 |
+
dtype: The type to use for the blocks.
|
413 |
+
"""
|
414 |
+
if starting_block == num_blocks:
|
415 |
+
return jnp.zeros(shape=(symmetric_block_size, 0), dtype=dtype)
|
416 |
+
|
417 |
+
blocks = []
|
418 |
+
for i in range(starting_block, num_blocks):
|
419 |
+
blocks.append(
|
420 |
+
jnp.zeros(
|
421 |
+
shape=(symmetric_block_size, symmetric_block_size * i), dtype=dtype
|
422 |
+
)
|
423 |
+
)
|
424 |
+
blocks.append(jnp.eye(symmetric_block_size, dtype=dtype))
|
425 |
+
return jnp.concatenate(blocks, axis=-1)
|
426 |
+
|
427 |
+
|
428 |
+
def pad_block_symmetric_matrix(
|
429 |
+
mat,
|
430 |
+
symmetric_block_size,
|
431 |
+
max_num_blocks,
|
432 |
+
):
|
433 |
+
"""Returns the padded blocked symmetric matrix.
|
434 |
+
|
435 |
+
The size of the padded matrix will be:
|
436 |
+
[symmetric_block_size, symmetric_block_size * max_num_blocks]
|
437 |
+
|
438 |
+
The input matrix can either:
|
439 |
+
- Be square with size less or equal to symmetric_block_size. In this case,
|
440 |
+
mat will first be padded to a square matrix of size symmetric_block_size,
|
441 |
+
and then be padded again up to the full size of the blocked matrix.
|
442 |
+
- Be a rectangle with number of rows equal to block size.
|
443 |
+
In this case, number of columns must be a multiple of number of rows, and
|
444 |
+
the ratio must correspond to a block representation of a symmetric matrix.
|
445 |
+
That is, the ratio must have form x * (x + 1) / 2. Here, x represents the
|
446 |
+
number of block rows represented by the matrix.
|
447 |
+
|
448 |
+
Args:
|
449 |
+
mat: The input block matrix.
|
450 |
+
symmetric_block_size: The size of blocks.
|
451 |
+
max_num_blocks: The largest number of blocks to pad to.
|
452 |
+
"""
|
453 |
+
rows, cols = mat.shape
|
454 |
+
if rows > symmetric_block_size:
|
455 |
+
raise ValueError(
|
456 |
+
"Must have rows <= symmetric_block_size. Instead got "
|
457 |
+
f"rows={rows}, symmetric_block_size={symmetric_block_size}."
|
458 |
+
)
|
459 |
+
if rows > cols:
|
460 |
+
raise ValueError(
|
461 |
+
"Must have rows <= cols, instead got " f"rows={rows}, cols={cols}."
|
462 |
+
)
|
463 |
+
if cols > symmetric_block_size * max_num_blocks:
|
464 |
+
raise ValueError(
|
465 |
+
"Must have cols <= symmetric_block_size * max_num_blocks "
|
466 |
+
f"Instead got cols={cols}, "
|
467 |
+
f"symmetric_block_size={symmetric_block_size}, "
|
468 |
+
f"max_num_blocks={max_num_blocks}."
|
469 |
+
)
|
470 |
+
if rows < symmetric_block_size:
|
471 |
+
mat = pad_square_matrix(mat, max_size=symmetric_block_size)
|
472 |
+
# Update rows and cols after possibly padding in pad_square_matrix.
|
473 |
+
rows, cols = mat.shape
|
474 |
+
assert rows == symmetric_block_size
|
475 |
+
assert cols % rows == 0
|
476 |
+
filled_blocks = cols // rows
|
477 |
+
padding_blocks = make_sliced_padding(
|
478 |
+
symmetric_block_size=symmetric_block_size,
|
479 |
+
num_blocks=symmetric_matrices.num_blocks_from_total_blocks(max_num_blocks),
|
480 |
+
starting_block=symmetric_matrices.num_blocks_from_total_blocks(filled_blocks),
|
481 |
+
dtype=mat.dtype,
|
482 |
+
)
|
483 |
+
return jnp.concatenate([mat, padding_blocks], axis=-1)
|
484 |
+
|
485 |
+
|
486 |
def pad_vector(vec, max_size):
|
487 |
"""Pad a vector to a max_size.
|
488 |
|
|
|
818 |
num_devices_for_pjit: Number of devices to parallelize over when using pjit.
|
819 |
shard_optimizer_states: Shard optimizer states to save memory in model
|
820 |
parallel training.
|
821 |
+
best_effort_memory_usage_reduction: Best effort memory usage reduction. -
|
822 |
+
diagonal_statistics -> jnp.bfloat16 - momentum buffers (2x) -> jnp.int8 -
|
|
|
823 |
statistics, preconditioners -> jnp.int16 + diagonals
|
824 |
inverse_failure_threshold: numerics are hard and inverses fail sometimes; we
|
825 |
determine that using this threshold.
|
826 |
moving_average_for_momentum: Whether to use moving average for momentum
|
827 |
instead of exponential moving average.
|
828 |
skip_preconditioning_dim_size_gt: Skip if preconditioning dim size is
|
829 |
+
greater than this value.
|
830 |
+
clip_by_scaled_gradient_norm: Clip by scaled gradient norm (only useful when
|
831 |
+
using RMSProp Grafting).
|
832 |
precision: precision XLA related flag, the available options are: a)
|
833 |
lax.Precision.DEFAULT (better step time, but not precise) b)
|
834 |
lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
|
|
|
1290 |
new_padded_statistics = []
|
1291 |
for stat in new_stats_flat:
|
1292 |
new_padded_statistics.extend(
|
1293 |
+
[pad_square_matrix(stat, max_size) for stat in stat.statistics]
|
1294 |
)
|
1295 |
|
1296 |
# Create global stats
|
|
|
1511 |
num_devices = lax.psum(1, batch_axis_name)
|
1512 |
num_statistics = len(statistics)
|
1513 |
# Pad statistics and exponents to next multiple of num_devices.
|
1514 |
+
packed_statistics = [pad_square_matrix(stat, max_size) for stat in statistics]
|
1515 |
to_pad = -num_statistics % num_devices
|
1516 |
packed_statistics.extend(
|
1517 |
[jnp.eye(max_size, dtype=packed_statistics[0].dtype) for _ in range(to_pad)]
|
|
|
1663 |
# diagonals [d] f32
|
1664 |
# bucket_sizes [d] f32
|
1665 |
packed_quantized_statistics = [
|
1666 |
+
pad_square_matrix(stat.quantized, max_size) for stat in statistics
|
1667 |
]
|
1668 |
packed_quantized_diagonals = [
|
1669 |
pad_vector(stat.diagonal, max_size) for stat in statistics
|
|
|
1895 |
"""
|
1896 |
num_statistics = len(statistics)
|
1897 |
to_pad = -num_statistics % num_devices_for_pjit
|
1898 |
+
padded_statistics = [pad_square_matrix(stat, max_size) for stat in statistics]
|
1899 |
padded_statistics.extend(
|
1900 |
[jnp.eye(max_size, dtype=padded_statistics[0].dtype) for _ in range(to_pad)]
|
1901 |
)
|
tools/train/scalable_shampoo/symmetric_matrices/symmetric_matrices.py
CHANGED
@@ -16,7 +16,7 @@
|
|
16 |
"""JAX Ops for symmetric matrices used by the Shampoo optimizer."""
|
17 |
|
18 |
import functools
|
19 |
-
from typing import Any, List, Sequence, Union
|
20 |
|
21 |
import jax
|
22 |
import jax.numpy as jnp
|
@@ -192,7 +192,7 @@ def materialize_matrix(symmetric_matrix):
|
|
192 |
@functools.partial(jax.jit, static_argnames=("num_blocks"))
|
193 |
def materialize_matrix_from_concat(
|
194 |
block_rows_concat,
|
195 |
-
num_blocks,
|
196 |
):
|
197 |
"""Returns a materialized symmetric matrix from concatenated slices.
|
198 |
|
@@ -200,7 +200,11 @@ def materialize_matrix_from_concat(
|
|
200 |
block_rows_concat: The matrix represented as the concatenated
|
201 |
lower-triangular blocks.
|
202 |
num_blocks: The number of block-rows used to represent the symmetric matrix.
|
|
|
203 |
"""
|
|
|
|
|
|
|
204 |
block_size = block_rows_concat.shape[-2]
|
205 |
|
206 |
block_rows = [
|
@@ -251,6 +255,28 @@ def update_sliced_rows(
|
|
251 |
)
|
252 |
|
253 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
def find_num_blocks(block_rows_concat):
|
255 |
"""Returns the number of (row) blocks representing the concatenated matrix.
|
256 |
|
@@ -270,11 +296,147 @@ def find_num_blocks(block_rows_concat):
|
|
270 |
# Compute the number of square blocks used to represent the matrix.
|
271 |
total_blocks = block_rows_concat.shape[-1] / block_rows_concat.shape[-2]
|
272 |
# Determine the number of block rows by inverting y = x*(x+1)/2.
|
273 |
-
|
274 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
raise ValueError(
|
276 |
-
"
|
277 |
-
"
|
278 |
)
|
279 |
-
|
280 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
"""JAX Ops for symmetric matrices used by the Shampoo optimizer."""
|
17 |
|
18 |
import functools
|
19 |
+
from typing import Any, List, Optional, Sequence, Union
|
20 |
|
21 |
import jax
|
22 |
import jax.numpy as jnp
|
|
|
192 |
@functools.partial(jax.jit, static_argnames=("num_blocks"))
|
193 |
def materialize_matrix_from_concat(
|
194 |
block_rows_concat,
|
195 |
+
num_blocks=None,
|
196 |
):
|
197 |
"""Returns a materialized symmetric matrix from concatenated slices.
|
198 |
|
|
|
200 |
block_rows_concat: The matrix represented as the concatenated
|
201 |
lower-triangular blocks.
|
202 |
num_blocks: The number of block-rows used to represent the symmetric matrix.
|
203 |
+
If not specified, it is inferred from the shape of block_rows_concat.
|
204 |
"""
|
205 |
+
if num_blocks is None:
|
206 |
+
num_blocks = find_num_blocks(block_rows_concat)
|
207 |
+
|
208 |
block_size = block_rows_concat.shape[-2]
|
209 |
|
210 |
block_rows = [
|
|
|
255 |
)
|
256 |
|
257 |
|
258 |
+
def num_blocks_from_total_blocks(total_blocks):
|
259 |
+
"""Returns the number of blocks (i.e.
|
260 |
+
|
261 |
+
block rows) from the total blocks.
|
262 |
+
|
263 |
+
This is the inverse of the function x -> x*(x+1)/2.
|
264 |
+
|
265 |
+
For example, the matrix M = [[A, B^T], [B, C]] may be represented using a
|
266 |
+
total of 3 blocks ([A, B, C]). The number of corresponding block rows is 2.
|
267 |
+
|
268 |
+
Args:
|
269 |
+
total_blocks: The total blocks used to represent the matrix.
|
270 |
+
"""
|
271 |
+
num_blocks = np.round((np.sqrt(8 * total_blocks + 1) - 1) / 2).astype(np.int32)
|
272 |
+
if (num_blocks * (num_blocks + 1)) / 2 != total_blocks:
|
273 |
+
raise ValueError(
|
274 |
+
f"total_blocks={total_blocks} does not correspond to "
|
275 |
+
"a symmetric matrix. It must have the form total_blocks = x*(x+1)/2."
|
276 |
+
)
|
277 |
+
return num_blocks
|
278 |
+
|
279 |
+
|
280 |
def find_num_blocks(block_rows_concat):
|
281 |
"""Returns the number of (row) blocks representing the concatenated matrix.
|
282 |
|
|
|
296 |
# Compute the number of square blocks used to represent the matrix.
|
297 |
total_blocks = block_rows_concat.shape[-1] / block_rows_concat.shape[-2]
|
298 |
# Determine the number of block rows by inverting y = x*(x+1)/2.
|
299 |
+
return num_blocks_from_total_blocks(total_blocks)
|
300 |
+
|
301 |
+
|
302 |
+
@functools.partial(jax.jit, static_argnames=("block_size"))
|
303 |
+
def slice_symmetric_matrix(
|
304 |
+
mat,
|
305 |
+
block_size,
|
306 |
+
):
|
307 |
+
"""Returns sliced row blocks.
|
308 |
+
|
309 |
+
Args:
|
310 |
+
mat: A symmetric matrix.
|
311 |
+
block_size: The size of the row slices.
|
312 |
+
"""
|
313 |
+
num_rows = mat.shape[-2]
|
314 |
+
num_cols = mat.shape[-1]
|
315 |
+
if num_rows != num_cols:
|
316 |
+
raise ValueError("mat is not square.")
|
317 |
+
if num_rows % block_size != 0:
|
318 |
raise ValueError(
|
319 |
+
"block size does not evenly divide rows. "
|
320 |
+
f"num_rows={num_rows}, block_size={block_size}"
|
321 |
)
|
322 |
+
return SlicedSymmetricMatrix(
|
323 |
+
block_rows=[
|
324 |
+
mat[
|
325 |
+
Ellipsis,
|
326 |
+
i * block_size : (i + 1) * block_size,
|
327 |
+
0 : (i + 1) * block_size,
|
328 |
+
]
|
329 |
+
for i in range(num_rows // block_size)
|
330 |
+
]
|
331 |
+
)
|
332 |
+
|
333 |
+
|
334 |
+
@functools.partial(jax.jit, static_argnames=("block_size"))
|
335 |
+
def slice_symmetric_matrix_concat(
|
336 |
+
mat,
|
337 |
+
block_size,
|
338 |
+
):
|
339 |
+
"""Returns the concatenated sliced row blocks.
|
340 |
+
|
341 |
+
Args:
|
342 |
+
mat: A symmetric matrix.
|
343 |
+
block_size: The size of the row slices.
|
344 |
+
"""
|
345 |
+
sliced_symmetric_matrix = slice_symmetric_matrix(mat=mat, block_size=block_size)
|
346 |
+
return jnp.concatenate(sliced_symmetric_matrix.block_rows, axis=-1)
|
347 |
+
|
348 |
+
|
349 |
+
def sliced_matrix_diag(mat):
|
350 |
+
"""Returns the diagonal of the symmetric matrix.
|
351 |
+
|
352 |
+
Args:
|
353 |
+
mat: The symmetric matrix represented in concatenated block form.
|
354 |
+
"""
|
355 |
+
rows, cols = mat.shape
|
356 |
+
total_blocks = cols // rows
|
357 |
+
num_blocks = num_blocks_from_total_blocks(total_blocks)
|
358 |
+
diags = []
|
359 |
+
for i in range(num_blocks):
|
360 |
+
last_index = rows * ((i + 2) * (i + 1)) // 2
|
361 |
+
first_index = last_index - rows
|
362 |
+
diags.append(jnp.diag(mat[Ellipsis, first_index:last_index]))
|
363 |
+
return jnp.concatenate(diags, axis=-1)
|
364 |
+
|
365 |
+
|
366 |
+
def diag_as_concat(diag, block_size):
|
367 |
+
"""Returns the representation of a diagonal matrix in symmetric block form.
|
368 |
+
|
369 |
+
Args:
|
370 |
+
diag: The 1D array for the diagonals.
|
371 |
+
block_size: The size of blocks to use. Must divide the length of diag.
|
372 |
+
"""
|
373 |
+
assert len(diag.shape) == 1 # diag must be 1D.
|
374 |
+
assert len(diag) % block_size == 0
|
375 |
+
num_diag_blocks = len(diag) // block_size
|
376 |
+
blocks = []
|
377 |
+
for i in range(num_diag_blocks):
|
378 |
+
blocks.append(jnp.zeros(shape=(block_size, block_size * i), dtype=diag.dtype))
|
379 |
+
blocks.append(jnp.diag(diag[i * block_size : (i + 1) * block_size]))
|
380 |
+
return jnp.concatenate(blocks, axis=-1)
|
381 |
+
|
382 |
+
|
383 |
+
def row_abs_maxes(mat):
|
384 |
+
"""Returns the max of the absolute values of the rows of the full matrix.
|
385 |
+
|
386 |
+
For example the symmetric matrix M = [[1, 6], [6, 2]] is represented using
|
387 |
+
mat = [1, 6, 2] with block_size = 1. In this case the function returns the
|
388 |
+
aboslute row maxes of the original symmetric matrix, [6, 6].
|
389 |
+
|
390 |
+
Args:
|
391 |
+
mat: The symmetric matrix represented as the concatenated blocks.
|
392 |
+
"""
|
393 |
+
rows, cols = mat.shape
|
394 |
+
|
395 |
+
# Find col and row max for each block.
|
396 |
+
col_maxes = []
|
397 |
+
row_maxes = []
|
398 |
+
for i in range(cols // rows):
|
399 |
+
block = jnp.abs(mat[Ellipsis, i * rows : (i + 1) * rows])
|
400 |
+
col_maxes.append(jnp.max(block, axis=1))
|
401 |
+
row_maxes.append(jnp.max(block, axis=0))
|
402 |
+
|
403 |
+
# global row max from block maxes.
|
404 |
+
num_blocks = num_blocks_from_total_blocks(cols // rows)
|
405 |
+
maxes = []
|
406 |
+
for i in range(num_blocks):
|
407 |
+
maxes.append(
|
408 |
+
jnp.concatenate(
|
409 |
+
row_maxes[(i * (i + 1) // 2) : ((i + 2) * (i + 1) // 2)]
|
410 |
+
+ [
|
411 |
+
col_maxes[((j + 1) * (j + 2)) // 2 - (j - i + 1)]
|
412 |
+
for j in range(i + 1, num_blocks)
|
413 |
+
],
|
414 |
+
axis=-1,
|
415 |
+
)
|
416 |
+
)
|
417 |
+
|
418 |
+
return jnp.max(jnp.stack(maxes), axis=0)
|
419 |
+
|
420 |
+
|
421 |
+
def times_vector(mat, vec):
|
422 |
+
"""Returns the symmetric block-concatenated matrix multiplied by a vector.
|
423 |
+
|
424 |
+
Specifically, each value in the vector is multiplied by a row of the full
|
425 |
+
matrix. That is, the vector is broadcast and multiplied element-wise. Note
|
426 |
+
this would be the transpose of full_mat * vec if full_mat represented the full
|
427 |
+
symmetric matrix.
|
428 |
+
|
429 |
+
Args:
|
430 |
+
mat: The symmetric matrix represented as the concatenated blocks.
|
431 |
+
vec: The vector, having the same dimension as the materialized matrix.
|
432 |
+
"""
|
433 |
+
rows, cols = mat.shape
|
434 |
+
num_blocks = num_blocks_from_total_blocks(cols // rows)
|
435 |
+
multiplied = []
|
436 |
+
for i in range(num_blocks):
|
437 |
+
mat_block = mat[
|
438 |
+
Ellipsis, rows * ((i + 1) * i) // 2 : rows * ((i + 1) * (i + 2)) // 2
|
439 |
+
]
|
440 |
+
vec_block = vec[Ellipsis, rows * i : rows * (i + 1)]
|
441 |
+
multiplied.append(jnp.einsum("...ij,...i->ij", mat_block, vec_block))
|
442 |
+
return jnp.concatenate(multiplied, axis=-1)
|