Spaces:
Runtime error
Runtime error
import random | |
import torch | |
import torch.nn.functional as F | |
import torch.distributed as dist | |
from typing import List | |
from torch import nn | |
from torch.nn import Module | |
from torch.amp import autocast | |
from einx import get_at | |
from einops import rearrange, reduce, pack, unpack | |
from sparktts.modules.fsq.finite_scalar_quantization import FSQ | |
def exists(val): | |
return val is not None | |
def first(l): | |
return l[0] | |
def default(val, d): | |
return val if exists(val) else d | |
def round_up_multiple(num, mult): | |
return ceil(num / mult) * mult | |
# distributed helpers | |
def is_distributed(): | |
return dist.is_initialized() and dist.get_world_size() > 1 | |
def get_maybe_sync_seed(device, max_size=10_000): | |
rand_int = torch.randint(0, max_size, (), device=device) | |
if is_distributed(): | |
dist.all_reduce(rand_int) | |
return rand_int.item() | |
class ResidualFSQ(Module): | |
"""Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf""" | |
def __init__( | |
self, | |
*, | |
levels: List[int], | |
num_quantizers, | |
dim=None, | |
is_channel_first=False, | |
quantize_dropout=False, | |
quantize_dropout_cutoff_index=0, | |
quantize_dropout_multiple_of=1, | |
**kwargs, | |
): | |
super().__init__() | |
codebook_dim = len(levels) | |
dim = default(dim, codebook_dim) | |
requires_projection = codebook_dim != dim | |
self.project_in = ( | |
nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity() | |
) | |
self.project_out = ( | |
nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity() | |
) | |
self.has_projections = requires_projection | |
self.is_channel_first = is_channel_first | |
self.num_quantizers = num_quantizers | |
self.levels = levels | |
self.layers = nn.ModuleList([]) | |
levels_tensor = torch.Tensor(levels) | |
scales = [] | |
for ind in range(num_quantizers): | |
scales.append((levels_tensor - 1) ** -ind) | |
fsq = FSQ(levels=levels, dim=codebook_dim, **kwargs) | |
self.layers.append(fsq) | |
assert all([not fsq.has_projections for fsq in self.layers]) | |
self.codebook_size = self.layers[0].codebook_size | |
self.register_buffer("scales", torch.stack(scales), persistent=False) | |
self.quantize_dropout = quantize_dropout and num_quantizers > 1 | |
assert quantize_dropout_cutoff_index >= 0 | |
self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index | |
self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4 | |
def codebooks(self): | |
codebooks = [layer.implicit_codebook for layer in self.layers] | |
codebooks = torch.stack(codebooks, dim=0) | |
return codebooks | |
def get_codes_from_indices(self, indices): | |
batch, quantize_dim = indices.shape[0], indices.shape[-1] | |
# may also receive indices in the shape of 'b h w q' (accept_image_fmap) | |
indices, ps = pack([indices], "b * q") | |
# because of quantize dropout, one can pass in indices that are coarse | |
# and the network should be able to reconstruct | |
if quantize_dim < self.num_quantizers: | |
assert ( | |
self.quantize_dropout > 0.0 | |
), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations" | |
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1) | |
# take care of quantizer dropout | |
mask = indices == -1 | |
indices = indices.masked_fill( | |
mask, 0 | |
) # have it fetch a dummy code to be masked out later | |
all_codes = get_at("q [c] d, b n q -> q b n d", self.codebooks, indices) | |
# mask out any codes that were dropout-ed | |
all_codes = all_codes.masked_fill(rearrange(mask, "b n q -> q b n 1"), 0.0) | |
# scale the codes | |
scales = rearrange(self.scales, "q d -> q 1 1 d") | |
all_codes = all_codes * scales | |
# if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension) | |
(all_codes,) = unpack(all_codes, ps, "q b * d") | |
return all_codes | |
def get_output_from_indices(self, indices): | |
codes = self.get_codes_from_indices(indices) | |
codes_summed = reduce(codes, "q ... -> ...", "sum") | |
return self.project_out(codes_summed) | |
def forward(self, x, return_all_codes=False, rand_quantize_dropout_fixed_seed=None): | |
num_quant, quant_dropout_multiple_of, device = ( | |
self.num_quantizers, | |
self.quantize_dropout_multiple_of, | |
x.device, | |
) | |
# handle channel first | |
if self.is_channel_first: | |
x = rearrange(x, "b d ... -> b ... d") | |
x, ps = pack([x], "b * d") | |
# maybe project in | |
x = self.project_in(x) | |
quantized_out = 0.0 | |
residual = x | |
all_indices = [] | |
should_quantize_dropout = self.training and self.quantize_dropout | |
# sample a layer index at which to dropout further residual quantization | |
# also prepare null indices | |
if should_quantize_dropout: | |
# check if seed is manually passed in | |
if not exists(rand_quantize_dropout_fixed_seed): | |
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device) | |
rand = random.Random(rand_quantize_dropout_fixed_seed) | |
rand_quantize_dropout_index = rand.randrange( | |
self.quantize_dropout_cutoff_index, num_quant | |
) | |
if quant_dropout_multiple_of != 1: | |
rand_quantize_dropout_index = ( | |
round_up_multiple( | |
rand_quantize_dropout_index + 1, quant_dropout_multiple_of | |
) | |
- 1 | |
) | |
null_indices = torch.full( | |
x.shape[:2], -1.0, device=device, dtype=torch.long | |
) | |
# go through the layers | |
with autocast("cuda", enabled=False): | |
for quantizer_index, (layer, scale) in enumerate( | |
zip(self.layers, self.scales) | |
): | |
if ( | |
should_quantize_dropout | |
and quantizer_index > rand_quantize_dropout_index | |
): | |
all_indices.append(null_indices) | |
continue | |
quantized, indices = layer(residual / scale) | |
quantized = quantized * scale | |
residual = residual - quantized.detach() | |
quantized_out = quantized_out + quantized | |
all_indices.append(indices) | |
# project out, if needed | |
quantized_out = self.project_out(quantized_out) | |
# stack all indices | |
all_indices = torch.stack(all_indices, dim=-1) | |
# channel first out | |
if self.is_channel_first: | |
(quantized_out,) = unpack(quantized_out, ps, "b * d") | |
(all_indices,) = unpack(all_indices, ps, "b * d") | |
quantized_out = rearrange(quantized_out, "b ... d -> b d ...") | |
all_indices = rearrange(all_indices, "b ... d -> b d ...") | |
# return | |
ret = (quantized_out, all_indices) | |
if not return_all_codes: | |
return ret | |
# whether to return all codes from all codebooks across layers | |
all_codes = self.get_codes_from_indices(all_indices) | |
# will return all codes in shape (quantizer, batch, sequence length, codebook dimension) | |
return (*ret, all_codes) | |
# grouped residual fsq | |
class GroupedResidualFSQ(Module): | |
def __init__(self, *, dim, groups=1, accept_image_fmap=False, **kwargs): | |
super().__init__() | |
self.dim = dim | |
self.groups = groups | |
assert (dim % groups) == 0 | |
dim_per_group = dim // groups | |
self.accept_image_fmap = accept_image_fmap | |
self.rvqs = nn.ModuleList([]) | |
for _ in range(groups): | |
self.rvqs.append(ResidualFSQ(dim=dim_per_group, **kwargs)) | |
self.codebook_size = self.rvqs[0].codebook_size | |
def codebooks(self): | |
return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs)) | |
def split_dim(self): | |
return 1 if self.accept_image_fmap else -1 | |
def get_codes_from_indices(self, indices): | |
codes = tuple( | |
rvq.get_codes_from_indices(chunk_indices) | |
for rvq, chunk_indices in zip(self.rvqs, indices) | |
) | |
return torch.stack(codes) | |
def get_output_from_indices(self, indices): | |
outputs = tuple( | |
rvq.get_output_from_indices(chunk_indices) | |
for rvq, chunk_indices in zip(self.rvqs, indices) | |
) | |
return torch.cat(outputs, dim=self.split_dim) | |
def forward(self, x, return_all_codes=False): | |
shape, split_dim, device = x.shape, self.split_dim, x.device | |
assert shape[split_dim] == self.dim | |
# split the feature dimension into groups | |
x = x.chunk(self.groups, dim=split_dim) | |
forward_kwargs = dict( | |
return_all_codes=return_all_codes, | |
rand_quantize_dropout_fixed_seed=( | |
get_maybe_sync_seed(device) if self.training else None | |
), | |
) | |
# invoke residual vq on each group | |
out = tuple(rvq(chunk, **forward_kwargs) for rvq, chunk in zip(self.rvqs, x)) | |
out = tuple(zip(*out)) | |
# otherwise, get all the zipped outputs and combine them | |
quantized, all_indices, *maybe_all_codes = out | |
quantized = torch.cat(quantized, dim=split_dim) | |
all_indices = torch.stack(all_indices) | |
ret = (quantized, all_indices, *maybe_all_codes) | |
return ret | |
if __name__ == "__main__": | |
model = ResidualFSQ( | |
levels=[4, 4, 4, 4, 4, 4], | |
num_quantizers=1, | |
dim=30, | |
is_channel_first=True, | |
quantize_dropout=False, | |
) | |
x = torch.randn(2, 30, 10) | |
quantize, embed_ind = model(x) | |
emb_from_ind = model.get_output_from_indices(embed_ind.transpose(1, 2)) | |
print(quantize == emb_from_ind.transpose(1, 2)) | |
print("quantize shape", quantize.shape) | |
print("embed_ind", embed_ind) | |