File size: 1,812 Bytes
6c333b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9dfa9e
6c333b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9dfa9e
 
6c333b4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from transformers import PreTrainedModel

from torch import nn
import torch

class BiLSTM(PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.hidden_dim = config.hidden_dim
        self.predict_output = config.predict_output

        self.embed_layer = nn.Embedding(config.vocab_size, config.embed_dim, padding_idx=0)
        self.biLSTM = nn.LSTM(input_size=config.embed_dim,
                              hidden_size=config.hidden_dim // 2, # BiLSTM will concatenate the 2 directional LSTMs
                              num_layers=config.num_layers,
                              bidirectional=True,
                              batch_first=True)
        self.linear = nn.Linear(config.hidden_dim, config.output_dim)
        self.dropout = nn.Dropout(config.dropout)
        self.elu = nn.ELU()
        self.fc = nn.Linear(config.output_dim, config.predict_output)
        self.device_ = config.device
    
    def forward(self, input):   # input is a list of indices, shape batch_size, seq_len
        x = self.embed_layer(input)                     # batch_size, seq_len, 100  (This is only when batch_first=True!!!!)
        batch_size = x.size(0)
        hidden, cell = self.init_hidden(batch_size)

        out, hidden = self.biLSTM(x, (hidden, cell))    # seq_len, batch_size, (hidden_dim//2) * 2
        
        out = self.dropout(out)

        out = self.elu(self.linear(out))                # self.linear(out): batch_size, seq_len, output_dim
        
        out = self.fc(out)
        
        return out, hidden
    
    def init_hidden(self, batch_size):
        hidden = torch.zeros(2, batch_size, self.hidden_dim//2, device=self.device_)
        cell = torch.zeros(2, batch_size, self.hidden_dim//2, device=self.device_)
        return hidden, cell