hysts HF staff commited on
Commit
384ccc2
1 Parent(s): 926874d
Files changed (7) hide show
  1. .gitattributes +8 -2
  2. .pre-commit-config.yaml +60 -0
  3. .vscode/settings.json +30 -0
  4. README.md +1 -1
  5. app.py +50 -90
  6. requirements.txt +4 -4
  7. style.css +11 -0
.gitattributes CHANGED
@@ -1,22 +1,28 @@
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bin.* filter=lfs diff=lfs merge=lfs -text
5
  *.bz2 filter=lfs diff=lfs merge=lfs -text
 
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
 
11
  *.model filter=lfs diff=lfs merge=lfs -text
12
  *.msgpack filter=lfs diff=lfs merge=lfs -text
 
 
13
  *.onnx filter=lfs diff=lfs merge=lfs -text
14
  *.ot filter=lfs diff=lfs merge=lfs -text
15
  *.parquet filter=lfs diff=lfs merge=lfs -text
16
  *.pb filter=lfs diff=lfs merge=lfs -text
 
 
17
  *.pt filter=lfs diff=lfs merge=lfs -text
18
  *.pth filter=lfs diff=lfs merge=lfs -text
19
  *.rar filter=lfs diff=lfs merge=lfs -text
 
20
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
  *.tar.* filter=lfs diff=lfs merge=lfs -text
22
  *.tflite filter=lfs diff=lfs merge=lfs -text
@@ -24,5 +30,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
  *.wasm filter=lfs diff=lfs merge=lfs -text
25
  *.xz filter=lfs diff=lfs merge=lfs -text
26
  *.zip filter=lfs diff=lfs merge=lfs -text
27
- *.zstandard filter=lfs diff=lfs merge=lfs -text
28
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
  *.tflite filter=lfs diff=lfs merge=lfs -text
 
30
  *.wasm filter=lfs diff=lfs merge=lfs -text
31
  *.xz filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
.pre-commit-config.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.5.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: end-of-file-fixer
12
+ - id: mixed-line-ending
13
+ args: ["--fix=lf"]
14
+ - id: requirements-txt-fixer
15
+ - id: trailing-whitespace
16
+ - repo: https://github.com/myint/docformatter
17
+ rev: v1.7.5
18
+ hooks:
19
+ - id: docformatter
20
+ args: ["--in-place"]
21
+ - repo: https://github.com/pycqa/isort
22
+ rev: 5.13.2
23
+ hooks:
24
+ - id: isort
25
+ args: ["--profile", "black"]
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v1.8.0
28
+ hooks:
29
+ - id: mypy
30
+ args: ["--ignore-missing-imports"]
31
+ additional_dependencies:
32
+ [
33
+ "types-python-slugify",
34
+ "types-requests",
35
+ "types-PyYAML",
36
+ "types-pytz",
37
+ ]
38
+ - repo: https://github.com/psf/black
39
+ rev: 24.2.0
40
+ hooks:
41
+ - id: black
42
+ language_version: python3.10
43
+ args: ["--line-length", "119"]
44
+ - repo: https://github.com/kynan/nbstripout
45
+ rev: 0.7.1
46
+ hooks:
47
+ - id: nbstripout
48
+ args:
49
+ [
50
+ "--extra-keys",
51
+ "metadata.interpreter metadata.kernelspec cell.metadata.pycharm",
52
+ ]
53
+ - repo: https://github.com/nbQA-dev/nbQA
54
+ rev: 1.7.1
55
+ hooks:
56
+ - id: nbqa-black
57
+ - id: nbqa-pyupgrade
58
+ args: ["--py37-plus"]
59
+ - id: nbqa-isort
60
+ args: ["--float-to-top"]
.vscode/settings.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "editor.formatOnSave": true,
3
+ "files.insertFinalNewline": false,
4
+ "[python]": {
5
+ "editor.defaultFormatter": "ms-python.black-formatter",
6
+ "editor.formatOnType": true,
7
+ "editor.codeActionsOnSave": {
8
+ "source.organizeImports": "explicit"
9
+ }
10
+ },
11
+ "[jupyter]": {
12
+ "files.insertFinalNewline": false
13
+ },
14
+ "black-formatter.args": [
15
+ "--line-length=119"
16
+ ],
17
+ "isort.args": ["--profile", "black"],
18
+ "flake8.args": [
19
+ "--max-line-length=119"
20
+ ],
21
+ "ruff.lint.args": [
22
+ "--line-length=119"
23
+ ],
24
+ "notebook.output.scrolling": true,
25
+ "notebook.formatOnCellExecution": true,
26
+ "notebook.formatOnSave.enabled": true,
27
+ "notebook.codeActionsOnSave": {
28
+ "source.organizeImports": "explicit"
29
+ }
30
+ }
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 💩
4
  colorFrom: red
