Spaces:
Sleeping
Sleeping
yurii_l
commited on
Commit
·
4bb166c
1
Parent(s):
f334b80
uploaded baseline app
Browse files- .gitattributes +1 -0
- app.py +95 -0
- labels.csv +0 -0
- model_prediction.py +88 -0
- requirements.txt +8 -0
.gitattributes
CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
embeddings.index filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import boto3
|
3 |
+
import streamlit as st
|
4 |
+
import faiss
|
5 |
+
import pandas as pd
|
6 |
+
from PIL import Image
|
7 |
+
from model_prediction import Ranker
|
8 |
+
from io import BytesIO
|
9 |
+
|
10 |
+
|
11 |
+
@st.cache
|
12 |
+
def load_model():
|
13 |
+
return Ranker()
|
14 |
+
|
15 |
+
|
16 |
+
def load_faiss_index():
|
17 |
+
return faiss.read_index('embeddings.index')
|
18 |
+
|
19 |
+
|
20 |
+
def load_labels():
|
21 |
+
return pd.read_csv("labels.csv")
|
22 |
+
|
23 |
+
|
24 |
+
class ModelLoader:
|
25 |
+
model = None
|
26 |
+
index = None
|
27 |
+
labels = None
|
28 |
+
|
29 |
+
@classmethod
|
30 |
+
def get_model(cls):
|
31 |
+
if cls.model is None:
|
32 |
+
cls.model = load_model()
|
33 |
+
return cls.model
|
34 |
+
|
35 |
+
@classmethod
|
36 |
+
def get_index(cls):
|
37 |
+
if cls.index is None:
|
38 |
+
cls.index = load_faiss_index()
|
39 |
+
return cls.index
|
40 |
+
|
41 |
+
@classmethod
|
42 |
+
def get_labels(cls):
|
43 |
+
if cls.labels is None:
|
44 |
+
cls.labels = load_labels()
|
45 |
+
return cls.labels
|
46 |
+
|
47 |
+
|
48 |
+
target_size = (224, 224)
|
49 |
+
st.set_page_config(page_title="Product Retrieval App")
|
50 |
+
st.title("Product Retrieval App")
|
51 |
+
|
52 |
+
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
53 |
+
loading_text = st.empty()
|
54 |
+
|
55 |
+
s3 = boto3.client(
|
56 |
+
's3',
|
57 |
+
aws_access_key_id='AKIAUUWYLZEQYT6ESW4Q',
|
58 |
+
aws_secret_access_key='ERiyg/QGtRyM5qxMg6UE6HLQhTkacuWcBXxfmRwB',
|
59 |
+
region_name='eu-west-1'
|
60 |
+
)
|
61 |
+
|
62 |
+
bucket_name = "product-retrieval"
|
63 |
+
|
64 |
+
if uploaded_file is not None:
|
65 |
+
image = Image.open(uploaded_file)
|
66 |
+
st.image(image, caption="Uploaded image", use_column_width=True)
|
67 |
+
|
68 |
+
loading_text.text("Loading predictions...")
|
69 |
+
|
70 |
+
model = ModelLoader.get_model()
|
71 |
+
index = ModelLoader.get_index()
|
72 |
+
labels = ModelLoader.get_labels()
|
73 |
+
|
74 |
+
image_embedding = model.predict(image)
|
75 |
+
distances, indices = index.search(image_embedding, 12)
|
76 |
+
predicted_images = labels["path"][indices[0]].to_list()
|
77 |
+
loading_text.empty()
|
78 |
+
|
79 |
+
col1, col2, col3, col4 = st.columns(4)
|
80 |
+
|
81 |
+
for i, img_path in enumerate(predicted_images):
|
82 |
+
response = s3.get_object(Bucket=bucket_name, Key=img_path.split("/")[-1])
|
83 |
+
image_data = response['Body'].read()
|
84 |
+
img = Image.open(BytesIO(image_data)).resize(target_size)
|
85 |
+
|
86 |
+
if i % 4 == 0:
|
87 |
+
column = col1
|
88 |
+
elif i % 4 == 1:
|
89 |
+
column = col2
|
90 |
+
elif i % 4 == 2:
|
91 |
+
column = col3
|
92 |
+
else:
|
93 |
+
column = col4
|
94 |
+
with column:
|
95 |
+
st.image(img, caption=f"Predicted image {i+1}", use_column_width=True)
|
labels.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
model_prediction.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sklearn.preprocessing import normalize
|
2 |
+
import torchvision.transforms as T
|
3 |
+
import open_clip
|
4 |
+
import torch
|
5 |
+
import math
|
6 |
+
from torch import nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
|
10 |
+
def get_final_transform():
|
11 |
+
final_transform = T.Compose([
|
12 |
+
T.Resize(
|
13 |
+
size=(224, 224),
|
14 |
+
interpolation=T.InterpolationMode.BICUBIC,
|
15 |
+
antialias=True),
|
16 |
+
T.ToTensor(),
|
17 |
+
T.Normalize(
|
18 |
+
mean=(0.48145466, 0.4578275, 0.40821073),
|
19 |
+
std=(0.26862954, 0.26130258, 0.27577711)
|
20 |
+
)
|
21 |
+
])
|
22 |
+
return final_transform
|
23 |
+
|
24 |
+
|
25 |
+
class Clip_Products(nn.Module):
|
26 |
+
def __init__(self, vit_backbone, head_size, k=3):
|
27 |
+
super(Clip_Products, self).__init__()
|
28 |
+
self.head = HeadV2(head_size, k)
|
29 |
+
self.encoder = vit_backbone.visual
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
x = self.encoder(x)
|
33 |
+
return self.head(x)
|
34 |
+
|
35 |
+
|
36 |
+
class ArcMarginProduct_subcenter(nn.Module):
|
37 |
+
def __init__(self, in_features, out_features, k=3):
|
38 |
+
super().__init__()
|
39 |
+
self.weight = nn.Parameter(torch.FloatTensor(out_features * k, in_features))
|
40 |
+
self.reset_parameters()
|
41 |
+
self.k = k
|
42 |
+
self.out_features = out_features
|
43 |
+
|
44 |
+
def reset_parameters(self):
|
45 |
+
stdv = 1. / math.sqrt(self.weight.size(1))
|
46 |
+
self.weight.data.uniform_(-stdv, stdv)
|
47 |
+
|
48 |
+
def forward(self, features):
|
49 |
+
cosine_all = F.linear(F.normalize(features), F.normalize(self.weight))
|
50 |
+
cosine_all = cosine_all.view(-1, self.out_features, self.k)
|
51 |
+
cosine, _ = torch.max(cosine_all, dim=2)
|
52 |
+
return cosine
|
53 |
+
|
54 |
+
|
55 |
+
class HeadV2(nn.Module):
|
56 |
+
def __init__(self, hidden_size, k=3):
|
57 |
+
super(HeadV2, self).__init__()
|
58 |
+
self.arc = ArcMarginProduct_subcenter(hidden_size, 9691, k)
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
output = self.arc(x)
|
62 |
+
return output, F.normalize(x)
|
63 |
+
|
64 |
+
|
65 |
+
class Ranker:
|
66 |
+
def __init__(self):
|
67 |
+
self.model_path = "model/best_model.pt"
|
68 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
69 |
+
|
70 |
+
backbone, _, _ = open_clip.create_model_and_transforms('ViT-L-14', None)
|
71 |
+
self.model = Clip_Products(backbone, 768, 3)
|
72 |
+
|
73 |
+
checkpoint = torch.load(self.model_path, map_location=self.device)
|
74 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
75 |
+
self.model.to(self.device)
|
76 |
+
|
77 |
+
def predict(self, img):
|
78 |
+
transform_img = get_final_transform()
|
79 |
+
query = transform_img(img)
|
80 |
+
|
81 |
+
with torch.no_grad():
|
82 |
+
self.model.eval()
|
83 |
+
|
84 |
+
images = query.to(self.device, dtype=torch.float).unsqueeze(0)
|
85 |
+
_, embeddings = self.model(images)
|
86 |
+
|
87 |
+
query_embeddings = embeddings.detach().cpu().numpy()
|
88 |
+
return normalize(query_embeddings)
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
boto3
|
2 |
+
faiss-cpu
|
3 |
+
pandas
|
4 |
+
Pillow
|
5 |
+
scikit-learn
|
6 |
+
torchvision
|
7 |
+
torch
|
8 |
+
open_clip_torch
|