product-retrieval / model_prediction.py
yurii_l
uploaded baseline app
4bb166c
from sklearn.preprocessing import normalize
import torchvision.transforms as T
import open_clip
import torch
import math
from torch import nn
import torch.nn.functional as F
def get_final_transform():
final_transform = T.Compose([
T.Resize(
size=(224, 224),
interpolation=T.InterpolationMode.BICUBIC,
antialias=True),
T.ToTensor(),
T.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)
)
])
return final_transform
class Clip_Products(nn.Module):
def __init__(self, vit_backbone, head_size, k=3):
super(Clip_Products, self).__init__()
self.head = HeadV2(head_size, k)
self.encoder = vit_backbone.visual
def forward(self, x):
x = self.encoder(x)
return self.head(x)
class ArcMarginProduct_subcenter(nn.Module):
def __init__(self, in_features, out_features, k=3):
super().__init__()
self.weight = nn.Parameter(torch.FloatTensor(out_features * k, in_features))
self.reset_parameters()
self.k = k
self.out_features = out_features
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
def forward(self, features):
cosine_all = F.linear(F.normalize(features), F.normalize(self.weight))
cosine_all = cosine_all.view(-1, self.out_features, self.k)
cosine, _ = torch.max(cosine_all, dim=2)
return cosine
class HeadV2(nn.Module):
def __init__(self, hidden_size, k=3):
super(HeadV2, self).__init__()
self.arc = ArcMarginProduct_subcenter(hidden_size, 9691, k)
def forward(self, x):
output = self.arc(x)
return output, F.normalize(x)
class Ranker:
def __init__(self):
self.model_path = "model/best_model.pt"
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
backbone, _, _ = open_clip.create_model_and_transforms('ViT-L-14', None)
self.model = Clip_Products(backbone, 768, 3)
checkpoint = torch.load(self.model_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.to(self.device)
def predict(self, img):
transform_img = get_final_transform()
query = transform_img(img)
with torch.no_grad():
self.model.eval()
images = query.to(self.device, dtype=torch.float).unsqueeze(0)
_, embeddings = self.model(images)
query_embeddings = embeddings.detach().cpu().numpy()
return normalize(query_embeddings)