yurii_l commited on
Commit
4bb166c
·
1 Parent(s): f334b80

uploaded baseline app

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. app.py +95 -0
  3. labels.csv +0 -0
  4. model_prediction.py +88 -0
  5. requirements.txt +8 -0
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ embeddings.index filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import boto3
3
+ import streamlit as st
4
+ import faiss
5
+ import pandas as pd
6
+ from PIL import Image
7
+ from model_prediction import Ranker
8
+ from io import BytesIO
9
+
10
+
11
+ @st.cache
12
+ def load_model():
13
+ return Ranker()
14
+
15
+
16
+ def load_faiss_index():
17
+ return faiss.read_index('embeddings.index')
18
+
19
+
20
+ def load_labels():
21
+ return pd.read_csv("labels.csv")
22
+
23
+
24
+ class ModelLoader:
25
+ model = None
26
+ index = None
27
+ labels = None
28
+
29
+ @classmethod
30
+ def get_model(cls):
31
+ if cls.model is None:
32
+ cls.model = load_model()
33
+ return cls.model
34
+
35
+ @classmethod
36
+ def get_index(cls):
37
+ if cls.index is None:
38
+ cls.index = load_faiss_index()
39
+ return cls.index
40
+
41
+ @classmethod
42
+ def get_labels(cls):
43
+ if cls.labels is None:
44
+ cls.labels = load_labels()
45
+ return cls.labels
46
+
47
+
48
+ target_size = (224, 224)
49
+ st.set_page_config(page_title="Product Retrieval App")
50
+ st.title("Product Retrieval App")
51
+
52
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
53
+ loading_text = st.empty()
54
+
55
+ s3 = boto3.client(
56
+ 's3',
57
+ aws_access_key_id='AKIAUUWYLZEQYT6ESW4Q',
58
+ aws_secret_access_key='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',
59
+ region_name='eu-west-1'
60
+ )
61
+
62
+ bucket_name = "product-retrieval"
63
+
64
+ if uploaded_file is not None:
65
+ image = Image.open(uploaded_file)
66
+ st.image(image, caption="Uploaded image", use_column_width=True)
67
+
68
+ loading_text.text("Loading predictions...")
69
+
70
+ model = ModelLoader.get_model()
71
+ index = ModelLoader.get_index()
72
+ labels = ModelLoader.get_labels()
73
+
74
+ image_embedding = model.predict(image)
75
+ distances, indices = index.search(image_embedding, 12)
76
+ predicted_images = labels["path"][indices[0]].to_list()
77
+ loading_text.empty()
78
+
79
+ col1, col2, col3, col4 = st.columns(4)
80
+
81
+ for i, img_path in enumerate(predicted_images):
82
+ response = s3.get_object(Bucket=bucket_name, Key=img_path.split("/")[-1])
83
+ image_data = response['Body'].read()
84
+ img = Image.open(BytesIO(image_data)).resize(target_size)
85
+
86
+ if i % 4 == 0:
87
+ column = col1
88
+ elif i % 4 == 1:
89
+ column = col2
90
+ elif i % 4 == 2:
91
+ column = col3
92
+ else:
93
+ column = col4
94
+ with column:
95
+ st.image(img, caption=f"Predicted image {i+1}", use_column_width=True)
labels.csv ADDED
The diff for this file is too large to render. See raw diff
 
model_prediction.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.preprocessing import normalize
2
+ import torchvision.transforms as T
3
+ import open_clip
4
+ import torch
5
+ import math
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+
9
+
10
+ def get_final_transform():
11
+ final_transform = T.Compose([
12
+ T.Resize(
13
+ size=(224, 224),
14
+ interpolation=T.InterpolationMode.BICUBIC,
15
+ antialias=True),
16
+ T.ToTensor(),
17
+ T.Normalize(
18
+ mean=(0.48145466, 0.4578275, 0.40821073),
19
+ std=(0.26862954, 0.26130258, 0.27577711)
20
+ )
21
+ ])
22
+ return final_transform
23
+
24
+
25
+ class Clip_Products(nn.Module):
26
+ def __init__(self, vit_backbone, head_size, k=3):
27
+ super(Clip_Products, self).__init__()
28
+ self.head = HeadV2(head_size, k)
29
+ self.encoder = vit_backbone.visual
30
+
31
+ def forward(self, x):
32
+ x = self.encoder(x)
33
+ return self.head(x)
34
+
35
+
36
+ class ArcMarginProduct_subcenter(nn.Module):
37
+ def __init__(self, in_features, out_features, k=3):
38
+ super().__init__()
39
+ self.weight = nn.Parameter(torch.FloatTensor(out_features * k, in_features))
40
+ self.reset_parameters()
41
+ self.k = k
42
+ self.out_features = out_features
43
+
44
+ def reset_parameters(self):
45
+ stdv = 1. / math.sqrt(self.weight.size(1))
46
+ self.weight.data.uniform_(-stdv, stdv)
47
+
48
+ def forward(self, features):
49
+ cosine_all = F.linear(F.normalize(features), F.normalize(self.weight))
50
+ cosine_all = cosine_all.view(-1, self.out_features, self.k)
51
+ cosine, _ = torch.max(cosine_all, dim=2)
52
+ return cosine
53
+
54
+
55
+ class HeadV2(nn.Module):
56
+ def __init__(self, hidden_size, k=3):
57
+ super(HeadV2, self).__init__()
58
+ self.arc = ArcMarginProduct_subcenter(hidden_size, 9691, k)
59
+
60
+ def forward(self, x):
61
+ output = self.arc(x)
62
+ return output, F.normalize(x)
63
+
64
+
65
+ class Ranker:
66
+ def __init__(self):
67
+ self.model_path = "model/best_model.pt"
68
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
69
+
70
+ backbone, _, _ = open_clip.create_model_and_transforms('ViT-L-14', None)
71
+ self.model = Clip_Products(backbone, 768, 3)
72
+
73
+ checkpoint = torch.load(self.model_path, map_location=self.device)
74
+ self.model.load_state_dict(checkpoint['model_state_dict'])
75
+ self.model.to(self.device)
76
+
77
+ def predict(self, img):
78
+ transform_img = get_final_transform()
79
+ query = transform_img(img)
80
+
81
+ with torch.no_grad():
82
+ self.model.eval()
83
+
84
+ images = query.to(self.device, dtype=torch.float).unsqueeze(0)
85
+ _, embeddings = self.model(images)
86
+
87
+ query_embeddings = embeddings.detach().cpu().numpy()
88
+ return normalize(query_embeddings)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ boto3
2
+ faiss-cpu
3
+ pandas
4
+ Pillow
5
+ scikit-learn
6
+ torchvision
7
+ torch
8
+ open_clip_torch