takamarou commited on
Commit
1a510b7
·
1 Parent(s): 01783ae

offload some work to CPU

Browse files
Files changed (1) hide show
  1. modules/app.py +21 -3
modules/app.py CHANGED
@@ -11,6 +11,13 @@ import time
11
 
12
  login(token=os.getenv("HF_TOKEN"))
13
 
 
 
 
 
 
 
 
14
  app = FastAPI(debug=True)
15
 
16
  pipeline = StableDiffusion3Pipeline.from_pretrained(
@@ -42,6 +49,7 @@ def sd_on_gpu(prompt, negative_prompt, num_inference_steps, height, width, guida
42
  height=height,
43
  width=width,
44
  guidance_scale=guidance_scale,
 
45
  )
46
  run_time = time.time() - start_time
47
  return response, run_time
@@ -59,7 +67,7 @@ play = gr.Interface(
59
  gr.Label(value="Image Generation Playground"),
60
  gr.Textbox(label="Prompt", lines=3),
61
  gr.Textbox(label="Negative Prompt", lines=2),
62
- gr.Slider(label="Inference Steps", value=20, minimum=1, maximum=30, step=1),
63
  gr.Number(label="Height", value=1024),
64
  gr.Number(label="Width", value=1024),
65
  gr.Slider(label="Guidance Scale", value=7, minimum=1, maximum=15, step=1)
@@ -73,7 +81,7 @@ stickers = gr.Interface(
73
  gr.Label(value="Sticker Optimization Console"),
74
  gr.Textbox(label="Prompt", lines=3),
75
  gr.Textbox(label="Negative Prompt", lines=2),
76
- gr.Slider(label="Inference Steps", value=20, minimum=1, maximum=30, step=1),
77
  gr.Number(label="Height", value=1024),
78
  gr.Number(label="Width", value=1024),
79
  gr.Slider(label="Guidance Scale", value=7, minimum=1, maximum=15, step=1)
@@ -81,5 +89,15 @@ stickers = gr.Interface(
81
  outputs=[gr.Gallery(), gr.Number(label="Total Generation Time"), gr.Number(label="GPU Time")],
82
  )
83
 
 
 
 
 
 
 
 
 
 
84
  app = gr.mount_gradio_app(app, play, path="/gradio")
85
- app = gr.mount_gradio_app(app, stickers, path="/stickers")
 
 
11
 
12
  login(token=os.getenv("HF_TOKEN"))
13
 
14
+ torch.set_float32_matmul_precision("high")
15
+
16
+ torch._inductor.config.conv_1x1_as_mm = True
17
+ torch._inductor.config.coordinate_descent_tuning = True
18
+ torch._inductor.config.epilogue_fusion = False
19
+ torch._inductor.config.coordinate_descent_check_all_directions = True
20
+
21
  app = FastAPI(debug=True)
22
 
23
  pipeline = StableDiffusion3Pipeline.from_pretrained(
 
49
  height=height,
50
  width=width,
51
  guidance_scale=guidance_scale,
52
+ generator=torch.manual_seed(1),
53
  )
54
  run_time = time.time() - start_time
55
  return response, run_time
 
67
  gr.Label(value="Image Generation Playground"),
68
  gr.Textbox(label="Prompt", lines=3),
69
  gr.Textbox(label="Negative Prompt", lines=2),
70
+ gr.Slider(label="Inference Steps", value=13, minimum=1, maximum=30, step=1),
71
  gr.Number(label="Height", value=1024),
72
  gr.Number(label="Width", value=1024),
73
  gr.Slider(label="Guidance Scale", value=7, minimum=1, maximum=15, step=1)
 
81
  gr.Label(value="Sticker Optimization Console"),
82
  gr.Textbox(label="Prompt", lines=3),
83
  gr.Textbox(label="Negative Prompt", lines=2),
84
+ gr.Slider(label="Inference Steps", value=13, minimum=1, maximum=30, step=1),
85
  gr.Number(label="Height", value=1024),
86
  gr.Number(label="Width", value=1024),
87
  gr.Slider(label="Guidance Scale", value=7, minimum=1, maximum=15, step=1)
 
89
  outputs=[gr.Gallery(), gr.Number(label="Total Generation Time"), gr.Number(label="GPU Time")],
90
  )
91
 
92
+ @spaces.GPU(duration=10)
93
+ def setup_pipe():
94
+ pipeline.transformer.to(memory_format=torch.channels_last)
95
+ pipeline.vae.to(memory_format=torch.channels_last)
96
+
97
+ pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
98
+ pipeline.vae.decode = torch.compile(pipeline.vae.decode, mode="max-autotune", fullgraph=True)
99
+
100
+ setup_pipe()
101
  app = gr.mount_gradio_app(app, play, path="/gradio")
102
+ app = gr.mount_gradio_app(app, stickers, path="/stickers")
103
+ print('mounted')