from transformers import PreTrainedModel, PretrainedConfig from .module import ConditionalViT from sentence_transformers import SentenceTransformer class CondViTConfig(PretrainedConfig): model_type = "condvit" def __init__( self, input_resolution: int = 224, patch_size: int = 16, width: int = 768, layers: int = 12, heads: int = 12, output_dim: int = 512, n_categories: int = 10, lm_backbone: str = "sentence-transformers/sentence-t5-xl", lm_revision: str = "e0976ba9afd18be963c22c680367a3928c44fd22", device: str = "cpu", **kwargs ): self.input_resolution = input_resolution self.patch_size = patch_size self.width = width self.layers = layers self.heads = heads self.output_dim = output_dim self.n_categories = n_categories self.lm_backbone = lm_backbone self.lm_revision = lm_revision self.device = device super().__init__(**kwargs) class CondViTForEmbedding(PreTrainedModel): config_class = CondViTConfig def __init__(self, config): super().__init__(config) self.condvit = ConditionalViT( input_resolution=config.input_resolution, patch_size=config.patch_size, width=config.width, layers=config.layers, heads=config.heads, output_dim=config.output_dim, ) if config.device: self.condvit.to(config.device) self.lm = SentenceTransformer( config.lm_backbone, revision=config.lm_revision, device=config.device ) def forward(self, pixel_values, texts=None): if texts is not None: text_embeddings = self.lm.encode( texts, convert_to_tensor=True, convert_to_numpy=False, ) text_embeddings = text_embeddings.to(pixel_values.device) else: text_embeddings = None return self.condvit(imgs=pixel_values, c=text_embeddings)