Spaces:
Runtime error
Runtime error
initial
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +12 -0
- .gitignore +51 -0
- README.md +4 -4
- app.py +31 -0
- configs/anonymizers/FB_cse.py +28 -0
- configs/anonymizers/FB_cse_mask.py +29 -0
- configs/anonymizers/FB_cse_mask_face.py +29 -0
- configs/anonymizers/deep_privacy1.py +15 -0
- configs/anonymizers/face.py +17 -0
- configs/anonymizers/face_fdf128.py +18 -0
- configs/anonymizers/market1501/blackout.py +8 -0
- configs/anonymizers/market1501/person.py +6 -0
- configs/anonymizers/market1501/pixelation16.py +8 -0
- configs/anonymizers/market1501/pixelation8.py +8 -0
- configs/datasets/coco_cse.py +69 -0
- configs/datasets/fdf128.py +24 -0
- configs/datasets/fdf256.py +55 -0
- configs/datasets/fdh.py +90 -0
- configs/datasets/utils.py +21 -0
- configs/defaults.py +53 -0
- configs/discriminators/sg2_discriminator.py +43 -0
- configs/fdf/deep_privacy1.py +9 -0
- configs/fdf/stylegan.py +14 -0
- configs/fdf/stylegan_fdf128.py +17 -0
- configs/fdh/styleganL.py +16 -0
- configs/fdh/styleganL_nocse.py +14 -0
- configs/generators/stylegan_unet.py +22 -0
- dp2/__init__.py +0 -0
- dp2/anonymizer/__init__.py +1 -0
- dp2/anonymizer/anonymizer.py +163 -0
- dp2/anonymizer/histogram_match_anonymizers.py +93 -0
- dp2/data/__init__.py +0 -0
- dp2/data/build.py +40 -0
- dp2/data/datasets/__init__.py +0 -0
- dp2/data/datasets/coco_cse.py +68 -0
- dp2/data/datasets/fdf.py +128 -0
- dp2/data/datasets/fdf128_wds.py +96 -0
- dp2/data/datasets/fdh.py +142 -0
- dp2/data/transforms/__init__.py +2 -0
- dp2/data/transforms/functional.py +57 -0
- dp2/data/transforms/stylegan2_transform.py +394 -0
- dp2/data/transforms/transforms.py +277 -0
- dp2/data/utils.py +122 -0
- dp2/detection/__init__.py +3 -0
- dp2/detection/base.py +42 -0
- dp2/detection/box_utils.py +104 -0
- dp2/detection/box_utils_fdf.py +202 -0
- dp2/detection/cse_mask_face_detector.py +116 -0
- dp2/detection/deep_privacy1_detector.py +106 -0
- dp2/detection/face_detector.py +62 -0
.gitattributes
CHANGED
@@ -32,3 +32,15 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
torch_home/hub/checkpoints/21841da7-2546-4ce3-8460-909b3a63c58b13aac1a1-c778-4c8d-9b69-3e5ed2cde9de1524e76e-7aa6-4dd8-b643-52abc9f0792c filter=lfs diff=lfs merge=lfs -text
|
36 |
+
torch_home/hub/checkpoints/Base-DensePose-RCNN-FPN-Human.yaml filter=lfs diff=lfs merge=lfs -text
|
37 |
+
torch_home/hub/checkpoints/Base-DensePose-RCNN-FPN.yaml filter=lfs diff=lfs merge=lfs -text
|
38 |
+
torch_home/hub/checkpoints/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml filter=lfs diff=lfs merge=lfs -text
|
39 |
+
torch_home/hub/checkpoints/model_final_1d3314.pkl filter=lfs diff=lfs merge=lfs -text
|
40 |
+
torch_home/hub/checkpoints/89660f04-5c11-4dbf-adac-cbe2f11b0aeea25cbf78-7558-475a-b3c7-03f5c10b7934646b0720-ca0a-4d53-aded-daddbfa45c9e filter=lfs diff=lfs merge=lfs -text
|
41 |
+
torch_home/hub/checkpoints/WIDERFace_DSFD_RES152.pth filter=lfs diff=lfs merge=lfs -text
|
42 |
+
media2/stylemc_example.jpg filter=lfs diff=lfs merge=lfs -text
|
43 |
+
media2/erling.jpg filter=lfs diff=lfs merge=lfs -text
|
44 |
+
media2/g7_leaders.jpg filter=lfs diff=lfs merge=lfs -text
|
45 |
+
media2/regjeringen.jpg filter=lfs diff=lfs merge=lfs -text
|
46 |
+
media/ filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# FILES
|
2 |
+
*.flist
|
3 |
+
*.zip
|
4 |
+
*.out
|
5 |
+
*.npy
|
6 |
+
*.gz
|
7 |
+
*.ckpt
|
8 |
+
*.log
|
9 |
+
*.pyc
|
10 |
+
*.csv
|
11 |
+
*.yml
|
12 |
+
*.ods
|
13 |
+
*.ods#
|
14 |
+
*.json
|
15 |
+
build_docker.sh
|
16 |
+
|
17 |
+
# Images / Videos
|
18 |
+
#*.png
|
19 |
+
#*.jpg
|
20 |
+
*.jpeg
|
21 |
+
*.m4a
|
22 |
+
*.mkv
|
23 |
+
*.mp4
|
24 |
+
|
25 |
+
# Directories created by inpaintron
|
26 |
+
.cache/
|
27 |
+
test_examples/
|
28 |
+
.vscode
|
29 |
+
__pycache__
|
30 |
+
.debug/
|
31 |
+
**/.ipynb_checkpoints/**
|
32 |
+
outputs/
|
33 |
+
|
34 |
+
|
35 |
+
# From pip setup
|
36 |
+
build/
|
37 |
+
*.egg-info
|
38 |
+
*.egg
|
39 |
+
.npm/
|
40 |
+
|
41 |
+
# From dockerfile
|
42 |
+
.bash_history
|
43 |
+
.viminfo
|
44 |
+
.local/
|
45 |
+
*.pickle
|
46 |
+
*.onnx
|
47 |
+
|
48 |
+
|
49 |
+
sbatch_files/
|
50 |
+
figures/
|
51 |
+
image_dump/
|
README.md
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
---
|
2 |
-
title: Deep Privacy2
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
|
|
1 |
---
|
2 |
+
title: Deep Privacy2
|
3 |
+
emoji: 📈
|
4 |
+
colorFrom: gray
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.9.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
app.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio
|
2 |
+
import os
|
3 |
+
from tops.config import instantiate
|
4 |
+
import gradio.inputs
|
5 |
+
os.system("pip install --upgrade pip")
|
6 |
+
os.system("pip install ftfy regex tqdm")
|
7 |
+
os.system("pip install --no-deps git+https://github.com/openai/CLIP.git")
|
8 |
+
os.system("pip install git+https://github.com/facebookresearch/detectron2@96c752ce821a3340e27edd51c28a00665dd32a30#subdirectory=projects/DensePose")
|
9 |
+
os.system("pip install --no-deps git+https://github.com/hukkelas/DSFD-Pytorch-Inference")
|
10 |
+
os.environ["TORCH_HOME"] = "torch_home"
|
11 |
+
from dp2 import utils
|
12 |
+
from gradio_demos.modules import ExampleDemo, WebcamDemo
|
13 |
+
|
14 |
+
cfg_face = utils.load_config("configs/anonymizers/face.py")
|
15 |
+
|
16 |
+
anonymizer_face = instantiate(cfg_face.anonymizer, load_cache=False)
|
17 |
+
|
18 |
+
anonymizer_face.initialize_tracker(fps=1)
|
19 |
+
|
20 |
+
|
21 |
+
with gradio.Blocks() as demo:
|
22 |
+
gradio.Markdown("# <center> DeepPrivacy2 - Realistic Image Anonymization </center>")
|
23 |
+
gradio.Markdown("### <center> Håkon Hukkelås, Rudolf Mester, Frank Lindseth </center>")
|
24 |
+
gradio.Markdown("<center> See more information at: <a href='https://github.com/hukkelas/deep_privacy2'> https://github.com/hukkelas/deep_privacy2 </a> </center>")
|
25 |
+
gradio.Markdown("<center> For a demo of face anonymization, see: <a href='https://huggingface.co/spaces/haakohu/deep_privacy2_face'> https://huggingface.co/spaces/haakohu/deep_privacy2_face </a> </center>")
|
26 |
+
with gradio.Tab("Face Anonymization"):
|
27 |
+
ExampleDemo(anonymizer_face)
|
28 |
+
with gradio.Tab("Live Webcam"):
|
29 |
+
WebcamDemo(anonymizer_face)
|
30 |
+
|
31 |
+
demo.launch()
|
configs/anonymizers/FB_cse.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dp2.anonymizer import Anonymizer
|
2 |
+
from dp2.detection.person_detector import CSEPersonDetector
|
3 |
+
from ..defaults import common
|
4 |
+
from tops.config import LazyCall as L
|
5 |
+
from dp2.generator.dummy_generators import MaskOutGenerator
|
6 |
+
|
7 |
+
|
8 |
+
maskout_G = L(MaskOutGenerator)(noise="constant")
|
9 |
+
|
10 |
+
detector = L(CSEPersonDetector)(
|
11 |
+
mask_rcnn_cfg=dict(),
|
12 |
+
cse_cfg=dict(),
|
13 |
+
cse_post_process_cfg=dict(
|
14 |
+
target_imsize=(288, 160),
|
15 |
+
exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1),
|
16 |
+
exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]),
|
17 |
+
iou_combine_threshold=0.4,
|
18 |
+
dilation_percentage=0.02,
|
19 |
+
normalize_embedding=False
|
20 |
+
),
|
21 |
+
score_threshold=0.3,
|
22 |
+
cache_directory=common.output_dir.joinpath("cse_person_detection_cache")
|
23 |
+
)
|
24 |
+
|
25 |
+
anonymizer = L(Anonymizer)(
|
26 |
+
detector="${detector}",
|
27 |
+
cse_person_G_cfg="configs/fdh/styleganL.py",
|
28 |
+
)
|
configs/anonymizers/FB_cse_mask.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dp2.anonymizer import Anonymizer
|
2 |
+
from dp2.detection.person_detector import CSEPersonDetector
|
3 |
+
from ..defaults import common
|
4 |
+
from tops.config import LazyCall as L
|
5 |
+
from dp2.generator.dummy_generators import MaskOutGenerator
|
6 |
+
|
7 |
+
|
8 |
+
maskout_G = L(MaskOutGenerator)(noise="constant")
|
9 |
+
|
10 |
+
detector = L(CSEPersonDetector)(
|
11 |
+
mask_rcnn_cfg=dict(),
|
12 |
+
cse_cfg=dict(),
|
13 |
+
cse_post_process_cfg=dict(
|
14 |
+
target_imsize=(288, 160),
|
15 |
+
exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1),
|
16 |
+
exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]),
|
17 |
+
iou_combine_threshold=0.4,
|
18 |
+
dilation_percentage=0.02,
|
19 |
+
normalize_embedding=False
|
20 |
+
),
|
21 |
+
score_threshold=0.3,
|
22 |
+
cache_directory=common.output_dir.joinpath("cse_person_detection_cache")
|
23 |
+
)
|
24 |
+
|
25 |
+
anonymizer = L(Anonymizer)(
|
26 |
+
detector="${detector}",
|
27 |
+
person_G_cfg="configs/fdh/styleganL_nocse.py",
|
28 |
+
cse_person_G_cfg="configs/fdh/styleganL.py",
|
29 |
+
)
|
configs/anonymizers/FB_cse_mask_face.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dp2.anonymizer import Anonymizer
|
2 |
+
from dp2.detection.cse_mask_face_detector import CSeMaskFaceDetector
|
3 |
+
from ..defaults import common
|
4 |
+
from tops.config import LazyCall as L
|
5 |
+
|
6 |
+
detector = L(CSeMaskFaceDetector)(
|
7 |
+
mask_rcnn_cfg=dict(),
|
8 |
+
face_detector_cfg=dict(),
|
9 |
+
face_post_process_cfg=dict(target_imsize=(256, 256), fdf128_expand=False),
|
10 |
+
cse_cfg=dict(),
|
11 |
+
cse_post_process_cfg=dict(
|
12 |
+
target_imsize=(288, 160),
|
13 |
+
exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1),
|
14 |
+
exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]),
|
15 |
+
iou_combine_threshold=0.4,
|
16 |
+
dilation_percentage=0.02,
|
17 |
+
normalize_embedding=False
|
18 |
+
),
|
19 |
+
score_threshold=0.3,
|
20 |
+
cache_directory=common.output_dir.joinpath("cse_mask_face_detection_cache")
|
21 |
+
)
|
22 |
+
|
23 |
+
anonymizer = L(Anonymizer)(
|
24 |
+
detector="${detector}",
|
25 |
+
face_G_cfg="configs/fdf/stylegan.py",
|
26 |
+
person_G_cfg="configs/fdh/styleganL_nocse.py",
|
27 |
+
cse_person_G_cfg="configs/fdh/styleganL.py",
|
28 |
+
car_G_cfg="configs/generators/dummy/pixelation8.py"
|
29 |
+
)
|
configs/anonymizers/deep_privacy1.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .face_fdf128 import anonymizer, common, detector
|
2 |
+
from dp2.detection.deep_privacy1_detector import DeepPrivacy1Detector
|
3 |
+
from tops.config import LazyCall as L
|
4 |
+
|
5 |
+
anonymizer.update(
|
6 |
+
face_G_cfg="configs/fdf/deep_privacy1.py",
|
7 |
+
)
|
8 |
+
|
9 |
+
anonymizer.detector = L(DeepPrivacy1Detector)(
|
10 |
+
face_detector_cfg=dict(name="DSFDDetector", clip_boxes=True),
|
11 |
+
face_post_process_cfg=dict(target_imsize=(128, 128), fdf128_expand=True),
|
12 |
+
score_threshold=0.3,
|
13 |
+
keypoint_threshold=0.3,
|
14 |
+
cache_directory=common.output_dir.joinpath("deep_privacy1_cache")
|
15 |
+
)
|
configs/anonymizers/face.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dp2.anonymizer import Anonymizer
|
2 |
+
from dp2.detection.face_detector import FaceDetector
|
3 |
+
from ..defaults import common
|
4 |
+
from tops.config import LazyCall as L
|
5 |
+
|
6 |
+
|
7 |
+
detector = L(FaceDetector)(
|
8 |
+
face_detector_cfg=dict(name="DSFDDetector", clip_boxes=True),
|
9 |
+
face_post_process_cfg=dict(target_imsize=(256, 256), fdf128_expand=False),
|
10 |
+
score_threshold=0.3,
|
11 |
+
cache_directory=common.output_dir.joinpath("face_detection_cache"),
|
12 |
+
)
|
13 |
+
|
14 |
+
anonymizer = L(Anonymizer)(
|
15 |
+
detector="${detector}",
|
16 |
+
face_G_cfg="configs/fdf/stylegan.py",
|
17 |
+
)
|
configs/anonymizers/face_fdf128.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dp2.anonymizer import Anonymizer
|
2 |
+
from dp2.detection.face_detector import FaceDetector
|
3 |
+
from ..defaults import common
|
4 |
+
from tops.config import LazyCall as L
|
5 |
+
|
6 |
+
|
7 |
+
detector = L(FaceDetector)(
|
8 |
+
face_detector_cfg=dict(name="DSFDDetector", clip_boxes=True),
|
9 |
+
face_post_process_cfg=dict(target_imsize=(128, 128), fdf128_expand=True),
|
10 |
+
score_threshold=0.3,
|
11 |
+
cache_directory=common.output_dir.joinpath("face_detection_cache")
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
anonymizer = L(Anonymizer)(
|
16 |
+
detector="${detector}",
|
17 |
+
face_G_cfg="configs/fdf/stylegan_fdf128.py",
|
18 |
+
)
|
configs/anonymizers/market1501/blackout.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..FB_cse_mask_face import anonymizer, detector, common
|
2 |
+
|
3 |
+
detector.score_threshold = .1
|
4 |
+
detector.face_detector_cfg.confidence_threshold = .5
|
5 |
+
detector.cse_cfg.score_thres = 0.3
|
6 |
+
anonymizer.generators.face_G_cfg = None
|
7 |
+
anonymizer.generators.person_G_cfg = "configs/generators/dummy/maskout.py"
|
8 |
+
anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/maskout.py"
|
configs/anonymizers/market1501/person.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..FB_cse_mask_face import anonymizer, detector, common
|
2 |
+
|
3 |
+
detector.score_threshold = .1
|
4 |
+
detector.face_detector_cfg.confidence_threshold = .5
|
5 |
+
detector.cse_cfg.score_thres = 0.3
|
6 |
+
anonymizer.generators.face_G_cfg = None
|
configs/anonymizers/market1501/pixelation16.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..FB_cse_mask_face import anonymizer, detector, common
|
2 |
+
|
3 |
+
detector.score_threshold = .1
|
4 |
+
detector.face_detector_cfg.confidence_threshold = .5
|
5 |
+
detector.cse_cfg.score_thres = 0.3
|
6 |
+
anonymizer.generators.face_G_cfg = None
|
7 |
+
anonymizer.generators.person_G_cfg = "configs/generators/dummy/pixelation16.py"
|
8 |
+
anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/pixelation16.py"
|
configs/anonymizers/market1501/pixelation8.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..FB_cse_mask_face import anonymizer, detector, common
|
2 |
+
|
3 |
+
detector.score_threshold = .1
|
4 |
+
detector.face_detector_cfg.confidence_threshold = .5
|
5 |
+
detector.cse_cfg.score_thres = 0.3
|
6 |
+
anonymizer.generators.face_G_cfg = None
|
7 |
+
anonymizer.generators.person_G_cfg = "configs/generators/dummy/pixelation8.py"
|
8 |
+
anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/pixelation8.py"
|
configs/datasets/coco_cse.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
from tops.config import LazyCall as L
|
4 |
+
import torch
|
5 |
+
import functools
|
6 |
+
from dp2.data.datasets.coco_cse import CocoCSE
|
7 |
+
from dp2.data.build import get_dataloader
|
8 |
+
from dp2.data.transforms.transforms import CreateEmbedding, Normalize, Resize, ToFloat, CreateCondition, RandomHorizontalFlip
|
9 |
+
from dp2.data.transforms.stylegan2_transform import StyleGANAugmentPipe
|
10 |
+
from dp2.metrics.torch_metrics import compute_metrics_iteratively
|
11 |
+
from .utils import final_eval_fn
|
12 |
+
|
13 |
+
|
14 |
+
dataset_base_dir = os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data"
|
15 |
+
metrics_cache = os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache"
|
16 |
+
data_dir = Path(dataset_base_dir, "coco_cse")
|
17 |
+
data = dict(
|
18 |
+
imsize=(288, 160),
|
19 |
+
im_channels=3,
|
20 |
+
semantic_nc=26,
|
21 |
+
cse_nc=16,
|
22 |
+
train=dict(
|
23 |
+
dataset=L(CocoCSE)(data_dir.joinpath("train"), transform=None, normalize_E=False),
|
24 |
+
loader=L(get_dataloader)(
|
25 |
+
shuffle=True, num_workers=6, drop_last=True, prefetch_factor=2,
|
26 |
+
batch_size="${train.batch_size}",
|
27 |
+
dataset="${..dataset}",
|
28 |
+
infinite=True,
|
29 |
+
gpu_transform=L(torch.nn.Sequential)(*[
|
30 |
+
L(ToFloat)(),
|
31 |
+
L(StyleGANAugmentPipe)(
|
32 |
+
rotate=0.5, rotate_max=.05,
|
33 |
+
xint=.5, xint_max=0.05,
|
34 |
+
scale=.5, scale_std=.05,
|
35 |
+
aniso=0.5, aniso_std=.05,
|
36 |
+
xfrac=.5, xfrac_std=.05,
|
37 |
+
brightness=.5, brightness_std=.05,
|
38 |
+
contrast=.5, contrast_std=.1,
|
39 |
+
hue=.5, hue_max=.05,
|
40 |
+
saturation=.5, saturation_std=.5,
|
41 |
+
imgfilter=.5, imgfilter_std=.1),
|
42 |
+
L(RandomHorizontalFlip)(p=0.5),
|
43 |
+
L(CreateEmbedding)(),
|
44 |
+
L(Resize)(size="${data.imsize}"),
|
45 |
+
L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
|
46 |
+
L(CreateCondition)(),
|
47 |
+
])
|
48 |
+
)
|
49 |
+
),
|
50 |
+
val=dict(
|
51 |
+
dataset=L(CocoCSE)(data_dir.joinpath("val"), transform=None, normalize_E=False),
|
52 |
+
loader=L(get_dataloader)(
|
53 |
+
shuffle=False, num_workers=6, drop_last=True, prefetch_factor=2,
|
54 |
+
batch_size="${train.batch_size}",
|
55 |
+
dataset="${..dataset}",
|
56 |
+
infinite=False,
|
57 |
+
gpu_transform=L(torch.nn.Sequential)(*[
|
58 |
+
L(ToFloat)(),
|
59 |
+
L(CreateEmbedding)(),
|
60 |
+
L(Resize)(size="${data.imsize}"),
|
61 |
+
L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
|
62 |
+
L(CreateCondition)(),
|
63 |
+
])
|
64 |
+
)
|
65 |
+
),
|
66 |
+
# Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP.
|
67 |
+
train_evaluation_fn=functools.partial(compute_metrics_iteratively, cache_directory=Path(metrics_cache, "coco_cse_val"), include_two_fake=False),
|
68 |
+
evaluation_fn=functools.partial(final_eval_fn, cache_directory=Path(metrics_cache, "coco_cse_val_final"), include_two_fake=True)
|
69 |
+
)
|
configs/datasets/fdf128.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from functools import partial
|
3 |
+
from dp2.data.datasets.fdf import FDFDataset
|
4 |
+
from .fdf256 import data, dataset_base_dir, metrics_cache, final_eval_fn, train_eval_fn
|
5 |
+
|
6 |
+
data_dir = Path(dataset_base_dir, "fdf")
|
7 |
+
data.train.dataset.dirpath = data_dir.joinpath("train")
|
8 |
+
data.val.dataset.dirpath = data_dir.joinpath("val")
|
9 |
+
data.imsize = (128, 128)
|
10 |
+
|
11 |
+
|
12 |
+
data.train_evaluation_fn = partial(
|
13 |
+
train_eval_fn, cache_directory=Path(metrics_cache, "fdf128_val_train"))
|
14 |
+
data.evaluation_fn = partial(
|
15 |
+
final_eval_fn, cache_directory=Path(metrics_cache, "fdf128_val_final"))
|
16 |
+
|
17 |
+
data.train.dataset.update(
|
18 |
+
_target_ = FDFDataset,
|
19 |
+
imsize="${data.imsize}"
|
20 |
+
)
|
21 |
+
data.val.dataset.update(
|
22 |
+
_target_ = FDFDataset,
|
23 |
+
imsize="${data.imsize}"
|
24 |
+
)
|
configs/datasets/fdf256.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
from tops.config import LazyCall as L
|
4 |
+
import torch
|
5 |
+
import functools
|
6 |
+
from dp2.data.datasets.fdf import FDF256Dataset
|
7 |
+
from dp2.data.build import get_dataloader
|
8 |
+
from dp2.data.transforms.transforms import Normalize, Resize, ToFloat, CreateCondition, RandomHorizontalFlip
|
9 |
+
from .utils import final_eval_fn, train_eval_fn
|
10 |
+
|
11 |
+
|
12 |
+
dataset_base_dir = os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data"
|
13 |
+
metrics_cache = os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache"
|
14 |
+
data_dir = Path(dataset_base_dir, "fdf256")
|
15 |
+
data = dict(
|
16 |
+
imsize=(256, 256),
|
17 |
+
im_channels=3,
|
18 |
+
semantic_nc=None,
|
19 |
+
cse_nc=None,
|
20 |
+
n_keypoints=None,
|
21 |
+
train=dict(
|
22 |
+
dataset=L(FDF256Dataset)(dirpath=data_dir.joinpath("train"), transform=None, load_keypoints=False),
|
23 |
+
loader=L(get_dataloader)(
|
24 |
+
shuffle=True, num_workers=3, drop_last=True, prefetch_factor=2,
|
25 |
+
batch_size="${train.batch_size}",
|
26 |
+
dataset="${..dataset}",
|
27 |
+
infinite=True,
|
28 |
+
gpu_transform=L(torch.nn.Sequential)(*[
|
29 |
+
L(ToFloat)(),
|
30 |
+
L(RandomHorizontalFlip)(p=0.5),
|
31 |
+
L(Resize)(size="${data.imsize}"),
|
32 |
+
L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
|
33 |
+
L(CreateCondition)(),
|
34 |
+
])
|
35 |
+
)
|
36 |
+
),
|
37 |
+
val=dict(
|
38 |
+
dataset=L(FDF256Dataset)(dirpath=data_dir.joinpath("val"), transform=None, load_keypoints=False),
|
39 |
+
loader=L(get_dataloader)(
|
40 |
+
shuffle=False, num_workers=3, drop_last=False, prefetch_factor=2,
|
41 |
+
batch_size="${train.batch_size}",
|
42 |
+
dataset="${..dataset}",
|
43 |
+
infinite=False,
|
44 |
+
gpu_transform=L(torch.nn.Sequential)(*[
|
45 |
+
L(ToFloat)(),
|
46 |
+
L(Resize)(size="${data.imsize}"),
|
47 |
+
L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
|
48 |
+
L(CreateCondition)(),
|
49 |
+
])
|
50 |
+
)
|
51 |
+
),
|
52 |
+
# Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP.
|
53 |
+
train_evaluation_fn=functools.partial(train_eval_fn, cache_directory=Path(metrics_cache, "fdf_val_train")),
|
54 |
+
evaluation_fn=functools.partial(final_eval_fn, cache_directory=Path(metrics_cache, "fdf_val"))
|
55 |
+
)
|
configs/datasets/fdh.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
from tops.config import LazyCall as L
|
4 |
+
import torch
|
5 |
+
import functools
|
6 |
+
from dp2.data.datasets.fdh import get_dataloader_fdh_wds
|
7 |
+
from dp2.data.utils import get_coco_flipmap
|
8 |
+
from dp2.data.transforms.transforms import (
|
9 |
+
Normalize,
|
10 |
+
ToFloat,
|
11 |
+
CreateCondition,
|
12 |
+
RandomHorizontalFlip,
|
13 |
+
CreateEmbedding,
|
14 |
+
)
|
15 |
+
from dp2.metrics.torch_metrics import compute_metrics_iteratively
|
16 |
+
from dp2.metrics.fid_clip import compute_fid_clip
|
17 |
+
from dp2.metrics.ppl import calculate_ppl
|
18 |
+
from .utils import train_eval_fn
|
19 |
+
|
20 |
+
|
21 |
+
def final_eval_fn(*args, **kwargs):
|
22 |
+
result = compute_metrics_iteratively(*args, **kwargs)
|
23 |
+
result2 = calculate_ppl(*args, **kwargs, upsample_size=(288, 160))
|
24 |
+
result3 = compute_fid_clip(*args, **kwargs)
|
25 |
+
assert all(key not in result for key in result2)
|
26 |
+
result.update(result2)
|
27 |
+
result.update(result3)
|
28 |
+
return result
|
29 |
+
|
30 |
+
|
31 |
+
def get_cache_directory(imsize, subset):
|
32 |
+
return Path(metrics_cache, f"{subset}{imsize[0]}")
|
33 |
+
|
34 |
+
dataset_base_dir = (
|
35 |
+
os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data"
|
36 |
+
)
|
37 |
+
metrics_cache = (
|
38 |
+
os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache"
|
39 |
+
)
|
40 |
+
data_dir = Path(dataset_base_dir, "fdh")
|
41 |
+
data = dict(
|
42 |
+
imsize=(288, 160),
|
43 |
+
im_channels=3,
|
44 |
+
cse_nc=16,
|
45 |
+
n_keypoints=17,
|
46 |
+
train=dict(
|
47 |
+
loader=L(get_dataloader_fdh_wds)(
|
48 |
+
path=data_dir.joinpath("train", "out-{000000..001423}.tar"),
|
49 |
+
batch_size="${train.batch_size}",
|
50 |
+
num_workers=6,
|
51 |
+
transform=L(torch.nn.Sequential)(
|
52 |
+
L(RandomHorizontalFlip)(p=0.5, flip_map=get_coco_flipmap()),
|
53 |
+
),
|
54 |
+
gpu_transform=L(torch.nn.Sequential)(
|
55 |
+
L(ToFloat)(norm=False, keys=["img", "mask", "E_mask", "maskrcnn_mask"]),
|
56 |
+
L(CreateEmbedding)(embed_path=data_dir.joinpath("embed_map.torch")),
|
57 |
+
L(Normalize)(mean=[0.5*255, 0.5*255, 0.5*255], std=[0.5*255, 0.5*255, 0.5*255], inplace=True),
|
58 |
+
L(CreateCondition)(),
|
59 |
+
),
|
60 |
+
infinite=True,
|
61 |
+
shuffle=True,
|
62 |
+
partial_batches=False,
|
63 |
+
load_embedding=True,
|
64 |
+
keypoints_split="train",
|
65 |
+
load_new_keypoints=False
|
66 |
+
)
|
67 |
+
),
|
68 |
+
val=dict(
|
69 |
+
loader=L(get_dataloader_fdh_wds)(
|
70 |
+
path=data_dir.joinpath("val", "out-{000000..000023}.tar"),
|
71 |
+
batch_size="${train.batch_size}",
|
72 |
+
num_workers=6,
|
73 |
+
transform=None,
|
74 |
+
gpu_transform="${data.train.loader.gpu_transform}",
|
75 |
+
infinite=False,
|
76 |
+
shuffle=False,
|
77 |
+
partial_batches=True,
|
78 |
+
load_embedding=True,
|
79 |
+
keypoints_split="val",
|
80 |
+
load_new_keypoints="${data.train.loader.load_new_keypoints}"
|
81 |
+
)
|
82 |
+
),
|
83 |
+
# Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP.
|
84 |
+
train_evaluation_fn=L(functools.partial)(
|
85 |
+
train_eval_fn, cache_directory=L(get_cache_directory)(imsize="${data.imsize}", subset="fdh"),
|
86 |
+
data_len=30_000),
|
87 |
+
evaluation_fn=L(functools.partial)(
|
88 |
+
final_eval_fn, cache_directory=L(get_cache_directory)(imsize="${data.imsize}", subset="fdh_eval"),
|
89 |
+
data_len=30_000)
|
90 |
+
)
|
configs/datasets/utils.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dp2.metrics.ppl import calculate_ppl
|
2 |
+
from dp2.metrics.torch_metrics import compute_metrics_iteratively
|
3 |
+
from dp2.metrics.fid_clip import compute_fid_clip
|
4 |
+
|
5 |
+
|
6 |
+
def final_eval_fn(*args, **kwargs):
|
7 |
+
result = compute_metrics_iteratively(*args, **kwargs)
|
8 |
+
result2 = calculate_ppl(*args, **kwargs,)
|
9 |
+
result3 = compute_fid_clip(*args, **kwargs)
|
10 |
+
assert all(key not in result for key in result2)
|
11 |
+
result.update(result2)
|
12 |
+
result.update(result3)
|
13 |
+
return result
|
14 |
+
|
15 |
+
|
16 |
+
def train_eval_fn(*args, **kwargs):
|
17 |
+
result = compute_metrics_iteratively(*args, **kwargs)
|
18 |
+
result2 = compute_fid_clip(*args, **kwargs)
|
19 |
+
assert all(key not in result for key in result2)
|
20 |
+
result.update(result2)
|
21 |
+
return result
|
configs/defaults.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pathlib
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from tops.config import LazyCall as L
|
5 |
+
|
6 |
+
if "PRETRAINED_CHECKPOINTS_PATH" in os.environ:
|
7 |
+
PRETRAINED_CHECKPOINTS_PATH = pathlib.Path(os.environ["PRETRAINED_CHECKPOINTS_PATH"])
|
8 |
+
else:
|
9 |
+
PRETRAINED_CHECKPOINTS_PATH = pathlib.Path("pretrained_checkpoints")
|
10 |
+
if "BASE_OUTPUT_DIR" in os.environ:
|
11 |
+
BASE_OUTPUT_DIR = pathlib.Path(os.environ["BASE_OUTPUT_DIR"])
|
12 |
+
else:
|
13 |
+
BASE_OUTPUT_DIR = pathlib.Path("outputs")
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
common = dict(
|
18 |
+
logger_backend=["wandb", "stdout", "json", "image_dumper"],
|
19 |
+
wandb_project="deep_privacy2",
|
20 |
+
output_dir=BASE_OUTPUT_DIR,
|
21 |
+
experiment_name=None, # Optional experiment name to show on wandb
|
22 |
+
)
|
23 |
+
|
24 |
+
train = dict(
|
25 |
+
batch_size=32,
|
26 |
+
seed=0,
|
27 |
+
ims_per_log=1024,
|
28 |
+
ims_per_val=int(200e3),
|
29 |
+
max_images_to_train=int(12e6),
|
30 |
+
amp=dict(
|
31 |
+
enabled=True,
|
32 |
+
scaler_D=L(torch.cuda.amp.GradScaler)(init_scale=2**16, growth_factor=4, growth_interval=100, enabled="${..enabled}"),
|
33 |
+
scaler_G=L(torch.cuda.amp.GradScaler)(init_scale=2**16, growth_factor=4, growth_interval=100, enabled="${..enabled}"),
|
34 |
+
),
|
35 |
+
fp16_ddp_accumulate=False, # All gather gradients in fp16?
|
36 |
+
broadcast_buffers=False,
|
37 |
+
bias_act_plugin_enabled=True,
|
38 |
+
grid_sample_gradfix_enabled=True,
|
39 |
+
conv2d_gradfix_enabled=False,
|
40 |
+
channels_last=False,
|
41 |
+
compile_G=dict(
|
42 |
+
enabled=False,
|
43 |
+
mode="default" # default, reduce-overhead or max-autotune
|
44 |
+
),
|
45 |
+
compile_D=dict(
|
46 |
+
enabled=False,
|
47 |
+
mode="default" # default, reduce-overhead or max-autotune
|
48 |
+
)
|
49 |
+
)
|
50 |
+
|
51 |
+
# exponential moving average
|
52 |
+
EMA = dict(rampup=0.05)
|
53 |
+
|
configs/discriminators/sg2_discriminator.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tops.config import LazyCall as L
|
2 |
+
from dp2.discriminator import SG2Discriminator
|
3 |
+
import torch
|
4 |
+
from dp2.loss import StyleGAN2Loss
|
5 |
+
|
6 |
+
|
7 |
+
discriminator = L(SG2Discriminator)(
|
8 |
+
imsize="${data.imsize}",
|
9 |
+
im_channels="${data.im_channels}",
|
10 |
+
min_fmap_resolution=4,
|
11 |
+
max_cnum_mul=8,
|
12 |
+
cnum=80,
|
13 |
+
input_condition=True,
|
14 |
+
conv_clamp=256,
|
15 |
+
input_cse=False,
|
16 |
+
cse_nc="${data.cse_nc}",
|
17 |
+
fix_residual=False,
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
loss_fnc = L(StyleGAN2Loss)(
|
22 |
+
lazy_regularization=True,
|
23 |
+
lazy_reg_interval=16,
|
24 |
+
r1_opts=dict(lambd=5, mask_out=False, mask_out_scale=False),
|
25 |
+
EP_lambd=0.001,
|
26 |
+
pl_reg_opts=dict(weight=0, batch_shrink=2,start_nimg=int(1e6), pl_decay=0.01)
|
27 |
+
)
|
28 |
+
|
29 |
+
def build_D_optim(type, lr, betas, lazy_regularization, lazy_reg_interval, **kwargs):
|
30 |
+
if lazy_regularization:
|
31 |
+
# From Analyzing and improving the image quality of stylegan, CVPR 2020
|
32 |
+
c = lazy_reg_interval / (lazy_reg_interval + 1)
|
33 |
+
betas = [beta ** c for beta in betas]
|
34 |
+
lr *= c
|
35 |
+
print(f"Lazy regularization on. Setting lr to: {lr}, betas to: {betas}")
|
36 |
+
return type(lr=lr, betas=betas, **kwargs)
|
37 |
+
|
38 |
+
|
39 |
+
D_optim = L(build_D_optim)(
|
40 |
+
type=torch.optim.Adam, lr=0.001, betas=(0.0, 0.99),
|
41 |
+
lazy_regularization="${loss_fnc.lazy_regularization}",
|
42 |
+
lazy_reg_interval="${loss_fnc.lazy_reg_interval}")
|
43 |
+
G_optim = L(torch.optim.Adam)(lr=0.001, betas=(0.0, 0.99))
|
configs/fdf/deep_privacy1.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tops.config import LazyCall as L
|
2 |
+
from dp2.generator.deep_privacy1 import MSGGenerator
|
3 |
+
from ..datasets.fdf128 import data
|
4 |
+
from ..defaults import common, train
|
5 |
+
|
6 |
+
generator = L(MSGGenerator)()
|
7 |
+
|
8 |
+
common.model_url = "https://folk.ntnu.no/haakohu/checkpoints/fdf128_model512.ckpt"
|
9 |
+
common.model_md5sum = "6cc8b285bdc1fcdfc64f5db7c521d0a6"
|
configs/fdf/stylegan.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..generators.stylegan_unet import generator
|
2 |
+
from ..datasets.fdf256 import data
|
3 |
+
from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
|
4 |
+
from ..defaults import train, common, EMA
|
5 |
+
|
6 |
+
train.max_images_to_train = int(35e6)
|
7 |
+
G_optim.lr = 0.002
|
8 |
+
D_optim.lr = 0.002
|
9 |
+
generator.input_cse = False
|
10 |
+
loss_fnc.r1_opts.lambd = 1
|
11 |
+
train.ims_per_val = int(2e6)
|
12 |
+
|
13 |
+
common.model_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/89660f04-5c11-4dbf-adac-cbe2f11b0aeea25cbf78-7558-475a-b3c7-03f5c10b7934646b0720-ca0a-4d53-aded-daddbfa45c9e"
|
14 |
+
common.model_md5sum = "e8e32190528af2ed75f0cb792b7f2b07"
|
configs/fdf/stylegan_fdf128.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
|
2 |
+
from ..datasets.fdf128 import data
|
3 |
+
from ..generators.stylegan_unet import generator
|
4 |
+
from ..defaults import train, common, EMA
|
5 |
+
from tops.config import LazyCall as L
|
6 |
+
|
7 |
+
G_optim.lr = 0.002
|
8 |
+
D_optim.lr = 0.002
|
9 |
+
generator.update(cnum=128, max_cnum_mul=4, input_cse=False)
|
10 |
+
loss_fnc.r1_opts.lambd = 0.1
|
11 |
+
|
12 |
+
train.update(ims_per_val=int(2e6), batch_size=64, max_images_to_train=int(35e6))
|
13 |
+
|
14 |
+
common.update(
|
15 |
+
model_url="https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/66d803c0-55ce-44c0-9d53-815c2c0e6ba4eb458409-9e91-45d1-bce0-95c8a47a57218b102fdf-bea3-44dc-aac4-0fb1d370ef1c",
|
16 |
+
model_md5sum="bccd4403e7c9bca682566ff3319e8176"
|
17 |
+
)
|
configs/fdh/styleganL.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tops.config import LazyCall as L
|
2 |
+
from ..generators.stylegan_unet import generator
|
3 |
+
from ..datasets.fdh import data
|
4 |
+
from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
|
5 |
+
from ..defaults import train, common, EMA
|
6 |
+
|
7 |
+
train.max_images_to_train = int(50e6)
|
8 |
+
train.batch_size = 64
|
9 |
+
G_optim.lr = 0.002
|
10 |
+
D_optim.lr = 0.002
|
11 |
+
data.train.loader.num_workers = 4
|
12 |
+
train.ims_per_val = int(1e6)
|
13 |
+
loss_fnc.r1_opts.lambd = .1
|
14 |
+
|
15 |
+
common.model_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/21841da7-2546-4ce3-8460-909b3a63c58b13aac1a1-c778-4c8d-9b69-3e5ed2cde9de1524e76e-7aa6-4dd8-b643-52abc9f0792c"
|
16 |
+
common.model_md5sum = "3411478b5ec600a4219cccf4499732bd"
|
configs/fdh/styleganL_nocse.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tops.config import LazyCall as L
|
2 |
+
from ..generators.stylegan_unet import generator
|
3 |
+
from ..datasets.fdh import data
|
4 |
+
from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
|
5 |
+
from ..defaults import train, common, EMA
|
6 |
+
|
7 |
+
train.max_images_to_train = int(50e6)
|
8 |
+
G_optim.lr = 0.002
|
9 |
+
D_optim.lr = 0.002
|
10 |
+
generator.input_cse = False
|
11 |
+
data.load_embeddings = False
|
12 |
+
common.model_url = "https://folk.ntnu.no/haakohu/checkpoints/deep_privacy2/fdh_styleganL_nocse.ckpt"
|
13 |
+
common.model_md5sum = "fda0d809741bc67487abada793975c37"
|
14 |
+
generator.fix_errors = False
|
configs/generators/stylegan_unet.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dp2.generator.stylegan_unet import StyleGANUnet
|
2 |
+
from tops.config import LazyCall as L
|
3 |
+
|
4 |
+
generator = L(StyleGANUnet)(
|
5 |
+
imsize="${data.imsize}",
|
6 |
+
im_channels="${data.im_channels}",
|
7 |
+
min_fmap_resolution=8,
|
8 |
+
cnum=64,
|
9 |
+
max_cnum_mul=8,
|
10 |
+
n_middle_blocks=0,
|
11 |
+
z_channels=512,
|
12 |
+
mask_output=True,
|
13 |
+
conv_clamp=256,
|
14 |
+
input_cse=True,
|
15 |
+
scale_grad=True,
|
16 |
+
cse_nc="${data.cse_nc}",
|
17 |
+
w_dim=512,
|
18 |
+
n_keypoints="${data.n_keypoints}",
|
19 |
+
input_keypoints=False,
|
20 |
+
input_keypoint_indices=[],
|
21 |
+
fix_errors=True
|
22 |
+
)
|
dp2/__init__.py
ADDED
File without changes
|
dp2/anonymizer/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .anonymizer import Anonymizer
|
dp2/anonymizer/anonymizer.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Union, Optional
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import tops
|
6 |
+
import torchvision.transforms.functional as F
|
7 |
+
from motpy import Detection, MultiObjectTracker
|
8 |
+
from dp2.utils import load_config
|
9 |
+
from dp2.infer import build_trained_generator
|
10 |
+
from dp2.detection.structures import CSEPersonDetection, FaceDetection, PersonDetection, VehicleDetection
|
11 |
+
|
12 |
+
|
13 |
+
def load_generator_from_cfg_path(cfg_path: Union[str, Path]):
|
14 |
+
cfg = load_config(cfg_path)
|
15 |
+
G = build_trained_generator(cfg)
|
16 |
+
tops.logger.log(f"Loaded generator from: {cfg_path}")
|
17 |
+
return G
|
18 |
+
|
19 |
+
|
20 |
+
class Anonymizer:
|
21 |
+
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
detector,
|
25 |
+
load_cache: bool = False,
|
26 |
+
person_G_cfg: Optional[Union[str, Path]] = None,
|
27 |
+
cse_person_G_cfg: Optional[Union[str, Path]] = None,
|
28 |
+
face_G_cfg: Optional[Union[str, Path]] = None,
|
29 |
+
car_G_cfg: Optional[Union[str, Path]] = None,
|
30 |
+
) -> None:
|
31 |
+
self.detector = detector
|
32 |
+
self.generators = {k: None for k in [CSEPersonDetection, PersonDetection, FaceDetection, VehicleDetection]}
|
33 |
+
self.load_cache = load_cache
|
34 |
+
if cse_person_G_cfg is not None:
|
35 |
+
self.generators[CSEPersonDetection] = load_generator_from_cfg_path(cse_person_G_cfg)
|
36 |
+
if person_G_cfg is not None:
|
37 |
+
self.generators[PersonDetection] = load_generator_from_cfg_path(person_G_cfg)
|
38 |
+
if face_G_cfg is not None:
|
39 |
+
self.generators[FaceDetection] = load_generator_from_cfg_path(face_G_cfg)
|
40 |
+
if car_G_cfg is not None:
|
41 |
+
self.generators[VehicleDetection] = load_generator_from_cfg_path(car_G_cfg)
|
42 |
+
|
43 |
+
def initialize_tracker(self, fps: float):
|
44 |
+
self.tracker = MultiObjectTracker(dt=1/fps)
|
45 |
+
self.track_to_z_idx = dict()
|
46 |
+
|
47 |
+
def reset_tracker(self):
|
48 |
+
self.track_to_z_idx = dict()
|
49 |
+
|
50 |
+
def forward_G(self,
|
51 |
+
G,
|
52 |
+
batch,
|
53 |
+
multi_modal_truncation: bool,
|
54 |
+
amp: bool,
|
55 |
+
z_idx: int,
|
56 |
+
truncation_value: float,
|
57 |
+
idx: int,
|
58 |
+
all_styles=None):
|
59 |
+
batch["img"] = F.normalize(batch["img"].float(), [0.5*255, 0.5*255, 0.5*255], [0.5*255, 0.5*255, 0.5*255])
|
60 |
+
batch["img"] = batch["img"].float()
|
61 |
+
batch["condition"] = batch["mask"].float() * batch["img"]
|
62 |
+
|
63 |
+
with torch.cuda.amp.autocast(amp):
|
64 |
+
z = None
|
65 |
+
if z_idx is not None:
|
66 |
+
state = np.random.RandomState(seed=z_idx[idx])
|
67 |
+
z = state.normal(size=(1, G.z_channels)).astype(np.float32)
|
68 |
+
z = tops.to_cuda(torch.from_numpy(z))
|
69 |
+
|
70 |
+
if all_styles is not None:
|
71 |
+
anonymized_im = G(**batch, s=iter(all_styles[idx]))["img"]
|
72 |
+
elif multi_modal_truncation:
|
73 |
+
w_indices = None
|
74 |
+
if z_idx is not None:
|
75 |
+
w_indices = [z_idx[idx] % len(G.style_net.w_centers)]
|
76 |
+
anonymized_im = G.multi_modal_truncate(
|
77 |
+
**batch, truncation_value=truncation_value,
|
78 |
+
w_indices=w_indices,
|
79 |
+
z=z
|
80 |
+
)["img"]
|
81 |
+
else:
|
82 |
+
anonymized_im = G.sample(**batch, truncation_value=truncation_value, z=z)["img"]
|
83 |
+
anonymized_im = (anonymized_im+1).div(2).clamp(0, 1).mul(255)
|
84 |
+
return anonymized_im
|
85 |
+
|
86 |
+
@torch.no_grad()
|
87 |
+
def anonymize_detections(self,
|
88 |
+
im, detection,
|
89 |
+
update_identity=None,
|
90 |
+
**synthesis_kwargs
|
91 |
+
):
|
92 |
+
G = self.generators[type(detection)]
|
93 |
+
if G is None:
|
94 |
+
return im
|
95 |
+
C, H, W = im.shape
|
96 |
+
if update_identity is None:
|
97 |
+
update_identity = [True for i in range(len(detection))]
|
98 |
+
for idx in range(len(detection)):
|
99 |
+
if not update_identity[idx]:
|
100 |
+
continue
|
101 |
+
batch = detection.get_crop(idx, im)
|
102 |
+
x0, y0, x1, y1 = batch.pop("boxes")[0]
|
103 |
+
batch = {k: tops.to_cuda(v) for k, v in batch.items()}
|
104 |
+
anonymized_im = self.forward_G(G, batch, **synthesis_kwargs, idx=idx)
|
105 |
+
|
106 |
+
gim = F.resize(anonymized_im[0], (y1-y0, x1-x0), interpolation=F.InterpolationMode.BICUBIC, antialias=True)
|
107 |
+
mask = F.resize(batch["mask"][0], (y1-y0, x1-x0), interpolation=F.InterpolationMode.NEAREST).squeeze(0)
|
108 |
+
# Remove padding
|
109 |
+
pad = [max(-x0, 0), max(-y0, 0)]
|
110 |
+
pad = [*pad, max(x1-W, 0), max(y1-H, 0)]
|
111 |
+
def remove_pad(x): return x[..., pad[1]:x.shape[-2]-pad[3], pad[0]:x.shape[-1]-pad[2]]
|
112 |
+
|
113 |
+
gim = remove_pad(gim)
|
114 |
+
mask = remove_pad(mask) > 0.5
|
115 |
+
x0, y0 = max(x0, 0), max(y0, 0)
|
116 |
+
x1, y1 = min(x1, W), min(y1, H)
|
117 |
+
mask = mask.logical_not()[None].repeat(3, 1, 1)
|
118 |
+
|
119 |
+
im[:, y0:y1, x0:x1][mask] = gim[mask].round().clamp(0, 255).byte()
|
120 |
+
return im
|
121 |
+
|
122 |
+
def visualize_detection(self, im: torch.Tensor, cache_id: str = None) -> torch.Tensor:
|
123 |
+
all_detections = self.detector.forward_and_cache(im, cache_id, load_cache=self.load_cache)
|
124 |
+
im = im.cpu()
|
125 |
+
for det in all_detections:
|
126 |
+
im = det.visualize(im)
|
127 |
+
return im
|
128 |
+
|
129 |
+
@torch.no_grad()
|
130 |
+
def forward(self, im: torch.Tensor, cache_id: str = None, track=True, detections=None, **synthesis_kwargs) -> torch.Tensor:
|
131 |
+
assert im.dtype == torch.uint8
|
132 |
+
im = tops.to_cuda(im)
|
133 |
+
all_detections = detections
|
134 |
+
if detections is None:
|
135 |
+
if self.load_cache:
|
136 |
+
all_detections = self.detector.forward_and_cache(im, cache_id)
|
137 |
+
else:
|
138 |
+
all_detections = self.detector(im)
|
139 |
+
if hasattr(self, "tracker") and track:
|
140 |
+
[_.pre_process() for _ in all_detections]
|
141 |
+
boxes = np.concatenate([_.boxes for _ in all_detections])
|
142 |
+
boxes = [Detection(box) for box in boxes]
|
143 |
+
self.tracker.step(boxes)
|
144 |
+
track_ids = self.tracker.detections_matched_ids
|
145 |
+
z_idx = []
|
146 |
+
for track_id in track_ids:
|
147 |
+
if track_id not in self.track_to_z_idx:
|
148 |
+
self.track_to_z_idx[track_id] = np.random.randint(0, 2**32-1)
|
149 |
+
z_idx.append(self.track_to_z_idx[track_id])
|
150 |
+
z_idx = np.array(z_idx)
|
151 |
+
idx_offset = 0
|
152 |
+
|
153 |
+
for detection in all_detections:
|
154 |
+
zs = None
|
155 |
+
if hasattr(self, "tracker") and track:
|
156 |
+
zs = z_idx[idx_offset:idx_offset+len(detection)]
|
157 |
+
idx_offset += len(detection)
|
158 |
+
im = self.anonymize_detections(im, detection, z_idx=zs, **synthesis_kwargs)
|
159 |
+
|
160 |
+
return im.cpu()
|
161 |
+
|
162 |
+
def __call__(self, *args, **kwargs):
|
163 |
+
return self.forward(*args, **kwargs)
|
dp2/anonymizer/histogram_match_anonymizers.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import tops
|
4 |
+
import numpy as np
|
5 |
+
from kornia.color import rgb_to_hsv
|
6 |
+
from dp2 import utils
|
7 |
+
from kornia.enhance import histogram
|
8 |
+
from .anonymizer import Anonymizer
|
9 |
+
import torchvision.transforms.functional as F
|
10 |
+
from skimage.exposure import match_histograms
|
11 |
+
from kornia.filters import gaussian_blur2d
|
12 |
+
|
13 |
+
|
14 |
+
class LatentHistogramMatchAnonymizer(Anonymizer):
|
15 |
+
|
16 |
+
def forward_G(
|
17 |
+
self,
|
18 |
+
G,
|
19 |
+
batch,
|
20 |
+
multi_modal_truncation: bool,
|
21 |
+
amp: bool,
|
22 |
+
z_idx: int,
|
23 |
+
truncation_value: float,
|
24 |
+
idx: int,
|
25 |
+
n_sampling_steps: int = 1,
|
26 |
+
all_styles=None,
|
27 |
+
):
|
28 |
+
batch["img"] = F.normalize(batch["img"].float(), [0.5*255, 0.5*255, 0.5*255], [0.5*255, 0.5*255, 0.5*255])
|
29 |
+
batch["img"] = batch["img"].float()
|
30 |
+
batch["condition"] = batch["mask"].float() * batch["img"]
|
31 |
+
|
32 |
+
assert z_idx is None and all_styles is None, "Arguments not supported with n_sampling_steps > 1."
|
33 |
+
real_hls = rgb_to_hsv(utils.denormalize_img(batch["img"]))
|
34 |
+
real_hls[:, 0] /= 2 * torch.pi
|
35 |
+
indices = [1, 2]
|
36 |
+
hist_kwargs = dict(
|
37 |
+
bins=torch.linspace(0, 1, 256, dtype=torch.float32, device=tops.get_device()),
|
38 |
+
bandwidth=torch.tensor(1., device=tops.get_device()))
|
39 |
+
real_hist = [histogram(real_hls[:, i].flatten(start_dim=1), **hist_kwargs) for i in indices]
|
40 |
+
for j in range(n_sampling_steps):
|
41 |
+
if j == 0:
|
42 |
+
if multi_modal_truncation:
|
43 |
+
w = G.style_net.multi_modal_truncate(
|
44 |
+
truncation_value=truncation_value, **batch, w_indices=None).detach()
|
45 |
+
else:
|
46 |
+
w = G.style_net.get_truncated(truncation_value, **batch).detach()
|
47 |
+
assert z_idx is None and all_styles is None, "Arguments not supported with n_sampling_steps > 1."
|
48 |
+
w.requires_grad = True
|
49 |
+
optim = torch.optim.Adam([w])
|
50 |
+
with torch.set_grad_enabled(True):
|
51 |
+
with torch.cuda.amp.autocast(amp):
|
52 |
+
anonymized_im = G(**batch, truncation_value=None, w=w)["img"]
|
53 |
+
fake_hls = rgb_to_hsv(anonymized_im*0.5 + 0.5)
|
54 |
+
fake_hls[:, 0] /= 2 * torch.pi
|
55 |
+
fake_hist = [histogram(fake_hls[:, i].flatten(start_dim=1), **hist_kwargs) for i in indices]
|
56 |
+
dist = sum([utils.torch_wasserstein_loss(r, f) for r, f in zip(real_hist, fake_hist)])
|
57 |
+
dist.backward()
|
58 |
+
if w.grad.sum() == 0:
|
59 |
+
break
|
60 |
+
assert w.grad.sum() != 0
|
61 |
+
optim.step()
|
62 |
+
optim.zero_grad()
|
63 |
+
if dist < 0.02:
|
64 |
+
break
|
65 |
+
anonymized_im = (anonymized_im+1).div(2).clamp(0, 1).mul(255)
|
66 |
+
return anonymized_im
|
67 |
+
|
68 |
+
|
69 |
+
class HistogramMatchAnonymizer(Anonymizer):
|
70 |
+
|
71 |
+
def forward_G(self, batch, *args, **kwargs):
|
72 |
+
rimg = batch["img"]
|
73 |
+
batch["img"] = F.normalize(batch["img"].float(), [0.5*255, 0.5*255, 0.5*255], [0.5*255, 0.5*255, 0.5*255])
|
74 |
+
batch["img"] = batch["img"].float()
|
75 |
+
batch["condition"] = batch["mask"].float() * batch["img"]
|
76 |
+
|
77 |
+
anonymized_im = super().forward_G(batch, *args, **kwargs)
|
78 |
+
|
79 |
+
equalized_gim = match_histograms(tops.im2numpy(anonymized_im.round().clamp(0, 255).byte()), tops.im2numpy(rimg))
|
80 |
+
if equalized_gim.dtype != np.uint8:
|
81 |
+
equalized_gim = equalized_gim.astype(np.float32)
|
82 |
+
assert equalized_gim.dtype == np.float32, equalized_gim.dtype
|
83 |
+
equalized_gim = tops.im2torch(equalized_gim, to_float=False)[0]
|
84 |
+
else:
|
85 |
+
equalized_gim = tops.im2torch(equalized_gim, to_float=False).float()[0]
|
86 |
+
equalized_gim = equalized_gim.to(device=rimg.device)
|
87 |
+
assert equalized_gim.dtype == torch.float32
|
88 |
+
gaussian_mask = 1 - (batch["maskrcnn_mask"][0].repeat(3, 1, 1) > 0.5).float()
|
89 |
+
|
90 |
+
gaussian_mask = gaussian_blur2d(gaussian_mask[None], kernel_size=[19, 19], sigma=[10, 10])[0]
|
91 |
+
gaussian_mask = gaussian_mask / gaussian_mask.max()
|
92 |
+
anonymized_im = gaussian_mask * equalized_gim + (1-gaussian_mask) * anonymized_im
|
93 |
+
return anonymized_im
|
dp2/data/__init__.py
ADDED
File without changes
|
dp2/data/build.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import tops
|
3 |
+
from .utils import collate_fn
|
4 |
+
|
5 |
+
|
6 |
+
def get_dataloader(
|
7 |
+
dataset, gpu_transform: torch.nn.Module,
|
8 |
+
num_workers,
|
9 |
+
batch_size,
|
10 |
+
infinite: bool,
|
11 |
+
drop_last: bool,
|
12 |
+
prefetch_factor: int,
|
13 |
+
shuffle,
|
14 |
+
channels_last=False
|
15 |
+
):
|
16 |
+
sampler = None
|
17 |
+
dl_kwargs = dict(
|
18 |
+
pin_memory=True,
|
19 |
+
)
|
20 |
+
if infinite:
|
21 |
+
sampler = tops.InfiniteSampler(
|
22 |
+
dataset, rank=tops.rank(),
|
23 |
+
num_replicas=tops.world_size(),
|
24 |
+
shuffle=shuffle
|
25 |
+
)
|
26 |
+
elif tops.world_size() > 1:
|
27 |
+
sampler = torch.utils.data.DistributedSampler(
|
28 |
+
dataset, shuffle=shuffle, num_replicas=tops.world_size(), rank=tops.rank())
|
29 |
+
dl_kwargs["drop_last"] = drop_last
|
30 |
+
else:
|
31 |
+
dl_kwargs["shuffle"] = shuffle
|
32 |
+
dl_kwargs["drop_last"] = drop_last
|
33 |
+
dataloader = torch.utils.data.DataLoader(
|
34 |
+
dataset, sampler=sampler, collate_fn=collate_fn,
|
35 |
+
batch_size=batch_size,
|
36 |
+
num_workers=num_workers, prefetch_factor=prefetch_factor,
|
37 |
+
**dl_kwargs
|
38 |
+
)
|
39 |
+
dataloader = tops.DataPrefetcher(dataloader, gpu_transform, channels_last=channels_last)
|
40 |
+
return dataloader
|
dp2/data/datasets/__init__.py
ADDED
File without changes
|
dp2/data/datasets/coco_cse.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import torchvision
|
3 |
+
import torch
|
4 |
+
import pathlib
|
5 |
+
import numpy as np
|
6 |
+
from typing import Callable, Optional, Union
|
7 |
+
from torch.hub import get_dir as get_hub_dir
|
8 |
+
|
9 |
+
|
10 |
+
def cache_embed_stats(embed_map: torch.Tensor):
|
11 |
+
mean = embed_map.mean(dim=0, keepdim=True)
|
12 |
+
rstd = ((embed_map - mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt()
|
13 |
+
|
14 |
+
cache = dict(mean=mean, rstd=rstd, embed_map=embed_map)
|
15 |
+
path = pathlib.Path(get_hub_dir(), f"embed_map_stats.torch")
|
16 |
+
path.parent.mkdir(exist_ok=True, parents=True)
|
17 |
+
torch.save(cache, path)
|
18 |
+
|
19 |
+
|
20 |
+
class CocoCSE(torch.utils.data.Dataset):
|
21 |
+
|
22 |
+
def __init__(self,
|
23 |
+
dirpath: Union[str, pathlib.Path],
|
24 |
+
transform: Optional[Callable],
|
25 |
+
normalize_E: bool,):
|
26 |
+
dirpath = pathlib.Path(dirpath)
|
27 |
+
self.dirpath = dirpath
|
28 |
+
|
29 |
+
self.transform = transform
|
30 |
+
assert self.dirpath.is_dir(),\
|
31 |
+
f"Did not find dataset at: {dirpath}"
|
32 |
+
self.image_paths, self.embedding_paths = self._load_impaths()
|
33 |
+
self.embed_map = torch.from_numpy(np.load(self.dirpath.joinpath("embed_map.npy")))
|
34 |
+
mean = self.embed_map.mean(dim=0, keepdim=True)
|
35 |
+
rstd = ((self.embed_map - mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt()
|
36 |
+
self.embed_map = (self.embed_map - mean) * rstd
|
37 |
+
cache_embed_stats(self.embed_map)
|
38 |
+
|
39 |
+
def _load_impaths(self):
|
40 |
+
image_dir = self.dirpath.joinpath("images")
|
41 |
+
image_paths = list(image_dir.glob("*.png"))
|
42 |
+
image_paths.sort()
|
43 |
+
embedding_paths = [
|
44 |
+
self.dirpath.joinpath("embedding", x.stem + ".npy") for x in image_paths
|
45 |
+
]
|
46 |
+
return image_paths, embedding_paths
|
47 |
+
|
48 |
+
def __len__(self):
|
49 |
+
return len(self.image_paths)
|
50 |
+
|
51 |
+
def __getitem__(self, idx):
|
52 |
+
im = torchvision.io.read_image(str(self.image_paths[idx]))
|
53 |
+
vertices, mask, border = np.split(np.load(self.embedding_paths[idx]), 3, axis=-1)
|
54 |
+
vertices = torch.from_numpy(vertices.squeeze()).long()
|
55 |
+
mask = torch.from_numpy(mask.squeeze()).float()
|
56 |
+
border = torch.from_numpy(border.squeeze()).float()
|
57 |
+
E_mask = 1 - mask - border
|
58 |
+
batch = {
|
59 |
+
"img": im,
|
60 |
+
"vertices": vertices[None],
|
61 |
+
"mask": mask[None],
|
62 |
+
"embed_map": self.embed_map,
|
63 |
+
"border": border[None],
|
64 |
+
"E_mask": E_mask[None]
|
65 |
+
}
|
66 |
+
if self.transform is None:
|
67 |
+
return batch
|
68 |
+
return self.transform(batch)
|
dp2/data/datasets/fdf.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pathlib
|
2 |
+
from typing import Tuple
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import pathlib
|
6 |
+
try:
|
7 |
+
import pyspng
|
8 |
+
PYSPNG_IMPORTED = True
|
9 |
+
except ImportError:
|
10 |
+
PYSPNG_IMPORTED = False
|
11 |
+
print("Could not load pyspng. Defaulting to pillow image backend.")
|
12 |
+
from PIL import Image
|
13 |
+
from tops import logger
|
14 |
+
|
15 |
+
|
16 |
+
class FDFDataset:
|
17 |
+
|
18 |
+
def __init__(self,
|
19 |
+
dirpath,
|
20 |
+
imsize: Tuple[int],
|
21 |
+
load_keypoints: bool,
|
22 |
+
transform):
|
23 |
+
dirpath = pathlib.Path(dirpath)
|
24 |
+
self.dirpath = dirpath
|
25 |
+
self.transform = transform
|
26 |
+
self.imsize = imsize[0]
|
27 |
+
self.load_keypoints = load_keypoints
|
28 |
+
assert self.dirpath.is_dir(),\
|
29 |
+
f"Did not find dataset at: {dirpath}"
|
30 |
+
image_dir = self.dirpath.joinpath("images", str(self.imsize))
|
31 |
+
self.image_paths = list(image_dir.glob("*.png"))
|
32 |
+
assert len(self.image_paths) > 0,\
|
33 |
+
f"Did not find images in: {image_dir}"
|
34 |
+
self.image_paths.sort(key=lambda x: int(x.stem))
|
35 |
+
self.landmarks = np.load(self.dirpath.joinpath("landmarks.npy")).reshape(-1, 7, 2).astype(np.float32)
|
36 |
+
|
37 |
+
self.bounding_boxes = torch.load(self.dirpath.joinpath("bounding_box", f"{self.imsize}.torch"))
|
38 |
+
assert len(self.image_paths) == len(self.bounding_boxes)
|
39 |
+
assert len(self.image_paths) == len(self.landmarks)
|
40 |
+
logger.log(
|
41 |
+
f"Dataset loaded from: {dirpath}. Number of samples:{len(self)}, imsize={imsize}")
|
42 |
+
|
43 |
+
def get_mask(self, idx):
|
44 |
+
mask = torch.ones((1, self.imsize, self.imsize), dtype=torch.bool)
|
45 |
+
bounding_box = self.bounding_boxes[idx]
|
46 |
+
x0, y0, x1, y1 = bounding_box
|
47 |
+
mask[:, y0:y1, x0:x1] = 0
|
48 |
+
return mask
|
49 |
+
|
50 |
+
def __len__(self):
|
51 |
+
return len(self.image_paths)
|
52 |
+
|
53 |
+
def __getitem__(self, index):
|
54 |
+
impath = self.image_paths[index]
|
55 |
+
if PYSPNG_IMPORTED:
|
56 |
+
with open(impath, "rb") as fp:
|
57 |
+
im = pyspng.load(fp.read())
|
58 |
+
else:
|
59 |
+
with Image.open(impath) as fp:
|
60 |
+
im = np.array(fp)
|
61 |
+
im = torch.from_numpy(np.rollaxis(im, -1, 0))
|
62 |
+
masks = self.get_mask(index)
|
63 |
+
landmark = self.landmarks[index]
|
64 |
+
batch = {
|
65 |
+
"img": im,
|
66 |
+
"mask": masks,
|
67 |
+
}
|
68 |
+
if self.load_keypoints:
|
69 |
+
batch["keypoints"] = landmark
|
70 |
+
if self.transform is None:
|
71 |
+
return batch
|
72 |
+
return self.transform(batch)
|
73 |
+
|
74 |
+
|
75 |
+
class FDF256Dataset:
|
76 |
+
|
77 |
+
def __init__(self,
|
78 |
+
dirpath,
|
79 |
+
load_keypoints: bool,
|
80 |
+
transform):
|
81 |
+
dirpath = pathlib.Path(dirpath)
|
82 |
+
self.dirpath = dirpath
|
83 |
+
self.transform = transform
|
84 |
+
self.load_keypoints = load_keypoints
|
85 |
+
assert self.dirpath.is_dir(),\
|
86 |
+
f"Did not find dataset at: {dirpath}"
|
87 |
+
image_dir = self.dirpath.joinpath("images")
|
88 |
+
self.image_paths = list(image_dir.glob("*.png"))
|
89 |
+
assert len(self.image_paths) > 0,\
|
90 |
+
f"Did not find images in: {image_dir}"
|
91 |
+
self.image_paths.sort(key=lambda x: int(x.stem))
|
92 |
+
self.landmarks = np.load(self.dirpath.joinpath("landmarks.npy")).reshape(-1, 7, 2).astype(np.float32)
|
93 |
+
self.bounding_boxes = torch.from_numpy(np.load(self.dirpath.joinpath("bounding_box.npy")))
|
94 |
+
assert len(self.image_paths) == len(self.bounding_boxes)
|
95 |
+
assert len(self.image_paths) == len(self.landmarks)
|
96 |
+
logger.log(
|
97 |
+
f"Dataset loaded from: {dirpath}. Number of samples:{len(self)}")
|
98 |
+
|
99 |
+
def get_mask(self, idx):
|
100 |
+
mask = torch.ones((1, 256, 256), dtype=torch.bool)
|
101 |
+
bounding_box = self.bounding_boxes[idx]
|
102 |
+
x0, y0, x1, y1 = bounding_box
|
103 |
+
mask[:, y0:y1, x0:x1] = 0
|
104 |
+
return mask
|
105 |
+
|
106 |
+
def __len__(self):
|
107 |
+
return len(self.image_paths)
|
108 |
+
|
109 |
+
def __getitem__(self, index):
|
110 |
+
impath = self.image_paths[index]
|
111 |
+
if PYSPNG_IMPORTED:
|
112 |
+
with open(impath, "rb") as fp:
|
113 |
+
im = pyspng.load(fp.read())
|
114 |
+
else:
|
115 |
+
with Image.open(impath) as fp:
|
116 |
+
im = np.array(fp)
|
117 |
+
im = torch.from_numpy(np.rollaxis(im, -1, 0))
|
118 |
+
masks = self.get_mask(index)
|
119 |
+
landmark = self.landmarks[index]
|
120 |
+
batch = {
|
121 |
+
"img": im,
|
122 |
+
"mask": masks,
|
123 |
+
}
|
124 |
+
if self.load_keypoints:
|
125 |
+
batch["keypoints"] = landmark
|
126 |
+
if self.transform is None:
|
127 |
+
return batch
|
128 |
+
return self.transform(batch)
|
dp2/data/datasets/fdf128_wds.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import tops
|
3 |
+
import numpy as np
|
4 |
+
import io
|
5 |
+
import webdataset as wds
|
6 |
+
import os
|
7 |
+
from ..utils import png_decoder, get_num_workers, collate_fn
|
8 |
+
|
9 |
+
|
10 |
+
def kp_decoder(x):
|
11 |
+
# Keypoints are between [0, 1] for webdataset
|
12 |
+
keypoints = torch.from_numpy(np.load(io.BytesIO(x))).float().view(7, 2).clamp(0, 1)
|
13 |
+
keypoints = torch.cat((keypoints, torch.ones((7, 1))), dim=-1)
|
14 |
+
return keypoints
|
15 |
+
|
16 |
+
|
17 |
+
def bbox_decoder(x):
|
18 |
+
return torch.from_numpy(np.load(io.BytesIO(x))).float().view(4)
|
19 |
+
|
20 |
+
|
21 |
+
class BBoxToMask:
|
22 |
+
|
23 |
+
def __call__(self, sample):
|
24 |
+
imsize = sample["image.png"].shape[-1]
|
25 |
+
bbox = sample["bounding_box.npy"] * imsize
|
26 |
+
x0, y0, x1, y1 = np.round(bbox).astype(np.int64)
|
27 |
+
mask = torch.ones((1, imsize, imsize), dtype=torch.bool)
|
28 |
+
mask[:, y0:y1, x0:x1] = 0
|
29 |
+
sample["mask"] = mask
|
30 |
+
return sample
|
31 |
+
|
32 |
+
|
33 |
+
def get_dataloader_fdf_wds(
|
34 |
+
path,
|
35 |
+
batch_size: int,
|
36 |
+
num_workers: int,
|
37 |
+
transform: torch.nn.Module,
|
38 |
+
gpu_transform: torch.nn.Module,
|
39 |
+
infinite: bool,
|
40 |
+
shuffle: bool,
|
41 |
+
partial_batches: bool,
|
42 |
+
sample_shuffle=10_000,
|
43 |
+
tar_shuffle=100,
|
44 |
+
channels_last=False,
|
45 |
+
):
|
46 |
+
# Need to set this for split_by_node to work.
|
47 |
+
os.environ["RANK"] = str(tops.rank())
|
48 |
+
os.environ["WORLD_SIZE"] = str(tops.world_size())
|
49 |
+
if infinite:
|
50 |
+
pipeline = [wds.ResampledShards(str(path))]
|
51 |
+
else:
|
52 |
+
pipeline = [wds.SimpleShardList(str(path))]
|
53 |
+
if shuffle:
|
54 |
+
pipeline.append(wds.shuffle(tar_shuffle))
|
55 |
+
pipeline.extend([
|
56 |
+
wds.split_by_node,
|
57 |
+
wds.split_by_worker,
|
58 |
+
])
|
59 |
+
if shuffle:
|
60 |
+
pipeline.append(wds.shuffle(sample_shuffle))
|
61 |
+
|
62 |
+
decoder = [
|
63 |
+
wds.handle_extension("image.png", png_decoder),
|
64 |
+
wds.handle_extension("keypoints.npy", kp_decoder),
|
65 |
+
]
|
66 |
+
|
67 |
+
rename_keys = [
|
68 |
+
["img", "image.png"],
|
69 |
+
["keypoints", "keypoints.npy"],
|
70 |
+
["__key__", "__key__"],
|
71 |
+
["mask", "mask"]
|
72 |
+
]
|
73 |
+
|
74 |
+
pipeline.extend([
|
75 |
+
wds.tarfile_to_samples(),
|
76 |
+
wds.decode(*decoder),
|
77 |
+
])
|
78 |
+
pipeline.append(wds.map(BBoxToMask()))
|
79 |
+
pipeline.extend([
|
80 |
+
wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches),
|
81 |
+
wds.rename_keys(*rename_keys),
|
82 |
+
])
|
83 |
+
|
84 |
+
if transform is not None:
|
85 |
+
pipeline.append(wds.map(transform))
|
86 |
+
pipeline = wds.DataPipeline(*pipeline)
|
87 |
+
if infinite:
|
88 |
+
pipeline = pipeline.repeat(nepochs=1000000)
|
89 |
+
|
90 |
+
loader = wds.WebLoader(
|
91 |
+
pipeline, batch_size=None, shuffle=False,
|
92 |
+
num_workers=get_num_workers(num_workers),
|
93 |
+
persistent_workers=True,
|
94 |
+
)
|
95 |
+
loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False)
|
96 |
+
return loader
|
dp2/data/datasets/fdh.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import tops
|
3 |
+
import numpy as np
|
4 |
+
import io
|
5 |
+
import webdataset as wds
|
6 |
+
import os
|
7 |
+
import json
|
8 |
+
from pathlib import Path
|
9 |
+
from ..utils import png_decoder, mask_decoder, get_num_workers, collate_fn
|
10 |
+
|
11 |
+
|
12 |
+
def kp_decoder(x):
|
13 |
+
# Keypoints are between [0, 1] for webdataset
|
14 |
+
keypoints = torch.from_numpy(np.load(io.BytesIO(x))).float()
|
15 |
+
def check_outside(x): return (x < 0).logical_or(x > 1)
|
16 |
+
is_outside = check_outside(keypoints[:, 0]).logical_or(
|
17 |
+
check_outside(keypoints[:, 1])
|
18 |
+
)
|
19 |
+
keypoints[:, 2] = (keypoints[:, 2] > 0).logical_and(is_outside.logical_not())
|
20 |
+
return keypoints
|
21 |
+
|
22 |
+
|
23 |
+
def vertices_decoder(x):
|
24 |
+
vertices = torch.from_numpy(np.load(io.BytesIO(x)).astype(np.int32))
|
25 |
+
return vertices.squeeze()[None]
|
26 |
+
|
27 |
+
|
28 |
+
class InsertNewKeypoints:
|
29 |
+
|
30 |
+
def __init__(self, keypoints_path: Path) -> None:
|
31 |
+
with open(keypoints_path, "r") as fp:
|
32 |
+
self.keypoints = json.load(fp)
|
33 |
+
|
34 |
+
def __call__(self, sample):
|
35 |
+
key = sample["__key__"]
|
36 |
+
keypoints = torch.tensor(self.keypoints[key], dtype=torch.float32)
|
37 |
+
def check_outside(x): return (x < 0).logical_or(x > 1)
|
38 |
+
is_outside = check_outside(keypoints[:, 0]).logical_or(
|
39 |
+
check_outside(keypoints[:, 1])
|
40 |
+
)
|
41 |
+
keypoints[:, 2] = (keypoints[:, 2] > 0).logical_and(is_outside.logical_not())
|
42 |
+
|
43 |
+
sample["keypoints.npy"] = keypoints
|
44 |
+
return sample
|
45 |
+
|
46 |
+
|
47 |
+
def get_dataloader_fdh_wds(
|
48 |
+
path,
|
49 |
+
batch_size: int,
|
50 |
+
num_workers: int,
|
51 |
+
transform: torch.nn.Module,
|
52 |
+
gpu_transform: torch.nn.Module,
|
53 |
+
infinite: bool,
|
54 |
+
shuffle: bool,
|
55 |
+
partial_batches: bool,
|
56 |
+
load_embedding: bool,
|
57 |
+
sample_shuffle=10_000,
|
58 |
+
tar_shuffle=100,
|
59 |
+
read_condition=False,
|
60 |
+
channels_last=False,
|
61 |
+
load_new_keypoints=False,
|
62 |
+
keypoints_split=None,
|
63 |
+
):
|
64 |
+
# Need to set this for split_by_node to work.
|
65 |
+
os.environ["RANK"] = str(tops.rank())
|
66 |
+
os.environ["WORLD_SIZE"] = str(tops.world_size())
|
67 |
+
if infinite:
|
68 |
+
pipeline = [wds.ResampledShards(str(path))]
|
69 |
+
else:
|
70 |
+
pipeline = [wds.SimpleShardList(str(path))]
|
71 |
+
if shuffle:
|
72 |
+
pipeline.append(wds.shuffle(tar_shuffle))
|
73 |
+
pipeline.extend([
|
74 |
+
wds.split_by_node,
|
75 |
+
wds.split_by_worker,
|
76 |
+
])
|
77 |
+
if shuffle:
|
78 |
+
pipeline.append(wds.shuffle(sample_shuffle))
|
79 |
+
|
80 |
+
decoder = [
|
81 |
+
wds.handle_extension("image.png", png_decoder),
|
82 |
+
wds.handle_extension("mask.png", mask_decoder),
|
83 |
+
wds.handle_extension("maskrcnn_mask.png", mask_decoder),
|
84 |
+
wds.handle_extension("keypoints.npy", kp_decoder),
|
85 |
+
]
|
86 |
+
|
87 |
+
rename_keys = [
|
88 |
+
["img", "image.png"], ["mask", "mask.png"],
|
89 |
+
["keypoints", "keypoints.npy"], ["maskrcnn_mask", "maskrcnn_mask.png"],
|
90 |
+
["__key__", "__key__"]
|
91 |
+
]
|
92 |
+
if load_embedding:
|
93 |
+
decoder.extend([
|
94 |
+
wds.handle_extension("vertices.npy", vertices_decoder),
|
95 |
+
wds.handle_extension("E_mask.png", mask_decoder)
|
96 |
+
])
|
97 |
+
rename_keys.extend([
|
98 |
+
["vertices", "vertices.npy"],
|
99 |
+
["E_mask", "e_mask.png"]
|
100 |
+
])
|
101 |
+
|
102 |
+
if read_condition:
|
103 |
+
decoder.append(
|
104 |
+
wds.handle_extension("condition.png", png_decoder)
|
105 |
+
)
|
106 |
+
rename_keys.append(["condition", "condition.png"])
|
107 |
+
|
108 |
+
pipeline.extend([
|
109 |
+
wds.tarfile_to_samples(),
|
110 |
+
wds.decode(*decoder),
|
111 |
+
|
112 |
+
])
|
113 |
+
if load_new_keypoints:
|
114 |
+
assert keypoints_split in ["train", "val"]
|
115 |
+
keypoint_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/1eb88522-8b91-49c7-b56a-ed98a9c7888cef9c0429-a385-4248-abe3-8682de26d041f268aed1-7c88-4677-baad-7623c2ee330f"
|
116 |
+
file_name = "fdh_keypoints_val-050133b34d.json"
|
117 |
+
if keypoints_split == "train":
|
118 |
+
keypoint_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/3e828b1c-d6c0-4622-90bc-1b2cce48ccfff14ab45d-0a5c-431d-be13-7e60580765bd7938601c-e72e-41d9-8836-fffc49e76f58"
|
119 |
+
file_name = "fdh_keypoints_train-2cff11f69a.json"
|
120 |
+
# Set check_hash=True if you suspect download is incorrect.
|
121 |
+
filepath = tops.download_file(keypoint_url, file_name=file_name, check_hash=False)
|
122 |
+
pipeline.append(
|
123 |
+
wds.map(InsertNewKeypoints(filepath))
|
124 |
+
)
|
125 |
+
pipeline.extend([
|
126 |
+
wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches),
|
127 |
+
wds.rename_keys(*rename_keys),
|
128 |
+
])
|
129 |
+
|
130 |
+
if transform is not None:
|
131 |
+
pipeline.append(wds.map(transform))
|
132 |
+
pipeline = wds.DataPipeline(*pipeline)
|
133 |
+
if infinite:
|
134 |
+
pipeline = pipeline.repeat(nepochs=1000000)
|
135 |
+
|
136 |
+
loader = wds.WebLoader(
|
137 |
+
pipeline, batch_size=None, shuffle=False,
|
138 |
+
num_workers=get_num_workers(num_workers),
|
139 |
+
persistent_workers=True,
|
140 |
+
)
|
141 |
+
loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False)
|
142 |
+
return loader
|
dp2/data/transforms/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .transforms import RandomCrop, CreateCondition, CreateEmbedding, Resize, ToFloat, Normalize
|
2 |
+
from .stylegan2_transform import StyleGANAugmentPipe
|
dp2/data/transforms/functional.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torchvision.transforms.functional as F
|
2 |
+
import torch
|
3 |
+
import pickle
|
4 |
+
from tops import download_file, assert_shape
|
5 |
+
from typing import Dict
|
6 |
+
from functools import lru_cache
|
7 |
+
|
8 |
+
global symmetry_transform
|
9 |
+
|
10 |
+
|
11 |
+
@lru_cache(maxsize=1)
|
12 |
+
def get_symmetry_transform(symmetry_url):
|
13 |
+
file_name = download_file(symmetry_url)
|
14 |
+
with open(file_name, "rb") as fp:
|
15 |
+
symmetry = pickle.load(fp)
|
16 |
+
return torch.from_numpy(symmetry["vertex_transforms"]).long()
|
17 |
+
|
18 |
+
|
19 |
+
hflip_handled_cases = set([
|
20 |
+
"keypoints", "img", "mask", "border", "semantic_mask", "vertices", "E_mask", "embed_map", "condition",
|
21 |
+
"embedding", "vertx2cat", "maskrcnn_mask", "__key__"])
|
22 |
+
|
23 |
+
|
24 |
+
def hflip(container: Dict[str, torch.Tensor], flip_map=None) -> Dict[str, torch.Tensor]:
|
25 |
+
container["img"] = F.hflip(container["img"])
|
26 |
+
if "condition" in container:
|
27 |
+
container["condition"] = F.hflip(container["condition"])
|
28 |
+
if "embedding" in container:
|
29 |
+
container["embedding"] = F.hflip(container["embedding"])
|
30 |
+
assert all([key in hflip_handled_cases for key in container]), container.keys()
|
31 |
+
if "keypoints" in container:
|
32 |
+
assert flip_map is not None
|
33 |
+
if container["keypoints"].ndim == 3:
|
34 |
+
keypoints = container["keypoints"][:, flip_map, :]
|
35 |
+
keypoints[:, :, 0] = 1 - keypoints[:, :, 0]
|
36 |
+
else:
|
37 |
+
assert_shape(container["keypoints"], (None, 3))
|
38 |
+
keypoints = container["keypoints"][flip_map, :]
|
39 |
+
keypoints[:, 0] = 1 - keypoints[:, 0]
|
40 |
+
container["keypoints"] = keypoints
|
41 |
+
if "mask" in container:
|
42 |
+
container["mask"] = F.hflip(container["mask"])
|
43 |
+
if "border" in container:
|
44 |
+
container["border"] = F.hflip(container["border"])
|
45 |
+
if "semantic_mask" in container:
|
46 |
+
container["semantic_mask"] = F.hflip(container["semantic_mask"])
|
47 |
+
if "vertices" in container:
|
48 |
+
symmetry_transform = get_symmetry_transform(
|
49 |
+
"https://dl.fbaipublicfiles.com/densepose/meshes/symmetry/symmetry_smpl_27554.pkl")
|
50 |
+
container["vertices"] = F.hflip(container["vertices"])
|
51 |
+
symmetry_transform_ = symmetry_transform.to(container["vertices"].device)
|
52 |
+
container["vertices"] = symmetry_transform_[container["vertices"].long()]
|
53 |
+
if "E_mask" in container:
|
54 |
+
container["E_mask"] = F.hflip(container["E_mask"])
|
55 |
+
if "maskrcnn_mask" in container:
|
56 |
+
container["maskrcnn_mask"] = F.hflip(container["maskrcnn_mask"])
|
57 |
+
return container
|
dp2/data/transforms/stylegan2_transform.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import scipy.signal
|
3 |
+
import torch
|
4 |
+
try:
|
5 |
+
from sg3_torch_utils import misc
|
6 |
+
from sg3_torch_utils.ops import upfirdn2d
|
7 |
+
from sg3_torch_utils.ops import grid_sample_gradfix
|
8 |
+
from sg3_torch_utils.ops import conv2d_gradfix
|
9 |
+
except:
|
10 |
+
pass
|
11 |
+
#----------------------------------------------------------------------------
|
12 |
+
# Coefficients of various wavelet decomposition low-pass filters.
|
13 |
+
|
14 |
+
wavelets = {
|
15 |
+
'haar': [0.7071067811865476, 0.7071067811865476],
|
16 |
+
'db1': [0.7071067811865476, 0.7071067811865476],
|
17 |
+
'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
|
18 |
+
'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
|
19 |
+
'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523],
|
20 |
+
'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125],
|
21 |
+
'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017],
|
22 |
+
'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236],
|
23 |
+
'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161],
|
24 |
+
'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
|
25 |
+
'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
|
26 |
+
'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427],
|
27 |
+
'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728],
|
28 |
+
'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148],
|
29 |
+
'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255],
|
30 |
+
'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609],
|
31 |
+
}
|
32 |
+
|
33 |
+
#----------------------------------------------------------------------------
|
34 |
+
# Helpers for constructing transformation matrices.
|
35 |
+
|
36 |
+
|
37 |
+
def matrix(*rows, device=None):
|
38 |
+
assert all(len(row) == len(rows[0]) for row in rows)
|
39 |
+
elems = [x for row in rows for x in row]
|
40 |
+
ref = [x for x in elems if isinstance(x, torch.Tensor)]
|
41 |
+
if len(ref) == 0:
|
42 |
+
return misc.constant(np.asarray(rows), device=device)
|
43 |
+
assert device is None or device == ref[0].device
|
44 |
+
elems = [x if isinstance(x, torch.Tensor) else misc.constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems]
|
45 |
+
return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1))
|
46 |
+
|
47 |
+
|
48 |
+
def translate2d(tx, ty, **kwargs):
|
49 |
+
return matrix(
|
50 |
+
[1, 0, tx],
|
51 |
+
[0, 1, ty],
|
52 |
+
[0, 0, 1],
|
53 |
+
**kwargs)
|
54 |
+
|
55 |
+
|
56 |
+
def translate3d(tx, ty, tz, **kwargs):
|
57 |
+
return matrix(
|
58 |
+
[1, 0, 0, tx],
|
59 |
+
[0, 1, 0, ty],
|
60 |
+
[0, 0, 1, tz],
|
61 |
+
[0, 0, 0, 1],
|
62 |
+
**kwargs)
|
63 |
+
|
64 |
+
|
65 |
+
def scale2d(sx, sy, **kwargs):
|
66 |
+
return matrix(
|
67 |
+
[sx, 0, 0],
|
68 |
+
[0, sy, 0],
|
69 |
+
[0, 0, 1],
|
70 |
+
**kwargs)
|
71 |
+
|
72 |
+
|
73 |
+
def scale3d(sx, sy, sz, **kwargs):
|
74 |
+
return matrix(
|
75 |
+
[sx, 0, 0, 0],
|
76 |
+
[0, sy, 0, 0],
|
77 |
+
[0, 0, sz, 0],
|
78 |
+
[0, 0, 0, 1],
|
79 |
+
**kwargs)
|
80 |
+
|
81 |
+
|
82 |
+
def rotate2d(theta, **kwargs):
|
83 |
+
return matrix(
|
84 |
+
[torch.cos(theta), torch.sin(-theta), 0],
|
85 |
+
[torch.sin(theta), torch.cos(theta), 0],
|
86 |
+
[0, 0, 1],
|
87 |
+
**kwargs)
|
88 |
+
|
89 |
+
|
90 |
+
def rotate3d(v, theta, **kwargs):
|
91 |
+
vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2]
|
92 |
+
s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c
|
93 |
+
return matrix(
|
94 |
+
[vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0],
|
95 |
+
[vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0],
|
96 |
+
[vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0],
|
97 |
+
[0, 0, 0, 1],
|
98 |
+
**kwargs)
|
99 |
+
|
100 |
+
|
101 |
+
def translate2d_inv(tx, ty, **kwargs):
|
102 |
+
return translate2d(-tx, -ty, **kwargs)
|
103 |
+
|
104 |
+
|
105 |
+
def scale2d_inv(sx, sy, **kwargs):
|
106 |
+
return scale2d(1 / sx, 1 / sy, **kwargs)
|
107 |
+
|
108 |
+
|
109 |
+
def rotate2d_inv(theta, **kwargs):
|
110 |
+
return rotate2d(-theta, **kwargs)
|
111 |
+
|
112 |
+
|
113 |
+
class StyleGANAugmentPipe(torch.nn.Module):
|
114 |
+
def __init__(self,
|
115 |
+
rotate90=0, xint=0, xint_max=0.125,
|
116 |
+
scale=0, rotate=0, aniso=0, xfrac=0, scale_std=0.2, rotate_max=1, aniso_std=0.2, xfrac_std=0.125,
|
117 |
+
brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5,
|
118 |
+
hue_max=1, saturation_std=1,
|
119 |
+
imgfilter=0, imgfilter_bands=[1,1,1,1], imgfilter_std=1,
|
120 |
+
):
|
121 |
+
super().__init__()
|
122 |
+
self.register_buffer('p', torch.ones([])) # Overall multiplier for augmentation probability.
|
123 |
+
|
124 |
+
# Pixel blitting.
|
125 |
+
self.rotate90 = float(rotate90) # Probability multiplier for 90 degree rotations.
|
126 |
+
self.xint = float(xint) # Probability multiplier for integer translation.
|
127 |
+
self.xint_max = float(xint_max) # Range of integer translation, relative to image dimensions.
|
128 |
+
|
129 |
+
# General geometric transformations.
|
130 |
+
self.scale = float(scale) # Probability multiplier for isotropic scaling.
|
131 |
+
self.rotate = float(rotate) # Probability multiplier for arbitrary rotation.
|
132 |
+
self.aniso = float(aniso) # Probability multiplier for anisotropic scaling.
|
133 |
+
self.xfrac = float(xfrac) # Probability multiplier for fractional translation.
|
134 |
+
self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling.
|
135 |
+
self.rotate_max = float(rotate_max) # Range of arbitrary rotation, 1 = full circle.
|
136 |
+
self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling.
|
137 |
+
self.xfrac_std = float(xfrac_std) # Standard deviation of frational translation, relative to image dimensions.
|
138 |
+
|
139 |
+
# Color transformations.
|
140 |
+
self.brightness = float(brightness) # Probability multiplier for brightness.
|
141 |
+
self.contrast = float(contrast) # Probability multiplier for contrast.
|
142 |
+
self.lumaflip = float(lumaflip) # Probability multiplier for luma flip.
|
143 |
+
self.hue = float(hue) # Probability multiplier for hue rotation.
|
144 |
+
self.saturation = float(saturation) # Probability multiplier for saturation.
|
145 |
+
self.brightness_std = float(brightness_std) # Standard deviation of brightness.
|
146 |
+
self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast.
|
147 |
+
self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle.
|
148 |
+
self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation.
|
149 |
+
|
150 |
+
# Image-space filtering.
|
151 |
+
self.imgfilter = float(imgfilter) # Probability multiplier for image-space filtering.
|
152 |
+
self.imgfilter_bands = list(imgfilter_bands) # Probability multipliers for individual frequency bands.
|
153 |
+
self.imgfilter_std = float(imgfilter_std) # Log2 standard deviation of image-space filter amplification.
|
154 |
+
|
155 |
+
# Setup orthogonal lowpass filter for geometric augmentations.
|
156 |
+
self.register_buffer('Hz_geom', upfirdn2d.setup_filter(wavelets['sym6']))
|
157 |
+
|
158 |
+
# Construct filter bank for image-space filtering.
|
159 |
+
Hz_lo = np.asarray(wavelets['sym2']) # H(z)
|
160 |
+
Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) # H(-z)
|
161 |
+
Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2
|
162 |
+
Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2
|
163 |
+
Hz_fbank = np.eye(4, 1) # Bandpass(H(z), b_i)
|
164 |
+
for i in range(1, Hz_fbank.shape[0]):
|
165 |
+
Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape(Hz_fbank.shape[0], -1)[:, :-1]
|
166 |
+
Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2])
|
167 |
+
Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) // 2 : (Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2
|
168 |
+
self.register_buffer('Hz_fbank', torch.as_tensor(Hz_fbank, dtype=torch.float32))
|
169 |
+
|
170 |
+
def forward(self, batch, debug_percentile=None):
|
171 |
+
images = batch["img"]
|
172 |
+
batch["vertices"] = batch["vertices"].float()
|
173 |
+
assert isinstance(images, torch.Tensor) and images.ndim == 4
|
174 |
+
batch_size, num_channels, height, width = images.shape
|
175 |
+
device = images.device
|
176 |
+
self.Hz_fbank = self.Hz_fbank.to(device)
|
177 |
+
self.Hz_geom = self.Hz_geom.to(device)
|
178 |
+
if debug_percentile is not None:
|
179 |
+
debug_percentile = torch.as_tensor(debug_percentile, dtype=torch.float32, device=device)
|
180 |
+
|
181 |
+
# -------------------------------------
|
182 |
+
# Select parameters for pixel blitting.
|
183 |
+
# -------------------------------------
|
184 |
+
|
185 |
+
# Initialize inverse homogeneous 2D transform: G_inv @ pixel_out ==> pixel_in
|
186 |
+
I_3 = torch.eye(3, device=device)
|
187 |
+
G_inv = I_3
|
188 |
+
|
189 |
+
# Apply integer translation with probability (xint * strength).
|
190 |
+
if self.xint > 0:
|
191 |
+
t = (torch.rand([batch_size, 2], device=device) * 2 - 1) * self.xint_max
|
192 |
+
t = torch.where(torch.rand([batch_size, 1], device=device) < self.xint * self.p, t, torch.zeros_like(t))
|
193 |
+
if debug_percentile is not None:
|
194 |
+
t = torch.full_like(t, (debug_percentile * 2 - 1) * self.xint_max)
|
195 |
+
G_inv = G_inv @ translate2d_inv(torch.round(t[:,0] * width), torch.round(t[:,1] * height))
|
196 |
+
|
197 |
+
# --------------------------------------------------------
|
198 |
+
# Select parameters for general geometric transformations.
|
199 |
+
# --------------------------------------------------------
|
200 |
+
|
201 |
+
# Apply isotropic scaling with probability (scale * strength).
|
202 |
+
if self.scale > 0:
|
203 |
+
s = torch.exp2(torch.randn([batch_size], device=device) * self.scale_std)
|
204 |
+
s = torch.where(torch.rand([batch_size], device=device) < self.scale * self.p, s, torch.ones_like(s))
|
205 |
+
if debug_percentile is not None:
|
206 |
+
s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.scale_std))
|
207 |
+
G_inv = G_inv @ scale2d_inv(s, s)
|
208 |
+
|
209 |
+
# Apply pre-rotation with probability p_rot.
|
210 |
+
p_rot = 1 - torch.sqrt((1 - self.rotate * self.p).clamp(0, 1)) # P(pre OR post) = p
|
211 |
+
if self.rotate > 0:
|
212 |
+
theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
|
213 |
+
theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
|
214 |
+
if debug_percentile is not None:
|
215 |
+
theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.rotate_max)
|
216 |
+
G_inv = G_inv @ rotate2d_inv(-theta) # Before anisotropic scaling.
|
217 |
+
|
218 |
+
# Apply anisotropic scaling with probability (aniso * strength).
|
219 |
+
if self.aniso > 0:
|
220 |
+
s = torch.exp2(torch.randn([batch_size], device=device) * self.aniso_std)
|
221 |
+
s = torch.where(torch.rand([batch_size], device=device) < self.aniso * self.p, s, torch.ones_like(s))
|
222 |
+
if debug_percentile is not None:
|
223 |
+
s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.aniso_std))
|
224 |
+
G_inv = G_inv @ scale2d_inv(s, 1 / s)
|
225 |
+
|
226 |
+
# Apply post-rotation with probability p_rot.
|
227 |
+
if self.rotate > 0:
|
228 |
+
theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
|
229 |
+
theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
|
230 |
+
if debug_percentile is not None:
|
231 |
+
theta = torch.zeros_like(theta)
|
232 |
+
G_inv = G_inv @ rotate2d_inv(-theta) # After anisotropic scaling.
|
233 |
+
|
234 |
+
# Apply fractional translation with probability (xfrac * strength).
|
235 |
+
if self.xfrac > 0:
|
236 |
+
t = torch.randn([batch_size, 2], device=device) * self.xfrac_std
|
237 |
+
t = torch.where(torch.rand([batch_size, 1], device=device) < self.xfrac * self.p, t, torch.zeros_like(t))
|
238 |
+
if debug_percentile is not None:
|
239 |
+
t = torch.full_like(t, torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std)
|
240 |
+
G_inv = G_inv @ translate2d_inv(t[:,0] * width, t[:,1] * height)
|
241 |
+
|
242 |
+
# ----------------------------------
|
243 |
+
# Execute geometric transformations.
|
244 |
+
# ----------------------------------
|
245 |
+
|
246 |
+
# Execute if the transform is not identity.
|
247 |
+
if G_inv is not I_3:
|
248 |
+
# Calculate padding.
|
249 |
+
cx = (width - 1) / 2
|
250 |
+
cy = (height - 1) / 2
|
251 |
+
cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz]
|
252 |
+
cp = G_inv @ cp.t() # [batch, xyz, idx]
|
253 |
+
Hz_pad = self.Hz_geom.shape[0] // 4
|
254 |
+
margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx]
|
255 |
+
margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1]
|
256 |
+
margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device)
|
257 |
+
margin = margin.max(misc.constant([0, 0] * 2, device=device))
|
258 |
+
margin = margin.min(misc.constant([width-1, height-1] * 2, device=device))
|
259 |
+
mx0, my0, mx1, my1 = margin.ceil().to(torch.int32)
|
260 |
+
|
261 |
+
# Pad image and adjust origin.
|
262 |
+
images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect')
|
263 |
+
batch["mask"] = torch.nn.functional.pad(input=batch["mask"], pad=[mx0,mx1,my0,my1], mode='constant', value=1.0)
|
264 |
+
batch["E_mask"] = torch.nn.functional.pad(input=batch["E_mask"], pad=[mx0,mx1,my0,my1], mode='constant', value=0.0)
|
265 |
+
batch["vertices"] = torch.nn.functional.pad(input=batch["vertices"], pad=[mx0,mx1,my0,my1], mode='constant', value=0.0)
|
266 |
+
G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv
|
267 |
+
|
268 |
+
# Upsample.
|
269 |
+
images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2)
|
270 |
+
batch["mask"] = torch.nn.functional.interpolate(batch["mask"], scale_factor=2, mode="nearest")
|
271 |
+
batch["E_mask"] = torch.nn.functional.interpolate(batch["E_mask"], scale_factor=2, mode="nearest")
|
272 |
+
batch["vertices"] = torch.nn.functional.interpolate(batch["vertices"], scale_factor=2, mode="nearest")
|
273 |
+
G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device)
|
274 |
+
G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device)
|
275 |
+
|
276 |
+
# Execute transformation.
|
277 |
+
shape = [batch_size, num_channels, (height + Hz_pad * 2) * 2, (width + Hz_pad * 2) * 2]
|
278 |
+
G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device)
|
279 |
+
grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False)
|
280 |
+
images = grid_sample_gradfix.grid_sample(images, grid)
|
281 |
+
|
282 |
+
batch["mask"] = torch.nn.functional.grid_sample(
|
283 |
+
input=batch["mask"], grid=grid, mode='nearest', padding_mode="border", align_corners=False)
|
284 |
+
batch["E_mask"] = torch.nn.functional.grid_sample(
|
285 |
+
input=batch["E_mask"], grid=grid, mode='nearest', padding_mode="border", align_corners=False)
|
286 |
+
batch["vertices"] = torch.nn.functional.grid_sample(
|
287 |
+
input=batch["vertices"], grid=grid, mode='nearest', padding_mode="border", align_corners=False)
|
288 |
+
|
289 |
+
|
290 |
+
# Downsample and crop.
|
291 |
+
images = upfirdn2d.downsample2d(x=images, f=self.Hz_geom, down=2, padding=-Hz_pad*2, flip_filter=True)
|
292 |
+
batch["mask"] = torch.nn.functional.interpolate(batch["mask"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False)
|
293 |
+
batch["E_mask"] = torch.nn.functional.interpolate(batch["E_mask"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False)
|
294 |
+
batch["vertices"] = torch.nn.functional.interpolate(batch["vertices"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False)
|
295 |
+
# --------------------------------------------
|
296 |
+
# Select parameters for color transformations.
|
297 |
+
# --------------------------------------------
|
298 |
+
|
299 |
+
# Initialize homogeneous 3D transformation matrix: C @ color_in ==> color_out
|
300 |
+
I_4 = torch.eye(4, device=device)
|
301 |
+
C = I_4
|
302 |
+
|
303 |
+
# Apply brightness with probability (brightness * strength).
|
304 |
+
if self.brightness > 0:
|
305 |
+
b = torch.randn([batch_size], device=device) * self.brightness_std
|
306 |
+
b = torch.where(torch.rand([batch_size], device=device) < self.brightness * self.p, b, torch.zeros_like(b))
|
307 |
+
if debug_percentile is not None:
|
308 |
+
b = torch.full_like(b, torch.erfinv(debug_percentile * 2 - 1) * self.brightness_std)
|
309 |
+
C = translate3d(b, b, b) @ C
|
310 |
+
|
311 |
+
# Apply contrast with probability (contrast * strength).
|
312 |
+
if self.contrast > 0:
|
313 |
+
c = torch.exp2(torch.randn([batch_size], device=device) * self.contrast_std)
|
314 |
+
c = torch.where(torch.rand([batch_size], device=device) < self.contrast * self.p, c, torch.ones_like(c))
|
315 |
+
if debug_percentile is not None:
|
316 |
+
c = torch.full_like(c, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.contrast_std))
|
317 |
+
C = scale3d(c, c, c) @ C
|
318 |
+
|
319 |
+
# Apply luma flip with probability (lumaflip * strength).
|
320 |
+
v = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) # Luma axis.
|
321 |
+
|
322 |
+
# Apply hue rotation with probability (hue * strength).
|
323 |
+
if self.hue > 0 and num_channels > 1:
|
324 |
+
theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.hue_max
|
325 |
+
theta = torch.where(torch.rand([batch_size], device=device) < self.hue * self.p, theta, torch.zeros_like(theta))
|
326 |
+
if debug_percentile is not None:
|
327 |
+
theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.hue_max)
|
328 |
+
C = rotate3d(v, theta) @ C # Rotate around v.
|
329 |
+
|
330 |
+
# Apply saturation with probability (saturation * strength).
|
331 |
+
if self.saturation > 0 and num_channels > 1:
|
332 |
+
s = torch.exp2(torch.randn([batch_size, 1, 1], device=device) * self.saturation_std)
|
333 |
+
s = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.saturation * self.p, s, torch.ones_like(s))
|
334 |
+
if debug_percentile is not None:
|
335 |
+
s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.saturation_std))
|
336 |
+
C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C
|
337 |
+
|
338 |
+
# ------------------------------
|
339 |
+
# Execute color transformations.
|
340 |
+
# ------------------------------
|
341 |
+
|
342 |
+
# Execute if the transform is not identity.
|
343 |
+
if C is not I_4:
|
344 |
+
images = images.reshape([batch_size, num_channels, height * width])
|
345 |
+
if num_channels == 3:
|
346 |
+
images = C[:, :3, :3] @ images + C[:, :3, 3:]
|
347 |
+
elif num_channels == 1:
|
348 |
+
C = C[:, :3, :].mean(dim=1, keepdims=True)
|
349 |
+
images = images * C[:, :, :3].sum(dim=2, keepdims=True) + C[:, :, 3:]
|
350 |
+
else:
|
351 |
+
raise ValueError('Image must be RGB (3 channels) or L (1 channel)')
|
352 |
+
images = images.reshape([batch_size, num_channels, height, width])
|
353 |
+
|
354 |
+
# ----------------------
|
355 |
+
# Image-space filtering.
|
356 |
+
# ----------------------
|
357 |
+
|
358 |
+
if self.imgfilter > 0:
|
359 |
+
num_bands = self.Hz_fbank.shape[0]
|
360 |
+
assert len(self.imgfilter_bands) == num_bands
|
361 |
+
expected_power = misc.constant(np.array([10, 1, 1, 1]) / 13, device=device) # Expected power spectrum (1/f).
|
362 |
+
|
363 |
+
# Apply amplification for each band with probability (imgfilter * strength * band_strength).
|
364 |
+
g = torch.ones([batch_size, num_bands], device=device) # Global gain vector (identity).
|
365 |
+
for i, band_strength in enumerate(self.imgfilter_bands):
|
366 |
+
t_i = torch.exp2(torch.randn([batch_size], device=device) * self.imgfilter_std)
|
367 |
+
t_i = torch.where(torch.rand([batch_size], device=device) < self.imgfilter * self.p * band_strength, t_i, torch.ones_like(t_i))
|
368 |
+
if debug_percentile is not None:
|
369 |
+
t_i = torch.full_like(t_i, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.imgfilter_std)) if band_strength > 0 else torch.ones_like(t_i)
|
370 |
+
t = torch.ones([batch_size, num_bands], device=device) # Temporary gain vector.
|
371 |
+
t[:, i] = t_i # Replace i'th element.
|
372 |
+
t = t / (expected_power * t.square()).sum(dim=-1, keepdims=True).sqrt() # Normalize power.
|
373 |
+
g = g * t # Accumulate into global gain.
|
374 |
+
|
375 |
+
# Construct combined amplification filter.
|
376 |
+
Hz_prime = g @ self.Hz_fbank # [batch, tap]
|
377 |
+
Hz_prime = Hz_prime.unsqueeze(1).repeat([1, num_channels, 1]) # [batch, channels, tap]
|
378 |
+
Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1, -1]) # [batch * channels, 1, tap]
|
379 |
+
|
380 |
+
# Apply filter.
|
381 |
+
p = self.Hz_fbank.shape[1] // 2
|
382 |
+
images = images.reshape([1, batch_size * num_channels, height, width])
|
383 |
+
images = torch.nn.functional.pad(input=images, pad=[p,p,p,p], mode='reflect')
|
384 |
+
images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size*num_channels)
|
385 |
+
images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size*num_channels)
|
386 |
+
images = images.reshape([batch_size, num_channels, height, width])
|
387 |
+
|
388 |
+
# ------------------------
|
389 |
+
# Image-space corruptions.
|
390 |
+
# ------------------------
|
391 |
+
batch["img"] = images
|
392 |
+
batch["vertices"] = batch["vertices"].long()
|
393 |
+
batch["border"] = 1 - batch["E_mask"] - batch["mask"]
|
394 |
+
return batch
|
dp2/data/transforms/transforms.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Dict, List
|
3 |
+
import torchvision
|
4 |
+
import torch
|
5 |
+
import tops
|
6 |
+
import torchvision.transforms.functional as F
|
7 |
+
from .functional import hflip
|
8 |
+
import numpy as np
|
9 |
+
from dp2.utils.vis_utils import get_coco_keypoints
|
10 |
+
from PIL import Image, ImageDraw
|
11 |
+
from typing import Tuple
|
12 |
+
|
13 |
+
|
14 |
+
class RandomHorizontalFlip(torch.nn.Module):
|
15 |
+
|
16 |
+
def __init__(self, p: float, flip_map=None, **kwargs):
|
17 |
+
super().__init__()
|
18 |
+
self.flip_ratio = p
|
19 |
+
self.flip_map = flip_map
|
20 |
+
if self.flip_ratio is None:
|
21 |
+
self.flip_ratio = 0.5
|
22 |
+
assert 0 <= self.flip_ratio <= 1
|
23 |
+
|
24 |
+
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
25 |
+
if torch.rand(1) > self.flip_ratio:
|
26 |
+
return container
|
27 |
+
return hflip(container, self.flip_map)
|
28 |
+
|
29 |
+
|
30 |
+
class CenterCrop(torch.nn.Module):
|
31 |
+
"""
|
32 |
+
Performs the transform on the image.
|
33 |
+
NOTE: Does not transform the mask to improve runtime.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self, size: List[int]):
|
37 |
+
super().__init__()
|
38 |
+
self.size = tuple(size)
|
39 |
+
|
40 |
+
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
41 |
+
min_size = min(container["img"].shape[1], container["img"].shape[2])
|
42 |
+
if min_size < self.size[0]:
|
43 |
+
container["img"] = F.center_crop(container["img"], min_size)
|
44 |
+
container["img"] = F.resize(container["img"], self.size)
|
45 |
+
return container
|
46 |
+
container["img"] = F.center_crop(container["img"], self.size)
|
47 |
+
return container
|
48 |
+
|
49 |
+
|
50 |
+
class Resize(torch.nn.Module):
|
51 |
+
"""
|
52 |
+
Performs the transform on the image.
|
53 |
+
NOTE: Does not transform the mask to improve runtime.
|
54 |
+
"""
|
55 |
+
|
56 |
+
def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR):
|
57 |
+
super().__init__()
|
58 |
+
self.size = tuple(size)
|
59 |
+
self.interpolation = interpolation
|
60 |
+
|
61 |
+
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
62 |
+
container["img"] = F.resize(container["img"], self.size, self.interpolation, antialias=True)
|
63 |
+
if "semantic_mask" in container:
|
64 |
+
container["semantic_mask"] = F.resize(
|
65 |
+
container["semantic_mask"], self.size, F.InterpolationMode.NEAREST)
|
66 |
+
if "embedding" in container:
|
67 |
+
container["embedding"] = F.resize(
|
68 |
+
container["embedding"], self.size, self.interpolation)
|
69 |
+
if "mask" in container:
|
70 |
+
container["mask"] = F.resize(
|
71 |
+
container["mask"], self.size, F.InterpolationMode.NEAREST)
|
72 |
+
if "E_mask" in container:
|
73 |
+
container["E_mask"] = F.resize(
|
74 |
+
container["E_mask"], self.size, F.InterpolationMode.NEAREST)
|
75 |
+
if "maskrcnn_mask" in container:
|
76 |
+
container["maskrcnn_mask"] = F.resize(
|
77 |
+
container["maskrcnn_mask"], self.size, F.InterpolationMode.NEAREST)
|
78 |
+
if "vertices" in container:
|
79 |
+
container["vertices"] = F.resize(
|
80 |
+
container["vertices"], self.size, F.InterpolationMode.NEAREST)
|
81 |
+
return container
|
82 |
+
|
83 |
+
def __repr__(self):
|
84 |
+
repr = super().__repr__()
|
85 |
+
vars_ = dict(size=self.size, interpolation=self.interpolation)
|
86 |
+
return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()])
|
87 |
+
|
88 |
+
|
89 |
+
class Normalize(torch.nn.Module):
|
90 |
+
"""
|
91 |
+
Performs the transform on the image.
|
92 |
+
NOTE: Does not transform the mask to improve runtime.
|
93 |
+
"""
|
94 |
+
|
95 |
+
def __init__(self, mean, std, inplace, keys=["img"]):
|
96 |
+
super().__init__()
|
97 |
+
self.mean = mean
|
98 |
+
self.std = std
|
99 |
+
self.inplace = inplace
|
100 |
+
self.keys = keys
|
101 |
+
|
102 |
+
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
103 |
+
for key in self.keys:
|
104 |
+
container[key] = F.normalize(container[key], self.mean, self.std, self.inplace)
|
105 |
+
return container
|
106 |
+
|
107 |
+
def __repr__(self):
|
108 |
+
repr = super().__repr__()
|
109 |
+
vars_ = dict(mean=self.mean, std=self.std, inplace=self.inplace)
|
110 |
+
return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()])
|
111 |
+
|
112 |
+
|
113 |
+
class ToFloat(torch.nn.Module):
|
114 |
+
|
115 |
+
def __init__(self, keys=["img"], norm=True) -> None:
|
116 |
+
super().__init__()
|
117 |
+
self.keys = keys
|
118 |
+
self.gain = 255 if norm else 1
|
119 |
+
|
120 |
+
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
121 |
+
for key in self.keys:
|
122 |
+
container[key] = container[key].float() / self.gain
|
123 |
+
return container
|
124 |
+
|
125 |
+
|
126 |
+
class RandomCrop(torchvision.transforms.RandomCrop):
|
127 |
+
"""
|
128 |
+
Performs the transform on the image.
|
129 |
+
NOTE: Does not transform the mask to improve runtime.
|
130 |
+
"""
|
131 |
+
|
132 |
+
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
133 |
+
container["img"] = super().forward(container["img"])
|
134 |
+
return container
|
135 |
+
|
136 |
+
|
137 |
+
class CreateCondition(torch.nn.Module):
|
138 |
+
|
139 |
+
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
140 |
+
if container["img"].dtype == torch.uint8:
|
141 |
+
container["condition"] = container["img"] * container["mask"].byte() + (1-container["mask"].byte()) * 127
|
142 |
+
return container
|
143 |
+
container["condition"] = container["img"] * container["mask"]
|
144 |
+
return container
|
145 |
+
|
146 |
+
|
147 |
+
class CreateEmbedding(torch.nn.Module):
|
148 |
+
|
149 |
+
def __init__(self, embed_path: Path, cuda=True) -> None:
|
150 |
+
super().__init__()
|
151 |
+
self.embed_map = torch.load(embed_path, map_location=torch.device("cpu"))
|
152 |
+
if cuda:
|
153 |
+
self.embed_map = tops.to_cuda(self.embed_map)
|
154 |
+
|
155 |
+
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
156 |
+
vertices = container["vertices"]
|
157 |
+
if vertices.ndim == 3:
|
158 |
+
embedding = self.embed_map[vertices.long()].squeeze(dim=0)
|
159 |
+
embedding = embedding.permute(2, 0, 1) * container["E_mask"]
|
160 |
+
pass
|
161 |
+
else:
|
162 |
+
assert vertices.ndim == 4
|
163 |
+
embedding = self.embed_map[vertices.long()].squeeze(dim=1)
|
164 |
+
embedding = embedding.permute(0, 3, 1, 2) * container["E_mask"]
|
165 |
+
container["embedding"] = embedding
|
166 |
+
container["embed_map"] = self.embed_map.clone()
|
167 |
+
return container
|
168 |
+
|
169 |
+
|
170 |
+
class InsertJointMap(torch.nn.Module):
|
171 |
+
|
172 |
+
def __init__(self, imsize: Tuple) -> None:
|
173 |
+
super().__init__()
|
174 |
+
self.imsize = imsize
|
175 |
+
knames = get_coco_keypoints()[0]
|
176 |
+
knames = knames + ["neck", "mid_hip"]
|
177 |
+
connectivity = {
|
178 |
+
"nose": ["left_eye", "right_eye", "neck"],
|
179 |
+
"left_eye": ["right_eye", "left_ear"],
|
180 |
+
"right_eye": ["right_ear"],
|
181 |
+
"left_shoulder": ["right_shoulder", "left_elbow", "left_hip"],
|
182 |
+
"right_shoulder": ["right_elbow", "right_hip"],
|
183 |
+
"left_elbow": ["left_wrist"],
|
184 |
+
"right_elbow": ["right_wrist"],
|
185 |
+
"left_hip": ["right_hip", "left_knee"],
|
186 |
+
"right_hip": ["right_knee"],
|
187 |
+
"left_knee": ["left_ankle"],
|
188 |
+
"right_knee": ["right_ankle"],
|
189 |
+
"neck": ["mid_hip", "nose"],
|
190 |
+
}
|
191 |
+
category = {
|
192 |
+
("nose", "left_eye"): 0, # head
|
193 |
+
("nose", "right_eye"): 0, # head
|
194 |
+
("nose", "neck"): 0, # head
|
195 |
+
("left_eye", "right_eye"): 0, # head
|
196 |
+
("left_eye", "left_ear"): 0, # head
|
197 |
+
("right_eye", "right_ear"): 0, # head
|
198 |
+
("left_shoulder", "left_elbow"): 1, # left arm
|
199 |
+
("left_elbow", "left_wrist"): 1, # left arm
|
200 |
+
("right_shoulder", "right_elbow"): 2, # right arm
|
201 |
+
("right_elbow", "right_wrist"): 2, # right arm
|
202 |
+
("left_shoulder", "right_shoulder"): 3, # body
|
203 |
+
("left_shoulder", "left_hip"): 3, # body
|
204 |
+
("right_shoulder", "right_hip"): 3, # body
|
205 |
+
("left_hip", "right_hip"): 3, # body
|
206 |
+
("left_hip", "left_knee"): 4, # left leg
|
207 |
+
("left_knee", "left_ankle"): 4, # left leg
|
208 |
+
("right_hip", "right_knee"): 5, # right leg
|
209 |
+
("right_knee", "right_ankle"): 5, # right leg
|
210 |
+
("neck", "mid_hip"): 3, # body
|
211 |
+
("neck", "nose"): 0, # head
|
212 |
+
}
|
213 |
+
self.indices2category = {
|
214 |
+
tuple([knames.index(n) for n in k]): v for k, v in category.items()
|
215 |
+
}
|
216 |
+
self.connectivity_indices = {
|
217 |
+
knames.index(k): [knames.index(v_) for v_ in v]
|
218 |
+
for k, v in connectivity.items()
|
219 |
+
}
|
220 |
+
self.l_shoulder = knames.index("left_shoulder")
|
221 |
+
self.r_shoulder = knames.index("right_shoulder")
|
222 |
+
self.l_hip = knames.index("left_hip")
|
223 |
+
self.r_hip = knames.index("right_hip")
|
224 |
+
self.l_eye = knames.index("left_eye")
|
225 |
+
self.r_eye = knames.index("right_eye")
|
226 |
+
self.nose = knames.index("nose")
|
227 |
+
self.neck = knames.index("neck")
|
228 |
+
|
229 |
+
def create_joint_map(self, N, H, W, keypoints):
|
230 |
+
joint_maps = np.zeros((N, H, W), dtype=np.uint8)
|
231 |
+
for bidx, keypoints in enumerate(keypoints):
|
232 |
+
assert keypoints.shape == (17, 3), keypoints.shape
|
233 |
+
keypoints = torch.cat((keypoints, torch.zeros(2, 3)))
|
234 |
+
visible = keypoints[:, -1] > 0
|
235 |
+
|
236 |
+
if visible[self.l_shoulder] and visible[self.r_shoulder]:
|
237 |
+
neck = (keypoints[self.l_shoulder]
|
238 |
+
+ (keypoints[self.r_shoulder] - keypoints[self.l_shoulder]) / 2)
|
239 |
+
keypoints[-2] = neck
|
240 |
+
visible[-2] = 1
|
241 |
+
if visible[self.l_hip] and visible[self.r_hip]:
|
242 |
+
mhip = (keypoints[self.l_hip]
|
243 |
+
+ (keypoints[self.r_hip] - keypoints[self.l_hip]) / 2
|
244 |
+
)
|
245 |
+
keypoints[-1] = mhip
|
246 |
+
visible[-1] = 1
|
247 |
+
|
248 |
+
keypoints[:, 0] *= W
|
249 |
+
keypoints[:, 1] *= H
|
250 |
+
joint_map = Image.fromarray(np.zeros((H, W), dtype=np.uint8))
|
251 |
+
draw = ImageDraw.Draw(joint_map)
|
252 |
+
for fidx in self.connectivity_indices.keys():
|
253 |
+
for tidx in self.connectivity_indices[fidx]:
|
254 |
+
if visible[fidx] == 0 or visible[tidx] == 0:
|
255 |
+
continue
|
256 |
+
c = self.indices2category[(fidx, tidx)]
|
257 |
+
s = tuple(keypoints[fidx, :2].round().long().numpy().tolist())
|
258 |
+
e = tuple(keypoints[tidx, :2].round().long().numpy().tolist())
|
259 |
+
draw.line((s, e), width=1, fill=c + 1)
|
260 |
+
if visible[self.nose] == 0 and visible[self.neck] == 1:
|
261 |
+
m_eye = (
|
262 |
+
keypoints[self.l_eye]
|
263 |
+
+ (keypoints[self.r_eye] - keypoints[self.l_eye]) / 2
|
264 |
+
)
|
265 |
+
s = tuple(m_eye[:2].round().long().numpy().tolist())
|
266 |
+
e = tuple(keypoints[self.neck, :2].round().long().numpy().tolist())
|
267 |
+
c = self.indices2category[(self.nose, self.neck)]
|
268 |
+
draw.line((s, e), width=1, fill=c + 1)
|
269 |
+
joint_map = np.array(joint_map)
|
270 |
+
|
271 |
+
joint_maps[bidx] = np.array(joint_map)
|
272 |
+
return joint_maps[:, None]
|
273 |
+
|
274 |
+
def forward(self, batch):
|
275 |
+
batch["joint_map"] = torch.from_numpy(self.create_joint_map(
|
276 |
+
batch["img"].shape[0], *self.imsize, batch["keypoints"]))
|
277 |
+
return batch
|
dp2/data/utils.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
import multiprocessing
|
5 |
+
import io
|
6 |
+
from tops import logger
|
7 |
+
from torch.utils.data._utils.collate import default_collate
|
8 |
+
|
9 |
+
try:
|
10 |
+
import pyspng
|
11 |
+
|
12 |
+
PYSPNG_IMPORTED = True
|
13 |
+
except ImportError:
|
14 |
+
PYSPNG_IMPORTED = False
|
15 |
+
print("Could not load pyspng. Defaulting to pillow image backend.")
|
16 |
+
from PIL import Image
|
17 |
+
|
18 |
+
|
19 |
+
def get_fdf_keypoints():
|
20 |
+
return get_coco_keypoints()[:7]
|
21 |
+
|
22 |
+
|
23 |
+
def get_fdf_flipmap():
|
24 |
+
keypoints = get_fdf_keypoints()
|
25 |
+
keypoint_flip_map = {
|
26 |
+
"left_eye": "right_eye",
|
27 |
+
"left_ear": "right_ear",
|
28 |
+
"left_shoulder": "right_shoulder",
|
29 |
+
}
|
30 |
+
for key, value in list(keypoint_flip_map.items()):
|
31 |
+
keypoint_flip_map[value] = key
|
32 |
+
keypoint_flip_map["nose"] = "nose"
|
33 |
+
keypoint_flip_map_idx = []
|
34 |
+
for source in keypoints:
|
35 |
+
keypoint_flip_map_idx.append(keypoints.index(keypoint_flip_map[source]))
|
36 |
+
return keypoint_flip_map_idx
|
37 |
+
|
38 |
+
|
39 |
+
def get_coco_keypoints():
|
40 |
+
return [
|
41 |
+
"nose",
|
42 |
+
"left_eye",
|
43 |
+
"right_eye", # 2
|
44 |
+
"left_ear",
|
45 |
+
"right_ear", # 4
|
46 |
+
"left_shoulder",
|
47 |
+
"right_shoulder", # 6
|
48 |
+
"left_elbow",
|
49 |
+
"right_elbow", # 8
|
50 |
+
"left_wrist",
|
51 |
+
"right_wrist", # 10
|
52 |
+
"left_hip",
|
53 |
+
"right_hip", # 12
|
54 |
+
"left_knee",
|
55 |
+
"right_knee", # 14
|
56 |
+
"left_ankle",
|
57 |
+
"right_ankle", # 16
|
58 |
+
]
|
59 |
+
|
60 |
+
|
61 |
+
def get_coco_flipmap():
|
62 |
+
keypoints = get_coco_keypoints()
|
63 |
+
keypoint_flip_map = {
|
64 |
+
"left_eye": "right_eye",
|
65 |
+
"left_ear": "right_ear",
|
66 |
+
"left_shoulder": "right_shoulder",
|
67 |
+
"left_elbow": "right_elbow",
|
68 |
+
"left_wrist": "right_wrist",
|
69 |
+
"left_hip": "right_hip",
|
70 |
+
"left_knee": "right_knee",
|
71 |
+
"left_ankle": "right_ankle",
|
72 |
+
}
|
73 |
+
for key, value in list(keypoint_flip_map.items()):
|
74 |
+
keypoint_flip_map[value] = key
|
75 |
+
keypoint_flip_map["nose"] = "nose"
|
76 |
+
keypoint_flip_map_idx = []
|
77 |
+
for source in keypoints:
|
78 |
+
keypoint_flip_map_idx.append(keypoints.index(keypoint_flip_map[source]))
|
79 |
+
return keypoint_flip_map_idx
|
80 |
+
|
81 |
+
|
82 |
+
def mask_decoder(x):
|
83 |
+
mask = torch.from_numpy(np.array(Image.open(io.BytesIO(x)))).squeeze()[None]
|
84 |
+
mask = mask > 0 # This fixes bug causing maskf.loat().max() == 255.
|
85 |
+
return mask
|
86 |
+
|
87 |
+
|
88 |
+
def png_decoder(x):
|
89 |
+
if PYSPNG_IMPORTED:
|
90 |
+
return torch.from_numpy(np.rollaxis(pyspng.load(x), 2))
|
91 |
+
with Image.open(io.BytesIO(x)) as im:
|
92 |
+
im = torch.from_numpy(np.rollaxis(np.array(im.convert("RGB")), 2))
|
93 |
+
return im
|
94 |
+
|
95 |
+
|
96 |
+
def jpg_decoder(x):
|
97 |
+
with Image.open(io.BytesIO(x)) as im:
|
98 |
+
im = torch.from_numpy(np.rollaxis(np.array(im.convert("RGB")), 2))
|
99 |
+
return im
|
100 |
+
|
101 |
+
|
102 |
+
def get_num_workers(num_workers: int):
|
103 |
+
n_cpus = multiprocessing.cpu_count()
|
104 |
+
if num_workers > n_cpus:
|
105 |
+
logger.warn(f"Setting the number of workers to match cpu count: {n_cpus}")
|
106 |
+
return n_cpus
|
107 |
+
return num_workers
|
108 |
+
|
109 |
+
|
110 |
+
def collate_fn(batch):
|
111 |
+
elem = batch[0]
|
112 |
+
ignore_keys = set(["embed_map", "vertx2cat"])
|
113 |
+
batch_ = {
|
114 |
+
key: default_collate([d[key] for d in batch])
|
115 |
+
for key in elem
|
116 |
+
if key not in ignore_keys
|
117 |
+
}
|
118 |
+
if "embed_map" in elem:
|
119 |
+
batch_["embed_map"] = elem["embed_map"]
|
120 |
+
if "vertx2cat" in elem:
|
121 |
+
batch_["vertx2cat"] = elem["vertx2cat"]
|
122 |
+
return batch_
|
dp2/detection/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .cse_mask_face_detector import CSeMaskFaceDetector
|
2 |
+
from .person_detector import CSEPersonDetector
|
3 |
+
from .structures import PersonDetection, VehicleDetection, FaceDetection
|
dp2/detection/base.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import torch
|
3 |
+
import lzma
|
4 |
+
from pathlib import Path
|
5 |
+
from tops import logger
|
6 |
+
|
7 |
+
|
8 |
+
class BaseDetector:
|
9 |
+
|
10 |
+
def __init__(self, cache_directory: str) -> None:
|
11 |
+
if cache_directory is not None:
|
12 |
+
self.cache_directory = Path(cache_directory, str(self.__class__.__name__))
|
13 |
+
self.cache_directory.mkdir(exist_ok=True, parents=True)
|
14 |
+
|
15 |
+
def save_to_cache(self, detection, cache_path: Path, after_preprocess=True):
|
16 |
+
logger.log(f"Caching detection to: {cache_path}")
|
17 |
+
with lzma.open(cache_path, "wb") as fp:
|
18 |
+
torch.save(
|
19 |
+
[det.state_dict(after_preprocess=after_preprocess) for det in detection], fp,
|
20 |
+
pickle_protocol=pickle.HIGHEST_PROTOCOL)
|
21 |
+
|
22 |
+
def load_from_cache(self, cache_path: Path):
|
23 |
+
logger.log(f"Loading detection from cache path: {cache_path}")
|
24 |
+
with lzma.open(cache_path, "rb") as fp:
|
25 |
+
state_dict = torch.load(fp)
|
26 |
+
return [
|
27 |
+
state["cls"].from_state_dict(state_dict=state) for state in state_dict
|
28 |
+
]
|
29 |
+
|
30 |
+
def forward_and_cache(self, im: torch.Tensor, cache_id: str, load_cache: bool):
|
31 |
+
if cache_id is None:
|
32 |
+
return self.forward(im)
|
33 |
+
cache_path = self.cache_directory.joinpath(cache_id + ".torch")
|
34 |
+
if cache_path.is_file() and load_cache:
|
35 |
+
try:
|
36 |
+
return self.load_from_cache(cache_path)
|
37 |
+
except Exception as e:
|
38 |
+
logger.warn(f"The cache file was corrupted: {cache_path}")
|
39 |
+
exit()
|
40 |
+
detections = self.forward(im)
|
41 |
+
self.save_to_cache(detections, cache_path)
|
42 |
+
return detections
|
dp2/detection/box_utils.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def expand_bbox_to_ratio(bbox, imshape, target_aspect_ratio):
|
5 |
+
x0, y0, x1, y1 = [int(_) for _ in bbox]
|
6 |
+
h, w = y1 - y0, x1 - x0
|
7 |
+
cur_ratio = h / w
|
8 |
+
|
9 |
+
if cur_ratio == target_aspect_ratio:
|
10 |
+
return [x0, y0, x1, y1]
|
11 |
+
if cur_ratio < target_aspect_ratio:
|
12 |
+
target_height = int(w*target_aspect_ratio)
|
13 |
+
y0, y1 = expand_axis(y0, y1, target_height, imshape[0])
|
14 |
+
else:
|
15 |
+
target_width = int(h/target_aspect_ratio)
|
16 |
+
x0, x1 = expand_axis(x0, x1, target_width, imshape[1])
|
17 |
+
return x0, y0, x1, y1
|
18 |
+
|
19 |
+
|
20 |
+
def expand_axis(start, end, target_width, limit):
|
21 |
+
# Can return a bbox outside of limit
|
22 |
+
cur_width = end - start
|
23 |
+
start = start - (target_width-cur_width)//2
|
24 |
+
end = end + (target_width-cur_width)//2
|
25 |
+
if end - start != target_width:
|
26 |
+
end += 1
|
27 |
+
assert end - start == target_width
|
28 |
+
if start < 0 and end > limit:
|
29 |
+
return start, end
|
30 |
+
if start < 0 and end < limit:
|
31 |
+
to_shift = min(0 - start, limit - end)
|
32 |
+
start += to_shift
|
33 |
+
end += to_shift
|
34 |
+
if end > limit and start > 0:
|
35 |
+
to_shift = min(end - limit, start)
|
36 |
+
end -= to_shift
|
37 |
+
start -= to_shift
|
38 |
+
assert end - start == target_width
|
39 |
+
return start, end
|
40 |
+
|
41 |
+
|
42 |
+
def expand_box(bbox, imshape, mask, percentage_background: float):
|
43 |
+
assert isinstance(bbox[0], int)
|
44 |
+
assert 0 < percentage_background < 1
|
45 |
+
# Percentage in S
|
46 |
+
mask_pixels = mask.long().sum().cpu()
|
47 |
+
total_pixels = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
|
48 |
+
percentage_mask = mask_pixels / total_pixels
|
49 |
+
if (1 - percentage_mask) > percentage_background:
|
50 |
+
return bbox
|
51 |
+
target_pixels = mask_pixels / (1 - percentage_background)
|
52 |
+
x0, y0, x1, y1 = bbox
|
53 |
+
H = y1 - y0
|
54 |
+
W = x1 - x0
|
55 |
+
p = np.sqrt(target_pixels/(H*W))
|
56 |
+
target_width = int(np.ceil(p * W))
|
57 |
+
target_height = int(np.ceil(p * H))
|
58 |
+
x0, x1 = expand_axis(x0, x1, target_width, imshape[1])
|
59 |
+
y0, y1 = expand_axis(y0, y1, target_height, imshape[0])
|
60 |
+
return [x0, y0, x1, y1]
|
61 |
+
|
62 |
+
|
63 |
+
def expand_axises_by_percentage(bbox_XYXY, imshape, percentage):
|
64 |
+
x0, y0, x1, y1 = bbox_XYXY
|
65 |
+
H = y1 - y0
|
66 |
+
W = x1 - x0
|
67 |
+
expansion = int(((H*W)**0.5) * percentage)
|
68 |
+
new_width = W + expansion
|
69 |
+
new_height = H + expansion
|
70 |
+
x0, x1 = expand_axis(x0, x1, min(new_width, imshape[1]), imshape[1])
|
71 |
+
y0, y1 = expand_axis(y0, y1, min(new_height, imshape[0]), imshape[0])
|
72 |
+
return [x0, y0, x1, y1]
|
73 |
+
|
74 |
+
|
75 |
+
def get_expanded_bbox(
|
76 |
+
bbox_XYXY,
|
77 |
+
imshape,
|
78 |
+
mask,
|
79 |
+
percentage_background: float,
|
80 |
+
axis_minimum_expansion: float,
|
81 |
+
target_aspect_ratio: float):
|
82 |
+
bbox_XYXY = bbox_XYXY.long().cpu().numpy().tolist()
|
83 |
+
# Expand each axis of the bounding box by a minimum percentage
|
84 |
+
bbox_XYXY = expand_axises_by_percentage(bbox_XYXY, imshape, axis_minimum_expansion)
|
85 |
+
# Find the minimum bbox with the aspect ratio. Can be outside of imshape
|
86 |
+
bbox_XYXY = expand_bbox_to_ratio(bbox_XYXY, imshape, target_aspect_ratio)
|
87 |
+
# Expands square box such that X% of the bbox is background
|
88 |
+
bbox_XYXY = expand_box(bbox_XYXY, imshape, mask, percentage_background)
|
89 |
+
assert isinstance(bbox_XYXY[0], (int, np.int64))
|
90 |
+
return bbox_XYXY
|
91 |
+
|
92 |
+
|
93 |
+
def include_box(bbox, minimum_area, aspect_ratio_range, min_bbox_ratio_inside, imshape):
|
94 |
+
def area_inside_ratio(bbox, imshape):
|
95 |
+
area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
|
96 |
+
area_inside = (min(bbox[2], imshape[1]) - max(0, bbox[0])) * (min(imshape[0], bbox[3]) - max(0, bbox[1]))
|
97 |
+
return area_inside / area
|
98 |
+
ratio = (bbox[3] - bbox[1]) / (bbox[2] - bbox[0])
|
99 |
+
area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
|
100 |
+
if area_inside_ratio(bbox, imshape) < min_bbox_ratio_inside:
|
101 |
+
return False
|
102 |
+
if ratio <= aspect_ratio_range[0] or ratio >= aspect_ratio_range[1] or area < minimum_area:
|
103 |
+
return False
|
104 |
+
return True
|
dp2/detection/box_utils_fdf.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
The FDF dataset expands bound boxes differently from what is used for CSE.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
def quadratic_bounding_box(x0, y0, width, height, imshape):
|
9 |
+
# We assume that we can create a image that is quadratic without
|
10 |
+
# minimizing any of the sides
|
11 |
+
assert width <= min(imshape[:2])
|
12 |
+
assert height <= min(imshape[:2])
|
13 |
+
min_side = min(height, width)
|
14 |
+
if height != width:
|
15 |
+
side_diff = abs(height - width)
|
16 |
+
# Want to extend the shortest side
|
17 |
+
if min_side == height:
|
18 |
+
# Vertical side
|
19 |
+
height += side_diff
|
20 |
+
if height > imshape[0]:
|
21 |
+
# Take full frame, and shrink width
|
22 |
+
y0 = 0
|
23 |
+
height = imshape[0]
|
24 |
+
|
25 |
+
side_diff = abs(height - width)
|
26 |
+
width -= side_diff
|
27 |
+
x0 += side_diff // 2
|
28 |
+
else:
|
29 |
+
y0 -= side_diff // 2
|
30 |
+
y0 = max(0, y0)
|
31 |
+
else:
|
32 |
+
# Horizontal side
|
33 |
+
width += side_diff
|
34 |
+
if width > imshape[1]:
|
35 |
+
# Take full frame width, and shrink height
|
36 |
+
x0 = 0
|
37 |
+
width = imshape[1]
|
38 |
+
|
39 |
+
side_diff = abs(height - width)
|
40 |
+
height -= side_diff
|
41 |
+
y0 += side_diff // 2
|
42 |
+
else:
|
43 |
+
x0 -= side_diff // 2
|
44 |
+
x0 = max(0, x0)
|
45 |
+
# Check that bbox goes outside image
|
46 |
+
x1 = x0 + width
|
47 |
+
y1 = y0 + height
|
48 |
+
if imshape[1] < x1:
|
49 |
+
diff = x1 - imshape[1]
|
50 |
+
x0 -= diff
|
51 |
+
if imshape[0] < y1:
|
52 |
+
diff = y1 - imshape[0]
|
53 |
+
y0 -= diff
|
54 |
+
assert x0 >= 0, "Bounding box outside image."
|
55 |
+
assert y0 >= 0, "Bounding box outside image."
|
56 |
+
assert x0 + width <= imshape[1], "Bounding box outside image."
|
57 |
+
assert y0 + height <= imshape[0], "Bounding box outside image."
|
58 |
+
return x0, y0, width, height
|
59 |
+
|
60 |
+
|
61 |
+
def expand_bounding_box(bbox, percentage, imshape):
|
62 |
+
orig_bbox = bbox.copy()
|
63 |
+
x0, y0, x1, y1 = bbox
|
64 |
+
width = x1 - x0
|
65 |
+
height = y1 - y0
|
66 |
+
x0, y0, width, height = quadratic_bounding_box(
|
67 |
+
x0, y0, width, height, imshape)
|
68 |
+
expanding_factor = int(max(height, width) * percentage)
|
69 |
+
|
70 |
+
possible_max_expansion = [(imshape[0] - width) // 2,
|
71 |
+
(imshape[1] - height) // 2,
|
72 |
+
expanding_factor]
|
73 |
+
|
74 |
+
expanding_factor = min(possible_max_expansion)
|
75 |
+
# Expand height
|
76 |
+
|
77 |
+
if expanding_factor > 0:
|
78 |
+
|
79 |
+
y0 = y0 - expanding_factor
|
80 |
+
y0 = max(0, y0)
|
81 |
+
|
82 |
+
height += expanding_factor * 2
|
83 |
+
if height > imshape[0]:
|
84 |
+
y0 -= (imshape[0] - height)
|
85 |
+
height = imshape[0]
|
86 |
+
|
87 |
+
if height + y0 > imshape[0]:
|
88 |
+
y0 -= (height + y0 - imshape[0])
|
89 |
+
|
90 |
+
# Expand width
|
91 |
+
x0 = x0 - expanding_factor
|
92 |
+
x0 = max(0, x0)
|
93 |
+
|
94 |
+
width += expanding_factor * 2
|
95 |
+
if width > imshape[1]:
|
96 |
+
x0 -= (imshape[1] - width)
|
97 |
+
width = imshape[1]
|
98 |
+
|
99 |
+
if width + x0 > imshape[1]:
|
100 |
+
x0 -= (width + x0 - imshape[1])
|
101 |
+
y1 = y0 + height
|
102 |
+
x1 = x0 + width
|
103 |
+
assert y0 >= 0, "Y0 is minus"
|
104 |
+
assert height <= imshape[0], "Height is larger than image."
|
105 |
+
assert x0 + width <= imshape[1]
|
106 |
+
assert y0 + height <= imshape[0]
|
107 |
+
assert width == height, "HEIGHT IS NOT EQUAL WIDTH!!"
|
108 |
+
assert x0 >= 0, "Y0 is minus"
|
109 |
+
assert width <= imshape[1], "Height is larger than image."
|
110 |
+
# Check that original bbox is within new
|
111 |
+
x0_o, y0_o, x1_o, y1_o = orig_bbox
|
112 |
+
assert x0 <= x0_o, f"New bbox is outisde of original. O:{x0_o}, N: {x0}"
|
113 |
+
assert x1 >= x1_o, f"New bbox is outisde of original. O:{x1_o}, N: {x1}"
|
114 |
+
assert y0 <= y0_o, f"New bbox is outisde of original. O:{y0_o}, N: {y0}"
|
115 |
+
assert y1 >= y1_o, f"New bbox is outisde of original. O:{y1_o}, N: {y1}"
|
116 |
+
|
117 |
+
x0, y0, width, height = [int(_) for _ in [x0, y0, width, height]]
|
118 |
+
x1 = x0 + width
|
119 |
+
y1 = y0 + height
|
120 |
+
return np.array([x0, y0, x1, y1])
|
121 |
+
|
122 |
+
|
123 |
+
def is_keypoint_within_bbox(x0, y0, x1, y1, keypoint):
|
124 |
+
keypoint = keypoint[:, :3] # only nose + eyes are relevant
|
125 |
+
kp_X = keypoint[0, :]
|
126 |
+
kp_Y = keypoint[1, :]
|
127 |
+
within_X = np.all(kp_X >= x0) and np.all(kp_X <= x1)
|
128 |
+
within_Y = np.all(kp_Y >= y0) and np.all(kp_Y <= y1)
|
129 |
+
return within_X and within_Y
|
130 |
+
|
131 |
+
|
132 |
+
def expand_bbox_simple(bbox, percentage):
|
133 |
+
x0, y0, x1, y1 = bbox.astype(float)
|
134 |
+
width = x1 - x0
|
135 |
+
height = y1 - y0
|
136 |
+
x_c = int(x0) + width // 2
|
137 |
+
y_c = int(y0) + height // 2
|
138 |
+
avg_size = max(width, height)
|
139 |
+
new_width = avg_size * (1 + percentage)
|
140 |
+
x0 = x_c - new_width // 2
|
141 |
+
y0 = y_c - new_width // 2
|
142 |
+
x1 = x_c + new_width // 2
|
143 |
+
y1 = y_c + new_width // 2
|
144 |
+
return np.array([x0, y0, x1, y1]).astype(int)
|
145 |
+
|
146 |
+
|
147 |
+
def pad_image(im, bbox, pad_value):
|
148 |
+
x0, y0, x1, y1 = bbox
|
149 |
+
if x0 < 0:
|
150 |
+
pad_im = np.zeros((im.shape[0], abs(x0), im.shape[2]),
|
151 |
+
dtype=np.uint8) + pad_value
|
152 |
+
im = np.concatenate((pad_im, im), axis=1)
|
153 |
+
x1 += abs(x0)
|
154 |
+
x0 = 0
|
155 |
+
if y0 < 0:
|
156 |
+
pad_im = np.zeros((abs(y0), im.shape[1], im.shape[2]),
|
157 |
+
dtype=np.uint8) + pad_value
|
158 |
+
im = np.concatenate((pad_im, im), axis=0)
|
159 |
+
y1 += abs(y0)
|
160 |
+
y0 = 0
|
161 |
+
if x1 >= im.shape[1]:
|
162 |
+
pad_im = np.zeros(
|
163 |
+
(im.shape[0], x1 - im.shape[1] + 1, im.shape[2]),
|
164 |
+
dtype=np.uint8) + pad_value
|
165 |
+
im = np.concatenate((im, pad_im), axis=1)
|
166 |
+
if y1 >= im.shape[0]:
|
167 |
+
pad_im = np.zeros(
|
168 |
+
(y1 - im.shape[0] + 1, im.shape[1], im.shape[2]),
|
169 |
+
dtype=np.uint8) + pad_value
|
170 |
+
im = np.concatenate((im, pad_im), axis=0)
|
171 |
+
return im[y0:y1, x0:x1]
|
172 |
+
|
173 |
+
|
174 |
+
def clip_box(bbox, im):
|
175 |
+
bbox[0] = max(0, bbox[0])
|
176 |
+
bbox[1] = max(0, bbox[1])
|
177 |
+
bbox[2] = min(im.shape[1] - 1, bbox[2])
|
178 |
+
bbox[3] = min(im.shape[0] - 1, bbox[3])
|
179 |
+
return bbox
|
180 |
+
|
181 |
+
|
182 |
+
def cut_face(im, bbox, simple_expand=False, pad_value=0, pad_im=True):
|
183 |
+
outside_im = (bbox < 0).any() or bbox[2] > im.shape[1] or bbox[3] > im.shape[0]
|
184 |
+
if simple_expand or (outside_im and pad_im):
|
185 |
+
return pad_image(im, bbox, pad_value)
|
186 |
+
bbox = clip_box(bbox, im)
|
187 |
+
x0, y0, x1, y1 = bbox
|
188 |
+
return im[y0:y1, x0:x1]
|
189 |
+
|
190 |
+
|
191 |
+
def expand_bbox(
|
192 |
+
bbox_ltrb, imshape, simple_expand, default_to_simple=False,
|
193 |
+
expansion_factor=0.35):
|
194 |
+
assert bbox_ltrb.shape == (4,), f"BBox shape was: {bbox_ltrb.shape}"
|
195 |
+
bbox = bbox_ltrb.astype(float)
|
196 |
+
# FDF256 uses simple expand with ratio 0.4
|
197 |
+
if simple_expand:
|
198 |
+
return expand_bbox_simple(bbox, 0.4)
|
199 |
+
try:
|
200 |
+
return expand_bounding_box(bbox, expansion_factor, imshape)
|
201 |
+
except AssertionError:
|
202 |
+
return expand_bbox_simple(bbox, expansion_factor * 2)
|
dp2/detection/cse_mask_face_detector.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import lzma
|
3 |
+
import tops
|
4 |
+
from pathlib import Path
|
5 |
+
from dp2.detection.base import BaseDetector
|
6 |
+
from .utils import combine_cse_maskrcnn_dets
|
7 |
+
from face_detection import build_detector as build_face_detector
|
8 |
+
from .models.cse import CSEDetector
|
9 |
+
from .models.mask_rcnn import MaskRCNNDetector
|
10 |
+
from .structures import CSEPersonDetection, VehicleDetection, FaceDetection, PersonDetection
|
11 |
+
from tops import logger
|
12 |
+
|
13 |
+
|
14 |
+
def box1_inside_box2(box1: torch.Tensor, box2: torch.Tensor):
|
15 |
+
assert len(box1.shape) == 2
|
16 |
+
assert len(box2.shape) == 2
|
17 |
+
box1_inside = torch.zeros(box1.shape[0], device=box1.device, dtype=torch.bool)
|
18 |
+
# This can be batched
|
19 |
+
for i, box in enumerate(box1):
|
20 |
+
is_outside_lefttop = (box[None, [0, 1]] <= box2[:, [0, 1]]).any(dim=1)
|
21 |
+
is_outside_rightbot = (box[None, [2, 3]] >= box2[:, [2, 3]]).any(dim=1)
|
22 |
+
is_outside = is_outside_lefttop.logical_or(is_outside_rightbot)
|
23 |
+
box1_inside[i] = is_outside.logical_not().any()
|
24 |
+
return box1_inside
|
25 |
+
|
26 |
+
|
27 |
+
class CSeMaskFaceDetector(BaseDetector):
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
mask_rcnn_cfg,
|
32 |
+
face_detector_cfg: dict,
|
33 |
+
cse_cfg: dict,
|
34 |
+
face_post_process_cfg: dict,
|
35 |
+
cse_post_process_cfg,
|
36 |
+
score_threshold: float,
|
37 |
+
**kwargs
|
38 |
+
) -> None:
|
39 |
+
super().__init__(**kwargs)
|
40 |
+
self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold)
|
41 |
+
if "confidence_threshold" not in face_detector_cfg:
|
42 |
+
face_detector_cfg["confidence_threshold"] = score_threshold
|
43 |
+
if "score_thres" not in cse_cfg:
|
44 |
+
cse_cfg["score_thres"] = score_threshold
|
45 |
+
self.cse_detector = CSEDetector(**cse_cfg)
|
46 |
+
self.face_detector = build_face_detector(**face_detector_cfg, clip_boxes=True)
|
47 |
+
self.cse_post_process_cfg = cse_post_process_cfg
|
48 |
+
self.face_mean = tops.to_cuda(torch.from_numpy(self.face_detector.mean).view(3, 1, 1))
|
49 |
+
self.mask_cse_iou_combine_threshold = self.cse_post_process_cfg.pop("iou_combine_threshold")
|
50 |
+
self.face_post_process_cfg = face_post_process_cfg
|
51 |
+
|
52 |
+
def __call__(self, *args, **kwargs):
|
53 |
+
return self.forward(*args, **kwargs)
|
54 |
+
|
55 |
+
def _detect_faces(self, im: torch.Tensor):
|
56 |
+
H, W = im.shape[1:]
|
57 |
+
im = im.float() - self.face_mean
|
58 |
+
im = self.face_detector.resize(im[None], 1.0)
|
59 |
+
boxes_XYXY = self.face_detector._batched_detect(im)[0][:, :-1] # Remove score
|
60 |
+
boxes_XYXY[:, [0, 2]] *= W
|
61 |
+
boxes_XYXY[:, [1, 3]] *= H
|
62 |
+
return boxes_XYXY.round().long()
|
63 |
+
|
64 |
+
def load_from_cache(self, cache_path: Path):
|
65 |
+
logger.log(f"Loading detection from cache path: {cache_path}",)
|
66 |
+
with lzma.open(cache_path, "rb") as fp:
|
67 |
+
state_dict = torch.load(fp, map_location="cpu")
|
68 |
+
kwargs = dict(
|
69 |
+
post_process_cfg=self.cse_post_process_cfg,
|
70 |
+
embed_map=self.cse_detector.embed_map,
|
71 |
+
**self.face_post_process_cfg
|
72 |
+
)
|
73 |
+
return [
|
74 |
+
state["cls"].from_state_dict(**kwargs, state_dict=state)
|
75 |
+
for state in state_dict
|
76 |
+
]
|
77 |
+
|
78 |
+
@torch.no_grad()
|
79 |
+
def forward(self, im: torch.Tensor):
|
80 |
+
maskrcnn_dets = self.mask_rcnn(im)
|
81 |
+
cse_dets = self.cse_detector(im)
|
82 |
+
embed_map = self.cse_detector.embed_map
|
83 |
+
print("Calling face detector.")
|
84 |
+
face_boxes = self._detect_faces(im).cpu()
|
85 |
+
maskrcnn_person = {
|
86 |
+
k: v[maskrcnn_dets["is_person"]] for k, v in maskrcnn_dets.items()
|
87 |
+
}
|
88 |
+
maskrcnn_other = {
|
89 |
+
k: v[maskrcnn_dets["is_person"].logical_not()] for k, v in maskrcnn_dets.items()
|
90 |
+
}
|
91 |
+
maskrcnn_other = VehicleDetection(maskrcnn_other["segmentation"])
|
92 |
+
combined_segmentation, cse_dets, matches = combine_cse_maskrcnn_dets(
|
93 |
+
maskrcnn_person["segmentation"], cse_dets, self.mask_cse_iou_combine_threshold)
|
94 |
+
|
95 |
+
persons_with_cse = CSEPersonDetection(
|
96 |
+
combined_segmentation, cse_dets, **self.cse_post_process_cfg,
|
97 |
+
embed_map=embed_map, orig_imshape_CHW=im.shape
|
98 |
+
)
|
99 |
+
persons_with_cse.pre_process()
|
100 |
+
not_matched = [i for i in range(maskrcnn_person["segmentation"].shape[0]) if i not in matches[:, 0]]
|
101 |
+
persons_without_cse = PersonDetection(
|
102 |
+
maskrcnn_person["segmentation"][not_matched], **self.cse_post_process_cfg,
|
103 |
+
orig_imshape_CHW=im.shape
|
104 |
+
)
|
105 |
+
persons_without_cse.pre_process()
|
106 |
+
|
107 |
+
face_boxes_covered = box1_inside_box2(face_boxes, persons_with_cse.dilated_boxes).logical_or(
|
108 |
+
box1_inside_box2(face_boxes, persons_without_cse.dilated_boxes)
|
109 |
+
)
|
110 |
+
face_boxes = face_boxes[face_boxes_covered.logical_not()]
|
111 |
+
face_boxes = FaceDetection(face_boxes, **self.face_post_process_cfg)
|
112 |
+
|
113 |
+
# Order matters. The anonymizer will anonymize FIFO.
|
114 |
+
# Later detections will overwrite.
|
115 |
+
all_detections = [face_boxes, maskrcnn_other, persons_without_cse, persons_with_cse]
|
116 |
+
return all_detections
|
dp2/detection/deep_privacy1_detector.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import tops
|
3 |
+
import lzma
|
4 |
+
from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
|
5 |
+
from .base import BaseDetector
|
6 |
+
from face_detection import build_detector as build_face_detector
|
7 |
+
from .structures import FaceDetection
|
8 |
+
from tops import logger
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
def is_keypoint_within_bbox(x0, y0, x1, y1, keypoint):
|
12 |
+
keypoint = keypoint[:3, :] # only nose + eyes are relevant
|
13 |
+
kp_X = keypoint[:, 0]
|
14 |
+
kp_Y = keypoint[:, 1]
|
15 |
+
within_X = (kp_X >= x0).all() and (kp_X <= x1).all()
|
16 |
+
within_Y = (kp_Y >= y0).all() and (kp_Y <= y1).all()
|
17 |
+
return within_X and within_Y
|
18 |
+
|
19 |
+
|
20 |
+
def match_bbox_keypoint(bounding_boxes, keypoints):
|
21 |
+
"""
|
22 |
+
bounding_boxes shape: [N, 5]
|
23 |
+
keypoints: [N persons, K keypoints, (x, y)]
|
24 |
+
"""
|
25 |
+
if len(bounding_boxes) == 0 or len(keypoints) == 0:
|
26 |
+
return torch.empty((0, 5)), torch.empty((0, 7, 2))
|
27 |
+
assert bounding_boxes.shape[1] == 4,\
|
28 |
+
f"Shape was : {bounding_boxes.shape}"
|
29 |
+
assert keypoints.shape[-1] == 2,\
|
30 |
+
f"Expected (x,y) in last axis, got: {keypoints.shape}"
|
31 |
+
assert keypoints.shape[1] in (5, 7),\
|
32 |
+
f"Expeted 5 or 7 keypoints. Keypoint shape was: {keypoints.shape}"
|
33 |
+
|
34 |
+
matches = []
|
35 |
+
for bbox_idx, bbox in enumerate(bounding_boxes):
|
36 |
+
keypoint = None
|
37 |
+
for kp_idx, keypoint in enumerate(keypoints):
|
38 |
+
if kp_idx in (x[1] for x in matches):
|
39 |
+
continue
|
40 |
+
if is_keypoint_within_bbox(*bbox, keypoint):
|
41 |
+
matches.append((bbox_idx, kp_idx))
|
42 |
+
break
|
43 |
+
keypoint_idx = [x[1] for x in matches]
|
44 |
+
bbox_idx = [x[0] for x in matches]
|
45 |
+
return bounding_boxes[bbox_idx], keypoints[keypoint_idx]
|
46 |
+
|
47 |
+
|
48 |
+
class DeepPrivacy1Detector(BaseDetector):
|
49 |
+
|
50 |
+
def __init__(self,
|
51 |
+
keypoint_threshold: float,
|
52 |
+
face_detector_cfg,
|
53 |
+
score_threshold: float,
|
54 |
+
face_post_process_cfg,
|
55 |
+
**kwargs):
|
56 |
+
super().__init__(**kwargs)
|
57 |
+
self.keypoint_detector = tops.to_cuda(keypointrcnn_resnet50_fpn(
|
58 |
+
weights=KeypointRCNN_ResNet50_FPN_Weights.COCO_V1).eval())
|
59 |
+
self.keypoint_threshold = keypoint_threshold
|
60 |
+
self.face_detector = build_face_detector(**face_detector_cfg, confidence_threshold=score_threshold)
|
61 |
+
self.face_mean = tops.to_cuda(torch.from_numpy(self.face_detector.mean).view(3, 1, 1))
|
62 |
+
self.face_post_process_cfg = face_post_process_cfg
|
63 |
+
|
64 |
+
@torch.no_grad()
|
65 |
+
def _detect_faces(self, im: torch.Tensor):
|
66 |
+
H, W = im.shape[1:]
|
67 |
+
im = im.float() - self.face_mean
|
68 |
+
im = self.face_detector.resize(im[None], 1.0)
|
69 |
+
boxes_XYXY = self.face_detector._batched_detect(im)[0][:, :-1] # Remove score
|
70 |
+
boxes_XYXY[:, [0, 2]] *= W
|
71 |
+
boxes_XYXY[:, [1, 3]] *= H
|
72 |
+
return boxes_XYXY.round().long().cpu()
|
73 |
+
|
74 |
+
@torch.no_grad()
|
75 |
+
def _detect_keypoints(self, img: torch.Tensor):
|
76 |
+
img = img.float() / 255
|
77 |
+
outputs = self.keypoint_detector([img])
|
78 |
+
|
79 |
+
# Shape: [N persons, K keypoints, (x,y,visibility)]
|
80 |
+
keypoints = outputs[0]["keypoints"]
|
81 |
+
scores = outputs[0]["scores"]
|
82 |
+
assert list(scores) == sorted(list(scores))[::-1]
|
83 |
+
mask = scores >= self.keypoint_threshold
|
84 |
+
keypoints = keypoints[mask, :, :2]
|
85 |
+
return keypoints[:, :7, :2]
|
86 |
+
|
87 |
+
def __call__(self, *args, **kwargs):
|
88 |
+
return self.forward(*args, **kwargs)
|
89 |
+
|
90 |
+
@torch.no_grad()
|
91 |
+
def forward(self, im: torch.Tensor):
|
92 |
+
face_boxes = self._detect_faces(im)
|
93 |
+
keypoints = self._detect_keypoints(im)
|
94 |
+
face_boxes, keypoints = match_bbox_keypoint(face_boxes, keypoints)
|
95 |
+
face_boxes = FaceDetection(face_boxes, **self.face_post_process_cfg, keypoints=keypoints)
|
96 |
+
return [face_boxes]
|
97 |
+
|
98 |
+
def load_from_cache(self, cache_path: Path):
|
99 |
+
logger.log(f"Loading detection from cache path: {cache_path}",)
|
100 |
+
with lzma.open(cache_path, "rb") as fp:
|
101 |
+
state_dict = torch.load(fp, map_location="cpu")
|
102 |
+
kwargs = self.face_post_process_cfg
|
103 |
+
return [
|
104 |
+
state["cls"].from_state_dict(**kwargs, state_dict=state)
|
105 |
+
for state in state_dict
|
106 |
+
]
|
dp2/detection/face_detector.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import lzma
|
3 |
+
import tops
|
4 |
+
from pathlib import Path
|
5 |
+
from dp2.detection.base import BaseDetector
|
6 |
+
from face_detection import build_detector as build_face_detector
|
7 |
+
from .structures import FaceDetection
|
8 |
+
from tops import logger
|
9 |
+
|
10 |
+
|
11 |
+
def box1_inside_box2(box1: torch.Tensor, box2: torch.Tensor):
|
12 |
+
assert len(box1.shape) == 2
|
13 |
+
assert len(box2.shape) == 2
|
14 |
+
box1_inside = torch.zeros(box1.shape[0], device=box1.device, dtype=torch.bool)
|
15 |
+
# This can be batched
|
16 |
+
for i, box in enumerate(box1):
|
17 |
+
is_outside_lefttop = (box[None, [0, 1]] <= box2[:, [0, 1]]).any(dim=1)
|
18 |
+
is_outside_rightbot = (box[None, [2, 3]] >= box2[:, [2, 3]]).any(dim=1)
|
19 |
+
is_outside = is_outside_lefttop.logical_or(is_outside_rightbot)
|
20 |
+
box1_inside[i] = is_outside.logical_not().any()
|
21 |
+
return box1_inside
|
22 |
+
|
23 |
+
|
24 |
+
class FaceDetector(BaseDetector):
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
face_detector_cfg: dict,
|
29 |
+
score_threshold: float,
|
30 |
+
face_post_process_cfg: dict,
|
31 |
+
**kwargs
|
32 |
+
) -> None:
|
33 |
+
super().__init__(**kwargs)
|
34 |
+
self.face_detector = build_face_detector(**face_detector_cfg, confidence_threshold=score_threshold)
|
35 |
+
self.face_mean = tops.to_cuda(torch.from_numpy(self.face_detector.mean).view(3, 1, 1))
|
36 |
+
self.face_post_process_cfg = face_post_process_cfg
|
37 |
+
|
38 |
+
def __call__(self, *args, **kwargs):
|
39 |
+
return self.forward(*args, **kwargs)
|
40 |
+
|
41 |
+
def _detect_faces(self, im: torch.Tensor):
|
42 |
+
H, W = im.shape[1:]
|
43 |
+
im = im.float() - self.face_mean
|
44 |
+
im = self.face_detector.resize(im[None], 1.0)
|
45 |
+
boxes_XYXY = self.face_detector._batched_detect(im)[0][:, :-1] # Remove score
|
46 |
+
boxes_XYXY[:, [0, 2]] *= W
|
47 |
+
boxes_XYXY[:, [1, 3]] *= H
|
48 |
+
return boxes_XYXY.round().long().cpu()
|
49 |
+
|
50 |
+
@torch.no_grad()
|
51 |
+
def forward(self, im: torch.Tensor):
|
52 |
+
face_boxes = self._detect_faces(im)
|
53 |
+
face_boxes = FaceDetection(face_boxes, **self.face_post_process_cfg)
|
54 |
+
return [face_boxes]
|
55 |
+
|
56 |
+
def load_from_cache(self, cache_path: Path):
|
57 |
+
logger.log(f"Loading detection from cache path: {cache_path}")
|
58 |
+
with lzma.open(cache_path, "rb") as fp:
|
59 |
+
state_dict = torch.load(fp)
|
60 |
+
return [
|
61 |
+
state["cls"].from_state_dict(state_dict=state, **self.face_post_process_cfg) for state in state_dict
|
62 |
+
]
|