cxrmate-ed / configuration_cxrmate_ed.py
anicolson's picture
Upload model
688909e verified
raw
history blame
2.04 kB
import transformers
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class CXRMateEDConfig(PretrainedConfig):
model_type = "cxrmate-ed"
def __init__(self, **kwargs):
super().__init__(**kwargs)
if 'decoder' not in kwargs:
self.decoder = transformers.LlamaConfig(
vocab_size=30000,
hidden_size=768,
intermediate_size=3072,
num_attention_heads=12,
num_hidden_layers=6,
max_position_embeddings=2048,
)
self.decoder.is_decoder = True
self.decoder.index_value_encoder_intermediate_size = 2048
self.decoder.include_time_delta = True
self.decoder.time_delta_monotonic_inversion = True
self.decoder.add_time_deltas = True
self.decoder.history = 0
self.decoder.tables_filter = ["mimic_cxr_sectioned", "triage", "medrecon"]
self.decoder.prompt_report_sections_filter = ["indication", "history"]
self.decoder.pad_token_id = 4
else:
self.decoder = kwargs.pop("decoder")
if 'encoder' not in kwargs:
self.encoder = transformers.AutoConfig.from_pretrained(
'aehrc/uniformer_base_tl_384',
projection_size=768,
trust_remote_code=True,
)
else:
self.encoder = kwargs.pop("encoder")
self.is_encoder_decoder = True
@classmethod
def from_encoder_decoder_configs(
cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
) -> PretrainedConfig:
logger.info("Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
return cls(encoder=encoder_config, decoder=decoder_config, **kwargs)