wd-tagger-v3 / app.py
CodeChris's picture
Add tag string format presets and comma-sep option.
09d71ca verified
import argparse
import gradio as gr
import huggingface_hub
import numpy as np
import onnxruntime as rt
import pandas as pd
from PIL import Image
TITLE = "Image Tagger"
DESCRIPTION = "Modified from: [SmilingWolf/wd-tagger](https://huggingface.co/spaces/SmilingWolf/wd-tagger) (8279aed)"
# Dataset v3 series of models:
SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
# Dataset v2 series of models:
# MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
# SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
# CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
# CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
# VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
# Files to download from the repos
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
kaomojis = [
"0_0",
"(o)_(o)",
"+_+",
"+_-",
"._.",
"<o>_<o>",
"<|>_<|>",
"=_=",
">_<",
"3_3",
"6_9",
">_o",
"@_@",
"^_^",
"o_o",
"u_u",
"x_x",
"|_|",
"||_||",
]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--score-slider-step", type=float, default=0.05)
parser.add_argument("--score-general-threshold", type=float, default=0.35)
parser.add_argument("--score-character-threshold", type=float, default=0.80)
parser.add_argument("--sort-tag-string-by-confidence", action="store_true")
parser.add_argument("--share", action="store_true")
return parser.parse_args()
def load_labels(dataframe) -> list[str]:
name_series = dataframe["name"]
name_series = name_series.map(
lambda x: x.replace("_", " ") if x not in kaomojis else x
)
tag_names = name_series.tolist()
rating_indexes = list(np.where(dataframe["category"] == 9)[0])
general_indexes = list(np.where(dataframe["category"] == 0)[0])
character_indexes = list(np.where(dataframe["category"] == 4)[0])
return tag_names, rating_indexes, general_indexes, character_indexes
def mcut_threshold(probs):
"""
Maximum Cut Thresholding (MCut)
Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy
for Multi-label Classification. In 11th International Symposium, IDA 2012
(pp. 172-183).
"""
sorted_probs = probs[probs.argsort()[::-1]]
difs = sorted_probs[:-1] - sorted_probs[1:]
t = difs.argmax()
thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
return thresh
class Predictor:
def __init__(self):
self.model_target_size = None
self.last_loaded_repo = None
def download_model(self, model_repo):
csv_path = huggingface_hub.hf_hub_download(
model_repo,
LABEL_FILENAME,
)
model_path = huggingface_hub.hf_hub_download(
model_repo,
MODEL_FILENAME,
)
return csv_path, model_path
def load_model(self, model_repo):
if model_repo == self.last_loaded_repo:
return
csv_path, model_path = self.download_model(model_repo)
tags_df = pd.read_csv(csv_path)
sep_tags = load_labels(tags_df)
self.tag_names = sep_tags[0]
self.rating_indexes = sep_tags[1]
self.general_indexes = sep_tags[2]
self.character_indexes = sep_tags[3]
model = rt.InferenceSession(model_path)
_, height, width, _ = model.get_inputs()[0].shape
self.model_target_size = height
self.last_loaded_repo = model_repo
self.model = model
def prepare_image(self, image):
target_size = self.model_target_size
canvas = Image.new("RGBA", image.size, (255, 255, 255))
canvas.alpha_composite(image)
image = canvas.convert("RGB")
# Pad image to square
image_shape = image.size
max_dim = max(image_shape)
pad_left = (max_dim - image_shape[0]) // 2
pad_top = (max_dim - image_shape[1]) // 2
padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
padded_image.paste(image, (pad_left, pad_top))
# Resize
if max_dim != target_size:
padded_image = padded_image.resize(
(target_size, target_size),
Image.BICUBIC,
)
# Convert to numpy array
image_array = np.asarray(padded_image, dtype=np.float32)
# Convert PIL-native RGB to BGR
image_array = image_array[:, :, ::-1]
return np.expand_dims(image_array, axis=0)
def tag_dict_to_sorted_string(self, dict_res: dict, sort_by_confidence, descending,
remove_underlines, escape_parens, comma_sep):
"""Custom function: Sort tag dict by confidence/alphabetically"""
sep = ', ' if comma_sep else ' '
if sort_by_confidence:
_sorted_list = sorted(
dict_res.items(),
key=lambda x: x[1],
reverse=descending
)
else:
_sorted_list = sorted(
dict_res.items(),
reverse=descending
)
if remove_underlines:
_sorted_string = sep.join([x[0] for x in _sorted_list])
else: # Add back underlines
_sorted_string = sep.join([x[0].replace(" ", "_") for x in _sorted_list])
if escape_parens:
_sorted_string = _sorted_string.replace("(", "\\(").replace(")", "\\)")
return _sorted_string
def predict(
self,
image,
model_repo,
general_thresh,
general_mcut_enabled,
character_thresh,
character_mcut_enabled,
sort_by_confidence_enabled,
sort_descending_enabled,
preset_checkboxgroup
):
# Decouple the checkgroup status into 3
remove_underline_enabled, escape_parens_enabled, comma_sep_enabled = [
True if i in preset_checkboxgroup else False
for i in range(3)
]
self.load_model(model_repo)
image = self.prepare_image(image)
input_name = self.model.get_inputs()[0].name
label_name = self.model.get_outputs()[0].name
preds = self.model.run([label_name], {input_name: image})[0]
labels = list(zip(self.tag_names, preds[0].astype(float)))
# First 4 labels are actually ratings: pick one with argmax
ratings_names = [labels[i] for i in self.rating_indexes]
rating = dict(ratings_names)
# Then we have general tags: pick any where prediction confidence > threshold
general_names = [labels[i] for i in self.general_indexes]
if general_mcut_enabled:
general_probs = np.array([x[1] for x in general_names])
general_thresh = mcut_threshold(general_probs)
general_res = [x for x in general_names if x[1] > general_thresh]
general_res = dict(general_res)
# Everything else is characters: pick any where prediction confidence > threshold
character_names = [labels[i] for i in self.character_indexes]
if character_mcut_enabled:
character_probs = np.array([x[1] for x in character_names])
character_thresh = mcut_threshold(character_probs)
character_thresh = max(0.15, character_thresh)
character_res = [x for x in character_names if x[1] > character_thresh]
character_res = dict(character_res)
sorted_general_strings = self.tag_dict_to_sorted_string(
general_res,
sort_by_confidence=sort_by_confidence_enabled,
descending=sort_descending_enabled,
remove_underlines=remove_underline_enabled,
escape_parens=escape_parens_enabled,
comma_sep=comma_sep_enabled
)
sorted_character_strings = self.tag_dict_to_sorted_string(
character_res,
sort_by_confidence=sort_by_confidence_enabled,
descending=sort_descending_enabled,
remove_underlines=remove_underline_enabled,
escape_parens=escape_parens_enabled,
comma_sep=comma_sep_enabled
)
return sorted_general_strings, sorted_character_strings, rating, character_res, general_res
def main():
args = parse_args()
predictor = Predictor()
dropdown_list = [
SWINV2_MODEL_DSV3_REPO,
CONV_MODEL_DSV3_REPO,
VIT_MODEL_DSV3_REPO,
VIT_LARGE_MODEL_DSV3_REPO,
EVA02_LARGE_MODEL_DSV3_REPO,
# MOAT_MODEL_DSV2_REPO,
# SWIN_MODEL_DSV2_REPO,
# CONV_MODEL_DSV2_REPO,
# CONV2_MODEL_DSV2_REPO,
# VIT_MODEL_DSV2_REPO,
]
# Define widget udpate functions
PRESET_CHECKBOX_CHOICES = ["Remove Underlines", "Escape Parens", "Comma Separator"]
PRESET_CHECKBOX_DICT = {
"Normal": [PRESET_CHECKBOX_CHOICES[i] for i in[0, 2]],
"Booru": []
}
def update_preset_checkboxes(preset_radio, preset_checkbox_indices):
"""Change checkboxgroup according to the radio selected preset."""
current_checks = [PRESET_CHECKBOX_CHOICES[i] for i in preset_checkbox_indices]
updated_checks = PRESET_CHECKBOX_DICT.get(preset_radio, current_checks)
return updated_checks
def update_tag_preset():
"""Whenever the checkboxgroup is manually changed, set preset to 'Custom'."""
return "Custom"
with gr.Blocks(title=TITLE, theme=gr.themes.Soft(primary_hue="teal")) as demo:
with gr.Column():
gr.Markdown(
value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
)
gr.Markdown(value=DESCRIPTION)
with gr.Row():
with gr.Column(variant="panel"):
submit = gr.Button(value="Submit", variant="primary")
image = gr.Image(type="pil", image_mode="RGBA", label="Input")
model_repo = gr.Dropdown(
dropdown_list,
value=SWINV2_MODEL_DSV3_REPO,
label="Model",
)
with gr.Row():
general_thresh = gr.Slider(
0,
1,
step=args.score_slider_step,
value=args.score_general_threshold,
label="General Tags Threshold",
scale=3,
)
general_mcut_enabled = gr.Checkbox(
value=False,
label="Use MCut threshold",
scale=1,
)
with gr.Row():
character_thresh = gr.Slider(
0,
1,
step=args.score_slider_step,
value=args.score_character_threshold,
label="Character Tags Threshold",
scale=3,
)
character_mcut_enabled = gr.Checkbox(
value=False,
label="Use MCut threshold",
scale=1,
)
with gr.Row():
clear = gr.ClearButton(
components=[
image,
model_repo,
general_thresh,
general_mcut_enabled,
character_thresh,
character_mcut_enabled,
],
variant="secondary"
)
with gr.Column(variant="panel"):
default_tag_preset = "Normal"
with gr.Row():
tag_format_preset = gr.Radio(
["Normal", "Booru", "Custom"],
value=default_tag_preset,
label="Tagging Format Presets"
)
with gr.Row():
preset_checkboxgroup = gr.CheckboxGroup(
choices=PRESET_CHECKBOX_CHOICES,
value=PRESET_CHECKBOX_DICT[default_tag_preset],
type='index',
show_label=False
)
with gr.Row():
sort_by_confidence_enabled = gr.Checkbox(
value=True if args.sort_tag_string_by_confidence else False,
label="Sort By Confidence"
)
sort_descending_enabled = gr.Checkbox(
value=False,
label="Descending"
)
sorted_general_strings = gr.Textbox(
label="Output (string)",
show_copy_button=True
)
sorted_character_strings = gr.Textbox(
label="Characters (string)",
show_copy_button=True
)
rating = gr.Label(label="Rating")
character_res = gr.Label(label="Output (characters)")
general_res = gr.Label(label="Output (tags)")
clear.add(
[
sorted_general_strings,
rating,
character_res,
general_res,
]
)
# Update gradio widgets
tag_format_preset.change(
fn=update_preset_checkboxes,
inputs=[tag_format_preset, preset_checkboxgroup],
outputs=preset_checkboxgroup
)
preset_checkboxgroup.input(
fn=update_tag_preset,
outputs=tag_format_preset
)
submit.click(
predictor.predict,
inputs=[
image,
model_repo,
general_thresh,
general_mcut_enabled,
character_thresh,
character_mcut_enabled,
sort_by_confidence_enabled,
sort_descending_enabled,
preset_checkboxgroup
],
outputs=[sorted_general_strings, sorted_character_strings,
rating, character_res, general_res],
)
demo.queue(max_size=10)
demo.launch(share=args.share)
if __name__ == "__main__":
main()