top001 commited on
Commit
0a85798
1 Parent(s): 3d2dbcf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -189
app.py CHANGED
@@ -1,5 +1,7 @@
1
  import argparse
2
  import os
 
 
3
 
4
  import gradio as gr
5
  import huggingface_hub
@@ -7,21 +9,22 @@ import numpy as np
7
  import onnxruntime as rt
8
  import pandas as pd
9
  from PIL import Image
 
 
 
 
10
 
11
  TITLE = "WaifuDiffusion Tagger"
12
- DESCRIPTION = """
13
- Demo for the WaifuDiffusion tagger models
14
- Example image by [ほし☆☆☆](https://www.pixiv.net/en/users/43565085)
15
- """
16
 
17
- # Dataset v3 series of models:
18
  SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
19
  CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
20
  VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
21
  VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
22
  EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
23
 
24
- # Dataset v2 series of models:
25
  MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
26
  SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
27
  CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
@@ -31,37 +34,8 @@ VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
31
  MODEL_FILENAME = "model.onnx"
32
  LABEL_FILENAME = "selected_tags.csv"
33
 
34
- kaomojis = [
35
- "0_0", "(o)_(o)", "+_+", "+_-", "._.", "<o>_<o>", "<|>_<|>", "=_=", ">_<",
36
- "3_3", "6_9", ">_o", "@_@", "^_^", "o_o", "u_u", "x_x", "|_|", "||_||",
37
- ]
38
-
39
- def parse_args() -> argparse.Namespace:
40
- parser = argparse.ArgumentParser()
41
- parser.add_argument("--score-slider-step", type=float, default=0.05)
42
- parser.add_argument("--score-general-threshold", type=float, default=0.35)
43
- parser.add_argument("--score-character-threshold", type=float, default=0.85)
44
- parser.add_argument("--share", action="store_true")
45
- return parser.parse_args()
46
-
47
- def load_labels(dataframe) -> list[str]:
48
- name_series = dataframe["name"]
49
- name_series = name_series.map(
50
- lambda x: x.replace("_", " ") if x not in kaomojis else x
51
- )
52
- tag_names = name_series.tolist()
53
-
54
- rating_indexes = list(np.where(dataframe["category"] == 9)[0])
55
- general_indexes = list(np.where(dataframe["category"] == 0)[0])
56
- character_indexes = list(np.where(dataframe["category"] == 4)[0])
57
- return tag_names, rating_indexes, general_indexes, character_indexes
58
-
59
- def mcut_threshold(probs):
60
- sorted_probs = probs[probs.argsort()[::-1]]
61
- difs = sorted_probs[:-1] - sorted_probs[1:]
62
- t = difs.argmax()
63
- thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
64
- return thresh
65
 
66
  class Predictor:
67
  def __init__(self):
@@ -79,52 +53,40 @@ class Predictor:
79
 
80
  csv_path, model_path = self.download_model(model_repo)
81
  tags_df = pd.read_csv(csv_path)
82
- sep_tags = load_labels(tags_df)
83
-
84
- self.tag_names = sep_tags[0]
85
- self.rating_indexes = sep_tags[1]
86
- self.general_indexes = sep_tags[2]
87
- self.character_indexes = sep_tags[3]
 
88
 
89
- model = rt.InferenceSession(model_path)
90
- _, height, width, _ = model.get_inputs()[0].shape
91
  self.model_target_size = height
92
-
93
  self.last_loaded_repo = model_repo
94
- self.model = model
95
 
96
  def prepare_image(self, image):
97
- target_size = self.model_target_size
98
-
99
  canvas = Image.new("RGBA", image.size, (255, 255, 255))
100
  canvas.alpha_composite(image)
101
  image = canvas.convert("RGB")
102
 
103
- image_shape = image.size
104
- max_dim = max(image_shape)
105
- pad_left = (max_dim - image_shape[0]) // 2
106
- pad_top = (max_dim - image_shape[1]) // 2
107
 
108
  padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
109
  padded_image.paste(image, (pad_left, pad_top))
110
 
111
- if max_dim != target_size:
112
- padded_image = padded_image.resize((target_size, target_size), Image.BICUBIC)
113
 
114
  image_array = np.asarray(padded_image, dtype=np.float32)
115
  image_array = image_array[:, :, ::-1]
116
 
117
  return np.expand_dims(image_array, axis=0)
118
 
