yhsong commited on
Commit
f9561b9
·
verified ·
1 Parent(s): 64c8ca3

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. .gitignore +16 -0
  3. LICENSE +21 -0
  4. README.md +78 -3
  5. assets/comparion.png +0 -0
  6. assets/demo0.gif +3 -0
  7. assets/demo1.gif +3 -0
  8. assets/pipeline.png +0 -0
  9. configs/common/__init__.py +1 -0
  10. configs/common/dataloader.py +214 -0
  11. configs/common/model.py +44 -0
  12. configs/common/optimizer.py +48 -0
  13. configs/common/scheduler.py +18 -0
  14. configs/common/train.py +16 -0
  15. configs/gazefollow.py +86 -0
  16. configs/gazefollow_518.py +86 -0
  17. configs/videoattentiontarget.py +90 -0
  18. data/__init__.py +5 -0
  19. data/augmentation.py +312 -0
  20. data/data_utils.py +181 -0
  21. data/gazefollow.py +295 -0
  22. data/masking.py +175 -0
  23. data/video_attention_target.py +228 -0
  24. data/video_attention_target_video.py +464 -0
  25. docs/eval.md +24 -0
  26. docs/install.md +19 -0
  27. docs/train.md +36 -0
  28. engine/__init__.py +1 -0
  29. engine/trainer.py +37 -0
  30. modeling/_init__.py +4 -0
  31. modeling/backbone/__init__.py +1 -0
  32. modeling/backbone/utils.py +154 -0
  33. modeling/backbone/vit.py +504 -0
  34. modeling/criterion/__init__.py +1 -0
  35. modeling/criterion/gaze_mapper_criterion.py +94 -0
  36. modeling/head/__init__.py +2 -0
  37. modeling/head/heatmap_head.py +225 -0
  38. modeling/head/inout_head.py +55 -0
  39. modeling/meta_arch/__init__.py +1 -0
  40. modeling/meta_arch/gaze_attention_mapper.py +97 -0
  41. modeling/patch_attention/__init__.py +1 -0
  42. modeling/patch_attention/patch_attention.py +106 -0
  43. pretrained/gazefollow.pth +3 -0
  44. pretrained/videoattentiontarget.pth +3 -0
  45. requirements.txt +17 -0
  46. scripts/convert_pth.py +32 -0
  47. scripts/gen_gazefollow_head_masks.py +185 -0
  48. tools/eval_on_gazefollow.py +92 -0
  49. tools/eval_on_video_attention_target.py +93 -0
  50. tools/train.py +104 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/demo0.gif filter=lfs diff=lfs merge=lfs -text
