hysts HF staff commited on
Commit
49ff668
·
1 Parent(s): 9013cd2
Files changed (5) hide show
  1. .pre-commit-config.yaml +35 -0
  2. .style.yapf +5 -0
  3. README.md +1 -29
  4. app.py +46 -81
  5. requirements.txt +3 -3
.pre-commit-config.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.2.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: double-quote-string-fixer
12
+ - id: end-of-file-fixer
13
+ - id: mixed-line-ending
14
+ args: ['--fix=lf']
15
+ - id: requirements-txt-fixer
16
+ - id: trailing-whitespace
17
+ - repo: https://github.com/myint/docformatter
18
+ rev: v1.4
19
+ hooks:
20
+ - id: docformatter
21
+ args: ['--in-place']
22
+ - repo: https://github.com/pycqa/isort
23
+ rev: 5.12.0
24
+ hooks:
25
+ - id: isort
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v0.991
28
+ hooks:
29
+ - id: mypy
30
+ args: ['--ignore-missing-imports']
31
+ - repo: https://github.com/google/yapf
32
+ rev: v0.32.0
33
+ hooks:
34
+ - id: yapf
35
+ args: ['--parallel', '--in-place']
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
README.md CHANGED
@@ -4,35 +4,7 @@ emoji: 🌍
4
  colorFrom: green
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 3.0.5
8
  app_file: app.py
9
  pinned: false
10
  ---
11
-
12
- # Configuration
13
-
14
- `title`: _string_
15
- Display title for the Space
16
-
17
- `emoji`: _string_
18
- Space emoji (emoji-only character allowed)
19
-
20
- `colorFrom`: _string_
21
- Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
22
-
23
- `colorTo`: _string_
24
- Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
25
-
26
- `sdk`: _string_
27
- Can be either `gradio`, `streamlit`, or `static`
28
-
29
- `sdk_version` : _string_
30
- Only applicable for `streamlit` SDK.
31
- See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
32
-
33
- `app_file`: _string_
34
- Path to your main application file (which contains either `gradio` or `streamlit` Python code, or `static` html code).
35
- Path is relative to the root of the repository.
36
-
37
- `pinned`: _boolean_
38
- Whether the Space stays on top of your list.
 
4
  colorFrom: green
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 3.19.1
8
  app_file: app.py
9
  pinned: false
10
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -2,20 +2,21 @@
2
 
3
  from __future__ import annotations
4
 
5
- import argparse
6
  import functools
7
  import os
8
  import pathlib
 
9
  import subprocess
10
  import tarfile
11
 
12
- if os.environ.get('SYSTEM') == 'spaces':
13
  subprocess.call(
14
- 'pip install git+https://github.com/facebookresearch/detectron2@v0.6'.
15
- split())
16
- subprocess.call(
17
- 'pip install git+https://github.com/aim-uofa/AdelaiDet@7bf9d87'.split(
18
  ))
 
 
 
19
 
20
  import gradio as gr
21
  import huggingface_hub
@@ -26,34 +27,15 @@ from detectron2.data.detection_utils import read_image
26
  from detectron2.engine.defaults import DefaultPredictor
27
  from detectron2.utils.visualizer import Visualizer
28
 
29
- TITLE = 'zymk9/Yet-Another-Anime-Segmenter'
30
  DESCRIPTION = 'This is an unofficial demo for https://github.com/zymk9/Yet-Another-Anime-Segmenter.'
31
- ARTICLE = '<center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.yet-another-anime-segmenter" alt="visitor badge"/></center>'
32
 
33
- TOKEN = os.environ['TOKEN']
34
  MODEL_REPO = 'hysts/Yet-Another-Anime-Segmenter'
35
  MODEL_FILENAME = 'SOLOv2.pth'
36
  CONFIG_FILENAME = 'SOLOv2.yaml'
37
 
38
 
