File size: 1,796 Bytes
c4bc1f2
f1a0ba2
c4bc1f2
 
f1a0ba2
 
c4bc1f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
from img2art_search.models.model import ViTImageSearchModel
import numpy as np
from sklearn.neighbors import NearestNeighbors
from img2art_search.data.dataset import ImageRetrievalDataset
from img2art_search.data.transforms import transform
from tqdm import tqdm
import os

def extract_embedding(image_data, fine_tuned_model):
    image = image_data.unsqueeze(0)
    with torch.no_grad():
        embedding = fine_tuned_model(image).cpu().numpy()
    return embedding


def load_fine_tuned_model():
    fine_tuned_model = ViTImageSearchModel()
    fine_tuned_model.load_state_dict(torch.load("results/model.pth"))
    fine_tuned_model.eval()
    return fine_tuned_model


def create_gallery(dataset, fine_tuned_model, save=True):
    gallery_embeddings = []
    for img_path, _ in tqdm(dataset):
        embedding = extract_embedding(img_path, fine_tuned_model)
        gallery_embeddings.append(embedding)
    gallery_embeddings = np.vstack(gallery_embeddings)
    if save:
        np.save("results/embeddings", gallery_embeddings)
    return gallery_embeddings


def search_image(query_image_path, gallery_embeddings, k=4):
    fine_tuned_model = load_fine_tuned_model()
    query_embedding = extract_embedding(query_image_path, fine_tuned_model)
    neighbors = NearestNeighbors(n_neighbors=k, metric="euclidean")
    neighbors.fit(gallery_embeddings)
    distances, indices = neighbors.kneighbors(query_embedding)
    return indices, distances


def create_gallery_embeddings(folder):  # noqa
    x = np.array([f"{folder}/{file}" for file in os.listdir(folder)])
    gallery_data = np.array([x, x])
    gallery_dataset = ImageRetrievalDataset(gallery_data, transform=transform)
    fine_tuned_model = load_fine_tuned_model()
    create_gallery(gallery_dataset, fine_tuned_model)