diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..e8558fa76f120147e1ade865b3d8fabfc12c231c --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2023 Active3DPose Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/README.md b/README.md index 154df8298fab5ecf322016157858e08cd1bccbe1..3f092443ce78a70c908e87b5f600ad12d8ae8486 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,110 @@ ---- -license: apache-2.0 ---- +# MotionBERT + +PyTorch [![arXiv](https://img.shields.io/badge/arXiv-2210.06551-b31b1b.svg)](https://arxiv.org/abs/2210.06551) Project Demo + +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/motionbert-unified-pretraining-for-human/monocular-3d-human-pose-estimation-on-human3)](https://paperswithcode.com/sota/monocular-3d-human-pose-estimation-on-human3?p=motionbert-unified-pretraining-for-human) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/motionbert-unified-pretraining-for-human/one-shot-3d-action-recognition-on-ntu-rgbd)](https://paperswithcode.com/sota/one-shot-3d-action-recognition-on-ntu-rgbd?p=motionbert-unified-pretraining-for-human) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/motionbert-unified-pretraining-for-human/3d-human-pose-estimation-on-3dpw)](https://paperswithcode.com/sota/3d-human-pose-estimation-on-3dpw?p=motionbert-unified-pretraining-for-human) + +This is the official PyTorch implementation of the paper *"[Learning Human Motion Representations: A Unified Perspective](https://arxiv.org/pdf/2210.06551.pdf)"*. + + + +## Installation + +```bash +conda create -n motionbert python=3.7 anaconda +conda activate motionbert +# Please install PyTorch according to your CUDA version. +conda install pytorch torchvision torchaudio pytorch-cuda=11.6 -c pytorch -c nvidia +pip install -r requirements.txt +``` + + + +## Getting Started + +| Task | Document | +| --------------------------------- | ------------------------------------------------------------ | +| Pretrain | [docs/pretrain.md](docs/pretrain.md) | +| 3D human pose estimation | [docs/pose3d.md](docs/pose3d.md) | +| Skeleton-based action recognition | [docs/action.md](docs/action.md) | +| Mesh recovery | [docs/mesh.md](docs/mesh.md) | + + + +## Applications + +### In-the-wild inference (for custom videos) + +Please refer to [docs/inference.md](docs/inference.md). + +### Using MotionBERT for *human-centric* video representations + +```python +''' + x: 2D skeletons + type = + shape = [batch size * frames * joints(17) * channels(3)] + + MotionBERT: pretrained human motion encoder + type = + + E: encoded motion representation + type = + shape = [batch size * frames * joints(17) * channels(512)] +''' +E = MotionBERT.get_representation(x) +``` + + + +> **Hints** +> +> 1. The model could handle different input lengths (no more than 243 frames). No need to explicitly specify the input length elsewhere. +> 2. The model uses 17 body keypoints ([H36M format](https://github.com/JimmySuen/integral-human-pose/blob/master/pytorch_projects/common_pytorch/dataset/hm36.py#L32)). If you are using other formats, please convert them before feeding to MotionBERT. +> 3. Please refer to [model_action.py](lib/model/model_action.py) and [model_mesh.py](lib/model/model_mesh.py) for examples of (easily) adapting MotionBERT to different downstream tasks. +> 4. For RGB videos, you need to extract 2D poses ([inference.md](docs/inference.md)), convert the keypoint format ([dataset_wild.py](lib/data/dataset_wild.py)), and then feed to MotionBERT ([infer_wild.py](infer_wild.py)). +> + + + +## Model Zoo + + + +| Model | Download Link | Config | Performance | +| ------------------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ---------------- | +| MotionBERT (162MB) | [OneDrive](https://1drv.ms/f/s!AvAdh0LSjEOlgS425shtVi9e5reN?e=6UeBa2) | [pretrain/MB_pretrain.yaml](configs/pretrain/MB_pretrain.yaml) | - | +| MotionBERT-Lite (61MB) | [OneDrive](https://1drv.ms/f/s!AvAdh0LSjEOlgS27Ydcbpxlkl0ng?e=rq2Btn) | [pretrain/MB_lite.yaml](configs/pretrain/MB_lite.yaml) | - | +| 3D Pose (H36M-SH, scratch) | [OneDrive](https://1drv.ms/f/s!AvAdh0LSjEOlgSvNejMQ0OHxMGZC?e=KcwBk1) | [pose3d/MB_train_h36m.yaml](configs/pose3d/MB_train_h36m.yaml) | 39.2mm (MPJPE) | +| 3D Pose (H36M-SH, ft) | [OneDrive](https://1drv.ms/f/s!AvAdh0LSjEOlgSoTqtyR5Zsgi8_Z?e=rn4VJf) | [pose3d/MB_ft_h36m.yaml](configs/pose3d/MB_ft_h36m.yaml) | 37.2mm (MPJPE) | +| Action Recognition (x-sub, ft) | [OneDrive](https://1drv.ms/f/s!AvAdh0LSjEOlgTX23yT_NO7RiZz-?e=nX6w2j) | [action/MB_ft_NTU60_xsub.yaml](configs/action/MB_ft_NTU60_xsub.yaml) | 97.2% (Top1 Acc) | +| Action Recognition (x-view, ft) | [OneDrive](https://1drv.ms/f/s!AvAdh0LSjEOlgTaNiXw2Nal-g37M?e=lSkE4T) | [action/MB_ft_NTU60_xview.yaml](configs/action/MB_ft_NTU60_xview.yaml) | 93.0% (Top1 Acc) | +| Mesh (with 3DPW, ft) | [OneDrive](https://1drv.ms/f/s!AvAdh0LSjEOlgTmgYNslCDWMNQi9?e=WjcB1F) | [mesh/MB_ft_pw3d.yaml](configs/mesh/MB_ft_pw3d.yaml) | 88.1mm (MPVE) | + +In most use cases (especially with finetuning), `MotionBERT-Lite` gives a similar performance with lower computation overhead. + + + +## TODO + +- [x] Scripts and docs for pretraining + +- [x] Demo for custom videos + + + +## Citation + +If you find our work useful for your project, please consider citing the paper: + +```bibtex +@article{motionbert2022, + title = {Learning Human Motion Representations: A Unified Perspective}, + author = {Zhu, Wentao and Ma, Xiaoxuan and Liu, Zhaoyang and Liu, Libin and Wu, Wayne and Wang, Yizhou}, + year = {2022}, + journal = {arXiv preprint arXiv:2210.06551}, +} +``` + diff --git a/configs/action/MB_ft_NTU120_oneshot.yaml b/configs/action/MB_ft_NTU120_oneshot.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ed25ac9719d0b7938338a21898cf2221a8b53513 --- /dev/null +++ b/configs/action/MB_ft_NTU120_oneshot.yaml @@ -0,0 +1,35 @@ +# General +finetune: True +partial_train: null + +# Traning +n_views: 2 +temp: 0.1 + +epochs: 50 +batch_size: 32 +lr_backbone: 0.0001 +lr_head: 0.001 +weight_decay: 0.01 +lr_decay: 0.99 + +# Model +model_version: embed +maxlen: 243 +dim_feat: 512 +mlp_ratio: 2 +depth: 5 +dim_rep: 512 +num_heads: 8 +att_fuse: True +num_joints: 17 +hidden_dim: 2048 +dropout_ratio: 0.1 + +# Data +clip_len: 100 + +# Augmentation +random_move: True +scale_range_train: [1, 3] +scale_range_test: [2, 2] \ No newline at end of file diff --git a/configs/action/MB_ft_NTU60_xsub.yaml b/configs/action/MB_ft_NTU60_xsub.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e26f8c42f58bbf754b3786f59bd6391b8f4a8dd4 --- /dev/null +++ b/configs/action/MB_ft_NTU60_xsub.yaml @@ -0,0 +1,35 @@ +# General +finetune: True +partial_train: null + +# Traning +epochs: 300 +batch_size: 32 +lr_backbone: 0.0001 +lr_head: 0.001 +weight_decay: 0.01 +lr_decay: 0.99 + +# Model +model_version: class +maxlen: 243 +dim_feat: 512 +mlp_ratio: 2 +depth: 5 +dim_rep: 512 +num_heads: 8 +att_fuse: True +num_joints: 17 +hidden_dim: 2048 +dropout_ratio: 0.5 + +# Data +dataset: ntu60_hrnet +data_split: xsub +clip_len: 243 +action_classes: 60 + +# Augmentation +random_move: True +scale_range_train: [1, 3] +scale_range_test: [2, 2] \ No newline at end of file diff --git a/configs/action/MB_ft_NTU60_xview.yaml b/configs/action/MB_ft_NTU60_xview.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e1fe649f5cca6a30579a42790f4fea2f6f93f5bf --- /dev/null +++ b/configs/action/MB_ft_NTU60_xview.yaml @@ -0,0 +1,35 @@ +# General +finetune: True +partial_train: null + +# Traning +epochs: 300 +batch_size: 32 +lr_backbone: 0.0001 +lr_head: 0.001 +weight_decay: 0.01 +lr_decay: 0.99 + +# Model +model_version: class +maxlen: 243 +dim_feat: 512 +mlp_ratio: 2 +depth: 5 +dim_rep: 512 +num_heads: 8 +att_fuse: True +num_joints: 17 +hidden_dim: 2048 +dropout_ratio: 0.5 + +# Data +dataset: ntu60_hrnet +data_split: xview +clip_len: 243 +action_classes: 60 + +# Augmentation +random_move: True +scale_range_train: [1, 3] +scale_range_test: [2, 2] \ No newline at end of file diff --git a/configs/action/MB_train_NTU120_oneshot.yaml b/configs/action/MB_train_NTU120_oneshot.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9aa8d732177eff9a0ece5706c9658645944fb928 --- /dev/null +++ b/configs/action/MB_train_NTU120_oneshot.yaml @@ -0,0 +1,35 @@ +# General +finetune: False +partial_train: null + +# Traning +n_views: 2 +temp: 0.1 + +epochs: 50 +batch_size: 32 +lr_backbone: 0.0001 +lr_head: 0.001 +weight_decay: 0.01 +lr_decay: 0.99 + +# Model +model_version: embed +maxlen: 243 +dim_feat: 512 +mlp_ratio: 2 +depth: 5 +dim_rep: 512 +num_heads: 8 +att_fuse: True +num_joints: 17 +hidden_dim: 2048 +dropout_ratio: 0.1 + +# Data +clip_len: 100 + +# Augmentation +random_move: True +scale_range_train: [1, 3] +scale_range_test: [2, 2] \ No newline at end of file diff --git a/configs/action/MB_train_NTU60_xsub.yaml b/configs/action/MB_train_NTU60_xsub.yaml new file mode 100644 index 0000000000000000000000000000000000000000..adec630e16b18ce9d07ba2907f8ebd94e8ddc146 --- /dev/null +++ b/configs/action/MB_train_NTU60_xsub.yaml @@ -0,0 +1,35 @@ +# General +finetune: False +partial_train: null + +# Traning +epochs: 300 +batch_size: 32 +lr_backbone: 0.0001 +lr_head: 0.0001 +weight_decay: 0.01 +lr_decay: 0.99 + +# Model +model_version: class +maxlen: 243 +dim_feat: 512 +mlp_ratio: 2 +depth: 5 +dim_rep: 512 +num_heads: 8 +att_fuse: True +num_joints: 17 +hidden_dim: 2048 +dropout_ratio: 0.5 + +# Data +dataset: ntu60_hrnet +data_split: xsub +clip_len: 243 +action_classes: 60 + +# Augmentation +random_move: True +scale_range_train: [1, 3] +scale_range_test: [2, 2] \ No newline at end of file diff --git a/configs/action/MB_train_NTU60_xview.yaml b/configs/action/MB_train_NTU60_xview.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c76d291a8fbe0b6472f9175fd04bbe755a4e304d --- /dev/null +++ b/configs/action/MB_train_NTU60_xview.yaml @@ -0,0 +1,35 @@ +# General +finetune: False +partial_train: null + +# Traning +epochs: 300 +batch_size: 32 +lr_backbone: 0.0001 +lr_head: 0.0001 +weight_decay: 0.01 +lr_decay: 0.99 + +# Model +model_version: class +maxlen: 243 +dim_feat: 512 +mlp_ratio: 2 +depth: 5 +dim_rep: 512 +num_heads: 8 +att_fuse: True +num_joints: 17 +hidden_dim: 2048 +dropout_ratio: 0.5 + +# Data +dataset: ntu60_hrnet +data_split: xview +clip_len: 243 +action_classes: 60 + +# Augmentation +random_move: True +scale_range_train: [1, 3] +scale_range_test: [2, 2] \ No newline at end of file diff --git a/configs/mesh/MB_ft_h36m.yaml b/configs/mesh/MB_ft_h36m.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a5f397ab68ee3a7d0f328712b5ca6cc3b5b748a8 --- /dev/null +++ b/configs/mesh/MB_ft_h36m.yaml @@ -0,0 +1,51 @@ +# General +finetune: True +partial_train: null +train_pw3d: False +warmup_h36m: 100 + +# Traning +epochs: 60 +checkpoint_frequency: 20 +batch_size: 128 +batch_size_img: 512 +dropout: 0.1 +dropout_loc: 1 +lr_backbone: 0.00005 +lr_head: 0.0005 +weight_decay: 0.01 +lr_decay: 0.98 + +# Model +maxlen: 243 +dim_feat: 512 +mlp_ratio: 2 +depth: 5 +dim_rep: 512 +num_heads: 8 +att_fuse: True +hidden_dim: 1024 + +# Data +data_root: data/mesh +dt_file_h36m: mesh_det_h36m.pkl +clip_len: 16 +data_stride: 8 +sample_stride: 1 +num_joints: 17 + +# Loss +lambda_3d: 0.5 +lambda_scale: 0 +lambda_3dv: 10 +lambda_lv: 0 +lambda_lg: 0 +lambda_a: 0 +lambda_av: 0 +lambda_pose: 1000 +lambda_shape: 1 +lambda_norm: 20 +loss_type: 'L1' + +# Augmentation +flip: True \ No newline at end of file diff --git a/configs/mesh/MB_ft_pw3d.yaml b/configs/mesh/MB_ft_pw3d.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2a59d6fe70f572416204f6f4f18ad82bb0f64a6c --- /dev/null +++ b/configs/mesh/MB_ft_pw3d.yaml @@ -0,0 +1,53 @@ +# General +finetune: True +partial_train: null +train_pw3d: True +warmup_h36m: 20 +warmup_coco: 100 + +# Traning +epochs: 60 +checkpoint_frequency: 20 +batch_size: 128 +batch_size_img: 512 +dropout: 0.1 +lr_backbone: 0.00005 +lr_head: 0.0005 +weight_decay: 0.01 +lr_decay: 0.98 + +# Model +maxlen: 243 +dim_feat: 512 +mlp_ratio: 2 +depth: 5 +dim_rep: 512 +num_heads: 8 +att_fuse: True +hidden_dim: 1024 + +# Data +data_root: data/mesh +dt_file_h36m: mesh_det_h36m.pkl +dt_file_coco: mesh_det_coco.pkl +dt_file_pw3d: mesh_det_pw3d.pkl +clip_len: 16 +data_stride: 8 +sample_stride: 1 +num_joints: 17 + +# Loss +lambda_3d: 0.5 +lambda_scale: 0 +lambda_3dv: 10 +lambda_lv: 0 +lambda_lg: 0 +lambda_a: 0 +lambda_av: 0 +lambda_pose: 1000 +lambda_shape: 1 +lambda_norm: 20 +loss_type: 'L1' + +# Augmentation +flip: True \ No newline at end of file diff --git a/configs/mesh/MB_train_h36m.yaml b/configs/mesh/MB_train_h36m.yaml new file mode 100644 index 0000000000000000000000000000000000000000..74669f73b09f157e326e47f1676899b7d9ba92db --- /dev/null +++ b/configs/mesh/MB_train_h36m.yaml @@ -0,0 +1,51 @@ +# General +finetune: False +partial_train: null +train_pw3d: False +warmup_h36m: 100 + +# Traning +epochs: 100 +checkpoint_frequency: 20 +batch_size: 128 +batch_size_img: 512 +dropout: 0.1 +dropout_loc: 1 +lr_backbone: 0.0001 +lr_head: 0.0001 +weight_decay: 0.01 +lr_decay: 0.98 + +# Model +maxlen: 243 +dim_feat: 512 +mlp_ratio: 2 +depth: 5 +dim_rep: 512 +num_heads: 8 +att_fuse: True +hidden_dim: 1024 + +# Data +data_root: data/mesh +dt_file_h36m: mesh_det_h36m.pkl +clip_len: 16 +data_stride: 8 +sample_stride: 1 +num_joints: 17 + +# Loss +lambda_3d: 0.5 +lambda_scale: 0 +lambda_3dv: 10 +lambda_lv: 0 +lambda_lg: 0 +lambda_a: 0 +lambda_av: 0 +lambda_pose: 1000 +lambda_shape: 1 +lambda_norm: 20 +loss_type: 'L1' + +# Augmentation +flip: True \ No newline at end of file diff --git a/configs/mesh/MB_train_pw3d.yaml b/configs/mesh/MB_train_pw3d.yaml new file mode 100644 index 0000000000000000000000000000000000000000..379d0f95d74676783554703ad23312b9e0fcec65 --- /dev/null +++ b/configs/mesh/MB_train_pw3d.yaml @@ -0,0 +1,53 @@ +# General +finetune: False +partial_train: null +train_pw3d: True +warmup_h36m: 20 +warmup_coco: 100 + +# Traning +epochs: 60 +checkpoint_frequency: 20 +batch_size: 128 +batch_size_img: 512 +dropout: 0.1 +lr_backbone: 0.0001 +lr_head: 0.0001 +weight_decay: 0.01 +lr_decay: 0.98 + +# Model +maxlen: 243 +dim_feat: 512 +mlp_ratio: 2 +depth: 5 +dim_rep: 512 +num_heads: 8 +att_fuse: True +hidden_dim: 1024 + +# Data +data_root: data/mesh +dt_file_h36m: mesh_det_h36m.pkl +dt_file_coco: mesh_det_coco.pkl +dt_file_pw3d: mesh_det_pw3d.pkl +clip_len: 16 +data_stride: 8 +sample_stride: 1 +num_joints: 17 + +# Loss +lambda_3d: 0.5 +lambda_scale: 0 +lambda_3dv: 10 +lambda_lv: 0 +lambda_lg: 0 +lambda_a: 0 +lambda_av: 0 +lambda_pose: 1000 +lambda_shape: 1 +lambda_norm: 20 +loss_type: 'L1' + +# Augmentation +flip: True diff --git a/configs/pose3d/MB_ft_h36m.yaml b/configs/pose3d/MB_ft_h36m.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b52f5a2c31057529959f871761e3430f1ada4371 --- /dev/null +++ b/configs/pose3d/MB_ft_h36m.yaml @@ -0,0 +1,50 @@ +# General +train_2d: False +no_eval: False +finetune: True +partial_train: null + +# Traning +epochs: 60 +checkpoint_frequency: 30 +batch_size: 32 +dropout: 0.0 +learning_rate: 0.0002 +weight_decay: 0.01 +lr_decay: 0.99 + +# Model +maxlen: 243 +dim_feat: 512 +mlp_ratio: 2 +depth: 5 +dim_rep: 512 +num_heads: 8 +att_fuse: True + +# Data +data_root: data/motion3d/MB3D_f243s81/ +subset_list: [H36M-SH] +dt_file: h36m_sh_conf_cam_source_final.pkl +clip_len: 243 +data_stride: 81 +rootrel: True +sample_stride: 1 +num_joints: 17 +no_conf: False +gt_2d: False + +# Loss +lambda_3d_velocity: 20.0 +lambda_scale: 0.5 +lambda_lv: 0.0 +lambda_lg: 0.0 +lambda_a: 0.0 +lambda_av: 0.0 + +# Augmentation +synthetic: False +flip: True +mask_ratio: 0. +mask_T_ratio: 0. +noise: False diff --git a/configs/pose3d/MB_ft_h36m_global.yaml b/configs/pose3d/MB_ft_h36m_global.yaml new file mode 100644 index 0000000000000000000000000000000000000000..eb9aa5e18a1a6b27d07c4d1395e6feeb12857406 --- /dev/null +++ b/configs/pose3d/MB_ft_h36m_global.yaml @@ -0,0 +1,50 @@ +# General +train_2d: False +no_eval: False +finetune: True +partial_train: null + +# Traning +epochs: 60 +checkpoint_frequency: 30 +batch_size: 32 +dropout: 0.0 +learning_rate: 0.0002 +weight_decay: 0.01 +lr_decay: 0.99 + +# Model +maxlen: 243 +dim_feat: 512 +mlp_ratio: 2 +depth: 5 +dim_rep: 512 +num_heads: 8 +att_fuse: True + +# Data +data_root: data/motion3d/MB3D_f243s81/ +subset_list: [H36M-SH] +dt_file: h36m_sh_conf_cam_source_final.pkl +clip_len: 243 +data_stride: 81 +rootrel: False +sample_stride: 1 +num_joints: 17 +no_conf: False +gt_2d: False + +# Loss +lambda_3d_velocity: 20.0 +lambda_scale: 0.5 +lambda_lv: 0.0 +lambda_lg: 0.0 +lambda_a: 0.0 +lambda_av: 0.0 + +# Augmentation +synthetic: False +flip: True +mask_ratio: 0. +mask_T_ratio: 0. +noise: False \ No newline at end of file diff --git a/configs/pose3d/MB_ft_h36m_global_lite.yaml b/configs/pose3d/MB_ft_h36m_global_lite.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4a2d5d976b8e0b52c1a166d92206aa16e95c0eb8 --- /dev/null +++ b/configs/pose3d/MB_ft_h36m_global_lite.yaml @@ -0,0 +1,50 @@ +# General +train_2d: False +no_eval: False +finetune: True +partial_train: null + +# Traning +epochs: 60 +checkpoint_frequency: 30 +batch_size: 32 +dropout: 0.0 +learning_rate: 0.0005 +weight_decay: 0.01 +lr_decay: 0.99 + +# Model +maxlen: 243 +dim_feat: 256 +mlp_ratio: 4 +depth: 5 +dim_rep: 512 +num_heads: 8 +att_fuse: True + +# Data +data_root: data/motion3d/MB3D_f243s81/ +subset_list: [H36M-SH] +dt_file: h36m_sh_conf_cam_source_final.pkl +clip_len: 243 +data_stride: 81 +rootrel: False +sample_stride: 1 +num_joints: 17 +no_conf: False +gt_2d: False + +# Loss +lambda_3d_velocity: 20.0 +lambda_scale: 0.5 +lambda_lv: 0.0 +lambda_lg: 0.0 +lambda_a: 0.0 +lambda_av: 0.0 + +# Augmentation +synthetic: False +flip: True +mask_ratio: 0. +mask_T_ratio: 0. +noise: False \ No newline at end of file diff --git a/configs/pose3d/MB_train_h36m.yaml b/configs/pose3d/MB_train_h36m.yaml new file mode 100644 index 0000000000000000000000000000000000000000..39a196e6a3089f0c88523e29a716c592369fc66a --- /dev/null +++ b/configs/pose3d/MB_train_h36m.yaml @@ -0,0 +1,51 @@ +# General +train_2d: False +no_eval: False +finetune: False +partial_train: null + +# Traning +epochs: 120 +checkpoint_frequency: 30 +batch_size: 32 +dropout: 0.0 +learning_rate: 0.0002 +weight_decay: 0.01 +lr_decay: 0.99 + +# Model +maxlen: 243 +dim_feat: 512 +mlp_ratio: 2 +depth: 5 +dim_rep: 512 +num_heads: 8 +att_fuse: True + +# Data +data_root: data/motion3d/MB3D_f243s81/ +subset_list: [H36M-SH] +dt_file: h36m_sh_conf_cam_source_final.pkl +clip_len: 243 +data_stride: 81 +rootrel: True +sample_stride: 1 +num_joints: 17 +no_conf: False +gt_2d: False + +# Loss +lambda_3d_velocity: 20.0 +lambda_scale: 0.5 +lambda_lv: 0.0 +lambda_lg: 0.0 +lambda_a: 0.0 +lambda_av: 0.0 + +# Augmentation +synthetic: False +flip: True +mask_ratio: 0. +mask_T_ratio: 0. +noise: False + diff --git a/configs/pretrain/MB_lite.yaml b/configs/pretrain/MB_lite.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ae06971e2b01a5e9a66a1c57ddfcb5a6bd31b6a5 --- /dev/null +++ b/configs/pretrain/MB_lite.yaml @@ -0,0 +1,53 @@ +# General +train_2d: True +no_eval: False +finetune: False +partial_train: null + +# Traning +epochs: 90 +checkpoint_frequency: 30 +batch_size: 64 +dropout: 0.0 +learning_rate: 0.0005 +weight_decay: 0.01 +lr_decay: 0.99 +pretrain_3d_curriculum: 30 + +# Model +maxlen: 243 +dim_feat: 256 +mlp_ratio: 4 +depth: 5 +dim_rep: 512 +num_heads: 8 +att_fuse: True + +# Data +data_root: data/motion3d/MB3D_f243s81/ +subset_list: [AMASS, H36M-SH] +dt_file: h36m_sh_conf_cam_source_final.pkl +clip_len: 243 +data_stride: 81 +rootrel: True +sample_stride: 1 +num_joints: 17 +no_conf: False +gt_2d: False + +# Loss +lambda_3d_velocity: 20.0 +lambda_scale: 0.5 +lambda_lv: 0.0 +lambda_lg: 0.0 +lambda_a: 0.0 +lambda_av: 0.0 + +# Augmentation +synthetic: True # synthetic: don't use 2D detection results, fake it (from 3D) +flip: True +mask_ratio: 0.05 +mask_T_ratio: 0.1 +noise: True +noise_path: params/synthetic_noise.pth +d2c_params_path: params/d2c_params.pkl \ No newline at end of file diff --git a/configs/pretrain/MB_pretrain.yaml b/configs/pretrain/MB_pretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..efc86eb7aa138893980d949ba43e0705606b46d3 --- /dev/null +++ b/configs/pretrain/MB_pretrain.yaml @@ -0,0 +1,53 @@ +# General +train_2d: True +no_eval: False +finetune: False +partial_train: null + +# Traning +epochs: 90 +checkpoint_frequency: 30 +batch_size: 64 +dropout: 0.0 +learning_rate: 0.0005 +weight_decay: 0.01 +lr_decay: 0.99 +pretrain_3d_curriculum: 30 + +# Model +maxlen: 243 +dim_feat: 512 +mlp_ratio: 2 +depth: 5 +dim_rep: 512 +num_heads: 8 +att_fuse: True + +# Data +data_root: data/motion3d/MB3D_f243s81/ +subset_list: [AMASS, H36M-SH] +dt_file: h36m_sh_conf_cam_source_final.pkl +clip_len: 243 +data_stride: 81 +rootrel: True +sample_stride: 1 +num_joints: 17 +no_conf: False +gt_2d: False + +# Loss +lambda_3d_velocity: 20.0 +lambda_scale: 0.5 +lambda_lv: 0.0 +lambda_lg: 0.0 +lambda_a: 0.0 +lambda_av: 0.0 + +# Augmentation +synthetic: True # synthetic: don't use 2D detection results, fake it (from 3D) +flip: True +mask_ratio: 0.05 +mask_T_ratio: 0.1 +noise: True +noise_path: params/synthetic_noise.pth +d2c_params_path: params/d2c_params.pkl diff --git a/docs/action.md b/docs/action.md new file mode 100644 index 0000000000000000000000000000000000000000..874f9415f9e9f62d3e9af446d6f1a6d2306666c8 --- /dev/null +++ b/docs/action.md @@ -0,0 +1,86 @@ +# Skeleton-based Action Recognition + +## Data + +The NTURGB+D 2D detection results are provided by [pyskl](https://github.com/kennymckormick/pyskl/blob/main/tools/data/README.md) using HRNet. + +1. Download [`ntu60_hrnet.pkl`](https://download.openmmlab.com/mmaction/pyskl/data/nturgbd/ntu60_hrnet.pkl) and [`ntu120_hrnet.pkl`](https://download.openmmlab.com/mmaction/pyskl/data/nturgbd/ntu120_hrnet.pkl) to `data/action/`. +2. Download the 1-shot split [here](https://1drv.ms/f/s!AvAdh0LSjEOlfi-hqlHxdVMZxWM) and put it to `data/action/`. + +## Running + +### NTURGB+D + +**Train from scratch:** + +```shell +# Cross-subject +python train_action.py \ +--config configs/action/MB_train_NTU60_xsub.yaml \ +--checkpoint checkpoint/action/MB_train_NTU60_xsub + +# Cross-view +python train_action.py \ +--config configs/action/MB_train_NTU60_xview.yaml \ +--checkpoint checkpoint/action/MB_train_NTU60_xview +``` + +**Finetune from pretrained MotionBERT:** + +```shell +# Cross-subject +python train_action.py \ +--config configs/action/MB_ft_NTU60_xsub.yaml \ +--pretrained checkpoint/pretrain/MB_release \ +--checkpoint checkpoint/action/FT_MB_release_MB_ft_NTU60_xsub + +# Cross-view +python train_action.py \ +--config configs/action/MB_ft_NTU60_xview.yaml \ +--pretrained checkpoint/pretrain/MB_release \ +--checkpoint checkpoint/action/FT_MB_release_MB_ft_NTU60_xview +``` + +**Evaluate:** + +```bash +# Cross-subject +python train_action.py \ +--config configs/action/MB_train_NTU60_xsub.yaml \ +--evaluate checkpoint/action/MB_train_NTU60_xsub/best_epoch.bin + +# Cross-view +python train_action.py \ +--config configs/action/MB_train_NTU60_xview.yaml \ +--evaluate checkpoint/action/MB_train_NTU60_xview/best_epoch.bin +``` + +### NTURGB+D-120 (1-shot) + +**Train from scratch:** + +```bash +python train_action_1shot.py \ +--config configs/action/MB_train_NTU120_oneshot.yaml \ +--checkpoint checkpoint/action/MB_train_NTU120_oneshot +``` + +**Finetune from a pretrained model:** + +```bash +python train_action_1shot.py \ +--config configs/action/MB_ft_NTU120_oneshot.yaml \ +--pretrained checkpoint/pretrain/MB_release \ +--checkpoint checkpoint/action/FT_MB_release_MB_ft_NTU120_oneshot +``` + +**Evaluate:** + +```bash +python train_action_1shot.py \ +--config configs/action/MB_train_NTU120_oneshot.yaml \ +--evaluate checkpoint/action/MB_train_NTU120_oneshot/best_epoch.bin +``` + + + diff --git a/docs/inference.md b/docs/inference.md new file mode 100644 index 0000000000000000000000000000000000000000..0333a71b5870ff3bb75ea277d33ec371b76322c2 --- /dev/null +++ b/docs/inference.md @@ -0,0 +1,48 @@ +# In-the-wild Inference + +## 2D Pose + +Please use [AlphaPose](https://github.com/MVIG-SJTU/AlphaPose#quick-start) to extract the 2D keypoints for your video first. We use the *Fast Pose* model trained on *Halpe* dataset ([Link](https://github.com/MVIG-SJTU/AlphaPose/blob/master/docs/MODEL_ZOO.md#halpe-dataset-26-keypoints)). + +Note: Currently we only support single person. If your video contains multiple person, you may need to use the [Pose Tracking Module for AlphaPose](https://github.com/MVIG-SJTU/AlphaPose/tree/master/trackers) and set `--focus` to specify the target person id. + + + +## 3D Pose + +| ![pose_1](https://github.com/motionbert/motionbert.github.io/blob/main/assets/pose_1.gif?raw=true) | ![pose_2](https://raw.githubusercontent.com/motionbert/motionbert.github.io/main/assets/pose_2.gif) | +| ------------------------------------------------------------ | ------------------------------------------------------------ | + + +1. Please download the checkpoint [here](https://1drv.ms/f/s!AvAdh0LSjEOlgT67igq_cIoYvO2y?e=bfEc73) and put it to `checkpoint/pose3d/FT_MB_lite_MB_ft_h36m_global_lite/`. +1. Run the following command to infer from the extracted 2D poses: +```bash +python infer_wild.py \ +--vid_path \ +--json_path \ +--out_path +``` + + + +## Mesh + +| ![mesh_1](https://raw.githubusercontent.com/motionbert/motionbert.github.io/main/assets/mesh_1.gif) | ![mesh_2](https://github.com/motionbert/motionbert.github.io/blob/main/assets/mesh_2.gif?raw=true) | +| ------------------------------------------------------------ | ----------- | + +1. Please download the checkpoint [here](https://1drv.ms/f/s!AvAdh0LSjEOlgTmgYNslCDWMNQi9?e=WjcB1F) and put it to `checkpoint/mesh/FT_MB_release_MB_ft_pw3d/` +2. Run the following command to infer from the extracted 2D poses: +```bash +python infer_wild_mesh.py \ +--vid_path \ +--json_path \ +--out_path \ +--ref_3d_motion_path <3d-pose-results.npy> # Optional, use the estimated 3D motion for root trajectory. +``` + + + + + + + diff --git a/docs/mesh.md b/docs/mesh.md new file mode 100644 index 0000000000000000000000000000000000000000..3d7e1fe27a1c050647f78ce814ee9c7609e86d7d --- /dev/null +++ b/docs/mesh.md @@ -0,0 +1,61 @@ +# Human Mesh Recovery + +## Data + +1. Download the datasets [here](https://1drv.ms/f/s!AvAdh0LSjEOlfy-hqlHxdVMZxWM) and put them to `data/mesh/`. We use Human3.6M, COCO, and PW3D for training and testing. Descriptions of the joint regressors could be found in [SPIN](https://github.com/nkolot/SPIN/tree/master/data). +2. Download the SMPL model(`basicModel_neutral_lbs_10_207_0_v1.0.0.pkl`) from [SMPLify](https://smplify.is.tue.mpg.de/), put it to `data/mesh/`, and rename it as `SMPL_NEUTRAL.pkl` + + +## Running + +**Train from scratch:** + +```bash +# with 3DPW +python train_mesh.py \ +--config configs/mesh/MB_train_pw3d.yaml \ +--checkpoint checkpoint/mesh/MB_train_pw3d + +# H36M +python train_mesh.py \ +--config configs/mesh/MB_train_h36m.yaml \ +--checkpoint checkpoint/mesh/MB_train_h36m +``` + +**Finetune from a pretrained model:** + +```bash +# with 3DPW +python train_mesh.py \ +--config configs/mesh/MB_ft_pw3d.yaml \ +--pretrained checkpoint/pretrain/MB_release \ +--checkpoint checkpoint/mesh/FT_MB_release_MB_ft_pw3d + +# H36M +python train_mesh.py \ +--config configs/mesh/MB_ft_h36m.yaml \ +--pretrained checkpoint/pretrain/MB_release \ +--checkpoint checkpoint/mesh/FT_MB_release_MB_ft_h36m + +``` + +**Evaluate:** + +```bash +# with 3DPW +python train_mesh.py \ +--config configs/mesh/MB_train_pw3d.yaml \ +--evaluate checkpoint/mesh/MB_train_pw3d/best_epoch.bin + +# H36M +python train_mesh.py \ +--config configs/mesh/MB_train_h36m.yaml \ +--evaluate checkpoint/mesh/MB_train_h36m/best_epoch.bin +``` + + + + + + + diff --git a/docs/pose3d.md b/docs/pose3d.md new file mode 100644 index 0000000000000000000000000000000000000000..448e3ee148703d8eb26e1cce295c67a85276adaf --- /dev/null +++ b/docs/pose3d.md @@ -0,0 +1,51 @@ +# 3D Human Pose Estimation + +## Data + +1. Download the finetuned Stacked Hourglass detections and our preprocessed H3.6M data (.pkl) [here](https://1drv.ms/u/s!AvAdh0LSjEOlgSMvoapR8XVTGcVj) and put it to `data/motion3d`. + + > Note that the preprocessed data is only intended for reproducing our results more easily. If you want to use the dataset, please register to the [Human3.6m website](http://vision.imar.ro/human3.6m/) and download the dataset in its original format. Please refer to [LCN](https://github.com/CHUNYUWANG/lcn-pose#data) for how we prepare the H3.6M data. + +2. Slice the motion clips (len=243, stride=81) + + ```bash + python tools/convert_h36m.py + ``` + +## Running + +**Train from scratch:** + +```bash +python train.py \ +--config configs/pose3d/MB_train_h36m.yaml \ +--checkpoint checkpoint/pose3d/MB_train_h36m +``` + +**Finetune from pretrained MotionBERT:** + +```bash +python train.py \ +--config configs/pose3d/MB_ft_h36m.yaml \ +--pretrained checkpoint/pretrain/MB_release \ +--checkpoint checkpoint/pose3d/FT_MB_release_MB_ft_h36m +``` + +**Evaluate:** + +```bash +python train.py \ +--config configs/pose3d/MB_train_h36m.yaml \ +--evaluate checkpoint/pose3d/MB_train_h36m/best_epoch.bin +``` + + + + + + + + + + + diff --git a/docs/pretrain.md b/docs/pretrain.md new file mode 100644 index 0000000000000000000000000000000000000000..bfec36d102d0bfab8bf496a3dff3ffa989c82cc7 --- /dev/null +++ b/docs/pretrain.md @@ -0,0 +1,59 @@ +# Pretrain + +## Data + +### AMASS + +1. Please download data from the [official website](https://amass.is.tue.mpg.de/download.php) (SMPL+H). +2. We provide the preprocessing scripts as follows. Minor modifications might be necessary. + - [tools/compress_amass.py](../tools/compress_amass.py): downsample the frame rate + - [tools/preprocess_amass.py](../tools/preprocess_amass.py): render the mocap data and extract the 3D keypoints + - [tools/convert_amass.py](../tools/convert_amass.py): slice them to motion clips + + +### Human 3.6M + +Please refer to [pose3d.md](pose3d.md#data). + +### InstaVariety + +1. Please download data from [human_dynamics](https://github.com/akanazawa/human_dynamics/blob/master/doc/insta_variety.md#generating-tfrecords) to `data/motion2d`. +1. Use [tools/convert_insta.py](../tools/convert_insta.py) to preprocess the 2D keypoints (need to specify `name_action` ). + +### PoseTrack + +Please download PoseTrack18 from [MMPose](https://mmpose.readthedocs.io/en/latest/tasks/2d_body_keypoint.html#posetrack18) and unzip to `data/motion2d`. + + + +The processed directory tree should look like this: + +``` +. +└── data/ + ├── motion3d/ + │ └── MB3D_f243s81/ + │ ├── AMASS + │ └── H36M-SH + ├── motion2d/ + │ ├── InstaVariety/ + │ │ ├── motion_all.npy + │ │ └── id_all.npy + │ └── posetrack18_annotations/ + │ ├── train + │ └── ... + └── ... +``` + + + +## Train + +```bash +python train.py \ +--config configs/pretrain/MB_pretrain.yaml \ +-c checkpoint/pretrain/MB_pretrain +``` + + + diff --git a/infer_wild.py b/infer_wild.py new file mode 100644 index 0000000000000000000000000000000000000000..17acd194e06db341001101f4c7bb70b6710bdae8 --- /dev/null +++ b/infer_wild.py @@ -0,0 +1,97 @@ +import os +import numpy as np +import argparse +from tqdm import tqdm +import imageio +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from lib.utils.tools import * +from lib.utils.learning import * +from lib.utils.utils_data import flip_data +from lib.data.dataset_wild import WildDetDataset +from lib.utils.vismo import render_and_save + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default="configs/pose3d/MB_ft_h36m_global_lite.yaml", help="Path to the config file.") + parser.add_argument('-e', '--evaluate', default='checkpoint/pose3d/FT_MB_lite_MB_ft_h36m_global_lite/best_epoch.bin', type=str, metavar='FILENAME', help='checkpoint to evaluate (file name)') + parser.add_argument('-j', '--json_path', type=str, help='alphapose detection result json path') + parser.add_argument('-v', '--vid_path', type=str, help='video path') + parser.add_argument('-o', '--out_path', type=str, help='output path') + parser.add_argument('--pixel', action='store_true', help='align with pixle coordinates') + parser.add_argument('--focus', type=int, default=None, help='target person id') + parser.add_argument('--clip_len', type=int, default=243, help='clip length for network input') + opts = parser.parse_args() + return opts + +opts = parse_args() +args = get_config(opts.config) + +model_backbone = load_backbone(args) +if torch.cuda.is_available(): + model_backbone = nn.DataParallel(model_backbone) + model_backbone = model_backbone.cuda() + +print('Loading checkpoint', opts.evaluate) +checkpoint = torch.load(opts.evaluate, map_location=lambda storage, loc: storage) +model_backbone.load_state_dict(checkpoint['model_pos'], strict=True) +model_pos = model_backbone +model_pos.eval() +testloader_params = { + 'batch_size': 1, + 'shuffle': False, + 'num_workers': 8, + 'pin_memory': True, + 'prefetch_factor': 4, + 'persistent_workers': True, + 'drop_last': False +} + +vid = imageio.get_reader(opts.vid_path, 'ffmpeg') +fps_in = vid.get_meta_data()['fps'] +vid_size = vid.get_meta_data()['size'] +os.makedirs(opts.out_path, exist_ok=True) + +if opts.pixel: + # Keep relative scale with pixel coornidates + wild_dataset = WildDetDataset(opts.json_path, clip_len=opts.clip_len, vid_size=vid_size, scale_range=None, focus=opts.focus) +else: + # Scale to [-1,1] + wild_dataset = WildDetDataset(opts.json_path, clip_len=opts.clip_len, scale_range=[1,1], focus=opts.focus) + +test_loader = DataLoader(wild_dataset, **testloader_params) + +results_all = [] +with torch.no_grad(): + for batch_input in tqdm(test_loader): + N, T = batch_input.shape[:2] + if torch.cuda.is_available(): + batch_input = batch_input.cuda() + if args.no_conf: + batch_input = batch_input[:, :, :, :2] + if args.flip: + batch_input_flip = flip_data(batch_input) + predicted_3d_pos_1 = model_pos(batch_input) + predicted_3d_pos_flip = model_pos(batch_input_flip) + predicted_3d_pos_2 = flip_data(predicted_3d_pos_flip) # Flip back + predicted_3d_pos = (predicted_3d_pos_1 + predicted_3d_pos_2) / 2.0 + else: + predicted_3d_pos = model_pos(batch_input) + if args.rootrel: + predicted_3d_pos[:,:,0,:]=0 # [N,T,17,3] + else: + predicted_3d_pos[:,0,0,2]=0 + pass + if args.gt_2d: + predicted_3d_pos[...,:2] = batch_input[...,:2] + results_all.append(predicted_3d_pos.cpu().numpy()) + +results_all = np.hstack(results_all) +results_all = np.concatenate(results_all) +render_and_save(results_all, '%s/X3D.mp4' % (opts.out_path), keep_imgs=False, fps=fps_in) +if opts.pixel: + # Convert to pixel coordinates + results_all = results_all * (min(vid_size) / 2.0) + results_all[:,:,:2] = results_all[:,:,:2] + np.array(vid_size) / 2.0 +np.save('%s/X3D.npy' % (opts.out_path), results_all) \ No newline at end of file diff --git a/infer_wild_mesh.py b/infer_wild_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..0c1c9b7d022a6b1cf771f84c644c75ee086f99c1 --- /dev/null +++ b/infer_wild_mesh.py @@ -0,0 +1,157 @@ +import os +import os.path as osp +import numpy as np +import argparse +import pickle +from tqdm import tqdm +import time +import random +import imageio + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader + +from lib.utils.tools import * +from lib.utils.learning import * +from lib.utils.utils_data import flip_data +from lib.utils.utils_mesh import flip_thetas_batch +from lib.data.dataset_wild import WildDetDataset +# from lib.model.loss import * +from lib.model.model_mesh import MeshRegressor +from lib.utils.vismo import render_and_save, motion2video_mesh +from lib.utils.utils_smpl import * +from scipy.optimize import least_squares + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default="configs/mesh/MB_ft_pw3d.yaml", help="Path to the config file.") + parser.add_argument('-e', '--evaluate', default='checkpoint/mesh/FT_MB_release_MB_ft_pw3d/best_epoch.bin', type=str, metavar='FILENAME', help='checkpoint to evaluate (file name)') + parser.add_argument('-j', '--json_path', type=str, help='alphapose detection result json path') + parser.add_argument('-v', '--vid_path', type=str, help='video path') + parser.add_argument('-o', '--out_path', type=str, help='output path') + parser.add_argument('--ref_3d_motion_path', type=str, default=None, help='3D motion path') + parser.add_argument('--pixel', action='store_true', help='align with pixle coordinates') + parser.add_argument('--focus', type=int, default=None, help='target person id') + parser.add_argument('--clip_len', type=int, default=243, help='clip length for network input') + opts = parser.parse_args() + return opts + +def err(p, x, y): + return np.linalg.norm(p[0] * x + np.array([p[1], p[2], p[3]]) - y, axis=-1).mean() + +def solve_scale(x, y): + print('Estimating camera transformation.') + best_res = 100000 + best_scale = None + for init_scale in tqdm(range(0,2000,5)): + p0 = [init_scale, 0.0, 0.0, 0.0] + est = least_squares(err, p0, args = (x.reshape(-1,3), y.reshape(-1,3))) + if est['fun'] < best_res: + best_res = est['fun'] + best_scale = est['x'][0] + print('Pose matching error = %.2f mm.' % best_res) + return best_scale + +opts = parse_args() +args = get_config(opts.config) + +# root_rel +# args.rootrel = True + +smpl = SMPL(args.data_root, batch_size=1).cuda() +J_regressor = smpl.J_regressor_h36m + +end = time.time() +model_backbone = load_backbone(args) +print(f'init backbone time: {(time.time()-end):02f}s') +end = time.time() +model = MeshRegressor(args, backbone=model_backbone, dim_rep=args.dim_rep, hidden_dim=args.hidden_dim, dropout_ratio=args.dropout) +print(f'init whole model time: {(time.time()-end):02f}s') + +if torch.cuda.is_available(): + model = nn.DataParallel(model) + model = model.cuda() + +chk_filename = opts.evaluate if opts.evaluate else opts.resume +print('Loading checkpoint', chk_filename) +checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage) +model.load_state_dict(checkpoint['model'], strict=True) +model.eval() + +testloader_params = { + 'batch_size': 1, + 'shuffle': False, + 'num_workers': 8, + 'pin_memory': True, + 'prefetch_factor': 4, + 'persistent_workers': True, + 'drop_last': False +} + +vid = imageio.get_reader(opts.vid_path, 'ffmpeg') +fps_in = vid.get_meta_data()['fps'] +vid_size = vid.get_meta_data()['size'] +os.makedirs(opts.out_path, exist_ok=True) + +if opts.pixel: + # Keep relative scale with pixel coornidates + wild_dataset = WildDetDataset(opts.json_path, clip_len=opts.clip_len, vid_size=vid_size, scale_range=None, focus=opts.focus) +else: + # Scale to [-1,1] + wild_dataset = WildDetDataset(opts.json_path, clip_len=opts.clip_len, scale_range=[1,1], focus=opts.focus) + +test_loader = DataLoader(wild_dataset, **testloader_params) + +verts_all = [] +reg3d_all = [] +with torch.no_grad(): + for batch_input in tqdm(test_loader): + batch_size, clip_frames = batch_input.shape[:2] + if torch.cuda.is_available(): + batch_input = batch_input.cuda().float() + output = model(batch_input) + batch_input_flip = flip_data(batch_input) + output_flip = model(batch_input_flip) + output_flip_pose = output_flip[0]['theta'][:, :, :72] + output_flip_shape = output_flip[0]['theta'][:, :, 72:] + output_flip_pose = flip_thetas_batch(output_flip_pose) + output_flip_pose = output_flip_pose.reshape(-1, 72) + output_flip_shape = output_flip_shape.reshape(-1, 10) + output_flip_smpl = smpl( + betas=output_flip_shape, + body_pose=output_flip_pose[:, 3:], + global_orient=output_flip_pose[:, :3], + pose2rot=True + ) + output_flip_verts = output_flip_smpl.vertices.detach() + J_regressor_batch = J_regressor[None, :].expand(output_flip_verts.shape[0], -1, -1).to(output_flip_verts.device) + output_flip_kp3d = torch.matmul(J_regressor_batch, output_flip_verts) # (NT,17,3) + output_flip_back = [{ + 'verts': output_flip_verts.reshape(batch_size, clip_frames, -1, 3) * 1000.0, + 'kp_3d': output_flip_kp3d.reshape(batch_size, clip_frames, -1, 3), + }] + output_final = [{}] + for k, v in output_flip_back[0].items(): + output_final[0][k] = (output[0][k] + output_flip_back[0][k]) / 2.0 + output = output_final + verts_all.append(output[0]['verts'].cpu().numpy()) + reg3d_all.append(output[0]['kp_3d'].cpu().numpy()) + +verts_all = np.hstack(verts_all) +verts_all = np.concatenate(verts_all) +reg3d_all = np.hstack(reg3d_all) +reg3d_all = np.concatenate(reg3d_all) + +if opts.ref_3d_motion_path: + ref_pose = np.load(opts.ref_3d_motion_path) + x = ref_pose - ref_pose[:, :1] + y = reg3d_all - reg3d_all[:, :1] + scale = solve_scale(x, y) + root_cam = ref_pose[:, :1] * scale + verts_all = verts_all - reg3d_all[:,:1] + root_cam + +render_and_save(verts_all, osp.join(opts.out_path, 'mesh.mp4'), keep_imgs=False, fps=fps_in, draw_face=True) + diff --git a/lib/data/augmentation.py b/lib/data/augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..0818d641ccc52459141738d86be86c6fcced9662 --- /dev/null +++ b/lib/data/augmentation.py @@ -0,0 +1,99 @@ +import numpy as np +import os +import random +import torch +import copy +import torch.nn as nn +from lib.utils.tools import read_pkl +from lib.utils.utils_data import flip_data, crop_scale_3d + +class Augmenter2D(object): + """ + Make 2D augmentations on the fly. PyTorch batch-processing GPU version. + """ + def __init__(self, args): + self.d2c_params = read_pkl(args.d2c_params_path) + self.noise = torch.load(args.noise_path) + self.mask_ratio = args.mask_ratio + self.mask_T_ratio = args.mask_T_ratio + self.num_Kframes = 27 + self.noise_std = 0.002 + + def dis2conf(self, dis, a, b, m, s): + f = a/(dis+a)+b*dis + shift = torch.randn(*dis.shape)*s + m + # if torch.cuda.is_available(): + shift = shift.to(dis.device) + return f + shift + + def add_noise(self, motion_2d): + a, b, m, s = self.d2c_params["a"], self.d2c_params["b"], self.d2c_params["m"], self.d2c_params["s"] + if "uniform_range" in self.noise.keys(): + uniform_range = self.noise["uniform_range"] + else: + uniform_range = 0.06 + motion_2d = motion_2d[:,:,:,:2] + batch_size = motion_2d.shape[0] + num_frames = motion_2d.shape[1] + num_joints = motion_2d.shape[2] + mean = self.noise['mean'].float() + std = self.noise['std'].float() + weight = self.noise['weight'][:,None].float() + sel = torch.rand((batch_size, self.num_Kframes, num_joints, 1)) + gaussian_sample = (torch.randn(batch_size, self.num_Kframes, num_joints, 2) * std + mean) + uniform_sample = (torch.rand((batch_size, self.num_Kframes, num_joints, 2))-0.5) * uniform_range + noise_mean = 0 + delta_noise = torch.randn(num_frames, num_joints, 2) * self.noise_std + noise_mean + # if torch.cuda.is_available(): + mean = mean.to(motion_2d.device) + std = std.to(motion_2d.device) + weight = weight.to(motion_2d.device) + gaussian_sample = gaussian_sample.to(motion_2d.device) + uniform_sample = uniform_sample.to(motion_2d.device) + sel = sel.to(motion_2d.device) + delta_noise = delta_noise.to(motion_2d.device) + + delta = gaussian_sample*(sel=weight) + delta_expand = torch.nn.functional.interpolate(delta.unsqueeze(1), [num_frames, num_joints, 2], mode='trilinear', align_corners=True)[:,0] + delta_final = delta_expand + delta_noise + motion_2d = motion_2d + delta_final + dx = delta_final[:,:,:,0] + dy = delta_final[:,:,:,1] + dis2 = dx*dx+dy*dy + dis = torch.sqrt(dis2) + conf = self.dis2conf(dis, a, b, m, s).clip(0,1).reshape([batch_size, num_frames, num_joints, -1]) + return torch.cat((motion_2d, conf), dim=3) + + def add_mask(self, x): + ''' motion_2d: (N,T,17,3) + ''' + N,T,J,C = x.shape + mask = torch.rand(N,T,J,1, dtype=x.dtype, device=x.device) > self.mask_ratio + mask_T = torch.rand(1,T,1,1, dtype=x.dtype, device=x.device) > self.mask_T_ratio + x = x * mask * mask_T + return x + + def augment2D(self, motion_2d, mask=False, noise=False): + if noise: + motion_2d = self.add_noise(motion_2d) + if mask: + motion_2d = self.add_mask(motion_2d) + return motion_2d + +class Augmenter3D(object): + """ + Make 3D augmentations when dataloaders get items. NumPy single motion version. + """ + def __init__(self, args): + self.flip = args.flip + if hasattr(args, "scale_range_pretrain"): + self.scale_range_pretrain = args.scale_range_pretrain + else: + self.scale_range_pretrain = None + + def augment3D(self, motion_3d): + if self.scale_range_pretrain: + motion_3d = crop_scale_3d(motion_3d, self.scale_range_pretrain) + if self.flip and random.random()>0.5: + motion_3d = flip_data(motion_3d) + return motion_3d \ No newline at end of file diff --git a/lib/data/datareader_h36m.py b/lib/data/datareader_h36m.py new file mode 100644 index 0000000000000000000000000000000000000000..b0f20b6845ea476921fbcdf5dc96fc0aee630951 --- /dev/null +++ b/lib/data/datareader_h36m.py @@ -0,0 +1,136 @@ +# Adapted from Optimizing Network Structure for 3D Human Pose Estimation (ICCV 2019) (https://github.com/CHUNYUWANG/lcn-pose/blob/master/tools/data.py) + +import numpy as np +import os, sys +import random +import copy +from lib.utils.tools import read_pkl +from lib.utils.utils_data import split_clips +random.seed(0) + +class DataReaderH36M(object): + def __init__(self, n_frames, sample_stride, data_stride_train, data_stride_test, read_confidence=True, dt_root = 'data/motion3d', dt_file = 'h36m_cpn_cam_source.pkl'): + self.gt_trainset = None + self.gt_testset = None + self.split_id_train = None + self.split_id_test = None + self.test_hw = None + self.dt_dataset = read_pkl('%s/%s' % (dt_root, dt_file)) + self.n_frames = n_frames + self.sample_stride = sample_stride + self.data_stride_train = data_stride_train + self.data_stride_test = data_stride_test + self.read_confidence = read_confidence + + def read_2d(self): + trainset = self.dt_dataset['train']['joint_2d'][::self.sample_stride, :, :2].astype(np.float32) # [N, 17, 2] + testset = self.dt_dataset['test']['joint_2d'][::self.sample_stride, :, :2].astype(np.float32) # [N, 17, 2] + # map to [-1, 1] + for idx, camera_name in enumerate(self.dt_dataset['train']['camera_name']): + if camera_name == '54138969' or camera_name == '60457274': + res_w, res_h = 1000, 1002 + elif camera_name == '55011271' or camera_name == '58860488': + res_w, res_h = 1000, 1000 + else: + assert 0, '%d data item has an invalid camera name' % idx + trainset[idx, :, :] = trainset[idx, :, :] / res_w * 2 - [1, res_h / res_w] + for idx, camera_name in enumerate(self.dt_dataset['test']['camera_name']): + if camera_name == '54138969' or camera_name == '60457274': + res_w, res_h = 1000, 1002 + elif camera_name == '55011271' or camera_name == '58860488': + res_w, res_h = 1000, 1000 + else: + assert 0, '%d data item has an invalid camera name' % idx + testset[idx, :, :] = testset[idx, :, :] / res_w * 2 - [1, res_h / res_w] + if self.read_confidence: + if 'confidence' in self.dt_dataset['train'].keys(): + train_confidence = self.dt_dataset['train']['confidence'][::self.sample_stride].astype(np.float32) + test_confidence = self.dt_dataset['test']['confidence'][::self.sample_stride].astype(np.float32) + if len(train_confidence.shape)==2: # (1559752, 17) + train_confidence = train_confidence[:,:,None] + test_confidence = test_confidence[:,:,None] + else: + # No conf provided, fill with 1. + train_confidence = np.ones(trainset.shape)[:,:,0:1] + test_confidence = np.ones(testset.shape)[:,:,0:1] + trainset = np.concatenate((trainset, train_confidence), axis=2) # [N, 17, 3] + testset = np.concatenate((testset, test_confidence), axis=2) # [N, 17, 3] + return trainset, testset + + def read_3d(self): + train_labels = self.dt_dataset['train']['joint3d_image'][::self.sample_stride, :, :3].astype(np.float32) # [N, 17, 3] + test_labels = self.dt_dataset['test']['joint3d_image'][::self.sample_stride, :, :3].astype(np.float32) # [N, 17, 3] + # map to [-1, 1] + for idx, camera_name in enumerate(self.dt_dataset['train']['camera_name']): + if camera_name == '54138969' or camera_name == '60457274': + res_w, res_h = 1000, 1002 + elif camera_name == '55011271' or camera_name == '58860488': + res_w, res_h = 1000, 1000 + else: + assert 0, '%d data item has an invalid camera name' % idx + train_labels[idx, :, :2] = train_labels[idx, :, :2] / res_w * 2 - [1, res_h / res_w] + train_labels[idx, :, 2:] = train_labels[idx, :, 2:] / res_w * 2 + + for idx, camera_name in enumerate(self.dt_dataset['test']['camera_name']): + if camera_name == '54138969' or camera_name == '60457274': + res_w, res_h = 1000, 1002 + elif camera_name == '55011271' or camera_name == '58860488': + res_w, res_h = 1000, 1000 + else: + assert 0, '%d data item has an invalid camera name' % idx + test_labels[idx, :, :2] = test_labels[idx, :, :2] / res_w * 2 - [1, res_h / res_w] + test_labels[idx, :, 2:] = test_labels[idx, :, 2:] / res_w * 2 + + return train_labels, test_labels + def read_hw(self): + if self.test_hw is not None: + return self.test_hw + test_hw = np.zeros((len(self.dt_dataset['test']['camera_name']), 2)) + for idx, camera_name in enumerate(self.dt_dataset['test']['camera_name']): + if camera_name == '54138969' or camera_name == '60457274': + res_w, res_h = 1000, 1002 + elif camera_name == '55011271' or camera_name == '58860488': + res_w, res_h = 1000, 1000 + else: + assert 0, '%d data item has an invalid camera name' % idx + test_hw[idx] = res_w, res_h + self.test_hw = test_hw + return test_hw + + def get_split_id(self): + if self.split_id_train is not None and self.split_id_test is not None: + return self.split_id_train, self.split_id_test + vid_list_train = self.dt_dataset['train']['source'][::self.sample_stride] # (1559752,) + vid_list_test = self.dt_dataset['test']['source'][::self.sample_stride] # (566920,) + self.split_id_train = split_clips(vid_list_train, self.n_frames, data_stride=self.data_stride_train) + self.split_id_test = split_clips(vid_list_test, self.n_frames, data_stride=self.data_stride_test) + return self.split_id_train, self.split_id_test + + def get_hw(self): +# Only Testset HW is needed for denormalization + test_hw = self.read_hw() # train_data (1559752, 2) test_data (566920, 2) + split_id_train, split_id_test = self.get_split_id() + test_hw = test_hw[split_id_test][:,0,:] # (N, 2) + return test_hw + + def get_sliced_data(self): + train_data, test_data = self.read_2d() # train_data (1559752, 17, 3) test_data (566920, 17, 3) + train_labels, test_labels = self.read_3d() # train_labels (1559752, 17, 3) test_labels (566920, 17, 3) + split_id_train, split_id_test = self.get_split_id() + train_data, test_data = train_data[split_id_train], test_data[split_id_test] # (N, 27, 17, 3) + train_labels, test_labels = train_labels[split_id_train], test_labels[split_id_test] # (N, 27, 17, 3) + # ipdb.set_trace() + return train_data, test_data, train_labels, test_labels + + def denormalize(self, test_data): +# data: (N, n_frames, 51) or data: (N, n_frames, 17, 3) + n_clips = test_data.shape[0] + test_hw = self.get_hw() + data = test_data.reshape([n_clips, -1, 17, 3]) + assert len(data) == len(test_hw) + # denormalize (x,y,z) coordiantes for results + for idx, item in enumerate(data): + res_w, res_h = test_hw[idx] + data[idx, :, :, :2] = (data[idx, :, :, :2] + np.array([1, res_h / res_w])) * res_w / 2 + data[idx, :, :, 2:] = data[idx, :, :, 2:] * res_w / 2 + return data # [n_clips, -1, 17, 3] diff --git a/lib/data/datareader_mesh.py b/lib/data/datareader_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..7cb1e87910a50bf36ee69c1dfd316075b2260b9f --- /dev/null +++ b/lib/data/datareader_mesh.py @@ -0,0 +1,59 @@ +import numpy as np +import os, sys +import copy +from lib.utils.tools import read_pkl +from lib.utils.utils_data import split_clips + +class DataReaderMesh(object): + def __init__(self, n_frames, sample_stride, data_stride_train, data_stride_test, read_confidence=True, dt_root = 'data/mesh', dt_file = 'pw3d_det.pkl', res=[1920, 1920]): + self.split_id_train = None + self.split_id_test = None + self.dt_dataset = read_pkl('%s/%s' % (dt_root, dt_file)) + self.n_frames = n_frames + self.sample_stride = sample_stride + self.data_stride_train = data_stride_train + self.data_stride_test = data_stride_test + self.read_confidence = read_confidence + self.res = res + + def read_2d(self): + if self.res is not None: + res_w, res_h = self.res + offset = [1, res_h / res_w] + else: + res = np.array(self.dt_dataset['train']['img_hw'])[::self.sample_stride].astype(np.float32) + res_w, res_h = res.max(1)[:, None, None], res.max(1)[:, None, None] + offset = 1 + trainset = self.dt_dataset['train']['joint_2d'][::self.sample_stride, :, :2].astype(np.float32) # [N, 17, 2] + testset = self.dt_dataset['test']['joint_2d'][::self.sample_stride, :, :2].astype(np.float32) # [N, 17, 2] + # res_w, res_h = self.res + trainset = trainset / res_w * 2 - offset + testset = testset / res_w * 2 - offset + if self.read_confidence: + train_confidence = self.dt_dataset['train']['confidence'][::self.sample_stride].astype(np.float32) + test_confidence = self.dt_dataset['test']['confidence'][::self.sample_stride].astype(np.float32) + if len(train_confidence.shape)==2: + train_confidence = train_confidence[:,:,None] + test_confidence = test_confidence[:,:,None] + trainset = np.concatenate((trainset, train_confidence), axis=2) # [N, 17, 3] + testset = np.concatenate((testset, test_confidence), axis=2) # [N, 17, 3] + return trainset, testset + + def get_split_id(self): + if self.split_id_train is not None and self.split_id_test is not None: + return self.split_id_train, self.split_id_test + vid_list_train = self.dt_dataset['train']['source'][::self.sample_stride] + vid_list_test = self.dt_dataset['test']['source'][::self.sample_stride] + self.split_id_train = split_clips(vid_list_train, self.n_frames, self.data_stride_train) + self.split_id_test = split_clips(vid_list_test, self.n_frames, self.data_stride_test) + return self.split_id_train, self.split_id_test + + def get_sliced_data(self): + train_data, test_data = self.read_2d() + train_labels, test_labels = self.read_3d() + split_id_train, split_id_test = self.get_split_id() + train_data, test_data = train_data[split_id_train], test_data[split_id_test] # (N, 27, 17, 3) + train_labels, test_labels = train_labels[split_id_train], test_labels[split_id_test] # (N, 27, 17, 3) + return train_data, test_data, train_labels, test_labels + + \ No newline at end of file diff --git a/lib/data/dataset_action.py b/lib/data/dataset_action.py new file mode 100644 index 0000000000000000000000000000000000000000..87bc5de62698baeb9d785139566b20ff5d4b5280 --- /dev/null +++ b/lib/data/dataset_action.py @@ -0,0 +1,206 @@ +import torch +import numpy as np +import os +import random +import copy +from torch.utils.data import Dataset, DataLoader +from lib.utils.utils_data import crop_scale, resample +from lib.utils.tools import read_pkl + +def get_action_names(file_path = "data/action/ntu_actions.txt"): + f = open(file_path, "r") + s = f.read() + actions = s.split('\n') + action_names = [] + for a in actions: + action_names.append(a.split('.')[1][1:]) + return action_names + +def make_cam(x, img_shape): + ''' + Input: x (M x T x V x C) + img_shape (height, width) + ''' + h, w = img_shape + if w >= h: + x_cam = x / w * 2 - 1 + else: + x_cam = x / h * 2 - 1 + return x_cam + +def coco2h36m(x): + ''' + Input: x (M x T x V x C) + + COCO: {0-nose 1-Leye 2-Reye 3-Lear 4Rear 5-Lsho 6-Rsho 7-Lelb 8-Relb 9-Lwri 10-Rwri 11-Lhip 12-Rhip 13-Lkne 14-Rkne 15-Lank 16-Rank} + + H36M: + 0: 'root', + 1: 'rhip', + 2: 'rkne', + 3: 'rank', + 4: 'lhip', + 5: 'lkne', + 6: 'lank', + 7: 'belly', + 8: 'neck', + 9: 'nose', + 10: 'head', + 11: 'lsho', + 12: 'lelb', + 13: 'lwri', + 14: 'rsho', + 15: 'relb', + 16: 'rwri' + ''' + y = np.zeros(x.shape) + y[:,:,0,:] = (x[:,:,11,:] + x[:,:,12,:]) * 0.5 + y[:,:,1,:] = x[:,:,12,:] + y[:,:,2,:] = x[:,:,14,:] + y[:,:,3,:] = x[:,:,16,:] + y[:,:,4,:] = x[:,:,11,:] + y[:,:,5,:] = x[:,:,13,:] + y[:,:,6,:] = x[:,:,15,:] + y[:,:,8,:] = (x[:,:,5,:] + x[:,:,6,:]) * 0.5 + y[:,:,7,:] = (y[:,:,0,:] + y[:,:,8,:]) * 0.5 + y[:,:,9,:] = x[:,:,0,:] + y[:,:,10,:] = (x[:,:,1,:] + x[:,:,2,:]) * 0.5 + y[:,:,11,:] = x[:,:,5,:] + y[:,:,12,:] = x[:,:,7,:] + y[:,:,13,:] = x[:,:,9,:] + y[:,:,14,:] = x[:,:,6,:] + y[:,:,15,:] = x[:,:,8,:] + y[:,:,16,:] = x[:,:,10,:] + return y + +def random_move(data_numpy, + angle_range=[-10., 10.], + scale_range=[0.9, 1.1], + transform_range=[-0.1, 0.1], + move_time_candidate=[1]): + data_numpy = np.transpose(data_numpy, (3,1,2,0)) # M,T,V,C-> C,T,V,M + C, T, V, M = data_numpy.shape + move_time = random.choice(move_time_candidate) + node = np.arange(0, T, T * 1.0 / move_time).round().astype(int) + node = np.append(node, T) + num_node = len(node) + A = np.random.uniform(angle_range[0], angle_range[1], num_node) + S = np.random.uniform(scale_range[0], scale_range[1], num_node) + T_x = np.random.uniform(transform_range[0], transform_range[1], num_node) + T_y = np.random.uniform(transform_range[0], transform_range[1], num_node) + a = np.zeros(T) + s = np.zeros(T) + t_x = np.zeros(T) + t_y = np.zeros(T) + # linspace + for i in range(num_node - 1): + a[node[i]:node[i + 1]] = np.linspace( + A[i], A[i + 1], node[i + 1] - node[i]) * np.pi / 180 + s[node[i]:node[i + 1]] = np.linspace(S[i], S[i + 1], node[i + 1] - node[i]) + t_x[node[i]:node[i + 1]] = np.linspace(T_x[i], T_x[i + 1], node[i + 1] - node[i]) + t_y[node[i]:node[i + 1]] = np.linspace(T_y[i], T_y[i + 1], node[i + 1] - node[i]) + theta = np.array([[np.cos(a) * s, -np.sin(a) * s], + [np.sin(a) * s, np.cos(a) * s]]) + # perform transformation + for i_frame in range(T): + xy = data_numpy[0:2, i_frame, :, :] + new_xy = np.dot(theta[:, :, i_frame], xy.reshape(2, -1)) + new_xy[0] += t_x[i_frame] + new_xy[1] += t_y[i_frame] + data_numpy[0:2, i_frame, :, :] = new_xy.reshape(2, V, M) + data_numpy = np.transpose(data_numpy, (3,1,2,0)) # C,T,V,M -> M,T,V,C + return data_numpy + +def human_tracking(x): + M, T = x.shape[:2] + if M==1: + return x + else: + diff0 = np.sum(np.linalg.norm(x[0,1:] - x[0,:-1], axis=-1), axis=-1) # (T-1, V, C) -> (T-1) + diff1 = np.sum(np.linalg.norm(x[0,1:] - x[1,:-1], axis=-1), axis=-1) + x_new = np.zeros(x.shape) + sel = np.cumsum(diff0 > diff1) % 2 + sel = sel[:,None,None] + x_new[0][0] = x[0][0] + x_new[1][0] = x[1][0] + x_new[0,1:] = x[1,1:] * sel + x[0,1:] * (1-sel) + x_new[1,1:] = x[0,1:] * sel + x[1,1:] * (1-sel) + return x_new + +class ActionDataset(Dataset): + def __init__(self, data_path, data_split, n_frames=243, random_move=True, scale_range=[1,1], check_split=True): # data_split: train/test etc. + np.random.seed(0) + dataset = read_pkl(data_path) + if check_split: + assert data_split in dataset['split'].keys() + self.split = dataset['split'][data_split] + annotations = dataset['annotations'] + self.random_move = random_move + self.is_train = "train" in data_split or (check_split==False) + if "oneshot" in data_split: + self.is_train = False + self.scale_range = scale_range + motions = [] + labels = [] + for sample in annotations: + if check_split and (not sample['frame_dir'] in self.split): + continue + resample_id = resample(ori_len=sample['total_frames'], target_len=n_frames, randomness=self.is_train) + motion_cam = make_cam(x=sample['keypoint'], img_shape=sample['img_shape']) + motion_cam = human_tracking(motion_cam) + motion_cam = coco2h36m(motion_cam) + motion_conf = sample['keypoint_score'][..., None] + motion = np.concatenate((motion_cam[:,resample_id], motion_conf[:,resample_id]), axis=-1) + if motion.shape[0]==1: # Single person, make a fake zero person + fake = np.zeros(motion.shape) + motion = np.concatenate((motion, fake), axis=0) + motions.append(motion.astype(np.float32)) + labels.append(sample['label']) + self.motions = np.array(motions) + self.labels = np.array(labels) + + def __len__(self): + 'Denotes the total number of samples' + return len(self.motions) + + def __getitem__(self, index): + raise NotImplementedError + +class NTURGBD(ActionDataset): + def __init__(self, data_path, data_split, n_frames=243, random_move=True, scale_range=[1,1]): + super(NTURGBD, self).__init__(data_path, data_split, n_frames, random_move, scale_range) + + def __getitem__(self, idx): + 'Generates one sample of data' + motion, label = self.motions[idx], self.labels[idx] # (M,T,J,C) + if self.random_move: + motion = random_move(motion) + if self.scale_range: + result = crop_scale(motion, scale_range=self.scale_range) + else: + result = motion + return result.astype(np.float32), label + +class NTURGBD1Shot(ActionDataset): + def __init__(self, data_path, data_split, n_frames=243, random_move=True, scale_range=[1,1], check_split=False): + super(NTURGBD1Shot, self).__init__(data_path, data_split, n_frames, random_move, scale_range, check_split) + oneshot_classes = [0, 6, 12, 18, 24, 30, 36, 42, 48, 54, 60, 66, 72, 78, 84, 90, 96, 102, 108, 114] + new_classes = set(range(120)) - set(oneshot_classes) + old2new = {} + for i, cid in enumerate(new_classes): + old2new[cid] = i + filtered = [not (x in oneshot_classes) for x in self.labels] + self.motions = self.motions[filtered] + filtered_labels = self.labels[filtered] + self.labels = [old2new[x] for x in filtered_labels] + + def __getitem__(self, idx): + 'Generates one sample of data' + motion, label = self.motions[idx], self.labels[idx] # (M,T,J,C) + if self.random_move: + motion = random_move(motion) + if self.scale_range: + result = crop_scale(motion, scale_range=self.scale_range) + else: + result = motion + return result.astype(np.float32), label \ No newline at end of file diff --git a/lib/data/dataset_mesh.py b/lib/data/dataset_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..c496a3ac34648a39076379508d67625099f589b3 --- /dev/null +++ b/lib/data/dataset_mesh.py @@ -0,0 +1,97 @@ +import torch +import numpy as np +import glob +import os +import io +import random +import pickle +from torch.utils.data import Dataset, DataLoader +from lib.data.augmentation import Augmenter3D +from lib.utils.tools import read_pkl +from lib.utils.utils_data import flip_data, crop_scale +from lib.utils.utils_mesh import flip_thetas +from lib.utils.utils_smpl import SMPL +from torch.utils.data import Dataset, DataLoader +from lib.data.datareader_h36m import DataReaderH36M +from lib.data.datareader_mesh import DataReaderMesh +from lib.data.dataset_action import random_move + +class SMPLDataset(Dataset): + def __init__(self, args, data_split, dataset): # data_split: train/test; dataset: h36m, coco, pw3d + random.seed(0) + np.random.seed(0) + self.clip_len = args.clip_len + self.data_split = data_split + if dataset=="h36m": + datareader = DataReaderH36M(n_frames=self.clip_len, sample_stride=args.sample_stride, data_stride_train=args.data_stride, data_stride_test=self.clip_len, dt_root=args.data_root, dt_file=args.dt_file_h36m) + elif dataset=="coco": + datareader = DataReaderMesh(n_frames=1, sample_stride=args.sample_stride, data_stride_train=1, data_stride_test=1, dt_root=args.data_root, dt_file=args.dt_file_coco, res=[640, 640]) + elif dataset=="pw3d": + datareader = DataReaderMesh(n_frames=self.clip_len, sample_stride=args.sample_stride, data_stride_train=args.data_stride, data_stride_test=self.clip_len, dt_root=args.data_root, dt_file=args.dt_file_pw3d, res=[1920, 1920]) + else: + raise Exception("Mesh dataset undefined.") + + split_id_train, split_id_test = datareader.get_split_id() # Index of clips + train_data, test_data = datareader.read_2d() + train_data, test_data = train_data[split_id_train], test_data[split_id_test] # Input: (N, T, 17, 3) + self.motion_2d = {'train': train_data, 'test': test_data}[data_split] + + dt = datareader.dt_dataset + smpl_pose_train = dt['train']['smpl_pose'][split_id_train] # (N, T, 72) + smpl_shape_train = dt['train']['smpl_shape'][split_id_train] # (N, T, 10) + smpl_pose_test = dt['test']['smpl_pose'][split_id_test] # (N, T, 72) + smpl_shape_test = dt['test']['smpl_shape'][split_id_test] # (N, T, 10) + + self.motion_smpl_3d = {'train': {'pose': smpl_pose_train, 'shape': smpl_shape_train}, 'test': {'pose': smpl_pose_test, 'shape': smpl_shape_test}}[data_split] + self.smpl = SMPL( + args.data_root, + batch_size=1, + ) + + def __len__(self): + 'Denotes the total number of samples' + return len(self.motion_2d) + + def __getitem__(self, index): + raise NotImplementedError + +class MotionSMPL(SMPLDataset): + def __init__(self, args, data_split, dataset): + super(MotionSMPL, self).__init__(args, data_split, dataset) + self.flip = args.flip + + def __getitem__(self, index): + 'Generates one sample of data' + # Select sample + motion_2d = self.motion_2d[index] # motion_2d: (T,17,3) + motion_2d[:,:,2] = np.clip(motion_2d[:,:,2], 0, 1) + motion_smpl_pose = self.motion_smpl_3d['pose'][index].reshape(-1, 24, 3) # motion_smpl_3d: (T, 24, 3) + motion_smpl_shape = self.motion_smpl_3d['shape'][index] # motion_smpl_3d: (T,10) + + if self.data_split=="train": + if self.flip and random.random() > 0.5: # Training augmentation - random flipping + motion_2d = flip_data(motion_2d) + motion_smpl_pose = flip_thetas(motion_smpl_pose) + + + motion_smpl_pose = torch.from_numpy(motion_smpl_pose).reshape(-1, 72).float() + motion_smpl_shape = torch.from_numpy(motion_smpl_shape).reshape(-1, 10).float() + motion_smpl = self.smpl( + betas=motion_smpl_shape, + body_pose=motion_smpl_pose[:, 3:], + global_orient=motion_smpl_pose[:, :3], + pose2rot=True + ) + motion_verts = motion_smpl.vertices.detach()*1000.0 + J_regressor = self.smpl.J_regressor_h36m + J_regressor_batch = J_regressor[None, :].expand(motion_verts.shape[0], -1, -1).to(motion_verts.device) + motion_3d_reg = torch.matmul(J_regressor_batch, motion_verts) # motion_3d: (T,17,3) + motion_verts = motion_verts - motion_3d_reg[:, :1, :] + motion_3d_reg = motion_3d_reg - motion_3d_reg[:, :1, :] # motion_3d: (T,17,3) + motion_theta = torch.cat((motion_smpl_pose, motion_smpl_shape), -1) + motion_smpl_3d = { + 'theta': motion_theta, # smpl pose and shape + 'kp_3d': motion_3d_reg, # 3D keypoints + 'verts': motion_verts, # 3D mesh vertices + } + return motion_2d, motion_smpl_3d \ No newline at end of file diff --git a/lib/data/dataset_motion_2d.py b/lib/data/dataset_motion_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..b136f33507de96323d517452def1fe1686743700 --- /dev/null +++ b/lib/data/dataset_motion_2d.py @@ -0,0 +1,148 @@ +import sys +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader +import numpy as np +import os +import random +import copy +import json +from collections import defaultdict +from lib.utils.utils_data import crop_scale, flip_data, resample, split_clips + +def posetrack2h36m(x): + ''' + Input: x (T x V x C) + + PoseTrack keypoints = [ 'nose', + 'head_bottom', + 'head_top', + 'left_ear', + 'right_ear', + 'left_shoulder', + 'right_shoulder', + 'left_elbow', + 'right_elbow', + 'left_wrist', + 'right_wrist', + 'left_hip', + 'right_hip', + 'left_knee', + 'right_knee', + 'left_ankle', + 'right_ankle'] + H36M: + 0: 'root', + 1: 'rhip', + 2: 'rkne', + 3: 'rank', + 4: 'lhip', + 5: 'lkne', + 6: 'lank', + 7: 'belly', + 8: 'neck', + 9: 'nose', + 10: 'head', + 11: 'lsho', + 12: 'lelb', + 13: 'lwri', + 14: 'rsho', + 15: 'relb', + 16: 'rwri' + ''' + y = np.zeros(x.shape) + y[:,0,:] = (x[:,11,:] + x[:,12,:]) * 0.5 + y[:,1,:] = x[:,12,:] + y[:,2,:] = x[:,14,:] + y[:,3,:] = x[:,16,:] + y[:,4,:] = x[:,11,:] + y[:,5,:] = x[:,13,:] + y[:,6,:] = x[:,15,:] + y[:,8,:] = x[:,1,:] + y[:,7,:] = (y[:,0,:] + y[:,8,:]) * 0.5 + y[:,9,:] = x[:,0,:] + y[:,10,:] = x[:,2,:] + y[:,11,:] = x[:,5,:] + y[:,12,:] = x[:,7,:] + y[:,13,:] = x[:,9,:] + y[:,14,:] = x[:,6,:] + y[:,15,:] = x[:,8,:] + y[:,16,:] = x[:,10,:] + y[:,0,2] = np.minimum(x[:,11,2], x[:,12,2]) + y[:,7,2] = np.minimum(y[:,0,2], y[:,8,2]) + return y + + +class PoseTrackDataset2D(Dataset): + def __init__(self, flip=True, scale_range=[0.25, 1]): + super(PoseTrackDataset2D, self).__init__() + self.flip = flip + data_root = "data/motion2d/posetrack18_annotations/train/" + file_list = sorted(os.listdir(data_root)) + all_motions = [] + all_motions_filtered = [] + self.scale_range = scale_range + for filename in file_list: + with open(os.path.join(data_root, filename), 'r') as file: + json_dict = json.load(file) + annots = json_dict['annotations'] + imgs = json_dict['images'] + motions = defaultdict(list) + for annot in annots: + tid = annot['track_id'] + pose2d = np.array(annot['keypoints']).reshape(-1,3) + motions[tid].append(pose2d) + all_motions += list(motions.values()) + for motion in all_motions: + if len(motion)<30: + continue + motion = np.array(motion[:30]) + if np.sum(motion[:,:,2]) <= 306: # Valid joint num threshold + continue + motion = crop_scale(motion, self.scale_range) + motion = posetrack2h36m(motion) + motion[motion[:,:,2]==0] = 0 + if np.sum(motion[:,0,2]) < 30: + continue # Root all visible (needed for framewise rootrel) + all_motions_filtered.append(motion) + all_motions_filtered = np.array(all_motions_filtered) + self.motions_2d = all_motions_filtered + + def __len__(self): + 'Denotes the total number of samples' + return len(self.motions_2d) + + def __getitem__(self, index): + 'Generates one sample of data' + motion_2d = torch.FloatTensor(self.motions_2d[index]) + if self.flip and random.random()>0.5: + motion_2d = flip_data(motion_2d) + return motion_2d, motion_2d + +class InstaVDataset2D(Dataset): + def __init__(self, n_frames=81, data_stride=27, flip=True, valid_threshold=0.0, scale_range=[0.25, 1]): + super(InstaVDataset2D, self).__init__() + self.flip = flip + self.scale_range = scale_range + motion_all = np.load('data/motion2d/InstaVariety/motion_all.npy') + id_all = np.load('data/motion2d/InstaVariety/id_all.npy') + split_id = split_clips(id_all, n_frames, data_stride) + motions_2d = motion_all[split_id] # [N, T, 17, 3] + valid_idx = (motions_2d[:,0,0,2] > valid_threshold) + self.motions_2d = motions_2d[valid_idx] + + def __len__(self): + 'Denotes the total number of samples' + return len(self.motions_2d) + + def __getitem__(self, index): + 'Generates one sample of data' + motion_2d = self.motions_2d[index] + motion_2d = crop_scale(motion_2d, self.scale_range) + motion_2d[motion_2d[:,:,2]==0] = 0 + if self.flip and random.random()>0.5: + motion_2d = flip_data(motion_2d) + motion_2d = torch.FloatTensor(motion_2d) + return motion_2d, motion_2d + \ No newline at end of file diff --git a/lib/data/dataset_motion_3d.py b/lib/data/dataset_motion_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..a2de10dc797f2c8d024ec2030af8cdd061fb18f4 --- /dev/null +++ b/lib/data/dataset_motion_3d.py @@ -0,0 +1,68 @@ +import torch +import numpy as np +import glob +import os +import io +import random +import pickle +from torch.utils.data import Dataset, DataLoader +from lib.data.augmentation import Augmenter3D +from lib.utils.tools import read_pkl +from lib.utils.utils_data import flip_data + +class MotionDataset(Dataset): + def __init__(self, args, subset_list, data_split): # data_split: train/test + np.random.seed(0) + self.data_root = args.data_root + self.subset_list = subset_list + self.data_split = data_split + file_list_all = [] + for subset in self.subset_list: + data_path = os.path.join(self.data_root, subset, self.data_split) + motion_list = sorted(os.listdir(data_path)) + for i in motion_list: + file_list_all.append(os.path.join(data_path, i)) + self.file_list = file_list_all + + def __len__(self): + 'Denotes the total number of samples' + return len(self.file_list) + + def __getitem__(self, index): + raise NotImplementedError + +class MotionDataset3D(MotionDataset): + def __init__(self, args, subset_list, data_split): + super(MotionDataset3D, self).__init__(args, subset_list, data_split) + self.flip = args.flip + self.synthetic = args.synthetic + self.aug = Augmenter3D(args) + self.gt_2d = args.gt_2d + + def __getitem__(self, index): + 'Generates one sample of data' + # Select sample + file_path = self.file_list[index] + motion_file = read_pkl(file_path) + motion_3d = motion_file["data_label"] + if self.data_split=="train": + if self.synthetic or self.gt_2d: + motion_3d = self.aug.augment3D(motion_3d) + motion_2d = np.zeros(motion_3d.shape, dtype=np.float32) + motion_2d[:,:,:2] = motion_3d[:,:,:2] + motion_2d[:,:,2] = 1 # No 2D detection, use GT xy and c=1. + elif motion_file["data_input"] is not None: # Have 2D detection + motion_2d = motion_file["data_input"] + if self.flip and random.random() > 0.5: # Training augmentation - random flipping + motion_2d = flip_data(motion_2d) + motion_3d = flip_data(motion_3d) + else: + raise ValueError('Training illegal.') + elif self.data_split=="test": + motion_2d = motion_file["data_input"] + if self.gt_2d: + motion_2d[:,:,:2] = motion_3d[:,:,:2] + motion_2d[:,:,2] = 1 + else: + raise ValueError('Data split unknown.') + return torch.FloatTensor(motion_2d), torch.FloatTensor(motion_3d) \ No newline at end of file diff --git a/lib/data/dataset_wild.py b/lib/data/dataset_wild.py new file mode 100644 index 0000000000000000000000000000000000000000..8a462c4759af73f247c7964fc3c6f53579d64a02 --- /dev/null +++ b/lib/data/dataset_wild.py @@ -0,0 +1,102 @@ +import torch +import numpy as np +import ipdb +import glob +import os +import io +import math +import random +import json +import pickle +import math +from torch.utils.data import Dataset, DataLoader +from lib.utils.utils_data import crop_scale + +def halpe2h36m(x): + ''' + Input: x (T x V x C) + //Halpe 26 body keypoints + {0, "Nose"}, + {1, "LEye"}, + {2, "REye"}, + {3, "LEar"}, + {4, "REar"}, + {5, "LShoulder"}, + {6, "RShoulder"}, + {7, "LElbow"}, + {8, "RElbow"}, + {9, "LWrist"}, + {10, "RWrist"}, + {11, "LHip"}, + {12, "RHip"}, + {13, "LKnee"}, + {14, "Rknee"}, + {15, "LAnkle"}, + {16, "RAnkle"}, + {17, "Head"}, + {18, "Neck"}, + {19, "Hip"}, + {20, "LBigToe"}, + {21, "RBigToe"}, + {22, "LSmallToe"}, + {23, "RSmallToe"}, + {24, "LHeel"}, + {25, "RHeel"}, + ''' + T, V, C = x.shape + y = np.zeros([T,17,C]) + y[:,0,:] = x[:,19,:] + y[:,1,:] = x[:,12,:] + y[:,2,:] = x[:,14,:] + y[:,3,:] = x[:,16,:] + y[:,4,:] = x[:,11,:] + y[:,5,:] = x[:,13,:] + y[:,6,:] = x[:,15,:] + y[:,7,:] = (x[:,18,:] + x[:,19,:]) * 0.5 + y[:,8,:] = x[:,18,:] + y[:,9,:] = x[:,0,:] + y[:,10,:] = x[:,17,:] + y[:,11,:] = x[:,5,:] + y[:,12,:] = x[:,7,:] + y[:,13,:] = x[:,9,:] + y[:,14,:] = x[:,6,:] + y[:,15,:] = x[:,8,:] + y[:,16,:] = x[:,10,:] + return y + +def read_input(json_path, vid_size, scale_range, focus): + with open(json_path, "r") as read_file: + results = json.load(read_file) + kpts_all = [] + for item in results: + if focus!=None and item['idx']!=focus: + continue + kpts = np.array(item['keypoints']).reshape([-1,3]) + kpts_all.append(kpts) + kpts_all = np.array(kpts_all) + kpts_all = halpe2h36m(kpts_all) + if vid_size: + w, h = vid_size + scale = min(w,h) / 2.0 + kpts_all[:,:,:2] = kpts_all[:,:,:2] - np.array([w, h]) / 2.0 + kpts_all[:,:,:2] = kpts_all[:,:,:2] / scale + motion = kpts_all + if scale_range: + motion = crop_scale(kpts_all, scale_range) + return motion.astype(np.float32) + +class WildDetDataset(Dataset): + def __init__(self, json_path, clip_len=243, vid_size=None, scale_range=None, focus=None): + self.json_path = json_path + self.clip_len = clip_len + self.vid_all = read_input(json_path, vid_size, scale_range, focus) + + def __len__(self): + 'Denotes the total number of samples' + return math.ceil(len(self.vid_all) / self.clip_len) + + def __getitem__(self, index): + 'Generates one sample of data' + st = index*self.clip_len + end = min((index+1)*self.clip_len, len(self.vid_all)) + return self.vid_all[st:end] \ No newline at end of file diff --git a/lib/model/DSTformer.py b/lib/model/DSTformer.py new file mode 100644 index 0000000000000000000000000000000000000000..2af23881e5cc71a41b07722acd462de804bca6c8 --- /dev/null +++ b/lib/model/DSTformer.py @@ -0,0 +1,362 @@ +import torch +import torch.nn as nn +import math +import warnings +import random +import numpy as np +from collections import OrderedDict +from functools import partial +from itertools import repeat +from lib.model.drop import DropPath + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +class MLP(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., st_mode='vanilla'): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.mode = st_mode + if self.mode == 'parallel': + self.ts_attn = nn.Linear(dim*2, dim*2) + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + else: + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj_drop = nn.Dropout(proj_drop) + + self.attn_count_s = None + self.attn_count_t = None + + def forward(self, x, seqlen=1): + B, N, C = x.shape + + if self.mode == 'series': + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + x = self.forward_spatial(q, k, v) + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + x = self.forward_temporal(q, k, v, seqlen=seqlen) + elif self.mode == 'parallel': + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + x_t = self.forward_temporal(q, k, v, seqlen=seqlen) + x_s = self.forward_spatial(q, k, v) + + alpha = torch.cat([x_s, x_t], dim=-1) + alpha = alpha.mean(dim=1, keepdim=True) + alpha = self.ts_attn(alpha).reshape(B, 1, C, 2) + alpha = alpha.softmax(dim=-1) + x = x_t * alpha[:,:,:,1] + x_s * alpha[:,:,:,0] + elif self.mode == 'coupling': + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + x = self.forward_coupling(q, k, v, seqlen=seqlen) + elif self.mode == 'vanilla': + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + x = self.forward_spatial(q, k, v) + elif self.mode == 'temporal': + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + x = self.forward_temporal(q, k, v, seqlen=seqlen) + elif self.mode == 'spatial': + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + x = self.forward_spatial(q, k, v) + else: + raise NotImplementedError(self.mode) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def reshape_T(self, x, seqlen=1, inverse=False): + if not inverse: + N, C = x.shape[-2:] + x = x.reshape(-1, seqlen, self.num_heads, N, C).transpose(1,2) + x = x.reshape(-1, self.num_heads, seqlen*N, C) #(B, H, TN, c) + else: + TN, C = x.shape[-2:] + x = x.reshape(-1, self.num_heads, seqlen, TN // seqlen, C).transpose(1,2) + x = x.reshape(-1, self.num_heads, TN // seqlen, C) #(BT, H, N, C) + return x + + def forward_coupling(self, q, k, v, seqlen=8): + BT, _, N, C = q.shape + q = self.reshape_T(q, seqlen) + k = self.reshape_T(k, seqlen) + v = self.reshape_T(v, seqlen) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = attn @ v + x = self.reshape_T(x, seqlen, inverse=True) + x = x.transpose(1,2).reshape(BT, N, C*self.num_heads) + return x + + def forward_spatial(self, q, k, v): + B, _, N, C = q.shape + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = attn @ v + x = x.transpose(1,2).reshape(B, N, C*self.num_heads) + return x + + def forward_temporal(self, q, k, v, seqlen=8): + B, _, N, C = q.shape + qt = q.reshape(-1, seqlen, self.num_heads, N, C).permute(0, 2, 3, 1, 4) #(B, H, N, T, C) + kt = k.reshape(-1, seqlen, self.num_heads, N, C).permute(0, 2, 3, 1, 4) #(B, H, N, T, C) + vt = v.reshape(-1, seqlen, self.num_heads, N, C).permute(0, 2, 3, 1, 4) #(B, H, N, T, C) + + attn = (qt @ kt.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = attn @ vt #(B, H, N, T, C) + x = x.permute(0, 3, 2, 1, 4).reshape(B, N, C*self.num_heads) + return x + + def count_attn(self, attn): + attn = attn.detach().cpu().numpy() + attn = attn.mean(axis=1) + attn_t = attn[:, :, 1].mean(axis=1) + attn_s = attn[:, :, 0].mean(axis=1) + if self.attn_count_s is None: + self.attn_count_s = attn_s + self.attn_count_t = attn_t + else: + self.attn_count_s = np.concatenate([self.attn_count_s, attn_s], axis=0) + self.attn_count_t = np.concatenate([self.attn_count_t, attn_t], axis=0) + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., mlp_out_ratio=1., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, st_mode='stage_st', att_fuse=False): + super().__init__() + # assert 'stage' in st_mode + self.st_mode = st_mode + self.norm1_s = norm_layer(dim) + self.norm1_t = norm_layer(dim) + self.attn_s = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, st_mode="spatial") + self.attn_t = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, st_mode="temporal") + + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2_s = norm_layer(dim) + self.norm2_t = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + mlp_out_dim = int(dim * mlp_out_ratio) + self.mlp_s = MLP(in_features=dim, hidden_features=mlp_hidden_dim, out_features=mlp_out_dim, act_layer=act_layer, drop=drop) + self.mlp_t = MLP(in_features=dim, hidden_features=mlp_hidden_dim, out_features=mlp_out_dim, act_layer=act_layer, drop=drop) + self.att_fuse = att_fuse + if self.att_fuse: + self.ts_attn = nn.Linear(dim*2, dim*2) + def forward(self, x, seqlen=1): + if self.st_mode=='stage_st': + x = x + self.drop_path(self.attn_s(self.norm1_s(x), seqlen)) + x = x + self.drop_path(self.mlp_s(self.norm2_s(x))) + x = x + self.drop_path(self.attn_t(self.norm1_t(x), seqlen)) + x = x + self.drop_path(self.mlp_t(self.norm2_t(x))) + elif self.st_mode=='stage_ts': + x = x + self.drop_path(self.attn_t(self.norm1_t(x), seqlen)) + x = x + self.drop_path(self.mlp_t(self.norm2_t(x))) + x = x + self.drop_path(self.attn_s(self.norm1_s(x), seqlen)) + x = x + self.drop_path(self.mlp_s(self.norm2_s(x))) + elif self.st_mode=='stage_para': + x_t = x + self.drop_path(self.attn_t(self.norm1_t(x), seqlen)) + x_t = x_t + self.drop_path(self.mlp_t(self.norm2_t(x_t))) + x_s = x + self.drop_path(self.attn_s(self.norm1_s(x), seqlen)) + x_s = x_s + self.drop_path(self.mlp_s(self.norm2_s(x_s))) + if self.att_fuse: + # x_s, x_t: [BF, J, dim] + alpha = torch.cat([x_s, x_t], dim=-1) + BF, J = alpha.shape[:2] + # alpha = alpha.mean(dim=1, keepdim=True) + alpha = self.ts_attn(alpha).reshape(BF, J, -1, 2) + alpha = alpha.softmax(dim=-1) + x = x_t * alpha[:,:,:,1] + x_s * alpha[:,:,:,0] + else: + x = (x_s + x_t)*0.5 + else: + raise NotImplementedError(self.st_mode) + return x + +class DSTformer(nn.Module): + def __init__(self, dim_in=3, dim_out=3, dim_feat=256, dim_rep=512, + depth=5, num_heads=8, mlp_ratio=4, + num_joints=17, maxlen=243, + qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, att_fuse=True): + super().__init__() + self.dim_out = dim_out + self.dim_feat = dim_feat + self.joints_embed = nn.Linear(dim_in, dim_feat) + self.pos_drop = nn.Dropout(p=drop_rate) + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks_st = nn.ModuleList([ + Block( + dim=dim_feat, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + st_mode="stage_st") + for i in range(depth)]) + self.blocks_ts = nn.ModuleList([ + Block( + dim=dim_feat, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + st_mode="stage_ts") + for i in range(depth)]) + self.norm = norm_layer(dim_feat) + if dim_rep: + self.pre_logits = nn.Sequential(OrderedDict([ + ('fc', nn.Linear(dim_feat, dim_rep)), + ('act', nn.Tanh()) + ])) + else: + self.pre_logits = nn.Identity() + self.head = nn.Linear(dim_rep, dim_out) if dim_out > 0 else nn.Identity() + self.temp_embed = nn.Parameter(torch.zeros(1, maxlen, 1, dim_feat)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_joints, dim_feat)) + trunc_normal_(self.temp_embed, std=.02) + trunc_normal_(self.pos_embed, std=.02) + self.apply(self._init_weights) + self.att_fuse = att_fuse + if self.att_fuse: + self.ts_attn = nn.ModuleList([nn.Linear(dim_feat*2, 2) for i in range(depth)]) + for i in range(depth): + self.ts_attn[i].weight.data.fill_(0) + self.ts_attn[i].bias.data.fill_(0.5) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def get_classifier(self): + return self.head + + def reset_classifier(self, dim_out, global_pool=''): + self.dim_out = dim_out + self.head = nn.Linear(self.dim_feat, dim_out) if dim_out > 0 else nn.Identity() + + def forward(self, x, return_rep=False): + B, F, J, C = x.shape + x = x.reshape(-1, J, C) + BF = x.shape[0] + x = self.joints_embed(x) + x = x + self.pos_embed + _, J, C = x.shape + x = x.reshape(-1, F, J, C) + self.temp_embed[:,:F,:,:] + x = x.reshape(BF, J, C) + x = self.pos_drop(x) + alphas = [] + for idx, (blk_st, blk_ts) in enumerate(zip(self.blocks_st, self.blocks_ts)): + x_st = blk_st(x, F) + x_ts = blk_ts(x, F) + if self.att_fuse: + att = self.ts_attn[idx] + alpha = torch.cat([x_st, x_ts], dim=-1) + BF, J = alpha.shape[:2] + alpha = att(alpha) + alpha = alpha.softmax(dim=-1) + x = x_st * alpha[:,:,0:1] + x_ts * alpha[:,:,1:2] + else: + x = (x_st + x_ts)*0.5 + x = self.norm(x) + x = x.reshape(B, F, J, -1) + x = self.pre_logits(x) # [B, F, J, dim_feat] + if return_rep: + return x + x = self.head(x) + return x + + def get_representation(self, x): + return self.forward(x, return_rep=True) + \ No newline at end of file diff --git a/lib/model/drop.py b/lib/model/drop.py new file mode 100644 index 0000000000000000000000000000000000000000..efbed356f80c4e94d0ecb7f1f9a64c3c3e232887 --- /dev/null +++ b/lib/model/drop.py @@ -0,0 +1,43 @@ +""" DropBlock, DropPath +PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers. +Papers: +DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890) +Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382) +Code: +DropBlock impl inspired by two Tensorflow impl that I liked: + - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74 + - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) \ No newline at end of file diff --git a/lib/model/loss.py b/lib/model/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..4397ce1665f05debad6f46dc89be01d5481bf37c --- /dev/null +++ b/lib/model/loss.py @@ -0,0 +1,204 @@ +import torch +import torch.nn as nn +import numpy as np +import torch.nn.functional as F + +# Numpy-based errors + +def mpjpe(predicted, target): + """ + Mean per-joint position error (i.e. mean Euclidean distance), + often referred to as "Protocol #1" in many papers. + """ + assert predicted.shape == target.shape + return np.mean(np.linalg.norm(predicted - target, axis=len(target.shape)-1), axis=1) + +def p_mpjpe(predicted, target): + """ + Pose error: MPJPE after rigid alignment (scale, rotation, and translation), + often referred to as "Protocol #2" in many papers. + """ + assert predicted.shape == target.shape + + muX = np.mean(target, axis=1, keepdims=True) + muY = np.mean(predicted, axis=1, keepdims=True) + + X0 = target - muX + Y0 = predicted - muY + + normX = np.sqrt(np.sum(X0**2, axis=(1, 2), keepdims=True)) + normY = np.sqrt(np.sum(Y0**2, axis=(1, 2), keepdims=True)) + + X0 /= normX + Y0 /= normY + + H = np.matmul(X0.transpose(0, 2, 1), Y0) + U, s, Vt = np.linalg.svd(H) + V = Vt.transpose(0, 2, 1) + R = np.matmul(V, U.transpose(0, 2, 1)) + + # Avoid improper rotations (reflections), i.e. rotations with det(R) = -1 + sign_detR = np.sign(np.expand_dims(np.linalg.det(R), axis=1)) + V[:, :, -1] *= sign_detR + s[:, -1] *= sign_detR.flatten() + R = np.matmul(V, U.transpose(0, 2, 1)) # Rotation + tr = np.expand_dims(np.sum(s, axis=1, keepdims=True), axis=2) + a = tr * normX / normY # Scale + t = muX - a*np.matmul(muY, R) # Translation + # Perform rigid transformation on the input + predicted_aligned = a*np.matmul(predicted, R) + t + # Return MPJPE + return np.mean(np.linalg.norm(predicted_aligned - target, axis=len(target.shape)-1), axis=1) + + +# PyTorch-based errors (for losses) + +def loss_mpjpe(predicted, target): + """ + Mean per-joint position error (i.e. mean Euclidean distance), + often referred to as "Protocol #1" in many papers. + """ + assert predicted.shape == target.shape + return torch.mean(torch.norm(predicted - target, dim=len(target.shape)-1)) + +def weighted_mpjpe(predicted, target, w): + """ + Weighted mean per-joint position error (i.e. mean Euclidean distance) + """ + assert predicted.shape == target.shape + assert w.shape[0] == predicted.shape[0] + return torch.mean(w * torch.norm(predicted - target, dim=len(target.shape)-1)) + +def loss_2d_weighted(predicted, target, conf): + assert predicted.shape == target.shape + predicted_2d = predicted[:,:,:,:2] + target_2d = target[:,:,:,:2] + diff = (predicted_2d - target_2d) * conf + return torch.mean(torch.norm(diff, dim=-1)) + +def n_mpjpe(predicted, target): + """ + Normalized MPJPE (scale only), adapted from: + https://github.com/hrhodin/UnsupervisedGeometryAwareRepresentationLearning/blob/master/losses/poses.py + """ + assert predicted.shape == target.shape + norm_predicted = torch.mean(torch.sum(predicted**2, dim=3, keepdim=True), dim=2, keepdim=True) + norm_target = torch.mean(torch.sum(target*predicted, dim=3, keepdim=True), dim=2, keepdim=True) + scale = norm_target / norm_predicted + return loss_mpjpe(scale * predicted, target) + +def weighted_bonelen_loss(predict_3d_length, gt_3d_length): + loss_length = 0.001 * torch.pow(predict_3d_length - gt_3d_length, 2).mean() + return loss_length + +def weighted_boneratio_loss(predict_3d_length, gt_3d_length): + loss_length = 0.1 * torch.pow((predict_3d_length - gt_3d_length)/gt_3d_length, 2).mean() + return loss_length + +def get_limb_lens(x): + ''' + Input: (N, T, 17, 3) + Output: (N, T, 16) + ''' + limbs_id = [[0,1], [1,2], [2,3], + [0,4], [4,5], [5,6], + [0,7], [7,8], [8,9], [9,10], + [8,11], [11,12], [12,13], + [8,14], [14,15], [15,16] + ] + limbs = x[:,:,limbs_id,:] + limbs = limbs[:,:,:,0,:]-limbs[:,:,:,1,:] + limb_lens = torch.norm(limbs, dim=-1) + return limb_lens + +def loss_limb_var(x): + ''' + Input: (N, T, 17, 3) + ''' + if x.shape[1]<=1: + return torch.FloatTensor(1).fill_(0.)[0].to(x.device) + limb_lens = get_limb_lens(x) + limb_lens_var = torch.var(limb_lens, dim=1) + limb_loss_var = torch.mean(limb_lens_var) + return limb_loss_var + +def loss_limb_gt(x, gt): + ''' + Input: (N, T, 17, 3), (N, T, 17, 3) + ''' + limb_lens_x = get_limb_lens(x) + limb_lens_gt = get_limb_lens(gt) # (N, T, 16) + return nn.L1Loss()(limb_lens_x, limb_lens_gt) + +def loss_velocity(predicted, target): + """ + Mean per-joint velocity error (i.e. mean Euclidean distance of the 1st derivative) + """ + assert predicted.shape == target.shape + if predicted.shape[1]<=1: + return torch.FloatTensor(1).fill_(0.)[0].to(predicted.device) + velocity_predicted = predicted[:,1:] - predicted[:,:-1] + velocity_target = target[:,1:] - target[:,:-1] + return torch.mean(torch.norm(velocity_predicted - velocity_target, dim=-1)) + +def loss_joint(predicted, target): + assert predicted.shape == target.shape + return nn.L1Loss()(predicted, target) + +def get_angles(x): + ''' + Input: (N, T, 17, 3) + Output: (N, T, 16) + ''' + limbs_id = [[0,1], [1,2], [2,3], + [0,4], [4,5], [5,6], + [0,7], [7,8], [8,9], [9,10], + [8,11], [11,12], [12,13], + [8,14], [14,15], [15,16] + ] + angle_id = [[ 0, 3], + [ 0, 6], + [ 3, 6], + [ 0, 1], + [ 1, 2], + [ 3, 4], + [ 4, 5], + [ 6, 7], + [ 7, 10], + [ 7, 13], + [ 8, 13], + [10, 13], + [ 7, 8], + [ 8, 9], + [10, 11], + [11, 12], + [13, 14], + [14, 15] ] + eps = 1e-7 + limbs = x[:,:,limbs_id,:] + limbs = limbs[:,:,:,0,:]-limbs[:,:,:,1,:] + angles = limbs[:,:,angle_id,:] + angle_cos = F.cosine_similarity(angles[:,:,:,0,:], angles[:,:,:,1,:], dim=-1) + return torch.acos(angle_cos.clamp(-1+eps, 1-eps)) + +def loss_angle(x, gt): + ''' + Input: (N, T, 17, 3), (N, T, 17, 3) + ''' + limb_angles_x = get_angles(x) + limb_angles_gt = get_angles(gt) + return nn.L1Loss()(limb_angles_x, limb_angles_gt) + +def loss_angle_velocity(x, gt): + """ + Mean per-angle velocity error (i.e. mean Euclidean distance of the 1st derivative) + """ + assert x.shape == gt.shape + if x.shape[1]<=1: + return torch.FloatTensor(1).fill_(0.)[0].to(x.device) + x_a = get_angles(x) + gt_a = get_angles(gt) + x_av = x_a[:,1:] - x_a[:,:-1] + gt_av = gt_a[:,1:] - gt_a[:,:-1] + return nn.L1Loss()(x_av, gt_av) + diff --git a/lib/model/loss_mesh.py b/lib/model/loss_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..82f615f1484d163b750eaf80eaa3560bfeadd773 --- /dev/null +++ b/lib/model/loss_mesh.py @@ -0,0 +1,68 @@ +import torch +import torch.nn as nn +import ipdb +from lib.utils.utils_mesh import batch_rodrigues +from lib.model.loss import * + +class MeshLoss(nn.Module): + def __init__( + self, + loss_type='MSE', + device='cuda', + ): + super(MeshLoss, self).__init__() + self.device = device + self.loss_type = loss_type + if loss_type == 'MSE': + self.criterion_keypoints = nn.MSELoss(reduction='none').to(self.device) + self.criterion_regr = nn.MSELoss().to(self.device) + elif loss_type == 'L1': + self.criterion_keypoints = nn.L1Loss(reduction='none').to(self.device) + self.criterion_regr = nn.L1Loss().to(self.device) + + def forward( + self, + smpl_output, + data_gt, + ): + # to reduce time dimension + reduce = lambda x: x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:]) + data_3d_theta = reduce(data_gt['theta']) + + preds = smpl_output[-1] + pred_theta = preds['theta'] + theta_size = pred_theta.shape[:2] + pred_theta = reduce(pred_theta) + preds_local = preds['kp_3d'] - preds['kp_3d'][:, :, 0:1,:] # (N, T, 17, 3) + gt_local = data_gt['kp_3d'] - data_gt['kp_3d'][:, :, 0:1,:] + real_shape, pred_shape = data_3d_theta[:, 72:], pred_theta[:, 72:] + real_pose, pred_pose = data_3d_theta[:, :72], pred_theta[:, :72] + loss_dict = {} + loss_dict['loss_3d_pos'] = loss_mpjpe(preds_local, gt_local) + loss_dict['loss_3d_scale'] = n_mpjpe(preds_local, gt_local) + loss_dict['loss_3d_velocity'] = loss_velocity(preds_local, gt_local) + loss_dict['loss_lv'] = loss_limb_var(preds_local) + loss_dict['loss_lg'] = loss_limb_gt(preds_local, gt_local) + loss_dict['loss_a'] = loss_angle(preds_local, gt_local) + loss_dict['loss_av'] = loss_angle_velocity(preds_local, gt_local) + + if pred_theta.shape[0] > 0: + loss_pose, loss_shape = self.smpl_losses(pred_pose, pred_shape, real_pose, real_shape) + loss_norm = torch.norm(pred_theta, dim=-1).mean() + loss_dict['loss_shape'] = loss_shape + loss_dict['loss_pose'] = loss_pose + loss_dict['loss_norm'] = loss_norm + return loss_dict + + def smpl_losses(self, pred_rotmat, pred_betas, gt_pose, gt_betas): + pred_rotmat_valid = batch_rodrigues(pred_rotmat.reshape(-1,3)).reshape(-1, 24, 3, 3) + gt_rotmat_valid = batch_rodrigues(gt_pose.reshape(-1,3)).reshape(-1, 24, 3, 3) + pred_betas_valid = pred_betas + gt_betas_valid = gt_betas + if len(pred_rotmat_valid) > 0: + loss_regr_pose = self.criterion_regr(pred_rotmat_valid, gt_rotmat_valid) + loss_regr_betas = self.criterion_regr(pred_betas_valid, gt_betas_valid) + else: + loss_regr_pose = torch.FloatTensor(1).fill_(0.).to(self.device) + loss_regr_betas = torch.FloatTensor(1).fill_(0.).to(self.device) + return loss_regr_pose, loss_regr_betas diff --git a/lib/model/loss_supcon.py b/lib/model/loss_supcon.py new file mode 100644 index 0000000000000000000000000000000000000000..17117d4210160679dee0bdc4a4b7d97c433e1d43 --- /dev/null +++ b/lib/model/loss_supcon.py @@ -0,0 +1,98 @@ +""" +Author: Yonglong Tian (yonglong@mit.edu) +Date: May 07, 2020 +""" +from __future__ import print_function + +import torch +import torch.nn as nn + + +class SupConLoss(nn.Module): + """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. + It also supports the unsupervised contrastive loss in SimCLR""" + def __init__(self, temperature=0.07, contrast_mode='all', + base_temperature=0.07): + super(SupConLoss, self).__init__() + self.temperature = temperature + self.contrast_mode = contrast_mode + self.base_temperature = base_temperature + + def forward(self, features, labels=None, mask=None): + """Compute loss for model. If both `labels` and `mask` are None, + it degenerates to SimCLR unsupervised loss: + https://arxiv.org/pdf/2002.05709.pdf + + Args: + features: hidden vector of shape [bsz, n_views, ...]. + labels: ground truth of shape [bsz]. + mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j + has the same class as sample i. Can be asymmetric. + Returns: + A loss scalar. + """ + device = (torch.device('cuda') + if features.is_cuda + else torch.device('cpu')) + + if len(features.shape) < 3: + raise ValueError('`features` needs to be [bsz, n_views, ...],' + 'at least 3 dimensions are required') + if len(features.shape) > 3: + features = features.view(features.shape[0], features.shape[1], -1) + + batch_size = features.shape[0] + if labels is not None and mask is not None: + raise ValueError('Cannot define both `labels` and `mask`') + elif labels is None and mask is None: + mask = torch.eye(batch_size, dtype=torch.float32).to(device) + elif labels is not None: + labels = labels.contiguous().view(-1, 1) + if labels.shape[0] != batch_size: + raise ValueError('Num of labels does not match num of features') + mask = torch.eq(labels, labels.T).float().to(device) + else: + mask = mask.float().to(device) + + contrast_count = features.shape[1] + contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) + if self.contrast_mode == 'one': + anchor_feature = features[:, 0] + anchor_count = 1 + elif self.contrast_mode == 'all': + anchor_feature = contrast_feature + anchor_count = contrast_count + else: + raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) + + # compute logits + anchor_dot_contrast = torch.div( + torch.matmul(anchor_feature, contrast_feature.T), + self.temperature) + # for numerical stability + logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) + logits = anchor_dot_contrast - logits_max.detach() + + # tile mask + mask = mask.repeat(anchor_count, contrast_count) + # mask-out self-contrast cases + logits_mask = torch.scatter( + torch.ones_like(mask), + 1, + torch.arange(batch_size * anchor_count).view(-1, 1).to(device), + 0 + ) + mask = mask * logits_mask + + # compute log_prob + exp_logits = torch.exp(logits) * logits_mask + log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) + + # compute mean of log-likelihood over positive + mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) + + # loss + loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos + loss = loss.view(anchor_count, batch_size).mean() + + return loss diff --git a/lib/model/model_action.py b/lib/model/model_action.py new file mode 100644 index 0000000000000000000000000000000000000000..785ec2671b4bdb8a1b47190753b9b46534d4280f --- /dev/null +++ b/lib/model/model_action.py @@ -0,0 +1,71 @@ +import sys +import torch +import torch.nn as nn +import torch.nn.functional as F + +class ActionHeadClassification(nn.Module): + def __init__(self, dropout_ratio=0., dim_rep=512, num_classes=60, num_joints=17, hidden_dim=2048): + super(ActionHeadClassification, self).__init__() + self.dropout = nn.Dropout(p=dropout_ratio) + self.bn = nn.BatchNorm1d(hidden_dim, momentum=0.1) + self.relu = nn.ReLU(inplace=True) + self.fc1 = nn.Linear(dim_rep*num_joints, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, num_classes) + + def forward(self, feat): + ''' + Input: (N, M, T, J, C) + ''' + N, M, T, J, C = feat.shape + feat = self.dropout(feat) + feat = feat.permute(0, 1, 3, 4, 2) # (N, M, T, J, C) -> (N, M, J, C, T) + feat = feat.mean(dim=-1) + feat = feat.reshape(N, M, -1) # (N, M, J*C) + feat = feat.mean(dim=1) + feat = self.fc1(feat) + feat = self.bn(feat) + feat = self.relu(feat) + feat = self.fc2(feat) + return feat + +class ActionHeadEmbed(nn.Module): + def __init__(self, dropout_ratio=0., dim_rep=512, num_joints=17, hidden_dim=2048): + super(ActionHeadEmbed, self).__init__() + self.dropout = nn.Dropout(p=dropout_ratio) + self.fc1 = nn.Linear(dim_rep*num_joints, hidden_dim) + def forward(self, feat): + ''' + Input: (N, M, T, J, C) + ''' + N, M, T, J, C = feat.shape + feat = self.dropout(feat) + feat = feat.permute(0, 1, 3, 4, 2) # (N, M, T, J, C) -> (N, M, J, C, T) + feat = feat.mean(dim=-1) + feat = feat.reshape(N, M, -1) # (N, M, J*C) + feat = feat.mean(dim=1) + feat = self.fc1(feat) + feat = F.normalize(feat, dim=-1) + return feat + +class ActionNet(nn.Module): + def __init__(self, backbone, dim_rep=512, num_classes=60, dropout_ratio=0., version='class', hidden_dim=2048, num_joints=17): + super(ActionNet, self).__init__() + self.backbone = backbone + self.feat_J = num_joints + if version=='class': + self.head = ActionHeadClassification(dropout_ratio=dropout_ratio, dim_rep=dim_rep, num_classes=num_classes, num_joints=num_joints) + elif version=='embed': + self.head = ActionHeadEmbed(dropout_ratio=dropout_ratio, dim_rep=dim_rep, hidden_dim=hidden_dim, num_joints=num_joints) + else: + raise Exception('Version Error.') + + def forward(self, x): + ''' + Input: (N, M x T x 17 x 3) + ''' + N, M, T, J, C = x.shape + x = x.reshape(N*M, T, J, C) + feat = self.backbone.get_representation(x) + feat = feat.reshape([N, M, T, self.feat_J, -1]) # (N, M, T, J, C) + out = self.head(feat) + return out \ No newline at end of file diff --git a/lib/model/model_mesh.py b/lib/model/model_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..dff579db07a3611c98deea341f3ccd8e87aea33c --- /dev/null +++ b/lib/model/model_mesh.py @@ -0,0 +1,101 @@ +import sys +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from lib.utils.utils_smpl import SMPL +from lib.utils.utils_mesh import rotation_matrix_to_angle_axis, rot6d_to_rotmat + +class SMPLRegressor(nn.Module): + def __init__(self, args, dim_rep=512, num_joints=17, hidden_dim=2048, dropout_ratio=0.): + super(SMPLRegressor, self).__init__() + param_pose_dim = 24 * 6 + self.dropout = nn.Dropout(p=dropout_ratio) + self.fc1 = nn.Linear(num_joints*dim_rep, hidden_dim) + self.pool2 = nn.AdaptiveAvgPool2d((None, 1)) + self.fc2 = nn.Linear(num_joints*dim_rep, hidden_dim) + self.bn1 = nn.BatchNorm1d(hidden_dim, momentum=0.1) + self.bn2 = nn.BatchNorm1d(hidden_dim, momentum=0.1) + self.relu1 = nn.ReLU(inplace=True) + self.relu2 = nn.ReLU(inplace=True) + self.head_pose = nn.Linear(hidden_dim, param_pose_dim) + self.head_shape = nn.Linear(hidden_dim, 10) + nn.init.xavier_uniform_(self.head_pose.weight, gain=0.01) + nn.init.xavier_uniform_(self.head_shape.weight, gain=0.01) + self.smpl = SMPL( + args.data_root, + batch_size=64, + create_transl=False, + ) + mean_params = np.load(self.smpl.smpl_mean_params) + init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0) + init_shape = torch.from_numpy(mean_params['shape'][:].astype('float32')).unsqueeze(0) + self.register_buffer('init_pose', init_pose) + self.register_buffer('init_shape', init_shape) + self.J_regressor = self.smpl.J_regressor_h36m + + def forward(self, feat, init_pose=None, init_shape=None): + N, T, J, C = feat.shape + NT = N * T + feat = feat.reshape(N, T, -1) + + feat_pose = feat.reshape(NT, -1) # (N*T, J*C) + + feat_pose = self.dropout(feat_pose) + feat_pose = self.fc1(feat_pose) + feat_pose = self.bn1(feat_pose) + feat_pose = self.relu1(feat_pose) # (NT, C) + + feat_shape = feat.permute(0,2,1) # (N, T, J*C) -> (N, J*C, T) + feat_shape = self.pool2(feat_shape).reshape(N, -1) # (N, J*C) + + feat_shape = self.dropout(feat_shape) + feat_shape = self.fc2(feat_shape) + feat_shape = self.bn2(feat_shape) + feat_shape = self.relu2(feat_shape) # (N, C) + + pred_pose = self.init_pose.expand(NT, -1) # (NT, C) + pred_shape = self.init_shape.expand(N, -1) # (N, C) + + pred_pose = self.head_pose(feat_pose) + pred_pose + pred_shape = self.head_shape(feat_shape) + pred_shape + pred_shape = pred_shape.expand(T, N, -1).permute(1, 0, 2).reshape(NT, -1) + pred_rotmat = rot6d_to_rotmat(pred_pose).view(-1, 24, 3, 3) + pred_output = self.smpl( + betas=pred_shape, + body_pose=pred_rotmat[:, 1:], + global_orient=pred_rotmat[:, 0].unsqueeze(1), + pose2rot=False + ) + pred_vertices = pred_output.vertices*1000.0 + assert self.J_regressor is not None + J_regressor_batch = self.J_regressor[None, :].expand(pred_vertices.shape[0], -1, -1).to(pred_vertices.device) + pred_joints = torch.matmul(J_regressor_batch, pred_vertices) + pose = rotation_matrix_to_angle_axis(pred_rotmat.reshape(-1, 3, 3)).reshape(-1, 72) + output = [{ + 'theta' : torch.cat([pose, pred_shape], dim=1), # (N*T, 72+10) + 'verts' : pred_vertices, # (N*T, 6890, 3) + 'kp_3d' : pred_joints, # (N*T, 17, 3) + }] + return output + +class MeshRegressor(nn.Module): + def __init__(self, args, backbone, dim_rep=512, num_joints=17, hidden_dim=2048, dropout_ratio=0.5): + super(MeshRegressor, self).__init__() + self.backbone = backbone + self.feat_J = num_joints + self.head = SMPLRegressor(args, dim_rep, num_joints, hidden_dim, dropout_ratio) + + def forward(self, x, init_pose=None, init_shape=None, n_iter=3): + ''' + Input: (N x T x 17 x 3) + ''' + N, T, J, C = x.shape + feat = self.backbone.get_representation(x) + feat = feat.reshape([N, T, self.feat_J, -1]) # (N, T, J, C) + smpl_output = self.head(feat) + for s in smpl_output: + s['theta'] = s['theta'].reshape(N, T, -1) + s['verts'] = s['verts'].reshape(N, T, -1, 3) + s['kp_3d'] = s['kp_3d'].reshape(N, T, -1, 3) + return smpl_output \ No newline at end of file diff --git a/lib/utils/learning.py b/lib/utils/learning.py new file mode 100644 index 0000000000000000000000000000000000000000..191e6697919a338f59ec53263b6edc3f300a2783 --- /dev/null +++ b/lib/utils/learning.py @@ -0,0 +1,102 @@ +import os +import numpy as np +import torch +import torch.nn as nn +from functools import partial +from lib.model.DSTformer import DSTformer + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + +def load_pretrained_weights(model, checkpoint): + """Load pretrianed weights to model + Incompatible layers (unmatched in name or size) will be ignored + Args: + - model (nn.Module): network model, which must not be nn.DataParallel + - weight_path (str): path to pretrained weights + """ + import collections + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + model_dict = model.state_dict() + new_state_dict = collections.OrderedDict() + matched_layers, discarded_layers = [], [] + for k, v in state_dict.items(): + # If the pretrained state_dict was saved as nn.DataParallel, + # keys would contain "module.", which should be ignored. + if k.startswith('module.'): + k = k[7:] + if k in model_dict and model_dict[k].size() == v.size(): + new_state_dict[k] = v + matched_layers.append(k) + else: + discarded_layers.append(k) + model_dict.update(new_state_dict) + model.load_state_dict(model_dict, strict=True) + print('load_weight', len(matched_layers)) + return model + +def partial_train_layers(model, partial_list): + """Train partial layers of a given model.""" + for name, p in model.named_parameters(): + p.requires_grad = False + for trainable in partial_list: + if trainable in name: + p.requires_grad = True + break + return model + +def load_backbone(args): + if not(hasattr(args, "backbone")): + args.backbone = 'DSTformer' # Default + if args.backbone=='DSTformer': + model_backbone = DSTformer(dim_in=3, dim_out=3, dim_feat=args.dim_feat, dim_rep=args.dim_rep, + depth=args.depth, num_heads=args.num_heads, mlp_ratio=args.mlp_ratio, norm_layer=partial(nn.LayerNorm, eps=1e-6), + maxlen=args.maxlen, num_joints=args.num_joints) + elif args.backbone=='TCN': + from lib.model.model_tcn import PoseTCN + model_backbone = PoseTCN() + elif args.backbone=='poseformer': + from lib.model.model_poseformer import PoseTransformer + model_backbone = PoseTransformer(num_frame=args.maxlen, num_joints=args.num_joints, in_chans=3, embed_dim_ratio=32, depth=4, + num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,drop_path_rate=0, attn_mask=None) + elif args.backbone=='mixste': + from lib.model.model_mixste import MixSTE2 + model_backbone = MixSTE2(num_frame=args.maxlen, num_joints=args.num_joints, in_chans=3, embed_dim_ratio=512, depth=8, + num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,drop_path_rate=0) + elif args.backbone=='stgcn': + from lib.model.model_stgcn import Model as STGCN + model_backbone = STGCN() + else: + raise Exception("Undefined backbone type.") + return model_backbone \ No newline at end of file diff --git a/lib/utils/tools.py b/lib/utils/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..b2b780f0b184584923e54643372a26cca3e9c277 --- /dev/null +++ b/lib/utils/tools.py @@ -0,0 +1,69 @@ +import numpy as np +import os, sys +import pickle +import yaml +from easydict import EasyDict as edict +from typing import Any, IO + +ROOT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..') + +class TextLogger: + def __init__(self, log_path): + self.log_path = log_path + with open(self.log_path, "w") as f: + f.write("") + def log(self, log): + with open(self.log_path, "a+") as f: + f.write(log + "\n") + +class Loader(yaml.SafeLoader): + """YAML Loader with `!include` constructor.""" + + def __init__(self, stream: IO) -> None: + """Initialise Loader.""" + + try: + self._root = os.path.split(stream.name)[0] + except AttributeError: + self._root = os.path.curdir + + super().__init__(stream) + +def construct_include(loader: Loader, node: yaml.Node) -> Any: + """Include file referenced at node.""" + + filename = os.path.abspath(os.path.join(loader._root, loader.construct_scalar(node))) + extension = os.path.splitext(filename)[1].lstrip('.') + + with open(filename, 'r') as f: + if extension in ('yaml', 'yml'): + return yaml.load(f, Loader) + elif extension in ('json', ): + return json.load(f) + else: + return ''.join(f.readlines()) + +def get_config(config_path): + yaml.add_constructor('!include', construct_include, Loader) + with open(config_path, 'r') as stream: + config = yaml.load(stream, Loader=Loader) + config = edict(config) + _, config_filename = os.path.split(config_path) + config_name, _ = os.path.splitext(config_filename) + config.name = config_name + return config + +def ensure_dir(path): + """ + create path by first checking its existence, + :param paths: path + :return: + """ + if not os.path.exists(path): + os.makedirs(path) + +def read_pkl(data_url): + file = open(data_url,'rb') + content = pickle.load(file) + file.close() + return content \ No newline at end of file diff --git a/lib/utils/utils_data.py b/lib/utils/utils_data.py new file mode 100644 index 0000000000000000000000000000000000000000..df7b61efacfa737191237fdeba63f77a6408b701 --- /dev/null +++ b/lib/utils/utils_data.py @@ -0,0 +1,112 @@ +import os +import torch +import torch.nn.functional as F +import numpy as np +import copy + +def crop_scale(motion, scale_range=[1, 1]): + ''' + Motion: [(M), T, 17, 3]. + Normalize to [-1, 1] + ''' + result = copy.deepcopy(motion) + valid_coords = motion[motion[..., 2]!=0][:,:2] + if len(valid_coords) < 4: + return np.zeros(motion.shape) + xmin = min(valid_coords[:,0]) + xmax = max(valid_coords[:,0]) + ymin = min(valid_coords[:,1]) + ymax = max(valid_coords[:,1]) + ratio = np.random.uniform(low=scale_range[0], high=scale_range[1], size=1)[0] + scale = max(xmax-xmin, ymax-ymin) * ratio + if scale==0: + return np.zeros(motion.shape) + xs = (xmin+xmax-scale) / 2 + ys = (ymin+ymax-scale) / 2 + result[...,:2] = (motion[..., :2]- [xs,ys]) / scale + result[...,:2] = (result[..., :2] - 0.5) * 2 + result = np.clip(result, -1, 1) + return result + +def crop_scale_3d(motion, scale_range=[1, 1]): + ''' + Motion: [T, 17, 3]. (x, y, z) + Normalize to [-1, 1] + Z is relative to the first frame's root. + ''' + result = copy.deepcopy(motion) + result[:,:,2] = result[:,:,2] - result[0,0,2] + xmin = np.min(motion[...,0]) + xmax = np.max(motion[...,0]) + ymin = np.min(motion[...,1]) + ymax = np.max(motion[...,1]) + ratio = np.random.uniform(low=scale_range[0], high=scale_range[1], size=1)[0] + scale = max(xmax-xmin, ymax-ymin) / ratio + if scale==0: + return np.zeros(motion.shape) + xs = (xmin+xmax-scale) / 2 + ys = (ymin+ymax-scale) / 2 + result[...,:2] = (motion[..., :2]- [xs,ys]) / scale + result[...,2] = result[...,2] / scale + result = (result - 0.5) * 2 + return result + +def flip_data(data): + """ + horizontal flip + data: [N, F, 17, D] or [F, 17, D]. X (horizontal coordinate) is the first channel in D. + Return + result: same + """ + left_joints = [4, 5, 6, 11, 12, 13] + right_joints = [1, 2, 3, 14, 15, 16] + flipped_data = copy.deepcopy(data) + flipped_data[..., 0] *= -1 # flip x of all joints + flipped_data[..., left_joints+right_joints, :] = flipped_data[..., right_joints+left_joints, :] + return flipped_data + +def resample(ori_len, target_len, replay=False, randomness=True): + if replay: + if ori_len > target_len: + st = np.random.randint(ori_len-target_len) + return range(st, st+target_len) # Random clipping from sequence + else: + return np.array(range(target_len)) % ori_len # Replay padding + else: + if randomness: + even = np.linspace(0, ori_len, num=target_len, endpoint=False) + if ori_len < target_len: + low = np.floor(even) + high = np.ceil(even) + sel = np.random.randint(2, size=even.shape) + result = np.sort(sel*low+(1-sel)*high) + else: + interval = even[1] - even[0] + result = np.random.random(even.shape)*interval + even + result = np.clip(result, a_min=0, a_max=ori_len-1).astype(np.uint32) + else: + result = np.linspace(0, ori_len, num=target_len, endpoint=False, dtype=int) + return result + +def split_clips(vid_list, n_frames, data_stride): + result = [] + n_clips = 0 + st = 0 + i = 0 + saved = set() + while i(w, x, y, z) + Returns: + Rotation matrix corresponding to the quaternion -- size = [batch_size, 3, 3] + """ + norm_quat = quat + norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True) + w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, + 2], norm_quat[:, + 3] + + batch_size = quat.size(0) + + w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + wx, wy, wz = w * x, w * y, w * z + xy, xz, yz = x * y, x * z, y * z + + rotMat = torch.stack([ + w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, + w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, + w2 - x2 - y2 + z2 + ], + dim=1).view(batch_size, 3, 3) + return rotMat + + +def rotation_matrix_to_angle_axis(rotation_matrix): + """ + This function is borrowed from https://github.com/kornia/kornia + + Convert 3x4 rotation matrix to Rodrigues vector + + Args: + rotation_matrix (Tensor): rotation matrix. + + Returns: + Tensor: Rodrigues vector transformation. + + Shape: + - Input: :math:`(N, 3, 4)` + - Output: :math:`(N, 3)` + + Example: + >>> input = torch.rand(2, 3, 4) # Nx4x4 + >>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3 + """ + if rotation_matrix.shape[1:] == (3,3): + rot_mat = rotation_matrix.reshape(-1, 3, 3) + hom = torch.tensor([0, 0, 1], dtype=torch.float32, + device=rotation_matrix.device).reshape(1, 3, 1).expand(rot_mat.shape[0], -1, -1) + rotation_matrix = torch.cat([rot_mat, hom], dim=-1) + + quaternion = rotation_matrix_to_quaternion(rotation_matrix) + aa = quaternion_to_angle_axis(quaternion) + aa[torch.isnan(aa)] = 0.0 + return aa + + +def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor: + """ + This function is borrowed from https://github.com/kornia/kornia + + Convert quaternion vector to angle axis of rotation. + + Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h + + Args: + quaternion (torch.Tensor): tensor with quaternions. + + Return: + torch.Tensor: tensor with angle axis of rotation. + + Shape: + - Input: :math:`(*, 4)` where `*` means, any number of dimensions + - Output: :math:`(*, 3)` + + Example: + >>> quaternion = torch.rand(2, 4) # Nx4 + >>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3 + """ + if not torch.is_tensor(quaternion): + raise TypeError("Input type is not a torch.Tensor. Got {}".format( + type(quaternion))) + + if not quaternion.shape[-1] == 4: + raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}" + .format(quaternion.shape)) + # unpack input and compute conversion + q1: torch.Tensor = quaternion[..., 1] + q2: torch.Tensor = quaternion[..., 2] + q3: torch.Tensor = quaternion[..., 3] + sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3 + + sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta) + cos_theta: torch.Tensor = quaternion[..., 0] + two_theta: torch.Tensor = 2.0 * torch.where( + cos_theta < 0.0, + torch.atan2(-sin_theta, -cos_theta), + torch.atan2(sin_theta, cos_theta)) + + k_pos: torch.Tensor = two_theta / sin_theta + k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta) + k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg) + + angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3] + angle_axis[..., 0] += q1 * k + angle_axis[..., 1] += q2 * k + angle_axis[..., 2] += q3 * k + return angle_axis + + +def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6): + """ + This function is borrowed from https://github.com/kornia/kornia + + Convert 3x4 rotation matrix to 4d quaternion vector + + This algorithm is based on algorithm described in + https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201 + + Args: + rotation_matrix (Tensor): the rotation matrix to convert. + + Return: + Tensor: the rotation in quaternion + + Shape: + - Input: :math:`(N, 3, 4)` + - Output: :math:`(N, 4)` + + Example: + >>> input = torch.rand(4, 3, 4) # Nx3x4 + >>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4 + """ + if not torch.is_tensor(rotation_matrix): + raise TypeError("Input type is not a torch.Tensor. Got {}".format( + type(rotation_matrix))) + + if len(rotation_matrix.shape) > 3: + raise ValueError( + "Input size must be a three dimensional tensor. Got {}".format( + rotation_matrix.shape)) + if not rotation_matrix.shape[-2:] == (3, 4): + raise ValueError( + "Input size must be a N x 3 x 4 tensor. Got {}".format( + rotation_matrix.shape)) + + rmat_t = torch.transpose(rotation_matrix, 1, 2) + + mask_d2 = rmat_t[:, 2, 2] < eps + + mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1] + mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1] + + t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2] + q0 = torch.stack([rmat_t[:, 1, 2] - rmat_t[:, 2, 1], + t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0], + rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], -1) + t0_rep = t0.repeat(4, 1).t() + + t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2] + q1 = torch.stack([rmat_t[:, 2, 0] - rmat_t[:, 0, 2], + rmat_t[:, 0, 1] + rmat_t[:, 1, 0], + t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], -1) + t1_rep = t1.repeat(4, 1).t() + + t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2] + q2 = torch.stack([rmat_t[:, 0, 1] - rmat_t[:, 1, 0], + rmat_t[:, 2, 0] + rmat_t[:, 0, 2], + rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], -1) + t2_rep = t2.repeat(4, 1).t() + + t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2] + q3 = torch.stack([t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1], + rmat_t[:, 2, 0] - rmat_t[:, 0, 2], + rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1) + t3_rep = t3.repeat(4, 1).t() + + mask_c0 = mask_d2 * mask_d0_d1 + mask_c1 = mask_d2 * ~mask_d0_d1 + mask_c2 = ~mask_d2 * mask_d0_nd1 + mask_c3 = ~mask_d2 * ~mask_d0_nd1 + mask_c0 = mask_c0.view(-1, 1).type_as(q0) + mask_c1 = mask_c1.view(-1, 1).type_as(q1) + mask_c2 = mask_c2.view(-1, 1).type_as(q2) + mask_c3 = mask_c3.view(-1, 1).type_as(q3) + + q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3 + q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa + t2_rep * mask_c2 + t3_rep * mask_c3) # noqa + q *= 0.5 + return q + + +def estimate_translation_np(S, joints_2d, joints_conf, focal_length=5000., img_size=224.): + """ + This function is borrowed from https://github.com/nkolot/SPIN/utils/geometry.py + + Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d. + Input: + S: (25, 3) 3D joint locations + joints: (25, 3) 2D joint locations and confidence + Returns: + (3,) camera translation vector + """ + + num_joints = S.shape[0] + # focal length + f = np.array([focal_length,focal_length]) + # optical center + center = np.array([img_size/2., img_size/2.]) + + # transformations + Z = np.reshape(np.tile(S[:,2],(2,1)).T,-1) + XY = np.reshape(S[:,0:2],-1) + O = np.tile(center,num_joints) + F = np.tile(f,num_joints) + weight2 = np.reshape(np.tile(np.sqrt(joints_conf),(2,1)).T,-1) + + # least squares + Q = np.array([F*np.tile(np.array([1,0]),num_joints), F*np.tile(np.array([0,1]),num_joints), O-np.reshape(joints_2d,-1)]).T + c = (np.reshape(joints_2d,-1)-O)*Z - F*XY + + # weighted least squares + W = np.diagflat(weight2) + Q = np.dot(W,Q) + c = np.dot(W,c) + + # square matrix + A = np.dot(Q.T,Q) + b = np.dot(Q.T,c) + + # solution + trans = np.linalg.solve(A, b) + + return trans + + +def estimate_translation(S, joints_2d, focal_length=5000., img_size=224.): + """ + This function is borrowed from https://github.com/nkolot/SPIN/utils/geometry.py + + Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d. + Input: + S: (B, 49, 3) 3D joint locations + joints: (B, 49, 3) 2D joint locations and confidence + Returns: + (B, 3) camera translation vectors + """ + + device = S.device + # Use only joints 25:49 (GT joints) + S = S[:, 25:, :].cpu().numpy() + joints_2d = joints_2d[:, 25:, :].cpu().numpy() + joints_conf = joints_2d[:, :, -1] + joints_2d = joints_2d[:, :, :-1] + trans = np.zeros((S.shape[0], 3), dtype=np.float32) + # Find the translation for each example in the batch + for i in range(S.shape[0]): + S_i = S[i] + joints_i = joints_2d[i] + conf_i = joints_conf[i] + trans[i] = estimate_translation_np(S_i, joints_i, conf_i, focal_length=focal_length, img_size=img_size) + return torch.from_numpy(trans).to(device) + + +def rot6d_to_rotmat_spin(x): + """Convert 6D rotation representation to 3x3 rotation matrix. + Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 + Input: + (B,6) Batch of 6-D rotation representations + Output: + (B,3,3) Batch of corresponding rotation matrices + """ + x = x.view(-1,3,2) + a1 = x[:, :, 0] + a2 = x[:, :, 1] + b1 = F.normalize(a1) + b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1) + + # inp = a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1 + # denom = inp.pow(2).sum(dim=1).sqrt().unsqueeze(-1) + 1e-8 + # b2 = inp / denom + + b3 = torch.cross(b1, b2) + return torch.stack((b1, b2, b3), dim=-1) + + +def rot6d_to_rotmat(x): + x = x.view(-1,3,2) + + # Normalize the first vector + b1 = F.normalize(x[:, :, 0], dim=1, eps=1e-6) + + dot_prod = torch.sum(b1 * x[:, :, 1], dim=1, keepdim=True) + # Compute the second vector by finding the orthogonal complement to it + b2 = F.normalize(x[:, :, 1] - dot_prod * b1, dim=-1, eps=1e-6) + + # Finish building the basis by taking the cross product + b3 = torch.cross(b1, b2, dim=1) + rot_mats = torch.stack([b1, b2, b3], dim=-1) + + return rot_mats + + +def rigid_transform_3D(A, B): + n, dim = A.shape + centroid_A = np.mean(A, axis = 0) + centroid_B = np.mean(B, axis = 0) + H = np.dot(np.transpose(A - centroid_A), B - centroid_B) / n + U, s, V = np.linalg.svd(H) + R = np.dot(np.transpose(V), np.transpose(U)) + if np.linalg.det(R) < 0: + s[-1] = -s[-1] + V[2] = -V[2] + R = np.dot(np.transpose(V), np.transpose(U)) + + varP = np.var(A, axis=0).sum() + c = 1/varP * np.sum(s) + + t = -np.dot(c*R, np.transpose(centroid_A)) + np.transpose(centroid_B) + return c, R, t + + +def rigid_align(A, B): + c, R, t = rigid_transform_3D(A, B) + A2 = np.transpose(np.dot(c*R, np.transpose(A))) + t + return A2 + +def compute_error(output, target): + with torch.no_grad(): + pred_verts = output[0]['verts'].reshape(-1, 6890, 3) + target_verts = target['verts'].reshape(-1, 6890, 3) + + pred_j3ds = output[0]['kp_3d'].reshape(-1, 17, 3) + target_j3ds = target['kp_3d'].reshape(-1, 17, 3) + + # mpve + pred_verts = pred_verts - pred_j3ds[:, :1, :] + target_verts = target_verts - target_j3ds[:, :1, :] + mpves = torch.sqrt(((pred_verts - target_verts) ** 2).sum(dim=-1)).mean(dim=-1).cpu() + + # mpjpe + pred_j3ds = pred_j3ds - pred_j3ds[:, :1, :] + target_j3ds = target_j3ds - target_j3ds[:, :1, :] + mpjpes = torch.sqrt(((pred_j3ds - target_j3ds) ** 2).sum(dim=-1)).mean(dim=-1).cpu() + return mpjpes.mean(), mpves.mean() + +def compute_error_frames(output, target): + with torch.no_grad(): + pred_verts = output[0]['verts'].reshape(-1, 6890, 3) + target_verts = target['verts'].reshape(-1, 6890, 3) + + pred_j3ds = output[0]['kp_3d'].reshape(-1, 17, 3) + target_j3ds = target['kp_3d'].reshape(-1, 17, 3) + + # mpve + pred_verts = pred_verts - pred_j3ds[:, :1, :] + target_verts = target_verts - target_j3ds[:, :1, :] + mpves = torch.sqrt(((pred_verts - target_verts) ** 2).sum(dim=-1)).mean(dim=-1).cpu() + + # mpjpe + pred_j3ds = pred_j3ds - pred_j3ds[:, :1, :] + target_j3ds = target_j3ds - target_j3ds[:, :1, :] + mpjpes = torch.sqrt(((pred_j3ds - target_j3ds) ** 2).sum(dim=-1)).mean(dim=-1).cpu() + return mpjpes, mpves + +def evaluate_mesh(results): + pred_verts = results['verts'].reshape(-1, 6890, 3) + target_verts = results['verts_gt'].reshape(-1, 6890, 3) + + pred_j3ds = results['kp_3d'].reshape(-1, 17, 3) + target_j3ds = results['kp_3d_gt'].reshape(-1, 17, 3) + num_samples = pred_j3ds.shape[0] + + # mpve + pred_verts = pred_verts - pred_j3ds[:, :1, :] + target_verts = target_verts - target_j3ds[:, :1, :] + mpve = np.mean(np.mean(np.sqrt(np.square(pred_verts - target_verts).sum(axis=2)), axis=1)) + + + # mpjpe-17 & mpjpe-14 + h36m_17_to_14 = (1, 2, 3, 4, 5, 6, 8, 10, 11, 12, 13, 14, 15, 16) + pred_j3ds_17j = (pred_j3ds - pred_j3ds[:, :1, :]) + target_j3ds_17j = (target_j3ds - target_j3ds[:, :1, :]) + + pred_j3ds = pred_j3ds_17j[:, h36m_17_to_14, :].copy() + target_j3ds = target_j3ds_17j[:, h36m_17_to_14, :].copy() + + mpjpe = np.mean(np.sqrt(np.square(pred_j3ds - target_j3ds).sum(axis=2)), axis=1) # (N, ) + mpjpe_17j = np.mean(np.sqrt(np.square(pred_j3ds_17j - target_j3ds_17j).sum(axis=2)), axis=1) # (N, ) + + pred_j3ds_pa, pred_j3ds_pa_17j = [], [] + for n in range(num_samples): + pred_j3ds_pa.append(rigid_align(pred_j3ds[n], target_j3ds[n])) + pred_j3ds_pa_17j.append(rigid_align(pred_j3ds_17j[n], target_j3ds_17j[n])) + pred_j3ds_pa = np.array(pred_j3ds_pa) + pred_j3ds_pa_17j = np.array(pred_j3ds_pa_17j) + + pa_mpjpe = np.mean(np.sqrt(np.square(pred_j3ds_pa - target_j3ds).sum(axis=2)), axis=1) # (N, ) + pa_mpjpe_17j = np.mean(np.sqrt(np.square(pred_j3ds_pa_17j - target_j3ds_17j).sum(axis=2)), axis=1) # (N, ) + + + error_dict = { + 'mpve': mpve.mean(), + 'mpjpe': mpjpe.mean(), + 'pa_mpjpe': pa_mpjpe.mean(), + 'mpjpe_17j': mpjpe_17j.mean(), + 'pa_mpjpe_17j': pa_mpjpe_17j.mean(), + } + return error_dict + + +def rectify_pose(pose): + """ + Rectify "upside down" people in global coord + + Args: + pose (72,): Pose. + + Returns: + Rotated pose. + """ + pose = pose.copy() + R_mod = cv2.Rodrigues(np.array([np.pi, 0, 0]))[0] + R_root = cv2.Rodrigues(pose[:3])[0] + new_root = R_root.dot(R_mod) + pose[:3] = cv2.Rodrigues(new_root)[0].reshape(3) + return pose + +def flip_thetas(thetas): + """Flip thetas. + + Parameters + ---------- + thetas : numpy.ndarray + Joints in shape (F, num_thetas, 3) + theta_pairs : list + List of theta pairs. + + Returns + ------- + numpy.ndarray + Flipped thetas with shape (F, num_thetas, 3) + + """ + #Joint pairs which defines the pairs of joint to be swapped when the image is flipped horizontally. + theta_pairs = ((1, 2), (4, 5), (7, 8), (10, 11), (13, 14), (16, 17), (18, 19), (20, 21), (22, 23)) + thetas_flip = thetas.copy() + # reflect horizontally + thetas_flip[:, :, 1] = -1 * thetas_flip[:, :, 1] + thetas_flip[:, :, 2] = -1 * thetas_flip[:, :, 2] + # change left-right parts + for pair in theta_pairs: + thetas_flip[:, pair[0], :], thetas_flip[:, pair[1], :] = \ + thetas_flip[:, pair[1], :], thetas_flip[:, pair[0], :].copy() + return thetas_flip + +def flip_thetas_batch(thetas): + """Flip thetas in batch. + + Parameters + ---------- + thetas : numpy.array + Joints in shape (N, F, num_thetas*3) + theta_pairs : list + List of theta pairs. + + Returns + ------- + numpy.array + Flipped thetas with shape (N, F, num_thetas*3) + + """ + #Joint pairs which defines the pairs of joint to be swapped when the image is flipped horizontally. + theta_pairs = ((1, 2), (4, 5), (7, 8), (10, 11), (13, 14), (16, 17), (18, 19), (20, 21), (22, 23)) + thetas_flip = copy.deepcopy(thetas).reshape(*thetas.shape[:2], 24, 3) + # reflect horizontally + thetas_flip[:, :, :, 1] = -1 * thetas_flip[:, :, :, 1] + thetas_flip[:, :, :, 2] = -1 * thetas_flip[:, :, :, 2] + # change left-right parts + for pair in theta_pairs: + thetas_flip[:, :, pair[0], :], thetas_flip[:, :, pair[1], :] = \ + thetas_flip[:, :, pair[1], :], thetas_flip[:, :, pair[0], :].clone() + + return thetas_flip.reshape(*thetas.shape[:2], -1) + +# def smpl_aa_to_ortho6d(smpl_aa): +# # [...,72] -> [...,144] +# rot_aa = smpl_aa.reshape([-1,24,3]) +# rotmat = axis_angle_to_matrix(rot_aa) +# rot6d = matrix_to_rotation_6d(rotmat) +# rot6d = rot6d.reshape(-1,24*6) +# return rot6d \ No newline at end of file diff --git a/lib/utils/utils_smpl.py b/lib/utils/utils_smpl.py new file mode 100644 index 0000000000000000000000000000000000000000..2215dd8517a1bdbfc1acf34c3ce13befc486fb67 --- /dev/null +++ b/lib/utils/utils_smpl.py @@ -0,0 +1,88 @@ +# This script is borrowed and extended from https://github.com/nkolot/SPIN/blob/master/models/hmr.py +# Adhere to their licence to use this script + +import torch +import numpy as np +import os.path as osp +from smplx import SMPL as _SMPL +from smplx.utils import ModelOutput, SMPLOutput +from smplx.lbs import vertices2joints + + +# Map joints to SMPL joints +JOINT_MAP = { + 'OP Nose': 24, 'OP Neck': 12, 'OP RShoulder': 17, + 'OP RElbow': 19, 'OP RWrist': 21, 'OP LShoulder': 16, + 'OP LElbow': 18, 'OP LWrist': 20, 'OP MidHip': 0, + 'OP RHip': 2, 'OP RKnee': 5, 'OP RAnkle': 8, + 'OP LHip': 1, 'OP LKnee': 4, 'OP LAnkle': 7, + 'OP REye': 25, 'OP LEye': 26, 'OP REar': 27, + 'OP LEar': 28, 'OP LBigToe': 29, 'OP LSmallToe': 30, + 'OP LHeel': 31, 'OP RBigToe': 32, 'OP RSmallToe': 33, 'OP RHeel': 34, + 'Right Ankle': 8, 'Right Knee': 5, 'Right Hip': 45, + 'Left Hip': 46, 'Left Knee': 4, 'Left Ankle': 7, + 'Right Wrist': 21, 'Right Elbow': 19, 'Right Shoulder': 17, + 'Left Shoulder': 16, 'Left Elbow': 18, 'Left Wrist': 20, + 'Neck (LSP)': 47, 'Top of Head (LSP)': 48, + 'Pelvis (MPII)': 49, 'Thorax (MPII)': 50, + 'Spine (H36M)': 51, 'Jaw (H36M)': 52, + 'Head (H36M)': 53, 'Nose': 24, 'Left Eye': 26, + 'Right Eye': 25, 'Left Ear': 28, 'Right Ear': 27 +} +JOINT_NAMES = [ + 'OP Nose', 'OP Neck', 'OP RShoulder', + 'OP RElbow', 'OP RWrist', 'OP LShoulder', + 'OP LElbow', 'OP LWrist', 'OP MidHip', + 'OP RHip', 'OP RKnee', 'OP RAnkle', + 'OP LHip', 'OP LKnee', 'OP LAnkle', + 'OP REye', 'OP LEye', 'OP REar', + 'OP LEar', 'OP LBigToe', 'OP LSmallToe', + 'OP LHeel', 'OP RBigToe', 'OP RSmallToe', 'OP RHeel', + 'Right Ankle', 'Right Knee', 'Right Hip', + 'Left Hip', 'Left Knee', 'Left Ankle', + 'Right Wrist', 'Right Elbow', 'Right Shoulder', + 'Left Shoulder', 'Left Elbow', 'Left Wrist', + 'Neck (LSP)', 'Top of Head (LSP)', + 'Pelvis (MPII)', 'Thorax (MPII)', + 'Spine (H36M)', 'Jaw (H36M)', + 'Head (H36M)', 'Nose', 'Left Eye', + 'Right Eye', 'Left Ear', 'Right Ear' +] + +JOINT_IDS = {JOINT_NAMES[i]: i for i in range(len(JOINT_NAMES))} +SMPL_MODEL_DIR = 'data/mesh' +H36M_TO_J17 = [6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10, 0, 7, 9] +H36M_TO_J14 = H36M_TO_J17[:14] + + +class SMPL(_SMPL): + """ Extension of the official SMPL implementation to support more joints """ + + def __init__(self, *args, **kwargs): + super(SMPL, self).__init__(*args, **kwargs) + joints = [JOINT_MAP[i] for i in JOINT_NAMES] + self.smpl_mean_params = osp.join(args[0], 'smpl_mean_params.npz') + J_regressor_extra = np.load(osp.join(args[0], 'J_regressor_extra.npy')) + self.register_buffer('J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32)) + J_regressor_h36m = np.load(osp.join(args[0], 'J_regressor_h36m_correct.npy')) + self.register_buffer('J_regressor_h36m', torch.tensor(J_regressor_h36m, dtype=torch.float32)) + self.joint_map = torch.tensor(joints, dtype=torch.long) + + def forward(self, *args, **kwargs): + kwargs['get_skin'] = True + smpl_output = super(SMPL, self).forward(*args, **kwargs) + extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices) + joints = torch.cat([smpl_output.joints, extra_joints], dim=1) + joints = joints[:, self.joint_map, :] + output = SMPLOutput(vertices=smpl_output.vertices, + global_orient=smpl_output.global_orient, + body_pose=smpl_output.body_pose, + joints=joints, + betas=smpl_output.betas, + full_pose=smpl_output.full_pose) + return output + + +def get_smpl_faces(): + smpl = SMPL(SMPL_MODEL_DIR, batch_size=1, create_transl=False) + return smpl.faces \ No newline at end of file diff --git a/lib/utils/vismo.py b/lib/utils/vismo.py new file mode 100644 index 0000000000000000000000000000000000000000..92b290dc7dd7075f9914bc5578b1a5c5b120ddb1 --- /dev/null +++ b/lib/utils/vismo.py @@ -0,0 +1,345 @@ +import numpy as np +import os +import cv2 +import math +import copy +import imageio +import io +from tqdm import tqdm +from PIL import Image +from lib.utils.tools import ensure_dir +import matplotlib +import matplotlib.pyplot as plt +from mpl_toolkits.mplot3d import Axes3D +from lib.utils.utils_smpl import * +import ipdb + +def render_and_save(motion_input, save_path, keep_imgs=False, fps=25, color="#F96706#FB8D43#FDB381", with_conf=False, draw_face=False): + ensure_dir(os.path.dirname(save_path)) + motion = copy.deepcopy(motion_input) + if motion.shape[-1]==2 or motion.shape[-1]==3: + motion = np.transpose(motion, (1,2,0)) #(T,17,D) -> (17,D,T) + if motion.shape[1]==2 or with_conf: + colors = hex2rgb(color) + if not with_conf: + J, D, T = motion.shape + motion_full = np.ones([J,3,T]) + motion_full[:,:2,:] = motion + else: + motion_full = motion + motion_full[:,:2,:] = pixel2world_vis_motion(motion_full[:,:2,:]) + motion2video(motion_full, save_path=save_path, colors=colors, fps=fps) + elif motion.shape[0]==6890: + # motion_world = pixel2world_vis_motion(motion, dim=3) + motion2video_mesh(motion, save_path=save_path, keep_imgs=keep_imgs, fps=fps, draw_face=draw_face) + else: + motion_world = pixel2world_vis_motion(motion, dim=3) + motion2video_3d(motion_world, save_path=save_path, keep_imgs=keep_imgs, fps=fps) + +def pixel2world_vis(pose): +# pose: (17,2) + return (pose + [1, 1]) * 512 / 2 + +def pixel2world_vis_motion(motion, dim=2, is_tensor=False): +# pose: (17,2,N) + N = motion.shape[-1] + if dim==2: + offset = np.ones([2,N]).astype(np.float32) + else: + offset = np.ones([3,N]).astype(np.float32) + offset[2,:] = 0 + if is_tensor: + offset = torch.tensor(offset) + return (motion + offset) * 512 / 2 + +def vis_data_batch(data_input, data_label, n_render=10, save_path='doodle/vis_train_data/'): + ''' + data_input: [N,T,17,2/3] + data_label: [N,T,17,3] + ''' + pathlib.Path(save_path).mkdir(parents=True, exist_ok=True) + for i in range(min(len(data_input), n_render)): + render_and_save(data_input[i][:,:,:2], '%s/input_%d.mp4' % (save_path, i)) + render_and_save(data_label[i], '%s/gt_%d.mp4' % (save_path, i)) + +def get_img_from_fig(fig, dpi=120): + buf = io.BytesIO() + fig.savefig(buf, format="png", dpi=dpi, bbox_inches="tight", pad_inches=0) + buf.seek(0) + img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8) + buf.close() + img = cv2.imdecode(img_arr, 1) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGBA) + return img + +def rgb2rgba(color): + return (color[0], color[1], color[2], 255) + +def hex2rgb(hex, number_of_colors=3): + h = hex + rgb = [] + for i in range(number_of_colors): + h = h.lstrip('#') + hex_color = h[0:6] + rgb_color = [int(hex_color[i:i+2], 16) for i in (0, 2 ,4)] + rgb.append(rgb_color) + h = h[6:] + return rgb + +def joints2image(joints_position, colors, transparency=False, H=1000, W=1000, nr_joints=49, imtype=np.uint8, grayscale=False, bg_color=(255, 255, 255)): +# joints_position: [17*2] + nr_joints = joints_position.shape[0] + + if nr_joints == 49: # full joints(49): basic(15) + eyes(2) + toes(2) + hands(30) + limbSeq = [[0, 1], [1, 2], [1, 5], [1, 8], [2, 3], [3, 4], [5, 6], [6, 7], \ + [8, 9], [8, 13], [9, 10], [10, 11], [11, 12], [13, 14], [14, 15], [15, 16], + ]#[0, 17], [0, 18]] #ignore eyes + + L = rgb2rgba(colors[0]) if transparency else colors[0] + M = rgb2rgba(colors[1]) if transparency else colors[1] + R = rgb2rgba(colors[2]) if transparency else colors[2] + + colors_joints = [M, M, L, L, L, R, R, + R, M, L, L, L, L, R, R, R, + R, R, L] + [L] * 15 + [R] * 15 + + colors_limbs = [M, L, R, M, L, L, R, + R, L, R, L, L, L, R, R, R, + R, R] + elif nr_joints == 15: # basic joints(15) + (eyes(2)) + limbSeq = [[0, 1], [1, 2], [1, 5], [1, 8], [2, 3], [3, 4], [5, 6], [6, 7], + [8, 9], [8, 12], [9, 10], [10, 11], [12, 13], [13, 14]] + # [0, 15], [0, 16] two eyes are not drawn + + L = rgb2rgba(colors[0]) if transparency else colors[0] + M = rgb2rgba(colors[1]) if transparency else colors[1] + R = rgb2rgba(colors[2]) if transparency else colors[2] + + colors_joints = [M, M, L, L, L, R, R, + R, M, L, L, L, R, R, R] + + colors_limbs = [M, L, R, M, L, L, R, + R, L, R, L, L, R, R] + elif nr_joints == 17: # H36M, 0: 'root', + # 1: 'rhip', + # 2: 'rkne', + # 3: 'rank', + # 4: 'lhip', + # 5: 'lkne', + # 6: 'lank', + # 7: 'belly', + # 8: 'neck', + # 9: 'nose', + # 10: 'head', + # 11: 'lsho', + # 12: 'lelb', + # 13: 'lwri', + # 14: 'rsho', + # 15: 'relb', + # 16: 'rwri' + limbSeq = [[0, 1], [1, 2], [2, 3], [0, 4], [4, 5], [5, 6], [0, 7], [7, 8], [8, 9], [8, 11], [8, 14], [9, 10], [11, 12], [12, 13], [14, 15], [15, 16]] + + L = rgb2rgba(colors[0]) if transparency else colors[0] + M = rgb2rgba(colors[1]) if transparency else colors[1] + R = rgb2rgba(colors[2]) if transparency else colors[2] + + colors_joints = [M, R, R, R, L, L, L, M, M, M, M, L, L, L, R, R, R] + colors_limbs = [R, R, R, L, L, L, M, M, M, L, R, M, L, L, R, R] + + else: + raise ValueError("Only support number of joints be 49 or 17 or 15") + + if transparency: + canvas = np.zeros(shape=(H, W, 4)) + else: + canvas = np.ones(shape=(H, W, 3)) * np.array(bg_color).reshape([1, 1, 3]) + hips = joints_position[0] + neck = joints_position[8] + torso_length = ((hips[1] - neck[1]) ** 2 + (hips[0] - neck[0]) ** 2) ** 0.5 + head_radius = int(torso_length/4.5) + end_effectors_radius = int(torso_length/15) + end_effectors_radius = 7 + joints_radius = 7 + for i in range(0, len(colors_joints)): + if i in (17, 18): + continue + elif i > 18: + radius = 2 + else: + radius = joints_radius + if len(joints_position[i])==3: # If there is confidence, weigh by confidence + weight = joints_position[i][2] + if weight==0: + continue + cv2.circle(canvas, (int(joints_position[i][0]),int(joints_position[i][1])), radius, colors_joints[i], thickness=-1) + + stickwidth = 2 + for i in range(len(limbSeq)): + limb = limbSeq[i] + cur_canvas = canvas.copy() + point1_index = limb[0] + point2_index = limb[1] + point1 = joints_position[point1_index] + point2 = joints_position[point2_index] + if len(point1)==3: # If there is confidence, weigh by confidence + limb_weight = min(point1[2], point2[2]) + if limb_weight==0: + bb = bounding_box(canvas) + canvas_cropped = canvas[:,bb[2]:bb[3], :] + continue + X = [point1[1], point2[1]] + Y = [point1[0], point2[0]] + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + alpha = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(alpha), 0, 360, 1) + cv2.fillConvexPoly(cur_canvas, polygon, colors_limbs[i]) + canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0) + bb = bounding_box(canvas) + canvas_cropped = canvas[:,bb[2]:bb[3], :] + canvas = canvas.astype(imtype) + canvas_cropped = canvas_cropped.astype(imtype) + if grayscale: + if transparency: + canvas = cv2.cvtColor(canvas, cv2.COLOR_RGBA2GRAY) + canvas_cropped = cv2.cvtColor(canvas_cropped, cv2.COLOR_RGBA2GRAY) + else: + canvas = cv2.cvtColor(canvas, cv2.COLOR_RGB2GRAY) + canvas_cropped = cv2.cvtColor(canvas_cropped, cv2.COLOR_RGB2GRAY) + return [canvas, canvas_cropped] + + +def motion2video(motion, save_path, colors, h=512, w=512, bg_color=(255, 255, 255), transparency=False, motion_tgt=None, fps=25, save_frame=False, grayscale=False, show_progress=True, as_array=False): + nr_joints = motion.shape[0] +# as_array = save_path.endswith(".npy") + vlen = motion.shape[-1] + + out_array = np.zeros([vlen, h, w, 3]) if as_array else None + videowriter = None if as_array else imageio.get_writer(save_path, fps=fps) + + if save_frame: + frames_dir = save_path[:-4] + '-frames' + ensure_dir(frames_dir) + + iterator = range(vlen) + if show_progress: iterator = tqdm(iterator) + for i in iterator: + [img, img_cropped] = joints2image(motion[:, :, i], colors, transparency=transparency, bg_color=bg_color, H=h, W=w, nr_joints=nr_joints, grayscale=grayscale) + if motion_tgt is not None: + [img_tgt, img_tgt_cropped] = joints2image(motion_tgt[:, :, i], colors, transparency=transparency, bg_color=bg_color, H=h, W=w, nr_joints=nr_joints, grayscale=grayscale) + img_ori = img.copy() + img = cv2.addWeighted(img_tgt, 0.3, img_ori, 0.7, 0) + img_cropped = cv2.addWeighted(img_tgt, 0.3, img_ori, 0.7, 0) + bb = bounding_box(img_cropped) + img_cropped = img_cropped[:, bb[2]:bb[3], :] + if save_frame: + save_image(img_cropped, os.path.join(frames_dir, "%04d.png" % i)) + if as_array: out_array[i] = img + else: videowriter.append_data(img) + + if not as_array: + videowriter.close() + + return out_array + +def motion2video_3d(motion, save_path, fps=25, keep_imgs = False): +# motion: (17,3,N) + videowriter = imageio.get_writer(save_path, fps=fps) + vlen = motion.shape[-1] + save_name = save_path.split('.')[0] + frames = [] + joint_pairs = [[0, 1], [1, 2], [2, 3], [0, 4], [4, 5], [5, 6], [0, 7], [7, 8], [8, 9], [8, 11], [8, 14], [9, 10], [11, 12], [12, 13], [14, 15], [15, 16]] + joint_pairs_left = [[8, 11], [11, 12], [12, 13], [0, 4], [4, 5], [5, 6]] + joint_pairs_right = [[8, 14], [14, 15], [15, 16], [0, 1], [1, 2], [2, 3]] + + color_mid = "#00457E" + color_left = "#02315E" + color_right = "#2F70AF" + for f in tqdm(range(vlen)): + j3d = motion[:,:,f] + fig = plt.figure(0, figsize=(10, 10)) + ax = plt.axes(projection="3d") + ax.set_xlim(-512, 0) + ax.set_ylim(-256, 256) + ax.set_zlim(-512, 0) + # ax.set_xlabel('X') + # ax.set_ylabel('Y') + # ax.set_zlabel('Z') + ax.view_init(elev=12., azim=80) + plt.tick_params(left = False, right = False , labelleft = False , + labelbottom = False, bottom = False) + for i in range(len(joint_pairs)): + limb = joint_pairs[i] + xs, ys, zs = [np.array([j3d[limb[0], j], j3d[limb[1], j]]) for j in range(3)] + if joint_pairs[i] in joint_pairs_left: + ax.plot(-xs, -zs, -ys, color=color_left, lw=3, marker='o', markerfacecolor='w', markersize=3, markeredgewidth=2) # axis transformation for visualization + elif joint_pairs[i] in joint_pairs_right: + ax.plot(-xs, -zs, -ys, color=color_right, lw=3, marker='o', markerfacecolor='w', markersize=3, markeredgewidth=2) # axis transformation for visualization + else: + ax.plot(-xs, -zs, -ys, color=color_mid, lw=3, marker='o', markerfacecolor='w', markersize=3, markeredgewidth=2) # axis transformation for visualization + + frame_vis = get_img_from_fig(fig) + videowriter.append_data(frame_vis) + videowriter.close() + +def motion2video_mesh(motion, save_path, fps=25, keep_imgs = False, draw_face=True): + videowriter = imageio.get_writer(save_path, fps=fps) + vlen = motion.shape[-1] + draw_skele = (motion.shape[0]==17) + save_name = save_path.split('.')[0] + smpl_faces = get_smpl_faces() + frames = [] + joint_pairs = [[0, 1], [1, 2], [2, 3], [0, 4], [4, 5], [5, 6], [0, 7], [7, 8], [8, 9], [8, 11], [8, 14], [9, 10], [11, 12], [12, 13], [14, 15], [15, 16]] + + + X, Y, Z = motion[:, 0], motion[:, 1], motion[:, 2] + max_range = np.array([X.max()-X.min(), Y.max()-Y.min(), Z.max()-Z.min()]).max() / 2.0 + mid_x = (X.max()+X.min()) * 0.5 + mid_y = (Y.max()+Y.min()) * 0.5 + mid_z = (Z.max()+Z.min()) * 0.5 + + for f in tqdm(range(vlen)): + j3d = motion[:,:,f] + plt.gca().set_axis_off() + plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) + plt.gca().xaxis.set_major_locator(plt.NullLocator()) + plt.gca().yaxis.set_major_locator(plt.NullLocator()) + fig = plt.figure(0, figsize=(8, 8)) + ax = plt.axes(projection="3d", proj_type = 'ortho') + ax.set_xlim(mid_x - max_range, mid_x + max_range) + ax.set_ylim(mid_y - max_range, mid_y + max_range) + ax.set_zlim(mid_z - max_range, mid_z + max_range) + ax.view_init(elev=-90, azim=-90) + plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) + plt.margins(0, 0, 0) + plt.gca().xaxis.set_major_locator(plt.NullLocator()) + plt.gca().yaxis.set_major_locator(plt.NullLocator()) + plt.axis('off') + plt.xticks([]) + plt.yticks([]) + + # plt.savefig("filename.png", transparent=True, bbox_inches="tight", pad_inches=0) + + if draw_skele: + for i in range(len(joint_pairs)): + limb = joint_pairs[i] + xs, ys, zs = [np.array([j3d[limb[0], j], j3d[limb[1], j]]) for j in range(3)] + ax.plot(-xs, -zs, -ys, c=[0,0,0], lw=3, marker='o', markerfacecolor='w', markersize=3, markeredgewidth=2) # axis transformation for visualization + elif draw_face: + ax.plot_trisurf(j3d[:, 0], j3d[:, 1], triangles=smpl_faces, Z=j3d[:, 2], color=(166/255.0,188/255.0,218/255.0,0.9)) + else: + ax.scatter(j3d[:, 0], j3d[:, 1], j3d[:, 2], s=3, c='w', edgecolors='grey') + frame_vis = get_img_from_fig(fig, dpi=128) + plt.cla() + videowriter.append_data(frame_vis) + videowriter.close() + +def save_image(image_numpy, image_path): + image_pil = Image.fromarray(image_numpy) + image_pil.save(image_path) + +def bounding_box(img): + a = np.where(img != 0) + bbox = np.min(a[0]), np.max(a[0]), np.min(a[1]), np.max(a[1]) + return bbox diff --git a/params/d2c_params.pkl b/params/d2c_params.pkl new file mode 100644 index 0000000000000000000000000000000000000000..0a45d3604e9852cba80e9acd630dde2424d6187b --- /dev/null +++ b/params/d2c_params.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3b02023c3fc660f4808c735e2f8a9eae1206a411f1ad7e3429d33719da1cd0d1 +size 184 diff --git a/params/synthetic_noise.pth b/params/synthetic_noise.pth new file mode 100644 index 0000000000000000000000000000000000000000..f4d1739dbce24feeef8c70c0a305b760ca0df605 --- /dev/null +++ b/params/synthetic_noise.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9c801dfb859b08cf2ed96012176b0dcc7af2358d1a5d18a7c72b6e944416297b +size 1997 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..2f2f952dc8fe17f327de4cf2877a98e88e4e90c4 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +tensorboardX +tqdm +easydict +prettytable +chumpy +opencv-python +imageio-ffmpeg +matplotlib==3.1.1 +roma +ipdb +pytorch-metric-learning # For one-hot action recognition +smplx[all] # For mesh recovery diff --git a/tools/compress_amass.py b/tools/compress_amass.py new file mode 100644 index 0000000000000000000000000000000000000000..3acc72c4121ea4a3fb8a2ea3e09d9efd0b1ee296 --- /dev/null +++ b/tools/compress_amass.py @@ -0,0 +1,62 @@ +import numpy as np +import os +import pickle + +raw_dir = './data/AMASS/amass_202203/' +processed_dir = './data/AMASS/amass_fps60' +os.makedirs(processed_dir, exist_ok=True) + +files = [] +length = 0 +target_fps = 60 + +def traverse(f): + fs = os.listdir(f) + for f1 in fs: + tmp_path = os.path.join(f,f1) + # file + if not os.path.isdir(tmp_path): + files.append(tmp_path) + # dir + else: + traverse(tmp_path) + +traverse(raw_dir) + +print('files:', len(files)) + +fnames = [] +all_motions = [] + +with open('data/AMASS/fps.csv', 'w') as f: + print('fname_new, len_ori, fps, len_new', file=f) + for fname in sorted(files): + try: + raw_x = np.load(fname) + x = dict(raw_x) + fps = x['mocap_framerate'] + len_ori = len(x['trans']) + sample_stride = round(fps / target_fps) + x['mocap_framerate'] = target_fps + x['trans'] = x['trans'][::sample_stride] + x['dmpls'] = x['dmpls'][::sample_stride] + x['poses'] = x['poses'][::sample_stride] + fname_new = '_'.join(fname.split('/')[2:]) + len_new = len(x['trans']) + + length += len_new + print(fname_new, ',', len_ori, ',', fps, ',', len_new, file=f) + fnames.append(fname_new) + all_motions.append(x) + np.savez('%s/%s' % (processed_dir, fname_new), x) + except: + pass + +# break + +print('poseFrame:', length) +print('motions:', len(fnames)) + +with open("data/AMASS/all_motions_fps%d.pkl" % target_fps, "wb") as myprofile: + pickle.dump(all_motions, myprofile) + diff --git a/tools/convert_amass.py b/tools/convert_amass.py new file mode 100644 index 0000000000000000000000000000000000000000..42c9d0a8de4bd4c2d7273f9824221aa1e520dfa1 --- /dev/null +++ b/tools/convert_amass.py @@ -0,0 +1,67 @@ +import os +import sys +import copy +import pickle +import ipdb +import torch +import numpy as np +sys.path.insert(0, os.getcwd()) +from lib.utils.utils_data import split_clips +from tqdm import tqdm + +fileName = open('data/AMASS/amass_joints_h36m_60.pkl','rb') +joints_all = pickle.load(fileName) + +joints_cam = [] +vid_list = [] +vid_len_list = [] +scale_factor = 0.298 + +for i, item in enumerate(joints_all): # (17,N,3): + item = item.astype(np.float32) + vid_len = item.shape[1] + vid_len_list.append(vid_len) + for _ in range(vid_len): + vid_list.append(i) + real2cam = np.array([[1,0,0], + [0,0,1], + [0,-1,0]], dtype=np.float32) + item = np.transpose(item, (1,0,2)) # (17,N,3) -> (N,17,3) + motion_cam = item @ real2cam + motion_cam *= scale_factor + # motion_cam = motion_cam - motion_cam[0,0,:] + joints_cam.append(motion_cam) + +joints_cam_all = np.vstack(joints_cam) +split_id = datareader.split_clips(vid_list, n_frames=243, data_stride=81) +print(joints_cam_all.shape) # (N,17,3) + +max_x, minx_x = np.max(joints_cam_all[:,:,0]), np.min(joints_cam_all[:,:,0]) +max_y, minx_y = np.max(joints_cam_all[:,:,1]), np.min(joints_cam_all[:,:,1]) +max_z, minx_z = np.max(joints_cam_all[:,:,2]), np.min(joints_cam_all[:,:,2]) +print(max_x, minx_x) +print(max_y, minx_y) +print(max_z, minx_z) + +joints_cam_clip = joints_cam_all[split_id] +print(joints_cam_clip.shape) # (N,27,17,3) + +# np.save('doodle/joints_cam_clip_amass_60.npy', joints_cam_clip) + +root_path = "data/motion3d/MB3D_f243s81/AMASS" +subset_name = "train" +save_path = os.path.join(root_path, subset_name) +if not os.path.exists(save_path): + os.makedirs(save_path) + +num_clips = len(joints_cam_clip) +for i in tqdm(range(num_clips)): + motion = joints_cam_clip[i] + data_dict = { + "data_input": None, + "data_label": motion + } + with open(os.path.join(save_path, "%08d.pkl" % i), "wb") as myprofile: + pickle.dump(data_dict, myprofile) + + diff --git a/tools/convert_h36m.py b/tools/convert_h36m.py new file mode 100644 index 0000000000000000000000000000000000000000..d2997ea40266d41be5abcac29a1b44a846273af0 --- /dev/null +++ b/tools/convert_h36m.py @@ -0,0 +1,38 @@ +import os +import sys +import pickle +import numpy as np +import random +sys.path.insert(0, os.getcwd()) +from lib.utils.tools import read_pkl +from lib.data.datareader_h36m import DataReaderH36M +from tqdm import tqdm + + +def save_clips(subset_name, root_path, train_data, train_labels): + len_train = len(train_data) + save_path = os.path.join(root_path, subset_name) + if not os.path.exists(save_path): + os.makedirs(save_path) + for i in tqdm(range(len_train)): + data_input, data_label = train_data[i], train_labels[i] + data_dict = { + "data_input": data_input, + "data_label": data_label + } + with open(os.path.join(save_path, "%08d.pkl" % i), "wb") as myprofile: + pickle.dump(data_dict, myprofile) + +datareader = DataReaderH36M(n_frames=243, sample_stride=1, data_stride_train=81, data_stride_test=243, dt_file = 'h36m_sh_conf_cam_source_final.pkl', dt_root='data/motion3d/') +train_data, test_data, train_labels, test_labels = datareader.get_sliced_data() +print(train_data.shape, test_data.shape) +assert len(train_data) == len(train_labels) +assert len(test_data) == len(test_labels) + +root_path = "data/motion3d/MB3D_f243s81/H36M-SH" +if not os.path.exists(root_path): + os.makedirs(root_path) + +save_clips("train", root_path, train_data, train_labels) +save_clips("test", root_path, test_data, test_labels) + diff --git a/tools/convert_insta.py b/tools/convert_insta.py new file mode 100644 index 0000000000000000000000000000000000000000..9135c2c7e20e37a0e4fb2ae4edcfd62105068c8b --- /dev/null +++ b/tools/convert_insta.py @@ -0,0 +1,79 @@ +from __future__ import print_function +import os +import sys +import random +import copy +import argparse +import math +import pickle +import json +import glob +import numpy as np +sys.path.insert(0, os.getcwd()) +from lib.utils.utils_data import crop_scale + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--name_action', type=str) + args = parser.parse_args() + print("\nParameters:") + for attr, value in sorted(args.__dict__.items()): + print("\t{}={}".format(attr.upper(), value)) + return args + +def json2pose(json_dict): + pose_h36m = np.zeros([17,3]) + idx2key = ['Hip', + 'R Hip', + 'R Knee', + 'R Ankle', + 'L Hip', + 'L Knee', + 'L Ankle', + 'Belly', + 'Neck', + 'Nose', + 'Head', + 'L Shoulder', + 'L Elbow', + 'L Wrist', + 'R Shoulder', + 'R Elbow', + 'R Wrist', + ] + for i in range(17): + if idx2key[i]=='Belly' or idx2key[i]=='Head': + pose_h36m[i] = 0, 0, 0 + else: + item = json_dict[idx2key[i]] + pose_h36m[i] = item['x'], item['y'], item['logits'] + return pose_h36m + +def load_motion(json_path): + json_dict = json.load(open(json_path, 'r')) + pose_h36m = json2pose(json_dict) + return pose_h36m + + +args = parse_args() +dataset_root = 'data/Motion2d/InstaVariety/InstaVariety_tracks/' +action_motions = [] +dir_action = os.path.join(dataset_root, args.name_action) +for name_vid in sorted(os.listdir(dir_action)): + dir_vid = os.path.join(dir_action, name_vid) + for name_clip in sorted(os.listdir(dir_vid)): + motion_path = os.path.join(dir_vid, name_clip) + motion_list = sorted(glob.glob(motion_path+'/*.json')) + if len(motion_list)==0: + continue + motion = [load_motion(i) for i in motion_list] + motion = np.array(motion) + motion = crop_scale(motion) + motion[:,:,:2] = motion[:,:,:2] - motion[0:1,0:1,:2] + motion[motion[:,:,2]==0] = 0 + action_motions.append(motion) + print("%s Done, %d vids processed" % (name_vid, len(action_motions))) +print("%s Done, %d vids processed" % (args.name_action, len(action_motions))) +with open(os.path.join(dir_action, '%s.pkl' % args.name_action), 'wb') as f: + pickle.dump(action_motions, f) diff --git a/tools/preprocess_amass.py b/tools/preprocess_amass.py new file mode 100644 index 0000000000000000000000000000000000000000..399d48ccac91cc4b5996149cdc226861f2ed19f8 --- /dev/null +++ b/tools/preprocess_amass.py @@ -0,0 +1,64 @@ +import torch +import numpy as np +import os +from os import path as osp +from human_body_prior.body_model.body_model import BodyModel +import copy +import pickle +import ipdb +import pandas as pd + +df = pd.read_csv('./data/AMASS/fps.csv', sep=',',header=None) +fname_list = list(df[0][1:]) + +processed_dir = './data/AMASS/amass_fps60/' +J_reg_dir = 'data/AMASS/J_regressor_h36m_correct.npy' +all_motions = 'data/AMASS/all_motions_fps60.pkl' + +file = open(all_motions, 'rb') +motion_data = pickle.load(file) +J_reg = np.load(J_reg_dir) +all_joints = [] + +max_len = 2916 +with open('data/AMASS/clip_list.csv', 'w') as f: + print('clip_id, fname, clip_len', file=f) + for i, bdata in enumerate(motion_data): + if i%200==0: + print(i, 'seqs done.') + comp_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + subject_gender = bdata['gender'] + if (str(subject_gender) != 'female') and (str(subject_gender) != 'male'): + subject_gender = 'female' + + bm_fname = osp.join('data/AMASS/body_models/smplh/{}/model.npz'.format(subject_gender)) + dmpl_fname = osp.join('data/AMASS/body_models/dmpls/{}/model.npz'.format(subject_gender)) + + # number of body parameters + num_betas = 16 + # number of DMPL parameters + num_dmpls = 8 + + bm = BodyModel(bm_fname=bm_fname, num_betas=num_betas, num_dmpls=num_dmpls, dmpl_fname=dmpl_fname).to(comp_device) + time_length = len(bdata['trans']) + num_slice = time_length // max_len + + for sid in range(num_slice+1): + start = sid*max_len + end = min((sid+1)*max_len, time_length) + body_parms = { + 'root_orient': torch.Tensor(bdata['poses'][start:end, :3]).to(comp_device), # controls the global root orientation + 'pose_body': torch.Tensor(bdata['poses'][start:end, 3:66]).to(comp_device), # controls the body + 'pose_hand': torch.Tensor(bdata['poses'][start:end, 66:]).to(comp_device), # controls the finger articulation + 'trans': torch.Tensor(bdata['trans'][start:end]).to(comp_device), # controls the global body position + 'betas': torch.Tensor(np.repeat(bdata['betas'][:num_betas][np.newaxis], repeats=(end-start), axis=0)).to(comp_device), # controls the body shape. Body shape is static + 'dmpls': torch.Tensor(bdata['dmpls'][start:end, :num_dmpls]).to(comp_device) # controls soft tissue dynamics + } + body_trans_root = bm(**{k:v for k,v in body_parms.items() if k in ['pose_body', 'betas', 'pose_hand', 'dmpls', 'trans', 'root_orient']}) + mesh = body_trans_root.v.cpu().numpy() + kpts = np.dot(J_reg, mesh) # (17,T,3) + all_joints.append(kpts) + print(len(all_joints)-1, ',', fname_list[i], ',', end-start, file=f) + fileName = open('data/AMASS/amass_joints_h36m_60.pkl','wb') + pickle.dump(all_joints, fileName) + print(len(all_joints)) \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..5950ee6c1587462c5ba7cb80978e22ab8fbb15a8 --- /dev/null +++ b/train.py @@ -0,0 +1,383 @@ +import os +import numpy as np +import argparse +import errno +import math +import pickle +import tensorboardX +from tqdm import tqdm +from time import time +import copy +import random +import prettytable + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader + +from lib.utils.tools import * +from lib.utils.learning import * +from lib.utils.utils_data import flip_data +from lib.data.dataset_motion_2d import PoseTrackDataset2D, InstaVDataset2D +from lib.data.dataset_motion_3d import MotionDataset3D +from lib.data.augmentation import Augmenter2D +from lib.data.datareader_h36m import DataReaderH36M +from lib.model.loss import * + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default="configs/pretrain.yaml", help="Path to the config file.") + parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', help='checkpoint directory') + parser.add_argument('-p', '--pretrained', default='checkpoint', type=str, metavar='PATH', help='pretrained checkpoint directory') + parser.add_argument('-r', '--resume', default='', type=str, metavar='FILENAME', help='checkpoint to resume (file name)') + parser.add_argument('-e', '--evaluate', default='', type=str, metavar='FILENAME', help='checkpoint to evaluate (file name)') + parser.add_argument('-ms', '--selection', default='latest_epoch.bin', type=str, metavar='FILENAME', help='checkpoint to finetune (file name)') + parser.add_argument('-sd', '--seed', default=0, type=int, help='random seed') + opts = parser.parse_args() + return opts + +def set_random_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + +def save_checkpoint(chk_path, epoch, lr, optimizer, model_pos, min_loss): + print('Saving checkpoint to', chk_path) + torch.save({ + 'epoch': epoch + 1, + 'lr': lr, + 'optimizer': optimizer.state_dict(), + 'model_pos': model_pos.state_dict(), + 'min_loss' : min_loss + }, chk_path) + +def evaluate(args, model_pos, test_loader, datareader): + print('INFO: Testing') + results_all = [] + model_pos.eval() + with torch.no_grad(): + for batch_input, batch_gt in tqdm(test_loader): + N, T = batch_gt.shape[:2] + if torch.cuda.is_available(): + batch_input = batch_input.cuda() + if args.no_conf: + batch_input = batch_input[:, :, :, :2] + if args.flip: + batch_input_flip = flip_data(batch_input) + predicted_3d_pos_1 = model_pos(batch_input) + predicted_3d_pos_flip = model_pos(batch_input_flip) + predicted_3d_pos_2 = flip_data(predicted_3d_pos_flip) # Flip back + predicted_3d_pos = (predicted_3d_pos_1+predicted_3d_pos_2) / 2 + else: + predicted_3d_pos = model_pos(batch_input) + if args.rootrel: + predicted_3d_pos[:,:,0,:] = 0 # [N,T,17,3] + else: + batch_gt[:,0,0,2] = 0 + + if args.gt_2d: + predicted_3d_pos[...,:2] = batch_input[...,:2] + results_all.append(predicted_3d_pos.cpu().numpy()) + results_all = np.concatenate(results_all) + results_all = datareader.denormalize(results_all) + _, split_id_test = datareader.get_split_id() + actions = np.array(datareader.dt_dataset['test']['action']) + factors = np.array(datareader.dt_dataset['test']['2.5d_factor']) + gts = np.array(datareader.dt_dataset['test']['joints_2.5d_image']) + sources = np.array(datareader.dt_dataset['test']['source']) + + num_test_frames = len(actions) + frames = np.array(range(num_test_frames)) + action_clips = actions[split_id_test] + factor_clips = factors[split_id_test] + source_clips = sources[split_id_test] + frame_clips = frames[split_id_test] + gt_clips = gts[split_id_test] + assert len(results_all)==len(action_clips) + + e1_all = np.zeros(num_test_frames) + e2_all = np.zeros(num_test_frames) + oc = np.zeros(num_test_frames) + results = {} + results_procrustes = {} + action_names = sorted(set(datareader.dt_dataset['test']['action'])) + for action in action_names: + results[action] = [] + results_procrustes[action] = [] + block_list = ['s_09_act_05_subact_02', + 's_09_act_10_subact_02', + 's_09_act_13_subact_01'] + for idx in range(len(action_clips)): + source = source_clips[idx][0][:-6] + if source in block_list: + continue + frame_list = frame_clips[idx] + action = action_clips[idx][0] + factor = factor_clips[idx][:,None,None] + gt = gt_clips[idx] + pred = results_all[idx] + pred *= factor + + # Root-relative Errors + pred = pred - pred[:,0:1,:] + gt = gt - gt[:,0:1,:] + err1 = mpjpe(pred, gt) + err2 = p_mpjpe(pred, gt) + e1_all[frame_list] += err1 + e2_all[frame_list] += err2 + oc[frame_list] += 1 + for idx in range(num_test_frames): + if e1_all[idx] > 0: + err1 = e1_all[idx] / oc[idx] + err2 = e2_all[idx] / oc[idx] + action = actions[idx] + results[action].append(err1) + results_procrustes[action].append(err2) + final_result = [] + final_result_procrustes = [] + summary_table = prettytable.PrettyTable() + summary_table.field_names = ['test_name'] + action_names + for action in action_names: + final_result.append(np.mean(results[action])) + final_result_procrustes.append(np.mean(results_procrustes[action])) + summary_table.add_row(['P1'] + final_result) + summary_table.add_row(['P2'] + final_result_procrustes) + print(summary_table) + e1 = np.mean(np.array(final_result)) + e2 = np.mean(np.array(final_result_procrustes)) + print('Protocol #1 Error (MPJPE):', e1, 'mm') + print('Protocol #2 Error (P-MPJPE):', e2, 'mm') + print('----------') + return e1, e2, results_all + +def train_epoch(args, model_pos, train_loader, losses, optimizer, has_3d, has_gt): + model_pos.train() + for idx, (batch_input, batch_gt) in tqdm(enumerate(train_loader)): + batch_size = len(batch_input) + if torch.cuda.is_available(): + batch_input = batch_input.cuda() + batch_gt = batch_gt.cuda() + with torch.no_grad(): + if args.no_conf: + batch_input = batch_input[:, :, :, :2] + if not has_3d: + conf = copy.deepcopy(batch_input[:,:,:,2:]) # For 2D data, weight/confidence is at the last channel + if args.rootrel: + batch_gt = batch_gt - batch_gt[:,:,0:1,:] + else: + batch_gt[:,:,:,2] = batch_gt[:,:,:,2] - batch_gt[:,0:1,0:1,2] # Place the depth of first frame root to 0. + if args.mask or args.noise: + batch_input = args.aug.augment2D(batch_input, noise=(args.noise and has_gt), mask=args.mask) + # Predict 3D poses + predicted_3d_pos = model_pos(batch_input) # (N, T, 17, 3) + + optimizer.zero_grad() + if has_3d: + loss_3d_pos = loss_mpjpe(predicted_3d_pos, batch_gt) + loss_3d_scale = n_mpjpe(predicted_3d_pos, batch_gt) + loss_3d_velocity = loss_velocity(predicted_3d_pos, batch_gt) + loss_lv = loss_limb_var(predicted_3d_pos) + loss_lg = loss_limb_gt(predicted_3d_pos, batch_gt) + loss_a = loss_angle(predicted_3d_pos, batch_gt) + loss_av = loss_angle_velocity(predicted_3d_pos, batch_gt) + loss_total = loss_3d_pos + \ + args.lambda_scale * loss_3d_scale + \ + args.lambda_3d_velocity * loss_3d_velocity + \ + args.lambda_lv * loss_lv + \ + args.lambda_lg * loss_lg + \ + args.lambda_a * loss_a + \ + args.lambda_av * loss_av + losses['3d_pos'].update(loss_3d_pos.item(), batch_size) + losses['3d_scale'].update(loss_3d_scale.item(), batch_size) + losses['3d_velocity'].update(loss_3d_velocity.item(), batch_size) + losses['lv'].update(loss_lv.item(), batch_size) + losses['lg'].update(loss_lg.item(), batch_size) + losses['angle'].update(loss_a.item(), batch_size) + losses['angle_velocity'].update(loss_av.item(), batch_size) + losses['total'].update(loss_total.item(), batch_size) + else: + loss_2d_proj = loss_2d_weighted(predicted_3d_pos, batch_gt, conf) + loss_total = loss_2d_proj + losses['2d_proj'].update(loss_2d_proj.item(), batch_size) + losses['total'].update(loss_total.item(), batch_size) + loss_total.backward() + optimizer.step() + +def train_with_config(args, opts): + print(args) + try: + os.makedirs(opts.checkpoint) + except OSError as e: + if e.errno != errno.EEXIST: + raise RuntimeError('Unable to create checkpoint directory:', opts.checkpoint) + train_writer = tensorboardX.SummaryWriter(os.path.join(opts.checkpoint, "logs")) + + + print('Loading dataset...') + trainloader_params = { + 'batch_size': args.batch_size, + 'shuffle': True, + 'num_workers': 12, + 'pin_memory': True, + 'prefetch_factor': 4, + 'persistent_workers': True + } + + testloader_params = { + 'batch_size': args.batch_size, + 'shuffle': False, + 'num_workers': 12, + 'pin_memory': True, + 'prefetch_factor': 4, + 'persistent_workers': True + } + + train_dataset = MotionDataset3D(args, args.subset_list, 'train') + test_dataset = MotionDataset3D(args, args.subset_list, 'test') + train_loader_3d = DataLoader(train_dataset, **trainloader_params) + test_loader = DataLoader(test_dataset, **testloader_params) + + if args.train_2d: + posetrack = PoseTrackDataset2D() + posetrack_loader_2d = DataLoader(posetrack, **trainloader_params) + instav = InstaVDataset2D() + instav_loader_2d = DataLoader(instav, **trainloader_params) + + datareader = DataReaderH36M(n_frames=args.clip_len, sample_stride=args.sample_stride, data_stride_train=args.data_stride, data_stride_test=args.clip_len, dt_root = 'data/motion3d', dt_file=args.dt_file) + min_loss = 100000 + model_backbone = load_backbone(args) + model_params = 0 + for parameter in model_backbone.parameters(): + model_params = model_params + parameter.numel() + print('INFO: Trainable parameter count:', model_params) + + if torch.cuda.is_available(): + model_backbone = nn.DataParallel(model_backbone) + model_backbone = model_backbone.cuda() + + if args.finetune: + if opts.resume or opts.evaluate: + chk_filename = opts.evaluate if opts.evaluate else opts.resume + print('Loading checkpoint', chk_filename) + checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage) + model_backbone.load_state_dict(checkpoint['model_pos'], strict=True) + model_pos = model_backbone + else: + chk_filename = os.path.join(opts.pretrained, opts.selection) + print('Loading checkpoint', chk_filename) + checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage) + model_backbone.load_state_dict(checkpoint['model_pos'], strict=True) + model_pos = model_backbone + else: + chk_filename = os.path.join(opts.checkpoint, "latest_epoch.bin") + if os.path.exists(chk_filename): + opts.resume = chk_filename + if opts.resume or opts.evaluate: + chk_filename = opts.evaluate if opts.evaluate else opts.resume + print('Loading checkpoint', chk_filename) + checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage) + model_backbone.load_state_dict(checkpoint['model_pos'], strict=True) + model_pos = model_backbone + + if args.partial_train: + model_pos = partial_train_layers(model_pos, args.partial_train) + + if not opts.evaluate: + lr = args.learning_rate + optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model_pos.parameters()), lr=lr, weight_decay=args.weight_decay) + lr_decay = args.lr_decay + st = 0 + if args.train_2d: + print('INFO: Training on {}(3D)+{}(2D) batches'.format(len(train_loader_3d), len(instav_loader_2d) + len(posetrack_loader_2d))) + else: + print('INFO: Training on {}(3D) batches'.format(len(train_loader_3d))) + if opts.resume: + st = checkpoint['epoch'] + if 'optimizer' in checkpoint and checkpoint['optimizer'] is not None: + optimizer.load_state_dict(checkpoint['optimizer']) + else: + print('WARNING: this checkpoint does not contain an optimizer state. The optimizer will be reinitialized.') + lr = checkpoint['lr'] + if 'min_loss' in checkpoint and checkpoint['min_loss'] is not None: + min_loss = checkpoint['min_loss'] + + args.mask = (args.mask_ratio > 0 and args.mask_T_ratio > 0) + if args.mask or args.noise: + args.aug = Augmenter2D(args) + + # Training + for epoch in range(st, args.epochs): + print('Training epoch %d.' % epoch) + start_time = time() + losses = {} + losses['3d_pos'] = AverageMeter() + losses['3d_scale'] = AverageMeter() + losses['2d_proj'] = AverageMeter() + losses['lg'] = AverageMeter() + losses['lv'] = AverageMeter() + losses['total'] = AverageMeter() + losses['3d_velocity'] = AverageMeter() + losses['angle'] = AverageMeter() + losses['angle_velocity'] = AverageMeter() + N = 0 + + # Curriculum Learning + if args.train_2d and (epoch >= args.pretrain_3d_curriculum): + train_epoch(args, model_pos, posetrack_loader_2d, losses, optimizer, has_3d=False, has_gt=True) + train_epoch(args, model_pos, instav_loader_2d, losses, optimizer, has_3d=False, has_gt=False) + train_epoch(args, model_pos, train_loader_3d, losses, optimizer, has_3d=True, has_gt=True) + elapsed = (time() - start_time) / 60 + + if args.no_eval: + print('[%d] time %.2f lr %f 3d_train %f' % ( + epoch + 1, + elapsed, + lr, + losses['3d_pos'].avg)) + else: + e1, e2, results_all = evaluate(args, model_pos, test_loader, datareader) + print('[%d] time %.2f lr %f 3d_train %f e1 %f e2 %f' % ( + epoch + 1, + elapsed, + lr, + losses['3d_pos'].avg, + e1, e2)) + train_writer.add_scalar('Error P1', e1, epoch + 1) + train_writer.add_scalar('Error P2', e2, epoch + 1) + train_writer.add_scalar('loss_3d_pos', losses['3d_pos'].avg, epoch + 1) + train_writer.add_scalar('loss_2d_proj', losses['2d_proj'].avg, epoch + 1) + train_writer.add_scalar('loss_3d_scale', losses['3d_scale'].avg, epoch + 1) + train_writer.add_scalar('loss_3d_velocity', losses['3d_velocity'].avg, epoch + 1) + train_writer.add_scalar('loss_lv', losses['lv'].avg, epoch + 1) + train_writer.add_scalar('loss_lg', losses['lg'].avg, epoch + 1) + train_writer.add_scalar('loss_a', losses['angle'].avg, epoch + 1) + train_writer.add_scalar('loss_av', losses['angle_velocity'].avg, epoch + 1) + train_writer.add_scalar('loss_total', losses['total'].avg, epoch + 1) + + # Decay learning rate exponentially + lr *= lr_decay + for param_group in optimizer.param_groups: + param_group['lr'] *= lr_decay + + # Save checkpoints + chk_path = os.path.join(opts.checkpoint, 'epoch_{}.bin'.format(epoch)) + chk_path_latest = os.path.join(opts.checkpoint, 'latest_epoch.bin') + chk_path_best = os.path.join(opts.checkpoint, 'best_epoch.bin'.format(epoch)) + + save_checkpoint(chk_path_latest, epoch, lr, optimizer, model_pos, min_loss) + if (epoch + 1) % args.checkpoint_frequency == 0: + save_checkpoint(chk_path, epoch, lr, optimizer, model_pos, min_loss) + if e1 < min_loss: + min_loss = e1 + save_checkpoint(chk_path_best, epoch, lr, optimizer, model_pos, min_loss) + + if opts.evaluate: + e1, e2, results_all = evaluate(args, model_pos, test_loader, datareader) + +if __name__ == "__main__": + opts = parse_args() + set_random_seed(opts.seed) + args = get_config(opts.config) + train_with_config(args, opts) \ No newline at end of file diff --git a/train_action.py b/train_action.py new file mode 100644 index 0000000000000000000000000000000000000000..e105c26e14580eb9aa04f979f3cb0560fb504b41 --- /dev/null +++ b/train_action.py @@ -0,0 +1,243 @@ +import os +import numpy as np +import time +import sys +import argparse +import errno +from collections import OrderedDict +import tensorboardX +from tqdm import tqdm +import random + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.optim.lr_scheduler import StepLR +from torch.utils.data import DataLoader + +from lib.utils.tools import * +from lib.utils.learning import * +from lib.model.loss import * +from lib.data.dataset_action import NTURGBD +from lib.model.model_action import ActionNet + +random.seed(0) +np.random.seed(0) +torch.manual_seed(0) + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default="configs/pretrain.yaml", help="Path to the config file.") + parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', help='checkpoint directory') + parser.add_argument('-p', '--pretrained', default='checkpoint', type=str, metavar='PATH', help='pretrained checkpoint directory') + parser.add_argument('-r', '--resume', default='', type=str, metavar='FILENAME', help='checkpoint to resume (file name)') + parser.add_argument('-e', '--evaluate', default='', type=str, metavar='FILENAME', help='checkpoint to evaluate (file name)') + parser.add_argument('-freq', '--print_freq', default=100) + parser.add_argument('-ms', '--selection', default='latest_epoch.bin', type=str, metavar='FILENAME', help='checkpoint to finetune (file name)') + opts = parser.parse_args() + return opts + +def validate(test_loader, model, criterion): + model.eval() + batch_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + with torch.no_grad(): + end = time.time() + for idx, (batch_input, batch_gt) in tqdm(enumerate(test_loader)): + batch_size = len(batch_input) + if torch.cuda.is_available(): + batch_gt = batch_gt.cuda() + batch_input = batch_input.cuda() + output = model(batch_input) # (N, num_classes) + loss = criterion(output, batch_gt) + + # update metric + losses.update(loss.item(), batch_size) + acc1, acc5 = accuracy(output, batch_gt, topk=(1, 5)) + top1.update(acc1[0], batch_size) + top5.update(acc5[0], batch_size) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if (idx+1) % opts.print_freq == 0: + print('Test: [{0}/{1}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})\t'.format( + idx, len(test_loader), batch_time=batch_time, + loss=losses, top1=top1, top5=top5)) + return losses.avg, top1.avg, top5.avg + + +def train_with_config(args, opts): + print(args) + try: + os.makedirs(opts.checkpoint) + except OSError as e: + if e.errno != errno.EEXIST: + raise RuntimeError('Unable to create checkpoint directory:', opts.checkpoint) + train_writer = tensorboardX.SummaryWriter(os.path.join(opts.checkpoint, "logs")) + model_backbone = load_backbone(args) + if args.finetune: + if opts.resume or opts.evaluate: + pass + else: + chk_filename = os.path.join(opts.pretrained, opts.selection) + print('Loading backbone', chk_filename) + checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage)['model_pos'] + model_backbone = load_pretrained_weights(model_backbone, checkpoint) + if args.partial_train: + model_backbone = partial_train_layers(model_backbone, args.partial_train) + model = ActionNet(backbone=model_backbone, dim_rep=args.dim_rep, num_classes=args.action_classes, dropout_ratio=args.dropout_ratio, version=args.model_version, hidden_dim=args.hidden_dim, num_joints=args.num_joints) + criterion = torch.nn.CrossEntropyLoss() + if torch.cuda.is_available(): + model = nn.DataParallel(model) + model = model.cuda() + criterion = criterion.cuda() + best_acc = 0 + model_params = 0 + for parameter in model.parameters(): + model_params = model_params + parameter.numel() + print('INFO: Trainable parameter count:', model_params) + print('Loading dataset...') + trainloader_params = { + 'batch_size': args.batch_size, + 'shuffle': True, + 'num_workers': 8, + 'pin_memory': True, + 'prefetch_factor': 4, + 'persistent_workers': True + } + testloader_params = { + 'batch_size': args.batch_size, + 'shuffle': False, + 'num_workers': 8, + 'pin_memory': True, + 'prefetch_factor': 4, + 'persistent_workers': True + } + data_path = 'data/action/%s.pkl' % args.dataset + ntu60_xsub_train = NTURGBD(data_path=data_path, data_split=args.data_split+'_train', n_frames=args.clip_len, random_move=args.random_move, scale_range=args.scale_range_train) + ntu60_xsub_val = NTURGBD(data_path=data_path, data_split=args.data_split+'_val', n_frames=args.clip_len, random_move=False, scale_range=args.scale_range_test) + + train_loader = DataLoader(ntu60_xsub_train, **trainloader_params) + test_loader = DataLoader(ntu60_xsub_val, **testloader_params) + + chk_filename = os.path.join(opts.checkpoint, "latest_epoch.bin") + if os.path.exists(chk_filename): + opts.resume = chk_filename + if opts.resume or opts.evaluate: + chk_filename = opts.evaluate if opts.evaluate else opts.resume + print('Loading checkpoint', chk_filename) + checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage) + model.load_state_dict(checkpoint['model'], strict=True) + + if not opts.evaluate: + optimizer = optim.AdamW( + [ {"params": filter(lambda p: p.requires_grad, model.module.backbone.parameters()), "lr": args.lr_backbone}, + {"params": filter(lambda p: p.requires_grad, model.module.head.parameters()), "lr": args.lr_head}, + ], lr=args.lr_backbone, + weight_decay=args.weight_decay + ) + + scheduler = StepLR(optimizer, step_size=1, gamma=args.lr_decay) + st = 0 + print('INFO: Training on {} batches'.format(len(train_loader))) + if opts.resume: + st = checkpoint['epoch'] + if 'optimizer' in checkpoint and checkpoint['optimizer'] is not None: + optimizer.load_state_dict(checkpoint['optimizer']) + else: + print('WARNING: this checkpoint does not contain an optimizer state. The optimizer will be reinitialized.') + lr = checkpoint['lr'] + if 'best_acc' in checkpoint and checkpoint['best_acc'] is not None: + best_acc = checkpoint['best_acc'] + # Training + for epoch in range(st, args.epochs): + print('Training epoch %d.' % epoch) + losses_train = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + batch_time = AverageMeter() + data_time = AverageMeter() + model.train() + end = time.time() + iters = len(train_loader) + for idx, (batch_input, batch_gt) in tqdm(enumerate(train_loader)): # (N, 2, T, 17, 3) + data_time.update(time.time() - end) + batch_size = len(batch_input) + if torch.cuda.is_available(): + batch_gt = batch_gt.cuda() + batch_input = batch_input.cuda() + output = model(batch_input) # (N, num_classes) + optimizer.zero_grad() + loss_train = criterion(output, batch_gt) + losses_train.update(loss_train.item(), batch_size) + acc1, acc5 = accuracy(output, batch_gt, topk=(1, 5)) + top1.update(acc1[0], batch_size) + top5.update(acc5[0], batch_size) + loss_train.backward() + optimizer.step() + batch_time.update(time.time() - end) + end = time.time() + if (idx + 1) % opts.print_freq == 0: + print('Train: [{0}][{1}/{2}]\t' + 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' + 'loss {loss.val:.3f} ({loss.avg:.3f})\t' + 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format( + epoch, idx + 1, len(train_loader), batch_time=batch_time, + data_time=data_time, loss=losses_train, top1=top1)) + sys.stdout.flush() + + test_loss, test_top1, test_top5 = validate(test_loader, model, criterion) + + train_writer.add_scalar('train_loss', losses_train.avg, epoch + 1) + train_writer.add_scalar('train_top1', top1.avg, epoch + 1) + train_writer.add_scalar('train_top5', top5.avg, epoch + 1) + train_writer.add_scalar('test_loss', test_loss, epoch + 1) + train_writer.add_scalar('test_top1', test_top1, epoch + 1) + train_writer.add_scalar('test_top5', test_top5, epoch + 1) + + scheduler.step() + + # Save latest checkpoint. + chk_path = os.path.join(opts.checkpoint, 'latest_epoch.bin') + print('Saving checkpoint to', chk_path) + torch.save({ + 'epoch': epoch+1, + 'lr': scheduler.get_last_lr(), + 'optimizer': optimizer.state_dict(), + 'model': model.state_dict(), + 'best_acc' : best_acc + }, chk_path) + + # Save best checkpoint. + best_chk_path = os.path.join(opts.checkpoint, 'best_epoch.bin'.format(epoch)) + if test_top1 > best_acc: + best_acc = test_top1 + print("save best checkpoint") + torch.save({ + 'epoch': epoch+1, + 'lr': scheduler.get_last_lr(), + 'optimizer': optimizer.state_dict(), + 'model': model.state_dict(), + 'best_acc' : best_acc + }, best_chk_path) + + if opts.evaluate: + test_loss, test_top1, test_top5 = validate(test_loader, model, criterion) + print('Loss {loss:.4f} \t' + 'Acc@1 {top1:.3f} \t' + 'Acc@5 {top5:.3f} \t'.format(loss=test_loss, top1=test_top1, top5=test_top5)) + +if __name__ == "__main__": + opts = parse_args() + args = get_config(opts.config) + train_with_config(args, opts) \ No newline at end of file diff --git a/train_action_1shot.py b/train_action_1shot.py new file mode 100644 index 0000000000000000000000000000000000000000..9f07902cb46b3d6b9d6258f59090e5932f756d5d --- /dev/null +++ b/train_action_1shot.py @@ -0,0 +1,243 @@ +import os +import numpy as np +import time +import sys +import argparse +import errno +from collections import OrderedDict +import tensorboardX +from tqdm import tqdm +import random + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.optim.lr_scheduler import StepLR +from torch.utils.data import DataLoader + +from lib.utils.tools import * +from lib.utils.learning import * +from lib.model.loss import * +from lib.data.dataset_action import NTURGBD, NTURGBD1Shot +from lib.model.model_action import ActionNet + +from lib.model.loss_supcon import SupConLoss +from pytorch_metric_learning import samplers + +random.seed(0) +np.random.seed(0) +torch.manual_seed(0) + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default="configs/pretrain.yaml", help="Path to the config file.") + parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', help='checkpoint directory') + parser.add_argument('-p', '--pretrained', default='checkpoint', type=str, metavar='PATH', help='pretrained checkpoint directory') + parser.add_argument('-r', '--resume', default='', type=str, metavar='FILENAME', help='checkpoint to resume (file name)') + parser.add_argument('-e', '--evaluate', default='', type=str, metavar='FILENAME', help='checkpoint to evaluate (file name)') + parser.add_argument('-freq', '--print_freq', default=100) + parser.add_argument('-ms', '--selection', default='best_epoch.bin', type=str, metavar='FILENAME', help='checkpoint to finetune (file name)') + opts = parser.parse_args() + return opts + +def extract_feats(dataloader_x, model): + all_feats = [] + all_gts = [] + with torch.no_grad(): + for idx, (batch_input, batch_gt) in tqdm(enumerate(dataloader_x)): # (N, 2, T, 17, 3) + if torch.cuda.is_available(): + batch_input = batch_input.cuda() + feat = model(batch_input) + all_feats.append(feat) + all_gts.append(batch_gt) + all_feats = torch.cat(all_feats) + all_gts = torch.cat(all_gts) + return all_feats, all_gts + +def validate(anchor_loader, test_loader, model): + train_feats, train_labels = extract_feats(anchor_loader, model) + test_feats, test_labels = extract_feats(test_loader, model) + M = len(train_feats) + N = len(test_feats) + train_feats = train_feats.unsqueeze(1) + test_feats = test_feats.unsqueeze(0) + dis = F.cosine_similarity(train_feats, test_feats, dim=-1) + pred = train_labels[torch.argmax(dis, dim=0)] + assert len(pred)==len(test_labels) + acc = sum(pred==test_labels) / len(pred) + return acc + +def train_with_config(args, opts): + print(args) + try: + os.makedirs(opts.checkpoint) + except OSError as e: + if e.errno != errno.EEXIST: + raise RuntimeError('Unable to create checkpoint directory:', opts.checkpoint) + train_writer = tensorboardX.SummaryWriter(os.path.join(opts.checkpoint, "logs")) + model_backbone = load_backbone(args) + if args.finetune: + if opts.resume or opts.evaluate: + pass + else: + chk_filename = os.path.join(opts.pretrained, "best_epoch.bin") + print('Loading backbone', chk_filename) + checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage) + new_state_dict = OrderedDict() + for k, v in checkpoint['model_pos'].items(): + name = k[7:] # remove 'module.' + new_state_dict[name] = v + model_backbone.load_state_dict(new_state_dict, strict=True) + if args.partial_train: + model_backbone = partial_train_layers(model_backbone, args.partial_train) + model = ActionNet(backbone=model_backbone, dim_rep=args.dim_rep, dropout_ratio=args.dropout_ratio, version=args.model_version, hidden_dim=args.hidden_dim, num_joints=args.num_joints) + criterion = SupConLoss(temperature=args.temp) + + if torch.cuda.is_available(): + model = nn.DataParallel(model) + model = model.cuda() + criterion = criterion.cuda() + + chk_filename = os.path.join(opts.checkpoint, "latest_epoch.bin") + if os.path.exists(chk_filename): + opts.resume = chk_filename + if opts.resume or opts.evaluate: + chk_filename = opts.evaluate if opts.evaluate else opts.resume + print('Loading checkpoint', chk_filename) + checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage) + model.load_state_dict(checkpoint['model'], strict=True) + + best_acc = 0 + model_params = 0 + for parameter in model.parameters(): + model_params = model_params + parameter.numel() + print('INFO: Trainable parameter count:', model_params) + print('Loading dataset...') + + anchorloader_params = { + 'batch_size': args.batch_size, + 'shuffle': False, + 'num_workers': 8, + 'pin_memory': True, + 'prefetch_factor': 4, + 'persistent_workers': True + } + + testloader_params = { + 'batch_size': args.batch_size, + 'shuffle': False, + 'num_workers': 8, + 'pin_memory': True, + 'prefetch_factor': 4, + 'persistent_workers': True + } + data_path_1shot = 'data/action/ntu120_hrnet_oneshot.pkl' + ntu60_1shot_anchor = NTURGBD(data_path=data_path_1shot, data_split='oneshot_train', n_frames=args.clip_len, random_move=False, scale_range=args.scale_range_test) + ntu60_1shot_test = NTURGBD(data_path=data_path_1shot, data_split='oneshot_val', n_frames=args.clip_len, random_move=False, scale_range=args.scale_range_test) + anchor_loader = DataLoader(ntu60_1shot_anchor, **anchorloader_params) + test_loader = DataLoader(ntu60_1shot_test, **testloader_params) + + if not opts.evaluate: + # Load training data (auxiliary set) + data_path = 'data/action/ntu120_hrnet.pkl' + ntu120_1shot_train = NTURGBD1Shot(data_path=data_path, data_split='', n_frames=args.clip_len, random_move=args.random_move, scale_range=args.scale_range_train, check_split=False) + sampler = samplers.MPerClassSampler(ntu120_1shot_train.labels, m=args.n_views, batch_size=args.batch_size, length_before_new_iter=len(ntu120_1shot_train)) + trainloader_params = { + 'batch_size': args.batch_size, + 'shuffle': False, + 'num_workers': 8, + 'pin_memory': True, + 'prefetch_factor': 4, + 'persistent_workers': True, + 'sampler': sampler + } + train_loader = DataLoader(ntu120_1shot_train, **trainloader_params) + optimizer = optim.AdamW( + [ {"params": filter(lambda p: p.requires_grad, model.module.backbone.parameters()), "lr": args.lr_backbone}, + {"params": filter(lambda p: p.requires_grad, model.module.head.parameters()), "lr": args.lr_head}, + ], lr=args.lr_backbone, + weight_decay=args.weight_decay + ) + scheduler = StepLR(optimizer, step_size=1, gamma=args.lr_decay) + st = 0 + print('INFO: Training on {} batches'.format(len(train_loader))) + if opts.resume: + st = checkpoint['epoch'] + if 'optimizer' in checkpoint and checkpoint['optimizer'] is not None: + optimizer.load_state_dict(checkpoint['optimizer']) + else: + print('WARNING: this checkpoint does not contain an optimizer state. The optimizer will be reinitialized.') + + lr = checkpoint['lr'] + if 'best_acc' in checkpoint and checkpoint['best_acc'] is not None: + best_acc = checkpoint['best_acc'] + + # Training + for epoch in range(st, args.epochs): + print('Training epoch %d.' % epoch) + losses_train = AverageMeter() + batch_time = AverageMeter() + data_time = AverageMeter() + + model.train() + end = time.time() + + for idx, (batch_input, batch_gt) in tqdm(enumerate(train_loader)): + data_time.update(time.time() - end) + batch_size = len(batch_input) + if torch.cuda.is_available(): + batch_gt = batch_gt.cuda() + batch_input = batch_input.cuda() + feat = model(batch_input) + feat = feat.reshape(batch_size, -1, args.hidden_dim) + optimizer.zero_grad() + loss_train = criterion(feat, batch_gt) + losses_train.update(loss_train.item(), batch_size) + loss_train.backward() + optimizer.step() + batch_time.update(time.time() - end) + end = time.time() + if (idx + 1) % opts.print_freq == 0: + print('Train: [{0}][{1}/{2}]\t' + 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' + 'loss {loss.val:.3f} ({loss.avg:.3f})\t'.format( + epoch, idx + 1, len(train_loader), batch_time=batch_time, + data_time=data_time, loss=losses_train)) + sys.stdout.flush() + test_top1 = validate(anchor_loader, test_loader, model) + train_writer.add_scalar('train_loss_supcon', losses_train.avg, epoch + 1) + train_writer.add_scalar('test_top1', test_top1, epoch + 1) + scheduler.step() + # Save latest checkpoint. + chk_path = os.path.join(opts.checkpoint, 'latest_epoch.bin') + print('Saving checkpoint to', chk_path) + torch.save({ + 'epoch': epoch+1, + 'lr': scheduler.get_last_lr(), + 'optimizer': optimizer.state_dict(), + 'model': model.state_dict(), + 'best_acc' : best_acc + }, chk_path) + + # Save best checkpoint + best_chk_path = os.path.join(opts.checkpoint, 'best_epoch.bin'.format(epoch)) + if test_top1 > best_acc: + best_acc = test_top1 + print("save best checkpoint") + torch.save({ + 'epoch': epoch+1, + 'lr': scheduler.get_last_lr(), + 'optimizer': optimizer.state_dict(), + 'model': model.state_dict(), + 'best_acc' : best_acc + }, best_chk_path) + if opts.evaluate: + test_top1 = validate(anchor_loader, test_loader, model) + print(test_top1) +if __name__ == "__main__": + opts = parse_args() + args = get_config(opts.config) + train_with_config(args, opts) + \ No newline at end of file diff --git a/train_mesh.py b/train_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..9c73dd5856f4bc492f1fa47dad903fd1227c60b6 --- /dev/null +++ b/train_mesh.py @@ -0,0 +1,437 @@ +import os +import random +import copy +import time +import sys +import shutil +import argparse +import errno +import math +import numpy as np +from collections import defaultdict, OrderedDict +import tensorboardX +from tqdm import tqdm + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +from torch.optim.lr_scheduler import StepLR + +from lib.utils.tools import * +from lib.model.loss import * +from lib.model.loss_mesh import * +from lib.utils.utils_mesh import * +from lib.utils.utils_smpl import * +from lib.utils.utils_data import * +from lib.utils.learning import * +from lib.data.dataset_mesh import MotionSMPL +from lib.model.model_mesh import MeshRegressor +from torch.utils.data import DataLoader + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default="configs/pretrain.yaml", help="Path to the config file.") + parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', help='checkpoint directory') + parser.add_argument('-p', '--pretrained', default='checkpoint', type=str, metavar='PATH', help='pretrained checkpoint directory') + parser.add_argument('-r', '--resume', default='', type=str, metavar='FILENAME', help='checkpoint to resume (file name)') + parser.add_argument('-e', '--evaluate', default='', type=str, metavar='FILENAME', help='checkpoint to evaluate (file name)') + parser.add_argument('-freq', '--print_freq', default=100) + parser.add_argument('-ms', '--selection', default='latest_epoch.bin', type=str, metavar='FILENAME', help='checkpoint to finetune (file name)') + parser.add_argument('-sd', '--seed', default=0, type=int, help='random seed') + opts = parser.parse_args() + return opts + +def set_random_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + +def validate(test_loader, model, criterion, dataset_name='h36m'): + model.eval() + print(f'===========> validating {dataset_name}') + batch_time = AverageMeter() + losses = AverageMeter() + losses_dict = {'loss_3d_pos': AverageMeter(), + 'loss_3d_scale': AverageMeter(), + 'loss_3d_velocity': AverageMeter(), + 'loss_lv': AverageMeter(), + 'loss_lg': AverageMeter(), + 'loss_a': AverageMeter(), + 'loss_av': AverageMeter(), + 'loss_pose': AverageMeter(), + 'loss_shape': AverageMeter(), + 'loss_norm': AverageMeter(), + } + mpjpes = AverageMeter() + mpves = AverageMeter() + results = defaultdict(list) + smpl = SMPL(args.data_root, batch_size=1).cuda() + J_regressor = smpl.J_regressor_h36m + with torch.no_grad(): + end = time.time() + for idx, (batch_input, batch_gt) in tqdm(enumerate(test_loader)): + batch_size, clip_len = batch_input.shape[:2] + if torch.cuda.is_available(): + batch_gt['theta'] = batch_gt['theta'].cuda().float() + batch_gt['kp_3d'] = batch_gt['kp_3d'].cuda().float() + batch_gt['verts'] = batch_gt['verts'].cuda().float() + batch_input = batch_input.cuda().float() + output = model(batch_input) + output_final = output + if args.flip: + batch_input_flip = flip_data(batch_input) + output_flip = model(batch_input_flip) + output_flip_pose = output_flip[0]['theta'][:, :, :72] + output_flip_shape = output_flip[0]['theta'][:, :, 72:] + output_flip_pose = flip_thetas_batch(output_flip_pose) + output_flip_pose = output_flip_pose.reshape(-1, 72) + output_flip_shape = output_flip_shape.reshape(-1, 10) + output_flip_smpl = smpl( + betas=output_flip_shape, + body_pose=output_flip_pose[:, 3:], + global_orient=output_flip_pose[:, :3], + pose2rot=True + ) + output_flip_verts = output_flip_smpl.vertices.detach()*1000.0 + J_regressor_batch = J_regressor[None, :].expand(output_flip_verts.shape[0], -1, -1).to(output_flip_verts.device) + output_flip_kp3d = torch.matmul(J_regressor_batch, output_flip_verts) # (NT,17,3) + output_flip_back = [{ + 'theta': torch.cat((output_flip_pose.reshape(batch_size, clip_len, -1), output_flip_shape.reshape(batch_size, clip_len, -1)), dim=-1), + 'verts': output_flip_verts.reshape(batch_size, clip_len, -1, 3), + 'kp_3d': output_flip_kp3d.reshape(batch_size, clip_len, -1, 3), + }] + output_final = [{}] + for k, v in output_flip[0].items(): + output_final[0][k] = (output[0][k] + output_flip_back[0][k])*0.5 + output = output_final + loss_dict = criterion(output, batch_gt) + loss = args.lambda_3d * loss_dict['loss_3d_pos'] + \ + args.lambda_scale * loss_dict['loss_3d_scale'] + \ + args.lambda_3dv * loss_dict['loss_3d_velocity'] + \ + args.lambda_lv * loss_dict['loss_lv'] + \ + args.lambda_lg * loss_dict['loss_lg'] + \ + args.lambda_a * loss_dict['loss_a'] + \ + args.lambda_av * loss_dict['loss_av'] + \ + args.lambda_shape * loss_dict['loss_shape'] + \ + args.lambda_pose * loss_dict['loss_pose'] + \ + args.lambda_norm * loss_dict['loss_norm'] + # update metric + losses.update(loss.item(), batch_size) + loss_str = '' + for k, v in loss_dict.items(): + losses_dict[k].update(v.item(), batch_size) + loss_str += '{0} {loss.val:.3f} ({loss.avg:.3f})\t'.format(k, loss=losses_dict[k]) + mpjpe, mpve = compute_error(output, batch_gt) + mpjpes.update(mpjpe, batch_size) + mpves.update(mpve, batch_size) + + for keys in output[0].keys(): + output[0][keys] = output[0][keys].detach().cpu().numpy() + batch_gt[keys] = batch_gt[keys].detach().cpu().numpy() + results['kp_3d'].append(output[0]['kp_3d']) + results['verts'].append(output[0]['verts']) + results['kp_3d_gt'].append(batch_gt['kp_3d']) + results['verts_gt'].append(batch_gt['verts']) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if idx % int(opts.print_freq) == 0: + print('Test: [{0}/{1}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + '{2}' + 'PVE {mpves.val:.3f} ({mpves.avg:.3f})\t' + 'JPE {mpjpes.val:.3f} ({mpjpes.avg:.3f})'.format( + idx, len(test_loader), loss_str, batch_time=batch_time, + loss=losses, mpves=mpves, mpjpes=mpjpes)) + + print(f'==> start concating results of {dataset_name}') + for term in results.keys(): + results[term] = np.concatenate(results[term]) + print(f'==> start evaluating {dataset_name}...') + error_dict = evaluate_mesh(results) + err_str = '' + for err_key, err_val in error_dict.items(): + err_str += '{}: {:.2f}mm \t'.format(err_key, err_val) + print(f'=======================> {dataset_name} validation done: ', loss_str) + print(f'=======================> {dataset_name} validation done: ', err_str) + return losses.avg, error_dict['mpjpe'], error_dict['pa_mpjpe'], error_dict['mpve'], losses_dict + + +def train_epoch(args, opts, model, train_loader, losses_train, losses_dict, mpjpes, mpves, criterion, optimizer, batch_time, data_time, epoch): + model.train() + end = time.time() + for idx, (batch_input, batch_gt) in tqdm(enumerate(train_loader)): + data_time.update(time.time() - end) + batch_size = len(batch_input) + + if torch.cuda.is_available(): + batch_gt['theta'] = batch_gt['theta'].cuda().float() + batch_gt['kp_3d'] = batch_gt['kp_3d'].cuda().float() + batch_gt['verts'] = batch_gt['verts'].cuda().float() + batch_input = batch_input.cuda().float() + output = model(batch_input) + optimizer.zero_grad() + loss_dict = criterion(output, batch_gt) + loss_train = args.lambda_3d * loss_dict['loss_3d_pos'] + \ + args.lambda_scale * loss_dict['loss_3d_scale'] + \ + args.lambda_3dv * loss_dict['loss_3d_velocity'] + \ + args.lambda_lv * loss_dict['loss_lv'] + \ + args.lambda_lg * loss_dict['loss_lg'] + \ + args.lambda_a * loss_dict['loss_a'] + \ + args.lambda_av * loss_dict['loss_av'] + \ + args.lambda_shape * loss_dict['loss_shape'] + \ + args.lambda_pose * loss_dict['loss_pose'] + \ + args.lambda_norm * loss_dict['loss_norm'] + losses_train.update(loss_train.item(), batch_size) + loss_str = '' + for k, v in loss_dict.items(): + losses_dict[k].update(v.item(), batch_size) + loss_str += '{0} {loss.val:.3f} ({loss.avg:.3f})\t'.format(k, loss=losses_dict[k]) + + mpjpe, mpve = compute_error(output, batch_gt) + mpjpes.update(mpjpe, batch_size) + mpves.update(mpve, batch_size) + + loss_train.backward() + optimizer.step() + + batch_time.update(time.time() - end) + end = time.time() + + if idx % int(opts.print_freq) == 0: + print('Train: [{0}][{1}/{2}]\t' + 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' + 'loss {loss.val:.3f} ({loss.avg:.3f})\t' + '{3}' + 'PVE {mpves.val:.3f} ({mpves.avg:.3f})\t' + 'JPE {mpjpes.val:.3f} ({mpjpes.avg:.3f})'.format( + epoch, idx + 1, len(train_loader), loss_str, batch_time=batch_time, + data_time=data_time, loss=losses_train, mpves=mpves, mpjpes=mpjpes)) + sys.stdout.flush() + +def train_with_config(args, opts): + print(args) + try: + os.makedirs(opts.checkpoint) + shutil.copy(opts.config, opts.checkpoint) + except OSError as e: + if e.errno != errno.EEXIST: + raise RuntimeError('Unable to create checkpoint directory:', opts.checkpoint) + train_writer = tensorboardX.SummaryWriter(os.path.join(opts.checkpoint, "logs")) + model_backbone = load_backbone(args) + if args.finetune: + if opts.resume or opts.evaluate: + pass + else: + chk_filename = os.path.join(opts.pretrained, opts.selection) + print('Loading backbone', chk_filename) + checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage)['model_pos'] + model_backbone = load_pretrained_weights(model_backbone, checkpoint) + if args.partial_train: + model_backbone = partial_train_layers(model_backbone, args.partial_train) + model = MeshRegressor(args, backbone=model_backbone, dim_rep=args.dim_rep, hidden_dim=args.hidden_dim, dropout_ratio=args.dropout, num_joints=args.num_joints) + criterion = MeshLoss(loss_type = args.loss_type) + best_jpe = 9999.0 + model_params = 0 + for parameter in model.parameters(): + if parameter.requires_grad == True: + model_params = model_params + parameter.numel() + print('INFO: Trainable parameter count:', model_params) + print('Loading dataset...') + trainloader_params = { + 'batch_size': args.batch_size, + 'shuffle': True, + 'num_workers': 8, + 'pin_memory': True, + 'prefetch_factor': 4, + 'persistent_workers': True + } + testloader_params = { + 'batch_size': args.batch_size, + 'shuffle': False, + 'num_workers': 8, + 'pin_memory': True, + 'prefetch_factor': 4, + 'persistent_workers': True + } + if hasattr(args, "dt_file_h36m"): + mesh_train = MotionSMPL(args, data_split='train', dataset="h36m") + mesh_val = MotionSMPL(args, data_split='test', dataset="h36m") + train_loader = DataLoader(mesh_train, **trainloader_params) + test_loader = DataLoader(mesh_val, **testloader_params) + print('INFO: Training on {} batches (h36m)'.format(len(train_loader))) + + if hasattr(args, "dt_file_pw3d"): + if args.train_pw3d: + mesh_train_pw3d = MotionSMPL(args, data_split='train', dataset="pw3d") + train_loader_pw3d = DataLoader(mesh_train_pw3d, **trainloader_params) + print('INFO: Training on {} batches (pw3d)'.format(len(train_loader_pw3d))) + mesh_val_pw3d = MotionSMPL(args, data_split='test', dataset="pw3d") + test_loader_pw3d = DataLoader(mesh_val_pw3d, **testloader_params) + + + trainloader_img_params = { + 'batch_size': args.batch_size_img, + 'shuffle': True, + 'num_workers': 8, + 'pin_memory': True, + 'prefetch_factor': 4, + 'persistent_workers': True + } + testloader_img_params = { + 'batch_size': args.batch_size_img, + 'shuffle': False, + 'num_workers': 8, + 'pin_memory': True, + 'prefetch_factor': 4, + 'persistent_workers': True + } + + if hasattr(args, "dt_file_coco"): + mesh_train_coco = MotionSMPL(args, data_split='train', dataset="coco") + mesh_val_coco = MotionSMPL(args, data_split='test', dataset="coco") + train_loader_coco = DataLoader(mesh_train_coco, **trainloader_img_params) + test_loader_coco = DataLoader(mesh_val_coco, **testloader_img_params) + print('INFO: Training on {} batches (coco)'.format(len(train_loader_coco))) + + if torch.cuda.is_available(): + model = nn.DataParallel(model) + model = model.cuda() + + chk_filename = os.path.join(opts.checkpoint, "latest_epoch.bin") + if os.path.exists(chk_filename): + opts.resume = chk_filename + if opts.resume or opts.evaluate: + chk_filename = opts.evaluate if opts.evaluate else opts.resume + print('Loading checkpoint', chk_filename) + checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage) + model.load_state_dict(checkpoint['model'], strict=True) + if not opts.evaluate: + optimizer = optim.AdamW( + [ {"params": filter(lambda p: p.requires_grad, model.module.backbone.parameters()), "lr": args.lr_backbone}, + {"params": filter(lambda p: p.requires_grad, model.module.head.parameters()), "lr": args.lr_head}, + ], lr=args.lr_backbone, + weight_decay=args.weight_decay + ) + scheduler = StepLR(optimizer, step_size=1, gamma=args.lr_decay) + st = 0 + if opts.resume: + st = checkpoint['epoch'] + if 'optimizer' in checkpoint and checkpoint['optimizer'] is not None: + optimizer.load_state_dict(checkpoint['optimizer']) + else: + print('WARNING: this checkpoint does not contain an optimizer state. The optimizer will be reinitialized.') + lr = checkpoint['lr'] + if 'best_jpe' in checkpoint and checkpoint['best_jpe'] is not None: + best_jpe = checkpoint['best_jpe'] + + # Training + for epoch in range(st, args.epochs): + print('Training epoch %d.' % epoch) + losses_train = AverageMeter() + losses_dict = { + 'loss_3d_pos': AverageMeter(), + 'loss_3d_scale': AverageMeter(), + 'loss_3d_velocity': AverageMeter(), + 'loss_lv': AverageMeter(), + 'loss_lg': AverageMeter(), + 'loss_a': AverageMeter(), + 'loss_av': AverageMeter(), + 'loss_pose': AverageMeter(), + 'loss_shape': AverageMeter(), + 'loss_norm': AverageMeter(), + } + mpjpes = AverageMeter() + mpves = AverageMeter() + batch_time = AverageMeter() + data_time = AverageMeter() + + if hasattr(args, "dt_file_h36m") and epoch < args.warmup_h36m: + train_epoch(args, opts, model, train_loader, losses_train, losses_dict, mpjpes, mpves, criterion, optimizer, batch_time, data_time, epoch) + test_loss, test_mpjpe, test_pa_mpjpe, test_mpve, test_losses_dict = validate(test_loader, model, criterion, 'h36m') + for k, v in test_losses_dict.items(): + train_writer.add_scalar('test_loss/'+k, v.avg, epoch + 1) + train_writer.add_scalar('test_loss', test_loss, epoch + 1) + train_writer.add_scalar('test_mpjpe', test_mpjpe, epoch + 1) + train_writer.add_scalar('test_pa_mpjpe', test_pa_mpjpe, epoch + 1) + train_writer.add_scalar('test_mpve', test_mpve, epoch + 1) + + if hasattr(args, "dt_file_coco") and epoch < args.warmup_coco: + train_epoch(args, opts, model, train_loader_coco, losses_train, losses_dict, mpjpes, mpves, criterion, optimizer, batch_time, data_time, epoch) + + if hasattr(args, "dt_file_pw3d"): + if args.train_pw3d: + train_epoch(args, opts, model, train_loader_pw3d, losses_train, losses_dict, mpjpes, mpves, criterion, optimizer, batch_time, data_time, epoch) + test_loss_pw3d, test_mpjpe_pw3d, test_pa_mpjpe_pw3d, test_mpve_pw3d, test_losses_dict_pw3d = validate(test_loader_pw3d, model, criterion, 'pw3d') + for k, v in test_losses_dict_pw3d.items(): + train_writer.add_scalar('test_loss_pw3d/'+k, v.avg, epoch + 1) + train_writer.add_scalar('test_loss_pw3d', test_loss_pw3d, epoch + 1) + train_writer.add_scalar('test_mpjpe_pw3d', test_mpjpe_pw3d, epoch + 1) + train_writer.add_scalar('test_pa_mpjpe_pw3d', test_pa_mpjpe_pw3d, epoch + 1) + train_writer.add_scalar('test_mpve_pw3d', test_mpve_pw3d, epoch + 1) + + for k, v in losses_dict.items(): + train_writer.add_scalar('train_loss/'+k, v.avg, epoch + 1) + train_writer.add_scalar('train_loss', losses_train.avg, epoch + 1) + train_writer.add_scalar('train_mpjpe', mpjpes.avg, epoch + 1) + train_writer.add_scalar('train_mpve', mpves.avg, epoch + 1) + + # Decay learning rate exponentially + scheduler.step() + # Save latest checkpoint. + chk_path = os.path.join(opts.checkpoint, 'latest_epoch.bin') + print('Saving checkpoint to', chk_path) + torch.save({ + 'epoch': epoch+1, + 'lr': scheduler.get_last_lr(), + 'optimizer': optimizer.state_dict(), + 'model': model.state_dict(), + 'best_jpe' : best_jpe + }, chk_path) + + # Save checkpoint if necessary. + if (epoch+1) % args.checkpoint_frequency == 0: + chk_path = os.path.join(opts.checkpoint, 'epoch_{}.bin'.format(epoch)) + print('Saving checkpoint to', chk_path) + torch.save({ + 'epoch': epoch+1, + 'lr': scheduler.get_last_lr(), + 'optimizer': optimizer.state_dict(), + 'model': model.state_dict(), + 'best_jpe' : best_jpe + }, chk_path) + + if hasattr(args, "dt_file_pw3d"): + best_jpe_cur = test_mpjpe_pw3d + else: + best_jpe_cur = test_mpjpe + # Save best checkpoint. + best_chk_path = os.path.join(opts.checkpoint, 'best_epoch.bin'.format(epoch)) + if best_jpe_cur < best_jpe: + best_jpe = best_jpe_cur + print("save best checkpoint") + torch.save({ + 'epoch': epoch+1, + 'lr': scheduler.get_last_lr(), + 'optimizer': optimizer.state_dict(), + 'model': model.state_dict(), + 'best_jpe' : best_jpe + }, best_chk_path) + + if opts.evaluate: + if hasattr(args, "dt_file_h36m"): + test_loss, test_mpjpe, test_pa_mpjpe, test_mpve, _ = validate(test_loader, model, criterion, 'h36m') + if hasattr(args, "dt_file_pw3d"): + test_loss_pw3d, test_mpjpe_pw3d, test_pa_mpjpe_pw3d, test_mpve_pw3d, _ = validate(test_loader_pw3d, model, criterion, 'pw3d') + +if __name__ == "__main__": + opts = parse_args() + set_random_seed(opts.seed) + args = get_config(opts.config) + train_with_config(args, opts) \ No newline at end of file