zhangyang-0123 commited on
Commit
eb3568a
·
1 Parent(s): f001490

enable zero GPU

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  from dataclasses import dataclass
3
-
4
  import torch
5
  from tqdm import tqdm
6
 
@@ -91,7 +91,7 @@ def prune_model(pipe, hookers):
91
  ffn_hook.clear_hooks()
92
  return pipe
93
 
94
-
95
  def binary_mask_eval(args):
96
  # load sdxl model
97
  pipe = StableDiffusionXLPipeline.from_pretrained(
@@ -141,7 +141,7 @@ def binary_mask_eval(args):
141
  print("prune complete")
142
  return pipe, pruned_pipe
143
 
144
-
145
  def generate_images(prompt, seed, steps, pipe, pruned_pipe):
146
  # Run the model and return images directly
147
  g_cpu = torch.Generator("cuda:0").manual_seed(seed)
@@ -221,4 +221,4 @@ def create_demo():
221
 
222
  if __name__ == "__main__":
223
  demo = create_demo()
224
- demo.launch(share=True)
 
1
  import gradio as gr
2
  from dataclasses import dataclass
3
+ import spaces
4
  import torch
5
  from tqdm import tqdm
6
 
 
91
  ffn_hook.clear_hooks()
92
  return pipe
93
 
94
+ @spaces.GPU
95
  def binary_mask_eval(args):
96
  # load sdxl model
97
  pipe = StableDiffusionXLPipeline.from_pretrained(
 
141
  print("prune complete")
142
  return pipe, pruned_pipe
143
 
144
+ @spaces.GPU
145
  def generate_images(prompt, seed, steps, pipe, pruned_pipe):
146
  # Run the model and return images directly
147
  g_cpu = torch.Generator("cuda:0").manual_seed(seed)
 
221
 
222
  if __name__ == "__main__":
223
  demo = create_demo()
224
+ demo.launch()