Spaces:
Runtime error
Runtime error
File size: 1,796 Bytes
c4bc1f2 f1a0ba2 c4bc1f2 f1a0ba2 c4bc1f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import torch
from img2art_search.models.model import ViTImageSearchModel
import numpy as np
from sklearn.neighbors import NearestNeighbors
from img2art_search.data.dataset import ImageRetrievalDataset
from img2art_search.data.transforms import transform
from tqdm import tqdm
import os
def extract_embedding(image_data, fine_tuned_model):
image = image_data.unsqueeze(0)
with torch.no_grad():
embedding = fine_tuned_model(image).cpu().numpy()
return embedding
def load_fine_tuned_model():
fine_tuned_model = ViTImageSearchModel()
fine_tuned_model.load_state_dict(torch.load("results/model.pth"))
fine_tuned_model.eval()
return fine_tuned_model
def create_gallery(dataset, fine_tuned_model, save=True):
gallery_embeddings = []
for img_path, _ in tqdm(dataset):
embedding = extract_embedding(img_path, fine_tuned_model)
gallery_embeddings.append(embedding)
gallery_embeddings = np.vstack(gallery_embeddings)
if save:
np.save("results/embeddings", gallery_embeddings)
return gallery_embeddings
def search_image(query_image_path, gallery_embeddings, k=4):
fine_tuned_model = load_fine_tuned_model()
query_embedding = extract_embedding(query_image_path, fine_tuned_model)
neighbors = NearestNeighbors(n_neighbors=k, metric="euclidean")
neighbors.fit(gallery_embeddings)
distances, indices = neighbors.kneighbors(query_embedding)
return indices, distances
def create_gallery_embeddings(folder): # noqa
x = np.array([f"{folder}/{file}" for file in os.listdir(folder)])
gallery_data = np.array([x, x])
gallery_dataset = ImageRetrievalDataset(gallery_data, transform=transform)
fine_tuned_model = load_fine_tuned_model()
create_gallery(gallery_dataset, fine_tuned_model)
|