|
""" |
|
The CBHG model implementation |
|
""" |
|
from typing import List, Optional |
|
|
|
from torch import nn |
|
import torch |
|
|
|
from poetry_diacritizer.modules.tacotron_modules import CBHG, Prenet |
|
|
|
|
|
class CBHGModel(nn.Module): |
|
"""CBHG model implementation as described in the paper: |
|
https://ieeexplore.ieee.org/document/9274427 |
|
|
|
Args: |
|
inp_vocab_size (int): the number of the input symbols |
|
targ_vocab_size (int): the number of the target symbols (diacritics) |
|
embedding_dim (int): the embedding size |
|
use_prenet (bool): whether to use prenet or not |
|
prenet_sizes (List[int]): the sizes of the prenet networks |
|
cbhg_gru_units (int): the number of units of the CBHG GRU, which is the last |
|
layer of the CBHG Model. |
|
cbhg_filters (int): number of filters used in the CBHG module |
|
cbhg_projections: projections used in the CBHG module |
|
|
|
Returns: |
|
diacritics Dict[str, Tensor]: |
|
""" |
|
|
|
def __init__( |
|
self, |
|
inp_vocab_size: int, |
|
targ_vocab_size: int, |
|
embedding_dim: int = 512, |
|
use_prenet: bool = True, |
|
prenet_sizes: List[int] = [512, 256], |
|
cbhg_gru_units: int = 512, |
|
cbhg_filters: int = 16, |
|
cbhg_projections: List[int] = [128, 256], |
|
post_cbhg_layers_units: List[int] = [256, 256], |
|
post_cbhg_use_batch_norm: bool = True |
|
): |
|
super().__init__() |
|
self.use_prenet = use_prenet |
|
self.embedding = nn.Embedding(inp_vocab_size, embedding_dim) |
|
if self.use_prenet: |
|
self.prenet = Prenet(embedding_dim, prenet_depth=prenet_sizes) |
|
|
|
self.cbhg = CBHG( |
|
prenet_sizes[-1] if self.use_prenet else embedding_dim, |
|
cbhg_gru_units, |
|
K=cbhg_filters, |
|
projections=cbhg_projections, |
|
) |
|
|
|
layers = [] |
|
post_cbhg_layers_units = [cbhg_gru_units] + post_cbhg_layers_units |
|
|
|
for i in range(1, len(post_cbhg_layers_units)): |
|
layers.append( |
|
nn.LSTM( |
|
post_cbhg_layers_units[i - 1] * 2, |
|
post_cbhg_layers_units[i], |
|
bidirectional=True, |
|
batch_first=True, |
|
) |
|
) |
|
if post_cbhg_use_batch_norm: |
|
layers.append(nn.BatchNorm1d(post_cbhg_layers_units[i] * 2)) |
|
|
|
self.post_cbhg_layers = nn.ModuleList(layers) |
|
self.projections = nn.Linear(post_cbhg_layers_units[-1] * 2, targ_vocab_size) |
|
self.post_cbhg_layers_units = post_cbhg_layers_units |
|
self.post_cbhg_use_batch_norm = post_cbhg_use_batch_norm |
|
|
|
|
|
def forward( |
|
self, |
|
src: torch.Tensor, |
|
lengths: Optional[torch.Tensor] = None, |
|
target: Optional[torch.Tensor] = None, |
|
): |
|
"""Compute forward propagation""" |
|
|
|
|
|
|
|
|
|
|
|
embedding_out = self.embedding(src) |
|
|
|
|
|
cbhg_input = embedding_out |
|
if self.use_prenet: |
|
cbhg_input = self.prenet(embedding_out) |
|
|
|
|
|
|
|
outputs = self.cbhg(cbhg_input, lengths) |
|
|
|
hn = torch.zeros((2, 2, 2)) |
|
cn = torch.zeros((2, 2, 2)) |
|
|
|
for i, layer in enumerate(self.post_cbhg_layers): |
|
if isinstance(layer, nn.BatchNorm1d): |
|
outputs = layer(outputs.permute(0, 2, 1)) |
|
outputs = outputs.permute(0, 2, 1) |
|
continue |
|
if i > 0: |
|
outputs, (hn, cn) = layer(outputs, (hn, cn)) |
|
else: |
|
outputs, (hn, cn) = layer(outputs) |
|
|
|
|
|
predictions = self.projections(outputs) |
|
|
|
|
|
|
|
output = {"diacritics": predictions} |
|
|
|
return output |
|
|