File size: 3,451 Bytes
1d409a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags
# https://github.com/pythongosssss/ComfyUI-WD14-Tagger/blob/main/wd14tagger.py

# {
#     "wd-v1-4-moat-tagger-v2": "https://huggingface.co/SmilingWolf/wd-v1-4-moat-tagger-v2",
#     "wd-v1-4-convnextv2-tagger-v2": "https://huggingface.co/SmilingWolf/wd-v1-4-convnextv2-tagger-v2",
#     "wd-v1-4-convnext-tagger-v2": "https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2",
#     "wd-v1-4-convnext-tagger": "https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger",
#     "wd-v1-4-vit-tagger-v2": "https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger-v2"
# }


import numpy as np
import csv
import onnxruntime as ort

from PIL import Image
from onnxruntime import InferenceSession
from modules.config import path_clip_vision
from modules.model_loader import load_file_from_url


global_model = None
global_csv = None


def default_interrogator(image_rgb, threshold=0.35, character_threshold=0.85, exclude_tags=""):
    global global_model, global_csv

    model_name = "wd-v1-4-moat-tagger-v2"

    model_onnx_filename = load_file_from_url(
        url=f'https://huggingface.co/lllyasviel/misc/resolve/main/{model_name}.onnx',
        model_dir=path_clip_vision,
        file_name=f'{model_name}.onnx',
    )

    model_csv_filename = load_file_from_url(
        url=f'https://huggingface.co/lllyasviel/misc/resolve/main/{model_name}.csv',
        model_dir=path_clip_vision,
        file_name=f'{model_name}.csv',
    )

    if global_model is not None:
        model = global_model
    else:
        model = InferenceSession(model_onnx_filename, providers=ort.get_available_providers())
        global_model = model

    input = model.get_inputs()[0]
    height = input.shape[1]

    image = Image.fromarray(image_rgb)  # RGB
    ratio = float(height)/max(image.size)
    new_size = tuple([int(x*ratio) for x in image.size])
    image = image.resize(new_size, Image.LANCZOS)
    square = Image.new("RGB", (height, height), (255, 255, 255))
    square.paste(image, ((height-new_size[0])//2, (height-new_size[1])//2))

    image = np.array(square).astype(np.float32)
    image = image[:, :, ::-1]  # RGB -> BGR
    image = np.expand_dims(image, 0)

    if global_csv is not None:
        csv_lines = global_csv
    else:
        csv_lines = []
        with open(model_csv_filename) as f:
            reader = csv.reader(f)
            next(reader)
            for row in reader:
                csv_lines.append(row)
        global_csv = csv_lines

    tags = []
    general_index = None
    character_index = None
    for line_num, row in enumerate(csv_lines):
        if general_index is None and row[2] == "0":
            general_index = line_num
        elif character_index is None and row[2] == "4":
            character_index = line_num
        tags.append(row[1])

    label_name = model.get_outputs()[0].name
    probs = model.run([label_name], {input.name: image})[0]

    result = list(zip(tags, probs[0]))

    general = [item for item in result[general_index:character_index] if item[1] > threshold]
    character = [item for item in result[character_index:] if item[1] > character_threshold]

    all = character + general
    remove = [s.strip() for s in exclude_tags.lower().split(",")]
    all = [tag for tag in all if tag[0] not in remove]

    res = ", ".join((item[0].replace("(", "\\(").replace(")", "\\)") for item in all)).replace('_', ' ')
    return res