import os from PIL import Image import numpy as np import torch from transformers import ( AutoImageProcessor, ) import gradio as gr from modeling_siglip import SiglipForImageClassification HF_TOKEN = os.environ.get("HF_READ_TOKEN") EXAMPLES = [["./images/sample.jpg"], ["./images/sample2.webp"]] model_maps: dict[str, dict] = { "test2": { "repo": "p1atdev/siglip-tagger-test-2", }, "test3": { "repo": "p1atdev/siglip-tagger-test-3", }, # "test4": { # "repo": "p1atdev/siglip-tagger-test-4", # }, } for key in model_maps.keys(): model_maps[key]["model"] = SiglipForImageClassification.from_pretrained( model_maps[key]["repo"], torch_dtype=torch.bfloat16, token=HF_TOKEN ) model_maps[key]["processor"] = AutoImageProcessor.from_pretrained( model_maps[key]["repo"], token=HF_TOKEN ) README_MD = ( f"""\ ## SigLIP Tagger Test 3 An experimental model for tagging danbooru tags of images using SigLIP. Model(s): """ + "\n".join( f"- [{value['repo']}](https://huggingface.co/{value['repo']})" for value in model_maps.values() ) + "\n" + """ Example images by NovelAI and niji惻journey. """ ) def compose_text(results: dict[str, float], threshold: float = 0.3): return ", ".join( [ key for key, value in sorted(results.items(), key=lambda x: x[1], reverse=True) if value > threshold ] ) @torch.no_grad() def predict_tags(image: Image.Image, model_name: str, threshold: float): if image is None: return None, None inputs = model_maps[model_name]["processor"](image, return_tensors="pt") logits = ( model_maps[model_name]["model"]( **inputs.to( model_maps[model_name]["model"].device, model_maps[model_name]["model"].dtype, ) ) .logits.detach() .cpu() .float() ) logits = np.clip(logits, 0.0, 1.0) results = {} for prediction in logits: for i, prob in enumerate(prediction): if prob.item() > 0: results[model_maps[model_name]["model"].config.id2label[i]] = ( prob.item() ) return compose_text(results, threshold), results css = """\ .sticky { position: sticky; top: 16px; } .gradio-container { overflow: clip; } """ def demo(): with gr.Blocks(css=css) as ui: gr.Markdown(README_MD) with gr.Row(): with gr.Column(): with gr.Row(elem_classes="sticky"): with gr.Column(): input_img = gr.Image( label="Input image", type="pil", height=480 ) with gr.Group(): model_name_radio = gr.Radio( label="Model", choices=list(model_maps.keys()), value="test3", ) tag_threshold_slider = gr.Slider( label="Tags threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.01, ) start_btn = gr.Button(value="Start", variant="primary") gr.Examples( examples=EXAMPLES, inputs=[input_img], cache_examples=False, ) with gr.Column(): output_tags = gr.Text(label="Output text", interactive=False) output_label = gr.Label(label="Output tags") start_btn.click( fn=predict_tags, inputs=[input_img, model_name_radio, tag_threshold_slider], outputs=[output_tags, output_label], ) ui.launch( debug=True, # share=True ) if __name__ == "__main__": demo()