Spaces:
Paused
Paused
import torch | |
import gradio as gr | |
from PIL import Image | |
import spaces | |
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline | |
device = "cuda" | |
num_images_per_prompt = 1 | |
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to(device) | |
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float16).to(device) | |
deafult_negative = "poorly Rendered face, poorly drawn face, poor facial details, poorly drawn hands, poorly rendered hands, low resolution, blurry image, oversaturated, bad anatomy, signature, watermark, username, error, missing limbs, error, out of frame, extra fingers, mutated hands, poorly drawn hands, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username" | |
prompt_add = "(dark shot:1.17), epic realistic, faded, ((neutral colors)), art, (hdr:1.5), (muted colors:1.2), hyperdetailed, (artstation:1.5), cinematic, warm lights, dramatic light, (intricate details:1.1), complex background, (rutkowski:0.8), (teal and orange:0.4)" | |
css = """ | |
footer { | |
visibility: hidden | |
} | |
#generate_button { | |
color: white; | |
border-color: #007bff; | |
background: #2563eb; | |
} | |
#save_button { | |
color: white; | |
border-color: #028b40; | |
background: #01b97c; | |
width: 200px; | |
} | |
#settings_header { | |
background: rgb(245, 105, 105); | |
} | |
""" | |
def gen(prompt, negative, width, height): | |
prior_output = prior( | |
prompt=f"{prompt}, {prompt_add}", | |
height=height, | |
width=width, | |
negative_prompt=negative, | |
guidance_scale=4.0, | |
num_images_per_prompt=num_images_per_prompt, | |
num_inference_steps=25 | |
) | |
decoder_output = decoder( | |
image_embeddings=prior_output.image_embeddings.half(), | |
prompt=f"{prompt}, {prompt_add}", | |
negative_prompt=negative, | |
guidance_scale=0.0, | |
output_type="pil", | |
num_inference_steps=10 | |
).images | |
return decoder_output | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown("# Stable Cascade ```DEMO```") | |
with gr.Row(): | |
prompt = gr.Textbox(show_label=False, placeholder="Enter your prompt", max_lines=3, lines=1, interactive=True, scale=20) | |
button = gr.Button(value="Generate", scale=1) | |
with gr.Accordion("Advanced options", open=False): | |
with gr.Row(): | |
negative = gr.Textbox(show_label=False, value=deafult_negative, placeholder="Enter a negative", max_lines=2, lines=1, interactive=True) | |
with gr.Row(): | |
width = gr.Slider(label="Width", minimum=1024, maximum=2048, step=8, value=1024, interactive=True) | |
height = gr.Slider(label="Height", minimum=1024, maximum=2048, step=8, value=1024, interactive=True) | |
with gr.Row(): | |
gallery = gr.Gallery(show_label=False, rows=1, columns=1, allow_preview=True, preview=True) | |
button.click(gen, inputs=[prompt, negative, width, height], outputs=gallery) | |
demo.launch(show_api=False) | |