hysts HF staff commited on
Commit
b9829b9
1 Parent(s): 980e614
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +32 -32
  3. style.css +3 -0
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🐢
4
  colorFrom: blue
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 3.19.1
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: blue
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 3.34.0
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -22,12 +22,7 @@ sys.path.insert(0, 'bizarre-pose-estimator')
22
 
23
  from _util.twodee_v0 import I as ImageWrapper
24
 
25
- TITLE = 'ShuhongChen/bizarre-pose-estimator (segmenter)'
26
- DESCRIPTION = 'This is an unofficial demo for https://github.com/ShuhongChen/bizarre-pose-estimator.'
27
-
28
- HF_TOKEN = os.getenv('HF_TOKEN')
29
- MODEL_REPO = 'hysts/bizarre-pose-estimator-models'
30
- MODEL_FILENAME = 'segmenter.pth'
31
 
32
 
33
  def load_sample_image_paths() -> list[pathlib.Path]:
@@ -36,8 +31,7 @@ def load_sample_image_paths() -> list[pathlib.Path]:
36
  dataset_repo = 'hysts/sample-images-TADNE'
37
  path = huggingface_hub.hf_hub_download(dataset_repo,
38
  'images.tar.gz',
39
- repo_type='dataset',
40
- use_auth_token=HF_TOKEN)
41
  with tarfile.open(path) as f:
42
  f.extractall()
43
  return sorted(image_dir.glob('*'))
@@ -45,9 +39,8 @@ def load_sample_image_paths() -> list[pathlib.Path]:
45
 
46
  def load_model(
47
  device: torch.device) -> tuple[torch.nn.Module, torch.nn.Module]:
48
- path = huggingface_hub.hf_hub_download(MODEL_REPO,
49
- MODEL_FILENAME,
50
- use_auth_token=HF_TOKEN)
51
  ckpt = torch.load(path)
52
 
53
  model = torchvision.models.segmentation.deeplabv3_resnet101()
@@ -114,24 +107,31 @@ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
114
  model, final_head = load_model(device)
115
  transform = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
116
 
117
- func = functools.partial(predict,
118
- transform=transform,
119
- device=device,
120
- model=model,
121
- final_head=final_head)
122
-
123
- gr.Interface(
124
- fn=func,
125
- inputs=[
126
- gr.Image(type='pil', label='Input'),
127
- gr.Slider(label='Score Threshold',
128
- minimum=0,
129
- maximum=1,
130
- step=0.05,
131
- value=0.5),
132
- ],
133
- outputs=gr.Image(label='Masked'),
134
- examples=examples,
135
- title=TITLE,
136
- description=DESCRIPTION,
137
- ).queue().launch(show_api=False)
 
 
 
 
 
 
 
 
22
 
23
  from _util.twodee_v0 import I as ImageWrapper
24
 
25
+ DESCRIPTION = '# [ShuhongChen/bizarre-pose-estimator (segmenter)](https://github.com/ShuhongChen/bizarre-pose-estimator)'
 
 
 
 
 
26
 
27
 
28
  def load_sample_image_paths() -> list[pathlib.Path]:
 
31
  dataset_repo = 'hysts/sample-images-TADNE'
32
  path = huggingface_hub.hf_hub_download(dataset_repo,
33
  'images.tar.gz',
34
+ repo_type='dataset')
 
35
  with tarfile.open(path) as f:
36
  f.extractall()
37
  return sorted(image_dir.glob('*'))
 
39
 
40
  def load_model(
41
  device: torch.device) -> tuple[torch.nn.Module, torch.nn.Module]:
42
+ path = huggingface_hub.hf_hub_download(
43
+ 'public-data/bizarre-pose-estimator-models', 'segmenter.pth')
 
44
  ckpt = torch.load(path)
45
 
46
  model = torchvision.models.segmentation.deeplabv3_resnet101()
 
107
  model, final_head = load_model(device)
108
  transform = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
109
 
110
+ fn = functools.partial(predict,
111
+ transform=transform,
112
+ device=device,
113
+ model=model,
114
+ final_head=final_head)
115
+
116
+ with gr.Blocks(css='style.css') as demo:
117
+ gr.Markdown(DESCRIPTION)
118
+ with gr.Row():
119
+ with gr.Column():
120
+ image = gr.Image(label='Input', type='pil')
121
+ threshold = gr.Slider(label='Score Threshold',
122
+ minimum=0,
123
+ maximum=1,
124
+ step=0.05,
125
+ value=0.5)
126
+ run_button = gr.Button('Run')
127
+ with gr.Column():
128
+ result = gr.Image(label='Masked')
129
+
130
+ inputs = [image, threshold]
131
+ gr.Examples(examples=examples,
132
+ inputs=inputs,
133
+ outputs=result,
134
+ fn=fn,
135
+ cache_examples=os.getenv('CACHE_EXAMPLES') == '1')
136
+ run_button.click(fn=fn, inputs=inputs, outputs=result, api_name='predict')
137
+ demo.queue(max_size=15).launch()
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }