geopavlakos commited on
Commit
d7a991a
1 Parent(s): 023c893

Initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +5 -5
  2. _DATA/data/mano/MANO_RIGHT.pkl +3 -0
  3. _DATA/data/mano_mean_params.npz +3 -0
  4. _DATA/hamer_ckpts/checkpoints/hamer.ckpt +3 -0
  5. _DATA/hamer_ckpts/dataset_config.yaml +42 -0
  6. _DATA/hamer_ckpts/model_config.yaml +111 -0
  7. _DATA/vitpose_ckpts/vitpose+_huge/wholebody.pth +3 -0
  8. app.py +234 -0
  9. assets/list.txt +0 -0
  10. assets/test1.jpg +0 -0
  11. assets/test2.jpg +0 -0
  12. assets/test3.jpg +0 -0
  13. assets/test4.jpg +0 -0
  14. assets/test5.jpg +0 -0
  15. hamer/__init__.py +0 -0
  16. hamer/configs/__init__.py +111 -0
  17. hamer/configs/cascade_mask_rcnn_vitdet_h_75ep.py +129 -0
  18. hamer/configs/datasets_tar.yaml +42 -0
  19. hamer/configs_hydra/data/mix_all.yaml +31 -0
  20. hamer/configs_hydra/data_filtering/low1.yaml +13 -0
  21. hamer/configs_hydra/experiment/default.yaml +29 -0
  22. hamer/configs_hydra/experiment/hamer_vit_transformer.yaml +51 -0
  23. hamer/configs_hydra/extras/default.yaml +8 -0
  24. hamer/configs_hydra/hydra/default.yaml +26 -0
  25. hamer/configs_hydra/launcher/local.yaml +13 -0
  26. hamer/configs_hydra/launcher/slurm.yaml +22 -0
  27. hamer/configs_hydra/paths/default.yaml +18 -0
  28. hamer/configs_hydra/train.yaml +47 -0
  29. hamer/configs_hydra/trainer/cpu.yaml +6 -0
  30. hamer/configs_hydra/trainer/ddp.yaml +14 -0
  31. hamer/configs_hydra/trainer/default.yaml +10 -0
  32. hamer/configs_hydra/trainer/default_hamer.yaml +8 -0
  33. hamer/configs_hydra/trainer/gpu.yaml +6 -0
  34. hamer/configs_hydra/trainer/mps.yaml +6 -0
  35. hamer/datasets/__init__.py +56 -0
  36. hamer/datasets/dataset.py +27 -0
  37. hamer/datasets/image_dataset.py +275 -0
  38. hamer/datasets/json_dataset.py +213 -0
  39. hamer/datasets/mocap_dataset.py +25 -0
  40. hamer/datasets/utils.py +993 -0
  41. hamer/datasets/vitdet_dataset.py +97 -0
  42. hamer/models/__init__.py +46 -0
  43. hamer/models/backbones/__init__.py +7 -0
  44. hamer/models/backbones/vit.py +348 -0
  45. hamer/models/components/__init__.py +0 -0
  46. hamer/models/components/pose_transformer.py +358 -0
  47. hamer/models/components/t_cond_mlp.py +199 -0
  48. hamer/models/discriminator.py +99 -0
  49. hamer/models/hamer.py +363 -0
  50. hamer/models/heads/__init__.py +1 -0
README.md CHANGED
@@ -1,12 +1,12 @@
1
  ---
2
- title: HaMeR Test
3
- emoji: 📚
4
- colorFrom: pink
5
- colorTo: purple
6
  sdk: gradio
7
  sdk_version: 4.8.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: HaMeR
