diff --git a/.gitattributes b/.gitattributes index c7d9f3332a950355d5a77d85000f05e6f45435ea..3ab2d8ef167e821dadfd2aaec2f962f72567b3a2 100644 --- a/.gitattributes +++ b/.gitattributes @@ -32,3 +32,15 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +torch_home/hub/checkpoints/21841da7-2546-4ce3-8460-909b3a63c58b13aac1a1-c778-4c8d-9b69-3e5ed2cde9de1524e76e-7aa6-4dd8-b643-52abc9f0792c filter=lfs diff=lfs merge=lfs -text +torch_home/hub/checkpoints/Base-DensePose-RCNN-FPN-Human.yaml filter=lfs diff=lfs merge=lfs -text +torch_home/hub/checkpoints/Base-DensePose-RCNN-FPN.yaml filter=lfs diff=lfs merge=lfs -text +torch_home/hub/checkpoints/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml filter=lfs diff=lfs merge=lfs -text +torch_home/hub/checkpoints/model_final_1d3314.pkl filter=lfs diff=lfs merge=lfs -text +torch_home/hub/checkpoints/89660f04-5c11-4dbf-adac-cbe2f11b0aeea25cbf78-7558-475a-b3c7-03f5c10b7934646b0720-ca0a-4d53-aded-daddbfa45c9e filter=lfs diff=lfs merge=lfs -text +torch_home/hub/checkpoints/WIDERFace_DSFD_RES152.pth filter=lfs diff=lfs merge=lfs -text +media2/stylemc_example.jpg filter=lfs diff=lfs merge=lfs -text +media2/erling.jpg filter=lfs diff=lfs merge=lfs -text +media2/g7_leaders.jpg filter=lfs diff=lfs merge=lfs -text +media2/regjeringen.jpg filter=lfs diff=lfs merge=lfs -text +media/ filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..a1554df8c50bec1b98e6ee212f338b0bd4275164 --- /dev/null +++ b/.gitignore @@ -0,0 +1,51 @@ +# FILES +*.flist +*.zip +*.out +*.npy +*.gz +*.ckpt +*.log +*.pyc +*.csv +*.yml +*.ods +*.ods# +*.json +build_docker.sh + +# Images / Videos +#*.png +#*.jpg +*.jpeg +*.m4a +*.mkv +*.mp4 + +# Directories created by inpaintron +.cache/ +test_examples/ +.vscode +__pycache__ +.debug/ +**/.ipynb_checkpoints/** +outputs/ + + +# From pip setup +build/ +*.egg-info +*.egg +.npm/ + +# From dockerfile +.bash_history +.viminfo +.local/ +*.pickle +*.onnx + + +sbatch_files/ +figures/ +image_dump/ \ No newline at end of file diff --git a/README.md b/README.md index 31e59290884a816c59fc4698bd93daa859fa24bc..6ea0221438aab42ad5a2c6c88a1a0299e60a19bc 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@ --- -title: Deep Privacy2 Face -emoji: 👀 -colorFrom: purple +title: Deep Privacy2 +emoji: 📈 +colorFrom: gray colorTo: indigo sdk: gradio -sdk_version: 3.23.0 +sdk_version: 3.9.1 app_file: app.py pinned: false --- diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..a9f0901281b5acda7b9d180302c4581d175d523d --- /dev/null +++ b/app.py @@ -0,0 +1,31 @@ +import gradio +import os +from tops.config import instantiate +import gradio.inputs +os.system("pip install --upgrade pip") +os.system("pip install ftfy regex tqdm") +os.system("pip install --no-deps git+https://github.com/openai/CLIP.git") +os.system("pip install git+https://github.com/facebookresearch/detectron2@96c752ce821a3340e27edd51c28a00665dd32a30#subdirectory=projects/DensePose") +os.system("pip install --no-deps git+https://github.com/hukkelas/DSFD-Pytorch-Inference") +os.environ["TORCH_HOME"] = "torch_home" +from dp2 import utils +from gradio_demos.modules import ExampleDemo, WebcamDemo + +cfg_face = utils.load_config("configs/anonymizers/face.py") + +anonymizer_face = instantiate(cfg_face.anonymizer, load_cache=False) + +anonymizer_face.initialize_tracker(fps=1) + + +with gradio.Blocks() as demo: + gradio.Markdown("#
DeepPrivacy2 - Realistic Image Anonymization
") + gradio.Markdown("###
Håkon Hukkelås, Rudolf Mester, Frank Lindseth
") + gradio.Markdown("
See more information at: https://github.com/hukkelas/deep_privacy2
") + gradio.Markdown("
For a demo of face anonymization, see: https://huggingface.co/spaces/haakohu/deep_privacy2_face
") + with gradio.Tab("Face Anonymization"): + ExampleDemo(anonymizer_face) + with gradio.Tab("Live Webcam"): + WebcamDemo(anonymizer_face) + +demo.launch() diff --git a/configs/anonymizers/FB_cse.py b/configs/anonymizers/FB_cse.py new file mode 100644 index 0000000000000000000000000000000000000000..ff44a8ef9da980d545b09609de82e071edc912ac --- /dev/null +++ b/configs/anonymizers/FB_cse.py @@ -0,0 +1,28 @@ +from dp2.anonymizer import Anonymizer +from dp2.detection.person_detector import CSEPersonDetector +from ..defaults import common +from tops.config import LazyCall as L +from dp2.generator.dummy_generators import MaskOutGenerator + + +maskout_G = L(MaskOutGenerator)(noise="constant") + +detector = L(CSEPersonDetector)( + mask_rcnn_cfg=dict(), + cse_cfg=dict(), + cse_post_process_cfg=dict( + target_imsize=(288, 160), + exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1), + exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]), + iou_combine_threshold=0.4, + dilation_percentage=0.02, + normalize_embedding=False + ), + score_threshold=0.3, + cache_directory=common.output_dir.joinpath("cse_person_detection_cache") +) + +anonymizer = L(Anonymizer)( + detector="${detector}", + cse_person_G_cfg="configs/fdh/styleganL.py", +) diff --git a/configs/anonymizers/FB_cse_mask.py b/configs/anonymizers/FB_cse_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..ff5e3bfbefad8e1d6e480fa22256aff0f9647b35 --- /dev/null +++ b/configs/anonymizers/FB_cse_mask.py @@ -0,0 +1,29 @@ +from dp2.anonymizer import Anonymizer +from dp2.detection.person_detector import CSEPersonDetector +from ..defaults import common +from tops.config import LazyCall as L +from dp2.generator.dummy_generators import MaskOutGenerator + + +maskout_G = L(MaskOutGenerator)(noise="constant") + +detector = L(CSEPersonDetector)( + mask_rcnn_cfg=dict(), + cse_cfg=dict(), + cse_post_process_cfg=dict( + target_imsize=(288, 160), + exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1), + exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]), + iou_combine_threshold=0.4, + dilation_percentage=0.02, + normalize_embedding=False + ), + score_threshold=0.3, + cache_directory=common.output_dir.joinpath("cse_person_detection_cache") +) + +anonymizer = L(Anonymizer)( + detector="${detector}", + person_G_cfg="configs/fdh/styleganL_nocse.py", + cse_person_G_cfg="configs/fdh/styleganL.py", +) diff --git a/configs/anonymizers/FB_cse_mask_face.py b/configs/anonymizers/FB_cse_mask_face.py new file mode 100644 index 0000000000000000000000000000000000000000..d411d66cc051f6b4c0d907551735e8f661cf17f1 --- /dev/null +++ b/configs/anonymizers/FB_cse_mask_face.py @@ -0,0 +1,29 @@ +from dp2.anonymizer import Anonymizer +from dp2.detection.cse_mask_face_detector import CSeMaskFaceDetector +from ..defaults import common +from tops.config import LazyCall as L + +detector = L(CSeMaskFaceDetector)( + mask_rcnn_cfg=dict(), + face_detector_cfg=dict(), + face_post_process_cfg=dict(target_imsize=(256, 256), fdf128_expand=False), + cse_cfg=dict(), + cse_post_process_cfg=dict( + target_imsize=(288, 160), + exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1), + exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]), + iou_combine_threshold=0.4, + dilation_percentage=0.02, + normalize_embedding=False + ), + score_threshold=0.3, + cache_directory=common.output_dir.joinpath("cse_mask_face_detection_cache") +) + +anonymizer = L(Anonymizer)( + detector="${detector}", + face_G_cfg="configs/fdf/stylegan.py", + person_G_cfg="configs/fdh/styleganL_nocse.py", + cse_person_G_cfg="configs/fdh/styleganL.py", + car_G_cfg="configs/generators/dummy/pixelation8.py" +) diff --git a/configs/anonymizers/deep_privacy1.py b/configs/anonymizers/deep_privacy1.py new file mode 100644 index 0000000000000000000000000000000000000000..9bf116cefdbe716a1f9ba56b7f55d5949560cfbc --- /dev/null +++ b/configs/anonymizers/deep_privacy1.py @@ -0,0 +1,15 @@ +from .face_fdf128 import anonymizer, common, detector +from dp2.detection.deep_privacy1_detector import DeepPrivacy1Detector +from tops.config import LazyCall as L + +anonymizer.update( + face_G_cfg="configs/fdf/deep_privacy1.py", +) + +anonymizer.detector = L(DeepPrivacy1Detector)( + face_detector_cfg=dict(name="DSFDDetector", clip_boxes=True), + face_post_process_cfg=dict(target_imsize=(128, 128), fdf128_expand=True), + score_threshold=0.3, + keypoint_threshold=0.3, + cache_directory=common.output_dir.joinpath("deep_privacy1_cache") +) diff --git a/configs/anonymizers/face.py b/configs/anonymizers/face.py new file mode 100644 index 0000000000000000000000000000000000000000..1eed93b812de5166ecddce94e36a4cb1cf4777d8 --- /dev/null +++ b/configs/anonymizers/face.py @@ -0,0 +1,17 @@ +from dp2.anonymizer import Anonymizer +from dp2.detection.face_detector import FaceDetector +from ..defaults import common +from tops.config import LazyCall as L + + +detector = L(FaceDetector)( + face_detector_cfg=dict(name="DSFDDetector", clip_boxes=True), + face_post_process_cfg=dict(target_imsize=(256, 256), fdf128_expand=False), + score_threshold=0.3, + cache_directory=common.output_dir.joinpath("face_detection_cache"), +) + +anonymizer = L(Anonymizer)( + detector="${detector}", + face_G_cfg="configs/fdf/stylegan.py", +) diff --git a/configs/anonymizers/face_fdf128.py b/configs/anonymizers/face_fdf128.py new file mode 100644 index 0000000000000000000000000000000000000000..327b7f5c5b2711bb59eb13489b44ad8a3c0f5f57 --- /dev/null +++ b/configs/anonymizers/face_fdf128.py @@ -0,0 +1,18 @@ +from dp2.anonymizer import Anonymizer +from dp2.detection.face_detector import FaceDetector +from ..defaults import common +from tops.config import LazyCall as L + + +detector = L(FaceDetector)( + face_detector_cfg=dict(name="DSFDDetector", clip_boxes=True), + face_post_process_cfg=dict(target_imsize=(128, 128), fdf128_expand=True), + score_threshold=0.3, + cache_directory=common.output_dir.joinpath("face_detection_cache") +) + + +anonymizer = L(Anonymizer)( + detector="${detector}", + face_G_cfg="configs/fdf/stylegan_fdf128.py", +) diff --git a/configs/anonymizers/market1501/blackout.py b/configs/anonymizers/market1501/blackout.py new file mode 100644 index 0000000000000000000000000000000000000000..14da21e3c4b367a942f9a99796a1d9996b773522 --- /dev/null +++ b/configs/anonymizers/market1501/blackout.py @@ -0,0 +1,8 @@ +from ..FB_cse_mask_face import anonymizer, detector, common + +detector.score_threshold = .1 +detector.face_detector_cfg.confidence_threshold = .5 +detector.cse_cfg.score_thres = 0.3 +anonymizer.generators.face_G_cfg = None +anonymizer.generators.person_G_cfg = "configs/generators/dummy/maskout.py" +anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/maskout.py" \ No newline at end of file diff --git a/configs/anonymizers/market1501/person.py b/configs/anonymizers/market1501/person.py new file mode 100644 index 0000000000000000000000000000000000000000..51fa99b21f068ce68f796fd32c85d37d9a22bec1 --- /dev/null +++ b/configs/anonymizers/market1501/person.py @@ -0,0 +1,6 @@ +from ..FB_cse_mask_face import anonymizer, detector, common + +detector.score_threshold = .1 +detector.face_detector_cfg.confidence_threshold = .5 +detector.cse_cfg.score_thres = 0.3 +anonymizer.generators.face_G_cfg = None \ No newline at end of file diff --git a/configs/anonymizers/market1501/pixelation16.py b/configs/anonymizers/market1501/pixelation16.py new file mode 100644 index 0000000000000000000000000000000000000000..2569fc2abb91919f91dd12546c06a86624d235fc --- /dev/null +++ b/configs/anonymizers/market1501/pixelation16.py @@ -0,0 +1,8 @@ +from ..FB_cse_mask_face import anonymizer, detector, common + +detector.score_threshold = .1 +detector.face_detector_cfg.confidence_threshold = .5 +detector.cse_cfg.score_thres = 0.3 +anonymizer.generators.face_G_cfg = None +anonymizer.generators.person_G_cfg = "configs/generators/dummy/pixelation16.py" +anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/pixelation16.py" \ No newline at end of file diff --git a/configs/anonymizers/market1501/pixelation8.py b/configs/anonymizers/market1501/pixelation8.py new file mode 100644 index 0000000000000000000000000000000000000000..ef49cb613d09e972adf7b8136b632eb210420686 --- /dev/null +++ b/configs/anonymizers/market1501/pixelation8.py @@ -0,0 +1,8 @@ +from ..FB_cse_mask_face import anonymizer, detector, common + +detector.score_threshold = .1 +detector.face_detector_cfg.confidence_threshold = .5 +detector.cse_cfg.score_thres = 0.3 +anonymizer.generators.face_G_cfg = None +anonymizer.generators.person_G_cfg = "configs/generators/dummy/pixelation8.py" +anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/pixelation8.py" \ No newline at end of file diff --git a/configs/datasets/coco_cse.py b/configs/datasets/coco_cse.py new file mode 100644 index 0000000000000000000000000000000000000000..f4834b64ebb77634ce5f1fb66a3cf9abafe26464 --- /dev/null +++ b/configs/datasets/coco_cse.py @@ -0,0 +1,69 @@ +import os +from pathlib import Path +from tops.config import LazyCall as L +import torch +import functools +from dp2.data.datasets.coco_cse import CocoCSE +from dp2.data.build import get_dataloader +from dp2.data.transforms.transforms import CreateEmbedding, Normalize, Resize, ToFloat, CreateCondition, RandomHorizontalFlip +from dp2.data.transforms.stylegan2_transform import StyleGANAugmentPipe +from dp2.metrics.torch_metrics import compute_metrics_iteratively +from .utils import final_eval_fn + + +dataset_base_dir = os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data" +metrics_cache = os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache" +data_dir = Path(dataset_base_dir, "coco_cse") +data = dict( + imsize=(288, 160), + im_channels=3, + semantic_nc=26, + cse_nc=16, + train=dict( + dataset=L(CocoCSE)(data_dir.joinpath("train"), transform=None, normalize_E=False), + loader=L(get_dataloader)( + shuffle=True, num_workers=6, drop_last=True, prefetch_factor=2, + batch_size="${train.batch_size}", + dataset="${..dataset}", + infinite=True, + gpu_transform=L(torch.nn.Sequential)(*[ + L(ToFloat)(), + L(StyleGANAugmentPipe)( + rotate=0.5, rotate_max=.05, + xint=.5, xint_max=0.05, + scale=.5, scale_std=.05, + aniso=0.5, aniso_std=.05, + xfrac=.5, xfrac_std=.05, + brightness=.5, brightness_std=.05, + contrast=.5, contrast_std=.1, + hue=.5, hue_max=.05, + saturation=.5, saturation_std=.5, + imgfilter=.5, imgfilter_std=.1), + L(RandomHorizontalFlip)(p=0.5), + L(CreateEmbedding)(), + L(Resize)(size="${data.imsize}"), + L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True), + L(CreateCondition)(), + ]) + ) + ), + val=dict( + dataset=L(CocoCSE)(data_dir.joinpath("val"), transform=None, normalize_E=False), + loader=L(get_dataloader)( + shuffle=False, num_workers=6, drop_last=True, prefetch_factor=2, + batch_size="${train.batch_size}", + dataset="${..dataset}", + infinite=False, + gpu_transform=L(torch.nn.Sequential)(*[ + L(ToFloat)(), + L(CreateEmbedding)(), + L(Resize)(size="${data.imsize}"), + L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True), + L(CreateCondition)(), + ]) + ) + ), + # Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP. + train_evaluation_fn=functools.partial(compute_metrics_iteratively, cache_directory=Path(metrics_cache, "coco_cse_val"), include_two_fake=False), + evaluation_fn=functools.partial(final_eval_fn, cache_directory=Path(metrics_cache, "coco_cse_val_final"), include_two_fake=True) +) diff --git a/configs/datasets/fdf128.py b/configs/datasets/fdf128.py new file mode 100644 index 0000000000000000000000000000000000000000..c19fc0d30e2a2aa1b2ac04080a48479f1d190267 --- /dev/null +++ b/configs/datasets/fdf128.py @@ -0,0 +1,24 @@ +from pathlib import Path +from functools import partial +from dp2.data.datasets.fdf import FDFDataset +from .fdf256 import data, dataset_base_dir, metrics_cache, final_eval_fn, train_eval_fn + +data_dir = Path(dataset_base_dir, "fdf") +data.train.dataset.dirpath = data_dir.joinpath("train") +data.val.dataset.dirpath = data_dir.joinpath("val") +data.imsize = (128, 128) + + +data.train_evaluation_fn = partial( + train_eval_fn, cache_directory=Path(metrics_cache, "fdf128_val_train")) +data.evaluation_fn = partial( + final_eval_fn, cache_directory=Path(metrics_cache, "fdf128_val_final")) + +data.train.dataset.update( + _target_ = FDFDataset, + imsize="${data.imsize}" +) +data.val.dataset.update( + _target_ = FDFDataset, + imsize="${data.imsize}" +) \ No newline at end of file diff --git a/configs/datasets/fdf256.py b/configs/datasets/fdf256.py new file mode 100644 index 0000000000000000000000000000000000000000..828802db23ca519cb13963a729c597e39ae8dee8 --- /dev/null +++ b/configs/datasets/fdf256.py @@ -0,0 +1,55 @@ +import os +from pathlib import Path +from tops.config import LazyCall as L +import torch +import functools +from dp2.data.datasets.fdf import FDF256Dataset +from dp2.data.build import get_dataloader +from dp2.data.transforms.transforms import Normalize, Resize, ToFloat, CreateCondition, RandomHorizontalFlip +from .utils import final_eval_fn, train_eval_fn + + +dataset_base_dir = os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data" +metrics_cache = os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache" +data_dir = Path(dataset_base_dir, "fdf256") +data = dict( + imsize=(256, 256), + im_channels=3, + semantic_nc=None, + cse_nc=None, + n_keypoints=None, + train=dict( + dataset=L(FDF256Dataset)(dirpath=data_dir.joinpath("train"), transform=None, load_keypoints=False), + loader=L(get_dataloader)( + shuffle=True, num_workers=3, drop_last=True, prefetch_factor=2, + batch_size="${train.batch_size}", + dataset="${..dataset}", + infinite=True, + gpu_transform=L(torch.nn.Sequential)(*[ + L(ToFloat)(), + L(RandomHorizontalFlip)(p=0.5), + L(Resize)(size="${data.imsize}"), + L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True), + L(CreateCondition)(), + ]) + ) + ), + val=dict( + dataset=L(FDF256Dataset)(dirpath=data_dir.joinpath("val"), transform=None, load_keypoints=False), + loader=L(get_dataloader)( + shuffle=False, num_workers=3, drop_last=False, prefetch_factor=2, + batch_size="${train.batch_size}", + dataset="${..dataset}", + infinite=False, + gpu_transform=L(torch.nn.Sequential)(*[ + L(ToFloat)(), + L(Resize)(size="${data.imsize}"), + L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True), + L(CreateCondition)(), + ]) + ) + ), + # Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP. + train_evaluation_fn=functools.partial(train_eval_fn, cache_directory=Path(metrics_cache, "fdf_val_train")), + evaluation_fn=functools.partial(final_eval_fn, cache_directory=Path(metrics_cache, "fdf_val")) +) \ No newline at end of file diff --git a/configs/datasets/fdh.py b/configs/datasets/fdh.py new file mode 100644 index 0000000000000000000000000000000000000000..47faade58b49ef7fdf227b2b54fe89d1650080f6 --- /dev/null +++ b/configs/datasets/fdh.py @@ -0,0 +1,90 @@ +import os +from pathlib import Path +from tops.config import LazyCall as L +import torch +import functools +from dp2.data.datasets.fdh import get_dataloader_fdh_wds +from dp2.data.utils import get_coco_flipmap +from dp2.data.transforms.transforms import ( + Normalize, + ToFloat, + CreateCondition, + RandomHorizontalFlip, + CreateEmbedding, +) +from dp2.metrics.torch_metrics import compute_metrics_iteratively +from dp2.metrics.fid_clip import compute_fid_clip +from dp2.metrics.ppl import calculate_ppl +from .utils import train_eval_fn + + +def final_eval_fn(*args, **kwargs): + result = compute_metrics_iteratively(*args, **kwargs) + result2 = calculate_ppl(*args, **kwargs, upsample_size=(288, 160)) + result3 = compute_fid_clip(*args, **kwargs) + assert all(key not in result for key in result2) + result.update(result2) + result.update(result3) + return result + + +def get_cache_directory(imsize, subset): + return Path(metrics_cache, f"{subset}{imsize[0]}") + +dataset_base_dir = ( + os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data" +) +metrics_cache = ( + os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache" +) +data_dir = Path(dataset_base_dir, "fdh") +data = dict( + imsize=(288, 160), + im_channels=3, + cse_nc=16, + n_keypoints=17, + train=dict( + loader=L(get_dataloader_fdh_wds)( + path=data_dir.joinpath("train", "out-{000000..001423}.tar"), + batch_size="${train.batch_size}", + num_workers=6, + transform=L(torch.nn.Sequential)( + L(RandomHorizontalFlip)(p=0.5, flip_map=get_coco_flipmap()), + ), + gpu_transform=L(torch.nn.Sequential)( + L(ToFloat)(norm=False, keys=["img", "mask", "E_mask", "maskrcnn_mask"]), + L(CreateEmbedding)(embed_path=data_dir.joinpath("embed_map.torch")), + L(Normalize)(mean=[0.5*255, 0.5*255, 0.5*255], std=[0.5*255, 0.5*255, 0.5*255], inplace=True), + L(CreateCondition)(), + ), + infinite=True, + shuffle=True, + partial_batches=False, + load_embedding=True, + keypoints_split="train", + load_new_keypoints=False + ) + ), + val=dict( + loader=L(get_dataloader_fdh_wds)( + path=data_dir.joinpath("val", "out-{000000..000023}.tar"), + batch_size="${train.batch_size}", + num_workers=6, + transform=None, + gpu_transform="${data.train.loader.gpu_transform}", + infinite=False, + shuffle=False, + partial_batches=True, + load_embedding=True, + keypoints_split="val", + load_new_keypoints="${data.train.loader.load_new_keypoints}" + ) + ), + # Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP. + train_evaluation_fn=L(functools.partial)( + train_eval_fn, cache_directory=L(get_cache_directory)(imsize="${data.imsize}", subset="fdh"), + data_len=30_000), + evaluation_fn=L(functools.partial)( + final_eval_fn, cache_directory=L(get_cache_directory)(imsize="${data.imsize}", subset="fdh_eval"), + data_len=30_000) +) diff --git a/configs/datasets/utils.py b/configs/datasets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6704b12a5e379b707b2c6d3dc9f78431ce01e61d --- /dev/null +++ b/configs/datasets/utils.py @@ -0,0 +1,21 @@ +from dp2.metrics.ppl import calculate_ppl +from dp2.metrics.torch_metrics import compute_metrics_iteratively +from dp2.metrics.fid_clip import compute_fid_clip + + +def final_eval_fn(*args, **kwargs): + result = compute_metrics_iteratively(*args, **kwargs) + result2 = calculate_ppl(*args, **kwargs,) + result3 = compute_fid_clip(*args, **kwargs) + assert all(key not in result for key in result2) + result.update(result2) + result.update(result3) + return result + + +def train_eval_fn(*args, **kwargs): + result = compute_metrics_iteratively(*args, **kwargs) + result2 = compute_fid_clip(*args, **kwargs) + assert all(key not in result for key in result2) + result.update(result2) + return result \ No newline at end of file diff --git a/configs/defaults.py b/configs/defaults.py new file mode 100644 index 0000000000000000000000000000000000000000..4f831200940c0aa9658bab3a82d6ad5714048d3c --- /dev/null +++ b/configs/defaults.py @@ -0,0 +1,53 @@ +import pathlib +import os +import torch +from tops.config import LazyCall as L + +if "PRETRAINED_CHECKPOINTS_PATH" in os.environ: + PRETRAINED_CHECKPOINTS_PATH = pathlib.Path(os.environ["PRETRAINED_CHECKPOINTS_PATH"]) +else: + PRETRAINED_CHECKPOINTS_PATH = pathlib.Path("pretrained_checkpoints") +if "BASE_OUTPUT_DIR" in os.environ: + BASE_OUTPUT_DIR = pathlib.Path(os.environ["BASE_OUTPUT_DIR"]) +else: + BASE_OUTPUT_DIR = pathlib.Path("outputs") + + + +common = dict( + logger_backend=["wandb", "stdout", "json", "image_dumper"], + wandb_project="deep_privacy2", + output_dir=BASE_OUTPUT_DIR, + experiment_name=None, # Optional experiment name to show on wandb +) + +train = dict( + batch_size=32, + seed=0, + ims_per_log=1024, + ims_per_val=int(200e3), + max_images_to_train=int(12e6), + amp=dict( + enabled=True, + scaler_D=L(torch.cuda.amp.GradScaler)(init_scale=2**16, growth_factor=4, growth_interval=100, enabled="${..enabled}"), + scaler_G=L(torch.cuda.amp.GradScaler)(init_scale=2**16, growth_factor=4, growth_interval=100, enabled="${..enabled}"), + ), + fp16_ddp_accumulate=False, # All gather gradients in fp16? + broadcast_buffers=False, + bias_act_plugin_enabled=True, + grid_sample_gradfix_enabled=True, + conv2d_gradfix_enabled=False, + channels_last=False, + compile_G=dict( + enabled=False, + mode="default" # default, reduce-overhead or max-autotune + ), + compile_D=dict( + enabled=False, + mode="default" # default, reduce-overhead or max-autotune + ) +) + +# exponential moving average +EMA = dict(rampup=0.05) + diff --git a/configs/discriminators/sg2_discriminator.py b/configs/discriminators/sg2_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..081ebfb76b330ab99334725b9cc82db06bfb1c5f --- /dev/null +++ b/configs/discriminators/sg2_discriminator.py @@ -0,0 +1,43 @@ +from tops.config import LazyCall as L +from dp2.discriminator import SG2Discriminator +import torch +from dp2.loss import StyleGAN2Loss + + +discriminator = L(SG2Discriminator)( + imsize="${data.imsize}", + im_channels="${data.im_channels}", + min_fmap_resolution=4, + max_cnum_mul=8, + cnum=80, + input_condition=True, + conv_clamp=256, + input_cse=False, + cse_nc="${data.cse_nc}", + fix_residual=False, +) + + +loss_fnc = L(StyleGAN2Loss)( + lazy_regularization=True, + lazy_reg_interval=16, + r1_opts=dict(lambd=5, mask_out=False, mask_out_scale=False), + EP_lambd=0.001, + pl_reg_opts=dict(weight=0, batch_shrink=2,start_nimg=int(1e6), pl_decay=0.01) +) + +def build_D_optim(type, lr, betas, lazy_regularization, lazy_reg_interval, **kwargs): + if lazy_regularization: + # From Analyzing and improving the image quality of stylegan, CVPR 2020 + c = lazy_reg_interval / (lazy_reg_interval + 1) + betas = [beta ** c for beta in betas] + lr *= c + print(f"Lazy regularization on. Setting lr to: {lr}, betas to: {betas}") + return type(lr=lr, betas=betas, **kwargs) + + +D_optim = L(build_D_optim)( + type=torch.optim.Adam, lr=0.001, betas=(0.0, 0.99), + lazy_regularization="${loss_fnc.lazy_regularization}", + lazy_reg_interval="${loss_fnc.lazy_reg_interval}") +G_optim = L(torch.optim.Adam)(lr=0.001, betas=(0.0, 0.99)) diff --git a/configs/fdf/deep_privacy1.py b/configs/fdf/deep_privacy1.py new file mode 100644 index 0000000000000000000000000000000000000000..88c09bd4b5810182bdd3520dfdf4f98bdcea3829 --- /dev/null +++ b/configs/fdf/deep_privacy1.py @@ -0,0 +1,9 @@ +from tops.config import LazyCall as L +from dp2.generator.deep_privacy1 import MSGGenerator +from ..datasets.fdf128 import data +from ..defaults import common, train + +generator = L(MSGGenerator)() + +common.model_url = "https://folk.ntnu.no/haakohu/checkpoints/fdf128_model512.ckpt" +common.model_md5sum = "6cc8b285bdc1fcdfc64f5db7c521d0a6" \ No newline at end of file diff --git a/configs/fdf/stylegan.py b/configs/fdf/stylegan.py new file mode 100644 index 0000000000000000000000000000000000000000..a4da2c3ad76d3d1fb6e1d91e832cde5c735bf32a --- /dev/null +++ b/configs/fdf/stylegan.py @@ -0,0 +1,14 @@ +from ..generators.stylegan_unet import generator +from ..datasets.fdf256 import data +from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc +from ..defaults import train, common, EMA + +train.max_images_to_train = int(35e6) +G_optim.lr = 0.002 +D_optim.lr = 0.002 +generator.input_cse = False +loss_fnc.r1_opts.lambd = 1 +train.ims_per_val = int(2e6) + +common.model_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/89660f04-5c11-4dbf-adac-cbe2f11b0aeea25cbf78-7558-475a-b3c7-03f5c10b7934646b0720-ca0a-4d53-aded-daddbfa45c9e" +common.model_md5sum = "e8e32190528af2ed75f0cb792b7f2b07" \ No newline at end of file diff --git a/configs/fdf/stylegan_fdf128.py b/configs/fdf/stylegan_fdf128.py new file mode 100644 index 0000000000000000000000000000000000000000..a47d6d2ee362c935e7879c9442c4dcd9aaf007c0 --- /dev/null +++ b/configs/fdf/stylegan_fdf128.py @@ -0,0 +1,17 @@ +from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc +from ..datasets.fdf128 import data +from ..generators.stylegan_unet import generator +from ..defaults import train, common, EMA +from tops.config import LazyCall as L + +G_optim.lr = 0.002 +D_optim.lr = 0.002 +generator.update(cnum=128, max_cnum_mul=4, input_cse=False) +loss_fnc.r1_opts.lambd = 0.1 + +train.update(ims_per_val=int(2e6), batch_size=64, max_images_to_train=int(35e6)) + +common.update( + model_url="https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/66d803c0-55ce-44c0-9d53-815c2c0e6ba4eb458409-9e91-45d1-bce0-95c8a47a57218b102fdf-bea3-44dc-aac4-0fb1d370ef1c", + model_md5sum="bccd4403e7c9bca682566ff3319e8176" +) \ No newline at end of file diff --git a/configs/fdh/styleganL.py b/configs/fdh/styleganL.py new file mode 100644 index 0000000000000000000000000000000000000000..48fcf09b43a7141a270fbe5c69bd7932414270fe --- /dev/null +++ b/configs/fdh/styleganL.py @@ -0,0 +1,16 @@ +from tops.config import LazyCall as L +from ..generators.stylegan_unet import generator +from ..datasets.fdh import data +from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc +from ..defaults import train, common, EMA + +train.max_images_to_train = int(50e6) +train.batch_size = 64 +G_optim.lr = 0.002 +D_optim.lr = 0.002 +data.train.loader.num_workers = 4 +train.ims_per_val = int(1e6) +loss_fnc.r1_opts.lambd = .1 + +common.model_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/21841da7-2546-4ce3-8460-909b3a63c58b13aac1a1-c778-4c8d-9b69-3e5ed2cde9de1524e76e-7aa6-4dd8-b643-52abc9f0792c" +common.model_md5sum = "3411478b5ec600a4219cccf4499732bd" \ No newline at end of file diff --git a/configs/fdh/styleganL_nocse.py b/configs/fdh/styleganL_nocse.py new file mode 100644 index 0000000000000000000000000000000000000000..210fd68743f0b872f89f4407dfaac7c9bf5f0e32 --- /dev/null +++ b/configs/fdh/styleganL_nocse.py @@ -0,0 +1,14 @@ +from tops.config import LazyCall as L +from ..generators.stylegan_unet import generator +from ..datasets.fdh import data +from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc +from ..defaults import train, common, EMA + +train.max_images_to_train = int(50e6) +G_optim.lr = 0.002 +D_optim.lr = 0.002 +generator.input_cse = False +data.load_embeddings = False +common.model_url = "https://folk.ntnu.no/haakohu/checkpoints/deep_privacy2/fdh_styleganL_nocse.ckpt" +common.model_md5sum = "fda0d809741bc67487abada793975c37" +generator.fix_errors = False \ No newline at end of file diff --git a/configs/generators/stylegan_unet.py b/configs/generators/stylegan_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..638859263a1cb549f533b75b2b19609665b3443e --- /dev/null +++ b/configs/generators/stylegan_unet.py @@ -0,0 +1,22 @@ +from dp2.generator.stylegan_unet import StyleGANUnet +from tops.config import LazyCall as L + +generator = L(StyleGANUnet)( + imsize="${data.imsize}", + im_channels="${data.im_channels}", + min_fmap_resolution=8, + cnum=64, + max_cnum_mul=8, + n_middle_blocks=0, + z_channels=512, + mask_output=True, + conv_clamp=256, + input_cse=True, + scale_grad=True, + cse_nc="${data.cse_nc}", + w_dim=512, + n_keypoints="${data.n_keypoints}", + input_keypoints=False, + input_keypoint_indices=[], + fix_errors=True +) \ No newline at end of file diff --git a/dp2/__init__.py b/dp2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/dp2/anonymizer/__init__.py b/dp2/anonymizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..32606aa927c8d593d64be02a499fba057b8ba6fa --- /dev/null +++ b/dp2/anonymizer/__init__.py @@ -0,0 +1 @@ +from .anonymizer import Anonymizer diff --git a/dp2/anonymizer/anonymizer.py b/dp2/anonymizer/anonymizer.py new file mode 100644 index 0000000000000000000000000000000000000000..a32e61a122c2e1916c0ce9baffecb095c5a136de --- /dev/null +++ b/dp2/anonymizer/anonymizer.py @@ -0,0 +1,163 @@ +from pathlib import Path +from typing import Union, Optional +import numpy as np +import torch +import tops +import torchvision.transforms.functional as F +from motpy import Detection, MultiObjectTracker +from dp2.utils import load_config +from dp2.infer import build_trained_generator +from dp2.detection.structures import CSEPersonDetection, FaceDetection, PersonDetection, VehicleDetection + + +def load_generator_from_cfg_path(cfg_path: Union[str, Path]): + cfg = load_config(cfg_path) + G = build_trained_generator(cfg) + tops.logger.log(f"Loaded generator from: {cfg_path}") + return G + + +class Anonymizer: + + def __init__( + self, + detector, + load_cache: bool = False, + person_G_cfg: Optional[Union[str, Path]] = None, + cse_person_G_cfg: Optional[Union[str, Path]] = None, + face_G_cfg: Optional[Union[str, Path]] = None, + car_G_cfg: Optional[Union[str, Path]] = None, + ) -> None: + self.detector = detector + self.generators = {k: None for k in [CSEPersonDetection, PersonDetection, FaceDetection, VehicleDetection]} + self.load_cache = load_cache + if cse_person_G_cfg is not None: + self.generators[CSEPersonDetection] = load_generator_from_cfg_path(cse_person_G_cfg) + if person_G_cfg is not None: + self.generators[PersonDetection] = load_generator_from_cfg_path(person_G_cfg) + if face_G_cfg is not None: + self.generators[FaceDetection] = load_generator_from_cfg_path(face_G_cfg) + if car_G_cfg is not None: + self.generators[VehicleDetection] = load_generator_from_cfg_path(car_G_cfg) + + def initialize_tracker(self, fps: float): + self.tracker = MultiObjectTracker(dt=1/fps) + self.track_to_z_idx = dict() + + def reset_tracker(self): + self.track_to_z_idx = dict() + + def forward_G(self, + G, + batch, + multi_modal_truncation: bool, + amp: bool, + z_idx: int, + truncation_value: float, + idx: int, + all_styles=None): + batch["img"] = F.normalize(batch["img"].float(), [0.5*255, 0.5*255, 0.5*255], [0.5*255, 0.5*255, 0.5*255]) + batch["img"] = batch["img"].float() + batch["condition"] = batch["mask"].float() * batch["img"] + + with torch.cuda.amp.autocast(amp): + z = None + if z_idx is not None: + state = np.random.RandomState(seed=z_idx[idx]) + z = state.normal(size=(1, G.z_channels)).astype(np.float32) + z = tops.to_cuda(torch.from_numpy(z)) + + if all_styles is not None: + anonymized_im = G(**batch, s=iter(all_styles[idx]))["img"] + elif multi_modal_truncation: + w_indices = None + if z_idx is not None: + w_indices = [z_idx[idx] % len(G.style_net.w_centers)] + anonymized_im = G.multi_modal_truncate( + **batch, truncation_value=truncation_value, + w_indices=w_indices, + z=z + )["img"] + else: + anonymized_im = G.sample(**batch, truncation_value=truncation_value, z=z)["img"] + anonymized_im = (anonymized_im+1).div(2).clamp(0, 1).mul(255) + return anonymized_im + + @torch.no_grad() + def anonymize_detections(self, + im, detection, + update_identity=None, + **synthesis_kwargs + ): + G = self.generators[type(detection)] + if G is None: + return im + C, H, W = im.shape + if update_identity is None: + update_identity = [True for i in range(len(detection))] + for idx in range(len(detection)): + if not update_identity[idx]: + continue + batch = detection.get_crop(idx, im) + x0, y0, x1, y1 = batch.pop("boxes")[0] + batch = {k: tops.to_cuda(v) for k, v in batch.items()} + anonymized_im = self.forward_G(G, batch, **synthesis_kwargs, idx=idx) + + gim = F.resize(anonymized_im[0], (y1-y0, x1-x0), interpolation=F.InterpolationMode.BICUBIC, antialias=True) + mask = F.resize(batch["mask"][0], (y1-y0, x1-x0), interpolation=F.InterpolationMode.NEAREST).squeeze(0) + # Remove padding + pad = [max(-x0, 0), max(-y0, 0)] + pad = [*pad, max(x1-W, 0), max(y1-H, 0)] + def remove_pad(x): return x[..., pad[1]:x.shape[-2]-pad[3], pad[0]:x.shape[-1]-pad[2]] + + gim = remove_pad(gim) + mask = remove_pad(mask) > 0.5 + x0, y0 = max(x0, 0), max(y0, 0) + x1, y1 = min(x1, W), min(y1, H) + mask = mask.logical_not()[None].repeat(3, 1, 1) + + im[:, y0:y1, x0:x1][mask] = gim[mask].round().clamp(0, 255).byte() + return im + + def visualize_detection(self, im: torch.Tensor, cache_id: str = None) -> torch.Tensor: + all_detections = self.detector.forward_and_cache(im, cache_id, load_cache=self.load_cache) + im = im.cpu() + for det in all_detections: + im = det.visualize(im) + return im + + @torch.no_grad() + def forward(self, im: torch.Tensor, cache_id: str = None, track=True, detections=None, **synthesis_kwargs) -> torch.Tensor: + assert im.dtype == torch.uint8 + im = tops.to_cuda(im) + all_detections = detections + if detections is None: + if self.load_cache: + all_detections = self.detector.forward_and_cache(im, cache_id) + else: + all_detections = self.detector(im) + if hasattr(self, "tracker") and track: + [_.pre_process() for _ in all_detections] + boxes = np.concatenate([_.boxes for _ in all_detections]) + boxes = [Detection(box) for box in boxes] + self.tracker.step(boxes) + track_ids = self.tracker.detections_matched_ids + z_idx = [] + for track_id in track_ids: + if track_id not in self.track_to_z_idx: + self.track_to_z_idx[track_id] = np.random.randint(0, 2**32-1) + z_idx.append(self.track_to_z_idx[track_id]) + z_idx = np.array(z_idx) + idx_offset = 0 + + for detection in all_detections: + zs = None + if hasattr(self, "tracker") and track: + zs = z_idx[idx_offset:idx_offset+len(detection)] + idx_offset += len(detection) + im = self.anonymize_detections(im, detection, z_idx=zs, **synthesis_kwargs) + + return im.cpu() + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) diff --git a/dp2/anonymizer/histogram_match_anonymizers.py b/dp2/anonymizer/histogram_match_anonymizers.py new file mode 100644 index 0000000000000000000000000000000000000000..421c80d5624b113afdf9aa4908b5c9cd0ba33c94 --- /dev/null +++ b/dp2/anonymizer/histogram_match_anonymizers.py @@ -0,0 +1,93 @@ + +import torch +import tops +import numpy as np +from kornia.color import rgb_to_hsv +from dp2 import utils +from kornia.enhance import histogram +from .anonymizer import Anonymizer +import torchvision.transforms.functional as F +from skimage.exposure import match_histograms +from kornia.filters import gaussian_blur2d + + +class LatentHistogramMatchAnonymizer(Anonymizer): + + def forward_G( + self, + G, + batch, + multi_modal_truncation: bool, + amp: bool, + z_idx: int, + truncation_value: float, + idx: int, + n_sampling_steps: int = 1, + all_styles=None, + ): + batch["img"] = F.normalize(batch["img"].float(), [0.5*255, 0.5*255, 0.5*255], [0.5*255, 0.5*255, 0.5*255]) + batch["img"] = batch["img"].float() + batch["condition"] = batch["mask"].float() * batch["img"] + + assert z_idx is None and all_styles is None, "Arguments not supported with n_sampling_steps > 1." + real_hls = rgb_to_hsv(utils.denormalize_img(batch["img"])) + real_hls[:, 0] /= 2 * torch.pi + indices = [1, 2] + hist_kwargs = dict( + bins=torch.linspace(0, 1, 256, dtype=torch.float32, device=tops.get_device()), + bandwidth=torch.tensor(1., device=tops.get_device())) + real_hist = [histogram(real_hls[:, i].flatten(start_dim=1), **hist_kwargs) for i in indices] + for j in range(n_sampling_steps): + if j == 0: + if multi_modal_truncation: + w = G.style_net.multi_modal_truncate( + truncation_value=truncation_value, **batch, w_indices=None).detach() + else: + w = G.style_net.get_truncated(truncation_value, **batch).detach() + assert z_idx is None and all_styles is None, "Arguments not supported with n_sampling_steps > 1." + w.requires_grad = True + optim = torch.optim.Adam([w]) + with torch.set_grad_enabled(True): + with torch.cuda.amp.autocast(amp): + anonymized_im = G(**batch, truncation_value=None, w=w)["img"] + fake_hls = rgb_to_hsv(anonymized_im*0.5 + 0.5) + fake_hls[:, 0] /= 2 * torch.pi + fake_hist = [histogram(fake_hls[:, i].flatten(start_dim=1), **hist_kwargs) for i in indices] + dist = sum([utils.torch_wasserstein_loss(r, f) for r, f in zip(real_hist, fake_hist)]) + dist.backward() + if w.grad.sum() == 0: + break + assert w.grad.sum() != 0 + optim.step() + optim.zero_grad() + if dist < 0.02: + break + anonymized_im = (anonymized_im+1).div(2).clamp(0, 1).mul(255) + return anonymized_im + + +class HistogramMatchAnonymizer(Anonymizer): + + def forward_G(self, batch, *args, **kwargs): + rimg = batch["img"] + batch["img"] = F.normalize(batch["img"].float(), [0.5*255, 0.5*255, 0.5*255], [0.5*255, 0.5*255, 0.5*255]) + batch["img"] = batch["img"].float() + batch["condition"] = batch["mask"].float() * batch["img"] + + anonymized_im = super().forward_G(batch, *args, **kwargs) + + equalized_gim = match_histograms(tops.im2numpy(anonymized_im.round().clamp(0, 255).byte()), tops.im2numpy(rimg)) + if equalized_gim.dtype != np.uint8: + equalized_gim = equalized_gim.astype(np.float32) + assert equalized_gim.dtype == np.float32, equalized_gim.dtype + equalized_gim = tops.im2torch(equalized_gim, to_float=False)[0] + else: + equalized_gim = tops.im2torch(equalized_gim, to_float=False).float()[0] + equalized_gim = equalized_gim.to(device=rimg.device) + assert equalized_gim.dtype == torch.float32 + gaussian_mask = 1 - (batch["maskrcnn_mask"][0].repeat(3, 1, 1) > 0.5).float() + + gaussian_mask = gaussian_blur2d(gaussian_mask[None], kernel_size=[19, 19], sigma=[10, 10])[0] + gaussian_mask = gaussian_mask / gaussian_mask.max() + anonymized_im = gaussian_mask * equalized_gim + (1-gaussian_mask) * anonymized_im + return anonymized_im diff --git a/dp2/data/__init__.py b/dp2/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/dp2/data/build.py b/dp2/data/build.py new file mode 100644 index 0000000000000000000000000000000000000000..ceab946b4da20467f879f3c6af0e9eb985465ac4 --- /dev/null +++ b/dp2/data/build.py @@ -0,0 +1,40 @@ +import torch +import tops +from .utils import collate_fn + + +def get_dataloader( + dataset, gpu_transform: torch.nn.Module, + num_workers, + batch_size, + infinite: bool, + drop_last: bool, + prefetch_factor: int, + shuffle, + channels_last=False + ): + sampler = None + dl_kwargs = dict( + pin_memory=True, + ) + if infinite: + sampler = tops.InfiniteSampler( + dataset, rank=tops.rank(), + num_replicas=tops.world_size(), + shuffle=shuffle + ) + elif tops.world_size() > 1: + sampler = torch.utils.data.DistributedSampler( + dataset, shuffle=shuffle, num_replicas=tops.world_size(), rank=tops.rank()) + dl_kwargs["drop_last"] = drop_last + else: + dl_kwargs["shuffle"] = shuffle + dl_kwargs["drop_last"] = drop_last + dataloader = torch.utils.data.DataLoader( + dataset, sampler=sampler, collate_fn=collate_fn, + batch_size=batch_size, + num_workers=num_workers, prefetch_factor=prefetch_factor, + **dl_kwargs + ) + dataloader = tops.DataPrefetcher(dataloader, gpu_transform, channels_last=channels_last) + return dataloader diff --git a/dp2/data/datasets/__init__.py b/dp2/data/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/dp2/data/datasets/coco_cse.py b/dp2/data/datasets/coco_cse.py new file mode 100644 index 0000000000000000000000000000000000000000..b240932da41b2db03c3807830935a78b50f84c4f --- /dev/null +++ b/dp2/data/datasets/coco_cse.py @@ -0,0 +1,68 @@ +import pickle +import torchvision +import torch +import pathlib +import numpy as np +from typing import Callable, Optional, Union +from torch.hub import get_dir as get_hub_dir + + +def cache_embed_stats(embed_map: torch.Tensor): + mean = embed_map.mean(dim=0, keepdim=True) + rstd = ((embed_map - mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt() + + cache = dict(mean=mean, rstd=rstd, embed_map=embed_map) + path = pathlib.Path(get_hub_dir(), f"embed_map_stats.torch") + path.parent.mkdir(exist_ok=True, parents=True) + torch.save(cache, path) + + +class CocoCSE(torch.utils.data.Dataset): + + def __init__(self, + dirpath: Union[str, pathlib.Path], + transform: Optional[Callable], + normalize_E: bool,): + dirpath = pathlib.Path(dirpath) + self.dirpath = dirpath + + self.transform = transform + assert self.dirpath.is_dir(),\ + f"Did not find dataset at: {dirpath}" + self.image_paths, self.embedding_paths = self._load_impaths() + self.embed_map = torch.from_numpy(np.load(self.dirpath.joinpath("embed_map.npy"))) + mean = self.embed_map.mean(dim=0, keepdim=True) + rstd = ((self.embed_map - mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt() + self.embed_map = (self.embed_map - mean) * rstd + cache_embed_stats(self.embed_map) + + def _load_impaths(self): + image_dir = self.dirpath.joinpath("images") + image_paths = list(image_dir.glob("*.png")) + image_paths.sort() + embedding_paths = [ + self.dirpath.joinpath("embedding", x.stem + ".npy") for x in image_paths + ] + return image_paths, embedding_paths + + def __len__(self): + return len(self.image_paths) + + def __getitem__(self, idx): + im = torchvision.io.read_image(str(self.image_paths[idx])) + vertices, mask, border = np.split(np.load(self.embedding_paths[idx]), 3, axis=-1) + vertices = torch.from_numpy(vertices.squeeze()).long() + mask = torch.from_numpy(mask.squeeze()).float() + border = torch.from_numpy(border.squeeze()).float() + E_mask = 1 - mask - border + batch = { + "img": im, + "vertices": vertices[None], + "mask": mask[None], + "embed_map": self.embed_map, + "border": border[None], + "E_mask": E_mask[None] + } + if self.transform is None: + return batch + return self.transform(batch) diff --git a/dp2/data/datasets/fdf.py b/dp2/data/datasets/fdf.py new file mode 100644 index 0000000000000000000000000000000000000000..23f68a52d4fb50143b2ef6720e126991b2981afc --- /dev/null +++ b/dp2/data/datasets/fdf.py @@ -0,0 +1,128 @@ +import pathlib +from typing import Tuple +import numpy as np +import torch +import pathlib +try: + import pyspng + PYSPNG_IMPORTED = True +except ImportError: + PYSPNG_IMPORTED = False + print("Could not load pyspng. Defaulting to pillow image backend.") + from PIL import Image +from tops import logger + + +class FDFDataset: + + def __init__(self, + dirpath, + imsize: Tuple[int], + load_keypoints: bool, + transform): + dirpath = pathlib.Path(dirpath) + self.dirpath = dirpath + self.transform = transform + self.imsize = imsize[0] + self.load_keypoints = load_keypoints + assert self.dirpath.is_dir(),\ + f"Did not find dataset at: {dirpath}" + image_dir = self.dirpath.joinpath("images", str(self.imsize)) + self.image_paths = list(image_dir.glob("*.png")) + assert len(self.image_paths) > 0,\ + f"Did not find images in: {image_dir}" + self.image_paths.sort(key=lambda x: int(x.stem)) + self.landmarks = np.load(self.dirpath.joinpath("landmarks.npy")).reshape(-1, 7, 2).astype(np.float32) + + self.bounding_boxes = torch.load(self.dirpath.joinpath("bounding_box", f"{self.imsize}.torch")) + assert len(self.image_paths) == len(self.bounding_boxes) + assert len(self.image_paths) == len(self.landmarks) + logger.log( + f"Dataset loaded from: {dirpath}. Number of samples:{len(self)}, imsize={imsize}") + + def get_mask(self, idx): + mask = torch.ones((1, self.imsize, self.imsize), dtype=torch.bool) + bounding_box = self.bounding_boxes[idx] + x0, y0, x1, y1 = bounding_box + mask[:, y0:y1, x0:x1] = 0 + return mask + + def __len__(self): + return len(self.image_paths) + + def __getitem__(self, index): + impath = self.image_paths[index] + if PYSPNG_IMPORTED: + with open(impath, "rb") as fp: + im = pyspng.load(fp.read()) + else: + with Image.open(impath) as fp: + im = np.array(fp) + im = torch.from_numpy(np.rollaxis(im, -1, 0)) + masks = self.get_mask(index) + landmark = self.landmarks[index] + batch = { + "img": im, + "mask": masks, + } + if self.load_keypoints: + batch["keypoints"] = landmark + if self.transform is None: + return batch + return self.transform(batch) + + +class FDF256Dataset: + + def __init__(self, + dirpath, + load_keypoints: bool, + transform): + dirpath = pathlib.Path(dirpath) + self.dirpath = dirpath + self.transform = transform + self.load_keypoints = load_keypoints + assert self.dirpath.is_dir(),\ + f"Did not find dataset at: {dirpath}" + image_dir = self.dirpath.joinpath("images") + self.image_paths = list(image_dir.glob("*.png")) + assert len(self.image_paths) > 0,\ + f"Did not find images in: {image_dir}" + self.image_paths.sort(key=lambda x: int(x.stem)) + self.landmarks = np.load(self.dirpath.joinpath("landmarks.npy")).reshape(-1, 7, 2).astype(np.float32) + self.bounding_boxes = torch.from_numpy(np.load(self.dirpath.joinpath("bounding_box.npy"))) + assert len(self.image_paths) == len(self.bounding_boxes) + assert len(self.image_paths) == len(self.landmarks) + logger.log( + f"Dataset loaded from: {dirpath}. Number of samples:{len(self)}") + + def get_mask(self, idx): + mask = torch.ones((1, 256, 256), dtype=torch.bool) + bounding_box = self.bounding_boxes[idx] + x0, y0, x1, y1 = bounding_box + mask[:, y0:y1, x0:x1] = 0 + return mask + + def __len__(self): + return len(self.image_paths) + + def __getitem__(self, index): + impath = self.image_paths[index] + if PYSPNG_IMPORTED: + with open(impath, "rb") as fp: + im = pyspng.load(fp.read()) + else: + with Image.open(impath) as fp: + im = np.array(fp) + im = torch.from_numpy(np.rollaxis(im, -1, 0)) + masks = self.get_mask(index) + landmark = self.landmarks[index] + batch = { + "img": im, + "mask": masks, + } + if self.load_keypoints: + batch["keypoints"] = landmark + if self.transform is None: + return batch + return self.transform(batch) diff --git a/dp2/data/datasets/fdf128_wds.py b/dp2/data/datasets/fdf128_wds.py new file mode 100644 index 0000000000000000000000000000000000000000..5af477de9b1d2e2670bfae24930c9e698e95b975 --- /dev/null +++ b/dp2/data/datasets/fdf128_wds.py @@ -0,0 +1,96 @@ +import torch +import tops +import numpy as np +import io +import webdataset as wds +import os +from ..utils import png_decoder, get_num_workers, collate_fn + + +def kp_decoder(x): + # Keypoints are between [0, 1] for webdataset + keypoints = torch.from_numpy(np.load(io.BytesIO(x))).float().view(7, 2).clamp(0, 1) + keypoints = torch.cat((keypoints, torch.ones((7, 1))), dim=-1) + return keypoints + + +def bbox_decoder(x): + return torch.from_numpy(np.load(io.BytesIO(x))).float().view(4) + + +class BBoxToMask: + + def __call__(self, sample): + imsize = sample["image.png"].shape[-1] + bbox = sample["bounding_box.npy"] * imsize + x0, y0, x1, y1 = np.round(bbox).astype(np.int64) + mask = torch.ones((1, imsize, imsize), dtype=torch.bool) + mask[:, y0:y1, x0:x1] = 0 + sample["mask"] = mask + return sample + + +def get_dataloader_fdf_wds( + path, + batch_size: int, + num_workers: int, + transform: torch.nn.Module, + gpu_transform: torch.nn.Module, + infinite: bool, + shuffle: bool, + partial_batches: bool, + sample_shuffle=10_000, + tar_shuffle=100, + channels_last=False, + ): + # Need to set this for split_by_node to work. + os.environ["RANK"] = str(tops.rank()) + os.environ["WORLD_SIZE"] = str(tops.world_size()) + if infinite: + pipeline = [wds.ResampledShards(str(path))] + else: + pipeline = [wds.SimpleShardList(str(path))] + if shuffle: + pipeline.append(wds.shuffle(tar_shuffle)) + pipeline.extend([ + wds.split_by_node, + wds.split_by_worker, + ]) + if shuffle: + pipeline.append(wds.shuffle(sample_shuffle)) + + decoder = [ + wds.handle_extension("image.png", png_decoder), + wds.handle_extension("keypoints.npy", kp_decoder), + ] + + rename_keys = [ + ["img", "image.png"], + ["keypoints", "keypoints.npy"], + ["__key__", "__key__"], + ["mask", "mask"] + ] + + pipeline.extend([ + wds.tarfile_to_samples(), + wds.decode(*decoder), + ]) + pipeline.append(wds.map(BBoxToMask())) + pipeline.extend([ + wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches), + wds.rename_keys(*rename_keys), + ]) + + if transform is not None: + pipeline.append(wds.map(transform)) + pipeline = wds.DataPipeline(*pipeline) + if infinite: + pipeline = pipeline.repeat(nepochs=1000000) + + loader = wds.WebLoader( + pipeline, batch_size=None, shuffle=False, + num_workers=get_num_workers(num_workers), + persistent_workers=True, + ) + loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False) + return loader diff --git a/dp2/data/datasets/fdh.py b/dp2/data/datasets/fdh.py new file mode 100644 index 0000000000000000000000000000000000000000..0c5293b42874c644da4622687f407c069a0a8e07 --- /dev/null +++ b/dp2/data/datasets/fdh.py @@ -0,0 +1,142 @@ +import torch +import tops +import numpy as np +import io +import webdataset as wds +import os +import json +from pathlib import Path +from ..utils import png_decoder, mask_decoder, get_num_workers, collate_fn + + +def kp_decoder(x): + # Keypoints are between [0, 1] for webdataset + keypoints = torch.from_numpy(np.load(io.BytesIO(x))).float() + def check_outside(x): return (x < 0).logical_or(x > 1) + is_outside = check_outside(keypoints[:, 0]).logical_or( + check_outside(keypoints[:, 1]) + ) + keypoints[:, 2] = (keypoints[:, 2] > 0).logical_and(is_outside.logical_not()) + return keypoints + + +def vertices_decoder(x): + vertices = torch.from_numpy(np.load(io.BytesIO(x)).astype(np.int32)) + return vertices.squeeze()[None] + + +class InsertNewKeypoints: + + def __init__(self, keypoints_path: Path) -> None: + with open(keypoints_path, "r") as fp: + self.keypoints = json.load(fp) + + def __call__(self, sample): + key = sample["__key__"] + keypoints = torch.tensor(self.keypoints[key], dtype=torch.float32) + def check_outside(x): return (x < 0).logical_or(x > 1) + is_outside = check_outside(keypoints[:, 0]).logical_or( + check_outside(keypoints[:, 1]) + ) + keypoints[:, 2] = (keypoints[:, 2] > 0).logical_and(is_outside.logical_not()) + + sample["keypoints.npy"] = keypoints + return sample + + +def get_dataloader_fdh_wds( + path, + batch_size: int, + num_workers: int, + transform: torch.nn.Module, + gpu_transform: torch.nn.Module, + infinite: bool, + shuffle: bool, + partial_batches: bool, + load_embedding: bool, + sample_shuffle=10_000, + tar_shuffle=100, + read_condition=False, + channels_last=False, + load_new_keypoints=False, + keypoints_split=None, + ): + # Need to set this for split_by_node to work. + os.environ["RANK"] = str(tops.rank()) + os.environ["WORLD_SIZE"] = str(tops.world_size()) + if infinite: + pipeline = [wds.ResampledShards(str(path))] + else: + pipeline = [wds.SimpleShardList(str(path))] + if shuffle: + pipeline.append(wds.shuffle(tar_shuffle)) + pipeline.extend([ + wds.split_by_node, + wds.split_by_worker, + ]) + if shuffle: + pipeline.append(wds.shuffle(sample_shuffle)) + + decoder = [ + wds.handle_extension("image.png", png_decoder), + wds.handle_extension("mask.png", mask_decoder), + wds.handle_extension("maskrcnn_mask.png", mask_decoder), + wds.handle_extension("keypoints.npy", kp_decoder), + ] + + rename_keys = [ + ["img", "image.png"], ["mask", "mask.png"], + ["keypoints", "keypoints.npy"], ["maskrcnn_mask", "maskrcnn_mask.png"], + ["__key__", "__key__"] + ] + if load_embedding: + decoder.extend([ + wds.handle_extension("vertices.npy", vertices_decoder), + wds.handle_extension("E_mask.png", mask_decoder) + ]) + rename_keys.extend([ + ["vertices", "vertices.npy"], + ["E_mask", "e_mask.png"] + ]) + + if read_condition: + decoder.append( + wds.handle_extension("condition.png", png_decoder) + ) + rename_keys.append(["condition", "condition.png"]) + + pipeline.extend([ + wds.tarfile_to_samples(), + wds.decode(*decoder), + + ]) + if load_new_keypoints: + assert keypoints_split in ["train", "val"] + keypoint_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/1eb88522-8b91-49c7-b56a-ed98a9c7888cef9c0429-a385-4248-abe3-8682de26d041f268aed1-7c88-4677-baad-7623c2ee330f" + file_name = "fdh_keypoints_val-050133b34d.json" + if keypoints_split == "train": + keypoint_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/3e828b1c-d6c0-4622-90bc-1b2cce48ccfff14ab45d-0a5c-431d-be13-7e60580765bd7938601c-e72e-41d9-8836-fffc49e76f58" + file_name = "fdh_keypoints_train-2cff11f69a.json" + # Set check_hash=True if you suspect download is incorrect. + filepath = tops.download_file(keypoint_url, file_name=file_name, check_hash=False) + pipeline.append( + wds.map(InsertNewKeypoints(filepath)) + ) + pipeline.extend([ + wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches), + wds.rename_keys(*rename_keys), + ]) + + if transform is not None: + pipeline.append(wds.map(transform)) + pipeline = wds.DataPipeline(*pipeline) + if infinite: + pipeline = pipeline.repeat(nepochs=1000000) + + loader = wds.WebLoader( + pipeline, batch_size=None, shuffle=False, + num_workers=get_num_workers(num_workers), + persistent_workers=True, + ) + loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False) + return loader diff --git a/dp2/data/transforms/__init__.py b/dp2/data/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9ee4bcf4e825af435ffde4c2b6e3c74112f8438f --- /dev/null +++ b/dp2/data/transforms/__init__.py @@ -0,0 +1,2 @@ +from .transforms import RandomCrop, CreateCondition, CreateEmbedding, Resize, ToFloat, Normalize +from .stylegan2_transform import StyleGANAugmentPipe diff --git a/dp2/data/transforms/functional.py b/dp2/data/transforms/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..1ee57f27ad07e597098ce1de967c3a50a1d06d0a --- /dev/null +++ b/dp2/data/transforms/functional.py @@ -0,0 +1,57 @@ +import torchvision.transforms.functional as F +import torch +import pickle +from tops import download_file, assert_shape +from typing import Dict +from functools import lru_cache + +global symmetry_transform + + +@lru_cache(maxsize=1) +def get_symmetry_transform(symmetry_url): + file_name = download_file(symmetry_url) + with open(file_name, "rb") as fp: + symmetry = pickle.load(fp) + return torch.from_numpy(symmetry["vertex_transforms"]).long() + + +hflip_handled_cases = set([ + "keypoints", "img", "mask", "border", "semantic_mask", "vertices", "E_mask", "embed_map", "condition", + "embedding", "vertx2cat", "maskrcnn_mask", "__key__"]) + + +def hflip(container: Dict[str, torch.Tensor], flip_map=None) -> Dict[str, torch.Tensor]: + container["img"] = F.hflip(container["img"]) + if "condition" in container: + container["condition"] = F.hflip(container["condition"]) + if "embedding" in container: + container["embedding"] = F.hflip(container["embedding"]) + assert all([key in hflip_handled_cases for key in container]), container.keys() + if "keypoints" in container: + assert flip_map is not None + if container["keypoints"].ndim == 3: + keypoints = container["keypoints"][:, flip_map, :] + keypoints[:, :, 0] = 1 - keypoints[:, :, 0] + else: + assert_shape(container["keypoints"], (None, 3)) + keypoints = container["keypoints"][flip_map, :] + keypoints[:, 0] = 1 - keypoints[:, 0] + container["keypoints"] = keypoints + if "mask" in container: + container["mask"] = F.hflip(container["mask"]) + if "border" in container: + container["border"] = F.hflip(container["border"]) + if "semantic_mask" in container: + container["semantic_mask"] = F.hflip(container["semantic_mask"]) + if "vertices" in container: + symmetry_transform = get_symmetry_transform( + "https://dl.fbaipublicfiles.com/densepose/meshes/symmetry/symmetry_smpl_27554.pkl") + container["vertices"] = F.hflip(container["vertices"]) + symmetry_transform_ = symmetry_transform.to(container["vertices"].device) + container["vertices"] = symmetry_transform_[container["vertices"].long()] + if "E_mask" in container: + container["E_mask"] = F.hflip(container["E_mask"]) + if "maskrcnn_mask" in container: + container["maskrcnn_mask"] = F.hflip(container["maskrcnn_mask"]) + return container diff --git a/dp2/data/transforms/stylegan2_transform.py b/dp2/data/transforms/stylegan2_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..49a143cddf9673d079b87ac7d725c433713e54c5 --- /dev/null +++ b/dp2/data/transforms/stylegan2_transform.py @@ -0,0 +1,394 @@ +import numpy as np +import scipy.signal +import torch +try: + from sg3_torch_utils import misc + from sg3_torch_utils.ops import upfirdn2d + from sg3_torch_utils.ops import grid_sample_gradfix + from sg3_torch_utils.ops import conv2d_gradfix +except: + pass +#---------------------------------------------------------------------------- +# Coefficients of various wavelet decomposition low-pass filters. + +wavelets = { + 'haar': [0.7071067811865476, 0.7071067811865476], + 'db1': [0.7071067811865476, 0.7071067811865476], + 'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025], + 'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569], + 'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523], + 'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125], + 'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017], + 'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236], + 'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161], + 'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025], + 'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569], + 'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427], + 'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728], + 'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148], + 'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255], + 'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609], +} + +#---------------------------------------------------------------------------- +# Helpers for constructing transformation matrices. + + +def matrix(*rows, device=None): + assert all(len(row) == len(rows[0]) for row in rows) + elems = [x for row in rows for x in row] + ref = [x for x in elems if isinstance(x, torch.Tensor)] + if len(ref) == 0: + return misc.constant(np.asarray(rows), device=device) + assert device is None or device == ref[0].device + elems = [x if isinstance(x, torch.Tensor) else misc.constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems] + return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1)) + + +def translate2d(tx, ty, **kwargs): + return matrix( + [1, 0, tx], + [0, 1, ty], + [0, 0, 1], + **kwargs) + + +def translate3d(tx, ty, tz, **kwargs): + return matrix( + [1, 0, 0, tx], + [0, 1, 0, ty], + [0, 0, 1, tz], + [0, 0, 0, 1], + **kwargs) + + +def scale2d(sx, sy, **kwargs): + return matrix( + [sx, 0, 0], + [0, sy, 0], + [0, 0, 1], + **kwargs) + + +def scale3d(sx, sy, sz, **kwargs): + return matrix( + [sx, 0, 0, 0], + [0, sy, 0, 0], + [0, 0, sz, 0], + [0, 0, 0, 1], + **kwargs) + + +def rotate2d(theta, **kwargs): + return matrix( + [torch.cos(theta), torch.sin(-theta), 0], + [torch.sin(theta), torch.cos(theta), 0], + [0, 0, 1], + **kwargs) + + +def rotate3d(v, theta, **kwargs): + vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2] + s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c + return matrix( + [vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0], + [vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0], + [vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0], + [0, 0, 0, 1], + **kwargs) + + +def translate2d_inv(tx, ty, **kwargs): + return translate2d(-tx, -ty, **kwargs) + + +def scale2d_inv(sx, sy, **kwargs): + return scale2d(1 / sx, 1 / sy, **kwargs) + + +def rotate2d_inv(theta, **kwargs): + return rotate2d(-theta, **kwargs) + + +class StyleGANAugmentPipe(torch.nn.Module): + def __init__(self, + rotate90=0, xint=0, xint_max=0.125, + scale=0, rotate=0, aniso=0, xfrac=0, scale_std=0.2, rotate_max=1, aniso_std=0.2, xfrac_std=0.125, + brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5, + hue_max=1, saturation_std=1, + imgfilter=0, imgfilter_bands=[1,1,1,1], imgfilter_std=1, + ): + super().__init__() + self.register_buffer('p', torch.ones([])) # Overall multiplier for augmentation probability. + + # Pixel blitting. + self.rotate90 = float(rotate90) # Probability multiplier for 90 degree rotations. + self.xint = float(xint) # Probability multiplier for integer translation. + self.xint_max = float(xint_max) # Range of integer translation, relative to image dimensions. + + # General geometric transformations. + self.scale = float(scale) # Probability multiplier for isotropic scaling. + self.rotate = float(rotate) # Probability multiplier for arbitrary rotation. + self.aniso = float(aniso) # Probability multiplier for anisotropic scaling. + self.xfrac = float(xfrac) # Probability multiplier for fractional translation. + self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling. + self.rotate_max = float(rotate_max) # Range of arbitrary rotation, 1 = full circle. + self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling. + self.xfrac_std = float(xfrac_std) # Standard deviation of frational translation, relative to image dimensions. + + # Color transformations. + self.brightness = float(brightness) # Probability multiplier for brightness. + self.contrast = float(contrast) # Probability multiplier for contrast. + self.lumaflip = float(lumaflip) # Probability multiplier for luma flip. + self.hue = float(hue) # Probability multiplier for hue rotation. + self.saturation = float(saturation) # Probability multiplier for saturation. + self.brightness_std = float(brightness_std) # Standard deviation of brightness. + self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast. + self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle. + self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation. + + # Image-space filtering. + self.imgfilter = float(imgfilter) # Probability multiplier for image-space filtering. + self.imgfilter_bands = list(imgfilter_bands) # Probability multipliers for individual frequency bands. + self.imgfilter_std = float(imgfilter_std) # Log2 standard deviation of image-space filter amplification. + + # Setup orthogonal lowpass filter for geometric augmentations. + self.register_buffer('Hz_geom', upfirdn2d.setup_filter(wavelets['sym6'])) + + # Construct filter bank for image-space filtering. + Hz_lo = np.asarray(wavelets['sym2']) # H(z) + Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) # H(-z) + Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2 + Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2 + Hz_fbank = np.eye(4, 1) # Bandpass(H(z), b_i) + for i in range(1, Hz_fbank.shape[0]): + Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape(Hz_fbank.shape[0], -1)[:, :-1] + Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2]) + Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) // 2 : (Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2 + self.register_buffer('Hz_fbank', torch.as_tensor(Hz_fbank, dtype=torch.float32)) + + def forward(self, batch, debug_percentile=None): + images = batch["img"] + batch["vertices"] = batch["vertices"].float() + assert isinstance(images, torch.Tensor) and images.ndim == 4 + batch_size, num_channels, height, width = images.shape + device = images.device + self.Hz_fbank = self.Hz_fbank.to(device) + self.Hz_geom = self.Hz_geom.to(device) + if debug_percentile is not None: + debug_percentile = torch.as_tensor(debug_percentile, dtype=torch.float32, device=device) + + # ------------------------------------- + # Select parameters for pixel blitting. + # ------------------------------------- + + # Initialize inverse homogeneous 2D transform: G_inv @ pixel_out ==> pixel_in + I_3 = torch.eye(3, device=device) + G_inv = I_3 + + # Apply integer translation with probability (xint * strength). + if self.xint > 0: + t = (torch.rand([batch_size, 2], device=device) * 2 - 1) * self.xint_max + t = torch.where(torch.rand([batch_size, 1], device=device) < self.xint * self.p, t, torch.zeros_like(t)) + if debug_percentile is not None: + t = torch.full_like(t, (debug_percentile * 2 - 1) * self.xint_max) + G_inv = G_inv @ translate2d_inv(torch.round(t[:,0] * width), torch.round(t[:,1] * height)) + + # -------------------------------------------------------- + # Select parameters for general geometric transformations. + # -------------------------------------------------------- + + # Apply isotropic scaling with probability (scale * strength). + if self.scale > 0: + s = torch.exp2(torch.randn([batch_size], device=device) * self.scale_std) + s = torch.where(torch.rand([batch_size], device=device) < self.scale * self.p, s, torch.ones_like(s)) + if debug_percentile is not None: + s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.scale_std)) + G_inv = G_inv @ scale2d_inv(s, s) + + # Apply pre-rotation with probability p_rot. + p_rot = 1 - torch.sqrt((1 - self.rotate * self.p).clamp(0, 1)) # P(pre OR post) = p + if self.rotate > 0: + theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max + theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta)) + if debug_percentile is not None: + theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.rotate_max) + G_inv = G_inv @ rotate2d_inv(-theta) # Before anisotropic scaling. + + # Apply anisotropic scaling with probability (aniso * strength). + if self.aniso > 0: + s = torch.exp2(torch.randn([batch_size], device=device) * self.aniso_std) + s = torch.where(torch.rand([batch_size], device=device) < self.aniso * self.p, s, torch.ones_like(s)) + if debug_percentile is not None: + s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.aniso_std)) + G_inv = G_inv @ scale2d_inv(s, 1 / s) + + # Apply post-rotation with probability p_rot. + if self.rotate > 0: + theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max + theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta)) + if debug_percentile is not None: + theta = torch.zeros_like(theta) + G_inv = G_inv @ rotate2d_inv(-theta) # After anisotropic scaling. + + # Apply fractional translation with probability (xfrac * strength). + if self.xfrac > 0: + t = torch.randn([batch_size, 2], device=device) * self.xfrac_std + t = torch.where(torch.rand([batch_size, 1], device=device) < self.xfrac * self.p, t, torch.zeros_like(t)) + if debug_percentile is not None: + t = torch.full_like(t, torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std) + G_inv = G_inv @ translate2d_inv(t[:,0] * width, t[:,1] * height) + + # ---------------------------------- + # Execute geometric transformations. + # ---------------------------------- + + # Execute if the transform is not identity. + if G_inv is not I_3: + # Calculate padding. + cx = (width - 1) / 2 + cy = (height - 1) / 2 + cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz] + cp = G_inv @ cp.t() # [batch, xyz, idx] + Hz_pad = self.Hz_geom.shape[0] // 4 + margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx] + margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1] + margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device) + margin = margin.max(misc.constant([0, 0] * 2, device=device)) + margin = margin.min(misc.constant([width-1, height-1] * 2, device=device)) + mx0, my0, mx1, my1 = margin.ceil().to(torch.int32) + + # Pad image and adjust origin. + images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect') + batch["mask"] = torch.nn.functional.pad(input=batch["mask"], pad=[mx0,mx1,my0,my1], mode='constant', value=1.0) + batch["E_mask"] = torch.nn.functional.pad(input=batch["E_mask"], pad=[mx0,mx1,my0,my1], mode='constant', value=0.0) + batch["vertices"] = torch.nn.functional.pad(input=batch["vertices"], pad=[mx0,mx1,my0,my1], mode='constant', value=0.0) + G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv + + # Upsample. + images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2) + batch["mask"] = torch.nn.functional.interpolate(batch["mask"], scale_factor=2, mode="nearest") + batch["E_mask"] = torch.nn.functional.interpolate(batch["E_mask"], scale_factor=2, mode="nearest") + batch["vertices"] = torch.nn.functional.interpolate(batch["vertices"], scale_factor=2, mode="nearest") + G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device) + G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device) + + # Execute transformation. + shape = [batch_size, num_channels, (height + Hz_pad * 2) * 2, (width + Hz_pad * 2) * 2] + G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device) + grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False) + images = grid_sample_gradfix.grid_sample(images, grid) + + batch["mask"] = torch.nn.functional.grid_sample( + input=batch["mask"], grid=grid, mode='nearest', padding_mode="border", align_corners=False) + batch["E_mask"] = torch.nn.functional.grid_sample( + input=batch["E_mask"], grid=grid, mode='nearest', padding_mode="border", align_corners=False) + batch["vertices"] = torch.nn.functional.grid_sample( + input=batch["vertices"], grid=grid, mode='nearest', padding_mode="border", align_corners=False) + + + # Downsample and crop. + images = upfirdn2d.downsample2d(x=images, f=self.Hz_geom, down=2, padding=-Hz_pad*2, flip_filter=True) + batch["mask"] = torch.nn.functional.interpolate(batch["mask"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False) + batch["E_mask"] = torch.nn.functional.interpolate(batch["E_mask"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False) + batch["vertices"] = torch.nn.functional.interpolate(batch["vertices"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False) + # -------------------------------------------- + # Select parameters for color transformations. + # -------------------------------------------- + + # Initialize homogeneous 3D transformation matrix: C @ color_in ==> color_out + I_4 = torch.eye(4, device=device) + C = I_4 + + # Apply brightness with probability (brightness * strength). + if self.brightness > 0: + b = torch.randn([batch_size], device=device) * self.brightness_std + b = torch.where(torch.rand([batch_size], device=device) < self.brightness * self.p, b, torch.zeros_like(b)) + if debug_percentile is not None: + b = torch.full_like(b, torch.erfinv(debug_percentile * 2 - 1) * self.brightness_std) + C = translate3d(b, b, b) @ C + + # Apply contrast with probability (contrast * strength). + if self.contrast > 0: + c = torch.exp2(torch.randn([batch_size], device=device) * self.contrast_std) + c = torch.where(torch.rand([batch_size], device=device) < self.contrast * self.p, c, torch.ones_like(c)) + if debug_percentile is not None: + c = torch.full_like(c, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.contrast_std)) + C = scale3d(c, c, c) @ C + + # Apply luma flip with probability (lumaflip * strength). + v = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) # Luma axis. + + # Apply hue rotation with probability (hue * strength). + if self.hue > 0 and num_channels > 1: + theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.hue_max + theta = torch.where(torch.rand([batch_size], device=device) < self.hue * self.p, theta, torch.zeros_like(theta)) + if debug_percentile is not None: + theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.hue_max) + C = rotate3d(v, theta) @ C # Rotate around v. + + # Apply saturation with probability (saturation * strength). + if self.saturation > 0 and num_channels > 1: + s = torch.exp2(torch.randn([batch_size, 1, 1], device=device) * self.saturation_std) + s = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.saturation * self.p, s, torch.ones_like(s)) + if debug_percentile is not None: + s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.saturation_std)) + C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C + + # ------------------------------ + # Execute color transformations. + # ------------------------------ + + # Execute if the transform is not identity. + if C is not I_4: + images = images.reshape([batch_size, num_channels, height * width]) + if num_channels == 3: + images = C[:, :3, :3] @ images + C[:, :3, 3:] + elif num_channels == 1: + C = C[:, :3, :].mean(dim=1, keepdims=True) + images = images * C[:, :, :3].sum(dim=2, keepdims=True) + C[:, :, 3:] + else: + raise ValueError('Image must be RGB (3 channels) or L (1 channel)') + images = images.reshape([batch_size, num_channels, height, width]) + + # ---------------------- + # Image-space filtering. + # ---------------------- + + if self.imgfilter > 0: + num_bands = self.Hz_fbank.shape[0] + assert len(self.imgfilter_bands) == num_bands + expected_power = misc.constant(np.array([10, 1, 1, 1]) / 13, device=device) # Expected power spectrum (1/f). + + # Apply amplification for each band with probability (imgfilter * strength * band_strength). + g = torch.ones([batch_size, num_bands], device=device) # Global gain vector (identity). + for i, band_strength in enumerate(self.imgfilter_bands): + t_i = torch.exp2(torch.randn([batch_size], device=device) * self.imgfilter_std) + t_i = torch.where(torch.rand([batch_size], device=device) < self.imgfilter * self.p * band_strength, t_i, torch.ones_like(t_i)) + if debug_percentile is not None: + t_i = torch.full_like(t_i, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.imgfilter_std)) if band_strength > 0 else torch.ones_like(t_i) + t = torch.ones([batch_size, num_bands], device=device) # Temporary gain vector. + t[:, i] = t_i # Replace i'th element. + t = t / (expected_power * t.square()).sum(dim=-1, keepdims=True).sqrt() # Normalize power. + g = g * t # Accumulate into global gain. + + # Construct combined amplification filter. + Hz_prime = g @ self.Hz_fbank # [batch, tap] + Hz_prime = Hz_prime.unsqueeze(1).repeat([1, num_channels, 1]) # [batch, channels, tap] + Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1, -1]) # [batch * channels, 1, tap] + + # Apply filter. + p = self.Hz_fbank.shape[1] // 2 + images = images.reshape([1, batch_size * num_channels, height, width]) + images = torch.nn.functional.pad(input=images, pad=[p,p,p,p], mode='reflect') + images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size*num_channels) + images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size*num_channels) + images = images.reshape([batch_size, num_channels, height, width]) + + # ------------------------ + # Image-space corruptions. + # ------------------------ + batch["img"] = images + batch["vertices"] = batch["vertices"].long() + batch["border"] = 1 - batch["E_mask"] - batch["mask"] + return batch diff --git a/dp2/data/transforms/transforms.py b/dp2/data/transforms/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..5fd43e7a515deacca4be7242d065b4f3ccb6800e --- /dev/null +++ b/dp2/data/transforms/transforms.py @@ -0,0 +1,277 @@ +from pathlib import Path +from typing import Dict, List +import torchvision +import torch +import tops +import torchvision.transforms.functional as F +from .functional import hflip +import numpy as np +from dp2.utils.vis_utils import get_coco_keypoints +from PIL import Image, ImageDraw +from typing import Tuple + + +class RandomHorizontalFlip(torch.nn.Module): + + def __init__(self, p: float, flip_map=None, **kwargs): + super().__init__() + self.flip_ratio = p + self.flip_map = flip_map + if self.flip_ratio is None: + self.flip_ratio = 0.5 + assert 0 <= self.flip_ratio <= 1 + + def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + if torch.rand(1) > self.flip_ratio: + return container + return hflip(container, self.flip_map) + + +class CenterCrop(torch.nn.Module): + """ + Performs the transform on the image. + NOTE: Does not transform the mask to improve runtime. + """ + + def __init__(self, size: List[int]): + super().__init__() + self.size = tuple(size) + + def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + min_size = min(container["img"].shape[1], container["img"].shape[2]) + if min_size < self.size[0]: + container["img"] = F.center_crop(container["img"], min_size) + container["img"] = F.resize(container["img"], self.size) + return container + container["img"] = F.center_crop(container["img"], self.size) + return container + + +class Resize(torch.nn.Module): + """ + Performs the transform on the image. + NOTE: Does not transform the mask to improve runtime. + """ + + def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR): + super().__init__() + self.size = tuple(size) + self.interpolation = interpolation + + def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + container["img"] = F.resize(container["img"], self.size, self.interpolation, antialias=True) + if "semantic_mask" in container: + container["semantic_mask"] = F.resize( + container["semantic_mask"], self.size, F.InterpolationMode.NEAREST) + if "embedding" in container: + container["embedding"] = F.resize( + container["embedding"], self.size, self.interpolation) + if "mask" in container: + container["mask"] = F.resize( + container["mask"], self.size, F.InterpolationMode.NEAREST) + if "E_mask" in container: + container["E_mask"] = F.resize( + container["E_mask"], self.size, F.InterpolationMode.NEAREST) + if "maskrcnn_mask" in container: + container["maskrcnn_mask"] = F.resize( + container["maskrcnn_mask"], self.size, F.InterpolationMode.NEAREST) + if "vertices" in container: + container["vertices"] = F.resize( + container["vertices"], self.size, F.InterpolationMode.NEAREST) + return container + + def __repr__(self): + repr = super().__repr__() + vars_ = dict(size=self.size, interpolation=self.interpolation) + return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()]) + + +class Normalize(torch.nn.Module): + """ + Performs the transform on the image. + NOTE: Does not transform the mask to improve runtime. + """ + + def __init__(self, mean, std, inplace, keys=["img"]): + super().__init__() + self.mean = mean + self.std = std + self.inplace = inplace + self.keys = keys + + def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + for key in self.keys: + container[key] = F.normalize(container[key], self.mean, self.std, self.inplace) + return container + + def __repr__(self): + repr = super().__repr__() + vars_ = dict(mean=self.mean, std=self.std, inplace=self.inplace) + return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()]) + + +class ToFloat(torch.nn.Module): + + def __init__(self, keys=["img"], norm=True) -> None: + super().__init__() + self.keys = keys + self.gain = 255 if norm else 1 + + def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + for key in self.keys: + container[key] = container[key].float() / self.gain + return container + + +class RandomCrop(torchvision.transforms.RandomCrop): + """ + Performs the transform on the image. + NOTE: Does not transform the mask to improve runtime. + """ + + def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + container["img"] = super().forward(container["img"]) + return container + + +class CreateCondition(torch.nn.Module): + + def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + if container["img"].dtype == torch.uint8: + container["condition"] = container["img"] * container["mask"].byte() + (1-container["mask"].byte()) * 127 + return container + container["condition"] = container["img"] * container["mask"] + return container + + +class CreateEmbedding(torch.nn.Module): + + def __init__(self, embed_path: Path, cuda=True) -> None: + super().__init__() + self.embed_map = torch.load(embed_path, map_location=torch.device("cpu")) + if cuda: + self.embed_map = tops.to_cuda(self.embed_map) + + def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + vertices = container["vertices"] + if vertices.ndim == 3: + embedding = self.embed_map[vertices.long()].squeeze(dim=0) + embedding = embedding.permute(2, 0, 1) * container["E_mask"] + pass + else: + assert vertices.ndim == 4 + embedding = self.embed_map[vertices.long()].squeeze(dim=1) + embedding = embedding.permute(0, 3, 1, 2) * container["E_mask"] + container["embedding"] = embedding + container["embed_map"] = self.embed_map.clone() + return container + + +class InsertJointMap(torch.nn.Module): + + def __init__(self, imsize: Tuple) -> None: + super().__init__() + self.imsize = imsize + knames = get_coco_keypoints()[0] + knames = knames + ["neck", "mid_hip"] + connectivity = { + "nose": ["left_eye", "right_eye", "neck"], + "left_eye": ["right_eye", "left_ear"], + "right_eye": ["right_ear"], + "left_shoulder": ["right_shoulder", "left_elbow", "left_hip"], + "right_shoulder": ["right_elbow", "right_hip"], + "left_elbow": ["left_wrist"], + "right_elbow": ["right_wrist"], + "left_hip": ["right_hip", "left_knee"], + "right_hip": ["right_knee"], + "left_knee": ["left_ankle"], + "right_knee": ["right_ankle"], + "neck": ["mid_hip", "nose"], + } + category = { + ("nose", "left_eye"): 0, # head + ("nose", "right_eye"): 0, # head + ("nose", "neck"): 0, # head + ("left_eye", "right_eye"): 0, # head + ("left_eye", "left_ear"): 0, # head + ("right_eye", "right_ear"): 0, # head + ("left_shoulder", "left_elbow"): 1, # left arm + ("left_elbow", "left_wrist"): 1, # left arm + ("right_shoulder", "right_elbow"): 2, # right arm + ("right_elbow", "right_wrist"): 2, # right arm + ("left_shoulder", "right_shoulder"): 3, # body + ("left_shoulder", "left_hip"): 3, # body + ("right_shoulder", "right_hip"): 3, # body + ("left_hip", "right_hip"): 3, # body + ("left_hip", "left_knee"): 4, # left leg + ("left_knee", "left_ankle"): 4, # left leg + ("right_hip", "right_knee"): 5, # right leg + ("right_knee", "right_ankle"): 5, # right leg + ("neck", "mid_hip"): 3, # body + ("neck", "nose"): 0, # head + } + self.indices2category = { + tuple([knames.index(n) for n in k]): v for k, v in category.items() + } + self.connectivity_indices = { + knames.index(k): [knames.index(v_) for v_ in v] + for k, v in connectivity.items() + } + self.l_shoulder = knames.index("left_shoulder") + self.r_shoulder = knames.index("right_shoulder") + self.l_hip = knames.index("left_hip") + self.r_hip = knames.index("right_hip") + self.l_eye = knames.index("left_eye") + self.r_eye = knames.index("right_eye") + self.nose = knames.index("nose") + self.neck = knames.index("neck") + + def create_joint_map(self, N, H, W, keypoints): + joint_maps = np.zeros((N, H, W), dtype=np.uint8) + for bidx, keypoints in enumerate(keypoints): + assert keypoints.shape == (17, 3), keypoints.shape + keypoints = torch.cat((keypoints, torch.zeros(2, 3))) + visible = keypoints[:, -1] > 0 + + if visible[self.l_shoulder] and visible[self.r_shoulder]: + neck = (keypoints[self.l_shoulder] + + (keypoints[self.r_shoulder] - keypoints[self.l_shoulder]) / 2) + keypoints[-2] = neck + visible[-2] = 1 + if visible[self.l_hip] and visible[self.r_hip]: + mhip = (keypoints[self.l_hip] + + (keypoints[self.r_hip] - keypoints[self.l_hip]) / 2 + ) + keypoints[-1] = mhip + visible[-1] = 1 + + keypoints[:, 0] *= W + keypoints[:, 1] *= H + joint_map = Image.fromarray(np.zeros((H, W), dtype=np.uint8)) + draw = ImageDraw.Draw(joint_map) + for fidx in self.connectivity_indices.keys(): + for tidx in self.connectivity_indices[fidx]: + if visible[fidx] == 0 or visible[tidx] == 0: + continue + c = self.indices2category[(fidx, tidx)] + s = tuple(keypoints[fidx, :2].round().long().numpy().tolist()) + e = tuple(keypoints[tidx, :2].round().long().numpy().tolist()) + draw.line((s, e), width=1, fill=c + 1) + if visible[self.nose] == 0 and visible[self.neck] == 1: + m_eye = ( + keypoints[self.l_eye] + + (keypoints[self.r_eye] - keypoints[self.l_eye]) / 2 + ) + s = tuple(m_eye[:2].round().long().numpy().tolist()) + e = tuple(keypoints[self.neck, :2].round().long().numpy().tolist()) + c = self.indices2category[(self.nose, self.neck)] + draw.line((s, e), width=1, fill=c + 1) + joint_map = np.array(joint_map) + + joint_maps[bidx] = np.array(joint_map) + return joint_maps[:, None] + + def forward(self, batch): + batch["joint_map"] = torch.from_numpy(self.create_joint_map( + batch["img"].shape[0], *self.imsize, batch["keypoints"])) + return batch diff --git a/dp2/data/utils.py b/dp2/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..392fe22f9f78daae0f65b1484311f7bb14381cbd --- /dev/null +++ b/dp2/data/utils.py @@ -0,0 +1,122 @@ +import torch +from PIL import Image +import numpy as np +import multiprocessing +import io +from tops import logger +from torch.utils.data._utils.collate import default_collate + +try: + import pyspng + + PYSPNG_IMPORTED = True +except ImportError: + PYSPNG_IMPORTED = False + print("Could not load pyspng. Defaulting to pillow image backend.") + from PIL import Image + + +def get_fdf_keypoints(): + return get_coco_keypoints()[:7] + + +def get_fdf_flipmap(): + keypoints = get_fdf_keypoints() + keypoint_flip_map = { + "left_eye": "right_eye", + "left_ear": "right_ear", + "left_shoulder": "right_shoulder", + } + for key, value in list(keypoint_flip_map.items()): + keypoint_flip_map[value] = key + keypoint_flip_map["nose"] = "nose" + keypoint_flip_map_idx = [] + for source in keypoints: + keypoint_flip_map_idx.append(keypoints.index(keypoint_flip_map[source])) + return keypoint_flip_map_idx + + +def get_coco_keypoints(): + return [ + "nose", + "left_eye", + "right_eye", # 2 + "left_ear", + "right_ear", # 4 + "left_shoulder", + "right_shoulder", # 6 + "left_elbow", + "right_elbow", # 8 + "left_wrist", + "right_wrist", # 10 + "left_hip", + "right_hip", # 12 + "left_knee", + "right_knee", # 14 + "left_ankle", + "right_ankle", # 16 + ] + + +def get_coco_flipmap(): + keypoints = get_coco_keypoints() + keypoint_flip_map = { + "left_eye": "right_eye", + "left_ear": "right_ear", + "left_shoulder": "right_shoulder", + "left_elbow": "right_elbow", + "left_wrist": "right_wrist", + "left_hip": "right_hip", + "left_knee": "right_knee", + "left_ankle": "right_ankle", + } + for key, value in list(keypoint_flip_map.items()): + keypoint_flip_map[value] = key + keypoint_flip_map["nose"] = "nose" + keypoint_flip_map_idx = [] + for source in keypoints: + keypoint_flip_map_idx.append(keypoints.index(keypoint_flip_map[source])) + return keypoint_flip_map_idx + + +def mask_decoder(x): + mask = torch.from_numpy(np.array(Image.open(io.BytesIO(x)))).squeeze()[None] + mask = mask > 0 # This fixes bug causing maskf.loat().max() == 255. + return mask + + +def png_decoder(x): + if PYSPNG_IMPORTED: + return torch.from_numpy(np.rollaxis(pyspng.load(x), 2)) + with Image.open(io.BytesIO(x)) as im: + im = torch.from_numpy(np.rollaxis(np.array(im.convert("RGB")), 2)) + return im + + +def jpg_decoder(x): + with Image.open(io.BytesIO(x)) as im: + im = torch.from_numpy(np.rollaxis(np.array(im.convert("RGB")), 2)) + return im + + +def get_num_workers(num_workers: int): + n_cpus = multiprocessing.cpu_count() + if num_workers > n_cpus: + logger.warn(f"Setting the number of workers to match cpu count: {n_cpus}") + return n_cpus + return num_workers + + +def collate_fn(batch): + elem = batch[0] + ignore_keys = set(["embed_map", "vertx2cat"]) + batch_ = { + key: default_collate([d[key] for d in batch]) + for key in elem + if key not in ignore_keys + } + if "embed_map" in elem: + batch_["embed_map"] = elem["embed_map"] + if "vertx2cat" in elem: + batch_["vertx2cat"] = elem["vertx2cat"] + return batch_ diff --git a/dp2/detection/__init__.py b/dp2/detection/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..613969b28384cd1c64fc8db685e7622f4cc02615 --- /dev/null +++ b/dp2/detection/__init__.py @@ -0,0 +1,3 @@ +from .cse_mask_face_detector import CSeMaskFaceDetector +from .person_detector import CSEPersonDetector +from .structures import PersonDetection, VehicleDetection, FaceDetection diff --git a/dp2/detection/base.py b/dp2/detection/base.py new file mode 100644 index 0000000000000000000000000000000000000000..6ab8b20c9474cbb6074b66a694d1a1a05df0c12c --- /dev/null +++ b/dp2/detection/base.py @@ -0,0 +1,42 @@ +import pickle +import torch +import lzma +from pathlib import Path +from tops import logger + + +class BaseDetector: + + def __init__(self, cache_directory: str) -> None: + if cache_directory is not None: + self.cache_directory = Path(cache_directory, str(self.__class__.__name__)) + self.cache_directory.mkdir(exist_ok=True, parents=True) + + def save_to_cache(self, detection, cache_path: Path, after_preprocess=True): + logger.log(f"Caching detection to: {cache_path}") + with lzma.open(cache_path, "wb") as fp: + torch.save( + [det.state_dict(after_preprocess=after_preprocess) for det in detection], fp, + pickle_protocol=pickle.HIGHEST_PROTOCOL) + + def load_from_cache(self, cache_path: Path): + logger.log(f"Loading detection from cache path: {cache_path}") + with lzma.open(cache_path, "rb") as fp: + state_dict = torch.load(fp) + return [ + state["cls"].from_state_dict(state_dict=state) for state in state_dict + ] + + def forward_and_cache(self, im: torch.Tensor, cache_id: str, load_cache: bool): + if cache_id is None: + return self.forward(im) + cache_path = self.cache_directory.joinpath(cache_id + ".torch") + if cache_path.is_file() and load_cache: + try: + return self.load_from_cache(cache_path) + except Exception as e: + logger.warn(f"The cache file was corrupted: {cache_path}") + exit() + detections = self.forward(im) + self.save_to_cache(detections, cache_path) + return detections diff --git a/dp2/detection/box_utils.py b/dp2/detection/box_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3d3e6b5a84f071c1b9e9a74f6adbbe49b3cd7610 --- /dev/null +++ b/dp2/detection/box_utils.py @@ -0,0 +1,104 @@ +import numpy as np + + +def expand_bbox_to_ratio(bbox, imshape, target_aspect_ratio): + x0, y0, x1, y1 = [int(_) for _ in bbox] + h, w = y1 - y0, x1 - x0 + cur_ratio = h / w + + if cur_ratio == target_aspect_ratio: + return [x0, y0, x1, y1] + if cur_ratio < target_aspect_ratio: + target_height = int(w*target_aspect_ratio) + y0, y1 = expand_axis(y0, y1, target_height, imshape[0]) + else: + target_width = int(h/target_aspect_ratio) + x0, x1 = expand_axis(x0, x1, target_width, imshape[1]) + return x0, y0, x1, y1 + + +def expand_axis(start, end, target_width, limit): + # Can return a bbox outside of limit + cur_width = end - start + start = start - (target_width-cur_width)//2 + end = end + (target_width-cur_width)//2 + if end - start != target_width: + end += 1 + assert end - start == target_width + if start < 0 and end > limit: + return start, end + if start < 0 and end < limit: + to_shift = min(0 - start, limit - end) + start += to_shift + end += to_shift + if end > limit and start > 0: + to_shift = min(end - limit, start) + end -= to_shift + start -= to_shift + assert end - start == target_width + return start, end + + +def expand_box(bbox, imshape, mask, percentage_background: float): + assert isinstance(bbox[0], int) + assert 0 < percentage_background < 1 + # Percentage in S + mask_pixels = mask.long().sum().cpu() + total_pixels = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) + percentage_mask = mask_pixels / total_pixels + if (1 - percentage_mask) > percentage_background: + return bbox + target_pixels = mask_pixels / (1 - percentage_background) + x0, y0, x1, y1 = bbox + H = y1 - y0 + W = x1 - x0 + p = np.sqrt(target_pixels/(H*W)) + target_width = int(np.ceil(p * W)) + target_height = int(np.ceil(p * H)) + x0, x1 = expand_axis(x0, x1, target_width, imshape[1]) + y0, y1 = expand_axis(y0, y1, target_height, imshape[0]) + return [x0, y0, x1, y1] + + +def expand_axises_by_percentage(bbox_XYXY, imshape, percentage): + x0, y0, x1, y1 = bbox_XYXY + H = y1 - y0 + W = x1 - x0 + expansion = int(((H*W)**0.5) * percentage) + new_width = W + expansion + new_height = H + expansion + x0, x1 = expand_axis(x0, x1, min(new_width, imshape[1]), imshape[1]) + y0, y1 = expand_axis(y0, y1, min(new_height, imshape[0]), imshape[0]) + return [x0, y0, x1, y1] + + +def get_expanded_bbox( + bbox_XYXY, + imshape, + mask, + percentage_background: float, + axis_minimum_expansion: float, + target_aspect_ratio: float): + bbox_XYXY = bbox_XYXY.long().cpu().numpy().tolist() + # Expand each axis of the bounding box by a minimum percentage + bbox_XYXY = expand_axises_by_percentage(bbox_XYXY, imshape, axis_minimum_expansion) + # Find the minimum bbox with the aspect ratio. Can be outside of imshape + bbox_XYXY = expand_bbox_to_ratio(bbox_XYXY, imshape, target_aspect_ratio) + # Expands square box such that X% of the bbox is background + bbox_XYXY = expand_box(bbox_XYXY, imshape, mask, percentage_background) + assert isinstance(bbox_XYXY[0], (int, np.int64)) + return bbox_XYXY + + +def include_box(bbox, minimum_area, aspect_ratio_range, min_bbox_ratio_inside, imshape): + def area_inside_ratio(bbox, imshape): + area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) + area_inside = (min(bbox[2], imshape[1]) - max(0, bbox[0])) * (min(imshape[0], bbox[3]) - max(0, bbox[1])) + return area_inside / area + ratio = (bbox[3] - bbox[1]) / (bbox[2] - bbox[0]) + area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0]) + if area_inside_ratio(bbox, imshape) < min_bbox_ratio_inside: + return False + if ratio <= aspect_ratio_range[0] or ratio >= aspect_ratio_range[1] or area < minimum_area: + return False + return True diff --git a/dp2/detection/box_utils_fdf.py b/dp2/detection/box_utils_fdf.py new file mode 100644 index 0000000000000000000000000000000000000000..e1af1e3012f9a594fd43d2b17a9dabdb4cda6ba2 --- /dev/null +++ b/dp2/detection/box_utils_fdf.py @@ -0,0 +1,202 @@ +""" +The FDF dataset expands bound boxes differently from what is used for CSE. +""" + +import numpy as np + + +def quadratic_bounding_box(x0, y0, width, height, imshape): + # We assume that we can create a image that is quadratic without + # minimizing any of the sides + assert width <= min(imshape[:2]) + assert height <= min(imshape[:2]) + min_side = min(height, width) + if height != width: + side_diff = abs(height - width) + # Want to extend the shortest side + if min_side == height: + # Vertical side + height += side_diff + if height > imshape[0]: + # Take full frame, and shrink width + y0 = 0 + height = imshape[0] + + side_diff = abs(height - width) + width -= side_diff + x0 += side_diff // 2 + else: + y0 -= side_diff // 2 + y0 = max(0, y0) + else: + # Horizontal side + width += side_diff + if width > imshape[1]: + # Take full frame width, and shrink height + x0 = 0 + width = imshape[1] + + side_diff = abs(height - width) + height -= side_diff + y0 += side_diff // 2 + else: + x0 -= side_diff // 2 + x0 = max(0, x0) + # Check that bbox goes outside image + x1 = x0 + width + y1 = y0 + height + if imshape[1] < x1: + diff = x1 - imshape[1] + x0 -= diff + if imshape[0] < y1: + diff = y1 - imshape[0] + y0 -= diff + assert x0 >= 0, "Bounding box outside image." + assert y0 >= 0, "Bounding box outside image." + assert x0 + width <= imshape[1], "Bounding box outside image." + assert y0 + height <= imshape[0], "Bounding box outside image." + return x0, y0, width, height + + +def expand_bounding_box(bbox, percentage, imshape): + orig_bbox = bbox.copy() + x0, y0, x1, y1 = bbox + width = x1 - x0 + height = y1 - y0 + x0, y0, width, height = quadratic_bounding_box( + x0, y0, width, height, imshape) + expanding_factor = int(max(height, width) * percentage) + + possible_max_expansion = [(imshape[0] - width) // 2, + (imshape[1] - height) // 2, + expanding_factor] + + expanding_factor = min(possible_max_expansion) + # Expand height + + if expanding_factor > 0: + + y0 = y0 - expanding_factor + y0 = max(0, y0) + + height += expanding_factor * 2 + if height > imshape[0]: + y0 -= (imshape[0] - height) + height = imshape[0] + + if height + y0 > imshape[0]: + y0 -= (height + y0 - imshape[0]) + + # Expand width + x0 = x0 - expanding_factor + x0 = max(0, x0) + + width += expanding_factor * 2 + if width > imshape[1]: + x0 -= (imshape[1] - width) + width = imshape[1] + + if width + x0 > imshape[1]: + x0 -= (width + x0 - imshape[1]) + y1 = y0 + height + x1 = x0 + width + assert y0 >= 0, "Y0 is minus" + assert height <= imshape[0], "Height is larger than image." + assert x0 + width <= imshape[1] + assert y0 + height <= imshape[0] + assert width == height, "HEIGHT IS NOT EQUAL WIDTH!!" + assert x0 >= 0, "Y0 is minus" + assert width <= imshape[1], "Height is larger than image." + # Check that original bbox is within new + x0_o, y0_o, x1_o, y1_o = orig_bbox + assert x0 <= x0_o, f"New bbox is outisde of original. O:{x0_o}, N: {x0}" + assert x1 >= x1_o, f"New bbox is outisde of original. O:{x1_o}, N: {x1}" + assert y0 <= y0_o, f"New bbox is outisde of original. O:{y0_o}, N: {y0}" + assert y1 >= y1_o, f"New bbox is outisde of original. O:{y1_o}, N: {y1}" + + x0, y0, width, height = [int(_) for _ in [x0, y0, width, height]] + x1 = x0 + width + y1 = y0 + height + return np.array([x0, y0, x1, y1]) + + +def is_keypoint_within_bbox(x0, y0, x1, y1, keypoint): + keypoint = keypoint[:, :3] # only nose + eyes are relevant + kp_X = keypoint[0, :] + kp_Y = keypoint[1, :] + within_X = np.all(kp_X >= x0) and np.all(kp_X <= x1) + within_Y = np.all(kp_Y >= y0) and np.all(kp_Y <= y1) + return within_X and within_Y + + +def expand_bbox_simple(bbox, percentage): + x0, y0, x1, y1 = bbox.astype(float) + width = x1 - x0 + height = y1 - y0 + x_c = int(x0) + width // 2 + y_c = int(y0) + height // 2 + avg_size = max(width, height) + new_width = avg_size * (1 + percentage) + x0 = x_c - new_width // 2 + y0 = y_c - new_width // 2 + x1 = x_c + new_width // 2 + y1 = y_c + new_width // 2 + return np.array([x0, y0, x1, y1]).astype(int) + + +def pad_image(im, bbox, pad_value): + x0, y0, x1, y1 = bbox + if x0 < 0: + pad_im = np.zeros((im.shape[0], abs(x0), im.shape[2]), + dtype=np.uint8) + pad_value + im = np.concatenate((pad_im, im), axis=1) + x1 += abs(x0) + x0 = 0 + if y0 < 0: + pad_im = np.zeros((abs(y0), im.shape[1], im.shape[2]), + dtype=np.uint8) + pad_value + im = np.concatenate((pad_im, im), axis=0) + y1 += abs(y0) + y0 = 0 + if x1 >= im.shape[1]: + pad_im = np.zeros( + (im.shape[0], x1 - im.shape[1] + 1, im.shape[2]), + dtype=np.uint8) + pad_value + im = np.concatenate((im, pad_im), axis=1) + if y1 >= im.shape[0]: + pad_im = np.zeros( + (y1 - im.shape[0] + 1, im.shape[1], im.shape[2]), + dtype=np.uint8) + pad_value + im = np.concatenate((im, pad_im), axis=0) + return im[y0:y1, x0:x1] + + +def clip_box(bbox, im): + bbox[0] = max(0, bbox[0]) + bbox[1] = max(0, bbox[1]) + bbox[2] = min(im.shape[1] - 1, bbox[2]) + bbox[3] = min(im.shape[0] - 1, bbox[3]) + return bbox + + +def cut_face(im, bbox, simple_expand=False, pad_value=0, pad_im=True): + outside_im = (bbox < 0).any() or bbox[2] > im.shape[1] or bbox[3] > im.shape[0] + if simple_expand or (outside_im and pad_im): + return pad_image(im, bbox, pad_value) + bbox = clip_box(bbox, im) + x0, y0, x1, y1 = bbox + return im[y0:y1, x0:x1] + + +def expand_bbox( + bbox_ltrb, imshape, simple_expand, default_to_simple=False, + expansion_factor=0.35): + assert bbox_ltrb.shape == (4,), f"BBox shape was: {bbox_ltrb.shape}" + bbox = bbox_ltrb.astype(float) + # FDF256 uses simple expand with ratio 0.4 + if simple_expand: + return expand_bbox_simple(bbox, 0.4) + try: + return expand_bounding_box(bbox, expansion_factor, imshape) + except AssertionError: + return expand_bbox_simple(bbox, expansion_factor * 2) diff --git a/dp2/detection/cse_mask_face_detector.py b/dp2/detection/cse_mask_face_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..5eccfc4ac885cfcca47fef389b99c1e35685579b --- /dev/null +++ b/dp2/detection/cse_mask_face_detector.py @@ -0,0 +1,116 @@ +import torch +import lzma +import tops +from pathlib import Path +from dp2.detection.base import BaseDetector +from .utils import combine_cse_maskrcnn_dets +from face_detection import build_detector as build_face_detector +from .models.cse import CSEDetector +from .models.mask_rcnn import MaskRCNNDetector +from .structures import CSEPersonDetection, VehicleDetection, FaceDetection, PersonDetection +from tops import logger + + +def box1_inside_box2(box1: torch.Tensor, box2: torch.Tensor): + assert len(box1.shape) == 2 + assert len(box2.shape) == 2 + box1_inside = torch.zeros(box1.shape[0], device=box1.device, dtype=torch.bool) + # This can be batched + for i, box in enumerate(box1): + is_outside_lefttop = (box[None, [0, 1]] <= box2[:, [0, 1]]).any(dim=1) + is_outside_rightbot = (box[None, [2, 3]] >= box2[:, [2, 3]]).any(dim=1) + is_outside = is_outside_lefttop.logical_or(is_outside_rightbot) + box1_inside[i] = is_outside.logical_not().any() + return box1_inside + + +class CSeMaskFaceDetector(BaseDetector): + + def __init__( + self, + mask_rcnn_cfg, + face_detector_cfg: dict, + cse_cfg: dict, + face_post_process_cfg: dict, + cse_post_process_cfg, + score_threshold: float, + **kwargs + ) -> None: + super().__init__(**kwargs) + self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold) + if "confidence_threshold" not in face_detector_cfg: + face_detector_cfg["confidence_threshold"] = score_threshold + if "score_thres" not in cse_cfg: + cse_cfg["score_thres"] = score_threshold + self.cse_detector = CSEDetector(**cse_cfg) + self.face_detector = build_face_detector(**face_detector_cfg, clip_boxes=True) + self.cse_post_process_cfg = cse_post_process_cfg + self.face_mean = tops.to_cuda(torch.from_numpy(self.face_detector.mean).view(3, 1, 1)) + self.mask_cse_iou_combine_threshold = self.cse_post_process_cfg.pop("iou_combine_threshold") + self.face_post_process_cfg = face_post_process_cfg + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def _detect_faces(self, im: torch.Tensor): + H, W = im.shape[1:] + im = im.float() - self.face_mean + im = self.face_detector.resize(im[None], 1.0) + boxes_XYXY = self.face_detector._batched_detect(im)[0][:, :-1] # Remove score + boxes_XYXY[:, [0, 2]] *= W + boxes_XYXY[:, [1, 3]] *= H + return boxes_XYXY.round().long() + + def load_from_cache(self, cache_path: Path): + logger.log(f"Loading detection from cache path: {cache_path}",) + with lzma.open(cache_path, "rb") as fp: + state_dict = torch.load(fp, map_location="cpu") + kwargs = dict( + post_process_cfg=self.cse_post_process_cfg, + embed_map=self.cse_detector.embed_map, + **self.face_post_process_cfg + ) + return [ + state["cls"].from_state_dict(**kwargs, state_dict=state) + for state in state_dict + ] + + @torch.no_grad() + def forward(self, im: torch.Tensor): + maskrcnn_dets = self.mask_rcnn(im) + cse_dets = self.cse_detector(im) + embed_map = self.cse_detector.embed_map + print("Calling face detector.") + face_boxes = self._detect_faces(im).cpu() + maskrcnn_person = { + k: v[maskrcnn_dets["is_person"]] for k, v in maskrcnn_dets.items() + } + maskrcnn_other = { + k: v[maskrcnn_dets["is_person"].logical_not()] for k, v in maskrcnn_dets.items() + } + maskrcnn_other = VehicleDetection(maskrcnn_other["segmentation"]) + combined_segmentation, cse_dets, matches = combine_cse_maskrcnn_dets( + maskrcnn_person["segmentation"], cse_dets, self.mask_cse_iou_combine_threshold) + + persons_with_cse = CSEPersonDetection( + combined_segmentation, cse_dets, **self.cse_post_process_cfg, + embed_map=embed_map, orig_imshape_CHW=im.shape + ) + persons_with_cse.pre_process() + not_matched = [i for i in range(maskrcnn_person["segmentation"].shape[0]) if i not in matches[:, 0]] + persons_without_cse = PersonDetection( + maskrcnn_person["segmentation"][not_matched], **self.cse_post_process_cfg, + orig_imshape_CHW=im.shape + ) + persons_without_cse.pre_process() + + face_boxes_covered = box1_inside_box2(face_boxes, persons_with_cse.dilated_boxes).logical_or( + box1_inside_box2(face_boxes, persons_without_cse.dilated_boxes) + ) + face_boxes = face_boxes[face_boxes_covered.logical_not()] + face_boxes = FaceDetection(face_boxes, **self.face_post_process_cfg) + + # Order matters. The anonymizer will anonymize FIFO. + # Later detections will overwrite. + all_detections = [face_boxes, maskrcnn_other, persons_without_cse, persons_with_cse] + return all_detections diff --git a/dp2/detection/deep_privacy1_detector.py b/dp2/detection/deep_privacy1_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..9657e7aaa993033c559310edcab0cc54ee2a03e1 --- /dev/null +++ b/dp2/detection/deep_privacy1_detector.py @@ -0,0 +1,106 @@ +import torch +import tops +import lzma +from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights +from .base import BaseDetector +from face_detection import build_detector as build_face_detector +from .structures import FaceDetection +from tops import logger +from pathlib import Path + +def is_keypoint_within_bbox(x0, y0, x1, y1, keypoint): + keypoint = keypoint[:3, :] # only nose + eyes are relevant + kp_X = keypoint[:, 0] + kp_Y = keypoint[:, 1] + within_X = (kp_X >= x0).all() and (kp_X <= x1).all() + within_Y = (kp_Y >= y0).all() and (kp_Y <= y1).all() + return within_X and within_Y + + +def match_bbox_keypoint(bounding_boxes, keypoints): + """ + bounding_boxes shape: [N, 5] + keypoints: [N persons, K keypoints, (x, y)] + """ + if len(bounding_boxes) == 0 or len(keypoints) == 0: + return torch.empty((0, 5)), torch.empty((0, 7, 2)) + assert bounding_boxes.shape[1] == 4,\ + f"Shape was : {bounding_boxes.shape}" + assert keypoints.shape[-1] == 2,\ + f"Expected (x,y) in last axis, got: {keypoints.shape}" + assert keypoints.shape[1] in (5, 7),\ + f"Expeted 5 or 7 keypoints. Keypoint shape was: {keypoints.shape}" + + matches = [] + for bbox_idx, bbox in enumerate(bounding_boxes): + keypoint = None + for kp_idx, keypoint in enumerate(keypoints): + if kp_idx in (x[1] for x in matches): + continue + if is_keypoint_within_bbox(*bbox, keypoint): + matches.append((bbox_idx, kp_idx)) + break + keypoint_idx = [x[1] for x in matches] + bbox_idx = [x[0] for x in matches] + return bounding_boxes[bbox_idx], keypoints[keypoint_idx] + + +class DeepPrivacy1Detector(BaseDetector): + + def __init__(self, + keypoint_threshold: float, + face_detector_cfg, + score_threshold: float, + face_post_process_cfg, + **kwargs): + super().__init__(**kwargs) + self.keypoint_detector = tops.to_cuda(keypointrcnn_resnet50_fpn( + weights=KeypointRCNN_ResNet50_FPN_Weights.COCO_V1).eval()) + self.keypoint_threshold = keypoint_threshold + self.face_detector = build_face_detector(**face_detector_cfg, confidence_threshold=score_threshold) + self.face_mean = tops.to_cuda(torch.from_numpy(self.face_detector.mean).view(3, 1, 1)) + self.face_post_process_cfg = face_post_process_cfg + + @torch.no_grad() + def _detect_faces(self, im: torch.Tensor): + H, W = im.shape[1:] + im = im.float() - self.face_mean + im = self.face_detector.resize(im[None], 1.0) + boxes_XYXY = self.face_detector._batched_detect(im)[0][:, :-1] # Remove score + boxes_XYXY[:, [0, 2]] *= W + boxes_XYXY[:, [1, 3]] *= H + return boxes_XYXY.round().long().cpu() + + @torch.no_grad() + def _detect_keypoints(self, img: torch.Tensor): + img = img.float() / 255 + outputs = self.keypoint_detector([img]) + + # Shape: [N persons, K keypoints, (x,y,visibility)] + keypoints = outputs[0]["keypoints"] + scores = outputs[0]["scores"] + assert list(scores) == sorted(list(scores))[::-1] + mask = scores >= self.keypoint_threshold + keypoints = keypoints[mask, :, :2] + return keypoints[:, :7, :2] + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + @torch.no_grad() + def forward(self, im: torch.Tensor): + face_boxes = self._detect_faces(im) + keypoints = self._detect_keypoints(im) + face_boxes, keypoints = match_bbox_keypoint(face_boxes, keypoints) + face_boxes = FaceDetection(face_boxes, **self.face_post_process_cfg, keypoints=keypoints) + return [face_boxes] + + def load_from_cache(self, cache_path: Path): + logger.log(f"Loading detection from cache path: {cache_path}",) + with lzma.open(cache_path, "rb") as fp: + state_dict = torch.load(fp, map_location="cpu") + kwargs = self.face_post_process_cfg + return [ + state["cls"].from_state_dict(**kwargs, state_dict=state) + for state in state_dict + ] \ No newline at end of file diff --git a/dp2/detection/face_detector.py b/dp2/detection/face_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..8354d7b71ffd9147d7a7c1be462b6a1c30626672 --- /dev/null +++ b/dp2/detection/face_detector.py @@ -0,0 +1,62 @@ +import torch +import lzma +import tops +from pathlib import Path +from dp2.detection.base import BaseDetector +from face_detection import build_detector as build_face_detector +from .structures import FaceDetection +from tops import logger + + +def box1_inside_box2(box1: torch.Tensor, box2: torch.Tensor): + assert len(box1.shape) == 2 + assert len(box2.shape) == 2 + box1_inside = torch.zeros(box1.shape[0], device=box1.device, dtype=torch.bool) + # This can be batched + for i, box in enumerate(box1): + is_outside_lefttop = (box[None, [0, 1]] <= box2[:, [0, 1]]).any(dim=1) + is_outside_rightbot = (box[None, [2, 3]] >= box2[:, [2, 3]]).any(dim=1) + is_outside = is_outside_lefttop.logical_or(is_outside_rightbot) + box1_inside[i] = is_outside.logical_not().any() + return box1_inside + + +class FaceDetector(BaseDetector): + + def __init__( + self, + face_detector_cfg: dict, + score_threshold: float, + face_post_process_cfg: dict, + **kwargs + ) -> None: + super().__init__(**kwargs) + self.face_detector = build_face_detector(**face_detector_cfg, confidence_threshold=score_threshold) + self.face_mean = tops.to_cuda(torch.from_numpy(self.face_detector.mean).view(3, 1, 1)) + self.face_post_process_cfg = face_post_process_cfg + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def _detect_faces(self, im: torch.Tensor): + H, W = im.shape[1:] + im = im.float() - self.face_mean + im = self.face_detector.resize(im[None], 1.0) + boxes_XYXY = self.face_detector._batched_detect(im)[0][:, :-1] # Remove score + boxes_XYXY[:, [0, 2]] *= W + boxes_XYXY[:, [1, 3]] *= H + return boxes_XYXY.round().long().cpu() + + @torch.no_grad() + def forward(self, im: torch.Tensor): + face_boxes = self._detect_faces(im) + face_boxes = FaceDetection(face_boxes, **self.face_post_process_cfg) + return [face_boxes] + + def load_from_cache(self, cache_path: Path): + logger.log(f"Loading detection from cache path: {cache_path}") + with lzma.open(cache_path, "rb") as fp: + state_dict = torch.load(fp) + return [ + state["cls"].from_state_dict(state_dict=state, **self.face_post_process_cfg) for state in state_dict + ] diff --git a/dp2/detection/models/__init__.py b/dp2/detection/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/dp2/detection/models/cse.py b/dp2/detection/models/cse.py new file mode 100644 index 0000000000000000000000000000000000000000..7b9fbf75fc8fd6c1905993316c03383b9e935564 --- /dev/null +++ b/dp2/detection/models/cse.py @@ -0,0 +1,134 @@ +import torch +from typing import List +import tops +from torchvision.transforms.functional import InterpolationMode, resize +from densepose.data.utils import get_class_to_mesh_name_mapping +from densepose import add_densepose_config +from densepose.structures import DensePoseEmbeddingPredictorOutput +from densepose.vis.extractor import DensePoseOutputsExtractor +from densepose.modeling import build_densepose_embedder +from detectron2.config import get_cfg +from detectron2.data.transforms import ResizeShortestEdge +from detectron2.checkpoint.detection_checkpoint import DetectionCheckpointer +from detectron2.modeling import build_model + + +model_urls = { + "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml": "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x/250713061/model_final_1d3314.pkl", + "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_s1x.yaml": "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_50_FPN_s1x/251155172/model_final_c4ea5f.pkl", +} + + +def cse_det_to_global(boxes_XYXY, S: torch.Tensor, imshape): + assert len(S.shape) == 3 + H, W = imshape + N = len(boxes_XYXY) + segmentation = torch.zeros((N, H, W), dtype=torch.bool, device=S.device) + boxes_XYXY = boxes_XYXY.long() + for i in range(N): + x0, y0, x1, y1 = boxes_XYXY[i] + assert x0 >= 0 and y0 >= 0 + assert x1 <= imshape[1] + assert y1 <= imshape[0] + h = y1 - y0 + w = x1 - x0 + segmentation[i:i+1, y0:y1, x0:x1] = resize(S[i:i+1], (h, w), interpolation=InterpolationMode.NEAREST) > 0 + return segmentation + + +class CSEDetector: + + def __init__( + self, + cfg_url: str = "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml", + cfg_2_download: List[str] = [ + "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml", + "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/Base-DensePose-RCNN-FPN.yaml", + "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/Base-DensePose-RCNN-FPN-Human.yaml"], + score_thres: float = 0.9, + nms_thresh: float = None, + ) -> None: + with tops.logger.capture_log_stdout(): + cfg = get_cfg() + self.device = tops.get_device() + add_densepose_config(cfg) + cfg_path = tops.download_file(cfg_url) + for p in cfg_2_download: + tops.download_file(p) + with tops.logger.capture_log_stdout(): + cfg.merge_from_file(cfg_path) + assert cfg_url in model_urls, cfg_url + model_path = tops.download_file(model_urls[cfg_url]) + cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = score_thres + if nms_thresh is not None: + cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = nms_thresh + cfg.MODEL.WEIGHTS = str(model_path) + cfg.MODEL.DEVICE = str(self.device) + cfg.freeze() + with tops.logger.capture_log_stdout(): + self.model = build_model(cfg) + self.model.eval() + DetectionCheckpointer(self.model).load(str(model_path)) + self.input_format = cfg.INPUT.FORMAT + self.densepose_extractor = DensePoseOutputsExtractor() + self.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg) + + self.embedder = build_densepose_embedder(cfg) + self.mesh_vertex_embeddings = { + mesh_name: self.embedder(mesh_name).to(self.device) + for mesh_name in self.class_to_mesh_name.values() + if self.embedder.has_embeddings(mesh_name) + } + self.cfg = cfg + self.embed_map = self.mesh_vertex_embeddings["smpl_27554"] + tops.logger.log("CSEDetector built.") + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def resize_im(self, im): + H, W = im.shape[1:] + newH, newW = ResizeShortestEdge.get_output_shape( + H, W, self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST) + return resize( + im, (newH, newW), InterpolationMode.BILINEAR, antialias=True) + + @torch.no_grad() + def forward(self, im): + assert im.dtype == torch.uint8 + if self.input_format == "BGR": + im = im.flip(0) + H, W = im.shape[1:] + im = self.resize_im(im) + output = self.model([{"image": im, "height": H, "width": W}])[0]["instances"] + scores = output.get("scores") + if len(scores) == 0: + return dict( + instance_segmentation=torch.empty((0, 0, 112, 112), dtype=torch.bool, device=im.device), + instance_embedding=torch.empty((0, 16, 112, 112), dtype=torch.float32, device=im.device), + embed_map=self.mesh_vertex_embeddings["smpl_27554"], + bbox_XYXY=torch.empty((0, 4), dtype=torch.long, device=im.device), + im_segmentation=torch.empty((0, H, W), dtype=torch.bool, device=im.device), + scores=torch.empty((0), dtype=torch.float, device=im.device) + ) + pred_densepose, boxes_xywh, classes = self.densepose_extractor(output) + assert isinstance(pred_densepose, DensePoseEmbeddingPredictorOutput), pred_densepose + S = pred_densepose.coarse_segm.argmax(dim=1) # Segmentation channel Nx2xHxW (2 because only 2 classes) + E = pred_densepose.embedding + mesh_name = self.class_to_mesh_name[classes[0]] + assert mesh_name == "smpl_27554" + x0, y0, w, h = [boxes_xywh[:, i] for i in range(4)] + boxes_XYXY = torch.stack((x0, y0, x0+w, y0+h), dim=-1) + boxes_XYXY = boxes_XYXY.round_().long() + + non_empty_boxes = (boxes_XYXY[:, :2] == boxes_XYXY[:, 2:]).any(dim=1).logical_not() + S = S[non_empty_boxes] + E = E[non_empty_boxes] + boxes_XYXY = boxes_XYXY[non_empty_boxes] + scores = scores[non_empty_boxes] + im_segmentation = cse_det_to_global(boxes_XYXY, S, [H, W]) + return dict( + instance_segmentation=S, instance_embedding=E, + bbox_XYXY=boxes_XYXY, + im_segmentation=im_segmentation, + scores=scores.view(-1)) diff --git a/dp2/detection/models/keypoint_maskrcnn.py b/dp2/detection/models/keypoint_maskrcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..b9e35a9f44a7133e9a0418f5b73d2bc39229139b --- /dev/null +++ b/dp2/detection/models/keypoint_maskrcnn.py @@ -0,0 +1,111 @@ +import numpy as np +import torch +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.modeling.roi_heads import CascadeROIHeads, StandardROIHeads +from detectron2.data.transforms import ResizeShortestEdge +from detectron2.structures import Instances +from detectron2 import model_zoo +from detectron2.config import instantiate +from detectron2.config import LazyCall as L +from PIL import Image +import tops +import functools +from torchvision.transforms.functional import resize + + +def get_rn50_fpn_keypoint_rcnn(weight_path: str): + from detectron2.modeling.poolers import ROIPooler + from detectron2.modeling.roi_heads import KRCNNConvDeconvUpsampleHead + from detectron2.layers import ShapeSpec + model = model_zoo.get_config("common/models/mask_rcnn_fpn.py").model + model.roi_heads.update( + num_classes=1, + keypoint_in_features=["p2", "p3", "p4", "p5"], + keypoint_pooler=L(ROIPooler)( + output_size=14, + scales=(1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32), + sampling_ratio=0, + pooler_type="ROIAlignV2", + ), + keypoint_head=L(KRCNNConvDeconvUpsampleHead)( + input_shape=ShapeSpec(channels=256, width=14, height=14), + num_keypoints=17, + conv_dims=[512] * 8, + loss_normalizer="visible", + ), + ) + + # Detectron1 uses 2000 proposals per-batch, but this option is per-image in detectron2. + # 1000 proposals per-image is found to hurt box AP. + # Therefore we increase it to 1500 per-image. + model.proposal_generator.post_nms_topk = (1500, 1000) + + # Keypoint AP degrades (though box AP improves) when using plain L1 loss + model.roi_heads.box_predictor.smooth_l1_beta = 0.5 + model = instantiate(model) + + dataloader = model_zoo.get_config("common/data/coco_keypoint.py").dataloader + test_transform = instantiate(dataloader.test.mapper.augmentations) + DetectionCheckpointer(model).load(weight_path) + return model, test_transform + + +models = { + "rn50_fpn_maskrcnn": functools.partial(get_rn50_fpn_keypoint_rcnn, weight_path="https://folk.ntnu.no/haakohu/checkpoints/maskrcnn_keypoint/keypoint_maskrcnn_R_50_FPN_1x.pth") +} + + +class KeypointMaskRCNN: + + def __init__(self, model_name: str, score_threshold: float) -> None: + assert model_name in models, f"Did not find {model_name} in models" + model, test_transform = models[model_name]() + self.model = model.eval().to(tops.get_device()) + if isinstance(self.model.roi_heads, CascadeROIHeads): + for head in self.model.roi_heads.box_predictors: + assert hasattr(head, "test_score_thresh") + head.test_score_thresh = score_threshold + else: + assert isinstance(self.model.roi_heads, StandardROIHeads) + assert hasattr(self.model.roi_heads.box_predictor, "test_score_thresh") + self.model.roi_heads.box_predictor.test_score_thresh = score_threshold + + self.test_transform = test_transform + assert len(self.test_transform) == 1 + self.test_transform = self.test_transform[0] + assert isinstance(self.test_transform, ResizeShortestEdge) + assert self.test_transform.interp == Image.BILINEAR + self.image_format = self.model.input_format + + def resize_im(self, im): + H, W = im.shape[-2:] + if self.test_transform.is_range: + size = np.random.randint( + self.test_transform.short_edge_length[0], self.test_transform.short_edge_length[1] + 1) + else: + size = np.random.choice(self.test_transform.short_edge_length) + newH, newW = ResizeShortestEdge.get_output_shape(H, W, size, self.test_transform.max_size) + return resize( + im, (newH, newW), antialias=True) + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + @torch.no_grad() + def forward(self, im: torch.Tensor): + assert im.ndim == 3 + if self.image_format == "BGR": + im = im.flip(0) + H, W = im.shape[-2:] + im = im.float() + im = self.resize_im(im) + + inputs = dict(image=im, height=H, width=W) + # instances contains + # dict_keys(['pred_boxes', 'scores', 'pred_classes', 'pred_masks', 'pred_keypoints', 'pred_keypoint_heatmaps']) + instances = self.model([inputs])[0]["instances"] + return dict( + scores=instances.get("scores").cpu(), + segmentation=instances.get("pred_masks").cpu(), + keypoints=instances.get("pred_keypoints").cpu() + ) diff --git a/dp2/detection/models/mask_rcnn.py b/dp2/detection/models/mask_rcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..ed64706c0036d6dcc2355c8ce2f830bd8a22c3e3 --- /dev/null +++ b/dp2/detection/models/mask_rcnn.py @@ -0,0 +1,78 @@ +import torch +import tops +from detectron2.modeling import build_model +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.structures import Boxes +from detectron2.data import MetadataCatalog +from detectron2 import model_zoo +from typing import Dict +from detectron2.data.transforms import ResizeShortestEdge +from torchvision.transforms.functional import resize + + +model_urls = { + "COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml": "https://dl.fbaipublicfiles.com/detectron2/COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x/139653917/model_final_2d9806.pkl", + +} + + +class MaskRCNNDetector: + + def __init__( + self, + cfg_name: str = "COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml", + score_thres: float = 0.9, + class_filter=["person"], # ["car", "bicycle","truck", "bus", "backpack"] + fp16_inference: bool = False + ) -> None: + cfg = model_zoo.get_config(cfg_name) + cfg.MODEL.DEVICE = str(tops.get_device()) + cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = score_thres + cfg.freeze() + self.cfg = cfg + with tops.logger.capture_log_stdout(): + self.model = build_model(cfg) + DetectionCheckpointer(self.model).load(model_urls[cfg_name]) + self.model.eval() + self.input_format = cfg.INPUT.FORMAT + self.class_names = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes + self.class_to_keep = set([self.class_names.index(cls_) for cls_ in class_filter]) + self.person_class = self.class_names.index("person") + self.fp16_inference = fp16_inference + tops.logger.log("Mask R-CNN built.") + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def resize_im(self, im): + H, W = im.shape[1:] + newH, newW = ResizeShortestEdge.get_output_shape( + H, W, self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST) + return resize( + im, (newH, newW), antialias=True) + + @torch.no_grad() + def forward(self, im: torch.Tensor): + if self.input_format == "BGR": + im = im.flip(0) + else: + assert self.input_format == "RGB" + H, W = im.shape[-2:] + im = self.resize_im(im) + with torch.cuda.amp.autocast(enabled=self.fp16_inference): + output = self.model([{"image": im, "height": H, "width": W}])[0]["instances"] + scores = output.get("scores") + N = len(scores) + classes = output.get("pred_classes") + idx2keep = [i for i in range(N) if classes[i].tolist() in self.class_to_keep] + classes = classes[idx2keep] + assert isinstance(output.get("pred_boxes"), Boxes) + segmentation = output.get("pred_masks")[idx2keep] + assert segmentation.dtype == torch.bool + is_person = classes == self.person_class + return { + "scores": output.get("scores")[idx2keep], + "segmentation": segmentation, + "classes": output.get("pred_classes")[idx2keep], + "is_person": is_person + } diff --git a/dp2/detection/models/vit_pose/backbone.py b/dp2/detection/models/vit_pose/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..f78712dd6131b02ea674fac452852274ae0dc5c2 --- /dev/null +++ b/dp2/detection/models/vit_pose/backbone.py @@ -0,0 +1,311 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Code adapted from: https://github.com/gpastal24/ViTPose-Pytorch +import torch +from functools import partial +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + +from timm.models.layers import drop_path, to_2tuple, trunc_normal_ + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self): + return 'p={}'.format(self.drop_prob) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__( + self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0., attn_head_dim=None,): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.dim = dim + + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, + drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, + norm_layer=nn.LayerNorm, attn_head_dim=None + ): + super().__init__() + + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim + ) + + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2) + self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio)) + self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1])) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=( + patch_size[0] // ratio), padding=4 + 2 * (ratio//2-1)) + + def forward(self, x, **kwargs): + B, C, H, W = x.shape + x = self.proj(x) + Hp, Wp = x.shape[2], x.shape[3] + + x = x.flatten(2).transpose(1, 2) + return x, (Hp, Wp) + + +class HybridEmbed(nn.Module): + """ CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + + def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): + super().__init__() + assert isinstance(backbone, nn.Module) + img_size = to_2tuple(img_size) + self.img_size = img_size + self.backbone = backbone + if feature_size is None: + with torch.no_grad(): + training = backbone.training + if training: + backbone.eval() + o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + feature_dim = self.backbone.feature_info.channels()[-1] + self.num_patches = feature_size[0] * feature_size[1] + self.proj = nn.Linear(feature_dim, embed_dim) + + def forward(self, x): + x = self.backbone(x)[-1] + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + + +# @BACKBONES.register_module() +class ViT(nn.Module): + + def __init__(self, + img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False, + frozen_stages=-1, ratio=1, last_norm=True, + patch_padding='pad', freeze_attn=False, freeze_ffn=False, + ): + # Protect mutable default arguments + super(ViT, self).__init__() + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.frozen_stages = frozen_stages + self.use_checkpoint = use_checkpoint + self.patch_padding = patch_padding + self.freeze_attn = freeze_attn + self.freeze_ffn = freeze_ffn + self.depth = depth + + if hybrid_backbone is not None: + self.patch_embed = HybridEmbed( + hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) + else: + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio) + num_patches = self.patch_embed.num_patches + + # since the pretraining model has class token + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + ) + for i in range(depth)]) + + self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity() + + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=.02) + + self._freeze_stages() + + def _freeze_stages(self): + """Freeze parameters.""" + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = self.blocks[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + if self.freeze_attn: + for i in range(0, self.depth): + m = self.blocks[i] + m.attn.eval() + m.norm1.eval() + for param in m.attn.parameters(): + param.requires_grad = False + for param in m.norm1.parameters(): + param.requires_grad = False + + if self.freeze_ffn: + self.pos_embed.requires_grad = False + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + for i in range(0, self.depth): + m = self.blocks[i] + m.mlp.eval() + m.norm2.eval() + for param in m.mlp.parameters(): + param.requires_grad = False + for param in m.norm2.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + super().init_weights(pretrained, patch_padding=self.patch_padding) + + if pretrained is None: + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + self.apply(_init_weights) + + def get_num_layers(self): + return len(self.blocks) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def forward_features(self, x): + B, C, H, W = x.shape + x, (Hp, Wp) = self.patch_embed(x) + + if self.pos_embed is not None: + # fit for multiple GPU training + # since the first element for pos embed (sin-cos manner) is zero, it will cause no difference + x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1] + + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + + x = self.last_norm(x) + + xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous() + + return xp + + def forward(self, x): + x = self.forward_features(x) + return x + + def train(self, mode=True): + """Convert the model into training mode.""" + super().train(mode) + self._freeze_stages() diff --git a/dp2/detection/models/vit_pose/topdown_heatmap_simple_head.py b/dp2/detection/models/vit_pose/topdown_heatmap_simple_head.py new file mode 100644 index 0000000000000000000000000000000000000000..85c08a76e4dab6d89f4b76183f327005211afd35 --- /dev/null +++ b/dp2/detection/models/vit_pose/topdown_heatmap_simple_head.py @@ -0,0 +1,505 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Code adapted from: https://github.com/gpastal24/ViTPose-Pytorch +import torch +import torch.nn as nn +import torch.nn.functional as F +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod + +# from mmpose.core.evaluation.top_down_eval import keypoints_from_heatmaps + + +class TopdownHeatmapBaseHead(nn.Module): + """Base class for top-down heatmap heads. + + All top-down heatmap heads should subclass it. + All subclass should overwrite: + + Methods:`get_loss`, supporting to calculate loss. + Methods:`get_accuracy`, supporting to calculate accuracy. + Methods:`forward`, supporting to forward model. + Methods:`inference_model`, supporting to inference model. + """ + + __metaclass__ = ABCMeta + + @abstractmethod + def get_loss(self, **kwargs): + """Gets the loss.""" + + @abstractmethod + def get_accuracy(self, **kwargs): + """Gets the accuracy.""" + + @abstractmethod + def forward(self, **kwargs): + """Forward function.""" + + @abstractmethod + def inference_model(self, **kwargs): + """Inference function.""" + + def decode(self, img_metas, output, **kwargs): + """Decode keypoints from heatmaps. + + Args: + img_metas (list(dict)): Information about data augmentation + By default this includes: + + - "image_file: path to the image file + - "center": center of the bbox + - "scale": scale of the bbox + - "rotation": rotation of the bbox + - "bbox_score": score of bbox + output (np.ndarray[N, K, H, W]): model predicted heatmaps. + """ + # batch_size = len(img_metas) + + # if 'bbox_id' in img_metas[0]: + # bbox_ids = [] + # else: + # bbox_ids = None + + # c = np.zeros((batch_size, 2), dtype=np.float32) + # s = np.zeros((batch_size, 2), dtype=np.float32) + # image_paths = [] + # score = np.ones(batch_size) + # for i in range(batch_size): + # c[i, :] = img_metas[i]['center'] + # s[i, :] = img_metas[i]['scale'] + # image_paths.append(img_metas[i]['image_file']) + + # if 'bbox_score' in img_metas[i]: + # score[i] = np.array(img_metas[i]['bbox_score']).reshape(-1) + # if bbox_ids is not None: + # bbox_ids.append(img_metas[i]['bbox_id']) + + # preds, maxvals = keypoints_from_heatmaps( + # output, + # c, + # s, + # unbiased=self.test_cfg.get('unbiased_decoding', False), + # post_process=self.test_cfg.get('post_process', 'default'), + # kernel=self.test_cfg.get('modulate_kernel', 11), + # valid_radius_factor=self.test_cfg.get('valid_radius_factor', + # 0.0546875), + # use_udp=self.test_cfg.get('use_udp', False), + # target_type=self.test_cfg.get('target_type', 'GaussianHeatmap')) + + # all_preds = np.zeros((batch_size, preds.shape[1], 3), dtype=np.float32) + # all_boxes = np.zeros((batch_size, 6), dtype=np.float32) + # all_preds[:, :, 0:2] = preds[:, :, 0:2] + # all_preds[:, :, 2:3] = maxvals + # all_boxes[:, 0:2] = c[:, 0:2] + # all_boxes[:, 2:4] = s[:, 0:2] + # all_boxes[:, 4] = np.prod(s * 200.0, axis=1) + # all_boxes[:, 5] = score + + # result = {} + + # result['preds'] = all_preds + # result['boxes'] = all_boxes + # result['image_paths'] = image_paths + # result['bbox_ids'] = bbox_ids + + return None + + @staticmethod + def _get_deconv_cfg(deconv_kernel): + """Get configurations for deconv layers.""" + if deconv_kernel == 4: + padding = 1 + output_padding = 0 + elif deconv_kernel == 3: + padding = 1 + output_padding = 1 + elif deconv_kernel == 2: + padding = 0 + output_padding = 0 + else: + raise ValueError(f'Not supported num_kernels ({deconv_kernel}).') + + return deconv_kernel, padding, output_padding + + +def build_conv_layer(cfg, *args, **kwargs) -> nn.Module: + """LICENSE""" + + if cfg is None: + cfg_ = dict(type='Conv2d') + else: + if not isinstance(cfg, dict): + raise TypeError('cfg must be a dict') + if 'type' not in cfg: + raise KeyError('the cfg dict must contain the key "type"') + cfg_ = cfg.copy() + + layer_type = cfg_.pop('type') + if layer_type != 'Conv2d': + raise KeyError(f'Unrecognized layer type {layer_type}') + else: + conv_layer = nn.Conv2d + + layer = conv_layer(*args, **kwargs, **cfg_) + + return layer + + +def build_upsample_layer(cfg, *args, **kwargs) -> nn.Module: + + if not isinstance(cfg, dict): + raise TypeError(f'cfg must be a dict, but got {type(cfg)}') + if 'type' not in cfg: + raise KeyError( + f'the cfg dict must contain the key "type", but got {cfg}') + cfg_ = cfg.copy() + + layer_type = cfg_.pop('type') + if layer_type != 'deconv': + raise KeyError(f'Unrecognized upsample type {layer_type}') + else: + upsample = nn.ConvTranspose2d + + if upsample is nn.Upsample: + cfg_['mode'] = layer_type + layer = upsample(*args, **kwargs, **cfg_) + return layer + +# @HEADS.register_module() + + +class TopdownHeatmapSimpleHead(TopdownHeatmapBaseHead): + """Top-down heatmap simple head. paper ref: Bin Xiao et al. ``Simple + Baselines for Human Pose Estimation and Tracking``. + + TopdownHeatmapSimpleHead is consisted of (>=0) number of deconv layers + and a simple conv2d layer. + + Args: + in_channels (int): Number of input channels + out_channels (int): Number of output channels + num_deconv_layers (int): Number of deconv layers. + num_deconv_layers should >= 0. Note that 0 means + no deconv layers. + num_deconv_filters (list|tuple): Number of filters. + If num_deconv_layers > 0, the length of + num_deconv_kernels (list|tuple): Kernel sizes. + in_index (int|Sequence[int]): Input feature index. Default: 0 + input_transform (str|None): Transformation type of input features. + Options: 'resize_concat', 'multiple_select', None. + Default: None. + + - 'resize_concat': Multiple feature maps will be resized to the + same size as the first one and then concat together. + Usually used in FCN head of HRNet. + - 'multiple_select': Multiple feature maps will be bundle into + a list and passed into decode head. + - None: Only one select feature map is allowed. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + loss_keypoint (dict): Config for keypoint loss. Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + num_deconv_layers=3, + num_deconv_filters=(256, 256, 256), + num_deconv_kernels=(4, 4, 4), + extra=None, + in_index=0, + input_transform=None, + align_corners=False, + loss_keypoint=None, + train_cfg=None, + test_cfg=None, + upsample=0,): + super().__init__() + + self.in_channels = in_channels + self.loss = None + self.upsample = upsample + + self.train_cfg = {} if train_cfg is None else train_cfg + self.test_cfg = {} if test_cfg is None else test_cfg + self.target_type = self.test_cfg.get('target_type', 'GaussianHeatmap') + + self._init_inputs(in_channels, in_index, input_transform) + self.in_index = in_index + self.align_corners = align_corners + + if extra is not None and not isinstance(extra, dict): + raise TypeError('extra should be dict or None.') + + if num_deconv_layers > 0: + self.deconv_layers = self._make_deconv_layer( + num_deconv_layers, + num_deconv_filters, + num_deconv_kernels, + ) + elif num_deconv_layers == 0: + self.deconv_layers = nn.Identity() + else: + raise ValueError( + f'num_deconv_layers ({num_deconv_layers}) should >= 0.') + + identity_final_layer = False + if extra is not None and 'final_conv_kernel' in extra: + assert extra['final_conv_kernel'] in [0, 1, 3] + if extra['final_conv_kernel'] == 3: + padding = 1 + elif extra['final_conv_kernel'] == 1: + padding = 0 + else: + # 0 for Identity mapping. + identity_final_layer = True + kernel_size = extra['final_conv_kernel'] + else: + kernel_size = 1 + padding = 0 + + if identity_final_layer: + self.final_layer = nn.Identity() + else: + conv_channels = num_deconv_filters[ + -1] if num_deconv_layers > 0 else self.in_channels + + layers = [] + if extra is not None: + num_conv_layers = extra.get('num_conv_layers', 0) + num_conv_kernels = extra.get('num_conv_kernels', + [1] * num_conv_layers) + + for i in range(num_conv_layers): + layers.append( + build_conv_layer( + dict(type='Conv2d'), + in_channels=conv_channels, + out_channels=conv_channels, + kernel_size=num_conv_kernels[i], + stride=1, + padding=(num_conv_kernels[i] - 1) // 2)) + layers.append( + nn.BatchNorm2d(conv_channels) + ) + layers.append(nn.ReLU(inplace=True)) + + layers.append( + build_conv_layer( + cfg=dict(type='Conv2d'), + in_channels=conv_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=1, + padding=padding)) + + if len(layers) > 1: + self.final_layer = nn.Sequential(*layers) + else: + self.final_layer = layers[0] + + def get_loss(self, output, target, target_weight): + """Calculate top-down keypoint loss. + + Note: + - batch_size: N + - num_keypoints: K + - heatmaps height: H + - heatmaps weight: W + + Args: + output (torch.Tensor[N,K,H,W]): Output heatmaps. + target (torch.Tensor[N,K,H,W]): Target heatmaps. + target_weight (torch.Tensor[N,K,1]): + Weights across different joint types. + """ + + losses = dict() + + assert not isinstance(self.loss, nn.Sequential) + assert target.dim() == 4 and target_weight.dim() == 3 + losses['heatmap_loss'] = self.loss(output, target, target_weight) + + return losses + + def get_accuracy(self, output, target, target_weight): + """Calculate accuracy for top-down keypoint loss. + + Note: + - batch_size: N + - num_keypoints: K + - heatmaps height: H + - heatmaps weight: W + + Args: + output (torch.Tensor[N,K,H,W]): Output heatmaps. + target (torch.Tensor[N,K,H,W]): Target heatmaps. + target_weight (torch.Tensor[N,K,1]): + Weights across different joint types. + """ + + accuracy = dict() + + if self.target_type == 'GaussianHeatmap': + _, avg_acc, _ = pose_pck_accuracy( + output.detach().cpu().numpy(), + target.detach().cpu().numpy(), + target_weight.detach().cpu().numpy().squeeze(-1) > 0) + accuracy['acc_pose'] = float(avg_acc) + + return accuracy + + def forward(self, x): + """Forward function.""" + x = self._transform_inputs(x) + x = self.deconv_layers(x) + x = self.final_layer(x) + return x + + def inference_model(self, x, flip_pairs=None): + """Inference function. + + Returns: + output_heatmap (np.ndarray): Output heatmaps. + + Args: + x (torch.Tensor[N,K,H,W]): Input features. + flip_pairs (None | list[tuple]): + Pairs of keypoints which are mirrored. + """ + output = self.forward(x) + + if flip_pairs is not None: + output_heatmap = flip_back( + output.detach().cpu().numpy(), + flip_pairs, + target_type=self.target_type) + # feature is not aligned, shift flipped heatmap for higher accuracy + if self.test_cfg.get('shift_heatmap', False): + output_heatmap[:, :, :, 1:] = output_heatmap[:, :, :, :-1] + else: + output_heatmap = output.detach().cpu().numpy() + return output_heatmap + + def _init_inputs(self, in_channels, in_index, input_transform): + """Check and initialize input transforms. + + The in_channels, in_index and input_transform must match. + Specifically, when input_transform is None, only single feature map + will be selected. So in_channels and in_index must be of type int. + When input_transform is not None, in_channels and in_index must be + list or tuple, with the same length. + + Args: + in_channels (int|Sequence[int]): Input channels. + in_index (int|Sequence[int]): Input feature index. + input_transform (str|None): Transformation type of input features. + Options: 'resize_concat', 'multiple_select', None. + + - 'resize_concat': Multiple feature maps will be resize to the + same size as first one and than concat together. + Usually used in FCN head of HRNet. + - 'multiple_select': Multiple feature maps will be bundle into + a list and passed into decode head. + - None: Only one select feature map is allowed. + """ + + if input_transform is not None: + assert input_transform in ['resize_concat', 'multiple_select'] + self.input_transform = input_transform + self.in_index = in_index + if input_transform is not None: + assert isinstance(in_channels, (list, tuple)) + assert isinstance(in_index, (list, tuple)) + assert len(in_channels) == len(in_index) + if input_transform == 'resize_concat': + self.in_channels = sum(in_channels) + else: + self.in_channels = in_channels + else: + assert isinstance(in_channels, int) + assert isinstance(in_index, int) + self.in_channels = in_channels + + def _transform_inputs(self, inputs): + """Transform inputs for decoder. + + Args: + inputs (list[Tensor] | Tensor): multi-level img features. + + Returns: + Tensor: The transformed inputs + """ + if not isinstance(inputs, list): + if not isinstance(inputs, list): + if self.upsample > 0: + inputs = resize( + input=F.relu(inputs), + scale_factor=self.upsample, + mode='bilinear', + align_corners=self.align_corners + ) + return inputs + + if self.input_transform == 'resize_concat': + inputs = [inputs[i] for i in self.in_index] + upsampled_inputs = [ + resize( + input=x, + size=inputs[0].shape[2:], + mode='bilinear', + align_corners=self.align_corners) for x in inputs + ] + inputs = torch.cat(upsampled_inputs, dim=1) + elif self.input_transform == 'multiple_select': + inputs = [inputs[i] for i in self.in_index] + else: + inputs = inputs[self.in_index] + + return inputs + + def _make_deconv_layer(self, num_layers, num_filters, num_kernels): + """Make deconv layers.""" + if num_layers != len(num_filters): + error_msg = f'num_layers({num_layers}) ' \ + f'!= length of num_filters({len(num_filters)})' + raise ValueError(error_msg) + if num_layers != len(num_kernels): + error_msg = f'num_layers({num_layers}) ' \ + f'!= length of num_kernels({len(num_kernels)})' + raise ValueError(error_msg) + + layers = [] + for i in range(num_layers): + kernel, padding, output_padding = \ + self._get_deconv_cfg(num_kernels[i]) + + planes = num_filters[i] + layers.append( + build_upsample_layer( + dict(type='deconv'), + in_channels=self.in_channels, + out_channels=planes, + kernel_size=kernel, + stride=2, + padding=padding, + output_padding=output_padding, + bias=False)) + layers.append(nn.BatchNorm2d(planes)) + layers.append(nn.ReLU(inplace=True)) + self.in_channels = planes + + return nn.Sequential(*layers) + + def init_weights(self): + """Initialize model weights.""" + for _, m in self.deconv_layers.named_modules(): + if isinstance(m, nn.ConvTranspose2d): + normal_init(m, std=0.001) + elif isinstance(m, nn.BatchNorm2d): + constant_init(m, 1) + for m in self.final_layer.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, std=0.001, bias=0) + elif isinstance(m, nn.BatchNorm2d): + constant_init(m, 1) diff --git a/dp2/detection/models/vit_pose/vit_pose.py b/dp2/detection/models/vit_pose/vit_pose.py new file mode 100644 index 0000000000000000000000000000000000000000..7dd6eafd75fd5208625123c2c74fa7510c871136 --- /dev/null +++ b/dp2/detection/models/vit_pose/vit_pose.py @@ -0,0 +1,218 @@ +# Code adapted from: https://github.com/gpastal24/ViTPose-Pytorch +from .topdown_heatmap_simple_head import TopdownHeatmapSimpleHead +import torch +from .backbone import ViT +import torchvision.transforms.functional as F +import torch.nn as nn +import tops + +model_large = dict( + type="TopDown", + pretrained=None, + backbone=dict( + type="ViT", + img_size=(256, 192), + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + ratio=1, + use_checkpoint=False, + mlp_ratio=4, + qkv_bias=True, + drop_path_rate=0.5, + ), + keypoint_head=dict( + type="TopdownHeatmapSimpleHead", + in_channels=1024, + num_deconv_layers=2, + num_deconv_filters=(256, 256), + num_deconv_kernels=(4, 4), + extra=dict( + final_conv_kernel=1, + ), + out_channels=17, + loss_keypoint=dict(type="JointsMSELoss", use_target_weight=True), + ), + train_cfg=dict(), + test_cfg=dict( + flip_test=True, + post_process="default", + shift_heatmap=False, + target_type="GaussianHeatmap", + modulate_kernel=11, + use_udp=True, + ), +) + + +model_base = dict( + type="TopDown", + pretrained=None, + backbone=dict( + type="ViT", + img_size=(256, 192), + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + ratio=1, + use_checkpoint=False, + mlp_ratio=4, + qkv_bias=True, + drop_path_rate=0.3, + ), + keypoint_head=dict( + type="TopdownHeatmapSimpleHead", + in_channels=768, + num_deconv_layers=2, + num_deconv_filters=(256, 256), + num_deconv_kernels=(4, 4), + extra=dict( + final_conv_kernel=1, + ), + out_channels=17, + loss_keypoint=dict(type="JointsMSELoss", use_target_weight=True), + ), + train_cfg=dict(), + test_cfg=dict(), +) +model_huge = dict( + type='TopDown', + pretrained=None, + backbone=dict( + type='ViT', + img_size=(256, 192), + patch_size=16, + embed_dim=1280, + depth=32, + num_heads=16, + ratio=1, + use_checkpoint=False, + mlp_ratio=4, + qkv_bias=True, + drop_path_rate=0.55, + ), + keypoint_head=dict( + type='TopdownHeatmapSimpleHead', + in_channels=1280, + num_deconv_layers=2, + num_deconv_filters=(256, 256), + num_deconv_kernels=(4, 4), + extra=dict(final_conv_kernel=1, ), + out_channels=17, + loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)), + train_cfg=dict(), + test_cfg=dict( + flip_test=True, + post_process='default', + shift_heatmap=False, + target_type="GaussianHeatmap", + modulate_kernel=11, + use_udp=True)) + + +class VitPoseModel(nn.Module): + def __init__(self, model_name): + super().__init__() + assert model_name in ["vit_base", "vit_large", "vit_huge"] + model = { + "vit_base": model_base, + "vit_large": model_large, + "vit_huge": model_huge + }[model_name] + weight_url = { + "vit_base": "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/90235a26-3b8c-427d-a264-c68155abecdcfcfcd8a9-0388-4575-b85b-607d3c0a9b149bef8f0f-a0f9-4662-a561-1b47ba5f1636", + "vit_large": "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/a580a44c-0afd-43ac-a2cb-9956c32b1d1a78c51ecb-81bb-4345-8710-13904cb9dbbe0703db2d-8534-42e0-ac4d-518ab51fe7db", + "vit_huge": "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/a33b6ada-4d2f-4ef7-8f83-b33f58b69f5b2a62e181-2131-467d-a900-027157a08571d761fad4-785b-4b84-8596-8932c7857e44" + }[model_name] + file_name = { + "vit_base": "vit-b-multi-coco-595b5e128b.pth", + "vit_large": "vit-l-multi-coco-9475d27cec.pth", + "vit_huge": "vit-h-multi-coco-dbc06d4337.pth", + }[model_name] + # Set check_hash to true if you suspect a download error. + weight_path = tops.download_file( + weight_url, file_name=file_name, check_hash=True) + + self.keypoint_head = tops.to_cuda(TopdownHeatmapSimpleHead( + in_channels=model["keypoint_head"]["in_channels"], + out_channels=model["keypoint_head"]["out_channels"], + num_deconv_filters=model["keypoint_head"]["num_deconv_filters"], + num_deconv_kernels=model["keypoint_head"]["num_deconv_kernels"], + num_deconv_layers=model["keypoint_head"]["num_deconv_layers"], + extra=model["keypoint_head"]["extra"], + )) + # print(head) + self.backbone = tops.to_cuda(ViT( + img_size=model["backbone"]["img_size"], + patch_size=model["backbone"]["patch_size"], + embed_dim=model["backbone"]["embed_dim"], + depth=model["backbone"]["depth"], + num_heads=model["backbone"]["num_heads"], + ratio=model["backbone"]["ratio"], + mlp_ratio=model["backbone"]["mlp_ratio"], + qkv_bias=model["backbone"]["qkv_bias"], + drop_path_rate=model["backbone"]["drop_path_rate"], + )) + ckpt = torch.load(weight_path, map_location=tops.get_device()) + self.load_state_dict(ckpt["state_dict"]) + self.backbone.eval() + self.keypoint_head.eval() + + def forward(self, img: torch.Tensor, boxes_ltrb: torch.Tensor): + assert img.ndim == 3 + assert img.dtype == torch.uint8 + assert boxes_ltrb.ndim == 2 and boxes_ltrb.shape[1] == 4 + assert boxes_ltrb.dtype == torch.long + boxes_ltrb = boxes_ltrb.clamp(0) + padded_boxes = torch.zeros_like(boxes_ltrb) + images = torch.zeros((len(boxes_ltrb), 3, 256, 192), device=img.device, dtype=torch.float32) + + for i, (x0, y0, x1, y1) in enumerate(boxes_ltrb): + x1 = min(img.shape[-1], x1) + y1 = min(img.shape[-2], y1) + correction_factor = 256 / 192 * (x1 - x0) / (y1 - y0) + if correction_factor > 1: + # increase y side + center = y0 + (y1 - y0) // 2 + length = (y1-y0).mul(correction_factor).round().long() + y0_new = center - length.div(2).long() + y1_new = center + length.div(2).long() + image_crop = img[:, y0:y1, x0:x1] + # print(y1,y2,x1,x2) + pad = ((y0_new-y0).abs(), (y1_new-y1).abs()) +# pad = (int(abs(y0_new-y0))), int(abs(y1_new-y1)) + image_crop = torch.nn.functional.pad(image_crop, [*(0, 0), *pad]) + padded_boxes[i] = torch.tensor([x0, y0_new, x1, y1_new]) + else: + center = x0 + (x1 - x0) // 2 + length = (x1-x0).div(correction_factor).round().long() + x0_new = center - length.div(2).long() + x1_new = center + length.div(2).long() + image_crop = img[:, y0:y1, x0:x1] + pad = ((x0_new-x0).abs(), (x1_new-x1).abs()) + image_crop = torch.nn.functional.pad(image_crop, [*pad, ]) + padded_boxes[i] = torch.tensor([x0_new, y0, x1_new, y1]) + image_crop = F.resize(image_crop.float(), (256, 192), antialias=True) + image_crop = F.normalize(image_crop, mean=[0.485*255, 0.456*255, + 0.406*255], std=[0.229*255, 0.224*255, 0.225*255]) + images[i] = image_crop + + x = self.backbone(images) + out = self.keypoint_head(x) + pts = torch.empty((out.shape[0], out.shape[1], 3), dtype=torch.float32, device=img.device) + # For each human, for each joint: y, x, confidence + b, indices = torch.max(out, dim=2) + b, indices = torch.max(b, dim=2) + + c, indicesc = torch.max(out, dim=3) + c, indicesc = torch.max(c, dim=2) + dim1 = torch.tensor(1./64, device=img.device) + dim2 = torch.tensor(1./48, device=img.device) + for i in range(0, out.shape[0]): + pts[i, :, 0] = indicesc[i, :] * dim1 * (padded_boxes[i][3] - padded_boxes[i][1]) + padded_boxes[i][1] + pts[i, :, 1] = indices[i, :] * dim2 * (padded_boxes[i][2] - padded_boxes[i][0]) + padded_boxes[i][0] + pts[i, :, 2] = c[i, :] + pts = pts[:, :, [1, 0, 2]] + return pts diff --git a/dp2/detection/models/vit_pose_maskrcnn.py b/dp2/detection/models/vit_pose_maskrcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..f801d186486ef43e01ed554b7db34a69a7de5e34 --- /dev/null +++ b/dp2/detection/models/vit_pose_maskrcnn.py @@ -0,0 +1,73 @@ +import torch +import lzma +from pathlib import Path +from dp2.detection.base import BaseDetector +from .mask_rcnn import MaskRCNNDetector +from ..structures import PersonDetection +from tops import logger +from .vit_pose.vit_pose import VitPoseModel +from ..utils import masks_to_boxes + + +def box1_inside_box2(box1: torch.Tensor, box2: torch.Tensor): + assert len(box1.shape) == 2 + assert len(box2.shape) == 2 + box1_inside = torch.zeros(box1.shape[0], device=box1.device, dtype=torch.bool) + # This can be batched + for i, box in enumerate(box1): + is_outside_lefttop = (box[None, [0, 1]] <= box2[:, [0, 1]]).any(dim=1) + is_outside_rightbot = (box[None, [2, 3]] >= box2[:, [2, 3]]).any(dim=1) + is_outside = is_outside_lefttop.logical_or(is_outside_rightbot) + box1_inside[i] = is_outside.logical_not().any() + return box1_inside + + +class MaskRCNNVitPose(BaseDetector): + + def __init__( + self, + mask_rcnn_cfg, + cse_post_process_cfg, + score_threshold: float, + **kwargs + ) -> None: + super().__init__(**kwargs) + self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold) + self.vit_pose = VitPoseModel("vit_huge") + + self.cse_post_process_cfg = cse_post_process_cfg + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def load_from_cache(self, cache_path: Path): + logger.log(f"Loading detection from cache path: {cache_path}",) + with lzma.open(cache_path, "rb") as fp: + state_dict = torch.load(fp, map_location="cpu") + kwargs = dict( + post_process_cfg=self.cse_post_process_cfg, + ) + return [ + state["cls"].from_state_dict(**kwargs, state_dict=state) + for state in state_dict + ] + + @torch.no_grad() + def forward(self, im: torch.Tensor): + maskrcnn_dets = self.mask_rcnn(im) + + maskrcnn_person = { + k: v[maskrcnn_dets["is_person"]] for k, v in maskrcnn_dets.items() + } + boxes = masks_to_boxes(maskrcnn_person["segmentation"]) + keypoints = self.vit_pose(im, boxes).cpu() + keypoints[:, :, -1] = keypoints[:, :, -1] >= 0.3 + persons_without_cse = PersonDetection( + maskrcnn_person["segmentation"], **self.cse_post_process_cfg, + orig_imshape_CHW=im.shape, + keypoints=keypoints + ) + persons_without_cse.pre_process() + + all_detections = [persons_without_cse] + return all_detections diff --git a/dp2/detection/person_detector.py b/dp2/detection/person_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..1bbd0df8c2aa44839a5de8bd9a6aeede054ff2ee --- /dev/null +++ b/dp2/detection/person_detector.py @@ -0,0 +1,135 @@ +import torch +import lzma +from dp2.detection.base import BaseDetector +from .utils import combine_cse_maskrcnn_dets +from .models.cse import CSEDetector +from .models.mask_rcnn import MaskRCNNDetector +from .models.keypoint_maskrcnn import KeypointMaskRCNN +from .structures import CSEPersonDetection, PersonDetection +from pathlib import Path + + +class CSEPersonDetector(BaseDetector): + def __init__( + self, + score_threshold: float, + mask_rcnn_cfg: dict, + cse_cfg: dict, + cse_post_process_cfg: dict, + **kwargs + ) -> None: + super().__init__(**kwargs) + self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold) + self.cse_detector = CSEDetector(**cse_cfg, score_thres=score_threshold) + self.post_process_cfg = cse_post_process_cfg + self.iou_combine_threshold = self.post_process_cfg.pop("iou_combine_threshold") + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def load_from_cache(self, cache_path: Path): + with lzma.open(cache_path, "rb") as fp: + state_dict = torch.load(fp) + kwargs = dict( + post_process_cfg=self.post_process_cfg, + embed_map=self.cse_detector.embed_map, + ) + return [ + state["cls"].from_state_dict(**kwargs, state_dict=state) + for state in state_dict + ] + + @torch.no_grad() + def forward(self, im: torch.Tensor, cse_dets=None): + mask_dets = self.mask_rcnn(im) + if cse_dets is None: + cse_dets = self.cse_detector(im) + segmentation = mask_dets["segmentation"] + segmentation, cse_dets, _ = combine_cse_maskrcnn_dets( + segmentation, cse_dets, self.iou_combine_threshold + ) + det = CSEPersonDetection( + segmentation=segmentation, + cse_dets=cse_dets, + embed_map=self.cse_detector.embed_map, + orig_imshape_CHW=im.shape, + **self.post_process_cfg + ) + return [det] + + +class MaskRCNNPersonDetector(BaseDetector): + def __init__( + self, + score_threshold: float, + mask_rcnn_cfg: dict, + cse_post_process_cfg: dict, + **kwargs + ) -> None: + super().__init__(**kwargs) + self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold) + self.post_process_cfg = cse_post_process_cfg + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def load_from_cache(self, cache_path: Path): + with lzma.open(cache_path, "rb") as fp: + state_dict = torch.load(fp) + kwargs = dict( + post_process_cfg=self.post_process_cfg, + ) + return [ + state["cls"].from_state_dict(**kwargs, state_dict=state) + for state in state_dict + ] + + @torch.no_grad() + def forward(self, im: torch.Tensor): + mask_dets = self.mask_rcnn(im) + segmentation = mask_dets["segmentation"] + det = PersonDetection( + segmentation, **self.post_process_cfg, orig_imshape_CHW=im.shape + ) + return [det] + + +class KeypointMaskRCNNPersonDetector(BaseDetector): + def __init__( + self, + score_threshold: float, + mask_rcnn_cfg: dict, + cse_post_process_cfg: dict, + **kwargs + ) -> None: + super().__init__(**kwargs) + self.mask_rcnn = KeypointMaskRCNN( + **mask_rcnn_cfg, score_threshold=score_threshold + ) + self.post_process_cfg = cse_post_process_cfg + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def load_from_cache(self, cache_path: Path): + with lzma.open(cache_path, "rb") as fp: + state_dict = torch.load(fp) + kwargs = dict( + post_process_cfg=self.post_process_cfg, + ) + return [ + state["cls"].from_state_dict(**kwargs, state_dict=state) + for state in state_dict + ] + + @torch.no_grad() + def forward(self, im: torch.Tensor): + mask_dets = self.mask_rcnn(im) + segmentation = mask_dets["segmentation"] + det = PersonDetection( + segmentation, + **self.post_process_cfg, + orig_imshape_CHW=im.shape, + keypoints=mask_dets["keypoints"] + ) + return [det] diff --git a/dp2/detection/structures.py b/dp2/detection/structures.py new file mode 100644 index 0000000000000000000000000000000000000000..3daf58f4617feedb7724137721e85e32e94b87b2 --- /dev/null +++ b/dp2/detection/structures.py @@ -0,0 +1,504 @@ +import torch +import numpy as np +from dp2 import utils +from dp2.utils import vis_utils, crop_box +from .utils import ( + cut_pad_resize, masks_to_boxes, + get_kernel, transform_embedding, initialize_cse_boxes +) +from .box_utils import get_expanded_bbox, include_box +import torchvision +import tops +from .box_utils_fdf import expand_bbox as expand_bbox_fdf + + +class VehicleDetection: + + def __init__(self, segmentation: torch.BoolTensor) -> None: + self.segmentation = segmentation + self.boxes = masks_to_boxes(segmentation) + assert self.boxes.shape[1] == 4, self.boxes.shape + self.n_detections = self.segmentation.shape[0] + area = (self.boxes[:, 3] - self.boxes[:, 1]) * (self.boxes[:, 2] - self.boxes[:, 0]) + + sorted_idx = torch.argsort(area, descending=True) + self.segmentation = self.segmentation[sorted_idx] + self.boxes = self.boxes[sorted_idx].cpu() + + def pre_process(self): + pass + + def get_crop(self, idx: int, im): + assert idx < len(self) + box = self.boxes[idx] + im = crop_box(self.im, box) + mask = crop_box(self.segmentation[idx]) + mask = mask == 0 + return dict(img=im, mask=mask.float(), boxes=box) + + def visualize(self, im): + if len(self) == 0: + return im + im = vis_utils.draw_mask(im.clone(), self.segmentation.logical_not()) + return im + + def __len__(self): + return self.n_detections + + @staticmethod + def from_state_dict(state_dict, **kwargs): + numel = np.prod(state_dict["shape"]) + arr = np.unpackbits(state_dict["segmentation"].numpy(), count=numel) + segmentation = tops.to_cuda(torch.from_numpy(arr)).view(state_dict["shape"]) + return VehicleDetection(segmentation) + + def state_dict(self, **kwargs): + segmentation = torch.from_numpy(np.packbits(self.segmentation.bool().cpu().numpy())) + return dict(segmentation=segmentation, cls=self.__class__, shape=self.segmentation.shape) + + +class FaceDetection: + + def __init__(self, + boxes_ltrb: torch.LongTensor, target_imsize, fdf128_expand: bool, + keypoints: torch.Tensor = None, + **kwargs) -> None: + + self.boxes = boxes_ltrb.cpu() + assert self.boxes.shape[1] == 4, self.boxes.shape + self.target_imsize = tuple(target_imsize) + # Sory by area to paste in largest faces last + area = (self.boxes[:, 2] - self.boxes[:, 0]) * (self.boxes[:, 3] - self.boxes[:, 1]).view(-1) + idx = area.argsort(descending=False) + self.boxes = self.boxes[idx] + self.fdf128_expand = fdf128_expand + self.orig_keypoints = keypoints + if keypoints is not None: + self.orig_keypoints = self.orig_keypoints[idx] + assert keypoints.shape == (len(boxes_ltrb), 17, 2) or \ + keypoints.shape == (len(boxes_ltrb), 7, 2), keypoints.shape + + def visualize(self, im): + if len(self) == 0: + return im + orig_device = im.device + for box in self.boxes: + simple_expand = False if self.fdf128_expand else True + e_box = torch.from_numpy(expand_bbox_fdf(box.numpy(), im.shape[-2:], simple_expand)) + im = torchvision.utils.draw_bounding_boxes(im.cpu(), e_box[None], colors=(0, 0, 255), width=2) + im = torchvision.utils.draw_bounding_boxes(im.cpu(), self.boxes, colors=(255, 0, 0), width=2) + if self.orig_keypoints is not None: + im = vis_utils.draw_keypoints(im, self.orig_keypoints, radius=1) + + return im.to(device=orig_device) + + def get_crop(self, idx: int, im): + assert idx < len(self) + box = self.boxes[idx].numpy() + simple_expand = False if self.fdf128_expand else True + expanded_boxes = expand_bbox_fdf(box, im.shape[-2:], simple_expand) + im = cut_pad_resize(im, expanded_boxes, self.target_imsize, fdf_resize=True) + + # Find the square mask corresponding to box. + box_mask = box.copy().astype(float) + box_mask[[0, 2]] -= expanded_boxes[0] + box_mask[[1, 3]] -= expanded_boxes[1] + + width = expanded_boxes[2] - expanded_boxes[0] + resize_factor = self.target_imsize[0] / width + box_mask = (box_mask * resize_factor).astype(int) + mask = torch.ones((1, *self.target_imsize), device=im.device, dtype=torch.float32) + crop_box(mask, box_mask).fill_(0) + if self.orig_keypoints is None: + return dict( + img=im[None], mask=mask[None], + boxes=torch.from_numpy(expanded_boxes).view(1, -1)) + + keypoint = self.orig_keypoints[idx, :7, :2].clone() + keypoint[:, 0] -= expanded_boxes[0] + keypoint[:, 1] -= expanded_boxes[1] + w = expanded_boxes[2] - expanded_boxes[0] + keypoint /= w + keypoint = keypoint.clamp(0, 1) + return dict( + img=im[None], mask=mask[None], + boxes=torch.from_numpy(expanded_boxes).view(1, -1), + keypoints=keypoint[None]) + + def __len__(self): + return len(self.boxes) + + @staticmethod + def from_state_dict(state_dict, **kwargs): + return FaceDetection( + state_dict["boxes"].cpu(), + keypoints=state_dict["orig_keypoints"] if "orig_keypoints" in state_dict else None, + **kwargs) + + def state_dict(self, **kwargs): + return dict( + boxes=self.boxes, + cls=self.__class__, + orig_keypoints=self.orig_keypoints) + + def pre_process(self): + pass + + +def remove_dilate_in_pad(mask: torch.Tensor, exp_box, orig_imshape): + """ + Dilation happens after padding, which could place dilation in the padded area. + Remove this. + """ + x0, y0, x1, y1 = exp_box + H, W = orig_imshape + # Padding in original image space + p_y0 = max(0, -y0) + p_y1 = max(y1 - H, 0) + p_x0 = max(0, -x0) + p_x1 = max(x1 - W, 0) + resize_ratio = mask.shape[-2] / (y1-y0) + p_x0, p_y0, p_x1, p_y1 = [(_*resize_ratio).floor().long() for _ in [p_x0, p_y0, p_x1, p_y1]] + mask[..., :p_y0, :] = 0 + mask[..., :p_x0] = 0 + mask[..., mask.shape[-2] - p_y1:, :] = 0 + mask[..., mask.shape[-1] - p_x1:] = 0 + + +class CSEPersonDetection: + + def __init__(self, + segmentation, cse_dets, + target_imsize, + exp_bbox_cfg, exp_bbox_filter, + dilation_percentage: float, + embed_map: torch.Tensor, + orig_imshape_CHW, + normalize_embedding: bool) -> None: + self.segmentation = segmentation + self.cse_dets = cse_dets + self.target_imsize = list(target_imsize) + self.pre_processed = False + self.exp_bbox_cfg = exp_bbox_cfg + self.exp_bbox_filter = exp_bbox_filter + self.dilation_percentage = dilation_percentage + self.embed_map = embed_map + self.embed_map_cpu = embed_map.cpu() + self.normalize_embedding = normalize_embedding + if self.normalize_embedding: + embed_map_mean = self.embed_map.mean(dim=0, keepdim=True) + embed_map_rstd = ((self.embed_map - embed_map_mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt() + self.embed_map_normalized = (self.embed_map - embed_map_mean) * embed_map_rstd + self.orig_imshape_CHW = orig_imshape_CHW + + @torch.no_grad() + def pre_process(self): + if self.pre_processed: + return + boxes = initialize_cse_boxes(self.segmentation, self.cse_dets["bbox_XYXY"]).cpu() + expanded_boxes = [] + included_boxes = [] + for i in range(len(boxes)): + exp_box = get_expanded_bbox( + boxes[i], self.orig_imshape_CHW[1:], self.segmentation[i], **self.exp_bbox_cfg, + target_aspect_ratio=self.target_imsize[0]/self.target_imsize[1]) + if not include_box(exp_box, imshape=self.orig_imshape_CHW[1:], **self.exp_bbox_filter): + continue + included_boxes.append(i) + expanded_boxes.append(exp_box) + expanded_boxes = torch.LongTensor(expanded_boxes).view(-1, 4) + self.segmentation = self.segmentation[included_boxes] + self.cse_dets = {k: v[included_boxes] for k, v in self.cse_dets.items()} + + self.mask = torch.empty((len(expanded_boxes), *self.target_imsize), device=tops.get_device(), dtype=torch.bool) + area = self.segmentation.sum(dim=[1, 2]).view(len(expanded_boxes)) + for i, box in enumerate(expanded_boxes): + self.mask[i] = cut_pad_resize(self.segmentation[i:i+1], box, self.target_imsize)[0] + + dilation_kernel = get_kernel(int((self.target_imsize[0]*self.target_imsize[1])**0.5*self.dilation_percentage)) + self.maskrcnn_mask = self.mask.clone().logical_not()[:, None] + self.mask = utils.binary_dilation(self.mask[:, None], dilation_kernel) + for i in range(len(expanded_boxes)): + remove_dilate_in_pad(self.mask[i], expanded_boxes[i], self.orig_imshape_CHW[1:]) + self.boxes = expanded_boxes.cpu() + self.dilated_boxes = get_dilated_boxes(self.boxes, self.mask) + + self.pre_processed = True + self.n_detections = len(self.boxes) + self.mask = self.mask.logical_not() + + E_mask = torch.zeros((self.n_detections, 1, *self.target_imsize), device=self.mask.device, dtype=torch.bool) + self.vertices = torch.zeros_like(E_mask, dtype=torch.long) + for i in range(self.n_detections): + E_, E_mask[i] = transform_embedding( + self.cse_dets["instance_embedding"][i], + self.cse_dets["instance_segmentation"][i], + self.boxes[i], + self.cse_dets["bbox_XYXY"][i].cpu(), + self.target_imsize + ) + self.vertices[i] = utils.from_E_to_vertex( + E_[None], E_mask[i:i+1].logical_not(), self.embed_map).squeeze()[None] + self.E_mask = E_mask + + sorted_idx = torch.argsort(area, descending=False) + self.mask = self.mask[sorted_idx] + self.boxes = self.boxes[sorted_idx.cpu()] + self.vertices = self.vertices[sorted_idx] + self.E_mask = self.E_mask[sorted_idx] + self.maskrcnn_mask = self.maskrcnn_mask[sorted_idx] + + def get_crop(self, idx: int, im): + self.pre_process() + assert idx < len(self) + box = self.boxes[idx] + mask = self.mask[idx] + im = cut_pad_resize(im, box, self.target_imsize).unsqueeze(0) + + vertices_ = self.vertices[idx] + E_mask_ = self.E_mask[idx].float() + if self.normalize_embedding: + embedding = self.embed_map_normalized[vertices_.squeeze(dim=0)].permute(2, 0, 1) * E_mask_ + else: + embedding = self.embed_map[vertices_.squeeze(dim=0)].permute(2, 0, 1) * E_mask_ + + return dict( + img=im, + mask=mask.float()[None], + boxes=box.reshape(1, -1), + E_mask=E_mask_[None], + vertices=vertices_[None], + embed_map=self.embed_map, + embedding=embedding[None], + maskrcnn_mask=self.maskrcnn_mask[idx].float()[None] + ) + + def __len__(self): + self.pre_process() + return self.n_detections + + def state_dict(self, after_preprocess=False): + """ + The processed annotations occupy more space than the original detections. + """ + if not after_preprocess: + return { + "combined_segmentation": self.segmentation.bool(), + "cse_instance_segmentation": self.cse_dets["instance_segmentation"].bool(), + "cse_instance_embedding": self.cse_dets["instance_embedding"], + "cse_bbox_XYXY": self.cse_dets["bbox_XYXY"].long(), + "cls": self.__class__, + "orig_imshape_CHW": self.orig_imshape_CHW + } + self.pre_process() + def compress_bool(x): return torch.from_numpy(np.packbits(x.bool().cpu().numpy())) + return dict( + E_mask=compress_bool(self.E_mask), + mask=compress_bool(self.mask), + maskrcnn_mask=compress_bool(self.maskrcnn_mask), + vertices=self.vertices.to(torch.int16).cpu(), + cls=self.__class__, + boxes=self.boxes, + orig_imshape_CHW=self.orig_imshape_CHW, + ) + + @staticmethod + def from_state_dict( + state_dict, embed_map, + post_process_cfg, **kwargs): + after_preprocess = "segmentation" not in state_dict and "combined_segmentation" not in state_dict + if after_preprocess: + detection = CSEPersonDetection( + segmentation=None, cse_dets=None, embed_map=embed_map, + orig_imshape_CHW=state_dict["orig_imshape_CHW"], + **post_process_cfg) + detection.vertices = tops.to_cuda(state_dict["vertices"].long()) + numel = np.prod(detection.vertices.shape) + + def unpack_bool(x): + x = torch.from_numpy(np.unpackbits(x.numpy(), count=numel)) + return x.view(*detection.vertices.shape) + detection.E_mask = tops.to_cuda(unpack_bool(state_dict["E_mask"])) + detection.mask = tops.to_cuda(unpack_bool(state_dict["mask"])) + detection.maskrcnn_mask = tops.to_cuda(unpack_bool(state_dict["maskrcnn_mask"])) + detection.n_detections = len(detection.mask) + detection.pre_processed = True + + if isinstance(state_dict["boxes"], np.ndarray): + state_dict["boxes"] = torch.from_numpy(state_dict["boxes"]) + detection.boxes = state_dict["boxes"] + return detection + + cse_dets = dict( + instance_segmentation=state_dict["cse_instance_segmentation"], + instance_embedding=state_dict["cse_instance_embedding"], + embed_map=embed_map, + bbox_XYXY=state_dict["cse_bbox_XYXY"]) + cse_dets = {k: tops.to_cuda(v) for k, v in cse_dets.items()} + + segmentation = state_dict["combined_segmentation"] + return CSEPersonDetection( + segmentation, cse_dets, embed_map=embed_map, + orig_imshape_CHW=state_dict["orig_imshape_CHW"], + **post_process_cfg) + + def visualize(self, im): + self.pre_process() + if len(self) == 0: + return im + im = vis_utils.draw_cropped_masks( + im.cpu(), self.mask.cpu(), self.boxes, visualize_instances=False) + E = self.embed_map_cpu[self.vertices.long().cpu()].squeeze(1).permute(0, 3, 1, 2) + im = vis_utils.draw_cse_all( + E, self.E_mask.squeeze(1).bool().cpu(), im, + self.boxes, self.embed_map_cpu) + im = torchvision.utils.draw_bounding_boxes(im, self.boxes, colors=(255, 0, 0), width=2) + return im + + +def shift_and_preprocess_keypoints(keypoints: torch.Tensor, boxes): + keypoints = keypoints.clone() + N = boxes.shape[0] + tops.assert_shape(keypoints, (N, None, 3)) + tops.assert_shape(boxes, (N, 4)) + x0, y0, x1, y1 = [_.view(-1, 1) for _ in boxes.T] + + w = x1 - x0 + h = y1 - y0 + keypoints[:, :, 0] = (keypoints[:, :, 0] - x0) / w + keypoints[:, :, 1] = (keypoints[:, :, 1] - y0) / h + def check_outside(x): return (x < 0).logical_or(x > 1) + is_outside = check_outside(keypoints[:, :, 0]).logical_or(check_outside(keypoints[:, :, 1])) + keypoints[:, :, 2] = keypoints[:, :, 2] > 0 + keypoints[:, :, 2] = (keypoints[:, :, 2] > 0).logical_and(is_outside.logical_not()) + return keypoints + + +class PersonDetection: + + def __init__( + self, + segmentation, + target_imsize, + exp_bbox_cfg, exp_bbox_filter, + dilation_percentage: float, + orig_imshape_CHW, + kp_vis_thr=None, + keypoints=None, + **kwargs) -> None: + self.segmentation = segmentation + self.target_imsize = list(target_imsize) + self.pre_processed = False + self.exp_bbox_cfg = exp_bbox_cfg + self.exp_bbox_filter = exp_bbox_filter + self.dilation_percentage = dilation_percentage + self.orig_imshape_CHW = orig_imshape_CHW + self.orig_keypoints = keypoints + if keypoints is not None: + assert kp_vis_thr is not None + self.kp_vis_thr = kp_vis_thr + + @torch.no_grad() + def pre_process(self): + if self.pre_processed: + return + boxes = masks_to_boxes(self.segmentation).cpu() + expanded_boxes = [] + included_boxes = [] + for i in range(len(boxes)): + exp_box = get_expanded_bbox( + boxes[i], self.orig_imshape_CHW[1:], self.segmentation[i], **self.exp_bbox_cfg, + target_aspect_ratio=self.target_imsize[0]/self.target_imsize[1]) + if not include_box(exp_box, imshape=self.orig_imshape_CHW[1:], **self.exp_bbox_filter): + continue + included_boxes.append(i) + expanded_boxes.append(exp_box) + expanded_boxes = torch.LongTensor(expanded_boxes).view(-1, 4) + self.segmentation = self.segmentation[included_boxes] + if self.orig_keypoints is not None: + self.keypoints = self.orig_keypoints[included_boxes].clone() + self.keypoints[:, :, 2] = self.keypoints[:, :, 2] >= self.kp_vis_thr + area = self.segmentation.sum(dim=[1, 2]).view(len(expanded_boxes)).cpu() + self.mask = torch.empty((len(expanded_boxes), *self.target_imsize), device=tops.get_device(), dtype=torch.bool) + for i, box in enumerate(expanded_boxes): + self.mask[i] = cut_pad_resize(self.segmentation[i:i+1], box, self.target_imsize)[0] + if self.orig_keypoints is not None: + self.keypoints = shift_and_preprocess_keypoints(self.keypoints, expanded_boxes) + dilation_kernel = get_kernel(int((self.target_imsize[0]*self.target_imsize[1])**0.5*self.dilation_percentage)) + self.maskrcnn_mask = self.mask.clone().logical_not()[:, None] + self.mask = utils.binary_dilation(self.mask[:, None], dilation_kernel) + for i in range(len(expanded_boxes)): + remove_dilate_in_pad(self.mask[i], expanded_boxes[i], self.orig_imshape_CHW[1:]) + self.boxes = expanded_boxes + self.dilated_boxes = get_dilated_boxes(self.boxes, self.mask) + + self.pre_processed = True + self.n_detections = len(self.boxes) + self.mask = self.mask.logical_not() + + sorted_idx = torch.argsort(area, descending=False) + self.mask = self.mask[sorted_idx] + self.boxes = self.boxes[sorted_idx.cpu()] + self.segmentation = self.segmentation[sorted_idx] + self.maskrcnn_mask = self.maskrcnn_mask[sorted_idx] + if self.keypoints is not None: + self.keypoints = self.keypoints[sorted_idx.cpu()] + + def get_crop(self, idx: int, im: torch.Tensor): + assert idx < len(self) + self.pre_process() + box = self.boxes[idx] + mask = self.mask[idx][None].float() + im = cut_pad_resize(im, box, self.target_imsize).unsqueeze(0) + batch = dict( + img=im, mask=mask, boxes=box.reshape(1, -1), + maskrcnn_mask=self.maskrcnn_mask[idx][None].float()) + if self.keypoints is not None: + batch["keypoints"] = self.keypoints[idx:idx+1] + return batch + + def __len__(self): + self.pre_process() + return self.n_detections + + def state_dict(self, **kwargs): + return dict( + segmentation=self.segmentation.bool(), + cls=self.__class__, + orig_imshape_CHW=self.orig_imshape_CHW, + keypoints=self.orig_keypoints + ) + + @staticmethod + def from_state_dict( + state_dict, + post_process_cfg, **kwargs): + return PersonDetection( + state_dict["segmentation"], + orig_imshape_CHW=state_dict["orig_imshape_CHW"], + **post_process_cfg, + keypoints=state_dict["keypoints"]) + + def visualize(self, im): + self.pre_process() + im = im.cpu() + if len(self) == 0: + return im + im = vis_utils.draw_cropped_masks(im.clone(), self.mask.cpu(), self.boxes, visualize_instances=False) + if self.keypoints is not None: + im = vis_utils.draw_cropped_keypoints(im, self.keypoints, self.boxes) + return im + + +def get_dilated_boxes(exp_bbox: torch.LongTensor, mask): + """ + mask: resized mask + """ + assert exp_bbox.shape[0] == mask.shape[0] + boxes = masks_to_boxes(mask.squeeze(1)).cpu() + H, W = exp_bbox[:, 3] - exp_bbox[:, 1], exp_bbox[:, 2] - exp_bbox[:, 0] + boxes[:, [0, 2]] = (boxes[:, [0, 2]] * W[:, None] / mask.shape[-1]).long() + boxes[:, [1, 3]] = (boxes[:, [1, 3]] * H[:, None] / mask.shape[-2]).long() + boxes[:, [0, 2]] += exp_bbox[:, 0:1] + boxes[:, [1, 3]] += exp_bbox[:, 1:2] + return boxes diff --git a/dp2/detection/utils.py b/dp2/detection/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..31bd8cc40dceae5b83bb52e74cdf4be25e764487 --- /dev/null +++ b/dp2/detection/utils.py @@ -0,0 +1,179 @@ +import cv2 +import numpy as np +import torch +import tops +from skimage.morphology import disk +from torchvision.transforms.functional import resize, InterpolationMode +from functools import lru_cache + + +@lru_cache(maxsize=200) +def get_kernel(n: int): + kernel = disk(n, dtype=bool) + return tops.to_cuda(torch.from_numpy(kernel).bool()) + + +def transform_embedding(E: torch.Tensor, S: torch.Tensor, exp_bbox, E_bbox, target_imshape): + """ + Transforms the detected embedding/mask directly to the target image shape + """ + + C, HE, WE = E.shape + assert E_bbox[0] >= exp_bbox[0], (E_bbox, exp_bbox) + assert E_bbox[2] >= exp_bbox[0] + assert E_bbox[1] >= exp_bbox[1] + assert E_bbox[3] >= exp_bbox[1] + assert E_bbox[2] <= exp_bbox[2] + assert E_bbox[3] <= exp_bbox[3] + + x0 = int(np.round((E_bbox[0] - exp_bbox[0]) / (exp_bbox[2] - exp_bbox[0]) * target_imshape[1])) + x1 = int(np.round((E_bbox[2] - exp_bbox[0]) / (exp_bbox[2] - exp_bbox[0]) * target_imshape[1])) + y0 = int(np.round((E_bbox[1] - exp_bbox[1]) / (exp_bbox[3] - exp_bbox[1]) * target_imshape[0])) + y1 = int(np.round((E_bbox[3] - exp_bbox[1]) / (exp_bbox[3] - exp_bbox[1]) * target_imshape[0])) + new_E = torch.zeros((C, *target_imshape), device=E.device, dtype=torch.float32) + new_S = torch.zeros((target_imshape), device=S.device, dtype=torch.bool) + + E = resize(E, (y1-y0, x1-x0), antialias=True, interpolation=InterpolationMode.BILINEAR) + new_E[:, y0:y1, x0:x1] = E + S = resize(S[None].float(), (y1-y0, x1-x0), antialias=True, interpolation=InterpolationMode.BILINEAR)[0] > 0 + new_S[y0:y1, x0:x1] = S + return new_E, new_S + + +def pairwise_mask_iou(mask1: torch.Tensor, mask2: torch.Tensor): + """ + mask: shape [N, H, W] + """ + assert len(mask1.shape) == 3 + assert len(mask2.shape) == 3 + assert mask1.device == mask2.device, (mask1.device, mask2.device) + assert mask2.dtype == mask2.dtype + assert mask1.dtype == torch.bool + assert mask1.shape[1:] == mask2.shape[1:] + N1, H1, W1 = mask1.shape + N2, H2, W2 = mask2.shape + iou = torch.zeros((N1, N2), dtype=torch.float32) + for i in range(N1): + cur = mask1[i:i+1] + inter = torch.logical_and(cur, mask2).flatten(start_dim=1).float().sum(dim=1).cpu() + union = torch.logical_or(cur, mask2).flatten(start_dim=1).float().sum(dim=1).cpu() + iou[i] = inter / union + return iou + + +def find_best_matches(mask1: torch.Tensor, mask2: torch.Tensor, iou_threshold: float): + N1 = mask1.shape[0] + N2 = mask2.shape[0] + ious = pairwise_mask_iou(mask1, mask2).cpu().numpy() + indices = np.array([idx for idx, iou in np.ndenumerate(ious)]) + ious = ious.flatten() + mask = ious >= iou_threshold + ious = ious[mask] + indices = indices[mask] + + # do not sort by iou to keep ordering of mask rcnn / cse sorting. + taken1 = np.zeros((N1), dtype=bool) + taken2 = np.zeros((N2), dtype=bool) + matches = [] + for i, j in indices: + if taken1[i].any() or taken2[j].any(): + continue + matches.append((i, j)) + taken1[i] = True + taken2[j] = True + return matches + + +def combine_cse_maskrcnn_dets(segmentation: torch.Tensor, cse_dets: dict, iou_threshold: float): + assert 0 < iou_threshold <= 1 + matches = find_best_matches(segmentation, cse_dets["im_segmentation"], iou_threshold) + H, W = segmentation.shape[1:] + new_seg = torch.zeros((len(matches), H, W), dtype=torch.bool, device=segmentation.device) + cse_im_seg = cse_dets["im_segmentation"] + for idx, (i, j) in enumerate(matches): + new_seg[idx] = torch.logical_or(segmentation[i], cse_im_seg[j]) + cse_dets = dict( + instance_segmentation=cse_dets["instance_segmentation"][[j for (i, j) in matches]], + instance_embedding=cse_dets["instance_embedding"][[j for (i, j) in matches]], + bbox_XYXY=cse_dets["bbox_XYXY"][[j for (i, j) in matches]], + scores=cse_dets["scores"][[j for (i, j) in matches]], + ) + return new_seg, cse_dets, np.array(matches).reshape(-1, 2) + + +def initialize_cse_boxes(segmentation: torch.Tensor, cse_boxes: torch.Tensor): + """ + cse_boxes can be outside of segmentation. + """ + boxes = masks_to_boxes(segmentation) + + assert boxes.shape == cse_boxes.shape, (boxes.shape, cse_boxes.shape) + combined = torch.stack((boxes, cse_boxes), dim=-1) + boxes = torch.cat(( + combined[:, :2].min(dim=2).values, + combined[:, 2:].max(dim=2).values, + ), dim=1) + return boxes + + +def cut_pad_resize(x: torch.Tensor, bbox, target_shape, fdf_resize=False): + """ + Crops or pads x to fit in the bbox and resize to target shape. + """ + C, H, W = x.shape + x0, y0, x1, y1 = bbox + + if y0 > 0 and x0 > 0 and x1 <= W and y1 <= H: + new_x = x[:, y0:y1, x0:x1] + else: + new_x = torch.zeros(((C, y1-y0, x1-x0)), dtype=x.dtype, device=x.device) + y0_t = max(0, -y0) + y1_t = min(y1-y0, (y1-y0)-(y1-H)) + x0_t = max(0, -x0) + x1_t = min(x1-x0, (x1-x0)-(x1-W)) + x0 = max(0, x0) + y0 = max(0, y0) + x1 = min(x1, W) + y1 = min(y1, H) + new_x[:, y0_t:y1_t, x0_t:x1_t] = x[:, y0:y1, x0:x1] + # Nearest upsampling often generates more sharp synthesized identities. + interp = InterpolationMode.BICUBIC + if (y1-y0) < target_shape[0] and (x1-x0) < target_shape[1]: + interp = InterpolationMode.NEAREST + antialias = interp == InterpolationMode.BICUBIC + if x1 - x0 == target_shape[1] and y1 - y0 == target_shape[0]: + return new_x + if x.dtype == torch.bool: + new_x = resize(new_x.float(), target_shape, interpolation=InterpolationMode.NEAREST) > 0.5 + elif x.dtype == torch.float32: + new_x = resize(new_x, target_shape, interpolation=interp, antialias=antialias) + elif x.dtype == torch.uint8: + if fdf_resize: # FDF dataset is created with cv2 INTER_AREA. + # Incorrect resizing generates noticeable poorer inpaintings. + upsampling = ((y1-y0) * (x1-x0)) < (target_shape[0] * target_shape[1]) + if upsampling: + new_x = resize(new_x.float(), target_shape, interpolation=InterpolationMode.BICUBIC, + antialias=True).round().clamp(0, 255).byte() + else: + device = new_x.device + new_x = new_x.permute(1, 2, 0).cpu().numpy() + new_x = cv2.resize(new_x, target_shape[::-1], interpolation=cv2.INTER_AREA) + new_x = torch.from_numpy(np.rollaxis(new_x, 2)).to(device) + else: + new_x = resize(new_x.float(), target_shape, interpolation=interp, + antialias=antialias).round().clamp(0, 255).byte() + else: + raise ValueError(f"Not supported dtype: {x.dtype}") + return new_x + + +def masks_to_boxes(segmentation: torch.Tensor): + assert len(segmentation.shape) == 3 + x = segmentation.any(dim=1).byte() # Compress rows + x0 = x.argmax(dim=1) + + x1 = segmentation.shape[2] - x.flip(dims=(1,)).argmax(dim=1) + y = segmentation.any(dim=2).byte() + y0 = y.argmax(dim=1) + y1 = segmentation.shape[1] - y.flip(dims=(1,)).argmax(dim=1) + return torch.stack([x0, y0, x1, y1], dim=1) diff --git a/dp2/discriminator/__init__.py b/dp2/discriminator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2c2c41a0f895fcb39d80b0cfeec8795899a04fa9 --- /dev/null +++ b/dp2/discriminator/__init__.py @@ -0,0 +1 @@ +from .sg2_discriminator import SG2Discriminator diff --git a/dp2/discriminator/sg2_discriminator.py b/dp2/discriminator/sg2_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..269675d44fec26f1838b56092bf98e28945a3462 --- /dev/null +++ b/dp2/discriminator/sg2_discriminator.py @@ -0,0 +1,79 @@ +from sg3_torch_utils.ops import upfirdn2d +import torch +import numpy as np +import torch.nn as nn +from .. import layers +from ..layers.sg2_layers import DiscriminatorEpilogue, ResidualBlock, Block + + +class SG2Discriminator(layers.Module): + + def __init__( + self, + cnum: int, + max_cnum_mul: int, + imsize, + min_fmap_resolution: int, + im_channels: int, + input_condition: bool, + conv_clamp: int, + input_cse: bool, + cse_nc: int, + fix_residual: bool, + ): + super().__init__() + + cse_nc = 0 if cse_nc is None else cse_nc + self._max_imsize = max(imsize) + self._cnum = cnum + self._max_cnum_mul = max_cnum_mul + self._min_fmap_resolution = min_fmap_resolution + self._input_condition = input_condition + self.input_cse = input_cse + self.layers = nn.ModuleList() + + out_ch = self.get_chsize(self._max_imsize) + self.from_rgb = Block( + im_channels + input_condition*(im_channels+1) + input_cse*(cse_nc+1), + out_ch, conv_clamp=conv_clamp + ) + n_levels = int(np.log2(self._max_imsize) - np.log2(min_fmap_resolution))+1 + + for i in range(n_levels): + resolution = [x//2**i for x in imsize] + in_ch = self.get_chsize(max(resolution)) + out_ch = self.get_chsize(max(max(resolution)//2, min_fmap_resolution)) + + down = 2 + if i == 0: + down = 1 + block = ResidualBlock( + in_ch, out_ch, down=down, conv_clamp=conv_clamp, + fix_residual=fix_residual + ) + self.layers.append(block) + self.output_layer = DiscriminatorEpilogue( + out_ch, resolution, conv_clamp=conv_clamp) + + self.register_buffer('resample_filter', upfirdn2d.setup_filter([1, 3, 3, 1])) + + def forward(self, img, condition, mask, embedding=None, E_mask=None, **kwargs): + to_cat = [img] + if self._input_condition: + to_cat.extend([condition, mask, ]) + if self.input_cse: + to_cat.extend([embedding, E_mask]) + x = torch.cat(to_cat, dim=1) + x = self.from_rgb(x) + + for i, layer in enumerate(self.layers): + x = layer(x) + + x = self.output_layer(x) + return dict(score=x) + + def get_chsize(self, imsize): + n = int(np.log2(self._max_imsize) - np.log2(imsize)) + mul = min(2 ** n, self._max_cnum_mul) + ch = self._cnum * mul + return int(ch) diff --git a/dp2/gan_trainer.py b/dp2/gan_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..149e0f6be0602e90f26b67551615d0abb96aad56 --- /dev/null +++ b/dp2/gan_trainer.py @@ -0,0 +1,325 @@ +import atexit +from collections import defaultdict +import logging +import typing +import torch +import time +from dp2.utils import vis_utils +from dp2 import utils +from tops import logger, checkpointer +import tops +from easydict import EasyDict + + +def accumulate_gradients(params, fp16_ddp_accumulate): + if len(params) == 0: + return + params = [param for param in params if param.grad is not None] + flat = torch.cat([param.grad.flatten() for param in params]) + orig_dtype = flat.dtype + if tops.world_size() > 1: + if fp16_ddp_accumulate: + flat = flat.half() / tops.world_size() + else: + flat /= tops.world_size() + torch.distributed.all_reduce(flat) + flat = flat.to(orig_dtype) + grads = flat.split([param.numel() for param in params]) + for param, grad in zip(params, grads): + param.grad = grad.reshape(param.shape) + + +def accumulate_buffers(module: torch.nn.Module): + buffers = [buf for buf in module.buffers()] + if len(buffers) == 0: + return + flat = torch.cat([buf.flatten() for buf in buffers]) + if tops.world_size() > 1: + torch.distributed.all_reduce(flat) + flat /= tops.world_size() + bufs = flat.split([buf.numel() for buf in buffers]) + for old, new in zip(buffers, bufs): + old.copy_(new.reshape(old.shape), non_blocking=True) + + +def check_ddp_consistency(module): + if tops.world_size() == 1: + return + assert isinstance(module, torch.nn.Module) + assert isinstance(module, torch.nn.Module) + params_buffs = list(module.named_parameters()) + list(module.named_buffers()) + for name, tensor in params_buffs: + fullname = type(module).__name__ + '.' + name + tensor = tensor.detach() + if tensor.is_floating_point(): + tensor = torch.nan_to_num(tensor) + other = tensor.clone() + torch.distributed.broadcast(tensor=other, src=0) + assert (tensor == other).all(), fullname + + +class AverageMeter(): + def __init__(self) -> None: + self.to_log = dict() + self.n = defaultdict(int) + pass + + @torch.no_grad() + def update(self, values: dict): + for key, value in values.items(): + self.n[key] += 1 + if key in self.to_log: + self.to_log[key] += value.mean().detach() + else: + self.to_log[key] = value.mean().detach() + + def get_average(self): + return {key: value / self.n[key] for key, value in self.to_log.items()} + + +class GANTrainer: + + def __init__( + self, + G: torch.nn.Module, + D: torch.nn.Module, + G_EMA: torch.nn.Module, + D_optim: torch.optim.Optimizer, + G_optim: torch.optim.Optimizer, + dl_train: typing.Iterator, + dl_val: typing.Iterable, + scaler_D: torch.cuda.amp.GradScaler, + scaler_G: torch.cuda.amp.GradScaler, + ims_per_log: int, + max_images_to_train: int, + loss_handler, + ims_per_val: int, + evaluate_fn, + batch_size: int, + broadcast_buffers: bool, + fp16_ddp_accumulate: bool, + save_state: bool, + *args, **kwargs): + super().__init__(*args, **kwargs) + + self.G = G + self.D = D + self.G_EMA = G_EMA + self.D_optim = D_optim + self.G_optim = G_optim + self.dl_train = dl_train + self.dl_val = dl_val + self.scaler_D = scaler_D + self.scaler_G = scaler_G + self.loss_handler = loss_handler + self.max_images_to_train = max_images_to_train + self.images_per_val = ims_per_val + self.images_per_log = ims_per_log + self.evaluate_fn = evaluate_fn + self.batch_size = batch_size + self.broadcast_buffers = broadcast_buffers + self.fp16_ddp_accumulate = fp16_ddp_accumulate + + self.train_state = EasyDict( + next_log_step=0, + next_val_step=ims_per_val, + total_time=0 + ) + + checkpointer.register_models(dict( + generator=G, discriminator=D, EMA_generator=G_EMA, + D_optimizer=D_optim, + G_optimizer=G_optim, + train_state=self.train_state, + scaler_D=self.scaler_D, + scaler_G=self.scaler_G + )) + if checkpointer.has_checkpoint(): + checkpointer.load_registered_models() + logger.log(f"Resuming training from: global step: {logger.global_step()}") + else: + logger.add_dict({ + "stats/discriminator_parameters": tops.num_parameters(self.D), + "stats/generator_parameters": tops.num_parameters(self.G), + }, commit=False) + if save_state: + # If the job is unexpectedly killed, there could be a mismatch between previously saved checkpoint and the current checkpoint. + atexit.register(checkpointer.save_registered_models) + + self._ims_per_log = ims_per_log + + self.to_log = AverageMeter() + self.trainable_params_D = [param for param in self.D.parameters() if param.requires_grad] + self.trainable_params_G = [param for param in self.G.parameters() if param.requires_grad] + logger.add_dict({ + "stats/discriminator_trainable_parameters": sum(p.numel() for p in self.trainable_params_D), + "stats/generator_trainable_parameters": sum(p.numel() for p in self.trainable_params_G), + }, commit=False, level=logging.INFO) + check_ddp_consistency(self.D) + check_ddp_consistency(self.G) + check_ddp_consistency(self.G_EMA.generator) + + def train_loop(self): + self.log_time() + while logger.global_step() <= self.max_images_to_train: + batch = next(self.dl_train) + self.G_EMA.update_beta() + self.to_log.update(self.step_D(batch)) + self.to_log.update(self.step_G(batch)) + self.G_EMA.update(self.G) + + if logger.global_step() >= self.train_state.next_log_step: + to_log = {f"loss/{key}": item.item() for key, item in self.to_log.get_average().items()} + to_log.update({"amp/grad_scale_G": self.scaler_G.get_scale()}) + to_log.update({"amp/grad_scale_D": self.scaler_D.get_scale()}) + self.to_log = AverageMeter() + logger.add_dict(to_log, commit=True) + self.train_state.next_log_step += self.images_per_log + if self.scaler_D.get_scale() < 1e-8 or self.scaler_G.get_scale() < 1e-8: + print("Stopping training as gradient scale < 1e-8") + logger.log("Stopping training as gradient scale < 1e-8") + break + + if logger.global_step() >= self.train_state.next_val_step: + self.evaluate() + self.log_time() + self.save_images() + self.train_state.next_val_step += self.images_per_val + logger.step(self.batch_size*tops.world_size()) + logger.log(f"Reached end of training at step {logger.global_step()}.") + checkpointer.save_registered_models() + + def estimate_ims_per_hour(self): + batch = next(self.dl_train) + n_ims = int(100e3) + n_steps = int(n_ims / (self.batch_size * tops.world_size())) + n_ims = n_steps * self.batch_size * tops.world_size() + for i in range(10): # Warmup + self.G_EMA.update_beta() + self.step_D(batch) + self.step_G(batch) + self.G_EMA.update(self.G) + start_time = time.time() + for i in utils.tqdm_(list(range(n_steps))): + self.G_EMA.update_beta() + self.step_D(batch) + self.step_G(batch) + self.G_EMA.update(self.G) + total_time = time.time() - start_time + ims_per_sec = n_ims / total_time + ims_per_hour = ims_per_sec * 60*60 + ims_per_day = ims_per_hour * 24 + logger.log(f"Images per hour: {ims_per_hour/1e6:.3f}M") + logger.log(f"Images per day: {ims_per_day/1e6:.3f}M") + import math + ims_per_4_day = int(math.ceil(ims_per_day / tops.world_size() * 4)) + logger.log(f"Images per 4 days: {ims_per_4_day}") + logger.add_dict({ + "stats/ims_per_day": ims_per_day, + "stats/ims_per_4_day": ims_per_4_day + }) + + def log_time(self): + if not hasattr(self, "start_time"): + self.start_time = time.time() + self.last_time_step = logger.global_step() + return + n_images = logger.global_step() - self.last_time_step + if n_images == 0: + return + n_secs = time.time() - self.start_time + n_ims_per_sec = n_images / n_secs + training_time_hours = n_secs / 60 / 60 + self.train_state.total_time += training_time_hours + remaining_images = self.max_images_to_train - logger.global_step() + remaining_time = remaining_images / n_ims_per_sec / 60 / 60 + logger.add_dict({ + "stats/n_ims_per_sec": n_ims_per_sec, + "stats/total_traing_time_hours": self.train_state.total_time, + "stats/remaining_time_hours": remaining_time + }) + self.last_time_step = logger.global_step() + self.start_time = time.time() + + def save_images(self): + dl_val = iter(self.dl_val) + batch = next(dl_val) + # TRUNCATED visualization + ims_to_log = 8 + self.G_EMA.eval() + z = self.G.get_z(batch["img"]) + fakes_truncated = self.G_EMA.sample(**batch, truncation_value=0)["img"] + fakes_truncated = utils.denormalize_img(fakes_truncated).mul(255).byte()[:ims_to_log].cpu() + if "__key__" in batch: + batch.pop("__key__") + real = vis_utils.visualize_batch(**tops.to_cpu(batch))[:ims_to_log] + to_vis = torch.cat((real, fakes_truncated)) + logger.add_images("images/truncated", to_vis, nrow=2) + + # Diverse images + ims_diverse = 3 + batch = next(dl_val) + to_vis = [] + + for i in range(ims_diverse): + z = self.G.get_z(batch["img"])[:1].repeat(batch["img"].shape[0], 1) + fakes = utils.denormalize_img(self.G_EMA(**batch, z=z)["img"]).mul(255).byte()[:ims_to_log].cpu() + to_vis.append(fakes) + if "__key__" in batch: + batch.pop("__key__") + reals = vis_utils.visualize_batch(**tops.to_cpu(batch))[:ims_to_log] + to_vis.insert(0, reals) + to_vis = torch.cat(to_vis) + logger.add_images("images/diverse", to_vis, nrow=ims_diverse+1) + + self.G_EMA.train() + pass + + def evaluate(self): + logger.log("Stating evaluation.") + self.G_EMA.eval() + try: + checkpointer.save_registered_models(max_keep=3) + except Exception: + logger.log("Could not save checkpoint.") + if self.broadcast_buffers: + check_ddp_consistency(self.G) + check_ddp_consistency(self.D) + metrics = self.evaluate_fn(generator=self.G_EMA, dataloader=self.dl_val) + metrics = {f"metrics/{k}": v for k, v in metrics.items()} + logger.add_dict(metrics, level=logger.logger.INFO) + + def step_D(self, batch): + utils.set_requires_grad(self.trainable_params_D, True) + utils.set_requires_grad(self.trainable_params_G, False) + tops.zero_grad(self.D) + loss, to_log = self.loss_handler.D_loss(batch, grad_scaler=self.scaler_D) + with torch.autograd.profiler.record_function("D_step"): + self.scaler_D.scale(loss).backward() + accumulate_gradients(self.trainable_params_D, fp16_ddp_accumulate=self.fp16_ddp_accumulate) + if self.broadcast_buffers: + accumulate_buffers(self.D) + accumulate_buffers(self.G) + # Step will not unscale if unscale is called previously. + self.scaler_D.step(self.D_optim) + self.scaler_D.update() + utils.set_requires_grad(self.trainable_params_D, False) + utils.set_requires_grad(self.trainable_params_G, False) + return to_log + + def step_G(self, batch): + utils.set_requires_grad(self.trainable_params_D, False) + utils.set_requires_grad(self.trainable_params_G, True) + tops.zero_grad(self.G) + loss, to_log = self.loss_handler.G_loss(batch, grad_scaler=self.scaler_G) + with torch.autograd.profiler.record_function("G_step"): + self.scaler_G.scale(loss).backward() + accumulate_gradients(self.trainable_params_G, fp16_ddp_accumulate=self.fp16_ddp_accumulate) + if self.broadcast_buffers: + accumulate_buffers(self.G) + accumulate_buffers(self.D) + self.scaler_G.step(self.G_optim) + self.scaler_G.update() + utils.set_requires_grad(self.trainable_params_D, False) + utils.set_requires_grad(self.trainable_params_G, False) + return to_log diff --git a/dp2/generator/__init__.py b/dp2/generator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/dp2/generator/base.py b/dp2/generator/base.py new file mode 100644 index 0000000000000000000000000000000000000000..ad403785887b8db971956a23ac8bdd330a04c509 --- /dev/null +++ b/dp2/generator/base.py @@ -0,0 +1,149 @@ +import torch +import numpy as np +import tqdm +import tops +from ..layers import Module +from ..layers.sg2_layers import FullyConnectedLayer + + +class BaseGenerator(Module): + + def __init__(self, z_channels: int): + super().__init__() + self.z_channels = z_channels + self.latent_space = "Z" + + @torch.no_grad() + def get_z( + self, + x: torch.Tensor = None, + z: torch.Tensor = None, + truncation_value: float = None, + batch_size: int = None, + dtype=None, device=None) -> torch.Tensor: + """Generates a latent variable for generator. + """ + if z is not None: + return z + if x is not None: + batch_size = x.shape[0] + dtype = x.dtype + device = x.device + if device is None: + device = tops.get_device() + if truncation_value == 0: + return torch.zeros((batch_size, self.z_channels), device=device, dtype=dtype) + z = torch.randn((batch_size, self.z_channels), device=device, dtype=dtype) + if truncation_value is None: + return z + while z.abs().max() > truncation_value: + m = z.abs() > truncation_value + z[m] = torch.rand_like(z)[m] + return z + + def sample(self, truncation_value, z=None, **kwargs): + """ + Samples via interpolating to the mean (0). + """ + if truncation_value is None: + return self.forward(**kwargs) + truncation_value = max(0, truncation_value) + truncation_value = min(truncation_value, 1) + if z is None: + z = self.get_z(kwargs["condition"]) + z = z * truncation_value + return self.forward(**kwargs, z=z) + + +class SG2StyleNet(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality. + w_dim, # Intermediate latent (W) dimensionality. + num_layers=2, # Number of mapping layers. + lr_multiplier=0.01, # Learning rate multiplier for the mapping layers. + w_avg_beta=0.998, # Decay for tracking the moving average of W during training. + ): + super().__init__() + self.z_dim = z_dim + self.w_dim = w_dim + self.num_layers = num_layers + self.w_avg_beta = w_avg_beta + # Construct layers. + features = [self.z_dim] + [self.w_dim] * self.num_layers + for idx, in_features, out_features in zip(range(num_layers), features[:-1], features[1:]): + layer = FullyConnectedLayer(in_features, out_features, activation='lrelu', lr_multiplier=lr_multiplier) + setattr(self, f'fc{idx}', layer) + self.register_buffer('w_avg', torch.zeros([w_dim])) + + def forward(self, z, update_emas=False, **kwargs): + tops.assert_shape(z, [None, self.z_dim]) + + # Embed, normalize, and concatenate inputs. + x = z.to(torch.float32) + x = x * (x.square().mean(1, keepdim=True) + 1e-8).rsqrt() + # Execute layers. + for idx in range(self.num_layers): + x = getattr(self, f'fc{idx}')(x) + # Update moving average of W. + if update_emas: + self.w_avg.copy_(x.float().detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) + + return x + + def extra_repr(self): + return f'z_dim={self.z_dim:d}, w_dim={self.w_dim:d}' + + def update_w(self, n=int(10e3), batch_size=32): + """ + Calculate w_ema over n iterations. + Useful in cases where w_ema is calculated incorrectly during training. + """ + n = n // batch_size + for i in tqdm.trange(n, desc="Updating w"): + z = torch.randn((batch_size, self.z_dim), device=tops.get_device()) + self(z, update_emas=True) + + def get_truncated(self, truncation_value, condition, z=None, **kwargs): + if z is None: + z = torch.randn((condition.shape[0], self.z_dim), device=tops.get_device()) + w = self(z) + truncation_value = max(0, truncation_value) + truncation_value = min(truncation_value, 1) + return self.w_avg.to(w.dtype).lerp(w, truncation_value) + + def multi_modal_truncate(self, truncation_value, condition, w_indices, z=None, **kwargs): + truncation_value = max(0, truncation_value) + truncation_value = min(truncation_value, 1) + if z is None: + z = torch.randn((condition.shape[0], self.z_dim), device=tops.get_device()) + w = self(z) + if w_indices is None: + w_indices = np.random.randint(0, len(self.w_centers), size=(len(w))) + w_centers = self.w_centers[w_indices].to(w.device) + w = w_centers.to(w.dtype).lerp(w, truncation_value) + return w + +class BaseStyleGAN(BaseGenerator): + + def __init__(self, z_channels: int, w_dim: int): + super().__init__(z_channels) + self.style_net = SG2StyleNet(z_channels, w_dim) + self.latent_space = "W" + + def get_w(self, z, update_emas): + return self.style_net(z, update_emas=update_emas) + + @torch.no_grad() + def sample(self, truncation_value, **kwargs): + if truncation_value is None: + return self.forward(**kwargs) + w = self.style_net.get_truncated(truncation_value, **kwargs) + return self.forward(**kwargs, w=w) + + def update_w(self, *args, **kwargs): + self.style_net.update_w(*args, **kwargs) + + @torch.no_grad() + def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs): + w = self.style_net.multi_modal_truncate(truncation_value, w_indices=w_indices, **kwargs) + return self.forward(**kwargs, w=w) diff --git a/dp2/generator/deep_privacy1.py b/dp2/generator/deep_privacy1.py new file mode 100644 index 0000000000000000000000000000000000000000..531d780ad995081b233c8d8d73242571eb7d33c4 --- /dev/null +++ b/dp2/generator/deep_privacy1.py @@ -0,0 +1,648 @@ +import torch +import torch.nn as nn +from easydict import EasyDict +from .base import BaseGenerator +import numpy as np +from typing import List + + +class LatentVariableConcat(nn.Module): + + def __init__(self, conv2d_config): + super().__init__() + + def forward(self, _inp): + x, mask, batch = _inp + z = batch["z"] + x = torch.cat((x, z), dim=1) + return (x, mask, batch) + + +def get_padding(kernel_size: int, dilation: int, stride: int): + out = (dilation * (kernel_size - 1) - 1) / 2 + 1 + return int(np.floor(out)) + + +class Conv2d(nn.Conv2d): + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=None, dilation=1, groups=1, + bias=True, padding_mode='zeros', + demodulation=False, wsconv=False, gain=1, + *args, **kwargs): + if padding is None: + padding = get_padding(kernel_size, dilation, stride) + super().__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, bias, padding_mode) + self.demodulation = demodulation + self.wsconv = wsconv + if self.wsconv: + fan_in = np.prod(self.weight.shape[1:]) / self.groups + self.ws_scale = gain / np.sqrt(fan_in) + nn.init.normal_(self.weight) + if bias: + nn.init.constant_(self.bias, val=0) + assert not self.padding_mode == "circular",\ + "conv2d_forward does not support circular padding. Look at original pytorch code" + + def _get_weight(self): + weight = self.weight + if self.wsconv: + weight = self.ws_scale * weight + if self.demodulation: + demod = torch.rsqrt(weight.pow(2).sum([1, 2, 3]) + 1e-7) + weight = weight * demod.view(self.out_channels, 1, 1, 1) + return weight + + def conv2d_forward(self, x, weight, bias=True): + bias_ = None + if bias: + bias_ = self.bias + return nn.functional.conv2d(x, weight, bias_, self.stride, + self.padding, self.dilation, self.groups) + + def forward(self, _inp): + x, mask = _inp + weight = self._get_weight() + return self.conv2d_forward(x, weight), mask + + def __repr__(self): + return ", ".join([ + super().__repr__(), + f"Demodulation={self.demodulation}", + f"Weight Scale={self.wsconv}", + f"Bias={self.bias is not None}" + ]) + + +class LeakyReLU(nn.LeakyReLU): + + def forward(self, _inp): + x, mask = _inp + return super().forward(x), mask + + +class AvgPool2d(nn.AvgPool2d): + + def forward(self, _inp): + x, mask, *args = _inp + x = super().forward(x) + mask = super().forward(mask) + if len(args) > 0: + return (x, mask, *args) + return x, mask + + +def up(x): + if x.shape[0] == 1 and x.shape[2] == 1 and x.shape[3] == 1: + # Analytical normalization + return x + return nn.functional.interpolate( + x, scale_factor=2, mode="nearest") + + +class NearestUpsample(nn.Module): + + def forward(self, _inp): + x, mask, *args = _inp + x = up(x) + mask = up(mask) + if len(args) > 0: + return (x, mask, *args) + return x, mask + + +class PixelwiseNormalization(nn.Module): + + def forward(self, _inp): + x, mask = _inp + norm = torch.rsqrt((x**2).mean(dim=1, keepdim=True) + 1e-7) + return x * norm, mask + + +class Linear(nn.Linear): + + def __init__(self, in_features, out_features): + super().__init__(in_features, out_features) + self.linear = nn.Linear(in_features, out_features) + fanIn = in_features + self.wtScale = 1 / np.sqrt(fanIn) + + nn.init.normal_(self.weight) + nn.init.constant_(self.bias, val=0) + + def _get_weight(self): + return self.weight * self.wtScale + + def forward_linear(self, x, weight): + return nn.functional.linear(x, weight, self.bias) + + def forward(self, x): + return self.forward_linear(x, self._get_weight()) + + +class OneHotPoseConcat(nn.Module): + + def forward(self, _inp): + x, mask, batch = _inp + landmarks = batch["landmarks_oh"] + res = x.shape[-1] + landmark = landmarks[res] + x = torch.cat((x, landmark), dim=1) + del batch["landmarks_oh"][res] + return x, mask, batch + + +def transition_features(x_old, x_new, transition_variable): + assert x_old.shape == x_new.shape,\ + "Old shape: {}, New: {}".format(x_old.shape, x_new.shape) + return torch.lerp(x_old.float(), x_new.float(), transition_variable) + + +class TransitionBlock(nn.Module): + + def forward(self, _inp): + x, mask, batch = _inp + x = transition_features( + batch["x_old"], x, batch["transition_value"]) + mask = transition_features( + batch["mask_old"], mask, batch["transition_value"]) + del batch["x_old"] + del batch["mask_old"] + return x, mask, batch + + +class UnetSkipConnection(nn.Module): + + def __init__(self, conv2d_config: dict, in_channels: int, + out_channels: int, resolution: int, + residual: bool, enabled: bool): + super().__init__() + self.use_iconv = conv2d_config.conv.type == "iconv" + self._in_channels = in_channels + self._out_channels = out_channels + self._resolution = resolution + self._enabled = enabled + self._residual = residual + if self.use_iconv: + self.beta0 = torch.nn.Parameter(torch.tensor(1.)) + self.beta1 = torch.nn.Parameter(torch.tensor(1.)) + else: + if self._residual: + self.conv = build_base_conv( + conv2d_config, False, in_channels // 2, + out_channels, kernel_size=1, padding=0) + else: + self.conv = ConvAct( + conv2d_config, in_channels, out_channels, + kernel_size=1, padding=0) + + def forward(self, _inp): + if not self._enabled: + return _inp + x, mask, batch = _inp + skip_x, skip_mask = batch["unet_features"][self._resolution] + assert x.shape == skip_x.shape, (x.shape, skip_x.shape) + del batch["unet_features"][self._resolution] + if self.use_iconv: + denom = skip_mask * self.beta0.relu() + mask * self.beta1.relu() + 1e-8 + gamma = skip_mask * self.beta0.relu() / denom + x = skip_x * gamma + (1 - gamma) * x + mask = skip_mask * gamma + (1 - gamma) * mask + else: + if self._residual: + skip_x, skip_mask = self.conv((skip_x, skip_mask)) + x = (x + skip_x) / np.sqrt(2) + if self._probabilistic: + mask = (mask + skip_mask) / np.sqrt(2) + else: + x = torch.cat((x, skip_x), dim=1) + x, mask = self.conv((x, mask)) + return x, mask, batch + + def __repr__(self): + return " ".join([ + self.__class__.__name__, + f"In channels={self._in_channels}", + f"Out channels={self._out_channels}", + f"Residual: {self._residual}", + f"Enabled: {self._enabled}" + f"IConv: {self.use_iconv}" + ]) + + +def get_conv(ctype, post_act): + type2conv = { + "conv": Conv2d, + "gconv": GatedConv + } + # Do not apply for output layer + if not post_act and ctype in ["gconv", "iconv"]: + return type2conv["conv"] + assert ctype in type2conv + return type2conv[ctype] + + +def build_base_conv( + conv2d_config, post_act: bool, *args, **kwargs) -> nn.Conv2d: + for k, v in conv2d_config.conv.items(): + assert k not in kwargs + kwargs[k] = v + # Demodulation should not be used for output layers. + demodulation = conv2d_config.normalization == "demodulation" and post_act + kwargs["demodulation"] = demodulation + conv = get_conv(conv2d_config.conv.type, post_act) + return conv(*args, **kwargs) + + +def build_post_activation(in_channels, conv2d_config) -> List[nn.Module]: + _layers = [] + negative_slope = conv2d_config.leaky_relu_nslope + _layers.append(LeakyReLU(negative_slope, inplace=True)) + if conv2d_config.normalization == "pixel_wise": + _layers.append(PixelwiseNormalization()) + return _layers + + +def build_avgpool(conv2d_config, kernel_size) -> nn.AvgPool2d: + return AvgPool2d(kernel_size) + + +def build_convact(conv2d_config, *args, **kwargs): + conv = build_base_conv(conv2d_config, True, *args, **kwargs) + out_channels = conv.out_channels + post_act = build_post_activation(out_channels, conv2d_config) + return nn.Sequential(conv, *post_act) + + +class ConvAct(nn.Module): + + def __init__(self, conv2d_config, *args, **kwargs): + super().__init__() + self._conv2d_config = conv2d_config + conv = build_base_conv(conv2d_config, True, *args, **kwargs) + self.in_channels = conv.in_channels + self.out_channels = conv.out_channels + _layers = [conv] + _layers.extend(build_post_activation(self.out_channels, conv2d_config)) + self.layers = nn.Sequential(*_layers) + + def forward(self, _inp): + return self.layers(_inp) + + +class GatedConv(Conv2d): + + def __init__(self, in_channels, out_channels, *args, **kwargs): + out_channels *= 2 + super().__init__(in_channels, out_channels, *args, **kwargs) + assert self.out_channels % 2 == 0 + self.lrelu = nn.LeakyReLU(0.2, inplace=True) + self.sigmoid = nn.Sigmoid() + + def conv2d_forward(self, x, weight, bias=True): + x_ = super().conv2d_forward(x, weight, bias) + x = x_[:, :self.out_channels // 2] + y = x_[:, self.out_channels // 2:] + x = self.lrelu(x) + y = y.sigmoid() + assert x.shape == y.shape, f"{x.shape}, {y.shape}" + return x * y + + +class BasicBlock(nn.Module): + + def __init__( + self, conv2d_config, resolution: int, in_channels: int, + out_channels: List[int], residual: bool): + super().__init__() + assert len(out_channels) == 2 + self._resolution = resolution + self._residual = residual + self.out_channels = out_channels + _layers = [] + _in_channels = in_channels + for out_ch in out_channels: + conv = build_base_conv( + conv2d_config, True, _in_channels, out_ch, kernel_size=3, + resolution=resolution) + _layers.append(conv) + _layers.extend(build_post_activation(_in_channels, conv2d_config)) + _in_channels = out_ch + self.layers = nn.Sequential(*_layers) + if self._residual: + self.residual_conv = build_base_conv( + conv2d_config, post_act=False, in_channels=in_channels, + out_channels=out_channels[-1], + kernel_size=1, padding=0) + self.const = 1 / np.sqrt(2) + + def forward(self, _inp): + x, mask, batch = _inp + y = x + mask_ = mask + assert y.shape[-1] == self._resolution or y.shape[-1] == 1 + y, mask = self.layers((x, mask)) + if self._residual: + residual, mask_ = self.residual_conv((x, mask_)) + y = (y + residual) * self.const + mask = (mask + mask_) * self.const + return y, mask, batch + + def extra_repr(self): + return f"Residual={self._residual}, Resolution={self._resolution}" + + +class PoseNormalize(nn.Module): + + @torch.no_grad() + def forward(self, x): + return x * 2 - 1 + + +class ScalarPoseFCNN(nn.Module): + + def __init__(self, pose_size, hidden_size, + output_shape): + super().__init__() + pose_size = pose_size + self._hidden_size = hidden_size + output_size = np.prod(output_shape) + self.output_shape = output_shape + self.pose_preprocessor = nn.Sequential( + PoseNormalize(), + Linear(pose_size, hidden_size), + nn.LeakyReLU(.2), + Linear(hidden_size, output_size), + nn.LeakyReLU(.2) + ) + + def forward(self, _inp): + x, mask, batch = _inp + pose_info = batch["landmarks"] + del batch["landmarks"] + pose = self.pose_preprocessor(pose_info) + pose = pose.view(-1, *self.output_shape) + if x.shape[0] == 1 and x.shape[2] == 1 and x.shape[3] == 1: + # Analytical normalization propagation + pose = pose.mean(dim=2, keepdim=True).mean(dim=3, keepdims=True) + x = torch.cat((x, pose), dim=1) + return x, mask, batch + + def __repr__(self): + return " ".join([ + self.__class__.__name__, + f"hidden_size={self._hidden_size}", + f"output shape={self.output_shape}" + ]) + + +class Attention(nn.Module): + + def __init__(self, in_channels): + super(Attention, self).__init__() + # Channel multiplier + self.in_channels = in_channels + self.theta = Conv2d( + self.in_channels, self.in_channels // 8, kernel_size=1, padding=0, + bias=False) + self.phi = Conv2d( + self.in_channels, self.in_channels // 8, kernel_size=1, padding=0, + bias=False) + self.g = Conv2d( + self.in_channels, self.in_channels // 2, kernel_size=1, padding=0, + bias=False) + self.o = Conv2d( + self.in_channels // 2, self.in_channels, kernel_size=1, padding=0, + bias=False) + # Learnable gain parameter + self.gamma = nn.Parameter(torch.tensor(0.), requires_grad=True) + + def forward(self, _inp): + x, mask, batch = _inp + # Apply convs + theta, _ = self.theta((x, None)) + phi = nn.functional.max_pool2d(self.phi((x, None))[0], [2, 2]) + g = nn.functional.max_pool2d(self.g((x, None))[0], [2, 2]) + # Perform reshapes + theta = theta.view(-1, self.in_channels // 8, x.shape[2] * x.shape[3]) + phi = phi.view(-1, self.in_channels // 8, x.shape[2] * x.shape[3] // 4) + g = g.view(-1, self.in_channels // 2, x.shape[2] * x.shape[3] // 4) + # Matmul and softmax to get attention maps + beta = nn.functional.softmax(torch.bmm(theta.transpose(1, 2), phi), -1) + # Attention map times g path + + o = self.o((torch.bmm(g, beta.transpose(1, 2)).view(-1, + self.in_channels // 2, x.shape[2], x.shape[3]), None))[0] + return self.gamma * o + x, mask, batch + + +class MSGGenerator(BaseGenerator): + + def __init__(self): + super().__init__(512) + max_imsize = 128 + unet = dict(enabled=True, residual=False) + + min_fmap_resolution = 4 + model_size = 512 + image_channels = 3 + pose_size = 14 + residual = False + conv_size = { + 4: model_size, + 8: model_size, + 16: model_size, + 32: model_size, + 64: model_size//2, + 128: model_size//4, + 256: model_size//8, + 512: model_size//16 + } + self.removable_hooks = [] + self.rgb_convolutions = nn.ModuleDict() + self.max_imsize = max_imsize + self._image_channels = image_channels + self._min_fmap_resolution = min_fmap_resolution + self._residual = residual + self._pose_size = pose_size + self.current_imsize = max_imsize + self._unet_cfg = unet + self.concat_input_mask = True + self.res2channels = {int(k): v for k, v in conv_size.items()} + + self.conv2d_config = EasyDict( + pixel_normalization=True, + leaky_relu_nslope=.2, + normalization="pixel_wise", + conv=dict( + type="conv", + wsconv=True, + gain=1, + ) + ) + self._init_decoder() + self._init_encoder() + + def _init_encoder(self): + self.encoder = nn.ModuleList() + imsize = self.max_imsize + self.from_rgb = build_convact( + self.conv2d_config, + in_channels=self._image_channels + self.concat_input_mask*2, + out_channels=self.res2channels[imsize], + kernel_size=1) + while imsize >= self._min_fmap_resolution: + current_size = self.res2channels[imsize] + next_size = self.res2channels[max(imsize//2, self._min_fmap_resolution)] + block = BasicBlock( + self.conv2d_config, imsize, current_size, + [current_size, next_size], self._residual) + self.encoder.add_module(f"basic_block{imsize}", block) + if imsize != self._min_fmap_resolution: + self.encoder.add_module( + f"downsample{imsize}", AvgPool2d(2)) + imsize //= 2 + + def _init_decoder(self): + self.decoder = nn.ModuleList() + self.decoder.add_module( + "latent_concat", LatentVariableConcat(self.conv2d_config)) + if self._pose_size > 0: + m = self._min_fmap_resolution + pose_shape = (16, m, m) + pose_fcnn = ScalarPoseFCNN(self._pose_size, 128, pose_shape) + self.decoder.add_module("pose_fcnn", pose_fcnn) + imsize = self._min_fmap_resolution + self.rgb_convolutions = nn.ModuleDict() + while imsize <= self.max_imsize: + current_size = self.res2channels[max(imsize//2, self._min_fmap_resolution)] + start_size = current_size + if imsize == self._min_fmap_resolution: + start_size += 32 + if self._pose_size > 0: + start_size += 16 + else: + self.decoder.add_module(f"upsample{imsize}", NearestUpsample()) + skip = UnetSkipConnection( + self.conv2d_config, current_size*2, current_size, imsize, + **self._unet_cfg) + self.decoder.add_module(f"skip_connection{imsize}", skip) + next_size = self.res2channels[imsize] + block = BasicBlock( + self.conv2d_config, imsize, start_size, [start_size, next_size], + residual=self._residual) + self.decoder.add_module(f"basic_block{imsize}", block) + + to_rgb = build_base_conv( + self.conv2d_config, False, in_channels=next_size, + out_channels=self._image_channels, kernel_size=1) + self.rgb_convolutions[str(imsize)] = to_rgb + imsize *= 2 + self.norm_constant = len(self.rgb_convolutions) + + def forward_decoder(self, x, mask, batch): + imsize_start = max(x.shape[-1] // 2, 1) + rgb = torch.zeros( + (x.shape[0], self._image_channels, + imsize_start, imsize_start), + dtype=x.dtype, device=x.device) + mask_size = 1 + mask_out = torch.zeros( + (x.shape[0], mask_size, + imsize_start, imsize_start), + dtype=x.dtype, device=x.device) + imsize = self._min_fmap_resolution // 2 + for module in self.decoder: + x, mask, batch = module((x, mask, batch)) + if isinstance(module, BasicBlock): + imsize *= 2 + rgb = up(rgb) + mask_out = up(mask_out) + conv = self.rgb_convolutions[str(imsize)] + rgb_, mask_ = conv((x, mask)) + assert rgb_.shape == rgb.shape,\ + f"rgb_ {rgb_.shape}, rgb: {rgb.shape}" + rgb = rgb + rgb_ + return rgb / self.norm_constant, mask_out + + def forward_encoder(self, x, mask, batch): + if self.concat_input_mask: + x = torch.cat((x, mask, 1 - mask), dim=1) + unet_features = {} + x, mask = self.from_rgb((x, mask)) + for module in self.encoder: + x, mask, batch = module((x, mask, batch)) + if isinstance(module, BasicBlock): + unet_features[module._resolution] = (x, mask) + return x, mask, unet_features + + def forward( + self, + condition, + mask, keypoints=None, z=None, + **kwargs): + keypoints = keypoints.flatten(start_dim=1).clip(-1, 1) + if z is None: + z = self.get_z(condition) + z = z.view(-1, 32, 4, 4) + batch = dict( + landmarks=keypoints, + z=z) + orig_mask = mask + x, mask, unet_features = self.forward_encoder(condition, mask, batch) + batch = dict( + landmarks=keypoints, + z=z, + unet_features=unet_features) + x, mask = self.forward_decoder(x, mask, batch) + x = condition * orig_mask + (1 - orig_mask) * x + return dict(img=x) + + def load_state_dict(self, state_dict, strict=True): + if "parameters" in state_dict: + state_dict = state_dict["parameters"] + old_checkpoint = any("basic_block0" in key for key in state_dict) + if not old_checkpoint: + return super().load_state_dict(state_dict, strict=strict) + mapping = {} + imsize = self._min_fmap_resolution + i = 0 + while imsize <= self.max_imsize: + old_key = f"decoder.basic_block{i}." + new_key = f"decoder.basic_block{imsize}." + mapping[old_key] = new_key + if i >= 1: + old_key = old_key.replace("basic_block", "skip_connection") + new_key = new_key.replace("basic_block", "skip_connection") + mapping[old_key] = new_key + mapping[old_key] = new_key + old_key = f"encoder.basic_block{i}." + new_key = f"encoder.basic_block{imsize}." + mapping[old_key] = new_key + old_key = "from_rgb.conv.layers.0." + new_key = "from_rgb.0." + mapping[old_key] = new_key + i += 1 + imsize *= 2 + new_sd = {} + for key, value in state_dict.items(): + old_key = key + if "from_rgb" in key: + new_sd[key.replace("encoder.", "").replace(".conv.layers", "")] = value + continue + for subkey, new_subkey in mapping.items(): + if subkey in key: + old_key = key + key = key.replace(subkey, new_subkey) + + break + if "decoder.to_rgb" in key: + continue + + new_sd[key] = value + return super().load_state_dict(new_sd, strict=strict) + + def update_w(self, *args, **kwargs): + return diff --git a/dp2/generator/dummy_generators.py b/dp2/generator/dummy_generators.py new file mode 100644 index 0000000000000000000000000000000000000000..c81b4d4f70bd84fb42bfe8ab3d6bd06f918533c4 --- /dev/null +++ b/dp2/generator/dummy_generators.py @@ -0,0 +1,60 @@ +import torch +from .base import BaseGenerator +from torchvision.transforms.functional import gaussian_blur +import torch.nn.functional as F + + +class PixelationGenerator(BaseGenerator): + + def __init__(self, pixelation_size, **kwargs): + super().__init__(z_channels=0) + self.pixelation_size = pixelation_size + self.z_channels = 0 + self.latent_space = None + + def forward(self, img, condition, mask, **kwargs): + old_shape = img.shape[-2:] + img = F.interpolate(img, size=( + self.pixelation_size, self.pixelation_size), mode="bilinear", align_corners=True) + img = F.interpolate(img, size=old_shape, mode="bilinear", align_corners=True) + out = img*(1-mask) + condition*mask + return {"img": out} + + +class MaskOutGenerator(BaseGenerator): + + def __init__(self, noise: str, **kwargs): + super().__init__(z_channels=0) + self.noise = noise + self.z_channels = 0 + assert self.noise in ["rand", "constant"] + self.latent_space = None + + def forward(self, img, condition, mask, **kwargs): + + if self.noise == "constant": + img = torch.zeros_like(img) + elif self.noise == "rand": + img = torch.rand_like(img) + out = img*(1-mask) + condition*mask + return {"img": out} + + +class IdentityGenerator(BaseGenerator): + + def __init__(self): + super().__init__(z_channels=0) + + def forward(self, img, condition, mask, **kwargs): + return dict(img=img) + + +class GaussianBlurGenerator(BaseGenerator): + + def __init__(self): + super().__init__(z_channels=0) + self.sigma = 7 + + def forward(self, img, condition, mask, **kwargs): + img_blur = gaussian_blur(img, kernel_size=min(self.sigma*3, img.shape[-1]), sigma=self.sigma) + return dict(img=img * mask + (1-mask) * img_blur) diff --git a/dp2/generator/stylegan_unet.py b/dp2/generator/stylegan_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..6c3dfc46da323d04919cf5c166ec038820eac1ad --- /dev/null +++ b/dp2/generator/stylegan_unet.py @@ -0,0 +1,211 @@ +import torch +import numpy as np +from dp2.layers import Sequential +from dp2.layers.sg2_layers import Conv2d, FullyConnectedLayer, ResidualBlock +from .base import BaseStyleGAN +from typing import List, Tuple +from .utils import spatial_embed_keypoints, mask_output + + +def get_chsize(imsize, cnum, max_imsize, max_cnum_mul): + n = int(np.log2(max_imsize) - np.log2(imsize)) + mul = min(2**n, max_cnum_mul) + ch = cnum * mul + return int(ch) + + +class StyleGANUnet(BaseStyleGAN): + def __init__( + self, + scale_grad: bool, + im_channels: int, + min_fmap_resolution: int, + imsize: List[int], + cnum: int, + max_cnum_mul: int, + mask_output: bool, + conv_clamp: int, + input_cse: bool, + cse_nc: int, + n_middle_blocks: int, + input_keypoints: bool, + n_keypoints: int, + input_keypoint_indices: Tuple[int], + fix_errors: bool, + **kwargs + ) -> None: + super().__init__(**kwargs) + self.n_keypoints = n_keypoints + self.input_keypoint_indices = list(input_keypoint_indices) + self.input_keypoints = input_keypoints + assert not (input_cse and input_keypoints) + cse_nc = 0 if cse_nc is None else cse_nc + self.imsize = imsize + self._cnum = cnum + self._max_cnum_mul = max_cnum_mul + self._min_fmap_resolution = min_fmap_resolution + self._image_channels = im_channels + self._max_imsize = max(imsize) + self.input_cse = input_cse + self.gain_unet = np.sqrt(1/3) + n_levels = int(np.log2(self._max_imsize) - np.log2(min_fmap_resolution))+1 + encoder_layers = [] + self.from_rgb = Conv2d( + im_channels + 1 + input_cse*(cse_nc+1) + input_keypoints*len(self.input_keypoint_indices), + cnum, 1 + ) + for i in range(n_levels): # Encoder layers + resolution = [x//2**i for x in imsize] + in_ch = get_chsize(max(resolution), cnum, self._max_imsize, max_cnum_mul) + second_ch = in_ch + out_ch = get_chsize(max(resolution)//2, cnum, self._max_imsize, max_cnum_mul) + down = 2 + + if i == 0: # first (lowest) block. Downsampling is performed at the start of the block + down = 1 + if i == n_levels - 1: + out_ch = second_ch + block = ResidualBlock(in_ch, out_ch, down=down, conv_clamp=conv_clamp, fix_residual=fix_errors) + encoder_layers.append(block) + self._encoder_out_shape = [ + get_chsize(min_fmap_resolution, cnum, self._max_imsize, max_cnum_mul), + *resolution] + + self.encoder = torch.nn.ModuleList(encoder_layers) + + # initialize decoder + decoder_layers = [] + for i in range(n_levels): + resolution = [x//2**(n_levels-1-i) for x in imsize] + in_ch = get_chsize(max(resolution)//2, cnum, self._max_imsize, max_cnum_mul) + out_ch = get_chsize(max(resolution), cnum, self._max_imsize, max_cnum_mul) + if i == 0: # first (lowest) block + in_ch = get_chsize(max(resolution), cnum, self._max_imsize, max_cnum_mul) + + up = 1 + if i != n_levels - 1: + up = 2 + block = ResidualBlock( + in_ch, out_ch, conv_clamp=conv_clamp, gain_out=np.sqrt(1/3), + w_dim=self.style_net.w_dim, norm=True, up=up, + fix_residual=fix_errors + ) + decoder_layers.append(block) + if i != 0: + unet_block = Conv2d( + in_ch, in_ch, kernel_size=1, conv_clamp=conv_clamp, norm=True, + gain=np.sqrt(1/3) if fix_errors else np.sqrt(.5)) + setattr(self, f"unet_block{i}", unet_block) + + # Initialize "middle blocks" that do not have down/up sample + middle_blocks = [] + for i in range(n_middle_blocks): + ch = get_chsize(min_fmap_resolution, cnum, self._max_imsize, max_cnum_mul) + block = ResidualBlock( + ch, ch, conv_clamp=conv_clamp, gain_out=np.sqrt(.5) if fix_errors else np.sqrt(1/3), + w_dim=self.style_net.w_dim, norm=True, + ) + middle_blocks.append(block) + if n_middle_blocks != 0: + self.middle_blocks = Sequential(*middle_blocks) + self.decoder = torch.nn.ModuleList(decoder_layers) + self.to_rgb = Conv2d(cnum, im_channels, 1, activation="linear", conv_clamp=conv_clamp) + # Initialize "middle blocks" that do not have down/up sample + self.decoder = torch.nn.ModuleList(decoder_layers) + self.scale_grad = scale_grad + self.mask_output = mask_output + + def forward_dec(self, x, w, unet_features, condition, mask, s, **kwargs): + for i, layer in enumerate(self.decoder): + if i != 0: + unet_layer = getattr(self, f"unet_block{i}") + x = x + unet_layer(unet_features[-i]) + x = layer(x, w=w, s=s) + x = self.to_rgb(x) + if self.mask_output: + x = mask_output(True, condition, x, mask) + return dict(img=x) + + def forward_enc(self, condition, mask, embedding, keypoints, E_mask, **kwargs): + if self.input_cse: + x = torch.cat((condition, mask, embedding, E_mask), dim=1) + else: + x = torch.cat((condition, mask), dim=1) + if self.input_keypoints: + keypoints = keypoints[:, self.input_keypoint_indices] + one_hot_pose = spatial_embed_keypoints(keypoints, x) + x = torch.cat((x, one_hot_pose), dim=1) + x = self.from_rgb(x) + + unet_features = [] + for i, layer in enumerate(self.encoder): + x = layer(x) + if i != len(self.encoder)-1: + unet_features.append(x) + if hasattr(self, "middle_blocks"): + for layer in self.middle_blocks: + x = layer(x) + return x, unet_features + + def forward( + self, condition, mask, + z=None, embedding=None, w=None, update_emas=False, x=None, + s=None, + keypoints=None, + unet_features=None, + E_mask=None, + **kwargs): + # Used to skip sampling from encoder in inference. E.g. for w projection. + if x is not None and unet_features is not None: + assert not self.training + else: + x, unet_features = self.forward_enc(condition, mask, embedding, keypoints, E_mask, **kwargs) + if w is None: + if z is None: + z = self.get_z(condition) + w = self.get_w(z, update_emas=update_emas) + return self.forward_dec(x, w, unet_features, condition, mask, s, **kwargs) + + +class ComodStyleUNet(StyleGANUnet): + + def __init__(self, min_comod_res=4, lr_multiplier_comod=1, **kwargs) -> None: + super().__init__(**kwargs) + min_fmap = min(self._encoder_out_shape[1:]) + enc_out_ch = self._encoder_out_shape[0] + n_down = int(np.ceil(np.log2(min_fmap) - np.log2(min_comod_res))) + comod_layers = [] + in_ch = enc_out_ch + for i in range(n_down): + comod_layers.append(Conv2d(enc_out_ch, 256, kernel_size=3, down=2, lr_multiplier=lr_multiplier_comod)) + in_ch = 256 + if n_down == 0: + comod_layers = [Conv2d(in_ch, 256, kernel_size=3)] + comod_layers.append(torch.nn.Flatten()) + out_res = [x//2**n_down for x in self._encoder_out_shape[1:]] + in_ch_fc = np.prod(out_res) * 256 + comod_layers.append(FullyConnectedLayer(in_ch_fc, 512, lr_multiplier=lr_multiplier_comod)) + self.comod_block = Sequential(*comod_layers) + self.comod_fc = FullyConnectedLayer( + 512+self.style_net.w_dim, self.style_net.w_dim, lr_multiplier=lr_multiplier_comod) + + def forward_dec(self, x, w, unet_features, condition, mask, **kwargs): + y = self.comod_block(x) + y = torch.cat((y, w), dim=1) + y = self.comod_fc(y) + for i, layer in enumerate(self.decoder): + if i != 0: + unet_layer = getattr(self, f"unet_block{i}") + x = x + unet_layer(unet_features[-i], gain=np.sqrt(.5)) + x = layer(x, w=y) + x = self.to_rgb(x) + if self.mask_output: + x = mask_output(True, condition, x, mask) + return dict(img=x) + + def get_comod_y(self, batch, w): + x, unet_features = self.forward_enc(**batch) + y = self.comod_block(x) + y = torch.cat((y, w), dim=1) + y = self.comod_fc(y) + return y diff --git a/dp2/generator/utils.py b/dp2/generator/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ec5a983df28d81b10f81635881f77c4747162431 --- /dev/null +++ b/dp2/generator/utils.py @@ -0,0 +1,49 @@ +import torch +import tops +import torch +from torch.cuda.amp import custom_bwd, custom_fwd + + +@torch.no_grad() +def spatial_embed_keypoints(keypoints: torch.Tensor, x): + tops.assert_shape(keypoints, (None, None, 3)) + B, N_K, _ = keypoints.shape + H, W = x.shape[-2:] + keypoint_spatial = torch.zeros(keypoints.shape[0], N_K, H, W, device=keypoints.device, dtype=torch.float32) + x, y, visible = keypoints.chunk(3, dim=2) + x = (x * W).round().long().clamp(0, W-1) + y = (y * H).round().long().clamp(0, H-1) + kp_idx = torch.arange(0, N_K, 1, device=keypoints.device, dtype=torch.long).view(1, -1, 1).repeat(B, 1, 1) + pos = (kp_idx*(H*W) + y*W + x + 1) + # Offset all by 1 to index invisible keypoints as 0 + pos = (pos * visible.round().long()).squeeze(dim=-1) + keypoint_spatial = torch.zeros(keypoints.shape[0], N_K*H*W+1, device=keypoints.device, dtype=torch.float32) + keypoint_spatial.scatter_(1, pos, 1) + keypoint_spatial = keypoint_spatial[:, 1:].view(-1, N_K, H, W) + return keypoint_spatial + + +class MaskOutput(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, x_real, x_fake, mask): + ctx.save_for_backward(mask) + out = x_real * mask + (1-mask) * x_fake + return out + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + fake_grad = grad_output + mask, = ctx.saved_tensors + fake_grad = fake_grad * (1 - mask) + known_percentage = mask.view(mask.shape[0], -1).mean(dim=1) + fake_grad = fake_grad / (1-known_percentage).view(-1, 1, 1, 1) + return None, fake_grad, None + + +def mask_output(scale_grad, x_real, x_fake, mask): + if scale_grad: + return MaskOutput.apply(x_real, x_fake, mask) + return x_real * mask + (1-mask) * x_fake diff --git a/dp2/infer.py b/dp2/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..ff1ea39ed6919be6f4e282823e0c291e48e248ee --- /dev/null +++ b/dp2/infer.py @@ -0,0 +1,78 @@ +import tops +import torch +from tops import checkpointer +from tops.config import instantiate +from tops.logger import warn +from dp2.generator.deep_privacy1 import MSGGenerator + + +def load_generator_state(ckpt, G: torch.nn.Module, ckpt_mapper=None): + state = ckpt["EMA_generator"] if "EMA_generator" in ckpt else ckpt["running_average_generator"] + if ckpt_mapper is not None: + state = ckpt_mapper(state) + if isinstance(G, MSGGenerator): + G.load_state_dict(state) + else: + load_state_dict(G, state) + tops.logger.log(f"Generator loaded, num parameters: {tops.num_parameters(G)/1e6}M") + if "w_centers" in ckpt: + G.style_net.register_buffer("w_centers", ckpt["w_centers"]) + tops.logger.log(f"W cluster centers loaded. Number of centers: {len(G.style_net.w_centers)}") + if "style_net.w_centers" in state: + G.style_net.register_buffer("w_centers", state["style_net.w_centers"]) + tops.logger.log(f"W cluster centers loaded. Number of centers: {len(G.style_net.w_centers)}") + + +def build_trained_generator(cfg, map_location=None): + map_location = map_location if map_location is not None else tops.get_device() + G = instantiate(cfg.generator) + G.eval() + G.imsize = tuple(cfg.data.imsize) if hasattr(cfg, "data") else None + if hasattr(cfg, "ckpt_mapper"): + ckpt_mapper = instantiate(cfg.ckpt_mapper) + else: + ckpt_mapper = None + if "model_url" in cfg.common: + ckpt = tops.load_file_or_url(cfg.common.model_url, md5sum=cfg.common.model_md5sum) + load_generator_state(ckpt, G, ckpt_mapper) + return G.to(map_location) + try: + ckpt = checkpointer.load_checkpoint(cfg.checkpoint_dir, map_location="cpu") + load_generator_state(ckpt, G, ckpt_mapper) + except FileNotFoundError as e: + tops.logger.warn(f"Did not find generator checkpoint in: {cfg.checkpoint_dir}") + return G.to(map_location) + + +def build_trained_discriminator(cfg, map_location=None): + map_location = map_location if map_location is not None else tops.get_device() + D = instantiate(cfg.discriminator).to(map_location) + D.eval() + try: + ckpt = checkpointer.load_checkpoint(cfg.checkpoint_dir, map_location="cpu") + if hasattr(cfg, "ckpt_mapper_D"): + ckpt["discriminator"] = instantiate(cfg.ckpt_mapper_D)(ckpt["discriminator"]) + D.load_state_dict(ckpt["discriminator"]) + except FileNotFoundError as e: + tops.logger.warn(f"Did not find discriminator checkpoint in: {cfg.checkpoint_dir}") + return D + + +def load_state_dict(module: torch.nn.Module, state_dict: dict): + module_sd = module.state_dict() + to_remove = [] + for key, item in state_dict.items(): + if key not in module_sd: + continue + if item.shape != module_sd[key].shape: + to_remove.append(key) + warn(f"Incorrect shape. Current model: {module_sd[key].shape}, in state dict: {item.shape} for key: {key}") + for key in to_remove: + state_dict.pop(key) + for key, item in state_dict.items(): + if key not in module_sd: + warn(f"Did not fin key in model state dict: {key}") + for key, item in module_sd.items(): + if key not in state_dict: + warn(f"Did not find key in state dict: {key}") + module.load_state_dict(state_dict, strict=False) diff --git a/dp2/layers/__init__.py b/dp2/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..680be5614803ce6a74a2994c023ab90b7d27a74b --- /dev/null +++ b/dp2/layers/__init__.py @@ -0,0 +1,22 @@ +from typing import Dict +import torch +import tops +import torch.nn as nn + + +class Sequential(nn.Sequential): + + def forward(self, x: Dict[str, torch.Tensor], **kwargs) -> Dict[str, torch.Tensor]: + for module in self: + x = module(x, **kwargs) + return x + + +class Module(nn.Module): + + def __init__(self, *args, **kwargs): + super().__init__() + + def extra_repr(self): + num_params = tops.num_parameters(self) / 10**6 + return f"Num params: {num_params:.3f}M" diff --git a/dp2/layers/sg2_layers.py b/dp2/layers/sg2_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..ce111ffd5a4ce1d5a879ba49cd7a7108b35defdf --- /dev/null +++ b/dp2/layers/sg2_layers.py @@ -0,0 +1,229 @@ +from typing import List +import numpy as np +import torch +import tops +import torch.nn.functional as F +from sg3_torch_utils.ops import conv2d_resample +from sg3_torch_utils.ops import upfirdn2d +from sg3_torch_utils.ops import bias_act +from sg3_torch_utils.ops.fma import fma + + +class FullyConnectedLayer(torch.nn.Module): + def __init__(self, + in_features, # Number of input features. + out_features, # Number of output features. + bias=True, # Apply additive bias before the activation function? + activation='linear', # Activation function: 'relu', 'lrelu', etc. + lr_multiplier=1, # Learning rate multiplier. + bias_init=0, # Initial value for the additive bias. + ): + super().__init__() + self.repr = dict( + in_features=in_features, out_features=out_features, bias=bias, + activation=activation, lr_multiplier=lr_multiplier, bias_init=bias_init) + self.activation = activation + self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) + self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None + self.weight_gain = lr_multiplier / np.sqrt(in_features) + self.bias_gain = lr_multiplier + self.in_features = in_features + self.out_features = out_features + + def forward(self, x): + w = self.weight * self.weight_gain + b = self.bias + if b is not None and self.bias_gain != 1: + b = b * self.bias_gain + x = F.linear(x, w) + x = bias_act.bias_act(x, b, act=self.activation) + return x + + def extra_repr(self) -> str: + return ", ".join([f"{key}={item}" for key, item in self.repr.items()]) + + +class Conv2d(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + kernel_size=3, # Convolution kernel size. + up=1, # Integer upsampling factor. + down=1, # Integer downsampling factor + activation='lrelu', # Activation function: 'relu', 'lrelu', etc. + resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + bias=True, + norm=False, + lr_multiplier=1, + bias_init=0, + w_dim=None, + gain=1, + ): + super().__init__() + if norm: + self.norm = torch.nn.InstanceNorm2d(None) + assert norm in [True, False] + self.up = up + self.down = down + self.activation = activation + self.conv_clamp = conv_clamp if conv_clamp is None else conv_clamp * gain + self.out_channels = out_channels + self.in_channels = in_channels + self.padding = kernel_size // 2 + + self.repr = dict( + in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, up=up, down=down, + activation=activation, resample_filter=resample_filter, conv_clamp=conv_clamp, bias=bias, + ) + + if self.up == 1 and self.down == 1: + self.resample_filter = None + else: + self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) + + self.act_gain = bias_act.activation_funcs[activation].def_gain * gain + self.weight_gain = lr_multiplier / np.sqrt(in_channels * (kernel_size ** 2)) + self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size])) + self.bias = torch.nn.Parameter(torch.zeros([out_channels]) + bias_init) if bias else None + self.bias_gain = lr_multiplier + if w_dim is not None: + self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) + self.affine_beta = FullyConnectedLayer(w_dim, in_channels, bias_init=0) + + def forward(self, x, w=None, s=None): + tops.assert_shape(x, [None, self.weight.shape[1], None, None]) + if s is not None: + s = s[..., :self.in_channels * 2] + gamma, beta = s.view(-1, self.in_channels * 2, 1, 1).chunk(2, dim=1) + x = fma(x, gamma, beta) + elif hasattr(self, "affine"): + gamma = self.affine(w).view(-1, self.in_channels, 1, 1) + beta = self.affine_beta(w).view(-1, self.in_channels, 1, 1) + x = fma(x, gamma, beta) + w = self.weight * self.weight_gain + # Removing flip weight is not safe. + x = conv2d_resample.conv2d_resample(x, w, self.resample_filter, self.up, + self.down, self.padding, flip_weight=self.up == 1) + if hasattr(self, "norm"): + x = self.norm(x) + b = self.bias * self.bias_gain if self.bias is not None else None + x = bias_act.bias_act(x, b, act=self.activation, gain=self.act_gain, clamp=self.conv_clamp) + return x + + def extra_repr(self) -> str: + return ", ".join([f"{key}={item}" for key, item in self.repr.items()]) + + +class Block(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels, 0 = first block. + out_channels, # Number of output channels. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + up=1, + down=1, + **layer_kwargs, # Arguments for SynthesisLayer. + ): + super().__init__() + self.in_channels = in_channels + self.down = down + self.conv0 = Conv2d(in_channels, out_channels, down=down, conv_clamp=conv_clamp, **layer_kwargs) + self.conv1 = Conv2d(out_channels, out_channels, up=up, conv_clamp=conv_clamp, **layer_kwargs) + + def forward(self, x, **layer_kwargs): + x = self.conv0(x, **layer_kwargs) + x = self.conv1(x, **layer_kwargs) + return x + + +class ResidualBlock(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels, 0 = first block. + out_channels, # Number of output channels. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + up=1, + down=1, + gain_out=np.sqrt(0.5), + fix_residual: bool = False, + **layer_kwargs, # Arguments for conv layer. + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.down = down + self.conv0 = Conv2d(in_channels, out_channels, down=down, conv_clamp=conv_clamp, **layer_kwargs) + + self.conv1 = Conv2d(out_channels, out_channels, up=up, conv_clamp=conv_clamp, gain=gain_out, **layer_kwargs) + + self.skip = Conv2d( + in_channels, out_channels, kernel_size=1, bias=False, up=up, down=down, + activation="linear" if fix_residual else "lrelu", + gain=gain_out + ) + self.gain_out = gain_out + + def forward(self, x, w=None, s=None, **layer_kwargs): + y = self.skip(x) + s_ = next(s) if s is not None else None + x = self.conv0(x, w, s=s_, **layer_kwargs) + s_ = next(s) if s is not None else None + x = self.conv1(x, w, s=s_, **layer_kwargs) + x = y + x + return x + + +class MinibatchStdLayer(torch.nn.Module): + def __init__(self, group_size, num_channels=1): + super().__init__() + self.group_size = group_size + self.num_channels = num_channels + + def forward(self, x): + N, C, H, W = x.shape + with tops.suppress_tracer_warnings(): # as_tensor results are registered as constants + G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N + F = self.num_channels + c = C // F + + # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c. + y = x.reshape(G, -1, F, c, H, W) + y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group. + y = y.square().mean(dim=0) # [nFcHW] Calc variance over group. + y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group. + y = y.mean(dim=[2, 3, 4]) # [nF] Take average over channels and pixels. + y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions. + y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels. + x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels. + return x + + +class DiscriminatorEpilogue(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels. + resolution: List[int], # Resolution of this block. + mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch. + mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable. + activation='lrelu', # Activation function: 'relu', 'lrelu', etc. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + ): + super().__init__() + self.in_channels = in_channels + self.resolution = resolution + self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, + num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None + self.conv = Conv2d( + in_channels + mbstd_num_channels, in_channels, + kernel_size=3, activation=activation, conv_clamp=conv_clamp) + self.fc = FullyConnectedLayer(in_channels * resolution[0] * resolution[1], in_channels, activation=activation) + self.out = FullyConnectedLayer(in_channels, 1) + + def forward(self, x): + tops.assert_shape(x, [None, self.in_channels, *self.resolution]) # [NCHW] + # Main layers. + if self.mbstd is not None: + x = self.mbstd(x) + x = self.conv(x) + x = self.fc(x.flatten(1)) + x = self.out(x) + return x diff --git a/dp2/loss/__init__.py b/dp2/loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..16cbdd0051ff51bda0828f37c9ef5faed65a9ed7 --- /dev/null +++ b/dp2/loss/__init__.py @@ -0,0 +1 @@ +from .sg2_loss import StyleGAN2Loss diff --git a/dp2/loss/pl_regularization.py b/dp2/loss/pl_regularization.py new file mode 100644 index 0000000000000000000000000000000000000000..cf47b9b4ea9f9a22eebae8d0d4be755f75c3878c --- /dev/null +++ b/dp2/loss/pl_regularization.py @@ -0,0 +1,49 @@ +import torch +import tops +import numpy as np +from sg3_torch_utils.ops import conv2d_gradfix + +pl_mean_total = torch.zeros([]) + + +class PLRegularization: + + def __init__(self, weight: float, batch_shrink: int, pl_decay: float, scale_by_mask: bool, **kwargs): + self.pl_mean = torch.zeros([], device=tops.get_device()) + self.pl_weight = weight + self.batch_shrink = batch_shrink + self.pl_decay = pl_decay + self.scale_by_mask = scale_by_mask + + def __call__(self, G, batch, grad_scaler): + batch_size = batch["img"].shape[0] // self.batch_shrink + batch = {k: v[:batch_size] for k, v in batch.items() if k != "embed_map"} + if "embed_map" in batch: + batch["embed_map"] = batch["embed_map"] + z = G.get_z(batch["img"]) + + with torch.cuda.amp.autocast(tops.AMP()): + gen_ws = G.style_net(z) + gen_img = G(**batch, w=gen_ws)["img"].float() + pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3]) + with conv2d_gradfix.no_weight_gradients(): + # Sums over HWC + pl_grads = torch.autograd.grad( + outputs=[grad_scaler.scale(gen_img * pl_noise)], + inputs=[gen_ws], + create_graph=True, + grad_outputs=torch.ones_like(gen_img), + only_inputs=True)[0] + + pl_grads = pl_grads.float() / grad_scaler.get_scale() + if self.scale_by_mask: + # Percentage of pixels known + scaling = batch["mask"].flatten(start_dim=1).mean(dim=1).view(-1, 1) + pl_grads = pl_grads / scaling + pl_lengths = pl_grads.square().sum(1).sqrt() + pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay) + if not torch.isnan(pl_mean).any(): + self.pl_mean.copy_(pl_mean.detach()) + pl_penalty = (pl_lengths - pl_mean).square() + to_log = dict(pl_penalty=pl_penalty.mean().detach()) + return pl_penalty.view(-1) * self.pl_weight, to_log diff --git a/dp2/loss/r1_regularization.py b/dp2/loss/r1_regularization.py new file mode 100644 index 0000000000000000000000000000000000000000..f974c5542bf49ed36b54b46cfc7c9c9bfaff9ce3 --- /dev/null +++ b/dp2/loss/r1_regularization.py @@ -0,0 +1,32 @@ +import torch +import tops + + +def r1_regularization( + real_img, real_score, mask, lambd: float, lazy_reg_interval: int, + lazy_regularization: bool, + scaler: torch.cuda.amp.GradScaler, mask_out: bool, + mask_out_scale: bool, + **kwargs +): + grad = torch.autograd.grad( + outputs=scaler.scale(real_score), + inputs=real_img, + grad_outputs=torch.ones_like(real_score), + create_graph=True, + only_inputs=True, + )[0] + inv_scale = 1.0 / scaler.get_scale() + grad = grad * inv_scale + with torch.cuda.amp.autocast(tops.AMP()): + if mask_out: + grad = grad * (1 - mask) + grad = grad.square().sum(dim=[1, 2, 3]) + if mask_out and mask_out_scale: + total_pixels = real_img.shape[1] * real_img.shape[2] * real_img.shape[3] + n_fake = (1-mask).sum(dim=[1, 2, 3]) + scaling = total_pixels / n_fake + grad = grad * scaling + if lazy_regularization: + lambd_ = lambd * lazy_reg_interval / 2 # From stylegan2, lazy regularization + return grad * lambd_, grad.detach() diff --git a/dp2/loss/sg2_loss.py b/dp2/loss/sg2_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..763263e2e7cb9330f24265ba8008e152fa4110f0 --- /dev/null +++ b/dp2/loss/sg2_loss.py @@ -0,0 +1,96 @@ +import functools +import torch +import tops +from tops import logger +from dp2.utils import forward_D_fake +from .utils import nsgan_d_loss, nsgan_g_loss +from .r1_regularization import r1_regularization +from .pl_regularization import PLRegularization + + +class StyleGAN2Loss: + + def __init__( + self, + D, + G, + r1_opts: dict, + EP_lambd: float, + lazy_reg_interval: int, + lazy_regularization: bool, + pl_reg_opts: dict, + ) -> None: + self.gradient_step_D = 0 + self._lazy_reg_interval = lazy_reg_interval + self.D = D + self.G = G + self.EP_lambd = EP_lambd + self.lazy_regularization = lazy_regularization + self.r1_reg = functools.partial( + r1_regularization, **r1_opts, lazy_reg_interval=lazy_reg_interval, + lazy_regularization=lazy_regularization) + self.do_PL_Reg = False + if pl_reg_opts.weight > 0: + self.pl_reg = PLRegularization(**pl_reg_opts) + self.do_PL_Reg = True + self.pl_start_nimg = pl_reg_opts.start_nimg + + def D_loss(self, batch: dict, grad_scaler): + to_log = {} + # Forward through G and D + do_GP = self.lazy_regularization and self.gradient_step_D % self._lazy_reg_interval == 0 + if do_GP: + batch["img"] = batch["img"].detach().requires_grad_(True) + with torch.cuda.amp.autocast(enabled=tops.AMP()): + with torch.no_grad(): + G_fake = self.G(**batch, update_emas=True) + D_out_real = self.D(**batch) + + D_out_fake = forward_D_fake(batch, G_fake["img"], self.D) + + # Non saturating loss + nsgan_loss = nsgan_d_loss(D_out_real["score"], D_out_fake["score"]) + tops.assert_shape(nsgan_loss, (batch["img"].shape[0], )) + to_log["d_loss"] = nsgan_loss.mean() + total_loss = nsgan_loss + epsilon_penalty = D_out_real["score"].pow(2).view(-1) + to_log["epsilon_penalty"] = epsilon_penalty.mean() + tops.assert_shape(epsilon_penalty, total_loss.shape) + total_loss = total_loss + epsilon_penalty * self.EP_lambd + + # Improved gradient penalty with lazy regularization + # Gradient penalty applies specialized autocast. + if do_GP: + gradient_pen, grad_unscaled = self.r1_reg( + batch["img"], D_out_real["score"], batch["mask"], scaler=grad_scaler) + to_log["r1_gradient_penalty"] = grad_unscaled.mean() + tops.assert_shape(gradient_pen, total_loss.shape) + total_loss = total_loss + gradient_pen + + batch["img"] = batch["img"].detach().requires_grad_(False) + if "score" in D_out_real: + to_log["real_scores"] = D_out_real["score"] + to_log["real_logits_sign"] = D_out_real["score"].sign() + to_log["fake_logits_sign"] = D_out_fake["score"].sign() + to_log["fake_scores"] = D_out_fake["score"] + to_log = {key: item.mean().detach() for key, item in to_log.items()} + self.gradient_step_D += 1 + return total_loss.mean(), to_log + + def G_loss(self, batch: dict, grad_scaler): + with torch.cuda.amp.autocast(enabled=tops.AMP()): + to_log = {} + # Forward through G and D + G_fake = self.G(**batch) + D_out_fake = forward_D_fake(batch, G_fake["img"], self.D) + # Adversarial Loss + total_loss = nsgan_g_loss(D_out_fake["score"]).view(-1) + to_log["g_loss"] = total_loss.mean() + tops.assert_shape(total_loss, (batch["img"].shape[0], )) + + if self.do_PL_Reg and logger.global_step() >= self.pl_start_nimg: + pl_reg, to_log_ = self.pl_reg(self.G, batch, grad_scaler=grad_scaler) + total_loss = total_loss + pl_reg.mean() + to_log.update(to_log_) + to_log = {key: item.mean().detach() for key, item in to_log.items()} + return total_loss.mean(), to_log diff --git a/dp2/loss/utils.py b/dp2/loss/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8d6e19c3a0c4718412e6d83e3405c73029275f35 --- /dev/null +++ b/dp2/loss/utils.py @@ -0,0 +1,26 @@ +import torch +import torch.nn.functional as F + + +def nsgan_g_loss(fake_score): + """ + Non-saturating criterion from Goodfellow et al. 2014 + """ + return torch.nn.functional.softplus(-fake_score) + + +def nsgan_d_loss(real_score, fake_score): + """ + Non-saturating criterion from Goodfellow et al. 2014 + """ + d_loss = F.softplus(-real_score) + F.softplus(fake_score) + return d_loss.view(-1) + + +def smooth_masked_l1_loss(x, target, mask): + """ + Pixel-wise l1 loss for the area indicated by mask + """ + # Beta=.1 <-> square loss if pixel difference <= 12.8 + l1 = F.smooth_l1_loss(x*mask, target*mask, beta=.1, reduction="none").sum(dim=[1, 2, 3]) / mask.sum(dim=[1, 2, 3]) + return l1 diff --git a/dp2/metrics/__init__.py b/dp2/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d8bbb1ddd31e4453ad3663805f9aa06c077cf21 --- /dev/null +++ b/dp2/metrics/__init__.py @@ -0,0 +1,3 @@ +from .torch_metrics import compute_metrics_iteratively +from .fid import compute_fid +from .ppl import calculate_ppl diff --git a/dp2/metrics/fid.py b/dp2/metrics/fid.py new file mode 100644 index 0000000000000000000000000000000000000000..3c8e0e6171fe723aa9f37735a216ec8c9d4e4a8f --- /dev/null +++ b/dp2/metrics/fid.py @@ -0,0 +1,52 @@ +import tops +from dp2 import utils +from pathlib import Path +from torch_fidelity.generative_model_modulewrapper import GenerativeModelModuleWrapper +import torch +import torch_fidelity + + +class GeneratorIteratorWrapper(GenerativeModelModuleWrapper): + + def __init__(self, generator, dataloader, zero_z: bool, n_diverse: int): + if isinstance(generator, utils.EMA): + generator = generator.generator + z_size = generator.z_channels + super().__init__(generator, z_size, "normal", 0) + self.zero_z = zero_z + self.dataloader = iter(dataloader) + self.n_diverse = n_diverse + self.cur_div_idx = 0 + + @torch.no_grad() + def forward(self, z, **kwargs): + if self.cur_div_idx == 0: + self.batch = next(self.dataloader) + if self.zero_z: + z = z.zero_() + self.cur_div_idx += 1 + self.cur_div_idx = 0 if self.cur_div_idx == self.n_diverse else self.cur_div_idx + with torch.cuda.amp.autocast(enabled=tops.AMP()): + img = self.module(**self.batch)["img"] + img = (utils.denormalize_img(img)*255).byte() + return img + + +def compute_fid(generator, dataloader, real_directory, n_source, zero_z, n_diverse): + generator = GeneratorIteratorWrapper(generator, dataloader, zero_z, n_diverse) + batch_size = dataloader.batch_size + num_samples = (n_source * n_diverse) // batch_size * batch_size + assert n_diverse >= 1 + assert (not zero_z) or n_diverse == 1 + assert num_samples % batch_size == 0 + assert n_source <= batch_size * len(dataloader), (batch_size*len(dataloader), n_source, n_diverse) + metrics = torch_fidelity.calculate_metrics( + input1=generator, + input2=real_directory, + cuda=torch.cuda.is_available(), + fid=True, + input2_cache_name="_".join(Path(real_directory).parts) + "_cached", + input1_model_num_samples=int(num_samples), + batch_size=dataloader.batch_size + ) + return metrics["frechet_inception_distance"] diff --git a/dp2/metrics/fid_clip.py b/dp2/metrics/fid_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..43bde1bf74c69399308ed15ceda5aaeb59a69818 --- /dev/null +++ b/dp2/metrics/fid_clip.py @@ -0,0 +1,84 @@ +import pickle +import torch +import torchvision +from pathlib import Path +from dp2 import utils +import tops +try: + import clip +except ImportError: + print("Could not import clip.") +from torch_fidelity.metric_fid import fid_features_to_statistics, fid_statistics_to_metric +clip_model = None +clip_preprocess = None + + +@torch.no_grad() +def compute_fid_clip( + dataloader, generator, + cache_directory, + data_len=None, + **kwargs + ) -> dict: + """ + FID CLIP following the description in The Role of ImageNet Classes in Frechet Inception Distance, Thomas Kynkaamniemi et al. + Args: + n_samples (int): Creates N samples from same image to calculate stats + """ + global clip_model, clip_preprocess + if clip_model is None: + clip_model, preprocess = clip.load("ViT-B/32", device="cpu") + normalize_fn = preprocess.transforms[-1] + img_mean = normalize_fn.mean + img_std = normalize_fn.std + clip_model = tops.to_cuda(clip_model.visual) + clip_preprocess = tops.to_cuda(torch.nn.Sequential( + torchvision.transforms.Resize((224, 224), interpolation=torchvision.transforms.InterpolationMode.BICUBIC), + torchvision.transforms.Normalize(img_mean, img_std) + )) + cache_directory = Path(cache_directory) + if data_len is None: + data_len = len(dataloader)*dataloader.batch_size + fid_cache_path = cache_directory.joinpath("fid_stats_clip.pkl") + has_fid_cache = fid_cache_path.is_file() + if not has_fid_cache: + fid_features_real = torch.zeros(data_len, 512, dtype=torch.float32, device=tops.get_device()) + fid_features_fake = torch.zeros(data_len, 512, dtype=torch.float32, device=tops.get_device()) + eidx = 0 + n_samples_seen = 0 + for batch in utils.tqdm_(iter(dataloader), desc="Computing FID CLIP."): + sidx = eidx + eidx = sidx + batch["img"].shape[0] + n_samples_seen += batch["img"].shape[0] + with torch.cuda.amp.autocast(tops.AMP()): + fakes = generator(**batch)["img"] + real_data = batch["img"] + fakes = utils.denormalize_img(fakes) + real_data = utils.denormalize_img(real_data) + if not has_fid_cache: + real_data = clip_preprocess(real_data) + fid_features_real[sidx:eidx] = clip_model(real_data) + fakes = clip_preprocess(fakes) + fid_features_fake[sidx:eidx] = clip_model(fakes) + fid_features_fake = fid_features_fake[:n_samples_seen] + fid_features_fake = tops.all_gather_uneven(fid_features_fake).cpu() + if has_fid_cache: + if tops.rank() == 0: + with open(fid_cache_path, "rb") as fp: + fid_stat_real = pickle.load(fp) + else: + fid_features_real = fid_features_real[:n_samples_seen] + fid_features_real = tops.all_gather_uneven(fid_features_real).cpu() + assert fid_features_real.shape == fid_features_fake.shape + if tops.rank() == 0: + fid_stat_real = fid_features_to_statistics(fid_features_real) + cache_directory.mkdir(exist_ok=True, parents=True) + with open(fid_cache_path, "wb") as fp: + pickle.dump(fid_stat_real, fp) + + if tops.rank() == 0: + print("Starting calculation of fid from features of shape:", fid_features_fake.shape) + fid_stat_fake = fid_features_to_statistics(fid_features_fake) + fid_ = fid_statistics_to_metric(fid_stat_real, fid_stat_fake, verbose=False)["frechet_inception_distance"] + return dict(fid_clip=fid_) + return dict(fid_clip=-1) diff --git a/dp2/metrics/lpips.py b/dp2/metrics/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..397d1b12cd6952aafb6929bc3fa33f39ba509e33 --- /dev/null +++ b/dp2/metrics/lpips.py @@ -0,0 +1,77 @@ +import torch +import tops +import sys +from contextlib import redirect_stdout +from torch_fidelity.sample_similarity_lpips import NetLinLayer, URL_VGG16_LPIPS, VGG16features, normalize_tensor, spatial_average + + +class SampleSimilarityLPIPS(torch.nn.Module): + SUPPORTED_DTYPES = { + 'uint8': torch.uint8, + 'float32': torch.float32, + } + + def __init__(self): + + super().__init__() + self.chns = [64, 128, 256, 512, 512] + self.L = len(self.chns) + self.lin0 = NetLinLayer(self.chns[0], use_dropout=True) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=True) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=True) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=True) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=True) + self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + with redirect_stdout(sys.stderr): + fp = tops.download_file(URL_VGG16_LPIPS) + state_dict = torch.load(fp, map_location="cpu") + self.load_state_dict(state_dict) + self.net = VGG16features() + self.eval() + for param in self.parameters(): + param.requires_grad = False + mean_rescaled = (1 + torch.tensor([-.030, -.088, -.188]).view(1, 3, 1, 1)) * 255 / 2 + inv_std_rescaled = 2 / (torch.tensor([.458, .448, .450]).view(1, 3, 1, 1) * 255) + self.register_buffer("mean", mean_rescaled) + self.register_buffer("std", inv_std_rescaled) + + def normalize(self, x): + # torchvision values in range [0,1] mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225] + x = (x.float() - self.mean) * self.std + return x + + @staticmethod + def resize(x, size): + if x.shape[-1] > size and x.shape[-2] > size: + x = torch.nn.functional.interpolate(x, (size, size), mode='area') + else: + x = torch.nn.functional.interpolate(x, (size, size), mode='bilinear', align_corners=False) + return x + + def lpips_from_feats(self, feats0, feats1): + diffs = {} + for kk in range(self.L): + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [spatial_average(self.lins[kk].model(diffs[kk])) for kk in range(self.L)] + val = sum(res) + return val + + def get_feats(self, x): + assert x.dim() == 4 and x.shape[1] == 3, 'Input 0 is not Bx3xHxW' + if x.shape[-2] < 16 or x.shape[-1] < 16: # Resize images < 16x16 + f = 2 + size = tuple([int(f*_) for _ in x.shape[-2:]]) + x = torch.nn.functional.interpolate(x, size=size, mode="bilinear", align_corners=False) + in0_input = self.normalize(x) + outs0 = self.net.forward(in0_input) + + feats = {} + for kk in range(self.L): + feats[kk] = normalize_tensor(outs0[kk]) + return feats + + def forward(self, in0, in1): + feats0 = self.get_feats(in0) + feats1 = self.get_feats(in1) + return self.lpips_from_feats(feats0, feats1), feats0, feats1 diff --git a/dp2/metrics/ppl.py b/dp2/metrics/ppl.py new file mode 100644 index 0000000000000000000000000000000000000000..421aeafc5edc4647037fdc390737b269cdfbeae5 --- /dev/null +++ b/dp2/metrics/ppl.py @@ -0,0 +1,116 @@ +import numpy as np +import torch +import tops +from dp2 import utils +from torch_fidelity.helpers import get_kwarg, vassert +from torch_fidelity.defaults import DEFAULTS as PPL_DEFAULTS +from torch_fidelity.utils import sample_random, batch_interp, create_sample_similarity +from torchvision.transforms.functional import resize + + +def slerp(a, b, t): + a = a / a.norm(dim=-1, keepdim=True) + b = b / b.norm(dim=-1, keepdim=True) + d = (a * b).sum(dim=-1, keepdim=True) + p = t * torch.acos(d) + c = b - d * a + c = c / c.norm(dim=-1, keepdim=True) + d = a * torch.cos(p) + c * torch.sin(p) + d = d / d.norm(dim=-1, keepdim=True) + return d + + +@torch.no_grad() +def calculate_ppl( + dataloader, + generator, + latent_space=None, + data_len=None, + upsample_size=None, + **kwargs) -> dict: + """ + Inspired by https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py + """ + if latent_space is None: + latent_space = generator.latent_space + assert latent_space in ["Z", "W"], f"Not supported latent space: {latent_space}" + assert len(upsample_size) == 2 + epsilon = PPL_DEFAULTS["ppl_epsilon"] + interp = PPL_DEFAULTS['ppl_z_interp_mode'] + similarity_name = PPL_DEFAULTS['ppl_sample_similarity'] + sample_similarity_resize = PPL_DEFAULTS['ppl_sample_similarity_resize'] + sample_similarity_dtype = PPL_DEFAULTS['ppl_sample_similarity_dtype'] + discard_percentile_lower = PPL_DEFAULTS['ppl_discard_percentile_lower'] + discard_percentile_higher = PPL_DEFAULTS['ppl_discard_percentile_higher'] + + vassert(type(epsilon) is float and epsilon > 0, 'Epsilon must be a small positive floating point number') + vassert(discard_percentile_lower is None or 0 < discard_percentile_lower < 100, 'Invalid percentile') + vassert(discard_percentile_higher is None or 0 < discard_percentile_higher < 100, 'Invalid percentile') + if discard_percentile_lower is not None and discard_percentile_higher is not None: + vassert(0 < discard_percentile_lower < discard_percentile_higher < 100, 'Invalid percentiles') + + sample_similarity = create_sample_similarity( + similarity_name, + sample_similarity_resize=sample_similarity_resize, + sample_similarity_dtype=sample_similarity_dtype, + cuda=False, + **kwargs + ) + sample_similarity = tops.to_cuda(sample_similarity) + rng = np.random.RandomState(get_kwarg('rng_seed', kwargs)) + distances = [] + if data_len is None: + data_len = len(dataloader) * dataloader.batch_size + z0 = sample_random(rng, (data_len, generator.z_channels), "normal") + z1 = sample_random(rng, (data_len, generator.z_channels), "normal") + if latent_space == "Z": + z1 = batch_interp(z0, z1, epsilon, interp) + print("Computing PPL IN", latent_space) + distances = torch.zeros(data_len, dtype=torch.float32, device=tops.get_device()) + print(distances.shape) + end = 0 + n_samples = 0 + for it, batch in enumerate(utils.tqdm_(dataloader, desc="Perceptual Path Length")): + start = end + end = start + batch["img"].shape[0] + n_samples += batch["img"].shape[0] + batch_lat_e0 = tops.to_cuda(z0[start:end]) + batch_lat_e1 = tops.to_cuda(z1[start:end]) + if latent_space == "W": + w0 = generator.get_w(batch_lat_e0, update_emas=False) + w1 = generator.get_w(batch_lat_e1, update_emas=False) + w1 = w0.lerp(w1, epsilon) # PPL end + rgb1 = generator(**batch, w=w0)["img"] + rgb2 = generator(**batch, w=w1)["img"] + else: + rgb1 = generator(**batch, z=batch_lat_e0)["img"] + rgb2 = generator(**batch, z=batch_lat_e1)["img"] + if rgb1.shape[-2] < upsample_size[0] or rgb1.shape[-1] < upsample_size[1]: + rgb1 = resize(rgb1, upsample_size, antialias=True) + rgb2 = resize(rgb2, upsample_size, antialias=True) + rgb1 = utils.denormalize_img(rgb1).mul(255).byte() + rgb2 = utils.denormalize_img(rgb2).mul(255).byte() + + sim = sample_similarity(rgb1, rgb2) + dist_lat_e01 = sim / (epsilon ** 2) + distances[start:end] = dist_lat_e01.view(-1) + distances = distances[:n_samples] + distances = tops.all_gather_uneven(distances).cpu().numpy() + if tops.rank() != 0: + return {"ppl/mean": -1, "ppl/std": -1} + if tops.rank() == 0: + cond, lo, hi = None, None, None + if discard_percentile_lower is not None: + lo = np.percentile(distances, discard_percentile_lower, interpolation='lower') + cond = lo <= distances + if discard_percentile_higher is not None: + hi = np.percentile(distances, discard_percentile_higher, interpolation='higher') + cond = np.logical_and(cond, distances <= hi) + if cond is not None: + distances = np.extract(cond, distances) + return { + "ppl/mean": float(np.mean(distances)), + "ppl/std": float(np.std(distances)), + } + else: + return {"ppl/mean"} diff --git a/dp2/metrics/torch_metrics.py b/dp2/metrics/torch_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..0c8747b01a98f3aa4df9161e09c309c101b396f2 --- /dev/null +++ b/dp2/metrics/torch_metrics.py @@ -0,0 +1,177 @@ +import pickle +import numpy as np +import torch +import time +from pathlib import Path +from dp2 import utils +import tops +from .lpips import SampleSimilarityLPIPS +from torch_fidelity.defaults import DEFAULTS as trf_defaults +from torch_fidelity.metric_fid import fid_features_to_statistics, fid_statistics_to_metric +from torch_fidelity.utils import create_feature_extractor +lpips_model = None +fid_model = None + + +@torch.no_grad() +def mse(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor: + se = (images1 - images2) ** 2 + se = se.view(images1.shape[0], -1).mean(dim=1) + return se + + +@torch.no_grad() +def psnr(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor: + mse_ = mse(images1, images2) + psnr = 10 * torch.log10(1 / mse_) + return psnr + + +@torch.no_grad() +def lpips(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor: + return _lpips_w_grad(images1, images2) + + +def _lpips_w_grad(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor: + global lpips_model + if lpips_model is None: + lpips_model = tops.to_cuda(SampleSimilarityLPIPS()) + + images1 = images1.mul(255) + images2 = images2.mul(255) + with torch.cuda.amp.autocast(tops.AMP()): + dists = lpips_model(images1, images2)[0].view(-1) + return dists + + +@torch.no_grad() +def compute_metrics_iteratively( + dataloader, generator, + cache_directory, + data_len=None, + truncation_value: float = None, +) -> dict: + """ + Args: + n_samples (int): Creates N samples from same image to calculate stats + dataset_percentage (float): The percentage of the dataset to compute metrics on. + """ + + global lpips_model, fid_model + if lpips_model is None: + lpips_model = tops.to_cuda(SampleSimilarityLPIPS()) + if fid_model is None: + fid_model = create_feature_extractor( + trf_defaults["feature_extractor"], [trf_defaults["feature_layer_fid"]], cuda=False) + fid_model = tops.to_cuda(fid_model) + cache_directory = Path(cache_directory) + start_time = time.time() + lpips_total = torch.tensor(0, dtype=torch.float32, device=tops.get_device()) + diversity_total = torch.zeros_like(lpips_total) + fid_cache_path = cache_directory.joinpath("fid_stats.pkl") + has_fid_cache = fid_cache_path.is_file() + if data_len is None: + data_len = len(dataloader)*dataloader.batch_size + if not has_fid_cache: + fid_features_real = torch.zeros(data_len, 2048, dtype=torch.float32, device=tops.get_device()) + fid_features_fake = torch.zeros(data_len, 2048, dtype=torch.float32, device=tops.get_device()) + n_samples_seen = torch.tensor([0], dtype=torch.int32, device=tops.get_device()) + eidx = 0 + for batch in utils.tqdm_(iter(dataloader), desc="Computing FID, LPIPS and LPIPS Diversity"): + sidx = eidx + eidx = sidx + batch["img"].shape[0] + n_samples_seen += batch["img"].shape[0] + with torch.cuda.amp.autocast(tops.AMP()): + fakes1 = generator.sample(**batch, truncation_value=truncation_value)["img"] + fakes2 = generator.sample(**batch, truncation_value=truncation_value)["img"] + fakes1 = utils.denormalize_img(fakes1).mul(255) + fakes2 = utils.denormalize_img(fakes2).mul(255) + real_data = utils.denormalize_img(batch["img"]).mul(255) + lpips_1, real_lpips_feats, fake1_lpips_feats = lpips_model(real_data, fakes1) + fake2_lpips_feats = lpips_model.get_feats(fakes2) + lpips_2 = lpips_model.lpips_from_feats(real_lpips_feats, fake2_lpips_feats) + + lpips_total += lpips_1.sum().add(lpips_2.sum()).div(2) + diversity_total += lpips_model.lpips_from_feats(fake1_lpips_feats, fake2_lpips_feats).sum() + if not has_fid_cache: + fid_features_real[sidx:eidx] = fid_model(real_data.byte())[0] + fid_features_fake[sidx:eidx] = fid_model(fakes1.byte())[0] + fid_features_fake = fid_features_fake[:n_samples_seen] + if has_fid_cache: + if tops.rank() == 0: + with open(fid_cache_path, "rb") as fp: + fid_stat_real = pickle.load(fp) + else: + fid_features_real = fid_features_real[:n_samples_seen] + fid_features_real = tops.all_gather_uneven(fid_features_real).cpu() + if tops.rank() == 0: + fid_stat_real = fid_features_to_statistics(fid_features_real) + cache_directory.mkdir(exist_ok=True, parents=True) + with open(fid_cache_path, "wb") as fp: + pickle.dump(fid_stat_real, fp) + fid_features_fake = tops.all_gather_uneven(fid_features_fake).cpu() + if tops.rank() == 0: + print("Starting calculation of fid from features of shape:", fid_features_fake.shape) + fid_stat_fake = fid_features_to_statistics(fid_features_fake) + fid_ = fid_statistics_to_metric(fid_stat_real, fid_stat_fake, verbose=False)["frechet_inception_distance"] + tops.all_reduce(n_samples_seen, torch.distributed.ReduceOp.SUM) + tops.all_reduce(lpips_total, torch.distributed.ReduceOp.SUM) + tops.all_reduce(diversity_total, torch.distributed.ReduceOp.SUM) + lpips_total = lpips_total / n_samples_seen + diversity_total = diversity_total / n_samples_seen + to_return = dict(lpips=lpips_total, lpips_diversity=diversity_total) + if tops.rank() == 0: + to_return["fid"] = fid_ + else: + to_return["fid"] = -1 + to_return["validation_time_s"] = time.time() - start_time + return to_return + + +@torch.no_grad() +def compute_lpips( + dataloader, generator, + truncation_value: float = None, + data_len=None, + ) -> dict: + """ + Args: + n_samples (int): Creates N samples from same image to calculate stats + dataset_percentage (float): The percentage of the dataset to compute metrics on. + """ + global lpips_model, fid_model + if lpips_model is None: + lpips_model = tops.to_cuda(SampleSimilarityLPIPS()) + start_time = time.time() + lpips_total = torch.tensor(0, dtype=torch.float32, device=tops.get_device()) + diversity_total = torch.zeros_like(lpips_total) + if data_len is None: + data_len = len(dataloader) * dataloader.batch_size + eidx = 0 + n_samples_seen = torch.tensor([0], dtype=torch.int32, device=tops.get_device()) + for batch in utils.tqdm_(dataloader, desc="Validating on dataset."): + sidx = eidx + eidx = sidx + batch["img"].shape[0] + n_samples_seen += batch["img"].shape[0] + with torch.cuda.amp.autocast(tops.AMP()): + fakes1 = generator.sample(**batch, truncation_value=truncation_value)["img"] + fakes2 = generator.sample(**batch, truncation_value=truncation_value)["img"] + real_data = batch["img"] + fakes1 = utils.denormalize_img(fakes1).mul(255) + fakes2 = utils.denormalize_img(fakes2).mul(255) + real_data = utils.denormalize_img(real_data).mul(255) + lpips_1, real_lpips_feats, fake1_lpips_feats = lpips_model(real_data, fakes1) + fake2_lpips_feats = lpips_model.get_feats(fakes2) + lpips_2 = lpips_model.lpips_from_feats(real_lpips_feats, fake2_lpips_feats) + + lpips_total += lpips_1.sum().add(lpips_2.sum()).div(2) + diversity_total += lpips_model.lpips_from_feats(fake1_lpips_feats, fake2_lpips_feats).sum() + tops.all_reduce(n_samples_seen, torch.distributed.ReduceOp.SUM) + tops.all_reduce(lpips_total, torch.distributed.ReduceOp.SUM) + tops.all_reduce(diversity_total, torch.distributed.ReduceOp.SUM) + lpips_total = lpips_total / n_samples_seen + diversity_total = diversity_total / n_samples_seen + to_return = dict(lpips=lpips_total, lpips_diversity=diversity_total) + to_return = {k: v.cpu().item() for k, v in to_return.items()} + to_return["validation_time_s"] = time.time() - start_time + return to_return diff --git a/dp2/utils/__init__.py b/dp2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4edbacb6e8032ea081839f1a2408d4101868e79 --- /dev/null +++ b/dp2/utils/__init__.py @@ -0,0 +1,24 @@ +import pathlib +from tops.config import LazyConfig +from .torch_utils import ( + im2torch, im2numpy, denormalize_img, set_requires_grad, forward_D_fake, + binary_dilation, crop_box, remove_pad, + torch_wasserstein_loss +) +from .ema import EMA +from .utils import init_tops, tqdm_, print_config, config_to_str, trange_ +from .cse import from_E_to_vertex + + +def load_config(config_path): + config_path = pathlib.Path(config_path) + assert config_path.is_file(), config_path + cfg = LazyConfig.load(str(config_path)) + cfg.output_dir = pathlib.Path(str(config_path).replace("configs", str(cfg.common.output_dir)).replace(".py", "")) + if cfg.common.experiment_name is None: + cfg.experiment_name = str(config_path) + else: + cfg.experiment_name = cfg.common.experiment_name + cfg.checkpoint_dir = cfg.output_dir.joinpath("checkpoints") + print("Saving outputs to:", cfg.output_dir) + return cfg diff --git a/dp2/utils/bufferless_video_capture.py b/dp2/utils/bufferless_video_capture.py new file mode 100644 index 0000000000000000000000000000000000000000..dd5e1006057706f32c6adaeb812bf4834bbdfd28 --- /dev/null +++ b/dp2/utils/bufferless_video_capture.py @@ -0,0 +1,32 @@ +import queue +import threading +import cv2 + + +class BufferlessVideoCapture: + + def __init__(self, name, width=None, height=None): + self.cap = cv2.VideoCapture(name) + if width is not None and height is not None: + self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, width) + self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height) + self.q = queue.Queue() + t = threading.Thread(target=self._reader) + t.daemon = True + t.start() + + # read frames as soon as they are available, keeping only most recent one + def _reader(self): + while True: + ret, frame = self.cap.read() + if not ret: + break + if not self.q.empty(): + try: + self.q.get_nowait() # discard previous (unprocessed) frame + except queue.Empty: + pass + self.q.put((ret, frame)) + + def read(self): + return self.q.get() diff --git a/dp2/utils/cse.py b/dp2/utils/cse.py new file mode 100644 index 0000000000000000000000000000000000000000..cd3e01d28ba10e6d4d14ecd49c3a70dcaaa194ce --- /dev/null +++ b/dp2/utils/cse.py @@ -0,0 +1,21 @@ +import warnings +import torch +from densepose.modeling.cse.utils import get_closest_vertices_mask_from_ES + + +def from_E_to_vertex(E, M, embed_map): + """ + M is 1 for unkown regions + """ + assert len(E.shape) == 4 + assert len(E.shape) == len(M.shape), (E.shape, M.shape) + assert E.shape[0] == 1 + M = M.float() + M = torch.cat([M, 1-M], dim=1) + with warnings.catch_warnings(): # Ignore userError for pytorch interpolate from detectron2 + warnings.filterwarnings("ignore") + vertices, _ = get_closest_vertices_mask_from_ES( + E, M, E.shape[2], E.shape[3], + embed_map, device=E.device) + + return vertices.long() diff --git a/dp2/utils/ema.py b/dp2/utils/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..475e6b5192575ad5a54541714b6c932227cbe7a3 --- /dev/null +++ b/dp2/utils/ema.py @@ -0,0 +1,80 @@ +import torch +import copy +import tops +from tops import logger +from .torch_utils import set_requires_grad + + +class EMA: + """ + Expoenential moving average. + See: + Yazici, Y. et al.The unusual effectiveness of averaging in GAN training. ICLR 2019 + + """ + + def __init__( + self, + generator: torch.nn.Module, + batch_size: int, + rampup: float, + ): + self.rampup = rampup + self._nimg_half_time = batch_size * 10 / 32 * 1000 + self._batch_size = batch_size + with torch.no_grad(): + self.generator = copy.deepcopy(generator.cpu()).eval() + self.generator = tops.to_cuda(self.generator) + self.old_ra_beta = 0 + set_requires_grad(self.generator, False) + + def update_beta(self): + y = self._nimg_half_time + global_step = logger.global_step() + if self.rampup != None: + y = min(y, global_step*self.rampup) + self.ra_beta = 0.5 ** (self._batch_size/max(y, 1e-8)) + if self.ra_beta != self.old_ra_beta: + logger.add_scalar("stats/EMA_beta", self.ra_beta) + self.old_ra_beta = self.ra_beta + + @torch.no_grad() + def update(self, normal_G): + with torch.autograd.profiler.record_function("EMA_update"): + for ema_p, p in zip(self.generator.parameters(), + normal_G.parameters()): + ema_p.copy_(p.lerp(ema_p, self.ra_beta)) + for ema_buf, buff in zip(self.generator.buffers(), + normal_G.buffers()): + ema_buf.copy_(buff) + + def __call__(self, *args, **kwargs): + return self.generator(*args, **kwargs) + + def __getattr__(self, name: str): + if hasattr(self.generator, name): + return getattr(self.generator, name) + raise AttributeError(f"Generator object has no attribute {name}") + + def cuda(self, *args, **kwargs): + self.generator = self.generator.cuda() + return self + + def state_dict(self, *args, **kwargs): + return self.generator.state_dict(*args, **kwargs) + + def load_state_dict(self, *args, **kwargs): + return self.generator.load_state_dict(*args, **kwargs) + + def eval(self): + self.generator.eval() + + def train(self): + self.generator.train() + + @property + def module(self): + return self.generator.module + + def sample(self, *args, **kwargs): + return self.generator.sample(*args, **kwargs) diff --git a/dp2/utils/torch_utils.py b/dp2/utils/torch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..80ab53e3dcedce4710a41d1d58bd292a9fa08432 --- /dev/null +++ b/dp2/utils/torch_utils.py @@ -0,0 +1,140 @@ +import torch +import tops + + +def denormalize_img(image, mean=0.5, std=0.5): + image = image * std + mean + image = torch.clamp(image.float(), 0, 1) + image = (image * 255) + image = torch.round(image) + return image / 255 + + +@torch.no_grad() +def im2numpy(images, to_uint8=False, denormalize=False): + if denormalize: + images = denormalize_img(images) + if images.dtype != torch.uint8: + images = images.clamp(0, 1) + return tops.im2numpy(images, to_uint8=to_uint8) + + +@torch.no_grad() +def im2torch(im, cuda=False, normalize=True, to_float=True): + im = tops.im2torch(im, cuda=cuda, to_float=to_float) + if normalize: + assert im.min() >= 0.0 and im.max() <= 1.0 + if normalize: + im = im * 2 - 1 + return im + + +@torch.no_grad() +def binary_dilation(im: torch.Tensor, kernel: torch.Tensor): + assert len(im.shape) == 4 + assert len(kernel.shape) == 2 + kernel = kernel.unsqueeze(0).unsqueeze(0) + padding = kernel.shape[-1]//2 + assert kernel.shape[-1] % 2 != 0 + if isinstance(im, torch.cuda.FloatTensor): + im, kernel = im.half(), kernel.half() + else: + im, kernel = im.float(), kernel.float() + im = torch.nn.functional.conv2d( + im, kernel, groups=im.shape[1], padding=padding) + im = im > 0.5 + return im + + +@torch.no_grad() +def binary_erosion(im: torch.Tensor, kernel: torch.Tensor): + assert len(im.shape) == 4 + assert len(kernel.shape) == 2 + kernel = kernel.unsqueeze(0).unsqueeze(0) + padding = kernel.shape[-1]//2 + assert kernel.shape[-1] % 2 != 0 + if isinstance(im, torch.cuda.FloatTensor): + im, kernel = im.half(), kernel.half() + else: + im, kernel = im.float(), kernel.float() + ksum = kernel.sum() + padding = (padding, padding, padding, padding) + im = torch.nn.functional.pad(im, padding, mode="reflect") + im = torch.nn.functional.conv2d( + im, kernel, groups=im.shape[1]) + return im.round() == ksum + + +def set_requires_grad(value: torch.nn.Module, requires_grad: bool): + if isinstance(value, (list, tuple)): + for param in value: + param.requires_grad = requires_grad + return + for p in value.parameters(): + p.requires_grad = requires_grad + + +def forward_D_fake(batch, fake_img, discriminator, **kwargs): + fake_batch = {k: v for k, v in batch.items() if k != "img"} + fake_batch["img"] = fake_img + return discriminator(**fake_batch, **kwargs) + + +def remove_pad(x: torch.Tensor, bbox_XYXY, imshape): + """ + Remove padding that is shown as negative + """ + H, W = imshape + x0, y0, x1, y1 = bbox_XYXY + padding = [ + max(0, -x0), + max(0, -y0), + max(x1 - W, 0), + max(y1 - H, 0) + ] + x0, y0 = padding[:2] + x1 = x.shape[2] - padding[2] + y1 = x.shape[1] - padding[3] + return x[:, y0:y1, x0:x1] + + +def crop_box(x: torch.Tensor, bbox_XYXY) -> torch.Tensor: + """ + Crops x by bbox_XYXY. + """ + x0, y0, x1, y1 = bbox_XYXY + x0 = max(x0, 0) + y0 = max(y0, 0) + x1 = min(x1, x.shape[-1]) + y1 = min(y1, x.shape[-2]) + return x[..., y0:y1, x0:x1] + + +def torch_wasserstein_loss(tensor_a, tensor_b): + # Compute the first Wasserstein distance between two 1D distributions. + return (torch_cdf_loss(tensor_a, tensor_b, p=1)) + + +def torch_cdf_loss(tensor_a, tensor_b, p=1): + # last-dimension is weight distribution + # p is the norm of the distance, p=1 --> First Wasserstein Distance + # to get a positive weight with our normalized distribution + # we recommend combining this loss with other difference-based losses like L1 + + # normalize distribution, add 1e-14 to divisor to avoid 0/0 + tensor_a = tensor_a / (torch.sum(tensor_a, dim=-1, keepdim=True) + 1e-14) + tensor_b = tensor_b / (torch.sum(tensor_b, dim=-1, keepdim=True) + 1e-14) + # make cdf with cumsum + cdf_tensor_a = torch.cumsum(tensor_a, dim=-1) + cdf_tensor_b = torch.cumsum(tensor_b, dim=-1) + + # choose different formulas for different norm situations + if p == 1: + cdf_distance = torch.sum(torch.abs((cdf_tensor_a-cdf_tensor_b)), dim=-1) + elif p == 2: + cdf_distance = torch.sqrt(torch.sum(torch.pow((cdf_tensor_a-cdf_tensor_b), 2), dim=-1)) + else: + cdf_distance = torch.pow(torch.sum(torch.pow(torch.abs(cdf_tensor_a-cdf_tensor_b), p), dim=-1), 1/p) + + cdf_loss = cdf_distance.mean() + return cdf_loss diff --git a/dp2/utils/utils.py b/dp2/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f3dda0ecd189f2b1b21726a12ced68af348bdb5e --- /dev/null +++ b/dp2/utils/utils.py @@ -0,0 +1,30 @@ +import tops +import tqdm +from tops import logger, highlight_py_str +from tops.config import LazyConfig + + +def print_config(cfg): + logger.log("\n" + highlight_py_str(LazyConfig.to_py(cfg, prefix=""))) + + +def config_to_str(cfg): + return LazyConfig.to_py(cfg, prefix=".") + + +def init_tops(cfg, reinit=False): + tops.init( + cfg.output_dir, cfg.common.logger_backend, cfg.experiment_name, + cfg.common.wandb_project, dict(cfg), reinit) + + +def tqdm_(iterator, *args, **kwargs): + if tops.rank() == 0: + return tqdm.tqdm(iterator, *args, **kwargs) + return iterator + + +def trange_(*args, **kwargs): + if tops.rank() == 0: + return tqdm.trange(*args, **kwargs) + return range(*args) diff --git a/dp2/utils/vis_utils.py b/dp2/utils/vis_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..996e062ce082cc4da8355551317f4678bc75feaa --- /dev/null +++ b/dp2/utils/vis_utils.py @@ -0,0 +1,440 @@ +import torch +import tops +import cv2 +import torchvision.transforms.functional as F +from typing import Optional, List, Union, Tuple +from .cse import from_E_to_vertex +import numpy as np +from tops import download_file +from .torch_utils import ( + denormalize_img, binary_dilation, binary_erosion, + remove_pad, crop_box) +from torchvision.utils import _generate_color_palette +from PIL import Image, ImageColor, ImageDraw + + +def get_coco_keypoints(): + # From: https://github.com/facebookresearch/Detectron/blob/main/detectron/utils/keypoints.py + keypoints = [ + 'nose', + 'left_eye', + 'right_eye', + 'left_ear', + 'right_ear', + 'left_shoulder', + 'right_shoulder', + 'left_elbow', + 'right_elbow', + 'left_wrist', + 'right_wrist', + 'left_hip', + 'right_hip', + 'left_knee', + 'right_knee', + 'left_ankle', + 'right_ankle' + ] + keypoint_flip_map = { + 'left_eye': 'right_eye', + 'left_ear': 'right_ear', + 'left_shoulder': 'right_shoulder', + 'left_elbow': 'right_elbow', + 'left_wrist': 'right_wrist', + 'left_hip': 'right_hip', + 'left_knee': 'right_knee', + 'left_ankle': 'right_ankle' + } + connectivity = { + "nose": "left_eye", + "left_eye": "right_eye", + "right_eye": "nose", + "left_ear": "left_eye", + "right_ear": "right_eye", + "left_shoulder": "nose", + "right_shoulder": "nose", + "left_elbow": "left_shoulder", + "right_elbow": "right_shoulder", + "left_wrist": "left_elbow", + "right_wrist": "right_elbow", + "left_hip": "left_shoulder", + "right_hip": "right_shoulder", + "left_knee": "left_hip", + "right_knee": "right_hip", + "left_ankle": "left_knee", + "right_ankle": "right_knee" + } + connectivity_indices = [ + (sidx, keypoints.index(connectivity[kp])) + for sidx, kp in enumerate(keypoints) + ] + return keypoints, keypoint_flip_map, connectivity_indices + + +def get_coco_colors(): + return [ + *["red"]*5, + "blue", + "green", + "blue", + "green", + "blue", + "green", + "purple", + "orange", + "purple", + "orange", + "purple", + "orange", + ] + + +@torch.no_grad() +def draw_keypoints( + image: torch.Tensor, + keypoints: torch.Tensor, + connectivity: Optional[List[Tuple[int, int]]] = None, + visible: Optional[List[List[bool]]] = None, + colors: Optional[Union[str, Tuple[int, int, int]]] = None, + radius: int = None, + width: int = None, +) -> torch.Tensor: + """ + Function taken from torchvision source code. Added in torchvision 0.12 + + Draws Keypoints on given RGB image. + The values of the input image should be uint8 between 0 and 255. + + Args: + image (Tensor): Tensor of shape (3, H, W) and dtype uint8. + keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances, + in the format [x, y]. + connectivity (List[Tuple[int, int]]]): A List of tuple where, + each tuple contains pair of keypoints to be connected. + colors (str, Tuple): The color can be represented as + PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. + radius (int): Integer denoting radius of keypoint. + width (int): Integer denoting width of line connecting keypoints. + + Returns: + img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn. + """ + + if not isinstance(image, torch.Tensor): + raise TypeError(f"The image must be a tensor, got {type(image)}") + elif image.dtype != torch.uint8: + raise ValueError(f"The image dtype must be uint8, got {image.dtype}") + elif image.dim() != 3: + raise ValueError("Pass individual images, not batches") + elif image.size()[0] != 3: + raise ValueError("Pass an RGB image. Other Image formats are not supported") + + if keypoints.ndim != 3: + raise ValueError("keypoints must be of shape (num_instances, K, 2)") + if width is None: + width = int(max(max(image.shape[-2:]) * 0.01, 1)) + if radius is None: + radius = int(max(max(image.shape[-2:]) * 0.01, 1)) + + ndarr = image.permute(1, 2, 0).cpu().numpy() + img_to_draw = Image.fromarray(ndarr) + draw = ImageDraw.Draw(img_to_draw) + if isinstance(keypoints, torch.Tensor): + img_kpts = keypoints.to(torch.int64).tolist() + else: + assert isinstance(keypoints, np.ndarray) + img_kpts = keypoints.astype(int).tolist() + colors = get_coco_colors() + for inst_id, kpt_inst in enumerate(img_kpts): + + for kpt_id, kpt in enumerate(kpt_inst): + if visible is not None and int(visible[inst_id][kpt_id]) == 0: + continue + x1 = kpt[0] - radius + x2 = kpt[0] + radius + y1 = kpt[1] - radius + y2 = kpt[1] + radius + + draw.ellipse([x1, y1, x2, y2], fill=colors[kpt_id], outline=None, width=0) + + if connectivity is not None: + for connection in connectivity: + if connection[1] >= len(kpt_inst) or connection[0] >= len(kpt_inst): + continue + if visible is not None and int(visible[inst_id][connection[1]]) == 0 or int(visible[inst_id][connection[0]]) == 0: + continue + + start_pt_x = kpt_inst[connection[0]][0] + start_pt_y = kpt_inst[connection[0]][1] + + end_pt_x = kpt_inst[connection[1]][0] + end_pt_y = kpt_inst[connection[1]][1] + + draw.line( + ((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)), + width=width, fill=colors[connection[1]] + ) + + return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) + + +def visualize_keypoints(img, keypoints): + img = img.clone() + keypoints = keypoints.clone() + keypoints[:, :, 0] *= img.shape[-1] + keypoints[:, :, 1] *= img.shape[-2] + _, _, connectivity = get_coco_keypoints() + connectivity = np.array(connectivity) + visible = None + if keypoints.shape[-1] == 3: + visible = keypoints[:, :, 2] > 0 + for idx in range(img.shape[0]): + img[idx] = draw_keypoints( + img[idx], keypoints[idx:idx+1].long(), colors="red", + connectivity=connectivity, visible=visible[idx:idx+1]) + return img + + +def visualize_batch( + img: torch.Tensor, mask: torch.Tensor, + vertices: torch.Tensor = None, + E_mask: torch.Tensor = None, + embed_map: torch.Tensor = None, + semantic_mask: torch.Tensor = None, + embedding: torch.Tensor = None, + keypoints: torch.Tensor = None, + maskrcnn_mask: torch.Tensor = None, + **kwargs) -> torch.ByteTensor: + img = denormalize_img(img).mul(255).round().clamp(0, 255).byte() + img = draw_mask(img, mask) + if maskrcnn_mask is not None and maskrcnn_mask.shape == mask.shape: + img = draw_mask(img, maskrcnn_mask) + if vertices is not None or embedding is not None: + assert E_mask is not None + assert embed_map is not None + img, E_mask, embedding, embed_map, vertices = tops.to_cpu([ + img, E_mask, embedding, embed_map, vertices + ]) + img = draw_cse(img, E_mask, embedding, embed_map, vertices) + elif semantic_mask is not None: + img = draw_segmentation_masks(img, semantic_mask) + if keypoints is not None: + img = visualize_keypoints(img, keypoints) + return img + + +@torch.no_grad() +def draw_cse( + img: torch.Tensor, E_seg: torch.Tensor, + embedding: torch.Tensor = None, + embed_map: torch.Tensor = None, + vertices: torch.Tensor = None, t=0.7 +): + """ + E_seg: 1 for areas with embedding + """ + assert img.dtype == torch.uint8 + img = img.view(-1, *img.shape[-3:]) + E_seg = E_seg.view(-1, 1, *E_seg.shape[-2:]) + if vertices is None: + assert embedding is not None + assert embed_map is not None + embedding = embedding.view(-1, *embedding.shape[-3:]) + vertices = torch.stack( + [from_E_to_vertex(e[None], e_seg[None].logical_not().float(), embed_map) + for e, e_seg in zip(embedding, E_seg)]) + + i = np.arange(0, 256, dtype=np.uint8).reshape(1, -1) + colormap_JET = torch.from_numpy(cv2.applyColorMap(i, cv2.COLORMAP_JET)[0]) + color_embed_map, _ = np.load(download_file( + "https://dl.fbaipublicfiles.com/densepose/data/cse/mds_d=256.npy"), allow_pickle=True) + color_embed_map = torch.from_numpy(color_embed_map).float()[:, 0] + color_embed_map -= color_embed_map.min() + color_embed_map /= color_embed_map.max() + vertx2idx = (color_embed_map*255).long() + vertx2colormap = colormap_JET[vertx2idx] + + vertices = vertices.view(-1, *vertices.shape[-2:]) + E_seg = E_seg.view(-1, 1, *E_seg.shape[-2:]) + # This operation might be good to do on cpu... + + E_color = vertx2colormap[vertices.long()] + E_color = E_color.to(E_seg.device) + E_color = E_color.permute(0, 3, 1, 2) + E_color = E_color*E_seg.byte() + + m = E_seg.bool().repeat(1, 3, 1, 1) + img[m] = (img[m] * (1-t) + t * E_color[m]).byte() + return img + + +def draw_cse_all( + embedding: List[torch.Tensor], E_mask: List[torch.Tensor], + im: torch.Tensor, boxes_XYXY: list, embed_map: torch.Tensor, t=0.7): + """ + E_seg: 1 for areas with embedding + """ + assert len(im.shape) == 3, im.shape + assert im.dtype == torch.uint8 + + N = len(E_mask) + im = im.clone() + for i in range(N): + assert len(E_mask[i].shape) == 2 + assert len(embedding[i].shape) == 3 + assert embed_map.shape[1] == embedding[i].shape[0] + assert len(boxes_XYXY[i]) == 4 + E = embedding[i] + x0, y0, x1, y1 = boxes_XYXY[i] + E = F.resize(E, (y1-y0, x1-x0), antialias=True) + s = E_mask[i].float() + s = (F.resize(s.squeeze()[None], (y1-y0, x1-x0), antialias=True) > 0).float() + box = boxes_XYXY[i] + + im_ = crop_box(im, box) + s = remove_pad(s, box, im.shape[1:]) + E = remove_pad(E, box, im.shape[1:]) + E_color = draw_cse(img=im_, E_seg=s[None], embedding=E[None], embed_map=embed_map)[0] + E_color = E_color.to(im.device) + s = s.bool().repeat(3, 1, 1) + crop_box(im, box)[s] = (im_[s] * (1-t) + t * E_color[s]).byte() + return im + + +@torch.no_grad() +def draw_segmentation_masks( + image: torch.Tensor, + masks: torch.Tensor, + alpha: float = 0.8, + colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None, +) -> torch.Tensor: + """ + Draws segmentation masks on given RGB image. + The values of the input image should be uint8 between 0 and 255. + + Args: + image (Tensor): Tensor of shape (3, H, W) and dtype uint8. + masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool. + alpha (float): Float number between 0 and 1 denoting the transparency of the masks. + 0 means full transparency, 1 means no transparency. + colors (list or None): List containing the colors of the masks. The colors can + be represented as PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. + When ``masks`` has a single entry of shape (H, W), you can pass a single color instead of a list + with one element. By default, random colors are generated for each mask. + + Returns: + img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top. + """ + + if not isinstance(image, torch.Tensor): + raise TypeError(f"The image must be a tensor, got {type(image)}") + elif image.dtype != torch.uint8: + raise ValueError(f"The image dtype must be uint8, got {image.dtype}") + elif image.dim() != 3: + raise ValueError("Pass individual images, not batches") + elif image.size()[0] != 3: + raise ValueError("Pass an RGB image. Other Image formats are not supported") + if masks.ndim == 2: + masks = masks[None, :, :] + if masks.ndim != 3: + raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)") + if masks.dtype != torch.bool: + raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}") + if masks.shape[-2:] != image.shape[-2:]: + raise ValueError("The image and the masks must have the same height and width") + num_masks = masks.size()[0] + if num_masks == 0: + return image + if colors is None: + colors = _generate_color_palette(num_masks) + if not isinstance(colors[0], (Tuple, List)): + colors = [colors for i in range(num_masks)] + if colors is not None and num_masks > len(colors): + raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})") + + if not isinstance(colors, list): + colors = [colors] + if not isinstance(colors[0], (tuple, str)): + raise ValueError("colors must be a tuple or a string, or a list thereof") + if isinstance(colors[0], tuple) and len(colors[0]) != 3: + raise ValueError("It seems that you passed a tuple of colors instead of a list of colors") + + out_dtype = torch.uint8 + + colors_ = [] + for color in colors: + if isinstance(color, str): + color = ImageColor.getrgb(color) + color = torch.tensor(color, dtype=out_dtype, device=masks.device) + colors_.append(color) + img_to_draw = image.detach().clone() + # TODO: There might be a way to vectorize this + for mask, color in zip(masks, colors_): + img_to_draw[:, mask] = color[:, None] + + out = image * (1 - alpha) + img_to_draw * alpha + return out.to(out_dtype) + + +def draw_mask(im: torch.Tensor, mask: torch.Tensor, t=0.2, color=(255, 255, 255), visualize_instances=True): + """ + Visualize mask where mask = 0. + Supports multiple instances. + mask shape: [N, C, H, W], where C is different instances in same image. + """ + orig_imshape = im.shape + if mask.numel() == 0: + return im + assert len(mask.shape) in (3, 4), mask.shape + mask = mask.view(-1, *mask.shape[-3:]) + im = im.view(-1, *im.shape[-3:]) + assert im.dtype == torch.uint8, im.dtype + assert 0 <= t <= 1 + if not visualize_instances: + mask = mask.any(dim=1, keepdim=True) + mask = mask.bool() + kernel = torch.ones((3, 3), dtype=mask.dtype, device=mask.device) + outer_border = binary_dilation(mask, kernel).logical_xor(mask) + outer_border = outer_border.any(dim=1, keepdim=True).repeat(1, 3, 1, 1) > 0 + inner_border = binary_erosion(mask, kernel).logical_xor(mask) + inner_border = inner_border.any(dim=1, keepdim=True).repeat(1, 3, 1, 1) > 0 + mask = (mask == 0).any(dim=1, keepdim=True).repeat(1, 3, 1, 1) + color = torch.tensor(color).to(im.device).byte().view(1, 3, 1, 1) # .repeat(1, *im.shape[1:]) + color = color.repeat(im.shape[0], 1, *im.shape[-2:]) + im[mask] = (im[mask] * (1-t) + t * color[mask]).byte() + im[outer_border] = 255 + im[inner_border] = 0 + return im.view(*orig_imshape) + + +def draw_cropped_masks(im: torch.Tensor, mask: torch.Tensor, boxes: torch.Tensor, **kwargs): + for i, box in enumerate(boxes): + x0, y0, x1, y1 = boxes[i] + orig_shape = (y1-y0, x1-x0) + m = F.resize(mask[i], orig_shape, F.InterpolationMode.NEAREST).squeeze()[None] + m = remove_pad(m, boxes[i], im.shape[-2:]) + crop_box(im, boxes[i]).set_(draw_mask(crop_box(im, boxes[i]), m)) + return im + + +def draw_cropped_keypoints(im: torch.Tensor, all_keypoints: torch.Tensor, boxes: torch.Tensor, **kwargs): + n_boxes = boxes.shape[0] + tops.assert_shape(all_keypoints, (n_boxes, 17, 3)) + im = im.clone() + for i, box in enumerate(boxes): + + x0, y0, x1, y1 = boxes[i] + orig_shape = (y1-y0, x1-x0) + keypoints = all_keypoints[i].clone() + keypoints[:, 0] *= orig_shape[1] + keypoints[:, 1] *= orig_shape[0] + keypoints = keypoints.long() + _, _, connectivity = get_coco_keypoints() + connectivity = np.array(connectivity) + visible = (keypoints[:, 2] > .5) + # Remove padding from keypoints before visualization + keypoints[:, 0] += min(x0, 0) + keypoints[:, 1] += min(y0, 0) + im_with_kp = draw_keypoints( + crop_box(im, box), keypoints[None], colors="red", connectivity=connectivity, visible=visible[None]) + crop_box(im, box).copy_(im_with_kp) + return im diff --git a/gradio_demos/body_cse.py b/gradio_demos/body_cse.py new file mode 100644 index 0000000000000000000000000000000000000000..2ddf1859352841539fdf5c4b6e9d0f098264208d --- /dev/null +++ b/gradio_demos/body_cse.py @@ -0,0 +1,23 @@ +import gradio +from dp2 import utils +from tops.config import instantiate +import gradio.inputs +from gradio_demos.modules import ExampleDemo, WebcamDemo + + +cfg_body = utils.load_config("configs/anonymizers/FB_cse.py") +anonymizer_body = instantiate(cfg_body.anonymizer, load_cache=False) +anonymizer_body.initialize_tracker(fps=1) + + +with gradio.Blocks() as demo: + gradio.Markdown("#
DeepPrivacy2 - Realistic Image Anonymization
") + gradio.Markdown("###
Håkon Hukkelås, Rudolf Mester, Frank Lindseth
") + with gradio.Tab("Full-Body CSE Anonymization"): + ExampleDemo(anonymizer_body) + with gradio.Tab("Full-body CSE Webcam"): + WebcamDemo(anonymizer_body) + + +demo.launch() + diff --git a/gradio_demos/face.py b/gradio_demos/face.py new file mode 100644 index 0000000000000000000000000000000000000000..1e8692fd191a36c09b933b5e36f3e31d0f78341b --- /dev/null +++ b/gradio_demos/face.py @@ -0,0 +1,22 @@ +import gradio +from dp2 import utils +from tops.config import instantiate +import gradio.inputs +from gradio_demos.modules import ExampleDemo, WebcamDemo + +cfg_face = utils.load_config("configs/anonymizers/face.py") +anonymizer_face = instantiate(cfg_face.anonymizer, load_cache=False) +print(anonymizer_face.detector) +anonymizer_face.initialize_tracker(fps=1) + + +with gradio.Blocks() as demo: + gradio.Markdown("#
DeepPrivacy2 - Realistic Image Anonymization
") + gradio.Markdown("###
Håkon Hukkelås, Rudolf Mester, Frank Lindseth
") + with gradio.Tab("Face Anonymization"): + ExampleDemo(anonymizer_face) + with gradio.Tab("Live Webcam"): + WebcamDemo(anonymizer_face) + +demo.launch() + diff --git a/gradio_demos/modules.py b/gradio_demos/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..3682f3d79688c3a55294a4f5ca4704112e0fd860 --- /dev/null +++ b/gradio_demos/modules.py @@ -0,0 +1,247 @@ +from collections import defaultdict +import gradio +import numpy as np +import torch +import cv2 +from PIL import Image +from dp2 import utils +from tops.config import instantiate +import tops +import gradio.inputs +from stylemc import get_and_cache_direction, get_styles +from sg3_torch_utils.ops import grid_sample_gradfix, bias_act, upfirdn2d + +grid_sample_gradfix.enabled = False +bias_act.enabled = False +upfirdn2d.enabled = False + + +class GuidedDemo: + def __init__(self, face_anonymizer, cfg_face, multi_modal_truncation, truncation_value) -> None: + self.anonymizer = face_anonymizer + self.multi_modal_truncation = multi_modal_truncation + self.truncation_value = truncation_value + assert sum([x is not None for x in list(face_anonymizer.generators.values())]) == 1 + self.generator = [x for x in list(face_anonymizer.generators.values()) if x is not None][0] + face_G_cfg = utils.load_config(cfg_face.anonymizer.face_G_cfg) + face_G_cfg.train.batch_size = 1 + self.dl = instantiate(face_G_cfg.data.val.loader) + self.cache_dir = face_G_cfg.output_dir + self.precompute_edits() + + def precompute_edits(self): + self.precomputed_edits = set() + for edit in self.precomputed_edits: + get_and_cache_direction(self.cache_dir, self.dl, self.generator, edit) + if self.cache_dir.joinpath("stylemc_cache").is_dir(): + for path in self.cache_dir.joinpath("stylemc_cache").iterdir(): + text_prompt = path.stem.replace("_", " ") + self.precomputed_edits.add(text_prompt) + print(text_prompt) + self.edits = defaultdict(defaultdict) + + def anonymize(self, img, show_boxes: bool, current_box_idx: int, current_styles, current_boxes, update_identity, edits, cache_id=None): + if not isinstance(img, torch.Tensor): + img, cache_id = pil2torch(img) + img = tops.to_cuda(img) + + current_box_idx = current_box_idx % len(current_boxes) + edited_styles = [s.clone() for s in current_styles] + for face_idx, face_edits in edits.items(): + for prompt, strength in face_edits.items(): + direction = get_and_cache_direction(self.cache_dir, self.dl, self.generator, prompt) + edited_styles[int(face_idx)] += direction * strength + update_identity[int(face_idx)] = True + assert img.dtype == torch.uint8 + img = self.anonymizer( + img, truncation_value=self.truncation_value, + multi_modal_truncation=self.multi_modal_truncation, amp=True, + cache_id=cache_id, + all_styles=edited_styles, + update_identity=update_identity) + update_identity = [True for i in range(len(update_identity))] + img = utils.im2numpy(img) + if show_boxes: + x0, y0, x1, y1 = [int(_) for _ in current_boxes[int(current_box_idx)]] + img = cv2.rectangle(img, (x0, y0), (x1, y1), (255, 0, 0), 1) + return img, update_identity + + def update_image(self, img, show_boxes): + img, cache_id = pil2torch(img) + img = tops.to_cuda(img) + det = self.anonymizer.detector.forward_and_cache(img, cache_id, load_cache=True)[0] + current_styles = [] + for i in range(len(det)): + s = get_styles( + np.random.randint(0, 999999), self.generator, + None, truncation_value=self.truncation_value) + current_styles.append(s) + update_identity = [True for i in range(len(det))] + current_boxes = np.array(det.boxes) + edits = defaultdict(defaultdict) + cur_face_idx = -1 % len(current_boxes) + img, update_identity = self.anonymize( + img, show_boxes, cur_face_idx, + current_styles, current_boxes, update_identity, edits, cache_id=cache_id) + return img, current_styles, current_boxes, update_identity, edits, cur_face_idx + + def change_face(self, change, cur_face_idx, current_boxes, input_image, show_boxes, current_styles, update_identity, edits): + cur_face_idx = (cur_face_idx + change) % len(current_boxes) + img, update_identity = self.anonymize( + input_image, show_boxes, cur_face_idx, + current_styles, current_boxes, update_identity, edits) + return img, update_identity, cur_face_idx + + def add_style(self, face_idx: int, prompt: str, strength: float, input_image, show_boxes, current_styles, current_boxes, update_identity, edits): + face_idx = face_idx % len(current_boxes) + edits[face_idx][prompt] = strength + img, update_identity = self.anonymize( + input_image, show_boxes, face_idx, + current_styles, current_boxes, update_identity, edits) + return img, update_identity, edits + + def setup_interface(self): + current_styles = gradio.State() + current_boxes = gradio.State(None) + update_identity = gradio.State([]) + edits = gradio.State([]) + with gradio.Row(): + input_image = gradio.Image( + type="pil", label="Upload your image or try the example below!", source="webcam") + output_image = gradio.Image(type="numpy", label="Output") + with gradio.Row(): + update_btn = gradio.Button("Update Anonymization").style(full_width=True) + with gradio.Row(): + show_boxes = gradio.Checkbox(value=True, label="Show Selected") + cur_face_idx = gradio.Number(value=-1, label="Current", interactive=False) + previous = gradio.Button("Previous Person") + next_ = gradio.Button("Next Person") + with gradio.Row(): + text_prompt = gradio.Textbox( + placeholder=" | ".join(list(self.precomputed_edits)), + label="Text Prompt for Edit") + edit_strength = gradio.Slider(0, 5, step=.01) + add_btn = gradio.Button("Add Edit") + add_btn.click( + self.add_style, + inputs=[cur_face_idx, text_prompt, edit_strength, input_image, show_boxes,current_styles, current_boxes, update_identity, edits], + outputs=[output_image, update_identity, edits]) + update_btn.click( + self.update_image, + inputs=[input_image, show_boxes], + outputs=[output_image, current_styles, current_boxes, update_identity, edits, cur_face_idx]) + input_image.change( + self.update_image, + inputs=[input_image, show_boxes], + outputs=[output_image, current_styles, current_boxes, update_identity, edits, cur_face_idx]) + previous.click( + self.change_face, + inputs=[gradio.State(-1), cur_face_idx, current_boxes, input_image, show_boxes, current_styles, update_identity, edits], + outputs=[output_image, update_identity, cur_face_idx]) + next_.click( + self.change_face, + inputs=[gradio.State(1), cur_face_idx, current_boxes, input_image, show_boxes,current_styles, update_identity, edits], + outputs=[output_image, update_identity, cur_face_idx]) + show_boxes.change( + self.anonymize, + inputs=[input_image, show_boxes, cur_face_idx, current_styles, current_boxes, update_identity, edits], + outputs=[output_image, update_identity]) + + +class WebcamDemo: + + def __init__(self, anonymizer) -> None: + self.anonymizer = anonymizer + with gradio.Row(): + input_image = gradio.Image(type="pil", source="webcam", streaming=True) + output_image = gradio.Image(type="numpy", label="Output") + with gradio.Row(): + truncation_value = gradio.Slider(0, 1, value=0, step=0.01) + truncation = gradio.Radio(["Multi-modal truncation", "Unimodal truncation"], value="Unimodal truncation") + with gradio.Row(): + visualize_det = gradio.Checkbox(value=False, label="Show Detections") + track = gradio.Checkbox(value=False, label="Track detections (samples same latent variable per track)") + input_image.stream( + self.anonymize, + inputs=[input_image, visualize_det, truncation_value,truncation, track, gradio.Variable(False)], + outputs=[output_image]) + self.track = True + + def anonymize(self, img: Image, visualize_detection: bool, truncation_value, truncation_type, track, reset_track): + if reset_track: + self.anonymizer.reset_tracker() + mmt = truncation_type == "Multi-modal truncation" + img, cache_id = pil2torch(img) + img = tops.to_cuda(img) + self.anonymizer + if visualize_detection: + img = self.anonymizer.visualize_detection(img, cache_id=cache_id) + else: + img = self.anonymizer( + img, + truncation_value=truncation_value, + multi_modal_truncation=mmt, + amp=True, + cache_id=cache_id, + track=track) + img = utils.im2numpy(img) + return img + + +class ExampleDemo(WebcamDemo): + + def __init__(self, anonymizer) -> None: + self.anonymizer = anonymizer + with gradio.Row(): + input_image = gradio.Image(type="pil", source="webcam") + output_image = gradio.Image(type="numpy", label="Output") + with gradio.Row(): + update_btn = gradio.Button("Update Anonymization").style(full_width=True) + resample = gradio.Button("Resample Latent Variables").style(full_width=True) + with gradio.Row(): + truncation_value = gradio.Slider(0, 1, value=0, step=0.01) + truncation = gradio.Radio(["Multi-modal truncation", "Unimodal truncation"], value="Unimodal truncation") + visualize_det = gradio.Checkbox(value=False, label="Show Detections") + visualize_det.change( + self.anonymize, + inputs=[input_image, visualize_det, truncation_value, truncation, gradio.Variable(True), gradio.Variable(False)], + outputs=[output_image]) + gradio.Examples( + ["media/erling.jpg", "media/regjeringen.jpg"], inputs=[input_image] + ) + + update_btn.click( + self.anonymize, + inputs=[input_image, visualize_det, truncation_value, truncation, gradio.Variable(True), gradio.Variable(False)], + outputs=[output_image]) + resample.click( + self.anonymize, + inputs=[input_image, visualize_det, truncation_value, truncation, gradio.Variable(True), gradio.Variable(True)], + outputs=[output_image]) + input_image.change( + self.anonymize, + inputs=[input_image, visualize_det, truncation_value, truncation, gradio.Variable(False), gradio.Variable(True)], + outputs=[output_image]) + self.track = False + self.truncation_value = truncation_value + + +class Information: + + def __init__(self) -> None: + gradio.Markdown("##
Face Anonymization Architecture
") + gradio.Markdown("---") + gradio.Image(value="media/overall_architecture.png") + gradio.Markdown("##
Full-Body Anonymization Architecture
") + gradio.Markdown("---") + gradio.Image(value="media/full_body.png") + gradio.Markdown("###
Generative Adversarial Networks
") + gradio.Markdown("---") + gradio.Image(value="media/gan_architecture.png") + + +def pil2torch(img: Image.Image): + img = img.convert("RGB") + img = np.array(img) + img = np.rollaxis(img, 2) + return torch.from_numpy(img), None diff --git a/media/erling.jpg b/media/erling.jpg new file mode 100644 index 0000000000000000000000000000000000000000..78612fd7643f95d48c6188e8b3c7b6f5557aef30 --- /dev/null +++ b/media/erling.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d1731504c0bdad6ed94578d13bc53c65181fef3c13db1a9214c6e9a255d9b02e +size 1227474 diff --git a/media/g7_leaders.jpg b/media/g7_leaders.jpg new file mode 100644 index 0000000000000000000000000000000000000000..573608b1f13612f30dfb6a4cd4b35db68543f64b --- /dev/null +++ b/media/g7_leaders.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:480fa6b02b77e12bb7e3b85c4c0c15797040fceb5f87eaaefdb626c2dfb49255 +size 2121459 diff --git a/media/regjeringen.jpg b/media/regjeringen.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f6ea34a1053634a1be00e06417aa6fddf838eea0 --- /dev/null +++ b/media/regjeringen.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3a59538772ccc8f9a0a8b490017ffd7e1dfad2527fac9bf8f49b3a227117335a +size 526611 diff --git a/media/stylemc_example.jpg b/media/stylemc_example.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4b3f23f450425f3afd19b1cae94dc75c86e16b13 --- /dev/null +++ b/media/stylemc_example.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3c95012af3c4a18229ca0d3031e22cae182cd8e41518b51b26a0040408e8210a +size 2513762 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..8d33bc492f1ab5454ab5f85f560e4a4fa015181c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,24 @@ +numpy +cython +matplotlib +tqdm +tensorboard +opencv-python +torch_fidelity +ninja +moviepy +pyspng +wandb +termcolor +tops@git+https://github.com/hukkelas/torch_ops.git +motpy@git+https://github.com/wmuron/motpy@c77f85d27e371c0a298e9a88ca99292d9b9cbe6b +fast_pytorch_kmeans +einops +einops_exts +regex +setuptools +resize_right +pillow +scipy +webdataset +scikit-image \ No newline at end of file diff --git a/sg3_torch_utils/LICENSE.txt b/sg3_torch_utils/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..6b5ee9bf994cc9441cb659c3527160b4ee5bcb33 --- /dev/null +++ b/sg3_torch_utils/LICENSE.txt @@ -0,0 +1,97 @@ +Copyright (c) 2021, NVIDIA Corporation & affiliates. All rights reserved. + + +NVIDIA Source Code License for StyleGAN3 + + +======================================================================= + +1. Definitions + +"Licensor" means any person or entity that distributes its Work. + +"Software" means the original work of authorship made available under +this License. + +"Work" means the Software and any additions to or derivative works of +the Software that are made available under this License. + +The terms "reproduce," "reproduction," "derivative works," and +"distribution" have the meaning as provided under U.S. copyright law; +provided, however, that for the purposes of this License, derivative +works shall not include works that remain separable from, or merely +link (or bind by name) to the interfaces of, the Work. + +Works, including the Software, are "made available" under this License +by including in or with the Work either (a) a copyright notice +referencing the applicability of this License to the Work, or (b) a +copy of this License. + +2. License Grants + + 2.1 Copyright Grant. Subject to the terms and conditions of this + License, each Licensor grants to you a perpetual, worldwide, + non-exclusive, royalty-free, copyright license to reproduce, + prepare derivative works of, publicly display, publicly perform, + sublicense and distribute its Work and any resulting derivative + works in any form. + +3. Limitations + + 3.1 Redistribution. You may reproduce or distribute the Work only + if (a) you do so under this License, (b) you include a complete + copy of this License with your distribution, and (c) you retain + without modification any copyright, patent, trademark, or + attribution notices that are present in the Work. + + 3.2 Derivative Works. You may specify that additional or different + terms apply to the use, reproduction, and distribution of your + derivative works of the Work ("Your Terms") only if (a) Your Terms + provide that the use limitation in Section 3.3 applies to your + derivative works, and (b) you identify the specific derivative + works that are subject to Your Terms. Notwithstanding Your Terms, + this License (including the redistribution requirements in Section + 3.1) will continue to apply to the Work itself. + + 3.3 Use Limitation. The Work and any derivative works thereof only + may be used or intended for use non-commercially. Notwithstanding + the foregoing, NVIDIA and its affiliates may use the Work and any + derivative works commercially. As used herein, "non-commercially" + means for research or evaluation purposes only. + + 3.4 Patent Claims. If you bring or threaten to bring a patent claim + against any Licensor (including any claim, cross-claim or + counterclaim in a lawsuit) to enforce any patents that you allege + are infringed by any Work, then your rights under this License from + such Licensor (including the grant in Section 2.1) will terminate + immediately. + + 3.5 Trademarks. This License does not grant any rights to use any + Licensor’s or its affiliates’ names, logos, or trademarks, except + as necessary to reproduce the notices described in this License. + + 3.6 Termination. If you violate any term of this License, then your + rights under this License (including the grant in Section 2.1) will + terminate immediately. + +4. Disclaimer of Warranty. + +THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR +NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER +THIS LICENSE. + +5. Limitation of Liability. + +EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL +THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE +SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, +INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF +OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK +(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, +LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER +COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF +THE POSSIBILITY OF SUCH DAMAGES. + +======================================================================= diff --git a/sg3_torch_utils/__init__.py b/sg3_torch_utils/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..ece0ea08fe2e939cc260a1dafc0ab5b391b773d9 --- /dev/null +++ b/sg3_torch_utils/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +# empty diff --git a/sg3_torch_utils/custom_ops.py b/sg3_torch_utils/custom_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..4cc4e43fc6f6ce79f2bd68a44ba87990b9b8564e --- /dev/null +++ b/sg3_torch_utils/custom_ops.py @@ -0,0 +1,126 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import os +import glob +import torch +import torch.utils.cpp_extension +import importlib +import hashlib +import shutil +from pathlib import Path + +from torch.utils.file_baton import FileBaton + +#---------------------------------------------------------------------------- +# Global options. + +verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' + +#---------------------------------------------------------------------------- +# Internal helper funcs. + +def _find_compiler_bindir(): + patterns = [ + 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', + ] + for pattern in patterns: + matches = sorted(glob.glob(pattern)) + if len(matches): + return matches[-1] + return None + +#---------------------------------------------------------------------------- +# Main entry point for compiling and loading C++/CUDA plugins. + +_cached_plugins = dict() + +def get_plugin(module_name, sources, **build_kwargs): + assert verbosity in ['none', 'brief', 'full'] + + # Already cached? + if module_name in _cached_plugins: + return _cached_plugins[module_name] + + # Print status. + if verbosity == 'full': + print(f'Setting up PyTorch plugin "{module_name}"...') + elif verbosity == 'brief': + print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) + + try: # pylint: disable=too-many-nested-blocks + # Make sure we can find the necessary compiler binaries. + if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: + compiler_bindir = _find_compiler_bindir() + if compiler_bindir is None: + raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') + os.environ['PATH'] += ';' + compiler_bindir + + # Compile and load. + verbose_build = (verbosity == 'full') + + # Incremental build md5sum trickery. Copies all the input source files + # into a cached build directory under a combined md5 digest of the input + # source files. Copying is done only if the combined digest has changed. + # This keeps input file timestamps and filenames the same as in previous + # extension builds, allowing for fast incremental rebuilds. + # + # This optimization is done only in case all the source files reside in + # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR + # environment variable is set (we take this as a signal that the user + # actually cares about this.) + source_dirs_set = set(os.path.dirname(source) for source in sources) + if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ): + all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())) + + # Compute a combined hash digest for all source files in the same + # custom op directory (usually .cu, .cpp, .py and .h files). + hash_md5 = hashlib.md5() + for src in all_source_files: + with open(src, 'rb') as f: + hash_md5.update(f.read()) + build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access + digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest()) + + if not os.path.isdir(digest_build_dir): + os.makedirs(digest_build_dir, exist_ok=True) + baton = FileBaton(os.path.join(digest_build_dir, 'lock')) + if baton.try_acquire(): + try: + for src in all_source_files: + shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src))) + finally: + baton.release() + else: + # Someone else is copying source files under the digest dir, + # wait until done and continue. + baton.wait() + digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources] + torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir, + verbose=verbose_build, sources=digest_sources, **build_kwargs) + else: + torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) + module = importlib.import_module(module_name) + + except: + if verbosity == 'brief': + print('Failed!') + raise + + # Print status and add to cache. + if verbosity == 'full': + print(f'Done setting up PyTorch plugin "{module_name}".') + elif verbosity == 'brief': + print('Done.') + _cached_plugins[module_name] = module + return module + +#---------------------------------------------------------------------------- diff --git a/sg3_torch_utils/misc.py b/sg3_torch_utils/misc.py new file mode 100755 index 0000000000000000000000000000000000000000..10d8e31880affdd185580b6f5b98e92c79597dc3 --- /dev/null +++ b/sg3_torch_utils/misc.py @@ -0,0 +1,172 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import re +import contextlib +import numpy as np +import torch +import warnings + +#---------------------------------------------------------------------------- +# Cached construction of constant tensors. Avoids CPU=>GPU copy when the +# same constant is used multiple times. + +_constant_cache = dict() + +def constant(value, shape=None, dtype=None, device=None, memory_format=None): + value = np.asarray(value) + if shape is not None: + shape = tuple(shape) + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = torch.device('cpu') + if memory_format is None: + memory_format = torch.contiguous_format + + key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) + tensor = _constant_cache.get(key, None) + if tensor is None: + tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) + if shape is not None: + tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) + tensor = tensor.contiguous(memory_format=memory_format) + _constant_cache[key] = tensor + return tensor + +#---------------------------------------------------------------------------- +# Replace NaN/Inf with specified numerical values. + +try: + nan_to_num = torch.nan_to_num # 1.8.0a0 +except AttributeError: + def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin + assert isinstance(input, torch.Tensor) + if posinf is None: + posinf = torch.finfo(input.dtype).max + if neginf is None: + neginf = torch.finfo(input.dtype).min + assert nan == 0 + return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) + +#---------------------------------------------------------------------------- +# Symbolic assert. + +try: + symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access +except AttributeError: + symbolic_assert = torch.Assert # 1.7.0 + +#---------------------------------------------------------------------------- +# Context manager to suppress known warnings in torch.jit.trace(). + +class suppress_tracer_warnings(warnings.catch_warnings): + def __enter__(self): + super().__enter__() + warnings.simplefilter('ignore', category=torch.jit.TracerWarning) + return self + +#---------------------------------------------------------------------------- +# Assert that the shape of a tensor matches the given list of integers. +# None indicates that the size of a dimension is allowed to vary. +# Performs symbolic assertion when used in torch.jit.trace(). + +def assert_shape(tensor, ref_shape): + if tensor.ndim != len(ref_shape): + raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') + for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): + if ref_size is None: + pass + elif isinstance(ref_size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') + elif isinstance(size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') + elif size != ref_size: + raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') + +#---------------------------------------------------------------------------- +# Function decorator that calls torch.autograd.profiler.record_function(). + +def profiled_function(fn): + def decorator(*args, **kwargs): + with torch.autograd.profiler.record_function(fn.__name__): + return fn(*args, **kwargs) + decorator.__name__ = fn.__name__ + return decorator + +#---------------------------------------------------------------------------- +# Sampler for torch.utils.data.DataLoader that loops over the dataset +# indefinitely, shuffling items as it goes. + +class InfiniteSampler(torch.utils.data.Sampler): + def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): + assert len(dataset) > 0 + assert num_replicas > 0 + assert 0 <= rank < num_replicas + assert 0 <= window_size <= 1 + super().__init__(dataset) + self.dataset = dataset + self.rank = rank + self.num_replicas = num_replicas + self.shuffle = shuffle + self.seed = seed + self.window_size = window_size + + def __iter__(self): + order = np.arange(len(self.dataset)) + rnd = None + window = 0 + if self.shuffle: + rnd = np.random.RandomState(self.seed) + rnd.shuffle(order) + window = int(np.rint(order.size * self.window_size)) + + idx = 0 + while True: + i = idx % order.size + if idx % self.num_replicas == self.rank: + yield order[i] + if window >= 2: + j = (i - rnd.randint(window)) % order.size + order[i], order[j] = order[j], order[i] + idx += 1 + +#---------------------------------------------------------------------------- +# Utilities for operating with torch.nn.Module parameters and buffers. + +def params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.parameters()) + list(module.buffers()) + +def named_params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.named_parameters()) + list(module.named_buffers()) + +def copy_params_and_buffers(src_module, dst_module, require_all=False): + assert isinstance(src_module, torch.nn.Module) + assert isinstance(dst_module, torch.nn.Module) + src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)} + for name, tensor in named_params_and_buffers(dst_module): + assert (name in src_tensors) or (not require_all) + if name in src_tensors: + tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) + +#---------------------------------------------------------------------------- +# Context manager for easily enabling/disabling DistributedDataParallel +# synchronization. + +@contextlib.contextmanager +def ddp_sync(module, sync): + assert isinstance(module, torch.nn.Module) + if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): + yield + else: + with module.no_sync(): + yield diff --git a/sg3_torch_utils/ops/__init__.py b/sg3_torch_utils/ops/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..ece0ea08fe2e939cc260a1dafc0ab5b391b773d9 --- /dev/null +++ b/sg3_torch_utils/ops/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +# empty diff --git a/sg3_torch_utils/ops/bias_act.cpp b/sg3_torch_utils/ops/bias_act.cpp new file mode 100755 index 0000000000000000000000000000000000000000..5d2425d8054991a8e8b6f7a940fd0ff7fa0bb330 --- /dev/null +++ b/sg3_torch_utils/ops/bias_act.cpp @@ -0,0 +1,99 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include +#include +#include "bias_act.h" + +//------------------------------------------------------------------------ + +static bool has_same_layout(torch::Tensor x, torch::Tensor y) +{ + if (x.dim() != y.dim()) + return false; + for (int64_t i = 0; i < x.dim(); i++) + { + if (x.size(i) != y.size(i)) + return false; + if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) + return false; + } + return true; +} + +//------------------------------------------------------------------------ + +static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) +{ + // Validate arguments. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); + TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); + TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); + TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(b.dim() == 1, "b must have rank 1"); + TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); + TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); + TORCH_CHECK(grad >= 0, "grad must be non-negative"); + + // Validate layout. + TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); + TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); + TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); + TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); + TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); + + // Create output tensor. + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + torch::Tensor y = torch::empty_like(x); + TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); + + // Initialize CUDA kernel parameters. + bias_act_kernel_params p; + p.x = x.data_ptr(); + p.b = (b.numel()) ? b.data_ptr() : NULL; + p.xref = (xref.numel()) ? xref.data_ptr() : NULL; + p.yref = (yref.numel()) ? yref.data_ptr() : NULL; + p.dy = (dy.numel()) ? dy.data_ptr() : NULL; + p.y = y.data_ptr(); + p.grad = grad; + p.act = act; + p.alpha = alpha; + p.gain = gain; + p.clamp = clamp; + p.sizeX = (int)x.numel(); + p.sizeB = (int)b.numel(); + p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; + + // Choose CUDA kernel. + void* kernel; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] + { + kernel = choose_bias_act_kernel(p); + }); + TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); + + // Launch CUDA kernel. + p.loopX = 4; + int blockSize = 4 * 32; + int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; + void* args[] = {&p}; + AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); + return y; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("bias_act", &bias_act); +} + +//------------------------------------------------------------------------ diff --git a/sg3_torch_utils/ops/bias_act.cu b/sg3_torch_utils/ops/bias_act.cu new file mode 100755 index 0000000000000000000000000000000000000000..dd8fc4756d7d94727f94af738665b68d9c518880 --- /dev/null +++ b/sg3_torch_utils/ops/bias_act.cu @@ -0,0 +1,173 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include "bias_act.h" + +//------------------------------------------------------------------------ +// Helpers. + +template struct InternalType; +template <> struct InternalType { typedef double scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; + +//------------------------------------------------------------------------ +// CUDA kernel. + +template +__global__ void bias_act_kernel(bias_act_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + int G = p.grad; + scalar_t alpha = (scalar_t)p.alpha; + scalar_t gain = (scalar_t)p.gain; + scalar_t clamp = (scalar_t)p.clamp; + scalar_t one = (scalar_t)1; + scalar_t two = (scalar_t)2; + scalar_t expRange = (scalar_t)80; + scalar_t halfExpRange = (scalar_t)40; + scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; + scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; + + // Loop over elements. + int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; + for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) + { + // Load. + scalar_t x = (scalar_t)((const T*)p.x)[xi]; + scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; + scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; + scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; + scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; + scalar_t yy = (gain != 0) ? yref / gain : 0; + scalar_t y = 0; + + // Apply bias. + ((G == 0) ? x : xref) += b; + + // linear + if (A == 1) + { + if (G == 0) y = x; + if (G == 1) y = x; + } + + // relu + if (A == 2) + { + if (G == 0) y = (x > 0) ? x : 0; + if (G == 1) y = (yy > 0) ? x : 0; + } + + // lrelu + if (A == 3) + { + if (G == 0) y = (x > 0) ? x : x * alpha; + if (G == 1) y = (yy > 0) ? x : x * alpha; + } + + // tanh + if (A == 4) + { + if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } + if (G == 1) y = x * (one - yy * yy); + if (G == 2) y = x * (one - yy * yy) * (-two * yy); + } + + // sigmoid + if (A == 5) + { + if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); + if (G == 1) y = x * yy * (one - yy); + if (G == 2) y = x * yy * (one - yy) * (one - two * yy); + } + + // elu + if (A == 6) + { + if (G == 0) y = (x >= 0) ? x : exp(x) - one; + if (G == 1) y = (yy >= 0) ? x : x * (yy + one); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); + } + + // selu + if (A == 7) + { + if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); + if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); + } + + // softplus + if (A == 8) + { + if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); + if (G == 1) y = x * (one - exp(-yy)); + if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } + } + + // swish + if (A == 9) + { + if (G == 0) + y = (x < -expRange) ? 0 : x / (exp(-x) + one); + else + { + scalar_t c = exp(xref); + scalar_t d = c + one; + if (G == 1) + y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); + else + y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); + yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; + } + } + + // Apply gain. + y *= gain * dy; + + // Clamp. + if (clamp >= 0) + { + if (G == 0) + y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; + else + y = (yref > -clamp & yref < clamp) ? y : 0; + } + + // Store. + ((T*)p.y)[xi] = (T)y; + } +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template void* choose_bias_act_kernel(const bias_act_kernel_params& p) +{ + if (p.act == 1) return (void*)bias_act_kernel; + if (p.act == 2) return (void*)bias_act_kernel; + if (p.act == 3) return (void*)bias_act_kernel; + if (p.act == 4) return (void*)bias_act_kernel; + if (p.act == 5) return (void*)bias_act_kernel; + if (p.act == 6) return (void*)bias_act_kernel; + if (p.act == 7) return (void*)bias_act_kernel; + if (p.act == 8) return (void*)bias_act_kernel; + if (p.act == 9) return (void*)bias_act_kernel; + return NULL; +} + +//------------------------------------------------------------------------ +// Template specializations. + +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/sg3_torch_utils/ops/bias_act.h b/sg3_torch_utils/ops/bias_act.h new file mode 100755 index 0000000000000000000000000000000000000000..a32187e1fb7e3bae509d4eceaf900866866875a4 --- /dev/null +++ b/sg3_torch_utils/ops/bias_act.h @@ -0,0 +1,38 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct bias_act_kernel_params +{ + const void* x; // [sizeX] + const void* b; // [sizeB] or NULL + const void* xref; // [sizeX] or NULL + const void* yref; // [sizeX] or NULL + const void* dy; // [sizeX] or NULL + void* y; // [sizeX] + + int grad; + int act; + float alpha; + float gain; + float clamp; + + int sizeX; + int sizeB; + int stepB; + int loopX; +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template void* choose_bias_act_kernel(const bias_act_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/sg3_torch_utils/ops/bias_act.py b/sg3_torch_utils/ops/bias_act.py new file mode 100755 index 0000000000000000000000000000000000000000..7c39717268055fafe737419486cf96f1f93f4fb5 --- /dev/null +++ b/sg3_torch_utils/ops/bias_act.py @@ -0,0 +1,215 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom PyTorch ops for efficient bias and activation.""" + +import os +import warnings +import numpy as np +import torch +import traceback + +from .. import custom_ops +from easydict import EasyDict +from torch.cuda.amp import custom_bwd, custom_fwd +#---------------------------------------------------------------------------- + +activation_funcs = { + 'linear': EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), + 'relu': EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), + 'lrelu': EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), + 'tanh': EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), + 'sigmoid': EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), + 'elu': EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), + 'selu': EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), + 'softplus': EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), + 'swish': EasyDict(func=lambda x, **_: torch.nn.functional.silu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), +} + +#---------------------------------------------------------------------------- + +_inited = False +_plugin = None +enabled = False +_null_tensor = torch.empty([0]) + +def _init(): + global _inited, _plugin + if not _inited: + _inited = True + sources = ['bias_act.cpp', 'bias_act.cu'] + sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] + try: + _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) + except: + warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) + return _plugin is not None + +#---------------------------------------------------------------------------- + +def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): + r"""Fused bias and activation function. + + Adds bias `b` to activation tensor `x`, evaluates activation function `act`, + and scales the result by `gain`. Each of the steps is optional. In most cases, + the fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports first and second order gradients, + but not third order gradients. + + Args: + x: Input activation tensor. Can be of any shape. + b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type + as `x`. The shape must be known, and it must match the dimension of `x` + corresponding to `dim`. + dim: The dimension in `x` corresponding to the elements of `b`. + The value of `dim` is ignored if `b` is not specified. + act: Name of the activation function to evaluate, or `"linear"` to disable. + Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. + See `activation_funcs` for a full list. `None` is not allowed. + alpha: Shape parameter for the activation function, or `None` to use the default. + gain: Scaling factor for the output tensor, or `None` to use default. + See `activation_funcs` for the default scaling of each activation function. + If unsure, consider specifying 1. + clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable + the clamping (default). + impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). + + Returns: + Tensor of the same shape and datatype as `x`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and enabled and _init(): + return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) + return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) + +#---------------------------------------------------------------------------- + +def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): + """Slow reference implementation of `bias_act()` using standard TensorFlow ops. + """ + assert isinstance(x, torch.Tensor) + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Add bias. + if b is not None: + assert isinstance(b, torch.Tensor) and b.ndim == 1 + assert 0 <= dim < x.ndim + assert b.shape[0] == x.shape[dim] + x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) + + # Evaluate activation function. + alpha = float(alpha) + x = spec.func(x, alpha=alpha) + + # Scale by gain. + gain = float(gain) + if gain != 1: + x = x * gain + + # Clamp. + if clamp >= 0: + x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type + return x + +#---------------------------------------------------------------------------- + +_bias_act_cuda_cache = dict() + +def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): + """Fast CUDA implementation of `bias_act()` using custom ops. + """ + # Parse arguments. + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Lookup from cache. + key = (dim, act, alpha, gain, clamp) + if key in _bias_act_cuda_cache: + return _bias_act_cuda_cache[key] + + # Forward op. + class BiasActCuda(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx, x, b): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format + x = x.contiguous(memory_format=ctx.memory_format) + b = b.contiguous() if b is not None else _null_tensor + y = x + if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: + y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + y if 'y' in spec.ref else _null_tensor) + return y + + @staticmethod + @custom_bwd + def backward(ctx, dy): # pylint: disable=arguments-differ + dy = dy.contiguous(memory_format=ctx.memory_format) + x, b, y = ctx.saved_tensors + dx = None + db = None + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + dx = dy + if act != 'linear' or gain != 1 or clamp >= 0: + dx = BiasActCudaGrad.apply(dy, x, b, y) + + if ctx.needs_input_grad[1]: + db = dx.sum([i for i in range(dx.ndim) if i != dim]) + + return dx, db + + # Backward op. + class BiasActCudaGrad(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format + dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + dy if spec.has_2nd_grad else _null_tensor, + x, b, y) + return dx + + @staticmethod + @custom_bwd + def backward(ctx, d_dx): # pylint: disable=arguments-differ + d_dx = d_dx.contiguous(memory_format=ctx.memory_format) + dy, x, b, y = ctx.saved_tensors + d_dy = None + d_x = None + d_b = None + d_y = None + + if ctx.needs_input_grad[0]: + d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) + + if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): + d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) + + if spec.has_2nd_grad and ctx.needs_input_grad[2]: + d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) + + return d_dy, d_x, d_b, d_y + + # Add to cache. + _bias_act_cuda_cache[key] = BiasActCuda + return BiasActCuda + +#---------------------------------------------------------------------------- diff --git a/sg3_torch_utils/ops/conv2d_gradfix.py b/sg3_torch_utils/ops/conv2d_gradfix.py new file mode 100755 index 0000000000000000000000000000000000000000..e66591f19fad68760d3df7c9737a14574b70ee83 --- /dev/null +++ b/sg3_torch_utils/ops/conv2d_gradfix.py @@ -0,0 +1,175 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom replacement for `torch.nn.functional.conv2d` that supports +arbitrarily high order gradients with zero performance penalty.""" + +import warnings +import contextlib +import torch +from torch.cuda.amp import custom_bwd, custom_fwd + +# pylint: disable=redefined-builtin +# pylint: disable=arguments-differ +# pylint: disable=protected-access + +#---------------------------------------------------------------------------- + +enabled = False # Enable the custom op by setting this to true. +weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. + +@contextlib.contextmanager +def no_weight_gradients(): + global weight_gradients_disabled + old = weight_gradients_disabled + weight_gradients_disabled = True + yield + weight_gradients_disabled = old + +#---------------------------------------------------------------------------- + +def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + if _should_use_custom_op(input): + return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) + return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + +def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): + if _should_use_custom_op(input): + return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) + return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) + +#---------------------------------------------------------------------------- + +def _should_use_custom_op(input): + assert isinstance(input, torch.Tensor) + if (not enabled) or (not torch.backends.cudnn.enabled): + return False + if input.device.type != 'cuda': + return False + if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9', '1.10']): + return True + warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().') + return False + +def _tuple_of_ints(xs, ndim): + xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim + assert len(xs) == ndim + assert all(isinstance(x, int) for x in xs) + return xs + +#---------------------------------------------------------------------------- + +_conv2d_gradfix_cache = dict() + +def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): + # Parse arguments. + ndim = 2 + weight_shape = tuple(weight_shape) + stride = _tuple_of_ints(stride, ndim) + padding = _tuple_of_ints(padding, ndim) + output_padding = _tuple_of_ints(output_padding, ndim) + dilation = _tuple_of_ints(dilation, ndim) + + # Lookup from cache. + key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) + if key in _conv2d_gradfix_cache: + return _conv2d_gradfix_cache[key] + + # Validate arguments. + assert groups >= 1 + assert len(weight_shape) == ndim + 2 + assert all(stride[i] >= 1 for i in range(ndim)) + assert all(padding[i] >= 0 for i in range(ndim)) + assert all(dilation[i] >= 0 for i in range(ndim)) + if not transpose: + assert all(output_padding[i] == 0 for i in range(ndim)) + else: # transpose + assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) + + # Helpers. + common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) + def calc_output_padding(input_shape, output_shape): + if transpose: + return [0, 0] + return [ + input_shape[i + 2] + - (output_shape[i + 2] - 1) * stride[i] + - (1 - 2 * padding[i]) + - dilation[i] * (weight_shape[i + 2] - 1) + for i in range(ndim) + ] + + # Forward & backward. + class Conv2d(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx, input, weight, bias): + assert weight.shape == weight_shape + if not transpose: + output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) + else: # transpose + output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) + ctx.save_for_backward(input, weight) + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + grad_input = None + grad_weight = None + grad_bias = None + + if ctx.needs_input_grad[0]: + p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) + grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output.float(), weight.float(), None) + assert grad_input.shape == input.shape + + if ctx.needs_input_grad[1] and not weight_gradients_disabled: + grad_weight = Conv2dGradWeight.apply(grad_output.float(), input.float()) + assert grad_weight.shape == weight_shape + + if ctx.needs_input_grad[2]: + grad_bias = grad_output.float().sum([0, 2, 3]) + + return grad_input, grad_weight, grad_bias + + # Gradient with respect to the weights. + class Conv2dGradWeight(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx, grad_output, input): + op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight') + flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] + grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) + assert grad_weight.shape == weight_shape + ctx.save_for_backward(grad_output, input) + return grad_weight + + @staticmethod + @custom_bwd + def backward(ctx, grad2_grad_weight): + grad_output, input = ctx.saved_tensors + grad2_grad_output = None + grad2_input = None + + if ctx.needs_input_grad[0]: + grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) + assert grad2_grad_output.shape == grad_output.shape + + if ctx.needs_input_grad[1]: + p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) + grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None) + assert grad2_input.shape == input.shape + + return grad2_grad_output, grad2_input + + _conv2d_gradfix_cache[key] = Conv2d + return Conv2d + +#---------------------------------------------------------------------------- diff --git a/sg3_torch_utils/ops/conv2d_resample.py b/sg3_torch_utils/ops/conv2d_resample.py new file mode 100755 index 0000000000000000000000000000000000000000..4a999b58b36a5da53752024e86a9ebdf9c031d97 --- /dev/null +++ b/sg3_torch_utils/ops/conv2d_resample.py @@ -0,0 +1,142 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""2D convolution with optional up/downsampling.""" + +import torch + +from .. import misc +from . import conv2d_gradfix +from . import upfirdn2d +from .upfirdn2d import _parse_padding +from .upfirdn2d import _get_filter_size + +#---------------------------------------------------------------------------- + +def _get_weight_shape(w): + with misc.suppress_tracer_warnings(): # this value will be treated as a constant + shape = [int(sz) for sz in w.shape] + misc.assert_shape(w, shape) + return shape + +#---------------------------------------------------------------------------- + +def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): + """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. + """ + out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) + + # Flip weight if requested. + if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). + w = w.flip([2, 3]) + + # Otherwise => execute using conv2d_gradfix. + op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d + return op(x, w, stride=stride, padding=padding, groups=groups) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): + r"""2D convolution with optional up/downsampling. + + Padding is performed only once at the beginning, not between the operations. + + Args: + x: Input tensor of shape + `[batch_size, in_channels, in_height, in_width]`. + w: Weight tensor of shape + `[out_channels, in_channels//groups, kernel_height, kernel_width]`. + f: Low-pass filter for up/downsampling. Must be prepared beforehand by + calling upfirdn2d.setup_filter(). None = identity (default). + up: Integer upsampling factor (default: 1). + down: Integer downsampling factor (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + groups: Split input channels into N groups (default: 1). + flip_weight: False = convolution, True = correlation (default: True). + flip_filter: False = convolution, True = correlation (default: False). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and (x.ndim == 4) + assert isinstance(w, torch.Tensor) and (w.ndim == 4) + assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) + assert isinstance(up, int) and (up >= 1) + assert isinstance(down, int) and (down >= 1) + assert isinstance(groups, int) and (groups >= 1) + out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) + fw, fh = _get_filter_size(f) + px0, px1, py0, py1 = _parse_padding(padding) + + # Adjust padding to account for up/downsampling. + if up > 1: + px0 += (fw + up - 1) // 2 + px1 += (fw - up) // 2 + py0 += (fh + up - 1) // 2 + py1 += (fh - up) // 2 + if down > 1: + px0 += (fw - down + 1) // 2 + px1 += (fw - down) // 2 + py0 += (fh - down + 1) // 2 + py1 += (fh - down) // 2 + + # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. + if kw == 1 and kh == 1 and (down > 1 and up == 1): + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + return x + + # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. + if kw == 1 and kh == 1 and (up > 1 and down == 1): + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) + return x + + # Fast path: downsampling only => use strided convolution. + if down > 1 and up == 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) + x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) + return x + + # Fast path: upsampling with optional downsampling => use transpose strided convolution. + if up > 1: + if groups == 1: + w = w.transpose(0, 1) + else: + w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) + w = w.transpose(1, 2) + w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) + px0 -= kw - 1 + px1 -= kw - up + py0 -= kh - 1 + py1 -= kh - up + pxt = max(min(-px0, -px1), 0) + pyt = max(min(-py0, -py1), 0) + x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) + x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) + if down > 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) + return x + + # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. + if up == 1 and down == 1: + if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: + return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) + + # Fallback: Generic reference implementation. + x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) + if down > 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) + return x + +#---------------------------------------------------------------------------- diff --git a/sg3_torch_utils/ops/fma.py b/sg3_torch_utils/ops/fma.py new file mode 100755 index 0000000000000000000000000000000000000000..b4e8ef9169440d4c3bd95befae7d26e3c1e1f017 --- /dev/null +++ b/sg3_torch_utils/ops/fma.py @@ -0,0 +1,63 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" + +import torch +from torch.cuda.amp import custom_bwd, custom_fwd + +#---------------------------------------------------------------------------- + +def fma(a, b, c): # => a * b + c + return _FusedMultiplyAdd.apply(a, b, c) + +#---------------------------------------------------------------------------- + +class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx, a, b, c): # pylint: disable=arguments-differ + out = torch.addcmul(c, a, b) + ctx.save_for_backward(a, b) + ctx.c_shape = c.shape + return out + + @staticmethod + @custom_bwd + def backward(ctx, dout): # pylint: disable=arguments-differ + a, b = ctx.saved_tensors + c_shape = ctx.c_shape + da = None + db = None + dc = None + + if ctx.needs_input_grad[0]: + da = _unbroadcast(dout * b, a.shape) + + if ctx.needs_input_grad[1]: + db = _unbroadcast(dout * a, b.shape) + + if ctx.needs_input_grad[2]: + dc = _unbroadcast(dout, c_shape) + + return da, db, dc + +#---------------------------------------------------------------------------- + +def _unbroadcast(x, shape): + extra_dims = x.ndim - len(shape) + assert extra_dims >= 0 + dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] + if len(dim): + x = x.sum(dim=dim, keepdim=True) + if extra_dims: + x = x.reshape(-1, *x.shape[extra_dims+1:]) + assert x.shape == shape + return x + +#---------------------------------------------------------------------------- diff --git a/sg3_torch_utils/ops/grid_sample_gradfix.py b/sg3_torch_utils/ops/grid_sample_gradfix.py new file mode 100755 index 0000000000000000000000000000000000000000..87067e150c591b1ace91816e7a5c3ee3a4aeacd3 --- /dev/null +++ b/sg3_torch_utils/ops/grid_sample_gradfix.py @@ -0,0 +1,88 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom replacement for `torch.nn.functional.grid_sample` that +supports arbitrarily high order gradients between the input and output. +Only works on 2D images and assumes +`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" + +import torch +from torch.cuda.amp import custom_bwd, custom_fwd +from pkg_resources import parse_version +# pylint: disable=redefined-builtin +# pylint: disable=arguments-differ +# pylint: disable=protected-access +_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11 + + +#---------------------------------------------------------------------------- + +enabled = False # Enable the custom op by setting this to true. + +#---------------------------------------------------------------------------- + +def grid_sample(input, grid): + if _should_use_custom_op(): + return _GridSample2dForward.apply(input, grid) + return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) + +#---------------------------------------------------------------------------- + +def _should_use_custom_op(): + return enabled + +#---------------------------------------------------------------------------- + +class _GridSample2dForward(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx, input, grid): + assert input.ndim == 4 + assert grid.ndim == 4 + output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) + ctx.save_for_backward(input, grid) + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + input, grid = ctx.saved_tensors + grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) + return grad_input, grad_grid + +#---------------------------------------------------------------------------- + +class _GridSample2dBackward(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx, grad_output, input, grid): + op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') + if _use_pytorch_1_11_api: + output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2]) + grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask) + else: + grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) + ctx.save_for_backward(grid) + return grad_input, grad_grid + + @staticmethod + @custom_bwd + def backward(ctx, grad2_grad_input, grad2_grad_grid): + _ = grad2_grad_grid # unused + grid, = ctx.saved_tensors + grad2_grad_output = None + grad2_input = None + grad2_grid = None + + if ctx.needs_input_grad[0]: + grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) + + assert not ctx.needs_input_grad[2] + return grad2_grad_output, grad2_input, grad2_grid + +#---------------------------------------------------------------------------- diff --git a/sg3_torch_utils/ops/upfirdn2d.cpp b/sg3_torch_utils/ops/upfirdn2d.cpp new file mode 100755 index 0000000000000000000000000000000000000000..2d7177fc60040751d20e9a8da0301fa3ab64968a --- /dev/null +++ b/sg3_torch_utils/ops/upfirdn2d.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include +#include +#include "upfirdn2d.h" + +//------------------------------------------------------------------------ + +static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) +{ + // Validate arguments. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); + TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); + TORCH_CHECK(x.dim() == 4, "x must be rank 4"); + TORCH_CHECK(f.dim() == 2, "f must be rank 2"); + TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); + TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); + TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); + + // Create output tensor. + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; + int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; + TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); + torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); + TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); + + // Initialize CUDA kernel parameters. + upfirdn2d_kernel_params p; + p.x = x.data_ptr(); + p.f = f.data_ptr(); + p.y = y.data_ptr(); + p.up = make_int2(upx, upy); + p.down = make_int2(downx, downy); + p.pad0 = make_int2(padx0, pady0); + p.flip = (flip) ? 1 : 0; + p.gain = gain; + p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); + p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); + p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); + p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); + p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); + p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); + p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; + p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; + + // Choose CUDA kernel. + upfirdn2d_kernel_spec spec; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] + { + spec = choose_upfirdn2d_kernel(p); + }); + + // Set looping options. + p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; + p.loopMinor = spec.loopMinor; + p.loopX = spec.loopX; + p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; + p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; + + // Compute grid size. + dim3 blockSize, gridSize; + if (spec.tileOutW < 0) // large + { + blockSize = dim3(4, 32, 1); + gridSize = dim3( + ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, + (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, + p.launchMajor); + } + else // small + { + blockSize = dim3(256, 1, 1); + gridSize = dim3( + ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, + (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, + p.launchMajor); + } + + // Launch CUDA kernel. + void* args[] = {&p}; + AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); + return y; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("upfirdn2d", &upfirdn2d); +} + +//------------------------------------------------------------------------ diff --git a/sg3_torch_utils/ops/upfirdn2d.cu b/sg3_torch_utils/ops/upfirdn2d.cu new file mode 100755 index 0000000000000000000000000000000000000000..ebdd9879f4bb16fc57a23cbc81f9de8ef54e4916 --- /dev/null +++ b/sg3_torch_utils/ops/upfirdn2d.cu @@ -0,0 +1,350 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include "upfirdn2d.h" + +//------------------------------------------------------------------------ +// Helpers. + +template struct InternalType; +template <> struct InternalType { typedef double scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; + +static __device__ __forceinline__ int floor_div(int a, int b) +{ + int t = 1 - a / b; + return (a + t * b) / b - t; +} + +//------------------------------------------------------------------------ +// Generic CUDA implementation for large filters. + +template static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + + // Calculate thread index. + int minorBase = blockIdx.x * blockDim.x + threadIdx.x; + int outY = minorBase / p.launchMinor; + minorBase -= outY * p.launchMinor; + int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y; + int majorBase = blockIdx.z * p.loopMajor; + if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor) + return; + + // Setup Y receptive field. + int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y; + int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y); + int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY; + int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y; + if (p.flip) + filterY = p.filterSize.y - 1 - filterY; + + // Loop over major, minor, and X. + for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) + for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor) + { + int nc = major * p.sizeMinor + minor; + int n = nc / p.inSize.z; + int c = nc - n * p.inSize.z; + for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y) + { + // Setup X receptive field. + int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x; + int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x); + int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX; + int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x; + if (p.flip) + filterX = p.filterSize.x - 1 - filterX; + + // Initialize pointers. + const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; + const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y]; + int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x; + int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y; + + // Inner loop. + scalar_t v = 0; + for (int y = 0; y < h; y++) + { + for (int x = 0; x < w; x++) + { + v += (scalar_t)(*xp) * (scalar_t)(*fp); + xp += p.inStride.x; + fp += filterStepX; + } + xp += p.inStride.y - w * p.inStride.x; + fp += filterStepY - w * filterStepX; + } + + // Store result. + v *= p.gain; + ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; + } + } +} + +//------------------------------------------------------------------------ +// Specialized CUDA implementation for small filters. + +template +static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1; + const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1; + __shared__ volatile scalar_t sf[filterH][filterW]; + __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor]; + + // Calculate tile index. + int minorBase = blockIdx.x; + int tileOutY = minorBase / p.launchMinor; + minorBase -= tileOutY * p.launchMinor; + minorBase *= loopMinor; + tileOutY *= tileOutH; + int tileOutXBase = blockIdx.y * p.loopX * tileOutW; + int majorBase = blockIdx.z * p.loopMajor; + if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor) + return; + + // Load filter (flipped). + for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x) + { + int fy = tapIdx / filterW; + int fx = tapIdx - fy * filterW; + scalar_t v = 0; + if (fx < p.filterSize.x & fy < p.filterSize.y) + { + int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx; + int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy; + v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y]; + } + sf[fy][fx] = v; + } + + // Loop over major and X. + for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) + { + int baseNC = major * p.sizeMinor + minorBase; + int n = baseNC / p.inSize.z; + int baseC = baseNC - n * p.inSize.z; + for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW) + { + // Load input pixels. + int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x; + int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y; + int tileInX = floor_div(tileMidX, upx); + int tileInY = floor_div(tileMidY, upy); + __syncthreads(); + for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x) + { + int relC = inIdx; + int relInX = relC / loopMinor; + int relInY = relInX / tileInW; + relC -= relInX * loopMinor; + relInX -= relInY * tileInW; + int c = baseC + relC; + int inX = tileInX + relInX; + int inY = tileInY + relInY; + scalar_t v = 0; + if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z) + v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; + sx[relInY][relInX][relC] = v; + } + + // Loop over output pixels. + __syncthreads(); + for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x) + { + int relC = outIdx; + int relOutX = relC / loopMinor; + int relOutY = relOutX / tileOutW; + relC -= relOutX * loopMinor; + relOutX -= relOutY * tileOutW; + int c = baseC + relC; + int outX = tileOutX + relOutX; + int outY = tileOutY + relOutY; + + // Setup receptive field. + int midX = tileMidX + relOutX * downx; + int midY = tileMidY + relOutY * downy; + int inX = floor_div(midX, upx); + int inY = floor_div(midY, upy); + int relInX = inX - tileInX; + int relInY = inY - tileInY; + int filterX = (inX + 1) * upx - midX - 1; // flipped + int filterY = (inY + 1) * upy - midY - 1; // flipped + + // Inner loop. + if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z) + { + scalar_t v = 0; + #pragma unroll + for (int y = 0; y < filterH / upy; y++) + #pragma unroll + for (int x = 0; x < filterW / upx; x++) + v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx]; + v *= p.gain; + ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; + } + } + } + } +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p) +{ + int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y; + + upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large, -1,-1,1, 4}; // contiguous + if (s == 1) spec = {(void*)upfirdn2d_kernel_large, -1,-1,4, 1}; // channels_last + + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + } + if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + } + if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + } + if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + } + return spec; +} + +//------------------------------------------------------------------------ +// Template specializations. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/sg3_torch_utils/ops/upfirdn2d.h b/sg3_torch_utils/ops/upfirdn2d.h new file mode 100755 index 0000000000000000000000000000000000000000..c9e2032bcac9d2abde7a75eea4d812da348afadd --- /dev/null +++ b/sg3_torch_utils/ops/upfirdn2d.h @@ -0,0 +1,59 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct upfirdn2d_kernel_params +{ + const void* x; + const float* f; + void* y; + + int2 up; + int2 down; + int2 pad0; + int flip; + float gain; + + int4 inSize; // [width, height, channel, batch] + int4 inStride; + int2 filterSize; // [width, height] + int2 filterStride; + int4 outSize; // [width, height, channel, batch] + int4 outStride; + int sizeMinor; + int sizeMajor; + + int loopMinor; + int loopMajor; + int loopX; + int launchMinor; + int launchMajor; +}; + +//------------------------------------------------------------------------ +// CUDA kernel specialization. + +struct upfirdn2d_kernel_spec +{ + void* kernel; + int tileOutW; + int tileOutH; + int loopMinor; + int loopX; +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/sg3_torch_utils/ops/upfirdn2d.py b/sg3_torch_utils/ops/upfirdn2d.py new file mode 100755 index 0000000000000000000000000000000000000000..a0bbd22d245481e7c5a19315e5cb3242b1278787 --- /dev/null +++ b/sg3_torch_utils/ops/upfirdn2d.py @@ -0,0 +1,388 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom PyTorch ops for efficient resampling of 2D images.""" + +import os +import warnings +import numpy as np +import torch +import traceback + +from .. import custom_ops +from .. import misc +from . import conv2d_gradfix +from torch.cuda.amp import custom_bwd, custom_fwd + +#---------------------------------------------------------------------------- + +_inited = False +_plugin = None +enabled = False + +def _init(): + global _inited, _plugin + if not _inited: + sources = ['upfirdn2d.cpp', 'upfirdn2d.cu'] + sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] + try: + _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) + except: + warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) + return _plugin is not None + +def _parse_scaling(scaling): + if isinstance(scaling, int): + scaling = [scaling, scaling] + assert isinstance(scaling, (list, tuple)) + assert all(isinstance(x, int) for x in scaling) + sx, sy = scaling + assert sx >= 1 and sy >= 1 + return sx, sy + +def _parse_padding(padding): + if isinstance(padding, int): + padding = [padding, padding] + assert isinstance(padding, (list, tuple)) + assert all(isinstance(x, int) for x in padding) + if len(padding) == 2: + padx, pady = padding + padding = [padx, padx, pady, pady] + padx0, padx1, pady0, pady1 = padding + return padx0, padx1, pady0, pady1 + +def _get_filter_size(f): + if f is None: + return 1, 1 + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + fw = f.shape[-1] + fh = f.shape[0] + with misc.suppress_tracer_warnings(): + fw = int(fw) + fh = int(fh) + misc.assert_shape(f, [fh, fw][:f.ndim]) + assert fw >= 1 and fh >= 1 + return fw, fh + +#---------------------------------------------------------------------------- + +def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None): + r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. + + Args: + f: Torch tensor, numpy array, or python list of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), + `[]` (impulse), or + `None` (identity). + device: Result device (default: cpu). + normalize: Normalize the filter so that it retains the magnitude + for constant input signal (DC)? (default: True). + flip_filter: Flip the filter? (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + separable: Return a separable filter? (default: select automatically). + + Returns: + Float32 tensor of the shape + `[filter_height, filter_width]` (non-separable) or + `[filter_taps]` (separable). + """ + # Validate. + if f is None: + f = 1 + f = torch.as_tensor(f, dtype=torch.float32) + assert f.ndim in [0, 1, 2] + assert f.numel() > 0 + if f.ndim == 0: + f = f[np.newaxis] + + # Separable? + if separable is None: + separable = (f.ndim == 1 and f.numel() >= 8) + if f.ndim == 1 and not separable: + f = f.ger(f) + assert f.ndim == (1 if separable else 2) + + # Apply normalize, flip, gain, and device. + if normalize: + f /= f.sum() + if flip_filter: + f = f.flip(list(range(f.ndim))) + f = f * (gain ** (f.ndim / 2)) + f = f.to(device=device) + return f + +#---------------------------------------------------------------------------- + +def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Pad, upsample, filter, and downsample a batch of 2D images. + + Performs the following sequence of operations for each channel: + + 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). + + 2. Pad the image with the specified number of zeros on each side (`padding`). + Negative padding corresponds to cropping the image. + + 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it + so that the footprint of all output pixels lies within the input image. + + 4. Downsample the image by keeping every Nth pixel (`down`). + + This sequence of operations bears close resemblance to scipy.signal.upfirdn(). + The fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports gradients of arbitrary order. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and enabled and _init(): + return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f) + return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): + """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + assert f.dtype == torch.float32 and not f.requires_grad + batch_size, num_channels, in_height, in_width = x.shape + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Upsample by inserting zeros. + x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) + x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) + x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) + + # Pad or crop. + x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) + x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)] + + # Setup filter. + f = f * (gain ** (f.ndim / 2)) + f = f.to(x.dtype) + if not flip_filter: + f = f.flip(list(range(f.ndim))) + + # Convolve with the filter. + f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) + if f.ndim == 4: + x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels) + else: + x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) + x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) + + # Downsample by throwing away pixels. + x = x[:, :, ::downy, ::downx] + return x + +#---------------------------------------------------------------------------- + +_upfirdn2d_cuda_cache = dict() + +def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1): + """Fast CUDA implementation of `upfirdn2d()` using custom ops. + """ + # Parse arguments. + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Lookup from cache. + key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) + if key in _upfirdn2d_cuda_cache: + return _upfirdn2d_cuda_cache[key] + + # Forward op. + class Upfirdn2dCuda(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, x, f): # pylint: disable=arguments-differ + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + y = x + if f.ndim == 2: + y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) + else: + y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain)) + y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain)) + ctx.save_for_backward(f) + ctx.x_shape = x.shape + return y + + @staticmethod + @custom_bwd + def backward(ctx, dy): # pylint: disable=arguments-differ + f, = ctx.saved_tensors + _, _, ih, iw = ctx.x_shape + _, _, oh, ow = dy.shape + fw, fh = _get_filter_size(f) + p = [ + fw - padx0 - 1, + iw * upx - ow * downx + padx0 - upx + 1, + fh - pady0 - 1, + ih * upy - oh * downy + pady0 - upy + 1, + ] + dx = None + df = None + + if ctx.needs_input_grad[0]: + dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f) + + assert not ctx.needs_input_grad[1] + return dx, df + + # Add to cache. + _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda + return Upfirdn2dCuda + +#---------------------------------------------------------------------------- + +def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Filter a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape matches the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + fw // 2, + padx1 + (fw - 1) // 2, + pady0 + fh // 2, + pady1 + (fh - 1) // 2, + ] + return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + +#---------------------------------------------------------------------------- + +def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Upsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a multiple of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + upx, upy = _parse_scaling(up) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw + upx - 1) // 2, + padx1 + (fw - upx) // 2, + pady0 + (fh + upy - 1) // 2, + pady1 + (fh - upy) // 2, + ] + return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl) + +#---------------------------------------------------------------------------- + +def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Downsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a fraction of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the input. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw - downx + 1) // 2, + padx1 + (fw - downx) // 2, + pady0 + (fh - downy + 1) // 2, + pady1 + (fh - downy) // 2, + ] + return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + +#---------------------------------------------------------------------------- diff --git a/torch_home/hub/checkpoints/89660f04-5c11-4dbf-adac-cbe2f11b0aeea25cbf78-7558-475a-b3c7-03f5c10b7934646b0720-ca0a-4d53-aded-daddbfa45c9e b/torch_home/hub/checkpoints/89660f04-5c11-4dbf-adac-cbe2f11b0aeea25cbf78-7558-475a-b3c7-03f5c10b7934646b0720-ca0a-4d53-aded-daddbfa45c9e new file mode 100644 index 0000000000000000000000000000000000000000..b17574a8a99960be5cef91c3b96a1a3e3e233d79 --- /dev/null +++ b/torch_home/hub/checkpoints/89660f04-5c11-4dbf-adac-cbe2f11b0aeea25cbf78-7558-475a-b3c7-03f5c10b7934646b0720-ca0a-4d53-aded-daddbfa45c9e @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1598f0dc475200dcbd0fcdfccab39c70b6023c4af47052cfa9cf867fa08fe047 +size 173628377 diff --git a/torch_home/hub/checkpoints/WIDERFace_DSFD_RES152.pth b/torch_home/hub/checkpoints/WIDERFace_DSFD_RES152.pth new file mode 100644 index 0000000000000000000000000000000000000000..3318ad1b61e111b3d24c265e1a84d26ae9420e6c --- /dev/null +++ b/torch_home/hub/checkpoints/WIDERFace_DSFD_RES152.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:76ea889083f2dddaa87cbaf4cff6bb70e3e34c83fb0ab8de7b2df15eaf78caf1 +size 481004605