from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging logger = logging.get_logger(__name__) class EncoderDecoderConfig(PretrainedConfig): model_type = "encoder-decoder" is_composition = True def __init__(self, **kwargs): super().__init__(**kwargs) if "encoder" not in kwargs or "decoder" not in kwargs: raise ValueError( f"A configuraton of type {self.model_type} cannot be instantiated because " f"both `encoder` and `decoder` sub-configurations were not passed, only {kwargs}" ) self.encoder = kwargs.pop("encoder") self.decoder = kwargs.pop("decoder") self.is_encoder_decoder = True @classmethod def from_encoder_decoder_configs( cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs ) -> PretrainedConfig: r""" Instantiate a [`EncoderDecoderConfig`] (or a derived class) from a pre-trained encoder model configuration and decoder model configuration. Returns: [`EncoderDecoderConfig`]: An instance of a configuration object """ logger.info("Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config") decoder_config.is_decoder = True decoder_config.add_cross_attention = True return cls(encoder=encoder_config, decoder=decoder_config, **kwargs)