PereLluis13's picture
Upload model
cb9ce74 verified
raw
history blame
1.7 kB
from typing import Optional
from transformers import AutoConfig
from transformers.configuration_utils import PretrainedConfig
class RelikReaderConfig(PretrainedConfig):
model_type = "relik-reader"
def __init__(
self,
transformer_model: str = "microsoft/deberta-v3-base",
additional_special_symbols: int = 101,
additional_special_symbols_types: Optional[int] = 0,
num_layers: Optional[int] = None,
activation: str = "gelu",
linears_hidden_size: Optional[int] = 512,
use_last_k_layers: int = 1,
entity_type_loss: bool = False,
add_entity_embedding: bool = None,
binary_end_logits: bool = False,
training: bool = False,
default_reader_class: Optional[str] = None,
threshold: Optional[float] = 0.5,
**kwargs
) -> None:
# TODO: add name_or_path to kwargs
self.transformer_model = transformer_model
self.additional_special_symbols = additional_special_symbols
self.additional_special_symbols_types = additional_special_symbols_types
self.num_layers = num_layers
self.activation = activation
self.linears_hidden_size = linears_hidden_size
self.use_last_k_layers = use_last_k_layers
self.entity_type_loss = entity_type_loss
self.add_entity_embedding = (
True
if add_entity_embedding is None and entity_type_loss
else add_entity_embedding
)
self.threshold = threshold
self.binary_end_logits = binary_end_logits
self.training = training
self.default_reader_class = default_reader_class
super().__init__(**kwargs)