yusufani commited on
Commit
0fe83e2
1 Parent(s): 94078d1

Initial Release

Browse files
Files changed (1) hide show
  1. app.py +93 -55
app.py CHANGED
@@ -12,6 +12,8 @@ print(f'gr version : {gr.__version__}')
12
  import pickle
13
  import random
14
 
 
 
15
  # %%
16
  model_name = 'trclip-vitl14-e10'
17
  if not os.path.exists(model_name):
@@ -28,24 +30,37 @@ if not os.path.exists('TrCaption-trclip-vitl14-e10'):
28
 
29
  # %%
30
 
31
- def load_image_embeddings():
32
  path = os.path.join('TrCaption-trclip-vitl14-e10', 'image_embeddings')
33
  bs = 100_000
34
- embeddings = []
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- for i in tqdm(range(0, 3_100_000, bs), desc='Loading TrCaption Image embeddings'):
37
- with open(os.path.join(path, f'image_em_{i}.pkl'), 'rb') as f:
38
- embeddings.append(pickle.load(f))
39
- return torch.cat(embeddings, dim=0)
40
 
41
- def load_text_embeddings():
42
  path = os.path.join('TrCaption-trclip-vitl14-e10', 'text_embeddings')
43
  bs = 100_000
44
- embeddings = []
45
- for i in tqdm(range(0, 3_600_000, bs), desc='Loading TrCaption text embeddings'):
46
- with open(os.path.join(path, f'text_em_{i}.pkl'), 'rb') as f:
47
- embeddings.append(pickle.load(f))
48
- return torch.cat(embeddings, dim=0)
 
 
 
 
 
 
49
 
50
 
51
  def load_metadata():
@@ -56,61 +71,64 @@ def load_metadata():
56
  trcap_urls = metadata['image_urls']
57
  return trcap_texts, trcap_urls
58
 
59
- def load_spesific_tensor(index , type , bs= 100_000):
 
60
  part = index // bs
61
  idx = index % bs
62
- with open(os.path.join('TrCaption-trclip-vitl14-e10', f'{type}_embeddings', f'{type}_em_{part*bs}.pkl'), 'rb') as f:
63
  embeddings = pickle.load(f)
64
  return embeddings[idx]
65
 
66
- # %%
67
 
68
- image_embeddings = None
69
- text_embeddings = None
70
-
71
- #%%
72
  trcap_texts, trcap_urls = load_metadata()
73
  # %%
 
74
  model_path = os.path.join(model_name, 'pytorch_model.bin')
75
  trclip = Trclip(model_path, clip_model='ViT-L/14', device='cpu')
76
- #%%
77
- import psutil
78
-
79
- print(f"First used memory {psutil.virtual_memory().used/float(1<<30):,.0f} GB" , )
80
  # %%
81
 
 
 
 
82
  def run_im(im1, use_trcap_images, text1, use_trcap_texts):
 
83
  f_texts_embeddings = None
84
- f_image_embeddings = None
85
- global image_embeddings
86
- global text_embeddings
87
  ims = None
88
- print("im2", use_trcap_images)
89
  if use_trcap_images:
90
- print('TRCaption images used')
91
- # Images taken from TRCAPTION
92
  im_paths = trcap_urls
93
- if image_embeddings is None:
94
- print(f"First used memory {psutil.virtual_memory().used / float(1 << 30):,.0f} GB", )
95
- text_embeddings = None
96
- image_embeddings = load_image_embeddings()
97
- print(f"First used memory {psutil.virtual_memory().used / float(1 << 30):,.0f} GB", )
98
- f_image_embeddings = image_embeddings
99
  else:
 
100
  # Images taken from user
101
  im_paths = [i.name for i in im1]
102
  ims = [Image.open(i) for i in im_paths]
103
  if use_trcap_texts:
 
104
  random_indexes = random.sample(range(len(trcap_texts)), 2) # MAX 2 text are allowed in image retrieval UI limit
105
  f_texts_embeddings = []
106
  for i in random_indexes:
107
  f_texts_embeddings.append(load_spesific_tensor(i, 'text'))
108
  f_texts_embeddings = torch.stack(f_texts_embeddings)
109
  texts = [trcap_texts[i] for i in random_indexes]
 
110
  else:
111
- texts = [i.trim() for i in text1.split('\n')[:2] if i.trim() != '']
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
- per_mode_indices, per_mode_probs = trclip.get_results(texts=texts, images=ims, text_features=f_texts_embeddings, image_features=f_image_embeddings, mode='per_text')
 
114
 
115
  print(f'per_mode_indices = {per_mode_indices}\n,per_mode_probs = {per_mode_probs} ')
116
  print(f'im_paths = {im_paths}')
@@ -122,39 +140,45 @@ def run_im(im1, use_trcap_images, text1, use_trcap_texts):
122
 
