Spaces:
Running
Running
File size: 8,070 Bytes
ae4c73e 3152b9c ae4c73e 3152b9c ae4c73e 3678910 ae4c73e a1bfb17 ae4c73e e4e61fc 3152b9c ae4c73e 87ea128 3152b9c ae4c73e 3152b9c ae4c73e 3152b9c ae4c73e 3152b9c ae4c73e 3152b9c ae4c73e e4e61fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import gradio as gr
import torch
from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import spaces
import os
from PIL import Image, ImageFilter
from typing import List, Tuple
SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1"
# Constants
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
checkpoints = {
"1-Step" : ["sdxl_lightning_1step_unet_x0.safetensors", 1],
"2-Step" : ["sdxl_lightning_2step_unet.safetensors", 2],
"4-Step" : ["sdxl_lightning_4step_unet.safetensors", 4],
"8-Step" : ["sdxl_lightning_8step_unet.safetensors", 8],
}
aspect_ratios = {
"21:9": (21, 9),
"2:1": (2, 1),
"16:9": (16, 9),
"5:4": (5, 4),
"4:3": (4, 3),
"3:2": (3, 2),
"1:1": (1, 1),
}
# Function to calculate resolution
def calculate_resolution(aspect_ratio, mode='landscape', total_pixels=1024*1024, divisibility=8):
if aspect_ratio not in aspect_ratios:
raise ValueError(f"Invalid aspect ratio: {aspect_ratio}")
width_multiplier, height_multiplier = aspect_ratios[aspect_ratio]
ratio = width_multiplier / height_multiplier
if mode == 'portrait':
# Swap the ratio for portrait mode
ratio = 1 / ratio
height = int((total_pixels / ratio) ** 0.5)
height -= height % divisibility
width = int(height * ratio)
width -= width % divisibility
while width * height > total_pixels:
height -= divisibility
width = int(height * ratio)
width -= width % divisibility
return width, height
# Example prompts with ckpt, aspect, and mode
examples = [
{"prompt": "A futuristic cityscape at sunset", "negative_prompt": "Ugly", "ckpt": "4-Step", "aspect": "16:9", "mode": "landscape"},
{"prompt": "pair of shoes made of dried fruit skins, 3d render, bright colours, clean composition, beautiful artwork, logo", "negative_prompt": "Ugly", "ckpt": "2-Step", "aspect": "1:1", "mode": "portrait"},
{"prompt": "A portrait of a robot in the style of Renaissance art", "negative_prompt": "Ugly", "ckpt": "2-Step", "aspect": "1:1", "mode": "portrait"},
{"prompt": "full body of alien shaped like woman, big golden eyes, mars planet, photo, digital art, fantasy", "negative_prompt": "Ugly", "ckpt": "4-Step", "aspect": "1:1", "mode": "portrait"},
{"prompt": "A serene landscape with mountains and a river", "negative_prompt": "Ugly", "ckpt": "8-Step", "aspect": "3:2", "mode": "landscape"},
{"prompt": "post-apocalyptic wasteland, the most delicate beautiful flower with green leaves growing from dust and rubble, vibrant colours, cinematic", "negative_prompt": "Ugly", "ckpt": "8-Step", "aspect": "16:9", "mode": "landscape"}
]
# Define a function to set the example inputs
def set_example(selected_prompt):
# Find the example that matches the selected prompt
for example in examples:
if example["prompt"] == selected_prompt:
return example["prompt"], example["negative_prompt"], example["ckpt"], example["aspect"], example["mode"]
return None, None, None, None, None # Default values if not found
# Check if CUDA is available (GPU support), and set the appropriate device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the pipeline for the specified device
# For GPU, use torch_dtype=torch.float16 for better performance
if device == "cuda":
pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to(device)
else:
pipe = StableDiffusionXLPipeline.from_pretrained(base).to(device)
if SAFETY_CHECKER:
from safety_checker import StableDiffusionSafetyChecker
from transformers import CLIPFeatureExtractor
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker"
).to(device)
feature_extractor = CLIPFeatureExtractor.from_pretrained(
"openai/clip-vit-base-patch32"
)
def check_nsfw_images(
images: List[Image.Image]
) -> Tuple[List[Image.Image], List[bool]]:
# Assuming feature_extractor and safety_checker are defined and initialized elsewhere
# Convert PIL Images to the format expected by the feature extractor
# This often involves converting them to tensors, but the exact method
# depends on the feature_extractor's requirements
safety_checker_inputs = [feature_extractor(image, return_tensors="pt").to("cuda") for image in images]
# Get NSFW concepts for each image
has_nsfw_concepts = [safety_checker(
images=[image],
clip_input=safety_checker_input.pixel_values.to("cuda")
) for image, safety_checker_input in zip(images, safety_checker_inputs)]
# Flatten the has_nsfw_concepts list if it's nested
has_nsfw_concepts = [item for sublist in has_nsfw_concepts for item in sublist]
return images, has_nsfw_concepts
# Function
@spaces.GPU(enable_queue=True)
def generate_image(prompt, negative_prompt, ckpt, aspect_ratio, mode):
width, height = calculate_resolution(aspect_ratio, mode) # Calculate resolution based on the aspect ratio
checkpoint = checkpoints[ckpt][0]
num_inference_steps = checkpoints[ckpt][1]
if num_inference_steps==1:
# Ensure sampler uses "trailing" timesteps and "sample" prediction type for 1-step inference.
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
else:
# Ensure sampler uses "trailing" timesteps.
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device=device))
results = pipe(prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=0, width=width, height=height )
if SAFETY_CHECKER:
images, has_nsfw_concepts = check_nsfw_images(results.images)
if any(has_nsfw_concepts):
gr.Warning("NSFW content detected.")
# Apply a blur filter to the first image in the results
blurred_image = images[0].filter(ImageFilter.GaussianBlur(16)) # Adjust the radius as needed
return blurred_image
return images[0]
return results.images[0]
# Gradio Interface
description = """
SDXL-Lightning ByteDance model demo. Link to model: https://huggingface.co/ByteDance/SDXL-Lightning
"""
with gr.Blocks(css="style.css") as demo:
gr.HTML("<h1><center>Text-to-Image with SDXL-Lightning ⚡</center></h1>")
gr.Markdown(description)
with gr.Group():
with gr.Row():
prompt = gr.Textbox(label='Enter you image prompt:', scale=8)
with gr.Row():
negative_prompt = gr.Textbox(label='Optional negative prompt:', scale=8)
with gr.Row():
ckpt = gr.Dropdown(label='Select inference steps',choices=['1-Step', '2-Step', '4-Step', '8-Step'], value='4-Step', interactive=True)
aspect = gr.Dropdown(label='Aspect Ratio', choices=list(aspect_ratios.keys()), value='1:1', interactive=True)
mode = gr.Dropdown(label='Mode', choices=['landscape', 'portrait'], value='landscape') # Mode as a dropdown
submit = gr.Button(scale=1, variant='primary')
img = gr.Image(label='SDXL-Lightning Generated Image')
prompt.submit(fn=generate_image,
inputs=[prompt, negative_prompt, ckpt, aspect, mode],
outputs=img,
)
submit.click(fn=generate_image,
inputs=[prompt, negative_prompt, ckpt, aspect, mode],
outputs=img,
)
# Dropdown for selecting examples
example_dropdown = gr.Dropdown(label='Select an Example', choices=[e["prompt"] for e in examples])
example_dropdown.change(fn=set_example, inputs=example_dropdown, outputs=[prompt, negative_prompt, ckpt, aspect, mode])
demo.queue().launch()
|