TrCLIP / app.py
yusufani's picture
Date time added
56aa817
raw
history blame contribute delete
No virus
14.1 kB
# Importing all the necessary libraries
import os
import gradio as gr
import torch
from PIL import Image
from tqdm import tqdm
from trclip.trclip import Trclip
from trclip.visualizer import image_retrieval_visualize, text_retrieval_visualize
print(f'gr version : {gr.__version__}')
import pickle
import random
import numpy as np
# %%
model_name = 'trclip-vitl14-e10'
if not os.path.exists(model_name):
os.system(f'git clone https://huggingface.co/yusufani/{model_name} --progress')
# %%
if not os.path.exists('TrCaption-trclip-vitl14-e10'):
os.system(f'git clone https://huggingface.co/datasets/yusufani/TrCaption-trclip-vitl14-e10/ --progress')
os.chdir('TrCaption-trclip-vitl14-e10')
os.system(f'git lfs install')
os.system(f' git lfs fetch')
os.system(f' git lfs pull')
os.chdir('..')
# %%
def load_image_embeddings(load_batch=True):
path = os.path.join('TrCaption-trclip-vitl14-e10', 'image_embeddings')
bs = 100_000
if load_batch:
for i in tqdm(range(0, 3_100_000, bs), desc='Loading TrCaption Image embeddings'):
with open(os.path.join(path, f'image_em_{i}.pkl'), 'rb') as f:
yield pickle.load(f)
return
else:
embeddings = []
for i in tqdm(range(0, 3_100_000, bs), desc='Loading TrCaption Image embeddings'):
with open(os.path.join(path, f'image_em_{i}.pkl'), 'rb') as f:
embeddings.append(pickle.load(f))
return torch.cat(embeddings, dim=0)
def load_text_embeddings(load_batch=True):
path = os.path.join('TrCaption-trclip-vitl14-e10', 'text_embeddings')
bs = 100_000
if load_batch:
for i in tqdm(range(0, 3_600_000, bs), desc='Loading TrCaption text embeddings'):
with open(os.path.join(path, f'text_em_{i}.pkl'), 'rb') as f:
yield pickle.load(f)
return
else:
embeddings = []
for i in tqdm(range(0, 3_600_000, bs), desc='Loading TrCaption text embeddings'):
with open(os.path.join(path, f'text_em_{i}.pkl'), 'rb') as f:
embeddings.append(pickle.load(f))
return torch.cat(embeddings, dim=0)
def load_metadata():
path = os.path.join('TrCaption-trclip-vitl14-e10', 'metadata.pkl')
with open(path, 'rb') as f:
metadata = pickle.load(f)
trcap_texts = metadata['texts']
trcap_urls = metadata['image_urls']
return trcap_texts, trcap_urls
def load_spesific_tensor(index, type, bs=100_000):
part = index // bs
idx = index % bs
with open(os.path.join('TrCaption-trclip-vitl14-e10', f'{type}_embeddings', f'{type}_em_{part * bs}.pkl'), 'rb') as f:
embeddings = pickle.load(f)
return embeddings[idx]
# %%
trcap_texts, trcap_urls = load_metadata()
# %%
print(f'INFO : Model loading')
model_path = os.path.join(model_name, 'pytorch_model.bin')
trclip = Trclip(model_path, clip_model='ViT-L/14', device='cpu')
# %%
import datetime
# %%
def run_im(im1, use_trcap_images, text1, use_trcap_texts):
print(f'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")} INFO : Image retrieval starting')
f_texts_embeddings = None
ims = None
if use_trcap_images:
print('INFO : TRCaption images used')
im_paths = trcap_urls
else:
print('INFO : Own images used')
# Images taken from user
im_paths = [i.name for i in im1]
ims = [Image.open(i) for i in im_paths]
if use_trcap_texts:
print(f'INFO : TRCaption texts used')
random_indexes = random.sample(range(len(trcap_texts)), 2) # MAX 2 text are allowed in image retrieval UI limit
f_texts_embeddings = []
for i in random_indexes:
f_texts_embeddings.append(load_spesific_tensor(i, 'text'))
f_texts_embeddings = torch.stack(f_texts_embeddings)
texts = [trcap_texts[i] for i in random_indexes]
else:
print(f'INFO : Own texts used')
texts = [i.strip() for i in text1.split('\n')[:2] if i.strip() != '']
if use_trcap_images: # This means that we will iterate over batches because Huggingface space has 16 gb limit :///
per_mode_probs = []
f_texts_embeddings = f_texts_embeddings if use_trcap_texts else trclip.get_text_features(texts)
for f_image_embeddings in tqdm(load_image_embeddings(load_batch=True), desc='Running image retrieval'):
batch_probs = trclip.get_results(
text_features=f_texts_embeddings, image_features=f_image_embeddings, mode='per_text', return_probs=True)
per_mode_probs.append(batch_probs)
per_mode_probs = torch.cat(per_mode_probs, dim=1)
per_mode_probs = per_mode_probs.softmax(dim=-1).cpu().detach().numpy()
per_mode_indices = [np.argsort(prob)[::-1] for prob in per_mode_probs]
else:
per_mode_indices, per_mode_probs = trclip.get_results(texts=texts, images=ims, text_features=f_texts_embeddings, mode='per_text')
print(f'per_mode_indices = {per_mode_indices}\n,per_mode_probs = {per_mode_probs} ')
print(f'im_paths = {im_paths}')
return image_retrieval_visualize(per_mode_indices, per_mode_probs, texts, im_paths,
n_figure_in_column=2,
n_images_in_figure=4, n_figure_in_row=1, save_fig=False,
show=False,
break_on_index=-1)
def run_text(im1, use_trcap_images, text1, use_trcap_texts):
print(f'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")} INFO : Text retrieval starting')
f_image_embeddings = None
ims = None
if use_trcap_images:
print('INFO : TRCaption images used')
random_indexes = random.sample(range(len(trcap_urls)), 2) # MAX 2 text are allowed in image retrieval UI limit
f_image_embeddings = []
for i in random_indexes:
f_image_embeddings.append(load_spesific_tensor(i, 'image'))
f_image_embeddings = torch.stack(f_image_embeddings)
print(f'f_image_embeddings = {f_image_embeddings}')
# Images taken from TRCAPTION
im_paths = [trcap_urls[i] for i in random_indexes]
print(f'im_paths = {im_paths}')
else:
print('INFO : Own images used')
# Images taken from user
im_paths = [i.name for i in im1[:2]]
ims = [Image.open(i) for i in im_paths]
if use_trcap_texts:
texts = trcap_texts
else:
texts = [i.strip() for i in text1.split('\n')[:2] if i.strip() != '']
if use_trcap_texts:
f_image_embeddings = f_image_embeddings if use_trcap_images else trclip.get_image_features(ims)
per_mode_probs = []
for f_texts_embeddings in tqdm(load_text_embeddings(load_batch=True), desc='Running text retrieval'):
batch_probs = trclip.get_results(
text_features=f_texts_embeddings, image_features=f_image_embeddings, mode='per_image', return_probs=True)
per_mode_probs.append(batch_probs)
per_mode_probs = torch.cat(per_mode_probs, dim=1)
per_mode_probs = per_mode_probs.softmax(dim=-1).cpu().detach().numpy()
per_mode_indices = [np.argsort(prob)[::-1] for prob in per_mode_probs]
else:
per_mode_indices, per_mode_probs = trclip.get_results(texts=texts, images=ims, image_features=f_image_embeddings, mode='per_image')
print(per_mode_indices)
print(per_mode_probs)
return text_retrieval_visualize(per_mode_indices, per_mode_probs, im_paths, texts,
n_figure_in_column=4,
n_texts_in_figure=4 if len(texts) > 4 else len(texts),
n_figure_in_row=2,
save_fig=False,
show=False,
break_on_index=-1,
)
def change_textbox(choice):
if choice == "Use Own Images":
return gr.Image.update(visible=True)
else:
return gr.Image.update(visible=False)
with gr.Blocks() as demo:
gr.HTML("""
<div style="text-align: center; max-width: 650px; margin: 0 auto;">
<div
style="
display: inline-flex;
align-items: center;
gap: 0.8rem;
font-size: 1.75rem;
"
>
<svg
width="0.65em"
height="0.65em"
viewBox="0 0 115 115"
fill="none"
xmlns="http://www.w3.org/2000/svg"
>
<rect width="23" height="23" fill="white"></rect>
<rect y="69" width="23" height="23" fill="white"></rect>
<rect x="23" width="23" height="23" fill="#AEAEAE"></rect>
<rect x="23" y="69" width="23" height="23" fill="#AEAEAE"></rect>
<rect x="46" width="23" height="23" fill="white"></rect>
<rect x="46" y="69" width="23" height="23" fill="white"></rect>
<rect x="69" width="23" height="23" fill="black"></rect>
<rect x="69" y="69" width="23" height="23" fill="black"></rect>
<rect x="92" width="23" height="23" fill="#D9D9D9"></rect>
<rect x="92" y="69" width="23" height="23" fill="#AEAEAE"></rect>
<rect x="115" y="46" width="23" height="23" fill="white"></rect>
<rect x="115" y="115" width="23" height="23" fill="white"></rect>
<rect x="115" y="69" width="23" height="23" fill="#D9D9D9"></rect>
<rect x="92" y="46" width="23" height="23" fill="#AEAEAE"></rect>
<rect x="92" y="115" width="23" height="23" fill="#AEAEAE"></rect>
<rect x="92" y="69" width="23" height="23" fill="white"></rect>
<rect x="69" y="46" width="23" height="23" fill="white"></rect>
<rect x="69" y="115" width="23" height="23" fill="white"></rect>
<rect x="69" y="69" width="23" height="23" fill="#D9D9D9"></rect>
<rect x="46" y="46" width="23" height="23" fill="black"></rect>
<rect x="46" y="115" width="23" height="23" fill="black"></rect>
<rect x="46" y="69" width="23" height="23" fill="black"></rect>
<rect x="23" y="46" width="23" height="23" fill="#D9D9D9"></rect>
<rect x="23" y="115" width="23" height="23" fill="#AEAEAE"></rect>
<rect x="23" y="69" width="23" height="23" fill="black"></rect>
</svg>
<h1 style="font-weight: 1500; margin-bottom: 7px;">
Trclip Demo
<a
href="https://github.com/yusufani/TrCLIP"
style="text-decoration: underline;"
target="_blank"
></a
Github Trclip:
</h1>
</div>
<p style="margin-bottom: 10px; font-size: 94%">
Trclip is Turkish port of real clip. In this space you can try your images or/and texts.
<br>Also you can use pre calculated TrCaption embeddings.
<br>Number of texts = 3533312
<br>Number of images = 3070976
<br>
Some images are not available in the internet because I downloaded and calculated TrCaption embeddings long time ago. Don't be suprise if you encounter with Image not found :D
<div style="text-align: center;font-size: 100%">
<p><strong><span style="background-color: #000000; color: #ffffff;"><a style="background-color: #000000; color: #ffffff;" href="https://github.com/yusufani/TrCLIP">A GitHub Repository</a> </span>--- <span style="background-color: #000000;"><span style="color: #ffffff;">Paper( Not available yet )&nbsp;</span></span></strong></p>
</div>
</p>
</div>
<div style="text-align: center; margin: 0 auto;">
<p style="margin-bottom: 10px; font-size: 75%" ><em>Huggingface Space containers has 16 gb ram. TrCaption embeddings are totaly 20 gb. </em><em>I did a lot of writing and reading to files to make this space workable. That's why<span style="background-color: #ff6600; color: #ffffff;"> <strong>it's running much slower if you're using TrCaption Embeddig</strong>s</span>.</em></p>
<div class="sc-jSFjdj sc-iCoGMd jcTaHb kMthTr">
<div class="sc-iqAclL xfxEN">
<div class="sc-bdnxRM fJdnBK sc-crzoAE DykGo">
<div class="sc-gtsrHT gfuSqG">&nbsp;</div>
</div>
</div>
</div>
</div>
""")
with gr.Tabs():
with gr.TabItem("Upload a Images"):
im_input = gr.components.File(label="Image input", optional=True, file_count='multiple')
is_trcap_ims = gr.Checkbox(label="Use TRCaption Images\n[Note: Random 2 sample selected in text retrieval mode]",default=True)
with gr.Tabs():
with gr.TabItem("Input a text (Seperated by new line Max 2 for Image retrieval)"):
text_input = gr.components.Textbox(label="Text input", optional=True , placeholder = "kedi\nköpek\nGemi\nKahvesini içmekte olan bir adam\n Kahvesini içmekte olan bir kadın\nAraba")
is_trcap_texts = gr.Checkbox(label="Use TrCaption Captions \n[Note: Random 2 sample selected in image retrieval mode]",default=True)
im_ret_but = gr.Button("Image Retrieval")
text_ret_but = gr.Button("Text Retrieval")
im_out = gr.components.Image()
im_ret_but.click(run_im, inputs=[im_input, is_trcap_ims, text_input, is_trcap_texts], outputs=im_out)
text_ret_but.click(run_text, inputs=[im_input, is_trcap_ims, text_input, is_trcap_texts], outputs=im_out)
demo.launch()
# %%