Update WellcomeBertMesh with transformers based trained model
Browse files- config.json +0 -0
- model.py +19 -23
- pytorch_model.bin +2 -2
config.json
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
model.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from transformers import AutoModel,
|
2 |
import torch
|
3 |
|
4 |
|
@@ -16,34 +16,33 @@ class MultiLabelAttention(torch.nn.Module):
|
|
16 |
|
17 |
|
18 |
class BertMesh(PreTrainedModel):
|
|
|
|
|
19 |
def __init__(
|
20 |
self,
|
21 |
config,
|
22 |
-
pretrained_model="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",
|
23 |
-
num_labels=28761,
|
24 |
-
hidden_size=1024,
|
25 |
-
dropout=0,
|
26 |
-
multilabel_attention=True,
|
27 |
):
|
28 |
super().__init__(config=config)
|
29 |
-
self.config.auto_map = {"AutoModel": "
|
30 |
-
self.pretrained_model = pretrained_model
|
31 |
-
self.num_labels = num_labels
|
32 |
-
self.hidden_size = hidden_size
|
33 |
-
self.dropout = dropout
|
34 |
-
self.multilabel_attention = multilabel_attention
|
35 |
-
|
36 |
-
self.bert = AutoModel.from_pretrained(pretrained_model) # 768
|
37 |
self.multilabel_attention_layer = MultiLabelAttention(
|
38 |
-
768, num_labels
|
39 |
) # num_labels, 768
|
40 |
-
self.linear_1 = torch.nn.Linear(768, hidden_size) # num_labels, 512
|
41 |
-
self.linear_2 = torch.nn.Linear(hidden_size, 1) # num_labels, 1
|
42 |
-
self.linear_out = torch.nn.Linear(hidden_size, num_labels)
|
43 |
self.dropout_layer = torch.nn.Dropout(self.dropout)
|
44 |
|
45 |
-
def forward(self, input_ids,
|
46 |
-
|
|
|
|
|
47 |
if self.multilabel_attention:
|
48 |
hidden_states = self.bert(input_ids=input_ids)[0]
|
49 |
attention_outs = self.multilabel_attention_layer(hidden_states)
|
@@ -57,6 +56,3 @@ class BertMesh(PreTrainedModel):
|
|
57 |
outs = self.dropout_layer(outs)
|
58 |
outs = torch.sigmoid(self.linear_out(outs))
|
59 |
return outs
|
60 |
-
|
61 |
-
def _init_weights(self, module):
|
62 |
-
pass
|
|
|
1 |
+
from transformers import AutoModel, PreTrainedModel, BertConfig
|
2 |
import torch
|
3 |
|
4 |
|
|
|
16 |
|
17 |
|
18 |
class BertMesh(PreTrainedModel):
|
19 |
+
config_class = BertConfig
|
20 |
+
|
21 |
def __init__(
|
22 |
self,
|
23 |
config,
|
|
|
|
|
|
|
|
|
|
|
24 |
):
|
25 |
super().__init__(config=config)
|
26 |
+
self.config.auto_map = {"AutoModel": "model.BertMesh"}
|
27 |
+
self.pretrained_model = self.config.pretrained_model
|
28 |
+
self.num_labels = self.config.num_labels
|
29 |
+
self.hidden_size = getattr(self.config, "hidden_size", 512)
|
30 |
+
self.dropout = getattr(self.config, "dropout", 0.1)
|
31 |
+
self.multilabel_attention = getattr(self.config, "multilabel_attention", False)
|
32 |
+
|
33 |
+
self.bert = AutoModel.from_pretrained(self.pretrained_model) # 768
|
34 |
self.multilabel_attention_layer = MultiLabelAttention(
|
35 |
+
768, self.num_labels
|
36 |
) # num_labels, 768
|
37 |
+
self.linear_1 = torch.nn.Linear(768, self.hidden_size) # num_labels, 512
|
38 |
+
self.linear_2 = torch.nn.Linear(self.hidden_size, 1) # num_labels, 1
|
39 |
+
self.linear_out = torch.nn.Linear(self.hidden_size, self.num_labels)
|
40 |
self.dropout_layer = torch.nn.Dropout(self.dropout)
|
41 |
|
42 |
+
def forward(self, input_ids, **kwargs):
|
43 |
+
if type(input_ids) is list:
|
44 |
+
# coming from tokenizer
|
45 |
+
input_ids = torch.tensor(input_ids)
|
46 |
if self.multilabel_attention:
|
47 |
hidden_states = self.bert(input_ids=input_ids)[0]
|
48 |
attention_outs = self.multilabel_attention_layer(hidden_states)
|
|
|
56 |
outs = self.dropout_layer(outs)
|
57 |
outs = torch.sigmoid(self.linear_out(outs))
|
58 |
return outs
|
|
|
|
|
|
pytorch_model.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1c80db3a392fe08b3faa111d46e48fef56eb2c0efe862f0a80cc7fe4da55baea
|
3 |
+
size 647442531
|