File size: 2,304 Bytes
875ae89 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import gradio as gr
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
import warnings
from create_image_embeddnigs import create_embeddings
from download_dataset import download_images
warnings.filterwarnings("ignore", category=UserWarning)
os.environ["TOKENIZERS_PARALLELISM"] = "true" # or "false"
download_images()
# get image embeddings. If file «image_embeddings.npy» exists, just load it, otherwise create it
if os.path.exists("image_embeddings.npy"):
image_embeddings = np.load("image_embeddings.npy")
else:
image_dir = "data/pictures"
batch_size = 32
device = "cuda" if torch.cuda.is_available() else "cpu"
image_embeddings = create_embeddings(image_dir, batch_size, device)
image_embeddings = image_embeddings / np.linalg.norm(image_embeddings, axis=1, keepdims=True)
def get_text_embeddings(input_text):
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
inputs = processor(text=input_text, return_tensors="pt", padding=True, truncation=True)
embeddings = model.get_text_features(**inputs)
vector = embeddings.detach().numpy().ravel()
return vector / np.linalg.norm(vector)
def cosine_similarity(a, b):
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
def find_similar_images(text_embedding, image_embeddings, top_k=4):
similarities = np.array([cosine_similarity(text_embedding, image_embedding) for image_embedding in image_embeddings])
top_k_indices = np.argsort(similarities)[-top_k:][::-1]
return top_k_indices
def get_similar_images(input_text):
text_embedding = get_text_embeddings(input_text)
top_k_indices = find_similar_images(text_embedding, image_embeddings)
image_paths = [os.path.join("data/pictures", f) for f in os.listdir("data/pictures") if f.endswith(('.png', '.jpg', '.jpeg'))]
similar_images = [image_paths[i] for i in top_k_indices]
return [Image.open(image_path) for image_path in similar_images]
if __name__ == "__main__":
iface = gr.Interface(fn=get_similar_images, inputs="text", outputs="gallery", title="Find Similar Images")
iface.launch(share=True)
|