from typing import Union, Optional from transformers import PretrainedConfig, AutoConfig from .visual_tokenizer import ClipVisualTokenizerConfig class OvisConfig(PretrainedConfig): model_type = "ovis" def __init__(self, llm_config: Optional[Union[PretrainedConfig, dict]] = None, visual_tokenizer_config: Optional[Union[PretrainedConfig, dict]] = None, multimodal_max_length=2048, hidden_size=None, conversation_formatter_class=None, **kwargs): super().__init__(**kwargs) if llm_config is not None: assert isinstance(llm_config, (PretrainedConfig, dict)), \ f"expect `llm_config` to be instance of PretrainedConfig or dict, but got {type(llm_config)} type" if not isinstance(llm_config, PretrainedConfig): model_type = llm_config['model_type'] llm_config.pop('model_type') llm_config = AutoConfig.for_model(model_type, **llm_config) self.llm_config = llm_config if visual_tokenizer_config is not None: assert isinstance(visual_tokenizer_config, (PretrainedConfig, dict)), \ f"expect `visual_tokenizer_config` to be instance of PretrainedConfig or dict, but got {type(visual_tokenizer_config)} type" if not isinstance(visual_tokenizer_config, PretrainedConfig): model_type = visual_tokenizer_config['model_type'] visual_tokenizer_config.pop('model_type') visual_tokenizer_config = AutoConfig.for_model(model_type, **visual_tokenizer_config) self.visual_tokenizer_config = visual_tokenizer_config self.multimodal_max_length = multimodal_max_length self.hidden_size = hidden_size self.conversation_formatter_class = conversation_formatter_class