wav2vec2 / src /model /modules /quantization.py
hoang1007
init
5381499
raw
history blame
3.69 kB
from typing import Optional
import torch
from torch import nn
import torch.nn.functional as F
import einops
class QuantizationModule(nn.Module):
def __init__(
self, config
):
"""
Args:
x (Tensor): The extracted features from waveforms. Shape: (batch, num_frames, in_features)
mask (BoolTensor): The mask for the valid frames. `True` is invalid. Shape: (batch, num_frames)
Returns:
out (Tensor): The quantized features. Shape: (batch, num_frames, d_model)
perplexity (Tensor): The perplexity of the quantized features. Shape: (1)
"""
super().__init__()
assert (
config.d_model % config.num_codebooks == 0
), "d_model must be divisible by num_codebooks"
self.num_codebooks = config.num_codebooks
self.num_codewords = config.num_codewords
self.d_model = config.d_model
self.codeword_dim = config.d_model // config.num_codebooks
self.codebooks = self._init_codebooks()
self.projection = nn.Linear(
config.in_features, self.num_codebooks * self.num_codewords
)
self.tau = 1 # temperature factor
def _init_codebooks(self):
codebooks = torch.randn(
1, 1, self.num_codebooks, self.num_codewords, self.codeword_dim
)
nn.init.xavier_uniform_(codebooks)
return nn.Parameter(codebooks)
@property
def total_codewords(self):
return self.num_codebooks * self.num_codewords
@staticmethod
def _compute_perplexity(probs: torch.Tensor, mask: Optional[torch.Tensor] = None):
"""
Computes the perplexity of the quantized features. (Diversity loss)
Args:
probs (Tensor): The probability distribution of words in codebooks. Shape: (batch, num_frames, num_codebooks, num_codewords)
mask (BoolTensor): The mask for the valid frames. `True` is invalid. Shape: (batch, num_frames)
"""
if mask is not None:
probs = (
probs * ~mask[..., None, None]
) # Turn invalid frames' probability to 0
marginal_probs = (
einops.reduce(probs, "b nf nb nw -> nb nw", "sum") / mask.sum()
)
else:
marginal_probs = einops.reduce(probs, "b nf nb nw -> nb nw", "mean")
perplexity = torch.exp(
-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)
).sum()
return perplexity
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
batch_size, num_frames, _ = x.shape
logits = self.projection(x)
logits = logits.view(
batch_size, num_frames, self.num_codebooks, self.num_codewords
)
if self.training:
word_probs = F.gumbel_softmax(logits, tau=self.tau, hard=True, dim=-1)
word_soft_probs = F.softmax(logits, dim=-1)
perplexity = self._compute_perplexity(word_soft_probs, mask=mask)
else:
word_ids = torch.argmax(logits, dim=-1, keepdim=True)
word_probs = torch.zeros_like(logits).scatter_(-1, word_ids, 1.0) # One-hot
perplexity = self._compute_perplexity(word_probs, mask=mask)
# (batch, num_frames, num_codebooks, num_codewords, 1) x (1, 1, num_codebooks, num_codewords, codeword_dim)
# -> (batch, num_frames, num_codebooks x codeword_dim)
quantized = einops.reduce(
word_probs.unsqueeze_(-1) * self.codebooks,
"b nf nb nw d -> b nf (nb d)",
reduction="sum",
)
return quantized, perplexity