Hei-Ha commited on
Commit
9094355
·
1 Parent(s): a1740b4
Files changed (3) hide show
  1. README.md +0 -2
  2. app.py +63 -167
  3. requirements.txt +4 -15
README.md CHANGED
@@ -8,6 +8,4 @@ sdk_version: 4.19.2
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
- disable_embedding: true
12
- header: mini
13
  ---
 
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
 
11
  ---
app.py CHANGED
@@ -1,79 +1,36 @@
1
- from diffusers import (
2
- StableDiffusionXLPipeline,
3
- EulerDiscreteScheduler,
4
- UNet2DConditionModel,
5
- AutoencoderTiny,
6
- )
7
  import torch
8
- import os
9
  from huggingface_hub import hf_hub_download
10
-
11
-
12
- from PIL import Image
13
- import gradio as gr
14
- import time
15
  from safetensors.torch import load_file
16
- import time
17
- import tempfile
18
- from pathlib import Path
19
-
20
- # Constants
21
- BASE = "stabilityai/stable-diffusion-xl-base-1.0"
22
- REPO = "ByteDance/SDXL-Lightning"
23
- # 1-step
24
- CHECKPOINT = "sdxl_lightning_2step_unet.safetensors"
25
- taesd_model = "madebyollin/taesdxl"
26
-
27
- # {
28
- # "1-Step": ["sdxl_lightning_1step_unet_x0.safetensors", 1],
29
- # "2-Step": ["sdxl_lightning_2step_unet.safetensors", 2],
30
- # "4-Step": ["sdxl_lightning_4step_unet.safetensors", 4],
31
- # "8-Step": ["sdxl_lightning_8step_unet.safetensors", 8],
32
- # }
33
-
34
 
35
- SFAST_COMPILE = os.environ.get("SFAST_COMPILE", "0") == "1"
36
  SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1"
37
- USE_TAESD = os.environ.get("USE_TAESD", "0") == "1"
38
-
39
- # check if MPS is available OSX only M1/M2/M3 chips
40
-
41
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
- torch_device = device
43
- torch_dtype = torch.float16
44
 
