Spaces:
Runtime error
Runtime error
import numpy as np | |
import faiss | |
import torch | |
from torchvision.transforms import ( | |
Compose, | |
Resize, | |
ToTensor, | |
Normalize, | |
InterpolationMode, | |
) | |
from PIL import Image | |
import gradio as gr | |
(ys,) = np.load("embs.npz").values() | |
model = torch.load( | |
"style-extractor-v0.2.0.ckpt", | |
map_location="cpu", | |
) | |
with open("urls.txt") as f: | |
urls = f.read().splitlines() | |
assert len(urls) == len(ys) | |
d = ys.shape[1] | |
index = faiss.IndexFlatL2(d) | |
index.is_trained | |
index.add(ys) | |
tf = Compose( | |
[ | |
Resize( | |
size=336, | |
interpolation=InterpolationMode.BICUBIC, | |
max_size=None, | |
antialias=True, | |
), | |
ToTensor(), | |
Normalize(mean=[0.4850, 0.4560, 0.4060], std=[0.2290, 0.2240, 0.2250]), | |
] | |
) | |
def get_emb(im: Image): | |
model.eval() | |
with torch.no_grad(): | |
return model(tf(im).unsqueeze(0)) | |
n_outputs = 50 | |
row_size = 5 | |
def f(im): | |
D, I = index.search(get_emb(im), n_outputs) | |
return [f"Distance: {d}\n![]({urls[i]})" for d, i in zip(D[0], I[0])] | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
"# Style Similarity Search\n\nFind artworks with a similar style from a small database (10k artists * 6img/artist)" | |
) | |
img = gr.Image(type="pil", label="Query", height=500) | |
btn = gr.Button(variant="primary", value="search") | |
outputs = [] | |
for i in range(-(n_outputs // (-row_size))): | |
with gr.Row(): | |
for _ in range(min(row_size, n_outputs - i * row_size)): | |
outputs.append(gr.Markdown(label=f"#{len(outputs) + 1}")) | |
btn.click(f, img, outputs) | |
demo.launch() |