# Importing all the necessary libraries import os import gradio as gr import torch from PIL import Image from tqdm import tqdm from trclip.trclip import Trclip from trclip.visualizer import image_retrieval_visualize, text_retrieval_visualize print(f'gr version : {gr.__version__}') import pickle import random import numpy as np # %% model_name = 'trclip-vitl14-e10' if not os.path.exists(model_name): os.system(f'git clone https://huggingface.co/yusufani/{model_name} --progress') # %% if not os.path.exists('TrCaption-trclip-vitl14-e10'): os.system(f'git clone https://huggingface.co/datasets/yusufani/TrCaption-trclip-vitl14-e10/ --progress') os.chdir('TrCaption-trclip-vitl14-e10') os.system(f'git lfs install') os.system(f' git lfs fetch') os.system(f' git lfs pull') os.chdir('..') # %% def load_image_embeddings(load_batch=True): path = os.path.join('TrCaption-trclip-vitl14-e10', 'image_embeddings') bs = 100_000 if load_batch: for i in tqdm(range(0, 3_100_000, bs), desc='Loading TrCaption Image embeddings'): with open(os.path.join(path, f'image_em_{i}.pkl'), 'rb') as f: yield pickle.load(f) return else: embeddings = [] for i in tqdm(range(0, 3_100_000, bs), desc='Loading TrCaption Image embeddings'): with open(os.path.join(path, f'image_em_{i}.pkl'), 'rb') as f: embeddings.append(pickle.load(f)) return torch.cat(embeddings, dim=0) def load_text_embeddings(load_batch=True): path = os.path.join('TrCaption-trclip-vitl14-e10', 'text_embeddings') bs = 100_000 if load_batch: for i in tqdm(range(0, 3_600_000, bs), desc='Loading TrCaption text embeddings'): with open(os.path.join(path, f'text_em_{i}.pkl'), 'rb') as f: yield pickle.load(f) return else: embeddings = [] for i in tqdm(range(0, 3_600_000, bs), desc='Loading TrCaption text embeddings'): with open(os.path.join(path, f'text_em_{i}.pkl'), 'rb') as f: embeddings.append(pickle.load(f)) return torch.cat(embeddings, dim=0) def load_metadata(): path = os.path.join('TrCaption-trclip-vitl14-e10', 'metadata.pkl') with open(path, 'rb') as f: metadata = pickle.load(f) trcap_texts = metadata['texts'] trcap_urls = metadata['image_urls'] return trcap_texts, trcap_urls def load_spesific_tensor(index, type, bs=100_000): part = index // bs idx = index % bs with open(os.path.join('TrCaption-trclip-vitl14-e10', f'{type}_embeddings', f'{type}_em_{part * bs}.pkl'), 'rb') as f: embeddings = pickle.load(f) return embeddings[idx] # %% trcap_texts, trcap_urls = load_metadata() # %% print(f'INFO : Model loading') model_path = os.path.join(model_name, 'pytorch_model.bin') trclip = Trclip(model_path, clip_model='ViT-L/14', device='cpu') # %% import datetime # %% def run_im(im1, use_trcap_images, text1, use_trcap_texts): print(f'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")} INFO : Image retrieval starting') f_texts_embeddings = None ims = None if use_trcap_images: print('INFO : TRCaption images used') im_paths = trcap_urls else: print('INFO : Own images used') # Images taken from user im_paths = [i.name for i in im1] ims = [Image.open(i) for i in im_paths] if use_trcap_texts: print(f'INFO : TRCaption texts used') random_indexes = random.sample(range(len(trcap_texts)), 2) # MAX 2 text are allowed in image retrieval UI limit f_texts_embeddings = [] for i in random_indexes: f_texts_embeddings.append(load_spesific_tensor(i, 'text')) f_texts_embeddings = torch.stack(f_texts_embeddings) texts = [trcap_texts[i] for i in random_indexes] else: print(f'INFO : Own texts used') texts = [i.strip() for i in text1.split('\n')[:2] if i.strip() != ''] if use_trcap_images: # This means that we will iterate over batches because Huggingface space has 16 gb limit :/// per_mode_probs = [] f_texts_embeddings = f_texts_embeddings if use_trcap_texts else trclip.get_text_features(texts) for f_image_embeddings in tqdm(load_image_embeddings(load_batch=True), desc='Running image retrieval'): batch_probs = trclip.get_results( text_features=f_texts_embeddings, image_features=f_image_embeddings, mode='per_text', return_probs=True) per_mode_probs.append(batch_probs) per_mode_probs = torch.cat(per_mode_probs, dim=1) per_mode_probs = per_mode_probs.softmax(dim=-1).cpu().detach().numpy() per_mode_indices = [np.argsort(prob)[::-1] for prob in per_mode_probs] else: per_mode_indices, per_mode_probs = trclip.get_results(texts=texts, images=ims, text_features=f_texts_embeddings, mode='per_text') print(f'per_mode_indices = {per_mode_indices}\n,per_mode_probs = {per_mode_probs} ') print(f'im_paths = {im_paths}') return image_retrieval_visualize(per_mode_indices, per_mode_probs, texts, im_paths, n_figure_in_column=2, n_images_in_figure=4, n_figure_in_row=1, save_fig=False, show=False, break_on_index=-1) def run_text(im1, use_trcap_images, text1, use_trcap_texts): print(f'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")} INFO : Text retrieval starting') f_image_embeddings = None ims = None if use_trcap_images: print('INFO : TRCaption images used') random_indexes = random.sample(range(len(trcap_urls)), 2) # MAX 2 text are allowed in image retrieval UI limit f_image_embeddings = [] for i in random_indexes: f_image_embeddings.append(load_spesific_tensor(i, 'image')) f_image_embeddings = torch.stack(f_image_embeddings) print(f'f_image_embeddings = {f_image_embeddings}') # Images taken from TRCAPTION im_paths = [trcap_urls[i] for i in random_indexes] print(f'im_paths = {im_paths}') else: print('INFO : Own images used') # Images taken from user im_paths = [i.name for i in im1[:2]] ims = [Image.open(i) for i in im_paths] if use_trcap_texts: texts = trcap_texts else: texts = [i.strip() for i in text1.split('\n')[:2] if i.strip() != ''] if use_trcap_texts: f_image_embeddings = f_image_embeddings if use_trcap_images else trclip.get_image_features(ims) per_mode_probs = [] for f_texts_embeddings in tqdm(load_text_embeddings(load_batch=True), desc='Running text retrieval'): batch_probs = trclip.get_results( text_features=f_texts_embeddings, image_features=f_image_embeddings, mode='per_image', return_probs=True) per_mode_probs.append(batch_probs) per_mode_probs = torch.cat(per_mode_probs, dim=1) per_mode_probs = per_mode_probs.softmax(dim=-1).cpu().detach().numpy() per_mode_indices = [np.argsort(prob)[::-1] for prob in per_mode_probs] else: per_mode_indices, per_mode_probs = trclip.get_results(texts=texts, images=ims, image_features=f_image_embeddings, mode='per_image') print(per_mode_indices) print(per_mode_probs) return text_retrieval_visualize(per_mode_indices, per_mode_probs, im_paths, texts, n_figure_in_column=4, n_texts_in_figure=4 if len(texts) > 4 else len(texts), n_figure_in_row=2, save_fig=False, show=False, break_on_index=-1, ) def change_textbox(choice): if choice == "Use Own Images": return gr.Image.update(visible=True) else: return gr.Image.update(visible=False) with gr.Blocks() as demo: gr.HTML("""

