diff --git a/.gitattributes b/.gitattributes index c7d9f3332a950355d5a77d85000f05e6f45435ea..1e344cbe6b61128fc25d39a484f466b934eb568d 100644 --- a/.gitattributes +++ b/.gitattributes @@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +**/*.gif filter=lfs diff=lfs merge=lfs -text +**/*.png filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..74a3b92ba2e58b784915ce0133c8d8cbc5c39b07 --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +*.pyc +.idea +*ZZZ* +sandbox +__pycache__ +wandb +temp +code diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..01e1483b59134eab44449bafe21dfec846e66308 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "externals/camviz"] + path = externals/camviz + url = git@github.com:TRI-ML/camviz.git diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000000000000000000000000000000000000..f59548d1a37a9aeec57a98ee62a7244716eefb6c --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,409 @@ +Attribution-NonCommercial 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial 4.0 International Public +License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial 4.0 International Public License ("Public +License"). To the extent this Public License may be interpreted as a +contract, You are granted the Licensed Rights in consideration of Your +acceptance of these terms and conditions, and the Licensor grants You +such rights in consideration of benefits the Licensor receives from +making the Licensed Material available under these terms and +conditions. + + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + d. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + e. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + f. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + g. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + h. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + i. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + j. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + k. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + l. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + 4. If You Share Adapted Material You produce, the Adapter's + License You apply must not prevent recipients of the Adapted + Material from complying with this Public License. + + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material; and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. + + diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..96585d88d394ab2aba5d315c50d4c41c09bfc0ac --- /dev/null +++ b/Makefile @@ -0,0 +1,56 @@ +PROJECT ?= vidar +WORKSPACE ?= /workspace/$(PROJECT) +DOCKER_IMAGE ?= ${PROJECT}:latest + +SHMSIZE ?= 444G +WANDB_MODE ?= run +DOCKER_OPTS := \ + --name ${PROJECT} \ + --rm -it \ + --shm-size=${SHMSIZE} \ + -e AWS_DEFAULT_REGION \ + -e AWS_ACCESS_KEY_ID \ + -e AWS_SECRET_ACCESS_KEY \ + -e WANDB_API_KEY \ + -e WANDB_ENTITY \ + -e WANDB_MODE \ + -e HOST_HOSTNAME= \ + -e OMP_NUM_THREADS=1 -e KMP_AFFINITY="granularity=fine,compact,1,0" \ + -e OMPI_ALLOW_RUN_AS_ROOT=1 \ + -e OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1 \ + -e NCCL_DEBUG=VERSION \ + -e DISPLAY=${DISPLAY} \ + -e XAUTHORITY \ + -e NVIDIA_DRIVER_CAPABILITIES=all \ + -v ~/.aws:/root/.aws \ + -v /root/.ssh:/root/.ssh \ + -v ~/.cache:/root/.cache \ + -v /dev/null:/dev/raw1394 \ + -v /mnt/fsx/tmp:/tmp \ + -v /tmp/.X11-unix/X0:/tmp/.X11-unix/X0 \ + -v /var/run/docker.sock:/var/run/docker.sock \ + -v /home/jiadingfang/datasets:/data \ + -v ${PWD}:${WORKSPACE} \ + -w ${WORKSPACE} \ + --privileged \ + --ipc=host \ + --network=host + +NGPUS=$(shell nvidia-smi -L | wc -l) + +all: clean + +clean: + find . -name "*.pyc" | xargs rm -f && \ + find . -name "__pycache__" | xargs rm -rf + +docker-build: + docker build \ + -f docker/Dockerfile \ + -t ${DOCKER_IMAGE} . + +docker-interactive: docker-build + nvidia-docker run ${DOCKER_OPTS} ${DOCKER_IMAGE} bash + +docker-run: docker-build + nvidia-docker run ${DOCKER_OPTS} ${DOCKER_IMAGE} bash -c "${COMMAND}" diff --git a/configs/overfit/ddad_tiny.yaml b/configs/overfit/ddad_tiny.yaml new file mode 100644 index 0000000000000000000000000000000000000000..121e845fc91c8e033c957d73a9c2b6356f5b21fc --- /dev/null +++ b/configs/overfit/ddad_tiny.yaml @@ -0,0 +1,31 @@ +wrapper: + recipe: wrapper|default + max_epochs: 1 +arch: + model: + file: depth/SelfSupervisedModel + networks: + depth: + recipe: networks/mono_depth_res_net|default + depth_range: [0.1,100.0] + pose: + recipe: networks/pose_net|default + losses: + reprojection: + recipe: losses/reprojection|default + smoothness: + recipe: losses/smoothness|default +evaluation: + depth: + recipe: evaluation/depth|ddad_resize +optimizers: + depth: + recipe: optimizers|adam_20_05 + pose: + recipe: optimizers|adam_20_05 +datasets: + train: + recipe: datasets/ddad_tiny|train_selfsup_front + validation: + recipe: datasets/ddad_tiny|validation_front + diff --git a/configs/overfit/generic.yaml b/configs/overfit/generic.yaml new file mode 100644 index 0000000000000000000000000000000000000000..36be6b63bbeca731d631c8bad8fe49c16b36e3f2 --- /dev/null +++ b/configs/overfit/generic.yaml @@ -0,0 +1,27 @@ +wrapper: + recipe: wrapper|default + max_epochs: 1 +arch: + model: + file: depth/SelfSupervisedModel + networks: + depth: + recipe: networks/mono_depth_res_net|default + depth_range: [0.1,100.0] + pose: + recipe: networks/pose_net|default + losses: + reprojection: + recipe: losses/reprojection|default + smoothness: + recipe: losses/smoothness|default +optimizers: + depth: + recipe: optimizers|adam_20_05 + pose: + recipe: optimizers|adam_20_05 +datasets: + train: + recipe: datasets/generic|default_train + validation: + recipe: datasets/generic|default_validation diff --git a/configs/overfit/kitti_tiny.yaml b/configs/overfit/kitti_tiny.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a8810a5f7944502ca2e94e0f79b882bfb814af47 --- /dev/null +++ b/configs/overfit/kitti_tiny.yaml @@ -0,0 +1,31 @@ +wrapper: + recipe: wrapper|default + max_epochs: 1 +arch: + model: + file: depth/SelfSupervisedModel + networks: + depth: + recipe: networks/mono_depth_res_net|default + depth_range: [0.1,100.0] + pose: + recipe: networks/pose_net|default + losses: + reprojection: + recipe: losses/reprojection|default + smoothness: + recipe: losses/smoothness|default +evaluation: + depth: + recipe: evaluation/depth|kitti_resize +optimizers: + depth: + recipe: optimizers|adam_20_05 + pose: + recipe: optimizers|adam_20_05 +datasets: + train: + recipe: datasets/kitti_tiny|train_selfsup_mr + validation: + recipe: datasets/kitti_tiny|validation_mr + diff --git a/configs/papers/define/scannet_temporal_test.yaml b/configs/papers/define/scannet_temporal_test.yaml new file mode 100755 index 0000000000000000000000000000000000000000..9338a7da0c89fe7dafb46fd11680e18bb58aca4f --- /dev/null +++ b/configs/papers/define/scannet_temporal_test.yaml @@ -0,0 +1,79 @@ +wrapper: + seed: 42 + min_epochs: 0 + max_epochs: 0 + find_unused_parameters: True + flip_lr_prob: 0.5 + validate_first: True + validate_flipped: False + sync_batch_norm: True +evaluation: + rgb: + only_first: False + depth: + crop: '' + min_depth: 0.2 + max_depth: 10.0 + scale_output: resize + median_scaling: True + post_process: False +arch: + model: + checkpoint: /data/vidar/models/scannet_full.ckpt + file: perceiver/DefineGenericModel + use_pose_noise: [] + use_virtual_cameras: [] + virtual_cameras_eval: False + use_virtual_rgb: False + augment_canonical: False + encode_train: all + encode_eval: all + decode_train: all + decode_eval: all + decode_encodes: False + task_weights: [1.0,1.0] + scale_loss: True + sample_decoded_queries: 0.0 + network: + tasks: [depth] + depth_range: [0.1,10.] + image_shape: [128,192] + num_bands_orig: 20 + num_bands_dirs: 10 + max_resolution_orig: 60 + max_resolution_dirs: 60 + to_world: True + d_latents: 512 + num_latents: 1024 + hidden_dropout_prob: 0.25 + num_cross_attention_heads: 1 + num_self_attends_per_block: 8 + num_self_attention_heads: 8 + decoder_num_heads: 1 + rgb_feat_dim: 960 + rgb_feat_type: resnet_all + downsample_encoder: 4 + downsample_decoder: 4 + upsample_convex: 'convex' + encoder_with_rgb: True + decoder_with_rgb: False + output_mode: inv_depth + sample_encoding_rays: 0 + with_monodepth: False +datasets: + validation: + name: [ScanNetTemporal] + path: [/data/vidar/scannet_processed] + split: [test] + context: [-2,2] + stride: [10] + labels: [depth,pose] + cameras: [[0]] + dataloader: + batch_size: 1 + pin_memory: True + num_workers: 4 + augmentation: + resize: [128,192] + resize_supervision: True + preserve_depth: True diff --git a/configs/papers/define/scannet_temporal_test_context_1.yaml b/configs/papers/define/scannet_temporal_test_context_1.yaml new file mode 100755 index 0000000000000000000000000000000000000000..67d4544a838105f58fc3e1557fcbf6a5fa3c6246 --- /dev/null +++ b/configs/papers/define/scannet_temporal_test_context_1.yaml @@ -0,0 +1,79 @@ +wrapper: + seed: 42 + min_epochs: 0 + max_epochs: 0 + find_unused_parameters: True + flip_lr_prob: 0.5 + validate_first: True + validate_flipped: False + sync_batch_norm: True +evaluation: + rgb: + only_first: False + depth: + crop: '' + min_depth: 0.2 + max_depth: 10.0 + scale_output: resize + median_scaling: True + post_process: False +arch: + model: + checkpoint: /data/vidar/models/scannet_full.ckpt + file: perceiver/DefineGenericModel + use_pose_noise: [] + use_virtual_cameras: [] + virtual_cameras_eval: False + use_virtual_rgb: False + augment_canonical: False + encode_train: all + encode_eval: all + decode_train: all + decode_eval: all + decode_encodes: False + task_weights: [1.0,1.0] + scale_loss: True + sample_decoded_queries: 0.0 + network: + tasks: [depth] + depth_range: [0.1,10.] + image_shape: [128,192] + num_bands_orig: 20 + num_bands_dirs: 10 + max_resolution_orig: 60 + max_resolution_dirs: 60 + to_world: True + d_latents: 512 + num_latents: 1024 + hidden_dropout_prob: 0.25 + num_cross_attention_heads: 1 + num_self_attends_per_block: 8 + num_self_attention_heads: 8 + decoder_num_heads: 1 + rgb_feat_dim: 960 + rgb_feat_type: resnet_all + downsample_encoder: 4 + downsample_decoder: 4 + upsample_convex: 'convex' + encoder_with_rgb: True + decoder_with_rgb: False + output_mode: inv_depth + sample_encoding_rays: 0 + with_monodepth: False +datasets: + validation: + name: [ScanNetTemporal] + path: [/data/vidar/scannet_processed] + split: [test] + context: [1] + stride: [10] + labels: [depth,pose] + cameras: [[0]] + dataloader: + batch_size: 1 + pin_memory: True + num_workers: 4 + augmentation: + resize: [128,192] + resize_supervision: True + preserve_depth: True diff --git a/configs/papers/depthformer/inference_kitti.yaml b/configs/papers/depthformer/inference_kitti.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d0b8ec992b1b4508ec2f39f00cdfa452b9852fc8 --- /dev/null +++ b/configs/papers/depthformer/inference_kitti.yaml @@ -0,0 +1,31 @@ +wrapper: + recipe: wrapper|default +arch: + model: + file: depth/DepthFormerModel + checkpoint: /data/vidar/models/papers/final/DepthFormer_MR_selfsup_KITTI.ckpt + warp_context: [-1,1] + match_context: [-1] + motion_masking: True + matching_augmentation: False + freeze_teacher_and_pose: 40 + networks: + transformer: + recipe: networks/transformer|depthformer + mono_depth: + recipe: networks/mono_depth_res_net|default + depth_range: [0.1,100.0] + multi_depth: + recipe: networks/multi_depth_res_net|depthformer + pose: + recipe: networks/pose_net|default +evaluation: + depth: + recipe: evaluation/depth|kitti_resize +datasets: + validation: + recipe: datasets/kitti|validation_mr + labels: [depth,pose] + context: [-1,1] +save: + recipe: save|depth_splitname diff --git a/configs/papers/fsm/inference_ddad.yaml b/configs/papers/fsm/inference_ddad.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9eb1b4e943d85cf653d1764e6e127030bccd6ae8 --- /dev/null +++ b/configs/papers/fsm/inference_ddad.yaml @@ -0,0 +1,19 @@ +wrapper: + recipe: wrapper|default +arch: + model: + file: depth/FSMModel + checkpoint: /data/vidar/models/papers/final/FSM_MR_6cams_DDAD.ckpt + networks: + depth: + recipe: networks/focal_depth_res_net|fsm_ddad + pose: + recipe: networks/conv_pose_net|default +evaluation: + depth: + recipe: evaluation/depth|ddad_resize +datasets: + validation: + recipe: datasets/ddad|validation_6cams +save: + recipe: save|depth_splitname diff --git a/configs/papers/packnet/inference_packnet.yaml b/configs/papers/packnet/inference_packnet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3ec7bbe8d096bdd50c712eaa86fd0e090760b2e8 --- /dev/null +++ b/configs/papers/packnet/inference_packnet.yaml @@ -0,0 +1,20 @@ +wrapper: + recipe: wrapper|default +arch: + model: + file: depth/SelfSupervisedModel + checkpoint: /data/vidar/models/papers/final/PackNet_MR_selfsup_KITTI.ckpt + networks: + depth: + recipe: networks/packnet|default + depth_range: [0.1,100.0] + pose: + recipe: networks/conv_pose_net|default +evaluation: + depth: + recipe: evaluation/depth|kitti_resize +datasets: + validation: + recipe: datasets/kitti|validation_mr +save: + recipe: save|depth_splitname diff --git a/configs/papers/packnet/inference_resnet.yaml b/configs/papers/packnet/inference_resnet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4a0bba4dacf818fe9182cef1b2dffdd5cfa9db1f --- /dev/null +++ b/configs/papers/packnet/inference_resnet.yaml @@ -0,0 +1,20 @@ +wrapper: + recipe: wrapper|default +arch: + model: + file: depth/SelfSupervisedModel + checkpoint: /data/vidar/models/papers/final/ResNet18_MR_selfsup_KITTI.ckpt + networks: + depth: + recipe: networks/mono_depth_res_net|default + depth_range: [0.1,100.0] + pose: + recipe: networks/pose_net|default +evaluation: + depth: + recipe: evaluation/depth|kitti_resize +datasets: + validation: + recipe: datasets/kitti|validation_mr +save: + recipe: save|depth_splitname diff --git a/configs/papers/selfcalib/ds_euroc.yaml b/configs/papers/selfcalib/ds_euroc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..63664c7229c50aadfc548d9626103c345054110e --- /dev/null +++ b/configs/papers/selfcalib/ds_euroc.yaml @@ -0,0 +1,40 @@ +wrapper: + recipe: wrapper|default + validate_first: False +arch: + model: + file: depth/SelfSupervisedModel + use_gt_intrinsics: False + networks: + depth: + recipe: networks/mono_depth_res_net|default + depth_range: [0.1,80.0] + pose: + recipe: networks/pose_net|default + intrinsics: + file: intrinsics/IntrinsicsNet + camera_model: 'DS' + shape: [256, 384] + losses: + reprojection: + recipe: losses/reprojection|default + smoothness: + recipe: losses/smoothness|default +evaluation: + depth: + recipe: evaluation/depth|kitti_resize +optimizers: + depth: + recipe: optimizers|adam_20_05 + lr: 0.0002 + pose: + recipe: optimizers|adam_20_05 + lr: 0.0002 + intrinsics: + recipe: optimizers|adam_20_05 + lr: 0.01 +datasets: + train: + recipe: datasets/euroc|train_selfsup + validation: + recipe: datasets/euroc|validation diff --git a/configs/papers/selfcalib/eucm_euroc.yaml b/configs/papers/selfcalib/eucm_euroc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6b4afceb737c59f9b9374475965f4ebd7c8ea844 --- /dev/null +++ b/configs/papers/selfcalib/eucm_euroc.yaml @@ -0,0 +1,40 @@ +wrapper: + recipe: wrapper|default + validate_first: False +arch: + model: + file: depth/SelfSupervisedModel + use_gt_intrinsics: False + networks: + depth: + recipe: networks/mono_depth_res_net|default + depth_range: [0.1,80.0] + pose: + recipe: networks/pose_net|default + intrinsics: + file: intrinsics/IntrinsicsNet + camera_model: 'EUCM' + shape: [256, 384] + losses: + reprojection: + recipe: losses/reprojection|default + smoothness: + recipe: losses/smoothness|default +evaluation: + depth: + recipe: evaluation/depth|kitti_resize +optimizers: + depth: + recipe: optimizers|adam_20_05 + lr: 0.0002 + pose: + recipe: optimizers|adam_20_05 + lr: 0.0002 + intrinsics: + recipe: optimizers|adam_20_05 + lr: 0.01 +datasets: + train: + recipe: datasets/euroc|train_selfsup + validation: + recipe: datasets/euroc|validation diff --git a/configs/papers/selfcalib/ucm_euroc.yaml b/configs/papers/selfcalib/ucm_euroc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e8e2edc37ce9a96d1fbb5594c2a0cb0e4da84684 --- /dev/null +++ b/configs/papers/selfcalib/ucm_euroc.yaml @@ -0,0 +1,40 @@ +wrapper: + recipe: wrapper|default + validate_first: False +arch: + model: + file: depth/SelfSupervisedModel + use_gt_intrinsics: False + networks: + depth: + recipe: networks/mono_depth_res_net|default + depth_range: [0.1,80.0] + pose: + recipe: networks/pose_net|default + intrinsics: + file: intrinsics/IntrinsicsNet + camera_model: 'UCM' + shape: [256, 384] + losses: + reprojection: + recipe: losses/reprojection|default + smoothness: + recipe: losses/smoothness|default +evaluation: + depth: + recipe: evaluation/depth|kitti_resize +optimizers: + depth: + recipe: optimizers|adam_20_05 + lr: 0.0002 + pose: + recipe: optimizers|adam_20_05 + lr: 0.0002 + intrinsics: + recipe: optimizers|adam_20_05 + lr: 0.01 +datasets: + train: + recipe: datasets/euroc|train_selfsup + validation: + recipe: datasets/euroc|validation diff --git a/configs/papers/selfcalib/ucm_kitti.yaml b/configs/papers/selfcalib/ucm_kitti.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2c47d53942450bf5aae7dd0ff7a01c9bb07f95c5 --- /dev/null +++ b/configs/papers/selfcalib/ucm_kitti.yaml @@ -0,0 +1,44 @@ +wrapper: + recipe: wrapper|default + validate_first: False +arch: + model: + file: depth/SelfSupervisedModel + use_gt_intrinsics: False + networks: + depth: + recipe: networks/mono_depth_res_net|default + depth_range: [0.1,100.0] + pose: + recipe: networks/pose_net|default + intrinsics: + file: intrinsics/IntrinsicsNet + camera_model: 'UCM' + shape: [192, 640] + losses: + reprojection: + recipe: losses/reprojection|default + smoothness: + recipe: losses/smoothness|default +evaluation: + depth: + recipe: evaluation/depth|kitti_resize +optimizers: + depth: + recipe: optimizers|adam_20_05 + lr: 0.0002 + pose: + recipe: optimizers|adam_20_05 + lr: 0.0002 + intrinsics: + recipe: optimizers|adam_20_05 + lr: 0.01 +datasets: + train: + recipe: datasets/kitti|train_selfsup_mr + dataloader: + batch_size: 8 + validation: + recipe: datasets/kitti|validation_mr + dataloader: + batch_size: 1 diff --git a/configs/recipes/checkpoint.yaml b/configs/recipes/checkpoint.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ffcb92e67db721cbbc10bc72d1ac01048e30cbf3 --- /dev/null +++ b/configs/recipes/checkpoint.yaml @@ -0,0 +1,8 @@ +default: + folder: /data/vidar/checkpoints + s3_bucket: miru-us-east-1/vidar/checkpoints + save_code: True + keep_top: 5 +default_local: + folder: /data/vidar/checkpoints + keep_top: 5 diff --git a/configs/recipes/datasets/ddad.yaml b/configs/recipes/datasets/ddad.yaml new file mode 100644 index 0000000000000000000000000000000000000000..583ebfbc3e1d4d1bc5d1ebce4bfc7d046950f76b --- /dev/null +++ b/configs/recipes/datasets/ddad.yaml @@ -0,0 +1,32 @@ +train_selfsup_6cams: + name: [Ouroboros] + path: [/data/vidar/DDAD/ddad.json] + split: [train] + augmentation: + jittering: [0.2, 0.2, 0.2, 0.05] + resize: [384, 640] + dataloader: + batch_size: 1 + pin_memory: True + num_workers: 16 + context: [-1,1] + labels: [depth] + cameras: [[1],[5],[6],[7],[8],[9]] + depth_type: [lidar] + repeat: [100] +validation_6cams: + name: [Ouroboros] + path: [/data/vidar/DDAD/ddad.json] + split: [val] + augmentation: + resize: [384, 640] + dataloader: + batch_size: 1 + pin_memory: True + num_workers: 16 + context: [] + labels: [depth] + cameras: [[1],[5],[6],[7],[8],[9]] + depth_type: [lidar] + + diff --git a/configs/recipes/datasets/ddad_tiny.yaml b/configs/recipes/datasets/ddad_tiny.yaml new file mode 100644 index 0000000000000000000000000000000000000000..08bbc9263e1ec555736a7a8266f3e879d81f61f5 --- /dev/null +++ b/configs/recipes/datasets/ddad_tiny.yaml @@ -0,0 +1,32 @@ +train_selfsup_front: + name: [Ouroboros] + path: [/data/vidar/DDAD_tiny/ddad_tiny.json] + split: [train] + augmentation: + jittering: [0.2, 0.2, 0.2, 0.05] + resize: [384, 640] + dataloader: + batch_size: 1 + pin_memory: True + num_workers: 16 + context: [-1,1] + labels: [depth] + cameras: [[1]] + depth_type: [lidar] + repeat: [100] +validation_front: + name: [Ouroboros] + path: [/data/vidar/DDAD_tiny/ddad_tiny.json] + split: [train] + augmentation: + resize: [384, 640] + dataloader: + batch_size: 1 + pin_memory: True + num_workers: 16 + context: [] + labels: [depth] + cameras: [[1]] + depth_type: [lidar] + + diff --git a/configs/recipes/datasets/euroc.yaml b/configs/recipes/datasets/euroc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..02949ad0297a2961ce5fe9f80093f1964db9d8eb --- /dev/null +++ b/configs/recipes/datasets/euroc.yaml @@ -0,0 +1,30 @@ +train_selfsup: + name: [EUROC] + path: [/data/vidar/euroc/euroc_cam/cam0] + augmentation: + jittering: [0.2, 0.2, 0.2, 0.05] + resize: [256, 384] + dataloader: + batch_size: 16 + pin_memory: True + num_workers: 16 + context: [-1, 1] + strides: [[49999872, 50000128]] + cameras: [[0]] + split: [euroc-train] + labels: [] + repeat: [1] +validation: + name: [EUROC] + path: [/data/vidar/euroc/euroc_has_depth/V2_01_easy_has_depth/mav0/] + augmentation: + resize: [256, 384] + dataloader: + batch_size: 1 + pin_memory: True + num_workers: 16 + split: [euroc-val] + labels: [depth] + context: [] + strides: [[0, 0]] + cameras: [[0]] diff --git a/configs/recipes/datasets/generic.yaml b/configs/recipes/datasets/generic.yaml new file mode 100644 index 0000000000000000000000000000000000000000..48bd3864c6a18e7e57f0f7dbe063dcfb83f0228f --- /dev/null +++ b/configs/recipes/datasets/generic.yaml @@ -0,0 +1,30 @@ +default_train: + name: [Generic] + path: [data/generic] + split: [''] + context: [-1,1] + cameras: [[0]] + labels: [] + extension: [jpg] + repeat: [100] + augmentation: + jittering: [0.2, 0.2, 0.2, 0.05] + resize: [384, 640] + dataloader: + batch_size: 1 + pin_memory: True + num_workers: 16 +default_validation: + name: [Generic] + path: [data/generic] + split: [''] + context: [] + cameras: [[0]] + labels: [] + extension: [jpg] + augmentation: + resize: [384, 640] + dataloader: + batch_size: 1 + pin_memory: True + num_workers: 16 diff --git a/configs/recipes/datasets/kitti.yaml b/configs/recipes/datasets/kitti.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0d481a4a8289661b6337b9ea67d541c225f8d7b9 --- /dev/null +++ b/configs/recipes/datasets/kitti.yaml @@ -0,0 +1,31 @@ +train_selfsup_mr: + name: [KITTI] + path: [/data/vidar/KITTI_raw] + split: [data_splits/eigen_zhou_files.txt] + augmentation: + jittering: [0.2, 0.2, 0.2, 0.05] + resize: [192, 640] + dataloader: + batch_size: 1 + pin_memory: True + num_workers: 16 + context: [-1,1] + labels: [] + cameras: [[0]] + single_intrinsics: [True] + repeat: [1] +validation_mr: + name: [KITTI] + path: [/data/vidar/KITTI_raw] + split: [data_splits/eigen_test_files.txt] + augmentation: + resize: [192, 640] + dataloader: + batch_size: 1 + pin_memory: True + num_workers: 16 + context: [] + labels: [depth] + cameras: [[0]] + single_intrinsics: [True] + depth_type: [velodyne,groundtruth] diff --git a/configs/recipes/datasets/kitti_tiny.yaml b/configs/recipes/datasets/kitti_tiny.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f0f547ec63eb3fb05f1a5a46f8aec10963634b38 --- /dev/null +++ b/configs/recipes/datasets/kitti_tiny.yaml @@ -0,0 +1,32 @@ +train_selfsup_mr: + name: [KITTI] + path: [/data/vidar/KITTI_tiny] + split: [kitti_tiny.txt] + augmentation: + jittering: [0.2,0.2,0.2,0.05] + resize: [192,640] + dataloader: + batch_size: 1 + pin_memory: True + num_workers: 16 + context: [-1,1] + labels: [] + cameras: [[0]] + single_intrinsics: [True] + repeat: [100] + depth_type: [velodyne] +validation_mr: + name: [KITTI] + path: [/data/vidar/KITTI_tiny] + split: [kitti_tiny.txt] + augmentation: + resize: [192,640] + dataloader: + batch_size: 1 + pin_memory: True + num_workers: 16 + context: [] + labels: [depth] + cameras: [[0]] + single_intrinsics: [True] + depth_type: [velodyne] diff --git a/configs/recipes/datasets/vkitti2.yaml b/configs/recipes/datasets/vkitti2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d3dca65f125b41ddfa01409936fd2c947e6baba9 --- /dev/null +++ b/configs/recipes/datasets/vkitti2.yaml @@ -0,0 +1,16 @@ +tiny: + name: [VKITTI2] + path: [/data/vidar/VKITTI2_tiny] + split: [''] + augmentation: + resize: [192, 640] + resize_supervision: True + dataloader: + batch_size: 1 + pin_memory: True + num_workers: 16 + context: [-1,1] + labels: [depth,pose] + labels_context: [depth,pose] + cameras: [[0]] + depth_type: [zbuffer] \ No newline at end of file diff --git a/configs/recipes/evaluation/depth.yaml b/configs/recipes/evaluation/depth.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bf7ced1473d04ab3c4a6c3ae41b6b2b4c59d5004 --- /dev/null +++ b/configs/recipes/evaluation/depth.yaml @@ -0,0 +1,21 @@ +kitti_resize: + crop: garg + min_depth: 0.0 + max_depth: 80.0 + scale_output: resize + median_scaling: True + post_process: True +kitti_crop: + crop: garg + min_depth: 0.0 + max_depth: 80.0 + scale_output: top-center + median_scaling: True + post_process: True +ddad_resize: + crop: '' + min_depth: 0.0 + max_depth: 200.0 + scale_output: resize + median_scaling: True + post_process: True diff --git a/configs/recipes/losses/reprojection.yaml b/configs/recipes/losses/reprojection.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a9206e53f8593d257e761df0bb8174e348dff974 --- /dev/null +++ b/configs/recipes/losses/reprojection.yaml @@ -0,0 +1,9 @@ +default: + file: ReprojectionLoss + automasking: True + reprojection_reduce_op: min + jitter_identity_reprojection: 0.00001 + photometric: + file: PhotometricLoss + weight: 1.0 + alpha: 0.85 \ No newline at end of file diff --git a/configs/recipes/losses/smoothness.yaml b/configs/recipes/losses/smoothness.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f9be726f317d6e03af9464fabba8ba38a1da1554 --- /dev/null +++ b/configs/recipes/losses/smoothness.yaml @@ -0,0 +1,5 @@ +default: + file: SmoothnessLoss + normalize: True + weight: 0.0001 + gamma: 0.5 diff --git a/configs/recipes/losses/supervised_depth.yaml b/configs/recipes/losses/supervised_depth.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b2736b71fb79ddd3c5f3504e3678e98e58747654 --- /dev/null +++ b/configs/recipes/losses/supervised_depth.yaml @@ -0,0 +1,40 @@ +huber: + file: SupervisedDepthLoss + gamma: 0.5 + method: huber + mask_sparse: True +abs_rel: + file: SupervisedDepthLoss + gamma: 0.5 + method: abs_rel + mask_sparse: True +l1log: + file: SupervisedDepthLoss + gamma: 0.5 + method: l1log + mask_sparse: True +l1: + file: SupervisedDepthLoss + gamma: 0.5 + method: l1 + mask_sparse: True +mse: + file: SupervisedDepthLoss + gamma: 0.5 + method: mse + mask_sparse: True +rmse: + file: SupervisedDepthLoss + gamma: 0.5 + method: rmse + mask_sparse: True +mixture: + file: SupervisedDepthLoss + gamma: 0.5 + method: mixture + mask_sparse: True +cross_entropy: + file: SupervisedDepthLoss + gamma: 0.5 + method: cross_entropy + mask_sparse: True \ No newline at end of file diff --git a/configs/recipes/networks/conv_pose_net.yaml b/configs/recipes/networks/conv_pose_net.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f1a0a8e0252cce457f6ab422b538d3ff9c63e2ab --- /dev/null +++ b/configs/recipes/networks/conv_pose_net.yaml @@ -0,0 +1,5 @@ +default: + file: pose/ConvPoseNet + version: 18 + pretrained: True + num_rgb_in: 2 \ No newline at end of file diff --git a/configs/recipes/networks/focal_depth_res_net.yaml b/configs/recipes/networks/focal_depth_res_net.yaml new file mode 100644 index 0000000000000000000000000000000000000000..01f37b3d3be58f291fcbf3a4402b91584b45841e --- /dev/null +++ b/configs/recipes/networks/focal_depth_res_net.yaml @@ -0,0 +1,12 @@ +fsm_ddad: + file: depth/FocalDepthResNet + encoder: + version: 18 + pretrained: True + num_rgb_in: 1 + decoder: + use_skips: True + activation: sigmoid + num_ch_out: 1 + min_depth: 0.005 + max_depth: 0.3 diff --git a/configs/recipes/networks/mono_depth_res_net.yaml b/configs/recipes/networks/mono_depth_res_net.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2ea5c0f134d394c739040112bc245f7060076594 --- /dev/null +++ b/configs/recipes/networks/mono_depth_res_net.yaml @@ -0,0 +1,10 @@ +default: + file: depth/MonoDepthResNet + encoder: + version: 18 + pretrained: True + num_rgb_in: 1 + decoder: + use_skips: True + activation: sigmoid + num_ch_out: 1 diff --git a/configs/recipes/networks/multi_depth_res_net.yaml b/configs/recipes/networks/multi_depth_res_net.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5704bb61c6eb44367c354a44b1749ed8cee5d4a6 --- /dev/null +++ b/configs/recipes/networks/multi_depth_res_net.yaml @@ -0,0 +1,17 @@ +depthformer: + file: depth/MultiDepthResNet + encoder: + version: 18 + pretrained: True + input_shape: [192,640] + adaptive_bins: True + depth_range: [0.1,100.] + depth_bin_range: [0.1,100.0] + depth_binning: sid + num_depth_bins: 128 + decoder: + use_skips: True + use_aux_depth: False + activation: sigmoid + num_ch_out: 1 + depth_range: [0.1,100.] diff --git a/configs/recipes/networks/packnet.yaml b/configs/recipes/networks/packnet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2a8229c10e8fb31c8fe411105845054a2e32ca16 --- /dev/null +++ b/configs/recipes/networks/packnet.yaml @@ -0,0 +1,10 @@ +default: + file: depth/PackNet + encoder: + version: 18 + pretrained: True + num_rgb_in: 1 + decoder: + use_skips: True + activation: sigmoid + num_ch_out: 1 diff --git a/configs/recipes/networks/pose_net.yaml b/configs/recipes/networks/pose_net.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c180cea4674b2d5af7c49b6b524bfb9403b2ed66 --- /dev/null +++ b/configs/recipes/networks/pose_net.yaml @@ -0,0 +1,5 @@ +default: + file: pose/PoseNet + version: 18 + pretrained: True + num_rgb_in: 2 \ No newline at end of file diff --git a/configs/recipes/networks/transformer.yaml b/configs/recipes/networks/transformer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bd0da952c5d19cc2bd003dfde69fb3a776613e29 --- /dev/null +++ b/configs/recipes/networks/transformer.yaml @@ -0,0 +1,12 @@ +depthformer: + file: transformers/MatchModule + context_adjustment: + expansion_ratio: 4 + feat_dim: 16 + num_blocks: 8 + channel_dim: 128 + nheads: 8 + num_attn_layers: 6 + min_depth: 0.1 + max_depth: 100.0 + num_bins: 128 \ No newline at end of file diff --git a/configs/recipes/optimizers.yaml b/configs/recipes/optimizers.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6fb5ebd1ec4652f818a5647e3261dd8a1a55f99b --- /dev/null +++ b/configs/recipes/optimizers.yaml @@ -0,0 +1,7 @@ +adam_20_05: + name: Adam + lr: 0.0001 + scheduler: + name: StepLR + step_size: 20 + gamma: 0.5 diff --git a/configs/recipes/save.yaml b/configs/recipes/save.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4a90afd4e668fb10ae2bb83de5f0e4e4a9bc6d74 --- /dev/null +++ b/configs/recipes/save.yaml @@ -0,0 +1,5 @@ +depth_splitname: + folder: /data/vidar/save + naming: splitname + rgb: [tgt] + depth: [viz,npz] \ No newline at end of file diff --git a/configs/recipes/wandb.yaml b/configs/recipes/wandb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..759968b204b14882a9631665b7d4b18ec1d9c34e --- /dev/null +++ b/configs/recipes/wandb.yaml @@ -0,0 +1,5 @@ +default: + folder: /data/vidar/wandb + entity: tri + project: vidar + num_validation_logs: 5 diff --git a/configs/recipes/wrapper.yaml b/configs/recipes/wrapper.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7293836d994ee5e8af4c593df953cc4386353278 --- /dev/null +++ b/configs/recipes/wrapper.yaml @@ -0,0 +1,10 @@ +default: + seed: 42 + min_epochs: 0 + max_epochs: 50 + grad_scaler: False + find_unused_parameters: True + flip_lr_prob: 0.5 + validate_first: True + validate_flipped: True + sync_batch_norm: True diff --git a/data/generic/scene1/15616458296936490.jpg b/data/generic/scene1/15616458296936490.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a9b8a988f7fea63f4ffbb6088abde4e664f1d8f3 Binary files /dev/null and b/data/generic/scene1/15616458296936490.jpg differ diff --git a/data/generic/scene1/15616458297936490.jpg b/data/generic/scene1/15616458297936490.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fb0177112b0b33f74199eb742872ddc1498e366e Binary files /dev/null and b/data/generic/scene1/15616458297936490.jpg differ diff --git a/data/generic/scene1/15616458298936492.jpg b/data/generic/scene1/15616458298936492.jpg new file mode 100644 index 0000000000000000000000000000000000000000..17b4ef41aa30deac5c330b0bd076f2e0fba22e81 Binary files /dev/null and b/data/generic/scene1/15616458298936492.jpg differ diff --git a/data/generic/scene1/15616458299936482.jpg b/data/generic/scene1/15616458299936482.jpg new file mode 100644 index 0000000000000000000000000000000000000000..32fb198b40b9da2268b44ef815601992a16a4dbb Binary files /dev/null and b/data/generic/scene1/15616458299936482.jpg differ diff --git a/data/generic/scene1/15616458300936480.jpg b/data/generic/scene1/15616458300936480.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d225cb4f6abec315cc5b4e3354881ecf3e11a808 Binary files /dev/null and b/data/generic/scene1/15616458300936480.jpg differ diff --git a/data/masks/ddad/01.png b/data/masks/ddad/01.png new file mode 100755 index 0000000000000000000000000000000000000000..5c6a1bfcff7e67517af1b44942622ce8e75b9b14 --- /dev/null +++ b/data/masks/ddad/01.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04c221df55ea2cb011114b9418a114dbb9017c2090405344df505114ea1865bf +size 4943 diff --git a/data/masks/ddad/05.png b/data/masks/ddad/05.png new file mode 100755 index 0000000000000000000000000000000000000000..23d82588e15fe7b0119f5a2d564ef9a2604ccf26 --- /dev/null +++ b/data/masks/ddad/05.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aaa8817848b87232baf54a1c78fda5db0b91c0a2439be666c6243cdbc96a8bc8 +size 8006 diff --git a/data/masks/ddad/06.png b/data/masks/ddad/06.png new file mode 100755 index 0000000000000000000000000000000000000000..bd7a27a416acfaf90133e46d8ddb97daecd2188f --- /dev/null +++ b/data/masks/ddad/06.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9466db07da39e90f1f92a01ce53331a1f7badf3ddd56ea05f9aa564cccdedba0 +size 5831 diff --git a/data/masks/ddad/07.png b/data/masks/ddad/07.png new file mode 100755 index 0000000000000000000000000000000000000000..89498a2b4b0317388313d0fa7a9eb71d6f24384e --- /dev/null +++ b/data/masks/ddad/07.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4b99acc91e02eeb16e931e628c63a556d9580a134fab7556ae09787027ab9e22 +size 6886 diff --git a/data/masks/ddad/08.png b/data/masks/ddad/08.png new file mode 100755 index 0000000000000000000000000000000000000000..25490ecd2e440d292a13ea8651409b9aa75a63cb --- /dev/null +++ b/data/masks/ddad/08.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ea91d428b0562e1592077b734277fb8634d99c66a123ef33a6ac52a9724a5aca +size 7199 diff --git a/data/masks/ddad/09.png b/data/masks/ddad/09.png new file mode 100755 index 0000000000000000000000000000000000000000..39c0af2a34d2349f084d63e9b6343100c36c7ba3 --- /dev/null +++ b/data/masks/ddad/09.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9e48cffda84af21884c0f4e1e1965d4a70e9724fd895a8e936bf16ffded90a5c +size 6517 diff --git a/demos/display_datasets/config.yaml b/demos/display_datasets/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..606012e38377ceef980a46bc26e09111b557b6ef --- /dev/null +++ b/demos/display_datasets/config.yaml @@ -0,0 +1,34 @@ + +# To download datasets use the following command: +# wget https://tri-ml-public.s3.amazonaws.com/github/vidar/datasets/{DATASET}.tar /data/vidar +# Don't forget to untar it afterwards, with: +# tar xvf /data/vidar/{DATASET}.tar -C /data/vidar + +datasets: + kitti: + name: [KITTI] + path: [/data/vidar/KITTI_tiny] + split: [kitti_tiny.txt] + context: [-1,3] + cameras: [[0,1]] + labels: [depth,pose] + labels_context: [depth, pose] + depth_type: [velodyne] + vkitti2: + name: [VKITTI2] + path: [/data/vidar/VKITTI2_tiny] + split: [train] + context: [-2,2] + cameras: [[0,1]] + labels: [depth,pose,optical_flow] + labels_context: [depth,pose,optical_flow] + ddad: + name: [Ouroboros] + path: [/data/vidar/DDAD_tiny/ddad_tiny.json] + split: [train] + context: [-1,3] + cameras: [[1,5,6,7,8,9]] + labels: [depth,pose] + labels_context: [depth,pose] + depth_type: [lidar] + virtual: [False] \ No newline at end of file diff --git a/demos/display_datasets/display_datasets.py b/demos/display_datasets/display_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..86fdf6f6191acbe91f8ff8938fcacd776f10e45b --- /dev/null +++ b/demos/display_datasets/display_datasets.py @@ -0,0 +1,14 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import sys + +from display.display_sample import display_sample +from vidar.utils.config import read_config +from vidar.utils.data import set_random_seed +from vidar.utils.setup import setup_datasets + +set_random_seed(42) + +cfg = read_config('demos/display_datasets/config.yaml') +datasets = setup_datasets(cfg.datasets, stack=False) +display_sample(datasets[0][sys.argv[1]][0][0], flip=False) diff --git a/demos/run_network/config.yaml b/demos/run_network/config.yaml new file mode 100755 index 0000000000000000000000000000000000000000..d865f6752785a0c93360f33343c0f8a073fb82ed --- /dev/null +++ b/demos/run_network/config.yaml @@ -0,0 +1,10 @@ +encoder: + version: 18 + pretrained: True + num_rgb_in: 1 +decoder: + use_skips: True + activation: sigmoid + num_ch_out: 1 +depth_range: [0.1,200.] + diff --git a/demos/run_network/run_network.py b/demos/run_network/run_network.py new file mode 100755 index 0000000000000000000000000000000000000000..cfcef146ed3229c6e2f52b8cda12dadea761bb68 --- /dev/null +++ b/demos/run_network/run_network.py @@ -0,0 +1,16 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import torch + +from vidar.arch.networks.depth.MonoDepthResNet import MonoDepthResNet +from vidar.utils.config import read_config + +### Create network + +cfg = read_config('demos/run_network/config.yaml') +net = MonoDepthResNet(cfg) + +### Create dummy input and run network + +rgb = torch.randn((2, 3, 128, 128)) +depth = net(rgb=rgb)['depths'] diff --git a/display/display_sample.py b/display/display_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..c544f21d486a50c9186b16a9deda3552e72fe048 --- /dev/null +++ b/display/display_sample.py @@ -0,0 +1,130 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import json + +import numpy as np +from camviz import BBox3D +from camviz import Camera as CameraCV +from camviz import Draw + +from vidar.geometry.camera import Camera +from vidar.geometry.pose import Pose +from vidar.utils.data import make_batch, fold_batch, modrem +from vidar.utils.flip import flip_batch +from vidar.utils.viz import viz_depth, viz_optical_flow, viz_semantic + + +def change_key(dic, c, n): + steps = sorted(dic.keys()) + return steps[(steps.index(c) + n) % len(steps)] + + +def display_sample(data, flip=False): + + tasks = ['rgb', 'depth', 'fwd_optical_flow', 'bwd_optical_flow','semantic'] + cam_colors = ['red', 'blu', 'gre', 'yel', 'mag', 'cya'] * 100 + + data = make_batch(data) + if flip: + data = flip_batch(data) + data = fold_batch(data) + + rgb = data['rgb'] + intrinsics = data['intrinsics'] + depth = data['depth'] + pose = data['pose'] + + pose = Pose.from_dict(pose, to_global=True) + cam = Camera.from_dict(intrinsics, rgb, pose) + + num_cams = rgb[0].shape[0] + wh = rgb[0].shape[-2:][::-1] + + keys = [key for key in tasks if key in data.keys()] + + points = {} + for key, val in cam.items(): + points[key] = cam[key].reconstruct_depth_map( + depth[key], to_world=True).reshape(num_cams, 3, -1).permute(0, 2, 1) + + draw = Draw((wh[0] * 4, wh[1] * 3), width=2100) + draw.add2DimageGrid('cam', (0.0, 0.0, 0.5, 1.0), n=(3, 2), res=wh) + draw.add3Dworld('wld', (0.5, 0.0, 1.0, 1.0), pose=cam[0].Tcw.T[0]) + + draw.addTexture('cam', n=num_cams) + draw.addBuffer3f('lidar', 1000000, n=num_cams) + draw.addBuffer3f('color', 1000000, n=num_cams) + + with_bbox3d = 'bbox3d' in data + if with_bbox3d: + bbox3d_corners = [[BBox3D(b) for b in bb] for bb in data['bbox3d']['corners']] + + with_pointcache = 'pointcache' in data + if with_pointcache: + pointcache = np.concatenate([np.concatenate(pp, 0) for pp in data['pointcache']['points']], 0) + draw.addBufferf('pointcache', pointcache[:, :3]) + + camcv = [] + for i in range(num_cams): + camcv.append({key: CameraCV.from_vidar(val, i) for key, val in cam.items()}) + + t, k = 0, 0 + key = keys[k] + change = True + color = True + + while draw.input(): + if draw.SPACE: + color = not color + change = True + if draw.RIGHT: + change = True + k = (k + 1) % len(keys) + while t not in data[keys[k]].keys(): + k = (k + 1) % len(keys) + key = keys[k] + if draw.LEFT: + change = True + k = (k - 1) % len(keys) + while t not in data[keys[k]].keys(): + k = (k - 1) % len(keys) + key = keys[k] + if draw.UP: + change = True + t = change_key(data[key], t, 1) + while t not in data[keys[k]].keys(): + t = change_key(data[key], t, 1) + if draw.DOWN: + change = True + t = change_key(data[key], t, -1) + while t not in data[keys[k]].keys(): + t = change_key(data[key], t, -1) + if change: + change = False + for i in range(num_cams): + img = data[key][t][i] + if key == 'depth': + img = viz_depth(img, filter_zeros=True) + elif key in ['fwd_optical_flow', 'bwd_optical_flow']: + img = viz_optical_flow(img) + elif key == 'semantic': + ontology = json.load(open('vidar/datasets/ontologies/%s.json' % data['tag'][0])) + img = viz_semantic(img, ontology) + draw.updTexture('cam%d' % i, img) + draw.updBufferf('lidar%d' % i, points[t][i]) + draw.updBufferf('color%d' % i, data['rgb'][t][i]) + + draw.clear() + for i in range(num_cams): + draw['cam%d%d' % modrem(i, 2)].image('cam%d' % i) + draw['wld'].size(1).color(cam_colors[i]).points('lidar%d' % i, ('color%d' % i) if color else None) + for cam_key, cam_val in camcv[i].items(): + clr = cam_colors[i] if cam_key == t else 'gra' + tex = 'cam%d' % i if cam_key == t else None + draw['wld'].object(cam_val, color=clr, tex=tex) + if with_bbox3d: + [[draw['wld'].object(b) for b in bb] for bb in bbox3d_corners] + if with_pointcache: + draw['wld'].color('whi').points('pointcache') + + draw.update(30) diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..60a39bf920b89ddc87f4279c9ec8899d7c9d2224 --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,107 @@ +FROM nvidia/cuda:11.3.1-devel-ubuntu18.04 + +ENV PROJECT=vidar +ENV LC_ALL=C.UTF-8 +ENV LANG=C.UTF-8 + +ENV PYTHON_VERSION=3.8 +ENV PYTORCH_VERSION=1.10.0+cu113 +ENV TORCHVISION_VERSION=0.11.1+cu113 +ENV CUDNN_VERSION=8.2.1.32-1+cuda11.3 +ENV NCCL_VERSION=2.9.9-1+cuda11.3 + +# Install basic libraries +RUN apt-get update && apt-get install -y \ + build-essential cmake g++-4.8 git curl docker.io vim wget ca-certificates + +# Install python and pip +RUN apt-get install -y python${PYTHON_VERSION} python3-pip +RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python & \ + ln -s /usr/bin/pip3 /usr/bin/pip + +# Upgrade pip +RUN pip install --upgrade pip + +# Install pytorch and torchvision +RUN pip install \ + torch==${PYTORCH_VERSION} \ + torchvision==${TORCHVISION_VERSION} \ + -f https://download.pytorch.org/whl/torch_stable.html + +# Install CUDNN and NCCL +RUN apt-get install -y \ + libcudnn8=${CUDNN_VERSION} \ + libnccl2=${NCCL_VERSION} + +# Install extra packages (apt-get) +RUN apt-get install -y \ + ffmpeg \ + tmux + +# Install extra packages (pip) +RUN pip install \ + tqdm==4.61.0 \ + boto3==1.17.83 \ + termcolor==1.1.0 \ + pyyaml==5.4.1 \ + wandb==0.10.31 \ + opencv-python==4.5.2.52 \ + flow_vis==0.1 \ + matplotlib==3.3.4 \ + fire==0.4.0 \ + pyquaternion==0.9.9 \ + pandas==1.1.5 \ + xarray==0.16.2 \ + diskcache==5.2.1 \ + tenacity==7.0.0 \ + pycocotools==2.0.2 \ + awscli==1.19.101 \ +# timm==0.4.9 \ + ref==0.0.2.2 \ + positional-encodings==4.0.0 \ + einops==0.3.2 \ + wget \ + ftfy \ + regex \ + tqdm + +# Install CamViz dependencies +RUN pip install \ + pygame==2.0.1 \ + PyOpenGL==3.1.5 \ + PyOpenGL-accelerate==3.1.5 +RUN apt-get install -y \ + mesa-utils \ + freeglut3-dev \ + libsdl2-2.0-0 \ + python-pygame + +# Install PyTorch3D +RUN pip install pytorch3d + +# Install CuPY +RUN pip install cupy + +# Install huggingface transformers +RUN pip install transformers + +# Install extras (should be moved to top when stable) +RUN pip install lpips wget scikit-image pyhocon dotmap path sacremoses filelock huggingface_hub +RUN pip install 'git+https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup' +RUN pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu113.html + +# Install DGP (dataset utils) +WORKDIR /workspace +RUN git clone https://github.com/VitorGuizilini-TRI/dgp.git +ENV PYTHONPATH="/workspace/dgp:$PYTHONPATH" + +# Create workspace folder +RUN mkdir -p /workspace/experiments +RUN mkdir -p /workspace/${PROJECT} +WORKDIR /workspace/${PROJECT} +# Copy project to workspace folder +COPY . /workspace/${PROJECT} + +# Set environment variables +ENV PYTHONPATH="/workspace/${PROJECT}:$PYTHONPATH" +ENV PYTHONPATH="/workspace/${PROJECT}/externals/camviz:$PYTHONPATH" diff --git a/externals/camviz/.gitignore b/externals/camviz/.gitignore new file mode 100755 index 0000000000000000000000000000000000000000..1b58d2a8a27eda23d7a749f31cbf8b507da576ba --- /dev/null +++ b/externals/camviz/.gitignore @@ -0,0 +1,3 @@ +*.pyc +.idea +__pycache__ diff --git a/externals/camviz/LICENSE.md b/externals/camviz/LICENSE.md new file mode 100755 index 0000000000000000000000000000000000000000..996346997eb8db1521d947c13a73422a9d77eeb8 --- /dev/null +++ b/externals/camviz/LICENSE.md @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 Toyota Research Institute (TRI) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/externals/camviz/README.md b/externals/camviz/README.md new file mode 100755 index 0000000000000000000000000000000000000000..9f413a6635ae386ef1417fa58b761040ee0dd81c --- /dev/null +++ b/externals/camviz/README.md @@ -0,0 +1,78 @@ + + + + + +## CamViz + +[Overview](#overview) // [Installation](#install) // [Demos](#demos) // [License](#license) + + + + + + +## Overview + +**CamViz** is a visualization library developed by the TRI-ML team with the goal of providing an interface for the visualization of monocular depth estimation results, both as depth maps and reconstructed pointclouds. It uses [PyGame](https://www.pygame.org/news) for window display and input management (mouse and keyboard), and [OpenGL](https://www.opengl.org//) for 2D and 3D drawing and rendering. It provides an easy and intuitive way to: +- Store information as textures and data buffers for efficient display +- Create 2D environments for image display and 3D environments for pointcloud visualization +- A pinhole camera class that manages most basic geometric operations (reconstruction, projection, transformation to different coordinate frames, etc.) + +Although **CamViz** works as a standalone library, it was designed specifically to be used in conjunction with other TRI-ML's repositories, in particular [PackNet-SFM](https://github.com/tri-ml/packnet-sfm) and [DDAD](https://github.com/tri-ml/ddad). To facilitate integration, it is also provided as a submodule in those repositories. + +## Installation + +We provide a `requirements.txt` file with all the required libraries (tested on Ubuntu 18.04). To start using **CamViz** all you need to do is: + +``` +git clone git@github.com:TRI-ML/camviz.git +cd camviz +pip install -r requirements.txt +PYTHONPATH=$PYTHONPATH:/path/to/camviz +``` + +## Demos + +The **CamViz** repository comes with a demo that visualizes a predicted monocular pointcloud (already calculated, and provided as part of the repository). We plan to include more demos as more functionalities are added, usually tied to scientific publications. +To run it, type the following command from the root folder: + +``` +python demos/pointcloud.py +``` + +The output should look like this: + + + + + +From this initial display you can: +- Zoom in/out on the images with the mouse wheel, and translate within image boundaries. +- Move freely within the 3D viewer (translation, rotation and zoom in/out) with the mouse. +- Change color modes with the `enter` key. + +## License + +The source code is released under the [MIT license](LICENSE.md). diff --git a/externals/camviz/camviz/__init__.py b/externals/camviz/camviz/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..99f399f2b8f1ebcb3492ba0c760ae3cbad564a36 --- /dev/null +++ b/externals/camviz/camviz/__init__.py @@ -0,0 +1,7 @@ +# Copyright 2021 Toyota Research Institute. All rights reserved. + +from camviz.draw.draw import Draw +from camviz.objects.camera import Camera +from camviz.objects.pose import Pose +from camviz.objects.bbox2d import BBox2D +from camviz.objects.bbox3d import BBox3D diff --git a/externals/camviz/camviz/containers/buffer.py b/externals/camviz/camviz/containers/buffer.py new file mode 100755 index 0000000000000000000000000000000000000000..34d84064358403247a7f3d81fe3bb4d7c229a6d2 --- /dev/null +++ b/externals/camviz/camviz/containers/buffer.py @@ -0,0 +1,107 @@ + +import numpy as np +from OpenGL.GL import \ + glGenBuffers, glBindBuffer, glBufferData, glBufferSubData, GL_ARRAY_BUFFER, GL_STATIC_DRAW + +from camviz.utils.utils import numpyf +from camviz.utils.types import is_tuple, is_list, is_tensor +from camviz.utils.cmaps import jet + + +class Buffer: + """ + Initialize a data buffer + + Parameters + ---------- + data : np.array [N,D] or tuple (n,d) + Data to be added to the buffer + If it's a tuple, create a data buffer of that size + dtype : numpy type (e.g. np.float32) + Numpy data type + gltype : OpenGL type (e.g. GL_FLOAT32) + OpenGL data type + """ + def __init__(self, data, dtype, gltype): + # Initialize buffer ID and max size + self.id, self.max = glGenBuffers(1), 0 + # Store data types + self.dtype, self.gltype = dtype, gltype + if is_tuple(data): + # If data is a tuple, store dimensions + data, (self.n, self.d) = None, data + else: + # Process data and store dimensions + data = self.process(data) + self.n, self.d = data.shape[:2] + # If size is larger than available, recreate buffer + if self.n > self.max: + self._create(data) + + @property + def size(self): + """Get buffer size""" + return self.n * self.d * np.dtype(self.dtype).itemsize + + def process(self, data): + """ + Process data buffer to get relevant information + + Parameters + ---------- + data : list or np.array or torch.Tensor + Data to be processed + + Returns + ------- + data : np.array + Processed data + """ + # If it's a list + if is_list(data): + data = numpyf(data) + # If it's a tensor + if is_tensor(data): + # If tensor is a grid with 3D coordinates [3,H,W] + if data.dim() == 3 and data.shape[0] == 3: + data = data.permute(1, 2, 0).reshape(-1, 3) + data = data.detach().cpu().numpy() + # If it's not the correct type, convert + if data.dtype != self.dtype: + data = data.astype(self.dtype) + # Expand if necessary + if len(data.shape) == 1: + data = np.expand_dims(data, 1) + # Return data + return data + + def _create(self, data): + """Create a new data buffer""" + self.max = self.n + glBindBuffer(GL_ARRAY_BUFFER, self.id) + glBufferData(GL_ARRAY_BUFFER, self.size, data, GL_STATIC_DRAW) + glBindBuffer(GL_ARRAY_BUFFER, 0) + + def update(self, data): + """Update data buffer""" + # Process data + data = self.process(data) + # Get dimensions or initialize as zero + self.n = 0 if data.size == 0 else data.shape[0] + # If dimensions are larger than available, recreate + if self.n > self.max: + self._create(data) + # Otherwise + else: + # Bind buffer and copy data + glBindBuffer(GL_ARRAY_BUFFER, self.id) + glBufferSubData(GL_ARRAY_BUFFER, 0, self.size, data.astype(self.dtype)) + glBindBuffer(GL_ARRAY_BUFFER, 0) + + def clear(self): + """Clear buffer""" + self.n = 0 + + def updateJET(self, data): + """Update buffer using a JET colormap""" + self.update(jet(data)) diff --git a/externals/camviz/camviz/containers/texture.py b/externals/camviz/camviz/containers/texture.py new file mode 100755 index 0000000000000000000000000000000000000000..e64052e868aa0bcee6f484b6a5010c3eae7e7cb4 --- /dev/null +++ b/externals/camviz/camviz/containers/texture.py @@ -0,0 +1,144 @@ + +import cv2 +import numpy as np +import pygame +from OpenGL.GL import \ + glEnable, glDisable, glTexParameterf, \ + glBindTexture, glGenTextures, glTexImage2D, glTexSubImage2D, \ + GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_TEXTURE_WRAP_T, GL_TEXTURE_MAG_FILTER, \ + GL_TEXTURE_MIN_FILTER, GL_REPEAT, GL_NEAREST, GL_RGB, GL_RGBA, GL_UNSIGNED_BYTE + +from camviz.utils.types import is_str, is_tensor, is_numpy, is_tuple + + +def load(image): + """ + Load an image as a texture surface + + Parameters + ---------- + image : np.array or str + Input image + Returns + ------- + surface : pygame surface + Output surface for texture buffer + """ + # Return None if image is None + if image is None: + return None + # If image is a string, load from file + if is_str(image): + surface = pygame.image.load(image) + # If it's a numpy array + else: + # Convert to uint8 and transpose + image = image.astype(np.uint8) + image = np.transpose(image, (1, 0, 2)) + # Create surface + surface = pygame.surfarray.make_surface(image) + # Return surface + return pygame.image.tostring(surface, "RGBA", 1) + + +class Texture: + """ + Initialize a texture buffer + + Parameters + ---------- + data : np.array [N,D] or tuple (n,d) + Data to be added to the buffer + If it's a tuple, create a data buffer of that size + """ + def __init__(self, data=None): + # Create a new texture ID + self.id = glGenTextures(1) + # If data exists create texture buffer from it + if data is not None: + self._create(data) + # Otherwise, just store dimensions + else: + self.wh = None + + def process(self, image): + """Process a new image to produce a texture buffer""" + # If it's a tensor + if is_tensor(image): + # Detach and transpose + image = image.detach().cpu().numpy() + if len(image.shape) == 4: + image = image[0] + if len(image.shape) == 3 and image.shape[0] == 3: + image = image.transpose((1, 2, 0)) + # If it's a numpy array + if is_numpy(image): + # Resize to proper shape + if image.shape[0] != self.wh[0] or image.shape[1] != self.wh[1]: + image = cv2.resize(image, self.wh, interpolation=cv2.INTER_LINEAR) + # Squeeze if necessary + if len(image.shape) == 3 and image.shape[2] == 1: + image = image.squeeze(-1) + # Stack to 3 channels if necessary + if len(image.shape) == 2: + image = np.stack([image] * 3, axis=2) + # Return image + return image * 255 + + def _create(self, data): + """Create a texture buffer from data""" + # If it's tuple, it only contains dimensions + if is_tuple(data): + image = None + w, h = data[:2] + # If it's a tensor, convert to numpy + elif is_tensor(data): + image = data.detach().cpu().numpy().transpose(1, 2, 0) * 255 + h, w = data.shape[-2:] + # Otherwise, it contains data and dimensions + else: + image = data * 255 + h, w = data.shape[:2] + # Store dimensions + self.wh = (int(w), int(h)) + # Bind and fill texture + glBindTexture(GL_TEXTURE_2D, self.id) + glTexImage2D(GL_TEXTURE_2D, 0, GL_RGB, self.wh[0], self.wh[1], + 0, GL_RGBA, GL_UNSIGNED_BYTE, load(image)) + glBindTexture(GL_TEXTURE_2D, 0) + # Return image + return image + + def update(self, image): + """Update texture buffer from an image""" + # Return None if image is None + if image is None: + return None + # If there are no stored dimensions, create a new texture buffer + if self.wh is None: + self._create(image) + # Otherwise, update texture buffer + else: + # Process image + image = self.process(image) + # Bind and update buffer + glBindTexture(GL_TEXTURE_2D, self.id) + glTexSubImage2D(GL_TEXTURE_2D, 0, 0, 0, self.wh[0], self.wh[1], + GL_RGBA, GL_UNSIGNED_BYTE, load(image)) + glBindTexture(GL_TEXTURE_2D, 0) + + def bind(self): + """Bind and store data in texture buffer""" + glEnable(GL_TEXTURE_2D) + glBindTexture(GL_TEXTURE_2D, self.id) + + glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S , GL_REPEAT ) + glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T , GL_REPEAT ) + glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER , GL_NEAREST) + glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER , GL_NEAREST) + + @staticmethod + def unbind(): + """Unbind texture buffer""" + glDisable(GL_TEXTURE_2D) + glBindTexture(GL_TEXTURE_2D, 0) diff --git a/externals/camviz/camviz/data/buffer.py b/externals/camviz/camviz/data/buffer.py new file mode 100755 index 0000000000000000000000000000000000000000..c074ca4f9f776fd37cefd7e063be00badcf9f9ca --- /dev/null +++ b/externals/camviz/camviz/data/buffer.py @@ -0,0 +1,105 @@ +# Copyright 2021 Toyota Research Institute. All rights reserved. + +import numpy as np +from OpenGL.GL import \ + glGenBuffers, glBindBuffer, glBufferData, glBufferSubData, GL_ARRAY_BUFFER, GL_STATIC_DRAW + +from camviz.utils.utils import numpyf +from camviz.utils.types import is_tuple, is_list, is_tensor +from camviz.utils.cmaps import jet + + +class Buffer: + """ + Initialize a data buffer + + Parameters + ---------- + data : np.array [N,D] or tuple (n,d) + Data to be added to the buffer + If it's a tuple, create a data buffer of that size + dtype : numpy type (e.g. np.float32) + Numpy data type + gltype : OpenGL type (e.g. GL_FLOAT32) + OpenGL data type + """ + def __init__(self, data, dtype, gltype): + # Initialize buffer ID and max size + self.id, self.max = glGenBuffers(1), 0 + # Store data types + self.dtype, self.gltype = dtype, gltype + if is_tuple(data): + # If data is a tuple, store dimensions + data, (self.n, self.d) = None, data + else: + # Process data and store dimensions + data = self.process(data) + self.n, self.d = data.shape[:2] + # If size is larger than available, recreate buffer + if self.n > self.max: + self._create(data) + + @property + def size(self): + """Get buffer size""" + return self.n * self.d * np.dtype(self.dtype).itemsize + + def process(self, data): + """ + Process data buffer to get relevant information + + Parameters + ---------- + data : list or np.array or torch.Tensor + Data to be processed + + Returns + ------- + data : np.array + Processed data + """ + # If it's a list + if is_list(data): + data = numpyf(data) + # If it's a tensor + if is_tensor(data): + data = data.detach().cpu().numpy() + # If it's not the correct type, convert + if data.dtype != self.dtype: + data = data.astype(self.dtype) + # Expand if necessary + if len(data.shape) == 1: + data = np.expand_dims(data, 1) + # Return data + return data + + def _create(self, data): + """Create a new data buffer""" + self.max = self.n + glBindBuffer(GL_ARRAY_BUFFER, self.id) + glBufferData(GL_ARRAY_BUFFER, self.size, data, GL_STATIC_DRAW) + glBindBuffer(GL_ARRAY_BUFFER, 0) + + def update(self, data): + """Update data buffer""" + # Process data + data = self.process(data) + # Get dimensions or initialize as zero + self.n = 0 if data.size == 0 else data.shape[0] + # If dimensions are larger than available, recreate + if self.n > self.max: + self._create(data) + # Otherwise + else: + # Bind buffer and copy data + glBindBuffer(GL_ARRAY_BUFFER, self.id) + glBufferSubData(GL_ARRAY_BUFFER, 0, self.size, data.astype(self.dtype)) + glBindBuffer(GL_ARRAY_BUFFER, 0) + + def clear(self): + """Clear buffer""" + self.n = 0 + + def updateJET(self, data): + """Update buffer using a JET colormap""" + self.update(jet(data)) diff --git a/externals/camviz/camviz/data/texture.py b/externals/camviz/camviz/data/texture.py new file mode 100755 index 0000000000000000000000000000000000000000..0d8a453e75f6ce8829cd00f13f193aa7c6312ebd --- /dev/null +++ b/externals/camviz/camviz/data/texture.py @@ -0,0 +1,137 @@ +# Copyright 2021 Toyota Research Institute. All rights reserved. + +import cv2 +import numpy as np +import pygame +from OpenGL.GL import \ + glEnable, glDisable, glTexParameterf, \ + glBindTexture, glGenTextures, glTexImage2D, glTexSubImage2D, \ + GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_TEXTURE_WRAP_T, GL_TEXTURE_MAG_FILTER, \ + GL_TEXTURE_MIN_FILTER, GL_REPEAT, GL_NEAREST, GL_RGB, GL_RGBA, GL_UNSIGNED_BYTE + +from camviz.utils.types import is_str, is_tensor, is_numpy, is_tuple + + +def load(image): + """ + Load an image as a texture surface + + Parameters + ---------- + image : np.array or str + Input image + Returns + ------- + surface : pygame surface + Output surface for texture buffer + """ + # Return None if image is None + if image is None: + return None + # If image is a string, load from file + if is_str(image): + surface = pygame.image.load(image) + # If it's a numpy array + else: + # Convert to uint8 and transpose + image = image.astype(np.uint8) + image = np.transpose(image, (1, 0, 2)) + # Create surface + surface = pygame.surfarray.make_surface(image) + # Return surface + return pygame.image.tostring(surface, "RGBA", 1) + + +class Texture: + """ + Initialize a texture buffer + + Parameters + ---------- + data : np.array [N,D] or tuple (n,d) + Data to be added to the buffer + If it's a tuple, create a data buffer of that size + """ + def __init__(self, data=None): + # Create a new texture ID + self.id = glGenTextures(1) + # If data exists create texture buffer from it + if data is not None: + self._create(data) + # Otherwise, just store dimensions + else: + self.wh = None + + def process(self, image): + """Process a new image to produce a texture buffer""" + # If it's a tensor + if is_tensor(image): + # Detach and transpose + image = image.detach().cpu().numpy() + if len(image.shape) == 3: + image = image.transpose((1, 2, 0)) + # If it's a numpy array + if is_numpy(image): + # Resize to proper shape + if image.shape[0] != self.wh[0] or image.shape[1] != self.wh[1]: + image = cv2.resize(image, self.wh, interpolation=cv2.INTER_LINEAR) + # Squeeze if necessary + if len(image.shape) == 3 and image.shape[2] == 1: + image = image.squeeze(-1) + # Stack to 3 channels if necessary + if len(image.shape) == 2: + image = np.stack([image] * 3, axis=2) + # Return image + return image + + def _create(self, data): + """Create a texture buffer from data""" + # If it's tuple, it only contains dimensions + if is_tuple(data): + image, (w, h) = None, data[:2] + # Otherwise, it contains data and dimensions + else: + image, (h, w) = data, data.shape[:2] + # Store dimensions + self.wh = (int(w), int(h)) + # Bind and fill texture + glBindTexture(GL_TEXTURE_2D, self.id) + glTexImage2D(GL_TEXTURE_2D, 0, GL_RGB, self.wh[0], self.wh[1], + 0, GL_RGBA, GL_UNSIGNED_BYTE, load(image)) + glBindTexture(GL_TEXTURE_2D, 0) + # Return image + return image + + def update(self, image): + """Update texture buffer from an image""" + # Return None if image is None + if image is None: + return None + # If there are no stored dimensions, create a new texture buffer + if self.wh is None: + self._create(image) + # Otherwise, update texture buffer + else: + # Process image + image = self.process(image) + # Bind and update buffer + glBindTexture(GL_TEXTURE_2D, self.id) + glTexSubImage2D(GL_TEXTURE_2D, 0, 0, 0, self.wh[0], self.wh[1], + GL_RGBA, GL_UNSIGNED_BYTE, load(image)) + glBindTexture(GL_TEXTURE_2D, 0) + + def bind(self): + """Bind and store data in texture buffer""" + glEnable(GL_TEXTURE_2D) + glBindTexture(GL_TEXTURE_2D, self.id) + + glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S , GL_REPEAT ) + glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T , GL_REPEAT ) + glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER , GL_NEAREST) + glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER , GL_NEAREST) + + @staticmethod + def unbind(): + """Unbind texture buffer""" + glDisable(GL_TEXTURE_2D) + glBindTexture(GL_TEXTURE_2D, 0) diff --git a/externals/camviz/camviz/draw/__init__.py b/externals/camviz/camviz/draw/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..6a0810f6a8884401e5b7063de351170926551ee8 --- /dev/null +++ b/externals/camviz/camviz/draw/__init__.py @@ -0,0 +1,3 @@ +# Copyright 2021 Toyota Research Institute. All rights reserved. + +from camviz.draw.draw import Draw diff --git a/externals/camviz/camviz/draw/draw.py b/externals/camviz/camviz/draw/draw.py new file mode 100644 index 0000000000000000000000000000000000000000..928c55b4e655931f637941d6c4e1a5266af1548a --- /dev/null +++ b/externals/camviz/camviz/draw/draw.py @@ -0,0 +1,351 @@ +# Copyright 2021 Toyota Research Institute. All rights reserved. + +import time + +import numpy as np +import pygame +from OpenGL.GL import glReadPixels, glViewport, glScissor, \ + glClear, glClearColor, glPixelStorei, \ + GL_BGR, GL_UNSIGNED_BYTE, GL_COLOR_BUFFER_BIT, GL_DEPTH_BUFFER_BIT, \ + GL_PACK_ALIGNMENT, GL_RGBA +from PIL import Image, ImageOps +from pygame.locals import * + +from camviz.draw.draw_buffer import drawBuffer +from camviz.draw.draw_input import DrawInput +from camviz.draw.draw_texture import DrawTexture +from camviz.opengl.opengl_colors import setColor +from camviz.opengl.opengl_shapes import setPointSize, setLineWidth +from camviz.screen.screen2Dimage import Screen2Dimage +from camviz.screen.screen3Dworld import Screen3Dworld +from camviz.utils.types import is_tuple, is_list +from camviz.utils.utils import labelrc + + +class Draw(DrawInput, DrawTexture, drawBuffer): + + def __init__(self, wh, rc=None, title=None, scale=1.0, width=1600): + """ + Draw class for display visualization + + Parameters + ---------- + wh : tuple (width, height) + Window dimensions + rc : tuple (row, column) + Number of rows and columns (multiplying wh) + title : str + Window title + scale : float + Scale for width/height window dimensions + """ + super().__init__() + # Initialize pygame display + pygame.init() + # Initialize title + if title is not None: + pygame.display.set_caption(title) + # Initialize parameters + wh = [int(val * scale) for val in wh] + if width is not None: + wh[0], wh[1] = width, width * wh[1] // wh[0] + self.wh = self.curr_color = self.curr_size = self.curr_width = None + self.screens, self.textures, self.buffers = {}, {}, {} + self.idx_screen = None + # Set size and color + self.setSize(wh, rc) + self.color('whi').size(1).width(1) + + def setSize(self, wh, rc=None): + """Set window size""" + # Multiply row and column to produce correct dimensions + if rc is not None: + wh = (wh[0] * rc[1], wh[1] * rc[0]) + # Store dimensions + self.wh = wh + # Initialize display + pygame.display.set_mode(self.wh, DOUBLEBUF|OPENGL) + + def __getitem__(self, name): + """Get screen from name""" + return self.screen(name) + + def scr(self, name): + """Get screen from name""" + return self.screens[name] + + def tex(self, name): + """Get texture from name""" + return self.textures[name] + + def buf(self, name): + """Get buffer from name""" + return self.buffers[name] + + def object(self, obj, *args, **kwargs): + """Display object on screen""" + obj.display(self, *args, **kwargs) + + def to_image(self): + """Convert window into a numpy image""" + x, y, w, h = 0, 0, self.wh[0], self.wh[1] + data = glReadPixels(x, y, w, h, GL_BGR, GL_UNSIGNED_BYTE) + image = Image.frombytes("RGB", (w, h), data) + image = image.transpose(Image.FLIP_TOP_BOTTOM) + return np.asarray(image) + + def currScreen(self): + """Return current screen""" + return self.screens[self.idx_screen] + + def addScreen(self, luwh): + """ + Add a new screen to the draw window + + Parameters + ---------- + luwh : tuple (left, up, width, height) + Screen dimensions (percentage or pixels) + + Returns + ------- + l, u, w, h : int + Dimensions + """ + # Parse dimensions + l, u, w, h = luwh + w, h = w - l, h - u + # Convert percentages to pixels + if isinstance(l, float): + l = int(l * self.wh[0]) + if isinstance(u, float): + u = int(u * self.wh[1]) + if isinstance(w, float): + w = int(w * self.wh[0]) + if isinstance(h, float): + h = int(h * self.wh[1]) + # Get screen index and return dimensions + self.idx_screen = len(self.screens) + return l, u, w, h + + def screen(self, name): + """Set which screen will be used for drawing""" + self.idx_screen = name + # Get parameters + d = self.wh[1] + l, u, w, h = self.currScreen().luwh + u = d - (h + u) + # Create viewport and cropping + glViewport(l, u, w, h) + glScissor(l, u, w, h) + # Set background color + glClearColor(0.0, 0.0, 0.0, 1.0) + # Prepare current screen + self.currScreen().prepare() + return self + + def add2Dimage(self, name, luwh, res=None): + """ + Add 2D image screen + + Parameters + ---------- + name : str + Screen name + luwh : tuple (left, up, width, height) + Screen dimensions (pixels or percentage) + res : tuple (width, height) + Screen resolution + """ + # If name is a tuple, create labels for rows and columns + if is_tuple(name): + name = labelrc(name) + # If name is a list, create several screens + if is_list(name): + for i in range(len(name)): + luwh_i = list(luwh) + d = (luwh[3] - luwh[1]) / len(name) + luwh_i[1] = luwh[1] + i * d + luwh_i[3] = luwh_i[1] + d + for j in range(len(name[i])): + d = (luwh[2] - luwh[0]) / len(name[i]) + luwh_i[0] = luwh[0] + j * d + luwh_i[2] = luwh_i[0] + d + self.add2Dimage(name[i][j], luwh_i, res) + # Else, create a single screen + else: + self.screens[name] = Screen2Dimage(self.addScreen(luwh), res) + + def add2DimageRow(self, name, luwh, n, res=None): + """ + Add row with multiple 2D image screens + + Parameters + ---------- + name : str + Screen name + luwh : tuple (left, up, width, height) + Screen dimensions (pixels or percentage) + n : int + Number of columns in the row + res : tuple (width, height) + Screen resolution + """ + for i in range(n): + # Copy dimension vector + luwh_i = [val for val in luwh] + # Offset rows + luwh_i[0] = luwh[0] + (i / n) * (luwh[2] - luwh[0]) + luwh_i[2] = luwh[0] + ((i + 1) / n) * (luwh[2] - luwh[0]) + # Create 2D image screen + self.add2Dimage('%s%d' % (name, i), luwh_i, res) + + def add2DimageCol(self, name, luwh, n, res=None): + """ + Add column with multiple 2D image screens + + Parameters + ---------- + name : str + Screen name + luwh : tuple (left, up, width, height) + Screen dimensions (pixels or percentage) + n : int + Number of rows in the column + res : tuple (width, height) + Screen resolution + """ + for i in range(n): + # Copy dimension vector + luwh_i = [val for val in luwh] + # Offset columns + luwh_i[1] = luwh[1] + (i / n) * (luwh[3] - luwh[1]) + luwh_i[3] = luwh[1] + ((i + 1) / n) * (luwh[3] - luwh[1]) + # Create 2D image screen + self.add2Dimage('%s%d' % (name, i), luwh_i, res) + + def add2DimageGrid(self, name, luwh, n, res=None): + """ + Add grid with multiple 2D image screens + + Parameters + ---------- + name : str + Screen name + luwh : tuple (left, up, width, height) + Screen dimensions (pixels or percentage) + n : tuple (int, int) + Number of rows and columns in the grid + res : tuple (width, height) + Screen resolution + """ + for i in range(n[0]): + for j in range(n[1]): + # Copy dimension vector + luwh_i = [val for val in luwh] + # Offset columns + luwh_i[1] = luwh[1] + (i / n[0]) * (luwh[3] - luwh[1]) + luwh_i[3] = luwh[1] + ((i + 1) / n[0]) * (luwh[3] - luwh[1]) + # Offset rows + luwh_i[0] = luwh[0] + (j / n[1]) * (luwh[2] - luwh[0]) + luwh_i[2] = luwh[0] + ((j + 1) / n[1]) * (luwh[2] - luwh[0]) + # Create 2D image screen + self.add2Dimage('%s%d%d' % (name, i, j), luwh_i, res) + + def add3Dworld(self, name, luwh=(0.0, 0.0, 1.0, 1.0), **kwargs): + """ + Add a 3D world screen + + Parameters + ---------- + name : str + Screen name + luwh : tuple + Screen dimensions (left, up, width, height), in pixels or percentage + """ + # If name is a tuple, create labels for rows and columns + if is_tuple(name): + name = labelrc(name) + # If name is a list, create several screens + if is_list(name): + for i in range(len(name)): + luwh_i = list(luwh) + d = (luwh[3] - luwh[1]) / len(name) + luwh_i[1] = luwh[1] + i * d + luwh_i[3] = luwh_i[1] + d + for j in range(len(name[i])): + d = (luwh[2] - luwh[0]) / len(name[i]) + luwh_i[0] = luwh[0] + j * d + luwh_i[2] = luwh_i[0] + d + self.add3Dworld(name[i][j], luwh_i, **kwargs) + # Else, create a single screen + else: + self.screens[name] = Screen3Dworld(self.addScreen(luwh), **kwargs) + + @staticmethod + def clear(): + """Clear window""" + glClear(GL_COLOR_BUFFER_BIT|GL_DEPTH_BUFFER_BIT) + + def populate(self, data, fit=False): + """ + Populate screens with information from a dictionary + + Parameters + ---------- + data : dict + Dictionary with information for each screen + fit : bool + If true, resize screens to fit the data showing + """ + for key, val in data.items(): + # If it's a tuple, use key and val from positions 0 and 1 + if is_tuple(val): + self.screen(key).image(val[0], data=val[1], fit=fit) + # Else, use key and val directly + else: + self.screen(key).image(key, data=val, fit=fit) + + def size(self, n): + """Set point size""" + self.curr_size = 1 + setPointSize(n) + return self + + def width(self, n): + """Set line width""" + self.curr_width = n + setLineWidth(n) + return self + + def color(self, clr): + """Set plot color""" + self.curr_color = clr + setColor(clr) + return self + + def setCSW(self, csw): + """Set color, size and width (CSW) simultaneously""" + self.color(csw[0]).size(csw[1]).width(csw[2]) + + def getCSW(self): + """Get color, size and width (CSW) information""" + return self.curr_color, self.curr_size, self.curr_width + + @staticmethod + def halt(n): + """Stop for n milliseconds""" + time.sleep(n/1000) + + def save(self, filename): + """Save window as an image file""" + # Get dimensions + width, height = self.wh + # Store window information in a variable + glPixelStorei(GL_PACK_ALIGNMENT, 1) + data = glReadPixels(0, 0, width, height, GL_RGBA, GL_UNSIGNED_BYTE) + image = Image.frombytes("RGBA", (width, height), data) + image = ImageOps.flip(image) # in my case image is flipped top-bottom for some reason + # Save image and halt for a bit + image.save(filename, 'PNG') + self.halt(1000) diff --git a/externals/camviz/camviz/draw/draw_buffer.py b/externals/camviz/camviz/draw/draw_buffer.py new file mode 100755 index 0000000000000000000000000000000000000000..1a833c34f7e2ab297af7a9abf1aaa46850f4f0b2 --- /dev/null +++ b/externals/camviz/camviz/draw/draw_buffer.py @@ -0,0 +1,241 @@ +# Copyright 2021 Toyota Research Institute. All rights reserved. + +import numpy as np +from OpenGL.GL import glEnableClientState, glDisableClientState, \ + glPolygonMode, glVertexPointer, glBindBuffer, glColorPointer, \ + glDrawArrays, glDrawElements, glBegin, glEnd, glVertex2fv, glVertex3fv, \ + GL_ARRAY_BUFFER, GL_FILL, GL_ELEMENT_ARRAY_BUFFER, \ + GL_FLOAT, GL_UNSIGNED_INT, GL_POINTS, GL_FRONT_AND_BACK, GL_COLOR_ARRAY, \ + GL_VERTEX_ARRAY, GL_LINE, GL_LINES, GL_LINE_LOOP, GL_LINE_STRIP, GL_QUADS, GL_TRIANGLES + +from camviz.containers.buffer import Buffer +from camviz.opengl.opengl_shapes import drawConnects, drawMatches, drawAxis, drawEllipse +from camviz.utils.utils import grid_idx +from camviz.utils.types import is_str, is_list, is_tuple, is_int +from camviz.utils.cmaps import jet + + +class drawBuffer: + """Draw subclass containing data buffer methods""" + def __init__(self): + pass + + def addBuffer(self, name, data, dtype, gltype, n=None): + """ + Create a new data buffer + + Parameters + ---------- + name : str + Buffer name + data : np.array [N,D] or tuple (n,d) + Data to be added to the buffer + If it's a tuple, create a data buffer of that size + dtype : numpy type (e.g. np.float32) + Numpy data type + gltype : OpenGL type (e.g. GL_FLOAT32) + OpenGL data type + n : int or tuple + Number of textures to be added + """ + # If it's a list, create one buffer for each item + if is_list(name): + for i in range(len(name)): + self.addBuffer(name[i], data[i] if is_list(data) else data, dtype, gltype) + # Otherwise, create a single buffer + else: + if n is not None: + if is_tuple(n): + for i in range(n[0]): + for j in range(n[1]): + self.buffers['%s%d%d' % (name, i, j)] = Buffer(data, dtype, gltype) + elif is_int(n): + for i in range(n): + self.buffers['%s%d' % (name, i)] = Buffer(data, dtype, gltype) + self.buffers[name] = Buffer(data, dtype, gltype) + + def addBufferf(self, name, data=0): + """Create a buffer with float32 values (2D or 3D is determined from data)""" + self.addBuffer(name, data, np.float32, GL_FLOAT) + + def addBufferu(self, name, data=0): + """Create a buffer with unsigned 32 values (2D or 3D is determined from data)""" + self.addBuffer(name, data, np.uint32, GL_UNSIGNED_INT) + + def addBuffer2f(self, name, data=0, n=None): + """Create a 2D empty buffer with float32 values""" + self.addBuffer(name, (data, 2), np.float32, GL_FLOAT, n) + + def addBuffer3f(self, name, data=0, n=None): + """Create a 3D empty buffer with float32 values""" + self.addBuffer(name, (data, 3), np.float32, GL_FLOAT, n) + + def addbufferIDX(self, name, data=0): + """Create an index buffer for shape drawing""" + self.addBufferu(name, grid_idx(data)) + + def addBufferJET(self, name, data=0): + """Create a JET colormap buffer from data""" + self.addBufferf(name, jet(data)) + + def addBuffer3JET(self, name, data=0): + """Create an empty 3D colormap buffer from data""" + self.addBuffer3f(name, data) + + def updBufferf(self, name, data): + """Update a buffer with float32 values""" + self.buffers[name].update(data) + + def clrBuffer(self, name): + """Clear a buffer""" + self.buffers[name].clear() + + def points(self, *args, **kwargs): + """Draw points""" + return self._drawSomething(GL_POINTS, *args, **kwargs) + + def lines( self, *args, **kwargs): + """Draw lines""" + return self._drawSomething(GL_LINES, *args, **kwargs) + + def strips(self, *args, **kwargs): + """Draw strips (connecting adjacent vertices)""" + return self._drawSomething(GL_LINE_STRIP, *args, **kwargs) + + def loop( self, *args, **kwargs): + """Draw loops (strips with last vertices connected)""" + return self._drawSomething(GL_LINE_LOOP, *args, **kwargs) + + def quads( self, *args, **kwargs): + """Draw quadratics""" + return self._drawSomething(GL_QUADS, *args, **kwargs) + + def tris( self, *args, **kwargs): + """Draw triangles""" + return self._drawSomething(GL_TRIANGLES, *args, **kwargs) + + def grid( self, *args, **kwargs): + """Draw a grid""" + return self._drawSomething(GL_QUADS, *args, **kwargs) + + def matches(self, *args, **kwargs): + """Draw matches from two sets of points""" + drawMatches(*args, **kwargs) + return self + + def connects(self, *args, **kwargs): + """Draw a connection from one point to many points""" + drawConnects(*args, **kwargs) + return self + + def axis(self, *args, **kwargs): + """Draw coordinate axis""" + drawAxis(*args, **kwargs) + return self + + def ellipse(self, *args, **kwargs): + """Draw ellipse""" + drawEllipse(*args, **kwargs) + return self + + def _drawSomething(self, shape, *args, **kwargs): + """ + Base function for shape drawing + + Parameters + ---------- + shape : opengl shape + OpenGL shape to draw (e.g. GL_POINTS) + args : args + Extra draw arguments + kwargs : kwargs + Extra draw arguments + """ + # If it's a string, draw buffer + if is_str(args[0]): + return self._drawBuffer(shape, *args, **kwargs) + # Otherwise, copy data and draw + else: + return self._drawBase(shape, *args, **kwargs) + + def _drawBuffer(self, shape, vert, color=None, idx=None, wire=None): + """ + Draw from a buffer + + Parameters + ---------- + shape : opengl type + OpenGL shape to draw (e.g. GL_POINTS) + vert : buffer + Buffer with vertices + color : buffer + Buffer with colors + idx : buffer + Buffer with indexes + wire : buffer + Buffer with wire (color and width) + """ + # If wire is avaialble + if wire is not None: + csw = self.getCSW() + self.width(wire[1]) + color_wire = wire[0] if wire[0] in self.buffers else None + if wire[0] not in self.buffers: + self.color(wire[0]) + glPolygonMode(GL_FRONT_AND_BACK, GL_LINE) + self._drawBuffer(shape, vert, color=color_wire, idx=idx, wire=None) + glPolygonMode(GL_FRONT_AND_BACK, GL_FILL) + self.setCSW(csw) + # If vert is available + if vert is not None: + vert = self.buffers[vert] + glEnableClientState(GL_VERTEX_ARRAY) + glBindBuffer(GL_ARRAY_BUFFER, vert.id) + glVertexPointer(vert.d, vert.gltype, 0, None) + # If color is available + if color is not None: + color = self.buffers[color] + glEnableClientState(GL_COLOR_ARRAY) + glBindBuffer(GL_ARRAY_BUFFER, color.id) + glColorPointer(color.d, color.gltype, 0, None) + # If idx is available + if idx is None: + glDrawArrays(shape, 0, vert.n) + else: + idx = self.buffers[idx] + glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, idx.id) + glDrawElements(shape, idx.n, idx.gltype, None) + # Bind buffers + glBindBuffer(GL_ARRAY_BUFFER, 0) + # Unbind vertices + if vert is not None: + glDisableClientState(GL_VERTEX_ARRAY) + # Unbind colors + if color is not None: + glDisableClientState(GL_COLOR_ARRAY) + # Return self + return self + + def _drawBase(self, shape, verts): + """ + Draw a shape by copying data (very slow) + + Parameters + ---------- + shape : opengl shape + OpenGL shape to draw (e.g. GL_POINTS) + verts : np.array + Vertices to draw + """ + # If there are no vertices, do nothing + if len(verts) == 0: + return self + # Select 2D or 3D vertex draw function + glVertex = glVertex2fv if len(verts[0]) == 2 else glVertex3fv + # Draw vertices + glBegin(shape) + for vert in verts: + glVertex(vert) + glEnd() + # Return self + return self diff --git a/externals/camviz/camviz/draw/draw_input.py b/externals/camviz/camviz/draw/draw_input.py new file mode 100755 index 0000000000000000000000000000000000000000..28ec09d9756fa92b54f86aec8c07ad19ffb8f88a --- /dev/null +++ b/externals/camviz/camviz/draw/draw_input.py @@ -0,0 +1,322 @@ +# Copyright 2021 Toyota Research Institute. All rights reserved. + +import pygame + + +class DrawInput: + """Draw subclass containing input controls""" + def __init__(self): + # Initialize basic keys + self.UP = self.DOWN = self.LEFT = self.RIGHT = False + self.RCTRL = self.LCTRL = self.RALT = self.LALT = self.RSHIFT = self.LSHIFT = False + self.SPACE = self.RETURN = self.PGUP = self.PGDOWN = False + # Initialize letter keys + self.KEY_Q = self.KEY_W = self.KEY_E = self.KEY_R = self.KEY_T = False + self.KEY_A = self.KEY_S = self.KEY_D = self.KEY_F = self.KEY_G = False + self.KEY_Z = self.KEY_X = self.KEY_C = self.KEY_V = self.KEY_B = False + # Initialize number keys + self.KEY_0 = self.KEY_1 = self.KEY_2 = self.KEY_3 = self.KEY_4 = self.KEY_5 = \ + self.KEY_6 = self.KEY_7 = self.KEY_8 = self.KEY_9 = False + # Initialize mouse keys + self.mouse_pos = self.motion_type = None + self.tmp_screen, self.tmp_focus = None, False + self.mouse_down = False + + def change_keys(self, key, flag): + """ + Change key to flag value + + Parameters + ---------- + key : pygame key + Key in consideration + flag : bool + State to set the key [True or False] + """ + if key == pygame.K_UP: + self.UP = flag + if key == pygame.K_DOWN: + self.DOWN = flag + if key == pygame.K_LEFT: + self.LEFT = flag + if key == pygame.K_RIGHT: + self.RIGHT = flag + + if key == pygame.K_RCTRL: + self.RCTRL = flag + if key == pygame.K_LCTRL: + self.LCTRL = flag + if key == pygame.K_RALT: + self.RALT = flag + if key == pygame.K_LALT: + self.LALT = flag + if key == pygame.K_RSHIFT: + self.RSHIFT = flag + if key == pygame.K_LSHIFT: + self.LSHIFT = flag + if key == pygame.K_PAGEUP: + self.PGUP = flag + if key == pygame.K_PAGEDOWN: + self.PGDOWN = flag + + if key == pygame.K_SPACE: + self.SPACE = flag + if key == pygame.K_RETURN: + self.RETURN = flag + + if key == pygame.K_s: + self.KEY_S = flag + if key == pygame.K_q: + self.KEY_Q = flag + if key == pygame.K_w: + self.KEY_W = flag + if key == pygame.K_e: + self.KEY_E = flag + if key == pygame.K_r: + self.KEY_R = flag + if key == pygame.K_t: + self.KEY_T = flag + if key == pygame.K_a: + self.KEY_A = flag + if key == pygame.K_s: + self.KEY_S = flag + if key == pygame.K_d: + self.KEY_D = flag + if key == pygame.K_f: + self.KEY_F = flag + if key == pygame.K_g: + self.KEY_G = flag + + if key == pygame.K_0: + self.KEY_0 = flag + if key == pygame.K_1: + self.KEY_1 = flag + if key == pygame.K_2: + self.KEY_2 = flag + if key == pygame.K_3: + self.KEY_3 = flag + if key == pygame.K_4: + self.KEY_4 = flag + if key == pygame.K_5: + self.KEY_5 = flag + if key == pygame.K_6: + self.KEY_6 = flag + if key == pygame.K_7: + self.KEY_7 = flag + if key == pygame.K_8: + self.KEY_8 = flag + if key == pygame.K_9: + self.KEY_9 = flag + + def input(self): + """ + Parse keyboard and mouse input + """ + events = pygame.event.get() # Get events + pos = pygame.mouse.get_pos() # Get mouse position + + # If mouse is not pressing down + if self.mouse_down is False: + # Get current screen based on mouse position + screen = None + for key, scr in self.screens.items(): + if scr.inside(pos): + screen = scr + break + # Set screen focus based on mouse position + focus = screen is not None and pygame.mouse.get_focused() + if not focus: + self.mouse_pos = None + else: + # Use stored screen and focus + screen, focus = self.tmp_screen, self.tmp_focus + + # For each event + for event in events: + # If x button is pressed + if event.type == pygame.QUIT: + return False + # If key has been presed down + if event.type == pygame.KEYDOWN: + # If escape has been pressed, exit + if event.key == pygame.K_ESCAPE: + return False + # If p has been pressed, return virtual camera pose + if event.key == pygame.K_p: + if self.currScreen().viewer is not None: + print('(%7.5f, %7.5f, %7.5f, %1.5f, %1.5f, %1.5f, %1.5f)' % + self.currScreen().viewer.current7()) + # Change key to pressed + self.change_keys(event.key, True) + # If key has been released + if event.type == pygame.KEYUP: + self.change_keys(event.key, False) + # If mouse button has been pressed down + if event.type == pygame.MOUSEBUTTONDOWN: + self.mouse_down = True + self.tmp_screen, self.tmp_focus = screen, focus + # If it's a 3D world screen + if focus and screen.mode is '3D_WORLD': + if event.button == 4: # Wheel forward + if self.RALT: # Going for rotation in Z + screen.viewer.rotateZ(5.0 if self.RCTRL else 0.05 if self.LCTRL else 0.5) + else: # Going for translation in Z + screen.viewer.translateZ(+(5.0 if self.RCTRL else 0.2 if self.LCTRL else 1.0)) + if event.button == 5: # Wheel backwards + if self.RALT: # Going for rotation in Z + screen.viewer.rotateZ(-5.0 if self.RCTRL else -0.05 if self.LCTRL else -0.5) + else: # Going for translation in Z + screen.viewer.translateZ(-(5.0 if self.RCTRL else 0.2 if self.LCTRL else 1.0)) + if event.button == 1: # Left button + self.motion_type, self.mouse_pos = 1, pos + if event.button == 3: # Right button + self.motion_type, self.mouse_pos = 3, pos + if event.button == 2: # Wheel press + screen.reset() + # If it's a 2D image screen + if focus and screen.mode is '2D_IMAGE': + if event.button == 1: # Left button + self.motion_type, self.mouse_pos = 1, pos + if event.button == 2: # Wheel press + screen.res = list(screen.orig_res) + else: + # Change resolution + rel = [(pos[0] - screen.luwh[0]) / screen.luwh[2] * screen.res[2] + screen.res[0], + (pos[1] - screen.luwh[1]) / screen.luwh[3] * screen.res[3] + screen.res[1]] + # Get speed multiplier + mlt = 1.20 if self.RSHIFT else 1.05 + # Wheel forward + if event.button == 4: + if screen.res[0] < 0.95 * screen.res[2] and \ + screen.res[1] < 0.95 * screen.res[3]: + screen.res[2] = (screen.res[0] + screen.res[2] - rel[0])/mlt + rel[0] - screen.res[0] + screen.res[0] = (screen.res[0] - rel[0])/mlt + rel[0] + screen.res[3] = (screen.res[1] + screen.res[3] - rel[1])/mlt + rel[1] - screen.res[1] + screen.res[1] = (screen.res[1] - rel[1])/mlt + rel[1] + # Wheel backwards + elif event.button == 5: + screen.res[2] = (screen.res[0] + screen.res[2] - rel[0])*mlt + rel[0] - screen.res[0] + screen.res[0] = (screen.res[0] - rel[0])*mlt + rel[0] + screen.res[3] = (screen.res[1] + screen.res[3] - rel[1])*mlt + rel[1] - screen.res[1] + screen.res[1] = (screen.res[1] - rel[1])*mlt + rel[1] + # Change resolution + screen.res[0] = max(screen.res[0], screen.orig_res[0]) + screen.res[1] = max(screen.res[1], screen.orig_res[1]) + screen.res[2] = min(screen.res[2], screen.orig_res[2]) + screen.res[3] = min(screen.res[3], screen.orig_res[3]) + # If mouse button has been released + if event.type == pygame.MOUSEBUTTONUP: + self.mouse_down = False + self.mouse_pos = None + # If screen has focus + if focus: + # If it's a 3D world screen + if screen.mode == '3D_WORLD': + # Get new mouse position + if self.mouse_pos is not None: + dX = pos[0] - self.mouse_pos[0] + dY = pos[1] - self.mouse_pos[1] + self.mouse_pos = pos + # If left button + if self.motion_type == 1: + mlin = 1.00 if self.RCTRL else 0.02 if self.LCTRL else 0.10 + screen.viewer.translateX(- dX * mlin) + screen.viewer.translateY(- dY * mlin) + # If right button + elif self.motion_type == 3: + mang = 0.25 if self.RCTRL else 0.01 if self.LCTRL else 0.05 + if screen.ref == 'cam': # Rotation in camera reference + screen.viewer.rotateX(- dY * mang) + screen.viewer.rotateY(+ dX * mang) + elif screen.ref == 'lidar': # Rotation in lidar reference + screen.viewer.rotateX(- dY * mang) + screen.viewer.rotateZ(- dX * mang) + # If it's a 2D image screen + elif screen.mode == '2D_IMAGE': + # Get new mouse position + if self.mouse_pos is not None: + mlin = 5.00 if self.RCTRL else 1.00 + dX = pos[0] - self.mouse_pos[0] + dY = pos[1] - self.mouse_pos[1] + self.mouse_pos = pos + # Resize and move screen center around + screen.res[0] -= dX * mlin + screen.res[2] -= dX * mlin + if screen.res[0] < screen.orig_res[0] or \ + screen.res[2] > screen.orig_res[2]: + screen.res[0] += dX * mlin + screen.res[2] += dX * mlin + screen.res[1] -= dY * mlin + screen.res[3] -= dY * mlin + if screen.res[1] < screen.orig_res[1] or \ + screen.res[3] > screen.orig_res[3]: + screen.res[1] += dY * mlin + screen.res[3] += dY * mlin + # Continue and return True + return True + + @staticmethod + def update(wait): + """Update window after every wait milisseconds""" + pygame.display.flip() + pygame.time.wait(wait) + + @staticmethod + def control(obj): + """Control an object with keyboard""" + # Get velocity values + dlin, dang = 0.2, 5.0 + # Get keys + keys = pygame.key.get_pressed() + # Check for changes + change = False + # Translate in +Z if UP + if keys[pygame.K_UP]: + change = True + obj.translateZ(+dlin) + # Translate in -Z if DOWN + if keys[pygame.K_DOWN ]: + change = True + obj.translateZ(-dlin) + # Translate in -X if LEFT + if keys[pygame.K_LEFT]: + change = True + obj.translateX(-dlin) + # Translate in +X if RIGHT + if keys[pygame.K_RIGHT]: + change = True + obj.translateX(+dlin) + # Translate in -Y if Q + if keys[pygame.K_q]: + change = True + obj.translateY(-dlin) + # Translate in +Y if A + if keys[pygame.K_a]: + change = True + obj.translateY(+dlin) + # Rotate in +Y if S + if keys[pygame.K_s]: + change = True + obj.rotateY(+dang) + # Rotate in -Y if F + if keys[pygame.K_f]: + change = True + obj.rotateY(-dang) + # Rotate in -X if E + if keys[pygame.K_e]: + change = True + obj.rotateX(-dang) + # Rotate in +X if D + if keys[pygame.K_d]: + change = True + obj.rotateX(+dang) + # Rotate in +Z if W + if keys[pygame.K_w]: + change = True + obj.rotateZ(+dang) + # Rotate in -Z if R + if keys[pygame.K_r]: + change = True + obj.rotateZ(-dang) + # Return change value + return change diff --git a/externals/camviz/camviz/draw/draw_texture.py b/externals/camviz/camviz/draw/draw_texture.py new file mode 100755 index 0000000000000000000000000000000000000000..c64ce3afe0c045bb33b9e3286dfedb89b30d5a9e --- /dev/null +++ b/externals/camviz/camviz/draw/draw_texture.py @@ -0,0 +1,99 @@ +# Copyright 2021 Toyota Research Institute. All rights reserved. + +from OpenGL.GL import \ + glTexCoord2f, glBegin, glEnd, glVertex2fv, glVertex3fv, \ + GL_QUADS + +from camviz.containers.texture import Texture +from camviz.opengl.opengl_colors import White +from camviz.utils.utils import labelrc, numpyf +from camviz.utils.types import is_tuple, is_list, is_int + +class DrawTexture: + """Draw subclass containing texture methods""" + def addTexture(self, name, data=None, n=None): + """ + Create a new texture buffer + + Parameters + ---------- + name : str + Buffer name + data : np.array [N,D] or tuple (n,d) + Data to be added to the buffer + If it's a tuple, create a data buffer of that size + n : int or tuple + Number of textures to be added + """ + # If it's a tuple, create individual names for each texture + if is_tuple(name): + name = labelrc(name) + # If it's a list, add each item to its own texture + if is_list(name): + for i in range(len(name)): + self.addTexture(name[i], data[i] if is_list(data) else data) + # Otherwise, create a single texture from data + else: + if n is not None: + if is_tuple(n): + for i in range(n[0]): + for j in range(n[1]): + self.textures['%s%d%d' % (name, i, j)] = Texture(data) + elif is_int(n): + for i in range(n): + self.textures['%s%d' % (name, i)] = Texture(data) + self.textures[name] = Texture(data) + + def updTexture(self, name, data): + """Update texture with new data""" + self.textures[name].update(data) + + def image(self, name, data=None, verts=None, fit=False): + """ + Display a texture on screen + + Parameters + ---------- + name : str + Name of the texture + data : np.array + Update texture with new data before displaying + verts : np.array + Vertices for the texture borders on screen + fit : bool + If true, resize screen to fit new image + """ + # If no name is provided, return None + if name is None: + return + # Get texture ID from name + tex = self.textures[name] + # Resize screen to fit screen if necessary + if fit is True: + self.currScreen().setRes(tex.wh) + # If data is provided, update texture first + if data is not None: + tex.update(data) + # If verts is not provided, create them based on screen dimension + if verts is None: + verts = [[tex.wh[0], 0.0 ], [tex.wh[0], tex.wh[1]], + [ 0.0 , tex.wh[1]], [ 0.0 , 0.0 ]] + verts = numpyf(verts) + # Draw texture + White() + tex.bind() + glBegin(GL_QUADS) + glVertex = glVertex2fv if len(verts[0]) == 2 else glVertex3fv + glTexCoord2f(1.0, 1.0) + glVertex(verts[0]) + glTexCoord2f(1.0, 0.0) + glVertex(verts[1]) + glTexCoord2f(0.0, 0.0) + glVertex(verts[2]) + glTexCoord2f(0.0, 1.0) + glVertex(verts[3]) + glEnd() + tex.unbind() + # Return self + return self + diff --git a/externals/camviz/camviz/objects/__init__.py b/externals/camviz/camviz/objects/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..520f40c05b716320ee33af7bcfdf1d8554e6a1c3 --- /dev/null +++ b/externals/camviz/camviz/objects/__init__.py @@ -0,0 +1,3 @@ +# Copyright 2021 Toyota Research Institute. All rights reserved. + +from camviz.objects.camera import Camera diff --git a/externals/camviz/camviz/objects/bbox2d.py b/externals/camviz/camviz/objects/bbox2d.py new file mode 100755 index 0000000000000000000000000000000000000000..8b0a3db541f111e3799710e27628910fa6461a16 --- /dev/null +++ b/externals/camviz/camviz/objects/bbox2d.py @@ -0,0 +1,45 @@ +# Copyright 2021 Toyota Research Institute. All rights reserved. + +import numpy as np + +from camviz.objects.object import Object + + +class BBox2D(Object): + """ + Bounding Box 2D draw class + + Parameters + ---------- + points : np.array + List of points for the bounding box dimension (left, top, right, bottom) + pose : np.array + Bounding box pose on the screen (right, down) + """ + def __init__(self, points, pose=None): + super().__init__(pose=pose) + self.pts = np.array([[points[0], points[1]], + [points[2], points[1]], + [points[2], points[3]], + [points[0], points[3]]]) + + def draw(self, draw, color_line='gre', color_edge=None): + """ + Draw 2D bounding box on screen + + Parameters + ---------- + draw : camviz.Draw + Draw instance + color_line : str + Line color + color_edge : str + Edge color + """ + # Set color line if provided + if color_line is not None: + draw.color(color_line).width(2).lines( + self.pts[[0, 1, 1, 2, 2, 3, 3, 0]]) + # Set color edge if provided + if color_edge is not None: + draw.color(color_edge).size(4).points(self.pts) diff --git a/externals/camviz/camviz/objects/bbox3d.py b/externals/camviz/camviz/objects/bbox3d.py new file mode 100755 index 0000000000000000000000000000000000000000..d385c13c808b555d424bb66444613fcf64ca50e3 --- /dev/null +++ b/externals/camviz/camviz/objects/bbox3d.py @@ -0,0 +1,42 @@ +# Copyright 2021 Toyota Research Institute. All rights reserved. + +from camviz.objects.object import Object + + +class BBox3D(Object): + """ + Bounding Box 3D draw class + + Parameters + ---------- + points : np.array + List of points for the bounding box dimension (assuming center is 0,0,0) + Order: +++, +-+, +--, ++-, -++, --+, ---, -+- + pose : np.array + Bounding box pose (x-forward, y-left, z-up) + """ + def __init__(self, points, pose=None): + super().__init__(pose=pose) + self.pts = points + + def draw(self, draw, color_line='gre', color_edge=None): + """ + Draw 2D bounding box on screen + + Parameters + ---------- + draw : camviz.Draw + Draw instance + color_line : str + Line color + color_edge : str + Edge color + """ + # Set color line if provided + if color_line is not None: + draw.color(color_line).width(2).lines( + self.pts[[0, 1, 1, 2, 2, 3, 3, 0, 4, 5, 5, 6, + 6, 7, 7, 4, 0, 4, 1, 5, 2, 6, 3, 7]]) + # Set color edge if provided + if color_edge is not None: + draw.color(color_edge).size(4).points(self.pts) diff --git a/externals/camviz/camviz/objects/camera.py b/externals/camviz/camviz/objects/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..3559597bf7b77bdfd4b860e07c8a086b021ed7f1 --- /dev/null +++ b/externals/camviz/camviz/objects/camera.py @@ -0,0 +1,188 @@ +# Copyright 2021 Toyota Research Institute. All rights reserved. + +import numpy as np + +from camviz.objects.object import Object +from camviz.utils.geometry import transpose, invert +from camviz.utils.types import is_list, is_float +from camviz.utils.utils import numpyf, add_row0, add_col1, image_grid + + +def camviz_camera(camera): + """ + Converts packnet-sfm cameras to camviz cameras + + Parameters + ---------- + camera : Camera or list[Camera] + Input packnet-sfm cameras + + Returns + ------- + camera_cv : camviz.objects.camera.Camera + output camviz cameras + """ + # Create a list of cameras if necessary + if is_list(camera): + return [camviz_camera(cam) for cam in camera] + # Return a list of cameras for each batch camera + return [Camera(cam=cam) for cam in camera] + + +class Camera(Object): + """ + Create a camera class + + Parameters + ---------- + scale : float + Scale used when drawing the object + wh : tuple + Image dimensions (width, height) + K : np.array + Camera intrinsics [3,3] + pose : np.array + Object pose + """ + def __init__(self, scale=1.0, wh=None, K=None, pose=None): + # Initialize object super-class + super().__init__(scale, pose) + # If intrinsics is provided, use it + if K is not None: + self.K = transpose(numpyf(K)) + self.iK = np.linalg.inv(self.K) + # If image dimensions is not provided, use it + if wh is not None: + if not isinstance(wh, (list, tuple)): + wh = wh.shape[:2] + self.w, self.h = wh + uv = numpyf([[self.w - 1, 0 ], + [self.w - 1, self.h - 1], + [ 0 , self.h - 1], + [ 0 , 0 ]]) + self.v = add_row0(self.i2c(scale, uv)) + + @staticmethod + def from_vidar(cam, b, scale=1.0): + return Camera(K=cam.K[b][:3, :3], + pose=cam.Tcw.T[b] if cam.Twc is not None else None, + wh=cam.wh, scale=scale) + + def i2c(self, depth=1.0, uv=None): + """ + Project an image to camera coordinates using a depth map + + Parameters + ---------- + depth : float or np.array + Depth values for lifting + uv : np.array + Image grid for lifting + + Returns + ------- + xyz : np.array + Lifted 3D points in camera frame of reference + """ + # If no grid is provided, uses depth map + if uv is None: + if not is_float(depth): + # Create image grid from depth values + uv = image_grid(depth) + else: + # Impossible to create an image grid + raise ValueError('No available grid for camera') + # Add third unitary coordinate to the image grid + if uv.shape[1] == 2: + uv = add_col1(uv) + # A depth map was provided, create a grid from it + elif uv.shape[1] > 3: + uv = image_grid(uv) + # If there are individual depth values per image grid cell + if not is_float(depth): + if len(depth.shape) == 1: + depth = depth[:, np.newaxis] + elif depth.shape[1] > 1: + if len(depth.shape) == 3: + depth = depth[:, :, 0] + depth = depth.reshape(-1, 1) + return (uv @ self.iK) * depth + + def c2i(self, xyz, filter=False, padding=0, return_z=False): + """ + Project 3D points in camera frame of reference to the image plane + + Parameters + ---------- + xyz : np.array + 3D points to be projected + filter : bool + Filter points outside boundaries + padding : int or float + Padding for filtering + return_z : bool + Return z values as well or not + + Returns + ------- + uv : np.array + 2D coordinates of projected points + idx : np.array + Valid indexes in case filtering was enabled + depth : np.array + Depth values in case return_z was enabled + """ + uv = (xyz / xyz[:, 2:] @ self.K)[:, :2] + if filter: + idx = (uv[:, 0] > -padding) & (uv[:, 0] < self.w + padding) & \ + (uv[:, 1] > -padding) & (uv[:, 1] < self.h + padding) & (xyz[:, 2] > 0) + if return_z: + return uv[idx], xyz[idx, 2:], idx + else: + return uv[idx], idx + else: + if return_z: + return uv, xyz[:, 2:] + else: + return uv + + def c2w(self, xyz): + """Transform 3D points in camera frame of reference to world frame of reference""" + if xyz.shape[1] == 3: + xyz = add_col1(xyz) + return (xyz @ self.Tt)[:, :3] + + def w2c(self, xyz): + """Transform 3D points in world frame of reference to camera frame of reference""" + if xyz.shape[1] == 3: + xyz = add_col1(xyz) + return (xyz @ invert(self.Tt))[:, :3] + + def i2w(self, depth=1.0, uv=None): + """Lift 2D image points to 3D space in world frame of reference""" + return self.c2w(self.i2c(depth, uv)) + + def w2i(self, xyz, filter=False, padding=0, return_z=False): + """Project 3D points in world frame of reference to the image plane""" + return self.c2i(self.w2c(xyz), filter=filter, + padding=padding, return_z=return_z) + + def draw(self, draw, tex=None, axes=True, color='gra'): + """ + Draw a camera in a 3D screen + + Parameters + ---------- + draw : Draw + Draw class to be used + tex : int + Optional texture to draw on the camera image plane + axes : bool + True if coordinate axes should be drawn as well + color : str + Which color should be used for the camera + """ + draw.image(tex, verts=self.v[:4]) + draw.color(color).width(4).connects(self.v[4], self.v[:4]).loop(self.v[:4]) + if axes: + draw.axis(0.25 * self.scale) diff --git a/externals/camviz/camviz/objects/object.py b/externals/camviz/camviz/objects/object.py new file mode 100755 index 0000000000000000000000000000000000000000..4a0e95186e815e34d814047ed69c171c3f52252b --- /dev/null +++ b/externals/camviz/camviz/objects/object.py @@ -0,0 +1,111 @@ +# Copyright 2021 Toyota Research Institute. All rights reserved. + +from OpenGL.GL import * + +from camviz.objects.pose import Pose + + +class Object: + """ + Base object draw class + + Parameters + ---------- + scale : float + Scale used when drawing the object + pose : np.array + Object pose + """ + def __init__(self, scale=1.0, pose=None): + self.scale = scale + self.pose = pose if isinstance(pose, Pose) else Pose(pose) + + @property + def t(self): + """Return pose translation""" + return self.pose.t + + @property + def R(self): + """Return pose rotation""" + return self.pose.R + + @property + def T(self): + """Return pose transformation""" + return self.pose.T + + @property + def Rt(self): + """Return pose rotation transposed""" + return self.pose.Rt + + @property + def Tt(self): + """Return pose transformation transposed""" + return self.pose.Tt + + def translateX(self, m): + """Translate object in X by m""" + return self.pose.translateX(m) + + def translateY(self, m): + """Translate object in Y by m""" + return self.pose.translateY(m) + + def translateZ(self, m): + """Translate object in Z by m""" + return self.pose.translateZ(m) + + def rotateX(self, d): + """Rotate object in X by d degrees""" + return self.pose.rotateX(d) + + def rotateY(self, d): + """Rotate object in Y by d degrees""" + return self.pose.rotateY(d) + + def rotateZ(self, d): + """Rotate object in Z by d degrees""" + return self.pose.rotateZ(d) + + def rotateI(self, d): + """Rotate object in X by d degrees (from the camera's perspective)""" + return self.pose.rotateI(d) + + def rotateJ(self, d): + """Rotate object in Y by d degrees (from the camera's perspective)""" + return self.pose.rotateJ(d) + + def rotateK(self, d): + """Rotate object in Z by d degrees (from the camera's perspective)""" + return self.pose.rotateK(d) + + def setPose(self, pose): + """Set object pose""" + return self.pose.setPose(pose) + + def display(self, *args, align=None, **kwargs): + """ + Display object + + Parameters + ---------- + args : args + Extra draw arguments + align : camviz.Pose + Pose used to align the object + kwargs : kwargs + Extra draw arguments + """ + # Get transformation (aligned or not) + if align is not None: + T = (align @ self.pose).Tt + else: + T = self.Tt + + # Draw object + glPushMatrix() + glMultMatrixf(T) + self.draw(*args, **kwargs) + glPopMatrix() diff --git a/externals/camviz/camviz/objects/pointcloud.py b/externals/camviz/camviz/objects/pointcloud.py new file mode 100755 index 0000000000000000000000000000000000000000..ddb786bbfd6e96d1140a0096f2148881b975a5be --- /dev/null +++ b/externals/camviz/camviz/objects/pointcloud.py @@ -0,0 +1,43 @@ +# Copyright 2021 Toyota Research Institute. All rights reserved. + +from camviz.objects.object import * + + +class Pointcloud(Object): + """ + Bounding Box 3D draw class + + Parameters + ---------- + scale : float + Scale used when drawing the object + pts : np.array + Pointcloud points + pose : np.array + Bounding box pose + draw : camviz.Draw + Draw instance + """ + def __init__(self, scale=1.0, pts=None, pose=None, draw=None): + super().__init__(scale, pose) + if draw is not None: + draw.addBufferf('pts', pts) + self.pts = 'pts' + else: + self.pts = pts + + def draw(self, draw, size=1, color='whi'): + """ + Draw pointcloud on screen + + Parameters + ---------- + draw : camviz.Draw + Draw instance + size : int + Point size + color : str + Point color + """ + draw.color(color).size(size).points(self.pts) + diff --git a/externals/camviz/camviz/objects/pose.py b/externals/camviz/camviz/objects/pose.py new file mode 100755 index 0000000000000000000000000000000000000000..9d92457de71c127c2bbaea7d2ebe63769034904b --- /dev/null +++ b/externals/camviz/camviz/objects/pose.py @@ -0,0 +1,192 @@ +# Copyright 2021 Toyota Research Institute. All rights reserved. + +from copy import deepcopy + +import numpy as np + +from camviz.objects.quaternion import Quaternion +from camviz.utils.geometry import unitX, unitY, unitZ +from camviz.utils.utils import numpyf, add_col1 + + +def rot2quat(R): + """Convert rotation matrix to quaternion""" + qw = np.sqrt(1.0 + R[0, 0] + R[1, 1] + R[2, 2]) / 2 + qx = (R[1, 2] - R[2, 1]) / (4.0 * qw) + qy = (R[2, 0] - R[0, 2]) / (4.0 * qw) + qz = (R[0, 1] - R[1, 0]) / (4.0 * qw) + return Quaternion(qw, qx, qy, qz) + + +class Pose: + + def __init__(self, pose=None, align=False): + """ + Pose class + + Parameters + ---------- + pose : np.array + Initial pose + align : np.array + Optional transformation matrix used for alignment + """ + self.q = self.M = None + # If pose is provided, use it + if pose is not None: + self.setPose(pose, align) + # Otherwise, set to identity + else: + self.reset() + + def copy(self): + """Return a copy of the instance""" + return deepcopy(self) + + @property + def t(self): + """Return pose translation""" + return self.M[:3, 3] + + @property + def R(self): + """Return pose translation""" + return self.M[:3, :3] + + @property + def T(self): + """Return pose transformation""" + return self.M + + @property + def Rt(self): + """Return pose rotation transposed""" + return self.R.T + + @property + def Tt(self): + """Return pose transformation transposed""" + return self.T.T + + @property + def inv(self): + """Return inverted pose""" + Tinv = self.T.copy() + Tinv[:3, :3] = np.transpose(self.T[:3, :3]) + Tinv[:3, -1] = np.matmul(-1. * Tinv[:3, :3], self.T[:3, -1]) + return Pose(Tinv) + + @property + def Tinv(self): + """Return inverted pose transformation""" + return self.inv.T + + def translateX(self, m): + """Translate object in X by m""" + return self.translate(unitX(m)) + + def translateY(self, m): + """Translate object in Y by m""" + return self.translate(unitY(m)) + + def translateZ(self, m): + """Translate object in Z by m""" + return self.translate(unitZ(m)) + + def rotateX(self, d, M=None): + """Rotate object in X by d degrees""" + return self.rotate(d, (self.M if M is None else M)[:3, 0]) + + def rotateY(self, d, M=None): + """Rotate object in Y by d degrees""" + return self.rotate(d, (self.M if M is None else M)[:3, 1]) + + def rotateZ(self, d, M=None): + """Rotate object in Z by d degrees""" + return self.rotate(d, (self.M if M is None else M)[:3, 2]) + + def rotateI(self, d): + """Rotate object in X by d degrees (from the camera's perspective)""" + return self.rotate(d, unitX(1)) + + def rotateJ(self, d): + """Rotate object in Y by d degrees (from the camera's perspective)""" + return self.rotate(d, unitY(1)) + + def rotateK(self, d): + """Rotate object in Z by d degrees (from the camera's perspective)""" + return self.rotate(d, unitZ(1)) + + def setPose(self, mat, align=False): + """ + Set pose value + + Parameters + ---------- + mat : np.array + New pose value + align : np.array + Optional transformation matrix used for alignment + """ + # Convert to numpy + mat = numpyf(mat) + # If mat is as 1-dimensional vector + if len(mat.shape) == 1: + # If it has 16 values, reshape and use it as a transformation matrix + if mat.shape[0] == 16: + self.M = np.reshape(mat, (4, 4)) + self.q = rot2quat(self.M) + # If it has 7 values, treat is as translation + quaternion + if mat.shape[0] == 7: + self.M = numpyf(np.identity(4)) + self.M[:3, 3] = [mat[0], mat[1], mat[2]] + self.q = Quaternion(mat[3], mat[4], mat[5], mat[6]) + self.M[:3, :3] = self.q.rotmat().T + # If it's two-dimensional, treat it as a transformation matrix + elif len(mat.shape) == 2: + if mat.shape[0] == 4 and mat.shape[1] == 4: + self.M = mat + self.q = rot2quat(self.M) + # Update transformation matrix + self.M = numpyf(self.M) + # Align if necessary + if align: + R = np.array([[0, -1, 0, 0], + [0, 0, -1, 0], + [1, 0, 0, 0], + [0, 0, 0, 1]]) + self.M = R @ self.M + self.q = rot2quat(self.M) + + def reset(self): + """Reset pose""" + self.q = Quaternion() + self.M = numpyf(np.identity(4)) + + def translate(self, axis): + """Translate pose in a certain axis""" + self.M[:3, 3] += self.q.rotate(numpyf(axis)) + return self + + def rotate(self, deg, axis): + """Rotate pose by deg in a certain axis""" + self.q *= Quaternion(numpyf(axis), deg) + self.M[:3, :3] = self.q.rotmat().T + return self + + def current7(self): + """Return current translation and quaternion values""" + t, q = self.M[:3, 3], self.q.coefs + return t[0], t[1], t[2], q[0], q[1], q[2], q[3] + + def __matmul__(self, other): + """Multiply pose with something else""" + # Pose x Pose + if isinstance(other, Pose): + return Pose(self.M @ other.T) + # Pose x points + elif other.shape[1] == 3: + return (add_col1(other) @ self.Tt)[:, :3] + # Generic multiplication + else: + return self.M @ other.T diff --git a/externals/camviz/camviz/objects/quaternion.py b/externals/camviz/camviz/objects/quaternion.py new file mode 100755 index 0000000000000000000000000000000000000000..fb70c5d815aafb9e917536a08fa419f7b6ceb8dd --- /dev/null +++ b/externals/camviz/camviz/objects/quaternion.py @@ -0,0 +1,72 @@ +# Copyright 2021 Toyota Research Institute. All rights reserved. + +import math + +import numpy as np + + +class Quaternion: + """Quaternion class""" + def __init__(self, *args): + # If no arguments are provided, create an identity quaternion + if len(args) == 0: + self.coefs = 1.0, 0.0, 0.0, 0.0 + # If a single argument is provided + elif len(args) == 1: + # If it's a tuple, it contains coefficients + if isinstance(args[0], tuple): + self.coefs = args[0] + # Otherwise, assume it's a rotation matrix + else: + R = np.array(args[0]) + w = np.sqrt(1.0 + R[0, 0] + R[1, 1] + R[2, 2]) / 2 + x = (R[1, 2] - R[2, 1]) / (4.0 * w) + y = (R[2, 0] - R[0, 2]) / (4.0 * w) + z = (R[0, 1] - R[1, 0]) / (4.0 * w) + self.coefs = w, x, y, z + # If two arguments are provided, assume it's axis and degree + elif len(args) == 2: + v, d = args + r = d * math.pi / 360.0 + c, s = math.cos(r), math.sin(r) + self.coefs = c, v[0] * s, v[1] * s, v[2] * s + # If there are four arguments, assume each individual coefficient is provided + elif len(args) == 4: + self.coefs = args + + def __getitem__(self, idx): + """Return quaternion coefficients""" + return self.coefs[idx] + + def __mul__(self, r): + """Multiply quaternion with rotation angles""" + q, r = self.coefs, r.coefs + return Quaternion(r[0] * q[0] - r[1] * q[1] - r[2] * q[2] - r[3] * q[3], + r[0] * q[1] + r[1] * q[0] - r[2] * q[3] + r[3] * q[2], + r[0] * q[2] + r[1] * q[3] + r[2] * q[0] - r[3] * q[1], + r[0] * q[3] - r[1] * q[2] + r[2] * q[1] + r[3] * q[0]) + + def invert(self): + """Return inverted quaternion""" + w, x, y, z = self.coefs + d = np.sqrt(w * w + x * x + y * y + z * z) + return Quaternion(w / d, - x / d, - y / d, - z / d) + + def rotate(self, p): + """Rotate points""" + vec = self.coefs[1:] + + uv = np.cross(p, vec) + uuv = np.cross(uv, vec) + + return p + 2 * (self.coefs[0] * uv + uuv) + + def rotmat(self): + """Return rotation matrix""" + w, x, y, z = self.coefs + xx, yy, zz = x * x, y * y, z * z + return np.array([[1-2*yy-2*zz, 2*x*y-2*z*w, 2*x*z+2*y*w], + [2*x*y+2*z*w, 1-2*xx-2*zz, 2*y*z-2*x*w], + [2*x*z-2*y*w, 2*y*z+2*x*w, 1-2*xx-2*yy]]) + + diff --git a/externals/camviz/camviz/opengl/opengl_colors.py b/externals/camviz/camviz/opengl/opengl_colors.py new file mode 100755 index 0000000000000000000000000000000000000000..dbc7a63d48de6f60a156f0bb4ff5402da66a9b67 --- /dev/null +++ b/externals/camviz/camviz/opengl/opengl_colors.py @@ -0,0 +1,63 @@ +# Copyright 2021 Toyota Research Institute. All rights reserved. + +from OpenGL.GL import glColor3fv + +def Red(n=1.0): + """Change to red color""" + glColor3fv((n, 0.0, 0.0)) + +def Green(n=1.0): + """Change to green color""" + glColor3fv((0.0, n, 0.0)) + +def Blue(n=1.0): + """Change to blue color""" + glColor3fv((0.0, 0.0, n)) + +def Yellow(n=1.0): + """Change to yellow color""" + glColor3fv((n, n, 0.0)) + +def Magenta(n=1.0): + """Change to magenta color""" + glColor3fv((n, 0.0, n)) + +def Cyan(n=1.0): + """Change to cyan color""" + glColor3fv((0.0, n, n)) + +def Black(): + """Change to black color""" + glColor3fv((0.0, 0.0, 0.0)) + +def White(): + """Change to white color""" + glColor3fv((1.0, 1.0, 1.0)) + +def Gray(): + """Change to gray color""" + glColor3fv((0.5, 0.5, 0.5)) + +def setColor(clr, n=1.0): + """Change to an specific color based on a string""" + if clr == 'red': + Red(n) + if clr == 'gre': + Green(n) + if clr == 'blu': + Blue(n) + if clr == 'yel': + Yellow(n) + if clr == 'mag': + Magenta(n) + if clr == 'cya': + Cyan(n) + if clr == 'blk': + Black() + if clr == 'whi': + White() + if clr == 'gra': + Gray() + # If clr is a tuple, create that specific color + if isinstance(clr, tuple): + glColor3fv(clr) diff --git a/externals/camviz/camviz/opengl/opengl_shapes.py b/externals/camviz/camviz/opengl/opengl_shapes.py new file mode 100755 index 0000000000000000000000000000000000000000..8a1fe36ec8659c30f6732b18b47e1ccdd2681f2f --- /dev/null +++ b/externals/camviz/camviz/opengl/opengl_shapes.py @@ -0,0 +1,181 @@ +# Copyright 2021 Toyota Research Institute. All rights reserved. + +import numpy as np +from OpenGL.GL import \ + glPointSize, glLineWidth, glVertex2fv, glVertex3fv, \ + glPushMatrix, glPopMatrix, glMultMatrixf, glScalef, glBegin, glEnd, \ + GL_LINES, GL_LINE_LOOP +from OpenGL.GLU import \ + gluSphere, gluNewQuadric + +from camviz.opengl.opengl_colors import Green, Blue, Red +from camviz.utils.utils import numpyf, add_list +from camviz.utils.types import is_numpy, is_double_list + + +def vertex_line(pt1, pt2): + """Create line vertices""" + glVertex = glVertex2fv if len(pt1) == 2 else glVertex3fv + glVertex(pt1) + glVertex(pt2) + +def has_multiple(data): + """Checks if data has multiple entries (list of lists or (n,d) numpy array)""" + return (is_numpy(data) and len(data.shape) > 1) or is_double_list(data) + +def setPointSize(n=1): + """Set point size""" + glPointSize(n) + +def setLineWidth(n=1): + """Set line width""" + glLineWidth(n) + +def drawLine(pt1, pt2): + """Draw a line from two points""" + glBegin(GL_LINES) + vertex_line(pt1, pt2) + glEnd() + +def drawMatches(pts1, pts2): + """Draw 1 to 1 matches between two sets of points""" + glBegin(GL_LINES) + for i in range(len(pts1)): + vertex_line([pts1[i, 0], pts1[i, 1], pts1[i, 2]], + [pts2[i, 0], pts2[i, 1], pts2[i, 2]]) + glEnd() + +def drawConnects(vert1, verts2): + """Draw connections from each vert1 to all vert2""" + vert1, verts2 = numpyf(vert1), numpyf(verts2) + glBegin(GL_LINES) + for vert2 in verts2: + vertex_line(vert1, vert2) + glEnd() + +def drawRect(lu=None, rd=None, ct=None, wh=None, x=False): + """ + Draw a rectangle + + Parameters + ---------- + lu : np.array + Left/Up point + rd : np.array + Right/Down point + ct : np.array + Center point + wh : np.array + Width/height + x : bool + Draw an x inside the rectangle + """ + # If no center is provided, get border points + if ct is not None and wh is not None: + lu = [ct[0] - wh[0] / 2, ct[1] - wh[1] / 2] + rd = [ct[0] + wh[0] / 2, ct[1] + wh[1] / 2] + ld, ru = [lu[0], rd[1]], [rd[0], lu[1]] + # Else, get points based on center + elif lu is not None: + if rd is not None: + ld, ru = [lu[0], rd[1]], [rd[0], lu[1]] + elif wh is not None: + ld, ru = [lu[0], lu[1] + wh[1]], \ + [lu[0] + wh[0], rd[1]] + # Wrong parameters + else: + raise ValueError('wrong drawRect parameters') + # Draw rectangle + glBegin(GL_LINE_LOOP) + vertex_line(lu, ld) + vertex_line(rd, ru) + # Draw x + if x: + vertex_line(lu, rd) + vertex_line(ld, ru) + glEnd() + +def drawCross(ct, sz): + """ + Draw a cross on screen + + Parameters + ---------- + ct : np.array + Cross center + sz : float + Cross size + """ + # If there are multiple centers, draw one cross for each + if has_multiple(ct): + [drawCross(pt, sz) for pt in ct] + # If there is a single center + else: + # Get borders + u, d = ct[1] - sz / 2, ct[1] + sz / 2 + l, r = ct[0] - sz / 2, ct[0] + sz / 2 + # Draw cross + glBegin(GL_LINES) + vertex_line([l, ct[1]], [r, ct[1]]) + vertex_line([ct[0], u], [ct[0], d]) + glEnd() + +def drawAxis(scale=1.0, center=(0, 0, 0), width=None): + """ + Draw a xyz axis on screen + + Parameters + ---------- + scale : float + Axis scale + center : np.array + Axis center + width : int + Axis line width + """ + # Set width if provided + if width is not None: + setLineWidth(width) + # Convert center to numpy + center = numpyf(center) + # Draw axis + glBegin(GL_LINES) + Green() + vertex_line(center, add_list(center, (scale, 0, 0))) + Blue() + vertex_line(center, add_list(center, (0, scale, 0))) + Red() + vertex_line(center, add_list(center, (0, 0, scale))) + glEnd() + +def drawEllipse(mean, cov): + """ + Draw an ellipse on screen + + Parameters + ---------- + mean : np.array + Ellipse mean + cov : np.array + Ellipse covariance + """ + # If there are multiple means, draw one ellipse for each + if len(mean.shape) > 1: + for i in range(mean.shape[0]): + drawEllipse(mean[i], cov[i]) + # Else, draw a single ellipse + else: + # Get eigenvalue and eigenvector + val, vec = np.linalg.eig(cov) + # Get transformation matrix + Tt = np.eye(4) + Tt[:3, :3] = vec + Tt[3, :3] = mean + # Apply transformation matrix and draw + glPushMatrix() + glMultMatrixf(Tt) + glScalef(2.0 * np.sqrt(val[0]), + 2.0 * np.sqrt(val[1]), + 2.0 * np.sqrt(val[2])) + gluSphere(gluNewQuadric(), 1.00, 100, 20) + glPopMatrix() diff --git a/externals/camviz/camviz/screen/screen.py b/externals/camviz/camviz/screen/screen.py new file mode 100755 index 0000000000000000000000000000000000000000..ac4ecdaa04901025a252e5327374b4f69543747b --- /dev/null +++ b/externals/camviz/camviz/screen/screen.py @@ -0,0 +1,46 @@ +# Copyright 2021 Toyota Research Institute. All rights reserved. + +from copy import deepcopy + + +class Screen: + def __init__(self, luwh, mode): + """ + Screen class + + Parameters + ---------- + luwh : tuple + Left/right/width/height values + mode : str + Screen mode ('2D_IMAGE' or '3D_WORLD') + """ + assert mode in ['2D_IMAGE', '3D_WORLD'], 'Invalid screen mode' + self.luwh, self.mode = luwh, mode + self.origin = self.viewer = None + + def inside(self, pos): + """ + Check if a 2D coordinate is inside the screen + + Parameters + ---------- + pos : np.array + Pose to check + + Returns + ------- + inside : bool + Whether pos is inside the screen + """ + return self.luwh[0] < pos[0] < self.luwh[0] + self.luwh[2] and \ + self.luwh[1] < pos[1] < self.luwh[1] + self.luwh[3] + + def saveViewer(self): + """Save current virtual viewer camera (pose and intrinsics)""" + self.origin = deepcopy(self.viewer) + + def reset(self): + """Reset current virtual viewer camera""" + self.viewer = deepcopy(self.origin) + diff --git a/externals/camviz/camviz/screen/screen2Dimage.py b/externals/camviz/camviz/screen/screen2Dimage.py new file mode 100755 index 0000000000000000000000000000000000000000..0b1cbda5451bc4c12e667de536881acbee4db372 --- /dev/null +++ b/externals/camviz/camviz/screen/screen2Dimage.py @@ -0,0 +1,45 @@ +# Copyright 2021 Toyota Research Institute. All rights reserved. + +from OpenGL.GL import GL_PROJECTION, GL_DEPTH_TEST, GL_MODELVIEW +from OpenGL.GL import glMatrixMode, glLoadIdentity, glDisable +from OpenGL.GLU import gluOrtho2D + +from camviz.screen.screen import Screen + + +class Screen2Dimage(Screen): + """ + 2D screen for image display + + Parameters + ---------- + luwh : tuple + Left/up/width/height values + res : tuple + Image resolution + """ + def __init__(self, luwh, res): + super().__init__(luwh, '2D_IMAGE') + # Get resolution from dimensions if not provided + if res is None: + res = (self.luwh[2], self.luwh[3]) + # Initialize values + self.setRes(res) + self.orig_res = list(self.res) + self.background = 'whi' + + def setRes(self, res): + """Set new resolution""" + self.res = [0, 0, res[0], res[1]] + self.prepare() + + def prepare(self): + """Prepare screen for display""" + glMatrixMode(GL_PROJECTION) + glLoadIdentity() + glDisable(GL_DEPTH_TEST) + gluOrtho2D(self.res[0], self.res[2], + self.res[3], self.res[1]) + glMatrixMode(GL_MODELVIEW) + glLoadIdentity() + diff --git a/externals/camviz/camviz/screen/screen3Dworld.py b/externals/camviz/camviz/screen/screen3Dworld.py new file mode 100755 index 0000000000000000000000000000000000000000..ec72a73b270658e9d70668f1091ed586e7420240 --- /dev/null +++ b/externals/camviz/camviz/screen/screen3Dworld.py @@ -0,0 +1,110 @@ +# Copyright 2021 Toyota Research Institute. All rights reserved. + +import numpy as np +from OpenGL.GL import GL_PROJECTION, GL_DEPTH_TEST, GL_MODELVIEW +from OpenGL.GL import glMatrixMode, glEnable, glLoadIdentity, glMultMatrixf +from OpenGL.GLU import gluPerspective, gluLookAt + +from camviz.objects.pose import Pose +from camviz.screen.screen import Screen + + +class Screen3Dworld(Screen): + """ + 3D screen for virtual world display + + Parameters + ---------- + luwh : tuple + Left/up/width/height values + wh : width/height + Virtual camera image dimensions + K : np.array [3,3] + Virtual camera intrinsics + nf : tuple + Near/far display parameters + background : str + Background color ['bla', 'whi'] + pose : tuple + Virtual camera pose + ref : str + Coordinate reference system ['cam', 'lidar'] + """ + def __init__(self, luwh, wh=None, K=None, nf=(0.01, 10000.0), + background='bla', pose=None, ref='cam'): + super().__init__(luwh, '3D_WORLD') + self.wh, self.K, self.nf = wh, K, nf + self.viewer = self.origin = self.P = None + self.background = background + self.ref = ref + # Start and prepare screen + self.start() + self.prepare() + # Rotate if using a lidar frame of reference + if ref == 'lidar': + self.viewer.rotateY(-90).rotateZ(90) + self.saveViewer() + # Set viewer pose if provided + if pose is not None: + self.viewer.setPose(pose) + self.saveViewer() + + def start(self): + """Start viewer""" + self.viewer = Pose() + self.origin = Pose() + if self.wh is not None and self.K is not None: + self.calibrate() + + def prepare(self): + """Prepare screen for display""" + glMatrixMode(GL_PROJECTION) + glLoadIdentity() + + glEnable(GL_DEPTH_TEST) + + # If calibration matrix is provided, use it + if self.P is not None: + glMultMatrixf(self.P) + # Otherwise, use a default perspective + else: + gluPerspective(45, self.luwh[2] / self.luwh[3], + self.nf[0], self.nf[1]) + + glMatrixMode(GL_MODELVIEW) + glLoadIdentity() + + # Determine look vector + T = self.viewer.T + gluLookAt( + T[0, 3], + T[1, 3], + T[2, 3], + T[0, 3] + T[0, 2], + T[1, 3] + T[1, 2], + T[2, 3] + T[2, 2], + - T[0, 1], + - T[1, 1], + - T[2, 1], + ) + + def calibrate(self): + """Calibrate screen for display""" + # Convert intrinsics to numpy if needed + if isinstance(self.K, list): + self.K = np.array(self.K) + + # Create transformation matrix + self.P = np.zeros(16) + + self.P[0] = 2 * self.K[0, 0] / self.wh[0] + self.P[5] = 2 * self.K[1, 1] / self.wh[1] + + self.P[8] = 2.0 * (self.K[0, 2] / self.wh[0]) - 1.0 + self.P[9] = 2.0 * (self.K[1, 2] / self.wh[1]) - 1.0 + + self.P[10] = - 1.0 * (self.nf[1] + self.nf[0]) / (self.nf[1] - self.nf[0]) + self.P[14] = - 2.0 * (self.nf[1] * self.nf[0]) / (self.nf[1] - self.nf[0]) + self.P[11] = - 1.0 + + self.P = np.reshape(self.P, (4, 4)) diff --git a/externals/camviz/camviz/utils/__init__.py b/externals/camviz/camviz/utils/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..394ee97e5a51c8683a46d1e83df26245c04c4a56 --- /dev/null +++ b/externals/camviz/camviz/utils/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2021 Toyota Research Institute. All rights reserved. + diff --git a/externals/camviz/camviz/utils/cmaps.py b/externals/camviz/camviz/utils/cmaps.py new file mode 100755 index 0000000000000000000000000000000000000000..d975b445e541f57240b1eb8c274934fbbd68f66c --- /dev/null +++ b/externals/camviz/camviz/utils/cmaps.py @@ -0,0 +1,63 @@ +# Copyright 2021 Toyota Research Institute. All rights reserved. + +import numpy as np +from matplotlib.cm import get_cmap + +from camviz.utils.types import is_numpy, is_tensor + + +def jet(data, range=None, exp=1.0): + """ + Creates a JET colormap from data + + Parameters + ---------- + data : np.array [N,1] + Data to be converted into a colormap + range : tuple (min,max) + Optional range value for the colormap (if None, use min and max from data) + exp : float + Exponential value to weight the color differently + + Returns + ------- + colormap : np.array [N,3] + Colormap obtained from data + """ + # Return if data is not available + if data is None or data.size == 0 or isinstance(data, tuple): + return data + else: + # If data is a tensor, convert to numpy + if is_tensor(data): + data = data.detach().cpu().numpy() + # If data is [N,1], remove second dimensions + if len(data.shape) > 1: + data = data.reshape(-1) + # Determine range if not available + if range is None: + data = data.copy() - np.min(data) + data = data / (np.max(data) + 1e-6) + else: + data = (data - range[0]) / (range[1] - range[0]) + data = np.maximum(np.minimum(data, 1.0), 0.0) + # Use exponential if requested + if exp != 1.0: + data = data ** exp + # Initialize colormap + jet = np.ones((data.shape[0], 3), dtype=np.float32) + # First stage + idx = (data <= 0.33) + jet[idx, 1] = data[idx] / 0.33 + jet[idx, 0] = 0.0 + # Second stage + idx = (data > 0.33) & (data <= 0.67) + jet[idx, 0] = (data[idx] - 0.33) / 0.33 + jet[idx, 2] = 1.0 - jet[idx, 0] + # Third stage + idx = data > 0.67 + jet[idx, 1] = 1.0 - (data[idx] - 0.67) / 0.33 + jet[idx, 2] = 0.0 + # Return colormap + return jet + diff --git a/externals/camviz/camviz/utils/geometry.py b/externals/camviz/camviz/utils/geometry.py new file mode 100755 index 0000000000000000000000000000000000000000..afcce0e117b29d1220726c005e4ed4bb2190287a --- /dev/null +++ b/externals/camviz/camviz/utils/geometry.py @@ -0,0 +1,23 @@ +# Copyright 2021 Toyota Research Institute. All rights reserved. + +import numpy as np + +def unitX(m=1.0): + """Return an unit vector on X""" + return np.array((m, 0, 0)) + +def unitY(m=1.0): + """Return an unit vector on Y""" + return np.array((0, m, 0)) + +def unitZ(m=1.0): + """Return an unit vector on Z""" + return np.array((0, 0, m)) + +def transpose(data): + """Transpose numpy array""" + return data.T + +def invert(data): + """Invert numpy array""" + return np.linalg.inv(data) diff --git a/externals/camviz/camviz/utils/image.py b/externals/camviz/camviz/utils/image.py new file mode 100755 index 0000000000000000000000000000000000000000..134bcd7a8c7d8487856247df5092b416ddc5dde5 --- /dev/null +++ b/externals/camviz/camviz/utils/image.py @@ -0,0 +1,26 @@ +# Copyright 2021 Toyota Research Institute. All rights reserved. + +import numpy as np +from PIL import Image + +def load_image(file, shape=None): + """ + Load an image and optionally resizes it + + Parameters + ---------- + file : str + Image filename + shape : tuple (width, height) + Optional reshape size + + Returns + ------- + image : np.array [H,W] + Loaded image + """ + image = Image.open(file) + if shape: + image = image.resize(shape, resample=Image.ANTIALIAS) + return np.array(image) + diff --git a/externals/camviz/camviz/utils/types.py b/externals/camviz/camviz/utils/types.py new file mode 100755 index 0000000000000000000000000000000000000000..ca4954dbd84e67eac4134322ec6a327aae967766 --- /dev/null +++ b/externals/camviz/camviz/utils/types.py @@ -0,0 +1,55 @@ +# Copyright 2020 Toyota Research Institute. All rights reserved. + +import numpy as np +import torch + + +def is_numpy(data): + """Checks if data is a numpy array.""" + return isinstance(data, np.ndarray) + + +def is_tensor(data): + """Checks if data is a torch tensor.""" + return type(data) == torch.Tensor + + +def is_tuple(data): + """Checks if data is a tuple.""" + return isinstance(data, tuple) + + +def is_list(data): + """Checks if data is a list.""" + return isinstance(data, list) + + +def is_double_list(data): + """Checks if data is a double list (list of lists)""" + return is_list(data) and is_list(data[0]) + + +def is_dict(data): + """Checks if data is a dictionary.""" + return isinstance(data, dict) + + +def is_str(data): + """Checks if data is a string.""" + return isinstance(data, str) + + +def is_int(data): + """Checks if data is an integer.""" + return isinstance(data, int) + + +def is_float(data): + """Checks if data is a float value.""" + return isinstance(data, float) + + +def is_seq(data): + """Checks if data is a list or tuple.""" + return is_tuple(data) or is_list(data) + diff --git a/externals/camviz/camviz/utils/utils.py b/externals/camviz/camviz/utils/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..b91f02dcf0718df574404b1fb7ec249f1495540d --- /dev/null +++ b/externals/camviz/camviz/utils/utils.py @@ -0,0 +1,83 @@ +# Copyright 2021 Toyota Research Institute. All rights reserved. + +import numpy as np +from matplotlib.cm import get_cmap + +from camviz.utils.types import is_numpy, is_tensor + + +def add_row0(npy): + """Add a row with zeros to a numpy array""" + return np.vstack([npy, np.zeros((1, npy.shape[1]))]) + +def add_col1(npy): + """Add a column with ones to a numpy array""" + return np.hstack([npy, np.ones((npy.shape[0], 1))]) + +def flatten(lst): + """Flatten a list of lists into a list""" + return [l for ls in lst for l in ls] + +def change_coords(xyz1): + """Flip coordinates from camera to lidar frame of reference""" + xyz2 = xyz1[:, [0, 2, 1]] + xyz2[:, 1] *= -1 + return xyz2 + +def numpyf(data): + """Convert data to a numpy array if necessary""" + return data if is_numpy(data) else \ + data.cpu().detach().numpy() if is_tensor(data) else \ + np.array(data, dtype=np.float32) + +def labelrc(tup): + """Create row and column labels for buffers""" + if len(tup) == 2: + tup = ('', tup[0], tup[1]) + return [['%s%d%d' % (tup[0], j, i) + for i in range(tup[2])] for j in range(tup[1])] + +def add_list(lst1, lst2): + """Add two lists element-wise""" + return [l1 + l2 for l1, l2 in zip(lst1, lst2)] + +def image_grid(mat): + i, j = mat.shape[:2] + u, v = np.meshgrid(np.arange(j), np.arange(i)) + return np.stack([u.reshape(-1), v.reshape(-1), np.ones(i * j)], 1) + +def alternate_points(x1, x2): + x = np.zeros((x1.shape[0] + x2.shape[0], 3)) + for i in range(x1.shape[0]): + x[2*i], x[2*i+1] = x1[i], x2[i] + return x + + +def vis_inverse_depth(inv_depth, normalizer=None, percentile=95, colormap='plasma'): + cm = get_cmap(colormap) + if normalizer: + inv_depth /= normalizer + else: + inv_depth /= (np.percentile(inv_depth, percentile) + 1e-6) + return cm(np.clip(inv_depth, 0., 1.0))[:, :, :3] + + +def grid_idx(grid): + + nx, ny = grid.shape[:2] + nqx, nqy = nx - 1, ny - 1 + nqt = nqx * nqy + + idx = np.zeros(4 * nqt) + cnt_idx, cnt_data = 0, 0 + for i in range(nx): + for j in range(ny): + if i < nqx and j < nqy: + idx[cnt_idx + 0] = cnt_data + idx[cnt_idx + 1] = cnt_data + 1 + idx[cnt_idx + 2] = cnt_data + ny + 1 + idx[cnt_idx + 3] = cnt_data + ny + cnt_idx += 4 + cnt_data += 1 + + return idx diff --git a/externals/camviz/demos/data/ddad_eval.npz b/externals/camviz/demos/data/ddad_eval.npz new file mode 100755 index 0000000000000000000000000000000000000000..3d08956bcbdfdf1a595ad76b90b3675765dd765a --- /dev/null +++ b/externals/camviz/demos/data/ddad_eval.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3f552b98d21c639df946df0ddf5e80d0423b8f392035720155cb30524e8d55a1 +size 1862400 diff --git a/externals/camviz/demos/pointcloud.py b/externals/camviz/demos/pointcloud.py new file mode 100755 index 0000000000000000000000000000000000000000..fd56e01d068600ec2e0d381c1c70eb2a31b00a16 --- /dev/null +++ b/externals/camviz/demos/pointcloud.py @@ -0,0 +1,72 @@ +# Copyright 2021 Toyota Research Institute. All rights reserved. + +import numpy as np + +import camviz as cv + +# Load evaluation data +data = np.load('demos/data/ddad_eval.npz') + +# Get image resolution +wh = data['rgb'].shape[:2][::-1] + +# Create draw tool with specific width and height window dimensions +draw = cv.Draw(wh=(2000, 900), title='CamViz Pointcloud Demo') + +# Create image screen to show the RGB image +draw.add2Dimage('rgb', luwh=(0.00, 0.00, 0.33, 0.50), res=wh) + +# Create image screen to show the depth visualization +draw.add2Dimage('viz', luwh=(0.00, 0.50, 0.33, 1.00), res=wh) + +# Create world screen at specific position inside the window (% left/up/right/down) +draw.add3Dworld('wld', luwh=(0.33, 0.00, 1.00, 1.00), + pose=(7.25323, -3.80291, -5.89996, 0.98435, 0.07935, 0.15674, 0.01431)) + +# Parse dictionary information +rgb = data['rgb'] +intrinsics = data['intrinsics'] +depth = data['depth'] +viz = data['viz'] + +# Create camera from intrinsics and image dimensions (width and height) +camera = cv.objects.Camera(K=intrinsics, wh=wh) + +# Project depth maps from image (i) to camera (c) coordinates +points = camera.i2c(depth) + +# Create pointcloud colors +rgb_clr = rgb.reshape(-1, 3) # RGB colors +viz_clr = viz.reshape(-1, 3) # Depth visualization colors +hgt_clr = cv.utils.cmaps.jet(-points[:, 1]) # Height colors + +# Create RGB and visualization textures +draw.addTexture('rgb', rgb) # Create texture buffer to store rgb image +draw.addTexture('viz', viz) # Create texture buffer to store visualization image + +# Create buffers to store data for display +draw.addBufferf('pts', points) # Create data buffer to store depth points +draw.addBufferf('clr', rgb_clr) # Create data buffer to store rgb points color +draw.addBufferf('viz', viz_clr) # Create data buffer to store pointcloud heights +draw.addBufferf('hgt', hgt_clr) # Create data buffer to store pointcloud heights + +# Color dictionary +color_dict = {0: 'clr', 1: 'viz', 2: 'hgt'} + +# Display loop +color_mode = 0 +while draw.input(): + # If RETURN is pressed, switch color mode + if draw.RETURN: + color_mode = (color_mode + 1) % len(color_dict) + # Clear window + draw.clear() + # Draw image textures on their respective screens + draw['rgb'].image('rgb') + draw['viz'].image('viz') + # Draw points and colors from buffer + draw['wld'].size(2).points('pts', color_dict[color_mode]) + # Draw camera with texture as image + draw['wld'].object(camera, tex='rgb') + # Update window + draw.update(30) diff --git a/externals/camviz/media/figs/demo_pointcloud.png b/externals/camviz/media/figs/demo_pointcloud.png new file mode 100755 index 0000000000000000000000000000000000000000..fa5ccc13c91fa44bfae11637d2faa2c422677866 --- /dev/null +++ b/externals/camviz/media/figs/demo_pointcloud.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6974972ece410a4e2462166e71a58db55887ec8448dbd7e01e887284543a65df +size 1357373 diff --git a/externals/camviz/media/figs/tri-logo.png b/externals/camviz/media/figs/tri-logo.png new file mode 100755 index 0000000000000000000000000000000000000000..926ca3c21e08abfd3f0293881b35fa5069bc6af9 --- /dev/null +++ b/externals/camviz/media/figs/tri-logo.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1fff5679bfb6c236e48d49d48d6d33cfa673ea91079ffc988c09a466d98adaf6 +size 9260 diff --git a/externals/camviz/media/gifs/fsm.gif b/externals/camviz/media/gifs/fsm.gif new file mode 100755 index 0000000000000000000000000000000000000000..ba5d555e598c1d0ec54e2c83a7959a5c6e655200 --- /dev/null +++ b/externals/camviz/media/gifs/fsm.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7243ba0e55a93e05dfc586313b52c1d9e17ae5cc0f299ea927700035a1b57605 +size 3919285 diff --git a/externals/camviz/media/gifs/guda.gif b/externals/camviz/media/gifs/guda.gif new file mode 100755 index 0000000000000000000000000000000000000000..0b09b753c16e1e07fa8c6c8cfdcff18af84d3ebf --- /dev/null +++ b/externals/camviz/media/gifs/guda.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fffe3626fa2d8cdd8bfd432d7557db8eeaca7fa6a057e6d8d9be81bbaba7733f +size 4059325 diff --git a/externals/camviz/media/gifs/packnet-ddad.gif b/externals/camviz/media/gifs/packnet-ddad.gif new file mode 100755 index 0000000000000000000000000000000000000000..afb9941c5bee2027106259b30152c24f394ea105 --- /dev/null +++ b/externals/camviz/media/gifs/packnet-ddad.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2d24eb15de8c40c9797d9f1222250ca64c64eeafa04dcf31d19f1309210c3c49 +size 4145246 diff --git a/externals/camviz/media/gifs/packnet-san.gif b/externals/camviz/media/gifs/packnet-san.gif new file mode 100755 index 0000000000000000000000000000000000000000..294a9cad627e0097c0504a826d1242ab3112a7e7 --- /dev/null +++ b/externals/camviz/media/gifs/packnet-san.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fc1648b6be7ef1a0833c9e62fa63ecc8d1d22bee4b550bf6c3c47489f807010c +size 9579461 diff --git a/externals/camviz/requirements.txt b/externals/camviz/requirements.txt new file mode 100755 index 0000000000000000000000000000000000000000..2f41d3e366d68b51fc5ec26c03717892efbab076 --- /dev/null +++ b/externals/camviz/requirements.txt @@ -0,0 +1,10 @@ +matplotlib==3.3.4 +numpy==1.19.5 +opencv-python==4.5.1.48 +Pillow==8.2.0 +pygame==2.0.1 +PyOpenGL==3.1.5 +pyparsing==2.4.7 +python-dateutil==2.8.1 +torch==1.8.1 + diff --git a/media/figs/camviz_ddad.jpg b/media/figs/camviz_ddad.jpg new file mode 100644 index 0000000000000000000000000000000000000000..07ea4ccfb2f70f5033a48ef031dceb9a2ded0e15 Binary files /dev/null and b/media/figs/camviz_ddad.jpg differ diff --git a/media/figs/camviz_kitti.jpg b/media/figs/camviz_kitti.jpg new file mode 100644 index 0000000000000000000000000000000000000000..31b08b0c039c72485fad682ad6fa5da59c80d76b Binary files /dev/null and b/media/figs/camviz_kitti.jpg differ diff --git a/media/figs/depthformer.gif b/media/figs/depthformer.gif new file mode 100644 index 0000000000000000000000000000000000000000..84f57099046f4b134a9d7dea8a9b9bcec9b0063d --- /dev/null +++ b/media/figs/depthformer.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3e884d6cc819906c1c90582dd6b3b4ec95da17c8ffa1ee4db2883a16f31c3c04 +size 16166886 diff --git a/media/figs/fsm.gif b/media/figs/fsm.gif new file mode 100755 index 0000000000000000000000000000000000000000..ba5d555e598c1d0ec54e2c83a7959a5c6e655200 --- /dev/null +++ b/media/figs/fsm.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7243ba0e55a93e05dfc586313b52c1d9e17ae5cc0f299ea927700035a1b57605 +size 3919285 diff --git a/media/figs/overfit_kitti.jpg b/media/figs/overfit_kitti.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1904875589652d526ff8634f04a7964d03ff3dcf Binary files /dev/null and b/media/figs/overfit_kitti.jpg differ diff --git a/media/figs/packnet.gif b/media/figs/packnet.gif new file mode 100644 index 0000000000000000000000000000000000000000..afb9941c5bee2027106259b30152c24f394ea105 --- /dev/null +++ b/media/figs/packnet.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2d24eb15de8c40c9797d9f1222250ca64c64eeafa04dcf31d19f1309210c3c49 +size 4145246 diff --git a/media/figs/self-calibration.gif b/media/figs/self-calibration.gif new file mode 100644 index 0000000000000000000000000000000000000000..a194e1f09ba3122e2700da44448dd518fe9f2fde --- /dev/null +++ b/media/figs/self-calibration.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fabc94a693efa3044c9588c60fb528858427a23f24f6af789ecb1ad50d929f18 +size 7183493 diff --git a/media/figs/tri-logo.png b/media/figs/tri-logo.png new file mode 100644 index 0000000000000000000000000000000000000000..926ca3c21e08abfd3f0293881b35fa5069bc6af9 --- /dev/null +++ b/media/figs/tri-logo.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1fff5679bfb6c236e48d49d48d6d33cfa673ea91079ffc988c09a466d98adaf6 +size 9260 diff --git a/media/tests/ddad.png b/media/tests/ddad.png new file mode 100644 index 0000000000000000000000000000000000000000..f099701bc286cb987a0ccf334c540faee432b466 --- /dev/null +++ b/media/tests/ddad.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9b2de170d88a1b4af88b3f99c021f862ef71b96d6204f0d598f5a08cd6dbf59e +size 350154 diff --git a/media/tests/kitti.png b/media/tests/kitti.png new file mode 100644 index 0000000000000000000000000000000000000000..9a8ceac798f4f0852358443286d54507a0ad00c9 --- /dev/null +++ b/media/tests/kitti.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:401c2f28e2721202fad012426fd06bbf2215209b7057405e4b2c75859cd24599 +size 920682 diff --git a/scripts/run.py b/scripts/run.py new file mode 100755 index 0000000000000000000000000000000000000000..547f7d3f2b4ff82601646d61b4da5eb4ace8a620 --- /dev/null +++ b/scripts/run.py @@ -0,0 +1,25 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import os + +import fire +import torch + +from vidar.core.trainer import Trainer +from vidar.core.wrapper import Wrapper +from vidar.utils.config import read_config + + +def train(cfg, **kwargs): + + os.environ['DIST_MODE'] = 'gpu' if torch.cuda.is_available() else 'cpu' + + cfg = read_config(cfg, **kwargs) + + wrapper = Wrapper(cfg, verbose=True) + trainer = Trainer(cfg) + trainer.learn(wrapper) + + +if __name__ == '__main__': + fire.Fire(train) diff --git a/scripts/run_ddp.py b/scripts/run_ddp.py new file mode 100644 index 0000000000000000000000000000000000000000..cdd1df1f1c688a9e0ac2aff9a5562163e91a6424 --- /dev/null +++ b/scripts/run_ddp.py @@ -0,0 +1,47 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import os + +import fire +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from vidar.core.trainer import Trainer +from vidar.core.wrapper import Wrapper +from vidar.utils.config import read_config + + +def train(cfg, **kwargs): + + os.environ['DIST_MODE'] = 'ddp' + + cfg = read_config(cfg, **kwargs) + + mp.spawn(main_worker, + nprocs=torch.cuda.device_count(), + args=(cfg,), join=True) + + +def main_worker(gpu, cfg): + + torch.cuda.set_device(gpu) + world_size = torch.cuda.device_count() + + os.environ['RANK'] = str(gpu) + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + os.environ['DIST_MODE'] = 'ddp' + + dist.init_process_group(backend='nccl', world_size=world_size, rank=gpu) + + wrapper = Wrapper(cfg, verbose=True) + trainer = Trainer(cfg) + trainer.learn(wrapper) + + dist.destroy_process_group() + + +if __name__ == '__main__': + fire.Fire(train) diff --git a/vidar/__init__.py b/vidar/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vidar/arch/__init__.py b/vidar/arch/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vidar/arch/blocks/depth/SigmoidToDepth.py b/vidar/arch/blocks/depth/SigmoidToDepth.py new file mode 100755 index 0000000000000000000000000000000000000000..1bf753e00f4e1b822c2dcdc2a8545e2f04fd65a3 --- /dev/null +++ b/vidar/arch/blocks/depth/SigmoidToDepth.py @@ -0,0 +1,30 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from abc import ABC + +import torch.nn as nn + +from vidar.utils.decorators import iterate2 + + +class SigmoidToDepth(nn.Module, ABC): + """ + Converts sigmoid to depth map + + Parameters + ---------- + min_depth : Float + Minimum depth value + max_depth + Maximum depth value + """ + def __init__(self, min_depth, max_depth): + super().__init__() + self.min_depth = min_depth + self.max_depth = max_depth + self.diff_depth = (self.max_depth - self.min_depth) + + @iterate2 + def forward(self, sigmoid): + """Convert sigmoid to depth""" + return self.min_depth + self.diff_depth * sigmoid diff --git a/vidar/arch/blocks/depth/SigmoidToInvDepth.py b/vidar/arch/blocks/depth/SigmoidToInvDepth.py new file mode 100755 index 0000000000000000000000000000000000000000..70eb16903861fa711d0ea169ab59f95f9b280396 --- /dev/null +++ b/vidar/arch/blocks/depth/SigmoidToInvDepth.py @@ -0,0 +1,35 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from abc import ABC + +import torch.nn as nn + +from vidar.utils.decorators import iterate2 +from vidar.utils.depth import inv2depth + + +class SigmoidToInvDepth(nn.Module, ABC): + """ + Converts sigmoid to inverse depth map + + Parameters + ---------- + min_depth : Float + Minimum depth value + max_depth + Maximum depth value + return_depth: + Whether the inverse depth map is inverted to depth when returning + """ + def __init__(self, min_depth, max_depth, return_depth=False): + super().__init__() + self.min_inv_depth = 1. / max_depth + self.max_inv_depth = 1. / min_depth + self.diff_inv_depth = (self.max_inv_depth - self.min_inv_depth) + self.return_depth = return_depth + + @iterate2 + def forward(self, sigmoid): + """Convert sigmoid to inverse depth""" + inv_depth = self.min_inv_depth + self.diff_inv_depth * sigmoid + return inv_depth if not self.return_depth else inv2depth(inv_depth) diff --git a/vidar/arch/blocks/depth/SigmoidToLogDepth.py b/vidar/arch/blocks/depth/SigmoidToLogDepth.py new file mode 100755 index 0000000000000000000000000000000000000000..8b1a4c631f5b028bfccd1d1894a01b0dfa6b2897 --- /dev/null +++ b/vidar/arch/blocks/depth/SigmoidToLogDepth.py @@ -0,0 +1,21 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from abc import ABC + +import torch +import torch.nn as nn + +from vidar.utils.decorators import iterate2 + + +class SigmoidToLogDepth(nn.Module, ABC): + """ + Converts sigmoid to a log depth map + """ + def __init__(self): + super().__init__() + + @iterate2 + def forward(self, sigmoid): + """Convert sigmoid to log depth""" + return torch.exp(sigmoid) diff --git a/vidar/arch/blocks/image/ViewSynthesis.py b/vidar/arch/blocks/image/ViewSynthesis.py new file mode 100644 index 0000000000000000000000000000000000000000..bb0eed8d05ecbf4870209f37293d005d5acc5ded --- /dev/null +++ b/vidar/arch/blocks/image/ViewSynthesis.py @@ -0,0 +1,116 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from abc import ABC +from functools import partial + +import torch +import torch.nn as nn + +from vidar.utils.flow import coords_from_optical_flow +from vidar.utils.tensor import grid_sample, interpolate +from vidar.utils.types import is_list + + +class ViewSynthesis(nn.Module, ABC): + """ + Class for view synthesis calculation based on image warping + + Parameters + ---------- + cfg : Config + Configuration with parameters + """ + def __init__(self, cfg=None): + super().__init__() + self.grid_sample = partial( + grid_sample, mode='bilinear', padding_mode='border', align_corners=True) + self.interpolate = partial( + interpolate, mode='bilinear', scale_factor=None, align_corners=True) + self.grid_sample_zeros = partial( + grid_sample, mode='nearest', padding_mode='zeros', align_corners=True) + self.upsample_depth = cfg.has('upsample_depth', True) if cfg is not None else True + + @staticmethod + def get_num_scales(depths, optical_flow): + """Return number of scales based on input""" + if depths is not None: + return len(depths) + if optical_flow is not None: + return len(optical_flow) + else: + raise ValueError('Invalid inputs for view synthesis') + + @staticmethod + def get_tensor_ones(depths, optical_flow, scale): + """Return unitary tensor based on input""" + if depths is not None: + return torch.ones_like(depths[scale]) + elif optical_flow is not None: + b, _, h, w = optical_flow[scale].shape + return torch.ones((b, 1, h, w), device=optical_flow[scale].device) + else: + raise ValueError('Invalid inputs for view synthesis') + + def get_coords(self, rgbs, depths, cams, optical_flow, context, scale, tgt): + """ + Calculate projection coordinates for warping + + Parameters + ---------- + rgbs : list[torch.Tensor] + Input images (for dimensions) [B,3,H,W] + depths : list[torch.Tensor] + Target depth maps [B,1,H,W] + cams : list[Camera] + Input cameras + optical_flow : list[torch.Tensor] + Input optical flow for alternative warping + context : list[Int] + Context indices + scale : Int + Current scale + tgt : Int + Target index + + Returns + ------- + output : Dict + Dictionary containing warped images and masks + """ + if depths is not None and cams is not None: + cams_tgt = cams[0] if is_list(cams) else cams + cams_ctx = cams[1] if is_list(cams) else cams + depth = self.interpolate(depths[scale], size=rgbs[tgt][0].shape[-2:]) \ + if self.upsample_depth else depths[scale] + return { + ctx: cams_tgt[tgt].coords_from_depth(depth, cams_ctx[ctx]) for ctx in context + } + elif optical_flow is not None: + return { + ctx: coords_from_optical_flow( + optical_flow[scale]).permute(0, 2, 3, 1) for ctx in context + } + else: + raise ValueError('Invalid input for view synthesis') + + def forward(self, rgbs, depths=None, cams=None, + optical_flow=None, return_masks=False, tgt=0): + + context = [key for key in rgbs.keys() if key != tgt] + num_scales = self.get_num_scales(depths, optical_flow) + + warps, masks = [], [] + for scale in range(num_scales): + src = 0 if self.upsample_depth else scale + coords = self.get_coords(rgbs, depths, cams, optical_flow, context, scale, tgt=tgt) + warps.append([self.grid_sample( + rgbs[ctx][src], coords[ctx].type(rgbs[ctx][src].dtype)) for ctx in context]) + if return_masks: + ones = self.get_tensor_ones(depths, optical_flow, scale) + masks.append([self.grid_sample_zeros( + ones, coords[ctx].type(ones.dtype)) for ctx in context]) + + return { + 'warps': warps, + 'masks': masks if return_masks else None + } diff --git a/vidar/arch/losses/BaseLoss.py b/vidar/arch/losses/BaseLoss.py new file mode 100644 index 0000000000000000000000000000000000000000..d614b61af482fb04830b0a32cee2ed9045f03e82 --- /dev/null +++ b/vidar/arch/losses/BaseLoss.py @@ -0,0 +1,125 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from abc import ABC +from collections import OrderedDict +from functools import partial + +import torch.nn as nn + +from vidar.utils.tensor import same_shape, interpolate + + +class BaseLoss(nn.Module, ABC): + """ + Base class for loss calculation + + Parameters + ---------- + cfg : Config + Configuration with parameters + """ + def __init__(self, cfg=None): + super().__init__() + + self.losses = OrderedDict() + self.blocks = OrderedDict() + + self.nearest = partial(interpolate, scale_factor=None, mode='nearest', align_corners=None) + self.bilinear = partial(interpolate, scale_factor=None, mode='bilinear', align_corners=True) + + if cfg is not None: + self.gamma = cfg.has('gamma', 1.0) + self.weight = cfg.has('weight', 1.0) + self.scales = cfg.has('scales', 99) + + self.flag_mask_sparse = cfg.has('mask_sparse', False) + self.flag_mask_range = cfg.has('mask_range', None) + + def forward(self, *args, **kwargs): + """Forward method""" + raise NotImplementedError('Forward not implemented for {}'.format(self.__name__)) + + def get_weights(self, scales): + """Get scale weights""" + return [self.weight * self.gamma ** i for i in range(scales)] + + def get_scales(self, scales): + """Get number of scales""" + return min(self.scales, len(scales)) + + @staticmethod + def interp(dst, src, fn): + """Interpolate dst to match src using fn""" + if dst is None or dst.dim() == 3: + return dst + assert dst.dim() == src.dim() + if dst.dim() == 4 and not same_shape(dst.shape, src.shape): + dst = fn(dst, size=src) + return dst + + def interp_bilinear(self, dst, src): + """Bilinear interpolation""" + return self.interp(dst, src, self.bilinear) + + def interp_nearest(self, dst, src): + """Nearest-neighbor interpolation""" + return self.interp(dst, src, self.nearest) + + def mask_sparse(self, mask, gt): + """Mask based on sparse GT""" + if mask is None: + return mask + if self.flag_mask_sparse: + mask *= gt.sum(1) > 0 + return mask + + def mask_range(self, mask, gt): + """Mask based on depth range""" + if mask is None: + return mask + if self.flag_mask_range is None: + return mask + mask *= (gt.sum(1) >= self.flag_mask_range[0]) & \ + (gt.sum(1) <= self.flag_mask_range[1]) + return mask + + @staticmethod + def flatten(pred, gt, mask=None, soft_mask=None): + """ + Flatten 2D inputs for loss calculation + + Parameters + ---------- + pred : torch.Tensor + Input predictions + gt : torch.Tensor + Input ground-truth + mask : torch.Tensor or None + Input mask (binary) + soft_mask : torch.Tensor or None + Input soft mask (probability) + + Returns + ------- + pred, gt, mask, soft_mask : torch.Tensor + Flattened inputs + """ + if pred.dim() == 4: + pred = pred.permute(0, 2, 3, 1) + pred = pred.reshape(-1, pred.shape[-1]) + + if gt.dim() == 4: + gt = gt.permute(0, 2, 3, 1) + gt = gt.reshape(-1, gt.shape[-1]) + + if mask is not None: + if mask.dim() == 4: + mask = mask.permute(0, 2, 3, 1) + mask = mask.reshape(-1) + + if soft_mask is not None: + if soft_mask.dim() == 4: + soft_mask = soft_mask.permute(0, 2, 3, 1) + soft_mask = soft_mask.reshape(-1) + + return pred, gt, mask, soft_mask diff --git a/vidar/arch/losses/ConsistencyLoss.py b/vidar/arch/losses/ConsistencyLoss.py new file mode 100755 index 0000000000000000000000000000000000000000..15a2a27859d8f22ae4600a90ca0802cea8f054b1 --- /dev/null +++ b/vidar/arch/losses/ConsistencyLoss.py @@ -0,0 +1,101 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from abc import ABC +from functools import partial + +import torch + +from vidar.arch.losses.BaseLoss import BaseLoss +from vidar.utils.data import get_mask_from_list +from vidar.utils.tensor import interpolate, same_shape, multiply_mask, masked_average +from vidar.utils.types import is_list + + +class ConsistencyLoss(BaseLoss, ABC): + """ + Consistency loss class + + Parameters + ---------- + cfg : Config + Configuration with parameters + """ + def __init__(self, cfg): + super().__init__(cfg) + self.interpolate = partial( + interpolate, mode='nearest', scale_factor=None, align_corners=None) + + def calculate(self, teacher, student, confidence_mask, valid_mask=None): + """ + Calculate consistency loss + + Parameters + ---------- + teacher : torch.Tensor + Teacher depth predictions [B,1,H,W] + student : torch.Tensor + Student depth predictions [B,1,H,W] + confidence_mask : torch.Tensor + Confidence mask for pixel selection [B,1,H,W] + valid_mask : torch.Tensor + Valid mask for pixel selection [B,1,H,W] + + Returns + ------- + loss : torch.Tensor + Consistency loss [1] + """ + if not same_shape(teacher.shape[-2:], student.shape[-2:]): + teacher = self.interpolate(teacher, size=student.shape[-2:]) + if not same_shape(confidence_mask.shape, teacher.shape): + confidence_mask = self.interpolate(confidence_mask, size=teacher.shape[-2:]) + if valid_mask is not None and not same_shape(valid_mask.shape, teacher.shape): + valid_mask = self.interpolate(valid_mask, size=teacher.shape[-2:]) + non_confidence_mask = (1 - confidence_mask).float() + consistency_loss = torch.abs(student - teacher.detach()) + return masked_average(consistency_loss, + multiply_mask(non_confidence_mask, valid_mask)) + + def forward(self, teacher, student, confidence_mask, valid_mask=None): + """ + Forward loop for loss calculation + + Parameters + ---------- + teacher : list[torch.Tensor] + Teacher depth predictions [B,1,H,W] + student : list[torch.Tensor] + Student depth predictions [B,1,H,W] + confidence_mask : list[torch.Tensor] + Confidence mask for pixel selection [B,1,H,W] + valid_mask : list[torch.Tensor] + Valid mask for pixel selection [B,1,H,W] + + Returns + ------- + output : Dict + Dictionary with loss and metrics + """ + scales = self.get_scales(student) + weights = self.get_weights(scales) + + losses, metrics = [], {} + + for i in range(scales): + teacher_i = teacher[i] if is_list(teacher) else teacher + student_i = student[i] if is_list(student) else student + confidence_mask_i = get_mask_from_list(confidence_mask, i) + valid_mask_i = get_mask_from_list(valid_mask, i) + + loss_i = weights[i] * self.calculate( + teacher_i, student_i, confidence_mask_i, valid_mask_i) + + metrics[f'consistency_loss/{i}'] = loss_i.detach() + losses.append(loss_i) + + loss = sum(losses) / len(losses) + + return { + 'loss': loss, + 'metrics': metrics, + } diff --git a/vidar/arch/losses/MultiCamPhotometricLoss.py b/vidar/arch/losses/MultiCamPhotometricLoss.py new file mode 100644 index 0000000000000000000000000000000000000000..5e3b5ed313f6b5a6fc229daea437ad85f0802918 --- /dev/null +++ b/vidar/arch/losses/MultiCamPhotometricLoss.py @@ -0,0 +1,173 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import torch + +from vidar.arch.losses.MultiViewPhotometricLoss import MultiViewPhotometricLoss +from vidar.arch.networks.layers.fsm.utils import coords_from_motion, warp_from_coords, mask_from_coords +from vidar.utils.depth import inv2depth +from vidar.utils.tensor import match_scales +from vidar.utils.types import is_list, is_double_list + + +class MultiCamPhotometricLoss(MultiViewPhotometricLoss): + def __init__(self, **kwargs): + super().__init__(**kwargs) + # Large value for loss masking + self.inf = 999999 + self.align_corners = True + + def warp(self, rgb_context, inv_depths, cam, cam_context, + scene_flow=None, with_mask=True): + # Initialize empty warp and mask list + warps_context, masks_context = [], [] + # If mask is available, use it instead of calculating + if is_list(with_mask): + masks_context, with_mask = with_mask, False + # Match inverse depth scales on reference images if necessary + rgbs_context = rgb_context if is_double_list(rgb_context) else \ + [match_scales(rgb, inv_depths, self.n, align_corners=self.align_corners) + for rgb in rgb_context] + # Warp each reference image to target + for j, (ref_rgbs, ref_cam) in enumerate(zip(rgbs_context, cam_context)): + # Get warping coordinates + ref_coords = [coords_from_motion(ref_cam, inv2depth(inv_depths[i]), cam) for i in range(self.n)] + # Get warped images + warps_context.append([warp_from_coords( + ref_rgbs[i], ref_coords[i], align_corners=self.align_corners, padding_mode='zeros') #'reflection') + for i in range(self.n)]) + # Get warped masks if requested + if with_mask: + masks_context.append([mask_from_coords(ref_coords[i]) for i in range(self.n)]) + # Return warped reference images + return warps_context, masks_context + + def reduce_photometric_loss_min(self, photometric_losses, + unwarped_photometric_losses=None): + """ + Reduce photometric losses using minimum reprojection error + + Parameters + ---------- + photometric_losses : list[Tensor] + Photometric losses for each warped image [B,3,H,W] + + unwarped_photometric_losses : list[Tensor] + Unwarped photometric losses for each reference image [B,3,H,W] + + Returns + ------- + reduced_photometric_loss : Tensor + Reduced loss value (single value) + min_photometric_loss : Tensor + Masked photometric loss [B,1,H,W] + """ + # Calculate minimum photometric losses + min_photometric_loss = [torch.cat(losses, 1).min(1, True)[0] + for losses in photometric_losses] + # Get invalid minimum mask + valid_mask = [warped < self.inf for warped in min_photometric_loss] + # If unwarped photometric losses are provided + if unwarped_photometric_losses is not None and \ + len(unwarped_photometric_losses[0]) > 0: + # Calculate minimum unwarped photometric losses + min_unwarped_photometric_loss = [torch.cat(losses, 1).min(1, True)[0] + for losses in unwarped_photometric_losses] + # Get minimum mask (warped < unwarped) + minimum_mask = [warped < unwarped for warped, unwarped in + zip(min_photometric_loss, min_unwarped_photometric_loss)] + # Update valid mask with minimum mask + valid_mask = [minimum & valid for minimum, valid in + zip(minimum_mask, valid_mask)] + # Get reduced photometric loss + reduced_photometric_loss = sum( + [loss[mask].mean() for mask, loss in + zip(valid_mask, min_photometric_loss)]) / len(min_photometric_loss) + # Mask min photometric loss for visualization + for i in range(len(min_photometric_loss)): + min_photometric_loss[i][~valid_mask[i]] = 0 + # Store and return reduced photometric loss + return reduced_photometric_loss, min_photometric_loss + + def reduce_photometric_loss_mean(self, photometric_losses, + unwarped_photometric_losses=None): + """ + Reduce photometric losses using minimum reprojection error + + Parameters + ---------- + photometric_losses : list[Tensor] + Photometric losses for each warped image [B,3,H,W] + + unwarped_photometric_losses : list[Tensor] + Unwarped photometric losses for each reference image [B,3,H,W] + + Returns + ------- + reduced_photometric_loss : Tensor + Reduced loss value (single value) + min_photometric_loss : Tensor + Masked photometric loss [B,1,H,W] + """ + valid_mask = [[w < self.inf for w in warped] for warped in photometric_losses] + if unwarped_photometric_losses is not None and \ + len(unwarped_photometric_losses[0]) > 0: + # Get minimum mask (warped < unwarped) + minimum_mask = [[w < u for w, u in zip(warped, unwarped)] for warped, unwarped in + zip(photometric_losses, unwarped_photometric_losses)] + # Update valid mask with minimum mask + valid_mask = [[m & v for m, v in zip(minimum, valid)] + for minimum, valid in zip(minimum_mask, valid_mask)] + reduced_photometric_loss = [] + for i in range(len(photometric_losses)): + for j in range(len(photometric_losses[i])): + loss = photometric_losses[i][j][valid_mask[i][j]].mean() + if not torch.isnan(loss): + reduced_photometric_loss.append(loss) + reduced_photometric_loss = sum(reduced_photometric_loss) / len(reduced_photometric_loss) + # Store and return reduced photometric loss + return reduced_photometric_loss, [photometric_losses[0][0]] + + def forward(self, rgb, rgb_context, inv_depths, + cam, cam_context, return_logs=False, progress=0.0, + opt_flow=None, scene_flow=None, with_mask=False, automask=None): + # Initialize photometric losses + photometric_losses = [[] for _ in range(self.n)] + unwarped_photometric_losses = [[] for _ in range(self.n)] + # Create RGB scales + rgbs = match_scales(rgb, inv_depths, self.n, align_corners=self.align_corners) + rgbs_context = [match_scales(rgb, inv_depths, self.n, align_corners=self.align_corners) + for rgb in rgb_context] + # Warp context to target + warps_context, masks_context = self.warp( + rgbs_context, inv_depths, cam, cam_context, + scene_flow=scene_flow, with_mask=with_mask) + for j in range(len(rgbs_context)): + # Calculate and store image loss + photometric_loss = self.calc_photometric_loss(warps_context[j], rgbs) + for i in range(self.n): + if with_mask: + # Apply mask if available + photometric_loss[i][~masks_context[j][i]] = self.inf + # Stack photometric losses for each scale + photometric_losses[i].append(photometric_loss[i]) + # If using automask, calculate and store unwarped losses + if self.automask_loss and automask is not False: + unwarped_image_loss = self.calc_photometric_loss(rgbs_context[j], rgbs) + for i in range(self.n): + unwarped_photometric_losses[i].append(unwarped_image_loss[i]) + # Calculate reduced photometric loss + reduced_loss, masked_loss = self.reduce_photometric_loss_min( + photometric_losses, unwarped_photometric_losses) + # Include smoothness loss if requested + if self.smooth_loss_weight > 0.0: + reduced_loss += self.calc_smoothness_loss(inv_depths, rgbs) + # Remove masks from warps_context + return { + 'loss': reduced_loss.unsqueeze(0), + 'metrics': {}, + 'warp': warps_context, + 'masks': masks_context, + 'photo': masked_loss, + } + +######################################################################################################################## diff --git a/vidar/arch/losses/MultiViewPhotometricLoss.py b/vidar/arch/losses/MultiViewPhotometricLoss.py new file mode 100644 index 0000000000000000000000000000000000000000..cebbac03c9f8b34786663be0d603bbdd2c0613d5 --- /dev/null +++ b/vidar/arch/losses/MultiViewPhotometricLoss.py @@ -0,0 +1,323 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from abc import ABC + +import torch +import torch.nn as nn +import torch.nn.functional as tf + +from vidar.arch.losses.BaseLoss import BaseLoss +from vidar.geometry.camera import Camera +from vidar.utils.depth import inv2depth +from vidar.utils.tensor import match_scales + + +def view_synthesis(ref_image, depth, ref_cam, cam, + mode='bilinear', padding_mode='zeros', align_corners=True): + assert depth.shape[1] == 1, 'Depth map should have C=1' + # Reconstruct world points from target_camera + world_points = cam.reconstruct(depth, frame='w') + # Project world points onto reference camera + ref_coords = ref_cam.project(world_points, frame='w') + # View-synthesis given the projected reference points + return tf.grid_sample(ref_image, ref_coords, mode=mode, + padding_mode=padding_mode, align_corners=align_corners) + + +def gradient_x(image): + return image[:, :, :, :-1] - image[:, :, :, 1:] + + +def gradient_y(image): + return image[:, :, :-1, :] - image[:, :, 1:, :] + + +def inv_depths_normalize(inv_depths): + mean_inv_depths = [inv_depth.mean(2, True).mean(3, True) for inv_depth in inv_depths] + return [inv_depth / mean_inv_depth.clamp(min=1e-6) + for inv_depth, mean_inv_depth in zip(inv_depths, mean_inv_depths)] + + +def calc_smoothness(inv_depths, images, num_scales): + inv_depths_norm = inv_depths_normalize(inv_depths) + inv_depth_gradients_x = [gradient_x(d) for d in inv_depths_norm] + inv_depth_gradients_y = [gradient_y(d) for d in inv_depths_norm] + + image_gradients_x = [gradient_x(image) for image in images] + image_gradients_y = [gradient_y(image) for image in images] + + weights_x = [torch.exp(-torch.mean(torch.abs(g), 1, keepdim=True)) for g in image_gradients_x] + weights_y = [torch.exp(-torch.mean(torch.abs(g), 1, keepdim=True)) for g in image_gradients_y] + + # Note: Fix gradient addition + smoothness_x = [inv_depth_gradients_x[i] * weights_x[i] for i in range(num_scales)] + smoothness_y = [inv_depth_gradients_y[i] * weights_y[i] for i in range(num_scales)] + return smoothness_x, smoothness_y + + +def SSIM(x, y, C1=1e-4, C2=9e-4, kernel_size=3, stride=1): + """ + Structural Similarity (SSIM) distance between two images. + + Parameters + ---------- + x,y : torch.Tensor + Input images [B,3,H,W] + C1,C2 : float + SSIM parameters + kernel_size,stride : int + Convolutional parameters + + Returns + ------- + ssim : torch.Tensor + SSIM distance [1] + """ + pool2d = nn.AvgPool2d(kernel_size, stride=stride) + refl = nn.ReflectionPad2d(1) + + x, y = refl(x), refl(y) + mu_x = pool2d(x) + mu_y = pool2d(y) + + mu_x_mu_y = mu_x * mu_y + mu_x_sq = mu_x.pow(2) + mu_y_sq = mu_y.pow(2) + + sigma_x = pool2d(x.pow(2)) - mu_x_sq + sigma_y = pool2d(y.pow(2)) - mu_y_sq + sigma_xy = pool2d(x * y) - mu_x_mu_y + v1 = 2 * sigma_xy + C2 + v2 = sigma_x + sigma_y + C2 + + ssim_n = (2 * mu_x_mu_y + C1) * v1 + ssim_d = (mu_x_sq + mu_y_sq + C1) * v2 + ssim = ssim_n / ssim_d + + return ssim + + +class MultiViewPhotometricLoss(BaseLoss, ABC): + """ + Self-Supervised multiview photometric loss. + It takes two images, a depth map and a pose transformation to produce a + reconstruction of one image from the perspective of the other, and calculates + the difference between them + + Parameters + ---------- + num_scales : int + Number of inverse depth map scales to consider + ssim_loss_weight : float + Weight for the SSIM loss + occ_reg_weight : float + Weight for the occlusion regularization loss + smooth_loss_weight : float + Weight for the smoothness loss + C1,C2 : float + SSIM parameters + photometric_reduce_op : str + Method to reduce the photometric loss + disp_norm : bool + True if inverse depth is normalized for + clip_loss : float + Threshold for photometric loss clipping + progressive_scaling : float + Training percentage for progressive scaling (0.0 to disable) + padding_mode : str + Padding mode for view synthesis + automask_loss : bool + True if automasking is enabled for the photometric loss + kwargs : dict + Extra parameters + """ + def __init__(self, num_scales=4, ssim_loss_weight=0.85, occ_reg_weight=0.1, smooth_loss_weight=0.1, + C1=1e-4, C2=9e-4, photometric_reduce_op='mean', disp_norm=True, clip_loss=0.5, + progressive_scaling=0.0, padding_mode='zeros', automask_loss=False, **kwargs): + super().__init__() + self.n = num_scales + self.ssim_loss_weight = ssim_loss_weight + self.occ_reg_weight = occ_reg_weight + self.smooth_loss_weight = smooth_loss_weight + self.C1 = C1 + self.C2 = C2 + self.photometric_reduce_op = photometric_reduce_op + self.disp_norm = disp_norm + self.clip_loss = clip_loss + self.padding_mode = padding_mode + self.automask_loss = automask_loss + + # Asserts + if self.automask_loss: + assert self.photometric_reduce_op == 'min', \ + 'For automasking only the min photometric_reduce_op is supported.' + +######################################################################################################################## + + @property + def logs(self): + """Returns class logs.""" + return { + 'num_scales': self.n, + } + +######################################################################################################################## + + def warp_ref_image(self, inv_depths, ref_image, K, ref_K, pose): + """ + Warps a reference image to produce a reconstruction of the original one. + + Parameters + ---------- + inv_depths : list[torch.Tensor] + Inverse depth map of the original image [B,1,H,W] + ref_image : torch.Tensor + Reference RGB image [B,3,H,W] + K : torch.Tensor + Original camera intrinsics [B,3,3] + ref_K : torch.Tensor + Reference camera intrinsics [B,3,3] + pose : Pose + Original -> Reference camera transformation + + Returns + ------- + ref_warped : torch.Tensor + Warped reference image (reconstructing the original one) [B,3,H,W] + """ + B, _, H, W = ref_image.shape + device = ref_image.device + # Generate cameras for all scales + cams, ref_cams = [], [] + for i in range(self.n): + _, _, DH, DW = inv_depths[i].shape + scale_factor = DW / float(W) + cams.append(Camera(K=K.float()).scaled(scale_factor).to(device)) + ref_cams.append(Camera(K=ref_K.float(), Tcw=pose).scaled(scale_factor).to(device)) + # View synthesis + depths = [inv2depth(inv_depths[i]) for i in range(self.n)] + ref_images = match_scales(ref_image, inv_depths, self.n) + ref_warped = [view_synthesis( + ref_images[i], depths[i], ref_cams[i], cams[i], + padding_mode=self.padding_mode) for i in range(self.n)] + # Return warped reference image + return ref_warped + +######################################################################################################################## + + def SSIM(self, x, y, kernel_size=3): + """ + Calculates the SSIM (Structural Similarity) loss + + Parameters + ---------- + x,y : torch.Tensor + Input images [B,3,H,W] + kernel_size : int + Convolutional parameter + + Returns + ------- + ssim : torch.Tensor + SSIM loss [1] + """ + ssim_value = SSIM(x, y, C1=self.C1, C2=self.C2, kernel_size=kernel_size) + return torch.clamp((1. - ssim_value) / 2., 0., 1.) + + def calc_photometric_loss(self, t_est, images): + """ + Calculates the photometric loss (L1 + SSIM) + Parameters + ---------- + t_est : list[torch.Tensor] + List of warped reference images in multiple scales [B,3,H,W] + images : list[torch.Tensor] + List of original images in multiple scales [B,3,H,W] + + Returns + ------- + photometric_loss : list[torch.Tensor] + Photometric loss [B,1,H,W] + """ + # L1 loss + l1_loss = [torch.abs(t_est[i] - images[i]) + for i in range(self.n)] + # SSIM loss + if self.ssim_loss_weight > 0.0: + ssim_loss = [self.SSIM(t_est[i], images[i], kernel_size=3) + for i in range(self.n)] + # Weighted Sum: alpha * ssim + (1 - alpha) * l1 + photometric_loss = [self.ssim_loss_weight * ssim_loss[i].mean(1, True) + + (1 - self.ssim_loss_weight) * l1_loss[i].mean(1, True) + for i in range(self.n)] + else: + photometric_loss = l1_loss + # Clip loss + if self.clip_loss > 0.0: + for i in range(self.n): + mean, std = photometric_loss[i].mean(), photometric_loss[i].std() + photometric_loss[i] = torch.clamp( + photometric_loss[i], max=float(mean + self.clip_loss * std)) + # Return total photometric loss + return photometric_loss + + def reduce_photometric_loss(self, photometric_losses): + """ + Combine the photometric loss from all context images + + Parameters + ---------- + photometric_losses : list[list[torch.Tensor]] + Pixel-wise photometric losses from the entire context [B,3,H,W] + + Returns + ------- + photometric_loss : torch.Tensor + Reduced photometric loss [1] + """ + # Reduce function + def reduce_function(losses): + if self.photometric_reduce_op == 'mean': + return sum([l.mean() for l in losses]) / len(losses) + elif self.photometric_reduce_op == 'min': + return torch.cat(losses, 1).min(1, True)[0].mean() + else: + raise NotImplementedError( + 'Unknown photometric_reduce_op: {}'.format(self.photometric_reduce_op)) + # Reduce photometric loss + photometric_loss = sum([reduce_function(photometric_losses[i]) + for i in range(self.n)]) / self.n + # Store and return reduced photometric loss + return photometric_loss + +######################################################################################################################## + + def calc_smoothness_loss(self, inv_depths, images): + """ + Calculates the smoothness loss for inverse depth maps. + + Parameters + ---------- + inv_depths : list[torch.Tensor] + Predicted inverse depth maps for all scales [B,1,H,W] + images : list[torch.Tensor] + Original images for all scales [B,3,H,W] + + Returns + ------- + smoothness_loss : torch.Tensor + Smoothness loss [1] + """ + # Calculate smoothness gradients + smoothness_x, smoothness_y = calc_smoothness(inv_depths, images, self.n) + # Calculate smoothness loss + smoothness_loss = sum([(smoothness_x[i].abs().mean() + + smoothness_y[i].abs().mean()) / 2 ** i + for i in range(self.n)]) / self.n + # Apply smoothness loss weight + smoothness_loss = self.smooth_loss_weight * smoothness_loss + # Store and return smoothness loss + return smoothness_loss + + + diff --git a/vidar/arch/losses/PhotometricLoss.py b/vidar/arch/losses/PhotometricLoss.py new file mode 100644 index 0000000000000000000000000000000000000000..d29fbf532fd05876a6361c9819f462e9d7cb857c --- /dev/null +++ b/vidar/arch/losses/PhotometricLoss.py @@ -0,0 +1,34 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from abc import ABC +from functools import partial + +import torch + +from vidar.arch.losses.BaseLoss import BaseLoss +from vidar.arch.losses.SSIMLoss import SSIMLoss +from vidar.utils.tensor import interpolate + + +class PhotometricLoss(BaseLoss, ABC): + def __init__(self, cfg): + super().__init__() + self.alpha = cfg.alpha + self.ssim_loss = SSIMLoss() + self.interpolate = partial(interpolate, scale_factor=None, mode='bilinear', align_corners=True) + + def forward(self, pred, gt): + + pred = self.interpolate(pred, size=gt) + l1_loss = torch.abs(pred - gt).mean(1, True) + + if self.alpha == 0.0: + photometric_loss = l1_loss + else: + ssim_loss = self.ssim_loss(pred, gt)['loss'].mean(1, True) + photometric_loss = self.alpha * ssim_loss + (1.0 - self.alpha) * l1_loss + + return { + 'loss': photometric_loss, + } + diff --git a/vidar/arch/losses/ReprojectionLoss.py b/vidar/arch/losses/ReprojectionLoss.py new file mode 100644 index 0000000000000000000000000000000000000000..4f208cad2bb1731a1015184afa46a96619453aa4 --- /dev/null +++ b/vidar/arch/losses/ReprojectionLoss.py @@ -0,0 +1,230 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from abc import ABC +from functools import partial + +import torch + +from vidar.arch.losses.BaseLoss import BaseLoss +from vidar.utils.config import cfg_has +from vidar.utils.data import get_from_list, get_mask_from_list +from vidar.utils.tensor import interpolate, multiply_args, masked_average +from vidar.utils.types import is_list + + +class ReprojectionLoss(BaseLoss, ABC): + """ + Reprojection loss class + + Parameters + ---------- + cfg : Config + Configuration with parameters + """ + def __init__(self, cfg): + super().__init__(cfg) + self.automasking = cfg.automasking + self.reprojection_reduce_op = cfg.reprojection_reduce_op + self.jitter_identity_reprojection = cfg.jitter_identity_reprojection + self.logvar_weight = cfg_has(cfg, 'logvar_weight', 0.0) + self.feature_weight = cfg_has(cfg, 'feature_weight', 0.0) + + self.interpolate = partial( + interpolate, mode='bilinear', scale_factor=None, align_corners=True) + + self.inf = 1e6 + + @staticmethod + def compute_reprojection_mask(reprojection_loss, identity_reprojection_loss): + """ + Compute reprojection mask based for automasking + + Parameters + ---------- + reprojection_loss : torch.Tensor + Warped reprojection loss [B,1,H,W] + identity_reprojection_loss : torch.Tensor + Identity reprojection loss [B,1,H,W] + + Returns + ------- + mask : torch.Tensor + Reprojection mask for automasking [B,1,H,W] + """ + if identity_reprojection_loss is None: + reprojection_mask = torch.ones_like(reprojection_loss) + else: + all_losses = torch.cat([reprojection_loss, identity_reprojection_loss], dim=1) + idxs = torch.argmin(all_losses, dim=1, keepdim=True) + reprojection_mask = (idxs == 0) + return reprojection_mask + + def reduce_reprojection(self, reprojection_losses, overlap_mask=None): + """ + Combine multi-image reprojection losses + + Parameters + ---------- + reprojection_losses : list[torch.Tensor] + Per-image reprojection losses + overlap_mask : list[torch.Tensor] or None + Valid mask to remove pixels + + Returns + ------- + reprojection_loss : torch.Tensor + Output loss [1] + overlap_mask : torch.Tensor + Reduced overlap mask + """ + if is_list(reprojection_losses): + reprojection_losses = torch.cat(reprojection_losses, 1) + if self.reprojection_reduce_op == 'mean': + assert overlap_mask is None, 'Not implemented yet' + reprojection_loss = reprojection_losses.mean(1, keepdim=True) + elif self.reprojection_reduce_op == 'min': + if overlap_mask is not None: + if is_list(overlap_mask): + overlap_mask = torch.cat(overlap_mask, 1) + reprojection_losses[~overlap_mask.bool()] = self.inf + reprojection_loss, _ = torch.min(reprojection_losses, dim=1, keepdim=True) + overlap_mask = reprojection_loss < self.inf + reprojection_loss[~overlap_mask] = 0.0 # For visualization purposes + else: + raise ValueError( + f'Invalid reprojection reduce operation: {self.reprojection_reduce_op}') + return reprojection_loss, overlap_mask + + def calculate(self, rgb, rgb_context, warps, logvar=None, + valid_mask=None, overlap_mask=None): + """ + Calculate reprojection loss + + Parameters + ---------- + rgb : torch.Tensor + Target image [B,3,H,W] + rgb_context : list[torch.Tensor] + List of context images [B,3,H,W] + warps : list[torch.Tensor] + List of warped images from view synthesis [B,3,H,W] + logvar : list[torch.Tensor] + Log variance for log-likelihood calculation + valid_mask : torch.Tensor or None + Valid mask for pixel filtering + overlap_mask : torch.Tensor or None + Overlap mask for pixel filtering + + Returns + ------- + average_loss : torch.Tensor + Output loss [1] + reprojection_mask : torch.Tensor + Combined reprojection mask (overlap + reprojection + valid) [B,1,H,W] + reprojection_loss : torch.Tensor + Per-pixel loss [B,1,H,W] + overlap_mask : torch.Tensor + Combined overlap mask [B,1,H,W] + """ + reprojection_losses = [ + self.losses['photometric'](warp, rgb)['loss'] for warp in warps] + reprojection_loss, overlap_mask = self.reduce_reprojection( + reprojection_losses, overlap_mask=overlap_mask) + + if 'featuremetric' in self.losses.keys(): + featuremetric_loss = [ + self.losses['featuremetric'](warp, rgb)['loss'] for warp in warps + ] + reduced_featuremetric_loss = torch.cat(reprojection_losses, 1).mean() + + if self.automasking: + reprojection_identity_losses = [ + self.losses['photometric'](context, rgb)['loss'] for context in rgb_context] + reprojection_identity_loss, _ = self.reduce_reprojection( + reprojection_identity_losses) + if self.jitter_identity_reprojection > 0: + reprojection_identity_loss += self.jitter_identity_reprojection * \ + torch.randn(reprojection_identity_loss.shape, device=reprojection_identity_loss.device) + else: + reprojection_identity_loss = None + + reprojection_mask = self.compute_reprojection_mask( + reprojection_loss, reprojection_identity_loss, + # reprojection_mask=valid_mask + # reprojection_mask=multiply_any(reprojection_mask, overlap_mask) + ) + reprojection_mask = multiply_args(reprojection_mask, valid_mask, overlap_mask) + + if logvar is not None and self.logvar_weight > 0.0: + logvar = self.interpolate(logvar, reprojection_loss.shape[-2:]) + reprojection_loss = reprojection_loss * torch.exp(-logvar) + + average_loss = masked_average(reprojection_loss, reprojection_mask) + # reprojection_loss *= reprojection_mask # REMOVE FOR VISUALIZATION + + if logvar is not None and self.logvar_weight > 0.0: + average_loss += self.logvar_weight * masked_average(logvar, reprojection_mask) + + if 'featuremetric' in self.losses.keys() and self.feature_weight > 0.0: + featuremetric_loss = [self.losses['featuremetric'](warp, rgb)['loss'] for warp in warps] + reduced_featuremetric_loss = torch.cat(featuremetric_loss, 1).mean() + average_loss += self.feature_weight * reduced_featuremetric_loss + + return average_loss, reprojection_mask, reprojection_loss, overlap_mask + + def forward(self, rgb, rgb_context, warps, logvar=None, + valid_mask=None, overlap_mask=None): + """ + Calculate reprojection loss + + Parameters + ---------- + rgb : torch.Tensor + Target image [B,3,H,W] + rgb_context : list[torch.Tensor] + List of context images [B,3,H,W] + warps : list[torch.Tensor] + List of warped images from view synthesis [B,3,H,W] + logvar : list[torch.Tensor] + Log variance for log-likelihood calculation + valid_mask : torch.Tensor or None + Valid mask for pixel filtering + overlap_mask : torch.Tensor or None + Overlap mask for pixel filtering + + Returns + ------- + output : Dictionary with loss, metrics, masks, photometric errors, and overlap + """ + scales = self.get_scales(warps) + weights = self.get_weights(scales) + + losses, masks, photos, overlaps, metrics = [], [], [], [], {} + + for i in range(scales): + rgb_i, rgb_context_i, warps_i = rgb[0], rgb_context[0], warps[i] + valid_mask_i = get_mask_from_list(valid_mask, i) + overlap_mask_i = get_mask_from_list(overlap_mask, i) + logvar_i = get_from_list(logvar, i) + + loss_i, mask_i, photo_i, overlap_mask_i = self.calculate( + rgb_i, rgb_context_i, warps_i, logvar=logvar_i, + valid_mask=valid_mask_i, overlap_mask=overlap_mask_i) + loss_i = weights[i] * loss_i + + metrics[f'reprojection_loss/{i}'] = loss_i.detach() + + losses.append(loss_i) + masks.append(mask_i) + photos.append(photo_i) + overlaps.append(overlap_mask_i) + + loss = sum(losses) / len(losses) + + return { + 'loss': loss, + 'metrics': metrics, + 'mask': masks, + 'photo': photos, + 'overlap': overlaps, + } \ No newline at end of file diff --git a/vidar/arch/losses/SSIMLoss.py b/vidar/arch/losses/SSIMLoss.py new file mode 100755 index 0000000000000000000000000000000000000000..eefa3912488cacd50a441d3a6a26f3632692a4fd --- /dev/null +++ b/vidar/arch/losses/SSIMLoss.py @@ -0,0 +1,62 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from abc import ABC + +import torch +import torch.nn as nn + +from vidar.arch.losses.BaseLoss import BaseLoss + + +class SSIMLoss(BaseLoss, ABC): + """SSIM (Structural Similarity Index Metric) loss class""" + def __init__(self): + super().__init__() + + self.mu_x_pool = nn.AvgPool2d(3, 1) + self.mu_y_pool = nn.AvgPool2d(3, 1) + + self.sig_x_pool = nn.AvgPool2d(3, 1) + self.sig_y_pool = nn.AvgPool2d(3, 1) + self.sig_xy_pool = nn.AvgPool2d(3, 1) + + self.refl = nn.ReflectionPad2d(1) + + self.C1 = 0.01 ** 2 + self.C2 = 0.03 ** 2 + + def forward(self, x, y): + """ + Calculates SSIM loss + + Parameters + ---------- + x : torch.Tensor + Input image 1 [B,3,H,W] + y : torch.Tensor + Input image 2 [B,3,H,W] + + Returns + ------- + output : Dict + Dictionary with loss + """ + x = self.refl(x) + y = self.refl(y) + + mu_x = self.mu_x_pool(x) + mu_y = self.mu_y_pool(y) + + sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2 + sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2 + sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y + + SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2) + SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2) + + loss = torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1) + + return { + 'loss': loss, + } + diff --git a/vidar/arch/losses/SmoothnessLoss.py b/vidar/arch/losses/SmoothnessLoss.py new file mode 100755 index 0000000000000000000000000000000000000000..248a6aa8c63bd814e05b1049d6c66af87abd2c8e --- /dev/null +++ b/vidar/arch/losses/SmoothnessLoss.py @@ -0,0 +1,93 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from abc import ABC + +import torch + +from vidar.arch.losses.BaseLoss import BaseLoss +from vidar.utils.tensor import same_shape, interpolate_image + + +class SmoothnessLoss(BaseLoss, ABC): + """ + Smoothness loss class + + Parameters + ---------- + cfg : Config + Configuration with parameters + """ + def __init__(self, cfg): + super().__init__(cfg) + self.normalize = cfg.normalize + + def calculate(self, rgb, depth): + """ + Calculate smoothness loss + + Parameters + ---------- + rgb : torch.Tensor + Input image [B,3,H,W] + depth : torch.Tensor + Predicted depth map [B,1,H,W] + + Returns + ------- + loss : torch.Tensor + Smoothness loss [1] + """ + if self.normalize: + mean_depth = depth.mean(2, True).mean(3, True) + norm_depth = depth / (mean_depth + 1e-7) + else: + norm_depth = depth + + grad_depth_x = torch.abs(norm_depth[:, :, :, :-1] - norm_depth[:, :, :, 1:]) + grad_depth_y = torch.abs(norm_depth[:, :, :-1, :] - norm_depth[:, :, 1:, :]) + + grad_rgb_x = torch.mean(torch.abs(rgb[:, :, :, :-1] - rgb[:, :, :, 1:]), 1, keepdim=True) + grad_rgb_y = torch.mean(torch.abs(rgb[:, :, :-1, :] - rgb[:, :, 1:, :]), 1, keepdim=True) + + grad_depth_x *= torch.exp(-1.0 * grad_rgb_x) + grad_depth_y *= torch.exp(-1.0 * grad_rgb_y) + + return grad_depth_x.mean() + grad_depth_y.mean() + + def forward(self, rgb, depth): + """ + Calculate smoothness loss + + Parameters + ---------- + rgb : list[torch.Tensor] + Input images [B,3,H,W] + depth : list[torch.Tensor] + Predicted depth maps [B,1,H,W] + + Returns + ------- + output : Dict + Dictionary with loss and metrics + """ + scales = self.get_scales(rgb) + weights = self.get_weights(scales) + + losses, metrics = [], {} + + for i in range(scales): + rgb_i, depth_i = rgb[i], depth[i] + if not same_shape(rgb_i.shape[-2:], depth_i.shape[-2:]): + rgb_i = interpolate_image(rgb_i, shape=depth_i.shape[-2:]) + + loss_i = weights[i] * self.calculate(rgb_i, depth_i) + + metrics[f'smoothness_loss/{i}'] = loss_i.detach() + losses.append(loss_i) + + loss = sum(losses) / len(losses) + + return { + 'loss': loss, + 'metrics': metrics, + } diff --git a/vidar/arch/losses/SupervisedDepthLoss.py b/vidar/arch/losses/SupervisedDepthLoss.py new file mode 100644 index 0000000000000000000000000000000000000000..50be573a76c52688bcddd34b5d1df551e42ea4be --- /dev/null +++ b/vidar/arch/losses/SupervisedDepthLoss.py @@ -0,0 +1,281 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from abc import ABC + +import torch +import torch.nn as nn + +from vidar.arch.losses.BaseLoss import BaseLoss +from vidar.utils.data import get_mask_from_list +from vidar.utils.depth import depth2index, get_depth_bins +from vidar.utils.depth import depth2inv +from vidar.utils.types import is_list + + +class BerHuLoss(nn.Module, ABC): + """BerHu Loss""" + def __init__(self, threshold=0.2): + super().__init__() + self.threshold = threshold + + def forward(self, pred, gt): + huber_c = self.threshold * torch.max(pred - gt) + diff = (pred - gt).abs() + diff2 = diff[diff > huber_c] ** 2 + return torch.cat((diff, diff2)) + + +class SilogLoss(nn.Module, ABC): + """Scale Invariant Logarithmic Loss""" + def __init__(self, ratio=10., var_ratio=0.85): + super().__init__() + self.ratio = ratio + self.var_ratio = var_ratio + + def forward(self, pred, gt): + log_diff = torch.log(pred) - torch.log(gt) + silog1 = (log_diff ** 2).mean() + silog2 = log_diff.mean() ** 2 + return torch.sqrt(silog1 - self.var_ratio * silog2) * self.ratio + + +class RMSELoss(nn.Module, ABC): + """Root Mean Squared Error Loss""" + def __init__(self): + super().__init__() + self.criterion = nn.MSELoss(reduction='none') + + def forward(self, pred, gt): + return torch.sqrt(self.criterion(pred, gt)) + + +class L1LogLoss(nn.Module, ABC): + """Root Mean Squared Error Loss""" + def __init__(self): + super().__init__() + self.criterion = nn.L1Loss(reduction='none') + + def forward(self, pred, gt): + return self.criterion(torch.log(pred), torch.log(gt)) + + +class MixtureLoss(nn.Module, ABC): + """Root Mean Squared Error Loss""" + def __init__(self): + super().__init__() + + @staticmethod + def laplacian(mu, std, gt): + std = std + 1e-12 + return 0.5 * torch.exp(-(torch.abs(mu - gt) / std)) / std + + def forward(self, pred, gt): + mu0, mu1 = pred[:, [0]], pred[:, [1]] + std0, std1 = pred[:, [2]], pred[:, [3]] + w0 = pred[:, [4]] + w1 = 1.0 - w0 + return (- torch.log(w0 * self.laplacian(mu0, std0, gt) + + w1 * self.laplacian(mu1, std1, gt))).mean() + + +class RootAbsRelLoss(nn.Module, ABC): + """Root Mean Squared Error Loss""" + def __init__(self): + super().__init__() + + def forward(self, pred, gt): + return torch.sqrt(torch.abs(pred - gt) / gt) + + +class SquareAbsRelLoss(nn.Module, ABC): + """Root Mean Squared Error Loss""" + def __init__(self): + super().__init__() + + def forward(self, pred, gt): + return (torch.abs(pred - gt) / gt) ** 2 + + +class CrossEntropyLoss(nn.Module, ABC): + """Supervised Loss""" + def __init__(self): + super().__init__() + self.gamma = 2.0 + self.alpha = 0.25 + + self.bootstrap_ratio = 0.3 + self.cross_entropy_loss = torch.nn.CrossEntropyLoss(weight=None, ignore_index=255, reduce=False) + + def forward(self, pred, gt): + min_depth, max_depth = 1.0, 100.0 + bins = get_depth_bins('linear', min_depth, max_depth, 100).to(pred.device) + gt = depth2index(gt, bins).squeeze(1) + + loss_ce = self.cross_entropy_loss(pred, gt.to(torch.long)) + num_bootstrapping = int(self.bootstrap_ratio * pred.shape[0]) + image_errors, _ = loss_ce.view(-1, ).sort() + worst_errors = image_errors[-num_bootstrapping:] + return torch.mean(worst_errors) + + +def get_criterion(method): + """Determines the supervised loss to be used""" + if method == 'l1': + return nn.L1Loss() + elif method == 'l1log': + return L1LogLoss() + elif method == 'mse': + return nn.MSELoss(reduction='none') + elif method == 'rmse': + return RMSELoss() + elif method == 'huber': + return nn.SmoothL1Loss(reduction='none') + elif method == 'berhu': + return BerHuLoss() + elif method == 'silog': + return SilogLoss() + elif method == 'abs_rel': + return lambda x, y: torch.abs(x - y) / x + elif method == 'root_abs_rel': + return RootAbsRelLoss() + elif method == 'square_abs_rel': + return SquareAbsRelLoss() + elif method == 'mixture': + return MixtureLoss() + elif method == 'cross_entropy': + return CrossEntropyLoss() + else: + raise ValueError('Unknown supervised loss {}'.format(method)) + + +class LossWrapper(nn.Module): + """ + Wrapper for supervised depth criteria + + Parameters + ---------- + method : String + Which supervised loss to use + """ + def __init__(self, method): + super().__init__() + self.criterion = get_criterion(method) + + def forward(self, pred, gt, soft_mask=None): + """ + Calculate supervised depth loss + + Parameters + ---------- + pred : torch.Tensor + Predicted depth [B,1,H,W] + gt : torch.Tensor + Ground-truth depth [B,1,H,W] + soft_mask + Mask for pixel weighting [B,1,H,W] + + Returns + ------- + loss : torch.Tensor + Supervised depth loss [1] + """ + loss = self.criterion(pred, gt) + if soft_mask is not None: + loss = loss * soft_mask.detach().view(-1, 1) + return loss.mean() + + +class SupervisedDepthLoss(BaseLoss, ABC): + def __init__(self, cfg): + """ + Supervised loss class + + Parameters + ---------- + cfg : Config + Configuration with parameters + """ + super().__init__(cfg) + self.criterion = LossWrapper(cfg.method) + self.inverse = cfg.has('inverse', False) + + def calculate(self, pred, gt, mask=None, soft_mask=None): + """ + Calculate supervised depth loss + + Parameters + ---------- + pred : torch.Tensor + Predicted depth [B,1,H,W] + gt : torch.Tensor + Ground-truth depth [B,1,H,W] + mask : torch.Tensor + Mask for pixel filtering [B,1,H,W] + soft_mask + Mask for pixel weighting [B,1,H,W] + + Returns + ------- + loss : torch.Tensor + Supervised depth loss [1] + """ + # Interpolations + pred = self.interp_nearest(pred, gt) + mask = self.interp_nearest(mask, gt) + soft_mask = self.interp_bilinear(soft_mask, gt) + + # Flatten tensors + pred, gt, mask, soft_mask = self.flatten(pred, gt, mask, soft_mask) + + # Masks + mask = self.mask_sparse(mask, gt) + mask = self.mask_range(mask, gt) + + # Calculate loss + return self.criterion(pred[mask], gt[mask], + soft_mask=soft_mask[mask] if soft_mask is not None else None) + + def forward(self, pred, gt, mask=None, soft_mask=None): + """ + Supervised depth loss + + Parameters + ---------- + pred : list[torch.Tensor] + Predicted depths [B,1,H,W] + gt : torch.Tensor + Ground-truth depth [B,1,H,W] + mask : torch.Tensor + Mask for pixel filtering [B,1,H,W] + soft_mask + Mask for pixel weighting [B,1,H,W] + + Returns + ------- + loss : torch.Tensor + Supervised depth loss [1] + """ + if self.inverse: + pred, gt = depth2inv(pred), depth2inv(gt) + + scales = self.get_scales(pred) + weights = self.get_weights(scales) + + losses, metrics = [], {} + + for i in range(scales): + pred_i, gt_i = pred[i], gt[i] if is_list(gt) else gt + mask_i = get_mask_from_list(mask, i, return_ones=gt_i) + soft_mask_i = get_mask_from_list(soft_mask, i) + + loss_i = weights[i] * self.calculate(pred_i, gt_i, mask_i, soft_mask_i) + + metrics[f'supervised_depth_loss/{i}'] = loss_i.detach() + losses.append(loss_i) + + loss = sum(losses) / len(losses) + + return { + 'loss': loss, + 'metrics': metrics, + } diff --git a/vidar/arch/models/BaseModel.py b/vidar/arch/models/BaseModel.py new file mode 100755 index 0000000000000000000000000000000000000000..1e3c5cec6b05f78b0515414b30eb212023df29a1 --- /dev/null +++ b/vidar/arch/models/BaseModel.py @@ -0,0 +1,43 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import torch +import torch.nn as nn + +from vidar.utils.config import cfg_has + + +class BaseModel(nn.Module): + """Base model super class, that all other models inherit""" + def __init__(self, cfg=None): + super().__init__() + + self.blocks = torch.nn.ModuleDict() + self.networks = torch.nn.ModuleDict() + self.losses = torch.nn.ModuleDict() + + if cfg is not None: + self.num_scales = cfg_has(cfg.model, 'num_scales', 99) + + def _forward_unimplemented(self, *args): + pass + + def forward(self, *args, **kwargs): + """Model forward pass""" + raise NotImplementedError( + 'Please implement forward function in your own subclass model.') + + def get_num_scales(self, scales): + """Return number of predicted scales""" + return min(self.num_scales, len(scales)) + + def compute_pose(self, rgb, net, tgt=0, ctx=None, invert=True): + """Compute poses from pairs of images""" + if ctx is None: + ctx = [key for key in rgb.keys() if key != tgt] + return {idx: net( + [rgb[tgt], rgb[idx]], invert=(idx < tgt) and invert)['transformation'] + for idx in ctx} + + def set_attr(self, cfg, key, default): + """Set an attribute for the model""" + self.__setattr__(key, cfg_has(cfg, key, default)) \ No newline at end of file diff --git a/vidar/arch/models/__init__.py b/vidar/arch/models/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vidar/arch/models/depth/DepthFormerModel.py b/vidar/arch/models/depth/DepthFormerModel.py new file mode 100644 index 0000000000000000000000000000000000000000..36fc0f758f9df0973332cbe0884f7d74c0976381 --- /dev/null +++ b/vidar/arch/models/depth/DepthFormerModel.py @@ -0,0 +1,550 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import random +from abc import ABC +from functools import partial + +import torch + +from vidar.arch.blocks.image.ViewSynthesis import ViewSynthesis +from vidar.arch.models.BaseModel import BaseModel +from vidar.arch.models.utils import make_rgb_scales, break_context, create_cameras +from vidar.utils.data import get_from_dict +from vidar.utils.depth import inv2depth +from vidar.utils.tensor import interpolate, multiply_args +from vidar.utils.types import is_str + + +def curr_stereo(val): + return is_str(val) and val.startswith('0') + + +class DepthFormerModel(BaseModel, ABC): + """ + Depthformer base model (https://arxiv.org/abs/2204.07616) + + Parameters + ---------- + cfg : Config + Configuration with parameters + """ + def __init__(self, cfg): + super().__init__(cfg) + + self.set_attr(cfg.model, 'warp_context', None) + self.set_attr(cfg.model, 'match_context', None) + + self.motion_masking = cfg.model.motion_masking + self.matching_augmentation = cfg.model.matching_augmentation + self.freeze_teacher_and_pose = cfg.model.freeze_teacher_and_pose + + self.view_synthesis = ViewSynthesis() + + self.interpolate_nearest = partial( + interpolate, mode='nearest', scale_factor=None, align_corners=None) + + self.set_attr(cfg.model, 'spatial_weight', [0.0, 0.0]) + self.set_attr(cfg.model, 'spatio_temporal_weight', [0.0, 0.0]) + self.set_attr(cfg.model, 'run_mono', True) + self.set_attr(cfg.model, 'run_multi', True) + + self.set_attr(cfg.model, 'use_gt_pose', False) + self.set_attr(cfg.model, 'use_gt_depth', False) + self.set_attr(cfg.model, 'display', False) + self.set_attr(cfg.model, 'mono_type', 'mono') + self.set_attr(cfg.model, 'multi_temporal_only', False) + + self.set_attr(cfg.model, 'apply_consistency', True) + + @staticmethod + def process_stereo(batch): + """Process batch to recover stereo / monocular information""" + batch = {key: val for key, val in batch.items()} + new_intrinsics = {0: batch['intrinsics'][0]} + for key in batch['pose'].keys(): + if not is_str(key) and key != 0: + new_intrinsics[key] = batch['intrinsics'][0] + batch['intrinsics'] = new_intrinsics + suffixes = ['', 'r', 's', 't', 'u', 'v'] + if batch['rgb'][0].dim() == 5: + # Change RGB + rgb_stereo = {} + for key, val in batch['rgb'].items(): + rgb_stereo[key] = val[:, 0] + for i in range(1, val.shape[1]): + rgb_stereo['%d%s' % (key, suffixes[i])] = val[:, i] + batch['rgb'] = rgb_stereo + # Change pose + if 'pose' in batch: + pose_stereo = {} + for key, val in batch['pose'].items(): + pose_stereo[key] = val[:, 0] + for i in range(1, val.shape[1]): + pose_stereo['%d%s' % (key, suffixes[i])] = val[:, i] + batch['pose'] = pose_stereo + # Change intrinsics + if 'intrinsics' in batch: + intrinsics_stereo = {} + for key, val in batch['intrinsics'].items(): + intrinsics_stereo[key] = val[:, 0] + for i in range(1, val.shape[1]): + intrinsics_stereo['%d%s' % (key, suffixes[i])] = val[:, i] + for key in batch['pose'].keys(): + if not is_str(key) and key != 0: + for suffix in ['r', 's', 't', 'u', 'v']: + if '0%s' % suffix in intrinsics_stereo.keys(): + intrinsics_stereo['%d%s' % (key, suffix)] = intrinsics_stereo['0%s' % suffix] + batch['intrinsics'] = intrinsics_stereo + return batch + + @staticmethod + def get_stereo_pose(batch, pose, context): + """Get poses from stereo images""" + for key in context: + if key not in pose.keys() and curr_stereo(key): + pose[key] = batch['pose'][key] + return pose + + @staticmethod + def pose_context(pose, context): + """Extract context poses from a pose dictionary""" + for key in context: + if not is_str(key): + for key2 in list(pose.keys()): + if curr_stereo(key2): + new_key = '%d%s' % (key, key2[1:]) + if new_key in context: + pose[new_key] = pose[key2] @ pose[key] + return pose + + @staticmethod + def conf_mask(depth1, depth2, thr=1.0): + """Calculate confidence masks""" + mask1 = ((depth1 - depth2) / depth2).abs() < thr + mask2 = ((depth2 - depth1) / depth1).abs() < thr + return mask1 * mask2 + + def forward(self, batch, epoch=0): + """Model forward pass""" + + mono_depth_string = 'mono_depth' + multi_depth_string = 'multi_depth' + + predictions = {} + + batch = {key: val for key, val in batch.items()} + batch['rgb'] = {key: val for key, val in batch['rgb'].items()} + + ### TRANSFORMER + + batch = self.process_stereo(batch) + batch_rgb = batch['rgb'] + rgbs = make_rgb_scales(batch_rgb, ratio_scales=(0.5, 4)) + + loss_auto_encoder = None + rgbs_pseudo = rgbs + + ### Get images and contexts + + device = rgbs[0][0].device + batch_size = rgbs[0][0].shape[0] + + rgb, rgb_context = break_context( + rgbs_pseudo, tgt=0, ctx=self.match_context, scl=0, stack=True) + + rgbs0 = {key: val[0] for key, val in rgbs.items()} + + ### Warp pose + + warp_context_pose = [idx for idx in self.warp_context if not is_str(idx)] + if not self.use_gt_pose: + pose_warp = self.compute_pose( + rgbs0, self.networks['pose'], + ctx=warp_context_pose, invert=True) + else: + pose_warp = {key: batch['pose'][key] for key in warp_context_pose} + + pose_warp = self.get_stereo_pose(batch, pose_warp, self.warp_context) + pose_warp = self.pose_context(pose_warp, self.warp_context) + + ### Match pose + + if self.run_multi: + match_context_pose = [idx for idx in self.match_context if not is_str(idx)] + if not self.use_gt_pose: + with torch.no_grad(): + pose_match = self.compute_pose( + rgbs0, self.networks['pose'], + ctx=match_context_pose, invert=True) + else: + pose_match = {key: batch['pose'][key] for key in match_context_pose} + pose_match = self.get_stereo_pose(batch, pose_match, self.match_context) + pose_match = self.pose_context(pose_match, self.match_context) + else: + pose_match = None + + ### Augmentation Mask + + augmentation_mask = torch.zeros([batch_size, 1, 1, 1], device=device).float() + if self.run_multi: + if self.training and self.matching_augmentation: + for batch_idx in range(batch_size): + rand_num = random.random() + if rand_num < 0.25: + rgb_context[batch_idx] = \ + torch.stack([rgb[batch_idx] for _ in self.match_context], 0) + augmentation_mask[batch_idx] += 1 + elif rand_num < 0.5: + pose_match[-1][batch_idx] *= 0 + augmentation_mask[batch_idx] += 1 + + ### Warp cameras + + intrinsics = batch['intrinsics'] + cams_warp = create_cameras(rgbs[0][0], intrinsics, pose_warp) + + ### Monocular depth + + if self.run_mono: + + if self.mono_type == 'multi': + ctx = [ctx for ctx in self.match_context if curr_stereo(ctx)] + pose_match2 = {key: val for key, val in pose_match.items() if curr_stereo(key)} + cams_match2 = create_cameras(rgbs[0][0], intrinsics, pose_match2) + rgb2, rgb_context2 = break_context( + rgbs, tgt=0, ctx=ctx, scl=0, stack=True) + mono_depth_output = self.networks[mono_depth_string]( + rgb=rgb2, rgb_context=rgb_context2, cams=cams_match2, intrinsics=intrinsics, mode='multi') + predictions['depth_lowest_mono'] = { + 0: [inv2depth(mono_depth_output['lowest_cost'].unsqueeze(1)).detach()]} + predictions['volume_mono'] = {0: mono_depth_output['cost_volume']} + predictions['mask_confidence_mono'] = { + 0: [mono_depth_output['confidence_mask'].unsqueeze(1)]} + elif self.mono_type == 'mono': + mono_depth_output = self.networks[mono_depth_string]( + rgb=rgb, intrinsics=intrinsics) + else: + raise ValueError + + if self.use_gt_depth: + depth_mono = [batch['depth'][0][:, 0]] + else: + depth_mono = mono_depth_output['depths'] + + predictions['depth_mono'] = {0: depth_mono} + else: + mono_depth_output = depth_mono = None + + ### Multi-frame depth + + if self.run_multi: + + if self.multi_temporal_only: + ctx = [ctx for ctx in self.match_context if not curr_stereo(ctx)] + pose_match3 = {key: val for key, val in pose_match.items() if not is_str(key)} + cams_match3 = create_cameras(rgbs[0][0], intrinsics[0], pose_match3) + rgb3, rgb_context3 = break_context( + rgbs_pseudo, tgt=0, ctx=ctx, scl=0, stack=True) + multi_depth_output = self.networks[multi_depth_string]( + rgb=rgb3, rgb_context=rgb_context3, cams=cams_match3, + intrinsics=intrinsics, networks=self.networks, + ) + else: + cams_match = create_cameras(rgbs[0][0], intrinsics, pose_match) + multi_depth_output = self.networks[multi_depth_string]( + rgb=rgb, rgb_context=rgb_context, cams=cams_match, + intrinsics=intrinsics, networks=self.networks, + ) + + if self.use_gt_depth: + depth_multi = [batch['depth'][0][:, 0]] + else: + depth_multi = multi_depth_output['depths'] + + predictions['depth_multi'] = {0: depth_multi} + predictions['volume_multi'] = {0: multi_depth_output['cost_volume']} + predictions['depth_lowest_multi'] = { + 0: [inv2depth(d.unsqueeze(1)).detach() for d in multi_depth_output['lowest_cost']]} + predictions['mask_confidence_multi'] = { + 0: [multi_depth_output['confidence_mask'].unsqueeze(1)]} + + if 'ssim_lowest_cost' in multi_depth_output: + predictions['depth_lowest_ssim'] = { + 0: [inv2depth(multi_depth_output['ssim_lowest_cost'].unsqueeze(1)).detach()]} + + else: + + multi_depth_output = depth_multi = None + + ### Confidence mask + + if self.run_multi: + shape = rgbs0[0].shape[-2:] + lowest_cost = self.interpolate_nearest( + multi_depth_output['lowest_cost'][0].unsqueeze(1), size=shape).squeeze(1).to(device) + confidence_mask = self.interpolate_nearest( + multi_depth_output['confidence_mask'].unsqueeze(1), size=shape).to(device) + if self.motion_masking and self.run_mono: + + if 'regression' in multi_depth_output: + inv_depth_low_res = multi_depth_output['regression']['disp_pred_low_res'] + inv_depth_low_res = self.interpolate_nearest( + inv_depth_low_res.unsqueeze(1), size=shape).squeeze(1).to(device) + lowest_cost = inv_depth_low_res + + matching_depth = 1. / lowest_cost.unsqueeze(1).to(device) + confidence_mask *= self.conf_mask(matching_depth, depth_mono[0]) + # predictions['mask_confidence'] = {0: [confidence_mask.unsqueeze(1)]} + confidence_mask = confidence_mask * (1 - augmentation_mask) + else: + confidence_mask = None + + ########## LOSSES + + loss, metrics = [], {} + mono_metrics, multi_metrics = {}, {} + mono_visuals, multi_visuals = {}, {} + + valid_mask = get_from_dict(batch, 'mask') + if valid_mask is not None: + valid_mask = valid_mask[:, 0] + + predictions['output_mono'] = mono_depth_output + predictions['output_multi'] = multi_depth_output + + if 'depth_regr' in multi_depth_output: + predictions['depth_regr'] = { + 0: [d.unsqueeze(1) for d in multi_depth_output['depth_regr']] + } + + if 'cal' in self.networks['multi_depth'].networks.keys(): + cal = self.networks['multi_depth'].networks['cal'] + depth1 = depth_multi[0] + depth2 = predictions['depth_regr'][0][0] + from vidar.utils.tensor import interpolate + depth2 = interpolate(depth2, size=depth1.shape[-2:], scale_factor=None, mode='nearest', align_corners=None) + predictions['depth_regr'][0].insert(0, cal(depth1, depth2, rgb)) + + if not self.training: + return { + 'predictions': predictions, + } + + gt_depth = None if 'depth' not in batch else batch['depth'][0] + + ### Temporal losses + + cams_warp_temp = {key: val for key, val in cams_warp.items() + if not curr_stereo(key) or key == 0} + + if len(cams_warp_temp) > 0 \ + and self.spatial_weight[0] < 1.0 \ + and self.spatio_temporal_weight[0] < 1.0: + + if self.run_mono: + mono_loss_temp, mono_metrics_temp, mono_visuals_temp = \ + self.compute_loss_and_metrics( + rgbs, depth_mono, cams_warp_temp, logvar=None, valid_mask=valid_mask + ) + loss.append((1 - self.spatial_weight[0]) * + (1 - self.spatio_temporal_weight[0]) * mono_loss_temp) + mono_metrics.update(**mono_metrics_temp) + mono_visuals.update(**{f'temp_{key}': val for key, val in mono_visuals_temp.items()}) + metrics.update({f'mono_temp_{key}': val for key, val in mono_metrics_temp.items()}) + + if len(cams_warp_temp) > 0 \ + and self.spatial_weight[1] < 1.0 \ + and self.spatio_temporal_weight[1] < 1.0: + + if self.run_multi: + multi_loss_temp, multi_metrics_temp, multi_visuals_temp = \ + self.compute_loss_and_metrics( + rgbs, depth_multi, cams_warp_temp, logvar=None, + depths_consistency=depth_mono if self.apply_consistency else None, + confidence_mask=confidence_mask if self.apply_consistency else None, + valid_mask=valid_mask, + ) + loss.append((1 - self.spatial_weight[1]) * + (1 - self.spatio_temporal_weight[1]) * multi_loss_temp) + multi_metrics.update(**multi_metrics_temp) + multi_visuals.update(**{f'temp_{key}': val for key, val in multi_visuals_temp.items()}) + metrics.update({f'multi_temp_{key}': val for key, val in multi_metrics_temp.items()}) + + ### Spatial Losses + + cams_warp_spat = {key: val for key, val in cams_warp.items() + if curr_stereo(key) or key == 0} + + if len(cams_warp_spat) > 0 \ + and self.spatial_weight[0] > 0.0 \ + and self.spatio_temporal_weight[0] < 1.0: + + if self.run_mono: + mono_loss_spat, mono_metrics_spat, mono_visuals_spat = \ + self.compute_loss_and_metrics( + rgbs, depth_mono, cams_warp_spat, logvar=None, valid_mask=valid_mask + ) + loss.append(self.spatial_weight[0] * + (1 - self.spatio_temporal_weight[0]) * mono_loss_spat) + mono_metrics.update(**mono_metrics_spat) + mono_visuals.update(**{f'spat_{key}': val for key, val in mono_visuals_spat.items()}) + metrics.update({f'mono_spat_{key}': val for key, val in mono_metrics_spat.items()}) + + if len(cams_warp_spat) > 0 \ + and self.spatial_weight[1] > 0.0 \ + and self.spatio_temporal_weight[1] < 1.0: + + if self.run_multi: + multi_loss_spat, multi_metrics_spat, multi_visuals_spat = \ + self.compute_loss_and_metrics( + rgbs, depth_multi, cams_warp_spat, logvar=None, + depths_consistency=depth_mono if self.apply_consistency else None, + valid_mask=valid_mask, + confidence_mask=confidence_mask if self.apply_consistency else None + ) + loss.append(self.spatial_weight[1] * + (1 - self.spatio_temporal_weight[1]) * multi_loss_spat) + multi_metrics.update(**multi_metrics_spat) + multi_visuals.update(**{f'spat_{key}': val for key, val in multi_visuals_spat.items()}) + metrics.update({f'multi_spat_{key}': val for key, val in multi_metrics_spat.items()}) + + ### Spatio-temporal Losses + + if self.spatio_temporal_weight[0] > 0.0: + + if self.run_mono: + mono_loss_both, mono_metrics_both, mono_visuals_both = \ + self.compute_loss_and_metrics( + rgbs, depth_mono, cams_warp, logvar=None, valid_mask=valid_mask + ) + loss.append(self.spatio_temporal_weight[0] * mono_loss_both) + mono_metrics.update(**mono_metrics_both) + mono_visuals.update(**{f'both_{key}': val for key, val in mono_visuals_both.items()}) + metrics.update({f'mono_both_{key}': val for key, val in mono_metrics_both.items()}) + + if self.spatio_temporal_weight[1] > 0.0: + + if self.run_multi: + multi_loss_both, multi_metrics_both, multi_visuals_both = \ + self.compute_loss_and_metrics( + rgbs, depth_multi, cams_warp, logvar=None, + depths_consistency=depth_mono if self.apply_consistency else None, + confidence_mask=confidence_mask if self.apply_consistency else None, + valid_mask=valid_mask, + ) + loss.append(self.spatio_temporal_weight[1] * multi_loss_both) + multi_metrics.update(**multi_metrics_both) + multi_visuals.update(**{f'both_{key}': val for key, val in multi_visuals_both.items()}) + metrics.update({f'multi_both_{key}': val for key, val in multi_metrics_both.items()}) + + ### + + if loss_auto_encoder is not None: + loss.append(loss_auto_encoder) + + if 'depth_regr' in predictions: + regr_loss, regr_metrics, regr_visuals = \ + self.compute_loss_and_metrics( + rgbs, predictions['depth_regr'][0], cams_warp, logvar=None, valid_mask=valid_mask + ) + loss.append(regr_loss) + + depth_pred = [predictions['depth_regr'][0][0]] + depth_gt = depth_mono[0].detach() + supervision_output = self.losses['supervision'](depth_pred, depth_gt) + loss.append(supervision_output['loss']) + + loss = sum(loss) + + metrics.update({ + 'min_depth_bin': self.networks[multi_depth_string].networks['encoder'].min_depth_bin, + 'max_depth_bin': self.networks[multi_depth_string].networks['encoder'].max_depth_bin, + }) + + visuals = { + **{f'mono_{key}': val for key, val in mono_visuals.items()}, + **{f'multi_{key}': val for key, val in multi_visuals.items()}, + } + + if self.run_mono and self.run_multi and \ + self.training and epoch < self.freeze_teacher_and_pose: + self.networks[multi_depth_string].networks['encoder'].update_adaptive_depth_bins(depth_mono[0]) + if 'lowest_cost' in mono_depth_output: + self.networks[mono_depth_string].networks['encoder'].min_depth_bin = \ + self.networks[multi_depth_string].networks['encoder'].min_depth_bin + self.networks[mono_depth_string].networks['encoder'].max_depth_bin = \ + self.networks[multi_depth_string].networks['encoder'].max_depth_bin + + return { + 'loss': loss, + 'metrics': metrics, + 'visuals': visuals, + 'predictions': predictions, + } + + def compute_loss_and_metrics(self, rgbs, depths, cams, depths_consistency=None, + logvar=None, valid_mask=None, confidence_mask=None): + """ + Compute model loss and metrics + + Parameters + ---------- + rgbs : list[torch.Tensor] + Input RGB images + depths : list[torch.Tensor] + Predicted depth maps + cams : list[Camera] + Image cameras + depths_consistency : list[torch.Tensor] + Depth maps used for consistency loss calculation + logvar : list[torch.Tensor] + Predicted log-variance for depth maps + valid_mask : list[torch.Tensor] + Valid mask for masking out pixels + confidence_mask : list[torch.Tensor] + Confidence mask for consistency calculation + + Returns + ------- + loss : torch.Tensor + Final loss + metrics : Dict + Dictionary with calculated metrics + visuals : Dict + Dictionary with calculated visualizations + """ + num_scales = self.get_num_scales(depths) + + rgb_tgt = [rgbs[0][i] for i in range(num_scales)] + rgb_ctx = [[rgbs[j][i] for j in cams.keys() if j != 0] for i in range(num_scales)] + + loss, metrics, visuals = [], {}, {} + + if 'reprojection' in self.losses: + synth = self.view_synthesis(rgbs, depths, cams, return_masks=True) + reprojection_mask = multiply_args(valid_mask, confidence_mask) + reprojection_output = self.losses['reprojection']( + rgb_tgt, rgb_ctx, synth['warps'], logvar=logvar, + valid_mask=reprojection_mask, overlap_mask=synth['masks']) + loss.append(reprojection_output['loss']) + metrics.update(reprojection_output['metrics']) + visuals['synth'] = synth + visuals['reproj'] = reprojection_output + + if 'smoothness' in self.losses: + smoothness_output = self.losses['smoothness'](rgb_tgt, depths) + loss.append(smoothness_output['loss']) + metrics.update(smoothness_output['metrics']) + + if 'consistency' in self.losses and depths_consistency is not None: + consistency_output = self.losses['consistency']( + depths_consistency, depths, + confidence_mask=reprojection_output['mask'], + valid_mask=valid_mask + ) + loss.append(consistency_output['loss']) + metrics.update(consistency_output['metrics']) + + loss = sum(loss) + + return loss, metrics, visuals diff --git a/vidar/arch/models/depth/DepthModel.py b/vidar/arch/models/depth/DepthModel.py new file mode 100755 index 0000000000000000000000000000000000000000..25567a40a56b7a239e81e404e78b80604c51b61e --- /dev/null +++ b/vidar/arch/models/depth/DepthModel.py @@ -0,0 +1,28 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from vidar.arch.models.BaseModel import BaseModel + + +class DepthModel(BaseModel): + """ + Base depth estimation model + + Parameters + ---------- + cfg : Config + Configuration with parameters + """ + def __init__(self, cfg): + super().__init__() + + def forward(self, batch, **kwargs): + """Model forward pass""" + + depth_output = self.networks['depth'](batch['rgb'][0]) + return { + 'loss': 0.0, + 'metrics': {}, + 'predictions': { + 'depth': {0: depth_output['depths']} + } + } diff --git a/vidar/arch/models/depth/FSMModel.py b/vidar/arch/models/depth/FSMModel.py new file mode 100644 index 0000000000000000000000000000000000000000..83afbcdd84610ffae465db680278b9d519f76dde --- /dev/null +++ b/vidar/arch/models/depth/FSMModel.py @@ -0,0 +1,311 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import random + +import numpy as np +import torch + +from vidar.arch.losses.MultiCamPhotometricLoss import MultiCamPhotometricLoss +from vidar.arch.models.BaseModel import BaseModel +from vidar.arch.networks.layers.fsm.camera import Camera +from vidar.arch.networks.layers.fsm.pose import Pose +from vidar.arch.networks.layers.fsm.utils import \ + CameraNormalizer, flip_batch_input, flip_output, filter_dict, coords_from_motion, mask_from_coords +from vidar.utils.depth import depth2inv, inv2depth +from vidar.utils.types import is_list, is_seq + + +def split_batch(tensor, n=1): + """Split a tensor batch-wise""" + if is_list(tensor): + split = [split_batch(t, n=n) for t in tensor] + return list(map(list, zip(*split))) + return torch.split(tensor, split_size_or_sections=n, dim=0) + + +def global_cameras(intrinsics, pose, pose_context, hw=None): + """Create global cameras for target and source poses + target intrinsics""" + cam = camera_from_intrinsics_pose(pose, intrinsics, hw=hw) + cam_context = camera_from_intrinsics_pose(pose_context, intrinsics, pose, hw=hw) + return cam, cam_context + + +def camera_from_intrinsics_pose(pose, intrinsics, orig_pose=None, hw=None): + """ + Create one or more cameras from pose and intrinsics + + Parameters + ---------- + pose : torch.Tensor or list[torch.Tensor] + Poses to be used [B,4,4] + intrinsics : torch.Tensor or list[torch.Tensor] + Intrinsics to be used [B,3,3] + orig_pose : torch.Tensor or list[torch.Tensor] + Original poses from which pose is generated [B,4,4] + hw : tuple + Camera image dimensions + + Returns + ------- + camera : Camera + Camera instance created from the input + """ + # If pose is a sequence, do it for each one + if is_seq(pose): + # If intrinsics is not a sequence, repeat it + if not is_seq(intrinsics): + intrinsics = [intrinsics] * len(pose) + # If orig pose is not a sequence, repeat it + if not is_seq(orig_pose): + orig_pose = [orig_pose] * len(pose) + # Recursive loop for each item + return [camera_from_intrinsics_pose(p, i, o, hw=hw) + for p, i, o in zip(pose, intrinsics, orig_pose)] + # Compound original pose if available + if orig_pose is not None: + pose = Pose(orig_pose) @ Pose(pose).inverse() + # Return camera + return Camera(K=intrinsics, Twc=pose, hw=hw) + + +class FSMModel(BaseModel): + """ + Full Surround Monodepth (FSM) model (https://arxiv.org/abs/2104.00152) + + Parameters + ---------- + cfg : Config + Configuration with parameters + """ + def __init__(self, cfg, **kwargs): + super().__init__(cfg) + + self.rotation_mode = 'euler' + self.flip_lr_prob = 0.0 + self.upsample_depth_maps = False + + norm_focal = [] + self.focal = None if len(norm_focal) == 0 else CameraNormalizer(focal=norm_focal) + + pairs = [ + [0, 1], [1, 0], + [0, 2], [2, 0], + [1, 4], [4, 1], + [2, 3], [3, 2], + [3, 5], [5, 3], + [4, 5], [5, 4], + ] + + stereo_weight = 0.1 + stereo_context = True + gt_pose = False + + self.multicam_loss = MultiCamPhotometricLoss(**kwargs) + self.pairs = pairs + self.stereo_weight = stereo_weight + self.gt_pose = gt_pose + self.stereo_context = stereo_context + + self._input_keys = ['rgb', 'rgb_context', 'intrinsics', 'extrinsics', + 'pose_context', 'filename'] + self.networks = torch.nn.ModuleDict() + + def compute_depth_net(self, batch, force_flip=False): + """Computes inverse depth maps from single images""" + # Randomly flip and estimate inverse depth maps + flag_flip_lr = random.random() < self.flip_lr_prob if self.training else force_flip + output = self.depth_net_flipping(batch, flag_flip_lr) + if self.focal is not None: + output['inv_depths'] = self.focal.unormalize(output['inv_depths']) + # Return inverse depth maps + return output + + def compute_pose_net(self, image, contexts): + """Compute poses from image and a sequence of context images""" + pose_vec = self.networks['pose'](image, contexts) + return [Pose.from_vec(pose_vec[:, i], self.rotation_mode) + for i in range(pose_vec.shape[1])] + + def depth_net_flipping(self, batch, flip): + + # Which keys are being passed to the depth network + batch_input = {key: batch[key] for key in filter_dict(batch, self._input_keys)} + if self.focal is not None: + batch_input['rgb'] = self.focal.normalize(batch_input['rgb'], batch['intrinsics']) + if flip: + # Run depth network with flipped inputs + output = self.networks['depth'](**flip_batch_input(batch_input)) + # Flip output back if training + if self.training: + output = flip_output(output) + else: + # Run depth network + output = self.networks['depth'](**batch_input) + return output + + def forward2(self, batch, return_logs=False, force_flip=False): + # Generate inverse depth predictions + depth_output = self.compute_depth_net(batch, force_flip=force_flip) + # Generate pose predictions if available + pose_output = None + if 'rgb_context' in batch and self.networks['pose'] is not None: + pose_output = self.compute_pose_net( + batch['rgb'], batch['rgb_context']) + # Return output dictionary + return { + **depth_output, + 'poses': pose_output, + } + + def forward(self, batch, return_logs=False, progress=0.0, **kwargs): + + new_batch = {} + new_batch['rgb'] = batch['rgb'][0] + if self.training: + new_batch['rgb_context'] = [batch['rgb'][1], batch['rgb'][-1]] + new_batch['pose'] = batch['pose'][0] + new_batch['pose_context'] = [batch['pose'][1], batch['pose'][-1]] + new_batch['intrinsics'] = batch['intrinsics'][0] + new_batch['filename'] = batch['filename'] + batch = new_batch + + if self.training: + batch['rgb'] = batch['rgb'][0] + batch['rgb_context'] = [b[0] for b in batch['rgb_context']] + batch['pose'] = batch['pose'][0] + batch['pose_context'] = [b[0] for b in batch['pose_context']] + batch['intrinsics'] = batch['intrinsics'][0] + + if self.networks['depth'] is not None: + output_self_sup = self.forward2(batch) + depth = inv2depth(output_self_sup['inv_depths']) + else: + output_self_sup = {} + depth = batch['depth'] + + if not self.training: + output_new = { + 'predictions': { + 'depth': { + 0: [1. / d for d in output_self_sup['inv_depths']] + } + } + } + output_self_sup = output_new + return output_self_sup + + rgb = batch['rgb'] + rgb_context = batch['rgb_context'] + intrinsics = batch['intrinsics'] + pose = batch['extrinsics'] if 'extrinsics' in batch else batch['pose'] + + pose_context_gt = batch['pose_context'] + if self.gt_pose: + pose_context = batch['pose_context'] + else: + pose_context = output_self_sup['poses'] + + for i in range(len(pose_context)): + pose_context[i] = pose_context[i].mat + + rgb_i, rgb_context_i = split_batch(rgb), split_batch(rgb_context) + pose_i, pose_context_i = split_batch(pose), split_batch(pose_context) + intrinsics_i, inv_depth_i = split_batch(intrinsics), depth2inv(split_batch(depth)) + cam_i, cam_context_i = global_cameras(intrinsics_i, pose_i, pose_context_i, hw=rgb.shape[2:]) + + _, pose_context_i_gt = split_batch(pose), split_batch(pose_context_gt) + _, cam_context_i_gt = global_cameras(intrinsics_i, pose_i, pose_context_i_gt, hw=rgb.shape[2:]) + + n_tgt = len(rgb_i) + + mono_coords = [coords_from_motion( + cam_context_i[tgt], inv2depth(inv_depth_i[tgt]), cam_i[tgt]) + for tgt in range(n_tgt)] + mono_masks = [mask_from_coords(coords) for coords in mono_coords] + + filename = batch['filename'] + try: + filename = ['camera' + f[0].split('/')[-2][-3:]+ '_mask.npy' for f in filename] + # filename = ['camera' + f.split('/')[-2][-3:]+ '_mask.npy' for f in filename] + masks = [torch.tensor(np.load(f)).unsqueeze(0).unsqueeze(0) for f in filename] + for tgt in range(n_tgt): + for i in range(len(mono_masks[tgt])): + for j in range(len(mono_masks[tgt][i])): + for k in range(len(mono_masks[tgt][i][j])): + # write_image('debug/camera_%d/mask_%d_%d_%d.png' % (tgt, i, j, k), + # mono_masks[tgt][i][j][k]) + resized_mask = torch.nn.functional.interpolate( + masks[tgt], mono_masks[tgt][i][j][k].shape[1:], mode='nearest').squeeze(0).bool() + mono_masks[tgt][i][j][k] *= resized_mask.to(mono_masks[tgt][i][j][k].device) + with_masks = True + except: + with_masks = False + pass + + mono = [] + outputs = [] + + for tgt in range(n_tgt): + output = self.multicam_loss( + rgb_i[tgt], rgb_context_i[tgt], inv_depth_i[tgt], + cam_i[tgt], cam_context_i[tgt], with_mask=mono_masks[tgt]) + if not torch.isnan(output['loss']): + mono.append(output['loss']) + outputs.append(output) + + stereo = [] + if not self.stereo_context and self.stereo_weight > 0: + + stereo_coords = [coords_from_motion( + [cam_i[src]], inv2depth(inv_depth_i[tgt]), cam_i[tgt]) + for tgt, src in self.pairs] + stereo_masks = [mask_from_coords(coords) for coords in stereo_coords] + + if with_masks: + for tgt in range(len(self.pairs)): + for i in range(len(stereo_masks[tgt])): + for j in range(len(stereo_masks[tgt][i])): + for k in range(len(stereo_masks[tgt][i][j])): + hw = stereo_masks[tgt][i][j][k].shape[1:] + h_st, h_fn = int(0.15 * hw[0]), int(0.75 * hw[0]) + stereo_masks[tgt][i][j][k][:, :h_st] = 0 + stereo_masks[tgt][i][j][k][:, h_fn:] = 0 + + for k, (tgt, src) in enumerate(self.pairs): + output = self.multicam_loss( + rgb_i[tgt], [rgb_i[src]], inv_depth_i[tgt], + cam_i[tgt], [cam_i[src]], with_mask=stereo_masks[k], automask=False) + if not torch.isnan(output['loss']): + stereo.append(self.stereo_weight * output['loss']) + + elif self.stereo_context and self.stereo_weight > 0: + + stereo_coords = [coords_from_motion( + [cam_i[src]] + cam_context_i[src], inv2depth(inv_depth_i[tgt]), cam_i[tgt]) + for tgt, src in self.pairs] + stereo_masks = [mask_from_coords(coords) for coords in stereo_coords] + + for tgt in range(len(self.pairs)): + for i in range(len(stereo_masks[tgt])): + for j in range(len(stereo_masks[tgt][i])): + for k in range(len(stereo_masks[tgt][i][j])): + hw = stereo_masks[tgt][i][j][k].shape[1:] + h_st, h_fn = int(0.15 * hw[0]), int(0.75 * hw[0]) + stereo_masks[tgt][i][j][k][:, :h_st] = 0 + stereo_masks[tgt][i][j][k][:, h_fn:] = 0 + + for k, (tgt, src) in enumerate(self.pairs): + output = self.multicam_loss( + rgb_i[tgt], [rgb_i[src]] + rgb_context_i[src], inv_depth_i[tgt], + cam_i[tgt], [cam_i[src]] + cam_context_i[src], with_mask=stereo_masks[k], automask=False) + if not torch.isnan(output['loss']): + stereo.append(self.stereo_weight * output['loss']) + + losses = mono + stereo + loss = sum(losses) / len(losses) + output_self_sup['loss'] = loss.unsqueeze(0) + + return { + **output_self_sup, + 'metrics': {}, + } diff --git a/vidar/arch/models/depth/SelfSupervisedModel.py b/vidar/arch/models/depth/SelfSupervisedModel.py new file mode 100644 index 0000000000000000000000000000000000000000..68c360b6202ccaf97f8775899240aadea836ab74 --- /dev/null +++ b/vidar/arch/models/depth/SelfSupervisedModel.py @@ -0,0 +1,176 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from abc import ABC + +from vidar.arch.blocks.image.ViewSynthesis import ViewSynthesis +from vidar.arch.models.BaseModel import BaseModel +from vidar.arch.models.utils import make_rgb_scales, create_cameras +from vidar.utils.data import get_from_dict +from vidar.utils.config import cfg_has + + +class SelfSupervisedModel(BaseModel, ABC): + """ + Self-supervised depth estimation model + + Parameters + ---------- + cfg : Config + Configuration with parameters + """ + def __init__(self, cfg): + super().__init__(cfg) + self.view_synthesis = ViewSynthesis() + self.set_attr(cfg.model, 'use_gt_pose', False) + self.set_attr(cfg.model, 'use_gt_intrinsics', True) + + if not self.use_gt_intrinsics: + self.camera_model = cfg_has(cfg.networks.intrinsics, 'camera_model', 'UCM') + if self.camera_model == 'UCM': + from vidar.geometry.camera_ucm import UCMCamera + self.camera_class = UCMCamera + elif self.camera_model == 'EUCM': + from vidar.geometry.camera_eucm import EUCMCamera + self.camera_class = EUCMCamera + elif self.camera_model == 'DS': + from vidar.geometry.camera_ds import DSCamera + self.camera_class = DSCamera + else: + raise NotImplementedError('Invalid camera type') + + def forward(self, batch, epoch=0): + """Model forward pass""" + + rgb = batch['rgb'] + if self.use_gt_intrinsics: + intrinsics = get_from_dict(batch, 'intrinsics') + else: + intrinsics = self.networks['intrinsics'](rgb=rgb[0]) + + valid_mask = get_from_dict(batch, 'mask') + + if self.use_gt_intrinsics: + depth_output = self.networks['depth'](rgb=rgb[0], intrinsics=intrinsics[0]) + else: + depth_output = self.networks['depth'](rgb=rgb[0]) + pred_depth = depth_output['depths'] + + predictions = { + 'depth': {0: pred_depth}, + } + + pred_logvar = get_from_dict(depth_output, 'logvar') + if pred_logvar is not None: + predictions['logvar'] = {0: pred_logvar} + + if not self.training: + return { + 'predictions': predictions, + } + + if self.use_gt_pose: + assert 'pose' in batch, 'You need input pose' + pose = batch['pose'] + elif 'pose' in self.networks: + pose = self.compute_pose(rgb, self.networks['pose'], tgt=0, invert=True) + predictions['pose'] = pose + else: + pose = None + + if not self.use_gt_intrinsics: + cams = {0: self.camera_class(I=intrinsics)} + for key in pose.keys(): + cams[key] = self.camera_class(I=intrinsics, Tcw=pose[key]) + else: + cams = create_cameras(rgb[0], intrinsics[0], pose) + + gt_depth = None if 'depth' not in batch else batch['depth'][0] + loss, metrics = self.compute_loss_and_metrics( + rgb, pred_depth, cams, gt_depth=gt_depth, + logvar=pred_logvar, valid_mask=valid_mask + ) + + if not self.use_gt_intrinsics: + if self.camera_model == 'UCM': + fx, fy, cx, cy, alpha = intrinsics[0].squeeze() + intrinsics_metrics = {'fx': fx, 'fy':fy, 'cx':cx, 'cy':cy, 'alpha':alpha} + metrics.update(intrinsics_metrics) + elif self.camera_model == 'EUCM': + fx, fy, cx, cy, alpha, beta = intrinsics[0].squeeze() + intrinsics_metrics = {'fx': fx, 'fy':fy, 'cx':cx, 'cy':cy, 'alpha':alpha, 'beta':beta} + metrics.update(intrinsics_metrics) + elif self.camera_model == 'DS': + fx, fy, cx, cy, xi, alpha = intrinsics[0].squeeze() + intrinsics_metrics = {'fx': fx, 'fy':fy, 'cx':cx, 'cy':cy, 'xi':xi, 'alpha':alpha} + metrics.update(intrinsics_metrics) + else: + raise NotImplementedError('Invalid camera type') + + return { + 'loss': loss, + 'metrics': metrics, + 'predictions': predictions, + } + + def compute_loss_and_metrics(self, rgb, depths, cams, gt_depth=None, + logvar=None, valid_mask=None): + """ + Compute loss and metrics for training + + Parameters + ---------- + rgb : Dict + Dictionary with input images [B,3,H,W] + depths : list[torch.Tensor] + List with target depth maps in different scales [B,1,H,W] + cams : Dict + Dictionary with cameras for each input image + gt_depth : torch.Tensor + Ground-truth depth map for supervised training + logvar : list[torch.Tensor] + Log-variance maps for uncertainty training + valid_mask : list[torch.Tensor] + Binary mask for masking out invalid pixels [B,1,H,W] + + Returns + ------- + loss : torch.Tensor + Training loss + metrics : Dict + Dictionary with training metrics + """ + tgt = 0 + ctx = [key for key in rgb.keys() if key != tgt] + + num_scales = self.get_num_scales(depths) + + rgbs = make_rgb_scales(rgb, pyramid=depths) + rgb_tgt = [rgbs[tgt][i] for i in range(num_scales)] + rgb_ctx = [[rgbs[j][i] for j in ctx] for i in range(num_scales)] + + loss, metrics = [], {} + + if 'reprojection' in self.losses: + synth = self.view_synthesis( + rgbs, depths=depths, cams=cams, return_masks=True) + reprojection_output = self.losses['reprojection']( + rgb_tgt, rgb_ctx, synth['warps'], logvar=logvar, + valid_mask=valid_mask, overlap_mask=synth['masks']) + loss.append(reprojection_output['loss']) + metrics.update(**reprojection_output['metrics']) + if 'smoothness' in self.losses: + smoothness_output = self.losses['smoothness'](rgb_tgt, depths) + loss.append(smoothness_output['loss']) + metrics.update(**smoothness_output['metrics']) + if 'supervision' in self.losses and gt_depth is not None: + supervision_output = self.losses['supervision'](depths, gt_depth) + loss.append(supervision_output['loss']) + metrics.update(**supervision_output['metrics']) + if 'normals' in self.losses and gt_depth is not None: + normals_output = self.losses['normals'](depths, gt_depth, cams[0]) + loss.append(normals_output['loss']) + metrics.update(**normals_output['metrics']) + + loss = sum(loss) + + return loss, metrics diff --git a/vidar/arch/models/depth/SupervisedModel.py b/vidar/arch/models/depth/SupervisedModel.py new file mode 100755 index 0000000000000000000000000000000000000000..44bfb7d24549c286c7207af7cfa925f115bd1d1c --- /dev/null +++ b/vidar/arch/models/depth/SupervisedModel.py @@ -0,0 +1,92 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from abc import ABC + +from vidar.arch.models.BaseModel import BaseModel +from vidar.utils.decorators import iterate1 +from vidar.utils.tensor import interpolate_image + + +@iterate1 +def make_rgb_scales(rgb, pyramid): + return [interpolate_image(rgb, shape=pyr.shape[-2:]) for pyr in pyramid] + + +class SupervisedModel(BaseModel, ABC): + """ + Supervised depth estimation model + + Parameters + ---------- + cfg : Config + Configuration with parameters + """ + def __init__(self, cfg): + super().__init__() + + def forward(self, batch, epoch): + """Model forward pass""" + + rgb = batch['rgb'] + + depth_output = self.networks['depth'](rgb=rgb[0]) + depths = depth_output['depths'] + + if not self.training: + return { + 'predictions': { + 'depth': {0: depths} + }, + } + + losses = self.compute_losses(rgb, depths, batch['depth']) + + return { + 'loss': losses['loss'], + 'metrics': { + }, + 'predictions': { + 'depth': {0: depths} + }, + } + + def compute_losses(self, rgb, depths, gt_depths): + """ + Compute loss and metrics for training + + Parameters + ---------- + rgb : Dict + Dictionary with input images [B,3,H,W] + depths : list[torch.Tensor] + List with target depth maps in different scales [B,1,H,W] + gt_depths : Dict + Dictionary with ground-truth depth maps + + Returns + ------- + loss : torch.Tensor + Training loss + metrics : Dict + Dictionary with training metrics + """ + tgt = 0 + + rgbs = make_rgb_scales(rgb, depths) + rgb_tgt = [rgbs[tgt][i] for i in range(len(rgbs[tgt]))] + + supervision_output = self.losses['supervision'](depths, gt_depths[tgt]) + smoothness_output = self.losses['smoothness'](rgb_tgt, depths) + + loss = supervision_output['loss'] + \ + smoothness_output['loss'] + + metrics = { + **supervision_output['metrics'], + **smoothness_output['metrics'], + } + + return { + 'loss': loss, + 'metrics': metrics, + } diff --git a/vidar/arch/models/depth/__init__.py b/vidar/arch/models/depth/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vidar/arch/models/perceiver/DefineGenericModel.py b/vidar/arch/models/perceiver/DefineGenericModel.py new file mode 100755 index 0000000000000000000000000000000000000000..dd624bb41c762a4956a3c6539e0765243bc2f9b8 --- /dev/null +++ b/vidar/arch/models/perceiver/DefineGenericModel.py @@ -0,0 +1,421 @@ +# Copyright 2021 Toyota Research Institute. All rights reserved. + +import random +from abc import ABC + +import numpy as np +import torch + +from vidar.arch.models.BaseModel import BaseModel +from vidar.arch.models.perceiver.DefineModel import augment_canonical, create_virtual_cameras +from vidar.geometry.camera_nerf import CameraNerf +from vidar.geometry.pose import Pose +from vidar.utils.data import flatten +from vidar.utils.types import is_list + + +def sample_pred_gt(output, gt, pred): + for i in range(len(gt)): + b, n, h, w = gt[i].shape + query_idx = output['query_idx'][i] + gt[i] = torch.stack([gt[i][j].view(n, -1)[:, query_idx[j]] + for j in range(b)], 0) + pred[i][0] = pred[i][0].permute(0, 2, 1) + return gt, pred + + +def parse_idx(all_idx, idxs): + new_idxs = [] + for idx in idxs: + if is_list(idx[0]) and is_list(idx[1]): + for i in idx[0]: + for j in idx[1]: + new_idxs.append([i, j]) + elif is_list(idx[0]): + for i in idx[0]: + new_idxs.append([i, idx[1]]) + elif is_list(idx[1]): + for i in idx[1]: + new_idxs.append([idx[0], i]) + elif idx[0] == '*': + for i in all_idx: + if i[1] == idx[1]: + new_idxs.append(i) + elif idx[1] == '*': + for i in all_idx: + if i[0] == idx[0]: + new_idxs.append(i) + else: + new_idxs.append(idx) + return new_idxs + + +class DefineGenericModel(BaseModel, ABC): + + def __init__(self, cfg): + super().__init__(cfg) + + from vidar.arch.networks.perceiver.DefineNet import DefineNet + self.networks['perceiver'] = DefineNet(cfg.model.network) + self.weights = cfg.model.task_weights + self.use_pose_noise = cfg.model.use_pose_noise + self.use_virtual_cameras = cfg.model.use_virtual_cameras + self.virtual_cameras_eval = cfg.model.virtual_cameras_eval + self.use_virtual_rgb = cfg.model.use_virtual_rgb + self.augment_canonical = cfg.model.augment_canonical + self.scale_loss = cfg.model.scale_loss + + self.encode_train = cfg.model.encode_train + self.decode_train = cfg.model.decode_train + + self.encode_eval = cfg.model.encode_eval + self.decode_eval = cfg.model.decode_eval + + if self.encode_eval == 'same': + self.encode_eval = self.encode_train + if self.decode_eval == 'same': + self.decode_eval = self.decode_train + + self.decode_encodes = cfg.model.decode_encodes + self.sample_decoded_queries = cfg.model.sample_decoded_queries + + if cfg.model.has('display'): + if cfg.model.display == 'interactive': + from vidar.arch.models.perceiver.display import DisplayInteractive + self.display = DisplayInteractive() + elif cfg.model.display == 'training': + from vidar.arch.models.perceiver.display import DisplayTraining + self.display = DisplayTraining() + else: + self.display = None + + def get_idx(self, all_idx): + + n = len(all_idx) + + num_encodes = self.encode_train if self.training else self.encode_eval + num_decodes = self.decode_train if self.training else self.decode_eval + + encode_idx = None + if is_list(num_encodes): + if num_encodes[0].startswith('+'): + encode_idx = parse_idx(all_idx, num_encodes[1:]) + elif num_encodes[0].startswith('-'): + num_encodes_parse = parse_idx(all_idx, num_encodes[1:]) + encode_idx = [ + idx for idx in all_idx if idx not in num_encodes_parse] + if len(num_encodes[0]) > 1: + num = int(num_encodes[0][1:]) + encode_idx = np.random.permutation(encode_idx) + encode_idx = encode_idx[:num] + elif num_encodes == 'all': + num_encodes = n + elif num_encodes < 0: + num_encodes = n + num_encodes + + decode_idx = None + if is_list(num_decodes): + if num_decodes[0].startswith('+'): + decode_idx = parse_idx(all_idx, num_decodes[1:]) + elif num_decodes[0].startswith('-'): + num_decodes_parse = parse_idx(all_idx, num_decodes[1:]) + decode_idx = [ + idx for idx in all_idx if idx not in num_decodes_parse] + if len(num_decodes[0]) > 1: + num = int(num_decodes[0][1:]) + decode_idx = np.random.permutation(decode_idx) + decode_idx = decode_idx[:num] + elif num_decodes == 'all': + num_decodes = n + elif num_decodes == 'remaining': + num_decodes = n - num_encodes + elif num_decodes < 0: + num_encodes = n + num_decodes + + if self.training: + # Shuffle indices and separate encode and decode indices + all_idx = np.random.permutation(all_idx) + + if encode_idx is None: + encode_idx = all_idx[:num_encodes] + if decode_idx is None: + decode_idx = all_idx[-num_decodes:] + + encode_idx = [list(idx) for idx in encode_idx] + decode_idx = [list(idx) for idx in decode_idx] + + if self.decode_encodes: + decode_idx += [idx for idx in encode_idx if idx not in decode_idx] + + return encode_idx, decode_idx + + def parse_output(self, output, key, encode_idx, predictions): + if key in output.keys(): + pred = [[val] for val in output[key]] + if not self.training: + predictions_key = {} + for idx, (i, j) in enumerate(encode_idx): + if j not in predictions_key.keys(): + predictions_key[j] = [] + predictions_key[j].append(output[key][idx]) + predictions_key = {key: [torch.stack(val, 1)] + for key, val in predictions_key.items()} + predictions[key] = predictions_key + else: + pred = None + return pred + + def forward(self, batch, epoch=0, collate=True): + + # Run on list of b + + if is_list(batch): + output = [self.forward(b, collate=False) for b in batch] + loss = [out['loss'] for out in output] + return {'loss': sum(loss) / len(loss), 'predictions': {}, 'metrics': {}} + + if not collate: + for key in ['rgb', 'intrinsics', 'pose', 'depth']: + batch[key] = {k: v.unsqueeze(0) for k, v in batch[key].items()} + + # Unsqueeze batch data if there is only one camera + + key_dim = {'rgb': 4, 'depth': 4, 'intrinsics': 3, 'pose': 3} + for key in ['rgb', 'depth', 'intrinsics', 'pose']: + for ctx in batch[key].keys(): + if batch[key][ctx].dim() == key_dim[key]: + batch[key][ctx] = batch[key][ctx].unsqueeze(1) + for key in ['intrinsics', 'pose']: + for ctx in batch[key].keys(): + if batch[key][ctx].dim() == 3: + batch[key][ctx] = batch[key][ctx].unsqueeze(1) + + ### + + rgb = batch['rgb'] + intrinsics = batch['intrinsics'] + pose = batch['pose'] + depth = batch['depth'] + + # Get context keys, batch size and number of cameras + ctx = [key for key in rgb.keys() if key != 'virtual'] + b, n = rgb[0].shape[:2] + + # Create all indices in a list + ii, jj = list(range(n)), ctx + all_idx = flatten([[[i, j] for i in ii] for j in jj]) + encode_idx, decode_idx = self.get_idx(all_idx) + + # Prepare pose and add jittering if requested + + pose = [{j: pose[j][i] for j in ctx} for i in range(b)] + + pose0 = [Pose().to(rgb[0].device) for _ in range(b)] + if self.training and len(self.use_pose_noise) > 0.0: + if random.random() < self.use_pose_noise[0]: + for i in range(b): + pose0[i].translateUp( + self.use_pose_noise[1] * (2 * random.random() - 1)) + pose0[i].translateLeft( + self.use_pose_noise[1] * (2 * random.random() - 1)) + pose0[i].translateForward( + self.use_pose_noise[1] * (2 * random.random() - 1)) + pose0[i].rotateRoll( + torch.pi * self.use_pose_noise[2] * (2 * random.random() - 1)) + pose0[i].rotatePitch( + torch.pi * self.use_pose_noise[2] * (2 * random.random() - 1)) + pose0[i].rotateYaw( + torch.pi * self.use_pose_noise[2] * (2 * random.random() - 1)) + for i in range(b): + pose[i][0][[0]] = pose0[i].T + + pose = [Pose.from_dict( + p, to_global=True, zero_origin=False, to_matrix=True) for p in pose] + pose = {j: torch.stack([pose[i][j] for i in range(b)], 0) for j in ctx} + + # Augment canonical pose if requested + if self.training and self.augment_canonical: + pose = augment_canonical(pose) + + # Separate batch data per camera + rgb = [{j: rgb[j][:, i] for j in ctx} for i in range(n)] + intrinsics = [{j: intrinsics[0][:, i] for j in ctx} for i in range(n)] + pose = [{j: pose[j][:, i] for j in ctx} for i in range(n)] + depth = [{j: depth[j][:, i] for j in ctx} for i in range(n)] + + # Create camera with batch information + cams = [{j: CameraNerf(K=intrinsics[i][0], Twc=pose[i][j], hw=rgb[i][j]) + for j in ctx} for i in range(n)] + + # Create encode dictionary + encode_data = [{ + 'rgb': rgb[i][j], + 'cam': cams[i][j], + 'gt_depth': depth[i][j], + } for i, j in encode_idx] + + # Create decode dictionary + decode_data = [{ + 'rgb': rgb[i][j], + 'cam': cams[i][j], + 'gt_depth': depth[i][j], + } for i, j in decode_idx] + + # Run PerceiverIO (encode and decode) + perceiver_output = self.networks['perceiver']( + encode_data=encode_data, + decode_data=decode_data, + sample_queries=self.sample_decoded_queries, + filter_invalid=False, + ) + output = perceiver_output['output'] + + predictions = {} + + # DEPTH + + # Get predicted depths + pred_depths = self.parse_output( + output, 'depth', decode_idx, predictions) + # Get predicted monocular depths + pred_depths_mono = self.parse_output( + output, 'depth_mono', decode_idx, predictions) + + # RGB + + # Get predicted RGB + pred_rgbs = self.parse_output(output, 'rgb', decode_idx, predictions) + + # VIRTUAL + + if len(self.use_virtual_cameras) > 0 and (self.training or self.virtual_cameras_eval): + + virtual_data = create_virtual_cameras( + decode_data, + n_samples=self.use_virtual_cameras[1], + cam_noise=self.use_virtual_cameras[2:-1], + center_noise=self.use_virtual_cameras[-1], + thr=0.1, tries=10, decay=0.9, + ) + virtual_output = self.networks['perceiver'].decode( + latent=perceiver_output['latent'], data=virtual_data, + sources=['cam'], field='cam', + sample_queries=self.sample_decoded_queries, + filter_invalid=True + )['output'] + + gt_virtual_cams = [data['cam'] for data in virtual_data] + + if pred_depths is not None: + pred_depths_virtual = [[depth] + for depth in virtual_output['depth']] + gt_depths_virtual = [data['gt_depth'] for data in virtual_data] + + if not self.training: + batch['depth']['virtual'] = torch.stack( + gt_depths_virtual, 1) + predictions['depth']['virtual'] = [torch.stack( + [pred[0] for pred in pred_depths_virtual], 1)] + + if 'query_idx' in virtual_output: + gt_depths_virtual, pred_depths_virtual = sample_pred_gt( + virtual_output, gt_depths_virtual, pred_depths_virtual) + else: + pred_depths_virtual, gt_depths_virtual = None, None + + if pred_rgbs is not None: + pred_rgbs_virtual = [[rgb] for rgb in virtual_output['rgb']] + gt_rgbs_virtual = [data['rgb'] for data in virtual_data] + + if not self.training: + batch['rgb']['virtual'] = torch.stack(gt_rgbs_virtual, 1) + predictions['rgb']['virtual'] = [torch.stack( + [pred[0] for pred in pred_rgbs_virtual], 1)] + + if 'query_idx' in virtual_output: + gt_rgbs_virtual, pred_rgbs_virtual = sample_pred_gt( + virtual_output, gt_rgbs_virtual, pred_rgbs_virtual) + else: + pred_rgbs_virtual, gt_rgbs_virtual = None, None + + else: + + virtual_data = virtual_output = None + + ########################################################## + + if self.display is not None: + self.display.loop(self.networks, encode_data, + decode_data, output, virtual_data) + + ########################################################## + + if not self.training: + return { + 'predictions': predictions, + 'batch': batch, + } + + # Get GT images and depths + gt_depths = [depth[i][j] for i, j in decode_idx] + gt_rgbs = [rgb[i][j] for i, j in decode_idx] + + if 'query_idx' in output: + if pred_depths is not None: + gt_depths, pred_depths = sample_pred_gt( + output, gt_depths, pred_depths) + if pred_rgbs is not None: + gt_rgbs, pred_rgbs = sample_pred_gt(output, gt_rgbs, pred_rgbs) + + loss, metrics = self.compute_loss_and_metrics( + pred_rgbs, gt_rgbs, + pred_depths, gt_depths, + ) + + if len(self.use_virtual_cameras) > 0: + virtual_loss, _ = self.compute_loss_and_metrics( + pred_rgbs_virtual if self.use_virtual_rgb else None, + gt_rgbs_virtual if self.use_virtual_rgb else None, + pred_depths_virtual, gt_depths_virtual, + ) + loss = loss + self.use_virtual_cameras[0] * virtual_loss + + if pred_depths_mono is not None: + mono_loss, _ = self.compute_loss_and_metrics( + None, None, pred_depths_mono, gt_depths, + ) + loss = loss + mono_loss + + return { + 'loss': loss, + 'metrics': metrics, + 'predictions': predictions, + } + + def compute_loss_and_metrics(self, pred_rgbs, gt_rgbs, pred_depths, gt_depths): + + loss, metrics = [], {} + + # Calculate RGB losses + if pred_rgbs is not None and 'rgb' in self.losses and self.weights[0] > 0.0: + loss_rgb = [] + for pred, gt in zip(pred_rgbs, gt_rgbs): + rgb_output = self.losses['rgb'](pred, gt) + loss_rgb.append(self.weights[0] * rgb_output['loss']) + loss.append(sum(loss_rgb) / len(loss_rgb)) + + # Calculate depth losses + if pred_depths is not None and 'depth' in self.losses and self.weights[1] > 0.0: + loss_depth = [] + for pred, gt in zip(pred_depths, gt_depths): + depth_output = self.losses['depth'](pred, gt) + loss_depth.append(self.weights[1] * depth_output['loss']) + loss.append(sum(loss_depth) / len(loss_depth)) + + if len(loss) == 2 and self.scale_loss: + ratio_rgb_depth = loss[1].item() / loss[0].item() + loss[0] = loss[0] * ratio_rgb_depth + + loss = sum(loss) / len(loss) + + return loss, metrics diff --git a/vidar/arch/models/perceiver/DefineModel.py b/vidar/arch/models/perceiver/DefineModel.py new file mode 100755 index 0000000000000000000000000000000000000000..248f46a1e81974678130f9e4bf35c293a97940f1 --- /dev/null +++ b/vidar/arch/models/perceiver/DefineModel.py @@ -0,0 +1,354 @@ +# Copyright 2021 Toyota Research Institute. All rights reserved. + +import random +from abc import ABC + +import torch + +from vidar.arch.models.BaseModel import BaseModel +from vidar.geometry.camera_nerf import CameraNerf +from vidar.geometry.pose import Pose +from vidar.utils.data import flatten +from vidar.utils.types import is_list, is_int + + +def augment_canonical(pose): + + ctx = list(pose.keys()) + num = list(range(pose[0].shape[1])) + + i = random.choice(ctx) + j = random.choice(num) + + base = Pose(pose[i][:, j]).inverse().T + for key in ctx: + for n in num: + pose[key][:, n] = pose[key][:, n] @ base + + return pose + + +def parse_output(output, key, encode_idx, predictions): + if key in output.keys(): + pred = [[val] for val in output[key]] + predictions_key = {} + for idx, (i, j) in enumerate(encode_idx): + if j not in predictions_key.keys(): + predictions_key[j] = [] + predictions_key[j].append(output[key][idx]) + predictions_key = {key: [torch.stack(val, 1)] + for key, val in predictions_key.items()} + predictions[key] = predictions_key + else: + pred = None + return pred + + +def create_virtual_cameras(encode_data, n_samples=1, cam_noise=None, center_noise=None, + downsample=1.0, thr=0.1, tries=10, decay=0.9): + + gt_cams = [datum['cam'] for datum in encode_data] + gt_depths = [datum['gt_depth'] for datum in encode_data] + gt_rgbs = [datum['rgb'] for datum in encode_data] + + if not is_int(thr): + n = gt_rgbs[0].shape[-2] * gt_rgbs[0].shape[-1] + thr = int(n * thr) + + pcls_proj = [cam.scaled(downsample).reconstruct_depth_map(depth, to_world=True) + for cam, depth in zip(gt_cams, gt_depths)] + pcls_proj = [pcl.reshape(*pcl.shape[:2], -1) for pcl in pcls_proj] + pcl_proj = torch.cat([pcl for pcl in pcls_proj], -1) + clr_proj = torch.cat([rgb.reshape(*rgb.shape[:2], -1) + for rgb in gt_rgbs], -1) + + gt_pcl_centers = [pcl.mean(-1) for pcl in pcls_proj] + + virtual_data = [] + for gt_rgb, gt_depth, gt_cam, gt_pcl_center in zip(gt_rgbs, gt_depths, gt_cams, gt_pcl_centers): + for i in range(n_samples): + + cam = gt_cam.clone() + pcl_center = gt_pcl_center.clone() + + weight = 1.0 + rgb_proj = depth_proj = None + for j in range(tries): + + if center_noise is not None: + pcl_center_noise = weight * center_noise * \ + (2 * torch.rand_like(gt_pcl_center) - 1) + pcl_center = gt_pcl_center + pcl_center_noise + + if cam_noise is not None: + cam.look_at(pcl_center) + cam.Twc.translateUp( + weight * cam_noise[0] * (2 * random.random() - 1)) + cam.Twc.translateLeft( + weight * cam_noise[1] * (2 * random.random() - 1)) + cam.Twc.translateForward( + weight * cam_noise[2] * (2 * random.random() - 1)) + cam.look_at(pcl_center) + + rgb_proj_try, depth_proj_try = cam.project_pointcloud( + pcl_proj, clr_proj) + + valid = (depth_proj_try > 0).sum() > thr + if valid: + rgb_proj, depth_proj = rgb_proj_try, depth_proj_try + break + else: + weight = weight * decay + + if rgb_proj is None and depth_proj is None: + rgb_proj, depth_proj = gt_rgb, gt_depth + cam = gt_cam.clone() + + virtual_data.append({ + 'cam': cam, + 'rgb': rgb_proj.contiguous(), + 'gt_depth': depth_proj.contiguous(), + }) + + return virtual_data + + +def ControlVidarCamera(draw, cam, tvel=0.2, rvel=0.1): + change = False + if draw.UP: + cam.Twc.translateForward(tvel) + change = True + if draw.DOWN: + cam.Twc.translateBackward(tvel) + change = True + if draw.LEFT: + cam.Twc.translateLeft(tvel) + change = True + if draw.RIGHT: + cam.Twc.translateRight(tvel) + change = True + if draw.PGUP: + cam.Twc.translateUp(tvel) + change = True + if draw.PGDOWN: + cam.Twc.translateDown(tvel) + change = True + if draw.KEY_A: + cam.Twc.rotateYaw(-rvel) + change = True + if draw.KEY_D: + cam.Twc.rotateYaw(+rvel) + change = True + if draw.KEY_W: + cam.Twc.rotatePitch(+rvel) + change = True + if draw.KEY_S: + cam.Twc.rotatePitch(-rvel) + change = True + if draw.KEY_Q: + cam.Twc.rotateRoll(-rvel) + change = True + if draw.KEY_E: + cam.Twc.rotateRoll(+rvel) + change = True + return change + + +class HuggingModel(BaseModel, ABC): + + def __init__(self, cfg): + super().__init__(cfg) + + from vidar.arch.networks.perceiver.HuggingNet import HuggingNet + self.networks['perceiver'] = HuggingNet(cfg.model.network) + self.weights = [1.0, 1.0] + self.use_pose_noise = cfg.model.use_pose_noise + self.use_virtual_cameras = cfg.model.use_virtual_cameras + self.use_virtual_rgb = cfg.model.use_virtual_rgb + self.augment_canonical = cfg.model.augment_canonical + + def forward(self, batch, epoch=0, collate=True): + + if is_list(batch): + output = [self.forward(b, collate=False) for b in batch] + loss = [out['loss'] for out in output] + return {'loss': sum(loss) / len(loss), 'predictions': {}, 'metrics': {}} + + if not collate: + for key in ['rgb', 'intrinsics', 'pose', 'depth']: + batch[key] = {k: v.unsqueeze(0) for k, v in batch[key].items()} + + rgb = batch['rgb'] + intrinsics = batch['intrinsics'] + pose = batch['pose'] + depth = batch['depth'] + + ctx = [0] # rgb.keys() + b, n = rgb[0].shape[:2] + + ii, jj = range(n), ctx + + pose = [{key: val[i] for key, val in pose.items()} for i in range(b)] + + pose0 = [Pose().to(rgb[0].device) for _ in range(b)] + if self.training and len(self.use_pose_noise) > 0.0: + if random.random() < self.use_pose_noise[0]: + for i in range(b): + pose0[i].translateUp( + self.use_pose_noise[1] * (2 * random.random() - 1)) + pose0[i].translateLeft( + self.use_pose_noise[1] * (2 * random.random() - 1)) + pose0[i].translateForward( + self.use_pose_noise[1] * (2 * random.random() - 1)) + pose0[i].rotateRoll( + torch.pi * self.use_pose_noise[2] * (2 * random.random() - 1)) + pose0[i].rotatePitch( + torch.pi * self.use_pose_noise[2] * (2 * random.random() - 1)) + pose0[i].rotateYaw( + torch.pi * self.use_pose_noise[2] * (2 * random.random() - 1)) + for i in range(b): + pose[i][0][[0]] = pose0[i].T + + pose = [Pose.from_dict( + p, to_global=True, zero_origin=False, to_matrix=True) for p in pose] + pose = {key: torch.stack([pose[i][key] for i in range( + len(pose))], 0) for key in pose[0].keys()} + + if self.training and self.augment_canonical: + pose = augment_canonical(pose) + + rgb = [{key: val[:, i] for key, val in rgb.items()} for i in range(n)] + intrinsics = [{key: val[:, i] + for key, val in intrinsics.items()} for i in range(n)] + pose = [{key: val[:, i] for key, val in pose.items()} + for i in range(n)] + depth = [{key: val[:, i] for key, val in depth.items()} + for i in range(n)] + + cams = [{j: CameraNerf(K=intrinsics[i][0], Twc=pose[i][j], hw=rgb[i][j]) + for j in ctx} for i in range(n)] + + encode_idx = flatten([[[i, j] for i in ii] for j in jj]) + + encode_data = [{ + 'rgb': rgb[i][j], + 'cam': cams[i][j], + 'gt_depth': depth[i][j], + } for i, j in encode_idx] + + perceiver_output = self.networks['perceiver']( + encode_data=encode_data, + ) + + output = perceiver_output['output'] + + predictions = {} + + # DEPTH + + pred_depths = parse_output(output, 'depth', encode_idx, predictions) + pred_depths_mono = parse_output( + output, 'depth_mono', encode_idx, predictions) + + if pred_depths is not None and pred_depths_mono is not None: + for i in range(len(pred_depths)): + pred_depths[i] += pred_depths_mono[i] + + # RGB + + pred_rgbs = parse_output(output, 'rgb', encode_idx, predictions) + + # VIRTUAL + + if len(self.use_virtual_cameras) > 0: + + virtual_data = create_virtual_cameras( + encode_data, + n_samples=self.use_virtual_cameras[1], + cam_noise=self.use_virtual_cameras[2:-1], + center_noise=self.use_virtual_cameras[-1], + thr=0.1, tries=10, decay=0.9, + ) + virtual_output = self.networks['perceiver'].decode( + latent=perceiver_output['latent'], data=virtual_data, + sources=['cam'], field='cam', + )['output'] + + gt_virtual_cams = [data['cam'] for data in virtual_data] + + if pred_depths is not None: + pred_depths_virtual = [[depth] + for depth in virtual_output['depth']] + gt_depths_virtual = [data['gt_depth'] for data in virtual_data] + + batch['depth']['virtual'] = torch.stack(gt_depths_virtual, 1) + predictions['depth']['virtual'] = [torch.stack( + [pred[0] for pred in pred_depths_virtual], 1)] + else: + pred_depths_virtual, gt_depths_virtual = None, None + + if pred_rgbs is not None: + pred_rgbs_virtual = [[rgb] for rgb in virtual_output['rgb']] + gt_rgbs_virtual = [data['rgb'] for data in virtual_data] + + batch['rgb']['virtual'] = torch.stack(gt_rgbs_virtual, 1) + predictions['rgb']['virtual'] = [torch.stack( + [pred[0] for pred in pred_rgbs_virtual], 1)] + else: + pred_rgbs_virtual, gt_rgbs_virtual = None, None + + ########################################################## + + # display_data_interactive(self.networks, encode_data, encode_data, output, virtual_data) + # display_data(rgb, depth, cams, ctx, n, batch, predictions, gt_virtual_cams) + + ########################################################## + + if not self.training: + return { + 'predictions': predictions, + 'batch': batch, + } + + gt_depths = [depth[i][j] for i, j in encode_idx] + gt_rgbs = [rgb[i][j] for i, j in encode_idx] + + loss, metrics = self.compute_loss_and_metrics( + pred_rgbs, gt_rgbs, + pred_depths, gt_depths, + ) + + if len(self.use_virtual_cameras) > 0: + virtual_loss, _ = self.compute_loss_and_metrics( + pred_rgbs_virtual if self.use_virtual_rgb else None, + gt_rgbs_virtual if self.use_virtual_rgb else None, + pred_depths_virtual, gt_depths_virtual, + ) + loss = loss + self.use_virtual_cameras[0] * virtual_loss + + return { + 'loss': loss, + 'metrics': metrics, + 'predictions': predictions, + } + + def compute_loss_and_metrics(self, pred_rgbs, gt_rgbs, pred_depths, gt_depths): + + loss, metrics = [], {} + + if pred_rgbs is not None and 'rgb' in self.losses and self.weights[0] > 0.0: + loss_rgb = [] + for pred, gt in zip(pred_rgbs, gt_rgbs): + rgb_output = self.losses['rgb'](pred, gt) + loss_rgb.append(self.weights[0] * rgb_output['loss']) + loss.append(sum(loss_rgb) / len(loss_rgb)) + if pred_depths is not None and 'depth' in self.losses and self.weights[1] > 0.0: + loss_depth = [] + for pred, gt in zip(pred_depths, gt_depths): + depth_output = self.losses['depth'](pred, gt) + loss_depth.append(self.weights[1] * depth_output['loss']) + loss.append(sum(loss_depth) / len(loss_depth)) + + loss = sum(loss) / len(loss) + + return loss, metrics diff --git a/vidar/arch/models/utils.py b/vidar/arch/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..33c2ebb92be2f1c8519b27d64b8e7a3aed301d8b --- /dev/null +++ b/vidar/arch/models/utils.py @@ -0,0 +1,112 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import torch + +from vidar.geometry.camera_full import CameraFull +from vidar.utils.decorators import iterate1 +from vidar.utils.tensor import interpolate_image +from vidar.utils.types import is_dict + + +@iterate1 +def make_rgb_scales(rgb, pyramid=None, ratio_scales=None): + """ + Create different RGB scales to correspond with predictions + + Parameters + ---------- + rgb : torch.Tensor + Input image [B,3,H,W] + pyramid : list[torch.Tensor] + List with tensors at different scales + ratio_scales : Tuple + Alternatively, you can provide how many scales and the downsampling ratio for each + + Returns + ------- + pyramid : list[torch.Tensor] + List with the input image at the same resolutions as pyramid + """ + assert pyramid is None or ratio_scales is None + if pyramid is not None: + return [interpolate_image(rgb, shape=pyr.shape[-2:]) for pyr in pyramid] + elif ratio_scales is not None: + return [interpolate_image(rgb, scale_factor=ratio_scales[0] ** i) + for i in range(ratio_scales[1])] + else: + raise NotImplementedError('Invalid option') + + +def break_context(dic, tgt=0, ctx=None, scl=None, stack=False): + """ + Separate a dictionary between target and context information + + Parameters + ---------- + dic : Dict + Input dictionary + tgt : Int + Which key corresponds to target + ctx : Int + Which key corresponds to context (if None, use everything else) + scl : Int + Which scale should be used (it None, assume there are no scales) + stack : Bool + Stack output context or not + + Returns + ------- + tgt : torch.Tensor + Target information + ctx : list[torch.Tensor] or torch.Tensor + Context information (list or stacked) + """ + # Get remaining frames if context is not provided + if ctx is None: + ctx = [key for key in dic.keys() if key != tgt] + # Get all scales or a single scale + if scl is None: + tgt, ctx = dic[tgt], [dic[key] for key in ctx if key != tgt] + else: + tgt, ctx = dic[tgt][scl], [dic[key][scl] for key in ctx if key != tgt] + # Stack context if requested + if stack: + ctx = torch.stack(ctx, 1) + # Return target and context + return tgt, ctx + + +def create_cameras(rgb, intrinsics, pose, zero_origin=True, scaled=None): + """ + Create cameras from batch information + Parameters + ---------- + rgb : Dict + Dictionary with images + intrinsics : Dict + Dictionary with camera intrinsics + pose : Dict + Dictionary with camera poses + zero_origin : Bool + Zero target camera to the origin or not + scaled : Float + Scale factor for the output cameras + + Returns + ------- + cams : Dict + Dictionary with output cameras + """ + if pose is None: + return None + cams = {key: CameraFull( + K=intrinsics[key] if is_dict(intrinsics) else intrinsics, + Twc=pose[key], + hw=rgb[key] if is_dict(rgb) else rgb, + ).scaled(scaled).to(pose[key].device) for key in pose.keys()} + if zero_origin: + cams[0] = CameraFull( + K=intrinsics[0] if is_dict(intrinsics) else intrinsics, + hw=rgb[0] if is_dict(rgb) else rgb, + ).scaled(scaled).to(rgb.device) + return cams diff --git a/vidar/arch/networks/BaseNet.py b/vidar/arch/networks/BaseNet.py new file mode 100755 index 0000000000000000000000000000000000000000..6489f5b287dd0d433743515cae08a6a5ca5b9fd8 --- /dev/null +++ b/vidar/arch/networks/BaseNet.py @@ -0,0 +1,54 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import torch +import torch.nn as nn + +from vidar.arch.blocks.depth.SigmoidToInvDepth import SigmoidToInvDepth +from vidar.utils.config import cfg_has + + +class BaseNet(nn.Module): + """Base network class, that all other networks inherit""" + def __init__(self, cfg): + super().__init__() + self.networks = torch.nn.ModuleDict() + self.blocks = torch.nn.ModuleDict() + + if cfg_has(cfg, 'depth_range'): + self.to_depth = SigmoidToInvDepth( + cfg.depth_range[0], cfg.depth_range[1], return_depth=True) + else: + self.to_depth = None + + def _forward_unimplemented(self, *args): + raise NotImplementedError('Forward unimplemented is unimplemented!') + + def set_attr(self, cfg, key, default): + """Set a network attribute""" + self.__setattr__(key, cfg_has(cfg, key, default)) + + def train(self, mode=True): + """Set all networks and blocks to train or val""" + super().train(mode=mode) + for key, val in self.networks.items(): + val.train(mode=mode) + for key, val in self.blocks.values(): + val.train(mode=mode) + + def eval(self): + self.train(mode=False) + + def sigmoid_to_depth(self, sigmoids): + """Convert sigmoids to depth values""" + return self.to_depth(sigmoids) if self.to_depth is not None else sigmoids + + def load(self, ckpt, name): + """Loads a checkpoint onto the network""" + state_dict = torch.load(ckpt, map_location='cpu')['state_dict'] + updated_state_dict = {} + for key, val in state_dict.items(): + idx = key.find(name) + if idx > -1: + updated_state_dict[key[idx + len(name) + 1:]] = val + self.load_state_dict(updated_state_dict) + diff --git a/vidar/arch/networks/__init__.py b/vidar/arch/networks/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vidar/arch/networks/decoders/DepthDecoder.py b/vidar/arch/networks/decoders/DepthDecoder.py new file mode 100755 index 0000000000000000000000000000000000000000..846e77fda2f187c7cf490f3c8d1eaaa3dc04116c --- /dev/null +++ b/vidar/arch/networks/decoders/DepthDecoder.py @@ -0,0 +1,81 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from abc import ABC +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn as nn + +from vidar.arch.networks.layers.convs import ConvBlock, Conv3x3, upsample +from vidar.utils.config import cfg_has + + +class DepthDecoder(nn.Module, ABC): + """ + Depth decoder network + + Parameters + ---------- + cfg : Config + Configuration with parameters + """ + def __init__(self, cfg): + super().__init__() + + self.num_scales = cfg_has(cfg, 'num_scales', 4) + self.use_skips = cfg.use_skips + + self.num_ch_enc = cfg.num_ch_enc + self.num_ch_dec = np.array([16, 32, 64, 128, 256]) + self.num_ch_out = cfg.num_ch_out + + self.convs = OrderedDict() + for i in range(self.num_scales, -1, -1): + + num_ch_in = self.num_ch_enc[-1] if i == self.num_scales else self.num_ch_dec[i + 1] + num_ch_out = self.num_ch_dec[i] + self.convs[('upconv', i, 0)] = ConvBlock( + num_ch_in, num_ch_out) + + num_ch_in = self.num_ch_dec[i] + if self.use_skips and i > 0: + num_ch_in += self.num_ch_enc[i - 1] + num_ch_out = self.num_ch_dec[i] + self.convs[('upconv', i, 1)] = ConvBlock( + num_ch_in, num_ch_out) + + for i in range(self.num_scales): + self.convs[('outconv', i)] = Conv3x3( + self.num_ch_dec[i], self.num_ch_out) + + self.decoder = nn.ModuleList(list(self.convs.values())) + + if cfg.activation == 'sigmoid': + self.activation = nn.Sigmoid() + elif cfg.activation == 'identity': + self.activation = nn.Identity() + elif cfg.activation == 'softmax': + self.activation = nn.Softmax(dim=1) + else: + raise ValueError('Invalid activation function {}'.format(cfg.activation)) + + def forward(self, input_features): + """Network forward pass""" + + outputs = {} + + x = input_features[-1] + for i in range(self.num_scales, -1, -1): + x = self.convs[('upconv', i, 0)](x) + x = [upsample(x)] + if self.use_skips and i > 0: + x += [input_features[i - 1]] + x = torch.cat(x, 1) + x = self.convs[('upconv', i, 1)](x) + if i in range(self.num_scales): + outputs[('features', i)] = x + outputs[('output', i)] = self.activation( + self.convs[('outconv', i)](x)) + + return outputs diff --git a/vidar/arch/networks/decoders/PoseDecoder.py b/vidar/arch/networks/decoders/PoseDecoder.py new file mode 100755 index 0000000000000000000000000000000000000000..89764d89998233d01ba8a41463217634678be0d7 --- /dev/null +++ b/vidar/arch/networks/decoders/PoseDecoder.py @@ -0,0 +1,55 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from abc import ABC + +import torch +import torch.nn as nn + + +class PoseDecoder(nn.Module, ABC): + """ + Pose decoder network + + Parameters + ---------- + cfg : Config + Configuration with parameters + """ + def __init__(self, num_ch_enc, num_input_features, + num_frames_to_predict_for=None, + stride=1, output_multiplier=0.01): + super().__init__() + + self.num_encoder_channels = num_ch_enc + self.num_input_features = num_input_features + self.output_multiplier = output_multiplier + + if num_frames_to_predict_for is None: + num_frames_to_predict_for = num_input_features - 1 + self.num_output_predictions = num_frames_to_predict_for + + self.convs = { + 'squeeze': nn.Conv2d(self.num_encoder_channels[-1], 256, 1), + ('pose', 0): nn.Conv2d(num_input_features * 256, 256, 3, stride, 1), + ('pose', 1): nn.Conv2d(256, 256, 3, stride, 1), + ('pose', 2): nn.Conv2d(256, 6 * num_frames_to_predict_for, 1), + } + + self.net = nn.ModuleList(list(self.convs.values())) + self.relu = nn.ReLU() + + def forward(self, all_features): + """Network forward pass""" + + last_features = [f[-1] for f in all_features] + last_features = [self.relu(self.convs['squeeze'](f)) for f in last_features] + cat_features = torch.cat(last_features, 1) + + for i in range(3): + cat_features = self.convs[('pose', i)](cat_features) + if i < 2: + cat_features = self.relu(cat_features) + + output = self.output_multiplier * \ + cat_features.mean(3).mean(2).view(-1, self.num_output_predictions, 1, 6) + return torch.split(output, split_size_or_sections=3, dim=-1) diff --git a/vidar/arch/networks/depth/FocalDepthResNet.py b/vidar/arch/networks/depth/FocalDepthResNet.py new file mode 100644 index 0000000000000000000000000000000000000000..5070a3cce0643ce05ae20e51d6a696d53c5ea330 --- /dev/null +++ b/vidar/arch/networks/depth/FocalDepthResNet.py @@ -0,0 +1,48 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from abc import ABC + +from vidar.arch.blocks.depth.SigmoidToInvDepth import SigmoidToInvDepth +from vidar.arch.networks.BaseNet import BaseNet +from vidar.arch.networks.decoders.DepthDecoder import DepthDecoder +from vidar.arch.networks.encoders.ResNetEncoder import ResNetEncoder as ResnetEncoder +from vidar.utils.depth import inv2depth, depth2inv + + +class FocalDepthResNet(BaseNet, ABC): + """ + Depth network with focal length normalization + + Parameters + ---------- + cfg : Config + Configuration with parameters + """ + def __init__(self, cfg): + super().__init__(cfg) + + self.networks['encoder'] = ResnetEncoder(cfg.encoder) + cfg.decoder.num_ch_enc = self.networks['encoder'].num_ch_enc + self.networks['decoder'] = DepthDecoder(cfg.decoder) + self.scale_inv_depth = SigmoidToInvDepth( + min_depth=cfg.min_depth, max_depth=cfg.max_depth) + + def forward(self, rgb, intrinsics, **kwargs): + """Network forward pass""" + + x = self.networks['encoder'](rgb) + x = self.networks['decoder'](x) + inv_depths = [x[('output', i)] for i in range(4)] + + if self.training: + inv_depths = [self.scale_inv_depth(inv_depth)[0] for inv_depth in inv_depths] + else: + inv_depths = [self.scale_inv_depth(inv_depths[0])[0]] + + depths = inv2depth(inv_depths) + depths = [d * intrinsics[:, 0, 0].view(rgb.shape[0], 1, 1, 1) for d in depths] + inv_depths = depth2inv(depths) + + return { + 'inv_depths': inv_depths + } diff --git a/vidar/arch/networks/depth/MonoDepthResNet.py b/vidar/arch/networks/depth/MonoDepthResNet.py new file mode 100755 index 0000000000000000000000000000000000000000..7d48e40f4da31e79780e62ecccc77a324ebe814e --- /dev/null +++ b/vidar/arch/networks/depth/MonoDepthResNet.py @@ -0,0 +1,46 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from abc import ABC + +from vidar.arch.networks.BaseNet import BaseNet +from vidar.arch.networks.decoders.DepthDecoder import DepthDecoder +from vidar.arch.networks.encoders.ResNetEncoder import ResNetEncoder +from vidar.utils.config import cfg_has + + +class MonoDepthResNet(BaseNet, ABC): + """ + Single-frame monocular depth network + + Parameters + ---------- + cfg : Config + Configuration with parameters + """ + def __init__(self, cfg): + super().__init__(cfg) + + self.num_scales = cfg_has(cfg, 'num_scales', 4) + self.set_attr(cfg, 'scale_intrinsics', False) + + self.networks['mono_encoder'] = ResNetEncoder(cfg.encoder) + cfg.decoder.num_ch_enc = self.networks['mono_encoder'].num_ch_enc + self.networks['mono_depth'] = DepthDecoder(cfg.decoder) + + def forward(self, rgb, intrinsics=None): + """Network forward pass""" + + features = self.networks['mono_encoder'](rgb) + output = self.networks['mono_depth'](features) + + sigmoids = [output[('output', i)] for i in range(self.num_scales)] + depths = self.sigmoid_to_depth(sigmoids) + + if intrinsics is not None and self.scale_intrinsics: + depths = [d * intrinsics[:, 0, 0].view( + rgb.shape[0], 1, 1, 1) for d in depths] + + return { + 'features': features, + 'depths': depths, + } diff --git a/vidar/arch/networks/depth/MultiDepthResNet.py b/vidar/arch/networks/depth/MultiDepthResNet.py new file mode 100644 index 0000000000000000000000000000000000000000..65eb95d4cc372d7213040027478ec095e1f69525 --- /dev/null +++ b/vidar/arch/networks/depth/MultiDepthResNet.py @@ -0,0 +1,50 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from abc import ABC + +from vidar.arch.networks.BaseNet import BaseNet +from vidar.arch.networks.decoders.DepthDecoder import DepthDecoder +from vidar.arch.networks.encoders.MultiResNetEncoder import MultiResNetEncoder +# from vidar.arch.networks.encoders.MultiResNetEncoderStereoTwin import ResnetEncoderMatchingStereo +from vidar.utils.config import cfg_has + + +class MultiDepthResNet(BaseNet, ABC): + """ + Multi-frame monocular depth network + + Parameters + ---------- + cfg : Config + Configuration with parameters + """ + def __init__(self, cfg): + super().__init__(cfg) + + self.num_scales = cfg_has(cfg, 'num_scales', 4) + self.set_attr(cfg, 'scale_intrinsics', False) + + self.networks['encoder'] = MultiResNetEncoder(cfg.encoder) + cfg.decoder.num_ch_enc = self.networks['encoder'].num_ch_enc + self.networks['depth'] = DepthDecoder(cfg.decoder) + + def forward(self, rgb, rgb_context, cams, + intrinsics=None, networks=None): + """Network forward pass""" + + encoder_output = self.networks['encoder']( + rgb, rgb_context, cams, networks=networks) + + network_output = { + **encoder_output, + } + + output = self.networks['depth'](encoder_output['features']) + sigmoids = [output[('output', i)] for i in range(self.num_scales)] + network_output['depths'] = self.sigmoid_to_depth(sigmoids) + + if intrinsics is not None and self.scale_intrinsics: + network_output['depths'] = [d * intrinsics[0][:, 0, 0].view( + rgb.shape[0], 1, 1, 1) for d in network_output['depths']] + + return network_output diff --git a/vidar/arch/networks/depth/PackNet.py b/vidar/arch/networks/depth/PackNet.py new file mode 100644 index 0000000000000000000000000000000000000000..5a797d269355d341183dc37717a129f47127e6a4 --- /dev/null +++ b/vidar/arch/networks/depth/PackNet.py @@ -0,0 +1,166 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from abc import ABC + +import torch +import torch.nn as nn + +from vidar.arch.networks.BaseNet import BaseNet +from vidar.arch.networks.layers.packnet.packnet import \ + PackLayerConv3d, UnpackLayerConv3d, Conv2D, ResidualBlock, InvDepth +from vidar.utils.depth import inv2depth + + +class PackNet(BaseNet, ABC): + """ + PackNet depth network (https://arxiv.org/abs/1905.02693) + + Parameters + ---------- + cfg : Config + Configuration with parameters + """ + def __init__(self, cfg): + super().__init__(cfg) + + # Configuration parameters + self.min_depth = cfg.has('min_depth', 0.5) + self.dropout = cfg.has('dropout', 0.0) + + # Input/output channels + in_channels = 3 + out_channels = 1 + + # Hyper-parameters + ni, no = 64, out_channels + n1, n2, n3, n4, n5 = 64, 64, 128, 256, 512 + num_blocks = [2, 2, 3, 3] + pack_kernel = [5, 3, 3, 3, 3] + unpack_kernel = [3, 3, 3, 3, 3] + iconv_kernel = [3, 3, 3, 3, 3] + + # Initial convolutional layer + self.pre_calc = Conv2D(in_channels, ni, 5, 1) + + # Support for different versions + n1o, n1i = n1, n1 + ni + no + n2o, n2i = n2, n2 + n1 + no + n3o, n3i = n3, n3 + n2 + no + n4o, n4i = n4, n4 + n3 + n5o, n5i = n5, n5 + n4 + + # Encoder + + self.pack1 = PackLayerConv3d(n1, pack_kernel[0]) + self.pack2 = PackLayerConv3d(n2, pack_kernel[1]) + self.pack3 = PackLayerConv3d(n3, pack_kernel[2]) + self.pack4 = PackLayerConv3d(n4, pack_kernel[3]) + self.pack5 = PackLayerConv3d(n5, pack_kernel[4]) + + self.conv1 = Conv2D(ni, n1, 7, 1) + self.conv2 = ResidualBlock(n1, n2, num_blocks[0], 1, dropout=self.dropout) + self.conv3 = ResidualBlock(n2, n3, num_blocks[1], 1, dropout=self.dropout) + self.conv4 = ResidualBlock(n3, n4, num_blocks[2], 1, dropout=self.dropout) + self.conv5 = ResidualBlock(n4, n5, num_blocks[3], 1, dropout=self.dropout) + + # Decoder + + self.unpack5 = UnpackLayerConv3d(n5, n5o, unpack_kernel[0]) + self.unpack4 = UnpackLayerConv3d(n5, n4o, unpack_kernel[1]) + self.unpack3 = UnpackLayerConv3d(n4, n3o, unpack_kernel[2]) + self.unpack2 = UnpackLayerConv3d(n3, n2o, unpack_kernel[3]) + self.unpack1 = UnpackLayerConv3d(n2, n1o, unpack_kernel[4]) + + self.iconv5 = Conv2D(n5i, n5, iconv_kernel[0], 1) + self.iconv4 = Conv2D(n4i, n4, iconv_kernel[1], 1) + self.iconv3 = Conv2D(n3i, n3, iconv_kernel[2], 1) + self.iconv2 = Conv2D(n2i, n2, iconv_kernel[3], 1) + self.iconv1 = Conv2D(n1i, n1, iconv_kernel[4], 1) + + # Depth Layers + + self.unpack_disps = nn.PixelShuffle(2) + self.unpack_disp4 = nn.Upsample(scale_factor=2, mode='nearest', align_corners=None) + self.unpack_disp3 = nn.Upsample(scale_factor=2, mode='nearest', align_corners=None) + self.unpack_disp2 = nn.Upsample(scale_factor=2, mode='nearest', align_corners=None) + + self.disp4_layer = InvDepth(n4, out_channels=out_channels, min_depth=self.min_depth) + self.disp3_layer = InvDepth(n3, out_channels=out_channels, min_depth=self.min_depth) + self.disp2_layer = InvDepth(n2, out_channels=out_channels, min_depth=self.min_depth) + self.disp1_layer = InvDepth(n1, out_channels=out_channels, min_depth=self.min_depth) + + self.init_weights() + + def init_weights(self): + """Weight initialization""" + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.Conv3d)): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, rgb, intrinsics=None): + """Network forward pass""" + + # Initial convolution + + x = self.pre_calc(rgb) + + # Encoder + + x1 = self.conv1(x) + x1p = self.pack1(x1) + x2 = self.conv2(x1p) + x2p = self.pack2(x2) + x3 = self.conv3(x2p) + x3p = self.pack3(x3) + x4 = self.conv4(x3p) + x4p = self.pack4(x4) + x5 = self.conv5(x4p) + x5p = self.pack5(x5) + + # Skips + + skip1 = x + skip2 = x1p + skip3 = x2p + skip4 = x3p + skip5 = x4p + + # Decoder + + unpack5 = self.unpack5(x5p) + concat5 = torch.cat((unpack5, skip5), 1) + iconv5 = self.iconv5(concat5) + + unpack4 = self.unpack4(iconv5) + concat4 = torch.cat((unpack4, skip4), 1) + iconv4 = self.iconv4(concat4) + inv_depth4 = self.disp4_layer(iconv4) + up_inv_depth4 = self.unpack_disp4(inv_depth4) + + unpack3 = self.unpack3(iconv4) + concat3 = torch.cat((unpack3, skip3, up_inv_depth4), 1) + iconv3 = self.iconv3(concat3) + inv_depth3 = self.disp3_layer(iconv3) + up_inv_depth3 = self.unpack_disp3(inv_depth3) + + unpack2 = self.unpack2(iconv3) + concat2 = torch.cat((unpack2, skip2, up_inv_depth3), 1) + iconv2 = self.iconv2(concat2) + inv_depth2 = self.disp2_layer(iconv2) + up_inv_depth2 = self.unpack_disp2(inv_depth2) + + unpack1 = self.unpack1(iconv2) + concat1 = torch.cat((unpack1, skip1, up_inv_depth2), 1) + iconv1 = self.iconv1(concat1) + inv_depth1 = self.disp1_layer(iconv1) + + if self.training: + inv_depths = [inv_depth1, inv_depth2, inv_depth3, inv_depth4] + else: + inv_depths = [inv_depth1] + + return { + 'depths': inv2depth(inv_depths), + } diff --git a/vidar/arch/networks/encoders/MultiResNetEncoder.py b/vidar/arch/networks/encoders/MultiResNetEncoder.py new file mode 100755 index 0000000000000000000000000000000000000000..021c2539edda918dd3d5d6f4aec898c49f88f7a1 --- /dev/null +++ b/vidar/arch/networks/encoders/MultiResNetEncoder.py @@ -0,0 +1,223 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from abc import ABC +from functools import partial + +import numpy as np +import torch +import torch.nn as nn +import torchvision.models as models + +from vidar.utils.config import cfg_has +from vidar.utils.data import flatten +from vidar.utils.depth import depth2inv +from vidar.utils.tensor import grid_sample + +RESNET_VERSIONS = { + 18: models.resnet18, + 34: models.resnet34, + 50: models.resnet50, + 101: models.resnet101, + 152: models.resnet152 +} + + +class MultiResNetEncoder(nn.Module, ABC): + """ + Multi-frame depth encoder + + Parameters + ---------- + cfg : Config + Configuration with parameters + """ + def __init__(self, cfg): + super().__init__() + + self.adaptive_bins = cfg.adaptive_bins + self.depth_binning = cfg.depth_binning + self.set_missing_to_max = True + + self.num_ch_enc = np.array([64, 64, 128, 256, 512]) + + self.depth_range = cfg.depth_range + self.min_depth_bin = cfg.depth_bin_range[0] + self.max_depth_bin = cfg.depth_bin_range[1] + self.num_depth_bins = cfg.num_depth_bins + + self.min_depth_bin = torch.nn.Parameter(torch.tensor( + self.min_depth_bin), requires_grad=False) + self.max_depth_bin = torch.nn.Parameter(torch.tensor( + self.max_depth_bin), requires_grad=False) + + self.matching_height = cfg.input_shape[0] // 4 + self.matching_width = cfg.input_shape[1] // 4 + + self.depth_bins = None + self.warp_depths = None + + assert cfg.version in RESNET_VERSIONS, ValueError( + '{} is not a valid number of resnet layers'.format(cfg.version)) + encoder = RESNET_VERSIONS[cfg.version](cfg.pretrained) + + self.layer0 = nn.Sequential(encoder.conv1, encoder.bn1, encoder.relu) + self.layer1 = nn.Sequential(encoder.maxpool, encoder.layer1) + self.layer2 = encoder.layer2 + self.layer3 = encoder.layer3 + self.layer4 = encoder.layer4 + + if cfg.version > 34: + self.num_ch_enc[1:] *= 4 + + self.prematching_conv = nn.Sequential( + nn.Conv2d(64, out_channels=16, kernel_size=1, stride=1, padding=0), + nn.ReLU(inplace=True) + ) + + self.double_volume = cfg_has(cfg, 'double_volume', False) + + self.reduce_conv = nn.Sequential( + nn.Conv2d(self.num_ch_enc[1] + self.num_depth_bins * (2 if self.double_volume else 1), + out_channels=self.num_ch_enc[1], + kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=True) + ) + + self.grid_sample = partial( + grid_sample, padding_mode='zeros', mode='bilinear', align_corners=True) + + self.volume_masking = cfg_has(cfg, 'volume_masking', False) + + def update_adaptive_depth_bins(self, depth): + """Change depth bins based on predicted depth""" + min_depth = depth.detach().min(-1)[0].min(-1)[0] + max_depth = depth.detach().max(-1)[0].max(-1)[0] + + min_depth = min_depth.mean().cpu().item() + max_depth = max_depth.mean().cpu().item() + + min_depth = max(self.depth_range[0], min_depth * 0.9) + max_depth = min(self.depth_range[1], max_depth * 1.1) + + self.min_depth_bin = nn.Parameter( + self.min_depth_bin * 0.99 + min_depth * 0.01, requires_grad=False) + self.max_depth_bin = nn.Parameter( + self.max_depth_bin * 0.99 + max_depth * 0.01, requires_grad=False) + + def compute_depth_bins(self, min_depth_bin, max_depth_bin, device): + """Compute depth bins based on minimum and maximum values""" + min_depth_bin = min_depth_bin.cpu() + max_depth_bin = max_depth_bin.cpu() + + if self.depth_binning == 'inverse': + self.depth_bins = 1. / np.linspace( + 1. / max_depth_bin, 1. / min_depth_bin, self.num_depth_bins)[::-1] + elif self.depth_binning == 'linear': + self.depth_bins = np.linspace( + min_depth_bin, max_depth_bin, self.num_depth_bins) + elif self.depth_binning == 'sid': + self.depth_bins = np.array( + [np.exp(np.log(min_depth_bin) + np.log(max_depth_bin / min_depth_bin) * i / (self.num_depth_bins - 1)) + for i in range(self.num_depth_bins)]) + else: + raise NotImplementedError + self.depth_bins = torch.from_numpy(self.depth_bins).float().to(device) + + ones = torch.ones((1, self.matching_height, self.matching_width), + dtype=torch.float, device=device) + return torch.stack([depth * ones for depth in self.depth_bins], 1) + + def feature_extraction(self, image, return_all_feats=False): + """Extract features from input images""" + image = (image - 0.45) / 0.225 + feats_0 = self.layer0(image) + feats_1 = self.layer1(feats_0) + return [feats_0, feats_1] if return_all_feats else feats_1 + + def indices_to_inv_depth(self, indices): + """Convert bin indices to inverse depth values""" + batch, height, width = indices.shape + depth = self.depth_bins[indices.reshape(-1)] + return 1 / depth.reshape((batch, height, width)) + + def compute_confidence_mask(self, cost_volume, num_bins_threshold=None): + """Compute confidence mask based on cost volume""" + if num_bins_threshold is None: + num_bins_threshold = self.num_depth_bins + return ((cost_volume > 0).sum(1) == num_bins_threshold).float() + + def forward(self, rgb, rgb_context=None, + cams=None, mode='multi', networks=None): + """Network forward pass""" + + feats = self.feature_extraction(rgb, return_all_feats=True) + current_feats = feats[-1] + + if mode == 'mono': + feats.append(self.layer2_mono(feats[-1])) + feats.append(self.layer3_mono(feats[-1])) + feats.append(self.layer4_mono(feats[-1])) + return { + 'features': feats, + } + + output = {} + + with torch.no_grad(): + if self.warp_depths is None or self.adaptive_bins: + self.warp_depths = self.compute_depth_bins( + self.min_depth_bin, self.max_depth_bin, device=rgb.device) + + b, n, c, h, w = rgb_context.shape + rgb_context = rgb_context.reshape(b * n, c, h, w) + feats_context = self.feature_extraction(rgb_context, return_all_feats=True) + + output_transformer = networks['transformer'](rgb, rgb_context, rgb.device, cams[-1].scaled(1/4)) + output['depth_regr'] = [ + output_transformer['depth1_low'], + ] + output['depth_regr'] = flatten(output['depth_regr']) + + if 'ssim_lowest_cost' in output_transformer.keys(): + output['lowest_cost_ssim'] = output_transformer['ssim_lowest_cost'] + + mask3d = output_transformer['warped_mask'] + + mask2d = (mask3d.sum(0) == mask3d.shape[0]).float() + mask2d[:, :2, :] = 0 + mask2d[:, -2:, :] = 0 + mask2d[:, :, :2] = 0 + mask2d[:, :,-2:] = 0 + + output['confidence_mask_transformer'] = mask2d + output['confidence_mask_transformer3d'] = mask3d + output['lowest_cost_transformer1'] = depth2inv(output_transformer['depth1_low']) + output['lowest_cost_transformer2'] = depth2inv(output_transformer['depth2_low']) + output['cost_volume_transformer'] = output_transformer['attn_weight_softmax'][0].permute(0, 3, 1, 2) + + cost_volume = output['cost_volume_transformer'] + confidence_mask = output['confidence_mask_transformer'] + lowest_cost = output['lowest_cost_transformer1'] + + if 'ssim_lowest_cost' in output_transformer: + output['ssim_lowest_cost'] = output_transformer['ssim_lowest_cost'] + + confidence_mask = self.compute_confidence_mask( + cost_volume.detach() * confidence_mask.detach()) + cost_volume = cost_volume * confidence_mask.unsqueeze(1) + + post_matching_feats = self.reduce_conv( + torch.cat([current_feats, cost_volume], 1)) + + feats.append(self.layer2(post_matching_feats)) + feats.append(self.layer3(feats[-1])) + feats.append(self.layer4(feats[-1])) + + output.update(**{ + 'features': feats, + 'lowest_cost': lowest_cost, + 'confidence_mask': confidence_mask, + 'cost_volume': cost_volume, + }) + + return output diff --git a/vidar/arch/networks/encoders/ResNetEncoder.py b/vidar/arch/networks/encoders/ResNetEncoder.py new file mode 100755 index 0000000000000000000000000000000000000000..b4cb41bf97745a762d2b154f616db20a4ed2a2e2 --- /dev/null +++ b/vidar/arch/networks/encoders/ResNetEncoder.py @@ -0,0 +1,112 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from abc import ABC + +import numpy as np +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +import torchvision.models as models + +RESNET_VERSIONS = { + 18: models.resnet18, + 34: models.resnet34, + 50: models.resnet50, + 101: models.resnet101, + 152: models.resnet152 +} + +################## + + +class ResNetMultiInput(models.ResNet, ABC): + """ResNet encoder with multiple inputs""" + def __init__(self, block_type, block_channels, num_input_rgb): + super().__init__(block_type, block_channels) + + self.inplanes = 64 + self.conv1 = nn.Conv2d( + num_input_rgb * 3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block_type, 64, block_channels[0]) + self.layer2 = self._make_layer(block_type, 128, block_channels[1], stride=2) + self.layer3 = self._make_layer(block_type, 256, block_channels[2], stride=2) + self.layer4 = self._make_layer(block_type, 512, block_channels[3], stride=2) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + +def resnet_multi_input(num_layers, num_input_rgb, pretrained=True): + """Create a resnet encoder with multiple input images by copying the first layer""" + assert num_layers in [18, 50], 'Can only run with 18 or 50 layer resnet' + + block_channels = { + 18: [2, 2, 2, 2], + 50: [3, 4, 6, 3] + }[num_layers] + + block_type = { + 18: models.resnet.BasicBlock, + 50: models.resnet.Bottleneck + }[num_layers] + + model = ResNetMultiInput(block_type, block_channels, num_input_rgb) + + if pretrained: + loaded = model_zoo.load_url(models.resnet.model_urls['resnet{}'.format(num_layers)]) + loaded['conv1.weight'] = torch.cat( + [loaded['conv1.weight']] * num_input_rgb, 1) / num_input_rgb + model.load_state_dict(loaded) + + return model + + +class ResNetEncoder(nn.Module, ABC): + """ + ResNet encoder network + + Parameters + ---------- + cfg : Config + Configuration with parameters + """ + def __init__(self, cfg): + super().__init__() + + self.num_ch_enc = np.array([64, 64, 128, 256, 512]) + self.features = [] + + assert cfg.version in RESNET_VERSIONS, f'Invalid ResNet version: {cfg.version}' + + if cfg.num_rgb_in > 1: + self.encoder = resnet_multi_input( + cfg.version, cfg.num_rgb_in, cfg.pretrained) + else: + self.encoder = RESNET_VERSIONS[cfg.version](cfg.pretrained) + + if cfg.version > 34: + self.num_ch_enc[1:] *= 4 + + def forward(self, input_image): + """Network forward pass""" + + x = (input_image - 0.45) / 0.225 + x = self.encoder.conv1(x) + x = self.encoder.bn1(x) + + self.features.clear() + self.features.append(self.encoder.relu(x)) + self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1]))) + self.features.append(self.encoder.layer2(self.features[-1])) + self.features.append(self.encoder.layer3(self.features[-1])) + self.features.append(self.encoder.layer4(self.features[-1])) + + return self.features diff --git a/vidar/arch/networks/intrinsics/IntrinsicsNet.py b/vidar/arch/networks/intrinsics/IntrinsicsNet.py new file mode 100644 index 0000000000000000000000000000000000000000..5c9acfad2a3fa1eef4cd6d6453136850b10ca314 --- /dev/null +++ b/vidar/arch/networks/intrinsics/IntrinsicsNet.py @@ -0,0 +1,105 @@ +from __future__ import absolute_import, division, print_function + +import numpy as np +import torch +import torch.nn as nn + +from collections import OrderedDict + +from abc import ABC + +from vidar.arch.networks.BaseNet import BaseNet +from vidar.utils.config import cfg_has + + +class IntrinsicsNet(BaseNet, ABC): + def __init__(self, cfg): + super().__init__(cfg) + assert cfg_has(cfg, 'shape') + self.image_shape = cfg.shape + self.camera_model = cfg_has(cfg, 'camera_model', 'UCM') + self.sigmoid_init = nn.Parameter(torch.tensor(self.setup_sigmoid_init(cfg), dtype=torch.float), requires_grad=True) + self.scale = nn.Parameter(torch.tensor(self.setup_scale(cfg), dtype=torch.float, requires_grad=False), requires_grad=False) + self.offset = nn.Parameter(torch.tensor(self.setup_offset(cfg), dtype=torch.float, requires_grad=False), requires_grad=False) + + self.tanh = nn.Tanh() + self.sigmoid = nn.Sigmoid() + + def __len__(self): + if self.camera_model == 'Pinhole': + return 4 + elif self.camera_model == 'UCM': + return 5 + elif self.camera_model == 'EUCM': + return 6 + elif self.camera_model == 'DS': + return 6 + else: + raise NotImplementedError('Camera model {} is not implemented. Please choose from \{Pinhole,UCM, EUCM, DS\}.'.format(self.camera_model)) + + def setup_sigmoid_init(self, cfg): + if cfg_has(cfg, 'sigmoid_init'): + assert len(cfg.sigmoid_init) == self.__len__() + return np.array(cfg.sigmoid_init) + else: + if self.camera_model == 'Pinhole': + return np.array([0.0, 0.0, 0.0, 0.0]) + elif self.camera_model == 'UCM': + return np.array([0.0, 0.0, 0.0, 0.0, 0.0]) + elif self.camera_model == 'EUCM': + return np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) + elif self.camera_model == 'DS': + return np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) + else: + raise NotImplementedError('Camera model {} is not implemented. Please choose from \{Pinhole,UCM, EUCM, DS\}.'.format(self.camera_model)) + + def setup_scale(self, cfg): + if cfg_has(cfg, 'scale'): + assert len(cfg.scale) == self.__len__() + return np.array(cfg.scale) + else: + h, w = self.image_shape + fx_scale, fy_scale = (h + w), (h + w) + cx_scale = w + cy_scale = h + alpha_scale = 1.0 + beta_scale = 2.0 + xi_scale = 2.0 + + if self.camera_model == 'Pinhole': + return np.array([fx_scale, fy_scale, cx_scale, cy_scale]) + elif self.camera_model == 'UCM': + return np.array([fx_scale, fy_scale, cx_scale, cy_scale, alpha_scale]) + elif self.camera_model == 'EUCM': + return np.array([fx_scale, fy_scale, cx_scale, cy_scale, alpha_scale, beta_scale]) + elif self.camera_model == 'DS': + return np.array([fx_scale, fy_scale, cx_scale, cy_scale, xi_scale, alpha_scale]) + else: + raise NotImplementedError('Camera model {} is not implemented. Please choose from \{Pinhole,UCM, EUCM, DS\}.'.format(self.camera_model)) + + def setup_offset(self, cfg): + if cfg_has(cfg, 'offset'): + assert len(cfg.offset) == self.__len__() + return np.array(cfg.offset) + else: + if self.camera_model == 'Pinhole': + return np.zeros(4) + elif self.camera_model == 'UCM': + return np.zeros(5) + elif self.camera_model == 'EUCM': + return np.zeros(6) + elif self.camera_model == 'DS': + return np.array([0.0, 0.0, 0.0, 0.0, -1.0, 0.0]) + else: + raise NotImplementedError('Camera model {} is not implemented. Please choose from \{Pinhole,UCM, EUCM, DS\}.'.format(self.camera_model)) + + + def forward(self, rgb): + B = rgb.shape[0] + + self.scale.requires_grad = False + self.offset.requires_grad = False + + I = self.sigmoid(self.sigmoid_init) * self.scale + self.offset + + return I.unsqueeze(0).repeat(B,1) \ No newline at end of file diff --git a/vidar/arch/networks/layers/__init__.py b/vidar/arch/networks/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vidar/arch/networks/layers/convs.py b/vidar/arch/networks/layers/convs.py new file mode 100644 index 0000000000000000000000000000000000000000..d441d2384358683b3d92a277d175cc584e4cd180 --- /dev/null +++ b/vidar/arch/networks/layers/convs.py @@ -0,0 +1,43 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from abc import ABC + +import torch.nn as nn +import torch.nn.functional as F + + +def upsample(x): + """Upsample input tensor by a factor of 2""" + return F.interpolate(x, scale_factor=2, mode="nearest") + + +class ConvBlock(nn.Module, ABC): + """Layer to perform a convolution followed by ELU""" + def __init__(self, in_channels, out_channels, kernel_size=3): + super().__init__() + self.conv = Conv3x3(in_channels, out_channels, kernel_size=kernel_size) + self.nonlin = nn.ELU(inplace=True) + + def forward(self, x): + out = self.conv(x) + out = self.nonlin(out) + return out + + +class Conv3x3(nn.Module, ABC): + """Layer to pad and convolve input""" + def __init__(self, in_channels, out_channels, use_refl=True, kernel_size=3): + super().__init__() + if kernel_size == 3: + if use_refl: + self.pad = nn.ReflectionPad2d(1) + else: + self.pad = nn.ZeroPad2d(1) + else: + self.pad = nn.Identity() + self.conv = nn.Conv2d(int(in_channels), int(out_channels), kernel_size=kernel_size) + + def forward(self, x): + out = self.pad(x) + out = self.conv(out) + return out diff --git a/vidar/arch/networks/layers/depthformer/context_adjustment.py b/vidar/arch/networks/layers/depthformer/context_adjustment.py new file mode 100644 index 0000000000000000000000000000000000000000..7dfcf0317b3298f0f965c0711b9b972269c88fd1 --- /dev/null +++ b/vidar/arch/networks/layers/depthformer/context_adjustment.py @@ -0,0 +1,72 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import torch +from torch import nn +from torch.nn.utils import weight_norm + + +class ContextAdjustmentLayer(nn.Module): + """ + Context adjustment layer + Base on https://github.com/mli0603/stereo-transformer/blob/main/module/context_adjustment_layer.py + + Parameters + ---------- + cfg : Config + Configuration with parameters + """ + def __init__(self, cfg): + super().__init__() + + num_blocks = cfg.num_blocks + feature_dim = cfg.feat_dim + expansion = cfg.expansion_ratio + + self.num_blocks = num_blocks + + self.in_conv = nn.Conv2d(4, feature_dim, kernel_size=3, padding=1) + self.layers = nn.ModuleList([ResBlock(feature_dim, expansion) for _ in range(num_blocks)]) + self.out_conv = nn.Conv2d(feature_dim, 1, kernel_size=3, padding=1) + + def forward(self, depth_raw, img): + """Network forward pass""" + + eps = 1e-6 + mean_depth_pred = depth_raw.mean() + std_depth_pred = depth_raw.std() + eps + depth_pred_normalized = (depth_raw - mean_depth_pred) / std_depth_pred + + feat = self.in_conv(torch.cat([depth_pred_normalized, img], dim=1)) + for layer in self.layers: + feat = layer(feat, depth_pred_normalized) + + depth_res = self.out_conv(feat) + depth_final = depth_pred_normalized + depth_res + + return depth_final * std_depth_pred + mean_depth_pred + + +class ResBlock(nn.Module): + def __init__(self, n_feats, expansion_ratio, res_scale=1.0): + """ + ResNet block + + Parameters + ---------- + n_feats : Int + Number of layer features + expansion_ratio : Int + Expansion ratio for middle layer + res_scale : Float + Scale ratio for residual connections + """ + super(ResBlock, self).__init__() + self.res_scale = res_scale + self.module = nn.Sequential( + weight_norm(nn.Conv2d(n_feats + 1, n_feats * expansion_ratio, kernel_size=3, padding=1)), + nn.ReLU(inplace=True), + weight_norm(nn.Conv2d(n_feats * expansion_ratio, n_feats, kernel_size=3, padding=1)) + ) + + def forward(self, x, depth): + return x + self.module(torch.cat([depth, x], dim=1)) * self.res_scale diff --git a/vidar/arch/networks/layers/depthformer/feature_extraction.py b/vidar/arch/networks/layers/depthformer/feature_extraction.py new file mode 100644 index 0000000000000000000000000000000000000000..3dd0a3258130e83f7577817b74dc0059b60322b2 --- /dev/null +++ b/vidar/arch/networks/layers/depthformer/feature_extraction.py @@ -0,0 +1,106 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import math + +import torch +import torch.nn.functional as F +from torch import nn +from torchvision.models.resnet import BasicBlock + + +class SppBackbone(nn.Module): + """ + Feature extraction network + Base on https://github.com/mli0603/stereo-transformer/blob/main/module/feat_extractor_backbone.py + + Parameters + ---------- + cfg : Config + Configuration with parameters + """ + def __init__(self, cfg): + super().__init__() + self.inplanes = 32 + self.in_conv = nn.Sequential( + nn.Conv2d(3, 16, kernel_size=3, padding=1, stride=2, bias=False), + nn.BatchNorm2d(16), + nn.ReLU(inplace=True), + nn.Conv2d(16, 16, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(16), + nn.ReLU(inplace=True), + nn.Conv2d(16, 32, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True) + ) + + self.resblock_1 = self._make_layer(BasicBlock, 64, 3, 2) + self.resblock_2 = self._make_layer(BasicBlock, 128, 3, 2) + + self.branch1 = nn.Sequential( + nn.AvgPool2d((16, 16), stride=(16, 16)), + nn.Conv2d(128, 32, kernel_size=1, bias=False), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True) + ) + + self.branch2 = nn.Sequential( + nn.AvgPool2d((8, 8), stride=(8, 8)), + nn.Conv2d(128, 32, kernel_size=1, bias=False), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True) + ) + + self.branch3 = nn.Sequential( + nn.AvgPool2d((4, 4), stride=(4, 4)), + nn.Conv2d(128, 32, kernel_size=1, bias=False), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True) + ) + + self.branch4 = nn.Sequential( + nn.AvgPool2d((2, 2), stride=(2, 2)), + nn.Conv2d(128, 32, kernel_size=1, bias=False), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True) + ) + + def _make_layer(self, block, planes, blocks, stride=1): + """Create intermediate layer""" + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion) + ) + layers = [block(self.inplanes, planes, stride, downsample)] + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + return nn.Sequential(*layers) + + def forward(self, target, context): + """Network forward pass""" + + _, _, h, w = target.shape + + both = torch.cat([target, context], dim=0) + + output = self.in_conv(both) + + output_1 = self.resblock_1(output) + output_2 = self.resblock_2(output_1) + + h_spp, w_spp = math.ceil(h / 16), math.ceil(w / 16) + spp_1 = self.branch1(output_2) + spp_1 = F.interpolate(spp_1, size=(h_spp, w_spp), mode='bilinear', align_corners=False) + spp_2 = self.branch2(output_2) + spp_2 = F.interpolate(spp_2, size=(h_spp, w_spp), mode='bilinear', align_corners=False) + spp_3 = self.branch3(output_2) + spp_3 = F.interpolate(spp_3, size=(h_spp, w_spp), mode='bilinear', align_corners=False) + spp_4 = self.branch4(output_2) + spp_4 = F.interpolate(spp_4, size=(h_spp, w_spp), mode='bilinear', align_corners=False) + output_3 = torch.cat([spp_1, spp_2, spp_3, spp_4], dim=1) # 1/16 + + return [both, output, output_1, output_2, output_3] + diff --git a/vidar/arch/networks/layers/depthformer/regression.py b/vidar/arch/networks/layers/depthformer/regression.py new file mode 100644 index 0000000000000000000000000000000000000000..4517bfcf4c505185810ff0fb345d3f0dfe0577b0 --- /dev/null +++ b/vidar/arch/networks/layers/depthformer/regression.py @@ -0,0 +1,118 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import torch +import torch.nn.functional as F +from torch import nn + +from vidar.arch.networks.layers.depthformer.context_adjustment import ContextAdjustmentLayer +from vidar.utils.volume import compute_depth_bin + + +class RegressionHead(nn.Module): + def __init__(self, cfg): + super().__init__() + self.cal = ContextAdjustmentLayer(cfg.context_adjustment) + self.phi = nn.Parameter(torch.tensor(0.0, requires_grad=True)) + self.monocular = True + + @staticmethod + def _compute_unscaled_pos_shift(w, device): + return torch.linspace(0, w - 1, w)[None, None, None, :].to(device) + + @staticmethod + def _compute_low_res_depth(pos_shift, attn_weight): + high_response = torch.argmax(attn_weight, dim=-1) # NxHxW + response_range = torch.stack([high_response - 1, high_response, high_response + 1], dim=-1) + attn_weight_pad = F.pad(attn_weight, [1, 1], value=0.0) + attn_weight_rw = torch.gather(attn_weight_pad, -1, response_range + 1) + + norm = attn_weight_rw.sum(-1, keepdim=True) + norm[norm < 0.1] = 1.0 + + attn_weight_rw = attn_weight_rw / norm + pos_pad = F.pad(pos_shift, [1, 1]).expand_as(attn_weight_pad).clone() + pos_pad[..., -1] = pos_shift[..., -1] + 1 + pos_rw = torch.gather(pos_pad, -1, response_range + 1) + depth_pred_low_res = (attn_weight_rw * pos_rw) + + depth_pred_low_res = depth_pred_low_res.sum(-1) + + return depth_pred_low_res, norm, high_response + + def upsample(self, x, depth_pred, scale=1.0): + _, _, h, w = x.size() + depth_pred_attn = depth_pred * scale + depth_pred = F.interpolate(depth_pred_attn[None,], size=(h, w), mode='nearest') + depth_pred_final = self.cal(depth_pred, x) + return depth_pred_final.squeeze(1), depth_pred_attn.squeeze(1) + + def softmax(self, attn): + bs, h, w, d = attn.shape + similarity_matrix = torch.cat([attn, self.phi.expand(bs, h, w, 1).to(attn.device)], -1) + attn_softmax = F.softmax(similarity_matrix, dim=-1) + return attn_softmax + + def forward(self, attn_weight, target, context, sampled_rows, sampled_cols, min_depth, max_depth, num_bins): + + stride = [1] + + outputs = [] + for s in stride: + output = self.forward2( + attn_weight, target, sampled_cols, min_depth, max_depth, num_bins, s) + outputs.append(output) + final_output = {} + for key in outputs[0].keys(): + final_output[key] = [o[key] for o in outputs] + return final_output + + def forward2(self, attn_weight, target, sampled_cols, min_depth, max_depth, num_bins, stride=1): + + bs, _, h, w = target.size() + output = {} + + if stride > 1: + shape = list(attn_weight.shape) + shape[-1] = shape[-1] // stride + attn_weight_tmp = torch.zeros(shape, dtype=attn_weight.dtype, device=attn_weight.device) + for i in range(0, shape[-1]): + attn_weight_tmp[..., i] = attn_weight[..., i * stride:(i + 1) * stride].mean(-1) + attn_weight = attn_weight_tmp + + attn_ot = self.softmax(attn_weight) + attn_ot = attn_ot[..., :-1] + output['attn_weight_softmax'] = attn_ot + + pos_shift = self._compute_unscaled_pos_shift(attn_weight.shape[3], attn_weight.device) + + depth_pred_low_res1, matched_attn1, high_response1 = self._compute_low_res_depth(pos_shift, attn_ot) + depth_pred_low_res2, matched_attn2, high_response2 = self._compute_low_res_depth(pos_shift, attn_ot) + + output['high_response'] = high_response1 + + if sampled_cols is not None: + output['depth_pred1'], output['depth_pred1_low'] = \ + self.upsample(target, depth_pred_low_res1) + output['depth_pred2'], output['depth_pred2_low'] = \ + self.upsample(target, depth_pred_low_res2) + else: + output['depth_pred_low'] = depth_pred_low_res1 + output['depth_pred'] = depth_pred_low_res1 + + if self.monocular: + + num_bins = num_bins // stride + + depth1 = compute_depth_bin(min_depth, max_depth, num_bins, output['depth_pred1']) + depth1_low = compute_depth_bin(min_depth, max_depth, num_bins, output['depth_pred1_low']) + + depth2 = compute_depth_bin(min_depth, max_depth, num_bins, output['depth_pred2']) + depth2_low = compute_depth_bin(min_depth, max_depth, num_bins, output['depth_pred2_low']) + + output['depth1'] = depth1 + output['depth1_low'] = depth1_low + + output['depth2'] = depth2 + output['depth2_low'] = depth2_low + + return output diff --git a/vidar/arch/networks/layers/depthformer/tokenizer.py b/vidar/arch/networks/layers/depthformer/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..e9f9134b258b054f8ff74f8ad92ee095045053f9 --- /dev/null +++ b/vidar/arch/networks/layers/depthformer/tokenizer.py @@ -0,0 +1,114 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import torch +from torch import nn +from torchvision.models.densenet import _DenseBlock + + +def center_crop(layer, max_height, max_width): + _, _, h, w = layer.size() + xy1 = (w - max_width) // 2 + xy2 = (h - max_height) // 2 + return layer[:, :, xy2:(xy2 + max_height), xy1:(xy1 + max_width)] + + +class TransitionUp(nn.Module): + """Transposed convolution for upsampling""" + def __init__(self, in_channels: int, out_channels: int, scale: int = 2): + super().__init__() + if scale == 2: + self.convTrans = nn.ConvTranspose2d( + in_channels=in_channels, out_channels=out_channels, + kernel_size=3, stride=2, padding=0, bias=True) + elif scale == 4: + self.convTrans = nn.Sequential( + nn.ConvTranspose2d( + in_channels=in_channels, out_channels=out_channels, + kernel_size=3, stride=2, padding=0, bias=False), + nn.BatchNorm2d(out_channels), + nn.ConvTranspose2d( + in_channels=out_channels, out_channels=out_channels, + kernel_size=3, stride=2, padding=0, bias=True) + ) + + def forward(self, x, skip): + out = self.convTrans(x) + out = center_crop(out, skip.size(2), skip.size(3)) + out = torch.cat([out, skip], 1) + return out + + +class DoubleConv(nn.Module): + """Helper class with two convolutional layers, plus BatchNorm and ReLU""" + def __init__(self, in_channels, out_channels): + super(DoubleConv, self).__init__() + + self.double_conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + return self.double_conv(x) + + +class Tokenizer(nn.Module): + """ + Feature tokenization network + Base on https://github.com/mli0603/stereo-transformer/blob/main/module/feat_extractor_tokenizer.py + + Parameters + ---------- + cfg : Config + Configuration with parameters + """ + def __init__(self, cfg): + super(Tokenizer, self).__init__() + + block_config = [4, 4, 4, 4] + backbone_feat_channel = [64, 128, 128] + hidden_dim = cfg.channel_dim + growth_rate = 4 + + backbone_feat_channel.reverse() + block_config.reverse() + + self.num_resolution = len(backbone_feat_channel) + self.block_config = block_config + self.growth_rate = growth_rate + + self.bottle_neck = _DenseBlock( + block_config[0], backbone_feat_channel[0], 4, drop_rate=0.0, growth_rate=growth_rate) + up = [] + dense_block = [] + prev_block_channels = growth_rate * block_config[0] + for i in range(self.num_resolution): + if i == self.num_resolution - 1: + up.append(TransitionUp(prev_block_channels, hidden_dim, 4)) + dense_block.append(DoubleConv(hidden_dim + 3, hidden_dim)) + else: + up.append(TransitionUp(prev_block_channels, prev_block_channels)) + cur_channels_count = prev_block_channels + backbone_feat_channel[i + 1] + dense_block.append( + _DenseBlock(block_config[i + 1], cur_channels_count, 4, drop_rate=0.0, growth_rate=growth_rate)) + prev_block_channels = growth_rate * block_config[i + 1] + + self.up = nn.ModuleList(up) + self.dense_block = nn.ModuleList(dense_block) + + def forward(self, features): + """Network forward pass""" + + features.reverse() + output = self.bottle_neck(features[0]) + output = output[:, -(self.block_config[0] * self.growth_rate):] + for i in range(self.num_resolution): + hs = self.up[i](output, features[i + 1]) + output = self.dense_block[i](hs) + if i < self.num_resolution - 1: + output = output[:, -(self.block_config[i + 1] * self.growth_rate):] + return output diff --git a/vidar/arch/networks/layers/depthformer/transformer.py b/vidar/arch/networks/layers/depthformer/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..ea81a7bcc0bdee245657e4fd7c3c18ce656737a4 --- /dev/null +++ b/vidar/arch/networks/layers/depthformer/transformer.py @@ -0,0 +1,223 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import copy +from functools import partial + +import torch +import torch.nn.functional as F +from torch import nn + +from vidar.arch.losses.SSIMLoss import SSIMLoss +from vidar.utils.tensor import grid_sample +from vidar.utils.volume import compute_depth_bins, compute_depth_bin + + +def get_clones(module, N): + """Create clones of a module""" + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def prepareA(feat): + """Reorganize features in one way""" + return feat.permute(1, 3, 2, 0).flatten(2).permute(1, 2, 0) + + +def prepareB(x): + """Reorganize features in another way""" + d, c, h, w = x.shape + return x.permute(1, 2, 3, 0).reshape(c, h * w, d).permute(2, 1, 0) + + +def unprepare(feat, shape): + """Return features back to original shape""" + b, c, h, w = shape + return feat.permute(2, 0, 1).reshape(c, w, h, b).permute(3, 0, 2, 1) + + +class Transformer(nn.Module): + """ + Transformer network for Feature Matching (https://arxiv.org/abs/2204.07616) + + Parameters + ---------- + cfg : Config + Configuration with parameters + """ + def __init__(self, cfg): + super().__init__() + + self.hidden_dim = cfg.channel_dim + self.num_attn_layers = cfg.num_attn_layers + + self_attn_layer = TransformerSelfAttnLayer(self.hidden_dim, cfg.nheads) + self.self_attn_layers = get_clones(self_attn_layer, self.num_attn_layers) + + cross_attn_layer = TransformerCrossAttnLayer(self.hidden_dim, cfg.nheads) + self.cross_attn_layers = get_clones(cross_attn_layer, self.num_attn_layers) + + self.norm = nn.LayerNorm(self.hidden_dim) + + self.grid_sample = partial( + grid_sample, padding_mode='zeros', mode='bilinear', align_corners=True) + + self.grid_sample_nearest = partial( + grid_sample, padding_mode='zeros', mode='nearest', align_corners=True) + + def _alternating_attn(self, feat1, feat2, + cam=None, min_depth=None, max_depth=None, num_bins=None): + """Perform self- and cross-attention between two feature maps""" + device = feat1.device + cam = cam.to(device) + h, w = cam.hw + + depth_bins = compute_depth_bins(min_depth, max_depth, num_bins, 'sid').to(device) + + ones = torch.ones((1, h, w), dtype=feat1.dtype, device=device) + warped_depth = torch.stack([depth * ones for depth in depth_bins], 1) + coords = cam.coords_from_cost_volume(warped_depth)[0] + + coords[coords < -1] = -2 + coords[coords > +1] = +2 + + repeated_feat2 = feat2.repeat([num_bins, 1, 1, 1]) + warped = self.grid_sample(repeated_feat2, coords.type(repeated_feat2.dtype)) + + repeated_ones = ones.repeat([num_bins, 1, 1, 1]) + warped_mask = self.grid_sample_nearest(repeated_ones, coords.type(repeated_ones.dtype)) + + with torch.no_grad(): + ssim_volume = SSIMLoss()(feat1, warped)['loss'].mean(1).unsqueeze(0) + lowest_cost = 1. / compute_depth_bin(min_depth, max_depth, num_bins, torch.min(ssim_volume, 1)[1]) + + feat1 = prepareB(feat1) + feat2 = prepareB(warped) + + attn_weight = None + for idx, (self_attn, cross_attn) in \ + enumerate(zip(self.self_attn_layers, self.cross_attn_layers)): + feat1 = self_attn(feat1) + feat1, feat2, attn_weight = cross_attn(feat1, feat2) + + return { + 'attn_weight': attn_weight, + 'warped_mask': warped_mask, + 'ssim_lowest_cost': lowest_cost, + 'ssim_cost_volume': ssim_volume, + } + + def forward(self, feat1, feat2, cam=None, min_depth=None, max_depth=None, num_bins=None): + """Network forward pass""" + + bs, c, hn, w = feat1.shape + + transformer_output = self._alternating_attn( + feat1, feat2, cam=cam, min_depth=min_depth, max_depth=max_depth, num_bins=num_bins) + transformer_output['attn_weight'] = \ + transformer_output['attn_weight'].view(bs, hn, w, num_bins) + + return transformer_output + + +class TransformerSelfAttnLayer(nn.Module): + """Self-attention layer for transformers""" + def __init__(self, hidden_dim, nheads): + super().__init__() + self.self_attn = MultiheadAttentionRelative(hidden_dim, nheads) + + self.norm1 = nn.LayerNorm(hidden_dim) + + def forward(self, feat): + feat_out = self.norm1(feat) + feat_out, _, _ = self.self_attn(query=feat_out, key=feat_out, value=feat_out) + return feat + feat_out + + +class TransformerCrossAttnLayer(nn.Module): + """Cross-attention layer for transformers""" + def __init__(self, hidden_dim, nheads): + super().__init__() + self.cross_attn = MultiheadAttentionRelative(hidden_dim, nheads) + + self.norm1 = nn.LayerNorm(hidden_dim) + self.norm2 = nn.LayerNorm(hidden_dim) + + def forward(self, feat1, feat2): + + feat1_2 = self.norm1(feat1) + feat2_2 = self.norm1(feat2) + + feat1_2, attn_weight, raw_attn = self.cross_attn(query=feat1_2, key=feat2_2, value=feat2_2) + feat1 = feat1 + feat1_2 + + return feat1, feat2, raw_attn + + @torch.no_grad() + def _generate_square_subsequent_mask(self, sz): + mask = torch.triu(torch.ones(sz, sz), diagonal=1) + mask[mask == 1] = float('-inf') + return mask + + +class MultiheadAttentionRelative(nn.MultiheadAttention): + """Multi-head attention layer""" + def __init__(self, embed_dim, num_heads): + super().__init__( + embed_dim, num_heads, dropout=0.0, bias=True, + add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None) + + def forward(self, query, key, value): + + w2, bsz2, embed_dim2 = key.size() + w, bsz, embed_dim = query.size() + + head_dim = embed_dim // self.num_heads + + if torch.equal(query, key) and torch.equal(key, value): + q, k, v = F.linear( + query, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) + elif torch.equal(key, value): + _b = self.in_proj_bias + _start = 0 + _end = embed_dim + _w = self.in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = F.linear(query, _w, _b) + + if key is None: + assert value is None + k = None + v = None + else: + _b = self.in_proj_bias + _start = embed_dim + _end = None + _w = self.in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + k, v = F.linear(key, _w, _b).chunk(2, dim=-1) + else: + raise ValueError('Invalid key/query/value') + + scaling = float(head_dim) ** - 0.5 + q = q * scaling + + q = q.contiguous().view(w, bsz, self.num_heads, head_dim) + if k is not None: + k = k.contiguous().view(-1, bsz, self.num_heads, head_dim) + if v is not None: + v = v.contiguous().view(-1, bsz, self.num_heads, head_dim) + + attn = torch.einsum('wnec,vnec->newv', q, k) + raw_attn = attn + attn = F.softmax(attn, dim=-1) + + v_out = torch.bmm(attn.view(bsz * self.num_heads, w, w2), + v.permute(1, 2, 0, 3).view(bsz * self.num_heads, w2, head_dim)) + v_out = v_out.reshape(bsz, self.num_heads, w, head_dim).permute(2, 0, 1, 3).reshape(w, bsz, embed_dim) + v_out = F.linear(v_out, self.out_proj.weight, self.out_proj.bias) + + attn = attn.sum(dim=1) / self.num_heads + raw_attn = raw_attn.sum(dim=1) + + return v_out, attn, raw_attn diff --git a/vidar/arch/networks/layers/depthformer/transformer_net.py b/vidar/arch/networks/layers/depthformer/transformer_net.py new file mode 100644 index 0000000000000000000000000000000000000000..7158d19d6f08704bcf98b71f46e148e750a51792 --- /dev/null +++ b/vidar/arch/networks/layers/depthformer/transformer_net.py @@ -0,0 +1,106 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import torch +import torch.nn as nn + +from vidar.arch.networks.layers.depthformer.feature_extraction import SppBackbone as Backbone +from vidar.arch.networks.layers.depthformer.regression import RegressionHead +from vidar.arch.networks.layers.depthformer.tokenizer import Tokenizer +from vidar.arch.networks.layers.depthformer.transformer import Transformer + + +def batched_index_select(source, dim, index): + views = [source.shape[0]] + [1 if i != dim else -1 for i in range(1, len(source.shape))] + expanse = list(source.shape) + expanse[0] = -1 + expanse[dim] = -1 + index = index.view(views).expand(expanse) + return torch.gather(source, dim, index) + + +class TransformerNet(nn.Module): + + def __init__(self, cfg, decoder_type='regression'): + super().__init__() + + self.backbone = Backbone(cfg) + self.tokenizer = Tokenizer(cfg) + self.transformer = Transformer(cfg) + + self.min_depth = cfg.min_depth + self.max_depth = cfg.max_depth + self.num_bins = cfg.num_bins + + self.decoder_type = decoder_type + + self.regression_head = RegressionHead(cfg) + + self._reset_parameters() + self._disable_batchnorm_tracking() + self._relu_inplace() + + def _reset_parameters(self): + for n, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.BatchNorm2d, nn.LayerNorm)): + nn.init.constant_(m.weight, 1) + nn.init.zeros_(m.bias) + + def _disable_batchnorm_tracking(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.track_running_stats = False + + def _relu_inplace(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.inplace = True + + def fix_layers(self): + def iterate(module): + for key in module.keys(): + if type(module[key]) == nn.BatchNorm2d: + module[key] = nn.InstanceNorm2d( + module[key].num_features, + module[key].eps, + module[key].momentum, + module[key].affine, + ) + iterate(module[key]._modules) + iterate(self._modules) + + def forward(self, target, context, sampled_rows, sampled_cols, cam=None): + + bs, _, h, w = target.size() + + feat_all = self.backbone(target, context) + feat = [feat_all[0]] + feat_all[2:] + + feat1 = [f[[0]] for f in feat] + feat2 = [f[[1]] for f in feat] + + feat1 = self.tokenizer(feat1) + feat2 = self.tokenizer(feat2) + + if sampled_cols is not None: + feat1 = batched_index_select(feat1, 3, sampled_cols) + feat2 = batched_index_select(feat2, 3, sampled_cols) + if sampled_rows is not None: + feat1 = batched_index_select(feat1, 2, sampled_rows) + feat2 = batched_index_select(feat2, 2, sampled_rows) + + output_transformer = self.transformer( + feat1, feat2, cam=cam, min_depth=self.min_depth, max_depth=self.max_depth, num_bins=self.num_bins) + + output_regression = self.regression_head( + output_transformer['attn_weight'], target, context, sampled_rows, sampled_cols, + min_depth=self.min_depth, max_depth=self.max_depth, num_bins=self.num_bins, + ) + + return { + **output_transformer, + **output_regression, + } diff --git a/vidar/arch/networks/layers/fsm/camera.py b/vidar/arch/networks/layers/fsm/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..83f536f291657aff2f7d5ab3f10c30182ef82ee6 --- /dev/null +++ b/vidar/arch/networks/layers/fsm/camera.py @@ -0,0 +1,289 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from functools import lru_cache + +import torch +import torch.nn as nn + +from vidar.arch.networks.layers.fsm.camera_utils import scale_intrinsics, invert_intrinsics +from vidar.arch.networks.layers.fsm.pose import Pose +from vidar.utils.tensor import pixel_grid +from vidar.utils.types import is_tensor, is_list + + +class Camera(nn.Module): + """ + Differentiable camera class implementing reconstruction and projection + functions for a pinhole model. + """ + def __init__(self, K, Tcw=None, Twc=None, hw=None): + """ + Initializes the Camera class + + Parameters + ---------- + K : torch.Tensor + Camera intrinsics [B,3,3] + Tcw : Pose or torch.Tensor + Camera -> World pose transformation [B,4,4] + Twc : Pose or torch.Tensor + World -> Camera pose transformation [B,4,4] + hw : tuple or torch.Tensor + Camera width and height, or a tensor with the proper shape + """ + super().__init__() + assert Tcw is None or Twc is None, 'You should provide either Tcw or Twc' + self.K = K + self.hw = None if hw is None else hw.shape[-2:] if is_tensor(hw) else hw[-2:] + if Tcw is not None: + self.Tcw = Tcw if isinstance(Tcw, Pose) else Pose(Tcw) + elif Twc is not None: + self.Tcw = Twc.inverse() if isinstance(Twc, Pose) else Pose(Twc).inverse() + else: + self.Tcw = Pose.identity(len(self.K)) + + def __len__(self): + """Batch size of the camera intrinsics""" + return len(self.K) + + def __getitem__(self, idx): + """Return single camera from a batch position""" + return Camera(K=self.K[idx].unsqueeze(0), + hw=self.hw, Tcw=self.Tcw[idx]).to(self.device) + + @property + def wh(self): + """Return camera width and height""" + return None if self.hw is None else self.hw[::-1] + + @property + def pose(self): + """Return camera pose""" + return self.Twc.mat + + @property + def device(self): + """Return camera device""" + return self.K.device + + def invert_pose(self): + """Return new camera with inverted pose""" + return Camera(K=self.K, Tcw=self.Twc) + + def to(self, *args, **kwargs): + """Moves object to a specific device""" + self.K = self.K.to(*args, **kwargs) + self.Tcw = self.Tcw.to(*args, **kwargs) + return self + + @property + def fx(self): + """Focal length in x""" + return self.K[:, 0, 0] + + @property + def fy(self): + """Focal length in y""" + return self.K[:, 1, 1] + + @property + def cx(self): + """Principal point in x""" + return self.K[:, 0, 2] + + @property + def cy(self): + """Principal point in y""" + return self.K[:, 1, 2] + + @property + @lru_cache() + def Twc(self): + """World -> Camera pose transformation (inverse of Tcw)""" + return self.Tcw.inverse() + + @property + @lru_cache() + def Kinv(self): + """Inverse intrinsics (for lifting)""" + return invert_intrinsics(self.K) + + def equal(self, cam): + """Check if two cameras are the same""" + return torch.allclose(self.K, cam.K) and \ + torch.allclose(self.Tcw.mat, cam.Tcw.mat) + + def scaled(self, x_scale, y_scale=None): + """ + Returns a scaled version of the camera (changing intrinsics) + + Parameters + ---------- + x_scale : float + Resize scale in x + y_scale : float + Resize scale in y. If None, use the same as x_scale + + Returns + ------- + camera : Camera + Scaled version of the current camera + """ + # If single value is provided, use for both dimensions + if y_scale is None: + y_scale = x_scale + # If no scaling is necessary, return same camera + if x_scale == 1. and y_scale == 1.: + return self + # Scale intrinsics + K = scale_intrinsics(self.K.clone(), x_scale, y_scale) + # Scale image dimensions + hw = None if self.hw is None else (int(self.hw[0] * y_scale), + int(self.hw[1] * x_scale)) + # Return scaled camera + return Camera(K=K, Tcw=self.Tcw, hw=hw) + + def scaled_K(self, shape): + """Return scaled intrinsics to match a shape""" + if self.hw is None: + return self.K + else: + y_scale, x_scale = [sh / hw for sh, hw in zip(shape[-2:], self.hw)] + return scale_intrinsics(self.K, x_scale, y_scale) + + def scaled_Kinv(self, shape): + """Return scaled inverse intrinsics to match a shape""" + return invert_intrinsics(self.scaled_K(shape)) + + def reconstruct(self, depth, frame='w', scene_flow=None, return_grid=False): + """ + Reconstructs pixel-wise 3D points from a depth map. + + Parameters + ---------- + depth : torch.Tensor + Depth map for the camera [B,1,H,W] + frame : 'w' + Reference frame: 'c' for camera and 'w' for world + scene_flow : torch.Tensor + Optional per-point scene flow to be added (camera reference frame) [B,3,H,W] + return_grid : bool + Return pixel grid as well + + Returns + ------- + points : torch.tensor + Pixel-wise 3D points [B,3,H,W] + """ + # If depth is a list, return each reconstruction + if is_list(depth): + return [self.reconstruct(d, frame, scene_flow, return_grid) for d in depth] + # Dimension assertions + assert depth.dim() == 4 and depth.shape[1] == 1, \ + 'Wrong dimensions for camera reconstruction' + + # Create flat index grid [B,3,H,W] + B, _, H, W = depth.shape + grid = pixel_grid((H, W), B, device=depth.device, normalize=False, with_ones=True) + flat_grid = grid.view(B, 3, -1) + + # Get inverse intrinsics + Kinv = self.Kinv if self.hw is None else self.scaled_Kinv(depth.shape) + + # Estimate the outward rays in the camera frame + Xnorm = (Kinv.bmm(flat_grid)).view(B, 3, H, W) + # Scale rays to metric depth + Xc = Xnorm * depth + + # Add scene flow if provided + if scene_flow is not None: + Xc = Xc + scene_flow + + # If in camera frame of reference + if frame == 'c': + pass + # If in world frame of reference + elif frame == 'w': + Xc = self.Twc @ Xc + # If none of the above + else: + raise ValueError('Unknown reference frame {}'.format(frame)) + # Return points and grid if requested + return (Xc, grid) if return_grid else Xc + + def project(self, X, frame='w', normalize=True, return_z=False): + """ + Projects 3D points onto the image plane + + Parameters + ---------- + X : torch.Tensor + 3D points to be projected [B,3,H,W] + frame : 'w' + Reference frame: 'c' for camera and 'w' for world + normalize : bool + Normalize grid coordinates + return_z : bool + Return the projected z coordinate as well + + Returns + ------- + points : torch.Tensor + 2D projected points that are within the image boundaries [B,H,W,2] + """ + assert 2 < X.dim() <= 4 and X.shape[1] == 3, \ + 'Wrong dimensions for camera projection' + + # Determine if input is a grid + is_grid = X.dim() == 4 + # If it's a grid, flatten it + X_flat = X.view(X.shape[0], 3, -1) if is_grid else X + + # Get dimensions + hw = X.shape[2:] if is_grid else self.hw + # Get intrinsics + K = self.scaled_K(X.shape) if is_grid else self.K + + # Project 3D points onto the camera image plane + if frame == 'c': + Xc = K.bmm(X_flat) + elif frame == 'w': + Xc = K.bmm(self.Tcw @ X_flat) + else: + raise ValueError('Unknown reference frame {}'.format(frame)) + + # Extract coordinates + Z = Xc[:, 2].clamp(min=1e-5) + XZ = Xc[:, 0] / Z + YZ = Xc[:, 1] / Z + + # Normalize points + if normalize and hw is not None: + XZ = 2 * XZ / (hw[1] - 1) - 1. + YZ = 2 * YZ / (hw[0] - 1) - 1. + + # Clamp out-of-bounds pixels + Xmask = ((XZ > 1) + (XZ < -1)).detach() + XZ[Xmask] = 2. + Ymask = ((XZ > 1) + (YZ < -1)).detach() + YZ[Ymask] = 2. + + # Stack X and Y coordinates + XY = torch.stack([XZ, YZ], dim=-1) + # Reshape coordinates to a grid if possible + if is_grid and hw is not None: + XY = XY.view(X.shape[0], hw[0], hw[1], 2) + + # If also returning depth + if return_z: + # Reshape depth values to a grid if possible + if is_grid and hw is not None: + Z = Z.view(X.shape[0], hw[0], hw[1], 1).permute(0, 3, 1, 2) + # Otherwise, reshape to an array + else: + Z = Z.view(X.shape[0], -1, 1).permute(0, 2, 1) + # Return coordinates and depth values + return XY, Z + else: + # Return coordinates + return XY diff --git a/vidar/arch/networks/layers/fsm/camera_utils.py b/vidar/arch/networks/layers/fsm/camera_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f42b646086443ebd49f57e1b5b5b4fa81101b545 --- /dev/null +++ b/vidar/arch/networks/layers/fsm/camera_utils.py @@ -0,0 +1,112 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import torch +import torch.nn.functional as tf + + +def construct_K(fx, fy, cx, cy, dtype=torch.float, device=None): + """Construct a [3,3] camera intrinsics from pinhole parameters""" + return torch.tensor([[fx, 0, cx], + [ 0, fy, cy], + [ 0, 0, 1]], dtype=dtype, device=device) + + +def scale_intrinsics(K, x_scale, y_scale): + """Scale intrinsics given x_scale and y_scale factors""" + K = K.clone() + K[..., 0, 0] *= x_scale + K[..., 1, 1] *= y_scale + # K[..., 0, 2] = (K[..., 0, 2] + 0.5) * x_scale - 0.5 + # K[..., 1, 2] = (K[..., 1, 2] + 0.5) * y_scale - 0.5 + K[..., 0, 2] = K[..., 0, 2] * x_scale + K[..., 1, 2] = K[..., 1, 2] * y_scale + return K + + +def invert_intrinsics(K): + """Invert camera intrinsics""" + Kinv = K.clone() + Kinv[:, 0, 0] = 1. / K[:, 0, 0] + Kinv[:, 1, 1] = 1. / K[:, 1, 1] + Kinv[:, 0, 2] = -1. * K[:, 0, 2] / K[:, 0, 0] + Kinv[:, 1, 2] = -1. * K[:, 1, 2] / K[:, 1, 1] + return Kinv + + +def view_synthesis(ref_image, depth, ref_cam, cam, scene_flow=None, + mode='bilinear', padding_mode='zeros', align_corners=True): + """ + Synthesize an image from another plus a depth map. + + Parameters + ---------- + ref_image : torch.Tensor + Reference image to be warped [B,3,H,W] + depth : torch.Tensor + Depth map from the original image [B,1,H,W] + ref_cam : Camera + Camera class for the reference image + cam : Camera + Camera class for the original image + scene_flow : torch.Tensor + Scene flow use for warping [B,3,H,W] + mode : str + Mode for grid sampling + padding_mode : str + Padding mode for grid sampling + align_corners : bool + Corner alignment for grid sampling + + Returns + ------- + ref_warped : torch.Tensor + Warped reference image in the original frame of reference [B,3,H,W] + """ + assert depth.shape[1] == 1, 'Depth map should have C=1' + # Reconstruct world points from target_camera + world_points = cam.reconstruct(depth, frame='w', scene_flow=scene_flow) + # Project world points onto reference camera + ref_coords = ref_cam.project(world_points, frame='w') + # View-synthesis given the projected reference points + return tf.grid_sample(ref_image, ref_coords, mode=mode, + padding_mode=padding_mode, align_corners=align_corners) + + +def view_synthesis_generic(ref_image, depth, ref_cam, cam, + mode='bilinear', padding_mode='zeros', align_corners=True, + progress=0.0): + """ + Synthesize an image from another plus a depth map. + + Parameters + ---------- + ref_image : torch.Tensor + Reference image to be warped [B,3,H,W] + depth : torch.Tensor + Depth map from the original image [B,1,H,W] + ref_cam : Camera + Camera class for the reference image + cam : Camera + Camera class for the original image + mode : str + Interpolation mode + padding_mode : str + Padding mode for interpolation + align_corners : bool + Corner alignment for grid sampling + progress : float + Training process (percentage) + + Returns + ------- + ref_warped : torch.Tensor + Warped reference image in the original frame of reference [B,3,H,W] + """ + assert depth.shape[1] == 1, 'Depth map should have C=1' + # Reconstruct world points from target_camera + world_points = cam.reconstruct(depth, frame='w') + # Project world points onto reference camera + ref_coords = ref_cam.project(world_points, progress=progress, frame='w') + # View-synthesis given the projected reference points + return tf.grid_sample(ref_image, ref_coords, mode=mode, + padding_mode=padding_mode, align_corners=align_corners) diff --git a/vidar/arch/networks/layers/fsm/pose.py b/vidar/arch/networks/layers/fsm/pose.py new file mode 100644 index 0000000000000000000000000000000000000000..cfea5b408e729c65fbca6b217655ed0ee908f8c4 --- /dev/null +++ b/vidar/arch/networks/layers/fsm/pose.py @@ -0,0 +1,108 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import torch + +from vidar.geometry.pose_utils import invert_pose, pose_vec2mat + + +class Pose: + """ + Pose class, that encapsulates a [4,4] transformation matrix + for a specific reference frame + """ + def __init__(self, mat): + """ + Initializes a Pose object. + + Parameters + ---------- + mat : torch.Tensor + Transformation matrix [B,4,4] + """ + assert tuple(mat.shape[-2:]) == (4, 4) + if mat.dim() == 2: + mat = mat.unsqueeze(0) + assert mat.dim() == 3 + self.mat = mat + + def __len__(self): + """Batch size of the transformation matrix""" + return len(self.mat) + + def __getitem__(self, i): + return Pose(self.mat[i].unsqueeze(0)).to(self.device) + + @property + def device(self): + """Return pose device""" + return self.mat.device + + @classmethod + def identity(cls, N=1, device=None, dtype=torch.float): + """Initializes as a [4,4] identity matrix""" + return cls(torch.eye(4, device=device, dtype=dtype).repeat([N,1,1])) + + @classmethod + def from_vec(cls, vec, mode): + """Initializes from a [B,6] batch vector""" + mat = pose_vec2mat(vec, mode) # [B,3,4] + pose = torch.eye(4, device=vec.device, dtype=vec.dtype).repeat([len(vec), 1, 1]) + pose[:, :3, :3] = mat[:, :3, :3] + pose[:, :3, -1] = mat[:, :3, -1] + return cls(pose) + + @property + def shape(self): + """Returns the transformation matrix shape""" + return self.mat.shape + + def item(self): + """Returns the transformation matrix""" + return self.mat + + def repeat(self, *args, **kwargs): + """Repeats the transformation matrix multiple times""" + self.mat = self.mat.repeat(*args, **kwargs) + return self + + def inverse(self): + """Returns a new Pose that is the inverse of this one""" + return Pose(invert_pose(self.mat)) + + def to(self, *args, **kwargs): + """Moves object to a specific device""" + self.mat = self.mat.to(*args, **kwargs) + return self + + def transform_pose(self, pose): + """Creates a new pose object that compounds this and another one (self * pose)""" + assert tuple(pose.shape[-2:]) == (4, 4) + return Pose(self.mat.bmm(pose.item())) + + def transform_points(self, points): + """Transforms 3D points using this object""" + assert 2 < points.dim() <= 4 and points.shape[1] == 3, \ + 'Wrong dimensions for transform_points' + # Determine if input is a grid + is_grid = points.dim() == 4 + # If it's a grid, flatten it + points_flat = points.view(points.shape[0], 3, -1) if is_grid else points + # Tranform points + out = self.mat[:, :3, :3].bmm(points_flat) + \ + self.mat[:, :3, -1].unsqueeze(-1) + # Return transformed points + return out.view(points.shape) if is_grid else out + + def __matmul__(self, other): + """Transforms the input (Pose or 3D points) using this object""" + if isinstance(other, Pose): + return self.transform_pose(other) + elif isinstance(other, torch.Tensor): + if other.shape[1] == 3 and other.dim() > 2: + assert other.dim() == 3 or other.dim() == 4 + return self.transform_points(other) + else: + raise ValueError('Unknown tensor dimensions {}'.format(other.shape)) + else: + raise NotImplementedError() + diff --git a/vidar/arch/networks/layers/fsm/utils.py b/vidar/arch/networks/layers/fsm/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3c6dda22891a38da274cdd3879f248881a2a4f57 --- /dev/null +++ b/vidar/arch/networks/layers/fsm/utils.py @@ -0,0 +1,298 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import torch +import torch.nn.functional as tfunc + +from vidar.utils.types import is_list + + +def coords_from_motion(ref_camera, tgt_depth, tgt_camera, scene_flow=None): + """ + Get coordinates from motion (depth + ego-motion) information + + Parameters + ---------- + ref_camera : Camera + Reference camera + tgt_depth : Tensor + Target depth map [B,1,H,W] + tgt_camera : Camera + Target camera + scene_flow : Tensor + Target optical flow + + Returns + ------- + coords : torch.Tensor + Warping coordinates [B,2,H,W] + """ + # If there are multiple reference cameras, iterate for each + if is_list(ref_camera): + return [coords_from_motion(camera, tgt_depth, tgt_camera, scene_flow) + for camera in ref_camera] + # If there are multiple depth maps, iterate for each + if is_list(tgt_depth): + return [coords_from_motion(ref_camera, depth, tgt_camera, scene_flow) + for depth in tgt_depth] + # Reconstruct and reproject points to generate warping coordinates + world_points = tgt_camera.reconstruct(tgt_depth, frame='w', scene_flow=scene_flow) + return ref_camera.project(world_points, frame='w').permute(0, 3, 1, 2).contiguous() + + +def mask_from_coords(coords): + """ + Get overlap mask from coordinates + + Parameters + ---------- + coords : Tensor + Warping coordinates [B,2,H,W] + + Returns + ------- + mask : Tensor + Overlap mask [B,1,H,W] + """ + # If there are multiple warping coordinates, iterate for each + if is_list(coords): + return [mask_from_coords(coord) for coord in coords] + # Create and return mask + b, _, h, w = coords.shape + mask = torch.ones((b, 1, h, w), dtype=torch.float32, device=coords.device, requires_grad=False) + mask = warp_from_coords(mask, coords, mode='bilinear', padding_mode='zeros', align_corners=True) + return mask.bool() + + +def warp_from_coords(tensor, coords, mask=False, mode='bilinear', + padding_mode='zeros', align_corners=True): + """ + Warp an image from a coordinate map + + Parameters + ---------- + tensor : torch.Tensor + Input tensor for warping [B,?,H,W] + coords : torch.Tensor + Warping coordinates [B,2,H,W] + mask : Bool + Whether the warped tensor is masked for non-overlapping regions + mode : String + Warping mode + padding_mode : String + Padding mode + align_corners : Bool + Align corners flag + + Returns + ------- + warp : torch.Tensor + Warped tensor [B,?,H,W] + """ + # Sample grid from data with coordinates + warp = tfunc.grid_sample(tensor, coords.permute(0, 2, 3, 1).contiguous(), + mode=mode, padding_mode=padding_mode, + align_corners=align_corners) + # If masking + if mask: + mask = torch.ones_like(tensor, requires_grad=False) + mask = tfunc.grid_sample(mask, coords.permute(0, 2, 3, 1).contiguous()) + warp = warp * (mask >= 1.0).detach() + # Returned warped tensor + return warp + + +def filter_dict(dictionary, keywords): + """ + Returns only the keywords that are part of a dictionary + + Parameters + ---------- + dictionary : dict + Dictionary for filtering + keywords : list of str + Keywords that will be filtered + + Returns + ------- + keywords : list of str + List containing the keywords that are keys in dictionary + """ + return [key for key in keywords if key in dictionary] + + +def merge_outputs(*outputs): + """ + Merges model outputs for logging + + Parameters + ---------- + outputs : tuple of dict + Outputs to be merged + + Returns + ------- + output : dict + Dictionary with a "metrics" key containing a dictionary with various metrics and + all other keys that are not "loss" (it is handled differently). + """ + ignore = ['loss'] # Keys to ignore + combine = ['metrics'] # Keys to combine + merge = {key: {} for key in combine} + for output in outputs: + # Iterate over all keys + for key, val in output.items(): + # Combine these keys + if key in combine: + for sub_key, sub_val in output[key].items(): + assert sub_key not in merge[key].keys(), \ + 'Combining duplicated key {} to {}'.format(sub_key, key) + merge[key][sub_key] = sub_val + # Ignore these keys + elif key not in ignore: + assert key not in merge.keys(), \ + 'Adding duplicated key {}'.format(key) + merge[key] = val + return merge + + +def flip_batch_input(batch): + """ + Flip batch input information (copies data first) + + Parameters + ---------- + batch : dict + Batch information + + Returns + ------- + batch : dict + Flipped batch + """ + # Flip images and input depth + for key in filter_dict(batch, [ + 'rgb', 'input_depth' + ]): + batch[key] = flip_lr(batch[key]) + # Flip context images + for key in filter_dict(batch, [ + 'rgb_context', + ]): + batch[key] = [flip_lr(img) for img in batch[key]] + # Flip intrinsics + for key in filter_dict(batch, [ + 'intrinsics' + ]): + batch[key] = batch[key].clone() + batch[key][:, 0, 2] = batch['rgb'].shape[3] - batch[key][:, 0, 2] + # Return flipped batch + return batch + + +def flip_output(output): + """ + Flip output information + + Parameters + ---------- + output : dict + Dictionary of model outputs (e.g. with keys like 'inv_depths' and 'uncertainty') + + Returns + ------- + output : dict + Flipped output + """ + # Flip list of tensors + for key in filter_dict(output, [ + 'inv_depths', 'uncertainty', 'logits_semantic' + ]): + output[key] = [flip_lr(val) for val in output[key]] + return output + + +class CameraNormalizer: + """ + Camera normalizer class. + Initialized with a desired focal lenght, and will normalize images to follow these values. + These images can then be unormalized to return to the original resolution/intrinsics + + Parameters + ---------- + focal : tuple + Focal lengths (fx, fy) + """ + def __init__(self, focal, mode='reflect'): + self.focal = focal + self.mode = mode + self.diffs = [] + + def normalize(self, rgb, intrinsics): + """ + Normalize input image + + Parameters + ---------- + rgb : torch.Tensor + Input image [B,3,H,W] + intrinsics : torch.Tensor + Input intrinsics [B,3,3] + + Returns + ------- + rgb_pad : torch.Tensor + Normalized image with padding [B,3,H,W] + """ + rgb_pad = [] + self.diffs.clear() + # Process each image independently + for i in range(len(rgb)): + rgb_i = rgb[i].unsqueeze(0) + intrinsics_i = intrinsics[i] + wh_orig = rgb_i.shape[2:] + # Get resize ratio + ratio = [float(self.focal[1] / intrinsics_i[1, 1]), + float(self.focal[0] / intrinsics_i[0, 0])] + wh_norm = [int(o * r) for o, r in zip(wh_orig, ratio)] + # Resize image + rgb_i_norm = torch.nn.functional.interpolate( + rgb_i, size=wh_norm, mode='bilinear', align_corners=True) + # Pad image + diff = [int(o - n) for o, n in zip(wh_orig, wh_norm)] + rgb_i_pad = torch.nn.functional.pad( + rgb_i_norm, pad=[diff[1] // 2, (diff[1] + 1) // 2, + diff[0] // 2, (diff[0] + 1) // 2], mode=self.mode) + rgb_pad.append(rgb_i_pad) + self.diffs.append(diff) + # Return concatenation of all images + return torch.cat(rgb_pad, 0) + + def unormalize(self, rgb): + """ + Unormalize image following the previous normalization. + + Parameters + ---------- + rgb : torch.Tensor + Normalized image with padding [B,3,H,W] + + Returns + ------- + orig_rgb : torch.Tensor + Original image + """ + # If it's a list, unnormalize each one + if is_list(rgb): + return [self.unormalize(r) for r in rgb] + rgb_orig = [] + hw = rgb.shape[2:] + for i in range(len(rgb)): + rgb_i = rgb[i].unsqueeze(0) + diff_i = self.diffs[i] + rgb_i = rgb_i[:, :, + (diff_i[0] // 2): (hw[0] - (diff_i[0] + 1) // 2), + (diff_i[1] // 2): (hw[1] - (diff_i[1] + 1) // 2)] + rgb_i_orig = torch.nn.functional.interpolate( + rgb_i, size=hw, mode='bilinear', align_corners=True) + rgb_orig.append(rgb_i_orig) + return torch.cat(rgb_orig, 0) \ No newline at end of file diff --git a/vidar/arch/networks/layers/inits.py b/vidar/arch/networks/layers/inits.py new file mode 100644 index 0000000000000000000000000000000000000000..5438597d5bd8145769b2539b2602f51d75e612c5 --- /dev/null +++ b/vidar/arch/networks/layers/inits.py @@ -0,0 +1,14 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import torch +import torch.nn as nn + + +def weights_init_xavier(m): + """Xavier weight initialization""" + if isinstance(m, nn.Conv2d): + torch.nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + torch.nn.init.zeros_(m.bias) + + diff --git a/vidar/arch/networks/layers/packnet/__init__.py b/vidar/arch/networks/layers/packnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vidar/arch/networks/layers/packnet/packnet.py b/vidar/arch/networks/layers/packnet/packnet.py new file mode 100644 index 0000000000000000000000000000000000000000..e14fa5fa9bbf4f069d1756ed2dcb57b8b33fac20 --- /dev/null +++ b/vidar/arch/networks/layers/packnet/packnet.py @@ -0,0 +1,288 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from functools import partial + +import torch +import torch.nn as nn + + +class Conv2D(nn.Module): + """ + 2D convolution with GroupNorm and ELU + + Parameters + ---------- + in_channels : int + Number of input channels + out_channels : int + Number of output channels + kernel_size : int + Kernel size + stride : int + Stride + """ + def __init__(self, in_channels, out_channels, kernel_size, stride): + super().__init__() + self.kernel_size = kernel_size + self.conv_base = nn.Conv2d( + in_channels, out_channels, kernel_size=kernel_size, stride=stride) + self.pad = nn.ConstantPad2d([kernel_size // 2] * 4, value=0) + self.normalize = torch.nn.GroupNorm(16, out_channels) + self.activ = nn.ELU(inplace=True) + + def forward(self, x): + """Runs the Conv2D layer.""" + x = self.conv_base(self.pad(x)) + return self.activ(self.normalize(x)) + + +class ResidualConv(nn.Module): + """2D Convolutional residual block with GroupNorm and ELU""" + def __init__(self, in_channels, out_channels, stride, dropout=None): + """ + Initializes a ResidualConv object. + + Parameters + ---------- + in_channels : int + Number of input channels + out_channels : int + Number of output channels + stride : int + Stride + dropout : float + Dropout value + """ + super().__init__() + self.conv1 = Conv2D(in_channels, out_channels, 3, stride) + self.conv2 = Conv2D(out_channels, out_channels, 3, 1) + self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride) + self.normalize = torch.nn.GroupNorm(16, out_channels) + self.activ = nn.ELU(inplace=True) + + if dropout: + # self.dropout = nn.Dropout(dropout) + self.conv3 = nn.Sequential(self.conv3, nn.Dropout2d(dropout)) + else: + self.dropout = None + + def forward(self, x): + """Runs the ResidualConv layer.""" + x_out = self.conv1(x) + x_out = self.conv2(x_out) + shortcut = self.conv3(x) + # if self.dropout: + # shortcut = self.dropout(shortcut) + return self.activ(self.normalize(x_out + shortcut)) + + +def ResidualBlock(in_channels, out_channels, num_blocks, stride, dropout=None): + """ + Returns a ResidualBlock with various ResidualConv layers. + + Parameters + ---------- + in_channels : int + Number of input channels + out_channels : int + Number of output channels + num_blocks : int + Number of residual blocks + stride : int + Stride + dropout : float + Dropout value + """ + layers = [ResidualConv(in_channels, out_channels, stride, dropout=dropout)] + for i in range(1, num_blocks): + layers.append(ResidualConv(out_channels, out_channels, 1, dropout=dropout)) + return nn.Sequential(*layers) + + +class InvDepth(nn.Module): + """Inverse depth layer""" + def __init__(self, in_channels, out_channels=1, min_depth=0.5): + """ + Initializes an InvDepth object. + + Parameters + ---------- + in_channels : int + Number of input channels + out_channels : int + Number of output channels + min_depth : float + Minimum depth value to calculate + """ + super().__init__() + self.min_depth = min_depth + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1) + self.pad = nn.ConstantPad2d([1] * 4, value=0) + self.activ = nn.Sigmoid() + + def forward(self, x): + """Runs the InvDepth layer.""" + x = self.conv1(self.pad(x)) + return self.activ(x) / self.min_depth + + +def packing(x, r=2): + """ + Takes a [B,C,H,W] tensor and returns a [B,(r^2)C,H/r,W/r] tensor, by concatenating + neighbor spatial pixels as extra channels. It is the inverse of nn.PixelShuffle + (if you apply both sequentially you should get the same tensor) + + Parameters + ---------- + x : torch.Tensor [B,C,H,W] + Input tensor + r : int + Packing ratio + + Returns + ------- + out : torch.Tensor [B,(r^2)C,H/r,W/r] + Packed tensor + """ + b, c, h, w = x.shape + out_channel = c * (r ** 2) + out_h, out_w = h // r, w // r + x = x.contiguous().view(b, c, out_h, r, out_w, r) + return x.permute(0, 1, 3, 5, 2, 4).contiguous().view(b, out_channel, out_h, out_w) + + +class PackLayerConv2d(nn.Module): + """ + Packing layer with 2d convolutions. Takes a [B,C,H,W] tensor, packs it + into [B,(r^2)C,H/r,W/r] and then convolves it to produce [B,C,H/r,W/r]. + """ + def __init__(self, in_channels, kernel_size, r=2): + """ + Initializes a PackLayerConv2d object. + + Parameters + ---------- + in_channels : int + Number of input channels + kernel_size : int + Kernel size + r : int + Packing ratio + """ + super().__init__() + self.conv = Conv2D(in_channels * (r ** 2), in_channels, kernel_size, 1) + self.pack = partial(packing, r=r) + + def forward(self, x): + """Runs the PackLayerConv2d layer.""" + x = self.pack(x) + x = self.conv(x) + return x + + +class UnpackLayerConv2d(nn.Module): + """ + Unpacking layer with 2d convolutions. Takes a [B,C,H,W] tensor, convolves it + to produce [B,(r^2)C,H,W] and then unpacks it to produce [B,C,rH,rW]. + """ + def __init__(self, in_channels, out_channels, kernel_size, r=2): + """ + Initializes a UnpackLayerConv2d object. + + Parameters + ---------- + in_channels : int + Number of input channels + out_channels : int + Number of output channels + kernel_size : int + Kernel size + r : int + Packing ratio + """ + super().__init__() + self.conv = Conv2D(in_channels, out_channels * (r ** 2), kernel_size, 1) + self.unpack = nn.PixelShuffle(r) + + def forward(self, x): + """Runs the UnpackLayerConv2d layer.""" + x = self.conv(x) + x = self.unpack(x) + return x + + +class PackLayerConv3d(nn.Module): + """ + Packing layer with 3d convolutions. Takes a [B,C,H,W] tensor, packs it + into [B,(r^2)C,H/r,W/r] and then convolves it to produce [B,C,H/r,W/r]. + """ + def __init__(self, in_channels, kernel_size, r=2, d=8): + """ + Initializes a PackLayerConv3d object. + + Parameters + ---------- + in_channels : int + Number of input channels + kernel_size : int + Kernel size + r : int + Packing ratio + d : int + Number of 3D features + """ + super().__init__() + self.conv = Conv2D(in_channels * (r ** 2) * d, in_channels, kernel_size, 1) + self.pack = partial(packing, r=r) + self.conv3d = nn.Conv3d(1, d, kernel_size=(3, 3, 3), + stride=(1, 1, 1), padding=(1, 1, 1)) + + def forward(self, x): + """Runs the PackLayerConv3d layer.""" + x = self.pack(x) + x = x.unsqueeze(1) + x = self.conv3d(x) + b, c, d, h, w = x.shape + x = x.view(b, c * d, h, w) + x = self.conv(x) + return x + + +class UnpackLayerConv3d(nn.Module): + """ + Unpacking layer with 3d convolutions. Takes a [B,C,H,W] tensor, convolves it + to produce [B,(r^2)C,H,W] and then unpacks it to produce [B,C,rH,rW]. + """ + def __init__(self, in_channels, out_channels, kernel_size, r=2, d=8): + """ + Initializes a UnpackLayerConv3d object. + + Parameters + ---------- + in_channels : int + Number of input channels + out_channels : int + Number of output channels + kernel_size : int + Kernel size + r : int + Packing ratio + d : int + Number of 3D features + """ + super().__init__() + self.conv = Conv2D(in_channels, out_channels * (r ** 2) // d, kernel_size, 1) + self.unpack = nn.PixelShuffle(r) + self.conv3d = nn.Conv3d(1, d, kernel_size=(3, 3, 3), + stride=(1, 1, 1), padding=(1, 1, 1)) + + def forward(self, x): + """Runs the UnpackLayerConv3d layer.""" + x = self.conv(x) + x = x.unsqueeze(1) + x = self.conv3d(x) + b, c, d, h, w = x.shape + x = x.view(b, c * d, h, w) + x = self.unpack(x) + return x + diff --git a/vidar/arch/networks/layers/resnet_encoder.py b/vidar/arch/networks/layers/resnet_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..7f69984cc5586665e82ab40ea1c2fdf8ae357b40 --- /dev/null +++ b/vidar/arch/networks/layers/resnet_encoder.py @@ -0,0 +1,46 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import torch.nn as nn +import torchvision.models as models + + +class ResNetEncoder(nn.Module): + """Creates a ResNet encoder with different parameters""" + def __init__(self, name): + super().__init__() + if name == 'densenet121': + self.base_model = models.densenet121(pretrained=True).features + self.feat_names = ['relu0', 'pool0', 'transition1', 'transition2', 'norm5'] + self.feat_out_channels = [64, 64, 128, 256, 1024] + elif name == 'densenet161': + self.base_model = models.densenet161(pretrained=True).features + self.feat_names = ['relu0', 'pool0', 'transition1', 'transition2', 'norm5'] + self.feat_out_channels = [96, 96, 192, 384, 2208] + elif name == 'resnet50': + self.base_model = models.resnet50(pretrained=True) + self.feat_names = ['relu', 'layer1', 'layer2', 'layer3', 'layer4'] + self.feat_out_channels = [64, 256, 512, 1024, 2048] + elif name == 'resnet101': + self.base_model = models.resnet101(pretrained=True) + self.feat_names = ['relu', 'layer1', 'layer2', 'layer3', 'layer4'] + self.feat_out_channels = [64, 256, 512, 1024, 2048] + elif name == 'resnext50': + self.base_model = models.resnext50_32x4d(pretrained=True) + self.feat_names = ['relu', 'layer1', 'layer2', 'layer3', 'layer4'] + self.feat_out_channels = [64, 256, 512, 1024, 2048] + elif name == 'resnext101': + self.base_model = models.resnext101_32x8d(pretrained=True) + self.feat_names = ['relu', 'layer1', 'layer2', 'layer3', 'layer4'] + self.feat_out_channels = [64, 256, 512, 1024, 2048] + else: + raise NotImplementedError('Not supported encoder: {}'.format(name)) + + def forward(self, x): + features, skips = [x], [x] + for key, val in self.base_model._modules.items(): + if not any(x in key for x in ['fc', 'avgpool']): + feature = val(features[-1]) + features.append(feature) + if any(x in key for x in self.feat_names): + skips.append(feature) + return skips diff --git a/vidar/arch/networks/perceiver/DefineNet.py b/vidar/arch/networks/perceiver/DefineNet.py new file mode 100755 index 0000000000000000000000000000000000000000..8e087a8293072816219e66b595b3a5f9b1c56e15 --- /dev/null +++ b/vidar/arch/networks/perceiver/DefineNet.py @@ -0,0 +1,371 @@ +from collections import OrderedDict + +import torch +import torch.nn as nn +import abc +from transformers import PerceiverModel, PerceiverConfig +from transformers.models.perceiver.modeling_perceiver import build_position_encoding + +from vidar.arch.networks.perceiver.externals.modeling_perceiver import PerceiverDepthDecoder, PerceiverRGBDecoder, build_position_encoding +from vidar.arch.blocks.depth.SigmoidToInvDepth import SigmoidToInvDepth +from vidar.arch.networks.decoders.DepthDecoder import DepthDecoder +from vidar.arch.networks.encoders.ResNetEncoder import ResNetEncoder +from vidar.utils.config import Config +from vidar.utils.networks import freeze_layers_and_norms +from vidar.utils.tensor import interpolate +from vidar.utils.types import is_int + + +class DownSampleRGB(nn.Module): + def __init__(self, out_dim): + super().__init__() + self.conv = torch.nn.Conv2d(3, out_dim, kernel_size=7, stride=2, padding=3) + self.norm = torch.nn.BatchNorm2d(out_dim) + self.actv = torch.nn.ReLU() + self.pool = torch.nn.MaxPool2d(2, stride=2) + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + x = self.actv(x) + x = self.pool(x) + return x + + +class DefineNet(nn.Module): + def __init__(self, cfg): + super().__init__() + + self.tasks = cfg.tasks + self.to_world = cfg.to_world + self.depth_range = cfg.depth_range + self.rgb_feat_dim = cfg.rgb_feat_dim + self.rgb_feat_type = cfg.rgb_feat_type + self.encoder_with_rgb = cfg.encoder_with_rgb + self.decoder_with_rgb = cfg.decoder_with_rgb + self.output_mode = cfg.output_mode + self.sample_encoding_rays = cfg.sample_encoding_rays + self.with_monodepth = cfg.with_monodepth + + self.upsample_convex = cfg.upsample_convex + self.downsample_encoder = cfg.downsample_encoder + self.downsample_decoder = cfg.downsample_decoder + + self.image_shape = [s // self.downsample_encoder for s in cfg.image_shape] + + self.fourier_encoding_orig, _ = build_position_encoding( + position_encoding_type='fourier', + fourier_position_encoding_kwargs={ + 'num_bands': cfg.num_bands_orig, + 'max_resolution': [cfg.max_resolution_orig] * 3, + 'concat_pos': True, + 'sine_only': False, + } + ) + + self.fourier_encoding_dirs, _ = build_position_encoding( + position_encoding_type='fourier', + fourier_position_encoding_kwargs={ + 'num_bands': cfg.num_bands_dirs, + 'max_resolution': [cfg.num_bands_dirs] * 3, + 'concat_pos': True, + 'sine_only': False, + } + ) + + tot_encoder = self.fourier_encoding_orig.output_size() + \ + self.fourier_encoding_dirs.output_size() + if self.encoder_with_rgb: + tot_encoder += self.rgb_feat_dim + + tot_decoder = self.fourier_encoding_orig.output_size() + \ + self.fourier_encoding_dirs.output_size() + if self.decoder_with_rgb: + tot_decoder += self.rgb_feat_dim + + tot_decoder_depth = tot_decoder + tot_decoder_rgb = tot_decoder + + self.config = PerceiverConfig( + train_size=self.image_shape, + d_latents=cfg.d_latents, + d_model=tot_encoder, + num_latents=cfg.num_latents, + hidden_act='gelu', + hidden_dropout_prob=cfg.hidden_dropout_prob, + initializer_range=0.02, + layer_norm_eps=1e-12, + num_blocks=1, + num_cross_attention_heads=cfg.num_cross_attention_heads, + num_self_attends_per_block=cfg.num_self_attends_per_block, + num_self_attention_heads=cfg.num_self_attention_heads, + qk_channels=None, + v_channels=None, + ) + + if 'depth' in self.tasks: + self.decoder = PerceiverDepthDecoder( + self.config, + num_channels=tot_decoder_depth, + use_query_residual=False, + output_num_channels=1, + position_encoding_type="none", + min_depth=self.depth_range[0], + max_depth=self.depth_range[1], + num_heads=cfg.decoder_num_heads, + upsample_mode=cfg.upsample_convex, + upsample_value=cfg.downsample_decoder, + output_mode=cfg.output_mode + ) + if 'rgb' in self.tasks: + self.decoder_rgb = PerceiverRGBDecoder( + self.config, + num_channels=tot_decoder_rgb, + use_query_residual=False, + output_num_channels=3, + position_encoding_type="none", + num_heads=cfg.decoder_num_heads, + upsample_mode=cfg.upsample_convex, + upsample_value=cfg.downsample_decoder, + ) + + self.model = PerceiverModel( + self.config, + ) + + if self.rgb_feat_type == 'convnet': + self.feature = DownSampleRGB(out_dim=self.rgb_feat_dim) + elif self.rgb_feat_type in ['resnet', 'resnet_all', 'resnet_all_rgb']: + self.feature = ResNetEncoder(Config(version=18, pretrained=True, num_rgb_in=1)) + + if self.with_monodepth: + self.mono_encoder = ResNetEncoder(Config(version=18, pretrained=True, num_rgb_in=1)) + self.mono_decoder = DepthDecoder(Config( + num_scales=4, use_skips=True, num_ch_enc=self.feature.num_ch_enc, + num_ch_out=1, activation='sigmoid', + )) + self.sigmoid_to_depth = SigmoidToInvDepth( + min_depth=self.depth_range[0], max_depth=self.depth_range[1], return_depth=True) + + def get_rgb_feat(self, rgb): + if self.rgb_feat_type == 'convnet': + return { + 'feat': self.feature(rgb) + } + elif self.rgb_feat_type == 'resnet': + return { + 'feat': self.feature(rgb)[1] + } + elif self.rgb_feat_type.startswith('resnet_all'): + all_feats = self.feature(rgb) + feats = all_feats[1:] + for i in range(1, len(feats)): + feats[i] = interpolate( + feats[i], size=feats[0], scale_factor=None, mode='bilinear', align_corners=True) + if self.rgb_feat_type.endswith('rgb'): + feats = feats + [interpolate( + rgb, size=feats[0], scale_factor=None, mode='bilinear', align_corners=True)] + feat = torch.cat(feats, 1) + return { + 'all_feats': all_feats, + 'feat': feat + } + + def run_monodepth(self, rgb, freeze): + freeze_layers_and_norms(self.mono_encoder, flag_freeze=freeze) + freeze_layers_and_norms(self.mono_decoder, flag_freeze=freeze) + mono_features = self.mono_encoder(rgb) + mono_output = self.mono_decoder(mono_features) + sigmoids = [mono_output[('output', i)] for i in range(1)] + return self.sigmoid_to_depth(sigmoids)[0] + + def embeddings(self, data, sources, downsample): + + if 'rgb' in sources: + assert 'rgb' in data[0].keys() + b = [datum['rgb'].shape[0] for datum in data] + rgb = torch.cat([datum['rgb'] for datum in data], 0) + output_feats = self.get_rgb_feat(rgb) + feats = torch.split(output_feats['feat'], b) + for i in range(len(data)): + data[i]['feat'] = feats[i] + + if self.with_monodepth: + depth = self.run_monodepth(rgb, freeze=False) + depth = torch.split(depth, b) + for i in range(len(data)): + data[i]['depth_mono'] = depth[i] + + encodings = [] + for datum in data: + + encoding = OrderedDict() + + if 'cam' in sources: + assert 'cam' in data[0].keys() + + cam = datum['cam'].scaled(1. / downsample) + orig = cam.get_origin(flatten=True) + + if self.to_world: + dirs = cam.get_viewdirs(normalize=True, flatten=True, to_world=True) + else: + dirs = cam.no_translation().get_viewdirs(normalize=True, flatten=True, to_world=True) + + orig_encodings = self.fourier_encoding_orig( + index_dims=None, pos=orig, batch_size=orig.shape[0], device=orig.device) + dirs_encodings = self.fourier_encoding_dirs( + index_dims=None, pos=dirs, batch_size=dirs.shape[0], device=dirs.device) + + encoding['cam'] = torch.cat([orig_encodings, dirs_encodings], -1) + + if 'rgb' in sources: + rgb = datum['feat'] + rgb_flat = rgb.view(*rgb.shape[:-2], -1).permute(0, 2, 1) + encoding['rgb'] = rgb_flat + + encoding['all'] = torch.cat([val for val in encoding.values()], -1) + encodings.append(encoding) + + return encodings + + @staticmethod + def sample_decoder(data, embeddings, field, sample_queries, filter_invalid): + + query_idx = [] + + if filter_invalid: + tot_min = [] + + for i in range(len(embeddings)): + for b in range(data[i]['rgb'].shape[0]): + tot_min.append((data[i]['rgb'][b].mean(0) >= 0).sum()) + tot_min = min(tot_min) + + tot = embeddings[0][field][0].shape[0] + tot = int(sample_queries * tot) + tot = min([tot, tot_min]) + + for i in range(len(embeddings)): + idx = [] + + for b in range(data[i]['rgb'].shape[0]): + if filter_invalid: + + valid = data[i]['rgb'][b].mean(0, keepdim=True) >= 0 + valid = valid.view(1, -1).permute(1, 0) + + num = embeddings[i][field][0].shape[0] + all_idx = torch.arange(num, device=valid.device).unsqueeze(1) + valid_idx = all_idx[valid] + + num = valid_idx.shape[0] + idx_i = torch.randperm(num)[tot:] + valid[valid_idx[idx_i]] = 0 + idx_i = all_idx[valid] + + else: + + num = embeddings[i][field][0].shape[0] + tot = int(sample_queries * num) + idx_i = torch.randperm(num)[:tot] + + idx.append(idx_i) + + idx = torch.stack(idx, 0) + embeddings[i][field] = torch.stack( + [embeddings[i][field][b][idx[b]] for b in range(idx.shape[0])], 0) + + query_idx.append(idx) + + return query_idx, embeddings + + def forward(self, encode_data, decode_data=None, + sample_queries=0, filter_invalid=False): + + encode_field = 'all' if self.encoder_with_rgb else 'cam' + decode_field = 'all' if self.decoder_with_rgb else 'cam' + + encode_sources = ['rgb', 'cam'] + decode_sources = ['cam'] + + shape = encode_data[0]['cam'].hw + + output = {} + + encode_dict = self.encode( + data=encode_data, field=encode_field, sources=encode_sources + ) + + if 'depth_mono' in encode_data[0].keys(): + output['depth_mono'] = [datum['depth_mono'] for datum in encode_data] + + decode_embeddings = encode_dict['embeddings'] if decode_data is None else None + + decode_dict = self.decode( + latent=encode_dict['latent'], shape=shape, + data=decode_data, embeddings=decode_embeddings, + field=decode_field, sources=decode_sources, + sample_queries=sample_queries, filter_invalid=filter_invalid + ) + + output.update(decode_dict['output']) + + return { + 'output': output, + 'encode_embeddings': encode_dict['embeddings'], + 'decode_embeddings': decode_dict['embeddings'], + 'latent': encode_dict['latent'], + } + + def encode(self, field, sources, data=None, embeddings=None): + assert data is not None or embeddings is not None + assert data is None or embeddings is None + + if embeddings is None: + embeddings = self.embeddings(data, sources=sources, downsample=self.downsample_encoder) + + all_embeddings = torch.cat([emb[field] for emb in embeddings], 1) + + if self.training and self.sample_encoding_rays > 0: + tot = self.sample_encoding_rays if is_int(self.sample_encoding_rays) \ + else int(self.sample_encoding_rays * all_embeddings.shape[1]) + all_embeddings = torch.stack([all_embeddings[i, torch.randperm(all_embeddings.shape[1])[:tot], :] + for i in range(all_embeddings.shape[0])], 0) + + return { + 'embeddings': embeddings, + 'latent': self.model(inputs=all_embeddings).last_hidden_state, + } + + def decode(self, latent, field, sources=None, data=None, embeddings=None, shape=None, + sample_queries=0, filter_invalid=False): + assert data is not None or embeddings is not None + assert data is None or embeddings is None + + if embeddings is None: + shape = data[0]['cam'].hw + shape = [s // self.downsample_decoder for s in shape] + embeddings = self.embeddings(data, sources=sources, downsample=self.downsample_decoder) + + output = {} + + if self.training and (sample_queries > 0): # or filter_invalid): + output['query_idx'], embeddings = self.sample_decoder( + data, embeddings, field, sample_queries, filter_invalid) + shape = None + + if 'rgb' in self.tasks: + output['rgb'] = [ + self.decoder_rgb(query=emb[field], z=latent, shape=shape).logits + for emb in embeddings] + + if 'depth' in self.tasks: + output['depth'] = [ + self.decoder(query=emb[field], z=latent, shape=shape).logits + for emb in embeddings] + + return { + 'embeddings': embeddings, + 'output': output, + } diff --git a/vidar/arch/networks/perceiver/MLPNet.py b/vidar/arch/networks/perceiver/MLPNet.py new file mode 100755 index 0000000000000000000000000000000000000000..00543ad989cc8bb6d0c2a9655d23da559e90abd2 --- /dev/null +++ b/vidar/arch/networks/perceiver/MLPNet.py @@ -0,0 +1,504 @@ +from collections import OrderedDict + +import torch +import torch.nn as nn + +from vidar.arch.networks.layers.define.decoders.camera import CameraDecoder +from vidar.arch.networks.layers.define.decoders.depth import DepthDecoder +from vidar.arch.networks.layers.define.decoders.multiview import MultiviewDecoder +from vidar.arch.networks.layers.define.decoders.normals import NormalsDecoder +from vidar.arch.networks.layers.define.decoders.rgb import RGBDecoder +from vidar.arch.networks.layers.define.decoders.semantic import SemanticDecoder +from vidar.arch.networks.layers.define.decoders.volumetric import VolumetricDecoder +from vidar.arch.networks.layers.define.embeddings.camera import CameraEmbeddings +from vidar.arch.networks.layers.define.embeddings.image import ImageEmbeddings +from vidar.arch.networks.layers.define.embeddings.multiview import MultiviewEmbeddings +from vidar.arch.networks.layers.define.embeddings.projection import ProjectionEmbeddings +from vidar.arch.networks.layers.define.embeddings.volumetric import VolumetricEmbeddings +from vidar.arch.networks.layers.define.perceiver.model import PerceiverModel +from vidar.utils.data import make_list, str_not_in +from vidar.utils.networks import freeze_layers_and_norms +from vidar.utils.types import is_int, is_dict, is_tensor, is_list + + +def enable_dropout(model): + """ Function to enable the dropout layers during test-time """ + for m in model.modules(): + if m.__class__.__name__.startswith('Dropout'): + m.train() + + +class MLPNet(nn.Module): + + DECODER_CLASSES = { + 'rgb': RGBDecoder, + 'depth': DepthDecoder, + 'normals': NormalsDecoder, + 'semantic': SemanticDecoder, + 'camera': CameraDecoder, + 'volumetric': VolumetricDecoder, + 'multiview': MultiviewDecoder, + } + + EMBEDDING_CLASSES = { + 'image': ImageEmbeddings, + 'camera': CameraEmbeddings, + 'volumetric': VolumetricEmbeddings, + 'multiview': MultiviewEmbeddings, + 'projection': ProjectionEmbeddings, + } + + def __init__(self, cfg): + super().__init__() + + # VARIATIONAL + + self.is_variational = cfg.has('variational') + self.variational_params = None if not self.is_variational else { + 'kld_weight': cfg.variational.kld_weight, + 'encode_vae': cfg.variational.encode_vae, + 'soft_mask_weight': cfg.variational.has('soft_mask_weight', 1.0), + } + + # Parameters + + self.encode_sources = cfg.encode_sources + self.decode_sources = cfg.decode_sources + self.extra_sources = cfg.has('extra_sources', []) + self.encode_cameras = cfg.has('encode_cameras', 'mod') + self.latent_grid = cfg.latent.has('mode') and cfg.latent.mode == 'grid' + self.latent_grid_dim = cfg.latent.has('grid_dim', 0) + + self.use_image_embeddings = cfg.encoder.has('use_image_embeddings', True) + + if len(self.encode_sources) == 0 or not is_list(self.encode_sources[0]): + self.encode_sources = [self.encode_sources] + + # Downsample + + self.downsample_encoder = cfg.downsample_encoder + self.downsample_decoder = cfg.downsample_decoder + + # Create embeddings + + self.embeddings = nn.ModuleDict() + for key in cfg.embeddings.keys(): + mode = key if '_' not in key else key.split('_')[0] + self.embeddings[key] = self.EMBEDDING_CLASSES[mode](cfg.embeddings.dict[key]) + + # Parse shared decoder parameters + + if cfg.decoders.has('shared'): + for task in cfg.decoders.keys(): + if task is not 'shared': + for key, val in cfg.decoders.shared.items(): + if key not in cfg.decoders.dict[task].keys(): + cfg.decoders.dict[task].dict[key] = val + + # Embeddings dimension + + tot_encoders = [self.total_dimension(sources) for sources in self.encode_sources] + + self.models = nn.ModuleList() + for i in range(len(tot_encoders)): + cfg.encoder.d_model = tot_encoders[i] + is_variational = self.is_variational and self.variational_params['encode_vae'][i] + self.models.append(PerceiverModel(cfg, is_variational=is_variational)) + + # Decoders [###### MOVED TO AFTER PERCEIVER MODEL ######] + + self.decoders = nn.ModuleDict() + for i in range(len(self.decode_sources)): + is_variational = self.is_variational and \ + self.variational_params['encode_vae'][self.decode_sources[i][1]] + self.decode_sources[i] += [is_variational] + for name, _, _, embeddings, decoders, _ in self.decode_sources: + tot_decoder = self.total_dimension(embeddings) + for dec in decoders: + if dec not in self.decoders.keys(): + cfg.decoders.dict[dec].name = name + cfg.decoders.dict[dec].num_channels = tot_decoder + cfg.decoders.dict[dec].d_latents = cfg.latent.dim + self.decoders[dec] = self.DECODER_CLASSES[dec.split('_')[0]](cfg.decoders.dict[dec]) + + self.bypass_self_attention = cfg.has('bypass_self_attention', True) + self.freeze_encoder = cfg.encoder.has('freeze', False) + self.really_encode = cfg.encoder.has('really_encode', True) + + @property + def tasks(self): + return self.decoders.keys() + + def total_dimension(self, sources): + if not self.use_image_embeddings and 'image' in sources: + sources = list(sources) + sources.remove('image') + return sum([self.embeddings[source].channels for source in sources if source not in self.extra_sources]) + \ + self.latent_grid_dim + + def get_embeddings(self, data, camera_mode, sources, downsample, encoded=None, sample=None, + previous=None, monodepth=None): + + cam_mode = f'cam_{camera_mode}' + + if 'image' in sources: + rgb_dict = {key: val['rgb'] for key, val in data.items()} + feats_dict = self.embeddings['image'](rgb_dict) + for key in data.keys(): + data[key]['feats'] = feats_dict[key] + + if monodepth is not None: + monodepth['depth_idx'] = {} + + embeddings = OrderedDict() + for key, val in data.items(): + + embedding = OrderedDict() + + cam_scaled = val[cam_mode].scaled(1. / downsample) + cam_scaled2 = None + + embedding['info'] = {} + + if self.training and sample is not None and sample > 0.0: + if previous is None: + idx, start = self.sample_decoder_idx(sample, cam_scaled.batch_size, cam_scaled.hw) + else: + idx, start = previous['info'][key]['idx'], previous['info'][key]['start'] + if start is not None: + cam_scaled2 = cam_scaled.offset_start(start).scaled(1. / sample) + val['cam_scaled'] = embedding['info']['cam_scaled'] = cam_scaled2 + else: + idx = start = None + val['cam_scaled'] = embedding['info']['cam_scaled'] = cam_scaled + + if monodepth is not None: + b, _, h, w = monodepth['depth'][key][0].shape + depth = monodepth['depth'][key][0].view(b, 1, -1).permute(0, 2, 1) + if idx is not None: + depth = torch.stack([depth[j][idx[j]] for j in range(len(idx))], 0) + # from vidar.utils.viz import viz_depth + # from vidar.utils.write import write_image + # write_image('depth.png', viz_depth(depth1.view(b, 1, 192, 640)[0])) + # write_image('depth_sub.png', viz_depth(depth2.view(b, 1, 48, 160)[0])) + # import sys + # sys.exit() + monodepth['depth_idx'][key] = depth + + embedding['info']['idx'] = idx + embedding['info']['start'] = start + + for source in sources: + if source.startswith('camera'): + val['camera'], embedding[source] = self.embeddings[source]( + cam_scaled, key, idx, + meta=val['meta'] if 'meta' in val else None) + if start is not None: + h, w = cam_scaled2.hw + b, n, c = val['camera'].shape + val['camera'] = val['camera'].permute(0, 2, 1).reshape(b, c, h, w) + + if 'image' in sources: + rgb = val['feats'] + embedding['image'] = rgb.view(*rgb.shape[:-2], -1).permute(0, 2, 1) + if idx is not None: + embedding['image'] = torch.stack([ + embedding['image'][i][idx[i]] for i in range(embedding['image'].shape[0])], 0) + + for source in sources: + if source.startswith('volumetric'): + embedding['info']['z_samples'], embedding[source] = self.embeddings[source]( + cam_scaled, key, data, previous, idx) + + for source in sources: + if source.startswith('multiview'): + if encoded is None: + encoded = {'data': data} + multiview_previous = previous if previous is not None else monodepth + embedding['info']['z_samples'], embedding['info']['xyz'], embedding[source] = \ + self.embeddings[source](cam_scaled, key, data, encoded, multiview_previous, idx) + embedding.pop('image', None) + + for source in sources: + if source.startswith('projection'): + embedding[source] = self.embeddings[source]( + cam_scaled2 if cam_scaled2 is not None else cam_scaled, + key, encoded, embedding['info'], idx) + + if idx is not None and previous is None: + data_key = data[key] + for key_gt in data_key.keys(): + if data_key[key_gt] is not None and not key_gt.startswith('cam'): + if is_tensor(data_key[key_gt]) and data_key[key_gt].dim() == 4: + data_key[key_gt] = data_key[key_gt].permute(0, 2, 3, 1).reshape( + data_key[key_gt].shape[0], -1, data_key[key_gt].shape[1]) + data_key[key_gt] = torch.stack([ + data_key[key_gt][i, idx[i]] for i in range(data_key[key_gt].shape[0])], 0) + if start is not None: + h, w = cam_scaled2.hw + b, n, c = data_key[key_gt].shape + data_key[key_gt] = data_key[key_gt].permute(0, 2, 1).reshape(b, c, h, w) + + embeddings[key] = embedding + + if monodepth is not None: + monodepth.pop('depth_idx') + + return embeddings + + @staticmethod + def sample_decoder_idx(sample_decoder, b, hw): + n = hw[0] * hw[1] + if is_int(sample_decoder): + idx, start = [], [] + for _ in range(b): + start_i = torch.randint(0, sample_decoder, (2,)) + idx_i = torch.arange(0, n).reshape(hw) + idx_i = idx_i[start_i[0]::sample_decoder, start_i[1]::sample_decoder].reshape(-1) + idx.append(idx_i) + start.append(start_i) + start = torch.stack(start, 0) + else: + tot = int(sample_decoder * n) + idx = [torch.randperm(n)[:tot] for _ in range(b)] + start = None + idx = torch.stack(idx, 0) + return idx, start + + def bypass_encode(self, data, scene, idx): + return { + 'embeddings': None, + 'latent': self.models[idx].embeddings(batch_size=1, scene=scene) if self.bypass_self_attention else + self.models[idx](data=None)['last_hidden_state'], + 'data': data, + } + + def encode(self, sources=None, data=None, embeddings=None, sample_encoder=0, scene=None): + assert data is not None or embeddings is not None + assert data is None or embeddings is None + + sources = sources if sources is not None else self.encode_sources + + return [self.single_encode(sources, data, embeddings, sample_encoder, scene, idx=i) + for i in range(len(sources))] + + def single_encode(self, sources=None, data=None, embeddings=None, sample_encoder=0, scene=None, idx=0): + assert data is not None or embeddings is not None + assert data is None or embeddings is None + + # Freeze encoder if requested + for model in self.models: + freeze_layers_and_norms(model, flag_freeze=self.freeze_encoder) + + # Get default sources if they are not provided + sources = sources if sources is not None else self.encode_sources + sources = sources[idx] + + camera_mode = self.encode_cameras[idx] if is_list(self.encode_cameras) else self.encode_cameras + + # Don't encode if there is no data or sources to use + if len(data) == 0 or len(sources) == 0: + return self.bypass_encode(data, scene, idx=idx) + + # Create embeddings if they are not provided + if embeddings is None: + embeddings = self.get_embeddings( + data=data, sources=sources, camera_mode=camera_mode, downsample=self.downsample_encoder) + embeddings = {key: torch.cat([val[source] for source in sources if source in val], -1) for key, val in embeddings.items()} + + # Sample embeddings if requested + if self.training and sample_encoder > 0: + for key in embeddings.keys(): + tot = sample_encoder if is_int(sample_encoder) \ + else int(sample_encoder * embeddings[key].shape[1]) + embeddings[key] = torch.stack([ + embeddings[key][i, torch.randperm(embeddings[key].shape[1])[:tot], :] + for i in range(embeddings[key].shape[0])], 0) + + # Don't encode if not requested + if not self.really_encode: + return self.bypass_encode(data, scene, idx=idx) + + # Encode embeddings + encode_output = self.models[idx](data=embeddings, scene=scene) + + # Return embeddings, latent space, and updated data + return { + 'embeddings': embeddings, + 'latent': encode_output['last_hidden_state'], + 'data': data, + } + + def decode(self, encoded, sources=None, data=None, embeddings=None, sample_decoder=0, scene=None, monodepth=None): + assert data is not None or embeddings is not None + assert data is None or embeddings is None + + # Get default sources if they are not provided + sources = sources if sources is not None else self.decode_sources + + # Initialize structures + outputs, previous = [], None + merged_output = {'losses': {}, 'embeddings': {}, 'output': {}} + + # Decode output for each source + for source in sources: + output = self.single_decode(encoded, source[1], source[2], source[3], source[4], + data, embeddings, previous, sample_decoder, + is_variational=source[5], scene=scene, monodepth=monodepth) + previous = output['output'] + outputs.append(output) + + # Combine all outputs + for output, source in zip(outputs, sources): + name = '' if len(source[0]) == 0 else '_' + source[0] + merged_output['losses'].update({'%s%s' % (key, name): val + for key, val in output['losses'].items()}) + merged_output['embeddings'].update({'%s%s' % (key, name): val + for key, val in output['embeddings'].items()}) + merged_output['output'].update({'%s%s' % (key, name): val + for key, val in output['output'].items()}) + + # Return merged output + return merged_output + + def single_decode(self, encoded, idx, camera_mode, sources, tasks, data=None, + embeddings=None, previous=None, sample_decoder=0, + is_variational=False, scene=None, monodepth=None): + + # Create embeddings if they are not provided + if embeddings is None: + embeddings = self.get_embeddings( + data=data, sources=sources, camera_mode=camera_mode, downsample=self.downsample_decoder, + encoded=encoded[idx], sample=sample_decoder, previous=previous, monodepth=monodepth, + ) + +#### + + latent = encoded[idx]['latent'] + + #if self.latent_grid: + if 0: + latent1, latent2 = latent + for key in embeddings.keys(): + xyz = embeddings[key]['info']['xyz'] + embeddings[key]['grid'] = latent1.sample(xyz).squeeze(1).permute(0, 2, 3, 1) + sources = [s for s in sources] + ['grid'] + latent = latent2 + +#### + + # Initialize output and losses + output, losses = {}, {} + + # Get additional information and stack embeddings according to source + info = {key: val['info'] for key, val in embeddings.items()} + source_embeddings = {key: torch.cat([ + val[source] for source in sources if source not in self.extra_sources], -1) + for key, val in embeddings.items()} + extra_embeddings = {key: torch.cat([ + val[source] for source in self.extra_sources], -1) if len(self.extra_sources) > 0 else None + for key, val in embeddings.items()} + + # Expand latent space if batch dimensions do not agree + batch_size = source_embeddings[list(source_embeddings.keys())[0]].shape[0] + if not is_dict(latent): + if latent.shape[0] == 1 and latent.shape[0] != batch_size: + latent = latent.repeat(batch_size, 1, 1) + else: + for key in latent.keys(): + if latent[key].shape[0] == 1 and latent[key].shape[0] != batch_size: + latent[key] = latent[key].repeat(batch_size, 1, 1) + + # Sample from latent space if it's variational + if is_variational: + output_variational = self.sample_from_latent(latent) + losses.update(**{key: val for key, val in output_variational.items() if 'loss' in key}) + latent = output_variational['sampled_latent'] + + # Decode all tasks from each embedding + for task in tasks: + output[task] = {key: make_list(self.decoders[task]( + query=val, z=latent[key] if is_dict(latent) else latent, + key=key, info=info, previous=previous, + extra=extra_embeddings[key], scene=scene)['predictions']) + for key, val in source_embeddings.items()} + # Break volumetric into rgb and depth predictions + if task.startswith('volumetric'): + for task_key in ['rgb', 'depth']: + output[task.replace('volumetric', task_key)] = { + key: [v[task_key] for v in val] for key, val in output[task].items()} + if task.startswith('multiview'): + for task_key in ['rgb', 'depth']: + output[task.replace('multiview', task_key)] = { + key: [v[task_key] for v in val] for key, val in output[task].items()} + + output['info'] = info + + # Return losses, embeddings, and output + return { + 'losses': losses, + 'embeddings': embeddings, + 'output': output, + } + + def multi_decode(self, encoded, sources=None, data=None, + embeddings=None, sample_decoder=0, num_evaluations=None, scene=None, monodepth=None): + output = {} + + for i in range(num_evaluations): + output_i = self.decode( + encoded, sources, data, embeddings, sample_decoder, scene=scene, monodepth=monodepth) + + if i == 0: + output = {key: val for key, val in output_i['output'].items()} + else: + for task in output_i['output'].keys(): + if str_not_in(task, ['info', 'volumetric']): + for key in output_i['output'][task].keys(): + output[task][key].extend(output_i['output'][task][key]) + + if not self.training: + for task in list(output.keys()): + if str_not_in(task, ['info', 'volumetric']): + output[f'{task}_mean'] = {} + output[f'stddev_{task}'] = {} + for key in output[task].keys(): + val = torch.stack(output[task][key], 0) + output[f'{task}_mean'][key] = [val.mean(0)] + output[f'stddev_{task}'][key] = [val.std(0).sum(1, keepdim=True)] + + return { + 'losses': None, + 'embeddings': embeddings, + 'output': output, + } + + def sample_from_latent(self, latent): + + if is_dict(latent): + latents = {key: self.sample_from_latent(val) for key, val in latent.items()} + output = { + 'sampled_latent': {key: val['sampled_latent'] for key, val in latents.items()} + } + if self.training: + output['kld_loss'] = sum([val['kld_loss'] for val in latents.values()]) / len(latents) + return output + + n = latent.shape[-1] // 2 + mu, logvar = latent[:, :, :n], latent[:, :, n:] + logvar = logvar.clamp(max=10) + + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + sampled_latent = eps * std + mu + + output = { + 'sampled_latent': sampled_latent + } + + if self.training: + output['kld_loss'] = self.variational_params['kld_weight'] * torch.mean( + - 0.5 * torch.mean(1 + logvar - mu ** 2 - logvar.exp(), dim=[1, 2]), dim=0) + + return output + diff --git a/vidar/arch/networks/perceiver/externals/modeling_perceiver.py b/vidar/arch/networks/perceiver/externals/modeling_perceiver.py new file mode 100755 index 0000000000000000000000000000000000000000..988eefb9cd54d2227e92cd68f114dfced5f4df53 --- /dev/null +++ b/vidar/arch/networks/perceiver/externals/modeling_perceiver.py @@ -0,0 +1,3751 @@ +# coding=utf-8 +# Copyright 2021 Deepmind and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Perceiver model.""" + +from transformers import PerceiverModel, PerceiverConfig, PreTrainedModel, apply_chunking_to_forward +from transformers.utils import ModelOutput +import abc +import math +from dataclasses import dataclass +from functools import reduce +from operator import __add__ +from typing import Any, Callable, Mapping, Optional, Tuple + +import numpy as np +import torch +import torch.utils.checkpoint +from torch import nn, Tensor +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from packaging import version +import torch.nn.functional as tfn + + +def upsample_tensor(tensor, mask, up=8): + b, c, h, w = tensor.shape + mask = mask.view(b, 1, 9, up, up, h, w) + mask = torch.softmax(mask, dim=2) + + up_tensor = tfn.unfold(tensor, [3, 3], padding=1) + up_tensor = up_tensor.view(b, -1, 9, 1, 1, h, w) + + up_tensor = torch.sum(mask * up_tensor, dim=2) + up_tensor = up_tensor.permute(0, 1, 4, 2, 5, 3) + return up_tensor.reshape(b, -1, up * h, up * w) + + +class NewGELUActivation(nn.Module): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see + the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + """ + + def forward(self, input: Tensor) -> Tensor: + return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) + + +class GELUActivation(nn.Module): + """ + Original Implementation of the GELU activation function in Google BERT repo when initially created. For + information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 + + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional + Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + """ + + def __init__(self, use_gelu_python: bool = False): + super().__init__() + if use_gelu_python: + self.act = self._gelu_python + else: + self.act = nn.functional.gelu + + def _gelu_python(self, input: Tensor) -> Tensor: + return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0))) + + def forward(self, input: Tensor) -> Tensor: + return self.act(input) + + +class FastGELUActivation(nn.Module): + """ + Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs + """ + + def forward(self, input: Tensor) -> Tensor: + return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input))) + + +class QuickGELUActivation(nn.Module): + """ + Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs + """ + + def forward(self, input: Tensor) -> Tensor: + return input * torch.sigmoid(1.702 * input) + + +class ClippedGELUActivation(nn.Module): + """ + Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as + it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to + https://arxiv.org/abs/2004.09602. + Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when + initially created. + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 + + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415 + """ + + def __init__(self, min: float, max: float): + if min > max: + raise ValueError( + f"min should be < max (got min: {min}, max: {max})") + + super().__init__() + self.min = min + self.max = max + + def forward(self, x: Tensor) -> Tensor: + return torch.clip(gelu(x), self.min, self.max) + + +class SiLUActivation(nn.Module): + """ + See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear + Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function + Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated + Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with + later. + """ + + def forward(self, input: Tensor) -> Tensor: + return nn.functional.silu(input) + + +class MishActivation(nn.Module): + """ + See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also + visit the official repository for the paper: https://github.com/digantamisra98/Mish + """ + + def __init__(self): + super().__init__() + if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.9"): + self.act = self._mish_python + else: + self.act = nn.functional.mish + + def _mish_python(self, input: Tensor) -> Tensor: + return input * torch.tanh(nn.functional.softplus(input)) + + def forward(self, input: Tensor) -> Tensor: + return self.act(input) + + +class LinearActivation(nn.Module): + """ + Applies the linear activation function, i.e. forwarding input directly to output. + """ + + def forward(self, input: Tensor) -> Tensor: + return input + + +ACT2FN = { + "gelu": GELUActivation(), + "gelu_10": ClippedGELUActivation(-10, 10), + "gelu_fast": FastGELUActivation(), + "gelu_new": NewGELUActivation(), + "gelu_python": GELUActivation(use_gelu_python=True), + "linear": LinearActivation(), + "mish": MishActivation(), + "quick_gelu": QuickGELUActivation(), + "relu": nn.ReLU(), + "sigmoid": nn.Sigmoid(), + "silu": SiLUActivation(), + "swish": SiLUActivation(), + "tanh": nn.Tanh(), +} + + +def get_activation(activation_string): + if activation_string in ACT2FN: + return ACT2FN[activation_string] + else: + raise KeyError( + f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}") +#from ...activations import ACT2FN +# from ...file_utils import ( +# ModelOutput, +# add_start_docstrings, +# add_start_docstrings_to_model_forward, +# replace_return_docstrings, +# ) +#from ...modeling_outputs import BaseModelOutputWithCrossAttentions +# from ...modeling_utils import ( +# PreTrainedModel, +# apply_chunking_to_forward, +# find_pruneable_heads_and_indices, +# prune_linear_layer, +# ) + + +#from ...utils import logging +#from .configuration_perceiver import PerceiverConfig + + +ModalitySizeType = Mapping[str, int] +PreprocessorOutputType = Tuple[torch.Tensor, + Optional[torch.Tensor], torch.Tensor] +PreprocessorType = Callable[..., PreprocessorOutputType] +PostprocessorType = Callable[..., Any] + +#logger = logging.get_logger(__name__) + +#_CHECKPOINT_FOR_DOC = "deepmind/language-perceiver" +_CONFIG_FOR_DOC = "PerceiverConfig" +#_TOKENIZER_FOR_DOC = "PerceiverTokenizer" + +PERCEIVER_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "deepmind/language-perceiver", + # See all Perceiver models at https://huggingface.co/models?filter=perceiver +] + + +@dataclass +class PerceiverModelOutput(ModelOutput): + """ + Base class for Perceiver base model's outputs, with potential hidden states, attentions and cross-attentions. + + Args: + logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer + plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, + used to compute the weighted average in the cross-attention heads. + """ + + logits: torch.FloatTensor = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class PerceiverDecoderOutput(ModelOutput): + """ + Base class for Perceiver decoder outputs, with potential cross-attentions. + + Args: + logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`): + Output of the basic decoder. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, + used to compute the weighted average in the cross-attention heads. + """ + + logits: torch.FloatTensor = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class PerceiverMaskedLMOutput(ModelOutput): + """ + Base class for Perceiver's masked language model outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Masked language modeling (MLM) loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer + plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, num_latents, + num_latents)`. Attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, + used to compute the weighted average in the cross-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class PerceiverClassifierOutput(ModelOutput): + """ + Base class for Perceiver's outputs of sequence/image classification models, optical flow and multimodal + autoencoding. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer + plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, + used to compute the weighted average in the cross-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class PerceiverEmbeddings(nn.Module): + """Construct the latent embeddings.""" + + def __init__(self, config): + super().__init__() + self.latents = nn.Parameter(torch.randn( + config.num_latents, config.d_latents)) + + def forward(self, batch_size): + return self.latents.expand(batch_size, -1, -1) # Thanks, Phil Wang + + +class PerceiverSelfAttention(nn.Module): + """Multi-headed {cross, self}-attention. Can be used both in the encoder as well as in the decoder.""" + + def __init__( + self, + config, + is_cross_attention=False, + qk_channels=None, + v_channels=None, + num_heads=1, + q_dim=None, + kv_dim=None, + ): + super().__init__() + self.num_heads = num_heads + # Q and K must have the same number of channels. + # Default to preserving Q's input's shape. + if qk_channels is None: + qk_channels = q_dim + # V's num_channels determines the shape of the output of QKV-attention. + # Default to the same number of channels used in the key-query operation. + if v_channels is None: + v_channels = qk_channels + if qk_channels % num_heads != 0: + raise ValueError( + f"qk_channels ({qk_channels}) must be divisible by num_heads ({num_heads}).") + if v_channels % num_heads != 0: + raise ValueError( + f"v_channels ({v_channels}) must be divisible by num_heads ({num_heads}).") + + self.qk_channels = qk_channels + self.v_channels = v_channels + self.qk_channels_per_head = self.qk_channels // num_heads + self.v_channels_per_head = self.v_channels // num_heads + + # Layer normalization + self.layernorm1 = nn.LayerNorm(q_dim) + self.layernorm2 = nn.LayerNorm( + kv_dim) if is_cross_attention else nn.Identity() + + # Projection matrices + self.query = nn.Linear(q_dim, qk_channels) + self.key = nn.Linear(kv_dim, qk_channels) + self.value = nn.Linear(kv_dim, v_channels) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x, channels_per_head): + new_x_shape = x.size()[:-1] + (self.num_heads, channels_per_head) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + inputs=None, + inputs_mask=None, + output_attentions=False, + ): + hidden_states = self.layernorm1(hidden_states) + inputs = self.layernorm2(inputs) + + # Project queries, keys and values to a common feature dimension. If this is instantiated as a cross-attention module, + # the keys and values come from the inputs; the attention mask needs to be such that the inputs's non-relevant tokens are not attended to. + is_cross_attention = inputs is not None + queries = self.query(hidden_states) + + if is_cross_attention: + keys = self.key(inputs) + values = self.value(inputs) + attention_mask = inputs_mask + else: + keys = self.key(hidden_states) + values = self.value(hidden_states) + + # Reshape channels for multi-head attention. + # We reshape from (batch_size, time, channels) to (batch_size, num_heads, time, channels per head) + queries = self.transpose_for_scores(queries, self.qk_channels_per_head) + keys = self.transpose_for_scores(keys, self.qk_channels_per_head) + values = self.transpose_for_scores(values, self.v_channels_per_head) + + # Take the dot product between the queries and keys to get the raw attention scores. + attention_scores = torch.matmul(queries, keys.transpose(-1, -2)) + + batch_size, num_heads, seq_len, q_head_dim = queries.shape + _, _, _, v_head_dim = values.shape + hiddens = self.num_heads * v_head_dim + + attention_scores = attention_scores / math.sqrt(q_head_dim) + + if attention_mask is not None: + # Apply the attention mask (precomputed for all layers in PerceiverModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, values) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (hiddens,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else ( + context_layer,) + + return outputs + + +class PerceiverSelfOutput(nn.Module): + def __init__(self, config, input_channels, output_channels): + super().__init__() + self.dense = nn.Linear(input_channels, output_channels) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + return hidden_states + + +class PerceiverAttention(nn.Module): + """Attention module, including a dense block.""" + + def __init__( + self, + config, + is_cross_attention=False, + qk_channels=None, + v_channels=None, + num_heads=1, + q_dim=None, + kv_dim=None, + use_query_residual=True, + ): + super().__init__() + # MultiHead attention + if is_cross_attention and qk_channels is None: + if config.cross_attention_shape_for_attention == "q": + qk_channels = q_dim + elif config.cross_attention_shape_for_attention == "kv": + qk_channels = kv_dim + else: + raise ValueError( + f"Unknown value {config.cross_attention_shape_for_attention} for " + "cross_attention_shape_for_attention." + ) + else: + if qk_channels is None: + qk_channels = q_dim + if v_channels is None: + v_channels = qk_channels + self.self = PerceiverSelfAttention( + config, + is_cross_attention=is_cross_attention, + qk_channels=qk_channels, + v_channels=v_channels, + num_heads=num_heads, + q_dim=q_dim, + kv_dim=kv_dim, + ) + # dense block + output_channels = None + if is_cross_attention: + output_channels = q_dim + else: + if output_channels is None: + output_channels = v_channels + self.output = PerceiverSelfOutput( + config, input_channels=self.self.v_channels, output_channels=output_channels) + self.use_query_residual = use_query_residual + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - \ + len(heads) + self.self.all_head_size = self.self.attention_head_size * \ + self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + inputs=None, + inputs_mask=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + inputs, + inputs_mask, + output_attentions, + ) + + # Output projection + attention_output = self.output(self_outputs[0]) + + # Optionally include a residual to the original queries. + # Consider omitting the residual if the semantics of query and output + # are different, e.g. if queries are positions and outputs are pixels. + if self.use_query_residual: + attention_output = attention_output + hidden_states + + # add attentions if we output them + outputs = (attention_output,) + self_outputs[1:] + return outputs + + +class PerceiverMLP(nn.Module): + """A Transformer-style dense module to follow attention.""" + + def __init__(self, config, input_size, widening_factor): + super().__init__() + self.dense1 = nn.Linear(input_size, widening_factor * input_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + self.dense2 = nn.Linear(input_size, input_size) + + def forward(self, hidden_states): + hidden_states = self.dense1(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.dense2(hidden_states) + return hidden_states + + +class PerceiverLayer(nn.Module): + def __init__( + self, + config, + is_cross_attention=False, + qk_channels=None, + v_channels=None, + num_heads=1, + q_dim=None, + kv_dim=None, + widening_factor=4, + use_query_residual=True, + ): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = PerceiverAttention( + config, + is_cross_attention=is_cross_attention, + qk_channels=qk_channels, + v_channels=v_channels, + num_heads=num_heads, + q_dim=q_dim, + kv_dim=kv_dim, + use_query_residual=use_query_residual, + ) + self.layernorm = nn.LayerNorm(q_dim) + self.mlp = PerceiverMLP(config, input_size=q_dim, + widening_factor=widening_factor) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + inputs=None, + inputs_mask=None, + output_attentions=False, + ): + attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + inputs, + inputs_mask, + output_attentions, + ) + attention_output = attention_outputs[0] + + # add attentions if we output attention weights + outputs = attention_outputs[1:] + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + + layer_output = layer_output + attention_output # residual connection + + outputs = (layer_output,) + outputs + + return outputs + + def feed_forward_chunk(self, attention_output): + layer_output = self.layernorm(attention_output) + layer_output = self.mlp(layer_output) + return layer_output + + +class PerceiverEncoder(nn.Module): + """The Perceiver Encoder: a scalable, fully attentional encoder.""" + + def __init__(self, config, kv_dim=None): + super().__init__() + self.config = config + + # Check that we can use multihead-attention with these shapes. + if config.d_latents % config.num_self_attention_heads != 0: + raise ValueError( + f"num_z_channels ({config.d_latents}) must be divisible by" + f" num_self_attend_heads ({config.num_self_attention_heads})." + ) + if config.d_latents % config.num_cross_attention_heads != 0: + raise ValueError( + f"num_z_channels ({config.d_latents}) must be divisible by" + f" num_cross_attend_heads ({config.num_cross_attention_heads})." + ) + + # Construct the cross attention layer. + self.cross_attention = PerceiverLayer( + config, + is_cross_attention=True, + qk_channels=config.qk_channels, + v_channels=config.v_channels, + num_heads=config.num_cross_attention_heads, + q_dim=config.d_latents, + kv_dim=kv_dim, + widening_factor=config.cross_attention_widening_factor, + use_query_residual=config.use_query_residual, + ) + + # Construct a single block of self-attention layers. + # We get deeper architectures by applying this block more than once. + self_attention_layers = [] + for _ in range(config.num_self_attends_per_block): + layer = PerceiverLayer( + config, + is_cross_attention=False, + qk_channels=config.qk_channels, + v_channels=config.v_channels, + num_heads=config.num_self_attention_heads, + q_dim=config.d_latents, + kv_dim=config.d_latents, + widening_factor=config.self_attention_widening_factor, + ) + self_attention_layers.append(layer) + + self.self_attends = nn.ModuleList(self_attention_layers) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + inputs=None, + inputs_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions else None + + # Apply the cross-attention between the latents (hidden_states) and inputs: + layer_outputs = self.cross_attention( + hidden_states, + attention_mask=attention_mask, + head_mask=None, + inputs=inputs, + inputs_mask=inputs_mask, + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_cross_attentions = all_cross_attentions + (layer_outputs[1],) + + # Apply the block of self-attention layers more than once: + for _ in range(self.config.num_blocks): + for i, layer_module in enumerate(self.self_attends): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states, + attention_mask=attention_mask, + head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + \ + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class PerceiverPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = PerceiverConfig + base_model_prefix = "perceiver" + main_input_name = "inputs" + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif hasattr(module, "latents"): + module.latents.data.normal_( + mean=0.0, std=self.config.initializer_range) + elif hasattr(module, "position_embeddings") and isinstance(module, PerceiverTrainablePositionEncoding): + module.position_embeddings.data.normal_( + mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.ParameterDict): + for modality in module.keys(): + module[modality].data.normal_( + mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +PERCEIVER_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`PerceiverConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +PERCEIVER_MODEL_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`PerceiverConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. + decoder (*DecoderType*, *optional*): + Optional decoder to use to decode the latent representation of the encoder. Examples include + *transformers.models.perceiver.modeling_perceiver.PerceiverBasicDecoder*, + *transformers.models.perceiver.modeling_perceiver.PerceiverClassificationDecoder*, + *transformers.models.perceiver.modeling_perceiver.PerceiverMultimodalDecoder*. + input_preprocessor (*PreprocessorType*, *optional*): + Optional input preprocessor to use. Examples include + *transformers.models.perceiver.modeling_perceiver.PerceiverImagePreprocessor*, + *transformers.models.perceiver.modeling_perceiver.PerceiverAudioPreprocessor*, + *transformers.models.perceiver.modeling_perceiver.PerceiverTextPreprocessor*, + *transformers.models.perceiver.modeling_perceiver.PerceiverMultimodalPreprocessor*. + output_postprocessor (*PostprocessorType*, *optional*): + Optional output postprocessor to use. Examples include + *transformers.models.perceiver.modeling_perceiver.PerceiverImagePostprocessor*, + *transformers.models.perceiver.modeling_perceiver.PerceiverAudioPostprocessor*, + *transformers.models.perceiver.modeling_perceiver.PerceiverClassificationPostprocessor*, + *transformers.models.perceiver.modeling_perceiver.PerceiverProjectionPostprocessor*, + *transformers.models.perceiver.modeling_perceiver.PerceiverMultimodalPostprocessor*. + + Note that you can define your own decoders, preprocessors and/or postprocessors to fit your use-case. +""" + +PERCEIVER_INPUTS_DOCSTRING = r""" + Args: + inputs (`torch.FloatTensor`): + Inputs to the perceiver. Can be anything: images, text, audio, video, etc. + attention_mask (`torch.FloatTensor` of shape `{0}`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. +""" + + +# @add_start_docstrings( +# """The Perceiver: a scalable, fully attentional architecture.""", +# PERCEIVER_MODEL_START_DOCSTRING, +# ) +class PerceiverModel(PerceiverPreTrainedModel): + def __init__( + self, + config, + decoder=None, + input_preprocessor: PreprocessorType = None, + output_postprocessor: PostprocessorType = None, + ): + super().__init__(config) + self.config = config + + self.input_preprocessor = input_preprocessor + self.output_postprocessor = output_postprocessor + self.embeddings = PerceiverEmbeddings(config) + self.encoder = PerceiverEncoder( + config, kv_dim=input_preprocessor.num_channels if input_preprocessor is not None else config.d_model + ) + self.decoder = decoder + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.latents + + def set_input_embeddings(self, value): + self.embeddings.latents = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + # @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + # @replace_return_docstrings(output_type=PerceiverModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + inputs, + attention_mask=None, + subsampled_output_points=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Returns: + + Examples: + + ```python + >>> from transformers import PerceiverConfig, PerceiverTokenizer, PerceiverFeatureExtractor, PerceiverModel + >>> from transformers.models.perceiver.modeling_perceiver import ( + ... PerceiverTextPreprocessor, + ... PerceiverImagePreprocessor, + ... PerceiverClassificationDecoder, + ... ) + >>> import torch + >>> import requests + >>> from PIL import Image + + >>> # EXAMPLE 1: using the Perceiver to classify texts + >>> # - we define a TextPreprocessor, which can be used to embed tokens + >>> # - we define a ClassificationDecoder, which can be used to decode the + >>> # final hidden states of the latents to classification logits + >>> # using trainable position embeddings + >>> config = PerceiverConfig() + >>> preprocessor = PerceiverTextPreprocessor(config) + >>> decoder = PerceiverClassificationDecoder( + ... config, + ... num_channels=config.d_latents, + ... trainable_position_encoding_kwargs=dict(num_channels=config.d_latents, index_dims=1), + ... use_query_residual=True, + ... ) + >>> model = PerceiverModel(config, input_preprocessor=preprocessor, decoder=decoder) + + >>> # you can then do a forward pass as follows: + >>> tokenizer = PerceiverTokenizer() + >>> text = "hello world" + >>> inputs = tokenizer(text, return_tensors="pt").input_ids + + >>> with torch.no_grad(): + ... outputs = model(inputs=inputs) + >>> logits = outputs.logits + + >>> # to train, one can train the model using standard cross-entropy: + >>> criterion = torch.nn.CrossEntropyLoss() + + >>> labels = torch.tensor([1]) + >>> loss = criterion(logits, labels) + + >>> # EXAMPLE 2: using the Perceiver to classify images + >>> # - we define an ImagePreprocessor, which can be used to embed images + >>> preprocessor = PerceiverImagePreprocessor( + ... config, + ... prep_type="conv1x1", + ... spatial_downsample=1, + ... out_channels=256, + ... position_encoding_type="trainable", + ... concat_or_add_pos="concat", + ... project_pos_dim=256, + ... trainable_position_encoding_kwargs=dict( + ... num_channels=256, + ... index_dims=config.image_size ** 2, + ... ), + ... ) + + >>> model = PerceiverModel( + ... config, + ... input_preprocessor=preprocessor, + ... decoder=PerceiverClassificationDecoder( + ... config, + ... num_channels=config.d_latents, + ... trainable_position_encoding_kwargs=dict(num_channels=config.d_latents, index_dims=1), + ... use_query_residual=True, + ... ), + ... ) + + >>> # you can then do a forward pass as follows: + >>> feature_extractor = PerceiverFeatureExtractor() + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> inputs = feature_extractor(image, return_tensors="pt").pixel_values + + >>> with torch.no_grad(): + ... outputs = model(inputs=inputs) + >>> logits = outputs.logits + + >>> # to train, one can train the model using standard cross-entropy: + >>> criterion = torch.nn.CrossEntropyLoss() + + >>> labels = torch.tensor([1]) + >>> loss = criterion(logits, labels) + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # if self.input_preprocessor is not None: + # inputs, modality_sizes, inputs_without_pos = self.input_preprocessor(inputs) + # else: + # modality_sizes = None + # inputs_without_pos = None + if inputs.size()[-1] != self.config.d_model: + raise ValueError( + f"Last dimension of the inputs: {inputs.size()[-1]} doesn't correspond to config.d_model: {self.config.d_model}. " + "Make sure to set config.d_model appropriately." + ) + + batch_size, seq_length, _ = inputs.size() + device = inputs.device + + # If no attention mask is provided, make them all ones + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length)), device=device) + # Make the attention mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + extended_attention_mask = self.invert_attention_mask(attention_mask) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_blocks x num_heads] + # and head_mask is converted to shape [num_blocks x batch x num_heads x N x N] + head_mask = self.get_head_mask( + head_mask, self.config.num_blocks * self.config.num_self_attends_per_block) + + embedding_output = self.embeddings(batch_size=batch_size) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=None, + head_mask=head_mask, + inputs=inputs, + inputs_mask=extended_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + + logits = None + # if self.decoder: + # if subsampled_output_points is not None: + # output_modality_sizes = { + # "audio": subsampled_output_points["audio"].shape[0], + # "image": subsampled_output_points["image"].shape[0], + # "label": 1, + # } + # else: + # output_modality_sizes = None + # decoder_query = self.decoder.decoder_query( + # inputs, modality_sizes, inputs_without_pos, subsampled_points=subsampled_output_points + # ) + # decoder_outputs = self.decoder( + # decoder_query, + # z=sequence_output, + # query_mask=extended_attention_mask, + # output_attentions=output_attentions, + # ) + # logits = decoder_outputs.logits + # + # # add cross-attentions of decoder + # if output_attentions and decoder_outputs.cross_attentions is not None: + # if return_dict: + # encoder_outputs.cross_attentions = ( + # encoder_outputs.cross_attentions + decoder_outputs.cross_attentions + # ) + # else: + # encoder_outputs = encoder_outputs + decoder_outputs.cross_attentions + # + # if self.output_postprocessor: + # logits = self.output_postprocessor(logits, modality_sizes=output_modality_sizes) + + if not return_dict: + if logits is not None: + return (logits, sequence_output) + encoder_outputs[1:] + else: + return (sequence_output,) + encoder_outputs[1:] + + return PerceiverModelOutput( + logits=logits, + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +# @add_start_docstrings("""Example use of Perceiver for masked language modeling.""", PERCEIVER_START_DOCSTRING) +class PerceiverForMaskedLM(PerceiverPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + text_preprocessor = PerceiverTextPreprocessor(config) + + trainable_position_encoding_kwargs_decoder = dict( + num_channels=text_preprocessor.num_channels, index_dims=config.max_position_embeddings + ) + + self.perceiver = PerceiverModel( + config, + input_preprocessor=text_preprocessor, + decoder=PerceiverBasicDecoder( + config, + output_num_channels=config.d_latents, + # we need to define the seq_len of the inputs beforehand + output_index_dims=config.max_position_embeddings, + num_channels=text_preprocessor.num_channels, + qk_channels=8 * 32, + v_channels=text_preprocessor.num_channels, + num_heads=8, + use_query_residual=False, + final_project=False, + trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder, + ), + ) + self.embedding_decoder = PerceiverEmbeddingDecoder(config) + + # Initialize weights and apply final processing + self.post_init() + + # @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + # @replace_return_docstrings(output_type=PerceiverMaskedLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + inputs=None, + attention_mask=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + labels=None, + return_dict=None, + input_ids=None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + + Returns: + + Examples: + + ```python + >>> from transformers import PerceiverTokenizer, PerceiverForMaskedLM + >>> import torch + + >>> tokenizer = PerceiverTokenizer.from_pretrained("deepmind/language-perceiver") + >>> model = PerceiverForMaskedLM.from_pretrained("deepmind/language-perceiver") + + >>> # training + >>> text = "This is an incomplete sentence where some words are missing." + >>> inputs = tokenizer(text, padding="max_length", return_tensors="pt") + >>> # mask " missing." + >>> inputs["input_ids"][0, 52:61] = tokenizer.mask_token_id + >>> labels = tokenizer(text, padding="max_length", return_tensors="pt").input_ids + + >>> outputs = model(**inputs, labels=labels) + >>> loss = outputs.loss + >>> logits = outputs.logits + + >>> # inference + >>> text = "This is an incomplete sentence where some words are missing." + >>> encoding = tokenizer(text, padding="max_length", return_tensors="pt") + + >>> # mask bytes corresponding to " missing.". Note that the model performs much better if the masked span starts with a space. + >>> encoding["input_ids"][0, 52:61] = tokenizer.mask_token_id + + >>> # forward pass + >>> with torch.no_grad(): + ... outputs = model(**encoding) + >>> logits = outputs.logits + + >>> masked_tokens_predictions = logits[0, 52:61].argmax(dim=-1).tolist() + >>> tokenizer.decode(masked_tokens_predictions) + ' missing.' + ```""" + if inputs is not None and input_ids is not None: + raise ValueError("You cannot use both `inputs` and `input_ids`") + elif inputs is None and input_ids is not None: + inputs = input_ids + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.perceiver( + inputs=inputs, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.embedding_decoder( + outputs.logits if return_dict else outputs[0], embedding_layer=self.perceiver.input_preprocessor.embeddings + ) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return PerceiverMaskedLMOutput( + loss=masked_lm_loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +# @add_start_docstrings("""Example use of Perceiver for text classification.""", PERCEIVER_START_DOCSTRING) +class PerceiverForSequenceClassification(PerceiverPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + trainable_position_encoding_kwargs_decoder = dict( + num_channels=config.d_latents, index_dims=1) + + self.num_labels = config.num_labels + self.perceiver = PerceiverModel( + config, + input_preprocessor=PerceiverTextPreprocessor(config), + decoder=PerceiverClassificationDecoder( + config, + num_channels=config.d_latents, + trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder, + use_query_residual=True, + ), + ) + + # Initialize weights and apply final processing + self.post_init() + + # @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + # @replace_return_docstrings(output_type=PerceiverClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + inputs=None, + attention_mask=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + labels=None, + return_dict=None, + input_ids=None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the classification/regression loss. Indices should be in `[0, ..., config.num_labels - + 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > + 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + + ```python + >>> from transformers import PerceiverTokenizer, PerceiverForSequenceClassification + + >>> tokenizer = PerceiverTokenizer.from_pretrained("deepmind/language-perceiver") + >>> model = PerceiverForSequenceClassification.from_pretrained("deepmind/language-perceiver") + + >>> text = "hello world" + >>> inputs = tokenizer(text, return_tensors="pt").input_ids + >>> outputs = model(inputs=inputs) + >>> logits = outputs.logits + ```""" + if inputs is not None and input_ids is not None: + raise ValueError("You cannot use both `inputs` and `input_ids`") + elif inputs is None and input_ids is not None: + inputs = input_ids + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.perceiver( + inputs=inputs, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = outputs.logits if return_dict else outputs[0] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct( + logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return PerceiverClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +# @add_start_docstrings( + """ +Example use of Perceiver for image classification, for tasks such as ImageNet. + +This model uses learned position embeddings. In other words, this model is not given any privileged information about +the structure of images. As shown in the paper, this model can achieve a top-1 accuracy of 72.7 on ImageNet. + +[`PerceiverForImageClassificationLearned`] uses [`~models.perceiver.modeling_perceiver.PerceiverImagePreprocessor`] +(with `prep_type="conv1x1"`) to preprocess the input images, and +[`~models.perceiver.modeling_perceiver.PerceiverClassificationDecoder`] to decode the latent representation of +[`PerceiverModel`] into classification logits. +""", + # PERCEIVER_START_DOCSTRING, +# ) + + +class PerceiverForImageClassificationLearned(PerceiverPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + trainable_position_encoding_kwargs_preprocessor = dict( + num_channels=256, index_dims=config.image_size ** 2) + trainable_position_encoding_kwargs_decoder = dict( + num_channels=config.d_latents, index_dims=1) + + self.num_labels = config.num_labels + self.perceiver = PerceiverModel( + config, + input_preprocessor=PerceiverImagePreprocessor( + config, + prep_type="conv1x1", + spatial_downsample=1, + out_channels=256, + position_encoding_type="trainable", + concat_or_add_pos="concat", + project_pos_dim=256, + trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_preprocessor, + ), + decoder=PerceiverClassificationDecoder( + config, + num_channels=config.d_latents, + trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder, + use_query_residual=True, + ), + ) + + # Initialize weights and apply final processing + self.post_init() + + # @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + # @replace_return_docstrings(output_type=PerceiverClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + inputs=None, + attention_mask=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + labels=None, + return_dict=None, + pixel_values=None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + + ```python + >>> from transformers import PerceiverFeatureExtractor, PerceiverForImageClassificationLearned + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = PerceiverFeatureExtractor.from_pretrained("deepmind/vision-perceiver-learned") + >>> model = PerceiverForImageClassificationLearned.from_pretrained("deepmind/vision-perceiver-learned") + + >>> inputs = feature_extractor(images=image, return_tensors="pt").pixel_values + >>> outputs = model(inputs=inputs) + >>> logits = outputs.logits + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_class_idx = logits.argmax(-1).item() + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + ```""" + if inputs is not None and pixel_values is not None: + raise ValueError("You cannot use both `inputs` and `pixel_values`") + elif inputs is None and pixel_values is not None: + inputs = pixel_values + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.perceiver( + inputs=inputs, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + logits = outputs.logits if return_dict else outputs[0] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct( + logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return PerceiverClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +# @add_start_docstrings( +# """ +# Example use of Perceiver for image classification, for tasks such as ImageNet. + +# This model uses fixed 2D Fourier position embeddings. As shown in the paper, this model can achieve a top-1 accuracy of +# 79.0 on ImageNet, and 84.5 when pre-trained on a large-scale dataset (i.e. JFT). + +# [`PerceiverForImageClassificationLearned`] uses [`~models.perceiver.modeling_perceiver.PerceiverImagePreprocessor`] +# (with `prep_type="pixels"`) to preprocess the input images, and +# [`~models.perceiver.modeling_perceiver.PerceiverClassificationDecoder`] to decode the latent representation of +# [`PerceiverModel`] into classification logits. +# """, +# PERCEIVER_START_DOCSTRING, +# ) +class PerceiverForImageClassificationFourier(PerceiverPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + fourier_position_encoding_kwargs_preprocessor = dict( + concat_pos=True, max_resolution=(224, 224), num_bands=64, sine_only=False + ) + trainable_position_encoding_kwargs_decoder = dict( + num_channels=config.d_latents, index_dims=1) + + self.num_labels = config.num_labels + self.perceiver = PerceiverModel( + config, + input_preprocessor=PerceiverImagePreprocessor( + config, + prep_type="pixels", + spatial_downsample=1, + fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor, + ), + decoder=PerceiverClassificationDecoder( + config, + num_channels=config.d_latents, + trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder, + use_query_residual=True, + ), + ) + + # Initialize weights and apply final processing + self.post_init() + + # @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + # @replace_return_docstrings(output_type=PerceiverClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + inputs=None, + attention_mask=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + labels=None, + return_dict=None, + pixel_values=None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + + ```python + >>> from transformers import PerceiverFeatureExtractor, PerceiverForImageClassificationFourier + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = PerceiverFeatureExtractor.from_pretrained("deepmind/vision-perceiver-fourier") + >>> model = PerceiverForImageClassificationFourier.from_pretrained("deepmind/vision-perceiver-fourier") + + >>> inputs = feature_extractor(images=image, return_tensors="pt").pixel_values + >>> outputs = model(inputs=inputs) + >>> logits = outputs.logits + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_class_idx = logits.argmax(-1).item() + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + ```""" + if inputs is not None and pixel_values is not None: + raise ValueError("You cannot use both `inputs` and `pixel_values`") + elif inputs is None and pixel_values is not None: + inputs = pixel_values + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.perceiver( + inputs=inputs, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + logits = outputs.logits if return_dict else outputs[0] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct( + logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return PerceiverClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +# @add_start_docstrings( +# """ +# Example use of Perceiver for image classification, for tasks such as ImageNet. + +# This model uses a 2D conv+maxpool preprocessing network. As shown in the paper, this model can achieve a top-1 accuracy +# of 82.1 on ImageNet. + +# [`PerceiverForImageClassificationLearned`] uses [`~models.perceiver.modeling_perceiver.PerceiverImagePreprocessor`] +# (with `prep_type="conv"`) to preprocess the input images, and +# [`~models.perceiver.modeling_perceiver.PerceiverClassificationDecoder`] to decode the latent representation of +# [`PerceiverModel`] into classification logits. +# """, +# PERCEIVER_START_DOCSTRING, +# ) +class PerceiverForImageClassificationConvProcessing(PerceiverPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + fourier_position_encoding_kwargs_preprocessor = dict( + concat_pos=True, max_resolution=(56, 56), num_bands=64, sine_only=False + ) + trainable_position_encoding_kwargs_decoder = dict( + num_channels=config.d_latents, index_dims=1) + + self.num_labels = config.num_labels + self.perceiver = PerceiverModel( + config, + input_preprocessor=PerceiverImagePreprocessor( + config, + prep_type="conv", + spatial_downsample=1, + position_encoding_type="fourier", + fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor, + ), + decoder=PerceiverClassificationDecoder( + config, + num_channels=config.d_latents, + trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder, + use_query_residual=True, + ), + ) + + # Initialize weights and apply final processing + self.post_init() + + # @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + # @replace_return_docstrings(output_type=PerceiverClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + inputs=None, + attention_mask=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + labels=None, + return_dict=None, + pixel_values=None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + + ```python + >>> from transformers import PerceiverFeatureExtractor, PerceiverForImageClassificationConvProcessing + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = PerceiverFeatureExtractor.from_pretrained("deepmind/vision-perceiver-conv") + >>> model = PerceiverForImageClassificationConvProcessing.from_pretrained("deepmind/vision-perceiver-conv") + + >>> inputs = feature_extractor(images=image, return_tensors="pt").pixel_values + >>> outputs = model(inputs=inputs) + >>> logits = outputs.logits + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_class_idx = logits.argmax(-1).item() + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + ```""" + if inputs is not None and pixel_values is not None: + raise ValueError("You cannot use both `inputs` and `pixel_values`") + elif inputs is None and pixel_values is not None: + inputs = pixel_values + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.perceiver( + inputs=inputs, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + logits = outputs.logits if return_dict else outputs[0] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct( + logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return PerceiverClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +# @add_start_docstrings( +# """ +# Example use of Perceiver for optical flow, for tasks such as Sintel and KITTI. [`PerceiverForOpticalFlow`] uses +# [`~models.perceiver.modeling_perceiver.PerceiverImagePreprocessor`] (with *prep_type="patches"*) to preprocess the +# input images, and [`~models.perceiver.modeling_perceiver.PerceiverOpticalFlowDecoder`] to decode the latent +# representation of [`PerceiverModel`]. + +# As input, one concatenates 2 subsequent frames along the channel dimension and extract a 3 x 3 patch around each pixel +# (leading to 3 x 3 x 3 x 2 = 54 values for each pixel). Fixed Fourier position encodings are used to encode the position +# of each pixel in the patch. Next, one applies the Perceiver encoder. To decode, one queries the latent representation +# using the same encoding used for the input. +# """, +# PERCEIVER_START_DOCSTRING, +# ) +class PerceiverForOpticalFlow(PerceiverPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + fourier_position_encoding_kwargs_preprocessor = dict( + num_bands=64, + max_resolution=config.train_size, + sine_only=False, + concat_pos=True, + ) + fourier_position_encoding_kwargs_decoder = dict( + concat_pos=True, max_resolution=config.train_size, num_bands=64, sine_only=False + ) + + image_preprocessor = PerceiverImagePreprocessor( + config, + prep_type="patches", + spatial_downsample=1, + conv_after_patching=True, + conv_after_patching_in_channels=54, + temporal_downsample=2, + position_encoding_type="fourier", + # position_encoding_kwargs + fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor, + ) + + self.perceiver = PerceiverModel( + config, + input_preprocessor=image_preprocessor, + decoder=PerceiverOpticalFlowDecoder( + config, + num_channels=image_preprocessor.num_channels, + output_image_shape=config.train_size, + rescale_factor=100.0, + # decoder kwargs + use_query_residual=False, + output_num_channels=2, + # We query the decoder using the first frame features + # rather than a standard decoder position encoding. + position_encoding_type="fourier", + fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_decoder, + ), + ) + + # Initialize weights and apply final processing + self.post_init() + + # @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + # @replace_return_docstrings(output_type=PerceiverClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + inputs=None, + attention_mask=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + labels=None, + return_dict=None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the optical flow loss. Indices should be in `[0, ..., config.num_labels - 1]`. + + Returns: + + Examples: + + ```python + >>> from transformers import PerceiverForOpticalFlow + >>> import torch + + >>> model = PerceiverForOpticalFlow.from_pretrained("deepmind/optical-flow-perceiver") + + >>> # in the Perceiver IO paper, the authors extract a 3 x 3 patch around each pixel, + >>> # leading to 3 x 3 x 3 = 27 values for each pixel (as each pixel also has 3 color channels) + >>> # patches have shape (batch_size, num_frames, num_channels, height, width) + >>> # the authors train on resolutions of 368 x 496 + >>> patches = torch.randn(1, 2, 27, 368, 496) + >>> outputs = model(inputs=patches) + >>> logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.perceiver( + inputs=inputs, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + logits = outputs.logits if return_dict else outputs[0] + + loss = None + if labels is not None: + raise NotImplementedError( + "Optical flow training is not yet supported") + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return PerceiverClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +# @add_start_docstrings( + """ +Example use of Perceiver for multimodal (video) autoencoding, for tasks such as Kinetics-700. + +[`PerceiverForMultimodalAutoencoding`] uses [`~models.perceiver.modeling_perceiver.PerceiverMultimodalPreprocessor`] to +preprocess the 3 modalities: images, audio and class labels. This preprocessor uses modality-specific preprocessors to +preprocess every modality separately, after which they are concatenated. Trainable position embeddings are used to pad +each modality to the same number of channels to make concatenation along the time dimension possible. Next, one applies +the Perceiver encoder. + +[`~models.perceiver.modeling_perceiver.PerceiverMultimodalDecoder`] is used to decode the latent representation of +[`PerceiverModel`]. This decoder uses each modality-specific decoder to construct queries. The decoder queries are +created based on the inputs after preprocessing. However, autoencoding an entire video in a single forward pass is +computationally infeasible, hence one only uses parts of the decoder queries to do cross-attention with the latent +representation. This is determined by the subsampled indices for each modality, which can be provided as additional +input to the forward pass of [`PerceiverForMultimodalAutoencoding`]. + +[`~models.perceiver.modeling_perceiver.PerceiverMultimodalDecoder`] also pads the decoder queries of the different +modalities to the same number of channels, in order to concatenate them along the time dimension. Next, cross-attention +is performed with the latent representation of [`PerceiverModel`]. + +Finally, [`~models.perceiver.modeling_perceiver.PerceiverMultiModalPostprocessor`] is used to turn this tensor into an +actual video. It first splits up the output into the different modalities, and then applies the respective +postprocessor for each modality. + +Note that, by masking the classification label during evaluation (i.e. simply providing a tensor of zeros for the +"label" modality), this auto-encoding model becomes a Kinetics 700 video classifier. +#""", +# PERCEIVER_START_DOCSTRING, +# ) + + +class PerceiverForMultimodalAutoencoding(PerceiverPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + n_audio_samples = config.num_frames * config.audio_samples_per_frame + + input_preprocessor = PerceiverMultimodalPreprocessor( + min_padding_size=4, + modalities={ + "audio": PerceiverAudioPreprocessor( + config, + position_encoding_type="fourier", + fourier_position_encoding_kwargs=dict( + num_bands=192, + max_resolution=(n_audio_samples,), + sine_only=False, + concat_pos=True, + ), + prep_type="patches", + samples_per_patch=config.samples_per_patch, + ), + "image": PerceiverImagePreprocessor( + config, + position_encoding_type="fourier", + fourier_position_encoding_kwargs=dict( + num_bands=32, + max_resolution=(config.num_frames, + config.image_size, config.image_size), + sine_only=False, + concat_pos=True, + ), + prep_type="patches", + spatial_downsample=4, + temporal_downsample=1, + ), + "label": PerceiverOneHotPreprocessor(config), + }, + mask_probs={"image": 0.0, "audio": 0.0, "label": 1.0}, + ) + + image_decoder = PerceiverBasicVideoAutoencodingDecoder( + config, + # Autoencoding, don't pass inputs to the queries. + concat_preprocessed_input=False, + output_shape=config.output_shape, + output_num_channels=512, + use_query_residual=False, + position_encoding_only=True, + position_encoding_type="fourier", + fourier_position_encoding_kwargs=dict( + num_bands=32, + max_resolution=(config.num_frames, + config.image_size, config.image_size), + sine_only=False, + concat_pos=True, + ), + ) + + decoder = PerceiverMultimodalDecoder( + config, + # Autoencoding, don't pass inputs to the queries. + concat_preprocessed_input=False, + # Modality specific decoders are used ONLY to generate queries. + # All modalties are decoded together using a unified decoder. + modalities={ + "audio": PerceiverBasicDecoder( + config, + # Autoencoding, don't pass inputs to the queries. + concat_preprocessed_input=False, + output_index_dims=(n_audio_samples // + config.samples_per_patch,), + output_num_channels=512, + use_query_residual=False, + position_encoding_only=True, + position_encoding_type="fourier", + fourier_position_encoding_kwargs=dict( + num_bands=192, + max_resolution=(n_audio_samples,), + sine_only=False, + concat_pos=True, + ), + ), + "image": image_decoder, + "label": PerceiverClassificationDecoder( + config, + # Autoencoding, don't pass inputs to the queries. + concat_preprocessed_input=False, + use_query_residual=False, + position_encoding_only=True, + position_encoding_type="trainable", + trainable_position_encoding_kwargs=dict( + num_channels=1024, + index_dims=1, + ), + ), + }, + num_outputs=None, + output_num_channels=512, + use_query_residual=False, + ) + + output_postprocessor = PerceiverMultimodalPostprocessor( + modalities={ + "audio": PerceiverAudioPostprocessor(config, in_channels=512), + "image": PerceiverProjectionPostprocessor(in_channels=512, out_channels=3), + "label": PerceiverClassificationPostprocessor(config, in_channels=512), + } + ) + + self.perceiver = PerceiverModel( + config, + input_preprocessor=input_preprocessor, + decoder=decoder, + output_postprocessor=output_postprocessor, + ) + + # Initialize weights and apply final processing + self.post_init() + + # @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + # @replace_return_docstrings(output_type=PerceiverClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + inputs=None, + attention_mask=None, + subsampled_output_points=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + labels=None, + return_dict=None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + + ```python + >>> from transformers import PerceiverForMultimodalAutoencoding + >>> import torch + >>> import numpy as np + + >>> # create multimodal inputs + >>> images = torch.randn((1, 16, 3, 224, 224)) + >>> audio = torch.randn((1, 30720, 1)) + >>> inputs = dict(image=images, audio=audio, label=torch.zeros((images.shape[0], 700))) + + >>> model = PerceiverForMultimodalAutoencoding.from_pretrained("deepmind/multimodal-perceiver") + + >>> # in the Perceiver IO paper, videos are auto-encoded in chunks + >>> # each chunk subsamples different index dimensions of the image and audio modality decoder queries + >>> nchunks = 128 + >>> image_chunk_size = np.prod((16, 224, 224)) // nchunks + >>> audio_chunk_size = audio.shape[1] // model.config.samples_per_patch // nchunks + >>> # process the first chunk + >>> chunk_idx = 0 + >>> subsampling = { + ... "image": torch.arange(image_chunk_size * chunk_idx, image_chunk_size * (chunk_idx + 1)), + ... "audio": torch.arange(audio_chunk_size * chunk_idx, audio_chunk_size * (chunk_idx + 1)), + ... "label": None, + ... } + + >>> outputs = model(inputs=inputs, subsampled_output_points=subsampling) + >>> logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.perceiver( + inputs=inputs, + attention_mask=attention_mask, + subsampled_output_points=subsampled_output_points, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + logits = outputs.logits if return_dict else outputs[0] + + loss = None + if labels is not None: + raise NotImplementedError( + "Multimodal autoencoding training is not yet supported") + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return PerceiverClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +# Below: position encodings + + +def build_position_encoding( + position_encoding_type, + out_channels=None, + project_pos_dim=-1, + trainable_position_encoding_kwargs=None, + fourier_position_encoding_kwargs=None, +): + """ + Builds the position encoding. + + Args: + + - out_channels: refers to the number of channels of the position encodings. + - project_pos_dim: if specified, will project the position encodings to this dimension. + + """ + + if position_encoding_type == "trainable": + if not trainable_position_encoding_kwargs: + raise ValueError( + "Make sure to pass trainable_position_encoding_kwargs") + output_pos_enc = PerceiverTrainablePositionEncoding( + **trainable_position_encoding_kwargs) + elif position_encoding_type == "fourier": + # We don't use the index_dims argument, as this is only known during the forward pass + if not fourier_position_encoding_kwargs: + raise ValueError( + "Make sure to pass fourier_position_encoding_kwargs") + output_pos_enc = PerceiverFourierPositionEncoding( + **fourier_position_encoding_kwargs) + else: + raise ValueError( + f"Unknown position encoding type: {position_encoding_type}.") + + # Optionally, project the position encoding to a target dimension: + positions_projection = nn.Linear( + out_channels, project_pos_dim) if project_pos_dim > 0 else nn.Identity() + + return output_pos_enc, positions_projection + + +# Below: Perceiver decoders + + +class PerceiverAbstractDecoder(nn.Module, metaclass=abc.ABCMeta): + """Perceiver abstract decoder.""" + + @abc.abstractmethod + def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None): + raise NotImplementedError + + @property + @abc.abstractmethod + def num_query_channels(self): + raise NotImplementedError + + @abc.abstractmethod + def forward(self, query, z, query_mask=None): + raise NotImplementedError + + +class PerceiverProjectionDecoder(PerceiverAbstractDecoder): + """ + Baseline projection decoder (no cross-attention). + + Args: + config ([`PerceiverConfig`]): + Model configuration. + """ + + def __init__(self, config): + super().__init__() + self.classifier = nn.Linear(config.d_latents, config.num_labels) + + def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None): + return None + + def forward(self, query, z, query_mask=None): + # (batch_size, num_latents, d_latents) -> (batch_size, d_latents) + z = torch.mean(z, dim=1) + # (batch_size, d_latents) -> (batch_size, config.num_labels) + logits = self.classifier(z) + return logits + + +class PerceiverBasicDecoder(PerceiverAbstractDecoder): + """ + Cross-attention-based decoder. This class can be used to decode the final hidden states of the latents using a + cross-attention operation, in which the latents produce keys and values. + + The shape of the output of this class depends on how one defines the output queries (also called decoder queries). + + Args: + config ([*PerceiverConfig*]): + Model configuration. + output_num_channels (`int`, *optional*): + The number of channels in the output. Will only be used in case *final_project* is set to `True`. + position_encoding_type (`str`, *optional*, defaults to "trainable"): + The type of position encoding to use. Can be either "trainable", "fourier", or "none". + output_index_dims (`int`, *optional*): + The number of dimensions of the output queries. Ignored if 'position_encoding_type' == 'none'. + num_channels (`int`, *optional*): + The number of channels of the decoder queries. Ignored if 'position_encoding_type' == 'none'. + qk_channels (`int`, *optional*): + The number of channels of the queries and keys in the cross-attention layer. + v_channels (`int`, *optional*, defaults to 128): + The number of channels of the values in the cross-attention layer. + num_heads (`int`, *optional*, defaults to 1): + The number of attention heads in the cross-attention layer. + widening_factor (`int`, *optional*, defaults to 1): + The widening factor of the cross-attention layer. + use_query_residual (`bool`, *optional*, defaults to `False`): + Whether to use a residual connection between the query and the output of the cross-attention layer. + concat_preprocessed_input (`bool`, *optional*, defaults to `False`): + Whether to concatenate the preprocessed input to the query. + final_project (`bool`, *optional*, defaults to `True`): + Whether to project the output of the cross-attention layer to a target dimension. + position_encoding_only (`bool`, *optional*, defaults to `False`): + Whether to only use this class to define output queries. + """ + + def __init__( + self, + config, + output_num_channels, + position_encoding_type="trainable", + # The following 2 arguments are ignored if position_encoding_type == 'none': + output_index_dims=None, + num_channels=128, + subsampled_index_dims=None, + qk_channels=None, + v_channels=None, + num_heads=1, + widening_factor=1, + use_query_residual=False, + concat_preprocessed_input=False, + final_project=True, + position_encoding_only=False, + **position_encoding_kwargs, + ): + super().__init__() + + self.output_num_channels = output_num_channels + # If `none`, the decoder will not construct any position encodings. + # You should construct your own when quering the decoder. + self.output_position_encodings = None + self.position_encoding_type = position_encoding_type + self.position_encoding_kwargs = position_encoding_kwargs + if position_encoding_type != "none": + self.output_position_encodings, self.positions_projection = build_position_encoding( + position_encoding_type=position_encoding_type, **position_encoding_kwargs + ) + + self.output_index_dims = output_index_dims + self.num_channels = num_channels + if subsampled_index_dims is None: + subsampled_index_dims = output_index_dims + self.subsampled_index_dims = subsampled_index_dims + self.concat_preprocessed_input = concat_preprocessed_input + self.final_project = final_project + self.position_encoding_only = position_encoding_only + + # for multimodal autoencoding, we don't need the decoder cross-attention and final layer + # so then we will set position_encoding_only to True + if not self.position_encoding_only: + self.decoding_cross_attention = PerceiverLayer( + config, + is_cross_attention=True, + qk_channels=qk_channels, + v_channels=v_channels, + num_heads=num_heads, + q_dim=num_channels, + kv_dim=config.d_latents, + widening_factor=widening_factor, + use_query_residual=use_query_residual, + ) + self.final_layer = nn.Linear( + num_channels, output_num_channels) if final_project else nn.Identity() + + @property + def num_query_channels(self) -> int: + if self.position_encoding_type == "none": # Queries come from elsewhere + raise ValueError( + "You cannot calculate number of decoder query channels when position_encoding_type is set to none" + ) + if self.position_encoding_only: + if "project_pos_dim" in self.position_encoding_kwargs: + return self.position_encoding_kwargs["project_pos_dim"] + return self.output_position_encodings.output_size() + if self.final_project: + return self.output_num_channels + return self.num_channels + + def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None): + if self.position_encoding_type == "none": # Queries come from elsewhere + raise ValueError( + "You cannot construct decoder queries when position_encoding_type is set to none") + if subsampled_points is not None: + # subsampled_points are the indices if the inputs would be flattened + # however, the inputs aren't flattened, that's why we use unravel_index + # to get the indices for the unflattened array + # unravel_index returns a tuple (x_idx, y_idx, ...) + # stack to get the [n, d] tensor of coordinates + indices = list( + torch.from_numpy(x) for x in np.unravel_index(subsampled_points.cpu(), self.output_index_dims) + ) + pos = torch.stack(indices, dim=1) + batch_size = inputs.shape[0] + # Map these coordinates to [-1, 1] + pos = -1 + 2 * pos / torch.tensor(self.output_index_dims)[None, :] + pos = torch.broadcast_to( + pos[None], [batch_size, pos.shape[0], pos.shape[1]]) + # Construct the position encoding. + if self.position_encoding_type == "trainable": + pos_emb = self.output_position_encodings(batch_size) + elif self.position_encoding_type == "fourier": + pos_emb = self.output_position_encodings( + self.output_index_dims, batch_size=batch_size, device=inputs.device, pos=pos + ) + + # Optionally project them to a target dimension. + pos_emb = self.positions_projection(pos_emb) + pos_emb = torch.reshape( + pos_emb, [pos_emb.shape[0], -1, pos_emb.shape[-1]]) + else: + batch_size = inputs.shape[0] + index_dims = inputs.shape[2:] + + # Construct the position encoding. + if self.position_encoding_type == "trainable": + pos_emb = self.output_position_encodings(batch_size) + elif self.position_encoding_type == "fourier": + pos_emb = self.output_position_encodings( + index_dims, batch_size, device=inputs.device) + + # Optionally project them to a target dimension. + pos_emb = self.positions_projection(pos_emb) + + if self.concat_preprocessed_input: + if inputs_without_pos is None: + raise ValueError( + "Value is required for inputs_without_pos if concat_preprocessed_input is True") + pos_emb = torch.cat([inputs_without_pos, pos_emb], div=-1) + + return pos_emb + + def forward(self, query, z, query_mask=None, output_attentions=False): + # Cross-attention decoding. + # key, value: B x N x K; query: B x M x K + # Attention maps -> B x N x M + # Output -> B x M x K + cross_attentions = () if output_attentions else None + + layer_outputs = self.decoding_cross_attention( + query, + attention_mask=query_mask, + head_mask=None, + inputs=z, + inputs_mask=None, + output_attentions=output_attentions, + ) + output = layer_outputs[0] + + if output_attentions: + cross_attentions = cross_attentions + (layer_outputs[1],) + + logits = self.final_layer(output) + + return PerceiverDecoderOutput(logits=logits, cross_attentions=cross_attentions) + + +class PerceiverClassificationDecoder(PerceiverAbstractDecoder): + """ + Cross-attention based classification decoder. Light-weight wrapper of [`PerceiverBasicDecoder`] for logit output. + Will turn the output of the Perceiver encoder which is of shape (batch_size, num_latents, d_latents) to a tensor of + shape (batch_size, num_labels). The queries are of shape (batch_size, 1, num_labels). + + Args: + config ([`PerceiverConfig`]): + Model configuration. + """ + + def __init__(self, config, **decoder_kwargs): + super().__init__() + + self.num_labels = config.num_labels + self.decoder = PerceiverBasicDecoder( + config, + output_num_channels=self.num_labels, + output_index_dims=1, # Predict a single logit array. + **decoder_kwargs, + ) + + @property + def num_query_channels(self) -> int: + return self.decoder.num_query_channels + + def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None): + return self.decoder.decoder_query( + inputs, modality_sizes, inputs_without_pos, subsampled_points=subsampled_points + ) + + def forward(self, query, z, query_mask=None, output_attentions=False): + decoder_outputs = self.decoder( + query, z, output_attentions=output_attentions) + + # B x 1 x num_classes -> B x num_classes + logits = decoder_outputs.logits[:, 0, :] + + return PerceiverDecoderOutput(logits=logits, cross_attentions=decoder_outputs.cross_attentions) + + +class PerceiverOpticalFlowDecoder(PerceiverAbstractDecoder): + """Cross-attention based optical flow decoder.""" + + def __init__(self, config, output_image_shape, output_num_channels=2, rescale_factor=100.0, **decoder_kwargs): + super().__init__() + + self.output_image_shape = output_image_shape + self.output_num_channels = output_num_channels + self.rescale_factor = rescale_factor + self.decoder = PerceiverBasicDecoder( + config, output_num_channels=output_num_channels, **decoder_kwargs) + + @property + def num_query_channels(self) -> int: + return self.decoder.num_query_channels + + def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None): + if subsampled_points is not None: + raise ValueError("FlowDecoder doesn't support subsampling yet.") + return inputs + + def forward(self, query, z, query_mask=None, output_attentions=False): + decoder_outputs = self.decoder( + query, z, output_attentions=output_attentions) + preds = decoder_outputs.logits + # Output flow and rescale. + preds /= self.rescale_factor + preds = preds.reshape( + [preds.shape[0]] + list(self.output_image_shape) + [preds.shape[-1]]) + print('qwerqwerqwer', decoder_outputs.cross_attentions) + return PerceiverDecoderOutput(logits=preds, cross_attentions=decoder_outputs.cross_attentions) + + +class PerceiverRGBDecoder(PerceiverAbstractDecoder): + """Cross-attention based optical flow decoder.""" + + def __init__(self, config, output_num_channels=3, upsample_mode=None, upsample_value=None, **decoder_kwargs): + super().__init__() + + if upsample_value != 1: + self.upsample_mode = upsample_mode + self.upsample_value = upsample_value + if self.upsample_mode == 'convex': + output_num_channels_mask = 9 * upsample_value ** 2 + self.decoder_mask = PerceiverBasicDecoder( + config, output_num_channels=output_num_channels_mask, **decoder_kwargs) + elif self.upsample_mode == 'bilinear': + from vidar.utils.tensor import interpolate + from functools import partial + self.interpolate = partial(interpolate, scale_factor=upsample_value, + size=None, mode='bilinear', align_corners=True) + elif self.upsample_mode == 'unpack': + output_num_channels *= upsample_value ** 2 + self.interpolate = torch.nn.PixelShuffle(upsample_value) + else: + self.upsample_mode = None + + self.output_num_channels = output_num_channels + self.decoder = PerceiverBasicDecoder( + config, output_num_channels=output_num_channels, **decoder_kwargs) + self.sigmoid = torch.nn.Sigmoid() + + @property + def num_query_channels(self) -> int: + return self.decoder.num_query_channels + + def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None): + if subsampled_points is not None: + raise ValueError("FlowDecoder doesn't support subsampling yet.") + return inputs + + def forward(self, query, z, shape=None, query_mask=None, output_attentions=False): + decoder_outputs = self.decoder( + query, z, output_attentions=output_attentions) + pred = decoder_outputs.logits + + if shape is not None: + pred = pred.reshape( + [pred.shape[0]] + list(shape) + [pred.shape[-1]]).permute(0, 3, 1, 2) + + pred = self.sigmoid(pred) + + if self.upsample_mode == 'convex': + mask = self.decoder_mask( + query, z, output_attentions=output_attentions).logits + if shape is not None: + mask = mask.reshape( + [mask.shape[0]] + list(shape) + [mask.shape[-1]]).permute(0, 3, 1, 2) + pred = upsample_tensor(pred, mask, up=self.upsample_value) + elif self.upsample_mode == 'bilinear': + pred = self.interpolate(pred) + elif self.upsample_mode == 'unpack': + pred = self.interpolate(pred) + + return PerceiverDecoderOutput(logits=pred, cross_attentions=decoder_outputs.cross_attentions) + + +class PerceiverDepthDecoder(PerceiverAbstractDecoder): + """Cross-attention based optical flow decoder.""" + + def __init__(self, config, output_num_channels=1, upsample_mode=None, upsample_value=None, + min_depth=0.1, max_depth=100.0, output_mode='inv_depth', **decoder_kwargs): + super().__init__() + + if upsample_value != 1: + self.upsample_mode = upsample_mode + self.upsample_value = upsample_value + if self.upsample_mode == 'convex': + output_num_channels_mask = 9 * upsample_value ** 2 + self.decoder_mask = PerceiverBasicDecoder( + config, output_num_channels=output_num_channels_mask, **decoder_kwargs) + elif self.upsample_mode == 'bilinear': + from vidar.utils.tensor import interpolate + from functools import partial + self.interpolate = partial(interpolate, scale_factor=upsample_value, + size=None, mode='bilinear', align_corners=True) + elif self.upsample_mode == 'unpack': + output_num_channels *= upsample_value ** 2 + self.interpolate = torch.nn.PixelShuffle(upsample_value) + else: + self.upsample_mode = None + + self.output_num_channels = output_num_channels + self.decoder = PerceiverBasicDecoder( + config, output_num_channels=output_num_channels, **decoder_kwargs) + + self.output_mode = output_mode + if self.output_mode == 'inv_depth': + from vidar.arch.blocks.depth.SigmoidToInvDepth import SigmoidToInvDepth + self.sigmoid = torch.nn.Sigmoid() + self.sigmoid_to_depth = SigmoidToInvDepth( + min_depth=min_depth, max_depth=max_depth, return_depth=True) + elif self.output_mode == 'log_depth': + from vidar.arch.blocks.depth.SigmoidToLogDepth import SigmoidToLogDepth + self.sigmoid_to_log_depth = SigmoidToLogDepth() + else: + raise ValueError('Invalid depth output mode') + + @property + def num_query_channels(self) -> int: + return self.decoder.num_query_channels + + def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None): + if subsampled_points is not None: + raise ValueError("FlowDecoder doesn't support subsampling yet.") + return inputs + + def forward(self, query, z, shape=None, query_mask=None, output_attentions=False): + decoder_outputs = self.decoder( + query, z, output_attentions=output_attentions) + pred = decoder_outputs.logits + + if shape is not None: + pred = pred.reshape( + [pred.shape[0]] + list(shape) + [pred.shape[-1]]).permute(0, 3, 1, 2) + + if self.output_mode == 'inv_depth': + pred = self.sigmoid_to_depth(self.sigmoid(pred)) + elif self.output_mode == 'log_depth': + pred = self.sigmoid_to_log_depth(pred) + + if self.upsample_mode == 'convex': + mask = self.decoder_mask( + query, z, output_attentions=output_attentions).logits + if shape is not None: + mask = mask.reshape( + [mask.shape[0]] + list(shape) + [mask.shape[-1]]).permute(0, 3, 1, 2) + pred = upsample_tensor(pred, mask, up=self.upsample_value) + elif self.upsample_mode == 'bilinear': + pred = self.interpolate(pred) + elif self.upsample_mode == 'unpack': + pred = self.interpolate(pred) + + return PerceiverDecoderOutput(logits=pred, cross_attentions=decoder_outputs.cross_attentions) + + +class PerceiverBasicVideoAutoencodingDecoder(PerceiverAbstractDecoder): + """ + Cross-attention based video-autoencoding decoder. Light-weight wrapper of [*PerceiverBasicDecoder*] with video + reshaping logic. + + Args: + config ([*PerceiverConfig*]): + Model configuration. + output_shape (`List[int]`): + Shape of the output as (batch_size, num_frames, height, width), excluding the channel dimension. + position_encoding_type (`str`): + The type of position encoding to use. Can be either "trainable", "fourier", or "none". + """ + + def __init__(self, config, output_shape, position_encoding_type, **decoder_kwargs): + super().__init__() + if len(output_shape) != 4: # B, T, H, W + raise ValueError( + f"Expected rank 4 output_shape, got {output_shape}.") + # Build the decoder components: + self.output_shape = output_shape + self.output_num_channels = decoder_kwargs["output_num_channels"] + + self.decoder = PerceiverBasicDecoder( + config, + output_index_dims=self.output_shape[1:4], # T*H*W + position_encoding_type=position_encoding_type, + **decoder_kwargs, + ) + + @property + def num_query_channels(self) -> int: + return self.decoder.num_query_channels + + def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None): + return self.decoder.decoder_query( + inputs, + modality_sizes=modality_sizes, + inputs_without_pos=inputs_without_pos, + subsampled_points=subsampled_points, + ) + + def forward(self, query, z, query_mask=None): + decoder_outputs = self.decoder(query, z) + logits = decoder_outputs.logits + + logits = torch.reshape(logits, self.output_shape + [logits.shape[-1]]) + return PerceiverDecoderOutput(logits=logits, cross_attentions=decoder_outputs.cross_attentions) + + +def restructure(modality_sizes: ModalitySizeType, inputs: torch.Tensor) -> Mapping[str, torch.Tensor]: + """ + Partitions a [B, N, C] tensor into tensors for each modality. + + Args: + modality_sizes + dict specifying the size of the modality + inputs: + input tensor + + Returns: + dict mapping name of modality to its associated tensor. + """ + outputs = {} + index = 0 + # Apply a predictable ordering to the modalities + for modality in sorted(modality_sizes.keys()): + size = modality_sizes[modality] + inp = inputs[:, index: index + size] + index += size + outputs[modality] = inp + return outputs + + +class PerceiverMultimodalDecoder(PerceiverAbstractDecoder): + """ + Multimodal decoding by composing uni-modal decoders. The *modalities* argument of the constructor is a dictionary + mapping modality name to the decoder of that modality. That decoder will be used to construct queries for that + modality. Modality-specific queries are padded with trainable modality-specific parameters, after which they are + concatenated along the time dimension. + + Next, there is a shared cross attention operation across all modalities. + + Args: + config ([*PerceiverConfig*]): + Model configuration. + modalities (`Dict[str, PerceiverAbstractDecoder]`): + Dictionary mapping modality name to the decoder of that modality. + num_outputs (`int`): + The number of outputs of the decoder. + output_num_channels (`int`): + The number of channels in the output. + min_padding_size (`int`, *optional*, defaults to 2): + The minimum padding size for all modalities. The final output will have num_channels equal to the maximum + channels across all modalities plus min_padding_size. + subsampled_index_dims (`Dict[str, PerceiverAbstractDecoder]`, *optional*): + Dictionary mapping modality name to the subsampled index dimensions to use for the decoder query of that + modality. + """ + + def __init__( + self, + config, + modalities, + num_outputs, + output_num_channels, + min_padding_size=2, + subsampled_index_dims=None, + **decoder_kwargs + ): + super().__init__() + self.modalities = nn.ModuleDict(modalities) + self.subsampled_index_dims = subsampled_index_dims + self.min_padding_size = min_padding_size + self.output_num_channels = output_num_channels + self.num_outputs = num_outputs + self.decoder = PerceiverBasicDecoder( + config, + output_index_dims=(num_outputs,), + output_num_channels=output_num_channels, + position_encoding_type="none", + num_channels=self.num_query_channels, + **decoder_kwargs, + ) + self.padding = nn.ParameterDict( + { + modality: nn.Parameter(torch.randn( + 1, self.num_query_channels - decoder.num_query_channels)) + for modality, decoder in modalities.items() + } + ) + + @property + def num_query_channels(self) -> int: + max_channel_size = max( + decoder.num_query_channels for _, decoder in self.modalities.items()) + common_channel_size = max_channel_size + self.min_padding_size + return common_channel_size + + def decoder_query(self, inputs, modality_sizes, inputs_without_pos=None, subsampled_points=None): + # Partition the flat inputs among the different modalities + inputs = restructure(modality_sizes, inputs) + + # Obtain modality-specific decoders' queries + subsampled_points = subsampled_points or dict() + + decoder_queries = dict() + for modality, decoder in self.modalities.items(): + # Get input_without_pos for this modality if it exists. + input_without_pos = None + if inputs_without_pos is not None: + input_without_pos = inputs_without_pos.get(modality, None) + query = decoder.decoder_query( + inputs=inputs[modality], + modality_sizes=None, + inputs_without_pos=input_without_pos, + subsampled_points=subsampled_points.get(modality, None), + ) + decoder_queries[modality] = query + + # Pad all queries with trainable position encodings to make them have the same channels + + def embed(modality, x): + x = torch.reshape( + x, [x.shape[0], np.prod(x.shape[1:-1]), x.shape[-1]]) + pos = self.padding[modality] + pos = torch.broadcast_to( + pos, [x.shape[0], x.shape[1], self.num_query_channels - x.shape[2]]) + return torch.cat([x, pos], dim=2) + + # Apply a predictable ordering to the modalities + return torch.cat( + [embed(modality, decoder_queries[modality]) for modality in sorted(self.modalities.keys())], dim=1 + ) + + def forward(self, query, z, query_mask=None, output_attentions=False): + # B x 1 x num_classes -> B x num_classes + decoder_outputs = self.decoder( + query, z, output_attentions=output_attentions) + + return decoder_outputs + + +# Below: IO pre- and post-processor classes for Perceiver. +def space_to_depth(frames: torch.Tensor, temporal_block_size: int = 1, spatial_block_size: int = 1) -> torch.Tensor: + """ + Space to depth transform. Rearranges blocks of spatial data, into depth. + + This function assumes the channels to be first, but will place the channels last after transformation. + + Based on https://discuss.pytorch.org/t/is-there-any-layer-like-tensorflows-space-to-depth-function/3487/15. + """ + if len(frames.shape) == 4: + batch_size, num_channels, height, width = frames.shape + # split up dimensions (height by spatial_block_size, width by spatial_block_size) + frames = frames.view( + batch_size, + num_channels, + height // spatial_block_size, + spatial_block_size, + width // spatial_block_size, + spatial_block_size, + ) + # move blocks to last dimension: (batch_size, H//bs, W//bs, bs, bs, C) + frames = frames.permute(0, 2, 4, 3, 5, 1).contiguous() + # concatenate blocks along channel dimension: (batch_size, H//bs, W//bs, bs*bs*C) + frames = frames.view( + batch_size, + height // spatial_block_size, + width // spatial_block_size, + (spatial_block_size ** 2) * num_channels, + ) + return frames + elif len(frames.shape) == 5: + batch_size, time, num_channels, height, width = frames.shape + # split up dimensions (time by temporal_block_size, height by spatial_block_size, width by spatial_block_size) + frames = frames.view( + batch_size, + time // temporal_block_size, + temporal_block_size, + num_channels, + height // spatial_block_size, + spatial_block_size, + width // spatial_block_size, + spatial_block_size, + ) + # move blocks to last dimension: (batch_size, T//ts, H//bs, W//bs, ts, bs, bs, C) + frames = frames.permute(0, 1, 4, 6, 2, 5, 7, 3).contiguous() + # concatenate blocks along channel dimension: (batch_size, T//ts, H//bs, W//bs, ts*bs*bs*C) + frames = frames.view( + batch_size, + time // temporal_block_size, + height // spatial_block_size, + width // spatial_block_size, + temporal_block_size * (spatial_block_size ** 2) * num_channels, + ) + return frames + else: + raise ValueError( + "Frames should be of rank 4 (batch, channels, height, width)" + " or rank 5 (batch, time, channels, height, width)" + ) + + +class Conv2dSamePadding(nn.Conv2d): + """ + Conv2d layer with padding="same" support. Source: + https://gist.github.com/sumanmichael/4de9dee93f972d47c80c4ade8e149ea6 + """ + + def __init__(self, *args, **kwargs): + super(Conv2dSamePadding, self).__init__(*args, **kwargs) + self.zero_pad_2d = nn.ZeroPad2d( + reduce(__add__, [(k // 2 + (k - 2 * (k // 2)) - 1, k // 2) + for k in self.kernel_size[::-1]]) + ) + + def forward(self, input): + return self._conv_forward(self.zero_pad_2d(input), self.weight, self.bias) + + +class Conv2DDownsample(nn.Module): + """Downsamples 4x by applying a 2D convolution and doing max pooling.""" + + def __init__( + self, + num_layers: int = 1, + in_channels: int = 3, + out_channels: int = 64, + use_batchnorm: bool = True, + ): + """ + Constructs a Conv2DDownsample model. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 64): + The number of conv output channels. + use_batchnorm (`bool`, *optional*, defaults to `True`): + Whether to use batchnorm. + """ + super().__init__() + + self.conv = Conv2dSamePadding( + in_channels=in_channels, out_channels=out_channels, kernel_size=7, stride=2, bias=False + ) + self.batchnorm = nn.BatchNorm2d( + num_features=out_channels) if use_batchnorm else nn.Identity() + self.relu = nn.ReLU() + self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + out = self.conv(inputs) + out = self.batchnorm(out) + out = self.relu(out) + out = self.max_pool(out) + return out + + +def generate_fourier_features(pos, num_bands, max_resolution=(224, 224), concat_pos=True, sine_only=False): + """ + Generate a Fourier frequency position encoding with linear spacing. + + Args: + pos (`torch.LongTensor` of shape `(batch_size, sequence_length, dim)`): + The Tensor containing the position of n points in d dimensional space. + num_bands (`int`): + The number of frequency bands (K) to use. + max_resolution (`Tuple[int]`, *optional*, defaults to (224, 224)): + The maximum resolution (i.e. the number of pixels per dim). A tuple representing resolution for each dimension. + concat_pos (`bool`, *optional*, defaults to `True`): + Whether to concatenate the input position encoding to the Fourier features. + sine_only (`bool`, *optional*, defaults to `False`): + Whether to use a single phase (sin) or two (sin/cos) for each frequency band. + + Returns: + `torch.FloatTensor` of shape `(batch_size, sequence_length, n_channels)`: The Fourier position embeddings. If + `concat_pos` is `True` and `sine_only` is `False`, output dimensions are ordered as: [dim_1, dim_2, ..., dim_d, + sin(pi*f_1*dim_1), ..., sin(pi*f_K*dim_1), ..., sin(pi*f_1*dim_d), ..., sin(pi*f_K*dim_d), cos(pi*f_1*dim_1), + ..., cos(pi*f_K*dim_1), ..., cos(pi*f_1*dim_d), ..., cos(pi*f_K*dim_d)], where dim_i is pos[:, i] and f_k is the + kth frequency band. + """ + + batch_size = pos.shape[0] + device = pos.device + + min_freq = 1.0 + # Nyquist frequency at the target resolution: + freq_bands = torch.stack( + [torch.linspace(start=min_freq, end=res / 2, steps=num_bands, device=device) for res in max_resolution], dim=0 + ) + + # Get frequency bands for each spatial dimension. + # Output is size [n, d * num_bands] + per_pos_features = pos[0, :, :][:, :, None] * freq_bands[None, :, :] + per_pos_features = torch.reshape( + per_pos_features, [-1, np.prod(per_pos_features.shape[1:])]) + + if sine_only: + # Output is size [n, d * num_bands] + per_pos_features = torch.sin(np.pi * (per_pos_features)) + else: + # Output is size [n, 2 * d * num_bands] + per_pos_features = torch.cat( + [torch.sin(np.pi * per_pos_features), torch.cos(np.pi * per_pos_features)], dim=-1 + ) + # Concatenate the raw input positions. + if concat_pos: + # Adds d bands to the encoding. + per_pos_features = torch.cat( + [pos, per_pos_features.expand(batch_size, -1, -1)], dim=-1) + return per_pos_features + + +def build_linear_positions(index_dims, output_range=(-1.0, 1.0)): + """ + Generate an array of position indices for an N-D input array. + + Args: + index_dims (`List[int]`): + The shape of the index dimensions of the input array. + output_range (`Tuple[float]`, *optional*, defaults to `(-1.0, 1.0)`): + The min and max values taken by each input index dimension. + + Returns: + `torch.FloatTensor` of shape `(index_dims[0], index_dims[1], .., index_dims[-1], N)`. + """ + + def _linspace(n_xels_per_dim): + return torch.linspace(start=output_range[0], end=output_range[1], steps=n_xels_per_dim, dtype=torch.float32) + + dim_ranges = [_linspace(n_xels_per_dim) for n_xels_per_dim in index_dims] + array_index_grid = torch.meshgrid(*dim_ranges) + + return torch.stack(array_index_grid, dim=-1) + + +class PerceiverAbstractPositionEncoding(nn.Module, metaclass=abc.ABCMeta): + """Perceiver abstract position encoding.""" + + @property + @abc.abstractmethod + def num_dimensions(self) -> int: + raise NotImplementedError + + @abc.abstractmethod + def output_size(self, *args, **kwargs) -> int: + raise NotImplementedError + + @abc.abstractmethod + def forward(self, batch_size, pos): + raise NotImplementedError + + +class PerceiverTrainablePositionEncoding(PerceiverAbstractPositionEncoding): + """Trainable position encoding.""" + + def __init__(self, index_dims, num_channels=128): + super().__init__() + self._num_channels = num_channels + self._index_dims = index_dims + index_dim = np.prod(index_dims) + self.position_embeddings = nn.Parameter( + torch.randn(index_dim, num_channels)) + + @property + def num_dimensions(self) -> int: + if isinstance(self._index_dims, int): + return 1 + return len(self._index_dims) + + def output_size(self, *args, **kwargs) -> int: + return self._num_channels + + def forward(self, batch_size): + position_embeddings = self.position_embeddings + + if batch_size is not None: + position_embeddings = position_embeddings.expand( + batch_size, -1, -1) + return position_embeddings + + +def _check_or_build_spatial_positions(pos, index_dims, batch_size): + """ + Checks or builds spatial position features (x, y, ...). + + Args: + pos (`torch.FloatTensor`): + None, or an array of position features. If None, position features are built. Otherwise, their size is checked. + index_dims (`List[int]`): + An iterable giving the spatial/index size of the data to be featurized. + batch_size (`int`): + The batch size of the data to be featurized. + + Returns: + `torch.FloatTensor` of shape `(batch_size, prod(index_dims))` an array of position features. + """ + if pos is None: + pos = build_linear_positions(index_dims) + pos = torch.broadcast_to(pos[None], (batch_size,) + pos.shape) + pos = torch.reshape(pos, [batch_size, np.prod(index_dims), -1]) + # else: + # # Just a warning label: you probably don't want your spatial features to + # # have a different spatial layout than your pos coordinate system. + # # But feel free to override if you think it'll work! + # if pos.shape[-1] != len(index_dims): + # raise ValueError("Spatial features have the wrong number of dimensions.") + return pos + + +class PerceiverFourierPositionEncoding(PerceiverAbstractPositionEncoding): + """Fourier (Sinusoidal) position encoding.""" + + def __init__(self, num_bands, max_resolution, concat_pos=True, sine_only=False): + super().__init__() + self.num_bands = num_bands + self.max_resolution = max_resolution + self.concat_pos = concat_pos + self.sine_only = sine_only + + @property + def num_dimensions(self) -> int: + return len(self.max_resolution) + + def output_size(self): + """Returns size of positional encodings last dimension.""" + num_dims = len(self.max_resolution) + encoding_size = self.num_bands * num_dims + if not self.sine_only: + encoding_size *= 2 + if self.concat_pos: + encoding_size += self.num_dimensions + + return encoding_size + + def forward(self, index_dims, batch_size, device, pos=None): + pos = _check_or_build_spatial_positions(pos, index_dims, batch_size) + fourier_pos_enc = generate_fourier_features( + pos, + num_bands=self.num_bands, + max_resolution=self.max_resolution, + concat_pos=self.concat_pos, + sine_only=self.sine_only, + ).to(device) + return fourier_pos_enc + + +class AbstractPreprocessor(nn.Module): + @property + def num_channels(self) -> int: + """Returns size of preprocessor output.""" + raise NotImplementedError() + + +class PerceiverTextPreprocessor(AbstractPreprocessor): + """ + Text preprocessing for Perceiver Encoder. Can be used to embed `inputs` and add positional encodings. + + The dimensionality of the embeddings is determined by the `d_model` attribute of the configuration. + + Args: + config ([`PerceiverConfig`]): + Model configuration. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.embeddings = nn.Embedding( + num_embeddings=config.vocab_size, embedding_dim=config.d_model) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.d_model) + + @property + def num_channels(self) -> int: + return self.config.d_model + + def forward(self, inputs): + embeddings = self.embeddings(inputs) + + seq_length = inputs.shape[1] + position_ids = torch.arange(0, seq_length, device=inputs.device) + embeddings = embeddings + self.position_embeddings(position_ids) + + return embeddings, None, None + + +class PerceiverEmbeddingDecoder(nn.Module): + """ + Module to decode embeddings (for masked language modeling). + + Args: + config ([`PerceiverConfig`]): + Model configuration. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.vocab_size = config.vocab_size + self.bias = nn.Parameter(torch.zeros(self.vocab_size)) + + def forward(self, hidden_states, embedding_layer): + batch_size, seq_len, d_model = hidden_states.shape + output = torch.matmul(hidden_states.reshape( + [-1, d_model]), embedding_layer.weight.T) # Flatten batch dim + output = output + self.bias + + return output.reshape([batch_size, seq_len, self.vocab_size]) + + +class PerceiverMultimodalPostprocessor(nn.Module): + """ + Multimodal postprocessing for Perceiver. Can be used to combine modality-specific postprocessors into a single + postprocessor. + + Args: + modalities (`Dict[str, PostprocessorType]`): + Dictionary mapping modality name to postprocessor class for that modality. + input_is_dict (`bool`, *optional*, defaults to `False`): + If True, input is assumed to be dictionary structured, and outputs keep the same dictionary shape. If + False, input is a tensor which is sliced up during postprocessing by *modality_sizes*. + """ + + def __init__(self, modalities: Mapping[str, PostprocessorType], input_is_dict: bool = False): + super().__init__() + self.modalities = nn.ModuleDict(modalities) + self.input_is_dict = input_is_dict + + def forward( + self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, modality_sizes=None + ) -> Mapping[str, torch.Tensor]: + if not self.input_is_dict: + # Slice up modalities by their sizes. + if modality_sizes is None: + raise ValueError( + "Modality sizes should be specified if input is not a dictionary.") + inputs = restructure(modality_sizes=modality_sizes, inputs=inputs) + + outputs = { + modality: postprocessor( + inputs[modality], pos=pos, modality_sizes=None) + for modality, postprocessor in self.modalities.items() + } + return outputs + + +class PerceiverClassificationPostprocessor(nn.Module): + """ + Classification postprocessing for Perceiver. Can be used to convert the decoder output to classification logits. + + Args: + config ([*PerceiverConfig*]): + Model configuration. + in_channels (`int`): + Number of channels in the input. + """ + + def __init__(self, config, in_channels): + super().__init__() + self.classifier = nn.Linear(in_channels, config.num_labels) + + def forward(self, inputs, pos: Optional[torch.Tensor] = None, modality_sizes=None) -> torch.Tensor: + logits = self.classifier(inputs) + return logits[:, 0, :] + + +class PerceiverAudioPostprocessor(nn.Module): + """ + Audio postprocessing for Perceiver. Can be used to convert the decoder output to audio features. + + Args: + config ([*PerceiverConfig*]): + Model configuration. + in_channels (`int`): + Number of channels in the input. + postproc_type (`str`, *optional*, defaults to `"patches"`): + Postprocessor type to use. Currently, only "patches" is supported. + """ + + def __init__(self, config, in_channels, postproc_type: str = "patches"): + super().__init__() + + # to be supported: 'conv', 'patches', 'pixels' + if postproc_type not in ("patches",): + raise ValueError("Invalid postproc_type!") + + # Architecture parameters: + self.classifier = nn.Linear(in_channels, config.samples_per_patch) + + def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, modality_sizes=None) -> torch.Tensor: + + logits = self.classifier(inputs) + return torch.reshape(logits, [inputs.shape[0], -1]) + + +class PerceiverProjectionPostprocessor(nn.Module): + """ + Projection postprocessing for Perceiver. Can be used to project the channels of the decoder output to a lower + dimension. + + Args: + in_channels (`int`): + Number of channels in the input. + out_channels (`int`): + Number of channels in the output. + """ + + def __init__(self, in_channels, out_channels): + super().__init__() + self.classifier = nn.Linear(in_channels, out_channels) + + def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, modality_sizes=None) -> torch.Tensor: + logits = self.classifier(inputs) + return logits + + +class PerceiverImagePreprocessor(AbstractPreprocessor): + """ + Image preprocessing for Perceiver Encoder. + + Note: the *out_channels* argument refers to the output channels of a convolutional layer, if *prep_type* is set to + "conv1x1" or "conv". If one adds absolute position embeddings, one must make sure the *num_channels* of the + position encoding kwargs are set equal to the *out_channels*. + + Args: + config ([*PerceiverConfig*]): + Model configuration. + prep_type (`str`, *optional*, defaults to `"conv"`): + Preprocessing type. Can be "conv1x1", "conv", "patches", "pixels". + spatial_downsample (`int`, *optional*, defaults to 4): + Spatial downsampling factor. + temporal_downsample (`int`, *optional*, defaults to 1): + Temporal downsampling factor (only relevant in case a time dimension is present). + position_encoding_type (`str`, *optional*, defaults to `"fourier"`): + Position encoding type. Can be "fourier" or "trainable". + in_channels (`int`, *optional*, defaults to 3): + Number of channels in the input. + out_channels (`int`, *optional*, defaults to 64): + Number of channels in the output. + conv_after_patching (`bool`, *optional*, defaults to `False`): + Whether to apply a convolutional layer after patching. + conv_after_patching_in_channels (`int`, *optional*, defaults to 54): + Number of channels in the input of the convolutional layer after patching. + conv2d_use_batchnorm (`bool`, *optional*, defaults to `True`): + Whether to use batch normalization in the convolutional layer. + concat_or_add_pos (`str`, *optional*, defaults to `"concat"`): + How to concatenate the position encoding to the input. Can be "concat" or "add". + project_pos_dim (`int`, *optional*, defaults to -1): + Dimension of the position encoding to project to. If -1, no projection is applied. + **position_encoding_kwargs (`Dict`, *optional*): + Keyword arguments for the position encoding. + """ + + def __init__( + self, + config, + prep_type="conv", + spatial_downsample: int = 4, + temporal_downsample: int = 1, + position_encoding_type: str = "fourier", + in_channels: int = 3, + out_channels: int = 64, + conv_after_patching: bool = False, + # only relevant when conv_after_patching = True + conv_after_patching_in_channels: int = 54, + conv2d_use_batchnorm: bool = True, + concat_or_add_pos: str = "concat", + project_pos_dim: int = -1, + **position_encoding_kwargs, + ): + super().__init__() + self.config = config + + if prep_type not in ("conv", "patches", "pixels", "conv1x1"): + raise ValueError(f"Prep_type {prep_type} is invalid") + + if concat_or_add_pos not in ["concat", "add"]: + raise ValueError( + f"Invalid value {concat_or_add_pos} for concat_or_add_pos.") + + self.in_channels = in_channels + self.prep_type = prep_type + self.spatial_downsample = spatial_downsample + self.temporal_downsample = temporal_downsample + self.position_encoding_type = position_encoding_type + self.concat_or_add_pos = concat_or_add_pos + self.conv_after_patching = conv_after_patching + self.out_channels = out_channels + + if self.prep_type == "conv": + # Downsampling with conv is currently restricted + convnet_num_layers = math.log(spatial_downsample, 4) + convnet_num_layers_is_int = convnet_num_layers == np.round( + convnet_num_layers) + if not convnet_num_layers_is_int or temporal_downsample != 1: + raise ValueError( + "Only powers of 4 expected for spatial and 1 expected for temporal downsampling with conv." + ) + self.convnet = Conv2DDownsample( + in_channels=in_channels, + num_layers=int(convnet_num_layers), + out_channels=out_channels, + use_batchnorm=conv2d_use_batchnorm, + ) + + elif self.prep_type == "conv1x1": + if temporal_downsample != 1: + raise ValueError("Conv1x1 does not downsample in time.") + self.convnet_1x1 = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(1, 1), + # spatial_downsample is unconstrained for 1x1 convolutions. + stride=(spatial_downsample, spatial_downsample), + ) + + # Position embeddings + self.project_pos_dim = project_pos_dim + self.position_embeddings, self.positions_projection = build_position_encoding( + position_encoding_type=position_encoding_type, + out_channels=out_channels, + project_pos_dim=project_pos_dim, + **position_encoding_kwargs, + ) + + # Optional convolutional layer after patches. + self.conv_after_patches = ( + nn.Linear(conv_after_patching_in_channels, + self.out_channels) if conv_after_patching else nn.Identity() + ) + + @property + def num_channels(self) -> int: + # Let's assume that the number of resolutions (in the context of image preprocessing) + # of the input data is 2 or 3 depending on whether we are processing image or video respectively. + # In this case, for convenience, we will declare is_temporal variable, + # which will show whether the data has a temporal dimension or not. + is_temporal = self.position_embeddings.num_dimensions > 2 + + # position embedding + if self.project_pos_dim > 0: + pos_dim = self.project_pos_dim + else: + pos_dim = self.position_embeddings.output_size() + if self.concat_or_add_pos == "add": + return pos_dim + + # inputs + if self.conv_after_patching or self.prep_type in ("conv1x1", "conv"): + inp_dim = self.out_channels + elif self.prep_type == "pixels": + inp_dim = self.in_channels + if not is_temporal: + inp_dim = math.ceil(inp_dim / self.spatial_downsample) + elif self.prep_type == "patches": + if self.conv_after_patching: + inp_dim = self.out_channels + else: + inp_dim = self.in_channels * self.spatial_downsample ** 2 + if is_temporal: + inp_dim *= self.temporal_downsample + + return inp_dim + pos_dim + + def _build_network_inputs(self, inputs: torch.Tensor, pos: torch.Tensor, network_input_is_1d: bool = True): + """ + Construct the final input, including position encoding. + + This method expects the inputs to always have channels as last dimension. + + """ + batch_size = inputs.shape[0] + index_dims = inputs.shape[1:-1] + indices = np.prod(index_dims) + + # Flatten input features to a 1D index dimension if necessary. + if len(inputs.shape) > 3 and network_input_is_1d: + inputs = torch.reshape(inputs, [batch_size, indices, -1]) + + # Construct the position encoding. + if self.position_encoding_type == "trainable": + pos_enc = self.position_embeddings(batch_size) + elif self.position_encoding_type == "fourier": + pos_enc = self.position_embeddings( + index_dims, batch_size, device=inputs.device) + + # Optionally project them to a target dimension. + pos_enc = self.positions_projection(pos_enc) + + if not network_input_is_1d: + # Reshape pos to match the input feature shape + # if the network takes non-1D inputs + sh = inputs.shape + pos_enc = torch.reshape(pos_enc, list(sh)[:-1] + [-1]) + if self.concat_or_add_pos == "concat": + inputs_with_pos = torch.cat([inputs, pos_enc], dim=-1) + elif self.concat_or_add_pos == "add": + inputs_with_pos = inputs + pos_enc + return inputs_with_pos, inputs + + def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True): + if self.prep_type == "conv": + # Convnet image featurization. + # Downsamples spatially by a factor of 4 + inputs = self.convnet(inputs) + + elif self.prep_type == "conv1x1": + # map inputs to self.out_channels + inputs = self.convnet_1x1(inputs) + + elif self.prep_type == "pixels": + # if requested, downsamples in the crudest way + if inputs.ndim == 4: + inputs = inputs[:: self.spatial_downsample, + :: self.spatial_downsample] + elif inputs.ndim == 5: + inputs = inputs[ + :, :: self.temporal_downsample, :, :: self.spatial_downsample, :: self.spatial_downsample + ] + else: + raise ValueError("Unsupported data format for pixels.") + + elif self.prep_type == "patches": + # Space2depth featurization. + # Video: B x T x C x H x W + inputs = space_to_depth( + inputs, temporal_block_size=self.temporal_downsample, spatial_block_size=self.spatial_downsample + ) + + if inputs.ndim == 5 and inputs.shape[1] == 1: + # for flow + inputs = inputs.squeeze(dim=1) + + # Optionally apply conv layer. + inputs = self.conv_after_patches(inputs) + + if self.prep_type != "patches": + # move channels to last dimension, as the _build_network_inputs method below expects this + if inputs.ndim == 4: + inputs = torch.moveaxis(inputs, 1, -1) + elif inputs.ndim == 5: + inputs = torch.moveaxis(inputs, 2, -1) + else: + raise ValueError("Unsupported data format for conv1x1.") + + inputs, inputs_without_pos = self._build_network_inputs( + inputs, pos, network_input_is_1d) + modality_sizes = None # Size for each modality, only needed for multimodal + + return inputs, modality_sizes, inputs_without_pos + + +class PerceiverOneHotPreprocessor(AbstractPreprocessor): + """ + One-hot preprocessor for Perceiver Encoder. Can be used to add a dummy index dimension to the input. + + Args: + config ([`PerceiverConfig`]): + Model configuration. + """ + + def __init__(self, config): + super().__init__() + self.config: PerceiverConfig = config + + @property + def num_channels(self) -> int: + return self.config.num_labels + + def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True): + # Add a dummy index dimension. + inputs = inputs[:, None, :] + + # No position encodings, so the 1st (input) and 3rd (inputs_without_pos) + # outputs are identical. + return inputs, None, inputs + + +class PerceiverAudioPreprocessor(AbstractPreprocessor): + """ + Audio preprocessing for Perceiver Encoder. + + Args: + config ([*PerceiverConfig*]): + Model configuration. + prep_type (`str`, *optional*, defaults to `"patches"`): + Preprocessor type to use. Only "patches" is supported. + samples_per_patch (`int`, *optional*, defaults to 96): + Number of samples per patch. + position_encoding_type (`str`, *optional*, defaults to `"fourier"`): + Type of position encoding to use. Can be "trainable" or "fourier". + concat_or_add_pos (`str`, *optional*, defaults to `"concat"`): + How to concatenate the position encoding to the input. Can be "concat" or "add". + out_channels (`int`, *optional*, defaults to 64): + Number of channels in the output. + project_pos_dim (`int`, *optional*, defaults to -1): + Dimension of the position encoding to project to. If -1, no projection is applied. + **position_encoding_kwargs (`Dict`, *optional*): + Keyword arguments for the position encoding. + """ + + def __init__( + self, + config, + prep_type: str = "patches", + samples_per_patch: int = 96, + position_encoding_type: str = "fourier", + concat_or_add_pos: str = "concat", + out_channels=64, + project_pos_dim=-1, + **position_encoding_kwargs, + ): + super().__init__() + self.config = config + + if prep_type not in ("patches",): + raise ValueError( + f"Prep_type {prep_type} is invalid, can only be 'patches'.") + + if concat_or_add_pos not in ["concat", "add"]: + raise ValueError( + f"Concat_or_pos {concat_or_add_pos} is invalid, can only be 'concat' or 'add'.") + + self.samples_per_patch = samples_per_patch + self.position_encoding_type = position_encoding_type + self.concat_or_add_pos = concat_or_add_pos + self.project_pos_dim = project_pos_dim + + # Position embeddings + self.position_embeddings, self.positions_projection = build_position_encoding( + position_encoding_type=position_encoding_type, + out_channels=out_channels, + project_pos_dim=project_pos_dim, + **position_encoding_kwargs, + ) + + @property + def num_channels(self) -> int: + # position embedding + if self.project_pos_dim > 0: + pos_dim = self.project_pos_dim + else: + pos_dim = self.position_embeddings.output_size() + if self.concat_or_add_pos == "add": + return pos_dim + return self.samples_per_patch + pos_dim + + def _build_network_inputs(self, inputs, pos): + """Construct the final input, including position encoding.""" + batch_size = inputs.shape[0] + index_dims = inputs.shape[1:-1] + + # Construct the position encoding. + if self.position_encoding_type == "trainable": + pos_enc = self.position_embeddings(batch_size) + elif self.position_encoding_type == "fourier": + pos_enc = self.position_embeddings( + index_dims, batch_size, device=inputs.device) + + # Optionally project them to a target dimension. + pos_enc = self.positions_projection(pos_enc) + + if self.concat_or_add_pos == "concat": + inputs_with_pos = torch.cat([inputs, pos_enc], dim=-1) + elif self.concat_or_add_pos == "add": + inputs_with_pos = inputs + pos_enc + + return inputs_with_pos, inputs + + def forward(self, inputs, pos, network_input_is_1d: bool = True): + inputs = torch.reshape( + inputs, [inputs.shape[0], -1, self.samples_per_patch]) + + inputs, inputs_without_pos = self._build_network_inputs(inputs, pos) + modality_sizes = None # Size for each modality, only needed for multimodal + + return inputs, modality_sizes, inputs_without_pos + + +class PerceiverMultimodalPreprocessor(AbstractPreprocessor): + """ + Multimodal preprocessing for Perceiver Encoder. + + Inputs for each modality are preprocessed, then padded with trainable position embeddings to have the same number + of channels. + + Args: + modalities (`Dict[str, PreprocessorType]`): + Dict mapping modality name to preprocessor. + mask_probs (`Dict[str, float]`): + Dict mapping modality name to masking probability of that modality. + min_padding_size (`int`, *optional*, defaults to 2): + The minimum padding size for all modalities. The final output will have num_channels equal to the maximum + channels across all modalities plus min_padding_size. + """ + + def __init__( + self, + modalities: Mapping[str, PreprocessorType], + mask_probs: Optional[Mapping[str, float]] = None, + min_padding_size: int = 2, + ): + super().__init__() + self.modalities = modalities + self.min_padding_size = min_padding_size + self.mask_probs = mask_probs if mask_probs is not None else dict() + self.padding = nn.ParameterDict( + { + modality: nn.Parameter(torch.randn( + 1, self.num_channels - preprocessor.num_channels)) + for modality, preprocessor in modalities.items() + } + ) + self.mask = nn.ParameterDict( + {modality: nn.Parameter(torch.randn(1, self.num_channels)) + for modality, _ in self.mask_probs.items()} + ) + + @property + def num_channels(self) -> int: + max_channel_size = max(processor.num_channels for _, + processor in self.modalities.items()) + common_channel_size = max_channel_size + self.min_padding_size + return common_channel_size + + def forward( + self, inputs: Mapping[str, torch.Tensor], pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True + ) -> PreprocessorOutputType: + padded = {} + modality_sizes = {} + inputs_without_pos = {} + for modality, preprocessor in self.modalities.items(): + # preprocess each modality using the respective preprocessor. + output, _, inputs_without_pos[modality] = preprocessor( + inputs[modality], pos=pos, network_input_is_1d=network_input_is_1d + ) + + # pad to the same common_channel_size. + batch_size, num_samples, num_channels = output.shape + pos_enc = self.padding[modality].expand(batch_size, -1, -1) + + padding = torch.broadcast_to( + pos_enc, + [batch_size, num_samples, self.num_channels - num_channels], + ) + output_padded = torch.cat([output, padding], dim=2) + + # mask if required + if modality in self.mask_probs: + mask_token = self.mask[modality].expand(batch_size, -1, -1) + mask_prob = self.mask_probs[modality] + mask = torch.bernoulli(torch.full( + [batch_size, num_samples], mask_prob)) + mask = torch.unsqueeze(mask, dim=2).to(mask_token.device) + output_padded = (1 - mask) * output_padded + mask * mask_token + + padded[modality] = output_padded + modality_sizes[modality] = output_padded.shape[1] + + # Apply a predictable ordering to the modalities + padded_ls = [padded[k] for k in sorted(padded.keys())] + + # Finally, concatenate along the time dimension + final_inputs = torch.cat(padded_ls, dim=1) + + return final_inputs, modality_sizes, inputs_without_pos diff --git a/vidar/arch/networks/pose/ConvPoseNet.py b/vidar/arch/networks/pose/ConvPoseNet.py new file mode 100644 index 0000000000000000000000000000000000000000..abed96abd72f2350779c181b80d65e0d733f54ab --- /dev/null +++ b/vidar/arch/networks/pose/ConvPoseNet.py @@ -0,0 +1,86 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import torch +import torch.nn as nn + + +def conv_gn(in_planes, out_planes, kernel_size=3): + """ + Convolutional block with GroupNorm + + Parameters + ---------- + in_planes : int + Number of input channels + out_planes : int + Number of output channels + kernel_size : int + Convolutional kernel size + + Returns + ------- + layers : nn.Sequential + Sequence of Conv2D + GroupNorm + ReLU + """ + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, stride=2), + nn.GroupNorm(16, out_planes), + nn.ReLU(inplace=True) + ) + + +class ConvPoseNet(nn.Module): + """Pose network """ + + def __init__(self, nb_ref_imgs=2, rotation_mode='euler', **kwargs): + + nb_ref_imgs = 2 + rotation_mode = 'euler' + + super().__init__() + self.nb_ref_imgs = nb_ref_imgs + self.rotation_mode = rotation_mode + + conv_channels = [16, 32, 64, 128, 256, 256, 256] + self.conv1 = conv_gn(3 * (1 + self.nb_ref_imgs), conv_channels[0], kernel_size=7) + self.conv2 = conv_gn(conv_channels[0], conv_channels[1], kernel_size=5) + self.conv3 = conv_gn(conv_channels[1], conv_channels[2]) + self.conv4 = conv_gn(conv_channels[2], conv_channels[3]) + self.conv5 = conv_gn(conv_channels[3], conv_channels[4]) + self.conv6 = conv_gn(conv_channels[4], conv_channels[5]) + self.conv7 = conv_gn(conv_channels[5], conv_channels[6]) + + self.pose_pred = nn.Conv2d(conv_channels[6], 6 * self.nb_ref_imgs, + kernel_size=1, padding=0) + + self.init_weights() + + def init_weights(self): + """Initialize weights""" + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + nn.init.xavier_uniform_(m.weight.data) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, image, context): + """Network forward pass""" + + assert (len(context) == self.nb_ref_imgs) + input = [image] + input.extend(context) + input = torch.cat(input, 1) + out_conv1 = self.conv1(input) + out_conv2 = self.conv2(out_conv1) + out_conv3 = self.conv3(out_conv2) + out_conv4 = self.conv4(out_conv3) + out_conv5 = self.conv5(out_conv4) + out_conv6 = self.conv6(out_conv5) + out_conv7 = self.conv7(out_conv6) + + pose = self.pose_pred(out_conv7) + pose = pose.mean(3).mean(2) + pose = 0.01 * pose.view(pose.size(0), self.nb_ref_imgs, 6) + + return pose diff --git a/vidar/arch/networks/pose/PoseNet.py b/vidar/arch/networks/pose/PoseNet.py new file mode 100755 index 0000000000000000000000000000000000000000..3dccc04d40f89886e8cfa700fa37d40927912363 --- /dev/null +++ b/vidar/arch/networks/pose/PoseNet.py @@ -0,0 +1,49 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from abc import ABC + +import torch + +from vidar.arch.networks.BaseNet import BaseNet +from vidar.arch.networks.decoders.PoseDecoder import PoseDecoder +from vidar.arch.networks.encoders.ResNetEncoder import ResNetEncoder +from vidar.geometry.pose_utils import vec2mat + + +class PoseNet(BaseNet, ABC): + """ + Pose Network + + Parameters + ---------- + cfg : Config + Configuration with parameters + """ + def __init__(self, cfg): + super().__init__(cfg) + + self.networks['pose_encoder'] = \ + ResNetEncoder(cfg) + + self.networks['pose'] = \ + PoseDecoder( + num_ch_enc=self.networks['pose_encoder'].num_ch_enc, + num_input_features=1, + num_frames_to_predict_for=2 + ) + + def forward(self, rgb, invert): + """Network forward pass""" + + rgb = torch.cat(rgb[::-1] if invert else rgb, 1) + feats = self.networks['pose_encoder'](rgb) + rotation, translation = self.networks['pose']([feats]) + transformation = vec2mat( + rotation[:, 0], translation[:, 0], invert=invert) + + return { + 'rotation': rotation, + 'translation': translation, + 'transformation': transformation, + } + diff --git a/vidar/arch/networks/transformers/MatchModule.py b/vidar/arch/networks/transformers/MatchModule.py new file mode 100644 index 0000000000000000000000000000000000000000..44114c9d118bf647de4806e6d059e68d37109cf5 --- /dev/null +++ b/vidar/arch/networks/transformers/MatchModule.py @@ -0,0 +1,46 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from abc import ABC + +import torch + +from vidar.arch.networks.BaseNet import BaseNet +from vidar.arch.networks.layers.depthformer.transformer_net import TransformerNet +from vidar.utils.config import cfg_has + + +class MatchModule(BaseNet, ABC): + """ + Feature matching module (https://arxiv.org/abs/2204.07616) + + Parameters + ---------- + cfg : Config + Configuration with parameters + """ + def __init__(self, cfg): + super().__init__(cfg) + + self.downsample = cfg_has(cfg, 'downsample', 3) + self.set_attr(cfg, 'monocular', False) + self.set_attr(cfg, 'decoder_type', 'regression') + self.set_attr(cfg, 'pos_enc', True) + self.set_attr(cfg, 'calc_right', True) + + self.match_module = TransformerNet(cfg, decoder_type=self.decoder_type) + if cfg_has(cfg, 'fix_layers', True): + self.match_module.fix_layers() + self.set_attr(cfg, 'preprocess', False) + + def forward(self, target, context, device, cam): + """Network forward pass""" + + bs, _, h, w = target.size() + + downsample = 4 + col_offset = int(downsample / 2) + row_offset = int(downsample / 2) + sampled_cols = torch.arange(col_offset, w, downsample)[None,].expand(bs, -1).to(device) + sampled_rows = torch.arange(row_offset, h, downsample)[None,].expand(bs, -1).to(device) + + return self.match_module(target, context, sampled_rows, sampled_cols, cam) diff --git a/vidar/core/__init__.py b/vidar/core/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vidar/core/checkpoint.py b/vidar/core/checkpoint.py new file mode 100755 index 0000000000000000000000000000000000000000..5ed6eba1473283a7e35006f33aa77e474a34f41d --- /dev/null +++ b/vidar/core/checkpoint.py @@ -0,0 +1,251 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import os +from datetime import datetime + +import numpy as np +import torch + +from vidar.utils.config import cfg_has +from vidar.utils.logging import pcolor + + +class ModelCheckpoint: + """ + Class for model checkpointing + + Parameters + ---------- + cfg : Config + Configuration with parameters + verbose : Bool + Print information on screen if enabled + """ + def __init__(self, cfg, verbose=False): + super().__init__() + + # Create checkpoint folder + self.folder = cfg_has(cfg, 'folder', None) + self.name = cfg_has(cfg, 'name', datetime.now().strftime("%Y-%m-%d_%Hh%Mm%Ss")) + if self.folder: + self.path = os.path.join(self.folder, self.name) + os.makedirs(self.path, exist_ok=True) + else: + self.path = None + + # Exclude folders + self.excludes = ['sandbox'] + + # If there is no folder, only track metrics + self.tracking_only = self.path is None + + # Store arguments + self.keep_top = cfg_has(cfg, 'keep_top', -1) + self.dataset = cfg_has(cfg, 'dataset', []) + self.monitor = cfg_has(cfg, 'monitor', []) + self.mode = cfg_has(cfg, 'mode', []) + + # Number of metrics to track + self.num_tracking = len(self.mode) + + # Prepare s3 bucket + if cfg_has(cfg, 's3_bucket'): + self.s3_path = f's3://{cfg.s3_bucket}/{self.name}' + self.s3_url = f'https://s3.console.aws.amazon.com/s3/buckets/{self.s3_path[5:]}' + else: + self.s3_path = self.s3_url = None + + # Get starting information + self.torch_inf = torch.tensor(np.Inf) + mode_dict = { + 'min': (self.torch_inf, 'min'), + 'max': (-self.torch_inf, 'max'), + 'auto': (-self.torch_inf, 'max') if \ + 'acc' in self.monitor or \ + 'a1' in self.monitor or \ + 'fmeasure' in self.monitor \ + else (self.torch_inf, 'min'), + } + + if self.mode: + self.top = [[] for _ in self.mode] + self.store_val = [[] for _ in self.mode] + self.previous = [0 for _ in self.mode] + self.best = [mode_dict[m][0] for m in self.mode] + self.mode = [mode_dict[m][1] for m in self.mode] + else: + self.top = [] + + # Print if requested + if verbose: + self.print() + + # Save if requested + if cfg_has(cfg, 'save_code', False): + self.save_code() + if self.s3_url: + self.sync_s3(verbose=False) + + def print(self): + """Print information on screen""" + font_base = {'color': 'red', 'attrs': ('bold', 'dark')} + font_name = {'color': 'red', 'attrs': ('bold',)} + font_underline = {'color': 'red', 'attrs': ('underline',)} + + print(pcolor('#' * 60, **font_base)) + if self.path: + print(pcolor('### Checkpoint: ', **font_base) + \ + pcolor('{}/{}'.format(self.folder, self.name), **font_name)) + if self.s3_url: + print(pcolor('### ', **font_base) + \ + pcolor('{}'.format(self.s3_url), **font_underline)) + else: + print(pcolor('### Checkpoint: ', **font_base) + \ + pcolor('Tracking only', **font_name)) + print(pcolor('#' * 60, **font_base)) + + @staticmethod + def save_model(wrapper, name, epoch): + """Save model""" + torch.save({ + 'config': wrapper.cfg, 'epoch': epoch, + 'state_dict': wrapper.arch.state_dict(), + }, name) + + @staticmethod + def del_model(name): + """Delete model""" + if os.path.isfile(name): + os.remove(name) + + def save_code(self): + """Save code in the models folder""" + excludes = ' '.join([f'--exclude {exclude}' for exclude in self.excludes]) + os.system(f"tar cfz {self.path}/{self.name}.tar.gz {excludes} *") + + def sync_s3(self, verbose=True): + """Sync saved models with the s3 bucket""" + + font_base = {'color': 'magenta', 'attrs': ('bold', 'dark')} + font_name = {'color': 'magenta', 'attrs': ('bold',)} + + if verbose: + print(pcolor('Syncing ', **font_base) + + pcolor('{}'.format(self.path), **font_name) + + pcolor(' -> ', **font_base) + + pcolor('{}'.format(self.s3_path), **font_name)) + + command = f'aws s3 sync {self.path} {self.s3_path} ' \ + f'--acl bucket-owner-full-control --quiet --delete' + os.system(command) + + def print_improvements(self, key, value, idx, is_best): + """Print color-coded changes in tracked metrics""" + + font1 = {'color': 'cyan', 'attrs':('dark', 'bold')} + font2 = {'color': 'cyan', 'attrs': ('bold',)} + font3 = {'color': 'yellow', 'attrs': ('bold',)} + font4 = {'color': 'green', 'attrs': ('bold',)} + font5 = {'color': 'red', 'attrs': ('bold',)} + + current_inf = self.best[idx] == self.torch_inf or \ + self.best[idx] == -self.torch_inf + + print( + pcolor(f'{key}', **font2) + \ + pcolor(f' ({self.mode[idx]}) : ', **font1) + \ + ('' if current_inf else + pcolor('%3.6f' % self.previous[idx], **font3) + + pcolor(f' -> ', **font1)) + \ + (pcolor('%3.6f' % value, **font4) if is_best else + pcolor('%3.6f' % value, **font5)) + + ('' if current_inf else + pcolor(' (%3.6f)' % self.best[idx], **font2)) + ) + + def save(self, wrapper, epoch, verbose=True): + """Save model""" + # Do nothing if no path is provided + if self.path: + + name = '%03d.ckpt' % epoch + folder = os.path.join(self.path, 'models') + + os.makedirs(folder, exist_ok=True) + folder_name = os.path.join(folder, name) + self.save_model(wrapper, folder_name, epoch) + self.top.append(folder_name) + if 0 < self.keep_top < len(self.top): + self.del_model(self.top.pop(0)) + if self.s3_url: + self.sync_s3(verbose=False) + + if verbose: + print() + + def check_and_save(self, wrapper, metrics, prefixes, epoch, verbose=True): + """Check if model should be saved and maybe save it""" + # Not tracking any metric, save every iteration + if self.num_tracking == 0: + # Do nothing if no path is provided + if self.path: + + name = '%03d.ckpt' % epoch + folder = os.path.join(self.path, 'models') + + os.makedirs(folder, exist_ok=True) + folder_name = os.path.join(folder, name) + self.save_model(wrapper, folder_name, epoch) + self.top.append(folder_name) + if 0 < self.keep_top < len(self.top): + self.del_model(self.top.pop(0)) + if self.s3_url: + self.sync_s3(verbose=False) + + # Check if saving for every metric + else: + + for idx in range(self.num_tracking): + + key = '{}-{}'.format(prefixes[self.dataset[idx]], self.monitor[idx]) + value = metrics[key] + + if self.mode[idx] == 'min': + is_best = value < self.best[idx] + will_store = len(self.store_val[idx]) < self.keep_top or \ + value < np.max(self.store_val[idx]) + store_idx = 0 if len(self.store_val[idx]) == 0 else int(np.argmax(self.store_val[idx])) + else: + is_best = value > self.best[idx] + will_store = len(self.store_val[idx]) < self.keep_top or \ + value > np.min(self.store_val[idx]) + store_idx = 0 if len(self.store_val[idx]) == 0 else int(np.argmin(self.store_val[idx])) + + if verbose: + self.print_improvements(key, value, idx, is_best) + + self.previous[idx] = value + + if is_best: + self.best[idx] = value + + if is_best or will_store: + + if self.path: + + name = '%03d_%3.6f.ckpt' % (epoch, value) + folder = os.path.join(self.path, key) + + os.makedirs(folder, exist_ok=True) + folder_name = os.path.join(folder, name) + self.save_model(wrapper, folder_name, epoch) + self.top[idx].append(folder_name) + self.store_val[idx].append(value) + if 0 < self.keep_top < len(self.top[idx]): + self.del_model(self.top[idx].pop(store_idx)) + self.store_val[idx].pop(store_idx) + if self.s3_url: + self.sync_s3(verbose=False) + + if verbose: + print() diff --git a/vidar/core/logger.py b/vidar/core/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..98d2ea88ecd4392f3e2086a52ce2125cabc677c8 --- /dev/null +++ b/vidar/core/logger.py @@ -0,0 +1,311 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from collections import OrderedDict +from copy import deepcopy + +import torch +import torch.nn as nn +import wandb + +from vidar.utils.config import cfg_has +from vidar.utils.distributed import world_size +from vidar.utils.logging import pcolor +from vidar.utils.types import is_dict, is_tensor, is_seq, is_namespace +from vidar.utils.viz import viz_depth, viz_inv_depth, viz_normals, viz_optical_flow, viz_camera + + +class WandbLogger: + """ + Wandb logger class to monitor training + + Parameters + ---------- + cfg : Config + Configuration with parameters + verbose : Bool + Print information on screen if enabled + """ + def __init__(self, cfg, verbose=False): + super().__init__() + + self.num_logs = { + 'train': cfg_has(cfg, 'num_train_logs', 0), + 'val': cfg_has(cfg, 'num_validation_logs', 0), + 'test': cfg_has(cfg, 'num_test_logs', 0), + } + + self._name = cfg.name if cfg_has(cfg, 'name') else None + self._dir = cfg.folder + self._entity = cfg.entity + self._project = cfg.project + + self._tags = cfg_has(cfg, 'tags', '') + self._notes = cfg_has(cfg, 'notes', '') + + self._id = None + self._anonymous = None + self._log_model = True + + self._experiment = self._create_experiment() + self._metrics = OrderedDict() + + self.only_first = cfg_has(cfg, 'only_first', False) + + cfg.name = self.run_name + cfg.url = self.run_url + + if verbose: + self.print() + + @staticmethod + def finish(): + """Finish wandb session""" + wandb.finish() + + def print(self): + """Print information on screen""" + + font_base = {'color': 'red', 'attrs': ('bold', 'dark')} + font_name = {'color': 'red', 'attrs': ('bold',)} + font_underline = {'color': 'red', 'attrs': ('underline',)} + + print(pcolor('#' * 60, **font_base)) + print(pcolor('### WandB: ', **font_base) + \ + pcolor('{}'.format(self.run_name), **font_name)) + print(pcolor('### ', **font_base) + \ + pcolor('{}'.format(self.run_url), **font_underline)) + print(pcolor('#' * 60, **font_base)) + + def __getstate__(self): + """Get the current logger state""" + state = self.__dict__.copy() + state['_id'] = self._experiment.id if self._experiment is not None else None + state['_experiment'] = None + return state + + def _create_experiment(self): + """Creates and returns a new experiment""" + experiment = wandb.init( + name=self._name, dir=self._dir, project=self._project, + anonymous=self._anonymous, reinit=True, id=self._id, notes=self._notes, + resume='allow', tags=self._tags, entity=self._entity + ) + wandb.run.save() + return experiment + + def watch(self, model: nn.Module, log='gradients', log_freq=100): + """Watch training parameters""" + self.experiment.watch(model, log=log, log_freq=log_freq) + + @property + def experiment(self): + """Returns the experiment (creates a new if it doesn't exist)""" + if self._experiment is None: + self._experiment = self._create_experiment() + return self._experiment + + @property + def run_name(self): + """Returns run name""" + return wandb.run.name if self._experiment else None + + @property + def run_url(self): + """Returns run URL""" + return f'https://app.wandb.ai/' \ + f'{wandb.run.entity}/' \ + f'{wandb.run.project}/runs/' \ + f'{wandb.run.id}' if self._experiment else None + + def log_config(self, cfg): + """Log model configuration""" + cfg = recursive_convert_config(deepcopy(cfg)) + self.experiment.config.update(cfg, allow_val_change=True) + + def log_metrics(self, metrics): + """Log training metrics""" + self._metrics.update(metrics) + if 'epochs' in metrics or 'samples' in metrics: + self.experiment.log(self._metrics) + self._metrics.clear() + + def log_images(self, batch, output, prefix, ontology=None): + """ + Log images depending on its nature + + Parameters + ---------- + batch : Dict + Dictionary containing batch information + output : Dict + Dictionary containing output information + prefix : String + Prefix string for the log name + ontology : Dict + Dictionary with ontology information + """ + for data, suffix in zip([batch, output['predictions']], ['-gt', '-pred']): + for key in data.keys(): + if key.startswith('rgb'): + self._metrics.update(log_rgb( + key, prefix + suffix, data, only_first=self.only_first)) + elif key.startswith('depth'): + self._metrics.update(log_depth( + key, prefix + suffix, data, only_first=self.only_first)) + elif key.startswith('inv_depth'): + self._metrics.update(log_inv_depth( + key, prefix + suffix, data, only_first=self.only_first)) + elif 'normals' in key: + self._metrics.update(log_normals( + key, prefix + suffix, data, only_first=self.only_first)) + elif key.startswith('stddev'): + self._metrics.update(log_stddev( + key, prefix + suffix, data, only_first=self.only_first)) + elif key.startswith('logvar'): + self._metrics.update(log_logvar( + key, prefix + suffix, data, only_first=self.only_first)) + elif 'optical_flow' in key: + self._metrics.update(log_optical_flow( + key, prefix + suffix, data, only_first=self.only_first)) + elif 'mask' in key or 'valid' in key: + self._metrics.update(log_rgb( + key, prefix, data, only_first=self.only_first)) + # elif 'camera' in key: + # self._metrics.update(log_camera( + # key, prefix + suffix, data, only_first=self.only_first)) + # elif 'uncertainty' in key: + # self._metrics.update(log_uncertainty(key, prefix, data)) + # elif 'semantic' in key and ontology is not None: + # self._metrics.update(log_semantic(key, prefix, data, ontology=ontology)) + # if 'scene_flow' in key: + # self._metrics.update(log_scene_flow(key, prefix_idx, data)) + # elif 'score' in key: + # # Log score as image heatmap + # self._metrics.update(log_keypoint_score(key, prefix, data)) + + def log_data(self, mode, batch, output, dataset, prefix, ontology=None): + """Helper function used to log images""" + idx = batch['idx'][0] + num_logs = self.num_logs[mode] + if num_logs > 0: + interval = (len(dataset) // world_size() // num_logs) * world_size() + if interval == 0 or (idx % interval == 0 and idx < interval * num_logs): + prefix = '{}-{}-{}'.format(mode, prefix, batch['idx'][0].item()) + # batch, output = prepare_logging(batch, output) + self.log_images(batch, output, prefix, ontology=ontology) + + +def recursive_convert_config(cfg): + """Convert configuration to dictionary recursively""" + cfg = cfg.__dict__ + for key, val in cfg.items(): + if is_namespace(val): + cfg[key] = recursive_convert_config(val) + return cfg + + +def prep_image(key, prefix, image): + """Prepare image for logging""" + if is_tensor(image): + if image.dim() == 2: + image = image.unsqueeze(0) + if image.dim() == 4: + image = image[0] + image = image.detach().permute(1, 2, 0).cpu().numpy() + prefix_key = '{}-{}'.format(prefix, key) + return {prefix_key: wandb.Image(image, caption=key)} + + +def log_sequence(key, prefix, data, i, only_first, fn): + """Logs a sequence of images (list, tuple or dict)""" + log = {} + if is_dict(data): + for ctx, dict_val in data.items(): + if is_seq(dict_val): + if only_first: + dict_val = dict_val[:1] + for idx, list_val in enumerate(dict_val): + if list_val.dim() == 5: + for j in range(list_val.shape[1]): + log.update(fn('%s(%s_%d)_%d' % (key, str(ctx), j, idx), prefix, list_val[:, j], i)) + else: + log.update(fn('%s(%s)_%d' % (key, str(ctx), idx), prefix, list_val, i)) + else: + if dict_val.dim() == 5: + for j in range(dict_val.shape[1]): + log.update(fn('%s(%s_%d)' % (key, str(ctx), j), prefix, dict_val[:, j], i)) + else: + log.update(fn('%s(%s)' % (key, str(ctx)), prefix, dict_val, i)) + elif is_seq(data): + if only_first: + data = data[:1] + for idx, list_val in enumerate(data): + log.update(fn('%s_%d' % (key, idx), prefix, list_val, i)) + else: + log.update(fn('%s' % key, prefix, data, i)) + return log + + +def log_rgb(key, prefix, batch, i=0, only_first=None): + """Log RGB image""" + rgb = batch[key] if is_dict(batch) else batch + if is_seq(rgb) or is_dict(rgb): + return log_sequence(key, prefix, rgb, i, only_first, log_rgb) + return prep_image(key, prefix, rgb[i].clamp(min=0.0, max=1.0)) + + +def log_depth(key, prefix, batch, i=0, only_first=None): + """Log depth map""" + depth = batch[key] if is_dict(batch) else batch + if is_seq(depth) or is_dict(depth): + return log_sequence(key, prefix, depth, i, only_first, log_depth) + return prep_image(key, prefix, viz_depth(depth[i], filter_zeros=True)) + + +def log_inv_depth(key, prefix, batch, i=0, only_first=None): + """Log inverse depth map""" + inv_depth = batch[key] if is_dict(batch) else batch + if is_seq(inv_depth) or is_dict(inv_depth): + return log_sequence(key, prefix, inv_depth, i, only_first, log_inv_depth) + return prep_image(key, prefix, viz_inv_depth(inv_depth[i])) + + +def log_normals(key, prefix, batch, i=0, only_first=None): + """Log normals""" + normals = batch[key] if is_dict(batch) else batch + if is_seq(normals) or is_dict(normals): + return log_sequence(key, prefix, normals, i, only_first, log_normals) + return prep_image(key, prefix, viz_normals(normals[i])) + + +def log_optical_flow(key, prefix, batch, i=0, only_first=None): + """Log optical flow""" + optical_flow = batch[key] if is_dict(batch) else batch + if is_seq(optical_flow) or is_dict(optical_flow): + return log_sequence(key, prefix, optical_flow, i, only_first, log_optical_flow) + return prep_image(key, prefix, viz_optical_flow(optical_flow[i])) + + +def log_stddev(key, prefix, batch, i=0, only_first=None): + """Log standard deviation""" + stddev = batch[key] if is_dict(batch) else batch + if is_seq(stddev) or is_dict(stddev): + return log_sequence(key, prefix, stddev, i, only_first, log_stddev) + return prep_image(key, prefix, viz_inv_depth(stddev[i], colormap='jet')) + + +def log_logvar(key, prefix, batch, i=0, only_first=None): + """Log standard deviation""" + logvar = batch[key] if is_dict(batch) else batch + if is_seq(logvar) or is_dict(logvar): + return log_sequence(key, prefix, logvar, i, only_first, log_logvar) + return prep_image(key, prefix, viz_inv_depth(torch.exp(logvar[i]), colormap='jet')) + + +def log_camera(key, prefix, batch, i=0, only_first=None): + """Log camera""" + camera = batch[key] if is_dict(batch) else batch + if is_seq(camera) or is_dict(camera): + return log_sequence(key, prefix, camera, i, only_first, log_camera) + return prep_image(key, prefix, viz_camera(camera[i])) + diff --git a/vidar/core/saver.py b/vidar/core/saver.py new file mode 100644 index 0000000000000000000000000000000000000000..e0e0506a8496ea70d38ad7346daeb5da1e95ac1c --- /dev/null +++ b/vidar/core/saver.py @@ -0,0 +1,214 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import os + +from vidar.utils.config import cfg_has +from vidar.utils.data import make_list +from vidar.utils.types import is_dict, is_list +from vidar.utils.viz import viz_depth, viz_optical_flow +from vidar.utils.write import write_depth, write_image, write_pickle, write_npz + + +class Saver: + """ + Wandb logger class to monitor training + + Parameters + ---------- + cfg : Config + Configuration with parameters + ckpt : String + Name of the model checkpoint (used to create the save folder) + """ + def __init__(self, cfg, ckpt=None): + self.folder = cfg_has(cfg, 'folder', None) + + self.rgb = make_list(cfg.rgb) if cfg_has(cfg, 'rgb') else [] + self.depth = make_list(cfg.depth) if cfg_has(cfg, 'depth') else [] + self.pose = make_list(cfg.pose) if cfg_has(cfg, 'pose') else [] + self.optical_flow = make_list(cfg.optical_flow) if cfg_has(cfg, 'optical_flow') else [] + + self.store_data = cfg_has(cfg, 'store_data', False) + self.separate = cfg.has('separate', False) + + self.ckpt = None if ckpt is None else \ + os.path.splitext(os.path.basename(ckpt))[0] + + self.naming = cfg_has(cfg, 'naming', 'filename') + assert self.naming in ['filename', 'splitname'], \ + 'Invalid naming for saver: {}'.format(self.naming) + + def get_filename(self, path, batch, idx, i): + """Get filename based on input information""" + if self.naming == 'filename': + filename = os.path.join(path, batch['filename'][0][i]).replace('{}', 'rgb') + os.makedirs(os.path.dirname(filename), exist_ok=True) + return filename + elif self.naming == 'splitname': + if self.separate: + return os.path.join(path, '%010d' % idx, '%010d' % idx) + else: + return os.path.join(path, '%010d' % idx) + else: + raise NotImplementedError('Invalid naming for saver: {}'.format(self.naming)) + + def save_data(self, batch, output, prefix): + """ + Prepare for data saving + + Parameters + ---------- + batch : Dict + Dictionary with batch information + output : Dict + Dictionary with output information + prefix : String + Prefix string for the log name + """ + if self.folder is None: + return + + idx = batch['idx'] + predictions = output['predictions'] + + path = os.path.join(self.folder, prefix) + if self.ckpt is not None: + path = os.path.join(path, self.ckpt) + os.makedirs(path, exist_ok=True) + + self.save(batch, predictions, path, idx, 0) + + def save(self, batch, predictions, path, idx, i): + """ + Save batch and prediction information + + Parameters + ---------- + batch : Dict + Dictionary with batch information + predictions : Dict + Dictionary with output predictions + path : String + Path where data will be saved + idx : Int + Batch index in the split + i : Int + Index within batch + + Returns + ------- + data : Dict + Dictionary with output data that was saved + """ + + filename = self.get_filename(path, batch, idx, i) + + raw_intrinsics = batch['raw_intrinsics'][0][i].cpu() if 'raw_intrinsics' in batch else \ + batch['intrinsics'][0][i].cpu() if 'intrinsics' in batch else None + intrinsics = batch['intrinsics'][0][i].cpu() if 'intrinsics' in batch else None + + data = { + 'raw_intrinsics': raw_intrinsics, + 'intrinsics': intrinsics, + } + + for key in batch.keys(): + + if key.startswith('rgb'): + data[key + '_gt'] = {k: v[i].cpu() for k, v in batch[key].items()} + for ctx in batch[key].keys(): + rgb = batch[key][ctx][i].cpu() + if 'gt' in self.rgb: + if rgb.dim() == 5: + for j in range(rgb.shape[1]): + write_image('%s_%s(%d_%d)_gt.png' % (filename, key, j, ctx), + rgb[:, j]) + else: + write_image('%s_%s(%d)_gt.png' % (filename, key, ctx), + rgb) + + if key.startswith('depth'): + data[key + '_gt'] = {k: v[i].cpu() for k, v in batch[key].items()} + for ctx in batch[key].keys(): + depth = batch[key][ctx][i].cpu() + if 'gt_png' in self.depth: + write_depth('%s_%s(%d)_gt.png' % (filename, key, ctx), + depth) + if 'gt_npz' in self.depth: + write_depth('%s_%s(%d)_gt.npz' % (filename, key, ctx), + depth, intrinsics=raw_intrinsics) + if 'gt_viz' in self.depth: + write_image('%s_%s(%d)_gt_viz.png' % (filename, key, ctx), + viz_depth(depth, filter_zeros=True)) + + if key.startswith('pose'): + pose = {k: v[i].cpu() for k, v in batch[key].items()} + data[key + '_gt'] = pose + if 'gt' in self.pose: + write_pickle('%s_%s_gt' % (filename, key), + pose) + + for key in predictions.keys(): + + if key.startswith('rgb'): + data[key + '_pred'] = {k: v[i].cpu() for k, v in predictions[key].items()} + for ctx in predictions[key].keys(): + rgb = predictions[key][ctx][i].cpu() + if 'pred' in self.rgb: + if rgb.dim() == 5: + for j in range(rgb.shape[1]): + write_image('%s_%s(%d_%d)_pred.png' % (filename, key, j, ctx), + rgb[:, j]) + else: + write_image('%s_%s(%d)_pred.png' % (filename, key, ctx), + rgb) + + if key.startswith('depth'): + data[key + '_pred'] = {k: v[i].cpu() for k, v in predictions[key].items()} + for ctx in predictions[key].keys(): + depth = predictions[key][ctx][0][i].cpu() + if 'png' in self.depth: + write_depth('%s_%s(%d)_pred.png' % (filename, key, ctx), + depth) + if 'npz' in self.depth: + write_depth('%s_%s(%d)_pred.npz' % (filename, key, ctx), + depth, intrinsics=intrinsics) + if 'viz' in self.depth: + write_image('%s_%s(%d)_pred_viz.png' % (filename, key, ctx), + viz_depth(depth)) + + if key.startswith('pose'): + pose = {key: val[i].cpu() for key, val in predictions[key].items()} + data[key + '_pred'] = pose + if 'pred' in self.pose: + write_pickle('%s_%s_pred' % (filename, key), + pose) + + if key.startswith('fwd_optical_flow'): + optical_flow = {key: val[i].cpu() for key, val in predictions[key].items()} + data[key + '_pred'] = optical_flow + if 'npz' in self.optical_flow: + write_npz('%s_%s_pred' % (filename, key), + {'fwd_optical_flow': optical_flow}) + if 'viz' in self.optical_flow: + for ctx in optical_flow.keys(): + write_image('%s_%s(%d)_pred_viz.png' % (filename, key, ctx), + viz_optical_flow(optical_flow[ctx])) + + if key.startswith('mask'): + if is_dict(predictions[key]): + data[key] = {k: v[i].cpu() for k, v in predictions[key].items()} + for ctx in data[key].keys(): + write_image('%s_%s(%d)_pred_viz.png' % (filename, key, ctx), predictions[key][ctx][0]) + elif is_list(predictions[key]): + data[key] = [v[i].cpu() for k, v in predictions[key]] + for ctx in data[key]: + write_image('%s_%s(%d)_pred_viz.png' % (filename, key, ctx), predictions[key][ctx][0]) + else: + data[key] = predictions[key][i].cpu() + write_image('%s_%s_pred_viz.png' % (filename, key), predictions[key][0]) + + if self.store_data: + write_pickle('%s' % filename, data) + + return data diff --git a/vidar/core/trainer.py b/vidar/core/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..7bcbe596068dbe1d349760f2c67508e46882f297 --- /dev/null +++ b/vidar/core/trainer.py @@ -0,0 +1,472 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from collections import OrderedDict + +import torch +from tqdm import tqdm + +from vidar.core.checkpoint import ModelCheckpoint +from vidar.core.logger import WandbLogger +from vidar.core.saver import Saver +from vidar.utils.config import cfg_has, dataset_prefix +from vidar.utils.data import make_list, keys_in +from vidar.utils.distributed import on_rank_0, rank, world_size, print0, dist_mode +from vidar.utils.logging import pcolor, AvgMeter +from vidar.utils.setup import setup_dataloader, reduce +from vidar.utils.types import is_dict, is_seq, is_numpy, is_tensor, is_list + + +def sample_to_cuda(sample, proc_rank, dtype=None): + """ + Copy sample to GPU + + Parameters + ---------- + sample : Dict + Dictionary with sample information + proc_rank : Int + Process rank + dtype : torch.Type + Data type for conversion + + Returns + ------- + sample : Dict + Dictionary with sample on the GPU + """ + # Do nothing if cuda is not available + if not torch.cuda.is_available(): + return sample + # If it's a sequence (list or tuple) + if is_seq(sample): + return [sample_to_cuda(val, proc_rank, dtype) for val in sample] + # If it's a dictionary + elif is_dict(sample): + return {key: sample_to_cuda(sample[key], proc_rank, dtype) for key in sample.keys()} + # If it's a torch tensor + elif is_tensor(sample): + dtype = dtype if torch.is_floating_point(sample) else None + return sample.to(f'cuda:{proc_rank}', dtype=dtype) + # If it's a numpy array + elif is_numpy(sample): + tensor_data = torch.Tensor(sample) + dtype = dtype if torch.is_floating_point(tensor_data) else None + return tensor_data.to(f'cuda:{proc_rank}', dtype=dtype) + # Otherwise, do nothing + else: + return sample + + +class Trainer: + """ + Trainer class for model optimization and inference + + Parameters + ---------- + cfg : Config + Configuration with parameters + ckpt : String + Name of the model checkpoint to start from + """ + def __init__(self, cfg, ckpt=None): + super().__init__() + + self.avg_losses = {} + self.min_epochs = cfg_has(cfg.wrapper, 'min_epochs', 0) + self.max_epochs = cfg_has(cfg.wrapper, 'max_epochs', 100) + + self.validate_first = cfg_has(cfg.wrapper, 'validate_first', False) + self.find_unused_parameters = cfg_has(cfg.wrapper, 'find_unused_parameters', False) + self.grad_scaler = cfg_has(cfg.wrapper, 'grad_scaler', False) and torch.cuda.is_available() + + self.saver = self.logger = self.checkpoint = None + self.prep_logger_and_checkpoint(cfg) + + self.prep_saver(cfg, ckpt) + + self.all_modes = ['train', 'mixed', 'validation', 'test'] + self.train_modes = ['train', 'mixed'] + + self.current_epoch = 0 + self.training_bar_metrics = cfg_has(cfg.wrapper, 'training_bar_metrics', []) + + @property + def progress(self): + """Current epoch progress (percentage)""" + return self.current_epoch / self.max_epochs + + @property + def proc_rank(self): + """Process rank""" + return rank() + + @property + def world_size(self): + """World size""" + return world_size() + + @property + def is_rank_0(self): + """True if worker is on rank 0""" + return self.proc_rank == 0 + + def param_logs(self, optimizers): + """Returns various logs for tracking""" + params = OrderedDict() + for key, val in optimizers.items(): + params[f'{key}_learning_rate'] = val['optimizer'].param_groups[0]['lr'] + params[f'{key}_weight_decay'] = val['optimizer'].param_groups[0]['weight_decay'] + params['progress'] = self.progress + return { + **params, + } + + @on_rank_0 + def prep_logger_and_checkpoint(self, cfg): + """Prepare logger and checkpoint class if requested""" + add_logger = cfg_has(cfg, 'wandb') + add_checkpoint = cfg_has(cfg, 'checkpoint') + + if add_logger: + self.logger = WandbLogger(cfg.wandb, verbose=True) + if add_checkpoint and not cfg_has(cfg.checkpoint, 'name'): + cfg.checkpoint.name = self.logger.run_name + else: + self.logger = None + + if add_checkpoint: + self.checkpoint = ModelCheckpoint(cfg.checkpoint, verbose=True) + else: + self.checkpoint = None + + if add_logger: + self.logger.log_config(cfg) + + def prep_saver(self, cfg, ckpt=None): + """Prepare saver class if requested""" + ckpt = ckpt if ckpt is not None else cfg.arch.model.has('checkpoint', None) + add_saver = cfg_has(cfg, 'save') + + if add_saver: + print0(pcolor('#' * 60, color='red', attrs=('dark',))) + print0(pcolor('### Saving data to: %s' % cfg.save.folder, color='red')) + print0(pcolor('#' * 60, color='red', attrs=('dark',))) + self.saver = Saver(cfg.save, ckpt) + + @on_rank_0 + def check_and_save(self, wrapper, output, prefixes): + """Check for conditions and save if it's time""" + if self.checkpoint is not None: + self.checkpoint.check_and_save( + wrapper, output, prefixes, epoch=self.current_epoch) + + @on_rank_0 + def log_losses_and_metrics(self, metrics=None, optimizers=None): + """Log losses and metrics on wandb""" + if self.logger is not None: + self.logger.log_metrics({ + '{}'.format(key): val.get() for key, val in self.avg_losses.items() + }) + if optimizers is not None: + self.logger.log_metrics(self.param_logs(optimizers)) + if metrics is not None: + self.logger.log_metrics({ + **metrics, 'epochs': self.current_epoch, + }) + + @on_rank_0 + def print_logger_and_checkpoint(self): + """Print logger and checkpoint information""" + font_base = {'color': 'red', 'attrs': ('bold', 'dark')} + font_name = {'color': 'red', 'attrs': ('bold',)} + font_underline = {'color': 'red', 'attrs': ('underline',)} + + if self.logger or self.checkpoint: + print(pcolor('#' * 120, **font_base)) + if self.logger: + print(pcolor('### WandB: ', **font_base) + \ + pcolor('{}'.format(self.logger.run_name), **font_name) + \ + pcolor(' - ', **font_base) + \ + pcolor('{}'.format(self.logger.run_url), **font_underline)) + if self.checkpoint and self.checkpoint.s3_url is not None: + print(pcolor('### Checkpoint: ', **font_base) + \ + pcolor('{}'.format(self.checkpoint.s3_url), **font_underline)) + if self.logger or self.checkpoint: + print(pcolor('#' * 120 + '\n', **font_base)) + + @on_rank_0 + def update_train_progress_bar(self, progress_bar): + """Update training progress bar on screen""" + string = '| {} | Loss {:.3f}'.format( + self.current_epoch, self.avg_losses['loss'].get()) + bar_keys = self.training_bar_metrics + for key in keys_in(self.avg_losses, bar_keys): + name, abbrv = (key[0], key[1]) if is_list(key) else (key, key) + string += ' | {} {:.2f}'.format(abbrv, self.avg_losses[name].get()) + progress_bar.set_description(string) + + @on_rank_0 + def update_averages(self, output): + """Update loss averages""" + averages = {'loss': output['loss'], **output['metrics']} + for key in averages.keys(): + if key not in self.avg_losses.keys(): + self.avg_losses[key] = AvgMeter(50) + self.avg_losses[key](averages[key].item() if is_tensor(averages[key]) else averages[key]) + + def train_progress_bar(self, dataloader, ncols=None, aux_dataloader=None): + """Print training progress bar on screen""" + full_dataloader = dataloader if aux_dataloader is None else zip(dataloader, aux_dataloader) + return tqdm(enumerate(full_dataloader, 0), + unit='im', unit_scale=self.world_size * dataloader.batch_size, + total=len(dataloader), smoothing=0, + disable=not self.is_rank_0, ncols=ncols) + + def val_progress_bar(self, dataloader, prefix, ncols=None): + """Print validation progress bar on screen""" + return tqdm(enumerate(dataloader, 0), + unit='im', unit_scale=self.world_size * dataloader.batch_size, + total=len(dataloader), smoothing=0, + disable=not self.is_rank_0, ncols=ncols, + desc=prefix) + + def prepare_distributed_model(self, wrapper): + """Prepare model for distributed training or not (CPU/GPU/DDP)""" + if dist_mode() == 'cpu': + wrapper.arch = wrapper.arch + elif dist_mode() == 'gpu': + wrapper = wrapper.cuda(self.proc_rank) + wrapper.arch = wrapper.arch + elif dist_mode() == 'ddp': + wrapper = wrapper.cuda(self.proc_rank) + wrapper.arch = torch.nn.parallel.DistributedDataParallel( + wrapper.arch, device_ids=[self.proc_rank], + find_unused_parameters=self.find_unused_parameters, + broadcast_buffers=True) + else: + raise ValueError('Wrong distributed mode {}'.format(dist_mode)) + return wrapper + + def prepare_dataloaders(self, wrapper): + """Prepare dataloaders for training and inference""" + font1 = {'color': 'blue', 'attrs': ('dark', 'bold')} + font2 = {'color': 'blue', 'attrs': ('bold',)} + + print0(pcolor('#' * 60, **font1)) + + if dist_mode() == 'cpu': + print0(pcolor(f'### ', **font1) + + pcolor(f'CPU Training', **font2)) + elif dist_mode() == 'gpu': + print0(pcolor(f'### ', **font1) + + pcolor(f'GPU Training', **font2)) + elif dist_mode() == 'ddp': + print0(pcolor(f'### ', **font1) + + pcolor(f'DDP Training ', **font2) + + pcolor(f'with ', **font1) + + pcolor(f'{self.world_size}', **font2) + + pcolor(f' GPUs', **font1)) + + # Send wrapper to GPU + wrapper = self.prepare_distributed_model(wrapper) + + for key in wrapper.datasets_cfg.keys(): + wrapper.datasets_cfg[key] = make_list(wrapper.datasets_cfg[key]) + + # Prepare dataloaders + dataloaders = { + key: setup_dataloader(val, wrapper.datasets_cfg[key][0].dataloader, key) + for key, val in wrapper.datasets.items() if key in wrapper.datasets_cfg.keys() + } + + # Prepare prefixes + + prefixes = { + key: [dataset_prefix(wrapper.datasets_cfg[key][n], n) for n in range(len(val))] + for key, val in wrapper.datasets_cfg.items() if 'name' in wrapper.datasets_cfg[key][0].__dict__.keys() + } + + # Reduce information + reduced_dataloaders = reduce(dataloaders, self.all_modes, self.train_modes) + reduced_prefixes = reduce(prefixes, self.all_modes, self.train_modes) + + print0(pcolor('#' * 60, **font1)) + + return reduced_dataloaders, reduced_prefixes + + def filter_optimizers(self, optimizers): + """Filter optimizers to find those being used at each epoch""" + in_optimizers, out_optimizers = {}, {} + for key, val in optimizers.items(): + if 'stop_epoch' not in val['settings'] or \ + val['settings']['stop_epoch'] >= self.current_epoch: + in_optimizers[key] = val['optimizer'] + else: + out_optimizers[key] = val['optimizer'] + + if rank() == 0: + + string = pcolor('Optimizing: ', color='yellow') + for key, val in in_optimizers.items(): + string += pcolor('{}'.format(key), color='green', attrs=('bold', 'dark')) + string += pcolor(' ({}) '.format(val.param_groups[0]['lr']), + color='green', attrs=('dark',)) + for key, val in out_optimizers.items(): + string += pcolor('{}'.format(key), color='cyan', attrs=('bold', 'dark')) + string += pcolor(' ({}) '.format(val.param_groups[0]['lr']), + color='cyan', attrs=('dark',)) + + print(pcolor('#' * 120, color='yellow', attrs=('dark',))) + print(string) + print(pcolor('#' * 120, color='yellow', attrs=('dark',))) + print() + + return in_optimizers, out_optimizers + + def learn(self, wrapper): + """Entry-point class for training a model""" + # Get optimizers and schedulers + optimizers, schedulers = wrapper.configure_optimizers_and_schedulers() + + # Get gradient scaler if requested + scaler = torch.cuda.amp.GradScaler() if self.grad_scaler else None + + # Get learn information + dataloaders, prefixes = self.prepare_dataloaders(wrapper) + aux_dataloader = None if 'mixed' not in dataloaders else dataloaders['mixed'] + + # Check for train and validation dataloaders + has_train_dataloader = 'train' in dataloaders + has_validation_dataloader = 'validation' in dataloaders + + # Validate before training if requested + if self.validate_first and has_validation_dataloader: + validation_output = self.validate('validation', dataloaders, prefixes, wrapper) + self.post_validation(validation_output, optimizers, prefixes['validation'], wrapper) + else: + self.current_epoch += 1 + + # Epoch loop + if has_train_dataloader: + for epoch in range(self.current_epoch, self.max_epochs + 1): + + # Train and log + self.train(dataloaders['train'], optimizers, schedulers, wrapper, scaler=scaler, + aux_dataloader=aux_dataloader) + + # Validate, save and log + if has_validation_dataloader: + validation_output = self.validate('validation', dataloaders, prefixes, wrapper) + self.post_validation(validation_output, optimizers, prefixes['validation'], wrapper) + + # Take a scheduler step + if wrapper.update_schedulers == 'epoch': + for scheduler in schedulers.values(): + scheduler.step() + + # Finish logger if available + if self.logger: + self.logger.finish() + + def train(self, dataloader, optimizers, schedulers, wrapper, scaler=None, aux_dataloader=None): + """Training loop for each epoch""" + # Choose which optimizers to use + in_optimizers, out_optimizers = self.filter_optimizers(optimizers) + + # Set wrapper to train + wrapper.train_custom(in_optimizers, out_optimizers) + + # Shuffle dataloader sampler + if hasattr(dataloader.sampler, "set_epoch"): + dataloader.sampler.set_epoch(self.current_epoch) + + # Shuffle auxiliar dataloader sampler + if aux_dataloader is not None: + if hasattr(aux_dataloader.sampler, "set_epoch"): + aux_dataloader.sampler.set_epoch(self.current_epoch) + + # Prepare progress bar + progress_bar = self.train_progress_bar( + dataloader, aux_dataloader=aux_dataloader, ncols=120) + + # Zero gradients for the first iteration + for optimizer in in_optimizers.values(): + optimizer.zero_grad() + + # Loop through all batches + for i, batch in progress_bar: + + # Send samples to GPU and take a training step + batch = sample_to_cuda(batch, self.proc_rank) + output = wrapper.training_step(batch, epoch=self.current_epoch) + + # Step optimizer + if wrapper.update_schedulers == 'step': + for scheduler in schedulers.values(): + scheduler.step() + + # Backprop through loss + if scaler is None: + output['loss'].backward() + else: + scaler.scale(output['loss']).backward() + + for optimizer in in_optimizers.values(): + if not output['loss'].isnan().any(): + if scaler is None: + optimizer.step() + else: + scaler.step(optimizer) + else: + print('NAN DETECTED!', i, batch['idx']) + optimizer.zero_grad() + if scaler is not None: + scaler.update() + + self.update_averages(output) + self.update_train_progress_bar(progress_bar) + + # Return outputs for epoch end + return wrapper.training_epoch_end() + + @torch.no_grad() + def validate(self, mode, dataloaders, prefixes, wrapper): + """Validation loop""" + # Set wrapper to eval + wrapper.eval_custom() + # For all validation datasets + dataset_outputs = [] + for dataset_idx, (dataset, dataloader, prefix) in \ + enumerate(zip(wrapper.datasets[mode], dataloaders[mode], prefixes[mode])): + # Prepare progress bar for that dataset + progress_bar = self.val_progress_bar(dataloader, prefix, ncols=120) + # For all batches + batch_outputs = [] + for batch_idx, batch in progress_bar: + # Send batch to GPU and take a validation step + batch = sample_to_cuda(batch, self.proc_rank) + output, results = wrapper.validation_step(batch, epoch=self.current_epoch) + if 'batch' in output: + batch = output['batch'] + batch_outputs += results + if self.logger: + self.logger.log_data('val', batch, output, dataset, prefix) + if self.saver: + self.saver.save_data(batch, output, prefix) + # Append dataset outputs to list of all outputs + dataset_outputs.append(batch_outputs) + # Get results from validation epoch end + return wrapper.validation_epoch_end(dataset_outputs, prefixes[mode]) + + def post_validation(self, output, optimizers, prefixes, wrapper): + """Post-processing steps for validation""" + self.check_and_save(wrapper, output, prefixes) + self.log_losses_and_metrics(output, optimizers) + self.print_logger_and_checkpoint() + self.current_epoch += 1 + + def test(self, wrapper): + """Test a model by running validation once""" + dataloaders, prefixes = self.prepare_dataloaders(wrapper) + self.validate('validation', dataloaders, prefixes, wrapper) + diff --git a/vidar/core/wrapper.py b/vidar/core/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..953d226d98428bde3f958ef928b9a4dfa84281e1 --- /dev/null +++ b/vidar/core/wrapper.py @@ -0,0 +1,262 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import os +import random +from abc import ABC +from collections import OrderedDict + +import torch + +from vidar.utils.config import cfg_has, read_config +from vidar.utils.data import set_random_seed +from vidar.utils.distributed import rank, world_size +from vidar.utils.flip import flip_batch, flip_output +from vidar.utils.logging import pcolor, set_debug +from vidar.utils.networks import load_checkpoint, save_checkpoint, freeze_layers_and_norms +from vidar.utils.setup import setup_arch, setup_datasets, setup_metrics +from vidar.utils.types import is_str + + +class Wrapper(torch.nn.Module, ABC): + """ + Trainer class for model optimization and inference + + Parameters + ---------- + cfg : Config + Configuration with parameters + ckpt : String + Name of the model checkpoint to start from + verbose : Bool + Print information on screen if enabled + """ + def __init__(self, cfg, ckpt=None, verbose=False): + super().__init__() + + if verbose and rank() == 0: + font = {'color': 'cyan', 'attrs': ('bold', 'dark')} + print(pcolor('#' * 100, **font)) + print(pcolor('#' * 42 + ' VIDAR WRAPPER ' + '#' * 43, **font)) + print(pcolor('#' * 100, **font)) + + # Get configuration + cfg = read_config(cfg) if is_str(cfg) else cfg + self.cfg = cfg + + # Data augmentations + self.flip_lr_prob = cfg_has(cfg.wrapper, 'flip_lr_prob', 0.0) + self.validate_flipped = cfg_has(cfg.wrapper, 'validate_flipped', False) + + # Set random seed + set_random_seed(cfg.wrapper.seed + rank()) + set_debug(cfg_has(cfg.wrapper, 'debug', False)) + + # Setup architecture, datasets and tasks + self.arch = setup_arch(cfg.arch, checkpoint=ckpt, verbose=verbose) if cfg_has(cfg, 'arch') else None + self.datasets, self.datasets_cfg = setup_datasets( + cfg.datasets, verbose=verbose) if cfg_has(cfg, 'datasets') else (None, None) + self.metrics = setup_metrics(cfg.evaluation) if cfg_has(cfg, 'evaluation') else {} + + sync_batch_norm = cfg_has(cfg.wrapper, 'sync_batch_norm', False) + if sync_batch_norm and os.environ['DIST_MODE'] == 'ddp': + self.arch = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.arch) + + self.mixed_precision = cfg_has(cfg.wrapper, 'mixed_precision', False) + + self.update_schedulers = None + + def save(self, filename, epoch=None): + """Save checkpoint""" + save_checkpoint(filename, self, epoch=epoch) + + def load(self, checkpoint, strict=True, verbose=False): + """Load checkpoint""" + load_checkpoint(self, checkpoint, strict=strict, verbose=verbose) + + def train_custom(self, in_optimizers, out_optimizers): + """Customized training flag for the model""" + self.arch.train() + for key in in_optimizers.keys(): + arch = self.arch.module if hasattr(self.arch, 'module') else self.arch + freeze_layers_and_norms(arch.networks[key], ['ALL'], flag_freeze=False) + for key in out_optimizers.keys(): + arch = self.arch.module if hasattr(self.arch, 'module') else self.arch + freeze_layers_and_norms(arch.networks[key], ['ALL'], flag_freeze=True) + + def eval_custom(self): + """Customized evaluation flag for the model""" + self.arch.eval() + + def configure_optimizers_and_schedulers(self): + """Configure depth and pose optimizers and the corresponding scheduler""" + + if not cfg_has(self.cfg, 'optimizers'): + return None, None + + optimizers = OrderedDict() + schedulers = OrderedDict() + + for key, val in self.cfg.optimizers.__dict__.items(): + assert key in self.arch.networks, f'There is no network for optimizer {key}' + optimizers[key] = { + 'optimizer': getattr(torch.optim, val.name)(**{ + 'lr': val.lr, + 'weight_decay': cfg_has(val, 'weight_decay', 0.0), + 'params': self.arch.networks[key].parameters(), + }), + 'settings': {} if not cfg_has(val, 'settings') else val.settings.__dict__ + } + if cfg_has(val, 'scheduler'): + if val.scheduler.name == 'CosineAnnealingWarmUpRestarts': + from cosine_annealing_warmup import CosineAnnealingWarmupRestarts + epoch = float(len(self.datasets['train']) / ( + world_size() * self.datasets_cfg['train'].dataloader.batch_size * self.datasets_cfg['train'].repeat[0])) + schedulers[key] = CosineAnnealingWarmupRestarts(**{ + 'optimizer': optimizers[key]['optimizer'], + 'first_cycle_steps': int(val.scheduler.first_cycle_steps * epoch), + 'cycle_mult': val.scheduler.cycle_mult, + 'min_lr': val.scheduler.min_lr, + 'max_lr': val.scheduler.max_lr, + 'warmup_steps': int(val.scheduler.warmup_steps * epoch), + 'gamma': val.scheduler.gamma, + }) + self.update_schedulers = 'step' + elif val.scheduler.name == 'LinearWarmUp': + from externals.huggingface.transformers.src.transformers.optimization import get_linear_schedule_with_warmup + schedulers[key] = get_linear_schedule_with_warmup(**{ + 'optimizer': optimizers[key]['optimizer'], + 'num_warmup_steps': val.scheduler.num_warmup_steps, + 'num_training_steps': val.scheduler.num_training_steps, + }) + self.update_schedulers = 'step' + else: + schedulers[key] = getattr(torch.optim.lr_scheduler, val.scheduler.name)(**{ + 'optimizer': optimizers[key]['optimizer'], + 'step_size': val.scheduler.step_size, + 'gamma': val.scheduler.gamma, + }) + self.update_schedulers = 'epoch' + + # Return optimizer and scheduler + return optimizers, schedulers + + def run_arch(self, batch, epoch, flip, unflip): + """ + Run model on a batch + + Parameters + ---------- + batch : Dict + Dictionary with batch information + epoch : Int + Current epoch + flip : Bool + Batch should be flipped + unflip : Bool + Output should be unflipped + + Returns + ------- + output : Dict + Dictionary with model outputs + """ + batch = flip_batch(batch) if flip else batch + output = self.arch(batch, epoch=epoch) + return flip_output(output) if flip and unflip else output + + def training_step(self, batch, epoch): + """Processes a training batch""" + flip_lr = False if self.flip_lr_prob == 0 else \ + random.random() < self.flip_lr_prob + + if self.mixed_precision: + with torch.cuda.amp.autocast(): + output = self.run_arch(batch, epoch=epoch, flip=flip_lr, unflip=False) + else: + output = self.run_arch(batch, epoch=epoch, flip=flip_lr, unflip=False) + + losses = {key: val for key, val in output.items() if key.startswith('loss')} + + return { + **losses, + 'metrics': output['metrics'] + } + + def validation_step(self, batch, epoch): + """Processes a validation batch""" + # from vidar.utils.data import break_batch + # batch = break_batch(batch) + + if self.mixed_precision: + with torch.cuda.amp.autocast(): + output = self.run_arch(batch, epoch=epoch, flip=False, unflip=False) + flipped_output = None if not self.validate_flipped else \ + self.run_arch(batch, epoch=epoch, flip=True, unflip=True) + else: + output = self.run_arch(batch, epoch=epoch, flip=False, unflip=False) + flipped_output = None if not self.validate_flipped else \ + self.run_arch(batch, epoch=epoch, flip=True, unflip=True) + + if 'batch' in output: + batch = output['batch'] + + results = self.evaluate(batch, output, flipped_output) + + results = [{ + 'idx': batch['idx'][i], + **{key: val[i] for key, val in results['metrics'].items()} + } for i in range(len(batch['idx']))] + + return output, results + + @staticmethod + def training_epoch_end(): + """Finishes a training epoch (do nothing for now)""" + return {} + + def validation_epoch_end(self, output, prefixes): + """Finishes a validation epoch""" + if isinstance(output[0], dict): + output = [output] + + metrics_dict = {} + for task in self.metrics: + metrics_dict.update( + self.metrics[task].reduce( + output, self.datasets['validation'], prefixes)) + + return metrics_dict + + def evaluate(self, batch, output, flipped_output=None): + """ + Evaluate batch to produce predictions and metrics for different tasks + + Parameters + ---------- + batch : Dict + Dictionary with batch information + output : Dict + Dictionary with output information + flipped_output : Dict + Dictionary with flipped output information + + Returns + ------- + results: Dict + Dictionary with evaluation results + """ + # Evaluate different tasks + metrics, predictions = OrderedDict(), OrderedDict() + for task in self.metrics: + task_metrics, task_predictions = \ + self.metrics[task].evaluate(batch, output['predictions'], + flipped_output['predictions'] if flipped_output else None) + metrics.update(task_metrics) + predictions.update(task_predictions) + # Crate results dictionary with metrics and predictions + results = {'metrics': metrics, 'predictions': predictions} + # Return final results + return results + + + diff --git a/vidar/datasets/BaseDataset.py b/vidar/datasets/BaseDataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2fc241194616d62d3810ab72d414dfc70b872b54 --- /dev/null +++ b/vidar/datasets/BaseDataset.py @@ -0,0 +1,173 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import os +from abc import ABC + +from torch.utils.data import Dataset + +from vidar.utils.types import is_list + + +class BaseDataset(Dataset, ABC): + """ + Base dataset class + + Parameters + ---------- + path : String + Dataset location + context : Tuple + Temporal context + cameras : Tuple + Camera names + labels : Tuple + Labels to be loaded + labels_context : + Context labels to be loaded + data_transform : Function + Transformations to be applied to sample + ontology : String + Which semantic ontology should be used + return_ontology : Bool + Whether the ontology should be returned + virtual : Bool + Whether the dataset is virtual or not + kwargs : Dict + Additional parameters + """ + def __init__(self, path, context, cameras, labels=(), labels_context=(), + data_transform=None, ontology=None, return_ontology=False, virtual=False, + **kwargs): + super().__init__() + + self.path = path + self.labels = labels + self.labels_context = labels_context + self.cameras = cameras + self.data_transform = data_transform + + self.num_cameras = len(cameras) if is_list(cameras) else cameras + + self.bwd_contexts = [ctx for ctx in context if ctx < 0] + self.fwd_contexts = [ctx for ctx in context if ctx > 0] + + self.bwd_context = 0 if len(context) == 0 else - min(0, min(context)) + self.fwd_context = 0 if len(context) == 0 else max(0, max(context)) + + self.context = [v for v in range(- self.bwd_context, 0)] + \ + [v for v in range(1, self.fwd_context + 1)] + + self.num_context = self.bwd_context + self.fwd_context + self.with_context = self.num_context > 0 + + self.ontology = ontology + self.return_ontology = return_ontology + self.virtual = virtual + + def relative_path(self, filename): + return {key: os.path.splitext(val.replace(self.path + '/', ''))[0] + for key, val in filename.items()} + + # Label properties + + @property + def with_depth(self): + """If dataset contains depth""" + return 'depth' in self.labels + + @property + def with_input_depth(self): + """If dataset contains input depth""" + return 'input_depth' in self.labels + + @property + def with_pose(self): + """If dataset contains pose""" + return 'pose' in self.labels + + @property + def with_semantic(self): + """If dataset contains semantic""" + return 'semantic' in self.labels + + @property + def with_instance(self): + """If dataset contains instance""" + return 'instance' in self.labels + + @property + def with_optical_flow(self): + """If dataset contains optical flow""" + return 'optical_flow' in self.labels + + @property + def with_scene_flow(self): + """If dataset contains scene flow""" + return 'scene_flow' in self.labels + + @property + def with_bbox2d(self): + """If dataset contains 2d bounding boxes""" + return 'bbox2d' in self.labels + + @property + def with_bbox3d(self): + """If dataset contains 3d bounding boxes""" + return 'bbox3d' in self.labels + + @property + def with_lidar(self): + """If dataset contains lidar""" + return 'lidar' in self.labels + + @property + def with_radar(self): + """If dataset contains radar""" + return 'radar' in self.labels + + @property + def with_pointcache(self): + """If dataset contains pointcaches""" + return 'pointcache' in self.labels + + # Label context properties + + @property + def with_depth_context(self): + """If dataset contains context depth""" + return 'depth' in self.labels_context + + @property + def with_input_depth_context(self): + """If dataset contains context input depth""" + return 'input_depth' in self.labels_context + + @property + def with_semantic_context(self): + """If dataset contains context semantic""" + return 'semantic' in self.labels_context + + @property + def with_instance_context(self): + """If dataset contains context instance""" + return 'instance' in self.labels_context + + @property + def with_optical_flow_context(self): + """If dataset contains context optical flow""" + return 'optical_flow' in self.labels_context + + @property + def with_scene_flow_context(self): + """If dataset contains context scene flow""" + return 'scene_flow' in self.labels_context + + @property + def with_bbox2d_context(self): + """If dataset contains context 2d bounding boxes""" + return 'bbox2d' in self.labels_context + + @property + def with_bbox3d_context(self): + """If dataset contains context 3d bounding boxes""" + return 'bbox3d' in self.labels_context diff --git a/vidar/datasets/EUROCDataset.py b/vidar/datasets/EUROCDataset.py new file mode 100644 index 0000000000000000000000000000000000000000..c70c0151d94448ff11d7bbc9fb1a5a632cf0051e --- /dev/null +++ b/vidar/datasets/EUROCDataset.py @@ -0,0 +1,177 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import re +from collections import defaultdict +import os +from abc import ABC + +from PIL import Image +import numpy as np +from vidar.utils.read import read_image +from vidar.datasets.BaseDataset import BaseDataset +from vidar.datasets.utils.misc import stack_sample + + +def dummy_calibration(image): + w, h = [float(d) for d in image.size] + return np.array([[1000. , 0. , w / 2. - 0.5], + [0. , 1000. , h / 2. - 0.5], + [0. , 0. , 1. ]]) + + +def get_idx(filename): + return int(re.search(r'\d+', filename).group()) + + +class EUROCDataset(BaseDataset, ABC): + """ + KITTI dataset class + + Parameters + ---------- + split : String + Split file + stride : Tuple + Which context strides to use + spaceholder : String + Space pattern on input images + data_transform : Function + Transformations applied to the sample + """ + def __init__(self, split, strides=(1,), spaceholder='{:19}', **kwargs): + super().__init__(**kwargs) + + self.split = split + self.spaceholder = spaceholder + + self.backward_context = strides[0] + self.forward_context = strides[1] + self.has_context = self.backward_context + self.forward_context > 0 + + self.file_tree = defaultdict(list) + self.read_files(self.path) + + self.files = [] + for k, v in self.file_tree.items(): + file_set = set(self.file_tree[k]) + files = [fname for fname in sorted(v) if self._has_context(fname, file_set)] + self.files.extend([[k, fname] for fname in files]) + + def read_files(self, directory, ext=('.png', '.jpg', '.jpeg'), skip_empty=True): + """Read input images""" + files = defaultdict(list) + for entry in os.scandir(directory): + relpath = os.path.relpath(entry.path, directory) + if entry.is_dir(): + d_files = self.read_files(entry.path, ext=ext, skip_empty=skip_empty) + if skip_empty and not len(d_files): + continue + self.file_tree[entry.path] = d_files[entry.path] + elif entry.is_file(): + if ext is None or entry.path.lower().endswith(tuple(ext)): + files[directory].append(relpath) + return files + + def __len__(self): + return len(self.files) + + def _change_idx(self, idx, filename): + """Prepare name strings according to index""" + _, ext = os.path.splitext(os.path.basename(filename)) + return self.spaceholder.format(idx) + ext + + def _has_context(self, filename, file_set): + """Check if image has context""" + context_paths = self._get_context_file_paths(filename, file_set) + return len([f in file_set for f in context_paths]) >= len(self.context) + + def _get_context_file_paths(self, filename, file_set): + """Get file path for contexts""" + fidx = get_idx(filename) + idxs = [-self.backward_context, -self.forward_context, self.backward_context, self.forward_context] + potential_files = [self._change_idx(fidx + i, filename) for i in idxs] + return [fname for fname in potential_files if fname in file_set] + + def _read_rgb_context_files(self, session, filename): + """Read context images""" + file_set = set(self.file_tree[session]) + context_paths = self._get_context_file_paths(filename, file_set) + return [self._read_rgb_file(session, filename) for filename in context_paths] + + def _read_rgb_file(self, session, filename): + """Read target images""" + gray_image = read_image(os.path.join(self.path, session, filename)) + gray_image_np = np.array(gray_image) + rgb_image_np = np.stack([gray_image_np for _ in range(3)], axis=2) + return Image.fromarray(rgb_image_np) + + def _read_npy_depth(self, session, depth_filename): + """Read depth from numpy file""" + depth_file_path = os.path.join(self.path, session, '../../depth_maps', depth_filename) + return np.load(depth_file_path) + + def _read_depth(self, session, depth_filename): + """Get the depth map from a file.""" + return self._read_npy_depth(session, depth_filename) + + def _has_depth(self, session, depth_filename): + """Check if depth map exists""" + depth_file_path = os.path.join(self.path, session, '../../depth_maps', depth_filename) + return os.path.isfile(depth_file_path) + + def __getitem__(self, idx): + """Get dataset sample""" + + samples = [] + + session, filename = self.files[idx] + image = self._read_rgb_file(session, filename) + + sample = { + 'idx': idx, + 'filename': '%s_%s' % (session, os.path.splitext(filename)[0]), + 'rgb': {0: image}, + 'intrinsics': {0: dummy_calibration(image)}, + } + + if self.has_context: + image_context = self._read_rgb_context_files(session, filename) + sample['rgb'].update({ + key: val for key, val in zip(self.context, image_context) + }) + + depth_filename = filename.split('.')[0] + 'depth.npy' + if self.with_depth: + if self._has_depth(session, depth_filename): + sample['depth'] = {0: self._read_depth(session, depth_filename)} + + samples.append(sample) + + if self.data_transform: + samples = self.data_transform(samples) + + return stack_sample(samples) + + +if __name__ == "__main__": + + data_dir = '/data/vidar/euroc/euroc_cam/cam0' + euroc_dataset = EUROCDataset(path=data_dir, + strides=[49999872, 50000128], + context=[-1,1], + split='{:19}', + labels=['depth'], + cameras=[[0]], + ) + print(len(euroc_dataset)) + print('\nsample 0:') + print(euroc_dataset[0].keys()) + print(euroc_dataset[0]['filename']) + print(euroc_dataset[0]['rgb']) + print(euroc_dataset[0]['intrinsics']) + + print('\nsample 1:') + print(euroc_dataset[1].keys()) + print(euroc_dataset[1]['filename']) + print(euroc_dataset[1]['rgb']) + print(euroc_dataset[1]['intrinsics']) \ No newline at end of file diff --git a/vidar/datasets/GenericDataset.py b/vidar/datasets/GenericDataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b9af533139b2ac5c64ea218fe92b74425c112e8c --- /dev/null +++ b/vidar/datasets/GenericDataset.py @@ -0,0 +1,86 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import numpy as np + +from vidar.datasets.BaseDataset import BaseDataset +from vidar.datasets.utils.FolderTree import FolderTree +from vidar.datasets.utils.misc import stack_sample +from vidar.utils.read import read_image + + +class GenericDataset(BaseDataset): + def __init__(self, tag=None, single_folder=False, split=None, extension='png', **kwargs): + """ + Generic dataset, used to load information from folders + + Parameters + ---------- + tag : String + Dataset tag + single_folder : Bool + Whether the dataset is a single folder + split : String + Dataset split + kwargs : Dict + Additional arguments + """ + super().__init__(**kwargs) + self.tag = 'generic' if tag is None else tag + if split is None or split == '': + split = ('', ) + self.rgb_tree = FolderTree( + self.path, context=self.context, sub_folders=split, + single_folder=single_folder, suffix=f'.{extension}') + + def __len__(self): + """Dataset length""" + return len(self.rgb_tree) + + @staticmethod + def get_intrinsics(rgb): + """Return dummy intrinsics""" + return np.array([[rgb.size[0] / 2., 0., rgb.size[0] / 2.], + [0., rgb.size[1], rgb.size[1] / 2.], + [0., 0., 1.]]) + + def __getitem__(self, idx): + """Get dataset sample""" + samples = [] + + for _ in self.cameras: + + # Filename + filename = self.rgb_tree.get_item(idx) + + # Base sample + sample = { + 'idx': idx, + 'tag': self.tag, + 'filename': self.relative_path(filename), + 'splitname': '%010d' % idx + } + + # Image + sample['rgb'] = read_image(filename) + + # Intrinsics + sample['intrinsics'] = { + 0: self.get_intrinsics((sample['rgb'][0])) + } + + # If with context + if self.with_context: + filename_context = self.rgb_tree.get_context(idx) + sample['rgb'].update(read_image(filename_context)) + + # Stack sample + samples.append(sample) + + # Transform data + if self.data_transform: + samples = self.data_transform(samples) + + # Return stacked sample + return stack_sample(samples) + + diff --git a/vidar/datasets/KITTIDataset.py b/vidar/datasets/KITTIDataset.py new file mode 100644 index 0000000000000000000000000000000000000000..5927f4e74bb2e0e4630603a6a0570bcac22a800e --- /dev/null +++ b/vidar/datasets/KITTIDataset.py @@ -0,0 +1,485 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import glob +import os + +import numpy as np + +from vidar.datasets.BaseDataset import BaseDataset +from vidar.datasets.KITTIDataset_utils import \ + pose_from_oxts_packet, read_calib_file, transform_from_rot_trans +from vidar.datasets.utils.misc import \ + invert_pose, stack_sample, make_relative_pose +from vidar.utils.read import read_image + +# Cameras from the stereo pair (left is the origin) +IMAGE_FOLDER = { + 'left': 'image_02', + 'right': 'image_03', +} + + +# Name of different calibration files +CALIB_FILE = { + 'cam2cam': 'calib_cam_to_cam.txt', + 'velo2cam': 'calib_velo_to_cam.txt', + 'imu2velo': 'calib_imu_to_velo.txt', +} + + +PNG_DEPTH_DATASETS = ['groundtruth'] +OXTS_POSE_DATA = 'oxts' + + +def read_npz_depth(file, depth_type): + """Reads a .npz depth map given a certain depth_type""" + depth = np.load(file)[depth_type + '_depth'].astype(np.float32) + return np.expand_dims(depth, axis=2) + + +def read_png_depth(file): + """Reads a .png depth map""" + depth_png = np.array(read_image(file), dtype=int) + assert (np.max(depth_png) > 255), 'Wrong .png depth file' + depth = depth_png.astype(np.float) / 256. + depth[depth_png == 0] = -1. + return np.expand_dims(depth, axis=2) + + +class KITTIDataset(BaseDataset): + """ + KITTI dataset class + + Parameters + ---------- + path : String + Path to the dataset + split : String + Split file, with paths to the images to be used + data_transform : Function + Transformations applied to the sample + depth_type : String + Which depth type to load + input_depth_type : String + Which input depth type to load + """ + def __init__(self, split, tag=None, + depth_type=None, input_depth_type=None, + single_intrinsics=False, **kwargs): + # Assertions + super().__init__(**kwargs) + self.tag = 'kitti' if tag is None else tag + + self.baseline = + 0.5407 + + self.backward_context_paths = [] + self.forward_context_paths = [] + + self.single_intrinsics = None if not single_intrinsics else \ + np.array([[0.58, 0.00, 0.5, 0], + [0.00, 1.92, 0.5, 0], + [0.00, 0.00, 1.0, 0], + [0.00, 0.00, 0.0, 1]], dtype=np.float32) + + self.split = split.split('/')[-1].split('.')[0] + + self.depth_type = depth_type + self.input_depth_type = input_depth_type + + self._cache = {} + self.pose_cache = {} + self.oxts_cache = {} + self.calibration_cache = {} + self.imu2velo_calib_cache = {} + self.sequence_origin_cache = {} + + with open(os.path.join(self.path, split), "r") as f: + data = f.readlines() + + self.paths = [] + # Get file list from data + for i, fname in enumerate(data): + path = os.path.join(self.path, fname.split()[0]) + add_flag = True + if add_flag and self.with_depth: + # Check if depth file exists + depth = self._get_depth_file(path, self.depth_type) + add_flag = depth is not None and os.path.exists(depth) + if add_flag and self.with_input_depth: + # Check if input depth file exists + depth = self._get_depth_file(path, self.input_depth_type) + add_flag = depth is not None and os.path.exists(depth) + if add_flag: + self.paths.append(path) + + # If using context, filter file list + if self.with_context: + paths_with_context = [] + for stride in [1]: + for idx, file in enumerate(self.paths): + backward_context_idxs, forward_context_idxs = \ + self._get_sample_context( + file, self.bwd_context, self.fwd_context, stride) + if backward_context_idxs is not None and forward_context_idxs is not None: + exists = True + if self.with_depth_context: + _, depth_context_files = self._get_context_files( + self.paths[idx], backward_context_idxs + forward_context_idxs) + for depth_file in depth_context_files: + exists = os.path.exists(depth_file) + if not exists: + break + if exists: + paths_with_context.append(self.paths[idx]) + self.forward_context_paths.append(forward_context_idxs) + self.backward_context_paths.append(backward_context_idxs[::-1]) + self.paths = paths_with_context + + if len(self.cameras) > 1: + self.paths = [im.replace('image_03', 'image_02') for im in self.paths] + + if 1 in self.cameras: + self.paths_stereo = [im.replace('image_02', 'image_03') for im in self.paths] + else: + self.paths_stereo = None + + @staticmethod + def _get_next_file(idx, file): + """Get next file given next idx and current file""" + base, ext = os.path.splitext(os.path.basename(file)) + return os.path.join(os.path.dirname(file), str(idx).zfill(len(base)) + ext) + + @staticmethod + def _get_parent_folder(image_file): + """Get the parent folder from image_file""" + return os.path.abspath(os.path.join(image_file, "../../../..")) + + @staticmethod + def _get_intrinsics(image_file, calib_data): + """Get intrinsics from the calib_data dictionary""" + for cam in ['left', 'right']: + # Check for both cameras, if found replace and return intrinsics + if IMAGE_FOLDER[cam] in image_file: + return np.reshape(calib_data[IMAGE_FOLDER[cam].replace('image', 'P_rect')], (3, 4))[:, :3] + + @staticmethod + def _read_raw_calib_file(folder): + """Read raw calibration files from folder""" + return read_calib_file(os.path.join(folder, CALIB_FILE['cam2cam'])) + + @staticmethod + def _get_keypoints(filename, size): + """Get keypoints from file""" + filename = filename. \ + replace('KITTI_tiny', 'KITTI_tiny_keypoints'). \ + replace('.png', '.txt.npz') + keypoints = np.load(filename)['data'] + keypoints_coord, keypoints_desc = keypoints[:, :2], keypoints[:, 2:] + keypoints_coord[:, 0] *= size[0] / 320 + keypoints_coord[:, 1] *= size[1] / 240 + return keypoints_coord, keypoints_desc + + def get_filename(self, sample_idx): + """ + Returns the filename for an index, following DGP structure + + Parameters + ---------- + sample_idx : Int + Sample index + + Returns + ------- + filename : String + Filename for that sample + """ + filename = os.path.splitext(self.paths[sample_idx].replace(self.path + '/', ''))[0] + for cam in ['left', 'right']: + filename = filename.replace('{}/data'.format(IMAGE_FOLDER[cam]), + 'proj_depth/{}/%s' % IMAGE_FOLDER[cam]) + return filename +######################################################################################################################## +#### DEPTH +######################################################################################################################## + + def _read_depth(self, depth_file): + """Get the depth map from a file""" + if depth_file.endswith('.npz'): + return read_npz_depth(depth_file, 'velodyne') + elif depth_file.endswith('.png'): + return read_png_depth(depth_file) + else: + raise NotImplementedError( + 'Depth type {} not implemented'.format(self.depth_type)) + + @staticmethod + def _get_depth_file(image_file, depth_type): + """Get the corresponding depth file from an image file""" + for cam in ['left', 'right']: + if IMAGE_FOLDER[cam] in image_file: + depth_file = image_file.replace( + IMAGE_FOLDER[cam] + '/data', 'proj_depth/{}/{}'.format(depth_type, IMAGE_FOLDER[cam])) + if depth_type not in PNG_DEPTH_DATASETS: + depth_file = depth_file.replace('png', 'npz') + return depth_file + + def _get_sample_context(self, sample_name, + backward_context, forward_context, stride=1): + """ + Get a sample context + + Parameters + ---------- + sample_name : String + Path + Name of the sample + backward_context : Int + Size of backward context + forward_context : Int + Size of forward context + stride : Int + Stride value to consider when building the context + + Returns + ------- + backward_context : list[Int] + List containing the indexes for the backward context + forward_context : list[Int] + List containing the indexes for the forward context + """ + base, ext = os.path.splitext(os.path.basename(sample_name)) + parent_folder = os.path.dirname(sample_name) + f_idx = int(base) + + # Check number of files in folder + if parent_folder in self._cache: + max_num_files = self._cache[parent_folder] + else: + max_num_files = len(glob.glob(os.path.join(parent_folder, '*' + ext))) + self._cache[parent_folder] = max_num_files + + # Check bounds + if (f_idx - backward_context * stride) < 0 or ( + f_idx + forward_context * stride) >= max_num_files: + return None, None + + # Backward context + c_idx = f_idx + backward_context_idxs = [] + while len(backward_context_idxs) < backward_context and c_idx > 0: + c_idx -= stride + filename = self._get_next_file(c_idx, sample_name) + if os.path.exists(filename): + backward_context_idxs.append(c_idx) + if c_idx < 0: + return None, None + + # Forward context + c_idx = f_idx + forward_context_idxs = [] + while len(forward_context_idxs) < forward_context and c_idx < max_num_files: + c_idx += stride + filename = self._get_next_file(c_idx, sample_name) + if os.path.exists(filename): + forward_context_idxs.append(c_idx) + if c_idx >= max_num_files: + return None, None + + return backward_context_idxs, forward_context_idxs + + def _get_context_files(self, sample_name, idxs): + """ + Returns image and depth context files + + Parameters + ---------- + sample_name : String + Name of current sample + idxs : list[Int] + Context indexes + + Returns + ------- + image_context_paths : list[String] + List of image names for the context + depth_context_paths : list[String] + List of depth names for the context + """ + image_context_paths = [self._get_next_file(i, sample_name) for i in idxs] + if self.with_depth: + depth_context_paths = [self._get_depth_file(f, self.depth_type) for f in image_context_paths] + return image_context_paths, depth_context_paths + else: + return image_context_paths, None + +######################################################################################################################## +#### POSE +######################################################################################################################## + + def _get_imu2cam_transform(self, image_file): + """Gets the transformation between IMU and camera from an image file""" + parent_folder = self._get_parent_folder(image_file) + if image_file in self.imu2velo_calib_cache: + return self.imu2velo_calib_cache[image_file] + + cam2cam = read_calib_file(os.path.join(parent_folder, CALIB_FILE['cam2cam'])) + imu2velo = read_calib_file(os.path.join(parent_folder, CALIB_FILE['imu2velo'])) + velo2cam = read_calib_file(os.path.join(parent_folder, CALIB_FILE['velo2cam'])) + + velo2cam_mat = transform_from_rot_trans(velo2cam['R'], velo2cam['T']) + imu2velo_mat = transform_from_rot_trans(imu2velo['R'], imu2velo['T']) + cam_2rect_mat = transform_from_rot_trans(cam2cam['R_rect_00'], np.zeros(3)) + + imu2cam = cam_2rect_mat @ velo2cam_mat @ imu2velo_mat + self.imu2velo_calib_cache[image_file] = imu2cam + return imu2cam + + @staticmethod + def _get_oxts_file(image_file): + """Gets the oxts file from an image file""" + # find oxts pose file + for cam in ['left', 'right']: + # Check for both cameras, if found replace and return file name + if IMAGE_FOLDER[cam] in image_file: + return image_file.replace(IMAGE_FOLDER[cam], OXTS_POSE_DATA).replace('.png', '.txt') + # Something went wrong (invalid image file) + raise ValueError('Invalid KITTI path for pose supervision.') + + def _get_oxts_data(self, image_file): + """Gets the oxts data from an image file""" + oxts_file = self._get_oxts_file(image_file) + if oxts_file in self.oxts_cache: + oxts_data = self.oxts_cache[oxts_file] + else: + oxts_data = np.loadtxt(oxts_file, delimiter=' ', skiprows=0) + self.oxts_cache[oxts_file] = oxts_data + return oxts_data + + def _get_pose(self, image_file, camera): + """Gets the pose information from an image file""" + if image_file in self.pose_cache: + return self.pose_cache[image_file] + # Find origin frame in this sequence to determine scale & origin translation + base, ext = os.path.splitext(os.path.basename(image_file)) + origin_frame = os.path.join(os.path.dirname(image_file), str(0).zfill(len(base)) + ext) + # Get origin data + origin_oxts_data = self._get_oxts_data(origin_frame) + lat = origin_oxts_data[0] + scale = np.cos(lat * np.pi / 180.) + # Get origin pose + origin_R, origin_t = pose_from_oxts_packet(origin_oxts_data, scale) + origin_pose = transform_from_rot_trans(origin_R, origin_t) + # Compute current pose + oxts_data = self._get_oxts_data(image_file) + R, t = pose_from_oxts_packet(oxts_data, scale) + pose = transform_from_rot_trans(R, t) + # Compute odometry pose + imu2cam = self._get_imu2cam_transform(image_file) + odo_pose = (imu2cam @ np.linalg.inv(origin_pose) @ + pose @ np.linalg.inv(imu2cam)).astype(np.float32) + odo_pose = invert_pose(odo_pose) + # Cache and return pose + self.pose_cache[image_file] = odo_pose + if camera == 1: + odo_pose[0, -1] -= self.baseline + return odo_pose + +######################################################################################################################## + + def __len__(self): + """Dataset length""" + return len(self.paths) + + def __getitem__(self, idx): + """Get dataset sample""" + + samples = [] + + for camera in self.cameras: + + # Filename + filename = self.paths[idx] if camera == 0 else self.paths_stereo[idx] + + # Base sample + sample = { + 'idx': idx, + 'tag': self.tag, + 'filename': self.relative_path({0: filename}), + 'splitname': '%s_%010d' % (self.split, idx) + } + + # Image + sample['rgb'] = {0: read_image(filename)} + + # Intrinsics + parent_folder = self._get_parent_folder(filename) + if parent_folder in self.calibration_cache: + c_data = self.calibration_cache[parent_folder] + else: + c_data = self._read_raw_calib_file(parent_folder) + self.calibration_cache[parent_folder] = c_data + + # Return individual or single intrinsics + if self.single_intrinsics is not None: + intrinsics = self.single_intrinsics.copy() + intrinsics[0, :] *= sample['rgb'][0].size[0] + intrinsics[1, :] *= sample['rgb'][0].size[1] + sample['intrinsics'] = {0: intrinsics} + else: + sample['intrinsics'] = {0: self._get_intrinsics(filename, c_data)} + + # Add pose information if requested + if self.with_pose: + sample['pose'] = {0: self._get_pose(filename, camera)} + + # Add depth information if requested + if self.with_depth: + sample['depth'] = {0: self._read_depth( + self._get_depth_file(filename, self.depth_type))} + + # Add input depth information if requested + if self.with_input_depth: + sample['input_depth'] = {0: self._read_depth( + self._get_depth_file(filename, self.input_depth_type))} + + # Add context information if requested + if self.with_context: + + # Add context images + all_context_idxs = self.backward_context_paths[idx] + \ + self.forward_context_paths[idx] + image_context_paths, depth_context_paths = \ + self._get_context_files(filename, all_context_idxs) + image_context = [read_image(f) for f in image_context_paths] + sample['rgb'].update({ + key: val for key, val in zip(self.context, image_context) + }) + + # Add context poses + if self.with_pose: + image_context_pose = [self._get_pose(f, camera) for f in image_context_paths] + sample['pose'].update({ + key: val for key, val in zip(self.context, image_context_pose) + }) + + # Add context depth + if self.with_depth_context: + depth_context = [self._read_depth(self._get_depth_file( + path, self.depth_type)) for path in depth_context_paths] + sample['depth'].update({ + key: val for key, val in zip(self.context, depth_context) + }) + + # Stack sample + samples.append(sample) + + # Make relative poses + samples = make_relative_pose(samples) + + # Transform data + if self.data_transform: + samples = self.data_transform(samples) + + # Return stacked sample + return stack_sample(samples) + +######################################################################################################################## diff --git a/vidar/datasets/KITTIDataset_utils.py b/vidar/datasets/KITTIDataset_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..0a1ca1965ad21220bdece990d2cd1dd02e120869 --- /dev/null +++ b/vidar/datasets/KITTIDataset_utils.py @@ -0,0 +1,223 @@ +"""Provides helper methods for loading and parsing KITTI data.""" + +from collections import namedtuple + +import numpy as np + +__author__ = "Lee Clement" +__email__ = "lee.clement@robotics.utias.utoronto.ca" + +# Per dataformat.txt +OxtsPacket = namedtuple('OxtsPacket', + 'lat, lon, alt, ' + + 'roll, pitch, yaw, ' + + 'vn, ve, vf, vl, vu, ' + + 'ax, ay, az, af, al, au, ' + + 'wx, wy, wz, wf, wl, wu, ' + + 'pos_accuracy, vel_accuracy, ' + + 'navstat, numsats, ' + + 'posmode, velmode, orimode') + +# Bundle into an easy-to-access structure +OxtsData = namedtuple('OxtsData', 'packet, T_w_imu') + + +def rotx(t): + """ + Rotation about the x-axis + + Parameters + ---------- + t : Float + Theta angle + + Returns + ------- + matrix : np.Array + Rotation matrix [3,3] + """ + c = np.cos(t) + s = np.sin(t) + return np.array([[1, 0, 0], + [0, c, -s], + [0, s, c]]) + + +def roty(t): + """ + Rotation about the y-axis + + Parameters + ---------- + t : Float + Theta angle + + Returns + ------- + matrix : np.Array + Rotation matrix [3,3] + """ + c = np.cos(t) + s = np.sin(t) + return np.array([[c, 0, s], + [0, 1, 0], + [-s, 0, c]]) + + +def rotz(t): + """ + Rotation about the z-axis + + Parameters + ---------- + t : Float + Theta angle + + Returns + ------- + matrix : np.Array + Rotation matrix [3,3] + """ + c = np.cos(t) + s = np.sin(t) + return np.array([[c, -s, 0], + [s, c, 0], + [0, 0, 1]]) + + +def transform_from_rot_trans(R, t): + """ + Transformation matrix from rotation matrix and translation vector. + + Parameters + ---------- + R : np.Array + Rotation matrix [3,3] + t : np.Array + translation vector [3] + + Returns + ------- + matrix : np.Array + Transformation matrix [4,4] + """ + R = R.reshape(3, 3) + t = t.reshape(3, 1) + return np.vstack((np.hstack([R, t]), [0, 0, 0, 1])) + + +def read_calib_file(filepath): + """ + Read in a calibration file and parse into a dictionary + + Parameters + ---------- + filepath : String + File path to read from + + Returns + ------- + calib : Dict + Dictionary with calibration values + """ + data = {} + + with open(filepath, 'r') as f: + for line in f.readlines(): + key, value = line.split(':', 1) + # The only non-float values in these files are dates, which + # we don't care about anyway + try: + data[key] = np.array([float(x) for x in value.split()]) + except ValueError: + pass + + return data + + +def pose_from_oxts_packet(raw_data, scale): + """ + Helper method to compute a SE(3) pose matrix from an OXTS packet + + Parameters + ---------- + raw_data : Dict + Oxts data to read from + scale : Float + Oxts scale + + Returns + ------- + R : np.Array + Rotation matrix [3,3] + t : np.Array + Translation vector [3] + """ + packet = OxtsPacket(*raw_data) + er = 6378137. # earth radius (approx.) in meters + + # Use a Mercator projection to get the translation vector + tx = scale * packet.lon * np.pi * er / 180. + ty = scale * er * \ + np.log(np.tan((90. + packet.lat) * np.pi / 360.)) + tz = packet.alt + t = np.array([tx, ty, tz]) + + # Use the Euler angles to get the rotation matrix + Rx = rotx(packet.roll) + Ry = roty(packet.pitch) + Rz = rotz(packet.yaw) + R = Rz.dot(Ry.dot(Rx)) + + # Combine the translation and rotation into a homogeneous transform + return R, t + + +def load_oxts_packets_and_poses(oxts_files): + """ + Generator to read OXTS ground truth data. + Poses are given in an East-North-Up coordinate system + whose origin is the first GPS position. + + Parameters + ---------- + oxts_files : list[String] + List of oxts files to read from + + Returns + ------- + oxts : list[Dict] + List of oxts ground-truth data + """ + # Scale for Mercator projection (from first lat value) + scale = None + # Origin of the global coordinate system (first GPS position) + origin = None + + oxts = [] + + for filename in oxts_files: + with open(filename, 'r') as f: + for line in f.readlines(): + line = line.split() + # Last five entries are flags and counts + line[:-5] = [float(x) for x in line[:-5]] + line[-5:] = [int(float(x)) for x in line[-5:]] + + packet = OxtsPacket(*line) + + if scale is None: + scale = np.cos(packet.lat * np.pi / 180.) + + R, t = pose_from_oxts_packet(packet, scale) + + if origin is None: + origin = t + + T_w_imu = transform_from_rot_trans(R, t - origin) + + oxts.append(OxtsData(packet, T_w_imu)) + + return oxts + + diff --git a/vidar/datasets/OuroborosDataset.py b/vidar/datasets/OuroborosDataset.py new file mode 100644 index 0000000000000000000000000000000000000000..004688ae84564c53fb71658ebe9532742f1dffd0 --- /dev/null +++ b/vidar/datasets/OuroborosDataset.py @@ -0,0 +1,487 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import os +import pickle + +import numpy as np +from dgp.utils.camera import Camera +from dgp.utils.pose import Pose + +from vidar.datasets.BaseDataset import BaseDataset +from vidar.datasets.utils.misc import \ + stack_sample, make_relative_pose +from vidar.utils.data import make_list +from vidar.utils.read import read_image +from vidar.utils.types import is_str + + +def load_from_file(filename, key): + """Load data cache from a file""" + data = np.load(filename, allow_pickle=True)[key] + if len(data.shape) == 0: + data = None + return data + + +def save_to_file(filename, key, value): + """Save data to a cache file""" + os.makedirs(os.path.dirname(filename), exist_ok=True) + np.savez_compressed(filename, **{key: value}) + + +def generate_proj_maps(camera, Xw, shape): + """Render pointcloud on image. + + Parameters + ---------- + camera: Camera + Camera object with appropriately set extrinsics wrt world. + Xw: np.Array + 3D point cloud (x, y, z) in the world coordinate. [N,3] + shape: np.Array + Output depth image shape [H, W] + + Returns + ------- + depth: np.Array + Rendered depth image + """ + assert len(shape) == 2, 'Shape needs to be 2-tuple.' + # Move point cloud to the camera's (C) reference frame from the world (W) + Xc = camera.p_cw * Xw + # Project the points as if they were in the camera's frame of reference + uv = Camera(K=camera.K).project(Xc).astype(int) + # Colorize the point cloud based on depth + z_c = Xc[:, 2] + + # Create an empty image to overlay + H, W = shape + proj_depth = np.zeros((H, W), dtype=np.float32) + in_view = np.logical_and.reduce([(uv >= 0).all(axis=1), uv[:, 0] < W, uv[:, 1] < H, z_c > 0]) + uv, z_c = uv[in_view], z_c[in_view] + proj_depth[uv[:, 1], uv[:, 0]] = z_c + + # Return projected maps + return proj_depth + + +class OuroborosDataset(BaseDataset): + """ + DGP dataset class + + Parameters + ---------- + path : String + Path to the dataset + split : String {'train', 'val', 'test'} + Which dataset split to use + cameras : list[String] + Which cameras to get information from + depth_type : String + Which lidar will be used to generate ground-truth information + input_depth_type : String + Which lidar will be used as input to the networks + with_pose : Bool + If enabled pose estimates are also returned + with_extra_context : Bool + If enabled extra context information (e.g. depth, semantic, instance) are also returned + back_context : Int + Size of the backward context + forward_context : Int + Size of the forward context + data_transform : Function + Transformations applied to the sample + dataset : String ['synchronized', 'parallel_domain'] + Which dataset will be used + only_cache : Bool + Only use cached pointcloud information, without loading the sensor + """ + def __init__(self, split, tag=None, + depth_type=None, input_depth_type=None, + masks=None, **kwargs): + super().__init__(**kwargs) + self.tag = 'ouroboros' if tag is None else tag + + cameras = [c if is_str(c) else 'camera_%02d' % c for c in self.cameras] + + # Store variables + self.split = split + self.dataset_idx = 0 + self.sensors = list(cameras) + + # Store task information + self.depth_type = depth_type + self.input_depth_type = input_depth_type + self.only_cache = False + + self.masks_path = masks + + # Add requested annotations + requested_annotations = [] + + # Add depth sensor + if self.with_depth and not self.only_cache and \ + self.depth_type != 'zbuffer': + self.sensors.append(depth_type) + self.depth_idx = len(self.sensors) - 1 + + # Add input depth sensor + if self.with_input_depth and not self.only_cache and \ + self.input_depth_type != 'zbuffer' and \ + self.input_depth_type != self.depth_type: + self.sensors.append(input_depth_type) + self.input_depth_idx = len(self.sensors) - 1 + + # Add radar sensor + if self.with_radar: + self.sensors.append('radar') + self.radar_idx = len(self.sensors) - 1 + + # Choose which dataset to use + if not self.virtual: + from dgp.datasets.synchronized_dataset import SynchronizedSceneDataset + dataset = SynchronizedSceneDataset + extra_args = {} + else: + from dgp.datasets.pd_dataset import ParallelDomainSceneDataset + dataset = ParallelDomainSceneDataset + extra_args = { + 'use_virtual_camera_datums': False, + } + + # Initialize chosen dataset + self.dataset = dataset( + scene_dataset_json=self.path, + split=split, + datum_names=self.sensors, + backward_context=self.bwd_context, + forward_context=self.fwd_context, + requested_annotations=requested_annotations, + only_annotated_datums=False, + **extra_args, + ) + + def depth_to_world_points(self, depth, datum_idx): + """ + Unproject depth from a camera's perspective into a world-frame pointcloud + + Parameters + ---------- + depth : np.Array + Depth map to be lifted [H,W] + datum_idx : Int + Index of the camera + + Returns + ------- + pointcloud : np.Array + Lifted 3D pointcloud [Nx3] + """ + # Access data + intrinsics = self.get_current('intrinsics', datum_idx) + pose = self.get_current('pose', datum_idx) + # Create pixel grid for 3D unprojection + h, w = depth.shape[:2] + uv = np.mgrid[:w, :h].transpose(2, 1, 0).reshape(-1, 2).astype(np.float32) + # Unproject grid to 3D in the camera frame of reference + pcl = Camera(K=intrinsics).unproject(uv) * depth.reshape(-1, 1) + # Return pointcloud in world frame of reference + return pose * pcl + + def create_camera(self, datum_idx, context=None): + """ + Create current camera + + Parameters + ---------- + datum_idx : Int + Index of the camera + context : Int + Context value for choosing current of reference information + + Returns + ------- + camera : Camera + DGP camera + """ + camera_pose = self.get_current_or_context('pose', datum_idx, context) + camera_intrinsics = self.get_current_or_context('intrinsics', datum_idx, context) + return Camera(K=camera_intrinsics, p_cw=camera_pose.inverse()) + + def create_proj_maps(self, filename, camera_idx, depth_idx, depth_type, + world_points=None, context=None): + """ + Creates the depth map for a camera by projecting LiDAR information. + It also caches the depth map following DGP folder structure, so it's not recalculated + + Parameters + ---------- + filename : String + Filename used for loading / saving + camera_idx : Int + Camera sensor index + depth_idx : Int + Depth sensor index + depth_type : String + Which depth type will be loaded + world_points : np.Array [Nx3] + Points that will be projected (optional) + context : Int + Context value for choosing current of reference information + + Returns + ------- + depth : np.Array + Depth map for that datum in that sample [H,W] + """ + # If we want the z-buffer (simulation) + if depth_type == 'zbuffer': + sensor_name = self.get_current('datum_name', camera_idx) + filename = filename.replace(self.sensors[camera_idx], sensor_name) + filename = '{}/{}.npz'.format( + os.path.dirname(self.path), filename.format('depth')) + return np.load(filename)['data'], None, None + # Otherwise, we want projected information + filename_depth = '{}/{}.npz'.format( + os.path.dirname(self.path), filename.format('projected/depth/{}'.format(depth_type))) + # Load and return if exists + try: + # Get cached depth map + depth = load_from_file(filename_depth, 'depth') + return depth + except: + pass + # Calculate world points if needed + if world_points is None: + # Get lidar information + lidar_pose = self.get_current_or_context('pose', depth_idx, context) + lidar_points = self.get_current_or_context('point_cloud', depth_idx, context) + world_points = lidar_pose * lidar_points + + # Create camera + camera = self.create_camera(camera_idx, context) + image_shape = self.get_current_or_context('rgb', camera_idx, context).size[::-1] + # Generate depth maps + depth = generate_proj_maps(camera, world_points, image_shape) + # Save depth map + save_to_file(filename_depth, 'depth', depth) + # Return depth + return depth + + + def get_current(self, key, sensor_idx, as_dict=False): + """Return current timestep of a key from a sensor""" + current = self.sample_dgp[self.bwd_context][sensor_idx][key] + return current if not as_dict else {0: current} + + def get_backward(self, key, sensor_idx): + """Return backward timesteps of a key from a sensor""" + return [] if self.bwd_context == 0 else \ + [self.sample_dgp[i][sensor_idx][key] for i in range(0, self.bwd_context)] + + def get_forward(self, key, sensor_idx): + """Return forward timesteps of a key from a sensor""" + return [] if self.fwd_context == 0 else \ + [self.sample_dgp[i][sensor_idx][key] + for i in range(self.bwd_context + 1, + self.bwd_context + self.fwd_context + 1)] + + def get_context(self, key, sensor_idx, as_dict=False): + """Get both backward and forward contexts""" + context = self.get_backward(key, sensor_idx) + self.get_forward(key, sensor_idx) + if not as_dict: + return context + else: + return {key: val for key, val in zip(self.context, context)} + + def get_current_or_context(self, key, sensor_idx, context=None, as_dict=False): + """Return current or context information for a given key and sensor index""" + if context is None: + return self.get_current(key, sensor_idx, as_dict=as_dict) + else: + return self.get_context(key, sensor_idx, as_dict=as_dict)[context] + + def has_dgp_key(self, key, sensor_idx): + """Returns True if the DGP sample contains a certain key""" + return key in self.sample_dgp[self.bwd_context][sensor_idx].keys() + + def get_filename(self, sample_idx, datum_idx, context=0): + """ + Returns the filename for an index, following DGP structure + + Parameters + ---------- + sample_idx : Int + Sample index + datum_idx : Int + Datum index + context : Int + Context offset for the sample + + Returns + ------- + filename : String + Filename for the datum in that sample + """ + scene_idx, sample_idx_in_scene, _ = self.dataset.dataset_item_index[sample_idx] + scene_dir = self.dataset.scenes[scene_idx].directory + filename = self.dataset.get_datum( + scene_idx, sample_idx_in_scene + context, self.sensors[datum_idx]).datum.image.filename + return os.path.splitext(os.path.join(os.path.basename(scene_dir), + filename.replace('rgb', '{}')))[0] + + def __len__(self): + """Length of dataset""" + return len(self.dataset) + + def __getitem__(self, idx): + """Get dataset sample""" + + # Get DGP sample (if single sensor, make it a list) + self.sample_dgp = self.dataset[idx] + self.sample_dgp = [make_list(sample) for sample in self.sample_dgp] + + # Reorganize sensors to the right order + sensor_names = [self.get_current('datum_name', i).lower() for i in range(len(self.sensors))] + indexes = [sensor_names.index(v) for v in self.sensors] + self.sample_dgp = [[s[idx] for idx in indexes] for s in self.sample_dgp] + + # Loop over all cameras + samples = [] + for i in range(self.num_cameras): + + # Filename + filename = self.get_filename(idx, i) + + # Base sample + sample = { + 'idx': idx, + 'tag': self.tag, + 'filename': self.relative_path({0: filename}), + 'splitname': '%s_%010d' % (self.split, idx), + 'sensor_name': self.get_current('datum_name', i), + } + + # Image and intrinsics + sample.update({ + 'rgb': self.get_current('rgb', i, as_dict=True), + 'intrinsics': self.get_current('intrinsics', i, as_dict=True), + }) + + # If masks are returned + if self.masks_path is not None: + sample.update({ + 'mask': read_image(os.path.join( + self.masks_path, '%02d.png' % self.cameras[i])) + }) + + # If depth is returned + if self.with_depth: + # Get depth maps + depth = self.create_proj_maps( + filename, i, self.depth_idx, self.depth_type) + # Include depth map + sample.update({ + 'depth': {0: depth} + }) + + # If input depth is returned + if self.with_input_depth: + sample.update({ + 'input_depth': {0: self.create_proj_maps( + filename, i, self.input_depth_idx, self.input_depth_type)[0]} + }) + + # If pose is returned + if self.with_pose: + sample.update({ + 'extrinsics': {key: val.inverse().matrix for key, val in + self.get_current('extrinsics', i, as_dict=True).items()}, + 'pose': {key: val.inverse().matrix for key, val in + self.get_current('pose', i, as_dict=True).items()}, + }) + + # If context is returned + if self.with_context: + + # Include context images + sample['rgb'].update(self.get_context('rgb', i, as_dict=True)) + + # Create contexts filenames if extra context is required + filename_context = [] + for context in range(-self.bwd_context, 0): + filename_context.append(self.get_filename(idx, i, context)) + for context in range(1, self.fwd_context + 1): + filename_context.append(self.get_filename(idx, i, context)) + sample['filename_context'] = filename_context + + # If context pose is returned + if self.with_pose: + # Get original values to calculate relative motion + inv_orig_extrinsics = Pose.from_matrix(sample['extrinsics'][0]).inverse() + sample['extrinsics'].update( + {key: (inv_orig_extrinsics * val.inverse()).matrix for key, val in zip( + self.context, self.get_context('extrinsics', i))}) + sample['pose'].update( + {key: (val.inverse()).matrix for key, val in zip( + self.context, self.get_context('pose', i))}) + + # If context depth is returned + if self.with_depth_context: + depth_context = [ + self.create_proj_maps( + filename, i, self.depth_idx, self.depth_type, + context=k) + for k, filename in enumerate(filename_context)] + sample['depth'].update( + {key: val for key, val in zip( + self.context, [dsf for dsf in depth_context])}) + + + samples.append(sample) + + # Make relative poses + samples = make_relative_pose(samples) + + # Add LiDAR information + + lidar_sample = {} + if self.with_lidar: + + # Include pointcloud information + lidar_sample.update({ + 'lidar_pointcloud': self.get_current('point_cloud', self.depth_idx), + }) + + # If pose is included + if self.with_pose: + lidar_sample.update({ + 'lidar_extrinsics': self.get_current('extrinsics', self.depth_idx).matrix, + 'lidar_pose': self.get_current('pose', self.depth_idx).matrix, + }) + + # If extra context is included + if self.with_extra_context: + lidar_sample['lidar_context'] = self.get_context('point_cloud', self.depth_idx) + # If context pose is included + if self.with_pose: + # Get original values to calculate relative motion + orig_extrinsics = Pose.from_matrix(lidar_sample['lidar_extrinsics']) + orig_pose = Pose.from_matrix(lidar_sample['lidar_pose']) + lidar_sample.update({ + 'lidar_extrinsics_context': + [(orig_extrinsics.inverse() * extrinsics).inverse().matrix + for extrinsics in self.get_context('extrinsics', self.depth_idx)], + 'lidar_pose_context': + [(orig_pose.inverse() * pose).inverse().matrix + for pose in self.get_context('pose', self.depth_idx)], + }) + + + # Apply same data transformations for all sensors + if self.data_transform: + samples = self.data_transform(samples) + # lidar_sample = self.data_transform(lidar_sample) + + # Return sample (stacked if necessary) + return stack_sample(samples, lidar_sample) diff --git a/vidar/datasets/ScanNetTemporalDataset.py b/vidar/datasets/ScanNetTemporalDataset.py new file mode 100755 index 0000000000000000000000000000000000000000..42524c99d3e0ab5834c9e2c58ef9bf5f12806353 --- /dev/null +++ b/vidar/datasets/ScanNetTemporalDataset.py @@ -0,0 +1,116 @@ + +import os + +import numpy as np + +import cv2 + +from vidar.datasets.BaseDataset import BaseDataset +from vidar.datasets.utils.FolderTree import FolderTree +from vidar.datasets.utils.misc import stack_sample +from vidar.utils.read import read_image +from vidar.datasets.BaseDataset import BaseDataset +from vidar.datasets.utils.misc import invert_pose, stack_sample, make_relative_pose +from vidar.utils.read import read_image + + +class ScanNetTemporalDataset(BaseDataset): + def __init__(self, tag=None, single_folder=False, split=None, stride=1, **kwargs): + super().__init__(**kwargs) + self.tag = 'scannet_temporal' if tag is None else tag + if split is None or split == '': + split = ('', ) + self.rgb_tree = FolderTree( + os.path.join(self.path, split), + context=self.context, sub_folders=['color'], stride=stride, + single_folder=single_folder, suffix='.jpg') + + def __len__(self): + """Dataset length""" + return len(self.rgb_tree) + + @staticmethod + def get_intrinsics(rgb): + """Return dummy intrinsics""" + return np.array([[rgb.size[0] / 2., 0., rgb.size[0] / 2.], + [0., rgb.size[1], rgb.size[1] / 2.], + [0., 0., 1.]]) + + @staticmethod + def load_intrinsics(filename): + filename_intrinsics = {key: '/'.join(val.split('/')[:-2]) + '/intrinsic/intrinsic_depth.txt' + for key, val in filename.items()} + return {key: np.genfromtxt(val).astype(np.float32).reshape((4, 4))[:3, :3] + for key, val in filename_intrinsics.items()} + + @staticmethod + def load_depth(filename): + try: + filename_depth = {key: val.replace('color', 'depth').replace('.jpg', '.png') + for key, val in filename.items()} + return {key: (cv2.imread(val, cv2.IMREAD_ANYDEPTH).astype(np.float32) / 1000.0) + for key, val in filename_depth.items()} + except: + filename_depth = {key: val.replace('color', 'depth').replace('.jpg', '.npy') + for key, val in filename.items()} + return {key: (np.load(val) / 1000.0).astype(np.float32) + for key, val in filename_depth.items()} + + @staticmethod + def load_pose(filename): + filename_pose = {key: val.replace('color', 'pose').replace('.jpg', '.txt') + for key, val in filename.items()} + return {key: invert_pose(np.genfromtxt(val).astype(np.float32).reshape((4, 4))) + for key, val in filename_pose.items()} + + def __getitem__(self, idx): + """Get dataset sample given an index.""" + + samples = [] + + for _ in self.cameras: + + # Filename + filename = self.rgb_tree.get_item(idx) + + # Base sample + sample = { + 'idx': idx, + 'tag': self.tag, + 'filename': self.relative_path(filename), + 'splitname': '%010d' % idx + } + + # Image + sample['rgb'] = read_image(filename) + + # Intrinsics + sample['intrinsics'] = self.load_intrinsics(filename) + + if self.with_depth: + sample['depth'] = self.load_depth(filename) + + if self.with_pose: + sample['pose'] = self.load_pose(filename) + + # If with context + if self.with_context: + filename_context = self.rgb_tree.get_context(idx) + sample['rgb'].update(read_image(filename_context)) + if self.with_depth: + sample['depth'].update(self.load_depth(filename_context)) + if self.with_pose: + sample['pose'].update(self.load_pose(filename_context)) + + # Stack sample + samples.append(sample) + + # Make relative poses + samples = make_relative_pose(samples) + + # Transform data + if self.data_transform: + samples = self.data_transform(samples) + + # Return stacked sample + return stack_sample(samples) diff --git a/vidar/datasets/VKITTI2Dataset.py b/vidar/datasets/VKITTI2Dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d5985d56c80ff7240cbda51500961f97c6a9a374 --- /dev/null +++ b/vidar/datasets/VKITTI2Dataset.py @@ -0,0 +1,413 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import csv +import os + +import cv2 +import numpy as np + +from vidar.datasets.BaseDataset import BaseDataset +from vidar.datasets.utils.FolderTree import FolderTree +from vidar.datasets.utils.misc import \ + convert_ontology, initialize_ontology, stack_sample, make_relative_pose +from vidar.utils.data import dict_remove_nones +from vidar.utils.decorators import iterate1 +from vidar.utils.read import read_image + + +def make_tree(path, sub_folder, camera, mode, context): + """ + Create a folder tree for a certain task + + Parameters + ---------- + path : String + Data path + sub_folder : String + Subfolder path + camera : Int + Camera index + mode : String + Which task we are using + context : list[Int] + Context samples + + Returns + ------- + tree : FolderTree + Folder tree containing task data + """ + path = os.path.join(path, sub_folder) + sub_folders = '{}/frames/{}/Camera_{}'.format(mode, sub_folder, camera) + return FolderTree(path, sub_folders=sub_folders, context=context) + + +def semantic_color_to_id(semantic_color, ontology): + """ + Convert semantic color to semantic ID + + Parameters + ---------- + semantic_color : numpy.Array + Matrix with semantic colors [H, W, 3] + ontology : Dict + Ontology dictionary, with {id: color} + + Returns + ------- + semantic_id : numpy.Array + Matrix with semantic IDs [H, W] + """ + # Create semantic ID map + semantic_id = np.zeros(semantic_color.shape[:2]) + # Loop over every ontology item and assign ID to color + for key, val in ontology.items(): + idx = (semantic_color[:, :, 0] == val['color'][0]) & \ + (semantic_color[:, :, 1] == val['color'][1]) & \ + (semantic_color[:, :, 2] == val['color'][2]) + semantic_id[idx] = key + # Return semantic ID map + return semantic_id + + +class VKITTI2Dataset(BaseDataset): + """ + VKITTI2 dataset class + + Parameters + ---------- + path : String + Path to the dataset + split : String {'train', 'val', 'test'} + Which dataset split to use + ontology : String + Which ontology should be used + return_ontology : Bool + Returns ontology information in the sample + data_transform : Function + Transformations applied to the sample + """ + def __init__(self, split, tag=None, **kwargs): + super().__init__(**kwargs) + self.tag = 'vkitti2' if tag is None else tag + + # Store variables + self.split = split + self.mode = 'clone' + + # Initialize ontology + if self.with_semantic: + self.ontology, self.ontology_convert = initialize_ontology('vkitti2', self.ontology) + self.return_ontology = self.return_ontology + + # Create RGB tree + self.rgb_tree = make_tree( + self.path, 'rgb', 0, self.mode, self.context) + + # Create semantic tree + if self.with_semantic: + self.semantic_tree = make_tree( + self.path, 'classSegmentation', 0, self.mode, self.context) + + # Create instance tree + if self.with_instance: + self.instance_tree = make_tree( + self.path, 'instanceSegmentation', 0, self.mode, self.context) + + def __len__(self): + """Dataset length""" + return len(self.rgb_tree) + + @staticmethod + @iterate1 + def _get_depth(filename): + """Get depth map from filename""" + filename = filename.replace('rgb', 'depth').replace('jpg', 'png') + return cv2.imread(filename, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) / 100. + + @staticmethod + @iterate1 + def _get_intrinsics(filename, camera, mode): + """Get intrinsics from filename""" + # Get sample number in the scene + number = int(filename.split('/')[-1].replace('rgb_', '').replace('.jpg', '')) + # Get intrinsic filename + filename_idx = filename.rfind(mode) + len(mode) + filename_intrinsics = os.path.join(filename[:filename_idx].replace( + '/rgb/', '/textgt/'), 'intrinsic.txt') + # Open intrinsic file + with open(filename_intrinsics, 'r') as f: + # Get intrinsic parameters + lines = list(csv.reader(f, delimiter=' '))[1:] + params = [float(p) for p in lines[number * 2 + camera][2:]] + # Build intrinsics matrix + intrinsics = np.array([[params[0], 0.0, params[2]], + [0.0, params[1], params[3]], + [0.0, 0.0, 1.0]]).astype(np.float32) + # Return intrinsics + return intrinsics + + @staticmethod + @iterate1 + def _get_pose(filename, camera, mode): + """Get pose from filename""" + # Get sample number in the scene + number = int(filename.split('/')[-1].replace('rgb_', '').replace('.jpg', '')) + # Get intrinsic filename + filename_idx = filename.rfind(mode) + len(mode) + filename_pose = os.path.join(filename[:filename_idx].replace( + '/rgb/', '/textgt/'), 'extrinsic.txt') + # Open intrinsics file + with open(filename_pose, 'r') as f: + # Get pose parameters + lines = list(csv.reader(f, delimiter=' '))[1:] + pose = np.array([float(p) for p in lines[number * 2 + camera][2:]]).reshape(4, 4) + # Return pose + return pose + + @staticmethod + def _get_ontology(filename, mode): + """Get ontology from filename""" + # Get ontology filename + filename_idx = filename.rfind(mode) + len(mode) + filename_ontology = os.path.join(filename[:filename_idx].replace( + '/classSegmentation/', '/textgt/'), 'colors.txt') + # Open ontology file + with open(filename_ontology, 'r') as f: + # Get ontology parameters + lines = list(csv.reader(f, delimiter=' '))[1:] + from collections import OrderedDict + ontology = OrderedDict() + for i, line in enumerate(lines): + ontology[i] = { + 'name': line[0], + 'color': np.array([int(clr) for clr in line[1:]]) + } + return ontology + + def _get_semantic(self, filename): + """Get semantic from filename""" + # Get semantic color map + semantic_color = {key: np.array(val) for key, val in read_image(filename).items()} + # Return semantic id map + semantic_id = {key: semantic_color_to_id(val, self.ontology) for key, val in semantic_color.items()} + return convert_ontology(semantic_id, self.ontology_convert) + + @staticmethod + def _get_instance(filename): + """Get instance from filename""" + # Get instance id map + return np.array(read_image(filename)) + + @staticmethod + def _get_bbox3d(filename): + + bboxes3d_dim = [] + bboxes3d_pos = [] + bboxes3d_rot = [] + bboxes3d_idx = [] + + k = int(filename.split('/')[-1][4:-4]) + bb = '/'.join(filename.replace('/rgb/', '/textgt/').split('/')[:-4]) + bb += '/pose.txt' + + with open(bb, 'r') as file: + for i, f in enumerate(file): + if i == 0: + continue + line = [float(a) for a in f.split(' ')] + if line[0] == k and line[1] == 0: + bboxes3d_dim.append(np.array([line[6], line[5], line[4]])) + bboxes3d_pos.append(np.array(line[13:16])) + # bboxes3d_rot.append(np.array([line[18], line[17], line[16]])) + bboxes3d_rot.append(np.array([line[17], line[16], line[18]])) + bboxes3d_idx.append(np.array([line[2]])) + + return { + 'dim': np.stack(bboxes3d_dim, 0), + 'pos': np.stack(bboxes3d_pos, 0), + 'rot': np.stack(bboxes3d_rot, 0), + 'idx': np.stack(bboxes3d_idx, 0), + } + + @staticmethod + @iterate1 + def _get_optical_flow(filename, mode): + """ + Get optical flow from filename. Code obtained here: + https://europe.naverlabs.com/research/computer-vision-research-naver-labs-europe/proxy-virtual-worlds-vkitti-2/ + """ + # Get filename + if mode == 'bwd': + filename = filename.replace('rgb', 'backwardFlow') + elif mode == 'fwd': + filename = filename.replace('/rgb/', '/forwardFlow/').replace('rgb_', 'flow_') + else: + raise ValueError('Invalid optical flow mode') + filename = filename.replace('jpg', 'png') + # Return None if file does not exist + if not os.path.exists(filename): + return None + else: + # Get optical flow + optical_flow = cv2.imread(filename, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) + h, w = optical_flow.shape[:2] + # Get invalid optical flow pixels + invalid = optical_flow[..., 0] == 0 + # Normalize and scale optical flow values + optical_flow = 2.0 / (2 ** 16 - 1.0) * optical_flow[..., 2:0:-1].astype('f4') - 1. + optical_flow[..., 0] *= w - 1 + optical_flow[..., 1] *= h - 1 + # Remove invalid pixels + optical_flow[invalid] = 0 + return optical_flow + + @staticmethod + @iterate1 + def _get_scene_flow(filename, mode): + """Get scene flow from filename. Code obtained here: + https://europe.naverlabs.com/research/computer-vision-research-naver-labs-europe/proxy-virtual-worlds-vkitti-2/ + """ + # Get filename + if mode == 'bwd': + filename = filename.replace('rgb', 'backwardSceneFlow') + elif mode == 'fwd': + filename = filename.replace('/rgb/', '/forwardSceneFlow/').replace('rgb_', 'sceneFlow_') + else: + raise ValueError('Invalid scene flow mode') + filename = filename.replace('jpg', 'png') + # Return None if file does not exist + if not os.path.exists(filename): + return None + else: + # Get scene flow + scene_flow = cv2.imread(filename, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) + # Return normalized and scaled optical flow (-10m to 10m) + return (scene_flow[:, :, ::-1] * 2. / 65535. - 1.) * 10. + + def __getitem__(self, idx): + """Get dataset sample""" + + samples = [] + + for camera in self.cameras: + + # Get filename + filename = self.rgb_tree.get_item(idx) + filename = {key: val.replace('Camera_0', 'Camera_{}'.format(camera)) + for key, val in filename.items()} + + # Base sample + sample = { + 'idx': idx, + 'tag': self.tag, + 'filename': self.relative_path(filename), + 'splitname': '%s_%010d' % (self.split, idx), + } + + # Image and intrinsics + sample.update({ + 'rgb': read_image(filename), + 'intrinsics': self._get_intrinsics(filename, camera, self.mode), + }) + + # If returning pose + if self.with_pose: + sample['pose'] = self._get_pose(filename, camera, self.mode) + + # If returning depth + if self.with_depth: + sample['depth'] = self._get_depth(filename) + + # If returning input depth + if self.with_input_depth: + sample['input_depth'] = self._get_depth(filename) + + # If returning semantic + if self.with_semantic: + filename = self.semantic_tree.get_item(idx) + sample.update({'semantic': self._get_semantic(filename)}) + # If returning ontology + if self.return_ontology: + sample.update({'ontology': self._get_ontology(filename, self.mode)}) + + # If returning instance + if self.with_instance: + filename = self.instance_tree.get_item(idx) + sample.update({'instance': self._get_instance(filename)}) + + # If returning 3D bounding boxes + if self.with_bbox3d: + filename = self.rgb_tree.get_item(idx) + sample.update({ + 'bboxes3d': self._get_bbox3d(filename) + }) + + # If returning optical flow + if self.with_optical_flow: + sample['bwd_optical_flow'] = \ + dict_remove_nones(self._get_optical_flow(filename, 'bwd')) + sample['fwd_optical_flow'] = \ + dict_remove_nones(self._get_optical_flow(filename, 'fwd')) + + # If returning scene flow + if self.with_scene_flow: + sample['bwd_scene_flow'] = \ + dict_remove_nones(self._get_scene_flow(filename, 'bwd')) + sample['fwd_scene_flow'] = \ + dict_remove_nones(self._get_scene_flow(filename, 'fwd')) + + # If returning context information + if self.with_context: + + # Get context filenames + filename_context = self.rgb_tree.get_context(idx) + filename_context = {key: val.replace('Camera_0', 'Camera_{}'.format(camera)) + for key, val in filename_context.items()} + + # Get RGB context + sample['rgb'].update(read_image(filename_context)) + + # Get pose context + if self.with_pose: + sample['pose'].update(self._get_pose(filename_context, camera, self.mode)) + + # Get depth context + if self.with_depth_context: + sample['depth'].update(self._get_depth(filename_context)) + + # Get input depth context + if self.with_input_depth_context: + sample['input_depth'].update(self._get_depth(filename_context)) + + # Get semantic context + if self.with_semantic_context: + sample['semantic'].update(self._get_semantic(self.semantic_tree.get_context(idx))) + + # Get optical flow context + if self.with_optical_flow_context: + sample['bwd_optical_flow'].update( + dict_remove_nones(self._get_optical_flow(filename_context, 'bwd'))) + sample['fwd_optical_flow'].update( + dict_remove_nones(self._get_optical_flow(filename_context, 'fwd'))) + + # Get scene flow context + if self.with_scene_flow_context: + sample['bwd_scene_flow'].update( + dict_remove_nones(self._get_scene_flow(filename_context, 'bwd'))) + sample['fwd_scene_flow'].update( + dict_remove_nones(self._get_scene_flow(filename_context, 'fwd'))) + + # Stack sample + samples.append(sample) + + # Make relative poses + samples = make_relative_pose(samples) + + # Transform data + if self.data_transform: + samples = self.data_transform(samples) + + # Return stacked sample + return stack_sample(samples) + diff --git a/vidar/datasets/__init__.py b/vidar/datasets/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vidar/datasets/augmentations/__init__.py b/vidar/datasets/augmentations/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vidar/datasets/augmentations/crop.py b/vidar/datasets/augmentations/crop.py new file mode 100755 index 0000000000000000000000000000000000000000..36a18ee99b7b807ea066ff9a8c2992c218e4caa2 --- /dev/null +++ b/vidar/datasets/augmentations/crop.py @@ -0,0 +1,158 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from copy import deepcopy + +import numpy as np + +from vidar.utils.data import keys_with +from vidar.utils.decorators import iterate1 + + +@iterate1 +def crop_pil(image, borders): + """ + Crop a PIL Image + + Parameters + ---------- + image : PIL Image + Input image + borders : Tuple + Borders used for cropping (left, top, right, lower) + + Returns + ------- + image : PIL.Image + Cropped image + """ + return image.crop(borders) + + +@iterate1 +def crop_npy(depth, borders): + """ + Crop a numpy depth map + + Parameters + ---------- + depth : np.Array + Input numpy array + borders : Tuple + Borders used for cropping + + Returns + ------- + image : np.array + Cropped numpy array + """ + # Return if depth value is None + return depth[borders[1]:borders[3], borders[0]:borders[2]] + + +@iterate1 +def crop_intrinsics(intrinsics, borders): + """ + Crop camera intrinsics matrix + + Parameters + ---------- + intrinsics : np.Array + Original intrinsics matrix [3,3] + borders : Tuple + Borders used for cropping + Returns + ------- + intrinsics : np.Array + Cropped intrinsics matrix [3,3] + """ + intrinsics = np.copy(intrinsics) + intrinsics[0, 2] -= borders[0] + intrinsics[1, 2] -= borders[1] + return intrinsics + + +def crop_sample_input(sample, borders): + """ + Crops the input information of a sample + + Parameters + ---------- + sample : Dict + Dictionary with sample values + borders : Tuple + Borders used for cropping + + Returns + ------- + sample : Dict + Cropped sample + """ + # Intrinsics + for key in keys_with(sample, 'intrinsics', without='raw'): + # Create copy of full intrinsics + if f'raw_{key}' not in sample.keys(): + sample[f'raw_{key}'] = deepcopy(sample[key]) + sample[key] = crop_intrinsics(sample[key], borders) + # RGB + for key in keys_with(sample, 'rgb', without='raw'): + sample[key] = crop_pil(sample[key], borders) + # Mask + for key in keys_with(sample, 'mask', without='raw'): + sample[key] = crop_pil(sample[key], borders) + # Input depth + for key in keys_with(sample, 'input_depth'): + sample[key] = crop_npy(sample[key], borders) + # Return cropped sample + return sample + + +def crop_sample_supervision(sample, borders): + """ + Crops the output information of a sample + + Parameters + ---------- + sample : Dict + Dictionary with sample values + borders : Tuple + Borders used for cropping + + Returns + ------- + sample : Dict + Cropped sample + """ + for key in keys_with(sample, 'depth', without='input_depth'): + sample[key] = crop_npy(sample[key], borders) + for key in keys_with(sample, 'optical_flow'): + sample[key] = crop_npy(sample[key], borders) + for key in keys_with(sample, 'scene_flow'): + sample[key] = crop_npy(sample[key], borders) + # Return cropped sample + return sample + + +def crop_sample(sample, borders): + """ + Crops a sample, including image, intrinsics and depth maps. + + Parameters + ---------- + sample : Dict + Dictionary with sample values + borders : Tuple + Borders used for cropping + + Returns + ------- + sample : Dict + Cropped sample + """ + # Crop input information + sample = crop_sample_input(sample, borders) + # Crop output information + sample = crop_sample_supervision(sample, borders) + # Return cropped sample + return sample + + diff --git a/vidar/datasets/augmentations/image.py b/vidar/datasets/augmentations/image.py new file mode 100644 index 0000000000000000000000000000000000000000..a891a225a702ee9a573498f360cf4b0167e4ca74 --- /dev/null +++ b/vidar/datasets/augmentations/image.py @@ -0,0 +1,138 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import random + +import torch +import torchvision.transforms as transforms + +from vidar.utils.data import keys_in +from vidar.utils.decorators import iterate1 + + +def random_colorjitter(parameters): + """ + Creates a reusable color jitter transformation + + Parameters + ---------- + parameters : Tuple + Color jittering parameters (brightness, contrast, saturation, hue, color) + + Returns + ------- + transform : torchvision.Transform + Color jitter transformation with fixed parameters + """ + # Get and unpack values + brightness, contrast, saturation, hue = parameters + brightness = [max(0, 1 - brightness), 1 + brightness] + contrast = [max(0, 1 - contrast), 1 + contrast] + saturation = [max(0, 1 - saturation), 1 + saturation] + hue = [-hue, hue] + + # Initialize transformation list + all_transforms = [] + + # Add brightness transformation + if brightness is not None: + brightness_factor = random.uniform(brightness[0], brightness[1]) + all_transforms.append(transforms.Lambda( + lambda img: transforms.functional.adjust_brightness(img, brightness_factor))) + # Add contrast transformation + if contrast is not None: + contrast_factor = random.uniform(contrast[0], contrast[1]) + all_transforms.append(transforms.Lambda( + lambda img: transforms.functional.adjust_contrast(img, contrast_factor))) + # Add saturation transformation + if saturation is not None: + saturation_factor = random.uniform(saturation[0], saturation[1]) + all_transforms.append(transforms.Lambda( + lambda img: transforms.functional.adjust_saturation(img, saturation_factor))) + # Add hue transformation + if hue is not None: + hue_factor = random.uniform(hue[0], hue[1]) + all_transforms.append(transforms.Lambda( + lambda img: transforms.functional.adjust_hue(img, hue_factor))) + # Shuffle transformation order + random.shuffle(all_transforms) + # Return composed transformation + return transforms.Compose(all_transforms) + + +def colorjitter_sample(samples, parameters, background=None, prob=1.0): + """ + Jitters input images as data augmentation. + + Parameters + ---------- + samples : Dict + Input sample + parameters : tuple (brightness, contrast, saturation, hue, color) + Color jittering parameters + background: None or String + Which background color should be use + prob : Float + Jittering probability + + Returns + ------- + sample : dict + Jittered sample + """ + if random.random() < prob: + # Prepare jitter transformation + colorjitter_transform = random_colorjitter(parameters[:4]) + # Prepare color transformation if requested + if len(parameters) > 4 and parameters[4] > 0: + matrix = (random.uniform(1. - parameters[4], 1 + parameters[4]), 0, 0, 0, + 0, random.uniform(1. - parameters[4], 1 + parameters[4]), 0, 0, + 0, 0, random.uniform(1. - parameters[4], 1 + parameters[4]), 0) + else: + matrix = None + for sample in samples: + # Jitter sample keys + for key in keys_in(sample, ['rgb']): + for ctx in sample[key].keys(): + bkg, color = [], {'white': (255, 255, 255), 'black': (0, 0, 0)} + if background is not None: + for i in range(sample[key][ctx].size[0]): + for j in range(sample[key][ctx].size[1]): + if sample[key][ctx].getpixel((i,j)) == color[background]: + bkg.append((i,j)) + sample[key][ctx] = colorjitter_transform(sample[key][ctx]) + if matrix is not None: + sample[key][ctx] = sample[key][ctx].convert('RGB', matrix) + if background is not None: + for ij in bkg: + sample[key][ctx].putpixel(ij, color[background]) + # Return jittered (?) sample + return samples + + +@iterate1 +def normalize_sample(sample, mean, std): + """ + Normalize sample + + Parameters + ---------- + sample : Dict + Input sample dictionary + mean : torch.Tensor + Normalization mean [B,3] + std : torch.Tensor + Normalization standard deviation [B,3] + + Returns + ------- + sample : Dict + Normalized sample + """ + # Get mean and std values in the right shape + mean = torch.tensor(mean).reshape(3, 1, 1) + std = torch.tensor(std).reshape(3, 1, 1) + # Apply mean and std to every image + for key_sample in keys_in(sample, ['rgb']): + sample[key_sample] = {key:(val - mean) / std for + key, val in sample[key_sample].items()} + return sample diff --git a/vidar/datasets/augmentations/misc.py b/vidar/datasets/augmentations/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..8cbf175195b7a9ed51dab7b6341e0ce5ea2dad1e --- /dev/null +++ b/vidar/datasets/augmentations/misc.py @@ -0,0 +1,98 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from copy import deepcopy + +import numpy as np + +from vidar.utils.data import keys_in +from vidar.utils.decorators import iterate1 + + +def duplicate_sample(sample, keys): + """ + Duplicates sample images and contexts to preserve their un-augmented versions. + + Parameters + ---------- + sample : Dict + Input sample + + Returns + ------- + sample : Dict + Sample including [+"_original"] keys with copies of images and contexts. + """ + for key in keys_in(sample, keys): + sample[f'raw_{key}'] = deepcopy(sample[key]) + # Return duplicated sample + return sample + + +@iterate1 +def mask_depth_number(depth, num_points): + """ + Mask depth map to remove valid pixels given the target number of points to keep. + + Parameters + ---------- + depth : np.Array + Depth map to be masked + num_points : Int + Number of input depth points that should be kept at each iteration + Returns + ------- + depth : np.Array + Masked depth map (modification done in-place!) + """ + # Find probability of maintaining + total_points = depth.shape[0] * depth.shape[1] + rnd = np.random.rand(depth.shape[0], depth.shape[1]) + percentile = 100 * num_points / total_points + # Mask depth map + mask = rnd < np.percentile(rnd, q=100 - percentile) + depth[mask] = 0.0 + # Return depth map + return depth + + +@iterate1 +def mask_depth_percentage(depth, percentage): + """ + Mask depth map to remove valid pixels given a range of percentages. + + Parameters + ---------- + depth : np.Array + Depth map to be masked + percentage : Tuple + Min/Max percentages to be maintained (min, max) + Returns + ------- + depth : np.Array + Masked depth map (modification done in-place!) + """ + # Find probability of maintaining + rnd = np.random.uniform(low=percentage[0], high=percentage[1], size=1)[0] + # Mask depth map + depth[np.random.rand(*depth.shape) > rnd] = 0.0 + # Return depth map + return depth + + +def clip_depth(sample, max_value): + """Clip depth map to a maximum range""" + for i in range(len(sample)): + if 'depth' in sample[i]: + for ctx in sample[i]['depth'].keys(): + sample[i]['depth'][ctx][sample[i]['depth'][ctx] > max_value] = max_value + return sample + + +def mask_depth_range(sample, depth_range): + """Mask out depth map within a range""" + for i in range(len(sample)): + if 'depth' in sample[i]: + for ctx in sample[i]['depth'].keys(): + sample[i]['depth'][ctx][sample[i]['depth'][ctx] < depth_range[0]] = 0.0 + sample[i]['depth'][ctx][sample[i]['depth'][ctx] > depth_range[1]] = 0.0 + return sample diff --git a/vidar/datasets/augmentations/resize.py b/vidar/datasets/augmentations/resize.py new file mode 100644 index 0000000000000000000000000000000000000000..0deec59b69fa761d7aac94526d4f2e37d1b54b34 --- /dev/null +++ b/vidar/datasets/augmentations/resize.py @@ -0,0 +1,312 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from copy import deepcopy + +import cv2 +import numpy as np +import torch +import torchvision.transforms as transforms +from torchvision.transforms import InterpolationMode + +from vidar.utils.data import keys_with +from vidar.utils.decorators import iterate1 +from vidar.utils.types import is_seq + + +@iterate1 +def resize_pil(image, shape, interpolation=InterpolationMode.LANCZOS): + """ + Resizes input image + + Parameters + ---------- + image : Image PIL + Input image + shape : Tuple + Output shape [H,W] + interpolation : Int + Interpolation mode + + Returns + ------- + image : Image PIL + Resized image + """ + transform = transforms.Resize(shape, interpolation=interpolation) + return transform(image) + + +@iterate1 +def resize_npy(depth, shape, expand=True): + """ + Resizes depth map + + Parameters + ---------- + depth : np.Array + Depth map [h,w] + shape : Tuple + Output shape (H,W) + expand : Bool + Expand output to [H,W,1] + + Returns + ------- + depth : np.Array + Resized depth map [H,W] + """ + # If a single number is provided, use resize ratio + if not is_seq(shape): + shape = tuple(int(s * shape) for s in depth.shape) + # Resize depth map + depth = cv2.resize(depth, dsize=tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST) + # Return resized depth map + return np.expand_dims(depth, axis=2) if expand else depth + + +@iterate1 +def resize_npy_preserve(depth, shape): + """ + Resizes depth map preserving all valid depth pixels + Multiple downsampled points can be assigned to the same pixel. + + Parameters + ---------- + depth : np.Array + Depth map [h,w] + shape : Tuple + Output shape (H,W) + + Returns + ------- + depth : np.Array + Resized depth map [H,W,1] + """ + # If a single number is provided, use resize ratio + if not is_seq(shape): + shape = tuple(int(s * shape) for s in depth.shape) + # Store dimensions and reshapes to single column + depth = np.squeeze(depth) + h, w = depth.shape + x = depth.reshape(-1) + # Create coordinate grid + uv = np.mgrid[:h, :w].transpose(1, 2, 0).reshape(-1, 2) + # Filters valid points + idx = x > 0 + crd, val = uv[idx], x[idx] + # Downsamples coordinates + crd[:, 0] = (crd[:, 0] * (shape[0] / h)).astype(np.int32) + crd[:, 1] = (crd[:, 1] * (shape[1] / w)).astype(np.int32) + # Filters points inside image + idx = (crd[:, 0] < shape[0]) & (crd[:, 1] < shape[1]) + crd, val = crd[idx], val[idx] + # Creates downsampled depth image and assigns points + depth = np.zeros(shape) + depth[crd[:, 0], crd[:, 1]] = val + # Return resized depth map + return np.expand_dims(depth, axis=2) + + +@iterate1 +def resize_torch_preserve(depth, shape): + """ + Resizes depth map preserving all valid depth pixels + Multiple downsampled points can be assigned to the same pixel. + + Parameters + ---------- + depth : torch.Tensor + Depth map [B,1,h,w] + shape : Tuple + Output shape (H,W) + + Returns + ------- + depth : torch.Tensor + Resized depth map [B,1,H,W] + """ + if depth.dim() == 4: + return torch.stack([resize_torch_preserve(depth[i], shape) + for i in range(depth.shape[0])], 0) + # If a single number is provided, use resize ratio + if not is_seq(shape): + shape = tuple(int(s * shape) for s in depth.shape) + # Store dimensions and reshapes to single column + c, h, w = depth.shape + # depth = np.squeeze(depth) + # h, w = depth.shape + x = depth.reshape(-1) + # Create coordinate grid + uv = np.mgrid[:h, :w].transpose(1, 2, 0).reshape(-1, 2) + # Filters valid points + idx = x > 0 + crd, val = uv[idx], x[idx] + # Downsamples coordinates + crd[:, 0] = (crd[:, 0] * (shape[0] / h)).astype(np.int32) + crd[:, 1] = (crd[:, 1] * (shape[1] / w)).astype(np.int32) + # Filters points inside image + idx = (crd[:, 0] < shape[0]) & (crd[:, 1] < shape[1]) + crd, val = crd[idx], val[idx] + # Creates downsampled depth image and assigns points + depth = torch.zeros(shape, device=depth.device, dtype=depth.dtype) + depth[crd[:, 0], crd[:, 1]] = val + # Return resized depth map + return depth.unsqueeze(0) + + +@iterate1 +def resize_npy_multiply(data, shape): + """Resize a numpy array and scale its content accordingly""" + ratio_w = shape[0] / data.shape[0] + ratio_h = shape[1] / data.shape[1] + out = resize_npy(data, shape, expand=False) + out[..., 0] *= ratio_h + out[..., 1] *= ratio_w + return out + + +@iterate1 +def resize_intrinsics(intrinsics, original, resized): + """ + Resize camera intrinsics matrix to match a target resolution + + Parameters + ---------- + intrinsics : np.Array + Original intrinsics matrix [3,3] + original : Tuple + Original image resolution [W,H] + resized : Tuple + Target image resolution [w,h] + Returns + ------- + intrinsics : np.Array + Resized intrinsics matrix [3,3] + """ + intrinsics = np.copy(intrinsics) + intrinsics[0] *= resized[0] / original[0] + intrinsics[1] *= resized[1] / original[1] + return intrinsics + + +@iterate1 +def resize_sample_input(sample, shape, shape_supervision=None, + depth_downsample=1.0, preserve_depth=False, + pil_interpolation=InterpolationMode.LANCZOS): + """ + Resizes the input information of a sample + + Parameters + ---------- + sample : Dict + Dictionary with sample values + shape : tuple (H,W) + Output shape + shape_supervision : Tuple + Output supervision shape (H,W) + depth_downsample: Float + Resize ratio for depth maps + preserve_depth : Bool + Preserve depth maps when resizing + pil_interpolation : Int + Interpolation mode + + Returns + ------- + sample : Dict + Resized sample + """ + # Intrinsics + for key in keys_with(sample, 'intrinsics', without='raw'): + if f'raw_{key}' not in sample.keys(): + sample[f'raw_{key}'] = deepcopy(sample[key]) + sample[key] = resize_intrinsics(sample[key], list(sample['rgb'].values())[0].size, shape[::-1]) + # RGB + for key in keys_with(sample, 'rgb', without='raw'): + sample[key] = resize_pil(sample[key], shape, interpolation=pil_interpolation) + # Mask + for key in keys_with(sample, 'mask', without='raw'): + sample[key] = resize_pil(sample[key], shape, interpolation=InterpolationMode.NEAREST) + # Input depth + for key in keys_with(sample, 'input_depth'): + shape_depth = [int(s * depth_downsample) for s in shape] + resize_npy_depth = resize_npy_preserve if preserve_depth else resize_npy + sample[key] = resize_npy_depth(sample[key], shape_depth) + return sample + + +@iterate1 +def resize_sample_supervision(sample, shape, depth_downsample=1.0, preserve_depth=False): + """ + Resizes the output information of a sample + + Parameters + ---------- + sample : Dict + Dictionary with sample values + shape : Tuple + Output shape (H,W) + depth_downsample: Float + Resize ratio for depth maps + preserve_depth : Bool + Preserve depth maps when resizing + + Returns + ------- + sample : Dict + Resized sample + """ + # Depth + for key in keys_with(sample, 'depth', without='input_depth'): + shape_depth = [int(s * depth_downsample) for s in shape] + resize_npy_depth = resize_npy_preserve if preserve_depth else resize_npy + sample[key] = resize_npy_depth(sample[key], shape_depth) + # Semantic + for key in keys_with(sample, 'semantic'): + sample[key] = resize_npy(sample[key], shape, expand=False) + # Optical flow + for key in keys_with(sample, 'optical_flow'): + sample[key] = resize_npy_multiply(sample[key], shape) + # Scene flow + for key in keys_with(sample, 'scene_flow'): + sample[key] = resize_npy(sample[key], shape, expand=False) + # Return resized sample + return sample + + +def resize_sample(sample, shape, shape_supervision=None, depth_downsample=1.0, preserve_depth=False, + pil_interpolation=InterpolationMode.LANCZOS): + """ + Resizes a sample, including image, intrinsics and depth maps. + + Parameters + ---------- + sample : Dict + Dictionary with sample values + shape : Tuple + Output shape (H,W) + shape_supervision : Tuple + Output shape (H,W) + depth_downsample: Float + Resize ratio for depth maps + preserve_depth : Bool + Preserve depth maps when resizing + pil_interpolation : Int + Interpolation mode + + Returns + ------- + sample : Dict + Resized sample + """ + # Resize input information + sample = resize_sample_input(sample, shape, + depth_downsample=depth_downsample, + preserve_depth=preserve_depth, + pil_interpolation=pil_interpolation) + # Resize output information + sample = resize_sample_supervision(sample, shape_supervision, + depth_downsample=depth_downsample, + preserve_depth=preserve_depth) + # Return resized sample + return sample diff --git a/vidar/datasets/augmentations/tensor.py b/vidar/datasets/augmentations/tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..c90bf7de2bdef5d5d3f1773ecdedf21645c54755 --- /dev/null +++ b/vidar/datasets/augmentations/tensor.py @@ -0,0 +1,53 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import torch +import torchvision.transforms as transforms + +from vidar.utils.decorators import iterate1 + + +@iterate1 +def to_tensor(matrix, tensor_type='torch.FloatTensor'): + """Casts a matrix to a torch.Tensor""" + return torch.Tensor(matrix).type(tensor_type) + + +@iterate1 +def to_tensor_image(image, tensor_type='torch.FloatTensor'): + """Casts an image to a torch.Tensor""" + transform = transforms.ToTensor() + return transform(image).type(tensor_type) + + +@iterate1 +def to_tensor_sample(sample, tensor_type='torch.FloatTensor'): + """ + Casts the keys of sample to tensors. + + Parameters + ---------- + sample : Dict + Input sample + tensor_type : String + Type of tensor we are casting to + + Returns + ------- + sample : Dict + Sample with keys cast as tensors + """ + # Convert using torchvision + keys = ['rgb', 'mask', 'input_depth', 'depth', 'disparity', + 'optical_flow', 'scene_flow'] + for key_sample, val_sample in sample.items(): + for key in keys: + if key in key_sample: + sample[key_sample] = to_tensor_image(val_sample, tensor_type) + # Convert from numpy + keys = ['intrinsics', 'extrinsics', 'pose', 'pointcloud', 'semantic'] + for key_sample, val_sample in sample.items(): + for key in keys: + if key in key_sample: + sample[key_sample] = to_tensor(val_sample, tensor_type) + # Return converted sample + return sample diff --git a/vidar/datasets/utils/FolderTree.py b/vidar/datasets/utils/FolderTree.py new file mode 100644 index 0000000000000000000000000000000000000000..667a86052903187dd9cb982951a00be674e50c30 --- /dev/null +++ b/vidar/datasets/utils/FolderTree.py @@ -0,0 +1,132 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import os +from glob import glob + +import numpy as np + +from vidar.utils.data import make_list +from vidar.utils.types import is_list + + +class FolderTree: + """ + Creates a dataset tree folder structure for file loading. + + Parameters + ---------- + path : String + Path where dataset is stored + prefix : String + Optional prefix for each filename + suffix : String + Optional prefix for each filename + sub_folders : list[String] + Optional list of sub_folders located inside each data folder where data is stored + deep : Int + How deep in the folder structure we should go + single_folder : Bool + Whether the folder structure should be ignored, and a single folder is used + nested : bool + If true, go one folder deeper to find scenes + stride: Int + Stride for context generation + context : Tuple + Which context should be used + """ + def __init__(self, path, prefix='', suffix='', sub_folders=('',), deep=1, + single_folder=False, nested=False, stride=1, context=()): + + # Store context information + self.context = list(context) + if 0 not in self.context: + self.context.append(0) + self.num_context = 0 if len(self.context) == 0 else max(self.context) - min(self.context) + self.with_context = self.num_context > 0 + self.min_context = 0 if not self.with_context else min(self.context) + + self.stride = stride + self.pad_numbers = False + + # Initialize empty folder tree + self.folder_tree = [] + + # If we are providing a file list, treat each line as a scene + if is_list(path): + self.folder_tree = [[file] for file in path] + # If we are providing a folder + else: + # Get folders + string = '*' + '/*' * (deep - 1) + folders = glob(os.path.join(path, string)) + folders.sort() + + # If nesting, go one folder deeper in order to find the scenes + if nested: + upd_folders = [] + for folder in folders: + new_folders = glob(os.path.join(folder, '*')) + upd_folders.extend(new_folders) + folders = upd_folders + folders.sort() + + if single_folder: + # Use current folder as the only one + self.folder_tree.append(folders) + else: + # Populate folder tree + for folder in folders: + # For each sub-folder + for sub_folder in make_list(sub_folders): + # Get and sort files in each folder + files = glob(os.path.join(folder, sub_folder, '{}*{}'.format(prefix, suffix))) + if self.pad_numbers: + for i in range(len(files)): + pref, suf = files[i].split('/')[:-1], files[i].split('/')[-1] + num, ext = suf.split('.') + files[i] = '/'.join(pref) + ('/%010d' % int(num)) + '.' + ext + files.sort() + if self.pad_numbers: + for i in range(len(files)): + pref, suf = files[i].split('/')[:-1], files[i].split('/')[-1] + num, ext = suf.split('.') + files[i] = '/'.join(pref) + ('/%d' % int(num)) + '.' + ext + if self.stride > 1: + files = files[::self.stride] + # Only store if there are more images than context + if len(files) > self.num_context: + self.folder_tree.append(files) + + # Get size of each folder + self.slices = [len(folder) for folder in self.folder_tree] + # Compensate for context size + if self.with_context: + self.slices = [s - self.num_context for s in self.slices] + # Create cumulative size and get total + self.slices = [0] + list(np.cumsum(self.slices)) + self.total = self.slices[-1] + + def __len__(self): + """Dataset size""" + return self.total + + def get_idxs(self, idx): + """Get folder and file indexes given dataset index""" + idx1 = np.searchsorted(self.slices, idx, side='right') - 1 + idx2 = idx - self.slices[idx1] + return idx1, idx2 + + def get_item(self, idx, return_loc=False): + """Return filename item given index""" + idx1, idx2 = self.get_idxs(idx) + item = {0: self.folder_tree[idx1][idx2 - self.min_context]} + if return_loc: + return item, idx2 - self.min_context + else: + return item + + def get_context(self, idx): + """Return forward context given index.""" + idx1, idx2 = self.get_idxs(idx) + return {ctx: self.folder_tree[idx1][idx2 - self.min_context + ctx] for ctx in self.context} + diff --git a/vidar/datasets/utils/__init__.py b/vidar/datasets/utils/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vidar/datasets/utils/misc.py b/vidar/datasets/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..67b3fd2d7ffe0f32327c9d50c79cde5700b9ec77 --- /dev/null +++ b/vidar/datasets/utils/misc.py @@ -0,0 +1,322 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import json +import os +import random +from collections import OrderedDict + +import numpy as np +import torch + +from vidar.utils.decorators import iterate1 +from vidar.utils.types import is_seq, is_tensor, is_dict, is_int + + +def stack_sample(sample, lidar_sample=None, radar_sample=None): + """ + Stack samples from multiple cameras + + Parameters + ---------- + sample : list[Dict] + List of camera samples + lidar_sample : list[Dict] + List of lidar samples + radar_sample : list[Dict] + List of radar samples + + Returns + ------- + stacked_sample: Dict + Stacked sample + """ + # If there are no tensors, return empty list + if len(sample) == 0: + return None + # If there is only one sensor don't do anything + if len(sample) == 1: + sample = sample[0] + return sample + # Otherwise, stack sample + first_sample = sample[0] + stacked_sample = {} + for key, val in first_sample.items(): + # Global keys (do not stack) + if key in ['idx', 'dataset_idx']: + stacked_sample[key] = first_sample[key] + # Meta keys + elif key in ['meta']: + stacked_sample[key] = {} + for key2 in first_sample[key].keys(): + stacked_sample[key][key2] = {} + for key3 in first_sample[key][key2].keys(): + stacked_sample[key][key2][key3] = torch.stack( + [torch.tensor(s[key][key2][key3]) for s in sample], 0) + # Stack tensors + elif is_tensor(val): + stacked_sample[key] = torch.stack([s[key] for s in sample], 0) + # Stack list + elif is_seq(first_sample[key]): + stacked_sample[key] = [] + # Stack list of torch tensors + if is_tensor(first_sample[key][0]): + for i in range(len(first_sample[key])): + stacked_sample[key].append( + torch.stack([s[key][i] for s in sample], 0)) + else: + stacked_sample[key] = [s[key] for s in sample] + # Repeat for dictionaries + elif is_dict(first_sample[key]): + stacked_sample[key] = stack_sample([s[key] for s in sample]) + # Append lists + else: + stacked_sample[key] = [s[key] for s in sample] + + # Return stacked sample + return stacked_sample + + +def merge_sample(samples): + """Merge information from multiple samples""" + merged_sample = {} + for sample in samples: + for key, val in sample.items(): + if key not in merged_sample: + merged_sample[key] = val + else: + merged_sample[key] = merge_sample([merged_sample[key], val]) + return merged_sample + + +def parse_crop(cfg, shape): + """Parse crop information to generate borders""" + borders = None + if cfg.has('crop_borders'): + borders = parse_crop_borders(cfg.crop_borders, shape) + if cfg.has('crop_random'): + if borders is None: + borders = [0, 0, shape[1], shape[0]] + borders = parse_crop_random(borders, cfg.crop_random) + return borders + + +def parse_crop_borders(borders, shape): + """ + Calculate borders for cropping. + + Parameters + ---------- + borders : Tuple + Border input for parsing. Can be one of the following forms: + (int, int, int, int): y, height, x, width + (int, int): y, x --> y, height = image_height - y, x, width = image_width - x + Negative numbers are taken from image borders, according to the shape argument + Float numbers for y and x are treated as percentage, according to the shape argument, + and in this case height and width are centered at that point. + shape : Tuple + Image shape (image_height, image_width), used to determine negative crop boundaries + + Returns + ------- + borders : Tuple + Parsed borders for cropping (left, top, right, bottom) + """ + # Return full image if there are no borders to crop + if len(borders) == 0: + return 0, 0, shape[1], shape[0] + # Copy borders for modification + borders = list(borders).copy() + # If borders are 4-dimensional + if len(borders) == 4: + borders = [borders[2], borders[0], borders[3], borders[1]] + if is_int(borders[0]): + # If horizontal cropping is integer (regular cropping) + borders[0] += shape[1] if borders[0] < 0 else 0 + borders[2] += shape[1] if borders[2] <= 0 else borders[0] + else: + # If horizontal cropping is float (center cropping) + center_w, half_w = borders[0] * shape[1], borders[2] / 2 + borders[0] = int(center_w - half_w) + borders[2] = int(center_w + half_w) + if is_int(borders[1]): + # If vertical cropping is integer (regular cropping) + borders[1] += shape[0] if borders[1] < 0 else 0 + borders[3] += shape[0] if borders[3] <= 0 else borders[1] + else: + # If vertical cropping is float (center cropping) + center_h, half_h = borders[1] * shape[0], borders[3] / 2 + borders[1] = int(center_h - half_h) + borders[3] = int(center_h + half_h) + # If borders are 2-dimensional + elif len(borders) == 2: + borders = [borders[1], borders[0]] + if is_int(borders[0]): + # If cropping is integer (regular cropping) + borders = (max(0, borders[0]), + max(0, borders[1]), + shape[1] + min(0, borders[0]), + shape[0] + min(0, borders[1])) + else: + # If cropping is float (center cropping) + center_w, half_w = borders[0] * shape[1], borders[1] / 2 + center_h, half_h = borders[0] * shape[0], borders[1] / 2 + borders = (int(center_w - half_w), int(center_h - half_h), + int(center_w + half_w), int(center_h + half_h)) + # Otherwise, invalid + else: + raise NotImplementedError('Crop tuple must have 2 or 4 values.') + # Assert that borders are valid + assert 0 <= borders[0] < borders[2] <= shape[1] and \ + 0 <= borders[1] < borders[3] <= shape[0], 'Crop borders {} are invalid'.format(borders) + # Return updated borders + return borders + + +def parse_crop_random(borders, shape): + """ + Create borders for random cropping. + Crops are generated anywhere in the image inside the borders + + Parameters + ---------- + borders : Tuple + Area of the image where random cropping can happen (left, top, right, bottom) + shape : Tuple + Cropped output shape (height, width) + + Returns + ------- + borders : tuple + Parsed borders for cropping (left, top, right, bottom) + """ + # Return full borders if there is no random crop + if len(shape) == 0: + return borders + # Check if random crop is valid + assert 0 < shape[1] <= borders[2] - borders[0] and \ + 0 < shape[0] <= borders[3] - borders[1], 'Random crop must be smaller than the image' + # Sample a crop + x = random.randint(borders[0], borders[2] - shape[1]) + y = random.randint(borders[1], borders[3] - shape[0]) + # Return new borders + return x, y, x + shape[1], y + shape[0] + + +@iterate1 +def invert_pose(pose): + """ + Inverts a transformation matrix (pose) + + Parameters + ---------- + pose : np.Array + Input pose [4, 4] + + Returns + ------- + inv_pose : np.Array + Inverted pose [4, 4] + """ + inv_pose = np.eye(4) + inv_pose[:3, :3] = np.transpose(pose[:3, :3]) + inv_pose[:3, -1] = - inv_pose[:3, :3] @ pose[:3, -1] + # return np.linalg.inv(pose) + return inv_pose + + +def make_relative_pose(samples): + """ + Convert sample poses to relative frane of reference (based on the first target frame) + + Parameters + ---------- + samples : list[Dict] + Input samples + + Returns + ------- + samples : list[Dict] + Relative samples + """ + # Do nothing if there is no pose + if 'pose' not in samples[0]: + return samples + # Get inverse current poses + inv_pose = [invert_pose(samples[i]['pose'][0]) for i in range(len(samples))] + # For each camera + for i in range(len(samples)): + # For each context + for j in samples[0]['pose'].keys(): + if j == 0: + if i > 0: + samples[i]['pose'][j] = \ + samples[i]['pose'][j] @ inv_pose[0] + else: + samples[i]['pose'][j] = \ + samples[i]['pose'][j] @ inv_pose[i] + return samples + + +def dummy_intrinsics(image): + """ + Return dummy intrinsics calculated based on image resolution + + Parameters + ---------- + image : PIL Image + Image from which intrinsics will be calculated + + Returns + ------- + intrinsics : np.Array + Image intrinsics (fx = cx = w/2, fy = cy = h/2) [3,3] + """ + w, h = [float(d) for d in image.size] + return np.array([[w/2, 0., w/2. - 0.5], + [0., h/2, h/2. - 0.5], + [0., 0., 1.]]) + + +def load_ontology(name, filter_list=None): + """Loads ontology from file and optionally filters it""" + filename = 'vidar/ontologies/{}.json'.format(name) + if os.path.exists(filename): + ontology = json.load(open(filename, 'r')) + if filter_list is not None and len(filter_list) > 0: + ontology = filter_ontology(ontology, filter_list) + return ontology + else: + return None + + +def save_ontology(ontology, name): + """Save ontology to a JSON file""" + if is_seq(ontology): + ontology = ontology[0] + for key in ontology.keys(): + ontology[key]['color'] = ontology[key]['color'].tolist() + json.dump(ontology, open('ontologies/{}.json'.format(name), 'w')) + + +def filter_ontology(ontology, values): + """Filter ontology to remove certain classes""" + new_ontology = OrderedDict() + for i, val in enumerate(values[1:]): + new_ontology[i] = ontology[str(val)] + return new_ontology + + +def convert_ontology(semantic_id, ontology_convert): + """Convert from one ontology to another""" + if ontology_convert is None: + return semantic_id + else: + semantic_id_convert = semantic_id.copy() + for key, val in ontology_convert.items(): + semantic_id_convert[semantic_id == key] = val + return semantic_id_convert + + +def initialize_ontology(base, ontology): + """Initialize ontology and conversion table if necessary""" + return load_ontology(base), None diff --git a/vidar/datasets/utils/transforms.py b/vidar/datasets/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..bd48f67f38494f6e48116b047d3f8434351f5478 --- /dev/null +++ b/vidar/datasets/utils/transforms.py @@ -0,0 +1,98 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from functools import partial + +from vidar.datasets.augmentations.image import \ + colorjitter_sample, normalize_sample +from vidar.datasets.augmentations.crop import \ + crop_sample_input, crop_sample +from vidar.datasets.augmentations.misc import \ + duplicate_sample, mask_depth_percentage, mask_depth_number, clip_depth, mask_depth_range +from vidar.datasets.augmentations.resize import resize_sample, resize_sample_input +from vidar.datasets.augmentations.tensor import to_tensor_sample +from vidar.datasets.utils.misc import parse_crop +from vidar.utils.types import is_list + + +def train_transforms(sample, cfg): + """ + Training data augmentation transformations + + Parameters + ---------- + sample : Dict + Sample to be augmented + cfg : Config + Configuration for transformations + + Returns + ------- + sample : Dict + Augmented sample + """ + # Resize + if cfg.has('resize'): + resize_fn = resize_sample if cfg.has('resize_supervision') else resize_sample_input + shape_supervision = None if not cfg.has('resize_supervision') else \ + cfg.resize if not is_list(cfg.resize_supervision) else cfg.resize_supervision + sample = resize_fn(sample, shape=cfg.resize, shape_supervision=shape_supervision, + depth_downsample=cfg.has('depth_downsample', 1.0), + preserve_depth=cfg.has('preserve_depth', False)) + # Crop + if cfg.has('crop_borders') or cfg.has('crop_random'): + crop_fn = crop_sample if cfg.has('crop_supervision') else crop_sample_input + sample = [crop_fn(s, parse_crop(cfg, s['rgb'][0].size[::-1])) for s in sample] + # Clip depth to a maximum value + if cfg.has('clip_depth'): + sample = clip_depth(sample, cfg.clip_depth) + if cfg.has('mask_depth_range'): + sample = mask_depth_range(sample, cfg.mask_depth_range) + # Change input depth density + if 'input_depth' in sample: + if cfg.has('input_depth_number'): + sample['input_depth'] = mask_depth_number( + sample['input_depth'], cfg.input_depth_number) + if cfg.has('input_depth_percentage'): + sample['input_depth'] = mask_depth_percentage( + sample['input_depth'], cfg.input_depth_percentage) + # Apply jittering + if cfg.has('jittering'): + sample = duplicate_sample(sample, ['rgb']) + sample = colorjitter_sample(sample, cfg.jittering, cfg.has('background', None), prob=1.0) + # Convert to tensor + sample = to_tensor_sample(sample) + if cfg.has('normalization'): + sample = normalize_sample(sample, cfg.normalization[0], cfg.normalization[1]) + # Return augmented sample + return sample + + +def no_transform(sample): + """No transformation, only convert sample to tensors""" + sample = to_tensor_sample(sample) + return sample + + +def get_transforms(mode, cfg=None): + """ + Get data augmentation transformations for each split + + Parameters + ---------- + mode : String {'train', 'validation', 'test'} + Mode from which we want the data augmentation transformations + cfg : Config + Configuration file + + Returns + ------- + XXX_transform: Partial function + Data augmentation transformation for that mode + """ + if mode == 'train': + return partial(train_transforms, cfg=cfg) + elif mode == 'none': + return partial(no_transform) + else: + raise ValueError('Unknown mode {}'.format(mode)) + diff --git a/vidar/geometry/camera.py b/vidar/geometry/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..0ba80045ad235422ccd5aa651210c86fef102d62 --- /dev/null +++ b/vidar/geometry/camera.py @@ -0,0 +1,528 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from abc import ABC +from copy import deepcopy + +import torch +import torch.nn as nn +from einops import rearrange + +from vidar.geometry.camera_utils import invert_intrinsics, scale_intrinsics +from vidar.geometry.pose import Pose +from vidar.geometry.pose_utils import invert_pose +from vidar.utils.tensor import pixel_grid, same_shape, cat_channel_ones, norm_pixel_grid, interpolate, interleave +from vidar.utils.types import is_tensor, is_seq + + +class Camera(nn.Module, ABC): + """ + Camera class for 3D reconstruction + + Parameters + ---------- + K : torch.Tensor + Camera intrinsics [B,3,3] + hw : Tuple + Camera height and width + Twc : Pose or torch.Tensor + Camera pose (world to camera) [B,4,4] + Tcw : Pose or torch.Tensor + Camera pose (camera to world) [B,4,4] + """ + def __init__(self, K, hw, Twc=None, Tcw=None): + super().__init__() + + # Asserts + + assert Twc is None or Tcw is None + + # Fold if multi-batch + + if K.dim() == 4: + K = rearrange(K, 'b n h w -> (b n) h w') + if Twc is not None: + Twc = rearrange(Twc, 'b n h w -> (b n) h w') + if Tcw is not None: + Tcw = rearrange(Tcw, 'b n h w -> (b n) h w') + + # Intrinsics + + if same_shape(K.shape[-2:], (3, 3)): + self._K = torch.eye(4, dtype=K.dtype, device=K.device).repeat(K.shape[0], 1, 1) + self._K[:, :3, :3] = K + else: + self._K = K + + # Pose + + if Twc is None and Tcw is None: + self._Twc = torch.eye(4, dtype=K.dtype, device=K.device).unsqueeze(0).repeat(K.shape[0], 1, 1) + else: + self._Twc = invert_pose(Tcw) if Tcw is not None else Twc + if is_tensor(self._Twc): + self._Twc = Pose(self._Twc) + + # Resolution + + self._hw = hw + if is_tensor(self._hw): + self._hw = self._hw.shape[-2:] + + def __getitem__(self, idx): + """Return batch-wise pose""" + if is_seq(idx): + return type(self).from_list([self.__getitem__(i) for i in idx]) + else: + return type(self)( + K=self._K[[idx]], + Twc=self._Twc[[idx]] if self._Twc is not None else None, + hw=self._hw, + ) + + def __len__(self): + """Return length as intrinsics batch""" + return self._K.shape[0] + + def __eq__(self, cam): + """Check if two cameras are the same""" + if not isinstance(cam, type(self)): + return False + if self._hw[0] != cam.hw[0] or self._hw[1] != cam.hw[1]: + return False + if not torch.allclose(self._K, cam.K): + return False + if not torch.allclose(self._Twc.T, cam.Twc.T): + return False + return True + + def clone(self): + """Return a copy of this camera""" + return deepcopy(self) + + @property + def pose(self): + """Return camera pose (world to camera)""" + return self._Twc.T + + @property + def K(self): + """Return camera intrinsics""" + return self._K + + @K.setter + def K(self, K): + """Set camera intrinsics""" + self._K = K + + @property + def invK(self): + """Return inverse of camera intrinsics""" + return invert_intrinsics(self._K) + + @property + def batch_size(self): + """Return batch size""" + return self._Twc.T.shape[0] + + @property + def hw(self): + """Return camera height and width""" + return self._hw + + @hw.setter + def hw(self, hw): + """Set camera height and width""" + self._hw = hw + + @property + def wh(self): + """Get camera width and height""" + return self._hw[::-1] + + @property + def n_pixels(self): + """Return number of pixels""" + return self._hw[0] * self._hw[1] + + @property + def fx(self): + """Return horizontal focal length""" + return self._K[:, 0, 0] + + @property + def fy(self): + """Return vertical focal length""" + return self._K[:, 1, 1] + + @property + def cx(self): + """Return horizontal principal point""" + return self._K[:, 0, 2] + + @property + def cy(self): + """Return vertical principal point""" + return self._K[:, 1, 2] + + @property + def fxy(self): + """Return focal length""" + return torch.tensor([self.fx, self.fy], dtype=self.dtype, device=self.device) + + @property + def cxy(self): + """Return principal points""" + return self._K[:, :2, 2] + # return torch.tensor([self.cx, self.cy], dtype=self.dtype, device=self.device) + + @property + def Tcw(self): + """Return camera pose (camera to world)""" + return None if self._Twc is None else self._Twc.inverse() + + @Tcw.setter + def Tcw(self, Tcw): + """Set camera pose (camera to world)""" + self._Twc = Tcw.inverse() + + @property + def Twc(self): + """Return camera pose (world to camera)""" + return self._Twc + + @Twc.setter + def Twc(self, Twc): + """Set camera pose (world to camera)""" + self._Twc = Twc + + @property + def dtype(self): + """Return tensor type""" + return self._K.dtype + + @property + def device(self): + """Return device""" + return self._K.device + + def detach_pose(self): + """Detach pose from the graph""" + return type(self)(K=self._K, hw=self._hw, + Twc=self._Twc.detach() if self._Twc is not None else None) + + def detach_K(self): + """Detach intrinsics from the graph""" + return type(self)(K=self._K.detach(), hw=self._hw, Twc=self._Twc) + + def detach(self): + """Detach camera from the graph""" + return type(self)(K=self._K.detach(), hw=self._hw, + Twc=self._Twc.detach() if self._Twc is not None else None) + + def inverted_pose(self): + """Invert camera pose""" + return type(self)(K=self._K, hw=self._hw, + Twc=self._Twc.inverse() if self._Twc is not None else None) + + def no_translation(self): + """Return new camera without translation""" + Twc = self.pose.clone() + Twc[:, :-1, -1] = 0 + return type(self)(K=self._K, hw=self._hw, Twc=Twc) + + @staticmethod + def from_dict(K, hw, Twc=None): + """Create cameras from a pose dictionary""" + return {key: Camera(K=K[0], hw=hw[0], Twc=val) for key, val in Twc.items()} + + # @staticmethod + # def from_dict(K, hw, Twc=None): + # return {key: Camera(K=K[(0, 0)], hw=hw[(0, 0)], Twc=val) for key, val in Twc.items()} + + @staticmethod + def from_list(cams): + """Create cameras from a list""" + K = torch.cat([cam.K for cam in cams], 0) + Twc = torch.cat([cam.Twc.T for cam in cams], 0) + return Camera(K=K, Twc=Twc, hw=cams[0].hw) + + def scaled(self, scale_factor): + """Return a scaled camera""" + if scale_factor is None or scale_factor == 1: + return self + if is_seq(scale_factor): + if len(scale_factor) == 4: + scale_factor = scale_factor[-2:] + scale_factor = [float(scale_factor[i]) / float(self._hw[i]) for i in range(2)] + else: + scale_factor = [scale_factor] * 2 + return type(self)( + K=scale_intrinsics(self._K, scale_factor), + hw=[int(self._hw[i] * scale_factor[i]) for i in range(len(self._hw))], + Twc=self._Twc + ) + + def offset_start(self, start): + """Offset camera intrinsics based on a crop""" + new_cam = self.clone() + start = start.to(self.device) + new_cam.K[:, 0, 2] -= start[:, 1] + new_cam.K[:, 1, 2] -= start[:, 0] + return new_cam + + def interpolate(self, rgb): + """Interpolate an image to fit the camera""" + if rgb.dim() == 5: + rgb = rearrange(rgb, 'b n c h w -> (b n) c h w') + return interpolate(rgb, scale_factor=None, size=self.hw, mode='bilinear', align_corners=True) + + def interleave_K(self, b): + """Interleave camera intrinsics to fit multiple batches""" + return type(self)( + K=interleave(self._K, b), + Twc=self._Twc, + hw=self._hw, + ) + + def interleave_Twc(self, b): + """Interleave camera pose to fit multiple batches""" + return type(self)( + K=self._K, + Twc=interleave(self._Twc, b), + hw=self._hw, + ) + + def interleave(self, b): + """Interleave camera to fit multiple batches""" + return type(self)( + K=interleave(self._K, b), + Twc=interleave(self._Twc, b), + hw=self._hw, + ) + + def Pwc(self, from_world=True): + """Return projection matrix""" + return self._K[:, :3] if not from_world or self._Twc is None else \ + torch.matmul(self._K, self._Twc.T)[:, :3] + + def to_world(self, points): + """Transform points to world coordinates""" + if points.dim() > 3: + points = points.reshape(points.shape[0], 3, -1) + return points if self.Tcw is None else self.Tcw * points + + def from_world(self, points): + """Transform points back to camera coordinates""" + if points.dim() > 3: + points = points.reshape(points.shape[0], 3, -1) + return points if self._Twc is None else \ + torch.matmul(self._Twc.T, cat_channel_ones(points, 1))[:, :3] + + def to(self, *args, **kwargs): + """Copy camera to device""" + self._K = self._K.to(*args, **kwargs) + if self._Twc is not None: + self._Twc = self._Twc.to(*args, **kwargs) + return self + + def cuda(self, *args, **kwargs): + """Copy camera to CUDA""" + return self.to('cuda') + + def relative_to(self, cam): + """Create a new camera relative to another camera""" + return Camera(K=self._K, hw=self._hw, Twc=self._Twc * cam.Twc.inverse()) + + def global_from(self, cam): + """Create a new camera in global coordinates relative to another camera""" + return Camera(K=self._K, hw=self._hw, Twc=self._Twc * cam.Twc) + + def reconstruct_depth_map(self, depth, to_world=False): + """ + Reconstruct a depth map from the camera viewpoint + + Parameters + ---------- + depth : torch.Tensor + Input depth map [B,1,H,W] + to_world : Bool + Transform points to world coordinates + + Returns + ------- + points : torch.Tensor + Output 3D points [B,3,H,W] + """ + if depth is None: + return None + b, _, h, w = depth.shape + grid = pixel_grid(depth, with_ones=True, device=depth.device).view(b, 3, -1) + points = depth.view(b, 1, -1) * torch.matmul(self.invK[:, :3, :3], grid) + if to_world and self.Tcw is not None: + points = self.Tcw * points + return points.view(b, 3, h, w) + + def reconstruct_cost_volume(self, volume, to_world=True, flatten=True): + """ + Reconstruct a cost volume from the camera viewpoint + + Parameters + ---------- + volume : torch.Tensor + Input depth map [B,1,D,H,W] + to_world : Bool + Transform points to world coordinates + flatten: Bool + Flatten volume points + + Returns + ------- + points : torch.Tensor + Output 3D points [B,3,D,H,W] + """ + c, d, h, w = volume.shape + grid = pixel_grid((h, w), with_ones=True, device=volume.device).view(3, -1).repeat(1, d) + points = torch.stack([ + (volume.view(c, -1) * torch.matmul(invK[:3, :3].unsqueeze(0), grid)).view(3, d * h * w) + for invK in self.invK], 0) + if to_world and self.Tcw is not None: + points = self.Tcw * points + if flatten: + return points.view(-1, 3, d, h * w).permute(0, 2, 1, 3) + else: + return points.view(-1, 3, d, h, w) + + def project_points(self, points, from_world=True, normalize=True, return_z=False): + """ + Project points back to image plane + + Parameters + ---------- + points : torch.Tensor + Input 3D points [B,3,H,W] or [B,3,N] + from_world : Bool + Whether points are in the global frame + normalize : Bool + Whether projections should be normalized to [-1,1] + return_z : Bool + Whether projected depth is return as well + + Returns + ------- + coords : torch.Tensor + Projected 2D coordinates [B,2,H,W] + depth : torch.Tensor + Projected depth [B,1,H,W] + """ + is_depth_map = points.dim() == 4 + hw = self._hw if not is_depth_map else points.shape[-2:] + + if is_depth_map: + points = points.reshape(points.shape[0], 3, -1) + b, _, n = points.shape + + points = torch.matmul(self.Pwc(from_world), cat_channel_ones(points, 1)) + + coords = points[:, :2] / (points[:, 2].unsqueeze(1) + 1e-7) + depth = points[:, 2] + + if not is_depth_map: + if normalize: + coords = norm_pixel_grid(coords, hw=self._hw, in_place=True) + invalid = (coords[:, 0] < -1) | (coords[:, 0] > 1) | \ + (coords[:, 1] < -1) | (coords[:, 1] > 1) | (depth < 0) + coords[invalid.unsqueeze(1).repeat(1, 2, 1)] = -2 + if return_z: + return coords.permute(0, 2, 1), depth + else: + return coords.permute(0, 2, 1) + + coords = coords.view(b, 2, *hw) + depth = depth.view(b, 1, *hw) + + if normalize: + coords = norm_pixel_grid(coords, hw=self._hw, in_place=True) + invalid = (coords[:, 0] < -1) | (coords[:, 0] > 1) | \ + (coords[:, 1] < -1) | (coords[:, 1] > 1) | (depth[:, 0] < 0) + coords[invalid.unsqueeze(1).repeat(1, 2, 1, 1)] = -2 + + if return_z: + return coords.permute(0, 2, 3, 1), depth + else: + return coords.permute(0, 2, 3, 1) + + def project_cost_volume(self, points, from_world=True, normalize=True): + """ + Project points back to image plane + + Parameters + ---------- + points : torch.Tensor + Input 3D points [B,3,H,W] or [B,3,N] + from_world : Bool + Whether points are in the global frame + normalize : Bool + Whether projections should be normalized to [-1,1] + + Returns + ------- + coords : torch.Tensor + Projected 2D coordinates [B,2,H,W] + """ + if points.dim() == 4: + points = points.permute(0, 2, 1, 3).reshape(points.shape[0], 3, -1) + b, _, n = points.shape + + points = torch.matmul(self.Pwc(from_world), cat_channel_ones(points, 1)) + + coords = points[:, :2] / (points[:, 2].unsqueeze(1) + 1e-7) + coords = coords.view(b, 2, -1, *self._hw).permute(0, 2, 3, 4, 1) + + if normalize: + coords[..., 0] /= self._hw[1] - 1 + coords[..., 1] /= self._hw[0] - 1 + return 2 * coords - 1 + else: + return coords + + def coords_from_cost_volume(self, volume, ref_cam=None): + """ + Get warp coordinates from a cost volume + + Parameters + ---------- + volume : torch.Tensor + Input cost volume [B,1,D,H,W] + ref_cam : Camera + Optional to generate cross-camera coordinates + + Returns + ------- + coords : torch.Tensor + Projected 2D coordinates [B,2,H,W] + """ + if ref_cam is None: + return self.project_cost_volume(self.reconstruct_cost_volume(volume, to_world=False), from_world=True) + else: + return ref_cam.project_cost_volume(self.reconstruct_cost_volume(volume, to_world=True), from_world=True) + + def coords_from_depth(self, depth, ref_cam=None): + """ + Get warp coordinates from a depth map + + Parameters + ---------- + depth : torch.Tensor + Input cost volume [B,1,D,H,W] + ref_cam : Camera + Optional to generate cross-camera coordinates + + Returns + ------- + coords : torch.Tensor + Projected 2D coordinates [B,2,H,W] + """ + if ref_cam is None: + return self.project_points(self.reconstruct_depth_map(depth, to_world=False), from_world=True) + else: + return ref_cam.project_points(self.reconstruct_depth_map(depth, to_world=True), from_world=True) diff --git a/vidar/geometry/camera_ds.py b/vidar/geometry/camera_ds.py new file mode 100644 index 0000000000000000000000000000000000000000..cfb39de6560c4dab67cd04fcd779b24545b0946e --- /dev/null +++ b/vidar/geometry/camera_ds.py @@ -0,0 +1,213 @@ +from functools import lru_cache +import torch +import torch.nn as nn + +from vidar.geometry.pose import Pose +from vidar.utils.tensor import pixel_grid + +######################################################################################################################## + +class DSCamera(nn.Module): + """ + Differentiable camera class implementing reconstruction and projection + functions for the double sphere (DS) camera model. + """ + def __init__(self, I, Tcw=None): + """ + Initializes the Camera class + + Parameters + ---------- + I : torch.Tensor [6] + Camera intrinsics parameter vector + Tcw : Pose + Camera -> World pose transformation + """ + super().__init__() + self.I = I + if Tcw is None: + self.Tcw = Pose.identity(len(I)) + elif isinstance(Tcw, Pose): + self.Tcw = Tcw + else: + self.Tcw = Pose(Tcw) + + self.Tcw.to(self.I.device) + + def __len__(self): + """Batch size of the camera intrinsics""" + return len(self.I) + + def to(self, *args, **kwargs): + """Moves object to a specific device""" + self.I = self.I.to(*args, **kwargs) + self.Tcw = self.Tcw.to(*args, **kwargs) + return self + +######################################################################################################################## + + @property + def fx(self): + """Focal length in x""" + return self.I[:, 0].unsqueeze(1).unsqueeze(2) + + @property + def fy(self): + """Focal length in y""" + return self.I[:, 1].unsqueeze(1).unsqueeze(2) + + @property + def cx(self): + """Principal point in x""" + return self.I[:, 2].unsqueeze(1).unsqueeze(2) + + @property + def cy(self): + """Principal point in y""" + return self.I[:, 3].unsqueeze(1).unsqueeze(2) + + @property + def xi(self): + """alpha in DS model""" + return self.I[:, 4].unsqueeze(1).unsqueeze(2) + + @property + def alpha(self): + """beta in DS model""" + return self.I[:, 5].unsqueeze(1).unsqueeze(2) + + @property + @lru_cache() + def Twc(self): + """World -> Camera pose transformation (inverse of Tcw)""" + return self.Tcw.inverse() + +######################################################################################################################## + + def reconstruct(self, depth, frame='w'): + """ + Reconstructs pixel-wise 3D points from a depth map. + + Parameters + ---------- + depth : torch.Tensor [B,1,H,W] + Depth map for the camera + frame : 'w' + Reference frame: 'c' for camera and 'w' for world + + Returns + ------- + points : torch.tensor [B,3,H,W] + Pixel-wise 3D points + """ + + if depth is None: + return None + b, c, h, w = depth.shape + assert c == 1 + + grid = pixel_grid(depth, with_ones=True, device=depth.device) + + # Estimate the outward rays in the camera frame + fx, fy, cx, cy, xi, alpha = self.fx, self.fy, self.cx, self.cy, self.xi, self.alpha + + if torch.any(torch.isnan(alpha)): + raise ValueError('alpha is nan') + + u = grid[:,0,:,:] + v = grid[:,1,:,:] + + mx = (u - cx) / fx + my = (v - cy) / fy + r_square = mx ** 2 + my ** 2 + mz = (1 - alpha ** 2 * r_square) / (alpha * torch.sqrt(1 - (2 * alpha - 1) * r_square) + (1 - alpha)) + coeff = (mz * xi + torch.sqrt(mz ** 2 + (1 - xi ** 2) * r_square)) / (mz ** 2 + r_square) + + x = coeff * mx + y = coeff * my + z = coeff * mz - xi + z = z.clamp(min=1e-7) + + x_norm = x / z + y_norm = y / z + z_norm = z / z + xnorm = torch.stack(( x_norm, y_norm, z_norm ), dim=1) + + # Scale rays to metric depth + Xc = xnorm * depth + + # If in camera frame of reference + if frame == 'c': + return Xc + # If in world frame of reference + elif frame == 'w': + return (self.Twc * Xc.view(b, 3, -1)).view(b,3,h,w) + # If none of the above + else: + raise ValueError('Unknown reference frame {}'.format(frame)) + + def project(self, X, frame='w'): + """ + Projects 3D points onto the image plane + + Parameters + ---------- + X : torch.Tensor [B,3,H,W] + 3D points to be projected + frame : 'w' + Reference frame: 'c' for camera and 'w' for world + + Returns + ------- + points : torch.Tensor [B,H,W,2] + 2D projected points that are within the image boundaries + """ + B, C, H, W = X.shape + assert C == 3 + + # Project 3D points onto the camera image plane + if frame == 'c': + X = X + elif frame == 'w': + X = (self.Tcw * X.view(B,3,-1)).view(B,3,H,W) + else: + raise ValueError('Unknown reference frame {}'.format(frame)) + + fx, fy, cx, cy, xi, alpha = self.fx, self.fy, self.cx, self.cy, self.xi, self.alpha + x, y, z = X[:,0,:], X[:,1,:], X[:,2,:] + z = z.clamp(min=1e-7) + d_1 = torch.sqrt( x ** 2 + y ** 2 + z ** 2 ) + d_2 = torch.sqrt( x ** 2 + y ** 2 + (xi * d_1 + z) ** 2 ) + + Xnorm = fx * x / (alpha * d_2 + (1 - alpha) * (xi * d_1 + z)) + cx + Ynorm = fy * y / (alpha * d_2 + (1 - alpha) * (xi * d_1 + z)) + cy + Xnorm = 2 * Xnorm / (W-1) - 1 + Ynorm = 2 * Ynorm / (H-1) - 1 + + coords = torch.stack([Xnorm, Ynorm], dim=-1).permute(0,3,1,2) + z = z.unsqueeze(1) + + invalid = (coords[:, 0] < -1) | (coords[:, 0] > 1) | \ + (coords[:, 1] < -1) | (coords[:, 1] > 1) | (z[:, 0] < 0) + coords[invalid.unsqueeze(1).repeat(1, 2, 1, 1)] = -2 + + # Return pixel coordinates + return coords.permute(0, 2, 3, 1) + + def reconstruct_depth_map(self, depth, to_world=True): + if to_world: + return self.reconstruct(depth, frame='w') + else: + return self.reconstruct(depth, frame='c') + + def project_points(self, points, from_world=True, normalize=True, return_z=False): + if from_world: + return self.project(points, frame='w') + else: + return self.project(points, frame='c') + + def coords_from_depth(self, depth, ref_cam=None): + if ref_cam is None: + return self.project_points(self.reconstruct_depth_map(depth, to_world=False), from_world=True) + else: + return ref_cam.project_points(self.reconstruct_depth_map(depth, to_world=True), from_world=True) \ No newline at end of file diff --git a/vidar/geometry/camera_eucm.py b/vidar/geometry/camera_eucm.py new file mode 100644 index 0000000000000000000000000000000000000000..472db8bd98fe0fa72ab1de4ff7ffb343fb3636bb --- /dev/null +++ b/vidar/geometry/camera_eucm.py @@ -0,0 +1,215 @@ +from functools import lru_cache +import torch +import torch.nn as nn + +from vidar.geometry.camera_utils import invert_intrinsics, scale_intrinsics +from vidar.geometry.pose import Pose +from vidar.geometry.pose_utils import invert_pose +from vidar.utils.tensor import pixel_grid, same_shape, cat_channel_ones, norm_pixel_grid, interpolate, interleave +from vidar.utils.types import is_tensor, is_seq + +######################################################################################################################## + +class EUCMCamera(nn.Module): + """ + Differentiable camera class implementing reconstruction and projection + functions for the extended unified camera model (EUCM). + """ + def __init__(self, I, Tcw=None): + """ + Initializes the Camera class + + Parameters + ---------- + I : torch.Tensor [6] + Camera intrinsics parameter vector + Tcw : Pose + Camera -> World pose transformation + """ + super().__init__() + self.I = I + if Tcw is None: + self.Tcw = Pose.identity(len(I)) + elif isinstance(Tcw, Pose): + self.Tcw = Tcw + else: + self.Tcw = Pose(Tcw) + + self.Tcw.to(self.I.device) + + def __len__(self): + """Batch size of the camera intrinsics""" + return len(self.I) + + def to(self, *args, **kwargs): + """Moves object to a specific device""" + self.I = self.I.to(*args, **kwargs) + self.Tcw = self.Tcw.to(*args, **kwargs) + return self + +######################################################################################################################## + + @property + def fx(self): + """Focal length in x""" + return self.I[:, 0].unsqueeze(1).unsqueeze(2) + + @property + def fy(self): + """Focal length in y""" + return self.I[:, 1].unsqueeze(1).unsqueeze(2) + + @property + def cx(self): + """Principal point in x""" + return self.I[:, 2].unsqueeze(1).unsqueeze(2) + + @property + def cy(self): + """Principal point in y""" + return self.I[:, 3].unsqueeze(1).unsqueeze(2) + + @property + def alpha(self): + """alpha in EUCM model""" + return self.I[:, 4].unsqueeze(1).unsqueeze(2) + + @property + def beta(self): + """beta in EUCM model""" + return self.I[:, 5].unsqueeze(1).unsqueeze(2) + + @property + @lru_cache() + def Twc(self): + """World -> Camera pose transformation (inverse of Tcw)""" + return self.Tcw.inverse() + +######################################################################################################################## + + def reconstruct(self, depth, frame='w'): + """ + Reconstructs pixel-wise 3D points from a depth map. + + Parameters + ---------- + depth : torch.Tensor [B,1,H,W] + Depth map for the camera + frame : 'w' + Reference frame: 'c' for camera and 'w' for world + + Returns + ------- + points : torch.tensor [B,3,H,W] + Pixel-wise 3D points + """ + + if depth is None: + return None + b, c, h, w = depth.shape + assert c == 1 + + grid = pixel_grid(depth, with_ones=True, device=depth.device) + + # Estimate the outward rays in the camera frame + fx, fy, cx, cy, alpha, beta = self.fx, self.fy, self.cx, self.cy, self.alpha, self.beta + + if torch.any(torch.isnan(alpha)): + raise ValueError('alpha is nan') + + u = grid[:,0,:,:] + v = grid[:,1,:,:] + + mx = (u - cx) / fx + my = (v - cy) / fy + r_square = mx ** 2 + my ** 2 + mz = (1 - beta * alpha ** 2 * r_square) / (alpha * torch.sqrt(1 - (2 * alpha - 1) * beta * r_square) + (1 - alpha)) + coeff = 1 / torch.sqrt(mx ** 2 + my ** 2 + mz ** 2) + + x = coeff * mx + y = coeff * my + z = coeff * mz + z = z.clamp(min=1e-7) + + x_norm = x / z + y_norm = y / z + z_norm = z / z + xnorm = torch.stack(( x_norm, y_norm, z_norm ), dim=1) + + # Scale rays to metric depth + Xc = xnorm * depth + + # If in camera frame of reference + if frame == 'c': + return Xc + # If in world frame of reference + elif frame == 'w': + return (self.Twc * Xc.view(b, 3, -1)).view(b,3,h,w) + # If none of the above + else: + raise ValueError('Unknown reference frame {}'.format(frame)) + + def project(self, X, frame='w'): + """ + Projects 3D points onto the image plane + + Parameters + ---------- + X : torch.Tensor [B,3,H,W] + 3D points to be projected + frame : 'w' + Reference frame: 'c' for camera and 'w' for world + + Returns + ------- + points : torch.Tensor [B,H,W,2] + 2D projected points that are within the image boundaries + """ + B, C, H, W = X.shape + assert C == 3 + + # Project 3D points onto the camera image plane + if frame == 'c': + X = X + elif frame == 'w': + X = (self.Tcw * X.view(B,3,-1)).view(B,3,H,W) + else: + raise ValueError('Unknown reference frame {}'.format(frame)) + + fx, fy, cx, cy, alpha, beta = self.fx, self.fy, self.cx, self.cy, self.alpha, self.beta + x, y, z = X[:,0,:], X[:,1,:], X[:,2,:] + d = torch.sqrt( beta * ( x ** 2 + y ** 2 ) + z ** 2 ) + z = z.clamp(min=1e-7) + + Xnorm = fx * x / (alpha * d + (1 - alpha) * z + 1e-7) + cx + Ynorm = fy * y / (alpha * d + (1 - alpha) * z + 1e-7) + cy + Xnorm = 2 * Xnorm / (W-1) - 1 + Ynorm = 2 * Ynorm / (H-1) - 1 + + coords = torch.stack([Xnorm, Ynorm], dim=-1).permute(0,3,1,2) + z = z.unsqueeze(1) + + invalid = (coords[:, 0] < -1) | (coords[:, 0] > 1) | \ + (coords[:, 1] < -1) | (coords[:, 1] > 1) | (z[:, 0] < 0) + coords[invalid.unsqueeze(1).repeat(1, 2, 1, 1)] = -2 + + # Return pixel coordinates + return coords.permute(0, 2, 3, 1) + + def reconstruct_depth_map(self, depth, to_world=True): + if to_world: + return self.reconstruct(depth, frame='w') + else: + return self.reconstruct(depth, frame='c') + + def project_points(self, points, from_world=True, normalize=True, return_z=False): + if from_world: + return self.project(points, frame='w') + else: + return self.project(points, frame='c') + + def coords_from_depth(self, depth, ref_cam=None): + if ref_cam is None: + return self.project_points(self.reconstruct_depth_map(depth, to_world=False), from_world=True) + else: + return ref_cam.project_points(self.reconstruct_depth_map(depth, to_world=True), from_world=True) \ No newline at end of file diff --git a/vidar/geometry/camera_full.py b/vidar/geometry/camera_full.py new file mode 100644 index 0000000000000000000000000000000000000000..21372ef5e1339484ca71474523357dc1b74e2019 --- /dev/null +++ b/vidar/geometry/camera_full.py @@ -0,0 +1,230 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import torch +from torch_scatter import scatter_min + +from vidar.geometry.camera import Camera +from vidar.utils.tensor import unnorm_pixel_grid + + +class CameraFull(Camera): + """Camera class with additional functionality""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.convert_matrix = torch.tensor( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]], + dtype=torch.float32, + ).unsqueeze(0) + + @staticmethod + def from_list(cams): + """Create cameras from a list""" + K = torch.cat([cam.K for cam in cams], 0) + Twc = torch.cat([cam.Twc.T for cam in cams], 0) + return CameraFull(K=K, Twc=Twc, hw=cams[0].hw) + + def switch(self): + """Switch camera between conventions""" + T = self.convert_matrix.to(self.device) + Twc = T @ self.Twc.T @ T + return type(self)(K=self.K, Twc=Twc, hw=self.hw) + + def bwd(self): + """Switch camera to the backwards convention""" + T = self.convert_matrix.to(self.device) + Tcw = T @ self.Twc.T @ T + return type(self)(K=self.K, Tcw=Tcw, hw=self.hw) + + def fwd(self): + """Switch camera to the forward convention""" + T = self.convert_matrix.to(self.device) + Twc = T @ self.Tcw.T @ T + return type(self)(K=self.K, Twc=Twc, hw=self.hw) + + def look_at(self, at, up=torch.Tensor([0, 1, 0])): + """ + Set a direction for the camera to point (in-place) + + Parameters + ---------- + at : torch.Tensor + Where the camera should be pointing at [B,3] + up : torch.Tensor + Up direction [B,3] + """ + eps = 1e-5 + eye = self.Tcw.T[:, :3, -1] + + at = at.unsqueeze(0) + up = up.unsqueeze(0).to(at.device) + + z_axis = at - eye + z_axis /= z_axis.norm(dim=-1, keepdim=True) + eps + + up = up.expand(z_axis.shape) + x_axis = torch.cross(up, z_axis) + x_axis /= x_axis.norm(dim=-1, keepdim=True) + eps + + y_axis = torch.cross(z_axis, x_axis) + y_axis /= y_axis.norm(dim=-1, keepdim=True) + eps + + R = torch.stack((x_axis, y_axis, z_axis), dim=-1) + + Tcw = self.Tcw + Tcw.T[:, :3, :3] = R + self.Twc = Tcw.inverse() + + def get_origin(self, flatten=False): + """Return camera origin""" + orig = self.Tcw.T[:, :3, -1].view(len(self), 3, 1, 1).repeat(1, 1, *self.hw) + if flatten: + orig = orig.reshape(len(self), 3, -1).permute(0, 2, 1) + return orig + + def get_viewdirs(self, normalize=False, flatten=False, to_world=False): + """Return camera viewing rays""" + ones = torch.ones((len(self), 1, *self.hw), dtype=self.dtype, device=self.device) + rays = self.reconstruct_depth_map(ones, to_world=False) + if normalize: + rays = rays / torch.norm(rays, dim=1).unsqueeze(1) + if to_world: + rays = self.to_world(rays).reshape(len(self), 3, *self.hw) + if flatten: + rays = rays.reshape(len(self), 3, -1).permute(0, 2, 1) + return rays + + def get_render_rays(self, near=None, far=None, n_rays=None, gt=None): + """ + Get render rays + + Parameters + ---------- + near : Float + Near plane + far : Float + Far plane + n_rays : Int + Number of rays + gt : torch.Tensor + Ground-truth values for concatenation + + Returns + ------- + rays : torch.Tensor + Camera viewing rays + """ + b = len(self) + + ones = torch.ones((b, 1, *self.hw), dtype=self.dtype, device=self.device) + + rays = self.reconstruct_depth_map(ones, to_world=False) + rays = rays / torch.norm(rays, dim=1).unsqueeze(1) + + rays[:, 1] = - rays[:, 1] + rays[:, 2] = - rays[:, 2] + + orig = self.pose[:, :3, -1].view(b, 3, 1, 1).repeat(1, 1, *self.hw) + rays = self.no_translation().inverted_pose().to_world(rays).reshape(b, 3, *self.hw) + + info = [orig, rays] + if near is not None: + info = info + [near * ones] + if far is not None: + info = info + [far * ones] + if gt is not None: + info = info + [gt] + + rays = torch.cat(info, 1) + rays = rays.permute(0, 2, 3, 1).reshape(b, -1, rays.shape[1]) + + if n_rays is not None: + idx = torch.randint(0, self.n_pixels, (n_rays,)) + rays = rays[:, idx, :] + + return rays + + def get_plucker(self): + """Get plucker vectors""" + b = len(self) + ones = torch.ones((b, 1, *self.hw), dtype=self.dtype, device=self.device) + rays = self.reconstruct_depth_map(ones, to_world=False) + rays = rays / torch.norm(rays, dim=1).unsqueeze(1) + orig = self.Tcw.T[:, :3, -1].view(b, 3, 1, 1).repeat(1, 1, *self.hw) + + orig = orig.view(1, 3, -1).permute(0, 2, 1) + rays = rays.view(1, 3, -1).permute(0, 2, 1) + + cross = torch.cross(orig, rays, dim=-1) + plucker = torch.cat((rays, cross), dim=-1) + + return plucker + + def project_pointcloud(self, pcl_src, rgb_src, thr=1): + """ + Project pointcloud to the camera plane + + Parameters + ---------- + pcl_src : torch.Tensor + Input 3D pointcloud + rgb_src : torch.Tensor + Pointcloud color information + thr : Int + Threshold for the number of valid points + + Returns + ------- + rgb_tgt : torch.Tensor + Projected image [B,3,H,W] + depth_tgt : torch.Tensor + Projected depth map [B,1,H,W] + """ + if rgb_src.dim() == 4: + rgb_src = rgb_src.view(*rgb_src.shape[:2], -1) + + # Get projected coordinates and depth values + uv_all, z_all = self.project_points(pcl_src, return_z=True, from_world=True) + + rgbs_tgt, depths_tgt = [], [] + + b = pcl_src.shape[0] + for i in range(b): + uv, z = uv_all[i].reshape(-1, 2), z_all[i].reshape(-1, 1) + + # Remove out-of-bounds coordinates and points behind the camera + idx = (uv[:, 0] >= -1) & (uv[:, 0] <= 1) & \ + (uv[:, 1] >= -1) & (uv[:, 1] <= 1) & (z[:, 0] > 0.0) + + # Unormalize and stack coordinates for scatter operation + uv = (unnorm_pixel_grid(uv[idx], self.hw)).round().long() + uv = uv[:, 0] + uv[:, 1] * self.hw[1] + + # Min scatter operation (only keep the closest depth) + depth_tgt = 1e10 * torch.ones((self.hw[0] * self.hw[1], 1), device=pcl_src.device) + depth_tgt, argmin = scatter_min(src=z[idx], index=uv.unsqueeze(1), dim=0, out=depth_tgt) + depth_tgt[depth_tgt == 1e10] = 0. + + num_valid = (depth_tgt > 0).sum() + if num_valid > thr: + + # Substitute invalid values with zero + invalid = argmin == argmin.max() + argmin[invalid] = 0 + rgb_tgt = rgb_src[i].permute(1, 0)[idx][argmin] + rgb_tgt[invalid] = -1 + + else: + + rgb_tgt = -1 * torch.ones(1, self.n_pixels, 3, device=self.device, dtype=self.dtype) + + # Reshape outputs + rgb_tgt = rgb_tgt.reshape(1, self.hw[0], self.hw[1], 3).permute(0, 3, 1, 2) + depth_tgt = depth_tgt.reshape(1, 1, self.hw[0], self.hw[1]) + + rgbs_tgt.append(rgb_tgt) + depths_tgt.append(depth_tgt) + + rgb_tgt = torch.cat(rgbs_tgt, 0) + depth_tgt = torch.cat(depths_tgt, 0) + + return rgb_tgt, depth_tgt diff --git a/vidar/geometry/camera_nerf.py b/vidar/geometry/camera_nerf.py new file mode 100755 index 0000000000000000000000000000000000000000..7cf46d9ea3b9517edd6701122df8b66044e709dd --- /dev/null +++ b/vidar/geometry/camera_nerf.py @@ -0,0 +1,193 @@ + +import torch +from torch_scatter import scatter_min + +from vidar.geometry.camera import Camera +from vidar.utils.tensor import unnorm_pixel_grid + + +class CameraNerf(Camera): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.convert_matrix = torch.tensor( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]], + dtype=torch.float32, + ).unsqueeze(0) + + @staticmethod + def from_list(cams): + K = torch.cat([cam.K for cam in cams], 0) + Twc = torch.cat([cam.Twc.T for cam in cams], 0) + return CameraNerf(K=K, Twc=Twc, hw=cams[0].hw) + + @staticmethod + def from_dict(K, hw, Twc=None): + return {key: CameraNerf(K=K[0], hw=hw[0], Twc=val) for key, val in Twc.items()} + + def switch(self): + T = self.convert_matrix.to(self.device) + Twc = T @ self.Twc.T @ T + return type(self)(K=self.K, Twc=Twc, hw=self.hw) + + def bwd(self): + T = self.convert_matrix.to(self.device) + Tcw = T @ self.Twc.T @ T + return type(self)(K=self.K, Tcw=Tcw, hw=self.hw) + + def fwd(self): + T = self.convert_matrix.to(self.device) + Twc = T @ self.Tcw.T @ T + return type(self)(K=self.K, Twc=Twc, hw=self.hw) + + def look_at(self, at, up=torch.Tensor([0, 1, 0])): + + eps = 1e-5 + eye = self.Tcw.T[:, :3, -1] + + at = at.unsqueeze(0) + up = up.unsqueeze(0).to(at.device) + up /= up.norm(dim=-1, keepdim=True) + eps + + z_axis = at - eye + z_axis /= z_axis.norm(dim=-1, keepdim=True) + eps + + up = up.expand(z_axis.shape) + x_axis = torch.cross(up, z_axis) + x_axis /= x_axis.norm(dim=-1, keepdim=True) + eps + + y_axis = torch.cross(z_axis, x_axis) + y_axis /= y_axis.norm(dim=-1, keepdim=True) + eps + + R = torch.stack((x_axis, y_axis, z_axis), dim=-1) + + Tcw = self.Tcw + Tcw.T[:, :3, :3] = R + self.Twc = Tcw.inverse() + + def get_origin(self, flatten=False): + orig = self.Tcw.T[:, :3, -1].view(len(self), 3, 1, 1).repeat(1, 1, *self.hw) + if flatten: + orig = orig.reshape(len(self), 3, -1).permute(0, 2, 1) + return orig + + def get_viewdirs(self, normalize=False, flatten=False, to_world=False): + + ones = torch.ones((len(self), 1, *self.hw), dtype=self.dtype, device=self.device) + rays = self.reconstruct_depth_map(ones, to_world=False) + if normalize: + rays = rays / torch.norm(rays, dim=1).unsqueeze(1) + if to_world: + rays = self.to_world(rays).reshape(len(self), 3, *self.hw) + if flatten: + rays = rays.reshape(len(self), 3, -1).permute(0, 2, 1) + return rays + + def get_render_rays(self, near=None, far=None, n_rays=None, gt=None): + + b = len(self) + + ones = torch.ones((b, 1, *self.hw), dtype=self.dtype, device=self.device) + + rays = self.reconstruct_depth_map(ones, to_world=False) + rays = rays / torch.norm(rays, dim=1).unsqueeze(1) + + rays[:, 1] = - rays[:, 1] + rays[:, 2] = - rays[:, 2] + + orig = self.pose[:, :3, -1].view(b, 3, 1, 1).repeat(1, 1, *self.hw) + rays = self.no_translation().inverted_pose().to_world(rays).reshape(b, 3, *self.hw) + + info = [orig, rays] + if near is not None: + info = info + [near * ones] + if far is not None: + info = info + [far * ones] + if gt is not None: + info = info + [gt] + + rays = torch.cat(info, 1) + rays = rays.permute(0, 2, 3, 1).reshape(b, -1, rays.shape[1]) + + if n_rays is not None: + idx = torch.randint(0, self.n_pixels, (n_rays,)) + rays = rays[:, idx, :] + + return rays + + def get_plucker(self): + + b = len(self) + ones = torch.ones((b, 1, *self.hw), dtype=self.dtype, device=self.device) + rays = self.reconstruct_depth_map(ones, to_world=False) + rays = rays / torch.norm(rays, dim=1).unsqueeze(1) + orig = self.Tcw.T[:, :3, -1].view(b, 3, 1, 1).repeat(1, 1, *self.hw) + + orig = orig.view(1, 3, -1).permute(0, 2, 1) + rays = rays.view(1, 3, -1).permute(0, 2, 1) + + cross = torch.cross(orig, rays, dim=-1) + plucker = torch.cat((rays, cross), dim=-1) + + return plucker + + def project_pointcloud(self, pcl_src, rgb_src, thr=1): + + if rgb_src.dim() == 4: + rgb_src = rgb_src.view(*rgb_src.shape[:2], -1) + + # Get projected coordinates and depth values + uv_all, z_all = self.project_points(pcl_src, return_z=True, from_world=True) + + rgbs_tgt, depths_tgt = [], [] + + b = pcl_src.shape[0] + for i in range(b): + uv, z = uv_all[i].reshape(-1, 2), z_all[i].reshape(-1, 1) + + # Remove out-of-bounds coordinates and points behind the camera + idx = (uv[:, 0] >= -1) & (uv[:, 0] <= 1) & \ + (uv[:, 1] >= -1) & (uv[:, 1] <= 1) & (z[:, 0] > 0.0) + + # Unormalize and stack coordinates for scatter operation + uv = (unnorm_pixel_grid(uv[idx], self.hw)).round().long() + uv = uv[:, 0] + uv[:, 1] * self.hw[1] + + # Min scatter operation (only keep the closest depth) + depth_tgt = 1e10 * torch.ones((self.hw[0] * self.hw[1], 1), device=pcl_src.device) + depth_tgt, argmin = scatter_min(src=z[idx], index=uv.unsqueeze(1), dim=0, out=depth_tgt) + depth_tgt[depth_tgt == 1e10] = 0. + + num_valid = (depth_tgt > 0).sum() + if num_valid > thr: + + # Substitute invalid values with zero + invalid = argmin == argmin.max() + argmin[invalid] = 0 + rgb_tgt = rgb_src[i].permute(1, 0)[idx][argmin] + rgb_tgt[invalid] = -1 + + else: + + rgb_tgt = -1 * torch.ones(1, self.n_pixels, 3, device=self.device, dtype=self.dtype) + + # Reshape outputs + rgb_tgt = rgb_tgt.reshape(1, self.hw[0], self.hw[1], 3).permute(0, 3, 1, 2) + depth_tgt = depth_tgt.reshape(1, 1, self.hw[0], self.hw[1]) + + rgbs_tgt.append(rgb_tgt) + depths_tgt.append(depth_tgt) + + rgb_tgt = torch.cat(rgbs_tgt, 0) + depth_tgt = torch.cat(depths_tgt, 0) + + return rgb_tgt, depth_tgt + + def reconstruct_depth_map_rays(self, depth, to_world=False): + if depth is None: + return None + b, _, h, w = depth.shape + rays = self.get_viewdirs(normalize=True, to_world=False) + points = (rays * depth).view(b, 3, -1) + if to_world and self.Tcw is not None: + points = self.Tcw * points + return points.view(b, 3, h, w) diff --git a/vidar/geometry/camera_ucm.py b/vidar/geometry/camera_ucm.py new file mode 100644 index 0000000000000000000000000000000000000000..90e9058ea78b4e30ac73d6136c1a75a579fddfdb --- /dev/null +++ b/vidar/geometry/camera_ucm.py @@ -0,0 +1,212 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from functools import lru_cache +import torch +import torch.nn as nn + +from vidar.geometry.camera_utils import invert_intrinsics, scale_intrinsics +from vidar.geometry.pose import Pose +from vidar.geometry.pose_utils import invert_pose +from vidar.utils.tensor import pixel_grid, same_shape, cat_channel_ones, norm_pixel_grid, interpolate, interleave +from vidar.utils.types import is_tensor, is_seq + +######################################################################################################################## + +class UCMCamera(nn.Module): + """ + Differentiable camera class implementing reconstruction and projection + functions for the unified camera model (UCM). + """ + def __init__(self, I, Tcw=None): + """ + Initializes the Camera class + + Parameters + ---------- + I : torch.Tensor [5] + Camera intrinsics parameter vector + Tcw : Pose + Camera -> World pose transformation + """ + super().__init__() + self.I = I + if Tcw is None: + self.Tcw = Pose.identity(len(I)) + elif isinstance(Tcw, Pose): + self.Tcw = Tcw + else: + self.Tcw = Pose(Tcw) + + self.Tcw.to(self.I.device) + + def __len__(self): + """Batch size of the camera intrinsics""" + return len(self.I) + + def to(self, *args, **kwargs): + """Moves object to a specific device""" + self.I = self.I.to(*args, **kwargs) + self.Tcw = self.Tcw.to(*args, **kwargs) + return self + +######################################################################################################################## + + @property + def fx(self): + """Focal length in x""" + return self.I[:, 0].unsqueeze(1).unsqueeze(2) + + @property + def fy(self): + """Focal length in y""" + return self.I[:, 1].unsqueeze(1).unsqueeze(2) + + @property + def cx(self): + """Principal point in x""" + return self.I[:, 2].unsqueeze(1).unsqueeze(2) + + @property + def cy(self): + """Principal point in y""" + return self.I[:, 3].unsqueeze(1).unsqueeze(2) + + @property + def alpha(self): + """alpha in UCM model""" + return self.I[:, 4].unsqueeze(1).unsqueeze(2) + + @property + @lru_cache() + def Twc(self): + """World -> Camera pose transformation (inverse of Tcw)""" + return self.Tcw.inverse() + +######################################################################################################################## + + def reconstruct(self, depth, frame='w'): + """ + Reconstructs pixel-wise 3D points from a depth map. + + Parameters + ---------- + depth : torch.Tensor [B,1,H,W] + Depth map for the camera + frame : 'w' + Reference frame: 'c' for camera and 'w' for world + + Returns + ------- + points : torch.tensor [B,3,H,W] + Pixel-wise 3D points + """ + + if depth is None: + return None + b, c, h, w = depth.shape + assert c == 1 + + grid = pixel_grid(depth, with_ones=True, device=depth.device) + + # Estimate the outward rays in the camera frame + fx, fy, cx, cy, alpha = self.fx, self.fy, self.cx, self.cy, self.alpha # [B,1,1] + + if torch.any(torch.isnan(alpha)): + raise ValueError('alpha is nan') + + u = grid[:,0,:,:] + v = grid[:,1,:,:] + + mx = (u - cx) / fx * (1 - alpha) + my = (v - cy) / fy * (1 - alpha) + r_square = mx ** 2 + my ** 2 + xi = alpha / (1 - alpha) # [B, 1, 1] + coeff = (xi + torch.sqrt(1 + (1 - xi ** 2) * r_square)) / (1 + r_square) # [B, H, W] + + x = coeff * mx + y = coeff * my + z = coeff * 1 - xi + z = z.clamp(min=1e-7) + + x_norm = x / z + y_norm = y / z + z_norm = z / z + xnorm = torch.stack(( x_norm, y_norm, z_norm ), dim=1).float() + + # Scale rays to metric depth + Xc = xnorm * depth + + # If in camera frame of reference + if frame == 'c': + return Xc + # If in world frame of reference + elif frame == 'w': + return (self.Twc * Xc.view(b, 3, -1)).view(b,3,h,w) + # If none of the above + else: + raise ValueError('Unknown reference frame {}'.format(frame)) + + def project(self, X, frame='w'): + """ + Projects 3D points onto the image plane + + Parameters + ---------- + X : torch.Tensor [B,3,H,W] + 3D points to be projected + frame : 'w' + Reference frame: 'c' for camera and 'w' for world + + Returns + ------- + points : torch.Tensor [B,H,W,2] + 2D projected points that are within the image boundaries + """ + B, C, H, W = X.shape + assert C == 3 + + # Project 3D points onto the camera image plane + if frame == 'c': + X = X + elif frame == 'w': + X = (self.Tcw * X.view(B,3,-1)).view(B,3,H,W) + else: + raise ValueError('Unknown reference frame {}'.format(frame)) + + d = torch.norm(X, dim=1) + fx, fy, cx, cy, alpha = self.fx, self.fy, self.cx, self.cy, self.alpha + x, y, z = X[:,0,:], X[:,1,:], X[:,2,:] + z = z.clamp(min=1e-7) + + Xnorm = fx * x / (alpha * d + (1 - alpha) * z + 1e-7) + cx + Ynorm = fy * y / (alpha * d + (1 - alpha) * z + 1e-7) + cy + Xnorm = 2 * Xnorm / (W-1) - 1 + Ynorm = 2 * Ynorm / (H-1) - 1 + + coords = torch.stack([Xnorm, Ynorm], dim=-1).permute(0,3,1,2) + z = z.unsqueeze(1) + + invalid = (coords[:, 0] < -1) | (coords[:, 0] > 1) | \ + (coords[:, 1] < -1) | (coords[:, 1] > 1) | (z[:, 0] < 0) + coords[invalid.unsqueeze(1).repeat(1, 2, 1, 1)] = -2 + + # Return pixel coordinates + return coords.permute(0, 2, 3, 1) + + def reconstruct_depth_map(self, depth, to_world=True): + if to_world: + return self.reconstruct(depth, frame='w') + else: + return self.reconstruct(depth, frame='c') + + def project_points(self, points, from_world=True, normalize=True, return_z=False): + if from_world: + return self.project(points, frame='w') + else: + return self.project(points, frame='c') + + def coords_from_depth(self, depth, ref_cam=None): + if ref_cam is None: + return self.project_points(self.reconstruct_depth_map(depth, to_world=False), from_world=True) + else: + return ref_cam.project_points(self.reconstruct_depth_map(depth, to_world=True), from_world=True) \ No newline at end of file diff --git a/vidar/geometry/camera_utils.py b/vidar/geometry/camera_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..461d140a4ebb0ad16ed1157f129bb773fb4467ad --- /dev/null +++ b/vidar/geometry/camera_utils.py @@ -0,0 +1,31 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from vidar.utils.types import is_seq + + +def invert_intrinsics(K): + """Invert camera intrinsics""" + Kinv = K.clone() + Kinv[:, 0, 0] = 1. / K[:, 0, 0] + Kinv[:, 1, 1] = 1. / K[:, 1, 1] + Kinv[:, 0, 2] = -1. * K[:, 0, 2] / K[:, 0, 0] + Kinv[:, 1, 2] = -1. * K[:, 1, 2] / K[:, 1, 1] + return Kinv + + +def scale_intrinsics(K, ratio): + """Scale intrinsics given a ratio (tuple for individual hw ratios, float for the same ratio)""" + if is_seq(ratio): + ratio_h, ratio_w = ratio + else: + ratio_h = ratio_w = ratio + + K = K.clone() + K[..., 0, 0] *= ratio_w + K[..., 1, 1] *= ratio_h + # K[..., 0, 2] = (K[..., 0, 2] + 0.5) * x_scale - 0.5 + # K[..., 1, 2] = (K[..., 1, 2] + 0.5) * y_scale - 0.5 + K[..., 0, 2] = K[..., 0, 2] * ratio_w + K[..., 1, 2] = K[..., 1, 2] * ratio_h + + return K diff --git a/vidar/geometry/pose.py b/vidar/geometry/pose.py new file mode 100644 index 0000000000000000000000000000000000000000..b385cb772a658028ebc3f7ad0666c66846747a4b --- /dev/null +++ b/vidar/geometry/pose.py @@ -0,0 +1,182 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import torch + +from vidar.geometry.pose_utils import invert_pose, pose_vec2mat, to_global_pose, euler2mat +from vidar.utils.types import is_int + + +def from_dict_sample(T, to_global=False, zero_origin=False, to_matrix=False): + """ + Create poses from a sample dictionary + + Parameters + ---------- + T : Dict + Dictionary containing input poses [B,4,4] + to_global : Bool + Whether poses should be converted to global frame of reference + zero_origin : Bool + Whether the target camera should be the center of the frame of reference + to_matrix : Bool + Whether output poses should be classes or tensors + + Returns + ------- + pose : Dict + Dictionary containing output poses + """ + pose = {key: Pose(val) for key, val in T.items()} + if to_global: + pose = to_global_pose(pose, zero_origin=zero_origin) + if to_matrix: + pose = {key: val.T for key, val in pose.items()} + return pose + + +def from_dict_batch(T, **kwargs): + """Create poses from a batch dictionary""" + pose_batch = [from_dict_sample({key: val[b] for key, val in T.items()}, **kwargs) + for b in range(T[0].shape[0])] + return {key: torch.stack([v[key] for v in pose_batch], 0) for key in pose_batch[0]} + + +class Pose: + """ + Pose class for 3D operations + + Parameters + ---------- + T : torch.Tensor or Int + Transformation matrix [B,4,4], or batch size (poses initialized as identity) + """ + def __init__(self, T=1): + if is_int(T): + T = torch.eye(4).repeat(T, 1, 1) + self.T = T if T.dim() == 3 else T.unsqueeze(0) + + def __len__(self): + """Return batch size""" + return len(self.T) + + def __getitem__(self, i): + """Return batch-wise pose""" + return Pose(self.T[[i]]) + + def __mul__(self, data): + """Transforms data (pose or 3D points)""" + if isinstance(data, Pose): + return Pose(self.T.bmm(data.T)) + elif isinstance(data, torch.Tensor): + return self.T[:, :3, :3].bmm(data) + self.T[:, :3, -1].unsqueeze(-1) + else: + raise NotImplementedError() + + def detach(self): + """Return detached pose""" + return Pose(self.T.detach()) + + @property + def shape(self): + """Return pose shape""" + return self.T.shape + + @property + def device(self): + """Return pose device""" + return self.T.device + + @property + def dtype(self): + """Return pose type""" + return self.T.dtype + + @classmethod + def identity(cls, N=1, device=None, dtype=torch.float): + """Initializes as a [4,4] identity matrix""" + return cls(torch.eye(4, device=device, dtype=dtype).repeat([N,1,1])) + + @staticmethod + def from_dict(T, to_global=False, zero_origin=False, to_matrix=False): + """Create poses from a dictionary""" + if T[0].dim() == 3: + return from_dict_sample(T, to_global=to_global, zero_origin=zero_origin, to_matrix=to_matrix) + elif T[0].dim() == 4: + return from_dict_batch(T, to_global=to_global, zero_origin=zero_origin, to_matrix=True) + + @classmethod + def from_vec(cls, vec, mode): + """Initializes from a [B,6] batch vector""" + mat = pose_vec2mat(vec, mode) + pose = torch.eye(4, device=vec.device, dtype=vec.dtype).repeat([len(vec), 1, 1]) + pose[:, :3, :3] = mat[:, :3, :3] + pose[:, :3, -1] = mat[:, :3, -1] + return cls(pose) + + def repeat(self, *args, **kwargs): + """Repeats the transformation matrix multiple times""" + self.T = self.T.repeat(*args, **kwargs) + return self + + def inverse(self): + """Returns a new Pose that is the inverse of this one""" + return Pose(invert_pose(self.T)) + + def to(self, *args, **kwargs): + """Copy pose to device""" + self.T = self.T.to(*args, **kwargs) + return self + + def cuda(self, *args, **kwargs): + """Copy pose to CUDA""" + self.to('cuda') + return self + + def translate(self, xyz): + """Translate pose""" + self.T[:, :3, -1] = self.T[:, :3, -1] + xyz.to(self.device) + return self + + def rotate(self, rpw): + """Rotate pose""" + rot = euler2mat(rpw) + T = invert_pose(self.T).clone() + T[:, :3, :3] = T[:, :3, :3] @ rot.to(self.device) + self.T = invert_pose(T) + return self + + def rotateRoll(self, r): + """Rotate pose in the roll axis""" + return self.rotate(torch.tensor([[0, 0, r]])) + + def rotatePitch(self, p): + """Rotate pose in the pitcv axis""" + return self.rotate(torch.tensor([[p, 0, 0]])) + + def rotateYaw(self, w): + """Rotate pose in the yaw axis""" + return self.rotate(torch.tensor([[0, w, 0]])) + + def translateForward(self, t): + """Translate pose forward""" + return self.translate(torch.tensor([[0, 0, -t]])) + + def translateBackward(self, t): + """Translate pose backward""" + return self.translate(torch.tensor([[0, 0, +t]])) + + def translateLeft(self, t): + """Translate pose left""" + return self.translate(torch.tensor([[+t, 0, 0]])) + + def translateRight(self, t): + """Translate pose right""" + return self.translate(torch.tensor([[-t, 0, 0]])) + + def translateUp(self, t): + """Translate pose up""" + return self.translate(torch.tensor([[0, +t, 0]])) + + def translateDown(self, t): + """Translate pose down""" + return self.translate(torch.tensor([[0, -t, 0]])) diff --git a/vidar/geometry/pose_utils.py b/vidar/geometry/pose_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6a57ee1531b5eb69012f361a6c3bae8136c940f2 --- /dev/null +++ b/vidar/geometry/pose_utils.py @@ -0,0 +1,206 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import torch +import torch.nn.functional as F + +from vidar.utils.decorators import iterate1 + + +def to_global_pose(pose, zero_origin=False): + """Get global pose coordinates from current and context poses""" + if zero_origin: + pose[0].T[[0]] = torch.eye(4, device=pose[0].device, dtype=pose[0].dtype) + for b in range(1, len(pose[0])): + pose[0].T[[b]] = (pose[0][b] * pose[0][0]).T.float() + for key in pose.keys(): + if key != 0: + pose[key] = pose[key] * pose[0] + return pose + + +# def to_global_pose(pose, zero_origin=False): +# """Get global pose coordinates from current and context poses""" +# if zero_origin: +# pose[(0, 0)].T = torch.eye(4, device=pose[(0, 0)].device, dtype=pose[(0, 0)].dtype). \ +# repeat(pose[(0, 0)].shape[0], 1, 1) +# for key in pose.keys(): +# if key[0] == 0 and key[1] != 0: +# pose[key].T = (pose[key] * pose[(0, 0)]).T +# for key in pose.keys(): +# if key[0] != 0: +# pose[key] = pose[key] * pose[(0, 0)] +# return pose + + +def euler2mat(angle): + """Convert euler angles to rotation matrix""" + B = angle.size(0) + x, y, z = angle[:, 0], angle[:, 1], angle[:, 2] + + cosz = torch.cos(z) + sinz = torch.sin(z) + + zeros = z.detach() * 0 + ones = zeros.detach() + 1 + zmat = torch.stack([ cosz, -sinz, zeros, + sinz, cosz, zeros, + zeros, zeros, ones], dim=1).view(B, 3, 3) + + cosy = torch.cos(y) + siny = torch.sin(y) + + ymat = torch.stack([ cosy, zeros, siny, + zeros, ones, zeros, + -siny, zeros, cosy], dim=1).view(B, 3, 3) + + cosx = torch.cos(x) + sinx = torch.sin(x) + + xmat = torch.stack([ ones, zeros, zeros, + zeros, cosx, -sinx, + zeros, sinx, cosx], dim=1).view(B, 3, 3) + + rot_mat = xmat.bmm(ymat).bmm(zmat) + return rot_mat + + +def pose_vec2mat(vec, mode='euler'): + """Convert translation and Euler rotation to a [B,4,4] torch.Tensor transformation matrix""" + if mode is None: + return vec + trans, rot = vec[:, :3].unsqueeze(-1), vec[:, 3:] + if mode == 'euler': + rot_mat = euler2mat(rot) + else: + raise ValueError('Rotation mode not supported {}'.format(mode)) + mat = torch.cat([rot_mat, trans], dim=2) # [B,3,4] + return mat + + +@iterate1 +def invert_pose(T): + """Invert a [B,4,4] torch.Tensor pose""" + Tinv = torch.eye(4, device=T.device, dtype=T.dtype).repeat([len(T), 1, 1]) + Tinv[:, :3, :3] = torch.transpose(T[:, :3, :3], -2, -1) + Tinv[:, :3, -1] = torch.bmm(-1. * Tinv[:, :3, :3], T[:, :3, -1].unsqueeze(-1)).squeeze(-1) + return Tinv + # return torch.linalg.inv(T) + + +def tvec_to_translation(tvec): + """Convert translation vector to translation matrix (no rotation)""" + batch_size = tvec.shape[0] + T = torch.eye(4).to(device=tvec.device).repeat(batch_size, 1, 1) + t = tvec.contiguous().view(-1, 3, 1) + T[:, :3, 3, None] = t + return T + + +def euler2rot(euler): + """Convert Euler parameters to a [B,3,3] torch.Tensor rotation matrix""" + euler_norm = torch.norm(euler, 2, 2, True) + axis = euler / (euler_norm + 1e-7) + + cos_a = torch.cos(euler_norm) + sin_a = torch.sin(euler_norm) + cos1_a = 1 - cos_a + + x = axis[..., 0].unsqueeze(1) + y = axis[..., 1].unsqueeze(1) + z = axis[..., 2].unsqueeze(1) + + x_sin = x * sin_a + y_sin = y * sin_a + z_sin = z * sin_a + x_cos1 = x * cos1_a + y_cos1 = y * cos1_a + z_cos1 = z * cos1_a + + xx_cos1 = x * x_cos1 + yy_cos1 = y * y_cos1 + zz_cos1 = z * z_cos1 + xy_cos1 = x * y_cos1 + yz_cos1 = y * z_cos1 + zx_cos1 = z * x_cos1 + + batch_size = euler.shape[0] + rot = torch.zeros((batch_size, 4, 4)).to(device=euler.device) + + rot[:, 0, 0] = torch.squeeze(xx_cos1 + cos_a) + rot[:, 0, 1] = torch.squeeze(xy_cos1 - z_sin) + rot[:, 0, 2] = torch.squeeze(zx_cos1 + y_sin) + rot[:, 1, 0] = torch.squeeze(xy_cos1 + z_sin) + rot[:, 1, 1] = torch.squeeze(yy_cos1 + cos_a) + rot[:, 1, 2] = torch.squeeze(yz_cos1 - x_sin) + rot[:, 2, 0] = torch.squeeze(zx_cos1 - y_sin) + rot[:, 2, 1] = torch.squeeze(yz_cos1 + x_sin) + rot[:, 2, 2] = torch.squeeze(zz_cos1 + cos_a) + rot[:, 3, 3] = 1 + + return rot + + +def vec2mat(euler, translation, invert=False): + """Convert Euler rotation and translation to a [B,4,4] torch.Tensor transformation matrix""" + R = euler2rot(euler) + t = translation.clone() + + if invert: + R = R.transpose(1, 2) + t *= -1 + + T = tvec_to_translation(t) + + if invert: + M = torch.matmul(R, T) + else: + M = torch.matmul(T, R) + + return M + + +def rot2quat(R): + """Convert a [B,3,3] rotation matrix to [B,4] quaternions""" + b, _, _ = R.shape + q = torch.ones((b, 4), device=R.device) + + R00 = R[:, 0, 0] + R01 = R[:, 0, 1] + R02 = R[:, 0, 2] + R10 = R[:, 1, 0] + R11 = R[:, 1, 1] + R12 = R[:, 1, 2] + R20 = R[:, 2, 0] + R21 = R[:, 2, 1] + R22 = R[:, 2, 2] + + q[:, 3] = torch.sqrt(1.0 + R00 + R11 + R22) / 2 + q[:, 0] = (R21 - R12) / (4 * q[:, 3]) + q[:, 1] = (R02 - R20) / (4 * q[:, 3]) + q[:, 2] = (R10 - R01) / (4 * q[:, 3]) + + return q + + +def quat2rot(q): + """Convert [B,4] quaternions to [B,3,3] rotation matrix""" + b, _ = q.shape + q = F.normalize(q, dim=1) + R = torch.ones((b, 3, 3), device=q.device) + + qr = q[:, 0] + qi = q[:, 1] + qj = q[:, 2] + qk = q[:, 3] + + R[:, 0, 0] = 1 - 2 * (qj ** 2 + qk ** 2) + R[:, 0, 1] = 2 * (qj * qi - qk * qr) + R[:, 0, 2] = 2 * (qi * qk + qr * qj) + R[:, 1, 0] = 2 * (qj * qi + qk * qr) + R[:, 1, 1] = 1 - 2 * (qi ** 2 + qk ** 2) + R[:, 1, 2] = 2 * (qj * qk - qi * qr) + R[:, 2, 0] = 2 * (qk * qi - qj * qr) + R[:, 2, 1] = 2 * (qj * qk + qi * qr) + R[:, 2, 2] = 1 - 2 * (qi ** 2 + qj ** 2) + + return R diff --git a/vidar/metrics/base.py b/vidar/metrics/base.py new file mode 100644 index 0000000000000000000000000000000000000000..61f303f9823d80fb1dee9eb7e779cdd68c6800e7 --- /dev/null +++ b/vidar/metrics/base.py @@ -0,0 +1,163 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from collections import OrderedDict +from functools import partial + +import numpy as np +import torch + +from vidar.utils.distributed import reduce_value +from vidar.utils.tensor import same_shape, interpolate + + +class BaseEvaluation: + """ + Base class for evaluation metrics + + Parameters + ---------- + cfg : Config + Configuration file + name : String + Evaluation name + task : String + Task referent to the evaluation + metrics : String + Metrics name + """ + def __init__(self, cfg, name, task, metrics): + self.name = name + self.task = task + self.width = 32 + 11 * len(metrics) + self.metrics = metrics + self.modes = [''] + + self.font1 = {'color': 'magenta', 'attrs': ('bold',)} + self.font2 = {'color': 'cyan', 'attrs': ()} + + self.nearest = partial(interpolate, scale_factor=None, mode='nearest', align_corners=None) + self.bilinear = partial(interpolate, scale_factor=None, mode='bilinear', align_corners=True) + + self.only_first = cfg.has('only_first', False) + + @property + def horz_line(self): + """Print horizontal line""" + return '|{:<}|'.format('*' * self.width) + + @property + def metr_line(self): + """Print metrics line""" + return '| {:^30} |' + ' {:^8} |' * len(self.metrics) + + @property + def outp_line(self): + """Print output line""" + return '{:<30}' + ' | {:^8.3f}' * len(self.metrics) + + @staticmethod + def wrap(string): + """Wrap line around vertical bars""" + return '| {} |'.format(string) + + def check_name(self, key): + """Check name for prefixes""" + return key.startswith(self.name) or \ + key.startswith('fwd_' + self.name) or \ + key.startswith('bwd_' + self.name) + + def reduce_fn(self, *args, **kwargs): + """Reduce function""" + raise NotImplementedError('reduce_fn not implemented for {}'.format(self.__name__)) + + def populate_metrics_dict(self, *args, **kwargs): + """Populate metrics function""" + raise NotImplementedError('create_dict_key not implemented for {}'.format(self.__name__)) + + def print(self, *args, **kwargs): + """Print function""" + raise NotImplementedError('print not implemented for {}'.format(self.__name__)) + + @staticmethod + def interp(dst, src, fn): + """Interpolate dst to be the size of src using the interpolation function fn""" + if dst is None: + return dst + assert dst.dim() == src.dim() + if dst.dim() == 4 and not same_shape(dst.shape, src.shape): + dst = fn(dst, size=src) + return dst + + def interp_bilinear(self, dst, src): + """Bilinear interpolation""" + return self.interp(dst, src, self.bilinear) + + def interp_nearest(self, dst, src): + """Nearest interpolation""" + return self.interp(dst, src, self.nearest) + + def reduce(self, output, dataloaders, prefixes, verbose=True): + """Reduce function""" + reduced_data = self.reduce_metrics(output, dataloaders) + metrics_dict = self.create_metrics_dict(reduced_data, prefixes) + if verbose: + self.print(reduced_data, prefixes) + return metrics_dict + + def create_metrics_dict(self, reduced_data, prefixes): + """Create metrics dictionary""" + metrics_dict = {} + # For all datasets + for n, metrics in enumerate(reduced_data): + if metrics: # If there are calculated metrics + self.populate_metrics_dict(metrics, metrics_dict, prefixes[n]) + # Return metrics dictionary + return metrics_dict + + def reduce_metrics(self, dataset_outputs, datasets, ontology=None, strict=True): + """Reduce metrics""" + # If there is only one dataset, wrap in a list + if isinstance(dataset_outputs[0], dict): + dataset_outputs = [dataset_outputs] + # List storing metrics for all datasets + all_metrics_dict = [] + # Loop over all datasets and all batches + for batch_outputs, dataset in zip(dataset_outputs, datasets): + # Initialize metrics dictionary + metrics_dict = OrderedDict() + # Get length, names and dimensions + length = len(dataset) + names = [key for key in list(batch_outputs[0].keys()) if self.check_name(key)] + dims = [tuple(batch_outputs[0][name].size()) for name in names] + # Get data device + device = batch_outputs[0]['idx'].device + # Count how many times each sample was seen + if strict: + seen = torch.zeros(length, device=device) + for output in batch_outputs: + seen[output['idx']] += 1 + seen = reduce_value(seen, average=False, name='idx') + assert not np.any(seen.cpu().numpy() == 0), \ + 'Not all samples were seen during evaluation' + # Reduce relevant metrics + for name, dim in zip(names, dims): + metrics = torch.zeros([length] + list(dim), device=device) + + # Count how many times each sample was seen + if not strict: + seen = torch.zeros(length, device=device) + for output in batch_outputs: + if name in output: + seen[output['idx']] += 1 + seen = reduce_value(seen, average=False, name='idx') + + for output in batch_outputs: + if name in output: + metrics[output['idx']] = output[name] + metrics = reduce_value(metrics, average=False, name=name) + metrics_dict[name] = self.reduce_fn(metrics, seen) + # Append metrics dictionary to the list + all_metrics_dict.append(metrics_dict) + # Return list of metrics dictionary + return all_metrics_dict + diff --git a/vidar/metrics/depth.py b/vidar/metrics/depth.py new file mode 100644 index 0000000000000000000000000000000000000000..86a8013d167719bc1f0587b1c24b9388ad1ce544 --- /dev/null +++ b/vidar/metrics/depth.py @@ -0,0 +1,223 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import torch + +from vidar.metrics.base import BaseEvaluation +from vidar.metrics.utils import create_crop_mask, scale_output +from vidar.utils.config import cfg_has +from vidar.utils.data import dict_remove_nones +from vidar.utils.depth import post_process_depth +from vidar.utils.distributed import on_rank_0 +from vidar.utils.logging import pcolor +from vidar.utils.types import is_dict + + +class DepthEvaluation(BaseEvaluation): + """ + Detph evaluation metrics + + Parameters + ---------- + cfg : Config + Configuration file + """ + def __init__(self, cfg): + super().__init__(cfg, + name='depth', task='depth', + metrics=('abs_rel', 'sqr_rel', 'rmse', 'rmse_log', 'silog', 'a1', 'a2', 'a3'), + ) + + self.min_depth = cfg.min_depth + self.max_depth = cfg.max_depth + self.crop = cfg_has(cfg, 'crop', '') + self.scale_output = cfg_has(cfg, 'scale_output', 'resize') + + self.post_process = cfg_has(cfg, 'post_process', False) + self.median_scaling = cfg_has(cfg, 'median_scaling', False) + self.valid_threshold = cfg.has('valid_threshold', None) + + if self.post_process: + self.modes += ['pp'] + if self.median_scaling: + self.modes += ['gt'] + if self.post_process and self.median_scaling: + self.modes += ['pp_gt'] + + @staticmethod + def reduce_fn(metrics, seen): + """Reduce function""" + valid = seen.view(-1) > 0 + return (metrics[valid] / seen.view(-1, 1)[valid]).mean(0) + + def populate_metrics_dict(self, metrics, metrics_dict, prefix): + """Populate metrics function""" + for metric in metrics: + if metric.startswith(self.name): + name, suffix = metric.split('|') + for i, key in enumerate(self.metrics): + metrics_dict[f'{prefix}-{name}|{key}_{suffix}'] = \ + metrics[metric][i].item() + + @on_rank_0 + def print(self, reduced_data, prefixes): + """Print function""" + print() + print(self.horz_line) + print(self.metr_line.format(*((self.name.upper(),) + self.metrics))) + for n, metrics in enumerate(reduced_data): + if sum([self.name in key for key in metrics.keys()]) == 0: + continue + print(self.horz_line) + print(self.wrap(pcolor('*** {:<114}'.format(prefixes[n]), **self.font1))) + print(self.horz_line) + for key, metric in sorted(metrics.items()): + if self.name in key: + print(self.wrap(pcolor(self.outp_line.format( + *((key.upper(),) + tuple(metric.tolist()))), **self.font2))) + print(self.horz_line) + print() + + def compute(self, gt, pred, use_gt_scale=True, mask=None): + """ + Compute depth metrics + + Parameters + ---------- + gt : torch.Tensor + Ground-truth depth maps [B,1,H,W] + pred : torch.Tensor + Predicted depth map [B,1,H,W] + use_gt_scale : Bool + Use median-scaling + mask : torch.Tensor or None + Mask to remove pixels from evaluation + + Returns + ------- + metrics : torch.Tensor + Depth metrics + """ + # Match predicted depth map to ground-truth resolution + pred = scale_output(pred, gt, self.scale_output) + # Create crop mask if requested + crop_mask = create_crop_mask(self.crop, gt) + # For each batch sample + metrics = [] + for i, (pred_i, gt_i) in enumerate(zip(pred, gt)): + + # Squeeze GT and PRED + gt_i, pred_i = torch.squeeze(gt_i), torch.squeeze(pred_i) + mask_i = torch.squeeze(mask[i]) if mask is not None else None + + # Keep valid pixels (min/max depth and crop) + valid = (gt_i > self.min_depth) & (gt_i < self.max_depth) + # Remove invalid predicted pixels as well + valid = valid & (pred_i > 0) + # Apply crop mask if requested + valid = valid & crop_mask.bool() if crop_mask is not None else valid + # Apply provided mask if available + valid = valid & mask_i.bool() if mask is not None else valid + + # Invalid evaluation + if self.valid_threshold is not None and valid.sum() < self.valid_threshold: + return None + + # Keep only valid pixels + gt_i, pred_i = gt_i[valid], pred_i[valid] + # GT median scaling if needed + if use_gt_scale: + pred_i = pred_i * torch.median(gt_i) / torch.median(pred_i) + # Clamp PRED depth values to min/max values + pred_i = pred_i.clamp(self.min_depth, self.max_depth) + + # Calculate depth metrics + + thresh = torch.max((gt_i / pred_i), (pred_i / gt_i)) + a1 = (thresh < 1.25).float().mean() + a2 = (thresh < 1.25 ** 2).float().mean() + a3 = (thresh < 1.25 ** 3).float().mean() + + diff_i = gt_i - pred_i + abs_rel = torch.mean(torch.abs(diff_i) / gt_i) + sq_rel = torch.mean(diff_i ** 2 / gt_i) + rmse = torch.sqrt(torch.mean(diff_i ** 2)) + rmse_log = torch.sqrt(torch.mean((torch.log(gt_i) - torch.log(pred_i)) ** 2)) + + err = torch.log(pred_i) - torch.log(gt_i) + silog = torch.sqrt(torch.mean(err ** 2) - torch.mean(err) ** 2) * 100 + + metrics.append([abs_rel, sq_rel, rmse, rmse_log, silog, a1, a2, a3]) + + # Return metrics + return torch.tensor(metrics, dtype=gt.dtype) + + def evaluate(self, batch, output, flipped_output=None): + """ + Evaluate predictions + + Parameters + ---------- + batch : Dict + Dictionary containing ground-truth information + output : Dict + Dictionary containing predictions + flipped_output : Bool + Optional flipped output for post-processing + + Returns + ------- + metrics : Dict + Dictionary with calculated metrics + predictions : Dict + Dictionary with additional predictions + """ + metrics, predictions = {}, {} + if self.name not in batch: + return metrics, predictions + # For each output item + for key, val in output.items(): + # If it corresponds to this task + if key.startswith(self.name) and 'debug' not in key: + # Loop over every context + val = val if is_dict(val) else {0: val} + for ctx in val.keys(): + # Loop over every scale + for i in range(1 if self.only_first else len(val[ctx])): + + pred = val[ctx][i] + gt = batch[self.name][ctx] + + if self.post_process: + pred_flipped = flipped_output[key][ctx][i] + pred_pp = post_process_depth(pred, pred_flipped, method='mean') + else: + pred_pp = None + + if i > 0: + pred = self.interp_nearest(pred, val[ctx][0]) + if self.post_process: + pred_pp = self.interp_nearest(pred_pp, val[ctx][0]) + + if pred.dim() == 4: + suffix = '(%s)' % str(ctx) + ('_%d' % i if not self.only_first else '') + for mode in self.modes: + metrics[f'{key}|{mode}{suffix}'] = \ + self.compute( + gt=gt, + pred=pred_pp if 'pp' in mode else pred, + use_gt_scale='gt' in mode, + mask=None, + ) + elif pred.dim() == 5: + for j in range(pred.shape[1]): + suffix = '(%s_%d)' % (str(ctx), j) + ('_%d' % i if not self.only_first else '') + for mode in self.modes: + metrics[f'{key}|{mode}{suffix}'] = self.compute( + gt=gt[:, j], + pred=pred_pp[:, j] if 'pp' in mode else pred[:, j], + use_gt_scale='gt' in mode, + mask=None, + ) + + return dict_remove_nones(metrics), predictions + diff --git a/vidar/metrics/utils.py b/vidar/metrics/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..91a632391be37cccd5304fbeadf74f955b0d245f --- /dev/null +++ b/vidar/metrics/utils.py @@ -0,0 +1,85 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import torch + +from vidar.utils.tensor import interpolate_image + + +def scale_output(pred, gt, scale_fn): + """ + Match depth maps to ground-truth resolution + + Parameters + ---------- + pred : torch.Tensor + Predicted depth maps [B,1,w,h] + gt : torch.tensor + Ground-truth depth maps [B,1,H,W] + scale_fn : String + How to scale output to GT resolution + Resize: Nearest neighbors interpolation + top-center: Pad the top of the image and left-right corners with zeros + + Returns + ------- + pred : torch.tensor + Uncropped predicted depth maps [B,1,H,W] + """ + if pred.dim() == 5 and gt.dim() == 5: + return torch.stack([scale_output(pred[:, i], gt[:, i], scale_fn) for i in range(pred.shape[1])], 1) + # Return depth map if scaling is not required + if scale_fn == 'none': + return pred + elif scale_fn == 'resize': + # Resize depth map to GT resolution + return interpolate_image(pred, gt.shape, mode='bilinear', align_corners=True) + else: + # Create empty depth map with GT resolution + pred_uncropped = torch.zeros(gt.shape, dtype=pred.dtype, device=pred.device) + # Uncrop top vertically and center horizontally + if scale_fn == 'top-center': + top, left = gt.shape[2] - pred.shape[2], (gt.shape[3] - pred.shape[3]) // 2 + pred_uncropped[:, :, top:(top + pred.shape[2]), left:(left + pred.shape[3])] = pred + else: + raise NotImplementedError('Depth scale function {} not implemented.'.format(scale_fn)) + # Return uncropped depth map + return pred_uncropped + + +def create_crop_mask(crop, gt): + """ + Create crop mask for evaluation + + Parameters + ---------- + crop : String + Type of crop + gt : torch.Tensor + Ground-truth depth map (for dimensions) + + Returns + ------- + crop_mask: torch.Tensor + Mask for evaluation + """ + # Return None if mask is not required + if crop in ('', None): + return None + # Create empty mask + batch_size, _, gt_height, gt_width = gt.shape + crop_mask = torch.zeros(gt.shape[-2:]).byte().type_as(gt) + # Get specific mask + if crop == 'eigen_nyu': + crop_mask[20:459, 24:615] = 1 + elif crop == 'bts_nyu': + crop_mask[45:471, 41:601] = 1 + elif crop == 'garg': + y1, y2 = int(0.40810811 * gt_height), int(0.99189189 * gt_height) + x1, x2 = int(0.03594771 * gt_width), int(0.96405229 * gt_width) + crop_mask[y1:y2, x1:x2] = 1 + elif crop == 'eigen': + y1, y2 = int(0.3324324 * gt_height), int(0.91351351 * gt_height) + x1, x2 = int(0.03594771 * gt_width), int(0.96405229 * gt_width) + crop_mask[y1:y2, x1:x2] = 1 + # Return crop mask + return crop_mask diff --git a/vidar/utils/config.py b/vidar/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..23c8d1b97e3e007db5bfea85b051c12ff3b9dfc7 --- /dev/null +++ b/vidar/utils/config.py @@ -0,0 +1,468 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import importlib +import os +from argparse import Namespace + +import torch +import yaml + +from vidar.utils.data import make_list, num_trainable_params +from vidar.utils.distributed import print0 +from vidar.utils.logging import pcolor +from vidar.utils.networks import load_checkpoint +from vidar.utils.types import is_dict, is_list, is_namespace + + +def cfg_has(*args): + """ + Check if a key is in configuration + + Parameters + ---------- + args : Tuple (Config, String, Value) + Inputs: + length 2 = configuration/name, + length 3 = configuration/name/default + + Returns + ------- + Flag : Bool or Value + True/False if key is in configuration, key value/default if default is provided + """ + if len(args) == 2: + cfg, name = args + if not is_list(name): + return name in cfg.__dict__.keys() + else: + return all([n in cfg.__dict__.keys() for n in name]) + elif len(args) == 3: + cfg, name, default = args + has = name in cfg.__dict__.keys() + return cfg.__dict__[name] if has else default + else: + raise ValueError('Wrong number of arguments for cfg_has') + + +def cfg_add_to_dict(dic, cfg, key, i=None): + """ + Add configuration key to dictionary + + Parameters + ---------- + dic : Dict + Input dictionary + cfg : Config + Input configuration + key : String + Input key + i : Int + Optional list index + """ + if cfg_has(cfg, key): + dic[key] = cfg.__dict__[key] if i is None \ + else cfg.__dict__[key][0] if len(cfg.__dict__[key]) == 1 \ + else cfg.__dict__[key][i] + + +def cfg_from_dict(dic): + """ + Create configuration from dictionary + + Parameters + ---------- + dic : Dict + Input dictionary + + Returns + ------- + cfg : Config + Output configuration + """ + for key, val in dic.items(): + if is_dict(val): + dic[key] = cfg_from_dict(val) + return Config(**dic) + + +def update_cfg(cfg): + """ + Update configuration with hard-coded information + + Parameters + ---------- + cfg : Config + Input configuration + + Returns + ------- + cfg : Config + Updated configuration + """ + if not torch.cuda.is_available(): + cfg.setup.grad_scaler = False + return cfg + + +def to_namespace(data): + """ + Convert dictionary to namespace + + Parameters + ---------- + data : Dict or Config + Input dictionary + + Returns + ------- + cfg : Config + Output configuration + """ + for key in data.keys(): + if is_dict(data[key]): + data[key] = to_namespace(data[key]) + return Config(**data) + + +def merge_dict(default, config): + """ + Merge two dictionaries + + Parameters + ---------- + default : Dict + Dictionary with default values + config : Dict + Dictionary with values to update + + Returns + ------- + cfg : Dict + Updated dictionary + """ + if is_namespace(default): + default = default.__dict__ + for key in config.keys(): + if key not in default.keys(): + default[key] = {} + if not is_dict(config[key]): + default[key] = config[key] + else: + default[key] = merge_dict(default[key], config[key]) + return default + + +def update_from_kwargs(cfg, **kwargs): + """ + Update configuration based on keyword arguments + + Parameters + ---------- + cfg : Config + Input configuration + kwargs : Dict + Keyword arguments + + Returns + ------- + cfg : Config + Updated configuration + """ + if kwargs is not None: + for key, val in kwargs.items(): + key_split = key.split('.') + dic = cfg.__dict__ + for k in key_split[:-1]: + dic = dic[k].__dict__ + dic[key_split[-1]] = val + return cfg + + +def recursive_recipe(cfg, super_key=None): + """ + Add recipe parameters to configuration + + Parameters + ---------- + cfg : Config + Input configuration + super_key : String + Which recipe entry to use + + Returns + ------- + cfg : Config + Updated configuration + """ + for key in list(cfg.keys()): + if is_dict(cfg[key]): + cfg[key] = recursive_recipe(cfg[key], super_key=key) + elif key == 'recipe': + recipe = 'configs/recipes/' + cfg.pop(key) + if '|' in recipe: + recipe, super_key = recipe.split('|') + recipe = read_config(recipe + '.yaml') + while '.' in super_key: + split = super_key.split('.') + recipe = recipe.__dict__[split[0]] + super_key = '.'.join(split[1:]) + recipe = recipe.__dict__[super_key].__dict__ + cfg = merge_dict(recipe, cfg) + return cfg + + +def read_config(path, **kwargs): + """ + Create configuration from file + + Parameters + ---------- + path : String + Configuration path + kwargs : Dict + Keyword arguments to update configuration + + Returns + ------- + cfg : Config + Output configuration + """ + """Read configuration from file""" + with open(path) as cfg: + config = yaml.load(cfg, Loader=yaml.FullLoader) + config = recursive_recipe(config) + cfg = to_namespace(config) + if kwargs is not None: + cfg = update_from_kwargs(cfg, **kwargs) + return cfg + + +def is_recursive(val): + """ + Check if configuration entry is recursive + + Parameters + ---------- + val : Config + Input Configuration + + Returns + ------- + Flag : Bool + True/False if is recursive or not + """ + return 'file' in val.__dict__.keys() + + +def get_folder_name(path, mode, root='vidar/arch'): + """ + Get folder and name from configuration path + + Parameters + ---------- + path : String + Input path + mode : String + Which mode to use (e.g., models, networks, losses) + root : String + Which folder to use + + Returns + ------- + folder : String + Output folder + name : String + Output name + """ + """Get folder and name from configuration path""" + folder, name = os.path.dirname(path), os.path.basename(path) + folder = os.path.join(root, mode, folder) + if folder.endswith('/'): + folder = folder[:-1] + return folder, name + + +def recursive_assignment(model, cfg, mode, verbose=True): + """ + Recursively assign information from a configuration + + Parameters + ---------- + model : torch.nn.Module + Which network we are using + cfg : Config + Input Configuration + mode : String + Which mode we are using (e.g., models, networks, losses) + verbose : Bool + Print information on screen + """ + font = {'color': 'yellow', 'attrs': ('dark',)} + for key, val in cfg.__dict__.items(): + cls = cfg.__dict__[key] + if is_namespace(cls): + if is_recursive(val): + folder, name = get_folder_name(val.file, mode) + getattr(model, mode)[key] = load_class(name, folder)(cls) + if verbose: + string = '######### {}'.format(getattr(model, mode)[key].__class__.__name__) + num_params = num_trainable_params(getattr(model, mode)[key]) + if num_params > 0: + string += f' ({num_params:,} parameters)' + print0(pcolor(string, **font)) + if cfg_has(val, 'checkpoint'): + model_attr = getattr(model, mode)[key] + load_checkpoint(model_attr, val.checkpoint, strict=False, verbose=verbose, prefix=key) + recursive_assignment(getattr(model, mode)[key], cls, mode, verbose=verbose) + if key == 'blocks': + for key2, val2 in cfg.__dict__[key].__dict__.items(): + cls2 = cfg.__dict__[key].__dict__[key2] + if is_recursive(val2): + folder, name = get_folder_name(val2.file, 'blocks') + model.blocks[key2] = load_class(name, folder)(cls2) + recursive_assignment(model.blocks[key2], cls2, 'blocks', verbose=verbose) + + +def load_class(filename, paths, concat=True, methodname=None): + """ + Look for a file in different locations and return its method with the same name + Optionally, you can use concat to search in path.filename instead + + Parameters + ---------- + filename : String + Name of the file we are searching for + paths : String or list[String] + Folders in which the file will be searched + concat : Bol + Flag to concatenate filename to each path during the search + methodname : String or list[String] + Method name (If None, use filename + If it's a string, use it as the methodname + If it's a list, use the first methodname found) + + Returns + ------- + method : Function + Loaded method + """ + # If method name is not given, use filename + methodname = make_list(filename if methodname is None else methodname) + # for each path in paths + for path in make_list(paths): + # Create full path + path = path.replace('/', '.') + full_path = '{}.{}'.format(path, filename) if concat else path + # Get module + module = importlib.import_module(full_path) + # Try all method names + for name in methodname: + method = getattr(module, name, None) + # Return if found + if method is not None: + return method + # Didn't find anything + raise ValueError('Unknown class {}'.format(filename)) + + +def get_from_cfg_list(cfg, key, idx): + """ + Get configuration value from a list + + Parameters + ---------- + cfg : Config + Input configuration + key : String + Input configuration key + idx : Int + List index + + Returns + ------- + data : Value + Key value at that index if it's a list, otherwise return the key value directly + """ + if key not in cfg.__dict__.keys(): + return None + data = cfg.__dict__[key] + return data if not is_list(data) else data[idx] if len(data) > 1 else data[0] + + +def dataset_prefix(cfg, idx): + """ + Create dataset prefix based on configuration information + + Parameters + ---------- + cfg : Config + Input configuration + idx : Int + Input index for information retrieval + + Returns + ------- + prefix : String + Dataset prefix + """ + # Dataset path is always available + # prefix = cfg.name[idx] + prefix = '{}'.format(os.path.splitext(get_from_cfg_list(cfg, 'path', idx).split('/')[-1])[0]) + # If split is available + val = get_from_cfg_list(cfg, 'split', idx) + if val is not None: + prefix += '-{}'.format(os.path.splitext(os.path.basename(val))[0]) + # If input depth type is available + val = get_from_cfg_list(cfg, 'input_depth_type', idx) + if val is not None and val not in [None, '']: + prefix += '-+{}'.format(val) + # If depth type is available + val = get_from_cfg_list(cfg, 'depth_type', idx) + if val is not None and val not in [None, '']: + prefix += '-{}'.format(val) + # If there is camera information + val = get_from_cfg_list(cfg, 'cameras', idx) + if val is not None and is_list(val) and len(val) > 0: + prefix += '-cam{}'.format(val[0]) + # Return prefix + return prefix + + +class Config(Namespace): + """ + Configuration class for passing arguments between other classes + + Parameters + ---------- + kwargs: Dict + Arguments to create configuration + """ + def __init__(self, **kwargs): + super().__init__(**kwargs) + + @staticmethod + def from_file(file): + """Read configuration from file""" + return read_config(file) + + @property + def dict(self): + """Return configuration as dictionary""" + return self.__dict__ + + def keys(self): + """Return dictionary keys of configuration""" + return self.dict.keys() + + def items(self): + """Return dictionary items of configuration""" + return self.dict.items() + + def values(self): + """Return dictionary values of configuration""" + return self.dict.values() + + def has(self, *args): + """Check if configuration has certain parameters""" + return cfg_has(self, *args) + diff --git a/vidar/utils/data.py b/vidar/utils/data.py new file mode 100644 index 0000000000000000000000000000000000000000..540bb918f9176a65dcd39d3252fdf9698cf1e6cc --- /dev/null +++ b/vidar/utils/data.py @@ -0,0 +1,245 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import random +from collections import OrderedDict +from inspect import signature + +import numpy as np +import torch + +from vidar.utils.decorators import iterate1, iterate2 +from vidar.utils.types import is_list, is_double_list, is_tuple, is_tensor, is_dict, is_seq + +KEYS_IMAGE = [ + 'rgb', 'mask', + 'input_depth', 'depth', + 'bwd_optical_flow', 'fwd_optical_flow', + ] + +KEYS_MATRIX = [ + 'intrinsics', 'extrinsics', 'pose', 'semantic', + ] + + +def modrem(v, n): + """Return round division and remainder""" + return v // n, v % n + + +def flatten(lst): + """Flatten a list of lists into a list""" + return [l for ls in lst for l in ls] if is_double_list(lst) else lst + + +def keys_with(dic, string, without=()): + """Return keys from a dictionary that contain a certain string""" + return [key for key in dic if string in key and not any(w in key for w in make_list(without))] + + +def keys_startswith(dic, string): + """Return keys from a dictionary that contain a certain string""" + return [key for key in dic if key.startswith(string)] + + +def keys_in(dic, keys): + """Return only keys in a dictionary""" + return [key for key in keys if key in dic] + + +def str_not_in(string, keys): + for key in keys: + if key in string: + return False + return True + + +def make_list(var, n=None): + """Wraps the input into a list, and optionally repeats it to be size n""" + var = var if is_list(var) or is_tuple(var) else [var] + if n is None: + return var + else: + assert len(var) == 1 or len(var) == n, 'Wrong list length for make_list' + return var * n if len(var) == 1 else var + + +def filter_args(func, keys): + """Filters a dictionary, so it only contains keys that are arguments of a function""" + filtered = {} + sign = list(signature(func).parameters.keys()) + for k, v in {**keys}.items(): + if k in sign: + filtered[k] = v + return filtered + + +def dict_remove_nones(dic): + """Filters dictionary to remove keys with None values""" + return {key: val for key, val in dic.items() if val is not None} + + +@iterate1 +def matmul1(v1, v2): + """Iteratively multiply tensors""" + return v1 @ v2 + + +@iterate2 +def matmul2(v1, v2): + """Iteratively multiply tensors""" + return v1 @ v2 + + +@iterate1 +def unsqueeze(x): + """Iteratively unsqueeze tensors to batch size 1""" + return x.unsqueeze(0) if is_tensor(x) else x + + +@iterate1 +def fold(data, n): + """Iteratively folds first and second dimensions into one""" + shape = list(data.shape) + if len(shape) == n + 1: + shape = [shape[0] * shape[1]] + shape[2:] + return data.view(*shape) + else: + return data + + +@iterate1 +def expand(data, n, d): + """Iteratively folds first and second dimensions into one""" + shape = list(data.shape) + if len(shape) == n: + return data.unsqueeze(d) + else: + return data + + +def fold_batch(batch, device=None): + """Combine the first (batch) and second (camera) dimensions of a batch""" + if is_seq(batch): + return [fold_batch(b, device=device) for b in batch] + for key in keys_in(batch, KEYS_IMAGE): + batch[key] = fold(batch[key], 4) + for key in keys_in(batch, KEYS_MATRIX): + batch[key] = fold(batch[key], 3) + if device is not None: + batch = batch_to_device(batch, device) + return batch + + +def expand_batch(batch, d, device=None): + """Expand the batch to include an additional dimension (0 for batch, 1 for camera)""" + if is_seq(batch): + return [expand_batch(b, d, device=device) for b in batch] + d = {'batch': 0, 'camera': 1}[d] + for key in keys_in(batch, KEYS_IMAGE): + batch[key] = expand(batch[key], 4, d) + for key in keys_in(batch, KEYS_MATRIX): + batch[key] = expand(batch[key], 3, d) + if device is not None: + batch = batch_to_device(batch, device) + return batch + + +def batch_to_device(batch, device): + """Copy batch information to device""" + if is_dict(batch): + return {key: batch_to_device(val, device) for key, val in batch.items()} + if is_list(batch): + return [batch_to_device(val, device) for val in batch] + if is_tensor(batch): + return batch.to(device) + return batch + + +def num_trainable_params(model): + """Return number of trainable parameters""" + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def set_random_seed(seed): + if seed >= 0: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def make_batch(batch, device=None): + """Transforms a sample into a batch""" + for key in batch.keys(): + if is_dict(batch[key]): + batch[key] = make_batch(batch[key]) + elif is_tensor(batch[key]): + batch[key] = batch[key].unsqueeze(0) + if device is not None: + batch = batch_to_device(batch, device) + return batch + + +def break_key(sample, n=None): + """Break a multi-camera sample key, so different cameras have their own entries (context, camera)""" + if sample is None: + return sample + new_sample = OrderedDict() + for ctx in sample.keys(): + if is_dict(sample[ctx]): + for key2, val in sample[ctx].items(): + if val.dim() == 1: + val = val.unsqueeze(1) + for i in range(val.shape[1]): + if (ctx, i) not in new_sample.keys(): + new_sample[(ctx, i)] = {} + new_sample[(ctx, i)][key2] = val[:, [i]] + elif sample[ctx].dim() == n + 1: + for i in range(sample[ctx].shape[1]): + new_sample[(ctx, i)] = sample[ctx][:, i] + return new_sample + + +def break_batch(batch): + """Break a multi-camera batch, so different cameras have their own entries (context, camera)""" + for key in keys_in(batch, KEYS_IMAGE): + for ctx in list(batch[key].keys()): + if batch[key][ctx].dim() == 5: + for n in range(batch[key][ctx].shape[1]): + batch[key][(ctx,n)] = batch[key][ctx][:, n] + batch[key].pop(ctx) + for key in keys_in(batch, KEYS_MATRIX): + for ctx in list(batch[key].keys()): + if batch[key][ctx].dim() == 4: + for n in range(batch[key][ctx].shape[1]): + batch[key][(ctx,n)] = batch[key][ctx][:, n] + batch[key].pop(ctx) + return batch + + +def dict_has(dic, key): + """Check if a dictionary has a certain key""" + return key in dic + + +def get_from_dict(dic, key): + """Get value from a dictionary (return None if key is not in dictionary)""" + return None if key not in dic else dic[key] + + +def get_mask_from_list(mask, i, return_ones=None): + """Retrieve mask from a list (if it's not a list, return the mask itself, and create one if requested)""" + if return_ones is None: + return None if mask is None else mask[i] if is_list(mask) else mask + else: + mask = torch.ones_like(return_ones[i] if is_list(return_ones) else return_ones).bool() if mask is None \ + else mask[i].clone().bool() if is_list(mask) else mask.clone().bool() + if mask.dim() == 4: + return mask[:, [0]] + elif mask.dim() == 3: + return mask[..., [0]] + + +def get_from_list(lst, i): + """Get information from a list (return None if input is None, and return value directly if it's not a list)""" + return None if lst is None else lst[i] if is_seq(lst) else lst diff --git a/vidar/utils/decorators.py b/vidar/utils/decorators.py new file mode 100755 index 0000000000000000000000000000000000000000..38d11b5d34014241ece78df8308907ee90e2913f --- /dev/null +++ b/vidar/utils/decorators.py @@ -0,0 +1,60 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from vidar.utils.types import is_seq, is_dict + + +def iterate1(func): + """Decorator to iterate over a list (first argument)""" + def inner(var, *args, **kwargs): + if is_seq(var): + return [func(v, *args, **kwargs) for v in var] + elif is_dict(var): + return {key: func(val, *args, **kwargs) for key, val in var.items()} + else: + return func(var, *args, **kwargs) + return inner + + +def iterate2(func): + """Decorator to iterate over a list (second argument)""" + def inner(self, var, *args, **kwargs): + if is_seq(var): + return [func(self, v, *args, **kwargs) for v in var] + elif is_dict(var): + return {key: func(self, val, *args, **kwargs) for key, val in var.items()} + else: + return func(self, var, *args, **kwargs) + return inner + + +def iterate12(func): + """Decorator to iterate over a list (first argument)""" + def inner(var1, var2, *args, **kwargs): + if is_seq(var1) and is_seq(var2): + return [func(v1, v2, *args, **kwargs) for v1, v2 in zip(var1, var2)] + elif is_dict(var1) and is_dict(var2): + return {key: func(val1, val2, *args, **kwargs) + for key, val1, val2 in zip(var1.keys(), var1.values(), var2.values())} + else: + return func(var1, var2, *args, **kwargs) + return inner + + +def multi_write(func): + """Decorator to write multiple files""" + def inner(filename, data, **kwargs): + if is_seq(data): + for i in range(len(data)): + filename_i, ext = filename.split('.') + filename_i = '%s_%d.%s' % (filename_i, i, ext) + func(filename_i, data[i], **kwargs) + return + elif is_dict(data): + for key, val in data.items(): + filename_i, ext = filename.split('.') + filename_i = '%s(%s).%s' % (filename_i, key, ext) + func(filename_i, val, **kwargs) + return + else: + return func(filename, data) + return inner diff --git a/vidar/utils/depth.py b/vidar/utils/depth.py new file mode 100644 index 0000000000000000000000000000000000000000..59493454bc15dbf5ddd6e3460f233444ddfab2f3 --- /dev/null +++ b/vidar/utils/depth.py @@ -0,0 +1,287 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import torch +import torch.nn.functional as tfn + +from vidar.geometry.camera import Camera +from vidar.utils.decorators import iterate1 +from vidar.utils.types import is_tensor, is_numpy + + +@iterate1 +@iterate1 +def inv2depth(inv_depth): + """ + Invert an inverse depth map to produce a depth map + + Parameters + ---------- + inv_depth : torch.Tensor or list[torch.Tensor] or np.array or list[np.array] + Inverse depth map [B,1,H,W] + + Returns + ------- + depth : torch.Tensor or list[torch.Tensor] or np.array or list[np.array] + Depth map [B,1,H,W] + """ + if is_tensor(inv_depth): + depth = 1. / inv_depth.clamp(min=1e-6, max=None) + elif is_numpy(inv_depth): + depth = 1. / inv_depth.clip(min=1e-6, max=None) + else: + raise NotImplementedError('Wrong inverse depth format.') + depth[inv_depth <= 0.] = 0. + return depth + + +@iterate1 +@iterate1 +def depth2inv(depth): + """ + Invert a depth map to produce an inverse depth map + + Parameters + ---------- + depth : torch.Tensor or list[torch.Tensor] or np.array or list[np.array] + Depth map [B,1,H,W] + + Returns + ------- + inv_depth : torch.Tensor or list[torch.Tensor] pr np.array or list[np.array] + Inverse depth map [B,1,H,W] + + """ + if is_tensor(depth): + inv_depth = 1. / depth.clamp(min=1e-6, max=None) + elif is_numpy(depth): + inv_depth = 1. / depth.clip(min=1e-6, max=None) + else: + raise NotImplementedError('Wrong depth format') + inv_depth[depth <= 0.] = 0. + return inv_depth + + +def fuse_inv_depth(inv_depth, inv_depth_hat, method='mean'): + """ + Fuse inverse depth and flipped inverse depth maps + + Parameters + ---------- + inv_depth : torch.Tensor + Inverse depth map [B,1,H,W] + inv_depth_hat : torch.Tensor + Flipped inverse depth map produced from a flipped image [B,1,H,W] + method : String + Method that will be used to fuse the inverse depth maps + + Returns + ------- + fused_inv_depth : torch.Tensor [B,1,H,W] + Fused inverse depth map + """ + if method == 'mean': + return 0.5 * (inv_depth + inv_depth_hat) + elif method == 'max': + return torch.max(inv_depth, inv_depth_hat) + elif method == 'min': + return torch.min(inv_depth, inv_depth_hat) + else: + raise ValueError('Unknown post-process method {}'.format(method)) + + +def post_process_inv_depth(inv_depth, inv_depth_flipped, method='mean'): + """ + Post-process an inverse and flipped inverse depth map + + Parameters + ---------- + inv_depth : torch.Tensor + Inverse depth map [B,1,H,W] + inv_depth_flipped : torch.Tensor + Inverse depth map produced from a flipped image [B,1,H,W] + method : String + Method that will be used to fuse the inverse depth maps + + Returns + ------- + inv_depth_pp : torch.Tensor + Post-processed inverse depth map [B,1,H,W] + """ + from vidar.utils.flip import flip_lr + B, C, H, W = inv_depth.shape + inv_depth_hat = inv_depth_flipped + inv_depth_fused = fuse_inv_depth(inv_depth, inv_depth_hat, method=method) + xs = torch.linspace(0., 1., W, device=inv_depth.device, + dtype=inv_depth.dtype).repeat(B, C, H, 1) + mask = 1.0 - torch.clamp(20. * (xs - 0.05), 0., 1.) + mask_hat = flip_lr(mask) + post_processed = mask_hat * inv_depth + mask * inv_depth_hat + \ + (1.0 - mask - mask_hat) * inv_depth_fused + mask0, mask_hat0 = inv_depth == 0, inv_depth_hat == 0 + post_processed[mask_hat0] = inv_depth[mask_hat0] + post_processed[mask0] = inv_depth_hat[mask0] + post_processed[mask0 & mask_hat0] = 0 + return post_processed + + +def post_process_depth(depth, depth_flipped, method='mean'): + """Post-process a depth map and flipped depth map""" + return inv2depth(post_process_inv_depth( + depth2inv(depth), depth2inv(depth_flipped), method=method)) + + +def calculate_normals(points, camera=None, intrinsics=None, pad_last=True): + """ + Calculate normals from a pointcloud map or from a depth map + intrinsics + + Parameters + ---------- + points : torch.Tensor + A pointcloud map [B,3,H,W] containing 3D coordinates + A depth map [B,1,H,W] containing depth values + camera : Camera + Camera used for normal calculation, in case a depth map is provided + intrinsics : torch.Tensor + Camera intrinsics [B,3,3] necessary in case a depth map is provided, to create the pointcloud map + pad_last : Bool + If true, pad the last row and column with zeros + + Returns + ------- + normals : torch.Tensor + Normal map [B,3,H,W] containing normal estimates + """ + if intrinsics is None and camera is None: + return points + # Create pointcloud map if intrinsics are provided + if camera is not None: + points = camera.reconstruct_depth_map(points) + elif intrinsics is not None: + points = Camera(K=intrinsics, hw=points).reconstruct_depth_map(points) + # Prepare points for cross-product + p0 = points[:, :, :-1, :-1] + p1 = points[:, :, 1:, :-1] + p2 = points[:, :, :-1, 1:] + # Calculate and normalize normals + normals = torch.cross(p1 - p0, p2 - p0, 1) + normals = normals / normals.norm(dim=1, keepdim=True) + # Pad normals + if pad_last: + normals = torch.nn.functional.pad(normals, [0, 1, 0, 1], mode='replicate') + # # Substitute nan values with zero + normals[torch.isnan(normals)] = 0.0 + # Return normals + return normals + + +def calc_dot_prod(pts, nrm): + """ + Calculate dot product of 3D points and their normals + + Parameters + ---------- + pts : torch.Tensor + Input 3D points [B,3,H,W] + nrm : torch.Tensor + Input 3D normal vectors [B,3,H,W] + + Returns + ------- + dots : torch.Tensor + Output dot product [B,1,H,W] + """ + pts = pts / pts.norm(dim=1, keepdim=True) + nrm = nrm / nrm.norm(dim=1, keepdim=True) + dots = torch.sum(pts * nrm, dim=1, keepdim=True) + return dots + + +def get_depth_bins(mode, min_val, max_val, num_vals): + """ + Create discretize depth bins + + Parameters + ---------- + mode : String + Discretization mode + min_val : Float + Minimum depth value + max_val : Float + Maximum depth value + num_vals : Int + Number of intervals + + Returns + ------- + bins : torch.Tensor + Discretized depth values [num_vals] + """ + if mode == 'inverse': + depth_bins = 1. / torch.linspace( + 1. / max_val, 1. / min_val, num_vals)[::-1] + elif mode == 'linear': + depth_bins = torch.linspace( + min_val, max_val, num_vals) + elif mode == 'sid': + depth_bins = torch.tensor( + [torch.exp(torch.log(min_val) + torch.log(max_val / min_val) * i / (num_vals - 1)) + for i in range(num_vals)]) + else: + raise NotImplementedError + return depth_bins.float() + + +def depth2index(depth, bins): + """ + Convert a depth map to discretized indexes + + Parameters + ---------- + depth : torch.Tensor + Input depth map [B,1,H,W] + bins : torch.Tensor + Discretized depth bins [D] + + Returns + ------- + idx : torch.Tensor + Discretized depth indexes [B,1,H,W] + """ + if depth.dim() == 2: + depth = tfn.relu(depth - bins.reshape(1, -1)) + elif depth.dim() == 4: + depth = tfn.relu(depth - bins.reshape(1, -1, 1, 1)) + else: + raise ValueError('Invalid depth dimension') + idx = torch.min(depth, dim=1)[1] + # idx[(idx < 0) | (idx == len(bins) - 1)] = -1 + idx[(idx < 0)] = 0 + idx[idx > len(bins) - 1] = len(bins) - 1 + return idx.unsqueeze(1) + + +def index2depth(idx, bins): + """ + Converts discretized indexes to depth map + + Parameters + ---------- + idx : torch.Tensor + Discretized indexes [B,1,H,W] + bins : torch.Tensor + Discretized depth bins [D] + + Returns + ------- + depth : torch.Tensor + Output depth map [B,1,H,W] + """ + if idx.dim() == 4: + b, _, h, w = idx.shape + bins = bins.reshape(1, -1, 1, 1).repeat(idx.shape[0], 1, idx.shape[2], idx.shape[3]).to(idx.device) + elif idx.dim() == 3: + b, _, n = idx.shape + bins = bins.reshape(1, -1, 1).repeat(idx.shape[0], 1, idx.shape[2]).to(idx.device) + else: + raise ValueError('Invalid depth dimension') + return torch.gather(bins, 1, idx) diff --git a/vidar/utils/distributed.py b/vidar/utils/distributed.py new file mode 100755 index 0000000000000000000000000000000000000000..ab0564a6e55ec4a1b5aa9611c530ed266733a22c --- /dev/null +++ b/vidar/utils/distributed.py @@ -0,0 +1,74 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import os + +import torch.distributed as dist + + +def dist_mode(): + return os.getenv('DIST_MODE') + + +def rank(): + """Returns process rank""" + if dist_mode() in ['cpu', 'gpu', None]: + return 0 + elif dist_mode() == 'ddp': + return int(os.environ['RANK']) + else: + raise ValueError('Wrong distributed mode {}'.format(dist_mode)) + + +def world_size(): + """Returns world size""" + if dist_mode() in ['cpu', 'gpu', None]: + return 1 + elif dist_mode() == 'ddp': + return int(os.environ['WORLD_SIZE']) + else: + raise ValueError('Wrong distributed mode {}'.format(dist_mode)) + + +def on_rank_0(func): + """Decorator to run function only on rank 0""" + def wrapper(*args, **kwargs): + if rank() == 0: + return func(*args, **kwargs) + return wrapper + + +@on_rank_0 +def print0(string='\n'): + """Function to print only on rank 0""" + print(string) + + +def reduce_value(value, average, name): + """ + Reduce the mean value of a tensor from all GPUs + + Parameters + ---------- + value : torch.Tensor + Value to be reduced + average : Bool + Whether values will be averaged or not + name : String + Value name + + Returns + ------- + value : torch.Tensor + reduced value + """ + if dist_mode() == 'cpu': + return value + elif dist_mode() == 'gpu': + return value + elif dist_mode() == 'ddp': + dist.all_reduce(tensor=value, op=dist.ReduceOp.SUM, async_op=False) + if average: + value /= world_size() + return value + else: + raise ValueError('Wrong distributed mode {}'.format(dist_mode)) diff --git a/vidar/utils/flip.py b/vidar/utils/flip.py new file mode 100644 index 0000000000000000000000000000000000000000..8f2f360a314ac9cef635dbc7c5a6d6af8d4907c4 --- /dev/null +++ b/vidar/utils/flip.py @@ -0,0 +1,182 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import torch +from pytorch3d.transforms.rotation_conversions import \ + matrix_to_euler_angles, euler_angles_to_matrix + +from vidar.utils.data import keys_in +from vidar.utils.decorators import iterate1, iterate12 +from vidar.utils.types import is_tensor, is_list, is_seq + + +def flip_lr_fn(tensor): + """Function to flip a tensor from left to right""" + return torch.flip(tensor, [-1]) + + +def flip_flow_lr_fn(flow): + """Function to flip a flow tensor from left to right""" + flow_flip = torch.flip(flow, [3]) + flow_flip[:, :1, :, :] *= -1 + return flow_flip.contiguous() + + +def flip_intrinsics_lr_fn(K, shape): + """Function to flip a 3x3 intrinsic matrix from left to right""" + K = K.clone() + K[:, 0, 2] = shape[-1] - K[:, 0, 2] + return K + + +def flip_pose_lr_fn(T): + """Function to flip a 4x4 transformation matrix from left to right""" + rot = T[:, :3, :3] + axis = matrix_to_euler_angles(rot, convention='XYZ') + axis[:, [1, 2]] = axis[:, [1, 2]] * -1 + rot = euler_angles_to_matrix(axis, convention='XYZ') + T[:, :3, :3] = rot + T[:, 0, -1] = - T[:, 0, -1] + return T + + +@iterate1 +def flip_lr(tensor, flip=True): + """Flip a tensor from left to right""" + # Not flipping option + if not flip: + return tensor + # If it's a list, repeat + if is_list(tensor): + return [flip_lr(t) for t in tensor] + # Return flipped tensor + if tensor.dim() == 5: + return torch.stack([flip_lr_fn(tensor[:, i]) + for i in range(tensor.shape[1])], 1) + else: + return flip_lr_fn(tensor) + + +@iterate1 +def flip_flow_lr(flow, flip=True): + """Flip a flow tensor from left to right""" + # Not flipping option + if not flip: + return flow + # If it's a list, repeat + if is_list(flow): + return [flip_flow_lr(f) for f in flow] + # Flip flow and invert first dimension + if flow.dim() == 5: + return torch.stack([flip_flow_lr_fn(flow[:, i]) + for i in range(flow.shape[1])], 1) + else: + return flip_flow_lr_fn(flow) + + +@iterate12 +def flip_intrinsics_lr(K, shape, flip=True): + """Flip a 3x3 camera intrinsic matrix from left to right""" + # Not flipping option + if not flip: + return K + # If shape is a tensor, use it's dimensions + if is_tensor(shape): + shape = shape.shape + # Flip horizontal information (first row) + if K.dim() == 4: + return torch.stack([flip_intrinsics_lr_fn(K[:, i], shape) + for i in range(K.shape[1])], 1) + else: + return flip_intrinsics_lr_fn(K, shape) + + +def flip_pose_lr(pose, flip=True): + """Flip a 4x4 transformation matrix from left to right""" + # Not flipping option + if not flip: + return pose + # Repeat for all pose keys + for key in pose.keys(): + # Get pose key + if key == 0: + if pose[key].dim() == 3: + continue + elif pose[key].dim() == 4: + T = pose[key][:, 1:].clone() + else: + raise ValueError('Invalid pose dimension') + else: + T = pose[key].clone() + # Flip pose + if T.dim() == 4: + T = torch.stack([flip_pose_lr_fn(T[:, i]) + for i in range(T.shape[1])], 1) + else: + T = flip_pose_lr_fn(T) + # Store flipped value back + if key == 0: + pose[key][:, 1:] = T + else: + pose[key] = T + # Return flipped pose + return pose + + +def flip_batch(batch, flip=True): + """Flip a batch from left to right""" + # Not flipping option + if not flip: + return batch + # If it's a list, repeat + if is_seq(batch): + return [flip_batch(b) for b in batch] + # Flip batch + flipped_batch = {} + # Keys to not flip + for key in keys_in(batch, ['idx', 'filename', 'splitname']): + flipped_batch[key] = batch[key] + # Tensor flipping + for key in keys_in(batch, ['rgb', 'mask', 'input_depth', 'depth', 'semantic']): + flipped_batch[key] = flip_lr(batch[key]) + # Intrinsics flipping + for key in keys_in(batch, ['intrinsics']): + flipped_batch[key] = flip_intrinsics_lr(batch[key], batch['rgb']) + # Pose flipping + for key in keys_in(batch, ['pose']): + flipped_batch[key] = flip_pose_lr(batch[key]) + return flipped_batch + + +def flip_predictions(predictions, flip=True): + """Flip predictions from left to right""" + # Not flipping option + if not flip: + return predictions + # Flip predictions + flipped_predictions = {} + for key in predictions.keys(): + if key.startswith('depth'): + flipped_predictions[key] = flip_lr(predictions[key]) + if key.startswith('pose'): + flipped_predictions[key] = flip_pose_lr(predictions[key]) + # Return flipped predictions + return flipped_predictions + + +def flip_output(output, flip=True): + """Flip output from left to right""" + # Not flipping option + if not flip: + return output + # If it's a list, repeat + if is_seq(output): + return [flip_output(b) for b in output] + # Flip output + flipped_output = {} + # Do not flip loss and metrics + for key in keys_in(output, ['loss', 'metrics']): + flipped_output[key] = output[key] + # Flip predictions + flipped_output['predictions'] = flip_predictions(output['predictions']) + # Return flipped output + return flipped_output diff --git a/vidar/utils/flow.py b/vidar/utils/flow.py new file mode 100644 index 0000000000000000000000000000000000000000..0ffbe6a4b6ec60835e3749ef063ff2fee141ceb1 --- /dev/null +++ b/vidar/utils/flow.py @@ -0,0 +1,293 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import torch +import torch.nn.functional as tfn + +from vidar.utils.data import make_list +from vidar.utils.flow_triangulation_support import bearing_grid, mult_rotation_bearing, triangulation +from vidar.utils.tensor import pixel_grid, norm_pixel_grid, unnorm_pixel_grid +from vidar.utils.types import is_list + + +def warp_from_coords(tensor, coords, mode='bilinear', + padding_mode='zeros', align_corners=True): + """ + Warp an image from a coordinate map + + Parameters + ---------- + tensor : torch.Tensor + Input tensor for warping [B,?,H,W] + coords : torch.Tensor + Warping coordinates [B,2,H,W] + mode : String + Warping mode + padding_mode : String + Padding mode + align_corners : Bool + Align corners flag + + Returns + ------- + warp : torch.Tensor + Warped tensor [B,?,H,W] + """ + # Sample grid from data with coordinates + warp = tfn.grid_sample(tensor, coords.permute(0, 2, 3, 1), + mode=mode, padding_mode=padding_mode, + align_corners=align_corners) + # Returned warped tensor + return warp + + +def coords_from_optical_flow(optflow): + """ + Get warping coordinates from optical flow + Parameters + ---------- + optflow : torch.Tensor + Input optical flow tensor [B,2,H,W] + + Returns + ------- + coords : torch.Tensor + Warping coordinates [B,2,H,W] + """ + # Create coordinate with optical flow + coords = pixel_grid(optflow, device=optflow) + optflow + # Normalize and return coordinate grid + return norm_pixel_grid(coords) + + +def warp_depth_from_motion(ref_depth, tgt_depth, ref_cam): + """ + Warp depth map using motion (depth + ego-motion) information + + Parameters + ---------- + ref_depth : torch.Tensor + Reference depth map [B,1,H,W] + tgt_depth : torch.Tensor + Target depth map [B,1,H,W] + ref_cam : Camera + Reference camera + + Returns + ------- + warp : torch.Tensor + Warped depth map [B,1,H,W] + """ + ref_depth = reproject_depth_from_motion(ref_depth, ref_cam) + return warp_from_motion(ref_depth, tgt_depth, ref_cam) + + +def reproject_depth_from_motion(ref_depth, ref_cam): + """ + Calculate reprojected depth from motion (depth + ego-motion) information + + Parameters + ---------- + ref_depth : torch.Tensor + Reference depth map [B,1,H,W] + ref_cam : Camera + Reference camera + + Returns + ------- + coords : torch.Tensor + Warping coordinates from reprojection [B,2,H,W] + """ + ref_points = ref_cam.reconstruct_depth_map(ref_depth, to_world=True) + return ref_cam.project_points(ref_points, from_world=False, return_z=True)[1] + + +def warp_from_motion(ref_rgb, tgt_depth, ref_cam): + """ + Warp image using motion (depth + ego-motion) information + + Parameters + ---------- + ref_rgb : torch.Tensor + Reference image [B,3,H,W] + tgt_depth : torch.Tensor + Target depth map [B,1,H,W] + ref_cam : Camera + Reference camera + + Returns + ------- + warp : torch.Tensor + Warped image [B,3,H,W] + """ + tgt_points = ref_cam.reconstruct_depth_map(tgt_depth, to_world=False) + return warp_from_coords(ref_rgb, ref_cam.project_points(tgt_points, from_world=True).permute(0, 3, 1, 2)) + + +def coords_from_motion(ref_camera, tgt_depth, tgt_camera): + """ + Get coordinates from motion (depth + ego-motion) information + + Parameters + ---------- + ref_camera : Camera + Reference camera + tgt_depth : torch.Tensor + Target depth map [B,1,H,W] + tgt_camera : Camera + Target camera + + Returns + ------- + coords : torch.Tensor + Warping coordinates [B,2,H,W] + """ + if is_list(ref_camera): + return [coords_from_motion(camera, tgt_depth, tgt_camera) + for camera in ref_camera] + # If there are multiple depth maps, iterate for each + if is_list(tgt_depth): + return [coords_from_motion(ref_camera, depth, tgt_camera) + for depth in tgt_depth] + world_points = tgt_camera.reconstruct_depth_map(tgt_depth, to_world=True) + return ref_camera.project_points(world_points, from_world=True).permute(0, 3, 1, 2) + + +def optflow_from_motion(ref_camera, tgt_depth): + """ + Get optical flow from motion (depth + ego-motion) information + + Parameters + ---------- + ref_camera : Camera + Reference camera + tgt_depth : torch.Tensor + Target depth map + + Returns + ------- + optflow : torch.Tensor + Optical flow map [B,2,H,W] + """ + coords = ref_camera.coords_from_depth(tgt_depth).permute(0, 3, 1, 2) + return optflow_from_coords(coords) + + +def optflow_from_coords(coords): + """ + Get optical flow from coordinates + + Parameters + ---------- + coords : torch.Tensor + Input warping coordinates [B,2,H,W] + + Returns + ------- + optflow : torch.Tensor + Optical flow map [B,2,H,W] + """ + return unnorm_pixel_grid(coords) - pixel_grid(coords, device=coords) + + +def warp_from_optflow(ref_rgb, tgt_optflow): + """ + Warp image using optical flow information + + Parameters + ---------- + ref_rgb : torch.Tensor + Reference image [B,3,H,W] + tgt_optflow : torch.Tensor + Target optical flow [B,2,H,W] + + Returns + ------- + warp : torch.Tensor + Warped image [B,3,H,W] + """ + coords = coords_from_optical_flow(tgt_optflow) + return warp_from_coords(ref_rgb, coords, align_corners=True, + mode='bilinear', padding_mode='zeros') + + +def reverse_optflow(tgt_optflow, ref_optflow): + """ + Reverse optical flow + + Parameters + ---------- + tgt_optflow : torch.Tensor + Target optical flow [B,2,H,W] + ref_optflow : torch.Tensor + Reference optical flow [B,2,H,W] + + Returns + ------- + optflow : torch.Tensor + Reversed optical flow [B,2,H,W] + """ + return - warp_from_optflow(tgt_optflow, ref_optflow) + + +def mask_from_coords(coords, align_corners=True): + """ + Get overlap mask from coordinates + + Parameters + ---------- + coords : torch.Tensor + Warping coordinates [B,2,H,W] + align_corners : Bool + Align corners flag + + Returns + ------- + mask : torch.Tensor + Overlap mask [B,1,H,W] + """ + if is_list(coords): + return [mask_from_coords(coord) for coord in coords] + b, _, h, w = coords.shape + mask = torch.ones((b, 1, h, w), dtype=torch.float32, device=coords.device, requires_grad=False) + mask = warp_from_coords(mask, coords, mode='nearest', padding_mode='zeros', align_corners=True) + return mask.bool() + + +def depth_from_optflow(rgb, intrinsics, pose_context, flows, + residual=False, clip_range=None): + """ + Get depth from optical flow + camera information + + Parameters + ---------- + rgb : torch.Tensor + Base image [B,3,H,W] + intrinsics : torch.Tensor + Camera intrinsics [B,3,3] + pose_context : torch.Tensor or list[torch.Tensor] + List of relative context camera poses [B,4,4] + flows : torch.Tensor or list[torch.Tensor] + List of target optical flows [B,2,H,W] + residual : Bool + Return residual error with depth + clip_range : Tuple + Depth range clipping values + + Returns + ------- + depth : torch.Tensor + Depth map [B,1,H,W] + """ + # Make lists if necessary + flows = make_list(flows) + pose_context = make_list(pose_context) + # Extract rotations and translations + rotations = [p[:, :3, :3] for p in pose_context] + translations = [p[:, :3, -1] for p in pose_context] + # Get bearings + bearings = bearing_grid(rgb, intrinsics).to(rgb.device) + rot_bearings = [mult_rotation_bearing(rotation, bearings) + for rotation in rotations] + # Return triangulation results + return triangulation(rot_bearings, translations, flows, intrinsics, + clip_range=clip_range, residual=residual) diff --git a/vidar/utils/flow_triangulation_support.py b/vidar/utils/flow_triangulation_support.py new file mode 100755 index 0000000000000000000000000000000000000000..286005be792e4fe562d0695001ab3c2d72c0f0a3 --- /dev/null +++ b/vidar/utils/flow_triangulation_support.py @@ -0,0 +1,223 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import numpy as np +import torch +import torch.nn.functional as tfunc + +from vidar.utils.tensor import pixel_grid, cat_channel_ones + + +def bearing_grid(rgb, intrinsics): + """ + Create a homogeneous bearing grid from camera intrinsics and a base image + + Parameters + ---------- + rgb : torch.Tensor + Base image for dimensions [B,3,H,W] + intrinsics : torch.Tensor + Camera intrinsics [B,3,3] + + Returns + ------- + grid : torch.Tensor + Bearing grid [B,3,H,W] + """ + # Create pixel grid from base image + b, _, h, w = rgb.shape + grid = pixel_grid((h, w), b).to(rgb.device) + # Normalize pixel grid with camera parameters + grid[:, 0] = (grid[:, 0] - intrinsics[:, 0, 2].unsqueeze(-1).unsqueeze(-1)) / intrinsics[:, 0, 0].unsqueeze(-1).unsqueeze(-1) + grid[:, 1] = (grid[:, 1] - intrinsics[:, 1, 2].unsqueeze(-1).unsqueeze(-1)) / intrinsics[:, 1, 1].unsqueeze(-1).unsqueeze(-1) + # Return bearing grid (with 1s as extra dimension) + return cat_channel_ones(grid) + + +def mult_rotation_bearing(rotation, bearing): + """ + Rotates a bearing grid + + Parameters + ---------- + rotation : torch.Tensor + Rotation matrix [B,3,3] + bearing : torch.Tensor + Bearing grid [B,3,H,W] + + Returns + ------- + rot_bearing : torch.Tensor + Rotated bearing grid [B,3,H,W] + """ + # Multiply rotation and bearing + product = torch.bmm(rotation, bearing.view(bearing.shape[0], 3, -1)) + # Return product with bearing shape + return product.view(bearing.shape) + + +def pre_triangulation(ref_bearings, ref_translations, tgt_flows, + intrinsics, concat=True): + """ + Triangulates bearings and flows + + Parameters + ---------- + ref_bearings : list[torch.Tensor] + Reference bearings [B,3,H,W] + ref_translations : list[torch.Tensor] + Reference translations [B,3] + tgt_flows : list[torch.Tensor] + Target optical flow values [B,2,H,W] + intrinsics : torch.Tensor + Camera intrinsics [B,3,3] + concat : Bool + True if cross product results are concatenated + + Returns + ------- + rs : torch.Tensor or list[torch.Tensor] + Bearing x translation cross product [B,3,H,W] (concatenated or not) + ss : torch.Tensor or list[torch.Tensor] + Bearing x bearing cross product [B,3,H,W] (concatenated or not) + """ + # Get target bearings from flow + tgt_bearings = [flow2bearing(flow, intrinsics, normalize=True) + for flow in tgt_flows] + # Bearings x translation cross product + rs = [torch.cross(tgt_bearing, ref_translation[:, :, None, None].expand_as(tgt_bearing), dim=1) + for tgt_bearing, ref_translation in zip(tgt_bearings, ref_translations)] + # Bearings x bearings cross product + ss = [torch.cross(tgt_bearing, ref_bearing, dim=1) + for tgt_bearing, ref_bearing in zip(tgt_bearings, ref_bearings)] + if concat: + # If results are to be concatenated + return torch.cat(rs, dim=1), torch.cat(ss, dim=1) + else: + # Otherwise, return as lists + return rs, ss + + +def depth_ls2views(r, s, clip_range=None): + """ + Least-squares depth estimation from two views + + Parameters + ---------- + r : torch.Tensor + Bearing x translation cross product between images [B,3,H,W] + s : torch.Tensor + Bearing x translation cross product between images [B,3,H,W] + clip_range : Tuple + Depth clipping range (min, max) + + Returns + ------- + depth : torch.Tensor + Calculated depth [B,1,H,W] + error : torch.Tensor + Calculated error [B,1,H,W] + hessian : torch.Tensor + Calculated hessian [B,1,H,W] + + """ + # Calculate matrices + hessian = (s * s).sum(dim=1, keepdims=True) + depth = -(s * r).sum(dim=1, keepdims=True) / (hessian + 1e-30) + error = (r * r).sum(dim=1, keepdims=True) - hessian * (depth ** 2) + + # Clip depth and other matrices if requested + if clip_range is not None: + + invalid_mask = (depth <= clip_range[0]) + invalid_mask |= (depth >= clip_range[1]) + + depth[invalid_mask] = 0 + error[invalid_mask] = 0 + hessian[invalid_mask] = 0 + # Return calculated matrices + return depth, error, hessian + + +def flow2bearing(flow, intrinsics, normalize=True): + """ + Convert optical flow to bearings + + Parameters + ---------- + flow : torch.Tensor + Input optical flow [B,2,H,W] + intrinsics : torch.Tensor + Camera intrinsics [B,3,3] + normalize : Bool + True if bearings are normalized + + Returns + ------- + bearings : torch.Tensor + Calculated bearings [B,3,H,W] + """ + # Create initial grid + height, width = flow.shape[2:] + xx, yy = np.meshgrid(range(width), range(height)) + # Initialize bearing matrix + bearings = torch.zeros_like(flow) + # Populate bearings + match = (flow[:, 0] + torch.from_numpy(xx).to(flow.device), + flow[:, 1] + torch.from_numpy(yy).to(flow.device)) + bearings[:, 0] = (match[0] - intrinsics[:, 0, 2].unsqueeze(-1).unsqueeze(-1)) / intrinsics[:, 0, 0].unsqueeze(-1).unsqueeze(-1) + bearings[:, 1] = (match[1] - intrinsics[:, 1, 2].unsqueeze(-1).unsqueeze(-1)) / intrinsics[:, 1, 1].unsqueeze(-1).unsqueeze(-1) + # Stack 1s as the last dimension + bearings = cat_channel_ones(bearings) + # Normalize if necessary + if normalize: + bearings = tfunc.normalize(bearings) + # Return bearings + return bearings + + +def triangulation(ref_bearings, ref_translations, + tgt_flows, intrinsics, clip_range=None, residual=False): + """ + Triangulate optical flow points to produce depth estimates + + Parameters + ---------- + ref_bearings : list[torch.Tensor] + Reference bearings [B,3,H,W] + ref_translations : list[torch.Tensor] + Reference translations [B,3] + tgt_flows : list[torch.Tensor] + Target optical flow to reference [B,2,H,W] + intrinsics : torch.Tensor + Camera intrinsics [B,3,3] + clip_range : Tuple + Depth clipping range + residual : Bool + True to return residual error and squared root of Hessian + + Returns + ------- + depth : torch.Tensor + Estimated depth [B,1,H,W] + error : torch.Tensor + Estimated error [B,1,H,W] + sqrt_hessian : torch.Tensor + Squared root of Hessian [B,1,H,W] + """ + # Pre-triangulate flows + rs, ss = pre_triangulation(ref_bearings, ref_translations, tgt_flows, intrinsics, concat=False) + # Calculate list of triangulations + outputs = [depth_ls2views(*rs_ss, clip_range=clip_range) for rs_ss in zip(rs, ss)] + # Calculate predicted hessian and depths + hessian = sum([output[2] for output in outputs]) + depth = sum([output[0] * output[2] for output in outputs]) / (hessian + 1e-12) + # Return depth + residual error and hessian matrix + if residual: + error = torch.sqrt(sum([output[2] * (depth - output[0]) ** 2 + output[1] + for output in outputs]).clamp_min(0)) + sqrt_hessian = torch.sqrt(hessian) + return depth, (error, sqrt_hessian) + # Return depth + else: + return depth + diff --git a/vidar/utils/logging.py b/vidar/utils/logging.py new file mode 100755 index 0000000000000000000000000000000000000000..b0e85b2a5cad2e51e886c94d20efcab9078ecf5a --- /dev/null +++ b/vidar/utils/logging.py @@ -0,0 +1,144 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import argparse +import os +from functools import partial + +from termcolor import colored + +from vidar.utils.distributed import on_rank_0 + + +def pcolor(string, color, on_color=None, attrs=None): + """ + Produces a colored string for printing + + Parameters + ---------- + string : String + String that will be colored + color : String + Color to use + on_color : String + Background color to use + attrs : list[String] + Different attributes for the string + + Returns + ------- + string: String + Colored string + """ + return colored(string, color, on_color, attrs) + + +@on_rank_0 +def print_config(config): + """ + Prints header for model configuration + + Parameters + ---------- + config : Config + Model configuration + """ + header_colors = { + 0: ('red', ('bold', 'dark')), + 1: ('cyan', ('bold','dark')), + 2: ('green', ('bold', 'dark')), + 3: ('green', ('bold', 'dark')), + } + line_colors = ('blue', ()) + + # Recursive print function + def print_recursive(rec_args, pad=3, level=0): + # if level == 0: + # print(pcolor('config:', + # color=header_colors[level][0], + # attrs=header_colors[level][1])) + for key, val in rec_args.__dict__.items(): + if isinstance(val, argparse.Namespace): + print(pcolor('{} {}:'.format('-' * pad, key), + color=header_colors[level][0], + attrs=header_colors[level][1])) + print_recursive(val, pad + 2, level + 1) + else: + print('{}: {}'.format(pcolor('{} {}'.format('-' * pad, key), + color=line_colors[0], + attrs=line_colors[1]), val)) + + # Color partial functions + pcolor1 = partial(pcolor, color='blue', attrs=('bold', 'dark')) + pcolor2 = partial(pcolor, color='blue', attrs=('bold',)) + # Config and name + line = pcolor1('#' * 120) + # if 'default' in config.__dict__.keys(): + # path = pcolor1('### Config: ') + \ + # pcolor2('{}'.format(config.default.replace('/', '.'))) + \ + # pcolor1(' -> ') + \ + # pcolor2('{}'.format(config.config.replace('/', '.'))) + # if 'name' in config.__dict__.keys(): + # name = pcolor1('### Name: ') + \ + # pcolor2('{}'.format(config.name)) + # # Add wandb link if available + # if not config.wandb.dry_run: + # name += pcolor1(' -> ') + \ + # pcolor2('{}'.format(config.wandb.url)) + # # Add s3 link if available + # if config.checkpoint.s3_path is not '': + # name += pcolor1('\n### s3:') + \ + # pcolor2(' {}'.format(config.checkpoint.s3_url)) + # # # Create header string + # # header = '%s\n%s\n%s\n%s' % (line, path, name, line) + + # Print header, config and header again + print() + # print(header) + print_recursive(config) + # print(header) + print() + + +def set_debug(debug): + """ + Enable or disable debug terminal logging + + Parameters + ---------- + debug : Bool + Debugging flag (True to enable) + """ + # Disable logging if requested + if not debug: + os.environ['NCCL_DEBUG'] = '' + os.environ['WANDB_SILENT'] = 'true' + # warnings.filterwarnings("ignore") + # logging.disable(logging.CRITICAL) + + +class AvgMeter: + """Average meter for logging""" + def __init__(self, n_max=100): + self.n_max = n_max + self.values = [] + + def __call__(self, value): + """Append new value and returns average""" + self.values.append(value) + if len(self.values) > self.n_max: + self.values.pop(0) + return self.get() + + def get(self): + """Get current average""" + return sum(self.values) / len(self.values) + + def reset(self): + """Reset meter""" + self.values.clear() + + def get_and_reset(self): + """Get current average and reset""" + average = self.get() + self.reset() + return average diff --git a/vidar/utils/networks.py b/vidar/utils/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..29bc8216b6738d301c33c3b4eecbc51611f93e5d --- /dev/null +++ b/vidar/utils/networks.py @@ -0,0 +1,203 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import math + +import torch +import torch.nn as nn + +from vidar.utils.distributed import print0, rank, dist_mode +from vidar.utils.logging import pcolor +from vidar.utils.tensor import same_shape +from vidar.utils.types import is_list + + +def freeze_layers(network, layers=('ALL',), flag_freeze=True): + """ + Freeze layers of a network (weights and biases) + + Parameters + ---------- + network : nn.Module + Network to be modified + layers : List or Tuple + List of layers to freeze/unfreeze ('ALL' for everything) + flag_freeze : Bool + Whether the layers will be frozen (True) or not (False) + """ + if len(layers) > 0: + for name, parameters in network.named_parameters(): + for layer in layers: + if layer in name or layer == 'ALL': + parameters.requires_grad_(not flag_freeze) + + +def freeze_norms(network, layers=('ALL',), flag_freeze=True): + """ + Freeze layers of a network (normalization) + + Parameters + ---------- + network : nn.Module + Network to be modified + layers : List or Tuple + List of layers to freeze/unfreeze ('ALL' for everything) + flag_freeze : Bool + Whether the layers will be frozen (True) or not (False) + """ + if len(layers) > 0: + for name, module in network.named_modules(): + for layer in layers: + if layer in name or layer == 'ALL': + if isinstance(module, nn.BatchNorm2d): + if hasattr(module, 'weight'): + module.weight.requires_grad_(not flag_freeze) + if hasattr(module, 'bias'): + module.bias.requires_grad_(not flag_freeze) + if flag_freeze: + module.eval() + else: + module.train() + + +def freeze_layers_and_norms(network, layers=('ALL',), flag_freeze=True): + """Freeze layers and normalizations of a network""" + freeze_layers(network, layers, flag_freeze) + freeze_norms(network, layers, flag_freeze) + + +def make_val_fit(model, key, val, updated_state_dict, strict=False): + """ + Parse state dictionary to fit a model, and make tensors fit if requested + + Parameters + ---------- + model : nn.Module + Network to be used + key : String + Which key will be used + val : torch.Tensor + Key value + updated_state_dict : Dict + Updated dictionary + strict : Bool + True if no changes are allowed, False if tensors can be changed to fit + + Returns + ------- + fit : Int + Number of tensors that fit the model + """ + fit = 0 + val_new = model.state_dict()[key] + if same_shape(val.shape, val_new.shape): + updated_state_dict[key] = val + fit += 1 + elif not strict: + for i in range(val.dim()): + if val.shape[i] != val_new.shape[i]: + if val_new.shape[i] > val.shape[i]: + ratio = math.ceil(val_new.shape[i] / val.shape[i]) + val = torch.cat([val] * ratio, i) + if val.shape[i] != val_new.shape[i]: + val = val[:val_new.shape[i]] + if same_shape(val.shape, val_new.shape): + updated_state_dict[key] = val + fit += 1 + elif val_new.shape[0] < val.shape[i]: + val = val[:val_new.shape[i]] + if same_shape(val.shape, val_new.shape): + updated_state_dict[key] = val + fit += 1 + assert fit <= 1 # Each tensor cannot fit 2 or more times + return fit + + +def load_checkpoint(model, checkpoint, strict=False, verbose=False, prefix=None): + """ + Load checkpoint into a model + + Parameters + ---------- + model : nn.Module + Input network + checkpoint : String or list[String] + Checkpoint path (if it's a list, load them in order) + strict : Bool + True if all tensors are required, False if can be partially loaded + verbose : Bool + Print information on screen + prefix : String + Prefix used to change keys + + Returns + ------- + model: nn.Module + Loaded network + """ + if is_list(checkpoint): + for ckpt in checkpoint: + load_checkpoint(model, ckpt, strict, verbose) + return model + + font1 = {'color': 'magenta', 'attrs': ('bold', 'dark')} + font2 = {'color': 'magenta', 'attrs': ('bold',)} + + if verbose: + print0(pcolor('#' * 60, **font1)) + print0(pcolor('###### Loading from checkpoint: ', **font1) + + pcolor('{}'.format(checkpoint), **font2)) + + state_dict = torch.load( + checkpoint, + map_location='cpu' if dist_mode() == 'cpu' else 'cuda:{}'.format(rank()) + )['state_dict'] + updated_state_dict = {} + + total, fit = len(model.state_dict()), 0 + for key, val in state_dict.items(): + + for start in ['model.', 'module.']: + if key.startswith(start): + key = key[len(start):] + if prefix is not None: + idx = key.find(prefix) + if idx > -1: + key = key[(idx + len(prefix) + 1):] + if key in model.state_dict().keys(): + fit += make_val_fit(model, key, val, updated_state_dict, strict=strict) + + model.load_state_dict(updated_state_dict, strict=strict) + + if verbose: + color = 'red' if fit == 0 else 'yellow' if fit < total else 'green' + print0(pcolor('###### Loaded ', **font1) + \ + pcolor('{}/{}'.format(fit,total), color=color, attrs=('bold',)) + \ + pcolor(' tensors', **font1)) + print0(pcolor('#' * 60, **font1)) + + return model + + +def save_checkpoint(filename, wrapper, epoch=None): + """ + Save checkpoint to disk + + Parameters + ---------- + filename : String + Name of the file + wrapper : nn.Module + Model wrapper to save + epoch : Int + Training epoch + """ + if epoch is None: + torch.save({ + 'state_dict': wrapper.state_dict(), + }, filename) + else: + torch.save({ + 'epoch': epoch, + 'config': wrapper.cfg, + 'state_dict': wrapper.arch.state_dict(), + }, filename) diff --git a/vidar/utils/read.py b/vidar/utils/read.py new file mode 100755 index 0000000000000000000000000000000000000000..2b85c499fb9ff49315127b54d196437e56c7133a --- /dev/null +++ b/vidar/utils/read.py @@ -0,0 +1,74 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import pickle as pkl + +import numpy as np +from PIL import Image + +from vidar.utils.decorators import iterate1 + + +def read_pickle(filename): + """ + Read pickle file + + Parameters + ---------- + filename : String + File to read from + + Returns + ------- + data : Value + Data loaded from file + """ + if not filename.endswith('.pkl'): + filename += '.pkl' + return pkl.load(open(filename, 'rb')) + + +@iterate1 +def read_image(path): + """ + Read an image using PIL + + Parameters + ---------- + path : String + Path to the image + + Returns + ------- + image : PIL Image + Loaded image + """ + return Image.open(path) + + +@iterate1 +def read_depth(file): + """ + Load a depth map from file + + Parameters + ---------- + file : String + Depth map filename (.npz or .png or .dpt) + + Returns + ------- + depth : np.array + Depth map (invalid pixels are 0) [H,W] + """ + # If loading a .npz array + if file.endswith('npz'): + return np.load(file)['depth'] + # If loading a .png image + elif file.endswith('png'): + depth_png = np.array(read_image(file), dtype=int) + assert (np.max(depth_png) > 255), 'Wrong .png depth file' + return depth_png.astype(np.float) / 256. + # Invalid type + else: + raise NotImplementedError('Depth extension not supported.') + diff --git a/vidar/utils/reduce.py b/vidar/utils/reduce.py new file mode 100755 index 0000000000000000000000000000000000000000..717b5d5a827a8b2510eb068bfbbe8c5044882d46 --- /dev/null +++ b/vidar/utils/reduce.py @@ -0,0 +1,76 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from collections import OrderedDict + + +def average_key(batch_list, key): + """ + Average key in a list of batches + + Parameters + ---------- + batch_list : list[Dict] + List containing dictionaries with the same keys + key : String + Key to be averaged + + Returns + ------- + average : Float + Average of the value contained in key for all batches + """ + values = [batch[key] for batch in batch_list] + return sum(values) / len(values) + + +def average_sub_key(batch_list, key, sub_key): + """ + Average subkey in a dictionary in a list of batches + + Parameters + ---------- + batch_list : list[Dict] + List containing dictionaries with the same keys + key : String + Key to be averaged + sub_key : String + Sub key to be averaged (belonging to key) + + Returns + ------- + average : Float + Average of the value contained in the sub_key of key for all batches + """ + values = [batch[key][sub_key] for batch in batch_list] + return sum(values) / len(values) + + +def average_loss_and_metrics(batch_list, prefix): + """ + Average loss and metrics values in a list of batches + + Parameters + ---------- + batch_list : list[Dict] + List containing dictionaries with the same keys + prefix : String + Prefix string for metrics logging + + Returns + ------- + values : Dict + Dictionary containing a 'loss' float entry and a 'metrics' dict entry + """ + values = OrderedDict() + + key = 'loss' + values['{}-{}'.format(prefix, key)] = \ + average_key(batch_list, key) + + key = 'metrics' + for sub_key in batch_list[0][key].keys(): + values['{}-{}'.format(prefix, sub_key)] = \ + average_sub_key(batch_list, key, sub_key) + + return values + diff --git a/vidar/utils/setup.py b/vidar/utils/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..ef1533bc90d773c30eaf6ce1098ae56674cd73d9 --- /dev/null +++ b/vidar/utils/setup.py @@ -0,0 +1,361 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import time +from collections import OrderedDict +from copy import deepcopy + +import numpy as np +import torch +from torch.utils.data import ConcatDataset, DataLoader + +from vidar.datasets.utils.transforms import get_transforms +from vidar.metrics.depth import DepthEvaluation +from vidar.utils.config import get_folder_name, load_class, \ + recursive_assignment, cfg_has, cfg_add_to_dict, get_from_cfg_list +from vidar.utils.config import merge_dict, to_namespace +from vidar.utils.data import flatten, keys_in +from vidar.utils.decorators import iterate1 +from vidar.utils.distributed import print0, rank, world_size, dist_mode +from vidar.utils.logging import pcolor +from vidar.utils.networks import load_checkpoint, save_checkpoint +from vidar.utils.types import is_namespace + + +def setup_arch(cfg, checkpoint=None, verbose=False): + """ + Set architecture up for training/inference + + Parameters + ---------- + cfg : Config + Configuration file + checkpoint : String + Checkpoint to be loaded + verbose : Bool + Print information on screen + + Returns + ------- + model: nn.Module + Model ready to go + """ + font = {'color': 'green'} + + if verbose: + print0(pcolor('#' * 60, **font)) + print0(pcolor('### Preparing Architecture', **font)) + print0(pcolor('#' * 60, **font)) + + font1 = {'color': 'yellow', 'attrs': ('dark',)} + font2 = {'color': 'yellow', 'attrs': ('dark', 'bold')} + + folder, name = get_folder_name(cfg.model.file, 'models') + model = load_class(name, folder)(cfg) + + if cfg_has(cfg, 'model'): + if verbose: + print0(pcolor('###### Model:', **font2)) + print0(pcolor('######### %s' % model.__class__.__name__, **font1)) + recursive_assignment(model, cfg.model, 'models', verbose=verbose) + + if cfg_has(cfg, 'networks'): + if verbose: + print0(pcolor('###### Networks:', **font2)) + recursive_assignment(model, cfg.networks, 'networks', verbose=verbose) + + if cfg_has(cfg, 'losses'): + if verbose: + print0(pcolor('###### Losses:', **font2)) + recursive_assignment(model, cfg.losses, 'losses', verbose=verbose) + + if checkpoint is not None: + model = load_checkpoint(model, checkpoint, + strict=True, verbose=verbose) + elif cfg_has(cfg.model, 'checkpoint'): + model = load_checkpoint(model, cfg.model.checkpoint, + strict=cfg.model.has('checkpoint_strict', False), verbose=verbose) + + if cfg.model.has('checkpoint_save'): + save_checkpoint(cfg.model.checkpoint_save, model) + + return model + + +def setup_dataset(cfg, root='vidar/datasets', verbose=False): + """ + Set dataset up for training/inference + + Parameters + ---------- + cfg : Config + Configuration file + root : String + Where the dataset is located + verbose : Bool + Print information on screen + + Returns + ------- + dataset : Dataset + Dataset ready to go + """ + shared_keys = ['context', 'labels', 'labels_context'] + + num_datasets = 0 + for key, val in cfg.__dict__.items(): + if key not in shared_keys and not is_namespace(val): + num_datasets = max(num_datasets, len(val)) + + datasets = [] + for i in range(num_datasets): + args = {} + for key, val in cfg.__dict__.items(): + if not is_namespace(val): + cfg_add_to_dict(args, cfg, key, i if key not in shared_keys else None) + + args['data_transform'] = get_transforms('train', cfg.augmentation) \ + if cfg_has(cfg, 'augmentation') else get_transforms('none') + + name = get_from_cfg_list(cfg, 'name', i) + repeat = get_from_cfg_list(cfg, 'repeat', i) + cameras = get_from_cfg_list(cfg, 'cameras', i) + + context = cfg.context + labels = cfg.labels + + dataset = load_class(name + 'Dataset', root)(**args) + + if cfg_has(cfg, 'repeat') and repeat > 1: + dataset = ConcatDataset([dataset for _ in range(repeat)]) + + if verbose: + string = f'######### {name}: {len(dataset)} samples' + if cfg_has(cfg, 'repeat'): + string += f' (x{repeat})' + if cfg_has(cfg, 'context'): + string += f' | context {context}'.replace(', ', ',') + if cfg_has(cfg, 'cameras'): + string += f' | cameras {cameras}'.replace(', ', ',') + if cfg_has(cfg, 'labels'): + string += f' | labels {labels}'.replace(', ', ',') + print0(pcolor(string , color='yellow', attrs=('dark',))) + + datasets.append(dataset) + + return datasets + + +def setup_datasets(cfg, verbose=False, concat_modes=('train', 'mixed'), stack=True): + """ + Set multiple datasets up for training/inference + + Parameters + ---------- + cfg : Config + Configuration file + verbose : Bool + Print information on screen + concat_modes : String + Which dataset modes are going to be concatenated into a single one + stack : Bool + Whether datasets are stacked together + + Returns + ------- + datasets : Dict + Datasets ready to go + datasets_cfg : Dict + Dataset configurations + """ + if verbose: + print0(pcolor('#' * 60, 'green')) + print0(pcolor('### Preparing Datasets', 'green')) + print0(pcolor('#' * 60, 'green')) + + font = {'color': 'yellow', 'attrs': ('bold', 'dark')} + + datasets_cfg = {} + for key in cfg.__dict__.keys(): + datasets_cfg[key] = cfg.__dict__[key] + for mode in ['train', 'validation']: + if key.startswith(mode) and key != mode and mode in cfg.__dict__.keys(): + datasets_cfg[key] = to_namespace(merge_dict(deepcopy( + cfg.__dict__[mode].__dict__), cfg.__dict__[key].__dict__)) + + datasets = {} + for key, val in list(datasets_cfg.items()): + if 'name' in val.__dict__.keys(): + if verbose: + print0(pcolor('###### {}'.format(key), **font)) + datasets[key] = setup_dataset(val, verbose=verbose) + datasets_cfg[key] = [datasets_cfg[key]] * len(datasets[key]) + for mode in concat_modes: + if key.startswith(mode) and len(datasets[key]) > 1: + datasets[key] = ConcatDataset(datasets[key]) + else: + datasets_cfg.pop(key) + + if stack: + datasets = stack_datasets(datasets) + + modes = ['train', 'mixed', 'validation', 'test'] + reduced_datasets_cfg = {key: [] for key in modes} + for key, val in datasets_cfg.items(): + for mode in modes: + if key.startswith(mode): + reduced_datasets_cfg[mode].append(val) + for key in list(reduced_datasets_cfg.keys()): + reduced_datasets_cfg[key] = flatten(reduced_datasets_cfg[key]) + if len(reduced_datasets_cfg[key]) == 0: + reduced_datasets_cfg.pop(key) + datasets_cfg = reduced_datasets_cfg + + if 'train' in datasets_cfg: + datasets_cfg['train'] = datasets_cfg['train'][0] + + return datasets, datasets_cfg + + +def setup_metrics(cfg): + """ + Set metrics up for evaluation + + Parameters + ---------- + cfg : Config + Configuration file + + Returns + ------- + tasks : Dict + Dictionary containing metric classes for requested tasks + """ + + methods = { + 'depth': DepthEvaluation, + } + + available_tasks = [key for key in cfg.__dict__.keys() if key is not 'tasks'] + requested_tasks = cfg_has(cfg, 'tasks', available_tasks) + tasks = [task for task in available_tasks if task in requested_tasks and task in methods] + + return {task: methods[task](cfg.__dict__[task]) for task in tasks} + + +def worker_init_fn(worker_id): + """Function to initialize workers""" + time_seed = np.array(time.time(), dtype=np.int32) + np.random.seed(time_seed + worker_id) + + +def get_datasampler(dataset, shuffle): + """Return distributed data sampler""" + return torch.utils.data.distributed.DistributedSampler( + dataset, shuffle=shuffle, + num_replicas=world_size(), rank=rank()) + + +def no_collate(batch): + """Dummy function to use when dataset is not to be collated""" + return batch + + +@iterate1 +def setup_dataloader(dataset, cfg, mode): + """ + Create a dataloader class + + Parameters + ---------- + mode : String {'train', 'validation', 'test'} + Mode from which we want the dataloader + dataset : Dataset + List of datasets from which to create dataloaders + cfg : Config + Model configuration (cf. configs/default_config.py) + + Returns + ------- + dataloaders : list[Dataloader] + List of created dataloaders for each input dataset + """ + ddp = dist_mode() == 'ddp' + shuffle = 'train' in mode + return DataLoader(dataset, + batch_size=cfg_has(cfg, 'batch_size', 1), + pin_memory=cfg_has(cfg, 'pin_memory', True), + num_workers=cfg_has(cfg, 'num_workers', 8), + worker_init_fn=worker_init_fn, + shuffle=False if ddp else shuffle, + sampler=get_datasampler(dataset, shuffle=shuffle) if ddp else None, + collate_fn=None if cfg_has(cfg, 'collate', True) else no_collate, + ) + + +def reduce(data, modes, train_modes): + """ + Reduce dictionary values + + Parameters + ---------- + data : Dict + Dictionary with data for reduction + modes : String + Data mode ('train', 'validation', 'test') + train_modes : list[String] + Which modes are training modes + + Returns + ------- + reduced : Dict + Dictionary with reduced information + """ + reduced = { + mode: flatten([val for key, val in data.items() if mode in key]) + for mode in modes + } + for key, val in list(reduced.items()): + if len(val) == 0: + reduced.pop(key) + for mode in keys_in(reduced, train_modes): + reduced[mode] = reduced[mode][0] + return reduced + + +def stack_datasets(datasets): + """ + Stack datasets together for training/validation + + Parameters + ---------- + datasets : Dict + Dictionary containing datasets + + Returns + ------- + stacked_datasets: : Dict + Dictionary containing stacked datasets + """ + all_modes = ['train', 'mixed', 'validation', 'test'] + train_modes = ['train', 'mixed'] + + stacked_datasets = OrderedDict() + + for mode in all_modes: + stacked_datasets[mode] = [] + for key, val in datasets.items(): + if mode in key: + stacked_datasets[mode].append(val) + stacked_datasets[mode] = flatten(stacked_datasets[mode]) + + for mode in train_modes: + length = len(stacked_datasets[mode]) + if length == 1: + stacked_datasets[mode] = stacked_datasets[mode][0] + elif length > 1: + stacked_datasets[mode] = ConcatDataset(stacked_datasets[mode]) + for key in list(datasets.keys()): + if key.startswith(mode) and key != mode: + datasets.pop(key) + + return stacked_datasets diff --git a/vidar/utils/tensor.py b/vidar/utils/tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..497404b1287cfa9b4acb067912cd11d07aa78c29 --- /dev/null +++ b/vidar/utils/tensor.py @@ -0,0 +1,315 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from functools import reduce + +import torch +import torch.nn.functional as tfn + +from vidar.utils.decorators import iterate1 +from vidar.utils.types import is_tensor, is_dict, is_seq + + +@iterate1 +def interpolate(tensor, size, scale_factor, mode, align_corners): + """ + Interpolate a tensor to a different resolution + + Parameters + ---------- + tensor : torch.Tensor + Input tensor [B,?,H,W] + size : Tuple + Interpolation size (H,W) + scale_factor : Float + Scale factor for interpolation + mode : String + Interpolation mode + align_corners : Bool + Corner alignment flag + + Returns + ------- + tensor : torch.Tensor + Interpolated tensor [B,?,h,w] + """ + if is_tensor(size): + size = size.shape[-2:] + return tfn.interpolate( + tensor, size=size, scale_factor=scale_factor, + mode=mode, align_corners=align_corners, recompute_scale_factor=False, + ) + + +def masked_average(loss, mask, eps=1e-7): + """Calculates the average of a tensor considering mask information""" + return (loss * mask).sum() / (mask.sum() + eps) + + +def multiply_mask(data, mask): + """Multiplies a tensor with a mask""" + return data if (data is None or mask is None) else data * mask + + +def multiply_args(*args): + """Multiplies input arguments""" + valids = [v for v in args if v is not None] + return None if not valids else reduce((lambda x, y: x * y), valids) + + +def grid_sample(tensor, grid, padding_mode, mode, align_corners): + return tfn.grid_sample(tensor, grid, + padding_mode=padding_mode, mode=mode, align_corners=align_corners) + + +def pixel_grid(hw, b=None, with_ones=False, device=None, normalize=False): + """ + Creates a pixel grid for image operations + + Parameters + ---------- + hw : Tuple + Height/width of the grid + b : Int + Batch size + with_ones : Bool + Stack an extra channel with 1s + device : String + Device where the grid will be created + normalize : Bool + Whether the grid is normalized between [-1,1] + + Returns + ------- + grid : torch.Tensor + Output pixel grid [B,2,H,W] + """ + if is_tensor(hw): + b, hw = hw.shape[0], hw.shape[-2:] + if is_tensor(device): + device = device.device + hi, hf = 0, hw[0] - 1 + wi, wf = 0, hw[1] - 1 + yy, xx = torch.meshgrid([torch.linspace(hi, hf, hw[0], device=device), + torch.linspace(wi, wf, hw[1], device=device)], indexing='ij') + if with_ones: + grid = torch.stack([xx, yy, torch.ones(hw, device=device)], 0) + else: + grid = torch.stack([xx, yy], 0) + if b is not None: + grid = grid.unsqueeze(0).repeat(b, 1, 1, 1) + if normalize: + grid = norm_pixel_grid(grid) + return grid + + +def norm_pixel_grid(grid, hw=None, in_place=False): + """ + Normalize a pixel grid to be between [0,1] + + Parameters + ---------- + grid : torch.Tensor + Grid to be normalized [B,2,H,W] + hw : Tuple + Height/Width for normalization + in_place : Bool + Whether the operation is done in place or not + + Returns + ------- + grid : torch.Tensor + Normalized grid [B,2,H,W] + """ + if hw is None: + hw = grid.shape[-2:] + if not in_place: + grid = grid.clone() + grid[:, 0] = 2.0 * grid[:, 0] / (hw[1] - 1) - 1.0 + grid[:, 1] = 2.0 * grid[:, 1] / (hw[0] - 1) - 1.0 + return grid + + +def unnorm_pixel_grid(grid, hw=None, in_place=False): + """ + Unnormalize pixel grid to be between [0,H] and [0,W] + + Parameters + ---------- + grid : torch.Tensor + Grid to be normalized [B,2,H,W] + hw : Tuple + Height/width for unnormalization + in_place : Bool + Whether the operation is done in place or not + + Returns + ------- + grid : torch.Tensor + Unnormalized grid [B,2,H,W] + """ + if hw is None: + hw = grid.shape[-2:] + if not in_place: + grid = grid.clone() + grid[:, 0] = 0.5 * (hw[1] - 1) * (grid[:, 0] + 1) + grid[:, 1] = 0.5 * (hw[0] - 1) * (grid[:, 1] + 1) + return grid + + +def match_scales(image, targets, num_scales, + mode='bilinear', align_corners=True): + """ + Creates multiple resolution versions of the input to match another list of tensors + + Parameters + ---------- + image : torch.Tensor + Input image [B,?,H,W] + targets : list[torch.Tensor] + Target resolutions + num_scales : int + Number of scales to consider + mode : String + Interpolation mode + align_corners : Bool + Corner alignment flag + + Returns + ------- + images : list[torch.Tensor] + List containing tensors in the required resolutions + """ + # For all scales + images = [] + image_shape = image.shape[-2:] + for i in range(num_scales): + target_shape = targets[i].shape + # If image shape is equal to target shape + if same_shape(image_shape, target_shape): + images.append(image) + else: + # Otherwise, interpolate + images.append(interpolate_image( + image, target_shape, mode=mode, align_corners=align_corners)) + # Return scaled images + return images + + +def cat_channel_ones(tensor, n=1): + """ + Concatenate tensor with an extra channel of ones + + Parameters + ---------- + tensor : torch.Tensor + Tensor to be concatenated + n : Int + Which channel will be concatenated + + Returns + ------- + cat_tensor : torch.Tensor + Concatenated tensor + """ + # Get tensor shape with 1 channel + shape = list(tensor.shape) + shape[n] = 1 + # Return concatenation of tensor with ones + return torch.cat([tensor, torch.ones(shape, + device=tensor.device, dtype=tensor.dtype)], n) + + +def same_shape(shape1, shape2): + """Checks if two shapes are the same""" + if len(shape1) != len(shape2): + return False + for i in range(len(shape1)): + if shape1[i] != shape2[i]: + return False + return True + + +def interpolate_image(image, shape=None, scale_factor=None, mode='bilinear', + align_corners=True, recompute_scale_factor=False): + """ + Interpolate an image to a different resolution + + Parameters + ---------- + image : torch.Tensor + Image to be interpolated [B,?,h,w] + shape : torch.Tensor or tuple + Output shape [H,W] + scale_factor : Float + Scale factor for output shape + mode : String + Interpolation mode + align_corners : Bool + True if corners will be aligned after interpolation + recompute_scale_factor : Bool + True if scale factor is recomputed + + Returns + ------- + image : torch.Tensor + Interpolated image [B,?,H,W] + """ + assert shape is not None or scale_factor is not None, 'Invalid option for interpolate_image' + if mode == 'nearest': + align_corners = None + # Take last two dimensions as shape + if shape is not None: + if is_tensor(shape): + shape = shape.shape + if len(shape) > 2: + shape = shape[-2:] + # If the shapes are the same, do nothing + if same_shape(image.shape[-2:], shape): + return image + # Interpolate image to match the shape + return tfn.interpolate(image, size=shape, scale_factor=scale_factor, + mode=mode, align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor) + + +def check_assert(pred, gt, atol=1e-5, rtol=1e-5): + """ + Check two dictionaries with allclose assertions + + Parameters + ---------- + pred : Dict + Dictionary with predictions + gt : Dict + Dictionary with ground-truth + atol : Float + Absolute tolerance + rtol : Float + Relative tolerance + """ + for key in gt.keys(): + if key in pred.keys(): + # assert key in pred and key in gt + if is_dict(pred[key]): + check_assert(pred[key], gt[key]) + elif is_seq(pred[key]): + for val1, val2 in zip(pred[key], gt[key]): + if is_tensor(val1): + assert torch.allclose(val1, val2, atol=atol, rtol=rtol), \ + f'Assert error in {key} : {val1.mean().item()} x {val2.mean().item()}' + else: + assert val1 == val2, \ + f'Assert error in {key} : {val1} x {val2}' + else: + if is_tensor(pred[key]): + assert torch.allclose(pred[key], gt[key], atol=atol, rtol=rtol), \ + f'Assert error in {key} : {pred[key].mean().item()} x {gt[key].mean().item()}' + else: + assert pred[key] == gt[key], \ + f'Assert error in {key} : {pred[key]} x {gt[key]}' + + +def interleave(data, b): + """Interleave data considering multiple batches""" + data_interleave = data.unsqueeze(1).expand(-1, b, *data.shape[1:]) + return data_interleave.reshape(-1, *data.shape[1:]) diff --git a/vidar/utils/types.py b/vidar/utils/types.py new file mode 100644 index 0000000000000000000000000000000000000000..853a54cb3bc15d4c8281c925b3f79f9bfe6bda85 --- /dev/null +++ b/vidar/utils/types.py @@ -0,0 +1,61 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +from argparse import Namespace + +import numpy as np +import torch + + +def is_numpy(data): + """Checks if data is a numpy array.""" + return isinstance(data, np.ndarray) + + +def is_tensor(data): + """Checks if data is a torch tensor.""" + return type(data) == torch.Tensor + + +def is_tuple(data): + """Checks if data is a tuple.""" + return isinstance(data, tuple) + + +def is_list(data): + """Checks if data is a list.""" + return isinstance(data, list) + + +def is_double_list(data): + """Checks if data is a double list (list of lists)""" + return is_list(data) and len(data) > 0 and is_list(data[0]) + + +def is_dict(data): + """Checks if data is a dictionary.""" + return isinstance(data, dict) + + +def is_str(data): + """Checks if data is a string.""" + return isinstance(data, str) + + +def is_int(data): + """Checks if data is an integer.""" + return isinstance(data, int) + + +def is_seq(data): + """Checks if data is a list or tuple.""" + return is_tuple(data) or is_list(data) + + +def is_namespace(data): + """Check if data is a Namespace""" + return isinstance(data, Namespace) + + +def exists(data): + """Check if data exists (it is not None)""" + return data is not None diff --git a/vidar/utils/viz.py b/vidar/utils/viz.py new file mode 100644 index 0000000000000000000000000000000000000000..74e331b1c4d325e5b294cd0f3043490b6c1bc161 --- /dev/null +++ b/vidar/utils/viz.py @@ -0,0 +1,247 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import flow_vis +import numpy as np +import torch +from matplotlib.cm import get_cmap + +from vidar.utils.decorators import iterate1 +from vidar.utils.depth import depth2inv +from vidar.utils.types import is_tensor, is_list + + +def flow_to_color(flow_uv, clip_flow=None): + """ + Calculate color from optical flow + + Parameters + ---------- + flow_uv : np.Array + Optical flow [H,W,2] + clip_flow : Float + Clipping value for optical flow + + Returns + ------- + colors : np.array + Optical flow colormap [H,W,3] + """ + # Clip if requested + if clip_flow is not None: + flow_uv = np.clip(flow_uv, -clip_flow, clip_flow) + # Get optical flow channels + u = flow_uv[:, :, 0] + v = flow_uv[:, :, 1] + # Calculate maximum radian + rad_max = np.sqrt(2) * clip_flow if clip_flow is not None else \ + np.max(np.sqrt(np.square(u) + np.square(v))) + # Normalize optical flow channels + epsilon = 1e-5 + u = u / (rad_max + epsilon) + v = v / (rad_max + epsilon) + # Return colormap [0,1] + return flow_vis.flow_uv_to_colors(u, v, convert_to_bgr=False) / 255 + + +@iterate1 +@iterate1 +def viz_inv_depth(inv_depth, normalizer=None, percentile=95, + colormap='plasma', filter_zeros=False): + """ + Converts an inverse depth map to a colormap for visualization. + + Parameters + ---------- + inv_depth : torch.Tensor + Inverse depth map to be converted [B,1,H,W] + normalizer : Float + Value for inverse depth map normalization + percentile : Float + Percentile value for automatic normalization + colormap : String + Colormap to be used + filter_zeros : Bool + If True, do not consider zero values during normalization + + Returns + ------- + colormap : np.Array [H,W,3] + Colormap generated from the inverse depth map + """ + if is_list(inv_depth): + return [viz_inv_depth( + inv[0], normalizer, percentile, colormap, filter_zeros) + for inv in inv_depth] + # If a tensor is provided, convert to numpy + if is_tensor(inv_depth): + # If it has a batch size, use first one + if len(inv_depth.shape) == 4: + inv_depth = inv_depth[0] + # Squeeze if depth channel exists + if len(inv_depth.shape) == 3: + inv_depth = inv_depth.squeeze(0) + inv_depth = inv_depth.detach().cpu().numpy() + cm = get_cmap(colormap) + if normalizer is None: + if (inv_depth > 0).sum() == 0: + normalizer = 1.0 + else: + normalizer = np.percentile( + inv_depth[inv_depth > 0] if filter_zeros else inv_depth, percentile) + inv_depth = inv_depth / (normalizer + 1e-6) + colormap = cm(np.clip(inv_depth, 0., 1.0))[:, :, :3] + colormap[inv_depth == 0] = 0 + return colormap + + +@iterate1 +@iterate1 +def viz_depth(depth, *args, **kwargs): + """Same as viz_inv_depth, but takes depth as input instead""" + return viz_inv_depth(depth2inv(depth), *args, **kwargs) + + +@iterate1 +@iterate1 +def viz_normals(normals): + """ + Converts normals map to a colormap for visualization. + + Parameters + ---------- + normals : torch.Tensor + Inverse depth map to be converted [B,3,H,W] + + Returns + ------- + colormap : np.Array + Colormap generated from the normals map [H,W,3] + """ + # If a tensor is provided, convert to numpy + if is_tensor(normals): + normals = normals.permute(1, 2, 0).detach().cpu().numpy() + return (normals + 1) / 2 + + +@iterate1 +@iterate1 +def viz_optical_flow(optflow, clip_value=100.): + """ + Returns a colorized version of an optical flow map + + Parameters + ---------- + optflow : torch.Tensor + Optical flow to be colorized (NOT in batch) [2,H,W] + clip_value : Float + Optical flow clip value for visualization + + Returns + ------- + colorized : np.Array + Colorized version of the input optical flow [H,W,3] + """ + # If a tensor is provided, convert to numpy + if is_list(optflow): + return [viz_optical_flow(opt[0]) for opt in optflow] + if is_tensor(optflow): + if len(optflow.shape) == 4: + optflow = optflow[0] + optflow = optflow.permute(1, 2, 0).detach().cpu().numpy() + # Return colorized optical flow + return flow_to_color(optflow, clip_flow=clip_value) + + +@iterate1 +@iterate1 +def viz_photo(photo, colormap='viridis', normalize=False): + """ + Returns a colorized version of the photometric loss + + Parameters + ---------- + photo : torch.Tensor + Per-pixel photometric error + colormap : String + Which colormap to use + normalize : Bool + Whether the photometric error should be normalized between [0,1] + + Returns + ------- + colorized : np.Array + Colorized version of the photometric error [H,W,3] + """ + if is_tensor(photo): + if len(photo.shape) == 4: + photo = photo[0] + if len(photo.shape) == 3: + photo = photo.squeeze(0) + photo = photo.detach().cpu().numpy() + cm = get_cmap(colormap) + if normalize: + photo -= photo.min() + photo /= photo.max() + colormap = cm(np.clip(photo, 0., 1.0))[:, :, :3] + colormap[photo == 0] = 0 + return colormap + + +@iterate1 +@iterate1 +def viz_semantic(semantic, ontology): + """ + Returns a colorized version of a semantic map + + Parameters + ---------- + semantic : torch.Tensor + Semantic map to be colorized [B,1,H,W] + ontology : Dict + Dictionary mapping between class and color + + Returns + ------- + colorized : np.Array + Colorized version of the semantic map [H,W,3] + """ + # If it is a tensor, cast to numpy + if is_tensor(semantic): + if semantic.dim() == 3: + semantic = semantic.squeeze(0) + semantic = semantic.detach().cpu().numpy() + # Create and populate color map + color = np.zeros((semantic.shape[0], semantic.shape[1], 3)) + for key in ontology.keys(): + key_color = np.array(ontology[key]['color']) + if is_tensor(key_color): + key_color = key_color.detach().cpu().numpy() + color[semantic == int(key)] = key_color / 255. + # Return colored semantic map + return color + + +@iterate1 +@iterate1 +def viz_camera(camera): + """ + Returns a colorized version of a camera viewing rays + + Parameters + ---------- + camera : Camera of torch.Tensor + Input camera or viewing rays + + Returns + ------- + colorized : np.Array + Colorized version of the camera viewing rays [H,W,3] + """ + if is_tensor(camera): + # If it's a tensor, reshape it + rays = camera[-3:].permute(1, 2, 0).detach().cpu().numpy() + else: + # If it's a camera, get viewing rays + rays = camera.no_translation().get_viewdirs(normalize=True, flatten=False, to_world=True) + rays = rays[0].permute(1, 2, 0).detach().cpu().numpy() + return (rays + 1) / 2 diff --git a/vidar/utils/volume.py b/vidar/utils/volume.py new file mode 100644 index 0000000000000000000000000000000000000000..9792bb5d59e51e1de389c8996a0b6351f8831dc4 --- /dev/null +++ b/vidar/utils/volume.py @@ -0,0 +1,146 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import numpy as np +import torch + +from vidar.geometry.camera import Camera +from vidar.utils.tensor import grid_sample +from vidar.utils.types import is_tensor + + +def warp_bins(rgb, cam, bins): + """ + Warp an image based on depth bins + + Parameters + ---------- + rgb : torch.Tensor [B,?,H,W] + Input image for warping + cam : Camera + Input camera + bins : torch.Tensor + Depth bins for warping + + Returns + ------- + warped : torch.Tensor + Warped images for each depth bin + """ + ones = torch.ones((1, *cam.hw)).to(rgb.device) + volume = torch.stack([depth * ones for depth in bins], 1) + coords_volume = cam.coords_from_cost_volume(volume) + return grid_sample( + rgb.repeat(len(bins), 1, 1, 1), coords_volume[0].type(rgb.dtype), + padding_mode='zeros', mode='bilinear', align_corners=True) + + +def sample(grid, pred): + """ + Sample a grid based on predictions + + Parameters + ---------- + grid : torch.Tensor + Grid to be sampled [B,?,H,W] + pred : torch.Tensor + Coordinate predictions [B,2,H,W] + + Returns + ------- + values : torch.Tensor + Sampled grid[B,?,H,W] + """ + n, _, h, w = grid.shape + coords = pred.permute(1, 2, 0).reshape(-1, 1, 1, 1).repeat(1, 1, 1, 2) + coords = 2 * coords / (n - 1) - 1 + grid = grid.permute(2, 3, 0, 1).reshape(-1, 1, n, 1).repeat(1, 1, 1, 2) + values = grid_sample(grid, coords, + padding_mode='zeros', mode='bilinear', align_corners=True) + return values.reshape(h, w, 1, 1).permute(2, 3, 0, 1) + + +def compute_depth_bin(min_depth, max_depth, num_bins, i): + """ + Calculate a single SID depth bin + + Parameters + ---------- + min_depth : Float + Minimum depth value + max_depth : Float + Maximum depth value + num_bins : Int + Number of depth bins + i : Int + Index of the depth bin in the interval + + Returns + ------- + bin : torch.Tensor + Corresponding depth bin + """ + return torch.exp(np.log(min_depth) + np.log(max_depth / min_depth) * i / (num_bins - 1)).\ + clamp(min=min_depth, max=max_depth) + + +def uncompute_depth_bin(min_depth, max_depth, num_bins, depth): + """ + Recover the SID bin index from a depth value + + Parameters + ---------- + min_depth : Float + Minimum depth value + max_depth : Float + Maximum depth value + num_bins : Int + Number of depth bins + depth : torch.Tensor + Depth value + + Returns + ------- + index : torch.Tensor + Index for the depth value in the SID interval + """ + return (num_bins - 1) * ((torch.log(depth) - np.log(min_depth)) / + np.log(max_depth / min_depth)).clamp(min=0, max=num_bins) + + +def compute_depth_bins(min_depth, max_depth, num_bins, mode): + """ + Compute depth bins for an interval + + Parameters + ---------- + min_depth : Float + Minimum depth value + max_depth : Float + Maximum depth value + num_bins : Int + Number of depth bins + mode : String + Depth discretization mode + + Returns + ------- + bins : torch.Tensor + Discretized depth bins + """ + if is_tensor(min_depth): + min_depth = min_depth.detach().cpu() + if is_tensor(max_depth): + max_depth = max_depth.detach().cpu() + if mode == 'inverse': + depth_bins = 1. / np.linspace( + 1. / max_depth, 1. / min_depth, num_bins)[::-1] + elif mode == 'linear': + depth_bins = np.linspace( + min_depth, max_depth, num_bins) + elif mode == 'sid': + depth_bins = np.array( + [np.exp(np.log(min_depth) + np.log(max_depth / min_depth) * i / (num_bins - 1)) + for i in range(num_bins)]) + else: + raise NotImplementedError + return torch.from_numpy(depth_bins).float() diff --git a/vidar/utils/write.py b/vidar/utils/write.py new file mode 100644 index 0000000000000000000000000000000000000000..89a10398e81f6d2d5f7be965df9f6b4000b7783f --- /dev/null +++ b/vidar/utils/write.py @@ -0,0 +1,132 @@ +# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. + +import os +import pickle as pkl + +import cv2 +import numpy as np +import torchvision.transforms as transforms + +from vidar.utils.decorators import multi_write +from vidar.utils.types import is_tensor, is_numpy + + +def create_folder(filename): + """Create a new folder if it doesn't exist""" + if '/' in filename: + os.makedirs(os.path.dirname(filename), exist_ok=True) + + +def write_pickle(filename, data): + """ + Write a pickle file + + Parameters + ---------- + filename : String + File where the pickle file will be saved + data : Value + Data to be saved + """ + create_folder(filename) + if not filename.endswith('.pkl'): + filename = filename + '.pkl' + pkl.dump(data, open(filename, 'wb')) + + +def write_npz(filename, data): + """ + Write a numpy compressed file + + Parameters + ---------- + filename : String + File where the numpy file will be saved + data : Value + Data to be saved + """ + np.savez_compressed(filename, **data) + + +@multi_write +@multi_write +def write_image(filename, image): + """ + Write an image to file + + Parameters + ---------- + filename : String + File where image will be saved + image : np.Array [H,W,3] + RGB image + """ + # Create folder if it doesn't exist + create_folder(filename) + # If image is a tensor + if is_tensor(image): + if len(image.shape) == 4: + image = image[0] + image = image.detach().cpu().numpy().transpose(1, 2, 0) + cv2.imwrite(filename, image[:, :, ::-1] * 255) + # If image is a numpy array + elif is_numpy(image): + cv2.imwrite(filename, image[:, :, ::-1] * 255) + # Otherwise, assume it's a PIL image + else: + image.save(filename) + + +@multi_write +def write_depth(filename, depth, intrinsics=None): + """ + Write a depth map to file, and optionally its corresponding intrinsics. + + Parameters + ---------- + filename : String + File where depth map will be saved (.npz or .png) + depth : np.Array or torch.Tensor + Depth map [H,W] + intrinsics : np.Array + Optional camera intrinsics matrix [3,3] + """ + # If depth is a tensor + if is_tensor(depth): + depth = depth.detach().squeeze().cpu().numpy() + # If intrinsics is a tensor + if is_tensor(intrinsics): + intrinsics = intrinsics.detach().cpu().numpy() + # If we are saving as a .npz + if filename.endswith('.npz'): + np.savez_compressed(filename, depth=depth, intrinsics=intrinsics) + # If we are saving as a .png + elif filename.endswith('.png'): + depth = transforms.ToPILImage()((depth * 256).astype(np.int32)) + depth.save(filename) + # Something is wrong + else: + raise NotImplementedError('Depth filename not valid.') + + +@multi_write +def write_optical_flow(filename, optflow): + """ + Write a depth map to file, and optionally its corresponding intrinsics. + + Parameters + ---------- + filename : String + File where depth map will be saved (.npz or .png) + optflow : np.Array or torch.Tensor + Optical flow map [H,W] + """ + # If depth is a tensor + if is_tensor(optflow): + optflow = optflow.detach().squeeze().cpu().numpy() + # If we are saving as a .npz + if filename.endswith('.npz'): + np.savez_compressed(filename, optflow=optflow) + # Something is wrong + else: + raise NotImplementedError('Optical flow filename not valid.') \ No newline at end of file