wav2vec2 / src /model /wav2vec2.py
hoang1007
init
5381499
raw
history blame contribute delete
No virus
10.3 kB
"""
A wrapper of Wav2Vec2 for training phase.
"""
from typing import Tuple, Optional
import torch
from pytorch_lightning import LightningModule
import einops
from torchmetrics import MeanMetric
from .modules import (
ContextEncoder,
FeatureExtractor,
QuantizationModule,
Wav2Vec2Processor,
)
from src.utils import init_module_weights
class Wav2Vec2PretrainingModule(LightningModule):
def __init__(self, config):
super().__init__()
self.save_hyperparameters(config)
self.processor = Wav2Vec2Processor()
self.context_encoder = ContextEncoder(config.context_encoder)
self.feature_extractor = FeatureExtractor(config.feature_extractor)
self.quantizer = QuantizationModule(config.quantizer)
self.train_loss = MeanMetric()
def forward(self, waveforms: Tuple[torch.Tensor, ...]):
"""
Args:
waveforms (Tuple[torch.Tensor]): The waveforms. Shape: (batch_size, wave_length).
Returns:
loss: The loss of the model. Contrastive loss + Diversity loss.
"""
waveforms, wave_lengths = self.processor(waveforms)
# features.shape == (batch_size, num_frames, hidden_size)
features, num_frames = self.feature_extractor(waveforms, wave_lengths)
attention_mask = self._compute_attention_mask(num_frames)
mask_time_indices = self._compute_mask_span(
shape=features.shape[:-1],
mask_prob=self.hparams.mask_prob,
mask_length=self.hparams.mask_length,
attention_mask=attention_mask,
device=features.device,
min_masks=self.hparams.min_masks,
)
context_features = self.context_encoder(
features, attention_mask=attention_mask, mask_time_indices=mask_time_indices
)
quantized_features, perplexity = self.quantizer(features, attention_mask)
negative_quantized_features = self._sample_negatives(
quantized_features,
num_negatives=self.hparams.num_negatives,
attention_mask=attention_mask,
)
# (batch_size, num_frames, num_negatives + 1)
contrastive_logits = self._compute_contrastive_logits(
context_features,
quantized_features,
negative_quantized_features,
self.hparams.contrastive_logits_temperature,
).flatten(0, -2)
# compute contrastive loss
# positive indices are always the first one
targets = (1 - mask_time_indices.long().flatten()) * -100
contrastive_loss = torch.nn.functional.cross_entropy(
contrastive_logits, targets, reduction="sum"
)
# compute diversity loss
diversity_loss = 1 - perplexity / self.quantizer.total_codewords
loss = contrastive_loss + diversity_loss * self.hparams.diversity_loss_weight
return loss
@staticmethod
def _sample_negatives(
features: torch.Tensor,
num_negatives: int,
attention_mask: Optional[torch.Tensor] = None,
):
"""
Sampling negative features from quantized features to compute the contrastive loss.
Args:
features (torch.Tensor): The quantized features. Shape: (batch_size, num_frames, d_model).
num_negatives (int): The number of negative samples.
attention_mask (Optional[torch.Tensor]): The mask for valid frames. `True` is invalid. Shape: (batch_size, num_frames).
Returns:
sampled_negatives (torch.Tensor): The sampled negative features. Shape: (batch_size, num_frames, num_negatives, d_model).
"""
batch_size, num_frames, d_model = features.shape
features = features.view(-1, d_model) # (batch_size * num_frames, d_model)
with torch.no_grad():
sampled_ids = []
for batch_idx in range(batch_size):
num_valid_frames = (
features.size(1)
if attention_mask is None
else (1 - attention_mask[batch_idx].long()).sum()
).item()
sampled_ids.append(
torch.randint(
0,
num_valid_frames - 1,
(num_frames * num_negatives,),
device=features.device,
)
)
sampled_ids = torch.stack(
sampled_ids, dim=0
) # (batch_size, num_frames * num_negatives)
feature_ids = einops.repeat(
torch.arange(num_frames, device=features.device),
"f -> (f n)",
n=num_negatives,
)
# avoid sampling the same positive vector, but keep the distribution uniform
sampled_ids[sampled_ids >= feature_ids] += 1
# correct for batch size
# E.g [[0, 1, 2], [0, 1, 2]] -> [0, 1, 2, 3, 4, 5]
sampled_ids += torch.arange(
0, batch_size * num_frames, num_frames, device=features.device
).unsqueeze_(-1)
sampled_negatives = features[sampled_ids.view(-1)]
sampled_negatives = einops.rearrange(
sampled_negatives,
"(b f n) d -> b f n d",
b=batch_size,
f=num_frames,
n=num_negatives,
)
return sampled_negatives
@staticmethod
def _compute_contrastive_logits(
predicted_features: torch.Tensor,
target_features: torch.Tensor,
negative_features: torch.Tensor,
temperature: int = 1,
):
"""
Compute the logits for contrastive loss.
Args:
predicted_features (torch.Tensor): The predicted features. Shape: (batch_size, num_frames, d_model).
target_features (torch.Tensor): The target features. Shape: (batch_size, num_frames, d_model).
negative_features (torch.Tensor): The negative features. Shape: (batch_size, num_frames, num_negatives, d_model).
temperature (int): The temperature for contrastive loss.
Returns:
logits (torch.Tensor): The logits for contrastive loss. Shape: (batch_size, num_frames, num_negatives + 1).
"""
# (batch_size, num_frames, num_negatives + 1, d_model)
target_features = torch.cat(
(target_features.unsqueeze_(2), negative_features), dim=2
)
# (batch_size, num_frames, 1, d_model)
predicted_features = predicted_features.unsqueeze_(2)
# (batch_size, num_frames, num_negatives + 1)
logits = torch.cosine_similarity(predicted_features, target_features, dim=-1)
logits /= temperature
return logits
@staticmethod
def _compute_mask_span(
shape: Tuple[int, int],
mask_prob: float = 0.065,
mask_length: int = 10,
attention_mask: Optional[torch.Tensor] = None,
device: torch.device = torch.device("cpu"),
min_masks: int = 0,
):
"""
Compute the mask span for contrastive task.
Args:
shape (Tuple[int, int]): The shape of the mask span. Shape: (batch_size, num_frames).
mask_prob (float): The probability of choosing a frame to be the start of masking position.
mask_length (int): The length of the mask span.
attention_mask (Optional[torch.Tensor]): The mask for valid frames. `True` is invalid. Shape: (batch_size, num_frames).
device (torch.device): The device of the mask span.
min_masks (int): The minimum number of masks.
Returns:
mask_span (torch.Tensor): The mask span. Shape: (batch_size, num_frames).
"""
batch_size, num_frames = shape
# NOTE: num_frames / mask_length: the number of spans in one waveform
num_masked_spans = int(
mask_prob * num_frames / mask_length + torch.rand(1).item()
)
num_masked_spans = max(num_masked_spans, min_masks)
# make sure num masked indices <= num frames
if num_masked_spans * mask_length > num_frames:
num_masked_spans = num_frames // mask_length
# uniform distribution to sample from
# NOTE: num_frames - (mask_length - 1): the number of start positions of the span
uniform_dist = torch.ones(
(batch_size, num_frames - (mask_length - 1)), device=device
)
# (batch_size, num_masked_spans)
mask_span_ids = torch.multinomial(uniform_dist, num_masked_spans)
# (batch_size, num_masked_spans * mask_length)
mask_span_ids = einops.repeat(mask_span_ids, "b n -> b (n l)", l=mask_length)
offsets = einops.repeat(
torch.arange(mask_length, device=device),
"l -> b (n l)",
b=batch_size,
n=num_masked_spans,
)
mask_span_ids = mask_span_ids + offsets
mask_span = torch.zeros(shape, device=device, dtype=torch.bool)
mask_span = mask_span.scatter_(1, mask_span_ids, True)
if attention_mask is not None:
# Make sure the invalid frames are not masked
mask_span = torch.where(attention_mask.bool(), mask_span, False)
return mask_span
@staticmethod
def _compute_attention_mask(length: torch.Tensor):
"""
Args:
length (Tensor): The length of valid frames. Shape: (batch)
max_length (int): The maximum length of the frames.
Returns:
attention_mask (BoolTensor): The mask for the valid frames. `True` is invalid. Shape: (batch, num_frames)
"""
max_length = length.max().item()
mask = (
torch.arange(max_length, device=length.device).expand(
length.size(0), max_length
)
>= length[:, None]
)
return mask
def training_step(self, batch, batch_idx):
loss = self(batch)
self.train_loss(loss)
if batch_idx % 100 == 0:
self.log("train/loss", self.train_loss, on_step=True, on_epoch=True)
return loss
def configure_optimizers(self):
return torch.optim.AdamW(self.parameters(), lr=1e-4)