Spaces:
Running
Running
import gradio as gr | |
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer | |
import sentence_transformers | |
from sentence_transformers import SentenceTransformer, util | |
import pickle | |
from PIL import Image | |
import os | |
from datasets import load_dataset | |
from huggingface_hub.hf_api import HfFolder | |
from sentence_transformers import SentenceTransformer, util | |
HfFolder.save_token('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') | |
## Define model | |
# model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
# processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
# tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") | |
from sentence_transformers import SentenceTransformer, util | |
model = SentenceTransformer('clip-ViT-B-32') | |
#Open the precomputed embeddings | |
#emb_filename = 'unsplash-25k-photos-embeddings.pkl' | |
ds_with_embeddings = load_dataset("kvriza8/image-embeddings", use_auth_token=True) | |
# img_names, img_emb = ds_with_embeddings['train']['image'], ds_with_embeddings['train']['embeddings'] | |
# with open(emb_filename, 'rb') as fIn: | |
# img_names, img_emb = pickle.load(fIn) | |
# #print(f'img_emb: {print(img_emb)}') | |
# #print(f'img_names: {print(img_names)}') | |
def search_text(query, top_k=1): | |
"""" Search an image based on the text query. | |
Args: | |
query ([string]): [query you want search for] | |
top_k (int, optional): [Amount of images o return]. Defaults to 1. | |
Returns: | |
[list]: [list of images that are related to the query.] | |
""" | |
# First, we encode the query. | |
inputs = tokenizer([query], padding=True, return_tensors="pt") | |
query_emb = model.get_text_features(**inputs) | |
# Then, we use the util.semantic_search function, which computes the cosine-similarity | |
# between the query embedding and all image embeddings. | |
# It then returns the top_k highest ranked images, which we output | |
hits = util.semantic_search(query_emb, img_emb, top_k=top_k)[0] | |
image=[] | |
for hit in hits: | |
#print(img_names[hit['corpus_id']]) | |
object = Image.open(os.path.join("photos/", img_names[hit['corpus_id']])) | |
image.append(object) | |
#print(f'array length is: {len(image)}') | |
return image | |
def get_image_from_text(text_prompt, number_to_retrieve=6): | |
prompt = model.encode(text_prompt) | |
ds_with_embeddings['train'].add_faiss_index(column='embeddings') | |
scores, retrieved_examples = ds_with_embeddings['train'].get_nearest_examples('embeddings', prompt,k=number_to_retrieve) | |
return retrieved_examples | |
# plt.figure(figsize=(15, 15)) | |
# columns = 3 | |
# for i in range(8): | |
# print('title', retrieved_examples['caption_summary'][i]) | |
# image = retrieved_examples['image'][i] | |
# plt.title(retrieved_examples['caption_summary'][i]) | |
# plt.imshow(image) | |
# plt.subplot(2, 3, i+1 ) | |
iface = gr.Interface( | |
title = "Text to Image using CLIP Model ๐ธ", | |
description = 'test', | |
article = "text", | |
fn=get_image_from_text, | |
inputs=[gr.Textbox(lines=4, | |
label="Insert your prompt", | |
placeholder="Text Here..."), | |
gr.Slider(0, 5, step=1)], | |
outputs=[gr.Gallery( | |
label="Retrieved images", show_label=False, elem_id="gallery" | |
)], | |
examples=[[("TEM image"), 2], | |
[("Nanoparticles"), 1], | |
[("ZnSe-ZnTe core-shell nanowire"), 2]] | |
).launch(debug=True) | |