File size: 2,388 Bytes
9ec969b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from transformers import Qwen2Config, PretrainedConfig, SiglipVisionConfig
from transformers.utils import logging

logger = logging.get_logger(__name__)


class JasperVLConfig(PretrainedConfig):
    model_type = "jasper_vl"

    def __init__(
            self,
            is_text_encoder: bool = True,
            vector_dim: int = 12288,
            vector_dropout_p: float = 0.2,

            num_img_tokens: int = 300,

            img_start_token_id: int = 151646,
            img_start_token: str = "<|jasper_img_start|>",

            img_token_id: int = 151647,
            img_token: str = "<|jasper_img_token|>",

            img_end_token_id: int = 151648,
            img_end_token: str = "<|jasper_img_end|>",

            text_config=None,
            vision_config=None,

            **kwargs
    ):
        super().__init__(**kwargs)
        if vector_dim not in (12288, 1024, 512, 256):
            raise ValueError("vector_dim must be 12288, 1024, 512, 256")
        self.is_text_encoder = is_text_encoder
        self.vector_dim = vector_dim
        self.vector_dropout_p = vector_dropout_p
        
        self.num_img_tokens = num_img_tokens

        self.img_start_token_id = img_start_token_id
        self.img_start_token = img_start_token

        self.img_token_id = img_token_id
        self.img_token = img_token

        self.img_end_token_id = img_end_token_id
        self.img_end_token = img_end_token

        if text_config is None:
            text_config = {}
            logger.info("`text_config` is `None`. Initializing the `Qwen2Config` with default values.")

        if vision_config is None:
            vision_config = {}
            logger.info("`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values.")

        self.text_config = Qwen2Config(**text_config)
        self.vision_config = SiglipVisionConfig(**vision_config)

    @classmethod
    def from_text_vision_configs(cls, text_config: Qwen2Config, vision_config: SiglipVisionConfig, **kwargs):
        r"""
        Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision
        model configuration.

        Returns:
            [`SiglipConfig`]: An instance of a configuration object
        """

        return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)