manishjaiswal commited on
Commit
3ed150b
·
1 Parent(s): 4a84a41

Create new file

Browse files
Files changed (1) hide show
  1. app.py +185 -0
app.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from html import escape
2
+ import re
3
+ import streamlit as st
4
+ import pandas as pd, numpy as np
5
+ from transformers import CLIPProcessor, CLIPModel
6
+ from st_clickable_images import clickable_images
7
+
8
+ @st.cache(
9
+ show_spinner=False,
10
+ hash_funcs={
11
+ CLIPModel: lambda _: None,
12
+ CLIPProcessor: lambda _: None,
13
+ dict: lambda _: None,
14
+ },
15
+ )
16
+ def load():
17
+ model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
18
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
19
+ df = {0: pd.read_csv("data.csv"), 1: pd.read_csv("data2.csv")}
20
+ embeddings = {0: np.load("embeddings.npy"), 1: np.load("embeddings2.npy")}
21
+ for k in [0, 1]:
22
+ embeddings[k] = embeddings[k] / np.linalg.norm(
23
+ embeddings[k], axis=1, keepdims=True
24
+ )
25
+ return model, processor, df, embeddings
26
+
27
+
28
+ model, processor, df, embeddings = load()
29
+ source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"}
30
+
31
+
32
+ def compute_text_embeddings(list_of_strings):
33
+ inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
34
+ result = model.get_text_features(**inputs).detach().numpy()
35
+ return result / np.linalg.norm(result, axis=1, keepdims=True)
36
+
37
+
38
+ def image_search(query, corpus, n_results=24):
39
+ positive_embeddings = None
40
+
41
+ def concatenate_embeddings(e1, e2):
42
+ if e1 is None:
43
+ return e2
44
+ else:
45
+ return np.concatenate((e1, e2), axis=0)
46
+
47
+ splitted_query = query.split("EXCLUDING ")
48
+ dot_product = 0
49
+ k = 0 if corpus == "Unsplash" else 1
50
+ if len(splitted_query[0]) > 0:
51
+ positive_queries = splitted_query[0].split(";")
52
+ for positive_query in positive_queries:
53
+ match = re.match(r"\[(Movies|Unsplash):(\d{1,5})\](.*)", positive_query)
54
+ if match:
55
+ corpus2, idx, remainder = match.groups()
56
+ idx, remainder = int(idx), remainder.strip()
57
+ k2 = 0 if corpus2 == "Unsplash" else 1
58
+ positive_embeddings = concatenate_embeddings(
59
+ positive_embeddings, embeddings[k2][idx : idx + 1, :]
60
+ )
61
+ if len(remainder) > 0:
62
+ positive_embeddings = concatenate_embeddings(
63
+ positive_embeddings, compute_text_embeddings([remainder])
64
+ )
65
+ else:
66
+ positive_embeddings = concatenate_embeddings(
67
+ positive_embeddings, compute_text_embeddings([positive_query])
68
+ )
69
+ dot_product = embeddings[k] @ positive_embeddings.T
70
+ dot_product = dot_product - np.median(dot_product, axis=0)
71
+ dot_product = dot_product / np.max(dot_product, axis=0, keepdims=True)
72
+ dot_product = np.min(dot_product, axis=1)
73
+
74
+ if len(splitted_query) > 1:
75
+ negative_queries = (" ".join(splitted_query[1:])).split(";")
76
+ negative_embeddings = compute_text_embeddings(negative_queries)
77
+ dot_product2 = embeddings[k] @ negative_embeddings.T
78
+ dot_product2 = dot_product2 - np.median(dot_product2, axis=0)
79
+ dot_product2 = dot_product2 / np.max(dot_product2, axis=0, keepdims=True)
80
+ dot_product -= np.max(np.maximum(dot_product2, 0), axis=1)
81
+
82
+ results = np.argsort(dot_product)[-1 : -n_results - 1 : -1]
83
+ return [
84
+ (
85
+ df[k].iloc[i]["path"],
86
+ df[k].iloc[i]["tooltip"] + source[k],
87
+ i,
88
+ )
89
+ for i in results
90
+ ]
91
+
92
+
93
+ description = """
94
+ # Semantic image search
95
+ **Enter your query and hit enter**
96
+ """
97
+
98
+ howto = """
99
+ - Click image to find similar images
100
+ - Use "**;**" to combine multiple queries)
101
+ - Use "**EXCLUDING**", to exclude a query
102
+ """
103
+
104
+
105
+ def main():
106
+ st.markdown(
107
+ """
108
+ <style>
109
+ .block-container{
110
+ max-width: 1200px;
111
+ }
112
+ div.row-widget.stRadio > div{
113
+ flex-direction:row;
114
+ display: flex;
115
+ justify-content: center;
116
+ }
117
+ div.row-widget.stRadio > div > label{
118
+ margin-left: 5px;
119
+ margin-right: 5px;
120
+ }
121
+ section.main>div:first-child {
122
+ padding-top: 0px;
123
+ }
124
+ section:not(.main)>div:first-child {
125
+ padding-top: 30px;
126
+ }
127
+ div.reportview-container > section:first-child{
128
+ max-width: 320px;
129
+ }
130
+ #MainMenu {
131
+ visibility: hidden;
132
+ }
133
+ footer {
134
+ visibility: hidden;
135
+ }
136
+ </style>""",
137
+ unsafe_allow_html=True,
138
+ )
139
+ st.sidebar.markdown(description)
140
+ with st.sidebar.expander("Advanced use"):
141
+ st.markdown(howto)
142
+
143
+
144
+ st.sidebar.markdown(f"Unsplash has categories that match: backgrounds, photos, nature, iphone, etc")
145
+ st.sidebar.markdown(f"Unsplash images contain animals, apps, events, feelings, food, travel, nature, people, religion, sports, things, stock")
146
+ st.sidebar.markdown(f"Unsplash things include flag, tree, clock, money, tattoo, arrow, book, car, fireworks, ghost, health, kiss, dance, balloon, crown, eye, house, music, airplane, lighthouse, typewriter, toys")
147
+ st.sidebar.markdown(f"unsplash feelings include funny, heart, love, cool, congratulations, love, scary, cute, friendship, inspirational, hug, sad, cursed, beautiful, crazy, respect, transformation, peaceful, happy")
148
+ st.sidebar.markdown(f"unsplash people contain baby, life, women, family, girls, pregnancy, society, old people, musician, attractive, bohemian")
149
+ st.sidebar.markdown(f"imagenet queries include: photo of, photo of many, sculpture of, rendering of, graffiti of, tattoo of, embroidered, drawing of, plastic, black and white, painting, video game, doodle, origami, sketch, etc")
150
+
151
+
152
+ _, c, _ = st.columns((1, 3, 1))
153
+ if "query" in st.session_state:
154
+ query = c.text_input("", value=st.session_state["query"])
155
+ else:
156
+
157
+ query = c.text_input("", value="lighthouse")
158
+ corpus = st.radio("", ["Unsplash"])
159
+ #corpus = st.radio("", ["Unsplash", "Movies"])
160
+ if len(query) > 0:
161
+ results = image_search(query, corpus)
162
+ clicked = clickable_images(
163
+ [result[0] for result in results],
164
+ titles=[result[1] for result in results],
165
+ div_style={
166
+ "display": "flex",
167
+ "justify-content": "center",
168
+ "flex-wrap": "wrap",
169
+ },
170
+ img_style={"margin": "2px", "height": "200px"},
171
+ )
172
+ if clicked >= 0:
173
+ change_query = False
174
+ if "last_clicked" not in st.session_state:
175
+ change_query = True
176
+ else:
177
+ if clicked != st.session_state["last_clicked"]:
178
+ change_query = True
179
+ if change_query:
180
+ st.session_state["query"] = f"[{corpus}:{results[clicked][2]}]"
181
+ st.experimental_rerun()
182
+
183
+
184
+ if __name__ == "__main__":
185
+ main()