Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import torch
3
  from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
4
  from huggingface_hub import hf_hub_download
 
5
  import spaces
6
 
7
 
@@ -9,10 +10,10 @@ import spaces
9
  base = "stabilityai/stable-diffusion-xl-base-1.0"
10
  repo = "ByteDance/SDXL-Lightning"
11
  checkpoints = {
12
- "1-Step" : ["sdxl_lightning_1step_unet_x0.pth", 1],
13
- "2-Step" : ["sdxl_lightning_2step_unet.pth", 2],
14
- "4-Step" : ["sdxl_lightning_4step_unet.pth", 4],
15
- "8-Step" : ["sdxl_lightning_8step_unet.pth", 8],
16
  }
17
 
18
 
@@ -35,7 +36,7 @@ def generate_image(prompt, ckpt):
35
  # Ensure sampler uses "trailing" timesteps.
36
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
37
 
38
- pipe.unet.load_state_dict(torch.load(hf_hub_download(repo, checkpoint), map_location="cuda"))
39
  image = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0).images[0]
40
  return image
41
 
 
2
  import torch
3
  from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
4
  from huggingface_hub import hf_hub_download
5
+ from safetensors.torch import load_file
6
  import spaces
7
 
8
 
 
10
  base = "stabilityai/stable-diffusion-xl-base-1.0"
11
  repo = "ByteDance/SDXL-Lightning"
12
  checkpoints = {
13
+ "1-Step" : ["sdxl_lightning_1step_unet_x0.safetensors", 1],
14
+ "2-Step" : ["sdxl_lightning_2step_unet.safetensors", 2],
15
+ "4-Step" : ["sdxl_lightning_4step_unet.safetensors", 4],
16
+ "8-Step" : ["sdxl_lightning_8step_unet.safetensors", 8],
17
  }
18
 
19
 
 
36
  # Ensure sampler uses "trailing" timesteps.
37
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
38
 
39
+ pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device="cuda"))
40
  image = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0).images[0]
41
  return image
42