File size: 1,500 Bytes
685ecb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# from src.model.modules.gemma import GemmaConfig
# from src.model.modules.siglip import SiglipVisionConfig
from src.model.modules.voicecraftconfig import VoiceCraftConfig

from transformers import SiglipVisionConfig, GemmaConfig, PretrainedConfig


class ImageCraftConfig(PretrainedConfig):

    model_type = "imagecraft"

    def __init__(
        self,
        vision_config=None,
        text_config=None,
        voicecraft_config=None,
        ignore_index=-100,
        image_token_index=256000,
        vocab_size=257152,
        projection_dim=2048,
        hidden_size=2048,
        pad_token_id=None,
        **kwargs
    ):
        super().__init__()
        self.ignore_index = ignore_index
        self.image_token_index = image_token_index
        self.vocab_size = vocab_size
        self.projection_dim = projection_dim
        self.hidden_size = hidden_size
        self.is_encoder_decoder = False

        self.pad_token_id = pad_token_id if pad_token_id is not None else -1

        self.vision_config = SiglipVisionConfig(**vision_config)

        self.text_config = GemmaConfig(**text_config, pad_token_id=pad_token_id)
        self.vocab_size = self.text_config.vocab_size

        self.text_config.num_image_tokens = (
            self.vision_config.image_size // self.vision_config.patch_size
        ) ** 2
        self.vision_config.projection_dim = projection_dim

        self.voicecraft_config = VoiceCraftConfig(**voicecraft_config)

        super().__init__(**kwargs)