import argparse import os from typing import Optional import io import gradio as gr import huggingface_hub import numpy as np import onnxruntime as rt import pandas as pd from PIL import Image from fastapi import FastAPI, File, UploadFile, Form from fastapi.responses import JSONResponse import uvicorn TITLE = "WaifuDiffusion Tagger" DESCRIPTION = """ Demo for the WaifuDiffusion tagger models Example image by [ほし☆☆☆](https://www.pixiv.net/en/users/43565085) """ # Dataset v3 series of models: SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3" CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3" VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3" VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3" EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3" # Dataset v2 series of models: MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2" SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2" VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2" MODEL_FILENAME = "model.onnx" LABEL_FILENAME = "selected_tags.csv" kaomojis = [ "0_0", "(o)_(o)", "+_+", "+_-", "._.", "_", "<|>_<|>", "=_=", ">_<", "3_3", "6_9", ">_o", "@_@", "^_^", "o_o", "u_u", "x_x", "|_|", "||_||", ] def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--score-slider-step", type=float, default=0.05) parser.add_argument("--score-general-threshold", type=float, default=0.35) parser.add_argument("--score-character-threshold", type=float, default=0.85) parser.add_argument("--share", action="store_true") return parser.parse_args() def load_labels(dataframe) -> list[str]: name_series = dataframe["name"] name_series = name_series.map( lambda x: x.replace("_", " ") if x not in kaomojis else x ) tag_names = name_series.tolist() rating_indexes = list(np.where(dataframe["category"] == 9)[0]) general_indexes = list(np.where(dataframe["category"] == 0)[0]) character_indexes = list(np.where(dataframe["category"] == 4)[0]) return tag_names, rating_indexes, general_indexes, character_indexes def mcut_threshold(probs): """ Maximum Cut Thresholding (MCut) """ sorted_probs = probs[probs.argsort()[::-1]] difs = sorted_probs[:-1] - sorted_probs[1:] t = difs.argmax() thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2 return thresh class Predictor: def __init__(self): self.model_target_size = None self.last_loaded_repo = None def download_model(self, model_repo): csv_path = huggingface_hub.hf_hub_download(model_repo, LABEL_FILENAME) model_path = huggingface_hub.hf_hub_download(model_repo, MODEL_FILENAME) return csv_path, model_path def load_model(self, model_repo): if model_repo == self.last_loaded_repo: return csv_path, model_path = self.download_model(model_repo) tags_df = pd.read_csv(csv_path) sep_tags = load_labels(tags_df) self.tag_names = sep_tags[0] self.rating_indexes = sep_tags[1] self.general_indexes = sep_tags[2] self.character_indexes = sep_tags[3] model = rt.InferenceSession(model_path) _, height, width, _ = model.get_inputs()[0].shape self.model_target_size = height self.last_loaded_repo = model_repo self.model = model def prepare_image(self, image): target_size = self.model_target_size canvas = Image.new("RGBA", image.size, (255, 255, 255)) canvas.alpha_composite(image) image = canvas.convert("RGB") image_shape = image.size max_dim = max(image_shape) pad_left = (max_dim - image_shape[0]) // 2 pad_top = (max_dim - image_shape[1]) // 2 padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) padded_image.paste(image, (pad_left, pad_top)) if max_dim != target_size: padded_image = padded_image.resize((target_size, target_size), Image.BICUBIC) image_array = np.asarray(padded_image, dtype=np.float32) image_array = image_array[:, :, ::-1] return np.expand_dims(image_array, axis=0) def predict( self, image, model_repo=SWINV2_MODEL_DSV3_REPO, general_thresh=0.35, general_mcut_enabled=False, character_thresh=0.85, character_mcut_enabled=False, ): self.load_model(model_repo) image = self.prepare_image(image) input_name = self.model.get_inputs()[0].name label_name = self.model.get_outputs()[0].name preds = self.model.run([label_name], {input_name: image})[0] labels = list(zip(self.tag_names, preds[0].astype(float))) ratings_names = [labels[i] for i in self.rating_indexes] rating = dict(ratings_names) general_names = [labels[i] for i in self.general_indexes] if general_mcut_enabled: general_probs = np.array([x[1] for x in general_names]) general_thresh = mcut_threshold(general_probs) general_res = [x for x in general_names if x[1] > general_thresh] general_res = dict(general_res) character_names = [labels[i] for i in self.character_indexes] if character_mcut_enabled: character_probs = np.array([x[1] for x in character_names]) character_thresh = mcut_threshold(character_probs) character_thresh = max(0.15, character_thresh) character_res = [x for x in character_names if x[1] > character_thresh] character_res = dict(character_res) sorted_general = sorted(general_res.items(), key=lambda x: x[1], reverse=True) sorted_general_strings = [x[0] for x in sorted_general] sorted_general_strings = ", ".join(sorted_general_strings).replace("(", "\(").replace(")", "\)") return sorted_general_strings, rating, character_res, general_res predictor = Predictor() @app.post("/tagging") async def tagging_endpoint( image: UploadFile = File(...), threshold: Optional[float] = Form(0.05) ): image_data = await image.read() pil_image = Image.open(io.BytesIO(image_data)).convert("RGBA") sorted_general_strings, _, _, _ = predictor.predict( pil_image, general_thresh=threshold ) tags = sorted_general_strings.split(", ") return JSONResponse(content={"tags": tags}) def create_demo(): args = parse_args() dropdown_list = [ SWINV2_MODEL_DSV3_REPO, CONV_MODEL_DSV3_REPO, VIT_MODEL_DSV3_REPO, VIT_LARGE_MODEL_DSV3_REPO, EVA02_LARGE_MODEL_DSV3_REPO, MOAT_MODEL_DSV2_REPO, SWIN_MODEL_DSV2_REPO, CONV_MODEL_DSV2_REPO, CONV2_MODEL_DSV2_REPO, VIT_MODEL_DSV2_REPO, ] with gr.Blocks(title=TITLE) as demo: with gr.Column(): gr.Markdown(f"

