|
from __future__ import annotations |
|
|
|
import contextlib |
|
import functools |
|
import os |
|
from typing import Callable, Iterator, Optional, Sequence |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
ellipsis = type(...) |
|
|
|
|
|
def get_mean_nbits_by_codebook(codes: torch.IntTensor, huffman_group_size: int = 2): |
|
|
|
""" |
|
Calculates average code length in codebooks. |
|
:param codes: codebook codes |
|
:param huffman_group_size: huffman compresssion dimension count |
|
""" |
|
import huffman |
|
|
|
_, codebook_size, num_codebooks = codes.shape |
|
flat_codes_by_codebook = codes.permute(2, 0, 1).flatten(1, 2) |
|
code_counts = torch.zeros( |
|
num_codebooks, codebook_size, device=flat_codes_by_codebook.device, dtype=flat_codes_by_codebook.dtype |
|
).scatter_add( |
|
-1, flat_codes_by_codebook, torch.ones_like(flat_codes_by_codebook) |
|
) |
|
code_probs = code_counts / code_counts.sum(dim=-1, keepdim=True).float() |
|
code_probs = code_probs.cpu().numpy() |
|
assert num_codebooks % huffman_group_size == 0 |
|
|
|
mean_code_lengths = [] |
|
for group_index in range(num_codebooks // huffman_group_size): |
|
group_code_probs = {(): 1} |
|
|
|
for codebook_index in range(group_index * huffman_group_size, (group_index + 1) * huffman_group_size): |
|
new_group_code_probs = {} |
|
for group, group_prob in group_code_probs.items(): |
|
for code, code_prob in tuple(enumerate(code_probs[codebook_index])): |
|
new_group_code_probs[group + (code,)] = group_prob * code_prob |
|
group_code_probs = new_group_code_probs |
|
|
|
huffman_codebook_i = huffman.codebook(list(group_code_probs.items())) |
|
codebook_mean_code_length_i = sum( |
|
len(huffman_codebook_i[code]) * prob for code, prob in group_code_probs.items() |
|
) |
|
mean_code_lengths.append(codebook_mean_code_length_i) |
|
return mean_code_lengths |
|
|
|
|
|
def get_int_dtype(nbits: int) -> torch.dtype: |
|
if nbits <= 8: |
|
return torch.int8 |
|
if nbits <= 16: |
|
return torch.int16 |
|
if nbits <= 32: |
|
return torch.int32 |
|
if nbits <= 64: |
|
return torch.int64 |
|
raise ValueError(f"No dtype available for {nbits}-bit codebooks") |
|
|
|
|
|
@torch.inference_mode() |
|
def pack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor: |
|
data[data >= 2 ** (nbits - 1)] -= 2**nbits |
|
return data.to(get_int_dtype(nbits)) |
|
|
|
|
|
@torch.inference_mode() |
|
def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor: |
|
return data.to(torch.int64) % (2**nbits) |
|
|
|
|
|
@functools.lru_cache() |
|
def maybe_script(fn: callable) -> callable: |
|
"""Apply torch.jit.script to function unless one is using TPU. TPU does not support torch.jit.script.""" |
|
using_tpu = bool(os.environ.get("TPU_NAME")) |
|
|
|
should_script = int(os.environ.get("AQ_USE_JIT", not using_tpu)) |
|
return torch.jit.script(fn) if should_script else fn |
|
|
|
|
|
@contextlib.contextmanager |
|
def using_tf32(enabled: bool): |
|
was_cudnn = torch.backends.cudnn.allow_tf32 |
|
was_matmul = torch.backends.cuda.matmul.allow_tf32 |
|
torch.backends.cudnn.allow_tf32 = enabled |
|
torch.backends.cuda.matmul.allow_tf32 = enabled |
|
yield |
|
torch.backends.cudnn.allow_tf32 = was_cudnn |
|
torch.backends.cuda.matmul.allow_tf32 = was_matmul |
|
|
|
|
|
def iterate_minibatches( |
|
*tensors: torch.Tensor, |
|
batch_size: int, |
|
allow_incomplete: bool = True, |
|
device: Optional[torch.device] = None, |
|
callback: Callable[[Sequence[torch.Tensor]], Sequence[torch.Tensor]] = lambda x: x, |
|
) -> Iterator[Sequence[torch.Tensor]]: |
|
""" |
|
Samples data points *forever*, in random order, with less overhead than DataLoader; |
|
Adapted from https://github.com/stanis-morozov/unq/blob/master/lib/utils.py |
|
probably implemented over9000 times in transformers, torch, etc |
|
:param tensors: one or more tensors with the same 0-th dimension |
|
:param batch_size: sample this many points with each yield |
|
:param allow_incomplete: if True and if dataset size is not divisible by batch size, the last batch |
|
may have less than :batch_size: samples to cover the entire dataset. If False, the last batch is dropped |
|
:param callback: optional function to be called on each batch of tensors before it is yielded to the user |
|
:returns: generates a tuple of minibatches from each tensor, same length as input *tensors |
|
If a batch contains only one tensor, this function will yield a tensor (and not a tuple/list with one tensor) |
|
""" |
|
num_samples = len(tensors[0]) |
|
assert all(len(x) == num_samples for x in tensors) |
|
indices = torch.randperm(num_samples, device=tensors[0].device) |
|
while True: |
|
prev_batch = None |
|
for batch_start in range(0, len(indices), batch_size): |
|
if not allow_incomplete and batch_start + batch_size > len(indices): |
|
break |
|
batch_ix = indices[batch_start : batch_start + batch_size] |
|
batch = callback(tuple(tensor[batch_ix].to(device, non_blocking=True) for tensor in tensors)) |
|
if prev_batch is not None: |
|
yield prev_batch |
|
prev_batch = batch if isinstance(batch, (list, tuple)) and len(tensors) > 1 else batch[0] |
|
del batch |
|
yield prev_batch |
|
|
|
|
|
@maybe_script |
|
def _dequantize_weight( |
|
codes: torch.Tensor, codebooks: torch.Tensor, scales: Optional[torch.Tensor] = None |
|
) -> torch.Tensor: |
|
""" |
|
Decode float weights from quantization codes. Differentiable. |
|
:param codes: tensor of integer quantization codes, shape [*dims, num_out_groups, num_in_groups, num_codebooks] |
|
:param codebooks: tensor of vectors for each quantization code, [num_codebooks, codebook_size, out_group_size, in_group_size] |
|
:param scales: weight will be multiplied by this factor, must be broadcastble with [*dims, out_groups, num_in_groups, out_group_size, in_group_size] |
|
:return: reconstructed weight tensor of shape [*dims, num_in_groups*group_size] |
|
""" |
|
num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:] |
|
num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape |
|
out_features = num_out_groups * out_group_size |
|
in_features = num_in_groups * in_group_size |
|
codebook_offsets = torch.arange( |
|
0, num_codebooks * codebook_size, codebook_size, device=codes.device |
|
) |
|
reconstructed_weight_flat = F.embedding_bag( |
|
codes.flatten(0, -2) + codebook_offsets, codebooks.flatten(0, 1).flatten(-2, -1), mode="sum" |
|
) |
|
|
|
reconstructed_weight_groupwise = reconstructed_weight_flat.view( |
|
list(codes.shape[:-3]) + [num_out_groups, num_in_groups, out_group_size, in_group_size] |
|
) |
|
if scales is not None: |
|
reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(scales) |
|
return reconstructed_weight_groupwise.swapaxes(-3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features]) |
|
|