takamarou commited on
Commit
347b364
·
1 Parent(s): 7b9da3f

link from the dashboard so we get higher quotas

Browse files
Files changed (1) hide show
  1. modules/app.py +28 -9
modules/app.py CHANGED
@@ -2,11 +2,12 @@ import spaces
2
  import os
3
  import torch
4
  import gradio as gr
5
- from fastapi import FastAPI
6
  from huggingface_hub import login
7
  from diffusers import StableDiffusion3Pipeline, DDPMScheduler
8
  from dotenv import load_dotenv
9
  import uvicorn
 
10
 
11
  login(token=os.getenv("HF_TOKEN"))
12
 
@@ -20,19 +21,37 @@ pipeline.to("cuda")
20
 
21
  @app.get("/")
22
  def index():
23
- return "Hello"
 
 
 
 
 
 
 
 
 
24
 
25
- @spaces.GPU
26
- def generate(label, prompt, negative_prompt, num_inference_steps, height, width, guidance_scale):
27
- print('start generate', prompt, negative_prompt, num_inference_steps, height, width, guidance_scale)
28
- return pipeline(
29
  prompt=prompt,
30
  negative_prompt=negative_prompt,
31
  num_inference_steps=num_inference_steps,
32
  height=height,
33
  width=width,
34
  guidance_scale=guidance_scale
35
- ).images
 
 
 
 
 
 
 
 
 
36
 
37
  play = gr.Interface(
38
  fn=generate,
@@ -45,7 +64,7 @@ play = gr.Interface(
45
  gr.Number(label="Width", value=1024),
46
  gr.Slider(label="Guidance Scale", value=7, minimum=1, maximum=15, step=1)
47
  ],
48
- outputs=gr.Gallery(),
49
  )
50
 
51
  stickers = gr.Interface(
@@ -59,7 +78,7 @@ stickers = gr.Interface(
59
  gr.Number(label="Width", value=1024),
60
  gr.Slider(label="Guidance Scale", value=7, minimum=1, maximum=15, step=1)
61
  ],
62
- outputs=gr.Gallery(),
63
  )
64
 
65
  app = gr.mount_gradio_app(app, play, path="/gradio")
 
2
  import os
3
  import torch
4
  import gradio as gr
5
+ from fastapi import FastAPI, Response
6
  from huggingface_hub import login
7
  from diffusers import StableDiffusion3Pipeline, DDPMScheduler
8
  from dotenv import load_dotenv
9
  import uvicorn
10
+ import time
11
 
12
  login(token=os.getenv("HF_TOKEN"))
13
 
 
21
 
22
  @app.get("/")
23
  def index():
24
+ content = """
25
+ <html>
26
+ <body>
27
+ <a href="https://takamarou-stickerparty.hf.space/stickers/">Playground</a>
28
+ <br />
29
+ <a href="https://takamarou-stickerparty.hf.space/stickers/">Stickers</a>
30
+ </body>
31
+ </html>
32
+ """
33
+ return Response(content=content, media_type="text/html")
34
 
35
+ @spaces.GPU(duration=30)
36
+ def sd_on_gpu(prompt, negative_prompt, num_inference_steps, height, width, guidance_scale):
37
+ start_time = time.time()
38
+ response = pipeline(
39
  prompt=prompt,
40
  negative_prompt=negative_prompt,
41
  num_inference_steps=num_inference_steps,
42
  height=height,
43
  width=width,
44
  guidance_scale=guidance_scale
45
+ )
46
+ run_time = time.time() - start_time
47
+ return response, run_time
48
+
49
+ def generate(label, prompt, negative_prompt, num_inference_steps, height, width, guidance_scale):
50
+ print('start generate', prompt, negative_prompt, num_inference_steps, height, width, guidance_scale)
51
+ start_time = time.time()
52
+ generation, gen_time = sd_on_gpu(prompt, negative_prompt, num_inference_steps, height, width, guidance_scale)
53
+ run_time = time.time() - start_time
54
+ return generation.images, run_time, gen_time
55
 
56
  play = gr.Interface(
57
  fn=generate,
 
64
  gr.Number(label="Width", value=1024),
65
  gr.Slider(label="Guidance Scale", value=7, minimum=1, maximum=15, step=1)
66
  ],
67
+ outputs=[gr.Gallery(), gr.Number(label="Total Generation Time"), gr.Number(label="GPU Time")],
68
  )
69
 
70
  stickers = gr.Interface(
 
78
  gr.Number(label="Width", value=1024),
79
  gr.Slider(label="Guidance Scale", value=7, minimum=1, maximum=15, step=1)
80
  ],
81
+ outputs=[gr.Gallery(), gr.Number(label="Total Generation Time"), gr.Number(label="GPU Time")],
82
  )
83
 
84
  app = gr.mount_gradio_app(app, play, path="/gradio")