ro-offense-model / configuration_vgcn.py
andyP's picture
Initial commit
aff5ec5
raw
history blame
1.3 kB
from transformers import PretrainedConfig, BertConfig
from typing import List
class VGCNConfig(BertConfig):
model_type = "vgcn"
def __init__(
self,
gcn_adj_matrix: str ='',
max_seq_len: int = 256,
npmi_threshold: float = 0.2,
tf_threshold: float = 0.0,
vocab_type: str = "all",
gcn_embedding_dim: int = 32,
**kwargs,
):
if vocab_type not in ["all", "pmi", "tf"]:
raise ValueError(f"`vocab_type` must be 'all', 'pmi' or 'tf', got {vocab_type}.")
if max_seq_len < 1 or max_seq_len > 512:
raise ValueError(f"`max_seq_len` must be between 1 and 512, got {max_seq_len}.")
if npmi_threshold < 0.0 or npmi_threshold > 1.0:
raise ValueError(f"`npmi_threshold` must be between 0.0 and 1.0, got {npmi_threshold}.")
if tf_threshold < 0.0 or tf_threshold > 1.0:
raise ValueError(f"`tf_threshold` must be between 0.0 and 1.0, got {tf_threshold}.")
self.gcn_adj_matrix = gcn_adj_matrix
self.max_seq_len = max_seq_len
self.npmi_threshold = npmi_threshold
self.tf_threshold = tf_threshold
self.vocab_type = vocab_type
self.gcn_embedding_dim = gcn_embedding_dim
super().__init__(**kwargs)