kunstnerfrits / tools /train /scalable_shampoo /quantization_utils.py
FritsLyneborg's picture
AI upload
6742988
# coding=utf-8
# Copyright 2022 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper routines for quantization."""
from typing import Any
import chex
import jax.numpy as jnp
from flax import struct
# pylint:disable=no-value-for-parameter
@struct.dataclass
class QuantizedValue:
"""State associated with quantized value."""
quantized: chex.Array
diagonal: chex.Array # Diagonal (if extract_diagonal is set)
bucket_size: chex.Array
quantized_dtype: jnp.dtype = struct.field(
pytree_node=False
) # Dtype for the quantized value.
extract_diagonal: bool = struct.field(pytree_node=False) # In case its centered.
shape: Any = struct.field(pytree_node=False) # Shape of the tensor.
@classmethod
def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False):
if isinstance(fvalue, list) and not fvalue:
return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, [])
quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize(
fvalue, quantized_dtype, extract_diagonal
)
return QuantizedValue(
quantized,
diagonal_fvalue,
bucket_size,
quantized_dtype,
extract_diagonal,
list(quantized.shape),
)
# Quantization is from Lingvo JAX optimizers.
# We extend it for int16 quantization of PSD matrices.
@classmethod
def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False):
"""Returns quantized value and the bucket."""
if quantized_dtype == jnp.float32:
return fvalue, [], []
elif quantized_dtype == jnp.bfloat16:
return fvalue.astype(jnp.bfloat16), [], []
float_dtype = fvalue.dtype
if quantized_dtype == jnp.int8:
# value -128 is not used.
num_buckets = jnp.array(127.0, dtype=float_dtype)
elif quantized_dtype == jnp.int16:
# value -32768 is not used.
num_buckets = jnp.array(32767.0, dtype=float_dtype)
else:
raise ValueError(f"Quantized dtype {quantized_dtype} not supported.")
# max value is mapped to num_buckets
if extract_diagonal and fvalue.ndim != 2:
raise ValueError(
f"Input array {fvalue} must be 2D to work with extract_diagonal."
)
diagonal_fvalue = []
if extract_diagonal:
diagonal_fvalue = jnp.diag(fvalue)
# Remove the diagonal entries.
fvalue = fvalue - jnp.diag(diagonal_fvalue)
# TODO(rohananil): Extend this by making use of information about the blocks
# SM3 style which will be useful for diagonal statistics
# We first decide the scale.
if fvalue.ndim < 1:
raise ValueError(
f"Input array {fvalue} must have a strictly positive number of "
"dimensions."
)
max_abs = jnp.max(jnp.abs(fvalue), axis=0)
bucket_size = max_abs / num_buckets
bs_expanded = bucket_size[jnp.newaxis, Ellipsis]
# To avoid divide by 0.0
bs_nonzero = jnp.where(
bs_expanded > 0.0, bs_expanded, jnp.ones_like(bs_expanded)
)
ratio = fvalue / bs_nonzero
# We use rounding to remove bias.
quantized = jnp.round(ratio)
return quantized.astype(quantized_dtype), diagonal_fvalue, bucket_size
def to_float(self):
"""Returns the float value."""
if isinstance(self.quantized, list) and not self.quantized:
return self.quantized
if self.quantized_dtype == jnp.float32:
return self.quantized
if self.quantized_dtype == jnp.bfloat16:
return self.quantized.astype(jnp.float32)
float_dtype = self.bucket_size.dtype
bucket_size = self.bucket_size[jnp.newaxis, Ellipsis]
val = self.quantized.astype(float_dtype) * bucket_size
if self.extract_diagonal:
val += jnp.diag(self.diagonal)
return val