SigLIP_Tagger / app.py
John6666's picture
Upload 2 files
44a2337 verified
raw
history blame
4.32 kB
import os
from PIL import Image
import numpy as np
import torch
from transformers import (
AutoImageProcessor,
)
import gradio as gr
from modeling_siglip import SiglipForImageClassification
HF_TOKEN = os.environ.get("HF_READ_TOKEN")
EXAMPLES = [["./images/sample.jpg"], ["./images/sample2.webp"]]
model_maps: dict[str, dict] = {
"test2": {
"repo": "p1atdev/siglip-tagger-test-2",
},
"test3": {
"repo": "p1atdev/siglip-tagger-test-3",
},
# "test4": {
# "repo": "p1atdev/siglip-tagger-test-4",
# },
}
for key in model_maps.keys():
model_maps[key]["model"] = SiglipForImageClassification.from_pretrained(
model_maps[key]["repo"], torch_dtype=torch.bfloat16, token=HF_TOKEN
)
model_maps[key]["processor"] = AutoImageProcessor.from_pretrained(
model_maps[key]["repo"], token=HF_TOKEN
)
README_MD = (
f"""\
## SigLIP Tagger Test 3
An experimental model for tagging danbooru tags of images using SigLIP.
Model(s):
"""
+ "\n".join(
f"- [{value['repo']}](https://huggingface.co/{value['repo']})"
for value in model_maps.values()
)
+ "\n"
+ """
Example images by NovelAI and niji・journey.
"""
)
def compose_text(results: dict[str, float], threshold: float = 0.3):
return ", ".join(
[
key
for key, value in sorted(results.items(), key=lambda x: x[1], reverse=True)
if value > threshold
]
)
@torch.no_grad()
def predict_tags(image: Image.Image, model_name: str, threshold: float):
if image is None:
return None, None
inputs = model_maps[model_name]["processor"](image, return_tensors="pt")
logits = (
model_maps[model_name]["model"](
**inputs.to(
model_maps[model_name]["model"].device,
model_maps[model_name]["model"].dtype,
)
)
.logits.detach()
.cpu()
.float()
)
logits = np.clip(logits, 0.0, 1.0)
results = {}
for prediction in logits:
for i, prob in enumerate(prediction):
if prob.item() > 0:
results[model_maps[model_name]["model"].config.id2label[i]] = (
prob.item()
)
return compose_text(results, threshold), results
css = """\
.sticky {
position: sticky;
top: 16px;
}
.gradio-container {
overflow: clip;
}
"""
def demo():
with gr.Blocks(css=css) as ui:
gr.Markdown(README_MD)
with gr.Row():
with gr.Column():
with gr.Row(elem_classes="sticky"):
with gr.Column():
input_img = gr.Image(
label="Input image", type="pil", height=480
)
with gr.Group():
model_name_radio = gr.Radio(
label="Model",
choices=list(model_maps.keys()),
value="test3",
)
tag_threshold_slider = gr.Slider(
label="Tags threshold",
minimum=0.0,
maximum=1.0,
value=0.3,
step=0.01,
)
start_btn = gr.Button(value="Start", variant="primary")
gr.Examples(
examples=EXAMPLES,
inputs=[input_img],
cache_examples=False,
)
with gr.Column():
output_tags = gr.Text(label="Output text", interactive=False)
output_label = gr.Label(label="Output tags")
start_btn.click(
fn=predict_tags,
inputs=[input_img, model_name_radio, tag_threshold_slider],
outputs=[output_tags, output_label],
)
ui.launch(
debug=True,
# share=True
)
if __name__ == "__main__":
demo()