Spaces:
Runtime error
Runtime error
File size: 6,094 Bytes
f0aa930 35dc227 b31f197 35dc227 f9309dc 20e6fde 35dc227 a064d1f 80eae63 35dc227 7e1ec9f 35dc227 a064d1f 35dc227 80eae63 35dc227 a064d1f 35dc227 a064d1f d1b6cf8 a064d1f 35dc227 80eae63 35dc227 c383967 35dc227 edf4c40 35dc227 c383967 35dc227 edf4c40 35dc227 c7facaa 35dc227 0cc2ea6 80eae63 35dc227 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import torch, os
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
import gradio as gr
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to("cuda")
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float16).to("cuda")
def generate_images(
prompt="a photo of a girl",
negative_prompt="bad,ugly,deformed",
height=1024,
width=1024,
guidance_scale=4.0,
seed=42,
num_images_per_prompt=1,
prior_inference_steps=20,
decoder_inference_steps=10
):
"""
Generates images based on a given prompt using Stable Diffusion models on CUDA device.
Parameters:
- prompt (str): The prompt to generate images for.
- negative_prompt (str): The negative prompt to guide image generation away from.
- height (int): The height of the generated images.
- width (int): The width of the generated images.
- guidance_scale (float): The scale of guidance for the image generation.
- prior_inference_steps (int): The number of inference steps for the prior model.
- decoder_inference_steps (int): The number of inference steps for the decoder model.
Returns:
- List[PIL.Image]: A list of generated PIL Image objects.
"""
generator = torch.Generator(device="cuda").manual_seed(int(seed))
# Generate image embeddings using the prior model
prior_output = prior(
prompt=prompt,
generator=generator,
height=height,
width=width,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
num_inference_steps=prior_inference_steps
)
# Generate images using the decoder model and the embeddings from the prior model
decoder_output = decoder(
image_embeddings=prior_output.image_embeddings.half(),
prompt=prompt,
generator=generator,
negative_prompt=negative_prompt,
guidance_scale=0.0, # Guidance scale typically set to 0 for decoder as guidance is applied in the prior
output_type="pil",
num_inference_steps=decoder_inference_steps
).images
return decoder_output
def web_demo():
with gr.Blocks():
with gr.Row():
with gr.Column():
text2image_prompt = gr.Textbox(
lines=1,
placeholder="Prompt",
show_label=False,
)
text2image_negative_prompt = gr.Textbox(
lines=1,
placeholder="Negative Prompt",
show_label=False,
)
text2image_seed = gr.Number(
value=42,
label="Seed",
)
with gr.Row():
with gr.Column():
text2image_num_images_per_prompt = gr.Slider(
minimum=1,
maximum=4,
step=1,
value=1,
label="Number Image",
)
text2image_height = gr.Slider(
minimum=128,
maximum=1024,
step=32,
value=1024,
label="Image Height",
)
text2image_width = gr.Slider(
minimum=128,
maximum=1024,
step=32,
value=1024,
label="Image Width",
)
with gr.Row():
with gr.Column():
text2image_guidance_scale = gr.Slider(
minimum=0.1,
maximum=15,
step=0.1,
value=4.0,
label="Guidance Scale",
)
text2image_prior_inference_step = gr.Slider(
minimum=1,
maximum=50,
step=1,
value=20,
label="Prior Inference Step",
)
text2image_decoder_inference_step = gr.Slider(
minimum=1,
maximum=50,
step=1,
value=10,
label="Decoder Inference Step",
)
text2image_predict = gr.Button(value="Generate Image")
with gr.Column():
output_image = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="gallery",
).style(grid=(1, 2), height=300)
text2image_predict.click(
fn=generate_images,
inputs=[
text2image_prompt,
text2image_negative_prompt,
text2image_height,
text2image_width,
text2image_guidance_scale,
text2image_seed,
text2image_num_images_per_prompt,
text2image_prior_inference_step,
text2image_decoder_inference_step
],
outputs=output_image,
) |