File size: 1,310 Bytes
11f6a98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aeb008e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from transformers import PreTrainedModel, PretrainedConfig
from .module import ConditionalViT


class CondViTConfig(PretrainedConfig):
    model_type = "condvit"

    def __init__(
        self,
        input_resolution: int = 224,
        patch_size: int = 16,
        width: int = 768,
        layers: int = 12,
        heads: int = 12,
        output_dim: int = 512,
        n_categories: int = 10,
        **kwargs
    ):
        self.input_resolution = input_resolution
        self.patch_size = patch_size
        self.width = width
        self.layers = layers
        self.heads = heads
        self.output_dim = output_dim
        self.n_categories = n_categories

        super().__init__(**kwargs)


class CondViTForEmbedding(PreTrainedModel):
    config_class = CondViTConfig

    def __init__(self, config):
        super().__init__(config)

        self.model = ConditionalViT(
            input_resolution=config.input_resolution,
            patch_size=config.patch_size,
            width=config.width,
            layers=config.layers,
            heads=config.heads,
            output_dim=config.output_dim,
            n_categories=config.n_categories,
        )

    def forward(self, pixel_values, category_indices=None):
        return self.model(imgs=pixel_values, c=category_indices)