File size: 2,304 Bytes
875ae89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import gradio as gr
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
import warnings
from create_image_embeddnigs import create_embeddings
from download_dataset import download_images

warnings.filterwarnings("ignore", category=UserWarning)
os.environ["TOKENIZERS_PARALLELISM"] = "true"  # or "false"

download_images()

# get image embeddings. If file «image_embeddings.npy» exists, just load it, otherwise create it
if os.path.exists("image_embeddings.npy"):
    image_embeddings = np.load("image_embeddings.npy")
else:
    image_dir = "data/pictures"
    batch_size = 32
    device = "cuda" if torch.cuda.is_available() else "cpu"
    image_embeddings = create_embeddings(image_dir, batch_size, device)

image_embeddings = image_embeddings / np.linalg.norm(image_embeddings, axis=1, keepdims=True)

def get_text_embeddings(input_text):
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    inputs = processor(text=input_text, return_tensors="pt", padding=True, truncation=True)
    embeddings = model.get_text_features(**inputs)
    vector = embeddings.detach().numpy().ravel()
    return vector / np.linalg.norm(vector)

def cosine_similarity(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

def find_similar_images(text_embedding, image_embeddings, top_k=4):
    similarities = np.array([cosine_similarity(text_embedding, image_embedding) for image_embedding in image_embeddings])
    top_k_indices = np.argsort(similarities)[-top_k:][::-1]
    return top_k_indices

def get_similar_images(input_text):
    text_embedding = get_text_embeddings(input_text)
    top_k_indices = find_similar_images(text_embedding, image_embeddings)
    image_paths = [os.path.join("data/pictures", f) for f in os.listdir("data/pictures") if f.endswith(('.png', '.jpg', '.jpeg'))]
    similar_images = [image_paths[i] for i in top_k_indices]
    return [Image.open(image_path) for image_path in similar_images]


if __name__ == "__main__":
    iface = gr.Interface(fn=get_similar_images, inputs="text", outputs="gallery", title="Find Similar Images")
    iface.launch(share=True)