|
from typing import List |
|
from poetry_diacritizer.models.seq2seq import Seq2Seq, Decoder as Seq2SeqDecoder |
|
from poetry_diacritizer.modules.tacotron_modules import CBHG, Prenet |
|
from torch import nn |
|
|
|
|
|
class Tacotron(Seq2Seq): |
|
pass |
|
|
|
|
|
class Encoder(nn.Module): |
|
def __init__( |
|
self, |
|
inp_vocab_size: int, |
|
embedding_dim: int = 512, |
|
use_prenet: bool = True, |
|
prenet_sizes: List[int] = [256, 128], |
|
cbhg_gru_units: int = 128, |
|
cbhg_filters: int = 16, |
|
cbhg_projections: List[int] = [128, 128], |
|
padding_idx: int = 0, |
|
): |
|
super().__init__() |
|
self.use_prenet = use_prenet |
|
|
|
self.embedding = nn.Embedding( |
|
inp_vocab_size, embedding_dim, padding_idx=padding_idx |
|
) |
|
if use_prenet: |
|
self.prenet = Prenet(embedding_dim, prenet_depth=prenet_sizes) |
|
self.cbhg = CBHG( |
|
prenet_sizes[-1] if use_prenet else embedding_dim, |
|
cbhg_gru_units, |
|
K=cbhg_filters, |
|
projections=cbhg_projections, |
|
) |
|
|
|
def forward(self, inputs, input_lengths=None): |
|
|
|
outputs = self.embedding(inputs) |
|
if self.use_prenet: |
|
outputs = self.prenet(outputs) |
|
return self.cbhg(outputs, input_lengths) |
|
|
|
|
|
class Decoder(Seq2SeqDecoder): |
|
pass |
|
|