from transformers import PreTrainedModel, PretrainedConfig from .module import ConditionalViT class CondViTConfig(PretrainedConfig): model_type = "condvit" def __init__( self, input_resolution: int = 224, patch_size: int = 16, width: int = 768, layers: int = 12, heads: int = 12, output_dim: int = 512, n_categories: int = 10, **kwargs ): self.input_resolution = input_resolution self.patch_size = patch_size self.width = width self.layers = layers self.heads = heads self.output_dim = output_dim self.n_categories = n_categories super().__init__(**kwargs) class CondViTForEmbedding(PreTrainedModel): config_class = CondViTConfig def __init__(self, config): super().__init__(config) self.model = ConditionalViT( input_resolution=config.input_resolution, patch_size=config.patch_size, width=config.width, layers=config.layers, heads=config.heads, output_dim=config.output_dim, n_categories=config.n_categories, ) def forward(self, pixel_values, category_indices=None): return self.model(imgs=pixel_values, c=category_indices)