RemBG / rembg /commands /s_command.py
KenjieDec's picture
3faa99b
raw
history blame
7.57 kB
import json
from typing import Annotated, Optional, Tuple, cast
import aiohttp
import click
import uvicorn
from asyncer import asyncify
from fastapi import Depends, FastAPI, File, Form, Query
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import Response
from .._version import get_versions
from ..bg import remove
from ..session_factory import new_session
from ..sessions import sessions_names
from ..sessions.base import BaseSession
@click.command(
name="s",
help="for a http server",
)
@click.option(
"-p",
"--port",
default=5000,
type=int,
show_default=True,
help="port",
)
@click.option(
"-l",
"--log_level",
default="info",
type=str,
show_default=True,
help="log level",
)
@click.option(
"-t",
"--threads",
default=None,
type=int,
show_default=True,
help="number of worker threads",
)
def s_command(port: int, log_level: str, threads: int) -> None:
sessions: dict[str, BaseSession] = {}
tags_metadata = [
{
"name": "Background Removal",
"description": "Endpoints that perform background removal with different image sources.",
"externalDocs": {
"description": "GitHub Source",
"url": "https://github.com/danielgatis/rembg",
},
},
]
app = FastAPI(
title="Rembg",
description="Rembg is a tool to remove images background. That is it.",
version=get_versions()["version"],
contact={
"name": "Daniel Gatis",
"url": "https://github.com/danielgatis",
"email": "danielgatis@gmail.com",
},
license_info={
"name": "MIT License",
"url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt",
},
openapi_tags=tags_metadata,
)
app.add_middleware(
CORSMiddleware,
allow_credentials=True,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
class CommonQueryParams:
def __init__(
self,
model: Annotated[
str, Query(regex=r"(" + "|".join(sessions_names) + ")")
] = Query(
description="Model to use when processing image",
),
a: bool = Query(default=False, description="Enable Alpha Matting"),
af: int = Query(
default=240,
ge=0,
le=255,
description="Alpha Matting (Foreground Threshold)",
),
ab: int = Query(
default=10,
ge=0,
le=255,
description="Alpha Matting (Background Threshold)",
),
ae: int = Query(
default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
),
om: bool = Query(default=False, description="Only Mask"),
ppm: bool = Query(default=False, description="Post Process Mask"),
bgc: Optional[str] = Query(default=None, description="Background Color"),
extras: Optional[str] = Query(
default=None, description="Extra parameters as JSON"
),
):
self.model = model
self.a = a
self.af = af
self.ab = ab
self.ae = ae
self.om = om
self.ppm = ppm
self.extras = extras
self.bgc = (
cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(","))))
if bgc
else None
)
class CommonQueryPostParams:
def __init__(
self,
model: Annotated[
str, Form(regex=r"(" + "|".join(sessions_names) + ")")
] = Form(
description="Model to use when processing image",
),
a: bool = Form(default=False, description="Enable Alpha Matting"),
af: int = Form(
default=240,
ge=0,
le=255,
description="Alpha Matting (Foreground Threshold)",
),
ab: int = Form(
default=10,
ge=0,
le=255,
description="Alpha Matting (Background Threshold)",
),
ae: int = Form(
default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
),
om: bool = Form(default=False, description="Only Mask"),
ppm: bool = Form(default=False, description="Post Process Mask"),
bgc: Optional[str] = Query(default=None, description="Background Color"),
extras: Optional[str] = Query(
default=None, description="Extra parameters as JSON"
),
):
self.model = model
self.a = a
self.af = af
self.ab = ab
self.ae = ae
self.om = om
self.ppm = ppm
self.extras = extras
self.bgc = (
cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(","))))
if bgc
else None
)
def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
kwargs = {}
if commons.extras:
try:
kwargs.update(json.loads(commons.extras))
except Exception:
pass
return Response(
remove(
content,
session=sessions.setdefault(commons.model, new_session(commons.model)),
alpha_matting=commons.a,
alpha_matting_foreground_threshold=commons.af,
alpha_matting_background_threshold=commons.ab,
alpha_matting_erode_size=commons.ae,
only_mask=commons.om,
post_process_mask=commons.ppm,
bgcolor=commons.bgc,
**kwargs
),
media_type="image/png",
)
@app.on_event("startup")
def startup():
if threads is not None:
from anyio import CapacityLimiter
from anyio.lowlevel import RunVar
RunVar("_default_thread_limiter").set(CapacityLimiter(threads))
@app.get(
path="/",
tags=["Background Removal"],
summary="Remove from URL",
description="Removes the background from an image obtained by retrieving an URL.",
)
async def get_index(
url: str = Query(
default=..., description="URL of the image that has to be processed."
),
commons: CommonQueryParams = Depends(),
):
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
file = await response.read()
return await asyncify(im_without_bg)(file, commons)
@app.post(
path="/",
tags=["Background Removal"],
summary="Remove from Stream",
description="Removes the background from an image sent within the request itself.",
)
async def post_index(
file: bytes = File(
default=...,
description="Image file (byte stream) that has to be processed.",
),
commons: CommonQueryPostParams = Depends(),
):
return await asyncify(im_without_bg)(file, commons) # type: ignore
uvicorn.run(app, host="0.0.0.0", port=port, log_level=log_level)