|
from typing import List |
|
|
|
import numpy as np |
|
from fastapi import FastAPI, Body |
|
from fastapi.exceptions import HTTPException |
|
from PIL import Image |
|
import gradio as gr |
|
|
|
from modules.api import api |
|
from .global_state import ( |
|
get_all_preprocessor_names, |
|
get_all_controlnet_names, |
|
get_preprocessor, |
|
) |
|
from .utils import judge_image_type |
|
from .logging import logger |
|
|
|
|
|
def encode_to_base64(image): |
|
if isinstance(image, str): |
|
return image |
|
elif not judge_image_type(image): |
|
return "Detect result is not image" |
|
elif isinstance(image, Image.Image): |
|
return api.encode_pil_to_base64(image) |
|
elif isinstance(image, np.ndarray): |
|
return encode_np_to_base64(image) |
|
else: |
|
logger.warn("Unable to encode image.") |
|
return "" |
|
|
|
|
|
def encode_np_to_base64(image): |
|
pil = Image.fromarray(image) |
|
return api.encode_pil_to_base64(pil) |
|
|
|
|
|
def controlnet_api(_: gr.Blocks, app: FastAPI): |
|
@app.get("/controlnet/model_list") |
|
async def model_list(): |
|
up_to_date_model_list = get_all_controlnet_names() |
|
logger.debug(up_to_date_model_list) |
|
return {"model_list": up_to_date_model_list} |
|
|
|
@app.get("/controlnet/module_list") |
|
async def module_list(): |
|
module_list = get_all_preprocessor_names() |
|
logger.debug(module_list) |
|
|
|
return { |
|
"module_list": module_list, |
|
|
|
|
|
} |
|
|
|
@app.post("/controlnet/detect") |
|
async def detect( |
|
controlnet_module: str = Body("none", title="Controlnet Module"), |
|
controlnet_input_images: List[str] = Body([], title="Controlnet Input Images"), |
|
controlnet_processor_res: int = Body( |
|
512, title="Controlnet Processor Resolution" |
|
), |
|
controlnet_threshold_a: float = Body(64, title="Controlnet Threshold a"), |
|
controlnet_threshold_b: float = Body(64, title="Controlnet Threshold b"), |
|
): |
|
processor_module = get_preprocessor(controlnet_module) |
|
if processor_module is None: |
|
raise HTTPException(status_code=422, detail="Module not available") |
|
|
|
if len(controlnet_input_images) == 0: |
|
raise HTTPException(status_code=422, detail="No image selected") |
|
|
|
logger.debug( |
|
f"Detecting {str(len(controlnet_input_images))} images with the {controlnet_module} module." |
|
) |
|
|
|
results = [] |
|
poses = [] |
|
|
|
for input_image in controlnet_input_images: |
|
img = np.array(api.decode_base64_to_image(input_image)).astype('uint8') |
|
|
|
class JsonAcceptor: |
|
def __init__(self) -> None: |
|
self.value = None |
|
|
|
def accept(self, json_dict: dict) -> None: |
|
self.value = json_dict |
|
|
|
json_acceptor = JsonAcceptor() |
|
|
|
results.append( |
|
processor_module( |
|
img, |
|
resolution=controlnet_processor_res, |
|
slider_1=controlnet_threshold_a, |
|
slider_2=controlnet_threshold_b, |
|
json_pose_callback=json_acceptor.accept, |
|
) |
|
) |
|
|
|
if "openpose" in controlnet_module: |
|
assert json_acceptor.value is not None |
|
poses.append(json_acceptor.value) |
|
|
|
results64 = [encode_to_base64(img) for img in results] |
|
res = {"images": results64, "info": "Success"} |
|
if poses: |
|
res["poses"] = poses |
|
|
|
return res |
|
|
|
|