import numpy as np import PIL.Image import torch from typing import List from diffusers.utils import numpy_to_pil from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS from fastapi import FastAPI import uvicorn from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse, StreamingResponse import io import os from pathlib import Path from db import Database import uuid import logging from fastapi import FastAPI, Request, HTTPException from fastapi.middleware.cors import CORSMiddleware logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO")) MAX_SEED = np.iinfo(np.int32).max USE_TORCH_COMPILE = os.environ.get("USE_TORCH_COMPILE", "0") == "1" SPACE_ID = os.environ.get("SPACE_ID", "") DB_PATH = Path("/data/cache") if SPACE_ID else Path("./cache") IMGS_PATH = DB_PATH / "imgs" DB_PATH.mkdir(exist_ok=True, parents=True) IMGS_PATH.mkdir(exist_ok=True, parents=True) database = Database(DB_PATH) dtype = torch.bfloat16 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if torch.cuda.is_available(): prior_pipeline = StableCascadePriorPipeline.from_pretrained( "stabilityai/stable-cascade-prior", torch_dtype=dtype ) # .to(device) decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained( "stabilityai/stable-cascade", torch_dtype=dtype ) # .to(device) prior_pipeline.to(device) decoder_pipeline.to(device) if USE_TORCH_COMPILE: prior_pipeline.prior = torch.compile( prior_pipeline.prior, mode="reduce-overhead", fullgraph=True ) decoder_pipeline.decoder = torch.compile( decoder_pipeline.decoder, mode="max-autotune", fullgraph=True ) def generate( prompt: str, negative_prompt: str = "", seed: int = 0, width: int = 1024, height: int = 1024, prior_num_inference_steps: int = 20, prior_guidance_scale: float = 4.0, decoder_num_inference_steps: int = 10, decoder_guidance_scale: float = 0.0, num_images_per_prompt: int = 2, ) -> PIL.Image.Image: generator = torch.Generator().manual_seed(seed) prior_output = prior_pipeline( prompt=prompt, height=height, width=width, num_inference_steps=prior_num_inference_steps, timesteps=DEFAULT_STAGE_C_TIMESTEPS, negative_prompt=negative_prompt, guidance_scale=prior_guidance_scale, num_images_per_prompt=num_images_per_prompt, generator=generator, ) decoder_output = decoder_pipeline( image_embeddings=prior_output.image_embeddings, prompt=prompt, num_inference_steps=decoder_num_inference_steps, # timesteps=decoder_timesteps, guidance_scale=decoder_guidance_scale, negative_prompt=negative_prompt, generator=generator, output_type="pil", ).images return decoder_output[0] app = FastAPI() origins = [ "http://huggingface.co", ] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.middleware("http") async def validate_origin(request: Request, call_next): logging.info(f"Request origin: {request.headers}") if request.headers.get("origin") not in origins: raise HTTPException(status_code=403, detail="Forbidden") response = await call_next(request) return response @app.get("/image") async def generate_image(prompt: str, negative_prompt: str = "", seed: int = 2134213213): cached_img = database.check(prompt, negative_prompt, seed) if cached_img: logging.info(f"Image found in cache: {cached_img[0]}") return StreamingResponse(open(cached_img[0], "rb"), media_type="image/jpeg") logging.info(f"Image not found in cache, generating new image") pil_image = generate(prompt, negative_prompt, seed) img_id = str(uuid.uuid4()) img_path = IMGS_PATH / f"{img_id}.jpg" pil_image.save(img_path) img_io = io.BytesIO() pil_image.save(img_io, "JPEG") img_io.seek(0) database.insert(prompt, negative_prompt, str(img_path), seed) return StreamingResponse(img_io, media_type="image/jpeg") @app.get("/") async def main(): # redirect to https://huggingface.co/spaces/multimodalart/stable-cascade return RedirectResponse( "https://multimodalart-stable-cascade.hf.space/?__theme=system" ) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)