|
from transformers import PretrainedConfig |
|
|
|
class MultiHeadConfig(PretrainedConfig): |
|
model_type = "multihead" |
|
|
|
def __init__( |
|
self, |
|
encoder_name="microsoft/deberta-v3-small", |
|
**kwargs |
|
): |
|
self.encoder_name = encoder_name |
|
self.classifier_dropout = kwargs.get("classifier_dropout", 0.1) |
|
self.num_labels = kwargs.get("num_labels", 2) |
|
self.id2label = kwargs.get("id2label", {0: "irrelevant", 1: "relevant"}) |
|
self.label2id = kwargs.get("label2id", {"irrelevant": 0, "relevant": 1}) |
|
self.tokenizer_class = kwargs.get("tokenizer_class", "DebertaV2TokenizerFast") |
|
super().__init__(**kwargs) |
|
|