Spaces:
Paused
Paused
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import math | |
import typing as tp | |
import torch | |
from .base import BaseQuantizer, QuantizedResult | |
from .core_vq import ResidualVectorQuantization | |
class ResidualVectorQuantizer(BaseQuantizer): | |
"""Residual Vector Quantizer. | |
Args: | |
dimension (int): Dimension of the codebooks. | |
n_q (int): Number of residual vector quantizers used. | |
q_dropout (bool): Random quantizer drop out at train time. | |
bins (int): Codebook size. | |
decay (float): Decay for exponential moving average over the codebooks. | |
kmeans_init (bool): Whether to use kmeans to initialize the codebooks. | |
kmeans_iters (int): Number of iterations used for kmeans initialization. | |
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes | |
that have an exponential moving average cluster size less than the specified threshold with | |
randomly selected vector from the current batch. | |
orthogonal_reg_weight (float): Orthogonal regularization weights. | |
orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes. | |
orthogonal_reg_max_codes (optional int): Maximum number of codes to consider. | |
for orthogonal regularization. | |
""" | |
def __init__( | |
self, | |
dimension: int = 256, | |
n_q: int = 8, | |
q_dropout: bool = False, | |
bins: int = 1024, | |
decay: float = 0.99, | |
kmeans_init: bool = True, | |
kmeans_iters: int = 10, | |
threshold_ema_dead_code: int = 2, | |
orthogonal_reg_weight: float = 0.0, | |
orthogonal_reg_active_codes_only: bool = False, | |
orthogonal_reg_max_codes: tp.Optional[int] = None, | |
): | |
super().__init__() | |
self.max_n_q = n_q | |
self.n_q = n_q | |
self.q_dropout = q_dropout | |
self.dimension = dimension | |
self.bins = bins | |
self.decay = decay | |
self.kmeans_init = kmeans_init | |
self.kmeans_iters = kmeans_iters | |
self.threshold_ema_dead_code = threshold_ema_dead_code | |
self.orthogonal_reg_weight = orthogonal_reg_weight | |
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only | |
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes | |
self.vq = ResidualVectorQuantization( | |
dim=self.dimension, | |
codebook_size=self.bins, | |
num_quantizers=self.n_q, | |
decay=self.decay, | |
kmeans_init=self.kmeans_init, | |
kmeans_iters=self.kmeans_iters, | |
threshold_ema_dead_code=self.threshold_ema_dead_code, | |
orthogonal_reg_weight=self.orthogonal_reg_weight, | |
orthogonal_reg_active_codes_only=self.orthogonal_reg_active_codes_only, | |
orthogonal_reg_max_codes=self.orthogonal_reg_max_codes, | |
channels_last=False | |
) | |
def forward(self, x: torch.Tensor, frame_rate: int): | |
n_q = self.n_q | |
if self.training and self.q_dropout: | |
n_q = int(torch.randint(1, self.n_q + 1, (1,)).item()) | |
bw_per_q = math.log2(self.bins) * frame_rate / 1000 | |
quantized, codes, commit_loss = self.vq(x, n_q=n_q) | |
codes = codes.transpose(0, 1) | |
# codes is [B, K, T], with T frames, K nb of codebooks. | |
bw = torch.tensor(n_q * bw_per_q).to(x) | |
return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss)) | |
def encode(self, x: torch.Tensor) -> torch.Tensor: | |
"""Encode a given input tensor with the specified frame rate at the given bandwidth. | |
The RVQ encode method sets the appropriate number of quantizer to use | |
and returns indices for each quantizer. | |
""" | |
n_q = self.n_q | |
codes = self.vq.encode(x, n_q=n_q) | |
codes = codes.transpose(0, 1) | |
# codes is [B, K, T], with T frames, K nb of codebooks. | |
return codes | |
def decode(self, codes: torch.Tensor) -> torch.Tensor: | |
"""Decode the given codes to the quantized representation.""" | |
# codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T]. | |
codes = codes.transpose(0, 1) | |
quantized = self.vq.decode(codes) | |
return quantized | |
def total_codebooks(self): | |
return self.max_n_q | |
def num_codebooks(self): | |
return self.n_q | |
def set_num_codebooks(self, n: int): | |
assert n > 0 and n <= self.max_n_q | |
self.n_q = n | |