from os import getenv from pathlib import Path from typing import Optional import gradio as gr import numpy as np import onnxruntime as rt from PIL import Image from tagger.common import LabelData, load_labels, preprocess_image from tagger.model import create_session HF_TOKEN = getenv("HF_TOKEN", None) WORK_DIR = Path.cwd().resolve() MODEL_VARIANTS: dict[str, str] = { "MOAT": "SmilingWolf/wd-v1-4-moat-tagger-v2", "SwinV2": "SmilingWolf/wd-v1-4-swinv2-tagger-v2", "ConvNeXT": "SmilingWolf/wd-v1-4-convnext-tagger-v2", "ConvNeXTv2": "SmilingWolf/wd-v1-4-convnextv2-tagger-v2", "ViT": "SmilingWolf/wd-v1-4-vit-tagger-v2", } # allowed extensions IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"] # model input shape IMAGE_SIZE = 448 example_images = sorted( [ str(x.relative_to(WORK_DIR)) for x in WORK_DIR.joinpath("examples").iterdir() if x.is_file() and x.suffix.lower() in IMAGE_EXTENSIONS ] ) loaded_models: dict[str, Optional[rt.InferenceSession]] = {k: None for k, _ in MODEL_VARIANTS.items()} def load_model(variant: str) -> rt.InferenceSession: global loaded_models # resolve the repo name model_repo = MODEL_VARIANTS.get(variant, None) if model_repo is None: raise ValueError(f"Unknown model variant: {variant}") if loaded_models.get(variant, None) is None: # save model to cache loaded_models[variant] = create_session(model_repo, token=HF_TOKEN) return loaded_models[variant] def predict( image: Image.Image, variant: str, general_threshold: float = 0.35, character_threshold: float = 0.85, ): # Load model model: rt.InferenceSession = load_model(variant) # load labels labels: LabelData = load_labels() # get input size and name _, h, w, _ = model.get_inputs()[0].shape input_name = model.get_inputs()[0].name output_name = model.get_outputs()[0].name # preprocess image image = preprocess_image(image, (h, w)) # turn into BGR24 numpy array of N,H,W,C since thats what these want inputs = image.convert("RGB").convert("BGR;24") inputs = np.array(inputs).astype(np.float32) inputs = np.expand_dims(inputs, axis=0) # Run the ONNX model probs = model.run([output_name], {input_name: inputs}) # Convert indices+probs to labels probs = list(zip(labels.names, probs[0][0].astype(float))) # First 4 labels are actually ratings rating_labels = dict([probs[i] for i in labels.rating]) # General labels, pick any where prediction confidence > threshold gen_labels = [probs[i] for i in labels.general] gen_labels = dict([x for x in gen_labels if x[1] > general_threshold]) gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True)) # Character labels, pick any where prediction confidence > threshold char_labels = [probs[i] for i in labels.character] char_labels = dict([x for x in char_labels if x[1] > character_threshold]) char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True)) # Combine general and character labels, sort by confidence combined_names = [x for x in gen_labels] combined_names.extend([x for x in char_labels]) # Convert to a string suitable for use as a training caption caption = ", ".join(combined_names) booru = caption.replace("_", " ").replace("(", "\(").replace(")", "\)") return image, caption, booru, rating_labels, char_labels, gen_labels with gr.Blocks(title="pi-chan's tagger") as demo: with gr.Row(equal_height=False): with gr.Column(): img_input = gr.Image( label="Input", type="pil", image_mode="RGB", sources=["upload", "clipboard"], ) variant = gr.Radio(choices=list(MODEL_VARIANTS.keys()), label="Model Variant", value="MOAT") gen_thresh = gr.Slider(0.0, 1.0, value=0.35, label="General Tag Threshold") char_thresh = gr.Slider(0.0, 1.0, value=0.85, label="Character Tag Threshold") show_processed = gr.Checkbox(label="Show Preprocessed", value=False) with gr.Row(): submit = gr.Button(value="Submit", variant="primary", size="lg") clear = gr.ClearButton( components=[], variant="secondary", size="lg", ) with gr.Row(): examples = gr.Examples( examples=[ [imgpath, var, 0.35, 0.85] for imgpath in example_images for var in ["MOAT", "ConvNeXTv2"] ], inputs=[img_input, variant, gen_thresh, char_thresh], ) with gr.Column(): img_output = gr.Image(label="Preprocessed", type="pil", image_mode="RGB", scale=1, visible=False) with gr.Group(): tags_string = gr.Textbox( label="Caption", placeholder="Caption will appear here", show_copy_button=True ) tags_booru = gr.Textbox( label="Tags", placeholder="Tag string will appear here", show_copy_button=True ) rating = gr.Label(label="Rating") character = gr.Label(label="Character") general = gr.Label(label="General") # tell clear button which components to clear clear.add([img_input, img_output, tags_string, rating, character, general]) # show/hide processed image def on_select_show_processed(evt: gr.SelectData): return gr.update(visible=evt.selected) show_processed.select(on_select_show_processed, inputs=[], outputs=[img_output]) submit.click( predict, inputs=[img_input, variant, gen_thresh, char_thresh], outputs=[img_output, tags_string, tags_booru, rating, character, general], api_name="predict", ) if __name__ == "__main__": demo.queue(max_size=10) demo.launch(server_name="0.0.0.0", server_port=7871)