File size: 2,066 Bytes
3b6fea8
 
 
 
32925c4
 
 
 
3b6fea8
32925c4
 
 
 
 
 
 
 
 
3b6fea8
32925c4
3b6fea8
ec4a1e7
3b6fea8
32925c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b6fea8
ec4a1e7
1af5cc8
 
 
ec4a1e7
6091f66
1af5cc8
6091f66
 
32925c4
6091f66
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#!/usr/bin/env python

from __future__ import annotations

import pathlib
import sys

import cv2
import gradio as gr
import numpy as np
import spaces
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download

current_dir = pathlib.Path(__file__).parent
submodule_dir = current_dir / "MangaLineExtraction_PyTorch"
sys.path.insert(0, submodule_dir.as_posix())

from model_torch import res_skip

DESCRIPTION = "# [MangaLineExtraction_PyTorch](https://github.com/ljsabc/MangaLineExtraction_PyTorch)"


def load_model(device: torch.device) -> nn.Module:
    ckpt_path = hf_hub_download("public-data/MangaLineExtraction_PyTorch", "erika.pth")
    state_dict = torch.load(ckpt_path)
    model = res_skip()
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()
    return model


MAX_SIZE = 1000
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = load_model(device)


@spaces.GPU
@torch.inference_mode()
def predict(image: np.ndarray) -> np.ndarray:
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)

    if max(gray.shape) > MAX_SIZE:
        scale = MAX_SIZE / max(gray.shape)
        gray = cv2.resize(gray, None, fx=scale, fy=scale)

    h, w = gray.shape
    size = 16
    new_w = (w + size - 1) // size * size
    new_h = (h + size - 1) // size * size

    patch = np.ones((1, 1, new_h, new_w), dtype=np.float32)
    patch[0, 0, :h, :w] = gray
    tensor = torch.from_numpy(patch).to(device)
    out = model(tensor)

    res = out.cpu().numpy()[0, 0, :h, :w]
    res = np.clip(res, 0, 255).astype(np.uint8)
    return res


with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(label="Input", type="numpy")
            run_button = gr.Button()
        with gr.Column():
            result = gr.Image(label="Result", elem_id="result")
    run_button.click(
        fn=predict,
        inputs=input_image,
        outputs=result,
    )

if __name__ == "__main__":
    demo.queue(max_size=20).launch()