Spaces:
Runtime error
Runtime error
import argparse | |
import gradio as gr | |
import huggingface_hub | |
import numpy as np | |
import onnxruntime as rt | |
import pandas as pd | |
from PIL import Image | |
TITLE = "Image Tagger" | |
DESCRIPTION = "Modified from: [SmilingWolf/wd-tagger](https://huggingface.co/spaces/SmilingWolf/wd-tagger) (8279aed)" | |
# Dataset v3 series of models: | |
SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3" | |
CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3" | |
VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3" | |
VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3" | |
EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3" | |
# Dataset v2 series of models: | |
# MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2" | |
# SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" | |
# CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" | |
# CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2" | |
# VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2" | |
# Files to download from the repos | |
MODEL_FILENAME = "model.onnx" | |
LABEL_FILENAME = "selected_tags.csv" | |
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368 | |
kaomojis = [ | |
"0_0", | |
"(o)_(o)", | |
"+_+", | |
"+_-", | |
"._.", | |
"<o>_<o>", | |
"<|>_<|>", | |
"=_=", | |
">_<", | |
"3_3", | |
"6_9", | |
">_o", | |
"@_@", | |
"^_^", | |
"o_o", | |
"u_u", | |
"x_x", | |
"|_|", | |
"||_||", | |
] | |
def parse_args() -> argparse.Namespace: | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--score-slider-step", type=float, default=0.05) | |
parser.add_argument("--score-general-threshold", type=float, default=0.35) | |
parser.add_argument("--score-character-threshold", type=float, default=0.80) | |
parser.add_argument("--sort-tag-string-by-confidence", action="store_true") | |
parser.add_argument("--share", action="store_true") | |
return parser.parse_args() | |
def load_labels(dataframe) -> list[str]: | |
name_series = dataframe["name"] | |
name_series = name_series.map( | |
lambda x: x.replace("_", " ") if x not in kaomojis else x | |
) | |
tag_names = name_series.tolist() | |
rating_indexes = list(np.where(dataframe["category"] == 9)[0]) | |
general_indexes = list(np.where(dataframe["category"] == 0)[0]) | |
character_indexes = list(np.where(dataframe["category"] == 4)[0]) | |
return tag_names, rating_indexes, general_indexes, character_indexes | |
def mcut_threshold(probs): | |
""" | |
Maximum Cut Thresholding (MCut) | |
Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy | |
for Multi-label Classification. In 11th International Symposium, IDA 2012 | |
(pp. 172-183). | |
""" | |
sorted_probs = probs[probs.argsort()[::-1]] | |
difs = sorted_probs[:-1] - sorted_probs[1:] | |
t = difs.argmax() | |
thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2 | |
return thresh | |
class Predictor: | |
def __init__(self): | |
self.model_target_size = None | |
self.last_loaded_repo = None | |
def download_model(self, model_repo): | |
csv_path = huggingface_hub.hf_hub_download( | |
model_repo, | |
LABEL_FILENAME, | |
) | |
model_path = huggingface_hub.hf_hub_download( | |
model_repo, | |
MODEL_FILENAME, | |
) | |
return csv_path, model_path | |
def load_model(self, model_repo): | |
if model_repo == self.last_loaded_repo: | |
return | |
csv_path, model_path = self.download_model(model_repo) | |
tags_df = pd.read_csv(csv_path) | |
sep_tags = load_labels(tags_df) | |
self.tag_names = sep_tags[0] | |
self.rating_indexes = sep_tags[1] | |
self.general_indexes = sep_tags[2] | |
self.character_indexes = sep_tags[3] | |
model = rt.InferenceSession(model_path) | |
_, height, width, _ = model.get_inputs()[0].shape | |
self.model_target_size = height | |
self.last_loaded_repo = model_repo | |
self.model = model | |
def prepare_image(self, image): | |
target_size = self.model_target_size | |
canvas = Image.new("RGBA", image.size, (255, 255, 255)) | |
canvas.alpha_composite(image) | |
image = canvas.convert("RGB") | |
# Pad image to square | |
image_shape = image.size | |
max_dim = max(image_shape) | |
pad_left = (max_dim - image_shape[0]) // 2 | |
pad_top = (max_dim - image_shape[1]) // 2 | |
padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) | |
padded_image.paste(image, (pad_left, pad_top)) | |
# Resize | |
if max_dim != target_size: | |
padded_image = padded_image.resize( | |
(target_size, target_size), | |
Image.BICUBIC, | |
) | |
# Convert to numpy array | |
image_array = np.asarray(padded_image, dtype=np.float32) | |
# Convert PIL-native RGB to BGR | |
image_array = image_array[:, :, ::-1] | |
return np.expand_dims(image_array, axis=0) | |
def tag_dict_to_sorted_string(self, dict_res: dict, sort_by_confidence, descending, | |
remove_underlines, escape_parens, comma_sep): | |
"""Custom function: Sort tag dict by confidence/alphabetically""" | |
sep = ', ' if comma_sep else ' ' | |
if sort_by_confidence: | |
_sorted_list = sorted( | |
dict_res.items(), | |
key=lambda x: x[1], | |
reverse=descending | |
) | |
else: | |
_sorted_list = sorted( | |
dict_res.items(), | |
reverse=descending | |
) | |
if remove_underlines: | |
_sorted_string = sep.join([x[0] for x in _sorted_list]) | |
else: # Add back underlines | |
_sorted_string = sep.join([x[0].replace(" ", "_") for x in _sorted_list]) | |
if escape_parens: | |
_sorted_string = _sorted_string.replace("(", "\\(").replace(")", "\\)") | |
return _sorted_string | |
def predict( | |
self, | |
image, | |
model_repo, | |
general_thresh, | |
general_mcut_enabled, | |
character_thresh, | |
character_mcut_enabled, | |
sort_by_confidence_enabled, | |
sort_descending_enabled, | |
preset_checkboxgroup | |
): | |
# Decouple the checkgroup status into 3 | |
remove_underline_enabled, escape_parens_enabled, comma_sep_enabled = [ | |
True if i in preset_checkboxgroup else False | |
for i in range(3) | |
] | |
self.load_model(model_repo) | |
image = self.prepare_image(image) | |
input_name = self.model.get_inputs()[0].name | |
label_name = self.model.get_outputs()[0].name | |
preds = self.model.run([label_name], {input_name: image})[0] | |
labels = list(zip(self.tag_names, preds[0].astype(float))) | |
# First 4 labels are actually ratings: pick one with argmax | |
ratings_names = [labels[i] for i in self.rating_indexes] | |
rating = dict(ratings_names) | |
# Then we have general tags: pick any where prediction confidence > threshold | |
general_names = [labels[i] for i in self.general_indexes] | |
if general_mcut_enabled: | |
general_probs = np.array([x[1] for x in general_names]) | |
general_thresh = mcut_threshold(general_probs) | |
general_res = [x for x in general_names if x[1] > general_thresh] | |
general_res = dict(general_res) | |
# Everything else is characters: pick any where prediction confidence > threshold | |
character_names = [labels[i] for i in self.character_indexes] | |
if character_mcut_enabled: | |
character_probs = np.array([x[1] for x in character_names]) | |
character_thresh = mcut_threshold(character_probs) | |
character_thresh = max(0.15, character_thresh) | |
character_res = [x for x in character_names if x[1] > character_thresh] | |
character_res = dict(character_res) | |
sorted_general_strings = self.tag_dict_to_sorted_string( | |
general_res, | |
sort_by_confidence=sort_by_confidence_enabled, | |
descending=sort_descending_enabled, | |
remove_underlines=remove_underline_enabled, | |
escape_parens=escape_parens_enabled, | |
comma_sep=comma_sep_enabled | |
) | |
sorted_character_strings = self.tag_dict_to_sorted_string( | |
character_res, | |
sort_by_confidence=sort_by_confidence_enabled, | |
descending=sort_descending_enabled, | |
remove_underlines=remove_underline_enabled, | |
escape_parens=escape_parens_enabled, | |
comma_sep=comma_sep_enabled | |
) | |
return sorted_general_strings, sorted_character_strings, rating, character_res, general_res | |
def main(): | |
args = parse_args() | |
predictor = Predictor() | |
dropdown_list = [ | |
SWINV2_MODEL_DSV3_REPO, | |
CONV_MODEL_DSV3_REPO, | |
VIT_MODEL_DSV3_REPO, | |
VIT_LARGE_MODEL_DSV3_REPO, | |
EVA02_LARGE_MODEL_DSV3_REPO, | |
# MOAT_MODEL_DSV2_REPO, | |
# SWIN_MODEL_DSV2_REPO, | |
# CONV_MODEL_DSV2_REPO, | |
# CONV2_MODEL_DSV2_REPO, | |
# VIT_MODEL_DSV2_REPO, | |
] | |
# Define widget udpate functions | |
PRESET_CHECKBOX_CHOICES = ["Remove Underlines", "Escape Parens", "Comma Separator"] | |
PRESET_CHECKBOX_DICT = { | |
"Normal": [PRESET_CHECKBOX_CHOICES[i] for i in[0, 2]], | |
"Booru": [] | |
} | |
def update_preset_checkboxes(preset_radio, preset_checkbox_indices): | |
"""Change checkboxgroup according to the radio selected preset.""" | |
current_checks = [PRESET_CHECKBOX_CHOICES[i] for i in preset_checkbox_indices] | |
updated_checks = PRESET_CHECKBOX_DICT.get(preset_radio, current_checks) | |
return updated_checks | |
def update_tag_preset(): | |
"""Whenever the checkboxgroup is manually changed, set preset to 'Custom'.""" | |
return "Custom" | |
with gr.Blocks(title=TITLE, theme=gr.themes.Soft(primary_hue="teal")) as demo: | |
with gr.Column(): | |
gr.Markdown( | |
value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>" | |
) | |
gr.Markdown(value=DESCRIPTION) | |
with gr.Row(): | |
with gr.Column(variant="panel"): | |
submit = gr.Button(value="Submit", variant="primary") | |
image = gr.Image(type="pil", image_mode="RGBA", label="Input") | |
model_repo = gr.Dropdown( | |
dropdown_list, | |
value=SWINV2_MODEL_DSV3_REPO, | |
label="Model", | |
) | |
with gr.Row(): | |
general_thresh = gr.Slider( | |
0, | |
1, | |
step=args.score_slider_step, | |
value=args.score_general_threshold, | |
label="General Tags Threshold", | |
scale=3, | |
) | |
general_mcut_enabled = gr.Checkbox( | |
value=False, | |
label="Use MCut threshold", | |
scale=1, | |
) | |
with gr.Row(): | |
character_thresh = gr.Slider( | |
0, | |
1, | |
step=args.score_slider_step, | |
value=args.score_character_threshold, | |
label="Character Tags Threshold", | |
scale=3, | |
) | |
character_mcut_enabled = gr.Checkbox( | |
value=False, | |
label="Use MCut threshold", | |
scale=1, | |
) | |
with gr.Row(): | |
clear = gr.ClearButton( | |
components=[ | |
image, | |
model_repo, | |
general_thresh, | |
general_mcut_enabled, | |
character_thresh, | |
character_mcut_enabled, | |
], | |
variant="secondary" | |
) | |
with gr.Column(variant="panel"): | |
default_tag_preset = "Normal" | |
with gr.Row(): | |
tag_format_preset = gr.Radio( | |
["Normal", "Booru", "Custom"], | |
value=default_tag_preset, | |
label="Tagging Format Presets" | |
) | |
with gr.Row(): | |
preset_checkboxgroup = gr.CheckboxGroup( | |
choices=PRESET_CHECKBOX_CHOICES, | |
value=PRESET_CHECKBOX_DICT[default_tag_preset], | |
type='index', | |
show_label=False | |
) | |
with gr.Row(): | |
sort_by_confidence_enabled = gr.Checkbox( | |
value=True if args.sort_tag_string_by_confidence else False, | |
label="Sort By Confidence" | |
) | |
sort_descending_enabled = gr.Checkbox( | |
value=False, | |
label="Descending" | |
) | |
sorted_general_strings = gr.Textbox( | |
label="Output (string)", | |
show_copy_button=True | |
) | |
sorted_character_strings = gr.Textbox( | |
label="Characters (string)", | |
show_copy_button=True | |
) | |
rating = gr.Label(label="Rating") | |
character_res = gr.Label(label="Output (characters)") | |
general_res = gr.Label(label="Output (tags)") | |
clear.add( | |
[ | |
sorted_general_strings, | |
rating, | |
character_res, | |
general_res, | |
] | |
) | |
# Update gradio widgets | |
tag_format_preset.change( | |
fn=update_preset_checkboxes, | |
inputs=[tag_format_preset, preset_checkboxgroup], | |
outputs=preset_checkboxgroup | |
) | |
preset_checkboxgroup.input( | |
fn=update_tag_preset, | |
outputs=tag_format_preset | |
) | |
submit.click( | |
predictor.predict, | |
inputs=[ | |
image, | |
model_repo, | |
general_thresh, | |
general_mcut_enabled, | |
character_thresh, | |
character_mcut_enabled, | |
sort_by_confidence_enabled, | |
sort_descending_enabled, | |
preset_checkboxgroup | |
], | |
outputs=[sorted_general_strings, sorted_character_strings, | |
rating, character_res, general_res], | |
) | |
demo.queue(max_size=10) | |
demo.launch(share=args.share) | |
if __name__ == "__main__": | |
main() |