File size: 3,920 Bytes
5112867 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
"""
The CBHG model implementation
"""
from typing import List, Optional
from torch import nn
import torch
from poetry_diacritizer.modules.tacotron_modules import CBHG, Prenet
class CBHGModel(nn.Module):
"""CBHG model implementation as described in the paper:
https://ieeexplore.ieee.org/document/9274427
Args:
inp_vocab_size (int): the number of the input symbols
targ_vocab_size (int): the number of the target symbols (diacritics)
embedding_dim (int): the embedding size
use_prenet (bool): whether to use prenet or not
prenet_sizes (List[int]): the sizes of the prenet networks
cbhg_gru_units (int): the number of units of the CBHG GRU, which is the last
layer of the CBHG Model.
cbhg_filters (int): number of filters used in the CBHG module
cbhg_projections: projections used in the CBHG module
Returns:
diacritics Dict[str, Tensor]:
"""
def __init__(
self,
inp_vocab_size: int,
targ_vocab_size: int,
embedding_dim: int = 512,
use_prenet: bool = True,
prenet_sizes: List[int] = [512, 256],
cbhg_gru_units: int = 512,
cbhg_filters: int = 16,
cbhg_projections: List[int] = [128, 256],
post_cbhg_layers_units: List[int] = [256, 256],
post_cbhg_use_batch_norm: bool = True
):
super().__init__()
self.use_prenet = use_prenet
self.embedding = nn.Embedding(inp_vocab_size, embedding_dim)
if self.use_prenet:
self.prenet = Prenet(embedding_dim, prenet_depth=prenet_sizes)
self.cbhg = CBHG(
prenet_sizes[-1] if self.use_prenet else embedding_dim,
cbhg_gru_units,
K=cbhg_filters,
projections=cbhg_projections,
)
layers = []
post_cbhg_layers_units = [cbhg_gru_units] + post_cbhg_layers_units
for i in range(1, len(post_cbhg_layers_units)):
layers.append(
nn.LSTM(
post_cbhg_layers_units[i - 1] * 2,
post_cbhg_layers_units[i],
bidirectional=True,
batch_first=True,
)
)
if post_cbhg_use_batch_norm:
layers.append(nn.BatchNorm1d(post_cbhg_layers_units[i] * 2))
self.post_cbhg_layers = nn.ModuleList(layers)
self.projections = nn.Linear(post_cbhg_layers_units[-1] * 2, targ_vocab_size)
self.post_cbhg_layers_units = post_cbhg_layers_units
self.post_cbhg_use_batch_norm = post_cbhg_use_batch_norm
def forward(
self,
src: torch.Tensor,
lengths: Optional[torch.Tensor] = None,
target: Optional[torch.Tensor] = None, # not required in this model
):
"""Compute forward propagation"""
# src = [batch_size, src len]
# lengths = [batch_size]
# target = [batch_size, trg len]
embedding_out = self.embedding(src)
# embedding_out; [batch_size, src_len, embedding_dim]
cbhg_input = embedding_out
if self.use_prenet:
cbhg_input = self.prenet(embedding_out)
# cbhg_input = [batch_size, src_len, prenet_sizes[-1]]
outputs = self.cbhg(cbhg_input, lengths)
hn = torch.zeros((2, 2, 2))
cn = torch.zeros((2, 2, 2))
for i, layer in enumerate(self.post_cbhg_layers):
if isinstance(layer, nn.BatchNorm1d):
outputs = layer(outputs.permute(0, 2, 1))
outputs = outputs.permute(0, 2, 1)
continue
if i > 0:
outputs, (hn, cn) = layer(outputs, (hn, cn))
else:
outputs, (hn, cn) = layer(outputs)
predictions = self.projections(outputs)
# predictions = [batch_size, src len, targ_vocab_size]
output = {"diacritics": predictions}
return output
|