Spaces:
Running
Running
File size: 2,978 Bytes
e212637 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import numpy as np
import torch
from transformers import (
AutoProcessor,
)
from PIL import Image
import gradio as gr
from modeling_siglip import SiglipForImageClassification
MODEL_NAME = "p1atdev/siglip-tagger-test-3"
PROCESSOR_NAME = "google/siglip-so400m-patch14-384"
model = SiglipForImageClassification.from_pretrained(
MODEL_NAME,
)
# model = torch.compile(model)
processor = AutoProcessor.from_pretrained(PROCESSOR_NAME)
def compose_text(results: dict[str, float], threshold: float = 0.3):
return ", ".join(
[
key
for key, value in sorted(results.items(), key=lambda x: x[1], reverse=True)
if value > threshold
]
)
@torch.no_grad()
def predict_tags(image: Image.Image, threshold: float):
inputs = processor(images=image, return_tensors="pt")
logits = model(**inputs.to(model.device, model.dtype)).logits.detach().cpu()
logits = np.clip(logits, 0.0, 1.0)
results = {}
for prediction in logits:
for i, prob in enumerate(prediction):
if prob.item() > 0:
results[model.config.id2label[i]] = prob.item()
return compose_text(results, threshold), results
css = """\
.sticky {
position: sticky;
top: 16px;
}
.gradio-container {
overflow: clip;
}
"""
def demo():
with gr.Blocks(css=css) as ui:
gr.Markdown(
"""\
## SigLIP Tagger Test 3
An experimental model for tagging danbooru tags of images using SigLIP.
Models:
- (soon)
Example images by NovelAI and niji・journey.
"""
)
with gr.Row():
with gr.Column():
with gr.Row(elem_classes="sticky"):
with gr.Column():
input_img = gr.Image(
label="Input image", type="pil", height=480
)
with gr.Group():
tag_threshold_slider = gr.Slider(
label="Tags threshold",
minimum=0.0,
maximum=1.0,
value=0.3,
step=0.01,
)
start_btn = gr.Button(value="Start", variant="primary")
gr.Examples(
examples=[["./sample.jpg"], ["./sample2.webp"]],
inputs=[input_img],
cache_examples=False,
)
with gr.Column():
output_tags = gr.Text(label="Output text", interactive=False)
output_label = gr.Label(label="Output tags")
start_btn.click(
fn=predict_tags,
inputs=[input_img, tag_threshold_slider],
outputs=[output_tags, output_label],
)
ui.launch(
debug=True,
# share=True
)
if __name__ == "__main__":
demo()
|