119
- def predict(
120
- self,
121
- image,
122
- model_repo,
123
- general_thresh,
124
- general_mcut_enabled,
125
- character_thresh,
126
- character_mcut_enabled,
127
- ):
128
  self.load_model(model_repo)
129
 
130
  image = self.prepare_image(image)
@@ -133,139 +95,94 @@ class Predictor:
133
  preds = self.model.run([label_name], {input_name: image})[0]
134
 
135
  labels = list(zip(self.tag_names, preds[0].astype(float)))
136
-
137
- ratings_names = [labels[i] for i in self.rating_indexes]
138
- rating = dict(ratings_names)
139
-
140
  general_names = [labels[i] for i in self.general_indexes]
141
- if general_mcut_enabled:
142
- general_probs = np.array([x[1] for x in general_names])
143
- general_thresh = mcut_threshold(general_probs)
144
-
145
- general_res = [x for x in general_names if x[1] > general_thresh]
146
  general_res = dict(general_res)
147
 
148
- character_names = [labels[i] for i in self.character_indexes]
149
- if character_mcut_enabled:
150
- character_probs = np.array([x[1] for x in character_names])
151
- character_thresh = mcut_threshold(character_probs)
152
- character_thresh = max(0.15, character_thresh)
153
-
154
- character_res = [x for x in character_names if x[1] > character_thresh]
155
- character_res = dict(character_res)
156
-
157
  sorted_general = sorted(general_res.items(), key=lambda x: x[1], reverse=True)