45
- print(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
46
- print(f"SFAST_COMPILE: {SFAST_COMPILE}")
47
- print(f"USE_TAESD: {USE_TAESD}")
48
- print(f"device: {device}")
49
-
50
-
51
- unet = UNet2DConditionModel.from_config(BASE, subfolder="unet").to(
52
- "cuda", torch.float16
53
- )
54
- unet.load_state_dict(load_file(hf_hub_download(REPO, CHECKPOINT), device="cuda"))
55
- pipe = StableDiffusionXLPipeline.from_pretrained(
56
- BASE, unet=unet, torch_dtype=torch.float16, variant="fp16", safety_checker=False
57
- ).to("cuda")
58
 
59
- if USE_TAESD:
60
- pipe.vae = AutoencoderTiny.from_pretrained(
61
- taesd_model, torch_dtype=torch_dtype, use_safetensors=True
62
- ).to(device)
63
 
 
 
 
64
 
65
- # Ensure sampler uses "trailing" timesteps.
66
- pipe.scheduler = EulerDiscreteScheduler.from_config(
67
- pipe.scheduler.config, timestep_spacing="trailing"
68
- )
69
- pipe.set_progress_bar_config(disable=True)
70
  if SAFETY_CHECKER:
71
  from safety_checker import StableDiffusionSafetyChecker
72
  from transformers import CLIPFeatureExtractor
73
 
74
  safety_checker = StableDiffusionSafetyChecker.from_pretrained(
75
  "CompVis/stable-diffusion-safety-checker"
76
- ).to(device)
77
  feature_extractor = CLIPFeatureExtractor.from_pretrained(
78
  "openai/clip-vit-base-patch32"
79
  )
@@ -81,125 +38,64 @@ if SAFETY_CHECKER:
81
  def check_nsfw_images(
82
  images: list[Image.Image],
83
  ) -> tuple[list[Image.Image], list[bool]]:
84
- safety_checker_input = feature_extractor(images, return_tensors="pt").to(device)
85
  has_nsfw_concepts = safety_checker(
86
  images=[images],
87
- clip_input=safety_checker_input.pixel_values.to(torch_device),
88
  )
89
 
90
  return images, has_nsfw_concepts
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
- if SFAST_COMPILE:
94
- from sfast.compilers.diffusion_pipeline_compiler import compile, CompilationConfig
95
-
96
- # sfast compilation
97
- config = CompilationConfig.Default()
98
- try:
99
- import xformers
100
-
101
- config.enable_xformers = True
102
- except ImportError:
103
- print("xformers not installed, skip")
104
- try:
105
- import triton
106
-
107
- config.enable_triton = True
108
- except ImportError:
109
- print("Triton not installed, skip")
110
- # CUDA Graph is suggested for small batch sizes and small resolutions to reduce CPU overhead.
111
- # But it can increase the amount of GPU memory used.
112
- # For StableVideoDiffusionPipeline it is not needed.
113
- config.enable_cuda_graph = True
114
-
115
- pipe = compile(pipe, config)
116
-
117
-
118
- def predict(prompt, seed=1231231):
119
- generator = torch.manual_seed(seed)
120
- last_time = time.time()
121
- results = pipe(
122
- prompt=prompt,
123
- generator=generator,
124
- num_inference_steps=2,
125
- guidance_scale=0.0,
126
- # width=768,
127
- # height=768,
128
- output_type="pil",
129
- )
130
- print(f"Pipe took {time.time() - last_time} seconds")
131
  if SAFETY_CHECKER:
132
  images, has_nsfw_concepts = check_nsfw_images(results.images)
133
  if any(has_nsfw_concepts):
134
  gr.Warning("NSFW content detected.")
135
  return Image.new("RGB", (512, 512))
136
- image = results.images[0]
137
- with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmpfile:
138
- image.save(tmpfile, "JPEG", quality=80, optimize=True, progressive=True)
139
- return Path(tmpfile.name)
140
 
141
 
142
- css = """
143
- #container{
144
- margin: 0 auto;
145
- max-width: 40rem;
146
- }
147
- #intro{
148
- max-width: 100%;
149
- margin: 0 auto;
150
- }
151
  """
152
- with gr.Blocks(css=css) as demo:
153
- with gr.Column(elem_id="container"):
154
- gr.Markdown(
155
- """
156
- # SDXL-Lightning- Text To Image 2-Steps
157
- **Model**: https://huggingface.co/ByteDance/SDXL-Lightning
158
- """,
159
- elem_id="intro",
160
- )
161
- with gr.Row():
162
- with gr.Row():
163
- prompt = gr.Textbox(
164
- placeholder="Insert your prompt here:", scale=5, container=False
165
- )
166
- generate_bt = gr.Button("Generate", scale=1)
167
-
168
- image = gr.Image(type="filepath")
169
- with gr.Accordion("Advanced options", open=False):
170
- seed = gr.Slider(
171
- randomize=True, minimum=0, maximum=12013012031030, label="Seed", step=1
172
- )
173
- with gr.Accordion("Run with diffusers"):
174
- gr.Markdown(
175
- """## Running SDXL-Lightning with `diffusers`
176
- ```py
177
- import torch
178
- from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
179
- from huggingface_hub import hf_hub_download
180
- from safetensors.torch import load_file
181
- base = "stabilityai/stable-diffusion-xl-base-1.0"
182
- repo = "ByteDance/SDXL-Lightning"
183
- ckpt = "sdxl_lightning_2step_unet.safetensors" # Use the correct ckpt for your step setting!
184
- # Load model.
185
- unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
186
- unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
187
- pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
188
- # Ensure sampler uses "trailing" timesteps.
189
- pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
190
- # Ensure using the same inference steps as the loaded model and CFG set to 0.
191
- pipe("A girl smiling", num_inference_steps=2, guidance_scale=0).images[0].save("output.png")
192
- ```
193
- """
194
- )
195
-
196
- inputs = [prompt, seed]
197
- outputs = [image]
198
- generate_bt.click(
199
- fn=predict, inputs=inputs, outputs=outputs, show_progress=False
200
- )
201
- prompt.input(fn=predict, inputs=inputs, outputs=outputs, show_progress=False)
202
- seed.change(fn=predict, inputs=inputs, outputs=outputs, show_progress=False)
203
 
204
- demo.queue()
205
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
 
 
 
 
 
2
  import torch
3
+ from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
4
  from huggingface_hub import hf_hub_download
 
 
 
 
 
5
  from safetensors.torch import load_file
6
+ import spaces
7
+ import os
8
+ from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
 
10
  SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1"
 
 
 
 
 
 
 
11
 
12
+ # Constants
13
+ base = "stabilityai/stable-diffusion-xl-base-1.0"
14
+ repo = "ByteDance/SDXL-Lightning"
15
+ checkpoints = {
16
+ "1-Step" : ["sdxl_lightning_1step_unet_x0.safetensors", 1],
17
+ "2-Step" : ["sdxl_lightning_2step_unet.safetensors", 2],
18
+ "4-Step" : ["sdxl_lightning_4step_unet.safetensors", 4],
19
+ "8-Step" : ["sdxl_lightning_8step_unet.safetensors", 8],
20
+ }
 
 
 
 
21
 
 
 
 
 
22
 
23
+ # Ensure model and scheduler are initialized in GPU-enabled function
24
+ if torch.cuda.is_available():
25
+ pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
26
 
 
 
 
 
 
27
  if SAFETY_CHECKER:
28
  from safety_checker import StableDiffusionSafetyChecker
29
  from transformers import CLIPFeatureExtractor
30
 
31
  safety_checker = StableDiffusionSafetyChecker.from_pretrained(
32
  "CompVis/stable-diffusion-safety-checker"
33
+ ).to("cuda")
34
  feature_extractor = CLIPFeatureExtractor.from_pretrained(
35
  "openai/clip-vit-base-patch32"
36
  )
 
38
  def check_nsfw_images(
39
  images: list[Image.Image],
40
  ) -> tuple[list[Image.Image], list[bool]]:
41
+ safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
42
  has_nsfw_concepts = safety_checker(
43
  images=[images],
44
+ clip_input=safety_checker_input.pixel_values.to("cuda")
45
  )
46
 
47
  return images, has_nsfw_concepts
48
 
49
+ # Function
50
+ @spaces.GPU(enable_queue=True)
51
+ def generate_image(prompt, ckpt):
52
+
53
+ checkpoint = checkpoints[ckpt][0]
54
+ num_inference_steps = checkpoints[ckpt][1]
55
+
56
+ if num_inference_steps==1:
57
+ # Ensure sampler uses "trailing" timesteps and "sample" prediction type for 1-step inference.
58
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
59
+ else:
60
+ # Ensure sampler uses "trailing" timesteps.
61
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
62
+
63
+ pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device="cuda"))
64
+ results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0)
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  if SAFETY_CHECKER:
67
  images, has_nsfw_concepts = check_nsfw_images(results.images)
