SmilingWolf
commited on
Commit
·
f6dbb10
1
Parent(s):
f56e0f7
Add support for model selection
Browse files
README.md
CHANGED
@@ -4,7 +4,7 @@ emoji: 💬
|
|
4 |
colorFrom: blue
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
duplicated_from: NoCrypt/DeepDanbooru_string
|
|
|
4 |
colorFrom: blue
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.13
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
duplicated_from: NoCrypt/DeepDanbooru_string
|
app.py
CHANGED
@@ -20,7 +20,7 @@ from Utils import dbimutils
|
|
20 |
|
21 |
TITLE = "WaifuDiffusion v1.4 Tags"
|
22 |
DESCRIPTION = """
|
23 |
-
Demo for [SmilingWolf/wd-v1-4-vit-tagger](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger) with "ready to copy" prompt and a prompt analyzer.
|
24 |
|
25 |
Modified from [NoCrypt/DeepDanbooru_string](https://huggingface.co/spaces/NoCrypt/DeepDanbooru_string)
|
26 |
Modified from [hysts/DeepDanbooru](https://huggingface.co/spaces/hysts/DeepDanbooru)
|
@@ -31,7 +31,8 @@ Example image by [ほし☆☆☆](https://www.pixiv.net/en/users/43565085)
|
|
31 |
"""
|
32 |
|
33 |
HF_TOKEN = os.environ["HF_TOKEN"]
|
34 |
-
|
|
|
35 |
MODEL_FILENAME = "model.onnx"
|
36 |
LABEL_FILENAME = "selected_tags.csv"
|
37 |
|
@@ -44,9 +45,9 @@ def parse_args() -> argparse.Namespace:
|
|
44 |
return parser.parse_args()
|
45 |
|
46 |
|
47 |
-
def load_model() -> rt.InferenceSession:
|
48 |
path = huggingface_hub.hf_hub_download(
|
49 |
-
|
50 |
)
|
51 |
model = rt.InferenceSession(path)
|
52 |
return model
|
@@ -54,7 +55,7 @@ def load_model() -> rt.InferenceSession:
|
|
54 |
|
55 |
def load_labels() -> list[str]:
|
56 |
path = huggingface_hub.hf_hub_download(
|
57 |
-
|
58 |
)
|
59 |
df = pd.read_csv(path)["name"].tolist()
|
60 |
return df
|
@@ -69,11 +70,14 @@ def plaintext_to_html(text):
|
|
69 |
|
70 |
def predict(
|
71 |
image: PIL.Image.Image,
|
|
|
72 |
score_threshold: float,
|
73 |
-
|
74 |
labels: list[str],
|
75 |
):
|
76 |
rawimage = image
|
|
|
|
|
77 |
_, height, width, _ = model.get_inputs()[0].shape
|
78 |
|
79 |
# Alpha to white
|
@@ -168,15 +172,19 @@ def predict(
|
|
168 |
|
169 |
def main():
|
170 |
args = parse_args()
|
171 |
-
|
|
|
172 |
labels = load_labels()
|
173 |
|
174 |
-
|
|
|
|
|
175 |
|
176 |
gr.Interface(
|
177 |
fn=func,
|
178 |
inputs=[
|
179 |
gr.Image(type="pil", label="Input"),
|
|
|
180 |
gr.Slider(
|
181 |
0,
|
182 |
1,
|
@@ -192,7 +200,7 @@ def main():
|
|
192 |
gr.Label(label="Output (label)"),
|
193 |
gr.HTML(),
|
194 |
],
|
195 |
-
examples=[["power.jpg", 0.5]],
|
196 |
title=TITLE,
|
197 |
description=DESCRIPTION,
|
198 |
allow_flagging="never",
|
|
|
20 |
|
21 |
TITLE = "WaifuDiffusion v1.4 Tags"
|
22 |
DESCRIPTION = """
|
23 |
+
Demo for [SmilingWolf/wd-v1-4-vit-tagger](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger) and [SmilingWolf/wd-v1-4-convnext-tagger](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger) with "ready to copy" prompt and a prompt analyzer.
|
24 |
|
25 |
Modified from [NoCrypt/DeepDanbooru_string](https://huggingface.co/spaces/NoCrypt/DeepDanbooru_string)
|
26 |
Modified from [hysts/DeepDanbooru](https://huggingface.co/spaces/hysts/DeepDanbooru)
|
|
|
31 |
"""
|
32 |
|
33 |
HF_TOKEN = os.environ["HF_TOKEN"]
|
34 |
+
VIT_MODEL_REPO = "SmilingWolf/wd-v1-4-vit-tagger"
|
35 |
+
CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger"
|
36 |
MODEL_FILENAME = "model.onnx"
|
37 |
LABEL_FILENAME = "selected_tags.csv"
|
38 |
|
|
|
45 |
return parser.parse_args()
|
46 |
|
47 |
|
48 |
+
def load_model(model_repo: str, model_filename: str) -> rt.InferenceSession:
|
49 |
path = huggingface_hub.hf_hub_download(
|
50 |
+
model_repo, model_filename, use_auth_token=HF_TOKEN
|
51 |
)
|
52 |
model = rt.InferenceSession(path)
|
53 |
return model
|
|
|
55 |
|
56 |
def load_labels() -> list[str]:
|
57 |
path = huggingface_hub.hf_hub_download(
|
58 |
+
VIT_MODEL_REPO, LABEL_FILENAME, use_auth_token=HF_TOKEN
|
59 |
)
|
60 |
df = pd.read_csv(path)["name"].tolist()
|
61 |
return df
|
|
|
70 |
|
71 |
def predict(
|
72 |
image: PIL.Image.Image,
|
73 |
+
selected_model: str,
|
74 |
score_threshold: float,
|
75 |
+
models: dict,
|
76 |
labels: list[str],
|
77 |
):
|
78 |
rawimage = image
|
79 |
+
|
80 |
+
model = models[selected_model]
|
81 |
_, height, width, _ = model.get_inputs()[0].shape
|
82 |
|
83 |
# Alpha to white
|
|
|
172 |
|
173 |
def main():
|
174 |
args = parse_args()
|
175 |
+
vit_model = load_model(VIT_MODEL_REPO, MODEL_FILENAME)
|
176 |
+
conv_model = load_model(CONV_MODEL_REPO, MODEL_FILENAME)
|
177 |
labels = load_labels()
|
178 |
|
179 |
+
models = {"ViT": vit_model, "ConvNext": conv_model}
|
180 |
+
|
181 |
+
func = functools.partial(predict, models=models, labels=labels)
|
182 |
|
183 |
gr.Interface(
|
184 |
fn=func,
|
185 |
inputs=[
|
186 |
gr.Image(type="pil", label="Input"),
|
187 |
+
gr.Radio(["ViT", "ConvNext"], label="Model"),
|
188 |
gr.Slider(
|
189 |
0,
|
190 |
1,
|
|
|
200 |
gr.Label(label="Output (label)"),
|
201 |
gr.HTML(),
|
202 |
],
|
203 |
+
examples=[["power.jpg", "ViT", 0.5]],
|
204 |
title=TITLE,
|
205 |
description=DESCRIPTION,
|
206 |
allow_flagging="never",
|