nsorros commited on
Commit
b4da537
1 Parent(s): 9c3f2b9

Update model

Browse files
Files changed (1) hide show
  1. model.py +4 -3
model.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import AutoModel
2
  import torch
3
 
4
 
@@ -15,7 +15,7 @@ class MultiLabelAttention(torch.nn.Module):
15
  return torch.matmul(torch.transpose(attention_weights, 2, 1), x)
16
 
17
 
18
- class BertMesh(torch.nn.Module):
19
  def __init__(
20
  self,
21
  pretrained_model,
@@ -24,7 +24,8 @@ class BertMesh(torch.nn.Module):
24
  dropout=0,
25
  multilabel_attention=False,
26
  ):
27
- super().__init__()
 
28
  self.pretrained_model = pretrained_model
29
  self.num_labels = num_labels
30
  self.hidden_size = hidden_size
 
1
+ from transformers import AutoModel, AutoConfig, PreTrainedModel
2
  import torch
3
 
4
 
 
15
  return torch.matmul(torch.transpose(attention_weights, 2, 1), x)
16
 
17
 
18
+ class BertMesh(PreTrainedModel):
19
  def __init__(
20
  self,
21
  pretrained_model,
 
24
  dropout=0,
25
  multilabel_attention=False,
26
  ):
27
+ super().__init__(config=AutoConfig.from_pretrained(pretrained_model))
28
+ self.config.auto_map = {"AutoModel": "transformers_model.BertMesh"}
29
  self.pretrained_model = pretrained_model
30
  self.num_labels = num_labels
31
  self.hidden_size = hidden_size