easyGUI / rvc /f0 /e2e.py
Blane187's picture
Upload 39 files
c3b58fa verified
raw
history blame
No virus
1.7 kB
from typing import Tuple
import torch.nn as nn
from .deepunet import DeepUnet
class E2E(nn.Module):
def __init__(
self,
n_blocks: int,
n_gru: int,
kernel_size: Tuple[int, int],
en_de_layers=5,
inter_layers=4,
in_channels=1,
en_out_channels=16,
):
super(E2E, self).__init__()
self.unet = DeepUnet(
kernel_size,
n_blocks,
en_de_layers,
inter_layers,
in_channels,
en_out_channels,
)
self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
if n_gru:
self.fc = nn.Sequential(
self.BiGRU(3 * 128, 256, n_gru),
nn.Linear(512, 360),
nn.Dropout(0.25),
nn.Sigmoid(),
)
else:
self.fc = nn.Sequential(
nn.Linear(3 * nn.N_MELS, nn.N_CLASS),
nn.Dropout(0.25),
nn.Sigmoid(),
)
def forward(self, mel):
mel = mel.transpose(-1, -2).unsqueeze(1)
x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
x = self.fc(x)
return x
class BiGRU(nn.Module):
def __init__(
self,
input_features: int,
hidden_features: int,
num_layers: int,
):
super().__init__()
self.gru = nn.GRU(
input_features,
hidden_features,
num_layers=num_layers,
batch_first=True,
bidirectional=True,
)
def forward(self, x):
return self.gru(x)[0]