top001 commited on
Commit
8787fc3
1 Parent(s): 53d40a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -93
app.py CHANGED
@@ -1,5 +1,7 @@
1
  import argparse
2
  import os
 
 
3
 
4
  import gradio as gr
5
  import huggingface_hub
@@ -7,11 +9,13 @@ import numpy as np
7
  import onnxruntime as rt
8
  import pandas as pd
9
  from PIL import Image
 
 
 
10
 
11
  TITLE = "WaifuDiffusion Tagger"
12
  DESCRIPTION = """
13
  Demo for the WaifuDiffusion tagger models
14
-
15
  Example image by [ほし☆☆☆](https://www.pixiv.net/en/users/43565085)
16
  """
17
 
@@ -29,34 +33,14 @@ CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
29
  CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
30
  VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
31
 
32
- # Files to download from the repos
33
  MODEL_FILENAME = "model.onnx"
34
  LABEL_FILENAME = "selected_tags.csv"
35
 
36
- # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
37
  kaomojis = [
38
- "0_0",
39
- "(o)_(o)",
40
- "+_+",
41
- "+_-",
42
- "._.",
43
- "<o>_<o>",
44
- "<|>_<|>",
45
- "=_=",
46
- ">_<",
47
- "3_3",
48
- "6_9",
49
- ">_o",
50
- "@_@",
51
- "^_^",
52
- "o_o",
53
- "u_u",
54
- "x_x",
55
- "|_|",
56
- "||_||",
57
  ]
58
 
59
-
60
  def parse_args() -> argparse.Namespace:
61
  parser = argparse.ArgumentParser()
62
  parser.add_argument("--score-slider-step", type=float, default=0.05)
@@ -65,7 +49,6 @@ def parse_args() -> argparse.Namespace:
65
  parser.add_argument("--share", action="store_true")
66
  return parser.parse_args()
67
 
68
-
69
  def load_labels(dataframe) -> list[str]:
70
  name_series = dataframe["name"]
