aps commited on
Commit
73d70e7
·
0 Parent(s):

Init commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pt filter=lfs diff=lfs merge=lfs -text
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
22
+ *.tflite filter=lfs diff=lfs merge=lfs -text
23
+ *.tgz filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ *.npy filter=lfs diff=lfs merge=lfs -text
29
+ embeddings-vit-base-patch32.npy filter=lfs diff=lfs merge=lfs -text
30
+ embeddings-vit-large-patch14-336.npy filter=lfs diff=lfs merge=lfs -text
31
+ embeddings-vit-large-patch14.npy filter=lfs diff=lfs merge=lfs -text
32
+ embeddings2-vit-base-patch32.npy filter=lfs diff=lfs merge=lfs -text
33
+ embeddings2-vit-large-patch14-336.npy filter=lfs diff=lfs merge=lfs -text
34
+ embeddings2-vit-large-patch14.npy filter=lfs diff=lfs merge=lfs -text
35
+ embeddings-vit-base-patch16.npy filter=lfs diff=lfs merge=lfs -text
36
+ embeddings2-flava-full.npy filter=lfs diff=lfs merge=lfs -text
37
+ embeddings2-vit-base-patch16.npy filter=lfs diff=lfs merge=lfs -text
38
+ embeddings-flava-full.npy filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .vscode/
README.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: FLAVA Semantic Image Text Search Demo
3
+ emoji: 👁
4
+ colorFrom: indigo
5
+ colorTo: blue
6
+ sdk: streamlit
7
+ sdk_version: 1.2.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ # Configuration
13
+
14
+ `title`: _string_
15
+ Display title for the Space
16
+
17
+ `emoji`: _string_
18
+ Space emoji (emoji-only character allowed)
19
+
20
+ `colorFrom`: _string_
21
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
22
+
23
+ `colorTo`: _string_
24
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
25
+
26
+ `sdk`: _string_
27
+ Can be either `gradio` or `streamlit`
28
+
29
+ `sdk_version` : _string_
30
+ Only applicable for `streamlit` SDK.
31
+ See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
32
+
33
+ `app_file`: _string_
34
+ Path to your main application file (which contains either `gradio` or `streamlit` Python code).
35
+ Path is relative to the root of the repository.
36
+
37
+ `pinned`: _boolean_
38
+ Whether the Space stays on top of your list.
app.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from html import escape
2
+ import re
3
+ import streamlit as st
4
+ import pandas as pd, numpy as np
5
+ from transformers import CLIPProcessor, CLIPModel, FlavaModel, FlavaProcessor
6
+ from st_clickable_images import clickable_images
7
+
8
+ MODEL_NAMES = ["flava-full", "vit-base-patch32", "vit-base-patch16", "vit-large-patch14", "vit-large-patch14-336"]
9
+
10
+
11
+ @st.cache(allow_output_mutation=True)
12
+ def load():
13
+ df = {0: pd.read_csv("data.csv"), 1: pd.read_csv("data2.csv")}
14
+ models = {}
15
+ processors = {}
16
+ embeddings = {}
17
+ for name in MODEL_NAMES:
18
+ if "flava" not in name:
19
+ model = CLIPModel
20
+ processor = CLIPProcessor
21
+ prefix = "openai/clip-"
22
+ else:
23
+ model = FlavaModel
24
+ processor = FlavaProcessor
25
+ prefix = "facebook/"
26
+ models[name] = model.from_pretrained(f"{prefix}{name}")
27
+ processors[name] = processor.from_pretrained(f"{prefix}{name}")
28
+ embeddings[name] = {
29
+ 0: np.load(f"embeddings-{name}.npy"),
30
+ 1: np.load(f"embeddings2-{name}.npy"),
31
+ }
32
+ for k in [0, 1]:
33
+ embeddings[name][k] = embeddings[name][k] / np.linalg.norm(
34
+ embeddings[name][k], axis=1, keepdims=True
35
+ )
36
+ return models, processors, df, embeddings
37
+
38
+
39
+ models, processors, df, embeddings = load()
40
+ source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"}
41
+
42
+
43
+ def compute_text_embeddings(list_of_strings, name):
44
+ inputs = processors[name](text=list_of_strings, return_tensors="pt", padding=True)
45
+ result = models[name].get_text_features(**inputs)
46
+ if "flava" in name:
47
+ result = result[:, 0, :]
48
+ result = result.detach().numpy()
49
+ return result / np.linalg.norm(result, axis=1, keepdims=True)
50
+
51
+
52
+ def image_search(query, corpus, name, n_results=24):
53
+ positive_embeddings = None
54
+
55
+ def concatenate_embeddings(e1, e2):
56
+ if e1 is None:
57
+ return e2
58
+ else:
59
+ return np.concatenate((e1, e2), axis=0)
60
+
61
+ splitted_query = query.split("EXCLUDING ")
62
+ dot_product = 0
63
+ k = 0 if corpus == "Unsplash" else 1
64
+ if len(splitted_query[0]) > 0:
65
+ positive_queries = splitted_query[0].split(";")
66
+ for positive_query in positive_queries:
67
+ match = re.match(r"\[(Movies|Unsplash):(\d{1,5})\](.*)", positive_query)
68
+ if match:
69
+ corpus2, idx, remainder = match.groups()
70
+ idx, remainder = int(idx), remainder.strip()
71
+ k2 = 0 if corpus2 == "Unsplash" else 1
72
+ positive_embeddings = concatenate_embeddings(
73
+ positive_embeddings, embeddings[name][k2][idx : idx + 1, :]
74
+ )
75
+ if len(remainder) > 0:
76
+ positive_embeddings = concatenate_embeddings(
77
+ positive_embeddings, compute_text_embeddings([remainder], name)
78
+ )
79
+ else:
80
+ positive_embeddings = concatenate_embeddings(
81
+ positive_embeddings, compute_text_embeddings([positive_query], name)
82
+ )
83
+ dot_product = embeddings[name][k] @ positive_embeddings.T
84
+ dot_product = dot_product - np.median(dot_product, axis=0)
85
+ dot_product = dot_product / np.max(dot_product, axis=0, keepdims=True)
86
+ dot_product = np.min(dot_product, axis=1)
87
+
88
+ if len(splitted_query) > 1:
89
+ negative_queries = (" ".join(splitted_query[1:])).split(";")
90
+ negative_embeddings = compute_text_embeddings(negative_queries, name)
91
+ dot_product2 = embeddings[name][k] @ negative_embeddings.T
92
+ dot_product2 = dot_product2 - np.median(dot_product2, axis=0)
93
+ dot_product2 = dot_product2 / np.max(dot_product2, axis=0, keepdims=True)
94
+ dot_product -= np.max(np.maximum(dot_product2, 0), axis=1)
95
+
96
+ results = np.argsort(dot_product)[-1 : -n_results - 1 : -1]
97
+ return [
98
+ (
99
+ df[k].iloc[i]["path"],
100
+ df[k].iloc[i]["tooltip"] + source[k],
101
+ i,
102
+ )
103
+ for i in results
104
+ ]
105
+
106
+
107
+ description = """
108
+ # FLAVA Semantic Image-Text Search
109
+ """
110
+ instruction= """
111
+ **Enter your query and hit enter**
112
+ """
113
+
114
+ credit = """
115
+ *Built with FAIR's [FLAVA](https://arxiv.org/abs/2112.04482) models, 🤗 Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), 25k images from [Unsplash](https://unsplash.com/) and 8k images from [The Movie Database (TMDB)](https://www.themoviedb.org/)*
116
+
117
+ *Forked and inspired from a similar app available [here](https://huggingface.co/spaces/vivien/clip/)*
118
+ """
119
+
120
+ options = """
121
+ ## Compare
122
+ Check results for a single model or compare two models by using the dropdown below:
123
+ """
124
+
125
+ howto = """
126
+ ## Advanced Use
127
+ - Click on an image to use it as a query and find similar images
128
+ - Several queries, including one based on an image, can be combined (use "**;**" as a separator).
129
+ - Try "sunset at beach; small children".
130
+ - If the input includes "**EXCLUDING**", text following it will be used as a negative query.
131
+ - Try "a busy city street with dogs" and "a busy city street EXCLUDING dogs".
132
+ """
133
+
134
+ div_style = {
135
+ "display": "flex",
136
+ "justify-content": "center",
137
+ "flex-wrap": "wrap",
138
+ }
139
+
140
+
141
+ def main():
142
+ st.markdown(
143
+ """
144
+ <style>
145
+ .block-container{
146
+ max-width: 1200px;
147
+ }
148
+ div.row-widget.stRadio > div{
149
+ flex-direction:row;
150
+ display: flex;
151
+ justify-content: center;
152
+ }
153
+ div.row-widget.stRadio > div > label{
154
+ margin-left: 5px;
155
+ margin-right: 5px;
156
+ }
157
+ .row-widget {
158
+ margin-top: -25px;
159
+ }
160
+ section>div:first-child {
161
+ padding-top: 30px;
162
+ }
163
+ div.reportview-container > section:first-child{
164
+ max-width: 320px;
165
+ }
166
+ #MainMenu {
167
+ visibility: hidden;
168
+ }
169
+ footer {
170
+ visibility: hidden;
171
+ }
172
+ </style>""",
173
+ unsafe_allow_html=True,
174
+ )
175
+
176
+ st.sidebar.markdown(description)
177
+ st.sidebar.markdown(options)
178
+ mode = st.sidebar.selectbox(
179
+ "", ["Results for FLAVA full", "Comparison of 2 models"], index=0
180
+ )
181
+ st.sidebar.markdown(howto)
182
+ st.sidebar.markdown(credit)
183
+ _, c, _ = st.columns((1, 3, 1))
184
+ c.markdown(instruction)
185
+ if "query" in st.session_state:
186
+ query = c.text_input("", value=st.session_state["query"])
187
+ else:
188
+ query = c.text_input("", value="a busy city with tall buildings")
189
+ corpus = st.radio("", ["Unsplash", "Movies"])
190
+
191
+ models_dict = {
192
+ "FLAVA": "flava-full",
193
+ "ViT-B/32 (quickest)": "vit-base-patch32",
194
+ "ViT-B/16 (quick)": "vit-base-patch16",
195
+ "ViT-L/14 (slow)": "vit-large-patch14",
196
+ "ViT-L/14@336px (slowest)": "vit-large-patch14-336",
197
+ }
198
+
199
+ if "Comparison" in mode:
200
+ c1, c2 = st.columns((1, 1))
201
+ selection1 = c1.selectbox("", models_dict.keys(), index=0)
202
+ selection2 = c2.selectbox("", models_dict.keys(), index=3)
203
+ name1 = models_dict[selection1]
204
+ name2 = models_dict[selection2]
205
+ else:
206
+ name1 = MODEL_NAMES[0]
207
+
208
+ if len(query) > 0:
209
+ results1 = image_search(query, corpus, name1)
210
+ if "Comparison" in mode:
211
+ with c1:
212
+ clicked1 = clickable_images(
213
+ [result[0] for result in results1],
214
+ titles=[result[1] for result in results1],
215
+ div_style=div_style,
216
+ img_style={"margin": "2px", "height": "150px"},
217
+ key=query + corpus + name1 + "1",
218
+ )
219
+ results2 = image_search(query, corpus, name2)
220
+ with c2:
221
+ clicked2 = clickable_images(
222
+ [result[0] for result in results2],
223
+ titles=[result[1] for result in results2],
224
+ div_style=div_style,
225
+ img_style={"margin": "2px", "height": "150px"},
226
+ key=query + corpus + name2 + "2",
227
+ )
228
+ else:
229
+ clicked1 = clickable_images(
230
+ [result[0] for result in results1],
231
+ titles=[result[1] for result in results1],
232
+ div_style=div_style,
233
+ img_style={"margin": "2px", "height": "200px"},
234
+ key=query + corpus + name1 + "1",
235
+ )
236
+ clicked2 = -1
237
+
238
+ if clicked2 >= 0 or clicked1 >= 0:
239
+ change_query = False
240
+ if "last_clicked" not in st.session_state:
241
+ change_query = True
242
+ else:
243
+ if max(clicked2, clicked1) != st.session_state["last_clicked"]:
244
+ change_query = True
245
+ if change_query:
246
+ if clicked1 >= 0:
247
+ st.session_state["query"] = f"[{corpus}:{results1[clicked1][2]}]"
248
+ elif clicked2 >= 0:
249
+ st.session_state["query"] = f"[{corpus}:{results2[clicked2][2]}]"
250
+ st.experimental_rerun()
251
+
252
+
253
+ if __name__ == "__main__":
254
+ main()
data.csv ADDED
The diff for this file is too large to render. See raw diff
 
