File size: 1,500 Bytes
685ecb2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
# from src.model.modules.gemma import GemmaConfig
# from src.model.modules.siglip import SiglipVisionConfig
from src.model.modules.voicecraftconfig import VoiceCraftConfig
from transformers import SiglipVisionConfig, GemmaConfig, PretrainedConfig
class ImageCraftConfig(PretrainedConfig):
model_type = "imagecraft"
def __init__(
self,
vision_config=None,
text_config=None,
voicecraft_config=None,
ignore_index=-100,
image_token_index=256000,
vocab_size=257152,
projection_dim=2048,
hidden_size=2048,
pad_token_id=None,
**kwargs
):
super().__init__()
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.vocab_size = vocab_size
self.projection_dim = projection_dim
self.hidden_size = hidden_size
self.is_encoder_decoder = False
self.pad_token_id = pad_token_id if pad_token_id is not None else -1
self.vision_config = SiglipVisionConfig(**vision_config)
self.text_config = GemmaConfig(**text_config, pad_token_id=pad_token_id)
self.vocab_size = self.text_config.vocab_size
self.text_config.num_image_tokens = (
self.vision_config.image_size // self.vision_config.patch_size
) ** 2
self.vision_config.projection_dim = projection_dim
self.voicecraft_config = VoiceCraftConfig(**voicecraft_config)
super().__init__(**kwargs)
|