diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..e010efe37302940c4d5082a9c99eff1f8a62a96d 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+assets/demo0.gif filter=lfs diff=lfs merge=lfs -text
+assets/demo1.gif filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..7ba185e9f8cc7306d13547e81841c9586ca47e23
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,16 @@
+.vscode
+.idea
+*output
+backup
+*/__pycache__/*
+pretrained
+logs
+*.pyc
+*.ipynb
+*.out
+*.log
+*.pth
+*.pkl
+*.pt
+*.npy
+*debug*
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..85d81725587d6c36deb2e975d77547709f61369b
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2024 HUST Vision Lab
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
index 7be5fc7f47d5db027d120b8024982df93db95b74..7838b35d2a2c8741f99fc909f8fd0959006a543a 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,78 @@
----
-license: mit
----
+
+
ViTGaze 👀
+Gaze Following with Interaction Features in Vision Transformers
+
+Yuehao Song1 , Xinggang Wang1 :email: , Jingfeng Yao1 , Wenyu Liu1 , Jinglin Zhang2 , Xiangmin Xu3
+
+1 Huazhong University of Science and Technology, 2 Shandong University, 3 South China University of Technology
+
+(:email:) corresponding author.
+
+ArXiv Preprint ([arXiv 2403.12778](https://arxiv.org/abs/2403.12778))
+
+
+
+#
+
+
+### News
+* **`Mar. 25th, 2024`:** We release an initial version of ViTGaze.
+* **`Mar. 19th, 2024`:** We released our paper on Arxiv. Code/Models are coming soon. Please stay tuned! ☕️
+
+
+## Introduction
+Plain Vision Transformer could also do gaze following with the simple ViTGaze framework!
+
+
+
+Inspired by the remarkable success of pre-trained plain Vision Transformers (ViTs), we introduce a novel single-modality gaze following framework, **ViTGaze**. In contrast to previous methods, it creates a brand new gaze following framework based mainly on powerful encoders (relative decoder parameter less than 1%). Our principal insight lies in that the inter-token interactions within self-attention can be transferred to interactions between humans and scenes. Our method achieves state-of-the-art (SOTA) performance among all single-modality methods (3.4% improvement on AUC, 5.1% improvement on AP) and very comparable performance against multi-modality methods with 59% number of parameters less.
+
+## Results
+> Results from the [ViTGaze paper](https://arxiv.org/abs/2403.12778)
+
+
+
+
+
+ Results on GazeFollow |
+ Results on VideoAttentionTarget |
+
+
+ AUC |
+ Avg. Dist. |
+ Min. Dist. |
+ AUC |
+ Dist. |
+ AP |
+
+
+ 0.949 |
+ 0.105 |
+ 0.047 |
+ 0.938 |
+ 0.102 |
+ 0.905 |
+
+
+
+Corresponding checkpoints are released:
+- GazeFollow: [GoogleDrive](https://drive.google.com/file/d/164c4woGCmUI8UrM7GEKQrV1FbA3vGwP4/view?usp=drive_link)
+- VideoAttentionTarget: [GoogleDrive](https://drive.google.com/file/d/11_O4Jm5wsvQ8qfLLgTlrudqSNvvepsV0/view?usp=drive_link)
+## Getting Started
+- [Installation](docs/install.md)
+- [Train](docs/train.md)
+- [Eval](docs/eval.md)
+
+## Acknowledgements
+ViTGaze is based on [detectron2](https://github.com/facebookresearch/detectron2). We use the efficient multi-head attention implemented in the [xFormers](https://github.com/facebookresearch/xformers) library.
+
+## Citation
+If you find ViTGaze is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.
+```bibtex
+@article{vitgaze,
+ title={ViTGaze: Gaze Following with Interaction Features in Vision Transformers},
+ author={Yuehao Song and Xinggang Wang and Jingfeng Yao and Wenyu Liu and Jinglin Zhang and Xiangmin Xu},
+ journal={arXiv preprint arXiv:2403.12778},
+ year={2024}
+}
+```
diff --git a/assets/comparion.png b/assets/comparion.png
new file mode 100644
index 0000000000000000000000000000000000000000..e3858095eeb24643628f9d108b6d79a61c5ef1af
Binary files /dev/null and b/assets/comparion.png differ
diff --git a/assets/demo0.gif b/assets/demo0.gif
new file mode 100644
index 0000000000000000000000000000000000000000..d8009ea60e54012f974b8f57599d464180bd0734
--- /dev/null
+++ b/assets/demo0.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:89e36ce1438c230928f376b4a2668e5daabace217a63e1df095ddce7202851ec
+size 5921851
diff --git a/assets/demo1.gif b/assets/demo1.gif
new file mode 100644
index 0000000000000000000000000000000000000000..6245e11013168d79398800694d22e1bc1ebb8937
--- /dev/null
+++ b/assets/demo1.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3d5c7a413de2f9ff12818d0f476f868e5fb867c17d85315cce3eb090212468da
+size 9477364
diff --git a/assets/pipeline.png b/assets/pipeline.png
new file mode 100644
index 0000000000000000000000000000000000000000..c1b70543898066cad1414f66c9d30048ac03a75c
Binary files /dev/null and b/assets/pipeline.png differ
diff --git a/configs/common/__init__.py b/configs/common/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6e690fd59145ce8900fd9ab8d8a996ee7d33834
--- /dev/null
+++ b/configs/common/__init__.py
@@ -0,0 +1 @@
+from . import *
diff --git a/configs/common/dataloader.py b/configs/common/dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..2de6507f4d3075211230f2b65d4cad2d55b98ba2
--- /dev/null
+++ b/configs/common/dataloader.py
@@ -0,0 +1,214 @@
+from os import path as osp
+from typing import Literal
+
+from omegaconf import OmegaConf
+from detectron2.config import LazyCall as L
+from detectron2.config import instantiate
+from torch.utils.data import DataLoader
+from torch.utils.data.distributed import DistributedSampler
+
+from data import *
+
+
+DATA_ROOT = "${Root to Datasets}"
+if DATA_ROOT == "${Root to Datasets}":
+ raise Exception(
+ f"""{osp.abspath(__file__)}: Rewrite `DATA_ROOT` with the root to the datasets.
+The directory structure should be:
+-DATA_ROOT
+|-videoattentiontarget
+|--images
+|--annotations
+|---train
+|---test
+|--head_masks
+|---images
+|-gazefollow
+|--train
+|--test2
+|--train_annotations_release.txt
+|--test_annotations_release.txt
+|--head_masks
+|---train
+|---test2
+"""
+ )
+
+# Basic Config for Video Attention Target dataset and preprocessing
+data_info = OmegaConf.create()
+data_info.video_attention_target = OmegaConf.create()
+data_info.video_attention_target.train_root = osp.join(
+ DATA_ROOT, "videoattentiontarget/images"
+)
+data_info.video_attention_target.train_anno = osp.join(
+ DATA_ROOT, "videoattentiontarget/annotations/train"
+)
+data_info.video_attention_target.val_root = osp.join(
+ DATA_ROOT, "videoattentiontarget/images"
+)
+data_info.video_attention_target.val_anno = osp.join(
+ DATA_ROOT, "videoattentiontarget/annotations/test"
+)
+data_info.video_attention_target.head_root = osp.join(
+ DATA_ROOT, "videoattentiontarget/head_masks/images"
+)
+
+data_info.video_attention_target_video = OmegaConf.create()
+data_info.video_attention_target_video.train_root = osp.join(
+ DATA_ROOT, "videoattentiontarget/images"
+)
+data_info.video_attention_target_video.train_anno = osp.join(
+ DATA_ROOT, "videoattentiontarget/annotations/train"
+)
+data_info.video_attention_target_video.val_root = osp.join(
+ DATA_ROOT, "videoattentiontarget/images"
+)
+data_info.video_attention_target_video.val_anno = osp.join(
+ DATA_ROOT, "videoattentiontarget/annotations/test"
+)
+data_info.video_attention_target_video.head_root = osp.join(
+ DATA_ROOT, "videoattentiontarget/head_masks/images"
+)
+
+data_info.gazefollow = OmegaConf.create()
+data_info.gazefollow.train_root = osp.join(DATA_ROOT, "gazefollow")
+data_info.gazefollow.train_anno = osp.join(
+ DATA_ROOT, "gazefollow/train_annotations_release.txt"
+)
+data_info.gazefollow.val_root = osp.join(DATA_ROOT, "gazefollow")
+data_info.gazefollow.val_anno = osp.join(
+ DATA_ROOT, "gazefollow/test_annotations_release.txt"
+)
+data_info.gazefollow.head_root = osp.join(DATA_ROOT, "gazefollow/head_masks")
+
+data_info.input_size = 224
+data_info.output_size = 64
+data_info.quant_labelmap = True
+data_info.mean = (0.485, 0.456, 0.406)
+data_info.std = (0.229, 0.224, 0.225)
+data_info.bbox_jitter = 0.5
+data_info.rand_crop = 0.5
+data_info.rand_flip = 0.5
+data_info.color_jitter = 0.5
+data_info.rand_rotate = 0.0
+data_info.rand_lsj = 0.0
+
+data_info.mask_size = 24
+data_info.mask_scene = False
+data_info.mask_head = False
+data_info.max_scene_patches_ratio = 0.5
+data_info.max_head_patches_ratio = 0.3
+data_info.mask_prob = 0.2
+
+data_info.seq_len = 16
+data_info.max_len = 32
+
+
+# Dataloader(gazefollow/video_atention_target, train/val)
+def __build_dataloader(
+ name: Literal[
+ "gazefollow", "video_attention_target", "video_attention_target_video"
+ ],
+ is_train: bool,
+ batch_size: int = 64,
+ num_workers: int = 14,
+ pin_memory: bool = True,
+ persistent_workers: bool = True,
+ drop_last: bool = True,
+ distributed: bool = False,
+ **kwargs,
+):
+ assert name in [
+ "gazefollow",
+ "video_attention_target",
+ "video_attention_target_video",
+ ], f'{name} not in ("gazefollow", "video_attention_target", "video_attention_target_video")'
+
+ for k, v in kwargs.items():
+ if k in ["train_root", "train_anno", "val_root", "val_anno", "head_root"]:
+ data_info[name][k] = v
+ else:
+ data_info[k] = v
+
+ datasets = {
+ "gazefollow": GazeFollow,
+ "video_attention_target": VideoAttentionTarget,
+ "video_attention_target_video": VideoAttentionTargetVideo,
+ }
+ dataset = L(datasets[name])(
+ image_root=data_info[name]["train_root" if is_train else "val_root"],
+ anno_root=data_info[name]["train_anno" if is_train else "val_anno"],
+ head_root=data_info[name]["head_root"],
+ transform=get_transform(
+ input_resolution=data_info.input_size,
+ mean=data_info.mean,
+ std=data_info.std,
+ ),
+ input_size=data_info.input_size,
+ output_size=data_info.output_size,
+ quant_labelmap=data_info.quant_labelmap,
+ is_train=is_train,
+ bbox_jitter=data_info.bbox_jitter,
+ rand_crop=data_info.rand_crop,
+ rand_flip=data_info.rand_flip,
+ color_jitter=data_info.color_jitter,
+ rand_rotate=data_info.rand_rotate,
+ rand_lsj=data_info.rand_lsj,
+ mask_generator=(
+ MaskGenerator(
+ input_size=data_info.mask_size,
+ mask_scene=data_info.mask_scene,
+ mask_head=data_info.mask_head,
+ max_scene_patches_ratio=data_info.max_scene_patches_ratio,
+ max_head_patches_ratio=data_info.max_head_patches_ratio,
+ mask_prob=data_info.mask_prob,
+ )
+ if is_train
+ else None
+ ),
+ )
+ if name == "video_attention_target_video":
+ dataset.seq_len = data_info.seq_len
+ dataset.max_len = data_info.max_len
+ dataset = instantiate(dataset)
+
+ return DataLoader(
+ dataset=dataset,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ pin_memory=pin_memory,
+ persistent_workers=persistent_workers,
+ collate_fn=video_collate if name == "video_attention_target_video" else None,
+ sampler=DistributedSampler(dataset, shuffle=is_train) if distributed else None,
+ drop_last=drop_last,
+ )
+
+
+dataloader = OmegaConf.create()
+dataloader.gazefollow = OmegaConf.create()
+dataloader.gazefollow.train = L(__build_dataloader)(
+ name="gazefollow",
+ is_train=True,
+)
+dataloader.gazefollow.val = L(__build_dataloader)(
+ name="gazefollow",
+ is_train=False,
+)
+dataloader.video_attention_target = OmegaConf.create()
+dataloader.video_attention_target.train = L(__build_dataloader)(
+ name="video_attention_target",
+ is_train=True,
+)
+dataloader.video_attention_target.val = L(__build_dataloader)(
+ name="video_attention_target",
+ is_train=False,
+)
+dataloader.video_attention_target_video = OmegaConf.create()
+dataloader.video_attention_target_video.train = L(__build_dataloader)(
+ name="video_attention_target_video",
+ is_train=True,
+)
+dataloader.video_attention_target_video.val = L(__build_dataloader)(
+ name="video_attention_target_video",
+ is_train=False,
+)
diff --git a/configs/common/model.py b/configs/common/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd9f0e3e0f6037ef0b1c839067ad337afbd897df
--- /dev/null
+++ b/configs/common/model.py
@@ -0,0 +1,44 @@
+from detectron2.config import LazyCall as L
+
+from modeling import backbone, patch_attention, meta_arch, head, criterion
+
+
+model = L(meta_arch.GazeAttentionMapper)()
+model.backbone = L(backbone.build_backbone)(
+ name="small", out_attn=[2, 5, 8, 11]
+)
+model.pam = L(patch_attention.build_pam)(name="PatchPAM", patch_size=16)
+model.regressor = L(head.build_heatmap_head)(
+ name="SimpleDeconv",
+ in_channel=24,
+ deconv_cfgs=[
+ {
+ "in_channels": 24,
+ "out_channels": 12,
+ "kernel_size": 3,
+ "stride": 2,
+ },
+ {
+ "in_channels": 12,
+ "out_channels": 6,
+ "kernel_size": 3,
+ "stride": 2,
+ },
+ {
+ "in_channels": 6,
+ "out_channels": 3,
+ "kernel_size": 3,
+ "stride": 2,
+ },
+ {
+ "in_channels": 3,
+ "out_channels": 1,
+ "kernel_size": 3,
+ "stride": 2,
+ },
+ ],
+ feat_type="attn",
+)
+model.classifier = L(head.build_inout_head)(name="SimpleLinear", in_channel=384)
+model.criterion = L(criterion.GazeMapperCriterion)()
+model.device = "cuda"
diff --git a/configs/common/optimizer.py b/configs/common/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cd84ed34f12d46b3be905ef3edebfae26b867d4
--- /dev/null
+++ b/configs/common/optimizer.py
@@ -0,0 +1,48 @@
+from detectron2 import model_zoo
+
+
+def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12):
+ """
+ Calculate lr decay rate for different ViT blocks.
+ Args:
+ name (string): parameter name.
+ lr_decay_rate (float): base lr decay rate.
+ num_layers (int): number of ViT blocks.
+
+ Returns:
+ lr decay rate for the given parameter.
+ """
+ layer_id = num_layers + 1
+ if name.startswith("backbone"):
+ if ".pos_embed" in name or ".patch_embed" in name:
+ layer_id = 0
+ elif ".blocks." in name and ".residual." not in name:
+ layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
+
+ return lr_decay_rate ** (num_layers + 1 - layer_id)
+
+
+class LRDecayRater:
+ def __init__(self, lr_decay_rate=1.0, num_layers=12, backbone_multiplier=1.0, freeze_pe=False, pam_lr_decay=1):
+ self.lr_decay_rate = lr_decay_rate
+ self.num_layers = num_layers
+ self.backbone_multiplier = backbone_multiplier
+ self.freeze_pe = freeze_pe
+ self.pam_lr_decay = pam_lr_decay
+
+ def __call__(self, name):
+ if name.startswith("backbone"):
+ if self.freeze_pe and ".pos_embed" in name or ".patch_embed" in name:
+ return 0
+ return self.backbone_multiplier * get_vit_lr_decay_rate(
+ name, self.lr_decay_rate, self.num_layers
+ )
+ if name.startswith("pam"):
+ return self.pam_lr_decay
+ return 1
+
+
+# Optimizer
+optimizer = model_zoo.get_config("common/optim.py").AdamW
+optimizer.params.lr_factor_func = LRDecayRater(num_layers=12, lr_decay_rate=0.65)
+optimizer.params.overrides = {"pos_embed": {"weight_decay": 0.0}}
diff --git a/configs/common/scheduler.py b/configs/common/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..120e06acb59ff30df5309db98d0c8cd3a5284596
--- /dev/null
+++ b/configs/common/scheduler.py
@@ -0,0 +1,18 @@
+from typing import Literal
+from detectron2.config import LazyCall as L
+from detectron2.solver import WarmupParamScheduler
+from fvcore.common.param_scheduler import MultiStepParamScheduler, CosineParamScheduler
+
+
+def get_scheduler(typ: Literal["multistep", "cosine"] = "multistep", **kwargs):
+ if typ == "multistep":
+ return MultiStepParamScheduler(**kwargs)
+ elif typ == "cosine":
+ return CosineParamScheduler(**kwargs)
+
+
+lr_multiplier = L(WarmupParamScheduler)(
+ scheduler=L(get_scheduler)(),
+ warmup_length=0,
+ warmup_factor=0.001,
+)
diff --git a/configs/common/train.py b/configs/common/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..4eb45504b9c470f45a0daf57f68f32e740d18124
--- /dev/null
+++ b/configs/common/train.py
@@ -0,0 +1,16 @@
+train = dict(
+ output_dir="./output",
+ init_checkpoint="",
+ max_iter=90000,
+ amp=dict(enabled=False), # options for Automatic Mixed Precision
+ ddp=dict( # options for DistributedDataParallel
+ broadcast_buffers=True,
+ find_unused_parameters=False,
+ fp16_compression=True,
+ ),
+ checkpointer=dict(period=5000, max_to_keep=100), # options for PeriodicCheckpointer
+ eval_period=5000,
+ log_period=100,
+ device="cuda",
+ # ...
+)
diff --git a/configs/gazefollow.py b/configs/gazefollow.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee5119a1932f25be69b42eb4d522d3de191116b3
--- /dev/null
+++ b/configs/gazefollow.py
@@ -0,0 +1,86 @@
+from .common.dataloader import dataloader
+from .common.model import model
+from .common.optimizer import optimizer
+from .common.scheduler import lr_multiplier
+from .common.train import train
+from os.path import join, basename
+from torch.cuda import device_count
+
+
+num_gpu = device_count()
+ins_per_iter = 48
+len_dataset = 126000
+num_epoch = 14
+# dataloader
+dataloader = dataloader.gazefollow
+dataloader.train.batch_size = ins_per_iter // num_gpu
+dataloader.train.num_workers = dataloader.val.num_workers = 14
+dataloader.train.distributed = num_gpu > 1
+dataloader.train.rand_rotate = 0.5
+dataloader.train.rand_lsj = 0.5
+dataloader.train.input_size = dataloader.val.input_size = 434
+dataloader.train.mask_scene = True
+dataloader.train.mask_prob = 0.5
+dataloader.train.mask_size = dataloader.train.input_size // 14
+dataloader.train.max_scene_patches_ratio = 0.5
+dataloader.val.batch_size = 64
+dataloader.val.distributed = False
+# train
+train.init_checkpoint = "pretrained/dinov2_small.pth"
+train.output_dir = join("./output", basename(__file__).split(".")[0])
+train.max_iter = len_dataset * num_epoch // ins_per_iter
+train.log_period = len_dataset // (ins_per_iter * 100)
+train.checkpointer.max_to_keep = 10
+train.checkpointer.period = len_dataset // ins_per_iter
+train.seed = 0
+# optimizer
+optimizer.lr = 1e-4
+optimizer.betas = (0.9, 0.99)
+lr_multiplier.scheduler.typ = "cosine"
+lr_multiplier.scheduler.start_value = 1
+lr_multiplier.scheduler.end_value = 0.1
+lr_multiplier.warmup_length = 1e-2
+# model
+model.use_aux_loss = model.pam.use_aux_loss = model.criterion.use_aux_loss = True
+model.pam.name = "PatchPAM"
+model.pam.embed_dim = 8
+model.pam.patch_size = 14
+model.backbone.name = "dinov2_small"
+model.backbone.return_softmax_attn = True
+model.backbone.out_attn = [2, 5, 8, 11]
+model.backbone.use_cls_token = True
+model.backbone.use_mask_token = True
+model.regressor.name = "UpSampleConv"
+model.regressor.in_channel = 24
+model.regressor.use_conv = False
+model.regressor.dim = 24
+model.regressor.deconv_cfgs = [
+ dict(
+ in_channels=24,
+ out_channels=16,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ ),
+ dict(
+ in_channels=16,
+ out_channels=8,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ ),
+ dict(
+ in_channels=8,
+ out_channels=1,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ ),
+]
+model.regressor.feat_type = "attn"
+model.classifier.name = "SimpleMlp"
+model.classifier.in_channel = 384
+model.criterion.aux_weight = 100
+model.criterion.aux_head_thres = 0.05
+model.criterion.use_focal_loss = True
+model.device = "cuda"
diff --git a/configs/gazefollow_518.py b/configs/gazefollow_518.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9cc1c86ad0a424c7e0953cb28646419922bd903
--- /dev/null
+++ b/configs/gazefollow_518.py
@@ -0,0 +1,86 @@
+from .common.dataloader import dataloader
+from .common.model import model
+from .common.optimizer import optimizer
+from .common.scheduler import lr_multiplier
+from .common.train import train
+from os.path import join, basename
+from torch.cuda import device_count
+
+
+num_gpu = device_count()
+ins_per_iter = 32
+len_dataset = 126000
+num_epoch = 1
+# dataloader
+dataloader = dataloader.gazefollow
+dataloader.train.batch_size = ins_per_iter // num_gpu
+dataloader.train.num_workers = dataloader.val.num_workers = 14
+dataloader.train.distributed = num_gpu > 1
+dataloader.train.rand_rotate = 0.5
+dataloader.train.rand_lsj = 0.5
+dataloader.train.input_size = dataloader.val.input_size = 518
+dataloader.train.mask_scene = True
+dataloader.train.mask_prob = 0.5
+dataloader.train.mask_size = dataloader.train.input_size // 14
+dataloader.train.max_scene_patches_ratio = 0.5
+dataloader.val.batch_size = 32
+dataloader.val.distributed = False
+# train
+train.init_checkpoint = "output/gazefollow/model_final.pth"
+train.output_dir = join("./output", basename(__file__).split(".")[0])
+train.max_iter = len_dataset * num_epoch // ins_per_iter
+train.log_period = len_dataset // (ins_per_iter * 100)
+train.checkpointer.max_to_keep = 10
+train.checkpointer.period = len_dataset // ins_per_iter
+train.seed = 0
+# optimizer
+optimizer.lr = 1e-5
+optimizer.betas = (0.9, 0.99)
+lr_multiplier.scheduler.typ = "cosine"
+lr_multiplier.scheduler.start_value = 1
+lr_multiplier.scheduler.end_value = 0.1
+lr_multiplier.warmup_length = 1e-2
+# model
+model.use_aux_loss = model.pam.use_aux_loss = model.criterion.use_aux_loss = True
+model.pam.name = "PatchPAM"
+model.pam.embed_dim = 8
+model.pam.patch_size = 14
+model.backbone.name = "dinov2_small"
+model.backbone.return_softmax_attn = True
+model.backbone.out_attn = [2, 5, 8, 11]
+model.backbone.use_cls_token = True
+model.backbone.use_mask_token = True
+model.regressor.name = "UpSampleConv"
+model.regressor.in_channel = 24
+model.regressor.use_conv = False
+model.regressor.dim = 24
+model.regressor.deconv_cfgs = [
+ dict(
+ in_channels=24,
+ out_channels=16,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ ),
+ dict(
+ in_channels=16,
+ out_channels=8,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ ),
+ dict(
+ in_channels=8,
+ out_channels=1,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ ),
+]
+model.regressor.feat_type = "attn"
+model.classifier.name = "SimpleMlp"
+model.classifier.in_channel = 384
+model.criterion.aux_weight = 0
+model.criterion.aux_head_thres = 0.05
+model.criterion.use_focal_loss = True
+model.device = "cuda"
diff --git a/configs/videoattentiontarget.py b/configs/videoattentiontarget.py
new file mode 100644
index 0000000000000000000000000000000000000000..67fd6d15ecc4077b0a00d7bcbef516e8dfcc5f79
--- /dev/null
+++ b/configs/videoattentiontarget.py
@@ -0,0 +1,90 @@
+from .common.dataloader import dataloader
+from .common.model import model
+from .common.optimizer import optimizer
+from .common.scheduler import lr_multiplier
+from .common.train import train
+from os.path import join, basename
+from torch.cuda import device_count
+
+
+num_gpu = device_count()
+ins_per_iter = 4
+len_dataset = 4400
+num_epoch = 1
+# dataloader
+dataloader = dataloader.video_attention_target_video
+dataloader.train.batch_size = ins_per_iter // num_gpu
+dataloader.train.num_workers = dataloader.val.num_workers = 14
+dataloader.train.distributed = num_gpu > 1
+dataloader.train.rand_rotate = 0.5
+dataloader.train.rand_lsj = 0.5
+dataloader.train.input_size = dataloader.val.input_size = 518
+dataloader.train.mask_scene = True
+dataloader.train.mask_prob = 0.5
+dataloader.train.mask_size = dataloader.train.input_size // 14
+dataloader.train.max_scene_patches_ratio = 0.5
+dataloader.train.seq_len = 8
+dataloader.val.quant_labelmap = False
+dataloader.val.seq_len = 8
+dataloader.val.batch_size = 4
+dataloader.val.distributed = False
+# train
+train.init_checkpoint = "output/gazefollow_518/model_final.pth"
+train.output_dir = join("./output", basename(__file__).split(".")[0])
+train.max_iter = len_dataset * num_epoch // ins_per_iter
+train.log_period = len_dataset // (ins_per_iter * 100)
+train.checkpointer.max_to_keep = 100
+train.checkpointer.period = len_dataset // ins_per_iter
+train.seed = 0
+# optimizer
+optimizer.lr = 1e-6
+# optimizer.params.lr_factor_func.backbone_multiplier = 0.1
+# optimizer.params.lr_factor_func.pam_lr_decay = 0.1
+lr_multiplier.scheduler.values = [1.0]
+lr_multiplier.scheduler.milestones = []
+lr_multiplier.scheduler.num_updates = train.max_iter
+lr_multiplier.warmup_length = 0
+# model
+model.use_aux_loss = model.pam.use_aux_loss = model.criterion.use_aux_loss = True
+model.pam.name = "PatchPAM"
+model.pam.embed_dim = 8
+model.pam.patch_size = 14
+model.backbone.name = "dinov2_small"
+model.backbone.return_softmax_attn = True
+model.backbone.out_attn = [2, 5, 8, 11]
+model.backbone.use_cls_token = True
+model.backbone.use_mask_token = True
+model.regressor.name = "UpSampleConv"
+model.regressor.in_channel = 24
+model.regressor.use_conv = False
+model.regressor.dim = 24
+model.regressor.deconv_cfgs = [
+ dict(
+ in_channels=24,
+ out_channels=16,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ ),
+ dict(
+ in_channels=16,
+ out_channels=8,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ ),
+ dict(
+ in_channels=8,
+ out_channels=1,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ ),
+]
+model.regressor.feat_type = "attn"
+model.classifier.name = "SimpleMlp"
+model.classifier.in_channel = 384
+model.criterion.aux_weight = 0
+model.criterion.aux_head_thres = 0.05
+model.criterion.use_focal_loss = True
+model.device = "cuda"
diff --git a/data/__init__.py b/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e7df832158379ba2b7820419fd9eb24be240503
--- /dev/null
+++ b/data/__init__.py
@@ -0,0 +1,5 @@
+from .video_attention_target import VideoAttentionTarget
+from .video_attention_target_video import VideoAttentionTargetVideo, video_collate
+from .gazefollow import GazeFollow
+from .data_utils import get_transform
+from .masking import MaskGenerator
diff --git a/data/augmentation.py b/data/augmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1a195aa8e4f55112d8b4e03dde923be84aa0d63
--- /dev/null
+++ b/data/augmentation.py
@@ -0,0 +1,312 @@
+import math
+from typing import Tuple, List
+import numpy as np
+from PIL import Image, ImageOps
+from torchvision import transforms
+from torchvision.transforms import functional as TF
+
+
+class Augmentation:
+ def __init__(self, p: float) -> None:
+ self.p = p
+
+ def transform(
+ self,
+ image: Image,
+ bbox: Tuple[float],
+ gaze: Tuple[float],
+ head_mask: Image,
+ size: Tuple[int],
+ ):
+ raise NotImplementedError
+
+ def __call__(
+ self,
+ image: Image,
+ bbox: Tuple[float],
+ gaze: Tuple[float],
+ head_mask: Image,
+ size: Tuple[int],
+ ):
+ if np.random.random_sample() < self.p:
+ return self.transform(image, bbox, gaze, head_mask, size)
+ return image, bbox, gaze, head_mask, size
+
+
+class AugmentationList:
+ def __init__(self, augmentations: List[Augmentation]) -> None:
+ self.augmentations = augmentations
+
+ def __call__(
+ self,
+ image: Image,
+ bbox: Tuple[float],
+ gaze: Tuple[float],
+ head_mask: Image,
+ size: Tuple[int],
+ ):
+ for aug in self.augmentations:
+ image, bbox, gaze, head_mask, size = aug(image, bbox, gaze, head_mask, size)
+ return image, bbox, gaze, head_mask, size
+
+
+class BoxJitter(Augmentation):
+ # Jitter (expansion-only) bounding box size
+ def __init__(self, p: float, expansion: float = 0.2) -> None:
+ super().__init__(p)
+ self.expansion = expansion
+
+ def transform(
+ self,
+ image: Image,
+ bbox: Tuple[float],
+ gaze: Tuple[float],
+ head_mask: Image,
+ size: Tuple[int],
+ ):
+ x_min, y_min, x_max, y_max = bbox
+ width, height = size
+ k = np.random.random_sample() * self.expansion
+ x_min = np.clip(x_min - k * abs(x_max - x_min), 0, width - 1)
+ y_min = np.clip(y_min - k * abs(y_max - y_min), 0, height - 1)
+ x_max = np.clip(x_max + k * abs(x_max - x_min), 0, width - 1)
+ y_max = np.clip(y_max + k * abs(y_max - y_min), 0, height - 1)
+ return image, (x_min, y_min, x_max, y_max), gaze, head_mask, size
+
+
+class RandomCrop(Augmentation):
+ def __init__(self, p: float) -> None:
+ super().__init__(p)
+
+ def transform(
+ self,
+ image: Image,
+ bbox: Tuple[float],
+ gaze: Tuple[float],
+ head_mask: Image,
+ size: Tuple[int],
+ ):
+ x_min, y_min, x_max, y_max = bbox
+ gaze_x, gaze_y = gaze
+ width, height = size
+ # Calculate the minimum valid range of the crop that doesn't exclude the face and the gaze target
+ crop_x_min = np.min([gaze_x * width, x_min, x_max])
+ crop_y_min = np.min([gaze_y * height, y_min, y_max])
+ crop_x_max = np.max([gaze_x * width, x_min, x_max])
+ crop_y_max = np.max([gaze_y * height, y_min, y_max])
+
+ # Randomly select a random top left corner
+ crop_x_min = np.random.uniform(0, crop_x_min)
+ crop_y_min = np.random.uniform(0, crop_y_min)
+
+ # Find the range of valid crop width and height starting from the (crop_x_min, crop_y_min)
+ crop_width_min = crop_x_max - crop_x_min
+ crop_height_min = crop_y_max - crop_y_min
+ crop_width_max = width - crop_x_min
+ crop_height_max = height - crop_y_min
+
+ # Randomly select a width and a height
+ crop_width = np.random.uniform(crop_width_min, crop_width_max)
+ crop_height = np.random.uniform(crop_height_min, crop_height_max)
+
+ # Round to integers
+ crop_y_min, crop_x_min, crop_height, crop_width = map(
+ int, map(round, (crop_y_min, crop_x_min, crop_height, crop_width))
+ )
+
+ # Crop it
+ image = TF.crop(image, crop_y_min, crop_x_min, crop_height, crop_width)
+ head_mask = TF.crop(head_mask, crop_y_min, crop_x_min, crop_height, crop_width)
+
+ # convert coordinates into the cropped frame
+ x_min, y_min, x_max, y_max = (
+ x_min - crop_x_min,
+ y_min - crop_y_min,
+ x_max - crop_x_min,
+ y_max - crop_y_min,
+ )
+
+ gaze_x = (gaze_x * width - crop_x_min) / float(crop_width)
+ gaze_y = (gaze_y * height - crop_y_min) / float(crop_height)
+
+ return (
+ image,
+ (x_min, y_min, x_max, y_max),
+ (gaze_x, gaze_y),
+ head_mask,
+ (crop_width, crop_height),
+ )
+
+
+class RandomFlip(Augmentation):
+ def __init__(self, p: float) -> None:
+ super().__init__(p)
+
+ def transform(
+ self,
+ image: Image,
+ bbox: Tuple[float],
+ gaze: Tuple[float],
+ head_mask: Image,
+ size: Tuple[int],
+ ):
+ image = image.transpose(Image.FLIP_LEFT_RIGHT)
+ head_mask = head_mask.transpose(Image.FLIP_LEFT_RIGHT)
+ x_min, y_min, x_max, y_max = bbox
+ x_min, x_max = size[0] - x_max, size[0] - x_min
+ gaze_x, gaze_y = 1 - gaze[0], gaze[1]
+ return image, (x_min, y_min, x_max, y_max), (gaze_x, gaze_y), head_mask, size
+
+
+class RandomRotate(Augmentation):
+ def __init__(
+ self, p: float, max_angle: int = 20, resample: int = Image.BILINEAR
+ ) -> None:
+ super().__init__(p)
+ self.max_angle = max_angle
+ self.resample = resample
+
+ def _random_rotation_matrix(self):
+ angle = (2 * np.random.random_sample() - 1) * self.max_angle
+ angle = -math.radians(angle)
+ return [
+ round(math.cos(angle), 15),
+ round(math.sin(angle), 15),
+ 0.0,
+ round(-math.sin(angle), 15),
+ round(math.cos(angle), 15),
+ 0.0,
+ ]
+
+ @staticmethod
+ def _transform(x, y, matrix):
+ return (
+ matrix[0] * x + matrix[1] * y + matrix[2],
+ matrix[3] * x + matrix[4] * y + matrix[5],
+ )
+
+ @staticmethod
+ def _inv_transform(x, y, matrix):
+ x, y = x - matrix[2], y - matrix[5]
+ return matrix[0] * x + matrix[3] * y, matrix[1] * x + matrix[4] * y
+
+ def transform(
+ self,
+ image: Image,
+ bbox: Tuple[float],
+ gaze: Tuple[float],
+ head_mask: Image,
+ size: Tuple[int],
+ ):
+ x_min, y_min, x_max, y_max = bbox
+ gaze_x, gaze_y = gaze
+ width, height = size
+ rot_mat = self._random_rotation_matrix()
+
+ # Calculate offsets
+ rot_center = (width / 2.0, height / 2.0)
+ rot_mat[2], rot_mat[5] = self._transform(
+ -rot_center[0], -rot_center[1], rot_mat
+ )
+ rot_mat[2] += rot_center[0]
+ rot_mat[5] += rot_center[1]
+ xx = []
+ yy = []
+ for x, y in ((0, 0), (width, 0), (width, height), (0, height)):
+ x, y = self._transform(x, y, rot_mat)
+ xx.append(x)
+ yy.append(y)
+ nw = math.ceil(max(xx)) - math.floor(min(xx))
+ nh = math.ceil(max(yy)) - math.floor(min(yy))
+ rot_mat[2], rot_mat[5] = self._transform(
+ -(nw - width) / 2.0, -(nh - height) / 2.0, rot_mat
+ )
+
+ image = image.transform((nw, nh), Image.AFFINE, rot_mat, self.resample)
+ head_mask = head_mask.transform((nw, nh), Image.AFFINE, rot_mat, self.resample)
+
+ xx = []
+ yy = []
+ for x, y in (
+ (x_min, y_min),
+ (x_min, y_max),
+ (x_max, y_min),
+ (x_max, y_max),
+ ):
+ x, y = self._inv_transform(x, y, rot_mat)
+ xx.append(x)
+ yy.append(y)
+ x_max, x_min = min(max(xx), nw), max(min(xx), 0)
+ y_max, y_min = min(max(yy), nh), max(min(yy), 0)
+
+ gaze_x, gaze_y = self._inv_transform(gaze_x * width, gaze_y * height, rot_mat)
+ gaze_x = max(min(gaze_x / nw, 1), 0)
+ gaze_y = max(min(gaze_y / nh, 1), 0)
+
+ return (
+ image,
+ (x_min, y_min, x_max, y_max),
+ (gaze_x, gaze_y),
+ head_mask,
+ (nw, nh),
+ )
+
+
+class ColorJitter(Augmentation):
+ def __init__(
+ self,
+ p: float,
+ brightness: float = 0.4,
+ contrast: float = 0.4,
+ saturation: float = 0.2,
+ hue: float = 0.1,
+ ) -> None:
+ super().__init__(p)
+ self.color_jitter = transforms.ColorJitter(
+ brightness=brightness, contrast=contrast, saturation=saturation, hue=hue
+ )
+
+ def transform(
+ self,
+ image: Image,
+ bbox: Tuple[float],
+ gaze: Tuple[float],
+ head_mask: Image,
+ size: Tuple[int],
+ ):
+ return self.color_jitter(image), bbox, gaze, head_mask, size
+
+
+class RandomLSJ(Augmentation):
+ def __init__(self, p: float, min_scale: float = 0.1) -> None:
+ super().__init__(p)
+ self.min_scale = min_scale
+
+ def transform(
+ self,
+ image: Image,
+ bbox: Tuple[float],
+ gaze: Tuple[float],
+ head_mask: Image,
+ size: Tuple[int],
+ ):
+ x_min, y_min, x_max, y_max = bbox
+ gaze_x, gaze_y = gaze
+ width, height = size
+
+ scale = self.min_scale + np.random.random_sample() * (1 - self.min_scale)
+ nh, nw = int(height * scale), int(width * scale)
+
+ image = TF.resize(image, (nh, nw))
+ image = ImageOps.expand(image, (0, 0, width - nw, height - nh))
+ head_mask = TF.resize(head_mask, (nh, nw))
+ head_mask = ImageOps.expand(head_mask, (0, 0, width - nw, height - nh))
+
+ x_min, y_min, x_max, y_max = (
+ x_min * scale,
+ y_min * scale,
+ x_max * scale,
+ y_max * scale,
+ )
+ gaze_x, gaze_y = gaze_x * scale, gaze_y * scale
+ return image, (x_min, y_min, x_max, y_max), (gaze_x, gaze_y), head_mask, size
diff --git a/data/data_utils.py b/data/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed3e3a6300703df8c05e6da5aa835deca6f56fef
--- /dev/null
+++ b/data/data_utils.py
@@ -0,0 +1,181 @@
+from typing import Tuple
+import torch
+from torchvision import transforms
+import numpy as np
+import pandas as pd
+
+
+def to_numpy(tensor: torch.Tensor):
+ if torch.is_tensor(tensor):
+ return tensor.cpu().detach().numpy()
+ elif type(tensor).__module__ != "numpy":
+ raise ValueError("Cannot convert {} to numpy array".format(type(tensor)))
+ return tensor
+
+
+def to_torch(ndarray: np.ndarray):
+ if type(ndarray).__module__ == "numpy":
+ return torch.from_numpy(ndarray)
+ elif not torch.is_tensor(ndarray):
+ raise ValueError("Cannot convert {} to torch tensor".format(type(ndarray)))
+ return ndarray
+
+
+def get_head_box_channel(
+ x_min, y_min, x_max, y_max, width, height, resolution, coordconv=False
+):
+ head_box = (
+ np.array([x_min / width, y_min / height, x_max / width, y_max / height])
+ * resolution
+ )
+ int_head_box = head_box.astype(int)
+ int_head_box = np.clip(int_head_box, 0, resolution - 1)
+ if int_head_box[0] == int_head_box[2]:
+ if int_head_box[0] == 0:
+ int_head_box[2] = 1
+ elif int_head_box[2] == resolution - 1:
+ int_head_box[0] = resolution - 2
+ elif abs(head_box[2] - int_head_box[2]) > abs(head_box[0] - int_head_box[0]):
+ int_head_box[2] += 1
+ else:
+ int_head_box[0] -= 1
+ if int_head_box[1] == int_head_box[3]:
+ if int_head_box[1] == 0:
+ int_head_box[3] = 1
+ elif int_head_box[3] == resolution - 1:
+ int_head_box[1] = resolution - 2
+ elif abs(head_box[3] - int_head_box[3]) > abs(head_box[1] - int_head_box[1]):
+ int_head_box[3] += 1
+ else:
+ int_head_box[1] -= 1
+ head_box = int_head_box
+ if coordconv:
+ unit = np.array(range(0, resolution), dtype=np.float32)
+ head_channel = []
+ for i in unit:
+ head_channel.append([unit + i])
+ head_channel = np.squeeze(np.array(head_channel)) / float(np.max(head_channel))
+ head_channel[head_box[1] : head_box[3], head_box[0] : head_box[2]] = 0
+ else:
+ head_channel = np.zeros((resolution, resolution), dtype=np.float32)
+ head_channel[head_box[1] : head_box[3], head_box[0] : head_box[2]] = 1
+ head_channel = torch.from_numpy(head_channel)
+ return head_channel
+
+
+def draw_labelmap(img, pt, sigma, type="Gaussian"):
+ # Draw a 2D gaussian
+ # Adopted from https://github.com/anewell/pose-hg-train/blob/master/src/pypose/draw.py
+ img = to_numpy(img)
+
+ # Check that any part of the gaussian is in-bounds
+ size = int(6 * sigma + 1)
+ ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)]
+ br = [ul[0] + size, ul[1] + size]
+ if ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or br[0] < 0 or br[1] < 0:
+ # If not, just return the image as is
+ return to_torch(img)
+
+ # Generate gaussian
+ x = np.arange(0, size, 1, float)
+ y = x[:, np.newaxis]
+ x0 = y0 = size // 2
+ # The gaussian is not normalized, we want the center value to equal 1
+ if type == "Gaussian":
+ g = np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma**2))
+ elif type == "Cauchy":
+ g = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma**2) ** 1.5)
+
+ # Usable gaussian range
+ g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0]
+ g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1]
+ # Image range
+ img_x = max(0, ul[0]), min(br[0], img.shape[1])
+ img_y = max(0, ul[1]), min(br[1], img.shape[0])
+
+ img[img_y[0] : img_y[1], img_x[0] : img_x[1]] += g[g_y[0] : g_y[1], g_x[0] : g_x[1]]
+ # img = img / np.max(img)
+ return to_torch(img)
+
+
+def draw_labelmap_no_quant(img, pt, sigma, type="Gaussian"):
+ img = to_numpy(img)
+ shape = img.shape
+ x = np.arange(shape[0])
+ y = np.arange(shape[1])
+ xx, yy = np.meshgrid(x, y, indexing="ij")
+ dist_matrix = (yy - float(pt[0])) ** 2 + (xx - float(pt[1])) ** 2
+ if type == "Gaussian":
+ g = np.exp(-dist_matrix / (2 * sigma**2))
+ elif type == "Cauchy":
+ g = sigma / ((dist_matrix + sigma**2) ** 1.5)
+ g[dist_matrix > 10 * sigma**2] = 0
+ img += g
+ # img = img / np.max(img)
+ return to_torch(img)
+
+
+def multi_hot_targets(gaze_pts, out_res):
+ w, h = out_res
+ target_map = np.zeros((h, w))
+ for p in gaze_pts:
+ if p[0] >= 0:
+ x, y = map(int, [p[0] * float(w), p[1] * float(h)])
+ x = min(x, w - 1)
+ y = min(y, h - 1)
+ target_map[y, x] = 1
+ return target_map
+
+
+def get_cone(tgt, src, wh, theta=150):
+ eye = src * wh
+ gaze = tgt * wh
+
+ pixel_mat = np.stack(
+ np.meshgrid(np.arange(wh[0]), np.arange(wh[1])),
+ -1,
+ )
+
+ dot_prod = np.sum((pixel_mat - eye) * (gaze - eye), axis=-1)
+ gaze_vector_norm = np.sqrt(np.sum((gaze - eye) ** 2))
+ pixel_mat_norm = np.sqrt(np.sum((pixel_mat - eye) ** 2, axis=-1))
+
+ gaze_cones = dot_prod / (gaze_vector_norm * pixel_mat_norm)
+ gaze_cones = np.nan_to_num(gaze_cones, nan=1)
+
+ theta = theta * (np.pi / 180)
+ beta = np.arccos(gaze_cones)
+ # Create mask where true if beta is less than theta/2
+ pixel_mat_presence = beta < (theta / 2)
+
+ # Zero out values outside the gaze cone
+ gaze_cones[~pixel_mat_presence] = 0
+ gaze_cones = np.clip(gaze_cones, 0, None)
+
+ return torch.from_numpy(gaze_cones).unsqueeze(0).float()
+
+
+def get_transform(
+ input_resolution: int, mean: Tuple[int, int, int], std: Tuple[int, int, int]
+):
+ return transforms.Compose(
+ [
+ transforms.Resize((input_resolution, input_resolution)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=mean, std=std),
+ ]
+ )
+
+
+def smooth_by_conv(window_size, df, col):
+ padded_track = pd.concat(
+ [
+ pd.DataFrame([[df.iloc[0][col]]] * (window_size // 2), columns=[0]),
+ df[col],
+ pd.DataFrame([[df.iloc[-1][col]]] * (window_size // 2), columns=[0]),
+ ]
+ )
+ smoothed_signals = np.convolve(
+ padded_track.squeeze(), np.ones(window_size) / window_size, mode="valid"
+ )
+ return smoothed_signals
diff --git a/data/gazefollow.py b/data/gazefollow.py
new file mode 100644
index 0000000000000000000000000000000000000000..68490853e991281d73ba1cbfa9053e2aab96aa25
--- /dev/null
+++ b/data/gazefollow.py
@@ -0,0 +1,295 @@
+from os import path as osp
+from typing import Callable, Optional
+
+import torch
+from torch.utils.data import Dataset
+from torchvision.transforms import functional as TF
+from PIL import Image
+import pandas as pd
+
+from . import augmentation
+from .masking import MaskGenerator
+from . import data_utils as utils
+
+
+class GazeFollow(Dataset):
+ def __init__(
+ self,
+ image_root: str,
+ anno_root: str,
+ head_root: str,
+ transform: Callable,
+ input_size: int,
+ output_size: int,
+ quant_labelmap: bool = True,
+ is_train: bool = True,
+ *,
+ mask_generator: Optional[MaskGenerator] = None,
+ bbox_jitter: float = 0.5,
+ rand_crop: float = 0.5,
+ rand_flip: float = 0.5,
+ color_jitter: float = 0.5,
+ rand_rotate: float = 0.0,
+ rand_lsj: float = 0.0,
+ ):
+ if is_train:
+ column_names = [
+ "path",
+ "idx",
+ "body_bbox_x",
+ "body_bbox_y",
+ "body_bbox_w",
+ "body_bbox_h",
+ "eye_x",
+ "eye_y",
+ "gaze_x",
+ "gaze_y",
+ "bbox_x_min",
+ "bbox_y_min",
+ "bbox_x_max",
+ "bbox_y_max",
+ "inout",
+ "meta0",
+ "meta1",
+ ]
+ df = pd.read_csv(
+ anno_root,
+ sep=",",
+ names=column_names,
+ index_col=False,
+ encoding="utf-8-sig",
+ )
+ df = df[
+ df["inout"] != -1
+ ] # only use "in" or "out "gaze. (-1 is invalid, 0 is out gaze)
+ df.reset_index(inplace=True)
+ self.y_train = df[
+ [
+ "bbox_x_min",
+ "bbox_y_min",
+ "bbox_x_max",
+ "bbox_y_max",
+ "eye_x",
+ "eye_y",
+ "gaze_x",
+ "gaze_y",
+ "inout",
+ ]
+ ]
+ self.X_train = df["path"]
+ self.length = len(df)
+ else:
+ column_names = [
+ "path",
+ "idx",
+ "body_bbox_x",
+ "body_bbox_y",
+ "body_bbox_w",
+ "body_bbox_h",
+ "eye_x",
+ "eye_y",
+ "gaze_x",
+ "gaze_y",
+ "bbox_x_min",
+ "bbox_y_min",
+ "bbox_x_max",
+ "bbox_y_max",
+ "meta0",
+ "meta1",
+ ]
+ df = pd.read_csv(
+ anno_root,
+ sep=",",
+ names=column_names,
+ index_col=False,
+ encoding="utf-8-sig",
+ )
+ df = df[
+ [
+ "path",
+ "eye_x",
+ "eye_y",
+ "gaze_x",
+ "gaze_y",
+ "bbox_x_min",
+ "bbox_y_min",
+ "bbox_x_max",
+ "bbox_y_max",
+ ]
+ ].groupby(["path", "eye_x"])
+ self.keys = list(df.groups.keys())
+ self.X_test = df
+ self.length = len(self.keys)
+
+ self.data_dir = image_root
+ self.head_dir = head_root
+ self.transform = transform
+ self.is_train = is_train
+
+ self.input_size = input_size
+ self.output_size = output_size
+
+ self.draw_labelmap = (
+ utils.draw_labelmap if quant_labelmap else utils.draw_labelmap_no_quant
+ )
+
+ if self.is_train:
+ ## data augmentation
+ self.augment = augmentation.AugmentationList(
+ [
+ augmentation.ColorJitter(color_jitter),
+ augmentation.BoxJitter(bbox_jitter),
+ augmentation.RandomCrop(rand_crop),
+ augmentation.RandomFlip(rand_flip),
+ augmentation.RandomRotate(rand_rotate),
+ augmentation.RandomLSJ(rand_lsj),
+ ]
+ )
+
+ self.mask_generator = mask_generator
+
+ def __getitem__(self, index):
+ if not self.is_train:
+ g = self.X_test.get_group(self.keys[index])
+ cont_gaze = []
+ for _, row in g.iterrows():
+ path = row["path"]
+ x_min = row["bbox_x_min"]
+ y_min = row["bbox_y_min"]
+ x_max = row["bbox_x_max"]
+ y_max = row["bbox_y_max"]
+ eye_x = row["eye_x"]
+ eye_y = row["eye_y"]
+ gaze_x = row["gaze_x"]
+ gaze_y = row["gaze_y"]
+ cont_gaze.append(
+ [gaze_x, gaze_y]
+ ) # all ground truth gaze are stacked up
+ for _ in range(len(cont_gaze), 20):
+ cont_gaze.append(
+ [-1, -1]
+ ) # pad dummy gaze to match size for batch processing
+ cont_gaze = torch.FloatTensor(cont_gaze)
+ gaze_inside = True # always consider test samples as inside
+ else:
+ path = self.X_train.iloc[index]
+ (
+ x_min,
+ y_min,
+ x_max,
+ y_max,
+ eye_x,
+ eye_y,
+ gaze_x,
+ gaze_y,
+ inout,
+ ) = self.y_train.iloc[index]
+ gaze_inside = bool(inout)
+
+ img = Image.open(osp.join(self.data_dir, path))
+ img = img.convert("RGB")
+ head_mask = Image.open(osp.join(self.head_dir, path))
+ width, height = img.size
+ x_min, y_min, x_max, y_max = map(float, [x_min, y_min, x_max, y_max])
+ if x_max < x_min:
+ x_min, x_max = x_max, x_min
+ if y_max < y_min:
+ y_min, y_max = y_max, y_min
+ # expand face bbox a bit
+ k = 0.1
+ x_min = max(x_min - k * abs(x_max - x_min), 0)
+ y_min = max(y_min - k * abs(y_max - y_min), 0)
+ x_max = min(x_max + k * abs(x_max - x_min), width - 1)
+ y_max = min(y_max + k * abs(y_max - y_min), height - 1)
+
+ if self.is_train:
+ img, bbox, gaze, head_mask, size = self.augment(
+ img,
+ (x_min, y_min, x_max, y_max),
+ (gaze_x, gaze_y),
+ head_mask,
+ (width, height),
+ )
+ x_min, y_min, x_max, y_max = bbox
+ gaze_x, gaze_y = gaze
+ width, height = size
+
+ head_channel = utils.get_head_box_channel(
+ x_min,
+ y_min,
+ x_max,
+ y_max,
+ width,
+ height,
+ resolution=self.input_size,
+ coordconv=False,
+ ).unsqueeze(0)
+
+ if self.is_train and self.mask_generator is not None:
+ image_mask = self.mask_generator(
+ x_min / width,
+ y_min / height,
+ x_max / width,
+ y_max / height,
+ head_channel,
+ )
+
+ if self.transform is not None:
+ img = self.transform(img)
+ head_mask = TF.to_tensor(
+ TF.resize(head_mask, (self.input_size, self.input_size))
+ )
+
+ # generate the heat map used for deconv prediction
+ gaze_heatmap = torch.zeros(
+ self.output_size, self.output_size
+ ) # set the size of the output
+ if not self.is_train: # aggregated heatmap
+ num_valid = 0
+ for gaze_x, gaze_y in cont_gaze:
+ if gaze_x != -1:
+ num_valid += 1
+ gaze_heatmap += self.draw_labelmap(
+ torch.zeros(self.output_size, self.output_size),
+ [gaze_x * self.output_size, gaze_y * self.output_size],
+ 3,
+ type="Gaussian",
+ )
+ gaze_heatmap /= num_valid
+ else:
+ # if gaze_inside:
+ gaze_heatmap = self.draw_labelmap(
+ gaze_heatmap,
+ [gaze_x * self.output_size, gaze_y * self.output_size],
+ 3,
+ type="Gaussian",
+ )
+
+ imsize = torch.IntTensor([width, height])
+
+ if self.is_train:
+ out_dict = {
+ "images": img,
+ "head_channels": head_channel,
+ "heatmaps": gaze_heatmap,
+ "gazes": torch.FloatTensor([gaze_x, gaze_y]),
+ "gaze_inouts": torch.FloatTensor([gaze_inside]),
+ "head_masks": head_mask,
+ "imsize": imsize,
+ }
+ if self.mask_generator is not None:
+ out_dict["image_masks"] = image_mask
+ return out_dict
+ else:
+ return {
+ "images": img,
+ "head_channels": head_channel,
+ "heatmaps": gaze_heatmap,
+ "gazes": cont_gaze,
+ "gaze_inouts": torch.FloatTensor([gaze_inside]),
+ "head_masks": head_mask,
+ "imsize": imsize,
+ }
+
+ def __len__(self):
+ return self.length
diff --git a/data/masking.py b/data/masking.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9736d9f231bec5767fcf2ed502178553ad379a3
--- /dev/null
+++ b/data/masking.py
@@ -0,0 +1,175 @@
+import random
+import math
+import numpy as np
+import torch
+from torch.nn import functional as F
+
+
+class SceneMaskGenerator:
+ def __init__(
+ self,
+ input_size,
+ min_num_patches=16,
+ max_num_patches_ratio=0.5,
+ min_aspect=0.3,
+ ):
+ if not isinstance(input_size, tuple):
+ input_size = (input_size,) * 2
+ self.input_size = input_size
+ self.num_patches = input_size[0] * input_size[1]
+
+ self.min_num_patches = min_num_patches
+ self.max_num_patches = max_num_patches_ratio * self.num_patches
+
+ self.log_aspect_ratio = (math.log(min_aspect), -math.log(min_aspect))
+
+ def _mask(self, mask, max_mask_patches):
+ delta = 0
+ for _ in range(4):
+ target_area = random.uniform(self.min_num_patches, max_mask_patches)
+ aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
+ h = int(round(math.sqrt(target_area * aspect_ratio)))
+ w = int(round(math.sqrt(target_area / aspect_ratio)))
+ height, width = self.input_size
+ if w < width and h < height:
+ top = random.randint(0, height - h)
+ left = random.randint(0, width - w)
+
+ num_masked = mask[top : top + h, left : left + w].sum()
+ # Overlap
+ if 0 < h * w - num_masked <= max_mask_patches:
+ mask[top : top + h, left : left + w] = 1
+ delta = h * w - num_masked
+ break
+ return delta
+
+ def __call__(self, head_mask):
+ mask = np.zeros(shape=self.input_size, dtype=bool)
+ mask_count = 0
+ num_masking_patches = random.uniform(self.min_num_patches, self.max_num_patches)
+ while mask_count < num_masking_patches:
+ max_mask_patches = num_masking_patches - mask_count
+ delta = self._mask(mask, max_mask_patches)
+ if delta == 0:
+ break
+ else:
+ mask_count += delta
+
+ mask = torch.from_numpy(mask).unsqueeze(0)
+ head_mask = (
+ F.interpolate(head_mask.unsqueeze(0), mask.shape[-2:]).squeeze(0) < 0.5
+ )
+ return torch.logical_and(mask, head_mask).squeeze(0)
+
+
+class HeadMaskGenerator:
+ def __init__(
+ self,
+ input_size,
+ min_num_patches=4,
+ max_num_patches_ratio=0.5,
+ min_aspect=0.3,
+ ):
+ if not isinstance(input_size, tuple):
+ input_size = (input_size,) * 2
+ self.input_size = input_size
+ self.num_patches = input_size[0] * input_size[1]
+
+ self.min_num_patches = min_num_patches
+ self.max_num_patches_ratio = max_num_patches_ratio
+
+ self.log_aspect_ratio = (math.log(min_aspect), -math.log(min_aspect))
+
+ def __call__(
+ self,
+ x_min,
+ y_min,
+ x_max,
+ y_max, # coords in [0,1]
+ ):
+ height = math.floor((y_max - y_min) * self.input_size[0])
+ width = math.floor((x_max - x_min) * self.input_size[1])
+ origin_area = width * height
+ if origin_area < self.min_num_patches:
+ return torch.zeros(size=self.input_size, dtype=bool)
+
+ target_area = random.uniform(
+ self.min_num_patches, self.max_num_patches_ratio * origin_area
+ )
+ aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
+ h = min(int(round(math.sqrt(target_area * aspect_ratio))), height)
+ w = min(int(round(math.sqrt(target_area / aspect_ratio))), width)
+ top = random.randint(0, height - h) + int(y_min * self.input_size[0])
+ left = random.randint(0, width - w) + int(x_min * self.input_size[1])
+ mask = torch.zeros(size=self.input_size, dtype=bool)
+ mask[top : top + h, left : left + w] = True
+ return mask
+
+
+class MaskGenerator:
+ def __init__(
+ self,
+ input_size,
+ mask_scene: bool = False,
+ mask_head: bool = False,
+ min_scene_patches=16,
+ max_scene_patches_ratio=0.5,
+ min_head_patches=4,
+ max_head_patches_ratio=0.5,
+ min_aspect=0.3,
+ mask_prob=0.2,
+ head_prob=0.2,
+ ):
+ if not isinstance(input_size, tuple):
+ input_size = (input_size,) * 2
+ self.input_size = input_size
+ if mask_scene:
+ self.scene_mask_generator = SceneMaskGenerator(
+ input_size, min_scene_patches, max_scene_patches_ratio, min_aspect
+ )
+ else:
+ self.scene_mask_generator = None
+
+ if mask_head:
+ self.head_mask_generator = HeadMaskGenerator(
+ input_size, min_head_patches, max_head_patches_ratio, min_aspect
+ )
+ else:
+ self.head_mask_generator = None
+
+ self.no_mask = not (mask_scene or mask_head)
+ self.mask_head = mask_head and not mask_scene
+ self.mask_scene = mask_scene and not mask_head
+ self.scene_prob = mask_prob
+ self.head_prob = head_prob
+
+ def __call__(
+ self,
+ x_min,
+ y_min,
+ x_max,
+ y_max,
+ head_mask,
+ ):
+ mask_scene = random.random() < self.scene_prob
+ mask_head = random.random() < self.head_prob
+ no_mask = (
+ self.no_mask
+ or (self.mask_head and not mask_head)
+ or (self.mask_scene and not mask_scene)
+ or not (mask_scene or mask_head)
+ )
+ if no_mask:
+ return torch.zeros(size=self.input_size, dtype=bool)
+ if self.mask_scene:
+ return self.scene_mask_generator(head_mask)
+ if self.mask_head:
+ return self.head_mask_generator(x_min, y_min, x_max, y_max)
+ if mask_head and mask_scene:
+ return torch.logical_or(
+ self.scene_mask_generator(head_mask),
+ self.head_mask_generator(x_min, y_min, x_max, y_max),
+ )
+ elif mask_head:
+ return self.head_mask_generator(x_min, y_min, x_max, y_max)
+ return self.scene_mask_generator(head_mask)
diff --git a/data/video_attention_target.py b/data/video_attention_target.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce4b50133b3fc26d62ee1950802670da5d474817
--- /dev/null
+++ b/data/video_attention_target.py
@@ -0,0 +1,228 @@
+import glob
+from typing import Callable, Optional
+from os import path as osp
+
+import torch
+from torch.utils.data.dataset import Dataset
+import torchvision.transforms.functional as TF
+import numpy as np
+import pandas as pd
+from PIL import Image
+
+from . import augmentation
+from . import data_utils as utils
+from .masking import MaskGenerator
+
+
+class VideoAttentionTarget(Dataset):
+ def __init__(
+ self,
+ image_root: str,
+ anno_root: str,
+ head_root: str,
+ transform: Callable,
+ input_size: int,
+ output_size: int,
+ quant_labelmap: bool = True,
+ is_train: bool = True,
+ *,
+ mask_generator: Optional[MaskGenerator] = None,
+ bbox_jitter: float = 0.5,
+ rand_crop: float = 0.5,
+ rand_flip: float = 0.5,
+ color_jitter: float = 0.5,
+ rand_rotate: float = 0.0,
+ rand_lsj: float = 0.0,
+ ):
+ frames = []
+ for show_dir in glob.glob(osp.join(anno_root, "*")):
+ for sequence_path in glob.glob(osp.join(show_dir, "*", "*.txt")):
+ df = pd.read_csv(
+ sequence_path,
+ header=None,
+ index_col=False,
+ names=[
+ "path",
+ "x_min",
+ "y_min",
+ "x_max",
+ "y_max",
+ "gaze_x",
+ "gaze_y",
+ ],
+ )
+
+ show_name = sequence_path.split("/")[-3]
+ clip = sequence_path.split("/")[-2]
+ df["path"] = df["path"].apply(
+ lambda path: osp.join(show_name, clip, path)
+ )
+ # Add two columns for the bbox center
+ df["eye_x"] = (df["x_min"] + df["x_max"]) / 2
+ df["eye_y"] = (df["y_min"] + df["y_max"]) / 2
+ df = df.sample(frac=0.2, random_state=42)
+ frames.extend(df.values.tolist())
+
+ df = pd.DataFrame(
+ frames,
+ columns=[
+ "path",
+ "x_min",
+ "y_min",
+ "x_max",
+ "y_max",
+ "gaze_x",
+ "gaze_y",
+ "eye_x",
+ "eye_y",
+ ],
+ )
+ # Drop rows with invalid bboxes
+ coords = torch.tensor(
+ np.array(
+ (
+ df["x_min"].values,
+ df["y_min"].values,
+ df["x_max"].values,
+ df["y_max"].values,
+ )
+ ).transpose(1, 0)
+ )
+ valid_bboxes = (coords[:, 2:] >= coords[:, :2]).all(dim=1)
+ df = df.loc[valid_bboxes.tolist(), :]
+ df.reset_index(inplace=True)
+ self.df = df
+ self.length = len(df)
+
+ self.data_dir = image_root
+ self.head_dir = head_root
+ self.transform = transform
+ self.draw_labelmap = (
+ utils.draw_labelmap if quant_labelmap else utils.draw_labelmap_no_quant
+ )
+ self.is_train = is_train
+
+ self.input_size = input_size
+ self.output_size = output_size
+
+ if self.is_train:
+ ## data augmentation
+ self.augment = augmentation.AugmentationList(
+ [
+ augmentation.ColorJitter(color_jitter),
+ augmentation.BoxJitter(bbox_jitter),
+ augmentation.RandomCrop(rand_crop),
+ augmentation.RandomFlip(rand_flip),
+ augmentation.RandomRotate(rand_rotate),
+ augmentation.RandomLSJ(rand_lsj),
+ ]
+ )
+
+ self.mask_generator = mask_generator
+
+ def __getitem__(self, index):
+ (
+ _,
+ path,
+ x_min,
+ y_min,
+ x_max,
+ y_max,
+ gaze_x,
+ gaze_y,
+ eye_x,
+ eye_y,
+ ) = self.df.iloc[index]
+ gaze_inside = gaze_x != -1 or gaze_y != -1
+
+ img = Image.open(osp.join(self.data_dir, path))
+ img = img.convert("RGB")
+ width, height = img.size
+ # Since we finetune from weights trained on GazeFollow,
+ # we don't incorporate the auxiliary task for VAT.
+ if osp.exists(osp.join(self.head_dir, path)):
+ head_mask = Image.open(osp.join(self.head_dir, path)).resize(
+ (width, height)
+ )
+ else:
+ head_mask = Image.fromarray(np.zeros((height, width), dtype=np.float32))
+ x_min, y_min, x_max, y_max = map(float, [x_min, y_min, x_max, y_max])
+ if x_max < x_min:
+ x_min, x_max = x_max, x_min
+ if y_max < y_min:
+ y_min, y_max = y_max, y_min
+ gaze_x, gaze_y = gaze_x / width, gaze_y / height
+ # expand face bbox a bit
+ k = 0.1
+ x_min = max(x_min - k * abs(x_max - x_min), 0)
+ y_min = max(y_min - k * abs(y_max - y_min), 0)
+ x_max = min(x_max + k * abs(x_max - x_min), width - 1)
+ y_max = min(y_max + k * abs(y_max - y_min), height - 1)
+
+ if self.is_train:
+ img, bbox, gaze, head_mask, size = self.augment(
+ img,
+ (x_min, y_min, x_max, y_max),
+ (gaze_x, gaze_y),
+ head_mask,
+ (width, height),
+ )
+ x_min, y_min, x_max, y_max = bbox
+ gaze_x, gaze_y = gaze
+ width, height = size
+
+ head_channel = utils.get_head_box_channel(
+ x_min,
+ y_min,
+ x_max,
+ y_max,
+ width,
+ height,
+ resolution=self.input_size,
+ coordconv=False,
+ ).unsqueeze(0)
+
+ if self.is_train and self.mask_generator is not None:
+ image_mask = self.mask_generator(
+ x_min / width,
+ y_min / height,
+ x_max / width,
+ y_max / height,
+ head_channel,
+ )
+
+ if self.transform is not None:
+ img = self.transform(img)
+ head_mask = TF.to_tensor(
+ TF.resize(head_mask, (self.input_size, self.input_size))
+ )
+
+ # generate the heat map used for deconv prediction
+ gaze_heatmap = torch.zeros(
+ self.output_size, self.output_size
+ ) # set the size of the output
+
+ gaze_heatmap = self.draw_labelmap(
+ gaze_heatmap,
+ [gaze_x * self.output_size, gaze_y * self.output_size],
+ 3,
+ type="Gaussian",
+ )
+
+ imsize = torch.IntTensor([width, height])
+
+ out_dict = {
+ "images": img,
+ "head_channels": head_channel,
+ "heatmaps": gaze_heatmap,
+ "gazes": torch.FloatTensor([gaze_x, gaze_y]),
+ "gaze_inouts": torch.FloatTensor([gaze_inside]),
+ "head_masks": head_mask,
+ "imsize": imsize,
+ }
+ if self.is_train and self.mask_generator is not None:
+ out_dict["image_masks"] = image_mask
+ return out_dict
+
+ def __len__(self):
+ return self.length
diff --git a/data/video_attention_target_video.py b/data/video_attention_target_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2a1d4711f91b0a76166abeaefe76feaa6329064
--- /dev/null
+++ b/data/video_attention_target_video.py
@@ -0,0 +1,464 @@
+import math
+from os import path as osp
+from typing import Callable, Optional
+import glob
+import torch
+from torch.utils.data.dataset import Dataset
+import torchvision.transforms.functional as TF
+import numpy as np
+from PIL import Image, ImageOps
+import pandas as pd
+from .masking import MaskGenerator
+from . import data_utils as utils
+
+
+class VideoAttentionTargetVideo(Dataset):
+ def __init__(
+ self,
+ image_root: str,
+ anno_root: str,
+ head_root: str,
+ transform: Callable,
+ input_size: int,
+ output_size: int,
+ quant_labelmap: bool = True,
+ is_train: bool = True,
+ seq_len: int = 8,
+ max_len: int = 32,
+ *,
+ mask_generator: Optional[MaskGenerator] = None,
+ bbox_jitter: float = 0.5,
+ rand_crop: float = 0.5,
+ rand_flip: float = 0.5,
+ color_jitter: float = 0.5,
+ rand_rotate: float = 0.0,
+ rand_lsj: float = 0.0,
+ ):
+ dfs = []
+ for show_dir in glob.glob(osp.join(anno_root, "*")):
+ for sequence_path in glob.glob(osp.join(show_dir, "*", "*.txt")):
+ df = pd.read_csv(
+ sequence_path,
+ header=None,
+ index_col=False,
+ names=[
+ "path",
+ "x_min",
+ "y_min",
+ "x_max",
+ "y_max",
+ "gaze_x",
+ "gaze_y",
+ ],
+ )
+ show_name = sequence_path.split("/")[-3]
+ clip = sequence_path.split("/")[-2]
+ df["path"] = df["path"].apply(
+ lambda path: osp.join(show_name, clip, path)
+ )
+ cur_len = len(df.index)
+ if is_train:
+ if cur_len <= max_len:
+ if cur_len >= seq_len:
+ dfs.append(df)
+ continue
+ remainder = cur_len % max_len
+ df_splits = [
+ df[i : i + max_len]
+ for i in range(0, cur_len - max_len, max_len)
+ ]
+ if remainder >= seq_len:
+ df_splits.append(df[-remainder:])
+ dfs.extend(df_splits)
+ else:
+ if cur_len < seq_len:
+ continue
+ df_splits = [
+ df[i : i + seq_len]
+ for i in range(0, cur_len - seq_len, seq_len)
+ ]
+ dfs.extend(df_splits)
+
+ for df in dfs:
+ df.reset_index(inplace=True)
+ self.dfs = dfs
+ self.length = len(dfs)
+
+ self.data_dir = image_root
+ self.head_dir = head_root
+ self.transform = transform
+ self.draw_labelmap = (
+ utils.draw_labelmap if quant_labelmap else utils.draw_labelmap_no_quant
+ )
+ self.is_train = is_train
+
+ self.input_size = input_size
+ self.output_size = output_size
+ self.seq_len = seq_len
+
+ if self.is_train:
+ self.bbox_jitter = bbox_jitter
+ self.rand_crop = rand_crop
+ self.rand_flip = rand_flip
+ self.color_jitter = color_jitter
+ self.rand_rotate = rand_rotate
+ self.rand_lsj = rand_lsj
+ self.mask_generator = mask_generator
+
+ def __getitem__(self, index):
+ df = self.dfs[index]
+ seq_len = len(df.index)
+ for coord in ["x_min", "y_min", "x_max", "y_max"]:
+ df[coord] = utils.smooth_by_conv(11, df, coord)
+
+ if self.is_train:
+ # cond for data augmentation
+ cond_jitter = np.random.random_sample()
+ cond_flip = np.random.random_sample()
+ cond_color = np.random.random_sample()
+ if cond_color < self.color_jitter:
+ n1 = np.random.uniform(0.5, 1.5)
+ n2 = np.random.uniform(0.5, 1.5)
+ n3 = np.random.uniform(0.5, 1.5)
+ cond_crop = np.random.random_sample()
+ cond_rotate = np.random.random_sample()
+ if cond_rotate < self.rand_rotate:
+ angle = (2 * np.random.random_sample() - 1) * 20
+ angle = -math.radians(angle)
+ cond_lsj = np.random.random_sample()
+ if cond_lsj < self.rand_lsj:
+ lsj_scale = 0.1 + np.random.random_sample() * 0.9
+
+ # if longer than seq_len_limit, cut it down to the limit with the init index randomly sampled
+ if seq_len > self.seq_len:
+ sampled_ind = np.random.randint(0, seq_len - self.seq_len)
+ seq_len = self.seq_len
+ else:
+ sampled_ind = 0
+
+ if cond_crop < self.rand_crop:
+ sliced_x_min = df["x_min"].iloc[sampled_ind : sampled_ind + seq_len]
+ sliced_x_max = df["x_max"].iloc[sampled_ind : sampled_ind + seq_len]
+ sliced_y_min = df["y_min"].iloc[sampled_ind : sampled_ind + seq_len]
+ sliced_y_max = df["y_max"].iloc[sampled_ind : sampled_ind + seq_len]
+
+ sliced_gaze_x = df["gaze_x"].iloc[sampled_ind : sampled_ind + seq_len]
+ sliced_gaze_y = df["gaze_y"].iloc[sampled_ind : sampled_ind + seq_len]
+
+ check_sum = sliced_gaze_x.sum() + sliced_gaze_y.sum()
+ all_outside = check_sum == -2 * seq_len
+
+ # Calculate the minimum valid range of the crop that doesn't exclude the face and the gaze target
+ if all_outside:
+ crop_x_min = np.min([sliced_x_min.min(), sliced_x_max.min()])
+ crop_y_min = np.min([sliced_y_min.min(), sliced_y_max.min()])
+ crop_x_max = np.max([sliced_x_min.max(), sliced_x_max.max()])
+ crop_y_max = np.max([sliced_y_min.max(), sliced_y_max.max()])
+ else:
+ crop_x_min = np.min(
+ [sliced_gaze_x.min(), sliced_x_min.min(), sliced_x_max.min()]
+ )
+ crop_y_min = np.min(
+ [sliced_gaze_y.min(), sliced_y_min.min(), sliced_y_max.min()]
+ )
+ crop_x_max = np.max(
+ [sliced_gaze_x.max(), sliced_x_min.max(), sliced_x_max.max()]
+ )
+ crop_y_max = np.max(
+ [sliced_gaze_y.max(), sliced_y_min.max(), sliced_y_max.max()]
+ )
+
+ # Randomly select a random top left corner
+ if crop_x_min >= 0:
+ crop_x_min = np.random.uniform(0, crop_x_min)
+ if crop_y_min >= 0:
+ crop_y_min = np.random.uniform(0, crop_y_min)
+
+ # Get image size
+ path = osp.join(self.data_dir, df["path"].iloc[0])
+ img = Image.open(path)
+ img = img.convert("RGB")
+ width, height = img.size
+
+ # Find the range of valid crop width and height starting from the (crop_x_min, crop_y_min)
+ crop_width_min = crop_x_max - crop_x_min
+ crop_height_min = crop_y_max - crop_y_min
+ crop_width_max = width - crop_x_min
+ crop_height_max = height - crop_y_min
+ # Randomly select a width and a height
+ crop_width = np.random.uniform(crop_width_min, crop_width_max)
+ crop_height = np.random.uniform(crop_height_min, crop_height_max)
+
+ # Round to integers
+ crop_y_min, crop_x_min, crop_height, crop_width = map(
+ int, map(round, (crop_y_min, crop_x_min, crop_height, crop_width))
+ )
+ else:
+ sampled_ind = 0
+
+ images = []
+ head_channels = []
+ heatmaps = []
+ gazes = []
+ gaze_inouts = []
+ imsizes = []
+ head_masks = []
+ if self.is_train and self.mask_generator is not None:
+ image_masks = []
+ for i, row in df.iterrows():
+ if self.is_train and (i < sampled_ind or i >= (sampled_ind + self.seq_len)):
+ continue
+
+ x_min = row["x_min"] # note: Already in image coordinates
+ y_min = row["y_min"] # note: Already in image coordinates
+ x_max = row["x_max"] # note: Already in image coordinates
+ y_max = row["y_max"] # note: Already in image coordinates
+ gaze_x = row["gaze_x"] # note: Already in image coordinates
+ gaze_y = row["gaze_y"] # note: Already in image coordinates
+
+ if x_min > x_max:
+ x_min, x_max = x_max, x_min
+ if y_min > y_max:
+ y_min, y_max = y_max, y_min
+
+ path = row["path"]
+ img = Image.open(osp.join(self.data_dir, path)).convert("RGB")
+ width, height = img.size
+ imsize = torch.FloatTensor([width, height])
+ imsizes.append(imsize)
+ # Since we finetune from weights trained on GazeFollow,
+ # we don't incorporate the auxiliary task for VAT.
+ if osp.exists(osp.join(self.head_dir, path)):
+ head_mask = Image.open(osp.join(self.head_dir, path)).resize(
+ (width, height)
+ )
+ else:
+ head_mask = Image.fromarray(np.zeros((height, width), dtype=np.float32))
+
+ x_min, y_min, x_max, y_max = map(float, [x_min, y_min, x_max, y_max])
+ gaze_x, gaze_y = map(float, [gaze_x, gaze_y])
+ if gaze_x == -1 and gaze_y == -1:
+ gaze_inside = False
+ else:
+ if (
+ gaze_x < 0
+ ): # move gaze point that was sliglty outside the image back in
+ gaze_x = 0
+ if gaze_y < 0:
+ gaze_y = 0
+ gaze_inside = True
+
+ if self.is_train:
+ ## data augmentation
+ # Jitter (expansion-only) bounding box size.
+ if cond_jitter < self.bbox_jitter:
+ k = cond_jitter * 0.1
+ x_min -= k * abs(x_max - x_min)
+ y_min -= k * abs(y_max - y_min)
+ x_max += k * abs(x_max - x_min)
+ y_max += k * abs(y_max - y_min)
+ x_min = np.clip(x_min, 0, width - 1)
+ x_max = np.clip(x_max, 0, width - 1)
+ y_min = np.clip(y_min, 0, height - 1)
+ y_max = np.clip(y_max, 0, height - 1)
+
+ # Random color change
+ if cond_color < self.color_jitter:
+ img = TF.adjust_brightness(img, brightness_factor=n1)
+ img = TF.adjust_contrast(img, contrast_factor=n2)
+ img = TF.adjust_saturation(img, saturation_factor=n3)
+
+ # Random Crop
+ if cond_crop < self.rand_crop:
+ # Crop it
+ img = TF.crop(img, crop_y_min, crop_x_min, crop_height, crop_width)
+ head_mask = TF.crop(
+ head_mask, crop_y_min, crop_x_min, crop_height, crop_width
+ )
+
+ # Record the crop's (x, y) offset
+ offset_x, offset_y = crop_x_min, crop_y_min
+
+ # convert coordinates into the cropped frame
+ x_min, y_min, x_max, y_max = (
+ x_min - offset_x,
+ y_min - offset_y,
+ x_max - offset_x,
+ y_max - offset_y,
+ )
+ if gaze_inside:
+ gaze_x, gaze_y = (gaze_x - offset_x), (gaze_y - offset_y)
+ else:
+ gaze_x = -1
+ gaze_y = -1
+
+ width, height = crop_width, crop_height
+
+ # Flip?
+ if cond_flip < self.rand_flip:
+ img = img.transpose(Image.FLIP_LEFT_RIGHT)
+ head_mask = head_mask.transpose(Image.FLIP_LEFT_RIGHT)
+ x_max_2 = width - x_min
+ x_min_2 = width - x_max
+ x_max = x_max_2
+ x_min = x_min_2
+ if gaze_x != -1 and gaze_y != -1:
+ gaze_x = width - gaze_x
+
+ # Random Rotation
+ if cond_rotate < self.rand_rotate:
+ rot_mat = [
+ round(math.cos(angle), 15),
+ round(math.sin(angle), 15),
+ 0.0,
+ round(-math.sin(angle), 15),
+ round(math.cos(angle), 15),
+ 0.0,
+ ]
+
+ def _transform(x, y, matrix):
+ return (
+ matrix[0] * x + matrix[1] * y + matrix[2],
+ matrix[3] * x + matrix[4] * y + matrix[5],
+ )
+
+ def _inv_transform(x, y, matrix):
+ x, y = x - matrix[2], y - matrix[5]
+ return (
+ matrix[0] * x + matrix[3] * y,
+ matrix[1] * x + matrix[4] * y,
+ )
+
+ # Calculate offsets
+ rot_center = (width / 2.0, height / 2.0)
+ rot_mat[2], rot_mat[5] = _transform(
+ -rot_center[0], -rot_center[1], rot_mat
+ )
+ rot_mat[2] += rot_center[0]
+ rot_mat[5] += rot_center[1]
+ xx = []
+ yy = []
+ for x, y in ((0, 0), (width, 0), (width, height), (0, height)):
+ x, y = _transform(x, y, rot_mat)
+ xx.append(x)
+ yy.append(y)
+ nw = math.ceil(max(xx)) - math.floor(min(xx))
+ nh = math.ceil(max(yy)) - math.floor(min(yy))
+ rot_mat[2], rot_mat[5] = _transform(
+ -(nw - width) / 2.0, -(nh - height) / 2.0, rot_mat
+ )
+
+ img = img.transform((nw, nh), Image.AFFINE, rot_mat, Image.BILINEAR)
+ head_mask = head_mask.transform(
+ (nw, nh), Image.AFFINE, rot_mat, Image.BILINEAR
+ )
+
+ xx = []
+ yy = []
+ for x, y in (
+ (x_min, y_min),
+ (x_min, y_max),
+ (x_max, y_min),
+ (x_max, y_max),
+ ):
+ x, y = _inv_transform(x, y, rot_mat)
+ xx.append(x)
+ yy.append(y)
+ x_max, x_min = min(max(xx), nw), max(min(xx), 0)
+ y_max, y_min = min(max(yy), nh), max(min(yy), 0)
+ gaze_x, gaze_y = _inv_transform(gaze_x, gaze_y, rot_mat)
+ width, height = nw, nh
+
+ if cond_lsj < self.rand_lsj:
+ nh, nw = int(height * lsj_scale), int(width * lsj_scale)
+ img = TF.resize(img, (nh, nw))
+ img = ImageOps.expand(img, (0, 0, width - nw, height - nh))
+ head_mask = TF.resize(head_mask, (nh, nw))
+ head_mask = ImageOps.expand(
+ head_mask, (0, 0, width - nw, height - nh)
+ )
+ x_min, y_min, x_max, y_max = (
+ x_min * lsj_scale,
+ y_min * lsj_scale,
+ x_max * lsj_scale,
+ y_max * lsj_scale,
+ )
+ gaze_x, gaze_y = gaze_x * lsj_scale, gaze_y * lsj_scale
+
+ head_channel = utils.get_head_box_channel(
+ x_min,
+ y_min,
+ x_max,
+ y_max,
+ width,
+ height,
+ resolution=self.input_size,
+ coordconv=False,
+ ).unsqueeze(0)
+
+ if self.is_train and self.mask_generator is not None:
+ image_mask = self.mask_generator(
+ x_min / width,
+ y_min / height,
+ x_max / width,
+ y_max / height,
+ head_channel,
+ )
+ image_masks.append(image_mask)
+
+ if self.transform is not None:
+ img = self.transform(img)
+ head_mask = TF.to_tensor(
+ TF.resize(head_mask, (self.input_size, self.input_size))
+ )
+
+ if gaze_inside:
+ gaze_x /= float(width) # fractional gaze
+ gaze_y /= float(height)
+ gaze_heatmap = torch.zeros(
+ self.output_size, self.output_size
+ ) # set the size of the output
+ gaze_map = self.draw_labelmap(
+ gaze_heatmap,
+ [gaze_x * self.output_size, gaze_y * self.output_size],
+ 3,
+ type="Gaussian",
+ )
+ gazes.append(torch.FloatTensor([gaze_x, gaze_y]))
+ else:
+ gaze_map = torch.zeros(self.output_size, self.output_size)
+ gazes.append(torch.FloatTensor([-1, -1]))
+ images.append(img)
+ head_channels.append(head_channel)
+ head_masks.append(head_mask)
+ heatmaps.append(gaze_map)
+ gaze_inouts.append(torch.FloatTensor([int(gaze_inside)]))
+
+ images = torch.stack(images)
+ head_channels = torch.stack(head_channels)
+ heatmaps = torch.stack(heatmaps)
+ gazes = torch.stack(gazes)
+ gaze_inouts = torch.stack(gaze_inouts)
+ head_masks = torch.stack(head_masks)
+ imsizes = torch.stack(imsizes)
+
+ out_dict = {
+ "images": images,
+ "head_channels": head_channels,
+ "heatmaps": heatmaps,
+ "gazes": gazes,
+ "gaze_inouts": gaze_inouts,
+ "head_masks": head_masks,
+ "imsize": imsizes,
+ }
+ if self.is_train and self.mask_generator is not None:
+ out_dict["image_masks"] = torch.stack(image_masks)
+ return out_dict
+
+ def __len__(self):
+ return self.length
+
+
+def video_collate(batch):
+ keys = batch[0].keys()
+ return {key: torch.cat([item[key] for item in batch]) for key in keys}
diff --git a/docs/eval.md b/docs/eval.md
new file mode 100644
index 0000000000000000000000000000000000000000..7b8f8b004a416f4ddeaca29f4248b1d2534df511
--- /dev/null
+++ b/docs/eval.md
@@ -0,0 +1,24 @@
+## Eval
+### Testing Dataset
+
+You should prepare GazeFollow and VideoAttentionTarget for training.
+
+* Get [GazeFollow](https://www.dropbox.com/s/3ejt9pm57ht2ed4/gazefollow_extended.zip?dl=0).
+* If train with auxiliary regression, use `scripts\gen_gazefollow_head_masks.py` to generate head masks.
+* Get [VideoAttentionTarget](https://www.dropbox.com/s/8ep3y1hd74wdjy5/videoattentiontarget.zip?dl=0).
+
+Check `ViTGaze/configs/common/dataloader` to modify DATA_ROOT.
+
+### Evaluation
+
+Run
+```
+bash val.sh configs/gazefollow_518.py ${Path2checkpoint} gf
+```
+to evaluate on GazeFollow.
+
+Run
+```
+bash val.sh configs/videoattentiontarget.py ${Path2checkpoint} vat
+```
+to evaluate on VideoAttentionTarget.
diff --git a/docs/install.md b/docs/install.md
new file mode 100644
index 0000000000000000000000000000000000000000..1e3215de27b62eed536398db1a522542bc603252
--- /dev/null
+++ b/docs/install.md
@@ -0,0 +1,19 @@
+## Installation
+
+* Create a conda virtual env and activate it.
+
+ ```
+ conda create -n ViTGaze python==3.9.18
+ conda activate ViTGaze
+ ```
+* Install packages.
+
+ ```
+ cd path/to/ViTGaze
+ pip install -r requirements.txt
+ ```
+* Install [detectron2](https://github.com/facebookresearch/detectron2) , follow its [documentation](https://detectron2.readthedocs.io/en/latest/).
+ For ViTGaze, we recommend to build it from latest source code.
+ ```
+ python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'
+ ```
diff --git a/docs/train.md b/docs/train.md
new file mode 100644
index 0000000000000000000000000000000000000000..ad0fcc89ba3e2c51fb6153d912b53a73ec2d60f4
--- /dev/null
+++ b/docs/train.md
@@ -0,0 +1,36 @@
+## Train
+
+### Training Dataset
+
+You should prepare GazeFollow and VideoAttentionTarget for training.
+
+* Get [GazeFollow](https://www.dropbox.com/s/3ejt9pm57ht2ed4/gazefollow_extended.zip?dl=0).
+* If train with auxiliary regression, use `scripts\gen_gazefollow_head_masks.py` to generate head masks.
+* Get [VideoAttentionTarget](https://www.dropbox.com/s/8ep3y1hd74wdjy5/videoattentiontarget.zip?dl=0).
+
+Check `ViTGaze/configs/common/dataloader` to modify DATA_ROOT.
+
+### Pretrained Model
+
+* Get [DINOv2](https://github.com/facebookresearch/dinov2) pretrained ViT-S.
+* Or you could download and preprocess pretrained weights by
+
+ ```
+ cd ViTGaze
+ mkdir pretrained && cd pretrained
+ wget https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth
+ ```
+* Preprocess the model weights with `scripts\convert_pth.py` to fit Detectron2 format.
+### Train ViTGaze
+
+You can modify configs in `configs/gazefollow.py`, `configs/gazefollow_518.py` and `configs/videoattentiontarget.py`.
+
+Run:
+
+```
+ bash train.sh
+```
+
+to train ViTGaze on the two datasets.
+
+Training output will be saved in `ViTGaze/output/`.
diff --git a/engine/__init__.py b/engine/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3015fc1f3dbdb29d72faa3a0f0d90e5f2d7d472b
--- /dev/null
+++ b/engine/__init__.py
@@ -0,0 +1 @@
+from .trainer import CycleTrainer
diff --git a/engine/trainer.py b/engine/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a49bb68fb7cabd1285a97037149d0ebf2b6c6f37
--- /dev/null
+++ b/engine/trainer.py
@@ -0,0 +1,37 @@
+import time
+import torch
+from detectron2.engine import SimpleTrainer
+from typing import Iterable, Generator
+
+
+def cycle(iterable: Iterable) -> Generator:
+ while True:
+ for item in iterable:
+ yield item
+
+
+class CycleTrainer(SimpleTrainer):
+ def __init__(
+ self,
+ model,
+ data_loader,
+ optimizer,
+ gather_metric_period=1,
+ zero_grad_before_forward=False,
+ async_write_metrics=False,
+ ):
+ super().__init__(
+ model,
+ data_loader,
+ optimizer,
+ gather_metric_period,
+ zero_grad_before_forward,
+ async_write_metrics,
+ )
+
+ @property
+ def _data_loader_iter(self):
+ # only create the data loader iterator when it is used
+ if self._data_loader_iter_obj is None:
+ self._data_loader_iter_obj = cycle(self.data_loader)
+ return self._data_loader_iter_obj
diff --git a/modeling/_init__.py b/modeling/_init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5759a489d356da9431693802cd1433071e1fc037
--- /dev/null
+++ b/modeling/_init__.py
@@ -0,0 +1,4 @@
+from . import backbone, patch_attention, head, criterion, meta_arch
+
+
+__all__ = ["backbone", "patch_attention", "head", "criterion", "meta_arch"]
diff --git a/modeling/backbone/__init__.py b/modeling/backbone/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b5e06e34b56b5e66cf8ade00f92e1271d479629
--- /dev/null
+++ b/modeling/backbone/__init__.py
@@ -0,0 +1 @@
+from .vit import build_backbone
diff --git a/modeling/backbone/utils.py b/modeling/backbone/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b779d27975791e7bb37990b227f2fd34ff0cf945
--- /dev/null
+++ b/modeling/backbone/utils.py
@@ -0,0 +1,154 @@
+from functools import partial
+from itertools import repeat
+from typing import Iterable
+import math
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+__all__ = [
+ "get_abs_pos",
+ "PatchEmbed",
+ "Mlp",
+ "DropPath",
+]
+
+
+def to_2tuple(x):
+ if isinstance(x, Iterable) and not isinstance(x, str):
+ return tuple(x)
+ return tuple(repeat(x, 2))
+
+
+def get_abs_pos(abs_pos, has_cls_token, hw):
+ """
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
+ dimension for the original embeddings.
+ Args:
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
+ hw (Tuple): size of input image tokens.
+
+ Returns:
+ Absolute positional embeddings after processing with shape (1, H, W, C)
+ """
+ h, w = hw
+ if has_cls_token:
+ abs_pos = abs_pos[:, 1:]
+ xy_num = abs_pos.shape[1]
+ size = int(math.sqrt(xy_num))
+ assert size * size == xy_num
+
+ if size != h or size != w:
+ new_abs_pos = F.interpolate(
+ abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2),
+ size=(h, w),
+ mode="bicubic",
+ align_corners=False,
+ )
+
+ return new_abs_pos.permute(0, 2, 3, 1)
+ else:
+ return abs_pos.reshape(1, h, w, -1)
+
+
+def drop_path(
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
+):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (
+ x.ndim - 1
+ ) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0 and scale_by_keep:
+ random_tensor.div_(keep_prob)
+ return x * random_tensor
+
+
+class PatchEmbed(nn.Module):
+ """
+ Image to Patch Embedding.
+ """
+
+ def __init__(
+ self,
+ kernel_size=(16, 16),
+ stride=(16, 16),
+ padding=(0, 0),
+ in_chans=3,
+ embed_dim=768,
+ ):
+ """
+ Args:
+ kernel_size (Tuple): kernel size of the projection layer.
+ stride (Tuple): stride of the projection layer.
+ padding (Tuple): padding size of the projection layer.
+ in_chans (int): Number of input image channels.
+ embed_dim (int): embed_dim (int): Patch embedding dimension.
+ """
+ super().__init__()
+
+ self.proj = nn.Conv2d(
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
+ )
+
+ def forward(self, x):
+ x = self.proj(x)
+ # B C H W -> B H W C
+ x = x.permute(0, 2, 3, 1)
+ return x
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+ self.scale_by_keep = scale_by_keep
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
+
+ def extra_repr(self):
+ return f"drop_prob={round(self.drop_prob,3):0.3f}"
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ norm_layer=None,
+ bias=True,
+ drop=0.0,
+ use_conv=False,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ bias = to_2tuple(bias)
+ drop_probs = to_2tuple(drop)
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
+
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.norm = (
+ norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
+ )
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.norm(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
diff --git a/modeling/backbone/vit.py b/modeling/backbone/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..78f45a51362aa91eae475ec26a15346b5c6c08bc
--- /dev/null
+++ b/modeling/backbone/vit.py
@@ -0,0 +1,504 @@
+import logging
+from typing import Literal, Union
+from functools import partial
+import torch
+import torch.nn as nn
+from detectron2.modeling import Backbone
+
+try:
+ from xformers.ops import memory_efficient_attention
+
+ XFORMERS_ON = True
+except ImportError:
+ XFORMERS_ON = False
+from .utils import (
+ PatchEmbed,
+ get_abs_pos,
+ DropPath,
+ Mlp,
+)
+
+
+logger = logging.getLogger(__name__)
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=True,
+ return_softmax_attn=True,
+ use_proj=True,
+ patch_token_offset=0,
+ ):
+ """
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
+ """
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.proj = nn.Linear(dim, dim) if use_proj else nn.Identity()
+
+ self.return_softmax_attn = return_softmax_attn
+
+ self.patch_token_offset = patch_token_offset
+
+ def forward(self, x, return_attention=False, extra_token_offset=None):
+ B, L, _ = x.shape
+ # qkv with shape (3, B, nHead, H * W, C)
+ qkv = self.qkv(x).view(B, L, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ # q, k, v with shape (B * nHead, H * W, C)
+ q, k, v = qkv.reshape(3, B * self.num_heads, L, -1).unbind(0)
+
+ if return_attention or not XFORMERS_ON:
+ attn = (q * self.scale) @ k.transpose(-2, -1)
+ if return_attention and not self.return_softmax_attn:
+ out_attn = attn
+ attn = attn.softmax(dim=-1)
+ if return_attention and self.return_softmax_attn:
+ out_attn = attn
+ x = attn @ v
+ else:
+ x = memory_efficient_attention(q, k, v, scale=self.scale)
+
+ x = x.view(B, self.num_heads, L, -1).permute(0, 2, 1, 3).reshape(B, L, -1)
+ x = self.proj(x)
+
+ if return_attention:
+ out_attn = out_attn.reshape(B, self.num_heads, L, -1)
+ out_attn = out_attn[
+ :,
+ :,
+ self.patch_token_offset : extra_token_offset,
+ self.patch_token_offset : extra_token_offset,
+ ]
+ return x, out_attn
+ else:
+ return x
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ drop_path=0.0,
+ norm_layer=nn.LayerNorm,
+ act_layer=nn.GELU,
+ init_values=None,
+ return_softmax_attn=True,
+ attention_map_only=False,
+ patch_token_offset=0,
+ ):
+ """
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads in each ViT block.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
+ drop_path (float): Stochastic depth rate.
+ norm_layer (nn.Module): Normalization layer.
+ act_layer (nn.Module): Activation layer.
+ """
+ super().__init__()
+ self.attention_map_only = attention_map_only
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ return_softmax_attn=return_softmax_attn,
+ use_proj=return_softmax_attn or not attention_map_only,
+ patch_token_offset=patch_token_offset,
+ )
+
+ if attention_map_only:
+ return
+
+ self.ls1 = (
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ )
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ self.mlp = Mlp(
+ in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer
+ )
+ self.ls2 = (
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ )
+
+ def forward(self, x, return_attention=False, extra_token_offset=None):
+ shortcut = x
+ x = self.norm1(x)
+
+ if return_attention:
+ x, attn = self.attn(x, True, extra_token_offset)
+ else:
+ x = self.attn(x)
+
+ if self.attention_map_only:
+ return x, attn
+
+ x = shortcut + self.drop_path(self.ls1(x))
+ x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))
+
+ if return_attention:
+ return x, attn
+ else:
+ return x
+
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: Union[float, torch.Tensor] = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
+
+
+class ViT(Backbone):
+ def __init__(
+ self,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ drop_path_rate=0.0,
+ norm_layer=nn.LayerNorm,
+ act_layer=nn.GELU,
+ pretrain_img_size=224,
+ pretrain_use_cls_token=True,
+ init_values=None,
+ use_cls_token=False,
+ use_mask_token=False,
+ norm_features=False,
+ return_softmax_attn=True,
+ num_register_tokens=0,
+ num_msg_tokens=0,
+ register_as_msg=False,
+ shift_strides=None, # [1, -1, 2, -2],
+ cls_shift=False,
+ num_extra_tokens=4,
+ use_extra_embed=False,
+ num_frames=None,
+ out_feature=True,
+ out_attn=(),
+ ):
+ super().__init__()
+ self.pretrain_use_cls_token = pretrain_use_cls_token
+
+ self.patch_size = patch_size
+ self.patch_embed = PatchEmbed(
+ kernel_size=(patch_size, patch_size),
+ stride=(patch_size, patch_size),
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ )
+
+ # Initialize absolute positional embedding with pretrain image size.
+ num_patches = (pretrain_img_size // patch_size) * (
+ pretrain_img_size // patch_size
+ )
+ num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
+
+ self.use_cls_token = use_cls_token
+ self.cls_token = (
+ nn.Parameter(torch.zeros(1, 1, embed_dim)) if use_cls_token else None
+ )
+
+ assert num_register_tokens >= 0
+ self.register_tokens = (
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim))
+ if num_register_tokens > 0
+ else None
+ )
+
+ # We tried to leverage temporal information with TeViT while it doesn't work
+ assert num_msg_tokens >= 0
+ self.num_msg_tokens = num_msg_tokens
+ if register_as_msg:
+ self.num_msg_tokens += num_register_tokens
+ self.msg_tokens = (
+ nn.Parameter(torch.zeros(1, num_msg_tokens, embed_dim))
+ if num_msg_tokens > 0
+ else None
+ )
+
+ patch_token_offset = (
+ num_msg_tokens + num_register_tokens + int(self.use_cls_token)
+ )
+ self.patch_token_offset = patch_token_offset
+
+ self.msg_shift = None
+ if shift_strides is not None:
+ self.msg_shift = []
+ for i in range(depth):
+ if i % 2 == 0:
+ self.msg_shift.append([_ for _ in shift_strides])
+ else:
+ self.msg_shift.append([-_ for _ in shift_strides])
+
+ self.cls_shift = None
+ if cls_shift:
+ self.cls_shift = [(-1) ** idx for idx in range(depth)]
+
+ assert num_extra_tokens >= 0
+ self.num_extra_tokens = num_extra_tokens
+ self.extra_pos_embed = (
+ nn.Linear(embed_dim, embed_dim)
+ if num_extra_tokens > 0 and use_extra_embed
+ else nn.Identity()
+ )
+
+ self.num_frames = num_frames
+
+ # Mask token for masking augmentation
+ self.use_mask_token = use_mask_token
+ self.mask_token = (
+ nn.Parameter(torch.zeros(1, embed_dim)) if use_mask_token else None
+ )
+
+ # stochastic depth decay rule
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
+
+ self.blocks = nn.ModuleList()
+ for i in range(depth):
+ block = Block(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ init_values=init_values,
+ return_softmax_attn=return_softmax_attn,
+ attention_map_only=(i == depth - 1) and not out_feature,
+ patch_token_offset=patch_token_offset,
+ )
+ self.blocks.append(block)
+
+ self.norm = norm_layer(embed_dim) if norm_features else nn.Identity()
+
+ self._out_features = out_feature
+ self._out_attn = out_attn
+
+ if self.pos_embed is not None:
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def forward(self, x, masks=None, guidance=None):
+ x = self.patch_embed(x)
+ if masks is not None:
+ x = torch.where(
+ masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x
+ )
+
+ if self.pos_embed is not None:
+ x = x + get_abs_pos(
+ self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2])
+ )
+ B, H, W, _ = x.shape
+ x = x.reshape(B, H * W, -1)
+
+ if self.use_cls_token:
+ cls_tokens = self.cls_token.expand(len(x), -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ if self.register_tokens is not None:
+ register_tokens = self.register_tokens.expand(len(x), -1, -1)
+ x = torch.cat((register_tokens, x), dim=1)
+
+ if self.msg_tokens is not None:
+ msg_tokens = self.msg_tokens.expand(len(x), -1, -1)
+ x = torch.cat((msg_tokens, x), dim=1)
+ # [MSG, REG, CLS, PAT]
+
+ extra_tokens_offset = None
+ if guidance is not None:
+ guidance = guidance.reshape(len(guidance), -1, 1)
+ extra_tokens = (
+ (x[:, self.patch_token_offset :] * guidance)
+ .sum(dim=1, keepdim=True)
+ .expand(-1, self.num_extra_tokens, -1)
+ )
+ extra_tokens = self.extra_pos_embed(extra_tokens)
+ x = torch.cat((x, extra_tokens), dim=1)
+ extra_tokens_offset = -self.num_extra_tokens
+ # [MSG, REG, CLS, PAT, EXT]
+
+ attn_maps = []
+ for idx, blk in enumerate(self.blocks):
+ if idx in self._out_attn:
+ x, attn = blk(x, True, extra_tokens_offset)
+ attn_maps.append(attn)
+ else:
+ x = blk(x)
+
+ if self.msg_shift is not None:
+ msg_shift = self.msg_shift[idx]
+ msg_tokens = (
+ x[:, : self.num_msg_tokens]
+ if guidance is None
+ else x[:, extra_tokens_offset:]
+ )
+ msg_tokens = msg_tokens.reshape(
+ -1, self.num_frames, *msg_tokens.shape[1:]
+ )
+ msg_tokens = msg_tokens.chunk(len(msg_shift), dim=2)
+ msg_tokens = [
+ torch.roll(tokens, roll, dims=1)
+ for tokens, roll in zip(msg_tokens, msg_shift)
+ ]
+ msg_tokens = torch.cat(msg_tokens, dim=2).flatten(0, 1)
+ if guidance is None:
+ x = torch.cat([msg_tokens, x[:, self.num_msg_tokens :]], dim=1)
+ else:
+ x = torch.cat([x[:, :extra_tokens_offset], msg_tokens], dim=1)
+
+ if self.cls_shift is not None:
+ cls_tokens = x[:, self.patch_token_offset - 1]
+ cls_tokens = cls_tokens.reshape(
+ -1, self.num_frames, 1, *cls_tokens.shape[1:]
+ )
+ cls_tokens = torch.roll(cls_tokens, self.cls_shift[idx], dims=1)
+ x = torch.cat(
+ [
+ x[:, : self.patch_token_offset - 1],
+ cls_tokens.flatten(0, 1),
+ x[:, self.patch_token_offset :],
+ ],
+ dim=1,
+ )
+
+ x = self.norm(x)
+
+ outputs = {}
+ outputs["attention_maps"] = torch.cat(attn_maps, dim=1).reshape(
+ B, -1, H * W, H, W
+ )
+ if self._out_features:
+ outputs["last_feat"] = (
+ x[:, self.patch_token_offset : extra_tokens_offset]
+ .reshape(B, H, W, -1)
+ .permute(0, 3, 1, 2)
+ )
+
+ return outputs
+
+
+def vit_tiny(**kwargs):
+ model = ViT(
+ patch_size=16,
+ embed_dim=192,
+ depth=12,
+ num_heads=3,
+ mlp_ratio=4,
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ **kwargs
+ )
+ return model
+
+
+def vit_small(**kwargs):
+ model = ViT(
+ patch_size=16,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ **kwargs
+ )
+ return model
+
+
+def vit_base(**kwargs):
+ model = ViT(
+ patch_size=16,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ **kwargs
+ )
+ return model
+
+
+def dinov2_base(**kwargs):
+ model = ViT(
+ patch_size=14,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ pretrain_img_size=518,
+ init_values=1,
+ **kwargs
+ )
+ return model
+
+
+def dinov2_small(**kwargs):
+ model = ViT(
+ patch_size=14,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ pretrain_img_size=518,
+ init_values=1,
+ **kwargs
+ )
+ return model
+
+
+def build_backbone(
+ name: Literal["tiny", "small", "base", "dinov2_base", "dinov2_small"], **kwargs
+):
+ vit_dict = {
+ "tiny": vit_tiny,
+ "small": vit_small,
+ "base": vit_base,
+ "dinov2_base": dinov2_base,
+ "dinov2_small": dinov2_small,
+ }
+ return vit_dict[name](**kwargs)
diff --git a/modeling/criterion/__init__.py b/modeling/criterion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c63773c315c423290e8c2e30a3351fce5cb5b5f
--- /dev/null
+++ b/modeling/criterion/__init__.py
@@ -0,0 +1 @@
+from .gaze_mapper_criterion import GazeMapperCriterion
diff --git a/modeling/criterion/gaze_mapper_criterion.py b/modeling/criterion/gaze_mapper_criterion.py
new file mode 100644
index 0000000000000000000000000000000000000000..472b0334cba6627d134671626711a7b3f66ea93c
--- /dev/null
+++ b/modeling/criterion/gaze_mapper_criterion.py
@@ -0,0 +1,94 @@
+from functools import partial
+import torch
+from torch import nn
+from torch.nn import functional as F
+from fvcore.nn import sigmoid_focal_loss_jit
+
+
+class GazeMapperCriterion(nn.Module):
+ def __init__(
+ self,
+ heatmap_weight: float = 10000,
+ inout_weight: float = 100,
+ aux_weight: float = 100,
+ use_aux_loss: bool = False,
+ aux_head_thres: float = 0,
+ use_focal_loss: bool = False,
+ alpha: float = -1,
+ gamma: float = 2,
+ ):
+ super().__init__()
+ self.heatmap_weight = heatmap_weight
+ self.inout_weight = inout_weight
+ self.aux_weight = aux_weight
+ self.aux_head_thres = aux_head_thres
+
+ self.heatmap_loss = nn.MSELoss(reduce=False)
+
+ if use_focal_loss:
+ self.inout_loss = partial(
+ sigmoid_focal_loss_jit, alpha=alpha, gamma=gamma, reduction="mean"
+ )
+ else:
+ self.inout_loss = nn.BCEWithLogitsLoss()
+
+ if use_aux_loss:
+ self.aux_loss = nn.BCEWithLogitsLoss()
+ else:
+ self.aux_loss = None
+
+ def forward(
+ self,
+ pred_heatmap,
+ pred_inout,
+ gt_heatmap,
+ gt_inout,
+ pred_head_masks=None,
+ gt_head_masks=None,
+ ):
+ loss_dict = {}
+
+ pred_heatmap = F.interpolate(
+ pred_heatmap,
+ size=tuple(gt_heatmap.shape[-2:]),
+ mode="bilinear",
+ align_corners=True,
+ )
+ heatmap_loss = (
+ self.heatmap_loss(pred_heatmap.squeeze(1), gt_heatmap) * self.heatmap_weight
+ )
+ heatmap_loss = torch.mean(heatmap_loss, dim=(-2, -1))
+ heatmap_loss = torch.sum(heatmap_loss.reshape(-1) * gt_inout.reshape(-1))
+ # Check whether all outside, avoid 0/0 to be nan
+ if heatmap_loss > 1e-7:
+ heatmap_loss = heatmap_loss / torch.sum(gt_inout)
+ loss_dict["regression loss"] = heatmap_loss
+ else:
+ loss_dict["regression loss"] = heatmap_loss * 0
+
+ inout_loss = (
+ self.inout_loss(pred_inout.reshape(-1), gt_inout.reshape(-1))
+ * self.inout_weight
+ )
+ loss_dict["classification loss"] = inout_loss
+
+ if self.aux_loss is not None:
+ pred_head_masks = F.interpolate(
+ pred_head_masks,
+ size=tuple(gt_head_masks.shape[-2:]),
+ mode="bilinear",
+ align_corners=True,
+ )
+ aux_loss = (
+ torch.clamp(
+ self.aux_loss(
+ pred_head_masks.reshape(-1), gt_head_masks.reshape(-1)
+ )
+ - self.aux_head_thres,
+ 0,
+ )
+ * self.aux_weight
+ )
+ loss_dict["aux head loss"] = aux_loss
+
+ return loss_dict
diff --git a/modeling/head/__init__.py b/modeling/head/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a67fa906dec39e14d7107607d45adf42fe9bdf0
--- /dev/null
+++ b/modeling/head/__init__.py
@@ -0,0 +1,2 @@
+from .inout_head import build_inout_head
+from .heatmap_head import build_heatmap_head
diff --git a/modeling/head/heatmap_head.py b/modeling/head/heatmap_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..60f464375e614a5c0f60de57b22eef107e9d72aa
--- /dev/null
+++ b/modeling/head/heatmap_head.py
@@ -0,0 +1,225 @@
+import torch
+from torch import nn
+from detectron2.utils.registry import Registry
+from typing import Literal, List, Dict, Optional, OrderedDict
+
+
+HEATMAP_HEAD_REGISTRY = Registry("HEATMAP_HEAD_REGISTRY")
+HEATMAP_HEAD_REGISTRY.__doc__ = "Registry for heatmap head"
+
+
+class BaseHeatmapHead(nn.Module):
+ def __init__(
+ self,
+ in_channel: int,
+ deconv_cfgs: List[Dict],
+ dim: int = 96,
+ use_conv: bool = False,
+ use_residual: bool = False,
+ feat_type: Literal["attn", "both"] = "both",
+ attn_layer: Optional[str] = None,
+ pre_norm: bool = False,
+ use_head: bool = False,
+ ) -> None:
+ super().__init__()
+ self.feat_type = feat_type
+ self.use_head = use_head
+
+ if pre_norm:
+ self.pre_norm = nn.Sequential(
+ OrderedDict(
+ [
+ ("bn", nn.BatchNorm2d(in_channel)),
+ ("relu", nn.ReLU(inplace=True)),
+ ]
+ )
+ )
+ else:
+ self.pre_norm = nn.Identity()
+
+ if use_conv:
+ if use_residual:
+ from timm.models.resnet import Bottleneck, downsample_conv
+
+ self.conv = Bottleneck(
+ in_channel,
+ dim // 4,
+ downsample=downsample_conv(in_channel, dim, 1)
+ if in_channel != dim
+ else None,
+ attn_layer=attn_layer,
+ )
+ else:
+ self.conv = nn.Sequential(
+ OrderedDict(
+ [
+ ("conv", nn.Conv2d(in_channel, dim, 3, 1, 1)),
+ ("bn", nn.BatchNorm2d(dim)),
+ ("relu", nn.ReLU(inplace=True)),
+ ]
+ )
+ )
+ else:
+ self.conv = nn.Identity()
+
+ self.decoder: nn.Module = None
+
+ def get_feat(self, x):
+ if self.feat_type == "attn":
+ feat = x["attention_maps"]
+ elif self.feat_type == "feat":
+ feat = x["last_feat"]
+ return feat
+
+ def forward(self, x):
+ feat = self.get_feat(x)
+ feat = self.pre_norm(feat)
+ feat = self.conv(feat)
+ return self.decoder(feat)
+
+
+@HEATMAP_HEAD_REGISTRY.register()
+class SimpleDeconv(BaseHeatmapHead):
+ def __init__(
+ self,
+ in_channel: int,
+ deconv_cfgs: List[Dict],
+ dim: int = 96,
+ use_conv: bool = False,
+ use_residual: bool = False,
+ feat_type: Literal["attn", "both"] = "both",
+ attn_layer: Optional[str] = None,
+ pre_norm: bool = False,
+ use_head: bool = False,
+ ) -> None:
+ super().__init__(
+ in_channel,
+ deconv_cfgs,
+ dim,
+ use_conv,
+ use_residual,
+ feat_type,
+ attn_layer,
+ pre_norm,
+ use_head,
+ )
+ decoder_layers = []
+ for i, deconv_cfg in enumerate(deconv_cfgs, start=1):
+ decoder_layers.extend(
+ [
+ (
+ "".join(["deconv", str(i)]),
+ nn.ConvTranspose2d(**deconv_cfg),
+ ),
+ (
+ "".join(["bn", str(i)]),
+ nn.BatchNorm2d(deconv_cfg["out_channels"]),
+ ),
+ ("".join(["relu", str(i)]), nn.ReLU(inplace=True)),
+ ]
+ )
+ decoder_layers.append(("conv", nn.Conv2d(1, 1, 1)))
+ self.decoder = nn.Sequential(OrderedDict(decoder_layers))
+
+
+@HEATMAP_HEAD_REGISTRY.register()
+class UpSampleConv(BaseHeatmapHead):
+ def __init__(
+ self,
+ in_channel: int,
+ deconv_cfgs: List[Dict],
+ dim: int = 96,
+ use_conv: bool = False,
+ use_residual: bool = False,
+ feat_type: Literal["attn", "both"] = "both",
+ attn_layer: Optional[str] = None,
+ pre_norm: bool = False,
+ use_head: bool = False,
+ ) -> None:
+ super().__init__(
+ in_channel,
+ deconv_cfgs,
+ dim,
+ use_conv,
+ use_residual,
+ feat_type,
+ attn_layer,
+ pre_norm,
+ use_head,
+ )
+ decoder_layers = []
+ for i, deconv_cfg in enumerate(deconv_cfgs, start=1):
+ decoder_layers.extend(
+ [
+ (
+ "".join(["upsample", str(i)]),
+ nn.Upsample(
+ scale_factor=2, mode="bilinear", align_corners=True
+ ),
+ ),
+ (
+ "".join(["conv", str(i)]),
+ nn.Conv2d(**deconv_cfg),
+ ),
+ (
+ "".join(["bn", str(i)]),
+ nn.BatchNorm2d(deconv_cfg["out_channels"]),
+ ),
+ ("".join(["relu", str(i)]), nn.ReLU(inplace=True)),
+ ]
+ )
+ decoder_layers.append(("conv", nn.Conv2d(1, 1, 1)))
+ self.decoder = nn.Sequential(OrderedDict(decoder_layers))
+
+
+@HEATMAP_HEAD_REGISTRY.register()
+class PixelShuffle(BaseHeatmapHead):
+ def __init__(
+ self,
+ in_channel: int,
+ deconv_cfgs: List[Dict],
+ dim: int = 96,
+ use_conv: bool = False,
+ use_residual: bool = False,
+ feat_type: Literal["attn", "both"] = "both",
+ attn_layer: Optional[str] = None,
+ pre_norm: bool = False,
+ use_head: bool = False,
+ ) -> None:
+ super().__init__(
+ in_channel,
+ deconv_cfgs,
+ dim,
+ use_conv,
+ use_residual,
+ feat_type,
+ attn_layer,
+ pre_norm,
+ use_head,
+ )
+ decoder_layers = []
+ for i, deconv_cfg in enumerate(deconv_cfgs, start=1):
+ deconv_cfg["out_channels"] = deconv_cfg["out_channels"] * 4
+ decoder_layers.extend(
+ [
+ (
+ "".join(["conv", str(i)]),
+ nn.Conv2d(**deconv_cfg),
+ ),
+ (
+ "".join(["pixel_shuffle", str(i)]),
+ nn.PixelShuffle(upscale_factor=2),
+ ),
+ (
+ "".join(["bn", str(i)]),
+ nn.BatchNorm2d(deconv_cfg["out_channels"] // 4),
+ ),
+ ("".join(["relu", str(i)]), nn.ReLU(inplace=True)),
+ ]
+ )
+ decoder_layers.append(("conv", nn.Conv2d(1, 1, 1)))
+ self.decoder = nn.Sequential(OrderedDict(decoder_layers))
+
+
+def build_heatmap_head(name, *args, **kwargs):
+ return HEATMAP_HEAD_REGISTRY.get(name)(*args, **kwargs)
diff --git a/modeling/head/inout_head.py b/modeling/head/inout_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5e8ae3a4077a9bdbda95cf4eaf97d37e0ff72d4
--- /dev/null
+++ b/modeling/head/inout_head.py
@@ -0,0 +1,55 @@
+from typing import OrderedDict
+import torch
+from torch import nn
+from detectron2.utils.registry import Registry
+
+
+INOUT_HEAD_REGISTRY = Registry("INOUT_HEAD_REGISTRY")
+INOUT_HEAD_REGISTRY.__doc__ = "Registry for inout head"
+
+
+@INOUT_HEAD_REGISTRY.register()
+class SimpleLinear(nn.Module):
+ def __init__(self, in_channel: int, dropout: float = 0) -> None:
+ super().__init__()
+ self.in_channel = in_channel
+ self.classifier = nn.Sequential(
+ OrderedDict(
+ [("dropout", nn.Dropout(dropout)), ("linear", nn.Linear(in_channel, 1))]
+ )
+ )
+
+ def get_feat(self, x, masks):
+ feats = x["head_feat"]
+ if masks is not None:
+ B, C = x["last_feat"].shape[:2]
+ scene_feats = x["last_feat"].view(B, C, -1).permute(0, 2, 1)
+ masks = masks / (masks.sum(dim=-1, keepdim=True) + 1e-6)
+ scene_feats = (scene_feats * masks.unsqueeze(-1)).sum(dim=1)
+ feats = torch.cat((feats, scene_feats), dim=1)
+ return feats
+
+ def forward(self, x, masks=None):
+ feat = self.get_feat(x, masks)
+ return self.classifier(feat)
+
+
+@INOUT_HEAD_REGISTRY.register()
+class SimpleMlp(SimpleLinear):
+ def __init__(self, in_channel: int, dropout: float = 0) -> None:
+ super().__init__(in_channel, dropout)
+ self.classifier = nn.Sequential(
+ OrderedDict(
+ [
+ ("dropout0", nn.Dropout(dropout)),
+ ("linear0", nn.Linear(in_channel, in_channel)),
+ ("relu", nn.ReLU()),
+ ("dropout1", nn.Dropout(dropout)),
+ ("linear1", nn.Linear(in_channel, 1)),
+ ]
+ )
+ )
+
+
+def build_inout_head(name, *args, **kwargs):
+ return INOUT_HEAD_REGISTRY.get(name)(*args, **kwargs)
diff --git a/modeling/meta_arch/__init__.py b/modeling/meta_arch/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d7b2a8079d210faff9e594e501e8980df6cec7a
--- /dev/null
+++ b/modeling/meta_arch/__init__.py
@@ -0,0 +1 @@
+from .gaze_attention_mapper import GazeAttentionMapper
diff --git a/modeling/meta_arch/gaze_attention_mapper.py b/modeling/meta_arch/gaze_attention_mapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6244bf6512b2c7d4fc25c2e731b4edc357cdfc4
--- /dev/null
+++ b/modeling/meta_arch/gaze_attention_mapper.py
@@ -0,0 +1,97 @@
+import torch
+from torch import nn
+from typing import Dict, Union
+
+
+class GazeAttentionMapper(nn.Module):
+ def __init__(
+ self,
+ backbone: nn.Module,
+ regressor: nn.Module,
+ classifier: nn.Module,
+ criterion: nn.Module,
+ pam: nn.Module,
+ use_aux_loss: bool = False,
+ device: Union[torch.device, str] = "cuda",
+ ) -> None:
+ super().__init__()
+ self.backbone = backbone
+ self.pam = pam
+ self.regressor = regressor
+ self.classifier = classifier
+ self.criterion = criterion
+ self.use_aux_loss = use_aux_loss
+ self.device = torch.device(device)
+
+ def forward(self, x):
+ (
+ scenes,
+ heads,
+ gt_heatmaps,
+ gt_inouts,
+ head_masks,
+ image_masks,
+ ) = self.preprocess_inputs(x)
+ # Calculate patch weights based on head position
+ embedded_heads = self.pam(scenes, heads)
+ aux_masks = None
+ if self.use_aux_loss:
+ embedded_heads, aux_masks = embedded_heads
+
+ # Get out-dict
+ x = self.backbone(
+ scenes,
+ image_masks,
+ None,
+ )
+
+ # Apply patch weights to get the final feats and attention maps
+ feats = x.get("last_feat", None)
+ if feats is not None:
+ x["head_feat"] = (
+ (embedded_heads.repeat(1, feats.shape[1], 1, 1) * feats)
+ .sum(dim=(2, 3))
+ .reshape(len(feats), -1)
+ ) # BC
+
+ attn_maps = x["attention_maps"]
+ B, C, *_ = attn_maps.shape
+ x["attention_maps"] = (
+ attn_maps * embedded_heads.reshape(B, 1, -1, 1, 1).repeat(1, C, 1, 1, 1)
+ ).sum(
+ dim=2
+ ) # BCHW
+
+ # Apply heads
+ heatmaps = self.regressor(x)
+ inouts = self.classifier(x, None)
+
+ if self.training:
+ return self.criterion(
+ heatmaps,
+ inouts,
+ gt_heatmaps,
+ gt_inouts,
+ aux_masks,
+ head_masks,
+ )
+ # Inference
+ return heatmaps, inouts.sigmoid()
+
+ def preprocess_inputs(self, batched_inputs: Dict[str, torch.Tensor]):
+ return (
+ batched_inputs["images"].to(self.device),
+ batched_inputs["head_channels"].to(self.device),
+ batched_inputs["heatmaps"].to(self.device)
+ if "heatmaps" in batched_inputs.keys()
+ else None,
+ batched_inputs["gaze_inouts"].to(self.device)
+ if "gaze_inouts" in batched_inputs.keys()
+ else None,
+ batched_inputs["head_masks"].to(self.device)
+ if "head_masks" in batched_inputs.keys()
+ else None,
+ batched_inputs["image_masks"].to(self.device)
+ if "image_masks" in batched_inputs.keys()
+ else None,
+ )
diff --git a/modeling/patch_attention/__init__.py b/modeling/patch_attention/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9be13cc2e49b56af39e59796fa92cb766ca45c4d
--- /dev/null
+++ b/modeling/patch_attention/__init__.py
@@ -0,0 +1 @@
+from .patch_attention import build_pam
diff --git a/modeling/patch_attention/patch_attention.py b/modeling/patch_attention/patch_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a42b236865e1e2d05a356f4f686ee282bdd6e0e
--- /dev/null
+++ b/modeling/patch_attention/patch_attention.py
@@ -0,0 +1,106 @@
+from typing import OrderedDict
+from torch import nn
+from torch.nn import functional as F
+from detectron2.utils.registry import Registry
+from fvcore.nn import c2_msra_fill
+
+
+SPATIAL_GUIDANCE_REGISTRY = Registry("SPATIAL_GUIDANCE_REGISTRY")
+SPATIAL_GUIDANCE_REGISTRY.__doc__ = "Registry for 2d spatial guidance"
+
+
+class _PoolFusion(nn.Module):
+ def __init__(self, patch_size: int, use_avgpool: bool = False) -> None:
+ super().__init__()
+ self.patch_size = patch_size
+ self.attn_reducer = F.avg_pool2d if use_avgpool else F.max_pool2d
+
+ def forward(self, scenes, heads):
+ attn_masks = self.attn_reducer(
+ heads,
+ (self.patch_size, self.patch_size),
+ (self.patch_size, self.patch_size),
+ (0, 0),
+ )
+ patch_attn = attn_masks.masked_fill(attn_masks <= 0, -1e9)
+ return F.softmax(patch_attn.view(len(patch_attn), -1), dim=1).view(
+ *patch_attn.shape
+ )
+
+
+@SPATIAL_GUIDANCE_REGISTRY.register()
+class AvgFusion(_PoolFusion):
+ def __init__(self, patch_size: int) -> None:
+ super().__init__(patch_size, False)
+
+
+@SPATIAL_GUIDANCE_REGISTRY.register()
+class MaxFusion(_PoolFusion):
+ def __init__(self, patch_size: int) -> None:
+ super().__init__(patch_size, True)
+
+
+@SPATIAL_GUIDANCE_REGISTRY.register()
+class PatchPAM(nn.Module):
+ def __init__(
+ self,
+ patch_size: int,
+ act_layer=nn.ReLU,
+ embed_dim: int = 768,
+ use_aux_loss: bool = False,
+ ) -> None:
+ super().__init__()
+ self.patch_size = patch_size
+ patch_embed = nn.Conv2d(
+ 3, embed_dim, (patch_size, patch_size), (patch_size, patch_size), (0, 0)
+ )
+ c2_msra_fill(patch_embed)
+ conv = nn.Conv2d(embed_dim, 1, (1, 1), (1, 1), (0, 0))
+ c2_msra_fill(conv)
+ self.use_aux_loss = use_aux_loss
+ if use_aux_loss:
+ self.patch_embed = nn.Sequential(
+ OrderedDict(
+ [
+ ("patch_embed", patch_embed),
+ ("act_layer", act_layer(inplace=True)),
+ ]
+ )
+ )
+ self.embed = conv
+ conv = nn.Conv2d(embed_dim, 1, (1, 1), (1, 1), (0, 0))
+ c2_msra_fill(conv)
+ self.aux_embed = conv
+ else:
+ self.embed = nn.Sequential(
+ OrderedDict(
+ [
+ ("patch_embed", patch_embed),
+ ("act_layer", act_layer(inplace=True)),
+ ("embed", conv),
+ ]
+ )
+ )
+
+ def forward(self, scenes, heads):
+ attn_masks = F.max_pool2d(
+ heads,
+ (self.patch_size, self.patch_size),
+ (self.patch_size, self.patch_size),
+ (0, 0),
+ )
+ if self.use_aux_loss:
+ embed = self.patch_embed(scenes)
+ aux_masks = self.aux_embed(embed)
+ patch_attn = self.embed(embed) * attn_masks
+ else:
+ patch_attn = self.embed(scenes) * attn_masks
+ patch_attn = patch_attn.masked_fill(attn_masks <= 0, -1e9)
+ patch_attn = F.softmax(patch_attn.view(len(patch_attn), -1), dim=1).view(
+ *patch_attn.shape
+ )
+ return (patch_attn, aux_masks) if self.use_aux_loss else patch_attn
+
+
+def build_pam(name, *args, **kwargs):
+ return SPATIAL_GUIDANCE_REGISTRY.get(name)(*args, **kwargs)
diff --git a/pretrained/gazefollow.pth b/pretrained/gazefollow.pth
new file mode 100644
index 0000000000000000000000000000000000000000..71fcf32cdf245634a5df508b03cf710dcac49f18
--- /dev/null
+++ b/pretrained/gazefollow.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:61f82255d4db0640208c90b037fd4bb62e061efd51cd03aa4b45f10159acab15
+size 266808135
diff --git a/pretrained/videoattentiontarget.pth b/pretrained/videoattentiontarget.pth
new file mode 100644
index 0000000000000000000000000000000000000000..1c745d99f3723b0649ed282f28c942f5cd5ce02b
--- /dev/null
+++ b/pretrained/videoattentiontarget.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cdf8d767815b7c8f1ba78d89710dc7a2de2d08e97554654a296a3e71b4903cec
+size 266808135
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..71b571a904fd06e3de67c9766361135e1f4d199e
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,17 @@
+fvcore==0.1.5.post20221221
+numpy==1.26.2
+omegaconf==2.3.0
+opencv-python==4.8.1.78
+opencv-python-headless==4.6.0.66
+pandas==1.4.4
+Pillow==10.3.0
+scikit-image==0.22.0
+scikit-learn==1.5.0
+scipy==1.11.4
+timm==0.9.12
+torch==2.2.0
+torchaudio==2.0.2
+torchvision==0.15.2
+tqdm==4.66.3
+xformers==0.0.21
+yacs==0.1.8
diff --git a/scripts/convert_pth.py b/scripts/convert_pth.py
new file mode 100644
index 0000000000000000000000000000000000000000..48d63809813fde7ac061d0c3fe4fafb4283b03bd
--- /dev/null
+++ b/scripts/convert_pth.py
@@ -0,0 +1,32 @@
+# Convert official model weights to format that d2 receives
+import argparse
+from collections import OrderedDict
+import torch
+
+
+def convert(src: str, dst: str):
+ checkpoint = torch.load(src)
+ has_model = "model" in checkpoint.keys()
+ checkpoint = checkpoint["model"] if has_model else checkpoint
+ if "state_dict" in checkpoint.keys():
+ checkpoint = checkpoint["state_dict"]
+ out_cp = OrderedDict()
+ for k, v in checkpoint.items():
+ out_cp[".".join(["backbone", k])] = v
+ out_cp = {"model": out_cp} if has_model else out_cp
+ torch.save(out_cp, dst)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--src", "-s", type=str, required=True, help="Path to src weights.pth"
+ )
+ parser.add_argument(
+ "--dst", "-d", type=str, required=True, help="Path to dst weights.pth"
+ )
+ args = parser.parse_args()
+ convert(
+ args.src,
+ args.dst,
+ )
diff --git a/scripts/gen_gazefollow_head_masks.py b/scripts/gen_gazefollow_head_masks.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b35ab230c33393650e8b36011c0b919f1f49d2d
--- /dev/null
+++ b/scripts/gen_gazefollow_head_masks.py
@@ -0,0 +1,185 @@
+import argparse
+import os
+import random
+import cv2
+import numpy as np
+import pandas as pd
+import tqdm
+from PIL import Image
+from retinaface.pre_trained_models import get_model
+
+
+random.seed(1)
+
+
+def gaussian(x_min, y_min, x_max, y_max):
+ x_min, x_max = sorted((x_min, x_max))
+ y_min, y_max = sorted((y_min, y_max))
+ x_mid, y_mid = (x_min + x_max) / 2, (y_min + y_max) / 2
+ x_sigma2, y_sigma2 = (
+ np.clip(np.square([x_max - x_min, y_max - y_min], dtype=float), 1, None) / 3
+ )
+
+ def _gaussian(_xs, _ys):
+ return np.exp(
+ -(np.square(_xs - x_mid) / x_sigma2 + np.square(_ys - y_mid) / y_sigma2)
+ )
+
+ return _gaussian
+
+
+def plot_ori(label_path, data_dir):
+ df = pd.read_csv(
+ label_path,
+ names=[
+ # Original labels
+ "path",
+ "idx",
+ "body_bbox_x",
+ "body_bbox_y",
+ "body_bbox_w",
+ "body_bbox_h",
+ "eye_x",
+ "eye_y",
+ "gaze_x",
+ "gaze_y",
+ "x_min",
+ "y_min",
+ "x_max",
+ "y_max",
+ "inout",
+ "meta0",
+ "meta1",
+ ],
+ index_col=False,
+ encoding="utf-8-sig",
+ )
+ grouped = df.groupby("path")
+
+ output_dir = os.path.join(data_dir, "head_masks")
+
+ for image_name, group_df in tqdm.tqdm(grouped, desc="Generating masks with annotations: "):
+ if not os.path.exists(os.path.join(output_dir, image_name)):
+ w, h = Image.open(image_name).size
+ heatmap = np.zeros((h, w), dtype=np.float32)
+ indices = np.meshgrid(
+ np.linspace(0.0, float(w), num=w, endpoint=False),
+ np.linspace(0.0, float(h), num=h, endpoint=False),
+ )
+ for _, row in group_df.iterrows():
+ x_min, y_min, x_max, y_max = (
+ row["x_min"],
+ row["y_min"],
+ row["x_max"],
+ row["y_max"],
+ )
+ gauss = gaussian(x_min, y_min, x_max, y_max)
+ heatmap += gauss(*indices)
+ heatmap /= np.max(heatmap)
+ heatmap_image = Image.fromarray((heatmap * 255).astype(np.uint8), mode="L")
+ output_filename = os.path.join(output_dir, image_name)
+ os.makedirs(os.path.dirname(output_filename), exist_ok=True)
+ heatmap_image.save(output_filename)
+
+
+def plot_gen(df, data_dir):
+ df = df[df["score"] > 0.8]
+ grouped = df.groupby("path")
+
+ output_dir = os.path.join(data_dir, "head_masks")
+
+ for image_name, group_df in tqdm.tqdm(grouped, desc="Generating masks with predictions: "):
+ w, h = Image.open(image_name).size
+ heatmap = np.zeros((h, w), dtype=np.float32)
+ indices = np.meshgrid(
+ np.linspace(0.0, float(w), num=w, endpoint=False),
+ np.linspace(0.0, float(h), num=h, endpoint=False),
+ )
+ for index, row in group_df.iterrows():
+ x_min, y_min, x_max, y_max = (
+ row["x_min"],
+ row["y_min"],
+ row["x_max"],
+ row["y_max"],
+ )
+ gauss = gaussian(x_min, y_min, x_max, y_max)
+ heatmap += gauss(*indices)
+ heatmap /= np.max(heatmap)
+ heatmap_image = Image.fromarray((heatmap * 255).astype(np.uint8), mode="L")
+ output_filename = os.path.join(output_dir, image_name)
+ os.makedirs(os.path.dirname(output_filename), exist_ok=True)
+ heatmap_image.save(output_filename)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--dataset_dir", help="Root directory of dataset")
+ parser.add_argument(
+ "--subset",
+ help="Subset of dataset to process",
+ choices=["train", "test"],
+ )
+ args = parser.parse_args()
+
+ label_path = os.path.join(
+ args.dataset_dir, args.subset + "_annotations_release.txt"
+ )
+
+ column_names = [
+ "path",
+ "idx",
+ "body_bbox_x",
+ "body_bbox_y",
+ "body_bbox_w",
+ "body_bbox_h",
+ "eye_x",
+ "eye_y",
+ "gaze_x",
+ "gaze_y",
+ "bbox_x_min",
+ "bbox_y_min",
+ "bbox_x_max",
+ "bbox_y_max",
+ ]
+ df = pd.read_csv(
+ label_path,
+ sep=",",
+ names=column_names,
+ usecols=column_names,
+ index_col=False,
+ )
+ df = df.groupby("path")
+
+ model = get_model("resnet50_2020-07-20", max_size=2048, device="cuda")
+ model.eval()
+
+ paths = list(df.groups.keys())
+ csv = []
+ for path in tqdm.tqdm(paths, desc="Predicting head bboxes: "):
+ img = cv2.imread(os.path.join(args.dataset_dir, path))
+
+ annotations = model.predict_jsons(img)
+
+ for annotation in annotations:
+ if len(annotation["bbox"]) == 0:
+ continue
+
+ csv.append(
+ [
+ path,
+ annotation["score"],
+ annotation["bbox"][0],
+ annotation["bbox"][1],
+ annotation["bbox"][2],
+ annotation["bbox"][3],
+ ]
+ )
+
+ # Write csv
+ df = pd.DataFrame(
+ csv, columns=["path", "score", "x_min", "y_min", "x_max", "y_max"]
+ )
+ df.to_csv(os.path.join(args.dataset_dir, f"{args.subset}_head.csv"), index=False)
+
+ plot_gen(df, args.dataset_dir)
+ plot_ori(label_path, args.data_dir)
diff --git a/tools/eval_on_gazefollow.py b/tools/eval_on_gazefollow.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bd51709ef4fe2671d59e2abde9c0bc21e389a86
--- /dev/null
+++ b/tools/eval_on_gazefollow.py
@@ -0,0 +1,92 @@
+import sys
+from os import path as osp
+import argparse
+import warnings
+import torch
+import numpy as np
+from PIL import Image
+from detectron2.config import instantiate, LazyConfig
+
+sys.path.append(osp.dirname(osp.dirname(__file__)))
+from utils import *
+
+
+warnings.simplefilter(action="ignore", category=FutureWarning)
+
+
+def do_test(cfg, model, use_dark_inference=False):
+ val_loader = instantiate(cfg.dataloader.val)
+
+ model.train(False)
+ AUC = []
+ min_dist = []
+ avg_dist = []
+ with torch.no_grad():
+ for data in val_loader:
+ val_gaze_heatmap_pred, _ = model(data)
+ val_gaze_heatmap_pred = (
+ val_gaze_heatmap_pred.squeeze(1).cpu().detach().numpy()
+ )
+
+ # go through each data point and record AUC, min dist, avg dist
+ for b_i in range(len(val_gaze_heatmap_pred)):
+ # remove padding and recover valid ground truth points
+ valid_gaze = data["gazes"][b_i]
+ valid_gaze = valid_gaze[valid_gaze != -1].view(-1, 2)
+ # AUC: area under curve of ROC
+ multi_hot = multi_hot_targets(data["gazes"][b_i], data["imsize"][b_i])
+ if use_dark_inference:
+ pred_x, pred_y = dark_inference(val_gaze_heatmap_pred[b_i])
+ else:
+ pred_x, pred_y = argmax_pts(val_gaze_heatmap_pred[b_i])
+ norm_p = [
+ pred_x / val_gaze_heatmap_pred[b_i].shape[-2],
+ pred_y / val_gaze_heatmap_pred[b_i].shape[-1],
+ ]
+ scaled_heatmap = np.array(
+ Image.fromarray(val_gaze_heatmap_pred[b_i]).resize(
+ data["imsize"][b_i],
+ resample=Image.BILINEAR,
+ )
+ )
+ auc_score = auc(scaled_heatmap, multi_hot)
+ AUC.append(auc_score)
+ # min distance: minimum among all possible pairs of
+ all_distances = []
+ for gt_gaze in valid_gaze:
+ all_distances.append(L2_dist(gt_gaze, norm_p))
+ min_dist.append(min(all_distances))
+ # average distance: distance between the predicted point and human average point
+ mean_gt_gaze = torch.mean(valid_gaze, 0)
+ avg_distance = L2_dist(mean_gt_gaze, norm_p)
+ avg_dist.append(avg_distance)
+
+ print("|AUC |min dist|avg dist|")
+ print(
+ "|{:.4f}|{:.4f} |{:.4f} |".format(
+ torch.mean(torch.tensor(AUC)),
+ torch.mean(torch.tensor(min_dist)),
+ torch.mean(torch.tensor(avg_dist)),
+ )
+ )
+
+
+def main(args):
+ cfg = LazyConfig.load(args.config_file)
+ model: torch.Module = instantiate(cfg.model)
+ model.load_state_dict(torch.load(args.model_weights)["model"])
+ model.to(cfg.train.device)
+ do_test(cfg, model, use_dark_inference=args.use_dark_inference)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config_file", type=str, help="config file")
+ parser.add_argument(
+ "--model_weights",
+ type=str,
+ help="model weights",
+ )
+ parser.add_argument("--use_dark_inference", action="store_true")
+ args = parser.parse_args()
+ main(args)
diff --git a/tools/eval_on_video_attention_target.py b/tools/eval_on_video_attention_target.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae770fef9d5be39c0abcabc90f56962285902f61
--- /dev/null
+++ b/tools/eval_on_video_attention_target.py
@@ -0,0 +1,93 @@
+import sys
+from os import path as osp
+import argparse
+import warnings
+import torch
+import numpy as np
+from PIL import Image
+from detectron2.config import instantiate, LazyConfig
+
+sys.path.append(osp.dirname(osp.dirname(__file__)))
+from utils import *
+
+
+warnings.simplefilter(action="ignore", category=FutureWarning)
+
+
+def do_test(cfg, model, use_dark_inference=False):
+ val_loader = instantiate(cfg.dataloader.val)
+
+ model.train(False)
+ AUC = []
+ dist = []
+ inout_gt = []
+ inout_pred = []
+ with torch.no_grad():
+ for data in val_loader:
+ val_gaze_heatmap_pred, val_gaze_inout_pred = model(data)
+ val_gaze_heatmap_pred = (
+ val_gaze_heatmap_pred.squeeze(1).cpu().detach().numpy()
+ )
+ val_gaze_inout_pred = val_gaze_inout_pred.cpu().detach().numpy()
+
+ # go through each data point and record AUC, dist, ap
+ for b_i in range(len(val_gaze_heatmap_pred)):
+ auc_batch = []
+ dist_batch = []
+ if data["gaze_inouts"][b_i]:
+ # remove padding and recover valid ground truth points
+ valid_gaze = data["gazes"][b_i]
+ # AUC: area under curve of ROC
+ multi_hot = data["heatmaps"][b_i]
+ multi_hot = (multi_hot > 0).float().numpy()
+ if use_dark_inference:
+ pred_x, pred_y = dark_inference(val_gaze_heatmap_pred[b_i])
+ else:
+ pred_x, pred_y = argmax_pts(val_gaze_heatmap_pred[b_i])
+ norm_p = [
+ pred_x / val_gaze_heatmap_pred[b_i].shape[-1],
+ pred_y / val_gaze_heatmap_pred[b_i].shape[-2],
+ ]
+ scaled_heatmap = np.array(
+ Image.fromarray(val_gaze_heatmap_pred[b_i]).resize(
+ (64, 64),
+ resample=Image.Resampling.BILINEAR,
+ )
+ )
+ auc_score = auc(scaled_heatmap, multi_hot)
+ auc_batch.append(auc_score)
+ dist_batch.append(L2_dist(valid_gaze.numpy(), norm_p))
+ AUC.extend(auc_batch)
+ dist.extend(dist_batch)
+ inout_gt.extend(data["gaze_inouts"].cpu().numpy())
+ inout_pred.extend(val_gaze_inout_pred)
+
+ print("|AUC |dist |AP |")
+ print(
+ "|{:.4f}|{:.4f} |{:.4f} |".format(
+ torch.mean(torch.tensor(AUC)),
+ torch.mean(torch.tensor(dist)),
+ ap(inout_gt, inout_pred),
+ )
+ )
+
+
+def main(args):
+ cfg = LazyConfig.load(args.config_file)
+ model: torch.Module = instantiate(cfg.model)
+ model.load_state_dict(torch.load(args.model_weights)["model"])
+ model.to(cfg.train.device)
+ do_test(cfg, model, use_dark_inference=args.use_dark_inference)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config_file", type=str, help="config file")
+ parser.add_argument(
+ "--model_weights",
+ type=str,
+ help="model weights",
+ )
+ parser.add_argument("--use_dark_inference", action="store_true")
+ args = parser.parse_args()
+ main(args)
diff --git a/tools/train.py b/tools/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..93671beb24924c6e2e4cef92a4cb802bd10e5253
--- /dev/null
+++ b/tools/train.py
@@ -0,0 +1,104 @@
+import os.path as osp
+import logging
+import warnings
+import sys
+
+from detectron2.checkpoint import DetectionCheckpointer
+from detectron2.config import LazyConfig, instantiate
+from detectron2.engine import (
+ default_argument_parser,
+ default_setup,
+ default_writers,
+ hooks,
+ launch,
+)
+from detectron2.engine.defaults import create_ddp_model
+from detectron2.utils import comm
+
+
+sys.path.append(osp.dirname(osp.dirname(__file__)))
+warnings.filterwarnings("ignore")
+logger = logging.getLogger("detectron2")
+
+
+from engine import CycleTrainer
+
+
+def do_train(args, cfg):
+ """
+ Args:
+ cfg: an object with the following attributes:
+ model: instantiate to a module
+ dataloader.{train,test}: instantiate to dataloaders
+ dataloader.evaluator: instantiate to evaluator for test set
+ optimizer: instantaite to an optimizer
+ lr_multiplier: instantiate to a fvcore scheduler
+ train: other misc config defined in `configs/common/train.py`, including:
+ output_dir (str)
+ init_checkpoint (str)
+ amp.enabled (bool)
+ max_iter (int)
+ eval_period, log_period (int)
+ device (str)
+ checkpointer (dict)
+ ddp (dict)
+ """
+ model = instantiate(cfg.model)
+ logger = logging.getLogger("detectron2")
+ logger.info("Model:\n{}".format(model))
+ model.to(cfg.train.device)
+
+ cfg.optimizer.params.model = model
+ optim = instantiate(cfg.optimizer)
+
+ train_loader = instantiate(cfg.dataloader.train)
+
+ model = create_ddp_model(model, **cfg.train.ddp)
+ trainer = CycleTrainer(model, train_loader, optim)
+ checkpointer = DetectionCheckpointer(
+ model,
+ cfg.train.output_dir,
+ trainer=trainer,
+ )
+ trainer.register_hooks(
+ [
+ hooks.IterationTimer(),
+ hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)),
+ hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer)
+ if comm.is_main_process()
+ else None,
+ # hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)),
+ hooks.PeriodicWriter(
+ default_writers(cfg.train.output_dir, cfg.train.max_iter),
+ period=cfg.train.log_period,
+ )
+ if comm.is_main_process()
+ else None,
+ ]
+ )
+
+ checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume)
+ if args.resume and checkpointer.has_checkpoint():
+ start_iter = trainer.iter + 1
+ else:
+ start_iter = 0
+ trainer.train(start_iter, cfg.train.max_iter)
+
+
+def main(args):
+ cfg = LazyConfig.load(args.config_file)
+ cfg = LazyConfig.apply_overrides(cfg, args.opts)
+ default_setup(cfg, args)
+ do_train(args, cfg)
+
+
+if __name__ == "__main__":
+ args = default_argument_parser().parse_args()
+ launch(
+ main,
+ args.num_gpus,
+ num_machines=args.num_machines,
+ machine_rank=args.machine_rank,
+ dist_url=args.dist_url,
+ args=(args,),
+ )
diff --git a/tools/utils.py b/tools/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..633cfc5184fba56042259919d74542780bedf9da
--- /dev/null
+++ b/tools/utils.py
@@ -0,0 +1,164 @@
+from typing import Union, Iterable, Tuple
+import numpy as np
+import torch
+import cv2
+from sklearn.metrics import roc_auc_score
+from sklearn.metrics import average_precision_score
+
+
+def auc(heatmap, onehot_im, is_im=True):
+ if is_im:
+ auc_score = roc_auc_score(
+ np.reshape(onehot_im, onehot_im.size), np.reshape(heatmap, heatmap.size)
+ )
+ else:
+ auc_score = roc_auc_score(onehot_im, heatmap)
+ return auc_score
+
+
+def ap(label, pred):
+ return average_precision_score(label, pred)
+
+
+def argmax_pts(heatmap):
+ idx = np.unravel_index(heatmap.argmax(), heatmap.shape)
+ pred_y, pred_x = map(float, idx)
+ return pred_x, pred_y
+
+
+def L2_dist(p1, p2):
+ return np.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2)
+
+
+def multi_hot_targets(gaze_pts, out_res):
+ w, h = out_res
+ target_map = np.zeros((h, w))
+ for p in gaze_pts:
+ if p[0] >= 0:
+ x, y = map(int, [p[0] * w.float(), p[1] * h.float()])
+ x = min(x, w - 1)
+ y = min(y, h - 1)
+ target_map[y, x] = 1
+ return target_map
+
+
+def inverse_transform(tensor: torch.Tensor) -> np.ndarray:
+ tensor = tensor.detach().cpu().permute(0, 2, 3, 1)
+ mean = torch.tensor([0.485, 0.456, 0.406])
+ std = torch.tensor([0.229, 0.224, 0.225])
+ tensor = tensor * std + mean
+ return cv2.cvtColor((tensor.numpy() * 255).astype(np.uint8)[0], cv2.COLOR_RGB2BGR)
+
+
+def draw(data, heatmap, out_path, on_img=True):
+ img = inverse_transform(data["images"])
+ head_channel = cv2.applyColorMap(
+ (data["head_channels"].squeeze().detach().cpu().numpy() * 255).astype(np.uint8),
+ cv2.COLORMAP_BONE,
+ )
+ hm = cv2.applyColorMap((heatmap * 255).astype(np.uint8), cv2.COLORMAP_JET)
+ heatmap = hm
+ heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
+ if on_img:
+ img = cv2.addWeighted(img, 1, heatmap, 0.5, 1)
+ else:
+ img = heatmap
+ # img = cv2.addWeighted(img, 1, head_channel, 0.1, 1)
+ cv2.imwrite(out_path, img)
+
+
+def draw_origin_img(data, out_path):
+ img = inverse_transform(data["images"])
+ hm = cv2.applyColorMap(
+ (data["heatmaps"].squeeze().detach().cpu().numpy() * 255).astype(np.uint8),
+ cv2.COLORMAP_JET,
+ )
+ hm[data["heatmaps"].squeeze().detach().cpu().numpy() == 0] = 0
+ hm = cv2.resize(hm, (img.shape[1], img.shape[0]))
+ head_channel = cv2.applyColorMap(
+ (data["head_channels"].squeeze().detach().cpu().numpy() * 255).astype(np.uint8),
+ cv2.COLORMAP_BONE,
+ )
+ head_channel[data["head_channels"].squeeze().detach().cpu().numpy() < 0.1] = 0
+ hm = cv2.resize(hm, (img.shape[1], img.shape[0]))
+ ori = cv2.addWeighted(img, 1, hm, 0.5, 1)
+ ori = cv2.addWeighted(ori, 1, head_channel, 0.1, 1)
+ cv2.imwrite(out_path, ori)
+
+
+class __Image2MP4:
+ def __init__(self):
+ self.Fourcc = cv2.VideoWriter_fourcc(*"mp4v")
+
+ def __call__(
+ self,
+ frames: Union[Iterable[np.ndarray], str],
+ path: str,
+ fps: float = 30.0,
+ isize: Tuple[int, int] = None,
+ ):
+ if isinstance(frames, str): # directory of img files
+ from os import listdir, path as osp
+
+ imgs = sorted(listdir(frames))
+ frames = [
+ cv2.imread(osp.join(frames, img), cv2.IMREAD_COLOR) for img in imgs
+ ]
+
+ if isize is None:
+ isize = (frames[0].shape[1], frames[0].shape[0])
+
+ output_video = cv2.VideoWriter(path, self.Fourcc, fps, isize)
+ for frame in frames:
+ frame = cv2.resize(frame, isize)
+ output_video.write(frame)
+ output_video.release()
+
+
+img2mp4 = __Image2MP4()
+
+
+def dark_inference(heatmap: np.ndarray, gaussian_kernel: int = 39):
+ pred_x, pred_y = argmax_pts(heatmap)
+ pred_x, pred_y = int(pred_x), int(pred_y)
+ height, width = heatmap.shape[-2:]
+ # Gaussian blur
+ orig_max = heatmap.max()
+ border = (gaussian_kernel - 1) // 2
+ dr = np.zeros((height + 2 * border, width + 2 * border))
+ dr[border:-border, border:-border] = heatmap.copy()
+ dr = cv2.GaussianBlur(dr, (gaussian_kernel, gaussian_kernel), 0)
+ heatmap = dr[border:-border, border:-border].copy()
+ heatmap *= orig_max / np.max(heatmap)
+ # Log-likelihood
+ heatmap = np.maximum(heatmap, 1e-10)
+ heatmap = np.log(heatmap)
+ # DARK
+ if 1 < pred_x < width - 2 and 1 < pred_y < height - 2:
+ dx = 0.5 * (heatmap[pred_y][pred_x + 1] - heatmap[pred_y][pred_x - 1])
+ dy = 0.5 * (heatmap[pred_y + 1][pred_x] - heatmap[pred_y - 1][pred_x])
+ dxx = 0.25 * (
+ heatmap[pred_y][pred_x + 2]
+ - 2 * heatmap[pred_y][pred_x]
+ + heatmap[pred_y][pred_x - 2]
+ )
+ dxy = 0.25 * (
+ heatmap[pred_y + 1][pred_x + 1]
+ - heatmap[pred_y - 1][pred_x + 1]
+ - heatmap[pred_y + 1][pred_x - 1]
+ + heatmap[pred_y - 1][pred_x - 1]
+ )
+ dyy = 0.25 * (
+ heatmap[pred_y + 2][pred_x]
+ - 2 * heatmap[pred_y][pred_x]
+ + heatmap[pred_y - 2][pred_x]
+ )
+ derivative = np.matrix([[dx],[dy]])
+ hessian = np.matrix([[dxx,dxy],[dxy,dyy]])
+ if dxx * dyy - dxy ** 2 != 0:
+ hessianinv = hessian.I
+ offset = -hessianinv * derivative
+ offset_x, offset_y = np.squeeze(np.array(offset.T), axis=0)
+ pred_x += offset_x
+ pred_y += offset_y
+ return pred_x, pred_y
diff --git a/train.sh b/train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..0b942ebde34c3f91b61a6be836fee7901f72c829
--- /dev/null
+++ b/train.sh
@@ -0,0 +1,23 @@
+#!/bin/bash
+
+export CUDA_VISIBLE_DEVICES="0,1"
+
+config_files=(
+ "configs/gazefollow.py"
+ "configs/gazefollow_518.py"
+ "configs/videoattentiontarget.py"
+)
+
+run_experiment() {
+ local config="$1"
+ echo "Running experiment with config: $config"
+ python -u tools/train.py --config-file "$config" --num-gpu 2
+}
+
+for config in "${config_files[@]}"
+do
+ run_experiment "$config" &
+ pid=$!
+ wait "$pid"
+ sleep 10
+done
diff --git a/val.sh b/val.sh
new file mode 100644
index 0000000000000000000000000000000000000000..7e043f9ecf0481df6cd096863427f9fa3029e349
--- /dev/null
+++ b/val.sh
@@ -0,0 +1,22 @@
+#!/bin/bash
+# [Usage] bash val.sh [$config_file] [$weight_file] [$(gf/vat)]; e.g.
+# bash val.sh configs/gazefollow_518.py output/gazefollow_518/model_final.pth gf
+# bash val.sh configs/videoattentiontarget.py output/videoattentiontarget/model_final.pth vat
+
+config_file="$1"
+checkpoint="$2"
+
+if [ "$3" = "gf" ]; then
+ evaluater="tools/eval_on_gazefollow.py"
+elif [ "$3" = "vat" ]; then
+ evaluater="tools/eval_on_video_attention_target.py"
+else
+ echo "Invalid dataset"
+ exit 1
+fi
+
+export CUDA_VISIBLE_DEVICES="0"
+echo "Evaluating with:"
+echo "config: $config_file"
+echo "checkpoint: $checkpoint"
+python -u $evaluater --config_file $config_file --model_weights "$checkpoint" --use_dark_inference