Spaces:
Running
Running
from sklearn.preprocessing import normalize | |
import torchvision.transforms as T | |
import open_clip | |
import torch | |
import math | |
from torch import nn | |
import torch.nn.functional as F | |
def get_final_transform(): | |
final_transform = T.Compose([ | |
T.Resize( | |
size=(224, 224), | |
interpolation=T.InterpolationMode.BICUBIC, | |
antialias=True), | |
T.ToTensor(), | |
T.Normalize( | |
mean=(0.48145466, 0.4578275, 0.40821073), | |
std=(0.26862954, 0.26130258, 0.27577711) | |
) | |
]) | |
return final_transform | |
class Clip_Products(nn.Module): | |
def __init__(self, vit_backbone, head_size, k=3): | |
super(Clip_Products, self).__init__() | |
self.head = HeadV2(head_size, k) | |
self.encoder = vit_backbone.visual | |
def forward(self, x): | |
x = self.encoder(x) | |
return self.head(x) | |
class ArcMarginProduct_subcenter(nn.Module): | |
def __init__(self, in_features, out_features, k=3): | |
super().__init__() | |
self.weight = nn.Parameter(torch.FloatTensor(out_features * k, in_features)) | |
self.reset_parameters() | |
self.k = k | |
self.out_features = out_features | |
def reset_parameters(self): | |
stdv = 1. / math.sqrt(self.weight.size(1)) | |
self.weight.data.uniform_(-stdv, stdv) | |
def forward(self, features): | |
cosine_all = F.linear(F.normalize(features), F.normalize(self.weight)) | |
cosine_all = cosine_all.view(-1, self.out_features, self.k) | |
cosine, _ = torch.max(cosine_all, dim=2) | |
return cosine | |
class HeadV2(nn.Module): | |
def __init__(self, hidden_size, k=3): | |
super(HeadV2, self).__init__() | |
self.arc = ArcMarginProduct_subcenter(hidden_size, 9691, k) | |
def forward(self, x): | |
output = self.arc(x) | |
return output, F.normalize(x) | |
class Ranker: | |
def __init__(self): | |
self.model_path = "model/best_model.pt" | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
backbone, _, _ = open_clip.create_model_and_transforms('ViT-L-14', None) | |
self.model = Clip_Products(backbone, 768, 3) | |
checkpoint = torch.load(self.model_path, map_location=self.device) | |
self.model.load_state_dict(checkpoint['model_state_dict']) | |
self.model.to(self.device) | |
def predict(self, img): | |
transform_img = get_final_transform() | |
query = transform_img(img) | |
with torch.no_grad(): | |
self.model.eval() | |
images = query.to(self.device, dtype=torch.float).unsqueeze(0) | |
_, embeddings = self.model(images) | |
query_embeddings = embeddings.detach().cpu().numpy() | |
return normalize(query_embeddings) | |