ameerazam08's picture
Upload folder using huggingface_hub
6a6edcb verified
#@title Load models
import torch
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
device = torch.device("cpu")
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
device = torch.device("mps")
if torch.cuda.is_available():
device = torch.device("cuda")
print("RUNNING ON:", device)
c_dtype = torch.bfloat16 if device.type == "cpu" else torch.float
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=c_dtype)
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.half)
prior.to(device)
decoder.to(device)
import random
import gc
import numpy as np
import gradio as gr
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1536
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
def generate_prior(prompt, negative_prompt, generator, width, height, num_inference_steps, guidance_scale, num_images_per_prompt):
prior_output = prior(
prompt=prompt,
height=height,
width=width,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
num_inference_steps=num_inference_steps
)
torch.cuda.empty_cache()
gc.collect()
return prior_output.image_embeddings
def generate_decoder(prior_embeds, prompt, negative_prompt, generator, num_inference_steps, guidance_scale):
decoder_output = decoder(
image_embeddings=prior_embeds.to(device=device, dtype=decoder.dtype),
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
output_type="pil",
num_inference_steps=num_inference_steps,
generator=generator
).images
torch.cuda.empty_cache()
gc.collect()
return decoder_output
@torch.inference_mode()
def generate(
prompt: str,
negative_prompt: str = "",
seed: int = 0,
randomize_seed: bool = True,
width: int = 1024,
height: int = 1024,
prior_num_inference_steps: int = 20,
prior_guidance_scale: float = 4.0,
decoder_num_inference_steps: int = 10,
decoder_guidance_scale: float = 0.0,
num_images_per_prompt: int = 2,
):
"""Generate images using Stable Cascade."""
seed = randomize_seed_fn(seed, randomize_seed)
print("seed:", seed)
generator = torch.Generator(device=device).manual_seed(seed)
prior_embeds = generate_prior(
prompt=prompt,
negative_prompt=negative_prompt,
generator=generator,
width=width,
height=height,
num_inference_steps=prior_num_inference_steps,
guidance_scale=prior_guidance_scale,
num_images_per_prompt=num_images_per_prompt,
)
decoder_output = generate_decoder(
prior_embeds=prior_embeds,
prompt=prompt,
negative_prompt=negative_prompt,
generator=generator,
num_inference_steps=decoder_num_inference_steps,
guidance_scale=decoder_guidance_scale,
)
return decoder_output
examples = [
"An astronaut riding a green horse",
"A mecha robot in a favela by Tarsila do Amaral",
"The sprirt of a Tamagotchi wandering in the city of Los Angeles",
"A delicious feijoada ramen dish"
]
with gr.Blocks(css="gradio_app/style.css") as demo:
with gr.Column():
prompt = gr.Text(
label="Prompt",
show_label=False,
placeholder="Enter your prompt",
)
run_button = gr.Button("Run")
with gr.Accordion("Advanced options", open=False):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a Negative Prompt",
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
width = gr.Slider(
label="Width",
minimum=1024,
maximum=MAX_IMAGE_SIZE,
step=128,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=1024,
maximum=MAX_IMAGE_SIZE,
step=128,
value=1024,
)
num_images_per_prompt = gr.Slider(
label="Number of Images",
minimum=1,
maximum=2,
step=1,
value=2,
)
prior_guidance_scale = gr.Slider(
label="Prior Guidance Scale",
minimum=0,
maximum=20,
step=0.1,
value=4.0,
)
prior_num_inference_steps = gr.Slider(
label="Prior Inference Steps",
minimum=10,
maximum=30,
step=1,
value=20,
)
decoder_guidance_scale = gr.Slider(
label="Decoder Guidance Scale",
minimum=0,
maximum=0,
step=0.1,
value=0.0,
)
decoder_num_inference_steps = gr.Slider(
label="Decoder Inference Steps",
minimum=4,
maximum=12,
step=1,
value=10,
)
with gr.Column():
result = gr.Gallery(label="Result", show_label=False)
gr.Examples(
examples=examples,
inputs=prompt,
outputs=result,
fn=generate,
)
inputs = [
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
prior_num_inference_steps,
prior_guidance_scale,
decoder_num_inference_steps,
decoder_guidance_scale,
num_images_per_prompt,
]
prompt.submit(
fn=generate,
inputs=inputs,
outputs=result,
)
negative_prompt.submit(
fn=generate,
inputs=inputs,
outputs=result,
)
run_button.click(
fn=generate,
inputs=inputs,
outputs=result,
)
demo.queue(20).launch()