from transformers import Qwen2Config, PretrainedConfig, SiglipVisionConfig from transformers.utils import logging logger = logging.get_logger(__name__) class JasperVLConfig(PretrainedConfig): model_type = "jasper_vl" def __init__( self, is_text_encoder: bool = True, vector_dim: int = 12288, vector_dropout_p: float = 0.2, num_img_tokens: int = 300, img_start_token_id: int = 151646, img_start_token: str = "<|jasper_img_start|>", img_token_id: int = 151647, img_token: str = "<|jasper_img_token|>", img_end_token_id: int = 151648, img_end_token: str = "<|jasper_img_end|>", text_config=None, vision_config=None, **kwargs ): super().__init__(**kwargs) if vector_dim not in (12288, 1024, 512, 256): raise ValueError("vector_dim must be 12288, 1024, 512, 256") self.is_text_encoder = is_text_encoder self.vector_dim = vector_dim self.vector_dropout_p = vector_dropout_p self.num_img_tokens = num_img_tokens self.img_start_token_id = img_start_token_id self.img_start_token = img_start_token self.img_token_id = img_token_id self.img_token = img_token self.img_end_token_id = img_end_token_id self.img_end_token = img_end_token if text_config is None: text_config = {} logger.info("`text_config` is `None`. Initializing the `Qwen2Config` with default values.") if vision_config is None: vision_config = {} logger.info("`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values.") self.text_config = Qwen2Config(**text_config) self.vision_config = SiglipVisionConfig(**vision_config) @classmethod def from_text_vision_configs(cls, text_config: Qwen2Config, vision_config: SiglipVisionConfig, **kwargs): r""" Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision model configuration. Returns: [`SiglipConfig`]: An instance of a configuration object """ return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)