File size: 469 Bytes
c4bc1f2
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from transformers import ViTModel
from torch import nn


class ViTImageSearchModel(nn.Module):
    def __init__(self, pretrained_model_name="google/vit-base-patch32-224-in21k"):
        super(ViTImageSearchModel, self).__init__()
        self.vit = ViTModel.from_pretrained(pretrained_model_name)

    def forward(self, x):  # noqa
        outputs = self.vit(pixel_values=x)
        cls_hidden_state = outputs.last_hidden_state[:, 0, :]
        return cls_hidden_state