katerinavr
update the code
9a86ed3 unverified
import gradio as gr
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
import sentence_transformers
from sentence_transformers import SentenceTransformer, util
import pickle
from PIL import Image
import os
from datasets import load_dataset
from huggingface_hub.hf_api import HfFolder
import numpy as np
import torch
import os
from PIL import Image
import io
def convert_to_image(byte_data):
"""Convert byte strings to images
"""
return Image.open(io.BytesIO(byte_data))
# Load the model and dataset
model = SentenceTransformer('clip-ViT-B-32')
ds_with_embeddings = load_dataset("kvriza8/clip_microscopy_image_text_embeddings")
# Initialize FAISS index once
ds_with_embeddings['train'].add_faiss_index(column='img_embeddings')
def get_image_from_text(text_prompt, number_to_retrieve=1):
prompt = model.encode(text_prompt)
scores, retrieved_examples = ds_with_embeddings['train'].get_nearest_examples('img_embeddings', prompt, k=number_to_retrieve)
# Convert byte images to PIL images
images = [convert_to_image(img) for img in retrieved_examples['images']]
captions = retrieved_examples['caption_summary']
return images, captions
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
def get_image_from_image(query_image, number_to_retrieve=1):
if number_to_retrieve <= 0:
raise ValueError("Number to retrieve must be a positive integer")
image = Image.fromarray((query_image * 255).astype(np.uint8))
inputs = clip_processor(images=image, return_tensors="pt")
with torch.no_grad():
image_features = clip_model.get_image_features(**inputs)
image_features_numpy = image_features.cpu().detach().numpy()
scores, retrieved_examples = ds_with_embeddings['train'].get_nearest_examples('img_embeddings', image_features_numpy, k=number_to_retrieve)
images = [convert_to_image(img) for img in retrieved_examples['images']]
captions = retrieved_examples['caption_summary']
return images, captions
def plot_images(text_prompt="", number_to_retrieve=1, query_image=None):
if query_image is not None:
# Handle image input
sample_images, sample_titles = get_image_from_image(query_image, number_to_retrieve)
elif text_prompt:
# Handle text input
sample_images, sample_titles = get_image_from_text(text_prompt, number_to_retrieve)
else:
# Handle empty input
return [], "No input provided"
concatenated_captions = "\n".join(sample_titles)
return sample_images, concatenated_captions
iface = gr.Interface(
title="Microscopy image retrieval",
fn=plot_images,
inputs=[
gr.Textbox(lines=4, label="Insert your prompt", placeholder="Text Here..."),
gr.Slider(0, 8, step=1),
gr.Image(label="Or Upload an Image")
],
outputs=[gr.Gallery(label="Retrieved Images"), gr.Textbox(label="Image Captions")],
examples=[["TEM image", 2], ["Nanoparticles", 1], ["ZnSe-ZnTe core-shell nanowire", 2]]
).launch(debug=True)