from fastapi import FastAPI, Request, Form, File, UploadFile from fastapi.responses import StreamingResponse from contextlib import asynccontextmanager from starlette.middleware.cors import CORSMiddleware import torch from PIL import Image from io import BytesIO from diffusers import ( AutoPipelineForText2Image, AutoPipelineForImage2Image, AutoPipelineForInpainting, ) @asynccontextmanager async def lifespan(app: FastAPI): text2img = AutoPipelineForText2Image.from_pretrained( "stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16" ).to("cpu") img2img = AutoPipelineForImage2Image.from_pipe(text2img).to("cpu") inpaint = AutoPipelineForInpainting.from_pipe(text2img).to("cpu") yield {"text2img": text2img, "img2img": img2img, "inpaint": inpaint} del text2img del img2img del inpaint app = FastAPI(lifespan=lifespan) origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/") async def root(): return {"Hello": "World"} @app.post("/text-to-image/") async def text_to_image(request: Request, prompt: str = Form(...)): image = request.state.text2img( prompt=prompt, num_inference_steps=1, guidance_scale=0.0 ).images[0] bytes = BytesIO() image.save(bytes, "PNG") bytes.seek(0) return StreamingResponse(bytes, media_type="image/png") @app.post("/image-to-image/") async def image_to_image( request: Request, prompt: str = Form(...), init_image: UploadFile = File(...) ): bytes = await init_image.read() init_image = Image.open(BytesIO(bytes)) init_image = init_image.convert("RGB").resize((512, 512)) image = request.state.img2img.pipe( prompt, image=init_image, num_inference_steps=2, strength=0.5, guidance_scale=0.0, ).images[0] bytes = BytesIO() image.save(bytes, "PNG") bytes.seek(0) return StreamingResponse(bytes, media_type="image/png") @app.post("/inpainting/") async def inpainting( request: Request, prompt: str = Form(...), init_image: UploadFile = File(...), mask_image: UploadFile = File(...), ): bytes = await init_image.read() init_image = Image.open(BytesIO(bytes)) init_image = init_image.convert("RGB").resize((512, 512)) bytes = await mask_image.read() mask_image = Image.open(BytesIO(bytes)) mask_image = mask_image.convert("RGB").resize((512, 512)) image = request.state.inpaint.pipe( prompt, image=init_image, mask_image=mask_image, num_inference_steps=3, strength=0.5, guidance_scale=0.0, ).images[0] bytes = BytesIO() image.save(bytes, "PNG") bytes.seek(0) return StreamingResponse(bytes, media_type="image/png")