import asyncio import os import threading import time import traceback from pathlib import Path from typing import Optional, Dict, List import cv2 import numpy as np import socketio import torch try: torch._C._jit_override_can_fuse_on_cpu(False) torch._C._jit_override_can_fuse_on_gpu(False) torch._C._jit_set_texpr_fuser_enabled(False) torch._C._jit_set_nvfuser_enabled(False) except: pass import uvicorn from PIL import Image from fastapi import APIRouter, FastAPI, Request, UploadFile from fastapi.encoders import jsonable_encoder from fastapi.exceptions import HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, FileResponse, Response from fastapi.staticfiles import StaticFiles from loguru import logger from socketio import AsyncServer from iopaint.file_manager import FileManager from iopaint.helper import ( load_img, decode_base64_to_image, pil_to_bytes, numpy_to_bytes, concat_alpha_channel, gen_frontend_mask, adjust_mask, ) from iopaint.model.utils import torch_gc from iopaint.model_manager import ModelManager from iopaint.plugins import build_plugins, RealESRGANUpscaler, InteractiveSeg from iopaint.plugins.base_plugin import BasePlugin from iopaint.plugins.remove_bg import RemoveBG from iopaint.schema import ( GenInfoResponse, ApiConfig, ServerConfigResponse, SwitchModelRequest, InpaintRequest, RunPluginRequest, SDSampler, PluginInfo, AdjustMaskRequest, RemoveBGModel, SwitchPluginModelRequest, ModelInfo, InteractiveSegModel, RealESRGANModel, ) CURRENT_DIR = Path(__file__).parent.absolute().resolve() WEB_APP_DIR = CURRENT_DIR / "web_app" def api_middleware(app: FastAPI): rich_available = False try: if os.environ.get("WEBUI_RICH_EXCEPTIONS", None) is not None: import anyio # importing just so it can be placed on silent list import starlette # importing just so it can be placed on silent list from rich.console import Console console = Console() rich_available = True except Exception: pass def handle_exception(request: Request, e: Exception): err = { "error": type(e).__name__, "detail": vars(e).get("detail", ""), "body": vars(e).get("body", ""), "errors": str(e), } if not isinstance( e, HTTPException ): # do not print backtrace on known httpexceptions message = f"API error: {request.method}: {request.url} {err}" if rich_available: print(message) console.print_exception( show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]), ) else: traceback.print_exc() return JSONResponse( status_code=vars(e).get("status_code", 500), content=jsonable_encoder(err) ) @app.middleware("http") async def exception_handling(request: Request, call_next): try: return await call_next(request) except Exception as e: return handle_exception(request, e) @app.exception_handler(Exception) async def fastapi_exception_handler(request: Request, e: Exception): return handle_exception(request, e) @app.exception_handler(HTTPException) async def http_exception_handler(request: Request, e: HTTPException): return handle_exception(request, e) cors_options = { "allow_methods": ["*"], "allow_headers": ["*"], "allow_origins": ["*"], "allow_credentials": True, } app.add_middleware(CORSMiddleware, **cors_options) global_sio: AsyncServer = None def diffuser_callback(pipe, step: int, timestep: int, callback_kwargs: Dict = {}): # self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict # logger.info(f"diffusion callback: step={step}, timestep={timestep}") # We use asyncio loos for task processing. Perhaps in the future, we can add a processing queue similar to InvokeAI, # but for now let's just start a separate event loop. It shouldn't make a difference for single person use asyncio.run(global_sio.emit("diffusion_progress", {"step": step})) return {} class Api: def __init__(self, app: FastAPI, config: ApiConfig): self.app = app self.config = config self.router = APIRouter() self.queue_lock = threading.Lock() api_middleware(self.app) self.file_manager = self._build_file_manager() self.plugins = self._build_plugins() self.model_manager = self._build_model_manager() # fmt: off self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse) self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"], response_model=ServerConfigResponse) self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo) self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo) self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"]) self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"]) self.add_api_route("/api/v1/switch_plugin_model", self.api_switch_plugin_model, methods=["POST"]) self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"]) self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"]) self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"]) self.add_api_route("/api/v1/adjust_mask", self.api_adjust_mask, methods=["POST"]) self.add_api_route("/api/v1/save_image", self.api_save_image, methods=["POST"]) self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets") # fmt: on global global_sio self.sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*") self.combined_asgi_app = socketio.ASGIApp(self.sio, self.app) self.app.mount("/ws", self.combined_asgi_app) global_sio = self.sio def add_api_route(self, path: str, endpoint, **kwargs): return self.app.add_api_route(path, endpoint, **kwargs) def api_save_image(self, file: UploadFile): filename = file.filename origin_image_bytes = file.file.read() with open(self.config.output_dir / filename, "wb") as fw: fw.write(origin_image_bytes) def api_current_model(self) -> ModelInfo: return self.model_manager.current_model def api_switch_model(self, req: SwitchModelRequest) -> ModelInfo: if req.name == self.model_manager.name: return self.model_manager.current_model self.model_manager.switch(req.name) return self.model_manager.current_model def api_switch_plugin_model(self, req: SwitchPluginModelRequest): if req.plugin_name in self.plugins: self.plugins[req.plugin_name].switch_model(req.model_name) if req.plugin_name == RemoveBG.name: self.config.remove_bg_model = req.model_name if req.plugin_name == RealESRGANUpscaler.name: self.config.realesrgan_model = req.model_name if req.plugin_name == InteractiveSeg.name: self.config.interactive_seg_model = req.model_name torch_gc() def api_server_config(self) -> ServerConfigResponse: plugins = [] for it in self.plugins.values(): plugins.append( PluginInfo( name=it.name, support_gen_image=it.support_gen_image, support_gen_mask=it.support_gen_mask, ) ) return ServerConfigResponse( plugins=plugins, modelInfos=self.model_manager.scan_models(), removeBGModel=self.config.remove_bg_model, removeBGModels=RemoveBGModel.values(), realesrganModel=self.config.realesrgan_model, realesrganModels=RealESRGANModel.values(), interactiveSegModel=self.config.interactive_seg_model, interactiveSegModels=InteractiveSegModel.values(), enableFileManager=self.file_manager is not None, enableAutoSaving=self.config.output_dir is not None, enableControlnet=self.model_manager.enable_controlnet, controlnetMethod=self.model_manager.controlnet_method, disableModelSwitch=False, isDesktop=False, samplers=self.api_samplers(), ) def api_input_image(self) -> FileResponse: if self.config.input and self.config.input.is_file(): return FileResponse(self.config.input) raise HTTPException(status_code=404, detail="Input image not found") def api_geninfo(self, file: UploadFile) -> GenInfoResponse: _, _, info = load_img(file.file.read(), return_info=True) parts = info.get("parameters", "").split("Negative prompt: ") prompt = parts[0].strip() negative_prompt = "" if len(parts) > 1: negative_prompt = parts[1].split("\n")[0].strip() return GenInfoResponse(prompt=prompt, negative_prompt=negative_prompt) def api_inpaint(self, req: InpaintRequest): image, alpha_channel, infos = decode_base64_to_image(req.image) mask, _, _ = decode_base64_to_image(req.mask, gray=True) mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1] if image.shape[:2] != mask.shape[:2]: raise HTTPException( 400, detail=f"Image size({image.shape[:2]}) and mask size({mask.shape[:2]}) not match.", ) if req.paint_by_example_example_image: paint_by_example_image, _, _ = decode_base64_to_image( req.paint_by_example_example_image ) start = time.time() rgb_np_img = self.model_manager(image, mask, req) logger.info(f"process time: {(time.time() - start) * 1000:.2f}ms") torch_gc() rgb_np_img = cv2.cvtColor(rgb_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB) rgb_res = concat_alpha_channel(rgb_np_img, alpha_channel) ext = "png" res_img_bytes = pil_to_bytes( Image.fromarray(rgb_res), ext=ext, quality=self.config.quality, infos=infos, ) asyncio.run(self.sio.emit("diffusion_finish")) return Response( content=res_img_bytes, media_type=f"image/{ext}", headers={"X-Seed": str(req.sd_seed)}, ) def api_run_plugin_gen_image(self, req: RunPluginRequest): ext = "png" if req.name not in self.plugins: raise HTTPException(status_code=422, detail="Plugin not found") if not self.plugins[req.name].support_gen_image: raise HTTPException( status_code=422, detail="Plugin does not support output image" ) rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image) bgr_or_rgba_np_img = self.plugins[req.name].gen_image(rgb_np_img, req) torch_gc() if bgr_or_rgba_np_img.shape[2] == 4: rgba_np_img = bgr_or_rgba_np_img else: rgba_np_img = cv2.cvtColor(bgr_or_rgba_np_img, cv2.COLOR_BGR2RGB) rgba_np_img = concat_alpha_channel(rgba_np_img, alpha_channel) return Response( content=pil_to_bytes( Image.fromarray(rgba_np_img), ext=ext, quality=self.config.quality, infos=infos, ), media_type=f"image/{ext}", ) def api_run_plugin_gen_mask(self, req: RunPluginRequest): if req.name not in self.plugins: raise HTTPException(status_code=422, detail="Plugin not found") if not self.plugins[req.name].support_gen_mask: raise HTTPException( status_code=422, detail="Plugin does not support output image" ) rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image) bgr_or_gray_mask = self.plugins[req.name].gen_mask(rgb_np_img, req) torch_gc() res_mask = gen_frontend_mask(bgr_or_gray_mask) return Response( content=numpy_to_bytes(res_mask, "png"), media_type="image/png", ) def api_samplers(self) -> List[str]: return [member.value for member in SDSampler.__members__.values()] def api_adjust_mask(self, req: AdjustMaskRequest): mask, _, _ = decode_base64_to_image(req.mask, gray=True) mask = adjust_mask(mask, req.kernel_size, req.operate) return Response(content=numpy_to_bytes(mask, "png"), media_type="image/png") def launch(self): self.app.include_router(self.router) uvicorn.run( self.combined_asgi_app, host=self.config.host, port=self.config.port, timeout_keep_alive=999999999, ) def _build_file_manager(self) -> Optional[FileManager]: if self.config.input and self.config.input.is_dir(): logger.info( f"Input is directory, initialize file manager {self.config.input}" ) return FileManager( app=self.app, input_dir=self.config.input, output_dir=self.config.output_dir, ) return None def _build_plugins(self) -> Dict[str, BasePlugin]: return build_plugins( self.config.enable_interactive_seg, self.config.interactive_seg_model, self.config.interactive_seg_device, self.config.enable_remove_bg, self.config.remove_bg_model, self.config.enable_anime_seg, self.config.enable_realesrgan, self.config.realesrgan_device, self.config.realesrgan_model, self.config.enable_gfpgan, self.config.gfpgan_device, self.config.enable_restoreformer, self.config.restoreformer_device, self.config.no_half, ) def _build_model_manager(self): return ModelManager( name=self.config.model, device=torch.device(self.config.device), no_half=self.config.no_half, low_mem=self.config.low_mem, disable_nsfw=self.config.disable_nsfw_checker, sd_cpu_textencoder=self.config.cpu_textencoder, local_files_only=self.config.local_files_only, cpu_offload=self.config.cpu_offload, callback=diffuser_callback, )