37
+ assets/demo1.gif filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .vscode
2
+ .idea
3
+ *output
4
+ backup
5
+ */__pycache__/*
6
+ pretrained
7
+ logs
8
+ *.pyc
9
+ *.ipynb
10
+ *.out
11
+ *.log
12
+ *.pth
13
+ *.pkl
14
+ *.pt
15
+ *.npy
16
+ *debug*
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 HUST Vision Lab
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,78 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <h1>ViTGaze 👀</h1>
3
+ <h3>Gaze Following with Interaction Features in Vision Transformers</h3>
4
+
5
+ Yuehao Song<sup>1</sup> , Xinggang Wang<sup>1 :email:</sup> , Jingfeng Yao<sup>1</sup> , Wenyu Liu<sup>1</sup> , Jinglin Zhang<sup>2</sup> , Xiangmin Xu<sup>3</sup>
6
+
7
+ <sup>1</sup> Huazhong University of Science and Technology, <sup>2</sup> Shandong University, <sup>3</sup> South China University of Technology
8
+
9
+ (<sup>:email:</sup>) corresponding author.
10
+
11
+ ArXiv Preprint ([arXiv 2403.12778](https://arxiv.org/abs/2403.12778))
12
+
13
+ </div>
14
+
15
+ #
16
+ ![Demo0](assets/demo0.gif)
17
+ ![Demo1](assets/demo1.gif)
18
+ ### News
19
+ * **`Mar. 25th, 2024`:** We release an initial version of ViTGaze.
20
+ * **`Mar. 19th, 2024`:** We released our paper on Arxiv. Code/Models are coming soon. Please stay tuned! ☕️
21
+
22
+
23
+ ## Introduction
24
+ <div align="center"><h5>Plain Vision Transformer could also do gaze following with the simple ViTGaze framework!</h5></div>
25
+
26
+ ![framework](assets/pipeline.png "framework")
27
+
28
+ Inspired by the remarkable success of pre-trained plain Vision Transformers (ViTs), we introduce a novel single-modality gaze following framework, **ViTGaze**. In contrast to previous methods, it creates a brand new gaze following framework based mainly on powerful encoders (relative decoder parameter less than 1%). Our principal insight lies in that the inter-token interactions within self-attention can be transferred to interactions between humans and scenes. Our method achieves state-of-the-art (SOTA) performance among all single-modality methods (3.4% improvement on AUC, 5.1% improvement on AP) and very comparable performance against multi-modality methods with 59% number of parameters less.
29
+
30
+ ## Results
31
+ > Results from the [ViTGaze paper](https://arxiv.org/abs/2403.12778)
32
+
33
+ ![comparison](assets/comparion.png "comparison")
34
+
35
+ <table align="center">
36
+ <tr>
37
+ <th colspan="3">Results on <a herf=http://gazefollow.csail.mit.edu/index.html>GazeFollow</a></th>
38
+ <th colspan="3">Results on <a herf=https://github.com/ejcgt/attention-target-detection>VideoAttentionTarget</a></th>
39
+ </tr>
40
+ <tr>
41
+ <td><b>AUC</b></td>
42
+ <td><b>Avg. Dist.</b></td>
43
+ <td><b>Min. Dist.</b></td>
44
+ <td><b>AUC</b></td>
45
+ <td><b>Dist.</b></td>
46
+ <td><b>AP</b></td>
47
+ </tr>
48
+ <tr>
49
+ <td>0.949</td>
50
+ <td>0.105</td>
51
+ <td>0.047</td>
52
+ <td>0.938</td>
53
+ <td>0.102</td>
54
+ <td>0.905</td>
55
+ </tr>
56
+ </table>
57
+
58
+ Corresponding checkpoints are released:
59
+ - GazeFollow: [GoogleDrive](https://drive.google.com/file/d/164c4woGCmUI8UrM7GEKQrV1FbA3vGwP4/view?usp=drive_link)
60
+ - VideoAttentionTarget: [GoogleDrive](https://drive.google.com/file/d/11_O4Jm5wsvQ8qfLLgTlrudqSNvvepsV0/view?usp=drive_link)
61
+ ## Getting Started
62
+ - [Installation](docs/install.md)
63
+ - [Train](docs/train.md)
64
+ - [Eval](docs/eval.md)
65
+
66
+ ## Acknowledgements
67
+ ViTGaze is based on [detectron2](https://github.com/facebookresearch/detectron2). We use the efficient multi-head attention implemented in the [xFormers](https://github.com/facebookresearch/xformers) library.
68
+
69
+ ## Citation
70
+ If you find ViTGaze is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.
71
+ ```bibtex
72
+ @article{vitgaze,
73
+ title={ViTGaze: Gaze Following with Interaction Features in Vision Transformers},
74
+ author={Yuehao Song and Xinggang Wang and Jingfeng Yao and Wenyu Liu and Jinglin Zhang and Xiangmin Xu},
75
+ journal={arXiv preprint arXiv:2403.12778},
76
+ year={2024}
77
+ }
78
+ ```
assets/comparion.png ADDED
assets/demo0.gif ADDED

Git LFS Details

  • SHA256: 89e36ce1438c230928f376b4a2668e5daabace217a63e1df095ddce7202851ec
  • Pointer size: 132 Bytes
  • Size of remote file: 5.92 MB
assets/demo1.gif ADDED

Git LFS Details

  • SHA256: 3d5c7a413de2f9ff12818d0f476f868e5fb867c17d85315cce3eb090212468da
  • Pointer size: 132 Bytes
  • Size of remote file: 9.48 MB
assets/pipeline.png ADDED
configs/common/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import *
configs/common/dataloader.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import path as osp
2
+ from typing import Literal
3
+
4
+ from omegaconf import OmegaConf
5
+ from detectron2.config import LazyCall as L
6
+ from detectron2.config import instantiate
7
+ from torch.utils.data import DataLoader
8
+ from torch.utils.data.distributed import DistributedSampler
9
+
10
+ from data import *
11
+
12
+
13
+ DATA_ROOT = "${Root to Datasets}"
14
+ if DATA_ROOT == "${Root to Datasets}":
15
+ raise Exception(
16
+ f"""{osp.abspath(__file__)}: Rewrite `DATA_ROOT` with the root to the datasets.
17
+ The directory structure should be:
18
+ -DATA_ROOT
19
+ |-videoattentiontarget
20
+ |--images
21
+ |--annotations
22
+ |---train
23
+ |---test
24
+ |--head_masks
25
+ |---images
26
+ |-gazefollow
27
+ |--train
28
+ |--test2
29
+ |--train_annotations_release.txt
30
+ |--test_annotations_release.txt
31
+ |--head_masks
32
+ |---train
33
+ |---test2
34
+ """
35
+ )
36
+
37
+ # Basic Config for Video Attention Target dataset and preprocessing
38
+ data_info = OmegaConf.create()
39
+ data_info.video_attention_target = OmegaConf.create()
40
+ data_info.video_attention_target.train_root = osp.join(
41
+ DATA_ROOT, "videoattentiontarget/images"
42
+ )
43
+ data_info.video_attention_target.train_anno = osp.join(
44
+ DATA_ROOT, "videoattentiontarget/annotations/train"
45
+ )
46
+ data_info.video_attention_target.val_root = osp.join(
47
+ DATA_ROOT, "videoattentiontarget/images"
48
+ )
49
+ data_info.video_attention_target.val_anno = osp.join(
50
+ DATA_ROOT, "videoattentiontarget/annotations/test"
51
+ )
52
+ data_info.video_attention_target.head_root = osp.join(
53
+ DATA_ROOT, "videoattentiontarget/head_masks/images"
54
+ )
55
+
56
+ data_info.video_attention_target_video = OmegaConf.create()
57
+ data_info.video_attention_target_video.train_root = osp.join(
58
+ DATA_ROOT, "videoattentiontarget/images"
59
+ )
60
+ data_info.video_attention_target_video.train_anno = osp.join(
61
+ DATA_ROOT, "videoattentiontarget/annotations/train"
62
+ )
63
+ data_info.video_attention_target_video.val_root = osp.join(
64
+ DATA_ROOT, "videoattentiontarget/images"
65
+ )
66
+ data_info.video_attention_target_video.val_anno = osp.join(
67
+ DATA_ROOT, "videoattentiontarget/annotations/test"
68
+ )
69
+ data_info.video_attention_target_video.head_root = osp.join(
70
+ DATA_ROOT, "videoattentiontarget/head_masks/images"
71
+ )
72
+
73
+ data_info.gazefollow = OmegaConf.create()
74
+ data_info.gazefollow.train_root = osp.join(DATA_ROOT, "gazefollow")
75
+ data_info.gazefollow.train_anno = osp.join(
76
+ DATA_ROOT, "gazefollow/train_annotations_release.txt"
77
+ )
78
+ data_info.gazefollow.val_root = osp.join(DATA_ROOT, "gazefollow")
79
+ data_info.gazefollow.val_anno = osp.join(
80
+ DATA_ROOT, "gazefollow/test_annotations_release.txt"
81
+ )
82
+ data_info.gazefollow.head_root = osp.join(DATA_ROOT, "gazefollow/head_masks")
83
+
84
+ data_info.input_size = 224
85
+ data_info.output_size = 64
86
+ data_info.quant_labelmap = True
87
+ data_info.mean = (0.485, 0.456, 0.406)
88
+ data_info.std = (0.229, 0.224, 0.225)
89
+ data_info.bbox_jitter = 0.5
90
+ data_info.rand_crop = 0.5
91
+ data_info.rand_flip = 0.5
92
+ data_info.color_jitter = 0.5
93
+ data_info.rand_rotate = 0.0
94
+ data_info.rand_lsj = 0.0
95
+
96
+ data_info.mask_size = 24
97
+ data_info.mask_scene = False
98
+ data_info.mask_head = False
99
+ data_info.max_scene_patches_ratio = 0.5
100
+ data_info.max_head_patches_ratio = 0.3
101
+ data_info.mask_prob = 0.2
102
+
103
+ data_info.seq_len = 16
104
+ data_info.max_len = 32
105
+
106
+
107
+ # Dataloader(gazefollow/video_atention_target, train/val)
108
+ def __build_dataloader(
109
+ name: Literal[
110
+ "gazefollow", "video_attention_target", "video_attention_target_video"
111
+ ],
112
+ is_train: bool,
113
+ batch_size: int = 64,
114
+ num_workers: int = 14,
115
+ pin_memory: bool = True,
116
+ persistent_workers: bool = True,
117
+ drop_last: bool = True,
118
+ distributed: bool = False,
119
+ **kwargs,
120
+ ):
121
+ assert name in [
122
+ "gazefollow",
123
+ "video_attention_target",
124
+ "video_attention_target_video",
125
+ ], f'{name} not in ("gazefollow", "video_attention_target", "video_attention_target_video")'
126
+
127
+ for k, v in kwargs.items():
128
+ if k in ["train_root", "train_anno", "val_root", "val_anno", "head_root"]:
129
+ data_info[name][k] = v
130
+ else:
131
+ data_info[k] = v
132
+
133
+ datasets = {
134
+ "gazefollow": GazeFollow,
135
+ "video_attention_target": VideoAttentionTarget,
136
+ "video_attention_target_video": VideoAttentionTargetVideo,
137
+ }
138
+ dataset = L(datasets[name])(
139
+ image_root=data_info[name]["train_root" if is_train else "val_root"],
140
+ anno_root=data_info[name]["train_anno" if is_train else "val_anno"],
141
+ head_root=data_info[name]["head_root"],
142
+ transform=get_transform(
143
+ input_resolution=data_info.input_size,
144
+ mean=data_info.mean,
145
+ std=data_info.std,
146
+ ),
147
+ input_size=data_info.input_size,
148
+ output_size=data_info.output_size,
149
+ quant_labelmap=data_info.quant_labelmap,
150
+ is_train=is_train,
151
+ bbox_jitter=data_info.bbox_jitter,
152
+ rand_crop=data_info.rand_crop,
153
+ rand_flip=data_info.rand_flip,
154
+ color_jitter=data_info.color_jitter,
155
+ rand_rotate=data_info.rand_rotate,
156
+ rand_lsj=data_info.rand_lsj,
157
+ mask_generator=(
158
+ MaskGenerator(
159
+ input_size=data_info.mask_size,
160
+ mask_scene=data_info.mask_scene,
161
+ mask_head=data_info.mask_head,
162
+ max_scene_patches_ratio=data_info.max_scene_patches_ratio,
163
+ max_head_patches_ratio=data_info.max_head_patches_ratio,
164
+ mask_prob=data_info.mask_prob,
165
+ )
166
+ if is_train
167
+ else None
168
+ ),
169
+ )
170
+ if name == "video_attention_target_video":
171
+ dataset.seq_len = data_info.seq_len
172
+ dataset.max_len = data_info.max_len
173
+ dataset = instantiate(dataset)
174
+
175
+ return DataLoader(
176
+ dataset=dataset,
177
+ batch_size=batch_size,
178
+ num_workers=num_workers,
179
+ pin_memory=pin_memory,
180
+ persistent_workers=persistent_workers,
181
+ collate_fn=video_collate if name == "video_attention_target_video" else None,
182
+ sampler=DistributedSampler(dataset, shuffle=is_train) if distributed else None,
183
+ drop_last=drop_last,
184
+ )
185
+
186
+
187
+ dataloader = OmegaConf.create()
188
+ dataloader.gazefollow = OmegaConf.create()
189
+ dataloader.gazefollow.train = L(__build_dataloader)(
190
+ name="gazefollow",
191
+ is_train=True,
192
+ )
193
+ dataloader.gazefollow.val = L(__build_dataloader)(
194
+ name="gazefollow",
195
+ is_train=False,
196
+ )
197
+ dataloader.video_attention_target = OmegaConf.create()
198
+ dataloader.video_attention_target.train = L(__build_dataloader)(
199
+ name="video_attention_target",
200
+ is_train=True,
201
+ )
202
+ dataloader.video_attention_target.val = L(__build_dataloader)(
203
+ name="video_attention_target",
204
+ is_train=False,
205
+ )
206
+ dataloader.video_attention_target_video = OmegaConf.create()
207
+ dataloader.video_attention_target_video.train = L(__build_dataloader)(
208
+ name="video_attention_target_video",
209
+ is_train=True,
210
+ )
211
+ dataloader.video_attention_target_video.val = L(__build_dataloader)(
212
+ name="video_attention_target_video",
213
+ is_train=False,
214
+ )
configs/common/model.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from detectron2.config import LazyCall as L
2
+
3
+ from modeling import backbone, patch_attention, meta_arch, head, criterion
4
+
5
+
6
+ model = L(meta_arch.GazeAttentionMapper)()
7
+ model.backbone = L(backbone.build_backbone)(
8
+ name="small", out_attn=[2, 5, 8, 11]
9
+ )
10
+ model.pam = L(patch_attention.build_pam)(name="PatchPAM", patch_size=16)
11
+ model.regressor = L(head.build_heatmap_head)(
12
+ name="SimpleDeconv",
13
+ in_channel=24,
14
+ deconv_cfgs=[
15
+ {
16
+ "in_channels": 24,
17
+ "out_channels": 12,
18
+ "kernel_size": 3,
19
+ "stride": 2,
20
+ },
21
+ {
22
+ "in_channels": 12,
23
+ "out_channels": 6,
24
+ "kernel_size": 3,
25
+ "stride": 2,
26
+ },
27
+ {
28
+ "in_channels": 6,
29
+ "out_channels": 3,
30
+ "kernel_size": 3,
31
+ "stride": 2,
32
+ },
33
+ {
34
+ "in_channels": 3,
35
+ "out_channels": 1,
36
+ "kernel_size": 3,
37
+ "stride": 2,
38
+ },
39
+ ],
40
+ feat_type="attn",
41
+ )
42
+ model.classifier = L(head.build_inout_head)(name="SimpleLinear", in_channel=384)
43
+ model.criterion = L(criterion.GazeMapperCriterion)()
44
+ model.device = "cuda"
configs/common/optimizer.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from detectron2 import model_zoo
2
+
3
+
4
+ def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12):
5
+ """
6
+ Calculate lr decay rate for different ViT blocks.
7
+ Args:
8
+ name (string): parameter name.
9
+ lr_decay_rate (float): base lr decay rate.
10
+ num_layers (int): number of ViT blocks.
11
+
12
+ Returns:
13
+ lr decay rate for the given parameter.
14
+ """
15
+ layer_id = num_layers + 1
16
+ if name.startswith("backbone"):
17
+ if ".pos_embed" in name or ".patch_embed" in name:
18
+ layer_id = 0
19
+ elif ".blocks." in name and ".residual." not in name:
20
+ layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
21
+
22
+ return lr_decay_rate ** (num_layers + 1 - layer_id)
23
+
24
+
25
+ class LRDecayRater:
26
+ def __init__(self, lr_decay_rate=1.0, num_layers=12, backbone_multiplier=1.0, freeze_pe=False, pam_lr_decay=1):
27
+ self.lr_decay_rate = lr_decay_rate
28
+ self.num_layers = num_layers
29
+ self.backbone_multiplier = backbone_multiplier
30
+ self.freeze_pe = freeze_pe
31
+ self.pam_lr_decay = pam_lr_decay
32
+
33
+ def __call__(self, name):
34
+ if name.startswith("backbone"):
35
+ if self.freeze_pe and ".pos_embed" in name or ".patch_embed" in name:
36
+ return 0
37
+ return self.backbone_multiplier * get_vit_lr_decay_rate(
38
+ name, self.lr_decay_rate, self.num_layers
39
+ )
40
+ if name.startswith("pam"):
41
+ return self.pam_lr_decay
42
+ return 1
43
+
44
+
45
+ # Optimizer
46
+ optimizer = model_zoo.get_config("common/optim.py").AdamW
47
+ optimizer.params.lr_factor_func = LRDecayRater(num_layers=12, lr_decay_rate=0.65)
48
+ optimizer.params.overrides = {"pos_embed": {"weight_decay": 0.0}}
configs/common/scheduler.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+ from detectron2.config import LazyCall as L
3
+ from detectron2.solver import WarmupParamScheduler
4
+ from fvcore.common.param_scheduler import MultiStepParamScheduler, CosineParamScheduler
5
+
6
+
7
+ def get_scheduler(typ: Literal["multistep", "cosine"] = "multistep", **kwargs):
8
+ if typ == "multistep":
9
+ return MultiStepParamScheduler(**kwargs)
10
+ elif typ == "cosine":
11
+ return CosineParamScheduler(**kwargs)
12
+
13
+
14
+ lr_multiplier = L(WarmupParamScheduler)(
15
+ scheduler=L(get_scheduler)(),
16
+ warmup_length=0,
17
+ warmup_factor=0.001,
18
+ )
configs/common/train.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train = dict(
2
+ output_dir="./output",
3
+ init_checkpoint="",
4
+ max_iter=90000,
5
+ amp=dict(enabled=False), # options for Automatic Mixed Precision
6
+ ddp=dict( # options for DistributedDataParallel
7
+ broadcast_buffers=True,
8
+ find_unused_parameters=False,
9
+ fp16_compression=True,
10
+ ),
11
+ checkpointer=dict(period=5000, max_to_keep=100), # options for PeriodicCheckpointer
12
+ eval_period=5000,
13
+ log_period=100,
14
+ device="cuda",
15
+ # ...
16
+ )
configs/gazefollow.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .common.dataloader import dataloader
2
+ from .common.model import model
3
+ from .common.optimizer import optimizer
4
+ from .common.scheduler import lr_multiplier
5
+ from .common.train import train
6
+ from os.path import join, basename
7
+ from torch.cuda import device_count
8
+
9
+
10
+ num_gpu = device_count()
11
+ ins_per_iter = 48
12
+ len_dataset = 126000
13
+ num_epoch = 14
14
+ # dataloader
15
+ dataloader = dataloader.gazefollow
16
+ dataloader.train.batch_size = ins_per_iter // num_gpu
17
+ dataloader.train.num_workers = dataloader.val.num_workers = 14
18
+ dataloader.train.distributed = num_gpu > 1
19
+ dataloader.train.rand_rotate = 0.5
20
+ dataloader.train.rand_lsj = 0.5
21
+ dataloader.train.input_size = dataloader.val.input_size = 434
22
+ dataloader.train.mask_scene = True
23
+ dataloader.train.mask_prob = 0.5
24
+ dataloader.train.mask_size = dataloader.train.input_size // 14
25
+ dataloader.train.max_scene_patches_ratio = 0.5
26
+ dataloader.val.batch_size = 64
27
+ dataloader.val.distributed = False
28
+ # train
29
+ train.init_checkpoint = "pretrained/dinov2_small.pth"
30
+ train.output_dir = join("./output", basename(__file__).split(".")[0])
31
+ train.max_iter = len_dataset * num_epoch // ins_per_iter
32
+ train.log_period = len_dataset // (ins_per_iter * 100)
33
+ train.checkpointer.max_to_keep = 10
34
+ train.checkpointer.period = len_dataset // ins_per_iter
35
+ train.seed = 0
36
+ # optimizer
37
+ optimizer.lr = 1e-4
38
+ optimizer.betas = (0.9, 0.99)
39
+ lr_multiplier.scheduler.typ = "cosine"
40
+ lr_multiplier.scheduler.start_value = 1
41
+ lr_multiplier.scheduler.end_value = 0.1
42
+ lr_multiplier.warmup_length = 1e-2
43
+ # model
44
+ model.use_aux_loss = model.pam.use_aux_loss = model.criterion.use_aux_loss = True
45
+ model.pam.name = "PatchPAM"
46
+ model.pam.embed_dim = 8
47
+ model.pam.patch_size = 14
48
+ model.backbone.name = "dinov2_small"
49
+ model.backbone.return_softmax_attn = True
50
+ model.backbone.out_attn = [2, 5, 8, 11]
51
+ model.backbone.use_cls_token = True
52
+ model.backbone.use_mask_token = True
53
+ model.regressor.name = "UpSampleConv"
54
+ model.regressor.in_channel = 24
55
+ model.regressor.use_conv = False
56
+ model.regressor.dim = 24
57
+ model.regressor.deconv_cfgs = [
58
+ dict(
59
+ in_channels=24,
60
+ out_channels=16,
61
+ kernel_size=3,
62
+ stride=1,
63
+ padding=1,
64
+ ),
65
+ dict(
66
+ in_channels=16,
67
+ out_channels=8,
68
+ kernel_size=3,
69
+ stride=1,
70
+ padding=1,
71
+ ),
72
+ dict(
73
+ in_channels=8,
74
+ out_channels=1,
75
+ kernel_size=3,
76
+ stride=1,
77
+ padding=1,
78
+ ),
79
+ ]
80
+ model.regressor.feat_type = "attn"
81
+ model.classifier.name = "SimpleMlp"
82
+ model.classifier.in_channel = 384
83
+ model.criterion.aux_weight = 100
84
+ model.criterion.aux_head_thres = 0.05
85
+ model.criterion.use_focal_loss = True
86
+ model.device = "cuda"
configs/gazefollow_518.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .common.dataloader import dataloader
2
+ from .common.model import model
3
+ from .common.optimizer import optimizer
4
+ from .common.scheduler import lr_multiplier
5
+ from .common.train import train
6
+ from os.path import join, basename
7
+ from torch.cuda import device_count
8
+
9
+
10
+ num_gpu = device_count()
11
+ ins_per_iter = 32
12
+ len_dataset = 126000
13
+ num_epoch = 1
14
+ # dataloader
15
+ dataloader = dataloader.gazefollow
16
+ dataloader.train.batch_size = ins_per_iter // num_gpu
17
+ dataloader.train.num_workers = dataloader.val.num_workers = 14
18
+ dataloader.train.distributed = num_gpu > 1
19
+ dataloader.train.rand_rotate = 0.5
20
+ dataloader.train.rand_lsj = 0.5
21
+ dataloader.train.input_size = dataloader.val.input_size = 518
22
+ dataloader.train.mask_scene = True
23
+ dataloader.train.mask_prob = 0.5
24
+ dataloader.train.mask_size = dataloader.train.input_size // 14
25
+ dataloader.train.max_scene_patches_ratio = 0.5
26
+ dataloader.val.batch_size = 32
27
+ dataloader.val.distributed = False
28
+ # train
29
+ train.init_checkpoint = "output/gazefollow/model_final.pth"
30
+ train.output_dir = join("./output", basename(__file__).split(".")[0])
31
+ train.max_iter = len_dataset * num_epoch // ins_per_iter
32
+ train.log_period = len_dataset // (ins_per_iter * 100)
33
+ train.checkpointer.max_to_keep = 10
34
+ train.checkpointer.period = len_dataset // ins_per_iter
35
+ train.seed = 0
36
+ # optimizer
37
+ optimizer.lr = 1e-5
38
+ optimizer.betas = (0.9, 0.99)
39
+ lr_multiplier.scheduler.typ = "cosine"
40
+ lr_multiplier.scheduler.start_value = 1
41
+ lr_multiplier.scheduler.end_value = 0.1
42
+ lr_multiplier.warmup_length = 1e-2
43
+ # model
44
+ model.use_aux_loss = model.pam.use_aux_loss = model.criterion.use_aux_loss = True
45
+ model.pam.name = "PatchPAM"
46
+ model.pam.embed_dim = 8
47
+ model.pam.patch_size = 14
48
+ model.backbone.name = "dinov2_small"
49
+ model.backbone.return_softmax_attn = True
50
+ model.backbone.out_attn = [2, 5, 8, 11]
51
+ model.backbone.use_cls_token = True
52
+ model.backbone.use_mask_token = True
53
+ model.regressor.name = "UpSampleConv"
54
+ model.regressor.in_channel = 24
55
+ model.regressor.use_conv = False
56
+ model.regressor.dim = 24
57
+ model.regressor.deconv_cfgs = [
58
+ dict(
59
+ in_channels=24,
60
+ out_channels=16,
61
+ kernel_size=3,
62
+ stride=1,
63
+ padding=1,
64
+ ),
65
+ dict(
66
+ in_channels=16,
67
+ out_channels=8,
68
+ kernel_size=3,
69
+ stride=1,
70
+ padding=1,
71
+ ),
72
+ dict(
73
+ in_channels=8,
74
+ out_channels=1,
75
+ kernel_size=3,
76
+ stride=1,
77
+ padding=1,
78
+ ),
79
+ ]
80
+ model.regressor.feat_type = "attn"
81
+ model.classifier.name = "SimpleMlp"
82
+ model.classifier.in_channel = 384
83
+ model.criterion.aux_weight = 0
84
+ model.criterion.aux_head_thres = 0.05
85
+ model.criterion.use_focal_loss = True
86
+ model.device = "cuda"
configs/videoattentiontarget.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .common.dataloader import dataloader
2
+ from .common.model import model
3
+ from .common.optimizer import optimizer
4
+ from .common.scheduler import lr_multiplier
5
+ from .common.train import train
6
+ from os.path import join, basename
7
+ from torch.cuda import device_count
8
+
9
+
10
+ num_gpu = device_count()
11
+ ins_per_iter = 4
12
+ len_dataset = 4400
13
+ num_epoch = 1
14
+ # dataloader
15
+ dataloader = dataloader.video_attention_target_video
16
+ dataloader.train.batch_size = ins_per_iter // num_gpu
17
+ dataloader.train.num_workers = dataloader.val.num_workers = 14
18
+ dataloader.train.distributed = num_gpu > 1
19
+ dataloader.train.rand_rotate = 0.5
20
+ dataloader.train.rand_lsj = 0.5
21
+ dataloader.train.input_size = dataloader.val.input_size = 518
22
+ dataloader.train.mask_scene = True
23
+ dataloader.train.mask_prob = 0.5
24
+ dataloader.train.mask_size = dataloader.train.input_size // 14
25
+ dataloader.train.max_scene_patches_ratio = 0.5
26
+ dataloader.train.seq_len = 8
27
+ dataloader.val.quant_labelmap = False
28
+ dataloader.val.seq_len = 8
29
+ dataloader.val.batch_size = 4
30
+ dataloader.val.distributed = False
31
+ # train
32
+ train.init_checkpoint = "output/gazefollow_518/model_final.pth"
33
+ train.output_dir = join("./output", basename(__file__).split(".")[0])
34
+ train.max_iter = len_dataset * num_epoch // ins_per_iter
35
+ train.log_period = len_dataset // (ins_per_iter * 100)
36
+ train.checkpointer.max_to_keep = 100
37
+ train.checkpointer.period = len_dataset // ins_per_iter
38
+ train.seed = 0
39
+ # optimizer
40
+ optimizer.lr = 1e-6
41
+ # optimizer.params.lr_factor_func.backbone_multiplier = 0.1
42
+ # optimizer.params.lr_factor_func.pam_lr_decay = 0.1
43
+ lr_multiplier.scheduler.values = [1.0]
44
+ lr_multiplier.scheduler.milestones = []
45
+ lr_multiplier.scheduler.num_updates = train.max_iter
46
+ lr_multiplier.warmup_length = 0
47
+ # model
48
+ model.use_aux_loss = model.pam.use_aux_loss = model.criterion.use_aux_loss = True
49
+ model.pam.name = "PatchPAM"
50
+ model.pam.embed_dim = 8
51
+ model.pam.patch_size = 14
52
+ model.backbone.name = "dinov2_small"
53
+ model.backbone.return_softmax_attn = True
54
+ model.backbone.out_attn = [2, 5, 8, 11]
55
+ model.backbone.use_cls_token = True
56
+ model.backbone.use_mask_token = True
57
+ model.regressor.name = "UpSampleConv"
58
+ model.regressor.in_channel = 24
59
+ model.regressor.use_conv = False
60
+ model.regressor.dim = 24
61
+ model.regressor.deconv_cfgs = [
62
+ dict(
63
+ in_channels=24,
64
+ out_channels=16,
65
+ kernel_size=3,
66
+ stride=1,
67
+ padding=1,
68
+ ),
69
+ dict(
70
+ in_channels=16,
71
+ out_channels=8,
72
+ kernel_size=3,
73
+ stride=1,
74
+ padding=1,
75
+ ),
76
+ dict(
77
+ in_channels=8,
78
+ out_channels=1,
79
+ kernel_size=3,
80
+ stride=1,
81
+ padding=1,
82
+ ),
83
+ ]
84
+ model.regressor.feat_type = "attn"
85
+ model.classifier.name = "SimpleMlp"
86
+ model.classifier.in_channel = 384
87
+ model.criterion.aux_weight = 0
88
+ model.criterion.aux_head_thres = 0.05
89
+ model.criterion.use_focal_loss = True
90
+ model.device = "cuda"
data/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .video_attention_target import VideoAttentionTarget
2
+ from .video_attention_target_video import VideoAttentionTargetVideo, video_collate
3
+ from .gazefollow import GazeFollow
4
+ from .data_utils import get_transform
5
+ from .masking import MaskGenerator
data/augmentation.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Tuple, List
3
+ import numpy as np
4
+ from PIL import Image, ImageOps
5
+ from torchvision import transforms
6
+ from torchvision.transforms import functional as TF
7
+
8
+
9
+ class Augmentation:
10
+ def __init__(self, p: float) -> None:
11
+ self.p = p
12
+
13
+ def transform(
14
+ self,
15
+ image: Image,
16
+ bbox: Tuple[float],
17
+ gaze: Tuple[float],
18
+ head_mask: Image,
19
+ size: Tuple[int],
20
+ ):
21
+ raise NotImplementedError
22
+
23
+ def __call__(
24
+ self,
25
+ image: Image,
26
+ bbox: Tuple[float],
27
+ gaze: Tuple[float],
28
+ head_mask: Image,
29
+ size: Tuple[int],
30
+ ):
31
+ if np.random.random_sample() < self.p:
32
+ return self.transform(image, bbox, gaze, head_mask, size)
33
+ return image, bbox, gaze, head_mask, size
34
+
35
+
36
+ class AugmentationList:
37
+ def __init__(self, augmentations: List[Augmentation]) -> None:
38
+ self.augmentations = augmentations
39
+
40
+ def __call__(
41
+ self,
42
+ image: Image,
43
+ bbox: Tuple[float],
44
+ gaze: Tuple[float],
45
+ head_mask: Image,
46
+ size: Tuple[int],
47
+ ):
48
+ for aug in self.augmentations:
49
+ image, bbox, gaze, head_mask, size = aug(image, bbox, gaze, head_mask, size)
50
+ return image, bbox, gaze, head_mask, size
51
+
52
+
53
+ class BoxJitter(Augmentation):
54
+ # Jitter (expansion-only) bounding box size
55
+ def __init__(self, p: float, expansion: float = 0.2) -> None:
56
+ super().__init__(p)
57
+ self.expansion = expansion
58
+
59
+ def transform(
60
+ self,
61
+ image: Image,
62
+ bbox: Tuple[float],
63
+ gaze: Tuple[float],
64
+ head_mask: Image,
65
+ size: Tuple[int],
66
+ ):
67
+ x_min, y_min, x_max, y_max = bbox
68
+ width, height = size
69
+ k = np.random.random_sample() * self.expansion
70
+ x_min = np.clip(x_min - k * abs(x_max - x_min), 0, width - 1)
71
+ y_min = np.clip(y_min - k * abs(y_max - y_min), 0, height - 1)
72
+ x_max = np.clip(x_max + k * abs(x_max - x_min), 0, width - 1)
73
+ y_max = np.clip(y_max + k * abs(y_max - y_min), 0, height - 1)
74
+ return image, (x_min, y_min, x_max, y_max), gaze, head_mask, size
75
+
76
+
77
+ class RandomCrop(Augmentation):
78
+ def __init__(self, p: float) -> None:
79
+ super().__init__(p)
80
+
81
+ def transform(
82
+ self,
83
+ image: Image,
84
+ bbox: Tuple[float],
85
+ gaze: Tuple[float],
86
+ head_mask: Image,
87
+ size: Tuple[int],
88
+ ):
89
+ x_min, y_min, x_max, y_max = bbox
90
+ gaze_x, gaze_y = gaze
91
+ width, height = size
92
+ # Calculate the minimum valid range of the crop that doesn't exclude the face and the gaze target
93
+ crop_x_min = np.min([gaze_x * width, x_min, x_max])
94
+ crop_y_min = np.min([gaze_y * height, y_min, y_max])
95
+ crop_x_max = np.max([gaze_x * width, x_min, x_max])
96
+ crop_y_max = np.max([gaze_y * height, y_min, y_max])
97
+
98
+ # Randomly select a random top left corner
99
+ crop_x_min = np.random.uniform(0, crop_x_min)
100
+ crop_y_min = np.random.uniform(0, crop_y_min)
101
+
102
+ # Find the range of valid crop width and height starting from the (crop_x_min, crop_y_min)
103
+ crop_width_min = crop_x_max - crop_x_min
104
+ crop_height_min = crop_y_max - crop_y_min
105
+ crop_width_max = width - crop_x_min
106
+ crop_height_max = height - crop_y_min
107
+
108
+ # Randomly select a width and a height
109
+ crop_width = np.random.uniform(crop_width_min, crop_width_max)
110
+ crop_height = np.random.uniform(crop_height_min, crop_height_max)
111
+
112
+ # Round to integers
113
+ crop_y_min, crop_x_min, crop_height, crop_width = map(
114
+ int, map(round, (crop_y_min, crop_x_min, crop_height, crop_width))
115
+ )
116
+
117
+ # Crop it
118
+ image = TF.crop(image, crop_y_min, crop_x_min, crop_height, crop_width)
119
+ head_mask = TF.crop(head_mask, crop_y_min, crop_x_min, crop_height, crop_width)
120
+
121
+ # convert coordinates into the cropped frame
122
+ x_min, y_min, x_max, y_max = (
123
+ x_min - crop_x_min,
124
+ y_min - crop_y_min,
125
+ x_max - crop_x_min,
126
+ y_max - crop_y_min,
127
+ )
128
+
129
+ gaze_x = (gaze_x * width - crop_x_min) / float(crop_width)
130
+ gaze_y = (gaze_y * height - crop_y_min) / float(crop_height)
131
+
132
+ return (
133
+ image,
134
+ (x_min, y_min, x_max, y_max),
135
+ (gaze_x, gaze_y),
136
+ head_mask,
137
+ (crop_width, crop_height),
138
+ )
139
+
140
+
141
+ class RandomFlip(Augmentation):
142
+ def __init__(self, p: float) -> None:
143
+ super().__init__(p)
144
+
145
+ def transform(
146
+ self,
147
+ image: Image,
148
+ bbox: Tuple[float],
149
+ gaze: Tuple[float],
150
+ head_mask: Image,
151
+ size: Tuple[int],
152
+ ):
153
+ image = image.transpose(Image.FLIP_LEFT_RIGHT)
154
+ head_mask = head_mask.transpose(Image.FLIP_LEFT_RIGHT)
155
+ x_min, y_min, x_max, y_max = bbox
156
+ x_min, x_max = size[0] - x_max, size[0] - x_min
157
+ gaze_x, gaze_y = 1 - gaze[0], gaze[1]
158
+ return image, (x_min, y_min, x_max, y_max), (gaze_x, gaze_y), head_mask, size
159
+
160
+
161
+ class RandomRotate(Augmentation):
162
+ def __init__(
163
+ self, p: float, max_angle: int = 20, resample: int = Image.BILINEAR
164
+ ) -> None:
165
+ super().__init__(p)
166
+ self.max_angle = max_angle
167
+ self.resample = resample
168
+
169
+ def _random_rotation_matrix(self):
170
+ angle = (2 * np.random.random_sample() - 1) * self.max_angle
171
+ angle = -math.radians(angle)
172
+ return [
173
+ round(math.cos(angle), 15),
174
+ round(math.sin(angle), 15),
175
+ 0.0,
176
+ round(-math.sin(angle), 15),
177
+ round(math.cos(angle), 15),
178
+ 0.0,
179
+ ]
180
+
181
+ @staticmethod
182
+ def _transform(x, y, matrix):
183
+ return (
184
+ matrix[0] * x + matrix[1] * y + matrix[2],
185
+ matrix[3] * x + matrix[4] * y + matrix[5],
186
+ )
187
+
188
+ @staticmethod
189
+ def _inv_transform(x, y, matrix):
190
+ x, y = x - matrix[2], y - matrix[5]
191
+ return matrix[0] * x + matrix[3] * y, matrix[1] * x + matrix[4] * y
192
+
193
+ def transform(
194
+ self,
195
+ image: Image,
196
+ bbox: Tuple[float],
197
+ gaze: Tuple[float],
198
+ head_mask: Image,
199
+ size: Tuple[int],
200
+ ):
201
+ x_min, y_min, x_max, y_max = bbox
202
+ gaze_x, gaze_y = gaze
203
+ width, height = size
204
+ rot_mat = self._random_rotation_matrix()
205
+
206
+ # Calculate offsets
207
+ rot_center = (width / 2.0, height / 2.0)
208
+ rot_mat[2], rot_mat[5] = self._transform(
209
+ -rot_center[0], -rot_center[1], rot_mat
210
+ )
211
+ rot_mat[2] += rot_center[0]
212
+ rot_mat[5] += rot_center[1]
213
+ xx = []
214
+ yy = []
215
+ for x, y in ((0, 0), (width, 0), (width, height), (0, height)):
216
+ x, y = self._transform(x, y, rot_mat)
217
+ xx.append(x)
218
+ yy.append(y)
219
+ nw = math.ceil(max(xx)) - math.floor(min(xx))
220
+ nh = math.ceil(max(yy)) - math.floor(min(yy))
221
+ rot_mat[2], rot_mat[5] = self._transform(
222
+ -(nw - width) / 2.0, -(nh - height) / 2.0, rot_mat
223
+ )
224
+
225
+ image = image.transform((nw, nh), Image.AFFINE, rot_mat, self.resample)
226
+ head_mask = head_mask.transform((nw, nh), Image.AFFINE, rot_mat, self.resample)
227
+
228
+ xx = []
229
+ yy = []
230
+ for x, y in (
231
+ (x_min, y_min),
232
+ (x_min, y_max),
233
+ (x_max, y_min),
234
+ (x_max, y_max),
235
+ ):
236
+ x, y = self._inv_transform(x, y, rot_mat)
237
+ xx.append(x)
238
+ yy.append(y)
239
+ x_max, x_min = min(max(xx), nw), max(min(xx), 0)
240
+ y_max, y_min = min(max(yy), nh), max(min(yy), 0)
241
+
242
+ gaze_x, gaze_y = self._inv_transform(gaze_x * width, gaze_y * height, rot_mat)
243
+ gaze_x = max(min(gaze_x / nw, 1), 0)
244
+ gaze_y = max(min(gaze_y / nh, 1), 0)
245
+
246
+ return (
247
+ image,
248
+ (x_min, y_min, x_max, y_max),
249
+ (gaze_x, gaze_y),
250
+ head_mask,
251
+ (nw, nh),
252
+ )
253
+
254
+
255
+ class ColorJitter(Augmentation):
256
+ def __init__(
257
+ self,
258
+ p: float,
259
+ brightness: float = 0.4,
260
+ contrast: float = 0.4,
261
+ saturation: float = 0.2,
262
+ hue: float = 0.1,
263
+ ) -> None:
264
+ super().__init__(p)
265
+ self.color_jitter = transforms.ColorJitter(
266
+ brightness=brightness, contrast=contrast, saturation=saturation, hue=hue
267
+ )
268
+
269
+ def transform(
270
+ self,
271
+ image: Image,
272
+ bbox: Tuple[float],
273
+ gaze: Tuple[float],
274
+ head_mask: Image,
275
+ size: Tuple[int],
276
+ ):
277
+ return self.color_jitter(image), bbox, gaze, head_mask, size
278
+
279
+
280
+ class RandomLSJ(Augmentation):
281
+ def __init__(self, p: float, min_scale: float = 0.1) -> None:
282
+ super().__init__(p)
283
+ self.min_scale = min_scale
284
+
285
+ def transform(
286
+ self,
287
+ image: Image,
288
+ bbox: Tuple[float],
289
+ gaze: Tuple[float],
290
+ head_mask: Image,
291
+ size: Tuple[int],
292
+ ):
293
+ x_min, y_min, x_max, y_max = bbox
294
+ gaze_x, gaze_y = gaze
295
+ width, height = size
296
+
297
+ scale = self.min_scale + np.random.random_sample() * (1 - self.min_scale)
298
+ nh, nw = int(height * scale), int(width * scale)
299
+
300
+ image = TF.resize(image, (nh, nw))
301
+ image = ImageOps.expand(image, (0, 0, width - nw, height - nh))
302
+ head_mask = TF.resize(head_mask, (nh, nw))
303
+ head_mask = ImageOps.expand(head_mask, (0, 0, width - nw, height - nh))
304
+
305
+ x_min, y_min, x_max, y_max = (
306
+ x_min * scale,
307
+ y_min * scale,
308
+ x_max * scale,
309
+ y_max * scale,
310
+ )
311
+ gaze_x, gaze_y = gaze_x * scale, gaze_y * scale
312
+ return image, (x_min, y_min, x_max, y_max), (gaze_x, gaze_y), head_mask, size
data/data_utils.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import torch
3
+ from torchvision import transforms
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+
8
+ def to_numpy(tensor: torch.Tensor):
9
+ if torch.is_tensor(tensor):
10
+ return tensor.cpu().detach().numpy()
11
+ elif type(tensor).__module__ != "numpy":
12
+ raise ValueError("Cannot convert {} to numpy array".format(type(tensor)))
13
+ return tensor
14
+
15
+
16
+ def to_torch(ndarray: np.ndarray):
17
+ if type(ndarray).__module__ == "numpy":
18
+ return torch.from_numpy(ndarray)
19
+ elif not torch.is_tensor(ndarray):
20
+ raise ValueError("Cannot convert {} to torch tensor".format(type(ndarray)))
21
+ return ndarray
22
+
23
+
24
+ def get_head_box_channel(
25
+ x_min, y_min, x_max, y_max, width, height, resolution, coordconv=False
26
+ ):
27
+ head_box = (
28
+ np.array([x_min / width, y_min / height, x_max / width, y_max / height])
29
+ * resolution
30
+ )
31
+ int_head_box = head_box.astype(int)
32
+ int_head_box = np.clip(int_head_box, 0, resolution - 1)
33
+ if int_head_box[0] == int_head_box[2]:
34
+ if int_head_box[0] == 0:
35
+ int_head_box[2] = 1
36
+ elif int_head_box[2] == resolution - 1:
37
+ int_head_box[0] = resolution - 2
38
+ elif abs(head_box[2] - int_head_box[2]) > abs(head_box[0] - int_head_box[0]):
39
+ int_head_box[2] += 1
40
+ else:
41
+ int_head_box[0] -= 1
42
+ if int_head_box[1] == int_head_box[3]:
43
+ if int_head_box[1] == 0:
44
+ int_head_box[3] = 1
45
+ elif int_head_box[3] == resolution - 1:
46
+ int_head_box[1] = resolution - 2
47
+ elif abs(head_box[3] - int_head_box[3]) > abs(head_box[1] - int_head_box[1]):
48
+ int_head_box[3] += 1
49
+ else:
50
+ int_head_box[1] -= 1
51
+ head_box = int_head_box
52
+ if coordconv:
53
+ unit = np.array(range(0, resolution), dtype=np.float32)
54
+ head_channel = []
55
+ for i in unit:
56
+ head_channel.append([unit + i])
57
+ head_channel = np.squeeze(np.array(head_channel)) / float(np.max(head_channel))
58
+ head_channel[head_box[1] : head_box[3], head_box[0] : head_box[2]] = 0
59
+ else:
60
+ head_channel = np.zeros((resolution, resolution), dtype=np.float32)
61
+ head_channel[head_box[1] : head_box[3], head_box[0] : head_box[2]] = 1
62
+ head_channel = torch.from_numpy(head_channel)
63
+ return head_channel
64
+
65
+
66
+ def draw_labelmap(img, pt, sigma, type="Gaussian"):
67
+ # Draw a 2D gaussian
68
+ # Adopted from https://github.com/anewell/pose-hg-train/blob/master/src/pypose/draw.py
69
+ img = to_numpy(img)
70
+
71
+ # Check that any part of the gaussian is in-bounds
72
+ size = int(6 * sigma + 1)
73
+ ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)]
74
+ br = [ul[0] + size, ul[1] + size]
75
+ if ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or br[0] < 0 or br[1] < 0:
76
+ # If not, just return the image as is
77
+ return to_torch(img)
78
+
79
+ # Generate gaussian
80
+ x = np.arange(0, size, 1, float)
81
+ y = x[:, np.newaxis]
82
+ x0 = y0 = size // 2
83
+ # The gaussian is not normalized, we want the center value to equal 1
84
+ if type == "Gaussian":
85
+ g = np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma**2))
86
+ elif type == "Cauchy":
87
+ g = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma**2) ** 1.5)
88
+
89
+ # Usable gaussian range
90
+ g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0]
91
+ g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1]
92
+ # Image range
93
+ img_x = max(0, ul[0]), min(br[0], img.shape[1])
94
+ img_y = max(0, ul[1]), min(br[1], img.shape[0])
95
+
96
+ img[img_y[0] : img_y[1], img_x[0] : img_x[1]] += g[g_y[0] : g_y[1], g_x[0] : g_x[1]]
97
+ # img = img / np.max(img)
98
+ return to_torch(img)
99
+
100
+
101
+ def draw_labelmap_no_quant(img, pt, sigma, type="Gaussian"):
102
+ img = to_numpy(img)
103
+ shape = img.shape
104
+ x = np.arange(shape[0])
105
+ y = np.arange(shape[1])
106
+ xx, yy = np.meshgrid(x, y, indexing="ij")
107
+ dist_matrix = (yy - float(pt[0])) ** 2 + (xx - float(pt[1])) ** 2
108
+ if type == "Gaussian":
109
+ g = np.exp(-dist_matrix / (2 * sigma**2))
110
+ elif type == "Cauchy":
111
+ g = sigma / ((dist_matrix + sigma**2) ** 1.5)
112
+ g[dist_matrix > 10 * sigma**2] = 0
113
+ img += g
114
+ # img = img / np.max(img)
115
+ return to_torch(img)
116
+
117
+
118
+ def multi_hot_targets(gaze_pts, out_res):
119
+ w, h = out_res
120
+ target_map = np.zeros((h, w))
121
+ for p in gaze_pts:
122
+ if p[0] >= 0:
123
+ x, y = map(int, [p[0] * float(w), p[1] * float(h)])
124
+ x = min(x, w - 1)
125
+ y = min(y, h - 1)
126
+ target_map[y, x] = 1
127
+ return target_map
128
+
129
+
130
+ def get_cone(tgt, src, wh, theta=150):
131
+ eye = src * wh
132
+ gaze = tgt * wh
133
+
134
+ pixel_mat = np.stack(
135
+ np.meshgrid(np.arange(wh[0]), np.arange(wh[1])),
136
+ -1,
137
+ )
138
+
139
+ dot_prod = np.sum((pixel_mat - eye) * (gaze - eye), axis=-1)
140
+ gaze_vector_norm = np.sqrt(np.sum((gaze - eye) ** 2))
141
+ pixel_mat_norm = np.sqrt(np.sum((pixel_mat - eye) ** 2, axis=-1))
142
+
143
+ gaze_cones = dot_prod / (gaze_vector_norm * pixel_mat_norm)
144
+ gaze_cones = np.nan_to_num(gaze_cones, nan=1)
145
+
146
+ theta = theta * (np.pi / 180)
147
+ beta = np.arccos(gaze_cones)
148
+ # Create mask where true if beta is less than theta/2
149
+ pixel_mat_presence = beta < (theta / 2)
150
+
151
+ # Zero out values outside the gaze cone
152
+ gaze_cones[~pixel_mat_presence] = 0
153
+ gaze_cones = np.clip(gaze_cones, 0, None)
154
+
155
+ return torch.from_numpy(gaze_cones).unsqueeze(0).float()
156
+
157
+
158
+ def get_transform(
159
+ input_resolution: int, mean: Tuple[int, int, int], std: Tuple[int, int, int]
160
+ ):
161
+ return transforms.Compose(
162
+ [
163
+ transforms.Resize((input_resolution, input_resolution)),
164
+ transforms.ToTensor(),
165
+ transforms.Normalize(mean=mean, std=std),
166
+ ]
167
+ )
168
+
169
+
170
+ def smooth_by_conv(window_size, df, col):
171
+ padded_track = pd.concat(
172
+ [
173
+ pd.DataFrame([[df.iloc[0][col]]] * (window_size // 2), columns=[0]),
174
+ df[col],
175
+ pd.DataFrame([[df.iloc[-1][col]]] * (window_size // 2), columns=[0]),
176
+ ]
177
+ )
178
+ smoothed_signals = np.convolve(
179
+ padded_track.squeeze(), np.ones(window_size) / window_size, mode="valid"
180
+ )
181
+ return smoothed_signals
data/gazefollow.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import path as osp
2
+ from typing import Callable, Optional
3
+
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ from torchvision.transforms import functional as TF
7
+ from PIL import Image
8
+ import pandas as pd
9
+
10
+ from . import augmentation
11
+ from .masking import MaskGenerator
12
+ from . import data_utils as utils
13
+
14
+
15
+ class GazeFollow(Dataset):
16
+ def __init__(
17
+ self,
18
+ image_root: str,
19
+ anno_root: str,
20
+ head_root: str,
21
+ transform: Callable,
22
+ input_size: int,
23
+ output_size: int,
24
+ quant_labelmap: bool = True,
25
+ is_train: bool = True,
26
+ *,
27
+ mask_generator: Optional[MaskGenerator] = None,
28
+ bbox_jitter: float = 0.5,
29
+ rand_crop: float = 0.5,
30
+ rand_flip: float = 0.5,
31
+ color_jitter: float = 0.5,
32
+ rand_rotate: float = 0.0,
33
+ rand_lsj: float = 0.0,
34
+ ):
35
+ if is_train:
36
+ column_names = [
37
+ "path",
38
+ "idx",
39
+ "body_bbox_x",
40
+ "body_bbox_y",
41
+ "body_bbox_w",
42
+ "body_bbox_h",
43
+ "eye_x",
44
+ "eye_y",
45
+ "gaze_x",
46
+ "gaze_y",
47
+ "bbox_x_min",
48
+ "bbox_y_min",
49
+ "bbox_x_max",
50
+ "bbox_y_max",
51
+ "inout",
52
+ "meta0",
53
+ "meta1",
54
+ ]
55
+ df = pd.read_csv(
56
+ anno_root,
57
+ sep=",",
58
+ names=column_names,
59
+ index_col=False,
60
+ encoding="utf-8-sig",
61
+ )
62
+ df = df[
63
+ df["inout"] != -1
64
+ ] # only use "in" or "out "gaze. (-1 is invalid, 0 is out gaze)
65
+ df.reset_index(inplace=True)
66
+ self.y_train = df[
67
+ [
68
+ "bbox_x_min",
69
+ "bbox_y_min",
70
+ "bbox_x_max",
71
+ "bbox_y_max",
72
+ "eye_x",
73
+ "eye_y",
74
+ "gaze_x",
75
+ "gaze_y",
76
+ "inout",
77
+ ]
78
+ ]
79
+ self.X_train = df["path"]
80
+ self.length = len(df)
81
+ else:
82
+ column_names = [
83
+ "path",
84
+ "idx",
85
+ "body_bbox_x",
86
+ "body_bbox_y",
87
+ "body_bbox_w",
88
+ "body_bbox_h",
89
+ "eye_x",
90
+ "eye_y",
91
+ "gaze_x",
92
+ "gaze_y",
93
+ "bbox_x_min",
94
+ "bbox_y_min",
95
+ "bbox_x_max",
96
+ "bbox_y_max",
97
+ "meta0",
98
+ "meta1",
99
+ ]
100
+ df = pd.read_csv(
101
+ anno_root,
102
+ sep=",",
103
+ names=column_names,
104
+ index_col=False,
105
+ encoding="utf-8-sig",
106
+ )
107
+ df = df[
108
+ [
109
+ "path",
110
+ "eye_x",
111
+ "eye_y",
112
+ "gaze_x",
113
+ "gaze_y",
114
+ "bbox_x_min",
115
+ "bbox_y_min",
116
+ "bbox_x_max",
117
+ "bbox_y_max",
118
+ ]
119
+ ].groupby(["path", "eye_x"])
120
+ self.keys = list(df.groups.keys())
121
+ self.X_test = df
122
+ self.length = len(self.keys)
123
+
124
+ self.data_dir = image_root
125
+ self.head_dir = head_root
126
+ self.transform = transform
127
+ self.is_train = is_train
128
+
129
+ self.input_size = input_size
130
+ self.output_size = output_size
131
+
132
+ self.draw_labelmap = (
133
+ utils.draw_labelmap if quant_labelmap else utils.draw_labelmap_no_quant
134
+ )
135
+
136
+ if self.is_train:
137
+ ## data augmentation
138
+ self.augment = augmentation.AugmentationList(
139
+ [
140
+ augmentation.ColorJitter(color_jitter),
141
+ augmentation.BoxJitter(bbox_jitter),
142
+ augmentation.RandomCrop(rand_crop),
143
+ augmentation.RandomFlip(rand_flip),
144
+ augmentation.RandomRotate(rand_rotate),
145
+ augmentation.RandomLSJ(rand_lsj),
146
+ ]
147
+ )
148
+
149
+ self.mask_generator = mask_generator
150
+
151
+ def __getitem__(self, index):
152
+ if not self.is_train:
153
+ g = self.X_test.get_group(self.keys[index])
154
+ cont_gaze = []
155
+ for _, row in g.iterrows():
156
+ path = row["path"]
157
+ x_min = row["bbox_x_min"]
158
+ y_min = row["bbox_y_min"]
159
+ x_max = row["bbox_x_max"]
160
+ y_max = row["bbox_y_max"]
161
+ eye_x = row["eye_x"]
162
+ eye_y = row["eye_y"]
163
+ gaze_x = row["gaze_x"]
164
+ gaze_y = row["gaze_y"]
165
+ cont_gaze.append(
166
+ [gaze_x, gaze_y]
167
+ ) # all ground truth gaze are stacked up
168
+ for _ in range(len(cont_gaze), 20):
169
+ cont_gaze.append(
170
+ [-1, -1]
171
+ ) # pad dummy gaze to match size for batch processing
172
+ cont_gaze = torch.FloatTensor(cont_gaze)
173
+ gaze_inside = True # always consider test samples as inside
174
+ else:
175
+ path = self.X_train.iloc[index]
176
+ (
177
+ x_min,
178
+ y_min,
179
+ x_max,
180
+ y_max,
181
+ eye_x,
182
+ eye_y,
183
+ gaze_x,
184
+ gaze_y,
185
+ inout,
186
+ ) = self.y_train.iloc[index]
187
+ gaze_inside = bool(inout)
188
+
189
+ img = Image.open(osp.join(self.data_dir, path))
190
+ img = img.convert("RGB")
191
+ head_mask = Image.open(osp.join(self.head_dir, path))
192
+ width, height = img.size
193
+ x_min, y_min, x_max, y_max = map(float, [x_min, y_min, x_max, y_max])
194
+ if x_max < x_min:
195
+ x_min, x_max = x_max, x_min
196
+ if y_max < y_min:
197
+ y_min, y_max = y_max, y_min
198
+ # expand face bbox a bit
199
+ k = 0.1
200
+ x_min = max(x_min - k * abs(x_max - x_min), 0)
201
+ y_min = max(y_min - k * abs(y_max - y_min), 0)
202
+ x_max = min(x_max + k * abs(x_max - x_min), width - 1)
203
+ y_max = min(y_max + k * abs(y_max - y_min), height - 1)
204
+
205
+ if self.is_train:
206
+ img, bbox, gaze, head_mask, size = self.augment(
207
+ img,
208
+ (x_min, y_min, x_max, y_max),
209
+ (gaze_x, gaze_y),
210
+ head_mask,
211
+ (width, height),
212
+ )
213
+ x_min, y_min, x_max, y_max = bbox
214
+ gaze_x, gaze_y = gaze
215
+ width, height = size
216
+
217
+ head_channel = utils.get_head_box_channel(
218
+ x_min,
219
+ y_min,
220
+ x_max,
221
+ y_max,
222
+ width,
223
+ height,
224
+ resolution=self.input_size,
225
+ coordconv=False,
226
+ ).unsqueeze(0)
227
+
228
+ if self.is_train and self.mask_generator is not None:
229
+ image_mask = self.mask_generator(
230
+ x_min / width,
231
+ y_min / height,
232
+ x_max / width,
233
+ y_max / height,
234
+ head_channel,
235
+ )
236
+
237
+ if self.transform is not None:
238
+ img = self.transform(img)
239
+ head_mask = TF.to_tensor(
240
+ TF.resize(head_mask, (self.input_size, self.input_size))
241
+ )
242
+
243
+ # generate the heat map used for deconv prediction
244
+ gaze_heatmap = torch.zeros(
245
+ self.output_size, self.output_size
246
+ ) # set the size of the output
247
+ if not self.is_train: # aggregated heatmap
248
+ num_valid = 0
249
+ for gaze_x, gaze_y in cont_gaze:
250
+ if gaze_x != -1:
251
+ num_valid += 1
252
+ gaze_heatmap += self.draw_labelmap(
253
+ torch.zeros(self.output_size, self.output_size),
254
+ [gaze_x * self.output_size, gaze_y * self.output_size],
255
+ 3,
256
+ type="Gaussian",
257
+ )
258
+ gaze_heatmap /= num_valid
259
+ else:
260
+ # if gaze_inside:
261
+ gaze_heatmap = self.draw_labelmap(
262
+ gaze_heatmap,
263
+ [gaze_x * self.output_size, gaze_y * self.output_size],
264
+ 3,
265
+ type="Gaussian",
266
+ )
267
+
268
+ imsize = torch.IntTensor([width, height])
269
+
270
+ if self.is_train:
271
+ out_dict = {
272
+ "images": img,
273
+ "head_channels": head_channel,
274
+ "heatmaps": gaze_heatmap,
275
+ "gazes": torch.FloatTensor([gaze_x, gaze_y]),
276
+ "gaze_inouts": torch.FloatTensor([gaze_inside]),
277
+ "head_masks": head_mask,
278
+ "imsize": imsize,
279
+ }
280
+ if self.mask_generator is not None:
281
+ out_dict["image_masks"] = image_mask
282
+ return out_dict
283
+ else:
284
+ return {
285
+ "images": img,
286
+ "head_channels": head_channel,
287
+ "heatmaps": gaze_heatmap,
288
+ "gazes": cont_gaze,
289
+ "gaze_inouts": torch.FloatTensor([gaze_inside]),
290
+ "head_masks": head_mask,
291
+ "imsize": imsize,
292
+ }
293
+
294
+ def __len__(self):
295
+ return self.length
data/masking.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ from torch.nn import functional as F
6
+
7
+
8
+ class SceneMaskGenerator:
9
+ def __init__(
10
+ self,
11
+ input_size,
12
+ min_num_patches=16,
13
+ max_num_patches_ratio=0.5,
14
+ min_aspect=0.3,
15
+ ):
16
+ if not isinstance(input_size, tuple):
17
+ input_size = (input_size,) * 2
18
+ self.input_size = input_size
19
+ self.num_patches = input_size[0] * input_size[1]
20
+
21
+ self.min_num_patches = min_num_patches
22
+ self.max_num_patches = max_num_patches_ratio * self.num_patches
23
+
24
+ self.log_aspect_ratio = (math.log(min_aspect), -math.log(min_aspect))
25
+
26
+ def _mask(self, mask, max_mask_patches):
27
+ delta = 0
28
+ for _ in range(4):
29
+ target_area = random.uniform(self.min_num_patches, max_mask_patches)
30
+ aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
31
+ h = int(round(math.sqrt(target_area * aspect_ratio)))
32
+ w = int(round(math.sqrt(target_area / aspect_ratio)))
33
+ height, width = self.input_size
34
+ if w < width and h < height:
35
+ top = random.randint(0, height - h)
36
+ left = random.randint(0, width - w)
37
+
38
+ num_masked = mask[top : top + h, left : left + w].sum()
39
+ # Overlap
40
+ if 0 < h * w - num_masked <= max_mask_patches:
41
+ mask[top : top + h, left : left + w] = 1
42
+ delta = h * w - num_masked
43
+ break
44
+ return delta
45
+
46
+ def __call__(self, head_mask):
47
+ mask = np.zeros(shape=self.input_size, dtype=bool)
48
+ mask_count = 0
49
+ num_masking_patches = random.uniform(self.min_num_patches, self.max_num_patches)
50
+ while mask_count < num_masking_patches:
51
+ max_mask_patches = num_masking_patches - mask_count
52
+ delta = self._mask(mask, max_mask_patches)
53
+ if delta == 0:
54
+ break
55
+ else:
56
+ mask_count += delta
57
+
58
+ mask = torch.from_numpy(mask).unsqueeze(0)
59
+ head_mask = (
60
+ F.interpolate(head_mask.unsqueeze(0), mask.shape[-2:]).squeeze(0) < 0.5
61
+ )
62
+ return torch.logical_and(mask, head_mask).squeeze(0)
63
+
64
+
65
+ class HeadMaskGenerator:
66
+ def __init__(
67
+ self,
68
+ input_size,
69
+ min_num_patches=4,
70
+ max_num_patches_ratio=0.5,
71
+ min_aspect=0.3,
72
+ ):
73
+ if not isinstance(input_size, tuple):
74
+ input_size = (input_size,) * 2
75
+ self.input_size = input_size
76
+ self.num_patches = input_size[0] * input_size[1]
77
+
78
+ self.min_num_patches = min_num_patches
79
+ self.max_num_patches_ratio = max_num_patches_ratio
80
+
81
+ self.log_aspect_ratio = (math.log(min_aspect), -math.log(min_aspect))
82
+
83
+ def __call__(
84
+ self,
85
+ x_min,
86
+ y_min,
87
+ x_max,
88
+ y_max, # coords in [0,1]
89
+ ):
90
+ height = math.floor((y_max - y_min) * self.input_size[0])
91
+ width = math.floor((x_max - x_min) * self.input_size[1])
92
+ origin_area = width * height
93
+ if origin_area < self.min_num_patches:
94
+ return torch.zeros(size=self.input_size, dtype=bool)
95
+
96
+ target_area = random.uniform(
97
+ self.min_num_patches, self.max_num_patches_ratio * origin_area
98
+ )
99
+ aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
100
+ h = min(int(round(math.sqrt(target_area * aspect_ratio))), height)
101
+ w = min(int(round(math.sqrt(target_area / aspect_ratio))), width)
102
+ top = random.randint(0, height - h) + int(y_min * self.input_size[0])
103
+ left = random.randint(0, width - w) + int(x_min * self.input_size[1])
104
+ mask = torch.zeros(size=self.input_size, dtype=bool)
105
+ mask[top : top + h, left : left + w] = True
106
+ return mask
107
+
108
+
109
+ class MaskGenerator:
110
+ def __init__(
111
+ self,
112
+ input_size,
113
+ mask_scene: bool = False,
114
+ mask_head: bool = False,
115
+ min_scene_patches=16,
116
+ max_scene_patches_ratio=0.5,
117
+ min_head_patches=4,
118
+ max_head_patches_ratio=0.5,
119
+ min_aspect=0.3,
120
+ mask_prob=0.2,
121
+ head_prob=0.2,
122
+ ):
123
+ if not isinstance(input_size, tuple):
124
+ input_size = (input_size,) * 2
125
+ self.input_size = input_size
126
+ if mask_scene:
127
+ self.scene_mask_generator = SceneMaskGenerator(
128
+ input_size, min_scene_patches, max_scene_patches_ratio, min_aspect
129
+ )
130
+ else:
131
+ self.scene_mask_generator = None
132
+
133
+ if mask_head:
134
+ self.head_mask_generator = HeadMaskGenerator(
135
+ input_size, min_head_patches, max_head_patches_ratio, min_aspect
136
+ )
137
+ else:
138
+ self.head_mask_generator = None
139
+
140
+ self.no_mask = not (mask_scene or mask_head)
141
+ self.mask_head = mask_head and not mask_scene
142
+ self.mask_scene = mask_scene and not mask_head
143
+ self.scene_prob = mask_prob
144
+ self.head_prob = head_prob
145
+
146
+ def __call__(
147
+ self,
148
+ x_min,
149
+ y_min,
150
+ x_max,
151
+ y_max,
152
+ head_mask,
153
+ ):
154
+ mask_scene = random.random() < self.scene_prob
155
+ mask_head = random.random() < self.head_prob
156
+ no_mask = (
157
+ self.no_mask
158
+ or (self.mask_head and not mask_head)
159
+ or (self.mask_scene and not mask_scene)
160
+ or not (mask_scene or mask_head)
161
+ )
162
+ if no_mask:
163
+ return torch.zeros(size=self.input_size, dtype=bool)
164
+ if self.mask_scene:
165
+ return self.scene_mask_generator(head_mask)
166
+ if self.mask_head:
167
+ return self.head_mask_generator(x_min, y_min, x_max, y_max)
168
+ if mask_head and mask_scene:
169
+ return torch.logical_or(
170
+ self.scene_mask_generator(head_mask),
171
+ self.head_mask_generator(x_min, y_min, x_max, y_max),
172
+ )
173
+ elif mask_head:
174
+ return self.head_mask_generator(x_min, y_min, x_max, y_max)
175
+ return self.scene_mask_generator(head_mask)
data/video_attention_target.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ from typing import Callable, Optional
3
+ from os import path as osp
4
+
5
+ import torch
6
+ from torch.utils.data.dataset import Dataset
7
+ import torchvision.transforms.functional as TF
8
+ import numpy as np
9
+ import pandas as pd
10
+ from PIL import Image
11
+
12
+ from . import augmentation
13
+ from . import data_utils as utils
14
+ from .masking import MaskGenerator
15
+
16
+
17
+ class VideoAttentionTarget(Dataset):
18
+ def __init__(
19
+ self,
20
+ image_root: str,
21
+ anno_root: str,
22
+ head_root: str,
23
+ transform: Callable,
24
+ input_size: int,
25
+ output_size: int,
26
+ quant_labelmap: bool = True,
27
+ is_train: bool = True,
28
+ *,
29
+ mask_generator: Optional[MaskGenerator] = None,
30
+ bbox_jitter: float = 0.5,
31
+ rand_crop: float = 0.5,
32
+ rand_flip: float = 0.5,
33
+ color_jitter: float = 0.5,
34
+ rand_rotate: float = 0.0,
35
+ rand_lsj: float = 0.0,
36
+ ):
37
+ frames = []
38
+ for show_dir in glob.glob(osp.join(anno_root, "*")):
39
+ for sequence_path in glob.glob(osp.join(show_dir, "*", "*.txt")):
40
+ df = pd.read_csv(
41
+ sequence_path,
42
+ header=None,
43
+ index_col=False,
44
+ names=[
45
+ "path",
46
+ "x_min",
47
+ "y_min",
48
+ "x_max",
49
+ "y_max",
50
+ "gaze_x",
51
+ "gaze_y",
52
+ ],
53
+ )
54
+
55
+ show_name = sequence_path.split("/")[-3]
56
+ clip = sequence_path.split("/")[-2]
57
+ df["path"] = df["path"].apply(
58
+ lambda path: osp.join(show_name, clip, path)
59
+ )
60
+ # Add two columns for the bbox center
61
+ df["eye_x"] = (df["x_min"] + df["x_max"]) / 2
62
+ df["eye_y"] = (df["y_min"] + df["y_max"]) / 2
63
+ df = df.sample(frac=0.2, random_state=42)
64
+ frames.extend(df.values.tolist())
65
+
66
+ df = pd.DataFrame(
67
+ frames,
68
+ columns=[
69
+ "path",
70
+ "x_min",
71
+ "y_min",
72
+ "x_max",
73
+ "y_max",
74
+ "gaze_x",
75
+ "gaze_y",
76
+ "eye_x",
77
+ "eye_y",
78
+ ],
79
+ )
80
+ # Drop rows with invalid bboxes
81
+ coords = torch.tensor(
82
+ np.array(
83
+ (
84
+ df["x_min"].values,
85
+ df["y_min"].values,
86
+ df["x_max"].values,
87
+ df["y_max"].values,
88
+ )
89
+ ).transpose(1, 0)
90
+ )
91
+ valid_bboxes = (coords[:, 2:] >= coords[:, :2]).all(dim=1)
92
+ df = df.loc[valid_bboxes.tolist(), :]
93
+ df.reset_index(inplace=True)
94
+ self.df = df
95
+ self.length = len(df)
96
+
97
+ self.data_dir = image_root
98
+ self.head_dir = head_root
99
+ self.transform = transform
100
+ self.draw_labelmap = (
101
+ utils.draw_labelmap if quant_labelmap else utils.draw_labelmap_no_quant
102
+ )
103
+ self.is_train = is_train
104
+
105
+ self.input_size = input_size
106
+ self.output_size = output_size
107
+
108
+ if self.is_train:
109
+ ## data augmentation
110
+ self.augment = augmentation.AugmentationList(
111
+ [
112
+ augmentation.ColorJitter(color_jitter),
113
+ augmentation.BoxJitter(bbox_jitter),
114
+ augmentation.RandomCrop(rand_crop),
115
+ augmentation.RandomFlip(rand_flip),
116
+ augmentation.RandomRotate(rand_rotate),
117
+ augmentation.RandomLSJ(rand_lsj),
118
+ ]
119
+ )
120
+
121
+ self.mask_generator = mask_generator
122
+
123
+ def __getitem__(self, index):
124
+ (
125
+ _,
126
+ path,
127
+ x_min,
128
+ y_min,
129
+ x_max,
130
+ y_max,
131
+ gaze_x,
132
+ gaze_y,
133
+ eye_x,
134
+ eye_y,
135
+ ) = self.df.iloc[index]
136
+ gaze_inside = gaze_x != -1 or gaze_y != -1
137
+
138
+ img = Image.open(osp.join(self.data_dir, path))
139
+ img = img.convert("RGB")
140
+ width, height = img.size
141
+ # Since we finetune from weights trained on GazeFollow,
142
+ # we don't incorporate the auxiliary task for VAT.
143
+ if osp.exists(osp.join(self.head_dir, path)):
144
+ head_mask = Image.open(osp.join(self.head_dir, path)).resize(
145
+ (width, height)
146
+ )
147
+ else:
148
+ head_mask = Image.fromarray(np.zeros((height, width), dtype=np.float32))
149
+ x_min, y_min, x_max, y_max = map(float, [x_min, y_min, x_max, y_max])
150
+ if x_max < x_min:
151
+ x_min, x_max = x_max, x_min
152
+ if y_max < y_min:
153
+ y_min, y_max = y_max, y_min
154
+ gaze_x, gaze_y = gaze_x / width, gaze_y / height
155
+ # expand face bbox a bit
156
+ k = 0.1
157
+ x_min = max(x_min - k * abs(x_max - x_min), 0)
158
+ y_min = max(y_min - k * abs(y_max - y_min), 0)
159
+ x_max = min(x_max + k * abs(x_max - x_min), width - 1)
160
+ y_max = min(y_max + k * abs(y_max - y_min), height - 1)
161
+
162
+ if self.is_train:
163
+ img, bbox, gaze, head_mask, size = self.augment(
164
+ img,
165
+ (x_min, y_min, x_max, y_max),
166
+ (gaze_x, gaze_y),
167
+ head_mask,
168
+ (width, height),
169
+ )
170
+ x_min, y_min, x_max, y_max = bbox
171
+ gaze_x, gaze_y = gaze
172
+ width, height = size
173
+
174
+ head_channel = utils.get_head_box_channel(
175
+ x_min,
176
+ y_min,
177
+ x_max,
178
+ y_max,
179
+ width,
180
+ height,
181
+ resolution=self.input_size,
182
+ coordconv=False,
183
+ ).unsqueeze(0)
184
+
185
+ if self.is_train and self.mask_generator is not None:
186
+ image_mask = self.mask_generator(
187
+ x_min / width,
188
+ y_min / height,
189
+ x_max / width,
190
+ y_max / height,
191
+ head_channel,
192
+ )
193
+
194
+ if self.transform is not None:
195
+ img = self.transform(img)
196
+ head_mask = TF.to_tensor(
197
+ TF.resize(head_mask, (self.input_size, self.input_size))
198
+ )
199
+
200
+ # generate the heat map used for deconv prediction
201
+ gaze_heatmap = torch.zeros(
202
+ self.output_size, self.output_size
203
+ ) # set the size of the output
204
+
205
+ gaze_heatmap = self.draw_labelmap(
206
+ gaze_heatmap,
207
+ [gaze_x * self.output_size, gaze_y * self.output_size],
208
+ 3,
209
+ type="Gaussian",
210
+ )
211
+
212
+ imsize = torch.IntTensor([width, height])
213
+
214
+ out_dict = {
215
+ "images": img,
216
+ "head_channels": head_channel,
217
+ "heatmaps": gaze_heatmap,
218
+ "gazes": torch.FloatTensor([gaze_x, gaze_y]),
219
+ "gaze_inouts": torch.FloatTensor([gaze_inside]),
220
+ "head_masks": head_mask,
221
+ "imsize": imsize,
222
+ }
223
+ if self.is_train and self.mask_generator is not None:
224
+ out_dict["image_masks"] = image_mask
225
+ return out_dict
226
+
227
+ def __len__(self):
228
+ return self.length
data/video_attention_target_video.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from os import path as osp
3
+ from typing import Callable, Optional
4
+ import glob
5
+ import torch
6
+ from torch.utils.data.dataset import Dataset
7
+ import torchvision.transforms.functional as TF
8
+ import numpy as np
9
+ from PIL import Image, ImageOps
10
+ import pandas as pd
11
+ from .masking import MaskGenerator
12
+ from . import data_utils as utils
13
+
14
+
15
+ class VideoAttentionTargetVideo(Dataset):
16
+ def __init__(
17
+ self,
18
+ image_root: str,
19
+ anno_root: str,
20
+ head_root: str,
21
+ transform: Callable,
22
+ input_size: int,
23
+ output_size: int,
24
+ quant_labelmap: bool = True,
25
+ is_train: bool = True,
26
+ seq_len: int = 8,
27
+ max_len: int = 32,
28
+ *,
29
+ mask_generator: Optional[MaskGenerator] = None,
30
+ bbox_jitter: float = 0.5,
31
+ rand_crop: float = 0.5,
32
+ rand_flip: float = 0.5,
33
+ color_jitter: float = 0.5,
34
+ rand_rotate: float = 0.0,
35
+ rand_lsj: float = 0.0,
36
+ ):
37
+ dfs = []
38
+ for show_dir in glob.glob(osp.join(anno_root, "*")):
39
+ for sequence_path in glob.glob(osp.join(show_dir, "*", "*.txt")):
40
+ df = pd.read_csv(
41
+ sequence_path,
42
+ header=None,
43
+ index_col=False,
44
+ names=[
45
+ "path",
46
+ "x_min",
47
+ "y_min",
48
+ "x_max",
49
+ "y_max",
50
+ "gaze_x",
51
+ "gaze_y",
52
+ ],
53
+ )
54
+ show_name = sequence_path.split("/")[-3]
55
+ clip = sequence_path.split("/")[-2]
56
+ df["path"] = df["path"].apply(
57
+ lambda path: osp.join(show_name, clip, path)
58
+ )
59
+ cur_len = len(df.index)
60
+ if is_train:
61
+ if cur_len <= max_len:
62
+ if cur_len >= seq_len:
63
+ dfs.append(df)
64
+ continue
65
+ remainder = cur_len % max_len
66
+ df_splits = [
67
+ df[i : i + max_len]
68
+ for i in range(0, cur_len - max_len, max_len)
69
+ ]
70
+ if remainder >= seq_len:
71
+ df_splits.append(df[-remainder:])
72
+ dfs.extend(df_splits)
73
+ else:
74
+ if cur_len < seq_len:
75
+ continue
76
+ df_splits = [
77
+ df[i : i + seq_len]
78
+ for i in range(0, cur_len - seq_len, seq_len)
79
+ ]
80
+ dfs.extend(df_splits)
81
+
82
+ for df in dfs:
83
+ df.reset_index(inplace=True)
84
+ self.dfs = dfs
85
+ self.length = len(dfs)
86
+
87
+ self.data_dir = image_root
88
+ self.head_dir = head_root
89
+ self.transform = transform
90
+ self.draw_labelmap = (
91
+ utils.draw_labelmap if quant_labelmap else utils.draw_labelmap_no_quant
92
+ )
93
+ self.is_train = is_train
94
+
95
+ self.input_size = input_size
96
+ self.output_size = output_size
97
+ self.seq_len = seq_len
98
+
99
+ if self.is_train:
100
+ self.bbox_jitter = bbox_jitter
101
+ self.rand_crop = rand_crop
102
+ self.rand_flip = rand_flip
103
+ self.color_jitter = color_jitter
104
+ self.rand_rotate = rand_rotate
105
+ self.rand_lsj = rand_lsj
106
+ self.mask_generator = mask_generator
107
+
108
+ def __getitem__(self, index):
109
+ df = self.dfs[index]
110
+ seq_len = len(df.index)
111
+ for coord in ["x_min", "y_min", "x_max", "y_max"]:
112
+ df[coord] = utils.smooth_by_conv(11, df, coord)
113
+
114
+ if self.is_train:
115
+ # cond for data augmentation
116
+ cond_jitter = np.random.random_sample()
117
+ cond_flip = np.random.random_sample()
118
+ cond_color = np.random.random_sample()
119
+ if cond_color < self.color_jitter:
120
+ n1 = np.random.uniform(0.5, 1.5)
121
+ n2 = np.random.uniform(0.5, 1.5)
122
+ n3 = np.random.uniform(0.5, 1.5)
123
+ cond_crop = np.random.random_sample()
124
+ cond_rotate = np.random.random_sample()
125
+ if cond_rotate < self.rand_rotate:
126
+ angle = (2 * np.random.random_sample() - 1) * 20
127
+ angle = -math.radians(angle)
128
+ cond_lsj = np.random.random_sample()
129
+ if cond_lsj < self.rand_lsj:
130
+ lsj_scale = 0.1 + np.random.random_sample() * 0.9
131
+
132
+ # if longer than seq_len_limit, cut it down to the limit with the init index randomly sampled
133
+ if seq_len > self.seq_len:
134
+ sampled_ind = np.random.randint(0, seq_len - self.seq_len)
135
+ seq_len = self.seq_len
136
+ else:
137
+ sampled_ind = 0
138
+
139
+ if cond_crop < self.rand_crop:
140
+ sliced_x_min = df["x_min"].iloc[sampled_ind : sampled_ind + seq_len]
141
+ sliced_x_max = df["x_max"].iloc[sampled_ind : sampled_ind + seq_len]
142
+ sliced_y_min = df["y_min"].iloc[sampled_ind : sampled_ind + seq_len]
143
+ sliced_y_max = df["y_max"].iloc[sampled_ind : sampled_ind + seq_len]
144
+
145
+ sliced_gaze_x = df["gaze_x"].iloc[sampled_ind : sampled_ind + seq_len]
146
+ sliced_gaze_y = df["gaze_y"].iloc[sampled_ind : sampled_ind + seq_len]
147
+
148
+ check_sum = sliced_gaze_x.sum() + sliced_gaze_y.sum()
149
+ all_outside = check_sum == -2 * seq_len
150
+
151
+ # Calculate the minimum valid range of the crop that doesn't exclude the face and the gaze target
152
+ if all_outside:
153
+ crop_x_min = np.min([sliced_x_min.min(), sliced_x_max.min()])
154
+ crop_y_min = np.min([sliced_y_min.min(), sliced_y_max.min()])
155
+ crop_x_max = np.max([sliced_x_min.max(), sliced_x_max.max()])
156
+ crop_y_max = np.max([sliced_y_min.max(), sliced_y_max.max()])
157
+ else:
158
+ crop_x_min = np.min(
159
+ [sliced_gaze_x.min(), sliced_x_min.min(), sliced_x_max.min()]
160
+ )
161
+ crop_y_min = np.min(
162
+ [sliced_gaze_y.min(), sliced_y_min.min(), sliced_y_max.min()]
163
+ )
164
+ crop_x_max = np.max(
165
+ [sliced_gaze_x.max(), sliced_x_min.max(), sliced_x_max.max()]
166
+ )
167
+ crop_y_max = np.max(
168
+ [sliced_gaze_y.max(), sliced_y_min.max(), sliced_y_max.max()]
169
+ )
170
+
171
+ # Randomly select a random top left corner
172
+ if crop_x_min >= 0:
173
+ crop_x_min = np.random.uniform(0, crop_x_min)
174
+ if crop_y_min >= 0:
175
+ crop_y_min = np.random.uniform(0, crop_y_min)
176
+
177
+ # Get image size
178
+ path = osp.join(self.data_dir, df["path"].iloc[0])
179
+ img = Image.open(path)
180
+ img = img.convert("RGB")
181
+ width, height = img.size
182
+
183
+ # Find the range of valid crop width and height starting from the (crop_x_min, crop_y_min)
184
+ crop_width_min = crop_x_max - crop_x_min
185
+ crop_height_min = crop_y_max - crop_y_min
186
+ crop_width_max = width - crop_x_min
187
+ crop_height_max = height - crop_y_min
188
+ # Randomly select a width and a height
189
+ crop_width = np.random.uniform(crop_width_min, crop_width_max)
190
+ crop_height = np.random.uniform(crop_height_min, crop_height_max)
191
+
192
+ # Round to integers
193
+ crop_y_min, crop_x_min, crop_height, crop_width = map(
194
+ int, map(round, (crop_y_min, crop_x_min, crop_height, crop_width))
195
+ )
196
+ else:
197
+ sampled_ind = 0
198
+
199
+ images = []
200
+ head_channels = []
201
+ heatmaps = []
202
+ gazes = []
203
+ gaze_inouts = []
204
+ imsizes = []
205
+ head_masks = []
206
+ if self.is_train and self.mask_generator is not None:
207
+ image_masks = []
208
+ for i, row in df.iterrows():
209
+ if self.is_train and (i < sampled_ind or i >= (sampled_ind + self.seq_len)):
210
+ continue
211
+
212
+ x_min = row["x_min"] # note: Already in image coordinates
213
+ y_min = row["y_min"] # note: Already in image coordinates
214
+ x_max = row["x_max"] # note: Already in image coordinates
215
+ y_max = row["y_max"] # note: Already in image coordinates
216
+ gaze_x = row["gaze_x"] # note: Already in image coordinates
217
+ gaze_y = row["gaze_y"] # note: Already in image coordinates
218
+
219
+ if x_min > x_max:
220
+ x_min, x_max = x_max, x_min
221
+ if y_min > y_max:
222
+ y_min, y_max = y_max, y_min
223
+
224
+ path = row["path"]
225
+ img = Image.open(osp.join(self.data_dir, path)).convert("RGB")
226
+ width, height = img.size
227
+ imsize = torch.FloatTensor([width, height])
228
+ imsizes.append(imsize)
229
+ # Since we finetune from weights trained on GazeFollow,
230
+ # we don't incorporate the auxiliary task for VAT.
231
+ if osp.exists(osp.join(self.head_dir, path)):
232
+ head_mask = Image.open(osp.join(self.head_dir, path)).resize(
233
+ (width, height)
234
+ )
235
+ else:
236
+ head_mask = Image.fromarray(np.zeros((height, width), dtype=np.float32))
237
+
238
+ x_min, y_min, x_max, y_max = map(float, [x_min, y_min, x_max, y_max])
239
+ gaze_x, gaze_y = map(float, [gaze_x, gaze_y])
240
+ if gaze_x == -1 and gaze_y == -1:
241
+ gaze_inside = False
242
+ else:
243
+ if (
244
+ gaze_x < 0
245
+ ): # move gaze point that was sliglty outside the image back in
246
+ gaze_x = 0
247
+ if gaze_y < 0:
248
+ gaze_y = 0
249
+ gaze_inside = True
250
+
251
+ if self.is_train:
252
+ ## data augmentation
253
+ # Jitter (expansion-only) bounding box size.
254
+ if cond_jitter < self.bbox_jitter:
255
+ k = cond_jitter * 0.1
256
+ x_min -= k * abs(x_max - x_min)
257
+ y_min -= k * abs(y_max - y_min)
258
+ x_max += k * abs(x_max - x_min)
259
+ y_max += k * abs(y_max - y_min)
260
+ x_min = np.clip(x_min, 0, width - 1)
261
+ x_max = np.clip(x_max, 0, width - 1)
262
+ y_min = np.clip(y_min, 0, height - 1)
263
+ y_max = np.clip(y_max, 0, height - 1)
264
+
265
+ # Random color change
266
+ if cond_color < self.color_jitter:
267
+ img = TF.adjust_brightness(img, brightness_factor=n1)
268
+ img = TF.adjust_contrast(img, contrast_factor=n2)
269
+ img = TF.adjust_saturation(img, saturation_factor=n3)
270
+
271
+ # Random Crop
272
+ if cond_crop < self.rand_crop:
273
+ # Crop it
274
+ img = TF.crop(img, crop_y_min, crop_x_min, crop_height, crop_width)
275
+ head_mask = TF.crop(
276
+ head_mask, crop_y_min, crop_x_min, crop_height, crop_width
277
+ )
278
+
279
+ # Record the crop's (x, y) offset
280
+ offset_x, offset_y = crop_x_min, crop_y_min
281
+
282
+ # convert coordinates into the cropped frame
283
+ x_min, y_min, x_max, y_max = (
284
+ x_min - offset_x,
285
+ y_min - offset_y,
286
+ x_max - offset_x,
287
+ y_max - offset_y,
288
+ )
289
+ if gaze_inside:
290
+ gaze_x, gaze_y = (gaze_x - offset_x), (gaze_y - offset_y)
291
+ else:
292
+ gaze_x = -1
293
+ gaze_y = -1
294
+
295
+ width, height = crop_width, crop_height
296
+
297
+ # Flip?
298
+ if cond_flip < self.rand_flip:
299
+ img = img.transpose(Image.FLIP_LEFT_RIGHT)
300
+ head_mask = head_mask.transpose(Image.FLIP_LEFT_RIGHT)
301
+ x_max_2 = width - x_min
302
+ x_min_2 = width - x_max
303
+ x_max = x_max_2
304
+ x_min = x_min_2
305
+ if gaze_x != -1 and gaze_y != -1:
306
+ gaze_x = width - gaze_x
307
+
308
+ # Random Rotation
309
+ if cond_rotate < self.rand_rotate:
310
+ rot_mat = [
311
+ round(math.cos(angle), 15),
312
+ round(math.sin(angle), 15),
313
+ 0.0,
314
+ round(-math.sin(angle), 15),
315
+ round(math.cos(angle), 15),
316
+ 0.0,
317
+ ]
318
+
319
+ def _transform(x, y, matrix):
320
+ return (
321
+ matrix[0] * x + matrix[1] * y + matrix[2],
322
+ matrix[3] * x + matrix[4] * y + matrix[5],
323
+ )
324
+
325
+ def _inv_transform(x, y, matrix):
326
+ x, y = x - matrix[2], y - matrix[5]
327
+ return (
328
+ matrix[0] * x + matrix[3] * y,
329
+ matrix[1] * x + matrix[4] * y,
330
+ )
331
+
332
+ # Calculate offsets
333
+ rot_center = (width / 2.0, height / 2.0)
334
+ rot_mat[2], rot_mat[5] = _transform(
335
+ -rot_center[0], -rot_center[1], rot_mat
336
+ )
337
+ rot_mat[2] += rot_center[0]
338
+ rot_mat[5] += rot_center[1]
339
+ xx = []
340
+ yy = []
341
+ for x, y in ((0, 0), (width, 0), (width, height), (0, height)):
342
+ x, y = _transform(x, y, rot_mat)
343
+ xx.append(x)
344
+ yy.append(y)
345
+ nw = math.ceil(max(xx)) - math.floor(min(xx))
346
+ nh = math.ceil(max(yy)) - math.floor(min(yy))
347
+ rot_mat[2], rot_mat[5] = _transform(
348
+ -(nw - width) / 2.0, -(nh - height) / 2.0, rot_mat
349
+ )
350
+
351
+ img = img.transform((nw, nh), Image.AFFINE, rot_mat, Image.BILINEAR)
352
+ head_mask = head_mask.transform(
353
+ (nw, nh), Image.AFFINE, rot_mat, Image.BILINEAR
354
+ )
355
+
356
+ xx = []
357
+ yy = []
358
+ for x, y in (
359
+ (x_min, y_min),
360
+ (x_min, y_max),
361
+ (x_max, y_min),
362
+ (x_max, y_max),
363
+ ):
364
+ x, y = _inv_transform(x, y, rot_mat)
365
+ xx.append(x)
366
+ yy.append(y)
367
+ x_max, x_min = min(max(xx), nw), max(min(xx), 0)
368
+ y_max, y_min = min(max(yy), nh), max(min(yy), 0)
369
+ gaze_x, gaze_y = _inv_transform(gaze_x, gaze_y, rot_mat)
370
+ width, height = nw, nh
371
+
372
+ if cond_lsj < self.rand_lsj:
373
+ nh, nw = int(height * lsj_scale), int(width * lsj_scale)
374
+ img = TF.resize(img, (nh, nw))
375
+ img = ImageOps.expand(img, (0, 0, width - nw, height - nh))
376
+ head_mask = TF.resize(head_mask, (nh, nw))
377
+ head_mask = ImageOps.expand(
378
+ head_mask, (0, 0, width - nw, height - nh)
379
+ )
380
+ x_min, y_min, x_max, y_max = (
381
+ x_min * lsj_scale,
382
+ y_min * lsj_scale,
383
+ x_max * lsj_scale,
384
+ y_max * lsj_scale,
385
+ )
386
+ gaze_x, gaze_y = gaze_x * lsj_scale, gaze_y * lsj_scale
387
+
388
+ head_channel = utils.get_head_box_channel(
389
+ x_min,
390
+ y_min,
391
+ x_max,
392
+ y_max,
393
+ width,
394
+ height,
395
+ resolution=self.input_size,
396
+ coordconv=False,
397
+ ).unsqueeze(0)
398
+
399
+ if self.is_train and self.mask_generator is not None:
400
+ image_mask = self.mask_generator(
401
+ x_min / width,
402
+ y_min / height,
403
+ x_max / width,
404
+ y_max / height,
405
+ head_channel,
406
+ )
407
+ image_masks.append(image_mask)
408
+
409
+ if self.transform is not None:
410
+ img = self.transform(img)
411
+ head_mask = TF.to_tensor(
412
+ TF.resize(head_mask, (self.input_size, self.input_size))
413
+ )
414
+
415
+ if gaze_inside:
416
+ gaze_x /= float(width) # fractional gaze
417
+ gaze_y /= float(height)
418
+ gaze_heatmap = torch.zeros(
419
+ self.output_size, self.output_size
420
+ ) # set the size of the output
421
+ gaze_map = self.draw_labelmap(
422
+ gaze_heatmap,
423
+ [gaze_x * self.output_size, gaze_y * self.output_size],
424
+ 3,
425
+ type="Gaussian",
426
+ )
427
+ gazes.append(torch.FloatTensor([gaze_x, gaze_y]))
428
+ else:
429
+ gaze_map = torch.zeros(self.output_size, self.output_size)
430
+ gazes.append(torch.FloatTensor([-1, -1]))
431
+ images.append(img)
432
+ head_channels.append(head_channel)
433
+ head_masks.append(head_mask)
434
+ heatmaps.append(gaze_map)
435
+ gaze_inouts.append(torch.FloatTensor([int(gaze_inside)]))
436
+
437
+ images = torch.stack(images)
438
+ head_channels = torch.stack(head_channels)
439
+ heatmaps = torch.stack(heatmaps)
440
+ gazes = torch.stack(gazes)
441
+ gaze_inouts = torch.stack(gaze_inouts)
442
+ head_masks = torch.stack(head_masks)
443
+ imsizes = torch.stack(imsizes)
444
+
445
+ out_dict = {
446
+ "images": images,
447
+ "head_channels": head_channels,
448
+ "heatmaps": heatmaps,
449
+ "gazes": gazes,
450
+ "gaze_inouts": gaze_inouts,
451
+ "head_masks": head_masks,
452
+ "imsize": imsizes,
453
+ }
454
+ if self.is_train and self.mask_generator is not None:
455
+ out_dict["image_masks"] = torch.stack(image_masks)
456
+ return out_dict
457
+
458
+ def __len__(self):
459
+ return self.length
460
+
461
+
462
+ def video_collate(batch):
463
+ keys = batch[0].keys()
464
+ return {key: torch.cat([item[key] for item in batch]) for key in keys}
docs/eval.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Eval
2
+ ### Testing Dataset
3
+
4
+ You should prepare GazeFollow and VideoAttentionTarget for training.
5
+
6
+ * Get [GazeFollow](https://www.dropbox.com/s/3ejt9pm57ht2ed4/gazefollow_extended.zip?dl=0).
7
+ * If train with auxiliary regression, use `scripts\gen_gazefollow_head_masks.py` to generate head masks.
8
+ * Get [VideoAttentionTarget](https://www.dropbox.com/s/8ep3y1hd74wdjy5/videoattentiontarget.zip?dl=0).
9
+
10
+ Check `ViTGaze/configs/common/dataloader` to modify DATA_ROOT.
11
+
12
+ ### Evaluation
13
+
14
+ Run
15
+ ```
16
+ bash val.sh configs/gazefollow_518.py ${Path2checkpoint} gf
17
+ ```
18
+ to evaluate on GazeFollow.
19
+
20
+ Run
21
+ ```
22
+ bash val.sh configs/videoattentiontarget.py ${Path2checkpoint} vat
23
+ ```
24
+ to evaluate on VideoAttentionTarget.
docs/install.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Installation
2
+
3
+ * Create a conda virtual env and activate it.
4
+
5
+ ```
6
+ conda create -n ViTGaze python==3.9.18
7
+ conda activate ViTGaze
8
+ ```
9
+ * Install packages.
10
+
11
+ ```
12
+ cd path/to/ViTGaze
13
+ pip install -r requirements.txt
14
+ ```
15
+ * Install [detectron2](https://github.com/facebookresearch/detectron2) , follow its [documentation](https://detectron2.readthedocs.io/en/latest/).
16
+ For ViTGaze, we recommend to build it from latest source code.
17
+ ```
18
+ python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'
19
+ ```
docs/train.md ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Train
2
+
3
+ ### Training Dataset
4
+
5
+ You should prepare GazeFollow and VideoAttentionTarget for training.
6
+
7
+ * Get [GazeFollow](https://www.dropbox.com/s/3ejt9pm57ht2ed4/gazefollow_extended.zip?dl=0).
8
+ * If train with auxiliary regression, use `scripts\gen_gazefollow_head_masks.py` to generate head masks.
9
+ * Get [VideoAttentionTarget](https://www.dropbox.com/s/8ep3y1hd74wdjy5/videoattentiontarget.zip?dl=0).
10
+
11
+ Check `ViTGaze/configs/common/dataloader` to modify DATA_ROOT.
12
+
13
+ ### Pretrained Model
14
+
15
+ * Get [DINOv2](https://github.com/facebookresearch/dinov2) pretrained ViT-S.
16
+ * Or you could download and preprocess pretrained weights by
17
+
18
+ ```
19
+ cd ViTGaze
20
+ mkdir pretrained && cd pretrained
21
+ wget https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth
22
+ ```
23
+ * Preprocess the model weights with `scripts\convert_pth.py` to fit Detectron2 format.
24
+ ### Train ViTGaze
25
+
26
+ You can modify configs in `configs/gazefollow.py`, `configs/gazefollow_518.py` and `configs/videoattentiontarget.py`.
27
+
28
+ Run:
29
+
30
+ ```
31
+ bash train.sh
32
+ ```
33
+
34
+ to train ViTGaze on the two datasets.
35
+
36
+ Training output will be saved in `ViTGaze/output/`.
engine/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .trainer import CycleTrainer
engine/trainer.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ from detectron2.engine import SimpleTrainer
4
+ from typing import Iterable, Generator
5
+
6
+
7
+ def cycle(iterable: Iterable) -> Generator:
8
+ while True:
9
+ for item in iterable:
10
+ yield item
11
+
12
+
13
+ class CycleTrainer(SimpleTrainer):
14
+ def __init__(
15
+ self,
16
+ model,
17
+ data_loader,
18
+ optimizer,
19
+ gather_metric_period=1,
20
+ zero_grad_before_forward=False,
21
+ async_write_metrics=False,
22
+ ):
23
+ super().__init__(
24
+ model,
25
+ data_loader,
26
+ optimizer,
27
+ gather_metric_period,
28
+ zero_grad_before_forward,
29
+ async_write_metrics,
30
+ )
31
+
32
+ @property
33
+ def _data_loader_iter(self):
34
+ # only create the data loader iterator when it is used
35
+ if self._data_loader_iter_obj is None:
36
+ self._data_loader_iter_obj = cycle(self.data_loader)
37
+ return self._data_loader_iter_obj
modeling/_init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from . import backbone, patch_attention, head, criterion, meta_arch
2
+
3
+
4
+ __all__ = ["backbone", "patch_attention", "head", "criterion", "meta_arch"]
modeling/backbone/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .vit import build_backbone
modeling/backbone/utils.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from itertools import repeat
3
+ from typing import Iterable
4
+ import math
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ __all__ = [
10
+ "get_abs_pos",
11
+ "PatchEmbed",
12
+ "Mlp",
13
+ "DropPath",
14
+ ]
15
+
16
+
17
+ def to_2tuple(x):
18
+ if isinstance(x, Iterable) and not isinstance(x, str):
19
+ return tuple(x)
20
+ return tuple(repeat(x, 2))
21
+
22
+
23
+ def get_abs_pos(abs_pos, has_cls_token, hw):
24
+ """
25
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
26
+ dimension for the original embeddings.
27
+ Args:
28
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
29
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
30
+ hw (Tuple): size of input image tokens.
31
+
32
+ Returns:
33
+ Absolute positional embeddings after processing with shape (1, H, W, C)
34
+ """
35
+ h, w = hw
36
+ if has_cls_token:
37
+ abs_pos = abs_pos[:, 1:]
38
+ xy_num = abs_pos.shape[1]
39
+ size = int(math.sqrt(xy_num))
40
+ assert size * size == xy_num
41
+
42
+ if size != h or size != w:
43
+ new_abs_pos = F.interpolate(
44
+ abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2),
45
+ size=(h, w),
46
+ mode="bicubic",
47
+ align_corners=False,
48
+ )
49
+
50
+ return new_abs_pos.permute(0, 2, 3, 1)
51
+ else:
52
+ return abs_pos.reshape(1, h, w, -1)
53
+
54
+
55
+ def drop_path(
56
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
57
+ ):
58
+ if drop_prob == 0.0 or not training:
59
+ return x
60
+ keep_prob = 1 - drop_prob
61
+ shape = (x.shape[0],) + (1,) * (
62
+ x.ndim - 1
63
+ ) # work with diff dim tensors, not just 2D ConvNets
64
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
65
+ if keep_prob > 0.0 and scale_by_keep:
66
+ random_tensor.div_(keep_prob)
67
+ return x * random_tensor
68
+
69
+
70
+ class PatchEmbed(nn.Module):
71
+ """
72
+ Image to Patch Embedding.
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ kernel_size=(16, 16),
78
+ stride=(16, 16),
79
+ padding=(0, 0),
80
+ in_chans=3,
81
+ embed_dim=768,
82
+ ):
83
+ """
84
+ Args:
85
+ kernel_size (Tuple): kernel size of the projection layer.
86
+ stride (Tuple): stride of the projection layer.
87
+ padding (Tuple): padding size of the projection layer.
88
+ in_chans (int): Number of input image channels.
89
+ embed_dim (int): embed_dim (int): Patch embedding dimension.
90
+ """
91
+ super().__init__()
92
+
93
+ self.proj = nn.Conv2d(
94
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
95
+ )
96
+
97
+ def forward(self, x):
98
+ x = self.proj(x)
99
+ # B C H W -> B H W C
100
+ x = x.permute(0, 2, 3, 1)
101
+ return x
102
+
103
+
104
+ class DropPath(nn.Module):
105
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
106
+
107
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
108
+ super(DropPath, self).__init__()
109
+ self.drop_prob = drop_prob
110
+ self.scale_by_keep = scale_by_keep
111
+
112
+ def forward(self, x):
113
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
114
+
115
+ def extra_repr(self):
116
+ return f"drop_prob={round(self.drop_prob,3):0.3f}"
117
+
118
+
119
+ class Mlp(nn.Module):
120
+ def __init__(
121
+ self,
122
+ in_features,
123
+ hidden_features=None,
124
+ out_features=None,
125
+ act_layer=nn.GELU,
126
+ norm_layer=None,
127
+ bias=True,
128
+ drop=0.0,
129
+ use_conv=False,
130
+ ):
131
+ super().__init__()
132
+ out_features = out_features or in_features
133
+ hidden_features = hidden_features or in_features
134
+ bias = to_2tuple(bias)
135
+ drop_probs = to_2tuple(drop)
136
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
137
+
138
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
139
+ self.act = act_layer()
140
+ self.drop1 = nn.Dropout(drop_probs[0])
141
+ self.norm = (
142
+ norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
143
+ )
144
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
145
+ self.drop2 = nn.Dropout(drop_probs[1])
146
+
147
+ def forward(self, x):
148
+ x = self.fc1(x)
149
+ x = self.act(x)
150
+ x = self.drop1(x)
151
+ x = self.norm(x)
152
+ x = self.fc2(x)
153
+ x = self.drop2(x)
154
+ return x
modeling/backbone/vit.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Literal, Union
3
+ from functools import partial
4
+ import torch
5
+ import torch.nn as nn
6
+ from detectron2.modeling import Backbone
7
+
8
+ try:
9
+ from xformers.ops import memory_efficient_attention
10
+
11
+ XFORMERS_ON = True
12
+ except ImportError:
13
+ XFORMERS_ON = False
14
+ from .utils import (
15
+ PatchEmbed,
16
+ get_abs_pos,
17
+ DropPath,
18
+ Mlp,
19
+ )
20
+
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class Attention(nn.Module):
26
+ def __init__(
27
+ self,
28
+ dim,
29
+ num_heads=8,
30
+ qkv_bias=True,
31
+ return_softmax_attn=True,
32
+ use_proj=True,
33
+ patch_token_offset=0,
34
+ ):
35
+ """
36
+ Args:
37
+ dim (int): Number of input channels.
38
+ num_heads (int): Number of attention heads.
39
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
40
+ """
41
+ super().__init__()
42
+ self.num_heads = num_heads
43
+ head_dim = dim // num_heads
44
+ self.scale = head_dim**-0.5
45
+
46
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
47
+ self.proj = nn.Linear(dim, dim) if use_proj else nn.Identity()
48
+
49
+ self.return_softmax_attn = return_softmax_attn
50
+
51
+ self.patch_token_offset = patch_token_offset
52
+
53
+ def forward(self, x, return_attention=False, extra_token_offset=None):
54
+ B, L, _ = x.shape
55
+ # qkv with shape (3, B, nHead, H * W, C)
56
+ qkv = self.qkv(x).view(B, L, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
57
+ # q, k, v with shape (B * nHead, H * W, C)
58
+ q, k, v = qkv.reshape(3, B * self.num_heads, L, -1).unbind(0)
59
+
60
+ if return_attention or not XFORMERS_ON:
61
+ attn = (q * self.scale) @ k.transpose(-2, -1)
62
+ if return_attention and not self.return_softmax_attn:
63
+ out_attn = attn
64
+ attn = attn.softmax(dim=-1)
65
+ if return_attention and self.return_softmax_attn:
66
+ out_attn = attn
67
+ x = attn @ v
68
+ else:
69
+ x = memory_efficient_attention(q, k, v, scale=self.scale)
70
+
71
+ x = x.view(B, self.num_heads, L, -1).permute(0, 2, 1, 3).reshape(B, L, -1)
72
+ x = self.proj(x)
73
+
74
+ if return_attention:
75
+ out_attn = out_attn.reshape(B, self.num_heads, L, -1)
76
+ out_attn = out_attn[
77
+ :,
78
+ :,
79
+ self.patch_token_offset : extra_token_offset,
80
+ self.patch_token_offset : extra_token_offset,
81
+ ]
82
+ return x, out_attn
83
+ else:
84
+ return x
85
+
86
+
87
+ class Block(nn.Module):
88
+ def __init__(
89
+ self,
90
+ dim,
91
+ num_heads,
92
+ mlp_ratio=4.0,
93
+ qkv_bias=True,
94
+ drop_path=0.0,
95
+ norm_layer=nn.LayerNorm,
96
+ act_layer=nn.GELU,
97
+ init_values=None,
98
+ return_softmax_attn=True,
99
+ attention_map_only=False,
100
+ patch_token_offset=0,
101
+ ):
102
+ """
103
+ Args:
104
+ dim (int): Number of input channels.
105
+ num_heads (int): Number of attention heads in each ViT block.
106
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
107
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
108
+ drop_path (float): Stochastic depth rate.
109
+ norm_layer (nn.Module): Normalization layer.
110
+ act_layer (nn.Module): Activation layer.
111
+ """
112
+ super().__init__()
113
+ self.attention_map_only = attention_map_only
114
+ self.norm1 = norm_layer(dim)
115
+ self.attn = Attention(
116
+ dim,
117
+ num_heads=num_heads,
118
+ qkv_bias=qkv_bias,
119
+ return_softmax_attn=return_softmax_attn,
120
+ use_proj=return_softmax_attn or not attention_map_only,
121
+ patch_token_offset=patch_token_offset,
122
+ )
123
+
124
+ if attention_map_only:
125
+ return
126
+
127
+ self.ls1 = (
128
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
129
+ )
130
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
131
+ self.norm2 = norm_layer(dim)
132
+ self.mlp = Mlp(
133
+ in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer
134
+ )
135
+ self.ls2 = (
136
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
137
+ )
138
+
139
+ def forward(self, x, return_attention=False, extra_token_offset=None):
140
+ shortcut = x
141
+ x = self.norm1(x)
142
+
143
+ if return_attention:
144
+ x, attn = self.attn(x, True, extra_token_offset)
145
+ else:
146
+ x = self.attn(x)
147
+
148
+ if self.attention_map_only:
149
+ return x, attn
150
+
151
+ x = shortcut + self.drop_path(self.ls1(x))
152
+ x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))
153
+
154
+ if return_attention:
155
+ return x, attn
156
+ else:
157
+ return x
158
+
159
+
160
+ class LayerScale(nn.Module):
161
+ def __init__(
162
+ self,
163
+ dim: int,
164
+ init_values: Union[float, torch.Tensor] = 1e-5,
165
+ inplace: bool = False,
166
+ ) -> None:
167
+ super().__init__()
168
+ self.inplace = inplace
169
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
170
+
171
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
172
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
173
+
174
+
175
+ class ViT(Backbone):
176
+ def __init__(
177
+ self,
178
+ patch_size=16,
179
+ in_chans=3,
180
+ embed_dim=768,
181
+ depth=12,
182
+ num_heads=12,
183
+ mlp_ratio=4.0,
184
+ qkv_bias=True,
185
+ drop_path_rate=0.0,
186
+ norm_layer=nn.LayerNorm,
187
+ act_layer=nn.GELU,
188
+ pretrain_img_size=224,
189
+ pretrain_use_cls_token=True,
190
+ init_values=None,
191
+ use_cls_token=False,
192
+ use_mask_token=False,
193
+ norm_features=False,
194
+ return_softmax_attn=True,
195
+ num_register_tokens=0,
196
+ num_msg_tokens=0,
197
+ register_as_msg=False,
198
+ shift_strides=None, # [1, -1, 2, -2],
199
+ cls_shift=False,
200
+ num_extra_tokens=4,
201
+ use_extra_embed=False,
202
+ num_frames=None,
203
+ out_feature=True,
204
+ out_attn=(),
205
+ ):
206
+ super().__init__()
207
+ self.pretrain_use_cls_token = pretrain_use_cls_token
208
+
209
+ self.patch_size = patch_size
210
+ self.patch_embed = PatchEmbed(
211
+ kernel_size=(patch_size, patch_size),
212
+ stride=(patch_size, patch_size),
213
+ in_chans=in_chans,
214
+ embed_dim=embed_dim,
215
+ )
216
+
217
+ # Initialize absolute positional embedding with pretrain image size.
218
+ num_patches = (pretrain_img_size // patch_size) * (
219
+ pretrain_img_size // patch_size
220
+ )
221
+ num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
222
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
223
+
224
+ self.use_cls_token = use_cls_token
225
+ self.cls_token = (
226
+ nn.Parameter(torch.zeros(1, 1, embed_dim)) if use_cls_token else None
227
+ )
228
+
229
+ assert num_register_tokens >= 0
230
+ self.register_tokens = (
231
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim))
232
+ if num_register_tokens > 0
233
+ else None
234
+ )
235
+
236
+ # We tried to leverage temporal information with TeViT while it doesn't work
237
+ assert num_msg_tokens >= 0
238
+ self.num_msg_tokens = num_msg_tokens
239
+ if register_as_msg:
240
+ self.num_msg_tokens += num_register_tokens
241
+ self.msg_tokens = (
242
+ nn.Parameter(torch.zeros(1, num_msg_tokens, embed_dim))
243
+ if num_msg_tokens > 0
244
+ else None
245
+ )
246
+
247
+ patch_token_offset = (
248
+ num_msg_tokens + num_register_tokens + int(self.use_cls_token)
249
+ )
250
+ self.patch_token_offset = patch_token_offset
251
+
252
+ self.msg_shift = None
253
+ if shift_strides is not None:
254
+ self.msg_shift = []
255
+ for i in range(depth):
256
+ if i % 2 == 0:
257
+ self.msg_shift.append([_ for _ in shift_strides])
258
+ else:
259
+ self.msg_shift.append([-_ for _ in shift_strides])
260
+
261
+ self.cls_shift = None
262
+ if cls_shift:
263
+ self.cls_shift = [(-1) ** idx for idx in range(depth)]
264
+
265
+ assert num_extra_tokens >= 0
266
+ self.num_extra_tokens = num_extra_tokens
267
+ self.extra_pos_embed = (
268
+ nn.Linear(embed_dim, embed_dim)
269
+ if num_extra_tokens > 0 and use_extra_embed
270
+ else nn.Identity()
271
+ )
272
+
273
+ self.num_frames = num_frames
274
+
275
+ # Mask token for masking augmentation
276
+ self.use_mask_token = use_mask_token
277
+ self.mask_token = (
278
+ nn.Parameter(torch.zeros(1, embed_dim)) if use_mask_token else None
279
+ )
280
+
281
+ # stochastic depth decay rule
282
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
283
+
284
+ self.blocks = nn.ModuleList()
285
+ for i in range(depth):
286
+ block = Block(
287
+ dim=embed_dim,
288
+ num_heads=num_heads,
289
+ mlp_ratio=mlp_ratio,
290
+ qkv_bias=qkv_bias,
291
+ drop_path=dpr[i],
292
+ norm_layer=norm_layer,
293
+ act_layer=act_layer,
294
+ init_values=init_values,
295
+ return_softmax_attn=return_softmax_attn,
296
+ attention_map_only=(i == depth - 1) and not out_feature,
297
+ patch_token_offset=patch_token_offset,
298
+ )
299
+ self.blocks.append(block)
300
+
301
+ self.norm = norm_layer(embed_dim) if norm_features else nn.Identity()
302
+
303
+ self._out_features = out_feature
304
+ self._out_attn = out_attn
305
+
306
+ if self.pos_embed is not None:
307
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
308
+
309
+ self.apply(self._init_weights)
310
+
311
+ def _init_weights(self, m):
312
+ if isinstance(m, nn.Linear):
313
+ nn.init.trunc_normal_(m.weight, std=0.02)
314
+ if isinstance(m, nn.Linear) and m.bias is not None:
315
+ nn.init.constant_(m.bias, 0)
316
+ elif isinstance(m, nn.LayerNorm):
317
+ nn.init.constant_(m.bias, 0)
318
+ nn.init.constant_(m.weight, 1.0)
319
+
320
+ def forward(self, x, masks=None, guidance=None):
321
+ x = self.patch_embed(x)
322
+ if masks is not None:
323
+ x = torch.where(
324
+ masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x
325
+ )
326
+
327
+ if self.pos_embed is not None:
328
+ x = x + get_abs_pos(
329
+ self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2])
330
+ )
331
+ B, H, W, _ = x.shape
332
+ x = x.reshape(B, H * W, -1)
333
+
334
+ if self.use_cls_token:
335
+ cls_tokens = self.cls_token.expand(len(x), -1, -1)
336
+ x = torch.cat((cls_tokens, x), dim=1)
337
+
338
+ if self.register_tokens is not None:
339
+ register_tokens = self.register_tokens.expand(len(x), -1, -1)
340
+ x = torch.cat((register_tokens, x), dim=1)
341
+
342
+ if self.msg_tokens is not None:
343
+ msg_tokens = self.msg_tokens.expand(len(x), -1, -1)
344
+ x = torch.cat((msg_tokens, x), dim=1)
345
+ # [MSG, REG, CLS, PAT]
346
+
347
+ extra_tokens_offset = None
348
+ if guidance is not None:
349
+ guidance = guidance.reshape(len(guidance), -1, 1)
350
+ extra_tokens = (
351
+ (x[:, self.patch_token_offset :] * guidance)
352
+ .sum(dim=1, keepdim=True)
353
+ .expand(-1, self.num_extra_tokens, -1)
354
+ )
355
+ extra_tokens = self.extra_pos_embed(extra_tokens)
356
+ x = torch.cat((x, extra_tokens), dim=1)
357
+ extra_tokens_offset = -self.num_extra_tokens
358
+ # [MSG, REG, CLS, PAT, EXT]
359
+
360
+ attn_maps = []
361
+ for idx, blk in enumerate(self.blocks):
362
+ if idx in self._out_attn:
363
+ x, attn = blk(x, True, extra_tokens_offset)
364
+ attn_maps.append(attn)
365
+ else:
366
+ x = blk(x)
367
+
368
+ if self.msg_shift is not None:
369
+ msg_shift = self.msg_shift[idx]
370
+ msg_tokens = (
371
+ x[:, : self.num_msg_tokens]
372
+ if guidance is None
373
+ else x[:, extra_tokens_offset:]
374
+ )
375
+ msg_tokens = msg_tokens.reshape(
376
+ -1, self.num_frames, *msg_tokens.shape[1:]
377
+ )
378
+ msg_tokens = msg_tokens.chunk(len(msg_shift), dim=2)
379
+ msg_tokens = [
380
+ torch.roll(tokens, roll, dims=1)
381
+ for tokens, roll in zip(msg_tokens, msg_shift)
382
+ ]
383
+ msg_tokens = torch.cat(msg_tokens, dim=2).flatten(0, 1)
384
+ if guidance is None:
385
+ x = torch.cat([msg_tokens, x[:, self.num_msg_tokens :]], dim=1)
386
+ else:
387
+ x = torch.cat([x[:, :extra_tokens_offset], msg_tokens], dim=1)
388
+
389
+ if self.cls_shift is not None:
390
+ cls_tokens = x[:, self.patch_token_offset - 1]
391
+ cls_tokens = cls_tokens.reshape(
392
+ -1, self.num_frames, 1, *cls_tokens.shape[1:]
393
+ )
394
+ cls_tokens = torch.roll(cls_tokens, self.cls_shift[idx], dims=1)
395
+ x = torch.cat(
396
+ [
397
+ x[:, : self.patch_token_offset - 1],
398
+ cls_tokens.flatten(0, 1),
399
+ x[:, self.patch_token_offset :],
400
+ ],
401
+ dim=1,
402
+ )
403
+
404
+ x = self.norm(x)
405
+
406
+ outputs = {}
407
+ outputs["attention_maps"] = torch.cat(attn_maps, dim=1).reshape(
408
+ B, -1, H * W, H, W
409
+ )
410
+ if self._out_features:
411
+ outputs["last_feat"] = (
412
+ x[:, self.patch_token_offset : extra_tokens_offset]
413
+ .reshape(B, H, W, -1)
414
+ .permute(0, 3, 1, 2)
415
+ )
416
+
417
+ return outputs
418
+
419
+
420
+ def vit_tiny(**kwargs):
421
+ model = ViT(
422
+ patch_size=16,
423
+ embed_dim=192,
424
+ depth=12,
425
+ num_heads=3,
426
+ mlp_ratio=4,
427
+ qkv_bias=True,
428
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
429
+ **kwargs
430
+ )
431
+ return model
432
+
433
+
434
+ def vit_small(**kwargs):
435
+ model = ViT(
436
+ patch_size=16,
437
+ embed_dim=384,
438
+ depth=12,
439
+ num_heads=6,
440
+ mlp_ratio=4,
441
+ qkv_bias=True,
442
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
443
+ **kwargs
444
+ )
445
+ return model
446
+
447
+
448
+ def vit_base(**kwargs):
449
+ model = ViT(
450
+ patch_size=16,
451
+ embed_dim=768,
452
+ depth=12,
453
+ num_heads=12,
454
+ mlp_ratio=4,
455
+ qkv_bias=True,
456
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
457
+ **kwargs
458
+ )
459
+ return model
460
+
461
+
462
+ def dinov2_base(**kwargs):
463
+ model = ViT(
464
+ patch_size=14,
465
+ embed_dim=768,
466
+ depth=12,
467
+ num_heads=12,
468
+ mlp_ratio=4,
469
+ qkv_bias=True,
470
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
471
+ pretrain_img_size=518,
472
+ init_values=1,
473
+ **kwargs
474
+ )
475
+ return model
476
+
477
+
478
+ def dinov2_small(**kwargs):
479
+ model = ViT(
480
+ patch_size=14,
481
+ embed_dim=384,
482
+ depth=12,
483
+ num_heads=6,
484
+ mlp_ratio=4,
485
+ qkv_bias=True,
486
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
487
+ pretrain_img_size=518,
488
+ init_values=1,
489
+ **kwargs
490
+ )
491
+ return model
492
+
493
+
494
+ def build_backbone(
495
+ name: Literal["tiny", "small", "base", "dinov2_base", "dinov2_small"], **kwargs
496
+ ):
497
+ vit_dict = {
498
+ "tiny": vit_tiny,
499
+ "small": vit_small,
500
+ "base": vit_base,
501
+ "dinov2_base": dinov2_base,
502
+ "dinov2_small": dinov2_small,
503
+ }
504
+ return vit_dict[name](**kwargs)
modeling/criterion/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .gaze_mapper_criterion import GazeMapperCriterion
modeling/criterion/gaze_mapper_criterion.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+ from fvcore.nn import sigmoid_focal_loss_jit
6
+
7
+
8
+ class GazeMapperCriterion(nn.Module):
9
+ def __init__(
10
+ self,
11
+ heatmap_weight: float = 10000,
12
+ inout_weight: float = 100,
13
+ aux_weight: float = 100,
14
+ use_aux_loss: bool = False,
15
+ aux_head_thres: float = 0,
16
+ use_focal_loss: bool = False,
17
+ alpha: float = -1,
18
+ gamma: float = 2,
19
+ ):
20
+ super().__init__()
21
+ self.heatmap_weight = heatmap_weight
22
+ self.inout_weight = inout_weight
23
+ self.aux_weight = aux_weight
24
+ self.aux_head_thres = aux_head_thres
25
+
26
+ self.heatmap_loss = nn.MSELoss(reduce=False)
27
+
28
+ if use_focal_loss:
29
+ self.inout_loss = partial(
30
+ sigmoid_focal_loss_jit, alpha=alpha, gamma=gamma, reduction="mean"
31
+ )
32
+ else:
33
+ self.inout_loss = nn.BCEWithLogitsLoss()
34
+
35
+ if use_aux_loss:
36
+ self.aux_loss = nn.BCEWithLogitsLoss()
37
+ else:
38
+ self.aux_loss = None
39
+
40
+ def forward(
41
+ self,
42
+ pred_heatmap,
43
+ pred_inout,
44
+ gt_heatmap,
45
+ gt_inout,
46
+ pred_head_masks=None,
47
+ gt_head_masks=None,
48
+ ):
49
+ loss_dict = {}
50
+
51
+ pred_heatmap = F.interpolate(
52
+ pred_heatmap,
53
+ size=tuple(gt_heatmap.shape[-2:]),
54
+ mode="bilinear",
55
+ align_corners=True,
56
+ )
57
+ heatmap_loss = (
58
+ self.heatmap_loss(pred_heatmap.squeeze(1), gt_heatmap) * self.heatmap_weight
59
+ )
60
+ heatmap_loss = torch.mean(heatmap_loss, dim=(-2, -1))
61
+ heatmap_loss = torch.sum(heatmap_loss.reshape(-1) * gt_inout.reshape(-1))
62
+ # Check whether all outside, avoid 0/0 to be nan
63
+ if heatmap_loss > 1e-7:
64
+ heatmap_loss = heatmap_loss / torch.sum(gt_inout)
65
+ loss_dict["regression loss"] = heatmap_loss
66
+ else:
67
+ loss_dict["regression loss"] = heatmap_loss * 0
68
+
69
+ inout_loss = (
70
+ self.inout_loss(pred_inout.reshape(-1), gt_inout.reshape(-1))
71
+ * self.inout_weight
72
+ )
73
+ loss_dict["classification loss"] = inout_loss
74
+
75
+ if self.aux_loss is not None:
76
+ pred_head_masks = F.interpolate(
77
+ pred_head_masks,
78
+ size=tuple(gt_head_masks.shape[-2:]),
79
+ mode="bilinear",
80
+ align_corners=True,
81
+ )
82
+ aux_loss = (
83
+ torch.clamp(
84
+ self.aux_loss(
85
+ pred_head_masks.reshape(-1), gt_head_masks.reshape(-1)
86
+ )
87
+ - self.aux_head_thres,
88
+ 0,
89
+ )
90
+ * self.aux_weight
91
+ )
92
+ loss_dict["aux head loss"] = aux_loss
93
+
94
+ return loss_dict
modeling/head/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .inout_head import build_inout_head
2
+ from .heatmap_head import build_heatmap_head
modeling/head/heatmap_head.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from detectron2.utils.registry import Registry
4
+ from typing import Literal, List, Dict, Optional, OrderedDict
5
+
6
+
7
+ HEATMAP_HEAD_REGISTRY = Registry("HEATMAP_HEAD_REGISTRY")
8
+ HEATMAP_HEAD_REGISTRY.__doc__ = "Registry for heatmap head"
9
+
10
+
11
+ class BaseHeatmapHead(nn.Module):
12
+ def __init__(
13
+ self,
14
+ in_channel: int,
15
+ deconv_cfgs: List[Dict],
16
+ dim: int = 96,
17
+ use_conv: bool = False,
18
+ use_residual: bool = False,
19
+ feat_type: Literal["attn", "both"] = "both",
20
+ attn_layer: Optional[str] = None,
21
+ pre_norm: bool = False,
22
+ use_head: bool = False,
23
+ ) -> None:
24
+ super().__init__()
25
+ self.feat_type = feat_type
26
+ self.use_head = use_head
27
+
28
+ if pre_norm:
29
+ self.pre_norm = nn.Sequential(
30
+ OrderedDict(
31
+ [
32
+ ("bn", nn.BatchNorm2d(in_channel)),
33
+ ("relu", nn.ReLU(inplace=True)),
34
+ ]
35
+ )
36
+ )
37
+ else:
38
+ self.pre_norm = nn.Identity()
39
+
40
+ if use_conv:
41
+ if use_residual:
42
+ from timm.models.resnet import Bottleneck, downsample_conv
43
+
44
+ self.conv = Bottleneck(
45
+ in_channel,
46
+ dim // 4,
47
+ downsample=downsample_conv(in_channel, dim, 1)
48
+ if in_channel != dim
49
+ else None,
50
+ attn_layer=attn_layer,
51
+ )
52
+ else:
53
+ self.conv = nn.Sequential(
54
+ OrderedDict(
55
+ [
56
+ ("conv", nn.Conv2d(in_channel, dim, 3, 1, 1)),
57
+ ("bn", nn.BatchNorm2d(dim)),
58
+ ("relu", nn.ReLU(inplace=True)),
59
+ ]
60
+ )
61
+ )
62
+ else:
63
+ self.conv = nn.Identity()
64
+
65
+ self.decoder: nn.Module = None
66
+
67
+ def get_feat(self, x):
68
+ if self.feat_type == "attn":
69
+ feat = x["attention_maps"]
70
+ elif self.feat_type == "feat":
71
+ feat = x["last_feat"]
72
+ return feat
73
+
74
+ def forward(self, x):
75
+ feat = self.get_feat(x)
76
+ feat = self.pre_norm(feat)
77
+ feat = self.conv(feat)
78
+ return self.decoder(feat)
79
+
80
+
81
+ @HEATMAP_HEAD_REGISTRY.register()
82
+ class SimpleDeconv(BaseHeatmapHead):
83
+ def __init__(
84
+ self,
85
+ in_channel: int,
86
+ deconv_cfgs: List[Dict],
87
+ dim: int = 96,
88
+ use_conv: bool = False,
89
+ use_residual: bool = False,
90
+ feat_type: Literal["attn", "both"] = "both",
91
+ attn_layer: Optional[str] = None,
92
+ pre_norm: bool = False,
93
+ use_head: bool = False,
94
+ ) -> None:
95
+ super().__init__(
96
+ in_channel,
97
+ deconv_cfgs,
98
+ dim,
99
+ use_conv,
100
+ use_residual,
101
+ feat_type,
102
+ attn_layer,
103
+ pre_norm,
104
+ use_head,
105
+ )
106
+ decoder_layers = []
107
+ for i, deconv_cfg in enumerate(deconv_cfgs, start=1):
108
+ decoder_layers.extend(
109
+ [
110
+ (
111
+ "".join(["deconv", str(i)]),
112
+ nn.ConvTranspose2d(**deconv_cfg),
113
+ ),
114
+ (
115
+ "".join(["bn", str(i)]),
116
+ nn.BatchNorm2d(deconv_cfg["out_channels"]),
117
+ ),
118
+ ("".join(["relu", str(i)]), nn.ReLU(inplace=True)),
119
+ ]
120
+ )
121
+ decoder_layers.append(("conv", nn.Conv2d(1, 1, 1)))
122
+ self.decoder = nn.Sequential(OrderedDict(decoder_layers))
123
+
124
+
125
+ @HEATMAP_HEAD_REGISTRY.register()
126
+ class UpSampleConv(BaseHeatmapHead):
127
+ def __init__(
128
+ self,
129
+ in_channel: int,
130
+ deconv_cfgs: List[Dict],
131
+ dim: int = 96,
132
+ use_conv: bool = False,
133
+ use_residual: bool = False,
134
+ feat_type: Literal["attn", "both"] = "both",
135
+ attn_layer: Optional[str] = None,
136
+ pre_norm: bool = False,
137
+ use_head: bool = False,
138
+ ) -> None:
139
+ super().__init__(
140
+ in_channel,
141
+ deconv_cfgs,
142
+ dim,
143
+ use_conv,
144
+ use_residual,
145
+ feat_type,
146
+ attn_layer,
147
+ pre_norm,
148
+ use_head,
149
+ )
150
+ decoder_layers = []
151
+ for i, deconv_cfg in enumerate(deconv_cfgs, start=1):
152
+ decoder_layers.extend(
153
+ [
154
+ (
155
+ "".join(["upsample", str(i)]),
156
+ nn.Upsample(
157
+ scale_factor=2, mode="bilinear", align_corners=True
158
+ ),
159
+ ),
160
+ (
161
+ "".join(["conv", str(i)]),
162
+ nn.Conv2d(**deconv_cfg),
163
+ ),
164
+ (
165
+ "".join(["bn", str(i)]),
166
+ nn.BatchNorm2d(deconv_cfg["out_channels"]),
167
+ ),
168
+ ("".join(["relu", str(i)]), nn.ReLU(inplace=True)),
169
+ ]
170
+ )
171
+ decoder_layers.append(("conv", nn.Conv2d(1, 1, 1)))
172
+ self.decoder = nn.Sequential(OrderedDict(decoder_layers))
173
+
174
+
175
+ @HEATMAP_HEAD_REGISTRY.register()
176
+ class PixelShuffle(BaseHeatmapHead):
177
+ def __init__(
178
+ self,
179
+ in_channel: int,
180
+ deconv_cfgs: List[Dict],
181
+ dim: int = 96,
182
+ use_conv: bool = False,
183
+ use_residual: bool = False,
184
+ feat_type: Literal["attn", "both"] = "both",
185
+ attn_layer: Optional[str] = None,
186
+ pre_norm: bool = False,
187
+ use_head: bool = False,
188
+ ) -> None:
189
+ super().__init__(
190
+ in_channel,
191
+ deconv_cfgs,
192
+ dim,
193
+ use_conv,
194
+ use_residual,
195
+ feat_type,
196
+ attn_layer,
197
+ pre_norm,
198
+ use_head,
199
+ )
200
+ decoder_layers = []
201
+ for i, deconv_cfg in enumerate(deconv_cfgs, start=1):
202
+ deconv_cfg["out_channels"] = deconv_cfg["out_channels"] * 4
203
+ decoder_layers.extend(
204
+ [
205
+ (
206
+ "".join(["conv", str(i)]),
207
+ nn.Conv2d(**deconv_cfg),
208
+ ),
209
+ (
210
+ "".join(["pixel_shuffle", str(i)]),
211
+ nn.PixelShuffle(upscale_factor=2),
212
+ ),
213
+ (
214
+ "".join(["bn", str(i)]),
215
+ nn.BatchNorm2d(deconv_cfg["out_channels"] // 4),
216
+ ),
217
+ ("".join(["relu", str(i)]), nn.ReLU(inplace=True)),
218
+ ]
219
+ )
220
+ decoder_layers.append(("conv", nn.Conv2d(1, 1, 1)))
221
+ self.decoder = nn.Sequential(OrderedDict(decoder_layers))
222
+
223
+
224
+ def build_heatmap_head(name, *args, **kwargs):
225
+ return HEATMAP_HEAD_REGISTRY.get(name)(*args, **kwargs)
modeling/head/inout_head.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import OrderedDict
2
+ import torch
3
+ from torch import nn
4
+ from detectron2.utils.registry import Registry
5
+
6
+
7
+ INOUT_HEAD_REGISTRY = Registry("INOUT_HEAD_REGISTRY")
8
+ INOUT_HEAD_REGISTRY.__doc__ = "Registry for inout head"
9
+
10
+
11
+ @INOUT_HEAD_REGISTRY.register()
12
+ class SimpleLinear(nn.Module):
13
+ def __init__(self, in_channel: int, dropout: float = 0) -> None:
14
+ super().__init__()
15
+ self.in_channel = in_channel
16
+ self.classifier = nn.Sequential(
17
+ OrderedDict(
18
+ [("dropout", nn.Dropout(dropout)), ("linear", nn.Linear(in_channel, 1))]
19
+ )
20
+ )
21
+
22
+ def get_feat(self, x, masks):
23
+ feats = x["head_feat"]
24
+ if masks is not None:
25
+ B, C = x["last_feat"].shape[:2]
26
+ scene_feats = x["last_feat"].view(B, C, -1).permute(0, 2, 1)
27
+ masks = masks / (masks.sum(dim=-1, keepdim=True) + 1e-6)
28
+ scene_feats = (scene_feats * masks.unsqueeze(-1)).sum(dim=1)
29
+ feats = torch.cat((feats, scene_feats), dim=1)
30
+ return feats
31
+
32
+ def forward(self, x, masks=None):
33
+ feat = self.get_feat(x, masks)
34
+ return self.classifier(feat)
35
+
36
+
37
+ @INOUT_HEAD_REGISTRY.register()
38
+ class SimpleMlp(SimpleLinear):
39
+ def __init__(self, in_channel: int, dropout: float = 0) -> None:
40
+ super().__init__(in_channel, dropout)
41
+ self.classifier = nn.Sequential(
42
+ OrderedDict(
43
+ [
44
+ ("dropout0", nn.Dropout(dropout)),
45
+ ("linear0", nn.Linear(in_channel, in_channel)),
46
+ ("relu", nn.ReLU()),
47
+ ("dropout1", nn.Dropout(dropout)),
48
+ ("linear1", nn.Linear(in_channel, 1)),
49
+ ]
50
+ )
51
+ )
52
+
53
+
54
+ def build_inout_head(name, *args, **kwargs):
55
+ return INOUT_HEAD_REGISTRY.get(name)(*args, **kwargs)
modeling/meta_arch/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .gaze_attention_mapper import GazeAttentionMapper
modeling/meta_arch/gaze_attention_mapper.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from typing import Dict, Union
4
+
5
+
6
+ class GazeAttentionMapper(nn.Module):
7
+ def __init__(
8
+ self,
9
+ backbone: nn.Module,
10
+ regressor: nn.Module,
11
+ classifier: nn.Module,
12
+ criterion: nn.Module,
13
+ pam: nn.Module,
14
+ use_aux_loss: bool = False,
15
+ device: Union[torch.device, str] = "cuda",
16
+ ) -> None:
17
+ super().__init__()
18
+ self.backbone = backbone
19
+ self.pam = pam
20
+ self.regressor = regressor
21
+ self.classifier = classifier
22
+ self.criterion = criterion
23
+ self.use_aux_loss = use_aux_loss
24
+ self.device = torch.device(device)
25
+
26
+ def forward(self, x):
27
+ (
28
+ scenes,
29
+ heads,
30
+ gt_heatmaps,
31
+ gt_inouts,
32
+ head_masks,
33
+ image_masks,
34
+ ) = self.preprocess_inputs(x)
35
+ # Calculate patch weights based on head position
36
+ embedded_heads = self.pam(scenes, heads)
37
+ aux_masks = None
38
+ if self.use_aux_loss:
39
+ embedded_heads, aux_masks = embedded_heads
40
+
41
+ # Get out-dict
42
+ x = self.backbone(
43
+ scenes,
44
+ image_masks,
45
+ None,
46
+ )
47
+
48
+ # Apply patch weights to get the final feats and attention maps
49
+ feats = x.get("last_feat", None)
50
+ if feats is not None:
51
+ x["head_feat"] = (
52
+ (embedded_heads.repeat(1, feats.shape[1], 1, 1) * feats)
53
+ .sum(dim=(2, 3))
54
+ .reshape(len(feats), -1)
55
+ ) # BC
56
+
57
+ attn_maps = x["attention_maps"]
58
+ B, C, *_ = attn_maps.shape
59
+ x["attention_maps"] = (
60
+ attn_maps * embedded_heads.reshape(B, 1, -1, 1, 1).repeat(1, C, 1, 1, 1)
61
+ ).sum(
62
+ dim=2
63
+ ) # BCHW
64
+
65
+ # Apply heads
66
+ heatmaps = self.regressor(x)
67
+ inouts = self.classifier(x, None)
68
+
69
+ if self.training:
70
+ return self.criterion(
71
+ heatmaps,
72
+ inouts,
73
+ gt_heatmaps,
74
+ gt_inouts,
75
+ aux_masks,
76
+ head_masks,
77
+ )
78
+ # Inference
79
+ return heatmaps, inouts.sigmoid()
80
+
81
+ def preprocess_inputs(self, batched_inputs: Dict[str, torch.Tensor]):
82
+ return (
83
+ batched_inputs["images"].to(self.device),
84
+ batched_inputs["head_channels"].to(self.device),
85
+ batched_inputs["heatmaps"].to(self.device)
86
+ if "heatmaps" in batched_inputs.keys()
87
+ else None,
88
+ batched_inputs["gaze_inouts"].to(self.device)
89
+ if "gaze_inouts" in batched_inputs.keys()
90
+ else None,
91
+ batched_inputs["head_masks"].to(self.device)
92
+ if "head_masks" in batched_inputs.keys()
93
+ else None,
94
+ batched_inputs["image_masks"].to(self.device)
95
+ if "image_masks" in batched_inputs.keys()
96
+ else None,
97
+ )
modeling/patch_attention/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .patch_attention import build_pam
modeling/patch_attention/patch_attention.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import OrderedDict
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ from detectron2.utils.registry import Registry
5
+ from fvcore.nn import c2_msra_fill
6
+
7
+
8
+ SPATIAL_GUIDANCE_REGISTRY = Registry("SPATIAL_GUIDANCE_REGISTRY")
9
+ SPATIAL_GUIDANCE_REGISTRY.__doc__ = "Registry for 2d spatial guidance"
10
+
11
+
12
+ class _PoolFusion(nn.Module):
13
+ def __init__(self, patch_size: int, use_avgpool: bool = False) -> None:
14
+ super().__init__()
15
+ self.patch_size = patch_size
16
+ self.attn_reducer = F.avg_pool2d if use_avgpool else F.max_pool2d
17
+
18
+ def forward(self, scenes, heads):
19
+ attn_masks = self.attn_reducer(
20
+ heads,
21
+ (self.patch_size, self.patch_size),
22
+ (self.patch_size, self.patch_size),
23
+ (0, 0),
24
+ )
25
+ patch_attn = attn_masks.masked_fill(attn_masks <= 0, -1e9)
26
+ return F.softmax(patch_attn.view(len(patch_attn), -1), dim=1).view(
27
+ *patch_attn.shape
28
+ )
29
+
30
+
31
+ @SPATIAL_GUIDANCE_REGISTRY.register()
32
+ class AvgFusion(_PoolFusion):
33
+ def __init__(self, patch_size: int) -> None:
34
+ super().__init__(patch_size, False)
35
+
36
+
37
+ @SPATIAL_GUIDANCE_REGISTRY.register()
38
+ class MaxFusion(_PoolFusion):
39
+ def __init__(self, patch_size: int) -> None:
40
+ super().__init__(patch_size, True)
41
+
42
+
43
+ @SPATIAL_GUIDANCE_REGISTRY.register()
44
+ class PatchPAM(nn.Module):
45
+ def __init__(
46
+ self,
47
+ patch_size: int,
48
+ act_layer=nn.ReLU,
49
+ embed_dim: int = 768,
50
+ use_aux_loss: bool = False,
51
+ ) -> None:
52
+ super().__init__()
53
+ self.patch_size = patch_size
54
+ patch_embed = nn.Conv2d(
55
+ 3, embed_dim, (patch_size, patch_size), (patch_size, patch_size), (0, 0)
56
+ )
57
+ c2_msra_fill(patch_embed)
58
+ conv = nn.Conv2d(embed_dim, 1, (1, 1), (1, 1), (0, 0))
59
+ c2_msra_fill(conv)
60
+ self.use_aux_loss = use_aux_loss
61
+ if use_aux_loss:
62
+ self.patch_embed = nn.Sequential(
63
+ OrderedDict(
64
+ [
65
+ ("patch_embed", patch_embed),
66
+ ("act_layer", act_layer(inplace=True)),
67
+ ]
68
+ )
69
+ )
70
+ self.embed = conv
71
+ conv = nn.Conv2d(embed_dim, 1, (1, 1), (1, 1), (0, 0))
72
+ c2_msra_fill(conv)
73
+ self.aux_embed = conv
74
+ else:
75
+ self.embed = nn.Sequential(
76
+ OrderedDict(
77
+ [
78
+ ("patch_embed", patch_embed),
79
+ ("act_layer", act_layer(inplace=True)),
80
+ ("embed", conv),
81
+ ]
82
+ )
83
+ )
84
+
85
+ def forward(self, scenes, heads):
86
+ attn_masks = F.max_pool2d(
87
+ heads,
88
+ (self.patch_size, self.patch_size),
89
+ (self.patch_size, self.patch_size),
90
+ (0, 0),
91
+ )
92
+ if self.use_aux_loss:
93
+ embed = self.patch_embed(scenes)
94
+ aux_masks = self.aux_embed(embed)
95
+ patch_attn = self.embed(embed) * attn_masks
96
+ else:
97
+ patch_attn = self.embed(scenes) * attn_masks
98
+ patch_attn = patch_attn.masked_fill(attn_masks <= 0, -1e9)
99
+ patch_attn = F.softmax(patch_attn.view(len(patch_attn), -1), dim=1).view(
100
+ *patch_attn.shape
101
+ )
102
+ return (patch_attn, aux_masks) if self.use_aux_loss else patch_attn
103
+
104
+
105
+ def build_pam(name, *args, **kwargs):
106
+ return SPATIAL_GUIDANCE_REGISTRY.get(name)(*args, **kwargs)
pretrained/gazefollow.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:61f82255d4db0640208c90b037fd4bb62e061efd51cd03aa4b45f10159acab15
3
+ size 266808135
pretrained/videoattentiontarget.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cdf8d767815b7c8f1ba78d89710dc7a2de2d08e97554654a296a3e71b4903cec
3
+ size 266808135
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fvcore==0.1.5.post20221221
2
+ numpy==1.26.2
3
+ omegaconf==2.3.0
4
+ opencv-python==4.8.1.78
5
+ opencv-python-headless==4.6.0.66
6
+ pandas==1.4.4
7
+ Pillow==10.3.0
8
+ scikit-image==0.22.0
9
+ scikit-learn==1.5.0
10
+ scipy==1.11.4
11
+ timm==0.9.12
12
+ torch==2.2.0
13
+ torchaudio==2.0.2
14
+ torchvision==0.15.2
15
+ tqdm==4.66.3
16
+ xformers==0.0.21
17
+ yacs==0.1.8
scripts/convert_pth.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Convert official model weights to format that d2 receives
2
+ import argparse
3
+ from collections import OrderedDict
4
+ import torch
5
+
6
+
7
+ def convert(src: str, dst: str):
8
+ checkpoint = torch.load(src)
9
+ has_model = "model" in checkpoint.keys()
10
+ checkpoint = checkpoint["model"] if has_model else checkpoint
11
+ if "state_dict" in checkpoint.keys():
12
+ checkpoint = checkpoint["state_dict"]
13
+ out_cp = OrderedDict()
14
+ for k, v in checkpoint.items():
15
+ out_cp[".".join(["backbone", k])] = v
16
+ out_cp = {"model": out_cp} if has_model else out_cp
17
+ torch.save(out_cp, dst)
18
+
19
+
20
+ if __name__ == "__main__":
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument(
23
+ "--src", "-s", type=str, required=True, help="Path to src weights.pth"
24
+ )
25
+ parser.add_argument(
26
+ "--dst", "-d", type=str, required=True, help="Path to dst weights.pth"
27
+ )
28
+ args = parser.parse_args()
29
+ convert(
30
+ args.src,
31
+ args.dst,
32
+ )
scripts/gen_gazefollow_head_masks.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ import cv2
5
+ import numpy as np
6
+ import pandas as pd
7
+ import tqdm
8
+ from PIL import Image
9
+ from retinaface.pre_trained_models import get_model
10
+
11
+
12
+ random.seed(1)
13
+
14
+
15
+ def gaussian(x_min, y_min, x_max, y_max):
16
+ x_min, x_max = sorted((x_min, x_max))
17
+ y_min, y_max = sorted((y_min, y_max))
18
+ x_mid, y_mid = (x_min + x_max) / 2, (y_min + y_max) / 2
19
+ x_sigma2, y_sigma2 = (
20
+ np.clip(np.square([x_max - x_min, y_max - y_min], dtype=float), 1, None) / 3
21
+ )
22
+
23
+ def _gaussian(_xs, _ys):
24
+ return np.exp(
25
+ -(np.square(_xs - x_mid) / x_sigma2 + np.square(_ys - y_mid) / y_sigma2)
26
+ )
27
+
28
+ return _gaussian
29
+
30
+
31
+ def plot_ori(label_path, data_dir):
32
+ df = pd.read_csv(
33
+ label_path,
34
+ names=[
35
+ # Original labels
36
+ "path",
37
+ "idx",
38
+ "body_bbox_x",
39
+ "body_bbox_y",
40
+ "body_bbox_w",
41
+ "body_bbox_h",
42
+ "eye_x",
43
+ "eye_y",
44
+ "gaze_x",
45
+ "gaze_y",
46
+ "x_min",
47
+ "y_min",
48
+ "x_max",
49
+ "y_max",
50
+ "inout",
51
+ "meta0",
52
+ "meta1",
53
+ ],
54
+ index_col=False,
55
+ encoding="utf-8-sig",
56
+ )
57
+ grouped = df.groupby("path")
58
+
59
+ output_dir = os.path.join(data_dir, "head_masks")
60
+
61
+ for image_name, group_df in tqdm.tqdm(grouped, desc="Generating masks with annotations: "):
62
+ if not os.path.exists(os.path.join(output_dir, image_name)):
63
+ w, h = Image.open(image_name).size
64
+ heatmap = np.zeros((h, w), dtype=np.float32)
65
+ indices = np.meshgrid(
66
+ np.linspace(0.0, float(w), num=w, endpoint=False),
67
+ np.linspace(0.0, float(h), num=h, endpoint=False),
68
+ )
69
+ for _, row in group_df.iterrows():
70
+ x_min, y_min, x_max, y_max = (
71
+ row["x_min"],
72
+ row["y_min"],
73
+ row["x_max"],
74
+ row["y_max"],
75
+ )
76
+ gauss = gaussian(x_min, y_min, x_max, y_max)
77
+ heatmap += gauss(*indices)
78
+ heatmap /= np.max(heatmap)
79
+ heatmap_image = Image.fromarray((heatmap * 255).astype(np.uint8), mode="L")
80
+ output_filename = os.path.join(output_dir, image_name)
81
+ os.makedirs(os.path.dirname(output_filename), exist_ok=True)
82
+ heatmap_image.save(output_filename)
83
+
84
+
85
+ def plot_gen(df, data_dir):
86
+ df = df[df["score"] > 0.8]
87
+ grouped = df.groupby("path")
88
+
89
+ output_dir = os.path.join(data_dir, "head_masks")
90
+
91
+ for image_name, group_df in tqdm.tqdm(grouped, desc="Generating masks with predictions: "):
92
+ w, h = Image.open(image_name).size
93
+ heatmap = np.zeros((h, w), dtype=np.float32)
94
+ indices = np.meshgrid(
95
+ np.linspace(0.0, float(w), num=w, endpoint=False),
96
+ np.linspace(0.0, float(h), num=h, endpoint=False),
97
+ )
98
+ for index, row in group_df.iterrows():
99
+ x_min, y_min, x_max, y_max = (
100
+ row["x_min"],
101
+ row["y_min"],
102
+ row["x_max"],
103
+ row["y_max"],
104
+ )
105
+ gauss = gaussian(x_min, y_min, x_max, y_max)
106
+ heatmap += gauss(*indices)
107
+ heatmap /= np.max(heatmap)
108
+ heatmap_image = Image.fromarray((heatmap * 255).astype(np.uint8), mode="L")
109
+ output_filename = os.path.join(output_dir, image_name)
110
+ os.makedirs(os.path.dirname(output_filename), exist_ok=True)
111
+ heatmap_image.save(output_filename)
112
+
113
+
114
+ if __name__ == "__main__":
115
+ parser = argparse.ArgumentParser()
116
+ parser.add_argument("--dataset_dir", help="Root directory of dataset")
117
+ parser.add_argument(
118
+ "--subset",
119
+ help="Subset of dataset to process",
120
+ choices=["train", "test"],
121
+ )
122
+ args = parser.parse_args()
123
+
124
+ label_path = os.path.join(
125
+ args.dataset_dir, args.subset + "_annotations_release.txt"
126
+ )
127
+
128
+ column_names = [
129
+ "path",
130
+ "idx",
131
+ "body_bbox_x",
132
+ "body_bbox_y",
133
+ "body_bbox_w",
134
+ "body_bbox_h",
135
+ "eye_x",
136
+ "eye_y",
137
+ "gaze_x",
138
+ "gaze_y",
139
+ "bbox_x_min",
140
+ "bbox_y_min",
141
+ "bbox_x_max",
142
+ "bbox_y_max",
143
+ ]
144
+ df = pd.read_csv(
145
+ label_path,
146
+ sep=",",
147
+ names=column_names,
148
+ usecols=column_names,
149
+ index_col=False,
150
+ )
151
+ df = df.groupby("path")
152
+
153
+ model = get_model("resnet50_2020-07-20", max_size=2048, device="cuda")
154
+ model.eval()
155
+
156
+ paths = list(df.groups.keys())
157
+ csv = []
158
+ for path in tqdm.tqdm(paths, desc="Predicting head bboxes: "):
159
+ img = cv2.imread(os.path.join(args.dataset_dir, path))
160
+
161
+ annotations = model.predict_jsons(img)
162
+
163
+ for annotation in annotations:
164
+ if len(annotation["bbox"]) == 0:
165
+ continue
166
+
167
+ csv.append(
168
+ [
169
+ path,
170
+ annotation["score"],
171
+ annotation["bbox"][0],
172
+ annotation["bbox"][1],
173
+ annotation["bbox"][2],
174
+ annotation["bbox"][3],
175
+ ]
176
+ )
177
+
178
+ # Write csv
179
+ df = pd.DataFrame(
180
+ csv, columns=["path", "score", "x_min", "y_min", "x_max", "y_max"]
181
+ )
182
+ df.to_csv(os.path.join(args.dataset_dir, f"{args.subset}_head.csv"), index=False)
183
+
184
+ plot_gen(df, args.dataset_dir)
185
+ plot_ori(label_path, args.data_dir)
tools/eval_on_gazefollow.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from os import path as osp
3
+ import argparse
4
+ import warnings
5
+ import torch
6
+ import numpy as np
7
+ from PIL import Image
8
+ from detectron2.config import instantiate, LazyConfig
9
+
10
+ sys.path.append(osp.dirname(osp.dirname(__file__)))
11
+ from utils import *
12
+
13
+
14
+ warnings.simplefilter(action="ignore", category=FutureWarning)
15
+
16
+
17
+ def do_test(cfg, model, use_dark_inference=False):
18
+ val_loader = instantiate(cfg.dataloader.val)
19
+
20
+ model.train(False)
21
+ AUC = []
22
+ min_dist = []
23
+ avg_dist = []
24
+ with torch.no_grad():
25
+ for data in val_loader:
26
+ val_gaze_heatmap_pred, _ = model(data)
27
+ val_gaze_heatmap_pred = (
28
+ val_gaze_heatmap_pred.squeeze(1).cpu().detach().numpy()
29
+ )
30
+
31
+ # go through each data point and record AUC, min dist, avg dist
32
+ for b_i in range(len(val_gaze_heatmap_pred)):
33
+ # remove padding and recover valid ground truth points
34
+ valid_gaze = data["gazes"][b_i]
35
+ valid_gaze = valid_gaze[valid_gaze != -1].view(-1, 2)
36
+ # AUC: area under curve of ROC
37
+ multi_hot = multi_hot_targets(data["gazes"][b_i], data["imsize"][b_i])
38
+ if use_dark_inference:
39
+ pred_x, pred_y = dark_inference(val_gaze_heatmap_pred[b_i])
40
+ else:
41
+ pred_x, pred_y = argmax_pts(val_gaze_heatmap_pred[b_i])
42
+ norm_p = [
43
+ pred_x / val_gaze_heatmap_pred[b_i].shape[-2],
44
+ pred_y / val_gaze_heatmap_pred[b_i].shape[-1],
45
+ ]
46
+ scaled_heatmap = np.array(
47
+ Image.fromarray(val_gaze_heatmap_pred[b_i]).resize(
48
+ data["imsize"][b_i],
49
+ resample=Image.BILINEAR,
50
+ )
51
+ )
52
+ auc_score = auc(scaled_heatmap, multi_hot)
53
+ AUC.append(auc_score)
54
+ # min distance: minimum among all possible pairs of <ground truth point, predicted point>
55
+ all_distances = []
56
+ for gt_gaze in valid_gaze:
57
+ all_distances.append(L2_dist(gt_gaze, norm_p))
58
+ min_dist.append(min(all_distances))
59
+ # average distance: distance between the predicted point and human average point
60
+ mean_gt_gaze = torch.mean(valid_gaze, 0)
61
+ avg_distance = L2_dist(mean_gt_gaze, norm_p)
62
+ avg_dist.append(avg_distance)
63
+
64
+ print("|AUC |min dist|avg dist|")
65
+ print(
66
+ "|{:.4f}|{:.4f} |{:.4f} |".format(
67
+ torch.mean(torch.tensor(AUC)),
68
+ torch.mean(torch.tensor(min_dist)),
69
+ torch.mean(torch.tensor(avg_dist)),
70
+ )
71
+ )
72
+
73
+
74
+ def main(args):
75
+ cfg = LazyConfig.load(args.config_file)
76
+ model: torch.Module = instantiate(cfg.model)
77
+ model.load_state_dict(torch.load(args.model_weights)["model"])
78
+ model.to(cfg.train.device)
79
+ do_test(cfg, model, use_dark_inference=args.use_dark_inference)
80
+
81
+
82
+ if __name__ == "__main__":
83
+ parser = argparse.ArgumentParser()
84
+ parser.add_argument("--config_file", type=str, help="config file")
85
+ parser.add_argument(
86
+ "--model_weights",
87
+ type=str,
88
+ help="model weights",
89
+ )
90
+ parser.add_argument("--use_dark_inference", action="store_true")
91
+ args = parser.parse_args()
92
+ main(args)
tools/eval_on_video_attention_target.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from os import path as osp
3
+ import argparse
4
+ import warnings
5
+ import torch
6
+ import numpy as np
7
+ from PIL import Image
8
+ from detectron2.config import instantiate, LazyConfig
9
+
10
+ sys.path.append(osp.dirname(osp.dirname(__file__)))
11
+ from utils import *
12
+
13
+
14
+ warnings.simplefilter(action="ignore", category=FutureWarning)
15
+
16
+
17
+ def do_test(cfg, model, use_dark_inference=False):
18
+ val_loader = instantiate(cfg.dataloader.val)
19
+
20
+ model.train(False)
21
+ AUC = []
22
+ dist = []
23
+ inout_gt = []
24
+ inout_pred = []
25
+ with torch.no_grad():
26
+ for data in val_loader:
27
+ val_gaze_heatmap_pred, val_gaze_inout_pred = model(data)
28
+ val_gaze_heatmap_pred = (
29
+ val_gaze_heatmap_pred.squeeze(1).cpu().detach().numpy()
30
+ )
31
+ val_gaze_inout_pred = val_gaze_inout_pred.cpu().detach().numpy()
32
+
33
+ # go through each data point and record AUC, dist, ap
34
+ for b_i in range(len(val_gaze_heatmap_pred)):
35
+ auc_batch = []
36
+ dist_batch = []
37
+ if data["gaze_inouts"][b_i]:
38
+ # remove padding and recover valid ground truth points
39
+ valid_gaze = data["gazes"][b_i]
40
+ # AUC: area under curve of ROC
41
+ multi_hot = data["heatmaps"][b_i]
42
+ multi_hot = (multi_hot > 0).float().numpy()
43
+ if use_dark_inference:
44
+ pred_x, pred_y = dark_inference(val_gaze_heatmap_pred[b_i])
45
+ else:
46
+ pred_x, pred_y = argmax_pts(val_gaze_heatmap_pred[b_i])
47
+ norm_p = [
48
+ pred_x / val_gaze_heatmap_pred[b_i].shape[-1],
49
+ pred_y / val_gaze_heatmap_pred[b_i].shape[-2],
50
+ ]
51
+ scaled_heatmap = np.array(
52
+ Image.fromarray(val_gaze_heatmap_pred[b_i]).resize(
53
+ (64, 64),
54
+ resample=Image.Resampling.BILINEAR,
55
+ )
56
+ )
57
+ auc_score = auc(scaled_heatmap, multi_hot)
58
+ auc_batch.append(auc_score)
59
+ dist_batch.append(L2_dist(valid_gaze.numpy(), norm_p))
60
+ AUC.extend(auc_batch)
61
+ dist.extend(dist_batch)
62
+ inout_gt.extend(data["gaze_inouts"].cpu().numpy())
63
+ inout_pred.extend(val_gaze_inout_pred)
64
+
65
+ print("|AUC |dist |AP |")
66
+ print(
67
+ "|{:.4f}|{:.4f} |{:.4f} |".format(
68
+ torch.mean(torch.tensor(AUC)),
69
+ torch.mean(torch.tensor(dist)),
70
+ ap(inout_gt, inout_pred),
71
+ )
72
+ )
73
+
74
+
75
+ def main(args):
76
+ cfg = LazyConfig.load(args.config_file)
77
+ model: torch.Module = instantiate(cfg.model)
78
+ model.load_state_dict(torch.load(args.model_weights)["model"])
79
+ model.to(cfg.train.device)
80
+ do_test(cfg, model, use_dark_inference=args.use_dark_inference)
81
+
82
+
83
+ if __name__ == "__main__":
84
+ parser = argparse.ArgumentParser()
85
+ parser.add_argument("--config_file", type=str, help="config file")
86
+ parser.add_argument(
87
+ "--model_weights",
88
+ type=str,
89
+ help="model weights",
90
+ )
91
+ parser.add_argument("--use_dark_inference", action="store_true")
92
+ args = parser.parse_args()
93
+ main(args)
tools/train.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import logging
3
+ import warnings
4
+ import sys
5
+
6
+ from detectron2.checkpoint import DetectionCheckpointer
7
+ from detectron2.config import LazyConfig, instantiate
8
+ from detectron2.engine import (
9
+ default_argument_parser,
10
+ default_setup,
11
+ default_writers,
12
+ hooks,
13
+ launch,
14
+ )
15
+ from detectron2.engine.defaults import create_ddp_model
16
+ from detectron2.utils import comm
17
+
18
+
19
+ sys.path.append(osp.dirname(osp.dirname(__file__)))
20
+ warnings.filterwarnings("ignore")
21
+ logger = logging.getLogger("detectron2")
22
+
23
+
24
+ from engine import CycleTrainer
25
+
26
+
27
+ def do_train(args, cfg):
28
+ """
29
+ Args:
30
+ cfg: an object with the following attributes:
31
+ model: instantiate to a module
32
+ dataloader.{train,test}: instantiate to dataloaders
33
+ dataloader.evaluator: instantiate to evaluator for test set
34
+ optimizer: instantaite to an optimizer
35
+ lr_multiplier: instantiate to a fvcore scheduler
36
+ train: other misc config defined in `configs/common/train.py`, including:
37
+ output_dir (str)
38
+ init_checkpoint (str)
39
+ amp.enabled (bool)
40
+ max_iter (int)
41
+ eval_period, log_period (int)
42
+ device (str)
43
+ checkpointer (dict)
44
+ ddp (dict)
45
+ """
46
+ model = instantiate(cfg.model)
47
+ logger = logging.getLogger("detectron2")
48
+ logger.info("Model:\n{}".format(model))
49
+ model.to(cfg.train.device)
50
+
51
+ cfg.optimizer.params.model = model
52
+ optim = instantiate(cfg.optimizer)
53
+
54
+ train_loader = instantiate(cfg.dataloader.train)
55
+
56
+ model = create_ddp_model(model, **cfg.train.ddp)
57
+ trainer = CycleTrainer(model, train_loader, optim)
58
+ checkpointer = DetectionCheckpointer(
59
+ model,
60
+ cfg.train.output_dir,
61
+ trainer=trainer,
62
+ )
63
+ trainer.register_hooks(
64
+ [
65
+ hooks.IterationTimer(),
66
+ hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)),
67
+ hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer)
68
+ if comm.is_main_process()
69
+ else None,
70
+ # hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)),
71
+ hooks.PeriodicWriter(
72
+ default_writers(cfg.train.output_dir, cfg.train.max_iter),
73
+ period=cfg.train.log_period,
74
+ )
75
+ if comm.is_main_process()
76
+ else None,
77
+ ]
78
+ )
79
+
80
+ checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume)
81
+ if args.resume and checkpointer.has_checkpoint():
82
+ start_iter = trainer.iter + 1
83
+ else:
84
+ start_iter = 0
85
+ trainer.train(start_iter, cfg.train.max_iter)
86
+
87
+
88
+ def main(args):
89
+ cfg = LazyConfig.load(args.config_file)
90
+ cfg = LazyConfig.apply_overrides(cfg, args.opts)
91
+ default_setup(cfg, args)
92
+ do_train(args, cfg)
93
+
94
+
95
+ if __name__ == "__main__":
96
+ args = default_argument_parser().parse_args()
97
+ launch(
98
+ main,
99
+ args.num_gpus,
100
+ num_machines=args.num_machines,
101
+ machine_rank=args.machine_rank,
102
+ dist_url=args.dist_url,
103
+ args=(args,),
104
+ )