|
import gradio as gr |
|
import torch |
|
from PIL import Image |
|
from transformers import CLIPProcessor, CLIPModel |
|
from torch.utils.data import Dataset, DataLoader |
|
import os |
|
import numpy as np |
|
import warnings |
|
from create_image_embeddnigs import create_embeddings |
|
from download_dataset import download_images |
|
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
os.environ["TOKENIZERS_PARALLELISM"] = "true" |
|
|
|
download_images() |
|
|
|
|
|
if os.path.exists("image_embeddings.npy"): |
|
image_embeddings = np.load("image_embeddings.npy") |
|
else: |
|
image_dir = "data/pictures" |
|
batch_size = 32 |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
image_embeddings = create_embeddings(image_dir, batch_size, device) |
|
|
|
image_embeddings = image_embeddings / np.linalg.norm(image_embeddings, axis=1, keepdims=True) |
|
|
|
def get_text_embeddings(input_text): |
|
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
|
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
inputs = processor(text=input_text, return_tensors="pt", padding=True, truncation=True) |
|
embeddings = model.get_text_features(**inputs) |
|
vector = embeddings.detach().numpy().ravel() |
|
return vector / np.linalg.norm(vector) |
|
|
|
def cosine_similarity(a, b): |
|
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) |
|
|
|
def find_similar_images(text_embedding, image_embeddings, top_k=4): |
|
similarities = np.array([cosine_similarity(text_embedding, image_embedding) for image_embedding in image_embeddings]) |
|
top_k_indices = np.argsort(similarities)[-top_k:][::-1] |
|
return top_k_indices |
|
|
|
def get_similar_images(input_text): |
|
text_embedding = get_text_embeddings(input_text) |
|
top_k_indices = find_similar_images(text_embedding, image_embeddings) |
|
image_paths = [os.path.join("data/pictures", f) for f in os.listdir("data/pictures") if f.endswith(('.png', '.jpg', '.jpeg'))] |
|
similar_images = [image_paths[i] for i in top_k_indices] |
|
return [Image.open(image_path) for image_path in similar_images] |
|
|
|
|
|
if __name__ == "__main__": |
|
iface = gr.Interface(fn=get_similar_images, inputs="text", outputs="gallery", title="Find Similar Images") |
|
iface.launch(share=True) |
|
|