import os import random import gradio as gr import numpy as np import PIL.Image import torch from typing import List from diffusers.utils import numpy_to_pil from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS from fastapi import FastAPI import uvicorn from pydantic import BaseModel from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse class GenerateRequest(BaseModel): prompt: str negative_prompt: str = "" seed: int = 0 app = FastAPI() origins = [ "http://localhost.tiangolo.com", "https://localhost.tiangolo.com", "http://localhost", "http://localhost:8080", ] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/") async def main(): # redirect to https://huggingface.co/spaces/multimodalart/stable-cascade return RedirectResponse("https://multimodalart-stable-cascade.hf.space/?__theme=system") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860) # MAX_SEED = np.iinfo(np.int32).max # USE_TORCH_COMPILE = False # dtype = torch.bfloat16 # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # if torch.cuda.is_available(): # prior_pipeline = StableCascadePriorPipeline.from_pretrained( # "stabilityai/stable-cascade-prior", torch_dtype=dtype) # .to(device) # decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained( # "stabilityai/stable-cascade", torch_dtype=dtype) # .to(device) # prior_pipeline.to(device) # decoder_pipeline.to(device) # if USE_TORCH_COMPILE: # prior_pipeline.prior = torch.compile( # prior_pipeline.prior, mode="reduce-overhead", fullgraph=True) # decoder_pipeline.decoder = torch.compile( # decoder_pipeline.decoder, mode="max-autotune", fullgraph=True) # else: # prior_pipeline = None # decoder_pipeline = None # def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: # if randomize_seed: # seed = random.randint(0, MAX_SEED) # return seed # def generate( # prompt: str, # negative_prompt: str = "", # seed: int = 0, # width: int = 1024, # height: int = 1024, # prior_num_inference_steps: int = 30, # # prior_timesteps: List[float] = None, # prior_guidance_scale: float = 4.0, # decoder_num_inference_steps: int = 12, # # decoder_timesteps: List[float] = None, # decoder_guidance_scale: float = 0.0, # num_images_per_prompt: int = 2, # progress=gr.Progress(track_tqdm=True), # ) -> PIL.Image.Image: # generator = torch.Generator().manual_seed(seed) # prior_output = prior_pipeline( # prompt=prompt, # height=height, # width=width, # num_inference_steps=prior_num_inference_steps, # timesteps=DEFAULT_STAGE_C_TIMESTEPS, # negative_prompt=negative_prompt, # guidance_scale=prior_guidance_scale, # num_images_per_prompt=num_images_per_prompt, # generator=generator, # ) # decoder_output = decoder_pipeline( # image_embeddings=prior_output.image_embeddings, # prompt=prompt, # num_inference_steps=decoder_num_inference_steps, # # timesteps=decoder_timesteps, # guidance_scale=decoder_guidance_scale, # negative_prompt=negative_prompt, # generator=generator, # output_type="pil", # ).images # return decoder_output[0] # examples = [ # "An astronaut riding a green horse", # "A mecha robot in a favela by Tarsila do Amaral", # "The sprirt of a Tamagotchi wandering in the city of Los Angeles", # "A delicious feijoada ramen dish" # ] # with gr.Blocks() as demo: # gr.Markdown(DESCRIPTION) # gr.DuplicateButton( # value="Duplicate Space for private use", # elem_id="duplicate-button", # visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1", # ) # with gr.Group(): # with gr.Row(): # prompt = gr.Text( # label="Prompt", # show_label=False, # max_lines=1, # placeholder="Enter your prompt", # container=False, # ) # run_button = gr.Button("Run", scale=0) # result = gr.Image(label="Result", show_label=False) # with gr.Accordion("Advanced options", open=False): # negative_prompt = gr.Text( # label="Negative prompt", # max_lines=1, # placeholder="Enter a Negative Prompt", # ) # seed = gr.Slider( # label="Seed", # minimum=0, # maximum=MAX_SEED, # step=1, # value=0, # ) # randomize_seed = gr.Checkbox(label="Randomize seed", value=True) # with gr.Row(): # width = gr.Slider( # label="Width", # minimum=1024, # maximum=1536, # step=512, # value=1024, # ) # height = gr.Slider( # label="Height", # minimum=1024, # maximum=1536, # step=512, # value=1024, # ) # num_images_per_prompt = gr.Slider( # label="Number of Images", # minimum=1, # maximum=2, # step=1, # value=1, # ) # with gr.Row(): # prior_guidance_scale = gr.Slider( # label="Prior Guidance Scale", # minimum=0, # maximum=20, # step=0.1, # value=4.0, # ) # prior_num_inference_steps = gr.Slider( # label="Prior Inference Steps", # minimum=10, # maximum=30, # step=1, # value=20, # ) # decoder_guidance_scale = gr.Slider( # label="Decoder Guidance Scale", # minimum=0, # maximum=0, # step=0.1, # value=0.0, # ) # decoder_num_inference_steps = gr.Slider( # label="Decoder Inference Steps", # minimum=4, # maximum=12, # step=1, # value=10, # ) # gr.Examples( # examples=examples, # inputs=prompt, # outputs=result, # fn=generate, # cache_examples=False, # ) # inputs = [ # prompt, # negative_prompt, # seed, # width, # height, # prior_num_inference_steps, # # prior_timesteps, # prior_guidance_scale, # decoder_num_inference_steps, # # decoder_timesteps, # decoder_guidance_scale, # num_images_per_prompt, # ] # gr.on( # triggers=[prompt.submit, negative_prompt.submit, run_button.click], # fn=randomize_seed_fn, # inputs=[seed, randomize_seed], # outputs=seed, # queue=False, # api_name=False, # ).then( # fn=generate, # inputs=inputs, # outputs=result, # api_name="run", # ) # if __name__ == "__main__": # demo.queue(max_size=20).launch()