Spaces:
Runtime error
Runtime error
frankleeeee
commited on
Commit
•
a097e62
1
Parent(s):
348ea80
udpated
Browse files
app.py
CHANGED
@@ -14,6 +14,7 @@ import sys
|
|
14 |
import spaces
|
15 |
import gradio as gr
|
16 |
import torch
|
|
|
17 |
|
18 |
|
19 |
|
@@ -29,7 +30,7 @@ HF_STDIT_MAP = {
|
|
29 |
"v1-HQ-16x512x512": "hpcai-tech/OpenSora-STDiT-v1-HQ-16x512x512",
|
30 |
}
|
31 |
|
32 |
-
def install_dependencies():
|
33 |
"""
|
34 |
Install the required dependencies for the demo if they are not already installed.
|
35 |
"""
|
@@ -41,7 +42,9 @@ def install_dependencies():
|
|
41 |
except (ImportError, ModuleNotFoundError):
|
42 |
return False
|
43 |
|
44 |
-
#
|
|
|
|
|
45 |
if not _is_package_available("flash_attn"):
|
46 |
subprocess.run(
|
47 |
f"{sys.executable} -m pip install flash-attn --no-build-isolation",
|
@@ -49,6 +52,25 @@ def install_dependencies():
|
|
49 |
shell=True,
|
50 |
)
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
def read_config(config_path):
|
53 |
"""
|
54 |
Read the configuration file.
|
@@ -114,6 +136,7 @@ def parse_args():
|
|
114 |
parser.add_argument("--port", default=None, type=int, help="The port to run the Gradio App on.")
|
115 |
parser.add_argument("--host", default=None, type=str, help="The host to run the Gradio App on.")
|
116 |
parser.add_argument("--share", action="store_true", help="Whether to share this gradio demo.")
|
|
|
117 |
return parser.parse_args()
|
118 |
|
119 |
|
@@ -128,11 +151,11 @@ config = read_config(CONFIG_MAP[args.model_type])
|
|
128 |
os.makedirs(args.output, exist_ok=True)
|
129 |
|
130 |
# disable torch jit as it can cause failure in gradio SDK
|
131 |
-
#
|
132 |
torch.jit._state.disable()
|
133 |
|
134 |
# set up
|
135 |
-
install_dependencies()
|
136 |
|
137 |
# build model
|
138 |
vae, text_encoder, stdit, scheduler = build_models(args.model_type, config)
|
@@ -141,7 +164,6 @@ vae, text_encoder, stdit, scheduler = build_models(args.model_type, config)
|
|
141 |
def run_inference(prompt_text):
|
142 |
latent_size = get_latent_size(config, vae)
|
143 |
|
144 |
-
from opensora.datasets import save_sample
|
145 |
samples = scheduler.sample(
|
146 |
stdit,
|
147 |
text_encoder,
|
@@ -204,6 +226,5 @@ with gr.Blocks() as demo:
|
|
204 |
)
|
205 |
|
206 |
# launch
|
207 |
-
|
208 |
-
demo.launch()
|
209 |
|
|
|
14 |
import spaces
|
15 |
import gradio as gr
|
16 |
import torch
|
17 |
+
from opensora.datasets import save_sample
|
18 |
|
19 |
|
20 |
|
|
|
30 |
"v1-HQ-16x512x512": "hpcai-tech/OpenSora-STDiT-v1-HQ-16x512x512",
|
31 |
}
|
32 |
|
33 |
+
def install_dependencies(enable_optimization=False):
|
34 |
"""
|
35 |
Install the required dependencies for the demo if they are not already installed.
|
36 |
"""
|
|
|
42 |
except (ImportError, ModuleNotFoundError):
|
43 |
return False
|
44 |
|
45 |
+
# flash attention is needed no matter optimization is enabled or not
|
46 |
+
# because Hugging Face transformers detects flash_attn is a dependency in STDiT
|
47 |
+
# thus, we need to install it no matter what
|
48 |
if not _is_package_available("flash_attn"):
|
49 |
subprocess.run(
|
50 |
f"{sys.executable} -m pip install flash-attn --no-build-isolation",
|
|
|
52 |
shell=True,
|
53 |
)
|
54 |
|
55 |
+
if enable_optimization:
|
56 |
+
# install ape
|
57 |
+
if not _is_package_available("apex"):
|
58 |
+
subprocess.run(
|
59 |
+
f'{sys.executable} -m pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" git+https://github.com/NVIDIA/apex.git',
|
60 |
+
shell=True,
|
61 |
+
)
|
62 |
+
|
63 |
+
# install ninja
|
64 |
+
if not _is_package_available("ninja"):
|
65 |
+
subprocess.run(f"{sys.executable} -m pip install ninja", shell=True)
|
66 |
+
|
67 |
+
# install xformers
|
68 |
+
if not _is_package_available("xformers"):
|
69 |
+
subprocess.run(
|
70 |
+
f"{sys.executable} -m pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers",
|
71 |
+
shell=True,
|
72 |
+
)
|
73 |
+
|
74 |
def read_config(config_path):
|
75 |
"""
|
76 |
Read the configuration file.
|
|
|
136 |
parser.add_argument("--port", default=None, type=int, help="The port to run the Gradio App on.")
|
137 |
parser.add_argument("--host", default=None, type=str, help="The host to run the Gradio App on.")
|
138 |
parser.add_argument("--share", action="store_true", help="Whether to share this gradio demo.")
|
139 |
+
parser.add_argument("--enable-optimization", action="store_true", help="Whether to enable optimization such as flash attention and fused layernorm")
|
140 |
return parser.parse_args()
|
141 |
|
142 |
|
|
|
151 |
os.makedirs(args.output, exist_ok=True)
|
152 |
|
153 |
# disable torch jit as it can cause failure in gradio SDK
|
154 |
+
# gradio sdk uses torch with cuda 11.3
|
155 |
torch.jit._state.disable()
|
156 |
|
157 |
# set up
|
158 |
+
install_dependencies(enable_optimization=args.enable_optimization)
|
159 |
|
160 |
# build model
|
161 |
vae, text_encoder, stdit, scheduler = build_models(args.model_type, config)
|
|
|
164 |
def run_inference(prompt_text):
|
165 |
latent_size = get_latent_size(config, vae)
|
166 |
|
|
|
167 |
samples = scheduler.sample(
|
168 |
stdit,
|
169 |
text_encoder,
|
|
|
226 |
)
|
227 |
|
228 |
# launch
|
229 |
+
demo.launch(server_port=args.port, server_name=args.host, share=args.share)
|
|
|
230 |
|