File size: 1,110 Bytes
6a34fd4
 
5c72fe4
6a34fd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c72fe4
 
 
 
6a34fd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1922da0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch.nn as nn

from .bert import BERT


class BERTSM(nn.Module):
    """
    BERT Sequence Model
    Masked Sequence Model
    """

    def __init__(self, bert: BERT, vocab_size):
        """
        :param bert: BERT model which should be trained
        :param vocab_size: total vocab size for masked_lm
        """

        super().__init__()
        self.bert = bert
        self.mask_lm = MaskedSequenceModel(self.bert.hidden, vocab_size)
        
    def forward(self, x, segment_label):
        x = self.bert(x, segment_label)
        return self.mask_lm(x), x[:, 0]

    
class MaskedSequenceModel(nn.Module):
    """
    predicting origin token from masked input sequence
    n-class classification problem, n-class = vocab_size
    """

    def __init__(self, hidden, vocab_size):
        """
        :param hidden: output size of BERT model
        :param vocab_size: total vocab size
        """
        super().__init__()
        self.linear = nn.Linear(hidden, vocab_size)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        return self.softmax(self.linear(x))