{TITLE}

") gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(variant="panel"): image = gr.Image(type="pil", image_mode="RGBA", label="Input") model_repo = gr.Dropdown( dropdown_list, value=SWINV2_MODEL_DSV3_REPO, label="Model", ) with gr.Row(): general_thresh = gr.Slider( 0, 1, step=args.score_slider_step, value=args.score_general_threshold, label="General Tags Threshold", scale=3, ) general_mcut_enabled = gr.Checkbox( value=False, label="Use MCut threshold", scale=1, ) with gr.Row(): character_thresh = gr.Slider( 0, 1, step=args.score_slider_step, value=args.score_character_threshold, label="Character Tags Threshold", scale=3, ) character_mcut_enabled = gr.Checkbox( value=False, label="Use MCut threshold", scale=1, ) with gr.Row(): clear = gr.ClearButton( components=[ image, model_repo, general_thresh, general_mcut_enabled, character_thresh, character_mcut_enabled, ], variant="secondary", size="lg", ) submit = gr.Button(value="Submit", variant="primary", size="lg") with gr.Column(variant="panel"): sorted_general_strings = gr.Textbox(label="Output (string)") rating = gr.Label(label="Rating") character_res = gr.Label(label="Output (characters)") general_res = gr.Label(label="Output (tags)") clear.add([sorted_general_strings, rating, character_res, general_res]) submit.click( predictor.predict, inputs=[ image, model_repo, general_thresh, general_mcut_enabled, character_thresh, character_mcut_enabled, ], outputs=[sorted_general_strings, rating, character_res, general_res], ) gr.Examples( [["power.jpg", SWINV2_MODEL_DSV3_REPO, 0.35, False, 0.85, False]], inputs=[ image, model_repo, general_thresh, general_mcut_enabled, character_thresh, character_mcut_enabled, ], ) demo.queue(max_size=10) return demo app = FastAPI() app = gr.mount_gradio_app(app, create_demo(), path="/") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)