top001's picture
Update app.py
8795ec2 verified
import gradio as gr
import huggingface_hub
import onnxruntime as rt
import numpy as np
import cv2
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import Response
import io
from PIL import Image
import imghdr
from typing import Optional
SUPPORTED_FORMATS = {'jpg', 'jpeg', 'png', 'bmp', 'webp', 'tiff'}
def is_valid_image(file_content: bytes) -> Optional[str]:
image_format = imghdr.what(None, file_content)
if image_format is None:
return None
return image_format.lower()
def process_image_bytes(image_bytes: bytes) -> np.ndarray:
try:
image = Image.open(io.BytesIO(image_bytes))
if image.mode == 'RGBA':
image = image.convert('RGB')
img_array = np.array(image)
return img_array
except Exception as e:
raise HTTPException(status_code=400, detail=f"Error: {str(e)}")
def get_mask(img, s=1024):
img = (img / 255).astype(np.float32)
h, w = h0, w0 = img.shape[:-1]
h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)
ph, pw = s - h, s - w
img_input = np.zeros([s, s, 3], dtype=np.float32)
img_input[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(img, (w, h))
img_input = np.transpose(img_input, (2, 0, 1))
img_input = img_input[np.newaxis, :]
mask = rmbg_model.run(None, {'img': img_input})[0][0]
mask = np.transpose(mask, (1, 2, 0))
mask = mask[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
mask = cv2.resize(mask, (w0, h0))[:, :, np.newaxis]
return mask
def rmbg_fn(img):
mask = get_mask(img)
img = (mask * img + 255 * (1 - mask)).astype(np.uint8)
mask = (mask * 255).astype(np.uint8)
img = np.concatenate([img, mask], axis=2, dtype=np.uint8)
mask = mask.repeat(3, axis=2)
return mask, img
app = FastAPI()
gradio_app = gr.Blocks()
with gradio_app:
gr.Markdown("# Anime Remove Background\n\n"
"![visitor badge](https://api.visitorbadge.io/api/visitors?path=skytnt.animeseg&countColor=%23263759&style=flat&labelStyle=lower)\n\n"
"demo for [https://github.com/SkyTNT/anime-segmentation/](https://github.com/SkyTNT/anime-segmentation/)")
with gr.Column():
input_img = gr.Image(label="input image")
examples_data = [[f"examples/{x:02d}.jpg"] for x in range(1, 4)]
examples = gr.Examples(examples=examples_data, inputs=[input_img])
run_btn = gr.Button(variant="primary")
with gr.Row():
output_mask = gr.Image(label="mask", format="png")
output_img = gr.Image(label="result", image_mode="RGBA", format="png")
run_btn.click(rmbg_fn, [input_img], [output_mask, output_img])
@app.post("/remove-bg")
async def remove_background(file: UploadFile = File(...)):
contents = await file.read()
image_format = is_valid_image(contents)
if not image_format or image_format not in SUPPORTED_FORMATS:
raise HTTPException(
status_code=400,
detail=f"Invalid format: {', '.join(SUPPORTED_FORMATS)}"
)
try:
img = process_image_bytes(contents)
mask = get_mask(img)
img = (mask * img + 255 * (1 - mask)).astype(np.uint8)
mask = (mask * 255).astype(np.uint8)
img = np.concatenate([img, mask], axis=2, dtype=np.uint8)
pil_image = Image.fromarray(img, 'RGBA')
img_byte_arr = io.BytesIO()
pil_image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
return Response(
content=img_byte_arr,
media_type="image/png",
headers={
"Content-Disposition": f"attachment; filename={file.filename.split('.')[0]}_nobg.png"
}
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
if __name__ == "__main__":
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
model_path = huggingface_hub.hf_hub_download("skytnt/anime-seg", "isnetis.onnx")
rmbg_model = rt.InferenceSession(model_path, providers=providers)
app = gr.mount_gradio_app(app, gradio_app, path="/")
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)