jasper_en_vision_language_v1 / configuration_jasper_vl.py
infgrad's picture
Upload 15 files
9ec969b verified
from transformers import Qwen2Config, PretrainedConfig, SiglipVisionConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class JasperVLConfig(PretrainedConfig):
model_type = "jasper_vl"
def __init__(
self,
is_text_encoder: bool = True,
vector_dim: int = 12288,
vector_dropout_p: float = 0.2,
num_img_tokens: int = 300,
img_start_token_id: int = 151646,
img_start_token: str = "<|jasper_img_start|>",
img_token_id: int = 151647,
img_token: str = "<|jasper_img_token|>",
img_end_token_id: int = 151648,
img_end_token: str = "<|jasper_img_end|>",
text_config=None,
vision_config=None,
**kwargs
):
super().__init__(**kwargs)
if vector_dim not in (12288, 1024, 512, 256):
raise ValueError("vector_dim must be 12288, 1024, 512, 256")
self.is_text_encoder = is_text_encoder
self.vector_dim = vector_dim
self.vector_dropout_p = vector_dropout_p
self.num_img_tokens = num_img_tokens
self.img_start_token_id = img_start_token_id
self.img_start_token = img_start_token
self.img_token_id = img_token_id
self.img_token = img_token
self.img_end_token_id = img_end_token_id
self.img_end_token = img_end_token
if text_config is None:
text_config = {}
logger.info("`text_config` is `None`. Initializing the `Qwen2Config` with default values.")
if vision_config is None:
vision_config = {}
logger.info("`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values.")
self.text_config = Qwen2Config(**text_config)
self.vision_config = SiglipVisionConfig(**vision_config)
@classmethod
def from_text_vision_configs(cls, text_config: Qwen2Config, vision_config: SiglipVisionConfig, **kwargs):
r"""
Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision
model configuration.
Returns:
[`SiglipConfig`]: An instance of a configuration object
"""
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)