Spaces:
Runtime error
Runtime error
chore: remote flag of model backend
Browse files- app.py +38 -23
- requirements.txt +3 -1
app.py
CHANGED
@@ -5,7 +5,7 @@ import os
|
|
5 |
import torch
|
6 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
7 |
|
8 |
-
|
9 |
|
10 |
from huggingface_hub import login
|
11 |
|
@@ -17,8 +17,14 @@ MODEL_NAME = (
|
|
17 |
else "p1atdev/dart-v1-sft"
|
18 |
)
|
19 |
HF_READ_TOKEN = os.environ.get("HF_READ_TOKEN")
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
assert isinstance(MODEL_NAME, str)
|
|
|
22 |
|
23 |
tokenizer = AutoTokenizer.from_pretrained(
|
24 |
MODEL_NAME,
|
@@ -30,19 +36,19 @@ model = {
|
|
30 |
MODEL_NAME,
|
31 |
token=HF_READ_TOKEN,
|
32 |
),
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
}
|
41 |
|
42 |
MODEL_BACKEND_MAP = {
|
43 |
"Default": "default",
|
44 |
-
|
45 |
-
|
46 |
}
|
47 |
|
48 |
try:
|
@@ -288,7 +294,7 @@ def handle_inputs(
|
|
288 |
top_p: float = 1.0,
|
289 |
top_k: int = 20,
|
290 |
num_beams: int = 1,
|
291 |
-
model_backend: str = "Default",
|
292 |
):
|
293 |
"""
|
294 |
Returns:
|
@@ -340,7 +346,7 @@ def handle_inputs(
|
|
340 |
|
341 |
generated_ids = generate(
|
342 |
prompt,
|
343 |
-
model_backend=
|
344 |
max_new_tokens=max_new_tokens,
|
345 |
min_new_tokens=min_new_tokens,
|
346 |
do_sample=True,
|
@@ -395,21 +401,30 @@ def demo():
|
|
395 |
with gr.Blocks() as ui:
|
396 |
gr.Markdown(
|
397 |
"""\
|
398 |
-
# Danbooru Tags Transformer Demo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
399 |
)
|
400 |
|
401 |
with gr.Row():
|
402 |
with gr.Column():
|
403 |
|
404 |
-
with gr.Group(
|
405 |
-
|
406 |
-
):
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
|
414 |
with gr.Group():
|
415 |
rating_dropdown = gr.Dropdown(
|
@@ -663,7 +678,7 @@ def demo():
|
|
663 |
top_p_slider,
|
664 |
top_k_slider,
|
665 |
num_beams_slider,
|
666 |
-
model_backend_radio,
|
667 |
],
|
668 |
outputs=[
|
669 |
output_tags_natural,
|
|
|
5 |
import torch
|
6 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
7 |
|
8 |
+
from optimum.onnxruntime import ORTModelForCausalLM
|
9 |
|
10 |
from huggingface_hub import login
|
11 |
|
|
|
17 |
else "p1atdev/dart-v1-sft"
|
18 |
)
|
19 |
HF_READ_TOKEN = os.environ.get("HF_READ_TOKEN")
|
20 |
+
MODEL_BACKEND = (
|
21 |
+
os.environ.get("MODEL_BACKEND")
|
22 |
+
if os.environ.get("MODEL_BACKEND") is not None
|
23 |
+
else "ONNX (quantized)"
|
24 |
+
)
|
25 |
|
26 |
assert isinstance(MODEL_NAME, str)
|
27 |
+
assert isinstance(MODEL_BACKEND, str)
|
28 |
|
29 |
tokenizer = AutoTokenizer.from_pretrained(
|
30 |
MODEL_NAME,
|
|
|
36 |
MODEL_NAME,
|
37 |
token=HF_READ_TOKEN,
|
38 |
),
|
39 |
+
"ort": ORTModelForCausalLM.from_pretrained(
|
40 |
+
MODEL_NAME,
|
41 |
+
),
|
42 |
+
"ort_qantized": ORTModelForCausalLM.from_pretrained(
|
43 |
+
MODEL_NAME,
|
44 |
+
file_name="model_quantized.onnx",
|
45 |
+
),
|
46 |
}
|
47 |
|
48 |
MODEL_BACKEND_MAP = {
|
49 |
"Default": "default",
|
50 |
+
"ONNX (normal)": "ort",
|
51 |
+
"ONNX (quantized)": "ort_qantized",
|
52 |
}
|
53 |
|
54 |
try:
|
|
|
294 |
top_p: float = 1.0,
|
295 |
top_k: int = 20,
|
296 |
num_beams: int = 1,
|
297 |
+
# model_backend: str = "Default",
|
298 |
):
|
299 |
"""
|
300 |
Returns:
|
|
|
346 |
|
347 |
generated_ids = generate(
|
348 |
prompt,
|
349 |
+
model_backend=MODEL_BACKEND,
|
350 |
max_new_tokens=max_new_tokens,
|
351 |
min_new_tokens=min_new_tokens,
|
352 |
do_sample=True,
|
|
|
401 |
with gr.Blocks() as ui:
|
402 |
gr.Markdown(
|
403 |
"""\
|
404 |
+
# Danbooru Tags Transformer Demo
|
405 |
+
|
406 |
+
Collection: [Dart (Danbooru Tags Transformer)](https://huggingface.co/collections/p1atdev/dart-danbooru-tags-transformer-65d687604ff57dc62ae40945)
|
407 |
+
|
408 |
+
Models:
|
409 |
+
|
410 |
+
- [p1atdev/dart-v1-sft](https://huggingface.co/p1atdev/dart-v1-sft)
|
411 |
+
- [p1atdev/dart-v1-base](https://huggingface.co/p1atdev/dart-v1-base)
|
412 |
+
|
413 |
+
"""
|
414 |
)
|
415 |
|
416 |
with gr.Row():
|
417 |
with gr.Column():
|
418 |
|
419 |
+
# with gr.Group(
|
420 |
+
# visible=False,
|
421 |
+
# ):
|
422 |
+
# model_backend_radio = gr.Radio(
|
423 |
+
# label="Model backend",
|
424 |
+
# choices=list(MODEL_BACKEND_MAP.keys()),
|
425 |
+
# value="Default",
|
426 |
+
# interactive=True,
|
427 |
+
# )
|
428 |
|
429 |
with gr.Group():
|
430 |
rating_dropdown = gr.Dropdown(
|
|
|
678 |
top_p_slider,
|
679 |
top_k_slider,
|
680 |
num_beams_slider,
|
681 |
+
# model_backend_radio,
|
682 |
],
|
683 |
outputs=[
|
684 |
output_tags_natural,
|
requirements.txt
CHANGED
@@ -1,2 +1,4 @@
|
|
1 |
torch==2.1.0
|
2 |
-
|
|
|
|
|
|
1 |
torch==2.1.0
|
2 |
+
accelerate==0.26.1
|
3 |
+
transformers==4.38.0
|
4 |
+
optimum[onnxruntime]==1.17.1
|