from sklearn.preprocessing import normalize import torchvision.transforms as T import open_clip import torch import math from torch import nn import torch.nn.functional as F def get_final_transform(): final_transform = T.Compose([ T.Resize( size=(224, 224), interpolation=T.InterpolationMode.BICUBIC, antialias=True), T.ToTensor(), T.Normalize( mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711) ) ]) return final_transform class Clip_Products(nn.Module): def __init__(self, vit_backbone, head_size, k=3): super(Clip_Products, self).__init__() self.head = HeadV2(head_size, k) self.encoder = vit_backbone.visual def forward(self, x): x = self.encoder(x) return self.head(x) class ArcMarginProduct_subcenter(nn.Module): def __init__(self, in_features, out_features, k=3): super().__init__() self.weight = nn.Parameter(torch.FloatTensor(out_features * k, in_features)) self.reset_parameters() self.k = k self.out_features = out_features def reset_parameters(self): stdv = 1. / math.sqrt(self.weight.size(1)) self.weight.data.uniform_(-stdv, stdv) def forward(self, features): cosine_all = F.linear(F.normalize(features), F.normalize(self.weight)) cosine_all = cosine_all.view(-1, self.out_features, self.k) cosine, _ = torch.max(cosine_all, dim=2) return cosine class HeadV2(nn.Module): def __init__(self, hidden_size, k=3): super(HeadV2, self).__init__() self.arc = ArcMarginProduct_subcenter(hidden_size, 9691, k) def forward(self, x): output = self.arc(x) return output, F.normalize(x) class Ranker: def __init__(self): self.model_path = "model/best_model.pt" self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') backbone, _, _ = open_clip.create_model_and_transforms('ViT-L-14', None) self.model = Clip_Products(backbone, 768, 3) checkpoint = torch.load(self.model_path, map_location=self.device) self.model.load_state_dict(checkpoint['model_state_dict']) self.model.to(self.device) def predict(self, img): transform_img = get_final_transform() query = transform_img(img) with torch.no_grad(): self.model.eval() images = query.to(self.device, dtype=torch.float).unsqueeze(0) _, embeddings = self.model(images) query_embeddings = embeddings.detach().cpu().numpy() return normalize(query_embeddings)