File size: 1,637 Bytes
0a6ee48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from dataclasses import dataclass

import torch
import torch.nn as nn

from transformers import SiglipVisionModel, SiglipPreTrainedModel, SiglipVisionConfig
from transformers.utils import ModelOutput


@dataclass
class SiglipForImageClassifierOutput(ModelOutput):
    loss: torch.FloatTensor | None = None
    logits: torch.FloatTensor | None = None
    pooler_output: torch.FloatTensor | None = None
    hidden_states: tuple[torch.FloatTensor, ...] | None = None
    attentions: tuple[torch.FloatTensor, ...] | None = None


class SiglipForImageClassification(SiglipPreTrainedModel):
    config_class = SiglipVisionConfig
    main_input_name = "pixel_values"

    def __init__(
        self,
        config,
    ):
        super().__init__(config)

        self.num_labels = config.num_labels
        self.siglip = SiglipVisionModel(config)

        # Classifier head
        self.classifier = (
            nn.Linear(config.hidden_size, config.num_labels)
            if config.num_labels > 0
            else nn.Identity()
        )

        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self, pixel_values: torch.FloatTensor, labels: torch.LongTensor | None = None
    ):
        outputs = self.siglip(pixel_values)
        pooler_output = outputs.pooler_output
        logits = self.classifier(pooler_output)

        loss = None

        return SiglipForImageClassifierOutput(
            loss=loss,
            logits=logits,
            pooler_output=outputs.pooler_output,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )