Update model
Browse files
model.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from transformers import AutoModel
|
2 |
import torch
|
3 |
|
4 |
|
@@ -15,7 +15,7 @@ class MultiLabelAttention(torch.nn.Module):
|
|
15 |
return torch.matmul(torch.transpose(attention_weights, 2, 1), x)
|
16 |
|
17 |
|
18 |
-
class BertMesh(
|
19 |
def __init__(
|
20 |
self,
|
21 |
pretrained_model,
|
@@ -24,7 +24,8 @@ class BertMesh(torch.nn.Module):
|
|
24 |
dropout=0,
|
25 |
multilabel_attention=False,
|
26 |
):
|
27 |
-
super().__init__()
|
|
|
28 |
self.pretrained_model = pretrained_model
|
29 |
self.num_labels = num_labels
|
30 |
self.hidden_size = hidden_size
|
|
|
1 |
+
from transformers import AutoModel, AutoConfig, PreTrainedModel
|
2 |
import torch
|
3 |
|
4 |
|
|
|
15 |
return torch.matmul(torch.transpose(attention_weights, 2, 1), x)
|
16 |
|
17 |
|
18 |
+
class BertMesh(PreTrainedModel):
|
19 |
def __init__(
|
20 |
self,
|
21 |
pretrained_model,
|
|
|
24 |
dropout=0,
|
25 |
multilabel_attention=False,
|
26 |
):
|
27 |
+
super().__init__(config=AutoConfig.from_pretrained(pretrained_model))
|
28 |
+
self.config.auto_map = {"AutoModel": "transformers_model.BertMesh"}
|
29 |
self.pretrained_model = pretrained_model
|
30 |
self.num_labels = num_labels
|
31 |
self.hidden_size = hidden_size
|