|
|
|
from transformers import PretrainedConfig, RobertaConfig |
|
|
|
|
|
class JapaneseCLIPVisionConfig(PretrainedConfig): |
|
model_type = "vit" |
|
|
|
def __init__(self, |
|
image_size: int, |
|
patch_size: int, |
|
width: int, |
|
layers: int, |
|
head_width: int, |
|
mlp_ratio: float, |
|
ls_init_value: float = None, |
|
attentional_pool: bool = False, |
|
attn_pooler_queries: int = 256, |
|
attn_pooler_heads: int = 8, |
|
output_dim: int = 512, |
|
patch_dropout: float = 0.0, |
|
no_ln_pre: bool = False, |
|
pool_type: str = "tok", |
|
final_ln_after_pool: bool = False, |
|
output_tokens: bool = False, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
|
|
self.image_size = image_size |
|
self.patch_size = patch_size |
|
self.width = width |
|
self.layers = layers |
|
self.head_width = head_width |
|
self.heads = width // head_width |
|
self.mlp_ratio = mlp_ratio |
|
self.ls_init_value = ls_init_value |
|
self.attentional_pool = attentional_pool |
|
self.attn_pooler_queries = attn_pooler_queries |
|
self.attn_pooler_heads = attn_pooler_heads |
|
self.output_dim = output_dim |
|
self.patch_dropout = patch_dropout |
|
self.no_ln_pre = no_ln_pre |
|
self.pool_type = pool_type |
|
self.final_ln_after_pool = final_ln_after_pool |
|
self.output_tokens = output_tokens |
|
|
|
|
|
class JapaneseCLIPConfig(PretrainedConfig): |
|
model_type = "japanese_clip" |
|
|
|
def __init__( |
|
self, |
|
max_length: int = 77, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
|
|
self.max_length = max_length |
|
|
|
if "vision_config" not in kwargs: |
|
raise ValueError("vision_config must be provided") |
|
if "text_config" not in kwargs: |
|
raise ValueError("text_config must be provided") |
|
|
|
vision_config = kwargs.pop("vision_config") |
|
text_config = kwargs.pop("text_config") |
|
|
|
self.vision_config = JapaneseCLIPVisionConfig(**vision_config) |
|
self.text_config = RobertaConfig(**text_config) |
|
|
|
@classmethod |
|
def from_vision_text_configs( |
|
cls, |
|
vision_config: PretrainedConfig, |
|
text_config: PretrainedConfig, |
|
**kwargs |
|
): |
|
r""" |
|
Instantiate a [`VisionTextDualEncoderConfig`] (or a derived class) from text model configuration and vision |
|
model configuration. |
|
Returns: |
|
[`VisionTextDualEncoderConfig`]: An instance of a configuration object |
|
""" |
|
|
|
return cls( |
|
vision_config=vision_config.to_dict(), |
|
text_config=text_config.to_dict(), |
|
**kwargs, |
|
) |