File size: 2,954 Bytes
ef187eb e7915f0 0cffd40 ef187eb 2b0f02c 0cffd40 8b1e96d 0cffd40 8b1e96d 0cccf69 8b1e96d ec35e66 8b1e96d e7915f0 8b1e96d f286ae5 8b1e96d 3494613 6380dba 8b1e96d 3819ced 67399b5 fee8445 67399b5 0cffd40 ef187eb 8b1e96d 0cffd40 8b3ca8d 0cffd40 8b1e96d 0cffd40 2b0f02c 8b1e96d 3eaeeea 8b1e96d 0cffd40 8b1e96d fe16630 8b1e96d 8b3ca8d fe16630 8b3ca8d 8b1e96d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import gradio as gr
import torch
from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler
from huggingface_hub import hf_hub_download
import spaces
from PIL import Image
# Constants
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "tianweiy/DMD2"
checkpoints = {
"1-Step" : ["dmd2_sdxl_1step_unet_fp16.bin", 1],
"4-Step" : ["dmd2_sdxl_4step_unet_fp16.bin", 4],
}
loaded = None
CSS = """
.gradio-container {
max-width: 690px !important;
}
"""
# Ensure model and scheduler are initialized in GPU-enabled function
if torch.cuda.is_available():
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
pipe = DiffusionPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
# Function
@spaces.GPU()
def generate_image(prompt, ckpt):
global loaded
print(prompt, ckpt)
checkpoint = checkpoints[ckpt][0]
num_inference_steps = checkpoints[ckpt][1]
if loaded != num_inference_steps:
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe.unet.load_state_dict(torch.load(hf_hub_download(repo, checkpoint), map_location="cuda"))
loaded = num_inference_steps
if loaded == 1:
results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0, timesteps=[399])
else:
results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0, timesteps=[999, 749, 499, 249])
return results.images[0]
examples = [
"a cat eating a piece of cheese",
"a ROBOT riding a BLUE horse on Mars, photorealistic",
"Ironman VS Hulk, ultrarealistic",
"a CUTE robot artist painting on an easel",
"Astronaut in a jungle, cold color palette, oil pastel, detailed, 8k",
"An alien holding sign board contain word 'Flash', futuristic, neonpunk",
"Kids going to school, Anime style"
]
# Gradio Interface
with gr.Blocks(css=CSS, theme="soft") as demo:
gr.HTML("<h1><center>Adobe DMD2🦖</center></h1>")
gr.HTML("<p><center><a href='https://huggingface.co/tianweiy/DMD2'>DMD2</a> text-to-image generation</center></p>")
with gr.Group():
with gr.Row():
prompt = gr.Textbox(label='Enter your prompt (English)', scale=8)
ckpt = gr.Dropdown(label='Steps',choices=['1-Step', '4-Step'], value='4-Step', interactive=True)
submit = gr.Button(scale=1, variant='primary')
img = gr.Image(label='DMD2 Generated Image')
gr.Examples(
examples=examples,
inputs=prompt,
outputs=img,
fn=generate_image,
cache_examples="lazy",
)
prompt.submit(fn=generate_image,
inputs=[prompt, ckpt],
outputs=img,
)
submit.click(fn=generate_image,
inputs=[prompt, ckpt],
outputs=img,
)
demo.queue().launch() |