from typing import Tuple import torch.nn as nn from .deepunet import DeepUnet class E2E(nn.Module): def __init__( self, n_blocks: int, n_gru: int, kernel_size: Tuple[int, int], en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16, ): super(E2E, self).__init__() self.unet = DeepUnet( kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels, ) self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1)) if n_gru: self.fc = nn.Sequential( self.BiGRU(3 * 128, 256, n_gru), nn.Linear(512, 360), nn.Dropout(0.25), nn.Sigmoid(), ) else: self.fc = nn.Sequential( nn.Linear(3 * nn.N_MELS, nn.N_CLASS), nn.Dropout(0.25), nn.Sigmoid(), ) def forward(self, mel): mel = mel.transpose(-1, -2).unsqueeze(1) x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2) x = self.fc(x) return x class BiGRU(nn.Module): def __init__( self, input_features: int, hidden_features: int, num_layers: int, ): super().__init__() self.gru = nn.GRU( input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True, ) def forward(self, x): return self.gru(x)[0]