File size: 2,706 Bytes
4bb166c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from sklearn.preprocessing import normalize
import torchvision.transforms as T
import open_clip
import torch
import math
from torch import nn
import torch.nn.functional as F


def get_final_transform():
    final_transform = T.Compose([
        T.Resize(
            size=(224, 224),
            interpolation=T.InterpolationMode.BICUBIC,
            antialias=True),
        T.ToTensor(),
        T.Normalize(
            mean=(0.48145466, 0.4578275, 0.40821073),
            std=(0.26862954, 0.26130258, 0.27577711)
        )
    ])
    return final_transform


class Clip_Products(nn.Module):
    def __init__(self, vit_backbone, head_size, k=3):
        super(Clip_Products, self).__init__()
        self.head = HeadV2(head_size, k)
        self.encoder = vit_backbone.visual

    def forward(self, x):
        x = self.encoder(x)
        return self.head(x)


class ArcMarginProduct_subcenter(nn.Module):
    def __init__(self, in_features, out_features, k=3):
        super().__init__()
        self.weight = nn.Parameter(torch.FloatTensor(out_features * k, in_features))
        self.reset_parameters()
        self.k = k
        self.out_features = out_features

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)

    def forward(self, features):
        cosine_all = F.linear(F.normalize(features), F.normalize(self.weight))
        cosine_all = cosine_all.view(-1, self.out_features, self.k)
        cosine, _ = torch.max(cosine_all, dim=2)
        return cosine


class HeadV2(nn.Module):
    def __init__(self, hidden_size, k=3):
        super(HeadV2, self).__init__()
        self.arc = ArcMarginProduct_subcenter(hidden_size, 9691, k)

    def forward(self, x):
        output = self.arc(x)
        return output, F.normalize(x)


class Ranker:
    def __init__(self):
        self.model_path = "model/best_model.pt"
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        backbone, _, _ = open_clip.create_model_and_transforms('ViT-L-14', None)
        self.model = Clip_Products(backbone, 768, 3)

        checkpoint = torch.load(self.model_path, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.to(self.device)

    def predict(self, img):
        transform_img = get_final_transform()
        query = transform_img(img)

        with torch.no_grad():
            self.model.eval()

            images = query.to(self.device, dtype=torch.float).unsqueeze(0)
            _, embeddings = self.model(images)

        query_embeddings = embeddings.detach().cpu().numpy()
        return normalize(query_embeddings)