|
from typing import Optional |
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from .transformers import EncoderLayer |
|
|
|
|
|
class FeatureProjection(nn.Module): |
|
def __init__(self, in_features: int, out_features: int, dropout: float = 0.1): |
|
""" |
|
Projects the extracted features to the encoder dimension. |
|
|
|
Args: |
|
x (Tensor): The input features. Shape: (batch, num_frames, in_features) |
|
|
|
Returns: |
|
hiddens (Tensor): The latent features. Shape: (batch, num_frames, out_features) |
|
""" |
|
super().__init__() |
|
|
|
self.projection = nn.Linear(in_features, out_features) |
|
self.layernorm = nn.LayerNorm(in_features) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, x: torch.Tensor): |
|
|
|
hiddens = self.layernorm(x) |
|
hiddens = self.projection(x) |
|
hiddens = self.dropout(hiddens) |
|
return hiddens |
|
|
|
|
|
class RelativePositionalEmbedding(nn.Module): |
|
def __init__( |
|
self, d_model: int, kernel_size: int, groups: int, dropout: float = 0.1 |
|
): |
|
""" |
|
Args: |
|
x (Tensor): The extracted features. Shape: (batch, num_frames, d_model) |
|
|
|
Returns: |
|
out (Tensor): The output which encoded the relative positional information. Shape: (batch, num_frames, d_model) |
|
""" |
|
super().__init__() |
|
|
|
self.conv = nn.Conv1d( |
|
in_channels=d_model, |
|
out_channels=d_model, |
|
kernel_size=kernel_size, |
|
padding=kernel_size // 2, |
|
groups=groups, |
|
) |
|
self.dropout = nn.Dropout(dropout) |
|
self.num_remove = 1 if kernel_size % 2 == 0 else 0 |
|
|
|
def forward(self, x: torch.Tensor): |
|
|
|
out = x.transpose(1, 2) |
|
|
|
out = self.conv(out) |
|
|
|
if self.num_remove > 0: |
|
out = out[..., : -self.num_remove] |
|
|
|
out = F.gelu(out) |
|
|
|
|
|
out = out.transpose_(1, 2) |
|
out = out + x |
|
out = self.dropout(out) |
|
|
|
return out |
|
|
|
|
|
class TranformerEncoder(nn.Module): |
|
def __init__(self, config): |
|
""" |
|
Args: |
|
x (Tensor): The extracted features. Shape: (batch, num_frames, d_model) |
|
mask (Tensor): The mask for the valid frames. Shape: (batch, num_frames) |
|
|
|
Returns: |
|
out (Tensor): The output of the transformer encoder. Shape: (batch, num_frames, d_model) |
|
""" |
|
super().__init__() |
|
|
|
self.pos_embedding = RelativePositionalEmbedding(**config.pos_embedding) |
|
self.layernorm = nn.LayerNorm(config.d_model) |
|
self.layer_drop = config.layer_drop |
|
|
|
self.layers = nn.ModuleList( |
|
EncoderLayer(**config.layer) for _ in range(config.num_layers) |
|
) |
|
|
|
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None): |
|
out = self.pos_embedding(x) |
|
|
|
for layer in self.layers: |
|
skip_layer = self.training and torch.rand(1).item() < self.layer_drop |
|
|
|
if skip_layer: |
|
continue |
|
else: |
|
out, _ = layer(out, attention_mask=mask) |
|
|
|
out = self.layernorm(out) |
|
|
|
return out |
|
|
|
|
|
class ContextEncoder(nn.Module): |
|
def __init__(self, config): |
|
""" |
|
Args: |
|
x (Tensor): The extracted features. Shape: (batch, num_frames, in_features) |
|
attention_mask (BoolTensor): The mask for the valid frames. `True` is invalid. Shape: (batch, num_frames) |
|
""" |
|
super().__init__() |
|
|
|
self.feature_projection = FeatureProjection(**config.feature_projection) |
|
self.encoder = TranformerEncoder(config.encoder) |
|
self.masked_spec_embed = nn.Parameter( |
|
torch.FloatTensor(config.feature_projection.out_features).uniform_() |
|
) |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
attention_mask: torch.Tensor = None, |
|
mask_time_indices: torch.Tensor = None, |
|
): |
|
x = self.feature_projection(x) |
|
|
|
if mask_time_indices is not None: |
|
x[mask_time_indices] = self.masked_spec_embed.to(x.dtype) |
|
|
|
if attention_mask is not None: |
|
x[attention_mask] = 0.0 |
|
|
|
attention_mask = attention_mask[:, None, None, :] |
|
|
|
|
|
attention_mask = ( |
|
torch.maximum(attention_mask, attention_mask.transpose(2, 3)) * -1e6 |
|
) |
|
|
|
x = self.encoder(x, mask=attention_mask) |
|
|
|
return x |
|
|