diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..e010efe37302940c4d5082a9c99eff1f8a62a96d 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +assets/demo0.gif filter=lfs diff=lfs merge=lfs -text +assets/demo1.gif filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..7ba185e9f8cc7306d13547e81841c9586ca47e23 --- /dev/null +++ b/.gitignore @@ -0,0 +1,16 @@ +.vscode +.idea +*output +backup +*/__pycache__/* +pretrained +logs +*.pyc +*.ipynb +*.out +*.log +*.pth +*.pkl +*.pt +*.npy +*debug* diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..85d81725587d6c36deb2e975d77547709f61369b --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 HUST Vision Lab + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 7be5fc7f47d5db027d120b8024982df93db95b74..7838b35d2a2c8741f99fc909f8fd0959006a543a 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,78 @@ ---- -license: mit ---- +
+

ViTGaze 👀

+

Gaze Following with Interaction Features in Vision Transformers

+ +Yuehao Song1 , Xinggang Wang1 :email: , Jingfeng Yao1 , Wenyu Liu1 , Jinglin Zhang2 , Xiangmin Xu3 + +1 Huazhong University of Science and Technology, 2 Shandong University, 3 South China University of Technology + +(:email:) corresponding author. + +ArXiv Preprint ([arXiv 2403.12778](https://arxiv.org/abs/2403.12778)) + +
+ +# +![Demo0](assets/demo0.gif) +![Demo1](assets/demo1.gif) +### News +* **`Mar. 25th, 2024`:** We release an initial version of ViTGaze. +* **`Mar. 19th, 2024`:** We released our paper on Arxiv. Code/Models are coming soon. Please stay tuned! ☕️ + + +## Introduction +
Plain Vision Transformer could also do gaze following with the simple ViTGaze framework!
+ +![framework](assets/pipeline.png "framework") + +Inspired by the remarkable success of pre-trained plain Vision Transformers (ViTs), we introduce a novel single-modality gaze following framework, **ViTGaze**. In contrast to previous methods, it creates a brand new gaze following framework based mainly on powerful encoders (relative decoder parameter less than 1%). Our principal insight lies in that the inter-token interactions within self-attention can be transferred to interactions between humans and scenes. Our method achieves state-of-the-art (SOTA) performance among all single-modality methods (3.4% improvement on AUC, 5.1% improvement on AP) and very comparable performance against multi-modality methods with 59% number of parameters less. + +## Results +> Results from the [ViTGaze paper](https://arxiv.org/abs/2403.12778) + +![comparison](assets/comparion.png "comparison") + + + + + + + + + + + + + + + + + + + + + + +
Results on GazeFollowResults on VideoAttentionTarget
AUCAvg. Dist.Min. Dist.AUCDist.AP
0.9490.1050.0470.9380.1020.905
+ +Corresponding checkpoints are released: +- GazeFollow: [GoogleDrive](https://drive.google.com/file/d/164c4woGCmUI8UrM7GEKQrV1FbA3vGwP4/view?usp=drive_link) +- VideoAttentionTarget: [GoogleDrive](https://drive.google.com/file/d/11_O4Jm5wsvQ8qfLLgTlrudqSNvvepsV0/view?usp=drive_link) +## Getting Started +- [Installation](docs/install.md) +- [Train](docs/train.md) +- [Eval](docs/eval.md) + +## Acknowledgements +ViTGaze is based on [detectron2](https://github.com/facebookresearch/detectron2). We use the efficient multi-head attention implemented in the [xFormers](https://github.com/facebookresearch/xformers) library. + +## Citation +If you find ViTGaze is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry. +```bibtex +@article{vitgaze, + title={ViTGaze: Gaze Following with Interaction Features in Vision Transformers}, + author={Yuehao Song and Xinggang Wang and Jingfeng Yao and Wenyu Liu and Jinglin Zhang and Xiangmin Xu}, + journal={arXiv preprint arXiv:2403.12778}, + year={2024} +} +``` diff --git a/assets/comparion.png b/assets/comparion.png new file mode 100644 index 0000000000000000000000000000000000000000..e3858095eeb24643628f9d108b6d79a61c5ef1af Binary files /dev/null and b/assets/comparion.png differ diff --git a/assets/demo0.gif b/assets/demo0.gif new file mode 100644 index 0000000000000000000000000000000000000000..d8009ea60e54012f974b8f57599d464180bd0734 --- /dev/null +++ b/assets/demo0.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:89e36ce1438c230928f376b4a2668e5daabace217a63e1df095ddce7202851ec +size 5921851 diff --git a/assets/demo1.gif b/assets/demo1.gif new file mode 100644 index 0000000000000000000000000000000000000000..6245e11013168d79398800694d22e1bc1ebb8937 --- /dev/null +++ b/assets/demo1.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3d5c7a413de2f9ff12818d0f476f868e5fb867c17d85315cce3eb090212468da +size 9477364 diff --git a/assets/pipeline.png b/assets/pipeline.png new file mode 100644 index 0000000000000000000000000000000000000000..c1b70543898066cad1414f66c9d30048ac03a75c Binary files /dev/null and b/assets/pipeline.png differ diff --git a/configs/common/__init__.py b/configs/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b6e690fd59145ce8900fd9ab8d8a996ee7d33834 --- /dev/null +++ b/configs/common/__init__.py @@ -0,0 +1 @@ +from . import * diff --git a/configs/common/dataloader.py b/configs/common/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..2de6507f4d3075211230f2b65d4cad2d55b98ba2 --- /dev/null +++ b/configs/common/dataloader.py @@ -0,0 +1,214 @@ +from os import path as osp +from typing import Literal + +from omegaconf import OmegaConf +from detectron2.config import LazyCall as L +from detectron2.config import instantiate +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from data import * + + +DATA_ROOT = "${Root to Datasets}" +if DATA_ROOT == "${Root to Datasets}": + raise Exception( + f"""{osp.abspath(__file__)}: Rewrite `DATA_ROOT` with the root to the datasets. +The directory structure should be: +-DATA_ROOT +|-videoattentiontarget +|--images +|--annotations +|---train +|---test +|--head_masks +|---images +|-gazefollow +|--train +|--test2 +|--train_annotations_release.txt +|--test_annotations_release.txt +|--head_masks +|---train +|---test2 +""" + ) + +# Basic Config for Video Attention Target dataset and preprocessing +data_info = OmegaConf.create() +data_info.video_attention_target = OmegaConf.create() +data_info.video_attention_target.train_root = osp.join( + DATA_ROOT, "videoattentiontarget/images" +) +data_info.video_attention_target.train_anno = osp.join( + DATA_ROOT, "videoattentiontarget/annotations/train" +) +data_info.video_attention_target.val_root = osp.join( + DATA_ROOT, "videoattentiontarget/images" +) +data_info.video_attention_target.val_anno = osp.join( + DATA_ROOT, "videoattentiontarget/annotations/test" +) +data_info.video_attention_target.head_root = osp.join( + DATA_ROOT, "videoattentiontarget/head_masks/images" +) + +data_info.video_attention_target_video = OmegaConf.create() +data_info.video_attention_target_video.train_root = osp.join( + DATA_ROOT, "videoattentiontarget/images" +) +data_info.video_attention_target_video.train_anno = osp.join( + DATA_ROOT, "videoattentiontarget/annotations/train" +) +data_info.video_attention_target_video.val_root = osp.join( + DATA_ROOT, "videoattentiontarget/images" +) +data_info.video_attention_target_video.val_anno = osp.join( + DATA_ROOT, "videoattentiontarget/annotations/test" +) +data_info.video_attention_target_video.head_root = osp.join( + DATA_ROOT, "videoattentiontarget/head_masks/images" +) + +data_info.gazefollow = OmegaConf.create() +data_info.gazefollow.train_root = osp.join(DATA_ROOT, "gazefollow") +data_info.gazefollow.train_anno = osp.join( + DATA_ROOT, "gazefollow/train_annotations_release.txt" +) +data_info.gazefollow.val_root = osp.join(DATA_ROOT, "gazefollow") +data_info.gazefollow.val_anno = osp.join( + DATA_ROOT, "gazefollow/test_annotations_release.txt" +) +data_info.gazefollow.head_root = osp.join(DATA_ROOT, "gazefollow/head_masks") + +data_info.input_size = 224 +data_info.output_size = 64 +data_info.quant_labelmap = True +data_info.mean = (0.485, 0.456, 0.406) +data_info.std = (0.229, 0.224, 0.225) +data_info.bbox_jitter = 0.5 +data_info.rand_crop = 0.5 +data_info.rand_flip = 0.5 +data_info.color_jitter = 0.5 +data_info.rand_rotate = 0.0 +data_info.rand_lsj = 0.0 + +data_info.mask_size = 24 +data_info.mask_scene = False +data_info.mask_head = False +data_info.max_scene_patches_ratio = 0.5 +data_info.max_head_patches_ratio = 0.3 +data_info.mask_prob = 0.2 + +data_info.seq_len = 16 +data_info.max_len = 32 + + +# Dataloader(gazefollow/video_atention_target, train/val) +def __build_dataloader( + name: Literal[ + "gazefollow", "video_attention_target", "video_attention_target_video" + ], + is_train: bool, + batch_size: int = 64, + num_workers: int = 14, + pin_memory: bool = True, + persistent_workers: bool = True, + drop_last: bool = True, + distributed: bool = False, + **kwargs, +): + assert name in [ + "gazefollow", + "video_attention_target", + "video_attention_target_video", + ], f'{name} not in ("gazefollow", "video_attention_target", "video_attention_target_video")' + + for k, v in kwargs.items(): + if k in ["train_root", "train_anno", "val_root", "val_anno", "head_root"]: + data_info[name][k] = v + else: + data_info[k] = v + + datasets = { + "gazefollow": GazeFollow, + "video_attention_target": VideoAttentionTarget, + "video_attention_target_video": VideoAttentionTargetVideo, + } + dataset = L(datasets[name])( + image_root=data_info[name]["train_root" if is_train else "val_root"], + anno_root=data_info[name]["train_anno" if is_train else "val_anno"], + head_root=data_info[name]["head_root"], + transform=get_transform( + input_resolution=data_info.input_size, + mean=data_info.mean, + std=data_info.std, + ), + input_size=data_info.input_size, + output_size=data_info.output_size, + quant_labelmap=data_info.quant_labelmap, + is_train=is_train, + bbox_jitter=data_info.bbox_jitter, + rand_crop=data_info.rand_crop, + rand_flip=data_info.rand_flip, + color_jitter=data_info.color_jitter, + rand_rotate=data_info.rand_rotate, + rand_lsj=data_info.rand_lsj, + mask_generator=( + MaskGenerator( + input_size=data_info.mask_size, + mask_scene=data_info.mask_scene, + mask_head=data_info.mask_head, + max_scene_patches_ratio=data_info.max_scene_patches_ratio, + max_head_patches_ratio=data_info.max_head_patches_ratio, + mask_prob=data_info.mask_prob, + ) + if is_train + else None + ), + ) + if name == "video_attention_target_video": + dataset.seq_len = data_info.seq_len + dataset.max_len = data_info.max_len + dataset = instantiate(dataset) + + return DataLoader( + dataset=dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + collate_fn=video_collate if name == "video_attention_target_video" else None, + sampler=DistributedSampler(dataset, shuffle=is_train) if distributed else None, + drop_last=drop_last, + ) + + +dataloader = OmegaConf.create() +dataloader.gazefollow = OmegaConf.create() +dataloader.gazefollow.train = L(__build_dataloader)( + name="gazefollow", + is_train=True, +) +dataloader.gazefollow.val = L(__build_dataloader)( + name="gazefollow", + is_train=False, +) +dataloader.video_attention_target = OmegaConf.create() +dataloader.video_attention_target.train = L(__build_dataloader)( + name="video_attention_target", + is_train=True, +) +dataloader.video_attention_target.val = L(__build_dataloader)( + name="video_attention_target", + is_train=False, +) +dataloader.video_attention_target_video = OmegaConf.create() +dataloader.video_attention_target_video.train = L(__build_dataloader)( + name="video_attention_target_video", + is_train=True, +) +dataloader.video_attention_target_video.val = L(__build_dataloader)( + name="video_attention_target_video", + is_train=False, +) diff --git a/configs/common/model.py b/configs/common/model.py new file mode 100644 index 0000000000000000000000000000000000000000..fd9f0e3e0f6037ef0b1c839067ad337afbd897df --- /dev/null +++ b/configs/common/model.py @@ -0,0 +1,44 @@ +from detectron2.config import LazyCall as L + +from modeling import backbone, patch_attention, meta_arch, head, criterion + + +model = L(meta_arch.GazeAttentionMapper)() +model.backbone = L(backbone.build_backbone)( + name="small", out_attn=[2, 5, 8, 11] +) +model.pam = L(patch_attention.build_pam)(name="PatchPAM", patch_size=16) +model.regressor = L(head.build_heatmap_head)( + name="SimpleDeconv", + in_channel=24, + deconv_cfgs=[ + { + "in_channels": 24, + "out_channels": 12, + "kernel_size": 3, + "stride": 2, + }, + { + "in_channels": 12, + "out_channels": 6, + "kernel_size": 3, + "stride": 2, + }, + { + "in_channels": 6, + "out_channels": 3, + "kernel_size": 3, + "stride": 2, + }, + { + "in_channels": 3, + "out_channels": 1, + "kernel_size": 3, + "stride": 2, + }, + ], + feat_type="attn", +) +model.classifier = L(head.build_inout_head)(name="SimpleLinear", in_channel=384) +model.criterion = L(criterion.GazeMapperCriterion)() +model.device = "cuda" diff --git a/configs/common/optimizer.py b/configs/common/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..5cd84ed34f12d46b3be905ef3edebfae26b867d4 --- /dev/null +++ b/configs/common/optimizer.py @@ -0,0 +1,48 @@ +from detectron2 import model_zoo + + +def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12): + """ + Calculate lr decay rate for different ViT blocks. + Args: + name (string): parameter name. + lr_decay_rate (float): base lr decay rate. + num_layers (int): number of ViT blocks. + + Returns: + lr decay rate for the given parameter. + """ + layer_id = num_layers + 1 + if name.startswith("backbone"): + if ".pos_embed" in name or ".patch_embed" in name: + layer_id = 0 + elif ".blocks." in name and ".residual." not in name: + layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 + + return lr_decay_rate ** (num_layers + 1 - layer_id) + + +class LRDecayRater: + def __init__(self, lr_decay_rate=1.0, num_layers=12, backbone_multiplier=1.0, freeze_pe=False, pam_lr_decay=1): + self.lr_decay_rate = lr_decay_rate + self.num_layers = num_layers + self.backbone_multiplier = backbone_multiplier + self.freeze_pe = freeze_pe + self.pam_lr_decay = pam_lr_decay + + def __call__(self, name): + if name.startswith("backbone"): + if self.freeze_pe and ".pos_embed" in name or ".patch_embed" in name: + return 0 + return self.backbone_multiplier * get_vit_lr_decay_rate( + name, self.lr_decay_rate, self.num_layers + ) + if name.startswith("pam"): + return self.pam_lr_decay + return 1 + + +# Optimizer +optimizer = model_zoo.get_config("common/optim.py").AdamW +optimizer.params.lr_factor_func = LRDecayRater(num_layers=12, lr_decay_rate=0.65) +optimizer.params.overrides = {"pos_embed": {"weight_decay": 0.0}} diff --git a/configs/common/scheduler.py b/configs/common/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..120e06acb59ff30df5309db98d0c8cd3a5284596 --- /dev/null +++ b/configs/common/scheduler.py @@ -0,0 +1,18 @@ +from typing import Literal +from detectron2.config import LazyCall as L +from detectron2.solver import WarmupParamScheduler +from fvcore.common.param_scheduler import MultiStepParamScheduler, CosineParamScheduler + + +def get_scheduler(typ: Literal["multistep", "cosine"] = "multistep", **kwargs): + if typ == "multistep": + return MultiStepParamScheduler(**kwargs) + elif typ == "cosine": + return CosineParamScheduler(**kwargs) + + +lr_multiplier = L(WarmupParamScheduler)( + scheduler=L(get_scheduler)(), + warmup_length=0, + warmup_factor=0.001, +) diff --git a/configs/common/train.py b/configs/common/train.py new file mode 100644 index 0000000000000000000000000000000000000000..4eb45504b9c470f45a0daf57f68f32e740d18124 --- /dev/null +++ b/configs/common/train.py @@ -0,0 +1,16 @@ +train = dict( + output_dir="./output", + init_checkpoint="", + max_iter=90000, + amp=dict(enabled=False), # options for Automatic Mixed Precision + ddp=dict( # options for DistributedDataParallel + broadcast_buffers=True, + find_unused_parameters=False, + fp16_compression=True, + ), + checkpointer=dict(period=5000, max_to_keep=100), # options for PeriodicCheckpointer + eval_period=5000, + log_period=100, + device="cuda", + # ... +) diff --git a/configs/gazefollow.py b/configs/gazefollow.py new file mode 100644 index 0000000000000000000000000000000000000000..ee5119a1932f25be69b42eb4d522d3de191116b3 --- /dev/null +++ b/configs/gazefollow.py @@ -0,0 +1,86 @@ +from .common.dataloader import dataloader +from .common.model import model +from .common.optimizer import optimizer +from .common.scheduler import lr_multiplier +from .common.train import train +from os.path import join, basename +from torch.cuda import device_count + + +num_gpu = device_count() +ins_per_iter = 48 +len_dataset = 126000 +num_epoch = 14 +# dataloader +dataloader = dataloader.gazefollow +dataloader.train.batch_size = ins_per_iter // num_gpu +dataloader.train.num_workers = dataloader.val.num_workers = 14 +dataloader.train.distributed = num_gpu > 1 +dataloader.train.rand_rotate = 0.5 +dataloader.train.rand_lsj = 0.5 +dataloader.train.input_size = dataloader.val.input_size = 434 +dataloader.train.mask_scene = True +dataloader.train.mask_prob = 0.5 +dataloader.train.mask_size = dataloader.train.input_size // 14 +dataloader.train.max_scene_patches_ratio = 0.5 +dataloader.val.batch_size = 64 +dataloader.val.distributed = False +# train +train.init_checkpoint = "pretrained/dinov2_small.pth" +train.output_dir = join("./output", basename(__file__).split(".")[0]) +train.max_iter = len_dataset * num_epoch // ins_per_iter +train.log_period = len_dataset // (ins_per_iter * 100) +train.checkpointer.max_to_keep = 10 +train.checkpointer.period = len_dataset // ins_per_iter +train.seed = 0 +# optimizer +optimizer.lr = 1e-4 +optimizer.betas = (0.9, 0.99) +lr_multiplier.scheduler.typ = "cosine" +lr_multiplier.scheduler.start_value = 1 +lr_multiplier.scheduler.end_value = 0.1 +lr_multiplier.warmup_length = 1e-2 +# model +model.use_aux_loss = model.pam.use_aux_loss = model.criterion.use_aux_loss = True +model.pam.name = "PatchPAM" +model.pam.embed_dim = 8 +model.pam.patch_size = 14 +model.backbone.name = "dinov2_small" +model.backbone.return_softmax_attn = True +model.backbone.out_attn = [2, 5, 8, 11] +model.backbone.use_cls_token = True +model.backbone.use_mask_token = True +model.regressor.name = "UpSampleConv" +model.regressor.in_channel = 24 +model.regressor.use_conv = False +model.regressor.dim = 24 +model.regressor.deconv_cfgs = [ + dict( + in_channels=24, + out_channels=16, + kernel_size=3, + stride=1, + padding=1, + ), + dict( + in_channels=16, + out_channels=8, + kernel_size=3, + stride=1, + padding=1, + ), + dict( + in_channels=8, + out_channels=1, + kernel_size=3, + stride=1, + padding=1, + ), +] +model.regressor.feat_type = "attn" +model.classifier.name = "SimpleMlp" +model.classifier.in_channel = 384 +model.criterion.aux_weight = 100 +model.criterion.aux_head_thres = 0.05 +model.criterion.use_focal_loss = True +model.device = "cuda" diff --git a/configs/gazefollow_518.py b/configs/gazefollow_518.py new file mode 100644 index 0000000000000000000000000000000000000000..a9cc1c86ad0a424c7e0953cb28646419922bd903 --- /dev/null +++ b/configs/gazefollow_518.py @@ -0,0 +1,86 @@ +from .common.dataloader import dataloader +from .common.model import model +from .common.optimizer import optimizer +from .common.scheduler import lr_multiplier +from .common.train import train +from os.path import join, basename +from torch.cuda import device_count + + +num_gpu = device_count() +ins_per_iter = 32 +len_dataset = 126000 +num_epoch = 1 +# dataloader +dataloader = dataloader.gazefollow +dataloader.train.batch_size = ins_per_iter // num_gpu +dataloader.train.num_workers = dataloader.val.num_workers = 14 +dataloader.train.distributed = num_gpu > 1 +dataloader.train.rand_rotate = 0.5 +dataloader.train.rand_lsj = 0.5 +dataloader.train.input_size = dataloader.val.input_size = 518 +dataloader.train.mask_scene = True +dataloader.train.mask_prob = 0.5 +dataloader.train.mask_size = dataloader.train.input_size // 14 +dataloader.train.max_scene_patches_ratio = 0.5 +dataloader.val.batch_size = 32 +dataloader.val.distributed = False +# train +train.init_checkpoint = "output/gazefollow/model_final.pth" +train.output_dir = join("./output", basename(__file__).split(".")[0]) +train.max_iter = len_dataset * num_epoch // ins_per_iter +train.log_period = len_dataset // (ins_per_iter * 100) +train.checkpointer.max_to_keep = 10 +train.checkpointer.period = len_dataset // ins_per_iter +train.seed = 0 +# optimizer +optimizer.lr = 1e-5 +optimizer.betas = (0.9, 0.99) +lr_multiplier.scheduler.typ = "cosine" +lr_multiplier.scheduler.start_value = 1 +lr_multiplier.scheduler.end_value = 0.1 +lr_multiplier.warmup_length = 1e-2 +# model +model.use_aux_loss = model.pam.use_aux_loss = model.criterion.use_aux_loss = True +model.pam.name = "PatchPAM" +model.pam.embed_dim = 8 +model.pam.patch_size = 14 +model.backbone.name = "dinov2_small" +model.backbone.return_softmax_attn = True +model.backbone.out_attn = [2, 5, 8, 11] +model.backbone.use_cls_token = True +model.backbone.use_mask_token = True +model.regressor.name = "UpSampleConv" +model.regressor.in_channel = 24 +model.regressor.use_conv = False +model.regressor.dim = 24 +model.regressor.deconv_cfgs = [ + dict( + in_channels=24, + out_channels=16, + kernel_size=3, + stride=1, + padding=1, + ), + dict( + in_channels=16, + out_channels=8, + kernel_size=3, + stride=1, + padding=1, + ), + dict( + in_channels=8, + out_channels=1, + kernel_size=3, + stride=1, + padding=1, + ), +] +model.regressor.feat_type = "attn" +model.classifier.name = "SimpleMlp" +model.classifier.in_channel = 384 +model.criterion.aux_weight = 0 +model.criterion.aux_head_thres = 0.05 +model.criterion.use_focal_loss = True +model.device = "cuda" diff --git a/configs/videoattentiontarget.py b/configs/videoattentiontarget.py new file mode 100644 index 0000000000000000000000000000000000000000..67fd6d15ecc4077b0a00d7bcbef516e8dfcc5f79 --- /dev/null +++ b/configs/videoattentiontarget.py @@ -0,0 +1,90 @@ +from .common.dataloader import dataloader +from .common.model import model +from .common.optimizer import optimizer +from .common.scheduler import lr_multiplier +from .common.train import train +from os.path import join, basename +from torch.cuda import device_count + + +num_gpu = device_count() +ins_per_iter = 4 +len_dataset = 4400 +num_epoch = 1 +# dataloader +dataloader = dataloader.video_attention_target_video +dataloader.train.batch_size = ins_per_iter // num_gpu +dataloader.train.num_workers = dataloader.val.num_workers = 14 +dataloader.train.distributed = num_gpu > 1 +dataloader.train.rand_rotate = 0.5 +dataloader.train.rand_lsj = 0.5 +dataloader.train.input_size = dataloader.val.input_size = 518 +dataloader.train.mask_scene = True +dataloader.train.mask_prob = 0.5 +dataloader.train.mask_size = dataloader.train.input_size // 14 +dataloader.train.max_scene_patches_ratio = 0.5 +dataloader.train.seq_len = 8 +dataloader.val.quant_labelmap = False +dataloader.val.seq_len = 8 +dataloader.val.batch_size = 4 +dataloader.val.distributed = False +# train +train.init_checkpoint = "output/gazefollow_518/model_final.pth" +train.output_dir = join("./output", basename(__file__).split(".")[0]) +train.max_iter = len_dataset * num_epoch // ins_per_iter +train.log_period = len_dataset // (ins_per_iter * 100) +train.checkpointer.max_to_keep = 100 +train.checkpointer.period = len_dataset // ins_per_iter +train.seed = 0 +# optimizer +optimizer.lr = 1e-6 +# optimizer.params.lr_factor_func.backbone_multiplier = 0.1 +# optimizer.params.lr_factor_func.pam_lr_decay = 0.1 +lr_multiplier.scheduler.values = [1.0] +lr_multiplier.scheduler.milestones = [] +lr_multiplier.scheduler.num_updates = train.max_iter +lr_multiplier.warmup_length = 0 +# model +model.use_aux_loss = model.pam.use_aux_loss = model.criterion.use_aux_loss = True +model.pam.name = "PatchPAM" +model.pam.embed_dim = 8 +model.pam.patch_size = 14 +model.backbone.name = "dinov2_small" +model.backbone.return_softmax_attn = True +model.backbone.out_attn = [2, 5, 8, 11] +model.backbone.use_cls_token = True +model.backbone.use_mask_token = True +model.regressor.name = "UpSampleConv" +model.regressor.in_channel = 24 +model.regressor.use_conv = False +model.regressor.dim = 24 +model.regressor.deconv_cfgs = [ + dict( + in_channels=24, + out_channels=16, + kernel_size=3, + stride=1, + padding=1, + ), + dict( + in_channels=16, + out_channels=8, + kernel_size=3, + stride=1, + padding=1, + ), + dict( + in_channels=8, + out_channels=1, + kernel_size=3, + stride=1, + padding=1, + ), +] +model.regressor.feat_type = "attn" +model.classifier.name = "SimpleMlp" +model.classifier.in_channel = 384 +model.criterion.aux_weight = 0 +model.criterion.aux_head_thres = 0.05 +model.criterion.use_focal_loss = True +model.device = "cuda" diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6e7df832158379ba2b7820419fd9eb24be240503 --- /dev/null +++ b/data/__init__.py @@ -0,0 +1,5 @@ +from .video_attention_target import VideoAttentionTarget +from .video_attention_target_video import VideoAttentionTargetVideo, video_collate +from .gazefollow import GazeFollow +from .data_utils import get_transform +from .masking import MaskGenerator diff --git a/data/augmentation.py b/data/augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..b1a195aa8e4f55112d8b4e03dde923be84aa0d63 --- /dev/null +++ b/data/augmentation.py @@ -0,0 +1,312 @@ +import math +from typing import Tuple, List +import numpy as np +from PIL import Image, ImageOps +from torchvision import transforms +from torchvision.transforms import functional as TF + + +class Augmentation: + def __init__(self, p: float) -> None: + self.p = p + + def transform( + self, + image: Image, + bbox: Tuple[float], + gaze: Tuple[float], + head_mask: Image, + size: Tuple[int], + ): + raise NotImplementedError + + def __call__( + self, + image: Image, + bbox: Tuple[float], + gaze: Tuple[float], + head_mask: Image, + size: Tuple[int], + ): + if np.random.random_sample() < self.p: + return self.transform(image, bbox, gaze, head_mask, size) + return image, bbox, gaze, head_mask, size + + +class AugmentationList: + def __init__(self, augmentations: List[Augmentation]) -> None: + self.augmentations = augmentations + + def __call__( + self, + image: Image, + bbox: Tuple[float], + gaze: Tuple[float], + head_mask: Image, + size: Tuple[int], + ): + for aug in self.augmentations: + image, bbox, gaze, head_mask, size = aug(image, bbox, gaze, head_mask, size) + return image, bbox, gaze, head_mask, size + + +class BoxJitter(Augmentation): + # Jitter (expansion-only) bounding box size + def __init__(self, p: float, expansion: float = 0.2) -> None: + super().__init__(p) + self.expansion = expansion + + def transform( + self, + image: Image, + bbox: Tuple[float], + gaze: Tuple[float], + head_mask: Image, + size: Tuple[int], + ): + x_min, y_min, x_max, y_max = bbox + width, height = size + k = np.random.random_sample() * self.expansion + x_min = np.clip(x_min - k * abs(x_max - x_min), 0, width - 1) + y_min = np.clip(y_min - k * abs(y_max - y_min), 0, height - 1) + x_max = np.clip(x_max + k * abs(x_max - x_min), 0, width - 1) + y_max = np.clip(y_max + k * abs(y_max - y_min), 0, height - 1) + return image, (x_min, y_min, x_max, y_max), gaze, head_mask, size + + +class RandomCrop(Augmentation): + def __init__(self, p: float) -> None: + super().__init__(p) + + def transform( + self, + image: Image, + bbox: Tuple[float], + gaze: Tuple[float], + head_mask: Image, + size: Tuple[int], + ): + x_min, y_min, x_max, y_max = bbox + gaze_x, gaze_y = gaze + width, height = size + # Calculate the minimum valid range of the crop that doesn't exclude the face and the gaze target + crop_x_min = np.min([gaze_x * width, x_min, x_max]) + crop_y_min = np.min([gaze_y * height, y_min, y_max]) + crop_x_max = np.max([gaze_x * width, x_min, x_max]) + crop_y_max = np.max([gaze_y * height, y_min, y_max]) + + # Randomly select a random top left corner + crop_x_min = np.random.uniform(0, crop_x_min) + crop_y_min = np.random.uniform(0, crop_y_min) + + # Find the range of valid crop width and height starting from the (crop_x_min, crop_y_min) + crop_width_min = crop_x_max - crop_x_min + crop_height_min = crop_y_max - crop_y_min + crop_width_max = width - crop_x_min + crop_height_max = height - crop_y_min + + # Randomly select a width and a height + crop_width = np.random.uniform(crop_width_min, crop_width_max) + crop_height = np.random.uniform(crop_height_min, crop_height_max) + + # Round to integers + crop_y_min, crop_x_min, crop_height, crop_width = map( + int, map(round, (crop_y_min, crop_x_min, crop_height, crop_width)) + ) + + # Crop it + image = TF.crop(image, crop_y_min, crop_x_min, crop_height, crop_width) + head_mask = TF.crop(head_mask, crop_y_min, crop_x_min, crop_height, crop_width) + + # convert coordinates into the cropped frame + x_min, y_min, x_max, y_max = ( + x_min - crop_x_min, + y_min - crop_y_min, + x_max - crop_x_min, + y_max - crop_y_min, + ) + + gaze_x = (gaze_x * width - crop_x_min) / float(crop_width) + gaze_y = (gaze_y * height - crop_y_min) / float(crop_height) + + return ( + image, + (x_min, y_min, x_max, y_max), + (gaze_x, gaze_y), + head_mask, + (crop_width, crop_height), + ) + + +class RandomFlip(Augmentation): + def __init__(self, p: float) -> None: + super().__init__(p) + + def transform( + self, + image: Image, + bbox: Tuple[float], + gaze: Tuple[float], + head_mask: Image, + size: Tuple[int], + ): + image = image.transpose(Image.FLIP_LEFT_RIGHT) + head_mask = head_mask.transpose(Image.FLIP_LEFT_RIGHT) + x_min, y_min, x_max, y_max = bbox + x_min, x_max = size[0] - x_max, size[0] - x_min + gaze_x, gaze_y = 1 - gaze[0], gaze[1] + return image, (x_min, y_min, x_max, y_max), (gaze_x, gaze_y), head_mask, size + + +class RandomRotate(Augmentation): + def __init__( + self, p: float, max_angle: int = 20, resample: int = Image.BILINEAR + ) -> None: + super().__init__(p) + self.max_angle = max_angle + self.resample = resample + + def _random_rotation_matrix(self): + angle = (2 * np.random.random_sample() - 1) * self.max_angle + angle = -math.radians(angle) + return [ + round(math.cos(angle), 15), + round(math.sin(angle), 15), + 0.0, + round(-math.sin(angle), 15), + round(math.cos(angle), 15), + 0.0, + ] + + @staticmethod + def _transform(x, y, matrix): + return ( + matrix[0] * x + matrix[1] * y + matrix[2], + matrix[3] * x + matrix[4] * y + matrix[5], + ) + + @staticmethod + def _inv_transform(x, y, matrix): + x, y = x - matrix[2], y - matrix[5] + return matrix[0] * x + matrix[3] * y, matrix[1] * x + matrix[4] * y + + def transform( + self, + image: Image, + bbox: Tuple[float], + gaze: Tuple[float], + head_mask: Image, + size: Tuple[int], + ): + x_min, y_min, x_max, y_max = bbox + gaze_x, gaze_y = gaze + width, height = size + rot_mat = self._random_rotation_matrix() + + # Calculate offsets + rot_center = (width / 2.0, height / 2.0) + rot_mat[2], rot_mat[5] = self._transform( + -rot_center[0], -rot_center[1], rot_mat + ) + rot_mat[2] += rot_center[0] + rot_mat[5] += rot_center[1] + xx = [] + yy = [] + for x, y in ((0, 0), (width, 0), (width, height), (0, height)): + x, y = self._transform(x, y, rot_mat) + xx.append(x) + yy.append(y) + nw = math.ceil(max(xx)) - math.floor(min(xx)) + nh = math.ceil(max(yy)) - math.floor(min(yy)) + rot_mat[2], rot_mat[5] = self._transform( + -(nw - width) / 2.0, -(nh - height) / 2.0, rot_mat + ) + + image = image.transform((nw, nh), Image.AFFINE, rot_mat, self.resample) + head_mask = head_mask.transform((nw, nh), Image.AFFINE, rot_mat, self.resample) + + xx = [] + yy = [] + for x, y in ( + (x_min, y_min), + (x_min, y_max), + (x_max, y_min), + (x_max, y_max), + ): + x, y = self._inv_transform(x, y, rot_mat) + xx.append(x) + yy.append(y) + x_max, x_min = min(max(xx), nw), max(min(xx), 0) + y_max, y_min = min(max(yy), nh), max(min(yy), 0) + + gaze_x, gaze_y = self._inv_transform(gaze_x * width, gaze_y * height, rot_mat) + gaze_x = max(min(gaze_x / nw, 1), 0) + gaze_y = max(min(gaze_y / nh, 1), 0) + + return ( + image, + (x_min, y_min, x_max, y_max), + (gaze_x, gaze_y), + head_mask, + (nw, nh), + ) + + +class ColorJitter(Augmentation): + def __init__( + self, + p: float, + brightness: float = 0.4, + contrast: float = 0.4, + saturation: float = 0.2, + hue: float = 0.1, + ) -> None: + super().__init__(p) + self.color_jitter = transforms.ColorJitter( + brightness=brightness, contrast=contrast, saturation=saturation, hue=hue + ) + + def transform( + self, + image: Image, + bbox: Tuple[float], + gaze: Tuple[float], + head_mask: Image, + size: Tuple[int], + ): + return self.color_jitter(image), bbox, gaze, head_mask, size + + +class RandomLSJ(Augmentation): + def __init__(self, p: float, min_scale: float = 0.1) -> None: + super().__init__(p) + self.min_scale = min_scale + + def transform( + self, + image: Image, + bbox: Tuple[float], + gaze: Tuple[float], + head_mask: Image, + size: Tuple[int], + ): + x_min, y_min, x_max, y_max = bbox + gaze_x, gaze_y = gaze + width, height = size + + scale = self.min_scale + np.random.random_sample() * (1 - self.min_scale) + nh, nw = int(height * scale), int(width * scale) + + image = TF.resize(image, (nh, nw)) + image = ImageOps.expand(image, (0, 0, width - nw, height - nh)) + head_mask = TF.resize(head_mask, (nh, nw)) + head_mask = ImageOps.expand(head_mask, (0, 0, width - nw, height - nh)) + + x_min, y_min, x_max, y_max = ( + x_min * scale, + y_min * scale, + x_max * scale, + y_max * scale, + ) + gaze_x, gaze_y = gaze_x * scale, gaze_y * scale + return image, (x_min, y_min, x_max, y_max), (gaze_x, gaze_y), head_mask, size diff --git a/data/data_utils.py b/data/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ed3e3a6300703df8c05e6da5aa835deca6f56fef --- /dev/null +++ b/data/data_utils.py @@ -0,0 +1,181 @@ +from typing import Tuple +import torch +from torchvision import transforms +import numpy as np +import pandas as pd + + +def to_numpy(tensor: torch.Tensor): + if torch.is_tensor(tensor): + return tensor.cpu().detach().numpy() + elif type(tensor).__module__ != "numpy": + raise ValueError("Cannot convert {} to numpy array".format(type(tensor))) + return tensor + + +def to_torch(ndarray: np.ndarray): + if type(ndarray).__module__ == "numpy": + return torch.from_numpy(ndarray) + elif not torch.is_tensor(ndarray): + raise ValueError("Cannot convert {} to torch tensor".format(type(ndarray))) + return ndarray + + +def get_head_box_channel( + x_min, y_min, x_max, y_max, width, height, resolution, coordconv=False +): + head_box = ( + np.array([x_min / width, y_min / height, x_max / width, y_max / height]) + * resolution + ) + int_head_box = head_box.astype(int) + int_head_box = np.clip(int_head_box, 0, resolution - 1) + if int_head_box[0] == int_head_box[2]: + if int_head_box[0] == 0: + int_head_box[2] = 1 + elif int_head_box[2] == resolution - 1: + int_head_box[0] = resolution - 2 + elif abs(head_box[2] - int_head_box[2]) > abs(head_box[0] - int_head_box[0]): + int_head_box[2] += 1 + else: + int_head_box[0] -= 1 + if int_head_box[1] == int_head_box[3]: + if int_head_box[1] == 0: + int_head_box[3] = 1 + elif int_head_box[3] == resolution - 1: + int_head_box[1] = resolution - 2 + elif abs(head_box[3] - int_head_box[3]) > abs(head_box[1] - int_head_box[1]): + int_head_box[3] += 1 + else: + int_head_box[1] -= 1 + head_box = int_head_box + if coordconv: + unit = np.array(range(0, resolution), dtype=np.float32) + head_channel = [] + for i in unit: + head_channel.append([unit + i]) + head_channel = np.squeeze(np.array(head_channel)) / float(np.max(head_channel)) + head_channel[head_box[1] : head_box[3], head_box[0] : head_box[2]] = 0 + else: + head_channel = np.zeros((resolution, resolution), dtype=np.float32) + head_channel[head_box[1] : head_box[3], head_box[0] : head_box[2]] = 1 + head_channel = torch.from_numpy(head_channel) + return head_channel + + +def draw_labelmap(img, pt, sigma, type="Gaussian"): + # Draw a 2D gaussian + # Adopted from https://github.com/anewell/pose-hg-train/blob/master/src/pypose/draw.py + img = to_numpy(img) + + # Check that any part of the gaussian is in-bounds + size = int(6 * sigma + 1) + ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)] + br = [ul[0] + size, ul[1] + size] + if ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or br[0] < 0 or br[1] < 0: + # If not, just return the image as is + return to_torch(img) + + # Generate gaussian + x = np.arange(0, size, 1, float) + y = x[:, np.newaxis] + x0 = y0 = size // 2 + # The gaussian is not normalized, we want the center value to equal 1 + if type == "Gaussian": + g = np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma**2)) + elif type == "Cauchy": + g = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma**2) ** 1.5) + + # Usable gaussian range + g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0] + g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1] + # Image range + img_x = max(0, ul[0]), min(br[0], img.shape[1]) + img_y = max(0, ul[1]), min(br[1], img.shape[0]) + + img[img_y[0] : img_y[1], img_x[0] : img_x[1]] += g[g_y[0] : g_y[1], g_x[0] : g_x[1]] + # img = img / np.max(img) + return to_torch(img) + + +def draw_labelmap_no_quant(img, pt, sigma, type="Gaussian"): + img = to_numpy(img) + shape = img.shape + x = np.arange(shape[0]) + y = np.arange(shape[1]) + xx, yy = np.meshgrid(x, y, indexing="ij") + dist_matrix = (yy - float(pt[0])) ** 2 + (xx - float(pt[1])) ** 2 + if type == "Gaussian": + g = np.exp(-dist_matrix / (2 * sigma**2)) + elif type == "Cauchy": + g = sigma / ((dist_matrix + sigma**2) ** 1.5) + g[dist_matrix > 10 * sigma**2] = 0 + img += g + # img = img / np.max(img) + return to_torch(img) + + +def multi_hot_targets(gaze_pts, out_res): + w, h = out_res + target_map = np.zeros((h, w)) + for p in gaze_pts: + if p[0] >= 0: + x, y = map(int, [p[0] * float(w), p[1] * float(h)]) + x = min(x, w - 1) + y = min(y, h - 1) + target_map[y, x] = 1 + return target_map + + +def get_cone(tgt, src, wh, theta=150): + eye = src * wh + gaze = tgt * wh + + pixel_mat = np.stack( + np.meshgrid(np.arange(wh[0]), np.arange(wh[1])), + -1, + ) + + dot_prod = np.sum((pixel_mat - eye) * (gaze - eye), axis=-1) + gaze_vector_norm = np.sqrt(np.sum((gaze - eye) ** 2)) + pixel_mat_norm = np.sqrt(np.sum((pixel_mat - eye) ** 2, axis=-1)) + + gaze_cones = dot_prod / (gaze_vector_norm * pixel_mat_norm) + gaze_cones = np.nan_to_num(gaze_cones, nan=1) + + theta = theta * (np.pi / 180) + beta = np.arccos(gaze_cones) + # Create mask where true if beta is less than theta/2 + pixel_mat_presence = beta < (theta / 2) + + # Zero out values outside the gaze cone + gaze_cones[~pixel_mat_presence] = 0 + gaze_cones = np.clip(gaze_cones, 0, None) + + return torch.from_numpy(gaze_cones).unsqueeze(0).float() + + +def get_transform( + input_resolution: int, mean: Tuple[int, int, int], std: Tuple[int, int, int] +): + return transforms.Compose( + [ + transforms.Resize((input_resolution, input_resolution)), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ] + ) + + +def smooth_by_conv(window_size, df, col): + padded_track = pd.concat( + [ + pd.DataFrame([[df.iloc[0][col]]] * (window_size // 2), columns=[0]), + df[col], + pd.DataFrame([[df.iloc[-1][col]]] * (window_size // 2), columns=[0]), + ] + ) + smoothed_signals = np.convolve( + padded_track.squeeze(), np.ones(window_size) / window_size, mode="valid" + ) + return smoothed_signals diff --git a/data/gazefollow.py b/data/gazefollow.py new file mode 100644 index 0000000000000000000000000000000000000000..68490853e991281d73ba1cbfa9053e2aab96aa25 --- /dev/null +++ b/data/gazefollow.py @@ -0,0 +1,295 @@ +from os import path as osp +from typing import Callable, Optional + +import torch +from torch.utils.data import Dataset +from torchvision.transforms import functional as TF +from PIL import Image +import pandas as pd + +from . import augmentation +from .masking import MaskGenerator +from . import data_utils as utils + + +class GazeFollow(Dataset): + def __init__( + self, + image_root: str, + anno_root: str, + head_root: str, + transform: Callable, + input_size: int, + output_size: int, + quant_labelmap: bool = True, + is_train: bool = True, + *, + mask_generator: Optional[MaskGenerator] = None, + bbox_jitter: float = 0.5, + rand_crop: float = 0.5, + rand_flip: float = 0.5, + color_jitter: float = 0.5, + rand_rotate: float = 0.0, + rand_lsj: float = 0.0, + ): + if is_train: + column_names = [ + "path", + "idx", + "body_bbox_x", + "body_bbox_y", + "body_bbox_w", + "body_bbox_h", + "eye_x", + "eye_y", + "gaze_x", + "gaze_y", + "bbox_x_min", + "bbox_y_min", + "bbox_x_max", + "bbox_y_max", + "inout", + "meta0", + "meta1", + ] + df = pd.read_csv( + anno_root, + sep=",", + names=column_names, + index_col=False, + encoding="utf-8-sig", + ) + df = df[ + df["inout"] != -1 + ] # only use "in" or "out "gaze. (-1 is invalid, 0 is out gaze) + df.reset_index(inplace=True) + self.y_train = df[ + [ + "bbox_x_min", + "bbox_y_min", + "bbox_x_max", + "bbox_y_max", + "eye_x", + "eye_y", + "gaze_x", + "gaze_y", + "inout", + ] + ] + self.X_train = df["path"] + self.length = len(df) + else: + column_names = [ + "path", + "idx", + "body_bbox_x", + "body_bbox_y", + "body_bbox_w", + "body_bbox_h", + "eye_x", + "eye_y", + "gaze_x", + "gaze_y", + "bbox_x_min", + "bbox_y_min", + "bbox_x_max", + "bbox_y_max", + "meta0", + "meta1", + ] + df = pd.read_csv( + anno_root, + sep=",", + names=column_names, + index_col=False, + encoding="utf-8-sig", + ) + df = df[ + [ + "path", + "eye_x", + "eye_y", + "gaze_x", + "gaze_y", + "bbox_x_min", + "bbox_y_min", + "bbox_x_max", + "bbox_y_max", + ] + ].groupby(["path", "eye_x"]) + self.keys = list(df.groups.keys()) + self.X_test = df + self.length = len(self.keys) + + self.data_dir = image_root + self.head_dir = head_root + self.transform = transform + self.is_train = is_train + + self.input_size = input_size + self.output_size = output_size + + self.draw_labelmap = ( + utils.draw_labelmap if quant_labelmap else utils.draw_labelmap_no_quant + ) + + if self.is_train: + ## data augmentation + self.augment = augmentation.AugmentationList( + [ + augmentation.ColorJitter(color_jitter), + augmentation.BoxJitter(bbox_jitter), + augmentation.RandomCrop(rand_crop), + augmentation.RandomFlip(rand_flip), + augmentation.RandomRotate(rand_rotate), + augmentation.RandomLSJ(rand_lsj), + ] + ) + + self.mask_generator = mask_generator + + def __getitem__(self, index): + if not self.is_train: + g = self.X_test.get_group(self.keys[index]) + cont_gaze = [] + for _, row in g.iterrows(): + path = row["path"] + x_min = row["bbox_x_min"] + y_min = row["bbox_y_min"] + x_max = row["bbox_x_max"] + y_max = row["bbox_y_max"] + eye_x = row["eye_x"] + eye_y = row["eye_y"] + gaze_x = row["gaze_x"] + gaze_y = row["gaze_y"] + cont_gaze.append( + [gaze_x, gaze_y] + ) # all ground truth gaze are stacked up + for _ in range(len(cont_gaze), 20): + cont_gaze.append( + [-1, -1] + ) # pad dummy gaze to match size for batch processing + cont_gaze = torch.FloatTensor(cont_gaze) + gaze_inside = True # always consider test samples as inside + else: + path = self.X_train.iloc[index] + ( + x_min, + y_min, + x_max, + y_max, + eye_x, + eye_y, + gaze_x, + gaze_y, + inout, + ) = self.y_train.iloc[index] + gaze_inside = bool(inout) + + img = Image.open(osp.join(self.data_dir, path)) + img = img.convert("RGB") + head_mask = Image.open(osp.join(self.head_dir, path)) + width, height = img.size + x_min, y_min, x_max, y_max = map(float, [x_min, y_min, x_max, y_max]) + if x_max < x_min: + x_min, x_max = x_max, x_min + if y_max < y_min: + y_min, y_max = y_max, y_min + # expand face bbox a bit + k = 0.1 + x_min = max(x_min - k * abs(x_max - x_min), 0) + y_min = max(y_min - k * abs(y_max - y_min), 0) + x_max = min(x_max + k * abs(x_max - x_min), width - 1) + y_max = min(y_max + k * abs(y_max - y_min), height - 1) + + if self.is_train: + img, bbox, gaze, head_mask, size = self.augment( + img, + (x_min, y_min, x_max, y_max), + (gaze_x, gaze_y), + head_mask, + (width, height), + ) + x_min, y_min, x_max, y_max = bbox + gaze_x, gaze_y = gaze + width, height = size + + head_channel = utils.get_head_box_channel( + x_min, + y_min, + x_max, + y_max, + width, + height, + resolution=self.input_size, + coordconv=False, + ).unsqueeze(0) + + if self.is_train and self.mask_generator is not None: + image_mask = self.mask_generator( + x_min / width, + y_min / height, + x_max / width, + y_max / height, + head_channel, + ) + + if self.transform is not None: + img = self.transform(img) + head_mask = TF.to_tensor( + TF.resize(head_mask, (self.input_size, self.input_size)) + ) + + # generate the heat map used for deconv prediction + gaze_heatmap = torch.zeros( + self.output_size, self.output_size + ) # set the size of the output + if not self.is_train: # aggregated heatmap + num_valid = 0 + for gaze_x, gaze_y in cont_gaze: + if gaze_x != -1: + num_valid += 1 + gaze_heatmap += self.draw_labelmap( + torch.zeros(self.output_size, self.output_size), + [gaze_x * self.output_size, gaze_y * self.output_size], + 3, + type="Gaussian", + ) + gaze_heatmap /= num_valid + else: + # if gaze_inside: + gaze_heatmap = self.draw_labelmap( + gaze_heatmap, + [gaze_x * self.output_size, gaze_y * self.output_size], + 3, + type="Gaussian", + ) + + imsize = torch.IntTensor([width, height]) + + if self.is_train: + out_dict = { + "images": img, + "head_channels": head_channel, + "heatmaps": gaze_heatmap, + "gazes": torch.FloatTensor([gaze_x, gaze_y]), + "gaze_inouts": torch.FloatTensor([gaze_inside]), + "head_masks": head_mask, + "imsize": imsize, + } + if self.mask_generator is not None: + out_dict["image_masks"] = image_mask + return out_dict + else: + return { + "images": img, + "head_channels": head_channel, + "heatmaps": gaze_heatmap, + "gazes": cont_gaze, + "gaze_inouts": torch.FloatTensor([gaze_inside]), + "head_masks": head_mask, + "imsize": imsize, + } + + def __len__(self): + return self.length diff --git a/data/masking.py b/data/masking.py new file mode 100644 index 0000000000000000000000000000000000000000..c9736d9f231bec5767fcf2ed502178553ad379a3 --- /dev/null +++ b/data/masking.py @@ -0,0 +1,175 @@ +import random +import math +import numpy as np +import torch +from torch.nn import functional as F + + +class SceneMaskGenerator: + def __init__( + self, + input_size, + min_num_patches=16, + max_num_patches_ratio=0.5, + min_aspect=0.3, + ): + if not isinstance(input_size, tuple): + input_size = (input_size,) * 2 + self.input_size = input_size + self.num_patches = input_size[0] * input_size[1] + + self.min_num_patches = min_num_patches + self.max_num_patches = max_num_patches_ratio * self.num_patches + + self.log_aspect_ratio = (math.log(min_aspect), -math.log(min_aspect)) + + def _mask(self, mask, max_mask_patches): + delta = 0 + for _ in range(4): + target_area = random.uniform(self.min_num_patches, max_mask_patches) + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + height, width = self.input_size + if w < width and h < height: + top = random.randint(0, height - h) + left = random.randint(0, width - w) + + num_masked = mask[top : top + h, left : left + w].sum() + # Overlap + if 0 < h * w - num_masked <= max_mask_patches: + mask[top : top + h, left : left + w] = 1 + delta = h * w - num_masked + break + return delta + + def __call__(self, head_mask): + mask = np.zeros(shape=self.input_size, dtype=bool) + mask_count = 0 + num_masking_patches = random.uniform(self.min_num_patches, self.max_num_patches) + while mask_count < num_masking_patches: + max_mask_patches = num_masking_patches - mask_count + delta = self._mask(mask, max_mask_patches) + if delta == 0: + break + else: + mask_count += delta + + mask = torch.from_numpy(mask).unsqueeze(0) + head_mask = ( + F.interpolate(head_mask.unsqueeze(0), mask.shape[-2:]).squeeze(0) < 0.5 + ) + return torch.logical_and(mask, head_mask).squeeze(0) + + +class HeadMaskGenerator: + def __init__( + self, + input_size, + min_num_patches=4, + max_num_patches_ratio=0.5, + min_aspect=0.3, + ): + if not isinstance(input_size, tuple): + input_size = (input_size,) * 2 + self.input_size = input_size + self.num_patches = input_size[0] * input_size[1] + + self.min_num_patches = min_num_patches + self.max_num_patches_ratio = max_num_patches_ratio + + self.log_aspect_ratio = (math.log(min_aspect), -math.log(min_aspect)) + + def __call__( + self, + x_min, + y_min, + x_max, + y_max, # coords in [0,1] + ): + height = math.floor((y_max - y_min) * self.input_size[0]) + width = math.floor((x_max - x_min) * self.input_size[1]) + origin_area = width * height + if origin_area < self.min_num_patches: + return torch.zeros(size=self.input_size, dtype=bool) + + target_area = random.uniform( + self.min_num_patches, self.max_num_patches_ratio * origin_area + ) + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + h = min(int(round(math.sqrt(target_area * aspect_ratio))), height) + w = min(int(round(math.sqrt(target_area / aspect_ratio))), width) + top = random.randint(0, height - h) + int(y_min * self.input_size[0]) + left = random.randint(0, width - w) + int(x_min * self.input_size[1]) + mask = torch.zeros(size=self.input_size, dtype=bool) + mask[top : top + h, left : left + w] = True + return mask + + +class MaskGenerator: + def __init__( + self, + input_size, + mask_scene: bool = False, + mask_head: bool = False, + min_scene_patches=16, + max_scene_patches_ratio=0.5, + min_head_patches=4, + max_head_patches_ratio=0.5, + min_aspect=0.3, + mask_prob=0.2, + head_prob=0.2, + ): + if not isinstance(input_size, tuple): + input_size = (input_size,) * 2 + self.input_size = input_size + if mask_scene: + self.scene_mask_generator = SceneMaskGenerator( + input_size, min_scene_patches, max_scene_patches_ratio, min_aspect + ) + else: + self.scene_mask_generator = None + + if mask_head: + self.head_mask_generator = HeadMaskGenerator( + input_size, min_head_patches, max_head_patches_ratio, min_aspect + ) + else: + self.head_mask_generator = None + + self.no_mask = not (mask_scene or mask_head) + self.mask_head = mask_head and not mask_scene + self.mask_scene = mask_scene and not mask_head + self.scene_prob = mask_prob + self.head_prob = head_prob + + def __call__( + self, + x_min, + y_min, + x_max, + y_max, + head_mask, + ): + mask_scene = random.random() < self.scene_prob + mask_head = random.random() < self.head_prob + no_mask = ( + self.no_mask + or (self.mask_head and not mask_head) + or (self.mask_scene and not mask_scene) + or not (mask_scene or mask_head) + ) + if no_mask: + return torch.zeros(size=self.input_size, dtype=bool) + if self.mask_scene: + return self.scene_mask_generator(head_mask) + if self.mask_head: + return self.head_mask_generator(x_min, y_min, x_max, y_max) + if mask_head and mask_scene: + return torch.logical_or( + self.scene_mask_generator(head_mask), + self.head_mask_generator(x_min, y_min, x_max, y_max), + ) + elif mask_head: + return self.head_mask_generator(x_min, y_min, x_max, y_max) + return self.scene_mask_generator(head_mask) diff --git a/data/video_attention_target.py b/data/video_attention_target.py new file mode 100644 index 0000000000000000000000000000000000000000..ce4b50133b3fc26d62ee1950802670da5d474817 --- /dev/null +++ b/data/video_attention_target.py @@ -0,0 +1,228 @@ +import glob +from typing import Callable, Optional +from os import path as osp + +import torch +from torch.utils.data.dataset import Dataset +import torchvision.transforms.functional as TF +import numpy as np +import pandas as pd +from PIL import Image + +from . import augmentation +from . import data_utils as utils +from .masking import MaskGenerator + + +class VideoAttentionTarget(Dataset): + def __init__( + self, + image_root: str, + anno_root: str, + head_root: str, + transform: Callable, + input_size: int, + output_size: int, + quant_labelmap: bool = True, + is_train: bool = True, + *, + mask_generator: Optional[MaskGenerator] = None, + bbox_jitter: float = 0.5, + rand_crop: float = 0.5, + rand_flip: float = 0.5, + color_jitter: float = 0.5, + rand_rotate: float = 0.0, + rand_lsj: float = 0.0, + ): + frames = [] + for show_dir in glob.glob(osp.join(anno_root, "*")): + for sequence_path in glob.glob(osp.join(show_dir, "*", "*.txt")): + df = pd.read_csv( + sequence_path, + header=None, + index_col=False, + names=[ + "path", + "x_min", + "y_min", + "x_max", + "y_max", + "gaze_x", + "gaze_y", + ], + ) + + show_name = sequence_path.split("/")[-3] + clip = sequence_path.split("/")[-2] + df["path"] = df["path"].apply( + lambda path: osp.join(show_name, clip, path) + ) + # Add two columns for the bbox center + df["eye_x"] = (df["x_min"] + df["x_max"]) / 2 + df["eye_y"] = (df["y_min"] + df["y_max"]) / 2 + df = df.sample(frac=0.2, random_state=42) + frames.extend(df.values.tolist()) + + df = pd.DataFrame( + frames, + columns=[ + "path", + "x_min", + "y_min", + "x_max", + "y_max", + "gaze_x", + "gaze_y", + "eye_x", + "eye_y", + ], + ) + # Drop rows with invalid bboxes + coords = torch.tensor( + np.array( + ( + df["x_min"].values, + df["y_min"].values, + df["x_max"].values, + df["y_max"].values, + ) + ).transpose(1, 0) + ) + valid_bboxes = (coords[:, 2:] >= coords[:, :2]).all(dim=1) + df = df.loc[valid_bboxes.tolist(), :] + df.reset_index(inplace=True) + self.df = df + self.length = len(df) + + self.data_dir = image_root + self.head_dir = head_root + self.transform = transform + self.draw_labelmap = ( + utils.draw_labelmap if quant_labelmap else utils.draw_labelmap_no_quant + ) + self.is_train = is_train + + self.input_size = input_size + self.output_size = output_size + + if self.is_train: + ## data augmentation + self.augment = augmentation.AugmentationList( + [ + augmentation.ColorJitter(color_jitter), + augmentation.BoxJitter(bbox_jitter), + augmentation.RandomCrop(rand_crop), + augmentation.RandomFlip(rand_flip), + augmentation.RandomRotate(rand_rotate), + augmentation.RandomLSJ(rand_lsj), + ] + ) + + self.mask_generator = mask_generator + + def __getitem__(self, index): + ( + _, + path, + x_min, + y_min, + x_max, + y_max, + gaze_x, + gaze_y, + eye_x, + eye_y, + ) = self.df.iloc[index] + gaze_inside = gaze_x != -1 or gaze_y != -1 + + img = Image.open(osp.join(self.data_dir, path)) + img = img.convert("RGB") + width, height = img.size + # Since we finetune from weights trained on GazeFollow, + # we don't incorporate the auxiliary task for VAT. + if osp.exists(osp.join(self.head_dir, path)): + head_mask = Image.open(osp.join(self.head_dir, path)).resize( + (width, height) + ) + else: + head_mask = Image.fromarray(np.zeros((height, width), dtype=np.float32)) + x_min, y_min, x_max, y_max = map(float, [x_min, y_min, x_max, y_max]) + if x_max < x_min: + x_min, x_max = x_max, x_min + if y_max < y_min: + y_min, y_max = y_max, y_min + gaze_x, gaze_y = gaze_x / width, gaze_y / height + # expand face bbox a bit + k = 0.1 + x_min = max(x_min - k * abs(x_max - x_min), 0) + y_min = max(y_min - k * abs(y_max - y_min), 0) + x_max = min(x_max + k * abs(x_max - x_min), width - 1) + y_max = min(y_max + k * abs(y_max - y_min), height - 1) + + if self.is_train: + img, bbox, gaze, head_mask, size = self.augment( + img, + (x_min, y_min, x_max, y_max), + (gaze_x, gaze_y), + head_mask, + (width, height), + ) + x_min, y_min, x_max, y_max = bbox + gaze_x, gaze_y = gaze + width, height = size + + head_channel = utils.get_head_box_channel( + x_min, + y_min, + x_max, + y_max, + width, + height, + resolution=self.input_size, + coordconv=False, + ).unsqueeze(0) + + if self.is_train and self.mask_generator is not None: + image_mask = self.mask_generator( + x_min / width, + y_min / height, + x_max / width, + y_max / height, + head_channel, + ) + + if self.transform is not None: + img = self.transform(img) + head_mask = TF.to_tensor( + TF.resize(head_mask, (self.input_size, self.input_size)) + ) + + # generate the heat map used for deconv prediction + gaze_heatmap = torch.zeros( + self.output_size, self.output_size + ) # set the size of the output + + gaze_heatmap = self.draw_labelmap( + gaze_heatmap, + [gaze_x * self.output_size, gaze_y * self.output_size], + 3, + type="Gaussian", + ) + + imsize = torch.IntTensor([width, height]) + + out_dict = { + "images": img, + "head_channels": head_channel, + "heatmaps": gaze_heatmap, + "gazes": torch.FloatTensor([gaze_x, gaze_y]), + "gaze_inouts": torch.FloatTensor([gaze_inside]), + "head_masks": head_mask, + "imsize": imsize, + } + if self.is_train and self.mask_generator is not None: + out_dict["image_masks"] = image_mask + return out_dict + + def __len__(self): + return self.length diff --git a/data/video_attention_target_video.py b/data/video_attention_target_video.py new file mode 100644 index 0000000000000000000000000000000000000000..a2a1d4711f91b0a76166abeaefe76feaa6329064 --- /dev/null +++ b/data/video_attention_target_video.py @@ -0,0 +1,464 @@ +import math +from os import path as osp +from typing import Callable, Optional +import glob +import torch +from torch.utils.data.dataset import Dataset +import torchvision.transforms.functional as TF +import numpy as np +from PIL import Image, ImageOps +import pandas as pd +from .masking import MaskGenerator +from . import data_utils as utils + + +class VideoAttentionTargetVideo(Dataset): + def __init__( + self, + image_root: str, + anno_root: str, + head_root: str, + transform: Callable, + input_size: int, + output_size: int, + quant_labelmap: bool = True, + is_train: bool = True, + seq_len: int = 8, + max_len: int = 32, + *, + mask_generator: Optional[MaskGenerator] = None, + bbox_jitter: float = 0.5, + rand_crop: float = 0.5, + rand_flip: float = 0.5, + color_jitter: float = 0.5, + rand_rotate: float = 0.0, + rand_lsj: float = 0.0, + ): + dfs = [] + for show_dir in glob.glob(osp.join(anno_root, "*")): + for sequence_path in glob.glob(osp.join(show_dir, "*", "*.txt")): + df = pd.read_csv( + sequence_path, + header=None, + index_col=False, + names=[ + "path", + "x_min", + "y_min", + "x_max", + "y_max", + "gaze_x", + "gaze_y", + ], + ) + show_name = sequence_path.split("/")[-3] + clip = sequence_path.split("/")[-2] + df["path"] = df["path"].apply( + lambda path: osp.join(show_name, clip, path) + ) + cur_len = len(df.index) + if is_train: + if cur_len <= max_len: + if cur_len >= seq_len: + dfs.append(df) + continue + remainder = cur_len % max_len + df_splits = [ + df[i : i + max_len] + for i in range(0, cur_len - max_len, max_len) + ] + if remainder >= seq_len: + df_splits.append(df[-remainder:]) + dfs.extend(df_splits) + else: + if cur_len < seq_len: + continue + df_splits = [ + df[i : i + seq_len] + for i in range(0, cur_len - seq_len, seq_len) + ] + dfs.extend(df_splits) + + for df in dfs: + df.reset_index(inplace=True) + self.dfs = dfs + self.length = len(dfs) + + self.data_dir = image_root + self.head_dir = head_root + self.transform = transform + self.draw_labelmap = ( + utils.draw_labelmap if quant_labelmap else utils.draw_labelmap_no_quant + ) + self.is_train = is_train + + self.input_size = input_size + self.output_size = output_size + self.seq_len = seq_len + + if self.is_train: + self.bbox_jitter = bbox_jitter + self.rand_crop = rand_crop + self.rand_flip = rand_flip + self.color_jitter = color_jitter + self.rand_rotate = rand_rotate + self.rand_lsj = rand_lsj + self.mask_generator = mask_generator + + def __getitem__(self, index): + df = self.dfs[index] + seq_len = len(df.index) + for coord in ["x_min", "y_min", "x_max", "y_max"]: + df[coord] = utils.smooth_by_conv(11, df, coord) + + if self.is_train: + # cond for data augmentation + cond_jitter = np.random.random_sample() + cond_flip = np.random.random_sample() + cond_color = np.random.random_sample() + if cond_color < self.color_jitter: + n1 = np.random.uniform(0.5, 1.5) + n2 = np.random.uniform(0.5, 1.5) + n3 = np.random.uniform(0.5, 1.5) + cond_crop = np.random.random_sample() + cond_rotate = np.random.random_sample() + if cond_rotate < self.rand_rotate: + angle = (2 * np.random.random_sample() - 1) * 20 + angle = -math.radians(angle) + cond_lsj = np.random.random_sample() + if cond_lsj < self.rand_lsj: + lsj_scale = 0.1 + np.random.random_sample() * 0.9 + + # if longer than seq_len_limit, cut it down to the limit with the init index randomly sampled + if seq_len > self.seq_len: + sampled_ind = np.random.randint(0, seq_len - self.seq_len) + seq_len = self.seq_len + else: + sampled_ind = 0 + + if cond_crop < self.rand_crop: + sliced_x_min = df["x_min"].iloc[sampled_ind : sampled_ind + seq_len] + sliced_x_max = df["x_max"].iloc[sampled_ind : sampled_ind + seq_len] + sliced_y_min = df["y_min"].iloc[sampled_ind : sampled_ind + seq_len] + sliced_y_max = df["y_max"].iloc[sampled_ind : sampled_ind + seq_len] + + sliced_gaze_x = df["gaze_x"].iloc[sampled_ind : sampled_ind + seq_len] + sliced_gaze_y = df["gaze_y"].iloc[sampled_ind : sampled_ind + seq_len] + + check_sum = sliced_gaze_x.sum() + sliced_gaze_y.sum() + all_outside = check_sum == -2 * seq_len + + # Calculate the minimum valid range of the crop that doesn't exclude the face and the gaze target + if all_outside: + crop_x_min = np.min([sliced_x_min.min(), sliced_x_max.min()]) + crop_y_min = np.min([sliced_y_min.min(), sliced_y_max.min()]) + crop_x_max = np.max([sliced_x_min.max(), sliced_x_max.max()]) + crop_y_max = np.max([sliced_y_min.max(), sliced_y_max.max()]) + else: + crop_x_min = np.min( + [sliced_gaze_x.min(), sliced_x_min.min(), sliced_x_max.min()] + ) + crop_y_min = np.min( + [sliced_gaze_y.min(), sliced_y_min.min(), sliced_y_max.min()] + ) + crop_x_max = np.max( + [sliced_gaze_x.max(), sliced_x_min.max(), sliced_x_max.max()] + ) + crop_y_max = np.max( + [sliced_gaze_y.max(), sliced_y_min.max(), sliced_y_max.max()] + ) + + # Randomly select a random top left corner + if crop_x_min >= 0: + crop_x_min = np.random.uniform(0, crop_x_min) + if crop_y_min >= 0: + crop_y_min = np.random.uniform(0, crop_y_min) + + # Get image size + path = osp.join(self.data_dir, df["path"].iloc[0]) + img = Image.open(path) + img = img.convert("RGB") + width, height = img.size + + # Find the range of valid crop width and height starting from the (crop_x_min, crop_y_min) + crop_width_min = crop_x_max - crop_x_min + crop_height_min = crop_y_max - crop_y_min + crop_width_max = width - crop_x_min + crop_height_max = height - crop_y_min + # Randomly select a width and a height + crop_width = np.random.uniform(crop_width_min, crop_width_max) + crop_height = np.random.uniform(crop_height_min, crop_height_max) + + # Round to integers + crop_y_min, crop_x_min, crop_height, crop_width = map( + int, map(round, (crop_y_min, crop_x_min, crop_height, crop_width)) + ) + else: + sampled_ind = 0 + + images = [] + head_channels = [] + heatmaps = [] + gazes = [] + gaze_inouts = [] + imsizes = [] + head_masks = [] + if self.is_train and self.mask_generator is not None: + image_masks = [] + for i, row in df.iterrows(): + if self.is_train and (i < sampled_ind or i >= (sampled_ind + self.seq_len)): + continue + + x_min = row["x_min"] # note: Already in image coordinates + y_min = row["y_min"] # note: Already in image coordinates + x_max = row["x_max"] # note: Already in image coordinates + y_max = row["y_max"] # note: Already in image coordinates + gaze_x = row["gaze_x"] # note: Already in image coordinates + gaze_y = row["gaze_y"] # note: Already in image coordinates + + if x_min > x_max: + x_min, x_max = x_max, x_min + if y_min > y_max: + y_min, y_max = y_max, y_min + + path = row["path"] + img = Image.open(osp.join(self.data_dir, path)).convert("RGB") + width, height = img.size + imsize = torch.FloatTensor([width, height]) + imsizes.append(imsize) + # Since we finetune from weights trained on GazeFollow, + # we don't incorporate the auxiliary task for VAT. + if osp.exists(osp.join(self.head_dir, path)): + head_mask = Image.open(osp.join(self.head_dir, path)).resize( + (width, height) + ) + else: + head_mask = Image.fromarray(np.zeros((height, width), dtype=np.float32)) + + x_min, y_min, x_max, y_max = map(float, [x_min, y_min, x_max, y_max]) + gaze_x, gaze_y = map(float, [gaze_x, gaze_y]) + if gaze_x == -1 and gaze_y == -1: + gaze_inside = False + else: + if ( + gaze_x < 0 + ): # move gaze point that was sliglty outside the image back in + gaze_x = 0 + if gaze_y < 0: + gaze_y = 0 + gaze_inside = True + + if self.is_train: + ## data augmentation + # Jitter (expansion-only) bounding box size. + if cond_jitter < self.bbox_jitter: + k = cond_jitter * 0.1 + x_min -= k * abs(x_max - x_min) + y_min -= k * abs(y_max - y_min) + x_max += k * abs(x_max - x_min) + y_max += k * abs(y_max - y_min) + x_min = np.clip(x_min, 0, width - 1) + x_max = np.clip(x_max, 0, width - 1) + y_min = np.clip(y_min, 0, height - 1) + y_max = np.clip(y_max, 0, height - 1) + + # Random color change + if cond_color < self.color_jitter: + img = TF.adjust_brightness(img, brightness_factor=n1) + img = TF.adjust_contrast(img, contrast_factor=n2) + img = TF.adjust_saturation(img, saturation_factor=n3) + + # Random Crop + if cond_crop < self.rand_crop: + # Crop it + img = TF.crop(img, crop_y_min, crop_x_min, crop_height, crop_width) + head_mask = TF.crop( + head_mask, crop_y_min, crop_x_min, crop_height, crop_width + ) + + # Record the crop's (x, y) offset + offset_x, offset_y = crop_x_min, crop_y_min + + # convert coordinates into the cropped frame + x_min, y_min, x_max, y_max = ( + x_min - offset_x, + y_min - offset_y, + x_max - offset_x, + y_max - offset_y, + ) + if gaze_inside: + gaze_x, gaze_y = (gaze_x - offset_x), (gaze_y - offset_y) + else: + gaze_x = -1 + gaze_y = -1 + + width, height = crop_width, crop_height + + # Flip? + if cond_flip < self.rand_flip: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + head_mask = head_mask.transpose(Image.FLIP_LEFT_RIGHT) + x_max_2 = width - x_min + x_min_2 = width - x_max + x_max = x_max_2 + x_min = x_min_2 + if gaze_x != -1 and gaze_y != -1: + gaze_x = width - gaze_x + + # Random Rotation + if cond_rotate < self.rand_rotate: + rot_mat = [ + round(math.cos(angle), 15), + round(math.sin(angle), 15), + 0.0, + round(-math.sin(angle), 15), + round(math.cos(angle), 15), + 0.0, + ] + + def _transform(x, y, matrix): + return ( + matrix[0] * x + matrix[1] * y + matrix[2], + matrix[3] * x + matrix[4] * y + matrix[5], + ) + + def _inv_transform(x, y, matrix): + x, y = x - matrix[2], y - matrix[5] + return ( + matrix[0] * x + matrix[3] * y, + matrix[1] * x + matrix[4] * y, + ) + + # Calculate offsets + rot_center = (width / 2.0, height / 2.0) + rot_mat[2], rot_mat[5] = _transform( + -rot_center[0], -rot_center[1], rot_mat + ) + rot_mat[2] += rot_center[0] + rot_mat[5] += rot_center[1] + xx = [] + yy = [] + for x, y in ((0, 0), (width, 0), (width, height), (0, height)): + x, y = _transform(x, y, rot_mat) + xx.append(x) + yy.append(y) + nw = math.ceil(max(xx)) - math.floor(min(xx)) + nh = math.ceil(max(yy)) - math.floor(min(yy)) + rot_mat[2], rot_mat[5] = _transform( + -(nw - width) / 2.0, -(nh - height) / 2.0, rot_mat + ) + + img = img.transform((nw, nh), Image.AFFINE, rot_mat, Image.BILINEAR) + head_mask = head_mask.transform( + (nw, nh), Image.AFFINE, rot_mat, Image.BILINEAR + ) + + xx = [] + yy = [] + for x, y in ( + (x_min, y_min), + (x_min, y_max), + (x_max, y_min), + (x_max, y_max), + ): + x, y = _inv_transform(x, y, rot_mat) + xx.append(x) + yy.append(y) + x_max, x_min = min(max(xx), nw), max(min(xx), 0) + y_max, y_min = min(max(yy), nh), max(min(yy), 0) + gaze_x, gaze_y = _inv_transform(gaze_x, gaze_y, rot_mat) + width, height = nw, nh + + if cond_lsj < self.rand_lsj: + nh, nw = int(height * lsj_scale), int(width * lsj_scale) + img = TF.resize(img, (nh, nw)) + img = ImageOps.expand(img, (0, 0, width - nw, height - nh)) + head_mask = TF.resize(head_mask, (nh, nw)) + head_mask = ImageOps.expand( + head_mask, (0, 0, width - nw, height - nh) + ) + x_min, y_min, x_max, y_max = ( + x_min * lsj_scale, + y_min * lsj_scale, + x_max * lsj_scale, + y_max * lsj_scale, + ) + gaze_x, gaze_y = gaze_x * lsj_scale, gaze_y * lsj_scale + + head_channel = utils.get_head_box_channel( + x_min, + y_min, + x_max, + y_max, + width, + height, + resolution=self.input_size, + coordconv=False, + ).unsqueeze(0) + + if self.is_train and self.mask_generator is not None: + image_mask = self.mask_generator( + x_min / width, + y_min / height, + x_max / width, + y_max / height, + head_channel, + ) + image_masks.append(image_mask) + + if self.transform is not None: + img = self.transform(img) + head_mask = TF.to_tensor( + TF.resize(head_mask, (self.input_size, self.input_size)) + ) + + if gaze_inside: + gaze_x /= float(width) # fractional gaze + gaze_y /= float(height) + gaze_heatmap = torch.zeros( + self.output_size, self.output_size + ) # set the size of the output + gaze_map = self.draw_labelmap( + gaze_heatmap, + [gaze_x * self.output_size, gaze_y * self.output_size], + 3, + type="Gaussian", + ) + gazes.append(torch.FloatTensor([gaze_x, gaze_y])) + else: + gaze_map = torch.zeros(self.output_size, self.output_size) + gazes.append(torch.FloatTensor([-1, -1])) + images.append(img) + head_channels.append(head_channel) + head_masks.append(head_mask) + heatmaps.append(gaze_map) + gaze_inouts.append(torch.FloatTensor([int(gaze_inside)])) + + images = torch.stack(images) + head_channels = torch.stack(head_channels) + heatmaps = torch.stack(heatmaps) + gazes = torch.stack(gazes) + gaze_inouts = torch.stack(gaze_inouts) + head_masks = torch.stack(head_masks) + imsizes = torch.stack(imsizes) + + out_dict = { + "images": images, + "head_channels": head_channels, + "heatmaps": heatmaps, + "gazes": gazes, + "gaze_inouts": gaze_inouts, + "head_masks": head_masks, + "imsize": imsizes, + } + if self.is_train and self.mask_generator is not None: + out_dict["image_masks"] = torch.stack(image_masks) + return out_dict + + def __len__(self): + return self.length + + +def video_collate(batch): + keys = batch[0].keys() + return {key: torch.cat([item[key] for item in batch]) for key in keys} diff --git a/docs/eval.md b/docs/eval.md new file mode 100644 index 0000000000000000000000000000000000000000..7b8f8b004a416f4ddeaca29f4248b1d2534df511 --- /dev/null +++ b/docs/eval.md @@ -0,0 +1,24 @@ +## Eval +### Testing Dataset + +You should prepare GazeFollow and VideoAttentionTarget for training. + +* Get [GazeFollow](https://www.dropbox.com/s/3ejt9pm57ht2ed4/gazefollow_extended.zip?dl=0). +* If train with auxiliary regression, use `scripts\gen_gazefollow_head_masks.py` to generate head masks. +* Get [VideoAttentionTarget](https://www.dropbox.com/s/8ep3y1hd74wdjy5/videoattentiontarget.zip?dl=0). + +Check `ViTGaze/configs/common/dataloader` to modify DATA_ROOT. + +### Evaluation + +Run +``` +bash val.sh configs/gazefollow_518.py ${Path2checkpoint} gf +``` +to evaluate on GazeFollow. + +Run +``` +bash val.sh configs/videoattentiontarget.py ${Path2checkpoint} vat +``` +to evaluate on VideoAttentionTarget. diff --git a/docs/install.md b/docs/install.md new file mode 100644 index 0000000000000000000000000000000000000000..1e3215de27b62eed536398db1a522542bc603252 --- /dev/null +++ b/docs/install.md @@ -0,0 +1,19 @@ +## Installation + +* Create a conda virtual env and activate it. + + ``` + conda create -n ViTGaze python==3.9.18 + conda activate ViTGaze + ``` +* Install packages. + + ``` + cd path/to/ViTGaze + pip install -r requirements.txt + ``` +* Install [detectron2](https://github.com/facebookresearch/detectron2) , follow its [documentation](https://detectron2.readthedocs.io/en/latest/). + For ViTGaze, we recommend to build it from latest source code. + ``` + python -m pip install 'git+https://github.com/facebookresearch/detectron2.git' + ``` diff --git a/docs/train.md b/docs/train.md new file mode 100644 index 0000000000000000000000000000000000000000..ad0fcc89ba3e2c51fb6153d912b53a73ec2d60f4 --- /dev/null +++ b/docs/train.md @@ -0,0 +1,36 @@ +## Train + +### Training Dataset + +You should prepare GazeFollow and VideoAttentionTarget for training. + +* Get [GazeFollow](https://www.dropbox.com/s/3ejt9pm57ht2ed4/gazefollow_extended.zip?dl=0). +* If train with auxiliary regression, use `scripts\gen_gazefollow_head_masks.py` to generate head masks. +* Get [VideoAttentionTarget](https://www.dropbox.com/s/8ep3y1hd74wdjy5/videoattentiontarget.zip?dl=0). + +Check `ViTGaze/configs/common/dataloader` to modify DATA_ROOT. + +### Pretrained Model + +* Get [DINOv2](https://github.com/facebookresearch/dinov2) pretrained ViT-S. +* Or you could download and preprocess pretrained weights by + + ``` + cd ViTGaze + mkdir pretrained && cd pretrained + wget https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth + ``` +* Preprocess the model weights with `scripts\convert_pth.py` to fit Detectron2 format. +### Train ViTGaze + +You can modify configs in `configs/gazefollow.py`, `configs/gazefollow_518.py` and `configs/videoattentiontarget.py`. + +Run: + +``` + bash train.sh +``` + +to train ViTGaze on the two datasets. + +Training output will be saved in `ViTGaze/output/`. diff --git a/engine/__init__.py b/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3015fc1f3dbdb29d72faa3a0f0d90e5f2d7d472b --- /dev/null +++ b/engine/__init__.py @@ -0,0 +1 @@ +from .trainer import CycleTrainer diff --git a/engine/trainer.py b/engine/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..a49bb68fb7cabd1285a97037149d0ebf2b6c6f37 --- /dev/null +++ b/engine/trainer.py @@ -0,0 +1,37 @@ +import time +import torch +from detectron2.engine import SimpleTrainer +from typing import Iterable, Generator + + +def cycle(iterable: Iterable) -> Generator: + while True: + for item in iterable: + yield item + + +class CycleTrainer(SimpleTrainer): + def __init__( + self, + model, + data_loader, + optimizer, + gather_metric_period=1, + zero_grad_before_forward=False, + async_write_metrics=False, + ): + super().__init__( + model, + data_loader, + optimizer, + gather_metric_period, + zero_grad_before_forward, + async_write_metrics, + ) + + @property + def _data_loader_iter(self): + # only create the data loader iterator when it is used + if self._data_loader_iter_obj is None: + self._data_loader_iter_obj = cycle(self.data_loader) + return self._data_loader_iter_obj diff --git a/modeling/_init__.py b/modeling/_init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5759a489d356da9431693802cd1433071e1fc037 --- /dev/null +++ b/modeling/_init__.py @@ -0,0 +1,4 @@ +from . import backbone, patch_attention, head, criterion, meta_arch + + +__all__ = ["backbone", "patch_attention", "head", "criterion", "meta_arch"] diff --git a/modeling/backbone/__init__.py b/modeling/backbone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b5e06e34b56b5e66cf8ade00f92e1271d479629 --- /dev/null +++ b/modeling/backbone/__init__.py @@ -0,0 +1 @@ +from .vit import build_backbone diff --git a/modeling/backbone/utils.py b/modeling/backbone/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b779d27975791e7bb37990b227f2fd34ff0cf945 --- /dev/null +++ b/modeling/backbone/utils.py @@ -0,0 +1,154 @@ +from functools import partial +from itertools import repeat +from typing import Iterable +import math +import torch.nn as nn +import torch.nn.functional as F + + +__all__ = [ + "get_abs_pos", + "PatchEmbed", + "Mlp", + "DropPath", +] + + +def to_2tuple(x): + if isinstance(x, Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, 2)) + + +def get_abs_pos(abs_pos, has_cls_token, hw): + """ + Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token + dimension for the original embeddings. + Args: + abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). + has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. + hw (Tuple): size of input image tokens. + + Returns: + Absolute positional embeddings after processing with shape (1, H, W, C) + """ + h, w = hw + if has_cls_token: + abs_pos = abs_pos[:, 1:] + xy_num = abs_pos.shape[1] + size = int(math.sqrt(xy_num)) + assert size * size == xy_num + + if size != h or size != w: + new_abs_pos = F.interpolate( + abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2), + size=(h, w), + mode="bicubic", + align_corners=False, + ) + + return new_abs_pos.permute(0, 2, 3, 1) + else: + return abs_pos.reshape(1, h, w, -1) + + +def drop_path( + x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True +): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size=(16, 16), + stride=(16, 16), + padding=(0, 0), + in_chans=3, + embed_dim=768, + ): + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x): + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f"drop_prob={round(self.drop_prob,3):0.3f}" + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = ( + norm_layer(hidden_features) if norm_layer is not None else nn.Identity() + ) + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x diff --git a/modeling/backbone/vit.py b/modeling/backbone/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..78f45a51362aa91eae475ec26a15346b5c6c08bc --- /dev/null +++ b/modeling/backbone/vit.py @@ -0,0 +1,504 @@ +import logging +from typing import Literal, Union +from functools import partial +import torch +import torch.nn as nn +from detectron2.modeling import Backbone + +try: + from xformers.ops import memory_efficient_attention + + XFORMERS_ON = True +except ImportError: + XFORMERS_ON = False +from .utils import ( + PatchEmbed, + get_abs_pos, + DropPath, + Mlp, +) + + +logger = logging.getLogger(__name__) + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + return_softmax_attn=True, + use_proj=True, + patch_token_offset=0, + ): + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) if use_proj else nn.Identity() + + self.return_softmax_attn = return_softmax_attn + + self.patch_token_offset = patch_token_offset + + def forward(self, x, return_attention=False, extra_token_offset=None): + B, L, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).view(B, L, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, L, -1).unbind(0) + + if return_attention or not XFORMERS_ON: + attn = (q * self.scale) @ k.transpose(-2, -1) + if return_attention and not self.return_softmax_attn: + out_attn = attn + attn = attn.softmax(dim=-1) + if return_attention and self.return_softmax_attn: + out_attn = attn + x = attn @ v + else: + x = memory_efficient_attention(q, k, v, scale=self.scale) + + x = x.view(B, self.num_heads, L, -1).permute(0, 2, 1, 3).reshape(B, L, -1) + x = self.proj(x) + + if return_attention: + out_attn = out_attn.reshape(B, self.num_heads, L, -1) + out_attn = out_attn[ + :, + :, + self.patch_token_offset : extra_token_offset, + self.patch_token_offset : extra_token_offset, + ] + return x, out_attn + else: + return x + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=True, + drop_path=0.0, + norm_layer=nn.LayerNorm, + act_layer=nn.GELU, + init_values=None, + return_softmax_attn=True, + attention_map_only=False, + patch_token_offset=0, + ): + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + drop_path (float): Stochastic depth rate. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + """ + super().__init__() + self.attention_map_only = attention_map_only + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + return_softmax_attn=return_softmax_attn, + use_proj=return_softmax_attn or not attention_map_only, + patch_token_offset=patch_token_offset, + ) + + if attention_map_only: + return + + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp = Mlp( + in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + + def forward(self, x, return_attention=False, extra_token_offset=None): + shortcut = x + x = self.norm1(x) + + if return_attention: + x, attn = self.attn(x, True, extra_token_offset) + else: + x = self.attn(x) + + if self.attention_map_only: + return x, attn + + x = shortcut + self.drop_path(self.ls1(x)) + x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x)))) + + if return_attention: + return x, attn + else: + return x + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, torch.Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class ViT(Backbone): + def __init__( + self, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + drop_path_rate=0.0, + norm_layer=nn.LayerNorm, + act_layer=nn.GELU, + pretrain_img_size=224, + pretrain_use_cls_token=True, + init_values=None, + use_cls_token=False, + use_mask_token=False, + norm_features=False, + return_softmax_attn=True, + num_register_tokens=0, + num_msg_tokens=0, + register_as_msg=False, + shift_strides=None, # [1, -1, 2, -2], + cls_shift=False, + num_extra_tokens=4, + use_extra_embed=False, + num_frames=None, + out_feature=True, + out_attn=(), + ): + super().__init__() + self.pretrain_use_cls_token = pretrain_use_cls_token + + self.patch_size = patch_size + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + # Initialize absolute positional embedding with pretrain image size. + num_patches = (pretrain_img_size // patch_size) * ( + pretrain_img_size // patch_size + ) + num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim)) + + self.use_cls_token = use_cls_token + self.cls_token = ( + nn.Parameter(torch.zeros(1, 1, embed_dim)) if use_cls_token else None + ) + + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) + if num_register_tokens > 0 + else None + ) + + # We tried to leverage temporal information with TeViT while it doesn't work + assert num_msg_tokens >= 0 + self.num_msg_tokens = num_msg_tokens + if register_as_msg: + self.num_msg_tokens += num_register_tokens + self.msg_tokens = ( + nn.Parameter(torch.zeros(1, num_msg_tokens, embed_dim)) + if num_msg_tokens > 0 + else None + ) + + patch_token_offset = ( + num_msg_tokens + num_register_tokens + int(self.use_cls_token) + ) + self.patch_token_offset = patch_token_offset + + self.msg_shift = None + if shift_strides is not None: + self.msg_shift = [] + for i in range(depth): + if i % 2 == 0: + self.msg_shift.append([_ for _ in shift_strides]) + else: + self.msg_shift.append([-_ for _ in shift_strides]) + + self.cls_shift = None + if cls_shift: + self.cls_shift = [(-1) ** idx for idx in range(depth)] + + assert num_extra_tokens >= 0 + self.num_extra_tokens = num_extra_tokens + self.extra_pos_embed = ( + nn.Linear(embed_dim, embed_dim) + if num_extra_tokens > 0 and use_extra_embed + else nn.Identity() + ) + + self.num_frames = num_frames + + # Mask token for masking augmentation + self.use_mask_token = use_mask_token + self.mask_token = ( + nn.Parameter(torch.zeros(1, embed_dim)) if use_mask_token else None + ) + + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + init_values=init_values, + return_softmax_attn=return_softmax_attn, + attention_map_only=(i == depth - 1) and not out_feature, + patch_token_offset=patch_token_offset, + ) + self.blocks.append(block) + + self.norm = norm_layer(embed_dim) if norm_features else nn.Identity() + + self._out_features = out_feature + self._out_attn = out_attn + + if self.pos_embed is not None: + nn.init.trunc_normal_(self.pos_embed, std=0.02) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x, masks=None, guidance=None): + x = self.patch_embed(x) + if masks is not None: + x = torch.where( + masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x + ) + + if self.pos_embed is not None: + x = x + get_abs_pos( + self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2]) + ) + B, H, W, _ = x.shape + x = x.reshape(B, H * W, -1) + + if self.use_cls_token: + cls_tokens = self.cls_token.expand(len(x), -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + if self.register_tokens is not None: + register_tokens = self.register_tokens.expand(len(x), -1, -1) + x = torch.cat((register_tokens, x), dim=1) + + if self.msg_tokens is not None: + msg_tokens = self.msg_tokens.expand(len(x), -1, -1) + x = torch.cat((msg_tokens, x), dim=1) + # [MSG, REG, CLS, PAT] + + extra_tokens_offset = None + if guidance is not None: + guidance = guidance.reshape(len(guidance), -1, 1) + extra_tokens = ( + (x[:, self.patch_token_offset :] * guidance) + .sum(dim=1, keepdim=True) + .expand(-1, self.num_extra_tokens, -1) + ) + extra_tokens = self.extra_pos_embed(extra_tokens) + x = torch.cat((x, extra_tokens), dim=1) + extra_tokens_offset = -self.num_extra_tokens + # [MSG, REG, CLS, PAT, EXT] + + attn_maps = [] + for idx, blk in enumerate(self.blocks): + if idx in self._out_attn: + x, attn = blk(x, True, extra_tokens_offset) + attn_maps.append(attn) + else: + x = blk(x) + + if self.msg_shift is not None: + msg_shift = self.msg_shift[idx] + msg_tokens = ( + x[:, : self.num_msg_tokens] + if guidance is None + else x[:, extra_tokens_offset:] + ) + msg_tokens = msg_tokens.reshape( + -1, self.num_frames, *msg_tokens.shape[1:] + ) + msg_tokens = msg_tokens.chunk(len(msg_shift), dim=2) + msg_tokens = [ + torch.roll(tokens, roll, dims=1) + for tokens, roll in zip(msg_tokens, msg_shift) + ] + msg_tokens = torch.cat(msg_tokens, dim=2).flatten(0, 1) + if guidance is None: + x = torch.cat([msg_tokens, x[:, self.num_msg_tokens :]], dim=1) + else: + x = torch.cat([x[:, :extra_tokens_offset], msg_tokens], dim=1) + + if self.cls_shift is not None: + cls_tokens = x[:, self.patch_token_offset - 1] + cls_tokens = cls_tokens.reshape( + -1, self.num_frames, 1, *cls_tokens.shape[1:] + ) + cls_tokens = torch.roll(cls_tokens, self.cls_shift[idx], dims=1) + x = torch.cat( + [ + x[:, : self.patch_token_offset - 1], + cls_tokens.flatten(0, 1), + x[:, self.patch_token_offset :], + ], + dim=1, + ) + + x = self.norm(x) + + outputs = {} + outputs["attention_maps"] = torch.cat(attn_maps, dim=1).reshape( + B, -1, H * W, H, W + ) + if self._out_features: + outputs["last_feat"] = ( + x[:, self.patch_token_offset : extra_tokens_offset] + .reshape(B, H, W, -1) + .permute(0, 3, 1, 2) + ) + + return outputs + + +def vit_tiny(**kwargs): + model = ViT( + patch_size=16, + embed_dim=192, + depth=12, + num_heads=3, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs + ) + return model + + +def vit_small(**kwargs): + model = ViT( + patch_size=16, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs + ) + return model + + +def vit_base(**kwargs): + model = ViT( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs + ) + return model + + +def dinov2_base(**kwargs): + model = ViT( + patch_size=14, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + pretrain_img_size=518, + init_values=1, + **kwargs + ) + return model + + +def dinov2_small(**kwargs): + model = ViT( + patch_size=14, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + pretrain_img_size=518, + init_values=1, + **kwargs + ) + return model + + +def build_backbone( + name: Literal["tiny", "small", "base", "dinov2_base", "dinov2_small"], **kwargs +): + vit_dict = { + "tiny": vit_tiny, + "small": vit_small, + "base": vit_base, + "dinov2_base": dinov2_base, + "dinov2_small": dinov2_small, + } + return vit_dict[name](**kwargs) diff --git a/modeling/criterion/__init__.py b/modeling/criterion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2c63773c315c423290e8c2e30a3351fce5cb5b5f --- /dev/null +++ b/modeling/criterion/__init__.py @@ -0,0 +1 @@ +from .gaze_mapper_criterion import GazeMapperCriterion diff --git a/modeling/criterion/gaze_mapper_criterion.py b/modeling/criterion/gaze_mapper_criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..472b0334cba6627d134671626711a7b3f66ea93c --- /dev/null +++ b/modeling/criterion/gaze_mapper_criterion.py @@ -0,0 +1,94 @@ +from functools import partial +import torch +from torch import nn +from torch.nn import functional as F +from fvcore.nn import sigmoid_focal_loss_jit + + +class GazeMapperCriterion(nn.Module): + def __init__( + self, + heatmap_weight: float = 10000, + inout_weight: float = 100, + aux_weight: float = 100, + use_aux_loss: bool = False, + aux_head_thres: float = 0, + use_focal_loss: bool = False, + alpha: float = -1, + gamma: float = 2, + ): + super().__init__() + self.heatmap_weight = heatmap_weight + self.inout_weight = inout_weight + self.aux_weight = aux_weight + self.aux_head_thres = aux_head_thres + + self.heatmap_loss = nn.MSELoss(reduce=False) + + if use_focal_loss: + self.inout_loss = partial( + sigmoid_focal_loss_jit, alpha=alpha, gamma=gamma, reduction="mean" + ) + else: + self.inout_loss = nn.BCEWithLogitsLoss() + + if use_aux_loss: + self.aux_loss = nn.BCEWithLogitsLoss() + else: + self.aux_loss = None + + def forward( + self, + pred_heatmap, + pred_inout, + gt_heatmap, + gt_inout, + pred_head_masks=None, + gt_head_masks=None, + ): + loss_dict = {} + + pred_heatmap = F.interpolate( + pred_heatmap, + size=tuple(gt_heatmap.shape[-2:]), + mode="bilinear", + align_corners=True, + ) + heatmap_loss = ( + self.heatmap_loss(pred_heatmap.squeeze(1), gt_heatmap) * self.heatmap_weight + ) + heatmap_loss = torch.mean(heatmap_loss, dim=(-2, -1)) + heatmap_loss = torch.sum(heatmap_loss.reshape(-1) * gt_inout.reshape(-1)) + # Check whether all outside, avoid 0/0 to be nan + if heatmap_loss > 1e-7: + heatmap_loss = heatmap_loss / torch.sum(gt_inout) + loss_dict["regression loss"] = heatmap_loss + else: + loss_dict["regression loss"] = heatmap_loss * 0 + + inout_loss = ( + self.inout_loss(pred_inout.reshape(-1), gt_inout.reshape(-1)) + * self.inout_weight + ) + loss_dict["classification loss"] = inout_loss + + if self.aux_loss is not None: + pred_head_masks = F.interpolate( + pred_head_masks, + size=tuple(gt_head_masks.shape[-2:]), + mode="bilinear", + align_corners=True, + ) + aux_loss = ( + torch.clamp( + self.aux_loss( + pred_head_masks.reshape(-1), gt_head_masks.reshape(-1) + ) + - self.aux_head_thres, + 0, + ) + * self.aux_weight + ) + loss_dict["aux head loss"] = aux_loss + + return loss_dict diff --git a/modeling/head/__init__.py b/modeling/head/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3a67fa906dec39e14d7107607d45adf42fe9bdf0 --- /dev/null +++ b/modeling/head/__init__.py @@ -0,0 +1,2 @@ +from .inout_head import build_inout_head +from .heatmap_head import build_heatmap_head diff --git a/modeling/head/heatmap_head.py b/modeling/head/heatmap_head.py new file mode 100644 index 0000000000000000000000000000000000000000..60f464375e614a5c0f60de57b22eef107e9d72aa --- /dev/null +++ b/modeling/head/heatmap_head.py @@ -0,0 +1,225 @@ +import torch +from torch import nn +from detectron2.utils.registry import Registry +from typing import Literal, List, Dict, Optional, OrderedDict + + +HEATMAP_HEAD_REGISTRY = Registry("HEATMAP_HEAD_REGISTRY") +HEATMAP_HEAD_REGISTRY.__doc__ = "Registry for heatmap head" + + +class BaseHeatmapHead(nn.Module): + def __init__( + self, + in_channel: int, + deconv_cfgs: List[Dict], + dim: int = 96, + use_conv: bool = False, + use_residual: bool = False, + feat_type: Literal["attn", "both"] = "both", + attn_layer: Optional[str] = None, + pre_norm: bool = False, + use_head: bool = False, + ) -> None: + super().__init__() + self.feat_type = feat_type + self.use_head = use_head + + if pre_norm: + self.pre_norm = nn.Sequential( + OrderedDict( + [ + ("bn", nn.BatchNorm2d(in_channel)), + ("relu", nn.ReLU(inplace=True)), + ] + ) + ) + else: + self.pre_norm = nn.Identity() + + if use_conv: + if use_residual: + from timm.models.resnet import Bottleneck, downsample_conv + + self.conv = Bottleneck( + in_channel, + dim // 4, + downsample=downsample_conv(in_channel, dim, 1) + if in_channel != dim + else None, + attn_layer=attn_layer, + ) + else: + self.conv = nn.Sequential( + OrderedDict( + [ + ("conv", nn.Conv2d(in_channel, dim, 3, 1, 1)), + ("bn", nn.BatchNorm2d(dim)), + ("relu", nn.ReLU(inplace=True)), + ] + ) + ) + else: + self.conv = nn.Identity() + + self.decoder: nn.Module = None + + def get_feat(self, x): + if self.feat_type == "attn": + feat = x["attention_maps"] + elif self.feat_type == "feat": + feat = x["last_feat"] + return feat + + def forward(self, x): + feat = self.get_feat(x) + feat = self.pre_norm(feat) + feat = self.conv(feat) + return self.decoder(feat) + + +@HEATMAP_HEAD_REGISTRY.register() +class SimpleDeconv(BaseHeatmapHead): + def __init__( + self, + in_channel: int, + deconv_cfgs: List[Dict], + dim: int = 96, + use_conv: bool = False, + use_residual: bool = False, + feat_type: Literal["attn", "both"] = "both", + attn_layer: Optional[str] = None, + pre_norm: bool = False, + use_head: bool = False, + ) -> None: + super().__init__( + in_channel, + deconv_cfgs, + dim, + use_conv, + use_residual, + feat_type, + attn_layer, + pre_norm, + use_head, + ) + decoder_layers = [] + for i, deconv_cfg in enumerate(deconv_cfgs, start=1): + decoder_layers.extend( + [ + ( + "".join(["deconv", str(i)]), + nn.ConvTranspose2d(**deconv_cfg), + ), + ( + "".join(["bn", str(i)]), + nn.BatchNorm2d(deconv_cfg["out_channels"]), + ), + ("".join(["relu", str(i)]), nn.ReLU(inplace=True)), + ] + ) + decoder_layers.append(("conv", nn.Conv2d(1, 1, 1))) + self.decoder = nn.Sequential(OrderedDict(decoder_layers)) + + +@HEATMAP_HEAD_REGISTRY.register() +class UpSampleConv(BaseHeatmapHead): + def __init__( + self, + in_channel: int, + deconv_cfgs: List[Dict], + dim: int = 96, + use_conv: bool = False, + use_residual: bool = False, + feat_type: Literal["attn", "both"] = "both", + attn_layer: Optional[str] = None, + pre_norm: bool = False, + use_head: bool = False, + ) -> None: + super().__init__( + in_channel, + deconv_cfgs, + dim, + use_conv, + use_residual, + feat_type, + attn_layer, + pre_norm, + use_head, + ) + decoder_layers = [] + for i, deconv_cfg in enumerate(deconv_cfgs, start=1): + decoder_layers.extend( + [ + ( + "".join(["upsample", str(i)]), + nn.Upsample( + scale_factor=2, mode="bilinear", align_corners=True + ), + ), + ( + "".join(["conv", str(i)]), + nn.Conv2d(**deconv_cfg), + ), + ( + "".join(["bn", str(i)]), + nn.BatchNorm2d(deconv_cfg["out_channels"]), + ), + ("".join(["relu", str(i)]), nn.ReLU(inplace=True)), + ] + ) + decoder_layers.append(("conv", nn.Conv2d(1, 1, 1))) + self.decoder = nn.Sequential(OrderedDict(decoder_layers)) + + +@HEATMAP_HEAD_REGISTRY.register() +class PixelShuffle(BaseHeatmapHead): + def __init__( + self, + in_channel: int, + deconv_cfgs: List[Dict], + dim: int = 96, + use_conv: bool = False, + use_residual: bool = False, + feat_type: Literal["attn", "both"] = "both", + attn_layer: Optional[str] = None, + pre_norm: bool = False, + use_head: bool = False, + ) -> None: + super().__init__( + in_channel, + deconv_cfgs, + dim, + use_conv, + use_residual, + feat_type, + attn_layer, + pre_norm, + use_head, + ) + decoder_layers = [] + for i, deconv_cfg in enumerate(deconv_cfgs, start=1): + deconv_cfg["out_channels"] = deconv_cfg["out_channels"] * 4 + decoder_layers.extend( + [ + ( + "".join(["conv", str(i)]), + nn.Conv2d(**deconv_cfg), + ), + ( + "".join(["pixel_shuffle", str(i)]), + nn.PixelShuffle(upscale_factor=2), + ), + ( + "".join(["bn", str(i)]), + nn.BatchNorm2d(deconv_cfg["out_channels"] // 4), + ), + ("".join(["relu", str(i)]), nn.ReLU(inplace=True)), + ] + ) + decoder_layers.append(("conv", nn.Conv2d(1, 1, 1))) + self.decoder = nn.Sequential(OrderedDict(decoder_layers)) + + +def build_heatmap_head(name, *args, **kwargs): + return HEATMAP_HEAD_REGISTRY.get(name)(*args, **kwargs) diff --git a/modeling/head/inout_head.py b/modeling/head/inout_head.py new file mode 100644 index 0000000000000000000000000000000000000000..b5e8ae3a4077a9bdbda95cf4eaf97d37e0ff72d4 --- /dev/null +++ b/modeling/head/inout_head.py @@ -0,0 +1,55 @@ +from typing import OrderedDict +import torch +from torch import nn +from detectron2.utils.registry import Registry + + +INOUT_HEAD_REGISTRY = Registry("INOUT_HEAD_REGISTRY") +INOUT_HEAD_REGISTRY.__doc__ = "Registry for inout head" + + +@INOUT_HEAD_REGISTRY.register() +class SimpleLinear(nn.Module): + def __init__(self, in_channel: int, dropout: float = 0) -> None: + super().__init__() + self.in_channel = in_channel + self.classifier = nn.Sequential( + OrderedDict( + [("dropout", nn.Dropout(dropout)), ("linear", nn.Linear(in_channel, 1))] + ) + ) + + def get_feat(self, x, masks): + feats = x["head_feat"] + if masks is not None: + B, C = x["last_feat"].shape[:2] + scene_feats = x["last_feat"].view(B, C, -1).permute(0, 2, 1) + masks = masks / (masks.sum(dim=-1, keepdim=True) + 1e-6) + scene_feats = (scene_feats * masks.unsqueeze(-1)).sum(dim=1) + feats = torch.cat((feats, scene_feats), dim=1) + return feats + + def forward(self, x, masks=None): + feat = self.get_feat(x, masks) + return self.classifier(feat) + + +@INOUT_HEAD_REGISTRY.register() +class SimpleMlp(SimpleLinear): + def __init__(self, in_channel: int, dropout: float = 0) -> None: + super().__init__(in_channel, dropout) + self.classifier = nn.Sequential( + OrderedDict( + [ + ("dropout0", nn.Dropout(dropout)), + ("linear0", nn.Linear(in_channel, in_channel)), + ("relu", nn.ReLU()), + ("dropout1", nn.Dropout(dropout)), + ("linear1", nn.Linear(in_channel, 1)), + ] + ) + ) + + +def build_inout_head(name, *args, **kwargs): + return INOUT_HEAD_REGISTRY.get(name)(*args, **kwargs) diff --git a/modeling/meta_arch/__init__.py b/modeling/meta_arch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6d7b2a8079d210faff9e594e501e8980df6cec7a --- /dev/null +++ b/modeling/meta_arch/__init__.py @@ -0,0 +1 @@ +from .gaze_attention_mapper import GazeAttentionMapper diff --git a/modeling/meta_arch/gaze_attention_mapper.py b/modeling/meta_arch/gaze_attention_mapper.py new file mode 100644 index 0000000000000000000000000000000000000000..d6244bf6512b2c7d4fc25c2e731b4edc357cdfc4 --- /dev/null +++ b/modeling/meta_arch/gaze_attention_mapper.py @@ -0,0 +1,97 @@ +import torch +from torch import nn +from typing import Dict, Union + + +class GazeAttentionMapper(nn.Module): + def __init__( + self, + backbone: nn.Module, + regressor: nn.Module, + classifier: nn.Module, + criterion: nn.Module, + pam: nn.Module, + use_aux_loss: bool = False, + device: Union[torch.device, str] = "cuda", + ) -> None: + super().__init__() + self.backbone = backbone + self.pam = pam + self.regressor = regressor + self.classifier = classifier + self.criterion = criterion + self.use_aux_loss = use_aux_loss + self.device = torch.device(device) + + def forward(self, x): + ( + scenes, + heads, + gt_heatmaps, + gt_inouts, + head_masks, + image_masks, + ) = self.preprocess_inputs(x) + # Calculate patch weights based on head position + embedded_heads = self.pam(scenes, heads) + aux_masks = None + if self.use_aux_loss: + embedded_heads, aux_masks = embedded_heads + + # Get out-dict + x = self.backbone( + scenes, + image_masks, + None, + ) + + # Apply patch weights to get the final feats and attention maps + feats = x.get("last_feat", None) + if feats is not None: + x["head_feat"] = ( + (embedded_heads.repeat(1, feats.shape[1], 1, 1) * feats) + .sum(dim=(2, 3)) + .reshape(len(feats), -1) + ) # BC + + attn_maps = x["attention_maps"] + B, C, *_ = attn_maps.shape + x["attention_maps"] = ( + attn_maps * embedded_heads.reshape(B, 1, -1, 1, 1).repeat(1, C, 1, 1, 1) + ).sum( + dim=2 + ) # BCHW + + # Apply heads + heatmaps = self.regressor(x) + inouts = self.classifier(x, None) + + if self.training: + return self.criterion( + heatmaps, + inouts, + gt_heatmaps, + gt_inouts, + aux_masks, + head_masks, + ) + # Inference + return heatmaps, inouts.sigmoid() + + def preprocess_inputs(self, batched_inputs: Dict[str, torch.Tensor]): + return ( + batched_inputs["images"].to(self.device), + batched_inputs["head_channels"].to(self.device), + batched_inputs["heatmaps"].to(self.device) + if "heatmaps" in batched_inputs.keys() + else None, + batched_inputs["gaze_inouts"].to(self.device) + if "gaze_inouts" in batched_inputs.keys() + else None, + batched_inputs["head_masks"].to(self.device) + if "head_masks" in batched_inputs.keys() + else None, + batched_inputs["image_masks"].to(self.device) + if "image_masks" in batched_inputs.keys() + else None, + ) diff --git a/modeling/patch_attention/__init__.py b/modeling/patch_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9be13cc2e49b56af39e59796fa92cb766ca45c4d --- /dev/null +++ b/modeling/patch_attention/__init__.py @@ -0,0 +1 @@ +from .patch_attention import build_pam diff --git a/modeling/patch_attention/patch_attention.py b/modeling/patch_attention/patch_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..2a42b236865e1e2d05a356f4f686ee282bdd6e0e --- /dev/null +++ b/modeling/patch_attention/patch_attention.py @@ -0,0 +1,106 @@ +from typing import OrderedDict +from torch import nn +from torch.nn import functional as F +from detectron2.utils.registry import Registry +from fvcore.nn import c2_msra_fill + + +SPATIAL_GUIDANCE_REGISTRY = Registry("SPATIAL_GUIDANCE_REGISTRY") +SPATIAL_GUIDANCE_REGISTRY.__doc__ = "Registry for 2d spatial guidance" + + +class _PoolFusion(nn.Module): + def __init__(self, patch_size: int, use_avgpool: bool = False) -> None: + super().__init__() + self.patch_size = patch_size + self.attn_reducer = F.avg_pool2d if use_avgpool else F.max_pool2d + + def forward(self, scenes, heads): + attn_masks = self.attn_reducer( + heads, + (self.patch_size, self.patch_size), + (self.patch_size, self.patch_size), + (0, 0), + ) + patch_attn = attn_masks.masked_fill(attn_masks <= 0, -1e9) + return F.softmax(patch_attn.view(len(patch_attn), -1), dim=1).view( + *patch_attn.shape + ) + + +@SPATIAL_GUIDANCE_REGISTRY.register() +class AvgFusion(_PoolFusion): + def __init__(self, patch_size: int) -> None: + super().__init__(patch_size, False) + + +@SPATIAL_GUIDANCE_REGISTRY.register() +class MaxFusion(_PoolFusion): + def __init__(self, patch_size: int) -> None: + super().__init__(patch_size, True) + + +@SPATIAL_GUIDANCE_REGISTRY.register() +class PatchPAM(nn.Module): + def __init__( + self, + patch_size: int, + act_layer=nn.ReLU, + embed_dim: int = 768, + use_aux_loss: bool = False, + ) -> None: + super().__init__() + self.patch_size = patch_size + patch_embed = nn.Conv2d( + 3, embed_dim, (patch_size, patch_size), (patch_size, patch_size), (0, 0) + ) + c2_msra_fill(patch_embed) + conv = nn.Conv2d(embed_dim, 1, (1, 1), (1, 1), (0, 0)) + c2_msra_fill(conv) + self.use_aux_loss = use_aux_loss + if use_aux_loss: + self.patch_embed = nn.Sequential( + OrderedDict( + [ + ("patch_embed", patch_embed), + ("act_layer", act_layer(inplace=True)), + ] + ) + ) + self.embed = conv + conv = nn.Conv2d(embed_dim, 1, (1, 1), (1, 1), (0, 0)) + c2_msra_fill(conv) + self.aux_embed = conv + else: + self.embed = nn.Sequential( + OrderedDict( + [ + ("patch_embed", patch_embed), + ("act_layer", act_layer(inplace=True)), + ("embed", conv), + ] + ) + ) + + def forward(self, scenes, heads): + attn_masks = F.max_pool2d( + heads, + (self.patch_size, self.patch_size), + (self.patch_size, self.patch_size), + (0, 0), + ) + if self.use_aux_loss: + embed = self.patch_embed(scenes) + aux_masks = self.aux_embed(embed) + patch_attn = self.embed(embed) * attn_masks + else: + patch_attn = self.embed(scenes) * attn_masks + patch_attn = patch_attn.masked_fill(attn_masks <= 0, -1e9) + patch_attn = F.softmax(patch_attn.view(len(patch_attn), -1), dim=1).view( + *patch_attn.shape + ) + return (patch_attn, aux_masks) if self.use_aux_loss else patch_attn + + +def build_pam(name, *args, **kwargs): + return SPATIAL_GUIDANCE_REGISTRY.get(name)(*args, **kwargs) diff --git a/pretrained/gazefollow.pth b/pretrained/gazefollow.pth new file mode 100644 index 0000000000000000000000000000000000000000..71fcf32cdf245634a5df508b03cf710dcac49f18 --- /dev/null +++ b/pretrained/gazefollow.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:61f82255d4db0640208c90b037fd4bb62e061efd51cd03aa4b45f10159acab15 +size 266808135 diff --git a/pretrained/videoattentiontarget.pth b/pretrained/videoattentiontarget.pth new file mode 100644 index 0000000000000000000000000000000000000000..1c745d99f3723b0649ed282f28c942f5cd5ce02b --- /dev/null +++ b/pretrained/videoattentiontarget.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cdf8d767815b7c8f1ba78d89710dc7a2de2d08e97554654a296a3e71b4903cec +size 266808135 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..71b571a904fd06e3de67c9766361135e1f4d199e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,17 @@ +fvcore==0.1.5.post20221221 +numpy==1.26.2 +omegaconf==2.3.0 +opencv-python==4.8.1.78 +opencv-python-headless==4.6.0.66 +pandas==1.4.4 +Pillow==10.3.0 +scikit-image==0.22.0 +scikit-learn==1.5.0 +scipy==1.11.4 +timm==0.9.12 +torch==2.2.0 +torchaudio==2.0.2 +torchvision==0.15.2 +tqdm==4.66.3 +xformers==0.0.21 +yacs==0.1.8 diff --git a/scripts/convert_pth.py b/scripts/convert_pth.py new file mode 100644 index 0000000000000000000000000000000000000000..48d63809813fde7ac061d0c3fe4fafb4283b03bd --- /dev/null +++ b/scripts/convert_pth.py @@ -0,0 +1,32 @@ +# Convert official model weights to format that d2 receives +import argparse +from collections import OrderedDict +import torch + + +def convert(src: str, dst: str): + checkpoint = torch.load(src) + has_model = "model" in checkpoint.keys() + checkpoint = checkpoint["model"] if has_model else checkpoint + if "state_dict" in checkpoint.keys(): + checkpoint = checkpoint["state_dict"] + out_cp = OrderedDict() + for k, v in checkpoint.items(): + out_cp[".".join(["backbone", k])] = v + out_cp = {"model": out_cp} if has_model else out_cp + torch.save(out_cp, dst) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--src", "-s", type=str, required=True, help="Path to src weights.pth" + ) + parser.add_argument( + "--dst", "-d", type=str, required=True, help="Path to dst weights.pth" + ) + args = parser.parse_args() + convert( + args.src, + args.dst, + ) diff --git a/scripts/gen_gazefollow_head_masks.py b/scripts/gen_gazefollow_head_masks.py new file mode 100644 index 0000000000000000000000000000000000000000..6b35ab230c33393650e8b36011c0b919f1f49d2d --- /dev/null +++ b/scripts/gen_gazefollow_head_masks.py @@ -0,0 +1,185 @@ +import argparse +import os +import random +import cv2 +import numpy as np +import pandas as pd +import tqdm +from PIL import Image +from retinaface.pre_trained_models import get_model + + +random.seed(1) + + +def gaussian(x_min, y_min, x_max, y_max): + x_min, x_max = sorted((x_min, x_max)) + y_min, y_max = sorted((y_min, y_max)) + x_mid, y_mid = (x_min + x_max) / 2, (y_min + y_max) / 2 + x_sigma2, y_sigma2 = ( + np.clip(np.square([x_max - x_min, y_max - y_min], dtype=float), 1, None) / 3 + ) + + def _gaussian(_xs, _ys): + return np.exp( + -(np.square(_xs - x_mid) / x_sigma2 + np.square(_ys - y_mid) / y_sigma2) + ) + + return _gaussian + + +def plot_ori(label_path, data_dir): + df = pd.read_csv( + label_path, + names=[ + # Original labels + "path", + "idx", + "body_bbox_x", + "body_bbox_y", + "body_bbox_w", + "body_bbox_h", + "eye_x", + "eye_y", + "gaze_x", + "gaze_y", + "x_min", + "y_min", + "x_max", + "y_max", + "inout", + "meta0", + "meta1", + ], + index_col=False, + encoding="utf-8-sig", + ) + grouped = df.groupby("path") + + output_dir = os.path.join(data_dir, "head_masks") + + for image_name, group_df in tqdm.tqdm(grouped, desc="Generating masks with annotations: "): + if not os.path.exists(os.path.join(output_dir, image_name)): + w, h = Image.open(image_name).size + heatmap = np.zeros((h, w), dtype=np.float32) + indices = np.meshgrid( + np.linspace(0.0, float(w), num=w, endpoint=False), + np.linspace(0.0, float(h), num=h, endpoint=False), + ) + for _, row in group_df.iterrows(): + x_min, y_min, x_max, y_max = ( + row["x_min"], + row["y_min"], + row["x_max"], + row["y_max"], + ) + gauss = gaussian(x_min, y_min, x_max, y_max) + heatmap += gauss(*indices) + heatmap /= np.max(heatmap) + heatmap_image = Image.fromarray((heatmap * 255).astype(np.uint8), mode="L") + output_filename = os.path.join(output_dir, image_name) + os.makedirs(os.path.dirname(output_filename), exist_ok=True) + heatmap_image.save(output_filename) + + +def plot_gen(df, data_dir): + df = df[df["score"] > 0.8] + grouped = df.groupby("path") + + output_dir = os.path.join(data_dir, "head_masks") + + for image_name, group_df in tqdm.tqdm(grouped, desc="Generating masks with predictions: "): + w, h = Image.open(image_name).size + heatmap = np.zeros((h, w), dtype=np.float32) + indices = np.meshgrid( + np.linspace(0.0, float(w), num=w, endpoint=False), + np.linspace(0.0, float(h), num=h, endpoint=False), + ) + for index, row in group_df.iterrows(): + x_min, y_min, x_max, y_max = ( + row["x_min"], + row["y_min"], + row["x_max"], + row["y_max"], + ) + gauss = gaussian(x_min, y_min, x_max, y_max) + heatmap += gauss(*indices) + heatmap /= np.max(heatmap) + heatmap_image = Image.fromarray((heatmap * 255).astype(np.uint8), mode="L") + output_filename = os.path.join(output_dir, image_name) + os.makedirs(os.path.dirname(output_filename), exist_ok=True) + heatmap_image.save(output_filename) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--dataset_dir", help="Root directory of dataset") + parser.add_argument( + "--subset", + help="Subset of dataset to process", + choices=["train", "test"], + ) + args = parser.parse_args() + + label_path = os.path.join( + args.dataset_dir, args.subset + "_annotations_release.txt" + ) + + column_names = [ + "path", + "idx", + "body_bbox_x", + "body_bbox_y", + "body_bbox_w", + "body_bbox_h", + "eye_x", + "eye_y", + "gaze_x", + "gaze_y", + "bbox_x_min", + "bbox_y_min", + "bbox_x_max", + "bbox_y_max", + ] + df = pd.read_csv( + label_path, + sep=",", + names=column_names, + usecols=column_names, + index_col=False, + ) + df = df.groupby("path") + + model = get_model("resnet50_2020-07-20", max_size=2048, device="cuda") + model.eval() + + paths = list(df.groups.keys()) + csv = [] + for path in tqdm.tqdm(paths, desc="Predicting head bboxes: "): + img = cv2.imread(os.path.join(args.dataset_dir, path)) + + annotations = model.predict_jsons(img) + + for annotation in annotations: + if len(annotation["bbox"]) == 0: + continue + + csv.append( + [ + path, + annotation["score"], + annotation["bbox"][0], + annotation["bbox"][1], + annotation["bbox"][2], + annotation["bbox"][3], + ] + ) + + # Write csv + df = pd.DataFrame( + csv, columns=["path", "score", "x_min", "y_min", "x_max", "y_max"] + ) + df.to_csv(os.path.join(args.dataset_dir, f"{args.subset}_head.csv"), index=False) + + plot_gen(df, args.dataset_dir) + plot_ori(label_path, args.data_dir) diff --git a/tools/eval_on_gazefollow.py b/tools/eval_on_gazefollow.py new file mode 100644 index 0000000000000000000000000000000000000000..5bd51709ef4fe2671d59e2abde9c0bc21e389a86 --- /dev/null +++ b/tools/eval_on_gazefollow.py @@ -0,0 +1,92 @@ +import sys +from os import path as osp +import argparse +import warnings +import torch +import numpy as np +from PIL import Image +from detectron2.config import instantiate, LazyConfig + +sys.path.append(osp.dirname(osp.dirname(__file__))) +from utils import * + + +warnings.simplefilter(action="ignore", category=FutureWarning) + + +def do_test(cfg, model, use_dark_inference=False): + val_loader = instantiate(cfg.dataloader.val) + + model.train(False) + AUC = [] + min_dist = [] + avg_dist = [] + with torch.no_grad(): + for data in val_loader: + val_gaze_heatmap_pred, _ = model(data) + val_gaze_heatmap_pred = ( + val_gaze_heatmap_pred.squeeze(1).cpu().detach().numpy() + ) + + # go through each data point and record AUC, min dist, avg dist + for b_i in range(len(val_gaze_heatmap_pred)): + # remove padding and recover valid ground truth points + valid_gaze = data["gazes"][b_i] + valid_gaze = valid_gaze[valid_gaze != -1].view(-1, 2) + # AUC: area under curve of ROC + multi_hot = multi_hot_targets(data["gazes"][b_i], data["imsize"][b_i]) + if use_dark_inference: + pred_x, pred_y = dark_inference(val_gaze_heatmap_pred[b_i]) + else: + pred_x, pred_y = argmax_pts(val_gaze_heatmap_pred[b_i]) + norm_p = [ + pred_x / val_gaze_heatmap_pred[b_i].shape[-2], + pred_y / val_gaze_heatmap_pred[b_i].shape[-1], + ] + scaled_heatmap = np.array( + Image.fromarray(val_gaze_heatmap_pred[b_i]).resize( + data["imsize"][b_i], + resample=Image.BILINEAR, + ) + ) + auc_score = auc(scaled_heatmap, multi_hot) + AUC.append(auc_score) + # min distance: minimum among all possible pairs of + all_distances = [] + for gt_gaze in valid_gaze: + all_distances.append(L2_dist(gt_gaze, norm_p)) + min_dist.append(min(all_distances)) + # average distance: distance between the predicted point and human average point + mean_gt_gaze = torch.mean(valid_gaze, 0) + avg_distance = L2_dist(mean_gt_gaze, norm_p) + avg_dist.append(avg_distance) + + print("|AUC |min dist|avg dist|") + print( + "|{:.4f}|{:.4f} |{:.4f} |".format( + torch.mean(torch.tensor(AUC)), + torch.mean(torch.tensor(min_dist)), + torch.mean(torch.tensor(avg_dist)), + ) + ) + + +def main(args): + cfg = LazyConfig.load(args.config_file) + model: torch.Module = instantiate(cfg.model) + model.load_state_dict(torch.load(args.model_weights)["model"]) + model.to(cfg.train.device) + do_test(cfg, model, use_dark_inference=args.use_dark_inference) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config_file", type=str, help="config file") + parser.add_argument( + "--model_weights", + type=str, + help="model weights", + ) + parser.add_argument("--use_dark_inference", action="store_true") + args = parser.parse_args() + main(args) diff --git a/tools/eval_on_video_attention_target.py b/tools/eval_on_video_attention_target.py new file mode 100644 index 0000000000000000000000000000000000000000..ae770fef9d5be39c0abcabc90f56962285902f61 --- /dev/null +++ b/tools/eval_on_video_attention_target.py @@ -0,0 +1,93 @@ +import sys +from os import path as osp +import argparse +import warnings +import torch +import numpy as np +from PIL import Image +from detectron2.config import instantiate, LazyConfig + +sys.path.append(osp.dirname(osp.dirname(__file__))) +from utils import * + + +warnings.simplefilter(action="ignore", category=FutureWarning) + + +def do_test(cfg, model, use_dark_inference=False): + val_loader = instantiate(cfg.dataloader.val) + + model.train(False) + AUC = [] + dist = [] + inout_gt = [] + inout_pred = [] + with torch.no_grad(): + for data in val_loader: + val_gaze_heatmap_pred, val_gaze_inout_pred = model(data) + val_gaze_heatmap_pred = ( + val_gaze_heatmap_pred.squeeze(1).cpu().detach().numpy() + ) + val_gaze_inout_pred = val_gaze_inout_pred.cpu().detach().numpy() + + # go through each data point and record AUC, dist, ap + for b_i in range(len(val_gaze_heatmap_pred)): + auc_batch = [] + dist_batch = [] + if data["gaze_inouts"][b_i]: + # remove padding and recover valid ground truth points + valid_gaze = data["gazes"][b_i] + # AUC: area under curve of ROC + multi_hot = data["heatmaps"][b_i] + multi_hot = (multi_hot > 0).float().numpy() + if use_dark_inference: + pred_x, pred_y = dark_inference(val_gaze_heatmap_pred[b_i]) + else: + pred_x, pred_y = argmax_pts(val_gaze_heatmap_pred[b_i]) + norm_p = [ + pred_x / val_gaze_heatmap_pred[b_i].shape[-1], + pred_y / val_gaze_heatmap_pred[b_i].shape[-2], + ] + scaled_heatmap = np.array( + Image.fromarray(val_gaze_heatmap_pred[b_i]).resize( + (64, 64), + resample=Image.Resampling.BILINEAR, + ) + ) + auc_score = auc(scaled_heatmap, multi_hot) + auc_batch.append(auc_score) + dist_batch.append(L2_dist(valid_gaze.numpy(), norm_p)) + AUC.extend(auc_batch) + dist.extend(dist_batch) + inout_gt.extend(data["gaze_inouts"].cpu().numpy()) + inout_pred.extend(val_gaze_inout_pred) + + print("|AUC |dist |AP |") + print( + "|{:.4f}|{:.4f} |{:.4f} |".format( + torch.mean(torch.tensor(AUC)), + torch.mean(torch.tensor(dist)), + ap(inout_gt, inout_pred), + ) + ) + + +def main(args): + cfg = LazyConfig.load(args.config_file) + model: torch.Module = instantiate(cfg.model) + model.load_state_dict(torch.load(args.model_weights)["model"]) + model.to(cfg.train.device) + do_test(cfg, model, use_dark_inference=args.use_dark_inference) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config_file", type=str, help="config file") + parser.add_argument( + "--model_weights", + type=str, + help="model weights", + ) + parser.add_argument("--use_dark_inference", action="store_true") + args = parser.parse_args() + main(args) diff --git a/tools/train.py b/tools/train.py new file mode 100644 index 0000000000000000000000000000000000000000..93671beb24924c6e2e4cef92a4cb802bd10e5253 --- /dev/null +++ b/tools/train.py @@ -0,0 +1,104 @@ +import os.path as osp +import logging +import warnings +import sys + +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import LazyConfig, instantiate +from detectron2.engine import ( + default_argument_parser, + default_setup, + default_writers, + hooks, + launch, +) +from detectron2.engine.defaults import create_ddp_model +from detectron2.utils import comm + + +sys.path.append(osp.dirname(osp.dirname(__file__))) +warnings.filterwarnings("ignore") +logger = logging.getLogger("detectron2") + + +from engine import CycleTrainer + + +def do_train(args, cfg): + """ + Args: + cfg: an object with the following attributes: + model: instantiate to a module + dataloader.{train,test}: instantiate to dataloaders + dataloader.evaluator: instantiate to evaluator for test set + optimizer: instantaite to an optimizer + lr_multiplier: instantiate to a fvcore scheduler + train: other misc config defined in `configs/common/train.py`, including: + output_dir (str) + init_checkpoint (str) + amp.enabled (bool) + max_iter (int) + eval_period, log_period (int) + device (str) + checkpointer (dict) + ddp (dict) + """ + model = instantiate(cfg.model) + logger = logging.getLogger("detectron2") + logger.info("Model:\n{}".format(model)) + model.to(cfg.train.device) + + cfg.optimizer.params.model = model + optim = instantiate(cfg.optimizer) + + train_loader = instantiate(cfg.dataloader.train) + + model = create_ddp_model(model, **cfg.train.ddp) + trainer = CycleTrainer(model, train_loader, optim) + checkpointer = DetectionCheckpointer( + model, + cfg.train.output_dir, + trainer=trainer, + ) + trainer.register_hooks( + [ + hooks.IterationTimer(), + hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)), + hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer) + if comm.is_main_process() + else None, + # hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)), + hooks.PeriodicWriter( + default_writers(cfg.train.output_dir, cfg.train.max_iter), + period=cfg.train.log_period, + ) + if comm.is_main_process() + else None, + ] + ) + + checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume) + if args.resume and checkpointer.has_checkpoint(): + start_iter = trainer.iter + 1 + else: + start_iter = 0 + trainer.train(start_iter, cfg.train.max_iter) + + +def main(args): + cfg = LazyConfig.load(args.config_file) + cfg = LazyConfig.apply_overrides(cfg, args.opts) + default_setup(cfg, args) + do_train(args, cfg) + + +if __name__ == "__main__": + args = default_argument_parser().parse_args() + launch( + main, + args.num_gpus, + num_machines=args.num_machines, + machine_rank=args.machine_rank, + dist_url=args.dist_url, + args=(args,), + ) diff --git a/tools/utils.py b/tools/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..633cfc5184fba56042259919d74542780bedf9da --- /dev/null +++ b/tools/utils.py @@ -0,0 +1,164 @@ +from typing import Union, Iterable, Tuple +import numpy as np +import torch +import cv2 +from sklearn.metrics import roc_auc_score +from sklearn.metrics import average_precision_score + + +def auc(heatmap, onehot_im, is_im=True): + if is_im: + auc_score = roc_auc_score( + np.reshape(onehot_im, onehot_im.size), np.reshape(heatmap, heatmap.size) + ) + else: + auc_score = roc_auc_score(onehot_im, heatmap) + return auc_score + + +def ap(label, pred): + return average_precision_score(label, pred) + + +def argmax_pts(heatmap): + idx = np.unravel_index(heatmap.argmax(), heatmap.shape) + pred_y, pred_x = map(float, idx) + return pred_x, pred_y + + +def L2_dist(p1, p2): + return np.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) + + +def multi_hot_targets(gaze_pts, out_res): + w, h = out_res + target_map = np.zeros((h, w)) + for p in gaze_pts: + if p[0] >= 0: + x, y = map(int, [p[0] * w.float(), p[1] * h.float()]) + x = min(x, w - 1) + y = min(y, h - 1) + target_map[y, x] = 1 + return target_map + + +def inverse_transform(tensor: torch.Tensor) -> np.ndarray: + tensor = tensor.detach().cpu().permute(0, 2, 3, 1) + mean = torch.tensor([0.485, 0.456, 0.406]) + std = torch.tensor([0.229, 0.224, 0.225]) + tensor = tensor * std + mean + return cv2.cvtColor((tensor.numpy() * 255).astype(np.uint8)[0], cv2.COLOR_RGB2BGR) + + +def draw(data, heatmap, out_path, on_img=True): + img = inverse_transform(data["images"]) + head_channel = cv2.applyColorMap( + (data["head_channels"].squeeze().detach().cpu().numpy() * 255).astype(np.uint8), + cv2.COLORMAP_BONE, + ) + hm = cv2.applyColorMap((heatmap * 255).astype(np.uint8), cv2.COLORMAP_JET) + heatmap = hm + heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0])) + if on_img: + img = cv2.addWeighted(img, 1, heatmap, 0.5, 1) + else: + img = heatmap + # img = cv2.addWeighted(img, 1, head_channel, 0.1, 1) + cv2.imwrite(out_path, img) + + +def draw_origin_img(data, out_path): + img = inverse_transform(data["images"]) + hm = cv2.applyColorMap( + (data["heatmaps"].squeeze().detach().cpu().numpy() * 255).astype(np.uint8), + cv2.COLORMAP_JET, + ) + hm[data["heatmaps"].squeeze().detach().cpu().numpy() == 0] = 0 + hm = cv2.resize(hm, (img.shape[1], img.shape[0])) + head_channel = cv2.applyColorMap( + (data["head_channels"].squeeze().detach().cpu().numpy() * 255).astype(np.uint8), + cv2.COLORMAP_BONE, + ) + head_channel[data["head_channels"].squeeze().detach().cpu().numpy() < 0.1] = 0 + hm = cv2.resize(hm, (img.shape[1], img.shape[0])) + ori = cv2.addWeighted(img, 1, hm, 0.5, 1) + ori = cv2.addWeighted(ori, 1, head_channel, 0.1, 1) + cv2.imwrite(out_path, ori) + + +class __Image2MP4: + def __init__(self): + self.Fourcc = cv2.VideoWriter_fourcc(*"mp4v") + + def __call__( + self, + frames: Union[Iterable[np.ndarray], str], + path: str, + fps: float = 30.0, + isize: Tuple[int, int] = None, + ): + if isinstance(frames, str): # directory of img files + from os import listdir, path as osp + + imgs = sorted(listdir(frames)) + frames = [ + cv2.imread(osp.join(frames, img), cv2.IMREAD_COLOR) for img in imgs + ] + + if isize is None: + isize = (frames[0].shape[1], frames[0].shape[0]) + + output_video = cv2.VideoWriter(path, self.Fourcc, fps, isize) + for frame in frames: + frame = cv2.resize(frame, isize) + output_video.write(frame) + output_video.release() + + +img2mp4 = __Image2MP4() + + +def dark_inference(heatmap: np.ndarray, gaussian_kernel: int = 39): + pred_x, pred_y = argmax_pts(heatmap) + pred_x, pred_y = int(pred_x), int(pred_y) + height, width = heatmap.shape[-2:] + # Gaussian blur + orig_max = heatmap.max() + border = (gaussian_kernel - 1) // 2 + dr = np.zeros((height + 2 * border, width + 2 * border)) + dr[border:-border, border:-border] = heatmap.copy() + dr = cv2.GaussianBlur(dr, (gaussian_kernel, gaussian_kernel), 0) + heatmap = dr[border:-border, border:-border].copy() + heatmap *= orig_max / np.max(heatmap) + # Log-likelihood + heatmap = np.maximum(heatmap, 1e-10) + heatmap = np.log(heatmap) + # DARK + if 1 < pred_x < width - 2 and 1 < pred_y < height - 2: + dx = 0.5 * (heatmap[pred_y][pred_x + 1] - heatmap[pred_y][pred_x - 1]) + dy = 0.5 * (heatmap[pred_y + 1][pred_x] - heatmap[pred_y - 1][pred_x]) + dxx = 0.25 * ( + heatmap[pred_y][pred_x + 2] + - 2 * heatmap[pred_y][pred_x] + + heatmap[pred_y][pred_x - 2] + ) + dxy = 0.25 * ( + heatmap[pred_y + 1][pred_x + 1] + - heatmap[pred_y - 1][pred_x + 1] + - heatmap[pred_y + 1][pred_x - 1] + + heatmap[pred_y - 1][pred_x - 1] + ) + dyy = 0.25 * ( + heatmap[pred_y + 2][pred_x] + - 2 * heatmap[pred_y][pred_x] + + heatmap[pred_y - 2][pred_x] + ) + derivative = np.matrix([[dx],[dy]]) + hessian = np.matrix([[dxx,dxy],[dxy,dyy]]) + if dxx * dyy - dxy ** 2 != 0: + hessianinv = hessian.I + offset = -hessianinv * derivative + offset_x, offset_y = np.squeeze(np.array(offset.T), axis=0) + pred_x += offset_x + pred_y += offset_y + return pred_x, pred_y diff --git a/train.sh b/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..0b942ebde34c3f91b61a6be836fee7901f72c829 --- /dev/null +++ b/train.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +export CUDA_VISIBLE_DEVICES="0,1" + +config_files=( + "configs/gazefollow.py" + "configs/gazefollow_518.py" + "configs/videoattentiontarget.py" +) + +run_experiment() { + local config="$1" + echo "Running experiment with config: $config" + python -u tools/train.py --config-file "$config" --num-gpu 2 +} + +for config in "${config_files[@]}" +do + run_experiment "$config" & + pid=$! + wait "$pid" + sleep 10 +done diff --git a/val.sh b/val.sh new file mode 100644 index 0000000000000000000000000000000000000000..7e043f9ecf0481df6cd096863427f9fa3029e349 --- /dev/null +++ b/val.sh @@ -0,0 +1,22 @@ +#!/bin/bash +# [Usage] bash val.sh [$config_file] [$weight_file] [$(gf/vat)]; e.g. +# bash val.sh configs/gazefollow_518.py output/gazefollow_518/model_final.pth gf +# bash val.sh configs/videoattentiontarget.py output/videoattentiontarget/model_final.pth vat + +config_file="$1" +checkpoint="$2" + +if [ "$3" = "gf" ]; then + evaluater="tools/eval_on_gazefollow.py" +elif [ "$3" = "vat" ]; then + evaluater="tools/eval_on_video_attention_target.py" +else + echo "Invalid dataset" + exit 1 +fi + +export CUDA_VISIBLE_DEVICES="0" +echo "Evaluating with:" +echo "config: $config_file" +echo "checkpoint: $checkpoint" +python -u $evaluater --config_file $config_file --model_weights "$checkpoint" --use_dark_inference