5
  colorTo: gray
6
  sdk: gradio
7
- sdk_version: 3.0.5
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: red
5
  colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 4.19.2
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -2,8 +2,7 @@
2
 
3
  from __future__ import annotations
4
 
5
- import argparse
6
- import functools
7
  import pathlib
8
  import sys
9
  import urllib.request
@@ -13,49 +12,35 @@ import cv2
13
  import gradio as gr
14
  import numpy as np
15
  import torch
16
- import torch.nn as nn
17
 
18
- sys.path.insert(0, 'face_detection')
19
 
20
  from ibug.face_detection import RetinaFacePredictor, S3FDPredictor
21
 
22
- TITLE = 'ibug-group/face_detection'
23
- DESCRIPTION = 'This is an unofficial demo for https://github.com/ibug-group/face_detection.'
24
- ARTICLE = '<center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.ibug-face_detection" alt="visitor badge"/></center>'
25
-
26
-
27
- def parse_args() -> argparse.Namespace:
28
- parser = argparse.ArgumentParser()
29
- parser.add_argument('--face-score-slider-step', type=float, default=0.05)
30
- parser.add_argument('--face-score-threshold', type=float, default=0.8)
31
- parser.add_argument('--device', type=str, default='cpu')
32
- parser.add_argument('--theme', type=str)
33
- parser.add_argument('--live', action='store_true')
34
- parser.add_argument('--share', action='store_true')
35
- parser.add_argument('--port', type=int)
36
- parser.add_argument('--disable-queue',
37
- dest='enable_queue',
38
- action='store_false')
39
- parser.add_argument('--allow-flagging', type=str, default='never')
40
- return parser.parse_args()
41
-
42
-
43
- def load_model(
44
- model_name: str, threshold: float,
45
- device: torch.device) -> Union[RetinaFacePredictor, S3FDPredictor]:
46
- if model_name == 's3fd':
47
  model = S3FDPredictor(threshold=threshold, device=device)
48
  else:
49
- model_name = model_name.replace('retinaface_', '')
50
  model = RetinaFacePredictor(
51
- threshold=threshold,
52
- device=device,
53
- model=RetinaFacePredictor.get_model(model_name))
54
  return model
55
 
56
 
57
- def detect(image: np.ndarray, model_name: str, face_score_threshold: float,
58
- detectors: dict[str, nn.Module]) -> np.ndarray:
 
 
 
 
 
 
 
 
59
  model = detectors[model_name]
60
  model.threshold = face_score_threshold
61
 
