takamarou's picture
offload some work to CPU
1a510b7
raw
history blame
3.64 kB
import spaces
import os
import torch
import gradio as gr
from fastapi import FastAPI, Response
from huggingface_hub import login
from diffusers import StableDiffusion3Pipeline, DDPMScheduler
from dotenv import load_dotenv
import uvicorn
import time
login(token=os.getenv("HF_TOKEN"))
torch.set_float32_matmul_precision("high")
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True
app = FastAPI(debug=True)
pipeline = StableDiffusion3Pipeline.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers",
torch_dtype=torch.float16,
)
pipeline.to("cuda")
@app.get("/")
def index():
content = """
<html>
<body>
<a href="https://takamarou-stickerparty.hf.space/stickers/">Playground</a>
<br />
<a href="https://takamarou-stickerparty.hf.space/stickers/">Stickers</a>
</body>
</html>
"""
return Response(content=content, media_type="text/html")
@spaces.GPU(duration=30)
def sd_on_gpu(prompt, negative_prompt, num_inference_steps, height, width, guidance_scale):
start_time = time.time()
response = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
height=height,
width=width,
guidance_scale=guidance_scale,
generator=torch.manual_seed(1),
)
run_time = time.time() - start_time
return response, run_time
def generate(label, prompt, negative_prompt, num_inference_steps, height, width, guidance_scale):
print('start generate', prompt, negative_prompt, num_inference_steps, height, width, guidance_scale)
start_time = time.time()
generation, gen_time = sd_on_gpu(prompt, negative_prompt, num_inference_steps, height, width, guidance_scale)
run_time = time.time() - start_time
return generation.images, run_time, gen_time
play = gr.Interface(
fn=generate,
inputs=[
gr.Label(value="Image Generation Playground"),
gr.Textbox(label="Prompt", lines=3),
gr.Textbox(label="Negative Prompt", lines=2),
gr.Slider(label="Inference Steps", value=13, minimum=1, maximum=30, step=1),
gr.Number(label="Height", value=1024),
gr.Number(label="Width", value=1024),
gr.Slider(label="Guidance Scale", value=7, minimum=1, maximum=15, step=1)
],
outputs=[gr.Gallery(), gr.Number(label="Total Generation Time"), gr.Number(label="GPU Time")],
)
stickers = gr.Interface(
fn=generate,
inputs=[
gr.Label(value="Sticker Optimization Console"),
gr.Textbox(label="Prompt", lines=3),
gr.Textbox(label="Negative Prompt", lines=2),
gr.Slider(label="Inference Steps", value=13, minimum=1, maximum=30, step=1),
gr.Number(label="Height", value=1024),
gr.Number(label="Width", value=1024),
gr.Slider(label="Guidance Scale", value=7, minimum=1, maximum=15, step=1)
],
outputs=[gr.Gallery(), gr.Number(label="Total Generation Time"), gr.Number(label="GPU Time")],
)
@spaces.GPU(duration=10)
def setup_pipe():
pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.vae.to(memory_format=torch.channels_last)
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
pipeline.vae.decode = torch.compile(pipeline.vae.decode, mode="max-autotune", fullgraph=True)
setup_pipe()
app = gr.mount_gradio_app(app, play, path="/gradio")
app = gr.mount_gradio_app(app, stickers, path="/stickers")
print('mounted')