stazizov commited on
Commit
35d14ab
1 Parent(s): 2458a22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -4
app.py CHANGED
@@ -5,6 +5,8 @@ import tempfile
5
  import subprocess
6
  from pathlib import Path
7
 
 
 
8
  import spaces
9
  import gradio as gr
10
  import torch
@@ -205,15 +207,27 @@ def create_demo(
205
 
206
  return demo
207
 
208
- if __name__ == "__main__":
209
- import argparse
 
 
 
 
 
 
 
210
  parser = argparse.ArgumentParser(description="Flux")
211
  parser.add_argument("--name", type=str, default="flux-dev", help="Model name")
212
  parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use")
213
  parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
214
  parser.add_argument("--share", action="store_true", help="Create a public link to your demo")
215
  parser.add_argument("--ckpt_dir", type=str, default=".", help="Folder with checkpoints in safetensors format")
 
216
  args = parser.parse_args()
 
217
 
218
- demo = create_demo(args.name, args.device, args.offload, args.ckpt_dir)
219
- demo.launch(share=args.share)
 
 
 
 
5
  import subprocess
6
  from pathlib import Path
7
 
8
+ from dataclasses import dataclass
9
+ import torch.multiprocessing as mp
10
  import spaces
11
  import gradio as gr
12
  import torch
 
207
 
208
  return demo
209
 
210
+ @dataclass
211
+ class Config:
212
+ name: str = "flux-dev"
213
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
214
+ offload: bool = False
215
+ share: bool = False
216
+ ckpt_dir: str = "."
217
+
218
+ def parse_args() -> Config:
219
  parser = argparse.ArgumentParser(description="Flux")
220
  parser.add_argument("--name", type=str, default="flux-dev", help="Model name")
221
  parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use")
222
  parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
223
  parser.add_argument("--share", action="store_true", help="Create a public link to your demo")
224
  parser.add_argument("--ckpt_dir", type=str, default=".", help="Folder with checkpoints in safetensors format")
225
+
226
  args = parser.parse_args()
227
+ return Config(**vars(args))
228
 
229
+ if __name__ == "__main__":
230
+ mp.set_start_method("spawn")
231
+ config = Config()
232
+ demo = create_demo(config.name, config.device, config.offload, config.ckpt_dir)
233
+ demo.launch(share=config.share)