Spaces:
Running
Running
File size: 7,425 Bytes
ae4c73e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import gradio as gr
import torch
from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import spaces
import os
from PIL import Image, ImageFilter
from typing import List, Tuple
SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1"
# Constants
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
checkpoints = {
"1-Step" : ["sdxl_lightning_1step_unet_x0.safetensors", 1],
"2-Step" : ["sdxl_lightning_2step_unet.safetensors", 2],
"4-Step" : ["sdxl_lightning_4step_unet.safetensors", 4],
"8-Step" : ["sdxl_lightning_8step_unet.safetensors", 8],
}
aspect_ratios = {
"21:9": (21, 9),
"2:1": (2, 1),
"16:9": (16, 9),
"5:4": (5, 4),
"4:3": (4, 3),
"3:2": (3, 2),
"1:1": (1, 1),
}
# Function to calculate resolution
def calculate_resolution(aspect_ratio, mode='landscape', total_pixels=1024*1024, divisibility=8):
if aspect_ratio not in aspect_ratios:
raise ValueError(f"Invalid aspect ratio: {aspect_ratio}")
width_multiplier, height_multiplier = aspect_ratios[aspect_ratio]
ratio = width_multiplier / height_multiplier
if mode == 'portrait':
# Swap the ratio for portrait mode
ratio = 1 / ratio
height = int((total_pixels / ratio) ** 0.5)
height -= height % divisibility
width = int(height * ratio)
width -= width % divisibility
while width * height > total_pixels:
height -= divisibility
width = int(height * ratio)
width -= width % divisibility
return width, height
# Example prompts with ckpt, aspect, and mode
examples = [
{"prompt": "A futuristic cityscape at sunset", "ckpt": "4-Step", "aspect": "16:9", "mode": "landscape"},
{"prompt": "pair of shoes made of dried fruit skins, 3d render, bright colours, clean composition, beautiful artwork, logo", "ckpt": "2-Step", "aspect": "1:1", "mode": "portrait"},
{"prompt": "A portrait of a robot in the style of Renaissance art", "ckpt": "2-Step", "aspect": "1:1", "mode": "portrait"},
{"prompt": "full body of alien shaped like woman, big golden eyes, mars planet, photo, digital art, fantasy", "ckpt": "4-Step", "aspect": "1:1", "mode": "portrait"},
{"prompt": "A serene landscape with mountains and a river", "ckpt": "8-Step", "aspect": "3:2", "mode": "landscape"},
{"prompt": "post-apocalyptic wasteland, the most delicate beautiful flower with green leaves growing from dust and rubble, vibrant colours, cinematic", "ckpt": "8-Step", "aspect": "16:9", "mode": "landscape"}
]
# Define a function to set the example inputs
def set_example(selected_prompt):
# Find the example that matches the selected prompt
for example in examples:
if example["prompt"] == selected_prompt:
return example["prompt"], example["ckpt"], example["aspect"], example["mode"]
return None, None, None, None # Default values if not found
# Ensure model and scheduler are initialized in GPU-enabled function
if torch.cuda.is_available():
pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
if SAFETY_CHECKER:
from safety_checker import StableDiffusionSafetyChecker
from transformers import CLIPFeatureExtractor
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker"
).to("cuda")
feature_extractor = CLIPFeatureExtractor.from_pretrained(
"openai/clip-vit-base-patch32"
)
def check_nsfw_images(
images: List[Image.Image]
) -> Tuple[List[Image.Image], List[bool]]:
# Assuming feature_extractor and safety_checker are defined and initialized elsewhere
# Convert PIL Images to the format expected by the feature extractor
# This often involves converting them to tensors, but the exact method
# depends on the feature_extractor's requirements
safety_checker_inputs = [feature_extractor(image, return_tensors="pt").to("cuda") for image in images]
# Get NSFW concepts for each image
has_nsfw_concepts = [safety_checker(
images=[image],
clip_input=safety_checker_input.pixel_values.to("cuda")
) for image, safety_checker_input in zip(images, safety_checker_inputs)]
# Flatten the has_nsfw_concepts list if it's nested
has_nsfw_concepts = [item for sublist in has_nsfw_concepts for item in sublist]
return images, has_nsfw_concepts
# Function
@spaces.GPU(enable_queue=True)
def generate_image(prompt, ckpt, aspect_ratio, mode):
width, height = calculate_resolution(aspect_ratio, mode) # Calculate resolution based on the aspect ratio
checkpoint = checkpoints[ckpt][0]
num_inference_steps = checkpoints[ckpt][1]
if num_inference_steps==1:
# Ensure sampler uses "trailing" timesteps and "sample" prediction type for 1-step inference.
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
else:
# Ensure sampler uses "trailing" timesteps.
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device="cuda"))
results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0, width=width, height=height )
if SAFETY_CHECKER:
images, has_nsfw_concepts = check_nsfw_images(results.images)
if any(has_nsfw_concepts):
gr.Warning("NSFW content detected.")
# Apply a blur filter to the first image in the results
blurred_image = images[0].filter(ImageFilter.GaussianBlur(16)) # Adjust the radius as needed
return blurred_image
return images[0]
return results.images[0]
# Gradio Interface
description = """
SDXL-Lightning ByteDance model demo. Link to model: https://huggingface.co/ByteDance/SDXL-Lightning
"""
with gr.Blocks(css="style.css") as demo:
gr.HTML("<h1><center>Text-to-Image with SDXL-Lightning ⚡</center></h1>")
gr.Markdown(description)
with gr.Group():
with gr.Row():
prompt = gr.Textbox(label='Enter you image prompt:', scale=8)
with gr.Row():
ckpt = gr.Dropdown(label='Select inference steps',choices=['1-Step', '2-Step', '4-Step', '8-Step'], value='4-Step', interactive=True)
aspect = gr.Dropdown(label='Aspect Ratio', choices=list(aspect_ratios.keys()), value='1:1', interactive=True)
mode = gr.Dropdown(label='Mode', choices=['landscape', 'portrait'], value='landscape') # Mode as a dropdown
submit = gr.Button(scale=1, variant='primary')
img = gr.Image(label='SDXL-Lightning Generated Image')
prompt.submit(fn=generate_image,
inputs=[prompt, ckpt, aspect, mode],
outputs=img,
)
submit.click(fn=generate_image,
inputs=[prompt, ckpt, aspect, mode],
outputs=img,
)
# Dropdown for selecting examples
example_dropdown = gr.Dropdown(label='Select an Example', choices=[e["prompt"] for e in examples])
example_dropdown.change(fn=set_example, inputs=example_dropdown, outputs=[prompt, ckpt, aspect, mode])
demo.queue().launch() |