mrfakename's picture
Upload 43 files
d93aca0 verified
raw
history blame
10.3 kB
import random
import torch
import torch.nn.functional as F
import torch.distributed as dist
from typing import List
from torch import nn
from torch.nn import Module
from torch.amp import autocast
from einx import get_at
from einops import rearrange, reduce, pack, unpack
from sparktts.modules.fsq.finite_scalar_quantization import FSQ
def exists(val):
return val is not None
def first(l):
return l[0]
def default(val, d):
return val if exists(val) else d
def round_up_multiple(num, mult):
return ceil(num / mult) * mult
# distributed helpers
def is_distributed():
return dist.is_initialized() and dist.get_world_size() > 1
def get_maybe_sync_seed(device, max_size=10_000):
rand_int = torch.randint(0, max_size, (), device=device)
if is_distributed():
dist.all_reduce(rand_int)
return rand_int.item()
class ResidualFSQ(Module):
"""Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf"""
def __init__(
self,
*,
levels: List[int],
num_quantizers,
dim=None,
is_channel_first=False,
quantize_dropout=False,
quantize_dropout_cutoff_index=0,
quantize_dropout_multiple_of=1,
**kwargs,
):
super().__init__()
codebook_dim = len(levels)
dim = default(dim, codebook_dim)
requires_projection = codebook_dim != dim
self.project_in = (
nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
)
self.project_out = (
nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
)
self.has_projections = requires_projection
self.is_channel_first = is_channel_first
self.num_quantizers = num_quantizers
self.levels = levels
self.layers = nn.ModuleList([])
levels_tensor = torch.Tensor(levels)
scales = []
for ind in range(num_quantizers):
scales.append((levels_tensor - 1) ** -ind)
fsq = FSQ(levels=levels, dim=codebook_dim, **kwargs)
self.layers.append(fsq)
assert all([not fsq.has_projections for fsq in self.layers])
self.codebook_size = self.layers[0].codebook_size
self.register_buffer("scales", torch.stack(scales), persistent=False)
self.quantize_dropout = quantize_dropout and num_quantizers > 1
assert quantize_dropout_cutoff_index >= 0
self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4
@property
def codebooks(self):
codebooks = [layer.implicit_codebook for layer in self.layers]
codebooks = torch.stack(codebooks, dim=0)
return codebooks
def get_codes_from_indices(self, indices):
batch, quantize_dim = indices.shape[0], indices.shape[-1]
# may also receive indices in the shape of 'b h w q' (accept_image_fmap)
indices, ps = pack([indices], "b * q")
# because of quantize dropout, one can pass in indices that are coarse
# and the network should be able to reconstruct
if quantize_dim < self.num_quantizers:
assert (
self.quantize_dropout > 0.0
), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1)
# take care of quantizer dropout
mask = indices == -1
indices = indices.masked_fill(
mask, 0
) # have it fetch a dummy code to be masked out later
all_codes = get_at("q [c] d, b n q -> q b n d", self.codebooks, indices)
# mask out any codes that were dropout-ed
all_codes = all_codes.masked_fill(rearrange(mask, "b n q -> q b n 1"), 0.0)
# scale the codes
scales = rearrange(self.scales, "q d -> q 1 1 d")
all_codes = all_codes * scales
# if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension)
(all_codes,) = unpack(all_codes, ps, "q b * d")
return all_codes
def get_output_from_indices(self, indices):
codes = self.get_codes_from_indices(indices)
codes_summed = reduce(codes, "q ... -> ...", "sum")
return self.project_out(codes_summed)
def forward(self, x, return_all_codes=False, rand_quantize_dropout_fixed_seed=None):
num_quant, quant_dropout_multiple_of, device = (
self.num_quantizers,
self.quantize_dropout_multiple_of,
x.device,
)
# handle channel first
if self.is_channel_first:
x = rearrange(x, "b d ... -> b ... d")
x, ps = pack([x], "b * d")
# maybe project in
x = self.project_in(x)
quantized_out = 0.0
residual = x
all_indices = []
should_quantize_dropout = self.training and self.quantize_dropout
# sample a layer index at which to dropout further residual quantization
# also prepare null indices
if should_quantize_dropout:
# check if seed is manually passed in
if not exists(rand_quantize_dropout_fixed_seed):
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device)
rand = random.Random(rand_quantize_dropout_fixed_seed)
rand_quantize_dropout_index = rand.randrange(
self.quantize_dropout_cutoff_index, num_quant
)
if quant_dropout_multiple_of != 1:
rand_quantize_dropout_index = (
round_up_multiple(
rand_quantize_dropout_index + 1, quant_dropout_multiple_of
)
- 1
)
null_indices = torch.full(
x.shape[:2], -1.0, device=device, dtype=torch.long
)
# go through the layers
with autocast("cuda", enabled=False):
for quantizer_index, (layer, scale) in enumerate(
zip(self.layers, self.scales)
):
if (
should_quantize_dropout
and quantizer_index > rand_quantize_dropout_index
):
all_indices.append(null_indices)
continue
quantized, indices = layer(residual / scale)
quantized = quantized * scale
residual = residual - quantized.detach()
quantized_out = quantized_out + quantized
all_indices.append(indices)
# project out, if needed
quantized_out = self.project_out(quantized_out)
# stack all indices
all_indices = torch.stack(all_indices, dim=-1)
# channel first out
if self.is_channel_first:
(quantized_out,) = unpack(quantized_out, ps, "b * d")
(all_indices,) = unpack(all_indices, ps, "b * d")
quantized_out = rearrange(quantized_out, "b ... d -> b d ...")
all_indices = rearrange(all_indices, "b ... d -> b d ...")
# return
ret = (quantized_out, all_indices)
if not return_all_codes:
return ret
# whether to return all codes from all codebooks across layers
all_codes = self.get_codes_from_indices(all_indices)
# will return all codes in shape (quantizer, batch, sequence length, codebook dimension)
return (*ret, all_codes)
# grouped residual fsq
class GroupedResidualFSQ(Module):
def __init__(self, *, dim, groups=1, accept_image_fmap=False, **kwargs):
super().__init__()
self.dim = dim
self.groups = groups
assert (dim % groups) == 0
dim_per_group = dim // groups
self.accept_image_fmap = accept_image_fmap
self.rvqs = nn.ModuleList([])
for _ in range(groups):
self.rvqs.append(ResidualFSQ(dim=dim_per_group, **kwargs))
self.codebook_size = self.rvqs[0].codebook_size
@property
def codebooks(self):
return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs))
@property
def split_dim(self):
return 1 if self.accept_image_fmap else -1
def get_codes_from_indices(self, indices):
codes = tuple(
rvq.get_codes_from_indices(chunk_indices)
for rvq, chunk_indices in zip(self.rvqs, indices)
)
return torch.stack(codes)
def get_output_from_indices(self, indices):
outputs = tuple(
rvq.get_output_from_indices(chunk_indices)
for rvq, chunk_indices in zip(self.rvqs, indices)
)
return torch.cat(outputs, dim=self.split_dim)
def forward(self, x, return_all_codes=False):
shape, split_dim, device = x.shape, self.split_dim, x.device
assert shape[split_dim] == self.dim
# split the feature dimension into groups
x = x.chunk(self.groups, dim=split_dim)
forward_kwargs = dict(
return_all_codes=return_all_codes,
rand_quantize_dropout_fixed_seed=(
get_maybe_sync_seed(device) if self.training else None
),
)
# invoke residual vq on each group
out = tuple(rvq(chunk, **forward_kwargs) for rvq, chunk in zip(self.rvqs, x))
out = tuple(zip(*out))
# otherwise, get all the zipped outputs and combine them
quantized, all_indices, *maybe_all_codes = out
quantized = torch.cat(quantized, dim=split_dim)
all_indices = torch.stack(all_indices)
ret = (quantized, all_indices, *maybe_all_codes)
return ret
if __name__ == "__main__":
model = ResidualFSQ(
levels=[4, 4, 4, 4, 4, 4],
num_quantizers=1,
dim=30,
is_channel_first=True,
quantize_dropout=False,
)
x = torch.randn(2, 30, 10)
quantize, embed_ind = model(x)
emb_from_ind = model.get_output_from_indices(embed_ind.transpose(1, 2))
print(quantize == emb_from_ind.transpose(1, 2))
print("quantize shape", quantize.shape)
print("embed_ind", embed_ind)