Spaces:
Running
Running
import numpy as np | |
import torch | |
from transformers import ( | |
AutoProcessor, | |
) | |
from PIL import Image | |
import gradio as gr | |
from modeling_siglip import SiglipForImageClassification | |
MODEL_NAME = "p1atdev/siglip-tagger-test-3" | |
PROCESSOR_NAME = "google/siglip-so400m-patch14-384" | |
model = SiglipForImageClassification.from_pretrained( | |
MODEL_NAME, | |
) | |
# model = torch.compile(model) | |
processor = AutoProcessor.from_pretrained(PROCESSOR_NAME) | |
def compose_text(results: dict[str, float], threshold: float = 0.3): | |
return ", ".join( | |
[ | |
key | |
for key, value in sorted(results.items(), key=lambda x: x[1], reverse=True) | |
if value > threshold | |
] | |
) | |
def predict_tags(image: Image.Image, threshold: float): | |
inputs = processor(images=image, return_tensors="pt") | |
logits = model(**inputs.to(model.device, model.dtype)).logits.detach().cpu() | |
logits = np.clip(logits, 0.0, 1.0) | |
results = {} | |
for prediction in logits: | |
for i, prob in enumerate(prediction): | |
if prob.item() > 0: | |
results[model.config.id2label[i]] = prob.item() | |
return compose_text(results, threshold), results | |
css = """\ | |
.sticky { | |
position: sticky; | |
top: 16px; | |
} | |
.gradio-container { | |
overflow: clip; | |
} | |
""" | |
def demo(): | |
with gr.Blocks(css=css) as ui: | |
gr.Markdown( | |
"""\ | |
## SigLIP Tagger Test 3 | |
An experimental model for tagging danbooru tags of images using SigLIP. | |
Models: | |
- (soon) | |
Example images by NovelAI and niji・journey. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(elem_classes="sticky"): | |
with gr.Column(): | |
input_img = gr.Image( | |
label="Input image", type="pil", height=480 | |
) | |
with gr.Group(): | |
tag_threshold_slider = gr.Slider( | |
label="Tags threshold", | |
minimum=0.0, | |
maximum=1.0, | |
value=0.3, | |
step=0.01, | |
) | |
start_btn = gr.Button(value="Start", variant="primary") | |
gr.Examples( | |
examples=[["./sample.jpg"], ["./sample2.webp"]], | |
inputs=[input_img], | |
cache_examples=False, | |
) | |
with gr.Column(): | |
output_tags = gr.Text(label="Output text", interactive=False) | |
output_label = gr.Label(label="Output tags") | |
start_btn.click( | |
fn=predict_tags, | |
inputs=[input_img, tag_threshold_slider], | |
outputs=[output_tags, output_label], | |
) | |
ui.launch( | |
debug=True, | |
# share=True | |
) | |
if __name__ == "__main__": | |
demo() | |