39
- def parse_args() -> argparse.Namespace:
40
- parser = argparse.ArgumentParser()
41
- parser.add_argument('--device', type=str, default='cpu')
42
- parser.add_argument('--class-score-slider-step', type=float, default=0.05)
43
- parser.add_argument('--class-score-threshold', type=float, default=0.1)
44
- parser.add_argument('--mask-score-slider-step', type=float, default=0.05)
45
- parser.add_argument('--mask-score-threshold', type=float, default=0.5)
46
- parser.add_argument('--theme', type=str)
47
- parser.add_argument('--live', action='store_true')
48
- parser.add_argument('--share', action='store_true')
49
- parser.add_argument('--port', type=int)
50
- parser.add_argument('--disable-queue',
51
- dest='enable_queue',
52
- action='store_false')
53
- parser.add_argument('--allow-flagging', type=str, default='never')
54
- return parser.parse_args()
55
-
56
-
57
  def load_sample_image_paths() -> list[pathlib.Path]:
58
  image_dir = pathlib.Path('images')
59
  if not image_dir.exists():
@@ -61,7 +43,7 @@ def load_sample_image_paths() -> list[pathlib.Path]:
61
  path = huggingface_hub.hf_hub_download(dataset_repo,
62
  'images.tar.gz',
63
  repo_type='dataset',
64
- use_auth_token=TOKEN)
65
  with tarfile.open(path) as f:
66
  f.extractall()
67
  return sorted(image_dir.glob('*'))
@@ -70,10 +52,10 @@ def load_sample_image_paths() -> list[pathlib.Path]:
70
  def load_model(device: torch.device) -> DefaultPredictor:
71
  config_path = huggingface_hub.hf_hub_download(MODEL_REPO,
72
  CONFIG_FILENAME,
73
- use_auth_token=TOKEN)
74
  model_path = huggingface_hub.hf_hub_download(MODEL_REPO,
75
  MODEL_FILENAME,
76
- use_auth_token=TOKEN)
77
  cfg = get_cfg()
78
  cfg.merge_from_file(config_path)
79
  cfg.MODEL.WEIGHTS = model_path
@@ -82,11 +64,12 @@ def load_model(device: torch.device) -> DefaultPredictor:
82
  return DefaultPredictor(cfg)
83
 
84
 
85
- def predict(image, class_score_threshold: float, mask_score_threshold: float,
 
86
  model: DefaultPredictor) -> tuple[np.ndarray, np.ndarray]:
87
  model.score_threshold = class_score_threshold
88
  model.mask_threshold = mask_score_threshold
89
- image = read_image(image.name, format='BGR')
90
  preds = model(image)
91
  instances = preds['instances'].to('cpu')
92
 
