Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,5 @@
|
|
1 |
import argparse
|
2 |
import os
|
3 |
-
from typing import Optional
|
4 |
-
import io
|
5 |
|
6 |
import gradio as gr
|
7 |
import huggingface_hub
|
@@ -9,9 +7,6 @@ import numpy as np
|
|
9 |
import onnxruntime as rt
|
10 |
import pandas as pd
|
11 |
from PIL import Image
|
12 |
-
from fastapi import FastAPI, File, UploadFile, Form
|
13 |
-
from fastapi.responses import JSONResponse
|
14 |
-
import uvicorn
|
15 |
|
16 |
TITLE = "WaifuDiffusion Tagger"
|
17 |
DESCRIPTION = """
|
@@ -62,9 +57,6 @@ def load_labels(dataframe) -> list[str]:
|
|
62 |
return tag_names, rating_indexes, general_indexes, character_indexes
|
63 |
|
64 |
def mcut_threshold(probs):
|
65 |
-
"""
|
66 |
-
Maximum Cut Thresholding (MCut)
|
67 |
-
"""
|
68 |
sorted_probs = probs[probs.argsort()[::-1]]
|
69 |
difs = sorted_probs[:-1] - sorted_probs[1:]
|
70 |
t = difs.argmax()
|
@@ -127,11 +119,11 @@ class Predictor:
|
|
127 |
def predict(
|
128 |
self,
|
129 |
image,
|
130 |
-
model_repo
|
131 |
-
general_thresh
|
132 |
-
general_mcut_enabled
|
133 |
-
character_thresh
|
134 |
-
character_mcut_enabled
|
135 |
):
|
136 |
self.load_model(model_repo)
|
137 |
|
@@ -168,25 +160,10 @@ class Predictor:
|
|
168 |
|
169 |
return sorted_general_strings, rating, character_res, general_res
|
170 |
|
171 |
-
|
172 |
-
|
173 |
-
@app.post("/tagging")
|
174 |
-
async def tagging_endpoint(
|
175 |
-
image: UploadFile = File(...),
|
176 |
-
threshold: Optional[float] = Form(0.05)
|
177 |
-
):
|
178 |
-
image_data = await image.read()
|
179 |
-
pil_image = Image.open(io.BytesIO(image_data)).convert("RGBA")
|
180 |
-
sorted_general_strings, _, _, _ = predictor.predict(
|
181 |
-
pil_image,
|
182 |
-
general_thresh=threshold
|
183 |
-
)
|
184 |
-
tags = sorted_general_strings.split(", ")
|
185 |
-
return JSONResponse(content={"tags": tags})
|
186 |
-
|
187 |
-
def create_demo():
|
188 |
args = parse_args()
|
189 |
-
|
|
|
190 |
dropdown_list = [
|
191 |
SWINV2_MODEL_DSV3_REPO,
|
192 |
CONV_MODEL_DSV3_REPO,
|
@@ -287,10 +264,8 @@ def create_demo():
|
|
287 |
)
|
288 |
|
289 |
demo.queue(max_size=10)
|
290 |
-
|
291 |
-
|
292 |
-
app = FastAPI()
|
293 |
-
app = gr.mount_gradio_app(app, create_demo(), path="/")
|
294 |
|
295 |
if __name__ == "__main__":
|
296 |
-
|
|
|
1 |
import argparse
|
2 |
import os
|
|
|
|
|
3 |
|
4 |
import gradio as gr
|
5 |
import huggingface_hub
|
|
|
7 |
import onnxruntime as rt
|
8 |
import pandas as pd
|
9 |
from PIL import Image
|
|
|
|
|
|
|
10 |
|
11 |
TITLE = "WaifuDiffusion Tagger"
|
12 |
DESCRIPTION = """
|
|
|
57 |
return tag_names, rating_indexes, general_indexes, character_indexes
|
58 |
|
59 |
def mcut_threshold(probs):
|
|
|
|
|
|
|
60 |
sorted_probs = probs[probs.argsort()[::-1]]
|
61 |
difs = sorted_probs[:-1] - sorted_probs[1:]
|
62 |
t = difs.argmax()
|
|
|
119 |
def predict(
|
120 |
self,
|
121 |
image,
|
122 |
+
model_repo,
|
123 |
+
general_thresh,
|
124 |
+
general_mcut_enabled,
|
125 |
+
character_thresh,
|
126 |
+
character_mcut_enabled,
|
127 |
):
|
128 |
self.load_model(model_repo)
|
129 |
|
|
|
160 |
|
161 |
return sorted_general_strings, rating, character_res, general_res
|
162 |
|
163 |
+
def main():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
args = parse_args()
|
165 |
+
predictor = Predictor()
|
166 |
+
|
167 |
dropdown_list = [
|
168 |
SWINV2_MODEL_DSV3_REPO,
|
169 |
CONV_MODEL_DSV3_REPO,
|
|
|
264 |
)
|
265 |
|
266 |
demo.queue(max_size=10)
|
267 |
+
|
268 |
+
demo.launch()
|
|
|
|
|
269 |
|
270 |
if __name__ == "__main__":
|
271 |
+
main()
|