123
 
124
  def run_text(im1, use_trcap_images, text1, use_trcap_texts):
125
- f_texts_embeddings = None
126
  f_image_embeddings = None
127
- global image_embeddings
128
- global text_embeddings
129
  ims = None
130
  if use_trcap_images:
 
131
  random_indexes = random.sample(range(len(trcap_urls)), 2) # MAX 2 text are allowed in image retrieval UI limit
132
  f_image_embeddings = []
133
  for i in random_indexes:
134
  f_image_embeddings.append(load_spesific_tensor(i, 'image'))
135
  f_image_embeddings = torch.stack(f_image_embeddings)
136
- print('TRCaption images used')
137
  # Images taken from TRCAPTION
138
  im_paths = [trcap_urls[i] for i in random_indexes]
 
 
139
  else:
 
140
  # Images taken from user
141
  im_paths = [i.name for i in im1[:2]]
142
  ims = [Image.open(i) for i in im_paths]
143
 
144
  if use_trcap_texts:
145
- if text_embeddings is None:
146
- print(f"Used memory {psutil.virtual_memory().used / float(1 << 30):,.0f} GB", )
147
- image_embeddings = None
148
- print(f"Image embd deleted used memory {psutil.virtual_memory().used / float(1 << 30):,.0f} GB", )
149
- text_embeddings = load_text_embeddings()
150
- print(f"Text embed used memory {psutil.virtual_memory().used / float(1 << 30):,.0f} GB", )
151
-
152
- f_texts_embeddings = text_embeddings
153
  texts = trcap_texts
154
  else:
155
- texts = [i.trim() for i in text1.split('\n') if i.trim() != '']
 
 
 
 
 
 
 
 
 
 
 
156
 
157
- per_mode_indices, per_mode_probs = trclip.get_results(texts=texts, images=ims, image_features=f_image_embeddings, text_features=f_texts_embeddings, mode='per_image')
 
158
  print(per_mode_indices)
159
  print(per_mode_probs)
