hysts HF staff commited on
Commit
af898ba
1 Parent(s): f2b5ccc
Files changed (4) hide show
  1. .gitmodules +3 -0
  2. app.py +184 -0
  3. bizarre-pose-estimator +1 -0
  4. requirements.txt +2 -0
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "bizarre-pose-estimator"]
2
+ path = bizarre-pose-estimator
3
+ url = https://github.com/ShuhongChen/bizarre-pose-estimator
app.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import functools
7
+ import os
8
+ import pathlib
9
+ import subprocess
10
+ import sys
11
+ import tarfile
12
+ from typing import Callable
13
+
14
+ # workaround for https://github.com/gradio-app/gradio/issues/483
15
+ command = 'pip install -U gradio==2.7.0'
16
+ subprocess.call(command.split())
17
+
18
+ import gradio as gr
19
+ import huggingface_hub
20
+ import numpy as np
21
+ import PIL.Image
22
+ import torch
23
+ import torch.nn as nn
24
+ import torchvision
25
+ import torchvision.transforms as T
26
+
27
+ sys.path.insert(0, 'bizarre-pose-estimator')
28
+
29
+ from _util.twodee_v0 import I as ImageWrapper
30
+
31
+ TOKEN = os.environ['TOKEN']
32
+
33
+ MODEL_REPO = 'hysts/bizarre-pose-estimator-models'
34
+ MODEL_FILENAME = 'segmenter.pth'
35
+
36
+
37
+ def parse_args() -> argparse.Namespace:
38
+ parser = argparse.ArgumentParser()
39
+ parser.add_argument('--device', type=str, default='cpu')
40
+ parser.add_argument('--score-slider-step', type=float, default=0.05)
41
+ parser.add_argument('--score-threshold', type=float, default=0.5)
42
+ parser.add_argument('--theme', type=str)
43
+ parser.add_argument('--live', action='store_true')
44
+ parser.add_argument('--share', action='store_true')
45
+ parser.add_argument('--port', type=int)
46
+ parser.add_argument('--disable-queue',
47
+ dest='enable_queue',
48
+ action='store_false')
49
+ parser.add_argument('--allow-flagging', type=str, default='never')
50
+ parser.add_argument('--allow-screenshot', action='store_true')
51
+ return parser.parse_args()
52
+
53
+
54
+ def load_sample_image_paths() -> list[pathlib.Path]:
55
+ image_dir = pathlib.Path('images')
56
+ if not image_dir.exists():
57
+ dataset_repo = 'hysts/sample-images-TADNE'
58
+ path = huggingface_hub.hf_hub_download(dataset_repo,
59
+ 'images.tar.gz',
60
+ repo_type='dataset',
61
+ use_auth_token=TOKEN)
62
+ with tarfile.open(path) as f:
63
+ f.extractall()
64
+ return sorted(image_dir.glob('*'))
65
+
66
+
67
+ def load_model(
68
+ device: torch.device) -> tuple[torch.nn.Module, torch.nn.Module]:
69
+ path = huggingface_hub.hf_hub_download(MODEL_REPO,
70
+ MODEL_FILENAME,
71
+ use_auth_token=TOKEN)
72
+ ckpt = torch.load(path)
73
+
74
+ model = torchvision.models.segmentation.deeplabv3_resnet101()
75
+ model.classifier = nn.Sequential(
76
+ torchvision.models.segmentation.deeplabv3.ASPP(2048, [12, 24, 36]),
77
+ nn.Conv2d(256, 64, kernel_size=3, stride=1, padding=1),
78
+ nn.BatchNorm2d(64),
79
+ nn.LeakyReLU(),
80
+ nn.Conv2d(64, 16, kernel_size=3, stride=1, padding=1),
81
+ nn.BatchNorm2d(16),
82
+ nn.LeakyReLU(),
83
+ )
84
+ final_head = nn.Sequential(
85
+ nn.Conv2d(16 + 3, 16, kernel_size=3, stride=1, padding=1),
86
+ nn.BatchNorm2d(16),
87
+ nn.LeakyReLU(),
88
+ nn.Conv2d(16, 8, kernel_size=3, stride=1, padding=1),
89
+ nn.BatchNorm2d(8),
90
+ nn.LeakyReLU(),
91
+ nn.Conv2d(8, 2, kernel_size=1, stride=1),
92
+ )
93
+ model.load_state_dict(ckpt['model'])
94
+ final_head.load_state_dict(ckpt['final_head'])
95
+ model.to(device)
96
+ model.eval()
97
+ final_head.to(device)
98
+ final_head.eval()
99
+ return model, final_head
100
+
101
+
102
+ @torch.inference_mode()
103
+ def predict(image: PIL.Image.Image, score_threshold: float,
104
+ transform: Callable, device: torch.device, model: torch.nn.Module,
105
+ final_head: torch.nn.Module) -> np.ndarray:
106
+ data = ImageWrapper(image).resize_min(256).convert('RGBA').alpha_bg(
107
+ 1).convert('RGB').pil()
108
+ data = torchvision.transforms.functional.to_tensor(data)
109
+ data = transform(data)
110
+ data = data.to(device).unsqueeze(0)
111
+
112
+ out = model(data)['out']
113
+ out_fin = final_head(torch.cat([
114
+ out,
115
+ data,
116
+ ], dim=1))
117
+ probs = torch.softmax(out_fin, dim=1)[0]
118
+ probs = probs[1] # foreground
119
+ probs = PIL.Image.fromarray(probs.cpu().numpy()).resize(image.size)
120
+
121
+ mask = np.asarray(probs)
122
+ mask[mask < score_threshold] = 0
123
+ mask[mask > 0] = 1
124
+ mask = mask.astype(bool)
125
+
126
+ res = np.asarray(image)
127
+ res[~mask] = 255
128
+ return res
129
+
130
+
131
+ def main():
132
+ gr.close_all()
133
+
134
+ args = parse_args()
135
+ device = torch.device(args.device)
136
+
137
+ image_paths = load_sample_image_paths()
138
+ examples = [[path.as_posix(), args.score_threshold]
139
+ for path in image_paths]
140
+
141
+ model, final_head = load_model(device)
142
+ transform = T.Normalize(mean=[0.485, 0.456, 0.406],
143
+ std=[0.229, 0.224, 0.225])
144
+
145
+ func = functools.partial(predict,
146
+ transform=transform,
147
+ device=device,
148
+ model=model,
149
+ final_head=final_head)
150
+ func = functools.update_wrapper(func, predict)
151
+
152
+ repo_url = 'https://github.com/ShuhongChen/bizarre-pose-estimator'
153
+ title = 'ShuhongChen/bizarre-pose-estimator (segmenter)'
154
+ description = f'A demo for {repo_url}'
155
+ article = None
156
+
157
+ gr.Interface(
158
+ func,
159
+ [
160
+ gr.inputs.Image(type='pil', label='Input'),
161
+ gr.inputs.Slider(0,
162
+ 1,
163
+ step=args.score_slider_step,
164
+ default=args.score_threshold,
165
+ label='Score Threshold'),
166
+ ],
167
+ gr.outputs.Image(label='Masked'),
168
+ theme=args.theme,
169
+ title=title,
170
+ description=description,
171
+ article=article,
172
+ examples=examples,
173
+ allow_screenshot=args.allow_screenshot,
174
+ allow_flagging=args.allow_flagging,
175
+ live=args.live,
176
+ ).launch(
177
+ enable_queue=args.enable_queue,
178
+ server_port=args.port,
179
+ share=args.share,
180
+ )
181
+
182
+
183
+ if __name__ == '__main__':
184
+ main()
bizarre-pose-estimator ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 7382ec234fa40cd8a6ec4a28b4639209199bc035
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch>=1.10.1
2
+ torchvision>=0.11.2