|
from typing import List |
|
from torch import nn |
|
import torch |
|
|
|
|
|
class BaseLineModel(nn.Module): |
|
def __init__( |
|
self, |
|
inp_vocab_size: int, |
|
targ_vocab_size: int, |
|
embedding_dim: int = 512, |
|
layers_units: List[int] = [256, 256, 256], |
|
use_batch_norm: bool = False, |
|
): |
|
super().__init__() |
|
self.targ_vocab_size = targ_vocab_size |
|
self.embedding = nn.Embedding(inp_vocab_size, embedding_dim) |
|
|
|
layers_units = [embedding_dim // 2] + layers_units |
|
|
|
layers = [] |
|
|
|
for i in range(1, len(layers_units)): |
|
layers.append( |
|
nn.LSTM( |
|
layers_units[i - 1] * 2, |
|
layers_units[i], |
|
bidirectional=True, |
|
batch_first=True, |
|
) |
|
) |
|
if use_batch_norm: |
|
layers.append(nn.BatchNorm1d(layers_units[i] * 2)) |
|
|
|
self.layers = nn.ModuleList(layers) |
|
self.projections = nn.Linear(layers_units[-1] * 2, targ_vocab_size) |
|
self.layers_units = layers_units |
|
self.use_batch_norm = use_batch_norm |
|
|
|
def forward(self, src: torch.Tensor, lengths: torch.Tensor, target=None): |
|
|
|
outputs = self.embedding(src) |
|
|
|
|
|
|
|
for i, layer in enumerate(self.layers): |
|
if isinstance(layer, nn.BatchNorm1d): |
|
outputs = layer(outputs.permute(0, 2, 1)) |
|
outputs = outputs.permute(0, 2, 1) |
|
continue |
|
if i > 0: |
|
outputs, (hn, cn) = layer(outputs, (hn, cn)) |
|
else: |
|
outputs, (hn, cn) = layer(outputs) |
|
|
|
predictions = self.projections(outputs) |
|
|
|
output = {"diacritics": predictions} |
|
|
|
return output |
|
|