File size: 1,820 Bytes
5112867
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from typing import List
from torch import nn
import torch


class BaseLineModel(nn.Module):
    def __init__(
        self,
        inp_vocab_size: int,
        targ_vocab_size: int,
        embedding_dim: int = 512,
        layers_units: List[int] = [256, 256, 256],
        use_batch_norm: bool = False,
    ):
        super().__init__()
        self.targ_vocab_size = targ_vocab_size
        self.embedding = nn.Embedding(inp_vocab_size, embedding_dim)

        layers_units = [embedding_dim // 2] + layers_units

        layers = []

        for i in range(1, len(layers_units)):
            layers.append(
                nn.LSTM(
                    layers_units[i - 1] * 2,
                    layers_units[i],
                    bidirectional=True,
                    batch_first=True,
                )
            )
            if use_batch_norm:
                layers.append(nn.BatchNorm1d(layers_units[i] * 2))

        self.layers = nn.ModuleList(layers)
        self.projections = nn.Linear(layers_units[-1] * 2, targ_vocab_size)
        self.layers_units = layers_units
        self.use_batch_norm = use_batch_norm

    def forward(self, src: torch.Tensor, lengths: torch.Tensor, target=None):

        outputs = self.embedding(src)

        # embedded_inputs = [batch_size, src_len, embedding_dim]

        for i, layer in enumerate(self.layers):
            if isinstance(layer, nn.BatchNorm1d):
                outputs = layer(outputs.permute(0, 2, 1))
                outputs = outputs.permute(0, 2, 1)
                continue
            if i > 0:
                outputs, (hn, cn) = layer(outputs, (hn, cn))
            else:
                outputs, (hn, cn) = layer(outputs)

        predictions = self.projections(outputs)

        output = {"diacritics": predictions}

        return output