Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -5,6 +5,8 @@ import tempfile
|
|
5 |
import subprocess
|
6 |
from pathlib import Path
|
7 |
|
|
|
|
|
8 |
import spaces
|
9 |
import gradio as gr
|
10 |
import torch
|
@@ -205,15 +207,27 @@ def create_demo(
|
|
205 |
|
206 |
return demo
|
207 |
|
208 |
-
|
209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
parser = argparse.ArgumentParser(description="Flux")
|
211 |
parser.add_argument("--name", type=str, default="flux-dev", help="Model name")
|
212 |
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use")
|
213 |
parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
|
214 |
parser.add_argument("--share", action="store_true", help="Create a public link to your demo")
|
215 |
parser.add_argument("--ckpt_dir", type=str, default=".", help="Folder with checkpoints in safetensors format")
|
|
|
216 |
args = parser.parse_args()
|
|
|
217 |
|
218 |
-
|
219 |
-
|
|
|
|
|
|
|
|
5 |
import subprocess
|
6 |
from pathlib import Path
|
7 |
|
8 |
+
from dataclasses import dataclass
|
9 |
+
import torch.multiprocessing as mp
|
10 |
import spaces
|
11 |
import gradio as gr
|
12 |
import torch
|
|
|
207 |
|
208 |
return demo
|
209 |
|
210 |
+
@dataclass
|
211 |
+
class Config:
|
212 |
+
name: str = "flux-dev"
|
213 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
214 |
+
offload: bool = False
|
215 |
+
share: bool = False
|
216 |
+
ckpt_dir: str = "."
|
217 |
+
|
218 |
+
def parse_args() -> Config:
|
219 |
parser = argparse.ArgumentParser(description="Flux")
|
220 |
parser.add_argument("--name", type=str, default="flux-dev", help="Model name")
|
221 |
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use")
|
222 |
parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
|
223 |
parser.add_argument("--share", action="store_true", help="Create a public link to your demo")
|
224 |
parser.add_argument("--ckpt_dir", type=str, default=".", help="Folder with checkpoints in safetensors format")
|
225 |
+
|
226 |
args = parser.parse_args()
|
227 |
+
return Config(**vars(args))
|
228 |
|
229 |
+
if __name__ == "__main__":
|
230 |
+
mp.set_start_method("spawn")
|
231 |
+
config = Config()
|
232 |
+
demo = create_demo(config.name, config.device, config.offload, config.ckpt_dir)
|
233 |
+
demo.launch(share=config.share)
|