#!/usr/bin/env python from __future__ import annotations import pathlib import sys import cv2 import gradio as gr import numpy as np import spaces import torch import torch.nn as nn from huggingface_hub import hf_hub_download current_dir = pathlib.Path(__file__).parent submodule_dir = current_dir / "MangaLineExtraction_PyTorch" sys.path.insert(0, submodule_dir.as_posix()) from model_torch import res_skip DESCRIPTION = "# [MangaLineExtraction_PyTorch](https://github.com/ljsabc/MangaLineExtraction_PyTorch)" def load_model(device: torch.device) -> nn.Module: ckpt_path = hf_hub_download("public-data/MangaLineExtraction_PyTorch", "erika.pth") state_dict = torch.load(ckpt_path) model = res_skip() model.load_state_dict(state_dict) model.to(device) model.eval() return model MAX_SIZE = 1000 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = load_model(device) @spaces.GPU @torch.inference_mode() def predict(image: np.ndarray) -> np.ndarray: gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) if max(gray.shape) > MAX_SIZE: scale = MAX_SIZE / max(gray.shape) gray = cv2.resize(gray, None, fx=scale, fy=scale) h, w = gray.shape size = 16 new_w = (w + size - 1) // size * size new_h = (h + size - 1) // size * size patch = np.ones((1, 1, new_h, new_w), dtype=np.float32) patch[0, 0, :h, :w] = gray tensor = torch.from_numpy(patch).to(device) out = model(tensor) res = out.cpu().numpy()[0, 0, :h, :w] res = np.clip(res, 0, 255).astype(np.uint8) return res with gr.Blocks(css="style.css") as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(): input_image = gr.Image(label="Input", type="numpy") run_button = gr.Button() with gr.Column(): result = gr.Image(label="Result", elem_id="result") run_button.click( fn=predict, inputs=input_image, outputs=result, ) if __name__ == "__main__": demo.queue(max_size=20).launch()