yusufani commited on
Commit
94078d1
1 Parent(s): 3789cfb

Initial Release

Browse files
Files changed (5) hide show
  1. .gitignore +5 -0
  2. README.md +7 -6
  3. app.py +263 -0
  4. not-found.png +0 -0
  5. requirements.txt +24 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ TrCaption/*
2
+ trclip-vitl14-e10/*
3
+ TrCaption-trclip-vitl14-e10/*
4
+ TrCaption-trclip-vitl14-e10-old/*
5
+ .idea/*
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: TrCLIP
3
- emoji: 📉
4
- colorFrom: pink
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 3.1.7
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Trclip
3
+ emoji: 📈
4
+ colorFrom: indigo
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 3.0.20
8
  app_file: app.py
9
  pinned: false
10
+ license: afl-3.0
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
app.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Importing all the necessary libraries
2
+ import os
3
+
4
+ import gradio as gr
5
+ import torch
6
+ from PIL import Image
7
+ from tqdm import tqdm
8
+ from trclip.trclip import Trclip
9
+ from trclip.visualizer import image_retrieval_visualize, text_retrieval_visualize
10
+
11
+ print(f'gr version : {gr.__version__}')
12
+ import pickle
13
+ import random
14
+
15
+ # %%
16
+ model_name = 'trclip-vitl14-e10'
17
+ if not os.path.exists(model_name):
18
+ os.system(f'git clone https://huggingface.co/yusufani/{model_name} --progress')
19
+ # %%
20
+ if not os.path.exists('TrCaption-trclip-vitl14-e10'):
21
+ os.system(f'git clone https://huggingface.co/datasets/yusufani/TrCaption-trclip-vitl14-e10/ --progress')
22
+ os.chdir('TrCaption-trclip-vitl14-e10')
23
+ os.system(f'git lfs install')
24
+ os.system(f' git lfs fetch')
25
+ os.system(f' git lfs pull')
26
+ os.chdir('..')
27
+
28
+
29
+ # %%
30
+
31
+ def load_image_embeddings():
32
+ path = os.path.join('TrCaption-trclip-vitl14-e10', 'image_embeddings')
33
+ bs = 100_000
34
+ embeddings = []
35
+
36
+ for i in tqdm(range(0, 3_100_000, bs), desc='Loading TrCaption Image embeddings'):
37
+ with open(os.path.join(path, f'image_em_{i}.pkl'), 'rb') as f:
38
+ embeddings.append(pickle.load(f))
39
+ return torch.cat(embeddings, dim=0)
40
+
41
+ def load_text_embeddings():
42
+ path = os.path.join('TrCaption-trclip-vitl14-e10', 'text_embeddings')
43
+ bs = 100_000
44
+ embeddings = []
45
+ for i in tqdm(range(0, 3_600_000, bs), desc='Loading TrCaption text embeddings'):
46
+ with open(os.path.join(path, f'text_em_{i}.pkl'), 'rb') as f:
47
+ embeddings.append(pickle.load(f))
48
+ return torch.cat(embeddings, dim=0)
49
+
50
+
51
+ def load_metadata():
52
+ path = os.path.join('TrCaption-trclip-vitl14-e10', 'metadata.pkl')
53
+ with open(path, 'rb') as f:
54
+ metadata = pickle.load(f)
55
+ trcap_texts = metadata['texts']
56
+ trcap_urls = metadata['image_urls']
57
+ return trcap_texts, trcap_urls
58
+
59
+ def load_spesific_tensor(index , type , bs= 100_000):
60
+ part = index // bs
61
+ idx = index % bs
62
+ with open(os.path.join('TrCaption-trclip-vitl14-e10', f'{type}_embeddings', f'{type}_em_{part*bs}.pkl'), 'rb') as f:
63
+ embeddings = pickle.load(f)
64
+ return embeddings[idx]
65
+
66
+ # %%
67
+
68
+ image_embeddings = None
69
+ text_embeddings = None
70
+
71
+ #%%
72
+ trcap_texts, trcap_urls = load_metadata()
73
+ # %%
74
+ model_path = os.path.join(model_name, 'pytorch_model.bin')
75
+ trclip = Trclip(model_path, clip_model='ViT-L/14', device='cpu')
76
+ #%%
77
+ import psutil
78
+
79
+ print(f"First used memory {psutil.virtual_memory().used/float(1<<30):,.0f} GB" , )
80
+ # %%
81
+
82
+ def run_im(im1, use_trcap_images, text1, use_trcap_texts):
83
+ f_texts_embeddings = None
84
+ f_image_embeddings = None
85
+ global image_embeddings
86
+ global text_embeddings
87
+ ims = None
88
+ print("im2", use_trcap_images)
89
+ if use_trcap_images:
90
+ print('TRCaption images used')
91
+ # Images taken from TRCAPTION
92
+ im_paths = trcap_urls
93
+ if image_embeddings is None:
94
+ print(f"First used memory {psutil.virtual_memory().used / float(1 << 30):,.0f} GB", )
95
+ text_embeddings = None
96
+ image_embeddings = load_image_embeddings()
97
+ print(f"First used memory {psutil.virtual_memory().used / float(1 << 30):,.0f} GB", )
98
+ f_image_embeddings = image_embeddings
99
+ else:
100
+ # Images taken from user
101
+ im_paths = [i.name for i in im1]
102
+ ims = [Image.open(i) for i in im_paths]
103
+ if use_trcap_texts:
104
+ random_indexes = random.sample(range(len(trcap_texts)), 2) # MAX 2 text are allowed in image retrieval UI limit
105
+ f_texts_embeddings = []
106
+ for i in random_indexes:
107
+ f_texts_embeddings.append(load_spesific_tensor(i, 'text'))
108
+ f_texts_embeddings = torch.stack(f_texts_embeddings)
109
+ texts = [trcap_texts[i] for i in random_indexes]
110
+ else:
111
+ texts = [i.trim() for i in text1.split('\n')[:2] if i.trim() != '']
112
+
113
+ per_mode_indices, per_mode_probs = trclip.get_results(texts=texts, images=ims, text_features=f_texts_embeddings, image_features=f_image_embeddings, mode='per_text')
114
+
115
+ print(f'per_mode_indices = {per_mode_indices}\n,per_mode_probs = {per_mode_probs} ')
116
+ print(f'im_paths = {im_paths}')
117
+ return image_retrieval_visualize(per_mode_indices, per_mode_probs, texts, im_paths,
118
+ n_figure_in_column=2,
119
+ n_images_in_figure=4, n_figure_in_row=1, save_fig=False,
120
+ show=False,
121
+ break_on_index=-1)
122
+
123
+
124
+ def run_text(im1, use_trcap_images, text1, use_trcap_texts):
125
+ f_texts_embeddings = None
126
+ f_image_embeddings = None
127
+ global image_embeddings
128
+ global text_embeddings
129
+ ims = None
130
+ if use_trcap_images:
131
+ random_indexes = random.sample(range(len(trcap_urls)), 2) # MAX 2 text are allowed in image retrieval UI limit
132
+ f_image_embeddings = []
133
+ for i in random_indexes:
134
+ f_image_embeddings.append(load_spesific_tensor(i, 'image'))
135
+ f_image_embeddings = torch.stack(f_image_embeddings)
136
+ print('TRCaption images used')
137
+ # Images taken from TRCAPTION
138
+ im_paths = [trcap_urls[i] for i in random_indexes]
139
+ else:
140
+ # Images taken from user
141
+ im_paths = [i.name for i in im1[:2]]
142
+ ims = [Image.open(i) for i in im_paths]
143
+
144
+ if use_trcap_texts:
145
+ if text_embeddings is None:
146
+ print(f"Used memory {psutil.virtual_memory().used / float(1 << 30):,.0f} GB", )
147
+ image_embeddings = None
148
+ print(f"Image embd deleted used memory {psutil.virtual_memory().used / float(1 << 30):,.0f} GB", )
149
+ text_embeddings = load_text_embeddings()
150
+ print(f"Text embed used memory {psutil.virtual_memory().used / float(1 << 30):,.0f} GB", )
151
+
152
+ f_texts_embeddings = text_embeddings
153
+ texts = trcap_texts
154
+ else:
155
+ texts = [i.trim() for i in text1.split('\n') if i.trim() != '']
156
+
157
+ per_mode_indices, per_mode_probs = trclip.get_results(texts=texts, images=ims, image_features=f_image_embeddings, text_features=f_texts_embeddings, mode='per_image')
158
+ print(per_mode_indices)
159
+ print(per_mode_probs)
160
+ return text_retrieval_visualize(per_mode_indices, per_mode_probs, im_paths, texts,
161
+ n_figure_in_column=4,
162
+ n_texts_in_figure=4 if len(texts) > 4 else len(texts),
163
+ n_figure_in_row=2,
164
+ save_fig=False,
165
+ show=False,
166
+ break_on_index=-1,
167
+ )
168
+
169
+
170
+ def change_textbox(choice):
171
+ if choice == "Use Own Images":
172
+
173
+ return gr.Image.update(visible=True)
174
+ else:
175
+ return gr.Image.update(visible=False)
176
+
177
+
178
+ with gr.Blocks() as demo:
179
+ gr.HTML("""
180
+ <div style="text-align: center; max-width: 650px; margin: 0 auto;">
181
+ <div
182
+ style="
183
+ display: inline-flex;
184
+ align-items: center;
185
+ gap: 0.8rem;
186
+ font-size: 1.75rem;
187
+ "
188
+ >
189
+ <svg
190
+ width="0.65em"
191
+ height="0.65em"
192
+ viewBox="0 0 115 115"
193
+ fill="none"
194
+ xmlns="http://www.w3.org/2000/svg"
195
+ >
196
+ <rect width="23" height="23" fill="white"></rect>
197
+ <rect y="69" width="23" height="23" fill="white"></rect>
198
+ <rect x="23" width="23" height="23" fill="#AEAEAE"></rect>
199
+ <rect x="23" y="69" width="23" height="23" fill="#AEAEAE"></rect>
200
+ <rect x="46" width="23" height="23" fill="white"></rect>
201
+ <rect x="46" y="69" width="23" height="23" fill="white"></rect>
202
+ <rect x="69" width="23" height="23" fill="black"></rect>
203
+ <rect x="69" y="69" width="23" height="23" fill="black"></rect>
204
+ <rect x="92" width="23" height="23" fill="#D9D9D9"></rect>
205
+ <rect x="92" y="69" width="23" height="23" fill="#AEAEAE"></rect>
206
+ <rect x="115" y="46" width="23" height="23" fill="white"></rect>
207
+ <rect x="115" y="115" width="23" height="23" fill="white"></rect>
208
+ <rect x="115" y="69" width="23" height="23" fill="#D9D9D9"></rect>
209
+ <rect x="92" y="46" width="23" height="23" fill="#AEAEAE"></rect>
210
+ <rect x="92" y="115" width="23" height="23" fill="#AEAEAE"></rect>
211
+ <rect x="92" y="69" width="23" height="23" fill="white"></rect>
212
+ <rect x="69" y="46" width="23" height="23" fill="white"></rect>
213
+ <rect x="69" y="115" width="23" height="23" fill="white"></rect>
214
+ <rect x="69" y="69" width="23" height="23" fill="#D9D9D9"></rect>
215
+ <rect x="46" y="46" width="23" height="23" fill="black"></rect>
216
+ <rect x="46" y="115" width="23" height="23" fill="black"></rect>
217
+ <rect x="46" y="69" width="23" height="23" fill="black"></rect>
218
+ <rect x="23" y="46" width="23" height="23" fill="#D9D9D9"></rect>
219
+ <rect x="23" y="115" width="23" height="23" fill="#AEAEAE"></rect>
220
+ <rect x="23" y="69" width="23" height="23" fill="black"></rect>
221
+ </svg>
222
+ <h1 style="font-weight: 900; margin-bottom: 7px;">
223
+ Trclip Demo
224
+ <a
225
+ href="https://github.com/yusufani/TrCLIP"
226
+ style="text-decoration: underline;"
227
+ target="_blank"
228
+ ></a
229
+ Github Trclip:
230
+ </h1>
231
+ </div>
232
+ <p style="margin-bottom: 10px; font-size: 94%">
233
+ Trclip is Turkish port of real clip. In this space you can try your images or/and texts.
234
+ Also you can use pre calculated TrCaption embeddings.
235
+ Number of texts = 3533312
236
+ Number of images = 3070976
237
+
238
+ >
239
+ </p>
240
+ </div>
241
+ """)
242
+
243
+ with gr.Tabs():
244
+ with gr.TabItem("Use Own Images"):
245
+ im_input = gr.components.File(label="Image input", optional=True, file_count='multiple')
246
+ is_trcap_ims = gr.Checkbox(label="Use TRCaption Images\nNote: ( Random 2 sample selected in text retrieval mode )")
247
+
248
+ with gr.Tabs():
249
+ with gr.TabItem("Input a text (Seperated by new line Max 2 for Image retrieval)"):
250
+ text_input = gr.components.Textbox(label="Text input", optional=True)
251
+ is_trcap_texts = gr.Checkbox(label="Use TrCaption Captions \nNote: ( Random 2 sample selected in image retrieval mode")
252
+
253
+ im_ret_but = gr.Button("Image Retrieval")
254
+ text_ret_but = gr.Button("Text Retrieval")
255
+
256
+ im_out = gr.components.Image()
257
+
258
+ im_ret_but.click(run_im, inputs=[im_input, is_trcap_ims, text_input, is_trcap_texts], outputs=im_out)
259
+ text_ret_but.click(run_text, inputs=[im_input, is_trcap_ims, text_input, is_trcap_texts], outputs=im_out)
260
+
261
+ demo.launch()
262
+
263
+ # %%
not-found.png ADDED
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ftfy
2
+ regex
3
+ tqdm
4
+ omegaconf
5
+ pytorch-lightning
6
+ kornia
7
+ imageio-ffmpeg
8
+ einops
9
+ torch
10
+ torchvision
11
+ Pillow
12
+ numpy
13
+ imageio
14
+ trclip
15
+ torch>= 0.7
16
+ transformers>=4
17
+ numpy>=1.20
18
+ git+https://github.com/openai/CLIP.git
19
+ tqdm
20
+ more_itertools
21
+ cairosvg
22
+ gradio==3.0.19
23
+ gdown
24
+ psutil