Spaces:
Runtime error
Runtime error
import base64 | |
import io | |
import json | |
from pathlib import Path | |
from typing import Dict, Optional | |
import cv2 | |
import psutil | |
from PIL import Image | |
from loguru import logger | |
from rich.console import Console | |
from rich.progress import ( | |
Progress, | |
SpinnerColumn, | |
TimeElapsedColumn, | |
MofNCompleteColumn, | |
TextColumn, | |
BarColumn, | |
TaskProgressColumn, | |
) | |
from iopaint.helper import pil_to_bytes_single | |
from iopaint.model.utils import torch_gc | |
from iopaint.model_manager import ModelManager | |
from iopaint.schema import InpaintRequest | |
import numpy as np | |
def glob_images(path: Path) -> Dict[str, Path]: | |
# png/jpg/jpeg | |
if path.is_file(): | |
return {path.stem: path} | |
elif path.is_dir(): | |
res = {} | |
for it in path.glob("*.*"): | |
if it.suffix.lower() in [".png", ".jpg", ".jpeg"]: | |
res[it.stem] = it | |
return res | |
# def batch_inpaint( | |
# model: str, | |
# device, | |
# image: Path, | |
# mask: Path, | |
# config: Optional[Path] = None, | |
# concat: bool = False, | |
# ): | |
# if config is None: | |
# inpaint_request = InpaintRequest() | |
# else: | |
# with open(config, "r", encoding="utf-8") as f: | |
# inpaint_request = InpaintRequest(**json.load(f)) | |
# | |
# model_manager = ModelManager(name=model, device=device) | |
# | |
# img = cv2.imread(str(image)) | |
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
# | |
# mask_img = cv2.imread(str(mask), cv2.IMREAD_GRAYSCALE) | |
# | |
# if mask_img.shape[:2] != img.shape[:2]: | |
# mask_img = cv2.resize( | |
# mask_img, | |
# (img.shape[1], img.shape[0]), | |
# interpolation=cv2.INTER_NEAREST, | |
# ) | |
# | |
# mask_img[mask_img >= 127] = 255 | |
# mask_img[mask_img < 127] = 0 | |
# | |
# # bgr | |
# inpaint_result = model_manager(img, mask_img, inpaint_request) | |
# inpaint_result = cv2.cvtColor(inpaint_result, cv2.COLOR_BGR2RGB) | |
# | |
# if concat: | |
# mask_img = cv2.cvtColor(mask_img, cv2.COLOR_GRAY2RGB) | |
# inpaint_result = cv2.hconcat([img, mask_img, inpaint_result]) | |
# | |
# # Convert the NumPy array to PIL Image | |
# pil_image = Image.fromarray(inpaint_result) | |
# | |
# # Encode the PIL Image as base64 string | |
# with io.BytesIO() as output_buffer: | |
# pil_image.save(output_buffer, format='PNG') | |
# base64_image = base64.b64encode(output_buffer.getvalue()).decode('utf-8') | |
# | |
# return base64_image | |
def batch_inpaint( | |
model: str, | |
device, | |
input_base64: str, | |
mask_base64: str, | |
config_base64: Optional[str] = None, | |
concat: bool = False, | |
): | |
if config_base64 is None: | |
inpaint_request = InpaintRequest() | |
else: | |
config_json = base64.b64decode(config_base64) | |
inpaint_request = InpaintRequest(**json.loads(config_json)) | |
model_manager = ModelManager(name=model, device=device) | |
# Decode input image from base64 | |
input_image_data = base64.b64decode(input_base64) | |
input_image = cv2.imdecode(np.frombuffer(input_image_data, np.uint8), cv2.IMREAD_COLOR) | |
# Decode mask image from base64 | |
mask_image_data = base64.b64decode(mask_base64) | |
mask_image = cv2.imdecode(np.frombuffer(mask_image_data, np.uint8), cv2.IMREAD_GRAYSCALE) | |
if mask_image.shape[:2] != input_image.shape[:2]: | |
mask_image = cv2.resize( | |
mask_image, | |
(input_image.shape[1], input_image.shape[0]), | |
interpolation=cv2.INTER_NEAREST, | |
) | |
mask_image[mask_image >= 127] = 255 | |
mask_image[mask_image < 127] = 0 | |
# Run inpainting | |
inpaint_result = model_manager(input_image, mask_image, inpaint_request) | |
if concat: | |
mask_image = cv2.cvtColor(mask_image, cv2.COLOR_GRAY2RGB) | |
inpaint_result = cv2.hconcat([input_image, mask_image, inpaint_result]) | |
# Convert NumPy array to PIL Image | |
pil_image = Image.fromarray(inpaint_result) | |
# Encode PIL Image to base64 string | |
with io.BytesIO() as output_buffer: | |
pil_image.save(output_buffer, format='PNG') | |
base64_image = base64.b64encode(output_buffer.getvalue()).decode('utf-8') | |
return base64_image | |
def batch_inpaint_cv2( | |
model: str, | |
device, | |
input_base: str, | |
mask_base: str, | |
config_base64: Optional[str] = None, | |
concat: bool = False, | |
): | |
if config_base64 is None: | |
inpaint_request = InpaintRequest() | |
else: | |
config_json = base64.b64decode(config_base64) | |
inpaint_request = InpaintRequest(**json.loads(config_json)) | |
model_manager = ModelManager(name=model, device=device) | |
# Decode input image from base | |
input_image = input_base | |
# Decode mask image from base | |
mask_image = mask_base | |
if mask_image.shape[:2] != input_image.shape[:2]: | |
mask_image = cv2.resize( | |
mask_image, | |
(input_image.shape[1], input_image.shape[0]), | |
interpolation=cv2.INTER_NEAREST, | |
) | |
mask_image[mask_image >= 127] = 255 | |
mask_image[mask_image < 127] = 0 | |
# Run inpainting | |
inpaint_result = model_manager(input_image, mask_image, inpaint_request) | |
if concat: | |
mask_image = cv2.cvtColor(mask_image, cv2.COLOR_GRAY2RGB) | |
inpaint_result = cv2.hconcat([input_image, mask_image, inpaint_result]) | |
return inpaint_result |