wav2vec2 / src /model /modules /context_encoder.py
hoang1007
init
5381499
from typing import Optional
import torch
from torch import nn
import torch.nn.functional as F
from .transformers import EncoderLayer
class FeatureProjection(nn.Module):
def __init__(self, in_features: int, out_features: int, dropout: float = 0.1):
"""
Projects the extracted features to the encoder dimension.
Args:
x (Tensor): The input features. Shape: (batch, num_frames, in_features)
Returns:
hiddens (Tensor): The latent features. Shape: (batch, num_frames, out_features)
"""
super().__init__()
self.projection = nn.Linear(in_features, out_features)
self.layernorm = nn.LayerNorm(in_features)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor):
hiddens = self.layernorm(x)
hiddens = self.projection(x)
hiddens = self.dropout(hiddens)
return hiddens
class RelativePositionalEmbedding(nn.Module):
def __init__(
self, d_model: int, kernel_size: int, groups: int, dropout: float = 0.1
):
"""
Args:
x (Tensor): The extracted features. Shape: (batch, num_frames, d_model)
Returns:
out (Tensor): The output which encoded the relative positional information. Shape: (batch, num_frames, d_model)
"""
super().__init__()
self.conv = nn.Conv1d(
in_channels=d_model,
out_channels=d_model,
kernel_size=kernel_size,
padding=kernel_size // 2,
groups=groups,
)
self.dropout = nn.Dropout(dropout)
self.num_remove = 1 if kernel_size % 2 == 0 else 0
def forward(self, x: torch.Tensor):
# (batch, channels=d_model, num_frames)
out = x.transpose(1, 2)
out = self.conv(out)
if self.num_remove > 0:
out = out[..., : -self.num_remove]
out = F.gelu(out)
# (batch, num_frames, channels=d_model)
out = out.transpose_(1, 2)
out = out + x
out = self.dropout(out)
return out
class TranformerEncoder(nn.Module):
def __init__(self, config):
"""
Args:
x (Tensor): The extracted features. Shape: (batch, num_frames, d_model)
mask (Tensor): The mask for the valid frames. Shape: (batch, num_frames)
Returns:
out (Tensor): The output of the transformer encoder. Shape: (batch, num_frames, d_model)
"""
super().__init__()
self.pos_embedding = RelativePositionalEmbedding(**config.pos_embedding)
self.layernorm = nn.LayerNorm(config.d_model)
self.layer_drop = config.layer_drop
self.layers = nn.ModuleList(
EncoderLayer(**config.layer) for _ in range(config.num_layers)
)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
out = self.pos_embedding(x)
for layer in self.layers:
skip_layer = self.training and torch.rand(1).item() < self.layer_drop
if skip_layer:
continue
else:
out, _ = layer(out, attention_mask=mask)
out = self.layernorm(out)
return out
class ContextEncoder(nn.Module):
def __init__(self, config):
"""
Args:
x (Tensor): The extracted features. Shape: (batch, num_frames, in_features)
attention_mask (BoolTensor): The mask for the valid frames. `True` is invalid. Shape: (batch, num_frames)
"""
super().__init__()
self.feature_projection = FeatureProjection(**config.feature_projection)
self.encoder = TranformerEncoder(config.encoder)
self.masked_spec_embed = nn.Parameter(
torch.FloatTensor(config.feature_projection.out_features).uniform_()
)
def forward(
self,
x: torch.Tensor,
attention_mask: torch.Tensor = None,
mask_time_indices: torch.Tensor = None,
):
x = self.feature_projection(x)
if mask_time_indices is not None:
x[mask_time_indices] = self.masked_spec_embed.to(x.dtype)
if attention_mask is not None:
x[attention_mask] = 0.0 # turn invalid frames to zero
attention_mask = attention_mask[:, None, None, :]
# (batch, 1, num_frames, num_frames)
# mask = mask[:, None, None, :].repeat(1, 1, mask.size(1), 1) # TODO: check this
attention_mask = (
torch.maximum(attention_mask, attention_mask.transpose(2, 3)) * -1e6
)
x = self.encoder(x, mask=attention_mask)
return x