File size: 3,048 Bytes
c45703e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7fe54e9
c45703e
 
7fe54e9
 
c45703e
 
 
 
 
 
 
 
7fe54e9
c45703e
 
 
 
 
 
 
7fe54e9
 
c45703e
 
 
 
7fe54e9
 
c45703e
 
 
 
 
7fe54e9
 
c45703e
 
 
 
 
7fe54e9
 
 
 
 
 
 
c45703e
7fe54e9
c45703e
7fe54e9
 
 
 
 
 
 
 
 
 
 
 
 
c45703e
 
7fe54e9
c45703e
 
 
7fe54e9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import pandas as pd
import torch
import faiss
import gradio as gr

import base64

from PIL import Image
from io import BytesIO

from src.model import ConditionalViT, B16_Params, categories
from src.transform import valid_tf
from src.process_images import process_img, make_img_html
from src.examples import ExamplesHandler
from src.js_loader import JavaScriptLoader

# Load Model
m = ConditionalViT(**B16_Params, n_categories=len(categories))
m.load_state_dict(torch.load("./artifacts/cat_condvit_b16.pth", map_location="cpu"))
m.eval()

# Load data
index = faiss.read_index("./artifacts/gallery_index.faiss")
gal_imgs = pd.read_parquet("./artifacts/gallery_imgs.parquet")
tfs = valid_tf((224, 224))

K = 5

examples = [
    ["examples/3.jpg", "Outwear"],
    ["examples/3.jpg", "Lower Body"],
    ["examples/3.jpg", "Feet"],
    ["examples/757.jpg", "Bags"],
    ["examples/757.jpg", "Upper Body"],
    ["examples/769.jpg", "Upper Body"],
    ["examples/1811.jpg", "Lower Body"],
    ["examples/1811.jpg", "Bags"],
]


@torch.inference_mode()
def retrieval(image, category):
    if image is None or category is None:
        return

    q_emb = m(tfs(image).unsqueeze(0), torch.tensor([category]))

    r = index.search(q_emb, K)

    imgs = [process_img(idx, gal_imgs) for idx in r[1][0]]

    html = [make_img_html(i) for i in imgs]
    html += ["<p></p>"]  # Avoid Gradio's last-child{margin-bottom:0!important;}

    return "\n".join(html)


JavaScriptLoader("src/custom_functions.js")
with gr.Blocks(css="src/style.css") as demo:
    with gr.Column():
        gr.Markdown(
            """
        # Conditional ViT Demo 
        [[`Paper`](https://arxiv.org/abs/2306.02928)] 
        [[`Code`](https://github.com/Simon-Lepage/CondViT-LRVSF)] 
        [[`Dataset`](https://huggingface.co/datasets/Slep/LAION-RVS-Fashion)] 
        [[`Model`](https://huggingface.co/Slep/CondViT-B16-cat)] 


        *Running on 2 vCPU, 16Go RAM.*

        - **Model :** Categorical CondViT-B/16
        - **Gallery :** 93K images.
        """
        )

        # Input section
        with gr.Row():
            img = gr.Image(label="Query Image", type="pil", elem_id="query_img")
            with gr.Column():
                cat = gr.Dropdown(
                    choices=categories,
                    label="Category",
                    value="Upper Body",
                    type="index",
                    elem_id="dropdown",
                )
                submit = gr.Button("Submit")

        # Examples
        gr.Examples(
            examples,
            inputs=[img, cat],
            fn=retrieval,
            elem_id="preset_examples",
            examples_per_page=100,
        )
        gr.HTML(
            value=ExamplesHandler(examples).to_html(),
            label="examples",
            elem_id="html_examples",
        )

        # Outputs
        gr.Markdown("# Retrieved Items")
        out = gr.HTML(label="Results", elem_id="html_output")

        submit.click(fn=retrieval, inputs=[img, cat], outputs=out)

demo.launch()