|
|
|
|
|
|
|
|
|
|
|
from typing import Union |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import rearrange |
|
from torch.nn.utils import weight_norm |
|
|
|
from models.codec.amphion_codec.quantize.factorized_vector_quantize import ( |
|
FactorizedVectorQuantize, |
|
) |
|
from models.codec.amphion_codec.quantize.vector_quantize import VectorQuantize |
|
from models.codec.amphion_codec.quantize.lookup_free_quantize import LookupFreeQuantize |
|
|
|
|
|
class ResidualVQ(nn.Module): |
|
""" |
|
Introduced in SoundStream: An end2end neural audio codec |
|
https://arxiv.org/abs/2107.03312 |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_dim: int = 256, |
|
num_quantizers: int = 8, |
|
codebook_size: int = 1024, |
|
codebook_dim: int = 256, |
|
quantizer_type: str = "vq", |
|
quantizer_dropout: float = 0.5, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
|
|
self.input_dim = input_dim |
|
self.num_quantizers = num_quantizers |
|
self.codebook_size = codebook_size |
|
self.codebook_dim = codebook_dim |
|
self.quantizer_type = quantizer_type |
|
self.quantizer_dropout = quantizer_dropout |
|
|
|
if quantizer_type == "vq": |
|
VQ = VectorQuantize |
|
elif quantizer_type == "fvq": |
|
VQ = FactorizedVectorQuantize |
|
elif quantizer_type == "lfq": |
|
VQ = LookupFreeQuantize |
|
else: |
|
raise ValueError(f"Unknown quantizer type {quantizer_type}") |
|
|
|
self.quantizers = nn.ModuleList( |
|
[ |
|
VQ( |
|
input_dim=input_dim, |
|
codebook_size=codebook_size, |
|
codebook_dim=codebook_dim, |
|
**kwargs, |
|
) |
|
for _ in range(num_quantizers) |
|
] |
|
) |
|
|
|
def forward(self, z, n_quantizers: int = None): |
|
""" |
|
Parameters |
|
---------- |
|
z : Tensor[B x D x T] |
|
n_quantizers : int, optional |
|
No. of quantizers to use |
|
(n_quantizers < self.n_codebooks ex: for quantizer dropout) |
|
Note: if `self.quantizer_dropout` is True, this argument is ignored |
|
when in training mode, and a random number of quantizers is used. |
|
Returns |
|
------- |
|
"quantized_out" : Tensor[B x D x T] |
|
Quantized continuous representation of input |
|
"all_indices" : Tensor[N x B x T] |
|
Codebook indices for each codebook |
|
(quantized discrete representation of input) |
|
"all_commit_losses" : Tensor[N] |
|
"all_codebook_losses" : Tensor[N] |
|
"all_quantized" : Tensor[N x B x D x T] |
|
""" |
|
|
|
quantized_out = 0.0 |
|
residual = z |
|
|
|
all_commit_losses = [] |
|
all_codebook_losses = [] |
|
all_indices = [] |
|
all_quantized = [] |
|
|
|
if n_quantizers is None: |
|
n_quantizers = self.num_quantizers |
|
|
|
if self.training: |
|
n_quantizers = torch.ones((z.shape[0],)) * self.num_quantizers + 1 |
|
dropout = torch.randint(1, self.num_quantizers + 1, (z.shape[0],)) |
|
n_dropout = int(z.shape[0] * self.quantizer_dropout) |
|
n_quantizers[:n_dropout] = dropout[:n_dropout] |
|
n_quantizers = n_quantizers.to(z.device) |
|
|
|
for i, quantizer in enumerate(self.quantizers): |
|
if self.training is False and i >= n_quantizers: |
|
break |
|
|
|
z_q_i, commit_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( |
|
residual |
|
) |
|
|
|
|
|
mask = ( |
|
torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers |
|
) |
|
quantized_out = quantized_out + z_q_i * mask[:, None, None] |
|
residual = residual - z_q_i |
|
|
|
commit_loss_i = (commit_loss_i * mask).mean() |
|
codebook_loss_i = (codebook_loss_i * mask).mean() |
|
|
|
all_commit_losses.append(commit_loss_i) |
|
all_codebook_losses.append(codebook_loss_i) |
|
all_indices.append(indices_i) |
|
all_quantized.append(z_q_i) |
|
|
|
all_commit_losses, all_codebook_losses, all_indices, all_quantized = map( |
|
torch.stack, |
|
(all_commit_losses, all_codebook_losses, all_indices, all_quantized), |
|
) |
|
|
|
return ( |
|
quantized_out, |
|
all_indices, |
|
all_commit_losses, |
|
all_codebook_losses, |
|
all_quantized, |
|
) |
|
|
|
def vq2emb(self, vq, n_quantizers=None): |
|
quantized_out = 0.0 |
|
if n_quantizers is None: |
|
n_quantizers = self.num_quantizers |
|
for idx, quantizer in enumerate(self.quantizers): |
|
if idx >= n_quantizers: |
|
break |
|
quantized_out += quantizer.vq2emb(vq[idx]) |
|
return quantized_out |
|
|
|
def latent2dist(self, z, n_quantizers=None): |
|
quantized_out = 0.0 |
|
residual = z |
|
|
|
all_dists = [] |
|
all_indices = [] |
|
|
|
if n_quantizers is None: |
|
n_quantizers = self.num_quantizers |
|
|
|
for i, quantizer in enumerate(self.quantizers): |
|
if self.training is False and i >= n_quantizers: |
|
break |
|
dist_i, indices_i, z_q_i = quantizer.latent2dist(residual) |
|
all_dists.append(dist_i) |
|
all_indices.append(indices_i) |
|
|
|
quantized_out = quantized_out + z_q_i |
|
residual = residual - z_q_i |
|
|
|
all_dists = torch.stack(all_dists) |
|
all_indices = torch.stack(all_indices) |
|
|
|
return all_dists, all_indices |
|
|