import gradio as gr from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer import sentence_transformers from sentence_transformers import SentenceTransformer, util import pickle from PIL import Image import os from datasets import load_dataset from huggingface_hub.hf_api import HfFolder from sentence_transformers import SentenceTransformer, util HfFolder.save_token('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') ## Define model # model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") # processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") # tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") from sentence_transformers import SentenceTransformer, util model = SentenceTransformer('clip-ViT-B-32') #Open the precomputed embeddings #emb_filename = 'unsplash-25k-photos-embeddings.pkl' ds_with_embeddings = load_dataset("kvriza8/image-embeddings", use_auth_token=True) ds_with_embeddings['train'].add_faiss_index(column='embeddings') # img_names, img_emb = ds_with_embeddings['train']['image'], ds_with_embeddings['train']['embeddings'] # with open(emb_filename, 'rb') as fIn: # img_names, img_emb = pickle.load(fIn) # #print(f'img_emb: {print(img_emb)}') # #print(f'img_names: {print(img_names)}') def search_text(query, top_k=1): """" Search an image based on the text query. Args: query ([string]): [query you want search for] top_k (int, optional): [Amount of images o return]. Defaults to 1. Returns: [list]: [list of images that are related to the query.] """ # First, we encode the query. inputs = tokenizer([query], padding=True, return_tensors="pt") query_emb = model.get_text_features(**inputs) # Then, we use the util.semantic_search function, which computes the cosine-similarity # between the query embedding and all image embeddings. # It then returns the top_k highest ranked images, which we output hits = util.semantic_search(query_emb, img_emb, top_k=top_k)[0] image=[] for hit in hits: #print(img_names[hit['corpus_id']]) object = Image.open(os.path.join("photos/", img_names[hit['corpus_id']])) image.append(object) #print(f'array length is: {len(image)}') return image def get_image_from_text(text_prompt, number_to_retrieve=6): prompt = model.encode(text_prompt) scores, retrieved_examples = ds_with_embeddings['train'].get_nearest_examples('embeddings', prompt,k=number_to_retrieve) return retrieved_examples # plt.figure(figsize=(15, 15)) # columns = 3 # for i in range(8): # print('title', retrieved_examples['caption_summary'][i]) # image = retrieved_examples['image'][i] # plt.title(retrieved_examples['caption_summary'][i]) # plt.imshow(image) # plt.subplot(2, 3, i+1 ) iface = gr.Interface( title = "Text to Image using CLIP Model 📸", description = 'test', article = "text", fn=get_image_from_text, inputs=[gr.Textbox(lines=4, label="Insert your prompt", placeholder="Text Here..."), gr.Slider(0, 5, step=1)], outputs=[gr.Gallery( label="Retrieved images", show_label=False, elem_id="gallery" )], examples=[[("TEM image"), 2], [("Nanoparticles"), 1], [("ZnSe-ZnTe core-shell nanowire"), 2]] ).launch(debug=True)