import transformers from transformers.models.auto import CONFIG_MAPPING class CXRMateEDConfig(transformers.PretrainedConfig): model_type = 'cxrmate-ed' def __init__( self, vision_config=None, text_config=None, index_value_encoder_intermediate_size: int = 2048, include_time_delta: bool = True, time_delta_monotonic_inversion: bool = True, add_time_deltas: bool = True, history: int = 0, tables_filter: list = ['mimic_cxr_sectioned', 'triage', 'medrecon'], prompt_report_sections_filter: list = ['indication', 'history'], pad_token_id: int = 4, **kwargs, ): super().__init__(**kwargs) self.vision_config = vision_config self.index_value_encoder_intermediate_size = index_value_encoder_intermediate_size self.include_time_delta = include_time_delta self.time_delta_monotonic_inversion = time_delta_monotonic_inversion self.add_time_deltas = add_time_deltas self.history = history self.tables_filter = tables_filter self.prompt_report_sections_filter = prompt_report_sections_filter self.pad_token_id = pad_token_id if isinstance(vision_config, dict): vision_config = transformers.AutoConfig.from_pretrained( 'aehrc/uniformer_base_tl_384', trust_remote_code=True, **vision_config, ) self.vision_config = vision_config if isinstance(text_config, dict): text_config['model_type'] = text_config['model_type'] if 'model_type' in text_config else 'llama' text_config = CONFIG_MAPPING[text_config['model_type']](**text_config) self.text_config = text_config