Spaces:
Runtime error
Runtime error
from os import getenv | |
from pathlib import Path | |
from typing import Optional | |
import gradio as gr | |
import numpy as np | |
import onnxruntime as rt | |
from PIL import Image | |
from tagger.common import LabelData, load_labels_hf, preprocess_image | |
from tagger.model import create_session | |
TITLE = "WaifuDiffusion Tagger" | |
DESCRIPTION = """ | |
Tag images with the WaifuDiffusion Tagger models! | |
Primarily used as a backend for a Discord bot. | |
""" | |
HF_TOKEN = getenv("HF_TOKEN", None) | |
MODEL_VARIANTS: dict[str, str] = { | |
"v3": { | |
"SwinV2": "SmilingWolf/wd-swinv2-tagger-v3", | |
"ConvNeXT": "SmilingWolf/wd-convnext-tagger-v3", | |
"ViT": "SmilingWolf/wd-vit-tagger-v3", | |
}, | |
"v2": { | |
"MOAT": "SmilingWolf/wd-v1-4-moat-tagger-v2", | |
"SwinV2": "SmilingWolf/wd-v1-4-swinv2-tagger-v2", | |
"ConvNeXT": "SmilingWolf/wd-v1-4-convnext-tagger-v2", | |
"ConvNeXTv2": "SmilingWolf/wd-v1-4-convnextv2-tagger-v2", | |
"ViT": "SmilingWolf/wd-v1-4-vit-tagger-v2", | |
}, | |
} | |
# prepopulate cache keys in model cache | |
cache_keys = ["-".join([x, y]) for x in MODEL_VARIANTS.keys() for y in MODEL_VARIANTS[x].keys()] | |
loaded_models: dict[str, Optional[rt.InferenceSession]] = {k: None for k in cache_keys} | |
# get the repo root (or the current working directory if running in ipython) | |
WORK_DIR = Path(__file__).parent.resolve() if "__file__" in globals() else Path().resolve() | |
# allowed extensions | |
IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"] | |
# get the example images | |
example_images = sorted( | |
[ | |
str(x.relative_to(WORK_DIR)) | |
for x in WORK_DIR.joinpath("examples").iterdir() | |
if x.is_file() and x.suffix.lower() in IMAGE_EXTENSIONS | |
] | |
) | |
def load_model(version: str, variant: str) -> rt.InferenceSession: | |
global loaded_models | |
# resolve the repo name | |
model_repo = MODEL_VARIANTS.get(version, {}).get(variant, None) | |
if model_repo is None: | |
raise ValueError(f"Unknown model variant: {version}-{variant}") | |
cache_key = f"{version}-{variant}" | |
if loaded_models.get(cache_key, None) is None: | |
# save model to cache | |
loaded_models[cache_key] = create_session(model_repo, token=HF_TOKEN) | |
return loaded_models[cache_key] | |
def mcut_threshold(probs: np.ndarray) -> float: | |
""" | |
Maximum Cut Thresholding (MCut) | |
Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy | |
for Multi-label Classification. In 11th International Symposium, IDA 2012 | |
(pp. 172-183). | |
""" | |
probs = probs[probs.argsort()[::-1]] | |
diffs = probs[:-1] - probs[1:] | |
idx = diffs.argmax() | |
thresh = (probs[idx] + probs[idx + 1]) / 2 | |
return float(thresh) | |
def predict( | |
image: Image.Image, | |
version: str, | |
variant: str, | |
gen_threshold: float = 0.35, | |
gen_use_mcut: bool = False, | |
char_threshold: float = 0.85, | |
char_use_mcut: bool = False, | |
): | |
# join variant for cache key | |
model: rt.InferenceSession = load_model(version, variant) | |
# load labels | |
labels: LabelData = load_labels_hf(MODEL_VARIANTS[version][variant]) | |
# get input size and name | |
_, h, w, _ = model.get_inputs()[0].shape | |
input_name = model.get_inputs()[0].name | |
output_name = model.get_outputs()[0].name | |
# preprocess image | |
image = preprocess_image(image, (h, w)) | |
# turn into BGR24 numpy array of N,H,W,C since thats what these want | |
inputs = image.convert("RGB").convert("BGR;24") | |
inputs = np.array(inputs).astype(np.float32) | |
inputs = np.expand_dims(inputs, axis=0) | |
# Run the ONNX model | |
probs = model.run([output_name], {input_name: inputs}) | |
# Convert indices+probs to labels | |
probs = list(zip(labels.names, probs[0][0].astype(float))) | |
# First 4 labels are actually ratings | |
rating_labels = dict([probs[i] for i in labels.rating]) | |
# General labels, pick any where prediction confidence > threshold | |
if gen_use_mcut: | |
gen_array = np.array([probs[i][1] for i in labels.general]) | |
gen_threshold = mcut_threshold(gen_array) | |
gen_labels = [probs[i] for i in labels.general] | |
gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold]) | |
gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True)) | |
# Character labels, pick any where prediction confidence > threshold | |
if char_use_mcut: | |
char_array = np.array([probs[i][1] for i in labels.character]) | |
char_threshold = round(mcut_threshold(char_array), 2) | |
char_labels = [probs[i] for i in labels.character] | |
char_labels = dict([x for x in char_labels if x[1] > char_threshold]) | |
char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True)) | |
# Combine general and character labels, sort by confidence | |
combined_names = [x for x in gen_labels] | |
combined_names.extend([x for x in char_labels]) | |
# Convert to a string suitable for use as a training caption | |
caption = ", ".join(combined_names) | |
booru = caption.replace("_", " ").replace("(", "\(").replace(")", "\)") | |
return image, caption, booru, rating_labels, char_labels, char_threshold, gen_labels, gen_threshold | |
css = """ | |
#gen_mcut, #char_mcut { | |
padding-top: var(--scale-3); | |
} | |
#gen_threshold.dimmed, #char_threshold.dimmed { | |
filter: brightness(75%); | |
} | |
""" | |
with gr.Blocks(theme="NoCrypt/miku", analytics_enabled=False, title=TITLE, css=css) as demo: | |
with gr.Row(equal_height=False): | |
with gr.Column(min_width=720): | |
with gr.Group(): | |
img_input = gr.Image( | |
label="Input", | |
type="pil", | |
image_mode="RGB", | |
sources=["upload", "clipboard"], | |
) | |
show_processed = gr.Checkbox(label="Show Preprocessed Image", value=False) | |
with gr.Row(): | |
version = gr.Radio( | |
choices=list(MODEL_VARIANTS.keys()), | |
label="Model Version", | |
value="v3", | |
min_width=160, | |
scale=1, | |
) # gen_threshold > div.wrap.hide | |
variant = gr.Radio( | |
choices=list(MODEL_VARIANTS[version.value].keys()), | |
label="Model Variant", | |
value="SwinV2", | |
min_width=560, | |
) | |
with gr.Group(): | |
with gr.Row(): | |
gen_threshold = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.35, | |
step=0.01, | |
label="General Tag Threshold", | |
scale=5, | |
elem_id="gen_threshold", | |
) | |
gen_mcut = gr.Checkbox(label="Use Max-Cut", value=False, scale=1, elem_id="gen_mcut") | |
with gr.Row(): | |
char_threshold = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.85, | |
step=0.01, | |
label="Character Tag Threshold", | |
scale=5, | |
elem_id="char_threshold", | |
) | |
char_mcut = gr.Checkbox(label="Use Max-Cut", value=False, scale=1, elem_id="char_mcut") | |
with gr.Row(): | |
clear = gr.ClearButton( | |
components=[], | |
variant="secondary", | |
size="lg", | |
) | |
submit = gr.Button(value="Submit", variant="primary", size="lg") | |
with gr.Column(min_width=720): | |
img_output = gr.Image( | |
label="Preprocessed Image", type="pil", image_mode="RGB", scale=1, visible=False | |
) | |
with gr.Group(): | |
caption = gr.Textbox(label="Caption", show_copy_button=True) | |
tags = gr.Textbox(label="Tags", show_copy_button=True) | |
with gr.Group(): | |
rating = gr.Label(label="Rating") | |
with gr.Group(): | |
char_mcut_out = gr.Number(label="Max-Cut Threshold", precision=2, visible=False) | |
character = gr.Label(label="Character") | |
with gr.Group(): | |
gen_mcut_out = gr.Number(label="Max-Cut Threshold", precision=2, visible=False) | |
general = gr.Label(label="General") | |
with gr.Row(): | |
examples = [[imgpath, 0.35, mc, 0.85, mc] for mc in [False, True] for imgpath in example_images] | |
examples = gr.Examples( | |
examples=examples, | |
inputs=[img_input, gen_threshold, gen_mcut, char_threshold, char_mcut], | |
) | |
# tell clear button which components to clear | |
clear.add([img_input, img_output, caption, rating, character, general]) | |
def on_select_variant(evt: gr.SelectData, variant: str): | |
if evt.selected: | |
choices = list(MODEL_VARIANTS[variant]) | |
return gr.update(choices=choices, value=choices[0]) | |
return gr.update() | |
version.select(on_select_variant, inputs=[version], outputs=[variant]) | |
# show/hide processed image | |
def on_change_show(val: gr.Checkbox): | |
return gr.update(visible=val) | |
show_processed.select(on_change_show, inputs=[show_processed], outputs=[img_output]) | |
# handle mcut thresholding (auto-calculate threshold from probs, disable slider) | |
def on_change_mcut(val: gr.Checkbox): | |
return ( | |
gr.update(interactive=not val, elem_classes=["dimmed"] if val else []), | |
gr.update(visible=val), | |
) | |
gen_mcut.change(on_change_mcut, inputs=[gen_mcut], outputs=[gen_threshold, gen_mcut_out]) | |
char_mcut.change(on_change_mcut, inputs=[char_mcut], outputs=[char_threshold, char_mcut_out]) | |
submit.click( | |
predict, | |
inputs=[img_input, version, variant, gen_threshold, gen_mcut, char_threshold, char_mcut], | |
outputs=[img_output, caption, tags, rating, character, char_threshold, general, gen_threshold], | |
api_name="predict", | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=10) | |
if getenv("SPACE_ID", None) is not None: | |
demo.launch() | |
else: | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7871, | |
) | |