@@ -68,8 +53,7 @@ def detect(image: np.ndarray, model_name: str, face_score_threshold: float,
68
  box = np.round(pred[:4]).astype(int)
69
 
70
  line_width = max(2, int(3 * (box[2:] - box[:2]).max() / 256))
71
- cv2.rectangle(res, tuple(box[:2]), tuple(box[2:]), (0, 255, 0),
72
- line_width)
73
 
74
  if len(pred) == 15:
75
  pts = pred[5:].reshape(-1, 2)
@@ -79,59 +63,35 @@ def detect(image: np.ndarray, model_name: str, face_score_threshold: float,
79
  return res[:, :, ::-1]
80
 
81
 
82
- def main():
83
- args = parse_args()
84
- device = torch.device(args.device)
85
-
86
- model_names = [
87
- 'retinaface_mobilenet0.25',
88
- 'retinaface_resnet50',
89
- 's3fd',
90
- ]
91
- detectors = {
92
- name: load_model(name,
93
- threshold=args.face_score_threshold,
94
- device=device)
95
- for name in model_names
96
- }
97
-
98
- func = functools.partial(detect, detectors=detectors)
99
- func = functools.update_wrapper(func, detect)
100
-
101
- image_path = pathlib.Path('selfie.jpg')
102
- if not image_path.exists():
103
- url = 'https://raw.githubusercontent.com/peiyunh/tiny/master/data/demo/selfie.jpg'
104
- urllib.request.urlretrieve(url, image_path)
105
- examples = [[image_path.as_posix(), model_names[1], 0.8]]
106
-
107
- gr.Interface(
108
- func,
109
- [
110
- gr.inputs.Image(type='numpy', label='Input'),
111
- gr.inputs.Radio(model_names,
112
- type='value',
113
- default='retinaface_resnet50',
114
- label='Model'),
115
- gr.inputs.Slider(0,
116
- 1,
117
- step=args.face_score_slider_step,
118
- default=args.face_score_threshold,
119
- label='Face Score Threshold'),
120
- ],
121
- gr.outputs.Image(type='numpy', label='Output'),
122
- examples=examples,
123
- title=TITLE,
124
- description=DESCRIPTION,
125
- article=ARTICLE,
126
- theme=args.theme,
127
- allow_flagging=args.allow_flagging,
128
- live=args.live,
129
- ).launch(
130
- enable_queue=args.enable_queue,
131
- server_port=args.port,
132
- share=args.share,
133
  )
134
 
135
 
136
- if __name__ == '__main__':
137
- main()
 
2
 
3
  from __future__ import annotations
4
 
5
+ import os
 
6
  import pathlib
7
  import sys
8
  import urllib.request
 
12
  import gradio as gr
13
  import numpy as np
14
  import torch
 
15
 
16
+ sys.path.insert(0, "face_detection")
17
 
18
  from ibug.face_detection import RetinaFacePredictor, S3FDPredictor
19
 
20
+ DESCRIPTION = "# [ibug-group/face_detection](https://github.com/ibug-group/face_detection)"
21
+
22
+
23
+ def load_model(model_name: str, threshold: float, device: torch.device) -> Union[RetinaFacePredictor, S3FDPredictor]:
24
+ if model_name == "s3fd":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  model = S3FDPredictor(threshold=threshold, device=device)
26
  else:
27
+ model_name = model_name.replace("retinaface_", "")
28
  model = RetinaFacePredictor(
29
+ threshold=threshold, device=device, model=RetinaFacePredictor.get_model(model_name)
30
+ )
 
31
  return model
32
 
33
 
34
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
35
+ model_names = [
36
+ "retinaface_mobilenet0.25",
37
+ "retinaface_resnet50",
38
+ "s3fd",
39
+ ]
40
+ detectors = {name: load_model(name, threshold=0.8, device=device) for name in model_names}
41
+
42
+
43
+ def detect(image: np.ndarray, model_name: str, face_score_threshold: float) -> np.ndarray:
44
  model = detectors[model_name]
45
  model.threshold = face_score_threshold
46
 
 
53
  box = np.round(pred[:4]).astype(int)
54
 
55
  line_width = max(2, int(3 * (box[2:] - box[:2]).max() / 256))
56
+ cv2.rectangle(res, tuple(box[:2]), tuple(box[2:]), (0, 255, 0), line_width)
 
57
 
58
  if len(pred) == 15:
59
  pts = pred[5:].reshape(-1, 2)
 
63
  return res[:, :, ::-1]
64
 
65
 
66
+ example_image_path = pathlib.Path("selfie.jpg")
67
+ if not example_image_path.exists():
68
+ url = "https://raw.githubusercontent.com/peiyunh/tiny/master/data/demo/selfie.jpg"
69
+ urllib.request.urlretrieve(url, example_image_path)
70
+
71
+ with gr.Blocks(css="style.css") as demo:
72
+ gr.Markdown(DESCRIPTION)
73
+ with gr.Row():
74
+ with gr.Column():
75
+ image = gr.Image(type="numpy", label="Input")
76
+ model_name = gr.Radio(model_names, type="value", value="retinaface_resnet50", label="Model")
77
+ score_threshold = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.8, label="Face Score Threshold")
78
+ run_button = gr.Button()
79
+ with gr.Column():
80
+ result = gr.Image(label="Output")
81
+ gr.Examples(
82
+ examples=[[example_image_path.as_posix(), model_names[1], 0.8]],
83
+ inputs=[image, model_name, score_threshold],
84
+ outputs=result,
85
+ fn=detect,
86
+ cache_examples=os.getenv("CACHE_EXAMPLES") == "1",
87
+ )
88
+ run_button.click(
89
+ fn=detect,
90
+ inputs=[image, model_name, score_threshold],
91
+ outputs=result,
92
+ api_name="detect",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  )
94
 
95
 
96
+ if __name__ == "__main__":
97
+ demo.queue(max_size=20).launch()
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- numpy==1.22.3
2
- opencv-python-headless==4.5.5.64
3
- torch==1.11.0
4
- torchvision==0.12.0
 
1
+ numpy==1.26.4
2
+ opencv-python-headless==4.9.0.80
3
+ torch==2.0.1
4
+ torchvision==0.15.2
style.css ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ display: block;
4
+ }
5
+
6
+ #duplicate-button {
7
+ margin: auto;
8
+ color: #fff;
9
+ background: #1565c0;
10
+ border-radius: 100vh;
11
+ }