top001 commited on
Commit
3d2dbcf
1 Parent(s): 8787fc3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -36
app.py CHANGED
@@ -1,7 +1,5 @@
1
  import argparse
2
  import os
3
- from typing import Optional
4
- import io
5
 
6
  import gradio as gr
7
  import huggingface_hub
@@ -9,9 +7,6 @@ import numpy as np
9
  import onnxruntime as rt
10
  import pandas as pd
11
  from PIL import Image
12
- from fastapi import FastAPI, File, UploadFile, Form
13
- from fastapi.responses import JSONResponse
14
- import uvicorn
15
 
16
  TITLE = "WaifuDiffusion Tagger"
17
  DESCRIPTION = """
@@ -62,9 +57,6 @@ def load_labels(dataframe) -> list[str]:
62
  return tag_names, rating_indexes, general_indexes, character_indexes
63
 
64
  def mcut_threshold(probs):
65
- """
66
- Maximum Cut Thresholding (MCut)
67
- """
68
  sorted_probs = probs[probs.argsort()[::-1]]
69
  difs = sorted_probs[:-1] - sorted_probs[1:]
70
  t = difs.argmax()
@@ -127,11 +119,11 @@ class Predictor:
127
  def predict(
128
  self,
129
  image,
130
- model_repo=SWINV2_MODEL_DSV3_REPO,
131
- general_thresh=0.35,
132
- general_mcut_enabled=False,
133
- character_thresh=0.85,
134
- character_mcut_enabled=False,
135
  ):
136
  self.load_model(model_repo)
137
 
@@ -168,25 +160,10 @@ class Predictor:
168
 
169
  return sorted_general_strings, rating, character_res, general_res
170
 
171
- predictor = Predictor()
172
-
173
- @app.post("/tagging")
174
- async def tagging_endpoint(
175
- image: UploadFile = File(...),
176
- threshold: Optional[float] = Form(0.05)
177
- ):
178
- image_data = await image.read()
179
- pil_image = Image.open(io.BytesIO(image_data)).convert("RGBA")
180
- sorted_general_strings, _, _, _ = predictor.predict(
181
- pil_image,
182
- general_thresh=threshold
183
- )
184
- tags = sorted_general_strings.split(", ")
185
- return JSONResponse(content={"tags": tags})
186
-
187
- def create_demo():
188
  args = parse_args()
189
-
 
190
  dropdown_list = [
191
  SWINV2_MODEL_DSV3_REPO,
192
  CONV_MODEL_DSV3_REPO,
@@ -287,10 +264,8 @@ def create_demo():
287
  )
288
 
289
  demo.queue(max_size=10)
290
- return demo
291
-
292
- app = FastAPI()
293
- app = gr.mount_gradio_app(app, create_demo(), path="/")
294
 
295
  if __name__ == "__main__":
296
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import argparse
2
  import os
 
 
3
 
4
  import gradio as gr
5
  import huggingface_hub
 
7
  import onnxruntime as rt
8
  import pandas as pd
9
  from PIL import Image
 
 
 
10
 
11
  TITLE = "WaifuDiffusion Tagger"
12
  DESCRIPTION = """
 
57
  return tag_names, rating_indexes, general_indexes, character_indexes
58
 
59
  def mcut_threshold(probs):
 
 
 
60
  sorted_probs = probs[probs.argsort()[::-1]]
61
  difs = sorted_probs[:-1] - sorted_probs[1:]
62
  t = difs.argmax()
 
119
  def predict(
120
  self,
121
  image,
122
+ model_repo,
123
+ general_thresh,
124
+ general_mcut_enabled,
125
+ character_thresh,
126
+ character_mcut_enabled,
127
  ):
128
  self.load_model(model_repo)
129
 
 
160
 
161
  return sorted_general_strings, rating, character_res, general_res
162
 
163
+ def main():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  args = parse_args()
165
+ predictor = Predictor()
166
+
167
  dropdown_list = [
168
  SWINV2_MODEL_DSV3_REPO,
169
  CONV_MODEL_DSV3_REPO,
 
264
  )
265
 
266
  demo.queue(max_size=10)
267
+
268
+ demo.launch()
 
 
269
 
270
  if __name__ == "__main__":
271
+ main()