|
import asyncio |
|
import os |
|
import threading |
|
import time |
|
import traceback |
|
from pathlib import Path |
|
from typing import Optional, Dict, List |
|
|
|
import cv2 |
|
import numpy as np |
|
import socketio |
|
import torch |
|
|
|
try: |
|
torch._C._jit_override_can_fuse_on_cpu(False) |
|
torch._C._jit_override_can_fuse_on_gpu(False) |
|
torch._C._jit_set_texpr_fuser_enabled(False) |
|
torch._C._jit_set_nvfuser_enabled(False) |
|
except: |
|
pass |
|
|
|
|
|
import uvicorn |
|
from PIL import Image |
|
from fastapi import APIRouter, FastAPI, Request, UploadFile |
|
from fastapi.encoders import jsonable_encoder |
|
from fastapi.exceptions import HTTPException |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.responses import JSONResponse, FileResponse, Response |
|
from fastapi.staticfiles import StaticFiles |
|
from loguru import logger |
|
from socketio import AsyncServer |
|
|
|
from iopaint.file_manager import FileManager |
|
from iopaint.helper import ( |
|
load_img, |
|
decode_base64_to_image, |
|
pil_to_bytes, |
|
numpy_to_bytes, |
|
concat_alpha_channel, |
|
gen_frontend_mask, |
|
adjust_mask, |
|
) |
|
from iopaint.model.utils import torch_gc |
|
from iopaint.model_manager import ModelManager |
|
from iopaint.plugins import build_plugins, RealESRGANUpscaler, InteractiveSeg |
|
from iopaint.plugins.base_plugin import BasePlugin |
|
from iopaint.plugins.remove_bg import RemoveBG |
|
from iopaint.schema import ( |
|
GenInfoResponse, |
|
ApiConfig, |
|
ServerConfigResponse, |
|
SwitchModelRequest, |
|
InpaintRequest, |
|
RunPluginRequest, |
|
SDSampler, |
|
PluginInfo, |
|
AdjustMaskRequest, |
|
RemoveBGModel, |
|
SwitchPluginModelRequest, |
|
ModelInfo, |
|
InteractiveSegModel, |
|
RealESRGANModel, |
|
) |
|
|
|
CURRENT_DIR = Path(__file__).parent.absolute().resolve() |
|
WEB_APP_DIR = CURRENT_DIR / "web_app" |
|
|
|
|
|
def api_middleware(app: FastAPI): |
|
rich_available = False |
|
try: |
|
if os.environ.get("WEBUI_RICH_EXCEPTIONS", None) is not None: |
|
import anyio |
|
import starlette |
|
from rich.console import Console |
|
|
|
console = Console() |
|
rich_available = True |
|
except Exception: |
|
pass |
|
|
|
def handle_exception(request: Request, e: Exception): |
|
err = { |
|
"error": type(e).__name__, |
|
"detail": vars(e).get("detail", ""), |
|
"body": vars(e).get("body", ""), |
|
"errors": str(e), |
|
} |
|
if not isinstance( |
|
e, HTTPException |
|
): |
|
message = f"API error: {request.method}: {request.url} {err}" |
|
if rich_available: |
|
print(message) |
|
console.print_exception( |
|
show_locals=True, |
|
max_frames=2, |
|
extra_lines=1, |
|
suppress=[anyio, starlette], |
|
word_wrap=False, |
|
width=min([console.width, 200]), |
|
) |
|
else: |
|
traceback.print_exc() |
|
return JSONResponse( |
|
status_code=vars(e).get("status_code", 500), content=jsonable_encoder(err) |
|
) |
|
|
|
@app.middleware("http") |
|
async def exception_handling(request: Request, call_next): |
|
try: |
|
return await call_next(request) |
|
except Exception as e: |
|
return handle_exception(request, e) |
|
|
|
@app.exception_handler(Exception) |
|
async def fastapi_exception_handler(request: Request, e: Exception): |
|
return handle_exception(request, e) |
|
|
|
@app.exception_handler(HTTPException) |
|
async def http_exception_handler(request: Request, e: HTTPException): |
|
return handle_exception(request, e) |
|
|
|
cors_options = { |
|
"allow_methods": ["*"], |
|
"allow_headers": ["*"], |
|
"allow_origins": ["*"], |
|
"allow_credentials": True, |
|
} |
|
app.add_middleware(CORSMiddleware, **cors_options) |
|
|
|
|
|
global_sio: AsyncServer = None |
|
|
|
|
|
def diffuser_callback(pipe, step: int, timestep: int, callback_kwargs: Dict = {}): |
|
|
|
|
|
|
|
|
|
|
|
asyncio.run(global_sio.emit("diffusion_progress", {"step": step})) |
|
return {} |
|
|
|
|
|
class Api: |
|
def __init__(self, app: FastAPI, config: ApiConfig): |
|
self.app = app |
|
self.config = config |
|
self.router = APIRouter() |
|
self.queue_lock = threading.Lock() |
|
api_middleware(self.app) |
|
|
|
self.file_manager = self._build_file_manager() |
|
self.plugins = self._build_plugins() |
|
self.model_manager = self._build_model_manager() |
|
|
|
|
|
self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse) |
|
self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"], response_model=ServerConfigResponse) |
|
self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo) |
|
self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo) |
|
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"]) |
|
self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"]) |
|
self.add_api_route("/api/v1/switch_plugin_model", self.api_switch_plugin_model, methods=["POST"]) |
|
self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"]) |
|
self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"]) |
|
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"]) |
|
self.add_api_route("/api/v1/adjust_mask", self.api_adjust_mask, methods=["POST"]) |
|
self.add_api_route("/api/v1/save_image", self.api_save_image, methods=["POST"]) |
|
self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets") |
|
|
|
|
|
global global_sio |
|
self.sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*") |
|
self.combined_asgi_app = socketio.ASGIApp(self.sio, self.app) |
|
self.app.mount("/ws", self.combined_asgi_app) |
|
global_sio = self.sio |
|
|
|
def add_api_route(self, path: str, endpoint, **kwargs): |
|
return self.app.add_api_route(path, endpoint, **kwargs) |
|
|
|
def api_save_image(self, file: UploadFile): |
|
filename = file.filename |
|
origin_image_bytes = file.file.read() |
|
with open(self.config.output_dir / filename, "wb") as fw: |
|
fw.write(origin_image_bytes) |
|
|
|
def api_current_model(self) -> ModelInfo: |
|
return self.model_manager.current_model |
|
|
|
def api_switch_model(self, req: SwitchModelRequest) -> ModelInfo: |
|
if req.name == self.model_manager.name: |
|
return self.model_manager.current_model |
|
self.model_manager.switch(req.name) |
|
return self.model_manager.current_model |
|
|
|
def api_switch_plugin_model(self, req: SwitchPluginModelRequest): |
|
if req.plugin_name in self.plugins: |
|
self.plugins[req.plugin_name].switch_model(req.model_name) |
|
if req.plugin_name == RemoveBG.name: |
|
self.config.remove_bg_model = req.model_name |
|
if req.plugin_name == RealESRGANUpscaler.name: |
|
self.config.realesrgan_model = req.model_name |
|
if req.plugin_name == InteractiveSeg.name: |
|
self.config.interactive_seg_model = req.model_name |
|
torch_gc() |
|
|
|
def api_server_config(self) -> ServerConfigResponse: |
|
plugins = [] |
|
for it in self.plugins.values(): |
|
plugins.append( |
|
PluginInfo( |
|
name=it.name, |
|
support_gen_image=it.support_gen_image, |
|
support_gen_mask=it.support_gen_mask, |
|
) |
|
) |
|
|
|
return ServerConfigResponse( |
|
plugins=plugins, |
|
modelInfos=self.model_manager.scan_models(), |
|
removeBGModel=self.config.remove_bg_model, |
|
removeBGModels=RemoveBGModel.values(), |
|
realesrganModel=self.config.realesrgan_model, |
|
realesrganModels=RealESRGANModel.values(), |
|
interactiveSegModel=self.config.interactive_seg_model, |
|
interactiveSegModels=InteractiveSegModel.values(), |
|
enableFileManager=self.file_manager is not None, |
|
enableAutoSaving=self.config.output_dir is not None, |
|
enableControlnet=self.model_manager.enable_controlnet, |
|
controlnetMethod=self.model_manager.controlnet_method, |
|
disableModelSwitch=False, |
|
isDesktop=False, |
|
samplers=self.api_samplers(), |
|
) |
|
|
|
def api_input_image(self) -> FileResponse: |
|
if self.config.input and self.config.input.is_file(): |
|
return FileResponse(self.config.input) |
|
raise HTTPException(status_code=404, detail="Input image not found") |
|
|
|
def api_geninfo(self, file: UploadFile) -> GenInfoResponse: |
|
_, _, info = load_img(file.file.read(), return_info=True) |
|
parts = info.get("parameters", "").split("Negative prompt: ") |
|
prompt = parts[0].strip() |
|
negative_prompt = "" |
|
if len(parts) > 1: |
|
negative_prompt = parts[1].split("\n")[0].strip() |
|
return GenInfoResponse(prompt=prompt, negative_prompt=negative_prompt) |
|
|
|
def api_inpaint(self, req: InpaintRequest): |
|
image, alpha_channel, infos = decode_base64_to_image(req.image) |
|
mask, _, _ = decode_base64_to_image(req.mask, gray=True) |
|
|
|
mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1] |
|
if image.shape[:2] != mask.shape[:2]: |
|
raise HTTPException( |
|
400, |
|
detail=f"Image size({image.shape[:2]}) and mask size({mask.shape[:2]}) not match.", |
|
) |
|
|
|
if req.paint_by_example_example_image: |
|
paint_by_example_image, _, _ = decode_base64_to_image( |
|
req.paint_by_example_example_image |
|
) |
|
|
|
start = time.time() |
|
rgb_np_img = self.model_manager(image, mask, req) |
|
logger.info(f"process time: {(time.time() - start) * 1000:.2f}ms") |
|
torch_gc() |
|
|
|
rgb_np_img = cv2.cvtColor(rgb_np_img.astype(np.uint8), cv2.COLOR_BGR2RGB) |
|
rgb_res = concat_alpha_channel(rgb_np_img, alpha_channel) |
|
|
|
ext = "png" |
|
res_img_bytes = pil_to_bytes( |
|
Image.fromarray(rgb_res), |
|
ext=ext, |
|
quality=self.config.quality, |
|
infos=infos, |
|
) |
|
|
|
asyncio.run(self.sio.emit("diffusion_finish")) |
|
|
|
return Response( |
|
content=res_img_bytes, |
|
media_type=f"image/{ext}", |
|
headers={"X-Seed": str(req.sd_seed)}, |
|
) |
|
|
|
def api_run_plugin_gen_image(self, req: RunPluginRequest): |
|
ext = "png" |
|
if req.name not in self.plugins: |
|
raise HTTPException(status_code=422, detail="Plugin not found") |
|
if not self.plugins[req.name].support_gen_image: |
|
raise HTTPException( |
|
status_code=422, detail="Plugin does not support output image" |
|
) |
|
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image) |
|
bgr_or_rgba_np_img = self.plugins[req.name].gen_image(rgb_np_img, req) |
|
torch_gc() |
|
|
|
if bgr_or_rgba_np_img.shape[2] == 4: |
|
rgba_np_img = bgr_or_rgba_np_img |
|
else: |
|
rgba_np_img = cv2.cvtColor(bgr_or_rgba_np_img, cv2.COLOR_BGR2RGB) |
|
rgba_np_img = concat_alpha_channel(rgba_np_img, alpha_channel) |
|
|
|
return Response( |
|
content=pil_to_bytes( |
|
Image.fromarray(rgba_np_img), |
|
ext=ext, |
|
quality=self.config.quality, |
|
infos=infos, |
|
), |
|
media_type=f"image/{ext}", |
|
) |
|
|
|
def api_run_plugin_gen_mask(self, req: RunPluginRequest): |
|
if req.name not in self.plugins: |
|
raise HTTPException(status_code=422, detail="Plugin not found") |
|
if not self.plugins[req.name].support_gen_mask: |
|
raise HTTPException( |
|
status_code=422, detail="Plugin does not support output image" |
|
) |
|
rgb_np_img, alpha_channel, infos = decode_base64_to_image(req.image) |
|
bgr_or_gray_mask = self.plugins[req.name].gen_mask(rgb_np_img, req) |
|
torch_gc() |
|
res_mask = gen_frontend_mask(bgr_or_gray_mask) |
|
return Response( |
|
content=numpy_to_bytes(res_mask, "png"), |
|
media_type="image/png", |
|
) |
|
|
|
def api_samplers(self) -> List[str]: |
|
return [member.value for member in SDSampler.__members__.values()] |
|
|
|
def api_adjust_mask(self, req: AdjustMaskRequest): |
|
mask, _, _ = decode_base64_to_image(req.mask, gray=True) |
|
mask = adjust_mask(mask, req.kernel_size, req.operate) |
|
return Response(content=numpy_to_bytes(mask, "png"), media_type="image/png") |
|
|
|
def launch(self): |
|
self.app.include_router(self.router) |
|
uvicorn.run( |
|
self.combined_asgi_app, |
|
host=self.config.host, |
|
port=self.config.port, |
|
timeout_keep_alive=999999999, |
|
) |
|
|
|
def _build_file_manager(self) -> Optional[FileManager]: |
|
if self.config.input and self.config.input.is_dir(): |
|
logger.info( |
|
f"Input is directory, initialize file manager {self.config.input}" |
|
) |
|
|
|
return FileManager( |
|
app=self.app, |
|
input_dir=self.config.input, |
|
output_dir=self.config.output_dir, |
|
) |
|
return None |
|
|
|
def _build_plugins(self) -> Dict[str, BasePlugin]: |
|
return build_plugins( |
|
self.config.enable_interactive_seg, |
|
self.config.interactive_seg_model, |
|
self.config.interactive_seg_device, |
|
self.config.enable_remove_bg, |
|
self.config.remove_bg_model, |
|
self.config.enable_anime_seg, |
|
self.config.enable_realesrgan, |
|
self.config.realesrgan_device, |
|
self.config.realesrgan_model, |
|
self.config.enable_gfpgan, |
|
self.config.gfpgan_device, |
|
self.config.enable_restoreformer, |
|
self.config.restoreformer_device, |
|
self.config.no_half, |
|
) |
|
|
|
def _build_model_manager(self): |
|
return ModelManager( |
|
name=self.config.model, |
|
device=torch.device(self.config.device), |
|
no_half=self.config.no_half, |
|
low_mem=self.config.low_mem, |
|
disable_nsfw=self.config.disable_nsfw_checker, |
|
sd_cpu_textencoder=self.config.cpu_textencoder, |
|
local_files_only=self.config.local_files_only, |
|
cpu_offload=self.config.cpu_offload, |
|
callback=diffuser_callback, |
|
) |
|
|