71
  name_series = name_series.map(
@@ -78,13 +61,9 @@ def load_labels(dataframe) -> list[str]:
78
  character_indexes = list(np.where(dataframe["category"] == 4)[0])
79
  return tag_names, rating_indexes, general_indexes, character_indexes
80
 
81
-
82
  def mcut_threshold(probs):
83
  """
84
  Maximum Cut Thresholding (MCut)
85
- Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy
86
- for Multi-label Classification. In 11th International Symposium, IDA 2012
87
- (pp. 172-183).
88
  """
89
  sorted_probs = probs[probs.argsort()[::-1]]
90
  difs = sorted_probs[:-1] - sorted_probs[1:]
@@ -92,21 +71,14 @@ def mcut_threshold(probs):
92
  thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
93
  return thresh
94
 
95
-
96
  class Predictor:
97
  def __init__(self):
98
  self.model_target_size = None
99
  self.last_loaded_repo = None
100
-
101
  def download_model(self, model_repo):
102
- csv_path = huggingface_hub.hf_hub_download(
103
- model_repo,
104
- LABEL_FILENAME,
105
- )
106
- model_path = huggingface_hub.hf_hub_download(
107
- model_repo,
108
- MODEL_FILENAME,
109
- )
110
  return csv_path, model_path
111
 
112
  def load_model(self, model_repo):
@@ -114,7 +86,6 @@ class Predictor:
114
  return
115
 
116
  csv_path, model_path = self.download_model(model_repo)
117
-
118
  tags_df = pd.read_csv(csv_path)
119
  sep_tags = load_labels(tags_df)
120
 
@@ -132,12 +103,11 @@ class Predictor:
132
 
133
  def prepare_image(self, image):
134
  target_size = self.model_target_size
135
-
136
  canvas = Image.new("RGBA", image.size, (255, 255, 255))
137
  canvas.alpha_composite(image)
138
  image = canvas.convert("RGB")
139
 
140
- # Pad image to square
141
  image_shape = image.size
142
  max_dim = max(image_shape)
143
  pad_left = (max_dim - image_shape[0]) // 2
@@ -146,47 +116,36 @@ class Predictor:
146
  padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
147
  padded_image.paste(image, (pad_left, pad_top))
148
 
149
- # Resize
150
  if max_dim != target_size:
151
- padded_image = padded_image.resize(
152
- (target_size, target_size),
153
- Image.BICUBIC,
154
- )
155
 
156
- # Convert to numpy array
157
  image_array = np.asarray(padded_image, dtype=np.float32)
158
-
159
- # Convert PIL-native RGB to BGR
160
  image_array = image_array[:, :, ::-1]
161
-
162
  return np.expand_dims(image_array, axis=0)
163
 
164
  def predict(
165
  self,
166
  image,
167
- model_repo,
168
- general_thresh,
169
- general_mcut_enabled,
170
- character_thresh,
171
- character_mcut_enabled,
172
  ):
173
  self.load_model(model_repo)
174
-
175
  image = self.prepare_image(image)
176
-
177
  input_name = self.model.get_inputs()[0].name
178
  label_name = self.model.get_outputs()[0].name
179
  preds = self.model.run([label_name], {input_name: image})[0]
180
 
181
  labels = list(zip(self.tag_names, preds[0].astype(float)))
182
-
183
- # First 4 labels are actually ratings: pick one with argmax
184
  ratings_names = [labels[i] for i in self.rating_indexes]
185
  rating = dict(ratings_names)
186
 
187
- # Then we have general tags: pick any where prediction confidence > threshold
188
  general_names = [labels[i] for i in self.general_indexes]
189
-
190
  if general_mcut_enabled:
191
  general_probs = np.array([x[1] for x in general_names])
192
  general_thresh = mcut_threshold(general_probs)
@@ -194,9 +153,7 @@ class Predictor:
194
  general_res = [x for x in general_names if x[1] > general_thresh]
195
  general_res = dict(general_res)
196
 
197
- # Everything else is characters: pick any where prediction confidence > threshold
198
  character_names = [labels[i] for i in self.character_indexes]
199
-
200
  if character_mcut_enabled:
201
  character_probs = np.array([x[1] for x in character_names])
202
  character_thresh = mcut_threshold(character_probs)
@@ -205,24 +162,31 @@ class Predictor:
205
  character_res = [x for x in character_names if x[1] > character_thresh]
206
  character_res = dict(character_res)
207
 
208
- sorted_general_strings = sorted(
209
- general_res.items(),
210
- key=lambda x: x[1],
211
- reverse=True,
212
- )
213
- sorted_general_strings = [x[0] for x in sorted_general_strings]
214
- sorted_general_strings = (
215
- ", ".join(sorted_general_strings).replace("(", "\(").replace(")", "\)")
216
- )
217
 
218
  return sorted_general_strings, rating, character_res, general_res
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
- def main():
222
  args = parse_args()
223
-
224
- predictor = Predictor()
225
-
226
  dropdown_list = [
227
  SWINV2_MODEL_DSV3_REPO,
228
  CONV_MODEL_DSV3_REPO,
@@ -238,10 +202,8 @@ def main():
238
 
239
  with gr.Blocks(title=TITLE) as demo:
240
  with gr.Column():
241
- gr.Markdown(
242
- value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
243
- )
244
- gr.Markdown(value=DESCRIPTION)
245
  with gr.Row():
246
  with gr.Column(variant="panel"):
247
  image = gr.Image(type="pil", image_mode="RGBA", label="Input")
@@ -297,14 +259,7 @@ def main():
297
  rating = gr.Label(label="Rating")
298
  character_res = gr.Label(label="Output (characters)")
299
  general_res = gr.Label(label="Output (tags)")
300
- clear.add(
301
- [
302
- sorted_general_strings,
303
- rating,
304
- character_res,
305
- general_res,
306
- ]
307
- )
308
 
309
  submit.click(
310
  predictor.predict,
@@ -318,7 +273,7 @@ def main():
318
  ],
319
  outputs=[sorted_general_strings, rating, character_res, general_res],
320
  )
321
-
322
  gr.Examples(
323
  [["power.jpg", SWINV2_MODEL_DSV3_REPO, 0.35, False, 0.85, False]],
324
  inputs=[
@@ -330,10 +285,12 @@ def main():
330
  character_mcut_enabled,
331
  ],
332
  )
 
 
 
333
 
334
- demo.queue(max_size=10)
335
- demo.launch()
336
-
337
 
338
  if __name__ == "__main__":
339
- main()
 
1
  import argparse
2
  import os
3
+ from typing import Optional
4
+ import io
5
 
6
  import gradio as gr
7
  import huggingface_hub
 
9
  import onnxruntime as rt
10
  import pandas as pd
11
  from PIL import Image
12
+ from fastapi import FastAPI, File, UploadFile, Form
13
+ from fastapi.responses import JSONResponse
14
+ import uvicorn
15
 
16
  TITLE = "WaifuDiffusion Tagger"
17
  DESCRIPTION = """
18
  Demo for the WaifuDiffusion tagger models
 
19
  Example image by [ほし☆☆☆](https://www.pixiv.net/en/users/43565085)
20
  """
21
 
 
33
  CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
34
  VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
35
 
 
36
  MODEL_FILENAME = "model.onnx"
37
  LABEL_FILENAME = "selected_tags.csv"
38
 
 
39
  kaomojis = [
40
+ "0_0", "(o)_(o)", "+_+", "+_-", "._.", "<o>_<o>", "<|>_<|>", "=_=", ">_<",
41
+ "3_3", "6_9", ">_o", "@_@", "^_^", "o_o", "u_u", "x_x", "|_|", "||_||",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  ]
43
 
 
44
  def parse_args() -> argparse.Namespace:
45
  parser = argparse.ArgumentParser()
46
  parser.add_argument("--score-slider-step", type=float, default=0.05)
 
49
  parser.add_argument("--share", action="store_true")
50
  return parser.parse_args()
51
 
 
52
  def load_labels(dataframe) -> list[str]:
53
  name_series = dataframe["name"]
54
  name_series = name_series.map(
 
61
  character_indexes = list(np.where(dataframe["category"] == 4)[0])
62
  return tag_names, rating_indexes, general_indexes, character_indexes
63
 
 
64
  def mcut_threshold(probs):
65
  """
66
  Maximum Cut Thresholding (MCut)
 
 
 
67
  """
68
  sorted_probs = probs[probs.argsort()[::-1]]
69
  difs = sorted_probs[:-1] - sorted_probs[1:]
 
71
  thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
72
  return thresh
73
 
 
74
  class Predictor:
75
  def __init__(self):
76
  self.model_target_size = None
77
  self.last_loaded_repo = None
78
+
79
  def download_model(self, model_repo):
80
+ csv_path = huggingface_hub.hf_hub_download(model_repo, LABEL_FILENAME)
81
+ model_path = huggingface_hub.hf_hub_download(model_repo, MODEL_FILENAME)
 
 
 
 
 
 
82
  return csv_path, model_path
83
 
84
  def load_model(self, model_repo):
 
86
  return
87
 
88
  csv_path, model_path = self.download_model(model_repo)
 
89
  tags_df = pd.read_csv(csv_path)
90
  sep_tags = load_labels(tags_df)
91
 
 
103
 
104
  def prepare_image(self, image):
105
  target_size = self.model_target_size
106
+
107
  canvas = Image.new("RGBA", image.size, (255, 255, 255))
108
  canvas.alpha_composite(image)
109
  image = canvas.convert("RGB")
110
 
 
111
  image_shape = image.size
112
  max_dim = max(image_shape)
113
  pad_left = (max_dim - image_shape[0]) // 2
 
116
  padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
117
  padded_image.paste(image, (pad_left, pad_top))
118
 
 
119
  if max_dim != target_size:
120
+ padded_image = padded_image.resize((target_size, target_size), Image.BICUBIC)
 
 
 
121
 
 
122
  image_array = np.asarray(padded_image, dtype=np.float32)
 
 
123
  image_array = image_array[:, :, ::-1]
124
+
125
  return np.expand_dims(image_array, axis=0)
126
 
127
  def predict(
128
  self,
129
  image,
130
+ model_repo=SWINV2_MODEL_DSV3_REPO,
131
+ general_thresh=0.35,
132
+ general_mcut_enabled=False,
133
+ character_thresh=0.85,
134
+ character_mcut_enabled=False,
135
  ):
136
  self.load_model(model_repo)
137
+
138
  image = self.prepare_image(image)
 
139
  input_name = self.model.get_inputs()[0].name
140
  label_name = self.model.get_outputs()[0].name
141
  preds = self.model.run([label_name], {input_name: image})[0]
142
 
143
  labels = list(zip(self.tag_names, preds[0].astype(float)))
144
+
 
145
  ratings_names = [labels[i] for i in self.rating_indexes]
146
  rating = dict(ratings_names)
147
 
 
148
  general_names = [labels[i] for i in self.general_indexes]
 
149
  if general_mcut_enabled:
150
  general_probs = np.array([x[1] for x in general_names])
151
  general_thresh = mcut_threshold(general_probs)
 
153
  general_res = [x for x in general_names if x[1] > general_thresh]
154
  general_res = dict(general_res)
155
 
 
156
  character_names = [labels[i] for i in self.character_indexes]
 
157
  if character_mcut_enabled:
158
  character_probs = np.array([x[1] for x in character_names])
159
  character_thresh = mcut_threshold(character_probs)
 
162
  character_res = [x for x in character_names if x[1] > character_thresh]
163
  character_res = dict(character_res)
164
 
165
+ sorted_general = sorted(general_res.items(), key=lambda x: x[1], reverse=True)
166
+ sorted_general_strings = [x[0] for x in sorted_general]
167
+ sorted_general_strings = ", ".join(sorted_general_strings).replace("(", "\(").replace(")", "\)")
 
 
 
 
 
 
168
 
169
  return sorted_general_strings, rating, character_res, general_res
170
 
171
+ predictor = Predictor()
172
+
173
+ @app.post("/tagging")
174
+ async def tagging_endpoint(
175
+ image: UploadFile = File(...),
176
+ threshold: Optional[float] = Form(0.05)
177
+ ):
178
+ image_data = await image.read()
179
+ pil_image = Image.open(io.BytesIO(image_data)).convert("RGBA")
180
+ sorted_general_strings, _, _, _ = predictor.predict(
181
+ pil_image,
182
+ general_thresh=threshold
183
+ )
184
+ tags = sorted_general_strings.split(", ")
185
+ return JSONResponse(content={"tags": tags})
186
 
187
+ def create_demo():
188
  args = parse_args()
189
+
 
 
190
  dropdown_list = [
191
  SWINV2_MODEL_DSV3_REPO,
192
  CONV_MODEL_DSV3_REPO,
 
202
 
203
  with gr.Blocks(title=TITLE) as demo:
204
  with gr.Column():
205
+ gr.Markdown(f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>")
206
+ gr.Markdown(DESCRIPTION)
 
 
207
  with gr.Row():
208
  with gr.Column(variant="panel"):
209
  image = gr.Image(type="pil", image_mode="RGBA", label="Input")
 
259
  rating = gr.Label(label="Rating")
260
  character_res = gr.Label(label="Output (characters)")
261
  general_res = gr.Label(label="Output (tags)")
262
+ clear.add([sorted_general_strings, rating, character_res, general_res])
 
 
 
 
 
 
 
263
 
264
  submit.click(
265
  predictor.predict,
 
273
  ],
274
  outputs=[sorted_general_strings, rating, character_res, general_res],
275
  )
276
+
277
  gr.Examples(
278
  [["power.jpg", SWINV2_MODEL_DSV3_REPO, 0.35, False, 0.85, False]],
279
  inputs=[
 
285
  character_mcut_enabled,
286
  ],
287
  )
288
+
289
+ demo.queue(max_size=10)
290
+ return demo
291
 
292
+ app = FastAPI()
293
+ app = gr.mount_gradio_app(app, create_demo(), path="/")
 
294
 
295
  if __name__ == "__main__":
296
+ uvicorn.run(app, host="0.0.0.0", port=7860)