File size: 1,378 Bytes
aff5ec5
 
 
 
 
 
 
 
be023c1
aff5ec5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be023c1
aff5ec5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from transformers import PretrainedConfig, BertConfig
from typing import List

class VGCNConfig(BertConfig):
    model_type = "vgcn"

    def __init__(
        self,
        bert_model='readerbench/RoBERT-base',
        gcn_adj_matrix: str ='',
        max_seq_len: int = 256,
        npmi_threshold: float = 0.2,
        tf_threshold: float = 0.0,
        vocab_type: str = "all",
        gcn_embedding_dim: int = 32,
        **kwargs,
    ):
        if vocab_type not in ["all", "pmi", "tf"]:
            raise ValueError(f"`vocab_type` must be 'all', 'pmi' or 'tf', got {vocab_type}.")
        if max_seq_len < 1 or max_seq_len > 512:
            raise ValueError(f"`max_seq_len` must be between 1 and 512, got {max_seq_len}.")
        if npmi_threshold < 0.0 or npmi_threshold > 1.0:
            raise ValueError(f"`npmi_threshold` must be between 0.0 and 1.0, got {npmi_threshold}.")
        if tf_threshold < 0.0 or tf_threshold > 1.0:
            raise ValueError(f"`tf_threshold` must be between 0.0 and 1.0, got {tf_threshold}.")
        
        self.gcn_adj_matrix = gcn_adj_matrix
        self.max_seq_len = max_seq_len
        self.npmi_threshold = npmi_threshold
        self.tf_threshold = tf_threshold
        self.vocab_type = vocab_type
        self.gcn_embedding_dim = gcn_embedding_dim
        self.bert_model = bert_model

        super().__init__(**kwargs)