3ie-intervention-outcome-entity-linking
/
relik
/reader
/pytorch_modules
/hf
/configuration_relik.py
from typing import Optional | |
from transformers import AutoConfig | |
from transformers.configuration_utils import PretrainedConfig | |
class RelikReaderConfig(PretrainedConfig): | |
model_type = "relik-reader" | |
def __init__( | |
self, | |
transformer_model: str = "microsoft/deberta-v3-base", | |
additional_special_symbols: int = 101, | |
additional_special_symbols_types: Optional[int] = 0, | |
num_layers: Optional[int] = None, | |
activation: str = "gelu", | |
linears_hidden_size: Optional[int] = 512, | |
use_last_k_layers: int = 1, | |
entity_type_loss: bool = False, | |
add_entity_embedding: bool = None, | |
training: bool = False, | |
default_reader_class: Optional[str] = None, | |
**kwargs | |
) -> None: | |
# TODO: add name_or_path to kwargs | |
self.transformer_model = transformer_model | |
self.additional_special_symbols = additional_special_symbols | |
self.additional_special_symbols_types = additional_special_symbols_types | |
self.num_layers = num_layers | |
self.activation = activation | |
self.linears_hidden_size = linears_hidden_size | |
self.use_last_k_layers = use_last_k_layers | |
self.entity_type_loss = entity_type_loss | |
self.add_entity_embedding = ( | |
True | |
if add_entity_embedding is None and entity_type_loss | |
else add_entity_embedding | |
) | |
self.training = training | |
self.default_reader_class = default_reader_class | |
super().__init__(**kwargs) | |
AutoConfig.register("relik-reader", RelikReaderConfig) | |