68
  if any(has_nsfw_concepts):
69
  gr.Warning("NSFW content detected.")
70
  return Image.new("RGB", (512, 512))
71
+ return images[0]
72
+ return results.images[0]
 
 
73
 
74
 
75
+
76
+ # Gradio Interface
77
+ description = """
78
+ This demo utilizes the SDXL-Lightning model by ByteDance, which is a lightning-fast text-to-image generative model capable of producing high-quality images in 4 steps.
79
+ As a community effort, this demo was put together by AngryPenguin. Link to model: https://huggingface.co/ByteDance/SDXL-Lightning
 
 
 
 
80
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
+ with gr.Blocks(css="style.css") as demo:
83
+ gr.HTML("<h1><center>Text-to-Image with SDXL-Lightning ⚡</center></h1>")
84
+ gr.Markdown(description)
85
+ with gr.Group():
86
+ with gr.Row():
87
+ prompt = gr.Textbox(label='Enter you image prompt:', scale=8)
88
+ ckpt = gr.Dropdown(label='Select inference steps',choices=['1-Step', '2-Step', '4-Step', '8-Step'], value='4-Step', interactive=True)
89
+ submit = gr.Button(scale=1, variant='primary')
90
+ img = gr.Image(label='SDXL-Lightning Generated Image')
91
+
92
+ prompt.submit(fn=generate_image,
93
+ inputs=[prompt, ckpt],
94
+ outputs=img,
95
+ )
96
+ submit.click(fn=generate_image,
97
+ inputs=[prompt, ckpt],
98
+ outputs=img,
99
+ )
100
+
101
+ demo.queue().launch()
requirements.txt CHANGED
@@ -1,16 +1,5 @@
1
- diffusers==0.26.3
2
  transformers
3
- gradio==4.19.2
4
- torch==2.1.0
5
- fastapi==0.104.0
6
- uvicorn==0.23.2
7
- Pillow==10.1.0
8
- accelerate==0.24.0
9
- compel==2.0.2
10
- controlnet-aux==0.0.7
11
- peft==0.6.0
12
- xformers
13
- hf_transfer
14
- huggingface_hub
15
- safetensors
16
- stable_fast @ https://github.com/chengzeyi/stable-fast/releases/download/v1.0.2/stable_fast-1.0.2+torch211cu121-cp310-cp310-manylinux2014_x86_64.whl; sys_platform != 'darwin' or platform_machine != 'arm64'
 
 
1
  transformers
2
+ diffusers
3
+ torch
4
+ accelerate
5
+ gradio