File size: 3,532 Bytes
9ee0e66
b0f3145
 
187715d
b0f3145
 
187715d
b0f3145
e9034bb
88dc089
187715d
b0f3145
 
187715d
 
 
 
 
b0f3145
f151301
e08daf1
187715d
 
 
59f3984
187715d
 
 
59f3984
187715d
 
 
 
 
0b120d5
b0f3145
187715d
 
 
 
 
 
 
0ce0e61
b0f3145
187715d
75859e2
187715d
 
 
f151301
2a8c2d1
bef93f3
187715d
f151301
187715d
f151301
 
 
 
187715d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a8c2d1
e9028a2
187715d
 
dd0e9d5
187715d
 
 
 
 
 
 
 
 
 
 
 
b0f3145
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import spaces
import gradio as gr
import torch
from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from PIL import Image

SAFETY_CHECKER = True

# Constants
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
checkpoints = {
    "1-Step" : ["sdxl_lightning_1step_unet_x0.safetensors", 1],
    "2-Step" : ["sdxl_lightning_2step_unet.safetensors", 2],
    "4-Step" : ["sdxl_lightning_4step_unet.safetensors", 4],
    "8-Step" : ["sdxl_lightning_8step_unet.safetensors", 8],
}
loaded = None

# Ensure model and scheduler are initialized in GPU-enabled function
if torch.cuda.is_available():
    pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")

if SAFETY_CHECKER:
    from safety_checker import StableDiffusionSafetyChecker
    from transformers import CLIPFeatureExtractor

    safety_checker = StableDiffusionSafetyChecker.from_pretrained(
        "CompVis/stable-diffusion-safety-checker"
    ).to("cuda")
    feature_extractor = CLIPFeatureExtractor.from_pretrained(
        "openai/clip-vit-base-patch32"
    )

    def check_nsfw_images(
        images: list[Image.Image],
    ) -> tuple[list[Image.Image], list[bool]]:
        safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
        has_nsfw_concepts = safety_checker(
            images=[images],
            clip_input=safety_checker_input.pixel_values.to("cuda")
        )

        return images, has_nsfw_concepts

# Function 
@spaces.GPU(enable_queue=True)
def generate_image(prompt, ckpt):
    global loaded
    print(prompt, ckpt)

    checkpoint = checkpoints[ckpt][0]
    num_inference_steps = checkpoints[ckpt][1]

    if loaded != num_inference_steps:
        pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if num_inference_steps==1 else "epsilon")
        pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device="cuda"))
        loaded = num_inference_steps
        
    results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0)

    if SAFETY_CHECKER:
        images, has_nsfw_concepts = check_nsfw_images(results.images)
        if any(has_nsfw_concepts):
            gr.Warning("NSFW content detected.")
            return Image.new("RGB", (512, 512))
        return images[0]
    return results.images[0]



# Gradio Interface

with gr.Blocks(css="style.css") as demo:
    gr.HTML("<h1><center>SDXL-Lightning ⚡</center></h1>")
    gr.HTML("<p><center>Lightning-fast text-to-image generation</center></p><p><center><a href='https://huggingface.co/ByteDance/SDXL-Lightning'>https://huggingface.co/ByteDance/SDXL-Lightning</a></center></p>")
    with gr.Group():
        with gr.Row():
            prompt = gr.Textbox(label='Enter your prompt (English)', scale=8)
            ckpt = gr.Dropdown(label='Select inference steps',choices=['1-Step', '2-Step', '4-Step', '8-Step'], value='4-Step', interactive=True)
            submit = gr.Button(scale=1, variant='primary')
    img = gr.Image(label='SDXL-Lightning Generated Image')

    prompt.submit(fn=generate_image,
                 inputs=[prompt, ckpt],
                 outputs=img,
                 )
    submit.click(fn=generate_image,
                 inputs=[prompt, ckpt],
                 outputs=img,
                 )
    
demo.queue().launch()