Hei-Ha commited on
Commit
a0b14df
·
1 Parent(s): f2afba6
Files changed (2) hide show
  1. app.py +85 -17
  2. requirements.txt +0 -2
app.py CHANGED
@@ -1,35 +1,103 @@
 
1
  import torch
2
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
3
  from huggingface_hub import hf_hub_download
4
  from safetensors.torch import load_file
 
 
 
 
 
5
 
6
  base = "stabilityai/stable-diffusion-xl-base-1.0"
7
  repo = "ByteDance/SDXL-Lightning"
8
- ckpt = "sdxl_lightning_4step_unet.safetensors" # Use the correct ckpt for your step setting!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # Load model.
11
- unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
12
- unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
13
- pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
14
 
15
- # Ensure sampler uses "trailing" timesteps.
16
- pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
 
 
 
 
 
17
 
18
 
19
- pipe("A girl smiling", num_inference_steps=4, guidance_scale=0).images[0].save("output.png")
20
 
21
 
22
- # with gr.Blocks() as demo:
23
- # with gr.Gradio():
24
- # with gr.Row():
25
 
26
 
 
 
 
 
 
27
 
 
 
 
 
 
 
 
 
 
28
 
 
 
 
 
 
 
 
 
29
 
30
- # import gradio as gr
31
- # def greet(name):
32
- # return "Hello " + name + "!!"
33
- #
34
- # iface = gr.Interface(fn=greet, inputs="text", outputs="text")
35
- # iface.launch()
 
1
+ import gradio as gr
2
  import torch
3
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
4
  from huggingface_hub import hf_hub_download
5
  from safetensors.torch import load_file
6
+ import spaces
7
+ import os
8
+ from PIL import Image
9
+
10
+ SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1"
11
 
12
  base = "stabilityai/stable-diffusion-xl-base-1.0"
13
  repo = "ByteDance/SDXL-Lightning"
14
+ checkpoints = {
15
+ "1-Step" : ["sdxl_lightning_1step_unet_x0.safetensors", 1],
16
+ "2-Step" : ["sdxl_lightning_2step_unet.safetensors", 2],
17
+ "4-Step" : ["sdxl_lightning_4step_unet.safetensors", 4],
18
+ "8-Step" : ["sdxl_lightning_8step_unet.safetensors", 8],
19
+ }
20
+
21
+ # Ensure model and scheduler are initialized in GPU-enabled function
22
+ if torch.cuda.is_available():
23
+ pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
24
+
25
+
26
+ if SAFETY_CHECKER:
27
+ from safety_checker import StableDiffusionSafetyChecker
28
+ from transformers import CLIPFeatureExtractor
29
+
30
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(
31
+ "CompVis/stable-diffusion-safety-checker"
32
+ ).to("cuda")
33
+ feature_extractor = CLIPFeatureExtractor.from_pretrained(
34
+ "openai/clip-vit-base-patch32"
35
+ )
36
+
37
+ def check_nsfw_images(
38
+ images: list[Image.Image],
39
+ ) -> tuple[list[Image.Image], list[bool]]:
40
+ safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
41
+ has_nsfw_concepts = safety_checker(
42
+ images=[images],
43
+ clip_input=safety_checker_input.pixel_values.to("cuda")
44
+ )
45
+
46
+ return images, has_nsfw_concepts
47
+
48
+ # Function
49
+ @spaces.GPU(enable_queue=True)
50
+ def generate_image(prompt, ckpt):
51
+
52
+ checkpoint = checkpoints[ckpt][0]
53
+ num_inference_steps = checkpoints[ckpt][1]
54
+
55
+ if num_inference_steps==1:
56
+ # Ensure sampler uses "trailing" timesteps and "sample" prediction type for 1-step inference.
57
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
58
+ else:
59
+ # Ensure sampler uses "trailing" timesteps.
60
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
61
 
62
+ pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device="cuda"))
63
+ results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0)
 
 
64
 
65
+ if SAFETY_CHECKER:
66
+ images, has_nsfw_concepts = check_nsfw_images(results.images)
67
+ if any(has_nsfw_concepts):
68
+ gr.Warning("NSFW content detected.")
69
+ return Image.new("RGB", (512, 512))
70
+ return images[0]
71
+ return results.images[0]
72
 
73
 
 
74
 
75
 
 
 
 
76
 
77
 
78
+ # Gradio Interface
79
+ description = """
80
+ This demo utilizes the SDXL-Lightning model by ByteDance, which is a lightning-fast text-to-image generative model capable of producing high-quality images in 4 steps.
81
+ As a community effort, this demo was put together by AngryPenguin. Link to model: https://huggingface.co/ByteDance/SDXL-Lightning
82
+ """
83
 
84
+ with gr.Blocks(css="style.css") as demo:
85
+ gr.HTML("<h1><center>Text-to-Image with SDXL-Lightning ⚡</center></h1>")
86
+ gr.Markdown(description)
87
+ with gr.Group():
88
+ with gr.Row():
89
+ prompt = gr.Textbox(label='Enter you image prompt:', scale=8)
90
+ ckpt = gr.Dropdown(label='Select inference steps',choices=['1-Step', '2-Step', '4-Step', '8-Step'], value='4-Step', interactive=True)
91
+ submit = gr.Button(scale=1, variant='primary')
92
+ img = gr.Image(label='SDXL-Lightning Generated Image')
93
 
94
+ prompt.submit(fn=generate_image,
95
+ inputs=[prompt, ckpt],
96
+ outputs=img,
97
+ )
98
+ submit.click(fn=generate_image,
99
+ inputs=[prompt, ckpt],
100
+ outputs=img,
101
+ )
102
 
103
+ demo.queue().launch()
 
 
 
 
 
requirements.txt CHANGED
@@ -1,6 +1,4 @@
1
  transformers
2
- huggingface_hub
3
- safetensors
4
  diffusers
5
  torch
6
  accelerate
 
1
  transformers
 
 
2
  diffusers
3
  torch
4
  accelerate