import numpy as np import PIL.Image import torch from typing import List from diffusers.utils import numpy_to_pil from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS from fastapi import FastAPI import uvicorn from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse, StreamingResponse import io import os from pathlib import Path from db import Database import uuid import logging logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO")) MAX_SEED = np.iinfo(np.int32).max USE_TORCH_COMPILE = os.environ.get("USE_TORCH_COMPILE", "0") == "1" SPACE_ID = os.environ.get('SPACE_ID', '') DB_PATH = Path("/data/cache") if SPACE_ID else Path("./cache") IMGS_PATH = DB_PATH / "imgs" DB_PATH.mkdir(exist_ok=True, parents=True) IMGS_PATH.mkdir(exist_ok=True, parents=True) database = Database(DB_PATH) with database() as db: cursor = db.cursor() cursor.execute("SELECT * FROM cache") print(list(cursor.fetchall())) dtype = torch.bfloat16 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if torch.cuda.is_available(): prior_pipeline = StableCascadePriorPipeline.from_pretrained( "stabilityai/stable-cascade-prior", torch_dtype=dtype ) # .to(device) decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained( "stabilityai/stable-cascade", torch_dtype=dtype ) # .to(device) prior_pipeline.to(device) decoder_pipeline.to(device) if USE_TORCH_COMPILE: prior_pipeline.prior = torch.compile( prior_pipeline.prior, mode="reduce-overhead", fullgraph=True ) decoder_pipeline.decoder = torch.compile( decoder_pipeline.decoder, mode="max-autotune", fullgraph=True ) def generate( prompt: str, negative_prompt: str = "", seed: int = 0, width: int = 1024, height: int = 1024, prior_num_inference_steps: int = 20, prior_guidance_scale: float = 4.0, decoder_num_inference_steps: int = 10, decoder_guidance_scale: float = 0.0, num_images_per_prompt: int = 2, ) -> PIL.Image.Image: generator = torch.Generator().manual_seed(seed) prior_output = prior_pipeline( prompt=prompt, height=height, width=width, num_inference_steps=prior_num_inference_steps, timesteps=DEFAULT_STAGE_C_TIMESTEPS, negative_prompt=negative_prompt, guidance_scale=prior_guidance_scale, num_images_per_prompt=num_images_per_prompt, generator=generator, ) decoder_output = decoder_pipeline( image_embeddings=prior_output.image_embeddings, prompt=prompt, num_inference_steps=decoder_num_inference_steps, # timesteps=decoder_timesteps, guidance_scale=decoder_guidance_scale, negative_prompt=negative_prompt, generator=generator, output_type="pil", ).images return decoder_output[0] app = FastAPI() origins = [ "http://huggingface.co", ] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/image") async def generate_image(prompt: str, negative_prompt: str, seed: int = 2134213213): cached_img = database.check(prompt, negative_prompt, seed) if cached_img: logging.info(f"Image found in cache: {cached_img[0]}") return StreamingResponse(open(cached_img[0], "rb"), media_type="image/jpeg") logging.info(f"Image not found in cache, generating new image") pil_image = generate(prompt, negative_prompt, seed) img_id = str(uuid.uuid4()) img_path = IMGS_PATH / f"{img_id}.jpg" pil_image.save(img_path) img_io = io.BytesIO() pil_image.save(img_io, "JPEG") img_io.seek(0) database.insert(prompt, negative_prompt, str(img_path), seed) return StreamingResponse(img_io, media_type="image/jpeg") @app.get("/") async def main(): # redirect to https://huggingface.co/spaces/multimodalart/stable-cascade return RedirectResponse( "https://multimodalart-stable-cascade.hf.space/?__theme=system" ) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860) # else: # prior_pipeline = None # decoder_pipeline = None # def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: # if randomize_seed: # seed = random.randint(0, MAX_SEED) # return seed # def generate( # prompt: str, # negative_prompt: str = "", # seed: int = 0, # width: int = 1024, # height: int = 1024, # prior_num_inference_steps: int = 30, # # prior_timesteps: List[float] = None, # prior_guidance_scale: float = 4.0, # decoder_num_inference_steps: int = 12, # # decoder_timesteps: List[float] = None, # decoder_guidance_scale: float = 0.0, # num_images_per_prompt: int = 2, # progress=gr.Progress(track_tqdm=True), # ) -> PIL.Image.Image: # generator = torch.Generator().manual_seed(seed) # prior_output = prior_pipeline( # prompt=prompt, # height=height, # width=width, # num_inference_steps=prior_num_inference_steps, # timesteps=DEFAULT_STAGE_C_TIMESTEPS, # negative_prompt=negative_prompt, # guidance_scale=prior_guidance_scale, # num_images_per_prompt=num_images_per_prompt, # generator=generator, # ) # decoder_output = decoder_pipeline( # image_embeddings=prior_output.image_embeddings, # prompt=prompt, # num_inference_steps=decoder_num_inference_steps, # # timesteps=decoder_timesteps, # guidance_scale=decoder_guidance_scale, # negative_prompt=negative_prompt, # generator=generator, # output_type="pil", # ).images # return decoder_output[0] # examples = [ # "An astronaut riding a green horse", # "A mecha robot in a favela by Tarsila do Amaral", # "The sprirt of a Tamagotchi wandering in the city of Los Angeles", # "A delicious feijoada ramen dish" # ] # with gr.Blocks() as demo: # gr.Markdown(DESCRIPTION) # gr.DuplicateButton( # value="Duplicate Space for private use", # elem_id="duplicate-button", # visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1", # ) # with gr.Group(): # with gr.Row(): # prompt = gr.Text( # label="Prompt", # show_label=False, # max_lines=1, # placeholder="Enter your prompt", # container=False, # ) # run_button = gr.Button("Run", scale=0) # result = gr.Image(label="Result", show_label=False) # with gr.Accordion("Advanced options", open=False): # negative_prompt = gr.Text( # label="Negative prompt", # max_lines=1, # placeholder="Enter a Negative Prompt", # ) # seed = gr.Slider( # label="Seed", # minimum=0, # maximum=MAX_SEED, # step=1, # value=0, # ) # randomize_seed = gr.Checkbox(label="Randomize seed", value=True) # with gr.Row(): # width = gr.Slider( # label="Width", # minimum=1024, # maximum=1536, # step=512, # value=1024, # ) # height = gr.Slider( # label="Height", # minimum=1024, # maximum=1536, # step=512, # value=1024, # ) # num_images_per_prompt = gr.Slider( # label="Number of Images", # minimum=1, # maximum=2, # step=1, # value=1, # ) # with gr.Row(): # prior_guidance_scale = gr.Slider( # label="Prior Guidance Scale", # minimum=0, # maximum=20, # step=0.1, # value=4.0, # ) # prior_num_inference_steps = gr.Slider( # label="Prior Inference Steps", # minimum=10, # maximum=30, # step=1, # value=20, # ) # decoder_guidance_scale = gr.Slider( # label="Decoder Guidance Scale", # minimum=0, # maximum=0, # step=0.1, # value=0.0, # ) # decoder_num_inference_steps = gr.Slider( # label="Decoder Inference Steps", # minimum=4, # maximum=12, # step=1, # value=10, # ) # gr.Examples( # examples=examples, # inputs=prompt, # outputs=result, # fn=generate, # cache_examples=False, # ) # inputs = [ # prompt, # negative_prompt, # seed, # width, # height, # prior_num_inference_steps, # # prior_timesteps, # prior_guidance_scale, # decoder_num_inference_steps, # # decoder_timesteps, # decoder_guidance_scale, # num_images_per_prompt, # ] # gr.on( # triggers=[prompt.submit, negative_prompt.submit, run_button.click], # fn=randomize_seed_fn, # inputs=[seed, randomize_seed], # outputs=seed, # queue=False, # api_name=False, # ).then( # fn=generate, # inputs=inputs, # outputs=result, # api_name="run", # ) # if __name__ == "__main__": # demo.queue(max_size=20).launch()