CondViT-B16-cat / hf_model.py
Slep's picture
Upload CondViTForEmbedding
aeb008e verified
raw
history blame
1.31 kB
from transformers import PreTrainedModel, PretrainedConfig
from .module import ConditionalViT
class CondViTConfig(PretrainedConfig):
model_type = "condvit"
def __init__(
self,
input_resolution: int = 224,
patch_size: int = 16,
width: int = 768,
layers: int = 12,
heads: int = 12,
output_dim: int = 512,
n_categories: int = 10,
**kwargs
):
self.input_resolution = input_resolution
self.patch_size = patch_size
self.width = width
self.layers = layers
self.heads = heads
self.output_dim = output_dim
self.n_categories = n_categories
super().__init__(**kwargs)
class CondViTForEmbedding(PreTrainedModel):
config_class = CondViTConfig
def __init__(self, config):
super().__init__(config)
self.model = ConditionalViT(
input_resolution=config.input_resolution,
patch_size=config.patch_size,
width=config.width,
layers=config.layers,
heads=config.heads,
output_dim=config.output_dim,
n_categories=config.n_categories,
)
def forward(self, pixel_values, category_indices=None):
return self.model(imgs=pixel_values, c=category_indices)