hysts HF staff commited on
Commit
ec93f77
·
1 Parent(s): d7f54c5
Files changed (2) hide show
  1. app.py +166 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import functools
7
+ import os
8
+ import pathlib
9
+ import subprocess
10
+ import tarfile
11
+
12
+ # workaround for https://github.com/gradio-app/gradio/issues/483
13
+ command = 'pip install -U gradio==2.7.0'
14
+ subprocess.call(command.split())
15
+
16
+ try:
17
+ import detectron2
18
+ except:
19
+ command = 'pip install git+https://github.com/facebookresearch/detectron2@v0.6'
20
+ subprocess.call(command.split())
21
+
22
+ try:
23
+ import adet
24
+ except:
25
+ command = 'pip install git+https://github.com/aim-uofa/AdelaiDet@7bf9d87'
26
+ subprocess.call(command.split())
27
+
28
+ import gradio as gr
29
+ import huggingface_hub
30
+ import numpy as np
31
+ import torch
32
+ from adet.config import get_cfg
33
+ from detectron2.data.detection_utils import read_image
34
+ from detectron2.engine.defaults import DefaultPredictor
35
+ from detectron2.utils.visualizer import Visualizer
36
+
37
+ TOKEN = os.environ['TOKEN']
38
+
39
+ MODEL_REPO = 'hysts/Yet-Another-Anime-Segmenter'
40
+ MODEL_FILENAME = 'SOLOv2.pth'
41
+ CONFIG_FILENAME = 'SOLOv2.yaml'
42
+
43
+
44
+ def parse_args() -> argparse.Namespace:
45
+ parser = argparse.ArgumentParser()
46
+ parser.add_argument('--device', type=str, default='cpu')
47
+ parser.add_argument('--class-score-slider-step', type=float, default=0.05)
48
+ parser.add_argument('--class-score-threshold', type=float, default=0.1)
49
+ parser.add_argument('--mask-score-slider-step', type=float, default=0.05)
50
+ parser.add_argument('--mask-score-threshold', type=float, default=0.5)
51
+ parser.add_argument('--theme', type=str)
52
+ parser.add_argument('--live', action='store_true')
53
+ parser.add_argument('--share', action='store_true')
54
+ parser.add_argument('--port', type=int)
55
+ parser.add_argument('--disable-queue',
56
+ dest='enable_queue',
57
+ action='store_false')
58
+ parser.add_argument('--allow-flagging', type=str, default='never')
59
+ parser.add_argument('--allow-screenshot', action='store_true')
60
+ return parser.parse_args()
61
+
62
+
63
+ def load_sample_image_paths() -> list[pathlib.Path]:
64
+ image_dir = pathlib.Path('images')
65
+ if not image_dir.exists():
66
+ dataset_repo = 'hysts/sample-images-TADNE'
67
+ path = huggingface_hub.hf_hub_download(dataset_repo,
68
+ 'images.tar.gz',
69
+ repo_type='dataset',
70
+ use_auth_token=TOKEN)
71
+ with tarfile.open(path) as f:
72
+ f.extractall()
73
+ return sorted(image_dir.glob('*'))
74
+
75
+
76
+ def load_model(device: torch.device) -> DefaultPredictor:
77
+ config_path = huggingface_hub.hf_hub_download(MODEL_REPO,
78
+ CONFIG_FILENAME,
79
+ use_auth_token=TOKEN)
80
+ model_path = huggingface_hub.hf_hub_download(MODEL_REPO,
81
+ MODEL_FILENAME,
82
+ use_auth_token=TOKEN)
83
+ cfg = get_cfg()
84
+ cfg.merge_from_file(config_path)
85
+ cfg.MODEL.WEIGHTS = model_path
86
+ cfg.MODEL.DEVICE = device.type
87
+ cfg.freeze()
88
+ return DefaultPredictor(cfg)
89
+
90
+
91
+ def predict(image, class_score_threshold: float, mask_score_threshold: float,
92
+ model: DefaultPredictor) -> tuple[np.ndarray, np.ndarray]:
93
+ model.score_threshold = class_score_threshold
94
+ model.mask_threshold = mask_score_threshold
95
+ image = read_image(image.name, format='BGR')
96
+ preds = model(image)
97
+ instances = preds['instances'].to('cpu')
98
+
99
+ visualizer = Visualizer(image[:, :, ::-1])
100
+ vis = visualizer.draw_instance_predictions(predictions=instances)
101
+ vis = vis.get_image()
102
+
103
+ masked = image.copy()[:, :, ::-1]
104
+ mask = instances.pred_masks.cpu().numpy().astype(int).max(axis=0)
105
+ masked[mask == 0] = 255
106
+
107
+ return vis, masked
108
+
109
+
110
+ def main():
111
+ gr.close_all()
112
+
113
+ args = parse_args()
114
+ device = torch.device(args.device)
115
+
116
+ image_paths = load_sample_image_paths()
117
+ examples = [[
118
+ path.as_posix(), args.class_score_threshold, args.mask_score_threshold
119
+ ] for path in image_paths]
120
+
121
+ model = load_model(device)
122
+
123
+ func = functools.partial(predict, model=model)
124
+ func = functools.update_wrapper(func, predict)
125
+
126
+ repo_url = 'https://github.com/zymk9/Yet-Another-Anime-Segmenter'
127
+ title = 'zymk9/Yet-Another-Anime-Segmenter'
128
+ description = f'A demo for {repo_url}'
129
+ article = None
130
+
131
+ gr.Interface(
132
+ func,
133
+ [
134
+ gr.inputs.Image(type='file', label='Input'),
135
+ gr.inputs.Slider(0,
136
+ 1,
137
+ step=args.class_score_slider_step,
138
+ default=args.class_score_threshold,
139
+ label='Class Score Threshold'),
140
+ gr.inputs.Slider(0,
141
+ 1,
142
+ step=args.mask_score_slider_step,
143
+ default=args.mask_score_threshold,
144
+ label='Mask Score Threshold'),
145
+ ],
146
+ [
147
+ gr.outputs.Image(label='Instances'),
148
+ gr.outputs.Image(label='Masked'),
149
+ ],
150
+ theme=args.theme,
151
+ title=title,
152
+ description=description,
153
+ article=article,
154
+ examples=examples,
155
+ allow_screenshot=args.allow_screenshot,
156
+ allow_flagging=args.allow_flagging,
157
+ live=args.live,
158
+ ).launch(
159
+ enable_queue=args.enable_queue,
160
+ server_port=args.port,
161
+ share=args.share,
162
+ )
163
+
164
+
165
+ if __name__ == '__main__':
166
+ main()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ opencv-python-headless>=4.5.5.62
2
+ torch>=1.10.1
3
+ torchvision>=0.11.2