import transformers from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging logger = logging.get_logger(__name__) class CXRMateEDConfig(PretrainedConfig): model_type = "cxrmate-ed" def __init__(self, **kwargs): super().__init__(**kwargs) if 'decoder' not in kwargs: self.decoder = transformers.LlamaConfig( vocab_size=30000, hidden_size=768, intermediate_size=3072, num_attention_heads=12, num_hidden_layers=6, max_position_embeddings=2048, ) self.decoder.is_decoder = True self.decoder.index_value_encoder_intermediate_size = 2048 self.decoder.include_time_delta = True self.decoder.time_delta_monotonic_inversion = True self.decoder.add_time_deltas = True self.decoder.history = 0 self.decoder.tables_filter = ["mimic_cxr_sectioned", "triage", "medrecon"] self.decoder.prompt_report_sections_filter = ["indication", "history"] self.decoder.pad_token_id = 4 else: self.decoder = kwargs.pop("decoder") if 'encoder' not in kwargs: self.encoder = transformers.AutoConfig.from_pretrained( 'aehrc/uniformer_base_tl_384', projection_size=768, trust_remote_code=True, ) else: self.encoder = kwargs.pop("encoder") self.is_encoder_decoder = True @classmethod def from_encoder_decoder_configs( cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs ) -> PretrainedConfig: logger.info("Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config") decoder_config.is_decoder = True decoder_config.add_cross_attention = True return cls(encoder=encoder_config, decoder=decoder_config, **kwargs)