initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- .gitignore +16 -0
- LICENSE +21 -0
- README.md +78 -3
- assets/comparion.png +0 -0
- assets/demo0.gif +3 -0
- assets/demo1.gif +3 -0
- assets/pipeline.png +0 -0
- configs/common/__init__.py +1 -0
- configs/common/dataloader.py +214 -0
- configs/common/model.py +44 -0
- configs/common/optimizer.py +48 -0
- configs/common/scheduler.py +18 -0
- configs/common/train.py +16 -0
- configs/gazefollow.py +86 -0
- configs/gazefollow_518.py +86 -0
- configs/videoattentiontarget.py +90 -0
- data/__init__.py +5 -0
- data/augmentation.py +312 -0
- data/data_utils.py +181 -0
- data/gazefollow.py +295 -0
- data/masking.py +175 -0
- data/video_attention_target.py +228 -0
- data/video_attention_target_video.py +464 -0
- docs/eval.md +24 -0
- docs/install.md +19 -0
- docs/train.md +36 -0
- engine/__init__.py +1 -0
- engine/trainer.py +37 -0
- modeling/_init__.py +4 -0
- modeling/backbone/__init__.py +1 -0
- modeling/backbone/utils.py +154 -0
- modeling/backbone/vit.py +504 -0
- modeling/criterion/__init__.py +1 -0
- modeling/criterion/gaze_mapper_criterion.py +94 -0
- modeling/head/__init__.py +2 -0
- modeling/head/heatmap_head.py +225 -0
- modeling/head/inout_head.py +55 -0
- modeling/meta_arch/__init__.py +1 -0
- modeling/meta_arch/gaze_attention_mapper.py +97 -0
- modeling/patch_attention/__init__.py +1 -0
- modeling/patch_attention/patch_attention.py +106 -0
- pretrained/gazefollow.pth +3 -0
- pretrained/videoattentiontarget.pth +3 -0
- requirements.txt +17 -0
- scripts/convert_pth.py +32 -0
- scripts/gen_gazefollow_head_masks.py +185 -0
- tools/eval_on_gazefollow.py +92 -0
- tools/eval_on_video_attention_target.py +93 -0
- tools/train.py +104 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/demo0.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/demo1.gif filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.vscode
|
2 |
+
.idea
|
3 |
+
*output
|
4 |
+
backup
|
5 |
+
*/__pycache__/*
|
6 |
+
pretrained
|
7 |
+
logs
|
8 |
+
*.pyc
|
9 |
+
*.ipynb
|
10 |
+
*.out
|
11 |
+
*.log
|
12 |
+
*.pth
|
13 |
+
*.pkl
|
14 |
+
*.pt
|
15 |
+
*.npy
|
16 |
+
*debug*
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 HUST Vision Lab
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,3 +1,78 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
<h1>ViTGaze 👀</h1>
|
3 |
+
<h3>Gaze Following with Interaction Features in Vision Transformers</h3>
|
4 |
+
|
5 |
+
Yuehao Song<sup>1</sup> , Xinggang Wang<sup>1 :email:</sup> , Jingfeng Yao<sup>1</sup> , Wenyu Liu<sup>1</sup> , Jinglin Zhang<sup>2</sup> , Xiangmin Xu<sup>3</sup>
|
6 |
+
|
7 |
+
<sup>1</sup> Huazhong University of Science and Technology, <sup>2</sup> Shandong University, <sup>3</sup> South China University of Technology
|
8 |
+
|
9 |
+
(<sup>:email:</sup>) corresponding author.
|
10 |
+
|
11 |
+
ArXiv Preprint ([arXiv 2403.12778](https://arxiv.org/abs/2403.12778))
|
12 |
+
|
13 |
+
</div>
|
14 |
+
|
15 |
+
#
|
16 |
+

|
17 |
+

|
18 |
+
### News
|
19 |
+
* **`Mar. 25th, 2024`:** We release an initial version of ViTGaze.
|
20 |
+
* **`Mar. 19th, 2024`:** We released our paper on Arxiv. Code/Models are coming soon. Please stay tuned! ☕️
|
21 |
+
|
22 |
+
|
23 |
+
## Introduction
|
24 |
+
<div align="center"><h5>Plain Vision Transformer could also do gaze following with the simple ViTGaze framework!</h5></div>
|
25 |
+
|
26 |
+

|
27 |
+
|
28 |
+
Inspired by the remarkable success of pre-trained plain Vision Transformers (ViTs), we introduce a novel single-modality gaze following framework, **ViTGaze**. In contrast to previous methods, it creates a brand new gaze following framework based mainly on powerful encoders (relative decoder parameter less than 1%). Our principal insight lies in that the inter-token interactions within self-attention can be transferred to interactions between humans and scenes. Our method achieves state-of-the-art (SOTA) performance among all single-modality methods (3.4% improvement on AUC, 5.1% improvement on AP) and very comparable performance against multi-modality methods with 59% number of parameters less.
|
29 |
+
|
30 |
+
## Results
|
31 |
+
> Results from the [ViTGaze paper](https://arxiv.org/abs/2403.12778)
|
32 |
+
|
33 |
+

|
34 |
+
|
35 |
+
<table align="center">
|
36 |
+
<tr>
|
37 |
+
<th colspan="3">Results on <a herf=http://gazefollow.csail.mit.edu/index.html>GazeFollow</a></th>
|
38 |
+
<th colspan="3">Results on <a herf=https://github.com/ejcgt/attention-target-detection>VideoAttentionTarget</a></th>
|
39 |
+
</tr>
|
40 |
+
<tr>
|
41 |
+
<td><b>AUC</b></td>
|
42 |
+
<td><b>Avg. Dist.</b></td>
|
43 |
+
<td><b>Min. Dist.</b></td>
|
44 |
+
<td><b>AUC</b></td>
|
45 |
+
<td><b>Dist.</b></td>
|
46 |
+
<td><b>AP</b></td>
|
47 |
+
</tr>
|
48 |
+
<tr>
|
49 |
+
<td>0.949</td>
|
50 |
+
<td>0.105</td>
|
51 |
+
<td>0.047</td>
|
52 |
+
<td>0.938</td>
|
53 |
+
<td>0.102</td>
|
54 |
+
<td>0.905</td>
|
55 |
+
</tr>
|
56 |
+
</table>
|
57 |
+
|
58 |
+
Corresponding checkpoints are released:
|
59 |
+
- GazeFollow: [GoogleDrive](https://drive.google.com/file/d/164c4woGCmUI8UrM7GEKQrV1FbA3vGwP4/view?usp=drive_link)
|
60 |
+
- VideoAttentionTarget: [GoogleDrive](https://drive.google.com/file/d/11_O4Jm5wsvQ8qfLLgTlrudqSNvvepsV0/view?usp=drive_link)
|
61 |
+
## Getting Started
|
62 |
+
- [Installation](docs/install.md)
|
63 |
+
- [Train](docs/train.md)
|
64 |
+
- [Eval](docs/eval.md)
|
65 |
+
|
66 |
+
## Acknowledgements
|
67 |
+
ViTGaze is based on [detectron2](https://github.com/facebookresearch/detectron2). We use the efficient multi-head attention implemented in the [xFormers](https://github.com/facebookresearch/xformers) library.
|
68 |
+
|
69 |
+
## Citation
|
70 |
+
If you find ViTGaze is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.
|
71 |
+
```bibtex
|
72 |
+
@article{vitgaze,
|
73 |
+
title={ViTGaze: Gaze Following with Interaction Features in Vision Transformers},
|
74 |
+
author={Yuehao Song and Xinggang Wang and Jingfeng Yao and Wenyu Liu and Jinglin Zhang and Xiangmin Xu},
|
75 |
+
journal={arXiv preprint arXiv:2403.12778},
|
76 |
+
year={2024}
|
77 |
+
}
|
78 |
+
```
|
assets/comparion.png
ADDED
![]() |
assets/demo0.gif
ADDED
![]() |
Git LFS Details
|
assets/demo1.gif
ADDED
![]() |
Git LFS Details
|
assets/pipeline.png
ADDED
![]() |
configs/common/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import *
|
configs/common/dataloader.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os import path as osp
|
2 |
+
from typing import Literal
|
3 |
+
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
from detectron2.config import LazyCall as L
|
6 |
+
from detectron2.config import instantiate
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from torch.utils.data.distributed import DistributedSampler
|
9 |
+
|
10 |
+
from data import *
|
11 |
+
|
12 |
+
|
13 |
+
DATA_ROOT = "${Root to Datasets}"
|
14 |
+
if DATA_ROOT == "${Root to Datasets}":
|
15 |
+
raise Exception(
|
16 |
+
f"""{osp.abspath(__file__)}: Rewrite `DATA_ROOT` with the root to the datasets.
|
17 |
+
The directory structure should be:
|
18 |
+
-DATA_ROOT
|
19 |
+
|-videoattentiontarget
|
20 |
+
|--images
|
21 |
+
|--annotations
|
22 |
+
|---train
|
23 |
+
|---test
|
24 |
+
|--head_masks
|
25 |
+
|---images
|
26 |
+
|-gazefollow
|
27 |
+
|--train
|
28 |
+
|--test2
|
29 |
+
|--train_annotations_release.txt
|
30 |
+
|--test_annotations_release.txt
|
31 |
+
|--head_masks
|
32 |
+
|---train
|
33 |
+
|---test2
|
34 |
+
"""
|
35 |
+
)
|
36 |
+
|
37 |
+
# Basic Config for Video Attention Target dataset and preprocessing
|
38 |
+
data_info = OmegaConf.create()
|
39 |
+
data_info.video_attention_target = OmegaConf.create()
|
40 |
+
data_info.video_attention_target.train_root = osp.join(
|
41 |
+
DATA_ROOT, "videoattentiontarget/images"
|
42 |
+
)
|
43 |
+
data_info.video_attention_target.train_anno = osp.join(
|
44 |
+
DATA_ROOT, "videoattentiontarget/annotations/train"
|
45 |
+
)
|
46 |
+
data_info.video_attention_target.val_root = osp.join(
|
47 |
+
DATA_ROOT, "videoattentiontarget/images"
|
48 |
+
)
|
49 |
+
data_info.video_attention_target.val_anno = osp.join(
|
50 |
+
DATA_ROOT, "videoattentiontarget/annotations/test"
|
51 |
+
)
|
52 |
+
data_info.video_attention_target.head_root = osp.join(
|
53 |
+
DATA_ROOT, "videoattentiontarget/head_masks/images"
|
54 |
+
)
|
55 |
+
|
56 |
+
data_info.video_attention_target_video = OmegaConf.create()
|
57 |
+
data_info.video_attention_target_video.train_root = osp.join(
|
58 |
+
DATA_ROOT, "videoattentiontarget/images"
|
59 |
+
)
|
60 |
+
data_info.video_attention_target_video.train_anno = osp.join(
|
61 |
+
DATA_ROOT, "videoattentiontarget/annotations/train"
|
62 |
+
)
|
63 |
+
data_info.video_attention_target_video.val_root = osp.join(
|
64 |
+
DATA_ROOT, "videoattentiontarget/images"
|
65 |
+
)
|
66 |
+
data_info.video_attention_target_video.val_anno = osp.join(
|
67 |
+
DATA_ROOT, "videoattentiontarget/annotations/test"
|
68 |
+
)
|
69 |
+
data_info.video_attention_target_video.head_root = osp.join(
|
70 |
+
DATA_ROOT, "videoattentiontarget/head_masks/images"
|
71 |
+
)
|
72 |
+
|
73 |
+
data_info.gazefollow = OmegaConf.create()
|
74 |
+
data_info.gazefollow.train_root = osp.join(DATA_ROOT, "gazefollow")
|
75 |
+
data_info.gazefollow.train_anno = osp.join(
|
76 |
+
DATA_ROOT, "gazefollow/train_annotations_release.txt"
|
77 |
+
)
|
78 |
+
data_info.gazefollow.val_root = osp.join(DATA_ROOT, "gazefollow")
|
79 |
+
data_info.gazefollow.val_anno = osp.join(
|
80 |
+
DATA_ROOT, "gazefollow/test_annotations_release.txt"
|
81 |
+
)
|
82 |
+
data_info.gazefollow.head_root = osp.join(DATA_ROOT, "gazefollow/head_masks")
|
83 |
+
|
84 |
+
data_info.input_size = 224
|
85 |
+
data_info.output_size = 64
|
86 |
+
data_info.quant_labelmap = True
|
87 |
+
data_info.mean = (0.485, 0.456, 0.406)
|
88 |
+
data_info.std = (0.229, 0.224, 0.225)
|
89 |
+
data_info.bbox_jitter = 0.5
|
90 |
+
data_info.rand_crop = 0.5
|
91 |
+
data_info.rand_flip = 0.5
|
92 |
+
data_info.color_jitter = 0.5
|
93 |
+
data_info.rand_rotate = 0.0
|
94 |
+
data_info.rand_lsj = 0.0
|
95 |
+
|
96 |
+
data_info.mask_size = 24
|
97 |
+
data_info.mask_scene = False
|
98 |
+
data_info.mask_head = False
|
99 |
+
data_info.max_scene_patches_ratio = 0.5
|
100 |
+
data_info.max_head_patches_ratio = 0.3
|
101 |
+
data_info.mask_prob = 0.2
|
102 |
+
|
103 |
+
data_info.seq_len = 16
|
104 |
+
data_info.max_len = 32
|
105 |
+
|
106 |
+
|
107 |
+
# Dataloader(gazefollow/video_atention_target, train/val)
|
108 |
+
def __build_dataloader(
|
109 |
+
name: Literal[
|
110 |
+
"gazefollow", "video_attention_target", "video_attention_target_video"
|
111 |
+
],
|
112 |
+
is_train: bool,
|
113 |
+
batch_size: int = 64,
|
114 |
+
num_workers: int = 14,
|
115 |
+
pin_memory: bool = True,
|
116 |
+
persistent_workers: bool = True,
|
117 |
+
drop_last: bool = True,
|
118 |
+
distributed: bool = False,
|
119 |
+
**kwargs,
|
120 |
+
):
|
121 |
+
assert name in [
|
122 |
+
"gazefollow",
|
123 |
+
"video_attention_target",
|
124 |
+
"video_attention_target_video",
|
125 |
+
], f'{name} not in ("gazefollow", "video_attention_target", "video_attention_target_video")'
|
126 |
+
|
127 |
+
for k, v in kwargs.items():
|
128 |
+
if k in ["train_root", "train_anno", "val_root", "val_anno", "head_root"]:
|
129 |
+
data_info[name][k] = v
|
130 |
+
else:
|
131 |
+
data_info[k] = v
|
132 |
+
|
133 |
+
datasets = {
|
134 |
+
"gazefollow": GazeFollow,
|
135 |
+
"video_attention_target": VideoAttentionTarget,
|
136 |
+
"video_attention_target_video": VideoAttentionTargetVideo,
|
137 |
+
}
|
138 |
+
dataset = L(datasets[name])(
|
139 |
+
image_root=data_info[name]["train_root" if is_train else "val_root"],
|
140 |
+
anno_root=data_info[name]["train_anno" if is_train else "val_anno"],
|
141 |
+
head_root=data_info[name]["head_root"],
|
142 |
+
transform=get_transform(
|
143 |
+
input_resolution=data_info.input_size,
|
144 |
+
mean=data_info.mean,
|
145 |
+
std=data_info.std,
|
146 |
+
),
|
147 |
+
input_size=data_info.input_size,
|
148 |
+
output_size=data_info.output_size,
|
149 |
+
quant_labelmap=data_info.quant_labelmap,
|
150 |
+
is_train=is_train,
|
151 |
+
bbox_jitter=data_info.bbox_jitter,
|
152 |
+
rand_crop=data_info.rand_crop,
|
153 |
+
rand_flip=data_info.rand_flip,
|
154 |
+
color_jitter=data_info.color_jitter,
|
155 |
+
rand_rotate=data_info.rand_rotate,
|
156 |
+
rand_lsj=data_info.rand_lsj,
|
157 |
+
mask_generator=(
|
158 |
+
MaskGenerator(
|
159 |
+
input_size=data_info.mask_size,
|
160 |
+
mask_scene=data_info.mask_scene,
|
161 |
+
mask_head=data_info.mask_head,
|
162 |
+
max_scene_patches_ratio=data_info.max_scene_patches_ratio,
|
163 |
+
max_head_patches_ratio=data_info.max_head_patches_ratio,
|
164 |
+
mask_prob=data_info.mask_prob,
|
165 |
+
)
|
166 |
+
if is_train
|
167 |
+
else None
|
168 |
+
),
|
169 |
+
)
|
170 |
+
if name == "video_attention_target_video":
|
171 |
+
dataset.seq_len = data_info.seq_len
|
172 |
+
dataset.max_len = data_info.max_len
|
173 |
+
dataset = instantiate(dataset)
|
174 |
+
|
175 |
+
return DataLoader(
|
176 |
+
dataset=dataset,
|
177 |
+
batch_size=batch_size,
|
178 |
+
num_workers=num_workers,
|
179 |
+
pin_memory=pin_memory,
|
180 |
+
persistent_workers=persistent_workers,
|
181 |
+
collate_fn=video_collate if name == "video_attention_target_video" else None,
|
182 |
+
sampler=DistributedSampler(dataset, shuffle=is_train) if distributed else None,
|
183 |
+
drop_last=drop_last,
|
184 |
+
)
|
185 |
+
|
186 |
+
|
187 |
+
dataloader = OmegaConf.create()
|
188 |
+
dataloader.gazefollow = OmegaConf.create()
|
189 |
+
dataloader.gazefollow.train = L(__build_dataloader)(
|
190 |
+
name="gazefollow",
|
191 |
+
is_train=True,
|
192 |
+
)
|
193 |
+
dataloader.gazefollow.val = L(__build_dataloader)(
|
194 |
+
name="gazefollow",
|
195 |
+
is_train=False,
|
196 |
+
)
|
197 |
+
dataloader.video_attention_target = OmegaConf.create()
|
198 |
+
dataloader.video_attention_target.train = L(__build_dataloader)(
|
199 |
+
name="video_attention_target",
|
200 |
+
is_train=True,
|
201 |
+
)
|
202 |
+
dataloader.video_attention_target.val = L(__build_dataloader)(
|
203 |
+
name="video_attention_target",
|
204 |
+
is_train=False,
|
205 |
+
)
|
206 |
+
dataloader.video_attention_target_video = OmegaConf.create()
|
207 |
+
dataloader.video_attention_target_video.train = L(__build_dataloader)(
|
208 |
+
name="video_attention_target_video",
|
209 |
+
is_train=True,
|
210 |
+
)
|
211 |
+
dataloader.video_attention_target_video.val = L(__build_dataloader)(
|
212 |
+
name="video_attention_target_video",
|
213 |
+
is_train=False,
|
214 |
+
)
|
configs/common/model.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from detectron2.config import LazyCall as L
|
2 |
+
|
3 |
+
from modeling import backbone, patch_attention, meta_arch, head, criterion
|
4 |
+
|
5 |
+
|
6 |
+
model = L(meta_arch.GazeAttentionMapper)()
|
7 |
+
model.backbone = L(backbone.build_backbone)(
|
8 |
+
name="small", out_attn=[2, 5, 8, 11]
|
9 |
+
)
|
10 |
+
model.pam = L(patch_attention.build_pam)(name="PatchPAM", patch_size=16)
|
11 |
+
model.regressor = L(head.build_heatmap_head)(
|
12 |
+
name="SimpleDeconv",
|
13 |
+
in_channel=24,
|
14 |
+
deconv_cfgs=[
|
15 |
+
{
|
16 |
+
"in_channels": 24,
|
17 |
+
"out_channels": 12,
|
18 |
+
"kernel_size": 3,
|
19 |
+
"stride": 2,
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"in_channels": 12,
|
23 |
+
"out_channels": 6,
|
24 |
+
"kernel_size": 3,
|
25 |
+
"stride": 2,
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"in_channels": 6,
|
29 |
+
"out_channels": 3,
|
30 |
+
"kernel_size": 3,
|
31 |
+
"stride": 2,
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"in_channels": 3,
|
35 |
+
"out_channels": 1,
|
36 |
+
"kernel_size": 3,
|
37 |
+
"stride": 2,
|
38 |
+
},
|
39 |
+
],
|
40 |
+
feat_type="attn",
|
41 |
+
)
|
42 |
+
model.classifier = L(head.build_inout_head)(name="SimpleLinear", in_channel=384)
|
43 |
+
model.criterion = L(criterion.GazeMapperCriterion)()
|
44 |
+
model.device = "cuda"
|
configs/common/optimizer.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from detectron2 import model_zoo
|
2 |
+
|
3 |
+
|
4 |
+
def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12):
|
5 |
+
"""
|
6 |
+
Calculate lr decay rate for different ViT blocks.
|
7 |
+
Args:
|
8 |
+
name (string): parameter name.
|
9 |
+
lr_decay_rate (float): base lr decay rate.
|
10 |
+
num_layers (int): number of ViT blocks.
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
lr decay rate for the given parameter.
|
14 |
+
"""
|
15 |
+
layer_id = num_layers + 1
|
16 |
+
if name.startswith("backbone"):
|
17 |
+
if ".pos_embed" in name or ".patch_embed" in name:
|
18 |
+
layer_id = 0
|
19 |
+
elif ".blocks." in name and ".residual." not in name:
|
20 |
+
layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
|
21 |
+
|
22 |
+
return lr_decay_rate ** (num_layers + 1 - layer_id)
|
23 |
+
|
24 |
+
|
25 |
+
class LRDecayRater:
|
26 |
+
def __init__(self, lr_decay_rate=1.0, num_layers=12, backbone_multiplier=1.0, freeze_pe=False, pam_lr_decay=1):
|
27 |
+
self.lr_decay_rate = lr_decay_rate
|
28 |
+
self.num_layers = num_layers
|
29 |
+
self.backbone_multiplier = backbone_multiplier
|
30 |
+
self.freeze_pe = freeze_pe
|
31 |
+
self.pam_lr_decay = pam_lr_decay
|
32 |
+
|
33 |
+
def __call__(self, name):
|
34 |
+
if name.startswith("backbone"):
|
35 |
+
if self.freeze_pe and ".pos_embed" in name or ".patch_embed" in name:
|
36 |
+
return 0
|
37 |
+
return self.backbone_multiplier * get_vit_lr_decay_rate(
|
38 |
+
name, self.lr_decay_rate, self.num_layers
|
39 |
+
)
|
40 |
+
if name.startswith("pam"):
|
41 |
+
return self.pam_lr_decay
|
42 |
+
return 1
|
43 |
+
|
44 |
+
|
45 |
+
# Optimizer
|
46 |
+
optimizer = model_zoo.get_config("common/optim.py").AdamW
|
47 |
+
optimizer.params.lr_factor_func = LRDecayRater(num_layers=12, lr_decay_rate=0.65)
|
48 |
+
optimizer.params.overrides = {"pos_embed": {"weight_decay": 0.0}}
|
configs/common/scheduler.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal
|
2 |
+
from detectron2.config import LazyCall as L
|
3 |
+
from detectron2.solver import WarmupParamScheduler
|
4 |
+
from fvcore.common.param_scheduler import MultiStepParamScheduler, CosineParamScheduler
|
5 |
+
|
6 |
+
|
7 |
+
def get_scheduler(typ: Literal["multistep", "cosine"] = "multistep", **kwargs):
|
8 |
+
if typ == "multistep":
|
9 |
+
return MultiStepParamScheduler(**kwargs)
|
10 |
+
elif typ == "cosine":
|
11 |
+
return CosineParamScheduler(**kwargs)
|
12 |
+
|
13 |
+
|
14 |
+
lr_multiplier = L(WarmupParamScheduler)(
|
15 |
+
scheduler=L(get_scheduler)(),
|
16 |
+
warmup_length=0,
|
17 |
+
warmup_factor=0.001,
|
18 |
+
)
|
configs/common/train.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
train = dict(
|
2 |
+
output_dir="./output",
|
3 |
+
init_checkpoint="",
|
4 |
+
max_iter=90000,
|
5 |
+
amp=dict(enabled=False), # options for Automatic Mixed Precision
|
6 |
+
ddp=dict( # options for DistributedDataParallel
|
7 |
+
broadcast_buffers=True,
|
8 |
+
find_unused_parameters=False,
|
9 |
+
fp16_compression=True,
|
10 |
+
),
|
11 |
+
checkpointer=dict(period=5000, max_to_keep=100), # options for PeriodicCheckpointer
|
12 |
+
eval_period=5000,
|
13 |
+
log_period=100,
|
14 |
+
device="cuda",
|
15 |
+
# ...
|
16 |
+
)
|
configs/gazefollow.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .common.dataloader import dataloader
|
2 |
+
from .common.model import model
|
3 |
+
from .common.optimizer import optimizer
|
4 |
+
from .common.scheduler import lr_multiplier
|
5 |
+
from .common.train import train
|
6 |
+
from os.path import join, basename
|
7 |
+
from torch.cuda import device_count
|
8 |
+
|
9 |
+
|
10 |
+
num_gpu = device_count()
|
11 |
+
ins_per_iter = 48
|
12 |
+
len_dataset = 126000
|
13 |
+
num_epoch = 14
|
14 |
+
# dataloader
|
15 |
+
dataloader = dataloader.gazefollow
|
16 |
+
dataloader.train.batch_size = ins_per_iter // num_gpu
|
17 |
+
dataloader.train.num_workers = dataloader.val.num_workers = 14
|
18 |
+
dataloader.train.distributed = num_gpu > 1
|
19 |
+
dataloader.train.rand_rotate = 0.5
|
20 |
+
dataloader.train.rand_lsj = 0.5
|
21 |
+
dataloader.train.input_size = dataloader.val.input_size = 434
|
22 |
+
dataloader.train.mask_scene = True
|
23 |
+
dataloader.train.mask_prob = 0.5
|
24 |
+
dataloader.train.mask_size = dataloader.train.input_size // 14
|
25 |
+
dataloader.train.max_scene_patches_ratio = 0.5
|
26 |
+
dataloader.val.batch_size = 64
|
27 |
+
dataloader.val.distributed = False
|
28 |
+
# train
|
29 |
+
train.init_checkpoint = "pretrained/dinov2_small.pth"
|
30 |
+
train.output_dir = join("./output", basename(__file__).split(".")[0])
|
31 |
+
train.max_iter = len_dataset * num_epoch // ins_per_iter
|
32 |
+
train.log_period = len_dataset // (ins_per_iter * 100)
|
33 |
+
train.checkpointer.max_to_keep = 10
|
34 |
+
train.checkpointer.period = len_dataset // ins_per_iter
|
35 |
+
train.seed = 0
|
36 |
+
# optimizer
|
37 |
+
optimizer.lr = 1e-4
|
38 |
+
optimizer.betas = (0.9, 0.99)
|
39 |
+
lr_multiplier.scheduler.typ = "cosine"
|
40 |
+
lr_multiplier.scheduler.start_value = 1
|
41 |
+
lr_multiplier.scheduler.end_value = 0.1
|
42 |
+
lr_multiplier.warmup_length = 1e-2
|
43 |
+
# model
|
44 |
+
model.use_aux_loss = model.pam.use_aux_loss = model.criterion.use_aux_loss = True
|
45 |
+
model.pam.name = "PatchPAM"
|
46 |
+
model.pam.embed_dim = 8
|
47 |
+
model.pam.patch_size = 14
|
48 |
+
model.backbone.name = "dinov2_small"
|
49 |
+
model.backbone.return_softmax_attn = True
|
50 |
+
model.backbone.out_attn = [2, 5, 8, 11]
|
51 |
+
model.backbone.use_cls_token = True
|
52 |
+
model.backbone.use_mask_token = True
|
53 |
+
model.regressor.name = "UpSampleConv"
|
54 |
+
model.regressor.in_channel = 24
|
55 |
+
model.regressor.use_conv = False
|
56 |
+
model.regressor.dim = 24
|
57 |
+
model.regressor.deconv_cfgs = [
|
58 |
+
dict(
|
59 |
+
in_channels=24,
|
60 |
+
out_channels=16,
|
61 |
+
kernel_size=3,
|
62 |
+
stride=1,
|
63 |
+
padding=1,
|
64 |
+
),
|
65 |
+
dict(
|
66 |
+
in_channels=16,
|
67 |
+
out_channels=8,
|
68 |
+
kernel_size=3,
|
69 |
+
stride=1,
|
70 |
+
padding=1,
|
71 |
+
),
|
72 |
+
dict(
|
73 |
+
in_channels=8,
|
74 |
+
out_channels=1,
|
75 |
+
kernel_size=3,
|
76 |
+
stride=1,
|
77 |
+
padding=1,
|
78 |
+
),
|
79 |
+
]
|
80 |
+
model.regressor.feat_type = "attn"
|
81 |
+
model.classifier.name = "SimpleMlp"
|
82 |
+
model.classifier.in_channel = 384
|
83 |
+
model.criterion.aux_weight = 100
|
84 |
+
model.criterion.aux_head_thres = 0.05
|
85 |
+
model.criterion.use_focal_loss = True
|
86 |
+
model.device = "cuda"
|
configs/gazefollow_518.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .common.dataloader import dataloader
|
2 |
+
from .common.model import model
|
3 |
+
from .common.optimizer import optimizer
|
4 |
+
from .common.scheduler import lr_multiplier
|
5 |
+
from .common.train import train
|
6 |
+
from os.path import join, basename
|
7 |
+
from torch.cuda import device_count
|
8 |
+
|
9 |
+
|
10 |
+
num_gpu = device_count()
|
11 |
+
ins_per_iter = 32
|
12 |
+
len_dataset = 126000
|
13 |
+
num_epoch = 1
|
14 |
+
# dataloader
|
15 |
+
dataloader = dataloader.gazefollow
|
16 |
+
dataloader.train.batch_size = ins_per_iter // num_gpu
|
17 |
+
dataloader.train.num_workers = dataloader.val.num_workers = 14
|
18 |
+
dataloader.train.distributed = num_gpu > 1
|
19 |
+
dataloader.train.rand_rotate = 0.5
|
20 |
+
dataloader.train.rand_lsj = 0.5
|
21 |
+
dataloader.train.input_size = dataloader.val.input_size = 518
|
22 |
+
dataloader.train.mask_scene = True
|
23 |
+
dataloader.train.mask_prob = 0.5
|
24 |
+
dataloader.train.mask_size = dataloader.train.input_size // 14
|
25 |
+
dataloader.train.max_scene_patches_ratio = 0.5
|
26 |
+
dataloader.val.batch_size = 32
|
27 |
+
dataloader.val.distributed = False
|
28 |
+
# train
|
29 |
+
train.init_checkpoint = "output/gazefollow/model_final.pth"
|
30 |
+
train.output_dir = join("./output", basename(__file__).split(".")[0])
|
31 |
+
train.max_iter = len_dataset * num_epoch // ins_per_iter
|
32 |
+
train.log_period = len_dataset // (ins_per_iter * 100)
|
33 |
+
train.checkpointer.max_to_keep = 10
|
34 |
+
train.checkpointer.period = len_dataset // ins_per_iter
|
35 |
+
train.seed = 0
|
36 |
+
# optimizer
|
37 |
+
optimizer.lr = 1e-5
|
38 |
+
optimizer.betas = (0.9, 0.99)
|
39 |
+
lr_multiplier.scheduler.typ = "cosine"
|
40 |
+
lr_multiplier.scheduler.start_value = 1
|
41 |
+
lr_multiplier.scheduler.end_value = 0.1
|
42 |
+
lr_multiplier.warmup_length = 1e-2
|
43 |
+
# model
|
44 |
+
model.use_aux_loss = model.pam.use_aux_loss = model.criterion.use_aux_loss = True
|
45 |
+
model.pam.name = "PatchPAM"
|
46 |
+
model.pam.embed_dim = 8
|
47 |
+
model.pam.patch_size = 14
|
48 |
+
model.backbone.name = "dinov2_small"
|
49 |
+
model.backbone.return_softmax_attn = True
|
50 |
+
model.backbone.out_attn = [2, 5, 8, 11]
|
51 |
+
model.backbone.use_cls_token = True
|
52 |
+
model.backbone.use_mask_token = True
|
53 |
+
model.regressor.name = "UpSampleConv"
|
54 |
+
model.regressor.in_channel = 24
|
55 |
+
model.regressor.use_conv = False
|
56 |
+
model.regressor.dim = 24
|
57 |
+
model.regressor.deconv_cfgs = [
|
58 |
+
dict(
|
59 |
+
in_channels=24,
|
60 |
+
out_channels=16,
|
61 |
+
kernel_size=3,
|
62 |
+
stride=1,
|
63 |
+
padding=1,
|
64 |
+
),
|
65 |
+
dict(
|
66 |
+
in_channels=16,
|
67 |
+
out_channels=8,
|
68 |
+
kernel_size=3,
|
69 |
+
stride=1,
|
70 |
+
padding=1,
|
71 |
+
),
|
72 |
+
dict(
|
73 |
+
in_channels=8,
|
74 |
+
out_channels=1,
|
75 |
+
kernel_size=3,
|
76 |
+
stride=1,
|
77 |
+
padding=1,
|
78 |
+
),
|
79 |
+
]
|
80 |
+
model.regressor.feat_type = "attn"
|
81 |
+
model.classifier.name = "SimpleMlp"
|
82 |
+
model.classifier.in_channel = 384
|
83 |
+
model.criterion.aux_weight = 0
|
84 |
+
model.criterion.aux_head_thres = 0.05
|
85 |
+
model.criterion.use_focal_loss = True
|
86 |
+
model.device = "cuda"
|
configs/videoattentiontarget.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .common.dataloader import dataloader
|
2 |
+
from .common.model import model
|
3 |
+
from .common.optimizer import optimizer
|
4 |
+
from .common.scheduler import lr_multiplier
|
5 |
+
from .common.train import train
|
6 |
+
from os.path import join, basename
|
7 |
+
from torch.cuda import device_count
|
8 |
+
|
9 |
+
|
10 |
+
num_gpu = device_count()
|
11 |
+
ins_per_iter = 4
|
12 |
+
len_dataset = 4400
|
13 |
+
num_epoch = 1
|
14 |
+
# dataloader
|
15 |
+
dataloader = dataloader.video_attention_target_video
|
16 |
+
dataloader.train.batch_size = ins_per_iter // num_gpu
|
17 |
+
dataloader.train.num_workers = dataloader.val.num_workers = 14
|
18 |
+
dataloader.train.distributed = num_gpu > 1
|
19 |
+
dataloader.train.rand_rotate = 0.5
|
20 |
+
dataloader.train.rand_lsj = 0.5
|
21 |
+
dataloader.train.input_size = dataloader.val.input_size = 518
|
22 |
+
dataloader.train.mask_scene = True
|
23 |
+
dataloader.train.mask_prob = 0.5
|
24 |
+
dataloader.train.mask_size = dataloader.train.input_size // 14
|
25 |
+
dataloader.train.max_scene_patches_ratio = 0.5
|
26 |
+
dataloader.train.seq_len = 8
|
27 |
+
dataloader.val.quant_labelmap = False
|
28 |
+
dataloader.val.seq_len = 8
|
29 |
+
dataloader.val.batch_size = 4
|
30 |
+
dataloader.val.distributed = False
|
31 |
+
# train
|
32 |
+
train.init_checkpoint = "output/gazefollow_518/model_final.pth"
|
33 |
+
train.output_dir = join("./output", basename(__file__).split(".")[0])
|
34 |
+
train.max_iter = len_dataset * num_epoch // ins_per_iter
|
35 |
+
train.log_period = len_dataset // (ins_per_iter * 100)
|
36 |
+
train.checkpointer.max_to_keep = 100
|
37 |
+
train.checkpointer.period = len_dataset // ins_per_iter
|
38 |
+
train.seed = 0
|
39 |
+
# optimizer
|
40 |
+
optimizer.lr = 1e-6
|
41 |
+
# optimizer.params.lr_factor_func.backbone_multiplier = 0.1
|
42 |
+
# optimizer.params.lr_factor_func.pam_lr_decay = 0.1
|
43 |
+
lr_multiplier.scheduler.values = [1.0]
|
44 |
+
lr_multiplier.scheduler.milestones = []
|
45 |
+
lr_multiplier.scheduler.num_updates = train.max_iter
|
46 |
+
lr_multiplier.warmup_length = 0
|
47 |
+
# model
|
48 |
+
model.use_aux_loss = model.pam.use_aux_loss = model.criterion.use_aux_loss = True
|
49 |
+
model.pam.name = "PatchPAM"
|
50 |
+
model.pam.embed_dim = 8
|
51 |
+
model.pam.patch_size = 14
|
52 |
+
model.backbone.name = "dinov2_small"
|
53 |
+
model.backbone.return_softmax_attn = True
|
54 |
+
model.backbone.out_attn = [2, 5, 8, 11]
|
55 |
+
model.backbone.use_cls_token = True
|
56 |
+
model.backbone.use_mask_token = True
|
57 |
+
model.regressor.name = "UpSampleConv"
|
58 |
+
model.regressor.in_channel = 24
|
59 |
+
model.regressor.use_conv = False
|
60 |
+
model.regressor.dim = 24
|
61 |
+
model.regressor.deconv_cfgs = [
|
62 |
+
dict(
|
63 |
+
in_channels=24,
|
64 |
+
out_channels=16,
|
65 |
+
kernel_size=3,
|
66 |
+
stride=1,
|
67 |
+
padding=1,
|
68 |
+
),
|
69 |
+
dict(
|
70 |
+
in_channels=16,
|
71 |
+
out_channels=8,
|
72 |
+
kernel_size=3,
|
73 |
+
stride=1,
|
74 |
+
padding=1,
|
75 |
+
),
|
76 |
+
dict(
|
77 |
+
in_channels=8,
|
78 |
+
out_channels=1,
|
79 |
+
kernel_size=3,
|
80 |
+
stride=1,
|
81 |
+
padding=1,
|
82 |
+
),
|
83 |
+
]
|
84 |
+
model.regressor.feat_type = "attn"
|
85 |
+
model.classifier.name = "SimpleMlp"
|
86 |
+
model.classifier.in_channel = 384
|
87 |
+
model.criterion.aux_weight = 0
|
88 |
+
model.criterion.aux_head_thres = 0.05
|
89 |
+
model.criterion.use_focal_loss = True
|
90 |
+
model.device = "cuda"
|
data/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .video_attention_target import VideoAttentionTarget
|
2 |
+
from .video_attention_target_video import VideoAttentionTargetVideo, video_collate
|
3 |
+
from .gazefollow import GazeFollow
|
4 |
+
from .data_utils import get_transform
|
5 |
+
from .masking import MaskGenerator
|
data/augmentation.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Tuple, List
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image, ImageOps
|
5 |
+
from torchvision import transforms
|
6 |
+
from torchvision.transforms import functional as TF
|
7 |
+
|
8 |
+
|
9 |
+
class Augmentation:
|
10 |
+
def __init__(self, p: float) -> None:
|
11 |
+
self.p = p
|
12 |
+
|
13 |
+
def transform(
|
14 |
+
self,
|
15 |
+
image: Image,
|
16 |
+
bbox: Tuple[float],
|
17 |
+
gaze: Tuple[float],
|
18 |
+
head_mask: Image,
|
19 |
+
size: Tuple[int],
|
20 |
+
):
|
21 |
+
raise NotImplementedError
|
22 |
+
|
23 |
+
def __call__(
|
24 |
+
self,
|
25 |
+
image: Image,
|
26 |
+
bbox: Tuple[float],
|
27 |
+
gaze: Tuple[float],
|
28 |
+
head_mask: Image,
|
29 |
+
size: Tuple[int],
|
30 |
+
):
|
31 |
+
if np.random.random_sample() < self.p:
|
32 |
+
return self.transform(image, bbox, gaze, head_mask, size)
|
33 |
+
return image, bbox, gaze, head_mask, size
|
34 |
+
|
35 |
+
|
36 |
+
class AugmentationList:
|
37 |
+
def __init__(self, augmentations: List[Augmentation]) -> None:
|
38 |
+
self.augmentations = augmentations
|
39 |
+
|
40 |
+
def __call__(
|
41 |
+
self,
|
42 |
+
image: Image,
|
43 |
+
bbox: Tuple[float],
|
44 |
+
gaze: Tuple[float],
|
45 |
+
head_mask: Image,
|
46 |
+
size: Tuple[int],
|
47 |
+
):
|
48 |
+
for aug in self.augmentations:
|
49 |
+
image, bbox, gaze, head_mask, size = aug(image, bbox, gaze, head_mask, size)
|
50 |
+
return image, bbox, gaze, head_mask, size
|
51 |
+
|
52 |
+
|
53 |
+
class BoxJitter(Augmentation):
|
54 |
+
# Jitter (expansion-only) bounding box size
|
55 |
+
def __init__(self, p: float, expansion: float = 0.2) -> None:
|
56 |
+
super().__init__(p)
|
57 |
+
self.expansion = expansion
|
58 |
+
|
59 |
+
def transform(
|
60 |
+
self,
|
61 |
+
image: Image,
|
62 |
+
bbox: Tuple[float],
|
63 |
+
gaze: Tuple[float],
|
64 |
+
head_mask: Image,
|
65 |
+
size: Tuple[int],
|
66 |
+
):
|
67 |
+
x_min, y_min, x_max, y_max = bbox
|
68 |
+
width, height = size
|
69 |
+
k = np.random.random_sample() * self.expansion
|
70 |
+
x_min = np.clip(x_min - k * abs(x_max - x_min), 0, width - 1)
|
71 |
+
y_min = np.clip(y_min - k * abs(y_max - y_min), 0, height - 1)
|
72 |
+
x_max = np.clip(x_max + k * abs(x_max - x_min), 0, width - 1)
|
73 |
+
y_max = np.clip(y_max + k * abs(y_max - y_min), 0, height - 1)
|
74 |
+
return image, (x_min, y_min, x_max, y_max), gaze, head_mask, size
|
75 |
+
|
76 |
+
|
77 |
+
class RandomCrop(Augmentation):
|
78 |
+
def __init__(self, p: float) -> None:
|
79 |
+
super().__init__(p)
|
80 |
+
|
81 |
+
def transform(
|
82 |
+
self,
|
83 |
+
image: Image,
|
84 |
+
bbox: Tuple[float],
|
85 |
+
gaze: Tuple[float],
|
86 |
+
head_mask: Image,
|
87 |
+
size: Tuple[int],
|
88 |
+
):
|
89 |
+
x_min, y_min, x_max, y_max = bbox
|
90 |
+
gaze_x, gaze_y = gaze
|
91 |
+
width, height = size
|
92 |
+
# Calculate the minimum valid range of the crop that doesn't exclude the face and the gaze target
|
93 |
+
crop_x_min = np.min([gaze_x * width, x_min, x_max])
|
94 |
+
crop_y_min = np.min([gaze_y * height, y_min, y_max])
|
95 |
+
crop_x_max = np.max([gaze_x * width, x_min, x_max])
|
96 |
+
crop_y_max = np.max([gaze_y * height, y_min, y_max])
|
97 |
+
|
98 |
+
# Randomly select a random top left corner
|
99 |
+
crop_x_min = np.random.uniform(0, crop_x_min)
|
100 |
+
crop_y_min = np.random.uniform(0, crop_y_min)
|
101 |
+
|
102 |
+
# Find the range of valid crop width and height starting from the (crop_x_min, crop_y_min)
|
103 |
+
crop_width_min = crop_x_max - crop_x_min
|
104 |
+
crop_height_min = crop_y_max - crop_y_min
|
105 |
+
crop_width_max = width - crop_x_min
|
106 |
+
crop_height_max = height - crop_y_min
|
107 |
+
|
108 |
+
# Randomly select a width and a height
|
109 |
+
crop_width = np.random.uniform(crop_width_min, crop_width_max)
|
110 |
+
crop_height = np.random.uniform(crop_height_min, crop_height_max)
|
111 |
+
|
112 |
+
# Round to integers
|
113 |
+
crop_y_min, crop_x_min, crop_height, crop_width = map(
|
114 |
+
int, map(round, (crop_y_min, crop_x_min, crop_height, crop_width))
|
115 |
+
)
|
116 |
+
|
117 |
+
# Crop it
|
118 |
+
image = TF.crop(image, crop_y_min, crop_x_min, crop_height, crop_width)
|
119 |
+
head_mask = TF.crop(head_mask, crop_y_min, crop_x_min, crop_height, crop_width)
|
120 |
+
|
121 |
+
# convert coordinates into the cropped frame
|
122 |
+
x_min, y_min, x_max, y_max = (
|
123 |
+
x_min - crop_x_min,
|
124 |
+
y_min - crop_y_min,
|
125 |
+
x_max - crop_x_min,
|
126 |
+
y_max - crop_y_min,
|
127 |
+
)
|
128 |
+
|
129 |
+
gaze_x = (gaze_x * width - crop_x_min) / float(crop_width)
|
130 |
+
gaze_y = (gaze_y * height - crop_y_min) / float(crop_height)
|
131 |
+
|
132 |
+
return (
|
133 |
+
image,
|
134 |
+
(x_min, y_min, x_max, y_max),
|
135 |
+
(gaze_x, gaze_y),
|
136 |
+
head_mask,
|
137 |
+
(crop_width, crop_height),
|
138 |
+
)
|
139 |
+
|
140 |
+
|
141 |
+
class RandomFlip(Augmentation):
|
142 |
+
def __init__(self, p: float) -> None:
|
143 |
+
super().__init__(p)
|
144 |
+
|
145 |
+
def transform(
|
146 |
+
self,
|
147 |
+
image: Image,
|
148 |
+
bbox: Tuple[float],
|
149 |
+
gaze: Tuple[float],
|
150 |
+
head_mask: Image,
|
151 |
+
size: Tuple[int],
|
152 |
+
):
|
153 |
+
image = image.transpose(Image.FLIP_LEFT_RIGHT)
|
154 |
+
head_mask = head_mask.transpose(Image.FLIP_LEFT_RIGHT)
|
155 |
+
x_min, y_min, x_max, y_max = bbox
|
156 |
+
x_min, x_max = size[0] - x_max, size[0] - x_min
|
157 |
+
gaze_x, gaze_y = 1 - gaze[0], gaze[1]
|
158 |
+
return image, (x_min, y_min, x_max, y_max), (gaze_x, gaze_y), head_mask, size
|
159 |
+
|
160 |
+
|
161 |
+
class RandomRotate(Augmentation):
|
162 |
+
def __init__(
|
163 |
+
self, p: float, max_angle: int = 20, resample: int = Image.BILINEAR
|
164 |
+
) -> None:
|
165 |
+
super().__init__(p)
|
166 |
+
self.max_angle = max_angle
|
167 |
+
self.resample = resample
|
168 |
+
|
169 |
+
def _random_rotation_matrix(self):
|
170 |
+
angle = (2 * np.random.random_sample() - 1) * self.max_angle
|
171 |
+
angle = -math.radians(angle)
|
172 |
+
return [
|
173 |
+
round(math.cos(angle), 15),
|
174 |
+
round(math.sin(angle), 15),
|
175 |
+
0.0,
|
176 |
+
round(-math.sin(angle), 15),
|
177 |
+
round(math.cos(angle), 15),
|
178 |
+
0.0,
|
179 |
+
]
|
180 |
+
|
181 |
+
@staticmethod
|
182 |
+
def _transform(x, y, matrix):
|
183 |
+
return (
|
184 |
+
matrix[0] * x + matrix[1] * y + matrix[2],
|
185 |
+
matrix[3] * x + matrix[4] * y + matrix[5],
|
186 |
+
)
|
187 |
+
|
188 |
+
@staticmethod
|
189 |
+
def _inv_transform(x, y, matrix):
|
190 |
+
x, y = x - matrix[2], y - matrix[5]
|
191 |
+
return matrix[0] * x + matrix[3] * y, matrix[1] * x + matrix[4] * y
|
192 |
+
|
193 |
+
def transform(
|
194 |
+
self,
|
195 |
+
image: Image,
|
196 |
+
bbox: Tuple[float],
|
197 |
+
gaze: Tuple[float],
|
198 |
+
head_mask: Image,
|
199 |
+
size: Tuple[int],
|
200 |
+
):
|
201 |
+
x_min, y_min, x_max, y_max = bbox
|
202 |
+
gaze_x, gaze_y = gaze
|
203 |
+
width, height = size
|
204 |
+
rot_mat = self._random_rotation_matrix()
|
205 |
+
|
206 |
+
# Calculate offsets
|
207 |
+
rot_center = (width / 2.0, height / 2.0)
|
208 |
+
rot_mat[2], rot_mat[5] = self._transform(
|
209 |
+
-rot_center[0], -rot_center[1], rot_mat
|
210 |
+
)
|
211 |
+
rot_mat[2] += rot_center[0]
|
212 |
+
rot_mat[5] += rot_center[1]
|
213 |
+
xx = []
|
214 |
+
yy = []
|
215 |
+
for x, y in ((0, 0), (width, 0), (width, height), (0, height)):
|
216 |
+
x, y = self._transform(x, y, rot_mat)
|
217 |
+
xx.append(x)
|
218 |
+
yy.append(y)
|
219 |
+
nw = math.ceil(max(xx)) - math.floor(min(xx))
|
220 |
+
nh = math.ceil(max(yy)) - math.floor(min(yy))
|
221 |
+
rot_mat[2], rot_mat[5] = self._transform(
|
222 |
+
-(nw - width) / 2.0, -(nh - height) / 2.0, rot_mat
|
223 |
+
)
|
224 |
+
|
225 |
+
image = image.transform((nw, nh), Image.AFFINE, rot_mat, self.resample)
|
226 |
+
head_mask = head_mask.transform((nw, nh), Image.AFFINE, rot_mat, self.resample)
|
227 |
+
|
228 |
+
xx = []
|
229 |
+
yy = []
|
230 |
+
for x, y in (
|
231 |
+
(x_min, y_min),
|
232 |
+
(x_min, y_max),
|
233 |
+
(x_max, y_min),
|
234 |
+
(x_max, y_max),
|
235 |
+
):
|
236 |
+
x, y = self._inv_transform(x, y, rot_mat)
|
237 |
+
xx.append(x)
|
238 |
+
yy.append(y)
|
239 |
+
x_max, x_min = min(max(xx), nw), max(min(xx), 0)
|
240 |
+
y_max, y_min = min(max(yy), nh), max(min(yy), 0)
|
241 |
+
|
242 |
+
gaze_x, gaze_y = self._inv_transform(gaze_x * width, gaze_y * height, rot_mat)
|
243 |
+
gaze_x = max(min(gaze_x / nw, 1), 0)
|
244 |
+
gaze_y = max(min(gaze_y / nh, 1), 0)
|
245 |
+
|
246 |
+
return (
|
247 |
+
image,
|
248 |
+
(x_min, y_min, x_max, y_max),
|
249 |
+
(gaze_x, gaze_y),
|
250 |
+
head_mask,
|
251 |
+
(nw, nh),
|
252 |
+
)
|
253 |
+
|
254 |
+
|
255 |
+
class ColorJitter(Augmentation):
|
256 |
+
def __init__(
|
257 |
+
self,
|
258 |
+
p: float,
|
259 |
+
brightness: float = 0.4,
|
260 |
+
contrast: float = 0.4,
|
261 |
+
saturation: float = 0.2,
|
262 |
+
hue: float = 0.1,
|
263 |
+
) -> None:
|
264 |
+
super().__init__(p)
|
265 |
+
self.color_jitter = transforms.ColorJitter(
|
266 |
+
brightness=brightness, contrast=contrast, saturation=saturation, hue=hue
|
267 |
+
)
|
268 |
+
|
269 |
+
def transform(
|
270 |
+
self,
|
271 |
+
image: Image,
|
272 |
+
bbox: Tuple[float],
|
273 |
+
gaze: Tuple[float],
|
274 |
+
head_mask: Image,
|
275 |
+
size: Tuple[int],
|
276 |
+
):
|
277 |
+
return self.color_jitter(image), bbox, gaze, head_mask, size
|
278 |
+
|
279 |
+
|
280 |
+
class RandomLSJ(Augmentation):
|
281 |
+
def __init__(self, p: float, min_scale: float = 0.1) -> None:
|
282 |
+
super().__init__(p)
|
283 |
+
self.min_scale = min_scale
|
284 |
+
|
285 |
+
def transform(
|
286 |
+
self,
|
287 |
+
image: Image,
|
288 |
+
bbox: Tuple[float],
|
289 |
+
gaze: Tuple[float],
|
290 |
+
head_mask: Image,
|
291 |
+
size: Tuple[int],
|
292 |
+
):
|
293 |
+
x_min, y_min, x_max, y_max = bbox
|
294 |
+
gaze_x, gaze_y = gaze
|
295 |
+
width, height = size
|
296 |
+
|
297 |
+
scale = self.min_scale + np.random.random_sample() * (1 - self.min_scale)
|
298 |
+
nh, nw = int(height * scale), int(width * scale)
|
299 |
+
|
300 |
+
image = TF.resize(image, (nh, nw))
|
301 |
+
image = ImageOps.expand(image, (0, 0, width - nw, height - nh))
|
302 |
+
head_mask = TF.resize(head_mask, (nh, nw))
|
303 |
+
head_mask = ImageOps.expand(head_mask, (0, 0, width - nw, height - nh))
|
304 |
+
|
305 |
+
x_min, y_min, x_max, y_max = (
|
306 |
+
x_min * scale,
|
307 |
+
y_min * scale,
|
308 |
+
x_max * scale,
|
309 |
+
y_max * scale,
|
310 |
+
)
|
311 |
+
gaze_x, gaze_y = gaze_x * scale, gaze_y * scale
|
312 |
+
return image, (x_min, y_min, x_max, y_max), (gaze_x, gaze_y), head_mask, size
|
data/data_utils.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
import torch
|
3 |
+
from torchvision import transforms
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
|
8 |
+
def to_numpy(tensor: torch.Tensor):
|
9 |
+
if torch.is_tensor(tensor):
|
10 |
+
return tensor.cpu().detach().numpy()
|
11 |
+
elif type(tensor).__module__ != "numpy":
|
12 |
+
raise ValueError("Cannot convert {} to numpy array".format(type(tensor)))
|
13 |
+
return tensor
|
14 |
+
|
15 |
+
|
16 |
+
def to_torch(ndarray: np.ndarray):
|
17 |
+
if type(ndarray).__module__ == "numpy":
|
18 |
+
return torch.from_numpy(ndarray)
|
19 |
+
elif not torch.is_tensor(ndarray):
|
20 |
+
raise ValueError("Cannot convert {} to torch tensor".format(type(ndarray)))
|
21 |
+
return ndarray
|
22 |
+
|
23 |
+
|
24 |
+
def get_head_box_channel(
|
25 |
+
x_min, y_min, x_max, y_max, width, height, resolution, coordconv=False
|
26 |
+
):
|
27 |
+
head_box = (
|
28 |
+
np.array([x_min / width, y_min / height, x_max / width, y_max / height])
|
29 |
+
* resolution
|
30 |
+
)
|
31 |
+
int_head_box = head_box.astype(int)
|
32 |
+
int_head_box = np.clip(int_head_box, 0, resolution - 1)
|
33 |
+
if int_head_box[0] == int_head_box[2]:
|
34 |
+
if int_head_box[0] == 0:
|
35 |
+
int_head_box[2] = 1
|
36 |
+
elif int_head_box[2] == resolution - 1:
|
37 |
+
int_head_box[0] = resolution - 2
|
38 |
+
elif abs(head_box[2] - int_head_box[2]) > abs(head_box[0] - int_head_box[0]):
|
39 |
+
int_head_box[2] += 1
|
40 |
+
else:
|
41 |
+
int_head_box[0] -= 1
|
42 |
+
if int_head_box[1] == int_head_box[3]:
|
43 |
+
if int_head_box[1] == 0:
|
44 |
+
int_head_box[3] = 1
|
45 |
+
elif int_head_box[3] == resolution - 1:
|
46 |
+
int_head_box[1] = resolution - 2
|
47 |
+
elif abs(head_box[3] - int_head_box[3]) > abs(head_box[1] - int_head_box[1]):
|
48 |
+
int_head_box[3] += 1
|
49 |
+
else:
|
50 |
+
int_head_box[1] -= 1
|
51 |
+
head_box = int_head_box
|
52 |
+
if coordconv:
|
53 |
+
unit = np.array(range(0, resolution), dtype=np.float32)
|
54 |
+
head_channel = []
|
55 |
+
for i in unit:
|
56 |
+
head_channel.append([unit + i])
|
57 |
+
head_channel = np.squeeze(np.array(head_channel)) / float(np.max(head_channel))
|
58 |
+
head_channel[head_box[1] : head_box[3], head_box[0] : head_box[2]] = 0
|
59 |
+
else:
|
60 |
+
head_channel = np.zeros((resolution, resolution), dtype=np.float32)
|
61 |
+
head_channel[head_box[1] : head_box[3], head_box[0] : head_box[2]] = 1
|
62 |
+
head_channel = torch.from_numpy(head_channel)
|
63 |
+
return head_channel
|
64 |
+
|
65 |
+
|
66 |
+
def draw_labelmap(img, pt, sigma, type="Gaussian"):
|
67 |
+
# Draw a 2D gaussian
|
68 |
+
# Adopted from https://github.com/anewell/pose-hg-train/blob/master/src/pypose/draw.py
|
69 |
+
img = to_numpy(img)
|
70 |
+
|
71 |
+
# Check that any part of the gaussian is in-bounds
|
72 |
+
size = int(6 * sigma + 1)
|
73 |
+
ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)]
|
74 |
+
br = [ul[0] + size, ul[1] + size]
|
75 |
+
if ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or br[0] < 0 or br[1] < 0:
|
76 |
+
# If not, just return the image as is
|
77 |
+
return to_torch(img)
|
78 |
+
|
79 |
+
# Generate gaussian
|
80 |
+
x = np.arange(0, size, 1, float)
|
81 |
+
y = x[:, np.newaxis]
|
82 |
+
x0 = y0 = size // 2
|
83 |
+
# The gaussian is not normalized, we want the center value to equal 1
|
84 |
+
if type == "Gaussian":
|
85 |
+
g = np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma**2))
|
86 |
+
elif type == "Cauchy":
|
87 |
+
g = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma**2) ** 1.5)
|
88 |
+
|
89 |
+
# Usable gaussian range
|
90 |
+
g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0]
|
91 |
+
g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1]
|
92 |
+
# Image range
|
93 |
+
img_x = max(0, ul[0]), min(br[0], img.shape[1])
|
94 |
+
img_y = max(0, ul[1]), min(br[1], img.shape[0])
|
95 |
+
|
96 |
+
img[img_y[0] : img_y[1], img_x[0] : img_x[1]] += g[g_y[0] : g_y[1], g_x[0] : g_x[1]]
|
97 |
+
# img = img / np.max(img)
|
98 |
+
return to_torch(img)
|
99 |
+
|
100 |
+
|
101 |
+
def draw_labelmap_no_quant(img, pt, sigma, type="Gaussian"):
|
102 |
+
img = to_numpy(img)
|
103 |
+
shape = img.shape
|
104 |
+
x = np.arange(shape[0])
|
105 |
+
y = np.arange(shape[1])
|
106 |
+
xx, yy = np.meshgrid(x, y, indexing="ij")
|
107 |
+
dist_matrix = (yy - float(pt[0])) ** 2 + (xx - float(pt[1])) ** 2
|
108 |
+
if type == "Gaussian":
|
109 |
+
g = np.exp(-dist_matrix / (2 * sigma**2))
|
110 |
+
elif type == "Cauchy":
|
111 |
+
g = sigma / ((dist_matrix + sigma**2) ** 1.5)
|
112 |
+
g[dist_matrix > 10 * sigma**2] = 0
|
113 |
+
img += g
|
114 |
+
# img = img / np.max(img)
|
115 |
+
return to_torch(img)
|
116 |
+
|
117 |
+
|
118 |
+
def multi_hot_targets(gaze_pts, out_res):
|
119 |
+
w, h = out_res
|
120 |
+
target_map = np.zeros((h, w))
|
121 |
+
for p in gaze_pts:
|
122 |
+
if p[0] >= 0:
|
123 |
+
x, y = map(int, [p[0] * float(w), p[1] * float(h)])
|
124 |
+
x = min(x, w - 1)
|
125 |
+
y = min(y, h - 1)
|
126 |
+
target_map[y, x] = 1
|
127 |
+
return target_map
|
128 |
+
|
129 |
+
|
130 |
+
def get_cone(tgt, src, wh, theta=150):
|
131 |
+
eye = src * wh
|
132 |
+
gaze = tgt * wh
|
133 |
+
|
134 |
+
pixel_mat = np.stack(
|
135 |
+
np.meshgrid(np.arange(wh[0]), np.arange(wh[1])),
|
136 |
+
-1,
|
137 |
+
)
|
138 |
+
|
139 |
+
dot_prod = np.sum((pixel_mat - eye) * (gaze - eye), axis=-1)
|
140 |
+
gaze_vector_norm = np.sqrt(np.sum((gaze - eye) ** 2))
|
141 |
+
pixel_mat_norm = np.sqrt(np.sum((pixel_mat - eye) ** 2, axis=-1))
|
142 |
+
|
143 |
+
gaze_cones = dot_prod / (gaze_vector_norm * pixel_mat_norm)
|
144 |
+
gaze_cones = np.nan_to_num(gaze_cones, nan=1)
|
145 |
+
|
146 |
+
theta = theta * (np.pi / 180)
|
147 |
+
beta = np.arccos(gaze_cones)
|
148 |
+
# Create mask where true if beta is less than theta/2
|
149 |
+
pixel_mat_presence = beta < (theta / 2)
|
150 |
+
|
151 |
+
# Zero out values outside the gaze cone
|
152 |
+
gaze_cones[~pixel_mat_presence] = 0
|
153 |
+
gaze_cones = np.clip(gaze_cones, 0, None)
|
154 |
+
|
155 |
+
return torch.from_numpy(gaze_cones).unsqueeze(0).float()
|
156 |
+
|
157 |
+
|
158 |
+
def get_transform(
|
159 |
+
input_resolution: int, mean: Tuple[int, int, int], std: Tuple[int, int, int]
|
160 |
+
):
|
161 |
+
return transforms.Compose(
|
162 |
+
[
|
163 |
+
transforms.Resize((input_resolution, input_resolution)),
|
164 |
+
transforms.ToTensor(),
|
165 |
+
transforms.Normalize(mean=mean, std=std),
|
166 |
+
]
|
167 |
+
)
|
168 |
+
|
169 |
+
|
170 |
+
def smooth_by_conv(window_size, df, col):
|
171 |
+
padded_track = pd.concat(
|
172 |
+
[
|
173 |
+
pd.DataFrame([[df.iloc[0][col]]] * (window_size // 2), columns=[0]),
|
174 |
+
df[col],
|
175 |
+
pd.DataFrame([[df.iloc[-1][col]]] * (window_size // 2), columns=[0]),
|
176 |
+
]
|
177 |
+
)
|
178 |
+
smoothed_signals = np.convolve(
|
179 |
+
padded_track.squeeze(), np.ones(window_size) / window_size, mode="valid"
|
180 |
+
)
|
181 |
+
return smoothed_signals
|
data/gazefollow.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os import path as osp
|
2 |
+
from typing import Callable, Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from torchvision.transforms import functional as TF
|
7 |
+
from PIL import Image
|
8 |
+
import pandas as pd
|
9 |
+
|
10 |
+
from . import augmentation
|
11 |
+
from .masking import MaskGenerator
|
12 |
+
from . import data_utils as utils
|
13 |
+
|
14 |
+
|
15 |
+
class GazeFollow(Dataset):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
image_root: str,
|
19 |
+
anno_root: str,
|
20 |
+
head_root: str,
|
21 |
+
transform: Callable,
|
22 |
+
input_size: int,
|
23 |
+
output_size: int,
|
24 |
+
quant_labelmap: bool = True,
|
25 |
+
is_train: bool = True,
|
26 |
+
*,
|
27 |
+
mask_generator: Optional[MaskGenerator] = None,
|
28 |
+
bbox_jitter: float = 0.5,
|
29 |
+
rand_crop: float = 0.5,
|
30 |
+
rand_flip: float = 0.5,
|
31 |
+
color_jitter: float = 0.5,
|
32 |
+
rand_rotate: float = 0.0,
|
33 |
+
rand_lsj: float = 0.0,
|
34 |
+
):
|
35 |
+
if is_train:
|
36 |
+
column_names = [
|
37 |
+
"path",
|
38 |
+
"idx",
|
39 |
+
"body_bbox_x",
|
40 |
+
"body_bbox_y",
|
41 |
+
"body_bbox_w",
|
42 |
+
"body_bbox_h",
|
43 |
+
"eye_x",
|
44 |
+
"eye_y",
|
45 |
+
"gaze_x",
|
46 |
+
"gaze_y",
|
47 |
+
"bbox_x_min",
|
48 |
+
"bbox_y_min",
|
49 |
+
"bbox_x_max",
|
50 |
+
"bbox_y_max",
|
51 |
+
"inout",
|
52 |
+
"meta0",
|
53 |
+
"meta1",
|
54 |
+
]
|
55 |
+
df = pd.read_csv(
|
56 |
+
anno_root,
|
57 |
+
sep=",",
|
58 |
+
names=column_names,
|
59 |
+
index_col=False,
|
60 |
+
encoding="utf-8-sig",
|
61 |
+
)
|
62 |
+
df = df[
|
63 |
+
df["inout"] != -1
|
64 |
+
] # only use "in" or "out "gaze. (-1 is invalid, 0 is out gaze)
|
65 |
+
df.reset_index(inplace=True)
|
66 |
+
self.y_train = df[
|
67 |
+
[
|
68 |
+
"bbox_x_min",
|
69 |
+
"bbox_y_min",
|
70 |
+
"bbox_x_max",
|
71 |
+
"bbox_y_max",
|
72 |
+
"eye_x",
|
73 |
+
"eye_y",
|
74 |
+
"gaze_x",
|
75 |
+
"gaze_y",
|
76 |
+
"inout",
|
77 |
+
]
|
78 |
+
]
|
79 |
+
self.X_train = df["path"]
|
80 |
+
self.length = len(df)
|
81 |
+
else:
|
82 |
+
column_names = [
|
83 |
+
"path",
|
84 |
+
"idx",
|
85 |
+
"body_bbox_x",
|
86 |
+
"body_bbox_y",
|
87 |
+
"body_bbox_w",
|
88 |
+
"body_bbox_h",
|
89 |
+
"eye_x",
|
90 |
+
"eye_y",
|
91 |
+
"gaze_x",
|
92 |
+
"gaze_y",
|
93 |
+
"bbox_x_min",
|
94 |
+
"bbox_y_min",
|
95 |
+
"bbox_x_max",
|
96 |
+
"bbox_y_max",
|
97 |
+
"meta0",
|
98 |
+
"meta1",
|
99 |
+
]
|
100 |
+
df = pd.read_csv(
|
101 |
+
anno_root,
|
102 |
+
sep=",",
|
103 |
+
names=column_names,
|
104 |
+
index_col=False,
|
105 |
+
encoding="utf-8-sig",
|
106 |
+
)
|
107 |
+
df = df[
|
108 |
+
[
|
109 |
+
"path",
|
110 |
+
"eye_x",
|
111 |
+
"eye_y",
|
112 |
+
"gaze_x",
|
113 |
+
"gaze_y",
|
114 |
+
"bbox_x_min",
|
115 |
+
"bbox_y_min",
|
116 |
+
"bbox_x_max",
|
117 |
+
"bbox_y_max",
|
118 |
+
]
|
119 |
+
].groupby(["path", "eye_x"])
|
120 |
+
self.keys = list(df.groups.keys())
|
121 |
+
self.X_test = df
|
122 |
+
self.length = len(self.keys)
|
123 |
+
|
124 |
+
self.data_dir = image_root
|
125 |
+
self.head_dir = head_root
|
126 |
+
self.transform = transform
|
127 |
+
self.is_train = is_train
|
128 |
+
|
129 |
+
self.input_size = input_size
|
130 |
+
self.output_size = output_size
|
131 |
+
|
132 |
+
self.draw_labelmap = (
|
133 |
+
utils.draw_labelmap if quant_labelmap else utils.draw_labelmap_no_quant
|
134 |
+
)
|
135 |
+
|
136 |
+
if self.is_train:
|
137 |
+
## data augmentation
|
138 |
+
self.augment = augmentation.AugmentationList(
|
139 |
+
[
|
140 |
+
augmentation.ColorJitter(color_jitter),
|
141 |
+
augmentation.BoxJitter(bbox_jitter),
|
142 |
+
augmentation.RandomCrop(rand_crop),
|
143 |
+
augmentation.RandomFlip(rand_flip),
|
144 |
+
augmentation.RandomRotate(rand_rotate),
|
145 |
+
augmentation.RandomLSJ(rand_lsj),
|
146 |
+
]
|
147 |
+
)
|
148 |
+
|
149 |
+
self.mask_generator = mask_generator
|
150 |
+
|
151 |
+
def __getitem__(self, index):
|
152 |
+
if not self.is_train:
|
153 |
+
g = self.X_test.get_group(self.keys[index])
|
154 |
+
cont_gaze = []
|
155 |
+
for _, row in g.iterrows():
|
156 |
+
path = row["path"]
|
157 |
+
x_min = row["bbox_x_min"]
|
158 |
+
y_min = row["bbox_y_min"]
|
159 |
+
x_max = row["bbox_x_max"]
|
160 |
+
y_max = row["bbox_y_max"]
|
161 |
+
eye_x = row["eye_x"]
|
162 |
+
eye_y = row["eye_y"]
|
163 |
+
gaze_x = row["gaze_x"]
|
164 |
+
gaze_y = row["gaze_y"]
|
165 |
+
cont_gaze.append(
|
166 |
+
[gaze_x, gaze_y]
|
167 |
+
) # all ground truth gaze are stacked up
|
168 |
+
for _ in range(len(cont_gaze), 20):
|
169 |
+
cont_gaze.append(
|
170 |
+
[-1, -1]
|
171 |
+
) # pad dummy gaze to match size for batch processing
|
172 |
+
cont_gaze = torch.FloatTensor(cont_gaze)
|
173 |
+
gaze_inside = True # always consider test samples as inside
|
174 |
+
else:
|
175 |
+
path = self.X_train.iloc[index]
|
176 |
+
(
|
177 |
+
x_min,
|
178 |
+
y_min,
|
179 |
+
x_max,
|
180 |
+
y_max,
|
181 |
+
eye_x,
|
182 |
+
eye_y,
|
183 |
+
gaze_x,
|
184 |
+
gaze_y,
|
185 |
+
inout,
|
186 |
+
) = self.y_train.iloc[index]
|
187 |
+
gaze_inside = bool(inout)
|
188 |
+
|
189 |
+
img = Image.open(osp.join(self.data_dir, path))
|
190 |
+
img = img.convert("RGB")
|
191 |
+
head_mask = Image.open(osp.join(self.head_dir, path))
|
192 |
+
width, height = img.size
|
193 |
+
x_min, y_min, x_max, y_max = map(float, [x_min, y_min, x_max, y_max])
|
194 |
+
if x_max < x_min:
|
195 |
+
x_min, x_max = x_max, x_min
|
196 |
+
if y_max < y_min:
|
197 |
+
y_min, y_max = y_max, y_min
|
198 |
+
# expand face bbox a bit
|
199 |
+
k = 0.1
|
200 |
+
x_min = max(x_min - k * abs(x_max - x_min), 0)
|
201 |
+
y_min = max(y_min - k * abs(y_max - y_min), 0)
|
202 |
+
x_max = min(x_max + k * abs(x_max - x_min), width - 1)
|
203 |
+
y_max = min(y_max + k * abs(y_max - y_min), height - 1)
|
204 |
+
|
205 |
+
if self.is_train:
|
206 |
+
img, bbox, gaze, head_mask, size = self.augment(
|
207 |
+
img,
|
208 |
+
(x_min, y_min, x_max, y_max),
|
209 |
+
(gaze_x, gaze_y),
|
210 |
+
head_mask,
|
211 |
+
(width, height),
|
212 |
+
)
|
213 |
+
x_min, y_min, x_max, y_max = bbox
|
214 |
+
gaze_x, gaze_y = gaze
|
215 |
+
width, height = size
|
216 |
+
|
217 |
+
head_channel = utils.get_head_box_channel(
|
218 |
+
x_min,
|
219 |
+
y_min,
|
220 |
+
x_max,
|
221 |
+
y_max,
|
222 |
+
width,
|
223 |
+
height,
|
224 |
+
resolution=self.input_size,
|
225 |
+
coordconv=False,
|
226 |
+
).unsqueeze(0)
|
227 |
+
|
228 |
+
if self.is_train and self.mask_generator is not None:
|
229 |
+
image_mask = self.mask_generator(
|
230 |
+
x_min / width,
|
231 |
+
y_min / height,
|
232 |
+
x_max / width,
|
233 |
+
y_max / height,
|
234 |
+
head_channel,
|
235 |
+
)
|
236 |
+
|
237 |
+
if self.transform is not None:
|
238 |
+
img = self.transform(img)
|
239 |
+
head_mask = TF.to_tensor(
|
240 |
+
TF.resize(head_mask, (self.input_size, self.input_size))
|
241 |
+
)
|
242 |
+
|
243 |
+
# generate the heat map used for deconv prediction
|
244 |
+
gaze_heatmap = torch.zeros(
|
245 |
+
self.output_size, self.output_size
|
246 |
+
) # set the size of the output
|
247 |
+
if not self.is_train: # aggregated heatmap
|
248 |
+
num_valid = 0
|
249 |
+
for gaze_x, gaze_y in cont_gaze:
|
250 |
+
if gaze_x != -1:
|
251 |
+
num_valid += 1
|
252 |
+
gaze_heatmap += self.draw_labelmap(
|
253 |
+
torch.zeros(self.output_size, self.output_size),
|
254 |
+
[gaze_x * self.output_size, gaze_y * self.output_size],
|
255 |
+
3,
|
256 |
+
type="Gaussian",
|
257 |
+
)
|
258 |
+
gaze_heatmap /= num_valid
|
259 |
+
else:
|
260 |
+
# if gaze_inside:
|
261 |
+
gaze_heatmap = self.draw_labelmap(
|
262 |
+
gaze_heatmap,
|
263 |
+
[gaze_x * self.output_size, gaze_y * self.output_size],
|
264 |
+
3,
|
265 |
+
type="Gaussian",
|
266 |
+
)
|
267 |
+
|
268 |
+
imsize = torch.IntTensor([width, height])
|
269 |
+
|
270 |
+
if self.is_train:
|
271 |
+
out_dict = {
|
272 |
+
"images": img,
|
273 |
+
"head_channels": head_channel,
|
274 |
+
"heatmaps": gaze_heatmap,
|
275 |
+
"gazes": torch.FloatTensor([gaze_x, gaze_y]),
|
276 |
+
"gaze_inouts": torch.FloatTensor([gaze_inside]),
|
277 |
+
"head_masks": head_mask,
|
278 |
+
"imsize": imsize,
|
279 |
+
}
|
280 |
+
if self.mask_generator is not None:
|
281 |
+
out_dict["image_masks"] = image_mask
|
282 |
+
return out_dict
|
283 |
+
else:
|
284 |
+
return {
|
285 |
+
"images": img,
|
286 |
+
"head_channels": head_channel,
|
287 |
+
"heatmaps": gaze_heatmap,
|
288 |
+
"gazes": cont_gaze,
|
289 |
+
"gaze_inouts": torch.FloatTensor([gaze_inside]),
|
290 |
+
"head_masks": head_mask,
|
291 |
+
"imsize": imsize,
|
292 |
+
}
|
293 |
+
|
294 |
+
def __len__(self):
|
295 |
+
return self.length
|
data/masking.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class SceneMaskGenerator:
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
input_size,
|
12 |
+
min_num_patches=16,
|
13 |
+
max_num_patches_ratio=0.5,
|
14 |
+
min_aspect=0.3,
|
15 |
+
):
|
16 |
+
if not isinstance(input_size, tuple):
|
17 |
+
input_size = (input_size,) * 2
|
18 |
+
self.input_size = input_size
|
19 |
+
self.num_patches = input_size[0] * input_size[1]
|
20 |
+
|
21 |
+
self.min_num_patches = min_num_patches
|
22 |
+
self.max_num_patches = max_num_patches_ratio * self.num_patches
|
23 |
+
|
24 |
+
self.log_aspect_ratio = (math.log(min_aspect), -math.log(min_aspect))
|
25 |
+
|
26 |
+
def _mask(self, mask, max_mask_patches):
|
27 |
+
delta = 0
|
28 |
+
for _ in range(4):
|
29 |
+
target_area = random.uniform(self.min_num_patches, max_mask_patches)
|
30 |
+
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
|
31 |
+
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
32 |
+
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
33 |
+
height, width = self.input_size
|
34 |
+
if w < width and h < height:
|
35 |
+
top = random.randint(0, height - h)
|
36 |
+
left = random.randint(0, width - w)
|
37 |
+
|
38 |
+
num_masked = mask[top : top + h, left : left + w].sum()
|
39 |
+
# Overlap
|
40 |
+
if 0 < h * w - num_masked <= max_mask_patches:
|
41 |
+
mask[top : top + h, left : left + w] = 1
|
42 |
+
delta = h * w - num_masked
|
43 |
+
break
|
44 |
+
return delta
|
45 |
+
|
46 |
+
def __call__(self, head_mask):
|
47 |
+
mask = np.zeros(shape=self.input_size, dtype=bool)
|
48 |
+
mask_count = 0
|
49 |
+
num_masking_patches = random.uniform(self.min_num_patches, self.max_num_patches)
|
50 |
+
while mask_count < num_masking_patches:
|
51 |
+
max_mask_patches = num_masking_patches - mask_count
|
52 |
+
delta = self._mask(mask, max_mask_patches)
|
53 |
+
if delta == 0:
|
54 |
+
break
|
55 |
+
else:
|
56 |
+
mask_count += delta
|
57 |
+
|
58 |
+
mask = torch.from_numpy(mask).unsqueeze(0)
|
59 |
+
head_mask = (
|
60 |
+
F.interpolate(head_mask.unsqueeze(0), mask.shape[-2:]).squeeze(0) < 0.5
|
61 |
+
)
|
62 |
+
return torch.logical_and(mask, head_mask).squeeze(0)
|
63 |
+
|
64 |
+
|
65 |
+
class HeadMaskGenerator:
|
66 |
+
def __init__(
|
67 |
+
self,
|
68 |
+
input_size,
|
69 |
+
min_num_patches=4,
|
70 |
+
max_num_patches_ratio=0.5,
|
71 |
+
min_aspect=0.3,
|
72 |
+
):
|
73 |
+
if not isinstance(input_size, tuple):
|
74 |
+
input_size = (input_size,) * 2
|
75 |
+
self.input_size = input_size
|
76 |
+
self.num_patches = input_size[0] * input_size[1]
|
77 |
+
|
78 |
+
self.min_num_patches = min_num_patches
|
79 |
+
self.max_num_patches_ratio = max_num_patches_ratio
|
80 |
+
|
81 |
+
self.log_aspect_ratio = (math.log(min_aspect), -math.log(min_aspect))
|
82 |
+
|
83 |
+
def __call__(
|
84 |
+
self,
|
85 |
+
x_min,
|
86 |
+
y_min,
|
87 |
+
x_max,
|
88 |
+
y_max, # coords in [0,1]
|
89 |
+
):
|
90 |
+
height = math.floor((y_max - y_min) * self.input_size[0])
|
91 |
+
width = math.floor((x_max - x_min) * self.input_size[1])
|
92 |
+
origin_area = width * height
|
93 |
+
if origin_area < self.min_num_patches:
|
94 |
+
return torch.zeros(size=self.input_size, dtype=bool)
|
95 |
+
|
96 |
+
target_area = random.uniform(
|
97 |
+
self.min_num_patches, self.max_num_patches_ratio * origin_area
|
98 |
+
)
|
99 |
+
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
|
100 |
+
h = min(int(round(math.sqrt(target_area * aspect_ratio))), height)
|
101 |
+
w = min(int(round(math.sqrt(target_area / aspect_ratio))), width)
|
102 |
+
top = random.randint(0, height - h) + int(y_min * self.input_size[0])
|
103 |
+
left = random.randint(0, width - w) + int(x_min * self.input_size[1])
|
104 |
+
mask = torch.zeros(size=self.input_size, dtype=bool)
|
105 |
+
mask[top : top + h, left : left + w] = True
|
106 |
+
return mask
|
107 |
+
|
108 |
+
|
109 |
+
class MaskGenerator:
|
110 |
+
def __init__(
|
111 |
+
self,
|
112 |
+
input_size,
|
113 |
+
mask_scene: bool = False,
|
114 |
+
mask_head: bool = False,
|
115 |
+
min_scene_patches=16,
|
116 |
+
max_scene_patches_ratio=0.5,
|
117 |
+
min_head_patches=4,
|
118 |
+
max_head_patches_ratio=0.5,
|
119 |
+
min_aspect=0.3,
|
120 |
+
mask_prob=0.2,
|
121 |
+
head_prob=0.2,
|
122 |
+
):
|
123 |
+
if not isinstance(input_size, tuple):
|
124 |
+
input_size = (input_size,) * 2
|
125 |
+
self.input_size = input_size
|
126 |
+
if mask_scene:
|
127 |
+
self.scene_mask_generator = SceneMaskGenerator(
|
128 |
+
input_size, min_scene_patches, max_scene_patches_ratio, min_aspect
|
129 |
+
)
|
130 |
+
else:
|
131 |
+
self.scene_mask_generator = None
|
132 |
+
|
133 |
+
if mask_head:
|
134 |
+
self.head_mask_generator = HeadMaskGenerator(
|
135 |
+
input_size, min_head_patches, max_head_patches_ratio, min_aspect
|
136 |
+
)
|
137 |
+
else:
|
138 |
+
self.head_mask_generator = None
|
139 |
+
|
140 |
+
self.no_mask = not (mask_scene or mask_head)
|
141 |
+
self.mask_head = mask_head and not mask_scene
|
142 |
+
self.mask_scene = mask_scene and not mask_head
|
143 |
+
self.scene_prob = mask_prob
|
144 |
+
self.head_prob = head_prob
|
145 |
+
|
146 |
+
def __call__(
|
147 |
+
self,
|
148 |
+
x_min,
|
149 |
+
y_min,
|
150 |
+
x_max,
|
151 |
+
y_max,
|
152 |
+
head_mask,
|
153 |
+
):
|
154 |
+
mask_scene = random.random() < self.scene_prob
|
155 |
+
mask_head = random.random() < self.head_prob
|
156 |
+
no_mask = (
|
157 |
+
self.no_mask
|
158 |
+
or (self.mask_head and not mask_head)
|
159 |
+
or (self.mask_scene and not mask_scene)
|
160 |
+
or not (mask_scene or mask_head)
|
161 |
+
)
|
162 |
+
if no_mask:
|
163 |
+
return torch.zeros(size=self.input_size, dtype=bool)
|
164 |
+
if self.mask_scene:
|
165 |
+
return self.scene_mask_generator(head_mask)
|
166 |
+
if self.mask_head:
|
167 |
+
return self.head_mask_generator(x_min, y_min, x_max, y_max)
|
168 |
+
if mask_head and mask_scene:
|
169 |
+
return torch.logical_or(
|
170 |
+
self.scene_mask_generator(head_mask),
|
171 |
+
self.head_mask_generator(x_min, y_min, x_max, y_max),
|
172 |
+
)
|
173 |
+
elif mask_head:
|
174 |
+
return self.head_mask_generator(x_min, y_min, x_max, y_max)
|
175 |
+
return self.scene_mask_generator(head_mask)
|
data/video_attention_target.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
from typing import Callable, Optional
|
3 |
+
from os import path as osp
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch.utils.data.dataset import Dataset
|
7 |
+
import torchvision.transforms.functional as TF
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
from . import augmentation
|
13 |
+
from . import data_utils as utils
|
14 |
+
from .masking import MaskGenerator
|
15 |
+
|
16 |
+
|
17 |
+
class VideoAttentionTarget(Dataset):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
image_root: str,
|
21 |
+
anno_root: str,
|
22 |
+
head_root: str,
|
23 |
+
transform: Callable,
|
24 |
+
input_size: int,
|
25 |
+
output_size: int,
|
26 |
+
quant_labelmap: bool = True,
|
27 |
+
is_train: bool = True,
|
28 |
+
*,
|
29 |
+
mask_generator: Optional[MaskGenerator] = None,
|
30 |
+
bbox_jitter: float = 0.5,
|
31 |
+
rand_crop: float = 0.5,
|
32 |
+
rand_flip: float = 0.5,
|
33 |
+
color_jitter: float = 0.5,
|
34 |
+
rand_rotate: float = 0.0,
|
35 |
+
rand_lsj: float = 0.0,
|
36 |
+
):
|
37 |
+
frames = []
|
38 |
+
for show_dir in glob.glob(osp.join(anno_root, "*")):
|
39 |
+
for sequence_path in glob.glob(osp.join(show_dir, "*", "*.txt")):
|
40 |
+
df = pd.read_csv(
|
41 |
+
sequence_path,
|
42 |
+
header=None,
|
43 |
+
index_col=False,
|
44 |
+
names=[
|
45 |
+
"path",
|
46 |
+
"x_min",
|
47 |
+
"y_min",
|
48 |
+
"x_max",
|
49 |
+
"y_max",
|
50 |
+
"gaze_x",
|
51 |
+
"gaze_y",
|
52 |
+
],
|
53 |
+
)
|
54 |
+
|
55 |
+
show_name = sequence_path.split("/")[-3]
|
56 |
+
clip = sequence_path.split("/")[-2]
|
57 |
+
df["path"] = df["path"].apply(
|
58 |
+
lambda path: osp.join(show_name, clip, path)
|
59 |
+
)
|
60 |
+
# Add two columns for the bbox center
|
61 |
+
df["eye_x"] = (df["x_min"] + df["x_max"]) / 2
|
62 |
+
df["eye_y"] = (df["y_min"] + df["y_max"]) / 2
|
63 |
+
df = df.sample(frac=0.2, random_state=42)
|
64 |
+
frames.extend(df.values.tolist())
|
65 |
+
|
66 |
+
df = pd.DataFrame(
|
67 |
+
frames,
|
68 |
+
columns=[
|
69 |
+
"path",
|
70 |
+
"x_min",
|
71 |
+
"y_min",
|
72 |
+
"x_max",
|
73 |
+
"y_max",
|
74 |
+
"gaze_x",
|
75 |
+
"gaze_y",
|
76 |
+
"eye_x",
|
77 |
+
"eye_y",
|
78 |
+
],
|
79 |
+
)
|
80 |
+
# Drop rows with invalid bboxes
|
81 |
+
coords = torch.tensor(
|
82 |
+
np.array(
|
83 |
+
(
|
84 |
+
df["x_min"].values,
|
85 |
+
df["y_min"].values,
|
86 |
+
df["x_max"].values,
|
87 |
+
df["y_max"].values,
|
88 |
+
)
|
89 |
+
).transpose(1, 0)
|
90 |
+
)
|
91 |
+
valid_bboxes = (coords[:, 2:] >= coords[:, :2]).all(dim=1)
|
92 |
+
df = df.loc[valid_bboxes.tolist(), :]
|
93 |
+
df.reset_index(inplace=True)
|
94 |
+
self.df = df
|
95 |
+
self.length = len(df)
|
96 |
+
|
97 |
+
self.data_dir = image_root
|
98 |
+
self.head_dir = head_root
|
99 |
+
self.transform = transform
|
100 |
+
self.draw_labelmap = (
|
101 |
+
utils.draw_labelmap if quant_labelmap else utils.draw_labelmap_no_quant
|
102 |
+
)
|
103 |
+
self.is_train = is_train
|
104 |
+
|
105 |
+
self.input_size = input_size
|
106 |
+
self.output_size = output_size
|
107 |
+
|
108 |
+
if self.is_train:
|
109 |
+
## data augmentation
|
110 |
+
self.augment = augmentation.AugmentationList(
|
111 |
+
[
|
112 |
+
augmentation.ColorJitter(color_jitter),
|
113 |
+
augmentation.BoxJitter(bbox_jitter),
|
114 |
+
augmentation.RandomCrop(rand_crop),
|
115 |
+
augmentation.RandomFlip(rand_flip),
|
116 |
+
augmentation.RandomRotate(rand_rotate),
|
117 |
+
augmentation.RandomLSJ(rand_lsj),
|
118 |
+
]
|
119 |
+
)
|
120 |
+
|
121 |
+
self.mask_generator = mask_generator
|
122 |
+
|
123 |
+
def __getitem__(self, index):
|
124 |
+
(
|
125 |
+
_,
|
126 |
+
path,
|
127 |
+
x_min,
|
128 |
+
y_min,
|
129 |
+
x_max,
|
130 |
+
y_max,
|
131 |
+
gaze_x,
|
132 |
+
gaze_y,
|
133 |
+
eye_x,
|
134 |
+
eye_y,
|
135 |
+
) = self.df.iloc[index]
|
136 |
+
gaze_inside = gaze_x != -1 or gaze_y != -1
|
137 |
+
|
138 |
+
img = Image.open(osp.join(self.data_dir, path))
|
139 |
+
img = img.convert("RGB")
|
140 |
+
width, height = img.size
|
141 |
+
# Since we finetune from weights trained on GazeFollow,
|
142 |
+
# we don't incorporate the auxiliary task for VAT.
|
143 |
+
if osp.exists(osp.join(self.head_dir, path)):
|
144 |
+
head_mask = Image.open(osp.join(self.head_dir, path)).resize(
|
145 |
+
(width, height)
|
146 |
+
)
|
147 |
+
else:
|
148 |
+
head_mask = Image.fromarray(np.zeros((height, width), dtype=np.float32))
|
149 |
+
x_min, y_min, x_max, y_max = map(float, [x_min, y_min, x_max, y_max])
|
150 |
+
if x_max < x_min:
|
151 |
+
x_min, x_max = x_max, x_min
|
152 |
+
if y_max < y_min:
|
153 |
+
y_min, y_max = y_max, y_min
|
154 |
+
gaze_x, gaze_y = gaze_x / width, gaze_y / height
|
155 |
+
# expand face bbox a bit
|
156 |
+
k = 0.1
|
157 |
+
x_min = max(x_min - k * abs(x_max - x_min), 0)
|
158 |
+
y_min = max(y_min - k * abs(y_max - y_min), 0)
|
159 |
+
x_max = min(x_max + k * abs(x_max - x_min), width - 1)
|
160 |
+
y_max = min(y_max + k * abs(y_max - y_min), height - 1)
|
161 |
+
|
162 |
+
if self.is_train:
|
163 |
+
img, bbox, gaze, head_mask, size = self.augment(
|
164 |
+
img,
|
165 |
+
(x_min, y_min, x_max, y_max),
|
166 |
+
(gaze_x, gaze_y),
|
167 |
+
head_mask,
|
168 |
+
(width, height),
|
169 |
+
)
|
170 |
+
x_min, y_min, x_max, y_max = bbox
|
171 |
+
gaze_x, gaze_y = gaze
|
172 |
+
width, height = size
|
173 |
+
|
174 |
+
head_channel = utils.get_head_box_channel(
|
175 |
+
x_min,
|
176 |
+
y_min,
|
177 |
+
x_max,
|
178 |
+
y_max,
|
179 |
+
width,
|
180 |
+
height,
|
181 |
+
resolution=self.input_size,
|
182 |
+
coordconv=False,
|
183 |
+
).unsqueeze(0)
|
184 |
+
|
185 |
+
if self.is_train and self.mask_generator is not None:
|
186 |
+
image_mask = self.mask_generator(
|
187 |
+
x_min / width,
|
188 |
+
y_min / height,
|
189 |
+
x_max / width,
|
190 |
+
y_max / height,
|
191 |
+
head_channel,
|
192 |
+
)
|
193 |
+
|
194 |
+
if self.transform is not None:
|
195 |
+
img = self.transform(img)
|
196 |
+
head_mask = TF.to_tensor(
|
197 |
+
TF.resize(head_mask, (self.input_size, self.input_size))
|
198 |
+
)
|
199 |
+
|
200 |
+
# generate the heat map used for deconv prediction
|
201 |
+
gaze_heatmap = torch.zeros(
|
202 |
+
self.output_size, self.output_size
|
203 |
+
) # set the size of the output
|
204 |
+
|
205 |
+
gaze_heatmap = self.draw_labelmap(
|
206 |
+
gaze_heatmap,
|
207 |
+
[gaze_x * self.output_size, gaze_y * self.output_size],
|
208 |
+
3,
|
209 |
+
type="Gaussian",
|
210 |
+
)
|
211 |
+
|
212 |
+
imsize = torch.IntTensor([width, height])
|
213 |
+
|
214 |
+
out_dict = {
|
215 |
+
"images": img,
|
216 |
+
"head_channels": head_channel,
|
217 |
+
"heatmaps": gaze_heatmap,
|
218 |
+
"gazes": torch.FloatTensor([gaze_x, gaze_y]),
|
219 |
+
"gaze_inouts": torch.FloatTensor([gaze_inside]),
|
220 |
+
"head_masks": head_mask,
|
221 |
+
"imsize": imsize,
|
222 |
+
}
|
223 |
+
if self.is_train and self.mask_generator is not None:
|
224 |
+
out_dict["image_masks"] = image_mask
|
225 |
+
return out_dict
|
226 |
+
|
227 |
+
def __len__(self):
|
228 |
+
return self.length
|
data/video_attention_target_video.py
ADDED
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from os import path as osp
|
3 |
+
from typing import Callable, Optional
|
4 |
+
import glob
|
5 |
+
import torch
|
6 |
+
from torch.utils.data.dataset import Dataset
|
7 |
+
import torchvision.transforms.functional as TF
|
8 |
+
import numpy as np
|
9 |
+
from PIL import Image, ImageOps
|
10 |
+
import pandas as pd
|
11 |
+
from .masking import MaskGenerator
|
12 |
+
from . import data_utils as utils
|
13 |
+
|
14 |
+
|
15 |
+
class VideoAttentionTargetVideo(Dataset):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
image_root: str,
|
19 |
+
anno_root: str,
|
20 |
+
head_root: str,
|
21 |
+
transform: Callable,
|
22 |
+
input_size: int,
|
23 |
+
output_size: int,
|
24 |
+
quant_labelmap: bool = True,
|
25 |
+
is_train: bool = True,
|
26 |
+
seq_len: int = 8,
|
27 |
+
max_len: int = 32,
|
28 |
+
*,
|
29 |
+
mask_generator: Optional[MaskGenerator] = None,
|
30 |
+
bbox_jitter: float = 0.5,
|
31 |
+
rand_crop: float = 0.5,
|
32 |
+
rand_flip: float = 0.5,
|
33 |
+
color_jitter: float = 0.5,
|
34 |
+
rand_rotate: float = 0.0,
|
35 |
+
rand_lsj: float = 0.0,
|
36 |
+
):
|
37 |
+
dfs = []
|
38 |
+
for show_dir in glob.glob(osp.join(anno_root, "*")):
|
39 |
+
for sequence_path in glob.glob(osp.join(show_dir, "*", "*.txt")):
|
40 |
+
df = pd.read_csv(
|
41 |
+
sequence_path,
|
42 |
+
header=None,
|
43 |
+
index_col=False,
|
44 |
+
names=[
|
45 |
+
"path",
|
46 |
+
"x_min",
|
47 |
+
"y_min",
|
48 |
+
"x_max",
|
49 |
+
"y_max",
|
50 |
+
"gaze_x",
|
51 |
+
"gaze_y",
|
52 |
+
],
|
53 |
+
)
|
54 |
+
show_name = sequence_path.split("/")[-3]
|
55 |
+
clip = sequence_path.split("/")[-2]
|
56 |
+
df["path"] = df["path"].apply(
|
57 |
+
lambda path: osp.join(show_name, clip, path)
|
58 |
+
)
|
59 |
+
cur_len = len(df.index)
|
60 |
+
if is_train:
|
61 |
+
if cur_len <= max_len:
|
62 |
+
if cur_len >= seq_len:
|
63 |
+
dfs.append(df)
|
64 |
+
continue
|
65 |
+
remainder = cur_len % max_len
|
66 |
+
df_splits = [
|
67 |
+
df[i : i + max_len]
|
68 |
+
for i in range(0, cur_len - max_len, max_len)
|
69 |
+
]
|
70 |
+
if remainder >= seq_len:
|
71 |
+
df_splits.append(df[-remainder:])
|
72 |
+
dfs.extend(df_splits)
|
73 |
+
else:
|
74 |
+
if cur_len < seq_len:
|
75 |
+
continue
|
76 |
+
df_splits = [
|
77 |
+
df[i : i + seq_len]
|
78 |
+
for i in range(0, cur_len - seq_len, seq_len)
|
79 |
+
]
|
80 |
+
dfs.extend(df_splits)
|
81 |
+
|
82 |
+
for df in dfs:
|
83 |
+
df.reset_index(inplace=True)
|
84 |
+
self.dfs = dfs
|
85 |
+
self.length = len(dfs)
|
86 |
+
|
87 |
+
self.data_dir = image_root
|
88 |
+
self.head_dir = head_root
|
89 |
+
self.transform = transform
|
90 |
+
self.draw_labelmap = (
|
91 |
+
utils.draw_labelmap if quant_labelmap else utils.draw_labelmap_no_quant
|
92 |
+
)
|
93 |
+
self.is_train = is_train
|
94 |
+
|
95 |
+
self.input_size = input_size
|
96 |
+
self.output_size = output_size
|
97 |
+
self.seq_len = seq_len
|
98 |
+
|
99 |
+
if self.is_train:
|
100 |
+
self.bbox_jitter = bbox_jitter
|
101 |
+
self.rand_crop = rand_crop
|
102 |
+
self.rand_flip = rand_flip
|
103 |
+
self.color_jitter = color_jitter
|
104 |
+
self.rand_rotate = rand_rotate
|
105 |
+
self.rand_lsj = rand_lsj
|
106 |
+
self.mask_generator = mask_generator
|
107 |
+
|
108 |
+
def __getitem__(self, index):
|
109 |
+
df = self.dfs[index]
|
110 |
+
seq_len = len(df.index)
|
111 |
+
for coord in ["x_min", "y_min", "x_max", "y_max"]:
|
112 |
+
df[coord] = utils.smooth_by_conv(11, df, coord)
|
113 |
+
|
114 |
+
if self.is_train:
|
115 |
+
# cond for data augmentation
|
116 |
+
cond_jitter = np.random.random_sample()
|
117 |
+
cond_flip = np.random.random_sample()
|
118 |
+
cond_color = np.random.random_sample()
|
119 |
+
if cond_color < self.color_jitter:
|
120 |
+
n1 = np.random.uniform(0.5, 1.5)
|
121 |
+
n2 = np.random.uniform(0.5, 1.5)
|
122 |
+
n3 = np.random.uniform(0.5, 1.5)
|
123 |
+
cond_crop = np.random.random_sample()
|
124 |
+
cond_rotate = np.random.random_sample()
|
125 |
+
if cond_rotate < self.rand_rotate:
|
126 |
+
angle = (2 * np.random.random_sample() - 1) * 20
|
127 |
+
angle = -math.radians(angle)
|
128 |
+
cond_lsj = np.random.random_sample()
|
129 |
+
if cond_lsj < self.rand_lsj:
|
130 |
+
lsj_scale = 0.1 + np.random.random_sample() * 0.9
|
131 |
+
|
132 |
+
# if longer than seq_len_limit, cut it down to the limit with the init index randomly sampled
|
133 |
+
if seq_len > self.seq_len:
|
134 |
+
sampled_ind = np.random.randint(0, seq_len - self.seq_len)
|
135 |
+
seq_len = self.seq_len
|
136 |
+
else:
|
137 |
+
sampled_ind = 0
|
138 |
+
|
139 |
+
if cond_crop < self.rand_crop:
|
140 |
+
sliced_x_min = df["x_min"].iloc[sampled_ind : sampled_ind + seq_len]
|
141 |
+
sliced_x_max = df["x_max"].iloc[sampled_ind : sampled_ind + seq_len]
|
142 |
+
sliced_y_min = df["y_min"].iloc[sampled_ind : sampled_ind + seq_len]
|
143 |
+
sliced_y_max = df["y_max"].iloc[sampled_ind : sampled_ind + seq_len]
|
144 |
+
|
145 |
+
sliced_gaze_x = df["gaze_x"].iloc[sampled_ind : sampled_ind + seq_len]
|
146 |
+
sliced_gaze_y = df["gaze_y"].iloc[sampled_ind : sampled_ind + seq_len]
|
147 |
+
|
148 |
+
check_sum = sliced_gaze_x.sum() + sliced_gaze_y.sum()
|
149 |
+
all_outside = check_sum == -2 * seq_len
|
150 |
+
|
151 |
+
# Calculate the minimum valid range of the crop that doesn't exclude the face and the gaze target
|
152 |
+
if all_outside:
|
153 |
+
crop_x_min = np.min([sliced_x_min.min(), sliced_x_max.min()])
|
154 |
+
crop_y_min = np.min([sliced_y_min.min(), sliced_y_max.min()])
|
155 |
+
crop_x_max = np.max([sliced_x_min.max(), sliced_x_max.max()])
|
156 |
+
crop_y_max = np.max([sliced_y_min.max(), sliced_y_max.max()])
|
157 |
+
else:
|
158 |
+
crop_x_min = np.min(
|
159 |
+
[sliced_gaze_x.min(), sliced_x_min.min(), sliced_x_max.min()]
|
160 |
+
)
|
161 |
+
crop_y_min = np.min(
|
162 |
+
[sliced_gaze_y.min(), sliced_y_min.min(), sliced_y_max.min()]
|
163 |
+
)
|
164 |
+
crop_x_max = np.max(
|
165 |
+
[sliced_gaze_x.max(), sliced_x_min.max(), sliced_x_max.max()]
|
166 |
+
)
|
167 |
+
crop_y_max = np.max(
|
168 |
+
[sliced_gaze_y.max(), sliced_y_min.max(), sliced_y_max.max()]
|
169 |
+
)
|
170 |
+
|
171 |
+
# Randomly select a random top left corner
|
172 |
+
if crop_x_min >= 0:
|
173 |
+
crop_x_min = np.random.uniform(0, crop_x_min)
|
174 |
+
if crop_y_min >= 0:
|
175 |
+
crop_y_min = np.random.uniform(0, crop_y_min)
|
176 |
+
|
177 |
+
# Get image size
|
178 |
+
path = osp.join(self.data_dir, df["path"].iloc[0])
|
179 |
+
img = Image.open(path)
|
180 |
+
img = img.convert("RGB")
|
181 |
+
width, height = img.size
|
182 |
+
|
183 |
+
# Find the range of valid crop width and height starting from the (crop_x_min, crop_y_min)
|
184 |
+
crop_width_min = crop_x_max - crop_x_min
|
185 |
+
crop_height_min = crop_y_max - crop_y_min
|
186 |
+
crop_width_max = width - crop_x_min
|
187 |
+
crop_height_max = height - crop_y_min
|
188 |
+
# Randomly select a width and a height
|
189 |
+
crop_width = np.random.uniform(crop_width_min, crop_width_max)
|
190 |
+
crop_height = np.random.uniform(crop_height_min, crop_height_max)
|
191 |
+
|
192 |
+
# Round to integers
|
193 |
+
crop_y_min, crop_x_min, crop_height, crop_width = map(
|
194 |
+
int, map(round, (crop_y_min, crop_x_min, crop_height, crop_width))
|
195 |
+
)
|
196 |
+
else:
|
197 |
+
sampled_ind = 0
|
198 |
+
|
199 |
+
images = []
|
200 |
+
head_channels = []
|
201 |
+
heatmaps = []
|
202 |
+
gazes = []
|
203 |
+
gaze_inouts = []
|
204 |
+
imsizes = []
|
205 |
+
head_masks = []
|
206 |
+
if self.is_train and self.mask_generator is not None:
|
207 |
+
image_masks = []
|
208 |
+
for i, row in df.iterrows():
|
209 |
+
if self.is_train and (i < sampled_ind or i >= (sampled_ind + self.seq_len)):
|
210 |
+
continue
|
211 |
+
|
212 |
+
x_min = row["x_min"] # note: Already in image coordinates
|
213 |
+
y_min = row["y_min"] # note: Already in image coordinates
|
214 |
+
x_max = row["x_max"] # note: Already in image coordinates
|
215 |
+
y_max = row["y_max"] # note: Already in image coordinates
|
216 |
+
gaze_x = row["gaze_x"] # note: Already in image coordinates
|
217 |
+
gaze_y = row["gaze_y"] # note: Already in image coordinates
|
218 |
+
|
219 |
+
if x_min > x_max:
|
220 |
+
x_min, x_max = x_max, x_min
|
221 |
+
if y_min > y_max:
|
222 |
+
y_min, y_max = y_max, y_min
|
223 |
+
|
224 |
+
path = row["path"]
|
225 |
+
img = Image.open(osp.join(self.data_dir, path)).convert("RGB")
|
226 |
+
width, height = img.size
|
227 |
+
imsize = torch.FloatTensor([width, height])
|
228 |
+
imsizes.append(imsize)
|
229 |
+
# Since we finetune from weights trained on GazeFollow,
|
230 |
+
# we don't incorporate the auxiliary task for VAT.
|
231 |
+
if osp.exists(osp.join(self.head_dir, path)):
|
232 |
+
head_mask = Image.open(osp.join(self.head_dir, path)).resize(
|
233 |
+
(width, height)
|
234 |
+
)
|
235 |
+
else:
|
236 |
+
head_mask = Image.fromarray(np.zeros((height, width), dtype=np.float32))
|
237 |
+
|
238 |
+
x_min, y_min, x_max, y_max = map(float, [x_min, y_min, x_max, y_max])
|
239 |
+
gaze_x, gaze_y = map(float, [gaze_x, gaze_y])
|
240 |
+
if gaze_x == -1 and gaze_y == -1:
|
241 |
+
gaze_inside = False
|
242 |
+
else:
|
243 |
+
if (
|
244 |
+
gaze_x < 0
|
245 |
+
): # move gaze point that was sliglty outside the image back in
|
246 |
+
gaze_x = 0
|
247 |
+
if gaze_y < 0:
|
248 |
+
gaze_y = 0
|
249 |
+
gaze_inside = True
|
250 |
+
|
251 |
+
if self.is_train:
|
252 |
+
## data augmentation
|
253 |
+
# Jitter (expansion-only) bounding box size.
|
254 |
+
if cond_jitter < self.bbox_jitter:
|
255 |
+
k = cond_jitter * 0.1
|
256 |
+
x_min -= k * abs(x_max - x_min)
|
257 |
+
y_min -= k * abs(y_max - y_min)
|
258 |
+
x_max += k * abs(x_max - x_min)
|
259 |
+
y_max += k * abs(y_max - y_min)
|
260 |
+
x_min = np.clip(x_min, 0, width - 1)
|
261 |
+
x_max = np.clip(x_max, 0, width - 1)
|
262 |
+
y_min = np.clip(y_min, 0, height - 1)
|
263 |
+
y_max = np.clip(y_max, 0, height - 1)
|
264 |
+
|
265 |
+
# Random color change
|
266 |
+
if cond_color < self.color_jitter:
|
267 |
+
img = TF.adjust_brightness(img, brightness_factor=n1)
|
268 |
+
img = TF.adjust_contrast(img, contrast_factor=n2)
|
269 |
+
img = TF.adjust_saturation(img, saturation_factor=n3)
|
270 |
+
|
271 |
+
# Random Crop
|
272 |
+
if cond_crop < self.rand_crop:
|
273 |
+
# Crop it
|
274 |
+
img = TF.crop(img, crop_y_min, crop_x_min, crop_height, crop_width)
|
275 |
+
head_mask = TF.crop(
|
276 |
+
head_mask, crop_y_min, crop_x_min, crop_height, crop_width
|
277 |
+
)
|
278 |
+
|
279 |
+
# Record the crop's (x, y) offset
|
280 |
+
offset_x, offset_y = crop_x_min, crop_y_min
|
281 |
+
|
282 |
+
# convert coordinates into the cropped frame
|
283 |
+
x_min, y_min, x_max, y_max = (
|
284 |
+
x_min - offset_x,
|
285 |
+
y_min - offset_y,
|
286 |
+
x_max - offset_x,
|
287 |
+
y_max - offset_y,
|
288 |
+
)
|
289 |
+
if gaze_inside:
|
290 |
+
gaze_x, gaze_y = (gaze_x - offset_x), (gaze_y - offset_y)
|
291 |
+
else:
|
292 |
+
gaze_x = -1
|
293 |
+
gaze_y = -1
|
294 |
+
|
295 |
+
width, height = crop_width, crop_height
|
296 |
+
|
297 |
+
# Flip?
|
298 |
+
if cond_flip < self.rand_flip:
|
299 |
+
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
300 |
+
head_mask = head_mask.transpose(Image.FLIP_LEFT_RIGHT)
|
301 |
+
x_max_2 = width - x_min
|
302 |
+
x_min_2 = width - x_max
|
303 |
+
x_max = x_max_2
|
304 |
+
x_min = x_min_2
|
305 |
+
if gaze_x != -1 and gaze_y != -1:
|
306 |
+
gaze_x = width - gaze_x
|
307 |
+
|
308 |
+
# Random Rotation
|
309 |
+
if cond_rotate < self.rand_rotate:
|
310 |
+
rot_mat = [
|
311 |
+
round(math.cos(angle), 15),
|
312 |
+
round(math.sin(angle), 15),
|
313 |
+
0.0,
|
314 |
+
round(-math.sin(angle), 15),
|
315 |
+
round(math.cos(angle), 15),
|
316 |
+
0.0,
|
317 |
+
]
|
318 |
+
|
319 |
+
def _transform(x, y, matrix):
|
320 |
+
return (
|
321 |
+
matrix[0] * x + matrix[1] * y + matrix[2],
|
322 |
+
matrix[3] * x + matrix[4] * y + matrix[5],
|
323 |
+
)
|
324 |
+
|
325 |
+
def _inv_transform(x, y, matrix):
|
326 |
+
x, y = x - matrix[2], y - matrix[5]
|
327 |
+
return (
|
328 |
+
matrix[0] * x + matrix[3] * y,
|
329 |
+
matrix[1] * x + matrix[4] * y,
|
330 |
+
)
|
331 |
+
|
332 |
+
# Calculate offsets
|
333 |
+
rot_center = (width / 2.0, height / 2.0)
|
334 |
+
rot_mat[2], rot_mat[5] = _transform(
|
335 |
+
-rot_center[0], -rot_center[1], rot_mat
|
336 |
+
)
|
337 |
+
rot_mat[2] += rot_center[0]
|
338 |
+
rot_mat[5] += rot_center[1]
|
339 |
+
xx = []
|
340 |
+
yy = []
|
341 |
+
for x, y in ((0, 0), (width, 0), (width, height), (0, height)):
|
342 |
+
x, y = _transform(x, y, rot_mat)
|
343 |
+
xx.append(x)
|
344 |
+
yy.append(y)
|
345 |
+
nw = math.ceil(max(xx)) - math.floor(min(xx))
|
346 |
+
nh = math.ceil(max(yy)) - math.floor(min(yy))
|
347 |
+
rot_mat[2], rot_mat[5] = _transform(
|
348 |
+
-(nw - width) / 2.0, -(nh - height) / 2.0, rot_mat
|
349 |
+
)
|
350 |
+
|
351 |
+
img = img.transform((nw, nh), Image.AFFINE, rot_mat, Image.BILINEAR)
|
352 |
+
head_mask = head_mask.transform(
|
353 |
+
(nw, nh), Image.AFFINE, rot_mat, Image.BILINEAR
|
354 |
+
)
|
355 |
+
|
356 |
+
xx = []
|
357 |
+
yy = []
|
358 |
+
for x, y in (
|
359 |
+
(x_min, y_min),
|
360 |
+
(x_min, y_max),
|
361 |
+
(x_max, y_min),
|
362 |
+
(x_max, y_max),
|
363 |
+
):
|
364 |
+
x, y = _inv_transform(x, y, rot_mat)
|
365 |
+
xx.append(x)
|
366 |
+
yy.append(y)
|
367 |
+
x_max, x_min = min(max(xx), nw), max(min(xx), 0)
|
368 |
+
y_max, y_min = min(max(yy), nh), max(min(yy), 0)
|
369 |
+
gaze_x, gaze_y = _inv_transform(gaze_x, gaze_y, rot_mat)
|
370 |
+
width, height = nw, nh
|
371 |
+
|
372 |
+
if cond_lsj < self.rand_lsj:
|
373 |
+
nh, nw = int(height * lsj_scale), int(width * lsj_scale)
|
374 |
+
img = TF.resize(img, (nh, nw))
|
375 |
+
img = ImageOps.expand(img, (0, 0, width - nw, height - nh))
|
376 |
+
head_mask = TF.resize(head_mask, (nh, nw))
|
377 |
+
head_mask = ImageOps.expand(
|
378 |
+
head_mask, (0, 0, width - nw, height - nh)
|
379 |
+
)
|
380 |
+
x_min, y_min, x_max, y_max = (
|
381 |
+
x_min * lsj_scale,
|
382 |
+
y_min * lsj_scale,
|
383 |
+
x_max * lsj_scale,
|
384 |
+
y_max * lsj_scale,
|
385 |
+
)
|
386 |
+
gaze_x, gaze_y = gaze_x * lsj_scale, gaze_y * lsj_scale
|
387 |
+
|
388 |
+
head_channel = utils.get_head_box_channel(
|
389 |
+
x_min,
|
390 |
+
y_min,
|
391 |
+
x_max,
|
392 |
+
y_max,
|
393 |
+
width,
|
394 |
+
height,
|
395 |
+
resolution=self.input_size,
|
396 |
+
coordconv=False,
|
397 |
+
).unsqueeze(0)
|
398 |
+
|
399 |
+
if self.is_train and self.mask_generator is not None:
|
400 |
+
image_mask = self.mask_generator(
|
401 |
+
x_min / width,
|
402 |
+
y_min / height,
|
403 |
+
x_max / width,
|
404 |
+
y_max / height,
|
405 |
+
head_channel,
|
406 |
+
)
|
407 |
+
image_masks.append(image_mask)
|
408 |
+
|
409 |
+
if self.transform is not None:
|
410 |
+
img = self.transform(img)
|
411 |
+
head_mask = TF.to_tensor(
|
412 |
+
TF.resize(head_mask, (self.input_size, self.input_size))
|
413 |
+
)
|
414 |
+
|
415 |
+
if gaze_inside:
|
416 |
+
gaze_x /= float(width) # fractional gaze
|
417 |
+
gaze_y /= float(height)
|
418 |
+
gaze_heatmap = torch.zeros(
|
419 |
+
self.output_size, self.output_size
|
420 |
+
) # set the size of the output
|
421 |
+
gaze_map = self.draw_labelmap(
|
422 |
+
gaze_heatmap,
|
423 |
+
[gaze_x * self.output_size, gaze_y * self.output_size],
|
424 |
+
3,
|
425 |
+
type="Gaussian",
|
426 |
+
)
|
427 |
+
gazes.append(torch.FloatTensor([gaze_x, gaze_y]))
|
428 |
+
else:
|
429 |
+
gaze_map = torch.zeros(self.output_size, self.output_size)
|
430 |
+
gazes.append(torch.FloatTensor([-1, -1]))
|
431 |
+
images.append(img)
|
432 |
+
head_channels.append(head_channel)
|
433 |
+
head_masks.append(head_mask)
|
434 |
+
heatmaps.append(gaze_map)
|
435 |
+
gaze_inouts.append(torch.FloatTensor([int(gaze_inside)]))
|
436 |
+
|
437 |
+
images = torch.stack(images)
|
438 |
+
head_channels = torch.stack(head_channels)
|
439 |
+
heatmaps = torch.stack(heatmaps)
|
440 |
+
gazes = torch.stack(gazes)
|
441 |
+
gaze_inouts = torch.stack(gaze_inouts)
|
442 |
+
head_masks = torch.stack(head_masks)
|
443 |
+
imsizes = torch.stack(imsizes)
|
444 |
+
|
445 |
+
out_dict = {
|
446 |
+
"images": images,
|
447 |
+
"head_channels": head_channels,
|
448 |
+
"heatmaps": heatmaps,
|
449 |
+
"gazes": gazes,
|
450 |
+
"gaze_inouts": gaze_inouts,
|
451 |
+
"head_masks": head_masks,
|
452 |
+
"imsize": imsizes,
|
453 |
+
}
|
454 |
+
if self.is_train and self.mask_generator is not None:
|
455 |
+
out_dict["image_masks"] = torch.stack(image_masks)
|
456 |
+
return out_dict
|
457 |
+
|
458 |
+
def __len__(self):
|
459 |
+
return self.length
|
460 |
+
|
461 |
+
|
462 |
+
def video_collate(batch):
|
463 |
+
keys = batch[0].keys()
|
464 |
+
return {key: torch.cat([item[key] for item in batch]) for key in keys}
|
docs/eval.md
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Eval
|
2 |
+
### Testing Dataset
|
3 |
+
|
4 |
+
You should prepare GazeFollow and VideoAttentionTarget for training.
|
5 |
+
|
6 |
+
* Get [GazeFollow](https://www.dropbox.com/s/3ejt9pm57ht2ed4/gazefollow_extended.zip?dl=0).
|
7 |
+
* If train with auxiliary regression, use `scripts\gen_gazefollow_head_masks.py` to generate head masks.
|
8 |
+
* Get [VideoAttentionTarget](https://www.dropbox.com/s/8ep3y1hd74wdjy5/videoattentiontarget.zip?dl=0).
|
9 |
+
|
10 |
+
Check `ViTGaze/configs/common/dataloader` to modify DATA_ROOT.
|
11 |
+
|
12 |
+
### Evaluation
|
13 |
+
|
14 |
+
Run
|
15 |
+
```
|
16 |
+
bash val.sh configs/gazefollow_518.py ${Path2checkpoint} gf
|
17 |
+
```
|
18 |
+
to evaluate on GazeFollow.
|
19 |
+
|
20 |
+
Run
|
21 |
+
```
|
22 |
+
bash val.sh configs/videoattentiontarget.py ${Path2checkpoint} vat
|
23 |
+
```
|
24 |
+
to evaluate on VideoAttentionTarget.
|
docs/install.md
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Installation
|
2 |
+
|
3 |
+
* Create a conda virtual env and activate it.
|
4 |
+
|
5 |
+
```
|
6 |
+
conda create -n ViTGaze python==3.9.18
|
7 |
+
conda activate ViTGaze
|
8 |
+
```
|
9 |
+
* Install packages.
|
10 |
+
|
11 |
+
```
|
12 |
+
cd path/to/ViTGaze
|
13 |
+
pip install -r requirements.txt
|
14 |
+
```
|
15 |
+
* Install [detectron2](https://github.com/facebookresearch/detectron2) , follow its [documentation](https://detectron2.readthedocs.io/en/latest/).
|
16 |
+
For ViTGaze, we recommend to build it from latest source code.
|
17 |
+
```
|
18 |
+
python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'
|
19 |
+
```
|
docs/train.md
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Train
|
2 |
+
|
3 |
+
### Training Dataset
|
4 |
+
|
5 |
+
You should prepare GazeFollow and VideoAttentionTarget for training.
|
6 |
+
|
7 |
+
* Get [GazeFollow](https://www.dropbox.com/s/3ejt9pm57ht2ed4/gazefollow_extended.zip?dl=0).
|
8 |
+
* If train with auxiliary regression, use `scripts\gen_gazefollow_head_masks.py` to generate head masks.
|
9 |
+
* Get [VideoAttentionTarget](https://www.dropbox.com/s/8ep3y1hd74wdjy5/videoattentiontarget.zip?dl=0).
|
10 |
+
|
11 |
+
Check `ViTGaze/configs/common/dataloader` to modify DATA_ROOT.
|
12 |
+
|
13 |
+
### Pretrained Model
|
14 |
+
|
15 |
+
* Get [DINOv2](https://github.com/facebookresearch/dinov2) pretrained ViT-S.
|
16 |
+
* Or you could download and preprocess pretrained weights by
|
17 |
+
|
18 |
+
```
|
19 |
+
cd ViTGaze
|
20 |
+
mkdir pretrained && cd pretrained
|
21 |
+
wget https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth
|
22 |
+
```
|
23 |
+
* Preprocess the model weights with `scripts\convert_pth.py` to fit Detectron2 format.
|
24 |
+
### Train ViTGaze
|
25 |
+
|
26 |
+
You can modify configs in `configs/gazefollow.py`, `configs/gazefollow_518.py` and `configs/videoattentiontarget.py`.
|
27 |
+
|
28 |
+
Run:
|
29 |
+
|
30 |
+
```
|
31 |
+
bash train.sh
|
32 |
+
```
|
33 |
+
|
34 |
+
to train ViTGaze on the two datasets.
|
35 |
+
|
36 |
+
Training output will be saved in `ViTGaze/output/`.
|
engine/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .trainer import CycleTrainer
|
engine/trainer.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import torch
|
3 |
+
from detectron2.engine import SimpleTrainer
|
4 |
+
from typing import Iterable, Generator
|
5 |
+
|
6 |
+
|
7 |
+
def cycle(iterable: Iterable) -> Generator:
|
8 |
+
while True:
|
9 |
+
for item in iterable:
|
10 |
+
yield item
|
11 |
+
|
12 |
+
|
13 |
+
class CycleTrainer(SimpleTrainer):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
model,
|
17 |
+
data_loader,
|
18 |
+
optimizer,
|
19 |
+
gather_metric_period=1,
|
20 |
+
zero_grad_before_forward=False,
|
21 |
+
async_write_metrics=False,
|
22 |
+
):
|
23 |
+
super().__init__(
|
24 |
+
model,
|
25 |
+
data_loader,
|
26 |
+
optimizer,
|
27 |
+
gather_metric_period,
|
28 |
+
zero_grad_before_forward,
|
29 |
+
async_write_metrics,
|
30 |
+
)
|
31 |
+
|
32 |
+
@property
|
33 |
+
def _data_loader_iter(self):
|
34 |
+
# only create the data loader iterator when it is used
|
35 |
+
if self._data_loader_iter_obj is None:
|
36 |
+
self._data_loader_iter_obj = cycle(self.data_loader)
|
37 |
+
return self._data_loader_iter_obj
|
modeling/_init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import backbone, patch_attention, head, criterion, meta_arch
|
2 |
+
|
3 |
+
|
4 |
+
__all__ = ["backbone", "patch_attention", "head", "criterion", "meta_arch"]
|
modeling/backbone/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .vit import build_backbone
|
modeling/backbone/utils.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from itertools import repeat
|
3 |
+
from typing import Iterable
|
4 |
+
import math
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
__all__ = [
|
10 |
+
"get_abs_pos",
|
11 |
+
"PatchEmbed",
|
12 |
+
"Mlp",
|
13 |
+
"DropPath",
|
14 |
+
]
|
15 |
+
|
16 |
+
|
17 |
+
def to_2tuple(x):
|
18 |
+
if isinstance(x, Iterable) and not isinstance(x, str):
|
19 |
+
return tuple(x)
|
20 |
+
return tuple(repeat(x, 2))
|
21 |
+
|
22 |
+
|
23 |
+
def get_abs_pos(abs_pos, has_cls_token, hw):
|
24 |
+
"""
|
25 |
+
Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
|
26 |
+
dimension for the original embeddings.
|
27 |
+
Args:
|
28 |
+
abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
|
29 |
+
has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
|
30 |
+
hw (Tuple): size of input image tokens.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
Absolute positional embeddings after processing with shape (1, H, W, C)
|
34 |
+
"""
|
35 |
+
h, w = hw
|
36 |
+
if has_cls_token:
|
37 |
+
abs_pos = abs_pos[:, 1:]
|
38 |
+
xy_num = abs_pos.shape[1]
|
39 |
+
size = int(math.sqrt(xy_num))
|
40 |
+
assert size * size == xy_num
|
41 |
+
|
42 |
+
if size != h or size != w:
|
43 |
+
new_abs_pos = F.interpolate(
|
44 |
+
abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2),
|
45 |
+
size=(h, w),
|
46 |
+
mode="bicubic",
|
47 |
+
align_corners=False,
|
48 |
+
)
|
49 |
+
|
50 |
+
return new_abs_pos.permute(0, 2, 3, 1)
|
51 |
+
else:
|
52 |
+
return abs_pos.reshape(1, h, w, -1)
|
53 |
+
|
54 |
+
|
55 |
+
def drop_path(
|
56 |
+
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
57 |
+
):
|
58 |
+
if drop_prob == 0.0 or not training:
|
59 |
+
return x
|
60 |
+
keep_prob = 1 - drop_prob
|
61 |
+
shape = (x.shape[0],) + (1,) * (
|
62 |
+
x.ndim - 1
|
63 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
64 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
65 |
+
if keep_prob > 0.0 and scale_by_keep:
|
66 |
+
random_tensor.div_(keep_prob)
|
67 |
+
return x * random_tensor
|
68 |
+
|
69 |
+
|
70 |
+
class PatchEmbed(nn.Module):
|
71 |
+
"""
|
72 |
+
Image to Patch Embedding.
|
73 |
+
"""
|
74 |
+
|
75 |
+
def __init__(
|
76 |
+
self,
|
77 |
+
kernel_size=(16, 16),
|
78 |
+
stride=(16, 16),
|
79 |
+
padding=(0, 0),
|
80 |
+
in_chans=3,
|
81 |
+
embed_dim=768,
|
82 |
+
):
|
83 |
+
"""
|
84 |
+
Args:
|
85 |
+
kernel_size (Tuple): kernel size of the projection layer.
|
86 |
+
stride (Tuple): stride of the projection layer.
|
87 |
+
padding (Tuple): padding size of the projection layer.
|
88 |
+
in_chans (int): Number of input image channels.
|
89 |
+
embed_dim (int): embed_dim (int): Patch embedding dimension.
|
90 |
+
"""
|
91 |
+
super().__init__()
|
92 |
+
|
93 |
+
self.proj = nn.Conv2d(
|
94 |
+
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
|
95 |
+
)
|
96 |
+
|
97 |
+
def forward(self, x):
|
98 |
+
x = self.proj(x)
|
99 |
+
# B C H W -> B H W C
|
100 |
+
x = x.permute(0, 2, 3, 1)
|
101 |
+
return x
|
102 |
+
|
103 |
+
|
104 |
+
class DropPath(nn.Module):
|
105 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
106 |
+
|
107 |
+
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
|
108 |
+
super(DropPath, self).__init__()
|
109 |
+
self.drop_prob = drop_prob
|
110 |
+
self.scale_by_keep = scale_by_keep
|
111 |
+
|
112 |
+
def forward(self, x):
|
113 |
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
114 |
+
|
115 |
+
def extra_repr(self):
|
116 |
+
return f"drop_prob={round(self.drop_prob,3):0.3f}"
|
117 |
+
|
118 |
+
|
119 |
+
class Mlp(nn.Module):
|
120 |
+
def __init__(
|
121 |
+
self,
|
122 |
+
in_features,
|
123 |
+
hidden_features=None,
|
124 |
+
out_features=None,
|
125 |
+
act_layer=nn.GELU,
|
126 |
+
norm_layer=None,
|
127 |
+
bias=True,
|
128 |
+
drop=0.0,
|
129 |
+
use_conv=False,
|
130 |
+
):
|
131 |
+
super().__init__()
|
132 |
+
out_features = out_features or in_features
|
133 |
+
hidden_features = hidden_features or in_features
|
134 |
+
bias = to_2tuple(bias)
|
135 |
+
drop_probs = to_2tuple(drop)
|
136 |
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
137 |
+
|
138 |
+
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
139 |
+
self.act = act_layer()
|
140 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
141 |
+
self.norm = (
|
142 |
+
norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
143 |
+
)
|
144 |
+
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
145 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
146 |
+
|
147 |
+
def forward(self, x):
|
148 |
+
x = self.fc1(x)
|
149 |
+
x = self.act(x)
|
150 |
+
x = self.drop1(x)
|
151 |
+
x = self.norm(x)
|
152 |
+
x = self.fc2(x)
|
153 |
+
x = self.drop2(x)
|
154 |
+
return x
|
modeling/backbone/vit.py
ADDED
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Literal, Union
|
3 |
+
from functools import partial
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from detectron2.modeling import Backbone
|
7 |
+
|
8 |
+
try:
|
9 |
+
from xformers.ops import memory_efficient_attention
|
10 |
+
|
11 |
+
XFORMERS_ON = True
|
12 |
+
except ImportError:
|
13 |
+
XFORMERS_ON = False
|
14 |
+
from .utils import (
|
15 |
+
PatchEmbed,
|
16 |
+
get_abs_pos,
|
17 |
+
DropPath,
|
18 |
+
Mlp,
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
class Attention(nn.Module):
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
dim,
|
29 |
+
num_heads=8,
|
30 |
+
qkv_bias=True,
|
31 |
+
return_softmax_attn=True,
|
32 |
+
use_proj=True,
|
33 |
+
patch_token_offset=0,
|
34 |
+
):
|
35 |
+
"""
|
36 |
+
Args:
|
37 |
+
dim (int): Number of input channels.
|
38 |
+
num_heads (int): Number of attention heads.
|
39 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
40 |
+
"""
|
41 |
+
super().__init__()
|
42 |
+
self.num_heads = num_heads
|
43 |
+
head_dim = dim // num_heads
|
44 |
+
self.scale = head_dim**-0.5
|
45 |
+
|
46 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
47 |
+
self.proj = nn.Linear(dim, dim) if use_proj else nn.Identity()
|
48 |
+
|
49 |
+
self.return_softmax_attn = return_softmax_attn
|
50 |
+
|
51 |
+
self.patch_token_offset = patch_token_offset
|
52 |
+
|
53 |
+
def forward(self, x, return_attention=False, extra_token_offset=None):
|
54 |
+
B, L, _ = x.shape
|
55 |
+
# qkv with shape (3, B, nHead, H * W, C)
|
56 |
+
qkv = self.qkv(x).view(B, L, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
57 |
+
# q, k, v with shape (B * nHead, H * W, C)
|
58 |
+
q, k, v = qkv.reshape(3, B * self.num_heads, L, -1).unbind(0)
|
59 |
+
|
60 |
+
if return_attention or not XFORMERS_ON:
|
61 |
+
attn = (q * self.scale) @ k.transpose(-2, -1)
|
62 |
+
if return_attention and not self.return_softmax_attn:
|
63 |
+
out_attn = attn
|
64 |
+
attn = attn.softmax(dim=-1)
|
65 |
+
if return_attention and self.return_softmax_attn:
|
66 |
+
out_attn = attn
|
67 |
+
x = attn @ v
|
68 |
+
else:
|
69 |
+
x = memory_efficient_attention(q, k, v, scale=self.scale)
|
70 |
+
|
71 |
+
x = x.view(B, self.num_heads, L, -1).permute(0, 2, 1, 3).reshape(B, L, -1)
|
72 |
+
x = self.proj(x)
|
73 |
+
|
74 |
+
if return_attention:
|
75 |
+
out_attn = out_attn.reshape(B, self.num_heads, L, -1)
|
76 |
+
out_attn = out_attn[
|
77 |
+
:,
|
78 |
+
:,
|
79 |
+
self.patch_token_offset : extra_token_offset,
|
80 |
+
self.patch_token_offset : extra_token_offset,
|
81 |
+
]
|
82 |
+
return x, out_attn
|
83 |
+
else:
|
84 |
+
return x
|
85 |
+
|
86 |
+
|
87 |
+
class Block(nn.Module):
|
88 |
+
def __init__(
|
89 |
+
self,
|
90 |
+
dim,
|
91 |
+
num_heads,
|
92 |
+
mlp_ratio=4.0,
|
93 |
+
qkv_bias=True,
|
94 |
+
drop_path=0.0,
|
95 |
+
norm_layer=nn.LayerNorm,
|
96 |
+
act_layer=nn.GELU,
|
97 |
+
init_values=None,
|
98 |
+
return_softmax_attn=True,
|
99 |
+
attention_map_only=False,
|
100 |
+
patch_token_offset=0,
|
101 |
+
):
|
102 |
+
"""
|
103 |
+
Args:
|
104 |
+
dim (int): Number of input channels.
|
105 |
+
num_heads (int): Number of attention heads in each ViT block.
|
106 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
107 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
108 |
+
drop_path (float): Stochastic depth rate.
|
109 |
+
norm_layer (nn.Module): Normalization layer.
|
110 |
+
act_layer (nn.Module): Activation layer.
|
111 |
+
"""
|
112 |
+
super().__init__()
|
113 |
+
self.attention_map_only = attention_map_only
|
114 |
+
self.norm1 = norm_layer(dim)
|
115 |
+
self.attn = Attention(
|
116 |
+
dim,
|
117 |
+
num_heads=num_heads,
|
118 |
+
qkv_bias=qkv_bias,
|
119 |
+
return_softmax_attn=return_softmax_attn,
|
120 |
+
use_proj=return_softmax_attn or not attention_map_only,
|
121 |
+
patch_token_offset=patch_token_offset,
|
122 |
+
)
|
123 |
+
|
124 |
+
if attention_map_only:
|
125 |
+
return
|
126 |
+
|
127 |
+
self.ls1 = (
|
128 |
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
129 |
+
)
|
130 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
131 |
+
self.norm2 = norm_layer(dim)
|
132 |
+
self.mlp = Mlp(
|
133 |
+
in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer
|
134 |
+
)
|
135 |
+
self.ls2 = (
|
136 |
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
137 |
+
)
|
138 |
+
|
139 |
+
def forward(self, x, return_attention=False, extra_token_offset=None):
|
140 |
+
shortcut = x
|
141 |
+
x = self.norm1(x)
|
142 |
+
|
143 |
+
if return_attention:
|
144 |
+
x, attn = self.attn(x, True, extra_token_offset)
|
145 |
+
else:
|
146 |
+
x = self.attn(x)
|
147 |
+
|
148 |
+
if self.attention_map_only:
|
149 |
+
return x, attn
|
150 |
+
|
151 |
+
x = shortcut + self.drop_path(self.ls1(x))
|
152 |
+
x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))
|
153 |
+
|
154 |
+
if return_attention:
|
155 |
+
return x, attn
|
156 |
+
else:
|
157 |
+
return x
|
158 |
+
|
159 |
+
|
160 |
+
class LayerScale(nn.Module):
|
161 |
+
def __init__(
|
162 |
+
self,
|
163 |
+
dim: int,
|
164 |
+
init_values: Union[float, torch.Tensor] = 1e-5,
|
165 |
+
inplace: bool = False,
|
166 |
+
) -> None:
|
167 |
+
super().__init__()
|
168 |
+
self.inplace = inplace
|
169 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
170 |
+
|
171 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
172 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
173 |
+
|
174 |
+
|
175 |
+
class ViT(Backbone):
|
176 |
+
def __init__(
|
177 |
+
self,
|
178 |
+
patch_size=16,
|
179 |
+
in_chans=3,
|
180 |
+
embed_dim=768,
|
181 |
+
depth=12,
|
182 |
+
num_heads=12,
|
183 |
+
mlp_ratio=4.0,
|
184 |
+
qkv_bias=True,
|
185 |
+
drop_path_rate=0.0,
|
186 |
+
norm_layer=nn.LayerNorm,
|
187 |
+
act_layer=nn.GELU,
|
188 |
+
pretrain_img_size=224,
|
189 |
+
pretrain_use_cls_token=True,
|
190 |
+
init_values=None,
|
191 |
+
use_cls_token=False,
|
192 |
+
use_mask_token=False,
|
193 |
+
norm_features=False,
|
194 |
+
return_softmax_attn=True,
|
195 |
+
num_register_tokens=0,
|
196 |
+
num_msg_tokens=0,
|
197 |
+
register_as_msg=False,
|
198 |
+
shift_strides=None, # [1, -1, 2, -2],
|
199 |
+
cls_shift=False,
|
200 |
+
num_extra_tokens=4,
|
201 |
+
use_extra_embed=False,
|
202 |
+
num_frames=None,
|
203 |
+
out_feature=True,
|
204 |
+
out_attn=(),
|
205 |
+
):
|
206 |
+
super().__init__()
|
207 |
+
self.pretrain_use_cls_token = pretrain_use_cls_token
|
208 |
+
|
209 |
+
self.patch_size = patch_size
|
210 |
+
self.patch_embed = PatchEmbed(
|
211 |
+
kernel_size=(patch_size, patch_size),
|
212 |
+
stride=(patch_size, patch_size),
|
213 |
+
in_chans=in_chans,
|
214 |
+
embed_dim=embed_dim,
|
215 |
+
)
|
216 |
+
|
217 |
+
# Initialize absolute positional embedding with pretrain image size.
|
218 |
+
num_patches = (pretrain_img_size // patch_size) * (
|
219 |
+
pretrain_img_size // patch_size
|
220 |
+
)
|
221 |
+
num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
|
222 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
|
223 |
+
|
224 |
+
self.use_cls_token = use_cls_token
|
225 |
+
self.cls_token = (
|
226 |
+
nn.Parameter(torch.zeros(1, 1, embed_dim)) if use_cls_token else None
|
227 |
+
)
|
228 |
+
|
229 |
+
assert num_register_tokens >= 0
|
230 |
+
self.register_tokens = (
|
231 |
+
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim))
|
232 |
+
if num_register_tokens > 0
|
233 |
+
else None
|
234 |
+
)
|
235 |
+
|
236 |
+
# We tried to leverage temporal information with TeViT while it doesn't work
|
237 |
+
assert num_msg_tokens >= 0
|
238 |
+
self.num_msg_tokens = num_msg_tokens
|
239 |
+
if register_as_msg:
|
240 |
+
self.num_msg_tokens += num_register_tokens
|
241 |
+
self.msg_tokens = (
|
242 |
+
nn.Parameter(torch.zeros(1, num_msg_tokens, embed_dim))
|
243 |
+
if num_msg_tokens > 0
|
244 |
+
else None
|
245 |
+
)
|
246 |
+
|
247 |
+
patch_token_offset = (
|
248 |
+
num_msg_tokens + num_register_tokens + int(self.use_cls_token)
|
249 |
+
)
|
250 |
+
self.patch_token_offset = patch_token_offset
|
251 |
+
|
252 |
+
self.msg_shift = None
|
253 |
+
if shift_strides is not None:
|
254 |
+
self.msg_shift = []
|
255 |
+
for i in range(depth):
|
256 |
+
if i % 2 == 0:
|
257 |
+
self.msg_shift.append([_ for _ in shift_strides])
|
258 |
+
else:
|
259 |
+
self.msg_shift.append([-_ for _ in shift_strides])
|
260 |
+
|
261 |
+
self.cls_shift = None
|
262 |
+
if cls_shift:
|
263 |
+
self.cls_shift = [(-1) ** idx for idx in range(depth)]
|
264 |
+
|
265 |
+
assert num_extra_tokens >= 0
|
266 |
+
self.num_extra_tokens = num_extra_tokens
|
267 |
+
self.extra_pos_embed = (
|
268 |
+
nn.Linear(embed_dim, embed_dim)
|
269 |
+
if num_extra_tokens > 0 and use_extra_embed
|
270 |
+
else nn.Identity()
|
271 |
+
)
|
272 |
+
|
273 |
+
self.num_frames = num_frames
|
274 |
+
|
275 |
+
# Mask token for masking augmentation
|
276 |
+
self.use_mask_token = use_mask_token
|
277 |
+
self.mask_token = (
|
278 |
+
nn.Parameter(torch.zeros(1, embed_dim)) if use_mask_token else None
|
279 |
+
)
|
280 |
+
|
281 |
+
# stochastic depth decay rule
|
282 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
283 |
+
|
284 |
+
self.blocks = nn.ModuleList()
|
285 |
+
for i in range(depth):
|
286 |
+
block = Block(
|
287 |
+
dim=embed_dim,
|
288 |
+
num_heads=num_heads,
|
289 |
+
mlp_ratio=mlp_ratio,
|
290 |
+
qkv_bias=qkv_bias,
|
291 |
+
drop_path=dpr[i],
|
292 |
+
norm_layer=norm_layer,
|
293 |
+
act_layer=act_layer,
|
294 |
+
init_values=init_values,
|
295 |
+
return_softmax_attn=return_softmax_attn,
|
296 |
+
attention_map_only=(i == depth - 1) and not out_feature,
|
297 |
+
patch_token_offset=patch_token_offset,
|
298 |
+
)
|
299 |
+
self.blocks.append(block)
|
300 |
+
|
301 |
+
self.norm = norm_layer(embed_dim) if norm_features else nn.Identity()
|
302 |
+
|
303 |
+
self._out_features = out_feature
|
304 |
+
self._out_attn = out_attn
|
305 |
+
|
306 |
+
if self.pos_embed is not None:
|
307 |
+
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
308 |
+
|
309 |
+
self.apply(self._init_weights)
|
310 |
+
|
311 |
+
def _init_weights(self, m):
|
312 |
+
if isinstance(m, nn.Linear):
|
313 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
314 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
315 |
+
nn.init.constant_(m.bias, 0)
|
316 |
+
elif isinstance(m, nn.LayerNorm):
|
317 |
+
nn.init.constant_(m.bias, 0)
|
318 |
+
nn.init.constant_(m.weight, 1.0)
|
319 |
+
|
320 |
+
def forward(self, x, masks=None, guidance=None):
|
321 |
+
x = self.patch_embed(x)
|
322 |
+
if masks is not None:
|
323 |
+
x = torch.where(
|
324 |
+
masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x
|
325 |
+
)
|
326 |
+
|
327 |
+
if self.pos_embed is not None:
|
328 |
+
x = x + get_abs_pos(
|
329 |
+
self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2])
|
330 |
+
)
|
331 |
+
B, H, W, _ = x.shape
|
332 |
+
x = x.reshape(B, H * W, -1)
|
333 |
+
|
334 |
+
if self.use_cls_token:
|
335 |
+
cls_tokens = self.cls_token.expand(len(x), -1, -1)
|
336 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
337 |
+
|
338 |
+
if self.register_tokens is not None:
|
339 |
+
register_tokens = self.register_tokens.expand(len(x), -1, -1)
|
340 |
+
x = torch.cat((register_tokens, x), dim=1)
|
341 |
+
|
342 |
+
if self.msg_tokens is not None:
|
343 |
+
msg_tokens = self.msg_tokens.expand(len(x), -1, -1)
|
344 |
+
x = torch.cat((msg_tokens, x), dim=1)
|
345 |
+
# [MSG, REG, CLS, PAT]
|
346 |
+
|
347 |
+
extra_tokens_offset = None
|
348 |
+
if guidance is not None:
|
349 |
+
guidance = guidance.reshape(len(guidance), -1, 1)
|
350 |
+
extra_tokens = (
|
351 |
+
(x[:, self.patch_token_offset :] * guidance)
|
352 |
+
.sum(dim=1, keepdim=True)
|
353 |
+
.expand(-1, self.num_extra_tokens, -1)
|
354 |
+
)
|
355 |
+
extra_tokens = self.extra_pos_embed(extra_tokens)
|
356 |
+
x = torch.cat((x, extra_tokens), dim=1)
|
357 |
+
extra_tokens_offset = -self.num_extra_tokens
|
358 |
+
# [MSG, REG, CLS, PAT, EXT]
|
359 |
+
|
360 |
+
attn_maps = []
|
361 |
+
for idx, blk in enumerate(self.blocks):
|
362 |
+
if idx in self._out_attn:
|
363 |
+
x, attn = blk(x, True, extra_tokens_offset)
|
364 |
+
attn_maps.append(attn)
|
365 |
+
else:
|
366 |
+
x = blk(x)
|
367 |
+
|
368 |
+
if self.msg_shift is not None:
|
369 |
+
msg_shift = self.msg_shift[idx]
|
370 |
+
msg_tokens = (
|
371 |
+
x[:, : self.num_msg_tokens]
|
372 |
+
if guidance is None
|
373 |
+
else x[:, extra_tokens_offset:]
|
374 |
+
)
|
375 |
+
msg_tokens = msg_tokens.reshape(
|
376 |
+
-1, self.num_frames, *msg_tokens.shape[1:]
|
377 |
+
)
|
378 |
+
msg_tokens = msg_tokens.chunk(len(msg_shift), dim=2)
|
379 |
+
msg_tokens = [
|
380 |
+
torch.roll(tokens, roll, dims=1)
|
381 |
+
for tokens, roll in zip(msg_tokens, msg_shift)
|
382 |
+
]
|
383 |
+
msg_tokens = torch.cat(msg_tokens, dim=2).flatten(0, 1)
|
384 |
+
if guidance is None:
|
385 |
+
x = torch.cat([msg_tokens, x[:, self.num_msg_tokens :]], dim=1)
|
386 |
+
else:
|
387 |
+
x = torch.cat([x[:, :extra_tokens_offset], msg_tokens], dim=1)
|
388 |
+
|
389 |
+
if self.cls_shift is not None:
|
390 |
+
cls_tokens = x[:, self.patch_token_offset - 1]
|
391 |
+
cls_tokens = cls_tokens.reshape(
|
392 |
+
-1, self.num_frames, 1, *cls_tokens.shape[1:]
|
393 |
+
)
|
394 |
+
cls_tokens = torch.roll(cls_tokens, self.cls_shift[idx], dims=1)
|
395 |
+
x = torch.cat(
|
396 |
+
[
|
397 |
+
x[:, : self.patch_token_offset - 1],
|
398 |
+
cls_tokens.flatten(0, 1),
|
399 |
+
x[:, self.patch_token_offset :],
|
400 |
+
],
|
401 |
+
dim=1,
|
402 |
+
)
|
403 |
+
|
404 |
+
x = self.norm(x)
|
405 |
+
|
406 |
+
outputs = {}
|
407 |
+
outputs["attention_maps"] = torch.cat(attn_maps, dim=1).reshape(
|
408 |
+
B, -1, H * W, H, W
|
409 |
+
)
|
410 |
+
if self._out_features:
|
411 |
+
outputs["last_feat"] = (
|
412 |
+
x[:, self.patch_token_offset : extra_tokens_offset]
|
413 |
+
.reshape(B, H, W, -1)
|
414 |
+
.permute(0, 3, 1, 2)
|
415 |
+
)
|
416 |
+
|
417 |
+
return outputs
|
418 |
+
|
419 |
+
|
420 |
+
def vit_tiny(**kwargs):
|
421 |
+
model = ViT(
|
422 |
+
patch_size=16,
|
423 |
+
embed_dim=192,
|
424 |
+
depth=12,
|
425 |
+
num_heads=3,
|
426 |
+
mlp_ratio=4,
|
427 |
+
qkv_bias=True,
|
428 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
429 |
+
**kwargs
|
430 |
+
)
|
431 |
+
return model
|
432 |
+
|
433 |
+
|
434 |
+
def vit_small(**kwargs):
|
435 |
+
model = ViT(
|
436 |
+
patch_size=16,
|
437 |
+
embed_dim=384,
|
438 |
+
depth=12,
|
439 |
+
num_heads=6,
|
440 |
+
mlp_ratio=4,
|
441 |
+
qkv_bias=True,
|
442 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
443 |
+
**kwargs
|
444 |
+
)
|
445 |
+
return model
|
446 |
+
|
447 |
+
|
448 |
+
def vit_base(**kwargs):
|
449 |
+
model = ViT(
|
450 |
+
patch_size=16,
|
451 |
+
embed_dim=768,
|
452 |
+
depth=12,
|
453 |
+
num_heads=12,
|
454 |
+
mlp_ratio=4,
|
455 |
+
qkv_bias=True,
|
456 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
457 |
+
**kwargs
|
458 |
+
)
|
459 |
+
return model
|
460 |
+
|
461 |
+
|
462 |
+
def dinov2_base(**kwargs):
|
463 |
+
model = ViT(
|
464 |
+
patch_size=14,
|
465 |
+
embed_dim=768,
|
466 |
+
depth=12,
|
467 |
+
num_heads=12,
|
468 |
+
mlp_ratio=4,
|
469 |
+
qkv_bias=True,
|
470 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
471 |
+
pretrain_img_size=518,
|
472 |
+
init_values=1,
|
473 |
+
**kwargs
|
474 |
+
)
|
475 |
+
return model
|
476 |
+
|
477 |
+
|
478 |
+
def dinov2_small(**kwargs):
|
479 |
+
model = ViT(
|
480 |
+
patch_size=14,
|
481 |
+
embed_dim=384,
|
482 |
+
depth=12,
|
483 |
+
num_heads=6,
|
484 |
+
mlp_ratio=4,
|
485 |
+
qkv_bias=True,
|
486 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
487 |
+
pretrain_img_size=518,
|
488 |
+
init_values=1,
|
489 |
+
**kwargs
|
490 |
+
)
|
491 |
+
return model
|
492 |
+
|
493 |
+
|
494 |
+
def build_backbone(
|
495 |
+
name: Literal["tiny", "small", "base", "dinov2_base", "dinov2_small"], **kwargs
|
496 |
+
):
|
497 |
+
vit_dict = {
|
498 |
+
"tiny": vit_tiny,
|
499 |
+
"small": vit_small,
|
500 |
+
"base": vit_base,
|
501 |
+
"dinov2_base": dinov2_base,
|
502 |
+
"dinov2_small": dinov2_small,
|
503 |
+
}
|
504 |
+
return vit_dict[name](**kwargs)
|
modeling/criterion/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .gaze_mapper_criterion import GazeMapperCriterion
|
modeling/criterion/gaze_mapper_criterion.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
from fvcore.nn import sigmoid_focal_loss_jit
|
6 |
+
|
7 |
+
|
8 |
+
class GazeMapperCriterion(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
heatmap_weight: float = 10000,
|
12 |
+
inout_weight: float = 100,
|
13 |
+
aux_weight: float = 100,
|
14 |
+
use_aux_loss: bool = False,
|
15 |
+
aux_head_thres: float = 0,
|
16 |
+
use_focal_loss: bool = False,
|
17 |
+
alpha: float = -1,
|
18 |
+
gamma: float = 2,
|
19 |
+
):
|
20 |
+
super().__init__()
|
21 |
+
self.heatmap_weight = heatmap_weight
|
22 |
+
self.inout_weight = inout_weight
|
23 |
+
self.aux_weight = aux_weight
|
24 |
+
self.aux_head_thres = aux_head_thres
|
25 |
+
|
26 |
+
self.heatmap_loss = nn.MSELoss(reduce=False)
|
27 |
+
|
28 |
+
if use_focal_loss:
|
29 |
+
self.inout_loss = partial(
|
30 |
+
sigmoid_focal_loss_jit, alpha=alpha, gamma=gamma, reduction="mean"
|
31 |
+
)
|
32 |
+
else:
|
33 |
+
self.inout_loss = nn.BCEWithLogitsLoss()
|
34 |
+
|
35 |
+
if use_aux_loss:
|
36 |
+
self.aux_loss = nn.BCEWithLogitsLoss()
|
37 |
+
else:
|
38 |
+
self.aux_loss = None
|
39 |
+
|
40 |
+
def forward(
|
41 |
+
self,
|
42 |
+
pred_heatmap,
|
43 |
+
pred_inout,
|
44 |
+
gt_heatmap,
|
45 |
+
gt_inout,
|
46 |
+
pred_head_masks=None,
|
47 |
+
gt_head_masks=None,
|
48 |
+
):
|
49 |
+
loss_dict = {}
|
50 |
+
|
51 |
+
pred_heatmap = F.interpolate(
|
52 |
+
pred_heatmap,
|
53 |
+
size=tuple(gt_heatmap.shape[-2:]),
|
54 |
+
mode="bilinear",
|
55 |
+
align_corners=True,
|
56 |
+
)
|
57 |
+
heatmap_loss = (
|
58 |
+
self.heatmap_loss(pred_heatmap.squeeze(1), gt_heatmap) * self.heatmap_weight
|
59 |
+
)
|
60 |
+
heatmap_loss = torch.mean(heatmap_loss, dim=(-2, -1))
|
61 |
+
heatmap_loss = torch.sum(heatmap_loss.reshape(-1) * gt_inout.reshape(-1))
|
62 |
+
# Check whether all outside, avoid 0/0 to be nan
|
63 |
+
if heatmap_loss > 1e-7:
|
64 |
+
heatmap_loss = heatmap_loss / torch.sum(gt_inout)
|
65 |
+
loss_dict["regression loss"] = heatmap_loss
|
66 |
+
else:
|
67 |
+
loss_dict["regression loss"] = heatmap_loss * 0
|
68 |
+
|
69 |
+
inout_loss = (
|
70 |
+
self.inout_loss(pred_inout.reshape(-1), gt_inout.reshape(-1))
|
71 |
+
* self.inout_weight
|
72 |
+
)
|
73 |
+
loss_dict["classification loss"] = inout_loss
|
74 |
+
|
75 |
+
if self.aux_loss is not None:
|
76 |
+
pred_head_masks = F.interpolate(
|
77 |
+
pred_head_masks,
|
78 |
+
size=tuple(gt_head_masks.shape[-2:]),
|
79 |
+
mode="bilinear",
|
80 |
+
align_corners=True,
|
81 |
+
)
|
82 |
+
aux_loss = (
|
83 |
+
torch.clamp(
|
84 |
+
self.aux_loss(
|
85 |
+
pred_head_masks.reshape(-1), gt_head_masks.reshape(-1)
|
86 |
+
)
|
87 |
+
- self.aux_head_thres,
|
88 |
+
0,
|
89 |
+
)
|
90 |
+
* self.aux_weight
|
91 |
+
)
|
92 |
+
loss_dict["aux head loss"] = aux_loss
|
93 |
+
|
94 |
+
return loss_dict
|
modeling/head/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .inout_head import build_inout_head
|
2 |
+
from .heatmap_head import build_heatmap_head
|
modeling/head/heatmap_head.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from detectron2.utils.registry import Registry
|
4 |
+
from typing import Literal, List, Dict, Optional, OrderedDict
|
5 |
+
|
6 |
+
|
7 |
+
HEATMAP_HEAD_REGISTRY = Registry("HEATMAP_HEAD_REGISTRY")
|
8 |
+
HEATMAP_HEAD_REGISTRY.__doc__ = "Registry for heatmap head"
|
9 |
+
|
10 |
+
|
11 |
+
class BaseHeatmapHead(nn.Module):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
in_channel: int,
|
15 |
+
deconv_cfgs: List[Dict],
|
16 |
+
dim: int = 96,
|
17 |
+
use_conv: bool = False,
|
18 |
+
use_residual: bool = False,
|
19 |
+
feat_type: Literal["attn", "both"] = "both",
|
20 |
+
attn_layer: Optional[str] = None,
|
21 |
+
pre_norm: bool = False,
|
22 |
+
use_head: bool = False,
|
23 |
+
) -> None:
|
24 |
+
super().__init__()
|
25 |
+
self.feat_type = feat_type
|
26 |
+
self.use_head = use_head
|
27 |
+
|
28 |
+
if pre_norm:
|
29 |
+
self.pre_norm = nn.Sequential(
|
30 |
+
OrderedDict(
|
31 |
+
[
|
32 |
+
("bn", nn.BatchNorm2d(in_channel)),
|
33 |
+
("relu", nn.ReLU(inplace=True)),
|
34 |
+
]
|
35 |
+
)
|
36 |
+
)
|
37 |
+
else:
|
38 |
+
self.pre_norm = nn.Identity()
|
39 |
+
|
40 |
+
if use_conv:
|
41 |
+
if use_residual:
|
42 |
+
from timm.models.resnet import Bottleneck, downsample_conv
|
43 |
+
|
44 |
+
self.conv = Bottleneck(
|
45 |
+
in_channel,
|
46 |
+
dim // 4,
|
47 |
+
downsample=downsample_conv(in_channel, dim, 1)
|
48 |
+
if in_channel != dim
|
49 |
+
else None,
|
50 |
+
attn_layer=attn_layer,
|
51 |
+
)
|
52 |
+
else:
|
53 |
+
self.conv = nn.Sequential(
|
54 |
+
OrderedDict(
|
55 |
+
[
|
56 |
+
("conv", nn.Conv2d(in_channel, dim, 3, 1, 1)),
|
57 |
+
("bn", nn.BatchNorm2d(dim)),
|
58 |
+
("relu", nn.ReLU(inplace=True)),
|
59 |
+
]
|
60 |
+
)
|
61 |
+
)
|
62 |
+
else:
|
63 |
+
self.conv = nn.Identity()
|
64 |
+
|
65 |
+
self.decoder: nn.Module = None
|
66 |
+
|
67 |
+
def get_feat(self, x):
|
68 |
+
if self.feat_type == "attn":
|
69 |
+
feat = x["attention_maps"]
|
70 |
+
elif self.feat_type == "feat":
|
71 |
+
feat = x["last_feat"]
|
72 |
+
return feat
|
73 |
+
|
74 |
+
def forward(self, x):
|
75 |
+
feat = self.get_feat(x)
|
76 |
+
feat = self.pre_norm(feat)
|
77 |
+
feat = self.conv(feat)
|
78 |
+
return self.decoder(feat)
|
79 |
+
|
80 |
+
|
81 |
+
@HEATMAP_HEAD_REGISTRY.register()
|
82 |
+
class SimpleDeconv(BaseHeatmapHead):
|
83 |
+
def __init__(
|
84 |
+
self,
|
85 |
+
in_channel: int,
|
86 |
+
deconv_cfgs: List[Dict],
|
87 |
+
dim: int = 96,
|
88 |
+
use_conv: bool = False,
|
89 |
+
use_residual: bool = False,
|
90 |
+
feat_type: Literal["attn", "both"] = "both",
|
91 |
+
attn_layer: Optional[str] = None,
|
92 |
+
pre_norm: bool = False,
|
93 |
+
use_head: bool = False,
|
94 |
+
) -> None:
|
95 |
+
super().__init__(
|
96 |
+
in_channel,
|
97 |
+
deconv_cfgs,
|
98 |
+
dim,
|
99 |
+
use_conv,
|
100 |
+
use_residual,
|
101 |
+
feat_type,
|
102 |
+
attn_layer,
|
103 |
+
pre_norm,
|
104 |
+
use_head,
|
105 |
+
)
|
106 |
+
decoder_layers = []
|
107 |
+
for i, deconv_cfg in enumerate(deconv_cfgs, start=1):
|
108 |
+
decoder_layers.extend(
|
109 |
+
[
|
110 |
+
(
|
111 |
+
"".join(["deconv", str(i)]),
|
112 |
+
nn.ConvTranspose2d(**deconv_cfg),
|
113 |
+
),
|
114 |
+
(
|
115 |
+
"".join(["bn", str(i)]),
|
116 |
+
nn.BatchNorm2d(deconv_cfg["out_channels"]),
|
117 |
+
),
|
118 |
+
("".join(["relu", str(i)]), nn.ReLU(inplace=True)),
|
119 |
+
]
|
120 |
+
)
|
121 |
+
decoder_layers.append(("conv", nn.Conv2d(1, 1, 1)))
|
122 |
+
self.decoder = nn.Sequential(OrderedDict(decoder_layers))
|
123 |
+
|
124 |
+
|
125 |
+
@HEATMAP_HEAD_REGISTRY.register()
|
126 |
+
class UpSampleConv(BaseHeatmapHead):
|
127 |
+
def __init__(
|
128 |
+
self,
|
129 |
+
in_channel: int,
|
130 |
+
deconv_cfgs: List[Dict],
|
131 |
+
dim: int = 96,
|
132 |
+
use_conv: bool = False,
|
133 |
+
use_residual: bool = False,
|
134 |
+
feat_type: Literal["attn", "both"] = "both",
|
135 |
+
attn_layer: Optional[str] = None,
|
136 |
+
pre_norm: bool = False,
|
137 |
+
use_head: bool = False,
|
138 |
+
) -> None:
|
139 |
+
super().__init__(
|
140 |
+
in_channel,
|
141 |
+
deconv_cfgs,
|
142 |
+
dim,
|
143 |
+
use_conv,
|
144 |
+
use_residual,
|
145 |
+
feat_type,
|
146 |
+
attn_layer,
|
147 |
+
pre_norm,
|
148 |
+
use_head,
|
149 |
+
)
|
150 |
+
decoder_layers = []
|
151 |
+
for i, deconv_cfg in enumerate(deconv_cfgs, start=1):
|
152 |
+
decoder_layers.extend(
|
153 |
+
[
|
154 |
+
(
|
155 |
+
"".join(["upsample", str(i)]),
|
156 |
+
nn.Upsample(
|
157 |
+
scale_factor=2, mode="bilinear", align_corners=True
|
158 |
+
),
|
159 |
+
),
|
160 |
+
(
|
161 |
+
"".join(["conv", str(i)]),
|
162 |
+
nn.Conv2d(**deconv_cfg),
|
163 |
+
),
|
164 |
+
(
|
165 |
+
"".join(["bn", str(i)]),
|
166 |
+
nn.BatchNorm2d(deconv_cfg["out_channels"]),
|
167 |
+
),
|
168 |
+
("".join(["relu", str(i)]), nn.ReLU(inplace=True)),
|
169 |
+
]
|
170 |
+
)
|
171 |
+
decoder_layers.append(("conv", nn.Conv2d(1, 1, 1)))
|
172 |
+
self.decoder = nn.Sequential(OrderedDict(decoder_layers))
|
173 |
+
|
174 |
+
|
175 |
+
@HEATMAP_HEAD_REGISTRY.register()
|
176 |
+
class PixelShuffle(BaseHeatmapHead):
|
177 |
+
def __init__(
|
178 |
+
self,
|
179 |
+
in_channel: int,
|
180 |
+
deconv_cfgs: List[Dict],
|
181 |
+
dim: int = 96,
|
182 |
+
use_conv: bool = False,
|
183 |
+
use_residual: bool = False,
|
184 |
+
feat_type: Literal["attn", "both"] = "both",
|
185 |
+
attn_layer: Optional[str] = None,
|
186 |
+
pre_norm: bool = False,
|
187 |
+
use_head: bool = False,
|
188 |
+
) -> None:
|
189 |
+
super().__init__(
|
190 |
+
in_channel,
|
191 |
+
deconv_cfgs,
|
192 |
+
dim,
|
193 |
+
use_conv,
|
194 |
+
use_residual,
|
195 |
+
feat_type,
|
196 |
+
attn_layer,
|
197 |
+
pre_norm,
|
198 |
+
use_head,
|
199 |
+
)
|
200 |
+
decoder_layers = []
|
201 |
+
for i, deconv_cfg in enumerate(deconv_cfgs, start=1):
|
202 |
+
deconv_cfg["out_channels"] = deconv_cfg["out_channels"] * 4
|
203 |
+
decoder_layers.extend(
|
204 |
+
[
|
205 |
+
(
|
206 |
+
"".join(["conv", str(i)]),
|
207 |
+
nn.Conv2d(**deconv_cfg),
|
208 |
+
),
|
209 |
+
(
|
210 |
+
"".join(["pixel_shuffle", str(i)]),
|
211 |
+
nn.PixelShuffle(upscale_factor=2),
|
212 |
+
),
|
213 |
+
(
|
214 |
+
"".join(["bn", str(i)]),
|
215 |
+
nn.BatchNorm2d(deconv_cfg["out_channels"] // 4),
|
216 |
+
),
|
217 |
+
("".join(["relu", str(i)]), nn.ReLU(inplace=True)),
|
218 |
+
]
|
219 |
+
)
|
220 |
+
decoder_layers.append(("conv", nn.Conv2d(1, 1, 1)))
|
221 |
+
self.decoder = nn.Sequential(OrderedDict(decoder_layers))
|
222 |
+
|
223 |
+
|
224 |
+
def build_heatmap_head(name, *args, **kwargs):
|
225 |
+
return HEATMAP_HEAD_REGISTRY.get(name)(*args, **kwargs)
|
modeling/head/inout_head.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import OrderedDict
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from detectron2.utils.registry import Registry
|
5 |
+
|
6 |
+
|
7 |
+
INOUT_HEAD_REGISTRY = Registry("INOUT_HEAD_REGISTRY")
|
8 |
+
INOUT_HEAD_REGISTRY.__doc__ = "Registry for inout head"
|
9 |
+
|
10 |
+
|
11 |
+
@INOUT_HEAD_REGISTRY.register()
|
12 |
+
class SimpleLinear(nn.Module):
|
13 |
+
def __init__(self, in_channel: int, dropout: float = 0) -> None:
|
14 |
+
super().__init__()
|
15 |
+
self.in_channel = in_channel
|
16 |
+
self.classifier = nn.Sequential(
|
17 |
+
OrderedDict(
|
18 |
+
[("dropout", nn.Dropout(dropout)), ("linear", nn.Linear(in_channel, 1))]
|
19 |
+
)
|
20 |
+
)
|
21 |
+
|
22 |
+
def get_feat(self, x, masks):
|
23 |
+
feats = x["head_feat"]
|
24 |
+
if masks is not None:
|
25 |
+
B, C = x["last_feat"].shape[:2]
|
26 |
+
scene_feats = x["last_feat"].view(B, C, -1).permute(0, 2, 1)
|
27 |
+
masks = masks / (masks.sum(dim=-1, keepdim=True) + 1e-6)
|
28 |
+
scene_feats = (scene_feats * masks.unsqueeze(-1)).sum(dim=1)
|
29 |
+
feats = torch.cat((feats, scene_feats), dim=1)
|
30 |
+
return feats
|
31 |
+
|
32 |
+
def forward(self, x, masks=None):
|
33 |
+
feat = self.get_feat(x, masks)
|
34 |
+
return self.classifier(feat)
|
35 |
+
|
36 |
+
|
37 |
+
@INOUT_HEAD_REGISTRY.register()
|
38 |
+
class SimpleMlp(SimpleLinear):
|
39 |
+
def __init__(self, in_channel: int, dropout: float = 0) -> None:
|
40 |
+
super().__init__(in_channel, dropout)
|
41 |
+
self.classifier = nn.Sequential(
|
42 |
+
OrderedDict(
|
43 |
+
[
|
44 |
+
("dropout0", nn.Dropout(dropout)),
|
45 |
+
("linear0", nn.Linear(in_channel, in_channel)),
|
46 |
+
("relu", nn.ReLU()),
|
47 |
+
("dropout1", nn.Dropout(dropout)),
|
48 |
+
("linear1", nn.Linear(in_channel, 1)),
|
49 |
+
]
|
50 |
+
)
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
def build_inout_head(name, *args, **kwargs):
|
55 |
+
return INOUT_HEAD_REGISTRY.get(name)(*args, **kwargs)
|
modeling/meta_arch/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .gaze_attention_mapper import GazeAttentionMapper
|
modeling/meta_arch/gaze_attention_mapper.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from typing import Dict, Union
|
4 |
+
|
5 |
+
|
6 |
+
class GazeAttentionMapper(nn.Module):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
backbone: nn.Module,
|
10 |
+
regressor: nn.Module,
|
11 |
+
classifier: nn.Module,
|
12 |
+
criterion: nn.Module,
|
13 |
+
pam: nn.Module,
|
14 |
+
use_aux_loss: bool = False,
|
15 |
+
device: Union[torch.device, str] = "cuda",
|
16 |
+
) -> None:
|
17 |
+
super().__init__()
|
18 |
+
self.backbone = backbone
|
19 |
+
self.pam = pam
|
20 |
+
self.regressor = regressor
|
21 |
+
self.classifier = classifier
|
22 |
+
self.criterion = criterion
|
23 |
+
self.use_aux_loss = use_aux_loss
|
24 |
+
self.device = torch.device(device)
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
(
|
28 |
+
scenes,
|
29 |
+
heads,
|
30 |
+
gt_heatmaps,
|
31 |
+
gt_inouts,
|
32 |
+
head_masks,
|
33 |
+
image_masks,
|
34 |
+
) = self.preprocess_inputs(x)
|
35 |
+
# Calculate patch weights based on head position
|
36 |
+
embedded_heads = self.pam(scenes, heads)
|
37 |
+
aux_masks = None
|
38 |
+
if self.use_aux_loss:
|
39 |
+
embedded_heads, aux_masks = embedded_heads
|
40 |
+
|
41 |
+
# Get out-dict
|
42 |
+
x = self.backbone(
|
43 |
+
scenes,
|
44 |
+
image_masks,
|
45 |
+
None,
|
46 |
+
)
|
47 |
+
|
48 |
+
# Apply patch weights to get the final feats and attention maps
|
49 |
+
feats = x.get("last_feat", None)
|
50 |
+
if feats is not None:
|
51 |
+
x["head_feat"] = (
|
52 |
+
(embedded_heads.repeat(1, feats.shape[1], 1, 1) * feats)
|
53 |
+
.sum(dim=(2, 3))
|
54 |
+
.reshape(len(feats), -1)
|
55 |
+
) # BC
|
56 |
+
|
57 |
+
attn_maps = x["attention_maps"]
|
58 |
+
B, C, *_ = attn_maps.shape
|
59 |
+
x["attention_maps"] = (
|
60 |
+
attn_maps * embedded_heads.reshape(B, 1, -1, 1, 1).repeat(1, C, 1, 1, 1)
|
61 |
+
).sum(
|
62 |
+
dim=2
|
63 |
+
) # BCHW
|
64 |
+
|
65 |
+
# Apply heads
|
66 |
+
heatmaps = self.regressor(x)
|
67 |
+
inouts = self.classifier(x, None)
|
68 |
+
|
69 |
+
if self.training:
|
70 |
+
return self.criterion(
|
71 |
+
heatmaps,
|
72 |
+
inouts,
|
73 |
+
gt_heatmaps,
|
74 |
+
gt_inouts,
|
75 |
+
aux_masks,
|
76 |
+
head_masks,
|
77 |
+
)
|
78 |
+
# Inference
|
79 |
+
return heatmaps, inouts.sigmoid()
|
80 |
+
|
81 |
+
def preprocess_inputs(self, batched_inputs: Dict[str, torch.Tensor]):
|
82 |
+
return (
|
83 |
+
batched_inputs["images"].to(self.device),
|
84 |
+
batched_inputs["head_channels"].to(self.device),
|
85 |
+
batched_inputs["heatmaps"].to(self.device)
|
86 |
+
if "heatmaps" in batched_inputs.keys()
|
87 |
+
else None,
|
88 |
+
batched_inputs["gaze_inouts"].to(self.device)
|
89 |
+
if "gaze_inouts" in batched_inputs.keys()
|
90 |
+
else None,
|
91 |
+
batched_inputs["head_masks"].to(self.device)
|
92 |
+
if "head_masks" in batched_inputs.keys()
|
93 |
+
else None,
|
94 |
+
batched_inputs["image_masks"].to(self.device)
|
95 |
+
if "image_masks" in batched_inputs.keys()
|
96 |
+
else None,
|
97 |
+
)
|
modeling/patch_attention/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .patch_attention import build_pam
|
modeling/patch_attention/patch_attention.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import OrderedDict
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
from detectron2.utils.registry import Registry
|
5 |
+
from fvcore.nn import c2_msra_fill
|
6 |
+
|
7 |
+
|
8 |
+
SPATIAL_GUIDANCE_REGISTRY = Registry("SPATIAL_GUIDANCE_REGISTRY")
|
9 |
+
SPATIAL_GUIDANCE_REGISTRY.__doc__ = "Registry for 2d spatial guidance"
|
10 |
+
|
11 |
+
|
12 |
+
class _PoolFusion(nn.Module):
|
13 |
+
def __init__(self, patch_size: int, use_avgpool: bool = False) -> None:
|
14 |
+
super().__init__()
|
15 |
+
self.patch_size = patch_size
|
16 |
+
self.attn_reducer = F.avg_pool2d if use_avgpool else F.max_pool2d
|
17 |
+
|
18 |
+
def forward(self, scenes, heads):
|
19 |
+
attn_masks = self.attn_reducer(
|
20 |
+
heads,
|
21 |
+
(self.patch_size, self.patch_size),
|
22 |
+
(self.patch_size, self.patch_size),
|
23 |
+
(0, 0),
|
24 |
+
)
|
25 |
+
patch_attn = attn_masks.masked_fill(attn_masks <= 0, -1e9)
|
26 |
+
return F.softmax(patch_attn.view(len(patch_attn), -1), dim=1).view(
|
27 |
+
*patch_attn.shape
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
@SPATIAL_GUIDANCE_REGISTRY.register()
|
32 |
+
class AvgFusion(_PoolFusion):
|
33 |
+
def __init__(self, patch_size: int) -> None:
|
34 |
+
super().__init__(patch_size, False)
|
35 |
+
|
36 |
+
|
37 |
+
@SPATIAL_GUIDANCE_REGISTRY.register()
|
38 |
+
class MaxFusion(_PoolFusion):
|
39 |
+
def __init__(self, patch_size: int) -> None:
|
40 |
+
super().__init__(patch_size, True)
|
41 |
+
|
42 |
+
|
43 |
+
@SPATIAL_GUIDANCE_REGISTRY.register()
|
44 |
+
class PatchPAM(nn.Module):
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
patch_size: int,
|
48 |
+
act_layer=nn.ReLU,
|
49 |
+
embed_dim: int = 768,
|
50 |
+
use_aux_loss: bool = False,
|
51 |
+
) -> None:
|
52 |
+
super().__init__()
|
53 |
+
self.patch_size = patch_size
|
54 |
+
patch_embed = nn.Conv2d(
|
55 |
+
3, embed_dim, (patch_size, patch_size), (patch_size, patch_size), (0, 0)
|
56 |
+
)
|
57 |
+
c2_msra_fill(patch_embed)
|
58 |
+
conv = nn.Conv2d(embed_dim, 1, (1, 1), (1, 1), (0, 0))
|
59 |
+
c2_msra_fill(conv)
|
60 |
+
self.use_aux_loss = use_aux_loss
|
61 |
+
if use_aux_loss:
|
62 |
+
self.patch_embed = nn.Sequential(
|
63 |
+
OrderedDict(
|
64 |
+
[
|
65 |
+
("patch_embed", patch_embed),
|
66 |
+
("act_layer", act_layer(inplace=True)),
|
67 |
+
]
|
68 |
+
)
|
69 |
+
)
|
70 |
+
self.embed = conv
|
71 |
+
conv = nn.Conv2d(embed_dim, 1, (1, 1), (1, 1), (0, 0))
|
72 |
+
c2_msra_fill(conv)
|
73 |
+
self.aux_embed = conv
|
74 |
+
else:
|
75 |
+
self.embed = nn.Sequential(
|
76 |
+
OrderedDict(
|
77 |
+
[
|
78 |
+
("patch_embed", patch_embed),
|
79 |
+
("act_layer", act_layer(inplace=True)),
|
80 |
+
("embed", conv),
|
81 |
+
]
|
82 |
+
)
|
83 |
+
)
|
84 |
+
|
85 |
+
def forward(self, scenes, heads):
|
86 |
+
attn_masks = F.max_pool2d(
|
87 |
+
heads,
|
88 |
+
(self.patch_size, self.patch_size),
|
89 |
+
(self.patch_size, self.patch_size),
|
90 |
+
(0, 0),
|
91 |
+
)
|
92 |
+
if self.use_aux_loss:
|
93 |
+
embed = self.patch_embed(scenes)
|
94 |
+
aux_masks = self.aux_embed(embed)
|
95 |
+
patch_attn = self.embed(embed) * attn_masks
|
96 |
+
else:
|
97 |
+
patch_attn = self.embed(scenes) * attn_masks
|
98 |
+
patch_attn = patch_attn.masked_fill(attn_masks <= 0, -1e9)
|
99 |
+
patch_attn = F.softmax(patch_attn.view(len(patch_attn), -1), dim=1).view(
|
100 |
+
*patch_attn.shape
|
101 |
+
)
|
102 |
+
return (patch_attn, aux_masks) if self.use_aux_loss else patch_attn
|
103 |
+
|
104 |
+
|
105 |
+
def build_pam(name, *args, **kwargs):
|
106 |
+
return SPATIAL_GUIDANCE_REGISTRY.get(name)(*args, **kwargs)
|
pretrained/gazefollow.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:61f82255d4db0640208c90b037fd4bb62e061efd51cd03aa4b45f10159acab15
|
3 |
+
size 266808135
|
pretrained/videoattentiontarget.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cdf8d767815b7c8f1ba78d89710dc7a2de2d08e97554654a296a3e71b4903cec
|
3 |
+
size 266808135
|
requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fvcore==0.1.5.post20221221
|
2 |
+
numpy==1.26.2
|
3 |
+
omegaconf==2.3.0
|
4 |
+
opencv-python==4.8.1.78
|
5 |
+
opencv-python-headless==4.6.0.66
|
6 |
+
pandas==1.4.4
|
7 |
+
Pillow==10.3.0
|
8 |
+
scikit-image==0.22.0
|
9 |
+
scikit-learn==1.5.0
|
10 |
+
scipy==1.11.4
|
11 |
+
timm==0.9.12
|
12 |
+
torch==2.2.0
|
13 |
+
torchaudio==2.0.2
|
14 |
+
torchvision==0.15.2
|
15 |
+
tqdm==4.66.3
|
16 |
+
xformers==0.0.21
|
17 |
+
yacs==0.1.8
|
scripts/convert_pth.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Convert official model weights to format that d2 receives
|
2 |
+
import argparse
|
3 |
+
from collections import OrderedDict
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def convert(src: str, dst: str):
|
8 |
+
checkpoint = torch.load(src)
|
9 |
+
has_model = "model" in checkpoint.keys()
|
10 |
+
checkpoint = checkpoint["model"] if has_model else checkpoint
|
11 |
+
if "state_dict" in checkpoint.keys():
|
12 |
+
checkpoint = checkpoint["state_dict"]
|
13 |
+
out_cp = OrderedDict()
|
14 |
+
for k, v in checkpoint.items():
|
15 |
+
out_cp[".".join(["backbone", k])] = v
|
16 |
+
out_cp = {"model": out_cp} if has_model else out_cp
|
17 |
+
torch.save(out_cp, dst)
|
18 |
+
|
19 |
+
|
20 |
+
if __name__ == "__main__":
|
21 |
+
parser = argparse.ArgumentParser()
|
22 |
+
parser.add_argument(
|
23 |
+
"--src", "-s", type=str, required=True, help="Path to src weights.pth"
|
24 |
+
)
|
25 |
+
parser.add_argument(
|
26 |
+
"--dst", "-d", type=str, required=True, help="Path to dst weights.pth"
|
27 |
+
)
|
28 |
+
args = parser.parse_args()
|
29 |
+
convert(
|
30 |
+
args.src,
|
31 |
+
args.dst,
|
32 |
+
)
|
scripts/gen_gazefollow_head_masks.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import tqdm
|
8 |
+
from PIL import Image
|
9 |
+
from retinaface.pre_trained_models import get_model
|
10 |
+
|
11 |
+
|
12 |
+
random.seed(1)
|
13 |
+
|
14 |
+
|
15 |
+
def gaussian(x_min, y_min, x_max, y_max):
|
16 |
+
x_min, x_max = sorted((x_min, x_max))
|
17 |
+
y_min, y_max = sorted((y_min, y_max))
|
18 |
+
x_mid, y_mid = (x_min + x_max) / 2, (y_min + y_max) / 2
|
19 |
+
x_sigma2, y_sigma2 = (
|
20 |
+
np.clip(np.square([x_max - x_min, y_max - y_min], dtype=float), 1, None) / 3
|
21 |
+
)
|
22 |
+
|
23 |
+
def _gaussian(_xs, _ys):
|
24 |
+
return np.exp(
|
25 |
+
-(np.square(_xs - x_mid) / x_sigma2 + np.square(_ys - y_mid) / y_sigma2)
|
26 |
+
)
|
27 |
+
|
28 |
+
return _gaussian
|
29 |
+
|
30 |
+
|
31 |
+
def plot_ori(label_path, data_dir):
|
32 |
+
df = pd.read_csv(
|
33 |
+
label_path,
|
34 |
+
names=[
|
35 |
+
# Original labels
|
36 |
+
"path",
|
37 |
+
"idx",
|
38 |
+
"body_bbox_x",
|
39 |
+
"body_bbox_y",
|
40 |
+
"body_bbox_w",
|
41 |
+
"body_bbox_h",
|
42 |
+
"eye_x",
|
43 |
+
"eye_y",
|
44 |
+
"gaze_x",
|
45 |
+
"gaze_y",
|
46 |
+
"x_min",
|
47 |
+
"y_min",
|
48 |
+
"x_max",
|
49 |
+
"y_max",
|
50 |
+
"inout",
|
51 |
+
"meta0",
|
52 |
+
"meta1",
|
53 |
+
],
|
54 |
+
index_col=False,
|
55 |
+
encoding="utf-8-sig",
|
56 |
+
)
|
57 |
+
grouped = df.groupby("path")
|
58 |
+
|
59 |
+
output_dir = os.path.join(data_dir, "head_masks")
|
60 |
+
|
61 |
+
for image_name, group_df in tqdm.tqdm(grouped, desc="Generating masks with annotations: "):
|
62 |
+
if not os.path.exists(os.path.join(output_dir, image_name)):
|
63 |
+
w, h = Image.open(image_name).size
|
64 |
+
heatmap = np.zeros((h, w), dtype=np.float32)
|
65 |
+
indices = np.meshgrid(
|
66 |
+
np.linspace(0.0, float(w), num=w, endpoint=False),
|
67 |
+
np.linspace(0.0, float(h), num=h, endpoint=False),
|
68 |
+
)
|
69 |
+
for _, row in group_df.iterrows():
|
70 |
+
x_min, y_min, x_max, y_max = (
|
71 |
+
row["x_min"],
|
72 |
+
row["y_min"],
|
73 |
+
row["x_max"],
|
74 |
+
row["y_max"],
|
75 |
+
)
|
76 |
+
gauss = gaussian(x_min, y_min, x_max, y_max)
|
77 |
+
heatmap += gauss(*indices)
|
78 |
+
heatmap /= np.max(heatmap)
|
79 |
+
heatmap_image = Image.fromarray((heatmap * 255).astype(np.uint8), mode="L")
|
80 |
+
output_filename = os.path.join(output_dir, image_name)
|
81 |
+
os.makedirs(os.path.dirname(output_filename), exist_ok=True)
|
82 |
+
heatmap_image.save(output_filename)
|
83 |
+
|
84 |
+
|
85 |
+
def plot_gen(df, data_dir):
|
86 |
+
df = df[df["score"] > 0.8]
|
87 |
+
grouped = df.groupby("path")
|
88 |
+
|
89 |
+
output_dir = os.path.join(data_dir, "head_masks")
|
90 |
+
|
91 |
+
for image_name, group_df in tqdm.tqdm(grouped, desc="Generating masks with predictions: "):
|
92 |
+
w, h = Image.open(image_name).size
|
93 |
+
heatmap = np.zeros((h, w), dtype=np.float32)
|
94 |
+
indices = np.meshgrid(
|
95 |
+
np.linspace(0.0, float(w), num=w, endpoint=False),
|
96 |
+
np.linspace(0.0, float(h), num=h, endpoint=False),
|
97 |
+
)
|
98 |
+
for index, row in group_df.iterrows():
|
99 |
+
x_min, y_min, x_max, y_max = (
|
100 |
+
row["x_min"],
|
101 |
+
row["y_min"],
|
102 |
+
row["x_max"],
|
103 |
+
row["y_max"],
|
104 |
+
)
|
105 |
+
gauss = gaussian(x_min, y_min, x_max, y_max)
|
106 |
+
heatmap += gauss(*indices)
|
107 |
+
heatmap /= np.max(heatmap)
|
108 |
+
heatmap_image = Image.fromarray((heatmap * 255).astype(np.uint8), mode="L")
|
109 |
+
output_filename = os.path.join(output_dir, image_name)
|
110 |
+
os.makedirs(os.path.dirname(output_filename), exist_ok=True)
|
111 |
+
heatmap_image.save(output_filename)
|
112 |
+
|
113 |
+
|
114 |
+
if __name__ == "__main__":
|
115 |
+
parser = argparse.ArgumentParser()
|
116 |
+
parser.add_argument("--dataset_dir", help="Root directory of dataset")
|
117 |
+
parser.add_argument(
|
118 |
+
"--subset",
|
119 |
+
help="Subset of dataset to process",
|
120 |
+
choices=["train", "test"],
|
121 |
+
)
|
122 |
+
args = parser.parse_args()
|
123 |
+
|
124 |
+
label_path = os.path.join(
|
125 |
+
args.dataset_dir, args.subset + "_annotations_release.txt"
|
126 |
+
)
|
127 |
+
|
128 |
+
column_names = [
|
129 |
+
"path",
|
130 |
+
"idx",
|
131 |
+
"body_bbox_x",
|
132 |
+
"body_bbox_y",
|
133 |
+
"body_bbox_w",
|
134 |
+
"body_bbox_h",
|
135 |
+
"eye_x",
|
136 |
+
"eye_y",
|
137 |
+
"gaze_x",
|
138 |
+
"gaze_y",
|
139 |
+
"bbox_x_min",
|
140 |
+
"bbox_y_min",
|
141 |
+
"bbox_x_max",
|
142 |
+
"bbox_y_max",
|
143 |
+
]
|
144 |
+
df = pd.read_csv(
|
145 |
+
label_path,
|
146 |
+
sep=",",
|
147 |
+
names=column_names,
|
148 |
+
usecols=column_names,
|
149 |
+
index_col=False,
|
150 |
+
)
|
151 |
+
df = df.groupby("path")
|
152 |
+
|
153 |
+
model = get_model("resnet50_2020-07-20", max_size=2048, device="cuda")
|
154 |
+
model.eval()
|
155 |
+
|
156 |
+
paths = list(df.groups.keys())
|
157 |
+
csv = []
|
158 |
+
for path in tqdm.tqdm(paths, desc="Predicting head bboxes: "):
|
159 |
+
img = cv2.imread(os.path.join(args.dataset_dir, path))
|
160 |
+
|
161 |
+
annotations = model.predict_jsons(img)
|
162 |
+
|
163 |
+
for annotation in annotations:
|
164 |
+
if len(annotation["bbox"]) == 0:
|
165 |
+
continue
|
166 |
+
|
167 |
+
csv.append(
|
168 |
+
[
|
169 |
+
path,
|
170 |
+
annotation["score"],
|
171 |
+
annotation["bbox"][0],
|
172 |
+
annotation["bbox"][1],
|
173 |
+
annotation["bbox"][2],
|
174 |
+
annotation["bbox"][3],
|
175 |
+
]
|
176 |
+
)
|
177 |
+
|
178 |
+
# Write csv
|
179 |
+
df = pd.DataFrame(
|
180 |
+
csv, columns=["path", "score", "x_min", "y_min", "x_max", "y_max"]
|
181 |
+
)
|
182 |
+
df.to_csv(os.path.join(args.dataset_dir, f"{args.subset}_head.csv"), index=False)
|
183 |
+
|
184 |
+
plot_gen(df, args.dataset_dir)
|
185 |
+
plot_ori(label_path, args.data_dir)
|
tools/eval_on_gazefollow.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from os import path as osp
|
3 |
+
import argparse
|
4 |
+
import warnings
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
from detectron2.config import instantiate, LazyConfig
|
9 |
+
|
10 |
+
sys.path.append(osp.dirname(osp.dirname(__file__)))
|
11 |
+
from utils import *
|
12 |
+
|
13 |
+
|
14 |
+
warnings.simplefilter(action="ignore", category=FutureWarning)
|
15 |
+
|
16 |
+
|
17 |
+
def do_test(cfg, model, use_dark_inference=False):
|
18 |
+
val_loader = instantiate(cfg.dataloader.val)
|
19 |
+
|
20 |
+
model.train(False)
|
21 |
+
AUC = []
|
22 |
+
min_dist = []
|
23 |
+
avg_dist = []
|
24 |
+
with torch.no_grad():
|
25 |
+
for data in val_loader:
|
26 |
+
val_gaze_heatmap_pred, _ = model(data)
|
27 |
+
val_gaze_heatmap_pred = (
|
28 |
+
val_gaze_heatmap_pred.squeeze(1).cpu().detach().numpy()
|
29 |
+
)
|
30 |
+
|
31 |
+
# go through each data point and record AUC, min dist, avg dist
|
32 |
+
for b_i in range(len(val_gaze_heatmap_pred)):
|
33 |
+
# remove padding and recover valid ground truth points
|
34 |
+
valid_gaze = data["gazes"][b_i]
|
35 |
+
valid_gaze = valid_gaze[valid_gaze != -1].view(-1, 2)
|
36 |
+
# AUC: area under curve of ROC
|
37 |
+
multi_hot = multi_hot_targets(data["gazes"][b_i], data["imsize"][b_i])
|
38 |
+
if use_dark_inference:
|
39 |
+
pred_x, pred_y = dark_inference(val_gaze_heatmap_pred[b_i])
|
40 |
+
else:
|
41 |
+
pred_x, pred_y = argmax_pts(val_gaze_heatmap_pred[b_i])
|
42 |
+
norm_p = [
|
43 |
+
pred_x / val_gaze_heatmap_pred[b_i].shape[-2],
|
44 |
+
pred_y / val_gaze_heatmap_pred[b_i].shape[-1],
|
45 |
+
]
|
46 |
+
scaled_heatmap = np.array(
|
47 |
+
Image.fromarray(val_gaze_heatmap_pred[b_i]).resize(
|
48 |
+
data["imsize"][b_i],
|
49 |
+
resample=Image.BILINEAR,
|
50 |
+
)
|
51 |
+
)
|
52 |
+
auc_score = auc(scaled_heatmap, multi_hot)
|
53 |
+
AUC.append(auc_score)
|
54 |
+
# min distance: minimum among all possible pairs of <ground truth point, predicted point>
|
55 |
+
all_distances = []
|
56 |
+
for gt_gaze in valid_gaze:
|
57 |
+
all_distances.append(L2_dist(gt_gaze, norm_p))
|
58 |
+
min_dist.append(min(all_distances))
|
59 |
+
# average distance: distance between the predicted point and human average point
|
60 |
+
mean_gt_gaze = torch.mean(valid_gaze, 0)
|
61 |
+
avg_distance = L2_dist(mean_gt_gaze, norm_p)
|
62 |
+
avg_dist.append(avg_distance)
|
63 |
+
|
64 |
+
print("|AUC |min dist|avg dist|")
|
65 |
+
print(
|
66 |
+
"|{:.4f}|{:.4f} |{:.4f} |".format(
|
67 |
+
torch.mean(torch.tensor(AUC)),
|
68 |
+
torch.mean(torch.tensor(min_dist)),
|
69 |
+
torch.mean(torch.tensor(avg_dist)),
|
70 |
+
)
|
71 |
+
)
|
72 |
+
|
73 |
+
|
74 |
+
def main(args):
|
75 |
+
cfg = LazyConfig.load(args.config_file)
|
76 |
+
model: torch.Module = instantiate(cfg.model)
|
77 |
+
model.load_state_dict(torch.load(args.model_weights)["model"])
|
78 |
+
model.to(cfg.train.device)
|
79 |
+
do_test(cfg, model, use_dark_inference=args.use_dark_inference)
|
80 |
+
|
81 |
+
|
82 |
+
if __name__ == "__main__":
|
83 |
+
parser = argparse.ArgumentParser()
|
84 |
+
parser.add_argument("--config_file", type=str, help="config file")
|
85 |
+
parser.add_argument(
|
86 |
+
"--model_weights",
|
87 |
+
type=str,
|
88 |
+
help="model weights",
|
89 |
+
)
|
90 |
+
parser.add_argument("--use_dark_inference", action="store_true")
|
91 |
+
args = parser.parse_args()
|
92 |
+
main(args)
|
tools/eval_on_video_attention_target.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from os import path as osp
|
3 |
+
import argparse
|
4 |
+
import warnings
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
from detectron2.config import instantiate, LazyConfig
|
9 |
+
|
10 |
+
sys.path.append(osp.dirname(osp.dirname(__file__)))
|
11 |
+
from utils import *
|
12 |
+
|
13 |
+
|
14 |
+
warnings.simplefilter(action="ignore", category=FutureWarning)
|
15 |
+
|
16 |
+
|
17 |
+
def do_test(cfg, model, use_dark_inference=False):
|
18 |
+
val_loader = instantiate(cfg.dataloader.val)
|
19 |
+
|
20 |
+
model.train(False)
|
21 |
+
AUC = []
|
22 |
+
dist = []
|
23 |
+
inout_gt = []
|
24 |
+
inout_pred = []
|
25 |
+
with torch.no_grad():
|
26 |
+
for data in val_loader:
|
27 |
+
val_gaze_heatmap_pred, val_gaze_inout_pred = model(data)
|
28 |
+
val_gaze_heatmap_pred = (
|
29 |
+
val_gaze_heatmap_pred.squeeze(1).cpu().detach().numpy()
|
30 |
+
)
|
31 |
+
val_gaze_inout_pred = val_gaze_inout_pred.cpu().detach().numpy()
|
32 |
+
|
33 |
+
# go through each data point and record AUC, dist, ap
|
34 |
+
for b_i in range(len(val_gaze_heatmap_pred)):
|
35 |
+
auc_batch = []
|
36 |
+
dist_batch = []
|
37 |
+
if data["gaze_inouts"][b_i]:
|
38 |
+
# remove padding and recover valid ground truth points
|
39 |
+
valid_gaze = data["gazes"][b_i]
|
40 |
+
# AUC: area under curve of ROC
|
41 |
+
multi_hot = data["heatmaps"][b_i]
|
42 |
+
multi_hot = (multi_hot > 0).float().numpy()
|
43 |
+
if use_dark_inference:
|
44 |
+
pred_x, pred_y = dark_inference(val_gaze_heatmap_pred[b_i])
|
45 |
+
else:
|
46 |
+
pred_x, pred_y = argmax_pts(val_gaze_heatmap_pred[b_i])
|
47 |
+
norm_p = [
|
48 |
+
pred_x / val_gaze_heatmap_pred[b_i].shape[-1],
|
49 |
+
pred_y / val_gaze_heatmap_pred[b_i].shape[-2],
|
50 |
+
]
|
51 |
+
scaled_heatmap = np.array(
|
52 |
+
Image.fromarray(val_gaze_heatmap_pred[b_i]).resize(
|
53 |
+
(64, 64),
|
54 |
+
resample=Image.Resampling.BILINEAR,
|
55 |
+
)
|
56 |
+
)
|
57 |
+
auc_score = auc(scaled_heatmap, multi_hot)
|
58 |
+
auc_batch.append(auc_score)
|
59 |
+
dist_batch.append(L2_dist(valid_gaze.numpy(), norm_p))
|
60 |
+
AUC.extend(auc_batch)
|
61 |
+
dist.extend(dist_batch)
|
62 |
+
inout_gt.extend(data["gaze_inouts"].cpu().numpy())
|
63 |
+
inout_pred.extend(val_gaze_inout_pred)
|
64 |
+
|
65 |
+
print("|AUC |dist |AP |")
|
66 |
+
print(
|
67 |
+
"|{:.4f}|{:.4f} |{:.4f} |".format(
|
68 |
+
torch.mean(torch.tensor(AUC)),
|
69 |
+
torch.mean(torch.tensor(dist)),
|
70 |
+
ap(inout_gt, inout_pred),
|
71 |
+
)
|
72 |
+
)
|
73 |
+
|
74 |
+
|
75 |
+
def main(args):
|
76 |
+
cfg = LazyConfig.load(args.config_file)
|
77 |
+
model: torch.Module = instantiate(cfg.model)
|
78 |
+
model.load_state_dict(torch.load(args.model_weights)["model"])
|
79 |
+
model.to(cfg.train.device)
|
80 |
+
do_test(cfg, model, use_dark_inference=args.use_dark_inference)
|
81 |
+
|
82 |
+
|
83 |
+
if __name__ == "__main__":
|
84 |
+
parser = argparse.ArgumentParser()
|
85 |
+
parser.add_argument("--config_file", type=str, help="config file")
|
86 |
+
parser.add_argument(
|
87 |
+
"--model_weights",
|
88 |
+
type=str,
|
89 |
+
help="model weights",
|
90 |
+
)
|
91 |
+
parser.add_argument("--use_dark_inference", action="store_true")
|
92 |
+
args = parser.parse_args()
|
93 |
+
main(args)
|
tools/train.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
import logging
|
3 |
+
import warnings
|
4 |
+
import sys
|
5 |
+
|
6 |
+
from detectron2.checkpoint import DetectionCheckpointer
|
7 |
+
from detectron2.config import LazyConfig, instantiate
|
8 |
+
from detectron2.engine import (
|
9 |
+
default_argument_parser,
|
10 |
+
default_setup,
|
11 |
+
default_writers,
|
12 |
+
hooks,
|
13 |
+
launch,
|
14 |
+
)
|
15 |
+
from detectron2.engine.defaults import create_ddp_model
|
16 |
+
from detectron2.utils import comm
|
17 |
+
|
18 |
+
|
19 |
+
sys.path.append(osp.dirname(osp.dirname(__file__)))
|
20 |
+
warnings.filterwarnings("ignore")
|
21 |
+
logger = logging.getLogger("detectron2")
|
22 |
+
|
23 |
+
|
24 |
+
from engine import CycleTrainer
|
25 |
+
|
26 |
+
|
27 |
+
def do_train(args, cfg):
|
28 |
+
"""
|
29 |
+
Args:
|
30 |
+
cfg: an object with the following attributes:
|
31 |
+
model: instantiate to a module
|
32 |
+
dataloader.{train,test}: instantiate to dataloaders
|
33 |
+
dataloader.evaluator: instantiate to evaluator for test set
|
34 |
+
optimizer: instantaite to an optimizer
|
35 |
+
lr_multiplier: instantiate to a fvcore scheduler
|
36 |
+
train: other misc config defined in `configs/common/train.py`, including:
|
37 |
+
output_dir (str)
|
38 |
+
init_checkpoint (str)
|
39 |
+
amp.enabled (bool)
|
40 |
+
max_iter (int)
|
41 |
+
eval_period, log_period (int)
|
42 |
+
device (str)
|
43 |
+
checkpointer (dict)
|
44 |
+
ddp (dict)
|
45 |
+
"""
|
46 |
+
model = instantiate(cfg.model)
|
47 |
+
logger = logging.getLogger("detectron2")
|
48 |
+
logger.info("Model:\n{}".format(model))
|
49 |
+
model.to(cfg.train.device)
|
50 |
+
|
51 |
+
cfg.optimizer.params.model = model
|
52 |
+
optim = instantiate(cfg.optimizer)
|
53 |
+
|
54 |
+
train_loader = instantiate(cfg.dataloader.train)
|
55 |
+
|
56 |
+
model = create_ddp_model(model, **cfg.train.ddp)
|
57 |
+
trainer = CycleTrainer(model, train_loader, optim)
|
58 |
+
checkpointer = DetectionCheckpointer(
|
59 |
+
model,
|
60 |
+
cfg.train.output_dir,
|
61 |
+
trainer=trainer,
|
62 |
+
)
|
63 |
+
trainer.register_hooks(
|
64 |
+
[
|
65 |
+
hooks.IterationTimer(),
|
66 |
+
hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)),
|
67 |
+
hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer)
|
68 |
+
if comm.is_main_process()
|
69 |
+
else None,
|
70 |
+
# hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)),
|
71 |
+
hooks.PeriodicWriter(
|
72 |
+
default_writers(cfg.train.output_dir, cfg.train.max_iter),
|
73 |
+
period=cfg.train.log_period,
|
74 |
+
)
|
75 |
+
if comm.is_main_process()
|
76 |
+
else None,
|
77 |
+
]
|
78 |
+
)
|
79 |
+
|
80 |
+
checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume)
|
81 |
+
if args.resume and checkpointer.has_checkpoint():
|
82 |
+
start_iter = trainer.iter + 1
|
83 |
+
else:
|
84 |
+
start_iter = 0
|
85 |
+
trainer.train(start_iter, cfg.train.max_iter)
|
86 |
+
|
87 |
+
|
88 |
+
def main(args):
|
89 |
+
cfg = LazyConfig.load(args.config_file)
|
90 |
+
cfg = LazyConfig.apply_overrides(cfg, args.opts)
|
91 |
+
default_setup(cfg, args)
|
92 |
+
do_train(args, cfg)
|
93 |
+
|
94 |
+
|
95 |
+
if __name__ == "__main__":
|
96 |
+
args = default_argument_parser().parse_args()
|
97 |
+
launch(
|
98 |
+
main,
|
99 |
+
args.num_gpus,
|
100 |
+
num_machines=args.num_machines,
|
101 |
+
machine_rank=args.machine_rank,
|
102 |
+
dist_url=args.dist_url,
|
103 |
+
args=(args,),
|
104 |
+
)
|