160
  return text_retrieval_visualize(per_mode_indices, per_mode_probs, im_paths, texts,
@@ -219,7 +243,7 @@ with gr.Blocks() as demo:
219
  <rect x="23" y="115" width="23" height="23" fill="#AEAEAE"></rect>
220
  <rect x="23" y="69" width="23" height="23" fill="black"></rect>
221
  </svg>
222
- <h1 style="font-weight: 900; margin-bottom: 7px;">
223
  Trclip Demo
224
  <a
225
  href="https://github.com/yusufani/TrCLIP"
@@ -234,21 +258,35 @@ with gr.Blocks() as demo:
234
  Also you can use pre calculated TrCaption embeddings.
235
  Number of texts = 3533312
236
  Number of images = 3070976
 
 
 
237
 
238
- >
239
  </p>
 
 
 
 
 
 
 
 
 
 
 
240
  </div>
241
  """)
242
 
243
  with gr.Tabs():
244
  with gr.TabItem("Use Own Images"):
245
  im_input = gr.components.File(label="Image input", optional=True, file_count='multiple')
246
- is_trcap_ims = gr.Checkbox(label="Use TRCaption Images\nNote: ( Random 2 sample selected in text retrieval mode )")
247
 
248
  with gr.Tabs():
249
  with gr.TabItem("Input a text (Seperated by new line Max 2 for Image retrieval)"):
250
  text_input = gr.components.Textbox(label="Text input", optional=True)
251
- is_trcap_texts = gr.Checkbox(label="Use TrCaption Captions \nNote: ( Random 2 sample selected in image retrieval mode")
252
 
253
  im_ret_but = gr.Button("Image Retrieval")
254
  text_ret_but = gr.Button("Text Retrieval")
 
12
  import pickle
13
  import random
14
 
15
+ import numpy as np
16
+
17
  # %%
18
  model_name = 'trclip-vitl14-e10'
19
  if not os.path.exists(model_name):
 
30
 
31
  # %%
32
 
33
+ def load_image_embeddings(load_batch=True):
34
  path = os.path.join('TrCaption-trclip-vitl14-e10', 'image_embeddings')
35
  bs = 100_000
36
+ if load_batch:
37
+ for i in tqdm(range(0, 3_100_000, bs), desc='Loading TrCaption Image embeddings'):
38
+ with open(os.path.join(path, f'image_em_{i}.pkl'), 'rb') as f:
39
+ yield pickle.load(f)
40
+ return
41
+
42
+ else:
43
+ embeddings = []
44
+ for i in tqdm(range(0, 3_100_000, bs), desc='Loading TrCaption Image embeddings'):
45
+ with open(os.path.join(path, f'image_em_{i}.pkl'), 'rb') as f:
46
+ embeddings.append(pickle.load(f))
47
+ return torch.cat(embeddings, dim=0)
48
 
 
 
 
 
49
 
50
+ def load_text_embeddings(load_batch=True):
51
  path = os.path.join('TrCaption-trclip-vitl14-e10', 'text_embeddings')
52
  bs = 100_000
53
+ if load_batch:
54
+ for i in tqdm(range(0, 3_600_000, bs), desc='Loading TrCaption text embeddings'):
55
+ with open(os.path.join(path, f'text_em_{i}.pkl'), 'rb') as f:
56
+ yield pickle.load(f)
57
+ return
58
+ else:
59
+ embeddings = []
60
+ for i in tqdm(range(0, 3_600_000, bs), desc='Loading TrCaption text embeddings'):
61
+ with open(os.path.join(path, f'text_em_{i}.pkl'), 'rb') as f:
62
+ embeddings.append(pickle.load(f))
63
+ return torch.cat(embeddings, dim=0)
64
 
65
 
66
  def load_metadata():
 
71
  trcap_urls = metadata['image_urls']
72
  return trcap_texts, trcap_urls
73
 
74
+
75
+ def load_spesific_tensor(index, type, bs=100_000):
76
  part = index // bs
77
  idx = index % bs
78
+ with open(os.path.join('TrCaption-trclip-vitl14-e10', f'{type}_embeddings', f'{type}_em_{part * bs}.pkl'), 'rb') as f:
79
  embeddings = pickle.load(f)
80
  return embeddings[idx]
81
 
 
82
 
83
+ # %%
 
 
 
84
  trcap_texts, trcap_urls = load_metadata()
85
  # %%
86
+ print(f'INFO : Model loading')
87
  model_path = os.path.join(model_name, 'pytorch_model.bin')
88
  trclip = Trclip(model_path, clip_model='ViT-L/14', device='cpu')
 
 
 
 
89
  # %%
90
 
91
+
92
+
93
+ # %%
94
  def run_im(im1, use_trcap_images, text1, use_trcap_texts):
95
+ print(f'INFO : Image retrieval starting')
96
  f_texts_embeddings = None
 
 
 
97
  ims = None
 
98
  if use_trcap_images:
99
+ print('INFO : TRCaption images used')
 
100
  im_paths = trcap_urls
 
 
 
 
 
 
101
  else:
102
+ print('INFO : Own images used')
103
  # Images taken from user
104
  im_paths = [i.name for i in im1]
105
  ims = [Image.open(i) for i in im_paths]
106
  if use_trcap_texts:
107
+ print(f'INFO : TRCaption texts used')
108
  random_indexes = random.sample(range(len(trcap_texts)), 2) # MAX 2 text are allowed in image retrieval UI limit
109
  f_texts_embeddings = []
110
  for i in random_indexes:
111
  f_texts_embeddings.append(load_spesific_tensor(i, 'text'))
112
  f_texts_embeddings = torch.stack(f_texts_embeddings)
113
  texts = [trcap_texts[i] for i in random_indexes]
114
+
115
  else:
116
+ print(f'INFO : Own texts used')
117
+ texts = [i.strip() for i in text1.split('\n')[:2] if i.strip() != '']
118
+
119
+ if use_trcap_images: # This means that we will iterate over batches because Huggingface space has 16 gb limit :///
120
+ per_mode_probs = []
121
+ f_texts_embeddings = f_texts_embeddings if use_trcap_texts else trclip.get_text_features(texts)
122
+ for f_image_embeddings in tqdm(load_image_embeddings(load_batch=True), desc='Running image retrieval'):
123
+ batch_probs = trclip.get_results(
124
+ text_features=f_texts_embeddings, image_features=f_image_embeddings, mode='per_text', return_probs=True)
125
+ per_mode_probs.append(batch_probs)
126
+ per_mode_probs = torch.cat(per_mode_probs, dim=1)
127
+ per_mode_probs = per_mode_probs.softmax(dim=-1).cpu().detach().numpy()
128
+ per_mode_indices = [np.argsort(prob)[::-1] for prob in per_mode_probs]
129
 
130
+ else:
131
+ per_mode_indices, per_mode_probs = trclip.get_results(texts=texts, images=ims, text_features=f_texts_embeddings, mode='per_text')
132
 
133
  print(f'per_mode_indices = {per_mode_indices}\n,per_mode_probs = {per_mode_probs} ')
134
  print(f'im_paths = {im_paths}')
 
140
 
141
 
142
  def run_text(im1, use_trcap_images, text1, use_trcap_texts):
143
+ print(f'INFO : Image retrieval starting')
144
  f_image_embeddings = None
 
 
145
  ims = None
146
  if use_trcap_images:
147
+ print('INFO : TRCaption images used')
148
  random_indexes = random.sample(range(len(trcap_urls)), 2) # MAX 2 text are allowed in image retrieval UI limit
149
  f_image_embeddings = []
150
  for i in random_indexes:
151
  f_image_embeddings.append(load_spesific_tensor(i, 'image'))
152
  f_image_embeddings = torch.stack(f_image_embeddings)
153
+ print(f'f_image_embeddings = {f_image_embeddings}')
154
  # Images taken from TRCAPTION
155
  im_paths = [trcap_urls[i] for i in random_indexes]
156
+ print(f'im_paths = {im_paths}')
157
+
158
  else:
159
+ print('INFO : Own images used')
160
  # Images taken from user
161
  im_paths = [i.name for i in im1[:2]]
162
  ims = [Image.open(i) for i in im_paths]
163
 
164
  if use_trcap_texts:
 
 
 
 
 
 
 
 
165
  texts = trcap_texts
166
  else:
167
+ texts = [i.strip() for i in text1.split('\n')[:2] if i.strip() != '']
168
+
169
+ if use_trcap_texts:
170
+ f_image_embeddings = f_image_embeddings if use_trcap_images else trclip.get_image_features(ims)
171
+ per_mode_probs = []
172
+ for f_texts_embeddings in tqdm(load_text_embeddings(load_batch=True), desc='Running text retrieval'):
173
+ batch_probs = trclip.get_results(
174
+ text_features=f_texts_embeddings, image_features=f_image_embeddings, mode='per_image', return_probs=True)
175
+ per_mode_probs.append(batch_probs)
176
+ per_mode_probs = torch.cat(per_mode_probs, dim=1)
177
+ per_mode_probs = per_mode_probs.softmax(dim=-1).cpu().detach().numpy()
178
+ per_mode_indices = [np.argsort(prob)[::-1] for prob in per_mode_probs]
179
 
180
+ else:
181
+ per_mode_indices, per_mode_probs = trclip.get_results(texts=texts, images=ims, image_features=f_image_embeddings, mode='per_image')
182
  print(per_mode_indices)
183
  print(per_mode_probs)
184
  return text_retrieval_visualize(per_mode_indices, per_mode_probs, im_paths, texts,
 
243
  <rect x="23" y="115" width="23" height="23" fill="#AEAEAE"></rect>
244
  <rect x="23" y="69" width="23" height="23" fill="black"></rect>
245
  </svg>
246
+ <h1 style="font-weight: 1500; margin-bottom: 7px;">
247
  Trclip Demo
248
  <a
249
  href="https://github.com/yusufani/TrCLIP"
 
258
  Also you can use pre calculated TrCaption embeddings.
259
  Number of texts = 3533312
260
  Number of images = 3070976
261
+
262
+ Some images are not available in the internet because I downloaded and calculated TrCaption embeddings long time ago. Don't be suprise if you encounter with Image not found :D
263
+
264
 
265
+
266
  </p>
267
+ <p style="margin-bottom: 10px; font-size: 75%" ><em>Huggingface Space containers has 16 gb ram. TrCaption embeddings are totaly 20 gb. </em><em>I did a lot of writing and reading to files to make this space workable. That's why<span style="background-color: #ff6600; color: #ffffff;"> <strong>it's running much slower if you're using TrCaption Embeddig</strong>s</span>.</em></p>
268
+ <div class="sc-jSFjdj sc-iCoGMd jcTaHb kMthTr">
269
+ <div class="sc-iqAclL xfxEN">
270
+ <div class="sc-bdnxRM fJdnBK sc-crzoAE DykGo">
271
+ <div class="sc-gtsrHT gfuSqG">&nbsp;</div>
272
+ </div>
273
+ </div>
274
+ </div>
275
+ <div class="sc-jSFjdj sc-gKAaRy jcTaHb hydYaP">
276
+ <div class="sc-pNWdM lfZLSv">&nbsp;</div>
277
+ </div>
278
  </div>
279
  """)
280
 
281
  with gr.Tabs():
282
  with gr.TabItem("Use Own Images"):
283
  im_input = gr.components.File(label="Image input", optional=True, file_count='multiple')
284
+ is_trcap_ims = gr.Checkbox(label="Use TRCaption Images\n[Note: Random 2 sample selected in text retrieval mode )]")
285
 
286
  with gr.Tabs():
287
  with gr.TabItem("Input a text (Seperated by new line Max 2 for Image retrieval)"):
288
  text_input = gr.components.Textbox(label="Text input", optional=True)
289
+ is_trcap_texts = gr.Checkbox(label="Use TrCaption Captions \n[Note: Random 2 sample selected in image retrieval mode]")
290
 
291
  im_ret_but = gr.Button("Image Retrieval")
292
  text_ret_but = gr.Button("Text Retrieval")