|
from typing import Optional |
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
import einops |
|
|
|
|
|
class QuantizationModule(nn.Module): |
|
def __init__( |
|
self, config |
|
): |
|
""" |
|
Args: |
|
x (Tensor): The extracted features from waveforms. Shape: (batch, num_frames, in_features) |
|
mask (BoolTensor): The mask for the valid frames. `True` is invalid. Shape: (batch, num_frames) |
|
|
|
Returns: |
|
out (Tensor): The quantized features. Shape: (batch, num_frames, d_model) |
|
perplexity (Tensor): The perplexity of the quantized features. Shape: (1) |
|
""" |
|
super().__init__() |
|
|
|
assert ( |
|
config.d_model % config.num_codebooks == 0 |
|
), "d_model must be divisible by num_codebooks" |
|
|
|
self.num_codebooks = config.num_codebooks |
|
self.num_codewords = config.num_codewords |
|
self.d_model = config.d_model |
|
self.codeword_dim = config.d_model // config.num_codebooks |
|
|
|
self.codebooks = self._init_codebooks() |
|
|
|
self.projection = nn.Linear( |
|
config.in_features, self.num_codebooks * self.num_codewords |
|
) |
|
|
|
self.tau = 1 |
|
|
|
def _init_codebooks(self): |
|
codebooks = torch.randn( |
|
1, 1, self.num_codebooks, self.num_codewords, self.codeword_dim |
|
) |
|
nn.init.xavier_uniform_(codebooks) |
|
|
|
return nn.Parameter(codebooks) |
|
|
|
@property |
|
def total_codewords(self): |
|
return self.num_codebooks * self.num_codewords |
|
|
|
@staticmethod |
|
def _compute_perplexity(probs: torch.Tensor, mask: Optional[torch.Tensor] = None): |
|
""" |
|
Computes the perplexity of the quantized features. (Diversity loss) |
|
|
|
Args: |
|
probs (Tensor): The probability distribution of words in codebooks. Shape: (batch, num_frames, num_codebooks, num_codewords) |
|
mask (BoolTensor): The mask for the valid frames. `True` is invalid. Shape: (batch, num_frames) |
|
""" |
|
if mask is not None: |
|
probs = ( |
|
probs * ~mask[..., None, None] |
|
) |
|
marginal_probs = ( |
|
einops.reduce(probs, "b nf nb nw -> nb nw", "sum") / mask.sum() |
|
) |
|
else: |
|
marginal_probs = einops.reduce(probs, "b nf nb nw -> nb nw", "mean") |
|
|
|
perplexity = torch.exp( |
|
-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1) |
|
).sum() |
|
return perplexity |
|
|
|
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None): |
|
batch_size, num_frames, _ = x.shape |
|
|
|
logits = self.projection(x) |
|
logits = logits.view( |
|
batch_size, num_frames, self.num_codebooks, self.num_codewords |
|
) |
|
|
|
if self.training: |
|
word_probs = F.gumbel_softmax(logits, tau=self.tau, hard=True, dim=-1) |
|
word_soft_probs = F.softmax(logits, dim=-1) |
|
|
|
perplexity = self._compute_perplexity(word_soft_probs, mask=mask) |
|
else: |
|
word_ids = torch.argmax(logits, dim=-1, keepdim=True) |
|
word_probs = torch.zeros_like(logits).scatter_(-1, word_ids, 1.0) |
|
|
|
perplexity = self._compute_perplexity(word_probs, mask=mask) |
|
|
|
|
|
|
|
quantized = einops.reduce( |
|
word_probs.unsqueeze_(-1) * self.codebooks, |
|
"b nf nb nw d -> b nf (nb d)", |
|
reduction="sum", |
|
) |
|
|
|
return quantized, perplexity |
|
|