Fix AutoModel not loading model correctly due to config_class inconsistency
#11
by
liamclarkza
- opened
- dnabert_layer.py +16 -0
dnabert_layer.py
CHANGED
@@ -9,22 +9,38 @@ from transformers.models.bert.modeling_bert import BertForPreTraining as Transfo
|
|
9 |
from transformers.models.bert.modeling_bert import BertPreTrainedModel
|
10 |
from transformers.modeling_outputs import SequenceClassifierOutput
|
11 |
|
|
|
|
|
12 |
|
13 |
class BertModel(TransformersBertModel):
|
|
|
|
|
|
|
14 |
def __init__(self, config):
|
15 |
super().__init__(config)
|
|
|
16 |
|
17 |
class BertForMaskedLM(TransformersBertForMaskedLM):
|
|
|
|
|
|
|
18 |
def __init__(self, config):
|
19 |
super().__init__(config)
|
|
|
20 |
|
21 |
class BertForPreTraining(TransformersBertForPreTraining):
|
|
|
|
|
|
|
22 |
def __init__(self, config):
|
23 |
super().__init__(config)
|
24 |
|
25 |
|
26 |
|
27 |
class DNABertForSequenceClassification(BertPreTrainedModel):
|
|
|
|
|
|
|
28 |
def __init__(self, config):
|
29 |
super().__init__(config)
|
30 |
self.num_labels = config.num_labels
|
|
|
9 |
from transformers.models.bert.modeling_bert import BertPreTrainedModel
|
10 |
from transformers.modeling_outputs import SequenceClassifierOutput
|
11 |
|
12 |
+
from .configuration_bert import BertConfig
|
13 |
+
|
14 |
|
15 |
class BertModel(TransformersBertModel):
|
16 |
+
|
17 |
+
config_class = BertConfig
|
18 |
+
|
19 |
def __init__(self, config):
|
20 |
super().__init__(config)
|
21 |
+
|
22 |
|
23 |
class BertForMaskedLM(TransformersBertForMaskedLM):
|
24 |
+
|
25 |
+
config_class = BertConfig
|
26 |
+
|
27 |
def __init__(self, config):
|
28 |
super().__init__(config)
|
29 |
+
|
30 |
|
31 |
class BertForPreTraining(TransformersBertForPreTraining):
|
32 |
+
|
33 |
+
config_class = BertConfig
|
34 |
+
|
35 |
def __init__(self, config):
|
36 |
super().__init__(config)
|
37 |
|
38 |
|
39 |
|
40 |
class DNABertForSequenceClassification(BertPreTrainedModel):
|
41 |
+
|
42 |
+
config_class = BertConfig
|
43 |
+
|
44 |
def __init__(self, config):
|
45 |
super().__init__(config)
|
46 |
self.num_labels = config.num_labels
|