# coding=utf-8 # Copyright 2024 HuggingFace Inc. team. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Mllama model configuration""" import os from typing import Dict, List, Optional, Union import transformers from transformers.configuration_utils import PretrainedConfig from transformers.modeling_rope_utils import rope_config_validation from transformers.utils import logging from transformers import Wav2Vec2BertConfig, AutoConfig from transformers.models.mllama.configuration_mllama import MllamaVisionConfig, MllamaTextConfig logger = logging.get_logger(__name__) class MllamaAudioConfig(Wav2Vec2BertConfig): def __init__(self, **kwargs): super().__init__(**kwargs) class Llama3Config(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`MllamaForConditionalGeneration`]. It is used to instantiate an Mllama model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Mllama-9B. e.g. [meta-llama/Llama-3.2-11B-Vision](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision) Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `MllamaVisionConfig`): The config object or dictionary of the vision backbone. text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `MllamaTextConfig`): The config object or dictionary of the text backbone. image_token_index (`int`, *optional*, defaults to 128256): The image token index to encode the image prompt. Example: ```python >>> from transformers import MllamaForConditionalGeneration, MllamaConfig, MllamaVisionConfig, MllamaTextConfig >>> # Initializing a CLIP-vision config >>> vision_config = MllamaVisionConfig() >>> # Initializing a Llama config >>> text_config = MllamaTextConfig() >>> # Initializing a mllama-11b style configuration >>> configuration = MllamaConfig(vision_config, text_config) >>> # Initializing a model from the mllama-11b style configuration >>> model = MllamaForConditionalGeneration(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "llama3" is_composition = True def __init__( self, vision_config=None, text_config=None, audio_config=None, image_token_index=128256, audio_token_index=128257, **kwargs, ): if vision_config is None: self.vision_config = MllamaVisionConfig() logger.info("vision_config is None, using default mllama vision config") elif isinstance(vision_config, dict): self.vision_config = MllamaVisionConfig(**vision_config) elif isinstance(vision_config, MllamaVisionConfig): self.vision_config = vision_config self.image_token_index = image_token_index if audio_config is None: self.audio_config = MllamaAudioConfig() logger.info("audio_config is None, using default mllama audio config") elif isinstance(audio_config, dict): self.audio_config = MllamaAudioConfig(**audio_config) elif isinstance(audio_config, MllamaAudioConfig): self.audio_config = audio_config self.audio_token_index = audio_token_index if text_config is None: self.text_config = MllamaTextConfig() logger.info("text_config is None, using default mllama text config") elif isinstance(text_config, dict): self.text_config = MllamaTextConfig(**text_config) elif isinstance(text_config, MllamaTextConfig): self.text_config = text_config super().__init__(**kwargs) AutoConfig.register("llama3", Llama3Config) transformers.Llama3Config = Llama3Config