|
from transformers import Qwen2Config, PretrainedConfig, SiglipVisionConfig |
|
from transformers.utils import logging |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class JasperVLConfig(PretrainedConfig): |
|
model_type = "jasper_vl" |
|
|
|
def __init__( |
|
self, |
|
is_text_encoder: bool = True, |
|
vector_dim: int = 12288, |
|
vector_dropout_p: float = 0.2, |
|
|
|
num_img_tokens: int = 300, |
|
|
|
img_start_token_id: int = 151646, |
|
img_start_token: str = "<|jasper_img_start|>", |
|
|
|
img_token_id: int = 151647, |
|
img_token: str = "<|jasper_img_token|>", |
|
|
|
img_end_token_id: int = 151648, |
|
img_end_token: str = "<|jasper_img_end|>", |
|
|
|
text_config=None, |
|
vision_config=None, |
|
|
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
if vector_dim not in (12288, 1024, 512, 256): |
|
raise ValueError("vector_dim must be 12288, 1024, 512, 256") |
|
self.is_text_encoder = is_text_encoder |
|
self.vector_dim = vector_dim |
|
self.vector_dropout_p = vector_dropout_p |
|
|
|
self.num_img_tokens = num_img_tokens |
|
|
|
self.img_start_token_id = img_start_token_id |
|
self.img_start_token = img_start_token |
|
|
|
self.img_token_id = img_token_id |
|
self.img_token = img_token |
|
|
|
self.img_end_token_id = img_end_token_id |
|
self.img_end_token = img_end_token |
|
|
|
if text_config is None: |
|
text_config = {} |
|
logger.info("`text_config` is `None`. Initializing the `Qwen2Config` with default values.") |
|
|
|
if vision_config is None: |
|
vision_config = {} |
|
logger.info("`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values.") |
|
|
|
self.text_config = Qwen2Config(**text_config) |
|
self.vision_config = SiglipVisionConfig(**vision_config) |
|
|
|
@classmethod |
|
def from_text_vision_configs(cls, text_config: Qwen2Config, vision_config: SiglipVisionConfig, **kwargs): |
|
r""" |
|
Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision |
|
model configuration. |
|
|
|
Returns: |
|
[`SiglipConfig`]: An instance of a configuration object |
|
""" |
|
|
|
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) |
|
|
|
|