style-similarity / main.py
gustproof's picture
Create main.py
04573a7 verified
raw
history blame
1.61 kB
import numpy as np
import faiss
import torch
from torchvision.transforms import (
Compose,
Resize,
ToTensor,
Normalize,
InterpolationMode,
)
from PIL import Image
import gradio as gr
(ys,) = np.load("embs.npz").values()
model = torch.load(
"style-extractor-v0.2.0.ckpt",
map_location="cpu",
)
with open("urls.txt") as f:
urls = f.read().splitlines()
assert len(urls) == len(ys)
d = ys.shape[1]
index = faiss.IndexFlatL2(d)
index.is_trained
index.add(ys)
tf = Compose(
[
Resize(
size=336,
interpolation=InterpolationMode.BICUBIC,
max_size=None,
antialias=True,
),
ToTensor(),
Normalize(mean=[0.4850, 0.4560, 0.4060], std=[0.2290, 0.2240, 0.2250]),
]
)
def get_emb(im: Image):
model.eval()
with torch.no_grad():
return model(tf(im).unsqueeze(0))
n_outputs = 50
row_size = 5
def f(im):
D, I = index.search(get_emb(im), n_outputs)
return [f"Distance: {d}\n![]({urls[i]})" for d, i in zip(D[0], I[0])]
with gr.Blocks() as demo:
gr.Markdown(
"# Style Similarity Search\n\nFind artworks with a similar style from a small database (10k artists * 6img/artist)"
)
img = gr.Image(type="pil", label="Query", height=500)
btn = gr.Button(variant="primary", value="search")
outputs = []
for i in range(-(n_outputs // (-row_size))):
with gr.Row():
for _ in range(min(row_size, n_outputs - i * row_size)):
outputs.append(gr.Markdown(label=f"#{len(outputs) + 1}"))
btn.click(f, img, outputs)
demo.launch()