|
import transformers |
|
from transformers.configuration_utils import PretrainedConfig |
|
from transformers.utils import logging |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class CXRMateEDConfig(PretrainedConfig): |
|
|
|
model_type = "cxrmate-ed" |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
if 'decoder' not in kwargs: |
|
|
|
self.decoder = transformers.LlamaConfig( |
|
vocab_size=30000, |
|
hidden_size=768, |
|
intermediate_size=3072, |
|
num_attention_heads=12, |
|
num_hidden_layers=6, |
|
max_position_embeddings=2048, |
|
) |
|
self.decoder.is_decoder = True |
|
|
|
self.decoder.index_value_encoder_intermediate_size = 2048 |
|
self.decoder.include_time_delta = True |
|
self.decoder.time_delta_monotonic_inversion = True |
|
self.decoder.add_time_deltas = True |
|
self.decoder.history = 0 |
|
self.decoder.tables_filter = ["mimic_cxr_sectioned", "triage", "medrecon"] |
|
self.decoder.prompt_report_sections_filter = ["indication", "history"] |
|
self.decoder.pad_token_id = 4 |
|
|
|
else: |
|
self.decoder = kwargs.pop("decoder") |
|
|
|
|
|
if 'encoder' not in kwargs: |
|
self.encoder = transformers.AutoConfig.from_pretrained( |
|
'aehrc/uniformer_base_tl_384', |
|
projection_size=768, |
|
trust_remote_code=True, |
|
) |
|
else: |
|
self.encoder = kwargs.pop("encoder") |
|
|
|
|
|
self.is_encoder_decoder = True |
|
|
|
@classmethod |
|
def from_encoder_decoder_configs( |
|
cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs |
|
) -> PretrainedConfig: |
|
|
|
logger.info("Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config") |
|
decoder_config.is_decoder = True |
|
decoder_config.add_cross_attention = True |
|
|
|
return cls(encoder=encoder_config, decoder=decoder_config, **kwargs) |