File size: 2,467 Bytes
b4da537
ba33264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4da537
ba33264
 
1b85e66
 
895295d
ba33264
 
 
 
b4da537
 
ba33264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26d4cd1
 
ba33264
26d4cd1
ba33264
 
 
 
 
 
11b09bf
ba33264
 
 
 
44b4eae
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from transformers import AutoModel, AutoConfig, PreTrainedModel
import torch


class MultiLabelAttention(torch.nn.Module):
    def __init__(self, D_in, num_labels):
        super().__init__()
        self.A = torch.nn.Parameter(torch.empty(D_in, num_labels))
        torch.nn.init.uniform_(self.A, -0.1, 0.1)

    def forward(self, x):
        attention_weights = torch.nn.functional.softmax(
            torch.tanh(torch.matmul(x, self.A)), dim=1
        )
        return torch.matmul(torch.transpose(attention_weights, 2, 1), x)


class BertMesh(PreTrainedModel):
    def __init__(
        self,
        config,
        pretrained_model="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",
        num_labels=28761,
        hidden_size=512,
        dropout=0,
        multilabel_attention=False,
    ):
        super().__init__(config=AutoConfig.from_pretrained(pretrained_model))
        self.config.auto_map = {"AutoModel": "transformers_model.BertMesh"}
        self.pretrained_model = pretrained_model
        self.num_labels = num_labels
        self.hidden_size = hidden_size
        self.dropout = dropout
        self.multilabel_attention = multilabel_attention

        self.bert = AutoModel.from_pretrained(pretrained_model)  # 768
        self.multilabel_attention_layer = MultiLabelAttention(
            768, num_labels
        )  # num_labels, 768
        self.linear_1 = torch.nn.Linear(768, hidden_size)  # num_labels, 512
        self.linear_2 = torch.nn.Linear(hidden_size, 1)  # num_labels, 1
        self.linear_out = torch.nn.Linear(hidden_size, num_labels)
        self.dropout_layer = torch.nn.Dropout(self.dropout)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None):
        input_ids = torch.tensor(input_ids)
        if self.multilabel_attention:
            hidden_states = self.bert(input_ids=input_ids)[0]
            attention_outs = self.multilabel_attention_layer(hidden_states)
            outs = torch.nn.functional.relu(self.linear_1(attention_outs))
            outs = self.dropout_layer(outs)
            outs = torch.sigmoid(self.linear_2(outs))
            outs = torch.flatten(outs, start_dim=1)
        else:
            cls = self.bert(input_ids=input_ids)[1]
            outs = torch.nn.functional.relu(self.linear_1(cls))
            outs = self.dropout_layer(outs)
            outs = torch.sigmoid(self.linear_out(outs))
        return outs

    def _init_weights(self, module):
        pass