|
from transformers import AutoModel, AutoConfig, PreTrainedModel |
|
import torch |
|
|
|
|
|
class MultiLabelAttention(torch.nn.Module): |
|
def __init__(self, D_in, num_labels): |
|
super().__init__() |
|
self.A = torch.nn.Parameter(torch.empty(D_in, num_labels)) |
|
torch.nn.init.uniform_(self.A, -0.1, 0.1) |
|
|
|
def forward(self, x): |
|
attention_weights = torch.nn.functional.softmax( |
|
torch.tanh(torch.matmul(x, self.A)), dim=1 |
|
) |
|
return torch.matmul(torch.transpose(attention_weights, 2, 1), x) |
|
|
|
|
|
class BertMesh(PreTrainedModel): |
|
def __init__( |
|
self, |
|
config, |
|
pretrained_model="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", |
|
num_labels=28761, |
|
hidden_size=1024, |
|
dropout=0, |
|
multilabel_attention=True, |
|
): |
|
super().__init__(config=config) |
|
self.config.auto_map = {"AutoModel": "transformers_model.BertMesh"} |
|
self.pretrained_model = pretrained_model |
|
self.num_labels = num_labels |
|
self.hidden_size = hidden_size |
|
self.dropout = dropout |
|
self.multilabel_attention = multilabel_attention |
|
|
|
self.bert = AutoModel.from_pretrained(pretrained_model) |
|
self.multilabel_attention_layer = MultiLabelAttention( |
|
768, num_labels |
|
) |
|
self.linear_1 = torch.nn.Linear(768, hidden_size) |
|
self.linear_2 = torch.nn.Linear(hidden_size, 1) |
|
self.linear_out = torch.nn.Linear(hidden_size, num_labels) |
|
self.dropout_layer = torch.nn.Dropout(self.dropout) |
|
|
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None): |
|
input_ids = torch.tensor(input_ids) |
|
if self.multilabel_attention: |
|
hidden_states = self.bert(input_ids=input_ids)[0] |
|
attention_outs = self.multilabel_attention_layer(hidden_states) |
|
outs = torch.nn.functional.relu(self.linear_1(attention_outs)) |
|
outs = self.dropout_layer(outs) |
|
outs = torch.sigmoid(self.linear_2(outs)) |
|
outs = torch.flatten(outs, start_dim=1) |
|
else: |
|
cls = self.bert(input_ids=input_ids)[1] |
|
outs = torch.nn.functional.relu(self.linear_1(cls)) |
|
outs = self.dropout_layer(outs) |
|
outs = torch.sigmoid(self.linear_out(outs)) |
|
return outs |
|
|
|
def _init_weights(self, module): |
|
pass |
|
|