data2.csv ADDED
The diff for this file is too large to render. See raw diff
 
embeddings-flava-full.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:17f7b7a1f297f314f3728eb50e16a18780263fa9ec99b8286c58c5fb4b6853df
3
+ size 153354368
embeddings-vit-base-patch16.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:125430e11a4a415ec0c0fc5339f97544f0447e4b0a24c20f2e59f8852e706afc
3
+ size 51200128
embeddings-vit-base-patch32.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f7ebdff24079665faf58d07045056a63b5499753e3ffbda479691d53de3ab38
3
+ size 51200128
embeddings-vit-large-patch14-336.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f79f10ebe267b4ee7acd553dfe0ee31df846123630058a6d58c04bf22e0ad068
3
+ size 76800128
embeddings-vit-large-patch14.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64515f7d3d71137e2944f2c3d72c8df3e684b5d6a6ff7dcebb92370f7326ccfd
3
+ size 76800128
embeddings2-flava-full.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:089b694a9552c65f3fdf81a0d41df299bb00cf199ab0b59fe4dc7ac0ba5e0c31
3
+ size 49545344
embeddings2-vit-base-patch16.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:153cf3fae2385d51fe8729d3a1c059f611ca47a3fc501049708114d1bbf79049
3
+ size 16732288
embeddings2-vit-base-patch32.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7d545bed86121dac1cedcc1de61ea5295f5840c1eb751637e6628ac54faef81
3
+ size 16732288
embeddings2-vit-large-patch14-336.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e66eb377465fbfaa56cec079aa3e214533ceac43646f2ca78028ae4d8ad6d03
3
+ size 25098368
embeddings2-vit-large-patch14.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d730b33e758c2648419a96ac86d39516c59795e613c35700d3a64079e5a9a27
3
+ size 25098368
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ ftfy
4
+ numpy
5
+ pandas
6
+ st-clickable-images