File size: 1,168 Bytes
fdb0b54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from transformers import PretrainedConfig, PreTrainedModel, BertModel, BertConfig
from torch import nn

class SimBertModel(PreTrainedModel):
    """ SimBert Model
    """

    config_class = BertConfig

    def __init__(
            self, 
            config: PretrainedConfig
        ) -> None:
            super().__init__(config)
            self.bert = BertModel(config=config, add_pooling_layer=True)
            self.fc = nn.Linear(config.hidden_size, 2)
            # self.loss_fct = nn.CrossEntropyLoss()
            self.loss_fct = nn.MSELoss()
            self.softmax = nn.Softmax(dim=1)

    def forward(
        self,
        input_ids,
        token_type_ids,
        attention_mask,
        labels=None
    ):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        pooled_output = outputs.pooler_output
        logits = self.fc(pooled_output)
        logits = self.softmax(logits)[:,1]
        if labels is not None:
            loss = self.loss_fct(logits.view(-1), labels.view(-1))
            return loss, logits
        return None, logits