Update app.py
Browse files
app.py
CHANGED
@@ -10,8 +10,8 @@ import PIL
|
|
10 |
base = "stabilityai/stable-diffusion-xl-base-1.0"
|
11 |
repo = "tianweiy/DMD2"
|
12 |
checkpoints = {
|
13 |
-
"1-Step" : ["
|
14 |
-
"4-Step" : ["
|
15 |
}
|
16 |
loaded = None
|
17 |
|
@@ -37,7 +37,7 @@ def generate_image(prompt, ckpt):
|
|
37 |
num_inference_steps = checkpoints[ckpt][1]
|
38 |
|
39 |
if loaded != num_inference_steps:
|
40 |
-
unet.load_state_dict(torch.load(hf_hub_download(repo,
|
41 |
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if num_inference_steps==1 else "epsilon")
|
42 |
loaded = num_inference_steps
|
43 |
|
@@ -51,7 +51,7 @@ def generate_image(prompt, ckpt):
|
|
51 |
|
52 |
with gr.Blocks(css=CSS) as demo:
|
53 |
gr.HTML("<h1><center>Adobe DMD2🦖</center></h1>")
|
54 |
-
gr.HTML("<p><center><a href='https://huggingface.co/tianweiy/DMD2'>
|
55 |
with gr.Group():
|
56 |
with gr.Row():
|
57 |
prompt = gr.Textbox(label='Enter your prompt (English)', scale=8)
|
|
|
10 |
base = "stabilityai/stable-diffusion-xl-base-1.0"
|
11 |
repo = "tianweiy/DMD2"
|
12 |
checkpoints = {
|
13 |
+
"1-Step" : ["dmd2_sdxl_1step_unet_fp16.bin", 1],
|
14 |
+
"4-Step" : ["dmd2_sdxl_4step_unet_fp16.bin", 4],
|
15 |
}
|
16 |
loaded = None
|
17 |
|
|
|
37 |
num_inference_steps = checkpoints[ckpt][1]
|
38 |
|
39 |
if loaded != num_inference_steps:
|
40 |
+
unet.load_state_dict(torch.load(hf_hub_download(repo, checkpoint)), map_location="cuda")
|
41 |
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if num_inference_steps==1 else "epsilon")
|
42 |
loaded = num_inference_steps
|
43 |
|
|
|
51 |
|
52 |
with gr.Blocks(css=CSS) as demo:
|
53 |
gr.HTML("<h1><center>Adobe DMD2🦖</center></h1>")
|
54 |
+
gr.HTML("<p><center><a href='https://huggingface.co/tianweiy/DMD2'>DMD2</a> text-to-image generation</center></p>")
|
55 |
with gr.Group():
|
56 |
with gr.Row():
|
57 |
prompt = gr.Textbox(label='Enter your prompt (English)', scale=8)
|