Spaces:
Sleeping
Sleeping
File size: 3,048 Bytes
c45703e 7fe54e9 c45703e 7fe54e9 c45703e 7fe54e9 c45703e 7fe54e9 c45703e 7fe54e9 c45703e 7fe54e9 c45703e 7fe54e9 c45703e 7fe54e9 c45703e 7fe54e9 c45703e 7fe54e9 c45703e 7fe54e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import pandas as pd
import torch
import faiss
import gradio as gr
import base64
from PIL import Image
from io import BytesIO
from src.model import ConditionalViT, B16_Params, categories
from src.transform import valid_tf
from src.process_images import process_img, make_img_html
from src.examples import ExamplesHandler
from src.js_loader import JavaScriptLoader
# Load Model
m = ConditionalViT(**B16_Params, n_categories=len(categories))
m.load_state_dict(torch.load("./artifacts/cat_condvit_b16.pth", map_location="cpu"))
m.eval()
# Load data
index = faiss.read_index("./artifacts/gallery_index.faiss")
gal_imgs = pd.read_parquet("./artifacts/gallery_imgs.parquet")
tfs = valid_tf((224, 224))
K = 5
examples = [
["examples/3.jpg", "Outwear"],
["examples/3.jpg", "Lower Body"],
["examples/3.jpg", "Feet"],
["examples/757.jpg", "Bags"],
["examples/757.jpg", "Upper Body"],
["examples/769.jpg", "Upper Body"],
["examples/1811.jpg", "Lower Body"],
["examples/1811.jpg", "Bags"],
]
@torch.inference_mode()
def retrieval(image, category):
if image is None or category is None:
return
q_emb = m(tfs(image).unsqueeze(0), torch.tensor([category]))
r = index.search(q_emb, K)
imgs = [process_img(idx, gal_imgs) for idx in r[1][0]]
html = [make_img_html(i) for i in imgs]
html += ["<p></p>"] # Avoid Gradio's last-child{margin-bottom:0!important;}
return "\n".join(html)
JavaScriptLoader("src/custom_functions.js")
with gr.Blocks(css="src/style.css") as demo:
with gr.Column():
gr.Markdown(
"""
# Conditional ViT Demo
[[`Paper`](https://arxiv.org/abs/2306.02928)]
[[`Code`](https://github.com/Simon-Lepage/CondViT-LRVSF)]
[[`Dataset`](https://huggingface.co/datasets/Slep/LAION-RVS-Fashion)]
[[`Model`](https://huggingface.co/Slep/CondViT-B16-cat)]
*Running on 2 vCPU, 16Go RAM.*
- **Model :** Categorical CondViT-B/16
- **Gallery :** 93K images.
"""
)
# Input section
with gr.Row():
img = gr.Image(label="Query Image", type="pil", elem_id="query_img")
with gr.Column():
cat = gr.Dropdown(
choices=categories,
label="Category",
value="Upper Body",
type="index",
elem_id="dropdown",
)
submit = gr.Button("Submit")
# Examples
gr.Examples(
examples,
inputs=[img, cat],
fn=retrieval,
elem_id="preset_examples",
examples_per_page=100,
)
gr.HTML(
value=ExamplesHandler(examples).to_html(),
label="examples",
elem_id="html_examples",
)
# Outputs
gr.Markdown("# Retrieved Items")
out = gr.HTML(label="Results", elem_id="html_output")
submit.click(fn=retrieval, inputs=[img, cat], outputs=out)
demo.launch()
|