Spaces:
Runtime error
Runtime error
import torchvision.datasets as datasets | |
import numpy as np | |
import clip | |
import torch | |
def get_similiarity(prompt, model_resnet, model_vit, top_k=3): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
data_dir = 'sample/sample/data' | |
image_arr = np.loadtxt("embeddings.csv", delimiter=",") | |
raw_dataset = datasets.ImageFolder(data_dir) | |
# получите список всех изображений | |
# create transformer-readable tokens | |
inputs = clip.tokenize(prompt).to(device) | |
text_emb = model_resnet.encode_text(inputs) | |
text_emb = text_emb.cpu().detach().numpy() | |
scores = np.dot(text_emb, image_arr.T) | |
# score_vit | |
# get the top k indices for most similar vecs | |
idx = np.argsort(-scores[0])[:top_k] | |
image_files = [] | |
for i in idx: | |
image_files.append(raw_dataset.imgs[i][0]) | |
image_arr_vit = np.loadtxt('embeddings_vit.csv', delimiter=",") | |
inputs_vit = clip.tokenize(prompt).to(device) | |
text_emb_vit = model_vit.encode_text(inputs_vit) | |
text_emb_vit = text_emb_vit.cpu().detach().numpy() | |
scores_vit = np.dot(text_emb_vit, image_arr_vit.T) | |
idx_vit = np.argsort(-scores_vit[0])[:top_k] | |
image_files_vit = [] | |
for i in idx_vit: | |
image_files_vit.append(raw_dataset.imgs[i][0]) | |
return image_files, image_files_vit | |
# def get_text_enc(input_text: str): | |
# text = clip.tokenize([input_text]).to(device) | |
# text_features = model.encode_text(text).cpu() | |
# text_features = text_features.cpu().detach().numpy() | |
# return text_features |