|
import argparse |
|
import os |
|
|
|
import gradio as gr |
|
import huggingface_hub |
|
import numpy as np |
|
import onnxruntime as rt |
|
import pandas as pd |
|
from PIL import Image |
|
|
|
TITLE = "WaifuDiffusion Tagger" |
|
DESCRIPTION = """ |
|
Demo for the WaifuDiffusion tagger models |
|
|
|
Example image by [ほし☆☆☆](https://www.pixiv.net/en/users/43565085) |
|
""" |
|
|
|
HF_TOKEN = os.environ["HF_TOKEN"] |
|
|
|
|
|
SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3" |
|
CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3" |
|
VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3" |
|
|
|
|
|
MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2" |
|
SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" |
|
CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" |
|
CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2" |
|
VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2" |
|
|
|
|
|
MODEL_FILENAME = "model.onnx" |
|
LABEL_FILENAME = "selected_tags.csv" |
|
|
|
|
|
kaomojis = [ |
|
"0_0", |
|
"(o)_(o)", |
|
"+_+", |
|
"+_-", |
|
"._.", |
|
"<o>_<o>", |
|
"<|>_<|>", |
|
"=_=", |
|
">_<", |
|
"3_3", |
|
"6_9", |
|
">_o", |
|
"@_@", |
|
"^_^", |
|
"o_o", |
|
"u_u", |
|
"x_x", |
|
"|_|", |
|
"||_||", |
|
] |
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--score-slider-step", type=float, default=0.05) |
|
parser.add_argument("--score-general-threshold", type=float, default=0.35) |
|
parser.add_argument("--score-character-threshold", type=float, default=0.85) |
|
parser.add_argument("--share", action="store_true") |
|
return parser.parse_args() |
|
|
|
|
|
def load_labels(dataframe) -> list[str]: |
|
name_series = dataframe["name"] |
|
name_series = name_series.map( |
|
lambda x: x.replace("_", " ") if x not in kaomojis else x |
|
) |
|
tag_names = name_series.tolist() |
|
|
|
rating_indexes = list(np.where(dataframe["category"] == 9)[0]) |
|
general_indexes = list(np.where(dataframe["category"] == 0)[0]) |
|
character_indexes = list(np.where(dataframe["category"] == 4)[0]) |
|
return tag_names, rating_indexes, general_indexes, character_indexes |
|
|
|
|
|
def mcut_threshold(probs): |
|
""" |
|
Maximum Cut Thresholding (MCut) |
|
Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy |
|
for Multi-label Classification. In 11th International Symposium, IDA 2012 |
|
(pp. 172-183). |
|
""" |
|
sorted_probs = probs[probs.argsort()[::-1]] |
|
difs = sorted_probs[:-1] - sorted_probs[1:] |
|
t = difs.argmax() |
|
thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2 |
|
return thresh |
|
|
|
|
|
class Predictor: |
|
def __init__(self): |
|
self.model_target_size = None |
|
self.last_loaded_repo = None |
|
|
|
def download_model(self, model_repo): |
|
csv_path = huggingface_hub.hf_hub_download( |
|
model_repo, |
|
LABEL_FILENAME, |
|
use_auth_token=HF_TOKEN, |
|
) |
|
model_path = huggingface_hub.hf_hub_download( |
|
model_repo, |
|
MODEL_FILENAME, |
|
use_auth_token=HF_TOKEN, |
|
) |
|
return csv_path, model_path |
|
|
|
def load_model(self, model_repo): |
|
if model_repo == self.last_loaded_repo: |
|
return |
|
|
|
csv_path, model_path = self.download_model(model_repo) |
|
|
|
tags_df = pd.read_csv(csv_path) |
|
sep_tags = load_labels(tags_df) |
|
|
|
self.tag_names = sep_tags[0] |
|
self.rating_indexes = sep_tags[1] |
|
self.general_indexes = sep_tags[2] |
|
self.character_indexes = sep_tags[3] |
|
|
|
model = rt.InferenceSession(model_path) |
|
_, height, width, _ = model.get_inputs()[0].shape |
|
self.model_target_size = height |
|
|
|
self.last_loaded_repo = model_repo |
|
self.model = model |
|
|
|
def prepare_image(self, image): |
|
target_size = self.model_target_size |
|
|
|
canvas = Image.new("RGBA", image.size, (255, 255, 255)) |
|
canvas.alpha_composite(image) |
|
image = canvas.convert("RGB") |
|
|
|
|
|
image_shape = image.size |
|
max_dim = max(image_shape) |
|
pad_left = (max_dim - image_shape[0]) // 2 |
|
pad_top = (max_dim - image_shape[1]) // 2 |
|
|
|
padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) |
|
padded_image.paste(image, (pad_left, pad_top)) |
|
|
|
|
|
if max_dim != target_size: |
|
padded_image = padded_image.resize( |
|
(target_size, target_size), |
|
Image.BICUBIC, |
|
) |
|
|
|
|
|
image_array = np.asarray(padded_image, dtype=np.float32) |
|
|
|
|
|
image_array = image_array[:, :, ::-1] |
|
|
|
return np.expand_dims(image_array, axis=0) |
|
|
|
def predict( |
|
self, |
|
image, |
|
model_repo, |
|
general_thresh, |
|
general_mcut_enabled, |
|
character_thresh, |
|
character_mcut_enabled, |
|
): |
|
self.load_model(model_repo) |
|
|
|
image = self.prepare_image(image) |
|
|
|
input_name = self.model.get_inputs()[0].name |
|
label_name = self.model.get_outputs()[0].name |
|
preds = self.model.run([label_name], {input_name: image})[0] |
|
|
|
labels = list(zip(self.tag_names, preds[0].astype(float))) |
|
|
|
|
|
ratings_names = [labels[i] for i in self.rating_indexes] |
|
rating = dict(ratings_names) |
|
|
|
|
|
general_names = [labels[i] for i in self.general_indexes] |
|
|
|
if general_mcut_enabled: |
|
general_probs = np.array([x[1] for x in general_names]) |
|
general_thresh = mcut_threshold(general_probs) |
|
|
|
general_res = [x for x in general_names if x[1] > general_thresh] |
|
general_res = dict(general_res) |
|
|
|
|
|
character_names = [labels[i] for i in self.character_indexes] |
|
|
|
if character_mcut_enabled: |
|
character_probs = np.array([x[1] for x in character_names]) |
|
character_thresh = mcut_threshold(character_probs) |
|
character_thresh = max(0.15, character_thresh) |
|
|
|
character_res = [x for x in character_names if x[1] > character_thresh] |
|
character_res = dict(character_res) |
|
|
|
sorted_general_strings = sorted( |
|
general_res.items(), |
|
key=lambda x: x[1], |
|
reverse=True, |
|
) |
|
sorted_general_strings = [x[0] for x in sorted_general_strings] |
|
sorted_general_strings = ( |
|
", ".join(sorted_general_strings).replace("(", "\(").replace(")", "\)") |
|
) |
|
|
|
return sorted_general_strings, rating, character_res, general_res |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
|
|
predictor = Predictor() |
|
|
|
dropdown_list = [ |
|
SWINV2_MODEL_DSV3_REPO, |
|
CONV_MODEL_DSV3_REPO, |
|
VIT_MODEL_DSV3_REPO, |
|
MOAT_MODEL_DSV2_REPO, |
|
SWIN_MODEL_DSV2_REPO, |
|
CONV_MODEL_DSV2_REPO, |
|
CONV2_MODEL_DSV2_REPO, |
|
VIT_MODEL_DSV2_REPO, |
|
] |
|
|
|
with gr.Blocks(title=TITLE) as demo: |
|
with gr.Column(): |
|
gr.Markdown( |
|
value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>" |
|
) |
|
gr.Markdown(value=DESCRIPTION) |
|
with gr.Row(): |
|
with gr.Column(variant="panel"): |
|
image = gr.Image(type="pil", image_mode="RGBA", label="Input") |
|
model_repo = gr.Dropdown( |
|
dropdown_list, |
|
value=SWINV2_MODEL_DSV3_REPO, |
|
label="Model", |
|
) |
|
with gr.Row(): |
|
general_thresh = gr.Slider( |
|
0, |
|
1, |
|
step=args.score_slider_step, |
|
value=args.score_general_threshold, |
|
label="General Tags Threshold", |
|
scale=3, |
|
) |
|
general_mcut_enabled = gr.Checkbox( |
|
value=False, |
|
label="Use MCut threshold", |
|
scale=1, |
|
) |
|
with gr.Row(): |
|
character_thresh = gr.Slider( |
|
0, |
|
1, |
|
step=args.score_slider_step, |
|
value=args.score_character_threshold, |
|
label="Character Tags Threshold", |
|
scale=3, |
|
) |
|
character_mcut_enabled = gr.Checkbox( |
|
value=False, |
|
label="Use MCut threshold", |
|
scale=1, |
|
) |
|
with gr.Row(): |
|
clear = gr.ClearButton( |
|
components=[ |
|
image, |
|
model_repo, |
|
general_thresh, |
|
general_mcut_enabled, |
|
character_thresh, |
|
character_mcut_enabled, |
|
], |
|
variant="secondary", |
|
size="lg", |
|
) |
|
submit = gr.Button(value="Submit", variant="primary", size="lg") |
|
with gr.Column(variant="panel"): |
|
sorted_general_strings = gr.Textbox(label="Output (string)") |
|
rating = gr.Label(label="Rating") |
|
character_res = gr.Label(label="Output (characters)") |
|
general_res = gr.Label(label="Output (tags)") |
|
clear.add( |
|
[ |
|
sorted_general_strings, |
|
rating, |
|
character_res, |
|
general_res, |
|
] |
|
) |
|
|
|
submit.click( |
|
predictor.predict, |
|
inputs=[ |
|
image, |
|
model_repo, |
|
general_thresh, |
|
general_mcut_enabled, |
|
character_thresh, |
|
character_mcut_enabled, |
|
], |
|
outputs=[sorted_general_strings, rating, character_res, general_res], |
|
) |
|
|
|
gr.Examples( |
|
[["power.jpg", SWINV2_MODEL_DSV3_REPO, 0.35, False, 0.85, False]], |
|
inputs=[ |
|
image, |
|
model_repo, |
|
general_thresh, |
|
general_mcut_enabled, |
|
character_thresh, |
|
character_mcut_enabled, |
|
], |
|
) |
|
|
|
demo.queue(max_size=10) |
|
demo.launch() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|