Spaces:
Running
Running
import gradio as gr | |
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer | |
import sentence_transformers | |
from sentence_transformers import SentenceTransformer, util | |
import pickle | |
from PIL import Image | |
import os | |
from datasets import load_dataset | |
from huggingface_hub.hf_api import HfFolder | |
import numpy as np | |
import torch | |
import os | |
from PIL import Image | |
import io | |
def convert_to_image(byte_data): | |
"""Convert byte strings to images | |
""" | |
return Image.open(io.BytesIO(byte_data)) | |
# Load the model and dataset | |
model = SentenceTransformer('clip-ViT-B-32') | |
ds_with_embeddings = load_dataset("kvriza8/clip_microscopy_image_text_embeddings") | |
# Initialize FAISS index once | |
ds_with_embeddings['train'].add_faiss_index(column='img_embeddings') | |
def get_image_from_text(text_prompt, number_to_retrieve=1): | |
prompt = model.encode(text_prompt) | |
scores, retrieved_examples = ds_with_embeddings['train'].get_nearest_examples('img_embeddings', prompt, k=number_to_retrieve) | |
# Convert byte images to PIL images | |
images = [convert_to_image(img) for img in retrieved_examples['images']] | |
captions = retrieved_examples['caption_summary'] | |
return images, captions | |
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
def get_image_from_image(query_image, number_to_retrieve=1): | |
if number_to_retrieve <= 0: | |
raise ValueError("Number to retrieve must be a positive integer") | |
image = Image.fromarray((query_image * 255).astype(np.uint8)) | |
inputs = clip_processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
image_features = clip_model.get_image_features(**inputs) | |
image_features_numpy = image_features.cpu().detach().numpy() | |
scores, retrieved_examples = ds_with_embeddings['train'].get_nearest_examples('img_embeddings', image_features_numpy, k=number_to_retrieve) | |
images = [convert_to_image(img) for img in retrieved_examples['images']] | |
captions = retrieved_examples['caption_summary'] | |
return images, captions | |
def plot_images(text_prompt="", number_to_retrieve=1, query_image=None): | |
if query_image is not None: | |
# Handle image input | |
sample_images, sample_titles = get_image_from_image(query_image, number_to_retrieve) | |
elif text_prompt: | |
# Handle text input | |
sample_images, sample_titles = get_image_from_text(text_prompt, number_to_retrieve) | |
else: | |
# Handle empty input | |
return [], "No input provided" | |
concatenated_captions = "\n".join(sample_titles) | |
return sample_images, concatenated_captions | |
iface = gr.Interface( | |
title="Microscopy image retrieval", | |
fn=plot_images, | |
inputs=[ | |
gr.Textbox(lines=4, label="Insert your prompt", placeholder="Text Here..."), | |
gr.Slider(0, 8, step=1), | |
gr.Image(label="Or Upload an Image") | |
], | |
outputs=[gr.Gallery(label="Retrieved Images"), gr.Textbox(label="Image Captions")], | |
examples=[["TEM image", 2], ["Nanoparticles", 1], ["ZnSe-ZnTe core-shell nanowire", 2]] | |
).launch(debug=True) | |