3
+ emoji: 🔥
4
+ colorFrom: yellow
5
+ colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 4.8.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
_DATA/data/mano/MANO_RIGHT.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:45d60aa3b27ef9107a7afd4e00808f307fd91111e1cfa35afd5c4a62de264767
3
+ size 3821356
_DATA/data/mano_mean_params.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:efc0ec58e4a5cef78f3abfb4e8f91623b8950be9eff8b8e0dbb0d036ebc63988
3
+ size 1178
_DATA/hamer_ckpts/checkpoints/hamer.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5cc06f294d88a92dee24e603480aab04de532b49f0e08200804ee7d90e16f53
3
+ size 2689536166
_DATA/hamer_ckpts/dataset_config.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ COCOW-TRAIN:
2
+ TYPE: ImageDataset
3
+ URLS: hamer_training_data/dataset_tars/cocow-train/{000000..000036}.tar
4
+ epoch_size: 78666
5
+ DEX-TRAIN:
6
+ TYPE: ImageDataset
7
+ URLS: hamer_training_data/dataset_tars/dex-train/{000000..000406}.tar
8
+ epoch_size: 406888
9
+ FREIHAND-MOCAP:
10
+ DATASET_FILE: hamer_training_data/freihand_mocap.npz
11
+ FREIHAND-TRAIN:
12
+ TYPE: ImageDataset
13
+ URLS: hamer_training_data/dataset_tars/freihand-train/{000000..000130}.tar
14
+ epoch_size: 130240
15
+ H2O3D-TRAIN:
16
+ TYPE: ImageDataset
17
+ URLS: hamer_training_data/dataset_tars/h2o3d-train/{000000..000060}.tar
18
+ epoch_size: 121996
19
+ HALPE-TRAIN:
20
+ TYPE: ImageDataset
21
+ URLS: hamer_training_data/dataset_tars/halpe-train/{000000..000022}.tar
22
+ epoch_size: 34289
23
+ HO3D-TRAIN:
24
+ TYPE: ImageDataset
25
+ URLS: hamer_training_data/dataset_tars/ho3d-train/{000000..000083}.tar
26
+ epoch_size: 83325
27
+ INTERHAND26M-TRAIN:
28
+ TYPE: ImageDataset
29
+ URLS: hamer_training_data/dataset_tars/interhand26m-train/{000000..001056}.tar
30
+ epoch_size: 1424632
31
+ MPIINZSL-TRAIN:
32
+ TYPE: ImageDataset
33
+ URLS: hamer_training_data/dataset_tars/mpiinzsl-train/{000000..000015}.tar
34
+ epoch_size: 15184
35
+ MTC-TRAIN:
36
+ TYPE: ImageDataset
37
+ URLS: hamer_training_data/dataset_tars/mtc-train/{000000..000306}.tar
38
+ epoch_size: 363947
39
+ RHD-TRAIN:
40
+ TYPE: ImageDataset
41
+ URLS: hamer_training_data/dataset_tars/rhd-train/{000000..000041}.tar
42
+ epoch_size: 61705
_DATA/hamer_ckpts/model_config.yaml ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ task_name: train
2
+ tags:
3
+ - dev
4
+ train: true
5
+ test: false
6
+ ckpt_path: null
7
+ seed: null
8
+ DATASETS:
9
+ TRAIN:
10
+ FREIHAND-TRAIN:
11
+ WEIGHT: 0.25
12
+ INTERHAND26M-TRAIN:
13
+ WEIGHT: 0.25
14
+ MTC-TRAIN:
15
+ WEIGHT: 0.1
16
+ RHD-TRAIN:
17
+ WEIGHT: 0.05
18
+ COCOW-TRAIN:
19
+ WEIGHT: 0.1
20
+ HALPE-TRAIN:
21
+ WEIGHT: 0.05
22
+ MPIINZSL-TRAIN:
23
+ WEIGHT: 0.05
24
+ HO3D-TRAIN:
25
+ WEIGHT: 0.05
26
+ H2O3D-TRAIN:
27
+ WEIGHT: 0.05
28
+ DEX-TRAIN:
29
+ WEIGHT: 0.05
30
+ VAL:
31
+ FREIHAND-TRAIN:
32
+ WEIGHT: 1.0
33
+ MOCAP: FREIHAND-MOCAP
34
+ BETAS_REG: true
35
+ CONFIG:
36
+ SCALE_FACTOR: 0.3
37
+ ROT_FACTOR: 30
38
+ TRANS_FACTOR: 0.02
39
+ COLOR_SCALE: 0.2
40
+ ROT_AUG_RATE: 0.6
41
+ TRANS_AUG_RATE: 0.5
42
+ DO_FLIP: false
43
+ FLIP_AUG_RATE: 0.0
44
+ EXTREME_CROP_AUG_RATE: 0.0
45
+ EXTREME_CROP_AUG_LEVEL: 1
46
+ extras:
47
+ ignore_warnings: false
48
+ enforce_tags: true
49
+ print_config: true
50
+ exp_name: hamer
51
+ MANO:
52
+ DATA_DIR: _DATA/data/
53
+ MODEL_PATH: _DATA/data/mano
54
+ GENDER: neutral
55
+ NUM_HAND_JOINTS: 15
56
+ MEAN_PARAMS: _DATA/data/mano_mean_params.npz
57
+ CREATE_BODY_POSE: false
58
+ EXTRA:
59
+ FOCAL_LENGTH: 5000
60
+ NUM_LOG_IMAGES: 4
61
+ NUM_LOG_SAMPLES_PER_IMAGE: 8
62
+ PELVIS_IND: 0
63
+ GENERAL:
64
+ TOTAL_STEPS: 1000000
65
+ LOG_STEPS: 1000
66
+ VAL_STEPS: 1000
67
+ CHECKPOINT_STEPS: 10000
68
+ CHECKPOINT_SAVE_TOP_K: 1
69
+ NUM_WORKERS: 8
70
+ PREFETCH_FACTOR: 2
71
+ TRAIN:
72
+ LR: 1.0e-05
73
+ WEIGHT_DECAY: 0.0001
74
+ BATCH_SIZE: 32
75
+ LOSS_REDUCTION: mean
76
+ NUM_TRAIN_SAMPLES: 2
77
+ NUM_TEST_SAMPLES: 64
78
+ POSE_2D_NOISE_RATIO: 0.01
79
+ SMPL_PARAM_NOISE_RATIO: 0.005
80
+ MODEL:
81
+ IMAGE_SIZE: 256
82
+ IMAGE_MEAN:
83
+ - 0.485
84
+ - 0.456
85
+ - 0.406
86
+ IMAGE_STD:
87
+ - 0.229
88
+ - 0.224
89
+ - 0.225
90
+ BACKBONE:
91
+ TYPE: vit
92
+ PRETRAINED_WEIGHTS: hamer_training_data/vitpose_backbone.pth
93
+ MANO_HEAD:
94
+ TYPE: transformer_decoder
95
+ IN_CHANNELS: 2048
96
+ TRANSFORMER_DECODER:
97
+ depth: 6
98
+ heads: 8
99
+ mlp_dim: 1024
100
+ dim_head: 64
101
+ dropout: 0.0
102
+ emb_dropout: 0.0
103
+ norm: layer
104
+ context_dim: 1280
105
+ LOSS_WEIGHTS:
106
+ KEYPOINTS_3D: 0.05
107
+ KEYPOINTS_2D: 0.01
108
+ GLOBAL_ORIENT: 0.001
109
+ HAND_POSE: 0.001
110
+ BETAS: 0.0005
111
+ ADVERSARIAL: 0.0005
_DATA/vitpose_ckpts/vitpose+_huge/wholebody.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b0555e1e2392e6a2be2d9265368f344d70ccbfd656ad480aa5c1de2e604519c9
3
+ size 3807742341
app.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from pathlib import Path
4
+ import tempfile
5
+ import sys
6
+ import cv2
7
+ import gradio as gr
8
+ import numpy as np
9
+ import torch
10
+ from PIL import Image
11
+
12
+ # print file path
13
+ print(os.path.abspath(__file__))
14
+ os.environ["PYOPENGL_PLATFORM"] = "egl"
15
+ os.environ["MESA_GL_VERSION_OVERRIDE"] = "4.1"
16
+ os.system('pip install /home/user/app/pyrender')
17
+ sys.path.append('/home/user/app/pyrender')
18
+
19
+ from hamer.configs import get_config
20
+ from hamer.datasets.vitdet_dataset import (DEFAULT_MEAN, DEFAULT_STD,
21
+ ViTDetDataset)
22
+ from hamer.models import HAMER
23
+ from hamer.utils import recursive_to
24
+ from hamer.utils.renderer import Renderer, cam_crop_to_full
25
+
26
+ try:
27
+ import detectron2
28
+ except:
29
+ import os
30
+ os.system('pip install --upgrade pip')
31
+ os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
32
+
33
+ #try:
34
+ # from vitpose_model import ViTPoseModel
35
+ #except:
36
+ # os.system('pip install -v -e /home/user/app/vendor/ViTPose')
37
+ # from vitpose_model import ViTPoseModel
38
+ from vitpose_model import ViTPoseModel
39
+
40
+ OUT_FOLDER = 'demo_out'
41
+ os.makedirs(OUT_FOLDER, exist_ok=True)
42
+
43
+ # Setup HaMeR model
44
+ LIGHT_BLUE=(0.65098039, 0.74117647, 0.85882353)
45
+ DEFAULT_CHECKPOINT='_DATA/hamer_ckpts/checkpoints/hamer.ckpt'
46
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
47
+ model_cfg = str(Path(DEFAULT_CHECKPOINT).parent.parent / 'model_config.yaml')
48
+ model_cfg = get_config(model_cfg)
49
+ model = HAMER.load_from_checkpoint(DEFAULT_CHECKPOINT, strict=False, cfg=model_cfg).to(device)
50
+ model.eval()
51
+
52
+
53
+ # Load detector
54
+ from detectron2.config import LazyConfig
55
+
56
+ from hamer.utils.utils_detectron2 import DefaultPredictor_Lazy
57
+
58
+ detectron2_cfg = LazyConfig.load(f"vendor/detectron2/projects/ViTDet/configs/COCO/cascade_mask_rcnn_vitdet_h_75ep.py")
59
+ detectron2_cfg.train.init_checkpoint = "https://dl.fbaipublicfiles.com/detectron2/ViTDet/COCO/cascade_mask_rcnn_vitdet_h/f328730692/model_final_f05665.pkl"
60
+ for i in range(3):
61
+ detectron2_cfg.model.roi_heads.box_predictors[i].test_score_thresh = 0.25
62
+ detector = DefaultPredictor_Lazy(detectron2_cfg)
63
+
64
+ # Setup the renderer
65
+ renderer = Renderer(model_cfg, faces=model.mano.faces)
66
+
67
+ # keypoint detector
68
+ cpm = ViTPoseModel(device)
69
+
70
+ import numpy as np
71
+
72
+ def infer(in_pil_img, in_threshold=0.8, out_pil_img=None):
73
+
74
+ open_cv_image = np.array(in_pil_img)
75
+ # Convert RGB to BGR
76
+ open_cv_image = open_cv_image[:, :, ::-1].copy()
77
+ print("EEEEE", open_cv_image.shape)
78
+ det_out = detector(open_cv_image)
79
+ det_instances = det_out['instances']
80
+ valid_idx = (det_instances.pred_classes==0) & (det_instances.scores > in_threshold)
81
+ pred_bboxes=det_instances.pred_boxes.tensor[valid_idx].cpu().numpy()
82
+ pred_scores=det_instances.scores[valid_idx].cpu().numpy()
83
+
84
+
85
+ # Detect human keypoints for each person
86
+ vitposes_out = cpm.predict_pose(
87
+ open_cv_image,
88
+ [np.concatenate([pred_bboxes, pred_scores[:, None]], axis=1)],
89
+ )
90
+
91
+ bboxes = []
92
+ is_right = []
93
+
94
+ # Use hands based on hand keypoint detections
95
+ for vitposes in vitposes_out:
96
+ left_hand_keyp = vitposes['keypoints'][-42:-21]
97
+ right_hand_keyp = vitposes['keypoints'][-21:]
98
+
99
+ # Rejecting not confident detections (this could be improved)
100
+ keyp = left_hand_keyp
101
+ valid = keyp[:,2] > 0.5
102
+ if sum(valid) > 3:
103
+ bbox = [keyp[valid,0].min(), keyp[valid,1].min(), keyp[valid,0].max(), keyp[valid,1].max()]
104
+ bboxes.append(bbox)
105
+ is_right.append(0)
106
+ keyp = right_hand_keyp
107
+ valid = keyp[:,2] > 0.5
108
+ if sum(valid) > 3:
109
+ bbox = [keyp[valid,0].min(), keyp[valid,1].min(), keyp[valid,0].max(), keyp[valid,1].max()]
110
+ bboxes.append(bbox)
111
+ is_right.append(1)
112
+
113
+ if len(bboxes) == 0:
114
+ return None, []
115
+
116
+ boxes = np.stack(bboxes)
117
+ right = np.stack(is_right)
118
+
119
+
120
+ # Run HaMeR on all detected humans
121
+ dataset = ViTDetDataset(model_cfg, open_cv_image, boxes, right)
122
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=False, num_workers=0)
123
+
124
+ all_verts = []
125
+ all_cam_t = []
126
+ all_right = []
127
+ all_mesh_paths = []
128
+
129
+ temp_name = next(tempfile._get_candidate_names())
130
+
131
+ for batch in dataloader:
132
+ batch = recursive_to(batch, device)
133
+ with torch.no_grad():
134
+ out = model(batch)
135
+
136
+ multiplier = (2*batch['right']-1)
137
+ pred_cam = out['pred_cam']
138
+ pred_cam[:,1] = multiplier*pred_cam[:,1]
139
+ box_center = batch["box_center"].float()
140
+ box_size = batch["box_size"].float()
141
+ img_size = batch["img_size"].float()
142
+ multiplier = (2*batch['right']-1)
143
+ render_size = img_size
144
+ scaled_focal_length = model_cfg.EXTRA.FOCAL_LENGTH / model_cfg.MODEL.IMAGE_SIZE * img_size.max()
145
+ pred_cam_t = cam_crop_to_full(pred_cam, box_center, box_size, render_size, scaled_focal_length).detach().cpu().numpy()
146
+
147
+ # Render the result
148
+ batch_size = batch['img'].shape[0]
149
+ for n in range(batch_size):
150
+ # Get filename from path img_path
151
+ # img_fn, _ = os.path.splitext(os.path.basename(img_path))
152
+ person_id = int(batch['personid'][n])
153
+ white_img = (torch.ones_like(batch['img'][n]).cpu() - DEFAULT_MEAN[:,None,None]/255) / (DEFAULT_STD[:,None,None]/255)
154
+ input_patch = batch['img'][n].cpu() * (DEFAULT_STD[:,None,None]/255) + (DEFAULT_MEAN[:,None,None]/255)
155
+ input_patch = input_patch.permute(1,2,0).numpy()
156
+
157
+
158
+ verts = out['pred_vertices'][n].detach().cpu().numpy()
159
+ is_right = batch['right'][n].cpu().numpy()
160
+ verts[:,0] = (2*is_right-1)*verts[:,0]
161
+ cam_t = pred_cam_t[n]
162
+
163
+ all_verts.append(verts)
164
+ all_cam_t.append(cam_t)
165
+ all_right.append(is_right)
166
+
167
+ # Save all meshes to disk
168
+ # if args.save_mesh:
169
+ if True:
170
+ camera_translation = cam_t.copy()
171
+ tmesh = renderer.vertices_to_trimesh(verts, camera_translation, LIGHT_BLUE, is_right=is_right)
172
+
173
+ temp_path = os.path.join(f'{OUT_FOLDER}/{temp_name}_{person_id}.obj')
174
+ tmesh.export(temp_path)
175
+ all_mesh_paths.append(temp_path)
176
+
177
+ # Render front view
178
+ if len(all_verts) > 0:
179
+ misc_args = dict(
180
+ mesh_base_color=LIGHT_BLUE,
181
+ scene_bg_color=(1, 1, 1),
182
+ focal_length=scaled_focal_length,
183
+ )
184
+ cam_view = renderer.render_rgba_multiple(all_verts, cam_t=all_cam_t, render_res=render_size[n], is_right=all_right, **misc_args)
185
+
186
+ # Overlay image
187
+ input_img = open_cv_image.astype(np.float32)[:,:,::-1]/255.0
188
+ input_img = np.concatenate([input_img, np.ones_like(input_img[:,:,:1])], axis=2) # Add alpha channel
189
+ input_img_overlay = input_img[:,:,:3] * (1-cam_view[:,:,3:]) + cam_view[:,:,:3] * cam_view[:,:,3:]
190
+
191
+ # convert to PIL image
192
+ out_pil_img = Image.fromarray((input_img_overlay*255).astype(np.uint8))
193
+
194
+ return out_pil_img, all_mesh_paths
195
+ else:
196
+ return None, []
197
+
198
+
199
+ with gr.Blocks(title="HaMeR", css=".gradio-container") as demo:
200
+
201
+ gr.HTML("""<div style="font-weight:bold; text-align:center; color:royalblue;">HaMeR</div>""")
202
+
203
+ with gr.Row():
204
+ with gr.Column():
205
+ input_image = gr.Image(label="Input image", type="pil")
206
+ with gr.Column():
207
+ output_image = gr.Image(label="Reconstructions", type="pil")
208
+ output_meshes = gr.File(label="3D meshes")
209
+
210
+ gr.HTML("""<br/>""")
211
+
212
+ with gr.Row():
213
+ threshold = gr.Slider(0, 1.0, value=0.6, label='Detection Threshold')
214
+ send_btn = gr.Button("Infer")
215
+ send_btn.click(fn=infer, inputs=[input_image, threshold], outputs=[output_image, output_meshes])
216
+
217
+ # with gr.Row():
218
+ example_images = gr.Examples([
219
+ ['/home/user/app/assets/test1.jpg'],
220
+ ['/home/user/app/assets/test2.jpg'],
221
+ ['/home/user/app/assets/test3.jpg'],
222
+ ['/home/user/app/assets/test4.jpg'],
223
+ ['/home/user/app/assets/test5.jpg'],
224
+ ],
225
+ inputs=[input_image, 0.6])
226
+
227
+
228
+ #demo.queue()
229
+ demo.launch(debug=True)
230
+
231
+
232
+
233
+
234
+ ### EOF ###
assets/list.txt ADDED
File without changes
assets/test1.jpg ADDED
assets/test2.jpg ADDED
assets/test3.jpg ADDED
assets/test4.jpg ADDED
assets/test5.jpg ADDED
hamer/__init__.py ADDED
File without changes
hamer/configs/__init__.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict
3
+ from yacs.config import CfgNode as CN
4
+
5
+ CACHE_DIR_HAMER = "./_DATA"
6
+
7
+ def to_lower(x: Dict) -> Dict:
8
+ """
9
+ Convert all dictionary keys to lowercase
10
+ Args:
11
+ x (dict): Input dictionary
12
+ Returns:
13
+ dict: Output dictionary with all keys converted to lowercase
14
+ """
15
+ return {k.lower(): v for k, v in x.items()}
16
+
17
+ _C = CN(new_allowed=True)
18
+
19
+ _C.GENERAL = CN(new_allowed=True)
20
+ _C.GENERAL.RESUME = True
21
+ _C.GENERAL.TIME_TO_RUN = 3300
22
+ _C.GENERAL.VAL_STEPS = 100
23
+ _C.GENERAL.LOG_STEPS = 100
24
+ _C.GENERAL.CHECKPOINT_STEPS = 20000
25
+ _C.GENERAL.CHECKPOINT_DIR = "checkpoints"
26
+ _C.GENERAL.SUMMARY_DIR = "tensorboard"
27
+ _C.GENERAL.NUM_GPUS = 1
28
+ _C.GENERAL.NUM_WORKERS = 4
29
+ _C.GENERAL.MIXED_PRECISION = True
30
+ _C.GENERAL.ALLOW_CUDA = True
31
+ _C.GENERAL.PIN_MEMORY = False
32
+ _C.GENERAL.DISTRIBUTED = False
33
+ _C.GENERAL.LOCAL_RANK = 0
34
+ _C.GENERAL.USE_SYNCBN = False
35
+ _C.GENERAL.WORLD_SIZE = 1
36
+
37
+ _C.TRAIN = CN(new_allowed=True)
38
+ _C.TRAIN.NUM_EPOCHS = 100
39
+ _C.TRAIN.BATCH_SIZE = 32
40
+ _C.TRAIN.SHUFFLE = True
41
+ _C.TRAIN.WARMUP = False
42
+ _C.TRAIN.NORMALIZE_PER_IMAGE = False
43
+ _C.TRAIN.CLIP_GRAD = False
44
+ _C.TRAIN.CLIP_GRAD_VALUE = 1.0
45
+ _C.LOSS_WEIGHTS = CN(new_allowed=True)
46
+
47
+ _C.DATASETS = CN(new_allowed=True)
48
+
49
+ _C.MODEL = CN(new_allowed=True)
50
+ _C.MODEL.IMAGE_SIZE = 224
51
+
52
+ _C.EXTRA = CN(new_allowed=True)
53
+ _C.EXTRA.FOCAL_LENGTH = 5000
54
+
55
+ _C.DATASETS.CONFIG = CN(new_allowed=True)
56
+ _C.DATASETS.CONFIG.SCALE_FACTOR = 0.3
57
+ _C.DATASETS.CONFIG.ROT_FACTOR = 30
58
+ _C.DATASETS.CONFIG.TRANS_FACTOR = 0.02
59
+ _C.DATASETS.CONFIG.COLOR_SCALE = 0.2
60
+ _C.DATASETS.CONFIG.ROT_AUG_RATE = 0.6
61
+ _C.DATASETS.CONFIG.TRANS_AUG_RATE = 0.5
62
+ _C.DATASETS.CONFIG.DO_FLIP = False
63
+ _C.DATASETS.CONFIG.FLIP_AUG_RATE = 0.5
64
+ _C.DATASETS.CONFIG.EXTREME_CROP_AUG_RATE = 0.10
65
+
66
+ def default_config() -> CN:
67
+ """
68
+ Get a yacs CfgNode object with the default config values.
69
+ """
70
+ # Return a clone so that the defaults will not be altered
71
+ # This is for the "local variable" use pattern
72
+ return _C.clone()
73
+
74
+ def dataset_config() -> CN:
75
+ """
76
+ Get dataset config file
77
+ Returns:
78
+ CfgNode: Dataset config as a yacs CfgNode object.
79
+ """
80
+ cfg = CN(new_allowed=True)
81
+ config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'datasets_tar.yaml')
82
+ cfg.merge_from_file(config_file)
83
+ cfg.freeze()
84
+ return cfg
85
+
86
+ def get_config(config_file: str, merge: bool = True, update_cachedir: bool = False) -> CN:
87
+ """
88
+ Read a config file and optionally merge it with the default config file.
89
+ Args:
90
+ config_file (str): Path to config file.
91
+ merge (bool): Whether to merge with the default config or not.
92
+ Returns:
93
+ CfgNode: Config as a yacs CfgNode object.
94
+ """
95
+ if merge:
96
+ cfg = default_config()
97
+ else:
98
+ cfg = CN(new_allowed=True)
99
+ cfg.merge_from_file(config_file)
100
+
101
+ if update_cachedir:
102
+ def update_path(path: str) -> str:
103
+ if os.path.isabs(path):
104
+ return path
105
+ return os.path.join(CACHE_DIR_HAMER, path)
106
+
107
+ cfg.MANO.MODEL_PATH = update_path(cfg.MANO.MODEL_PATH)
108
+ cfg.MANO.MEAN_PARAMS = update_path(cfg.MANO.MEAN_PARAMS)
109
+
110
+ cfg.freeze()
111
+ return cfg
hamer/configs/cascade_mask_rcnn_vitdet_h_75ep.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## coco_loader_lsj.py
2
+
3
+ import detectron2.data.transforms as T
4
+ from detectron2 import model_zoo
5
+ from detectron2.config import LazyCall as L
6
+
7
+ # Data using LSJ
8
+ image_size = 1024
9
+ dataloader = model_zoo.get_config("common/data/coco.py").dataloader
10
+ dataloader.train.mapper.augmentations = [
11
+ L(T.RandomFlip)(horizontal=True), # flip first
12
+ L(T.ResizeScale)(
13
+ min_scale=0.1, max_scale=2.0, target_height=image_size, target_width=image_size
14
+ ),
15
+ L(T.FixedSizeCrop)(crop_size=(image_size, image_size), pad=False),
16
+ ]
17
+ dataloader.train.mapper.image_format = "RGB"
18
+ dataloader.train.total_batch_size = 64
19
+ # recompute boxes due to cropping
20
+ dataloader.train.mapper.recompute_boxes = True
21
+
22
+ dataloader.test.mapper.augmentations = [
23
+ L(T.ResizeShortestEdge)(short_edge_length=image_size, max_size=image_size),
24
+ ]
25
+
26
+ from functools import partial
27
+ from fvcore.common.param_scheduler import MultiStepParamScheduler
28
+
29
+ from detectron2 import model_zoo
30
+ from detectron2.config import LazyCall as L
31
+ from detectron2.solver import WarmupParamScheduler
32
+ from detectron2.modeling.backbone.vit import get_vit_lr_decay_rate
33
+
34
+ # mask_rcnn_vitdet_b_100ep.py
35
+
36
+ model = model_zoo.get_config("common/models/mask_rcnn_vitdet.py").model
37
+
38
+ # Initialization and trainer settings
39
+ train = model_zoo.get_config("common/train.py").train
40
+ train.amp.enabled = True
41
+ train.ddp.fp16_compression = True
42
+ train.init_checkpoint = "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_base.pth"
43
+
44
+
45
+ # Schedule
46
+ # 100 ep = 184375 iters * 64 images/iter / 118000 images/ep
47
+ train.max_iter = 184375
48
+
49
+ lr_multiplier = L(WarmupParamScheduler)(
50
+ scheduler=L(MultiStepParamScheduler)(
51
+ values=[1.0, 0.1, 0.01],
52
+ milestones=[163889, 177546],
53
+ num_updates=train.max_iter,
54
+ ),
55
+ warmup_length=250 / train.max_iter,
56
+ warmup_factor=0.001,
57
+ )
58
+
59
+ # Optimizer
60
+ optimizer = model_zoo.get_config("common/optim.py").AdamW
61
+ optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, num_layers=12, lr_decay_rate=0.7)
62
+ optimizer.params.overrides = {"pos_embed": {"weight_decay": 0.0}}
63
+
64
+ # cascade_mask_rcnn_vitdet_b_100ep.py
65
+
66
+ from detectron2.config import LazyCall as L
67
+ from detectron2.layers import ShapeSpec
68
+ from detectron2.modeling.box_regression import Box2BoxTransform
69
+ from detectron2.modeling.matcher import Matcher
70
+ from detectron2.modeling.roi_heads import (
71
+ FastRCNNOutputLayers,
72
+ FastRCNNConvFCHead,
73
+ CascadeROIHeads,
74
+ )
75
+
76
+ # arguments that don't exist for Cascade R-CNN
77
+ [model.roi_heads.pop(k) for k in ["box_head", "box_predictor", "proposal_matcher"]]
78
+
79
+ model.roi_heads.update(
80
+ _target_=CascadeROIHeads,
81
+ box_heads=[
82
+ L(FastRCNNConvFCHead)(
83
+ input_shape=ShapeSpec(channels=256, height=7, width=7),
84
+ conv_dims=[256, 256, 256, 256],
85
+ fc_dims=[1024],
86
+ conv_norm="LN",
87
+ )
88
+ for _ in range(3)
89
+ ],
90
+ box_predictors=[
91
+ L(FastRCNNOutputLayers)(
92
+ input_shape=ShapeSpec(channels=1024),
93
+ test_score_thresh=0.05,
94
+ box2box_transform=L(Box2BoxTransform)(weights=(w1, w1, w2, w2)),
95
+ cls_agnostic_bbox_reg=True,
96
+ num_classes="${...num_classes}",
97
+ )
98
+ for (w1, w2) in [(10, 5), (20, 10), (30, 15)]
99
+ ],
100
+ proposal_matchers=[
101
+ L(Matcher)(thresholds=[th], labels=[0, 1], allow_low_quality_matches=False)
102
+ for th in [0.5, 0.6, 0.7]
103
+ ],
104
+ )
105
+
106
+ # cascade_mask_rcnn_vitdet_h_75ep.py
107
+
108
+ from functools import partial
109
+
110
+ train.init_checkpoint = "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_huge_p14to16.pth"
111
+
112
+ model.backbone.net.embed_dim = 1280
113
+ model.backbone.net.depth = 32
114
+ model.backbone.net.num_heads = 16
115
+ model.backbone.net.drop_path_rate = 0.5
116
+ # 7, 15, 23, 31 for global attention
117
+ model.backbone.net.window_block_indexes = (
118
+ list(range(0, 7)) + list(range(8, 15)) + list(range(16, 23)) + list(range(24, 31))
119
+ )
120
+
121
+ optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, lr_decay_rate=0.9, num_layers=32)
122
+ optimizer.params.overrides = {}
123
+ optimizer.params.weight_decay_norm = None
124
+
125
+ train.max_iter = train.max_iter * 3 // 4 # 100ep -> 75ep
126
+ lr_multiplier.scheduler.milestones = [
127
+ milestone * 3 // 4 for milestone in lr_multiplier.scheduler.milestones
128
+ ]
129
+ lr_multiplier.scheduler.num_updates = train.max_iter
hamer/configs/datasets_tar.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FREIHAND-TRAIN:
2
+ TYPE: ImageDataset
3
+ URLS: hamer_training_data/dataset_tars/freihand-train/{000000..000130}.tar
4
+ epoch_size: 130_240
5
+ INTERHAND26M-TRAIN:
6
+ TYPE: ImageDataset
7
+ URLS: hamer_training_data/dataset_tars/interhand26m-train/{000000..001056}.tar
8
+ epoch_size: 1_424_632
9
+ HALPE-TRAIN:
10
+ TYPE: ImageDataset
11
+ URLS: hamer_training_data/dataset_tars/halpe-train/{000000..000022}.tar
12
+ epoch_size: 34_289
13
+ COCOW-TRAIN:
14
+ TYPE: ImageDataset
15
+ URLS: hamer_training_data/dataset_tars/cocow-train/{000000..000036}.tar
16
+ epoch_size: 78_666
17
+ MTC-TRAIN:
18
+ TYPE: ImageDataset
19
+ URLS: hamer_training_data/dataset_tars/mtc-train/{000000..000306}.tar
20
+ epoch_size: 363_947
21
+ RHD-TRAIN:
22
+ TYPE: ImageDataset
23
+ URLS: hamer_training_data/dataset_tars/rhd-train/{000000..000041}.tar
24
+ epoch_size: 61_705
25
+ MPIINZSL-TRAIN:
26
+ TYPE: ImageDataset
27
+ URLS: hamer_training_data/dataset_tars/mpiinzsl-train/{000000..000015}.tar
28
+ epoch_size: 15_184
29
+ HO3D-TRAIN:
30
+ TYPE: ImageDataset
31
+ URLS: hamer_training_data/dataset_tars/ho3d-train/{000000..000083}.tar
32
+ epoch_size: 83_325
33
+ H2O3D-TRAIN:
34
+ TYPE: ImageDataset
35
+ URLS: hamer_training_data/dataset_tars/h2o3d-train/{000000..000060}.tar
36
+ epoch_size: 121_996
37
+ DEX-TRAIN:
38
+ TYPE: ImageDataset
39
+ URLS: hamer_training_data/dataset_tars/dex-train/{000000..000406}.tar
40
+ epoch_size: 406_888
41
+ FREIHAND-MOCAP:
42
+ DATASET_FILE: hamer_training_data/freihand_mocap.npz
hamer/configs_hydra/data/mix_all.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ defaults:
3
+ - /data_filtering: low1
4
+
5
+ DATASETS:
6
+ TRAIN:
7
+ FREIHAND-TRAIN:
8
+ WEIGHT: 0.25
9
+ INTERHAND26M-TRAIN:
10
+ WEIGHT: 0.25
11
+ MTC-TRAIN:
12
+ WEIGHT: 0.1
13
+ RHD-TRAIN:
14
+ WEIGHT: 0.05
15
+ COCOW-TRAIN:
16
+ WEIGHT: 0.1
17
+ HALPE-TRAIN:
18
+ WEIGHT: 0.05
19
+ MPIINZSL-TRAIN:
20
+ WEIGHT: 0.05
21
+ HO3D-TRAIN:
22
+ WEIGHT: 0.05
23
+ H2O3D-TRAIN:
24
+ WEIGHT: 0.05
25
+ DEX-TRAIN:
26
+ WEIGHT: 0.05
27
+ VAL:
28
+ FREIHAND-TRAIN:
29
+ WEIGHT: 1.0
30
+
31
+ MOCAP: FREIHAND-MOCAP
hamer/configs_hydra/data_filtering/low1.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ DATASETS:
4
+ # Data filtering during training
5
+ SUPPRESS_KP_CONF_THRESH: 0.3
6
+ FILTER_NUM_KP: 4
7
+ FILTER_NUM_KP_THRESH: 0.0
8
+ FILTER_REPROJ_THRESH: 31000
9
+
10
+ SUPPRESS_BETAS_THRESH: 3.0
11
+ SUPPRESS_BAD_POSES: False
12
+ POSES_BETAS_SIMULTANEOUS: True
13
+ FILTER_NO_POSES: False # If True, filters images that don't have poses
hamer/configs_hydra/experiment/default.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ MANO:
4
+ DATA_DIR: ${oc.env:HOME}/.cache/4DHumans/data/
5
+ MODEL_PATH: ${MANO.DATA_DIR}/mano
6
+ GENDER: neutral
7
+ NUM_HAND_JOINTS: 15
8
+ MEAN_PARAMS: ${MANO.DATA_DIR}/mano_mean_params.npz
9
+ CREATE_BODY_POSE: FALSE
10
+
11
+ EXTRA:
12
+ FOCAL_LENGTH: 5000
13
+ NUM_LOG_IMAGES: 4
14
+ NUM_LOG_SAMPLES_PER_IMAGE: 8
15
+ PELVIS_IND: 0
16
+
17
+ DATASETS:
18
+ BETAS_REG: True
19
+ CONFIG:
20
+ SCALE_FACTOR: 0.3
21
+ ROT_FACTOR: 30
22
+ TRANS_FACTOR: 0.02
23
+ COLOR_SCALE: 0.2
24
+ ROT_AUG_RATE: 0.6
25
+ TRANS_AUG_RATE: 0.5
26
+ DO_FLIP: False
27
+ FLIP_AUG_RATE: 0.0
28
+ EXTREME_CROP_AUG_RATE: 0.0
29
+ EXTREME_CROP_AUG_LEVEL: 1
hamer/configs_hydra/experiment/hamer_vit_transformer.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - default.yaml
5
+
6
+ GENERAL:
7
+ TOTAL_STEPS: 1_000_000
8
+ LOG_STEPS: 1000
9
+ VAL_STEPS: 1000
10
+ CHECKPOINT_STEPS: 1000
11
+ CHECKPOINT_SAVE_TOP_K: 1
12
+ NUM_WORKERS: 25
13
+ PREFETCH_FACTOR: 2
14
+
15
+ TRAIN:
16
+ LR: 1e-5
17
+ WEIGHT_DECAY: 1e-4
18
+ BATCH_SIZE: 8
19
+ LOSS_REDUCTION: mean
20
+ NUM_TRAIN_SAMPLES: 2
21
+ NUM_TEST_SAMPLES: 64
22
+ POSE_2D_NOISE_RATIO: 0.01
23
+ SMPL_PARAM_NOISE_RATIO: 0.005
24
+
25
+ MODEL:
26
+ IMAGE_SIZE: 256
27
+ IMAGE_MEAN: [0.485, 0.456, 0.406]
28
+ IMAGE_STD: [0.229, 0.224, 0.225]
29
+ BACKBONE:
30
+ TYPE: vit
31
+ PRETRAINED_WEIGHTS: hamer_training_data/vitpose_backbone.pth
32
+ MANO_HEAD:
33
+ TYPE: transformer_decoder
34
+ IN_CHANNELS: 2048
35
+ TRANSFORMER_DECODER:
36
+ depth: 6
37
+ heads: 8
38
+ mlp_dim: 1024
39
+ dim_head: 64
40
+ dropout: 0.0
41
+ emb_dropout: 0.0
42
+ norm: layer
43
+ context_dim: 1280 # from vitpose-H
44
+
45
+ LOSS_WEIGHTS:
46
+ KEYPOINTS_3D: 0.05
47
+ KEYPOINTS_2D: 0.01
48
+ GLOBAL_ORIENT: 0.001
49
+ HAND_POSE: 0.001
50
+ BETAS: 0.0005
51
+ ADVERSARIAL: 0.0005
hamer/configs_hydra/extras/default.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # disable python warnings if they annoy you
2
+ ignore_warnings: False
3
+
4
+ # ask user for tags if none are provided in the config
5
+ enforce_tags: True
6
+
7
+ # pretty print config tree at the start of the run using Rich library
8
+ print_config: True
hamer/configs_hydra/hydra/default.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ # https://hydra.cc/docs/configure_hydra/intro/
3
+
4
+ # enable color logging
5
+ defaults:
6
+ - override /hydra/hydra_logging: colorlog
7
+ - override /hydra/job_logging: colorlog
8
+
9
+ # exp_name: ovrd_${hydra:job.override_dirname}
10
+ exp_name: ${now:%Y-%m-%d}_${now:%H-%M-%S}
11
+
12
+ hydra:
13
+ run:
14
+ dir: ${paths.log_dir}/${task_name}/runs/${exp_name}
15
+ sweep:
16
+ dir: ${paths.log_dir}/${task_name}/multiruns/${exp_name}
17
+ subdir: ${hydra.job.num}
18
+ job:
19
+ config:
20
+ override_dirname:
21
+ exclude_keys:
22
+ - trainer
23
+ - trainer.devices
24
+ - trainer.num_nodes
25
+ - callbacks
26
+ - debug
hamer/configs_hydra/launcher/local.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /hydra/launcher: submitit_local
5
+
6
+ hydra:
7
+ launcher:
8
+ timeout_min: 10_080 # 7 days
9
+ nodes: 1
10
+ tasks_per_node: ${trainer.devices}
11
+ cpus_per_task: 6
12
+ gpus_per_node: ${trainer.devices}
13
+ name: hamer
hamer/configs_hydra/launcher/slurm.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /hydra/launcher: submitit_slurm
5
+
6
+ hydra:
7
+ launcher:
8
+ timeout_min: 10_080 # 7 days
9
+ max_num_timeout: 3
10
+ partition: g40
11
+ qos: idle
12
+ nodes: 1
13
+ tasks_per_node: ${trainer.devices}
14
+ gpus_per_task: null
15
+ cpus_per_task: 12
16
+ gpus_per_node: ${trainer.devices}
17
+ cpus_per_gpu: null
18
+ comment: laion
19
+ name: hamer
20
+ setup:
21
+ - module load cuda openmpi libfabric-aws
22
+ - export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
hamer/configs_hydra/paths/default.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # path to root directory
2
+ # this requires PROJECT_ROOT environment variable to exist
3
+ # PROJECT_ROOT is inferred and set by pyrootutils package in `train.py` and `eval.py`
4
+ root_dir: ${oc.env:PROJECT_ROOT}
5
+
6
+ # path to data directory
7
+ data_dir: ${paths.root_dir}/data/
8
+
9
+ # path to logging directory
10
+ log_dir: logs/
11
+
12
+ # path to output directory, created dynamically by hydra
13
+ # path generation pattern is specified in `configs/hydra/default.yaml`
14
+ # use it to store all files generated during the run, like ckpts and metrics
15
+ output_dir: ${hydra:runtime.output_dir}
16
+
17
+ # path to working directory
18
+ work_dir: ${hydra:runtime.cwd}
hamer/configs_hydra/train.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # specify here default configuration
4
+ # order of defaults determines the order in which configs override each other
5
+ defaults:
6
+ - _self_
7
+ - data: mix_all.yaml
8
+ - trainer: ddp.yaml
9
+ - paths: default.yaml
10
+ - extras: default.yaml
11
+ - hydra: default.yaml
12
+
13
+ # experiment configs allow for version control of specific hyperparameters
14
+ # e.g. best hyperparameters for given model and datamodule
15
+ - experiment: null
16
+ - texture_exp: null
17
+
18
+ # optional local config for machine/user specific settings
19
+ # it's optional since it doesn't need to exist and is excluded from version control
20
+ - optional launcher: local.yaml
21
+ # - optional launcher: slurm.yaml
22
+
23
+ # debugging config (enable through command line, e.g. `python train.py debug=default)
24
+ - debug: null
25
+
26
+ # task name, determines output directory path
27
+ task_name: "train"
28
+
29
+ # tags to help you identify your experiments
30
+ # you can overwrite this in experiment configs
31
+ # overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
32
+ # appending lists from command line is currently not supported :(
33
+ # https://github.com/facebookresearch/hydra/issues/1547
34
+ tags: ["dev"]
35
+
36
+ # set False to skip model training
37
+ train: True
38
+
39
+ # evaluate on test set, using best model weights achieved during training
40
+ # lightning chooses best weights based on the metric specified in checkpoint callback
41
+ test: False
42
+
43
+ # simply provide checkpoint path to resume training
44
+ ckpt_path: null
45
+
46
+ # seed for random number generators in pytorch, numpy and python.random
47
+ seed: null
hamer/configs_hydra/trainer/cpu.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ defaults:
2
+ - default.yaml
3
+ - default_hamer.yaml
4
+
5
+ accelerator: cpu
6
+ devices: 1
hamer/configs_hydra/trainer/ddp.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - default.yaml
3
+ - default_hamer.yaml
4
+
5
+ # use "ddp_spawn" instead of "ddp",
6
+ # it's slower but normal "ddp" currently doesn't work ideally with hydra
7
+ # https://github.com/facebookresearch/hydra/issues/2070
8
+ # https://pytorch-lightning.readthedocs.io/en/latest/accelerators/gpu_intermediate.html#distributed-data-parallel-spawn
9
+ strategy: ddp
10
+
11
+ accelerator: gpu
12
+ devices: 8
13
+ num_nodes: 1
14
+ sync_batchnorm: True
hamer/configs_hydra/trainer/default.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: pytorch_lightning.Trainer
2
+
3
+ default_root_dir: ${paths.output_dir}
4
+
5
+ accelerator: cpu
6
+ devices: 1
7
+
8
+ # set True to to ensure deterministic results
9
+ # makes training slower but gives more reproducibility than just setting seeds
10
+ deterministic: False
hamer/configs_hydra/trainer/default_hamer.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ num_sanity_val_steps: 0
2
+ log_every_n_steps: ${GENERAL.LOG_STEPS}
3
+ val_check_interval: ${GENERAL.VAL_STEPS}
4
+ precision: 16
5
+ max_steps: ${GENERAL.TOTAL_STEPS}
6
+ # move_metrics_to_cpu: True
7
+ limit_val_batches: 1
8
+ # track_grad_norm: -1
hamer/configs_hydra/trainer/gpu.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ defaults:
2
+ - default.yaml
3
+ - default_hamer.yaml
4
+
5
+ accelerator: gpu
6
+ devices: 1
hamer/configs_hydra/trainer/mps.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ defaults:
2
+ - default.yaml
3
+ - default_hamer.yaml
4
+
5
+ accelerator: mps
6
+ devices: 1
hamer/datasets/__init__.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional
2
+
3
+ import torch
4
+ import numpy as np
5
+ import pytorch_lightning as pl
6
+ from yacs.config import CfgNode
7
+
8
+ from ..configs import to_lower
9
+ from .dataset import Dataset
10
+
11
+ class HAMERDataModule(pl.LightningDataModule):
12
+
13
+ def __init__(self, cfg: CfgNode, dataset_cfg: CfgNode) -> None:
14
+ """
15
+ Initialize LightningDataModule for HAMER training
16
+ Args:
17
+ cfg (CfgNode): Config file as a yacs CfgNode containing necessary dataset info.
18
+ dataset_cfg (CfgNode): Dataset configuration file
19
+ """
20
+ super().__init__()
21
+ self.cfg = cfg
22
+ self.dataset_cfg = dataset_cfg
23
+ self.train_dataset = None
24
+ self.val_dataset = None
25
+ self.test_dataset = None
26
+ self.mocap_dataset = None
27
+
28
+ def setup(self, stage: Optional[str] = None) -> None:
29
+ """
30
+ Load datasets necessary for training
31
+ Args:
32
+ cfg (CfgNode): Config file as a yacs CfgNode containing necessary dataset info.
33
+ """
34
+ if self.train_dataset == None:
35
+ self.train_dataset = MixedWebDataset(self.cfg, self.dataset_cfg, train=True).with_epoch(100_000).shuffle(4000)
36
+ self.val_dataset = MixedWebDataset(self.cfg, self.dataset_cfg, train=False).shuffle(4000)
37
+ self.mocap_dataset = MoCapDataset(**to_lower(self.dataset_cfg[self.cfg.DATASETS.MOCAP]))
38
+
39
+ def train_dataloader(self) -> Dict:
40
+ """
41
+ Setup training data loader.
42
+ Returns:
43
+ Dict: Dictionary containing image and mocap data dataloaders
44
+ """
45
+ train_dataloader = torch.utils.data.DataLoader(self.train_dataset, self.cfg.TRAIN.BATCH_SIZE, drop_last=True, num_workers=self.cfg.GENERAL.NUM_WORKERS, prefetch_factor=self.cfg.GENERAL.PREFETCH_FACTOR)
46
+ mocap_dataloader = torch.utils.data.DataLoader(self.mocap_dataset, self.cfg.TRAIN.NUM_TRAIN_SAMPLES * self.cfg.TRAIN.BATCH_SIZE, shuffle=True, drop_last=True, num_workers=1)
47
+ return {'img': train_dataloader, 'mocap': mocap_dataloader}
48
+
49
+ def val_dataloader(self) -> torch.utils.data.DataLoader:
50
+ """
51
+ Setup val data loader.
52
+ Returns:
53
+ torch.utils.data.DataLoader: Validation dataloader
54
+ """
55
+ val_dataloader = torch.utils.data.DataLoader(self.val_dataset, self.cfg.TRAIN.BATCH_SIZE, drop_last=True, num_workers=self.cfg.GENERAL.NUM_WORKERS)
56
+ return val_dataloader
hamer/datasets/dataset.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file contains the defition of the base Dataset class.
3
+ """
4
+
5
+ class DatasetRegistration(type):
6
+ """
7
+ Metaclass for registering different datasets
8
+ """
9
+ def __init__(cls, name, bases, nmspc):
10
+ super().__init__(name, bases, nmspc)
11
+ if not hasattr(cls, 'registry'):
12
+ cls.registry = dict()
13
+ cls.registry[name] = cls
14
+
15
+ # Metamethods, called on class objects:
16
+ def __iter__(cls):
17
+ return iter(cls.registry)
18
+
19
+ def __str__(cls):
20
+ return str(cls.registry)
21
+
22
+ class Dataset(metaclass=DatasetRegistration):
23
+ """
24
+ Base Dataset class
25
+ """
26
+ def __init__(self, *args, **kwargs):
27
+ pass
hamer/datasets/image_dataset.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ import numpy as np
4
+ import torch
5
+ from typing import List
6
+ from yacs.config import CfgNode
7
+ import braceexpand
8
+ import cv2
9
+
10
+ from .dataset import Dataset
11
+ from .utils import get_example, expand_to_aspect_ratio
12
+
13
+ def expand(s):
14
+ return os.path.expanduser(os.path.expandvars(s))
15
+ def expand_urls(urls: str|List[str]):
16
+ if isinstance(urls, str):
17
+ urls = [urls]
18
+ urls = [u for url in urls for u in braceexpand.braceexpand(expand(url))]
19
+ return urls
20
+
21
+ FLIP_KEYPOINT_PERMUTATION = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
22
+
23
+ DEFAULT_MEAN = 255. * np.array([0.485, 0.456, 0.406])
24
+ DEFAULT_STD = 255. * np.array([0.229, 0.224, 0.225])
25
+ DEFAULT_IMG_SIZE = 256
26
+
27
+ class ImageDataset(Dataset):
28
+
29
+ @staticmethod
30
+ def load_tars_as_webdataset(cfg: CfgNode, urls: str|List[str], train: bool,
31
+ resampled=False,
32
+ epoch_size=None,
33
+ cache_dir=None,
34
+ **kwargs) -> Dataset:
35
+ """
36
+ Loads the dataset from a webdataset tar file.
37
+ """
38
+
39
+ IMG_SIZE = cfg.MODEL.IMAGE_SIZE
40
+ BBOX_SHAPE = cfg.MODEL.get('BBOX_SHAPE', None)
41
+ MEAN = 255. * np.array(cfg.MODEL.IMAGE_MEAN)
42
+ STD = 255. * np.array(cfg.MODEL.IMAGE_STD)
43
+
44
+ def split_data(source):
45
+ for item in source:
46
+ datas = item['data.pyd']
47
+ for data in datas:
48
+ if 'detection.npz' in item:
49
+ det_idx = data['extra_info']['detection_npz_idx']
50
+ mask = item['detection.npz']['masks'][det_idx]
51
+ else:
52
+ mask = np.ones_like(item['jpg'][:,:,0], dtype=bool)
53
+ yield {
54
+ '__key__': item['__key__'],
55
+ 'jpg': item['jpg'],
56
+ 'data.pyd': data,
57
+ 'mask': mask,
58
+ }
59
+
60
+ def suppress_bad_kps(item, thresh=0.0):
61
+ if thresh > 0:
62
+ kp2d = item['data.pyd']['keypoints_2d']
63
+ kp2d_conf = np.where(kp2d[:, 2] < thresh, 0.0, kp2d[:, 2])
64
+ item['data.pyd']['keypoints_2d'] = np.concatenate([kp2d[:,:2], kp2d_conf[:,None]], axis=1)
65
+ return item
66
+
67
+ def filter_numkp(item, numkp=4, thresh=0.0):
68
+ kp_conf = item['data.pyd']['keypoints_2d'][:, 2]
69
+ return (kp_conf > thresh).sum() > numkp
70
+
71
+ def filter_reproj_error(item, thresh=10**4.5):
72
+ losses = item['data.pyd'].get('extra_info', {}).get('fitting_loss', np.array({})).item()
73
+ reproj_loss = losses.get('reprojection_loss', None)
74
+ return reproj_loss is None or reproj_loss < thresh
75
+
76
+ def filter_bbox_size(item, thresh=1):
77
+ bbox_size_min = item['data.pyd']['scale'].min().item() * 200.
78
+ return bbox_size_min > thresh
79
+
80
+ def filter_no_poses(item):
81
+ return (item['data.pyd']['has_hand_pose'] > 0)
82
+
83
+ def supress_bad_betas(item, thresh=3):
84
+ has_betas = item['data.pyd']['has_betas']
85
+ if thresh > 0 and has_betas:
86
+ betas_abs = np.abs(item['data.pyd']['betas'])
87
+ if (betas_abs > thresh).any():
88
+ item['data.pyd']['has_betas'] = False
89
+ return item
90
+
91
+ def supress_bad_poses(item):
92
+ has_hand_pose = item['data.pyd']['has_hand_pose']
93
+ if has_hand_pose:
94
+ hand_pose = item['data.pyd']['hand_pose']
95
+ pose_is_probable = poses_check_probable(torch.from_numpy(hand_pose)[None, 3:], amass_poses_hist100_smooth).item()
96
+ if not pose_is_probable:
97
+ item['data.pyd']['has_hand_pose'] = False
98
+ return item
99
+
100
+ def poses_betas_simultaneous(item):
101
+ # We either have both hand_pose and betas, or neither
102
+ has_betas = item['data.pyd']['has_betas']
103
+ has_hand_pose = item['data.pyd']['has_hand_pose']
104
+ item['data.pyd']['has_betas'] = item['data.pyd']['has_hand_pose'] = np.array(float((has_hand_pose>0) and (has_betas>0)))
105
+ return item
106
+
107
+ def set_betas_for_reg(item):
108
+ # Always have betas set to true
109
+ has_betas = item['data.pyd']['has_betas']
110
+ betas = item['data.pyd']['betas']
111
+
112
+ if not (has_betas>0):
113
+ item['data.pyd']['has_betas'] = np.array(float((True)))
114
+ item['data.pyd']['betas'] = betas * 0
115
+ return item
116
+
117
+ # Load the dataset
118
+ if epoch_size is not None:
119
+ resampled = True
120
+ #corrupt_filter = lambda sample: (sample['__key__'] not in CORRUPT_KEYS)
121
+ import webdataset as wds
122
+ dataset = wds.WebDataset(expand_urls(urls),
123
+ nodesplitter=wds.split_by_node,
124
+ shardshuffle=True,
125
+ resampled=resampled,
126
+ cache_dir=cache_dir,
127
+ ) #.select(corrupt_filter)
128
+ if train:
129
+ dataset = dataset.shuffle(100)
130
+ dataset = dataset.decode('rgb8').rename(jpg='jpg;jpeg;png')
131
+
132
+ # Process the dataset
133
+ dataset = dataset.compose(split_data)
134
+
135
+ # Filter/clean the dataset
136
+ SUPPRESS_KP_CONF_THRESH = cfg.DATASETS.get('SUPPRESS_KP_CONF_THRESH', 0.0)
137
+ SUPPRESS_BETAS_THRESH = cfg.DATASETS.get('SUPPRESS_BETAS_THRESH', 0.0)
138
+ SUPPRESS_BAD_POSES = cfg.DATASETS.get('SUPPRESS_BAD_POSES', False)
139
+ POSES_BETAS_SIMULTANEOUS = cfg.DATASETS.get('POSES_BETAS_SIMULTANEOUS', False)
140
+ BETAS_REG = cfg.DATASETS.get('BETAS_REG', False)
141
+ FILTER_NO_POSES = cfg.DATASETS.get('FILTER_NO_POSES', False)
142
+ FILTER_NUM_KP = cfg.DATASETS.get('FILTER_NUM_KP', 4)
143
+ FILTER_NUM_KP_THRESH = cfg.DATASETS.get('FILTER_NUM_KP_THRESH', 0.0)
144
+ FILTER_REPROJ_THRESH = cfg.DATASETS.get('FILTER_REPROJ_THRESH', 0.0)
145
+ FILTER_MIN_BBOX_SIZE = cfg.DATASETS.get('FILTER_MIN_BBOX_SIZE', 0.0)
146
+ if SUPPRESS_KP_CONF_THRESH > 0:
147
+ dataset = dataset.map(lambda x: suppress_bad_kps(x, thresh=SUPPRESS_KP_CONF_THRESH))
148
+ if SUPPRESS_BETAS_THRESH > 0:
149
+ dataset = dataset.map(lambda x: supress_bad_betas(x, thresh=SUPPRESS_BETAS_THRESH))
150
+ if SUPPRESS_BAD_POSES:
151
+ dataset = dataset.map(lambda x: supress_bad_poses(x))
152
+ if POSES_BETAS_SIMULTANEOUS:
153
+ dataset = dataset.map(lambda x: poses_betas_simultaneous(x))
154
+ if FILTER_NO_POSES:
155
+ dataset = dataset.select(lambda x: filter_no_poses(x))
156
+ if FILTER_NUM_KP > 0:
157
+ dataset = dataset.select(lambda x: filter_numkp(x, numkp=FILTER_NUM_KP, thresh=FILTER_NUM_KP_THRESH))
158
+ if FILTER_REPROJ_THRESH > 0:
159
+ dataset = dataset.select(lambda x: filter_reproj_error(x, thresh=FILTER_REPROJ_THRESH))
160
+ if FILTER_MIN_BBOX_SIZE > 0:
161
+ dataset = dataset.select(lambda x: filter_bbox_size(x, thresh=FILTER_MIN_BBOX_SIZE))
162
+ if BETAS_REG:
163
+ dataset = dataset.map(lambda x: set_betas_for_reg(x)) # NOTE: Must be at the end
164
+
165
+ use_skimage_antialias = cfg.DATASETS.get('USE_SKIMAGE_ANTIALIAS', False)
166
+ border_mode = {
167
+ 'constant': cv2.BORDER_CONSTANT,
168
+ 'replicate': cv2.BORDER_REPLICATE,
169
+ }[cfg.DATASETS.get('BORDER_MODE', 'constant')]
170
+
171
+ # Process the dataset further
172
+ dataset = dataset.map(lambda x: ImageDataset.process_webdataset_tar_item(x, train,
173
+ augm_config=cfg.DATASETS.CONFIG,
174
+ MEAN=MEAN, STD=STD, IMG_SIZE=IMG_SIZE,
175
+ BBOX_SHAPE=BBOX_SHAPE,
176
+ use_skimage_antialias=use_skimage_antialias,
177
+ border_mode=border_mode,
178
+ ))
179
+ if epoch_size is not None:
180
+ dataset = dataset.with_epoch(epoch_size)
181
+
182
+ return dataset
183
+
184
+ @staticmethod
185
+ def process_webdataset_tar_item(item, train,
186
+ augm_config=None,
187
+ MEAN=DEFAULT_MEAN,
188
+ STD=DEFAULT_STD,
189
+ IMG_SIZE=DEFAULT_IMG_SIZE,
190
+ BBOX_SHAPE=None,
191
+ use_skimage_antialias=False,
192
+ border_mode=cv2.BORDER_CONSTANT,
193
+ ):
194
+ # Read data from item
195
+ key = item['__key__']
196
+ image = item['jpg']
197
+ data = item['data.pyd']
198
+ mask = item['mask']
199
+
200
+ keypoints_2d = data['keypoints_2d']
201
+ keypoints_3d = data['keypoints_3d']
202
+ center = data['center']
203
+ scale = data['scale']
204
+ hand_pose = data['hand_pose']
205
+ betas = data['betas']
206
+ right = data['right']
207
+ #right = True
208
+ has_hand_pose = data['has_hand_pose']
209
+ has_betas = data['has_betas']
210
+ # image_file = data['image_file']
211
+
212
+ # Process data
213
+ orig_keypoints_2d = keypoints_2d.copy()
214
+ center_x = center[0]
215
+ center_y = center[1]
216
+ bbox_size = expand_to_aspect_ratio(scale*200, target_aspect_ratio=BBOX_SHAPE).max()
217
+ if bbox_size < 1:
218
+ breakpoint()
219
+
220
+
221
+ mano_params = {'global_orient': hand_pose[:3],
222
+ 'hand_pose': hand_pose[3:],
223
+ 'betas': betas
224
+ }
225
+
226
+ has_mano_params = {'global_orient': has_hand_pose,
227
+ 'hand_pose': has_hand_pose,
228
+ 'betas': has_betas
229
+ }
230
+
231
+ mano_params_is_axis_angle = {'global_orient': True,
232
+ 'hand_pose': True,
233
+ 'betas': False
234
+ }
235
+
236
+ augm_config = copy.deepcopy(augm_config)
237
+ # Crop image and (possibly) perform data augmentation
238
+ img_rgba = np.concatenate([image, mask.astype(np.uint8)[:,:,None]*255], axis=2)
239
+ img_patch_rgba, keypoints_2d, keypoints_3d, mano_params, has_mano_params, img_size, trans = get_example(img_rgba,
240
+ center_x, center_y,
241
+ bbox_size, bbox_size,
242
+ keypoints_2d, keypoints_3d,
243
+ mano_params, has_mano_params,
244
+ FLIP_KEYPOINT_PERMUTATION,
245
+ IMG_SIZE, IMG_SIZE,
246
+ MEAN, STD, train, right, augm_config,
247
+ is_bgr=False, return_trans=True,
248
+ use_skimage_antialias=use_skimage_antialias,
249
+ border_mode=border_mode,
250
+ )
251
+ img_patch = img_patch_rgba[:3,:,:]
252
+ mask_patch = (img_patch_rgba[3,:,:] / 255.0).clip(0,1)
253
+ if (mask_patch < 0.5).all():
254
+ mask_patch = np.ones_like(mask_patch)
255
+
256
+ item = {}
257
+
258
+ item['img'] = img_patch
259
+ item['mask'] = mask_patch
260
+ # item['img_og'] = image
261
+ # item['mask_og'] = mask
262
+ item['keypoints_2d'] = keypoints_2d.astype(np.float32)
263
+ item['keypoints_3d'] = keypoints_3d.astype(np.float32)
264
+ item['orig_keypoints_2d'] = orig_keypoints_2d
265
+ item['box_center'] = center.copy()
266
+ item['box_size'] = bbox_size
267
+ item['img_size'] = 1.0 * img_size[::-1].copy()
268
+ item['mano_params'] = mano_params
269
+ item['has_mano_params'] = has_mano_params
270
+ item['mano_params_is_axis_angle'] = mano_params_is_axis_angle
271
+ item['_scale'] = scale
272
+ item['_trans'] = trans
273
+ item['imgname'] = key
274
+ # item['idx'] = idx
275
+ return item
hamer/datasets/json_dataset.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ import json
4
+ import glob
5
+ import numpy as np
6
+ import torch
7
+ from typing import Any, Dict, List
8
+ from yacs.config import CfgNode
9
+ import braceexpand
10
+ import cv2
11
+
12
+ from .dataset import Dataset
13
+ from .utils import get_example, expand_to_aspect_ratio
14
+ from .smplh_prob_filter import poses_check_probable, load_amass_hist_smooth
15
+
16
+ def expand(s):
17
+ return os.path.expanduser(os.path.expandvars(s))
18
+ def expand_urls(urls: str|List[str]):
19
+ if isinstance(urls, str):
20
+ urls = [urls]
21
+ urls = [u for url in urls for u in braceexpand.braceexpand(expand(url))]
22
+ return urls
23
+
24
+ AIC_TRAIN_CORRUPT_KEYS = {
25
+ '0a047f0124ae48f8eee15a9506ce1449ee1ba669',
26
+ '1a703aa174450c02fbc9cfbf578a5435ef403689',
27
+ '0394e6dc4df78042929b891dbc24f0fd7ffb6b6d',
28
+ '5c032b9626e410441544c7669123ecc4ae077058',
29
+ 'ca018a7b4c5f53494006ebeeff9b4c0917a55f07',
30
+ '4a77adb695bef75a5d34c04d589baf646fe2ba35',
31
+ 'a0689017b1065c664daef4ae2d14ea03d543217e',
32
+ '39596a45cbd21bed4a5f9c2342505532f8ec5cbb',
33
+ '3d33283b40610d87db660b62982f797d50a7366b',
34
+ }
35
+ CORRUPT_KEYS = {
36
+ *{f'aic-train/{k}' for k in AIC_TRAIN_CORRUPT_KEYS},
37
+ *{f'aic-train-vitpose/{k}' for k in AIC_TRAIN_CORRUPT_KEYS},
38
+ }
39
+
40
+ FLIP_KEYPOINT_PERMUTATION = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
41
+
42
+ DEFAULT_MEAN = 255. * np.array([0.485, 0.456, 0.406])
43
+ DEFAULT_STD = 255. * np.array([0.229, 0.224, 0.225])
44
+ DEFAULT_IMG_SIZE = 256
45
+
46
+ class JsonDataset(Dataset):
47
+
48
+ def __init__(self,
49
+ cfg: CfgNode,
50
+ dataset_file: str,
51
+ img_dir: str,
52
+ right: bool,
53
+ train: bool = False,
54
+ prune: Dict[str, Any] = {},
55
+ **kwargs):
56
+ """
57
+ Dataset class used for loading images and corresponding annotations.
58
+ Args:
59
+ cfg (CfgNode): Model config file.
60
+ dataset_file (str): Path to npz file containing dataset info.
61
+ img_dir (str): Path to image folder.
62
+ train (bool): Whether it is for training or not (enables data augmentation).
63
+ """
64
+ super(JsonDataset, self).__init__()
65
+ self.train = train
66
+ self.cfg = cfg
67
+
68
+ self.img_size = cfg.MODEL.IMAGE_SIZE
69
+ self.mean = 255. * np.array(self.cfg.MODEL.IMAGE_MEAN)
70
+ self.std = 255. * np.array(self.cfg.MODEL.IMAGE_STD)
71
+
72
+ self.img_dir = img_dir
73
+ boxes = np.array(json.load(open(dataset_file, 'rb')))
74
+
75
+ self.imgname = glob.glob(os.path.join(self.img_dir,'*.jpg'))
76
+ self.imgname.sort()
77
+
78
+ self.flip_keypoint_permutation = copy.copy(FLIP_KEYPOINT_PERMUTATION)
79
+
80
+ num_pose = 3 * (self.cfg.MANO.NUM_HAND_JOINTS + 1)
81
+
82
+ # Bounding boxes are assumed to be in the center and scale format
83
+ boxes = boxes.astype(np.float32)
84
+ self.center = (boxes[:, 2:4] + boxes[:, 0:2]) / 2.0
85
+ self.scale = 2 * (boxes[:, 2:4] - boxes[:, 0:2]) / 200.0
86
+ self.personid = np.arange(len(boxes), dtype=np.int32)
87
+ if right:
88
+ self.right = np.ones(len(self.imgname), dtype=np.float32)
89
+ else:
90
+ self.right = np.zeros(len(self.imgname), dtype=np.float32)
91
+ assert self.scale.shape == (len(self.center), 2)
92
+
93
+ # Get gt SMPLX parameters, if available
94
+ try:
95
+ self.hand_pose = self.data['hand_pose'].astype(np.float32)
96
+ self.has_hand_pose = self.data['has_hand_pose'].astype(np.float32)
97
+ except:
98
+ self.hand_pose = np.zeros((len(self.imgname), num_pose), dtype=np.float32)
99
+ self.has_hand_pose = np.zeros(len(self.imgname), dtype=np.float32)
100
+ try:
101
+ self.betas = self.data['betas'].astype(np.float32)
102
+ self.has_betas = self.data['has_betas'].astype(np.float32)
103
+ except:
104
+ self.betas = np.zeros((len(self.imgname), 10), dtype=np.float32)
105
+ self.has_betas = np.zeros(len(self.imgname), dtype=np.float32)
106
+
107
+ # Try to get 2d keypoints, if available
108
+ try:
109
+ hand_keypoints_2d = self.data['hand_keypoints_2d']
110
+ except:
111
+ hand_keypoints_2d = np.zeros((len(self.center), 21, 3))
112
+ ## Try to get extra 2d keypoints, if available
113
+ #try:
114
+ # extra_keypoints_2d = self.data['extra_keypoints_2d']
115
+ #except KeyError:
116
+ # extra_keypoints_2d = np.zeros((len(self.center), 19, 3))
117
+
118
+ #self.keypoints_2d = np.concatenate((hand_keypoints_2d, extra_keypoints_2d), axis=1).astype(np.float32)
119
+ self.keypoints_2d = hand_keypoints_2d
120
+
121
+ # Try to get 3d keypoints, if available
122
+ try:
123
+ hand_keypoints_3d = self.data['hand_keypoints_3d'].astype(np.float32)
124
+ except:
125
+ hand_keypoints_3d = np.zeros((len(self.center), 21, 4), dtype=np.float32)
126
+ ## Try to get extra 3d keypoints, if available
127
+ #try:
128
+ # extra_keypoints_3d = self.data['extra_keypoints_3d'].astype(np.float32)
129
+ #except KeyError:
130
+ # extra_keypoints_3d = np.zeros((len(self.center), 19, 4), dtype=np.float32)
131
+
132
+ self.keypoints_3d = hand_keypoints_3d
133
+
134
+ #body_keypoints_3d[:, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], -1] = 0
135
+
136
+ #self.keypoints_3d = np.concatenate((body_keypoints_3d, extra_keypoints_3d), axis=1).astype(np.float32)
137
+
138
+ def __len__(self) -> int:
139
+ return len(self.scale)
140
+
141
+ def __getitem__(self, idx: int) -> Dict:
142
+ """
143
+ Returns an example from the dataset.
144
+ """
145
+ try:
146
+ image_file = self.imgname[idx].decode('utf-8')
147
+ except AttributeError:
148
+ image_file = self.imgname[idx]
149
+ keypoints_2d = self.keypoints_2d[idx].copy()
150
+ keypoints_3d = self.keypoints_3d[idx].copy()
151
+
152
+ center = self.center[idx].copy()
153
+ center_x = center[0]
154
+ center_y = center[1]
155
+ scale = self.scale[idx]
156
+ right = self.right[idx].copy()
157
+ BBOX_SHAPE = self.cfg.MODEL.get('BBOX_SHAPE', None)
158
+ #bbox_size = expand_to_aspect_ratio(scale*200, target_aspect_ratio=BBOX_SHAPE).max()
159
+ bbox_size = ((scale*200).max())
160
+ bbox_expand_factor = bbox_size / ((scale*200).max())
161
+ hand_pose = self.hand_pose[idx].copy().astype(np.float32)
162
+ betas = self.betas[idx].copy().astype(np.float32)
163
+
164
+ has_hand_pose = self.has_hand_pose[idx].copy()
165
+ has_betas = self.has_betas[idx].copy()
166
+
167
+ mano_params = {'global_orient': hand_pose[:3],
168
+ 'hand_pose': hand_pose[3:],
169
+ 'betas': betas
170
+ }
171
+
172
+ has_mano_params = {'global_orient': has_hand_pose,
173
+ 'hand_pose': has_hand_pose,
174
+ 'betas': has_betas
175
+ }
176
+
177
+ mano_params_is_axis_angle = {'global_orient': True,
178
+ 'hand_pose': True,
179
+ 'betas': False
180
+ }
181
+
182
+ augm_config = self.cfg.DATASETS.CONFIG
183
+ # Crop image and (possibly) perform data augmentation
184
+ img_patch, keypoints_2d, keypoints_3d, mano_params, has_mano_params, img_size = get_example(image_file,
185
+ center_x, center_y,
186
+ bbox_size, bbox_size,
187
+ keypoints_2d, keypoints_3d,
188
+ mano_params, has_mano_params,
189
+ self.flip_keypoint_permutation,
190
+ self.img_size, self.img_size,
191
+ self.mean, self.std, self.train, right, augm_config)
192
+
193
+ item = {}
194
+ # These are the keypoints in the original image coordinates (before cropping)
195
+ orig_keypoints_2d = self.keypoints_2d[idx].copy()
196
+
197
+ item['img'] = img_patch
198
+ item['keypoints_2d'] = keypoints_2d.astype(np.float32)
199
+ item['keypoints_3d'] = keypoints_3d.astype(np.float32)
200
+ item['orig_keypoints_2d'] = orig_keypoints_2d
201
+ item['box_center'] = self.center[idx].copy()
202
+ item['box_size'] = bbox_size
203
+ item['bbox_expand_factor'] = bbox_expand_factor
204
+ item['img_size'] = 1.0 * img_size[::-1].copy()
205
+ item['mano_params'] = mano_params
206
+ item['has_mano_params'] = has_mano_params
207
+ item['mano_params_is_axis_angle'] = mano_params_is_axis_angle
208
+ item['imgname'] = image_file
209
+ item['personid'] = int(self.personid[idx])
210
+ item['idx'] = idx
211
+ item['_scale'] = scale
212
+ item['right'] = self.right[idx].copy()
213
+ return item
hamer/datasets/mocap_dataset.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import Dict
3
+
4
+ class MoCapDataset:
5
+
6
+ def __init__(self, dataset_file: str):
7
+ """
8
+ Dataset class used for loading a dataset of unpaired MANO parameter annotations
9
+ Args:
10
+ cfg (CfgNode): Model config file.
11
+ dataset_file (str): Path to npz file containing dataset info.
12
+ """
13
+ data = np.load(dataset_file)
14
+ self.pose = data['hand_pose'].astype(np.float32)[:, 3:]
15
+ self.betas = data['betas'].astype(np.float32)
16
+ self.length = len(self.pose)
17
+
18
+ def __getitem__(self, idx: int) -> Dict:
19
+ pose = self.pose[idx].copy()
20
+ betas = self.betas[idx].copy()
21
+ item = {'hand_pose': pose, 'betas': betas}
22
+ return item
23
+
24
+ def __len__(self) -> int:
25
+ return self.length
hamer/datasets/utils.py ADDED
@@ -0,0 +1,993 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Parts of the code are taken or adapted from
3
+ https://github.com/mkocabas/EpipolarPose/blob/master/lib/utils/img_utils.py
4
+ """
5
+ import torch
6
+ import numpy as np
7
+ from skimage.transform import rotate, resize
8
+ from skimage.filters import gaussian
9
+ import random
10
+ import cv2
11
+ from typing import List, Dict, Tuple
12
+ from yacs.config import CfgNode
13
+
14
+ def expand_to_aspect_ratio(input_shape, target_aspect_ratio=None):
15
+ """Increase the size of the bounding box to match the target shape."""
16
+ if target_aspect_ratio is None:
17
+ return input_shape
18
+
19
+ try:
20
+ w , h = input_shape
21
+ except (ValueError, TypeError):
22
+ return input_shape
23
+
24
+ w_t, h_t = target_aspect_ratio
25
+ if h / w < h_t / w_t:
26
+ h_new = max(w * h_t / w_t, h)
27
+ w_new = w
28
+ else:
29
+ h_new = h
30
+ w_new = max(h * w_t / h_t, w)
31
+ if h_new < h or w_new < w:
32
+ breakpoint()
33
+ return np.array([w_new, h_new])
34
+
35
+ def do_augmentation(aug_config: CfgNode) -> Tuple:
36
+ """
37
+ Compute random augmentation parameters.
38
+ Args:
39
+ aug_config (CfgNode): Config containing augmentation parameters.
40
+ Returns:
41
+ scale (float): Box rescaling factor.
42
+ rot (float): Random image rotation.
43
+ do_flip (bool): Whether to flip image or not.
44
+ do_extreme_crop (bool): Whether to apply extreme cropping (as proposed in EFT).
45
+ color_scale (List): Color rescaling factor
46
+ tx (float): Random translation along the x axis.
47
+ ty (float): Random translation along the y axis.
48
+ """
49
+
50
+ tx = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.TRANS_FACTOR
51
+ ty = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.TRANS_FACTOR
52
+ scale = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.SCALE_FACTOR + 1.0
53
+ rot = np.clip(np.random.randn(), -2.0,
54
+ 2.0) * aug_config.ROT_FACTOR if random.random() <= aug_config.ROT_AUG_RATE else 0
55
+ do_flip = aug_config.DO_FLIP and random.random() <= aug_config.FLIP_AUG_RATE
56
+ do_extreme_crop = random.random() <= aug_config.EXTREME_CROP_AUG_RATE
57
+ extreme_crop_lvl = aug_config.get('EXTREME_CROP_AUG_LEVEL', 0)
58
+ # extreme_crop_lvl = 0
59
+ c_up = 1.0 + aug_config.COLOR_SCALE
60
+ c_low = 1.0 - aug_config.COLOR_SCALE
61
+ color_scale = [random.uniform(c_low, c_up), random.uniform(c_low, c_up), random.uniform(c_low, c_up)]
62
+ return scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty
63
+
64
+ def rotate_2d(pt_2d: np.array, rot_rad: float) -> np.array:
65
+ """
66
+ Rotate a 2D point on the x-y plane.
67
+ Args:
68
+ pt_2d (np.array): Input 2D point with shape (2,).
69
+ rot_rad (float): Rotation angle
70
+ Returns:
71
+ np.array: Rotated 2D point.
72
+ """
73
+ x = pt_2d[0]
74
+ y = pt_2d[1]
75
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
76
+ xx = x * cs - y * sn
77
+ yy = x * sn + y * cs
78
+ return np.array([xx, yy], dtype=np.float32)
79
+
80
+
81
+ def gen_trans_from_patch_cv(c_x: float, c_y: float,
82
+ src_width: float, src_height: float,
83
+ dst_width: float, dst_height: float,
84
+ scale: float, rot: float) -> np.array:
85
+ """
86
+ Create transformation matrix for the bounding box crop.
87
+ Args:
88
+ c_x (float): Bounding box center x coordinate in the original image.
89
+ c_y (float): Bounding box center y coordinate in the original image.
90
+ src_width (float): Bounding box width.
91
+ src_height (float): Bounding box height.
92
+ dst_width (float): Output box width.
93
+ dst_height (float): Output box height.
94
+ scale (float): Rescaling factor for the bounding box (augmentation).
95
+ rot (float): Random rotation applied to the box.
96
+ Returns:
97
+ trans (np.array): Target geometric transformation.
98
+ """
99
+ # augment size with scale
100
+ src_w = src_width * scale
101
+ src_h = src_height * scale
102
+ src_center = np.zeros(2)
103
+ src_center[0] = c_x
104
+ src_center[1] = c_y
105
+ # augment rotation
106
+ rot_rad = np.pi * rot / 180
107
+ src_downdir = rotate_2d(np.array([0, src_h * 0.5], dtype=np.float32), rot_rad)
108
+ src_rightdir = rotate_2d(np.array([src_w * 0.5, 0], dtype=np.float32), rot_rad)
109
+
110
+ dst_w = dst_width
111
+ dst_h = dst_height
112
+ dst_center = np.array([dst_w * 0.5, dst_h * 0.5], dtype=np.float32)
113
+ dst_downdir = np.array([0, dst_h * 0.5], dtype=np.float32)
114
+ dst_rightdir = np.array([dst_w * 0.5, 0], dtype=np.float32)
115
+
116
+ src = np.zeros((3, 2), dtype=np.float32)
117
+ src[0, :] = src_center
118
+ src[1, :] = src_center + src_downdir
119
+ src[2, :] = src_center + src_rightdir
120
+
121
+ dst = np.zeros((3, 2), dtype=np.float32)
122
+ dst[0, :] = dst_center
123
+ dst[1, :] = dst_center + dst_downdir
124
+ dst[2, :] = dst_center + dst_rightdir
125
+
126
+ trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
127
+
128
+ return trans
129
+
130
+
131
+ def trans_point2d(pt_2d: np.array, trans: np.array):
132
+ """
133
+ Transform a 2D point using translation matrix trans.
134
+ Args:
135
+ pt_2d (np.array): Input 2D point with shape (2,).
136
+ trans (np.array): Transformation matrix.
137
+ Returns:
138
+ np.array: Transformed 2D point.
139
+ """
140
+ src_pt = np.array([pt_2d[0], pt_2d[1], 1.]).T
141
+ dst_pt = np.dot(trans, src_pt)
142
+ return dst_pt[0:2]
143
+
144
+ def get_transform(center, scale, res, rot=0):
145
+ """Generate transformation matrix."""
146
+ """Taken from PARE: https://github.com/mkocabas/PARE/blob/6e0caca86c6ab49ff80014b661350958e5b72fd8/pare/utils/image_utils.py"""
147
+ h = 200 * scale
148
+ t = np.zeros((3, 3))
149
+ t[0, 0] = float(res[1]) / h
150
+ t[1, 1] = float(res[0]) / h
151
+ t[0, 2] = res[1] * (-float(center[0]) / h + .5)
152
+ t[1, 2] = res[0] * (-float(center[1]) / h + .5)
153
+ t[2, 2] = 1
154
+ if not rot == 0:
155
+ rot = -rot # To match direction of rotation from cropping
156
+ rot_mat = np.zeros((3, 3))
157
+ rot_rad = rot * np.pi / 180
158
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
159
+ rot_mat[0, :2] = [cs, -sn]
160
+ rot_mat[1, :2] = [sn, cs]
161
+ rot_mat[2, 2] = 1
162
+ # Need to rotate around center
163
+ t_mat = np.eye(3)
164
+ t_mat[0, 2] = -res[1] / 2
165
+ t_mat[1, 2] = -res[0] / 2
166
+ t_inv = t_mat.copy()
167
+ t_inv[:2, 2] *= -1
168
+ t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
169
+ return t
170
+
171
+
172
+ def transform(pt, center, scale, res, invert=0, rot=0, as_int=True):
173
+ """Transform pixel location to different reference."""
174
+ """Taken from PARE: https://github.com/mkocabas/PARE/blob/6e0caca86c6ab49ff80014b661350958e5b72fd8/pare/utils/image_utils.py"""
175
+ t = get_transform(center, scale, res, rot=rot)
176
+ if invert:
177
+ t = np.linalg.inv(t)
178
+ new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
179
+ new_pt = np.dot(t, new_pt)
180
+ if as_int:
181
+ new_pt = new_pt.astype(int)
182
+ return new_pt[:2] + 1
183
+
184
+ def crop_img(img, ul, br, border_mode=cv2.BORDER_CONSTANT, border_value=0):
185
+ c_x = (ul[0] + br[0])/2
186
+ c_y = (ul[1] + br[1])/2
187
+ bb_width = patch_width = br[0] - ul[0]
188
+ bb_height = patch_height = br[1] - ul[1]
189
+ trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, 1.0, 0)
190
+ img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)),
191
+ flags=cv2.INTER_LINEAR,
192
+ borderMode=border_mode,
193
+ borderValue=border_value
194
+ )
195
+
196
+ # Force borderValue=cv2.BORDER_CONSTANT for alpha channel
197
+ if (img.shape[2] == 4) and (border_mode != cv2.BORDER_CONSTANT):
198
+ img_patch[:,:,3] = cv2.warpAffine(img[:,:,3], trans, (int(patch_width), int(patch_height)),
199
+ flags=cv2.INTER_LINEAR,
200
+ borderMode=cv2.BORDER_CONSTANT,
201
+ )
202
+
203
+ return img_patch
204
+
205
+ def generate_image_patch_skimage(img: np.array, c_x: float, c_y: float,
206
+ bb_width: float, bb_height: float,
207
+ patch_width: float, patch_height: float,
208
+ do_flip: bool, scale: float, rot: float,
209
+ border_mode=cv2.BORDER_CONSTANT, border_value=0) -> Tuple[np.array, np.array]:
210
+ """
211
+ Crop image according to the supplied bounding box.
212
+ Args:
213
+ img (np.array): Input image of shape (H, W, 3)
214
+ c_x (float): Bounding box center x coordinate in the original image.
215
+ c_y (float): Bounding box center y coordinate in the original image.
216
+ bb_width (float): Bounding box width.
217
+ bb_height (float): Bounding box height.
218
+ patch_width (float): Output box width.
219
+ patch_height (float): Output box height.
220
+ do_flip (bool): Whether to flip image or not.
221
+ scale (float): Rescaling factor for the bounding box (augmentation).
222
+ rot (float): Random rotation applied to the box.
223
+ Returns:
224
+ img_patch (np.array): Cropped image patch of shape (patch_height, patch_height, 3)
225
+ trans (np.array): Transformation matrix.
226
+ """
227
+
228
+ img_height, img_width, img_channels = img.shape
229
+ if do_flip:
230
+ img = img[:, ::-1, :]
231
+ c_x = img_width - c_x - 1
232
+
233
+ trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, scale, rot)
234
+
235
+ #img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)), flags=cv2.INTER_LINEAR)
236
+
237
+ # skimage
238
+ center = np.zeros(2)
239
+ center[0] = c_x
240
+ center[1] = c_y
241
+ res = np.zeros(2)
242
+ res[0] = patch_width
243
+ res[1] = patch_height
244
+ # assumes bb_width = bb_height
245
+ # assumes patch_width = patch_height
246
+ assert bb_width == bb_height, f'{bb_width=} != {bb_height=}'
247
+ assert patch_width == patch_height, f'{patch_width=} != {patch_height=}'
248
+ scale1 = scale*bb_width/200.
249
+
250
+ # Upper left point
251
+ ul = np.array(transform([1, 1], center, scale1, res, invert=1, as_int=False)) - 1
252
+ # Bottom right point
253
+ br = np.array(transform([res[0] + 1,
254
+ res[1] + 1], center, scale1, res, invert=1, as_int=False)) - 1
255
+
256
+ # Padding so that when rotated proper amount of context is included
257
+ try:
258
+ pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2) + 1
259
+ except:
260
+ breakpoint()
261
+ if not rot == 0:
262
+ ul -= pad
263
+ br += pad
264
+
265
+
266
+ if False:
267
+ # Old way of cropping image
268
+ ul_int = ul.astype(int)
269
+ br_int = br.astype(int)
270
+ new_shape = [br_int[1] - ul_int[1], br_int[0] - ul_int[0]]
271
+ if len(img.shape) > 2:
272
+ new_shape += [img.shape[2]]
273
+ new_img = np.zeros(new_shape)
274
+
275
+ # Range to fill new array
276
+ new_x = max(0, -ul_int[0]), min(br_int[0], len(img[0])) - ul_int[0]
277
+ new_y = max(0, -ul_int[1]), min(br_int[1], len(img)) - ul_int[1]
278
+ # Range to sample from original image
279
+ old_x = max(0, ul_int[0]), min(len(img[0]), br_int[0])
280
+ old_y = max(0, ul_int[1]), min(len(img), br_int[1])
281
+ new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1],
282
+ old_x[0]:old_x[1]]
283
+
284
+ # New way of cropping image
285
+ new_img = crop_img(img, ul, br, border_mode=border_mode, border_value=border_value).astype(np.float32)
286
+
287
+ # print(f'{new_img.shape=}')
288
+ # print(f'{new_img1.shape=}')
289
+ # print(f'{np.allclose(new_img, new_img1)=}')
290
+ # print(f'{img.dtype=}')
291
+
292
+
293
+ if not rot == 0:
294
+ # Remove padding
295
+
296
+ new_img = rotate(new_img, rot) # scipy.misc.imrotate(new_img, rot)
297
+ new_img = new_img[pad:-pad, pad:-pad]
298
+
299
+ if new_img.shape[0] < 1 or new_img.shape[1] < 1:
300
+ print(f'{img.shape=}')
301
+ print(f'{new_img.shape=}')
302
+ print(f'{ul=}')
303
+ print(f'{br=}')
304
+ print(f'{pad=}')
305
+ print(f'{rot=}')
306
+
307
+ breakpoint()
308
+
309
+ # resize image
310
+ new_img = resize(new_img, res) # scipy.misc.imresize(new_img, res)
311
+
312
+ new_img = np.clip(new_img, 0, 255).astype(np.uint8)
313
+
314
+ return new_img, trans
315
+
316
+
317
+ def generate_image_patch_cv2(img: np.array, c_x: float, c_y: float,
318
+ bb_width: float, bb_height: float,
319
+ patch_width: float, patch_height: float,
320
+ do_flip: bool, scale: float, rot: float,
321
+ border_mode=cv2.BORDER_CONSTANT, border_value=0) -> Tuple[np.array, np.array]:
322
+ """
323
+ Crop the input image and return the crop and the corresponding transformation matrix.
324
+ Args:
325
+ img (np.array): Input image of shape (H, W, 3)
326
+ c_x (float): Bounding box center x coordinate in the original image.
327
+ c_y (float): Bounding box center y coordinate in the original image.
328
+ bb_width (float): Bounding box width.
329
+ bb_height (float): Bounding box height.
330
+ patch_width (float): Output box width.
331
+ patch_height (float): Output box height.
332
+ do_flip (bool): Whether to flip image or not.
333
+ scale (float): Rescaling factor for the bounding box (augmentation).
334
+ rot (float): Random rotation applied to the box.
335
+ Returns:
336
+ img_patch (np.array): Cropped image patch of shape (patch_height, patch_height, 3)
337
+ trans (np.array): Transformation matrix.
338
+ """
339
+
340
+ img_height, img_width, img_channels = img.shape
341
+ if do_flip:
342
+ img = img[:, ::-1, :]
343
+ c_x = img_width - c_x - 1
344
+
345
+
346
+ trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, scale, rot)
347
+
348
+ img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)),
349
+ flags=cv2.INTER_LINEAR,
350
+ borderMode=border_mode,
351
+ borderValue=border_value,
352
+ )
353
+ # Force borderValue=cv2.BORDER_CONSTANT for alpha channel
354
+ if (img.shape[2] == 4) and (border_mode != cv2.BORDER_CONSTANT):
355
+ img_patch[:,:,3] = cv2.warpAffine(img[:,:,3], trans, (int(patch_width), int(patch_height)),
356
+ flags=cv2.INTER_LINEAR,
357
+ borderMode=cv2.BORDER_CONSTANT,
358
+ )
359
+
360
+ return img_patch, trans
361
+
362
+
363
+ def convert_cvimg_to_tensor(cvimg: np.array):
364
+ """
365
+ Convert image from HWC to CHW format.
366
+ Args:
367
+ cvimg (np.array): Image of shape (H, W, 3) as loaded by OpenCV.
368
+ Returns:
369
+ np.array: Output image of shape (3, H, W).
370
+ """
371
+ # from h,w,c(OpenCV) to c,h,w
372
+ img = cvimg.copy()
373
+ img = np.transpose(img, (2, 0, 1))
374
+ # from int to float
375
+ img = img.astype(np.float32)
376
+ return img
377
+
378
+ def fliplr_params(mano_params: Dict, has_mano_params: Dict) -> Tuple[Dict, Dict]:
379
+ """
380
+ Flip MANO parameters when flipping the image.
381
+ Args:
382
+ mano_params (Dict): MANO parameter annotations.
383
+ has_mano_params (Dict): Whether MANO annotations are valid.
384
+ Returns:
385
+ Dict, Dict: Flipped MANO parameters and valid flags.
386
+ """
387
+ global_orient = mano_params['global_orient'].copy()
388
+ hand_pose = mano_params['hand_pose'].copy()
389
+ betas = mano_params['betas'].copy()
390
+ has_global_orient = has_mano_params['global_orient'].copy()
391
+ has_hand_pose = has_mano_params['hand_pose'].copy()
392
+ has_betas = has_mano_params['betas'].copy()
393
+
394
+ global_orient[1::3] *= -1
395
+ global_orient[2::3] *= -1
396
+ hand_pose[1::3] *= -1
397
+ hand_pose[2::3] *= -1
398
+
399
+ mano_params = {'global_orient': global_orient.astype(np.float32),
400
+ 'hand_pose': hand_pose.astype(np.float32),
401
+ 'betas': betas.astype(np.float32)
402
+ }
403
+
404
+ has_mano_params = {'global_orient': has_global_orient,
405
+ 'hand_pose': has_hand_pose,
406
+ 'betas': has_betas
407
+ }
408
+
409
+ return mano_params, has_mano_params
410
+
411
+
412
+ def fliplr_keypoints(joints: np.array, width: float, flip_permutation: List[int]) -> np.array:
413
+ """
414
+ Flip 2D or 3D keypoints.
415
+ Args:
416
+ joints (np.array): Array of shape (N, 3) or (N, 4) containing 2D or 3D keypoint locations and confidence.
417
+ flip_permutation (List): Permutation to apply after flipping.
418
+ Returns:
419
+ np.array: Flipped 2D or 3D keypoints with shape (N, 3) or (N, 4) respectively.
420
+ """
421
+ joints = joints.copy()
422
+ # Flip horizontal
423
+ joints[:, 0] = width - joints[:, 0] - 1
424
+ joints = joints[flip_permutation, :]
425
+
426
+ return joints
427
+
428
+ def keypoint_3d_processing(keypoints_3d: np.array, flip_permutation: List[int], rot: float, do_flip: float) -> np.array:
429
+ """
430
+ Process 3D keypoints (rotation/flipping).
431
+ Args:
432
+ keypoints_3d (np.array): Input array of shape (N, 4) containing the 3D keypoints and confidence.
433
+ flip_permutation (List): Permutation to apply after flipping.
434
+ rot (float): Random rotation applied to the keypoints.
435
+ do_flip (bool): Whether to flip keypoints or not.
436
+ Returns:
437
+ np.array: Transformed 3D keypoints with shape (N, 4).
438
+ """
439
+ if do_flip:
440
+ keypoints_3d = fliplr_keypoints(keypoints_3d, 1, flip_permutation)
441
+ # in-plane rotation
442
+ rot_mat = np.eye(3)
443
+ if not rot == 0:
444
+ rot_rad = -rot * np.pi / 180
445
+ sn,cs = np.sin(rot_rad), np.cos(rot_rad)
446
+ rot_mat[0,:2] = [cs, -sn]
447
+ rot_mat[1,:2] = [sn, cs]
448
+ keypoints_3d[:, :-1] = np.einsum('ij,kj->ki', rot_mat, keypoints_3d[:, :-1])
449
+ # flip the x coordinates
450
+ keypoints_3d = keypoints_3d.astype('float32')
451
+ return keypoints_3d
452
+
453
+ def rot_aa(aa: np.array, rot: float) -> np.array:
454
+ """
455
+ Rotate axis angle parameters.
456
+ Args:
457
+ aa (np.array): Axis-angle vector of shape (3,).
458
+ rot (np.array): Rotation angle in degrees.
459
+ Returns:
460
+ np.array: Rotated axis-angle vector.
461
+ """
462
+ # pose parameters
463
+ R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
464
+ [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
465
+ [0, 0, 1]])
466
+ # find the rotation of the hand in camera frame
467
+ per_rdg, _ = cv2.Rodrigues(aa)
468
+ # apply the global rotation to the global orientation
469
+ resrot, _ = cv2.Rodrigues(np.dot(R,per_rdg))
470
+ aa = (resrot.T)[0]
471
+ return aa.astype(np.float32)
472
+
473
+ def mano_param_processing(mano_params: Dict, has_mano_params: Dict, rot: float, do_flip: bool) -> Tuple[Dict, Dict]:
474
+ """
475
+ Apply random augmentations to the MANO parameters.
476
+ Args:
477
+ mano_params (Dict): MANO parameter annotations.
478
+ has_mano_params (Dict): Whether mano annotations are valid.
479
+ rot (float): Random rotation applied to the keypoints.
480
+ do_flip (bool): Whether to flip keypoints or not.
481
+ Returns:
482
+ Dict, Dict: Transformed MANO parameters and valid flags.
483
+ """
484
+ if do_flip:
485
+ mano_params, has_mano_params = fliplr_params(mano_params, has_mano_params)
486
+ mano_params['global_orient'] = rot_aa(mano_params['global_orient'], rot)
487
+ return mano_params, has_mano_params
488
+
489
+
490
+
491
+ def get_example(img_path: str|np.ndarray, center_x: float, center_y: float,
492
+ width: float, height: float,
493
+ keypoints_2d: np.array, keypoints_3d: np.array,
494
+ mano_params: Dict, has_mano_params: Dict,
495
+ flip_kp_permutation: List[int],
496
+ patch_width: int, patch_height: int,
497
+ mean: np.array, std: np.array,
498
+ do_augment: bool, is_right: bool, augm_config: CfgNode,
499
+ is_bgr: bool = True,
500
+ use_skimage_antialias: bool = False,
501
+ border_mode: int = cv2.BORDER_CONSTANT,
502
+ return_trans: bool = False) -> Tuple:
503
+ """
504
+ Get an example from the dataset and (possibly) apply random augmentations.
505
+ Args:
506
+ img_path (str): Image filename
507
+ center_x (float): Bounding box center x coordinate in the original image.
508
+ center_y (float): Bounding box center y coordinate in the original image.
509
+ width (float): Bounding box width.
510
+ height (float): Bounding box height.
511
+ keypoints_2d (np.array): Array with shape (N,3) containing the 2D keypoints in the original image coordinates.
512
+ keypoints_3d (np.array): Array with shape (N,4) containing the 3D keypoints.
513
+ mano_params (Dict): MANO parameter annotations.
514
+ has_mano_params (Dict): Whether MANO annotations are valid.
515
+ flip_kp_permutation (List): Permutation to apply to the keypoints after flipping.
516
+ patch_width (float): Output box width.
517
+ patch_height (float): Output box height.
518
+ mean (np.array): Array of shape (3,) containing the mean for normalizing the input image.
519
+ std (np.array): Array of shape (3,) containing the std for normalizing the input image.
520
+ do_augment (bool): Whether to apply data augmentation or not.
521
+ aug_config (CfgNode): Config containing augmentation parameters.
522
+ Returns:
523
+ return img_patch, keypoints_2d, keypoints_3d, mano_params, has_mano_params, img_size
524
+ img_patch (np.array): Cropped image patch of shape (3, patch_height, patch_height)
525
+ keypoints_2d (np.array): Array with shape (N,3) containing the transformed 2D keypoints.
526
+ keypoints_3d (np.array): Array with shape (N,4) containing the transformed 3D keypoints.
527
+ mano_params (Dict): Transformed MANO parameters.
528
+ has_mano_params (Dict): Valid flag for transformed MANO parameters.
529
+ img_size (np.array): Image size of the original image.
530
+ """
531
+ if isinstance(img_path, str):
532
+ # 1. load image
533
+ cvimg = cv2.imread(img_path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
534
+ if not isinstance(cvimg, np.ndarray):
535
+ raise IOError("Fail to read %s" % img_path)
536
+ elif isinstance(img_path, np.ndarray):
537
+ cvimg = img_path
538
+ else:
539
+ raise TypeError('img_path must be either a string or a numpy array')
540
+ img_height, img_width, img_channels = cvimg.shape
541
+
542
+ img_size = np.array([img_height, img_width])
543
+
544
+ # 2. get augmentation params
545
+ if do_augment:
546
+ scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = do_augmentation(augm_config)
547
+ else:
548
+ scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = 1.0, 0, False, False, 0, [1.0, 1.0, 1.0], 0., 0.
549
+
550
+ # if it's a left hand, we flip
551
+ if not is_right:
552
+ do_flip = True
553
+
554
+ if width < 1 or height < 1:
555
+ breakpoint()
556
+
557
+ if do_extreme_crop:
558
+ if extreme_crop_lvl == 0:
559
+ center_x1, center_y1, width1, height1 = extreme_cropping(center_x, center_y, width, height, keypoints_2d)
560
+ elif extreme_crop_lvl == 1:
561
+ center_x1, center_y1, width1, height1 = extreme_cropping_aggressive(center_x, center_y, width, height, keypoints_2d)
562
+
563
+ THRESH = 4
564
+ if width1 < THRESH or height1 < THRESH:
565
+ # print(f'{do_extreme_crop=}')
566
+ # print(f'width: {width}, height: {height}')
567
+ # print(f'width1: {width1}, height1: {height1}')
568
+ # print(f'center_x: {center_x}, center_y: {center_y}')
569
+ # print(f'center_x1: {center_x1}, center_y1: {center_y1}')
570
+ # print(f'keypoints_2d: {keypoints_2d}')
571
+ # print(f'\n\n', flush=True)
572
+ # breakpoint()
573
+ pass
574
+ # print(f'skip ==> width1: {width1}, height1: {height1}, width: {width}, height: {height}')
575
+ else:
576
+ center_x, center_y, width, height = center_x1, center_y1, width1, height1
577
+
578
+ center_x += width * tx
579
+ center_y += height * ty
580
+
581
+ # Process 3D keypoints
582
+ keypoints_3d = keypoint_3d_processing(keypoints_3d, flip_kp_permutation, rot, do_flip)
583
+
584
+ # 3. generate image patch
585
+ if use_skimage_antialias:
586
+ # Blur image to avoid aliasing artifacts
587
+ downsampling_factor = (patch_width / (width*scale))
588
+ if downsampling_factor > 1.1:
589
+ cvimg = gaussian(cvimg, sigma=(downsampling_factor-1)/2, channel_axis=2, preserve_range=True, truncate=3.0)
590
+
591
+ img_patch_cv, trans = generate_image_patch_cv2(cvimg,
592
+ center_x, center_y,
593
+ width, height,
594
+ patch_width, patch_height,
595
+ do_flip, scale, rot,
596
+ border_mode=border_mode)
597
+ # img_patch_cv, trans = generate_image_patch_skimage(cvimg,
598
+ # center_x, center_y,
599
+ # width, height,
600
+ # patch_width, patch_height,
601
+ # do_flip, scale, rot,
602
+ # border_mode=border_mode)
603
+
604
+ image = img_patch_cv.copy()
605
+ if is_bgr:
606
+ image = image[:, :, ::-1]
607
+ img_patch_cv = image.copy()
608
+ img_patch = convert_cvimg_to_tensor(image)
609
+
610
+
611
+ mano_params, has_mano_params = mano_param_processing(mano_params, has_mano_params, rot, do_flip)
612
+
613
+ # apply normalization
614
+ for n_c in range(min(img_channels, 3)):
615
+ img_patch[n_c, :, :] = np.clip(img_patch[n_c, :, :] * color_scale[n_c], 0, 255)
616
+ if mean is not None and std is not None:
617
+ img_patch[n_c, :, :] = (img_patch[n_c, :, :] - mean[n_c]) / std[n_c]
618
+ if do_flip:
619
+ keypoints_2d = fliplr_keypoints(keypoints_2d, img_width, flip_kp_permutation)
620
+
621
+
622
+ for n_jt in range(len(keypoints_2d)):
623
+ keypoints_2d[n_jt, 0:2] = trans_point2d(keypoints_2d[n_jt, 0:2], trans)
624
+ keypoints_2d[:, :-1] = keypoints_2d[:, :-1] / patch_width - 0.5
625
+
626
+ if not return_trans:
627
+ return img_patch, keypoints_2d, keypoints_3d, mano_params, has_mano_params, img_size
628
+ else:
629
+ return img_patch, keypoints_2d, keypoints_3d, mano_params, has_mano_params, img_size, trans
630
+
631
+ def crop_to_hips(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array) -> Tuple:
632
+ """
633
+ Extreme cropping: Crop the box up to the hip locations.
634
+ Args:
635
+ center_x (float): x coordinate of the bounding box center.
636
+ center_y (float): y coordinate of the bounding box center.
637
+ width (float): Bounding box width.
638
+ height (float): Bounding box height.
639
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
640
+ Returns:
641
+ center_x (float): x coordinate of the new bounding box center.
642
+ center_y (float): y coordinate of the new bounding box center.
643
+ width (float): New bounding box width.
644
+ height (float): New bounding box height.
645
+ """
646
+ keypoints_2d = keypoints_2d.copy()
647
+ lower_body_keypoints = [10, 11, 13, 14, 19, 20, 21, 22, 23, 24, 25+0, 25+1, 25+4, 25+5]
648
+ keypoints_2d[lower_body_keypoints, :] = 0
649
+ if keypoints_2d[:, -1].sum() > 1:
650
+ center, scale = get_bbox(keypoints_2d)
651
+ center_x = center[0]
652
+ center_y = center[1]
653
+ width = 1.1 * scale[0]
654
+ height = 1.1 * scale[1]
655
+ return center_x, center_y, width, height
656
+
657
+
658
+ def crop_to_shoulders(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
659
+ """
660
+ Extreme cropping: Crop the box up to the shoulder locations.
661
+ Args:
662
+ center_x (float): x coordinate of the bounding box center.
663
+ center_y (float): y coordinate of the bounding box center.
664
+ width (float): Bounding box width.
665
+ height (float): Bounding box height.
666
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
667
+ Returns:
668
+ center_x (float): x coordinate of the new bounding box center.
669
+ center_y (float): y coordinate of the new bounding box center.
670
+ width (float): New bounding box width.
671
+ height (float): New bounding box height.
672
+ """
673
+ keypoints_2d = keypoints_2d.copy()
674
+ lower_body_keypoints = [3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16]]
675
+ keypoints_2d[lower_body_keypoints, :] = 0
676
+ center, scale = get_bbox(keypoints_2d)
677
+ if keypoints_2d[:, -1].sum() > 1:
678
+ center, scale = get_bbox(keypoints_2d)
679
+ center_x = center[0]
680
+ center_y = center[1]
681
+ width = 1.2 * scale[0]
682
+ height = 1.2 * scale[1]
683
+ return center_x, center_y, width, height
684
+
685
+ def crop_to_head(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
686
+ """
687
+ Extreme cropping: Crop the box and keep on only the head.
688
+ Args:
689
+ center_x (float): x coordinate of the bounding box center.
690
+ center_y (float): y coordinate of the bounding box center.
691
+ width (float): Bounding box width.
692
+ height (float): Bounding box height.
693
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
694
+ Returns:
695
+ center_x (float): x coordinate of the new bounding box center.
696
+ center_y (float): y coordinate of the new bounding box center.
697
+ width (float): New bounding box width.
698
+ height (float): New bounding box height.
699
+ """
700
+ keypoints_2d = keypoints_2d.copy()
701
+ lower_body_keypoints = [3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, 16]]
702
+ keypoints_2d[lower_body_keypoints, :] = 0
703
+ if keypoints_2d[:, -1].sum() > 1:
704
+ center, scale = get_bbox(keypoints_2d)
705
+ center_x = center[0]
706
+ center_y = center[1]
707
+ width = 1.3 * scale[0]
708
+ height = 1.3 * scale[1]
709
+ return center_x, center_y, width, height
710
+
711
+ def crop_torso_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
712
+ """
713
+ Extreme cropping: Crop the box and keep on only the torso.
714
+ Args:
715
+ center_x (float): x coordinate of the bounding box center.
716
+ center_y (float): y coordinate of the bounding box center.
717
+ width (float): Bounding box width.
718
+ height (float): Bounding box height.
719
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
720
+ Returns:
721
+ center_x (float): x coordinate of the new bounding box center.
722
+ center_y (float): y coordinate of the new bounding box center.
723
+ width (float): New bounding box width.
724
+ height (float): New bounding box height.
725
+ """
726
+ keypoints_2d = keypoints_2d.copy()
727
+ nontorso_body_keypoints = [0, 3, 4, 6, 7, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 4, 5, 6, 7, 10, 11, 13, 17, 18]]
728
+ keypoints_2d[nontorso_body_keypoints, :] = 0
729
+ if keypoints_2d[:, -1].sum() > 1:
730
+ center, scale = get_bbox(keypoints_2d)
731
+ center_x = center[0]
732
+ center_y = center[1]
733
+ width = 1.1 * scale[0]
734
+ height = 1.1 * scale[1]
735
+ return center_x, center_y, width, height
736
+
737
+ def crop_rightarm_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
738
+ """
739
+ Extreme cropping: Crop the box and keep on only the right arm.
740
+ Args:
741
+ center_x (float): x coordinate of the bounding box center.
742
+ center_y (float): y coordinate of the bounding box center.
743
+ width (float): Bounding box width.
744
+ height (float): Bounding box height.
745
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
746
+ Returns:
747
+ center_x (float): x coordinate of the new bounding box center.
748
+ center_y (float): y coordinate of the new bounding box center.
749
+ width (float): New bounding box width.
750
+ height (float): New bounding box height.
751
+ """
752
+ keypoints_2d = keypoints_2d.copy()
753
+ nonrightarm_body_keypoints = [0, 1, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]]
754
+ keypoints_2d[nonrightarm_body_keypoints, :] = 0
755
+ if keypoints_2d[:, -1].sum() > 1:
756
+ center, scale = get_bbox(keypoints_2d)
757
+ center_x = center[0]
758
+ center_y = center[1]
759
+ width = 1.1 * scale[0]
760
+ height = 1.1 * scale[1]
761
+ return center_x, center_y, width, height
762
+
763
+ def crop_leftarm_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
764
+ """
765
+ Extreme cropping: Crop the box and keep on only the left arm.
766
+ Args:
767
+ center_x (float): x coordinate of the bounding box center.
768
+ center_y (float): y coordinate of the bounding box center.
769
+ width (float): Bounding box width.
770
+ height (float): Bounding box height.
771
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
772
+ Returns:
773
+ center_x (float): x coordinate of the new bounding box center.
774
+ center_y (float): y coordinate of the new bounding box center.
775
+ width (float): New bounding box width.
776
+ height (float): New bounding box height.
777
+ """
778
+ keypoints_2d = keypoints_2d.copy()
779
+ nonleftarm_body_keypoints = [0, 1, 2, 3, 4, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18]]
780
+ keypoints_2d[nonleftarm_body_keypoints, :] = 0
781
+ if keypoints_2d[:, -1].sum() > 1:
782
+ center, scale = get_bbox(keypoints_2d)
783
+ center_x = center[0]
784
+ center_y = center[1]
785
+ width = 1.1 * scale[0]
786
+ height = 1.1 * scale[1]
787
+ return center_x, center_y, width, height
788
+
789
+ def crop_legs_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
790
+ """
791
+ Extreme cropping: Crop the box and keep on only the legs.
792
+ Args:
793
+ center_x (float): x coordinate of the bounding box center.
794
+ center_y (float): y coordinate of the bounding box center.
795
+ width (float): Bounding box width.
796
+ height (float): Bounding box height.
797
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
798
+ Returns:
799
+ center_x (float): x coordinate of the new bounding box center.
800
+ center_y (float): y coordinate of the new bounding box center.
801
+ width (float): New bounding box width.
802
+ height (float): New bounding box height.
803
+ """
804
+ keypoints_2d = keypoints_2d.copy()
805
+ nonlegs_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 15, 16, 17, 18] + [25 + i for i in [6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18]]
806
+ keypoints_2d[nonlegs_body_keypoints, :] = 0
807
+ if keypoints_2d[:, -1].sum() > 1:
808
+ center, scale = get_bbox(keypoints_2d)
809
+ center_x = center[0]
810
+ center_y = center[1]
811
+ width = 1.1 * scale[0]
812
+ height = 1.1 * scale[1]
813
+ return center_x, center_y, width, height
814
+
815
+ def crop_rightleg_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
816
+ """
817
+ Extreme cropping: Crop the box and keep on only the right leg.
818
+ Args:
819
+ center_x (float): x coordinate of the bounding box center.
820
+ center_y (float): y coordinate of the bounding box center.
821
+ width (float): Bounding box width.
822
+ height (float): Bounding box height.
823
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
824
+ Returns:
825
+ center_x (float): x coordinate of the new bounding box center.
826
+ center_y (float): y coordinate of the new bounding box center.
827
+ width (float): New bounding box width.
828
+ height (float): New bounding box height.
829
+ """
830
+ keypoints_2d = keypoints_2d.copy()
831
+ nonrightleg_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21] + [25 + i for i in [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]]
832
+ keypoints_2d[nonrightleg_body_keypoints, :] = 0
833
+ if keypoints_2d[:, -1].sum() > 1:
834
+ center, scale = get_bbox(keypoints_2d)
835
+ center_x = center[0]
836
+ center_y = center[1]
837
+ width = 1.1 * scale[0]
838
+ height = 1.1 * scale[1]
839
+ return center_x, center_y, width, height
840
+
841
+ def crop_leftleg_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
842
+ """
843
+ Extreme cropping: Crop the box and keep on only the left leg.
844
+ Args:
845
+ center_x (float): x coordinate of the bounding box center.
846
+ center_y (float): y coordinate of the bounding box center.
847
+ width (float): Bounding box width.
848
+ height (float): Bounding box height.
849
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
850
+ Returns:
851
+ center_x (float): x coordinate of the new bounding box center.
852
+ center_y (float): y coordinate of the new bounding box center.
853
+ width (float): New bounding box width.
854
+ height (float): New bounding box height.
855
+ """
856
+ keypoints_2d = keypoints_2d.copy()
857
+ nonleftleg_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 15, 16, 17, 18, 22, 23, 24] + [25 + i for i in [0, 1, 2, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]]
858
+ keypoints_2d[nonleftleg_body_keypoints, :] = 0
859
+ if keypoints_2d[:, -1].sum() > 1:
860
+ center, scale = get_bbox(keypoints_2d)
861
+ center_x = center[0]
862
+ center_y = center[1]
863
+ width = 1.1 * scale[0]
864
+ height = 1.1 * scale[1]
865
+ return center_x, center_y, width, height
866
+
867
+ def full_body(keypoints_2d: np.array) -> bool:
868
+ """
869
+ Check if all main body joints are visible.
870
+ Args:
871
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
872
+ Returns:
873
+ bool: True if all main body joints are visible.
874
+ """
875
+
876
+ body_keypoints_openpose = [2, 3, 4, 5, 6, 7, 10, 11, 13, 14]
877
+ body_keypoints = [25 + i for i in [8, 7, 6, 9, 10, 11, 1, 0, 4, 5]]
878
+ return (np.maximum(keypoints_2d[body_keypoints, -1], keypoints_2d[body_keypoints_openpose, -1]) > 0).sum() == len(body_keypoints)
879
+
880
+ def upper_body(keypoints_2d: np.array):
881
+ """
882
+ Check if all upper body joints are visible.
883
+ Args:
884
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
885
+ Returns:
886
+ bool: True if all main body joints are visible.
887
+ """
888
+ lower_body_keypoints_openpose = [10, 11, 13, 14]
889
+ lower_body_keypoints = [25 + i for i in [1, 0, 4, 5]]
890
+ upper_body_keypoints_openpose = [0, 1, 15, 16, 17, 18]
891
+ upper_body_keypoints = [25+8, 25+9, 25+12, 25+13, 25+17, 25+18]
892
+ return ((keypoints_2d[lower_body_keypoints + lower_body_keypoints_openpose, -1] > 0).sum() == 0)\
893
+ and ((keypoints_2d[upper_body_keypoints + upper_body_keypoints_openpose, -1] > 0).sum() >= 2)
894
+
895
+ def get_bbox(keypoints_2d: np.array, rescale: float = 1.2) -> Tuple:
896
+ """
897
+ Get center and scale for bounding box from openpose detections.
898
+ Args:
899
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
900
+ rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
901
+ Returns:
902
+ center (np.array): Array of shape (2,) containing the new bounding box center.
903
+ scale (float): New bounding box scale.
904
+ """
905
+ valid = keypoints_2d[:,-1] > 0
906
+ valid_keypoints = keypoints_2d[valid][:,:-1]
907
+ center = 0.5 * (valid_keypoints.max(axis=0) + valid_keypoints.min(axis=0))
908
+ bbox_size = (valid_keypoints.max(axis=0) - valid_keypoints.min(axis=0))
909
+ # adjust bounding box tightness
910
+ scale = bbox_size
911
+ scale *= rescale
912
+ return center, scale
913
+
914
+ def extreme_cropping(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array) -> Tuple:
915
+ """
916
+ Perform extreme cropping
917
+ Args:
918
+ center_x (float): x coordinate of bounding box center.
919
+ center_y (float): y coordinate of bounding box center.
920
+ width (float): bounding box width.
921
+ height (float): bounding box height.
922
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
923
+ rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
924
+ Returns:
925
+ center_x (float): x coordinate of bounding box center.
926
+ center_y (float): y coordinate of bounding box center.
927
+ width (float): bounding box width.
928
+ height (float): bounding box height.
929
+ """
930
+ p = torch.rand(1).item()
931
+ if full_body(keypoints_2d):
932
+ if p < 0.7:
933
+ center_x, center_y, width, height = crop_to_hips(center_x, center_y, width, height, keypoints_2d)
934
+ elif p < 0.9:
935
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
936
+ else:
937
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
938
+ elif upper_body(keypoints_2d):
939
+ if p < 0.9:
940
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
941
+ else:
942
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
943
+
944
+ return center_x, center_y, max(width, height), max(width, height)
945
+
946
+ def extreme_cropping_aggressive(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array) -> Tuple:
947
+ """
948
+ Perform aggressive extreme cropping
949
+ Args:
950
+ center_x (float): x coordinate of bounding box center.
951
+ center_y (float): y coordinate of bounding box center.
952
+ width (float): bounding box width.
953
+ height (float): bounding box height.
954
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
955
+ rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
956
+ Returns:
957
+ center_x (float): x coordinate of bounding box center.
958
+ center_y (float): y coordinate of bounding box center.
959
+ width (float): bounding box width.
960
+ height (float): bounding box height.
961
+ """
962
+ p = torch.rand(1).item()
963
+ if full_body(keypoints_2d):
964
+ if p < 0.2:
965
+ center_x, center_y, width, height = crop_to_hips(center_x, center_y, width, height, keypoints_2d)
966
+ elif p < 0.3:
967
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
968
+ elif p < 0.4:
969
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
970
+ elif p < 0.5:
971
+ center_x, center_y, width, height = crop_torso_only(center_x, center_y, width, height, keypoints_2d)
972
+ elif p < 0.6:
973
+ center_x, center_y, width, height = crop_rightarm_only(center_x, center_y, width, height, keypoints_2d)
974
+ elif p < 0.7:
975
+ center_x, center_y, width, height = crop_leftarm_only(center_x, center_y, width, height, keypoints_2d)
976
+ elif p < 0.8:
977
+ center_x, center_y, width, height = crop_legs_only(center_x, center_y, width, height, keypoints_2d)
978
+ elif p < 0.9:
979
+ center_x, center_y, width, height = crop_rightleg_only(center_x, center_y, width, height, keypoints_2d)
980
+ else:
981
+ center_x, center_y, width, height = crop_leftleg_only(center_x, center_y, width, height, keypoints_2d)
982
+ elif upper_body(keypoints_2d):
983
+ if p < 0.2:
984
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
985
+ elif p < 0.4:
986
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
987
+ elif p < 0.6:
988
+ center_x, center_y, width, height = crop_torso_only(center_x, center_y, width, height, keypoints_2d)
989
+ elif p < 0.8:
990
+ center_x, center_y, width, height = crop_rightarm_only(center_x, center_y, width, height, keypoints_2d)
991
+ else:
992
+ center_x, center_y, width, height = crop_leftarm_only(center_x, center_y, width, height, keypoints_2d)
993
+ return center_x, center_y, max(width, height), max(width, height)
hamer/datasets/vitdet_dataset.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import cv2
4
+ import numpy as np
5
+ from skimage.filters import gaussian
6
+ from yacs.config import CfgNode
7
+ import torch
8
+
9
+ from .utils import (convert_cvimg_to_tensor,
10
+ expand_to_aspect_ratio,
11
+ generate_image_patch_cv2)
12
+
13
+ DEFAULT_MEAN = 255. * np.array([0.485, 0.456, 0.406])
14
+ DEFAULT_STD = 255. * np.array([0.229, 0.224, 0.225])
15
+
16
+ class ViTDetDataset(torch.utils.data.Dataset):
17
+
18
+ def __init__(self,
19
+ cfg: CfgNode,
20
+ img_cv2: np.array,
21
+ boxes: np.array,
22
+ right: np.array,
23
+ rescale_factor=2.5,
24
+ train: bool = False,
25
+ **kwargs):
26
+ super().__init__()
27
+ self.cfg = cfg
28
+ self.img_cv2 = img_cv2
29
+ # self.boxes = boxes
30
+
31
+ assert train == False, "ViTDetDataset is only for inference"
32
+ self.train = train
33
+ self.img_size = cfg.MODEL.IMAGE_SIZE
34
+ self.mean = 255. * np.array(self.cfg.MODEL.IMAGE_MEAN)
35
+ self.std = 255. * np.array(self.cfg.MODEL.IMAGE_STD)
36
+
37
+ # Preprocess annotations
38
+ boxes = boxes.astype(np.float32)
39
+ self.center = (boxes[:, 2:4] + boxes[:, 0:2]) / 2.0
40
+ self.scale = rescale_factor * (boxes[:, 2:4] - boxes[:, 0:2]) / 200.0
41
+ #self.scale = (boxes[:, 2:4] - boxes[:, 0:2]) / 200.0
42
+ self.personid = np.arange(len(boxes), dtype=np.int32)
43
+ self.right = right.astype(np.float32)
44
+
45
+ def __len__(self) -> int:
46
+ return len(self.personid)
47
+
48
+ def __getitem__(self, idx: int) -> Dict[str, np.array]:
49
+
50
+ center = self.center[idx].copy()
51
+ center_x = center[0]
52
+ center_y = center[1]
53
+
54
+ scale = self.scale[idx]
55
+ BBOX_SHAPE = self.cfg.MODEL.get('BBOX_SHAPE', None)
56
+ bbox_size = expand_to_aspect_ratio(scale*200, target_aspect_ratio=BBOX_SHAPE).max()
57
+ #bbox_size = scale.max()*200
58
+
59
+ patch_width = patch_height = self.img_size
60
+
61
+ right = self.right[idx].copy()
62
+ flip = right == 0
63
+
64
+ # 3. generate image patch
65
+ # if use_skimage_antialias:
66
+ cvimg = self.img_cv2.copy()
67
+ if True:
68
+ # Blur image to avoid aliasing artifacts
69
+ downsampling_factor = ((bbox_size*1.0) / patch_width)
70
+ print(f'{downsampling_factor=}')
71
+ downsampling_factor = downsampling_factor / 2.0
72
+ if downsampling_factor > 1.1:
73
+ cvimg = gaussian(cvimg, sigma=(downsampling_factor-1)/2, channel_axis=2, preserve_range=True)
74
+
75
+
76
+ img_patch_cv, trans = generate_image_patch_cv2(cvimg,
77
+ center_x, center_y,
78
+ bbox_size, bbox_size,
79
+ patch_width, patch_height,
80
+ flip, 1.0, 0,
81
+ border_mode=cv2.BORDER_CONSTANT)
82
+ img_patch_cv = img_patch_cv[:, :, ::-1]
83
+ img_patch = convert_cvimg_to_tensor(img_patch_cv)
84
+
85
+ # apply normalization
86
+ for n_c in range(min(self.img_cv2.shape[2], 3)):
87
+ img_patch[n_c, :, :] = (img_patch[n_c, :, :] - self.mean[n_c]) / self.std[n_c]
88
+
89
+ item = {
90
+ 'img': img_patch,
91
+ 'personid': int(self.personid[idx]),
92
+ }
93
+ item['box_center'] = self.center[idx].copy()
94
+ item['box_size'] = bbox_size
95
+ item['img_size'] = 1.0 * np.array([cvimg.shape[1], cvimg.shape[0]])
96
+ item['right'] = self.right[idx].copy()
97
+ return item
hamer/models/__init__.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .mano_wrapper import MANO
2
+ from .hamer import HAMER
3
+ from .discriminator import Discriminator
4
+
5
+ from ..utils.download import cache_url
6
+ from ..configs import CACHE_DIR_HAMER
7
+
8
+
9
+ def download_models(folder=CACHE_DIR_HAMER):
10
+ """Download checkpoints and files for running inference.
11
+ """
12
+ import os
13
+ os.makedirs(folder, exist_ok=True)
14
+ download_files = {
15
+ "hamer_data.tar.gz" : ["https://people.eecs.berkeley.edu/~jathushan/projects/4dhumans/hamer_data.tar.gz", folder],
16
+ }
17
+
18
+ for file_name, url in download_files.items():
19
+ output_path = os.path.join(url[1], file_name)
20
+ if not os.path.exists(output_path):
21
+ print("Downloading file: " + file_name)
22
+ # output = gdown.cached_download(url[0], output_path, fuzzy=True)
23
+ output = cache_url(url[0], output_path)
24
+ assert os.path.exists(output_path), f"{output} does not exist"
25
+
26
+ # if ends with tar.gz, tar -xzf
27
+ if file_name.endswith(".tar.gz"):
28
+ print("Extracting file: " + file_name)
29
+ os.system("tar -xvf " + output_path + " -C " + url[1])
30
+
31
+ DEFAULT_CHECKPOINT=f'{CACHE_DIR_HAMER}/hamer_ckpts/checkpoints/hamer.ckpt'
32
+ def load_hamer(checkpoint_path=DEFAULT_CHECKPOINT):
33
+ from pathlib import Path
34
+ from ..configs import get_config
35
+ model_cfg = str(Path(checkpoint_path).parent.parent / 'model_config.yaml')
36
+ model_cfg = get_config(model_cfg, update_cachedir=True)
37
+
38
+ # Override some config values, to crop bbox correctly
39
+ if (model_cfg.MODEL.BACKBONE.TYPE == 'vit') and ('BBOX_SHAPE' not in model_cfg.MODEL):
40
+ model_cfg.defrost()
41
+ assert model_cfg.MODEL.IMAGE_SIZE == 256, f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 256 for ViT backbone"
42
+ model_cfg.MODEL.BBOX_SHAPE = [192,256]
43
+ model_cfg.freeze()
44
+
45
+ model = HAMER.load_from_checkpoint(checkpoint_path, strict=False, cfg=model_cfg)
46
+ return model, model_cfg
hamer/models/backbones/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .vit import vit
2
+
3
+ def create_backbone(cfg):
4
+ if cfg.MODEL.BACKBONE.TYPE == 'vit':
5
+ return vit(cfg)
6
+ else:
7
+ raise NotImplementedError('Backbone type is not implemented')
hamer/models/backbones/vit.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import math
3
+
4
+ import torch
5
+ from functools import partial
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torch.utils.checkpoint as checkpoint
9
+
10
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
11
+
12
+ def vit(cfg):
13
+ return ViT(
14
+ img_size=(256, 192),
15
+ patch_size=16,
16
+ embed_dim=1280,
17
+ depth=32,
18
+ num_heads=16,
19
+ ratio=1,
20
+ use_checkpoint=False,
21
+ mlp_ratio=4,
22
+ qkv_bias=True,
23
+ drop_path_rate=0.55,
24
+ )
25
+
26
+ def get_abs_pos(abs_pos, h, w, ori_h, ori_w, has_cls_token=True):
27
+ """
28
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
29
+ dimension for the original embeddings.
30
+ Args:
31
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
32
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
33
+ hw (Tuple): size of input image tokens.
34
+
35
+ Returns:
36
+ Absolute positional embeddings after processing with shape (1, H, W, C)
37
+ """
38
+ cls_token = None
39
+ B, L, C = abs_pos.shape
40
+ if has_cls_token:
41
+ cls_token = abs_pos[:, 0:1]
42
+ abs_pos = abs_pos[:, 1:]
43
+
44
+ if ori_h != h or ori_w != w:
45
+ new_abs_pos = F.interpolate(
46
+ abs_pos.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2),
47
+ size=(h, w),
48
+ mode="bicubic",
49
+ align_corners=False,
50
+ ).permute(0, 2, 3, 1).reshape(B, -1, C)
51
+
52
+ else:
53
+ new_abs_pos = abs_pos
54
+
55
+ if cls_token is not None:
56
+ new_abs_pos = torch.cat([cls_token, new_abs_pos], dim=1)
57
+ return new_abs_pos
58
+
59
+ class DropPath(nn.Module):
60
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
61
+ """
62
+ def __init__(self, drop_prob=None):
63
+ super(DropPath, self).__init__()
64
+ self.drop_prob = drop_prob
65
+
66
+ def forward(self, x):
67
+ return drop_path(x, self.drop_prob, self.training)
68
+
69
+ def extra_repr(self):
70
+ return 'p={}'.format(self.drop_prob)
71
+
72
+ class Mlp(nn.Module):
73
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
74
+ super().__init__()
75
+ out_features = out_features or in_features
76
+ hidden_features = hidden_features or in_features
77
+ self.fc1 = nn.Linear(in_features, hidden_features)
78
+ self.act = act_layer()
79
+ self.fc2 = nn.Linear(hidden_features, out_features)
80
+ self.drop = nn.Dropout(drop)
81
+
82
+ def forward(self, x):
83
+ x = self.fc1(x)
84
+ x = self.act(x)
85
+ x = self.fc2(x)
86
+ x = self.drop(x)
87
+ return x
88
+
89
+ class Attention(nn.Module):
90
+ def __init__(
91
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
92
+ proj_drop=0., attn_head_dim=None,):
93
+ super().__init__()
94
+ self.num_heads = num_heads
95
+ head_dim = dim // num_heads
96
+ self.dim = dim
97
+
98
+ if attn_head_dim is not None:
99
+ head_dim = attn_head_dim
100
+ all_head_dim = head_dim * self.num_heads
101
+
102
+ self.scale = qk_scale or head_dim ** -0.5
103
+
104
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias)
105
+
106
+ self.attn_drop = nn.Dropout(attn_drop)
107
+ self.proj = nn.Linear(all_head_dim, dim)
108
+ self.proj_drop = nn.Dropout(proj_drop)
109
+
110
+ def forward(self, x):
111
+ B, N, C = x.shape
112
+ qkv = self.qkv(x)
113
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
114
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
115
+
116
+ q = q * self.scale
117
+ attn = (q @ k.transpose(-2, -1))
118
+
119
+ attn = attn.softmax(dim=-1)
120
+ attn = self.attn_drop(attn)
121
+
122
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
123
+ x = self.proj(x)
124
+ x = self.proj_drop(x)
125
+
126
+ return x
127
+
128
+ class Block(nn.Module):
129
+
130
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
131
+ drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
132
+ norm_layer=nn.LayerNorm, attn_head_dim=None
133
+ ):
134
+ super().__init__()
135
+
136
+ self.norm1 = norm_layer(dim)
137
+ self.attn = Attention(
138
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
139
+ attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim
140
+ )
141
+
142
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
143
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
144
+ self.norm2 = norm_layer(dim)
145
+ mlp_hidden_dim = int(dim * mlp_ratio)
146
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
147
+
148
+ def forward(self, x):
149
+ x = x + self.drop_path(self.attn(self.norm1(x)))
150
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
151
+ return x
152
+
153
+
154
+ class PatchEmbed(nn.Module):
155
+ """ Image to Patch Embedding
156
+ """
157
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1):
158
+ super().__init__()
159
+ img_size = to_2tuple(img_size)
160
+ patch_size = to_2tuple(patch_size)
161
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2)
162
+ self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio))
163
+ self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1]))
164
+ self.img_size = img_size
165
+ self.patch_size = patch_size
166
+ self.num_patches = num_patches
167
+
168
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio), padding=4 + 2 * (ratio//2-1))
169
+
170
+ def forward(self, x, **kwargs):
171
+ B, C, H, W = x.shape
172
+ x = self.proj(x)
173
+ Hp, Wp = x.shape[2], x.shape[3]
174
+
175
+ x = x.flatten(2).transpose(1, 2)
176
+ return x, (Hp, Wp)
177
+
178
+
179
+ class HybridEmbed(nn.Module):
180
+ """ CNN Feature Map Embedding
181
+ Extract feature map from CNN, flatten, project to embedding dim.
182
+ """
183
+ def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
184
+ super().__init__()
185
+ assert isinstance(backbone, nn.Module)
186
+ img_size = to_2tuple(img_size)
187
+ self.img_size = img_size
188
+ self.backbone = backbone
189
+ if feature_size is None:
190
+ with torch.no_grad():
191
+ training = backbone.training
192
+ if training:
193
+ backbone.eval()
194
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
195
+ feature_size = o.shape[-2:]
196
+ feature_dim = o.shape[1]
197
+ backbone.train(training)
198
+ else:
199
+ feature_size = to_2tuple(feature_size)
200
+ feature_dim = self.backbone.feature_info.channels()[-1]
201
+ self.num_patches = feature_size[0] * feature_size[1]
202
+ self.proj = nn.Linear(feature_dim, embed_dim)
203
+
204
+ def forward(self, x):
205
+ x = self.backbone(x)[-1]
206
+ x = x.flatten(2).transpose(1, 2)
207
+ x = self.proj(x)
208
+ return x
209
+
210
+
211
+ class ViT(nn.Module):
212
+
213
+ def __init__(self,
214
+ img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12,
215
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
216
+ drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False,
217
+ frozen_stages=-1, ratio=1, last_norm=True,
218
+ patch_padding='pad', freeze_attn=False, freeze_ffn=False,
219
+ ):
220
+ # Protect mutable default arguments
221
+ super(ViT, self).__init__()
222
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
223
+ self.num_classes = num_classes
224
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
225
+ self.frozen_stages = frozen_stages
226
+ self.use_checkpoint = use_checkpoint
227
+ self.patch_padding = patch_padding
228
+ self.freeze_attn = freeze_attn
229
+ self.freeze_ffn = freeze_ffn
230
+ self.depth = depth
231
+
232
+ if hybrid_backbone is not None:
233
+ self.patch_embed = HybridEmbed(
234
+ hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
235
+ else:
236
+ self.patch_embed = PatchEmbed(
237
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio)
238
+ num_patches = self.patch_embed.num_patches
239
+
240
+ # since the pretraining model has class token
241
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
242
+
243
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
244
+
245
+ self.blocks = nn.ModuleList([
246
+ Block(
247
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
248
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
249
+ )
250
+ for i in range(depth)])
251
+
252
+ self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity()
253
+
254
+ if self.pos_embed is not None:
255
+ trunc_normal_(self.pos_embed, std=.02)
256
+
257
+ self._freeze_stages()
258
+
259
+ def _freeze_stages(self):
260
+ """Freeze parameters."""
261
+ if self.frozen_stages >= 0:
262
+ self.patch_embed.eval()
263
+ for param in self.patch_embed.parameters():
264
+ param.requires_grad = False
265
+
266
+ for i in range(1, self.frozen_stages + 1):
267
+ m = self.blocks[i]
268
+ m.eval()
269
+ for param in m.parameters():
270
+ param.requires_grad = False
271
+
272
+ if self.freeze_attn:
273
+ for i in range(0, self.depth):
274
+ m = self.blocks[i]
275
+ m.attn.eval()
276
+ m.norm1.eval()
277
+ for param in m.attn.parameters():
278
+ param.requires_grad = False
279
+ for param in m.norm1.parameters():
280
+ param.requires_grad = False
281
+
282
+ if self.freeze_ffn:
283
+ self.pos_embed.requires_grad = False
284
+ self.patch_embed.eval()
285
+ for param in self.patch_embed.parameters():
286
+ param.requires_grad = False
287
+ for i in range(0, self.depth):
288
+ m = self.blocks[i]
289
+ m.mlp.eval()
290
+ m.norm2.eval()
291
+ for param in m.mlp.parameters():
292
+ param.requires_grad = False
293
+ for param in m.norm2.parameters():
294
+ param.requires_grad = False
295
+
296
+ def init_weights(self):
297
+ """Initialize the weights in backbone.
298
+ Args:
299
+ pretrained (str, optional): Path to pre-trained weights.
300
+ Defaults to None.
301
+ """
302
+ def _init_weights(m):
303
+ if isinstance(m, nn.Linear):
304
+ trunc_normal_(m.weight, std=.02)
305
+ if isinstance(m, nn.Linear) and m.bias is not None:
306
+ nn.init.constant_(m.bias, 0)
307
+ elif isinstance(m, nn.LayerNorm):
308
+ nn.init.constant_(m.bias, 0)
309
+ nn.init.constant_(m.weight, 1.0)
310
+
311
+ self.apply(_init_weights)
312
+
313
+ def get_num_layers(self):
314
+ return len(self.blocks)
315
+
316
+ @torch.jit.ignore
317
+ def no_weight_decay(self):
318
+ return {'pos_embed', 'cls_token'}
319
+
320
+ def forward_features(self, x):
321
+ B, C, H, W = x.shape
322
+ x, (Hp, Wp) = self.patch_embed(x)
323
+
324
+ if self.pos_embed is not None:
325
+ # fit for multiple GPU training
326
+ # since the first element for pos embed (sin-cos manner) is zero, it will cause no difference
327
+ x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1]
328
+
329
+ for blk in self.blocks:
330
+ if self.use_checkpoint:
331
+ x = checkpoint.checkpoint(blk, x)
332
+ else:
333
+ x = blk(x)
334
+
335
+ x = self.last_norm(x)
336
+
337
+ xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous()
338
+
339
+ return xp
340
+
341
+ def forward(self, x):
342
+ x = self.forward_features(x)
343
+ return x
344
+
345
+ def train(self, mode=True):
346
+ """Convert the model into training mode."""
347
+ super().train(mode)
348
+ self._freeze_stages()
hamer/models/components/__init__.py ADDED
File without changes
hamer/models/components/pose_transformer.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ from typing import Callable, Optional
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from einops.layers.torch import Rearrange
7
+ from torch import nn
8
+
9
+ from .t_cond_mlp import (
10
+ AdaptiveLayerNorm1D,
11
+ FrequencyEmbedder,
12
+ normalization_layer,
13
+ )
14
+ # from .vit import Attention, FeedForward
15
+
16
+
17
+ def exists(val):
18
+ return val is not None
19
+
20
+
21
+ def default(val, d):
22
+ if exists(val):
23
+ return val
24
+ return d() if isfunction(d) else d
25
+
26
+
27
+ class PreNorm(nn.Module):
28
+ def __init__(self, dim: int, fn: Callable, norm: str = "layer", norm_cond_dim: int = -1):
29
+ super().__init__()
30
+ self.norm = normalization_layer(norm, dim, norm_cond_dim)
31
+ self.fn = fn
32
+
33
+ def forward(self, x: torch.Tensor, *args, **kwargs):
34
+ if isinstance(self.norm, AdaptiveLayerNorm1D):
35
+ return self.fn(self.norm(x, *args), **kwargs)
36
+ else:
37
+ return self.fn(self.norm(x), **kwargs)
38
+
39
+
40
+ class FeedForward(nn.Module):
41
+ def __init__(self, dim, hidden_dim, dropout=0.0):
42
+ super().__init__()
43
+ self.net = nn.Sequential(
44
+ nn.Linear(dim, hidden_dim),
45
+ nn.GELU(),
46
+ nn.Dropout(dropout),
47
+ nn.Linear(hidden_dim, dim),
48
+ nn.Dropout(dropout),
49
+ )
50
+
51
+ def forward(self, x):
52
+ return self.net(x)
53
+
54
+
55
+ class Attention(nn.Module):
56
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
57
+ super().__init__()
58
+ inner_dim = dim_head * heads
59
+ project_out = not (heads == 1 and dim_head == dim)
60
+
61
+ self.heads = heads
62
+ self.scale = dim_head**-0.5
63
+
64
+ self.attend = nn.Softmax(dim=-1)
65
+ self.dropout = nn.Dropout(dropout)
66
+
67
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
68
+
69
+ self.to_out = (
70
+ nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
71
+ if project_out
72
+ else nn.Identity()
73
+ )
74
+
75
+ def forward(self, x):
76
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
77
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
78
+
79
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
80
+
81
+ attn = self.attend(dots)
82
+ attn = self.dropout(attn)
83
+
84
+ out = torch.matmul(attn, v)
85
+ out = rearrange(out, "b h n d -> b n (h d)")
86
+ return self.to_out(out)
87
+
88
+
89
+ class CrossAttention(nn.Module):
90
+ def __init__(self, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
91
+ super().__init__()
92
+ inner_dim = dim_head * heads
93
+ project_out = not (heads == 1 and dim_head == dim)
94
+
95
+ self.heads = heads
96
+ self.scale = dim_head**-0.5
97
+
98
+ self.attend = nn.Softmax(dim=-1)
99
+ self.dropout = nn.Dropout(dropout)
100
+
101
+ context_dim = default(context_dim, dim)
102
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
103
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
104
+
105
+ self.to_out = (
106
+ nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
107
+ if project_out
108
+ else nn.Identity()
109
+ )
110
+
111
+ def forward(self, x, context=None):
112
+ context = default(context, x)
113
+ k, v = self.to_kv(context).chunk(2, dim=-1)
114
+ q = self.to_q(x)
115
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), [q, k, v])
116
+
117
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
118
+
119
+ attn = self.attend(dots)
120
+ attn = self.dropout(attn)
121
+
122
+ out = torch.matmul(attn, v)
123
+ out = rearrange(out, "b h n d -> b n (h d)")
124
+ return self.to_out(out)
125
+
126
+
127
+ class Transformer(nn.Module):
128
+ def __init__(
129
+ self,
130
+ dim: int,
131
+ depth: int,
132
+ heads: int,
133
+ dim_head: int,
134
+ mlp_dim: int,
135
+ dropout: float = 0.0,
136
+ norm: str = "layer",
137
+ norm_cond_dim: int = -1,
138
+ ):
139
+ super().__init__()
140
+ self.layers = nn.ModuleList([])
141
+ for _ in range(depth):
142
+ sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
143
+ ff = FeedForward(dim, mlp_dim, dropout=dropout)
144
+ self.layers.append(
145
+ nn.ModuleList(
146
+ [
147
+ PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
148
+ PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
149
+ ]
150
+ )
151
+ )
152
+
153
+ def forward(self, x: torch.Tensor, *args):
154
+ for attn, ff in self.layers:
155
+ x = attn(x, *args) + x
156
+ x = ff(x, *args) + x
157
+ return x
158
+
159
+
160
+ class TransformerCrossAttn(nn.Module):
161
+ def __init__(
162
+ self,
163
+ dim: int,
164
+ depth: int,
165
+ heads: int,
166
+ dim_head: int,
167
+ mlp_dim: int,
168
+ dropout: float = 0.0,
169
+ norm: str = "layer",
170
+ norm_cond_dim: int = -1,
171
+ context_dim: Optional[int] = None,
172
+ ):
173
+ super().__init__()
174
+ self.layers = nn.ModuleList([])
175
+ for _ in range(depth):
176
+ sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
177
+ ca = CrossAttention(
178
+ dim, context_dim=context_dim, heads=heads, dim_head=dim_head, dropout=dropout
179
+ )
180
+ ff = FeedForward(dim, mlp_dim, dropout=dropout)
181
+ self.layers.append(
182
+ nn.ModuleList(
183
+ [
184
+ PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
185
+ PreNorm(dim, ca, norm=norm, norm_cond_dim=norm_cond_dim),
186
+ PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
187
+ ]
188
+ )
189
+ )
190
+
191
+ def forward(self, x: torch.Tensor, *args, context=None, context_list=None):
192
+ if context_list is None:
193
+ context_list = [context] * len(self.layers)
194
+ if len(context_list) != len(self.layers):
195
+ raise ValueError(f"len(context_list) != len(self.layers) ({len(context_list)} != {len(self.layers)})")
196
+
197
+ for i, (self_attn, cross_attn, ff) in enumerate(self.layers):
198
+ x = self_attn(x, *args) + x
199
+ x = cross_attn(x, *args, context=context_list[i]) + x
200
+ x = ff(x, *args) + x
201
+ return x
202
+
203
+
204
+ class DropTokenDropout(nn.Module):
205
+ def __init__(self, p: float = 0.1):
206
+ super().__init__()
207
+ if p < 0 or p > 1:
208
+ raise ValueError(
209
+ "dropout probability has to be between 0 and 1, " "but got {}".format(p)
210
+ )
211
+ self.p = p
212
+
213
+ def forward(self, x: torch.Tensor):
214
+ # x: (batch_size, seq_len, dim)
215
+ if self.training and self.p > 0:
216
+ zero_mask = torch.full_like(x[0, :, 0], self.p).bernoulli().bool()
217
+ # TODO: permutation idx for each batch using torch.argsort
218
+ if zero_mask.any():
219
+ x = x[:, ~zero_mask, :]
220
+ return x
221
+
222
+
223
+ class ZeroTokenDropout(nn.Module):
224
+ def __init__(self, p: float = 0.1):
225
+ super().__init__()
226
+ if p < 0 or p > 1:
227
+ raise ValueError(
228
+ "dropout probability has to be between 0 and 1, " "but got {}".format(p)
229
+ )
230
+ self.p = p
231
+
232
+ def forward(self, x: torch.Tensor):
233
+ # x: (batch_size, seq_len, dim)
234
+ if self.training and self.p > 0:
235
+ zero_mask = torch.full_like(x[:, :, 0], self.p).bernoulli().bool()
236
+ # Zero-out the masked tokens
237
+ x[zero_mask, :] = 0
238
+ return x
239
+
240
+
241
+ class TransformerEncoder(nn.Module):
242
+ def __init__(
243
+ self,
244
+ num_tokens: int,
245
+ token_dim: int,
246
+ dim: int,
247
+ depth: int,
248
+ heads: int,
249
+ mlp_dim: int,
250
+ dim_head: int = 64,
251
+ dropout: float = 0.0,
252
+ emb_dropout: float = 0.0,
253
+ emb_dropout_type: str = "drop",
254
+ emb_dropout_loc: str = "token",
255
+ norm: str = "layer",
256
+ norm_cond_dim: int = -1,
257
+ token_pe_numfreq: int = -1,
258
+ ):
259
+ super().__init__()
260
+ if token_pe_numfreq > 0:
261
+ token_dim_new = token_dim * (2 * token_pe_numfreq + 1)
262
+ self.to_token_embedding = nn.Sequential(
263
+ Rearrange("b n d -> (b n) d", n=num_tokens, d=token_dim),
264
+ FrequencyEmbedder(token_pe_numfreq, token_pe_numfreq - 1),
265
+ Rearrange("(b n) d -> b n d", n=num_tokens, d=token_dim_new),
266
+ nn.Linear(token_dim_new, dim),
267
+ )
268
+ else:
269
+ self.to_token_embedding = nn.Linear(token_dim, dim)
270
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
271
+ if emb_dropout_type == "drop":
272
+ self.dropout = DropTokenDropout(emb_dropout)
273
+ elif emb_dropout_type == "zero":
274
+ self.dropout = ZeroTokenDropout(emb_dropout)
275
+ else:
276
+ raise ValueError(f"Unknown emb_dropout_type: {emb_dropout_type}")
277
+ self.emb_dropout_loc = emb_dropout_loc
278
+
279
+ self.transformer = Transformer(
280
+ dim, depth, heads, dim_head, mlp_dim, dropout, norm=norm, norm_cond_dim=norm_cond_dim
281
+ )
282
+
283
+ def forward(self, inp: torch.Tensor, *args, **kwargs):
284
+ x = inp
285
+
286
+ if self.emb_dropout_loc == "input":
287
+ x = self.dropout(x)
288
+ x = self.to_token_embedding(x)
289
+
290
+ if self.emb_dropout_loc == "token":
291
+ x = self.dropout(x)
292
+ b, n, _ = x.shape
293
+ x += self.pos_embedding[:, :n]
294
+
295
+ if self.emb_dropout_loc == "token_afterpos":
296
+ x = self.dropout(x)
297
+ x = self.transformer(x, *args)
298
+ return x
299
+
300
+
301
+ class TransformerDecoder(nn.Module):
302
+ def __init__(
303
+ self,
304
+ num_tokens: int,
305
+ token_dim: int,
306
+ dim: int,
307
+ depth: int,
308
+ heads: int,
309
+ mlp_dim: int,
310
+ dim_head: int = 64,
311
+ dropout: float = 0.0,
312
+ emb_dropout: float = 0.0,
313
+ emb_dropout_type: str = 'drop',
314
+ norm: str = "layer",
315
+ norm_cond_dim: int = -1,
316
+ context_dim: Optional[int] = None,
317
+ skip_token_embedding: bool = False,
318
+ ):
319
+ super().__init__()
320
+ if not skip_token_embedding:
321
+ self.to_token_embedding = nn.Linear(token_dim, dim)
322
+ else:
323
+ self.to_token_embedding = nn.Identity()
324
+ if token_dim != dim:
325
+ raise ValueError(
326
+ f"token_dim ({token_dim}) != dim ({dim}) when skip_token_embedding is True"
327
+ )
328
+
329
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
330
+ if emb_dropout_type == "drop":
331
+ self.dropout = DropTokenDropout(emb_dropout)
332
+ elif emb_dropout_type == "zero":
333
+ self.dropout = ZeroTokenDropout(emb_dropout)
334
+ elif emb_dropout_type == "normal":
335
+ self.dropout = nn.Dropout(emb_dropout)
336
+
337
+ self.transformer = TransformerCrossAttn(
338
+ dim,
339
+ depth,
340
+ heads,
341
+ dim_head,
342
+ mlp_dim,
343
+ dropout,
344
+ norm=norm,
345
+ norm_cond_dim=norm_cond_dim,
346
+ context_dim=context_dim,
347
+ )
348
+
349
+ def forward(self, inp: torch.Tensor, *args, context=None, context_list=None):
350
+ x = self.to_token_embedding(inp)
351
+ b, n, _ = x.shape
352
+
353
+ x = self.dropout(x)
354
+ x += self.pos_embedding[:, :n]
355
+
356
+ x = self.transformer(x, *args, context=context, context_list=context_list)
357
+ return x
358
+
hamer/models/components/t_cond_mlp.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import List, Optional
3
+
4
+ import torch
5
+
6
+
7
+ class AdaptiveLayerNorm1D(torch.nn.Module):
8
+ def __init__(self, data_dim: int, norm_cond_dim: int):
9
+ super().__init__()
10
+ if data_dim <= 0:
11
+ raise ValueError(f"data_dim must be positive, but got {data_dim}")
12
+ if norm_cond_dim <= 0:
13
+ raise ValueError(f"norm_cond_dim must be positive, but got {norm_cond_dim}")
14
+ self.norm = torch.nn.LayerNorm(
15
+ data_dim
16
+ ) # TODO: Check if elementwise_affine=True is correct
17
+ self.linear = torch.nn.Linear(norm_cond_dim, 2 * data_dim)
18
+ torch.nn.init.zeros_(self.linear.weight)
19
+ torch.nn.init.zeros_(self.linear.bias)
20
+
21
+ def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
22
+ # x: (batch, ..., data_dim)
23
+ # t: (batch, norm_cond_dim)
24
+ # return: (batch, data_dim)
25
+ x = self.norm(x)
26
+ alpha, beta = self.linear(t).chunk(2, dim=-1)
27
+
28
+ # Add singleton dimensions to alpha and beta
29
+ if x.dim() > 2:
30
+ alpha = alpha.view(alpha.shape[0], *([1] * (x.dim() - 2)), alpha.shape[1])
31
+ beta = beta.view(beta.shape[0], *([1] * (x.dim() - 2)), beta.shape[1])
32
+
33
+ return x * (1 + alpha) + beta
34
+
35
+
36
+ class SequentialCond(torch.nn.Sequential):
37
+ def forward(self, input, *args, **kwargs):
38
+ for module in self:
39
+ if isinstance(module, (AdaptiveLayerNorm1D, SequentialCond, ResidualMLPBlock)):
40
+ # print(f'Passing on args to {module}', [a.shape for a in args])
41
+ input = module(input, *args, **kwargs)
42
+ else:
43
+ # print(f'Skipping passing args to {module}', [a.shape for a in args])
44
+ input = module(input)
45
+ return input
46
+
47
+
48
+ def normalization_layer(norm: Optional[str], dim: int, norm_cond_dim: int = -1):
49
+ if norm == "batch":
50
+ return torch.nn.BatchNorm1d(dim)
51
+ elif norm == "layer":
52
+ return torch.nn.LayerNorm(dim)
53
+ elif norm == "ada":
54
+ assert norm_cond_dim > 0, f"norm_cond_dim must be positive, got {norm_cond_dim}"
55
+ return AdaptiveLayerNorm1D(dim, norm_cond_dim)
56
+ elif norm is None:
57
+ return torch.nn.Identity()
58
+ else:
59
+ raise ValueError(f"Unknown norm: {norm}")
60
+
61
+
62
+ def linear_norm_activ_dropout(
63
+ input_dim: int,
64
+ output_dim: int,
65
+ activation: torch.nn.Module = torch.nn.ReLU(),
66
+ bias: bool = True,
67
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
68
+ dropout: float = 0.0,
69
+ norm_cond_dim: int = -1,
70
+ ) -> SequentialCond:
71
+ layers = []
72
+ layers.append(torch.nn.Linear(input_dim, output_dim, bias=bias))
73
+ if norm is not None:
74
+ layers.append(normalization_layer(norm, output_dim, norm_cond_dim))
75
+ layers.append(copy.deepcopy(activation))
76
+ if dropout > 0.0:
77
+ layers.append(torch.nn.Dropout(dropout))
78
+ return SequentialCond(*layers)
79
+
80
+
81
+ def create_simple_mlp(
82
+ input_dim: int,
83
+ hidden_dims: List[int],
84
+ output_dim: int,
85
+ activation: torch.nn.Module = torch.nn.ReLU(),
86
+ bias: bool = True,
87
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
88
+ dropout: float = 0.0,
89
+ norm_cond_dim: int = -1,
90
+ ) -> SequentialCond:
91
+ layers = []
92
+ prev_dim = input_dim
93
+ for hidden_dim in hidden_dims:
94
+ layers.extend(
95
+ linear_norm_activ_dropout(
96
+ prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
97
+ )
98
+ )
99
+ prev_dim = hidden_dim
100
+ layers.append(torch.nn.Linear(prev_dim, output_dim, bias=bias))
101
+ return SequentialCond(*layers)
102
+
103
+
104
+ class ResidualMLPBlock(torch.nn.Module):
105
+ def __init__(
106
+ self,
107
+ input_dim: int,
108
+ hidden_dim: int,
109
+ num_hidden_layers: int,
110
+ output_dim: int,
111
+ activation: torch.nn.Module = torch.nn.ReLU(),
112
+ bias: bool = True,
113
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
114
+ dropout: float = 0.0,
115
+ norm_cond_dim: int = -1,
116
+ ):
117
+ super().__init__()
118
+ if not (input_dim == output_dim == hidden_dim):
119
+ raise NotImplementedError(
120
+ f"input_dim {input_dim} != output_dim {output_dim} is not implemented"
121
+ )
122
+
123
+ layers = []
124
+ prev_dim = input_dim
125
+ for i in range(num_hidden_layers):
126
+ layers.append(
127
+ linear_norm_activ_dropout(
128
+ prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
129
+ )
130
+ )
131
+ prev_dim = hidden_dim
132
+ self.model = SequentialCond(*layers)
133
+ self.skip = torch.nn.Identity()
134
+
135
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
136
+ return x + self.model(x, *args, **kwargs)
137
+
138
+
139
+ class ResidualMLP(torch.nn.Module):
140
+ def __init__(
141
+ self,
142
+ input_dim: int,
143
+ hidden_dim: int,
144
+ num_hidden_layers: int,
145
+ output_dim: int,
146
+ activation: torch.nn.Module = torch.nn.ReLU(),
147
+ bias: bool = True,
148
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
149
+ dropout: float = 0.0,
150
+ num_blocks: int = 1,
151
+ norm_cond_dim: int = -1,
152
+ ):
153
+ super().__init__()
154
+ self.input_dim = input_dim
155
+ self.model = SequentialCond(
156
+ linear_norm_activ_dropout(
157
+ input_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
158
+ ),
159
+ *[
160
+ ResidualMLPBlock(
161
+ hidden_dim,
162
+ hidden_dim,
163
+ num_hidden_layers,
164
+ hidden_dim,
165
+ activation,
166
+ bias,
167
+ norm,
168
+ dropout,
169
+ norm_cond_dim,
170
+ )
171
+ for _ in range(num_blocks)
172
+ ],
173
+ torch.nn.Linear(hidden_dim, output_dim, bias=bias),
174
+ )
175
+
176
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
177
+ return self.model(x, *args, **kwargs)
178
+
179
+
180
+ class FrequencyEmbedder(torch.nn.Module):
181
+ def __init__(self, num_frequencies, max_freq_log2):
182
+ super().__init__()
183
+ frequencies = 2 ** torch.linspace(0, max_freq_log2, steps=num_frequencies)
184
+ self.register_buffer("frequencies", frequencies)
185
+
186
+ def forward(self, x):
187
+ # x should be of size (N,) or (N, D)
188
+ N = x.size(0)
189
+ if x.dim() == 1: # (N,)
190
+ x = x.unsqueeze(1) # (N, D) where D=1
191
+ x_unsqueezed = x.unsqueeze(-1) # (N, D, 1)
192
+ scaled = self.frequencies.view(1, 1, -1) * x_unsqueezed # (N, D, num_frequencies)
193
+ s = torch.sin(scaled)
194
+ c = torch.cos(scaled)
195
+ embedded = torch.cat([s, c, x_unsqueezed], dim=-1).view(
196
+ N, -1
197
+ ) # (N, D * 2 * num_frequencies + D)
198
+ return embedded
199
+
hamer/models/discriminator.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class Discriminator(nn.Module):
5
+
6
+ def __init__(self):
7
+ """
8
+ Pose + Shape discriminator proposed in HMR
9
+ """
10
+ super(Discriminator, self).__init__()
11
+
12
+ self.num_joints = 15
13
+ # poses_alone
14
+ self.D_conv1 = nn.Conv2d(9, 32, kernel_size=1)
15
+ nn.init.xavier_uniform_(self.D_conv1.weight)
16
+ nn.init.zeros_(self.D_conv1.bias)
17
+ self.relu = nn.ReLU(inplace=True)
18
+ self.D_conv2 = nn.Conv2d(32, 32, kernel_size=1)
19
+ nn.init.xavier_uniform_(self.D_conv2.weight)
20
+ nn.init.zeros_(self.D_conv2.bias)
21
+ pose_out = []
22
+ for i in range(self.num_joints):
23
+ pose_out_temp = nn.Linear(32, 1)
24
+ nn.init.xavier_uniform_(pose_out_temp.weight)
25
+ nn.init.zeros_(pose_out_temp.bias)
26
+ pose_out.append(pose_out_temp)
27
+ self.pose_out = nn.ModuleList(pose_out)
28
+
29
+ # betas
30
+ self.betas_fc1 = nn.Linear(10, 10)
31
+ nn.init.xavier_uniform_(self.betas_fc1.weight)
32
+ nn.init.zeros_(self.betas_fc1.bias)
33
+ self.betas_fc2 = nn.Linear(10, 5)
34
+ nn.init.xavier_uniform_(self.betas_fc2.weight)
35
+ nn.init.zeros_(self.betas_fc2.bias)
36
+ self.betas_out = nn.Linear(5, 1)
37
+ nn.init.xavier_uniform_(self.betas_out.weight)
38
+ nn.init.zeros_(self.betas_out.bias)
39
+
40
+ # poses_joint
41
+ self.D_alljoints_fc1 = nn.Linear(32*self.num_joints, 1024)
42
+ nn.init.xavier_uniform_(self.D_alljoints_fc1.weight)
43
+ nn.init.zeros_(self.D_alljoints_fc1.bias)
44
+ self.D_alljoints_fc2 = nn.Linear(1024, 1024)
45
+ nn.init.xavier_uniform_(self.D_alljoints_fc2.weight)
46
+ nn.init.zeros_(self.D_alljoints_fc2.bias)
47
+ self.D_alljoints_out = nn.Linear(1024, 1)
48
+ nn.init.xavier_uniform_(self.D_alljoints_out.weight)
49
+ nn.init.zeros_(self.D_alljoints_out.bias)
50
+
51
+
52
+ def forward(self, poses: torch.Tensor, betas: torch.Tensor) -> torch.Tensor:
53
+ """
54
+ Forward pass of the discriminator.
55
+ Args:
56
+ poses (torch.Tensor): Tensor of shape (B, 23, 3, 3) containing a batch of MANO hand poses (excluding the global orientation).
57
+ betas (torch.Tensor): Tensor of shape (B, 10) containign a batch of MANO beta coefficients.
58
+ Returns:
59
+ torch.Tensor: Discriminator output with shape (B, 25)
60
+ """
61
+ #import ipdb; ipdb.set_trace()
62
+ #bn = poses.shape[0]
63
+ # poses B x 207
64
+ #poses = poses.reshape(bn, -1)
65
+ # poses B x num_joints x 1 x 9
66
+ poses = poses.reshape(-1, self.num_joints, 1, 9)
67
+ bn = poses.shape[0]
68
+ # poses B x 9 x num_joints x 1
69
+ poses = poses.permute(0, 3, 1, 2).contiguous()
70
+
71
+ # poses_alone
72
+ poses = self.D_conv1(poses)
73
+ poses = self.relu(poses)
74
+ poses = self.D_conv2(poses)
75
+ poses = self.relu(poses)
76
+
77
+ poses_out = []
78
+ for i in range(self.num_joints):
79
+ poses_out_ = self.pose_out[i](poses[:, :, i, 0])
80
+ poses_out.append(poses_out_)
81
+ poses_out = torch.cat(poses_out, dim=1)
82
+
83
+ # betas
84
+ betas = self.betas_fc1(betas)
85
+ betas = self.relu(betas)
86
+ betas = self.betas_fc2(betas)
87
+ betas = self.relu(betas)
88
+ betas_out = self.betas_out(betas)
89
+
90
+ # poses_joint
91
+ poses = poses.reshape(bn,-1)
92
+ poses_all = self.D_alljoints_fc1(poses)
93
+ poses_all = self.relu(poses_all)
94
+ poses_all = self.D_alljoints_fc2(poses_all)
95
+ poses_all = self.relu(poses_all)
96
+ poses_all_out = self.D_alljoints_out(poses_all)
97
+
98
+ disc_out = torch.cat((poses_out, betas_out, poses_all_out), 1)
99
+ return disc_out
hamer/models/hamer.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch_lightning as pl
3
+ from typing import Any, Dict, Mapping, Tuple
4
+
5
+ from yacs.config import CfgNode
6
+
7
+ from ..utils import SkeletonRenderer, MeshRenderer
8
+ from ..utils.geometry import aa_to_rotmat, perspective_projection
9
+ from ..utils.pylogger import get_pylogger
10
+ from .backbones import create_backbone
11
+ from .heads import build_mano_head
12
+ from .discriminator import Discriminator
13
+ from .losses import Keypoint3DLoss, Keypoint2DLoss, ParameterLoss
14
+ from . import MANO
15
+
16
+ log = get_pylogger(__name__)
17
+
18
+ class HAMER(pl.LightningModule):
19
+
20
+ def __init__(self, cfg: CfgNode, init_renderer: bool = False):
21
+ """
22
+ Setup HAMER model
23
+ Args:
24
+ cfg (CfgNode): Config file as a yacs CfgNode
25
+ """
26
+ super().__init__()
27
+
28
+ # Save hyperparameters
29
+ self.save_hyperparameters(logger=False, ignore=['init_renderer'])
30
+
31
+ self.cfg = cfg
32
+ # Create backbone feature extractor
33
+ self.backbone = create_backbone(cfg)
34
+ #if cfg.MODEL.BACKBONE.get('PRETRAINED_WEIGHTS', None):
35
+ # log.info(f'Loading backbone weights from {cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS}')
36
+ # self.backbone.load_state_dict(torch.load(cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS, map_location='cpu')['state_dict'])
37
+
38
+ # Create MANO head
39
+ self.mano_head = build_mano_head(cfg)
40
+
41
+ # Create discriminator
42
+ if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0:
43
+ self.discriminator = Discriminator()
44
+
45
+ # Define loss functions
46
+ self.keypoint_3d_loss = Keypoint3DLoss(loss_type='l1')
47
+ self.keypoint_2d_loss = Keypoint2DLoss(loss_type='l1')
48
+ self.mano_parameter_loss = ParameterLoss()
49
+
50
+ # Instantiate MANO model
51
+ mano_cfg = {k.lower(): v for k,v in dict(cfg.MANO).items()}
52
+ self.mano = MANO(**mano_cfg)
53
+
54
+ # Buffer that shows whetheer we need to initialize ActNorm layers
55
+ self.register_buffer('initialized', torch.tensor(False))
56
+ # Setup renderer for visualization
57
+ if init_renderer:
58
+ self.renderer = SkeletonRenderer(self.cfg)
59
+ self.mesh_renderer = MeshRenderer(self.cfg, faces=self.mano.faces)
60
+ else:
61
+ self.renderer = None
62
+ self.mesh_renderer = None
63
+
64
+ # Disable automatic optimization since we use adversarial training
65
+ self.automatic_optimization = False
66
+
67
+ def on_after_backward(self):
68
+ for name, param in self.named_parameters():
69
+ if param.grad is None:
70
+ print(param.shape)
71
+ print(name)
72
+
73
+ def get_parameters(self):
74
+ all_params = list(self.mano_head.parameters())
75
+ all_params += list(self.backbone.parameters())
76
+ return all_params
77
+
78
+ def configure_optimizers(self) -> Tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
79
+ """
80
+ Setup model and distriminator Optimizers
81
+ Returns:
82
+ Tuple[torch.optim.Optimizer, torch.optim.Optimizer]: Model and discriminator optimizers
83
+ """
84
+ param_groups = [{'params': filter(lambda p: p.requires_grad, self.get_parameters()), 'lr': self.cfg.TRAIN.LR}]
85
+
86
+ optimizer = torch.optim.AdamW(params=param_groups,
87
+ # lr=self.cfg.TRAIN.LR,
88
+ weight_decay=self.cfg.TRAIN.WEIGHT_DECAY)
89
+ optimizer_disc = torch.optim.AdamW(params=self.discriminator.parameters(),
90
+ lr=self.cfg.TRAIN.LR,
91
+ weight_decay=self.cfg.TRAIN.WEIGHT_DECAY)
92
+
93
+ return optimizer, optimizer_disc
94
+
95
+ def forward_step(self, batch: Dict, train: bool = False) -> Dict:
96
+ """
97
+ Run a forward step of the network
98
+ Args:
99
+ batch (Dict): Dictionary containing batch data
100
+ train (bool): Flag indicating whether it is training or validation mode
101
+ Returns:
102
+ Dict: Dictionary containing the regression output
103
+ """
104
+
105
+ # Use RGB image as input
106
+ x = batch['img']
107
+ batch_size = x.shape[0]
108
+
109
+ # Compute conditioning features using the backbone
110
+ # if using ViT backbone, we need to use a different aspect ratio
111
+ conditioning_feats = self.backbone(x[:,:,:,32:-32])
112
+
113
+ pred_mano_params, pred_cam, _ = self.mano_head(conditioning_feats)
114
+
115
+ # Store useful regression outputs to the output dict
116
+ output = {}
117
+ output['pred_cam'] = pred_cam
118
+ output['pred_mano_params'] = {k: v.clone() for k,v in pred_mano_params.items()}
119
+
120
+ # Compute camera translation
121
+ device = pred_mano_params['hand_pose'].device
122
+ dtype = pred_mano_params['hand_pose'].dtype
123
+ focal_length = self.cfg.EXTRA.FOCAL_LENGTH * torch.ones(batch_size, 2, device=device, dtype=dtype)
124
+ pred_cam_t = torch.stack([pred_cam[:, 1],
125
+ pred_cam[:, 2],
126
+ 2*focal_length[:, 0]/(self.cfg.MODEL.IMAGE_SIZE * pred_cam[:, 0] +1e-9)],dim=-1)
127
+ output['pred_cam_t'] = pred_cam_t
128
+ output['focal_length'] = focal_length
129
+
130
+ # Compute model vertices, joints and the projected joints
131
+ pred_mano_params['global_orient'] = pred_mano_params['global_orient'].reshape(batch_size, -1, 3, 3)
132
+ pred_mano_params['hand_pose'] = pred_mano_params['hand_pose'].reshape(batch_size, -1, 3, 3)
133
+ pred_mano_params['betas'] = pred_mano_params['betas'].reshape(batch_size, -1)
134
+ mano_output = self.mano(**{k: v.float() for k,v in pred_mano_params.items()}, pose2rot=False)
135
+ pred_keypoints_3d = mano_output.joints
136
+ pred_vertices = mano_output.vertices
137
+ output['pred_keypoints_3d'] = pred_keypoints_3d.reshape(batch_size, -1, 3)
138
+ output['pred_vertices'] = pred_vertices.reshape(batch_size, -1, 3)
139
+ pred_cam_t = pred_cam_t.reshape(-1, 3)
140
+ focal_length = focal_length.reshape(-1, 2)
141
+ pred_keypoints_2d = perspective_projection(pred_keypoints_3d,
142
+ translation=pred_cam_t,
143
+ focal_length=focal_length / self.cfg.MODEL.IMAGE_SIZE)
144
+
145
+ output['pred_keypoints_2d'] = pred_keypoints_2d.reshape(batch_size, -1, 2)
146
+ return output
147
+
148
+ def compute_loss(self, batch: Dict, output: Dict, train: bool = True) -> torch.Tensor:
149
+ """
150
+ Compute losses given the input batch and the regression output
151
+ Args:
152
+ batch (Dict): Dictionary containing batch data
153
+ output (Dict): Dictionary containing the regression output
154
+ train (bool): Flag indicating whether it is training or validation mode
155
+ Returns:
156
+ torch.Tensor : Total loss for current batch
157
+ """
158
+
159
+ pred_mano_params = output['pred_mano_params']
160
+ pred_keypoints_2d = output['pred_keypoints_2d']
161
+ pred_keypoints_3d = output['pred_keypoints_3d']
162
+
163
+
164
+ batch_size = pred_mano_params['hand_pose'].shape[0]
165
+ device = pred_mano_params['hand_pose'].device
166
+ dtype = pred_mano_params['hand_pose'].dtype
167
+
168
+ # Get annotations
169
+ gt_keypoints_2d = batch['keypoints_2d']
170
+ gt_keypoints_3d = batch['keypoints_3d']
171
+ gt_mano_params = batch['mano_params']
172
+ has_mano_params = batch['has_mano_params']
173
+ is_axis_angle = batch['mano_params_is_axis_angle']
174
+
175
+ # Compute 3D keypoint loss
176
+ loss_keypoints_2d = self.keypoint_2d_loss(pred_keypoints_2d, gt_keypoints_2d)
177
+ loss_keypoints_3d = self.keypoint_3d_loss(pred_keypoints_3d, gt_keypoints_3d, pelvis_id=0)
178
+
179
+ # Compute loss on MANO parameters
180
+ loss_mano_params = {}
181
+ for k, pred in pred_mano_params.items():
182
+ gt = gt_mano_params[k].view(batch_size, -1)
183
+ if is_axis_angle[k].all():
184
+ gt = aa_to_rotmat(gt.reshape(-1, 3)).view(batch_size, -1, 3, 3)
185
+ has_gt = has_mano_params[k]
186
+ loss_mano_params[k] = self.mano_parameter_loss(pred.reshape(batch_size, -1), gt.reshape(batch_size, -1), has_gt)
187
+
188
+ loss = self.cfg.LOSS_WEIGHTS['KEYPOINTS_3D'] * loss_keypoints_3d+\
189
+ self.cfg.LOSS_WEIGHTS['KEYPOINTS_2D'] * loss_keypoints_2d+\
190
+ sum([loss_mano_params[k] * self.cfg.LOSS_WEIGHTS[k.upper()] for k in loss_mano_params])
191
+
192
+ #loss = loss + 0*self.mano.body_pose.mean()
193
+
194
+ losses = dict(loss=loss.detach(),
195
+ loss_keypoints_2d=loss_keypoints_2d.detach(),
196
+ loss_keypoints_3d=loss_keypoints_3d.detach())
197
+
198
+ for k, v in loss_mano_params.items():
199
+ losses['loss_' + k] = v.detach()
200
+
201
+ output['losses'] = losses
202
+
203
+ return loss
204
+
205
+ # Tensoroboard logging should run from first rank only
206
+ @pl.utilities.rank_zero.rank_zero_only
207
+ def tensorboard_logging(self, batch: Dict, output: Dict, step_count: int, train: bool = True, write_to_summary_writer: bool = True) -> None:
208
+ """
209
+ Log results to Tensorboard
210
+ Args:
211
+ batch (Dict): Dictionary containing batch data
212
+ output (Dict): Dictionary containing the regression output
213
+ step_count (int): Global training step count
214
+ train (bool): Flag indicating whether it is training or validation mode
215
+ """
216
+
217
+ mode = 'train' if train else 'val'
218
+ batch_size = batch['keypoints_2d'].shape[0]
219
+ images = batch['img']
220
+ images = images * torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1,3,1,1)
221
+ images = images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1,3,1,1)
222
+ #images = 255*images.permute(0, 2, 3, 1).cpu().numpy()
223
+
224
+ pred_keypoints_3d = output['pred_keypoints_3d'].detach().reshape(batch_size, -1, 3)
225
+ pred_vertices = output['pred_vertices'].detach().reshape(batch_size, -1, 3)
226
+ focal_length = output['focal_length'].detach().reshape(batch_size, 2)
227
+ gt_keypoints_3d = batch['keypoints_3d']
228
+ gt_keypoints_2d = batch['keypoints_2d']
229
+ losses = output['losses']
230
+ pred_cam_t = output['pred_cam_t'].detach().reshape(batch_size, 3)
231
+ pred_keypoints_2d = output['pred_keypoints_2d'].detach().reshape(batch_size, -1, 2)
232
+
233
+ if write_to_summary_writer:
234
+ summary_writer = self.logger.experiment
235
+ for loss_name, val in losses.items():
236
+ summary_writer.add_scalar(mode +'/' + loss_name, val.detach().item(), step_count)
237
+ num_images = min(batch_size, self.cfg.EXTRA.NUM_LOG_IMAGES)
238
+
239
+ gt_keypoints_3d = batch['keypoints_3d']
240
+ pred_keypoints_3d = output['pred_keypoints_3d'].detach().reshape(batch_size, -1, 3)
241
+
242
+ # We render the skeletons instead of the full mesh because rendering a lot of meshes will make the training slow.
243
+ #predictions = self.renderer(pred_keypoints_3d[:num_images],
244
+ # gt_keypoints_3d[:num_images],
245
+ # 2 * gt_keypoints_2d[:num_images],
246
+ # images=images[:num_images],
247
+ # camera_translation=pred_cam_t[:num_images])
248
+ predictions = self.mesh_renderer.visualize_tensorboard(pred_vertices[:num_images].cpu().numpy(),
249
+ pred_cam_t[:num_images].cpu().numpy(),
250
+ images[:num_images].cpu().numpy(),
251
+ pred_keypoints_2d[:num_images].cpu().numpy(),
252
+ gt_keypoints_2d[:num_images].cpu().numpy(),
253
+ focal_length=focal_length[:num_images].cpu().numpy())
254
+ if write_to_summary_writer:
255
+ summary_writer.add_image('%s/predictions' % mode, predictions, step_count)
256
+
257
+ return predictions
258
+
259
+ def forward(self, batch: Dict) -> Dict:
260
+ """
261
+ Run a forward step of the network in val mode
262
+ Args:
263
+ batch (Dict): Dictionary containing batch data
264
+ Returns:
265
+ Dict: Dictionary containing the regression output
266
+ """
267
+ return self.forward_step(batch, train=False)
268
+
269
+ def training_step_discriminator(self, batch: Dict,
270
+ hand_pose: torch.Tensor,
271
+ betas: torch.Tensor,
272
+ optimizer: torch.optim.Optimizer) -> torch.Tensor:
273
+ """
274
+ Run a discriminator training step
275
+ Args:
276
+ batch (Dict): Dictionary containing mocap batch data
277
+ hand_pose (torch.Tensor): Regressed hand pose from current step
278
+ betas (torch.Tensor): Regressed betas from current step
279
+ optimizer (torch.optim.Optimizer): Discriminator optimizer
280
+ Returns:
281
+ torch.Tensor: Discriminator loss
282
+ """
283
+ batch_size = hand_pose.shape[0]
284
+ gt_hand_pose = batch['hand_pose']
285
+ gt_betas = batch['betas']
286
+ gt_rotmat = aa_to_rotmat(gt_hand_pose.view(-1,3)).view(batch_size, -1, 3, 3)
287
+ disc_fake_out = self.discriminator(hand_pose.detach(), betas.detach())
288
+ loss_fake = ((disc_fake_out - 0.0) ** 2).sum() / batch_size
289
+ disc_real_out = self.discriminator(gt_rotmat, gt_betas)
290
+ loss_real = ((disc_real_out - 1.0) ** 2).sum() / batch_size
291
+ loss_disc = loss_fake + loss_real
292
+ loss = self.cfg.LOSS_WEIGHTS.ADVERSARIAL * loss_disc
293
+ optimizer.zero_grad()
294
+ self.manual_backward(loss)
295
+ optimizer.step()
296
+ return loss_disc.detach()
297
+
298
+ def training_step(self, joint_batch: Dict, batch_idx: int) -> Dict:
299
+ """
300
+ Run a full training step
301
+ Args:
302
+ joint_batch (Dict): Dictionary containing image and mocap batch data
303
+ batch_idx (int): Unused.
304
+ batch_idx (torch.Tensor): Unused.
305
+ Returns:
306
+ Dict: Dictionary containing regression output.
307
+ """
308
+ batch = joint_batch['img']
309
+ mocap_batch = joint_batch['mocap']
310
+ optimizer = self.optimizers(use_pl_optimizer=True)
311
+ if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0:
312
+ optimizer, optimizer_disc = optimizer
313
+
314
+ batch_size = batch['img'].shape[0]
315
+ output = self.forward_step(batch, train=True)
316
+ pred_mano_params = output['pred_mano_params']
317
+ if self.cfg.get('UPDATE_GT_SPIN', False):
318
+ self.update_batch_gt_spin(batch, output)
319
+ loss = self.compute_loss(batch, output, train=True)
320
+ if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0:
321
+ disc_out = self.discriminator(pred_mano_params['hand_pose'].reshape(batch_size, -1), pred_mano_params['betas'].reshape(batch_size, -1))
322
+ loss_adv = ((disc_out - 1.0) ** 2).sum() / batch_size
323
+ loss = loss + self.cfg.LOSS_WEIGHTS.ADVERSARIAL * loss_adv
324
+
325
+ # Error if Nan
326
+ if torch.isnan(loss):
327
+ raise ValueError('Loss is NaN')
328
+
329
+ optimizer.zero_grad()
330
+ self.manual_backward(loss)
331
+ # Clip gradient
332
+ if self.cfg.TRAIN.get('GRAD_CLIP_VAL', 0) > 0:
333
+ gn = torch.nn.utils.clip_grad_norm_(self.get_parameters(), self.cfg.TRAIN.GRAD_CLIP_VAL, error_if_nonfinite=True)
334
+ self.log('train/grad_norm', gn, on_step=True, on_epoch=True, prog_bar=True, logger=True)
335
+ optimizer.step()
336
+ if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0:
337
+ loss_disc = self.training_step_discriminator(mocap_batch, pred_mano_params['hand_pose'].reshape(batch_size, -1), pred_mano_params['betas'].reshape(batch_size, -1), optimizer_disc)
338
+ output['losses']['loss_gen'] = loss_adv
339
+ output['losses']['loss_disc'] = loss_disc
340
+
341
+ if self.global_step > 0 and self.global_step % self.cfg.GENERAL.LOG_STEPS == 0:
342
+ self.tensorboard_logging(batch, output, self.global_step, train=True)
343
+
344
+ self.log('train/loss', output['losses']['loss'], on_step=True, on_epoch=True, prog_bar=True, logger=False)
345
+
346
+ return output
347
+
348
+ def validation_step(self, batch: Dict, batch_idx: int, dataloader_idx=0) -> Dict:
349
+ """
350
+ Run a validation step and log to Tensorboard
351
+ Args:
352
+ batch (Dict): Dictionary containing batch data
353
+ batch_idx (int): Unused.
354
+ Returns:
355
+ Dict: Dictionary containing regression output.
356
+ """
357
+ # batch_size = batch['img'].shape[0]
358
+ output = self.forward_step(batch, train=False)
359
+ loss = self.compute_loss(batch, output, train=False)
360
+ output['loss'] = loss
361
+ self.tensorboard_logging(batch, output, self.global_step, train=False)
362
+
363
+ return output
hamer/models/heads/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .mano_head import build_mano_head