diff --git a/.gitattributes b/.gitattributes
index c7d9f3332a950355d5a77d85000f05e6f45435ea..3ab2d8ef167e821dadfd2aaec2f962f72567b3a2 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -32,3 +32,15 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+torch_home/hub/checkpoints/21841da7-2546-4ce3-8460-909b3a63c58b13aac1a1-c778-4c8d-9b69-3e5ed2cde9de1524e76e-7aa6-4dd8-b643-52abc9f0792c filter=lfs diff=lfs merge=lfs -text
+torch_home/hub/checkpoints/Base-DensePose-RCNN-FPN-Human.yaml filter=lfs diff=lfs merge=lfs -text
+torch_home/hub/checkpoints/Base-DensePose-RCNN-FPN.yaml filter=lfs diff=lfs merge=lfs -text
+torch_home/hub/checkpoints/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml filter=lfs diff=lfs merge=lfs -text
+torch_home/hub/checkpoints/model_final_1d3314.pkl filter=lfs diff=lfs merge=lfs -text
+torch_home/hub/checkpoints/89660f04-5c11-4dbf-adac-cbe2f11b0aeea25cbf78-7558-475a-b3c7-03f5c10b7934646b0720-ca0a-4d53-aded-daddbfa45c9e filter=lfs diff=lfs merge=lfs -text
+torch_home/hub/checkpoints/WIDERFace_DSFD_RES152.pth filter=lfs diff=lfs merge=lfs -text
+media2/stylemc_example.jpg filter=lfs diff=lfs merge=lfs -text
+media2/erling.jpg filter=lfs diff=lfs merge=lfs -text
+media2/g7_leaders.jpg filter=lfs diff=lfs merge=lfs -text
+media2/regjeringen.jpg filter=lfs diff=lfs merge=lfs -text
+media/ filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..a1554df8c50bec1b98e6ee212f338b0bd4275164
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,51 @@
+# FILES
+*.flist
+*.zip
+*.out
+*.npy
+*.gz
+*.ckpt
+*.log
+*.pyc
+*.csv
+*.yml
+*.ods
+*.ods#
+*.json
+build_docker.sh
+
+# Images / Videos
+#*.png
+#*.jpg
+*.jpeg
+*.m4a
+*.mkv
+*.mp4
+
+# Directories created by inpaintron
+.cache/
+test_examples/
+.vscode
+__pycache__
+.debug/
+**/.ipynb_checkpoints/**
+outputs/
+
+
+# From pip setup
+build/
+*.egg-info
+*.egg
+.npm/
+
+# From dockerfile
+.bash_history
+.viminfo
+.local/
+*.pickle
+*.onnx
+
+
+sbatch_files/
+figures/
+image_dump/
\ No newline at end of file
diff --git a/README.md b/README.md
index 31e59290884a816c59fc4698bd93daa859fa24bc..6ea0221438aab42ad5a2c6c88a1a0299e60a19bc 100644
--- a/README.md
+++ b/README.md
@@ -1,10 +1,10 @@
---
-title: Deep Privacy2 Face
-emoji: 👀
-colorFrom: purple
+title: Deep Privacy2
+emoji: 📈
+colorFrom: gray
colorTo: indigo
sdk: gradio
-sdk_version: 3.23.0
+sdk_version: 3.9.1
app_file: app.py
pinned: false
---
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9f0901281b5acda7b9d180302c4581d175d523d
--- /dev/null
+++ b/app.py
@@ -0,0 +1,31 @@
+import gradio
+import os
+from tops.config import instantiate
+import gradio.inputs
+os.system("pip install --upgrade pip")
+os.system("pip install ftfy regex tqdm")
+os.system("pip install --no-deps git+https://github.com/openai/CLIP.git")
+os.system("pip install git+https://github.com/facebookresearch/detectron2@96c752ce821a3340e27edd51c28a00665dd32a30#subdirectory=projects/DensePose")
+os.system("pip install --no-deps git+https://github.com/hukkelas/DSFD-Pytorch-Inference")
+os.environ["TORCH_HOME"] = "torch_home"
+from dp2 import utils
+from gradio_demos.modules import ExampleDemo, WebcamDemo
+
+cfg_face = utils.load_config("configs/anonymizers/face.py")
+
+anonymizer_face = instantiate(cfg_face.anonymizer, load_cache=False)
+
+anonymizer_face.initialize_tracker(fps=1)
+
+
+with gradio.Blocks() as demo:
+ gradio.Markdown("#
DeepPrivacy2 - Realistic Image Anonymization ")
+ gradio.Markdown("### Håkon Hukkelås, Rudolf Mester, Frank Lindseth ")
+ gradio.Markdown(" See more information at: https://github.com/hukkelas/deep_privacy2 ")
+ gradio.Markdown(" For a demo of face anonymization, see: https://huggingface.co/spaces/haakohu/deep_privacy2_face ")
+ with gradio.Tab("Face Anonymization"):
+ ExampleDemo(anonymizer_face)
+ with gradio.Tab("Live Webcam"):
+ WebcamDemo(anonymizer_face)
+
+demo.launch()
diff --git a/configs/anonymizers/FB_cse.py b/configs/anonymizers/FB_cse.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff44a8ef9da980d545b09609de82e071edc912ac
--- /dev/null
+++ b/configs/anonymizers/FB_cse.py
@@ -0,0 +1,28 @@
+from dp2.anonymizer import Anonymizer
+from dp2.detection.person_detector import CSEPersonDetector
+from ..defaults import common
+from tops.config import LazyCall as L
+from dp2.generator.dummy_generators import MaskOutGenerator
+
+
+maskout_G = L(MaskOutGenerator)(noise="constant")
+
+detector = L(CSEPersonDetector)(
+ mask_rcnn_cfg=dict(),
+ cse_cfg=dict(),
+ cse_post_process_cfg=dict(
+ target_imsize=(288, 160),
+ exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1),
+ exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]),
+ iou_combine_threshold=0.4,
+ dilation_percentage=0.02,
+ normalize_embedding=False
+ ),
+ score_threshold=0.3,
+ cache_directory=common.output_dir.joinpath("cse_person_detection_cache")
+)
+
+anonymizer = L(Anonymizer)(
+ detector="${detector}",
+ cse_person_G_cfg="configs/fdh/styleganL.py",
+)
diff --git a/configs/anonymizers/FB_cse_mask.py b/configs/anonymizers/FB_cse_mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff5e3bfbefad8e1d6e480fa22256aff0f9647b35
--- /dev/null
+++ b/configs/anonymizers/FB_cse_mask.py
@@ -0,0 +1,29 @@
+from dp2.anonymizer import Anonymizer
+from dp2.detection.person_detector import CSEPersonDetector
+from ..defaults import common
+from tops.config import LazyCall as L
+from dp2.generator.dummy_generators import MaskOutGenerator
+
+
+maskout_G = L(MaskOutGenerator)(noise="constant")
+
+detector = L(CSEPersonDetector)(
+ mask_rcnn_cfg=dict(),
+ cse_cfg=dict(),
+ cse_post_process_cfg=dict(
+ target_imsize=(288, 160),
+ exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1),
+ exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]),
+ iou_combine_threshold=0.4,
+ dilation_percentage=0.02,
+ normalize_embedding=False
+ ),
+ score_threshold=0.3,
+ cache_directory=common.output_dir.joinpath("cse_person_detection_cache")
+)
+
+anonymizer = L(Anonymizer)(
+ detector="${detector}",
+ person_G_cfg="configs/fdh/styleganL_nocse.py",
+ cse_person_G_cfg="configs/fdh/styleganL.py",
+)
diff --git a/configs/anonymizers/FB_cse_mask_face.py b/configs/anonymizers/FB_cse_mask_face.py
new file mode 100644
index 0000000000000000000000000000000000000000..d411d66cc051f6b4c0d907551735e8f661cf17f1
--- /dev/null
+++ b/configs/anonymizers/FB_cse_mask_face.py
@@ -0,0 +1,29 @@
+from dp2.anonymizer import Anonymizer
+from dp2.detection.cse_mask_face_detector import CSeMaskFaceDetector
+from ..defaults import common
+from tops.config import LazyCall as L
+
+detector = L(CSeMaskFaceDetector)(
+ mask_rcnn_cfg=dict(),
+ face_detector_cfg=dict(),
+ face_post_process_cfg=dict(target_imsize=(256, 256), fdf128_expand=False),
+ cse_cfg=dict(),
+ cse_post_process_cfg=dict(
+ target_imsize=(288, 160),
+ exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1),
+ exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]),
+ iou_combine_threshold=0.4,
+ dilation_percentage=0.02,
+ normalize_embedding=False
+ ),
+ score_threshold=0.3,
+ cache_directory=common.output_dir.joinpath("cse_mask_face_detection_cache")
+)
+
+anonymizer = L(Anonymizer)(
+ detector="${detector}",
+ face_G_cfg="configs/fdf/stylegan.py",
+ person_G_cfg="configs/fdh/styleganL_nocse.py",
+ cse_person_G_cfg="configs/fdh/styleganL.py",
+ car_G_cfg="configs/generators/dummy/pixelation8.py"
+)
diff --git a/configs/anonymizers/deep_privacy1.py b/configs/anonymizers/deep_privacy1.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bf116cefdbe716a1f9ba56b7f55d5949560cfbc
--- /dev/null
+++ b/configs/anonymizers/deep_privacy1.py
@@ -0,0 +1,15 @@
+from .face_fdf128 import anonymizer, common, detector
+from dp2.detection.deep_privacy1_detector import DeepPrivacy1Detector
+from tops.config import LazyCall as L
+
+anonymizer.update(
+ face_G_cfg="configs/fdf/deep_privacy1.py",
+)
+
+anonymizer.detector = L(DeepPrivacy1Detector)(
+ face_detector_cfg=dict(name="DSFDDetector", clip_boxes=True),
+ face_post_process_cfg=dict(target_imsize=(128, 128), fdf128_expand=True),
+ score_threshold=0.3,
+ keypoint_threshold=0.3,
+ cache_directory=common.output_dir.joinpath("deep_privacy1_cache")
+)
diff --git a/configs/anonymizers/face.py b/configs/anonymizers/face.py
new file mode 100644
index 0000000000000000000000000000000000000000..1eed93b812de5166ecddce94e36a4cb1cf4777d8
--- /dev/null
+++ b/configs/anonymizers/face.py
@@ -0,0 +1,17 @@
+from dp2.anonymizer import Anonymizer
+from dp2.detection.face_detector import FaceDetector
+from ..defaults import common
+from tops.config import LazyCall as L
+
+
+detector = L(FaceDetector)(
+ face_detector_cfg=dict(name="DSFDDetector", clip_boxes=True),
+ face_post_process_cfg=dict(target_imsize=(256, 256), fdf128_expand=False),
+ score_threshold=0.3,
+ cache_directory=common.output_dir.joinpath("face_detection_cache"),
+)
+
+anonymizer = L(Anonymizer)(
+ detector="${detector}",
+ face_G_cfg="configs/fdf/stylegan.py",
+)
diff --git a/configs/anonymizers/face_fdf128.py b/configs/anonymizers/face_fdf128.py
new file mode 100644
index 0000000000000000000000000000000000000000..327b7f5c5b2711bb59eb13489b44ad8a3c0f5f57
--- /dev/null
+++ b/configs/anonymizers/face_fdf128.py
@@ -0,0 +1,18 @@
+from dp2.anonymizer import Anonymizer
+from dp2.detection.face_detector import FaceDetector
+from ..defaults import common
+from tops.config import LazyCall as L
+
+
+detector = L(FaceDetector)(
+ face_detector_cfg=dict(name="DSFDDetector", clip_boxes=True),
+ face_post_process_cfg=dict(target_imsize=(128, 128), fdf128_expand=True),
+ score_threshold=0.3,
+ cache_directory=common.output_dir.joinpath("face_detection_cache")
+)
+
+
+anonymizer = L(Anonymizer)(
+ detector="${detector}",
+ face_G_cfg="configs/fdf/stylegan_fdf128.py",
+)
diff --git a/configs/anonymizers/market1501/blackout.py b/configs/anonymizers/market1501/blackout.py
new file mode 100644
index 0000000000000000000000000000000000000000..14da21e3c4b367a942f9a99796a1d9996b773522
--- /dev/null
+++ b/configs/anonymizers/market1501/blackout.py
@@ -0,0 +1,8 @@
+from ..FB_cse_mask_face import anonymizer, detector, common
+
+detector.score_threshold = .1
+detector.face_detector_cfg.confidence_threshold = .5
+detector.cse_cfg.score_thres = 0.3
+anonymizer.generators.face_G_cfg = None
+anonymizer.generators.person_G_cfg = "configs/generators/dummy/maskout.py"
+anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/maskout.py"
\ No newline at end of file
diff --git a/configs/anonymizers/market1501/person.py b/configs/anonymizers/market1501/person.py
new file mode 100644
index 0000000000000000000000000000000000000000..51fa99b21f068ce68f796fd32c85d37d9a22bec1
--- /dev/null
+++ b/configs/anonymizers/market1501/person.py
@@ -0,0 +1,6 @@
+from ..FB_cse_mask_face import anonymizer, detector, common
+
+detector.score_threshold = .1
+detector.face_detector_cfg.confidence_threshold = .5
+detector.cse_cfg.score_thres = 0.3
+anonymizer.generators.face_G_cfg = None
\ No newline at end of file
diff --git a/configs/anonymizers/market1501/pixelation16.py b/configs/anonymizers/market1501/pixelation16.py
new file mode 100644
index 0000000000000000000000000000000000000000..2569fc2abb91919f91dd12546c06a86624d235fc
--- /dev/null
+++ b/configs/anonymizers/market1501/pixelation16.py
@@ -0,0 +1,8 @@
+from ..FB_cse_mask_face import anonymizer, detector, common
+
+detector.score_threshold = .1
+detector.face_detector_cfg.confidence_threshold = .5
+detector.cse_cfg.score_thres = 0.3
+anonymizer.generators.face_G_cfg = None
+anonymizer.generators.person_G_cfg = "configs/generators/dummy/pixelation16.py"
+anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/pixelation16.py"
\ No newline at end of file
diff --git a/configs/anonymizers/market1501/pixelation8.py b/configs/anonymizers/market1501/pixelation8.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef49cb613d09e972adf7b8136b632eb210420686
--- /dev/null
+++ b/configs/anonymizers/market1501/pixelation8.py
@@ -0,0 +1,8 @@
+from ..FB_cse_mask_face import anonymizer, detector, common
+
+detector.score_threshold = .1
+detector.face_detector_cfg.confidence_threshold = .5
+detector.cse_cfg.score_thres = 0.3
+anonymizer.generators.face_G_cfg = None
+anonymizer.generators.person_G_cfg = "configs/generators/dummy/pixelation8.py"
+anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/pixelation8.py"
\ No newline at end of file
diff --git a/configs/datasets/coco_cse.py b/configs/datasets/coco_cse.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4834b64ebb77634ce5f1fb66a3cf9abafe26464
--- /dev/null
+++ b/configs/datasets/coco_cse.py
@@ -0,0 +1,69 @@
+import os
+from pathlib import Path
+from tops.config import LazyCall as L
+import torch
+import functools
+from dp2.data.datasets.coco_cse import CocoCSE
+from dp2.data.build import get_dataloader
+from dp2.data.transforms.transforms import CreateEmbedding, Normalize, Resize, ToFloat, CreateCondition, RandomHorizontalFlip
+from dp2.data.transforms.stylegan2_transform import StyleGANAugmentPipe
+from dp2.metrics.torch_metrics import compute_metrics_iteratively
+from .utils import final_eval_fn
+
+
+dataset_base_dir = os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data"
+metrics_cache = os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache"
+data_dir = Path(dataset_base_dir, "coco_cse")
+data = dict(
+ imsize=(288, 160),
+ im_channels=3,
+ semantic_nc=26,
+ cse_nc=16,
+ train=dict(
+ dataset=L(CocoCSE)(data_dir.joinpath("train"), transform=None, normalize_E=False),
+ loader=L(get_dataloader)(
+ shuffle=True, num_workers=6, drop_last=True, prefetch_factor=2,
+ batch_size="${train.batch_size}",
+ dataset="${..dataset}",
+ infinite=True,
+ gpu_transform=L(torch.nn.Sequential)(*[
+ L(ToFloat)(),
+ L(StyleGANAugmentPipe)(
+ rotate=0.5, rotate_max=.05,
+ xint=.5, xint_max=0.05,
+ scale=.5, scale_std=.05,
+ aniso=0.5, aniso_std=.05,
+ xfrac=.5, xfrac_std=.05,
+ brightness=.5, brightness_std=.05,
+ contrast=.5, contrast_std=.1,
+ hue=.5, hue_max=.05,
+ saturation=.5, saturation_std=.5,
+ imgfilter=.5, imgfilter_std=.1),
+ L(RandomHorizontalFlip)(p=0.5),
+ L(CreateEmbedding)(),
+ L(Resize)(size="${data.imsize}"),
+ L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
+ L(CreateCondition)(),
+ ])
+ )
+ ),
+ val=dict(
+ dataset=L(CocoCSE)(data_dir.joinpath("val"), transform=None, normalize_E=False),
+ loader=L(get_dataloader)(
+ shuffle=False, num_workers=6, drop_last=True, prefetch_factor=2,
+ batch_size="${train.batch_size}",
+ dataset="${..dataset}",
+ infinite=False,
+ gpu_transform=L(torch.nn.Sequential)(*[
+ L(ToFloat)(),
+ L(CreateEmbedding)(),
+ L(Resize)(size="${data.imsize}"),
+ L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
+ L(CreateCondition)(),
+ ])
+ )
+ ),
+ # Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP.
+ train_evaluation_fn=functools.partial(compute_metrics_iteratively, cache_directory=Path(metrics_cache, "coco_cse_val"), include_two_fake=False),
+ evaluation_fn=functools.partial(final_eval_fn, cache_directory=Path(metrics_cache, "coco_cse_val_final"), include_two_fake=True)
+)
diff --git a/configs/datasets/fdf128.py b/configs/datasets/fdf128.py
new file mode 100644
index 0000000000000000000000000000000000000000..c19fc0d30e2a2aa1b2ac04080a48479f1d190267
--- /dev/null
+++ b/configs/datasets/fdf128.py
@@ -0,0 +1,24 @@
+from pathlib import Path
+from functools import partial
+from dp2.data.datasets.fdf import FDFDataset
+from .fdf256 import data, dataset_base_dir, metrics_cache, final_eval_fn, train_eval_fn
+
+data_dir = Path(dataset_base_dir, "fdf")
+data.train.dataset.dirpath = data_dir.joinpath("train")
+data.val.dataset.dirpath = data_dir.joinpath("val")
+data.imsize = (128, 128)
+
+
+data.train_evaluation_fn = partial(
+ train_eval_fn, cache_directory=Path(metrics_cache, "fdf128_val_train"))
+data.evaluation_fn = partial(
+ final_eval_fn, cache_directory=Path(metrics_cache, "fdf128_val_final"))
+
+data.train.dataset.update(
+ _target_ = FDFDataset,
+ imsize="${data.imsize}"
+)
+data.val.dataset.update(
+ _target_ = FDFDataset,
+ imsize="${data.imsize}"
+)
\ No newline at end of file
diff --git a/configs/datasets/fdf256.py b/configs/datasets/fdf256.py
new file mode 100644
index 0000000000000000000000000000000000000000..828802db23ca519cb13963a729c597e39ae8dee8
--- /dev/null
+++ b/configs/datasets/fdf256.py
@@ -0,0 +1,55 @@
+import os
+from pathlib import Path
+from tops.config import LazyCall as L
+import torch
+import functools
+from dp2.data.datasets.fdf import FDF256Dataset
+from dp2.data.build import get_dataloader
+from dp2.data.transforms.transforms import Normalize, Resize, ToFloat, CreateCondition, RandomHorizontalFlip
+from .utils import final_eval_fn, train_eval_fn
+
+
+dataset_base_dir = os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data"
+metrics_cache = os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache"
+data_dir = Path(dataset_base_dir, "fdf256")
+data = dict(
+ imsize=(256, 256),
+ im_channels=3,
+ semantic_nc=None,
+ cse_nc=None,
+ n_keypoints=None,
+ train=dict(
+ dataset=L(FDF256Dataset)(dirpath=data_dir.joinpath("train"), transform=None, load_keypoints=False),
+ loader=L(get_dataloader)(
+ shuffle=True, num_workers=3, drop_last=True, prefetch_factor=2,
+ batch_size="${train.batch_size}",
+ dataset="${..dataset}",
+ infinite=True,
+ gpu_transform=L(torch.nn.Sequential)(*[
+ L(ToFloat)(),
+ L(RandomHorizontalFlip)(p=0.5),
+ L(Resize)(size="${data.imsize}"),
+ L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
+ L(CreateCondition)(),
+ ])
+ )
+ ),
+ val=dict(
+ dataset=L(FDF256Dataset)(dirpath=data_dir.joinpath("val"), transform=None, load_keypoints=False),
+ loader=L(get_dataloader)(
+ shuffle=False, num_workers=3, drop_last=False, prefetch_factor=2,
+ batch_size="${train.batch_size}",
+ dataset="${..dataset}",
+ infinite=False,
+ gpu_transform=L(torch.nn.Sequential)(*[
+ L(ToFloat)(),
+ L(Resize)(size="${data.imsize}"),
+ L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
+ L(CreateCondition)(),
+ ])
+ )
+ ),
+ # Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP.
+ train_evaluation_fn=functools.partial(train_eval_fn, cache_directory=Path(metrics_cache, "fdf_val_train")),
+ evaluation_fn=functools.partial(final_eval_fn, cache_directory=Path(metrics_cache, "fdf_val"))
+)
\ No newline at end of file
diff --git a/configs/datasets/fdh.py b/configs/datasets/fdh.py
new file mode 100644
index 0000000000000000000000000000000000000000..47faade58b49ef7fdf227b2b54fe89d1650080f6
--- /dev/null
+++ b/configs/datasets/fdh.py
@@ -0,0 +1,90 @@
+import os
+from pathlib import Path
+from tops.config import LazyCall as L
+import torch
+import functools
+from dp2.data.datasets.fdh import get_dataloader_fdh_wds
+from dp2.data.utils import get_coco_flipmap
+from dp2.data.transforms.transforms import (
+ Normalize,
+ ToFloat,
+ CreateCondition,
+ RandomHorizontalFlip,
+ CreateEmbedding,
+)
+from dp2.metrics.torch_metrics import compute_metrics_iteratively
+from dp2.metrics.fid_clip import compute_fid_clip
+from dp2.metrics.ppl import calculate_ppl
+from .utils import train_eval_fn
+
+
+def final_eval_fn(*args, **kwargs):
+ result = compute_metrics_iteratively(*args, **kwargs)
+ result2 = calculate_ppl(*args, **kwargs, upsample_size=(288, 160))
+ result3 = compute_fid_clip(*args, **kwargs)
+ assert all(key not in result for key in result2)
+ result.update(result2)
+ result.update(result3)
+ return result
+
+
+def get_cache_directory(imsize, subset):
+ return Path(metrics_cache, f"{subset}{imsize[0]}")
+
+dataset_base_dir = (
+ os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data"
+)
+metrics_cache = (
+ os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache"
+)
+data_dir = Path(dataset_base_dir, "fdh")
+data = dict(
+ imsize=(288, 160),
+ im_channels=3,
+ cse_nc=16,
+ n_keypoints=17,
+ train=dict(
+ loader=L(get_dataloader_fdh_wds)(
+ path=data_dir.joinpath("train", "out-{000000..001423}.tar"),
+ batch_size="${train.batch_size}",
+ num_workers=6,
+ transform=L(torch.nn.Sequential)(
+ L(RandomHorizontalFlip)(p=0.5, flip_map=get_coco_flipmap()),
+ ),
+ gpu_transform=L(torch.nn.Sequential)(
+ L(ToFloat)(norm=False, keys=["img", "mask", "E_mask", "maskrcnn_mask"]),
+ L(CreateEmbedding)(embed_path=data_dir.joinpath("embed_map.torch")),
+ L(Normalize)(mean=[0.5*255, 0.5*255, 0.5*255], std=[0.5*255, 0.5*255, 0.5*255], inplace=True),
+ L(CreateCondition)(),
+ ),
+ infinite=True,
+ shuffle=True,
+ partial_batches=False,
+ load_embedding=True,
+ keypoints_split="train",
+ load_new_keypoints=False
+ )
+ ),
+ val=dict(
+ loader=L(get_dataloader_fdh_wds)(
+ path=data_dir.joinpath("val", "out-{000000..000023}.tar"),
+ batch_size="${train.batch_size}",
+ num_workers=6,
+ transform=None,
+ gpu_transform="${data.train.loader.gpu_transform}",
+ infinite=False,
+ shuffle=False,
+ partial_batches=True,
+ load_embedding=True,
+ keypoints_split="val",
+ load_new_keypoints="${data.train.loader.load_new_keypoints}"
+ )
+ ),
+ # Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP.
+ train_evaluation_fn=L(functools.partial)(
+ train_eval_fn, cache_directory=L(get_cache_directory)(imsize="${data.imsize}", subset="fdh"),
+ data_len=30_000),
+ evaluation_fn=L(functools.partial)(
+ final_eval_fn, cache_directory=L(get_cache_directory)(imsize="${data.imsize}", subset="fdh_eval"),
+ data_len=30_000)
+)
diff --git a/configs/datasets/utils.py b/configs/datasets/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6704b12a5e379b707b2c6d3dc9f78431ce01e61d
--- /dev/null
+++ b/configs/datasets/utils.py
@@ -0,0 +1,21 @@
+from dp2.metrics.ppl import calculate_ppl
+from dp2.metrics.torch_metrics import compute_metrics_iteratively
+from dp2.metrics.fid_clip import compute_fid_clip
+
+
+def final_eval_fn(*args, **kwargs):
+ result = compute_metrics_iteratively(*args, **kwargs)
+ result2 = calculate_ppl(*args, **kwargs,)
+ result3 = compute_fid_clip(*args, **kwargs)
+ assert all(key not in result for key in result2)
+ result.update(result2)
+ result.update(result3)
+ return result
+
+
+def train_eval_fn(*args, **kwargs):
+ result = compute_metrics_iteratively(*args, **kwargs)
+ result2 = compute_fid_clip(*args, **kwargs)
+ assert all(key not in result for key in result2)
+ result.update(result2)
+ return result
\ No newline at end of file
diff --git a/configs/defaults.py b/configs/defaults.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f831200940c0aa9658bab3a82d6ad5714048d3c
--- /dev/null
+++ b/configs/defaults.py
@@ -0,0 +1,53 @@
+import pathlib
+import os
+import torch
+from tops.config import LazyCall as L
+
+if "PRETRAINED_CHECKPOINTS_PATH" in os.environ:
+ PRETRAINED_CHECKPOINTS_PATH = pathlib.Path(os.environ["PRETRAINED_CHECKPOINTS_PATH"])
+else:
+ PRETRAINED_CHECKPOINTS_PATH = pathlib.Path("pretrained_checkpoints")
+if "BASE_OUTPUT_DIR" in os.environ:
+ BASE_OUTPUT_DIR = pathlib.Path(os.environ["BASE_OUTPUT_DIR"])
+else:
+ BASE_OUTPUT_DIR = pathlib.Path("outputs")
+
+
+
+common = dict(
+ logger_backend=["wandb", "stdout", "json", "image_dumper"],
+ wandb_project="deep_privacy2",
+ output_dir=BASE_OUTPUT_DIR,
+ experiment_name=None, # Optional experiment name to show on wandb
+)
+
+train = dict(
+ batch_size=32,
+ seed=0,
+ ims_per_log=1024,
+ ims_per_val=int(200e3),
+ max_images_to_train=int(12e6),
+ amp=dict(
+ enabled=True,
+ scaler_D=L(torch.cuda.amp.GradScaler)(init_scale=2**16, growth_factor=4, growth_interval=100, enabled="${..enabled}"),
+ scaler_G=L(torch.cuda.amp.GradScaler)(init_scale=2**16, growth_factor=4, growth_interval=100, enabled="${..enabled}"),
+ ),
+ fp16_ddp_accumulate=False, # All gather gradients in fp16?
+ broadcast_buffers=False,
+ bias_act_plugin_enabled=True,
+ grid_sample_gradfix_enabled=True,
+ conv2d_gradfix_enabled=False,
+ channels_last=False,
+ compile_G=dict(
+ enabled=False,
+ mode="default" # default, reduce-overhead or max-autotune
+ ),
+ compile_D=dict(
+ enabled=False,
+ mode="default" # default, reduce-overhead or max-autotune
+ )
+)
+
+# exponential moving average
+EMA = dict(rampup=0.05)
+
diff --git a/configs/discriminators/sg2_discriminator.py b/configs/discriminators/sg2_discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..081ebfb76b330ab99334725b9cc82db06bfb1c5f
--- /dev/null
+++ b/configs/discriminators/sg2_discriminator.py
@@ -0,0 +1,43 @@
+from tops.config import LazyCall as L
+from dp2.discriminator import SG2Discriminator
+import torch
+from dp2.loss import StyleGAN2Loss
+
+
+discriminator = L(SG2Discriminator)(
+ imsize="${data.imsize}",
+ im_channels="${data.im_channels}",
+ min_fmap_resolution=4,
+ max_cnum_mul=8,
+ cnum=80,
+ input_condition=True,
+ conv_clamp=256,
+ input_cse=False,
+ cse_nc="${data.cse_nc}",
+ fix_residual=False,
+)
+
+
+loss_fnc = L(StyleGAN2Loss)(
+ lazy_regularization=True,
+ lazy_reg_interval=16,
+ r1_opts=dict(lambd=5, mask_out=False, mask_out_scale=False),
+ EP_lambd=0.001,
+ pl_reg_opts=dict(weight=0, batch_shrink=2,start_nimg=int(1e6), pl_decay=0.01)
+)
+
+def build_D_optim(type, lr, betas, lazy_regularization, lazy_reg_interval, **kwargs):
+ if lazy_regularization:
+ # From Analyzing and improving the image quality of stylegan, CVPR 2020
+ c = lazy_reg_interval / (lazy_reg_interval + 1)
+ betas = [beta ** c for beta in betas]
+ lr *= c
+ print(f"Lazy regularization on. Setting lr to: {lr}, betas to: {betas}")
+ return type(lr=lr, betas=betas, **kwargs)
+
+
+D_optim = L(build_D_optim)(
+ type=torch.optim.Adam, lr=0.001, betas=(0.0, 0.99),
+ lazy_regularization="${loss_fnc.lazy_regularization}",
+ lazy_reg_interval="${loss_fnc.lazy_reg_interval}")
+G_optim = L(torch.optim.Adam)(lr=0.001, betas=(0.0, 0.99))
diff --git a/configs/fdf/deep_privacy1.py b/configs/fdf/deep_privacy1.py
new file mode 100644
index 0000000000000000000000000000000000000000..88c09bd4b5810182bdd3520dfdf4f98bdcea3829
--- /dev/null
+++ b/configs/fdf/deep_privacy1.py
@@ -0,0 +1,9 @@
+from tops.config import LazyCall as L
+from dp2.generator.deep_privacy1 import MSGGenerator
+from ..datasets.fdf128 import data
+from ..defaults import common, train
+
+generator = L(MSGGenerator)()
+
+common.model_url = "https://folk.ntnu.no/haakohu/checkpoints/fdf128_model512.ckpt"
+common.model_md5sum = "6cc8b285bdc1fcdfc64f5db7c521d0a6"
\ No newline at end of file
diff --git a/configs/fdf/stylegan.py b/configs/fdf/stylegan.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4da2c3ad76d3d1fb6e1d91e832cde5c735bf32a
--- /dev/null
+++ b/configs/fdf/stylegan.py
@@ -0,0 +1,14 @@
+from ..generators.stylegan_unet import generator
+from ..datasets.fdf256 import data
+from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
+from ..defaults import train, common, EMA
+
+train.max_images_to_train = int(35e6)
+G_optim.lr = 0.002
+D_optim.lr = 0.002
+generator.input_cse = False
+loss_fnc.r1_opts.lambd = 1
+train.ims_per_val = int(2e6)
+
+common.model_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/89660f04-5c11-4dbf-adac-cbe2f11b0aeea25cbf78-7558-475a-b3c7-03f5c10b7934646b0720-ca0a-4d53-aded-daddbfa45c9e"
+common.model_md5sum = "e8e32190528af2ed75f0cb792b7f2b07"
\ No newline at end of file
diff --git a/configs/fdf/stylegan_fdf128.py b/configs/fdf/stylegan_fdf128.py
new file mode 100644
index 0000000000000000000000000000000000000000..a47d6d2ee362c935e7879c9442c4dcd9aaf007c0
--- /dev/null
+++ b/configs/fdf/stylegan_fdf128.py
@@ -0,0 +1,17 @@
+from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
+from ..datasets.fdf128 import data
+from ..generators.stylegan_unet import generator
+from ..defaults import train, common, EMA
+from tops.config import LazyCall as L
+
+G_optim.lr = 0.002
+D_optim.lr = 0.002
+generator.update(cnum=128, max_cnum_mul=4, input_cse=False)
+loss_fnc.r1_opts.lambd = 0.1
+
+train.update(ims_per_val=int(2e6), batch_size=64, max_images_to_train=int(35e6))
+
+common.update(
+ model_url="https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/66d803c0-55ce-44c0-9d53-815c2c0e6ba4eb458409-9e91-45d1-bce0-95c8a47a57218b102fdf-bea3-44dc-aac4-0fb1d370ef1c",
+ model_md5sum="bccd4403e7c9bca682566ff3319e8176"
+)
\ No newline at end of file
diff --git a/configs/fdh/styleganL.py b/configs/fdh/styleganL.py
new file mode 100644
index 0000000000000000000000000000000000000000..48fcf09b43a7141a270fbe5c69bd7932414270fe
--- /dev/null
+++ b/configs/fdh/styleganL.py
@@ -0,0 +1,16 @@
+from tops.config import LazyCall as L
+from ..generators.stylegan_unet import generator
+from ..datasets.fdh import data
+from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
+from ..defaults import train, common, EMA
+
+train.max_images_to_train = int(50e6)
+train.batch_size = 64
+G_optim.lr = 0.002
+D_optim.lr = 0.002
+data.train.loader.num_workers = 4
+train.ims_per_val = int(1e6)
+loss_fnc.r1_opts.lambd = .1
+
+common.model_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/21841da7-2546-4ce3-8460-909b3a63c58b13aac1a1-c778-4c8d-9b69-3e5ed2cde9de1524e76e-7aa6-4dd8-b643-52abc9f0792c"
+common.model_md5sum = "3411478b5ec600a4219cccf4499732bd"
\ No newline at end of file
diff --git a/configs/fdh/styleganL_nocse.py b/configs/fdh/styleganL_nocse.py
new file mode 100644
index 0000000000000000000000000000000000000000..210fd68743f0b872f89f4407dfaac7c9bf5f0e32
--- /dev/null
+++ b/configs/fdh/styleganL_nocse.py
@@ -0,0 +1,14 @@
+from tops.config import LazyCall as L
+from ..generators.stylegan_unet import generator
+from ..datasets.fdh import data
+from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
+from ..defaults import train, common, EMA
+
+train.max_images_to_train = int(50e6)
+G_optim.lr = 0.002
+D_optim.lr = 0.002
+generator.input_cse = False
+data.load_embeddings = False
+common.model_url = "https://folk.ntnu.no/haakohu/checkpoints/deep_privacy2/fdh_styleganL_nocse.ckpt"
+common.model_md5sum = "fda0d809741bc67487abada793975c37"
+generator.fix_errors = False
\ No newline at end of file
diff --git a/configs/generators/stylegan_unet.py b/configs/generators/stylegan_unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..638859263a1cb549f533b75b2b19609665b3443e
--- /dev/null
+++ b/configs/generators/stylegan_unet.py
@@ -0,0 +1,22 @@
+from dp2.generator.stylegan_unet import StyleGANUnet
+from tops.config import LazyCall as L
+
+generator = L(StyleGANUnet)(
+ imsize="${data.imsize}",
+ im_channels="${data.im_channels}",
+ min_fmap_resolution=8,
+ cnum=64,
+ max_cnum_mul=8,
+ n_middle_blocks=0,
+ z_channels=512,
+ mask_output=True,
+ conv_clamp=256,
+ input_cse=True,
+ scale_grad=True,
+ cse_nc="${data.cse_nc}",
+ w_dim=512,
+ n_keypoints="${data.n_keypoints}",
+ input_keypoints=False,
+ input_keypoint_indices=[],
+ fix_errors=True
+)
\ No newline at end of file
diff --git a/dp2/__init__.py b/dp2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/dp2/anonymizer/__init__.py b/dp2/anonymizer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..32606aa927c8d593d64be02a499fba057b8ba6fa
--- /dev/null
+++ b/dp2/anonymizer/__init__.py
@@ -0,0 +1 @@
+from .anonymizer import Anonymizer
diff --git a/dp2/anonymizer/anonymizer.py b/dp2/anonymizer/anonymizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a32e61a122c2e1916c0ce9baffecb095c5a136de
--- /dev/null
+++ b/dp2/anonymizer/anonymizer.py
@@ -0,0 +1,163 @@
+from pathlib import Path
+from typing import Union, Optional
+import numpy as np
+import torch
+import tops
+import torchvision.transforms.functional as F
+from motpy import Detection, MultiObjectTracker
+from dp2.utils import load_config
+from dp2.infer import build_trained_generator
+from dp2.detection.structures import CSEPersonDetection, FaceDetection, PersonDetection, VehicleDetection
+
+
+def load_generator_from_cfg_path(cfg_path: Union[str, Path]):
+ cfg = load_config(cfg_path)
+ G = build_trained_generator(cfg)
+ tops.logger.log(f"Loaded generator from: {cfg_path}")
+ return G
+
+
+class Anonymizer:
+
+ def __init__(
+ self,
+ detector,
+ load_cache: bool = False,
+ person_G_cfg: Optional[Union[str, Path]] = None,
+ cse_person_G_cfg: Optional[Union[str, Path]] = None,
+ face_G_cfg: Optional[Union[str, Path]] = None,
+ car_G_cfg: Optional[Union[str, Path]] = None,
+ ) -> None:
+ self.detector = detector
+ self.generators = {k: None for k in [CSEPersonDetection, PersonDetection, FaceDetection, VehicleDetection]}
+ self.load_cache = load_cache
+ if cse_person_G_cfg is not None:
+ self.generators[CSEPersonDetection] = load_generator_from_cfg_path(cse_person_G_cfg)
+ if person_G_cfg is not None:
+ self.generators[PersonDetection] = load_generator_from_cfg_path(person_G_cfg)
+ if face_G_cfg is not None:
+ self.generators[FaceDetection] = load_generator_from_cfg_path(face_G_cfg)
+ if car_G_cfg is not None:
+ self.generators[VehicleDetection] = load_generator_from_cfg_path(car_G_cfg)
+
+ def initialize_tracker(self, fps: float):
+ self.tracker = MultiObjectTracker(dt=1/fps)
+ self.track_to_z_idx = dict()
+
+ def reset_tracker(self):
+ self.track_to_z_idx = dict()
+
+ def forward_G(self,
+ G,
+ batch,
+ multi_modal_truncation: bool,
+ amp: bool,
+ z_idx: int,
+ truncation_value: float,
+ idx: int,
+ all_styles=None):
+ batch["img"] = F.normalize(batch["img"].float(), [0.5*255, 0.5*255, 0.5*255], [0.5*255, 0.5*255, 0.5*255])
+ batch["img"] = batch["img"].float()
+ batch["condition"] = batch["mask"].float() * batch["img"]
+
+ with torch.cuda.amp.autocast(amp):
+ z = None
+ if z_idx is not None:
+ state = np.random.RandomState(seed=z_idx[idx])
+ z = state.normal(size=(1, G.z_channels)).astype(np.float32)
+ z = tops.to_cuda(torch.from_numpy(z))
+
+ if all_styles is not None:
+ anonymized_im = G(**batch, s=iter(all_styles[idx]))["img"]
+ elif multi_modal_truncation:
+ w_indices = None
+ if z_idx is not None:
+ w_indices = [z_idx[idx] % len(G.style_net.w_centers)]
+ anonymized_im = G.multi_modal_truncate(
+ **batch, truncation_value=truncation_value,
+ w_indices=w_indices,
+ z=z
+ )["img"]
+ else:
+ anonymized_im = G.sample(**batch, truncation_value=truncation_value, z=z)["img"]
+ anonymized_im = (anonymized_im+1).div(2).clamp(0, 1).mul(255)
+ return anonymized_im
+
+ @torch.no_grad()
+ def anonymize_detections(self,
+ im, detection,
+ update_identity=None,
+ **synthesis_kwargs
+ ):
+ G = self.generators[type(detection)]
+ if G is None:
+ return im
+ C, H, W = im.shape
+ if update_identity is None:
+ update_identity = [True for i in range(len(detection))]
+ for idx in range(len(detection)):
+ if not update_identity[idx]:
+ continue
+ batch = detection.get_crop(idx, im)
+ x0, y0, x1, y1 = batch.pop("boxes")[0]
+ batch = {k: tops.to_cuda(v) for k, v in batch.items()}
+ anonymized_im = self.forward_G(G, batch, **synthesis_kwargs, idx=idx)
+
+ gim = F.resize(anonymized_im[0], (y1-y0, x1-x0), interpolation=F.InterpolationMode.BICUBIC, antialias=True)
+ mask = F.resize(batch["mask"][0], (y1-y0, x1-x0), interpolation=F.InterpolationMode.NEAREST).squeeze(0)
+ # Remove padding
+ pad = [max(-x0, 0), max(-y0, 0)]
+ pad = [*pad, max(x1-W, 0), max(y1-H, 0)]
+ def remove_pad(x): return x[..., pad[1]:x.shape[-2]-pad[3], pad[0]:x.shape[-1]-pad[2]]
+
+ gim = remove_pad(gim)
+ mask = remove_pad(mask) > 0.5
+ x0, y0 = max(x0, 0), max(y0, 0)
+ x1, y1 = min(x1, W), min(y1, H)
+ mask = mask.logical_not()[None].repeat(3, 1, 1)
+
+ im[:, y0:y1, x0:x1][mask] = gim[mask].round().clamp(0, 255).byte()
+ return im
+
+ def visualize_detection(self, im: torch.Tensor, cache_id: str = None) -> torch.Tensor:
+ all_detections = self.detector.forward_and_cache(im, cache_id, load_cache=self.load_cache)
+ im = im.cpu()
+ for det in all_detections:
+ im = det.visualize(im)
+ return im
+
+ @torch.no_grad()
+ def forward(self, im: torch.Tensor, cache_id: str = None, track=True, detections=None, **synthesis_kwargs) -> torch.Tensor:
+ assert im.dtype == torch.uint8
+ im = tops.to_cuda(im)
+ all_detections = detections
+ if detections is None:
+ if self.load_cache:
+ all_detections = self.detector.forward_and_cache(im, cache_id)
+ else:
+ all_detections = self.detector(im)
+ if hasattr(self, "tracker") and track:
+ [_.pre_process() for _ in all_detections]
+ boxes = np.concatenate([_.boxes for _ in all_detections])
+ boxes = [Detection(box) for box in boxes]
+ self.tracker.step(boxes)
+ track_ids = self.tracker.detections_matched_ids
+ z_idx = []
+ for track_id in track_ids:
+ if track_id not in self.track_to_z_idx:
+ self.track_to_z_idx[track_id] = np.random.randint(0, 2**32-1)
+ z_idx.append(self.track_to_z_idx[track_id])
+ z_idx = np.array(z_idx)
+ idx_offset = 0
+
+ for detection in all_detections:
+ zs = None
+ if hasattr(self, "tracker") and track:
+ zs = z_idx[idx_offset:idx_offset+len(detection)]
+ idx_offset += len(detection)
+ im = self.anonymize_detections(im, detection, z_idx=zs, **synthesis_kwargs)
+
+ return im.cpu()
+
+ def __call__(self, *args, **kwargs):
+ return self.forward(*args, **kwargs)
diff --git a/dp2/anonymizer/histogram_match_anonymizers.py b/dp2/anonymizer/histogram_match_anonymizers.py
new file mode 100644
index 0000000000000000000000000000000000000000..421c80d5624b113afdf9aa4908b5c9cd0ba33c94
--- /dev/null
+++ b/dp2/anonymizer/histogram_match_anonymizers.py
@@ -0,0 +1,93 @@
+
+import torch
+import tops
+import numpy as np
+from kornia.color import rgb_to_hsv
+from dp2 import utils
+from kornia.enhance import histogram
+from .anonymizer import Anonymizer
+import torchvision.transforms.functional as F
+from skimage.exposure import match_histograms
+from kornia.filters import gaussian_blur2d
+
+
+class LatentHistogramMatchAnonymizer(Anonymizer):
+
+ def forward_G(
+ self,
+ G,
+ batch,
+ multi_modal_truncation: bool,
+ amp: bool,
+ z_idx: int,
+ truncation_value: float,
+ idx: int,
+ n_sampling_steps: int = 1,
+ all_styles=None,
+ ):
+ batch["img"] = F.normalize(batch["img"].float(), [0.5*255, 0.5*255, 0.5*255], [0.5*255, 0.5*255, 0.5*255])
+ batch["img"] = batch["img"].float()
+ batch["condition"] = batch["mask"].float() * batch["img"]
+
+ assert z_idx is None and all_styles is None, "Arguments not supported with n_sampling_steps > 1."
+ real_hls = rgb_to_hsv(utils.denormalize_img(batch["img"]))
+ real_hls[:, 0] /= 2 * torch.pi
+ indices = [1, 2]
+ hist_kwargs = dict(
+ bins=torch.linspace(0, 1, 256, dtype=torch.float32, device=tops.get_device()),
+ bandwidth=torch.tensor(1., device=tops.get_device()))
+ real_hist = [histogram(real_hls[:, i].flatten(start_dim=1), **hist_kwargs) for i in indices]
+ for j in range(n_sampling_steps):
+ if j == 0:
+ if multi_modal_truncation:
+ w = G.style_net.multi_modal_truncate(
+ truncation_value=truncation_value, **batch, w_indices=None).detach()
+ else:
+ w = G.style_net.get_truncated(truncation_value, **batch).detach()
+ assert z_idx is None and all_styles is None, "Arguments not supported with n_sampling_steps > 1."
+ w.requires_grad = True
+ optim = torch.optim.Adam([w])
+ with torch.set_grad_enabled(True):
+ with torch.cuda.amp.autocast(amp):
+ anonymized_im = G(**batch, truncation_value=None, w=w)["img"]
+ fake_hls = rgb_to_hsv(anonymized_im*0.5 + 0.5)
+ fake_hls[:, 0] /= 2 * torch.pi
+ fake_hist = [histogram(fake_hls[:, i].flatten(start_dim=1), **hist_kwargs) for i in indices]
+ dist = sum([utils.torch_wasserstein_loss(r, f) for r, f in zip(real_hist, fake_hist)])
+ dist.backward()
+ if w.grad.sum() == 0:
+ break
+ assert w.grad.sum() != 0
+ optim.step()
+ optim.zero_grad()
+ if dist < 0.02:
+ break
+ anonymized_im = (anonymized_im+1).div(2).clamp(0, 1).mul(255)
+ return anonymized_im
+
+
+class HistogramMatchAnonymizer(Anonymizer):
+
+ def forward_G(self, batch, *args, **kwargs):
+ rimg = batch["img"]
+ batch["img"] = F.normalize(batch["img"].float(), [0.5*255, 0.5*255, 0.5*255], [0.5*255, 0.5*255, 0.5*255])
+ batch["img"] = batch["img"].float()
+ batch["condition"] = batch["mask"].float() * batch["img"]
+
+ anonymized_im = super().forward_G(batch, *args, **kwargs)
+
+ equalized_gim = match_histograms(tops.im2numpy(anonymized_im.round().clamp(0, 255).byte()), tops.im2numpy(rimg))
+ if equalized_gim.dtype != np.uint8:
+ equalized_gim = equalized_gim.astype(np.float32)
+ assert equalized_gim.dtype == np.float32, equalized_gim.dtype
+ equalized_gim = tops.im2torch(equalized_gim, to_float=False)[0]
+ else:
+ equalized_gim = tops.im2torch(equalized_gim, to_float=False).float()[0]
+ equalized_gim = equalized_gim.to(device=rimg.device)
+ assert equalized_gim.dtype == torch.float32
+ gaussian_mask = 1 - (batch["maskrcnn_mask"][0].repeat(3, 1, 1) > 0.5).float()
+
+ gaussian_mask = gaussian_blur2d(gaussian_mask[None], kernel_size=[19, 19], sigma=[10, 10])[0]
+ gaussian_mask = gaussian_mask / gaussian_mask.max()
+ anonymized_im = gaussian_mask * equalized_gim + (1-gaussian_mask) * anonymized_im
+ return anonymized_im
diff --git a/dp2/data/__init__.py b/dp2/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/dp2/data/build.py b/dp2/data/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..ceab946b4da20467f879f3c6af0e9eb985465ac4
--- /dev/null
+++ b/dp2/data/build.py
@@ -0,0 +1,40 @@
+import torch
+import tops
+from .utils import collate_fn
+
+
+def get_dataloader(
+ dataset, gpu_transform: torch.nn.Module,
+ num_workers,
+ batch_size,
+ infinite: bool,
+ drop_last: bool,
+ prefetch_factor: int,
+ shuffle,
+ channels_last=False
+ ):
+ sampler = None
+ dl_kwargs = dict(
+ pin_memory=True,
+ )
+ if infinite:
+ sampler = tops.InfiniteSampler(
+ dataset, rank=tops.rank(),
+ num_replicas=tops.world_size(),
+ shuffle=shuffle
+ )
+ elif tops.world_size() > 1:
+ sampler = torch.utils.data.DistributedSampler(
+ dataset, shuffle=shuffle, num_replicas=tops.world_size(), rank=tops.rank())
+ dl_kwargs["drop_last"] = drop_last
+ else:
+ dl_kwargs["shuffle"] = shuffle
+ dl_kwargs["drop_last"] = drop_last
+ dataloader = torch.utils.data.DataLoader(
+ dataset, sampler=sampler, collate_fn=collate_fn,
+ batch_size=batch_size,
+ num_workers=num_workers, prefetch_factor=prefetch_factor,
+ **dl_kwargs
+ )
+ dataloader = tops.DataPrefetcher(dataloader, gpu_transform, channels_last=channels_last)
+ return dataloader
diff --git a/dp2/data/datasets/__init__.py b/dp2/data/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/dp2/data/datasets/coco_cse.py b/dp2/data/datasets/coco_cse.py
new file mode 100644
index 0000000000000000000000000000000000000000..b240932da41b2db03c3807830935a78b50f84c4f
--- /dev/null
+++ b/dp2/data/datasets/coco_cse.py
@@ -0,0 +1,68 @@
+import pickle
+import torchvision
+import torch
+import pathlib
+import numpy as np
+from typing import Callable, Optional, Union
+from torch.hub import get_dir as get_hub_dir
+
+
+def cache_embed_stats(embed_map: torch.Tensor):
+ mean = embed_map.mean(dim=0, keepdim=True)
+ rstd = ((embed_map - mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt()
+
+ cache = dict(mean=mean, rstd=rstd, embed_map=embed_map)
+ path = pathlib.Path(get_hub_dir(), f"embed_map_stats.torch")
+ path.parent.mkdir(exist_ok=True, parents=True)
+ torch.save(cache, path)
+
+
+class CocoCSE(torch.utils.data.Dataset):
+
+ def __init__(self,
+ dirpath: Union[str, pathlib.Path],
+ transform: Optional[Callable],
+ normalize_E: bool,):
+ dirpath = pathlib.Path(dirpath)
+ self.dirpath = dirpath
+
+ self.transform = transform
+ assert self.dirpath.is_dir(),\
+ f"Did not find dataset at: {dirpath}"
+ self.image_paths, self.embedding_paths = self._load_impaths()
+ self.embed_map = torch.from_numpy(np.load(self.dirpath.joinpath("embed_map.npy")))
+ mean = self.embed_map.mean(dim=0, keepdim=True)
+ rstd = ((self.embed_map - mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt()
+ self.embed_map = (self.embed_map - mean) * rstd
+ cache_embed_stats(self.embed_map)
+
+ def _load_impaths(self):
+ image_dir = self.dirpath.joinpath("images")
+ image_paths = list(image_dir.glob("*.png"))
+ image_paths.sort()
+ embedding_paths = [
+ self.dirpath.joinpath("embedding", x.stem + ".npy") for x in image_paths
+ ]
+ return image_paths, embedding_paths
+
+ def __len__(self):
+ return len(self.image_paths)
+
+ def __getitem__(self, idx):
+ im = torchvision.io.read_image(str(self.image_paths[idx]))
+ vertices, mask, border = np.split(np.load(self.embedding_paths[idx]), 3, axis=-1)
+ vertices = torch.from_numpy(vertices.squeeze()).long()
+ mask = torch.from_numpy(mask.squeeze()).float()
+ border = torch.from_numpy(border.squeeze()).float()
+ E_mask = 1 - mask - border
+ batch = {
+ "img": im,
+ "vertices": vertices[None],
+ "mask": mask[None],
+ "embed_map": self.embed_map,
+ "border": border[None],
+ "E_mask": E_mask[None]
+ }
+ if self.transform is None:
+ return batch
+ return self.transform(batch)
diff --git a/dp2/data/datasets/fdf.py b/dp2/data/datasets/fdf.py
new file mode 100644
index 0000000000000000000000000000000000000000..23f68a52d4fb50143b2ef6720e126991b2981afc
--- /dev/null
+++ b/dp2/data/datasets/fdf.py
@@ -0,0 +1,128 @@
+import pathlib
+from typing import Tuple
+import numpy as np
+import torch
+import pathlib
+try:
+ import pyspng
+ PYSPNG_IMPORTED = True
+except ImportError:
+ PYSPNG_IMPORTED = False
+ print("Could not load pyspng. Defaulting to pillow image backend.")
+ from PIL import Image
+from tops import logger
+
+
+class FDFDataset:
+
+ def __init__(self,
+ dirpath,
+ imsize: Tuple[int],
+ load_keypoints: bool,
+ transform):
+ dirpath = pathlib.Path(dirpath)
+ self.dirpath = dirpath
+ self.transform = transform
+ self.imsize = imsize[0]
+ self.load_keypoints = load_keypoints
+ assert self.dirpath.is_dir(),\
+ f"Did not find dataset at: {dirpath}"
+ image_dir = self.dirpath.joinpath("images", str(self.imsize))
+ self.image_paths = list(image_dir.glob("*.png"))
+ assert len(self.image_paths) > 0,\
+ f"Did not find images in: {image_dir}"
+ self.image_paths.sort(key=lambda x: int(x.stem))
+ self.landmarks = np.load(self.dirpath.joinpath("landmarks.npy")).reshape(-1, 7, 2).astype(np.float32)
+
+ self.bounding_boxes = torch.load(self.dirpath.joinpath("bounding_box", f"{self.imsize}.torch"))
+ assert len(self.image_paths) == len(self.bounding_boxes)
+ assert len(self.image_paths) == len(self.landmarks)
+ logger.log(
+ f"Dataset loaded from: {dirpath}. Number of samples:{len(self)}, imsize={imsize}")
+
+ def get_mask(self, idx):
+ mask = torch.ones((1, self.imsize, self.imsize), dtype=torch.bool)
+ bounding_box = self.bounding_boxes[idx]
+ x0, y0, x1, y1 = bounding_box
+ mask[:, y0:y1, x0:x1] = 0
+ return mask
+
+ def __len__(self):
+ return len(self.image_paths)
+
+ def __getitem__(self, index):
+ impath = self.image_paths[index]
+ if PYSPNG_IMPORTED:
+ with open(impath, "rb") as fp:
+ im = pyspng.load(fp.read())
+ else:
+ with Image.open(impath) as fp:
+ im = np.array(fp)
+ im = torch.from_numpy(np.rollaxis(im, -1, 0))
+ masks = self.get_mask(index)
+ landmark = self.landmarks[index]
+ batch = {
+ "img": im,
+ "mask": masks,
+ }
+ if self.load_keypoints:
+ batch["keypoints"] = landmark
+ if self.transform is None:
+ return batch
+ return self.transform(batch)
+
+
+class FDF256Dataset:
+
+ def __init__(self,
+ dirpath,
+ load_keypoints: bool,
+ transform):
+ dirpath = pathlib.Path(dirpath)
+ self.dirpath = dirpath
+ self.transform = transform
+ self.load_keypoints = load_keypoints
+ assert self.dirpath.is_dir(),\
+ f"Did not find dataset at: {dirpath}"
+ image_dir = self.dirpath.joinpath("images")
+ self.image_paths = list(image_dir.glob("*.png"))
+ assert len(self.image_paths) > 0,\
+ f"Did not find images in: {image_dir}"
+ self.image_paths.sort(key=lambda x: int(x.stem))
+ self.landmarks = np.load(self.dirpath.joinpath("landmarks.npy")).reshape(-1, 7, 2).astype(np.float32)
+ self.bounding_boxes = torch.from_numpy(np.load(self.dirpath.joinpath("bounding_box.npy")))
+ assert len(self.image_paths) == len(self.bounding_boxes)
+ assert len(self.image_paths) == len(self.landmarks)
+ logger.log(
+ f"Dataset loaded from: {dirpath}. Number of samples:{len(self)}")
+
+ def get_mask(self, idx):
+ mask = torch.ones((1, 256, 256), dtype=torch.bool)
+ bounding_box = self.bounding_boxes[idx]
+ x0, y0, x1, y1 = bounding_box
+ mask[:, y0:y1, x0:x1] = 0
+ return mask
+
+ def __len__(self):
+ return len(self.image_paths)
+
+ def __getitem__(self, index):
+ impath = self.image_paths[index]
+ if PYSPNG_IMPORTED:
+ with open(impath, "rb") as fp:
+ im = pyspng.load(fp.read())
+ else:
+ with Image.open(impath) as fp:
+ im = np.array(fp)
+ im = torch.from_numpy(np.rollaxis(im, -1, 0))
+ masks = self.get_mask(index)
+ landmark = self.landmarks[index]
+ batch = {
+ "img": im,
+ "mask": masks,
+ }
+ if self.load_keypoints:
+ batch["keypoints"] = landmark
+ if self.transform is None:
+ return batch
+ return self.transform(batch)
diff --git a/dp2/data/datasets/fdf128_wds.py b/dp2/data/datasets/fdf128_wds.py
new file mode 100644
index 0000000000000000000000000000000000000000..5af477de9b1d2e2670bfae24930c9e698e95b975
--- /dev/null
+++ b/dp2/data/datasets/fdf128_wds.py
@@ -0,0 +1,96 @@
+import torch
+import tops
+import numpy as np
+import io
+import webdataset as wds
+import os
+from ..utils import png_decoder, get_num_workers, collate_fn
+
+
+def kp_decoder(x):
+ # Keypoints are between [0, 1] for webdataset
+ keypoints = torch.from_numpy(np.load(io.BytesIO(x))).float().view(7, 2).clamp(0, 1)
+ keypoints = torch.cat((keypoints, torch.ones((7, 1))), dim=-1)
+ return keypoints
+
+
+def bbox_decoder(x):
+ return torch.from_numpy(np.load(io.BytesIO(x))).float().view(4)
+
+
+class BBoxToMask:
+
+ def __call__(self, sample):
+ imsize = sample["image.png"].shape[-1]
+ bbox = sample["bounding_box.npy"] * imsize
+ x0, y0, x1, y1 = np.round(bbox).astype(np.int64)
+ mask = torch.ones((1, imsize, imsize), dtype=torch.bool)
+ mask[:, y0:y1, x0:x1] = 0
+ sample["mask"] = mask
+ return sample
+
+
+def get_dataloader_fdf_wds(
+ path,
+ batch_size: int,
+ num_workers: int,
+ transform: torch.nn.Module,
+ gpu_transform: torch.nn.Module,
+ infinite: bool,
+ shuffle: bool,
+ partial_batches: bool,
+ sample_shuffle=10_000,
+ tar_shuffle=100,
+ channels_last=False,
+ ):
+ # Need to set this for split_by_node to work.
+ os.environ["RANK"] = str(tops.rank())
+ os.environ["WORLD_SIZE"] = str(tops.world_size())
+ if infinite:
+ pipeline = [wds.ResampledShards(str(path))]
+ else:
+ pipeline = [wds.SimpleShardList(str(path))]
+ if shuffle:
+ pipeline.append(wds.shuffle(tar_shuffle))
+ pipeline.extend([
+ wds.split_by_node,
+ wds.split_by_worker,
+ ])
+ if shuffle:
+ pipeline.append(wds.shuffle(sample_shuffle))
+
+ decoder = [
+ wds.handle_extension("image.png", png_decoder),
+ wds.handle_extension("keypoints.npy", kp_decoder),
+ ]
+
+ rename_keys = [
+ ["img", "image.png"],
+ ["keypoints", "keypoints.npy"],
+ ["__key__", "__key__"],
+ ["mask", "mask"]
+ ]
+
+ pipeline.extend([
+ wds.tarfile_to_samples(),
+ wds.decode(*decoder),
+ ])
+ pipeline.append(wds.map(BBoxToMask()))
+ pipeline.extend([
+ wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches),
+ wds.rename_keys(*rename_keys),
+ ])
+
+ if transform is not None:
+ pipeline.append(wds.map(transform))
+ pipeline = wds.DataPipeline(*pipeline)
+ if infinite:
+ pipeline = pipeline.repeat(nepochs=1000000)
+
+ loader = wds.WebLoader(
+ pipeline, batch_size=None, shuffle=False,
+ num_workers=get_num_workers(num_workers),
+ persistent_workers=True,
+ )
+ loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False)
+ return loader
diff --git a/dp2/data/datasets/fdh.py b/dp2/data/datasets/fdh.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c5293b42874c644da4622687f407c069a0a8e07
--- /dev/null
+++ b/dp2/data/datasets/fdh.py
@@ -0,0 +1,142 @@
+import torch
+import tops
+import numpy as np
+import io
+import webdataset as wds
+import os
+import json
+from pathlib import Path
+from ..utils import png_decoder, mask_decoder, get_num_workers, collate_fn
+
+
+def kp_decoder(x):
+ # Keypoints are between [0, 1] for webdataset
+ keypoints = torch.from_numpy(np.load(io.BytesIO(x))).float()
+ def check_outside(x): return (x < 0).logical_or(x > 1)
+ is_outside = check_outside(keypoints[:, 0]).logical_or(
+ check_outside(keypoints[:, 1])
+ )
+ keypoints[:, 2] = (keypoints[:, 2] > 0).logical_and(is_outside.logical_not())
+ return keypoints
+
+
+def vertices_decoder(x):
+ vertices = torch.from_numpy(np.load(io.BytesIO(x)).astype(np.int32))
+ return vertices.squeeze()[None]
+
+
+class InsertNewKeypoints:
+
+ def __init__(self, keypoints_path: Path) -> None:
+ with open(keypoints_path, "r") as fp:
+ self.keypoints = json.load(fp)
+
+ def __call__(self, sample):
+ key = sample["__key__"]
+ keypoints = torch.tensor(self.keypoints[key], dtype=torch.float32)
+ def check_outside(x): return (x < 0).logical_or(x > 1)
+ is_outside = check_outside(keypoints[:, 0]).logical_or(
+ check_outside(keypoints[:, 1])
+ )
+ keypoints[:, 2] = (keypoints[:, 2] > 0).logical_and(is_outside.logical_not())
+
+ sample["keypoints.npy"] = keypoints
+ return sample
+
+
+def get_dataloader_fdh_wds(
+ path,
+ batch_size: int,
+ num_workers: int,
+ transform: torch.nn.Module,
+ gpu_transform: torch.nn.Module,
+ infinite: bool,
+ shuffle: bool,
+ partial_batches: bool,
+ load_embedding: bool,
+ sample_shuffle=10_000,
+ tar_shuffle=100,
+ read_condition=False,
+ channels_last=False,
+ load_new_keypoints=False,
+ keypoints_split=None,
+ ):
+ # Need to set this for split_by_node to work.
+ os.environ["RANK"] = str(tops.rank())
+ os.environ["WORLD_SIZE"] = str(tops.world_size())
+ if infinite:
+ pipeline = [wds.ResampledShards(str(path))]
+ else:
+ pipeline = [wds.SimpleShardList(str(path))]
+ if shuffle:
+ pipeline.append(wds.shuffle(tar_shuffle))
+ pipeline.extend([
+ wds.split_by_node,
+ wds.split_by_worker,
+ ])
+ if shuffle:
+ pipeline.append(wds.shuffle(sample_shuffle))
+
+ decoder = [
+ wds.handle_extension("image.png", png_decoder),
+ wds.handle_extension("mask.png", mask_decoder),
+ wds.handle_extension("maskrcnn_mask.png", mask_decoder),
+ wds.handle_extension("keypoints.npy", kp_decoder),
+ ]
+
+ rename_keys = [
+ ["img", "image.png"], ["mask", "mask.png"],
+ ["keypoints", "keypoints.npy"], ["maskrcnn_mask", "maskrcnn_mask.png"],
+ ["__key__", "__key__"]
+ ]
+ if load_embedding:
+ decoder.extend([
+ wds.handle_extension("vertices.npy", vertices_decoder),
+ wds.handle_extension("E_mask.png", mask_decoder)
+ ])
+ rename_keys.extend([
+ ["vertices", "vertices.npy"],
+ ["E_mask", "e_mask.png"]
+ ])
+
+ if read_condition:
+ decoder.append(
+ wds.handle_extension("condition.png", png_decoder)
+ )
+ rename_keys.append(["condition", "condition.png"])
+
+ pipeline.extend([
+ wds.tarfile_to_samples(),
+ wds.decode(*decoder),
+
+ ])
+ if load_new_keypoints:
+ assert keypoints_split in ["train", "val"]
+ keypoint_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/1eb88522-8b91-49c7-b56a-ed98a9c7888cef9c0429-a385-4248-abe3-8682de26d041f268aed1-7c88-4677-baad-7623c2ee330f"
+ file_name = "fdh_keypoints_val-050133b34d.json"
+ if keypoints_split == "train":
+ keypoint_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/3e828b1c-d6c0-4622-90bc-1b2cce48ccfff14ab45d-0a5c-431d-be13-7e60580765bd7938601c-e72e-41d9-8836-fffc49e76f58"
+ file_name = "fdh_keypoints_train-2cff11f69a.json"
+ # Set check_hash=True if you suspect download is incorrect.
+ filepath = tops.download_file(keypoint_url, file_name=file_name, check_hash=False)
+ pipeline.append(
+ wds.map(InsertNewKeypoints(filepath))
+ )
+ pipeline.extend([
+ wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches),
+ wds.rename_keys(*rename_keys),
+ ])
+
+ if transform is not None:
+ pipeline.append(wds.map(transform))
+ pipeline = wds.DataPipeline(*pipeline)
+ if infinite:
+ pipeline = pipeline.repeat(nepochs=1000000)
+
+ loader = wds.WebLoader(
+ pipeline, batch_size=None, shuffle=False,
+ num_workers=get_num_workers(num_workers),
+ persistent_workers=True,
+ )
+ loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False)
+ return loader
diff --git a/dp2/data/transforms/__init__.py b/dp2/data/transforms/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ee4bcf4e825af435ffde4c2b6e3c74112f8438f
--- /dev/null
+++ b/dp2/data/transforms/__init__.py
@@ -0,0 +1,2 @@
+from .transforms import RandomCrop, CreateCondition, CreateEmbedding, Resize, ToFloat, Normalize
+from .stylegan2_transform import StyleGANAugmentPipe
diff --git a/dp2/data/transforms/functional.py b/dp2/data/transforms/functional.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ee57f27ad07e597098ce1de967c3a50a1d06d0a
--- /dev/null
+++ b/dp2/data/transforms/functional.py
@@ -0,0 +1,57 @@
+import torchvision.transforms.functional as F
+import torch
+import pickle
+from tops import download_file, assert_shape
+from typing import Dict
+from functools import lru_cache
+
+global symmetry_transform
+
+
+@lru_cache(maxsize=1)
+def get_symmetry_transform(symmetry_url):
+ file_name = download_file(symmetry_url)
+ with open(file_name, "rb") as fp:
+ symmetry = pickle.load(fp)
+ return torch.from_numpy(symmetry["vertex_transforms"]).long()
+
+
+hflip_handled_cases = set([
+ "keypoints", "img", "mask", "border", "semantic_mask", "vertices", "E_mask", "embed_map", "condition",
+ "embedding", "vertx2cat", "maskrcnn_mask", "__key__"])
+
+
+def hflip(container: Dict[str, torch.Tensor], flip_map=None) -> Dict[str, torch.Tensor]:
+ container["img"] = F.hflip(container["img"])
+ if "condition" in container:
+ container["condition"] = F.hflip(container["condition"])
+ if "embedding" in container:
+ container["embedding"] = F.hflip(container["embedding"])
+ assert all([key in hflip_handled_cases for key in container]), container.keys()
+ if "keypoints" in container:
+ assert flip_map is not None
+ if container["keypoints"].ndim == 3:
+ keypoints = container["keypoints"][:, flip_map, :]
+ keypoints[:, :, 0] = 1 - keypoints[:, :, 0]
+ else:
+ assert_shape(container["keypoints"], (None, 3))
+ keypoints = container["keypoints"][flip_map, :]
+ keypoints[:, 0] = 1 - keypoints[:, 0]
+ container["keypoints"] = keypoints
+ if "mask" in container:
+ container["mask"] = F.hflip(container["mask"])
+ if "border" in container:
+ container["border"] = F.hflip(container["border"])
+ if "semantic_mask" in container:
+ container["semantic_mask"] = F.hflip(container["semantic_mask"])
+ if "vertices" in container:
+ symmetry_transform = get_symmetry_transform(
+ "https://dl.fbaipublicfiles.com/densepose/meshes/symmetry/symmetry_smpl_27554.pkl")
+ container["vertices"] = F.hflip(container["vertices"])
+ symmetry_transform_ = symmetry_transform.to(container["vertices"].device)
+ container["vertices"] = symmetry_transform_[container["vertices"].long()]
+ if "E_mask" in container:
+ container["E_mask"] = F.hflip(container["E_mask"])
+ if "maskrcnn_mask" in container:
+ container["maskrcnn_mask"] = F.hflip(container["maskrcnn_mask"])
+ return container
diff --git a/dp2/data/transforms/stylegan2_transform.py b/dp2/data/transforms/stylegan2_transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..49a143cddf9673d079b87ac7d725c433713e54c5
--- /dev/null
+++ b/dp2/data/transforms/stylegan2_transform.py
@@ -0,0 +1,394 @@
+import numpy as np
+import scipy.signal
+import torch
+try:
+ from sg3_torch_utils import misc
+ from sg3_torch_utils.ops import upfirdn2d
+ from sg3_torch_utils.ops import grid_sample_gradfix
+ from sg3_torch_utils.ops import conv2d_gradfix
+except:
+ pass
+#----------------------------------------------------------------------------
+# Coefficients of various wavelet decomposition low-pass filters.
+
+wavelets = {
+ 'haar': [0.7071067811865476, 0.7071067811865476],
+ 'db1': [0.7071067811865476, 0.7071067811865476],
+ 'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
+ 'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
+ 'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523],
+ 'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125],
+ 'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017],
+ 'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236],
+ 'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161],
+ 'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
+ 'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
+ 'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427],
+ 'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728],
+ 'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148],
+ 'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255],
+ 'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609],
+}
+
+#----------------------------------------------------------------------------
+# Helpers for constructing transformation matrices.
+
+
+def matrix(*rows, device=None):
+ assert all(len(row) == len(rows[0]) for row in rows)
+ elems = [x for row in rows for x in row]
+ ref = [x for x in elems if isinstance(x, torch.Tensor)]
+ if len(ref) == 0:
+ return misc.constant(np.asarray(rows), device=device)
+ assert device is None or device == ref[0].device
+ elems = [x if isinstance(x, torch.Tensor) else misc.constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems]
+ return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1))
+
+
+def translate2d(tx, ty, **kwargs):
+ return matrix(
+ [1, 0, tx],
+ [0, 1, ty],
+ [0, 0, 1],
+ **kwargs)
+
+
+def translate3d(tx, ty, tz, **kwargs):
+ return matrix(
+ [1, 0, 0, tx],
+ [0, 1, 0, ty],
+ [0, 0, 1, tz],
+ [0, 0, 0, 1],
+ **kwargs)
+
+
+def scale2d(sx, sy, **kwargs):
+ return matrix(
+ [sx, 0, 0],
+ [0, sy, 0],
+ [0, 0, 1],
+ **kwargs)
+
+
+def scale3d(sx, sy, sz, **kwargs):
+ return matrix(
+ [sx, 0, 0, 0],
+ [0, sy, 0, 0],
+ [0, 0, sz, 0],
+ [0, 0, 0, 1],
+ **kwargs)
+
+
+def rotate2d(theta, **kwargs):
+ return matrix(
+ [torch.cos(theta), torch.sin(-theta), 0],
+ [torch.sin(theta), torch.cos(theta), 0],
+ [0, 0, 1],
+ **kwargs)
+
+
+def rotate3d(v, theta, **kwargs):
+ vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2]
+ s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c
+ return matrix(
+ [vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0],
+ [vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0],
+ [vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0],
+ [0, 0, 0, 1],
+ **kwargs)
+
+
+def translate2d_inv(tx, ty, **kwargs):
+ return translate2d(-tx, -ty, **kwargs)
+
+
+def scale2d_inv(sx, sy, **kwargs):
+ return scale2d(1 / sx, 1 / sy, **kwargs)
+
+
+def rotate2d_inv(theta, **kwargs):
+ return rotate2d(-theta, **kwargs)
+
+
+class StyleGANAugmentPipe(torch.nn.Module):
+ def __init__(self,
+ rotate90=0, xint=0, xint_max=0.125,
+ scale=0, rotate=0, aniso=0, xfrac=0, scale_std=0.2, rotate_max=1, aniso_std=0.2, xfrac_std=0.125,
+ brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5,
+ hue_max=1, saturation_std=1,
+ imgfilter=0, imgfilter_bands=[1,1,1,1], imgfilter_std=1,
+ ):
+ super().__init__()
+ self.register_buffer('p', torch.ones([])) # Overall multiplier for augmentation probability.
+
+ # Pixel blitting.
+ self.rotate90 = float(rotate90) # Probability multiplier for 90 degree rotations.
+ self.xint = float(xint) # Probability multiplier for integer translation.
+ self.xint_max = float(xint_max) # Range of integer translation, relative to image dimensions.
+
+ # General geometric transformations.
+ self.scale = float(scale) # Probability multiplier for isotropic scaling.
+ self.rotate = float(rotate) # Probability multiplier for arbitrary rotation.
+ self.aniso = float(aniso) # Probability multiplier for anisotropic scaling.
+ self.xfrac = float(xfrac) # Probability multiplier for fractional translation.
+ self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling.
+ self.rotate_max = float(rotate_max) # Range of arbitrary rotation, 1 = full circle.
+ self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling.
+ self.xfrac_std = float(xfrac_std) # Standard deviation of frational translation, relative to image dimensions.
+
+ # Color transformations.
+ self.brightness = float(brightness) # Probability multiplier for brightness.
+ self.contrast = float(contrast) # Probability multiplier for contrast.
+ self.lumaflip = float(lumaflip) # Probability multiplier for luma flip.
+ self.hue = float(hue) # Probability multiplier for hue rotation.
+ self.saturation = float(saturation) # Probability multiplier for saturation.
+ self.brightness_std = float(brightness_std) # Standard deviation of brightness.
+ self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast.
+ self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle.
+ self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation.
+
+ # Image-space filtering.
+ self.imgfilter = float(imgfilter) # Probability multiplier for image-space filtering.
+ self.imgfilter_bands = list(imgfilter_bands) # Probability multipliers for individual frequency bands.
+ self.imgfilter_std = float(imgfilter_std) # Log2 standard deviation of image-space filter amplification.
+
+ # Setup orthogonal lowpass filter for geometric augmentations.
+ self.register_buffer('Hz_geom', upfirdn2d.setup_filter(wavelets['sym6']))
+
+ # Construct filter bank for image-space filtering.
+ Hz_lo = np.asarray(wavelets['sym2']) # H(z)
+ Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) # H(-z)
+ Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2
+ Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2
+ Hz_fbank = np.eye(4, 1) # Bandpass(H(z), b_i)
+ for i in range(1, Hz_fbank.shape[0]):
+ Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape(Hz_fbank.shape[0], -1)[:, :-1]
+ Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2])
+ Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) // 2 : (Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2
+ self.register_buffer('Hz_fbank', torch.as_tensor(Hz_fbank, dtype=torch.float32))
+
+ def forward(self, batch, debug_percentile=None):
+ images = batch["img"]
+ batch["vertices"] = batch["vertices"].float()
+ assert isinstance(images, torch.Tensor) and images.ndim == 4
+ batch_size, num_channels, height, width = images.shape
+ device = images.device
+ self.Hz_fbank = self.Hz_fbank.to(device)
+ self.Hz_geom = self.Hz_geom.to(device)
+ if debug_percentile is not None:
+ debug_percentile = torch.as_tensor(debug_percentile, dtype=torch.float32, device=device)
+
+ # -------------------------------------
+ # Select parameters for pixel blitting.
+ # -------------------------------------
+
+ # Initialize inverse homogeneous 2D transform: G_inv @ pixel_out ==> pixel_in
+ I_3 = torch.eye(3, device=device)
+ G_inv = I_3
+
+ # Apply integer translation with probability (xint * strength).
+ if self.xint > 0:
+ t = (torch.rand([batch_size, 2], device=device) * 2 - 1) * self.xint_max
+ t = torch.where(torch.rand([batch_size, 1], device=device) < self.xint * self.p, t, torch.zeros_like(t))
+ if debug_percentile is not None:
+ t = torch.full_like(t, (debug_percentile * 2 - 1) * self.xint_max)
+ G_inv = G_inv @ translate2d_inv(torch.round(t[:,0] * width), torch.round(t[:,1] * height))
+
+ # --------------------------------------------------------
+ # Select parameters for general geometric transformations.
+ # --------------------------------------------------------
+
+ # Apply isotropic scaling with probability (scale * strength).
+ if self.scale > 0:
+ s = torch.exp2(torch.randn([batch_size], device=device) * self.scale_std)
+ s = torch.where(torch.rand([batch_size], device=device) < self.scale * self.p, s, torch.ones_like(s))
+ if debug_percentile is not None:
+ s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.scale_std))
+ G_inv = G_inv @ scale2d_inv(s, s)
+
+ # Apply pre-rotation with probability p_rot.
+ p_rot = 1 - torch.sqrt((1 - self.rotate * self.p).clamp(0, 1)) # P(pre OR post) = p
+ if self.rotate > 0:
+ theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
+ theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
+ if debug_percentile is not None:
+ theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.rotate_max)
+ G_inv = G_inv @ rotate2d_inv(-theta) # Before anisotropic scaling.
+
+ # Apply anisotropic scaling with probability (aniso * strength).
+ if self.aniso > 0:
+ s = torch.exp2(torch.randn([batch_size], device=device) * self.aniso_std)
+ s = torch.where(torch.rand([batch_size], device=device) < self.aniso * self.p, s, torch.ones_like(s))
+ if debug_percentile is not None:
+ s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.aniso_std))
+ G_inv = G_inv @ scale2d_inv(s, 1 / s)
+
+ # Apply post-rotation with probability p_rot.
+ if self.rotate > 0:
+ theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
+ theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
+ if debug_percentile is not None:
+ theta = torch.zeros_like(theta)
+ G_inv = G_inv @ rotate2d_inv(-theta) # After anisotropic scaling.
+
+ # Apply fractional translation with probability (xfrac * strength).
+ if self.xfrac > 0:
+ t = torch.randn([batch_size, 2], device=device) * self.xfrac_std
+ t = torch.where(torch.rand([batch_size, 1], device=device) < self.xfrac * self.p, t, torch.zeros_like(t))
+ if debug_percentile is not None:
+ t = torch.full_like(t, torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std)
+ G_inv = G_inv @ translate2d_inv(t[:,0] * width, t[:,1] * height)
+
+ # ----------------------------------
+ # Execute geometric transformations.
+ # ----------------------------------
+
+ # Execute if the transform is not identity.
+ if G_inv is not I_3:
+ # Calculate padding.
+ cx = (width - 1) / 2
+ cy = (height - 1) / 2
+ cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz]
+ cp = G_inv @ cp.t() # [batch, xyz, idx]
+ Hz_pad = self.Hz_geom.shape[0] // 4
+ margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx]
+ margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1]
+ margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device)
+ margin = margin.max(misc.constant([0, 0] * 2, device=device))
+ margin = margin.min(misc.constant([width-1, height-1] * 2, device=device))
+ mx0, my0, mx1, my1 = margin.ceil().to(torch.int32)
+
+ # Pad image and adjust origin.
+ images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect')
+ batch["mask"] = torch.nn.functional.pad(input=batch["mask"], pad=[mx0,mx1,my0,my1], mode='constant', value=1.0)
+ batch["E_mask"] = torch.nn.functional.pad(input=batch["E_mask"], pad=[mx0,mx1,my0,my1], mode='constant', value=0.0)
+ batch["vertices"] = torch.nn.functional.pad(input=batch["vertices"], pad=[mx0,mx1,my0,my1], mode='constant', value=0.0)
+ G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv
+
+ # Upsample.
+ images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2)
+ batch["mask"] = torch.nn.functional.interpolate(batch["mask"], scale_factor=2, mode="nearest")
+ batch["E_mask"] = torch.nn.functional.interpolate(batch["E_mask"], scale_factor=2, mode="nearest")
+ batch["vertices"] = torch.nn.functional.interpolate(batch["vertices"], scale_factor=2, mode="nearest")
+ G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device)
+ G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device)
+
+ # Execute transformation.
+ shape = [batch_size, num_channels, (height + Hz_pad * 2) * 2, (width + Hz_pad * 2) * 2]
+ G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device)
+ grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False)
+ images = grid_sample_gradfix.grid_sample(images, grid)
+
+ batch["mask"] = torch.nn.functional.grid_sample(
+ input=batch["mask"], grid=grid, mode='nearest', padding_mode="border", align_corners=False)
+ batch["E_mask"] = torch.nn.functional.grid_sample(
+ input=batch["E_mask"], grid=grid, mode='nearest', padding_mode="border", align_corners=False)
+ batch["vertices"] = torch.nn.functional.grid_sample(
+ input=batch["vertices"], grid=grid, mode='nearest', padding_mode="border", align_corners=False)
+
+
+ # Downsample and crop.
+ images = upfirdn2d.downsample2d(x=images, f=self.Hz_geom, down=2, padding=-Hz_pad*2, flip_filter=True)
+ batch["mask"] = torch.nn.functional.interpolate(batch["mask"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False)
+ batch["E_mask"] = torch.nn.functional.interpolate(batch["E_mask"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False)
+ batch["vertices"] = torch.nn.functional.interpolate(batch["vertices"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False)
+ # --------------------------------------------
+ # Select parameters for color transformations.
+ # --------------------------------------------
+
+ # Initialize homogeneous 3D transformation matrix: C @ color_in ==> color_out
+ I_4 = torch.eye(4, device=device)
+ C = I_4
+
+ # Apply brightness with probability (brightness * strength).
+ if self.brightness > 0:
+ b = torch.randn([batch_size], device=device) * self.brightness_std
+ b = torch.where(torch.rand([batch_size], device=device) < self.brightness * self.p, b, torch.zeros_like(b))
+ if debug_percentile is not None:
+ b = torch.full_like(b, torch.erfinv(debug_percentile * 2 - 1) * self.brightness_std)
+ C = translate3d(b, b, b) @ C
+
+ # Apply contrast with probability (contrast * strength).
+ if self.contrast > 0:
+ c = torch.exp2(torch.randn([batch_size], device=device) * self.contrast_std)
+ c = torch.where(torch.rand([batch_size], device=device) < self.contrast * self.p, c, torch.ones_like(c))
+ if debug_percentile is not None:
+ c = torch.full_like(c, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.contrast_std))
+ C = scale3d(c, c, c) @ C
+
+ # Apply luma flip with probability (lumaflip * strength).
+ v = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) # Luma axis.
+
+ # Apply hue rotation with probability (hue * strength).
+ if self.hue > 0 and num_channels > 1:
+ theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.hue_max
+ theta = torch.where(torch.rand([batch_size], device=device) < self.hue * self.p, theta, torch.zeros_like(theta))
+ if debug_percentile is not None:
+ theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.hue_max)
+ C = rotate3d(v, theta) @ C # Rotate around v.
+
+ # Apply saturation with probability (saturation * strength).
+ if self.saturation > 0 and num_channels > 1:
+ s = torch.exp2(torch.randn([batch_size, 1, 1], device=device) * self.saturation_std)
+ s = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.saturation * self.p, s, torch.ones_like(s))
+ if debug_percentile is not None:
+ s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.saturation_std))
+ C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C
+
+ # ------------------------------
+ # Execute color transformations.
+ # ------------------------------
+
+ # Execute if the transform is not identity.
+ if C is not I_4:
+ images = images.reshape([batch_size, num_channels, height * width])
+ if num_channels == 3:
+ images = C[:, :3, :3] @ images + C[:, :3, 3:]
+ elif num_channels == 1:
+ C = C[:, :3, :].mean(dim=1, keepdims=True)
+ images = images * C[:, :, :3].sum(dim=2, keepdims=True) + C[:, :, 3:]
+ else:
+ raise ValueError('Image must be RGB (3 channels) or L (1 channel)')
+ images = images.reshape([batch_size, num_channels, height, width])
+
+ # ----------------------
+ # Image-space filtering.
+ # ----------------------
+
+ if self.imgfilter > 0:
+ num_bands = self.Hz_fbank.shape[0]
+ assert len(self.imgfilter_bands) == num_bands
+ expected_power = misc.constant(np.array([10, 1, 1, 1]) / 13, device=device) # Expected power spectrum (1/f).
+
+ # Apply amplification for each band with probability (imgfilter * strength * band_strength).
+ g = torch.ones([batch_size, num_bands], device=device) # Global gain vector (identity).
+ for i, band_strength in enumerate(self.imgfilter_bands):
+ t_i = torch.exp2(torch.randn([batch_size], device=device) * self.imgfilter_std)
+ t_i = torch.where(torch.rand([batch_size], device=device) < self.imgfilter * self.p * band_strength, t_i, torch.ones_like(t_i))
+ if debug_percentile is not None:
+ t_i = torch.full_like(t_i, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.imgfilter_std)) if band_strength > 0 else torch.ones_like(t_i)
+ t = torch.ones([batch_size, num_bands], device=device) # Temporary gain vector.
+ t[:, i] = t_i # Replace i'th element.
+ t = t / (expected_power * t.square()).sum(dim=-1, keepdims=True).sqrt() # Normalize power.
+ g = g * t # Accumulate into global gain.
+
+ # Construct combined amplification filter.
+ Hz_prime = g @ self.Hz_fbank # [batch, tap]
+ Hz_prime = Hz_prime.unsqueeze(1).repeat([1, num_channels, 1]) # [batch, channels, tap]
+ Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1, -1]) # [batch * channels, 1, tap]
+
+ # Apply filter.
+ p = self.Hz_fbank.shape[1] // 2
+ images = images.reshape([1, batch_size * num_channels, height, width])
+ images = torch.nn.functional.pad(input=images, pad=[p,p,p,p], mode='reflect')
+ images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size*num_channels)
+ images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size*num_channels)
+ images = images.reshape([batch_size, num_channels, height, width])
+
+ # ------------------------
+ # Image-space corruptions.
+ # ------------------------
+ batch["img"] = images
+ batch["vertices"] = batch["vertices"].long()
+ batch["border"] = 1 - batch["E_mask"] - batch["mask"]
+ return batch
diff --git a/dp2/data/transforms/transforms.py b/dp2/data/transforms/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fd43e7a515deacca4be7242d065b4f3ccb6800e
--- /dev/null
+++ b/dp2/data/transforms/transforms.py
@@ -0,0 +1,277 @@
+from pathlib import Path
+from typing import Dict, List
+import torchvision
+import torch
+import tops
+import torchvision.transforms.functional as F
+from .functional import hflip
+import numpy as np
+from dp2.utils.vis_utils import get_coco_keypoints
+from PIL import Image, ImageDraw
+from typing import Tuple
+
+
+class RandomHorizontalFlip(torch.nn.Module):
+
+ def __init__(self, p: float, flip_map=None, **kwargs):
+ super().__init__()
+ self.flip_ratio = p
+ self.flip_map = flip_map
+ if self.flip_ratio is None:
+ self.flip_ratio = 0.5
+ assert 0 <= self.flip_ratio <= 1
+
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ if torch.rand(1) > self.flip_ratio:
+ return container
+ return hflip(container, self.flip_map)
+
+
+class CenterCrop(torch.nn.Module):
+ """
+ Performs the transform on the image.
+ NOTE: Does not transform the mask to improve runtime.
+ """
+
+ def __init__(self, size: List[int]):
+ super().__init__()
+ self.size = tuple(size)
+
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ min_size = min(container["img"].shape[1], container["img"].shape[2])
+ if min_size < self.size[0]:
+ container["img"] = F.center_crop(container["img"], min_size)
+ container["img"] = F.resize(container["img"], self.size)
+ return container
+ container["img"] = F.center_crop(container["img"], self.size)
+ return container
+
+
+class Resize(torch.nn.Module):
+ """
+ Performs the transform on the image.
+ NOTE: Does not transform the mask to improve runtime.
+ """
+
+ def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR):
+ super().__init__()
+ self.size = tuple(size)
+ self.interpolation = interpolation
+
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ container["img"] = F.resize(container["img"], self.size, self.interpolation, antialias=True)
+ if "semantic_mask" in container:
+ container["semantic_mask"] = F.resize(
+ container["semantic_mask"], self.size, F.InterpolationMode.NEAREST)
+ if "embedding" in container:
+ container["embedding"] = F.resize(
+ container["embedding"], self.size, self.interpolation)
+ if "mask" in container:
+ container["mask"] = F.resize(
+ container["mask"], self.size, F.InterpolationMode.NEAREST)
+ if "E_mask" in container:
+ container["E_mask"] = F.resize(
+ container["E_mask"], self.size, F.InterpolationMode.NEAREST)
+ if "maskrcnn_mask" in container:
+ container["maskrcnn_mask"] = F.resize(
+ container["maskrcnn_mask"], self.size, F.InterpolationMode.NEAREST)
+ if "vertices" in container:
+ container["vertices"] = F.resize(
+ container["vertices"], self.size, F.InterpolationMode.NEAREST)
+ return container
+
+ def __repr__(self):
+ repr = super().__repr__()
+ vars_ = dict(size=self.size, interpolation=self.interpolation)
+ return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()])
+
+
+class Normalize(torch.nn.Module):
+ """
+ Performs the transform on the image.
+ NOTE: Does not transform the mask to improve runtime.
+ """
+
+ def __init__(self, mean, std, inplace, keys=["img"]):
+ super().__init__()
+ self.mean = mean
+ self.std = std
+ self.inplace = inplace
+ self.keys = keys
+
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ for key in self.keys:
+ container[key] = F.normalize(container[key], self.mean, self.std, self.inplace)
+ return container
+
+ def __repr__(self):
+ repr = super().__repr__()
+ vars_ = dict(mean=self.mean, std=self.std, inplace=self.inplace)
+ return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()])
+
+
+class ToFloat(torch.nn.Module):
+
+ def __init__(self, keys=["img"], norm=True) -> None:
+ super().__init__()
+ self.keys = keys
+ self.gain = 255 if norm else 1
+
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ for key in self.keys:
+ container[key] = container[key].float() / self.gain
+ return container
+
+
+class RandomCrop(torchvision.transforms.RandomCrop):
+ """
+ Performs the transform on the image.
+ NOTE: Does not transform the mask to improve runtime.
+ """
+
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ container["img"] = super().forward(container["img"])
+ return container
+
+
+class CreateCondition(torch.nn.Module):
+
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ if container["img"].dtype == torch.uint8:
+ container["condition"] = container["img"] * container["mask"].byte() + (1-container["mask"].byte()) * 127
+ return container
+ container["condition"] = container["img"] * container["mask"]
+ return container
+
+
+class CreateEmbedding(torch.nn.Module):
+
+ def __init__(self, embed_path: Path, cuda=True) -> None:
+ super().__init__()
+ self.embed_map = torch.load(embed_path, map_location=torch.device("cpu"))
+ if cuda:
+ self.embed_map = tops.to_cuda(self.embed_map)
+
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ vertices = container["vertices"]
+ if vertices.ndim == 3:
+ embedding = self.embed_map[vertices.long()].squeeze(dim=0)
+ embedding = embedding.permute(2, 0, 1) * container["E_mask"]
+ pass
+ else:
+ assert vertices.ndim == 4
+ embedding = self.embed_map[vertices.long()].squeeze(dim=1)
+ embedding = embedding.permute(0, 3, 1, 2) * container["E_mask"]
+ container["embedding"] = embedding
+ container["embed_map"] = self.embed_map.clone()
+ return container
+
+
+class InsertJointMap(torch.nn.Module):
+
+ def __init__(self, imsize: Tuple) -> None:
+ super().__init__()
+ self.imsize = imsize
+ knames = get_coco_keypoints()[0]
+ knames = knames + ["neck", "mid_hip"]
+ connectivity = {
+ "nose": ["left_eye", "right_eye", "neck"],
+ "left_eye": ["right_eye", "left_ear"],
+ "right_eye": ["right_ear"],
+ "left_shoulder": ["right_shoulder", "left_elbow", "left_hip"],
+ "right_shoulder": ["right_elbow", "right_hip"],
+ "left_elbow": ["left_wrist"],
+ "right_elbow": ["right_wrist"],
+ "left_hip": ["right_hip", "left_knee"],
+ "right_hip": ["right_knee"],
+ "left_knee": ["left_ankle"],
+ "right_knee": ["right_ankle"],
+ "neck": ["mid_hip", "nose"],
+ }
+ category = {
+ ("nose", "left_eye"): 0, # head
+ ("nose", "right_eye"): 0, # head
+ ("nose", "neck"): 0, # head
+ ("left_eye", "right_eye"): 0, # head
+ ("left_eye", "left_ear"): 0, # head
+ ("right_eye", "right_ear"): 0, # head
+ ("left_shoulder", "left_elbow"): 1, # left arm
+ ("left_elbow", "left_wrist"): 1, # left arm
+ ("right_shoulder", "right_elbow"): 2, # right arm
+ ("right_elbow", "right_wrist"): 2, # right arm
+ ("left_shoulder", "right_shoulder"): 3, # body
+ ("left_shoulder", "left_hip"): 3, # body
+ ("right_shoulder", "right_hip"): 3, # body
+ ("left_hip", "right_hip"): 3, # body
+ ("left_hip", "left_knee"): 4, # left leg
+ ("left_knee", "left_ankle"): 4, # left leg
+ ("right_hip", "right_knee"): 5, # right leg
+ ("right_knee", "right_ankle"): 5, # right leg
+ ("neck", "mid_hip"): 3, # body
+ ("neck", "nose"): 0, # head
+ }
+ self.indices2category = {
+ tuple([knames.index(n) for n in k]): v for k, v in category.items()
+ }
+ self.connectivity_indices = {
+ knames.index(k): [knames.index(v_) for v_ in v]
+ for k, v in connectivity.items()
+ }
+ self.l_shoulder = knames.index("left_shoulder")
+ self.r_shoulder = knames.index("right_shoulder")
+ self.l_hip = knames.index("left_hip")
+ self.r_hip = knames.index("right_hip")
+ self.l_eye = knames.index("left_eye")
+ self.r_eye = knames.index("right_eye")
+ self.nose = knames.index("nose")
+ self.neck = knames.index("neck")
+
+ def create_joint_map(self, N, H, W, keypoints):
+ joint_maps = np.zeros((N, H, W), dtype=np.uint8)
+ for bidx, keypoints in enumerate(keypoints):
+ assert keypoints.shape == (17, 3), keypoints.shape
+ keypoints = torch.cat((keypoints, torch.zeros(2, 3)))
+ visible = keypoints[:, -1] > 0
+
+ if visible[self.l_shoulder] and visible[self.r_shoulder]:
+ neck = (keypoints[self.l_shoulder]
+ + (keypoints[self.r_shoulder] - keypoints[self.l_shoulder]) / 2)
+ keypoints[-2] = neck
+ visible[-2] = 1
+ if visible[self.l_hip] and visible[self.r_hip]:
+ mhip = (keypoints[self.l_hip]
+ + (keypoints[self.r_hip] - keypoints[self.l_hip]) / 2
+ )
+ keypoints[-1] = mhip
+ visible[-1] = 1
+
+ keypoints[:, 0] *= W
+ keypoints[:, 1] *= H
+ joint_map = Image.fromarray(np.zeros((H, W), dtype=np.uint8))
+ draw = ImageDraw.Draw(joint_map)
+ for fidx in self.connectivity_indices.keys():
+ for tidx in self.connectivity_indices[fidx]:
+ if visible[fidx] == 0 or visible[tidx] == 0:
+ continue
+ c = self.indices2category[(fidx, tidx)]
+ s = tuple(keypoints[fidx, :2].round().long().numpy().tolist())
+ e = tuple(keypoints[tidx, :2].round().long().numpy().tolist())
+ draw.line((s, e), width=1, fill=c + 1)
+ if visible[self.nose] == 0 and visible[self.neck] == 1:
+ m_eye = (
+ keypoints[self.l_eye]
+ + (keypoints[self.r_eye] - keypoints[self.l_eye]) / 2
+ )
+ s = tuple(m_eye[:2].round().long().numpy().tolist())
+ e = tuple(keypoints[self.neck, :2].round().long().numpy().tolist())
+ c = self.indices2category[(self.nose, self.neck)]
+ draw.line((s, e), width=1, fill=c + 1)
+ joint_map = np.array(joint_map)
+
+ joint_maps[bidx] = np.array(joint_map)
+ return joint_maps[:, None]
+
+ def forward(self, batch):
+ batch["joint_map"] = torch.from_numpy(self.create_joint_map(
+ batch["img"].shape[0], *self.imsize, batch["keypoints"]))
+ return batch
diff --git a/dp2/data/utils.py b/dp2/data/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..392fe22f9f78daae0f65b1484311f7bb14381cbd
--- /dev/null
+++ b/dp2/data/utils.py
@@ -0,0 +1,122 @@
+import torch
+from PIL import Image
+import numpy as np
+import multiprocessing
+import io
+from tops import logger
+from torch.utils.data._utils.collate import default_collate
+
+try:
+ import pyspng
+
+ PYSPNG_IMPORTED = True
+except ImportError:
+ PYSPNG_IMPORTED = False
+ print("Could not load pyspng. Defaulting to pillow image backend.")
+ from PIL import Image
+
+
+def get_fdf_keypoints():
+ return get_coco_keypoints()[:7]
+
+
+def get_fdf_flipmap():
+ keypoints = get_fdf_keypoints()
+ keypoint_flip_map = {
+ "left_eye": "right_eye",
+ "left_ear": "right_ear",
+ "left_shoulder": "right_shoulder",
+ }
+ for key, value in list(keypoint_flip_map.items()):
+ keypoint_flip_map[value] = key
+ keypoint_flip_map["nose"] = "nose"
+ keypoint_flip_map_idx = []
+ for source in keypoints:
+ keypoint_flip_map_idx.append(keypoints.index(keypoint_flip_map[source]))
+ return keypoint_flip_map_idx
+
+
+def get_coco_keypoints():
+ return [
+ "nose",
+ "left_eye",
+ "right_eye", # 2
+ "left_ear",
+ "right_ear", # 4
+ "left_shoulder",
+ "right_shoulder", # 6
+ "left_elbow",
+ "right_elbow", # 8
+ "left_wrist",
+ "right_wrist", # 10
+ "left_hip",
+ "right_hip", # 12
+ "left_knee",
+ "right_knee", # 14
+ "left_ankle",
+ "right_ankle", # 16
+ ]
+
+
+def get_coco_flipmap():
+ keypoints = get_coco_keypoints()
+ keypoint_flip_map = {
+ "left_eye": "right_eye",
+ "left_ear": "right_ear",
+ "left_shoulder": "right_shoulder",
+ "left_elbow": "right_elbow",
+ "left_wrist": "right_wrist",
+ "left_hip": "right_hip",
+ "left_knee": "right_knee",
+ "left_ankle": "right_ankle",
+ }
+ for key, value in list(keypoint_flip_map.items()):
+ keypoint_flip_map[value] = key
+ keypoint_flip_map["nose"] = "nose"
+ keypoint_flip_map_idx = []
+ for source in keypoints:
+ keypoint_flip_map_idx.append(keypoints.index(keypoint_flip_map[source]))
+ return keypoint_flip_map_idx
+
+
+def mask_decoder(x):
+ mask = torch.from_numpy(np.array(Image.open(io.BytesIO(x)))).squeeze()[None]
+ mask = mask > 0 # This fixes bug causing maskf.loat().max() == 255.
+ return mask
+
+
+def png_decoder(x):
+ if PYSPNG_IMPORTED:
+ return torch.from_numpy(np.rollaxis(pyspng.load(x), 2))
+ with Image.open(io.BytesIO(x)) as im:
+ im = torch.from_numpy(np.rollaxis(np.array(im.convert("RGB")), 2))
+ return im
+
+
+def jpg_decoder(x):
+ with Image.open(io.BytesIO(x)) as im:
+ im = torch.from_numpy(np.rollaxis(np.array(im.convert("RGB")), 2))
+ return im
+
+
+def get_num_workers(num_workers: int):
+ n_cpus = multiprocessing.cpu_count()
+ if num_workers > n_cpus:
+ logger.warn(f"Setting the number of workers to match cpu count: {n_cpus}")
+ return n_cpus
+ return num_workers
+
+
+def collate_fn(batch):
+ elem = batch[0]
+ ignore_keys = set(["embed_map", "vertx2cat"])
+ batch_ = {
+ key: default_collate([d[key] for d in batch])
+ for key in elem
+ if key not in ignore_keys
+ }
+ if "embed_map" in elem:
+ batch_["embed_map"] = elem["embed_map"]
+ if "vertx2cat" in elem:
+ batch_["vertx2cat"] = elem["vertx2cat"]
+ return batch_
diff --git a/dp2/detection/__init__.py b/dp2/detection/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..613969b28384cd1c64fc8db685e7622f4cc02615
--- /dev/null
+++ b/dp2/detection/__init__.py
@@ -0,0 +1,3 @@
+from .cse_mask_face_detector import CSeMaskFaceDetector
+from .person_detector import CSEPersonDetector
+from .structures import PersonDetection, VehicleDetection, FaceDetection
diff --git a/dp2/detection/base.py b/dp2/detection/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ab8b20c9474cbb6074b66a694d1a1a05df0c12c
--- /dev/null
+++ b/dp2/detection/base.py
@@ -0,0 +1,42 @@
+import pickle
+import torch
+import lzma
+from pathlib import Path
+from tops import logger
+
+
+class BaseDetector:
+
+ def __init__(self, cache_directory: str) -> None:
+ if cache_directory is not None:
+ self.cache_directory = Path(cache_directory, str(self.__class__.__name__))
+ self.cache_directory.mkdir(exist_ok=True, parents=True)
+
+ def save_to_cache(self, detection, cache_path: Path, after_preprocess=True):
+ logger.log(f"Caching detection to: {cache_path}")
+ with lzma.open(cache_path, "wb") as fp:
+ torch.save(
+ [det.state_dict(after_preprocess=after_preprocess) for det in detection], fp,
+ pickle_protocol=pickle.HIGHEST_PROTOCOL)
+
+ def load_from_cache(self, cache_path: Path):
+ logger.log(f"Loading detection from cache path: {cache_path}")
+ with lzma.open(cache_path, "rb") as fp:
+ state_dict = torch.load(fp)
+ return [
+ state["cls"].from_state_dict(state_dict=state) for state in state_dict
+ ]
+
+ def forward_and_cache(self, im: torch.Tensor, cache_id: str, load_cache: bool):
+ if cache_id is None:
+ return self.forward(im)
+ cache_path = self.cache_directory.joinpath(cache_id + ".torch")
+ if cache_path.is_file() and load_cache:
+ try:
+ return self.load_from_cache(cache_path)
+ except Exception as e:
+ logger.warn(f"The cache file was corrupted: {cache_path}")
+ exit()
+ detections = self.forward(im)
+ self.save_to_cache(detections, cache_path)
+ return detections
diff --git a/dp2/detection/box_utils.py b/dp2/detection/box_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d3e6b5a84f071c1b9e9a74f6adbbe49b3cd7610
--- /dev/null
+++ b/dp2/detection/box_utils.py
@@ -0,0 +1,104 @@
+import numpy as np
+
+
+def expand_bbox_to_ratio(bbox, imshape, target_aspect_ratio):
+ x0, y0, x1, y1 = [int(_) for _ in bbox]
+ h, w = y1 - y0, x1 - x0
+ cur_ratio = h / w
+
+ if cur_ratio == target_aspect_ratio:
+ return [x0, y0, x1, y1]
+ if cur_ratio < target_aspect_ratio:
+ target_height = int(w*target_aspect_ratio)
+ y0, y1 = expand_axis(y0, y1, target_height, imshape[0])
+ else:
+ target_width = int(h/target_aspect_ratio)
+ x0, x1 = expand_axis(x0, x1, target_width, imshape[1])
+ return x0, y0, x1, y1
+
+
+def expand_axis(start, end, target_width, limit):
+ # Can return a bbox outside of limit
+ cur_width = end - start
+ start = start - (target_width-cur_width)//2
+ end = end + (target_width-cur_width)//2
+ if end - start != target_width:
+ end += 1
+ assert end - start == target_width
+ if start < 0 and end > limit:
+ return start, end
+ if start < 0 and end < limit:
+ to_shift = min(0 - start, limit - end)
+ start += to_shift
+ end += to_shift
+ if end > limit and start > 0:
+ to_shift = min(end - limit, start)
+ end -= to_shift
+ start -= to_shift
+ assert end - start == target_width
+ return start, end
+
+
+def expand_box(bbox, imshape, mask, percentage_background: float):
+ assert isinstance(bbox[0], int)
+ assert 0 < percentage_background < 1
+ # Percentage in S
+ mask_pixels = mask.long().sum().cpu()
+ total_pixels = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
+ percentage_mask = mask_pixels / total_pixels
+ if (1 - percentage_mask) > percentage_background:
+ return bbox
+ target_pixels = mask_pixels / (1 - percentage_background)
+ x0, y0, x1, y1 = bbox
+ H = y1 - y0
+ W = x1 - x0
+ p = np.sqrt(target_pixels/(H*W))
+ target_width = int(np.ceil(p * W))
+ target_height = int(np.ceil(p * H))
+ x0, x1 = expand_axis(x0, x1, target_width, imshape[1])
+ y0, y1 = expand_axis(y0, y1, target_height, imshape[0])
+ return [x0, y0, x1, y1]
+
+
+def expand_axises_by_percentage(bbox_XYXY, imshape, percentage):
+ x0, y0, x1, y1 = bbox_XYXY
+ H = y1 - y0
+ W = x1 - x0
+ expansion = int(((H*W)**0.5) * percentage)
+ new_width = W + expansion
+ new_height = H + expansion
+ x0, x1 = expand_axis(x0, x1, min(new_width, imshape[1]), imshape[1])
+ y0, y1 = expand_axis(y0, y1, min(new_height, imshape[0]), imshape[0])
+ return [x0, y0, x1, y1]
+
+
+def get_expanded_bbox(
+ bbox_XYXY,
+ imshape,
+ mask,
+ percentage_background: float,
+ axis_minimum_expansion: float,
+ target_aspect_ratio: float):
+ bbox_XYXY = bbox_XYXY.long().cpu().numpy().tolist()
+ # Expand each axis of the bounding box by a minimum percentage
+ bbox_XYXY = expand_axises_by_percentage(bbox_XYXY, imshape, axis_minimum_expansion)
+ # Find the minimum bbox with the aspect ratio. Can be outside of imshape
+ bbox_XYXY = expand_bbox_to_ratio(bbox_XYXY, imshape, target_aspect_ratio)
+ # Expands square box such that X% of the bbox is background
+ bbox_XYXY = expand_box(bbox_XYXY, imshape, mask, percentage_background)
+ assert isinstance(bbox_XYXY[0], (int, np.int64))
+ return bbox_XYXY
+
+
+def include_box(bbox, minimum_area, aspect_ratio_range, min_bbox_ratio_inside, imshape):
+ def area_inside_ratio(bbox, imshape):
+ area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
+ area_inside = (min(bbox[2], imshape[1]) - max(0, bbox[0])) * (min(imshape[0], bbox[3]) - max(0, bbox[1]))
+ return area_inside / area
+ ratio = (bbox[3] - bbox[1]) / (bbox[2] - bbox[0])
+ area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
+ if area_inside_ratio(bbox, imshape) < min_bbox_ratio_inside:
+ return False
+ if ratio <= aspect_ratio_range[0] or ratio >= aspect_ratio_range[1] or area < minimum_area:
+ return False
+ return True
diff --git a/dp2/detection/box_utils_fdf.py b/dp2/detection/box_utils_fdf.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1af1e3012f9a594fd43d2b17a9dabdb4cda6ba2
--- /dev/null
+++ b/dp2/detection/box_utils_fdf.py
@@ -0,0 +1,202 @@
+"""
+The FDF dataset expands bound boxes differently from what is used for CSE.
+"""
+
+import numpy as np
+
+
+def quadratic_bounding_box(x0, y0, width, height, imshape):
+ # We assume that we can create a image that is quadratic without
+ # minimizing any of the sides
+ assert width <= min(imshape[:2])
+ assert height <= min(imshape[:2])
+ min_side = min(height, width)
+ if height != width:
+ side_diff = abs(height - width)
+ # Want to extend the shortest side
+ if min_side == height:
+ # Vertical side
+ height += side_diff
+ if height > imshape[0]:
+ # Take full frame, and shrink width
+ y0 = 0
+ height = imshape[0]
+
+ side_diff = abs(height - width)
+ width -= side_diff
+ x0 += side_diff // 2
+ else:
+ y0 -= side_diff // 2
+ y0 = max(0, y0)
+ else:
+ # Horizontal side
+ width += side_diff
+ if width > imshape[1]:
+ # Take full frame width, and shrink height
+ x0 = 0
+ width = imshape[1]
+
+ side_diff = abs(height - width)
+ height -= side_diff
+ y0 += side_diff // 2
+ else:
+ x0 -= side_diff // 2
+ x0 = max(0, x0)
+ # Check that bbox goes outside image
+ x1 = x0 + width
+ y1 = y0 + height
+ if imshape[1] < x1:
+ diff = x1 - imshape[1]
+ x0 -= diff
+ if imshape[0] < y1:
+ diff = y1 - imshape[0]
+ y0 -= diff
+ assert x0 >= 0, "Bounding box outside image."
+ assert y0 >= 0, "Bounding box outside image."
+ assert x0 + width <= imshape[1], "Bounding box outside image."
+ assert y0 + height <= imshape[0], "Bounding box outside image."
+ return x0, y0, width, height
+
+
+def expand_bounding_box(bbox, percentage, imshape):
+ orig_bbox = bbox.copy()
+ x0, y0, x1, y1 = bbox
+ width = x1 - x0
+ height = y1 - y0
+ x0, y0, width, height = quadratic_bounding_box(
+ x0, y0, width, height, imshape)
+ expanding_factor = int(max(height, width) * percentage)
+
+ possible_max_expansion = [(imshape[0] - width) // 2,
+ (imshape[1] - height) // 2,
+ expanding_factor]
+
+ expanding_factor = min(possible_max_expansion)
+ # Expand height
+
+ if expanding_factor > 0:
+
+ y0 = y0 - expanding_factor
+ y0 = max(0, y0)
+
+ height += expanding_factor * 2
+ if height > imshape[0]:
+ y0 -= (imshape[0] - height)
+ height = imshape[0]
+
+ if height + y0 > imshape[0]:
+ y0 -= (height + y0 - imshape[0])
+
+ # Expand width
+ x0 = x0 - expanding_factor
+ x0 = max(0, x0)
+
+ width += expanding_factor * 2
+ if width > imshape[1]:
+ x0 -= (imshape[1] - width)
+ width = imshape[1]
+
+ if width + x0 > imshape[1]:
+ x0 -= (width + x0 - imshape[1])
+ y1 = y0 + height
+ x1 = x0 + width
+ assert y0 >= 0, "Y0 is minus"
+ assert height <= imshape[0], "Height is larger than image."
+ assert x0 + width <= imshape[1]
+ assert y0 + height <= imshape[0]
+ assert width == height, "HEIGHT IS NOT EQUAL WIDTH!!"
+ assert x0 >= 0, "Y0 is minus"
+ assert width <= imshape[1], "Height is larger than image."
+ # Check that original bbox is within new
+ x0_o, y0_o, x1_o, y1_o = orig_bbox
+ assert x0 <= x0_o, f"New bbox is outisde of original. O:{x0_o}, N: {x0}"
+ assert x1 >= x1_o, f"New bbox is outisde of original. O:{x1_o}, N: {x1}"
+ assert y0 <= y0_o, f"New bbox is outisde of original. O:{y0_o}, N: {y0}"
+ assert y1 >= y1_o, f"New bbox is outisde of original. O:{y1_o}, N: {y1}"
+
+ x0, y0, width, height = [int(_) for _ in [x0, y0, width, height]]
+ x1 = x0 + width
+ y1 = y0 + height
+ return np.array([x0, y0, x1, y1])
+
+
+def is_keypoint_within_bbox(x0, y0, x1, y1, keypoint):
+ keypoint = keypoint[:, :3] # only nose + eyes are relevant
+ kp_X = keypoint[0, :]
+ kp_Y = keypoint[1, :]
+ within_X = np.all(kp_X >= x0) and np.all(kp_X <= x1)
+ within_Y = np.all(kp_Y >= y0) and np.all(kp_Y <= y1)
+ return within_X and within_Y
+
+
+def expand_bbox_simple(bbox, percentage):
+ x0, y0, x1, y1 = bbox.astype(float)
+ width = x1 - x0
+ height = y1 - y0
+ x_c = int(x0) + width // 2
+ y_c = int(y0) + height // 2
+ avg_size = max(width, height)
+ new_width = avg_size * (1 + percentage)
+ x0 = x_c - new_width // 2
+ y0 = y_c - new_width // 2
+ x1 = x_c + new_width // 2
+ y1 = y_c + new_width // 2
+ return np.array([x0, y0, x1, y1]).astype(int)
+
+
+def pad_image(im, bbox, pad_value):
+ x0, y0, x1, y1 = bbox
+ if x0 < 0:
+ pad_im = np.zeros((im.shape[0], abs(x0), im.shape[2]),
+ dtype=np.uint8) + pad_value
+ im = np.concatenate((pad_im, im), axis=1)
+ x1 += abs(x0)
+ x0 = 0
+ if y0 < 0:
+ pad_im = np.zeros((abs(y0), im.shape[1], im.shape[2]),
+ dtype=np.uint8) + pad_value
+ im = np.concatenate((pad_im, im), axis=0)
+ y1 += abs(y0)
+ y0 = 0
+ if x1 >= im.shape[1]:
+ pad_im = np.zeros(
+ (im.shape[0], x1 - im.shape[1] + 1, im.shape[2]),
+ dtype=np.uint8) + pad_value
+ im = np.concatenate((im, pad_im), axis=1)
+ if y1 >= im.shape[0]:
+ pad_im = np.zeros(
+ (y1 - im.shape[0] + 1, im.shape[1], im.shape[2]),
+ dtype=np.uint8) + pad_value
+ im = np.concatenate((im, pad_im), axis=0)
+ return im[y0:y1, x0:x1]
+
+
+def clip_box(bbox, im):
+ bbox[0] = max(0, bbox[0])
+ bbox[1] = max(0, bbox[1])
+ bbox[2] = min(im.shape[1] - 1, bbox[2])
+ bbox[3] = min(im.shape[0] - 1, bbox[3])
+ return bbox
+
+
+def cut_face(im, bbox, simple_expand=False, pad_value=0, pad_im=True):
+ outside_im = (bbox < 0).any() or bbox[2] > im.shape[1] or bbox[3] > im.shape[0]
+ if simple_expand or (outside_im and pad_im):
+ return pad_image(im, bbox, pad_value)
+ bbox = clip_box(bbox, im)
+ x0, y0, x1, y1 = bbox
+ return im[y0:y1, x0:x1]
+
+
+def expand_bbox(
+ bbox_ltrb, imshape, simple_expand, default_to_simple=False,
+ expansion_factor=0.35):
+ assert bbox_ltrb.shape == (4,), f"BBox shape was: {bbox_ltrb.shape}"
+ bbox = bbox_ltrb.astype(float)
+ # FDF256 uses simple expand with ratio 0.4
+ if simple_expand:
+ return expand_bbox_simple(bbox, 0.4)
+ try:
+ return expand_bounding_box(bbox, expansion_factor, imshape)
+ except AssertionError:
+ return expand_bbox_simple(bbox, expansion_factor * 2)
diff --git a/dp2/detection/cse_mask_face_detector.py b/dp2/detection/cse_mask_face_detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..5eccfc4ac885cfcca47fef389b99c1e35685579b
--- /dev/null
+++ b/dp2/detection/cse_mask_face_detector.py
@@ -0,0 +1,116 @@
+import torch
+import lzma
+import tops
+from pathlib import Path
+from dp2.detection.base import BaseDetector
+from .utils import combine_cse_maskrcnn_dets
+from face_detection import build_detector as build_face_detector
+from .models.cse import CSEDetector
+from .models.mask_rcnn import MaskRCNNDetector
+from .structures import CSEPersonDetection, VehicleDetection, FaceDetection, PersonDetection
+from tops import logger
+
+
+def box1_inside_box2(box1: torch.Tensor, box2: torch.Tensor):
+ assert len(box1.shape) == 2
+ assert len(box2.shape) == 2
+ box1_inside = torch.zeros(box1.shape[0], device=box1.device, dtype=torch.bool)
+ # This can be batched
+ for i, box in enumerate(box1):
+ is_outside_lefttop = (box[None, [0, 1]] <= box2[:, [0, 1]]).any(dim=1)
+ is_outside_rightbot = (box[None, [2, 3]] >= box2[:, [2, 3]]).any(dim=1)
+ is_outside = is_outside_lefttop.logical_or(is_outside_rightbot)
+ box1_inside[i] = is_outside.logical_not().any()
+ return box1_inside
+
+
+class CSeMaskFaceDetector(BaseDetector):
+
+ def __init__(
+ self,
+ mask_rcnn_cfg,
+ face_detector_cfg: dict,
+ cse_cfg: dict,
+ face_post_process_cfg: dict,
+ cse_post_process_cfg,
+ score_threshold: float,
+ **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold)
+ if "confidence_threshold" not in face_detector_cfg:
+ face_detector_cfg["confidence_threshold"] = score_threshold
+ if "score_thres" not in cse_cfg:
+ cse_cfg["score_thres"] = score_threshold
+ self.cse_detector = CSEDetector(**cse_cfg)
+ self.face_detector = build_face_detector(**face_detector_cfg, clip_boxes=True)
+ self.cse_post_process_cfg = cse_post_process_cfg
+ self.face_mean = tops.to_cuda(torch.from_numpy(self.face_detector.mean).view(3, 1, 1))
+ self.mask_cse_iou_combine_threshold = self.cse_post_process_cfg.pop("iou_combine_threshold")
+ self.face_post_process_cfg = face_post_process_cfg
+
+ def __call__(self, *args, **kwargs):
+ return self.forward(*args, **kwargs)
+
+ def _detect_faces(self, im: torch.Tensor):
+ H, W = im.shape[1:]
+ im = im.float() - self.face_mean
+ im = self.face_detector.resize(im[None], 1.0)
+ boxes_XYXY = self.face_detector._batched_detect(im)[0][:, :-1] # Remove score
+ boxes_XYXY[:, [0, 2]] *= W
+ boxes_XYXY[:, [1, 3]] *= H
+ return boxes_XYXY.round().long()
+
+ def load_from_cache(self, cache_path: Path):
+ logger.log(f"Loading detection from cache path: {cache_path}",)
+ with lzma.open(cache_path, "rb") as fp:
+ state_dict = torch.load(fp, map_location="cpu")
+ kwargs = dict(
+ post_process_cfg=self.cse_post_process_cfg,
+ embed_map=self.cse_detector.embed_map,
+ **self.face_post_process_cfg
+ )
+ return [
+ state["cls"].from_state_dict(**kwargs, state_dict=state)
+ for state in state_dict
+ ]
+
+ @torch.no_grad()
+ def forward(self, im: torch.Tensor):
+ maskrcnn_dets = self.mask_rcnn(im)
+ cse_dets = self.cse_detector(im)
+ embed_map = self.cse_detector.embed_map
+ print("Calling face detector.")
+ face_boxes = self._detect_faces(im).cpu()
+ maskrcnn_person = {
+ k: v[maskrcnn_dets["is_person"]] for k, v in maskrcnn_dets.items()
+ }
+ maskrcnn_other = {
+ k: v[maskrcnn_dets["is_person"].logical_not()] for k, v in maskrcnn_dets.items()
+ }
+ maskrcnn_other = VehicleDetection(maskrcnn_other["segmentation"])
+ combined_segmentation, cse_dets, matches = combine_cse_maskrcnn_dets(
+ maskrcnn_person["segmentation"], cse_dets, self.mask_cse_iou_combine_threshold)
+
+ persons_with_cse = CSEPersonDetection(
+ combined_segmentation, cse_dets, **self.cse_post_process_cfg,
+ embed_map=embed_map, orig_imshape_CHW=im.shape
+ )
+ persons_with_cse.pre_process()
+ not_matched = [i for i in range(maskrcnn_person["segmentation"].shape[0]) if i not in matches[:, 0]]
+ persons_without_cse = PersonDetection(
+ maskrcnn_person["segmentation"][not_matched], **self.cse_post_process_cfg,
+ orig_imshape_CHW=im.shape
+ )
+ persons_without_cse.pre_process()
+
+ face_boxes_covered = box1_inside_box2(face_boxes, persons_with_cse.dilated_boxes).logical_or(
+ box1_inside_box2(face_boxes, persons_without_cse.dilated_boxes)
+ )
+ face_boxes = face_boxes[face_boxes_covered.logical_not()]
+ face_boxes = FaceDetection(face_boxes, **self.face_post_process_cfg)
+
+ # Order matters. The anonymizer will anonymize FIFO.
+ # Later detections will overwrite.
+ all_detections = [face_boxes, maskrcnn_other, persons_without_cse, persons_with_cse]
+ return all_detections
diff --git a/dp2/detection/deep_privacy1_detector.py b/dp2/detection/deep_privacy1_detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..9657e7aaa993033c559310edcab0cc54ee2a03e1
--- /dev/null
+++ b/dp2/detection/deep_privacy1_detector.py
@@ -0,0 +1,106 @@
+import torch
+import tops
+import lzma
+from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
+from .base import BaseDetector
+from face_detection import build_detector as build_face_detector
+from .structures import FaceDetection
+from tops import logger
+from pathlib import Path
+
+def is_keypoint_within_bbox(x0, y0, x1, y1, keypoint):
+ keypoint = keypoint[:3, :] # only nose + eyes are relevant
+ kp_X = keypoint[:, 0]
+ kp_Y = keypoint[:, 1]
+ within_X = (kp_X >= x0).all() and (kp_X <= x1).all()
+ within_Y = (kp_Y >= y0).all() and (kp_Y <= y1).all()
+ return within_X and within_Y
+
+
+def match_bbox_keypoint(bounding_boxes, keypoints):
+ """
+ bounding_boxes shape: [N, 5]
+ keypoints: [N persons, K keypoints, (x, y)]
+ """
+ if len(bounding_boxes) == 0 or len(keypoints) == 0:
+ return torch.empty((0, 5)), torch.empty((0, 7, 2))
+ assert bounding_boxes.shape[1] == 4,\
+ f"Shape was : {bounding_boxes.shape}"
+ assert keypoints.shape[-1] == 2,\
+ f"Expected (x,y) in last axis, got: {keypoints.shape}"
+ assert keypoints.shape[1] in (5, 7),\
+ f"Expeted 5 or 7 keypoints. Keypoint shape was: {keypoints.shape}"
+
+ matches = []
+ for bbox_idx, bbox in enumerate(bounding_boxes):
+ keypoint = None
+ for kp_idx, keypoint in enumerate(keypoints):
+ if kp_idx in (x[1] for x in matches):
+ continue
+ if is_keypoint_within_bbox(*bbox, keypoint):
+ matches.append((bbox_idx, kp_idx))
+ break
+ keypoint_idx = [x[1] for x in matches]
+ bbox_idx = [x[0] for x in matches]
+ return bounding_boxes[bbox_idx], keypoints[keypoint_idx]
+
+
+class DeepPrivacy1Detector(BaseDetector):
+
+ def __init__(self,
+ keypoint_threshold: float,
+ face_detector_cfg,
+ score_threshold: float,
+ face_post_process_cfg,
+ **kwargs):
+ super().__init__(**kwargs)
+ self.keypoint_detector = tops.to_cuda(keypointrcnn_resnet50_fpn(
+ weights=KeypointRCNN_ResNet50_FPN_Weights.COCO_V1).eval())
+ self.keypoint_threshold = keypoint_threshold
+ self.face_detector = build_face_detector(**face_detector_cfg, confidence_threshold=score_threshold)
+ self.face_mean = tops.to_cuda(torch.from_numpy(self.face_detector.mean).view(3, 1, 1))
+ self.face_post_process_cfg = face_post_process_cfg
+
+ @torch.no_grad()
+ def _detect_faces(self, im: torch.Tensor):
+ H, W = im.shape[1:]
+ im = im.float() - self.face_mean
+ im = self.face_detector.resize(im[None], 1.0)
+ boxes_XYXY = self.face_detector._batched_detect(im)[0][:, :-1] # Remove score
+ boxes_XYXY[:, [0, 2]] *= W
+ boxes_XYXY[:, [1, 3]] *= H
+ return boxes_XYXY.round().long().cpu()
+
+ @torch.no_grad()
+ def _detect_keypoints(self, img: torch.Tensor):
+ img = img.float() / 255
+ outputs = self.keypoint_detector([img])
+
+ # Shape: [N persons, K keypoints, (x,y,visibility)]
+ keypoints = outputs[0]["keypoints"]
+ scores = outputs[0]["scores"]
+ assert list(scores) == sorted(list(scores))[::-1]
+ mask = scores >= self.keypoint_threshold
+ keypoints = keypoints[mask, :, :2]
+ return keypoints[:, :7, :2]
+
+ def __call__(self, *args, **kwargs):
+ return self.forward(*args, **kwargs)
+
+ @torch.no_grad()
+ def forward(self, im: torch.Tensor):
+ face_boxes = self._detect_faces(im)
+ keypoints = self._detect_keypoints(im)
+ face_boxes, keypoints = match_bbox_keypoint(face_boxes, keypoints)
+ face_boxes = FaceDetection(face_boxes, **self.face_post_process_cfg, keypoints=keypoints)
+ return [face_boxes]
+
+ def load_from_cache(self, cache_path: Path):
+ logger.log(f"Loading detection from cache path: {cache_path}",)
+ with lzma.open(cache_path, "rb") as fp:
+ state_dict = torch.load(fp, map_location="cpu")
+ kwargs = self.face_post_process_cfg
+ return [
+ state["cls"].from_state_dict(**kwargs, state_dict=state)
+ for state in state_dict
+ ]
\ No newline at end of file
diff --git a/dp2/detection/face_detector.py b/dp2/detection/face_detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..8354d7b71ffd9147d7a7c1be462b6a1c30626672
--- /dev/null
+++ b/dp2/detection/face_detector.py
@@ -0,0 +1,62 @@
+import torch
+import lzma
+import tops
+from pathlib import Path
+from dp2.detection.base import BaseDetector
+from face_detection import build_detector as build_face_detector
+from .structures import FaceDetection
+from tops import logger
+
+
+def box1_inside_box2(box1: torch.Tensor, box2: torch.Tensor):
+ assert len(box1.shape) == 2
+ assert len(box2.shape) == 2
+ box1_inside = torch.zeros(box1.shape[0], device=box1.device, dtype=torch.bool)
+ # This can be batched
+ for i, box in enumerate(box1):
+ is_outside_lefttop = (box[None, [0, 1]] <= box2[:, [0, 1]]).any(dim=1)
+ is_outside_rightbot = (box[None, [2, 3]] >= box2[:, [2, 3]]).any(dim=1)
+ is_outside = is_outside_lefttop.logical_or(is_outside_rightbot)
+ box1_inside[i] = is_outside.logical_not().any()
+ return box1_inside
+
+
+class FaceDetector(BaseDetector):
+
+ def __init__(
+ self,
+ face_detector_cfg: dict,
+ score_threshold: float,
+ face_post_process_cfg: dict,
+ **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ self.face_detector = build_face_detector(**face_detector_cfg, confidence_threshold=score_threshold)
+ self.face_mean = tops.to_cuda(torch.from_numpy(self.face_detector.mean).view(3, 1, 1))
+ self.face_post_process_cfg = face_post_process_cfg
+
+ def __call__(self, *args, **kwargs):
+ return self.forward(*args, **kwargs)
+
+ def _detect_faces(self, im: torch.Tensor):
+ H, W = im.shape[1:]
+ im = im.float() - self.face_mean
+ im = self.face_detector.resize(im[None], 1.0)
+ boxes_XYXY = self.face_detector._batched_detect(im)[0][:, :-1] # Remove score
+ boxes_XYXY[:, [0, 2]] *= W
+ boxes_XYXY[:, [1, 3]] *= H
+ return boxes_XYXY.round().long().cpu()
+
+ @torch.no_grad()
+ def forward(self, im: torch.Tensor):
+ face_boxes = self._detect_faces(im)
+ face_boxes = FaceDetection(face_boxes, **self.face_post_process_cfg)
+ return [face_boxes]
+
+ def load_from_cache(self, cache_path: Path):
+ logger.log(f"Loading detection from cache path: {cache_path}")
+ with lzma.open(cache_path, "rb") as fp:
+ state_dict = torch.load(fp)
+ return [
+ state["cls"].from_state_dict(state_dict=state, **self.face_post_process_cfg) for state in state_dict
+ ]
diff --git a/dp2/detection/models/__init__.py b/dp2/detection/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/dp2/detection/models/cse.py b/dp2/detection/models/cse.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b9fbf75fc8fd6c1905993316c03383b9e935564
--- /dev/null
+++ b/dp2/detection/models/cse.py
@@ -0,0 +1,134 @@
+import torch
+from typing import List
+import tops
+from torchvision.transforms.functional import InterpolationMode, resize
+from densepose.data.utils import get_class_to_mesh_name_mapping
+from densepose import add_densepose_config
+from densepose.structures import DensePoseEmbeddingPredictorOutput
+from densepose.vis.extractor import DensePoseOutputsExtractor
+from densepose.modeling import build_densepose_embedder
+from detectron2.config import get_cfg
+from detectron2.data.transforms import ResizeShortestEdge
+from detectron2.checkpoint.detection_checkpoint import DetectionCheckpointer
+from detectron2.modeling import build_model
+
+
+model_urls = {
+ "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml": "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x/250713061/model_final_1d3314.pkl",
+ "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_s1x.yaml": "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_50_FPN_s1x/251155172/model_final_c4ea5f.pkl",
+}
+
+
+def cse_det_to_global(boxes_XYXY, S: torch.Tensor, imshape):
+ assert len(S.shape) == 3
+ H, W = imshape
+ N = len(boxes_XYXY)
+ segmentation = torch.zeros((N, H, W), dtype=torch.bool, device=S.device)
+ boxes_XYXY = boxes_XYXY.long()
+ for i in range(N):
+ x0, y0, x1, y1 = boxes_XYXY[i]
+ assert x0 >= 0 and y0 >= 0
+ assert x1 <= imshape[1]
+ assert y1 <= imshape[0]
+ h = y1 - y0
+ w = x1 - x0
+ segmentation[i:i+1, y0:y1, x0:x1] = resize(S[i:i+1], (h, w), interpolation=InterpolationMode.NEAREST) > 0
+ return segmentation
+
+
+class CSEDetector:
+
+ def __init__(
+ self,
+ cfg_url: str = "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml",
+ cfg_2_download: List[str] = [
+ "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml",
+ "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/Base-DensePose-RCNN-FPN.yaml",
+ "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/Base-DensePose-RCNN-FPN-Human.yaml"],
+ score_thres: float = 0.9,
+ nms_thresh: float = None,
+ ) -> None:
+ with tops.logger.capture_log_stdout():
+ cfg = get_cfg()
+ self.device = tops.get_device()
+ add_densepose_config(cfg)
+ cfg_path = tops.download_file(cfg_url)
+ for p in cfg_2_download:
+ tops.download_file(p)
+ with tops.logger.capture_log_stdout():
+ cfg.merge_from_file(cfg_path)
+ assert cfg_url in model_urls, cfg_url
+ model_path = tops.download_file(model_urls[cfg_url])
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = score_thres
+ if nms_thresh is not None:
+ cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = nms_thresh
+ cfg.MODEL.WEIGHTS = str(model_path)
+ cfg.MODEL.DEVICE = str(self.device)
+ cfg.freeze()
+ with tops.logger.capture_log_stdout():
+ self.model = build_model(cfg)
+ self.model.eval()
+ DetectionCheckpointer(self.model).load(str(model_path))
+ self.input_format = cfg.INPUT.FORMAT
+ self.densepose_extractor = DensePoseOutputsExtractor()
+ self.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg)
+
+ self.embedder = build_densepose_embedder(cfg)
+ self.mesh_vertex_embeddings = {
+ mesh_name: self.embedder(mesh_name).to(self.device)
+ for mesh_name in self.class_to_mesh_name.values()
+ if self.embedder.has_embeddings(mesh_name)
+ }
+ self.cfg = cfg
+ self.embed_map = self.mesh_vertex_embeddings["smpl_27554"]
+ tops.logger.log("CSEDetector built.")
+
+ def __call__(self, *args, **kwargs):
+ return self.forward(*args, **kwargs)
+
+ def resize_im(self, im):
+ H, W = im.shape[1:]
+ newH, newW = ResizeShortestEdge.get_output_shape(
+ H, W, self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST)
+ return resize(
+ im, (newH, newW), InterpolationMode.BILINEAR, antialias=True)
+
+ @torch.no_grad()
+ def forward(self, im):
+ assert im.dtype == torch.uint8
+ if self.input_format == "BGR":
+ im = im.flip(0)
+ H, W = im.shape[1:]
+ im = self.resize_im(im)
+ output = self.model([{"image": im, "height": H, "width": W}])[0]["instances"]
+ scores = output.get("scores")
+ if len(scores) == 0:
+ return dict(
+ instance_segmentation=torch.empty((0, 0, 112, 112), dtype=torch.bool, device=im.device),
+ instance_embedding=torch.empty((0, 16, 112, 112), dtype=torch.float32, device=im.device),
+ embed_map=self.mesh_vertex_embeddings["smpl_27554"],
+ bbox_XYXY=torch.empty((0, 4), dtype=torch.long, device=im.device),
+ im_segmentation=torch.empty((0, H, W), dtype=torch.bool, device=im.device),
+ scores=torch.empty((0), dtype=torch.float, device=im.device)
+ )
+ pred_densepose, boxes_xywh, classes = self.densepose_extractor(output)
+ assert isinstance(pred_densepose, DensePoseEmbeddingPredictorOutput), pred_densepose
+ S = pred_densepose.coarse_segm.argmax(dim=1) # Segmentation channel Nx2xHxW (2 because only 2 classes)
+ E = pred_densepose.embedding
+ mesh_name = self.class_to_mesh_name[classes[0]]
+ assert mesh_name == "smpl_27554"
+ x0, y0, w, h = [boxes_xywh[:, i] for i in range(4)]
+ boxes_XYXY = torch.stack((x0, y0, x0+w, y0+h), dim=-1)
+ boxes_XYXY = boxes_XYXY.round_().long()
+
+ non_empty_boxes = (boxes_XYXY[:, :2] == boxes_XYXY[:, 2:]).any(dim=1).logical_not()
+ S = S[non_empty_boxes]
+ E = E[non_empty_boxes]
+ boxes_XYXY = boxes_XYXY[non_empty_boxes]
+ scores = scores[non_empty_boxes]
+ im_segmentation = cse_det_to_global(boxes_XYXY, S, [H, W])
+ return dict(
+ instance_segmentation=S, instance_embedding=E,
+ bbox_XYXY=boxes_XYXY,
+ im_segmentation=im_segmentation,
+ scores=scores.view(-1))
diff --git a/dp2/detection/models/keypoint_maskrcnn.py b/dp2/detection/models/keypoint_maskrcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9e35a9f44a7133e9a0418f5b73d2bc39229139b
--- /dev/null
+++ b/dp2/detection/models/keypoint_maskrcnn.py
@@ -0,0 +1,111 @@
+import numpy as np
+import torch
+from detectron2.checkpoint import DetectionCheckpointer
+from detectron2.modeling.roi_heads import CascadeROIHeads, StandardROIHeads
+from detectron2.data.transforms import ResizeShortestEdge
+from detectron2.structures import Instances
+from detectron2 import model_zoo
+from detectron2.config import instantiate
+from detectron2.config import LazyCall as L
+from PIL import Image
+import tops
+import functools
+from torchvision.transforms.functional import resize
+
+
+def get_rn50_fpn_keypoint_rcnn(weight_path: str):
+ from detectron2.modeling.poolers import ROIPooler
+ from detectron2.modeling.roi_heads import KRCNNConvDeconvUpsampleHead
+ from detectron2.layers import ShapeSpec
+ model = model_zoo.get_config("common/models/mask_rcnn_fpn.py").model
+ model.roi_heads.update(
+ num_classes=1,
+ keypoint_in_features=["p2", "p3", "p4", "p5"],
+ keypoint_pooler=L(ROIPooler)(
+ output_size=14,
+ scales=(1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32),
+ sampling_ratio=0,
+ pooler_type="ROIAlignV2",
+ ),
+ keypoint_head=L(KRCNNConvDeconvUpsampleHead)(
+ input_shape=ShapeSpec(channels=256, width=14, height=14),
+ num_keypoints=17,
+ conv_dims=[512] * 8,
+ loss_normalizer="visible",
+ ),
+ )
+
+ # Detectron1 uses 2000 proposals per-batch, but this option is per-image in detectron2.
+ # 1000 proposals per-image is found to hurt box AP.
+ # Therefore we increase it to 1500 per-image.
+ model.proposal_generator.post_nms_topk = (1500, 1000)
+
+ # Keypoint AP degrades (though box AP improves) when using plain L1 loss
+ model.roi_heads.box_predictor.smooth_l1_beta = 0.5
+ model = instantiate(model)
+
+ dataloader = model_zoo.get_config("common/data/coco_keypoint.py").dataloader
+ test_transform = instantiate(dataloader.test.mapper.augmentations)
+ DetectionCheckpointer(model).load(weight_path)
+ return model, test_transform
+
+
+models = {
+ "rn50_fpn_maskrcnn": functools.partial(get_rn50_fpn_keypoint_rcnn, weight_path="https://folk.ntnu.no/haakohu/checkpoints/maskrcnn_keypoint/keypoint_maskrcnn_R_50_FPN_1x.pth")
+}
+
+
+class KeypointMaskRCNN:
+
+ def __init__(self, model_name: str, score_threshold: float) -> None:
+ assert model_name in models, f"Did not find {model_name} in models"
+ model, test_transform = models[model_name]()
+ self.model = model.eval().to(tops.get_device())
+ if isinstance(self.model.roi_heads, CascadeROIHeads):
+ for head in self.model.roi_heads.box_predictors:
+ assert hasattr(head, "test_score_thresh")
+ head.test_score_thresh = score_threshold
+ else:
+ assert isinstance(self.model.roi_heads, StandardROIHeads)
+ assert hasattr(self.model.roi_heads.box_predictor, "test_score_thresh")
+ self.model.roi_heads.box_predictor.test_score_thresh = score_threshold
+
+ self.test_transform = test_transform
+ assert len(self.test_transform) == 1
+ self.test_transform = self.test_transform[0]
+ assert isinstance(self.test_transform, ResizeShortestEdge)
+ assert self.test_transform.interp == Image.BILINEAR
+ self.image_format = self.model.input_format
+
+ def resize_im(self, im):
+ H, W = im.shape[-2:]
+ if self.test_transform.is_range:
+ size = np.random.randint(
+ self.test_transform.short_edge_length[0], self.test_transform.short_edge_length[1] + 1)
+ else:
+ size = np.random.choice(self.test_transform.short_edge_length)
+ newH, newW = ResizeShortestEdge.get_output_shape(H, W, size, self.test_transform.max_size)
+ return resize(
+ im, (newH, newW), antialias=True)
+
+ def __call__(self, *args, **kwargs):
+ return self.forward(*args, **kwargs)
+
+ @torch.no_grad()
+ def forward(self, im: torch.Tensor):
+ assert im.ndim == 3
+ if self.image_format == "BGR":
+ im = im.flip(0)
+ H, W = im.shape[-2:]
+ im = im.float()
+ im = self.resize_im(im)
+
+ inputs = dict(image=im, height=H, width=W)
+ # instances contains
+ # dict_keys(['pred_boxes', 'scores', 'pred_classes', 'pred_masks', 'pred_keypoints', 'pred_keypoint_heatmaps'])
+ instances = self.model([inputs])[0]["instances"]
+ return dict(
+ scores=instances.get("scores").cpu(),
+ segmentation=instances.get("pred_masks").cpu(),
+ keypoints=instances.get("pred_keypoints").cpu()
+ )
diff --git a/dp2/detection/models/mask_rcnn.py b/dp2/detection/models/mask_rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed64706c0036d6dcc2355c8ce2f830bd8a22c3e3
--- /dev/null
+++ b/dp2/detection/models/mask_rcnn.py
@@ -0,0 +1,78 @@
+import torch
+import tops
+from detectron2.modeling import build_model
+from detectron2.checkpoint import DetectionCheckpointer
+from detectron2.structures import Boxes
+from detectron2.data import MetadataCatalog
+from detectron2 import model_zoo
+from typing import Dict
+from detectron2.data.transforms import ResizeShortestEdge
+from torchvision.transforms.functional import resize
+
+
+model_urls = {
+ "COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml": "https://dl.fbaipublicfiles.com/detectron2/COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x/139653917/model_final_2d9806.pkl",
+
+}
+
+
+class MaskRCNNDetector:
+
+ def __init__(
+ self,
+ cfg_name: str = "COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml",
+ score_thres: float = 0.9,
+ class_filter=["person"], # ["car", "bicycle","truck", "bus", "backpack"]
+ fp16_inference: bool = False
+ ) -> None:
+ cfg = model_zoo.get_config(cfg_name)
+ cfg.MODEL.DEVICE = str(tops.get_device())
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = score_thres
+ cfg.freeze()
+ self.cfg = cfg
+ with tops.logger.capture_log_stdout():
+ self.model = build_model(cfg)
+ DetectionCheckpointer(self.model).load(model_urls[cfg_name])
+ self.model.eval()
+ self.input_format = cfg.INPUT.FORMAT
+ self.class_names = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes
+ self.class_to_keep = set([self.class_names.index(cls_) for cls_ in class_filter])
+ self.person_class = self.class_names.index("person")
+ self.fp16_inference = fp16_inference
+ tops.logger.log("Mask R-CNN built.")
+
+ def __call__(self, *args, **kwargs):
+ return self.forward(*args, **kwargs)
+
+ def resize_im(self, im):
+ H, W = im.shape[1:]
+ newH, newW = ResizeShortestEdge.get_output_shape(
+ H, W, self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST)
+ return resize(
+ im, (newH, newW), antialias=True)
+
+ @torch.no_grad()
+ def forward(self, im: torch.Tensor):
+ if self.input_format == "BGR":
+ im = im.flip(0)
+ else:
+ assert self.input_format == "RGB"
+ H, W = im.shape[-2:]
+ im = self.resize_im(im)
+ with torch.cuda.amp.autocast(enabled=self.fp16_inference):
+ output = self.model([{"image": im, "height": H, "width": W}])[0]["instances"]
+ scores = output.get("scores")
+ N = len(scores)
+ classes = output.get("pred_classes")
+ idx2keep = [i for i in range(N) if classes[i].tolist() in self.class_to_keep]
+ classes = classes[idx2keep]
+ assert isinstance(output.get("pred_boxes"), Boxes)
+ segmentation = output.get("pred_masks")[idx2keep]
+ assert segmentation.dtype == torch.bool
+ is_person = classes == self.person_class
+ return {
+ "scores": output.get("scores")[idx2keep],
+ "segmentation": segmentation,
+ "classes": output.get("pred_classes")[idx2keep],
+ "is_person": is_person
+ }
diff --git a/dp2/detection/models/vit_pose/backbone.py b/dp2/detection/models/vit_pose/backbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..f78712dd6131b02ea674fac452852274ae0dc5c2
--- /dev/null
+++ b/dp2/detection/models/vit_pose/backbone.py
@@ -0,0 +1,311 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Code adapted from: https://github.com/gpastal24/ViTPose-Pytorch
+import torch
+from functools import partial
+import torch.nn as nn
+import torch.utils.checkpoint as checkpoint
+
+from timm.models.layers import drop_path, to_2tuple, trunc_normal_
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+ def extra_repr(self):
+ return 'p={}'.format(self.drop_prob)
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
+ proj_drop=0., attn_head_dim=None,):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.dim = dim
+
+ if attn_head_dim is not None:
+ head_dim = attn_head_dim
+ all_head_dim = head_dim * self.num_heads
+
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias)
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(all_head_dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x)
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
+ drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm, attn_head_dim=None
+ ):
+ super().__init__()
+
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim
+ )
+
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def forward(self, x):
+ x = x + self.drop_path(self.attn(self.norm1(x)))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+ """
+
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2)
+ self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio))
+ self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1]))
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(
+ patch_size[0] // ratio), padding=4 + 2 * (ratio//2-1))
+
+ def forward(self, x, **kwargs):
+ B, C, H, W = x.shape
+ x = self.proj(x)
+ Hp, Wp = x.shape[2], x.shape[3]
+
+ x = x.flatten(2).transpose(1, 2)
+ return x, (Hp, Wp)
+
+
+class HybridEmbed(nn.Module):
+ """ CNN Feature Map Embedding
+ Extract feature map from CNN, flatten, project to embedding dim.
+ """
+
+ def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
+ super().__init__()
+ assert isinstance(backbone, nn.Module)
+ img_size = to_2tuple(img_size)
+ self.img_size = img_size
+ self.backbone = backbone
+ if feature_size is None:
+ with torch.no_grad():
+ training = backbone.training
+ if training:
+ backbone.eval()
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
+ feature_size = o.shape[-2:]
+ feature_dim = o.shape[1]
+ backbone.train(training)
+ else:
+ feature_size = to_2tuple(feature_size)
+ feature_dim = self.backbone.feature_info.channels()[-1]
+ self.num_patches = feature_size[0] * feature_size[1]
+ self.proj = nn.Linear(feature_dim, embed_dim)
+
+ def forward(self, x):
+ x = self.backbone(x)[-1]
+ x = x.flatten(2).transpose(1, 2)
+ x = self.proj(x)
+ return x
+
+
+# @BACKBONES.register_module()
+class ViT(nn.Module):
+
+ def __init__(self,
+ img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12,
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
+ drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False,
+ frozen_stages=-1, ratio=1, last_norm=True,
+ patch_padding='pad', freeze_attn=False, freeze_ffn=False,
+ ):
+ # Protect mutable default arguments
+ super(ViT, self).__init__()
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+ self.num_classes = num_classes
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.frozen_stages = frozen_stages
+ self.use_checkpoint = use_checkpoint
+ self.patch_padding = patch_padding
+ self.freeze_attn = freeze_attn
+ self.freeze_ffn = freeze_ffn
+ self.depth = depth
+
+ if hybrid_backbone is not None:
+ self.patch_embed = HybridEmbed(
+ hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
+ else:
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio)
+ num_patches = self.patch_embed.num_patches
+
+ # since the pretraining model has class token
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ self.blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
+ )
+ for i in range(depth)])
+
+ self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity()
+
+ if self.pos_embed is not None:
+ trunc_normal_(self.pos_embed, std=.02)
+
+ self._freeze_stages()
+
+ def _freeze_stages(self):
+ """Freeze parameters."""
+ if self.frozen_stages >= 0:
+ self.patch_embed.eval()
+ for param in self.patch_embed.parameters():
+ param.requires_grad = False
+
+ for i in range(1, self.frozen_stages + 1):
+ m = self.blocks[i]
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ if self.freeze_attn:
+ for i in range(0, self.depth):
+ m = self.blocks[i]
+ m.attn.eval()
+ m.norm1.eval()
+ for param in m.attn.parameters():
+ param.requires_grad = False
+ for param in m.norm1.parameters():
+ param.requires_grad = False
+
+ if self.freeze_ffn:
+ self.pos_embed.requires_grad = False
+ self.patch_embed.eval()
+ for param in self.patch_embed.parameters():
+ param.requires_grad = False
+ for i in range(0, self.depth):
+ m = self.blocks[i]
+ m.mlp.eval()
+ m.norm2.eval()
+ for param in m.mlp.parameters():
+ param.requires_grad = False
+ for param in m.norm2.parameters():
+ param.requires_grad = False
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ super().init_weights(pretrained, patch_padding=self.patch_padding)
+
+ if pretrained is None:
+ def _init_weights(m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ self.apply(_init_weights)
+
+ def get_num_layers(self):
+ return len(self.blocks)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed', 'cls_token'}
+
+ def forward_features(self, x):
+ B, C, H, W = x.shape
+ x, (Hp, Wp) = self.patch_embed(x)
+
+ if self.pos_embed is not None:
+ # fit for multiple GPU training
+ # since the first element for pos embed (sin-cos manner) is zero, it will cause no difference
+ x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1]
+
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+
+ x = self.last_norm(x)
+
+ xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous()
+
+ return xp
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ return x
+
+ def train(self, mode=True):
+ """Convert the model into training mode."""
+ super().train(mode)
+ self._freeze_stages()
diff --git a/dp2/detection/models/vit_pose/topdown_heatmap_simple_head.py b/dp2/detection/models/vit_pose/topdown_heatmap_simple_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..85c08a76e4dab6d89f4b76183f327005211afd35
--- /dev/null
+++ b/dp2/detection/models/vit_pose/topdown_heatmap_simple_head.py
@@ -0,0 +1,505 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Code adapted from: https://github.com/gpastal24/ViTPose-Pytorch
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta, abstractmethod
+
+# from mmpose.core.evaluation.top_down_eval import keypoints_from_heatmaps
+
+
+class TopdownHeatmapBaseHead(nn.Module):
+ """Base class for top-down heatmap heads.
+
+ All top-down heatmap heads should subclass it.
+ All subclass should overwrite:
+
+ Methods:`get_loss`, supporting to calculate loss.
+ Methods:`get_accuracy`, supporting to calculate accuracy.
+ Methods:`forward`, supporting to forward model.
+ Methods:`inference_model`, supporting to inference model.
+ """
+
+ __metaclass__ = ABCMeta
+
+ @abstractmethod
+ def get_loss(self, **kwargs):
+ """Gets the loss."""
+
+ @abstractmethod
+ def get_accuracy(self, **kwargs):
+ """Gets the accuracy."""
+
+ @abstractmethod
+ def forward(self, **kwargs):
+ """Forward function."""
+
+ @abstractmethod
+ def inference_model(self, **kwargs):
+ """Inference function."""
+
+ def decode(self, img_metas, output, **kwargs):
+ """Decode keypoints from heatmaps.
+
+ Args:
+ img_metas (list(dict)): Information about data augmentation
+ By default this includes:
+
+ - "image_file: path to the image file
+ - "center": center of the bbox
+ - "scale": scale of the bbox
+ - "rotation": rotation of the bbox
+ - "bbox_score": score of bbox
+ output (np.ndarray[N, K, H, W]): model predicted heatmaps.
+ """
+ # batch_size = len(img_metas)
+
+ # if 'bbox_id' in img_metas[0]:
+ # bbox_ids = []
+ # else:
+ # bbox_ids = None
+
+ # c = np.zeros((batch_size, 2), dtype=np.float32)
+ # s = np.zeros((batch_size, 2), dtype=np.float32)
+ # image_paths = []
+ # score = np.ones(batch_size)
+ # for i in range(batch_size):
+ # c[i, :] = img_metas[i]['center']
+ # s[i, :] = img_metas[i]['scale']
+ # image_paths.append(img_metas[i]['image_file'])
+
+ # if 'bbox_score' in img_metas[i]:
+ # score[i] = np.array(img_metas[i]['bbox_score']).reshape(-1)
+ # if bbox_ids is not None:
+ # bbox_ids.append(img_metas[i]['bbox_id'])
+
+ # preds, maxvals = keypoints_from_heatmaps(
+ # output,
+ # c,
+ # s,
+ # unbiased=self.test_cfg.get('unbiased_decoding', False),
+ # post_process=self.test_cfg.get('post_process', 'default'),
+ # kernel=self.test_cfg.get('modulate_kernel', 11),
+ # valid_radius_factor=self.test_cfg.get('valid_radius_factor',
+ # 0.0546875),
+ # use_udp=self.test_cfg.get('use_udp', False),
+ # target_type=self.test_cfg.get('target_type', 'GaussianHeatmap'))
+
+ # all_preds = np.zeros((batch_size, preds.shape[1], 3), dtype=np.float32)
+ # all_boxes = np.zeros((batch_size, 6), dtype=np.float32)
+ # all_preds[:, :, 0:2] = preds[:, :, 0:2]
+ # all_preds[:, :, 2:3] = maxvals
+ # all_boxes[:, 0:2] = c[:, 0:2]
+ # all_boxes[:, 2:4] = s[:, 0:2]
+ # all_boxes[:, 4] = np.prod(s * 200.0, axis=1)
+ # all_boxes[:, 5] = score
+
+ # result = {}
+
+ # result['preds'] = all_preds
+ # result['boxes'] = all_boxes
+ # result['image_paths'] = image_paths
+ # result['bbox_ids'] = bbox_ids
+
+ return None
+
+ @staticmethod
+ def _get_deconv_cfg(deconv_kernel):
+ """Get configurations for deconv layers."""
+ if deconv_kernel == 4:
+ padding = 1
+ output_padding = 0
+ elif deconv_kernel == 3:
+ padding = 1
+ output_padding = 1
+ elif deconv_kernel == 2:
+ padding = 0
+ output_padding = 0
+ else:
+ raise ValueError(f'Not supported num_kernels ({deconv_kernel}).')
+
+ return deconv_kernel, padding, output_padding
+
+
+def build_conv_layer(cfg, *args, **kwargs) -> nn.Module:
+ """LICENSE"""
+
+ if cfg is None:
+ cfg_ = dict(type='Conv2d')
+ else:
+ if not isinstance(cfg, dict):
+ raise TypeError('cfg must be a dict')
+ if 'type' not in cfg:
+ raise KeyError('the cfg dict must contain the key "type"')
+ cfg_ = cfg.copy()
+
+ layer_type = cfg_.pop('type')
+ if layer_type != 'Conv2d':
+ raise KeyError(f'Unrecognized layer type {layer_type}')
+ else:
+ conv_layer = nn.Conv2d
+
+ layer = conv_layer(*args, **kwargs, **cfg_)
+
+ return layer
+
+
+def build_upsample_layer(cfg, *args, **kwargs) -> nn.Module:
+
+ if not isinstance(cfg, dict):
+ raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
+ if 'type' not in cfg:
+ raise KeyError(
+ f'the cfg dict must contain the key "type", but got {cfg}')
+ cfg_ = cfg.copy()
+
+ layer_type = cfg_.pop('type')
+ if layer_type != 'deconv':
+ raise KeyError(f'Unrecognized upsample type {layer_type}')
+ else:
+ upsample = nn.ConvTranspose2d
+
+ if upsample is nn.Upsample:
+ cfg_['mode'] = layer_type
+ layer = upsample(*args, **kwargs, **cfg_)
+ return layer
+
+# @HEADS.register_module()
+
+
+class TopdownHeatmapSimpleHead(TopdownHeatmapBaseHead):
+ """Top-down heatmap simple head. paper ref: Bin Xiao et al. ``Simple
+ Baselines for Human Pose Estimation and Tracking``.
+
+ TopdownHeatmapSimpleHead is consisted of (>=0) number of deconv layers
+ and a simple conv2d layer.
+
+ Args:
+ in_channels (int): Number of input channels
+ out_channels (int): Number of output channels
+ num_deconv_layers (int): Number of deconv layers.
+ num_deconv_layers should >= 0. Note that 0 means
+ no deconv layers.
+ num_deconv_filters (list|tuple): Number of filters.
+ If num_deconv_layers > 0, the length of
+ num_deconv_kernels (list|tuple): Kernel sizes.
+ in_index (int|Sequence[int]): Input feature index. Default: 0
+ input_transform (str|None): Transformation type of input features.
+ Options: 'resize_concat', 'multiple_select', None.
+ Default: None.
+
+ - 'resize_concat': Multiple feature maps will be resized to the
+ same size as the first one and then concat together.
+ Usually used in FCN head of HRNet.
+ - 'multiple_select': Multiple feature maps will be bundle into
+ a list and passed into decode head.
+ - None: Only one select feature map is allowed.
+ align_corners (bool): align_corners argument of F.interpolate.
+ Default: False.
+ loss_keypoint (dict): Config for keypoint loss. Default: None.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_deconv_layers=3,
+ num_deconv_filters=(256, 256, 256),
+ num_deconv_kernels=(4, 4, 4),
+ extra=None,
+ in_index=0,
+ input_transform=None,
+ align_corners=False,
+ loss_keypoint=None,
+ train_cfg=None,
+ test_cfg=None,
+ upsample=0,):
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.loss = None
+ self.upsample = upsample
+
+ self.train_cfg = {} if train_cfg is None else train_cfg
+ self.test_cfg = {} if test_cfg is None else test_cfg
+ self.target_type = self.test_cfg.get('target_type', 'GaussianHeatmap')
+
+ self._init_inputs(in_channels, in_index, input_transform)
+ self.in_index = in_index
+ self.align_corners = align_corners
+
+ if extra is not None and not isinstance(extra, dict):
+ raise TypeError('extra should be dict or None.')
+
+ if num_deconv_layers > 0:
+ self.deconv_layers = self._make_deconv_layer(
+ num_deconv_layers,
+ num_deconv_filters,
+ num_deconv_kernels,
+ )
+ elif num_deconv_layers == 0:
+ self.deconv_layers = nn.Identity()
+ else:
+ raise ValueError(
+ f'num_deconv_layers ({num_deconv_layers}) should >= 0.')
+
+ identity_final_layer = False
+ if extra is not None and 'final_conv_kernel' in extra:
+ assert extra['final_conv_kernel'] in [0, 1, 3]
+ if extra['final_conv_kernel'] == 3:
+ padding = 1
+ elif extra['final_conv_kernel'] == 1:
+ padding = 0
+ else:
+ # 0 for Identity mapping.
+ identity_final_layer = True
+ kernel_size = extra['final_conv_kernel']
+ else:
+ kernel_size = 1
+ padding = 0
+
+ if identity_final_layer:
+ self.final_layer = nn.Identity()
+ else:
+ conv_channels = num_deconv_filters[
+ -1] if num_deconv_layers > 0 else self.in_channels
+
+ layers = []
+ if extra is not None:
+ num_conv_layers = extra.get('num_conv_layers', 0)
+ num_conv_kernels = extra.get('num_conv_kernels',
+ [1] * num_conv_layers)
+
+ for i in range(num_conv_layers):
+ layers.append(
+ build_conv_layer(
+ dict(type='Conv2d'),
+ in_channels=conv_channels,
+ out_channels=conv_channels,
+ kernel_size=num_conv_kernels[i],
+ stride=1,
+ padding=(num_conv_kernels[i] - 1) // 2))
+ layers.append(
+ nn.BatchNorm2d(conv_channels)
+ )
+ layers.append(nn.ReLU(inplace=True))
+
+ layers.append(
+ build_conv_layer(
+ cfg=dict(type='Conv2d'),
+ in_channels=conv_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=padding))
+
+ if len(layers) > 1:
+ self.final_layer = nn.Sequential(*layers)
+ else:
+ self.final_layer = layers[0]
+
+ def get_loss(self, output, target, target_weight):
+ """Calculate top-down keypoint loss.
+
+ Note:
+ - batch_size: N
+ - num_keypoints: K
+ - heatmaps height: H
+ - heatmaps weight: W
+
+ Args:
+ output (torch.Tensor[N,K,H,W]): Output heatmaps.
+ target (torch.Tensor[N,K,H,W]): Target heatmaps.
+ target_weight (torch.Tensor[N,K,1]):
+ Weights across different joint types.
+ """
+
+ losses = dict()
+
+ assert not isinstance(self.loss, nn.Sequential)
+ assert target.dim() == 4 and target_weight.dim() == 3
+ losses['heatmap_loss'] = self.loss(output, target, target_weight)
+
+ return losses
+
+ def get_accuracy(self, output, target, target_weight):
+ """Calculate accuracy for top-down keypoint loss.
+
+ Note:
+ - batch_size: N
+ - num_keypoints: K
+ - heatmaps height: H
+ - heatmaps weight: W
+
+ Args:
+ output (torch.Tensor[N,K,H,W]): Output heatmaps.
+ target (torch.Tensor[N,K,H,W]): Target heatmaps.
+ target_weight (torch.Tensor[N,K,1]):
+ Weights across different joint types.
+ """
+
+ accuracy = dict()
+
+ if self.target_type == 'GaussianHeatmap':
+ _, avg_acc, _ = pose_pck_accuracy(
+ output.detach().cpu().numpy(),
+ target.detach().cpu().numpy(),
+ target_weight.detach().cpu().numpy().squeeze(-1) > 0)
+ accuracy['acc_pose'] = float(avg_acc)
+
+ return accuracy
+
+ def forward(self, x):
+ """Forward function."""
+ x = self._transform_inputs(x)
+ x = self.deconv_layers(x)
+ x = self.final_layer(x)
+ return x
+
+ def inference_model(self, x, flip_pairs=None):
+ """Inference function.
+
+ Returns:
+ output_heatmap (np.ndarray): Output heatmaps.
+
+ Args:
+ x (torch.Tensor[N,K,H,W]): Input features.
+ flip_pairs (None | list[tuple]):
+ Pairs of keypoints which are mirrored.
+ """
+ output = self.forward(x)
+
+ if flip_pairs is not None:
+ output_heatmap = flip_back(
+ output.detach().cpu().numpy(),
+ flip_pairs,
+ target_type=self.target_type)
+ # feature is not aligned, shift flipped heatmap for higher accuracy
+ if self.test_cfg.get('shift_heatmap', False):
+ output_heatmap[:, :, :, 1:] = output_heatmap[:, :, :, :-1]
+ else:
+ output_heatmap = output.detach().cpu().numpy()
+ return output_heatmap
+
+ def _init_inputs(self, in_channels, in_index, input_transform):
+ """Check and initialize input transforms.
+
+ The in_channels, in_index and input_transform must match.
+ Specifically, when input_transform is None, only single feature map
+ will be selected. So in_channels and in_index must be of type int.
+ When input_transform is not None, in_channels and in_index must be
+ list or tuple, with the same length.
+
+ Args:
+ in_channels (int|Sequence[int]): Input channels.
+ in_index (int|Sequence[int]): Input feature index.
+ input_transform (str|None): Transformation type of input features.
+ Options: 'resize_concat', 'multiple_select', None.
+
+ - 'resize_concat': Multiple feature maps will be resize to the
+ same size as first one and than concat together.
+ Usually used in FCN head of HRNet.
+ - 'multiple_select': Multiple feature maps will be bundle into
+ a list and passed into decode head.
+ - None: Only one select feature map is allowed.
+ """
+
+ if input_transform is not None:
+ assert input_transform in ['resize_concat', 'multiple_select']
+ self.input_transform = input_transform
+ self.in_index = in_index
+ if input_transform is not None:
+ assert isinstance(in_channels, (list, tuple))
+ assert isinstance(in_index, (list, tuple))
+ assert len(in_channels) == len(in_index)
+ if input_transform == 'resize_concat':
+ self.in_channels = sum(in_channels)
+ else:
+ self.in_channels = in_channels
+ else:
+ assert isinstance(in_channels, int)
+ assert isinstance(in_index, int)
+ self.in_channels = in_channels
+
+ def _transform_inputs(self, inputs):
+ """Transform inputs for decoder.
+
+ Args:
+ inputs (list[Tensor] | Tensor): multi-level img features.
+
+ Returns:
+ Tensor: The transformed inputs
+ """
+ if not isinstance(inputs, list):
+ if not isinstance(inputs, list):
+ if self.upsample > 0:
+ inputs = resize(
+ input=F.relu(inputs),
+ scale_factor=self.upsample,
+ mode='bilinear',
+ align_corners=self.align_corners
+ )
+ return inputs
+
+ if self.input_transform == 'resize_concat':
+ inputs = [inputs[i] for i in self.in_index]
+ upsampled_inputs = [
+ resize(
+ input=x,
+ size=inputs[0].shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners) for x in inputs
+ ]
+ inputs = torch.cat(upsampled_inputs, dim=1)
+ elif self.input_transform == 'multiple_select':
+ inputs = [inputs[i] for i in self.in_index]
+ else:
+ inputs = inputs[self.in_index]
+
+ return inputs
+
+ def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
+ """Make deconv layers."""
+ if num_layers != len(num_filters):
+ error_msg = f'num_layers({num_layers}) ' \
+ f'!= length of num_filters({len(num_filters)})'
+ raise ValueError(error_msg)
+ if num_layers != len(num_kernels):
+ error_msg = f'num_layers({num_layers}) ' \
+ f'!= length of num_kernels({len(num_kernels)})'
+ raise ValueError(error_msg)
+
+ layers = []
+ for i in range(num_layers):
+ kernel, padding, output_padding = \
+ self._get_deconv_cfg(num_kernels[i])
+
+ planes = num_filters[i]
+ layers.append(
+ build_upsample_layer(
+ dict(type='deconv'),
+ in_channels=self.in_channels,
+ out_channels=planes,
+ kernel_size=kernel,
+ stride=2,
+ padding=padding,
+ output_padding=output_padding,
+ bias=False))
+ layers.append(nn.BatchNorm2d(planes))
+ layers.append(nn.ReLU(inplace=True))
+ self.in_channels = planes
+
+ return nn.Sequential(*layers)
+
+ def init_weights(self):
+ """Initialize model weights."""
+ for _, m in self.deconv_layers.named_modules():
+ if isinstance(m, nn.ConvTranspose2d):
+ normal_init(m, std=0.001)
+ elif isinstance(m, nn.BatchNorm2d):
+ constant_init(m, 1)
+ for m in self.final_layer.modules():
+ if isinstance(m, nn.Conv2d):
+ normal_init(m, std=0.001, bias=0)
+ elif isinstance(m, nn.BatchNorm2d):
+ constant_init(m, 1)
diff --git a/dp2/detection/models/vit_pose/vit_pose.py b/dp2/detection/models/vit_pose/vit_pose.py
new file mode 100644
index 0000000000000000000000000000000000000000..7dd6eafd75fd5208625123c2c74fa7510c871136
--- /dev/null
+++ b/dp2/detection/models/vit_pose/vit_pose.py
@@ -0,0 +1,218 @@
+# Code adapted from: https://github.com/gpastal24/ViTPose-Pytorch
+from .topdown_heatmap_simple_head import TopdownHeatmapSimpleHead
+import torch
+from .backbone import ViT
+import torchvision.transforms.functional as F
+import torch.nn as nn
+import tops
+
+model_large = dict(
+ type="TopDown",
+ pretrained=None,
+ backbone=dict(
+ type="ViT",
+ img_size=(256, 192),
+ patch_size=16,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ ratio=1,
+ use_checkpoint=False,
+ mlp_ratio=4,
+ qkv_bias=True,
+ drop_path_rate=0.5,
+ ),
+ keypoint_head=dict(
+ type="TopdownHeatmapSimpleHead",
+ in_channels=1024,
+ num_deconv_layers=2,
+ num_deconv_filters=(256, 256),
+ num_deconv_kernels=(4, 4),
+ extra=dict(
+ final_conv_kernel=1,
+ ),
+ out_channels=17,
+ loss_keypoint=dict(type="JointsMSELoss", use_target_weight=True),
+ ),
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=True,
+ post_process="default",
+ shift_heatmap=False,
+ target_type="GaussianHeatmap",
+ modulate_kernel=11,
+ use_udp=True,
+ ),
+)
+
+
+model_base = dict(
+ type="TopDown",
+ pretrained=None,
+ backbone=dict(
+ type="ViT",
+ img_size=(256, 192),
+ patch_size=16,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ ratio=1,
+ use_checkpoint=False,
+ mlp_ratio=4,
+ qkv_bias=True,
+ drop_path_rate=0.3,
+ ),
+ keypoint_head=dict(
+ type="TopdownHeatmapSimpleHead",
+ in_channels=768,
+ num_deconv_layers=2,
+ num_deconv_filters=(256, 256),
+ num_deconv_kernels=(4, 4),
+ extra=dict(
+ final_conv_kernel=1,
+ ),
+ out_channels=17,
+ loss_keypoint=dict(type="JointsMSELoss", use_target_weight=True),
+ ),
+ train_cfg=dict(),
+ test_cfg=dict(),
+)
+model_huge = dict(
+ type='TopDown',
+ pretrained=None,
+ backbone=dict(
+ type='ViT',
+ img_size=(256, 192),
+ patch_size=16,
+ embed_dim=1280,
+ depth=32,
+ num_heads=16,
+ ratio=1,
+ use_checkpoint=False,
+ mlp_ratio=4,
+ qkv_bias=True,
+ drop_path_rate=0.55,
+ ),
+ keypoint_head=dict(
+ type='TopdownHeatmapSimpleHead',
+ in_channels=1280,
+ num_deconv_layers=2,
+ num_deconv_filters=(256, 256),
+ num_deconv_kernels=(4, 4),
+ extra=dict(final_conv_kernel=1, ),
+ out_channels=17,
+ loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=True,
+ post_process='default',
+ shift_heatmap=False,
+ target_type="GaussianHeatmap",
+ modulate_kernel=11,
+ use_udp=True))
+
+
+class VitPoseModel(nn.Module):
+ def __init__(self, model_name):
+ super().__init__()
+ assert model_name in ["vit_base", "vit_large", "vit_huge"]
+ model = {
+ "vit_base": model_base,
+ "vit_large": model_large,
+ "vit_huge": model_huge
+ }[model_name]
+ weight_url = {
+ "vit_base": "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/90235a26-3b8c-427d-a264-c68155abecdcfcfcd8a9-0388-4575-b85b-607d3c0a9b149bef8f0f-a0f9-4662-a561-1b47ba5f1636",
+ "vit_large": "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/a580a44c-0afd-43ac-a2cb-9956c32b1d1a78c51ecb-81bb-4345-8710-13904cb9dbbe0703db2d-8534-42e0-ac4d-518ab51fe7db",
+ "vit_huge": "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/a33b6ada-4d2f-4ef7-8f83-b33f58b69f5b2a62e181-2131-467d-a900-027157a08571d761fad4-785b-4b84-8596-8932c7857e44"
+ }[model_name]
+ file_name = {
+ "vit_base": "vit-b-multi-coco-595b5e128b.pth",
+ "vit_large": "vit-l-multi-coco-9475d27cec.pth",
+ "vit_huge": "vit-h-multi-coco-dbc06d4337.pth",
+ }[model_name]
+ # Set check_hash to true if you suspect a download error.
+ weight_path = tops.download_file(
+ weight_url, file_name=file_name, check_hash=True)
+
+ self.keypoint_head = tops.to_cuda(TopdownHeatmapSimpleHead(
+ in_channels=model["keypoint_head"]["in_channels"],
+ out_channels=model["keypoint_head"]["out_channels"],
+ num_deconv_filters=model["keypoint_head"]["num_deconv_filters"],
+ num_deconv_kernels=model["keypoint_head"]["num_deconv_kernels"],
+ num_deconv_layers=model["keypoint_head"]["num_deconv_layers"],
+ extra=model["keypoint_head"]["extra"],
+ ))
+ # print(head)
+ self.backbone = tops.to_cuda(ViT(
+ img_size=model["backbone"]["img_size"],
+ patch_size=model["backbone"]["patch_size"],
+ embed_dim=model["backbone"]["embed_dim"],
+ depth=model["backbone"]["depth"],
+ num_heads=model["backbone"]["num_heads"],
+ ratio=model["backbone"]["ratio"],
+ mlp_ratio=model["backbone"]["mlp_ratio"],
+ qkv_bias=model["backbone"]["qkv_bias"],
+ drop_path_rate=model["backbone"]["drop_path_rate"],
+ ))
+ ckpt = torch.load(weight_path, map_location=tops.get_device())
+ self.load_state_dict(ckpt["state_dict"])
+ self.backbone.eval()
+ self.keypoint_head.eval()
+
+ def forward(self, img: torch.Tensor, boxes_ltrb: torch.Tensor):
+ assert img.ndim == 3
+ assert img.dtype == torch.uint8
+ assert boxes_ltrb.ndim == 2 and boxes_ltrb.shape[1] == 4
+ assert boxes_ltrb.dtype == torch.long
+ boxes_ltrb = boxes_ltrb.clamp(0)
+ padded_boxes = torch.zeros_like(boxes_ltrb)
+ images = torch.zeros((len(boxes_ltrb), 3, 256, 192), device=img.device, dtype=torch.float32)
+
+ for i, (x0, y0, x1, y1) in enumerate(boxes_ltrb):
+ x1 = min(img.shape[-1], x1)
+ y1 = min(img.shape[-2], y1)
+ correction_factor = 256 / 192 * (x1 - x0) / (y1 - y0)
+ if correction_factor > 1:
+ # increase y side
+ center = y0 + (y1 - y0) // 2
+ length = (y1-y0).mul(correction_factor).round().long()
+ y0_new = center - length.div(2).long()
+ y1_new = center + length.div(2).long()
+ image_crop = img[:, y0:y1, x0:x1]
+ # print(y1,y2,x1,x2)
+ pad = ((y0_new-y0).abs(), (y1_new-y1).abs())
+# pad = (int(abs(y0_new-y0))), int(abs(y1_new-y1))
+ image_crop = torch.nn.functional.pad(image_crop, [*(0, 0), *pad])
+ padded_boxes[i] = torch.tensor([x0, y0_new, x1, y1_new])
+ else:
+ center = x0 + (x1 - x0) // 2
+ length = (x1-x0).div(correction_factor).round().long()
+ x0_new = center - length.div(2).long()
+ x1_new = center + length.div(2).long()
+ image_crop = img[:, y0:y1, x0:x1]
+ pad = ((x0_new-x0).abs(), (x1_new-x1).abs())
+ image_crop = torch.nn.functional.pad(image_crop, [*pad, ])
+ padded_boxes[i] = torch.tensor([x0_new, y0, x1_new, y1])
+ image_crop = F.resize(image_crop.float(), (256, 192), antialias=True)
+ image_crop = F.normalize(image_crop, mean=[0.485*255, 0.456*255,
+ 0.406*255], std=[0.229*255, 0.224*255, 0.225*255])
+ images[i] = image_crop
+
+ x = self.backbone(images)
+ out = self.keypoint_head(x)
+ pts = torch.empty((out.shape[0], out.shape[1], 3), dtype=torch.float32, device=img.device)
+ # For each human, for each joint: y, x, confidence
+ b, indices = torch.max(out, dim=2)
+ b, indices = torch.max(b, dim=2)
+
+ c, indicesc = torch.max(out, dim=3)
+ c, indicesc = torch.max(c, dim=2)
+ dim1 = torch.tensor(1./64, device=img.device)
+ dim2 = torch.tensor(1./48, device=img.device)
+ for i in range(0, out.shape[0]):
+ pts[i, :, 0] = indicesc[i, :] * dim1 * (padded_boxes[i][3] - padded_boxes[i][1]) + padded_boxes[i][1]
+ pts[i, :, 1] = indices[i, :] * dim2 * (padded_boxes[i][2] - padded_boxes[i][0]) + padded_boxes[i][0]
+ pts[i, :, 2] = c[i, :]
+ pts = pts[:, :, [1, 0, 2]]
+ return pts
diff --git a/dp2/detection/models/vit_pose_maskrcnn.py b/dp2/detection/models/vit_pose_maskrcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..f801d186486ef43e01ed554b7db34a69a7de5e34
--- /dev/null
+++ b/dp2/detection/models/vit_pose_maskrcnn.py
@@ -0,0 +1,73 @@
+import torch
+import lzma
+from pathlib import Path
+from dp2.detection.base import BaseDetector
+from .mask_rcnn import MaskRCNNDetector
+from ..structures import PersonDetection
+from tops import logger
+from .vit_pose.vit_pose import VitPoseModel
+from ..utils import masks_to_boxes
+
+
+def box1_inside_box2(box1: torch.Tensor, box2: torch.Tensor):
+ assert len(box1.shape) == 2
+ assert len(box2.shape) == 2
+ box1_inside = torch.zeros(box1.shape[0], device=box1.device, dtype=torch.bool)
+ # This can be batched
+ for i, box in enumerate(box1):
+ is_outside_lefttop = (box[None, [0, 1]] <= box2[:, [0, 1]]).any(dim=1)
+ is_outside_rightbot = (box[None, [2, 3]] >= box2[:, [2, 3]]).any(dim=1)
+ is_outside = is_outside_lefttop.logical_or(is_outside_rightbot)
+ box1_inside[i] = is_outside.logical_not().any()
+ return box1_inside
+
+
+class MaskRCNNVitPose(BaseDetector):
+
+ def __init__(
+ self,
+ mask_rcnn_cfg,
+ cse_post_process_cfg,
+ score_threshold: float,
+ **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold)
+ self.vit_pose = VitPoseModel("vit_huge")
+
+ self.cse_post_process_cfg = cse_post_process_cfg
+
+ def __call__(self, *args, **kwargs):
+ return self.forward(*args, **kwargs)
+
+ def load_from_cache(self, cache_path: Path):
+ logger.log(f"Loading detection from cache path: {cache_path}",)
+ with lzma.open(cache_path, "rb") as fp:
+ state_dict = torch.load(fp, map_location="cpu")
+ kwargs = dict(
+ post_process_cfg=self.cse_post_process_cfg,
+ )
+ return [
+ state["cls"].from_state_dict(**kwargs, state_dict=state)
+ for state in state_dict
+ ]
+
+ @torch.no_grad()
+ def forward(self, im: torch.Tensor):
+ maskrcnn_dets = self.mask_rcnn(im)
+
+ maskrcnn_person = {
+ k: v[maskrcnn_dets["is_person"]] for k, v in maskrcnn_dets.items()
+ }
+ boxes = masks_to_boxes(maskrcnn_person["segmentation"])
+ keypoints = self.vit_pose(im, boxes).cpu()
+ keypoints[:, :, -1] = keypoints[:, :, -1] >= 0.3
+ persons_without_cse = PersonDetection(
+ maskrcnn_person["segmentation"], **self.cse_post_process_cfg,
+ orig_imshape_CHW=im.shape,
+ keypoints=keypoints
+ )
+ persons_without_cse.pre_process()
+
+ all_detections = [persons_without_cse]
+ return all_detections
diff --git a/dp2/detection/person_detector.py b/dp2/detection/person_detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bbd0df8c2aa44839a5de8bd9a6aeede054ff2ee
--- /dev/null
+++ b/dp2/detection/person_detector.py
@@ -0,0 +1,135 @@
+import torch
+import lzma
+from dp2.detection.base import BaseDetector
+from .utils import combine_cse_maskrcnn_dets
+from .models.cse import CSEDetector
+from .models.mask_rcnn import MaskRCNNDetector
+from .models.keypoint_maskrcnn import KeypointMaskRCNN
+from .structures import CSEPersonDetection, PersonDetection
+from pathlib import Path
+
+
+class CSEPersonDetector(BaseDetector):
+ def __init__(
+ self,
+ score_threshold: float,
+ mask_rcnn_cfg: dict,
+ cse_cfg: dict,
+ cse_post_process_cfg: dict,
+ **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold)
+ self.cse_detector = CSEDetector(**cse_cfg, score_thres=score_threshold)
+ self.post_process_cfg = cse_post_process_cfg
+ self.iou_combine_threshold = self.post_process_cfg.pop("iou_combine_threshold")
+
+ def __call__(self, *args, **kwargs):
+ return self.forward(*args, **kwargs)
+
+ def load_from_cache(self, cache_path: Path):
+ with lzma.open(cache_path, "rb") as fp:
+ state_dict = torch.load(fp)
+ kwargs = dict(
+ post_process_cfg=self.post_process_cfg,
+ embed_map=self.cse_detector.embed_map,
+ )
+ return [
+ state["cls"].from_state_dict(**kwargs, state_dict=state)
+ for state in state_dict
+ ]
+
+ @torch.no_grad()
+ def forward(self, im: torch.Tensor, cse_dets=None):
+ mask_dets = self.mask_rcnn(im)
+ if cse_dets is None:
+ cse_dets = self.cse_detector(im)
+ segmentation = mask_dets["segmentation"]
+ segmentation, cse_dets, _ = combine_cse_maskrcnn_dets(
+ segmentation, cse_dets, self.iou_combine_threshold
+ )
+ det = CSEPersonDetection(
+ segmentation=segmentation,
+ cse_dets=cse_dets,
+ embed_map=self.cse_detector.embed_map,
+ orig_imshape_CHW=im.shape,
+ **self.post_process_cfg
+ )
+ return [det]
+
+
+class MaskRCNNPersonDetector(BaseDetector):
+ def __init__(
+ self,
+ score_threshold: float,
+ mask_rcnn_cfg: dict,
+ cse_post_process_cfg: dict,
+ **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold)
+ self.post_process_cfg = cse_post_process_cfg
+
+ def __call__(self, *args, **kwargs):
+ return self.forward(*args, **kwargs)
+
+ def load_from_cache(self, cache_path: Path):
+ with lzma.open(cache_path, "rb") as fp:
+ state_dict = torch.load(fp)
+ kwargs = dict(
+ post_process_cfg=self.post_process_cfg,
+ )
+ return [
+ state["cls"].from_state_dict(**kwargs, state_dict=state)
+ for state in state_dict
+ ]
+
+ @torch.no_grad()
+ def forward(self, im: torch.Tensor):
+ mask_dets = self.mask_rcnn(im)
+ segmentation = mask_dets["segmentation"]
+ det = PersonDetection(
+ segmentation, **self.post_process_cfg, orig_imshape_CHW=im.shape
+ )
+ return [det]
+
+
+class KeypointMaskRCNNPersonDetector(BaseDetector):
+ def __init__(
+ self,
+ score_threshold: float,
+ mask_rcnn_cfg: dict,
+ cse_post_process_cfg: dict,
+ **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ self.mask_rcnn = KeypointMaskRCNN(
+ **mask_rcnn_cfg, score_threshold=score_threshold
+ )
+ self.post_process_cfg = cse_post_process_cfg
+
+ def __call__(self, *args, **kwargs):
+ return self.forward(*args, **kwargs)
+
+ def load_from_cache(self, cache_path: Path):
+ with lzma.open(cache_path, "rb") as fp:
+ state_dict = torch.load(fp)
+ kwargs = dict(
+ post_process_cfg=self.post_process_cfg,
+ )
+ return [
+ state["cls"].from_state_dict(**kwargs, state_dict=state)
+ for state in state_dict
+ ]
+
+ @torch.no_grad()
+ def forward(self, im: torch.Tensor):
+ mask_dets = self.mask_rcnn(im)
+ segmentation = mask_dets["segmentation"]
+ det = PersonDetection(
+ segmentation,
+ **self.post_process_cfg,
+ orig_imshape_CHW=im.shape,
+ keypoints=mask_dets["keypoints"]
+ )
+ return [det]
diff --git a/dp2/detection/structures.py b/dp2/detection/structures.py
new file mode 100644
index 0000000000000000000000000000000000000000..3daf58f4617feedb7724137721e85e32e94b87b2
--- /dev/null
+++ b/dp2/detection/structures.py
@@ -0,0 +1,504 @@
+import torch
+import numpy as np
+from dp2 import utils
+from dp2.utils import vis_utils, crop_box
+from .utils import (
+ cut_pad_resize, masks_to_boxes,
+ get_kernel, transform_embedding, initialize_cse_boxes
+)
+from .box_utils import get_expanded_bbox, include_box
+import torchvision
+import tops
+from .box_utils_fdf import expand_bbox as expand_bbox_fdf
+
+
+class VehicleDetection:
+
+ def __init__(self, segmentation: torch.BoolTensor) -> None:
+ self.segmentation = segmentation
+ self.boxes = masks_to_boxes(segmentation)
+ assert self.boxes.shape[1] == 4, self.boxes.shape
+ self.n_detections = self.segmentation.shape[0]
+ area = (self.boxes[:, 3] - self.boxes[:, 1]) * (self.boxes[:, 2] - self.boxes[:, 0])
+
+ sorted_idx = torch.argsort(area, descending=True)
+ self.segmentation = self.segmentation[sorted_idx]
+ self.boxes = self.boxes[sorted_idx].cpu()
+
+ def pre_process(self):
+ pass
+
+ def get_crop(self, idx: int, im):
+ assert idx < len(self)
+ box = self.boxes[idx]
+ im = crop_box(self.im, box)
+ mask = crop_box(self.segmentation[idx])
+ mask = mask == 0
+ return dict(img=im, mask=mask.float(), boxes=box)
+
+ def visualize(self, im):
+ if len(self) == 0:
+ return im
+ im = vis_utils.draw_mask(im.clone(), self.segmentation.logical_not())
+ return im
+
+ def __len__(self):
+ return self.n_detections
+
+ @staticmethod
+ def from_state_dict(state_dict, **kwargs):
+ numel = np.prod(state_dict["shape"])
+ arr = np.unpackbits(state_dict["segmentation"].numpy(), count=numel)
+ segmentation = tops.to_cuda(torch.from_numpy(arr)).view(state_dict["shape"])
+ return VehicleDetection(segmentation)
+
+ def state_dict(self, **kwargs):
+ segmentation = torch.from_numpy(np.packbits(self.segmentation.bool().cpu().numpy()))
+ return dict(segmentation=segmentation, cls=self.__class__, shape=self.segmentation.shape)
+
+
+class FaceDetection:
+
+ def __init__(self,
+ boxes_ltrb: torch.LongTensor, target_imsize, fdf128_expand: bool,
+ keypoints: torch.Tensor = None,
+ **kwargs) -> None:
+
+ self.boxes = boxes_ltrb.cpu()
+ assert self.boxes.shape[1] == 4, self.boxes.shape
+ self.target_imsize = tuple(target_imsize)
+ # Sory by area to paste in largest faces last
+ area = (self.boxes[:, 2] - self.boxes[:, 0]) * (self.boxes[:, 3] - self.boxes[:, 1]).view(-1)
+ idx = area.argsort(descending=False)
+ self.boxes = self.boxes[idx]
+ self.fdf128_expand = fdf128_expand
+ self.orig_keypoints = keypoints
+ if keypoints is not None:
+ self.orig_keypoints = self.orig_keypoints[idx]
+ assert keypoints.shape == (len(boxes_ltrb), 17, 2) or \
+ keypoints.shape == (len(boxes_ltrb), 7, 2), keypoints.shape
+
+ def visualize(self, im):
+ if len(self) == 0:
+ return im
+ orig_device = im.device
+ for box in self.boxes:
+ simple_expand = False if self.fdf128_expand else True
+ e_box = torch.from_numpy(expand_bbox_fdf(box.numpy(), im.shape[-2:], simple_expand))
+ im = torchvision.utils.draw_bounding_boxes(im.cpu(), e_box[None], colors=(0, 0, 255), width=2)
+ im = torchvision.utils.draw_bounding_boxes(im.cpu(), self.boxes, colors=(255, 0, 0), width=2)
+ if self.orig_keypoints is not None:
+ im = vis_utils.draw_keypoints(im, self.orig_keypoints, radius=1)
+
+ return im.to(device=orig_device)
+
+ def get_crop(self, idx: int, im):
+ assert idx < len(self)
+ box = self.boxes[idx].numpy()
+ simple_expand = False if self.fdf128_expand else True
+ expanded_boxes = expand_bbox_fdf(box, im.shape[-2:], simple_expand)
+ im = cut_pad_resize(im, expanded_boxes, self.target_imsize, fdf_resize=True)
+
+ # Find the square mask corresponding to box.
+ box_mask = box.copy().astype(float)
+ box_mask[[0, 2]] -= expanded_boxes[0]
+ box_mask[[1, 3]] -= expanded_boxes[1]
+
+ width = expanded_boxes[2] - expanded_boxes[0]
+ resize_factor = self.target_imsize[0] / width
+ box_mask = (box_mask * resize_factor).astype(int)
+ mask = torch.ones((1, *self.target_imsize), device=im.device, dtype=torch.float32)
+ crop_box(mask, box_mask).fill_(0)
+ if self.orig_keypoints is None:
+ return dict(
+ img=im[None], mask=mask[None],
+ boxes=torch.from_numpy(expanded_boxes).view(1, -1))
+
+ keypoint = self.orig_keypoints[idx, :7, :2].clone()
+ keypoint[:, 0] -= expanded_boxes[0]
+ keypoint[:, 1] -= expanded_boxes[1]
+ w = expanded_boxes[2] - expanded_boxes[0]
+ keypoint /= w
+ keypoint = keypoint.clamp(0, 1)
+ return dict(
+ img=im[None], mask=mask[None],
+ boxes=torch.from_numpy(expanded_boxes).view(1, -1),
+ keypoints=keypoint[None])
+
+ def __len__(self):
+ return len(self.boxes)
+
+ @staticmethod
+ def from_state_dict(state_dict, **kwargs):
+ return FaceDetection(
+ state_dict["boxes"].cpu(),
+ keypoints=state_dict["orig_keypoints"] if "orig_keypoints" in state_dict else None,
+ **kwargs)
+
+ def state_dict(self, **kwargs):
+ return dict(
+ boxes=self.boxes,
+ cls=self.__class__,
+ orig_keypoints=self.orig_keypoints)
+
+ def pre_process(self):
+ pass
+
+
+def remove_dilate_in_pad(mask: torch.Tensor, exp_box, orig_imshape):
+ """
+ Dilation happens after padding, which could place dilation in the padded area.
+ Remove this.
+ """
+ x0, y0, x1, y1 = exp_box
+ H, W = orig_imshape
+ # Padding in original image space
+ p_y0 = max(0, -y0)
+ p_y1 = max(y1 - H, 0)
+ p_x0 = max(0, -x0)
+ p_x1 = max(x1 - W, 0)
+ resize_ratio = mask.shape[-2] / (y1-y0)
+ p_x0, p_y0, p_x1, p_y1 = [(_*resize_ratio).floor().long() for _ in [p_x0, p_y0, p_x1, p_y1]]
+ mask[..., :p_y0, :] = 0
+ mask[..., :p_x0] = 0
+ mask[..., mask.shape[-2] - p_y1:, :] = 0
+ mask[..., mask.shape[-1] - p_x1:] = 0
+
+
+class CSEPersonDetection:
+
+ def __init__(self,
+ segmentation, cse_dets,
+ target_imsize,
+ exp_bbox_cfg, exp_bbox_filter,
+ dilation_percentage: float,
+ embed_map: torch.Tensor,
+ orig_imshape_CHW,
+ normalize_embedding: bool) -> None:
+ self.segmentation = segmentation
+ self.cse_dets = cse_dets
+ self.target_imsize = list(target_imsize)
+ self.pre_processed = False
+ self.exp_bbox_cfg = exp_bbox_cfg
+ self.exp_bbox_filter = exp_bbox_filter
+ self.dilation_percentage = dilation_percentage
+ self.embed_map = embed_map
+ self.embed_map_cpu = embed_map.cpu()
+ self.normalize_embedding = normalize_embedding
+ if self.normalize_embedding:
+ embed_map_mean = self.embed_map.mean(dim=0, keepdim=True)
+ embed_map_rstd = ((self.embed_map - embed_map_mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt()
+ self.embed_map_normalized = (self.embed_map - embed_map_mean) * embed_map_rstd
+ self.orig_imshape_CHW = orig_imshape_CHW
+
+ @torch.no_grad()
+ def pre_process(self):
+ if self.pre_processed:
+ return
+ boxes = initialize_cse_boxes(self.segmentation, self.cse_dets["bbox_XYXY"]).cpu()
+ expanded_boxes = []
+ included_boxes = []
+ for i in range(len(boxes)):
+ exp_box = get_expanded_bbox(
+ boxes[i], self.orig_imshape_CHW[1:], self.segmentation[i], **self.exp_bbox_cfg,
+ target_aspect_ratio=self.target_imsize[0]/self.target_imsize[1])
+ if not include_box(exp_box, imshape=self.orig_imshape_CHW[1:], **self.exp_bbox_filter):
+ continue
+ included_boxes.append(i)
+ expanded_boxes.append(exp_box)
+ expanded_boxes = torch.LongTensor(expanded_boxes).view(-1, 4)
+ self.segmentation = self.segmentation[included_boxes]
+ self.cse_dets = {k: v[included_boxes] for k, v in self.cse_dets.items()}
+
+ self.mask = torch.empty((len(expanded_boxes), *self.target_imsize), device=tops.get_device(), dtype=torch.bool)
+ area = self.segmentation.sum(dim=[1, 2]).view(len(expanded_boxes))
+ for i, box in enumerate(expanded_boxes):
+ self.mask[i] = cut_pad_resize(self.segmentation[i:i+1], box, self.target_imsize)[0]
+
+ dilation_kernel = get_kernel(int((self.target_imsize[0]*self.target_imsize[1])**0.5*self.dilation_percentage))
+ self.maskrcnn_mask = self.mask.clone().logical_not()[:, None]
+ self.mask = utils.binary_dilation(self.mask[:, None], dilation_kernel)
+ for i in range(len(expanded_boxes)):
+ remove_dilate_in_pad(self.mask[i], expanded_boxes[i], self.orig_imshape_CHW[1:])
+ self.boxes = expanded_boxes.cpu()
+ self.dilated_boxes = get_dilated_boxes(self.boxes, self.mask)
+
+ self.pre_processed = True
+ self.n_detections = len(self.boxes)
+ self.mask = self.mask.logical_not()
+
+ E_mask = torch.zeros((self.n_detections, 1, *self.target_imsize), device=self.mask.device, dtype=torch.bool)
+ self.vertices = torch.zeros_like(E_mask, dtype=torch.long)
+ for i in range(self.n_detections):
+ E_, E_mask[i] = transform_embedding(
+ self.cse_dets["instance_embedding"][i],
+ self.cse_dets["instance_segmentation"][i],
+ self.boxes[i],
+ self.cse_dets["bbox_XYXY"][i].cpu(),
+ self.target_imsize
+ )
+ self.vertices[i] = utils.from_E_to_vertex(
+ E_[None], E_mask[i:i+1].logical_not(), self.embed_map).squeeze()[None]
+ self.E_mask = E_mask
+
+ sorted_idx = torch.argsort(area, descending=False)
+ self.mask = self.mask[sorted_idx]
+ self.boxes = self.boxes[sorted_idx.cpu()]
+ self.vertices = self.vertices[sorted_idx]
+ self.E_mask = self.E_mask[sorted_idx]
+ self.maskrcnn_mask = self.maskrcnn_mask[sorted_idx]
+
+ def get_crop(self, idx: int, im):
+ self.pre_process()
+ assert idx < len(self)
+ box = self.boxes[idx]
+ mask = self.mask[idx]
+ im = cut_pad_resize(im, box, self.target_imsize).unsqueeze(0)
+
+ vertices_ = self.vertices[idx]
+ E_mask_ = self.E_mask[idx].float()
+ if self.normalize_embedding:
+ embedding = self.embed_map_normalized[vertices_.squeeze(dim=0)].permute(2, 0, 1) * E_mask_
+ else:
+ embedding = self.embed_map[vertices_.squeeze(dim=0)].permute(2, 0, 1) * E_mask_
+
+ return dict(
+ img=im,
+ mask=mask.float()[None],
+ boxes=box.reshape(1, -1),
+ E_mask=E_mask_[None],
+ vertices=vertices_[None],
+ embed_map=self.embed_map,
+ embedding=embedding[None],
+ maskrcnn_mask=self.maskrcnn_mask[idx].float()[None]
+ )
+
+ def __len__(self):
+ self.pre_process()
+ return self.n_detections
+
+ def state_dict(self, after_preprocess=False):
+ """
+ The processed annotations occupy more space than the original detections.
+ """
+ if not after_preprocess:
+ return {
+ "combined_segmentation": self.segmentation.bool(),
+ "cse_instance_segmentation": self.cse_dets["instance_segmentation"].bool(),
+ "cse_instance_embedding": self.cse_dets["instance_embedding"],
+ "cse_bbox_XYXY": self.cse_dets["bbox_XYXY"].long(),
+ "cls": self.__class__,
+ "orig_imshape_CHW": self.orig_imshape_CHW
+ }
+ self.pre_process()
+ def compress_bool(x): return torch.from_numpy(np.packbits(x.bool().cpu().numpy()))
+ return dict(
+ E_mask=compress_bool(self.E_mask),
+ mask=compress_bool(self.mask),
+ maskrcnn_mask=compress_bool(self.maskrcnn_mask),
+ vertices=self.vertices.to(torch.int16).cpu(),
+ cls=self.__class__,
+ boxes=self.boxes,
+ orig_imshape_CHW=self.orig_imshape_CHW,
+ )
+
+ @staticmethod
+ def from_state_dict(
+ state_dict, embed_map,
+ post_process_cfg, **kwargs):
+ after_preprocess = "segmentation" not in state_dict and "combined_segmentation" not in state_dict
+ if after_preprocess:
+ detection = CSEPersonDetection(
+ segmentation=None, cse_dets=None, embed_map=embed_map,
+ orig_imshape_CHW=state_dict["orig_imshape_CHW"],
+ **post_process_cfg)
+ detection.vertices = tops.to_cuda(state_dict["vertices"].long())
+ numel = np.prod(detection.vertices.shape)
+
+ def unpack_bool(x):
+ x = torch.from_numpy(np.unpackbits(x.numpy(), count=numel))
+ return x.view(*detection.vertices.shape)
+ detection.E_mask = tops.to_cuda(unpack_bool(state_dict["E_mask"]))
+ detection.mask = tops.to_cuda(unpack_bool(state_dict["mask"]))
+ detection.maskrcnn_mask = tops.to_cuda(unpack_bool(state_dict["maskrcnn_mask"]))
+ detection.n_detections = len(detection.mask)
+ detection.pre_processed = True
+
+ if isinstance(state_dict["boxes"], np.ndarray):
+ state_dict["boxes"] = torch.from_numpy(state_dict["boxes"])
+ detection.boxes = state_dict["boxes"]
+ return detection
+
+ cse_dets = dict(
+ instance_segmentation=state_dict["cse_instance_segmentation"],
+ instance_embedding=state_dict["cse_instance_embedding"],
+ embed_map=embed_map,
+ bbox_XYXY=state_dict["cse_bbox_XYXY"])
+ cse_dets = {k: tops.to_cuda(v) for k, v in cse_dets.items()}
+
+ segmentation = state_dict["combined_segmentation"]
+ return CSEPersonDetection(
+ segmentation, cse_dets, embed_map=embed_map,
+ orig_imshape_CHW=state_dict["orig_imshape_CHW"],
+ **post_process_cfg)
+
+ def visualize(self, im):
+ self.pre_process()
+ if len(self) == 0:
+ return im
+ im = vis_utils.draw_cropped_masks(
+ im.cpu(), self.mask.cpu(), self.boxes, visualize_instances=False)
+ E = self.embed_map_cpu[self.vertices.long().cpu()].squeeze(1).permute(0, 3, 1, 2)
+ im = vis_utils.draw_cse_all(
+ E, self.E_mask.squeeze(1).bool().cpu(), im,
+ self.boxes, self.embed_map_cpu)
+ im = torchvision.utils.draw_bounding_boxes(im, self.boxes, colors=(255, 0, 0), width=2)
+ return im
+
+
+def shift_and_preprocess_keypoints(keypoints: torch.Tensor, boxes):
+ keypoints = keypoints.clone()
+ N = boxes.shape[0]
+ tops.assert_shape(keypoints, (N, None, 3))
+ tops.assert_shape(boxes, (N, 4))
+ x0, y0, x1, y1 = [_.view(-1, 1) for _ in boxes.T]
+
+ w = x1 - x0
+ h = y1 - y0
+ keypoints[:, :, 0] = (keypoints[:, :, 0] - x0) / w
+ keypoints[:, :, 1] = (keypoints[:, :, 1] - y0) / h
+ def check_outside(x): return (x < 0).logical_or(x > 1)
+ is_outside = check_outside(keypoints[:, :, 0]).logical_or(check_outside(keypoints[:, :, 1]))
+ keypoints[:, :, 2] = keypoints[:, :, 2] > 0
+ keypoints[:, :, 2] = (keypoints[:, :, 2] > 0).logical_and(is_outside.logical_not())
+ return keypoints
+
+
+class PersonDetection:
+
+ def __init__(
+ self,
+ segmentation,
+ target_imsize,
+ exp_bbox_cfg, exp_bbox_filter,
+ dilation_percentage: float,
+ orig_imshape_CHW,
+ kp_vis_thr=None,
+ keypoints=None,
+ **kwargs) -> None:
+ self.segmentation = segmentation
+ self.target_imsize = list(target_imsize)
+ self.pre_processed = False
+ self.exp_bbox_cfg = exp_bbox_cfg
+ self.exp_bbox_filter = exp_bbox_filter
+ self.dilation_percentage = dilation_percentage
+ self.orig_imshape_CHW = orig_imshape_CHW
+ self.orig_keypoints = keypoints
+ if keypoints is not None:
+ assert kp_vis_thr is not None
+ self.kp_vis_thr = kp_vis_thr
+
+ @torch.no_grad()
+ def pre_process(self):
+ if self.pre_processed:
+ return
+ boxes = masks_to_boxes(self.segmentation).cpu()
+ expanded_boxes = []
+ included_boxes = []
+ for i in range(len(boxes)):
+ exp_box = get_expanded_bbox(
+ boxes[i], self.orig_imshape_CHW[1:], self.segmentation[i], **self.exp_bbox_cfg,
+ target_aspect_ratio=self.target_imsize[0]/self.target_imsize[1])
+ if not include_box(exp_box, imshape=self.orig_imshape_CHW[1:], **self.exp_bbox_filter):
+ continue
+ included_boxes.append(i)
+ expanded_boxes.append(exp_box)
+ expanded_boxes = torch.LongTensor(expanded_boxes).view(-1, 4)
+ self.segmentation = self.segmentation[included_boxes]
+ if self.orig_keypoints is not None:
+ self.keypoints = self.orig_keypoints[included_boxes].clone()
+ self.keypoints[:, :, 2] = self.keypoints[:, :, 2] >= self.kp_vis_thr
+ area = self.segmentation.sum(dim=[1, 2]).view(len(expanded_boxes)).cpu()
+ self.mask = torch.empty((len(expanded_boxes), *self.target_imsize), device=tops.get_device(), dtype=torch.bool)
+ for i, box in enumerate(expanded_boxes):
+ self.mask[i] = cut_pad_resize(self.segmentation[i:i+1], box, self.target_imsize)[0]
+ if self.orig_keypoints is not None:
+ self.keypoints = shift_and_preprocess_keypoints(self.keypoints, expanded_boxes)
+ dilation_kernel = get_kernel(int((self.target_imsize[0]*self.target_imsize[1])**0.5*self.dilation_percentage))
+ self.maskrcnn_mask = self.mask.clone().logical_not()[:, None]
+ self.mask = utils.binary_dilation(self.mask[:, None], dilation_kernel)
+ for i in range(len(expanded_boxes)):
+ remove_dilate_in_pad(self.mask[i], expanded_boxes[i], self.orig_imshape_CHW[1:])
+ self.boxes = expanded_boxes
+ self.dilated_boxes = get_dilated_boxes(self.boxes, self.mask)
+
+ self.pre_processed = True
+ self.n_detections = len(self.boxes)
+ self.mask = self.mask.logical_not()
+
+ sorted_idx = torch.argsort(area, descending=False)
+ self.mask = self.mask[sorted_idx]
+ self.boxes = self.boxes[sorted_idx.cpu()]
+ self.segmentation = self.segmentation[sorted_idx]
+ self.maskrcnn_mask = self.maskrcnn_mask[sorted_idx]
+ if self.keypoints is not None:
+ self.keypoints = self.keypoints[sorted_idx.cpu()]
+
+ def get_crop(self, idx: int, im: torch.Tensor):
+ assert idx < len(self)
+ self.pre_process()
+ box = self.boxes[idx]
+ mask = self.mask[idx][None].float()
+ im = cut_pad_resize(im, box, self.target_imsize).unsqueeze(0)
+ batch = dict(
+ img=im, mask=mask, boxes=box.reshape(1, -1),
+ maskrcnn_mask=self.maskrcnn_mask[idx][None].float())
+ if self.keypoints is not None:
+ batch["keypoints"] = self.keypoints[idx:idx+1]
+ return batch
+
+ def __len__(self):
+ self.pre_process()
+ return self.n_detections
+
+ def state_dict(self, **kwargs):
+ return dict(
+ segmentation=self.segmentation.bool(),
+ cls=self.__class__,
+ orig_imshape_CHW=self.orig_imshape_CHW,
+ keypoints=self.orig_keypoints
+ )
+
+ @staticmethod
+ def from_state_dict(
+ state_dict,
+ post_process_cfg, **kwargs):
+ return PersonDetection(
+ state_dict["segmentation"],
+ orig_imshape_CHW=state_dict["orig_imshape_CHW"],
+ **post_process_cfg,
+ keypoints=state_dict["keypoints"])
+
+ def visualize(self, im):
+ self.pre_process()
+ im = im.cpu()
+ if len(self) == 0:
+ return im
+ im = vis_utils.draw_cropped_masks(im.clone(), self.mask.cpu(), self.boxes, visualize_instances=False)
+ if self.keypoints is not None:
+ im = vis_utils.draw_cropped_keypoints(im, self.keypoints, self.boxes)
+ return im
+
+
+def get_dilated_boxes(exp_bbox: torch.LongTensor, mask):
+ """
+ mask: resized mask
+ """
+ assert exp_bbox.shape[0] == mask.shape[0]
+ boxes = masks_to_boxes(mask.squeeze(1)).cpu()
+ H, W = exp_bbox[:, 3] - exp_bbox[:, 1], exp_bbox[:, 2] - exp_bbox[:, 0]
+ boxes[:, [0, 2]] = (boxes[:, [0, 2]] * W[:, None] / mask.shape[-1]).long()
+ boxes[:, [1, 3]] = (boxes[:, [1, 3]] * H[:, None] / mask.shape[-2]).long()
+ boxes[:, [0, 2]] += exp_bbox[:, 0:1]
+ boxes[:, [1, 3]] += exp_bbox[:, 1:2]
+ return boxes
diff --git a/dp2/detection/utils.py b/dp2/detection/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..31bd8cc40dceae5b83bb52e74cdf4be25e764487
--- /dev/null
+++ b/dp2/detection/utils.py
@@ -0,0 +1,179 @@
+import cv2
+import numpy as np
+import torch
+import tops
+from skimage.morphology import disk
+from torchvision.transforms.functional import resize, InterpolationMode
+from functools import lru_cache
+
+
+@lru_cache(maxsize=200)
+def get_kernel(n: int):
+ kernel = disk(n, dtype=bool)
+ return tops.to_cuda(torch.from_numpy(kernel).bool())
+
+
+def transform_embedding(E: torch.Tensor, S: torch.Tensor, exp_bbox, E_bbox, target_imshape):
+ """
+ Transforms the detected embedding/mask directly to the target image shape
+ """
+
+ C, HE, WE = E.shape
+ assert E_bbox[0] >= exp_bbox[0], (E_bbox, exp_bbox)
+ assert E_bbox[2] >= exp_bbox[0]
+ assert E_bbox[1] >= exp_bbox[1]
+ assert E_bbox[3] >= exp_bbox[1]
+ assert E_bbox[2] <= exp_bbox[2]
+ assert E_bbox[3] <= exp_bbox[3]
+
+ x0 = int(np.round((E_bbox[0] - exp_bbox[0]) / (exp_bbox[2] - exp_bbox[0]) * target_imshape[1]))
+ x1 = int(np.round((E_bbox[2] - exp_bbox[0]) / (exp_bbox[2] - exp_bbox[0]) * target_imshape[1]))
+ y0 = int(np.round((E_bbox[1] - exp_bbox[1]) / (exp_bbox[3] - exp_bbox[1]) * target_imshape[0]))
+ y1 = int(np.round((E_bbox[3] - exp_bbox[1]) / (exp_bbox[3] - exp_bbox[1]) * target_imshape[0]))
+ new_E = torch.zeros((C, *target_imshape), device=E.device, dtype=torch.float32)
+ new_S = torch.zeros((target_imshape), device=S.device, dtype=torch.bool)
+
+ E = resize(E, (y1-y0, x1-x0), antialias=True, interpolation=InterpolationMode.BILINEAR)
+ new_E[:, y0:y1, x0:x1] = E
+ S = resize(S[None].float(), (y1-y0, x1-x0), antialias=True, interpolation=InterpolationMode.BILINEAR)[0] > 0
+ new_S[y0:y1, x0:x1] = S
+ return new_E, new_S
+
+
+def pairwise_mask_iou(mask1: torch.Tensor, mask2: torch.Tensor):
+ """
+ mask: shape [N, H, W]
+ """
+ assert len(mask1.shape) == 3
+ assert len(mask2.shape) == 3
+ assert mask1.device == mask2.device, (mask1.device, mask2.device)
+ assert mask2.dtype == mask2.dtype
+ assert mask1.dtype == torch.bool
+ assert mask1.shape[1:] == mask2.shape[1:]
+ N1, H1, W1 = mask1.shape
+ N2, H2, W2 = mask2.shape
+ iou = torch.zeros((N1, N2), dtype=torch.float32)
+ for i in range(N1):
+ cur = mask1[i:i+1]
+ inter = torch.logical_and(cur, mask2).flatten(start_dim=1).float().sum(dim=1).cpu()
+ union = torch.logical_or(cur, mask2).flatten(start_dim=1).float().sum(dim=1).cpu()
+ iou[i] = inter / union
+ return iou
+
+
+def find_best_matches(mask1: torch.Tensor, mask2: torch.Tensor, iou_threshold: float):
+ N1 = mask1.shape[0]
+ N2 = mask2.shape[0]
+ ious = pairwise_mask_iou(mask1, mask2).cpu().numpy()
+ indices = np.array([idx for idx, iou in np.ndenumerate(ious)])
+ ious = ious.flatten()
+ mask = ious >= iou_threshold
+ ious = ious[mask]
+ indices = indices[mask]
+
+ # do not sort by iou to keep ordering of mask rcnn / cse sorting.
+ taken1 = np.zeros((N1), dtype=bool)
+ taken2 = np.zeros((N2), dtype=bool)
+ matches = []
+ for i, j in indices:
+ if taken1[i].any() or taken2[j].any():
+ continue
+ matches.append((i, j))
+ taken1[i] = True
+ taken2[j] = True
+ return matches
+
+
+def combine_cse_maskrcnn_dets(segmentation: torch.Tensor, cse_dets: dict, iou_threshold: float):
+ assert 0 < iou_threshold <= 1
+ matches = find_best_matches(segmentation, cse_dets["im_segmentation"], iou_threshold)
+ H, W = segmentation.shape[1:]
+ new_seg = torch.zeros((len(matches), H, W), dtype=torch.bool, device=segmentation.device)
+ cse_im_seg = cse_dets["im_segmentation"]
+ for idx, (i, j) in enumerate(matches):
+ new_seg[idx] = torch.logical_or(segmentation[i], cse_im_seg[j])
+ cse_dets = dict(
+ instance_segmentation=cse_dets["instance_segmentation"][[j for (i, j) in matches]],
+ instance_embedding=cse_dets["instance_embedding"][[j for (i, j) in matches]],
+ bbox_XYXY=cse_dets["bbox_XYXY"][[j for (i, j) in matches]],
+ scores=cse_dets["scores"][[j for (i, j) in matches]],
+ )
+ return new_seg, cse_dets, np.array(matches).reshape(-1, 2)
+
+
+def initialize_cse_boxes(segmentation: torch.Tensor, cse_boxes: torch.Tensor):
+ """
+ cse_boxes can be outside of segmentation.
+ """
+ boxes = masks_to_boxes(segmentation)
+
+ assert boxes.shape == cse_boxes.shape, (boxes.shape, cse_boxes.shape)
+ combined = torch.stack((boxes, cse_boxes), dim=-1)
+ boxes = torch.cat((
+ combined[:, :2].min(dim=2).values,
+ combined[:, 2:].max(dim=2).values,
+ ), dim=1)
+ return boxes
+
+
+def cut_pad_resize(x: torch.Tensor, bbox, target_shape, fdf_resize=False):
+ """
+ Crops or pads x to fit in the bbox and resize to target shape.
+ """
+ C, H, W = x.shape
+ x0, y0, x1, y1 = bbox
+
+ if y0 > 0 and x0 > 0 and x1 <= W and y1 <= H:
+ new_x = x[:, y0:y1, x0:x1]
+ else:
+ new_x = torch.zeros(((C, y1-y0, x1-x0)), dtype=x.dtype, device=x.device)
+ y0_t = max(0, -y0)
+ y1_t = min(y1-y0, (y1-y0)-(y1-H))
+ x0_t = max(0, -x0)
+ x1_t = min(x1-x0, (x1-x0)-(x1-W))
+ x0 = max(0, x0)
+ y0 = max(0, y0)
+ x1 = min(x1, W)
+ y1 = min(y1, H)
+ new_x[:, y0_t:y1_t, x0_t:x1_t] = x[:, y0:y1, x0:x1]
+ # Nearest upsampling often generates more sharp synthesized identities.
+ interp = InterpolationMode.BICUBIC
+ if (y1-y0) < target_shape[0] and (x1-x0) < target_shape[1]:
+ interp = InterpolationMode.NEAREST
+ antialias = interp == InterpolationMode.BICUBIC
+ if x1 - x0 == target_shape[1] and y1 - y0 == target_shape[0]:
+ return new_x
+ if x.dtype == torch.bool:
+ new_x = resize(new_x.float(), target_shape, interpolation=InterpolationMode.NEAREST) > 0.5
+ elif x.dtype == torch.float32:
+ new_x = resize(new_x, target_shape, interpolation=interp, antialias=antialias)
+ elif x.dtype == torch.uint8:
+ if fdf_resize: # FDF dataset is created with cv2 INTER_AREA.
+ # Incorrect resizing generates noticeable poorer inpaintings.
+ upsampling = ((y1-y0) * (x1-x0)) < (target_shape[0] * target_shape[1])
+ if upsampling:
+ new_x = resize(new_x.float(), target_shape, interpolation=InterpolationMode.BICUBIC,
+ antialias=True).round().clamp(0, 255).byte()
+ else:
+ device = new_x.device
+ new_x = new_x.permute(1, 2, 0).cpu().numpy()
+ new_x = cv2.resize(new_x, target_shape[::-1], interpolation=cv2.INTER_AREA)
+ new_x = torch.from_numpy(np.rollaxis(new_x, 2)).to(device)
+ else:
+ new_x = resize(new_x.float(), target_shape, interpolation=interp,
+ antialias=antialias).round().clamp(0, 255).byte()
+ else:
+ raise ValueError(f"Not supported dtype: {x.dtype}")
+ return new_x
+
+
+def masks_to_boxes(segmentation: torch.Tensor):
+ assert len(segmentation.shape) == 3
+ x = segmentation.any(dim=1).byte() # Compress rows
+ x0 = x.argmax(dim=1)
+
+ x1 = segmentation.shape[2] - x.flip(dims=(1,)).argmax(dim=1)
+ y = segmentation.any(dim=2).byte()
+ y0 = y.argmax(dim=1)
+ y1 = segmentation.shape[1] - y.flip(dims=(1,)).argmax(dim=1)
+ return torch.stack([x0, y0, x1, y1], dim=1)
diff --git a/dp2/discriminator/__init__.py b/dp2/discriminator/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c2c41a0f895fcb39d80b0cfeec8795899a04fa9
--- /dev/null
+++ b/dp2/discriminator/__init__.py
@@ -0,0 +1 @@
+from .sg2_discriminator import SG2Discriminator
diff --git a/dp2/discriminator/sg2_discriminator.py b/dp2/discriminator/sg2_discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..269675d44fec26f1838b56092bf98e28945a3462
--- /dev/null
+++ b/dp2/discriminator/sg2_discriminator.py
@@ -0,0 +1,79 @@
+from sg3_torch_utils.ops import upfirdn2d
+import torch
+import numpy as np
+import torch.nn as nn
+from .. import layers
+from ..layers.sg2_layers import DiscriminatorEpilogue, ResidualBlock, Block
+
+
+class SG2Discriminator(layers.Module):
+
+ def __init__(
+ self,
+ cnum: int,
+ max_cnum_mul: int,
+ imsize,
+ min_fmap_resolution: int,
+ im_channels: int,
+ input_condition: bool,
+ conv_clamp: int,
+ input_cse: bool,
+ cse_nc: int,
+ fix_residual: bool,
+ ):
+ super().__init__()
+
+ cse_nc = 0 if cse_nc is None else cse_nc
+ self._max_imsize = max(imsize)
+ self._cnum = cnum
+ self._max_cnum_mul = max_cnum_mul
+ self._min_fmap_resolution = min_fmap_resolution
+ self._input_condition = input_condition
+ self.input_cse = input_cse
+ self.layers = nn.ModuleList()
+
+ out_ch = self.get_chsize(self._max_imsize)
+ self.from_rgb = Block(
+ im_channels + input_condition*(im_channels+1) + input_cse*(cse_nc+1),
+ out_ch, conv_clamp=conv_clamp
+ )
+ n_levels = int(np.log2(self._max_imsize) - np.log2(min_fmap_resolution))+1
+
+ for i in range(n_levels):
+ resolution = [x//2**i for x in imsize]
+ in_ch = self.get_chsize(max(resolution))
+ out_ch = self.get_chsize(max(max(resolution)//2, min_fmap_resolution))
+
+ down = 2
+ if i == 0:
+ down = 1
+ block = ResidualBlock(
+ in_ch, out_ch, down=down, conv_clamp=conv_clamp,
+ fix_residual=fix_residual
+ )
+ self.layers.append(block)
+ self.output_layer = DiscriminatorEpilogue(
+ out_ch, resolution, conv_clamp=conv_clamp)
+
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter([1, 3, 3, 1]))
+
+ def forward(self, img, condition, mask, embedding=None, E_mask=None, **kwargs):
+ to_cat = [img]
+ if self._input_condition:
+ to_cat.extend([condition, mask, ])
+ if self.input_cse:
+ to_cat.extend([embedding, E_mask])
+ x = torch.cat(to_cat, dim=1)
+ x = self.from_rgb(x)
+
+ for i, layer in enumerate(self.layers):
+ x = layer(x)
+
+ x = self.output_layer(x)
+ return dict(score=x)
+
+ def get_chsize(self, imsize):
+ n = int(np.log2(self._max_imsize) - np.log2(imsize))
+ mul = min(2 ** n, self._max_cnum_mul)
+ ch = self._cnum * mul
+ return int(ch)
diff --git a/dp2/gan_trainer.py b/dp2/gan_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..149e0f6be0602e90f26b67551615d0abb96aad56
--- /dev/null
+++ b/dp2/gan_trainer.py
@@ -0,0 +1,325 @@
+import atexit
+from collections import defaultdict
+import logging
+import typing
+import torch
+import time
+from dp2.utils import vis_utils
+from dp2 import utils
+from tops import logger, checkpointer
+import tops
+from easydict import EasyDict
+
+
+def accumulate_gradients(params, fp16_ddp_accumulate):
+ if len(params) == 0:
+ return
+ params = [param for param in params if param.grad is not None]
+ flat = torch.cat([param.grad.flatten() for param in params])
+ orig_dtype = flat.dtype
+ if tops.world_size() > 1:
+ if fp16_ddp_accumulate:
+ flat = flat.half() / tops.world_size()
+ else:
+ flat /= tops.world_size()
+ torch.distributed.all_reduce(flat)
+ flat = flat.to(orig_dtype)
+ grads = flat.split([param.numel() for param in params])
+ for param, grad in zip(params, grads):
+ param.grad = grad.reshape(param.shape)
+
+
+def accumulate_buffers(module: torch.nn.Module):
+ buffers = [buf for buf in module.buffers()]
+ if len(buffers) == 0:
+ return
+ flat = torch.cat([buf.flatten() for buf in buffers])
+ if tops.world_size() > 1:
+ torch.distributed.all_reduce(flat)
+ flat /= tops.world_size()
+ bufs = flat.split([buf.numel() for buf in buffers])
+ for old, new in zip(buffers, bufs):
+ old.copy_(new.reshape(old.shape), non_blocking=True)
+
+
+def check_ddp_consistency(module):
+ if tops.world_size() == 1:
+ return
+ assert isinstance(module, torch.nn.Module)
+ assert isinstance(module, torch.nn.Module)
+ params_buffs = list(module.named_parameters()) + list(module.named_buffers())
+ for name, tensor in params_buffs:
+ fullname = type(module).__name__ + '.' + name
+ tensor = tensor.detach()
+ if tensor.is_floating_point():
+ tensor = torch.nan_to_num(tensor)
+ other = tensor.clone()
+ torch.distributed.broadcast(tensor=other, src=0)
+ assert (tensor == other).all(), fullname
+
+
+class AverageMeter():
+ def __init__(self) -> None:
+ self.to_log = dict()
+ self.n = defaultdict(int)
+ pass
+
+ @torch.no_grad()
+ def update(self, values: dict):
+ for key, value in values.items():
+ self.n[key] += 1
+ if key in self.to_log:
+ self.to_log[key] += value.mean().detach()
+ else:
+ self.to_log[key] = value.mean().detach()
+
+ def get_average(self):
+ return {key: value / self.n[key] for key, value in self.to_log.items()}
+
+
+class GANTrainer:
+
+ def __init__(
+ self,
+ G: torch.nn.Module,
+ D: torch.nn.Module,
+ G_EMA: torch.nn.Module,
+ D_optim: torch.optim.Optimizer,
+ G_optim: torch.optim.Optimizer,
+ dl_train: typing.Iterator,
+ dl_val: typing.Iterable,
+ scaler_D: torch.cuda.amp.GradScaler,
+ scaler_G: torch.cuda.amp.GradScaler,
+ ims_per_log: int,
+ max_images_to_train: int,
+ loss_handler,
+ ims_per_val: int,
+ evaluate_fn,
+ batch_size: int,
+ broadcast_buffers: bool,
+ fp16_ddp_accumulate: bool,
+ save_state: bool,
+ *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.G = G
+ self.D = D
+ self.G_EMA = G_EMA
+ self.D_optim = D_optim
+ self.G_optim = G_optim
+ self.dl_train = dl_train
+ self.dl_val = dl_val
+ self.scaler_D = scaler_D
+ self.scaler_G = scaler_G
+ self.loss_handler = loss_handler
+ self.max_images_to_train = max_images_to_train
+ self.images_per_val = ims_per_val
+ self.images_per_log = ims_per_log
+ self.evaluate_fn = evaluate_fn
+ self.batch_size = batch_size
+ self.broadcast_buffers = broadcast_buffers
+ self.fp16_ddp_accumulate = fp16_ddp_accumulate
+
+ self.train_state = EasyDict(
+ next_log_step=0,
+ next_val_step=ims_per_val,
+ total_time=0
+ )
+
+ checkpointer.register_models(dict(
+ generator=G, discriminator=D, EMA_generator=G_EMA,
+ D_optimizer=D_optim,
+ G_optimizer=G_optim,
+ train_state=self.train_state,
+ scaler_D=self.scaler_D,
+ scaler_G=self.scaler_G
+ ))
+ if checkpointer.has_checkpoint():
+ checkpointer.load_registered_models()
+ logger.log(f"Resuming training from: global step: {logger.global_step()}")
+ else:
+ logger.add_dict({
+ "stats/discriminator_parameters": tops.num_parameters(self.D),
+ "stats/generator_parameters": tops.num_parameters(self.G),
+ }, commit=False)
+ if save_state:
+ # If the job is unexpectedly killed, there could be a mismatch between previously saved checkpoint and the current checkpoint.
+ atexit.register(checkpointer.save_registered_models)
+
+ self._ims_per_log = ims_per_log
+
+ self.to_log = AverageMeter()
+ self.trainable_params_D = [param for param in self.D.parameters() if param.requires_grad]
+ self.trainable_params_G = [param for param in self.G.parameters() if param.requires_grad]
+ logger.add_dict({
+ "stats/discriminator_trainable_parameters": sum(p.numel() for p in self.trainable_params_D),
+ "stats/generator_trainable_parameters": sum(p.numel() for p in self.trainable_params_G),
+ }, commit=False, level=logging.INFO)
+ check_ddp_consistency(self.D)
+ check_ddp_consistency(self.G)
+ check_ddp_consistency(self.G_EMA.generator)
+
+ def train_loop(self):
+ self.log_time()
+ while logger.global_step() <= self.max_images_to_train:
+ batch = next(self.dl_train)
+ self.G_EMA.update_beta()
+ self.to_log.update(self.step_D(batch))
+ self.to_log.update(self.step_G(batch))
+ self.G_EMA.update(self.G)
+
+ if logger.global_step() >= self.train_state.next_log_step:
+ to_log = {f"loss/{key}": item.item() for key, item in self.to_log.get_average().items()}
+ to_log.update({"amp/grad_scale_G": self.scaler_G.get_scale()})
+ to_log.update({"amp/grad_scale_D": self.scaler_D.get_scale()})
+ self.to_log = AverageMeter()
+ logger.add_dict(to_log, commit=True)
+ self.train_state.next_log_step += self.images_per_log
+ if self.scaler_D.get_scale() < 1e-8 or self.scaler_G.get_scale() < 1e-8:
+ print("Stopping training as gradient scale < 1e-8")
+ logger.log("Stopping training as gradient scale < 1e-8")
+ break
+
+ if logger.global_step() >= self.train_state.next_val_step:
+ self.evaluate()
+ self.log_time()
+ self.save_images()
+ self.train_state.next_val_step += self.images_per_val
+ logger.step(self.batch_size*tops.world_size())
+ logger.log(f"Reached end of training at step {logger.global_step()}.")
+ checkpointer.save_registered_models()
+
+ def estimate_ims_per_hour(self):
+ batch = next(self.dl_train)
+ n_ims = int(100e3)
+ n_steps = int(n_ims / (self.batch_size * tops.world_size()))
+ n_ims = n_steps * self.batch_size * tops.world_size()
+ for i in range(10): # Warmup
+ self.G_EMA.update_beta()
+ self.step_D(batch)
+ self.step_G(batch)
+ self.G_EMA.update(self.G)
+ start_time = time.time()
+ for i in utils.tqdm_(list(range(n_steps))):
+ self.G_EMA.update_beta()
+ self.step_D(batch)
+ self.step_G(batch)
+ self.G_EMA.update(self.G)
+ total_time = time.time() - start_time
+ ims_per_sec = n_ims / total_time
+ ims_per_hour = ims_per_sec * 60*60
+ ims_per_day = ims_per_hour * 24
+ logger.log(f"Images per hour: {ims_per_hour/1e6:.3f}M")
+ logger.log(f"Images per day: {ims_per_day/1e6:.3f}M")
+ import math
+ ims_per_4_day = int(math.ceil(ims_per_day / tops.world_size() * 4))
+ logger.log(f"Images per 4 days: {ims_per_4_day}")
+ logger.add_dict({
+ "stats/ims_per_day": ims_per_day,
+ "stats/ims_per_4_day": ims_per_4_day
+ })
+
+ def log_time(self):
+ if not hasattr(self, "start_time"):
+ self.start_time = time.time()
+ self.last_time_step = logger.global_step()
+ return
+ n_images = logger.global_step() - self.last_time_step
+ if n_images == 0:
+ return
+ n_secs = time.time() - self.start_time
+ n_ims_per_sec = n_images / n_secs
+ training_time_hours = n_secs / 60 / 60
+ self.train_state.total_time += training_time_hours
+ remaining_images = self.max_images_to_train - logger.global_step()
+ remaining_time = remaining_images / n_ims_per_sec / 60 / 60
+ logger.add_dict({
+ "stats/n_ims_per_sec": n_ims_per_sec,
+ "stats/total_traing_time_hours": self.train_state.total_time,
+ "stats/remaining_time_hours": remaining_time
+ })
+ self.last_time_step = logger.global_step()
+ self.start_time = time.time()
+
+ def save_images(self):
+ dl_val = iter(self.dl_val)
+ batch = next(dl_val)
+ # TRUNCATED visualization
+ ims_to_log = 8
+ self.G_EMA.eval()
+ z = self.G.get_z(batch["img"])
+ fakes_truncated = self.G_EMA.sample(**batch, truncation_value=0)["img"]
+ fakes_truncated = utils.denormalize_img(fakes_truncated).mul(255).byte()[:ims_to_log].cpu()
+ if "__key__" in batch:
+ batch.pop("__key__")
+ real = vis_utils.visualize_batch(**tops.to_cpu(batch))[:ims_to_log]
+ to_vis = torch.cat((real, fakes_truncated))
+ logger.add_images("images/truncated", to_vis, nrow=2)
+
+ # Diverse images
+ ims_diverse = 3
+ batch = next(dl_val)
+ to_vis = []
+
+ for i in range(ims_diverse):
+ z = self.G.get_z(batch["img"])[:1].repeat(batch["img"].shape[0], 1)
+ fakes = utils.denormalize_img(self.G_EMA(**batch, z=z)["img"]).mul(255).byte()[:ims_to_log].cpu()
+ to_vis.append(fakes)
+ if "__key__" in batch:
+ batch.pop("__key__")
+ reals = vis_utils.visualize_batch(**tops.to_cpu(batch))[:ims_to_log]
+ to_vis.insert(0, reals)
+ to_vis = torch.cat(to_vis)
+ logger.add_images("images/diverse", to_vis, nrow=ims_diverse+1)
+
+ self.G_EMA.train()
+ pass
+
+ def evaluate(self):
+ logger.log("Stating evaluation.")
+ self.G_EMA.eval()
+ try:
+ checkpointer.save_registered_models(max_keep=3)
+ except Exception:
+ logger.log("Could not save checkpoint.")
+ if self.broadcast_buffers:
+ check_ddp_consistency(self.G)
+ check_ddp_consistency(self.D)
+ metrics = self.evaluate_fn(generator=self.G_EMA, dataloader=self.dl_val)
+ metrics = {f"metrics/{k}": v for k, v in metrics.items()}
+ logger.add_dict(metrics, level=logger.logger.INFO)
+
+ def step_D(self, batch):
+ utils.set_requires_grad(self.trainable_params_D, True)
+ utils.set_requires_grad(self.trainable_params_G, False)
+ tops.zero_grad(self.D)
+ loss, to_log = self.loss_handler.D_loss(batch, grad_scaler=self.scaler_D)
+ with torch.autograd.profiler.record_function("D_step"):
+ self.scaler_D.scale(loss).backward()
+ accumulate_gradients(self.trainable_params_D, fp16_ddp_accumulate=self.fp16_ddp_accumulate)
+ if self.broadcast_buffers:
+ accumulate_buffers(self.D)
+ accumulate_buffers(self.G)
+ # Step will not unscale if unscale is called previously.
+ self.scaler_D.step(self.D_optim)
+ self.scaler_D.update()
+ utils.set_requires_grad(self.trainable_params_D, False)
+ utils.set_requires_grad(self.trainable_params_G, False)
+ return to_log
+
+ def step_G(self, batch):
+ utils.set_requires_grad(self.trainable_params_D, False)
+ utils.set_requires_grad(self.trainable_params_G, True)
+ tops.zero_grad(self.G)
+ loss, to_log = self.loss_handler.G_loss(batch, grad_scaler=self.scaler_G)
+ with torch.autograd.profiler.record_function("G_step"):
+ self.scaler_G.scale(loss).backward()
+ accumulate_gradients(self.trainable_params_G, fp16_ddp_accumulate=self.fp16_ddp_accumulate)
+ if self.broadcast_buffers:
+ accumulate_buffers(self.G)
+ accumulate_buffers(self.D)
+ self.scaler_G.step(self.G_optim)
+ self.scaler_G.update()
+ utils.set_requires_grad(self.trainable_params_D, False)
+ utils.set_requires_grad(self.trainable_params_G, False)
+ return to_log
diff --git a/dp2/generator/__init__.py b/dp2/generator/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/dp2/generator/base.py b/dp2/generator/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad403785887b8db971956a23ac8bdd330a04c509
--- /dev/null
+++ b/dp2/generator/base.py
@@ -0,0 +1,149 @@
+import torch
+import numpy as np
+import tqdm
+import tops
+from ..layers import Module
+from ..layers.sg2_layers import FullyConnectedLayer
+
+
+class BaseGenerator(Module):
+
+ def __init__(self, z_channels: int):
+ super().__init__()
+ self.z_channels = z_channels
+ self.latent_space = "Z"
+
+ @torch.no_grad()
+ def get_z(
+ self,
+ x: torch.Tensor = None,
+ z: torch.Tensor = None,
+ truncation_value: float = None,
+ batch_size: int = None,
+ dtype=None, device=None) -> torch.Tensor:
+ """Generates a latent variable for generator.
+ """
+ if z is not None:
+ return z
+ if x is not None:
+ batch_size = x.shape[0]
+ dtype = x.dtype
+ device = x.device
+ if device is None:
+ device = tops.get_device()
+ if truncation_value == 0:
+ return torch.zeros((batch_size, self.z_channels), device=device, dtype=dtype)
+ z = torch.randn((batch_size, self.z_channels), device=device, dtype=dtype)
+ if truncation_value is None:
+ return z
+ while z.abs().max() > truncation_value:
+ m = z.abs() > truncation_value
+ z[m] = torch.rand_like(z)[m]
+ return z
+
+ def sample(self, truncation_value, z=None, **kwargs):
+ """
+ Samples via interpolating to the mean (0).
+ """
+ if truncation_value is None:
+ return self.forward(**kwargs)
+ truncation_value = max(0, truncation_value)
+ truncation_value = min(truncation_value, 1)
+ if z is None:
+ z = self.get_z(kwargs["condition"])
+ z = z * truncation_value
+ return self.forward(**kwargs, z=z)
+
+
+class SG2StyleNet(torch.nn.Module):
+ def __init__(self,
+ z_dim, # Input latent (Z) dimensionality.
+ w_dim, # Intermediate latent (W) dimensionality.
+ num_layers=2, # Number of mapping layers.
+ lr_multiplier=0.01, # Learning rate multiplier for the mapping layers.
+ w_avg_beta=0.998, # Decay for tracking the moving average of W during training.
+ ):
+ super().__init__()
+ self.z_dim = z_dim
+ self.w_dim = w_dim
+ self.num_layers = num_layers
+ self.w_avg_beta = w_avg_beta
+ # Construct layers.
+ features = [self.z_dim] + [self.w_dim] * self.num_layers
+ for idx, in_features, out_features in zip(range(num_layers), features[:-1], features[1:]):
+ layer = FullyConnectedLayer(in_features, out_features, activation='lrelu', lr_multiplier=lr_multiplier)
+ setattr(self, f'fc{idx}', layer)
+ self.register_buffer('w_avg', torch.zeros([w_dim]))
+
+ def forward(self, z, update_emas=False, **kwargs):
+ tops.assert_shape(z, [None, self.z_dim])
+
+ # Embed, normalize, and concatenate inputs.
+ x = z.to(torch.float32)
+ x = x * (x.square().mean(1, keepdim=True) + 1e-8).rsqrt()
+ # Execute layers.
+ for idx in range(self.num_layers):
+ x = getattr(self, f'fc{idx}')(x)
+ # Update moving average of W.
+ if update_emas:
+ self.w_avg.copy_(x.float().detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
+
+ return x
+
+ def extra_repr(self):
+ return f'z_dim={self.z_dim:d}, w_dim={self.w_dim:d}'
+
+ def update_w(self, n=int(10e3), batch_size=32):
+ """
+ Calculate w_ema over n iterations.
+ Useful in cases where w_ema is calculated incorrectly during training.
+ """
+ n = n // batch_size
+ for i in tqdm.trange(n, desc="Updating w"):
+ z = torch.randn((batch_size, self.z_dim), device=tops.get_device())
+ self(z, update_emas=True)
+
+ def get_truncated(self, truncation_value, condition, z=None, **kwargs):
+ if z is None:
+ z = torch.randn((condition.shape[0], self.z_dim), device=tops.get_device())
+ w = self(z)
+ truncation_value = max(0, truncation_value)
+ truncation_value = min(truncation_value, 1)
+ return self.w_avg.to(w.dtype).lerp(w, truncation_value)
+
+ def multi_modal_truncate(self, truncation_value, condition, w_indices, z=None, **kwargs):
+ truncation_value = max(0, truncation_value)
+ truncation_value = min(truncation_value, 1)
+ if z is None:
+ z = torch.randn((condition.shape[0], self.z_dim), device=tops.get_device())
+ w = self(z)
+ if w_indices is None:
+ w_indices = np.random.randint(0, len(self.w_centers), size=(len(w)))
+ w_centers = self.w_centers[w_indices].to(w.device)
+ w = w_centers.to(w.dtype).lerp(w, truncation_value)
+ return w
+
+class BaseStyleGAN(BaseGenerator):
+
+ def __init__(self, z_channels: int, w_dim: int):
+ super().__init__(z_channels)
+ self.style_net = SG2StyleNet(z_channels, w_dim)
+ self.latent_space = "W"
+
+ def get_w(self, z, update_emas):
+ return self.style_net(z, update_emas=update_emas)
+
+ @torch.no_grad()
+ def sample(self, truncation_value, **kwargs):
+ if truncation_value is None:
+ return self.forward(**kwargs)
+ w = self.style_net.get_truncated(truncation_value, **kwargs)
+ return self.forward(**kwargs, w=w)
+
+ def update_w(self, *args, **kwargs):
+ self.style_net.update_w(*args, **kwargs)
+
+ @torch.no_grad()
+ def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs):
+ w = self.style_net.multi_modal_truncate(truncation_value, w_indices=w_indices, **kwargs)
+ return self.forward(**kwargs, w=w)
diff --git a/dp2/generator/deep_privacy1.py b/dp2/generator/deep_privacy1.py
new file mode 100644
index 0000000000000000000000000000000000000000..531d780ad995081b233c8d8d73242571eb7d33c4
--- /dev/null
+++ b/dp2/generator/deep_privacy1.py
@@ -0,0 +1,648 @@
+import torch
+import torch.nn as nn
+from easydict import EasyDict
+from .base import BaseGenerator
+import numpy as np
+from typing import List
+
+
+class LatentVariableConcat(nn.Module):
+
+ def __init__(self, conv2d_config):
+ super().__init__()
+
+ def forward(self, _inp):
+ x, mask, batch = _inp
+ z = batch["z"]
+ x = torch.cat((x, z), dim=1)
+ return (x, mask, batch)
+
+
+def get_padding(kernel_size: int, dilation: int, stride: int):
+ out = (dilation * (kernel_size - 1) - 1) / 2 + 1
+ return int(np.floor(out))
+
+
+class Conv2d(nn.Conv2d):
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+ padding=None, dilation=1, groups=1,
+ bias=True, padding_mode='zeros',
+ demodulation=False, wsconv=False, gain=1,
+ *args, **kwargs):
+ if padding is None:
+ padding = get_padding(kernel_size, dilation, stride)
+ super().__init__(
+ in_channels, out_channels, kernel_size, stride, padding, dilation,
+ groups, bias, padding_mode)
+ self.demodulation = demodulation
+ self.wsconv = wsconv
+ if self.wsconv:
+ fan_in = np.prod(self.weight.shape[1:]) / self.groups
+ self.ws_scale = gain / np.sqrt(fan_in)
+ nn.init.normal_(self.weight)
+ if bias:
+ nn.init.constant_(self.bias, val=0)
+ assert not self.padding_mode == "circular",\
+ "conv2d_forward does not support circular padding. Look at original pytorch code"
+
+ def _get_weight(self):
+ weight = self.weight
+ if self.wsconv:
+ weight = self.ws_scale * weight
+ if self.demodulation:
+ demod = torch.rsqrt(weight.pow(2).sum([1, 2, 3]) + 1e-7)
+ weight = weight * demod.view(self.out_channels, 1, 1, 1)
+ return weight
+
+ def conv2d_forward(self, x, weight, bias=True):
+ bias_ = None
+ if bias:
+ bias_ = self.bias
+ return nn.functional.conv2d(x, weight, bias_, self.stride,
+ self.padding, self.dilation, self.groups)
+
+ def forward(self, _inp):
+ x, mask = _inp
+ weight = self._get_weight()
+ return self.conv2d_forward(x, weight), mask
+
+ def __repr__(self):
+ return ", ".join([
+ super().__repr__(),
+ f"Demodulation={self.demodulation}",
+ f"Weight Scale={self.wsconv}",
+ f"Bias={self.bias is not None}"
+ ])
+
+
+class LeakyReLU(nn.LeakyReLU):
+
+ def forward(self, _inp):
+ x, mask = _inp
+ return super().forward(x), mask
+
+
+class AvgPool2d(nn.AvgPool2d):
+
+ def forward(self, _inp):
+ x, mask, *args = _inp
+ x = super().forward(x)
+ mask = super().forward(mask)
+ if len(args) > 0:
+ return (x, mask, *args)
+ return x, mask
+
+
+def up(x):
+ if x.shape[0] == 1 and x.shape[2] == 1 and x.shape[3] == 1:
+ # Analytical normalization
+ return x
+ return nn.functional.interpolate(
+ x, scale_factor=2, mode="nearest")
+
+
+class NearestUpsample(nn.Module):
+
+ def forward(self, _inp):
+ x, mask, *args = _inp
+ x = up(x)
+ mask = up(mask)
+ if len(args) > 0:
+ return (x, mask, *args)
+ return x, mask
+
+
+class PixelwiseNormalization(nn.Module):
+
+ def forward(self, _inp):
+ x, mask = _inp
+ norm = torch.rsqrt((x**2).mean(dim=1, keepdim=True) + 1e-7)
+ return x * norm, mask
+
+
+class Linear(nn.Linear):
+
+ def __init__(self, in_features, out_features):
+ super().__init__(in_features, out_features)
+ self.linear = nn.Linear(in_features, out_features)
+ fanIn = in_features
+ self.wtScale = 1 / np.sqrt(fanIn)
+
+ nn.init.normal_(self.weight)
+ nn.init.constant_(self.bias, val=0)
+
+ def _get_weight(self):
+ return self.weight * self.wtScale
+
+ def forward_linear(self, x, weight):
+ return nn.functional.linear(x, weight, self.bias)
+
+ def forward(self, x):
+ return self.forward_linear(x, self._get_weight())
+
+
+class OneHotPoseConcat(nn.Module):
+
+ def forward(self, _inp):
+ x, mask, batch = _inp
+ landmarks = batch["landmarks_oh"]
+ res = x.shape[-1]
+ landmark = landmarks[res]
+ x = torch.cat((x, landmark), dim=1)
+ del batch["landmarks_oh"][res]
+ return x, mask, batch
+
+
+def transition_features(x_old, x_new, transition_variable):
+ assert x_old.shape == x_new.shape,\
+ "Old shape: {}, New: {}".format(x_old.shape, x_new.shape)
+ return torch.lerp(x_old.float(), x_new.float(), transition_variable)
+
+
+class TransitionBlock(nn.Module):
+
+ def forward(self, _inp):
+ x, mask, batch = _inp
+ x = transition_features(
+ batch["x_old"], x, batch["transition_value"])
+ mask = transition_features(
+ batch["mask_old"], mask, batch["transition_value"])
+ del batch["x_old"]
+ del batch["mask_old"]
+ return x, mask, batch
+
+
+class UnetSkipConnection(nn.Module):
+
+ def __init__(self, conv2d_config: dict, in_channels: int,
+ out_channels: int, resolution: int,
+ residual: bool, enabled: bool):
+ super().__init__()
+ self.use_iconv = conv2d_config.conv.type == "iconv"
+ self._in_channels = in_channels
+ self._out_channels = out_channels
+ self._resolution = resolution
+ self._enabled = enabled
+ self._residual = residual
+ if self.use_iconv:
+ self.beta0 = torch.nn.Parameter(torch.tensor(1.))
+ self.beta1 = torch.nn.Parameter(torch.tensor(1.))
+ else:
+ if self._residual:
+ self.conv = build_base_conv(
+ conv2d_config, False, in_channels // 2,
+ out_channels, kernel_size=1, padding=0)
+ else:
+ self.conv = ConvAct(
+ conv2d_config, in_channels, out_channels,
+ kernel_size=1, padding=0)
+
+ def forward(self, _inp):
+ if not self._enabled:
+ return _inp
+ x, mask, batch = _inp
+ skip_x, skip_mask = batch["unet_features"][self._resolution]
+ assert x.shape == skip_x.shape, (x.shape, skip_x.shape)
+ del batch["unet_features"][self._resolution]
+ if self.use_iconv:
+ denom = skip_mask * self.beta0.relu() + mask * self.beta1.relu() + 1e-8
+ gamma = skip_mask * self.beta0.relu() / denom
+ x = skip_x * gamma + (1 - gamma) * x
+ mask = skip_mask * gamma + (1 - gamma) * mask
+ else:
+ if self._residual:
+ skip_x, skip_mask = self.conv((skip_x, skip_mask))
+ x = (x + skip_x) / np.sqrt(2)
+ if self._probabilistic:
+ mask = (mask + skip_mask) / np.sqrt(2)
+ else:
+ x = torch.cat((x, skip_x), dim=1)
+ x, mask = self.conv((x, mask))
+ return x, mask, batch
+
+ def __repr__(self):
+ return " ".join([
+ self.__class__.__name__,
+ f"In channels={self._in_channels}",
+ f"Out channels={self._out_channels}",
+ f"Residual: {self._residual}",
+ f"Enabled: {self._enabled}"
+ f"IConv: {self.use_iconv}"
+ ])
+
+
+def get_conv(ctype, post_act):
+ type2conv = {
+ "conv": Conv2d,
+ "gconv": GatedConv
+ }
+ # Do not apply for output layer
+ if not post_act and ctype in ["gconv", "iconv"]:
+ return type2conv["conv"]
+ assert ctype in type2conv
+ return type2conv[ctype]
+
+
+def build_base_conv(
+ conv2d_config, post_act: bool, *args, **kwargs) -> nn.Conv2d:
+ for k, v in conv2d_config.conv.items():
+ assert k not in kwargs
+ kwargs[k] = v
+ # Demodulation should not be used for output layers.
+ demodulation = conv2d_config.normalization == "demodulation" and post_act
+ kwargs["demodulation"] = demodulation
+ conv = get_conv(conv2d_config.conv.type, post_act)
+ return conv(*args, **kwargs)
+
+
+def build_post_activation(in_channels, conv2d_config) -> List[nn.Module]:
+ _layers = []
+ negative_slope = conv2d_config.leaky_relu_nslope
+ _layers.append(LeakyReLU(negative_slope, inplace=True))
+ if conv2d_config.normalization == "pixel_wise":
+ _layers.append(PixelwiseNormalization())
+ return _layers
+
+
+def build_avgpool(conv2d_config, kernel_size) -> nn.AvgPool2d:
+ return AvgPool2d(kernel_size)
+
+
+def build_convact(conv2d_config, *args, **kwargs):
+ conv = build_base_conv(conv2d_config, True, *args, **kwargs)
+ out_channels = conv.out_channels
+ post_act = build_post_activation(out_channels, conv2d_config)
+ return nn.Sequential(conv, *post_act)
+
+
+class ConvAct(nn.Module):
+
+ def __init__(self, conv2d_config, *args, **kwargs):
+ super().__init__()
+ self._conv2d_config = conv2d_config
+ conv = build_base_conv(conv2d_config, True, *args, **kwargs)
+ self.in_channels = conv.in_channels
+ self.out_channels = conv.out_channels
+ _layers = [conv]
+ _layers.extend(build_post_activation(self.out_channels, conv2d_config))
+ self.layers = nn.Sequential(*_layers)
+
+ def forward(self, _inp):
+ return self.layers(_inp)
+
+
+class GatedConv(Conv2d):
+
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
+ out_channels *= 2
+ super().__init__(in_channels, out_channels, *args, **kwargs)
+ assert self.out_channels % 2 == 0
+ self.lrelu = nn.LeakyReLU(0.2, inplace=True)
+ self.sigmoid = nn.Sigmoid()
+
+ def conv2d_forward(self, x, weight, bias=True):
+ x_ = super().conv2d_forward(x, weight, bias)
+ x = x_[:, :self.out_channels // 2]
+ y = x_[:, self.out_channels // 2:]
+ x = self.lrelu(x)
+ y = y.sigmoid()
+ assert x.shape == y.shape, f"{x.shape}, {y.shape}"
+ return x * y
+
+
+class BasicBlock(nn.Module):
+
+ def __init__(
+ self, conv2d_config, resolution: int, in_channels: int,
+ out_channels: List[int], residual: bool):
+ super().__init__()
+ assert len(out_channels) == 2
+ self._resolution = resolution
+ self._residual = residual
+ self.out_channels = out_channels
+ _layers = []
+ _in_channels = in_channels
+ for out_ch in out_channels:
+ conv = build_base_conv(
+ conv2d_config, True, _in_channels, out_ch, kernel_size=3,
+ resolution=resolution)
+ _layers.append(conv)
+ _layers.extend(build_post_activation(_in_channels, conv2d_config))
+ _in_channels = out_ch
+ self.layers = nn.Sequential(*_layers)
+ if self._residual:
+ self.residual_conv = build_base_conv(
+ conv2d_config, post_act=False, in_channels=in_channels,
+ out_channels=out_channels[-1],
+ kernel_size=1, padding=0)
+ self.const = 1 / np.sqrt(2)
+
+ def forward(self, _inp):
+ x, mask, batch = _inp
+ y = x
+ mask_ = mask
+ assert y.shape[-1] == self._resolution or y.shape[-1] == 1
+ y, mask = self.layers((x, mask))
+ if self._residual:
+ residual, mask_ = self.residual_conv((x, mask_))
+ y = (y + residual) * self.const
+ mask = (mask + mask_) * self.const
+ return y, mask, batch
+
+ def extra_repr(self):
+ return f"Residual={self._residual}, Resolution={self._resolution}"
+
+
+class PoseNormalize(nn.Module):
+
+ @torch.no_grad()
+ def forward(self, x):
+ return x * 2 - 1
+
+
+class ScalarPoseFCNN(nn.Module):
+
+ def __init__(self, pose_size, hidden_size,
+ output_shape):
+ super().__init__()
+ pose_size = pose_size
+ self._hidden_size = hidden_size
+ output_size = np.prod(output_shape)
+ self.output_shape = output_shape
+ self.pose_preprocessor = nn.Sequential(
+ PoseNormalize(),
+ Linear(pose_size, hidden_size),
+ nn.LeakyReLU(.2),
+ Linear(hidden_size, output_size),
+ nn.LeakyReLU(.2)
+ )
+
+ def forward(self, _inp):
+ x, mask, batch = _inp
+ pose_info = batch["landmarks"]
+ del batch["landmarks"]
+ pose = self.pose_preprocessor(pose_info)
+ pose = pose.view(-1, *self.output_shape)
+ if x.shape[0] == 1 and x.shape[2] == 1 and x.shape[3] == 1:
+ # Analytical normalization propagation
+ pose = pose.mean(dim=2, keepdim=True).mean(dim=3, keepdims=True)
+ x = torch.cat((x, pose), dim=1)
+ return x, mask, batch
+
+ def __repr__(self):
+ return " ".join([
+ self.__class__.__name__,
+ f"hidden_size={self._hidden_size}",
+ f"output shape={self.output_shape}"
+ ])
+
+
+class Attention(nn.Module):
+
+ def __init__(self, in_channels):
+ super(Attention, self).__init__()
+ # Channel multiplier
+ self.in_channels = in_channels
+ self.theta = Conv2d(
+ self.in_channels, self.in_channels // 8, kernel_size=1, padding=0,
+ bias=False)
+ self.phi = Conv2d(
+ self.in_channels, self.in_channels // 8, kernel_size=1, padding=0,
+ bias=False)
+ self.g = Conv2d(
+ self.in_channels, self.in_channels // 2, kernel_size=1, padding=0,
+ bias=False)
+ self.o = Conv2d(
+ self.in_channels // 2, self.in_channels, kernel_size=1, padding=0,
+ bias=False)
+ # Learnable gain parameter
+ self.gamma = nn.Parameter(torch.tensor(0.), requires_grad=True)
+
+ def forward(self, _inp):
+ x, mask, batch = _inp
+ # Apply convs
+ theta, _ = self.theta((x, None))
+ phi = nn.functional.max_pool2d(self.phi((x, None))[0], [2, 2])
+ g = nn.functional.max_pool2d(self.g((x, None))[0], [2, 2])
+ # Perform reshapes
+ theta = theta.view(-1, self.in_channels // 8, x.shape[2] * x.shape[3])
+ phi = phi.view(-1, self.in_channels // 8, x.shape[2] * x.shape[3] // 4)
+ g = g.view(-1, self.in_channels // 2, x.shape[2] * x.shape[3] // 4)
+ # Matmul and softmax to get attention maps
+ beta = nn.functional.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
+ # Attention map times g path
+
+ o = self.o((torch.bmm(g, beta.transpose(1, 2)).view(-1,
+ self.in_channels // 2, x.shape[2], x.shape[3]), None))[0]
+ return self.gamma * o + x, mask, batch
+
+
+class MSGGenerator(BaseGenerator):
+
+ def __init__(self):
+ super().__init__(512)
+ max_imsize = 128
+ unet = dict(enabled=True, residual=False)
+
+ min_fmap_resolution = 4
+ model_size = 512
+ image_channels = 3
+ pose_size = 14
+ residual = False
+ conv_size = {
+ 4: model_size,
+ 8: model_size,
+ 16: model_size,
+ 32: model_size,
+ 64: model_size//2,
+ 128: model_size//4,
+ 256: model_size//8,
+ 512: model_size//16
+ }
+ self.removable_hooks = []
+ self.rgb_convolutions = nn.ModuleDict()
+ self.max_imsize = max_imsize
+ self._image_channels = image_channels
+ self._min_fmap_resolution = min_fmap_resolution
+ self._residual = residual
+ self._pose_size = pose_size
+ self.current_imsize = max_imsize
+ self._unet_cfg = unet
+ self.concat_input_mask = True
+ self.res2channels = {int(k): v for k, v in conv_size.items()}
+
+ self.conv2d_config = EasyDict(
+ pixel_normalization=True,
+ leaky_relu_nslope=.2,
+ normalization="pixel_wise",
+ conv=dict(
+ type="conv",
+ wsconv=True,
+ gain=1,
+ )
+ )
+ self._init_decoder()
+ self._init_encoder()
+
+ def _init_encoder(self):
+ self.encoder = nn.ModuleList()
+ imsize = self.max_imsize
+ self.from_rgb = build_convact(
+ self.conv2d_config,
+ in_channels=self._image_channels + self.concat_input_mask*2,
+ out_channels=self.res2channels[imsize],
+ kernel_size=1)
+ while imsize >= self._min_fmap_resolution:
+ current_size = self.res2channels[imsize]
+ next_size = self.res2channels[max(imsize//2, self._min_fmap_resolution)]
+ block = BasicBlock(
+ self.conv2d_config, imsize, current_size,
+ [current_size, next_size], self._residual)
+ self.encoder.add_module(f"basic_block{imsize}", block)
+ if imsize != self._min_fmap_resolution:
+ self.encoder.add_module(
+ f"downsample{imsize}", AvgPool2d(2))
+ imsize //= 2
+
+ def _init_decoder(self):
+ self.decoder = nn.ModuleList()
+ self.decoder.add_module(
+ "latent_concat", LatentVariableConcat(self.conv2d_config))
+ if self._pose_size > 0:
+ m = self._min_fmap_resolution
+ pose_shape = (16, m, m)
+ pose_fcnn = ScalarPoseFCNN(self._pose_size, 128, pose_shape)
+ self.decoder.add_module("pose_fcnn", pose_fcnn)
+ imsize = self._min_fmap_resolution
+ self.rgb_convolutions = nn.ModuleDict()
+ while imsize <= self.max_imsize:
+ current_size = self.res2channels[max(imsize//2, self._min_fmap_resolution)]
+ start_size = current_size
+ if imsize == self._min_fmap_resolution:
+ start_size += 32
+ if self._pose_size > 0:
+ start_size += 16
+ else:
+ self.decoder.add_module(f"upsample{imsize}", NearestUpsample())
+ skip = UnetSkipConnection(
+ self.conv2d_config, current_size*2, current_size, imsize,
+ **self._unet_cfg)
+ self.decoder.add_module(f"skip_connection{imsize}", skip)
+ next_size = self.res2channels[imsize]
+ block = BasicBlock(
+ self.conv2d_config, imsize, start_size, [start_size, next_size],
+ residual=self._residual)
+ self.decoder.add_module(f"basic_block{imsize}", block)
+
+ to_rgb = build_base_conv(
+ self.conv2d_config, False, in_channels=next_size,
+ out_channels=self._image_channels, kernel_size=1)
+ self.rgb_convolutions[str(imsize)] = to_rgb
+ imsize *= 2
+ self.norm_constant = len(self.rgb_convolutions)
+
+ def forward_decoder(self, x, mask, batch):
+ imsize_start = max(x.shape[-1] // 2, 1)
+ rgb = torch.zeros(
+ (x.shape[0], self._image_channels,
+ imsize_start, imsize_start),
+ dtype=x.dtype, device=x.device)
+ mask_size = 1
+ mask_out = torch.zeros(
+ (x.shape[0], mask_size,
+ imsize_start, imsize_start),
+ dtype=x.dtype, device=x.device)
+ imsize = self._min_fmap_resolution // 2
+ for module in self.decoder:
+ x, mask, batch = module((x, mask, batch))
+ if isinstance(module, BasicBlock):
+ imsize *= 2
+ rgb = up(rgb)
+ mask_out = up(mask_out)
+ conv = self.rgb_convolutions[str(imsize)]
+ rgb_, mask_ = conv((x, mask))
+ assert rgb_.shape == rgb.shape,\
+ f"rgb_ {rgb_.shape}, rgb: {rgb.shape}"
+ rgb = rgb + rgb_
+ return rgb / self.norm_constant, mask_out
+
+ def forward_encoder(self, x, mask, batch):
+ if self.concat_input_mask:
+ x = torch.cat((x, mask, 1 - mask), dim=1)
+ unet_features = {}
+ x, mask = self.from_rgb((x, mask))
+ for module in self.encoder:
+ x, mask, batch = module((x, mask, batch))
+ if isinstance(module, BasicBlock):
+ unet_features[module._resolution] = (x, mask)
+ return x, mask, unet_features
+
+ def forward(
+ self,
+ condition,
+ mask, keypoints=None, z=None,
+ **kwargs):
+ keypoints = keypoints.flatten(start_dim=1).clip(-1, 1)
+ if z is None:
+ z = self.get_z(condition)
+ z = z.view(-1, 32, 4, 4)
+ batch = dict(
+ landmarks=keypoints,
+ z=z)
+ orig_mask = mask
+ x, mask, unet_features = self.forward_encoder(condition, mask, batch)
+ batch = dict(
+ landmarks=keypoints,
+ z=z,
+ unet_features=unet_features)
+ x, mask = self.forward_decoder(x, mask, batch)
+ x = condition * orig_mask + (1 - orig_mask) * x
+ return dict(img=x)
+
+ def load_state_dict(self, state_dict, strict=True):
+ if "parameters" in state_dict:
+ state_dict = state_dict["parameters"]
+ old_checkpoint = any("basic_block0" in key for key in state_dict)
+ if not old_checkpoint:
+ return super().load_state_dict(state_dict, strict=strict)
+ mapping = {}
+ imsize = self._min_fmap_resolution
+ i = 0
+ while imsize <= self.max_imsize:
+ old_key = f"decoder.basic_block{i}."
+ new_key = f"decoder.basic_block{imsize}."
+ mapping[old_key] = new_key
+ if i >= 1:
+ old_key = old_key.replace("basic_block", "skip_connection")
+ new_key = new_key.replace("basic_block", "skip_connection")
+ mapping[old_key] = new_key
+ mapping[old_key] = new_key
+ old_key = f"encoder.basic_block{i}."
+ new_key = f"encoder.basic_block{imsize}."
+ mapping[old_key] = new_key
+ old_key = "from_rgb.conv.layers.0."
+ new_key = "from_rgb.0."
+ mapping[old_key] = new_key
+ i += 1
+ imsize *= 2
+ new_sd = {}
+ for key, value in state_dict.items():
+ old_key = key
+ if "from_rgb" in key:
+ new_sd[key.replace("encoder.", "").replace(".conv.layers", "")] = value
+ continue
+ for subkey, new_subkey in mapping.items():
+ if subkey in key:
+ old_key = key
+ key = key.replace(subkey, new_subkey)
+
+ break
+ if "decoder.to_rgb" in key:
+ continue
+
+ new_sd[key] = value
+ return super().load_state_dict(new_sd, strict=strict)
+
+ def update_w(self, *args, **kwargs):
+ return
diff --git a/dp2/generator/dummy_generators.py b/dp2/generator/dummy_generators.py
new file mode 100644
index 0000000000000000000000000000000000000000..c81b4d4f70bd84fb42bfe8ab3d6bd06f918533c4
--- /dev/null
+++ b/dp2/generator/dummy_generators.py
@@ -0,0 +1,60 @@
+import torch
+from .base import BaseGenerator
+from torchvision.transforms.functional import gaussian_blur
+import torch.nn.functional as F
+
+
+class PixelationGenerator(BaseGenerator):
+
+ def __init__(self, pixelation_size, **kwargs):
+ super().__init__(z_channels=0)
+ self.pixelation_size = pixelation_size
+ self.z_channels = 0
+ self.latent_space = None
+
+ def forward(self, img, condition, mask, **kwargs):
+ old_shape = img.shape[-2:]
+ img = F.interpolate(img, size=(
+ self.pixelation_size, self.pixelation_size), mode="bilinear", align_corners=True)
+ img = F.interpolate(img, size=old_shape, mode="bilinear", align_corners=True)
+ out = img*(1-mask) + condition*mask
+ return {"img": out}
+
+
+class MaskOutGenerator(BaseGenerator):
+
+ def __init__(self, noise: str, **kwargs):
+ super().__init__(z_channels=0)
+ self.noise = noise
+ self.z_channels = 0
+ assert self.noise in ["rand", "constant"]
+ self.latent_space = None
+
+ def forward(self, img, condition, mask, **kwargs):
+
+ if self.noise == "constant":
+ img = torch.zeros_like(img)
+ elif self.noise == "rand":
+ img = torch.rand_like(img)
+ out = img*(1-mask) + condition*mask
+ return {"img": out}
+
+
+class IdentityGenerator(BaseGenerator):
+
+ def __init__(self):
+ super().__init__(z_channels=0)
+
+ def forward(self, img, condition, mask, **kwargs):
+ return dict(img=img)
+
+
+class GaussianBlurGenerator(BaseGenerator):
+
+ def __init__(self):
+ super().__init__(z_channels=0)
+ self.sigma = 7
+
+ def forward(self, img, condition, mask, **kwargs):
+ img_blur = gaussian_blur(img, kernel_size=min(self.sigma*3, img.shape[-1]), sigma=self.sigma)
+ return dict(img=img * mask + (1-mask) * img_blur)
diff --git a/dp2/generator/stylegan_unet.py b/dp2/generator/stylegan_unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c3dfc46da323d04919cf5c166ec038820eac1ad
--- /dev/null
+++ b/dp2/generator/stylegan_unet.py
@@ -0,0 +1,211 @@
+import torch
+import numpy as np
+from dp2.layers import Sequential
+from dp2.layers.sg2_layers import Conv2d, FullyConnectedLayer, ResidualBlock
+from .base import BaseStyleGAN
+from typing import List, Tuple
+from .utils import spatial_embed_keypoints, mask_output
+
+
+def get_chsize(imsize, cnum, max_imsize, max_cnum_mul):
+ n = int(np.log2(max_imsize) - np.log2(imsize))
+ mul = min(2**n, max_cnum_mul)
+ ch = cnum * mul
+ return int(ch)
+
+
+class StyleGANUnet(BaseStyleGAN):
+ def __init__(
+ self,
+ scale_grad: bool,
+ im_channels: int,
+ min_fmap_resolution: int,
+ imsize: List[int],
+ cnum: int,
+ max_cnum_mul: int,
+ mask_output: bool,
+ conv_clamp: int,
+ input_cse: bool,
+ cse_nc: int,
+ n_middle_blocks: int,
+ input_keypoints: bool,
+ n_keypoints: int,
+ input_keypoint_indices: Tuple[int],
+ fix_errors: bool,
+ **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ self.n_keypoints = n_keypoints
+ self.input_keypoint_indices = list(input_keypoint_indices)
+ self.input_keypoints = input_keypoints
+ assert not (input_cse and input_keypoints)
+ cse_nc = 0 if cse_nc is None else cse_nc
+ self.imsize = imsize
+ self._cnum = cnum
+ self._max_cnum_mul = max_cnum_mul
+ self._min_fmap_resolution = min_fmap_resolution
+ self._image_channels = im_channels
+ self._max_imsize = max(imsize)
+ self.input_cse = input_cse
+ self.gain_unet = np.sqrt(1/3)
+ n_levels = int(np.log2(self._max_imsize) - np.log2(min_fmap_resolution))+1
+ encoder_layers = []
+ self.from_rgb = Conv2d(
+ im_channels + 1 + input_cse*(cse_nc+1) + input_keypoints*len(self.input_keypoint_indices),
+ cnum, 1
+ )
+ for i in range(n_levels): # Encoder layers
+ resolution = [x//2**i for x in imsize]
+ in_ch = get_chsize(max(resolution), cnum, self._max_imsize, max_cnum_mul)
+ second_ch = in_ch
+ out_ch = get_chsize(max(resolution)//2, cnum, self._max_imsize, max_cnum_mul)
+ down = 2
+
+ if i == 0: # first (lowest) block. Downsampling is performed at the start of the block
+ down = 1
+ if i == n_levels - 1:
+ out_ch = second_ch
+ block = ResidualBlock(in_ch, out_ch, down=down, conv_clamp=conv_clamp, fix_residual=fix_errors)
+ encoder_layers.append(block)
+ self._encoder_out_shape = [
+ get_chsize(min_fmap_resolution, cnum, self._max_imsize, max_cnum_mul),
+ *resolution]
+
+ self.encoder = torch.nn.ModuleList(encoder_layers)
+
+ # initialize decoder
+ decoder_layers = []
+ for i in range(n_levels):
+ resolution = [x//2**(n_levels-1-i) for x in imsize]
+ in_ch = get_chsize(max(resolution)//2, cnum, self._max_imsize, max_cnum_mul)
+ out_ch = get_chsize(max(resolution), cnum, self._max_imsize, max_cnum_mul)
+ if i == 0: # first (lowest) block
+ in_ch = get_chsize(max(resolution), cnum, self._max_imsize, max_cnum_mul)
+
+ up = 1
+ if i != n_levels - 1:
+ up = 2
+ block = ResidualBlock(
+ in_ch, out_ch, conv_clamp=conv_clamp, gain_out=np.sqrt(1/3),
+ w_dim=self.style_net.w_dim, norm=True, up=up,
+ fix_residual=fix_errors
+ )
+ decoder_layers.append(block)
+ if i != 0:
+ unet_block = Conv2d(
+ in_ch, in_ch, kernel_size=1, conv_clamp=conv_clamp, norm=True,
+ gain=np.sqrt(1/3) if fix_errors else np.sqrt(.5))
+ setattr(self, f"unet_block{i}", unet_block)
+
+ # Initialize "middle blocks" that do not have down/up sample
+ middle_blocks = []
+ for i in range(n_middle_blocks):
+ ch = get_chsize(min_fmap_resolution, cnum, self._max_imsize, max_cnum_mul)
+ block = ResidualBlock(
+ ch, ch, conv_clamp=conv_clamp, gain_out=np.sqrt(.5) if fix_errors else np.sqrt(1/3),
+ w_dim=self.style_net.w_dim, norm=True,
+ )
+ middle_blocks.append(block)
+ if n_middle_blocks != 0:
+ self.middle_blocks = Sequential(*middle_blocks)
+ self.decoder = torch.nn.ModuleList(decoder_layers)
+ self.to_rgb = Conv2d(cnum, im_channels, 1, activation="linear", conv_clamp=conv_clamp)
+ # Initialize "middle blocks" that do not have down/up sample
+ self.decoder = torch.nn.ModuleList(decoder_layers)
+ self.scale_grad = scale_grad
+ self.mask_output = mask_output
+
+ def forward_dec(self, x, w, unet_features, condition, mask, s, **kwargs):
+ for i, layer in enumerate(self.decoder):
+ if i != 0:
+ unet_layer = getattr(self, f"unet_block{i}")
+ x = x + unet_layer(unet_features[-i])
+ x = layer(x, w=w, s=s)
+ x = self.to_rgb(x)
+ if self.mask_output:
+ x = mask_output(True, condition, x, mask)
+ return dict(img=x)
+
+ def forward_enc(self, condition, mask, embedding, keypoints, E_mask, **kwargs):
+ if self.input_cse:
+ x = torch.cat((condition, mask, embedding, E_mask), dim=1)
+ else:
+ x = torch.cat((condition, mask), dim=1)
+ if self.input_keypoints:
+ keypoints = keypoints[:, self.input_keypoint_indices]
+ one_hot_pose = spatial_embed_keypoints(keypoints, x)
+ x = torch.cat((x, one_hot_pose), dim=1)
+ x = self.from_rgb(x)
+
+ unet_features = []
+ for i, layer in enumerate(self.encoder):
+ x = layer(x)
+ if i != len(self.encoder)-1:
+ unet_features.append(x)
+ if hasattr(self, "middle_blocks"):
+ for layer in self.middle_blocks:
+ x = layer(x)
+ return x, unet_features
+
+ def forward(
+ self, condition, mask,
+ z=None, embedding=None, w=None, update_emas=False, x=None,
+ s=None,
+ keypoints=None,
+ unet_features=None,
+ E_mask=None,
+ **kwargs):
+ # Used to skip sampling from encoder in inference. E.g. for w projection.
+ if x is not None and unet_features is not None:
+ assert not self.training
+ else:
+ x, unet_features = self.forward_enc(condition, mask, embedding, keypoints, E_mask, **kwargs)
+ if w is None:
+ if z is None:
+ z = self.get_z(condition)
+ w = self.get_w(z, update_emas=update_emas)
+ return self.forward_dec(x, w, unet_features, condition, mask, s, **kwargs)
+
+
+class ComodStyleUNet(StyleGANUnet):
+
+ def __init__(self, min_comod_res=4, lr_multiplier_comod=1, **kwargs) -> None:
+ super().__init__(**kwargs)
+ min_fmap = min(self._encoder_out_shape[1:])
+ enc_out_ch = self._encoder_out_shape[0]
+ n_down = int(np.ceil(np.log2(min_fmap) - np.log2(min_comod_res)))
+ comod_layers = []
+ in_ch = enc_out_ch
+ for i in range(n_down):
+ comod_layers.append(Conv2d(enc_out_ch, 256, kernel_size=3, down=2, lr_multiplier=lr_multiplier_comod))
+ in_ch = 256
+ if n_down == 0:
+ comod_layers = [Conv2d(in_ch, 256, kernel_size=3)]
+ comod_layers.append(torch.nn.Flatten())
+ out_res = [x//2**n_down for x in self._encoder_out_shape[1:]]
+ in_ch_fc = np.prod(out_res) * 256
+ comod_layers.append(FullyConnectedLayer(in_ch_fc, 512, lr_multiplier=lr_multiplier_comod))
+ self.comod_block = Sequential(*comod_layers)
+ self.comod_fc = FullyConnectedLayer(
+ 512+self.style_net.w_dim, self.style_net.w_dim, lr_multiplier=lr_multiplier_comod)
+
+ def forward_dec(self, x, w, unet_features, condition, mask, **kwargs):
+ y = self.comod_block(x)
+ y = torch.cat((y, w), dim=1)
+ y = self.comod_fc(y)
+ for i, layer in enumerate(self.decoder):
+ if i != 0:
+ unet_layer = getattr(self, f"unet_block{i}")
+ x = x + unet_layer(unet_features[-i], gain=np.sqrt(.5))
+ x = layer(x, w=y)
+ x = self.to_rgb(x)
+ if self.mask_output:
+ x = mask_output(True, condition, x, mask)
+ return dict(img=x)
+
+ def get_comod_y(self, batch, w):
+ x, unet_features = self.forward_enc(**batch)
+ y = self.comod_block(x)
+ y = torch.cat((y, w), dim=1)
+ y = self.comod_fc(y)
+ return y
diff --git a/dp2/generator/utils.py b/dp2/generator/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec5a983df28d81b10f81635881f77c4747162431
--- /dev/null
+++ b/dp2/generator/utils.py
@@ -0,0 +1,49 @@
+import torch
+import tops
+import torch
+from torch.cuda.amp import custom_bwd, custom_fwd
+
+
+@torch.no_grad()
+def spatial_embed_keypoints(keypoints: torch.Tensor, x):
+ tops.assert_shape(keypoints, (None, None, 3))
+ B, N_K, _ = keypoints.shape
+ H, W = x.shape[-2:]
+ keypoint_spatial = torch.zeros(keypoints.shape[0], N_K, H, W, device=keypoints.device, dtype=torch.float32)
+ x, y, visible = keypoints.chunk(3, dim=2)
+ x = (x * W).round().long().clamp(0, W-1)
+ y = (y * H).round().long().clamp(0, H-1)
+ kp_idx = torch.arange(0, N_K, 1, device=keypoints.device, dtype=torch.long).view(1, -1, 1).repeat(B, 1, 1)
+ pos = (kp_idx*(H*W) + y*W + x + 1)
+ # Offset all by 1 to index invisible keypoints as 0
+ pos = (pos * visible.round().long()).squeeze(dim=-1)
+ keypoint_spatial = torch.zeros(keypoints.shape[0], N_K*H*W+1, device=keypoints.device, dtype=torch.float32)
+ keypoint_spatial.scatter_(1, pos, 1)
+ keypoint_spatial = keypoint_spatial[:, 1:].view(-1, N_K, H, W)
+ return keypoint_spatial
+
+
+class MaskOutput(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx, x_real, x_fake, mask):
+ ctx.save_for_backward(mask)
+ out = x_real * mask + (1-mask) * x_fake
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad_output):
+ fake_grad = grad_output
+ mask, = ctx.saved_tensors
+ fake_grad = fake_grad * (1 - mask)
+ known_percentage = mask.view(mask.shape[0], -1).mean(dim=1)
+ fake_grad = fake_grad / (1-known_percentage).view(-1, 1, 1, 1)
+ return None, fake_grad, None
+
+
+def mask_output(scale_grad, x_real, x_fake, mask):
+ if scale_grad:
+ return MaskOutput.apply(x_real, x_fake, mask)
+ return x_real * mask + (1-mask) * x_fake
diff --git a/dp2/infer.py b/dp2/infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff1ea39ed6919be6f4e282823e0c291e48e248ee
--- /dev/null
+++ b/dp2/infer.py
@@ -0,0 +1,78 @@
+import tops
+import torch
+from tops import checkpointer
+from tops.config import instantiate
+from tops.logger import warn
+from dp2.generator.deep_privacy1 import MSGGenerator
+
+
+def load_generator_state(ckpt, G: torch.nn.Module, ckpt_mapper=None):
+ state = ckpt["EMA_generator"] if "EMA_generator" in ckpt else ckpt["running_average_generator"]
+ if ckpt_mapper is not None:
+ state = ckpt_mapper(state)
+ if isinstance(G, MSGGenerator):
+ G.load_state_dict(state)
+ else:
+ load_state_dict(G, state)
+ tops.logger.log(f"Generator loaded, num parameters: {tops.num_parameters(G)/1e6}M")
+ if "w_centers" in ckpt:
+ G.style_net.register_buffer("w_centers", ckpt["w_centers"])
+ tops.logger.log(f"W cluster centers loaded. Number of centers: {len(G.style_net.w_centers)}")
+ if "style_net.w_centers" in state:
+ G.style_net.register_buffer("w_centers", state["style_net.w_centers"])
+ tops.logger.log(f"W cluster centers loaded. Number of centers: {len(G.style_net.w_centers)}")
+
+
+def build_trained_generator(cfg, map_location=None):
+ map_location = map_location if map_location is not None else tops.get_device()
+ G = instantiate(cfg.generator)
+ G.eval()
+ G.imsize = tuple(cfg.data.imsize) if hasattr(cfg, "data") else None
+ if hasattr(cfg, "ckpt_mapper"):
+ ckpt_mapper = instantiate(cfg.ckpt_mapper)
+ else:
+ ckpt_mapper = None
+ if "model_url" in cfg.common:
+ ckpt = tops.load_file_or_url(cfg.common.model_url, md5sum=cfg.common.model_md5sum)
+ load_generator_state(ckpt, G, ckpt_mapper)
+ return G.to(map_location)
+ try:
+ ckpt = checkpointer.load_checkpoint(cfg.checkpoint_dir, map_location="cpu")
+ load_generator_state(ckpt, G, ckpt_mapper)
+ except FileNotFoundError as e:
+ tops.logger.warn(f"Did not find generator checkpoint in: {cfg.checkpoint_dir}")
+ return G.to(map_location)
+
+
+def build_trained_discriminator(cfg, map_location=None):
+ map_location = map_location if map_location is not None else tops.get_device()
+ D = instantiate(cfg.discriminator).to(map_location)
+ D.eval()
+ try:
+ ckpt = checkpointer.load_checkpoint(cfg.checkpoint_dir, map_location="cpu")
+ if hasattr(cfg, "ckpt_mapper_D"):
+ ckpt["discriminator"] = instantiate(cfg.ckpt_mapper_D)(ckpt["discriminator"])
+ D.load_state_dict(ckpt["discriminator"])
+ except FileNotFoundError as e:
+ tops.logger.warn(f"Did not find discriminator checkpoint in: {cfg.checkpoint_dir}")
+ return D
+
+
+def load_state_dict(module: torch.nn.Module, state_dict: dict):
+ module_sd = module.state_dict()
+ to_remove = []
+ for key, item in state_dict.items():
+ if key not in module_sd:
+ continue
+ if item.shape != module_sd[key].shape:
+ to_remove.append(key)
+ warn(f"Incorrect shape. Current model: {module_sd[key].shape}, in state dict: {item.shape} for key: {key}")
+ for key in to_remove:
+ state_dict.pop(key)
+ for key, item in state_dict.items():
+ if key not in module_sd:
+ warn(f"Did not fin key in model state dict: {key}")
+ for key, item in module_sd.items():
+ if key not in state_dict:
+ warn(f"Did not find key in state dict: {key}")
+ module.load_state_dict(state_dict, strict=False)
diff --git a/dp2/layers/__init__.py b/dp2/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..680be5614803ce6a74a2994c023ab90b7d27a74b
--- /dev/null
+++ b/dp2/layers/__init__.py
@@ -0,0 +1,22 @@
+from typing import Dict
+import torch
+import tops
+import torch.nn as nn
+
+
+class Sequential(nn.Sequential):
+
+ def forward(self, x: Dict[str, torch.Tensor], **kwargs) -> Dict[str, torch.Tensor]:
+ for module in self:
+ x = module(x, **kwargs)
+ return x
+
+
+class Module(nn.Module):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+
+ def extra_repr(self):
+ num_params = tops.num_parameters(self) / 10**6
+ return f"Num params: {num_params:.3f}M"
diff --git a/dp2/layers/sg2_layers.py b/dp2/layers/sg2_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce111ffd5a4ce1d5a879ba49cd7a7108b35defdf
--- /dev/null
+++ b/dp2/layers/sg2_layers.py
@@ -0,0 +1,229 @@
+from typing import List
+import numpy as np
+import torch
+import tops
+import torch.nn.functional as F
+from sg3_torch_utils.ops import conv2d_resample
+from sg3_torch_utils.ops import upfirdn2d
+from sg3_torch_utils.ops import bias_act
+from sg3_torch_utils.ops.fma import fma
+
+
+class FullyConnectedLayer(torch.nn.Module):
+ def __init__(self,
+ in_features, # Number of input features.
+ out_features, # Number of output features.
+ bias=True, # Apply additive bias before the activation function?
+ activation='linear', # Activation function: 'relu', 'lrelu', etc.
+ lr_multiplier=1, # Learning rate multiplier.
+ bias_init=0, # Initial value for the additive bias.
+ ):
+ super().__init__()
+ self.repr = dict(
+ in_features=in_features, out_features=out_features, bias=bias,
+ activation=activation, lr_multiplier=lr_multiplier, bias_init=bias_init)
+ self.activation = activation
+ self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
+ self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
+ self.weight_gain = lr_multiplier / np.sqrt(in_features)
+ self.bias_gain = lr_multiplier
+ self.in_features = in_features
+ self.out_features = out_features
+
+ def forward(self, x):
+ w = self.weight * self.weight_gain
+ b = self.bias
+ if b is not None and self.bias_gain != 1:
+ b = b * self.bias_gain
+ x = F.linear(x, w)
+ x = bias_act.bias_act(x, b, act=self.activation)
+ return x
+
+ def extra_repr(self) -> str:
+ return ", ".join([f"{key}={item}" for key, item in self.repr.items()])
+
+
+class Conv2d(torch.nn.Module):
+ def __init__(self,
+ in_channels, # Number of input channels.
+ out_channels, # Number of output channels.
+ kernel_size=3, # Convolution kernel size.
+ up=1, # Integer upsampling factor.
+ down=1, # Integer downsampling factor
+ activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
+ resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ bias=True,
+ norm=False,
+ lr_multiplier=1,
+ bias_init=0,
+ w_dim=None,
+ gain=1,
+ ):
+ super().__init__()
+ if norm:
+ self.norm = torch.nn.InstanceNorm2d(None)
+ assert norm in [True, False]
+ self.up = up
+ self.down = down
+ self.activation = activation
+ self.conv_clamp = conv_clamp if conv_clamp is None else conv_clamp * gain
+ self.out_channels = out_channels
+ self.in_channels = in_channels
+ self.padding = kernel_size // 2
+
+ self.repr = dict(
+ in_channels=in_channels, out_channels=out_channels,
+ kernel_size=kernel_size, up=up, down=down,
+ activation=activation, resample_filter=resample_filter, conv_clamp=conv_clamp, bias=bias,
+ )
+
+ if self.up == 1 and self.down == 1:
+ self.resample_filter = None
+ else:
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
+
+ self.act_gain = bias_act.activation_funcs[activation].def_gain * gain
+ self.weight_gain = lr_multiplier / np.sqrt(in_channels * (kernel_size ** 2))
+ self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]))
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]) + bias_init) if bias else None
+ self.bias_gain = lr_multiplier
+ if w_dim is not None:
+ self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
+ self.affine_beta = FullyConnectedLayer(w_dim, in_channels, bias_init=0)
+
+ def forward(self, x, w=None, s=None):
+ tops.assert_shape(x, [None, self.weight.shape[1], None, None])
+ if s is not None:
+ s = s[..., :self.in_channels * 2]
+ gamma, beta = s.view(-1, self.in_channels * 2, 1, 1).chunk(2, dim=1)
+ x = fma(x, gamma, beta)
+ elif hasattr(self, "affine"):
+ gamma = self.affine(w).view(-1, self.in_channels, 1, 1)
+ beta = self.affine_beta(w).view(-1, self.in_channels, 1, 1)
+ x = fma(x, gamma, beta)
+ w = self.weight * self.weight_gain
+ # Removing flip weight is not safe.
+ x = conv2d_resample.conv2d_resample(x, w, self.resample_filter, self.up,
+ self.down, self.padding, flip_weight=self.up == 1)
+ if hasattr(self, "norm"):
+ x = self.norm(x)
+ b = self.bias * self.bias_gain if self.bias is not None else None
+ x = bias_act.bias_act(x, b, act=self.activation, gain=self.act_gain, clamp=self.conv_clamp)
+ return x
+
+ def extra_repr(self) -> str:
+ return ", ".join([f"{key}={item}" for key, item in self.repr.items()])
+
+
+class Block(torch.nn.Module):
+ def __init__(self,
+ in_channels, # Number of input channels, 0 = first block.
+ out_channels, # Number of output channels.
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ up=1,
+ down=1,
+ **layer_kwargs, # Arguments for SynthesisLayer.
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.down = down
+ self.conv0 = Conv2d(in_channels, out_channels, down=down, conv_clamp=conv_clamp, **layer_kwargs)
+ self.conv1 = Conv2d(out_channels, out_channels, up=up, conv_clamp=conv_clamp, **layer_kwargs)
+
+ def forward(self, x, **layer_kwargs):
+ x = self.conv0(x, **layer_kwargs)
+ x = self.conv1(x, **layer_kwargs)
+ return x
+
+
+class ResidualBlock(torch.nn.Module):
+ def __init__(self,
+ in_channels, # Number of input channels, 0 = first block.
+ out_channels, # Number of output channels.
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ up=1,
+ down=1,
+ gain_out=np.sqrt(0.5),
+ fix_residual: bool = False,
+ **layer_kwargs, # Arguments for conv layer.
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.down = down
+ self.conv0 = Conv2d(in_channels, out_channels, down=down, conv_clamp=conv_clamp, **layer_kwargs)
+
+ self.conv1 = Conv2d(out_channels, out_channels, up=up, conv_clamp=conv_clamp, gain=gain_out, **layer_kwargs)
+
+ self.skip = Conv2d(
+ in_channels, out_channels, kernel_size=1, bias=False, up=up, down=down,
+ activation="linear" if fix_residual else "lrelu",
+ gain=gain_out
+ )
+ self.gain_out = gain_out
+
+ def forward(self, x, w=None, s=None, **layer_kwargs):
+ y = self.skip(x)
+ s_ = next(s) if s is not None else None
+ x = self.conv0(x, w, s=s_, **layer_kwargs)
+ s_ = next(s) if s is not None else None
+ x = self.conv1(x, w, s=s_, **layer_kwargs)
+ x = y + x
+ return x
+
+
+class MinibatchStdLayer(torch.nn.Module):
+ def __init__(self, group_size, num_channels=1):
+ super().__init__()
+ self.group_size = group_size
+ self.num_channels = num_channels
+
+ def forward(self, x):
+ N, C, H, W = x.shape
+ with tops.suppress_tracer_warnings(): # as_tensor results are registered as constants
+ G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N
+ F = self.num_channels
+ c = C // F
+
+ # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
+ y = x.reshape(G, -1, F, c, H, W)
+ y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group.
+ y = y.square().mean(dim=0) # [nFcHW] Calc variance over group.
+ y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group.
+ y = y.mean(dim=[2, 3, 4]) # [nF] Take average over channels and pixels.
+ y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions.
+ y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels.
+ x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels.
+ return x
+
+
+class DiscriminatorEpilogue(torch.nn.Module):
+ def __init__(self,
+ in_channels, # Number of input channels.
+ resolution: List[int], # Resolution of this block.
+ mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
+ mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable.
+ activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.resolution = resolution
+ self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size,
+ num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None
+ self.conv = Conv2d(
+ in_channels + mbstd_num_channels, in_channels,
+ kernel_size=3, activation=activation, conv_clamp=conv_clamp)
+ self.fc = FullyConnectedLayer(in_channels * resolution[0] * resolution[1], in_channels, activation=activation)
+ self.out = FullyConnectedLayer(in_channels, 1)
+
+ def forward(self, x):
+ tops.assert_shape(x, [None, self.in_channels, *self.resolution]) # [NCHW]
+ # Main layers.
+ if self.mbstd is not None:
+ x = self.mbstd(x)
+ x = self.conv(x)
+ x = self.fc(x.flatten(1))
+ x = self.out(x)
+ return x
diff --git a/dp2/loss/__init__.py b/dp2/loss/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..16cbdd0051ff51bda0828f37c9ef5faed65a9ed7
--- /dev/null
+++ b/dp2/loss/__init__.py
@@ -0,0 +1 @@
+from .sg2_loss import StyleGAN2Loss
diff --git a/dp2/loss/pl_regularization.py b/dp2/loss/pl_regularization.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf47b9b4ea9f9a22eebae8d0d4be755f75c3878c
--- /dev/null
+++ b/dp2/loss/pl_regularization.py
@@ -0,0 +1,49 @@
+import torch
+import tops
+import numpy as np
+from sg3_torch_utils.ops import conv2d_gradfix
+
+pl_mean_total = torch.zeros([])
+
+
+class PLRegularization:
+
+ def __init__(self, weight: float, batch_shrink: int, pl_decay: float, scale_by_mask: bool, **kwargs):
+ self.pl_mean = torch.zeros([], device=tops.get_device())
+ self.pl_weight = weight
+ self.batch_shrink = batch_shrink
+ self.pl_decay = pl_decay
+ self.scale_by_mask = scale_by_mask
+
+ def __call__(self, G, batch, grad_scaler):
+ batch_size = batch["img"].shape[0] // self.batch_shrink
+ batch = {k: v[:batch_size] for k, v in batch.items() if k != "embed_map"}
+ if "embed_map" in batch:
+ batch["embed_map"] = batch["embed_map"]
+ z = G.get_z(batch["img"])
+
+ with torch.cuda.amp.autocast(tops.AMP()):
+ gen_ws = G.style_net(z)
+ gen_img = G(**batch, w=gen_ws)["img"].float()
+ pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3])
+ with conv2d_gradfix.no_weight_gradients():
+ # Sums over HWC
+ pl_grads = torch.autograd.grad(
+ outputs=[grad_scaler.scale(gen_img * pl_noise)],
+ inputs=[gen_ws],
+ create_graph=True,
+ grad_outputs=torch.ones_like(gen_img),
+ only_inputs=True)[0]
+
+ pl_grads = pl_grads.float() / grad_scaler.get_scale()
+ if self.scale_by_mask:
+ # Percentage of pixels known
+ scaling = batch["mask"].flatten(start_dim=1).mean(dim=1).view(-1, 1)
+ pl_grads = pl_grads / scaling
+ pl_lengths = pl_grads.square().sum(1).sqrt()
+ pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay)
+ if not torch.isnan(pl_mean).any():
+ self.pl_mean.copy_(pl_mean.detach())
+ pl_penalty = (pl_lengths - pl_mean).square()
+ to_log = dict(pl_penalty=pl_penalty.mean().detach())
+ return pl_penalty.view(-1) * self.pl_weight, to_log
diff --git a/dp2/loss/r1_regularization.py b/dp2/loss/r1_regularization.py
new file mode 100644
index 0000000000000000000000000000000000000000..f974c5542bf49ed36b54b46cfc7c9c9bfaff9ce3
--- /dev/null
+++ b/dp2/loss/r1_regularization.py
@@ -0,0 +1,32 @@
+import torch
+import tops
+
+
+def r1_regularization(
+ real_img, real_score, mask, lambd: float, lazy_reg_interval: int,
+ lazy_regularization: bool,
+ scaler: torch.cuda.amp.GradScaler, mask_out: bool,
+ mask_out_scale: bool,
+ **kwargs
+):
+ grad = torch.autograd.grad(
+ outputs=scaler.scale(real_score),
+ inputs=real_img,
+ grad_outputs=torch.ones_like(real_score),
+ create_graph=True,
+ only_inputs=True,
+ )[0]
+ inv_scale = 1.0 / scaler.get_scale()
+ grad = grad * inv_scale
+ with torch.cuda.amp.autocast(tops.AMP()):
+ if mask_out:
+ grad = grad * (1 - mask)
+ grad = grad.square().sum(dim=[1, 2, 3])
+ if mask_out and mask_out_scale:
+ total_pixels = real_img.shape[1] * real_img.shape[2] * real_img.shape[3]
+ n_fake = (1-mask).sum(dim=[1, 2, 3])
+ scaling = total_pixels / n_fake
+ grad = grad * scaling
+ if lazy_regularization:
+ lambd_ = lambd * lazy_reg_interval / 2 # From stylegan2, lazy regularization
+ return grad * lambd_, grad.detach()
diff --git a/dp2/loss/sg2_loss.py b/dp2/loss/sg2_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..763263e2e7cb9330f24265ba8008e152fa4110f0
--- /dev/null
+++ b/dp2/loss/sg2_loss.py
@@ -0,0 +1,96 @@
+import functools
+import torch
+import tops
+from tops import logger
+from dp2.utils import forward_D_fake
+from .utils import nsgan_d_loss, nsgan_g_loss
+from .r1_regularization import r1_regularization
+from .pl_regularization import PLRegularization
+
+
+class StyleGAN2Loss:
+
+ def __init__(
+ self,
+ D,
+ G,
+ r1_opts: dict,
+ EP_lambd: float,
+ lazy_reg_interval: int,
+ lazy_regularization: bool,
+ pl_reg_opts: dict,
+ ) -> None:
+ self.gradient_step_D = 0
+ self._lazy_reg_interval = lazy_reg_interval
+ self.D = D
+ self.G = G
+ self.EP_lambd = EP_lambd
+ self.lazy_regularization = lazy_regularization
+ self.r1_reg = functools.partial(
+ r1_regularization, **r1_opts, lazy_reg_interval=lazy_reg_interval,
+ lazy_regularization=lazy_regularization)
+ self.do_PL_Reg = False
+ if pl_reg_opts.weight > 0:
+ self.pl_reg = PLRegularization(**pl_reg_opts)
+ self.do_PL_Reg = True
+ self.pl_start_nimg = pl_reg_opts.start_nimg
+
+ def D_loss(self, batch: dict, grad_scaler):
+ to_log = {}
+ # Forward through G and D
+ do_GP = self.lazy_regularization and self.gradient_step_D % self._lazy_reg_interval == 0
+ if do_GP:
+ batch["img"] = batch["img"].detach().requires_grad_(True)
+ with torch.cuda.amp.autocast(enabled=tops.AMP()):
+ with torch.no_grad():
+ G_fake = self.G(**batch, update_emas=True)
+ D_out_real = self.D(**batch)
+
+ D_out_fake = forward_D_fake(batch, G_fake["img"], self.D)
+
+ # Non saturating loss
+ nsgan_loss = nsgan_d_loss(D_out_real["score"], D_out_fake["score"])
+ tops.assert_shape(nsgan_loss, (batch["img"].shape[0], ))
+ to_log["d_loss"] = nsgan_loss.mean()
+ total_loss = nsgan_loss
+ epsilon_penalty = D_out_real["score"].pow(2).view(-1)
+ to_log["epsilon_penalty"] = epsilon_penalty.mean()
+ tops.assert_shape(epsilon_penalty, total_loss.shape)
+ total_loss = total_loss + epsilon_penalty * self.EP_lambd
+
+ # Improved gradient penalty with lazy regularization
+ # Gradient penalty applies specialized autocast.
+ if do_GP:
+ gradient_pen, grad_unscaled = self.r1_reg(
+ batch["img"], D_out_real["score"], batch["mask"], scaler=grad_scaler)
+ to_log["r1_gradient_penalty"] = grad_unscaled.mean()
+ tops.assert_shape(gradient_pen, total_loss.shape)
+ total_loss = total_loss + gradient_pen
+
+ batch["img"] = batch["img"].detach().requires_grad_(False)
+ if "score" in D_out_real:
+ to_log["real_scores"] = D_out_real["score"]
+ to_log["real_logits_sign"] = D_out_real["score"].sign()
+ to_log["fake_logits_sign"] = D_out_fake["score"].sign()
+ to_log["fake_scores"] = D_out_fake["score"]
+ to_log = {key: item.mean().detach() for key, item in to_log.items()}
+ self.gradient_step_D += 1
+ return total_loss.mean(), to_log
+
+ def G_loss(self, batch: dict, grad_scaler):
+ with torch.cuda.amp.autocast(enabled=tops.AMP()):
+ to_log = {}
+ # Forward through G and D
+ G_fake = self.G(**batch)
+ D_out_fake = forward_D_fake(batch, G_fake["img"], self.D)
+ # Adversarial Loss
+ total_loss = nsgan_g_loss(D_out_fake["score"]).view(-1)
+ to_log["g_loss"] = total_loss.mean()
+ tops.assert_shape(total_loss, (batch["img"].shape[0], ))
+
+ if self.do_PL_Reg and logger.global_step() >= self.pl_start_nimg:
+ pl_reg, to_log_ = self.pl_reg(self.G, batch, grad_scaler=grad_scaler)
+ total_loss = total_loss + pl_reg.mean()
+ to_log.update(to_log_)
+ to_log = {key: item.mean().detach() for key, item in to_log.items()}
+ return total_loss.mean(), to_log
diff --git a/dp2/loss/utils.py b/dp2/loss/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d6e19c3a0c4718412e6d83e3405c73029275f35
--- /dev/null
+++ b/dp2/loss/utils.py
@@ -0,0 +1,26 @@
+import torch
+import torch.nn.functional as F
+
+
+def nsgan_g_loss(fake_score):
+ """
+ Non-saturating criterion from Goodfellow et al. 2014
+ """
+ return torch.nn.functional.softplus(-fake_score)
+
+
+def nsgan_d_loss(real_score, fake_score):
+ """
+ Non-saturating criterion from Goodfellow et al. 2014
+ """
+ d_loss = F.softplus(-real_score) + F.softplus(fake_score)
+ return d_loss.view(-1)
+
+
+def smooth_masked_l1_loss(x, target, mask):
+ """
+ Pixel-wise l1 loss for the area indicated by mask
+ """
+ # Beta=.1 <-> square loss if pixel difference <= 12.8
+ l1 = F.smooth_l1_loss(x*mask, target*mask, beta=.1, reduction="none").sum(dim=[1, 2, 3]) / mask.sum(dim=[1, 2, 3])
+ return l1
diff --git a/dp2/metrics/__init__.py b/dp2/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d8bbb1ddd31e4453ad3663805f9aa06c077cf21
--- /dev/null
+++ b/dp2/metrics/__init__.py
@@ -0,0 +1,3 @@
+from .torch_metrics import compute_metrics_iteratively
+from .fid import compute_fid
+from .ppl import calculate_ppl
diff --git a/dp2/metrics/fid.py b/dp2/metrics/fid.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c8e0e6171fe723aa9f37735a216ec8c9d4e4a8f
--- /dev/null
+++ b/dp2/metrics/fid.py
@@ -0,0 +1,52 @@
+import tops
+from dp2 import utils
+from pathlib import Path
+from torch_fidelity.generative_model_modulewrapper import GenerativeModelModuleWrapper
+import torch
+import torch_fidelity
+
+
+class GeneratorIteratorWrapper(GenerativeModelModuleWrapper):
+
+ def __init__(self, generator, dataloader, zero_z: bool, n_diverse: int):
+ if isinstance(generator, utils.EMA):
+ generator = generator.generator
+ z_size = generator.z_channels
+ super().__init__(generator, z_size, "normal", 0)
+ self.zero_z = zero_z
+ self.dataloader = iter(dataloader)
+ self.n_diverse = n_diverse
+ self.cur_div_idx = 0
+
+ @torch.no_grad()
+ def forward(self, z, **kwargs):
+ if self.cur_div_idx == 0:
+ self.batch = next(self.dataloader)
+ if self.zero_z:
+ z = z.zero_()
+ self.cur_div_idx += 1
+ self.cur_div_idx = 0 if self.cur_div_idx == self.n_diverse else self.cur_div_idx
+ with torch.cuda.amp.autocast(enabled=tops.AMP()):
+ img = self.module(**self.batch)["img"]
+ img = (utils.denormalize_img(img)*255).byte()
+ return img
+
+
+def compute_fid(generator, dataloader, real_directory, n_source, zero_z, n_diverse):
+ generator = GeneratorIteratorWrapper(generator, dataloader, zero_z, n_diverse)
+ batch_size = dataloader.batch_size
+ num_samples = (n_source * n_diverse) // batch_size * batch_size
+ assert n_diverse >= 1
+ assert (not zero_z) or n_diverse == 1
+ assert num_samples % batch_size == 0
+ assert n_source <= batch_size * len(dataloader), (batch_size*len(dataloader), n_source, n_diverse)
+ metrics = torch_fidelity.calculate_metrics(
+ input1=generator,
+ input2=real_directory,
+ cuda=torch.cuda.is_available(),
+ fid=True,
+ input2_cache_name="_".join(Path(real_directory).parts) + "_cached",
+ input1_model_num_samples=int(num_samples),
+ batch_size=dataloader.batch_size
+ )
+ return metrics["frechet_inception_distance"]
diff --git a/dp2/metrics/fid_clip.py b/dp2/metrics/fid_clip.py
new file mode 100644
index 0000000000000000000000000000000000000000..43bde1bf74c69399308ed15ceda5aaeb59a69818
--- /dev/null
+++ b/dp2/metrics/fid_clip.py
@@ -0,0 +1,84 @@
+import pickle
+import torch
+import torchvision
+from pathlib import Path
+from dp2 import utils
+import tops
+try:
+ import clip
+except ImportError:
+ print("Could not import clip.")
+from torch_fidelity.metric_fid import fid_features_to_statistics, fid_statistics_to_metric
+clip_model = None
+clip_preprocess = None
+
+
+@torch.no_grad()
+def compute_fid_clip(
+ dataloader, generator,
+ cache_directory,
+ data_len=None,
+ **kwargs
+ ) -> dict:
+ """
+ FID CLIP following the description in The Role of ImageNet Classes in Frechet Inception Distance, Thomas Kynkaamniemi et al.
+ Args:
+ n_samples (int): Creates N samples from same image to calculate stats
+ """
+ global clip_model, clip_preprocess
+ if clip_model is None:
+ clip_model, preprocess = clip.load("ViT-B/32", device="cpu")
+ normalize_fn = preprocess.transforms[-1]
+ img_mean = normalize_fn.mean
+ img_std = normalize_fn.std
+ clip_model = tops.to_cuda(clip_model.visual)
+ clip_preprocess = tops.to_cuda(torch.nn.Sequential(
+ torchvision.transforms.Resize((224, 224), interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
+ torchvision.transforms.Normalize(img_mean, img_std)
+ ))
+ cache_directory = Path(cache_directory)
+ if data_len is None:
+ data_len = len(dataloader)*dataloader.batch_size
+ fid_cache_path = cache_directory.joinpath("fid_stats_clip.pkl")
+ has_fid_cache = fid_cache_path.is_file()
+ if not has_fid_cache:
+ fid_features_real = torch.zeros(data_len, 512, dtype=torch.float32, device=tops.get_device())
+ fid_features_fake = torch.zeros(data_len, 512, dtype=torch.float32, device=tops.get_device())
+ eidx = 0
+ n_samples_seen = 0
+ for batch in utils.tqdm_(iter(dataloader), desc="Computing FID CLIP."):
+ sidx = eidx
+ eidx = sidx + batch["img"].shape[0]
+ n_samples_seen += batch["img"].shape[0]
+ with torch.cuda.amp.autocast(tops.AMP()):
+ fakes = generator(**batch)["img"]
+ real_data = batch["img"]
+ fakes = utils.denormalize_img(fakes)
+ real_data = utils.denormalize_img(real_data)
+ if not has_fid_cache:
+ real_data = clip_preprocess(real_data)
+ fid_features_real[sidx:eidx] = clip_model(real_data)
+ fakes = clip_preprocess(fakes)
+ fid_features_fake[sidx:eidx] = clip_model(fakes)
+ fid_features_fake = fid_features_fake[:n_samples_seen]
+ fid_features_fake = tops.all_gather_uneven(fid_features_fake).cpu()
+ if has_fid_cache:
+ if tops.rank() == 0:
+ with open(fid_cache_path, "rb") as fp:
+ fid_stat_real = pickle.load(fp)
+ else:
+ fid_features_real = fid_features_real[:n_samples_seen]
+ fid_features_real = tops.all_gather_uneven(fid_features_real).cpu()
+ assert fid_features_real.shape == fid_features_fake.shape
+ if tops.rank() == 0:
+ fid_stat_real = fid_features_to_statistics(fid_features_real)
+ cache_directory.mkdir(exist_ok=True, parents=True)
+ with open(fid_cache_path, "wb") as fp:
+ pickle.dump(fid_stat_real, fp)
+
+ if tops.rank() == 0:
+ print("Starting calculation of fid from features of shape:", fid_features_fake.shape)
+ fid_stat_fake = fid_features_to_statistics(fid_features_fake)
+ fid_ = fid_statistics_to_metric(fid_stat_real, fid_stat_fake, verbose=False)["frechet_inception_distance"]
+ return dict(fid_clip=fid_)
+ return dict(fid_clip=-1)
diff --git a/dp2/metrics/lpips.py b/dp2/metrics/lpips.py
new file mode 100644
index 0000000000000000000000000000000000000000..397d1b12cd6952aafb6929bc3fa33f39ba509e33
--- /dev/null
+++ b/dp2/metrics/lpips.py
@@ -0,0 +1,77 @@
+import torch
+import tops
+import sys
+from contextlib import redirect_stdout
+from torch_fidelity.sample_similarity_lpips import NetLinLayer, URL_VGG16_LPIPS, VGG16features, normalize_tensor, spatial_average
+
+
+class SampleSimilarityLPIPS(torch.nn.Module):
+ SUPPORTED_DTYPES = {
+ 'uint8': torch.uint8,
+ 'float32': torch.float32,
+ }
+
+ def __init__(self):
+
+ super().__init__()
+ self.chns = [64, 128, 256, 512, 512]
+ self.L = len(self.chns)
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=True)
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=True)
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=True)
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=True)
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=True)
+ self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
+ with redirect_stdout(sys.stderr):
+ fp = tops.download_file(URL_VGG16_LPIPS)
+ state_dict = torch.load(fp, map_location="cpu")
+ self.load_state_dict(state_dict)
+ self.net = VGG16features()
+ self.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+ mean_rescaled = (1 + torch.tensor([-.030, -.088, -.188]).view(1, 3, 1, 1)) * 255 / 2
+ inv_std_rescaled = 2 / (torch.tensor([.458, .448, .450]).view(1, 3, 1, 1) * 255)
+ self.register_buffer("mean", mean_rescaled)
+ self.register_buffer("std", inv_std_rescaled)
+
+ def normalize(self, x):
+ # torchvision values in range [0,1] mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]
+ x = (x.float() - self.mean) * self.std
+ return x
+
+ @staticmethod
+ def resize(x, size):
+ if x.shape[-1] > size and x.shape[-2] > size:
+ x = torch.nn.functional.interpolate(x, (size, size), mode='area')
+ else:
+ x = torch.nn.functional.interpolate(x, (size, size), mode='bilinear', align_corners=False)
+ return x
+
+ def lpips_from_feats(self, feats0, feats1):
+ diffs = {}
+ for kk in range(self.L):
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
+
+ res = [spatial_average(self.lins[kk].model(diffs[kk])) for kk in range(self.L)]
+ val = sum(res)
+ return val
+
+ def get_feats(self, x):
+ assert x.dim() == 4 and x.shape[1] == 3, 'Input 0 is not Bx3xHxW'
+ if x.shape[-2] < 16 or x.shape[-1] < 16: # Resize images < 16x16
+ f = 2
+ size = tuple([int(f*_) for _ in x.shape[-2:]])
+ x = torch.nn.functional.interpolate(x, size=size, mode="bilinear", align_corners=False)
+ in0_input = self.normalize(x)
+ outs0 = self.net.forward(in0_input)
+
+ feats = {}
+ for kk in range(self.L):
+ feats[kk] = normalize_tensor(outs0[kk])
+ return feats
+
+ def forward(self, in0, in1):
+ feats0 = self.get_feats(in0)
+ feats1 = self.get_feats(in1)
+ return self.lpips_from_feats(feats0, feats1), feats0, feats1
diff --git a/dp2/metrics/ppl.py b/dp2/metrics/ppl.py
new file mode 100644
index 0000000000000000000000000000000000000000..421aeafc5edc4647037fdc390737b269cdfbeae5
--- /dev/null
+++ b/dp2/metrics/ppl.py
@@ -0,0 +1,116 @@
+import numpy as np
+import torch
+import tops
+from dp2 import utils
+from torch_fidelity.helpers import get_kwarg, vassert
+from torch_fidelity.defaults import DEFAULTS as PPL_DEFAULTS
+from torch_fidelity.utils import sample_random, batch_interp, create_sample_similarity
+from torchvision.transforms.functional import resize
+
+
+def slerp(a, b, t):
+ a = a / a.norm(dim=-1, keepdim=True)
+ b = b / b.norm(dim=-1, keepdim=True)
+ d = (a * b).sum(dim=-1, keepdim=True)
+ p = t * torch.acos(d)
+ c = b - d * a
+ c = c / c.norm(dim=-1, keepdim=True)
+ d = a * torch.cos(p) + c * torch.sin(p)
+ d = d / d.norm(dim=-1, keepdim=True)
+ return d
+
+
+@torch.no_grad()
+def calculate_ppl(
+ dataloader,
+ generator,
+ latent_space=None,
+ data_len=None,
+ upsample_size=None,
+ **kwargs) -> dict:
+ """
+ Inspired by https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py
+ """
+ if latent_space is None:
+ latent_space = generator.latent_space
+ assert latent_space in ["Z", "W"], f"Not supported latent space: {latent_space}"
+ assert len(upsample_size) == 2
+ epsilon = PPL_DEFAULTS["ppl_epsilon"]
+ interp = PPL_DEFAULTS['ppl_z_interp_mode']
+ similarity_name = PPL_DEFAULTS['ppl_sample_similarity']
+ sample_similarity_resize = PPL_DEFAULTS['ppl_sample_similarity_resize']
+ sample_similarity_dtype = PPL_DEFAULTS['ppl_sample_similarity_dtype']
+ discard_percentile_lower = PPL_DEFAULTS['ppl_discard_percentile_lower']
+ discard_percentile_higher = PPL_DEFAULTS['ppl_discard_percentile_higher']
+
+ vassert(type(epsilon) is float and epsilon > 0, 'Epsilon must be a small positive floating point number')
+ vassert(discard_percentile_lower is None or 0 < discard_percentile_lower < 100, 'Invalid percentile')
+ vassert(discard_percentile_higher is None or 0 < discard_percentile_higher < 100, 'Invalid percentile')
+ if discard_percentile_lower is not None and discard_percentile_higher is not None:
+ vassert(0 < discard_percentile_lower < discard_percentile_higher < 100, 'Invalid percentiles')
+
+ sample_similarity = create_sample_similarity(
+ similarity_name,
+ sample_similarity_resize=sample_similarity_resize,
+ sample_similarity_dtype=sample_similarity_dtype,
+ cuda=False,
+ **kwargs
+ )
+ sample_similarity = tops.to_cuda(sample_similarity)
+ rng = np.random.RandomState(get_kwarg('rng_seed', kwargs))
+ distances = []
+ if data_len is None:
+ data_len = len(dataloader) * dataloader.batch_size
+ z0 = sample_random(rng, (data_len, generator.z_channels), "normal")
+ z1 = sample_random(rng, (data_len, generator.z_channels), "normal")
+ if latent_space == "Z":
+ z1 = batch_interp(z0, z1, epsilon, interp)
+ print("Computing PPL IN", latent_space)
+ distances = torch.zeros(data_len, dtype=torch.float32, device=tops.get_device())
+ print(distances.shape)
+ end = 0
+ n_samples = 0
+ for it, batch in enumerate(utils.tqdm_(dataloader, desc="Perceptual Path Length")):
+ start = end
+ end = start + batch["img"].shape[0]
+ n_samples += batch["img"].shape[0]
+ batch_lat_e0 = tops.to_cuda(z0[start:end])
+ batch_lat_e1 = tops.to_cuda(z1[start:end])
+ if latent_space == "W":
+ w0 = generator.get_w(batch_lat_e0, update_emas=False)
+ w1 = generator.get_w(batch_lat_e1, update_emas=False)
+ w1 = w0.lerp(w1, epsilon) # PPL end
+ rgb1 = generator(**batch, w=w0)["img"]
+ rgb2 = generator(**batch, w=w1)["img"]
+ else:
+ rgb1 = generator(**batch, z=batch_lat_e0)["img"]
+ rgb2 = generator(**batch, z=batch_lat_e1)["img"]
+ if rgb1.shape[-2] < upsample_size[0] or rgb1.shape[-1] < upsample_size[1]:
+ rgb1 = resize(rgb1, upsample_size, antialias=True)
+ rgb2 = resize(rgb2, upsample_size, antialias=True)
+ rgb1 = utils.denormalize_img(rgb1).mul(255).byte()
+ rgb2 = utils.denormalize_img(rgb2).mul(255).byte()
+
+ sim = sample_similarity(rgb1, rgb2)
+ dist_lat_e01 = sim / (epsilon ** 2)
+ distances[start:end] = dist_lat_e01.view(-1)
+ distances = distances[:n_samples]
+ distances = tops.all_gather_uneven(distances).cpu().numpy()
+ if tops.rank() != 0:
+ return {"ppl/mean": -1, "ppl/std": -1}
+ if tops.rank() == 0:
+ cond, lo, hi = None, None, None
+ if discard_percentile_lower is not None:
+ lo = np.percentile(distances, discard_percentile_lower, interpolation='lower')
+ cond = lo <= distances
+ if discard_percentile_higher is not None:
+ hi = np.percentile(distances, discard_percentile_higher, interpolation='higher')
+ cond = np.logical_and(cond, distances <= hi)
+ if cond is not None:
+ distances = np.extract(cond, distances)
+ return {
+ "ppl/mean": float(np.mean(distances)),
+ "ppl/std": float(np.std(distances)),
+ }
+ else:
+ return {"ppl/mean"}
diff --git a/dp2/metrics/torch_metrics.py b/dp2/metrics/torch_metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c8747b01a98f3aa4df9161e09c309c101b396f2
--- /dev/null
+++ b/dp2/metrics/torch_metrics.py
@@ -0,0 +1,177 @@
+import pickle
+import numpy as np
+import torch
+import time
+from pathlib import Path
+from dp2 import utils
+import tops
+from .lpips import SampleSimilarityLPIPS
+from torch_fidelity.defaults import DEFAULTS as trf_defaults
+from torch_fidelity.metric_fid import fid_features_to_statistics, fid_statistics_to_metric
+from torch_fidelity.utils import create_feature_extractor
+lpips_model = None
+fid_model = None
+
+
+@torch.no_grad()
+def mse(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor:
+ se = (images1 - images2) ** 2
+ se = se.view(images1.shape[0], -1).mean(dim=1)
+ return se
+
+
+@torch.no_grad()
+def psnr(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor:
+ mse_ = mse(images1, images2)
+ psnr = 10 * torch.log10(1 / mse_)
+ return psnr
+
+
+@torch.no_grad()
+def lpips(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor:
+ return _lpips_w_grad(images1, images2)
+
+
+def _lpips_w_grad(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor:
+ global lpips_model
+ if lpips_model is None:
+ lpips_model = tops.to_cuda(SampleSimilarityLPIPS())
+
+ images1 = images1.mul(255)
+ images2 = images2.mul(255)
+ with torch.cuda.amp.autocast(tops.AMP()):
+ dists = lpips_model(images1, images2)[0].view(-1)
+ return dists
+
+
+@torch.no_grad()
+def compute_metrics_iteratively(
+ dataloader, generator,
+ cache_directory,
+ data_len=None,
+ truncation_value: float = None,
+) -> dict:
+ """
+ Args:
+ n_samples (int): Creates N samples from same image to calculate stats
+ dataset_percentage (float): The percentage of the dataset to compute metrics on.
+ """
+
+ global lpips_model, fid_model
+ if lpips_model is None:
+ lpips_model = tops.to_cuda(SampleSimilarityLPIPS())
+ if fid_model is None:
+ fid_model = create_feature_extractor(
+ trf_defaults["feature_extractor"], [trf_defaults["feature_layer_fid"]], cuda=False)
+ fid_model = tops.to_cuda(fid_model)
+ cache_directory = Path(cache_directory)
+ start_time = time.time()
+ lpips_total = torch.tensor(0, dtype=torch.float32, device=tops.get_device())
+ diversity_total = torch.zeros_like(lpips_total)
+ fid_cache_path = cache_directory.joinpath("fid_stats.pkl")
+ has_fid_cache = fid_cache_path.is_file()
+ if data_len is None:
+ data_len = len(dataloader)*dataloader.batch_size
+ if not has_fid_cache:
+ fid_features_real = torch.zeros(data_len, 2048, dtype=torch.float32, device=tops.get_device())
+ fid_features_fake = torch.zeros(data_len, 2048, dtype=torch.float32, device=tops.get_device())
+ n_samples_seen = torch.tensor([0], dtype=torch.int32, device=tops.get_device())
+ eidx = 0
+ for batch in utils.tqdm_(iter(dataloader), desc="Computing FID, LPIPS and LPIPS Diversity"):
+ sidx = eidx
+ eidx = sidx + batch["img"].shape[0]
+ n_samples_seen += batch["img"].shape[0]
+ with torch.cuda.amp.autocast(tops.AMP()):
+ fakes1 = generator.sample(**batch, truncation_value=truncation_value)["img"]
+ fakes2 = generator.sample(**batch, truncation_value=truncation_value)["img"]
+ fakes1 = utils.denormalize_img(fakes1).mul(255)
+ fakes2 = utils.denormalize_img(fakes2).mul(255)
+ real_data = utils.denormalize_img(batch["img"]).mul(255)
+ lpips_1, real_lpips_feats, fake1_lpips_feats = lpips_model(real_data, fakes1)
+ fake2_lpips_feats = lpips_model.get_feats(fakes2)
+ lpips_2 = lpips_model.lpips_from_feats(real_lpips_feats, fake2_lpips_feats)
+
+ lpips_total += lpips_1.sum().add(lpips_2.sum()).div(2)
+ diversity_total += lpips_model.lpips_from_feats(fake1_lpips_feats, fake2_lpips_feats).sum()
+ if not has_fid_cache:
+ fid_features_real[sidx:eidx] = fid_model(real_data.byte())[0]
+ fid_features_fake[sidx:eidx] = fid_model(fakes1.byte())[0]
+ fid_features_fake = fid_features_fake[:n_samples_seen]
+ if has_fid_cache:
+ if tops.rank() == 0:
+ with open(fid_cache_path, "rb") as fp:
+ fid_stat_real = pickle.load(fp)
+ else:
+ fid_features_real = fid_features_real[:n_samples_seen]
+ fid_features_real = tops.all_gather_uneven(fid_features_real).cpu()
+ if tops.rank() == 0:
+ fid_stat_real = fid_features_to_statistics(fid_features_real)
+ cache_directory.mkdir(exist_ok=True, parents=True)
+ with open(fid_cache_path, "wb") as fp:
+ pickle.dump(fid_stat_real, fp)
+ fid_features_fake = tops.all_gather_uneven(fid_features_fake).cpu()
+ if tops.rank() == 0:
+ print("Starting calculation of fid from features of shape:", fid_features_fake.shape)
+ fid_stat_fake = fid_features_to_statistics(fid_features_fake)
+ fid_ = fid_statistics_to_metric(fid_stat_real, fid_stat_fake, verbose=False)["frechet_inception_distance"]
+ tops.all_reduce(n_samples_seen, torch.distributed.ReduceOp.SUM)
+ tops.all_reduce(lpips_total, torch.distributed.ReduceOp.SUM)
+ tops.all_reduce(diversity_total, torch.distributed.ReduceOp.SUM)
+ lpips_total = lpips_total / n_samples_seen
+ diversity_total = diversity_total / n_samples_seen
+ to_return = dict(lpips=lpips_total, lpips_diversity=diversity_total)
+ if tops.rank() == 0:
+ to_return["fid"] = fid_
+ else:
+ to_return["fid"] = -1
+ to_return["validation_time_s"] = time.time() - start_time
+ return to_return
+
+
+@torch.no_grad()
+def compute_lpips(
+ dataloader, generator,
+ truncation_value: float = None,
+ data_len=None,
+ ) -> dict:
+ """
+ Args:
+ n_samples (int): Creates N samples from same image to calculate stats
+ dataset_percentage (float): The percentage of the dataset to compute metrics on.
+ """
+ global lpips_model, fid_model
+ if lpips_model is None:
+ lpips_model = tops.to_cuda(SampleSimilarityLPIPS())
+ start_time = time.time()
+ lpips_total = torch.tensor(0, dtype=torch.float32, device=tops.get_device())
+ diversity_total = torch.zeros_like(lpips_total)
+ if data_len is None:
+ data_len = len(dataloader) * dataloader.batch_size
+ eidx = 0
+ n_samples_seen = torch.tensor([0], dtype=torch.int32, device=tops.get_device())
+ for batch in utils.tqdm_(dataloader, desc="Validating on dataset."):
+ sidx = eidx
+ eidx = sidx + batch["img"].shape[0]
+ n_samples_seen += batch["img"].shape[0]
+ with torch.cuda.amp.autocast(tops.AMP()):
+ fakes1 = generator.sample(**batch, truncation_value=truncation_value)["img"]
+ fakes2 = generator.sample(**batch, truncation_value=truncation_value)["img"]
+ real_data = batch["img"]
+ fakes1 = utils.denormalize_img(fakes1).mul(255)
+ fakes2 = utils.denormalize_img(fakes2).mul(255)
+ real_data = utils.denormalize_img(real_data).mul(255)
+ lpips_1, real_lpips_feats, fake1_lpips_feats = lpips_model(real_data, fakes1)
+ fake2_lpips_feats = lpips_model.get_feats(fakes2)
+ lpips_2 = lpips_model.lpips_from_feats(real_lpips_feats, fake2_lpips_feats)
+
+ lpips_total += lpips_1.sum().add(lpips_2.sum()).div(2)
+ diversity_total += lpips_model.lpips_from_feats(fake1_lpips_feats, fake2_lpips_feats).sum()
+ tops.all_reduce(n_samples_seen, torch.distributed.ReduceOp.SUM)
+ tops.all_reduce(lpips_total, torch.distributed.ReduceOp.SUM)
+ tops.all_reduce(diversity_total, torch.distributed.ReduceOp.SUM)
+ lpips_total = lpips_total / n_samples_seen
+ diversity_total = diversity_total / n_samples_seen
+ to_return = dict(lpips=lpips_total, lpips_diversity=diversity_total)
+ to_return = {k: v.cpu().item() for k, v in to_return.items()}
+ to_return["validation_time_s"] = time.time() - start_time
+ return to_return
diff --git a/dp2/utils/__init__.py b/dp2/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4edbacb6e8032ea081839f1a2408d4101868e79
--- /dev/null
+++ b/dp2/utils/__init__.py
@@ -0,0 +1,24 @@
+import pathlib
+from tops.config import LazyConfig
+from .torch_utils import (
+ im2torch, im2numpy, denormalize_img, set_requires_grad, forward_D_fake,
+ binary_dilation, crop_box, remove_pad,
+ torch_wasserstein_loss
+)
+from .ema import EMA
+from .utils import init_tops, tqdm_, print_config, config_to_str, trange_
+from .cse import from_E_to_vertex
+
+
+def load_config(config_path):
+ config_path = pathlib.Path(config_path)
+ assert config_path.is_file(), config_path
+ cfg = LazyConfig.load(str(config_path))
+ cfg.output_dir = pathlib.Path(str(config_path).replace("configs", str(cfg.common.output_dir)).replace(".py", ""))
+ if cfg.common.experiment_name is None:
+ cfg.experiment_name = str(config_path)
+ else:
+ cfg.experiment_name = cfg.common.experiment_name
+ cfg.checkpoint_dir = cfg.output_dir.joinpath("checkpoints")
+ print("Saving outputs to:", cfg.output_dir)
+ return cfg
diff --git a/dp2/utils/bufferless_video_capture.py b/dp2/utils/bufferless_video_capture.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd5e1006057706f32c6adaeb812bf4834bbdfd28
--- /dev/null
+++ b/dp2/utils/bufferless_video_capture.py
@@ -0,0 +1,32 @@
+import queue
+import threading
+import cv2
+
+
+class BufferlessVideoCapture:
+
+ def __init__(self, name, width=None, height=None):
+ self.cap = cv2.VideoCapture(name)
+ if width is not None and height is not None:
+ self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
+ self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
+ self.q = queue.Queue()
+ t = threading.Thread(target=self._reader)
+ t.daemon = True
+ t.start()
+
+ # read frames as soon as they are available, keeping only most recent one
+ def _reader(self):
+ while True:
+ ret, frame = self.cap.read()
+ if not ret:
+ break
+ if not self.q.empty():
+ try:
+ self.q.get_nowait() # discard previous (unprocessed) frame
+ except queue.Empty:
+ pass
+ self.q.put((ret, frame))
+
+ def read(self):
+ return self.q.get()
diff --git a/dp2/utils/cse.py b/dp2/utils/cse.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd3e01d28ba10e6d4d14ecd49c3a70dcaaa194ce
--- /dev/null
+++ b/dp2/utils/cse.py
@@ -0,0 +1,21 @@
+import warnings
+import torch
+from densepose.modeling.cse.utils import get_closest_vertices_mask_from_ES
+
+
+def from_E_to_vertex(E, M, embed_map):
+ """
+ M is 1 for unkown regions
+ """
+ assert len(E.shape) == 4
+ assert len(E.shape) == len(M.shape), (E.shape, M.shape)
+ assert E.shape[0] == 1
+ M = M.float()
+ M = torch.cat([M, 1-M], dim=1)
+ with warnings.catch_warnings(): # Ignore userError for pytorch interpolate from detectron2
+ warnings.filterwarnings("ignore")
+ vertices, _ = get_closest_vertices_mask_from_ES(
+ E, M, E.shape[2], E.shape[3],
+ embed_map, device=E.device)
+
+ return vertices.long()
diff --git a/dp2/utils/ema.py b/dp2/utils/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..475e6b5192575ad5a54541714b6c932227cbe7a3
--- /dev/null
+++ b/dp2/utils/ema.py
@@ -0,0 +1,80 @@
+import torch
+import copy
+import tops
+from tops import logger
+from .torch_utils import set_requires_grad
+
+
+class EMA:
+ """
+ Expoenential moving average.
+ See:
+ Yazici, Y. et al.The unusual effectiveness of averaging in GAN training. ICLR 2019
+
+ """
+
+ def __init__(
+ self,
+ generator: torch.nn.Module,
+ batch_size: int,
+ rampup: float,
+ ):
+ self.rampup = rampup
+ self._nimg_half_time = batch_size * 10 / 32 * 1000
+ self._batch_size = batch_size
+ with torch.no_grad():
+ self.generator = copy.deepcopy(generator.cpu()).eval()
+ self.generator = tops.to_cuda(self.generator)
+ self.old_ra_beta = 0
+ set_requires_grad(self.generator, False)
+
+ def update_beta(self):
+ y = self._nimg_half_time
+ global_step = logger.global_step()
+ if self.rampup != None:
+ y = min(y, global_step*self.rampup)
+ self.ra_beta = 0.5 ** (self._batch_size/max(y, 1e-8))
+ if self.ra_beta != self.old_ra_beta:
+ logger.add_scalar("stats/EMA_beta", self.ra_beta)
+ self.old_ra_beta = self.ra_beta
+
+ @torch.no_grad()
+ def update(self, normal_G):
+ with torch.autograd.profiler.record_function("EMA_update"):
+ for ema_p, p in zip(self.generator.parameters(),
+ normal_G.parameters()):
+ ema_p.copy_(p.lerp(ema_p, self.ra_beta))
+ for ema_buf, buff in zip(self.generator.buffers(),
+ normal_G.buffers()):
+ ema_buf.copy_(buff)
+
+ def __call__(self, *args, **kwargs):
+ return self.generator(*args, **kwargs)
+
+ def __getattr__(self, name: str):
+ if hasattr(self.generator, name):
+ return getattr(self.generator, name)
+ raise AttributeError(f"Generator object has no attribute {name}")
+
+ def cuda(self, *args, **kwargs):
+ self.generator = self.generator.cuda()
+ return self
+
+ def state_dict(self, *args, **kwargs):
+ return self.generator.state_dict(*args, **kwargs)
+
+ def load_state_dict(self, *args, **kwargs):
+ return self.generator.load_state_dict(*args, **kwargs)
+
+ def eval(self):
+ self.generator.eval()
+
+ def train(self):
+ self.generator.train()
+
+ @property
+ def module(self):
+ return self.generator.module
+
+ def sample(self, *args, **kwargs):
+ return self.generator.sample(*args, **kwargs)
diff --git a/dp2/utils/torch_utils.py b/dp2/utils/torch_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..80ab53e3dcedce4710a41d1d58bd292a9fa08432
--- /dev/null
+++ b/dp2/utils/torch_utils.py
@@ -0,0 +1,140 @@
+import torch
+import tops
+
+
+def denormalize_img(image, mean=0.5, std=0.5):
+ image = image * std + mean
+ image = torch.clamp(image.float(), 0, 1)
+ image = (image * 255)
+ image = torch.round(image)
+ return image / 255
+
+
+@torch.no_grad()
+def im2numpy(images, to_uint8=False, denormalize=False):
+ if denormalize:
+ images = denormalize_img(images)
+ if images.dtype != torch.uint8:
+ images = images.clamp(0, 1)
+ return tops.im2numpy(images, to_uint8=to_uint8)
+
+
+@torch.no_grad()
+def im2torch(im, cuda=False, normalize=True, to_float=True):
+ im = tops.im2torch(im, cuda=cuda, to_float=to_float)
+ if normalize:
+ assert im.min() >= 0.0 and im.max() <= 1.0
+ if normalize:
+ im = im * 2 - 1
+ return im
+
+
+@torch.no_grad()
+def binary_dilation(im: torch.Tensor, kernel: torch.Tensor):
+ assert len(im.shape) == 4
+ assert len(kernel.shape) == 2
+ kernel = kernel.unsqueeze(0).unsqueeze(0)
+ padding = kernel.shape[-1]//2
+ assert kernel.shape[-1] % 2 != 0
+ if isinstance(im, torch.cuda.FloatTensor):
+ im, kernel = im.half(), kernel.half()
+ else:
+ im, kernel = im.float(), kernel.float()
+ im = torch.nn.functional.conv2d(
+ im, kernel, groups=im.shape[1], padding=padding)
+ im = im > 0.5
+ return im
+
+
+@torch.no_grad()
+def binary_erosion(im: torch.Tensor, kernel: torch.Tensor):
+ assert len(im.shape) == 4
+ assert len(kernel.shape) == 2
+ kernel = kernel.unsqueeze(0).unsqueeze(0)
+ padding = kernel.shape[-1]//2
+ assert kernel.shape[-1] % 2 != 0
+ if isinstance(im, torch.cuda.FloatTensor):
+ im, kernel = im.half(), kernel.half()
+ else:
+ im, kernel = im.float(), kernel.float()
+ ksum = kernel.sum()
+ padding = (padding, padding, padding, padding)
+ im = torch.nn.functional.pad(im, padding, mode="reflect")
+ im = torch.nn.functional.conv2d(
+ im, kernel, groups=im.shape[1])
+ return im.round() == ksum
+
+
+def set_requires_grad(value: torch.nn.Module, requires_grad: bool):
+ if isinstance(value, (list, tuple)):
+ for param in value:
+ param.requires_grad = requires_grad
+ return
+ for p in value.parameters():
+ p.requires_grad = requires_grad
+
+
+def forward_D_fake(batch, fake_img, discriminator, **kwargs):
+ fake_batch = {k: v for k, v in batch.items() if k != "img"}
+ fake_batch["img"] = fake_img
+ return discriminator(**fake_batch, **kwargs)
+
+
+def remove_pad(x: torch.Tensor, bbox_XYXY, imshape):
+ """
+ Remove padding that is shown as negative
+ """
+ H, W = imshape
+ x0, y0, x1, y1 = bbox_XYXY
+ padding = [
+ max(0, -x0),
+ max(0, -y0),
+ max(x1 - W, 0),
+ max(y1 - H, 0)
+ ]
+ x0, y0 = padding[:2]
+ x1 = x.shape[2] - padding[2]
+ y1 = x.shape[1] - padding[3]
+ return x[:, y0:y1, x0:x1]
+
+
+def crop_box(x: torch.Tensor, bbox_XYXY) -> torch.Tensor:
+ """
+ Crops x by bbox_XYXY.
+ """
+ x0, y0, x1, y1 = bbox_XYXY
+ x0 = max(x0, 0)
+ y0 = max(y0, 0)
+ x1 = min(x1, x.shape[-1])
+ y1 = min(y1, x.shape[-2])
+ return x[..., y0:y1, x0:x1]
+
+
+def torch_wasserstein_loss(tensor_a, tensor_b):
+ # Compute the first Wasserstein distance between two 1D distributions.
+ return (torch_cdf_loss(tensor_a, tensor_b, p=1))
+
+
+def torch_cdf_loss(tensor_a, tensor_b, p=1):
+ # last-dimension is weight distribution
+ # p is the norm of the distance, p=1 --> First Wasserstein Distance
+ # to get a positive weight with our normalized distribution
+ # we recommend combining this loss with other difference-based losses like L1
+
+ # normalize distribution, add 1e-14 to divisor to avoid 0/0
+ tensor_a = tensor_a / (torch.sum(tensor_a, dim=-1, keepdim=True) + 1e-14)
+ tensor_b = tensor_b / (torch.sum(tensor_b, dim=-1, keepdim=True) + 1e-14)
+ # make cdf with cumsum
+ cdf_tensor_a = torch.cumsum(tensor_a, dim=-1)
+ cdf_tensor_b = torch.cumsum(tensor_b, dim=-1)
+
+ # choose different formulas for different norm situations
+ if p == 1:
+ cdf_distance = torch.sum(torch.abs((cdf_tensor_a-cdf_tensor_b)), dim=-1)
+ elif p == 2:
+ cdf_distance = torch.sqrt(torch.sum(torch.pow((cdf_tensor_a-cdf_tensor_b), 2), dim=-1))
+ else:
+ cdf_distance = torch.pow(torch.sum(torch.pow(torch.abs(cdf_tensor_a-cdf_tensor_b), p), dim=-1), 1/p)
+
+ cdf_loss = cdf_distance.mean()
+ return cdf_loss
diff --git a/dp2/utils/utils.py b/dp2/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3dda0ecd189f2b1b21726a12ced68af348bdb5e
--- /dev/null
+++ b/dp2/utils/utils.py
@@ -0,0 +1,30 @@
+import tops
+import tqdm
+from tops import logger, highlight_py_str
+from tops.config import LazyConfig
+
+
+def print_config(cfg):
+ logger.log("\n" + highlight_py_str(LazyConfig.to_py(cfg, prefix="")))
+
+
+def config_to_str(cfg):
+ return LazyConfig.to_py(cfg, prefix=".")
+
+
+def init_tops(cfg, reinit=False):
+ tops.init(
+ cfg.output_dir, cfg.common.logger_backend, cfg.experiment_name,
+ cfg.common.wandb_project, dict(cfg), reinit)
+
+
+def tqdm_(iterator, *args, **kwargs):
+ if tops.rank() == 0:
+ return tqdm.tqdm(iterator, *args, **kwargs)
+ return iterator
+
+
+def trange_(*args, **kwargs):
+ if tops.rank() == 0:
+ return tqdm.trange(*args, **kwargs)
+ return range(*args)
diff --git a/dp2/utils/vis_utils.py b/dp2/utils/vis_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..996e062ce082cc4da8355551317f4678bc75feaa
--- /dev/null
+++ b/dp2/utils/vis_utils.py
@@ -0,0 +1,440 @@
+import torch
+import tops
+import cv2
+import torchvision.transforms.functional as F
+from typing import Optional, List, Union, Tuple
+from .cse import from_E_to_vertex
+import numpy as np
+from tops import download_file
+from .torch_utils import (
+ denormalize_img, binary_dilation, binary_erosion,
+ remove_pad, crop_box)
+from torchvision.utils import _generate_color_palette
+from PIL import Image, ImageColor, ImageDraw
+
+
+def get_coco_keypoints():
+ # From: https://github.com/facebookresearch/Detectron/blob/main/detectron/utils/keypoints.py
+ keypoints = [
+ 'nose',
+ 'left_eye',
+ 'right_eye',
+ 'left_ear',
+ 'right_ear',
+ 'left_shoulder',
+ 'right_shoulder',
+ 'left_elbow',
+ 'right_elbow',
+ 'left_wrist',
+ 'right_wrist',
+ 'left_hip',
+ 'right_hip',
+ 'left_knee',
+ 'right_knee',
+ 'left_ankle',
+ 'right_ankle'
+ ]
+ keypoint_flip_map = {
+ 'left_eye': 'right_eye',
+ 'left_ear': 'right_ear',
+ 'left_shoulder': 'right_shoulder',
+ 'left_elbow': 'right_elbow',
+ 'left_wrist': 'right_wrist',
+ 'left_hip': 'right_hip',
+ 'left_knee': 'right_knee',
+ 'left_ankle': 'right_ankle'
+ }
+ connectivity = {
+ "nose": "left_eye",
+ "left_eye": "right_eye",
+ "right_eye": "nose",
+ "left_ear": "left_eye",
+ "right_ear": "right_eye",
+ "left_shoulder": "nose",
+ "right_shoulder": "nose",
+ "left_elbow": "left_shoulder",
+ "right_elbow": "right_shoulder",
+ "left_wrist": "left_elbow",
+ "right_wrist": "right_elbow",
+ "left_hip": "left_shoulder",
+ "right_hip": "right_shoulder",
+ "left_knee": "left_hip",
+ "right_knee": "right_hip",
+ "left_ankle": "left_knee",
+ "right_ankle": "right_knee"
+ }
+ connectivity_indices = [
+ (sidx, keypoints.index(connectivity[kp]))
+ for sidx, kp in enumerate(keypoints)
+ ]
+ return keypoints, keypoint_flip_map, connectivity_indices
+
+
+def get_coco_colors():
+ return [
+ *["red"]*5,
+ "blue",
+ "green",
+ "blue",
+ "green",
+ "blue",
+ "green",
+ "purple",
+ "orange",
+ "purple",
+ "orange",
+ "purple",
+ "orange",
+ ]
+
+
+@torch.no_grad()
+def draw_keypoints(
+ image: torch.Tensor,
+ keypoints: torch.Tensor,
+ connectivity: Optional[List[Tuple[int, int]]] = None,
+ visible: Optional[List[List[bool]]] = None,
+ colors: Optional[Union[str, Tuple[int, int, int]]] = None,
+ radius: int = None,
+ width: int = None,
+) -> torch.Tensor:
+ """
+ Function taken from torchvision source code. Added in torchvision 0.12
+
+ Draws Keypoints on given RGB image.
+ The values of the input image should be uint8 between 0 and 255.
+
+ Args:
+ image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
+ keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances,
+ in the format [x, y].
+ connectivity (List[Tuple[int, int]]]): A List of tuple where,
+ each tuple contains pair of keypoints to be connected.
+ colors (str, Tuple): The color can be represented as
+ PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
+ radius (int): Integer denoting radius of keypoint.
+ width (int): Integer denoting width of line connecting keypoints.
+
+ Returns:
+ img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn.
+ """
+
+ if not isinstance(image, torch.Tensor):
+ raise TypeError(f"The image must be a tensor, got {type(image)}")
+ elif image.dtype != torch.uint8:
+ raise ValueError(f"The image dtype must be uint8, got {image.dtype}")
+ elif image.dim() != 3:
+ raise ValueError("Pass individual images, not batches")
+ elif image.size()[0] != 3:
+ raise ValueError("Pass an RGB image. Other Image formats are not supported")
+
+ if keypoints.ndim != 3:
+ raise ValueError("keypoints must be of shape (num_instances, K, 2)")
+ if width is None:
+ width = int(max(max(image.shape[-2:]) * 0.01, 1))
+ if radius is None:
+ radius = int(max(max(image.shape[-2:]) * 0.01, 1))
+
+ ndarr = image.permute(1, 2, 0).cpu().numpy()
+ img_to_draw = Image.fromarray(ndarr)
+ draw = ImageDraw.Draw(img_to_draw)
+ if isinstance(keypoints, torch.Tensor):
+ img_kpts = keypoints.to(torch.int64).tolist()
+ else:
+ assert isinstance(keypoints, np.ndarray)
+ img_kpts = keypoints.astype(int).tolist()
+ colors = get_coco_colors()
+ for inst_id, kpt_inst in enumerate(img_kpts):
+
+ for kpt_id, kpt in enumerate(kpt_inst):
+ if visible is not None and int(visible[inst_id][kpt_id]) == 0:
+ continue
+ x1 = kpt[0] - radius
+ x2 = kpt[0] + radius
+ y1 = kpt[1] - radius
+ y2 = kpt[1] + radius
+
+ draw.ellipse([x1, y1, x2, y2], fill=colors[kpt_id], outline=None, width=0)
+
+ if connectivity is not None:
+ for connection in connectivity:
+ if connection[1] >= len(kpt_inst) or connection[0] >= len(kpt_inst):
+ continue
+ if visible is not None and int(visible[inst_id][connection[1]]) == 0 or int(visible[inst_id][connection[0]]) == 0:
+ continue
+
+ start_pt_x = kpt_inst[connection[0]][0]
+ start_pt_y = kpt_inst[connection[0]][1]
+
+ end_pt_x = kpt_inst[connection[1]][0]
+ end_pt_y = kpt_inst[connection[1]][1]
+
+ draw.line(
+ ((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)),
+ width=width, fill=colors[connection[1]]
+ )
+
+ return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
+
+
+def visualize_keypoints(img, keypoints):
+ img = img.clone()
+ keypoints = keypoints.clone()
+ keypoints[:, :, 0] *= img.shape[-1]
+ keypoints[:, :, 1] *= img.shape[-2]
+ _, _, connectivity = get_coco_keypoints()
+ connectivity = np.array(connectivity)
+ visible = None
+ if keypoints.shape[-1] == 3:
+ visible = keypoints[:, :, 2] > 0
+ for idx in range(img.shape[0]):
+ img[idx] = draw_keypoints(
+ img[idx], keypoints[idx:idx+1].long(), colors="red",
+ connectivity=connectivity, visible=visible[idx:idx+1])
+ return img
+
+
+def visualize_batch(
+ img: torch.Tensor, mask: torch.Tensor,
+ vertices: torch.Tensor = None,
+ E_mask: torch.Tensor = None,
+ embed_map: torch.Tensor = None,
+ semantic_mask: torch.Tensor = None,
+ embedding: torch.Tensor = None,
+ keypoints: torch.Tensor = None,
+ maskrcnn_mask: torch.Tensor = None,
+ **kwargs) -> torch.ByteTensor:
+ img = denormalize_img(img).mul(255).round().clamp(0, 255).byte()
+ img = draw_mask(img, mask)
+ if maskrcnn_mask is not None and maskrcnn_mask.shape == mask.shape:
+ img = draw_mask(img, maskrcnn_mask)
+ if vertices is not None or embedding is not None:
+ assert E_mask is not None
+ assert embed_map is not None
+ img, E_mask, embedding, embed_map, vertices = tops.to_cpu([
+ img, E_mask, embedding, embed_map, vertices
+ ])
+ img = draw_cse(img, E_mask, embedding, embed_map, vertices)
+ elif semantic_mask is not None:
+ img = draw_segmentation_masks(img, semantic_mask)
+ if keypoints is not None:
+ img = visualize_keypoints(img, keypoints)
+ return img
+
+
+@torch.no_grad()
+def draw_cse(
+ img: torch.Tensor, E_seg: torch.Tensor,
+ embedding: torch.Tensor = None,
+ embed_map: torch.Tensor = None,
+ vertices: torch.Tensor = None, t=0.7
+):
+ """
+ E_seg: 1 for areas with embedding
+ """
+ assert img.dtype == torch.uint8
+ img = img.view(-1, *img.shape[-3:])
+ E_seg = E_seg.view(-1, 1, *E_seg.shape[-2:])
+ if vertices is None:
+ assert embedding is not None
+ assert embed_map is not None
+ embedding = embedding.view(-1, *embedding.shape[-3:])
+ vertices = torch.stack(
+ [from_E_to_vertex(e[None], e_seg[None].logical_not().float(), embed_map)
+ for e, e_seg in zip(embedding, E_seg)])
+
+ i = np.arange(0, 256, dtype=np.uint8).reshape(1, -1)
+ colormap_JET = torch.from_numpy(cv2.applyColorMap(i, cv2.COLORMAP_JET)[0])
+ color_embed_map, _ = np.load(download_file(
+ "https://dl.fbaipublicfiles.com/densepose/data/cse/mds_d=256.npy"), allow_pickle=True)
+ color_embed_map = torch.from_numpy(color_embed_map).float()[:, 0]
+ color_embed_map -= color_embed_map.min()
+ color_embed_map /= color_embed_map.max()
+ vertx2idx = (color_embed_map*255).long()
+ vertx2colormap = colormap_JET[vertx2idx]
+
+ vertices = vertices.view(-1, *vertices.shape[-2:])
+ E_seg = E_seg.view(-1, 1, *E_seg.shape[-2:])
+ # This operation might be good to do on cpu...
+
+ E_color = vertx2colormap[vertices.long()]
+ E_color = E_color.to(E_seg.device)
+ E_color = E_color.permute(0, 3, 1, 2)
+ E_color = E_color*E_seg.byte()
+
+ m = E_seg.bool().repeat(1, 3, 1, 1)
+ img[m] = (img[m] * (1-t) + t * E_color[m]).byte()
+ return img
+
+
+def draw_cse_all(
+ embedding: List[torch.Tensor], E_mask: List[torch.Tensor],
+ im: torch.Tensor, boxes_XYXY: list, embed_map: torch.Tensor, t=0.7):
+ """
+ E_seg: 1 for areas with embedding
+ """
+ assert len(im.shape) == 3, im.shape
+ assert im.dtype == torch.uint8
+
+ N = len(E_mask)
+ im = im.clone()
+ for i in range(N):
+ assert len(E_mask[i].shape) == 2
+ assert len(embedding[i].shape) == 3
+ assert embed_map.shape[1] == embedding[i].shape[0]
+ assert len(boxes_XYXY[i]) == 4
+ E = embedding[i]
+ x0, y0, x1, y1 = boxes_XYXY[i]
+ E = F.resize(E, (y1-y0, x1-x0), antialias=True)
+ s = E_mask[i].float()
+ s = (F.resize(s.squeeze()[None], (y1-y0, x1-x0), antialias=True) > 0).float()
+ box = boxes_XYXY[i]
+
+ im_ = crop_box(im, box)
+ s = remove_pad(s, box, im.shape[1:])
+ E = remove_pad(E, box, im.shape[1:])
+ E_color = draw_cse(img=im_, E_seg=s[None], embedding=E[None], embed_map=embed_map)[0]
+ E_color = E_color.to(im.device)
+ s = s.bool().repeat(3, 1, 1)
+ crop_box(im, box)[s] = (im_[s] * (1-t) + t * E_color[s]).byte()
+ return im
+
+
+@torch.no_grad()
+def draw_segmentation_masks(
+ image: torch.Tensor,
+ masks: torch.Tensor,
+ alpha: float = 0.8,
+ colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None,
+) -> torch.Tensor:
+ """
+ Draws segmentation masks on given RGB image.
+ The values of the input image should be uint8 between 0 and 255.
+
+ Args:
+ image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
+ masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool.
+ alpha (float): Float number between 0 and 1 denoting the transparency of the masks.
+ 0 means full transparency, 1 means no transparency.
+ colors (list or None): List containing the colors of the masks. The colors can
+ be represented as PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
+ When ``masks`` has a single entry of shape (H, W), you can pass a single color instead of a list
+ with one element. By default, random colors are generated for each mask.
+
+ Returns:
+ img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top.
+ """
+
+ if not isinstance(image, torch.Tensor):
+ raise TypeError(f"The image must be a tensor, got {type(image)}")
+ elif image.dtype != torch.uint8:
+ raise ValueError(f"The image dtype must be uint8, got {image.dtype}")
+ elif image.dim() != 3:
+ raise ValueError("Pass individual images, not batches")
+ elif image.size()[0] != 3:
+ raise ValueError("Pass an RGB image. Other Image formats are not supported")
+ if masks.ndim == 2:
+ masks = masks[None, :, :]
+ if masks.ndim != 3:
+ raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)")
+ if masks.dtype != torch.bool:
+ raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}")
+ if masks.shape[-2:] != image.shape[-2:]:
+ raise ValueError("The image and the masks must have the same height and width")
+ num_masks = masks.size()[0]
+ if num_masks == 0:
+ return image
+ if colors is None:
+ colors = _generate_color_palette(num_masks)
+ if not isinstance(colors[0], (Tuple, List)):
+ colors = [colors for i in range(num_masks)]
+ if colors is not None and num_masks > len(colors):
+ raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})")
+
+ if not isinstance(colors, list):
+ colors = [colors]
+ if not isinstance(colors[0], (tuple, str)):
+ raise ValueError("colors must be a tuple or a string, or a list thereof")
+ if isinstance(colors[0], tuple) and len(colors[0]) != 3:
+ raise ValueError("It seems that you passed a tuple of colors instead of a list of colors")
+
+ out_dtype = torch.uint8
+
+ colors_ = []
+ for color in colors:
+ if isinstance(color, str):
+ color = ImageColor.getrgb(color)
+ color = torch.tensor(color, dtype=out_dtype, device=masks.device)
+ colors_.append(color)
+ img_to_draw = image.detach().clone()
+ # TODO: There might be a way to vectorize this
+ for mask, color in zip(masks, colors_):
+ img_to_draw[:, mask] = color[:, None]
+
+ out = image * (1 - alpha) + img_to_draw * alpha
+ return out.to(out_dtype)
+
+
+def draw_mask(im: torch.Tensor, mask: torch.Tensor, t=0.2, color=(255, 255, 255), visualize_instances=True):
+ """
+ Visualize mask where mask = 0.
+ Supports multiple instances.
+ mask shape: [N, C, H, W], where C is different instances in same image.
+ """
+ orig_imshape = im.shape
+ if mask.numel() == 0:
+ return im
+ assert len(mask.shape) in (3, 4), mask.shape
+ mask = mask.view(-1, *mask.shape[-3:])
+ im = im.view(-1, *im.shape[-3:])
+ assert im.dtype == torch.uint8, im.dtype
+ assert 0 <= t <= 1
+ if not visualize_instances:
+ mask = mask.any(dim=1, keepdim=True)
+ mask = mask.bool()
+ kernel = torch.ones((3, 3), dtype=mask.dtype, device=mask.device)
+ outer_border = binary_dilation(mask, kernel).logical_xor(mask)
+ outer_border = outer_border.any(dim=1, keepdim=True).repeat(1, 3, 1, 1) > 0
+ inner_border = binary_erosion(mask, kernel).logical_xor(mask)
+ inner_border = inner_border.any(dim=1, keepdim=True).repeat(1, 3, 1, 1) > 0
+ mask = (mask == 0).any(dim=1, keepdim=True).repeat(1, 3, 1, 1)
+ color = torch.tensor(color).to(im.device).byte().view(1, 3, 1, 1) # .repeat(1, *im.shape[1:])
+ color = color.repeat(im.shape[0], 1, *im.shape[-2:])
+ im[mask] = (im[mask] * (1-t) + t * color[mask]).byte()
+ im[outer_border] = 255
+ im[inner_border] = 0
+ return im.view(*orig_imshape)
+
+
+def draw_cropped_masks(im: torch.Tensor, mask: torch.Tensor, boxes: torch.Tensor, **kwargs):
+ for i, box in enumerate(boxes):
+ x0, y0, x1, y1 = boxes[i]
+ orig_shape = (y1-y0, x1-x0)
+ m = F.resize(mask[i], orig_shape, F.InterpolationMode.NEAREST).squeeze()[None]
+ m = remove_pad(m, boxes[i], im.shape[-2:])
+ crop_box(im, boxes[i]).set_(draw_mask(crop_box(im, boxes[i]), m))
+ return im
+
+
+def draw_cropped_keypoints(im: torch.Tensor, all_keypoints: torch.Tensor, boxes: torch.Tensor, **kwargs):
+ n_boxes = boxes.shape[0]
+ tops.assert_shape(all_keypoints, (n_boxes, 17, 3))
+ im = im.clone()
+ for i, box in enumerate(boxes):
+
+ x0, y0, x1, y1 = boxes[i]
+ orig_shape = (y1-y0, x1-x0)
+ keypoints = all_keypoints[i].clone()
+ keypoints[:, 0] *= orig_shape[1]
+ keypoints[:, 1] *= orig_shape[0]
+ keypoints = keypoints.long()
+ _, _, connectivity = get_coco_keypoints()
+ connectivity = np.array(connectivity)
+ visible = (keypoints[:, 2] > .5)
+ # Remove padding from keypoints before visualization
+ keypoints[:, 0] += min(x0, 0)
+ keypoints[:, 1] += min(y0, 0)
+ im_with_kp = draw_keypoints(
+ crop_box(im, box), keypoints[None], colors="red", connectivity=connectivity, visible=visible[None])
+ crop_box(im, box).copy_(im_with_kp)
+ return im
diff --git a/gradio_demos/body_cse.py b/gradio_demos/body_cse.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ddf1859352841539fdf5c4b6e9d0f098264208d
--- /dev/null
+++ b/gradio_demos/body_cse.py
@@ -0,0 +1,23 @@
+import gradio
+from dp2 import utils
+from tops.config import instantiate
+import gradio.inputs
+from gradio_demos.modules import ExampleDemo, WebcamDemo
+
+
+cfg_body = utils.load_config("configs/anonymizers/FB_cse.py")
+anonymizer_body = instantiate(cfg_body.anonymizer, load_cache=False)
+anonymizer_body.initialize_tracker(fps=1)
+
+
+with gradio.Blocks() as demo:
+ gradio.Markdown("# DeepPrivacy2 - Realistic Image Anonymization ")
+ gradio.Markdown("### Håkon Hukkelås, Rudolf Mester, Frank Lindseth ")
+ with gradio.Tab("Full-Body CSE Anonymization"):
+ ExampleDemo(anonymizer_body)
+ with gradio.Tab("Full-body CSE Webcam"):
+ WebcamDemo(anonymizer_body)
+
+
+demo.launch()
+
diff --git a/gradio_demos/face.py b/gradio_demos/face.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e8692fd191a36c09b933b5e36f3e31d0f78341b
--- /dev/null
+++ b/gradio_demos/face.py
@@ -0,0 +1,22 @@
+import gradio
+from dp2 import utils
+from tops.config import instantiate
+import gradio.inputs
+from gradio_demos.modules import ExampleDemo, WebcamDemo
+
+cfg_face = utils.load_config("configs/anonymizers/face.py")
+anonymizer_face = instantiate(cfg_face.anonymizer, load_cache=False)
+print(anonymizer_face.detector)
+anonymizer_face.initialize_tracker(fps=1)
+
+
+with gradio.Blocks() as demo:
+ gradio.Markdown("# DeepPrivacy2 - Realistic Image Anonymization ")
+ gradio.Markdown("### Håkon Hukkelås, Rudolf Mester, Frank Lindseth ")
+ with gradio.Tab("Face Anonymization"):
+ ExampleDemo(anonymizer_face)
+ with gradio.Tab("Live Webcam"):
+ WebcamDemo(anonymizer_face)
+
+demo.launch()
+
diff --git a/gradio_demos/modules.py b/gradio_demos/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..3682f3d79688c3a55294a4f5ca4704112e0fd860
--- /dev/null
+++ b/gradio_demos/modules.py
@@ -0,0 +1,247 @@
+from collections import defaultdict
+import gradio
+import numpy as np
+import torch
+import cv2
+from PIL import Image
+from dp2 import utils
+from tops.config import instantiate
+import tops
+import gradio.inputs
+from stylemc import get_and_cache_direction, get_styles
+from sg3_torch_utils.ops import grid_sample_gradfix, bias_act, upfirdn2d
+
+grid_sample_gradfix.enabled = False
+bias_act.enabled = False
+upfirdn2d.enabled = False
+
+
+class GuidedDemo:
+ def __init__(self, face_anonymizer, cfg_face, multi_modal_truncation, truncation_value) -> None:
+ self.anonymizer = face_anonymizer
+ self.multi_modal_truncation = multi_modal_truncation
+ self.truncation_value = truncation_value
+ assert sum([x is not None for x in list(face_anonymizer.generators.values())]) == 1
+ self.generator = [x for x in list(face_anonymizer.generators.values()) if x is not None][0]
+ face_G_cfg = utils.load_config(cfg_face.anonymizer.face_G_cfg)
+ face_G_cfg.train.batch_size = 1
+ self.dl = instantiate(face_G_cfg.data.val.loader)
+ self.cache_dir = face_G_cfg.output_dir
+ self.precompute_edits()
+
+ def precompute_edits(self):
+ self.precomputed_edits = set()
+ for edit in self.precomputed_edits:
+ get_and_cache_direction(self.cache_dir, self.dl, self.generator, edit)
+ if self.cache_dir.joinpath("stylemc_cache").is_dir():
+ for path in self.cache_dir.joinpath("stylemc_cache").iterdir():
+ text_prompt = path.stem.replace("_", " ")
+ self.precomputed_edits.add(text_prompt)
+ print(text_prompt)
+ self.edits = defaultdict(defaultdict)
+
+ def anonymize(self, img, show_boxes: bool, current_box_idx: int, current_styles, current_boxes, update_identity, edits, cache_id=None):
+ if not isinstance(img, torch.Tensor):
+ img, cache_id = pil2torch(img)
+ img = tops.to_cuda(img)
+
+ current_box_idx = current_box_idx % len(current_boxes)
+ edited_styles = [s.clone() for s in current_styles]
+ for face_idx, face_edits in edits.items():
+ for prompt, strength in face_edits.items():
+ direction = get_and_cache_direction(self.cache_dir, self.dl, self.generator, prompt)
+ edited_styles[int(face_idx)] += direction * strength
+ update_identity[int(face_idx)] = True
+ assert img.dtype == torch.uint8
+ img = self.anonymizer(
+ img, truncation_value=self.truncation_value,
+ multi_modal_truncation=self.multi_modal_truncation, amp=True,
+ cache_id=cache_id,
+ all_styles=edited_styles,
+ update_identity=update_identity)
+ update_identity = [True for i in range(len(update_identity))]
+ img = utils.im2numpy(img)
+ if show_boxes:
+ x0, y0, x1, y1 = [int(_) for _ in current_boxes[int(current_box_idx)]]
+ img = cv2.rectangle(img, (x0, y0), (x1, y1), (255, 0, 0), 1)
+ return img, update_identity
+
+ def update_image(self, img, show_boxes):
+ img, cache_id = pil2torch(img)
+ img = tops.to_cuda(img)
+ det = self.anonymizer.detector.forward_and_cache(img, cache_id, load_cache=True)[0]
+ current_styles = []
+ for i in range(len(det)):
+ s = get_styles(
+ np.random.randint(0, 999999), self.generator,
+ None, truncation_value=self.truncation_value)
+ current_styles.append(s)
+ update_identity = [True for i in range(len(det))]
+ current_boxes = np.array(det.boxes)
+ edits = defaultdict(defaultdict)
+ cur_face_idx = -1 % len(current_boxes)
+ img, update_identity = self.anonymize(
+ img, show_boxes, cur_face_idx,
+ current_styles, current_boxes, update_identity, edits, cache_id=cache_id)
+ return img, current_styles, current_boxes, update_identity, edits, cur_face_idx
+
+ def change_face(self, change, cur_face_idx, current_boxes, input_image, show_boxes, current_styles, update_identity, edits):
+ cur_face_idx = (cur_face_idx + change) % len(current_boxes)
+ img, update_identity = self.anonymize(
+ input_image, show_boxes, cur_face_idx,
+ current_styles, current_boxes, update_identity, edits)
+ return img, update_identity, cur_face_idx
+
+ def add_style(self, face_idx: int, prompt: str, strength: float, input_image, show_boxes, current_styles, current_boxes, update_identity, edits):
+ face_idx = face_idx % len(current_boxes)
+ edits[face_idx][prompt] = strength
+ img, update_identity = self.anonymize(
+ input_image, show_boxes, face_idx,
+ current_styles, current_boxes, update_identity, edits)
+ return img, update_identity, edits
+
+ def setup_interface(self):
+ current_styles = gradio.State()
+ current_boxes = gradio.State(None)
+ update_identity = gradio.State([])
+ edits = gradio.State([])
+ with gradio.Row():
+ input_image = gradio.Image(
+ type="pil", label="Upload your image or try the example below!", source="webcam")
+ output_image = gradio.Image(type="numpy", label="Output")
+ with gradio.Row():
+ update_btn = gradio.Button("Update Anonymization").style(full_width=True)
+ with gradio.Row():
+ show_boxes = gradio.Checkbox(value=True, label="Show Selected")
+ cur_face_idx = gradio.Number(value=-1, label="Current", interactive=False)
+ previous = gradio.Button("Previous Person")
+ next_ = gradio.Button("Next Person")
+ with gradio.Row():
+ text_prompt = gradio.Textbox(
+ placeholder=" | ".join(list(self.precomputed_edits)),
+ label="Text Prompt for Edit")
+ edit_strength = gradio.Slider(0, 5, step=.01)
+ add_btn = gradio.Button("Add Edit")
+ add_btn.click(
+ self.add_style,
+ inputs=[cur_face_idx, text_prompt, edit_strength, input_image, show_boxes,current_styles, current_boxes, update_identity, edits],
+ outputs=[output_image, update_identity, edits])
+ update_btn.click(
+ self.update_image,
+ inputs=[input_image, show_boxes],
+ outputs=[output_image, current_styles, current_boxes, update_identity, edits, cur_face_idx])
+ input_image.change(
+ self.update_image,
+ inputs=[input_image, show_boxes],
+ outputs=[output_image, current_styles, current_boxes, update_identity, edits, cur_face_idx])
+ previous.click(
+ self.change_face,
+ inputs=[gradio.State(-1), cur_face_idx, current_boxes, input_image, show_boxes, current_styles, update_identity, edits],
+ outputs=[output_image, update_identity, cur_face_idx])
+ next_.click(
+ self.change_face,
+ inputs=[gradio.State(1), cur_face_idx, current_boxes, input_image, show_boxes,current_styles, update_identity, edits],
+ outputs=[output_image, update_identity, cur_face_idx])
+ show_boxes.change(
+ self.anonymize,
+ inputs=[input_image, show_boxes, cur_face_idx, current_styles, current_boxes, update_identity, edits],
+ outputs=[output_image, update_identity])
+
+
+class WebcamDemo:
+
+ def __init__(self, anonymizer) -> None:
+ self.anonymizer = anonymizer
+ with gradio.Row():
+ input_image = gradio.Image(type="pil", source="webcam", streaming=True)
+ output_image = gradio.Image(type="numpy", label="Output")
+ with gradio.Row():
+ truncation_value = gradio.Slider(0, 1, value=0, step=0.01)
+ truncation = gradio.Radio(["Multi-modal truncation", "Unimodal truncation"], value="Unimodal truncation")
+ with gradio.Row():
+ visualize_det = gradio.Checkbox(value=False, label="Show Detections")
+ track = gradio.Checkbox(value=False, label="Track detections (samples same latent variable per track)")
+ input_image.stream(
+ self.anonymize,
+ inputs=[input_image, visualize_det, truncation_value,truncation, track, gradio.Variable(False)],
+ outputs=[output_image])
+ self.track = True
+
+ def anonymize(self, img: Image, visualize_detection: bool, truncation_value, truncation_type, track, reset_track):
+ if reset_track:
+ self.anonymizer.reset_tracker()
+ mmt = truncation_type == "Multi-modal truncation"
+ img, cache_id = pil2torch(img)
+ img = tops.to_cuda(img)
+ self.anonymizer
+ if visualize_detection:
+ img = self.anonymizer.visualize_detection(img, cache_id=cache_id)
+ else:
+ img = self.anonymizer(
+ img,
+ truncation_value=truncation_value,
+ multi_modal_truncation=mmt,
+ amp=True,
+ cache_id=cache_id,
+ track=track)
+ img = utils.im2numpy(img)
+ return img
+
+
+class ExampleDemo(WebcamDemo):
+
+ def __init__(self, anonymizer) -> None:
+ self.anonymizer = anonymizer
+ with gradio.Row():
+ input_image = gradio.Image(type="pil", source="webcam")
+ output_image = gradio.Image(type="numpy", label="Output")
+ with gradio.Row():
+ update_btn = gradio.Button("Update Anonymization").style(full_width=True)
+ resample = gradio.Button("Resample Latent Variables").style(full_width=True)
+ with gradio.Row():
+ truncation_value = gradio.Slider(0, 1, value=0, step=0.01)
+ truncation = gradio.Radio(["Multi-modal truncation", "Unimodal truncation"], value="Unimodal truncation")
+ visualize_det = gradio.Checkbox(value=False, label="Show Detections")
+ visualize_det.change(
+ self.anonymize,
+ inputs=[input_image, visualize_det, truncation_value, truncation, gradio.Variable(True), gradio.Variable(False)],
+ outputs=[output_image])
+ gradio.Examples(
+ ["media/erling.jpg", "media/regjeringen.jpg"], inputs=[input_image]
+ )
+
+ update_btn.click(
+ self.anonymize,
+ inputs=[input_image, visualize_det, truncation_value, truncation, gradio.Variable(True), gradio.Variable(False)],
+ outputs=[output_image])
+ resample.click(
+ self.anonymize,
+ inputs=[input_image, visualize_det, truncation_value, truncation, gradio.Variable(True), gradio.Variable(True)],
+ outputs=[output_image])
+ input_image.change(
+ self.anonymize,
+ inputs=[input_image, visualize_det, truncation_value, truncation, gradio.Variable(False), gradio.Variable(True)],
+ outputs=[output_image])
+ self.track = False
+ self.truncation_value = truncation_value
+
+
+class Information:
+
+ def __init__(self) -> None:
+ gradio.Markdown("## Face Anonymization Architecture ")
+ gradio.Markdown("---")
+ gradio.Image(value="media/overall_architecture.png")
+ gradio.Markdown("## Full-Body Anonymization Architecture ")
+ gradio.Markdown("---")
+ gradio.Image(value="media/full_body.png")
+ gradio.Markdown("### Generative Adversarial Networks ")
+ gradio.Markdown("---")
+ gradio.Image(value="media/gan_architecture.png")
+
+
+def pil2torch(img: Image.Image):
+ img = img.convert("RGB")
+ img = np.array(img)
+ img = np.rollaxis(img, 2)
+ return torch.from_numpy(img), None
diff --git a/media/erling.jpg b/media/erling.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..78612fd7643f95d48c6188e8b3c7b6f5557aef30
--- /dev/null
+++ b/media/erling.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d1731504c0bdad6ed94578d13bc53c65181fef3c13db1a9214c6e9a255d9b02e
+size 1227474
diff --git a/media/g7_leaders.jpg b/media/g7_leaders.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..573608b1f13612f30dfb6a4cd4b35db68543f64b
--- /dev/null
+++ b/media/g7_leaders.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:480fa6b02b77e12bb7e3b85c4c0c15797040fceb5f87eaaefdb626c2dfb49255
+size 2121459
diff --git a/media/regjeringen.jpg b/media/regjeringen.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f6ea34a1053634a1be00e06417aa6fddf838eea0
--- /dev/null
+++ b/media/regjeringen.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3a59538772ccc8f9a0a8b490017ffd7e1dfad2527fac9bf8f49b3a227117335a
+size 526611
diff --git a/media/stylemc_example.jpg b/media/stylemc_example.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..4b3f23f450425f3afd19b1cae94dc75c86e16b13
--- /dev/null
+++ b/media/stylemc_example.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3c95012af3c4a18229ca0d3031e22cae182cd8e41518b51b26a0040408e8210a
+size 2513762
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8d33bc492f1ab5454ab5f85f560e4a4fa015181c
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,24 @@
+numpy
+cython
+matplotlib
+tqdm
+tensorboard
+opencv-python
+torch_fidelity
+ninja
+moviepy
+pyspng
+wandb
+termcolor
+tops@git+https://github.com/hukkelas/torch_ops.git
+motpy@git+https://github.com/wmuron/motpy@c77f85d27e371c0a298e9a88ca99292d9b9cbe6b
+fast_pytorch_kmeans
+einops
+einops_exts
+regex
+setuptools
+resize_right
+pillow
+scipy
+webdataset
+scikit-image
\ No newline at end of file
diff --git a/sg3_torch_utils/LICENSE.txt b/sg3_torch_utils/LICENSE.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6b5ee9bf994cc9441cb659c3527160b4ee5bcb33
--- /dev/null
+++ b/sg3_torch_utils/LICENSE.txt
@@ -0,0 +1,97 @@
+Copyright (c) 2021, NVIDIA Corporation & affiliates. All rights reserved.
+
+
+NVIDIA Source Code License for StyleGAN3
+
+
+=======================================================================
+
+1. Definitions
+
+"Licensor" means any person or entity that distributes its Work.
+
+"Software" means the original work of authorship made available under
+this License.
+
+"Work" means the Software and any additions to or derivative works of
+the Software that are made available under this License.
+
+The terms "reproduce," "reproduction," "derivative works," and
+"distribution" have the meaning as provided under U.S. copyright law;
+provided, however, that for the purposes of this License, derivative
+works shall not include works that remain separable from, or merely
+link (or bind by name) to the interfaces of, the Work.
+
+Works, including the Software, are "made available" under this License
+by including in or with the Work either (a) a copyright notice
+referencing the applicability of this License to the Work, or (b) a
+copy of this License.
+
+2. License Grants
+
+ 2.1 Copyright Grant. Subject to the terms and conditions of this
+ License, each Licensor grants to you a perpetual, worldwide,
+ non-exclusive, royalty-free, copyright license to reproduce,
+ prepare derivative works of, publicly display, publicly perform,
+ sublicense and distribute its Work and any resulting derivative
+ works in any form.
+
+3. Limitations
+
+ 3.1 Redistribution. You may reproduce or distribute the Work only
+ if (a) you do so under this License, (b) you include a complete
+ copy of this License with your distribution, and (c) you retain
+ without modification any copyright, patent, trademark, or
+ attribution notices that are present in the Work.
+
+ 3.2 Derivative Works. You may specify that additional or different
+ terms apply to the use, reproduction, and distribution of your
+ derivative works of the Work ("Your Terms") only if (a) Your Terms
+ provide that the use limitation in Section 3.3 applies to your
+ derivative works, and (b) you identify the specific derivative
+ works that are subject to Your Terms. Notwithstanding Your Terms,
+ this License (including the redistribution requirements in Section
+ 3.1) will continue to apply to the Work itself.
+
+ 3.3 Use Limitation. The Work and any derivative works thereof only
+ may be used or intended for use non-commercially. Notwithstanding
+ the foregoing, NVIDIA and its affiliates may use the Work and any
+ derivative works commercially. As used herein, "non-commercially"
+ means for research or evaluation purposes only.
+
+ 3.4 Patent Claims. If you bring or threaten to bring a patent claim
+ against any Licensor (including any claim, cross-claim or
+ counterclaim in a lawsuit) to enforce any patents that you allege
+ are infringed by any Work, then your rights under this License from
+ such Licensor (including the grant in Section 2.1) will terminate
+ immediately.
+
+ 3.5 Trademarks. This License does not grant any rights to use any
+ Licensor’s or its affiliates’ names, logos, or trademarks, except
+ as necessary to reproduce the notices described in this License.
+
+ 3.6 Termination. If you violate any term of this License, then your
+ rights under this License (including the grant in Section 2.1) will
+ terminate immediately.
+
+4. Disclaimer of Warranty.
+
+THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
+MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
+NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
+THIS LICENSE.
+
+5. Limitation of Liability.
+
+EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
+THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
+SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
+INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
+OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
+(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
+LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
+COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
+THE POSSIBILITY OF SUCH DAMAGES.
+
+=======================================================================
diff --git a/sg3_torch_utils/__init__.py b/sg3_torch_utils/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..ece0ea08fe2e939cc260a1dafc0ab5b391b773d9
--- /dev/null
+++ b/sg3_torch_utils/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+# empty
diff --git a/sg3_torch_utils/custom_ops.py b/sg3_torch_utils/custom_ops.py
new file mode 100755
index 0000000000000000000000000000000000000000..4cc4e43fc6f6ce79f2bd68a44ba87990b9b8564e
--- /dev/null
+++ b/sg3_torch_utils/custom_ops.py
@@ -0,0 +1,126 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import os
+import glob
+import torch
+import torch.utils.cpp_extension
+import importlib
+import hashlib
+import shutil
+from pathlib import Path
+
+from torch.utils.file_baton import FileBaton
+
+#----------------------------------------------------------------------------
+# Global options.
+
+verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
+
+#----------------------------------------------------------------------------
+# Internal helper funcs.
+
+def _find_compiler_bindir():
+ patterns = [
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
+ 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
+ ]
+ for pattern in patterns:
+ matches = sorted(glob.glob(pattern))
+ if len(matches):
+ return matches[-1]
+ return None
+
+#----------------------------------------------------------------------------
+# Main entry point for compiling and loading C++/CUDA plugins.
+
+_cached_plugins = dict()
+
+def get_plugin(module_name, sources, **build_kwargs):
+ assert verbosity in ['none', 'brief', 'full']
+
+ # Already cached?
+ if module_name in _cached_plugins:
+ return _cached_plugins[module_name]
+
+ # Print status.
+ if verbosity == 'full':
+ print(f'Setting up PyTorch plugin "{module_name}"...')
+ elif verbosity == 'brief':
+ print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
+
+ try: # pylint: disable=too-many-nested-blocks
+ # Make sure we can find the necessary compiler binaries.
+ if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
+ compiler_bindir = _find_compiler_bindir()
+ if compiler_bindir is None:
+ raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
+ os.environ['PATH'] += ';' + compiler_bindir
+
+ # Compile and load.
+ verbose_build = (verbosity == 'full')
+
+ # Incremental build md5sum trickery. Copies all the input source files
+ # into a cached build directory under a combined md5 digest of the input
+ # source files. Copying is done only if the combined digest has changed.
+ # This keeps input file timestamps and filenames the same as in previous
+ # extension builds, allowing for fast incremental rebuilds.
+ #
+ # This optimization is done only in case all the source files reside in
+ # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
+ # environment variable is set (we take this as a signal that the user
+ # actually cares about this.)
+ source_dirs_set = set(os.path.dirname(source) for source in sources)
+ if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
+ all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
+
+ # Compute a combined hash digest for all source files in the same
+ # custom op directory (usually .cu, .cpp, .py and .h files).
+ hash_md5 = hashlib.md5()
+ for src in all_source_files:
+ with open(src, 'rb') as f:
+ hash_md5.update(f.read())
+ build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
+ digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
+
+ if not os.path.isdir(digest_build_dir):
+ os.makedirs(digest_build_dir, exist_ok=True)
+ baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
+ if baton.try_acquire():
+ try:
+ for src in all_source_files:
+ shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
+ finally:
+ baton.release()
+ else:
+ # Someone else is copying source files under the digest dir,
+ # wait until done and continue.
+ baton.wait()
+ digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
+ torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
+ verbose=verbose_build, sources=digest_sources, **build_kwargs)
+ else:
+ torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
+ module = importlib.import_module(module_name)
+
+ except:
+ if verbosity == 'brief':
+ print('Failed!')
+ raise
+
+ # Print status and add to cache.
+ if verbosity == 'full':
+ print(f'Done setting up PyTorch plugin "{module_name}".')
+ elif verbosity == 'brief':
+ print('Done.')
+ _cached_plugins[module_name] = module
+ return module
+
+#----------------------------------------------------------------------------
diff --git a/sg3_torch_utils/misc.py b/sg3_torch_utils/misc.py
new file mode 100755
index 0000000000000000000000000000000000000000..10d8e31880affdd185580b6f5b98e92c79597dc3
--- /dev/null
+++ b/sg3_torch_utils/misc.py
@@ -0,0 +1,172 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import re
+import contextlib
+import numpy as np
+import torch
+import warnings
+
+#----------------------------------------------------------------------------
+# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
+# same constant is used multiple times.
+
+_constant_cache = dict()
+
+def constant(value, shape=None, dtype=None, device=None, memory_format=None):
+ value = np.asarray(value)
+ if shape is not None:
+ shape = tuple(shape)
+ if dtype is None:
+ dtype = torch.get_default_dtype()
+ if device is None:
+ device = torch.device('cpu')
+ if memory_format is None:
+ memory_format = torch.contiguous_format
+
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
+ tensor = _constant_cache.get(key, None)
+ if tensor is None:
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
+ if shape is not None:
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
+ tensor = tensor.contiguous(memory_format=memory_format)
+ _constant_cache[key] = tensor
+ return tensor
+
+#----------------------------------------------------------------------------
+# Replace NaN/Inf with specified numerical values.
+
+try:
+ nan_to_num = torch.nan_to_num # 1.8.0a0
+except AttributeError:
+ def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
+ assert isinstance(input, torch.Tensor)
+ if posinf is None:
+ posinf = torch.finfo(input.dtype).max
+ if neginf is None:
+ neginf = torch.finfo(input.dtype).min
+ assert nan == 0
+ return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
+
+#----------------------------------------------------------------------------
+# Symbolic assert.
+
+try:
+ symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
+except AttributeError:
+ symbolic_assert = torch.Assert # 1.7.0
+
+#----------------------------------------------------------------------------
+# Context manager to suppress known warnings in torch.jit.trace().
+
+class suppress_tracer_warnings(warnings.catch_warnings):
+ def __enter__(self):
+ super().__enter__()
+ warnings.simplefilter('ignore', category=torch.jit.TracerWarning)
+ return self
+
+#----------------------------------------------------------------------------
+# Assert that the shape of a tensor matches the given list of integers.
+# None indicates that the size of a dimension is allowed to vary.
+# Performs symbolic assertion when used in torch.jit.trace().
+
+def assert_shape(tensor, ref_shape):
+ if tensor.ndim != len(ref_shape):
+ raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
+ for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
+ if ref_size is None:
+ pass
+ elif isinstance(ref_size, torch.Tensor):
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
+ symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
+ elif isinstance(size, torch.Tensor):
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
+ symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
+ elif size != ref_size:
+ raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
+
+#----------------------------------------------------------------------------
+# Function decorator that calls torch.autograd.profiler.record_function().
+
+def profiled_function(fn):
+ def decorator(*args, **kwargs):
+ with torch.autograd.profiler.record_function(fn.__name__):
+ return fn(*args, **kwargs)
+ decorator.__name__ = fn.__name__
+ return decorator
+
+#----------------------------------------------------------------------------
+# Sampler for torch.utils.data.DataLoader that loops over the dataset
+# indefinitely, shuffling items as it goes.
+
+class InfiniteSampler(torch.utils.data.Sampler):
+ def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
+ assert len(dataset) > 0
+ assert num_replicas > 0
+ assert 0 <= rank < num_replicas
+ assert 0 <= window_size <= 1
+ super().__init__(dataset)
+ self.dataset = dataset
+ self.rank = rank
+ self.num_replicas = num_replicas
+ self.shuffle = shuffle
+ self.seed = seed
+ self.window_size = window_size
+
+ def __iter__(self):
+ order = np.arange(len(self.dataset))
+ rnd = None
+ window = 0
+ if self.shuffle:
+ rnd = np.random.RandomState(self.seed)
+ rnd.shuffle(order)
+ window = int(np.rint(order.size * self.window_size))
+
+ idx = 0
+ while True:
+ i = idx % order.size
+ if idx % self.num_replicas == self.rank:
+ yield order[i]
+ if window >= 2:
+ j = (i - rnd.randint(window)) % order.size
+ order[i], order[j] = order[j], order[i]
+ idx += 1
+
+#----------------------------------------------------------------------------
+# Utilities for operating with torch.nn.Module parameters and buffers.
+
+def params_and_buffers(module):
+ assert isinstance(module, torch.nn.Module)
+ return list(module.parameters()) + list(module.buffers())
+
+def named_params_and_buffers(module):
+ assert isinstance(module, torch.nn.Module)
+ return list(module.named_parameters()) + list(module.named_buffers())
+
+def copy_params_and_buffers(src_module, dst_module, require_all=False):
+ assert isinstance(src_module, torch.nn.Module)
+ assert isinstance(dst_module, torch.nn.Module)
+ src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)}
+ for name, tensor in named_params_and_buffers(dst_module):
+ assert (name in src_tensors) or (not require_all)
+ if name in src_tensors:
+ tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
+
+#----------------------------------------------------------------------------
+# Context manager for easily enabling/disabling DistributedDataParallel
+# synchronization.
+
+@contextlib.contextmanager
+def ddp_sync(module, sync):
+ assert isinstance(module, torch.nn.Module)
+ if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
+ yield
+ else:
+ with module.no_sync():
+ yield
diff --git a/sg3_torch_utils/ops/__init__.py b/sg3_torch_utils/ops/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..ece0ea08fe2e939cc260a1dafc0ab5b391b773d9
--- /dev/null
+++ b/sg3_torch_utils/ops/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+# empty
diff --git a/sg3_torch_utils/ops/bias_act.cpp b/sg3_torch_utils/ops/bias_act.cpp
new file mode 100755
index 0000000000000000000000000000000000000000..5d2425d8054991a8e8b6f7a940fd0ff7fa0bb330
--- /dev/null
+++ b/sg3_torch_utils/ops/bias_act.cpp
@@ -0,0 +1,99 @@
+// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include
+#include
+#include
+#include "bias_act.h"
+
+//------------------------------------------------------------------------
+
+static bool has_same_layout(torch::Tensor x, torch::Tensor y)
+{
+ if (x.dim() != y.dim())
+ return false;
+ for (int64_t i = 0; i < x.dim(); i++)
+ {
+ if (x.size(i) != y.size(i))
+ return false;
+ if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
+ return false;
+ }
+ return true;
+}
+
+//------------------------------------------------------------------------
+
+static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
+{
+ // Validate arguments.
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+ TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
+ TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
+ TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
+ TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
+ TORCH_CHECK(b.dim() == 1, "b must have rank 1");
+ TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
+ TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
+ TORCH_CHECK(grad >= 0, "grad must be non-negative");
+
+ // Validate layout.
+ TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
+ TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
+ TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
+ TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
+ TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
+
+ // Create output tensor.
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+ torch::Tensor y = torch::empty_like(x);
+ TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
+
+ // Initialize CUDA kernel parameters.
+ bias_act_kernel_params p;
+ p.x = x.data_ptr();
+ p.b = (b.numel()) ? b.data_ptr() : NULL;
+ p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
+ p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
+ p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
+ p.y = y.data_ptr();
+ p.grad = grad;
+ p.act = act;
+ p.alpha = alpha;
+ p.gain = gain;
+ p.clamp = clamp;
+ p.sizeX = (int)x.numel();
+ p.sizeB = (int)b.numel();
+ p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
+
+ // Choose CUDA kernel.
+ void* kernel;
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
+ {
+ kernel = choose_bias_act_kernel(p);
+ });
+ TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
+
+ // Launch CUDA kernel.
+ p.loopX = 4;
+ int blockSize = 4 * 32;
+ int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
+ void* args[] = {&p};
+ AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
+ return y;
+}
+
+//------------------------------------------------------------------------
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("bias_act", &bias_act);
+}
+
+//------------------------------------------------------------------------
diff --git a/sg3_torch_utils/ops/bias_act.cu b/sg3_torch_utils/ops/bias_act.cu
new file mode 100755
index 0000000000000000000000000000000000000000..dd8fc4756d7d94727f94af738665b68d9c518880
--- /dev/null
+++ b/sg3_torch_utils/ops/bias_act.cu
@@ -0,0 +1,173 @@
+// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include
+#include "bias_act.h"
+
+//------------------------------------------------------------------------
+// Helpers.
+
+template struct InternalType;
+template <> struct InternalType { typedef double scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+
+//------------------------------------------------------------------------
+// CUDA kernel.
+
+template
+__global__ void bias_act_kernel(bias_act_kernel_params p)
+{
+ typedef typename InternalType::scalar_t scalar_t;
+ int G = p.grad;
+ scalar_t alpha = (scalar_t)p.alpha;
+ scalar_t gain = (scalar_t)p.gain;
+ scalar_t clamp = (scalar_t)p.clamp;
+ scalar_t one = (scalar_t)1;
+ scalar_t two = (scalar_t)2;
+ scalar_t expRange = (scalar_t)80;
+ scalar_t halfExpRange = (scalar_t)40;
+ scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
+ scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
+
+ // Loop over elements.
+ int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
+ for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
+ {
+ // Load.
+ scalar_t x = (scalar_t)((const T*)p.x)[xi];
+ scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
+ scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
+ scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
+ scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
+ scalar_t yy = (gain != 0) ? yref / gain : 0;
+ scalar_t y = 0;
+
+ // Apply bias.
+ ((G == 0) ? x : xref) += b;
+
+ // linear
+ if (A == 1)
+ {
+ if (G == 0) y = x;
+ if (G == 1) y = x;
+ }
+
+ // relu
+ if (A == 2)
+ {
+ if (G == 0) y = (x > 0) ? x : 0;
+ if (G == 1) y = (yy > 0) ? x : 0;
+ }
+
+ // lrelu
+ if (A == 3)
+ {
+ if (G == 0) y = (x > 0) ? x : x * alpha;
+ if (G == 1) y = (yy > 0) ? x : x * alpha;
+ }
+
+ // tanh
+ if (A == 4)
+ {
+ if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
+ if (G == 1) y = x * (one - yy * yy);
+ if (G == 2) y = x * (one - yy * yy) * (-two * yy);
+ }
+
+ // sigmoid
+ if (A == 5)
+ {
+ if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
+ if (G == 1) y = x * yy * (one - yy);
+ if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
+ }
+
+ // elu
+ if (A == 6)
+ {
+ if (G == 0) y = (x >= 0) ? x : exp(x) - one;
+ if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
+ }
+
+ // selu
+ if (A == 7)
+ {
+ if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
+ if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
+ }
+
+ // softplus
+ if (A == 8)
+ {
+ if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
+ if (G == 1) y = x * (one - exp(-yy));
+ if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
+ }
+
+ // swish
+ if (A == 9)
+ {
+ if (G == 0)
+ y = (x < -expRange) ? 0 : x / (exp(-x) + one);
+ else
+ {
+ scalar_t c = exp(xref);
+ scalar_t d = c + one;
+ if (G == 1)
+ y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
+ else
+ y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
+ yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
+ }
+ }
+
+ // Apply gain.
+ y *= gain * dy;
+
+ // Clamp.
+ if (clamp >= 0)
+ {
+ if (G == 0)
+ y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
+ else
+ y = (yref > -clamp & yref < clamp) ? y : 0;
+ }
+
+ // Store.
+ ((T*)p.y)[xi] = (T)y;
+ }
+}
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template void* choose_bias_act_kernel(const bias_act_kernel_params& p)
+{
+ if (p.act == 1) return (void*)bias_act_kernel;
+ if (p.act == 2) return (void*)bias_act_kernel;
+ if (p.act == 3) return (void*)bias_act_kernel;
+ if (p.act == 4) return (void*)bias_act_kernel;
+ if (p.act == 5) return (void*)bias_act_kernel;
+ if (p.act == 6) return (void*)bias_act_kernel;
+ if (p.act == 7) return (void*)bias_act_kernel;
+ if (p.act == 8) return (void*)bias_act_kernel;
+ if (p.act == 9) return (void*)bias_act_kernel;
+ return NULL;
+}
+
+//------------------------------------------------------------------------
+// Template specializations.
+
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/sg3_torch_utils/ops/bias_act.h b/sg3_torch_utils/ops/bias_act.h
new file mode 100755
index 0000000000000000000000000000000000000000..a32187e1fb7e3bae509d4eceaf900866866875a4
--- /dev/null
+++ b/sg3_torch_utils/ops/bias_act.h
@@ -0,0 +1,38 @@
+// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+//------------------------------------------------------------------------
+// CUDA kernel parameters.
+
+struct bias_act_kernel_params
+{
+ const void* x; // [sizeX]
+ const void* b; // [sizeB] or NULL
+ const void* xref; // [sizeX] or NULL
+ const void* yref; // [sizeX] or NULL
+ const void* dy; // [sizeX] or NULL
+ void* y; // [sizeX]
+
+ int grad;
+ int act;
+ float alpha;
+ float gain;
+ float clamp;
+
+ int sizeX;
+ int sizeB;
+ int stepB;
+ int loopX;
+};
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template void* choose_bias_act_kernel(const bias_act_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/sg3_torch_utils/ops/bias_act.py b/sg3_torch_utils/ops/bias_act.py
new file mode 100755
index 0000000000000000000000000000000000000000..7c39717268055fafe737419486cf96f1f93f4fb5
--- /dev/null
+++ b/sg3_torch_utils/ops/bias_act.py
@@ -0,0 +1,215 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Custom PyTorch ops for efficient bias and activation."""
+
+import os
+import warnings
+import numpy as np
+import torch
+import traceback
+
+from .. import custom_ops
+from easydict import EasyDict
+from torch.cuda.amp import custom_bwd, custom_fwd
+#----------------------------------------------------------------------------
+
+activation_funcs = {
+ 'linear': EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
+ 'relu': EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
+ 'lrelu': EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
+ 'tanh': EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
+ 'sigmoid': EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
+ 'elu': EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
+ 'selu': EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
+ 'softplus': EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
+ 'swish': EasyDict(func=lambda x, **_: torch.nn.functional.silu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
+}
+
+#----------------------------------------------------------------------------
+
+_inited = False
+_plugin = None
+enabled = False
+_null_tensor = torch.empty([0])
+
+def _init():
+ global _inited, _plugin
+ if not _inited:
+ _inited = True
+ sources = ['bias_act.cpp', 'bias_act.cu']
+ sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
+ try:
+ _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
+ except:
+ warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
+ return _plugin is not None
+
+#----------------------------------------------------------------------------
+
+def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
+ r"""Fused bias and activation function.
+
+ Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
+ and scales the result by `gain`. Each of the steps is optional. In most cases,
+ the fused op is considerably more efficient than performing the same calculation
+ using standard PyTorch ops. It supports first and second order gradients,
+ but not third order gradients.
+
+ Args:
+ x: Input activation tensor. Can be of any shape.
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
+ as `x`. The shape must be known, and it must match the dimension of `x`
+ corresponding to `dim`.
+ dim: The dimension in `x` corresponding to the elements of `b`.
+ The value of `dim` is ignored if `b` is not specified.
+ act: Name of the activation function to evaluate, or `"linear"` to disable.
+ Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
+ See `activation_funcs` for a full list. `None` is not allowed.
+ alpha: Shape parameter for the activation function, or `None` to use the default.
+ gain: Scaling factor for the output tensor, or `None` to use default.
+ See `activation_funcs` for the default scaling of each activation function.
+ If unsure, consider specifying 1.
+ clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
+ the clamping (default).
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
+
+ Returns:
+ Tensor of the same shape and datatype as `x`.
+ """
+ assert isinstance(x, torch.Tensor)
+ assert impl in ['ref', 'cuda']
+ if impl == 'cuda' and x.device.type == 'cuda' and enabled and _init():
+ return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
+ return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
+
+#----------------------------------------------------------------------------
+
+def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
+ """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
+ """
+ assert isinstance(x, torch.Tensor)
+ assert clamp is None or clamp >= 0
+ spec = activation_funcs[act]
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
+ gain = float(gain if gain is not None else spec.def_gain)
+ clamp = float(clamp if clamp is not None else -1)
+
+ # Add bias.
+ if b is not None:
+ assert isinstance(b, torch.Tensor) and b.ndim == 1
+ assert 0 <= dim < x.ndim
+ assert b.shape[0] == x.shape[dim]
+ x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
+
+ # Evaluate activation function.
+ alpha = float(alpha)
+ x = spec.func(x, alpha=alpha)
+
+ # Scale by gain.
+ gain = float(gain)
+ if gain != 1:
+ x = x * gain
+
+ # Clamp.
+ if clamp >= 0:
+ x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
+ return x
+
+#----------------------------------------------------------------------------
+
+_bias_act_cuda_cache = dict()
+
+def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
+ """Fast CUDA implementation of `bias_act()` using custom ops.
+ """
+ # Parse arguments.
+ assert clamp is None or clamp >= 0
+ spec = activation_funcs[act]
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
+ gain = float(gain if gain is not None else spec.def_gain)
+ clamp = float(clamp if clamp is not None else -1)
+
+ # Lookup from cache.
+ key = (dim, act, alpha, gain, clamp)
+ if key in _bias_act_cuda_cache:
+ return _bias_act_cuda_cache[key]
+
+ # Forward op.
+ class BiasActCuda(torch.autograd.Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float16)
+ def forward(ctx, x, b): # pylint: disable=arguments-differ
+ ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format
+ x = x.contiguous(memory_format=ctx.memory_format)
+ b = b.contiguous() if b is not None else _null_tensor
+ y = x
+ if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
+ y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
+ ctx.save_for_backward(
+ x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
+ b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
+ y if 'y' in spec.ref else _null_tensor)
+ return y
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy): # pylint: disable=arguments-differ
+ dy = dy.contiguous(memory_format=ctx.memory_format)
+ x, b, y = ctx.saved_tensors
+ dx = None
+ db = None
+
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
+ dx = dy
+ if act != 'linear' or gain != 1 or clamp >= 0:
+ dx = BiasActCudaGrad.apply(dy, x, b, y)
+
+ if ctx.needs_input_grad[1]:
+ db = dx.sum([i for i in range(dx.ndim) if i != dim])
+
+ return dx, db
+
+ # Backward op.
+ class BiasActCudaGrad(torch.autograd.Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float16)
+ def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
+ ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format
+ dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
+ ctx.save_for_backward(
+ dy if spec.has_2nd_grad else _null_tensor,
+ x, b, y)
+ return dx
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, d_dx): # pylint: disable=arguments-differ
+ d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
+ dy, x, b, y = ctx.saved_tensors
+ d_dy = None
+ d_x = None
+ d_b = None
+ d_y = None
+
+ if ctx.needs_input_grad[0]:
+ d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
+
+ if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
+ d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
+
+ if spec.has_2nd_grad and ctx.needs_input_grad[2]:
+ d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
+
+ return d_dy, d_x, d_b, d_y
+
+ # Add to cache.
+ _bias_act_cuda_cache[key] = BiasActCuda
+ return BiasActCuda
+
+#----------------------------------------------------------------------------
diff --git a/sg3_torch_utils/ops/conv2d_gradfix.py b/sg3_torch_utils/ops/conv2d_gradfix.py
new file mode 100755
index 0000000000000000000000000000000000000000..e66591f19fad68760d3df7c9737a14574b70ee83
--- /dev/null
+++ b/sg3_torch_utils/ops/conv2d_gradfix.py
@@ -0,0 +1,175 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Custom replacement for `torch.nn.functional.conv2d` that supports
+arbitrarily high order gradients with zero performance penalty."""
+
+import warnings
+import contextlib
+import torch
+from torch.cuda.amp import custom_bwd, custom_fwd
+
+# pylint: disable=redefined-builtin
+# pylint: disable=arguments-differ
+# pylint: disable=protected-access
+
+#----------------------------------------------------------------------------
+
+enabled = False # Enable the custom op by setting this to true.
+weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
+
+@contextlib.contextmanager
+def no_weight_gradients():
+ global weight_gradients_disabled
+ old = weight_gradients_disabled
+ weight_gradients_disabled = True
+ yield
+ weight_gradients_disabled = old
+
+#----------------------------------------------------------------------------
+
+def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
+ if _should_use_custom_op(input):
+ return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
+ return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
+
+def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
+ if _should_use_custom_op(input):
+ return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
+ return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
+
+#----------------------------------------------------------------------------
+
+def _should_use_custom_op(input):
+ assert isinstance(input, torch.Tensor)
+ if (not enabled) or (not torch.backends.cudnn.enabled):
+ return False
+ if input.device.type != 'cuda':
+ return False
+ if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9', '1.10']):
+ return True
+ warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
+ return False
+
+def _tuple_of_ints(xs, ndim):
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
+ assert len(xs) == ndim
+ assert all(isinstance(x, int) for x in xs)
+ return xs
+
+#----------------------------------------------------------------------------
+
+_conv2d_gradfix_cache = dict()
+
+def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
+ # Parse arguments.
+ ndim = 2
+ weight_shape = tuple(weight_shape)
+ stride = _tuple_of_ints(stride, ndim)
+ padding = _tuple_of_ints(padding, ndim)
+ output_padding = _tuple_of_ints(output_padding, ndim)
+ dilation = _tuple_of_ints(dilation, ndim)
+
+ # Lookup from cache.
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
+ if key in _conv2d_gradfix_cache:
+ return _conv2d_gradfix_cache[key]
+
+ # Validate arguments.
+ assert groups >= 1
+ assert len(weight_shape) == ndim + 2
+ assert all(stride[i] >= 1 for i in range(ndim))
+ assert all(padding[i] >= 0 for i in range(ndim))
+ assert all(dilation[i] >= 0 for i in range(ndim))
+ if not transpose:
+ assert all(output_padding[i] == 0 for i in range(ndim))
+ else: # transpose
+ assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
+
+ # Helpers.
+ common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
+ def calc_output_padding(input_shape, output_shape):
+ if transpose:
+ return [0, 0]
+ return [
+ input_shape[i + 2]
+ - (output_shape[i + 2] - 1) * stride[i]
+ - (1 - 2 * padding[i])
+ - dilation[i] * (weight_shape[i + 2] - 1)
+ for i in range(ndim)
+ ]
+
+ # Forward & backward.
+ class Conv2d(torch.autograd.Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float16)
+ def forward(ctx, input, weight, bias):
+ assert weight.shape == weight_shape
+ if not transpose:
+ output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
+ else: # transpose
+ output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
+ ctx.save_for_backward(input, weight)
+ return output
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad_output):
+ input, weight = ctx.saved_tensors
+ grad_input = None
+ grad_weight = None
+ grad_bias = None
+
+ if ctx.needs_input_grad[0]:
+ p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
+ grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output.float(), weight.float(), None)
+ assert grad_input.shape == input.shape
+
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
+ grad_weight = Conv2dGradWeight.apply(grad_output.float(), input.float())
+ assert grad_weight.shape == weight_shape
+
+ if ctx.needs_input_grad[2]:
+ grad_bias = grad_output.float().sum([0, 2, 3])
+
+ return grad_input, grad_weight, grad_bias
+
+ # Gradient with respect to the weights.
+ class Conv2dGradWeight(torch.autograd.Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float16)
+ def forward(ctx, grad_output, input):
+ op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight')
+ flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
+ grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
+ assert grad_weight.shape == weight_shape
+ ctx.save_for_backward(grad_output, input)
+ return grad_weight
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad2_grad_weight):
+ grad_output, input = ctx.saved_tensors
+ grad2_grad_output = None
+ grad2_input = None
+
+ if ctx.needs_input_grad[0]:
+ grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
+ assert grad2_grad_output.shape == grad_output.shape
+
+ if ctx.needs_input_grad[1]:
+ p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
+ grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None)
+ assert grad2_input.shape == input.shape
+
+ return grad2_grad_output, grad2_input
+
+ _conv2d_gradfix_cache[key] = Conv2d
+ return Conv2d
+
+#----------------------------------------------------------------------------
diff --git a/sg3_torch_utils/ops/conv2d_resample.py b/sg3_torch_utils/ops/conv2d_resample.py
new file mode 100755
index 0000000000000000000000000000000000000000..4a999b58b36a5da53752024e86a9ebdf9c031d97
--- /dev/null
+++ b/sg3_torch_utils/ops/conv2d_resample.py
@@ -0,0 +1,142 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""2D convolution with optional up/downsampling."""
+
+import torch
+
+from .. import misc
+from . import conv2d_gradfix
+from . import upfirdn2d
+from .upfirdn2d import _parse_padding
+from .upfirdn2d import _get_filter_size
+
+#----------------------------------------------------------------------------
+
+def _get_weight_shape(w):
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
+ shape = [int(sz) for sz in w.shape]
+ misc.assert_shape(w, shape)
+ return shape
+
+#----------------------------------------------------------------------------
+
+def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
+ """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
+ """
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
+
+ # Flip weight if requested.
+ if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
+ w = w.flip([2, 3])
+
+ # Otherwise => execute using conv2d_gradfix.
+ op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
+ return op(x, w, stride=stride, padding=padding, groups=groups)
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
+ r"""2D convolution with optional up/downsampling.
+
+ Padding is performed only once at the beginning, not between the operations.
+
+ Args:
+ x: Input tensor of shape
+ `[batch_size, in_channels, in_height, in_width]`.
+ w: Weight tensor of shape
+ `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
+ f: Low-pass filter for up/downsampling. Must be prepared beforehand by
+ calling upfirdn2d.setup_filter(). None = identity (default).
+ up: Integer upsampling factor (default: 1).
+ down: Integer downsampling factor (default: 1).
+ padding: Padding with respect to the upsampled image. Can be a single number
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+ (default: 0).
+ groups: Split input channels into N groups (default: 1).
+ flip_weight: False = convolution, True = correlation (default: True).
+ flip_filter: False = convolution, True = correlation (default: False).
+
+ Returns:
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+ """
+ # Validate arguments.
+ assert isinstance(x, torch.Tensor) and (x.ndim == 4)
+ assert isinstance(w, torch.Tensor) and (w.ndim == 4)
+ assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
+ assert isinstance(up, int) and (up >= 1)
+ assert isinstance(down, int) and (down >= 1)
+ assert isinstance(groups, int) and (groups >= 1)
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
+ fw, fh = _get_filter_size(f)
+ px0, px1, py0, py1 = _parse_padding(padding)
+
+ # Adjust padding to account for up/downsampling.
+ if up > 1:
+ px0 += (fw + up - 1) // 2
+ px1 += (fw - up) // 2
+ py0 += (fh + up - 1) // 2
+ py1 += (fh - up) // 2
+ if down > 1:
+ px0 += (fw - down + 1) // 2
+ px1 += (fw - down) // 2
+ py0 += (fh - down + 1) // 2
+ py1 += (fh - down) // 2
+
+ # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
+ if kw == 1 and kh == 1 and (down > 1 and up == 1):
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
+ return x
+
+ # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
+ if kw == 1 and kh == 1 and (up > 1 and down == 1):
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
+ x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
+ return x
+
+ # Fast path: downsampling only => use strided convolution.
+ if down > 1 and up == 1:
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
+ x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
+ return x
+
+ # Fast path: upsampling with optional downsampling => use transpose strided convolution.
+ if up > 1:
+ if groups == 1:
+ w = w.transpose(0, 1)
+ else:
+ w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
+ w = w.transpose(1, 2)
+ w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
+ px0 -= kw - 1
+ px1 -= kw - up
+ py0 -= kh - 1
+ py1 -= kh - up
+ pxt = max(min(-px0, -px1), 0)
+ pyt = max(min(-py0, -py1), 0)
+ x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
+ if down > 1:
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
+ return x
+
+ # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
+ if up == 1 and down == 1:
+ if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
+ return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
+
+ # Fallback: Generic reference implementation.
+ x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
+ if down > 1:
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
+ return x
+
+#----------------------------------------------------------------------------
diff --git a/sg3_torch_utils/ops/fma.py b/sg3_torch_utils/ops/fma.py
new file mode 100755
index 0000000000000000000000000000000000000000..b4e8ef9169440d4c3bd95befae7d26e3c1e1f017
--- /dev/null
+++ b/sg3_torch_utils/ops/fma.py
@@ -0,0 +1,63 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
+
+import torch
+from torch.cuda.amp import custom_bwd, custom_fwd
+
+#----------------------------------------------------------------------------
+
+def fma(a, b, c): # => a * b + c
+ return _FusedMultiplyAdd.apply(a, b, c)
+
+#----------------------------------------------------------------------------
+
+class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float16)
+ def forward(ctx, a, b, c): # pylint: disable=arguments-differ
+ out = torch.addcmul(c, a, b)
+ ctx.save_for_backward(a, b)
+ ctx.c_shape = c.shape
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dout): # pylint: disable=arguments-differ
+ a, b = ctx.saved_tensors
+ c_shape = ctx.c_shape
+ da = None
+ db = None
+ dc = None
+
+ if ctx.needs_input_grad[0]:
+ da = _unbroadcast(dout * b, a.shape)
+
+ if ctx.needs_input_grad[1]:
+ db = _unbroadcast(dout * a, b.shape)
+
+ if ctx.needs_input_grad[2]:
+ dc = _unbroadcast(dout, c_shape)
+
+ return da, db, dc
+
+#----------------------------------------------------------------------------
+
+def _unbroadcast(x, shape):
+ extra_dims = x.ndim - len(shape)
+ assert extra_dims >= 0
+ dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
+ if len(dim):
+ x = x.sum(dim=dim, keepdim=True)
+ if extra_dims:
+ x = x.reshape(-1, *x.shape[extra_dims+1:])
+ assert x.shape == shape
+ return x
+
+#----------------------------------------------------------------------------
diff --git a/sg3_torch_utils/ops/grid_sample_gradfix.py b/sg3_torch_utils/ops/grid_sample_gradfix.py
new file mode 100755
index 0000000000000000000000000000000000000000..87067e150c591b1ace91816e7a5c3ee3a4aeacd3
--- /dev/null
+++ b/sg3_torch_utils/ops/grid_sample_gradfix.py
@@ -0,0 +1,88 @@
+# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Custom replacement for `torch.nn.functional.grid_sample` that
+supports arbitrarily high order gradients between the input and output.
+Only works on 2D images and assumes
+`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
+
+import torch
+from torch.cuda.amp import custom_bwd, custom_fwd
+from pkg_resources import parse_version
+# pylint: disable=redefined-builtin
+# pylint: disable=arguments-differ
+# pylint: disable=protected-access
+_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11
+
+
+#----------------------------------------------------------------------------
+
+enabled = False # Enable the custom op by setting this to true.
+
+#----------------------------------------------------------------------------
+
+def grid_sample(input, grid):
+ if _should_use_custom_op():
+ return _GridSample2dForward.apply(input, grid)
+ return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
+
+#----------------------------------------------------------------------------
+
+def _should_use_custom_op():
+ return enabled
+
+#----------------------------------------------------------------------------
+
+class _GridSample2dForward(torch.autograd.Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float16)
+ def forward(ctx, input, grid):
+ assert input.ndim == 4
+ assert grid.ndim == 4
+ output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
+ ctx.save_for_backward(input, grid)
+ return output
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad_output):
+ input, grid = ctx.saved_tensors
+ grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
+ return grad_input, grad_grid
+
+#----------------------------------------------------------------------------
+
+class _GridSample2dBackward(torch.autograd.Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float16)
+ def forward(ctx, grad_output, input, grid):
+ op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
+ if _use_pytorch_1_11_api:
+ output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2])
+ grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask)
+ else:
+ grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
+ ctx.save_for_backward(grid)
+ return grad_input, grad_grid
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad2_grad_input, grad2_grad_grid):
+ _ = grad2_grad_grid # unused
+ grid, = ctx.saved_tensors
+ grad2_grad_output = None
+ grad2_input = None
+ grad2_grid = None
+
+ if ctx.needs_input_grad[0]:
+ grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
+
+ assert not ctx.needs_input_grad[2]
+ return grad2_grad_output, grad2_input, grad2_grid
+
+#----------------------------------------------------------------------------
diff --git a/sg3_torch_utils/ops/upfirdn2d.cpp b/sg3_torch_utils/ops/upfirdn2d.cpp
new file mode 100755
index 0000000000000000000000000000000000000000..2d7177fc60040751d20e9a8da0301fa3ab64968a
--- /dev/null
+++ b/sg3_torch_utils/ops/upfirdn2d.cpp
@@ -0,0 +1,103 @@
+// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include
+#include
+#include
+#include "upfirdn2d.h"
+
+//------------------------------------------------------------------------
+
+static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
+{
+ // Validate arguments.
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+ TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
+ TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
+ TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
+ TORCH_CHECK(f.dim() == 2, "f must be rank 2");
+ TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
+ TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
+ TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
+
+ // Create output tensor.
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+ int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
+ int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
+ TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
+ torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
+ TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
+
+ // Initialize CUDA kernel parameters.
+ upfirdn2d_kernel_params p;
+ p.x = x.data_ptr();
+ p.f = f.data_ptr();
+ p.y = y.data_ptr();
+ p.up = make_int2(upx, upy);
+ p.down = make_int2(downx, downy);
+ p.pad0 = make_int2(padx0, pady0);
+ p.flip = (flip) ? 1 : 0;
+ p.gain = gain;
+ p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
+ p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
+ p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
+ p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
+ p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
+ p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
+ p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
+ p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
+
+ // Choose CUDA kernel.
+ upfirdn2d_kernel_spec spec;
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
+ {
+ spec = choose_upfirdn2d_kernel(p);
+ });
+
+ // Set looping options.
+ p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
+ p.loopMinor = spec.loopMinor;
+ p.loopX = spec.loopX;
+ p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
+ p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
+
+ // Compute grid size.
+ dim3 blockSize, gridSize;
+ if (spec.tileOutW < 0) // large
+ {
+ blockSize = dim3(4, 32, 1);
+ gridSize = dim3(
+ ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
+ (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
+ p.launchMajor);
+ }
+ else // small
+ {
+ blockSize = dim3(256, 1, 1);
+ gridSize = dim3(
+ ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
+ (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
+ p.launchMajor);
+ }
+
+ // Launch CUDA kernel.
+ void* args[] = {&p};
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
+ return y;
+}
+
+//------------------------------------------------------------------------
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("upfirdn2d", &upfirdn2d);
+}
+
+//------------------------------------------------------------------------
diff --git a/sg3_torch_utils/ops/upfirdn2d.cu b/sg3_torch_utils/ops/upfirdn2d.cu
new file mode 100755
index 0000000000000000000000000000000000000000..ebdd9879f4bb16fc57a23cbc81f9de8ef54e4916
--- /dev/null
+++ b/sg3_torch_utils/ops/upfirdn2d.cu
@@ -0,0 +1,350 @@
+// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include
+#include "upfirdn2d.h"
+
+//------------------------------------------------------------------------
+// Helpers.
+
+template struct InternalType;
+template <> struct InternalType { typedef double scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+
+static __device__ __forceinline__ int floor_div(int a, int b)
+{
+ int t = 1 - a / b;
+ return (a + t * b) / b - t;
+}
+
+//------------------------------------------------------------------------
+// Generic CUDA implementation for large filters.
+
+template static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
+{
+ typedef typename InternalType::scalar_t scalar_t;
+
+ // Calculate thread index.
+ int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
+ int outY = minorBase / p.launchMinor;
+ minorBase -= outY * p.launchMinor;
+ int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
+ int majorBase = blockIdx.z * p.loopMajor;
+ if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
+ return;
+
+ // Setup Y receptive field.
+ int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
+ int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
+ int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
+ int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
+ if (p.flip)
+ filterY = p.filterSize.y - 1 - filterY;
+
+ // Loop over major, minor, and X.
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
+ for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
+ {
+ int nc = major * p.sizeMinor + minor;
+ int n = nc / p.inSize.z;
+ int c = nc - n * p.inSize.z;
+ for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
+ {
+ // Setup X receptive field.
+ int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
+ int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
+ int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
+ int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
+ if (p.flip)
+ filterX = p.filterSize.x - 1 - filterX;
+
+ // Initialize pointers.
+ const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
+ const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
+ int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
+ int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
+
+ // Inner loop.
+ scalar_t v = 0;
+ for (int y = 0; y < h; y++)
+ {
+ for (int x = 0; x < w; x++)
+ {
+ v += (scalar_t)(*xp) * (scalar_t)(*fp);
+ xp += p.inStride.x;
+ fp += filterStepX;
+ }
+ xp += p.inStride.y - w * p.inStride.x;
+ fp += filterStepY - w * filterStepX;
+ }
+
+ // Store result.
+ v *= p.gain;
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
+ }
+ }
+}
+
+//------------------------------------------------------------------------
+// Specialized CUDA implementation for small filters.
+
+template
+static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
+{
+ typedef typename InternalType::scalar_t scalar_t;
+ const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
+ const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
+ __shared__ volatile scalar_t sf[filterH][filterW];
+ __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
+
+ // Calculate tile index.
+ int minorBase = blockIdx.x;
+ int tileOutY = minorBase / p.launchMinor;
+ minorBase -= tileOutY * p.launchMinor;
+ minorBase *= loopMinor;
+ tileOutY *= tileOutH;
+ int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
+ int majorBase = blockIdx.z * p.loopMajor;
+ if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
+ return;
+
+ // Load filter (flipped).
+ for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
+ {
+ int fy = tapIdx / filterW;
+ int fx = tapIdx - fy * filterW;
+ scalar_t v = 0;
+ if (fx < p.filterSize.x & fy < p.filterSize.y)
+ {
+ int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
+ int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
+ v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
+ }
+ sf[fy][fx] = v;
+ }
+
+ // Loop over major and X.
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
+ {
+ int baseNC = major * p.sizeMinor + minorBase;
+ int n = baseNC / p.inSize.z;
+ int baseC = baseNC - n * p.inSize.z;
+ for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
+ {
+ // Load input pixels.
+ int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
+ int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
+ int tileInX = floor_div(tileMidX, upx);
+ int tileInY = floor_div(tileMidY, upy);
+ __syncthreads();
+ for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
+ {
+ int relC = inIdx;
+ int relInX = relC / loopMinor;
+ int relInY = relInX / tileInW;
+ relC -= relInX * loopMinor;
+ relInX -= relInY * tileInW;
+ int c = baseC + relC;
+ int inX = tileInX + relInX;
+ int inY = tileInY + relInY;
+ scalar_t v = 0;
+ if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
+ v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
+ sx[relInY][relInX][relC] = v;
+ }
+
+ // Loop over output pixels.
+ __syncthreads();
+ for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
+ {
+ int relC = outIdx;
+ int relOutX = relC / loopMinor;
+ int relOutY = relOutX / tileOutW;
+ relC -= relOutX * loopMinor;
+ relOutX -= relOutY * tileOutW;
+ int c = baseC + relC;
+ int outX = tileOutX + relOutX;
+ int outY = tileOutY + relOutY;
+
+ // Setup receptive field.
+ int midX = tileMidX + relOutX * downx;
+ int midY = tileMidY + relOutY * downy;
+ int inX = floor_div(midX, upx);
+ int inY = floor_div(midY, upy);
+ int relInX = inX - tileInX;
+ int relInY = inY - tileInY;
+ int filterX = (inX + 1) * upx - midX - 1; // flipped
+ int filterY = (inY + 1) * upy - midY - 1; // flipped
+
+ // Inner loop.
+ if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
+ {
+ scalar_t v = 0;
+ #pragma unroll
+ for (int y = 0; y < filterH / upy; y++)
+ #pragma unroll
+ for (int x = 0; x < filterW / upx; x++)
+ v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
+ v *= p.gain;
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
+ }
+ }
+ }
+ }
+}
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
+{
+ int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
+
+ upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large, -1,-1,1, 4}; // contiguous
+ if (s == 1) spec = {(void*)upfirdn2d_kernel_large, -1,-1,4, 1}; // channels_last
+
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
+ {
+ if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ }
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
+ {
+ if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ }
+ if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
+ {
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1};
+ }
+ if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
+ {
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1};
+ }
+ if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
+ {
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1};
+ }
+ if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
+ {
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1};
+ }
+ if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
+ {
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1};
+ }
+ if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
+ {
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1};
+ }
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous
+ {
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1};
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1};
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1};
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1};
+ }
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last
+ {
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1};
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1};
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1};
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1};
+ }
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous
+ {
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1};
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1};
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1};
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1};
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1};
+ }
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last
+ {
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1};
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1};
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1};
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1};
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1};
+ }
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous
+ {
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1};
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1};
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1};
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1};
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1};
+ }
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last
+ {
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1};
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1};
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1};
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1};
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1};
+ }
+ return spec;
+}
+
+//------------------------------------------------------------------------
+// Template specializations.
+
+template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p);
+template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p);
+template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/sg3_torch_utils/ops/upfirdn2d.h b/sg3_torch_utils/ops/upfirdn2d.h
new file mode 100755
index 0000000000000000000000000000000000000000..c9e2032bcac9d2abde7a75eea4d812da348afadd
--- /dev/null
+++ b/sg3_torch_utils/ops/upfirdn2d.h
@@ -0,0 +1,59 @@
+// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include
+
+//------------------------------------------------------------------------
+// CUDA kernel parameters.
+
+struct upfirdn2d_kernel_params
+{
+ const void* x;
+ const float* f;
+ void* y;
+
+ int2 up;
+ int2 down;
+ int2 pad0;
+ int flip;
+ float gain;
+
+ int4 inSize; // [width, height, channel, batch]
+ int4 inStride;
+ int2 filterSize; // [width, height]
+ int2 filterStride;
+ int4 outSize; // [width, height, channel, batch]
+ int4 outStride;
+ int sizeMinor;
+ int sizeMajor;
+
+ int loopMinor;
+ int loopMajor;
+ int loopX;
+ int launchMinor;
+ int launchMajor;
+};
+
+//------------------------------------------------------------------------
+// CUDA kernel specialization.
+
+struct upfirdn2d_kernel_spec
+{
+ void* kernel;
+ int tileOutW;
+ int tileOutH;
+ int loopMinor;
+ int loopX;
+};
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/sg3_torch_utils/ops/upfirdn2d.py b/sg3_torch_utils/ops/upfirdn2d.py
new file mode 100755
index 0000000000000000000000000000000000000000..a0bbd22d245481e7c5a19315e5cb3242b1278787
--- /dev/null
+++ b/sg3_torch_utils/ops/upfirdn2d.py
@@ -0,0 +1,388 @@
+# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+"""Custom PyTorch ops for efficient resampling of 2D images."""
+
+import os
+import warnings
+import numpy as np
+import torch
+import traceback
+
+from .. import custom_ops
+from .. import misc
+from . import conv2d_gradfix
+from torch.cuda.amp import custom_bwd, custom_fwd
+
+#----------------------------------------------------------------------------
+
+_inited = False
+_plugin = None
+enabled = False
+
+def _init():
+ global _inited, _plugin
+ if not _inited:
+ sources = ['upfirdn2d.cpp', 'upfirdn2d.cu']
+ sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
+ try:
+ _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
+ except:
+ warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
+ return _plugin is not None
+
+def _parse_scaling(scaling):
+ if isinstance(scaling, int):
+ scaling = [scaling, scaling]
+ assert isinstance(scaling, (list, tuple))
+ assert all(isinstance(x, int) for x in scaling)
+ sx, sy = scaling
+ assert sx >= 1 and sy >= 1
+ return sx, sy
+
+def _parse_padding(padding):
+ if isinstance(padding, int):
+ padding = [padding, padding]
+ assert isinstance(padding, (list, tuple))
+ assert all(isinstance(x, int) for x in padding)
+ if len(padding) == 2:
+ padx, pady = padding
+ padding = [padx, padx, pady, pady]
+ padx0, padx1, pady0, pady1 = padding
+ return padx0, padx1, pady0, pady1
+
+def _get_filter_size(f):
+ if f is None:
+ return 1, 1
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
+ fw = f.shape[-1]
+ fh = f.shape[0]
+ with misc.suppress_tracer_warnings():
+ fw = int(fw)
+ fh = int(fh)
+ misc.assert_shape(f, [fh, fw][:f.ndim])
+ assert fw >= 1 and fh >= 1
+ return fw, fh
+
+#----------------------------------------------------------------------------
+
+def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
+ r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
+
+ Args:
+ f: Torch tensor, numpy array, or python list of the shape
+ `[filter_height, filter_width]` (non-separable),
+ `[filter_taps]` (separable),
+ `[]` (impulse), or
+ `None` (identity).
+ device: Result device (default: cpu).
+ normalize: Normalize the filter so that it retains the magnitude
+ for constant input signal (DC)? (default: True).
+ flip_filter: Flip the filter? (default: False).
+ gain: Overall scaling factor for signal magnitude (default: 1).
+ separable: Return a separable filter? (default: select automatically).
+
+ Returns:
+ Float32 tensor of the shape
+ `[filter_height, filter_width]` (non-separable) or
+ `[filter_taps]` (separable).
+ """
+ # Validate.
+ if f is None:
+ f = 1
+ f = torch.as_tensor(f, dtype=torch.float32)
+ assert f.ndim in [0, 1, 2]
+ assert f.numel() > 0
+ if f.ndim == 0:
+ f = f[np.newaxis]
+
+ # Separable?
+ if separable is None:
+ separable = (f.ndim == 1 and f.numel() >= 8)
+ if f.ndim == 1 and not separable:
+ f = f.ger(f)
+ assert f.ndim == (1 if separable else 2)
+
+ # Apply normalize, flip, gain, and device.
+ if normalize:
+ f /= f.sum()
+ if flip_filter:
+ f = f.flip(list(range(f.ndim)))
+ f = f * (gain ** (f.ndim / 2))
+ f = f.to(device=device)
+ return f
+
+#----------------------------------------------------------------------------
+
+def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
+ r"""Pad, upsample, filter, and downsample a batch of 2D images.
+
+ Performs the following sequence of operations for each channel:
+
+ 1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
+
+ 2. Pad the image with the specified number of zeros on each side (`padding`).
+ Negative padding corresponds to cropping the image.
+
+ 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
+ so that the footprint of all output pixels lies within the input image.
+
+ 4. Downsample the image by keeping every Nth pixel (`down`).
+
+ This sequence of operations bears close resemblance to scipy.signal.upfirdn().
+ The fused op is considerably more efficient than performing the same calculation
+ using standard PyTorch ops. It supports gradients of arbitrary order.
+
+ Args:
+ x: Float32/float64/float16 input tensor of the shape
+ `[batch_size, num_channels, in_height, in_width]`.
+ f: Float32 FIR filter of the shape
+ `[filter_height, filter_width]` (non-separable),
+ `[filter_taps]` (separable), or
+ `None` (identity).
+ up: Integer upsampling factor. Can be a single int or a list/tuple
+ `[x, y]` (default: 1).
+ down: Integer downsampling factor. Can be a single int or a list/tuple
+ `[x, y]` (default: 1).
+ padding: Padding with respect to the upsampled image. Can be a single number
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+ (default: 0).
+ flip_filter: False = convolution, True = correlation (default: False).
+ gain: Overall scaling factor for signal magnitude (default: 1).
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
+
+ Returns:
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+ """
+ assert isinstance(x, torch.Tensor)
+ assert impl in ['ref', 'cuda']
+ if impl == 'cuda' and x.device.type == 'cuda' and enabled and _init():
+ return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
+ return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
+ """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
+ """
+ # Validate arguments.
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
+ if f is None:
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
+ assert f.dtype == torch.float32 and not f.requires_grad
+ batch_size, num_channels, in_height, in_width = x.shape
+ upx, upy = _parse_scaling(up)
+ downx, downy = _parse_scaling(down)
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
+
+ # Upsample by inserting zeros.
+ x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
+ x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
+ x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
+
+ # Pad or crop.
+ x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
+ x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
+
+ # Setup filter.
+ f = f * (gain ** (f.ndim / 2))
+ f = f.to(x.dtype)
+ if not flip_filter:
+ f = f.flip(list(range(f.ndim)))
+
+ # Convolve with the filter.
+ f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
+ if f.ndim == 4:
+ x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
+ else:
+ x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
+ x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
+
+ # Downsample by throwing away pixels.
+ x = x[:, :, ::downy, ::downx]
+ return x
+
+#----------------------------------------------------------------------------
+
+_upfirdn2d_cuda_cache = dict()
+
+def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
+ """Fast CUDA implementation of `upfirdn2d()` using custom ops.
+ """
+ # Parse arguments.
+ upx, upy = _parse_scaling(up)
+ downx, downy = _parse_scaling(down)
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
+
+ # Lookup from cache.
+ key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
+ if key in _upfirdn2d_cuda_cache:
+ return _upfirdn2d_cuda_cache[key]
+
+ # Forward op.
+ class Upfirdn2dCuda(torch.autograd.Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float32)
+ def forward(ctx, x, f): # pylint: disable=arguments-differ
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
+ if f is None:
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
+ y = x
+ if f.ndim == 2:
+ y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
+ else:
+ y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain))
+ y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain))
+ ctx.save_for_backward(f)
+ ctx.x_shape = x.shape
+ return y
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, dy): # pylint: disable=arguments-differ
+ f, = ctx.saved_tensors
+ _, _, ih, iw = ctx.x_shape
+ _, _, oh, ow = dy.shape
+ fw, fh = _get_filter_size(f)
+ p = [
+ fw - padx0 - 1,
+ iw * upx - ow * downx + padx0 - upx + 1,
+ fh - pady0 - 1,
+ ih * upy - oh * downy + pady0 - upy + 1,
+ ]
+ dx = None
+ df = None
+
+ if ctx.needs_input_grad[0]:
+ dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
+
+ assert not ctx.needs_input_grad[1]
+ return dx, df
+
+ # Add to cache.
+ _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
+ return Upfirdn2dCuda
+
+#----------------------------------------------------------------------------
+
+def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
+ r"""Filter a batch of 2D images using the given 2D FIR filter.
+
+ By default, the result is padded so that its shape matches the input.
+ User-specified padding is applied on top of that, with negative values
+ indicating cropping. Pixels outside the image are assumed to be zero.
+
+ Args:
+ x: Float32/float64/float16 input tensor of the shape
+ `[batch_size, num_channels, in_height, in_width]`.
+ f: Float32 FIR filter of the shape
+ `[filter_height, filter_width]` (non-separable),
+ `[filter_taps]` (separable), or
+ `None` (identity).
+ padding: Padding with respect to the output. Can be a single number or a
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+ (default: 0).
+ flip_filter: False = convolution, True = correlation (default: False).
+ gain: Overall scaling factor for signal magnitude (default: 1).
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
+
+ Returns:
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+ """
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
+ fw, fh = _get_filter_size(f)
+ p = [
+ padx0 + fw // 2,
+ padx1 + (fw - 1) // 2,
+ pady0 + fh // 2,
+ pady1 + (fh - 1) // 2,
+ ]
+ return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
+
+#----------------------------------------------------------------------------
+
+def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
+ r"""Upsample a batch of 2D images using the given 2D FIR filter.
+
+ By default, the result is padded so that its shape is a multiple of the input.
+ User-specified padding is applied on top of that, with negative values
+ indicating cropping. Pixels outside the image are assumed to be zero.
+
+ Args:
+ x: Float32/float64/float16 input tensor of the shape
+ `[batch_size, num_channels, in_height, in_width]`.
+ f: Float32 FIR filter of the shape
+ `[filter_height, filter_width]` (non-separable),
+ `[filter_taps]` (separable), or
+ `None` (identity).
+ up: Integer upsampling factor. Can be a single int or a list/tuple
+ `[x, y]` (default: 1).
+ padding: Padding with respect to the output. Can be a single number or a
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+ (default: 0).
+ flip_filter: False = convolution, True = correlation (default: False).
+ gain: Overall scaling factor for signal magnitude (default: 1).
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
+
+ Returns:
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+ """
+ upx, upy = _parse_scaling(up)
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
+ fw, fh = _get_filter_size(f)
+ p = [
+ padx0 + (fw + upx - 1) // 2,
+ padx1 + (fw - upx) // 2,
+ pady0 + (fh + upy - 1) // 2,
+ pady1 + (fh - upy) // 2,
+ ]
+ return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
+
+#----------------------------------------------------------------------------
+
+def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
+ r"""Downsample a batch of 2D images using the given 2D FIR filter.
+
+ By default, the result is padded so that its shape is a fraction of the input.
+ User-specified padding is applied on top of that, with negative values
+ indicating cropping. Pixels outside the image are assumed to be zero.
+
+ Args:
+ x: Float32/float64/float16 input tensor of the shape
+ `[batch_size, num_channels, in_height, in_width]`.
+ f: Float32 FIR filter of the shape
+ `[filter_height, filter_width]` (non-separable),
+ `[filter_taps]` (separable), or
+ `None` (identity).
+ down: Integer downsampling factor. Can be a single int or a list/tuple
+ `[x, y]` (default: 1).
+ padding: Padding with respect to the input. Can be a single number or a
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+ (default: 0).
+ flip_filter: False = convolution, True = correlation (default: False).
+ gain: Overall scaling factor for signal magnitude (default: 1).
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
+
+ Returns:
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+ """
+ downx, downy = _parse_scaling(down)
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
+ fw, fh = _get_filter_size(f)
+ p = [
+ padx0 + (fw - downx + 1) // 2,
+ padx1 + (fw - downx) // 2,
+ pady0 + (fh - downy + 1) // 2,
+ pady1 + (fh - downy) // 2,
+ ]
+ return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
+
+#----------------------------------------------------------------------------
diff --git a/torch_home/hub/checkpoints/89660f04-5c11-4dbf-adac-cbe2f11b0aeea25cbf78-7558-475a-b3c7-03f5c10b7934646b0720-ca0a-4d53-aded-daddbfa45c9e b/torch_home/hub/checkpoints/89660f04-5c11-4dbf-adac-cbe2f11b0aeea25cbf78-7558-475a-b3c7-03f5c10b7934646b0720-ca0a-4d53-aded-daddbfa45c9e
new file mode 100644
index 0000000000000000000000000000000000000000..b17574a8a99960be5cef91c3b96a1a3e3e233d79
--- /dev/null
+++ b/torch_home/hub/checkpoints/89660f04-5c11-4dbf-adac-cbe2f11b0aeea25cbf78-7558-475a-b3c7-03f5c10b7934646b0720-ca0a-4d53-aded-daddbfa45c9e
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1598f0dc475200dcbd0fcdfccab39c70b6023c4af47052cfa9cf867fa08fe047
+size 173628377
diff --git a/torch_home/hub/checkpoints/WIDERFace_DSFD_RES152.pth b/torch_home/hub/checkpoints/WIDERFace_DSFD_RES152.pth
new file mode 100644
index 0000000000000000000000000000000000000000..3318ad1b61e111b3d24c265e1a84d26ae9420e6c
--- /dev/null
+++ b/torch_home/hub/checkpoints/WIDERFace_DSFD_RES152.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:76ea889083f2dddaa87cbaf4cff6bb70e3e34c83fb0ab8de7b2df15eaf78caf1
+size 481004605