@@ -101,52 +84,34 @@ def predict(image, class_score_threshold: float, mask_score_threshold: float,
101
  return vis, masked
102
 
103
 
104
- def main():
105
- args = parse_args()
106
- device = torch.device(args.device)
107
-
108
- image_paths = load_sample_image_paths()
109
- examples = [[
110
- path.as_posix(), args.class_score_threshold, args.mask_score_threshold
111
- ] for path in image_paths]
112
-
113
- model = load_model(device)
114
-
115
- func = functools.partial(predict, model=model)
116
- func = functools.update_wrapper(func, predict)
117
-
118
- gr.Interface(
119
- func,
120
- [
121
- gr.inputs.Image(type='file', label='Input'),
122
- gr.inputs.Slider(0,
123
- 1,
124
- step=args.class_score_slider_step,
125
- default=args.class_score_threshold,
126
- label='Class Score Threshold'),
127
- gr.inputs.Slider(0,
128
- 1,
129
- step=args.mask_score_slider_step,
130
- default=args.mask_score_threshold,
131
- label='Mask Score Threshold'),
132
- ],
133
- [
134
- gr.outputs.Image(label='Instances'),
135
- gr.outputs.Image(label='Masked'),
136
- ],
137
- examples=examples,
138
- title=TITLE,
139
- description=DESCRIPTION,
140
- article=ARTICLE,
141
- theme=args.theme,
142
- allow_flagging=args.allow_flagging,
143
- live=args.live,
144
- ).launch(
145
- enable_queue=args.enable_queue,
146
- server_port=args.port,
147
- share=args.share,
148
- )
149
-
150
-
151
- if __name__ == '__main__':
152
- main()
 
2
 
3
  from __future__ import annotations
4
 
 
5
  import functools
6
  import os
7
  import pathlib
8
+ import shlex
9
  import subprocess
10
  import tarfile
11
 
12
+ if os.getenv('SYSTEM') == 'spaces':
13
  subprocess.call(
14
+ shlex.split(
15
+ 'pip install git+https://github.com/facebookresearch/detectron2@v0.6'
 
 
16
  ))
17
+ subprocess.call(
18
+ shlex.split(
19
+ 'pip install git+https://github.com/aim-uofa/AdelaiDet@7bf9d87'))
20
 
21
  import gradio as gr
22
  import huggingface_hub
 
27
  from detectron2.engine.defaults import DefaultPredictor
28
  from detectron2.utils.visualizer import Visualizer
29
 
30
+ TITLE = 'Yet-Another-Anime-Segmenter'
31
  DESCRIPTION = 'This is an unofficial demo for https://github.com/zymk9/Yet-Another-Anime-Segmenter.'
 
32
 
33
+ HF_TOKEN = os.getenv('HF_TOKEN')
34
  MODEL_REPO = 'hysts/Yet-Another-Anime-Segmenter'
35
  MODEL_FILENAME = 'SOLOv2.pth'
36
  CONFIG_FILENAME = 'SOLOv2.yaml'
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def load_sample_image_paths() -> list[pathlib.Path]:
40
  image_dir = pathlib.Path('images')
41
  if not image_dir.exists():
 
43
  path = huggingface_hub.hf_hub_download(dataset_repo,
44
  'images.tar.gz',
45
  repo_type='dataset',
46
+ use_auth_token=HF_TOKEN)
47
  with tarfile.open(path) as f:
48
  f.extractall()
49
  return sorted(image_dir.glob('*'))
 
52
  def load_model(device: torch.device) -> DefaultPredictor:
53
  config_path = huggingface_hub.hf_hub_download(MODEL_REPO,
54
  CONFIG_FILENAME,
55
+ use_auth_token=HF_TOKEN)
56
  model_path = huggingface_hub.hf_hub_download(MODEL_REPO,
57
  MODEL_FILENAME,
58
+ use_auth_token=HF_TOKEN)
59
  cfg = get_cfg()
60
  cfg.merge_from_file(config_path)
61
  cfg.MODEL.WEIGHTS = model_path
 
64
  return DefaultPredictor(cfg)
65
 
66
 
67
+ def predict(image_path: str, class_score_threshold: float,
68
+ mask_score_threshold: float,
69
  model: DefaultPredictor) -> tuple[np.ndarray, np.ndarray]:
70
  model.score_threshold = class_score_threshold
71
  model.mask_threshold = mask_score_threshold
72
+ image = read_image(image_path, format='BGR')
73
  preds = model(image)
74
  instances = preds['instances'].to('cpu')
75
 
 
84
  return vis, masked
85
 
86
 
87
+ image_paths = load_sample_image_paths()
88
+ examples = [[path.as_posix(), 0.1, 0.5] for path in image_paths]
89
+
90
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
91
+ model = load_model(device)
92
+
93
+ func = functools.partial(predict, model=model)
94
+
95
+ gr.Interface(
96
+ fn=func,
97
+ inputs=[
98
+ gr.Image(label='Input', type='filepath'),
99
+ gr.Slider(label='Class Score Threshold',
100
+ minimum=0,
101
+ maximum=1,
102
+ step=0.05,
103
+ value=0.1),
104
+ gr.Slider(label='Mask Score Threshold',
105
+ minimum=0,
106
+ maximum=1,
107
+ step=0.05,
108
+ default=0.5),
109
+ ],
110
+ outputs=[
111
+ gr.Image(label='Instances'),
112
+ gr.Image(label='Masked'),
113
+ ],
114
+ examples=examples,
115
+ title=TITLE,
116
+ description=DESCRIPTION,
117
+ ).queue().launch(show_api=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
- opencv-python-headless>=4.5.5.62
2
- torch>=1.10.1
3
- torchvision>=0.11.2
 
1
+ opencv-python-headless==4.5.5.62
2
+ torch==1.10.1
3
+ torchvision==0.11.2