|
import argparse |
|
import time |
|
|
|
import gradio as gr |
|
import torch |
|
from diffusers import DiffusionPipeline, UNet2DConditionModel |
|
|
|
from scheduling_dmd import DMDScheduler |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--unet-path", type='Lykon/dreamshaper-8') |
|
parser.add_argument("--model-path", type='aaronb/dreamshaper-8-dmd-kl-only-6kstep') |
|
args = parser.parse_args() |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
unet = UNet2DConditionModel.from_pretrained(args.unet_path) |
|
pipe = DiffusionPipeline.from_pretrained(args.model_path, unet=unet) |
|
pipe.scheduler = DMDScheduler.from_config(pipe.scheduler.config) |
|
pipe.to(device=device, dtype=torch.float16) |
|
|
|
|
|
def predict(prompt, seed=1231231): |
|
generator = torch.manual_seed(seed) |
|
last_time = time.time() |
|
|
|
image = pipe( |
|
prompt, |
|
num_inference_steps=1, |
|
guidance_scale=0.0, |
|
generator=generator, |
|
).images[0] |
|
|
|
print(f"Pipe took {time.time() - last_time} seconds") |
|
return image |
|
|
|
|
|
css = """ |
|
#container{ |
|
margin: 0 auto; |
|
max-width: 40rem; |
|
} |
|
#intro{ |
|
max-width: 100%; |
|
text-align: center; |
|
margin: 0 auto; |
|
} |
|
""" |
|
with gr.Blocks(css=css) as demo: |
|
with gr.Column(elem_id="container"): |
|
gr.Markdown( |
|
"""# Distribution Matching Distillation |
|
""", |
|
elem_id="intro", |
|
) |
|
with gr.Row(): |
|
with gr.Row(): |
|
prompt = gr.Textbox(placeholder="Insert your prompt here:", scale=5, container=False) |
|
generate_bt = gr.Button("Generate", scale=1) |
|
|
|
image = gr.Image(type="filepath") |
|
with gr.Accordion("Advanced options", open=False): |
|
seed = gr.Slider(randomize=True, minimum=0, maximum=12013012031030, label="Seed", step=1) |
|
|
|
inputs = [prompt, seed] |
|
generate_bt.click(fn=predict, inputs=inputs, outputs=image, show_progress=False) |
|
prompt.input(fn=predict, inputs=inputs, outputs=image, show_progress=False) |
|
seed.change(fn=predict, inputs=inputs, outputs=image, show_progress=False) |
|
|
|
demo.queue(api_open=False) |
|
demo.launch(show_api=False) |
|
|