File size: 3,641 Bytes
4690160
be15aa1
4690160
 
347b364
4690160
 
 
 
347b364
6dfd733
4690160
6dfd733
1a510b7
 
 
 
 
 
 
d5d120d
6dfd733
4690160
cd3c0aa
4690160
 
 
6dfd733
4690160
 
347b364
 
 
 
 
 
 
 
 
 
6dfd733
01783ae
347b364
 
 
4690160
 
 
 
 
b3d6ca9
1a510b7
347b364
 
 
 
 
 
 
 
 
 
6dfd733
95619c1
4690160
 
cd3c0aa
4690160
 
1a510b7
7b9da3f
 
4690160
 
347b364
4690160
 
95619c1
 
 
cd3c0aa
95619c1
 
1a510b7
7b9da3f
 
95619c1
 
347b364
95619c1
 
1a510b7
 
 
 
 
 
 
 
 
95619c1
1a510b7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import spaces
import os
import torch
import gradio as gr
from fastapi import FastAPI, Response
from huggingface_hub import login
from diffusers import StableDiffusion3Pipeline, DDPMScheduler
from dotenv import load_dotenv
import uvicorn
import time

login(token=os.getenv("HF_TOKEN"))

torch.set_float32_matmul_precision("high")

torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True

app = FastAPI(debug=True)

pipeline = StableDiffusion3Pipeline.from_pretrained(
    "stabilityai/stable-diffusion-3-medium-diffusers",
    torch_dtype=torch.float16,
)
pipeline.to("cuda")

@app.get("/")
def index():
    content = """
<html>
    <body>
        <a href="https://takamarou-stickerparty.hf.space/stickers/">Playground</a>
        <br />
        <a href="https://takamarou-stickerparty.hf.space/stickers/">Stickers</a>
    </body>
</html>
"""
    return Response(content=content, media_type="text/html")

@spaces.GPU(duration=30)
def sd_on_gpu(prompt, negative_prompt, num_inference_steps, height, width, guidance_scale):
    start_time = time.time()
    response = pipeline(
        prompt=prompt,
        negative_prompt=negative_prompt,
        num_inference_steps=num_inference_steps,
        height=height,
        width=width,
        guidance_scale=guidance_scale,
        generator=torch.manual_seed(1),
    )
    run_time = time.time() - start_time
    return response, run_time

def generate(label, prompt, negative_prompt, num_inference_steps, height, width, guidance_scale):
    print('start generate', prompt, negative_prompt, num_inference_steps, height, width, guidance_scale)
    start_time = time.time()
    generation, gen_time = sd_on_gpu(prompt, negative_prompt, num_inference_steps, height, width, guidance_scale)
    run_time = time.time() - start_time
    return generation.images, run_time, gen_time

play = gr.Interface(
    fn=generate,
    inputs=[
        gr.Label(value="Image Generation Playground"),
        gr.Textbox(label="Prompt", lines=3),
        gr.Textbox(label="Negative Prompt", lines=2),
        gr.Slider(label="Inference Steps", value=13, minimum=1, maximum=30, step=1),
        gr.Number(label="Height", value=1024),
        gr.Number(label="Width", value=1024),
        gr.Slider(label="Guidance Scale", value=7, minimum=1, maximum=15, step=1)
    ],
    outputs=[gr.Gallery(), gr.Number(label="Total Generation Time"), gr.Number(label="GPU Time")],
)

stickers = gr.Interface(
    fn=generate,
    inputs=[
        gr.Label(value="Sticker Optimization Console"),
        gr.Textbox(label="Prompt", lines=3),
        gr.Textbox(label="Negative Prompt", lines=2),
        gr.Slider(label="Inference Steps", value=13, minimum=1, maximum=30, step=1),
        gr.Number(label="Height", value=1024),
        gr.Number(label="Width", value=1024),
        gr.Slider(label="Guidance Scale", value=7, minimum=1, maximum=15, step=1)
    ],
    outputs=[gr.Gallery(), gr.Number(label="Total Generation Time"), gr.Number(label="GPU Time")],
)

@spaces.GPU(duration=10)
def setup_pipe():
    pipeline.transformer.to(memory_format=torch.channels_last)
    pipeline.vae.to(memory_format=torch.channels_last)

    pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
    pipeline.vae.decode = torch.compile(pipeline.vae.decode, mode="max-autotune", fullgraph=True)

setup_pipe()
app = gr.mount_gradio_app(app, play, path="/gradio")
app = gr.mount_gradio_app(app, stickers, path="/stickers")
print('mounted')