Spaces:
Running
Running
File size: 2,706 Bytes
4bb166c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
from sklearn.preprocessing import normalize
import torchvision.transforms as T
import open_clip
import torch
import math
from torch import nn
import torch.nn.functional as F
def get_final_transform():
final_transform = T.Compose([
T.Resize(
size=(224, 224),
interpolation=T.InterpolationMode.BICUBIC,
antialias=True),
T.ToTensor(),
T.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)
)
])
return final_transform
class Clip_Products(nn.Module):
def __init__(self, vit_backbone, head_size, k=3):
super(Clip_Products, self).__init__()
self.head = HeadV2(head_size, k)
self.encoder = vit_backbone.visual
def forward(self, x):
x = self.encoder(x)
return self.head(x)
class ArcMarginProduct_subcenter(nn.Module):
def __init__(self, in_features, out_features, k=3):
super().__init__()
self.weight = nn.Parameter(torch.FloatTensor(out_features * k, in_features))
self.reset_parameters()
self.k = k
self.out_features = out_features
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
def forward(self, features):
cosine_all = F.linear(F.normalize(features), F.normalize(self.weight))
cosine_all = cosine_all.view(-1, self.out_features, self.k)
cosine, _ = torch.max(cosine_all, dim=2)
return cosine
class HeadV2(nn.Module):
def __init__(self, hidden_size, k=3):
super(HeadV2, self).__init__()
self.arc = ArcMarginProduct_subcenter(hidden_size, 9691, k)
def forward(self, x):
output = self.arc(x)
return output, F.normalize(x)
class Ranker:
def __init__(self):
self.model_path = "model/best_model.pt"
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
backbone, _, _ = open_clip.create_model_and_transforms('ViT-L-14', None)
self.model = Clip_Products(backbone, 768, 3)
checkpoint = torch.load(self.model_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.to(self.device)
def predict(self, img):
transform_img = get_final_transform()
query = transform_img(img)
with torch.no_grad():
self.model.eval()
images = query.to(self.device, dtype=torch.float).unsqueeze(0)
_, embeddings = self.model(images)
query_embeddings = embeddings.detach().cpu().numpy()
return normalize(query_embeddings)
|