hysts HF staff commited on
Commit
32925c4
1 Parent(s): 6091f66
Files changed (2) hide show
  1. app.py +54 -3
  2. model.py +0 -55
app.py CHANGED
@@ -2,13 +2,64 @@
2
 
3
  from __future__ import annotations
4
 
 
 
 
 
5
  import gradio as gr
 
 
 
 
 
 
 
 
 
6
 
7
- from model import Model
8
 
9
  DESCRIPTION = "# [MangaLineExtraction_PyTorch](https://github.com/ljsabc/MangaLineExtraction_PyTorch)"
10
 
11
- model = Model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  with gr.Blocks(css="style.css") as demo:
14
  gr.Markdown(DESCRIPTION)
@@ -19,7 +70,7 @@ with gr.Blocks(css="style.css") as demo:
19
  with gr.Column():
20
  result = gr.Image(label="Result", elem_id="result")
21
  run_button.click(
22
- fn=model.predict,
23
  inputs=input_image,
24
  outputs=result,
25
  )
 
2
 
3
  from __future__ import annotations
4
 
5
+ import pathlib
6
+ import sys
7
+
8
+ import cv2
9
  import gradio as gr
10
+ import numpy as np
11
+ import spaces
12
+ import torch
13
+ import torch.nn as nn
14
+ from huggingface_hub import hf_hub_download
15
+
16
+ current_dir = pathlib.Path(__file__).parent
17
+ submodule_dir = current_dir / "MangaLineExtraction_PyTorch"
18
+ sys.path.insert(0, submodule_dir.as_posix())
19
 
20
+ from model_torch import res_skip
21
 
22
  DESCRIPTION = "# [MangaLineExtraction_PyTorch](https://github.com/ljsabc/MangaLineExtraction_PyTorch)"
23
 
24
+
25
+ def load_model(device: torch.device) -> nn.Module:
26
+ ckpt_path = hf_hub_download("public-data/MangaLineExtraction_PyTorch", "erika.pth")
27
+ state_dict = torch.load(ckpt_path)
28
+ model = res_skip()
29
+ model.load_state_dict(state_dict)
30
+ model.to(device)
31
+ model.eval()
32
+ return model
33
+
34
+
35
+ MAX_SIZE = 1000
36
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
37
+ model = load_model(device)
38
+
39
+
40
+ @spaces.GPU
41
+ @torch.inference_mode()
42
+ def predict(image: np.ndarray) -> np.ndarray:
43
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
44
+
45
+ if max(gray.shape) > MAX_SIZE:
46
+ scale = MAX_SIZE / max(gray.shape)
47
+ gray = cv2.resize(gray, None, fx=scale, fy=scale)
48
+
49
+ h, w = gray.shape
50
+ size = 16
51
+ new_w = (w + size - 1) // size * size
52
+ new_h = (h + size - 1) // size * size
53
+
54
+ patch = np.ones((1, 1, new_h, new_w), dtype=np.float32)
55
+ patch[0, 0, :h, :w] = gray
56
+ tensor = torch.from_numpy(patch).to(device)
57
+ out = model(tensor)
58
+
59
+ res = out.cpu().numpy()[0, 0, :h, :w]
60
+ res = np.clip(res, 0, 255).astype(np.uint8)
61
+ return res
62
+
63
 
64
  with gr.Blocks(css="style.css") as demo:
65
  gr.Markdown(DESCRIPTION)
 
70
  with gr.Column():
71
  result = gr.Image(label="Result", elem_id="result")
72
  run_button.click(
73
+ fn=predict,
74
  inputs=input_image,
75
  outputs=result,
76
  )
model.py DELETED
@@ -1,55 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import pathlib
4
- import sys
5
-
6
- import cv2
7
- import huggingface_hub
8
- import numpy as np
9
- import torch
10
- import torch.nn as nn
11
-
12
- current_dir = pathlib.Path(__file__).parent
13
- submodule_dir = current_dir / "MangaLineExtraction_PyTorch"
14
- sys.path.insert(0, submodule_dir.as_posix())
15
-
16
- from model_torch import res_skip
17
-
18
- MAX_SIZE = 1000
19
-
20
-
21
- class Model:
22
- def __init__(self):
23
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
24
- self.model = self._load_model()
25
-
26
- def _load_model(self) -> nn.Module:
27
- ckpt_path = huggingface_hub.hf_hub_download("public-data/MangaLineExtraction_PyTorch", "erika.pth")
28
- state_dict = torch.load(ckpt_path)
29
- model = res_skip()
30
- model.load_state_dict(state_dict)
31
- model.to(self.device)
32
- model.eval()
33
- return model
34
-
35
- @torch.inference_mode()
36
- def predict(self, image: np.ndarray) -> np.ndarray:
37
- gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
38
-
39
- if max(gray.shape) > MAX_SIZE:
40
- scale = MAX_SIZE / max(gray.shape)
41
- gray = cv2.resize(gray, None, fx=scale, fy=scale)
42
-
43
- h, w = gray.shape
44
- size = 16
45
- new_w = (w + size - 1) // size * size
46
- new_h = (h + size - 1) // size * size
47
-
48
- patch = np.ones((1, 1, new_h, new_w), dtype=np.float32)
49
- patch[0, 0, :h, :w] = gray
50
- tensor = torch.from_numpy(patch).to(self.device)
51
- out = self.model(tensor)
52
-
53
- res = out.cpu().numpy()[0, 0, :h, :w]
54
- res = np.clip(res, 0, 255).astype(np.uint8)
55
- return res