File size: 2,472 Bytes
04573a7
 
 
 
 
 
 
 
 
accc6c7
04573a7
 
 
 
a819f61
04573a7
a819f61
04573a7
accc6c7
04573a7
 
a819f61
04573a7
 
a819f61
04573a7
 
006e560
a819f61
04573a7
accc6c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04573a7
 
accc6c7
04573a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f615499
04573a7
accc6c7
a819f61
04573a7
 
1026aeb
04573a7
 
 
 
 
 
 
 
 
a819f61
accc6c7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import numpy as np
import faiss
import torch
from torchvision.transforms import (
    Compose,
    Resize,
    ToTensor,
    Normalize,
    InterpolationMode,
    CenterCrop,
)
from PIL import Image
import gradio as gr

print("starting...")
(ys,) = np.load("embs.npz").values()
print("loaded embs")
model = torch.load(
    "style-extractor-v0.3.0.ckpt",
    map_location="cpu",
)
print("loaded extractor")
with open("urls.txt") as f:
    urls = f.read().splitlines()
print("loaded urls")
assert len(urls) == len(ys)
d = ys.shape[1]
index = faiss.IndexHNSWFlat(d, 32)
print("building index")
index.add(ys)
print("index built")


def MyResize(area, d):
    def f(im: Image):
        w, h = im.size
        s = (area / w / h) ** 0.5
        wd, hd = int(s * w / d), int(s * h / d)
        e = lambda a, b: 1 - min(a, b) / max(a, b)
        wd, hd = min(
            (
                (ww * d, hh * d)
                for ww, hh in [(wd + i, hd + j) for i in (0, 1) for j in (0, 1)]
                if ww * d * hh * d <= area
            ),
            key=lambda wh: e(wh[0] / wh[1], w / h),
        )

        return Compose(
            [
                Resize(
                    (int(h * wd / w), wd) if wd / w > hd / h else (hd, int(w * hd / h)),
                    InterpolationMode.BICUBIC,
                ),
                CenterCrop((hd, wd)),
            ]
        )(im)

    return f


tf = Compose(
    [
        MyResize((518 * 1.3) ** 2, 14),
        ToTensor(),
        Normalize(mean=[0.4850, 0.4560, 0.4060], std=[0.2290, 0.2240, 0.2250]),
    ]
)


def get_emb(im: Image):
    model.eval()
    with torch.no_grad():
        return model(tf(im).unsqueeze(0))


n_outputs = 50
row_size = 5


def f(im):
    D, I = index.search(get_emb(im), n_outputs)
    return [f"Distance: {d:.1f}\n![]({urls[i]})" for d, i in zip(D[0], I[0])]


print("preparing gradio")
with gr.Blocks() as demo:
    gr.Markdown(
        "# Style Similarity Search\n\nFind artworks with a similar style from a medium-sized database (10k artists * 30 img/artist)"
    )
    img = gr.Image(type="pil", label="Query", height=500)
    btn = gr.Button(variant="primary", value="search")
    outputs = []
    for i in range(-(n_outputs // (-row_size))):
        with gr.Row():
            for _ in range(min(row_size, n_outputs - i * row_size)):
                outputs.append(gr.Markdown(label=f"#{len(outputs) + 1}"))
    btn.click(f, img, outputs)
print("starting gradio")
demo.launch()