liamclarkza commited on
Commit
a791476
1 Parent(s): 55e0c0e

Fix AutoModel not loading model correctly due to config_class inconsistency

Browse files

This fixes an issue when using AutoModel to instantiate the model where the config class instantiated with the model is from the transformers library instead of the model's module. This causes the instantiation to fail with the error below. See [this Github issue](https://github.com/huggingface/transformers/issues/31068) for more details.

```
Traceback (most recent call last):
model = AutoModel.from_pretrained("zhihan1996/DNA_bert_6", trust_remote_code=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../lib/python3.11/site-packages/transformers/models/auto/auto_factory.py", line 560, in from_pretrained
cls.register(config.__class__, model_class, exist_ok=True)
File ".../lib/python3.11/site-packages/transformers/models/auto/auto_factory.py", line 586, in register
raise ValueError(
ValueError: The model class you are passing has a `config_class` attribute that is not consistent with the config class you passed (model has <class 'transformers.models.bert.configuration_bert.BertConfig'> and you passed <class 'transformers_modules.zhihan1996.DNA_bert_6.55e0c0eb7b734c8b9b77bc083bf89eb6fbda1341.configuration_bert.BertConfig'>. Fix one of those so they match!
```

Files changed (1) hide show
  1. dnabert_layer.py +16 -0
dnabert_layer.py CHANGED
@@ -9,22 +9,38 @@ from transformers.models.bert.modeling_bert import BertForPreTraining as Transfo
9
  from transformers.models.bert.modeling_bert import BertPreTrainedModel
10
  from transformers.modeling_outputs import SequenceClassifierOutput
11
 
 
 
12
 
13
  class BertModel(TransformersBertModel):
 
 
 
14
  def __init__(self, config):
15
  super().__init__(config)
 
16
 
17
  class BertForMaskedLM(TransformersBertForMaskedLM):
 
 
 
18
  def __init__(self, config):
19
  super().__init__(config)
 
20
 
21
  class BertForPreTraining(TransformersBertForPreTraining):
 
 
 
22
  def __init__(self, config):
23
  super().__init__(config)
24
 
25
 
26
 
27
  class DNABertForSequenceClassification(BertPreTrainedModel):
 
 
 
28
  def __init__(self, config):
29
  super().__init__(config)
30
  self.num_labels = config.num_labels
 
9
  from transformers.models.bert.modeling_bert import BertPreTrainedModel
10
  from transformers.modeling_outputs import SequenceClassifierOutput
11
 
12
+ from .configuration_bert import BertConfig
13
+
14
 
15
  class BertModel(TransformersBertModel):
16
+
17
+ config_class = BertConfig
18
+
19
  def __init__(self, config):
20
  super().__init__(config)
21
+
22
 
23
  class BertForMaskedLM(TransformersBertForMaskedLM):
24
+
25
+ config_class = BertConfig
26
+
27
  def __init__(self, config):
28
  super().__init__(config)
29
+
30
 
31
  class BertForPreTraining(TransformersBertForPreTraining):
32
+
33
+ config_class = BertConfig
34
+
35
  def __init__(self, config):
36
  super().__init__(config)
37
 
38
 
39
 
40
  class DNABertForSequenceClassification(BertPreTrainedModel):
41
+
42
+ config_class = BertConfig
43
+
44
  def __init__(self, config):
45
  super().__init__(config)
46
  self.num_labels = config.num_labels