158
- sorted_general_strings = [x[0] for x in sorted_general]
159
- sorted_general_strings = ", ".join(sorted_general_strings).replace("(", "\(").replace(")", "\)")
160
-
161
- return sorted_general_strings, rating, character_res, general_res
162
-
163
- def main():
164
- args = parse_args()
165
- predictor = Predictor()
166
-
167
- dropdown_list = [
168
- SWINV2_MODEL_DSV3_REPO,
169
- CONV_MODEL_DSV3_REPO,
170
- VIT_MODEL_DSV3_REPO,
171
- VIT_LARGE_MODEL_DSV3_REPO,
172
- EVA02_LARGE_MODEL_DSV3_REPO,
173
- MOAT_MODEL_DSV2_REPO,
174
- SWIN_MODEL_DSV2_REPO,
175
- CONV_MODEL_DSV2_REPO,
176
- CONV2_MODEL_DSV2_REPO,
177
- VIT_MODEL_DSV2_REPO,
178
- ]
179
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  with gr.Blocks(title=TITLE) as demo:
181
- with gr.Column():
182
- gr.Markdown(f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>")
183
- gr.Markdown(DESCRIPTION)
184
- with gr.Row():
185
- with gr.Column(variant="panel"):
186
- image = gr.Image(type="pil", image_mode="RGBA", label="Input")
187
- model_repo = gr.Dropdown(
188
- dropdown_list,
189
- value=SWINV2_MODEL_DSV3_REPO,
190
- label="Model",
191
- )
192
- with gr.Row():
193
- general_thresh = gr.Slider(
194
- 0,
195
- 1,
196
- step=args.score_slider_step,
197
- value=args.score_general_threshold,
198
- label="General Tags Threshold",
199
- scale=3,
200
- )
201
- general_mcut_enabled = gr.Checkbox(
202
- value=False,
203
- label="Use MCut threshold",
204
- scale=1,
205
- )
206
- with gr.Row():
207
- character_thresh = gr.Slider(
208
- 0,
209
- 1,
210
- step=args.score_slider_step,
211
- value=args.score_character_threshold,
212
- label="Character Tags Threshold",
213
- scale=3,
214
- )
215
- character_mcut_enabled = gr.Checkbox(
216
- value=False,
217
- label="Use MCut threshold",
218
- scale=1,
219
- )
220
- with gr.Row():
221
- clear = gr.ClearButton(
222
- components=[
223
- image,
224
- model_repo,
225
- general_thresh,
226
- general_mcut_enabled,
227
- character_thresh,
228
- character_mcut_enabled,
229
- ],
230
- variant="secondary",
231
- size="lg",
232
- )
233
- submit = gr.Button(value="Submit", variant="primary", size="lg")
234
- with gr.Column(variant="panel"):
235
- sorted_general_strings = gr.Textbox(label="Output (string)")
236
- rating = gr.Label(label="Rating")
237
- character_res = gr.Label(label="Output (characters)")
238
- general_res = gr.Label(label="Output (tags)")
239
- clear.add([sorted_general_strings, rating, character_res, general_res])
240
 
241
  submit.click(
242
- predictor.predict,
243
- inputs=[
244
- image,
245
- model_repo,
246
- general_thresh,
247
- general_mcut_enabled,
248
- character_thresh,
249
- character_mcut_enabled,
250
- ],
251
- outputs=[sorted_general_strings, rating, character_res, general_res],
252
  )
253
-
254
- gr.Examples(
255
- [["power.jpg", SWINV2_MODEL_DSV3_REPO, 0.35, False, 0.85, False]],
256
- inputs=[
257
- image,
258
- model_repo,
259
- general_thresh,
260
- general_mcut_enabled,
261
- character_thresh,
262
- character_mcut_enabled,
263
- ],
264
- )
265
-
266
  demo.queue(max_size=10)
267
-
268
- demo.launch()
 
269
 
270
  if __name__ == "__main__":
271
- main()
 
 
1
  import argparse
2
  import os
3
+ from typing import Optional
4
+ import io
5
 
6
  import gradio as gr
7
  import huggingface_hub
 
9
  import onnxruntime as rt
10
  import pandas as pd
11
  from PIL import Image
12
+ from fastapi import FastAPI, File, UploadFile, Form
13
+ from fastapi.responses import JSONResponse
14
+
15
+ app = FastAPI()
16
 
17
  TITLE = "WaifuDiffusion Tagger"
18
+ DESCRIPTION = "Demo for the WaifuDiffusion tagger models"
 
 
 
19
 
20
+ # Dataset v3 models
21
  SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
22
  CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
23
  VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
24
  VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
25
  EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
26
 
27
+ # Dataset v2 models
28
  MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
29
  SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
30
  CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
 
34
  MODEL_FILENAME = "model.onnx"
35
  LABEL_FILENAME = "selected_tags.csv"
36
 
37
+ kaomojis = ["0_0", "(o)_(o)", "+_+", "+_-", "._.", "<o>_<o>", "<|>_<|>", "=_=", ">_<",
38
+ "3_3", "6_9", ">_o", "@_@", "^_^", "o_o", "u_u", "x_x", "|_|", "||_||"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  class Predictor:
41
  def __init__(self):
 
53
 
54
  csv_path, model_path = self.download_model(model_repo)
55
  tags_df = pd.read_csv(csv_path)
56
+ name_series = tags_df["name"]
57
+ name_series = name_series.map(lambda x: x.replace("_", " ") if x not in kaomojis else x)
58
+
59
+ self.tag_names = name_series.tolist()
60
+ self.rating_indexes = list(np.where(tags_df["category"] == 9)[0])
61
+ self.general_indexes = list(np.where(tags_df["category"] == 0)[0])
62
+ self.character_indexes = list(np.where(tags_df["category"] == 4)[0])
63
 
64
+ self.model = rt.InferenceSession(model_path)
65
+ _, height, width, _ = self.model.get_inputs()[0].shape
66
  self.model_target_size = height
 
67
  self.last_loaded_repo = model_repo
 
68
 
69
  def prepare_image(self, image):
 
 
70
  canvas = Image.new("RGBA", image.size, (255, 255, 255))
71
  canvas.alpha_composite(image)
72
  image = canvas.convert("RGB")
73
 
74
+ max_dim = max(image.size)
75
+ pad_left = (max_dim - image.size[0]) // 2
76
+ pad_top = (max_dim - image.size[1]) // 2
 
77
 
78
  padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
79
  padded_image.paste(image, (pad_left, pad_top))
80
 
81
+ if max_dim != self.model_target_size:
82
+ padded_image = padded_image.resize((self.model_target_size, self.model_target_size), Image.BICUBIC)
83
 
84
  image_array = np.asarray(padded_image, dtype=np.float32)
85
  image_array = image_array[:, :, ::-1]
86
 
87
  return np.expand_dims(image_array, axis=0)
88
 
89
+ def predict(self, image, model_repo=SWINV2_MODEL_DSV3_REPO, threshold=0.05):
 
 
 
 
 
 
 
 
90
  self.load_model(model_repo)
91
 
92
  image = self.prepare_image(image)
 
95
  preds = self.model.run([label_name], {input_name: image})[0]
96
 
97
  labels = list(zip(self.tag_names, preds[0].astype(float)))
 
 
 
 
98
  general_names = [labels[i] for i in self.general_indexes]
99
+ general_res = [x for x in general_names if x[1] > threshold]
 
 
 
 
100
  general_res = dict(general_res)
101
 
 
 
 
 
 
 
 
 
 
102
  sorted_general = sorted(general_res.items(), key=lambda x: x[1], reverse=True)
103
+ return sorted_general, labels
104
+
105
+ predictor = Predictor()
106
+
107
+ @app.post("/tagging")
108
+ async def tagging_endpoint(
109
+ image: UploadFile = File(...),
110
+ threshold: Optional[float] = Form(0.05)
111
+ ):
112
+ image_data = await image.read()
113
+ pil_image = Image.open(io.BytesIO(image_data)).convert("RGBA")
114
+ sorted_general, _ = predictor.predict(pil_image, threshold=threshold)
115
+ return JSONResponse(content={"tags": [x[0] for x in sorted_general]})
116
+
117
+ def ui_predict(
118
+ image,
119
+ model_repo,
120
+ general_thresh,
121
+ general_mcut_enabled,
122
+ character_thresh,
123
+ character_mcut_enabled,
124
+ ):
125
+ sorted_general, all_labels = predictor.predict(image, model_repo, general_thresh)
126
+
127
+ # Ratings
128
+ ratings = {all_labels[i][0]: all_labels[i][1] for i in predictor.rating_indexes}
129
+
130
+ # Characters
131
+ character_labels = [all_labels[i] for i in predictor.character_indexes]
132
+ if character_mcut_enabled:
133
+ character_probs = np.array([x[1] for x in character_labels])
134
+ character_thresh = max(0.15, np.mean(character_probs))
135
+ character_res = {x[0]: x[1] for x in character_labels if x[1] > character_thresh}
136
+
137
+ # Format output
138
+ sorted_general_strings = ", ".join(x[0] for x in sorted_general).replace("(", "\(").replace(")", "\)")
139
+ return sorted_general_strings, ratings, character_res, dict(sorted_general)
140
+
141
+ def create_demo():
142
  with gr.Blocks(title=TITLE) as demo:
143
+ gr.Markdown(f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>")
144
+ gr.Markdown(DESCRIPTION)
145
+
146
+ with gr.Row():
147
+ with gr.Column(variant="panel"):
148
+ image = gr.Image(type="pil", image_mode="RGBA", label="Input")
149
+ model_repo = gr.Dropdown(
150
+ choices=[
151
+ SWINV2_MODEL_DSV3_REPO, CONV_MODEL_DSV3_REPO,
152
+ VIT_MODEL_DSV3_REPO, VIT_LARGE_MODEL_DSV3_REPO,
153
+ EVA02_LARGE_MODEL_DSV3_REPO, MOAT_MODEL_DSV2_REPO,
154
+ SWIN_MODEL_DSV2_REPO, CONV_MODEL_DSV2_REPO,
155
+ CONV2_MODEL_DSV2_REPO, VIT_MODEL_DSV2_REPO
156
+ ],
157
+ value=SWINV2_MODEL_DSV3_REPO,
158
+ label="Model"
159
+ )
160
+ with gr.Row():
161
+ general_thresh = gr.Slider(0, 1, value=0.35, step=0.05, label="General Tags Threshold")
162
+ general_mcut = gr.Checkbox(value=False, label="Use MCut threshold")
163
+ with gr.Row():
164
+ character_thresh = gr.Slider(0, 1, value=0.85, step=0.05, label="Character Tags Threshold")
165
+ character_mcut = gr.Checkbox(value=False, label="Use MCut threshold")
166
+ submit = gr.Button(value="Submit", variant="primary")
167
+
168
+ with gr.Column(variant="panel"):
169
+ text_output = gr.Textbox(label="Output (string)")
170
+ rating_output = gr.Label(label="Rating")
171
+ character_output = gr.Label(label="Characters")
172
+ general_output = gr.Label(label="Tags")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
  submit.click(
175
+ ui_predict,
176
+ inputs=[image, model_repo, general_thresh, general_mcut,
177
+ character_thresh, character_mcut],
178
+ outputs=[text_output, rating_output, character_output, general_output]
 
 
 
 
 
 
179
  )
180
+
 
 
 
 
 
 
 
 
 
 
 
 
181
  demo.queue(max_size=10)
182
+ return demo
183
+
184
+ app = gr.mount_gradio_app(app, create_demo(), path="/")
185
 
186
  if __name__ == "__main__":
187
+ import uvicorn
188
+ uvicorn.run(app, host="0.0.0.0", port=7860)