Fix AutoModel not loading model correctly due to config_class inconsistency

#11
by liamclarkza - opened
Files changed (1) hide show
  1. dnabert_layer.py +16 -0
dnabert_layer.py CHANGED
@@ -9,22 +9,38 @@ from transformers.models.bert.modeling_bert import BertForPreTraining as Transfo
9
  from transformers.models.bert.modeling_bert import BertPreTrainedModel
10
  from transformers.modeling_outputs import SequenceClassifierOutput
11
 
 
 
12
 
13
  class BertModel(TransformersBertModel):
 
 
 
14
  def __init__(self, config):
15
  super().__init__(config)
 
16
 
17
  class BertForMaskedLM(TransformersBertForMaskedLM):
 
 
 
18
  def __init__(self, config):
19
  super().__init__(config)
 
20
 
21
  class BertForPreTraining(TransformersBertForPreTraining):
 
 
 
22
  def __init__(self, config):
23
  super().__init__(config)
24
 
25
 
26
 
27
  class DNABertForSequenceClassification(BertPreTrainedModel):
 
 
 
28
  def __init__(self, config):
29
  super().__init__(config)
30
  self.num_labels = config.num_labels
 
9
  from transformers.models.bert.modeling_bert import BertPreTrainedModel
10
  from transformers.modeling_outputs import SequenceClassifierOutput
11
 
12
+ from .configuration_bert import BertConfig
13
+
14
 
15
  class BertModel(TransformersBertModel):
16
+
17
+ config_class = BertConfig
18
+
19
  def __init__(self, config):
20
  super().__init__(config)
21
+
22
 
23
  class BertForMaskedLM(TransformersBertForMaskedLM):
24
+
25
+ config_class = BertConfig
26
+
27
  def __init__(self, config):
28
  super().__init__(config)
29
+
30
 
31
  class BertForPreTraining(TransformersBertForPreTraining):
32
+
33
+ config_class = BertConfig
34
+
35
  def __init__(self, config):
36
  super().__init__(config)
37
 
38
 
39
 
40
  class DNABertForSequenceClassification(BertPreTrainedModel):
41
+
42
+ config_class = BertConfig
43
+
44
  def __init__(self, config):
45
  super().__init__(config)
46
  self.num_labels = config.num_labels