Trclip Demo

Trclip is Turkish port of real clip. In this space you can try your images or/and texts.
Also you can use pre calculated TrCaption embeddings.
Number of texts = 3533312
Number of images = 3070976
Some images are not available in the internet because I downloaded and calculated TrCaption embeddings long time ago. Don't be suprise if you encounter with Image not found :D

A GitHub Repository --- Paper( Not available yet ) 

Huggingface Space containers has 16 gb ram. TrCaption embeddings are totaly 20 gb. I did a lot of writing and reading to files to make this space workable. That's why it's running much slower if you're using TrCaption Embeddigs.

 
""") with gr.Tabs(): with gr.TabItem("Upload a Images"): im_input = gr.components.File(label="Image input", optional=True, file_count='multiple') is_trcap_ims = gr.Checkbox(label="Use TRCaption Images\n[Note: Random 2 sample selected in text retrieval mode]",default=True) with gr.Tabs(): with gr.TabItem("Input a text (Seperated by new line Max 2 for Image retrieval)"): text_input = gr.components.Textbox(label="Text input", optional=True , placeholder = "kedi\nköpek\nGemi\nKahvesini içmekte olan bir adam\n Kahvesini içmekte olan bir kadın\nAraba") is_trcap_texts = gr.Checkbox(label="Use TrCaption Captions \n[Note: Random 2 sample selected in image retrieval mode]",default=True) im_ret_but = gr.Button("Image Retrieval") text_ret_but = gr.Button("Text Retrieval") im_out = gr.components.Image() im_ret_but.click(run_im, inputs=[im_input, is_trcap_ims, text_input, is_trcap_texts], outputs=im_out) text_ret_but.click(run_text, inputs=[im_input, is_trcap_ims, text_input, is_trcap_texts], outputs=im_out) demo.launch() # %%