diff --git a/.gitattributes b/.gitattributes
index c7d9f3332a950355d5a77d85000f05e6f45435ea..1e344cbe6b61128fc25d39a484f466b934eb568d 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+**/*.gif filter=lfs diff=lfs merge=lfs -text
+**/*.png filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..74a3b92ba2e58b784915ce0133c8d8cbc5c39b07
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,8 @@
+*.pyc
+.idea
+*ZZZ*
+sandbox
+__pycache__
+wandb
+temp
+code
diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 0000000000000000000000000000000000000000..01e1483b59134eab44449bafe21dfec846e66308
--- /dev/null
+++ b/.gitmodules
@@ -0,0 +1,3 @@
+[submodule "externals/camviz"]
+ path = externals/camviz
+ url = git@github.com:TRI-ML/camviz.git
diff --git a/LICENSE.md b/LICENSE.md
new file mode 100644
index 0000000000000000000000000000000000000000..f59548d1a37a9aeec57a98ee62a7244716eefb6c
--- /dev/null
+++ b/LICENSE.md
@@ -0,0 +1,409 @@
+Attribution-NonCommercial 4.0 International
+
+=======================================================================
+
+Creative Commons Corporation ("Creative Commons") is not a law firm and
+does not provide legal services or legal advice. Distribution of
+Creative Commons public licenses does not create a lawyer-client or
+other relationship. Creative Commons makes its licenses and related
+information available on an "as-is" basis. Creative Commons gives no
+warranties regarding its licenses, any material licensed under their
+terms and conditions, or any related information. Creative Commons
+disclaims all liability for damages resulting from their use to the
+fullest extent possible.
+
+Using Creative Commons Public Licenses
+
+Creative Commons public licenses provide a standard set of terms and
+conditions that creators and other rights holders may use to share
+original works of authorship and other material subject to copyright
+and certain other rights specified in the public license below. The
+following considerations are for informational purposes only, are not
+exhaustive, and do not form part of our licenses.
+
+ Considerations for licensors: Our public licenses are
+ intended for use by those authorized to give the public
+ permission to use material in ways otherwise restricted by
+ copyright and certain other rights. Our licenses are
+ irrevocable. Licensors should read and understand the terms
+ and conditions of the license they choose before applying it.
+ Licensors should also secure all rights necessary before
+ applying our licenses so that the public can reuse the
+ material as expected. Licensors should clearly mark any
+ material not subject to the license. This includes other CC-
+ licensed material, or material used under an exception or
+ limitation to copyright. More considerations for licensors:
+ wiki.creativecommons.org/Considerations_for_licensors
+
+ Considerations for the public: By using one of our public
+ licenses, a licensor grants the public permission to use the
+ licensed material under specified terms and conditions. If
+ the licensor's permission is not necessary for any reason--for
+ example, because of any applicable exception or limitation to
+ copyright--then that use is not regulated by the license. Our
+ licenses grant only permissions under copyright and certain
+ other rights that a licensor has authority to grant. Use of
+ the licensed material may still be restricted for other
+ reasons, including because others have copyright or other
+ rights in the material. A licensor may make special requests,
+ such as asking that all changes be marked or described.
+ Although not required by our licenses, you are encouraged to
+ respect those requests where reasonable. More considerations
+ for the public:
+ wiki.creativecommons.org/Considerations_for_licensees
+
+=======================================================================
+
+Creative Commons Attribution-NonCommercial 4.0 International Public
+License
+
+By exercising the Licensed Rights (defined below), You accept and agree
+to be bound by the terms and conditions of this Creative Commons
+Attribution-NonCommercial 4.0 International Public License ("Public
+License"). To the extent this Public License may be interpreted as a
+contract, You are granted the Licensed Rights in consideration of Your
+acceptance of these terms and conditions, and the Licensor grants You
+such rights in consideration of benefits the Licensor receives from
+making the Licensed Material available under these terms and
+conditions.
+
+
+Section 1 -- Definitions.
+
+ a. Adapted Material means material subject to Copyright and Similar
+ Rights that is derived from or based upon the Licensed Material
+ and in which the Licensed Material is translated, altered,
+ arranged, transformed, or otherwise modified in a manner requiring
+ permission under the Copyright and Similar Rights held by the
+ Licensor. For purposes of this Public License, where the Licensed
+ Material is a musical work, performance, or sound recording,
+ Adapted Material is always produced where the Licensed Material is
+ synched in timed relation with a moving image.
+
+ b. Adapter's License means the license You apply to Your Copyright
+ and Similar Rights in Your contributions to Adapted Material in
+ accordance with the terms and conditions of this Public License.
+
+ c. Copyright and Similar Rights means copyright and/or similar rights
+ closely related to copyright including, without limitation,
+ performance, broadcast, sound recording, and Sui Generis Database
+ Rights, without regard to how the rights are labeled or
+ categorized. For purposes of this Public License, the rights
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
+ Rights.
+ d. Effective Technological Measures means those measures that, in the
+ absence of proper authority, may not be circumvented under laws
+ fulfilling obligations under Article 11 of the WIPO Copyright
+ Treaty adopted on December 20, 1996, and/or similar international
+ agreements.
+
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
+ any other exception or limitation to Copyright and Similar Rights
+ that applies to Your use of the Licensed Material.
+
+ f. Licensed Material means the artistic or literary work, database,
+ or other material to which the Licensor applied this Public
+ License.
+
+ g. Licensed Rights means the rights granted to You subject to the
+ terms and conditions of this Public License, which are limited to
+ all Copyright and Similar Rights that apply to Your use of the
+ Licensed Material and that the Licensor has authority to license.
+
+ h. Licensor means the individual(s) or entity(ies) granting rights
+ under this Public License.
+
+ i. NonCommercial means not primarily intended for or directed towards
+ commercial advantage or monetary compensation. For purposes of
+ this Public License, the exchange of the Licensed Material for
+ other material subject to Copyright and Similar Rights by digital
+ file-sharing or similar means is NonCommercial provided there is
+ no payment of monetary compensation in connection with the
+ exchange.
+
+ j. Share means to provide material to the public by any means or
+ process that requires permission under the Licensed Rights, such
+ as reproduction, public display, public performance, distribution,
+ dissemination, communication, or importation, and to make material
+ available to the public including in ways that members of the
+ public may access the material from a place and at a time
+ individually chosen by them.
+
+ k. Sui Generis Database Rights means rights other than copyright
+ resulting from Directive 96/9/EC of the European Parliament and of
+ the Council of 11 March 1996 on the legal protection of databases,
+ as amended and/or succeeded, as well as other essentially
+ equivalent rights anywhere in the world.
+
+ l. You means the individual or entity exercising the Licensed Rights
+ under this Public License. Your has a corresponding meaning.
+
+
+Section 2 -- Scope.
+
+ a. License grant.
+
+ 1. Subject to the terms and conditions of this Public License,
+ the Licensor hereby grants You a worldwide, royalty-free,
+ non-sublicensable, non-exclusive, irrevocable license to
+ exercise the Licensed Rights in the Licensed Material to:
+
+ a. reproduce and Share the Licensed Material, in whole or
+ in part, for NonCommercial purposes only; and
+
+ b. produce, reproduce, and Share Adapted Material for
+ NonCommercial purposes only.
+
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
+ Exceptions and Limitations apply to Your use, this Public
+ License does not apply, and You do not need to comply with
+ its terms and conditions.
+
+ 3. Term. The term of this Public License is specified in Section
+ 6(a).
+
+ 4. Media and formats; technical modifications allowed. The
+ Licensor authorizes You to exercise the Licensed Rights in
+ all media and formats whether now known or hereafter created,
+ and to make technical modifications necessary to do so. The
+ Licensor waives and/or agrees not to assert any right or
+ authority to forbid You from making technical modifications
+ necessary to exercise the Licensed Rights, including
+ technical modifications necessary to circumvent Effective
+ Technological Measures. For purposes of this Public License,
+ simply making modifications authorized by this Section 2(a)
+ (4) never produces Adapted Material.
+
+ 5. Downstream recipients.
+
+ a. Offer from the Licensor -- Licensed Material. Every
+ recipient of the Licensed Material automatically
+ receives an offer from the Licensor to exercise the
+ Licensed Rights under the terms and conditions of this
+ Public License.
+
+ b. No downstream restrictions. You may not offer or impose
+ any additional or different terms or conditions on, or
+ apply any Effective Technological Measures to, the
+ Licensed Material if doing so restricts exercise of the
+ Licensed Rights by any recipient of the Licensed
+ Material.
+
+ 6. No endorsement. Nothing in this Public License constitutes or
+ may be construed as permission to assert or imply that You
+ are, or that Your use of the Licensed Material is, connected
+ with, or sponsored, endorsed, or granted official status by,
+ the Licensor or others designated to receive attribution as
+ provided in Section 3(a)(1)(A)(i).
+
+ b. Other rights.
+
+ 1. Moral rights, such as the right of integrity, are not
+ licensed under this Public License, nor are publicity,
+ privacy, and/or other similar personality rights; however, to
+ the extent possible, the Licensor waives and/or agrees not to
+ assert any such rights held by the Licensor to the limited
+ extent necessary to allow You to exercise the Licensed
+ Rights, but not otherwise.
+
+ 2. Patent and trademark rights are not licensed under this
+ Public License.
+
+ 3. To the extent possible, the Licensor waives any right to
+ collect royalties from You for the exercise of the Licensed
+ Rights, whether directly or through a collecting society
+ under any voluntary or waivable statutory or compulsory
+ licensing scheme. In all other cases the Licensor expressly
+ reserves any right to collect such royalties, including when
+ the Licensed Material is used other than for NonCommercial
+ purposes.
+
+
+Section 3 -- License Conditions.
+
+Your exercise of the Licensed Rights is expressly made subject to the
+following conditions.
+
+ a. Attribution.
+
+ 1. If You Share the Licensed Material (including in modified
+ form), You must:
+
+ a. retain the following if it is supplied by the Licensor
+ with the Licensed Material:
+
+ i. identification of the creator(s) of the Licensed
+ Material and any others designated to receive
+ attribution, in any reasonable manner requested by
+ the Licensor (including by pseudonym if
+ designated);
+
+ ii. a copyright notice;
+
+ iii. a notice that refers to this Public License;
+
+ iv. a notice that refers to the disclaimer of
+ warranties;
+
+ v. a URI or hyperlink to the Licensed Material to the
+ extent reasonably practicable;
+
+ b. indicate if You modified the Licensed Material and
+ retain an indication of any previous modifications; and
+
+ c. indicate the Licensed Material is licensed under this
+ Public License, and include the text of, or the URI or
+ hyperlink to, this Public License.
+
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
+ reasonable manner based on the medium, means, and context in
+ which You Share the Licensed Material. For example, it may be
+ reasonable to satisfy the conditions by providing a URI or
+ hyperlink to a resource that includes the required
+ information.
+
+ 3. If requested by the Licensor, You must remove any of the
+ information required by Section 3(a)(1)(A) to the extent
+ reasonably practicable.
+
+ 4. If You Share Adapted Material You produce, the Adapter's
+ License You apply must not prevent recipients of the Adapted
+ Material from complying with this Public License.
+
+
+Section 4 -- Sui Generis Database Rights.
+
+Where the Licensed Rights include Sui Generis Database Rights that
+apply to Your use of the Licensed Material:
+
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
+ to extract, reuse, reproduce, and Share all or a substantial
+ portion of the contents of the database for NonCommercial purposes
+ only;
+
+ b. if You include all or a substantial portion of the database
+ contents in a database in which You have Sui Generis Database
+ Rights, then the database in which You have Sui Generis Database
+ Rights (but not its individual contents) is Adapted Material; and
+
+ c. You must comply with the conditions in Section 3(a) if You Share
+ all or a substantial portion of the contents of the database.
+
+For the avoidance of doubt, this Section 4 supplements and does not
+replace Your obligations under this Public License where the Licensed
+Rights include other Copyright and Similar Rights.
+
+
+Section 5 -- Disclaimer of Warranties and Limitation of Liability.
+
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
+
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
+
+ c. The disclaimer of warranties and limitation of liability provided
+ above shall be interpreted in a manner that, to the extent
+ possible, most closely approximates an absolute disclaimer and
+ waiver of all liability.
+
+
+Section 6 -- Term and Termination.
+
+ a. This Public License applies for the term of the Copyright and
+ Similar Rights licensed here. However, if You fail to comply with
+ this Public License, then Your rights under this Public License
+ terminate automatically.
+
+ b. Where Your right to use the Licensed Material has terminated under
+ Section 6(a), it reinstates:
+
+ 1. automatically as of the date the violation is cured, provided
+ it is cured within 30 days of Your discovery of the
+ violation; or
+
+ 2. upon express reinstatement by the Licensor.
+
+ For the avoidance of doubt, this Section 6(b) does not affect any
+ right the Licensor may have to seek remedies for Your violations
+ of this Public License.
+
+ c. For the avoidance of doubt, the Licensor may also offer the
+ Licensed Material under separate terms or conditions or stop
+ distributing the Licensed Material at any time; however, doing so
+ will not terminate this Public License.
+
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
+ License.
+
+
+Section 7 -- Other Terms and Conditions.
+
+ a. The Licensor shall not be bound by any additional or different
+ terms or conditions communicated by You unless expressly agreed.
+
+ b. Any arrangements, understandings, or agreements regarding the
+ Licensed Material not stated herein are separate from and
+ independent of the terms and conditions of this Public License.
+
+
+Section 8 -- Interpretation.
+
+ a. For the avoidance of doubt, this Public License does not, and
+ shall not be interpreted to, reduce, limit, restrict, or impose
+ conditions on any use of the Licensed Material that could lawfully
+ be made without permission under this Public License.
+
+ b. To the extent possible, if any provision of this Public License is
+ deemed unenforceable, it shall be automatically reformed to the
+ minimum extent necessary to make it enforceable. If the provision
+ cannot be reformed, it shall be severed from this Public License
+ without affecting the enforceability of the remaining terms and
+ conditions.
+
+ c. No term or condition of this Public License will be waived and no
+ failure to comply consented to unless expressly agreed to by the
+ Licensor.
+
+ d. Nothing in this Public License constitutes or may be interpreted
+ as a limitation upon, or waiver of, any privileges and immunities
+ that apply to the Licensor or You, including from the legal
+ processes of any jurisdiction or authority.
+
+=======================================================================
+
+Creative Commons is not a party to its public
+licenses. Notwithstanding, Creative Commons may elect to apply one of
+its public licenses to material it publishes and in those instances
+will be considered the “Licensor.” The text of the Creative Commons
+public licenses is dedicated to the public domain under the CC0 Public
+Domain Dedication. Except for the limited purpose of indicating that
+material is shared under a Creative Commons public license or as
+otherwise permitted by the Creative Commons policies published at
+creativecommons.org/policies, Creative Commons does not authorize the
+use of the trademark "Creative Commons" or any other trademark or logo
+of Creative Commons without its prior written consent including,
+without limitation, in connection with any unauthorized modifications
+to any of its public licenses or any other arrangements,
+understandings, or agreements concerning use of licensed material. For
+the avoidance of doubt, this paragraph does not form part of the
+public licenses.
+
+Creative Commons may be contacted at creativecommons.org.
+
+
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..96585d88d394ab2aba5d315c50d4c41c09bfc0ac
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,56 @@
+PROJECT ?= vidar
+WORKSPACE ?= /workspace/$(PROJECT)
+DOCKER_IMAGE ?= ${PROJECT}:latest
+
+SHMSIZE ?= 444G
+WANDB_MODE ?= run
+DOCKER_OPTS := \
+ --name ${PROJECT} \
+ --rm -it \
+ --shm-size=${SHMSIZE} \
+ -e AWS_DEFAULT_REGION \
+ -e AWS_ACCESS_KEY_ID \
+ -e AWS_SECRET_ACCESS_KEY \
+ -e WANDB_API_KEY \
+ -e WANDB_ENTITY \
+ -e WANDB_MODE \
+ -e HOST_HOSTNAME= \
+ -e OMP_NUM_THREADS=1 -e KMP_AFFINITY="granularity=fine,compact,1,0" \
+ -e OMPI_ALLOW_RUN_AS_ROOT=1 \
+ -e OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1 \
+ -e NCCL_DEBUG=VERSION \
+ -e DISPLAY=${DISPLAY} \
+ -e XAUTHORITY \
+ -e NVIDIA_DRIVER_CAPABILITIES=all \
+ -v ~/.aws:/root/.aws \
+ -v /root/.ssh:/root/.ssh \
+ -v ~/.cache:/root/.cache \
+ -v /dev/null:/dev/raw1394 \
+ -v /mnt/fsx/tmp:/tmp \
+ -v /tmp/.X11-unix/X0:/tmp/.X11-unix/X0 \
+ -v /var/run/docker.sock:/var/run/docker.sock \
+ -v /home/jiadingfang/datasets:/data \
+ -v ${PWD}:${WORKSPACE} \
+ -w ${WORKSPACE} \
+ --privileged \
+ --ipc=host \
+ --network=host
+
+NGPUS=$(shell nvidia-smi -L | wc -l)
+
+all: clean
+
+clean:
+ find . -name "*.pyc" | xargs rm -f && \
+ find . -name "__pycache__" | xargs rm -rf
+
+docker-build:
+ docker build \
+ -f docker/Dockerfile \
+ -t ${DOCKER_IMAGE} .
+
+docker-interactive: docker-build
+ nvidia-docker run ${DOCKER_OPTS} ${DOCKER_IMAGE} bash
+
+docker-run: docker-build
+ nvidia-docker run ${DOCKER_OPTS} ${DOCKER_IMAGE} bash -c "${COMMAND}"
diff --git a/configs/overfit/ddad_tiny.yaml b/configs/overfit/ddad_tiny.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..121e845fc91c8e033c957d73a9c2b6356f5b21fc
--- /dev/null
+++ b/configs/overfit/ddad_tiny.yaml
@@ -0,0 +1,31 @@
+wrapper:
+ recipe: wrapper|default
+ max_epochs: 1
+arch:
+ model:
+ file: depth/SelfSupervisedModel
+ networks:
+ depth:
+ recipe: networks/mono_depth_res_net|default
+ depth_range: [0.1,100.0]
+ pose:
+ recipe: networks/pose_net|default
+ losses:
+ reprojection:
+ recipe: losses/reprojection|default
+ smoothness:
+ recipe: losses/smoothness|default
+evaluation:
+ depth:
+ recipe: evaluation/depth|ddad_resize
+optimizers:
+ depth:
+ recipe: optimizers|adam_20_05
+ pose:
+ recipe: optimizers|adam_20_05
+datasets:
+ train:
+ recipe: datasets/ddad_tiny|train_selfsup_front
+ validation:
+ recipe: datasets/ddad_tiny|validation_front
+
diff --git a/configs/overfit/generic.yaml b/configs/overfit/generic.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..36be6b63bbeca731d631c8bad8fe49c16b36e3f2
--- /dev/null
+++ b/configs/overfit/generic.yaml
@@ -0,0 +1,27 @@
+wrapper:
+ recipe: wrapper|default
+ max_epochs: 1
+arch:
+ model:
+ file: depth/SelfSupervisedModel
+ networks:
+ depth:
+ recipe: networks/mono_depth_res_net|default
+ depth_range: [0.1,100.0]
+ pose:
+ recipe: networks/pose_net|default
+ losses:
+ reprojection:
+ recipe: losses/reprojection|default
+ smoothness:
+ recipe: losses/smoothness|default
+optimizers:
+ depth:
+ recipe: optimizers|adam_20_05
+ pose:
+ recipe: optimizers|adam_20_05
+datasets:
+ train:
+ recipe: datasets/generic|default_train
+ validation:
+ recipe: datasets/generic|default_validation
diff --git a/configs/overfit/kitti_tiny.yaml b/configs/overfit/kitti_tiny.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a8810a5f7944502ca2e94e0f79b882bfb814af47
--- /dev/null
+++ b/configs/overfit/kitti_tiny.yaml
@@ -0,0 +1,31 @@
+wrapper:
+ recipe: wrapper|default
+ max_epochs: 1
+arch:
+ model:
+ file: depth/SelfSupervisedModel
+ networks:
+ depth:
+ recipe: networks/mono_depth_res_net|default
+ depth_range: [0.1,100.0]
+ pose:
+ recipe: networks/pose_net|default
+ losses:
+ reprojection:
+ recipe: losses/reprojection|default
+ smoothness:
+ recipe: losses/smoothness|default
+evaluation:
+ depth:
+ recipe: evaluation/depth|kitti_resize
+optimizers:
+ depth:
+ recipe: optimizers|adam_20_05
+ pose:
+ recipe: optimizers|adam_20_05
+datasets:
+ train:
+ recipe: datasets/kitti_tiny|train_selfsup_mr
+ validation:
+ recipe: datasets/kitti_tiny|validation_mr
+
diff --git a/configs/papers/define/scannet_temporal_test.yaml b/configs/papers/define/scannet_temporal_test.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..9338a7da0c89fe7dafb46fd11680e18bb58aca4f
--- /dev/null
+++ b/configs/papers/define/scannet_temporal_test.yaml
@@ -0,0 +1,79 @@
+wrapper:
+ seed: 42
+ min_epochs: 0
+ max_epochs: 0
+ find_unused_parameters: True
+ flip_lr_prob: 0.5
+ validate_first: True
+ validate_flipped: False
+ sync_batch_norm: True
+evaluation:
+ rgb:
+ only_first: False
+ depth:
+ crop: ''
+ min_depth: 0.2
+ max_depth: 10.0
+ scale_output: resize
+ median_scaling: True
+ post_process: False
+arch:
+ model:
+ checkpoint: /data/vidar/models/scannet_full.ckpt
+ file: perceiver/DefineGenericModel
+ use_pose_noise: []
+ use_virtual_cameras: []
+ virtual_cameras_eval: False
+ use_virtual_rgb: False
+ augment_canonical: False
+ encode_train: all
+ encode_eval: all
+ decode_train: all
+ decode_eval: all
+ decode_encodes: False
+ task_weights: [1.0,1.0]
+ scale_loss: True
+ sample_decoded_queries: 0.0
+ network:
+ tasks: [depth]
+ depth_range: [0.1,10.]
+ image_shape: [128,192]
+ num_bands_orig: 20
+ num_bands_dirs: 10
+ max_resolution_orig: 60
+ max_resolution_dirs: 60
+ to_world: True
+ d_latents: 512
+ num_latents: 1024
+ hidden_dropout_prob: 0.25
+ num_cross_attention_heads: 1
+ num_self_attends_per_block: 8
+ num_self_attention_heads: 8
+ decoder_num_heads: 1
+ rgb_feat_dim: 960
+ rgb_feat_type: resnet_all
+ downsample_encoder: 4
+ downsample_decoder: 4
+ upsample_convex: 'convex'
+ encoder_with_rgb: True
+ decoder_with_rgb: False
+ output_mode: inv_depth
+ sample_encoding_rays: 0
+ with_monodepth: False
+datasets:
+ validation:
+ name: [ScanNetTemporal]
+ path: [/data/vidar/scannet_processed]
+ split: [test]
+ context: [-2,2]
+ stride: [10]
+ labels: [depth,pose]
+ cameras: [[0]]
+ dataloader:
+ batch_size: 1
+ pin_memory: True
+ num_workers: 4
+ augmentation:
+ resize: [128,192]
+ resize_supervision: True
+ preserve_depth: True
diff --git a/configs/papers/define/scannet_temporal_test_context_1.yaml b/configs/papers/define/scannet_temporal_test_context_1.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..67d4544a838105f58fc3e1557fcbf6a5fa3c6246
--- /dev/null
+++ b/configs/papers/define/scannet_temporal_test_context_1.yaml
@@ -0,0 +1,79 @@
+wrapper:
+ seed: 42
+ min_epochs: 0
+ max_epochs: 0
+ find_unused_parameters: True
+ flip_lr_prob: 0.5
+ validate_first: True
+ validate_flipped: False
+ sync_batch_norm: True
+evaluation:
+ rgb:
+ only_first: False
+ depth:
+ crop: ''
+ min_depth: 0.2
+ max_depth: 10.0
+ scale_output: resize
+ median_scaling: True
+ post_process: False
+arch:
+ model:
+ checkpoint: /data/vidar/models/scannet_full.ckpt
+ file: perceiver/DefineGenericModel
+ use_pose_noise: []
+ use_virtual_cameras: []
+ virtual_cameras_eval: False
+ use_virtual_rgb: False
+ augment_canonical: False
+ encode_train: all
+ encode_eval: all
+ decode_train: all
+ decode_eval: all
+ decode_encodes: False
+ task_weights: [1.0,1.0]
+ scale_loss: True
+ sample_decoded_queries: 0.0
+ network:
+ tasks: [depth]
+ depth_range: [0.1,10.]
+ image_shape: [128,192]
+ num_bands_orig: 20
+ num_bands_dirs: 10
+ max_resolution_orig: 60
+ max_resolution_dirs: 60
+ to_world: True
+ d_latents: 512
+ num_latents: 1024
+ hidden_dropout_prob: 0.25
+ num_cross_attention_heads: 1
+ num_self_attends_per_block: 8
+ num_self_attention_heads: 8
+ decoder_num_heads: 1
+ rgb_feat_dim: 960
+ rgb_feat_type: resnet_all
+ downsample_encoder: 4
+ downsample_decoder: 4
+ upsample_convex: 'convex'
+ encoder_with_rgb: True
+ decoder_with_rgb: False
+ output_mode: inv_depth
+ sample_encoding_rays: 0
+ with_monodepth: False
+datasets:
+ validation:
+ name: [ScanNetTemporal]
+ path: [/data/vidar/scannet_processed]
+ split: [test]
+ context: [1]
+ stride: [10]
+ labels: [depth,pose]
+ cameras: [[0]]
+ dataloader:
+ batch_size: 1
+ pin_memory: True
+ num_workers: 4
+ augmentation:
+ resize: [128,192]
+ resize_supervision: True
+ preserve_depth: True
diff --git a/configs/papers/depthformer/inference_kitti.yaml b/configs/papers/depthformer/inference_kitti.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d0b8ec992b1b4508ec2f39f00cdfa452b9852fc8
--- /dev/null
+++ b/configs/papers/depthformer/inference_kitti.yaml
@@ -0,0 +1,31 @@
+wrapper:
+ recipe: wrapper|default
+arch:
+ model:
+ file: depth/DepthFormerModel
+ checkpoint: /data/vidar/models/papers/final/DepthFormer_MR_selfsup_KITTI.ckpt
+ warp_context: [-1,1]
+ match_context: [-1]
+ motion_masking: True
+ matching_augmentation: False
+ freeze_teacher_and_pose: 40
+ networks:
+ transformer:
+ recipe: networks/transformer|depthformer
+ mono_depth:
+ recipe: networks/mono_depth_res_net|default
+ depth_range: [0.1,100.0]
+ multi_depth:
+ recipe: networks/multi_depth_res_net|depthformer
+ pose:
+ recipe: networks/pose_net|default
+evaluation:
+ depth:
+ recipe: evaluation/depth|kitti_resize
+datasets:
+ validation:
+ recipe: datasets/kitti|validation_mr
+ labels: [depth,pose]
+ context: [-1,1]
+save:
+ recipe: save|depth_splitname
diff --git a/configs/papers/fsm/inference_ddad.yaml b/configs/papers/fsm/inference_ddad.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9eb1b4e943d85cf653d1764e6e127030bccd6ae8
--- /dev/null
+++ b/configs/papers/fsm/inference_ddad.yaml
@@ -0,0 +1,19 @@
+wrapper:
+ recipe: wrapper|default
+arch:
+ model:
+ file: depth/FSMModel
+ checkpoint: /data/vidar/models/papers/final/FSM_MR_6cams_DDAD.ckpt
+ networks:
+ depth:
+ recipe: networks/focal_depth_res_net|fsm_ddad
+ pose:
+ recipe: networks/conv_pose_net|default
+evaluation:
+ depth:
+ recipe: evaluation/depth|ddad_resize
+datasets:
+ validation:
+ recipe: datasets/ddad|validation_6cams
+save:
+ recipe: save|depth_splitname
diff --git a/configs/papers/packnet/inference_packnet.yaml b/configs/papers/packnet/inference_packnet.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3ec7bbe8d096bdd50c712eaa86fd0e090760b2e8
--- /dev/null
+++ b/configs/papers/packnet/inference_packnet.yaml
@@ -0,0 +1,20 @@
+wrapper:
+ recipe: wrapper|default
+arch:
+ model:
+ file: depth/SelfSupervisedModel
+ checkpoint: /data/vidar/models/papers/final/PackNet_MR_selfsup_KITTI.ckpt
+ networks:
+ depth:
+ recipe: networks/packnet|default
+ depth_range: [0.1,100.0]
+ pose:
+ recipe: networks/conv_pose_net|default
+evaluation:
+ depth:
+ recipe: evaluation/depth|kitti_resize
+datasets:
+ validation:
+ recipe: datasets/kitti|validation_mr
+save:
+ recipe: save|depth_splitname
diff --git a/configs/papers/packnet/inference_resnet.yaml b/configs/papers/packnet/inference_resnet.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4a0bba4dacf818fe9182cef1b2dffdd5cfa9db1f
--- /dev/null
+++ b/configs/papers/packnet/inference_resnet.yaml
@@ -0,0 +1,20 @@
+wrapper:
+ recipe: wrapper|default
+arch:
+ model:
+ file: depth/SelfSupervisedModel
+ checkpoint: /data/vidar/models/papers/final/ResNet18_MR_selfsup_KITTI.ckpt
+ networks:
+ depth:
+ recipe: networks/mono_depth_res_net|default
+ depth_range: [0.1,100.0]
+ pose:
+ recipe: networks/pose_net|default
+evaluation:
+ depth:
+ recipe: evaluation/depth|kitti_resize
+datasets:
+ validation:
+ recipe: datasets/kitti|validation_mr
+save:
+ recipe: save|depth_splitname
diff --git a/configs/papers/selfcalib/ds_euroc.yaml b/configs/papers/selfcalib/ds_euroc.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..63664c7229c50aadfc548d9626103c345054110e
--- /dev/null
+++ b/configs/papers/selfcalib/ds_euroc.yaml
@@ -0,0 +1,40 @@
+wrapper:
+ recipe: wrapper|default
+ validate_first: False
+arch:
+ model:
+ file: depth/SelfSupervisedModel
+ use_gt_intrinsics: False
+ networks:
+ depth:
+ recipe: networks/mono_depth_res_net|default
+ depth_range: [0.1,80.0]
+ pose:
+ recipe: networks/pose_net|default
+ intrinsics:
+ file: intrinsics/IntrinsicsNet
+ camera_model: 'DS'
+ shape: [256, 384]
+ losses:
+ reprojection:
+ recipe: losses/reprojection|default
+ smoothness:
+ recipe: losses/smoothness|default
+evaluation:
+ depth:
+ recipe: evaluation/depth|kitti_resize
+optimizers:
+ depth:
+ recipe: optimizers|adam_20_05
+ lr: 0.0002
+ pose:
+ recipe: optimizers|adam_20_05
+ lr: 0.0002
+ intrinsics:
+ recipe: optimizers|adam_20_05
+ lr: 0.01
+datasets:
+ train:
+ recipe: datasets/euroc|train_selfsup
+ validation:
+ recipe: datasets/euroc|validation
diff --git a/configs/papers/selfcalib/eucm_euroc.yaml b/configs/papers/selfcalib/eucm_euroc.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6b4afceb737c59f9b9374475965f4ebd7c8ea844
--- /dev/null
+++ b/configs/papers/selfcalib/eucm_euroc.yaml
@@ -0,0 +1,40 @@
+wrapper:
+ recipe: wrapper|default
+ validate_first: False
+arch:
+ model:
+ file: depth/SelfSupervisedModel
+ use_gt_intrinsics: False
+ networks:
+ depth:
+ recipe: networks/mono_depth_res_net|default
+ depth_range: [0.1,80.0]
+ pose:
+ recipe: networks/pose_net|default
+ intrinsics:
+ file: intrinsics/IntrinsicsNet
+ camera_model: 'EUCM'
+ shape: [256, 384]
+ losses:
+ reprojection:
+ recipe: losses/reprojection|default
+ smoothness:
+ recipe: losses/smoothness|default
+evaluation:
+ depth:
+ recipe: evaluation/depth|kitti_resize
+optimizers:
+ depth:
+ recipe: optimizers|adam_20_05
+ lr: 0.0002
+ pose:
+ recipe: optimizers|adam_20_05
+ lr: 0.0002
+ intrinsics:
+ recipe: optimizers|adam_20_05
+ lr: 0.01
+datasets:
+ train:
+ recipe: datasets/euroc|train_selfsup
+ validation:
+ recipe: datasets/euroc|validation
diff --git a/configs/papers/selfcalib/ucm_euroc.yaml b/configs/papers/selfcalib/ucm_euroc.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e8e2edc37ce9a96d1fbb5594c2a0cb0e4da84684
--- /dev/null
+++ b/configs/papers/selfcalib/ucm_euroc.yaml
@@ -0,0 +1,40 @@
+wrapper:
+ recipe: wrapper|default
+ validate_first: False
+arch:
+ model:
+ file: depth/SelfSupervisedModel
+ use_gt_intrinsics: False
+ networks:
+ depth:
+ recipe: networks/mono_depth_res_net|default
+ depth_range: [0.1,80.0]
+ pose:
+ recipe: networks/pose_net|default
+ intrinsics:
+ file: intrinsics/IntrinsicsNet
+ camera_model: 'UCM'
+ shape: [256, 384]
+ losses:
+ reprojection:
+ recipe: losses/reprojection|default
+ smoothness:
+ recipe: losses/smoothness|default
+evaluation:
+ depth:
+ recipe: evaluation/depth|kitti_resize
+optimizers:
+ depth:
+ recipe: optimizers|adam_20_05
+ lr: 0.0002
+ pose:
+ recipe: optimizers|adam_20_05
+ lr: 0.0002
+ intrinsics:
+ recipe: optimizers|adam_20_05
+ lr: 0.01
+datasets:
+ train:
+ recipe: datasets/euroc|train_selfsup
+ validation:
+ recipe: datasets/euroc|validation
diff --git a/configs/papers/selfcalib/ucm_kitti.yaml b/configs/papers/selfcalib/ucm_kitti.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2c47d53942450bf5aae7dd0ff7a01c9bb07f95c5
--- /dev/null
+++ b/configs/papers/selfcalib/ucm_kitti.yaml
@@ -0,0 +1,44 @@
+wrapper:
+ recipe: wrapper|default
+ validate_first: False
+arch:
+ model:
+ file: depth/SelfSupervisedModel
+ use_gt_intrinsics: False
+ networks:
+ depth:
+ recipe: networks/mono_depth_res_net|default
+ depth_range: [0.1,100.0]
+ pose:
+ recipe: networks/pose_net|default
+ intrinsics:
+ file: intrinsics/IntrinsicsNet
+ camera_model: 'UCM'
+ shape: [192, 640]
+ losses:
+ reprojection:
+ recipe: losses/reprojection|default
+ smoothness:
+ recipe: losses/smoothness|default
+evaluation:
+ depth:
+ recipe: evaluation/depth|kitti_resize
+optimizers:
+ depth:
+ recipe: optimizers|adam_20_05
+ lr: 0.0002
+ pose:
+ recipe: optimizers|adam_20_05
+ lr: 0.0002
+ intrinsics:
+ recipe: optimizers|adam_20_05
+ lr: 0.01
+datasets:
+ train:
+ recipe: datasets/kitti|train_selfsup_mr
+ dataloader:
+ batch_size: 8
+ validation:
+ recipe: datasets/kitti|validation_mr
+ dataloader:
+ batch_size: 1
diff --git a/configs/recipes/checkpoint.yaml b/configs/recipes/checkpoint.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ffcb92e67db721cbbc10bc72d1ac01048e30cbf3
--- /dev/null
+++ b/configs/recipes/checkpoint.yaml
@@ -0,0 +1,8 @@
+default:
+ folder: /data/vidar/checkpoints
+ s3_bucket: miru-us-east-1/vidar/checkpoints
+ save_code: True
+ keep_top: 5
+default_local:
+ folder: /data/vidar/checkpoints
+ keep_top: 5
diff --git a/configs/recipes/datasets/ddad.yaml b/configs/recipes/datasets/ddad.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..583ebfbc3e1d4d1bc5d1ebce4bfc7d046950f76b
--- /dev/null
+++ b/configs/recipes/datasets/ddad.yaml
@@ -0,0 +1,32 @@
+train_selfsup_6cams:
+ name: [Ouroboros]
+ path: [/data/vidar/DDAD/ddad.json]
+ split: [train]
+ augmentation:
+ jittering: [0.2, 0.2, 0.2, 0.05]
+ resize: [384, 640]
+ dataloader:
+ batch_size: 1
+ pin_memory: True
+ num_workers: 16
+ context: [-1,1]
+ labels: [depth]
+ cameras: [[1],[5],[6],[7],[8],[9]]
+ depth_type: [lidar]
+ repeat: [100]
+validation_6cams:
+ name: [Ouroboros]
+ path: [/data/vidar/DDAD/ddad.json]
+ split: [val]
+ augmentation:
+ resize: [384, 640]
+ dataloader:
+ batch_size: 1
+ pin_memory: True
+ num_workers: 16
+ context: []
+ labels: [depth]
+ cameras: [[1],[5],[6],[7],[8],[9]]
+ depth_type: [lidar]
+
+
diff --git a/configs/recipes/datasets/ddad_tiny.yaml b/configs/recipes/datasets/ddad_tiny.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..08bbc9263e1ec555736a7a8266f3e879d81f61f5
--- /dev/null
+++ b/configs/recipes/datasets/ddad_tiny.yaml
@@ -0,0 +1,32 @@
+train_selfsup_front:
+ name: [Ouroboros]
+ path: [/data/vidar/DDAD_tiny/ddad_tiny.json]
+ split: [train]
+ augmentation:
+ jittering: [0.2, 0.2, 0.2, 0.05]
+ resize: [384, 640]
+ dataloader:
+ batch_size: 1
+ pin_memory: True
+ num_workers: 16
+ context: [-1,1]
+ labels: [depth]
+ cameras: [[1]]
+ depth_type: [lidar]
+ repeat: [100]
+validation_front:
+ name: [Ouroboros]
+ path: [/data/vidar/DDAD_tiny/ddad_tiny.json]
+ split: [train]
+ augmentation:
+ resize: [384, 640]
+ dataloader:
+ batch_size: 1
+ pin_memory: True
+ num_workers: 16
+ context: []
+ labels: [depth]
+ cameras: [[1]]
+ depth_type: [lidar]
+
+
diff --git a/configs/recipes/datasets/euroc.yaml b/configs/recipes/datasets/euroc.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..02949ad0297a2961ce5fe9f80093f1964db9d8eb
--- /dev/null
+++ b/configs/recipes/datasets/euroc.yaml
@@ -0,0 +1,30 @@
+train_selfsup:
+ name: [EUROC]
+ path: [/data/vidar/euroc/euroc_cam/cam0]
+ augmentation:
+ jittering: [0.2, 0.2, 0.2, 0.05]
+ resize: [256, 384]
+ dataloader:
+ batch_size: 16
+ pin_memory: True
+ num_workers: 16
+ context: [-1, 1]
+ strides: [[49999872, 50000128]]
+ cameras: [[0]]
+ split: [euroc-train]
+ labels: []
+ repeat: [1]
+validation:
+ name: [EUROC]
+ path: [/data/vidar/euroc/euroc_has_depth/V2_01_easy_has_depth/mav0/]
+ augmentation:
+ resize: [256, 384]
+ dataloader:
+ batch_size: 1
+ pin_memory: True
+ num_workers: 16
+ split: [euroc-val]
+ labels: [depth]
+ context: []
+ strides: [[0, 0]]
+ cameras: [[0]]
diff --git a/configs/recipes/datasets/generic.yaml b/configs/recipes/datasets/generic.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..48bd3864c6a18e7e57f0f7dbe063dcfb83f0228f
--- /dev/null
+++ b/configs/recipes/datasets/generic.yaml
@@ -0,0 +1,30 @@
+default_train:
+ name: [Generic]
+ path: [data/generic]
+ split: ['']
+ context: [-1,1]
+ cameras: [[0]]
+ labels: []
+ extension: [jpg]
+ repeat: [100]
+ augmentation:
+ jittering: [0.2, 0.2, 0.2, 0.05]
+ resize: [384, 640]
+ dataloader:
+ batch_size: 1
+ pin_memory: True
+ num_workers: 16
+default_validation:
+ name: [Generic]
+ path: [data/generic]
+ split: ['']
+ context: []
+ cameras: [[0]]
+ labels: []
+ extension: [jpg]
+ augmentation:
+ resize: [384, 640]
+ dataloader:
+ batch_size: 1
+ pin_memory: True
+ num_workers: 16
diff --git a/configs/recipes/datasets/kitti.yaml b/configs/recipes/datasets/kitti.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0d481a4a8289661b6337b9ea67d541c225f8d7b9
--- /dev/null
+++ b/configs/recipes/datasets/kitti.yaml
@@ -0,0 +1,31 @@
+train_selfsup_mr:
+ name: [KITTI]
+ path: [/data/vidar/KITTI_raw]
+ split: [data_splits/eigen_zhou_files.txt]
+ augmentation:
+ jittering: [0.2, 0.2, 0.2, 0.05]
+ resize: [192, 640]
+ dataloader:
+ batch_size: 1
+ pin_memory: True
+ num_workers: 16
+ context: [-1,1]
+ labels: []
+ cameras: [[0]]
+ single_intrinsics: [True]
+ repeat: [1]
+validation_mr:
+ name: [KITTI]
+ path: [/data/vidar/KITTI_raw]
+ split: [data_splits/eigen_test_files.txt]
+ augmentation:
+ resize: [192, 640]
+ dataloader:
+ batch_size: 1
+ pin_memory: True
+ num_workers: 16
+ context: []
+ labels: [depth]
+ cameras: [[0]]
+ single_intrinsics: [True]
+ depth_type: [velodyne,groundtruth]
diff --git a/configs/recipes/datasets/kitti_tiny.yaml b/configs/recipes/datasets/kitti_tiny.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f0f547ec63eb3fb05f1a5a46f8aec10963634b38
--- /dev/null
+++ b/configs/recipes/datasets/kitti_tiny.yaml
@@ -0,0 +1,32 @@
+train_selfsup_mr:
+ name: [KITTI]
+ path: [/data/vidar/KITTI_tiny]
+ split: [kitti_tiny.txt]
+ augmentation:
+ jittering: [0.2,0.2,0.2,0.05]
+ resize: [192,640]
+ dataloader:
+ batch_size: 1
+ pin_memory: True
+ num_workers: 16
+ context: [-1,1]
+ labels: []
+ cameras: [[0]]
+ single_intrinsics: [True]
+ repeat: [100]
+ depth_type: [velodyne]
+validation_mr:
+ name: [KITTI]
+ path: [/data/vidar/KITTI_tiny]
+ split: [kitti_tiny.txt]
+ augmentation:
+ resize: [192,640]
+ dataloader:
+ batch_size: 1
+ pin_memory: True
+ num_workers: 16
+ context: []
+ labels: [depth]
+ cameras: [[0]]
+ single_intrinsics: [True]
+ depth_type: [velodyne]
diff --git a/configs/recipes/datasets/vkitti2.yaml b/configs/recipes/datasets/vkitti2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d3dca65f125b41ddfa01409936fd2c947e6baba9
--- /dev/null
+++ b/configs/recipes/datasets/vkitti2.yaml
@@ -0,0 +1,16 @@
+tiny:
+ name: [VKITTI2]
+ path: [/data/vidar/VKITTI2_tiny]
+ split: ['']
+ augmentation:
+ resize: [192, 640]
+ resize_supervision: True
+ dataloader:
+ batch_size: 1
+ pin_memory: True
+ num_workers: 16
+ context: [-1,1]
+ labels: [depth,pose]
+ labels_context: [depth,pose]
+ cameras: [[0]]
+ depth_type: [zbuffer]
\ No newline at end of file
diff --git a/configs/recipes/evaluation/depth.yaml b/configs/recipes/evaluation/depth.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bf7ced1473d04ab3c4a6c3ae41b6b2b4c59d5004
--- /dev/null
+++ b/configs/recipes/evaluation/depth.yaml
@@ -0,0 +1,21 @@
+kitti_resize:
+ crop: garg
+ min_depth: 0.0
+ max_depth: 80.0
+ scale_output: resize
+ median_scaling: True
+ post_process: True
+kitti_crop:
+ crop: garg
+ min_depth: 0.0
+ max_depth: 80.0
+ scale_output: top-center
+ median_scaling: True
+ post_process: True
+ddad_resize:
+ crop: ''
+ min_depth: 0.0
+ max_depth: 200.0
+ scale_output: resize
+ median_scaling: True
+ post_process: True
diff --git a/configs/recipes/losses/reprojection.yaml b/configs/recipes/losses/reprojection.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a9206e53f8593d257e761df0bb8174e348dff974
--- /dev/null
+++ b/configs/recipes/losses/reprojection.yaml
@@ -0,0 +1,9 @@
+default:
+ file: ReprojectionLoss
+ automasking: True
+ reprojection_reduce_op: min
+ jitter_identity_reprojection: 0.00001
+ photometric:
+ file: PhotometricLoss
+ weight: 1.0
+ alpha: 0.85
\ No newline at end of file
diff --git a/configs/recipes/losses/smoothness.yaml b/configs/recipes/losses/smoothness.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f9be726f317d6e03af9464fabba8ba38a1da1554
--- /dev/null
+++ b/configs/recipes/losses/smoothness.yaml
@@ -0,0 +1,5 @@
+default:
+ file: SmoothnessLoss
+ normalize: True
+ weight: 0.0001
+ gamma: 0.5
diff --git a/configs/recipes/losses/supervised_depth.yaml b/configs/recipes/losses/supervised_depth.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b2736b71fb79ddd3c5f3504e3678e98e58747654
--- /dev/null
+++ b/configs/recipes/losses/supervised_depth.yaml
@@ -0,0 +1,40 @@
+huber:
+ file: SupervisedDepthLoss
+ gamma: 0.5
+ method: huber
+ mask_sparse: True
+abs_rel:
+ file: SupervisedDepthLoss
+ gamma: 0.5
+ method: abs_rel
+ mask_sparse: True
+l1log:
+ file: SupervisedDepthLoss
+ gamma: 0.5
+ method: l1log
+ mask_sparse: True
+l1:
+ file: SupervisedDepthLoss
+ gamma: 0.5
+ method: l1
+ mask_sparse: True
+mse:
+ file: SupervisedDepthLoss
+ gamma: 0.5
+ method: mse
+ mask_sparse: True
+rmse:
+ file: SupervisedDepthLoss
+ gamma: 0.5
+ method: rmse
+ mask_sparse: True
+mixture:
+ file: SupervisedDepthLoss
+ gamma: 0.5
+ method: mixture
+ mask_sparse: True
+cross_entropy:
+ file: SupervisedDepthLoss
+ gamma: 0.5
+ method: cross_entropy
+ mask_sparse: True
\ No newline at end of file
diff --git a/configs/recipes/networks/conv_pose_net.yaml b/configs/recipes/networks/conv_pose_net.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f1a0a8e0252cce457f6ab422b538d3ff9c63e2ab
--- /dev/null
+++ b/configs/recipes/networks/conv_pose_net.yaml
@@ -0,0 +1,5 @@
+default:
+ file: pose/ConvPoseNet
+ version: 18
+ pretrained: True
+ num_rgb_in: 2
\ No newline at end of file
diff --git a/configs/recipes/networks/focal_depth_res_net.yaml b/configs/recipes/networks/focal_depth_res_net.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..01f37b3d3be58f291fcbf3a4402b91584b45841e
--- /dev/null
+++ b/configs/recipes/networks/focal_depth_res_net.yaml
@@ -0,0 +1,12 @@
+fsm_ddad:
+ file: depth/FocalDepthResNet
+ encoder:
+ version: 18
+ pretrained: True
+ num_rgb_in: 1
+ decoder:
+ use_skips: True
+ activation: sigmoid
+ num_ch_out: 1
+ min_depth: 0.005
+ max_depth: 0.3
diff --git a/configs/recipes/networks/mono_depth_res_net.yaml b/configs/recipes/networks/mono_depth_res_net.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2ea5c0f134d394c739040112bc245f7060076594
--- /dev/null
+++ b/configs/recipes/networks/mono_depth_res_net.yaml
@@ -0,0 +1,10 @@
+default:
+ file: depth/MonoDepthResNet
+ encoder:
+ version: 18
+ pretrained: True
+ num_rgb_in: 1
+ decoder:
+ use_skips: True
+ activation: sigmoid
+ num_ch_out: 1
diff --git a/configs/recipes/networks/multi_depth_res_net.yaml b/configs/recipes/networks/multi_depth_res_net.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5704bb61c6eb44367c354a44b1749ed8cee5d4a6
--- /dev/null
+++ b/configs/recipes/networks/multi_depth_res_net.yaml
@@ -0,0 +1,17 @@
+depthformer:
+ file: depth/MultiDepthResNet
+ encoder:
+ version: 18
+ pretrained: True
+ input_shape: [192,640]
+ adaptive_bins: True
+ depth_range: [0.1,100.]
+ depth_bin_range: [0.1,100.0]
+ depth_binning: sid
+ num_depth_bins: 128
+ decoder:
+ use_skips: True
+ use_aux_depth: False
+ activation: sigmoid
+ num_ch_out: 1
+ depth_range: [0.1,100.]
diff --git a/configs/recipes/networks/packnet.yaml b/configs/recipes/networks/packnet.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2a8229c10e8fb31c8fe411105845054a2e32ca16
--- /dev/null
+++ b/configs/recipes/networks/packnet.yaml
@@ -0,0 +1,10 @@
+default:
+ file: depth/PackNet
+ encoder:
+ version: 18
+ pretrained: True
+ num_rgb_in: 1
+ decoder:
+ use_skips: True
+ activation: sigmoid
+ num_ch_out: 1
diff --git a/configs/recipes/networks/pose_net.yaml b/configs/recipes/networks/pose_net.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c180cea4674b2d5af7c49b6b524bfb9403b2ed66
--- /dev/null
+++ b/configs/recipes/networks/pose_net.yaml
@@ -0,0 +1,5 @@
+default:
+ file: pose/PoseNet
+ version: 18
+ pretrained: True
+ num_rgb_in: 2
\ No newline at end of file
diff --git a/configs/recipes/networks/transformer.yaml b/configs/recipes/networks/transformer.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bd0da952c5d19cc2bd003dfde69fb3a776613e29
--- /dev/null
+++ b/configs/recipes/networks/transformer.yaml
@@ -0,0 +1,12 @@
+depthformer:
+ file: transformers/MatchModule
+ context_adjustment:
+ expansion_ratio: 4
+ feat_dim: 16
+ num_blocks: 8
+ channel_dim: 128
+ nheads: 8
+ num_attn_layers: 6
+ min_depth: 0.1
+ max_depth: 100.0
+ num_bins: 128
\ No newline at end of file
diff --git a/configs/recipes/optimizers.yaml b/configs/recipes/optimizers.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6fb5ebd1ec4652f818a5647e3261dd8a1a55f99b
--- /dev/null
+++ b/configs/recipes/optimizers.yaml
@@ -0,0 +1,7 @@
+adam_20_05:
+ name: Adam
+ lr: 0.0001
+ scheduler:
+ name: StepLR
+ step_size: 20
+ gamma: 0.5
diff --git a/configs/recipes/save.yaml b/configs/recipes/save.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4a90afd4e668fb10ae2bb83de5f0e4e4a9bc6d74
--- /dev/null
+++ b/configs/recipes/save.yaml
@@ -0,0 +1,5 @@
+depth_splitname:
+ folder: /data/vidar/save
+ naming: splitname
+ rgb: [tgt]
+ depth: [viz,npz]
\ No newline at end of file
diff --git a/configs/recipes/wandb.yaml b/configs/recipes/wandb.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..759968b204b14882a9631665b7d4b18ec1d9c34e
--- /dev/null
+++ b/configs/recipes/wandb.yaml
@@ -0,0 +1,5 @@
+default:
+ folder: /data/vidar/wandb
+ entity: tri
+ project: vidar
+ num_validation_logs: 5
diff --git a/configs/recipes/wrapper.yaml b/configs/recipes/wrapper.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7293836d994ee5e8af4c593df953cc4386353278
--- /dev/null
+++ b/configs/recipes/wrapper.yaml
@@ -0,0 +1,10 @@
+default:
+ seed: 42
+ min_epochs: 0
+ max_epochs: 50
+ grad_scaler: False
+ find_unused_parameters: True
+ flip_lr_prob: 0.5
+ validate_first: True
+ validate_flipped: True
+ sync_batch_norm: True
diff --git a/data/generic/scene1/15616458296936490.jpg b/data/generic/scene1/15616458296936490.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..a9b8a988f7fea63f4ffbb6088abde4e664f1d8f3
Binary files /dev/null and b/data/generic/scene1/15616458296936490.jpg differ
diff --git a/data/generic/scene1/15616458297936490.jpg b/data/generic/scene1/15616458297936490.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..fb0177112b0b33f74199eb742872ddc1498e366e
Binary files /dev/null and b/data/generic/scene1/15616458297936490.jpg differ
diff --git a/data/generic/scene1/15616458298936492.jpg b/data/generic/scene1/15616458298936492.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..17b4ef41aa30deac5c330b0bd076f2e0fba22e81
Binary files /dev/null and b/data/generic/scene1/15616458298936492.jpg differ
diff --git a/data/generic/scene1/15616458299936482.jpg b/data/generic/scene1/15616458299936482.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..32fb198b40b9da2268b44ef815601992a16a4dbb
Binary files /dev/null and b/data/generic/scene1/15616458299936482.jpg differ
diff --git a/data/generic/scene1/15616458300936480.jpg b/data/generic/scene1/15616458300936480.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d225cb4f6abec315cc5b4e3354881ecf3e11a808
Binary files /dev/null and b/data/generic/scene1/15616458300936480.jpg differ
diff --git a/data/masks/ddad/01.png b/data/masks/ddad/01.png
new file mode 100755
index 0000000000000000000000000000000000000000..5c6a1bfcff7e67517af1b44942622ce8e75b9b14
--- /dev/null
+++ b/data/masks/ddad/01.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:04c221df55ea2cb011114b9418a114dbb9017c2090405344df505114ea1865bf
+size 4943
diff --git a/data/masks/ddad/05.png b/data/masks/ddad/05.png
new file mode 100755
index 0000000000000000000000000000000000000000..23d82588e15fe7b0119f5a2d564ef9a2604ccf26
--- /dev/null
+++ b/data/masks/ddad/05.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:aaa8817848b87232baf54a1c78fda5db0b91c0a2439be666c6243cdbc96a8bc8
+size 8006
diff --git a/data/masks/ddad/06.png b/data/masks/ddad/06.png
new file mode 100755
index 0000000000000000000000000000000000000000..bd7a27a416acfaf90133e46d8ddb97daecd2188f
--- /dev/null
+++ b/data/masks/ddad/06.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9466db07da39e90f1f92a01ce53331a1f7badf3ddd56ea05f9aa564cccdedba0
+size 5831
diff --git a/data/masks/ddad/07.png b/data/masks/ddad/07.png
new file mode 100755
index 0000000000000000000000000000000000000000..89498a2b4b0317388313d0fa7a9eb71d6f24384e
--- /dev/null
+++ b/data/masks/ddad/07.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4b99acc91e02eeb16e931e628c63a556d9580a134fab7556ae09787027ab9e22
+size 6886
diff --git a/data/masks/ddad/08.png b/data/masks/ddad/08.png
new file mode 100755
index 0000000000000000000000000000000000000000..25490ecd2e440d292a13ea8651409b9aa75a63cb
--- /dev/null
+++ b/data/masks/ddad/08.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ea91d428b0562e1592077b734277fb8634d99c66a123ef33a6ac52a9724a5aca
+size 7199
diff --git a/data/masks/ddad/09.png b/data/masks/ddad/09.png
new file mode 100755
index 0000000000000000000000000000000000000000..39c0af2a34d2349f084d63e9b6343100c36c7ba3
--- /dev/null
+++ b/data/masks/ddad/09.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9e48cffda84af21884c0f4e1e1965d4a70e9724fd895a8e936bf16ffded90a5c
+size 6517
diff --git a/demos/display_datasets/config.yaml b/demos/display_datasets/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..606012e38377ceef980a46bc26e09111b557b6ef
--- /dev/null
+++ b/demos/display_datasets/config.yaml
@@ -0,0 +1,34 @@
+
+# To download datasets use the following command:
+# wget https://tri-ml-public.s3.amazonaws.com/github/vidar/datasets/{DATASET}.tar /data/vidar
+# Don't forget to untar it afterwards, with:
+# tar xvf /data/vidar/{DATASET}.tar -C /data/vidar
+
+datasets:
+ kitti:
+ name: [KITTI]
+ path: [/data/vidar/KITTI_tiny]
+ split: [kitti_tiny.txt]
+ context: [-1,3]
+ cameras: [[0,1]]
+ labels: [depth,pose]
+ labels_context: [depth, pose]
+ depth_type: [velodyne]
+ vkitti2:
+ name: [VKITTI2]
+ path: [/data/vidar/VKITTI2_tiny]
+ split: [train]
+ context: [-2,2]
+ cameras: [[0,1]]
+ labels: [depth,pose,optical_flow]
+ labels_context: [depth,pose,optical_flow]
+ ddad:
+ name: [Ouroboros]
+ path: [/data/vidar/DDAD_tiny/ddad_tiny.json]
+ split: [train]
+ context: [-1,3]
+ cameras: [[1,5,6,7,8,9]]
+ labels: [depth,pose]
+ labels_context: [depth,pose]
+ depth_type: [lidar]
+ virtual: [False]
\ No newline at end of file
diff --git a/demos/display_datasets/display_datasets.py b/demos/display_datasets/display_datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..86fdf6f6191acbe91f8ff8938fcacd776f10e45b
--- /dev/null
+++ b/demos/display_datasets/display_datasets.py
@@ -0,0 +1,14 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import sys
+
+from display.display_sample import display_sample
+from vidar.utils.config import read_config
+from vidar.utils.data import set_random_seed
+from vidar.utils.setup import setup_datasets
+
+set_random_seed(42)
+
+cfg = read_config('demos/display_datasets/config.yaml')
+datasets = setup_datasets(cfg.datasets, stack=False)
+display_sample(datasets[0][sys.argv[1]][0][0], flip=False)
diff --git a/demos/run_network/config.yaml b/demos/run_network/config.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..d865f6752785a0c93360f33343c0f8a073fb82ed
--- /dev/null
+++ b/demos/run_network/config.yaml
@@ -0,0 +1,10 @@
+encoder:
+ version: 18
+ pretrained: True
+ num_rgb_in: 1
+decoder:
+ use_skips: True
+ activation: sigmoid
+ num_ch_out: 1
+depth_range: [0.1,200.]
+
diff --git a/demos/run_network/run_network.py b/demos/run_network/run_network.py
new file mode 100755
index 0000000000000000000000000000000000000000..cfcef146ed3229c6e2f52b8cda12dadea761bb68
--- /dev/null
+++ b/demos/run_network/run_network.py
@@ -0,0 +1,16 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import torch
+
+from vidar.arch.networks.depth.MonoDepthResNet import MonoDepthResNet
+from vidar.utils.config import read_config
+
+### Create network
+
+cfg = read_config('demos/run_network/config.yaml')
+net = MonoDepthResNet(cfg)
+
+### Create dummy input and run network
+
+rgb = torch.randn((2, 3, 128, 128))
+depth = net(rgb=rgb)['depths']
diff --git a/display/display_sample.py b/display/display_sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..c544f21d486a50c9186b16a9deda3552e72fe048
--- /dev/null
+++ b/display/display_sample.py
@@ -0,0 +1,130 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import json
+
+import numpy as np
+from camviz import BBox3D
+from camviz import Camera as CameraCV
+from camviz import Draw
+
+from vidar.geometry.camera import Camera
+from vidar.geometry.pose import Pose
+from vidar.utils.data import make_batch, fold_batch, modrem
+from vidar.utils.flip import flip_batch
+from vidar.utils.viz import viz_depth, viz_optical_flow, viz_semantic
+
+
+def change_key(dic, c, n):
+ steps = sorted(dic.keys())
+ return steps[(steps.index(c) + n) % len(steps)]
+
+
+def display_sample(data, flip=False):
+
+ tasks = ['rgb', 'depth', 'fwd_optical_flow', 'bwd_optical_flow','semantic']
+ cam_colors = ['red', 'blu', 'gre', 'yel', 'mag', 'cya'] * 100
+
+ data = make_batch(data)
+ if flip:
+ data = flip_batch(data)
+ data = fold_batch(data)
+
+ rgb = data['rgb']
+ intrinsics = data['intrinsics']
+ depth = data['depth']
+ pose = data['pose']
+
+ pose = Pose.from_dict(pose, to_global=True)
+ cam = Camera.from_dict(intrinsics, rgb, pose)
+
+ num_cams = rgb[0].shape[0]
+ wh = rgb[0].shape[-2:][::-1]
+
+ keys = [key for key in tasks if key in data.keys()]
+
+ points = {}
+ for key, val in cam.items():
+ points[key] = cam[key].reconstruct_depth_map(
+ depth[key], to_world=True).reshape(num_cams, 3, -1).permute(0, 2, 1)
+
+ draw = Draw((wh[0] * 4, wh[1] * 3), width=2100)
+ draw.add2DimageGrid('cam', (0.0, 0.0, 0.5, 1.0), n=(3, 2), res=wh)
+ draw.add3Dworld('wld', (0.5, 0.0, 1.0, 1.0), pose=cam[0].Tcw.T[0])
+
+ draw.addTexture('cam', n=num_cams)
+ draw.addBuffer3f('lidar', 1000000, n=num_cams)
+ draw.addBuffer3f('color', 1000000, n=num_cams)
+
+ with_bbox3d = 'bbox3d' in data
+ if with_bbox3d:
+ bbox3d_corners = [[BBox3D(b) for b in bb] for bb in data['bbox3d']['corners']]
+
+ with_pointcache = 'pointcache' in data
+ if with_pointcache:
+ pointcache = np.concatenate([np.concatenate(pp, 0) for pp in data['pointcache']['points']], 0)
+ draw.addBufferf('pointcache', pointcache[:, :3])
+
+ camcv = []
+ for i in range(num_cams):
+ camcv.append({key: CameraCV.from_vidar(val, i) for key, val in cam.items()})
+
+ t, k = 0, 0
+ key = keys[k]
+ change = True
+ color = True
+
+ while draw.input():
+ if draw.SPACE:
+ color = not color
+ change = True
+ if draw.RIGHT:
+ change = True
+ k = (k + 1) % len(keys)
+ while t not in data[keys[k]].keys():
+ k = (k + 1) % len(keys)
+ key = keys[k]
+ if draw.LEFT:
+ change = True
+ k = (k - 1) % len(keys)
+ while t not in data[keys[k]].keys():
+ k = (k - 1) % len(keys)
+ key = keys[k]
+ if draw.UP:
+ change = True
+ t = change_key(data[key], t, 1)
+ while t not in data[keys[k]].keys():
+ t = change_key(data[key], t, 1)
+ if draw.DOWN:
+ change = True
+ t = change_key(data[key], t, -1)
+ while t not in data[keys[k]].keys():
+ t = change_key(data[key], t, -1)
+ if change:
+ change = False
+ for i in range(num_cams):
+ img = data[key][t][i]
+ if key == 'depth':
+ img = viz_depth(img, filter_zeros=True)
+ elif key in ['fwd_optical_flow', 'bwd_optical_flow']:
+ img = viz_optical_flow(img)
+ elif key == 'semantic':
+ ontology = json.load(open('vidar/datasets/ontologies/%s.json' % data['tag'][0]))
+ img = viz_semantic(img, ontology)
+ draw.updTexture('cam%d' % i, img)
+ draw.updBufferf('lidar%d' % i, points[t][i])
+ draw.updBufferf('color%d' % i, data['rgb'][t][i])
+
+ draw.clear()
+ for i in range(num_cams):
+ draw['cam%d%d' % modrem(i, 2)].image('cam%d' % i)
+ draw['wld'].size(1).color(cam_colors[i]).points('lidar%d' % i, ('color%d' % i) if color else None)
+ for cam_key, cam_val in camcv[i].items():
+ clr = cam_colors[i] if cam_key == t else 'gra'
+ tex = 'cam%d' % i if cam_key == t else None
+ draw['wld'].object(cam_val, color=clr, tex=tex)
+ if with_bbox3d:
+ [[draw['wld'].object(b) for b in bb] for bb in bbox3d_corners]
+ if with_pointcache:
+ draw['wld'].color('whi').points('pointcache')
+
+ draw.update(30)
diff --git a/docker/Dockerfile b/docker/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..60a39bf920b89ddc87f4279c9ec8899d7c9d2224
--- /dev/null
+++ b/docker/Dockerfile
@@ -0,0 +1,107 @@
+FROM nvidia/cuda:11.3.1-devel-ubuntu18.04
+
+ENV PROJECT=vidar
+ENV LC_ALL=C.UTF-8
+ENV LANG=C.UTF-8
+
+ENV PYTHON_VERSION=3.8
+ENV PYTORCH_VERSION=1.10.0+cu113
+ENV TORCHVISION_VERSION=0.11.1+cu113
+ENV CUDNN_VERSION=8.2.1.32-1+cuda11.3
+ENV NCCL_VERSION=2.9.9-1+cuda11.3
+
+# Install basic libraries
+RUN apt-get update && apt-get install -y \
+ build-essential cmake g++-4.8 git curl docker.io vim wget ca-certificates
+
+# Install python and pip
+RUN apt-get install -y python${PYTHON_VERSION} python3-pip
+RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python & \
+ ln -s /usr/bin/pip3 /usr/bin/pip
+
+# Upgrade pip
+RUN pip install --upgrade pip
+
+# Install pytorch and torchvision
+RUN pip install \
+ torch==${PYTORCH_VERSION} \
+ torchvision==${TORCHVISION_VERSION} \
+ -f https://download.pytorch.org/whl/torch_stable.html
+
+# Install CUDNN and NCCL
+RUN apt-get install -y \
+ libcudnn8=${CUDNN_VERSION} \
+ libnccl2=${NCCL_VERSION}
+
+# Install extra packages (apt-get)
+RUN apt-get install -y \
+ ffmpeg \
+ tmux
+
+# Install extra packages (pip)
+RUN pip install \
+ tqdm==4.61.0 \
+ boto3==1.17.83 \
+ termcolor==1.1.0 \
+ pyyaml==5.4.1 \
+ wandb==0.10.31 \
+ opencv-python==4.5.2.52 \
+ flow_vis==0.1 \
+ matplotlib==3.3.4 \
+ fire==0.4.0 \
+ pyquaternion==0.9.9 \
+ pandas==1.1.5 \
+ xarray==0.16.2 \
+ diskcache==5.2.1 \
+ tenacity==7.0.0 \
+ pycocotools==2.0.2 \
+ awscli==1.19.101 \
+# timm==0.4.9 \
+ ref==0.0.2.2 \
+ positional-encodings==4.0.0 \
+ einops==0.3.2 \
+ wget \
+ ftfy \
+ regex \
+ tqdm
+
+# Install CamViz dependencies
+RUN pip install \
+ pygame==2.0.1 \
+ PyOpenGL==3.1.5 \
+ PyOpenGL-accelerate==3.1.5
+RUN apt-get install -y \
+ mesa-utils \
+ freeglut3-dev \
+ libsdl2-2.0-0 \
+ python-pygame
+
+# Install PyTorch3D
+RUN pip install pytorch3d
+
+# Install CuPY
+RUN pip install cupy
+
+# Install huggingface transformers
+RUN pip install transformers
+
+# Install extras (should be moved to top when stable)
+RUN pip install lpips wget scikit-image pyhocon dotmap path sacremoses filelock huggingface_hub
+RUN pip install 'git+https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup'
+RUN pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu113.html
+
+# Install DGP (dataset utils)
+WORKDIR /workspace
+RUN git clone https://github.com/VitorGuizilini-TRI/dgp.git
+ENV PYTHONPATH="/workspace/dgp:$PYTHONPATH"
+
+# Create workspace folder
+RUN mkdir -p /workspace/experiments
+RUN mkdir -p /workspace/${PROJECT}
+WORKDIR /workspace/${PROJECT}
+# Copy project to workspace folder
+COPY . /workspace/${PROJECT}
+
+# Set environment variables
+ENV PYTHONPATH="/workspace/${PROJECT}:$PYTHONPATH"
+ENV PYTHONPATH="/workspace/${PROJECT}/externals/camviz:$PYTHONPATH"
diff --git a/externals/camviz/.gitignore b/externals/camviz/.gitignore
new file mode 100755
index 0000000000000000000000000000000000000000..1b58d2a8a27eda23d7a749f31cbf8b507da576ba
--- /dev/null
+++ b/externals/camviz/.gitignore
@@ -0,0 +1,3 @@
+*.pyc
+.idea
+__pycache__
diff --git a/externals/camviz/LICENSE.md b/externals/camviz/LICENSE.md
new file mode 100755
index 0000000000000000000000000000000000000000..996346997eb8db1521d947c13a73422a9d77eeb8
--- /dev/null
+++ b/externals/camviz/LICENSE.md
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2021 Toyota Research Institute (TRI)
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/externals/camviz/README.md b/externals/camviz/README.md
new file mode 100755
index 0000000000000000000000000000000000000000..9f413a6635ae386ef1417fa58b761040ee0dd81c
--- /dev/null
+++ b/externals/camviz/README.md
@@ -0,0 +1,78 @@
+
+
+
+
+
+## CamViz
+
+[Overview](#overview) // [Installation](#install) // [Demos](#demos) // [License](#license)
+
+
+
+
+
+
+## Overview
+
+**CamViz** is a visualization library developed by the TRI-ML team with the goal of providing an interface for the visualization of monocular depth estimation results, both as depth maps and reconstructed pointclouds. It uses [PyGame](https://www.pygame.org/news) for window display and input management (mouse and keyboard), and [OpenGL](https://www.opengl.org//) for 2D and 3D drawing and rendering. It provides an easy and intuitive way to:
+- Store information as textures and data buffers for efficient display
+- Create 2D environments for image display and 3D environments for pointcloud visualization
+- A pinhole camera class that manages most basic geometric operations (reconstruction, projection, transformation to different coordinate frames, etc.)
+
+Although **CamViz** works as a standalone library, it was designed specifically to be used in conjunction with other TRI-ML's repositories, in particular [PackNet-SFM](https://github.com/tri-ml/packnet-sfm) and [DDAD](https://github.com/tri-ml/ddad). To facilitate integration, it is also provided as a submodule in those repositories.
+
+## Installation
+
+We provide a `requirements.txt` file with all the required libraries (tested on Ubuntu 18.04). To start using **CamViz** all you need to do is:
+
+```
+git clone git@github.com:TRI-ML/camviz.git
+cd camviz
+pip install -r requirements.txt
+PYTHONPATH=$PYTHONPATH:/path/to/camviz
+```
+
+## Demos
+
+The **CamViz** repository comes with a demo that visualizes a predicted monocular pointcloud (already calculated, and provided as part of the repository). We plan to include more demos as more functionalities are added, usually tied to scientific publications.
+To run it, type the following command from the root folder:
+
+```
+python demos/pointcloud.py
+```
+
+The output should look like this:
+
+
+
+
+
+From this initial display you can:
+- Zoom in/out on the images with the mouse wheel, and translate within image boundaries.
+- Move freely within the 3D viewer (translation, rotation and zoom in/out) with the mouse.
+- Change color modes with the `enter` key.
+
+## License
+
+The source code is released under the [MIT license](LICENSE.md).
diff --git a/externals/camviz/camviz/__init__.py b/externals/camviz/camviz/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..99f399f2b8f1ebcb3492ba0c760ae3cbad564a36
--- /dev/null
+++ b/externals/camviz/camviz/__init__.py
@@ -0,0 +1,7 @@
+# Copyright 2021 Toyota Research Institute. All rights reserved.
+
+from camviz.draw.draw import Draw
+from camviz.objects.camera import Camera
+from camviz.objects.pose import Pose
+from camviz.objects.bbox2d import BBox2D
+from camviz.objects.bbox3d import BBox3D
diff --git a/externals/camviz/camviz/containers/buffer.py b/externals/camviz/camviz/containers/buffer.py
new file mode 100755
index 0000000000000000000000000000000000000000..34d84064358403247a7f3d81fe3bb4d7c229a6d2
--- /dev/null
+++ b/externals/camviz/camviz/containers/buffer.py
@@ -0,0 +1,107 @@
+
+import numpy as np
+from OpenGL.GL import \
+ glGenBuffers, glBindBuffer, glBufferData, glBufferSubData, GL_ARRAY_BUFFER, GL_STATIC_DRAW
+
+from camviz.utils.utils import numpyf
+from camviz.utils.types import is_tuple, is_list, is_tensor
+from camviz.utils.cmaps import jet
+
+
+class Buffer:
+ """
+ Initialize a data buffer
+
+ Parameters
+ ----------
+ data : np.array [N,D] or tuple (n,d)
+ Data to be added to the buffer
+ If it's a tuple, create a data buffer of that size
+ dtype : numpy type (e.g. np.float32)
+ Numpy data type
+ gltype : OpenGL type (e.g. GL_FLOAT32)
+ OpenGL data type
+ """
+ def __init__(self, data, dtype, gltype):
+ # Initialize buffer ID and max size
+ self.id, self.max = glGenBuffers(1), 0
+ # Store data types
+ self.dtype, self.gltype = dtype, gltype
+ if is_tuple(data):
+ # If data is a tuple, store dimensions
+ data, (self.n, self.d) = None, data
+ else:
+ # Process data and store dimensions
+ data = self.process(data)
+ self.n, self.d = data.shape[:2]
+ # If size is larger than available, recreate buffer
+ if self.n > self.max:
+ self._create(data)
+
+ @property
+ def size(self):
+ """Get buffer size"""
+ return self.n * self.d * np.dtype(self.dtype).itemsize
+
+ def process(self, data):
+ """
+ Process data buffer to get relevant information
+
+ Parameters
+ ----------
+ data : list or np.array or torch.Tensor
+ Data to be processed
+
+ Returns
+ -------
+ data : np.array
+ Processed data
+ """
+ # If it's a list
+ if is_list(data):
+ data = numpyf(data)
+ # If it's a tensor
+ if is_tensor(data):
+ # If tensor is a grid with 3D coordinates [3,H,W]
+ if data.dim() == 3 and data.shape[0] == 3:
+ data = data.permute(1, 2, 0).reshape(-1, 3)
+ data = data.detach().cpu().numpy()
+ # If it's not the correct type, convert
+ if data.dtype != self.dtype:
+ data = data.astype(self.dtype)
+ # Expand if necessary
+ if len(data.shape) == 1:
+ data = np.expand_dims(data, 1)
+ # Return data
+ return data
+
+ def _create(self, data):
+ """Create a new data buffer"""
+ self.max = self.n
+ glBindBuffer(GL_ARRAY_BUFFER, self.id)
+ glBufferData(GL_ARRAY_BUFFER, self.size, data, GL_STATIC_DRAW)
+ glBindBuffer(GL_ARRAY_BUFFER, 0)
+
+ def update(self, data):
+ """Update data buffer"""
+ # Process data
+ data = self.process(data)
+ # Get dimensions or initialize as zero
+ self.n = 0 if data.size == 0 else data.shape[0]
+ # If dimensions are larger than available, recreate
+ if self.n > self.max:
+ self._create(data)
+ # Otherwise
+ else:
+ # Bind buffer and copy data
+ glBindBuffer(GL_ARRAY_BUFFER, self.id)
+ glBufferSubData(GL_ARRAY_BUFFER, 0, self.size, data.astype(self.dtype))
+ glBindBuffer(GL_ARRAY_BUFFER, 0)
+
+ def clear(self):
+ """Clear buffer"""
+ self.n = 0
+
+ def updateJET(self, data):
+ """Update buffer using a JET colormap"""
+ self.update(jet(data))
diff --git a/externals/camviz/camviz/containers/texture.py b/externals/camviz/camviz/containers/texture.py
new file mode 100755
index 0000000000000000000000000000000000000000..e64052e868aa0bcee6f484b6a5010c3eae7e7cb4
--- /dev/null
+++ b/externals/camviz/camviz/containers/texture.py
@@ -0,0 +1,144 @@
+
+import cv2
+import numpy as np
+import pygame
+from OpenGL.GL import \
+ glEnable, glDisable, glTexParameterf, \
+ glBindTexture, glGenTextures, glTexImage2D, glTexSubImage2D, \
+ GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_TEXTURE_WRAP_T, GL_TEXTURE_MAG_FILTER, \
+ GL_TEXTURE_MIN_FILTER, GL_REPEAT, GL_NEAREST, GL_RGB, GL_RGBA, GL_UNSIGNED_BYTE
+
+from camviz.utils.types import is_str, is_tensor, is_numpy, is_tuple
+
+
+def load(image):
+ """
+ Load an image as a texture surface
+
+ Parameters
+ ----------
+ image : np.array or str
+ Input image
+ Returns
+ -------
+ surface : pygame surface
+ Output surface for texture buffer
+ """
+ # Return None if image is None
+ if image is None:
+ return None
+ # If image is a string, load from file
+ if is_str(image):
+ surface = pygame.image.load(image)
+ # If it's a numpy array
+ else:
+ # Convert to uint8 and transpose
+ image = image.astype(np.uint8)
+ image = np.transpose(image, (1, 0, 2))
+ # Create surface
+ surface = pygame.surfarray.make_surface(image)
+ # Return surface
+ return pygame.image.tostring(surface, "RGBA", 1)
+
+
+class Texture:
+ """
+ Initialize a texture buffer
+
+ Parameters
+ ----------
+ data : np.array [N,D] or tuple (n,d)
+ Data to be added to the buffer
+ If it's a tuple, create a data buffer of that size
+ """
+ def __init__(self, data=None):
+ # Create a new texture ID
+ self.id = glGenTextures(1)
+ # If data exists create texture buffer from it
+ if data is not None:
+ self._create(data)
+ # Otherwise, just store dimensions
+ else:
+ self.wh = None
+
+ def process(self, image):
+ """Process a new image to produce a texture buffer"""
+ # If it's a tensor
+ if is_tensor(image):
+ # Detach and transpose
+ image = image.detach().cpu().numpy()
+ if len(image.shape) == 4:
+ image = image[0]
+ if len(image.shape) == 3 and image.shape[0] == 3:
+ image = image.transpose((1, 2, 0))
+ # If it's a numpy array
+ if is_numpy(image):
+ # Resize to proper shape
+ if image.shape[0] != self.wh[0] or image.shape[1] != self.wh[1]:
+ image = cv2.resize(image, self.wh, interpolation=cv2.INTER_LINEAR)
+ # Squeeze if necessary
+ if len(image.shape) == 3 and image.shape[2] == 1:
+ image = image.squeeze(-1)
+ # Stack to 3 channels if necessary
+ if len(image.shape) == 2:
+ image = np.stack([image] * 3, axis=2)
+ # Return image
+ return image * 255
+
+ def _create(self, data):
+ """Create a texture buffer from data"""
+ # If it's tuple, it only contains dimensions
+ if is_tuple(data):
+ image = None
+ w, h = data[:2]
+ # If it's a tensor, convert to numpy
+ elif is_tensor(data):
+ image = data.detach().cpu().numpy().transpose(1, 2, 0) * 255
+ h, w = data.shape[-2:]
+ # Otherwise, it contains data and dimensions
+ else:
+ image = data * 255
+ h, w = data.shape[:2]
+ # Store dimensions
+ self.wh = (int(w), int(h))
+ # Bind and fill texture
+ glBindTexture(GL_TEXTURE_2D, self.id)
+ glTexImage2D(GL_TEXTURE_2D, 0, GL_RGB, self.wh[0], self.wh[1],
+ 0, GL_RGBA, GL_UNSIGNED_BYTE, load(image))
+ glBindTexture(GL_TEXTURE_2D, 0)
+ # Return image
+ return image
+
+ def update(self, image):
+ """Update texture buffer from an image"""
+ # Return None if image is None
+ if image is None:
+ return None
+ # If there are no stored dimensions, create a new texture buffer
+ if self.wh is None:
+ self._create(image)
+ # Otherwise, update texture buffer
+ else:
+ # Process image
+ image = self.process(image)
+ # Bind and update buffer
+ glBindTexture(GL_TEXTURE_2D, self.id)
+ glTexSubImage2D(GL_TEXTURE_2D, 0, 0, 0, self.wh[0], self.wh[1],
+ GL_RGBA, GL_UNSIGNED_BYTE, load(image))
+ glBindTexture(GL_TEXTURE_2D, 0)
+
+ def bind(self):
+ """Bind and store data in texture buffer"""
+ glEnable(GL_TEXTURE_2D)
+ glBindTexture(GL_TEXTURE_2D, self.id)
+
+ glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S , GL_REPEAT )
+ glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T , GL_REPEAT )
+ glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER , GL_NEAREST)
+ glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER , GL_NEAREST)
+
+ @staticmethod
+ def unbind():
+ """Unbind texture buffer"""
+ glDisable(GL_TEXTURE_2D)
+ glBindTexture(GL_TEXTURE_2D, 0)
diff --git a/externals/camviz/camviz/data/buffer.py b/externals/camviz/camviz/data/buffer.py
new file mode 100755
index 0000000000000000000000000000000000000000..c074ca4f9f776fd37cefd7e063be00badcf9f9ca
--- /dev/null
+++ b/externals/camviz/camviz/data/buffer.py
@@ -0,0 +1,105 @@
+# Copyright 2021 Toyota Research Institute. All rights reserved.
+
+import numpy as np
+from OpenGL.GL import \
+ glGenBuffers, glBindBuffer, glBufferData, glBufferSubData, GL_ARRAY_BUFFER, GL_STATIC_DRAW
+
+from camviz.utils.utils import numpyf
+from camviz.utils.types import is_tuple, is_list, is_tensor
+from camviz.utils.cmaps import jet
+
+
+class Buffer:
+ """
+ Initialize a data buffer
+
+ Parameters
+ ----------
+ data : np.array [N,D] or tuple (n,d)
+ Data to be added to the buffer
+ If it's a tuple, create a data buffer of that size
+ dtype : numpy type (e.g. np.float32)
+ Numpy data type
+ gltype : OpenGL type (e.g. GL_FLOAT32)
+ OpenGL data type
+ """
+ def __init__(self, data, dtype, gltype):
+ # Initialize buffer ID and max size
+ self.id, self.max = glGenBuffers(1), 0
+ # Store data types
+ self.dtype, self.gltype = dtype, gltype
+ if is_tuple(data):
+ # If data is a tuple, store dimensions
+ data, (self.n, self.d) = None, data
+ else:
+ # Process data and store dimensions
+ data = self.process(data)
+ self.n, self.d = data.shape[:2]
+ # If size is larger than available, recreate buffer
+ if self.n > self.max:
+ self._create(data)
+
+ @property
+ def size(self):
+ """Get buffer size"""
+ return self.n * self.d * np.dtype(self.dtype).itemsize
+
+ def process(self, data):
+ """
+ Process data buffer to get relevant information
+
+ Parameters
+ ----------
+ data : list or np.array or torch.Tensor
+ Data to be processed
+
+ Returns
+ -------
+ data : np.array
+ Processed data
+ """
+ # If it's a list
+ if is_list(data):
+ data = numpyf(data)
+ # If it's a tensor
+ if is_tensor(data):
+ data = data.detach().cpu().numpy()
+ # If it's not the correct type, convert
+ if data.dtype != self.dtype:
+ data = data.astype(self.dtype)
+ # Expand if necessary
+ if len(data.shape) == 1:
+ data = np.expand_dims(data, 1)
+ # Return data
+ return data
+
+ def _create(self, data):
+ """Create a new data buffer"""
+ self.max = self.n
+ glBindBuffer(GL_ARRAY_BUFFER, self.id)
+ glBufferData(GL_ARRAY_BUFFER, self.size, data, GL_STATIC_DRAW)
+ glBindBuffer(GL_ARRAY_BUFFER, 0)
+
+ def update(self, data):
+ """Update data buffer"""
+ # Process data
+ data = self.process(data)
+ # Get dimensions or initialize as zero
+ self.n = 0 if data.size == 0 else data.shape[0]
+ # If dimensions are larger than available, recreate
+ if self.n > self.max:
+ self._create(data)
+ # Otherwise
+ else:
+ # Bind buffer and copy data
+ glBindBuffer(GL_ARRAY_BUFFER, self.id)
+ glBufferSubData(GL_ARRAY_BUFFER, 0, self.size, data.astype(self.dtype))
+ glBindBuffer(GL_ARRAY_BUFFER, 0)
+
+ def clear(self):
+ """Clear buffer"""
+ self.n = 0
+
+ def updateJET(self, data):
+ """Update buffer using a JET colormap"""
+ self.update(jet(data))
diff --git a/externals/camviz/camviz/data/texture.py b/externals/camviz/camviz/data/texture.py
new file mode 100755
index 0000000000000000000000000000000000000000..0d8a453e75f6ce8829cd00f13f193aa7c6312ebd
--- /dev/null
+++ b/externals/camviz/camviz/data/texture.py
@@ -0,0 +1,137 @@
+# Copyright 2021 Toyota Research Institute. All rights reserved.
+
+import cv2
+import numpy as np
+import pygame
+from OpenGL.GL import \
+ glEnable, glDisable, glTexParameterf, \
+ glBindTexture, glGenTextures, glTexImage2D, glTexSubImage2D, \
+ GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_TEXTURE_WRAP_T, GL_TEXTURE_MAG_FILTER, \
+ GL_TEXTURE_MIN_FILTER, GL_REPEAT, GL_NEAREST, GL_RGB, GL_RGBA, GL_UNSIGNED_BYTE
+
+from camviz.utils.types import is_str, is_tensor, is_numpy, is_tuple
+
+
+def load(image):
+ """
+ Load an image as a texture surface
+
+ Parameters
+ ----------
+ image : np.array or str
+ Input image
+ Returns
+ -------
+ surface : pygame surface
+ Output surface for texture buffer
+ """
+ # Return None if image is None
+ if image is None:
+ return None
+ # If image is a string, load from file
+ if is_str(image):
+ surface = pygame.image.load(image)
+ # If it's a numpy array
+ else:
+ # Convert to uint8 and transpose
+ image = image.astype(np.uint8)
+ image = np.transpose(image, (1, 0, 2))
+ # Create surface
+ surface = pygame.surfarray.make_surface(image)
+ # Return surface
+ return pygame.image.tostring(surface, "RGBA", 1)
+
+
+class Texture:
+ """
+ Initialize a texture buffer
+
+ Parameters
+ ----------
+ data : np.array [N,D] or tuple (n,d)
+ Data to be added to the buffer
+ If it's a tuple, create a data buffer of that size
+ """
+ def __init__(self, data=None):
+ # Create a new texture ID
+ self.id = glGenTextures(1)
+ # If data exists create texture buffer from it
+ if data is not None:
+ self._create(data)
+ # Otherwise, just store dimensions
+ else:
+ self.wh = None
+
+ def process(self, image):
+ """Process a new image to produce a texture buffer"""
+ # If it's a tensor
+ if is_tensor(image):
+ # Detach and transpose
+ image = image.detach().cpu().numpy()
+ if len(image.shape) == 3:
+ image = image.transpose((1, 2, 0))
+ # If it's a numpy array
+ if is_numpy(image):
+ # Resize to proper shape
+ if image.shape[0] != self.wh[0] or image.shape[1] != self.wh[1]:
+ image = cv2.resize(image, self.wh, interpolation=cv2.INTER_LINEAR)
+ # Squeeze if necessary
+ if len(image.shape) == 3 and image.shape[2] == 1:
+ image = image.squeeze(-1)
+ # Stack to 3 channels if necessary
+ if len(image.shape) == 2:
+ image = np.stack([image] * 3, axis=2)
+ # Return image
+ return image
+
+ def _create(self, data):
+ """Create a texture buffer from data"""
+ # If it's tuple, it only contains dimensions
+ if is_tuple(data):
+ image, (w, h) = None, data[:2]
+ # Otherwise, it contains data and dimensions
+ else:
+ image, (h, w) = data, data.shape[:2]
+ # Store dimensions
+ self.wh = (int(w), int(h))
+ # Bind and fill texture
+ glBindTexture(GL_TEXTURE_2D, self.id)
+ glTexImage2D(GL_TEXTURE_2D, 0, GL_RGB, self.wh[0], self.wh[1],
+ 0, GL_RGBA, GL_UNSIGNED_BYTE, load(image))
+ glBindTexture(GL_TEXTURE_2D, 0)
+ # Return image
+ return image
+
+ def update(self, image):
+ """Update texture buffer from an image"""
+ # Return None if image is None
+ if image is None:
+ return None
+ # If there are no stored dimensions, create a new texture buffer
+ if self.wh is None:
+ self._create(image)
+ # Otherwise, update texture buffer
+ else:
+ # Process image
+ image = self.process(image)
+ # Bind and update buffer
+ glBindTexture(GL_TEXTURE_2D, self.id)
+ glTexSubImage2D(GL_TEXTURE_2D, 0, 0, 0, self.wh[0], self.wh[1],
+ GL_RGBA, GL_UNSIGNED_BYTE, load(image))
+ glBindTexture(GL_TEXTURE_2D, 0)
+
+ def bind(self):
+ """Bind and store data in texture buffer"""
+ glEnable(GL_TEXTURE_2D)
+ glBindTexture(GL_TEXTURE_2D, self.id)
+
+ glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S , GL_REPEAT )
+ glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T , GL_REPEAT )
+ glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER , GL_NEAREST)
+ glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER , GL_NEAREST)
+
+ @staticmethod
+ def unbind():
+ """Unbind texture buffer"""
+ glDisable(GL_TEXTURE_2D)
+ glBindTexture(GL_TEXTURE_2D, 0)
diff --git a/externals/camviz/camviz/draw/__init__.py b/externals/camviz/camviz/draw/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..6a0810f6a8884401e5b7063de351170926551ee8
--- /dev/null
+++ b/externals/camviz/camviz/draw/__init__.py
@@ -0,0 +1,3 @@
+# Copyright 2021 Toyota Research Institute. All rights reserved.
+
+from camviz.draw.draw import Draw
diff --git a/externals/camviz/camviz/draw/draw.py b/externals/camviz/camviz/draw/draw.py
new file mode 100644
index 0000000000000000000000000000000000000000..928c55b4e655931f637941d6c4e1a5266af1548a
--- /dev/null
+++ b/externals/camviz/camviz/draw/draw.py
@@ -0,0 +1,351 @@
+# Copyright 2021 Toyota Research Institute. All rights reserved.
+
+import time
+
+import numpy as np
+import pygame
+from OpenGL.GL import glReadPixels, glViewport, glScissor, \
+ glClear, glClearColor, glPixelStorei, \
+ GL_BGR, GL_UNSIGNED_BYTE, GL_COLOR_BUFFER_BIT, GL_DEPTH_BUFFER_BIT, \
+ GL_PACK_ALIGNMENT, GL_RGBA
+from PIL import Image, ImageOps
+from pygame.locals import *
+
+from camviz.draw.draw_buffer import drawBuffer
+from camviz.draw.draw_input import DrawInput
+from camviz.draw.draw_texture import DrawTexture
+from camviz.opengl.opengl_colors import setColor
+from camviz.opengl.opengl_shapes import setPointSize, setLineWidth
+from camviz.screen.screen2Dimage import Screen2Dimage
+from camviz.screen.screen3Dworld import Screen3Dworld
+from camviz.utils.types import is_tuple, is_list
+from camviz.utils.utils import labelrc
+
+
+class Draw(DrawInput, DrawTexture, drawBuffer):
+
+ def __init__(self, wh, rc=None, title=None, scale=1.0, width=1600):
+ """
+ Draw class for display visualization
+
+ Parameters
+ ----------
+ wh : tuple (width, height)
+ Window dimensions
+ rc : tuple (row, column)
+ Number of rows and columns (multiplying wh)
+ title : str
+ Window title
+ scale : float
+ Scale for width/height window dimensions
+ """
+ super().__init__()
+ # Initialize pygame display
+ pygame.init()
+ # Initialize title
+ if title is not None:
+ pygame.display.set_caption(title)
+ # Initialize parameters
+ wh = [int(val * scale) for val in wh]
+ if width is not None:
+ wh[0], wh[1] = width, width * wh[1] // wh[0]
+ self.wh = self.curr_color = self.curr_size = self.curr_width = None
+ self.screens, self.textures, self.buffers = {}, {}, {}
+ self.idx_screen = None
+ # Set size and color
+ self.setSize(wh, rc)
+ self.color('whi').size(1).width(1)
+
+ def setSize(self, wh, rc=None):
+ """Set window size"""
+ # Multiply row and column to produce correct dimensions
+ if rc is not None:
+ wh = (wh[0] * rc[1], wh[1] * rc[0])
+ # Store dimensions
+ self.wh = wh
+ # Initialize display
+ pygame.display.set_mode(self.wh, DOUBLEBUF|OPENGL)
+
+ def __getitem__(self, name):
+ """Get screen from name"""
+ return self.screen(name)
+
+ def scr(self, name):
+ """Get screen from name"""
+ return self.screens[name]
+
+ def tex(self, name):
+ """Get texture from name"""
+ return self.textures[name]
+
+ def buf(self, name):
+ """Get buffer from name"""
+ return self.buffers[name]
+
+ def object(self, obj, *args, **kwargs):
+ """Display object on screen"""
+ obj.display(self, *args, **kwargs)
+
+ def to_image(self):
+ """Convert window into a numpy image"""
+ x, y, w, h = 0, 0, self.wh[0], self.wh[1]
+ data = glReadPixels(x, y, w, h, GL_BGR, GL_UNSIGNED_BYTE)
+ image = Image.frombytes("RGB", (w, h), data)
+ image = image.transpose(Image.FLIP_TOP_BOTTOM)
+ return np.asarray(image)
+
+ def currScreen(self):
+ """Return current screen"""
+ return self.screens[self.idx_screen]
+
+ def addScreen(self, luwh):
+ """
+ Add a new screen to the draw window
+
+ Parameters
+ ----------
+ luwh : tuple (left, up, width, height)
+ Screen dimensions (percentage or pixels)
+
+ Returns
+ -------
+ l, u, w, h : int
+ Dimensions
+ """
+ # Parse dimensions
+ l, u, w, h = luwh
+ w, h = w - l, h - u
+ # Convert percentages to pixels
+ if isinstance(l, float):
+ l = int(l * self.wh[0])
+ if isinstance(u, float):
+ u = int(u * self.wh[1])
+ if isinstance(w, float):
+ w = int(w * self.wh[0])
+ if isinstance(h, float):
+ h = int(h * self.wh[1])
+ # Get screen index and return dimensions
+ self.idx_screen = len(self.screens)
+ return l, u, w, h
+
+ def screen(self, name):
+ """Set which screen will be used for drawing"""
+ self.idx_screen = name
+ # Get parameters
+ d = self.wh[1]
+ l, u, w, h = self.currScreen().luwh
+ u = d - (h + u)
+ # Create viewport and cropping
+ glViewport(l, u, w, h)
+ glScissor(l, u, w, h)
+ # Set background color
+ glClearColor(0.0, 0.0, 0.0, 1.0)
+ # Prepare current screen
+ self.currScreen().prepare()
+ return self
+
+ def add2Dimage(self, name, luwh, res=None):
+ """
+ Add 2D image screen
+
+ Parameters
+ ----------
+ name : str
+ Screen name
+ luwh : tuple (left, up, width, height)
+ Screen dimensions (pixels or percentage)
+ res : tuple (width, height)
+ Screen resolution
+ """
+ # If name is a tuple, create labels for rows and columns
+ if is_tuple(name):
+ name = labelrc(name)
+ # If name is a list, create several screens
+ if is_list(name):
+ for i in range(len(name)):
+ luwh_i = list(luwh)
+ d = (luwh[3] - luwh[1]) / len(name)
+ luwh_i[1] = luwh[1] + i * d
+ luwh_i[3] = luwh_i[1] + d
+ for j in range(len(name[i])):
+ d = (luwh[2] - luwh[0]) / len(name[i])
+ luwh_i[0] = luwh[0] + j * d
+ luwh_i[2] = luwh_i[0] + d
+ self.add2Dimage(name[i][j], luwh_i, res)
+ # Else, create a single screen
+ else:
+ self.screens[name] = Screen2Dimage(self.addScreen(luwh), res)
+
+ def add2DimageRow(self, name, luwh, n, res=None):
+ """
+ Add row with multiple 2D image screens
+
+ Parameters
+ ----------
+ name : str
+ Screen name
+ luwh : tuple (left, up, width, height)
+ Screen dimensions (pixels or percentage)
+ n : int
+ Number of columns in the row
+ res : tuple (width, height)
+ Screen resolution
+ """
+ for i in range(n):
+ # Copy dimension vector
+ luwh_i = [val for val in luwh]
+ # Offset rows
+ luwh_i[0] = luwh[0] + (i / n) * (luwh[2] - luwh[0])
+ luwh_i[2] = luwh[0] + ((i + 1) / n) * (luwh[2] - luwh[0])
+ # Create 2D image screen
+ self.add2Dimage('%s%d' % (name, i), luwh_i, res)
+
+ def add2DimageCol(self, name, luwh, n, res=None):
+ """
+ Add column with multiple 2D image screens
+
+ Parameters
+ ----------
+ name : str
+ Screen name
+ luwh : tuple (left, up, width, height)
+ Screen dimensions (pixels or percentage)
+ n : int
+ Number of rows in the column
+ res : tuple (width, height)
+ Screen resolution
+ """
+ for i in range(n):
+ # Copy dimension vector
+ luwh_i = [val for val in luwh]
+ # Offset columns
+ luwh_i[1] = luwh[1] + (i / n) * (luwh[3] - luwh[1])
+ luwh_i[3] = luwh[1] + ((i + 1) / n) * (luwh[3] - luwh[1])
+ # Create 2D image screen
+ self.add2Dimage('%s%d' % (name, i), luwh_i, res)
+
+ def add2DimageGrid(self, name, luwh, n, res=None):
+ """
+ Add grid with multiple 2D image screens
+
+ Parameters
+ ----------
+ name : str
+ Screen name
+ luwh : tuple (left, up, width, height)
+ Screen dimensions (pixels or percentage)
+ n : tuple (int, int)
+ Number of rows and columns in the grid
+ res : tuple (width, height)
+ Screen resolution
+ """
+ for i in range(n[0]):
+ for j in range(n[1]):
+ # Copy dimension vector
+ luwh_i = [val for val in luwh]
+ # Offset columns
+ luwh_i[1] = luwh[1] + (i / n[0]) * (luwh[3] - luwh[1])
+ luwh_i[3] = luwh[1] + ((i + 1) / n[0]) * (luwh[3] - luwh[1])
+ # Offset rows
+ luwh_i[0] = luwh[0] + (j / n[1]) * (luwh[2] - luwh[0])
+ luwh_i[2] = luwh[0] + ((j + 1) / n[1]) * (luwh[2] - luwh[0])
+ # Create 2D image screen
+ self.add2Dimage('%s%d%d' % (name, i, j), luwh_i, res)
+
+ def add3Dworld(self, name, luwh=(0.0, 0.0, 1.0, 1.0), **kwargs):
+ """
+ Add a 3D world screen
+
+ Parameters
+ ----------
+ name : str
+ Screen name
+ luwh : tuple
+ Screen dimensions (left, up, width, height), in pixels or percentage
+ """
+ # If name is a tuple, create labels for rows and columns
+ if is_tuple(name):
+ name = labelrc(name)
+ # If name is a list, create several screens
+ if is_list(name):
+ for i in range(len(name)):
+ luwh_i = list(luwh)
+ d = (luwh[3] - luwh[1]) / len(name)
+ luwh_i[1] = luwh[1] + i * d
+ luwh_i[3] = luwh_i[1] + d
+ for j in range(len(name[i])):
+ d = (luwh[2] - luwh[0]) / len(name[i])
+ luwh_i[0] = luwh[0] + j * d
+ luwh_i[2] = luwh_i[0] + d
+ self.add3Dworld(name[i][j], luwh_i, **kwargs)
+ # Else, create a single screen
+ else:
+ self.screens[name] = Screen3Dworld(self.addScreen(luwh), **kwargs)
+
+ @staticmethod
+ def clear():
+ """Clear window"""
+ glClear(GL_COLOR_BUFFER_BIT|GL_DEPTH_BUFFER_BIT)
+
+ def populate(self, data, fit=False):
+ """
+ Populate screens with information from a dictionary
+
+ Parameters
+ ----------
+ data : dict
+ Dictionary with information for each screen
+ fit : bool
+ If true, resize screens to fit the data showing
+ """
+ for key, val in data.items():
+ # If it's a tuple, use key and val from positions 0 and 1
+ if is_tuple(val):
+ self.screen(key).image(val[0], data=val[1], fit=fit)
+ # Else, use key and val directly
+ else:
+ self.screen(key).image(key, data=val, fit=fit)
+
+ def size(self, n):
+ """Set point size"""
+ self.curr_size = 1
+ setPointSize(n)
+ return self
+
+ def width(self, n):
+ """Set line width"""
+ self.curr_width = n
+ setLineWidth(n)
+ return self
+
+ def color(self, clr):
+ """Set plot color"""
+ self.curr_color = clr
+ setColor(clr)
+ return self
+
+ def setCSW(self, csw):
+ """Set color, size and width (CSW) simultaneously"""
+ self.color(csw[0]).size(csw[1]).width(csw[2])
+
+ def getCSW(self):
+ """Get color, size and width (CSW) information"""
+ return self.curr_color, self.curr_size, self.curr_width
+
+ @staticmethod
+ def halt(n):
+ """Stop for n milliseconds"""
+ time.sleep(n/1000)
+
+ def save(self, filename):
+ """Save window as an image file"""
+ # Get dimensions
+ width, height = self.wh
+ # Store window information in a variable
+ glPixelStorei(GL_PACK_ALIGNMENT, 1)
+ data = glReadPixels(0, 0, width, height, GL_RGBA, GL_UNSIGNED_BYTE)
+ image = Image.frombytes("RGBA", (width, height), data)
+ image = ImageOps.flip(image) # in my case image is flipped top-bottom for some reason
+ # Save image and halt for a bit
+ image.save(filename, 'PNG')
+ self.halt(1000)
diff --git a/externals/camviz/camviz/draw/draw_buffer.py b/externals/camviz/camviz/draw/draw_buffer.py
new file mode 100755
index 0000000000000000000000000000000000000000..1a833c34f7e2ab297af7a9abf1aaa46850f4f0b2
--- /dev/null
+++ b/externals/camviz/camviz/draw/draw_buffer.py
@@ -0,0 +1,241 @@
+# Copyright 2021 Toyota Research Institute. All rights reserved.
+
+import numpy as np
+from OpenGL.GL import glEnableClientState, glDisableClientState, \
+ glPolygonMode, glVertexPointer, glBindBuffer, glColorPointer, \
+ glDrawArrays, glDrawElements, glBegin, glEnd, glVertex2fv, glVertex3fv, \
+ GL_ARRAY_BUFFER, GL_FILL, GL_ELEMENT_ARRAY_BUFFER, \
+ GL_FLOAT, GL_UNSIGNED_INT, GL_POINTS, GL_FRONT_AND_BACK, GL_COLOR_ARRAY, \
+ GL_VERTEX_ARRAY, GL_LINE, GL_LINES, GL_LINE_LOOP, GL_LINE_STRIP, GL_QUADS, GL_TRIANGLES
+
+from camviz.containers.buffer import Buffer
+from camviz.opengl.opengl_shapes import drawConnects, drawMatches, drawAxis, drawEllipse
+from camviz.utils.utils import grid_idx
+from camviz.utils.types import is_str, is_list, is_tuple, is_int
+from camviz.utils.cmaps import jet
+
+
+class drawBuffer:
+ """Draw subclass containing data buffer methods"""
+ def __init__(self):
+ pass
+
+ def addBuffer(self, name, data, dtype, gltype, n=None):
+ """
+ Create a new data buffer
+
+ Parameters
+ ----------
+ name : str
+ Buffer name
+ data : np.array [N,D] or tuple (n,d)
+ Data to be added to the buffer
+ If it's a tuple, create a data buffer of that size
+ dtype : numpy type (e.g. np.float32)
+ Numpy data type
+ gltype : OpenGL type (e.g. GL_FLOAT32)
+ OpenGL data type
+ n : int or tuple
+ Number of textures to be added
+ """
+ # If it's a list, create one buffer for each item
+ if is_list(name):
+ for i in range(len(name)):
+ self.addBuffer(name[i], data[i] if is_list(data) else data, dtype, gltype)
+ # Otherwise, create a single buffer
+ else:
+ if n is not None:
+ if is_tuple(n):
+ for i in range(n[0]):
+ for j in range(n[1]):
+ self.buffers['%s%d%d' % (name, i, j)] = Buffer(data, dtype, gltype)
+ elif is_int(n):
+ for i in range(n):
+ self.buffers['%s%d' % (name, i)] = Buffer(data, dtype, gltype)
+ self.buffers[name] = Buffer(data, dtype, gltype)
+
+ def addBufferf(self, name, data=0):
+ """Create a buffer with float32 values (2D or 3D is determined from data)"""
+ self.addBuffer(name, data, np.float32, GL_FLOAT)
+
+ def addBufferu(self, name, data=0):
+ """Create a buffer with unsigned 32 values (2D or 3D is determined from data)"""
+ self.addBuffer(name, data, np.uint32, GL_UNSIGNED_INT)
+
+ def addBuffer2f(self, name, data=0, n=None):
+ """Create a 2D empty buffer with float32 values"""
+ self.addBuffer(name, (data, 2), np.float32, GL_FLOAT, n)
+
+ def addBuffer3f(self, name, data=0, n=None):
+ """Create a 3D empty buffer with float32 values"""
+ self.addBuffer(name, (data, 3), np.float32, GL_FLOAT, n)
+
+ def addbufferIDX(self, name, data=0):
+ """Create an index buffer for shape drawing"""
+ self.addBufferu(name, grid_idx(data))
+
+ def addBufferJET(self, name, data=0):
+ """Create a JET colormap buffer from data"""
+ self.addBufferf(name, jet(data))
+
+ def addBuffer3JET(self, name, data=0):
+ """Create an empty 3D colormap buffer from data"""
+ self.addBuffer3f(name, data)
+
+ def updBufferf(self, name, data):
+ """Update a buffer with float32 values"""
+ self.buffers[name].update(data)
+
+ def clrBuffer(self, name):
+ """Clear a buffer"""
+ self.buffers[name].clear()
+
+ def points(self, *args, **kwargs):
+ """Draw points"""
+ return self._drawSomething(GL_POINTS, *args, **kwargs)
+
+ def lines( self, *args, **kwargs):
+ """Draw lines"""
+ return self._drawSomething(GL_LINES, *args, **kwargs)
+
+ def strips(self, *args, **kwargs):
+ """Draw strips (connecting adjacent vertices)"""
+ return self._drawSomething(GL_LINE_STRIP, *args, **kwargs)
+
+ def loop( self, *args, **kwargs):
+ """Draw loops (strips with last vertices connected)"""
+ return self._drawSomething(GL_LINE_LOOP, *args, **kwargs)
+
+ def quads( self, *args, **kwargs):
+ """Draw quadratics"""
+ return self._drawSomething(GL_QUADS, *args, **kwargs)
+
+ def tris( self, *args, **kwargs):
+ """Draw triangles"""
+ return self._drawSomething(GL_TRIANGLES, *args, **kwargs)
+
+ def grid( self, *args, **kwargs):
+ """Draw a grid"""
+ return self._drawSomething(GL_QUADS, *args, **kwargs)
+
+ def matches(self, *args, **kwargs):
+ """Draw matches from two sets of points"""
+ drawMatches(*args, **kwargs)
+ return self
+
+ def connects(self, *args, **kwargs):
+ """Draw a connection from one point to many points"""
+ drawConnects(*args, **kwargs)
+ return self
+
+ def axis(self, *args, **kwargs):
+ """Draw coordinate axis"""
+ drawAxis(*args, **kwargs)
+ return self
+
+ def ellipse(self, *args, **kwargs):
+ """Draw ellipse"""
+ drawEllipse(*args, **kwargs)
+ return self
+
+ def _drawSomething(self, shape, *args, **kwargs):
+ """
+ Base function for shape drawing
+
+ Parameters
+ ----------
+ shape : opengl shape
+ OpenGL shape to draw (e.g. GL_POINTS)
+ args : args
+ Extra draw arguments
+ kwargs : kwargs
+ Extra draw arguments
+ """
+ # If it's a string, draw buffer
+ if is_str(args[0]):
+ return self._drawBuffer(shape, *args, **kwargs)
+ # Otherwise, copy data and draw
+ else:
+ return self._drawBase(shape, *args, **kwargs)
+
+ def _drawBuffer(self, shape, vert, color=None, idx=None, wire=None):
+ """
+ Draw from a buffer
+
+ Parameters
+ ----------
+ shape : opengl type
+ OpenGL shape to draw (e.g. GL_POINTS)
+ vert : buffer
+ Buffer with vertices
+ color : buffer
+ Buffer with colors
+ idx : buffer
+ Buffer with indexes
+ wire : buffer
+ Buffer with wire (color and width)
+ """
+ # If wire is avaialble
+ if wire is not None:
+ csw = self.getCSW()
+ self.width(wire[1])
+ color_wire = wire[0] if wire[0] in self.buffers else None
+ if wire[0] not in self.buffers:
+ self.color(wire[0])
+ glPolygonMode(GL_FRONT_AND_BACK, GL_LINE)
+ self._drawBuffer(shape, vert, color=color_wire, idx=idx, wire=None)
+ glPolygonMode(GL_FRONT_AND_BACK, GL_FILL)
+ self.setCSW(csw)
+ # If vert is available
+ if vert is not None:
+ vert = self.buffers[vert]
+ glEnableClientState(GL_VERTEX_ARRAY)
+ glBindBuffer(GL_ARRAY_BUFFER, vert.id)
+ glVertexPointer(vert.d, vert.gltype, 0, None)
+ # If color is available
+ if color is not None:
+ color = self.buffers[color]
+ glEnableClientState(GL_COLOR_ARRAY)
+ glBindBuffer(GL_ARRAY_BUFFER, color.id)
+ glColorPointer(color.d, color.gltype, 0, None)
+ # If idx is available
+ if idx is None:
+ glDrawArrays(shape, 0, vert.n)
+ else:
+ idx = self.buffers[idx]
+ glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, idx.id)
+ glDrawElements(shape, idx.n, idx.gltype, None)
+ # Bind buffers
+ glBindBuffer(GL_ARRAY_BUFFER, 0)
+ # Unbind vertices
+ if vert is not None:
+ glDisableClientState(GL_VERTEX_ARRAY)
+ # Unbind colors
+ if color is not None:
+ glDisableClientState(GL_COLOR_ARRAY)
+ # Return self
+ return self
+
+ def _drawBase(self, shape, verts):
+ """
+ Draw a shape by copying data (very slow)
+
+ Parameters
+ ----------
+ shape : opengl shape
+ OpenGL shape to draw (e.g. GL_POINTS)
+ verts : np.array
+ Vertices to draw
+ """
+ # If there are no vertices, do nothing
+ if len(verts) == 0:
+ return self
+ # Select 2D or 3D vertex draw function
+ glVertex = glVertex2fv if len(verts[0]) == 2 else glVertex3fv
+ # Draw vertices
+ glBegin(shape)
+ for vert in verts:
+ glVertex(vert)
+ glEnd()
+ # Return self
+ return self
diff --git a/externals/camviz/camviz/draw/draw_input.py b/externals/camviz/camviz/draw/draw_input.py
new file mode 100755
index 0000000000000000000000000000000000000000..28ec09d9756fa92b54f86aec8c07ad19ffb8f88a
--- /dev/null
+++ b/externals/camviz/camviz/draw/draw_input.py
@@ -0,0 +1,322 @@
+# Copyright 2021 Toyota Research Institute. All rights reserved.
+
+import pygame
+
+
+class DrawInput:
+ """Draw subclass containing input controls"""
+ def __init__(self):
+ # Initialize basic keys
+ self.UP = self.DOWN = self.LEFT = self.RIGHT = False
+ self.RCTRL = self.LCTRL = self.RALT = self.LALT = self.RSHIFT = self.LSHIFT = False
+ self.SPACE = self.RETURN = self.PGUP = self.PGDOWN = False
+ # Initialize letter keys
+ self.KEY_Q = self.KEY_W = self.KEY_E = self.KEY_R = self.KEY_T = False
+ self.KEY_A = self.KEY_S = self.KEY_D = self.KEY_F = self.KEY_G = False
+ self.KEY_Z = self.KEY_X = self.KEY_C = self.KEY_V = self.KEY_B = False
+ # Initialize number keys
+ self.KEY_0 = self.KEY_1 = self.KEY_2 = self.KEY_3 = self.KEY_4 = self.KEY_5 = \
+ self.KEY_6 = self.KEY_7 = self.KEY_8 = self.KEY_9 = False
+ # Initialize mouse keys
+ self.mouse_pos = self.motion_type = None
+ self.tmp_screen, self.tmp_focus = None, False
+ self.mouse_down = False
+
+ def change_keys(self, key, flag):
+ """
+ Change key to flag value
+
+ Parameters
+ ----------
+ key : pygame key
+ Key in consideration
+ flag : bool
+ State to set the key [True or False]
+ """
+ if key == pygame.K_UP:
+ self.UP = flag
+ if key == pygame.K_DOWN:
+ self.DOWN = flag
+ if key == pygame.K_LEFT:
+ self.LEFT = flag
+ if key == pygame.K_RIGHT:
+ self.RIGHT = flag
+
+ if key == pygame.K_RCTRL:
+ self.RCTRL = flag
+ if key == pygame.K_LCTRL:
+ self.LCTRL = flag
+ if key == pygame.K_RALT:
+ self.RALT = flag
+ if key == pygame.K_LALT:
+ self.LALT = flag
+ if key == pygame.K_RSHIFT:
+ self.RSHIFT = flag
+ if key == pygame.K_LSHIFT:
+ self.LSHIFT = flag
+ if key == pygame.K_PAGEUP:
+ self.PGUP = flag
+ if key == pygame.K_PAGEDOWN:
+ self.PGDOWN = flag
+
+ if key == pygame.K_SPACE:
+ self.SPACE = flag
+ if key == pygame.K_RETURN:
+ self.RETURN = flag
+
+ if key == pygame.K_s:
+ self.KEY_S = flag
+ if key == pygame.K_q:
+ self.KEY_Q = flag
+ if key == pygame.K_w:
+ self.KEY_W = flag
+ if key == pygame.K_e:
+ self.KEY_E = flag
+ if key == pygame.K_r:
+ self.KEY_R = flag
+ if key == pygame.K_t:
+ self.KEY_T = flag
+ if key == pygame.K_a:
+ self.KEY_A = flag
+ if key == pygame.K_s:
+ self.KEY_S = flag
+ if key == pygame.K_d:
+ self.KEY_D = flag
+ if key == pygame.K_f:
+ self.KEY_F = flag
+ if key == pygame.K_g:
+ self.KEY_G = flag
+
+ if key == pygame.K_0:
+ self.KEY_0 = flag
+ if key == pygame.K_1:
+ self.KEY_1 = flag
+ if key == pygame.K_2:
+ self.KEY_2 = flag
+ if key == pygame.K_3:
+ self.KEY_3 = flag
+ if key == pygame.K_4:
+ self.KEY_4 = flag
+ if key == pygame.K_5:
+ self.KEY_5 = flag
+ if key == pygame.K_6:
+ self.KEY_6 = flag
+ if key == pygame.K_7:
+ self.KEY_7 = flag
+ if key == pygame.K_8:
+ self.KEY_8 = flag
+ if key == pygame.K_9:
+ self.KEY_9 = flag
+
+ def input(self):
+ """
+ Parse keyboard and mouse input
+ """
+ events = pygame.event.get() # Get events
+ pos = pygame.mouse.get_pos() # Get mouse position
+
+ # If mouse is not pressing down
+ if self.mouse_down is False:
+ # Get current screen based on mouse position
+ screen = None
+ for key, scr in self.screens.items():
+ if scr.inside(pos):
+ screen = scr
+ break
+ # Set screen focus based on mouse position
+ focus = screen is not None and pygame.mouse.get_focused()
+ if not focus:
+ self.mouse_pos = None
+ else:
+ # Use stored screen and focus
+ screen, focus = self.tmp_screen, self.tmp_focus
+
+ # For each event
+ for event in events:
+ # If x button is pressed
+ if event.type == pygame.QUIT:
+ return False
+ # If key has been presed down
+ if event.type == pygame.KEYDOWN:
+ # If escape has been pressed, exit
+ if event.key == pygame.K_ESCAPE:
+ return False
+ # If p has been pressed, return virtual camera pose
+ if event.key == pygame.K_p:
+ if self.currScreen().viewer is not None:
+ print('(%7.5f, %7.5f, %7.5f, %1.5f, %1.5f, %1.5f, %1.5f)' %
+ self.currScreen().viewer.current7())
+ # Change key to pressed
+ self.change_keys(event.key, True)
+ # If key has been released
+ if event.type == pygame.KEYUP:
+ self.change_keys(event.key, False)
+ # If mouse button has been pressed down
+ if event.type == pygame.MOUSEBUTTONDOWN:
+ self.mouse_down = True
+ self.tmp_screen, self.tmp_focus = screen, focus
+ # If it's a 3D world screen
+ if focus and screen.mode is '3D_WORLD':
+ if event.button == 4: # Wheel forward
+ if self.RALT: # Going for rotation in Z
+ screen.viewer.rotateZ(5.0 if self.RCTRL else 0.05 if self.LCTRL else 0.5)
+ else: # Going for translation in Z
+ screen.viewer.translateZ(+(5.0 if self.RCTRL else 0.2 if self.LCTRL else 1.0))
+ if event.button == 5: # Wheel backwards
+ if self.RALT: # Going for rotation in Z
+ screen.viewer.rotateZ(-5.0 if self.RCTRL else -0.05 if self.LCTRL else -0.5)
+ else: # Going for translation in Z
+ screen.viewer.translateZ(-(5.0 if self.RCTRL else 0.2 if self.LCTRL else 1.0))
+ if event.button == 1: # Left button
+ self.motion_type, self.mouse_pos = 1, pos
+ if event.button == 3: # Right button
+ self.motion_type, self.mouse_pos = 3, pos
+ if event.button == 2: # Wheel press
+ screen.reset()
+ # If it's a 2D image screen
+ if focus and screen.mode is '2D_IMAGE':
+ if event.button == 1: # Left button
+ self.motion_type, self.mouse_pos = 1, pos
+ if event.button == 2: # Wheel press
+ screen.res = list(screen.orig_res)
+ else:
+ # Change resolution
+ rel = [(pos[0] - screen.luwh[0]) / screen.luwh[2] * screen.res[2] + screen.res[0],
+ (pos[1] - screen.luwh[1]) / screen.luwh[3] * screen.res[3] + screen.res[1]]
+ # Get speed multiplier
+ mlt = 1.20 if self.RSHIFT else 1.05
+ # Wheel forward
+ if event.button == 4:
+ if screen.res[0] < 0.95 * screen.res[2] and \
+ screen.res[1] < 0.95 * screen.res[3]:
+ screen.res[2] = (screen.res[0] + screen.res[2] - rel[0])/mlt + rel[0] - screen.res[0]
+ screen.res[0] = (screen.res[0] - rel[0])/mlt + rel[0]
+ screen.res[3] = (screen.res[1] + screen.res[3] - rel[1])/mlt + rel[1] - screen.res[1]
+ screen.res[1] = (screen.res[1] - rel[1])/mlt + rel[1]
+ # Wheel backwards
+ elif event.button == 5:
+ screen.res[2] = (screen.res[0] + screen.res[2] - rel[0])*mlt + rel[0] - screen.res[0]
+ screen.res[0] = (screen.res[0] - rel[0])*mlt + rel[0]
+ screen.res[3] = (screen.res[1] + screen.res[3] - rel[1])*mlt + rel[1] - screen.res[1]
+ screen.res[1] = (screen.res[1] - rel[1])*mlt + rel[1]
+ # Change resolution
+ screen.res[0] = max(screen.res[0], screen.orig_res[0])
+ screen.res[1] = max(screen.res[1], screen.orig_res[1])
+ screen.res[2] = min(screen.res[2], screen.orig_res[2])
+ screen.res[3] = min(screen.res[3], screen.orig_res[3])
+ # If mouse button has been released
+ if event.type == pygame.MOUSEBUTTONUP:
+ self.mouse_down = False
+ self.mouse_pos = None
+ # If screen has focus
+ if focus:
+ # If it's a 3D world screen
+ if screen.mode == '3D_WORLD':
+ # Get new mouse position
+ if self.mouse_pos is not None:
+ dX = pos[0] - self.mouse_pos[0]
+ dY = pos[1] - self.mouse_pos[1]
+ self.mouse_pos = pos
+ # If left button
+ if self.motion_type == 1:
+ mlin = 1.00 if self.RCTRL else 0.02 if self.LCTRL else 0.10
+ screen.viewer.translateX(- dX * mlin)
+ screen.viewer.translateY(- dY * mlin)
+ # If right button
+ elif self.motion_type == 3:
+ mang = 0.25 if self.RCTRL else 0.01 if self.LCTRL else 0.05
+ if screen.ref == 'cam': # Rotation in camera reference
+ screen.viewer.rotateX(- dY * mang)
+ screen.viewer.rotateY(+ dX * mang)
+ elif screen.ref == 'lidar': # Rotation in lidar reference
+ screen.viewer.rotateX(- dY * mang)
+ screen.viewer.rotateZ(- dX * mang)
+ # If it's a 2D image screen
+ elif screen.mode == '2D_IMAGE':
+ # Get new mouse position
+ if self.mouse_pos is not None:
+ mlin = 5.00 if self.RCTRL else 1.00
+ dX = pos[0] - self.mouse_pos[0]
+ dY = pos[1] - self.mouse_pos[1]
+ self.mouse_pos = pos
+ # Resize and move screen center around
+ screen.res[0] -= dX * mlin
+ screen.res[2] -= dX * mlin
+ if screen.res[0] < screen.orig_res[0] or \
+ screen.res[2] > screen.orig_res[2]:
+ screen.res[0] += dX * mlin
+ screen.res[2] += dX * mlin
+ screen.res[1] -= dY * mlin
+ screen.res[3] -= dY * mlin
+ if screen.res[1] < screen.orig_res[1] or \
+ screen.res[3] > screen.orig_res[3]:
+ screen.res[1] += dY * mlin
+ screen.res[3] += dY * mlin
+ # Continue and return True
+ return True
+
+ @staticmethod
+ def update(wait):
+ """Update window after every wait milisseconds"""
+ pygame.display.flip()
+ pygame.time.wait(wait)
+
+ @staticmethod
+ def control(obj):
+ """Control an object with keyboard"""
+ # Get velocity values
+ dlin, dang = 0.2, 5.0
+ # Get keys
+ keys = pygame.key.get_pressed()
+ # Check for changes
+ change = False
+ # Translate in +Z if UP
+ if keys[pygame.K_UP]:
+ change = True
+ obj.translateZ(+dlin)
+ # Translate in -Z if DOWN
+ if keys[pygame.K_DOWN ]:
+ change = True
+ obj.translateZ(-dlin)
+ # Translate in -X if LEFT
+ if keys[pygame.K_LEFT]:
+ change = True
+ obj.translateX(-dlin)
+ # Translate in +X if RIGHT
+ if keys[pygame.K_RIGHT]:
+ change = True
+ obj.translateX(+dlin)
+ # Translate in -Y if Q
+ if keys[pygame.K_q]:
+ change = True
+ obj.translateY(-dlin)
+ # Translate in +Y if A
+ if keys[pygame.K_a]:
+ change = True
+ obj.translateY(+dlin)
+ # Rotate in +Y if S
+ if keys[pygame.K_s]:
+ change = True
+ obj.rotateY(+dang)
+ # Rotate in -Y if F
+ if keys[pygame.K_f]:
+ change = True
+ obj.rotateY(-dang)
+ # Rotate in -X if E
+ if keys[pygame.K_e]:
+ change = True
+ obj.rotateX(-dang)
+ # Rotate in +X if D
+ if keys[pygame.K_d]:
+ change = True
+ obj.rotateX(+dang)
+ # Rotate in +Z if W
+ if keys[pygame.K_w]:
+ change = True
+ obj.rotateZ(+dang)
+ # Rotate in -Z if R
+ if keys[pygame.K_r]:
+ change = True
+ obj.rotateZ(-dang)
+ # Return change value
+ return change
diff --git a/externals/camviz/camviz/draw/draw_texture.py b/externals/camviz/camviz/draw/draw_texture.py
new file mode 100755
index 0000000000000000000000000000000000000000..c64ce3afe0c045bb33b9e3286dfedb89b30d5a9e
--- /dev/null
+++ b/externals/camviz/camviz/draw/draw_texture.py
@@ -0,0 +1,99 @@
+# Copyright 2021 Toyota Research Institute. All rights reserved.
+
+from OpenGL.GL import \
+ glTexCoord2f, glBegin, glEnd, glVertex2fv, glVertex3fv, \
+ GL_QUADS
+
+from camviz.containers.texture import Texture
+from camviz.opengl.opengl_colors import White
+from camviz.utils.utils import labelrc, numpyf
+from camviz.utils.types import is_tuple, is_list, is_int
+
+class DrawTexture:
+ """Draw subclass containing texture methods"""
+ def addTexture(self, name, data=None, n=None):
+ """
+ Create a new texture buffer
+
+ Parameters
+ ----------
+ name : str
+ Buffer name
+ data : np.array [N,D] or tuple (n,d)
+ Data to be added to the buffer
+ If it's a tuple, create a data buffer of that size
+ n : int or tuple
+ Number of textures to be added
+ """
+ # If it's a tuple, create individual names for each texture
+ if is_tuple(name):
+ name = labelrc(name)
+ # If it's a list, add each item to its own texture
+ if is_list(name):
+ for i in range(len(name)):
+ self.addTexture(name[i], data[i] if is_list(data) else data)
+ # Otherwise, create a single texture from data
+ else:
+ if n is not None:
+ if is_tuple(n):
+ for i in range(n[0]):
+ for j in range(n[1]):
+ self.textures['%s%d%d' % (name, i, j)] = Texture(data)
+ elif is_int(n):
+ for i in range(n):
+ self.textures['%s%d' % (name, i)] = Texture(data)
+ self.textures[name] = Texture(data)
+
+ def updTexture(self, name, data):
+ """Update texture with new data"""
+ self.textures[name].update(data)
+
+ def image(self, name, data=None, verts=None, fit=False):
+ """
+ Display a texture on screen
+
+ Parameters
+ ----------
+ name : str
+ Name of the texture
+ data : np.array
+ Update texture with new data before displaying
+ verts : np.array
+ Vertices for the texture borders on screen
+ fit : bool
+ If true, resize screen to fit new image
+ """
+ # If no name is provided, return None
+ if name is None:
+ return
+ # Get texture ID from name
+ tex = self.textures[name]
+ # Resize screen to fit screen if necessary
+ if fit is True:
+ self.currScreen().setRes(tex.wh)
+ # If data is provided, update texture first
+ if data is not None:
+ tex.update(data)
+ # If verts is not provided, create them based on screen dimension
+ if verts is None:
+ verts = [[tex.wh[0], 0.0 ], [tex.wh[0], tex.wh[1]],
+ [ 0.0 , tex.wh[1]], [ 0.0 , 0.0 ]]
+ verts = numpyf(verts)
+ # Draw texture
+ White()
+ tex.bind()
+ glBegin(GL_QUADS)
+ glVertex = glVertex2fv if len(verts[0]) == 2 else glVertex3fv
+ glTexCoord2f(1.0, 1.0)
+ glVertex(verts[0])
+ glTexCoord2f(1.0, 0.0)
+ glVertex(verts[1])
+ glTexCoord2f(0.0, 0.0)
+ glVertex(verts[2])
+ glTexCoord2f(0.0, 1.0)
+ glVertex(verts[3])
+ glEnd()
+ tex.unbind()
+ # Return self
+ return self
+
diff --git a/externals/camviz/camviz/objects/__init__.py b/externals/camviz/camviz/objects/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..520f40c05b716320ee33af7bcfdf1d8554e6a1c3
--- /dev/null
+++ b/externals/camviz/camviz/objects/__init__.py
@@ -0,0 +1,3 @@
+# Copyright 2021 Toyota Research Institute. All rights reserved.
+
+from camviz.objects.camera import Camera
diff --git a/externals/camviz/camviz/objects/bbox2d.py b/externals/camviz/camviz/objects/bbox2d.py
new file mode 100755
index 0000000000000000000000000000000000000000..8b0a3db541f111e3799710e27628910fa6461a16
--- /dev/null
+++ b/externals/camviz/camviz/objects/bbox2d.py
@@ -0,0 +1,45 @@
+# Copyright 2021 Toyota Research Institute. All rights reserved.
+
+import numpy as np
+
+from camviz.objects.object import Object
+
+
+class BBox2D(Object):
+ """
+ Bounding Box 2D draw class
+
+ Parameters
+ ----------
+ points : np.array
+ List of points for the bounding box dimension (left, top, right, bottom)
+ pose : np.array
+ Bounding box pose on the screen (right, down)
+ """
+ def __init__(self, points, pose=None):
+ super().__init__(pose=pose)
+ self.pts = np.array([[points[0], points[1]],
+ [points[2], points[1]],
+ [points[2], points[3]],
+ [points[0], points[3]]])
+
+ def draw(self, draw, color_line='gre', color_edge=None):
+ """
+ Draw 2D bounding box on screen
+
+ Parameters
+ ----------
+ draw : camviz.Draw
+ Draw instance
+ color_line : str
+ Line color
+ color_edge : str
+ Edge color
+ """
+ # Set color line if provided
+ if color_line is not None:
+ draw.color(color_line).width(2).lines(
+ self.pts[[0, 1, 1, 2, 2, 3, 3, 0]])
+ # Set color edge if provided
+ if color_edge is not None:
+ draw.color(color_edge).size(4).points(self.pts)
diff --git a/externals/camviz/camviz/objects/bbox3d.py b/externals/camviz/camviz/objects/bbox3d.py
new file mode 100755
index 0000000000000000000000000000000000000000..d385c13c808b555d424bb66444613fcf64ca50e3
--- /dev/null
+++ b/externals/camviz/camviz/objects/bbox3d.py
@@ -0,0 +1,42 @@
+# Copyright 2021 Toyota Research Institute. All rights reserved.
+
+from camviz.objects.object import Object
+
+
+class BBox3D(Object):
+ """
+ Bounding Box 3D draw class
+
+ Parameters
+ ----------
+ points : np.array
+ List of points for the bounding box dimension (assuming center is 0,0,0)
+ Order: +++, +-+, +--, ++-, -++, --+, ---, -+-
+ pose : np.array
+ Bounding box pose (x-forward, y-left, z-up)
+ """
+ def __init__(self, points, pose=None):
+ super().__init__(pose=pose)
+ self.pts = points
+
+ def draw(self, draw, color_line='gre', color_edge=None):
+ """
+ Draw 2D bounding box on screen
+
+ Parameters
+ ----------
+ draw : camviz.Draw
+ Draw instance
+ color_line : str
+ Line color
+ color_edge : str
+ Edge color
+ """
+ # Set color line if provided
+ if color_line is not None:
+ draw.color(color_line).width(2).lines(
+ self.pts[[0, 1, 1, 2, 2, 3, 3, 0, 4, 5, 5, 6,
+ 6, 7, 7, 4, 0, 4, 1, 5, 2, 6, 3, 7]])
+ # Set color edge if provided
+ if color_edge is not None:
+ draw.color(color_edge).size(4).points(self.pts)
diff --git a/externals/camviz/camviz/objects/camera.py b/externals/camviz/camviz/objects/camera.py
new file mode 100644
index 0000000000000000000000000000000000000000..3559597bf7b77bdfd4b860e07c8a086b021ed7f1
--- /dev/null
+++ b/externals/camviz/camviz/objects/camera.py
@@ -0,0 +1,188 @@
+# Copyright 2021 Toyota Research Institute. All rights reserved.
+
+import numpy as np
+
+from camviz.objects.object import Object
+from camviz.utils.geometry import transpose, invert
+from camviz.utils.types import is_list, is_float
+from camviz.utils.utils import numpyf, add_row0, add_col1, image_grid
+
+
+def camviz_camera(camera):
+ """
+ Converts packnet-sfm cameras to camviz cameras
+
+ Parameters
+ ----------
+ camera : Camera or list[Camera]
+ Input packnet-sfm cameras
+
+ Returns
+ -------
+ camera_cv : camviz.objects.camera.Camera
+ output camviz cameras
+ """
+ # Create a list of cameras if necessary
+ if is_list(camera):
+ return [camviz_camera(cam) for cam in camera]
+ # Return a list of cameras for each batch camera
+ return [Camera(cam=cam) for cam in camera]
+
+
+class Camera(Object):
+ """
+ Create a camera class
+
+ Parameters
+ ----------
+ scale : float
+ Scale used when drawing the object
+ wh : tuple
+ Image dimensions (width, height)
+ K : np.array
+ Camera intrinsics [3,3]
+ pose : np.array
+ Object pose
+ """
+ def __init__(self, scale=1.0, wh=None, K=None, pose=None):
+ # Initialize object super-class
+ super().__init__(scale, pose)
+ # If intrinsics is provided, use it
+ if K is not None:
+ self.K = transpose(numpyf(K))
+ self.iK = np.linalg.inv(self.K)
+ # If image dimensions is not provided, use it
+ if wh is not None:
+ if not isinstance(wh, (list, tuple)):
+ wh = wh.shape[:2]
+ self.w, self.h = wh
+ uv = numpyf([[self.w - 1, 0 ],
+ [self.w - 1, self.h - 1],
+ [ 0 , self.h - 1],
+ [ 0 , 0 ]])
+ self.v = add_row0(self.i2c(scale, uv))
+
+ @staticmethod
+ def from_vidar(cam, b, scale=1.0):
+ return Camera(K=cam.K[b][:3, :3],
+ pose=cam.Tcw.T[b] if cam.Twc is not None else None,
+ wh=cam.wh, scale=scale)
+
+ def i2c(self, depth=1.0, uv=None):
+ """
+ Project an image to camera coordinates using a depth map
+
+ Parameters
+ ----------
+ depth : float or np.array
+ Depth values for lifting
+ uv : np.array
+ Image grid for lifting
+
+ Returns
+ -------
+ xyz : np.array
+ Lifted 3D points in camera frame of reference
+ """
+ # If no grid is provided, uses depth map
+ if uv is None:
+ if not is_float(depth):
+ # Create image grid from depth values
+ uv = image_grid(depth)
+ else:
+ # Impossible to create an image grid
+ raise ValueError('No available grid for camera')
+ # Add third unitary coordinate to the image grid
+ if uv.shape[1] == 2:
+ uv = add_col1(uv)
+ # A depth map was provided, create a grid from it
+ elif uv.shape[1] > 3:
+ uv = image_grid(uv)
+ # If there are individual depth values per image grid cell
+ if not is_float(depth):
+ if len(depth.shape) == 1:
+ depth = depth[:, np.newaxis]
+ elif depth.shape[1] > 1:
+ if len(depth.shape) == 3:
+ depth = depth[:, :, 0]
+ depth = depth.reshape(-1, 1)
+ return (uv @ self.iK) * depth
+
+ def c2i(self, xyz, filter=False, padding=0, return_z=False):
+ """
+ Project 3D points in camera frame of reference to the image plane
+
+ Parameters
+ ----------
+ xyz : np.array
+ 3D points to be projected
+ filter : bool
+ Filter points outside boundaries
+ padding : int or float
+ Padding for filtering
+ return_z : bool
+ Return z values as well or not
+
+ Returns
+ -------
+ uv : np.array
+ 2D coordinates of projected points
+ idx : np.array
+ Valid indexes in case filtering was enabled
+ depth : np.array
+ Depth values in case return_z was enabled
+ """
+ uv = (xyz / xyz[:, 2:] @ self.K)[:, :2]
+ if filter:
+ idx = (uv[:, 0] > -padding) & (uv[:, 0] < self.w + padding) & \
+ (uv[:, 1] > -padding) & (uv[:, 1] < self.h + padding) & (xyz[:, 2] > 0)
+ if return_z:
+ return uv[idx], xyz[idx, 2:], idx
+ else:
+ return uv[idx], idx
+ else:
+ if return_z:
+ return uv, xyz[:, 2:]
+ else:
+ return uv
+
+ def c2w(self, xyz):
+ """Transform 3D points in camera frame of reference to world frame of reference"""
+ if xyz.shape[1] == 3:
+ xyz = add_col1(xyz)
+ return (xyz @ self.Tt)[:, :3]
+
+ def w2c(self, xyz):
+ """Transform 3D points in world frame of reference to camera frame of reference"""
+ if xyz.shape[1] == 3:
+ xyz = add_col1(xyz)
+ return (xyz @ invert(self.Tt))[:, :3]
+
+ def i2w(self, depth=1.0, uv=None):
+ """Lift 2D image points to 3D space in world frame of reference"""
+ return self.c2w(self.i2c(depth, uv))
+
+ def w2i(self, xyz, filter=False, padding=0, return_z=False):
+ """Project 3D points in world frame of reference to the image plane"""
+ return self.c2i(self.w2c(xyz), filter=filter,
+ padding=padding, return_z=return_z)
+
+ def draw(self, draw, tex=None, axes=True, color='gra'):
+ """
+ Draw a camera in a 3D screen
+
+ Parameters
+ ----------
+ draw : Draw
+ Draw class to be used
+ tex : int
+ Optional texture to draw on the camera image plane
+ axes : bool
+ True if coordinate axes should be drawn as well
+ color : str
+ Which color should be used for the camera
+ """
+ draw.image(tex, verts=self.v[:4])
+ draw.color(color).width(4).connects(self.v[4], self.v[:4]).loop(self.v[:4])
+ if axes:
+ draw.axis(0.25 * self.scale)
diff --git a/externals/camviz/camviz/objects/object.py b/externals/camviz/camviz/objects/object.py
new file mode 100755
index 0000000000000000000000000000000000000000..4a0e95186e815e34d814047ed69c171c3f52252b
--- /dev/null
+++ b/externals/camviz/camviz/objects/object.py
@@ -0,0 +1,111 @@
+# Copyright 2021 Toyota Research Institute. All rights reserved.
+
+from OpenGL.GL import *
+
+from camviz.objects.pose import Pose
+
+
+class Object:
+ """
+ Base object draw class
+
+ Parameters
+ ----------
+ scale : float
+ Scale used when drawing the object
+ pose : np.array
+ Object pose
+ """
+ def __init__(self, scale=1.0, pose=None):
+ self.scale = scale
+ self.pose = pose if isinstance(pose, Pose) else Pose(pose)
+
+ @property
+ def t(self):
+ """Return pose translation"""
+ return self.pose.t
+
+ @property
+ def R(self):
+ """Return pose rotation"""
+ return self.pose.R
+
+ @property
+ def T(self):
+ """Return pose transformation"""
+ return self.pose.T
+
+ @property
+ def Rt(self):
+ """Return pose rotation transposed"""
+ return self.pose.Rt
+
+ @property
+ def Tt(self):
+ """Return pose transformation transposed"""
+ return self.pose.Tt
+
+ def translateX(self, m):
+ """Translate object in X by m"""
+ return self.pose.translateX(m)
+
+ def translateY(self, m):
+ """Translate object in Y by m"""
+ return self.pose.translateY(m)
+
+ def translateZ(self, m):
+ """Translate object in Z by m"""
+ return self.pose.translateZ(m)
+
+ def rotateX(self, d):
+ """Rotate object in X by d degrees"""
+ return self.pose.rotateX(d)
+
+ def rotateY(self, d):
+ """Rotate object in Y by d degrees"""
+ return self.pose.rotateY(d)
+
+ def rotateZ(self, d):
+ """Rotate object in Z by d degrees"""
+ return self.pose.rotateZ(d)
+
+ def rotateI(self, d):
+ """Rotate object in X by d degrees (from the camera's perspective)"""
+ return self.pose.rotateI(d)
+
+ def rotateJ(self, d):
+ """Rotate object in Y by d degrees (from the camera's perspective)"""
+ return self.pose.rotateJ(d)
+
+ def rotateK(self, d):
+ """Rotate object in Z by d degrees (from the camera's perspective)"""
+ return self.pose.rotateK(d)
+
+ def setPose(self, pose):
+ """Set object pose"""
+ return self.pose.setPose(pose)
+
+ def display(self, *args, align=None, **kwargs):
+ """
+ Display object
+
+ Parameters
+ ----------
+ args : args
+ Extra draw arguments
+ align : camviz.Pose
+ Pose used to align the object
+ kwargs : kwargs
+ Extra draw arguments
+ """
+ # Get transformation (aligned or not)
+ if align is not None:
+ T = (align @ self.pose).Tt
+ else:
+ T = self.Tt
+
+ # Draw object
+ glPushMatrix()
+ glMultMatrixf(T)
+ self.draw(*args, **kwargs)
+ glPopMatrix()
diff --git a/externals/camviz/camviz/objects/pointcloud.py b/externals/camviz/camviz/objects/pointcloud.py
new file mode 100755
index 0000000000000000000000000000000000000000..ddb786bbfd6e96d1140a0096f2148881b975a5be
--- /dev/null
+++ b/externals/camviz/camviz/objects/pointcloud.py
@@ -0,0 +1,43 @@
+# Copyright 2021 Toyota Research Institute. All rights reserved.
+
+from camviz.objects.object import *
+
+
+class Pointcloud(Object):
+ """
+ Bounding Box 3D draw class
+
+ Parameters
+ ----------
+ scale : float
+ Scale used when drawing the object
+ pts : np.array
+ Pointcloud points
+ pose : np.array
+ Bounding box pose
+ draw : camviz.Draw
+ Draw instance
+ """
+ def __init__(self, scale=1.0, pts=None, pose=None, draw=None):
+ super().__init__(scale, pose)
+ if draw is not None:
+ draw.addBufferf('pts', pts)
+ self.pts = 'pts'
+ else:
+ self.pts = pts
+
+ def draw(self, draw, size=1, color='whi'):
+ """
+ Draw pointcloud on screen
+
+ Parameters
+ ----------
+ draw : camviz.Draw
+ Draw instance
+ size : int
+ Point size
+ color : str
+ Point color
+ """
+ draw.color(color).size(size).points(self.pts)
+
diff --git a/externals/camviz/camviz/objects/pose.py b/externals/camviz/camviz/objects/pose.py
new file mode 100755
index 0000000000000000000000000000000000000000..9d92457de71c127c2bbaea7d2ebe63769034904b
--- /dev/null
+++ b/externals/camviz/camviz/objects/pose.py
@@ -0,0 +1,192 @@
+# Copyright 2021 Toyota Research Institute. All rights reserved.
+
+from copy import deepcopy
+
+import numpy as np
+
+from camviz.objects.quaternion import Quaternion
+from camviz.utils.geometry import unitX, unitY, unitZ
+from camviz.utils.utils import numpyf, add_col1
+
+
+def rot2quat(R):
+ """Convert rotation matrix to quaternion"""
+ qw = np.sqrt(1.0 + R[0, 0] + R[1, 1] + R[2, 2]) / 2
+ qx = (R[1, 2] - R[2, 1]) / (4.0 * qw)
+ qy = (R[2, 0] - R[0, 2]) / (4.0 * qw)
+ qz = (R[0, 1] - R[1, 0]) / (4.0 * qw)
+ return Quaternion(qw, qx, qy, qz)
+
+
+class Pose:
+
+ def __init__(self, pose=None, align=False):
+ """
+ Pose class
+
+ Parameters
+ ----------
+ pose : np.array
+ Initial pose
+ align : np.array
+ Optional transformation matrix used for alignment
+ """
+ self.q = self.M = None
+ # If pose is provided, use it
+ if pose is not None:
+ self.setPose(pose, align)
+ # Otherwise, set to identity
+ else:
+ self.reset()
+
+ def copy(self):
+ """Return a copy of the instance"""
+ return deepcopy(self)
+
+ @property
+ def t(self):
+ """Return pose translation"""
+ return self.M[:3, 3]
+
+ @property
+ def R(self):
+ """Return pose translation"""
+ return self.M[:3, :3]
+
+ @property
+ def T(self):
+ """Return pose transformation"""
+ return self.M
+
+ @property
+ def Rt(self):
+ """Return pose rotation transposed"""
+ return self.R.T
+
+ @property
+ def Tt(self):
+ """Return pose transformation transposed"""
+ return self.T.T
+
+ @property
+ def inv(self):
+ """Return inverted pose"""
+ Tinv = self.T.copy()
+ Tinv[:3, :3] = np.transpose(self.T[:3, :3])
+ Tinv[:3, -1] = np.matmul(-1. * Tinv[:3, :3], self.T[:3, -1])
+ return Pose(Tinv)
+
+ @property
+ def Tinv(self):
+ """Return inverted pose transformation"""
+ return self.inv.T
+
+ def translateX(self, m):
+ """Translate object in X by m"""
+ return self.translate(unitX(m))
+
+ def translateY(self, m):
+ """Translate object in Y by m"""
+ return self.translate(unitY(m))
+
+ def translateZ(self, m):
+ """Translate object in Z by m"""
+ return self.translate(unitZ(m))
+
+ def rotateX(self, d, M=None):
+ """Rotate object in X by d degrees"""
+ return self.rotate(d, (self.M if M is None else M)[:3, 0])
+
+ def rotateY(self, d, M=None):
+ """Rotate object in Y by d degrees"""
+ return self.rotate(d, (self.M if M is None else M)[:3, 1])
+
+ def rotateZ(self, d, M=None):
+ """Rotate object in Z by d degrees"""
+ return self.rotate(d, (self.M if M is None else M)[:3, 2])
+
+ def rotateI(self, d):
+ """Rotate object in X by d degrees (from the camera's perspective)"""
+ return self.rotate(d, unitX(1))
+
+ def rotateJ(self, d):
+ """Rotate object in Y by d degrees (from the camera's perspective)"""
+ return self.rotate(d, unitY(1))
+
+ def rotateK(self, d):
+ """Rotate object in Z by d degrees (from the camera's perspective)"""
+ return self.rotate(d, unitZ(1))
+
+ def setPose(self, mat, align=False):
+ """
+ Set pose value
+
+ Parameters
+ ----------
+ mat : np.array
+ New pose value
+ align : np.array
+ Optional transformation matrix used for alignment
+ """
+ # Convert to numpy
+ mat = numpyf(mat)
+ # If mat is as 1-dimensional vector
+ if len(mat.shape) == 1:
+ # If it has 16 values, reshape and use it as a transformation matrix
+ if mat.shape[0] == 16:
+ self.M = np.reshape(mat, (4, 4))
+ self.q = rot2quat(self.M)
+ # If it has 7 values, treat is as translation + quaternion
+ if mat.shape[0] == 7:
+ self.M = numpyf(np.identity(4))
+ self.M[:3, 3] = [mat[0], mat[1], mat[2]]
+ self.q = Quaternion(mat[3], mat[4], mat[5], mat[6])
+ self.M[:3, :3] = self.q.rotmat().T
+ # If it's two-dimensional, treat it as a transformation matrix
+ elif len(mat.shape) == 2:
+ if mat.shape[0] == 4 and mat.shape[1] == 4:
+ self.M = mat
+ self.q = rot2quat(self.M)
+ # Update transformation matrix
+ self.M = numpyf(self.M)
+ # Align if necessary
+ if align:
+ R = np.array([[0, -1, 0, 0],
+ [0, 0, -1, 0],
+ [1, 0, 0, 0],
+ [0, 0, 0, 1]])
+ self.M = R @ self.M
+ self.q = rot2quat(self.M)
+
+ def reset(self):
+ """Reset pose"""
+ self.q = Quaternion()
+ self.M = numpyf(np.identity(4))
+
+ def translate(self, axis):
+ """Translate pose in a certain axis"""
+ self.M[:3, 3] += self.q.rotate(numpyf(axis))
+ return self
+
+ def rotate(self, deg, axis):
+ """Rotate pose by deg in a certain axis"""
+ self.q *= Quaternion(numpyf(axis), deg)
+ self.M[:3, :3] = self.q.rotmat().T
+ return self
+
+ def current7(self):
+ """Return current translation and quaternion values"""
+ t, q = self.M[:3, 3], self.q.coefs
+ return t[0], t[1], t[2], q[0], q[1], q[2], q[3]
+
+ def __matmul__(self, other):
+ """Multiply pose with something else"""
+ # Pose x Pose
+ if isinstance(other, Pose):
+ return Pose(self.M @ other.T)
+ # Pose x points
+ elif other.shape[1] == 3:
+ return (add_col1(other) @ self.Tt)[:, :3]
+ # Generic multiplication
+ else:
+ return self.M @ other.T
diff --git a/externals/camviz/camviz/objects/quaternion.py b/externals/camviz/camviz/objects/quaternion.py
new file mode 100755
index 0000000000000000000000000000000000000000..fb70c5d815aafb9e917536a08fa419f7b6ceb8dd
--- /dev/null
+++ b/externals/camviz/camviz/objects/quaternion.py
@@ -0,0 +1,72 @@
+# Copyright 2021 Toyota Research Institute. All rights reserved.
+
+import math
+
+import numpy as np
+
+
+class Quaternion:
+ """Quaternion class"""
+ def __init__(self, *args):
+ # If no arguments are provided, create an identity quaternion
+ if len(args) == 0:
+ self.coefs = 1.0, 0.0, 0.0, 0.0
+ # If a single argument is provided
+ elif len(args) == 1:
+ # If it's a tuple, it contains coefficients
+ if isinstance(args[0], tuple):
+ self.coefs = args[0]
+ # Otherwise, assume it's a rotation matrix
+ else:
+ R = np.array(args[0])
+ w = np.sqrt(1.0 + R[0, 0] + R[1, 1] + R[2, 2]) / 2
+ x = (R[1, 2] - R[2, 1]) / (4.0 * w)
+ y = (R[2, 0] - R[0, 2]) / (4.0 * w)
+ z = (R[0, 1] - R[1, 0]) / (4.0 * w)
+ self.coefs = w, x, y, z
+ # If two arguments are provided, assume it's axis and degree
+ elif len(args) == 2:
+ v, d = args
+ r = d * math.pi / 360.0
+ c, s = math.cos(r), math.sin(r)
+ self.coefs = c, v[0] * s, v[1] * s, v[2] * s
+ # If there are four arguments, assume each individual coefficient is provided
+ elif len(args) == 4:
+ self.coefs = args
+
+ def __getitem__(self, idx):
+ """Return quaternion coefficients"""
+ return self.coefs[idx]
+
+ def __mul__(self, r):
+ """Multiply quaternion with rotation angles"""
+ q, r = self.coefs, r.coefs
+ return Quaternion(r[0] * q[0] - r[1] * q[1] - r[2] * q[2] - r[3] * q[3],
+ r[0] * q[1] + r[1] * q[0] - r[2] * q[3] + r[3] * q[2],
+ r[0] * q[2] + r[1] * q[3] + r[2] * q[0] - r[3] * q[1],
+ r[0] * q[3] - r[1] * q[2] + r[2] * q[1] + r[3] * q[0])
+
+ def invert(self):
+ """Return inverted quaternion"""
+ w, x, y, z = self.coefs
+ d = np.sqrt(w * w + x * x + y * y + z * z)
+ return Quaternion(w / d, - x / d, - y / d, - z / d)
+
+ def rotate(self, p):
+ """Rotate points"""
+ vec = self.coefs[1:]
+
+ uv = np.cross(p, vec)
+ uuv = np.cross(uv, vec)
+
+ return p + 2 * (self.coefs[0] * uv + uuv)
+
+ def rotmat(self):
+ """Return rotation matrix"""
+ w, x, y, z = self.coefs
+ xx, yy, zz = x * x, y * y, z * z
+ return np.array([[1-2*yy-2*zz, 2*x*y-2*z*w, 2*x*z+2*y*w],
+ [2*x*y+2*z*w, 1-2*xx-2*zz, 2*y*z-2*x*w],
+ [2*x*z-2*y*w, 2*y*z+2*x*w, 1-2*xx-2*yy]])
+
+
diff --git a/externals/camviz/camviz/opengl/opengl_colors.py b/externals/camviz/camviz/opengl/opengl_colors.py
new file mode 100755
index 0000000000000000000000000000000000000000..dbc7a63d48de6f60a156f0bb4ff5402da66a9b67
--- /dev/null
+++ b/externals/camviz/camviz/opengl/opengl_colors.py
@@ -0,0 +1,63 @@
+# Copyright 2021 Toyota Research Institute. All rights reserved.
+
+from OpenGL.GL import glColor3fv
+
+def Red(n=1.0):
+ """Change to red color"""
+ glColor3fv((n, 0.0, 0.0))
+
+def Green(n=1.0):
+ """Change to green color"""
+ glColor3fv((0.0, n, 0.0))
+
+def Blue(n=1.0):
+ """Change to blue color"""
+ glColor3fv((0.0, 0.0, n))
+
+def Yellow(n=1.0):
+ """Change to yellow color"""
+ glColor3fv((n, n, 0.0))
+
+def Magenta(n=1.0):
+ """Change to magenta color"""
+ glColor3fv((n, 0.0, n))
+
+def Cyan(n=1.0):
+ """Change to cyan color"""
+ glColor3fv((0.0, n, n))
+
+def Black():
+ """Change to black color"""
+ glColor3fv((0.0, 0.0, 0.0))
+
+def White():
+ """Change to white color"""
+ glColor3fv((1.0, 1.0, 1.0))
+
+def Gray():
+ """Change to gray color"""
+ glColor3fv((0.5, 0.5, 0.5))
+
+def setColor(clr, n=1.0):
+ """Change to an specific color based on a string"""
+ if clr == 'red':
+ Red(n)
+ if clr == 'gre':
+ Green(n)
+ if clr == 'blu':
+ Blue(n)
+ if clr == 'yel':
+ Yellow(n)
+ if clr == 'mag':
+ Magenta(n)
+ if clr == 'cya':
+ Cyan(n)
+ if clr == 'blk':
+ Black()
+ if clr == 'whi':
+ White()
+ if clr == 'gra':
+ Gray()
+ # If clr is a tuple, create that specific color
+ if isinstance(clr, tuple):
+ glColor3fv(clr)
diff --git a/externals/camviz/camviz/opengl/opengl_shapes.py b/externals/camviz/camviz/opengl/opengl_shapes.py
new file mode 100755
index 0000000000000000000000000000000000000000..8a1fe36ec8659c30f6732b18b47e1ccdd2681f2f
--- /dev/null
+++ b/externals/camviz/camviz/opengl/opengl_shapes.py
@@ -0,0 +1,181 @@
+# Copyright 2021 Toyota Research Institute. All rights reserved.
+
+import numpy as np
+from OpenGL.GL import \
+ glPointSize, glLineWidth, glVertex2fv, glVertex3fv, \
+ glPushMatrix, glPopMatrix, glMultMatrixf, glScalef, glBegin, glEnd, \
+ GL_LINES, GL_LINE_LOOP
+from OpenGL.GLU import \
+ gluSphere, gluNewQuadric
+
+from camviz.opengl.opengl_colors import Green, Blue, Red
+from camviz.utils.utils import numpyf, add_list
+from camviz.utils.types import is_numpy, is_double_list
+
+
+def vertex_line(pt1, pt2):
+ """Create line vertices"""
+ glVertex = glVertex2fv if len(pt1) == 2 else glVertex3fv
+ glVertex(pt1)
+ glVertex(pt2)
+
+def has_multiple(data):
+ """Checks if data has multiple entries (list of lists or (n,d) numpy array)"""
+ return (is_numpy(data) and len(data.shape) > 1) or is_double_list(data)
+
+def setPointSize(n=1):
+ """Set point size"""
+ glPointSize(n)
+
+def setLineWidth(n=1):
+ """Set line width"""
+ glLineWidth(n)
+
+def drawLine(pt1, pt2):
+ """Draw a line from two points"""
+ glBegin(GL_LINES)
+ vertex_line(pt1, pt2)
+ glEnd()
+
+def drawMatches(pts1, pts2):
+ """Draw 1 to 1 matches between two sets of points"""
+ glBegin(GL_LINES)
+ for i in range(len(pts1)):
+ vertex_line([pts1[i, 0], pts1[i, 1], pts1[i, 2]],
+ [pts2[i, 0], pts2[i, 1], pts2[i, 2]])
+ glEnd()
+
+def drawConnects(vert1, verts2):
+ """Draw connections from each vert1 to all vert2"""
+ vert1, verts2 = numpyf(vert1), numpyf(verts2)
+ glBegin(GL_LINES)
+ for vert2 in verts2:
+ vertex_line(vert1, vert2)
+ glEnd()
+
+def drawRect(lu=None, rd=None, ct=None, wh=None, x=False):
+ """
+ Draw a rectangle
+
+ Parameters
+ ----------
+ lu : np.array
+ Left/Up point
+ rd : np.array
+ Right/Down point
+ ct : np.array
+ Center point
+ wh : np.array
+ Width/height
+ x : bool
+ Draw an x inside the rectangle
+ """
+ # If no center is provided, get border points
+ if ct is not None and wh is not None:
+ lu = [ct[0] - wh[0] / 2, ct[1] - wh[1] / 2]
+ rd = [ct[0] + wh[0] / 2, ct[1] + wh[1] / 2]
+ ld, ru = [lu[0], rd[1]], [rd[0], lu[1]]
+ # Else, get points based on center
+ elif lu is not None:
+ if rd is not None:
+ ld, ru = [lu[0], rd[1]], [rd[0], lu[1]]
+ elif wh is not None:
+ ld, ru = [lu[0], lu[1] + wh[1]], \
+ [lu[0] + wh[0], rd[1]]
+ # Wrong parameters
+ else:
+ raise ValueError('wrong drawRect parameters')
+ # Draw rectangle
+ glBegin(GL_LINE_LOOP)
+ vertex_line(lu, ld)
+ vertex_line(rd, ru)
+ # Draw x
+ if x:
+ vertex_line(lu, rd)
+ vertex_line(ld, ru)
+ glEnd()
+
+def drawCross(ct, sz):
+ """
+ Draw a cross on screen
+
+ Parameters
+ ----------
+ ct : np.array
+ Cross center
+ sz : float
+ Cross size
+ """
+ # If there are multiple centers, draw one cross for each
+ if has_multiple(ct):
+ [drawCross(pt, sz) for pt in ct]
+ # If there is a single center
+ else:
+ # Get borders
+ u, d = ct[1] - sz / 2, ct[1] + sz / 2
+ l, r = ct[0] - sz / 2, ct[0] + sz / 2
+ # Draw cross
+ glBegin(GL_LINES)
+ vertex_line([l, ct[1]], [r, ct[1]])
+ vertex_line([ct[0], u], [ct[0], d])
+ glEnd()
+
+def drawAxis(scale=1.0, center=(0, 0, 0), width=None):
+ """
+ Draw a xyz axis on screen
+
+ Parameters
+ ----------
+ scale : float
+ Axis scale
+ center : np.array
+ Axis center
+ width : int
+ Axis line width
+ """
+ # Set width if provided
+ if width is not None:
+ setLineWidth(width)
+ # Convert center to numpy
+ center = numpyf(center)
+ # Draw axis
+ glBegin(GL_LINES)
+ Green()
+ vertex_line(center, add_list(center, (scale, 0, 0)))
+ Blue()
+ vertex_line(center, add_list(center, (0, scale, 0)))
+ Red()
+ vertex_line(center, add_list(center, (0, 0, scale)))
+ glEnd()
+
+def drawEllipse(mean, cov):
+ """
+ Draw an ellipse on screen
+
+ Parameters
+ ----------
+ mean : np.array
+ Ellipse mean
+ cov : np.array
+ Ellipse covariance
+ """
+ # If there are multiple means, draw one ellipse for each
+ if len(mean.shape) > 1:
+ for i in range(mean.shape[0]):
+ drawEllipse(mean[i], cov[i])
+ # Else, draw a single ellipse
+ else:
+ # Get eigenvalue and eigenvector
+ val, vec = np.linalg.eig(cov)
+ # Get transformation matrix
+ Tt = np.eye(4)
+ Tt[:3, :3] = vec
+ Tt[3, :3] = mean
+ # Apply transformation matrix and draw
+ glPushMatrix()
+ glMultMatrixf(Tt)
+ glScalef(2.0 * np.sqrt(val[0]),
+ 2.0 * np.sqrt(val[1]),
+ 2.0 * np.sqrt(val[2]))
+ gluSphere(gluNewQuadric(), 1.00, 100, 20)
+ glPopMatrix()
diff --git a/externals/camviz/camviz/screen/screen.py b/externals/camviz/camviz/screen/screen.py
new file mode 100755
index 0000000000000000000000000000000000000000..ac4ecdaa04901025a252e5327374b4f69543747b
--- /dev/null
+++ b/externals/camviz/camviz/screen/screen.py
@@ -0,0 +1,46 @@
+# Copyright 2021 Toyota Research Institute. All rights reserved.
+
+from copy import deepcopy
+
+
+class Screen:
+ def __init__(self, luwh, mode):
+ """
+ Screen class
+
+ Parameters
+ ----------
+ luwh : tuple
+ Left/right/width/height values
+ mode : str
+ Screen mode ('2D_IMAGE' or '3D_WORLD')
+ """
+ assert mode in ['2D_IMAGE', '3D_WORLD'], 'Invalid screen mode'
+ self.luwh, self.mode = luwh, mode
+ self.origin = self.viewer = None
+
+ def inside(self, pos):
+ """
+ Check if a 2D coordinate is inside the screen
+
+ Parameters
+ ----------
+ pos : np.array
+ Pose to check
+
+ Returns
+ -------
+ inside : bool
+ Whether pos is inside the screen
+ """
+ return self.luwh[0] < pos[0] < self.luwh[0] + self.luwh[2] and \
+ self.luwh[1] < pos[1] < self.luwh[1] + self.luwh[3]
+
+ def saveViewer(self):
+ """Save current virtual viewer camera (pose and intrinsics)"""
+ self.origin = deepcopy(self.viewer)
+
+ def reset(self):
+ """Reset current virtual viewer camera"""
+ self.viewer = deepcopy(self.origin)
+
diff --git a/externals/camviz/camviz/screen/screen2Dimage.py b/externals/camviz/camviz/screen/screen2Dimage.py
new file mode 100755
index 0000000000000000000000000000000000000000..0b1cbda5451bc4c12e667de536881acbee4db372
--- /dev/null
+++ b/externals/camviz/camviz/screen/screen2Dimage.py
@@ -0,0 +1,45 @@
+# Copyright 2021 Toyota Research Institute. All rights reserved.
+
+from OpenGL.GL import GL_PROJECTION, GL_DEPTH_TEST, GL_MODELVIEW
+from OpenGL.GL import glMatrixMode, glLoadIdentity, glDisable
+from OpenGL.GLU import gluOrtho2D
+
+from camviz.screen.screen import Screen
+
+
+class Screen2Dimage(Screen):
+ """
+ 2D screen for image display
+
+ Parameters
+ ----------
+ luwh : tuple
+ Left/up/width/height values
+ res : tuple
+ Image resolution
+ """
+ def __init__(self, luwh, res):
+ super().__init__(luwh, '2D_IMAGE')
+ # Get resolution from dimensions if not provided
+ if res is None:
+ res = (self.luwh[2], self.luwh[3])
+ # Initialize values
+ self.setRes(res)
+ self.orig_res = list(self.res)
+ self.background = 'whi'
+
+ def setRes(self, res):
+ """Set new resolution"""
+ self.res = [0, 0, res[0], res[1]]
+ self.prepare()
+
+ def prepare(self):
+ """Prepare screen for display"""
+ glMatrixMode(GL_PROJECTION)
+ glLoadIdentity()
+ glDisable(GL_DEPTH_TEST)
+ gluOrtho2D(self.res[0], self.res[2],
+ self.res[3], self.res[1])
+ glMatrixMode(GL_MODELVIEW)
+ glLoadIdentity()
+
diff --git a/externals/camviz/camviz/screen/screen3Dworld.py b/externals/camviz/camviz/screen/screen3Dworld.py
new file mode 100755
index 0000000000000000000000000000000000000000..ec72a73b270658e9d70668f1091ed586e7420240
--- /dev/null
+++ b/externals/camviz/camviz/screen/screen3Dworld.py
@@ -0,0 +1,110 @@
+# Copyright 2021 Toyota Research Institute. All rights reserved.
+
+import numpy as np
+from OpenGL.GL import GL_PROJECTION, GL_DEPTH_TEST, GL_MODELVIEW
+from OpenGL.GL import glMatrixMode, glEnable, glLoadIdentity, glMultMatrixf
+from OpenGL.GLU import gluPerspective, gluLookAt
+
+from camviz.objects.pose import Pose
+from camviz.screen.screen import Screen
+
+
+class Screen3Dworld(Screen):
+ """
+ 3D screen for virtual world display
+
+ Parameters
+ ----------
+ luwh : tuple
+ Left/up/width/height values
+ wh : width/height
+ Virtual camera image dimensions
+ K : np.array [3,3]
+ Virtual camera intrinsics
+ nf : tuple
+ Near/far display parameters
+ background : str
+ Background color ['bla', 'whi']
+ pose : tuple
+ Virtual camera pose
+ ref : str
+ Coordinate reference system ['cam', 'lidar']
+ """
+ def __init__(self, luwh, wh=None, K=None, nf=(0.01, 10000.0),
+ background='bla', pose=None, ref='cam'):
+ super().__init__(luwh, '3D_WORLD')
+ self.wh, self.K, self.nf = wh, K, nf
+ self.viewer = self.origin = self.P = None
+ self.background = background
+ self.ref = ref
+ # Start and prepare screen
+ self.start()
+ self.prepare()
+ # Rotate if using a lidar frame of reference
+ if ref == 'lidar':
+ self.viewer.rotateY(-90).rotateZ(90)
+ self.saveViewer()
+ # Set viewer pose if provided
+ if pose is not None:
+ self.viewer.setPose(pose)
+ self.saveViewer()
+
+ def start(self):
+ """Start viewer"""
+ self.viewer = Pose()
+ self.origin = Pose()
+ if self.wh is not None and self.K is not None:
+ self.calibrate()
+
+ def prepare(self):
+ """Prepare screen for display"""
+ glMatrixMode(GL_PROJECTION)
+ glLoadIdentity()
+
+ glEnable(GL_DEPTH_TEST)
+
+ # If calibration matrix is provided, use it
+ if self.P is not None:
+ glMultMatrixf(self.P)
+ # Otherwise, use a default perspective
+ else:
+ gluPerspective(45, self.luwh[2] / self.luwh[3],
+ self.nf[0], self.nf[1])
+
+ glMatrixMode(GL_MODELVIEW)
+ glLoadIdentity()
+
+ # Determine look vector
+ T = self.viewer.T
+ gluLookAt(
+ T[0, 3],
+ T[1, 3],
+ T[2, 3],
+ T[0, 3] + T[0, 2],
+ T[1, 3] + T[1, 2],
+ T[2, 3] + T[2, 2],
+ - T[0, 1],
+ - T[1, 1],
+ - T[2, 1],
+ )
+
+ def calibrate(self):
+ """Calibrate screen for display"""
+ # Convert intrinsics to numpy if needed
+ if isinstance(self.K, list):
+ self.K = np.array(self.K)
+
+ # Create transformation matrix
+ self.P = np.zeros(16)
+
+ self.P[0] = 2 * self.K[0, 0] / self.wh[0]
+ self.P[5] = 2 * self.K[1, 1] / self.wh[1]
+
+ self.P[8] = 2.0 * (self.K[0, 2] / self.wh[0]) - 1.0
+ self.P[9] = 2.0 * (self.K[1, 2] / self.wh[1]) - 1.0
+
+ self.P[10] = - 1.0 * (self.nf[1] + self.nf[0]) / (self.nf[1] - self.nf[0])
+ self.P[14] = - 2.0 * (self.nf[1] * self.nf[0]) / (self.nf[1] - self.nf[0])
+ self.P[11] = - 1.0
+
+ self.P = np.reshape(self.P, (4, 4))
diff --git a/externals/camviz/camviz/utils/__init__.py b/externals/camviz/camviz/utils/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..394ee97e5a51c8683a46d1e83df26245c04c4a56
--- /dev/null
+++ b/externals/camviz/camviz/utils/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2021 Toyota Research Institute. All rights reserved.
+
diff --git a/externals/camviz/camviz/utils/cmaps.py b/externals/camviz/camviz/utils/cmaps.py
new file mode 100755
index 0000000000000000000000000000000000000000..d975b445e541f57240b1eb8c274934fbbd68f66c
--- /dev/null
+++ b/externals/camviz/camviz/utils/cmaps.py
@@ -0,0 +1,63 @@
+# Copyright 2021 Toyota Research Institute. All rights reserved.
+
+import numpy as np
+from matplotlib.cm import get_cmap
+
+from camviz.utils.types import is_numpy, is_tensor
+
+
+def jet(data, range=None, exp=1.0):
+ """
+ Creates a JET colormap from data
+
+ Parameters
+ ----------
+ data : np.array [N,1]
+ Data to be converted into a colormap
+ range : tuple (min,max)
+ Optional range value for the colormap (if None, use min and max from data)
+ exp : float
+ Exponential value to weight the color differently
+
+ Returns
+ -------
+ colormap : np.array [N,3]
+ Colormap obtained from data
+ """
+ # Return if data is not available
+ if data is None or data.size == 0 or isinstance(data, tuple):
+ return data
+ else:
+ # If data is a tensor, convert to numpy
+ if is_tensor(data):
+ data = data.detach().cpu().numpy()
+ # If data is [N,1], remove second dimensions
+ if len(data.shape) > 1:
+ data = data.reshape(-1)
+ # Determine range if not available
+ if range is None:
+ data = data.copy() - np.min(data)
+ data = data / (np.max(data) + 1e-6)
+ else:
+ data = (data - range[0]) / (range[1] - range[0])
+ data = np.maximum(np.minimum(data, 1.0), 0.0)
+ # Use exponential if requested
+ if exp != 1.0:
+ data = data ** exp
+ # Initialize colormap
+ jet = np.ones((data.shape[0], 3), dtype=np.float32)
+ # First stage
+ idx = (data <= 0.33)
+ jet[idx, 1] = data[idx] / 0.33
+ jet[idx, 0] = 0.0
+ # Second stage
+ idx = (data > 0.33) & (data <= 0.67)
+ jet[idx, 0] = (data[idx] - 0.33) / 0.33
+ jet[idx, 2] = 1.0 - jet[idx, 0]
+ # Third stage
+ idx = data > 0.67
+ jet[idx, 1] = 1.0 - (data[idx] - 0.67) / 0.33
+ jet[idx, 2] = 0.0
+ # Return colormap
+ return jet
+
diff --git a/externals/camviz/camviz/utils/geometry.py b/externals/camviz/camviz/utils/geometry.py
new file mode 100755
index 0000000000000000000000000000000000000000..afcce0e117b29d1220726c005e4ed4bb2190287a
--- /dev/null
+++ b/externals/camviz/camviz/utils/geometry.py
@@ -0,0 +1,23 @@
+# Copyright 2021 Toyota Research Institute. All rights reserved.
+
+import numpy as np
+
+def unitX(m=1.0):
+ """Return an unit vector on X"""
+ return np.array((m, 0, 0))
+
+def unitY(m=1.0):
+ """Return an unit vector on Y"""
+ return np.array((0, m, 0))
+
+def unitZ(m=1.0):
+ """Return an unit vector on Z"""
+ return np.array((0, 0, m))
+
+def transpose(data):
+ """Transpose numpy array"""
+ return data.T
+
+def invert(data):
+ """Invert numpy array"""
+ return np.linalg.inv(data)
diff --git a/externals/camviz/camviz/utils/image.py b/externals/camviz/camviz/utils/image.py
new file mode 100755
index 0000000000000000000000000000000000000000..134bcd7a8c7d8487856247df5092b416ddc5dde5
--- /dev/null
+++ b/externals/camviz/camviz/utils/image.py
@@ -0,0 +1,26 @@
+# Copyright 2021 Toyota Research Institute. All rights reserved.
+
+import numpy as np
+from PIL import Image
+
+def load_image(file, shape=None):
+ """
+ Load an image and optionally resizes it
+
+ Parameters
+ ----------
+ file : str
+ Image filename
+ shape : tuple (width, height)
+ Optional reshape size
+
+ Returns
+ -------
+ image : np.array [H,W]
+ Loaded image
+ """
+ image = Image.open(file)
+ if shape:
+ image = image.resize(shape, resample=Image.ANTIALIAS)
+ return np.array(image)
+
diff --git a/externals/camviz/camviz/utils/types.py b/externals/camviz/camviz/utils/types.py
new file mode 100755
index 0000000000000000000000000000000000000000..ca4954dbd84e67eac4134322ec6a327aae967766
--- /dev/null
+++ b/externals/camviz/camviz/utils/types.py
@@ -0,0 +1,55 @@
+# Copyright 2020 Toyota Research Institute. All rights reserved.
+
+import numpy as np
+import torch
+
+
+def is_numpy(data):
+ """Checks if data is a numpy array."""
+ return isinstance(data, np.ndarray)
+
+
+def is_tensor(data):
+ """Checks if data is a torch tensor."""
+ return type(data) == torch.Tensor
+
+
+def is_tuple(data):
+ """Checks if data is a tuple."""
+ return isinstance(data, tuple)
+
+
+def is_list(data):
+ """Checks if data is a list."""
+ return isinstance(data, list)
+
+
+def is_double_list(data):
+ """Checks if data is a double list (list of lists)"""
+ return is_list(data) and is_list(data[0])
+
+
+def is_dict(data):
+ """Checks if data is a dictionary."""
+ return isinstance(data, dict)
+
+
+def is_str(data):
+ """Checks if data is a string."""
+ return isinstance(data, str)
+
+
+def is_int(data):
+ """Checks if data is an integer."""
+ return isinstance(data, int)
+
+
+def is_float(data):
+ """Checks if data is a float value."""
+ return isinstance(data, float)
+
+
+def is_seq(data):
+ """Checks if data is a list or tuple."""
+ return is_tuple(data) or is_list(data)
+
diff --git a/externals/camviz/camviz/utils/utils.py b/externals/camviz/camviz/utils/utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..b91f02dcf0718df574404b1fb7ec249f1495540d
--- /dev/null
+++ b/externals/camviz/camviz/utils/utils.py
@@ -0,0 +1,83 @@
+# Copyright 2021 Toyota Research Institute. All rights reserved.
+
+import numpy as np
+from matplotlib.cm import get_cmap
+
+from camviz.utils.types import is_numpy, is_tensor
+
+
+def add_row0(npy):
+ """Add a row with zeros to a numpy array"""
+ return np.vstack([npy, np.zeros((1, npy.shape[1]))])
+
+def add_col1(npy):
+ """Add a column with ones to a numpy array"""
+ return np.hstack([npy, np.ones((npy.shape[0], 1))])
+
+def flatten(lst):
+ """Flatten a list of lists into a list"""
+ return [l for ls in lst for l in ls]
+
+def change_coords(xyz1):
+ """Flip coordinates from camera to lidar frame of reference"""
+ xyz2 = xyz1[:, [0, 2, 1]]
+ xyz2[:, 1] *= -1
+ return xyz2
+
+def numpyf(data):
+ """Convert data to a numpy array if necessary"""
+ return data if is_numpy(data) else \
+ data.cpu().detach().numpy() if is_tensor(data) else \
+ np.array(data, dtype=np.float32)
+
+def labelrc(tup):
+ """Create row and column labels for buffers"""
+ if len(tup) == 2:
+ tup = ('', tup[0], tup[1])
+ return [['%s%d%d' % (tup[0], j, i)
+ for i in range(tup[2])] for j in range(tup[1])]
+
+def add_list(lst1, lst2):
+ """Add two lists element-wise"""
+ return [l1 + l2 for l1, l2 in zip(lst1, lst2)]
+
+def image_grid(mat):
+ i, j = mat.shape[:2]
+ u, v = np.meshgrid(np.arange(j), np.arange(i))
+ return np.stack([u.reshape(-1), v.reshape(-1), np.ones(i * j)], 1)
+
+def alternate_points(x1, x2):
+ x = np.zeros((x1.shape[0] + x2.shape[0], 3))
+ for i in range(x1.shape[0]):
+ x[2*i], x[2*i+1] = x1[i], x2[i]
+ return x
+
+
+def vis_inverse_depth(inv_depth, normalizer=None, percentile=95, colormap='plasma'):
+ cm = get_cmap(colormap)
+ if normalizer:
+ inv_depth /= normalizer
+ else:
+ inv_depth /= (np.percentile(inv_depth, percentile) + 1e-6)
+ return cm(np.clip(inv_depth, 0., 1.0))[:, :, :3]
+
+
+def grid_idx(grid):
+
+ nx, ny = grid.shape[:2]
+ nqx, nqy = nx - 1, ny - 1
+ nqt = nqx * nqy
+
+ idx = np.zeros(4 * nqt)
+ cnt_idx, cnt_data = 0, 0
+ for i in range(nx):
+ for j in range(ny):
+ if i < nqx and j < nqy:
+ idx[cnt_idx + 0] = cnt_data
+ idx[cnt_idx + 1] = cnt_data + 1
+ idx[cnt_idx + 2] = cnt_data + ny + 1
+ idx[cnt_idx + 3] = cnt_data + ny
+ cnt_idx += 4
+ cnt_data += 1
+
+ return idx
diff --git a/externals/camviz/demos/data/ddad_eval.npz b/externals/camviz/demos/data/ddad_eval.npz
new file mode 100755
index 0000000000000000000000000000000000000000..3d08956bcbdfdf1a595ad76b90b3675765dd765a
--- /dev/null
+++ b/externals/camviz/demos/data/ddad_eval.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3f552b98d21c639df946df0ddf5e80d0423b8f392035720155cb30524e8d55a1
+size 1862400
diff --git a/externals/camviz/demos/pointcloud.py b/externals/camviz/demos/pointcloud.py
new file mode 100755
index 0000000000000000000000000000000000000000..fd56e01d068600ec2e0d381c1c70eb2a31b00a16
--- /dev/null
+++ b/externals/camviz/demos/pointcloud.py
@@ -0,0 +1,72 @@
+# Copyright 2021 Toyota Research Institute. All rights reserved.
+
+import numpy as np
+
+import camviz as cv
+
+# Load evaluation data
+data = np.load('demos/data/ddad_eval.npz')
+
+# Get image resolution
+wh = data['rgb'].shape[:2][::-1]
+
+# Create draw tool with specific width and height window dimensions
+draw = cv.Draw(wh=(2000, 900), title='CamViz Pointcloud Demo')
+
+# Create image screen to show the RGB image
+draw.add2Dimage('rgb', luwh=(0.00, 0.00, 0.33, 0.50), res=wh)
+
+# Create image screen to show the depth visualization
+draw.add2Dimage('viz', luwh=(0.00, 0.50, 0.33, 1.00), res=wh)
+
+# Create world screen at specific position inside the window (% left/up/right/down)
+draw.add3Dworld('wld', luwh=(0.33, 0.00, 1.00, 1.00),
+ pose=(7.25323, -3.80291, -5.89996, 0.98435, 0.07935, 0.15674, 0.01431))
+
+# Parse dictionary information
+rgb = data['rgb']
+intrinsics = data['intrinsics']
+depth = data['depth']
+viz = data['viz']
+
+# Create camera from intrinsics and image dimensions (width and height)
+camera = cv.objects.Camera(K=intrinsics, wh=wh)
+
+# Project depth maps from image (i) to camera (c) coordinates
+points = camera.i2c(depth)
+
+# Create pointcloud colors
+rgb_clr = rgb.reshape(-1, 3) # RGB colors
+viz_clr = viz.reshape(-1, 3) # Depth visualization colors
+hgt_clr = cv.utils.cmaps.jet(-points[:, 1]) # Height colors
+
+# Create RGB and visualization textures
+draw.addTexture('rgb', rgb) # Create texture buffer to store rgb image
+draw.addTexture('viz', viz) # Create texture buffer to store visualization image
+
+# Create buffers to store data for display
+draw.addBufferf('pts', points) # Create data buffer to store depth points
+draw.addBufferf('clr', rgb_clr) # Create data buffer to store rgb points color
+draw.addBufferf('viz', viz_clr) # Create data buffer to store pointcloud heights
+draw.addBufferf('hgt', hgt_clr) # Create data buffer to store pointcloud heights
+
+# Color dictionary
+color_dict = {0: 'clr', 1: 'viz', 2: 'hgt'}
+
+# Display loop
+color_mode = 0
+while draw.input():
+ # If RETURN is pressed, switch color mode
+ if draw.RETURN:
+ color_mode = (color_mode + 1) % len(color_dict)
+ # Clear window
+ draw.clear()
+ # Draw image textures on their respective screens
+ draw['rgb'].image('rgb')
+ draw['viz'].image('viz')
+ # Draw points and colors from buffer
+ draw['wld'].size(2).points('pts', color_dict[color_mode])
+ # Draw camera with texture as image
+ draw['wld'].object(camera, tex='rgb')
+ # Update window
+ draw.update(30)
diff --git a/externals/camviz/media/figs/demo_pointcloud.png b/externals/camviz/media/figs/demo_pointcloud.png
new file mode 100755
index 0000000000000000000000000000000000000000..fa5ccc13c91fa44bfae11637d2faa2c422677866
--- /dev/null
+++ b/externals/camviz/media/figs/demo_pointcloud.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6974972ece410a4e2462166e71a58db55887ec8448dbd7e01e887284543a65df
+size 1357373
diff --git a/externals/camviz/media/figs/tri-logo.png b/externals/camviz/media/figs/tri-logo.png
new file mode 100755
index 0000000000000000000000000000000000000000..926ca3c21e08abfd3f0293881b35fa5069bc6af9
--- /dev/null
+++ b/externals/camviz/media/figs/tri-logo.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1fff5679bfb6c236e48d49d48d6d33cfa673ea91079ffc988c09a466d98adaf6
+size 9260
diff --git a/externals/camviz/media/gifs/fsm.gif b/externals/camviz/media/gifs/fsm.gif
new file mode 100755
index 0000000000000000000000000000000000000000..ba5d555e598c1d0ec54e2c83a7959a5c6e655200
--- /dev/null
+++ b/externals/camviz/media/gifs/fsm.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7243ba0e55a93e05dfc586313b52c1d9e17ae5cc0f299ea927700035a1b57605
+size 3919285
diff --git a/externals/camviz/media/gifs/guda.gif b/externals/camviz/media/gifs/guda.gif
new file mode 100755
index 0000000000000000000000000000000000000000..0b09b753c16e1e07fa8c6c8cfdcff18af84d3ebf
--- /dev/null
+++ b/externals/camviz/media/gifs/guda.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fffe3626fa2d8cdd8bfd432d7557db8eeaca7fa6a057e6d8d9be81bbaba7733f
+size 4059325
diff --git a/externals/camviz/media/gifs/packnet-ddad.gif b/externals/camviz/media/gifs/packnet-ddad.gif
new file mode 100755
index 0000000000000000000000000000000000000000..afb9941c5bee2027106259b30152c24f394ea105
--- /dev/null
+++ b/externals/camviz/media/gifs/packnet-ddad.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2d24eb15de8c40c9797d9f1222250ca64c64eeafa04dcf31d19f1309210c3c49
+size 4145246
diff --git a/externals/camviz/media/gifs/packnet-san.gif b/externals/camviz/media/gifs/packnet-san.gif
new file mode 100755
index 0000000000000000000000000000000000000000..294a9cad627e0097c0504a826d1242ab3112a7e7
--- /dev/null
+++ b/externals/camviz/media/gifs/packnet-san.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fc1648b6be7ef1a0833c9e62fa63ecc8d1d22bee4b550bf6c3c47489f807010c
+size 9579461
diff --git a/externals/camviz/requirements.txt b/externals/camviz/requirements.txt
new file mode 100755
index 0000000000000000000000000000000000000000..2f41d3e366d68b51fc5ec26c03717892efbab076
--- /dev/null
+++ b/externals/camviz/requirements.txt
@@ -0,0 +1,10 @@
+matplotlib==3.3.4
+numpy==1.19.5
+opencv-python==4.5.1.48
+Pillow==8.2.0
+pygame==2.0.1
+PyOpenGL==3.1.5
+pyparsing==2.4.7
+python-dateutil==2.8.1
+torch==1.8.1
+
diff --git a/media/figs/camviz_ddad.jpg b/media/figs/camviz_ddad.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..07ea4ccfb2f70f5033a48ef031dceb9a2ded0e15
Binary files /dev/null and b/media/figs/camviz_ddad.jpg differ
diff --git a/media/figs/camviz_kitti.jpg b/media/figs/camviz_kitti.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..31b08b0c039c72485fad682ad6fa5da59c80d76b
Binary files /dev/null and b/media/figs/camviz_kitti.jpg differ
diff --git a/media/figs/depthformer.gif b/media/figs/depthformer.gif
new file mode 100644
index 0000000000000000000000000000000000000000..84f57099046f4b134a9d7dea8a9b9bcec9b0063d
--- /dev/null
+++ b/media/figs/depthformer.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3e884d6cc819906c1c90582dd6b3b4ec95da17c8ffa1ee4db2883a16f31c3c04
+size 16166886
diff --git a/media/figs/fsm.gif b/media/figs/fsm.gif
new file mode 100755
index 0000000000000000000000000000000000000000..ba5d555e598c1d0ec54e2c83a7959a5c6e655200
--- /dev/null
+++ b/media/figs/fsm.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7243ba0e55a93e05dfc586313b52c1d9e17ae5cc0f299ea927700035a1b57605
+size 3919285
diff --git a/media/figs/overfit_kitti.jpg b/media/figs/overfit_kitti.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1904875589652d526ff8634f04a7964d03ff3dcf
Binary files /dev/null and b/media/figs/overfit_kitti.jpg differ
diff --git a/media/figs/packnet.gif b/media/figs/packnet.gif
new file mode 100644
index 0000000000000000000000000000000000000000..afb9941c5bee2027106259b30152c24f394ea105
--- /dev/null
+++ b/media/figs/packnet.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2d24eb15de8c40c9797d9f1222250ca64c64eeafa04dcf31d19f1309210c3c49
+size 4145246
diff --git a/media/figs/self-calibration.gif b/media/figs/self-calibration.gif
new file mode 100644
index 0000000000000000000000000000000000000000..a194e1f09ba3122e2700da44448dd518fe9f2fde
--- /dev/null
+++ b/media/figs/self-calibration.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fabc94a693efa3044c9588c60fb528858427a23f24f6af789ecb1ad50d929f18
+size 7183493
diff --git a/media/figs/tri-logo.png b/media/figs/tri-logo.png
new file mode 100644
index 0000000000000000000000000000000000000000..926ca3c21e08abfd3f0293881b35fa5069bc6af9
--- /dev/null
+++ b/media/figs/tri-logo.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1fff5679bfb6c236e48d49d48d6d33cfa673ea91079ffc988c09a466d98adaf6
+size 9260
diff --git a/media/tests/ddad.png b/media/tests/ddad.png
new file mode 100644
index 0000000000000000000000000000000000000000..f099701bc286cb987a0ccf334c540faee432b466
--- /dev/null
+++ b/media/tests/ddad.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9b2de170d88a1b4af88b3f99c021f862ef71b96d6204f0d598f5a08cd6dbf59e
+size 350154
diff --git a/media/tests/kitti.png b/media/tests/kitti.png
new file mode 100644
index 0000000000000000000000000000000000000000..9a8ceac798f4f0852358443286d54507a0ad00c9
--- /dev/null
+++ b/media/tests/kitti.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:401c2f28e2721202fad012426fd06bbf2215209b7057405e4b2c75859cd24599
+size 920682
diff --git a/scripts/run.py b/scripts/run.py
new file mode 100755
index 0000000000000000000000000000000000000000..547f7d3f2b4ff82601646d61b4da5eb4ace8a620
--- /dev/null
+++ b/scripts/run.py
@@ -0,0 +1,25 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import os
+
+import fire
+import torch
+
+from vidar.core.trainer import Trainer
+from vidar.core.wrapper import Wrapper
+from vidar.utils.config import read_config
+
+
+def train(cfg, **kwargs):
+
+ os.environ['DIST_MODE'] = 'gpu' if torch.cuda.is_available() else 'cpu'
+
+ cfg = read_config(cfg, **kwargs)
+
+ wrapper = Wrapper(cfg, verbose=True)
+ trainer = Trainer(cfg)
+ trainer.learn(wrapper)
+
+
+if __name__ == '__main__':
+ fire.Fire(train)
diff --git a/scripts/run_ddp.py b/scripts/run_ddp.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdd1df1f1c688a9e0ac2aff9a5562163e91a6424
--- /dev/null
+++ b/scripts/run_ddp.py
@@ -0,0 +1,47 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import os
+
+import fire
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+
+from vidar.core.trainer import Trainer
+from vidar.core.wrapper import Wrapper
+from vidar.utils.config import read_config
+
+
+def train(cfg, **kwargs):
+
+ os.environ['DIST_MODE'] = 'ddp'
+
+ cfg = read_config(cfg, **kwargs)
+
+ mp.spawn(main_worker,
+ nprocs=torch.cuda.device_count(),
+ args=(cfg,), join=True)
+
+
+def main_worker(gpu, cfg):
+
+ torch.cuda.set_device(gpu)
+ world_size = torch.cuda.device_count()
+
+ os.environ['RANK'] = str(gpu)
+ os.environ['WORLD_SIZE'] = str(world_size)
+ os.environ['MASTER_ADDR'] = 'localhost'
+ os.environ['MASTER_PORT'] = '12355'
+ os.environ['DIST_MODE'] = 'ddp'
+
+ dist.init_process_group(backend='nccl', world_size=world_size, rank=gpu)
+
+ wrapper = Wrapper(cfg, verbose=True)
+ trainer = Trainer(cfg)
+ trainer.learn(wrapper)
+
+ dist.destroy_process_group()
+
+
+if __name__ == '__main__':
+ fire.Fire(train)
diff --git a/vidar/__init__.py b/vidar/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vidar/arch/__init__.py b/vidar/arch/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vidar/arch/blocks/depth/SigmoidToDepth.py b/vidar/arch/blocks/depth/SigmoidToDepth.py
new file mode 100755
index 0000000000000000000000000000000000000000..1bf753e00f4e1b822c2dcdc2a8545e2f04fd65a3
--- /dev/null
+++ b/vidar/arch/blocks/depth/SigmoidToDepth.py
@@ -0,0 +1,30 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from abc import ABC
+
+import torch.nn as nn
+
+from vidar.utils.decorators import iterate2
+
+
+class SigmoidToDepth(nn.Module, ABC):
+ """
+ Converts sigmoid to depth map
+
+ Parameters
+ ----------
+ min_depth : Float
+ Minimum depth value
+ max_depth
+ Maximum depth value
+ """
+ def __init__(self, min_depth, max_depth):
+ super().__init__()
+ self.min_depth = min_depth
+ self.max_depth = max_depth
+ self.diff_depth = (self.max_depth - self.min_depth)
+
+ @iterate2
+ def forward(self, sigmoid):
+ """Convert sigmoid to depth"""
+ return self.min_depth + self.diff_depth * sigmoid
diff --git a/vidar/arch/blocks/depth/SigmoidToInvDepth.py b/vidar/arch/blocks/depth/SigmoidToInvDepth.py
new file mode 100755
index 0000000000000000000000000000000000000000..70eb16903861fa711d0ea169ab59f95f9b280396
--- /dev/null
+++ b/vidar/arch/blocks/depth/SigmoidToInvDepth.py
@@ -0,0 +1,35 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from abc import ABC
+
+import torch.nn as nn
+
+from vidar.utils.decorators import iterate2
+from vidar.utils.depth import inv2depth
+
+
+class SigmoidToInvDepth(nn.Module, ABC):
+ """
+ Converts sigmoid to inverse depth map
+
+ Parameters
+ ----------
+ min_depth : Float
+ Minimum depth value
+ max_depth
+ Maximum depth value
+ return_depth:
+ Whether the inverse depth map is inverted to depth when returning
+ """
+ def __init__(self, min_depth, max_depth, return_depth=False):
+ super().__init__()
+ self.min_inv_depth = 1. / max_depth
+ self.max_inv_depth = 1. / min_depth
+ self.diff_inv_depth = (self.max_inv_depth - self.min_inv_depth)
+ self.return_depth = return_depth
+
+ @iterate2
+ def forward(self, sigmoid):
+ """Convert sigmoid to inverse depth"""
+ inv_depth = self.min_inv_depth + self.diff_inv_depth * sigmoid
+ return inv_depth if not self.return_depth else inv2depth(inv_depth)
diff --git a/vidar/arch/blocks/depth/SigmoidToLogDepth.py b/vidar/arch/blocks/depth/SigmoidToLogDepth.py
new file mode 100755
index 0000000000000000000000000000000000000000..8b1a4c631f5b028bfccd1d1894a01b0dfa6b2897
--- /dev/null
+++ b/vidar/arch/blocks/depth/SigmoidToLogDepth.py
@@ -0,0 +1,21 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from abc import ABC
+
+import torch
+import torch.nn as nn
+
+from vidar.utils.decorators import iterate2
+
+
+class SigmoidToLogDepth(nn.Module, ABC):
+ """
+ Converts sigmoid to a log depth map
+ """
+ def __init__(self):
+ super().__init__()
+
+ @iterate2
+ def forward(self, sigmoid):
+ """Convert sigmoid to log depth"""
+ return torch.exp(sigmoid)
diff --git a/vidar/arch/blocks/image/ViewSynthesis.py b/vidar/arch/blocks/image/ViewSynthesis.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb0eed8d05ecbf4870209f37293d005d5acc5ded
--- /dev/null
+++ b/vidar/arch/blocks/image/ViewSynthesis.py
@@ -0,0 +1,116 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from abc import ABC
+from functools import partial
+
+import torch
+import torch.nn as nn
+
+from vidar.utils.flow import coords_from_optical_flow
+from vidar.utils.tensor import grid_sample, interpolate
+from vidar.utils.types import is_list
+
+
+class ViewSynthesis(nn.Module, ABC):
+ """
+ Class for view synthesis calculation based on image warping
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration with parameters
+ """
+ def __init__(self, cfg=None):
+ super().__init__()
+ self.grid_sample = partial(
+ grid_sample, mode='bilinear', padding_mode='border', align_corners=True)
+ self.interpolate = partial(
+ interpolate, mode='bilinear', scale_factor=None, align_corners=True)
+ self.grid_sample_zeros = partial(
+ grid_sample, mode='nearest', padding_mode='zeros', align_corners=True)
+ self.upsample_depth = cfg.has('upsample_depth', True) if cfg is not None else True
+
+ @staticmethod
+ def get_num_scales(depths, optical_flow):
+ """Return number of scales based on input"""
+ if depths is not None:
+ return len(depths)
+ if optical_flow is not None:
+ return len(optical_flow)
+ else:
+ raise ValueError('Invalid inputs for view synthesis')
+
+ @staticmethod
+ def get_tensor_ones(depths, optical_flow, scale):
+ """Return unitary tensor based on input"""
+ if depths is not None:
+ return torch.ones_like(depths[scale])
+ elif optical_flow is not None:
+ b, _, h, w = optical_flow[scale].shape
+ return torch.ones((b, 1, h, w), device=optical_flow[scale].device)
+ else:
+ raise ValueError('Invalid inputs for view synthesis')
+
+ def get_coords(self, rgbs, depths, cams, optical_flow, context, scale, tgt):
+ """
+ Calculate projection coordinates for warping
+
+ Parameters
+ ----------
+ rgbs : list[torch.Tensor]
+ Input images (for dimensions) [B,3,H,W]
+ depths : list[torch.Tensor]
+ Target depth maps [B,1,H,W]
+ cams : list[Camera]
+ Input cameras
+ optical_flow : list[torch.Tensor]
+ Input optical flow for alternative warping
+ context : list[Int]
+ Context indices
+ scale : Int
+ Current scale
+ tgt : Int
+ Target index
+
+ Returns
+ -------
+ output : Dict
+ Dictionary containing warped images and masks
+ """
+ if depths is not None and cams is not None:
+ cams_tgt = cams[0] if is_list(cams) else cams
+ cams_ctx = cams[1] if is_list(cams) else cams
+ depth = self.interpolate(depths[scale], size=rgbs[tgt][0].shape[-2:]) \
+ if self.upsample_depth else depths[scale]
+ return {
+ ctx: cams_tgt[tgt].coords_from_depth(depth, cams_ctx[ctx]) for ctx in context
+ }
+ elif optical_flow is not None:
+ return {
+ ctx: coords_from_optical_flow(
+ optical_flow[scale]).permute(0, 2, 3, 1) for ctx in context
+ }
+ else:
+ raise ValueError('Invalid input for view synthesis')
+
+ def forward(self, rgbs, depths=None, cams=None,
+ optical_flow=None, return_masks=False, tgt=0):
+
+ context = [key for key in rgbs.keys() if key != tgt]
+ num_scales = self.get_num_scales(depths, optical_flow)
+
+ warps, masks = [], []
+ for scale in range(num_scales):
+ src = 0 if self.upsample_depth else scale
+ coords = self.get_coords(rgbs, depths, cams, optical_flow, context, scale, tgt=tgt)
+ warps.append([self.grid_sample(
+ rgbs[ctx][src], coords[ctx].type(rgbs[ctx][src].dtype)) for ctx in context])
+ if return_masks:
+ ones = self.get_tensor_ones(depths, optical_flow, scale)
+ masks.append([self.grid_sample_zeros(
+ ones, coords[ctx].type(ones.dtype)) for ctx in context])
+
+ return {
+ 'warps': warps,
+ 'masks': masks if return_masks else None
+ }
diff --git a/vidar/arch/losses/BaseLoss.py b/vidar/arch/losses/BaseLoss.py
new file mode 100644
index 0000000000000000000000000000000000000000..d614b61af482fb04830b0a32cee2ed9045f03e82
--- /dev/null
+++ b/vidar/arch/losses/BaseLoss.py
@@ -0,0 +1,125 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from abc import ABC
+from collections import OrderedDict
+from functools import partial
+
+import torch.nn as nn
+
+from vidar.utils.tensor import same_shape, interpolate
+
+
+class BaseLoss(nn.Module, ABC):
+ """
+ Base class for loss calculation
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration with parameters
+ """
+ def __init__(self, cfg=None):
+ super().__init__()
+
+ self.losses = OrderedDict()
+ self.blocks = OrderedDict()
+
+ self.nearest = partial(interpolate, scale_factor=None, mode='nearest', align_corners=None)
+ self.bilinear = partial(interpolate, scale_factor=None, mode='bilinear', align_corners=True)
+
+ if cfg is not None:
+ self.gamma = cfg.has('gamma', 1.0)
+ self.weight = cfg.has('weight', 1.0)
+ self.scales = cfg.has('scales', 99)
+
+ self.flag_mask_sparse = cfg.has('mask_sparse', False)
+ self.flag_mask_range = cfg.has('mask_range', None)
+
+ def forward(self, *args, **kwargs):
+ """Forward method"""
+ raise NotImplementedError('Forward not implemented for {}'.format(self.__name__))
+
+ def get_weights(self, scales):
+ """Get scale weights"""
+ return [self.weight * self.gamma ** i for i in range(scales)]
+
+ def get_scales(self, scales):
+ """Get number of scales"""
+ return min(self.scales, len(scales))
+
+ @staticmethod
+ def interp(dst, src, fn):
+ """Interpolate dst to match src using fn"""
+ if dst is None or dst.dim() == 3:
+ return dst
+ assert dst.dim() == src.dim()
+ if dst.dim() == 4 and not same_shape(dst.shape, src.shape):
+ dst = fn(dst, size=src)
+ return dst
+
+ def interp_bilinear(self, dst, src):
+ """Bilinear interpolation"""
+ return self.interp(dst, src, self.bilinear)
+
+ def interp_nearest(self, dst, src):
+ """Nearest-neighbor interpolation"""
+ return self.interp(dst, src, self.nearest)
+
+ def mask_sparse(self, mask, gt):
+ """Mask based on sparse GT"""
+ if mask is None:
+ return mask
+ if self.flag_mask_sparse:
+ mask *= gt.sum(1) > 0
+ return mask
+
+ def mask_range(self, mask, gt):
+ """Mask based on depth range"""
+ if mask is None:
+ return mask
+ if self.flag_mask_range is None:
+ return mask
+ mask *= (gt.sum(1) >= self.flag_mask_range[0]) & \
+ (gt.sum(1) <= self.flag_mask_range[1])
+ return mask
+
+ @staticmethod
+ def flatten(pred, gt, mask=None, soft_mask=None):
+ """
+ Flatten 2D inputs for loss calculation
+
+ Parameters
+ ----------
+ pred : torch.Tensor
+ Input predictions
+ gt : torch.Tensor
+ Input ground-truth
+ mask : torch.Tensor or None
+ Input mask (binary)
+ soft_mask : torch.Tensor or None
+ Input soft mask (probability)
+
+ Returns
+ -------
+ pred, gt, mask, soft_mask : torch.Tensor
+ Flattened inputs
+ """
+ if pred.dim() == 4:
+ pred = pred.permute(0, 2, 3, 1)
+ pred = pred.reshape(-1, pred.shape[-1])
+
+ if gt.dim() == 4:
+ gt = gt.permute(0, 2, 3, 1)
+ gt = gt.reshape(-1, gt.shape[-1])
+
+ if mask is not None:
+ if mask.dim() == 4:
+ mask = mask.permute(0, 2, 3, 1)
+ mask = mask.reshape(-1)
+
+ if soft_mask is not None:
+ if soft_mask.dim() == 4:
+ soft_mask = soft_mask.permute(0, 2, 3, 1)
+ soft_mask = soft_mask.reshape(-1)
+
+ return pred, gt, mask, soft_mask
diff --git a/vidar/arch/losses/ConsistencyLoss.py b/vidar/arch/losses/ConsistencyLoss.py
new file mode 100755
index 0000000000000000000000000000000000000000..15a2a27859d8f22ae4600a90ca0802cea8f054b1
--- /dev/null
+++ b/vidar/arch/losses/ConsistencyLoss.py
@@ -0,0 +1,101 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from abc import ABC
+from functools import partial
+
+import torch
+
+from vidar.arch.losses.BaseLoss import BaseLoss
+from vidar.utils.data import get_mask_from_list
+from vidar.utils.tensor import interpolate, same_shape, multiply_mask, masked_average
+from vidar.utils.types import is_list
+
+
+class ConsistencyLoss(BaseLoss, ABC):
+ """
+ Consistency loss class
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration with parameters
+ """
+ def __init__(self, cfg):
+ super().__init__(cfg)
+ self.interpolate = partial(
+ interpolate, mode='nearest', scale_factor=None, align_corners=None)
+
+ def calculate(self, teacher, student, confidence_mask, valid_mask=None):
+ """
+ Calculate consistency loss
+
+ Parameters
+ ----------
+ teacher : torch.Tensor
+ Teacher depth predictions [B,1,H,W]
+ student : torch.Tensor
+ Student depth predictions [B,1,H,W]
+ confidence_mask : torch.Tensor
+ Confidence mask for pixel selection [B,1,H,W]
+ valid_mask : torch.Tensor
+ Valid mask for pixel selection [B,1,H,W]
+
+ Returns
+ -------
+ loss : torch.Tensor
+ Consistency loss [1]
+ """
+ if not same_shape(teacher.shape[-2:], student.shape[-2:]):
+ teacher = self.interpolate(teacher, size=student.shape[-2:])
+ if not same_shape(confidence_mask.shape, teacher.shape):
+ confidence_mask = self.interpolate(confidence_mask, size=teacher.shape[-2:])
+ if valid_mask is not None and not same_shape(valid_mask.shape, teacher.shape):
+ valid_mask = self.interpolate(valid_mask, size=teacher.shape[-2:])
+ non_confidence_mask = (1 - confidence_mask).float()
+ consistency_loss = torch.abs(student - teacher.detach())
+ return masked_average(consistency_loss,
+ multiply_mask(non_confidence_mask, valid_mask))
+
+ def forward(self, teacher, student, confidence_mask, valid_mask=None):
+ """
+ Forward loop for loss calculation
+
+ Parameters
+ ----------
+ teacher : list[torch.Tensor]
+ Teacher depth predictions [B,1,H,W]
+ student : list[torch.Tensor]
+ Student depth predictions [B,1,H,W]
+ confidence_mask : list[torch.Tensor]
+ Confidence mask for pixel selection [B,1,H,W]
+ valid_mask : list[torch.Tensor]
+ Valid mask for pixel selection [B,1,H,W]
+
+ Returns
+ -------
+ output : Dict
+ Dictionary with loss and metrics
+ """
+ scales = self.get_scales(student)
+ weights = self.get_weights(scales)
+
+ losses, metrics = [], {}
+
+ for i in range(scales):
+ teacher_i = teacher[i] if is_list(teacher) else teacher
+ student_i = student[i] if is_list(student) else student
+ confidence_mask_i = get_mask_from_list(confidence_mask, i)
+ valid_mask_i = get_mask_from_list(valid_mask, i)
+
+ loss_i = weights[i] * self.calculate(
+ teacher_i, student_i, confidence_mask_i, valid_mask_i)
+
+ metrics[f'consistency_loss/{i}'] = loss_i.detach()
+ losses.append(loss_i)
+
+ loss = sum(losses) / len(losses)
+
+ return {
+ 'loss': loss,
+ 'metrics': metrics,
+ }
diff --git a/vidar/arch/losses/MultiCamPhotometricLoss.py b/vidar/arch/losses/MultiCamPhotometricLoss.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e3b5ed313f6b5a6fc229daea437ad85f0802918
--- /dev/null
+++ b/vidar/arch/losses/MultiCamPhotometricLoss.py
@@ -0,0 +1,173 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import torch
+
+from vidar.arch.losses.MultiViewPhotometricLoss import MultiViewPhotometricLoss
+from vidar.arch.networks.layers.fsm.utils import coords_from_motion, warp_from_coords, mask_from_coords
+from vidar.utils.depth import inv2depth
+from vidar.utils.tensor import match_scales
+from vidar.utils.types import is_list, is_double_list
+
+
+class MultiCamPhotometricLoss(MultiViewPhotometricLoss):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ # Large value for loss masking
+ self.inf = 999999
+ self.align_corners = True
+
+ def warp(self, rgb_context, inv_depths, cam, cam_context,
+ scene_flow=None, with_mask=True):
+ # Initialize empty warp and mask list
+ warps_context, masks_context = [], []
+ # If mask is available, use it instead of calculating
+ if is_list(with_mask):
+ masks_context, with_mask = with_mask, False
+ # Match inverse depth scales on reference images if necessary
+ rgbs_context = rgb_context if is_double_list(rgb_context) else \
+ [match_scales(rgb, inv_depths, self.n, align_corners=self.align_corners)
+ for rgb in rgb_context]
+ # Warp each reference image to target
+ for j, (ref_rgbs, ref_cam) in enumerate(zip(rgbs_context, cam_context)):
+ # Get warping coordinates
+ ref_coords = [coords_from_motion(ref_cam, inv2depth(inv_depths[i]), cam) for i in range(self.n)]
+ # Get warped images
+ warps_context.append([warp_from_coords(
+ ref_rgbs[i], ref_coords[i], align_corners=self.align_corners, padding_mode='zeros') #'reflection')
+ for i in range(self.n)])
+ # Get warped masks if requested
+ if with_mask:
+ masks_context.append([mask_from_coords(ref_coords[i]) for i in range(self.n)])
+ # Return warped reference images
+ return warps_context, masks_context
+
+ def reduce_photometric_loss_min(self, photometric_losses,
+ unwarped_photometric_losses=None):
+ """
+ Reduce photometric losses using minimum reprojection error
+
+ Parameters
+ ----------
+ photometric_losses : list[Tensor]
+ Photometric losses for each warped image [B,3,H,W]
+
+ unwarped_photometric_losses : list[Tensor]
+ Unwarped photometric losses for each reference image [B,3,H,W]
+
+ Returns
+ -------
+ reduced_photometric_loss : Tensor
+ Reduced loss value (single value)
+ min_photometric_loss : Tensor
+ Masked photometric loss [B,1,H,W]
+ """
+ # Calculate minimum photometric losses
+ min_photometric_loss = [torch.cat(losses, 1).min(1, True)[0]
+ for losses in photometric_losses]
+ # Get invalid minimum mask
+ valid_mask = [warped < self.inf for warped in min_photometric_loss]
+ # If unwarped photometric losses are provided
+ if unwarped_photometric_losses is not None and \
+ len(unwarped_photometric_losses[0]) > 0:
+ # Calculate minimum unwarped photometric losses
+ min_unwarped_photometric_loss = [torch.cat(losses, 1).min(1, True)[0]
+ for losses in unwarped_photometric_losses]
+ # Get minimum mask (warped < unwarped)
+ minimum_mask = [warped < unwarped for warped, unwarped in
+ zip(min_photometric_loss, min_unwarped_photometric_loss)]
+ # Update valid mask with minimum mask
+ valid_mask = [minimum & valid for minimum, valid in
+ zip(minimum_mask, valid_mask)]
+ # Get reduced photometric loss
+ reduced_photometric_loss = sum(
+ [loss[mask].mean() for mask, loss in
+ zip(valid_mask, min_photometric_loss)]) / len(min_photometric_loss)
+ # Mask min photometric loss for visualization
+ for i in range(len(min_photometric_loss)):
+ min_photometric_loss[i][~valid_mask[i]] = 0
+ # Store and return reduced photometric loss
+ return reduced_photometric_loss, min_photometric_loss
+
+ def reduce_photometric_loss_mean(self, photometric_losses,
+ unwarped_photometric_losses=None):
+ """
+ Reduce photometric losses using minimum reprojection error
+
+ Parameters
+ ----------
+ photometric_losses : list[Tensor]
+ Photometric losses for each warped image [B,3,H,W]
+
+ unwarped_photometric_losses : list[Tensor]
+ Unwarped photometric losses for each reference image [B,3,H,W]
+
+ Returns
+ -------
+ reduced_photometric_loss : Tensor
+ Reduced loss value (single value)
+ min_photometric_loss : Tensor
+ Masked photometric loss [B,1,H,W]
+ """
+ valid_mask = [[w < self.inf for w in warped] for warped in photometric_losses]
+ if unwarped_photometric_losses is not None and \
+ len(unwarped_photometric_losses[0]) > 0:
+ # Get minimum mask (warped < unwarped)
+ minimum_mask = [[w < u for w, u in zip(warped, unwarped)] for warped, unwarped in
+ zip(photometric_losses, unwarped_photometric_losses)]
+ # Update valid mask with minimum mask
+ valid_mask = [[m & v for m, v in zip(minimum, valid)]
+ for minimum, valid in zip(minimum_mask, valid_mask)]
+ reduced_photometric_loss = []
+ for i in range(len(photometric_losses)):
+ for j in range(len(photometric_losses[i])):
+ loss = photometric_losses[i][j][valid_mask[i][j]].mean()
+ if not torch.isnan(loss):
+ reduced_photometric_loss.append(loss)
+ reduced_photometric_loss = sum(reduced_photometric_loss) / len(reduced_photometric_loss)
+ # Store and return reduced photometric loss
+ return reduced_photometric_loss, [photometric_losses[0][0]]
+
+ def forward(self, rgb, rgb_context, inv_depths,
+ cam, cam_context, return_logs=False, progress=0.0,
+ opt_flow=None, scene_flow=None, with_mask=False, automask=None):
+ # Initialize photometric losses
+ photometric_losses = [[] for _ in range(self.n)]
+ unwarped_photometric_losses = [[] for _ in range(self.n)]
+ # Create RGB scales
+ rgbs = match_scales(rgb, inv_depths, self.n, align_corners=self.align_corners)
+ rgbs_context = [match_scales(rgb, inv_depths, self.n, align_corners=self.align_corners)
+ for rgb in rgb_context]
+ # Warp context to target
+ warps_context, masks_context = self.warp(
+ rgbs_context, inv_depths, cam, cam_context,
+ scene_flow=scene_flow, with_mask=with_mask)
+ for j in range(len(rgbs_context)):
+ # Calculate and store image loss
+ photometric_loss = self.calc_photometric_loss(warps_context[j], rgbs)
+ for i in range(self.n):
+ if with_mask:
+ # Apply mask if available
+ photometric_loss[i][~masks_context[j][i]] = self.inf
+ # Stack photometric losses for each scale
+ photometric_losses[i].append(photometric_loss[i])
+ # If using automask, calculate and store unwarped losses
+ if self.automask_loss and automask is not False:
+ unwarped_image_loss = self.calc_photometric_loss(rgbs_context[j], rgbs)
+ for i in range(self.n):
+ unwarped_photometric_losses[i].append(unwarped_image_loss[i])
+ # Calculate reduced photometric loss
+ reduced_loss, masked_loss = self.reduce_photometric_loss_min(
+ photometric_losses, unwarped_photometric_losses)
+ # Include smoothness loss if requested
+ if self.smooth_loss_weight > 0.0:
+ reduced_loss += self.calc_smoothness_loss(inv_depths, rgbs)
+ # Remove masks from warps_context
+ return {
+ 'loss': reduced_loss.unsqueeze(0),
+ 'metrics': {},
+ 'warp': warps_context,
+ 'masks': masks_context,
+ 'photo': masked_loss,
+ }
+
+########################################################################################################################
diff --git a/vidar/arch/losses/MultiViewPhotometricLoss.py b/vidar/arch/losses/MultiViewPhotometricLoss.py
new file mode 100644
index 0000000000000000000000000000000000000000..cebbac03c9f8b34786663be0d603bbdd2c0613d5
--- /dev/null
+++ b/vidar/arch/losses/MultiViewPhotometricLoss.py
@@ -0,0 +1,323 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from abc import ABC
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as tf
+
+from vidar.arch.losses.BaseLoss import BaseLoss
+from vidar.geometry.camera import Camera
+from vidar.utils.depth import inv2depth
+from vidar.utils.tensor import match_scales
+
+
+def view_synthesis(ref_image, depth, ref_cam, cam,
+ mode='bilinear', padding_mode='zeros', align_corners=True):
+ assert depth.shape[1] == 1, 'Depth map should have C=1'
+ # Reconstruct world points from target_camera
+ world_points = cam.reconstruct(depth, frame='w')
+ # Project world points onto reference camera
+ ref_coords = ref_cam.project(world_points, frame='w')
+ # View-synthesis given the projected reference points
+ return tf.grid_sample(ref_image, ref_coords, mode=mode,
+ padding_mode=padding_mode, align_corners=align_corners)
+
+
+def gradient_x(image):
+ return image[:, :, :, :-1] - image[:, :, :, 1:]
+
+
+def gradient_y(image):
+ return image[:, :, :-1, :] - image[:, :, 1:, :]
+
+
+def inv_depths_normalize(inv_depths):
+ mean_inv_depths = [inv_depth.mean(2, True).mean(3, True) for inv_depth in inv_depths]
+ return [inv_depth / mean_inv_depth.clamp(min=1e-6)
+ for inv_depth, mean_inv_depth in zip(inv_depths, mean_inv_depths)]
+
+
+def calc_smoothness(inv_depths, images, num_scales):
+ inv_depths_norm = inv_depths_normalize(inv_depths)
+ inv_depth_gradients_x = [gradient_x(d) for d in inv_depths_norm]
+ inv_depth_gradients_y = [gradient_y(d) for d in inv_depths_norm]
+
+ image_gradients_x = [gradient_x(image) for image in images]
+ image_gradients_y = [gradient_y(image) for image in images]
+
+ weights_x = [torch.exp(-torch.mean(torch.abs(g), 1, keepdim=True)) for g in image_gradients_x]
+ weights_y = [torch.exp(-torch.mean(torch.abs(g), 1, keepdim=True)) for g in image_gradients_y]
+
+ # Note: Fix gradient addition
+ smoothness_x = [inv_depth_gradients_x[i] * weights_x[i] for i in range(num_scales)]
+ smoothness_y = [inv_depth_gradients_y[i] * weights_y[i] for i in range(num_scales)]
+ return smoothness_x, smoothness_y
+
+
+def SSIM(x, y, C1=1e-4, C2=9e-4, kernel_size=3, stride=1):
+ """
+ Structural Similarity (SSIM) distance between two images.
+
+ Parameters
+ ----------
+ x,y : torch.Tensor
+ Input images [B,3,H,W]
+ C1,C2 : float
+ SSIM parameters
+ kernel_size,stride : int
+ Convolutional parameters
+
+ Returns
+ -------
+ ssim : torch.Tensor
+ SSIM distance [1]
+ """
+ pool2d = nn.AvgPool2d(kernel_size, stride=stride)
+ refl = nn.ReflectionPad2d(1)
+
+ x, y = refl(x), refl(y)
+ mu_x = pool2d(x)
+ mu_y = pool2d(y)
+
+ mu_x_mu_y = mu_x * mu_y
+ mu_x_sq = mu_x.pow(2)
+ mu_y_sq = mu_y.pow(2)
+
+ sigma_x = pool2d(x.pow(2)) - mu_x_sq
+ sigma_y = pool2d(y.pow(2)) - mu_y_sq
+ sigma_xy = pool2d(x * y) - mu_x_mu_y
+ v1 = 2 * sigma_xy + C2
+ v2 = sigma_x + sigma_y + C2
+
+ ssim_n = (2 * mu_x_mu_y + C1) * v1
+ ssim_d = (mu_x_sq + mu_y_sq + C1) * v2
+ ssim = ssim_n / ssim_d
+
+ return ssim
+
+
+class MultiViewPhotometricLoss(BaseLoss, ABC):
+ """
+ Self-Supervised multiview photometric loss.
+ It takes two images, a depth map and a pose transformation to produce a
+ reconstruction of one image from the perspective of the other, and calculates
+ the difference between them
+
+ Parameters
+ ----------
+ num_scales : int
+ Number of inverse depth map scales to consider
+ ssim_loss_weight : float
+ Weight for the SSIM loss
+ occ_reg_weight : float
+ Weight for the occlusion regularization loss
+ smooth_loss_weight : float
+ Weight for the smoothness loss
+ C1,C2 : float
+ SSIM parameters
+ photometric_reduce_op : str
+ Method to reduce the photometric loss
+ disp_norm : bool
+ True if inverse depth is normalized for
+ clip_loss : float
+ Threshold for photometric loss clipping
+ progressive_scaling : float
+ Training percentage for progressive scaling (0.0 to disable)
+ padding_mode : str
+ Padding mode for view synthesis
+ automask_loss : bool
+ True if automasking is enabled for the photometric loss
+ kwargs : dict
+ Extra parameters
+ """
+ def __init__(self, num_scales=4, ssim_loss_weight=0.85, occ_reg_weight=0.1, smooth_loss_weight=0.1,
+ C1=1e-4, C2=9e-4, photometric_reduce_op='mean', disp_norm=True, clip_loss=0.5,
+ progressive_scaling=0.0, padding_mode='zeros', automask_loss=False, **kwargs):
+ super().__init__()
+ self.n = num_scales
+ self.ssim_loss_weight = ssim_loss_weight
+ self.occ_reg_weight = occ_reg_weight
+ self.smooth_loss_weight = smooth_loss_weight
+ self.C1 = C1
+ self.C2 = C2
+ self.photometric_reduce_op = photometric_reduce_op
+ self.disp_norm = disp_norm
+ self.clip_loss = clip_loss
+ self.padding_mode = padding_mode
+ self.automask_loss = automask_loss
+
+ # Asserts
+ if self.automask_loss:
+ assert self.photometric_reduce_op == 'min', \
+ 'For automasking only the min photometric_reduce_op is supported.'
+
+########################################################################################################################
+
+ @property
+ def logs(self):
+ """Returns class logs."""
+ return {
+ 'num_scales': self.n,
+ }
+
+########################################################################################################################
+
+ def warp_ref_image(self, inv_depths, ref_image, K, ref_K, pose):
+ """
+ Warps a reference image to produce a reconstruction of the original one.
+
+ Parameters
+ ----------
+ inv_depths : list[torch.Tensor]
+ Inverse depth map of the original image [B,1,H,W]
+ ref_image : torch.Tensor
+ Reference RGB image [B,3,H,W]
+ K : torch.Tensor
+ Original camera intrinsics [B,3,3]
+ ref_K : torch.Tensor
+ Reference camera intrinsics [B,3,3]
+ pose : Pose
+ Original -> Reference camera transformation
+
+ Returns
+ -------
+ ref_warped : torch.Tensor
+ Warped reference image (reconstructing the original one) [B,3,H,W]
+ """
+ B, _, H, W = ref_image.shape
+ device = ref_image.device
+ # Generate cameras for all scales
+ cams, ref_cams = [], []
+ for i in range(self.n):
+ _, _, DH, DW = inv_depths[i].shape
+ scale_factor = DW / float(W)
+ cams.append(Camera(K=K.float()).scaled(scale_factor).to(device))
+ ref_cams.append(Camera(K=ref_K.float(), Tcw=pose).scaled(scale_factor).to(device))
+ # View synthesis
+ depths = [inv2depth(inv_depths[i]) for i in range(self.n)]
+ ref_images = match_scales(ref_image, inv_depths, self.n)
+ ref_warped = [view_synthesis(
+ ref_images[i], depths[i], ref_cams[i], cams[i],
+ padding_mode=self.padding_mode) for i in range(self.n)]
+ # Return warped reference image
+ return ref_warped
+
+########################################################################################################################
+
+ def SSIM(self, x, y, kernel_size=3):
+ """
+ Calculates the SSIM (Structural Similarity) loss
+
+ Parameters
+ ----------
+ x,y : torch.Tensor
+ Input images [B,3,H,W]
+ kernel_size : int
+ Convolutional parameter
+
+ Returns
+ -------
+ ssim : torch.Tensor
+ SSIM loss [1]
+ """
+ ssim_value = SSIM(x, y, C1=self.C1, C2=self.C2, kernel_size=kernel_size)
+ return torch.clamp((1. - ssim_value) / 2., 0., 1.)
+
+ def calc_photometric_loss(self, t_est, images):
+ """
+ Calculates the photometric loss (L1 + SSIM)
+ Parameters
+ ----------
+ t_est : list[torch.Tensor]
+ List of warped reference images in multiple scales [B,3,H,W]
+ images : list[torch.Tensor]
+ List of original images in multiple scales [B,3,H,W]
+
+ Returns
+ -------
+ photometric_loss : list[torch.Tensor]
+ Photometric loss [B,1,H,W]
+ """
+ # L1 loss
+ l1_loss = [torch.abs(t_est[i] - images[i])
+ for i in range(self.n)]
+ # SSIM loss
+ if self.ssim_loss_weight > 0.0:
+ ssim_loss = [self.SSIM(t_est[i], images[i], kernel_size=3)
+ for i in range(self.n)]
+ # Weighted Sum: alpha * ssim + (1 - alpha) * l1
+ photometric_loss = [self.ssim_loss_weight * ssim_loss[i].mean(1, True) +
+ (1 - self.ssim_loss_weight) * l1_loss[i].mean(1, True)
+ for i in range(self.n)]
+ else:
+ photometric_loss = l1_loss
+ # Clip loss
+ if self.clip_loss > 0.0:
+ for i in range(self.n):
+ mean, std = photometric_loss[i].mean(), photometric_loss[i].std()
+ photometric_loss[i] = torch.clamp(
+ photometric_loss[i], max=float(mean + self.clip_loss * std))
+ # Return total photometric loss
+ return photometric_loss
+
+ def reduce_photometric_loss(self, photometric_losses):
+ """
+ Combine the photometric loss from all context images
+
+ Parameters
+ ----------
+ photometric_losses : list[list[torch.Tensor]]
+ Pixel-wise photometric losses from the entire context [B,3,H,W]
+
+ Returns
+ -------
+ photometric_loss : torch.Tensor
+ Reduced photometric loss [1]
+ """
+ # Reduce function
+ def reduce_function(losses):
+ if self.photometric_reduce_op == 'mean':
+ return sum([l.mean() for l in losses]) / len(losses)
+ elif self.photometric_reduce_op == 'min':
+ return torch.cat(losses, 1).min(1, True)[0].mean()
+ else:
+ raise NotImplementedError(
+ 'Unknown photometric_reduce_op: {}'.format(self.photometric_reduce_op))
+ # Reduce photometric loss
+ photometric_loss = sum([reduce_function(photometric_losses[i])
+ for i in range(self.n)]) / self.n
+ # Store and return reduced photometric loss
+ return photometric_loss
+
+########################################################################################################################
+
+ def calc_smoothness_loss(self, inv_depths, images):
+ """
+ Calculates the smoothness loss for inverse depth maps.
+
+ Parameters
+ ----------
+ inv_depths : list[torch.Tensor]
+ Predicted inverse depth maps for all scales [B,1,H,W]
+ images : list[torch.Tensor]
+ Original images for all scales [B,3,H,W]
+
+ Returns
+ -------
+ smoothness_loss : torch.Tensor
+ Smoothness loss [1]
+ """
+ # Calculate smoothness gradients
+ smoothness_x, smoothness_y = calc_smoothness(inv_depths, images, self.n)
+ # Calculate smoothness loss
+ smoothness_loss = sum([(smoothness_x[i].abs().mean() +
+ smoothness_y[i].abs().mean()) / 2 ** i
+ for i in range(self.n)]) / self.n
+ # Apply smoothness loss weight
+ smoothness_loss = self.smooth_loss_weight * smoothness_loss
+ # Store and return smoothness loss
+ return smoothness_loss
+
+
+
diff --git a/vidar/arch/losses/PhotometricLoss.py b/vidar/arch/losses/PhotometricLoss.py
new file mode 100644
index 0000000000000000000000000000000000000000..d29fbf532fd05876a6361c9819f462e9d7cb857c
--- /dev/null
+++ b/vidar/arch/losses/PhotometricLoss.py
@@ -0,0 +1,34 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from abc import ABC
+from functools import partial
+
+import torch
+
+from vidar.arch.losses.BaseLoss import BaseLoss
+from vidar.arch.losses.SSIMLoss import SSIMLoss
+from vidar.utils.tensor import interpolate
+
+
+class PhotometricLoss(BaseLoss, ABC):
+ def __init__(self, cfg):
+ super().__init__()
+ self.alpha = cfg.alpha
+ self.ssim_loss = SSIMLoss()
+ self.interpolate = partial(interpolate, scale_factor=None, mode='bilinear', align_corners=True)
+
+ def forward(self, pred, gt):
+
+ pred = self.interpolate(pred, size=gt)
+ l1_loss = torch.abs(pred - gt).mean(1, True)
+
+ if self.alpha == 0.0:
+ photometric_loss = l1_loss
+ else:
+ ssim_loss = self.ssim_loss(pred, gt)['loss'].mean(1, True)
+ photometric_loss = self.alpha * ssim_loss + (1.0 - self.alpha) * l1_loss
+
+ return {
+ 'loss': photometric_loss,
+ }
+
diff --git a/vidar/arch/losses/ReprojectionLoss.py b/vidar/arch/losses/ReprojectionLoss.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f208cad2bb1731a1015184afa46a96619453aa4
--- /dev/null
+++ b/vidar/arch/losses/ReprojectionLoss.py
@@ -0,0 +1,230 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from abc import ABC
+from functools import partial
+
+import torch
+
+from vidar.arch.losses.BaseLoss import BaseLoss
+from vidar.utils.config import cfg_has
+from vidar.utils.data import get_from_list, get_mask_from_list
+from vidar.utils.tensor import interpolate, multiply_args, masked_average
+from vidar.utils.types import is_list
+
+
+class ReprojectionLoss(BaseLoss, ABC):
+ """
+ Reprojection loss class
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration with parameters
+ """
+ def __init__(self, cfg):
+ super().__init__(cfg)
+ self.automasking = cfg.automasking
+ self.reprojection_reduce_op = cfg.reprojection_reduce_op
+ self.jitter_identity_reprojection = cfg.jitter_identity_reprojection
+ self.logvar_weight = cfg_has(cfg, 'logvar_weight', 0.0)
+ self.feature_weight = cfg_has(cfg, 'feature_weight', 0.0)
+
+ self.interpolate = partial(
+ interpolate, mode='bilinear', scale_factor=None, align_corners=True)
+
+ self.inf = 1e6
+
+ @staticmethod
+ def compute_reprojection_mask(reprojection_loss, identity_reprojection_loss):
+ """
+ Compute reprojection mask based for automasking
+
+ Parameters
+ ----------
+ reprojection_loss : torch.Tensor
+ Warped reprojection loss [B,1,H,W]
+ identity_reprojection_loss : torch.Tensor
+ Identity reprojection loss [B,1,H,W]
+
+ Returns
+ -------
+ mask : torch.Tensor
+ Reprojection mask for automasking [B,1,H,W]
+ """
+ if identity_reprojection_loss is None:
+ reprojection_mask = torch.ones_like(reprojection_loss)
+ else:
+ all_losses = torch.cat([reprojection_loss, identity_reprojection_loss], dim=1)
+ idxs = torch.argmin(all_losses, dim=1, keepdim=True)
+ reprojection_mask = (idxs == 0)
+ return reprojection_mask
+
+ def reduce_reprojection(self, reprojection_losses, overlap_mask=None):
+ """
+ Combine multi-image reprojection losses
+
+ Parameters
+ ----------
+ reprojection_losses : list[torch.Tensor]
+ Per-image reprojection losses
+ overlap_mask : list[torch.Tensor] or None
+ Valid mask to remove pixels
+
+ Returns
+ -------
+ reprojection_loss : torch.Tensor
+ Output loss [1]
+ overlap_mask : torch.Tensor
+ Reduced overlap mask
+ """
+ if is_list(reprojection_losses):
+ reprojection_losses = torch.cat(reprojection_losses, 1)
+ if self.reprojection_reduce_op == 'mean':
+ assert overlap_mask is None, 'Not implemented yet'
+ reprojection_loss = reprojection_losses.mean(1, keepdim=True)
+ elif self.reprojection_reduce_op == 'min':
+ if overlap_mask is not None:
+ if is_list(overlap_mask):
+ overlap_mask = torch.cat(overlap_mask, 1)
+ reprojection_losses[~overlap_mask.bool()] = self.inf
+ reprojection_loss, _ = torch.min(reprojection_losses, dim=1, keepdim=True)
+ overlap_mask = reprojection_loss < self.inf
+ reprojection_loss[~overlap_mask] = 0.0 # For visualization purposes
+ else:
+ raise ValueError(
+ f'Invalid reprojection reduce operation: {self.reprojection_reduce_op}')
+ return reprojection_loss, overlap_mask
+
+ def calculate(self, rgb, rgb_context, warps, logvar=None,
+ valid_mask=None, overlap_mask=None):
+ """
+ Calculate reprojection loss
+
+ Parameters
+ ----------
+ rgb : torch.Tensor
+ Target image [B,3,H,W]
+ rgb_context : list[torch.Tensor]
+ List of context images [B,3,H,W]
+ warps : list[torch.Tensor]
+ List of warped images from view synthesis [B,3,H,W]
+ logvar : list[torch.Tensor]
+ Log variance for log-likelihood calculation
+ valid_mask : torch.Tensor or None
+ Valid mask for pixel filtering
+ overlap_mask : torch.Tensor or None
+ Overlap mask for pixel filtering
+
+ Returns
+ -------
+ average_loss : torch.Tensor
+ Output loss [1]
+ reprojection_mask : torch.Tensor
+ Combined reprojection mask (overlap + reprojection + valid) [B,1,H,W]
+ reprojection_loss : torch.Tensor
+ Per-pixel loss [B,1,H,W]
+ overlap_mask : torch.Tensor
+ Combined overlap mask [B,1,H,W]
+ """
+ reprojection_losses = [
+ self.losses['photometric'](warp, rgb)['loss'] for warp in warps]
+ reprojection_loss, overlap_mask = self.reduce_reprojection(
+ reprojection_losses, overlap_mask=overlap_mask)
+
+ if 'featuremetric' in self.losses.keys():
+ featuremetric_loss = [
+ self.losses['featuremetric'](warp, rgb)['loss'] for warp in warps
+ ]
+ reduced_featuremetric_loss = torch.cat(reprojection_losses, 1).mean()
+
+ if self.automasking:
+ reprojection_identity_losses = [
+ self.losses['photometric'](context, rgb)['loss'] for context in rgb_context]
+ reprojection_identity_loss, _ = self.reduce_reprojection(
+ reprojection_identity_losses)
+ if self.jitter_identity_reprojection > 0:
+ reprojection_identity_loss += self.jitter_identity_reprojection * \
+ torch.randn(reprojection_identity_loss.shape, device=reprojection_identity_loss.device)
+ else:
+ reprojection_identity_loss = None
+
+ reprojection_mask = self.compute_reprojection_mask(
+ reprojection_loss, reprojection_identity_loss,
+ # reprojection_mask=valid_mask
+ # reprojection_mask=multiply_any(reprojection_mask, overlap_mask)
+ )
+ reprojection_mask = multiply_args(reprojection_mask, valid_mask, overlap_mask)
+
+ if logvar is not None and self.logvar_weight > 0.0:
+ logvar = self.interpolate(logvar, reprojection_loss.shape[-2:])
+ reprojection_loss = reprojection_loss * torch.exp(-logvar)
+
+ average_loss = masked_average(reprojection_loss, reprojection_mask)
+ # reprojection_loss *= reprojection_mask # REMOVE FOR VISUALIZATION
+
+ if logvar is not None and self.logvar_weight > 0.0:
+ average_loss += self.logvar_weight * masked_average(logvar, reprojection_mask)
+
+ if 'featuremetric' in self.losses.keys() and self.feature_weight > 0.0:
+ featuremetric_loss = [self.losses['featuremetric'](warp, rgb)['loss'] for warp in warps]
+ reduced_featuremetric_loss = torch.cat(featuremetric_loss, 1).mean()
+ average_loss += self.feature_weight * reduced_featuremetric_loss
+
+ return average_loss, reprojection_mask, reprojection_loss, overlap_mask
+
+ def forward(self, rgb, rgb_context, warps, logvar=None,
+ valid_mask=None, overlap_mask=None):
+ """
+ Calculate reprojection loss
+
+ Parameters
+ ----------
+ rgb : torch.Tensor
+ Target image [B,3,H,W]
+ rgb_context : list[torch.Tensor]
+ List of context images [B,3,H,W]
+ warps : list[torch.Tensor]
+ List of warped images from view synthesis [B,3,H,W]
+ logvar : list[torch.Tensor]
+ Log variance for log-likelihood calculation
+ valid_mask : torch.Tensor or None
+ Valid mask for pixel filtering
+ overlap_mask : torch.Tensor or None
+ Overlap mask for pixel filtering
+
+ Returns
+ -------
+ output : Dictionary with loss, metrics, masks, photometric errors, and overlap
+ """
+ scales = self.get_scales(warps)
+ weights = self.get_weights(scales)
+
+ losses, masks, photos, overlaps, metrics = [], [], [], [], {}
+
+ for i in range(scales):
+ rgb_i, rgb_context_i, warps_i = rgb[0], rgb_context[0], warps[i]
+ valid_mask_i = get_mask_from_list(valid_mask, i)
+ overlap_mask_i = get_mask_from_list(overlap_mask, i)
+ logvar_i = get_from_list(logvar, i)
+
+ loss_i, mask_i, photo_i, overlap_mask_i = self.calculate(
+ rgb_i, rgb_context_i, warps_i, logvar=logvar_i,
+ valid_mask=valid_mask_i, overlap_mask=overlap_mask_i)
+ loss_i = weights[i] * loss_i
+
+ metrics[f'reprojection_loss/{i}'] = loss_i.detach()
+
+ losses.append(loss_i)
+ masks.append(mask_i)
+ photos.append(photo_i)
+ overlaps.append(overlap_mask_i)
+
+ loss = sum(losses) / len(losses)
+
+ return {
+ 'loss': loss,
+ 'metrics': metrics,
+ 'mask': masks,
+ 'photo': photos,
+ 'overlap': overlaps,
+ }
\ No newline at end of file
diff --git a/vidar/arch/losses/SSIMLoss.py b/vidar/arch/losses/SSIMLoss.py
new file mode 100755
index 0000000000000000000000000000000000000000..eefa3912488cacd50a441d3a6a26f3632692a4fd
--- /dev/null
+++ b/vidar/arch/losses/SSIMLoss.py
@@ -0,0 +1,62 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from abc import ABC
+
+import torch
+import torch.nn as nn
+
+from vidar.arch.losses.BaseLoss import BaseLoss
+
+
+class SSIMLoss(BaseLoss, ABC):
+ """SSIM (Structural Similarity Index Metric) loss class"""
+ def __init__(self):
+ super().__init__()
+
+ self.mu_x_pool = nn.AvgPool2d(3, 1)
+ self.mu_y_pool = nn.AvgPool2d(3, 1)
+
+ self.sig_x_pool = nn.AvgPool2d(3, 1)
+ self.sig_y_pool = nn.AvgPool2d(3, 1)
+ self.sig_xy_pool = nn.AvgPool2d(3, 1)
+
+ self.refl = nn.ReflectionPad2d(1)
+
+ self.C1 = 0.01 ** 2
+ self.C2 = 0.03 ** 2
+
+ def forward(self, x, y):
+ """
+ Calculates SSIM loss
+
+ Parameters
+ ----------
+ x : torch.Tensor
+ Input image 1 [B,3,H,W]
+ y : torch.Tensor
+ Input image 2 [B,3,H,W]
+
+ Returns
+ -------
+ output : Dict
+ Dictionary with loss
+ """
+ x = self.refl(x)
+ y = self.refl(y)
+
+ mu_x = self.mu_x_pool(x)
+ mu_y = self.mu_y_pool(y)
+
+ sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2
+ sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2
+ sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y
+
+ SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2)
+ SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2)
+
+ loss = torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1)
+
+ return {
+ 'loss': loss,
+ }
+
diff --git a/vidar/arch/losses/SmoothnessLoss.py b/vidar/arch/losses/SmoothnessLoss.py
new file mode 100755
index 0000000000000000000000000000000000000000..248a6aa8c63bd814e05b1049d6c66af87abd2c8e
--- /dev/null
+++ b/vidar/arch/losses/SmoothnessLoss.py
@@ -0,0 +1,93 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from abc import ABC
+
+import torch
+
+from vidar.arch.losses.BaseLoss import BaseLoss
+from vidar.utils.tensor import same_shape, interpolate_image
+
+
+class SmoothnessLoss(BaseLoss, ABC):
+ """
+ Smoothness loss class
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration with parameters
+ """
+ def __init__(self, cfg):
+ super().__init__(cfg)
+ self.normalize = cfg.normalize
+
+ def calculate(self, rgb, depth):
+ """
+ Calculate smoothness loss
+
+ Parameters
+ ----------
+ rgb : torch.Tensor
+ Input image [B,3,H,W]
+ depth : torch.Tensor
+ Predicted depth map [B,1,H,W]
+
+ Returns
+ -------
+ loss : torch.Tensor
+ Smoothness loss [1]
+ """
+ if self.normalize:
+ mean_depth = depth.mean(2, True).mean(3, True)
+ norm_depth = depth / (mean_depth + 1e-7)
+ else:
+ norm_depth = depth
+
+ grad_depth_x = torch.abs(norm_depth[:, :, :, :-1] - norm_depth[:, :, :, 1:])
+ grad_depth_y = torch.abs(norm_depth[:, :, :-1, :] - norm_depth[:, :, 1:, :])
+
+ grad_rgb_x = torch.mean(torch.abs(rgb[:, :, :, :-1] - rgb[:, :, :, 1:]), 1, keepdim=True)
+ grad_rgb_y = torch.mean(torch.abs(rgb[:, :, :-1, :] - rgb[:, :, 1:, :]), 1, keepdim=True)
+
+ grad_depth_x *= torch.exp(-1.0 * grad_rgb_x)
+ grad_depth_y *= torch.exp(-1.0 * grad_rgb_y)
+
+ return grad_depth_x.mean() + grad_depth_y.mean()
+
+ def forward(self, rgb, depth):
+ """
+ Calculate smoothness loss
+
+ Parameters
+ ----------
+ rgb : list[torch.Tensor]
+ Input images [B,3,H,W]
+ depth : list[torch.Tensor]
+ Predicted depth maps [B,1,H,W]
+
+ Returns
+ -------
+ output : Dict
+ Dictionary with loss and metrics
+ """
+ scales = self.get_scales(rgb)
+ weights = self.get_weights(scales)
+
+ losses, metrics = [], {}
+
+ for i in range(scales):
+ rgb_i, depth_i = rgb[i], depth[i]
+ if not same_shape(rgb_i.shape[-2:], depth_i.shape[-2:]):
+ rgb_i = interpolate_image(rgb_i, shape=depth_i.shape[-2:])
+
+ loss_i = weights[i] * self.calculate(rgb_i, depth_i)
+
+ metrics[f'smoothness_loss/{i}'] = loss_i.detach()
+ losses.append(loss_i)
+
+ loss = sum(losses) / len(losses)
+
+ return {
+ 'loss': loss,
+ 'metrics': metrics,
+ }
diff --git a/vidar/arch/losses/SupervisedDepthLoss.py b/vidar/arch/losses/SupervisedDepthLoss.py
new file mode 100644
index 0000000000000000000000000000000000000000..50be573a76c52688bcddd34b5d1df551e42ea4be
--- /dev/null
+++ b/vidar/arch/losses/SupervisedDepthLoss.py
@@ -0,0 +1,281 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from abc import ABC
+
+import torch
+import torch.nn as nn
+
+from vidar.arch.losses.BaseLoss import BaseLoss
+from vidar.utils.data import get_mask_from_list
+from vidar.utils.depth import depth2index, get_depth_bins
+from vidar.utils.depth import depth2inv
+from vidar.utils.types import is_list
+
+
+class BerHuLoss(nn.Module, ABC):
+ """BerHu Loss"""
+ def __init__(self, threshold=0.2):
+ super().__init__()
+ self.threshold = threshold
+
+ def forward(self, pred, gt):
+ huber_c = self.threshold * torch.max(pred - gt)
+ diff = (pred - gt).abs()
+ diff2 = diff[diff > huber_c] ** 2
+ return torch.cat((diff, diff2))
+
+
+class SilogLoss(nn.Module, ABC):
+ """Scale Invariant Logarithmic Loss"""
+ def __init__(self, ratio=10., var_ratio=0.85):
+ super().__init__()
+ self.ratio = ratio
+ self.var_ratio = var_ratio
+
+ def forward(self, pred, gt):
+ log_diff = torch.log(pred) - torch.log(gt)
+ silog1 = (log_diff ** 2).mean()
+ silog2 = log_diff.mean() ** 2
+ return torch.sqrt(silog1 - self.var_ratio * silog2) * self.ratio
+
+
+class RMSELoss(nn.Module, ABC):
+ """Root Mean Squared Error Loss"""
+ def __init__(self):
+ super().__init__()
+ self.criterion = nn.MSELoss(reduction='none')
+
+ def forward(self, pred, gt):
+ return torch.sqrt(self.criterion(pred, gt))
+
+
+class L1LogLoss(nn.Module, ABC):
+ """Root Mean Squared Error Loss"""
+ def __init__(self):
+ super().__init__()
+ self.criterion = nn.L1Loss(reduction='none')
+
+ def forward(self, pred, gt):
+ return self.criterion(torch.log(pred), torch.log(gt))
+
+
+class MixtureLoss(nn.Module, ABC):
+ """Root Mean Squared Error Loss"""
+ def __init__(self):
+ super().__init__()
+
+ @staticmethod
+ def laplacian(mu, std, gt):
+ std = std + 1e-12
+ return 0.5 * torch.exp(-(torch.abs(mu - gt) / std)) / std
+
+ def forward(self, pred, gt):
+ mu0, mu1 = pred[:, [0]], pred[:, [1]]
+ std0, std1 = pred[:, [2]], pred[:, [3]]
+ w0 = pred[:, [4]]
+ w1 = 1.0 - w0
+ return (- torch.log(w0 * self.laplacian(mu0, std0, gt) +
+ w1 * self.laplacian(mu1, std1, gt))).mean()
+
+
+class RootAbsRelLoss(nn.Module, ABC):
+ """Root Mean Squared Error Loss"""
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, pred, gt):
+ return torch.sqrt(torch.abs(pred - gt) / gt)
+
+
+class SquareAbsRelLoss(nn.Module, ABC):
+ """Root Mean Squared Error Loss"""
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, pred, gt):
+ return (torch.abs(pred - gt) / gt) ** 2
+
+
+class CrossEntropyLoss(nn.Module, ABC):
+ """Supervised Loss"""
+ def __init__(self):
+ super().__init__()
+ self.gamma = 2.0
+ self.alpha = 0.25
+
+ self.bootstrap_ratio = 0.3
+ self.cross_entropy_loss = torch.nn.CrossEntropyLoss(weight=None, ignore_index=255, reduce=False)
+
+ def forward(self, pred, gt):
+ min_depth, max_depth = 1.0, 100.0
+ bins = get_depth_bins('linear', min_depth, max_depth, 100).to(pred.device)
+ gt = depth2index(gt, bins).squeeze(1)
+
+ loss_ce = self.cross_entropy_loss(pred, gt.to(torch.long))
+ num_bootstrapping = int(self.bootstrap_ratio * pred.shape[0])
+ image_errors, _ = loss_ce.view(-1, ).sort()
+ worst_errors = image_errors[-num_bootstrapping:]
+ return torch.mean(worst_errors)
+
+
+def get_criterion(method):
+ """Determines the supervised loss to be used"""
+ if method == 'l1':
+ return nn.L1Loss()
+ elif method == 'l1log':
+ return L1LogLoss()
+ elif method == 'mse':
+ return nn.MSELoss(reduction='none')
+ elif method == 'rmse':
+ return RMSELoss()
+ elif method == 'huber':
+ return nn.SmoothL1Loss(reduction='none')
+ elif method == 'berhu':
+ return BerHuLoss()
+ elif method == 'silog':
+ return SilogLoss()
+ elif method == 'abs_rel':
+ return lambda x, y: torch.abs(x - y) / x
+ elif method == 'root_abs_rel':
+ return RootAbsRelLoss()
+ elif method == 'square_abs_rel':
+ return SquareAbsRelLoss()
+ elif method == 'mixture':
+ return MixtureLoss()
+ elif method == 'cross_entropy':
+ return CrossEntropyLoss()
+ else:
+ raise ValueError('Unknown supervised loss {}'.format(method))
+
+
+class LossWrapper(nn.Module):
+ """
+ Wrapper for supervised depth criteria
+
+ Parameters
+ ----------
+ method : String
+ Which supervised loss to use
+ """
+ def __init__(self, method):
+ super().__init__()
+ self.criterion = get_criterion(method)
+
+ def forward(self, pred, gt, soft_mask=None):
+ """
+ Calculate supervised depth loss
+
+ Parameters
+ ----------
+ pred : torch.Tensor
+ Predicted depth [B,1,H,W]
+ gt : torch.Tensor
+ Ground-truth depth [B,1,H,W]
+ soft_mask
+ Mask for pixel weighting [B,1,H,W]
+
+ Returns
+ -------
+ loss : torch.Tensor
+ Supervised depth loss [1]
+ """
+ loss = self.criterion(pred, gt)
+ if soft_mask is not None:
+ loss = loss * soft_mask.detach().view(-1, 1)
+ return loss.mean()
+
+
+class SupervisedDepthLoss(BaseLoss, ABC):
+ def __init__(self, cfg):
+ """
+ Supervised loss class
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration with parameters
+ """
+ super().__init__(cfg)
+ self.criterion = LossWrapper(cfg.method)
+ self.inverse = cfg.has('inverse', False)
+
+ def calculate(self, pred, gt, mask=None, soft_mask=None):
+ """
+ Calculate supervised depth loss
+
+ Parameters
+ ----------
+ pred : torch.Tensor
+ Predicted depth [B,1,H,W]
+ gt : torch.Tensor
+ Ground-truth depth [B,1,H,W]
+ mask : torch.Tensor
+ Mask for pixel filtering [B,1,H,W]
+ soft_mask
+ Mask for pixel weighting [B,1,H,W]
+
+ Returns
+ -------
+ loss : torch.Tensor
+ Supervised depth loss [1]
+ """
+ # Interpolations
+ pred = self.interp_nearest(pred, gt)
+ mask = self.interp_nearest(mask, gt)
+ soft_mask = self.interp_bilinear(soft_mask, gt)
+
+ # Flatten tensors
+ pred, gt, mask, soft_mask = self.flatten(pred, gt, mask, soft_mask)
+
+ # Masks
+ mask = self.mask_sparse(mask, gt)
+ mask = self.mask_range(mask, gt)
+
+ # Calculate loss
+ return self.criterion(pred[mask], gt[mask],
+ soft_mask=soft_mask[mask] if soft_mask is not None else None)
+
+ def forward(self, pred, gt, mask=None, soft_mask=None):
+ """
+ Supervised depth loss
+
+ Parameters
+ ----------
+ pred : list[torch.Tensor]
+ Predicted depths [B,1,H,W]
+ gt : torch.Tensor
+ Ground-truth depth [B,1,H,W]
+ mask : torch.Tensor
+ Mask for pixel filtering [B,1,H,W]
+ soft_mask
+ Mask for pixel weighting [B,1,H,W]
+
+ Returns
+ -------
+ loss : torch.Tensor
+ Supervised depth loss [1]
+ """
+ if self.inverse:
+ pred, gt = depth2inv(pred), depth2inv(gt)
+
+ scales = self.get_scales(pred)
+ weights = self.get_weights(scales)
+
+ losses, metrics = [], {}
+
+ for i in range(scales):
+ pred_i, gt_i = pred[i], gt[i] if is_list(gt) else gt
+ mask_i = get_mask_from_list(mask, i, return_ones=gt_i)
+ soft_mask_i = get_mask_from_list(soft_mask, i)
+
+ loss_i = weights[i] * self.calculate(pred_i, gt_i, mask_i, soft_mask_i)
+
+ metrics[f'supervised_depth_loss/{i}'] = loss_i.detach()
+ losses.append(loss_i)
+
+ loss = sum(losses) / len(losses)
+
+ return {
+ 'loss': loss,
+ 'metrics': metrics,
+ }
diff --git a/vidar/arch/models/BaseModel.py b/vidar/arch/models/BaseModel.py
new file mode 100755
index 0000000000000000000000000000000000000000..1e3c5cec6b05f78b0515414b30eb212023df29a1
--- /dev/null
+++ b/vidar/arch/models/BaseModel.py
@@ -0,0 +1,43 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import torch
+import torch.nn as nn
+
+from vidar.utils.config import cfg_has
+
+
+class BaseModel(nn.Module):
+ """Base model super class, that all other models inherit"""
+ def __init__(self, cfg=None):
+ super().__init__()
+
+ self.blocks = torch.nn.ModuleDict()
+ self.networks = torch.nn.ModuleDict()
+ self.losses = torch.nn.ModuleDict()
+
+ if cfg is not None:
+ self.num_scales = cfg_has(cfg.model, 'num_scales', 99)
+
+ def _forward_unimplemented(self, *args):
+ pass
+
+ def forward(self, *args, **kwargs):
+ """Model forward pass"""
+ raise NotImplementedError(
+ 'Please implement forward function in your own subclass model.')
+
+ def get_num_scales(self, scales):
+ """Return number of predicted scales"""
+ return min(self.num_scales, len(scales))
+
+ def compute_pose(self, rgb, net, tgt=0, ctx=None, invert=True):
+ """Compute poses from pairs of images"""
+ if ctx is None:
+ ctx = [key for key in rgb.keys() if key != tgt]
+ return {idx: net(
+ [rgb[tgt], rgb[idx]], invert=(idx < tgt) and invert)['transformation']
+ for idx in ctx}
+
+ def set_attr(self, cfg, key, default):
+ """Set an attribute for the model"""
+ self.__setattr__(key, cfg_has(cfg, key, default))
\ No newline at end of file
diff --git a/vidar/arch/models/__init__.py b/vidar/arch/models/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vidar/arch/models/depth/DepthFormerModel.py b/vidar/arch/models/depth/DepthFormerModel.py
new file mode 100644
index 0000000000000000000000000000000000000000..36fc0f758f9df0973332cbe0884f7d74c0976381
--- /dev/null
+++ b/vidar/arch/models/depth/DepthFormerModel.py
@@ -0,0 +1,550 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import random
+from abc import ABC
+from functools import partial
+
+import torch
+
+from vidar.arch.blocks.image.ViewSynthesis import ViewSynthesis
+from vidar.arch.models.BaseModel import BaseModel
+from vidar.arch.models.utils import make_rgb_scales, break_context, create_cameras
+from vidar.utils.data import get_from_dict
+from vidar.utils.depth import inv2depth
+from vidar.utils.tensor import interpolate, multiply_args
+from vidar.utils.types import is_str
+
+
+def curr_stereo(val):
+ return is_str(val) and val.startswith('0')
+
+
+class DepthFormerModel(BaseModel, ABC):
+ """
+ Depthformer base model (https://arxiv.org/abs/2204.07616)
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration with parameters
+ """
+ def __init__(self, cfg):
+ super().__init__(cfg)
+
+ self.set_attr(cfg.model, 'warp_context', None)
+ self.set_attr(cfg.model, 'match_context', None)
+
+ self.motion_masking = cfg.model.motion_masking
+ self.matching_augmentation = cfg.model.matching_augmentation
+ self.freeze_teacher_and_pose = cfg.model.freeze_teacher_and_pose
+
+ self.view_synthesis = ViewSynthesis()
+
+ self.interpolate_nearest = partial(
+ interpolate, mode='nearest', scale_factor=None, align_corners=None)
+
+ self.set_attr(cfg.model, 'spatial_weight', [0.0, 0.0])
+ self.set_attr(cfg.model, 'spatio_temporal_weight', [0.0, 0.0])
+ self.set_attr(cfg.model, 'run_mono', True)
+ self.set_attr(cfg.model, 'run_multi', True)
+
+ self.set_attr(cfg.model, 'use_gt_pose', False)
+ self.set_attr(cfg.model, 'use_gt_depth', False)
+ self.set_attr(cfg.model, 'display', False)
+ self.set_attr(cfg.model, 'mono_type', 'mono')
+ self.set_attr(cfg.model, 'multi_temporal_only', False)
+
+ self.set_attr(cfg.model, 'apply_consistency', True)
+
+ @staticmethod
+ def process_stereo(batch):
+ """Process batch to recover stereo / monocular information"""
+ batch = {key: val for key, val in batch.items()}
+ new_intrinsics = {0: batch['intrinsics'][0]}
+ for key in batch['pose'].keys():
+ if not is_str(key) and key != 0:
+ new_intrinsics[key] = batch['intrinsics'][0]
+ batch['intrinsics'] = new_intrinsics
+ suffixes = ['', 'r', 's', 't', 'u', 'v']
+ if batch['rgb'][0].dim() == 5:
+ # Change RGB
+ rgb_stereo = {}
+ for key, val in batch['rgb'].items():
+ rgb_stereo[key] = val[:, 0]
+ for i in range(1, val.shape[1]):
+ rgb_stereo['%d%s' % (key, suffixes[i])] = val[:, i]
+ batch['rgb'] = rgb_stereo
+ # Change pose
+ if 'pose' in batch:
+ pose_stereo = {}
+ for key, val in batch['pose'].items():
+ pose_stereo[key] = val[:, 0]
+ for i in range(1, val.shape[1]):
+ pose_stereo['%d%s' % (key, suffixes[i])] = val[:, i]
+ batch['pose'] = pose_stereo
+ # Change intrinsics
+ if 'intrinsics' in batch:
+ intrinsics_stereo = {}
+ for key, val in batch['intrinsics'].items():
+ intrinsics_stereo[key] = val[:, 0]
+ for i in range(1, val.shape[1]):
+ intrinsics_stereo['%d%s' % (key, suffixes[i])] = val[:, i]
+ for key in batch['pose'].keys():
+ if not is_str(key) and key != 0:
+ for suffix in ['r', 's', 't', 'u', 'v']:
+ if '0%s' % suffix in intrinsics_stereo.keys():
+ intrinsics_stereo['%d%s' % (key, suffix)] = intrinsics_stereo['0%s' % suffix]
+ batch['intrinsics'] = intrinsics_stereo
+ return batch
+
+ @staticmethod
+ def get_stereo_pose(batch, pose, context):
+ """Get poses from stereo images"""
+ for key in context:
+ if key not in pose.keys() and curr_stereo(key):
+ pose[key] = batch['pose'][key]
+ return pose
+
+ @staticmethod
+ def pose_context(pose, context):
+ """Extract context poses from a pose dictionary"""
+ for key in context:
+ if not is_str(key):
+ for key2 in list(pose.keys()):
+ if curr_stereo(key2):
+ new_key = '%d%s' % (key, key2[1:])
+ if new_key in context:
+ pose[new_key] = pose[key2] @ pose[key]
+ return pose
+
+ @staticmethod
+ def conf_mask(depth1, depth2, thr=1.0):
+ """Calculate confidence masks"""
+ mask1 = ((depth1 - depth2) / depth2).abs() < thr
+ mask2 = ((depth2 - depth1) / depth1).abs() < thr
+ return mask1 * mask2
+
+ def forward(self, batch, epoch=0):
+ """Model forward pass"""
+
+ mono_depth_string = 'mono_depth'
+ multi_depth_string = 'multi_depth'
+
+ predictions = {}
+
+ batch = {key: val for key, val in batch.items()}
+ batch['rgb'] = {key: val for key, val in batch['rgb'].items()}
+
+ ### TRANSFORMER
+
+ batch = self.process_stereo(batch)
+ batch_rgb = batch['rgb']
+ rgbs = make_rgb_scales(batch_rgb, ratio_scales=(0.5, 4))
+
+ loss_auto_encoder = None
+ rgbs_pseudo = rgbs
+
+ ### Get images and contexts
+
+ device = rgbs[0][0].device
+ batch_size = rgbs[0][0].shape[0]
+
+ rgb, rgb_context = break_context(
+ rgbs_pseudo, tgt=0, ctx=self.match_context, scl=0, stack=True)
+
+ rgbs0 = {key: val[0] for key, val in rgbs.items()}
+
+ ### Warp pose
+
+ warp_context_pose = [idx for idx in self.warp_context if not is_str(idx)]
+ if not self.use_gt_pose:
+ pose_warp = self.compute_pose(
+ rgbs0, self.networks['pose'],
+ ctx=warp_context_pose, invert=True)
+ else:
+ pose_warp = {key: batch['pose'][key] for key in warp_context_pose}
+
+ pose_warp = self.get_stereo_pose(batch, pose_warp, self.warp_context)
+ pose_warp = self.pose_context(pose_warp, self.warp_context)
+
+ ### Match pose
+
+ if self.run_multi:
+ match_context_pose = [idx for idx in self.match_context if not is_str(idx)]
+ if not self.use_gt_pose:
+ with torch.no_grad():
+ pose_match = self.compute_pose(
+ rgbs0, self.networks['pose'],
+ ctx=match_context_pose, invert=True)
+ else:
+ pose_match = {key: batch['pose'][key] for key in match_context_pose}
+ pose_match = self.get_stereo_pose(batch, pose_match, self.match_context)
+ pose_match = self.pose_context(pose_match, self.match_context)
+ else:
+ pose_match = None
+
+ ### Augmentation Mask
+
+ augmentation_mask = torch.zeros([batch_size, 1, 1, 1], device=device).float()
+ if self.run_multi:
+ if self.training and self.matching_augmentation:
+ for batch_idx in range(batch_size):
+ rand_num = random.random()
+ if rand_num < 0.25:
+ rgb_context[batch_idx] = \
+ torch.stack([rgb[batch_idx] for _ in self.match_context], 0)
+ augmentation_mask[batch_idx] += 1
+ elif rand_num < 0.5:
+ pose_match[-1][batch_idx] *= 0
+ augmentation_mask[batch_idx] += 1
+
+ ### Warp cameras
+
+ intrinsics = batch['intrinsics']
+ cams_warp = create_cameras(rgbs[0][0], intrinsics, pose_warp)
+
+ ### Monocular depth
+
+ if self.run_mono:
+
+ if self.mono_type == 'multi':
+ ctx = [ctx for ctx in self.match_context if curr_stereo(ctx)]
+ pose_match2 = {key: val for key, val in pose_match.items() if curr_stereo(key)}
+ cams_match2 = create_cameras(rgbs[0][0], intrinsics, pose_match2)
+ rgb2, rgb_context2 = break_context(
+ rgbs, tgt=0, ctx=ctx, scl=0, stack=True)
+ mono_depth_output = self.networks[mono_depth_string](
+ rgb=rgb2, rgb_context=rgb_context2, cams=cams_match2, intrinsics=intrinsics, mode='multi')
+ predictions['depth_lowest_mono'] = {
+ 0: [inv2depth(mono_depth_output['lowest_cost'].unsqueeze(1)).detach()]}
+ predictions['volume_mono'] = {0: mono_depth_output['cost_volume']}
+ predictions['mask_confidence_mono'] = {
+ 0: [mono_depth_output['confidence_mask'].unsqueeze(1)]}
+ elif self.mono_type == 'mono':
+ mono_depth_output = self.networks[mono_depth_string](
+ rgb=rgb, intrinsics=intrinsics)
+ else:
+ raise ValueError
+
+ if self.use_gt_depth:
+ depth_mono = [batch['depth'][0][:, 0]]
+ else:
+ depth_mono = mono_depth_output['depths']
+
+ predictions['depth_mono'] = {0: depth_mono}
+ else:
+ mono_depth_output = depth_mono = None
+
+ ### Multi-frame depth
+
+ if self.run_multi:
+
+ if self.multi_temporal_only:
+ ctx = [ctx for ctx in self.match_context if not curr_stereo(ctx)]
+ pose_match3 = {key: val for key, val in pose_match.items() if not is_str(key)}
+ cams_match3 = create_cameras(rgbs[0][0], intrinsics[0], pose_match3)
+ rgb3, rgb_context3 = break_context(
+ rgbs_pseudo, tgt=0, ctx=ctx, scl=0, stack=True)
+ multi_depth_output = self.networks[multi_depth_string](
+ rgb=rgb3, rgb_context=rgb_context3, cams=cams_match3,
+ intrinsics=intrinsics, networks=self.networks,
+ )
+ else:
+ cams_match = create_cameras(rgbs[0][0], intrinsics, pose_match)
+ multi_depth_output = self.networks[multi_depth_string](
+ rgb=rgb, rgb_context=rgb_context, cams=cams_match,
+ intrinsics=intrinsics, networks=self.networks,
+ )
+
+ if self.use_gt_depth:
+ depth_multi = [batch['depth'][0][:, 0]]
+ else:
+ depth_multi = multi_depth_output['depths']
+
+ predictions['depth_multi'] = {0: depth_multi}
+ predictions['volume_multi'] = {0: multi_depth_output['cost_volume']}
+ predictions['depth_lowest_multi'] = {
+ 0: [inv2depth(d.unsqueeze(1)).detach() for d in multi_depth_output['lowest_cost']]}
+ predictions['mask_confidence_multi'] = {
+ 0: [multi_depth_output['confidence_mask'].unsqueeze(1)]}
+
+ if 'ssim_lowest_cost' in multi_depth_output:
+ predictions['depth_lowest_ssim'] = {
+ 0: [inv2depth(multi_depth_output['ssim_lowest_cost'].unsqueeze(1)).detach()]}
+
+ else:
+
+ multi_depth_output = depth_multi = None
+
+ ### Confidence mask
+
+ if self.run_multi:
+ shape = rgbs0[0].shape[-2:]
+ lowest_cost = self.interpolate_nearest(
+ multi_depth_output['lowest_cost'][0].unsqueeze(1), size=shape).squeeze(1).to(device)
+ confidence_mask = self.interpolate_nearest(
+ multi_depth_output['confidence_mask'].unsqueeze(1), size=shape).to(device)
+ if self.motion_masking and self.run_mono:
+
+ if 'regression' in multi_depth_output:
+ inv_depth_low_res = multi_depth_output['regression']['disp_pred_low_res']
+ inv_depth_low_res = self.interpolate_nearest(
+ inv_depth_low_res.unsqueeze(1), size=shape).squeeze(1).to(device)
+ lowest_cost = inv_depth_low_res
+
+ matching_depth = 1. / lowest_cost.unsqueeze(1).to(device)
+ confidence_mask *= self.conf_mask(matching_depth, depth_mono[0])
+ # predictions['mask_confidence'] = {0: [confidence_mask.unsqueeze(1)]}
+ confidence_mask = confidence_mask * (1 - augmentation_mask)
+ else:
+ confidence_mask = None
+
+ ########## LOSSES
+
+ loss, metrics = [], {}
+ mono_metrics, multi_metrics = {}, {}
+ mono_visuals, multi_visuals = {}, {}
+
+ valid_mask = get_from_dict(batch, 'mask')
+ if valid_mask is not None:
+ valid_mask = valid_mask[:, 0]
+
+ predictions['output_mono'] = mono_depth_output
+ predictions['output_multi'] = multi_depth_output
+
+ if 'depth_regr' in multi_depth_output:
+ predictions['depth_regr'] = {
+ 0: [d.unsqueeze(1) for d in multi_depth_output['depth_regr']]
+ }
+
+ if 'cal' in self.networks['multi_depth'].networks.keys():
+ cal = self.networks['multi_depth'].networks['cal']
+ depth1 = depth_multi[0]
+ depth2 = predictions['depth_regr'][0][0]
+ from vidar.utils.tensor import interpolate
+ depth2 = interpolate(depth2, size=depth1.shape[-2:], scale_factor=None, mode='nearest', align_corners=None)
+ predictions['depth_regr'][0].insert(0, cal(depth1, depth2, rgb))
+
+ if not self.training:
+ return {
+ 'predictions': predictions,
+ }
+
+ gt_depth = None if 'depth' not in batch else batch['depth'][0]
+
+ ### Temporal losses
+
+ cams_warp_temp = {key: val for key, val in cams_warp.items()
+ if not curr_stereo(key) or key == 0}
+
+ if len(cams_warp_temp) > 0 \
+ and self.spatial_weight[0] < 1.0 \
+ and self.spatio_temporal_weight[0] < 1.0:
+
+ if self.run_mono:
+ mono_loss_temp, mono_metrics_temp, mono_visuals_temp = \
+ self.compute_loss_and_metrics(
+ rgbs, depth_mono, cams_warp_temp, logvar=None, valid_mask=valid_mask
+ )
+ loss.append((1 - self.spatial_weight[0]) *
+ (1 - self.spatio_temporal_weight[0]) * mono_loss_temp)
+ mono_metrics.update(**mono_metrics_temp)
+ mono_visuals.update(**{f'temp_{key}': val for key, val in mono_visuals_temp.items()})
+ metrics.update({f'mono_temp_{key}': val for key, val in mono_metrics_temp.items()})
+
+ if len(cams_warp_temp) > 0 \
+ and self.spatial_weight[1] < 1.0 \
+ and self.spatio_temporal_weight[1] < 1.0:
+
+ if self.run_multi:
+ multi_loss_temp, multi_metrics_temp, multi_visuals_temp = \
+ self.compute_loss_and_metrics(
+ rgbs, depth_multi, cams_warp_temp, logvar=None,
+ depths_consistency=depth_mono if self.apply_consistency else None,
+ confidence_mask=confidence_mask if self.apply_consistency else None,
+ valid_mask=valid_mask,
+ )
+ loss.append((1 - self.spatial_weight[1]) *
+ (1 - self.spatio_temporal_weight[1]) * multi_loss_temp)
+ multi_metrics.update(**multi_metrics_temp)
+ multi_visuals.update(**{f'temp_{key}': val for key, val in multi_visuals_temp.items()})
+ metrics.update({f'multi_temp_{key}': val for key, val in multi_metrics_temp.items()})
+
+ ### Spatial Losses
+
+ cams_warp_spat = {key: val for key, val in cams_warp.items()
+ if curr_stereo(key) or key == 0}
+
+ if len(cams_warp_spat) > 0 \
+ and self.spatial_weight[0] > 0.0 \
+ and self.spatio_temporal_weight[0] < 1.0:
+
+ if self.run_mono:
+ mono_loss_spat, mono_metrics_spat, mono_visuals_spat = \
+ self.compute_loss_and_metrics(
+ rgbs, depth_mono, cams_warp_spat, logvar=None, valid_mask=valid_mask
+ )
+ loss.append(self.spatial_weight[0] *
+ (1 - self.spatio_temporal_weight[0]) * mono_loss_spat)
+ mono_metrics.update(**mono_metrics_spat)
+ mono_visuals.update(**{f'spat_{key}': val for key, val in mono_visuals_spat.items()})
+ metrics.update({f'mono_spat_{key}': val for key, val in mono_metrics_spat.items()})
+
+ if len(cams_warp_spat) > 0 \
+ and self.spatial_weight[1] > 0.0 \
+ and self.spatio_temporal_weight[1] < 1.0:
+
+ if self.run_multi:
+ multi_loss_spat, multi_metrics_spat, multi_visuals_spat = \
+ self.compute_loss_and_metrics(
+ rgbs, depth_multi, cams_warp_spat, logvar=None,
+ depths_consistency=depth_mono if self.apply_consistency else None,
+ valid_mask=valid_mask,
+ confidence_mask=confidence_mask if self.apply_consistency else None
+ )
+ loss.append(self.spatial_weight[1] *
+ (1 - self.spatio_temporal_weight[1]) * multi_loss_spat)
+ multi_metrics.update(**multi_metrics_spat)
+ multi_visuals.update(**{f'spat_{key}': val for key, val in multi_visuals_spat.items()})
+ metrics.update({f'multi_spat_{key}': val for key, val in multi_metrics_spat.items()})
+
+ ### Spatio-temporal Losses
+
+ if self.spatio_temporal_weight[0] > 0.0:
+
+ if self.run_mono:
+ mono_loss_both, mono_metrics_both, mono_visuals_both = \
+ self.compute_loss_and_metrics(
+ rgbs, depth_mono, cams_warp, logvar=None, valid_mask=valid_mask
+ )
+ loss.append(self.spatio_temporal_weight[0] * mono_loss_both)
+ mono_metrics.update(**mono_metrics_both)
+ mono_visuals.update(**{f'both_{key}': val for key, val in mono_visuals_both.items()})
+ metrics.update({f'mono_both_{key}': val for key, val in mono_metrics_both.items()})
+
+ if self.spatio_temporal_weight[1] > 0.0:
+
+ if self.run_multi:
+ multi_loss_both, multi_metrics_both, multi_visuals_both = \
+ self.compute_loss_and_metrics(
+ rgbs, depth_multi, cams_warp, logvar=None,
+ depths_consistency=depth_mono if self.apply_consistency else None,
+ confidence_mask=confidence_mask if self.apply_consistency else None,
+ valid_mask=valid_mask,
+ )
+ loss.append(self.spatio_temporal_weight[1] * multi_loss_both)
+ multi_metrics.update(**multi_metrics_both)
+ multi_visuals.update(**{f'both_{key}': val for key, val in multi_visuals_both.items()})
+ metrics.update({f'multi_both_{key}': val for key, val in multi_metrics_both.items()})
+
+ ###
+
+ if loss_auto_encoder is not None:
+ loss.append(loss_auto_encoder)
+
+ if 'depth_regr' in predictions:
+ regr_loss, regr_metrics, regr_visuals = \
+ self.compute_loss_and_metrics(
+ rgbs, predictions['depth_regr'][0], cams_warp, logvar=None, valid_mask=valid_mask
+ )
+ loss.append(regr_loss)
+
+ depth_pred = [predictions['depth_regr'][0][0]]
+ depth_gt = depth_mono[0].detach()
+ supervision_output = self.losses['supervision'](depth_pred, depth_gt)
+ loss.append(supervision_output['loss'])
+
+ loss = sum(loss)
+
+ metrics.update({
+ 'min_depth_bin': self.networks[multi_depth_string].networks['encoder'].min_depth_bin,
+ 'max_depth_bin': self.networks[multi_depth_string].networks['encoder'].max_depth_bin,
+ })
+
+ visuals = {
+ **{f'mono_{key}': val for key, val in mono_visuals.items()},
+ **{f'multi_{key}': val for key, val in multi_visuals.items()},
+ }
+
+ if self.run_mono and self.run_multi and \
+ self.training and epoch < self.freeze_teacher_and_pose:
+ self.networks[multi_depth_string].networks['encoder'].update_adaptive_depth_bins(depth_mono[0])
+ if 'lowest_cost' in mono_depth_output:
+ self.networks[mono_depth_string].networks['encoder'].min_depth_bin = \
+ self.networks[multi_depth_string].networks['encoder'].min_depth_bin
+ self.networks[mono_depth_string].networks['encoder'].max_depth_bin = \
+ self.networks[multi_depth_string].networks['encoder'].max_depth_bin
+
+ return {
+ 'loss': loss,
+ 'metrics': metrics,
+ 'visuals': visuals,
+ 'predictions': predictions,
+ }
+
+ def compute_loss_and_metrics(self, rgbs, depths, cams, depths_consistency=None,
+ logvar=None, valid_mask=None, confidence_mask=None):
+ """
+ Compute model loss and metrics
+
+ Parameters
+ ----------
+ rgbs : list[torch.Tensor]
+ Input RGB images
+ depths : list[torch.Tensor]
+ Predicted depth maps
+ cams : list[Camera]
+ Image cameras
+ depths_consistency : list[torch.Tensor]
+ Depth maps used for consistency loss calculation
+ logvar : list[torch.Tensor]
+ Predicted log-variance for depth maps
+ valid_mask : list[torch.Tensor]
+ Valid mask for masking out pixels
+ confidence_mask : list[torch.Tensor]
+ Confidence mask for consistency calculation
+
+ Returns
+ -------
+ loss : torch.Tensor
+ Final loss
+ metrics : Dict
+ Dictionary with calculated metrics
+ visuals : Dict
+ Dictionary with calculated visualizations
+ """
+ num_scales = self.get_num_scales(depths)
+
+ rgb_tgt = [rgbs[0][i] for i in range(num_scales)]
+ rgb_ctx = [[rgbs[j][i] for j in cams.keys() if j != 0] for i in range(num_scales)]
+
+ loss, metrics, visuals = [], {}, {}
+
+ if 'reprojection' in self.losses:
+ synth = self.view_synthesis(rgbs, depths, cams, return_masks=True)
+ reprojection_mask = multiply_args(valid_mask, confidence_mask)
+ reprojection_output = self.losses['reprojection'](
+ rgb_tgt, rgb_ctx, synth['warps'], logvar=logvar,
+ valid_mask=reprojection_mask, overlap_mask=synth['masks'])
+ loss.append(reprojection_output['loss'])
+ metrics.update(reprojection_output['metrics'])
+ visuals['synth'] = synth
+ visuals['reproj'] = reprojection_output
+
+ if 'smoothness' in self.losses:
+ smoothness_output = self.losses['smoothness'](rgb_tgt, depths)
+ loss.append(smoothness_output['loss'])
+ metrics.update(smoothness_output['metrics'])
+
+ if 'consistency' in self.losses and depths_consistency is not None:
+ consistency_output = self.losses['consistency'](
+ depths_consistency, depths,
+ confidence_mask=reprojection_output['mask'],
+ valid_mask=valid_mask
+ )
+ loss.append(consistency_output['loss'])
+ metrics.update(consistency_output['metrics'])
+
+ loss = sum(loss)
+
+ return loss, metrics, visuals
diff --git a/vidar/arch/models/depth/DepthModel.py b/vidar/arch/models/depth/DepthModel.py
new file mode 100755
index 0000000000000000000000000000000000000000..25567a40a56b7a239e81e404e78b80604c51b61e
--- /dev/null
+++ b/vidar/arch/models/depth/DepthModel.py
@@ -0,0 +1,28 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from vidar.arch.models.BaseModel import BaseModel
+
+
+class DepthModel(BaseModel):
+ """
+ Base depth estimation model
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration with parameters
+ """
+ def __init__(self, cfg):
+ super().__init__()
+
+ def forward(self, batch, **kwargs):
+ """Model forward pass"""
+
+ depth_output = self.networks['depth'](batch['rgb'][0])
+ return {
+ 'loss': 0.0,
+ 'metrics': {},
+ 'predictions': {
+ 'depth': {0: depth_output['depths']}
+ }
+ }
diff --git a/vidar/arch/models/depth/FSMModel.py b/vidar/arch/models/depth/FSMModel.py
new file mode 100644
index 0000000000000000000000000000000000000000..83afbcdd84610ffae465db680278b9d519f76dde
--- /dev/null
+++ b/vidar/arch/models/depth/FSMModel.py
@@ -0,0 +1,311 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import random
+
+import numpy as np
+import torch
+
+from vidar.arch.losses.MultiCamPhotometricLoss import MultiCamPhotometricLoss
+from vidar.arch.models.BaseModel import BaseModel
+from vidar.arch.networks.layers.fsm.camera import Camera
+from vidar.arch.networks.layers.fsm.pose import Pose
+from vidar.arch.networks.layers.fsm.utils import \
+ CameraNormalizer, flip_batch_input, flip_output, filter_dict, coords_from_motion, mask_from_coords
+from vidar.utils.depth import depth2inv, inv2depth
+from vidar.utils.types import is_list, is_seq
+
+
+def split_batch(tensor, n=1):
+ """Split a tensor batch-wise"""
+ if is_list(tensor):
+ split = [split_batch(t, n=n) for t in tensor]
+ return list(map(list, zip(*split)))
+ return torch.split(tensor, split_size_or_sections=n, dim=0)
+
+
+def global_cameras(intrinsics, pose, pose_context, hw=None):
+ """Create global cameras for target and source poses + target intrinsics"""
+ cam = camera_from_intrinsics_pose(pose, intrinsics, hw=hw)
+ cam_context = camera_from_intrinsics_pose(pose_context, intrinsics, pose, hw=hw)
+ return cam, cam_context
+
+
+def camera_from_intrinsics_pose(pose, intrinsics, orig_pose=None, hw=None):
+ """
+ Create one or more cameras from pose and intrinsics
+
+ Parameters
+ ----------
+ pose : torch.Tensor or list[torch.Tensor]
+ Poses to be used [B,4,4]
+ intrinsics : torch.Tensor or list[torch.Tensor]
+ Intrinsics to be used [B,3,3]
+ orig_pose : torch.Tensor or list[torch.Tensor]
+ Original poses from which pose is generated [B,4,4]
+ hw : tuple
+ Camera image dimensions
+
+ Returns
+ -------
+ camera : Camera
+ Camera instance created from the input
+ """
+ # If pose is a sequence, do it for each one
+ if is_seq(pose):
+ # If intrinsics is not a sequence, repeat it
+ if not is_seq(intrinsics):
+ intrinsics = [intrinsics] * len(pose)
+ # If orig pose is not a sequence, repeat it
+ if not is_seq(orig_pose):
+ orig_pose = [orig_pose] * len(pose)
+ # Recursive loop for each item
+ return [camera_from_intrinsics_pose(p, i, o, hw=hw)
+ for p, i, o in zip(pose, intrinsics, orig_pose)]
+ # Compound original pose if available
+ if orig_pose is not None:
+ pose = Pose(orig_pose) @ Pose(pose).inverse()
+ # Return camera
+ return Camera(K=intrinsics, Twc=pose, hw=hw)
+
+
+class FSMModel(BaseModel):
+ """
+ Full Surround Monodepth (FSM) model (https://arxiv.org/abs/2104.00152)
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration with parameters
+ """
+ def __init__(self, cfg, **kwargs):
+ super().__init__(cfg)
+
+ self.rotation_mode = 'euler'
+ self.flip_lr_prob = 0.0
+ self.upsample_depth_maps = False
+
+ norm_focal = []
+ self.focal = None if len(norm_focal) == 0 else CameraNormalizer(focal=norm_focal)
+
+ pairs = [
+ [0, 1], [1, 0],
+ [0, 2], [2, 0],
+ [1, 4], [4, 1],
+ [2, 3], [3, 2],
+ [3, 5], [5, 3],
+ [4, 5], [5, 4],
+ ]
+
+ stereo_weight = 0.1
+ stereo_context = True
+ gt_pose = False
+
+ self.multicam_loss = MultiCamPhotometricLoss(**kwargs)
+ self.pairs = pairs
+ self.stereo_weight = stereo_weight
+ self.gt_pose = gt_pose
+ self.stereo_context = stereo_context
+
+ self._input_keys = ['rgb', 'rgb_context', 'intrinsics', 'extrinsics',
+ 'pose_context', 'filename']
+ self.networks = torch.nn.ModuleDict()
+
+ def compute_depth_net(self, batch, force_flip=False):
+ """Computes inverse depth maps from single images"""
+ # Randomly flip and estimate inverse depth maps
+ flag_flip_lr = random.random() < self.flip_lr_prob if self.training else force_flip
+ output = self.depth_net_flipping(batch, flag_flip_lr)
+ if self.focal is not None:
+ output['inv_depths'] = self.focal.unormalize(output['inv_depths'])
+ # Return inverse depth maps
+ return output
+
+ def compute_pose_net(self, image, contexts):
+ """Compute poses from image and a sequence of context images"""
+ pose_vec = self.networks['pose'](image, contexts)
+ return [Pose.from_vec(pose_vec[:, i], self.rotation_mode)
+ for i in range(pose_vec.shape[1])]
+
+ def depth_net_flipping(self, batch, flip):
+
+ # Which keys are being passed to the depth network
+ batch_input = {key: batch[key] for key in filter_dict(batch, self._input_keys)}
+ if self.focal is not None:
+ batch_input['rgb'] = self.focal.normalize(batch_input['rgb'], batch['intrinsics'])
+ if flip:
+ # Run depth network with flipped inputs
+ output = self.networks['depth'](**flip_batch_input(batch_input))
+ # Flip output back if training
+ if self.training:
+ output = flip_output(output)
+ else:
+ # Run depth network
+ output = self.networks['depth'](**batch_input)
+ return output
+
+ def forward2(self, batch, return_logs=False, force_flip=False):
+ # Generate inverse depth predictions
+ depth_output = self.compute_depth_net(batch, force_flip=force_flip)
+ # Generate pose predictions if available
+ pose_output = None
+ if 'rgb_context' in batch and self.networks['pose'] is not None:
+ pose_output = self.compute_pose_net(
+ batch['rgb'], batch['rgb_context'])
+ # Return output dictionary
+ return {
+ **depth_output,
+ 'poses': pose_output,
+ }
+
+ def forward(self, batch, return_logs=False, progress=0.0, **kwargs):
+
+ new_batch = {}
+ new_batch['rgb'] = batch['rgb'][0]
+ if self.training:
+ new_batch['rgb_context'] = [batch['rgb'][1], batch['rgb'][-1]]
+ new_batch['pose'] = batch['pose'][0]
+ new_batch['pose_context'] = [batch['pose'][1], batch['pose'][-1]]
+ new_batch['intrinsics'] = batch['intrinsics'][0]
+ new_batch['filename'] = batch['filename']
+ batch = new_batch
+
+ if self.training:
+ batch['rgb'] = batch['rgb'][0]
+ batch['rgb_context'] = [b[0] for b in batch['rgb_context']]
+ batch['pose'] = batch['pose'][0]
+ batch['pose_context'] = [b[0] for b in batch['pose_context']]
+ batch['intrinsics'] = batch['intrinsics'][0]
+
+ if self.networks['depth'] is not None:
+ output_self_sup = self.forward2(batch)
+ depth = inv2depth(output_self_sup['inv_depths'])
+ else:
+ output_self_sup = {}
+ depth = batch['depth']
+
+ if not self.training:
+ output_new = {
+ 'predictions': {
+ 'depth': {
+ 0: [1. / d for d in output_self_sup['inv_depths']]
+ }
+ }
+ }
+ output_self_sup = output_new
+ return output_self_sup
+
+ rgb = batch['rgb']
+ rgb_context = batch['rgb_context']
+ intrinsics = batch['intrinsics']
+ pose = batch['extrinsics'] if 'extrinsics' in batch else batch['pose']
+
+ pose_context_gt = batch['pose_context']
+ if self.gt_pose:
+ pose_context = batch['pose_context']
+ else:
+ pose_context = output_self_sup['poses']
+
+ for i in range(len(pose_context)):
+ pose_context[i] = pose_context[i].mat
+
+ rgb_i, rgb_context_i = split_batch(rgb), split_batch(rgb_context)
+ pose_i, pose_context_i = split_batch(pose), split_batch(pose_context)
+ intrinsics_i, inv_depth_i = split_batch(intrinsics), depth2inv(split_batch(depth))
+ cam_i, cam_context_i = global_cameras(intrinsics_i, pose_i, pose_context_i, hw=rgb.shape[2:])
+
+ _, pose_context_i_gt = split_batch(pose), split_batch(pose_context_gt)
+ _, cam_context_i_gt = global_cameras(intrinsics_i, pose_i, pose_context_i_gt, hw=rgb.shape[2:])
+
+ n_tgt = len(rgb_i)
+
+ mono_coords = [coords_from_motion(
+ cam_context_i[tgt], inv2depth(inv_depth_i[tgt]), cam_i[tgt])
+ for tgt in range(n_tgt)]
+ mono_masks = [mask_from_coords(coords) for coords in mono_coords]
+
+ filename = batch['filename']
+ try:
+ filename = ['camera' + f[0].split('/')[-2][-3:]+ '_mask.npy' for f in filename]
+ # filename = ['camera' + f.split('/')[-2][-3:]+ '_mask.npy' for f in filename]
+ masks = [torch.tensor(np.load(f)).unsqueeze(0).unsqueeze(0) for f in filename]
+ for tgt in range(n_tgt):
+ for i in range(len(mono_masks[tgt])):
+ for j in range(len(mono_masks[tgt][i])):
+ for k in range(len(mono_masks[tgt][i][j])):
+ # write_image('debug/camera_%d/mask_%d_%d_%d.png' % (tgt, i, j, k),
+ # mono_masks[tgt][i][j][k])
+ resized_mask = torch.nn.functional.interpolate(
+ masks[tgt], mono_masks[tgt][i][j][k].shape[1:], mode='nearest').squeeze(0).bool()
+ mono_masks[tgt][i][j][k] *= resized_mask.to(mono_masks[tgt][i][j][k].device)
+ with_masks = True
+ except:
+ with_masks = False
+ pass
+
+ mono = []
+ outputs = []
+
+ for tgt in range(n_tgt):
+ output = self.multicam_loss(
+ rgb_i[tgt], rgb_context_i[tgt], inv_depth_i[tgt],
+ cam_i[tgt], cam_context_i[tgt], with_mask=mono_masks[tgt])
+ if not torch.isnan(output['loss']):
+ mono.append(output['loss'])
+ outputs.append(output)
+
+ stereo = []
+ if not self.stereo_context and self.stereo_weight > 0:
+
+ stereo_coords = [coords_from_motion(
+ [cam_i[src]], inv2depth(inv_depth_i[tgt]), cam_i[tgt])
+ for tgt, src in self.pairs]
+ stereo_masks = [mask_from_coords(coords) for coords in stereo_coords]
+
+ if with_masks:
+ for tgt in range(len(self.pairs)):
+ for i in range(len(stereo_masks[tgt])):
+ for j in range(len(stereo_masks[tgt][i])):
+ for k in range(len(stereo_masks[tgt][i][j])):
+ hw = stereo_masks[tgt][i][j][k].shape[1:]
+ h_st, h_fn = int(0.15 * hw[0]), int(0.75 * hw[0])
+ stereo_masks[tgt][i][j][k][:, :h_st] = 0
+ stereo_masks[tgt][i][j][k][:, h_fn:] = 0
+
+ for k, (tgt, src) in enumerate(self.pairs):
+ output = self.multicam_loss(
+ rgb_i[tgt], [rgb_i[src]], inv_depth_i[tgt],
+ cam_i[tgt], [cam_i[src]], with_mask=stereo_masks[k], automask=False)
+ if not torch.isnan(output['loss']):
+ stereo.append(self.stereo_weight * output['loss'])
+
+ elif self.stereo_context and self.stereo_weight > 0:
+
+ stereo_coords = [coords_from_motion(
+ [cam_i[src]] + cam_context_i[src], inv2depth(inv_depth_i[tgt]), cam_i[tgt])
+ for tgt, src in self.pairs]
+ stereo_masks = [mask_from_coords(coords) for coords in stereo_coords]
+
+ for tgt in range(len(self.pairs)):
+ for i in range(len(stereo_masks[tgt])):
+ for j in range(len(stereo_masks[tgt][i])):
+ for k in range(len(stereo_masks[tgt][i][j])):
+ hw = stereo_masks[tgt][i][j][k].shape[1:]
+ h_st, h_fn = int(0.15 * hw[0]), int(0.75 * hw[0])
+ stereo_masks[tgt][i][j][k][:, :h_st] = 0
+ stereo_masks[tgt][i][j][k][:, h_fn:] = 0
+
+ for k, (tgt, src) in enumerate(self.pairs):
+ output = self.multicam_loss(
+ rgb_i[tgt], [rgb_i[src]] + rgb_context_i[src], inv_depth_i[tgt],
+ cam_i[tgt], [cam_i[src]] + cam_context_i[src], with_mask=stereo_masks[k], automask=False)
+ if not torch.isnan(output['loss']):
+ stereo.append(self.stereo_weight * output['loss'])
+
+ losses = mono + stereo
+ loss = sum(losses) / len(losses)
+ output_self_sup['loss'] = loss.unsqueeze(0)
+
+ return {
+ **output_self_sup,
+ 'metrics': {},
+ }
diff --git a/vidar/arch/models/depth/SelfSupervisedModel.py b/vidar/arch/models/depth/SelfSupervisedModel.py
new file mode 100644
index 0000000000000000000000000000000000000000..68c360b6202ccaf97f8775899240aadea836ab74
--- /dev/null
+++ b/vidar/arch/models/depth/SelfSupervisedModel.py
@@ -0,0 +1,176 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from abc import ABC
+
+from vidar.arch.blocks.image.ViewSynthesis import ViewSynthesis
+from vidar.arch.models.BaseModel import BaseModel
+from vidar.arch.models.utils import make_rgb_scales, create_cameras
+from vidar.utils.data import get_from_dict
+from vidar.utils.config import cfg_has
+
+
+class SelfSupervisedModel(BaseModel, ABC):
+ """
+ Self-supervised depth estimation model
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration with parameters
+ """
+ def __init__(self, cfg):
+ super().__init__(cfg)
+ self.view_synthesis = ViewSynthesis()
+ self.set_attr(cfg.model, 'use_gt_pose', False)
+ self.set_attr(cfg.model, 'use_gt_intrinsics', True)
+
+ if not self.use_gt_intrinsics:
+ self.camera_model = cfg_has(cfg.networks.intrinsics, 'camera_model', 'UCM')
+ if self.camera_model == 'UCM':
+ from vidar.geometry.camera_ucm import UCMCamera
+ self.camera_class = UCMCamera
+ elif self.camera_model == 'EUCM':
+ from vidar.geometry.camera_eucm import EUCMCamera
+ self.camera_class = EUCMCamera
+ elif self.camera_model == 'DS':
+ from vidar.geometry.camera_ds import DSCamera
+ self.camera_class = DSCamera
+ else:
+ raise NotImplementedError('Invalid camera type')
+
+ def forward(self, batch, epoch=0):
+ """Model forward pass"""
+
+ rgb = batch['rgb']
+ if self.use_gt_intrinsics:
+ intrinsics = get_from_dict(batch, 'intrinsics')
+ else:
+ intrinsics = self.networks['intrinsics'](rgb=rgb[0])
+
+ valid_mask = get_from_dict(batch, 'mask')
+
+ if self.use_gt_intrinsics:
+ depth_output = self.networks['depth'](rgb=rgb[0], intrinsics=intrinsics[0])
+ else:
+ depth_output = self.networks['depth'](rgb=rgb[0])
+ pred_depth = depth_output['depths']
+
+ predictions = {
+ 'depth': {0: pred_depth},
+ }
+
+ pred_logvar = get_from_dict(depth_output, 'logvar')
+ if pred_logvar is not None:
+ predictions['logvar'] = {0: pred_logvar}
+
+ if not self.training:
+ return {
+ 'predictions': predictions,
+ }
+
+ if self.use_gt_pose:
+ assert 'pose' in batch, 'You need input pose'
+ pose = batch['pose']
+ elif 'pose' in self.networks:
+ pose = self.compute_pose(rgb, self.networks['pose'], tgt=0, invert=True)
+ predictions['pose'] = pose
+ else:
+ pose = None
+
+ if not self.use_gt_intrinsics:
+ cams = {0: self.camera_class(I=intrinsics)}
+ for key in pose.keys():
+ cams[key] = self.camera_class(I=intrinsics, Tcw=pose[key])
+ else:
+ cams = create_cameras(rgb[0], intrinsics[0], pose)
+
+ gt_depth = None if 'depth' not in batch else batch['depth'][0]
+ loss, metrics = self.compute_loss_and_metrics(
+ rgb, pred_depth, cams, gt_depth=gt_depth,
+ logvar=pred_logvar, valid_mask=valid_mask
+ )
+
+ if not self.use_gt_intrinsics:
+ if self.camera_model == 'UCM':
+ fx, fy, cx, cy, alpha = intrinsics[0].squeeze()
+ intrinsics_metrics = {'fx': fx, 'fy':fy, 'cx':cx, 'cy':cy, 'alpha':alpha}
+ metrics.update(intrinsics_metrics)
+ elif self.camera_model == 'EUCM':
+ fx, fy, cx, cy, alpha, beta = intrinsics[0].squeeze()
+ intrinsics_metrics = {'fx': fx, 'fy':fy, 'cx':cx, 'cy':cy, 'alpha':alpha, 'beta':beta}
+ metrics.update(intrinsics_metrics)
+ elif self.camera_model == 'DS':
+ fx, fy, cx, cy, xi, alpha = intrinsics[0].squeeze()
+ intrinsics_metrics = {'fx': fx, 'fy':fy, 'cx':cx, 'cy':cy, 'xi':xi, 'alpha':alpha}
+ metrics.update(intrinsics_metrics)
+ else:
+ raise NotImplementedError('Invalid camera type')
+
+ return {
+ 'loss': loss,
+ 'metrics': metrics,
+ 'predictions': predictions,
+ }
+
+ def compute_loss_and_metrics(self, rgb, depths, cams, gt_depth=None,
+ logvar=None, valid_mask=None):
+ """
+ Compute loss and metrics for training
+
+ Parameters
+ ----------
+ rgb : Dict
+ Dictionary with input images [B,3,H,W]
+ depths : list[torch.Tensor]
+ List with target depth maps in different scales [B,1,H,W]
+ cams : Dict
+ Dictionary with cameras for each input image
+ gt_depth : torch.Tensor
+ Ground-truth depth map for supervised training
+ logvar : list[torch.Tensor]
+ Log-variance maps for uncertainty training
+ valid_mask : list[torch.Tensor]
+ Binary mask for masking out invalid pixels [B,1,H,W]
+
+ Returns
+ -------
+ loss : torch.Tensor
+ Training loss
+ metrics : Dict
+ Dictionary with training metrics
+ """
+ tgt = 0
+ ctx = [key for key in rgb.keys() if key != tgt]
+
+ num_scales = self.get_num_scales(depths)
+
+ rgbs = make_rgb_scales(rgb, pyramid=depths)
+ rgb_tgt = [rgbs[tgt][i] for i in range(num_scales)]
+ rgb_ctx = [[rgbs[j][i] for j in ctx] for i in range(num_scales)]
+
+ loss, metrics = [], {}
+
+ if 'reprojection' in self.losses:
+ synth = self.view_synthesis(
+ rgbs, depths=depths, cams=cams, return_masks=True)
+ reprojection_output = self.losses['reprojection'](
+ rgb_tgt, rgb_ctx, synth['warps'], logvar=logvar,
+ valid_mask=valid_mask, overlap_mask=synth['masks'])
+ loss.append(reprojection_output['loss'])
+ metrics.update(**reprojection_output['metrics'])
+ if 'smoothness' in self.losses:
+ smoothness_output = self.losses['smoothness'](rgb_tgt, depths)
+ loss.append(smoothness_output['loss'])
+ metrics.update(**smoothness_output['metrics'])
+ if 'supervision' in self.losses and gt_depth is not None:
+ supervision_output = self.losses['supervision'](depths, gt_depth)
+ loss.append(supervision_output['loss'])
+ metrics.update(**supervision_output['metrics'])
+ if 'normals' in self.losses and gt_depth is not None:
+ normals_output = self.losses['normals'](depths, gt_depth, cams[0])
+ loss.append(normals_output['loss'])
+ metrics.update(**normals_output['metrics'])
+
+ loss = sum(loss)
+
+ return loss, metrics
diff --git a/vidar/arch/models/depth/SupervisedModel.py b/vidar/arch/models/depth/SupervisedModel.py
new file mode 100755
index 0000000000000000000000000000000000000000..44bfb7d24549c286c7207af7cfa925f115bd1d1c
--- /dev/null
+++ b/vidar/arch/models/depth/SupervisedModel.py
@@ -0,0 +1,92 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from abc import ABC
+
+from vidar.arch.models.BaseModel import BaseModel
+from vidar.utils.decorators import iterate1
+from vidar.utils.tensor import interpolate_image
+
+
+@iterate1
+def make_rgb_scales(rgb, pyramid):
+ return [interpolate_image(rgb, shape=pyr.shape[-2:]) for pyr in pyramid]
+
+
+class SupervisedModel(BaseModel, ABC):
+ """
+ Supervised depth estimation model
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration with parameters
+ """
+ def __init__(self, cfg):
+ super().__init__()
+
+ def forward(self, batch, epoch):
+ """Model forward pass"""
+
+ rgb = batch['rgb']
+
+ depth_output = self.networks['depth'](rgb=rgb[0])
+ depths = depth_output['depths']
+
+ if not self.training:
+ return {
+ 'predictions': {
+ 'depth': {0: depths}
+ },
+ }
+
+ losses = self.compute_losses(rgb, depths, batch['depth'])
+
+ return {
+ 'loss': losses['loss'],
+ 'metrics': {
+ },
+ 'predictions': {
+ 'depth': {0: depths}
+ },
+ }
+
+ def compute_losses(self, rgb, depths, gt_depths):
+ """
+ Compute loss and metrics for training
+
+ Parameters
+ ----------
+ rgb : Dict
+ Dictionary with input images [B,3,H,W]
+ depths : list[torch.Tensor]
+ List with target depth maps in different scales [B,1,H,W]
+ gt_depths : Dict
+ Dictionary with ground-truth depth maps
+
+ Returns
+ -------
+ loss : torch.Tensor
+ Training loss
+ metrics : Dict
+ Dictionary with training metrics
+ """
+ tgt = 0
+
+ rgbs = make_rgb_scales(rgb, depths)
+ rgb_tgt = [rgbs[tgt][i] for i in range(len(rgbs[tgt]))]
+
+ supervision_output = self.losses['supervision'](depths, gt_depths[tgt])
+ smoothness_output = self.losses['smoothness'](rgb_tgt, depths)
+
+ loss = supervision_output['loss'] + \
+ smoothness_output['loss']
+
+ metrics = {
+ **supervision_output['metrics'],
+ **smoothness_output['metrics'],
+ }
+
+ return {
+ 'loss': loss,
+ 'metrics': metrics,
+ }
diff --git a/vidar/arch/models/depth/__init__.py b/vidar/arch/models/depth/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vidar/arch/models/perceiver/DefineGenericModel.py b/vidar/arch/models/perceiver/DefineGenericModel.py
new file mode 100755
index 0000000000000000000000000000000000000000..dd624bb41c762a4956a3c6539e0765243bc2f9b8
--- /dev/null
+++ b/vidar/arch/models/perceiver/DefineGenericModel.py
@@ -0,0 +1,421 @@
+# Copyright 2021 Toyota Research Institute. All rights reserved.
+
+import random
+from abc import ABC
+
+import numpy as np
+import torch
+
+from vidar.arch.models.BaseModel import BaseModel
+from vidar.arch.models.perceiver.DefineModel import augment_canonical, create_virtual_cameras
+from vidar.geometry.camera_nerf import CameraNerf
+from vidar.geometry.pose import Pose
+from vidar.utils.data import flatten
+from vidar.utils.types import is_list
+
+
+def sample_pred_gt(output, gt, pred):
+ for i in range(len(gt)):
+ b, n, h, w = gt[i].shape
+ query_idx = output['query_idx'][i]
+ gt[i] = torch.stack([gt[i][j].view(n, -1)[:, query_idx[j]]
+ for j in range(b)], 0)
+ pred[i][0] = pred[i][0].permute(0, 2, 1)
+ return gt, pred
+
+
+def parse_idx(all_idx, idxs):
+ new_idxs = []
+ for idx in idxs:
+ if is_list(idx[0]) and is_list(idx[1]):
+ for i in idx[0]:
+ for j in idx[1]:
+ new_idxs.append([i, j])
+ elif is_list(idx[0]):
+ for i in idx[0]:
+ new_idxs.append([i, idx[1]])
+ elif is_list(idx[1]):
+ for i in idx[1]:
+ new_idxs.append([idx[0], i])
+ elif idx[0] == '*':
+ for i in all_idx:
+ if i[1] == idx[1]:
+ new_idxs.append(i)
+ elif idx[1] == '*':
+ for i in all_idx:
+ if i[0] == idx[0]:
+ new_idxs.append(i)
+ else:
+ new_idxs.append(idx)
+ return new_idxs
+
+
+class DefineGenericModel(BaseModel, ABC):
+
+ def __init__(self, cfg):
+ super().__init__(cfg)
+
+ from vidar.arch.networks.perceiver.DefineNet import DefineNet
+ self.networks['perceiver'] = DefineNet(cfg.model.network)
+ self.weights = cfg.model.task_weights
+ self.use_pose_noise = cfg.model.use_pose_noise
+ self.use_virtual_cameras = cfg.model.use_virtual_cameras
+ self.virtual_cameras_eval = cfg.model.virtual_cameras_eval
+ self.use_virtual_rgb = cfg.model.use_virtual_rgb
+ self.augment_canonical = cfg.model.augment_canonical
+ self.scale_loss = cfg.model.scale_loss
+
+ self.encode_train = cfg.model.encode_train
+ self.decode_train = cfg.model.decode_train
+
+ self.encode_eval = cfg.model.encode_eval
+ self.decode_eval = cfg.model.decode_eval
+
+ if self.encode_eval == 'same':
+ self.encode_eval = self.encode_train
+ if self.decode_eval == 'same':
+ self.decode_eval = self.decode_train
+
+ self.decode_encodes = cfg.model.decode_encodes
+ self.sample_decoded_queries = cfg.model.sample_decoded_queries
+
+ if cfg.model.has('display'):
+ if cfg.model.display == 'interactive':
+ from vidar.arch.models.perceiver.display import DisplayInteractive
+ self.display = DisplayInteractive()
+ elif cfg.model.display == 'training':
+ from vidar.arch.models.perceiver.display import DisplayTraining
+ self.display = DisplayTraining()
+ else:
+ self.display = None
+
+ def get_idx(self, all_idx):
+
+ n = len(all_idx)
+
+ num_encodes = self.encode_train if self.training else self.encode_eval
+ num_decodes = self.decode_train if self.training else self.decode_eval
+
+ encode_idx = None
+ if is_list(num_encodes):
+ if num_encodes[0].startswith('+'):
+ encode_idx = parse_idx(all_idx, num_encodes[1:])
+ elif num_encodes[0].startswith('-'):
+ num_encodes_parse = parse_idx(all_idx, num_encodes[1:])
+ encode_idx = [
+ idx for idx in all_idx if idx not in num_encodes_parse]
+ if len(num_encodes[0]) > 1:
+ num = int(num_encodes[0][1:])
+ encode_idx = np.random.permutation(encode_idx)
+ encode_idx = encode_idx[:num]
+ elif num_encodes == 'all':
+ num_encodes = n
+ elif num_encodes < 0:
+ num_encodes = n + num_encodes
+
+ decode_idx = None
+ if is_list(num_decodes):
+ if num_decodes[0].startswith('+'):
+ decode_idx = parse_idx(all_idx, num_decodes[1:])
+ elif num_decodes[0].startswith('-'):
+ num_decodes_parse = parse_idx(all_idx, num_decodes[1:])
+ decode_idx = [
+ idx for idx in all_idx if idx not in num_decodes_parse]
+ if len(num_decodes[0]) > 1:
+ num = int(num_decodes[0][1:])
+ decode_idx = np.random.permutation(decode_idx)
+ decode_idx = decode_idx[:num]
+ elif num_decodes == 'all':
+ num_decodes = n
+ elif num_decodes == 'remaining':
+ num_decodes = n - num_encodes
+ elif num_decodes < 0:
+ num_encodes = n + num_decodes
+
+ if self.training:
+ # Shuffle indices and separate encode and decode indices
+ all_idx = np.random.permutation(all_idx)
+
+ if encode_idx is None:
+ encode_idx = all_idx[:num_encodes]
+ if decode_idx is None:
+ decode_idx = all_idx[-num_decodes:]
+
+ encode_idx = [list(idx) for idx in encode_idx]
+ decode_idx = [list(idx) for idx in decode_idx]
+
+ if self.decode_encodes:
+ decode_idx += [idx for idx in encode_idx if idx not in decode_idx]
+
+ return encode_idx, decode_idx
+
+ def parse_output(self, output, key, encode_idx, predictions):
+ if key in output.keys():
+ pred = [[val] for val in output[key]]
+ if not self.training:
+ predictions_key = {}
+ for idx, (i, j) in enumerate(encode_idx):
+ if j not in predictions_key.keys():
+ predictions_key[j] = []
+ predictions_key[j].append(output[key][idx])
+ predictions_key = {key: [torch.stack(val, 1)]
+ for key, val in predictions_key.items()}
+ predictions[key] = predictions_key
+ else:
+ pred = None
+ return pred
+
+ def forward(self, batch, epoch=0, collate=True):
+
+ # Run on list of b
+
+ if is_list(batch):
+ output = [self.forward(b, collate=False) for b in batch]
+ loss = [out['loss'] for out in output]
+ return {'loss': sum(loss) / len(loss), 'predictions': {}, 'metrics': {}}
+
+ if not collate:
+ for key in ['rgb', 'intrinsics', 'pose', 'depth']:
+ batch[key] = {k: v.unsqueeze(0) for k, v in batch[key].items()}
+
+ # Unsqueeze batch data if there is only one camera
+
+ key_dim = {'rgb': 4, 'depth': 4, 'intrinsics': 3, 'pose': 3}
+ for key in ['rgb', 'depth', 'intrinsics', 'pose']:
+ for ctx in batch[key].keys():
+ if batch[key][ctx].dim() == key_dim[key]:
+ batch[key][ctx] = batch[key][ctx].unsqueeze(1)
+ for key in ['intrinsics', 'pose']:
+ for ctx in batch[key].keys():
+ if batch[key][ctx].dim() == 3:
+ batch[key][ctx] = batch[key][ctx].unsqueeze(1)
+
+ ###
+
+ rgb = batch['rgb']
+ intrinsics = batch['intrinsics']
+ pose = batch['pose']
+ depth = batch['depth']
+
+ # Get context keys, batch size and number of cameras
+ ctx = [key for key in rgb.keys() if key != 'virtual']
+ b, n = rgb[0].shape[:2]
+
+ # Create all indices in a list
+ ii, jj = list(range(n)), ctx
+ all_idx = flatten([[[i, j] for i in ii] for j in jj])
+ encode_idx, decode_idx = self.get_idx(all_idx)
+
+ # Prepare pose and add jittering if requested
+
+ pose = [{j: pose[j][i] for j in ctx} for i in range(b)]
+
+ pose0 = [Pose().to(rgb[0].device) for _ in range(b)]
+ if self.training and len(self.use_pose_noise) > 0.0:
+ if random.random() < self.use_pose_noise[0]:
+ for i in range(b):
+ pose0[i].translateUp(
+ self.use_pose_noise[1] * (2 * random.random() - 1))
+ pose0[i].translateLeft(
+ self.use_pose_noise[1] * (2 * random.random() - 1))
+ pose0[i].translateForward(
+ self.use_pose_noise[1] * (2 * random.random() - 1))
+ pose0[i].rotateRoll(
+ torch.pi * self.use_pose_noise[2] * (2 * random.random() - 1))
+ pose0[i].rotatePitch(
+ torch.pi * self.use_pose_noise[2] * (2 * random.random() - 1))
+ pose0[i].rotateYaw(
+ torch.pi * self.use_pose_noise[2] * (2 * random.random() - 1))
+ for i in range(b):
+ pose[i][0][[0]] = pose0[i].T
+
+ pose = [Pose.from_dict(
+ p, to_global=True, zero_origin=False, to_matrix=True) for p in pose]
+ pose = {j: torch.stack([pose[i][j] for i in range(b)], 0) for j in ctx}
+
+ # Augment canonical pose if requested
+ if self.training and self.augment_canonical:
+ pose = augment_canonical(pose)
+
+ # Separate batch data per camera
+ rgb = [{j: rgb[j][:, i] for j in ctx} for i in range(n)]
+ intrinsics = [{j: intrinsics[0][:, i] for j in ctx} for i in range(n)]
+ pose = [{j: pose[j][:, i] for j in ctx} for i in range(n)]
+ depth = [{j: depth[j][:, i] for j in ctx} for i in range(n)]
+
+ # Create camera with batch information
+ cams = [{j: CameraNerf(K=intrinsics[i][0], Twc=pose[i][j], hw=rgb[i][j])
+ for j in ctx} for i in range(n)]
+
+ # Create encode dictionary
+ encode_data = [{
+ 'rgb': rgb[i][j],
+ 'cam': cams[i][j],
+ 'gt_depth': depth[i][j],
+ } for i, j in encode_idx]
+
+ # Create decode dictionary
+ decode_data = [{
+ 'rgb': rgb[i][j],
+ 'cam': cams[i][j],
+ 'gt_depth': depth[i][j],
+ } for i, j in decode_idx]
+
+ # Run PerceiverIO (encode and decode)
+ perceiver_output = self.networks['perceiver'](
+ encode_data=encode_data,
+ decode_data=decode_data,
+ sample_queries=self.sample_decoded_queries,
+ filter_invalid=False,
+ )
+ output = perceiver_output['output']
+
+ predictions = {}
+
+ # DEPTH
+
+ # Get predicted depths
+ pred_depths = self.parse_output(
+ output, 'depth', decode_idx, predictions)
+ # Get predicted monocular depths
+ pred_depths_mono = self.parse_output(
+ output, 'depth_mono', decode_idx, predictions)
+
+ # RGB
+
+ # Get predicted RGB
+ pred_rgbs = self.parse_output(output, 'rgb', decode_idx, predictions)
+
+ # VIRTUAL
+
+ if len(self.use_virtual_cameras) > 0 and (self.training or self.virtual_cameras_eval):
+
+ virtual_data = create_virtual_cameras(
+ decode_data,
+ n_samples=self.use_virtual_cameras[1],
+ cam_noise=self.use_virtual_cameras[2:-1],
+ center_noise=self.use_virtual_cameras[-1],
+ thr=0.1, tries=10, decay=0.9,
+ )
+ virtual_output = self.networks['perceiver'].decode(
+ latent=perceiver_output['latent'], data=virtual_data,
+ sources=['cam'], field='cam',
+ sample_queries=self.sample_decoded_queries,
+ filter_invalid=True
+ )['output']
+
+ gt_virtual_cams = [data['cam'] for data in virtual_data]
+
+ if pred_depths is not None:
+ pred_depths_virtual = [[depth]
+ for depth in virtual_output['depth']]
+ gt_depths_virtual = [data['gt_depth'] for data in virtual_data]
+
+ if not self.training:
+ batch['depth']['virtual'] = torch.stack(
+ gt_depths_virtual, 1)
+ predictions['depth']['virtual'] = [torch.stack(
+ [pred[0] for pred in pred_depths_virtual], 1)]
+
+ if 'query_idx' in virtual_output:
+ gt_depths_virtual, pred_depths_virtual = sample_pred_gt(
+ virtual_output, gt_depths_virtual, pred_depths_virtual)
+ else:
+ pred_depths_virtual, gt_depths_virtual = None, None
+
+ if pred_rgbs is not None:
+ pred_rgbs_virtual = [[rgb] for rgb in virtual_output['rgb']]
+ gt_rgbs_virtual = [data['rgb'] for data in virtual_data]
+
+ if not self.training:
+ batch['rgb']['virtual'] = torch.stack(gt_rgbs_virtual, 1)
+ predictions['rgb']['virtual'] = [torch.stack(
+ [pred[0] for pred in pred_rgbs_virtual], 1)]
+
+ if 'query_idx' in virtual_output:
+ gt_rgbs_virtual, pred_rgbs_virtual = sample_pred_gt(
+ virtual_output, gt_rgbs_virtual, pred_rgbs_virtual)
+ else:
+ pred_rgbs_virtual, gt_rgbs_virtual = None, None
+
+ else:
+
+ virtual_data = virtual_output = None
+
+ ##########################################################
+
+ if self.display is not None:
+ self.display.loop(self.networks, encode_data,
+ decode_data, output, virtual_data)
+
+ ##########################################################
+
+ if not self.training:
+ return {
+ 'predictions': predictions,
+ 'batch': batch,
+ }
+
+ # Get GT images and depths
+ gt_depths = [depth[i][j] for i, j in decode_idx]
+ gt_rgbs = [rgb[i][j] for i, j in decode_idx]
+
+ if 'query_idx' in output:
+ if pred_depths is not None:
+ gt_depths, pred_depths = sample_pred_gt(
+ output, gt_depths, pred_depths)
+ if pred_rgbs is not None:
+ gt_rgbs, pred_rgbs = sample_pred_gt(output, gt_rgbs, pred_rgbs)
+
+ loss, metrics = self.compute_loss_and_metrics(
+ pred_rgbs, gt_rgbs,
+ pred_depths, gt_depths,
+ )
+
+ if len(self.use_virtual_cameras) > 0:
+ virtual_loss, _ = self.compute_loss_and_metrics(
+ pred_rgbs_virtual if self.use_virtual_rgb else None,
+ gt_rgbs_virtual if self.use_virtual_rgb else None,
+ pred_depths_virtual, gt_depths_virtual,
+ )
+ loss = loss + self.use_virtual_cameras[0] * virtual_loss
+
+ if pred_depths_mono is not None:
+ mono_loss, _ = self.compute_loss_and_metrics(
+ None, None, pred_depths_mono, gt_depths,
+ )
+ loss = loss + mono_loss
+
+ return {
+ 'loss': loss,
+ 'metrics': metrics,
+ 'predictions': predictions,
+ }
+
+ def compute_loss_and_metrics(self, pred_rgbs, gt_rgbs, pred_depths, gt_depths):
+
+ loss, metrics = [], {}
+
+ # Calculate RGB losses
+ if pred_rgbs is not None and 'rgb' in self.losses and self.weights[0] > 0.0:
+ loss_rgb = []
+ for pred, gt in zip(pred_rgbs, gt_rgbs):
+ rgb_output = self.losses['rgb'](pred, gt)
+ loss_rgb.append(self.weights[0] * rgb_output['loss'])
+ loss.append(sum(loss_rgb) / len(loss_rgb))
+
+ # Calculate depth losses
+ if pred_depths is not None and 'depth' in self.losses and self.weights[1] > 0.0:
+ loss_depth = []
+ for pred, gt in zip(pred_depths, gt_depths):
+ depth_output = self.losses['depth'](pred, gt)
+ loss_depth.append(self.weights[1] * depth_output['loss'])
+ loss.append(sum(loss_depth) / len(loss_depth))
+
+ if len(loss) == 2 and self.scale_loss:
+ ratio_rgb_depth = loss[1].item() / loss[0].item()
+ loss[0] = loss[0] * ratio_rgb_depth
+
+ loss = sum(loss) / len(loss)
+
+ return loss, metrics
diff --git a/vidar/arch/models/perceiver/DefineModel.py b/vidar/arch/models/perceiver/DefineModel.py
new file mode 100755
index 0000000000000000000000000000000000000000..248f46a1e81974678130f9e4bf35c293a97940f1
--- /dev/null
+++ b/vidar/arch/models/perceiver/DefineModel.py
@@ -0,0 +1,354 @@
+# Copyright 2021 Toyota Research Institute. All rights reserved.
+
+import random
+from abc import ABC
+
+import torch
+
+from vidar.arch.models.BaseModel import BaseModel
+from vidar.geometry.camera_nerf import CameraNerf
+from vidar.geometry.pose import Pose
+from vidar.utils.data import flatten
+from vidar.utils.types import is_list, is_int
+
+
+def augment_canonical(pose):
+
+ ctx = list(pose.keys())
+ num = list(range(pose[0].shape[1]))
+
+ i = random.choice(ctx)
+ j = random.choice(num)
+
+ base = Pose(pose[i][:, j]).inverse().T
+ for key in ctx:
+ for n in num:
+ pose[key][:, n] = pose[key][:, n] @ base
+
+ return pose
+
+
+def parse_output(output, key, encode_idx, predictions):
+ if key in output.keys():
+ pred = [[val] for val in output[key]]
+ predictions_key = {}
+ for idx, (i, j) in enumerate(encode_idx):
+ if j not in predictions_key.keys():
+ predictions_key[j] = []
+ predictions_key[j].append(output[key][idx])
+ predictions_key = {key: [torch.stack(val, 1)]
+ for key, val in predictions_key.items()}
+ predictions[key] = predictions_key
+ else:
+ pred = None
+ return pred
+
+
+def create_virtual_cameras(encode_data, n_samples=1, cam_noise=None, center_noise=None,
+ downsample=1.0, thr=0.1, tries=10, decay=0.9):
+
+ gt_cams = [datum['cam'] for datum in encode_data]
+ gt_depths = [datum['gt_depth'] for datum in encode_data]
+ gt_rgbs = [datum['rgb'] for datum in encode_data]
+
+ if not is_int(thr):
+ n = gt_rgbs[0].shape[-2] * gt_rgbs[0].shape[-1]
+ thr = int(n * thr)
+
+ pcls_proj = [cam.scaled(downsample).reconstruct_depth_map(depth, to_world=True)
+ for cam, depth in zip(gt_cams, gt_depths)]
+ pcls_proj = [pcl.reshape(*pcl.shape[:2], -1) for pcl in pcls_proj]
+ pcl_proj = torch.cat([pcl for pcl in pcls_proj], -1)
+ clr_proj = torch.cat([rgb.reshape(*rgb.shape[:2], -1)
+ for rgb in gt_rgbs], -1)
+
+ gt_pcl_centers = [pcl.mean(-1) for pcl in pcls_proj]
+
+ virtual_data = []
+ for gt_rgb, gt_depth, gt_cam, gt_pcl_center in zip(gt_rgbs, gt_depths, gt_cams, gt_pcl_centers):
+ for i in range(n_samples):
+
+ cam = gt_cam.clone()
+ pcl_center = gt_pcl_center.clone()
+
+ weight = 1.0
+ rgb_proj = depth_proj = None
+ for j in range(tries):
+
+ if center_noise is not None:
+ pcl_center_noise = weight * center_noise * \
+ (2 * torch.rand_like(gt_pcl_center) - 1)
+ pcl_center = gt_pcl_center + pcl_center_noise
+
+ if cam_noise is not None:
+ cam.look_at(pcl_center)
+ cam.Twc.translateUp(
+ weight * cam_noise[0] * (2 * random.random() - 1))
+ cam.Twc.translateLeft(
+ weight * cam_noise[1] * (2 * random.random() - 1))
+ cam.Twc.translateForward(
+ weight * cam_noise[2] * (2 * random.random() - 1))
+ cam.look_at(pcl_center)
+
+ rgb_proj_try, depth_proj_try = cam.project_pointcloud(
+ pcl_proj, clr_proj)
+
+ valid = (depth_proj_try > 0).sum() > thr
+ if valid:
+ rgb_proj, depth_proj = rgb_proj_try, depth_proj_try
+ break
+ else:
+ weight = weight * decay
+
+ if rgb_proj is None and depth_proj is None:
+ rgb_proj, depth_proj = gt_rgb, gt_depth
+ cam = gt_cam.clone()
+
+ virtual_data.append({
+ 'cam': cam,
+ 'rgb': rgb_proj.contiguous(),
+ 'gt_depth': depth_proj.contiguous(),
+ })
+
+ return virtual_data
+
+
+def ControlVidarCamera(draw, cam, tvel=0.2, rvel=0.1):
+ change = False
+ if draw.UP:
+ cam.Twc.translateForward(tvel)
+ change = True
+ if draw.DOWN:
+ cam.Twc.translateBackward(tvel)
+ change = True
+ if draw.LEFT:
+ cam.Twc.translateLeft(tvel)
+ change = True
+ if draw.RIGHT:
+ cam.Twc.translateRight(tvel)
+ change = True
+ if draw.PGUP:
+ cam.Twc.translateUp(tvel)
+ change = True
+ if draw.PGDOWN:
+ cam.Twc.translateDown(tvel)
+ change = True
+ if draw.KEY_A:
+ cam.Twc.rotateYaw(-rvel)
+ change = True
+ if draw.KEY_D:
+ cam.Twc.rotateYaw(+rvel)
+ change = True
+ if draw.KEY_W:
+ cam.Twc.rotatePitch(+rvel)
+ change = True
+ if draw.KEY_S:
+ cam.Twc.rotatePitch(-rvel)
+ change = True
+ if draw.KEY_Q:
+ cam.Twc.rotateRoll(-rvel)
+ change = True
+ if draw.KEY_E:
+ cam.Twc.rotateRoll(+rvel)
+ change = True
+ return change
+
+
+class HuggingModel(BaseModel, ABC):
+
+ def __init__(self, cfg):
+ super().__init__(cfg)
+
+ from vidar.arch.networks.perceiver.HuggingNet import HuggingNet
+ self.networks['perceiver'] = HuggingNet(cfg.model.network)
+ self.weights = [1.0, 1.0]
+ self.use_pose_noise = cfg.model.use_pose_noise
+ self.use_virtual_cameras = cfg.model.use_virtual_cameras
+ self.use_virtual_rgb = cfg.model.use_virtual_rgb
+ self.augment_canonical = cfg.model.augment_canonical
+
+ def forward(self, batch, epoch=0, collate=True):
+
+ if is_list(batch):
+ output = [self.forward(b, collate=False) for b in batch]
+ loss = [out['loss'] for out in output]
+ return {'loss': sum(loss) / len(loss), 'predictions': {}, 'metrics': {}}
+
+ if not collate:
+ for key in ['rgb', 'intrinsics', 'pose', 'depth']:
+ batch[key] = {k: v.unsqueeze(0) for k, v in batch[key].items()}
+
+ rgb = batch['rgb']
+ intrinsics = batch['intrinsics']
+ pose = batch['pose']
+ depth = batch['depth']
+
+ ctx = [0] # rgb.keys()
+ b, n = rgb[0].shape[:2]
+
+ ii, jj = range(n), ctx
+
+ pose = [{key: val[i] for key, val in pose.items()} for i in range(b)]
+
+ pose0 = [Pose().to(rgb[0].device) for _ in range(b)]
+ if self.training and len(self.use_pose_noise) > 0.0:
+ if random.random() < self.use_pose_noise[0]:
+ for i in range(b):
+ pose0[i].translateUp(
+ self.use_pose_noise[1] * (2 * random.random() - 1))
+ pose0[i].translateLeft(
+ self.use_pose_noise[1] * (2 * random.random() - 1))
+ pose0[i].translateForward(
+ self.use_pose_noise[1] * (2 * random.random() - 1))
+ pose0[i].rotateRoll(
+ torch.pi * self.use_pose_noise[2] * (2 * random.random() - 1))
+ pose0[i].rotatePitch(
+ torch.pi * self.use_pose_noise[2] * (2 * random.random() - 1))
+ pose0[i].rotateYaw(
+ torch.pi * self.use_pose_noise[2] * (2 * random.random() - 1))
+ for i in range(b):
+ pose[i][0][[0]] = pose0[i].T
+
+ pose = [Pose.from_dict(
+ p, to_global=True, zero_origin=False, to_matrix=True) for p in pose]
+ pose = {key: torch.stack([pose[i][key] for i in range(
+ len(pose))], 0) for key in pose[0].keys()}
+
+ if self.training and self.augment_canonical:
+ pose = augment_canonical(pose)
+
+ rgb = [{key: val[:, i] for key, val in rgb.items()} for i in range(n)]
+ intrinsics = [{key: val[:, i]
+ for key, val in intrinsics.items()} for i in range(n)]
+ pose = [{key: val[:, i] for key, val in pose.items()}
+ for i in range(n)]
+ depth = [{key: val[:, i] for key, val in depth.items()}
+ for i in range(n)]
+
+ cams = [{j: CameraNerf(K=intrinsics[i][0], Twc=pose[i][j], hw=rgb[i][j])
+ for j in ctx} for i in range(n)]
+
+ encode_idx = flatten([[[i, j] for i in ii] for j in jj])
+
+ encode_data = [{
+ 'rgb': rgb[i][j],
+ 'cam': cams[i][j],
+ 'gt_depth': depth[i][j],
+ } for i, j in encode_idx]
+
+ perceiver_output = self.networks['perceiver'](
+ encode_data=encode_data,
+ )
+
+ output = perceiver_output['output']
+
+ predictions = {}
+
+ # DEPTH
+
+ pred_depths = parse_output(output, 'depth', encode_idx, predictions)
+ pred_depths_mono = parse_output(
+ output, 'depth_mono', encode_idx, predictions)
+
+ if pred_depths is not None and pred_depths_mono is not None:
+ for i in range(len(pred_depths)):
+ pred_depths[i] += pred_depths_mono[i]
+
+ # RGB
+
+ pred_rgbs = parse_output(output, 'rgb', encode_idx, predictions)
+
+ # VIRTUAL
+
+ if len(self.use_virtual_cameras) > 0:
+
+ virtual_data = create_virtual_cameras(
+ encode_data,
+ n_samples=self.use_virtual_cameras[1],
+ cam_noise=self.use_virtual_cameras[2:-1],
+ center_noise=self.use_virtual_cameras[-1],
+ thr=0.1, tries=10, decay=0.9,
+ )
+ virtual_output = self.networks['perceiver'].decode(
+ latent=perceiver_output['latent'], data=virtual_data,
+ sources=['cam'], field='cam',
+ )['output']
+
+ gt_virtual_cams = [data['cam'] for data in virtual_data]
+
+ if pred_depths is not None:
+ pred_depths_virtual = [[depth]
+ for depth in virtual_output['depth']]
+ gt_depths_virtual = [data['gt_depth'] for data in virtual_data]
+
+ batch['depth']['virtual'] = torch.stack(gt_depths_virtual, 1)
+ predictions['depth']['virtual'] = [torch.stack(
+ [pred[0] for pred in pred_depths_virtual], 1)]
+ else:
+ pred_depths_virtual, gt_depths_virtual = None, None
+
+ if pred_rgbs is not None:
+ pred_rgbs_virtual = [[rgb] for rgb in virtual_output['rgb']]
+ gt_rgbs_virtual = [data['rgb'] for data in virtual_data]
+
+ batch['rgb']['virtual'] = torch.stack(gt_rgbs_virtual, 1)
+ predictions['rgb']['virtual'] = [torch.stack(
+ [pred[0] for pred in pred_rgbs_virtual], 1)]
+ else:
+ pred_rgbs_virtual, gt_rgbs_virtual = None, None
+
+ ##########################################################
+
+ # display_data_interactive(self.networks, encode_data, encode_data, output, virtual_data)
+ # display_data(rgb, depth, cams, ctx, n, batch, predictions, gt_virtual_cams)
+
+ ##########################################################
+
+ if not self.training:
+ return {
+ 'predictions': predictions,
+ 'batch': batch,
+ }
+
+ gt_depths = [depth[i][j] for i, j in encode_idx]
+ gt_rgbs = [rgb[i][j] for i, j in encode_idx]
+
+ loss, metrics = self.compute_loss_and_metrics(
+ pred_rgbs, gt_rgbs,
+ pred_depths, gt_depths,
+ )
+
+ if len(self.use_virtual_cameras) > 0:
+ virtual_loss, _ = self.compute_loss_and_metrics(
+ pred_rgbs_virtual if self.use_virtual_rgb else None,
+ gt_rgbs_virtual if self.use_virtual_rgb else None,
+ pred_depths_virtual, gt_depths_virtual,
+ )
+ loss = loss + self.use_virtual_cameras[0] * virtual_loss
+
+ return {
+ 'loss': loss,
+ 'metrics': metrics,
+ 'predictions': predictions,
+ }
+
+ def compute_loss_and_metrics(self, pred_rgbs, gt_rgbs, pred_depths, gt_depths):
+
+ loss, metrics = [], {}
+
+ if pred_rgbs is not None and 'rgb' in self.losses and self.weights[0] > 0.0:
+ loss_rgb = []
+ for pred, gt in zip(pred_rgbs, gt_rgbs):
+ rgb_output = self.losses['rgb'](pred, gt)
+ loss_rgb.append(self.weights[0] * rgb_output['loss'])
+ loss.append(sum(loss_rgb) / len(loss_rgb))
+ if pred_depths is not None and 'depth' in self.losses and self.weights[1] > 0.0:
+ loss_depth = []
+ for pred, gt in zip(pred_depths, gt_depths):
+ depth_output = self.losses['depth'](pred, gt)
+ loss_depth.append(self.weights[1] * depth_output['loss'])
+ loss.append(sum(loss_depth) / len(loss_depth))
+
+ loss = sum(loss) / len(loss)
+
+ return loss, metrics
diff --git a/vidar/arch/models/utils.py b/vidar/arch/models/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..33c2ebb92be2f1c8519b27d64b8e7a3aed301d8b
--- /dev/null
+++ b/vidar/arch/models/utils.py
@@ -0,0 +1,112 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import torch
+
+from vidar.geometry.camera_full import CameraFull
+from vidar.utils.decorators import iterate1
+from vidar.utils.tensor import interpolate_image
+from vidar.utils.types import is_dict
+
+
+@iterate1
+def make_rgb_scales(rgb, pyramid=None, ratio_scales=None):
+ """
+ Create different RGB scales to correspond with predictions
+
+ Parameters
+ ----------
+ rgb : torch.Tensor
+ Input image [B,3,H,W]
+ pyramid : list[torch.Tensor]
+ List with tensors at different scales
+ ratio_scales : Tuple
+ Alternatively, you can provide how many scales and the downsampling ratio for each
+
+ Returns
+ -------
+ pyramid : list[torch.Tensor]
+ List with the input image at the same resolutions as pyramid
+ """
+ assert pyramid is None or ratio_scales is None
+ if pyramid is not None:
+ return [interpolate_image(rgb, shape=pyr.shape[-2:]) for pyr in pyramid]
+ elif ratio_scales is not None:
+ return [interpolate_image(rgb, scale_factor=ratio_scales[0] ** i)
+ for i in range(ratio_scales[1])]
+ else:
+ raise NotImplementedError('Invalid option')
+
+
+def break_context(dic, tgt=0, ctx=None, scl=None, stack=False):
+ """
+ Separate a dictionary between target and context information
+
+ Parameters
+ ----------
+ dic : Dict
+ Input dictionary
+ tgt : Int
+ Which key corresponds to target
+ ctx : Int
+ Which key corresponds to context (if None, use everything else)
+ scl : Int
+ Which scale should be used (it None, assume there are no scales)
+ stack : Bool
+ Stack output context or not
+
+ Returns
+ -------
+ tgt : torch.Tensor
+ Target information
+ ctx : list[torch.Tensor] or torch.Tensor
+ Context information (list or stacked)
+ """
+ # Get remaining frames if context is not provided
+ if ctx is None:
+ ctx = [key for key in dic.keys() if key != tgt]
+ # Get all scales or a single scale
+ if scl is None:
+ tgt, ctx = dic[tgt], [dic[key] for key in ctx if key != tgt]
+ else:
+ tgt, ctx = dic[tgt][scl], [dic[key][scl] for key in ctx if key != tgt]
+ # Stack context if requested
+ if stack:
+ ctx = torch.stack(ctx, 1)
+ # Return target and context
+ return tgt, ctx
+
+
+def create_cameras(rgb, intrinsics, pose, zero_origin=True, scaled=None):
+ """
+ Create cameras from batch information
+ Parameters
+ ----------
+ rgb : Dict
+ Dictionary with images
+ intrinsics : Dict
+ Dictionary with camera intrinsics
+ pose : Dict
+ Dictionary with camera poses
+ zero_origin : Bool
+ Zero target camera to the origin or not
+ scaled : Float
+ Scale factor for the output cameras
+
+ Returns
+ -------
+ cams : Dict
+ Dictionary with output cameras
+ """
+ if pose is None:
+ return None
+ cams = {key: CameraFull(
+ K=intrinsics[key] if is_dict(intrinsics) else intrinsics,
+ Twc=pose[key],
+ hw=rgb[key] if is_dict(rgb) else rgb,
+ ).scaled(scaled).to(pose[key].device) for key in pose.keys()}
+ if zero_origin:
+ cams[0] = CameraFull(
+ K=intrinsics[0] if is_dict(intrinsics) else intrinsics,
+ hw=rgb[0] if is_dict(rgb) else rgb,
+ ).scaled(scaled).to(rgb.device)
+ return cams
diff --git a/vidar/arch/networks/BaseNet.py b/vidar/arch/networks/BaseNet.py
new file mode 100755
index 0000000000000000000000000000000000000000..6489f5b287dd0d433743515cae08a6a5ca5b9fd8
--- /dev/null
+++ b/vidar/arch/networks/BaseNet.py
@@ -0,0 +1,54 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import torch
+import torch.nn as nn
+
+from vidar.arch.blocks.depth.SigmoidToInvDepth import SigmoidToInvDepth
+from vidar.utils.config import cfg_has
+
+
+class BaseNet(nn.Module):
+ """Base network class, that all other networks inherit"""
+ def __init__(self, cfg):
+ super().__init__()
+ self.networks = torch.nn.ModuleDict()
+ self.blocks = torch.nn.ModuleDict()
+
+ if cfg_has(cfg, 'depth_range'):
+ self.to_depth = SigmoidToInvDepth(
+ cfg.depth_range[0], cfg.depth_range[1], return_depth=True)
+ else:
+ self.to_depth = None
+
+ def _forward_unimplemented(self, *args):
+ raise NotImplementedError('Forward unimplemented is unimplemented!')
+
+ def set_attr(self, cfg, key, default):
+ """Set a network attribute"""
+ self.__setattr__(key, cfg_has(cfg, key, default))
+
+ def train(self, mode=True):
+ """Set all networks and blocks to train or val"""
+ super().train(mode=mode)
+ for key, val in self.networks.items():
+ val.train(mode=mode)
+ for key, val in self.blocks.values():
+ val.train(mode=mode)
+
+ def eval(self):
+ self.train(mode=False)
+
+ def sigmoid_to_depth(self, sigmoids):
+ """Convert sigmoids to depth values"""
+ return self.to_depth(sigmoids) if self.to_depth is not None else sigmoids
+
+ def load(self, ckpt, name):
+ """Loads a checkpoint onto the network"""
+ state_dict = torch.load(ckpt, map_location='cpu')['state_dict']
+ updated_state_dict = {}
+ for key, val in state_dict.items():
+ idx = key.find(name)
+ if idx > -1:
+ updated_state_dict[key[idx + len(name) + 1:]] = val
+ self.load_state_dict(updated_state_dict)
+
diff --git a/vidar/arch/networks/__init__.py b/vidar/arch/networks/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vidar/arch/networks/decoders/DepthDecoder.py b/vidar/arch/networks/decoders/DepthDecoder.py
new file mode 100755
index 0000000000000000000000000000000000000000..846e77fda2f187c7cf490f3c8d1eaaa3dc04116c
--- /dev/null
+++ b/vidar/arch/networks/decoders/DepthDecoder.py
@@ -0,0 +1,81 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from abc import ABC
+from collections import OrderedDict
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from vidar.arch.networks.layers.convs import ConvBlock, Conv3x3, upsample
+from vidar.utils.config import cfg_has
+
+
+class DepthDecoder(nn.Module, ABC):
+ """
+ Depth decoder network
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration with parameters
+ """
+ def __init__(self, cfg):
+ super().__init__()
+
+ self.num_scales = cfg_has(cfg, 'num_scales', 4)
+ self.use_skips = cfg.use_skips
+
+ self.num_ch_enc = cfg.num_ch_enc
+ self.num_ch_dec = np.array([16, 32, 64, 128, 256])
+ self.num_ch_out = cfg.num_ch_out
+
+ self.convs = OrderedDict()
+ for i in range(self.num_scales, -1, -1):
+
+ num_ch_in = self.num_ch_enc[-1] if i == self.num_scales else self.num_ch_dec[i + 1]
+ num_ch_out = self.num_ch_dec[i]
+ self.convs[('upconv', i, 0)] = ConvBlock(
+ num_ch_in, num_ch_out)
+
+ num_ch_in = self.num_ch_dec[i]
+ if self.use_skips and i > 0:
+ num_ch_in += self.num_ch_enc[i - 1]
+ num_ch_out = self.num_ch_dec[i]
+ self.convs[('upconv', i, 1)] = ConvBlock(
+ num_ch_in, num_ch_out)
+
+ for i in range(self.num_scales):
+ self.convs[('outconv', i)] = Conv3x3(
+ self.num_ch_dec[i], self.num_ch_out)
+
+ self.decoder = nn.ModuleList(list(self.convs.values()))
+
+ if cfg.activation == 'sigmoid':
+ self.activation = nn.Sigmoid()
+ elif cfg.activation == 'identity':
+ self.activation = nn.Identity()
+ elif cfg.activation == 'softmax':
+ self.activation = nn.Softmax(dim=1)
+ else:
+ raise ValueError('Invalid activation function {}'.format(cfg.activation))
+
+ def forward(self, input_features):
+ """Network forward pass"""
+
+ outputs = {}
+
+ x = input_features[-1]
+ for i in range(self.num_scales, -1, -1):
+ x = self.convs[('upconv', i, 0)](x)
+ x = [upsample(x)]
+ if self.use_skips and i > 0:
+ x += [input_features[i - 1]]
+ x = torch.cat(x, 1)
+ x = self.convs[('upconv', i, 1)](x)
+ if i in range(self.num_scales):
+ outputs[('features', i)] = x
+ outputs[('output', i)] = self.activation(
+ self.convs[('outconv', i)](x))
+
+ return outputs
diff --git a/vidar/arch/networks/decoders/PoseDecoder.py b/vidar/arch/networks/decoders/PoseDecoder.py
new file mode 100755
index 0000000000000000000000000000000000000000..89764d89998233d01ba8a41463217634678be0d7
--- /dev/null
+++ b/vidar/arch/networks/decoders/PoseDecoder.py
@@ -0,0 +1,55 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from abc import ABC
+
+import torch
+import torch.nn as nn
+
+
+class PoseDecoder(nn.Module, ABC):
+ """
+ Pose decoder network
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration with parameters
+ """
+ def __init__(self, num_ch_enc, num_input_features,
+ num_frames_to_predict_for=None,
+ stride=1, output_multiplier=0.01):
+ super().__init__()
+
+ self.num_encoder_channels = num_ch_enc
+ self.num_input_features = num_input_features
+ self.output_multiplier = output_multiplier
+
+ if num_frames_to_predict_for is None:
+ num_frames_to_predict_for = num_input_features - 1
+ self.num_output_predictions = num_frames_to_predict_for
+
+ self.convs = {
+ 'squeeze': nn.Conv2d(self.num_encoder_channels[-1], 256, 1),
+ ('pose', 0): nn.Conv2d(num_input_features * 256, 256, 3, stride, 1),
+ ('pose', 1): nn.Conv2d(256, 256, 3, stride, 1),
+ ('pose', 2): nn.Conv2d(256, 6 * num_frames_to_predict_for, 1),
+ }
+
+ self.net = nn.ModuleList(list(self.convs.values()))
+ self.relu = nn.ReLU()
+
+ def forward(self, all_features):
+ """Network forward pass"""
+
+ last_features = [f[-1] for f in all_features]
+ last_features = [self.relu(self.convs['squeeze'](f)) for f in last_features]
+ cat_features = torch.cat(last_features, 1)
+
+ for i in range(3):
+ cat_features = self.convs[('pose', i)](cat_features)
+ if i < 2:
+ cat_features = self.relu(cat_features)
+
+ output = self.output_multiplier * \
+ cat_features.mean(3).mean(2).view(-1, self.num_output_predictions, 1, 6)
+ return torch.split(output, split_size_or_sections=3, dim=-1)
diff --git a/vidar/arch/networks/depth/FocalDepthResNet.py b/vidar/arch/networks/depth/FocalDepthResNet.py
new file mode 100644
index 0000000000000000000000000000000000000000..5070a3cce0643ce05ae20e51d6a696d53c5ea330
--- /dev/null
+++ b/vidar/arch/networks/depth/FocalDepthResNet.py
@@ -0,0 +1,48 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from abc import ABC
+
+from vidar.arch.blocks.depth.SigmoidToInvDepth import SigmoidToInvDepth
+from vidar.arch.networks.BaseNet import BaseNet
+from vidar.arch.networks.decoders.DepthDecoder import DepthDecoder
+from vidar.arch.networks.encoders.ResNetEncoder import ResNetEncoder as ResnetEncoder
+from vidar.utils.depth import inv2depth, depth2inv
+
+
+class FocalDepthResNet(BaseNet, ABC):
+ """
+ Depth network with focal length normalization
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration with parameters
+ """
+ def __init__(self, cfg):
+ super().__init__(cfg)
+
+ self.networks['encoder'] = ResnetEncoder(cfg.encoder)
+ cfg.decoder.num_ch_enc = self.networks['encoder'].num_ch_enc
+ self.networks['decoder'] = DepthDecoder(cfg.decoder)
+ self.scale_inv_depth = SigmoidToInvDepth(
+ min_depth=cfg.min_depth, max_depth=cfg.max_depth)
+
+ def forward(self, rgb, intrinsics, **kwargs):
+ """Network forward pass"""
+
+ x = self.networks['encoder'](rgb)
+ x = self.networks['decoder'](x)
+ inv_depths = [x[('output', i)] for i in range(4)]
+
+ if self.training:
+ inv_depths = [self.scale_inv_depth(inv_depth)[0] for inv_depth in inv_depths]
+ else:
+ inv_depths = [self.scale_inv_depth(inv_depths[0])[0]]
+
+ depths = inv2depth(inv_depths)
+ depths = [d * intrinsics[:, 0, 0].view(rgb.shape[0], 1, 1, 1) for d in depths]
+ inv_depths = depth2inv(depths)
+
+ return {
+ 'inv_depths': inv_depths
+ }
diff --git a/vidar/arch/networks/depth/MonoDepthResNet.py b/vidar/arch/networks/depth/MonoDepthResNet.py
new file mode 100755
index 0000000000000000000000000000000000000000..7d48e40f4da31e79780e62ecccc77a324ebe814e
--- /dev/null
+++ b/vidar/arch/networks/depth/MonoDepthResNet.py
@@ -0,0 +1,46 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from abc import ABC
+
+from vidar.arch.networks.BaseNet import BaseNet
+from vidar.arch.networks.decoders.DepthDecoder import DepthDecoder
+from vidar.arch.networks.encoders.ResNetEncoder import ResNetEncoder
+from vidar.utils.config import cfg_has
+
+
+class MonoDepthResNet(BaseNet, ABC):
+ """
+ Single-frame monocular depth network
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration with parameters
+ """
+ def __init__(self, cfg):
+ super().__init__(cfg)
+
+ self.num_scales = cfg_has(cfg, 'num_scales', 4)
+ self.set_attr(cfg, 'scale_intrinsics', False)
+
+ self.networks['mono_encoder'] = ResNetEncoder(cfg.encoder)
+ cfg.decoder.num_ch_enc = self.networks['mono_encoder'].num_ch_enc
+ self.networks['mono_depth'] = DepthDecoder(cfg.decoder)
+
+ def forward(self, rgb, intrinsics=None):
+ """Network forward pass"""
+
+ features = self.networks['mono_encoder'](rgb)
+ output = self.networks['mono_depth'](features)
+
+ sigmoids = [output[('output', i)] for i in range(self.num_scales)]
+ depths = self.sigmoid_to_depth(sigmoids)
+
+ if intrinsics is not None and self.scale_intrinsics:
+ depths = [d * intrinsics[:, 0, 0].view(
+ rgb.shape[0], 1, 1, 1) for d in depths]
+
+ return {
+ 'features': features,
+ 'depths': depths,
+ }
diff --git a/vidar/arch/networks/depth/MultiDepthResNet.py b/vidar/arch/networks/depth/MultiDepthResNet.py
new file mode 100644
index 0000000000000000000000000000000000000000..65eb95d4cc372d7213040027478ec095e1f69525
--- /dev/null
+++ b/vidar/arch/networks/depth/MultiDepthResNet.py
@@ -0,0 +1,50 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from abc import ABC
+
+from vidar.arch.networks.BaseNet import BaseNet
+from vidar.arch.networks.decoders.DepthDecoder import DepthDecoder
+from vidar.arch.networks.encoders.MultiResNetEncoder import MultiResNetEncoder
+# from vidar.arch.networks.encoders.MultiResNetEncoderStereoTwin import ResnetEncoderMatchingStereo
+from vidar.utils.config import cfg_has
+
+
+class MultiDepthResNet(BaseNet, ABC):
+ """
+ Multi-frame monocular depth network
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration with parameters
+ """
+ def __init__(self, cfg):
+ super().__init__(cfg)
+
+ self.num_scales = cfg_has(cfg, 'num_scales', 4)
+ self.set_attr(cfg, 'scale_intrinsics', False)
+
+ self.networks['encoder'] = MultiResNetEncoder(cfg.encoder)
+ cfg.decoder.num_ch_enc = self.networks['encoder'].num_ch_enc
+ self.networks['depth'] = DepthDecoder(cfg.decoder)
+
+ def forward(self, rgb, rgb_context, cams,
+ intrinsics=None, networks=None):
+ """Network forward pass"""
+
+ encoder_output = self.networks['encoder'](
+ rgb, rgb_context, cams, networks=networks)
+
+ network_output = {
+ **encoder_output,
+ }
+
+ output = self.networks['depth'](encoder_output['features'])
+ sigmoids = [output[('output', i)] for i in range(self.num_scales)]
+ network_output['depths'] = self.sigmoid_to_depth(sigmoids)
+
+ if intrinsics is not None and self.scale_intrinsics:
+ network_output['depths'] = [d * intrinsics[0][:, 0, 0].view(
+ rgb.shape[0], 1, 1, 1) for d in network_output['depths']]
+
+ return network_output
diff --git a/vidar/arch/networks/depth/PackNet.py b/vidar/arch/networks/depth/PackNet.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a797d269355d341183dc37717a129f47127e6a4
--- /dev/null
+++ b/vidar/arch/networks/depth/PackNet.py
@@ -0,0 +1,166 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from abc import ABC
+
+import torch
+import torch.nn as nn
+
+from vidar.arch.networks.BaseNet import BaseNet
+from vidar.arch.networks.layers.packnet.packnet import \
+ PackLayerConv3d, UnpackLayerConv3d, Conv2D, ResidualBlock, InvDepth
+from vidar.utils.depth import inv2depth
+
+
+class PackNet(BaseNet, ABC):
+ """
+ PackNet depth network (https://arxiv.org/abs/1905.02693)
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration with parameters
+ """
+ def __init__(self, cfg):
+ super().__init__(cfg)
+
+ # Configuration parameters
+ self.min_depth = cfg.has('min_depth', 0.5)
+ self.dropout = cfg.has('dropout', 0.0)
+
+ # Input/output channels
+ in_channels = 3
+ out_channels = 1
+
+ # Hyper-parameters
+ ni, no = 64, out_channels
+ n1, n2, n3, n4, n5 = 64, 64, 128, 256, 512
+ num_blocks = [2, 2, 3, 3]
+ pack_kernel = [5, 3, 3, 3, 3]
+ unpack_kernel = [3, 3, 3, 3, 3]
+ iconv_kernel = [3, 3, 3, 3, 3]
+
+ # Initial convolutional layer
+ self.pre_calc = Conv2D(in_channels, ni, 5, 1)
+
+ # Support for different versions
+ n1o, n1i = n1, n1 + ni + no
+ n2o, n2i = n2, n2 + n1 + no
+ n3o, n3i = n3, n3 + n2 + no
+ n4o, n4i = n4, n4 + n3
+ n5o, n5i = n5, n5 + n4
+
+ # Encoder
+
+ self.pack1 = PackLayerConv3d(n1, pack_kernel[0])
+ self.pack2 = PackLayerConv3d(n2, pack_kernel[1])
+ self.pack3 = PackLayerConv3d(n3, pack_kernel[2])
+ self.pack4 = PackLayerConv3d(n4, pack_kernel[3])
+ self.pack5 = PackLayerConv3d(n5, pack_kernel[4])
+
+ self.conv1 = Conv2D(ni, n1, 7, 1)
+ self.conv2 = ResidualBlock(n1, n2, num_blocks[0], 1, dropout=self.dropout)
+ self.conv3 = ResidualBlock(n2, n3, num_blocks[1], 1, dropout=self.dropout)
+ self.conv4 = ResidualBlock(n3, n4, num_blocks[2], 1, dropout=self.dropout)
+ self.conv5 = ResidualBlock(n4, n5, num_blocks[3], 1, dropout=self.dropout)
+
+ # Decoder
+
+ self.unpack5 = UnpackLayerConv3d(n5, n5o, unpack_kernel[0])
+ self.unpack4 = UnpackLayerConv3d(n5, n4o, unpack_kernel[1])
+ self.unpack3 = UnpackLayerConv3d(n4, n3o, unpack_kernel[2])
+ self.unpack2 = UnpackLayerConv3d(n3, n2o, unpack_kernel[3])
+ self.unpack1 = UnpackLayerConv3d(n2, n1o, unpack_kernel[4])
+
+ self.iconv5 = Conv2D(n5i, n5, iconv_kernel[0], 1)
+ self.iconv4 = Conv2D(n4i, n4, iconv_kernel[1], 1)
+ self.iconv3 = Conv2D(n3i, n3, iconv_kernel[2], 1)
+ self.iconv2 = Conv2D(n2i, n2, iconv_kernel[3], 1)
+ self.iconv1 = Conv2D(n1i, n1, iconv_kernel[4], 1)
+
+ # Depth Layers
+
+ self.unpack_disps = nn.PixelShuffle(2)
+ self.unpack_disp4 = nn.Upsample(scale_factor=2, mode='nearest', align_corners=None)
+ self.unpack_disp3 = nn.Upsample(scale_factor=2, mode='nearest', align_corners=None)
+ self.unpack_disp2 = nn.Upsample(scale_factor=2, mode='nearest', align_corners=None)
+
+ self.disp4_layer = InvDepth(n4, out_channels=out_channels, min_depth=self.min_depth)
+ self.disp3_layer = InvDepth(n3, out_channels=out_channels, min_depth=self.min_depth)
+ self.disp2_layer = InvDepth(n2, out_channels=out_channels, min_depth=self.min_depth)
+ self.disp1_layer = InvDepth(n1, out_channels=out_channels, min_depth=self.min_depth)
+
+ self.init_weights()
+
+ def init_weights(self):
+ """Weight initialization"""
+ for m in self.modules():
+ if isinstance(m, (nn.Conv2d, nn.Conv3d)):
+ nn.init.xavier_uniform_(m.weight)
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ def forward(self, rgb, intrinsics=None):
+ """Network forward pass"""
+
+ # Initial convolution
+
+ x = self.pre_calc(rgb)
+
+ # Encoder
+
+ x1 = self.conv1(x)
+ x1p = self.pack1(x1)
+ x2 = self.conv2(x1p)
+ x2p = self.pack2(x2)
+ x3 = self.conv3(x2p)
+ x3p = self.pack3(x3)
+ x4 = self.conv4(x3p)
+ x4p = self.pack4(x4)
+ x5 = self.conv5(x4p)
+ x5p = self.pack5(x5)
+
+ # Skips
+
+ skip1 = x
+ skip2 = x1p
+ skip3 = x2p
+ skip4 = x3p
+ skip5 = x4p
+
+ # Decoder
+
+ unpack5 = self.unpack5(x5p)
+ concat5 = torch.cat((unpack5, skip5), 1)
+ iconv5 = self.iconv5(concat5)
+
+ unpack4 = self.unpack4(iconv5)
+ concat4 = torch.cat((unpack4, skip4), 1)
+ iconv4 = self.iconv4(concat4)
+ inv_depth4 = self.disp4_layer(iconv4)
+ up_inv_depth4 = self.unpack_disp4(inv_depth4)
+
+ unpack3 = self.unpack3(iconv4)
+ concat3 = torch.cat((unpack3, skip3, up_inv_depth4), 1)
+ iconv3 = self.iconv3(concat3)
+ inv_depth3 = self.disp3_layer(iconv3)
+ up_inv_depth3 = self.unpack_disp3(inv_depth3)
+
+ unpack2 = self.unpack2(iconv3)
+ concat2 = torch.cat((unpack2, skip2, up_inv_depth3), 1)
+ iconv2 = self.iconv2(concat2)
+ inv_depth2 = self.disp2_layer(iconv2)
+ up_inv_depth2 = self.unpack_disp2(inv_depth2)
+
+ unpack1 = self.unpack1(iconv2)
+ concat1 = torch.cat((unpack1, skip1, up_inv_depth2), 1)
+ iconv1 = self.iconv1(concat1)
+ inv_depth1 = self.disp1_layer(iconv1)
+
+ if self.training:
+ inv_depths = [inv_depth1, inv_depth2, inv_depth3, inv_depth4]
+ else:
+ inv_depths = [inv_depth1]
+
+ return {
+ 'depths': inv2depth(inv_depths),
+ }
diff --git a/vidar/arch/networks/encoders/MultiResNetEncoder.py b/vidar/arch/networks/encoders/MultiResNetEncoder.py
new file mode 100755
index 0000000000000000000000000000000000000000..021c2539edda918dd3d5d6f4aec898c49f88f7a1
--- /dev/null
+++ b/vidar/arch/networks/encoders/MultiResNetEncoder.py
@@ -0,0 +1,223 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from abc import ABC
+from functools import partial
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torchvision.models as models
+
+from vidar.utils.config import cfg_has
+from vidar.utils.data import flatten
+from vidar.utils.depth import depth2inv
+from vidar.utils.tensor import grid_sample
+
+RESNET_VERSIONS = {
+ 18: models.resnet18,
+ 34: models.resnet34,
+ 50: models.resnet50,
+ 101: models.resnet101,
+ 152: models.resnet152
+}
+
+
+class MultiResNetEncoder(nn.Module, ABC):
+ """
+ Multi-frame depth encoder
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration with parameters
+ """
+ def __init__(self, cfg):
+ super().__init__()
+
+ self.adaptive_bins = cfg.adaptive_bins
+ self.depth_binning = cfg.depth_binning
+ self.set_missing_to_max = True
+
+ self.num_ch_enc = np.array([64, 64, 128, 256, 512])
+
+ self.depth_range = cfg.depth_range
+ self.min_depth_bin = cfg.depth_bin_range[0]
+ self.max_depth_bin = cfg.depth_bin_range[1]
+ self.num_depth_bins = cfg.num_depth_bins
+
+ self.min_depth_bin = torch.nn.Parameter(torch.tensor(
+ self.min_depth_bin), requires_grad=False)
+ self.max_depth_bin = torch.nn.Parameter(torch.tensor(
+ self.max_depth_bin), requires_grad=False)
+
+ self.matching_height = cfg.input_shape[0] // 4
+ self.matching_width = cfg.input_shape[1] // 4
+
+ self.depth_bins = None
+ self.warp_depths = None
+
+ assert cfg.version in RESNET_VERSIONS, ValueError(
+ '{} is not a valid number of resnet layers'.format(cfg.version))
+ encoder = RESNET_VERSIONS[cfg.version](cfg.pretrained)
+
+ self.layer0 = nn.Sequential(encoder.conv1, encoder.bn1, encoder.relu)
+ self.layer1 = nn.Sequential(encoder.maxpool, encoder.layer1)
+ self.layer2 = encoder.layer2
+ self.layer3 = encoder.layer3
+ self.layer4 = encoder.layer4
+
+ if cfg.version > 34:
+ self.num_ch_enc[1:] *= 4
+
+ self.prematching_conv = nn.Sequential(
+ nn.Conv2d(64, out_channels=16, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(inplace=True)
+ )
+
+ self.double_volume = cfg_has(cfg, 'double_volume', False)
+
+ self.reduce_conv = nn.Sequential(
+ nn.Conv2d(self.num_ch_enc[1] + self.num_depth_bins * (2 if self.double_volume else 1),
+ out_channels=self.num_ch_enc[1],
+ kernel_size=3, stride=1, padding=1),
+ nn.ReLU(inplace=True)
+ )
+
+ self.grid_sample = partial(
+ grid_sample, padding_mode='zeros', mode='bilinear', align_corners=True)
+
+ self.volume_masking = cfg_has(cfg, 'volume_masking', False)
+
+ def update_adaptive_depth_bins(self, depth):
+ """Change depth bins based on predicted depth"""
+ min_depth = depth.detach().min(-1)[0].min(-1)[0]
+ max_depth = depth.detach().max(-1)[0].max(-1)[0]
+
+ min_depth = min_depth.mean().cpu().item()
+ max_depth = max_depth.mean().cpu().item()
+
+ min_depth = max(self.depth_range[0], min_depth * 0.9)
+ max_depth = min(self.depth_range[1], max_depth * 1.1)
+
+ self.min_depth_bin = nn.Parameter(
+ self.min_depth_bin * 0.99 + min_depth * 0.01, requires_grad=False)
+ self.max_depth_bin = nn.Parameter(
+ self.max_depth_bin * 0.99 + max_depth * 0.01, requires_grad=False)
+
+ def compute_depth_bins(self, min_depth_bin, max_depth_bin, device):
+ """Compute depth bins based on minimum and maximum values"""
+ min_depth_bin = min_depth_bin.cpu()
+ max_depth_bin = max_depth_bin.cpu()
+
+ if self.depth_binning == 'inverse':
+ self.depth_bins = 1. / np.linspace(
+ 1. / max_depth_bin, 1. / min_depth_bin, self.num_depth_bins)[::-1]
+ elif self.depth_binning == 'linear':
+ self.depth_bins = np.linspace(
+ min_depth_bin, max_depth_bin, self.num_depth_bins)
+ elif self.depth_binning == 'sid':
+ self.depth_bins = np.array(
+ [np.exp(np.log(min_depth_bin) + np.log(max_depth_bin / min_depth_bin) * i / (self.num_depth_bins - 1))
+ for i in range(self.num_depth_bins)])
+ else:
+ raise NotImplementedError
+ self.depth_bins = torch.from_numpy(self.depth_bins).float().to(device)
+
+ ones = torch.ones((1, self.matching_height, self.matching_width),
+ dtype=torch.float, device=device)
+ return torch.stack([depth * ones for depth in self.depth_bins], 1)
+
+ def feature_extraction(self, image, return_all_feats=False):
+ """Extract features from input images"""
+ image = (image - 0.45) / 0.225
+ feats_0 = self.layer0(image)
+ feats_1 = self.layer1(feats_0)
+ return [feats_0, feats_1] if return_all_feats else feats_1
+
+ def indices_to_inv_depth(self, indices):
+ """Convert bin indices to inverse depth values"""
+ batch, height, width = indices.shape
+ depth = self.depth_bins[indices.reshape(-1)]
+ return 1 / depth.reshape((batch, height, width))
+
+ def compute_confidence_mask(self, cost_volume, num_bins_threshold=None):
+ """Compute confidence mask based on cost volume"""
+ if num_bins_threshold is None:
+ num_bins_threshold = self.num_depth_bins
+ return ((cost_volume > 0).sum(1) == num_bins_threshold).float()
+
+ def forward(self, rgb, rgb_context=None,
+ cams=None, mode='multi', networks=None):
+ """Network forward pass"""
+
+ feats = self.feature_extraction(rgb, return_all_feats=True)
+ current_feats = feats[-1]
+
+ if mode == 'mono':
+ feats.append(self.layer2_mono(feats[-1]))
+ feats.append(self.layer3_mono(feats[-1]))
+ feats.append(self.layer4_mono(feats[-1]))
+ return {
+ 'features': feats,
+ }
+
+ output = {}
+
+ with torch.no_grad():
+ if self.warp_depths is None or self.adaptive_bins:
+ self.warp_depths = self.compute_depth_bins(
+ self.min_depth_bin, self.max_depth_bin, device=rgb.device)
+
+ b, n, c, h, w = rgb_context.shape
+ rgb_context = rgb_context.reshape(b * n, c, h, w)
+ feats_context = self.feature_extraction(rgb_context, return_all_feats=True)
+
+ output_transformer = networks['transformer'](rgb, rgb_context, rgb.device, cams[-1].scaled(1/4))
+ output['depth_regr'] = [
+ output_transformer['depth1_low'],
+ ]
+ output['depth_regr'] = flatten(output['depth_regr'])
+
+ if 'ssim_lowest_cost' in output_transformer.keys():
+ output['lowest_cost_ssim'] = output_transformer['ssim_lowest_cost']
+
+ mask3d = output_transformer['warped_mask']
+
+ mask2d = (mask3d.sum(0) == mask3d.shape[0]).float()
+ mask2d[:, :2, :] = 0
+ mask2d[:, -2:, :] = 0
+ mask2d[:, :, :2] = 0
+ mask2d[:, :,-2:] = 0
+
+ output['confidence_mask_transformer'] = mask2d
+ output['confidence_mask_transformer3d'] = mask3d
+ output['lowest_cost_transformer1'] = depth2inv(output_transformer['depth1_low'])
+ output['lowest_cost_transformer2'] = depth2inv(output_transformer['depth2_low'])
+ output['cost_volume_transformer'] = output_transformer['attn_weight_softmax'][0].permute(0, 3, 1, 2)
+
+ cost_volume = output['cost_volume_transformer']
+ confidence_mask = output['confidence_mask_transformer']
+ lowest_cost = output['lowest_cost_transformer1']
+
+ if 'ssim_lowest_cost' in output_transformer:
+ output['ssim_lowest_cost'] = output_transformer['ssim_lowest_cost']
+
+ confidence_mask = self.compute_confidence_mask(
+ cost_volume.detach() * confidence_mask.detach())
+ cost_volume = cost_volume * confidence_mask.unsqueeze(1)
+
+ post_matching_feats = self.reduce_conv(
+ torch.cat([current_feats, cost_volume], 1))
+
+ feats.append(self.layer2(post_matching_feats))
+ feats.append(self.layer3(feats[-1]))
+ feats.append(self.layer4(feats[-1]))
+
+ output.update(**{
+ 'features': feats,
+ 'lowest_cost': lowest_cost,
+ 'confidence_mask': confidence_mask,
+ 'cost_volume': cost_volume,
+ })
+
+ return output
diff --git a/vidar/arch/networks/encoders/ResNetEncoder.py b/vidar/arch/networks/encoders/ResNetEncoder.py
new file mode 100755
index 0000000000000000000000000000000000000000..b4cb41bf97745a762d2b154f616db20a4ed2a2e2
--- /dev/null
+++ b/vidar/arch/networks/encoders/ResNetEncoder.py
@@ -0,0 +1,112 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from abc import ABC
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.utils.model_zoo as model_zoo
+import torchvision.models as models
+
+RESNET_VERSIONS = {
+ 18: models.resnet18,
+ 34: models.resnet34,
+ 50: models.resnet50,
+ 101: models.resnet101,
+ 152: models.resnet152
+}
+
+##################
+
+
+class ResNetMultiInput(models.ResNet, ABC):
+ """ResNet encoder with multiple inputs"""
+ def __init__(self, block_type, block_channels, num_input_rgb):
+ super().__init__(block_type, block_channels)
+
+ self.inplanes = 64
+ self.conv1 = nn.Conv2d(
+ num_input_rgb * 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.relu = nn.ReLU(inplace=True)
+
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block_type, 64, block_channels[0])
+ self.layer2 = self._make_layer(block_type, 128, block_channels[1], stride=2)
+ self.layer3 = self._make_layer(block_type, 256, block_channels[2], stride=2)
+ self.layer4 = self._make_layer(block_type, 512, block_channels[3], stride=2)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+
+def resnet_multi_input(num_layers, num_input_rgb, pretrained=True):
+ """Create a resnet encoder with multiple input images by copying the first layer"""
+ assert num_layers in [18, 50], 'Can only run with 18 or 50 layer resnet'
+
+ block_channels = {
+ 18: [2, 2, 2, 2],
+ 50: [3, 4, 6, 3]
+ }[num_layers]
+
+ block_type = {
+ 18: models.resnet.BasicBlock,
+ 50: models.resnet.Bottleneck
+ }[num_layers]
+
+ model = ResNetMultiInput(block_type, block_channels, num_input_rgb)
+
+ if pretrained:
+ loaded = model_zoo.load_url(models.resnet.model_urls['resnet{}'.format(num_layers)])
+ loaded['conv1.weight'] = torch.cat(
+ [loaded['conv1.weight']] * num_input_rgb, 1) / num_input_rgb
+ model.load_state_dict(loaded)
+
+ return model
+
+
+class ResNetEncoder(nn.Module, ABC):
+ """
+ ResNet encoder network
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration with parameters
+ """
+ def __init__(self, cfg):
+ super().__init__()
+
+ self.num_ch_enc = np.array([64, 64, 128, 256, 512])
+ self.features = []
+
+ assert cfg.version in RESNET_VERSIONS, f'Invalid ResNet version: {cfg.version}'
+
+ if cfg.num_rgb_in > 1:
+ self.encoder = resnet_multi_input(
+ cfg.version, cfg.num_rgb_in, cfg.pretrained)
+ else:
+ self.encoder = RESNET_VERSIONS[cfg.version](cfg.pretrained)
+
+ if cfg.version > 34:
+ self.num_ch_enc[1:] *= 4
+
+ def forward(self, input_image):
+ """Network forward pass"""
+
+ x = (input_image - 0.45) / 0.225
+ x = self.encoder.conv1(x)
+ x = self.encoder.bn1(x)
+
+ self.features.clear()
+ self.features.append(self.encoder.relu(x))
+ self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1])))
+ self.features.append(self.encoder.layer2(self.features[-1]))
+ self.features.append(self.encoder.layer3(self.features[-1]))
+ self.features.append(self.encoder.layer4(self.features[-1]))
+
+ return self.features
diff --git a/vidar/arch/networks/intrinsics/IntrinsicsNet.py b/vidar/arch/networks/intrinsics/IntrinsicsNet.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c9acfad2a3fa1eef4cd6d6453136850b10ca314
--- /dev/null
+++ b/vidar/arch/networks/intrinsics/IntrinsicsNet.py
@@ -0,0 +1,105 @@
+from __future__ import absolute_import, division, print_function
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from collections import OrderedDict
+
+from abc import ABC
+
+from vidar.arch.networks.BaseNet import BaseNet
+from vidar.utils.config import cfg_has
+
+
+class IntrinsicsNet(BaseNet, ABC):
+ def __init__(self, cfg):
+ super().__init__(cfg)
+ assert cfg_has(cfg, 'shape')
+ self.image_shape = cfg.shape
+ self.camera_model = cfg_has(cfg, 'camera_model', 'UCM')
+ self.sigmoid_init = nn.Parameter(torch.tensor(self.setup_sigmoid_init(cfg), dtype=torch.float), requires_grad=True)
+ self.scale = nn.Parameter(torch.tensor(self.setup_scale(cfg), dtype=torch.float, requires_grad=False), requires_grad=False)
+ self.offset = nn.Parameter(torch.tensor(self.setup_offset(cfg), dtype=torch.float, requires_grad=False), requires_grad=False)
+
+ self.tanh = nn.Tanh()
+ self.sigmoid = nn.Sigmoid()
+
+ def __len__(self):
+ if self.camera_model == 'Pinhole':
+ return 4
+ elif self.camera_model == 'UCM':
+ return 5
+ elif self.camera_model == 'EUCM':
+ return 6
+ elif self.camera_model == 'DS':
+ return 6
+ else:
+ raise NotImplementedError('Camera model {} is not implemented. Please choose from \{Pinhole,UCM, EUCM, DS\}.'.format(self.camera_model))
+
+ def setup_sigmoid_init(self, cfg):
+ if cfg_has(cfg, 'sigmoid_init'):
+ assert len(cfg.sigmoid_init) == self.__len__()
+ return np.array(cfg.sigmoid_init)
+ else:
+ if self.camera_model == 'Pinhole':
+ return np.array([0.0, 0.0, 0.0, 0.0])
+ elif self.camera_model == 'UCM':
+ return np.array([0.0, 0.0, 0.0, 0.0, 0.0])
+ elif self.camera_model == 'EUCM':
+ return np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
+ elif self.camera_model == 'DS':
+ return np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
+ else:
+ raise NotImplementedError('Camera model {} is not implemented. Please choose from \{Pinhole,UCM, EUCM, DS\}.'.format(self.camera_model))
+
+ def setup_scale(self, cfg):
+ if cfg_has(cfg, 'scale'):
+ assert len(cfg.scale) == self.__len__()
+ return np.array(cfg.scale)
+ else:
+ h, w = self.image_shape
+ fx_scale, fy_scale = (h + w), (h + w)
+ cx_scale = w
+ cy_scale = h
+ alpha_scale = 1.0
+ beta_scale = 2.0
+ xi_scale = 2.0
+
+ if self.camera_model == 'Pinhole':
+ return np.array([fx_scale, fy_scale, cx_scale, cy_scale])
+ elif self.camera_model == 'UCM':
+ return np.array([fx_scale, fy_scale, cx_scale, cy_scale, alpha_scale])
+ elif self.camera_model == 'EUCM':
+ return np.array([fx_scale, fy_scale, cx_scale, cy_scale, alpha_scale, beta_scale])
+ elif self.camera_model == 'DS':
+ return np.array([fx_scale, fy_scale, cx_scale, cy_scale, xi_scale, alpha_scale])
+ else:
+ raise NotImplementedError('Camera model {} is not implemented. Please choose from \{Pinhole,UCM, EUCM, DS\}.'.format(self.camera_model))
+
+ def setup_offset(self, cfg):
+ if cfg_has(cfg, 'offset'):
+ assert len(cfg.offset) == self.__len__()
+ return np.array(cfg.offset)
+ else:
+ if self.camera_model == 'Pinhole':
+ return np.zeros(4)
+ elif self.camera_model == 'UCM':
+ return np.zeros(5)
+ elif self.camera_model == 'EUCM':
+ return np.zeros(6)
+ elif self.camera_model == 'DS':
+ return np.array([0.0, 0.0, 0.0, 0.0, -1.0, 0.0])
+ else:
+ raise NotImplementedError('Camera model {} is not implemented. Please choose from \{Pinhole,UCM, EUCM, DS\}.'.format(self.camera_model))
+
+
+ def forward(self, rgb):
+ B = rgb.shape[0]
+
+ self.scale.requires_grad = False
+ self.offset.requires_grad = False
+
+ I = self.sigmoid(self.sigmoid_init) * self.scale + self.offset
+
+ return I.unsqueeze(0).repeat(B,1)
\ No newline at end of file
diff --git a/vidar/arch/networks/layers/__init__.py b/vidar/arch/networks/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vidar/arch/networks/layers/convs.py b/vidar/arch/networks/layers/convs.py
new file mode 100644
index 0000000000000000000000000000000000000000..d441d2384358683b3d92a277d175cc584e4cd180
--- /dev/null
+++ b/vidar/arch/networks/layers/convs.py
@@ -0,0 +1,43 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from abc import ABC
+
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def upsample(x):
+ """Upsample input tensor by a factor of 2"""
+ return F.interpolate(x, scale_factor=2, mode="nearest")
+
+
+class ConvBlock(nn.Module, ABC):
+ """Layer to perform a convolution followed by ELU"""
+ def __init__(self, in_channels, out_channels, kernel_size=3):
+ super().__init__()
+ self.conv = Conv3x3(in_channels, out_channels, kernel_size=kernel_size)
+ self.nonlin = nn.ELU(inplace=True)
+
+ def forward(self, x):
+ out = self.conv(x)
+ out = self.nonlin(out)
+ return out
+
+
+class Conv3x3(nn.Module, ABC):
+ """Layer to pad and convolve input"""
+ def __init__(self, in_channels, out_channels, use_refl=True, kernel_size=3):
+ super().__init__()
+ if kernel_size == 3:
+ if use_refl:
+ self.pad = nn.ReflectionPad2d(1)
+ else:
+ self.pad = nn.ZeroPad2d(1)
+ else:
+ self.pad = nn.Identity()
+ self.conv = nn.Conv2d(int(in_channels), int(out_channels), kernel_size=kernel_size)
+
+ def forward(self, x):
+ out = self.pad(x)
+ out = self.conv(out)
+ return out
diff --git a/vidar/arch/networks/layers/depthformer/context_adjustment.py b/vidar/arch/networks/layers/depthformer/context_adjustment.py
new file mode 100644
index 0000000000000000000000000000000000000000..7dfcf0317b3298f0f965c0711b9b972269c88fd1
--- /dev/null
+++ b/vidar/arch/networks/layers/depthformer/context_adjustment.py
@@ -0,0 +1,72 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import torch
+from torch import nn
+from torch.nn.utils import weight_norm
+
+
+class ContextAdjustmentLayer(nn.Module):
+ """
+ Context adjustment layer
+ Base on https://github.com/mli0603/stereo-transformer/blob/main/module/context_adjustment_layer.py
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration with parameters
+ """
+ def __init__(self, cfg):
+ super().__init__()
+
+ num_blocks = cfg.num_blocks
+ feature_dim = cfg.feat_dim
+ expansion = cfg.expansion_ratio
+
+ self.num_blocks = num_blocks
+
+ self.in_conv = nn.Conv2d(4, feature_dim, kernel_size=3, padding=1)
+ self.layers = nn.ModuleList([ResBlock(feature_dim, expansion) for _ in range(num_blocks)])
+ self.out_conv = nn.Conv2d(feature_dim, 1, kernel_size=3, padding=1)
+
+ def forward(self, depth_raw, img):
+ """Network forward pass"""
+
+ eps = 1e-6
+ mean_depth_pred = depth_raw.mean()
+ std_depth_pred = depth_raw.std() + eps
+ depth_pred_normalized = (depth_raw - mean_depth_pred) / std_depth_pred
+
+ feat = self.in_conv(torch.cat([depth_pred_normalized, img], dim=1))
+ for layer in self.layers:
+ feat = layer(feat, depth_pred_normalized)
+
+ depth_res = self.out_conv(feat)
+ depth_final = depth_pred_normalized + depth_res
+
+ return depth_final * std_depth_pred + mean_depth_pred
+
+
+class ResBlock(nn.Module):
+ def __init__(self, n_feats, expansion_ratio, res_scale=1.0):
+ """
+ ResNet block
+
+ Parameters
+ ----------
+ n_feats : Int
+ Number of layer features
+ expansion_ratio : Int
+ Expansion ratio for middle layer
+ res_scale : Float
+ Scale ratio for residual connections
+ """
+ super(ResBlock, self).__init__()
+ self.res_scale = res_scale
+ self.module = nn.Sequential(
+ weight_norm(nn.Conv2d(n_feats + 1, n_feats * expansion_ratio, kernel_size=3, padding=1)),
+ nn.ReLU(inplace=True),
+ weight_norm(nn.Conv2d(n_feats * expansion_ratio, n_feats, kernel_size=3, padding=1))
+ )
+
+ def forward(self, x, depth):
+ return x + self.module(torch.cat([depth, x], dim=1)) * self.res_scale
diff --git a/vidar/arch/networks/layers/depthformer/feature_extraction.py b/vidar/arch/networks/layers/depthformer/feature_extraction.py
new file mode 100644
index 0000000000000000000000000000000000000000..3dd0a3258130e83f7577817b74dc0059b60322b2
--- /dev/null
+++ b/vidar/arch/networks/layers/depthformer/feature_extraction.py
@@ -0,0 +1,106 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import math
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torchvision.models.resnet import BasicBlock
+
+
+class SppBackbone(nn.Module):
+ """
+ Feature extraction network
+ Base on https://github.com/mli0603/stereo-transformer/blob/main/module/feat_extractor_backbone.py
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration with parameters
+ """
+ def __init__(self, cfg):
+ super().__init__()
+ self.inplanes = 32
+ self.in_conv = nn.Sequential(
+ nn.Conv2d(3, 16, kernel_size=3, padding=1, stride=2, bias=False),
+ nn.BatchNorm2d(16),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(16, 16, kernel_size=3, padding=1, bias=False),
+ nn.BatchNorm2d(16),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(16, 32, kernel_size=3, padding=1, bias=False),
+ nn.BatchNorm2d(32),
+ nn.ReLU(inplace=True)
+ )
+
+ self.resblock_1 = self._make_layer(BasicBlock, 64, 3, 2)
+ self.resblock_2 = self._make_layer(BasicBlock, 128, 3, 2)
+
+ self.branch1 = nn.Sequential(
+ nn.AvgPool2d((16, 16), stride=(16, 16)),
+ nn.Conv2d(128, 32, kernel_size=1, bias=False),
+ nn.BatchNorm2d(32),
+ nn.ReLU(inplace=True)
+ )
+
+ self.branch2 = nn.Sequential(
+ nn.AvgPool2d((8, 8), stride=(8, 8)),
+ nn.Conv2d(128, 32, kernel_size=1, bias=False),
+ nn.BatchNorm2d(32),
+ nn.ReLU(inplace=True)
+ )
+
+ self.branch3 = nn.Sequential(
+ nn.AvgPool2d((4, 4), stride=(4, 4)),
+ nn.Conv2d(128, 32, kernel_size=1, bias=False),
+ nn.BatchNorm2d(32),
+ nn.ReLU(inplace=True)
+ )
+
+ self.branch4 = nn.Sequential(
+ nn.AvgPool2d((2, 2), stride=(2, 2)),
+ nn.Conv2d(128, 32, kernel_size=1, bias=False),
+ nn.BatchNorm2d(32),
+ nn.ReLU(inplace=True)
+ )
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ """Create intermediate layer"""
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(planes * block.expansion)
+ )
+ layers = [block(self.inplanes, planes, stride, downsample)]
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+ return nn.Sequential(*layers)
+
+ def forward(self, target, context):
+ """Network forward pass"""
+
+ _, _, h, w = target.shape
+
+ both = torch.cat([target, context], dim=0)
+
+ output = self.in_conv(both)
+
+ output_1 = self.resblock_1(output)
+ output_2 = self.resblock_2(output_1)
+
+ h_spp, w_spp = math.ceil(h / 16), math.ceil(w / 16)
+ spp_1 = self.branch1(output_2)
+ spp_1 = F.interpolate(spp_1, size=(h_spp, w_spp), mode='bilinear', align_corners=False)
+ spp_2 = self.branch2(output_2)
+ spp_2 = F.interpolate(spp_2, size=(h_spp, w_spp), mode='bilinear', align_corners=False)
+ spp_3 = self.branch3(output_2)
+ spp_3 = F.interpolate(spp_3, size=(h_spp, w_spp), mode='bilinear', align_corners=False)
+ spp_4 = self.branch4(output_2)
+ spp_4 = F.interpolate(spp_4, size=(h_spp, w_spp), mode='bilinear', align_corners=False)
+ output_3 = torch.cat([spp_1, spp_2, spp_3, spp_4], dim=1) # 1/16
+
+ return [both, output, output_1, output_2, output_3]
+
diff --git a/vidar/arch/networks/layers/depthformer/regression.py b/vidar/arch/networks/layers/depthformer/regression.py
new file mode 100644
index 0000000000000000000000000000000000000000..4517bfcf4c505185810ff0fb345d3f0dfe0577b0
--- /dev/null
+++ b/vidar/arch/networks/layers/depthformer/regression.py
@@ -0,0 +1,118 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from vidar.arch.networks.layers.depthformer.context_adjustment import ContextAdjustmentLayer
+from vidar.utils.volume import compute_depth_bin
+
+
+class RegressionHead(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.cal = ContextAdjustmentLayer(cfg.context_adjustment)
+ self.phi = nn.Parameter(torch.tensor(0.0, requires_grad=True))
+ self.monocular = True
+
+ @staticmethod
+ def _compute_unscaled_pos_shift(w, device):
+ return torch.linspace(0, w - 1, w)[None, None, None, :].to(device)
+
+ @staticmethod
+ def _compute_low_res_depth(pos_shift, attn_weight):
+ high_response = torch.argmax(attn_weight, dim=-1) # NxHxW
+ response_range = torch.stack([high_response - 1, high_response, high_response + 1], dim=-1)
+ attn_weight_pad = F.pad(attn_weight, [1, 1], value=0.0)
+ attn_weight_rw = torch.gather(attn_weight_pad, -1, response_range + 1)
+
+ norm = attn_weight_rw.sum(-1, keepdim=True)
+ norm[norm < 0.1] = 1.0
+
+ attn_weight_rw = attn_weight_rw / norm
+ pos_pad = F.pad(pos_shift, [1, 1]).expand_as(attn_weight_pad).clone()
+ pos_pad[..., -1] = pos_shift[..., -1] + 1
+ pos_rw = torch.gather(pos_pad, -1, response_range + 1)
+ depth_pred_low_res = (attn_weight_rw * pos_rw)
+
+ depth_pred_low_res = depth_pred_low_res.sum(-1)
+
+ return depth_pred_low_res, norm, high_response
+
+ def upsample(self, x, depth_pred, scale=1.0):
+ _, _, h, w = x.size()
+ depth_pred_attn = depth_pred * scale
+ depth_pred = F.interpolate(depth_pred_attn[None,], size=(h, w), mode='nearest')
+ depth_pred_final = self.cal(depth_pred, x)
+ return depth_pred_final.squeeze(1), depth_pred_attn.squeeze(1)
+
+ def softmax(self, attn):
+ bs, h, w, d = attn.shape
+ similarity_matrix = torch.cat([attn, self.phi.expand(bs, h, w, 1).to(attn.device)], -1)
+ attn_softmax = F.softmax(similarity_matrix, dim=-1)
+ return attn_softmax
+
+ def forward(self, attn_weight, target, context, sampled_rows, sampled_cols, min_depth, max_depth, num_bins):
+
+ stride = [1]
+
+ outputs = []
+ for s in stride:
+ output = self.forward2(
+ attn_weight, target, sampled_cols, min_depth, max_depth, num_bins, s)
+ outputs.append(output)
+ final_output = {}
+ for key in outputs[0].keys():
+ final_output[key] = [o[key] for o in outputs]
+ return final_output
+
+ def forward2(self, attn_weight, target, sampled_cols, min_depth, max_depth, num_bins, stride=1):
+
+ bs, _, h, w = target.size()
+ output = {}
+
+ if stride > 1:
+ shape = list(attn_weight.shape)
+ shape[-1] = shape[-1] // stride
+ attn_weight_tmp = torch.zeros(shape, dtype=attn_weight.dtype, device=attn_weight.device)
+ for i in range(0, shape[-1]):
+ attn_weight_tmp[..., i] = attn_weight[..., i * stride:(i + 1) * stride].mean(-1)
+ attn_weight = attn_weight_tmp
+
+ attn_ot = self.softmax(attn_weight)
+ attn_ot = attn_ot[..., :-1]
+ output['attn_weight_softmax'] = attn_ot
+
+ pos_shift = self._compute_unscaled_pos_shift(attn_weight.shape[3], attn_weight.device)
+
+ depth_pred_low_res1, matched_attn1, high_response1 = self._compute_low_res_depth(pos_shift, attn_ot)
+ depth_pred_low_res2, matched_attn2, high_response2 = self._compute_low_res_depth(pos_shift, attn_ot)
+
+ output['high_response'] = high_response1
+
+ if sampled_cols is not None:
+ output['depth_pred1'], output['depth_pred1_low'] = \
+ self.upsample(target, depth_pred_low_res1)
+ output['depth_pred2'], output['depth_pred2_low'] = \
+ self.upsample(target, depth_pred_low_res2)
+ else:
+ output['depth_pred_low'] = depth_pred_low_res1
+ output['depth_pred'] = depth_pred_low_res1
+
+ if self.monocular:
+
+ num_bins = num_bins // stride
+
+ depth1 = compute_depth_bin(min_depth, max_depth, num_bins, output['depth_pred1'])
+ depth1_low = compute_depth_bin(min_depth, max_depth, num_bins, output['depth_pred1_low'])
+
+ depth2 = compute_depth_bin(min_depth, max_depth, num_bins, output['depth_pred2'])
+ depth2_low = compute_depth_bin(min_depth, max_depth, num_bins, output['depth_pred2_low'])
+
+ output['depth1'] = depth1
+ output['depth1_low'] = depth1_low
+
+ output['depth2'] = depth2
+ output['depth2_low'] = depth2_low
+
+ return output
diff --git a/vidar/arch/networks/layers/depthformer/tokenizer.py b/vidar/arch/networks/layers/depthformer/tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9f9134b258b054f8ff74f8ad92ee095045053f9
--- /dev/null
+++ b/vidar/arch/networks/layers/depthformer/tokenizer.py
@@ -0,0 +1,114 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import torch
+from torch import nn
+from torchvision.models.densenet import _DenseBlock
+
+
+def center_crop(layer, max_height, max_width):
+ _, _, h, w = layer.size()
+ xy1 = (w - max_width) // 2
+ xy2 = (h - max_height) // 2
+ return layer[:, :, xy2:(xy2 + max_height), xy1:(xy1 + max_width)]
+
+
+class TransitionUp(nn.Module):
+ """Transposed convolution for upsampling"""
+ def __init__(self, in_channels: int, out_channels: int, scale: int = 2):
+ super().__init__()
+ if scale == 2:
+ self.convTrans = nn.ConvTranspose2d(
+ in_channels=in_channels, out_channels=out_channels,
+ kernel_size=3, stride=2, padding=0, bias=True)
+ elif scale == 4:
+ self.convTrans = nn.Sequential(
+ nn.ConvTranspose2d(
+ in_channels=in_channels, out_channels=out_channels,
+ kernel_size=3, stride=2, padding=0, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ConvTranspose2d(
+ in_channels=out_channels, out_channels=out_channels,
+ kernel_size=3, stride=2, padding=0, bias=True)
+ )
+
+ def forward(self, x, skip):
+ out = self.convTrans(x)
+ out = center_crop(out, skip.size(2), skip.size(3))
+ out = torch.cat([out, skip], 1)
+ return out
+
+
+class DoubleConv(nn.Module):
+ """Helper class with two convolutional layers, plus BatchNorm and ReLU"""
+ def __init__(self, in_channels, out_channels):
+ super(DoubleConv, self).__init__()
+
+ self.double_conv = nn.Sequential(
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(inplace=True)
+ )
+
+ def forward(self, x):
+ return self.double_conv(x)
+
+
+class Tokenizer(nn.Module):
+ """
+ Feature tokenization network
+ Base on https://github.com/mli0603/stereo-transformer/blob/main/module/feat_extractor_tokenizer.py
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration with parameters
+ """
+ def __init__(self, cfg):
+ super(Tokenizer, self).__init__()
+
+ block_config = [4, 4, 4, 4]
+ backbone_feat_channel = [64, 128, 128]
+ hidden_dim = cfg.channel_dim
+ growth_rate = 4
+
+ backbone_feat_channel.reverse()
+ block_config.reverse()
+
+ self.num_resolution = len(backbone_feat_channel)
+ self.block_config = block_config
+ self.growth_rate = growth_rate
+
+ self.bottle_neck = _DenseBlock(
+ block_config[0], backbone_feat_channel[0], 4, drop_rate=0.0, growth_rate=growth_rate)
+ up = []
+ dense_block = []
+ prev_block_channels = growth_rate * block_config[0]
+ for i in range(self.num_resolution):
+ if i == self.num_resolution - 1:
+ up.append(TransitionUp(prev_block_channels, hidden_dim, 4))
+ dense_block.append(DoubleConv(hidden_dim + 3, hidden_dim))
+ else:
+ up.append(TransitionUp(prev_block_channels, prev_block_channels))
+ cur_channels_count = prev_block_channels + backbone_feat_channel[i + 1]
+ dense_block.append(
+ _DenseBlock(block_config[i + 1], cur_channels_count, 4, drop_rate=0.0, growth_rate=growth_rate))
+ prev_block_channels = growth_rate * block_config[i + 1]
+
+ self.up = nn.ModuleList(up)
+ self.dense_block = nn.ModuleList(dense_block)
+
+ def forward(self, features):
+ """Network forward pass"""
+
+ features.reverse()
+ output = self.bottle_neck(features[0])
+ output = output[:, -(self.block_config[0] * self.growth_rate):]
+ for i in range(self.num_resolution):
+ hs = self.up[i](output, features[i + 1])
+ output = self.dense_block[i](hs)
+ if i < self.num_resolution - 1:
+ output = output[:, -(self.block_config[i + 1] * self.growth_rate):]
+ return output
diff --git a/vidar/arch/networks/layers/depthformer/transformer.py b/vidar/arch/networks/layers/depthformer/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea81a7bcc0bdee245657e4fd7c3c18ce656737a4
--- /dev/null
+++ b/vidar/arch/networks/layers/depthformer/transformer.py
@@ -0,0 +1,223 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import copy
+from functools import partial
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from vidar.arch.losses.SSIMLoss import SSIMLoss
+from vidar.utils.tensor import grid_sample
+from vidar.utils.volume import compute_depth_bins, compute_depth_bin
+
+
+def get_clones(module, N):
+ """Create clones of a module"""
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+def prepareA(feat):
+ """Reorganize features in one way"""
+ return feat.permute(1, 3, 2, 0).flatten(2).permute(1, 2, 0)
+
+
+def prepareB(x):
+ """Reorganize features in another way"""
+ d, c, h, w = x.shape
+ return x.permute(1, 2, 3, 0).reshape(c, h * w, d).permute(2, 1, 0)
+
+
+def unprepare(feat, shape):
+ """Return features back to original shape"""
+ b, c, h, w = shape
+ return feat.permute(2, 0, 1).reshape(c, w, h, b).permute(3, 0, 2, 1)
+
+
+class Transformer(nn.Module):
+ """
+ Transformer network for Feature Matching (https://arxiv.org/abs/2204.07616)
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration with parameters
+ """
+ def __init__(self, cfg):
+ super().__init__()
+
+ self.hidden_dim = cfg.channel_dim
+ self.num_attn_layers = cfg.num_attn_layers
+
+ self_attn_layer = TransformerSelfAttnLayer(self.hidden_dim, cfg.nheads)
+ self.self_attn_layers = get_clones(self_attn_layer, self.num_attn_layers)
+
+ cross_attn_layer = TransformerCrossAttnLayer(self.hidden_dim, cfg.nheads)
+ self.cross_attn_layers = get_clones(cross_attn_layer, self.num_attn_layers)
+
+ self.norm = nn.LayerNorm(self.hidden_dim)
+
+ self.grid_sample = partial(
+ grid_sample, padding_mode='zeros', mode='bilinear', align_corners=True)
+
+ self.grid_sample_nearest = partial(
+ grid_sample, padding_mode='zeros', mode='nearest', align_corners=True)
+
+ def _alternating_attn(self, feat1, feat2,
+ cam=None, min_depth=None, max_depth=None, num_bins=None):
+ """Perform self- and cross-attention between two feature maps"""
+ device = feat1.device
+ cam = cam.to(device)
+ h, w = cam.hw
+
+ depth_bins = compute_depth_bins(min_depth, max_depth, num_bins, 'sid').to(device)
+
+ ones = torch.ones((1, h, w), dtype=feat1.dtype, device=device)
+ warped_depth = torch.stack([depth * ones for depth in depth_bins], 1)
+ coords = cam.coords_from_cost_volume(warped_depth)[0]
+
+ coords[coords < -1] = -2
+ coords[coords > +1] = +2
+
+ repeated_feat2 = feat2.repeat([num_bins, 1, 1, 1])
+ warped = self.grid_sample(repeated_feat2, coords.type(repeated_feat2.dtype))
+
+ repeated_ones = ones.repeat([num_bins, 1, 1, 1])
+ warped_mask = self.grid_sample_nearest(repeated_ones, coords.type(repeated_ones.dtype))
+
+ with torch.no_grad():
+ ssim_volume = SSIMLoss()(feat1, warped)['loss'].mean(1).unsqueeze(0)
+ lowest_cost = 1. / compute_depth_bin(min_depth, max_depth, num_bins, torch.min(ssim_volume, 1)[1])
+
+ feat1 = prepareB(feat1)
+ feat2 = prepareB(warped)
+
+ attn_weight = None
+ for idx, (self_attn, cross_attn) in \
+ enumerate(zip(self.self_attn_layers, self.cross_attn_layers)):
+ feat1 = self_attn(feat1)
+ feat1, feat2, attn_weight = cross_attn(feat1, feat2)
+
+ return {
+ 'attn_weight': attn_weight,
+ 'warped_mask': warped_mask,
+ 'ssim_lowest_cost': lowest_cost,
+ 'ssim_cost_volume': ssim_volume,
+ }
+
+ def forward(self, feat1, feat2, cam=None, min_depth=None, max_depth=None, num_bins=None):
+ """Network forward pass"""
+
+ bs, c, hn, w = feat1.shape
+
+ transformer_output = self._alternating_attn(
+ feat1, feat2, cam=cam, min_depth=min_depth, max_depth=max_depth, num_bins=num_bins)
+ transformer_output['attn_weight'] = \
+ transformer_output['attn_weight'].view(bs, hn, w, num_bins)
+
+ return transformer_output
+
+
+class TransformerSelfAttnLayer(nn.Module):
+ """Self-attention layer for transformers"""
+ def __init__(self, hidden_dim, nheads):
+ super().__init__()
+ self.self_attn = MultiheadAttentionRelative(hidden_dim, nheads)
+
+ self.norm1 = nn.LayerNorm(hidden_dim)
+
+ def forward(self, feat):
+ feat_out = self.norm1(feat)
+ feat_out, _, _ = self.self_attn(query=feat_out, key=feat_out, value=feat_out)
+ return feat + feat_out
+
+
+class TransformerCrossAttnLayer(nn.Module):
+ """Cross-attention layer for transformers"""
+ def __init__(self, hidden_dim, nheads):
+ super().__init__()
+ self.cross_attn = MultiheadAttentionRelative(hidden_dim, nheads)
+
+ self.norm1 = nn.LayerNorm(hidden_dim)
+ self.norm2 = nn.LayerNorm(hidden_dim)
+
+ def forward(self, feat1, feat2):
+
+ feat1_2 = self.norm1(feat1)
+ feat2_2 = self.norm1(feat2)
+
+ feat1_2, attn_weight, raw_attn = self.cross_attn(query=feat1_2, key=feat2_2, value=feat2_2)
+ feat1 = feat1 + feat1_2
+
+ return feat1, feat2, raw_attn
+
+ @torch.no_grad()
+ def _generate_square_subsequent_mask(self, sz):
+ mask = torch.triu(torch.ones(sz, sz), diagonal=1)
+ mask[mask == 1] = float('-inf')
+ return mask
+
+
+class MultiheadAttentionRelative(nn.MultiheadAttention):
+ """Multi-head attention layer"""
+ def __init__(self, embed_dim, num_heads):
+ super().__init__(
+ embed_dim, num_heads, dropout=0.0, bias=True,
+ add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None)
+
+ def forward(self, query, key, value):
+
+ w2, bsz2, embed_dim2 = key.size()
+ w, bsz, embed_dim = query.size()
+
+ head_dim = embed_dim // self.num_heads
+
+ if torch.equal(query, key) and torch.equal(key, value):
+ q, k, v = F.linear(
+ query, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
+ elif torch.equal(key, value):
+ _b = self.in_proj_bias
+ _start = 0
+ _end = embed_dim
+ _w = self.in_proj_weight[_start:_end, :]
+ if _b is not None:
+ _b = _b[_start:_end]
+ q = F.linear(query, _w, _b)
+
+ if key is None:
+ assert value is None
+ k = None
+ v = None
+ else:
+ _b = self.in_proj_bias
+ _start = embed_dim
+ _end = None
+ _w = self.in_proj_weight[_start:, :]
+ if _b is not None:
+ _b = _b[_start:]
+ k, v = F.linear(key, _w, _b).chunk(2, dim=-1)
+ else:
+ raise ValueError('Invalid key/query/value')
+
+ scaling = float(head_dim) ** - 0.5
+ q = q * scaling
+
+ q = q.contiguous().view(w, bsz, self.num_heads, head_dim)
+ if k is not None:
+ k = k.contiguous().view(-1, bsz, self.num_heads, head_dim)
+ if v is not None:
+ v = v.contiguous().view(-1, bsz, self.num_heads, head_dim)
+
+ attn = torch.einsum('wnec,vnec->newv', q, k)
+ raw_attn = attn
+ attn = F.softmax(attn, dim=-1)
+
+ v_out = torch.bmm(attn.view(bsz * self.num_heads, w, w2),
+ v.permute(1, 2, 0, 3).view(bsz * self.num_heads, w2, head_dim))
+ v_out = v_out.reshape(bsz, self.num_heads, w, head_dim).permute(2, 0, 1, 3).reshape(w, bsz, embed_dim)
+ v_out = F.linear(v_out, self.out_proj.weight, self.out_proj.bias)
+
+ attn = attn.sum(dim=1) / self.num_heads
+ raw_attn = raw_attn.sum(dim=1)
+
+ return v_out, attn, raw_attn
diff --git a/vidar/arch/networks/layers/depthformer/transformer_net.py b/vidar/arch/networks/layers/depthformer/transformer_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..7158d19d6f08704bcf98b71f46e148e750a51792
--- /dev/null
+++ b/vidar/arch/networks/layers/depthformer/transformer_net.py
@@ -0,0 +1,106 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import torch
+import torch.nn as nn
+
+from vidar.arch.networks.layers.depthformer.feature_extraction import SppBackbone as Backbone
+from vidar.arch.networks.layers.depthformer.regression import RegressionHead
+from vidar.arch.networks.layers.depthformer.tokenizer import Tokenizer
+from vidar.arch.networks.layers.depthformer.transformer import Transformer
+
+
+def batched_index_select(source, dim, index):
+ views = [source.shape[0]] + [1 if i != dim else -1 for i in range(1, len(source.shape))]
+ expanse = list(source.shape)
+ expanse[0] = -1
+ expanse[dim] = -1
+ index = index.view(views).expand(expanse)
+ return torch.gather(source, dim, index)
+
+
+class TransformerNet(nn.Module):
+
+ def __init__(self, cfg, decoder_type='regression'):
+ super().__init__()
+
+ self.backbone = Backbone(cfg)
+ self.tokenizer = Tokenizer(cfg)
+ self.transformer = Transformer(cfg)
+
+ self.min_depth = cfg.min_depth
+ self.max_depth = cfg.max_depth
+ self.num_bins = cfg.num_bins
+
+ self.decoder_type = decoder_type
+
+ self.regression_head = RegressionHead(cfg)
+
+ self._reset_parameters()
+ self._disable_batchnorm_tracking()
+ self._relu_inplace()
+
+ def _reset_parameters(self):
+ for n, m in self.named_modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, (nn.BatchNorm2d, nn.LayerNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.zeros_(m.bias)
+
+ def _disable_batchnorm_tracking(self):
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.track_running_stats = False
+
+ def _relu_inplace(self):
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.inplace = True
+
+ def fix_layers(self):
+ def iterate(module):
+ for key in module.keys():
+ if type(module[key]) == nn.BatchNorm2d:
+ module[key] = nn.InstanceNorm2d(
+ module[key].num_features,
+ module[key].eps,
+ module[key].momentum,
+ module[key].affine,
+ )
+ iterate(module[key]._modules)
+ iterate(self._modules)
+
+ def forward(self, target, context, sampled_rows, sampled_cols, cam=None):
+
+ bs, _, h, w = target.size()
+
+ feat_all = self.backbone(target, context)
+ feat = [feat_all[0]] + feat_all[2:]
+
+ feat1 = [f[[0]] for f in feat]
+ feat2 = [f[[1]] for f in feat]
+
+ feat1 = self.tokenizer(feat1)
+ feat2 = self.tokenizer(feat2)
+
+ if sampled_cols is not None:
+ feat1 = batched_index_select(feat1, 3, sampled_cols)
+ feat2 = batched_index_select(feat2, 3, sampled_cols)
+ if sampled_rows is not None:
+ feat1 = batched_index_select(feat1, 2, sampled_rows)
+ feat2 = batched_index_select(feat2, 2, sampled_rows)
+
+ output_transformer = self.transformer(
+ feat1, feat2, cam=cam, min_depth=self.min_depth, max_depth=self.max_depth, num_bins=self.num_bins)
+
+ output_regression = self.regression_head(
+ output_transformer['attn_weight'], target, context, sampled_rows, sampled_cols,
+ min_depth=self.min_depth, max_depth=self.max_depth, num_bins=self.num_bins,
+ )
+
+ return {
+ **output_transformer,
+ **output_regression,
+ }
diff --git a/vidar/arch/networks/layers/fsm/camera.py b/vidar/arch/networks/layers/fsm/camera.py
new file mode 100644
index 0000000000000000000000000000000000000000..83f536f291657aff2f7d5ab3f10c30182ef82ee6
--- /dev/null
+++ b/vidar/arch/networks/layers/fsm/camera.py
@@ -0,0 +1,289 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from functools import lru_cache
+
+import torch
+import torch.nn as nn
+
+from vidar.arch.networks.layers.fsm.camera_utils import scale_intrinsics, invert_intrinsics
+from vidar.arch.networks.layers.fsm.pose import Pose
+from vidar.utils.tensor import pixel_grid
+from vidar.utils.types import is_tensor, is_list
+
+
+class Camera(nn.Module):
+ """
+ Differentiable camera class implementing reconstruction and projection
+ functions for a pinhole model.
+ """
+ def __init__(self, K, Tcw=None, Twc=None, hw=None):
+ """
+ Initializes the Camera class
+
+ Parameters
+ ----------
+ K : torch.Tensor
+ Camera intrinsics [B,3,3]
+ Tcw : Pose or torch.Tensor
+ Camera -> World pose transformation [B,4,4]
+ Twc : Pose or torch.Tensor
+ World -> Camera pose transformation [B,4,4]
+ hw : tuple or torch.Tensor
+ Camera width and height, or a tensor with the proper shape
+ """
+ super().__init__()
+ assert Tcw is None or Twc is None, 'You should provide either Tcw or Twc'
+ self.K = K
+ self.hw = None if hw is None else hw.shape[-2:] if is_tensor(hw) else hw[-2:]
+ if Tcw is not None:
+ self.Tcw = Tcw if isinstance(Tcw, Pose) else Pose(Tcw)
+ elif Twc is not None:
+ self.Tcw = Twc.inverse() if isinstance(Twc, Pose) else Pose(Twc).inverse()
+ else:
+ self.Tcw = Pose.identity(len(self.K))
+
+ def __len__(self):
+ """Batch size of the camera intrinsics"""
+ return len(self.K)
+
+ def __getitem__(self, idx):
+ """Return single camera from a batch position"""
+ return Camera(K=self.K[idx].unsqueeze(0),
+ hw=self.hw, Tcw=self.Tcw[idx]).to(self.device)
+
+ @property
+ def wh(self):
+ """Return camera width and height"""
+ return None if self.hw is None else self.hw[::-1]
+
+ @property
+ def pose(self):
+ """Return camera pose"""
+ return self.Twc.mat
+
+ @property
+ def device(self):
+ """Return camera device"""
+ return self.K.device
+
+ def invert_pose(self):
+ """Return new camera with inverted pose"""
+ return Camera(K=self.K, Tcw=self.Twc)
+
+ def to(self, *args, **kwargs):
+ """Moves object to a specific device"""
+ self.K = self.K.to(*args, **kwargs)
+ self.Tcw = self.Tcw.to(*args, **kwargs)
+ return self
+
+ @property
+ def fx(self):
+ """Focal length in x"""
+ return self.K[:, 0, 0]
+
+ @property
+ def fy(self):
+ """Focal length in y"""
+ return self.K[:, 1, 1]
+
+ @property
+ def cx(self):
+ """Principal point in x"""
+ return self.K[:, 0, 2]
+
+ @property
+ def cy(self):
+ """Principal point in y"""
+ return self.K[:, 1, 2]
+
+ @property
+ @lru_cache()
+ def Twc(self):
+ """World -> Camera pose transformation (inverse of Tcw)"""
+ return self.Tcw.inverse()
+
+ @property
+ @lru_cache()
+ def Kinv(self):
+ """Inverse intrinsics (for lifting)"""
+ return invert_intrinsics(self.K)
+
+ def equal(self, cam):
+ """Check if two cameras are the same"""
+ return torch.allclose(self.K, cam.K) and \
+ torch.allclose(self.Tcw.mat, cam.Tcw.mat)
+
+ def scaled(self, x_scale, y_scale=None):
+ """
+ Returns a scaled version of the camera (changing intrinsics)
+
+ Parameters
+ ----------
+ x_scale : float
+ Resize scale in x
+ y_scale : float
+ Resize scale in y. If None, use the same as x_scale
+
+ Returns
+ -------
+ camera : Camera
+ Scaled version of the current camera
+ """
+ # If single value is provided, use for both dimensions
+ if y_scale is None:
+ y_scale = x_scale
+ # If no scaling is necessary, return same camera
+ if x_scale == 1. and y_scale == 1.:
+ return self
+ # Scale intrinsics
+ K = scale_intrinsics(self.K.clone(), x_scale, y_scale)
+ # Scale image dimensions
+ hw = None if self.hw is None else (int(self.hw[0] * y_scale),
+ int(self.hw[1] * x_scale))
+ # Return scaled camera
+ return Camera(K=K, Tcw=self.Tcw, hw=hw)
+
+ def scaled_K(self, shape):
+ """Return scaled intrinsics to match a shape"""
+ if self.hw is None:
+ return self.K
+ else:
+ y_scale, x_scale = [sh / hw for sh, hw in zip(shape[-2:], self.hw)]
+ return scale_intrinsics(self.K, x_scale, y_scale)
+
+ def scaled_Kinv(self, shape):
+ """Return scaled inverse intrinsics to match a shape"""
+ return invert_intrinsics(self.scaled_K(shape))
+
+ def reconstruct(self, depth, frame='w', scene_flow=None, return_grid=False):
+ """
+ Reconstructs pixel-wise 3D points from a depth map.
+
+ Parameters
+ ----------
+ depth : torch.Tensor
+ Depth map for the camera [B,1,H,W]
+ frame : 'w'
+ Reference frame: 'c' for camera and 'w' for world
+ scene_flow : torch.Tensor
+ Optional per-point scene flow to be added (camera reference frame) [B,3,H,W]
+ return_grid : bool
+ Return pixel grid as well
+
+ Returns
+ -------
+ points : torch.tensor
+ Pixel-wise 3D points [B,3,H,W]
+ """
+ # If depth is a list, return each reconstruction
+ if is_list(depth):
+ return [self.reconstruct(d, frame, scene_flow, return_grid) for d in depth]
+ # Dimension assertions
+ assert depth.dim() == 4 and depth.shape[1] == 1, \
+ 'Wrong dimensions for camera reconstruction'
+
+ # Create flat index grid [B,3,H,W]
+ B, _, H, W = depth.shape
+ grid = pixel_grid((H, W), B, device=depth.device, normalize=False, with_ones=True)
+ flat_grid = grid.view(B, 3, -1)
+
+ # Get inverse intrinsics
+ Kinv = self.Kinv if self.hw is None else self.scaled_Kinv(depth.shape)
+
+ # Estimate the outward rays in the camera frame
+ Xnorm = (Kinv.bmm(flat_grid)).view(B, 3, H, W)
+ # Scale rays to metric depth
+ Xc = Xnorm * depth
+
+ # Add scene flow if provided
+ if scene_flow is not None:
+ Xc = Xc + scene_flow
+
+ # If in camera frame of reference
+ if frame == 'c':
+ pass
+ # If in world frame of reference
+ elif frame == 'w':
+ Xc = self.Twc @ Xc
+ # If none of the above
+ else:
+ raise ValueError('Unknown reference frame {}'.format(frame))
+ # Return points and grid if requested
+ return (Xc, grid) if return_grid else Xc
+
+ def project(self, X, frame='w', normalize=True, return_z=False):
+ """
+ Projects 3D points onto the image plane
+
+ Parameters
+ ----------
+ X : torch.Tensor
+ 3D points to be projected [B,3,H,W]
+ frame : 'w'
+ Reference frame: 'c' for camera and 'w' for world
+ normalize : bool
+ Normalize grid coordinates
+ return_z : bool
+ Return the projected z coordinate as well
+
+ Returns
+ -------
+ points : torch.Tensor
+ 2D projected points that are within the image boundaries [B,H,W,2]
+ """
+ assert 2 < X.dim() <= 4 and X.shape[1] == 3, \
+ 'Wrong dimensions for camera projection'
+
+ # Determine if input is a grid
+ is_grid = X.dim() == 4
+ # If it's a grid, flatten it
+ X_flat = X.view(X.shape[0], 3, -1) if is_grid else X
+
+ # Get dimensions
+ hw = X.shape[2:] if is_grid else self.hw
+ # Get intrinsics
+ K = self.scaled_K(X.shape) if is_grid else self.K
+
+ # Project 3D points onto the camera image plane
+ if frame == 'c':
+ Xc = K.bmm(X_flat)
+ elif frame == 'w':
+ Xc = K.bmm(self.Tcw @ X_flat)
+ else:
+ raise ValueError('Unknown reference frame {}'.format(frame))
+
+ # Extract coordinates
+ Z = Xc[:, 2].clamp(min=1e-5)
+ XZ = Xc[:, 0] / Z
+ YZ = Xc[:, 1] / Z
+
+ # Normalize points
+ if normalize and hw is not None:
+ XZ = 2 * XZ / (hw[1] - 1) - 1.
+ YZ = 2 * YZ / (hw[0] - 1) - 1.
+
+ # Clamp out-of-bounds pixels
+ Xmask = ((XZ > 1) + (XZ < -1)).detach()
+ XZ[Xmask] = 2.
+ Ymask = ((XZ > 1) + (YZ < -1)).detach()
+ YZ[Ymask] = 2.
+
+ # Stack X and Y coordinates
+ XY = torch.stack([XZ, YZ], dim=-1)
+ # Reshape coordinates to a grid if possible
+ if is_grid and hw is not None:
+ XY = XY.view(X.shape[0], hw[0], hw[1], 2)
+
+ # If also returning depth
+ if return_z:
+ # Reshape depth values to a grid if possible
+ if is_grid and hw is not None:
+ Z = Z.view(X.shape[0], hw[0], hw[1], 1).permute(0, 3, 1, 2)
+ # Otherwise, reshape to an array
+ else:
+ Z = Z.view(X.shape[0], -1, 1).permute(0, 2, 1)
+ # Return coordinates and depth values
+ return XY, Z
+ else:
+ # Return coordinates
+ return XY
diff --git a/vidar/arch/networks/layers/fsm/camera_utils.py b/vidar/arch/networks/layers/fsm/camera_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f42b646086443ebd49f57e1b5b5b4fa81101b545
--- /dev/null
+++ b/vidar/arch/networks/layers/fsm/camera_utils.py
@@ -0,0 +1,112 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import torch
+import torch.nn.functional as tf
+
+
+def construct_K(fx, fy, cx, cy, dtype=torch.float, device=None):
+ """Construct a [3,3] camera intrinsics from pinhole parameters"""
+ return torch.tensor([[fx, 0, cx],
+ [ 0, fy, cy],
+ [ 0, 0, 1]], dtype=dtype, device=device)
+
+
+def scale_intrinsics(K, x_scale, y_scale):
+ """Scale intrinsics given x_scale and y_scale factors"""
+ K = K.clone()
+ K[..., 0, 0] *= x_scale
+ K[..., 1, 1] *= y_scale
+ # K[..., 0, 2] = (K[..., 0, 2] + 0.5) * x_scale - 0.5
+ # K[..., 1, 2] = (K[..., 1, 2] + 0.5) * y_scale - 0.5
+ K[..., 0, 2] = K[..., 0, 2] * x_scale
+ K[..., 1, 2] = K[..., 1, 2] * y_scale
+ return K
+
+
+def invert_intrinsics(K):
+ """Invert camera intrinsics"""
+ Kinv = K.clone()
+ Kinv[:, 0, 0] = 1. / K[:, 0, 0]
+ Kinv[:, 1, 1] = 1. / K[:, 1, 1]
+ Kinv[:, 0, 2] = -1. * K[:, 0, 2] / K[:, 0, 0]
+ Kinv[:, 1, 2] = -1. * K[:, 1, 2] / K[:, 1, 1]
+ return Kinv
+
+
+def view_synthesis(ref_image, depth, ref_cam, cam, scene_flow=None,
+ mode='bilinear', padding_mode='zeros', align_corners=True):
+ """
+ Synthesize an image from another plus a depth map.
+
+ Parameters
+ ----------
+ ref_image : torch.Tensor
+ Reference image to be warped [B,3,H,W]
+ depth : torch.Tensor
+ Depth map from the original image [B,1,H,W]
+ ref_cam : Camera
+ Camera class for the reference image
+ cam : Camera
+ Camera class for the original image
+ scene_flow : torch.Tensor
+ Scene flow use for warping [B,3,H,W]
+ mode : str
+ Mode for grid sampling
+ padding_mode : str
+ Padding mode for grid sampling
+ align_corners : bool
+ Corner alignment for grid sampling
+
+ Returns
+ -------
+ ref_warped : torch.Tensor
+ Warped reference image in the original frame of reference [B,3,H,W]
+ """
+ assert depth.shape[1] == 1, 'Depth map should have C=1'
+ # Reconstruct world points from target_camera
+ world_points = cam.reconstruct(depth, frame='w', scene_flow=scene_flow)
+ # Project world points onto reference camera
+ ref_coords = ref_cam.project(world_points, frame='w')
+ # View-synthesis given the projected reference points
+ return tf.grid_sample(ref_image, ref_coords, mode=mode,
+ padding_mode=padding_mode, align_corners=align_corners)
+
+
+def view_synthesis_generic(ref_image, depth, ref_cam, cam,
+ mode='bilinear', padding_mode='zeros', align_corners=True,
+ progress=0.0):
+ """
+ Synthesize an image from another plus a depth map.
+
+ Parameters
+ ----------
+ ref_image : torch.Tensor
+ Reference image to be warped [B,3,H,W]
+ depth : torch.Tensor
+ Depth map from the original image [B,1,H,W]
+ ref_cam : Camera
+ Camera class for the reference image
+ cam : Camera
+ Camera class for the original image
+ mode : str
+ Interpolation mode
+ padding_mode : str
+ Padding mode for interpolation
+ align_corners : bool
+ Corner alignment for grid sampling
+ progress : float
+ Training process (percentage)
+
+ Returns
+ -------
+ ref_warped : torch.Tensor
+ Warped reference image in the original frame of reference [B,3,H,W]
+ """
+ assert depth.shape[1] == 1, 'Depth map should have C=1'
+ # Reconstruct world points from target_camera
+ world_points = cam.reconstruct(depth, frame='w')
+ # Project world points onto reference camera
+ ref_coords = ref_cam.project(world_points, progress=progress, frame='w')
+ # View-synthesis given the projected reference points
+ return tf.grid_sample(ref_image, ref_coords, mode=mode,
+ padding_mode=padding_mode, align_corners=align_corners)
diff --git a/vidar/arch/networks/layers/fsm/pose.py b/vidar/arch/networks/layers/fsm/pose.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfea5b408e729c65fbca6b217655ed0ee908f8c4
--- /dev/null
+++ b/vidar/arch/networks/layers/fsm/pose.py
@@ -0,0 +1,108 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import torch
+
+from vidar.geometry.pose_utils import invert_pose, pose_vec2mat
+
+
+class Pose:
+ """
+ Pose class, that encapsulates a [4,4] transformation matrix
+ for a specific reference frame
+ """
+ def __init__(self, mat):
+ """
+ Initializes a Pose object.
+
+ Parameters
+ ----------
+ mat : torch.Tensor
+ Transformation matrix [B,4,4]
+ """
+ assert tuple(mat.shape[-2:]) == (4, 4)
+ if mat.dim() == 2:
+ mat = mat.unsqueeze(0)
+ assert mat.dim() == 3
+ self.mat = mat
+
+ def __len__(self):
+ """Batch size of the transformation matrix"""
+ return len(self.mat)
+
+ def __getitem__(self, i):
+ return Pose(self.mat[i].unsqueeze(0)).to(self.device)
+
+ @property
+ def device(self):
+ """Return pose device"""
+ return self.mat.device
+
+ @classmethod
+ def identity(cls, N=1, device=None, dtype=torch.float):
+ """Initializes as a [4,4] identity matrix"""
+ return cls(torch.eye(4, device=device, dtype=dtype).repeat([N,1,1]))
+
+ @classmethod
+ def from_vec(cls, vec, mode):
+ """Initializes from a [B,6] batch vector"""
+ mat = pose_vec2mat(vec, mode) # [B,3,4]
+ pose = torch.eye(4, device=vec.device, dtype=vec.dtype).repeat([len(vec), 1, 1])
+ pose[:, :3, :3] = mat[:, :3, :3]
+ pose[:, :3, -1] = mat[:, :3, -1]
+ return cls(pose)
+
+ @property
+ def shape(self):
+ """Returns the transformation matrix shape"""
+ return self.mat.shape
+
+ def item(self):
+ """Returns the transformation matrix"""
+ return self.mat
+
+ def repeat(self, *args, **kwargs):
+ """Repeats the transformation matrix multiple times"""
+ self.mat = self.mat.repeat(*args, **kwargs)
+ return self
+
+ def inverse(self):
+ """Returns a new Pose that is the inverse of this one"""
+ return Pose(invert_pose(self.mat))
+
+ def to(self, *args, **kwargs):
+ """Moves object to a specific device"""
+ self.mat = self.mat.to(*args, **kwargs)
+ return self
+
+ def transform_pose(self, pose):
+ """Creates a new pose object that compounds this and another one (self * pose)"""
+ assert tuple(pose.shape[-2:]) == (4, 4)
+ return Pose(self.mat.bmm(pose.item()))
+
+ def transform_points(self, points):
+ """Transforms 3D points using this object"""
+ assert 2 < points.dim() <= 4 and points.shape[1] == 3, \
+ 'Wrong dimensions for transform_points'
+ # Determine if input is a grid
+ is_grid = points.dim() == 4
+ # If it's a grid, flatten it
+ points_flat = points.view(points.shape[0], 3, -1) if is_grid else points
+ # Tranform points
+ out = self.mat[:, :3, :3].bmm(points_flat) + \
+ self.mat[:, :3, -1].unsqueeze(-1)
+ # Return transformed points
+ return out.view(points.shape) if is_grid else out
+
+ def __matmul__(self, other):
+ """Transforms the input (Pose or 3D points) using this object"""
+ if isinstance(other, Pose):
+ return self.transform_pose(other)
+ elif isinstance(other, torch.Tensor):
+ if other.shape[1] == 3 and other.dim() > 2:
+ assert other.dim() == 3 or other.dim() == 4
+ return self.transform_points(other)
+ else:
+ raise ValueError('Unknown tensor dimensions {}'.format(other.shape))
+ else:
+ raise NotImplementedError()
+
diff --git a/vidar/arch/networks/layers/fsm/utils.py b/vidar/arch/networks/layers/fsm/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c6dda22891a38da274cdd3879f248881a2a4f57
--- /dev/null
+++ b/vidar/arch/networks/layers/fsm/utils.py
@@ -0,0 +1,298 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import torch
+import torch.nn.functional as tfunc
+
+from vidar.utils.types import is_list
+
+
+def coords_from_motion(ref_camera, tgt_depth, tgt_camera, scene_flow=None):
+ """
+ Get coordinates from motion (depth + ego-motion) information
+
+ Parameters
+ ----------
+ ref_camera : Camera
+ Reference camera
+ tgt_depth : Tensor
+ Target depth map [B,1,H,W]
+ tgt_camera : Camera
+ Target camera
+ scene_flow : Tensor
+ Target optical flow
+
+ Returns
+ -------
+ coords : torch.Tensor
+ Warping coordinates [B,2,H,W]
+ """
+ # If there are multiple reference cameras, iterate for each
+ if is_list(ref_camera):
+ return [coords_from_motion(camera, tgt_depth, tgt_camera, scene_flow)
+ for camera in ref_camera]
+ # If there are multiple depth maps, iterate for each
+ if is_list(tgt_depth):
+ return [coords_from_motion(ref_camera, depth, tgt_camera, scene_flow)
+ for depth in tgt_depth]
+ # Reconstruct and reproject points to generate warping coordinates
+ world_points = tgt_camera.reconstruct(tgt_depth, frame='w', scene_flow=scene_flow)
+ return ref_camera.project(world_points, frame='w').permute(0, 3, 1, 2).contiguous()
+
+
+def mask_from_coords(coords):
+ """
+ Get overlap mask from coordinates
+
+ Parameters
+ ----------
+ coords : Tensor
+ Warping coordinates [B,2,H,W]
+
+ Returns
+ -------
+ mask : Tensor
+ Overlap mask [B,1,H,W]
+ """
+ # If there are multiple warping coordinates, iterate for each
+ if is_list(coords):
+ return [mask_from_coords(coord) for coord in coords]
+ # Create and return mask
+ b, _, h, w = coords.shape
+ mask = torch.ones((b, 1, h, w), dtype=torch.float32, device=coords.device, requires_grad=False)
+ mask = warp_from_coords(mask, coords, mode='bilinear', padding_mode='zeros', align_corners=True)
+ return mask.bool()
+
+
+def warp_from_coords(tensor, coords, mask=False, mode='bilinear',
+ padding_mode='zeros', align_corners=True):
+ """
+ Warp an image from a coordinate map
+
+ Parameters
+ ----------
+ tensor : torch.Tensor
+ Input tensor for warping [B,?,H,W]
+ coords : torch.Tensor
+ Warping coordinates [B,2,H,W]
+ mask : Bool
+ Whether the warped tensor is masked for non-overlapping regions
+ mode : String
+ Warping mode
+ padding_mode : String
+ Padding mode
+ align_corners : Bool
+ Align corners flag
+
+ Returns
+ -------
+ warp : torch.Tensor
+ Warped tensor [B,?,H,W]
+ """
+ # Sample grid from data with coordinates
+ warp = tfunc.grid_sample(tensor, coords.permute(0, 2, 3, 1).contiguous(),
+ mode=mode, padding_mode=padding_mode,
+ align_corners=align_corners)
+ # If masking
+ if mask:
+ mask = torch.ones_like(tensor, requires_grad=False)
+ mask = tfunc.grid_sample(mask, coords.permute(0, 2, 3, 1).contiguous())
+ warp = warp * (mask >= 1.0).detach()
+ # Returned warped tensor
+ return warp
+
+
+def filter_dict(dictionary, keywords):
+ """
+ Returns only the keywords that are part of a dictionary
+
+ Parameters
+ ----------
+ dictionary : dict
+ Dictionary for filtering
+ keywords : list of str
+ Keywords that will be filtered
+
+ Returns
+ -------
+ keywords : list of str
+ List containing the keywords that are keys in dictionary
+ """
+ return [key for key in keywords if key in dictionary]
+
+
+def merge_outputs(*outputs):
+ """
+ Merges model outputs for logging
+
+ Parameters
+ ----------
+ outputs : tuple of dict
+ Outputs to be merged
+
+ Returns
+ -------
+ output : dict
+ Dictionary with a "metrics" key containing a dictionary with various metrics and
+ all other keys that are not "loss" (it is handled differently).
+ """
+ ignore = ['loss'] # Keys to ignore
+ combine = ['metrics'] # Keys to combine
+ merge = {key: {} for key in combine}
+ for output in outputs:
+ # Iterate over all keys
+ for key, val in output.items():
+ # Combine these keys
+ if key in combine:
+ for sub_key, sub_val in output[key].items():
+ assert sub_key not in merge[key].keys(), \
+ 'Combining duplicated key {} to {}'.format(sub_key, key)
+ merge[key][sub_key] = sub_val
+ # Ignore these keys
+ elif key not in ignore:
+ assert key not in merge.keys(), \
+ 'Adding duplicated key {}'.format(key)
+ merge[key] = val
+ return merge
+
+
+def flip_batch_input(batch):
+ """
+ Flip batch input information (copies data first)
+
+ Parameters
+ ----------
+ batch : dict
+ Batch information
+
+ Returns
+ -------
+ batch : dict
+ Flipped batch
+ """
+ # Flip images and input depth
+ for key in filter_dict(batch, [
+ 'rgb', 'input_depth'
+ ]):
+ batch[key] = flip_lr(batch[key])
+ # Flip context images
+ for key in filter_dict(batch, [
+ 'rgb_context',
+ ]):
+ batch[key] = [flip_lr(img) for img in batch[key]]
+ # Flip intrinsics
+ for key in filter_dict(batch, [
+ 'intrinsics'
+ ]):
+ batch[key] = batch[key].clone()
+ batch[key][:, 0, 2] = batch['rgb'].shape[3] - batch[key][:, 0, 2]
+ # Return flipped batch
+ return batch
+
+
+def flip_output(output):
+ """
+ Flip output information
+
+ Parameters
+ ----------
+ output : dict
+ Dictionary of model outputs (e.g. with keys like 'inv_depths' and 'uncertainty')
+
+ Returns
+ -------
+ output : dict
+ Flipped output
+ """
+ # Flip list of tensors
+ for key in filter_dict(output, [
+ 'inv_depths', 'uncertainty', 'logits_semantic'
+ ]):
+ output[key] = [flip_lr(val) for val in output[key]]
+ return output
+
+
+class CameraNormalizer:
+ """
+ Camera normalizer class.
+ Initialized with a desired focal lenght, and will normalize images to follow these values.
+ These images can then be unormalized to return to the original resolution/intrinsics
+
+ Parameters
+ ----------
+ focal : tuple
+ Focal lengths (fx, fy)
+ """
+ def __init__(self, focal, mode='reflect'):
+ self.focal = focal
+ self.mode = mode
+ self.diffs = []
+
+ def normalize(self, rgb, intrinsics):
+ """
+ Normalize input image
+
+ Parameters
+ ----------
+ rgb : torch.Tensor
+ Input image [B,3,H,W]
+ intrinsics : torch.Tensor
+ Input intrinsics [B,3,3]
+
+ Returns
+ -------
+ rgb_pad : torch.Tensor
+ Normalized image with padding [B,3,H,W]
+ """
+ rgb_pad = []
+ self.diffs.clear()
+ # Process each image independently
+ for i in range(len(rgb)):
+ rgb_i = rgb[i].unsqueeze(0)
+ intrinsics_i = intrinsics[i]
+ wh_orig = rgb_i.shape[2:]
+ # Get resize ratio
+ ratio = [float(self.focal[1] / intrinsics_i[1, 1]),
+ float(self.focal[0] / intrinsics_i[0, 0])]
+ wh_norm = [int(o * r) for o, r in zip(wh_orig, ratio)]
+ # Resize image
+ rgb_i_norm = torch.nn.functional.interpolate(
+ rgb_i, size=wh_norm, mode='bilinear', align_corners=True)
+ # Pad image
+ diff = [int(o - n) for o, n in zip(wh_orig, wh_norm)]
+ rgb_i_pad = torch.nn.functional.pad(
+ rgb_i_norm, pad=[diff[1] // 2, (diff[1] + 1) // 2,
+ diff[0] // 2, (diff[0] + 1) // 2], mode=self.mode)
+ rgb_pad.append(rgb_i_pad)
+ self.diffs.append(diff)
+ # Return concatenation of all images
+ return torch.cat(rgb_pad, 0)
+
+ def unormalize(self, rgb):
+ """
+ Unormalize image following the previous normalization.
+
+ Parameters
+ ----------
+ rgb : torch.Tensor
+ Normalized image with padding [B,3,H,W]
+
+ Returns
+ -------
+ orig_rgb : torch.Tensor
+ Original image
+ """
+ # If it's a list, unnormalize each one
+ if is_list(rgb):
+ return [self.unormalize(r) for r in rgb]
+ rgb_orig = []
+ hw = rgb.shape[2:]
+ for i in range(len(rgb)):
+ rgb_i = rgb[i].unsqueeze(0)
+ diff_i = self.diffs[i]
+ rgb_i = rgb_i[:, :,
+ (diff_i[0] // 2): (hw[0] - (diff_i[0] + 1) // 2),
+ (diff_i[1] // 2): (hw[1] - (diff_i[1] + 1) // 2)]
+ rgb_i_orig = torch.nn.functional.interpolate(
+ rgb_i, size=hw, mode='bilinear', align_corners=True)
+ rgb_orig.append(rgb_i_orig)
+ return torch.cat(rgb_orig, 0)
\ No newline at end of file
diff --git a/vidar/arch/networks/layers/inits.py b/vidar/arch/networks/layers/inits.py
new file mode 100644
index 0000000000000000000000000000000000000000..5438597d5bd8145769b2539b2602f51d75e612c5
--- /dev/null
+++ b/vidar/arch/networks/layers/inits.py
@@ -0,0 +1,14 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import torch
+import torch.nn as nn
+
+
+def weights_init_xavier(m):
+ """Xavier weight initialization"""
+ if isinstance(m, nn.Conv2d):
+ torch.nn.init.xavier_uniform_(m.weight)
+ if m.bias is not None:
+ torch.nn.init.zeros_(m.bias)
+
+
diff --git a/vidar/arch/networks/layers/packnet/__init__.py b/vidar/arch/networks/layers/packnet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vidar/arch/networks/layers/packnet/packnet.py b/vidar/arch/networks/layers/packnet/packnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..e14fa5fa9bbf4f069d1756ed2dcb57b8b33fac20
--- /dev/null
+++ b/vidar/arch/networks/layers/packnet/packnet.py
@@ -0,0 +1,288 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from functools import partial
+
+import torch
+import torch.nn as nn
+
+
+class Conv2D(nn.Module):
+ """
+ 2D convolution with GroupNorm and ELU
+
+ Parameters
+ ----------
+ in_channels : int
+ Number of input channels
+ out_channels : int
+ Number of output channels
+ kernel_size : int
+ Kernel size
+ stride : int
+ Stride
+ """
+ def __init__(self, in_channels, out_channels, kernel_size, stride):
+ super().__init__()
+ self.kernel_size = kernel_size
+ self.conv_base = nn.Conv2d(
+ in_channels, out_channels, kernel_size=kernel_size, stride=stride)
+ self.pad = nn.ConstantPad2d([kernel_size // 2] * 4, value=0)
+ self.normalize = torch.nn.GroupNorm(16, out_channels)
+ self.activ = nn.ELU(inplace=True)
+
+ def forward(self, x):
+ """Runs the Conv2D layer."""
+ x = self.conv_base(self.pad(x))
+ return self.activ(self.normalize(x))
+
+
+class ResidualConv(nn.Module):
+ """2D Convolutional residual block with GroupNorm and ELU"""
+ def __init__(self, in_channels, out_channels, stride, dropout=None):
+ """
+ Initializes a ResidualConv object.
+
+ Parameters
+ ----------
+ in_channels : int
+ Number of input channels
+ out_channels : int
+ Number of output channels
+ stride : int
+ Stride
+ dropout : float
+ Dropout value
+ """
+ super().__init__()
+ self.conv1 = Conv2D(in_channels, out_channels, 3, stride)
+ self.conv2 = Conv2D(out_channels, out_channels, 3, 1)
+ self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)
+ self.normalize = torch.nn.GroupNorm(16, out_channels)
+ self.activ = nn.ELU(inplace=True)
+
+ if dropout:
+ # self.dropout = nn.Dropout(dropout)
+ self.conv3 = nn.Sequential(self.conv3, nn.Dropout2d(dropout))
+ else:
+ self.dropout = None
+
+ def forward(self, x):
+ """Runs the ResidualConv layer."""
+ x_out = self.conv1(x)
+ x_out = self.conv2(x_out)
+ shortcut = self.conv3(x)
+ # if self.dropout:
+ # shortcut = self.dropout(shortcut)
+ return self.activ(self.normalize(x_out + shortcut))
+
+
+def ResidualBlock(in_channels, out_channels, num_blocks, stride, dropout=None):
+ """
+ Returns a ResidualBlock with various ResidualConv layers.
+
+ Parameters
+ ----------
+ in_channels : int
+ Number of input channels
+ out_channels : int
+ Number of output channels
+ num_blocks : int
+ Number of residual blocks
+ stride : int
+ Stride
+ dropout : float
+ Dropout value
+ """
+ layers = [ResidualConv(in_channels, out_channels, stride, dropout=dropout)]
+ for i in range(1, num_blocks):
+ layers.append(ResidualConv(out_channels, out_channels, 1, dropout=dropout))
+ return nn.Sequential(*layers)
+
+
+class InvDepth(nn.Module):
+ """Inverse depth layer"""
+ def __init__(self, in_channels, out_channels=1, min_depth=0.5):
+ """
+ Initializes an InvDepth object.
+
+ Parameters
+ ----------
+ in_channels : int
+ Number of input channels
+ out_channels : int
+ Number of output channels
+ min_depth : float
+ Minimum depth value to calculate
+ """
+ super().__init__()
+ self.min_depth = min_depth
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1)
+ self.pad = nn.ConstantPad2d([1] * 4, value=0)
+ self.activ = nn.Sigmoid()
+
+ def forward(self, x):
+ """Runs the InvDepth layer."""
+ x = self.conv1(self.pad(x))
+ return self.activ(x) / self.min_depth
+
+
+def packing(x, r=2):
+ """
+ Takes a [B,C,H,W] tensor and returns a [B,(r^2)C,H/r,W/r] tensor, by concatenating
+ neighbor spatial pixels as extra channels. It is the inverse of nn.PixelShuffle
+ (if you apply both sequentially you should get the same tensor)
+
+ Parameters
+ ----------
+ x : torch.Tensor [B,C,H,W]
+ Input tensor
+ r : int
+ Packing ratio
+
+ Returns
+ -------
+ out : torch.Tensor [B,(r^2)C,H/r,W/r]
+ Packed tensor
+ """
+ b, c, h, w = x.shape
+ out_channel = c * (r ** 2)
+ out_h, out_w = h // r, w // r
+ x = x.contiguous().view(b, c, out_h, r, out_w, r)
+ return x.permute(0, 1, 3, 5, 2, 4).contiguous().view(b, out_channel, out_h, out_w)
+
+
+class PackLayerConv2d(nn.Module):
+ """
+ Packing layer with 2d convolutions. Takes a [B,C,H,W] tensor, packs it
+ into [B,(r^2)C,H/r,W/r] and then convolves it to produce [B,C,H/r,W/r].
+ """
+ def __init__(self, in_channels, kernel_size, r=2):
+ """
+ Initializes a PackLayerConv2d object.
+
+ Parameters
+ ----------
+ in_channels : int
+ Number of input channels
+ kernel_size : int
+ Kernel size
+ r : int
+ Packing ratio
+ """
+ super().__init__()
+ self.conv = Conv2D(in_channels * (r ** 2), in_channels, kernel_size, 1)
+ self.pack = partial(packing, r=r)
+
+ def forward(self, x):
+ """Runs the PackLayerConv2d layer."""
+ x = self.pack(x)
+ x = self.conv(x)
+ return x
+
+
+class UnpackLayerConv2d(nn.Module):
+ """
+ Unpacking layer with 2d convolutions. Takes a [B,C,H,W] tensor, convolves it
+ to produce [B,(r^2)C,H,W] and then unpacks it to produce [B,C,rH,rW].
+ """
+ def __init__(self, in_channels, out_channels, kernel_size, r=2):
+ """
+ Initializes a UnpackLayerConv2d object.
+
+ Parameters
+ ----------
+ in_channels : int
+ Number of input channels
+ out_channels : int
+ Number of output channels
+ kernel_size : int
+ Kernel size
+ r : int
+ Packing ratio
+ """
+ super().__init__()
+ self.conv = Conv2D(in_channels, out_channels * (r ** 2), kernel_size, 1)
+ self.unpack = nn.PixelShuffle(r)
+
+ def forward(self, x):
+ """Runs the UnpackLayerConv2d layer."""
+ x = self.conv(x)
+ x = self.unpack(x)
+ return x
+
+
+class PackLayerConv3d(nn.Module):
+ """
+ Packing layer with 3d convolutions. Takes a [B,C,H,W] tensor, packs it
+ into [B,(r^2)C,H/r,W/r] and then convolves it to produce [B,C,H/r,W/r].
+ """
+ def __init__(self, in_channels, kernel_size, r=2, d=8):
+ """
+ Initializes a PackLayerConv3d object.
+
+ Parameters
+ ----------
+ in_channels : int
+ Number of input channels
+ kernel_size : int
+ Kernel size
+ r : int
+ Packing ratio
+ d : int
+ Number of 3D features
+ """
+ super().__init__()
+ self.conv = Conv2D(in_channels * (r ** 2) * d, in_channels, kernel_size, 1)
+ self.pack = partial(packing, r=r)
+ self.conv3d = nn.Conv3d(1, d, kernel_size=(3, 3, 3),
+ stride=(1, 1, 1), padding=(1, 1, 1))
+
+ def forward(self, x):
+ """Runs the PackLayerConv3d layer."""
+ x = self.pack(x)
+ x = x.unsqueeze(1)
+ x = self.conv3d(x)
+ b, c, d, h, w = x.shape
+ x = x.view(b, c * d, h, w)
+ x = self.conv(x)
+ return x
+
+
+class UnpackLayerConv3d(nn.Module):
+ """
+ Unpacking layer with 3d convolutions. Takes a [B,C,H,W] tensor, convolves it
+ to produce [B,(r^2)C,H,W] and then unpacks it to produce [B,C,rH,rW].
+ """
+ def __init__(self, in_channels, out_channels, kernel_size, r=2, d=8):
+ """
+ Initializes a UnpackLayerConv3d object.
+
+ Parameters
+ ----------
+ in_channels : int
+ Number of input channels
+ out_channels : int
+ Number of output channels
+ kernel_size : int
+ Kernel size
+ r : int
+ Packing ratio
+ d : int
+ Number of 3D features
+ """
+ super().__init__()
+ self.conv = Conv2D(in_channels, out_channels * (r ** 2) // d, kernel_size, 1)
+ self.unpack = nn.PixelShuffle(r)
+ self.conv3d = nn.Conv3d(1, d, kernel_size=(3, 3, 3),
+ stride=(1, 1, 1), padding=(1, 1, 1))
+
+ def forward(self, x):
+ """Runs the UnpackLayerConv3d layer."""
+ x = self.conv(x)
+ x = x.unsqueeze(1)
+ x = self.conv3d(x)
+ b, c, d, h, w = x.shape
+ x = x.view(b, c * d, h, w)
+ x = self.unpack(x)
+ return x
+
diff --git a/vidar/arch/networks/layers/resnet_encoder.py b/vidar/arch/networks/layers/resnet_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f69984cc5586665e82ab40ea1c2fdf8ae357b40
--- /dev/null
+++ b/vidar/arch/networks/layers/resnet_encoder.py
@@ -0,0 +1,46 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import torch.nn as nn
+import torchvision.models as models
+
+
+class ResNetEncoder(nn.Module):
+ """Creates a ResNet encoder with different parameters"""
+ def __init__(self, name):
+ super().__init__()
+ if name == 'densenet121':
+ self.base_model = models.densenet121(pretrained=True).features
+ self.feat_names = ['relu0', 'pool0', 'transition1', 'transition2', 'norm5']
+ self.feat_out_channels = [64, 64, 128, 256, 1024]
+ elif name == 'densenet161':
+ self.base_model = models.densenet161(pretrained=True).features
+ self.feat_names = ['relu0', 'pool0', 'transition1', 'transition2', 'norm5']
+ self.feat_out_channels = [96, 96, 192, 384, 2208]
+ elif name == 'resnet50':
+ self.base_model = models.resnet50(pretrained=True)
+ self.feat_names = ['relu', 'layer1', 'layer2', 'layer3', 'layer4']
+ self.feat_out_channels = [64, 256, 512, 1024, 2048]
+ elif name == 'resnet101':
+ self.base_model = models.resnet101(pretrained=True)
+ self.feat_names = ['relu', 'layer1', 'layer2', 'layer3', 'layer4']
+ self.feat_out_channels = [64, 256, 512, 1024, 2048]
+ elif name == 'resnext50':
+ self.base_model = models.resnext50_32x4d(pretrained=True)
+ self.feat_names = ['relu', 'layer1', 'layer2', 'layer3', 'layer4']
+ self.feat_out_channels = [64, 256, 512, 1024, 2048]
+ elif name == 'resnext101':
+ self.base_model = models.resnext101_32x8d(pretrained=True)
+ self.feat_names = ['relu', 'layer1', 'layer2', 'layer3', 'layer4']
+ self.feat_out_channels = [64, 256, 512, 1024, 2048]
+ else:
+ raise NotImplementedError('Not supported encoder: {}'.format(name))
+
+ def forward(self, x):
+ features, skips = [x], [x]
+ for key, val in self.base_model._modules.items():
+ if not any(x in key for x in ['fc', 'avgpool']):
+ feature = val(features[-1])
+ features.append(feature)
+ if any(x in key for x in self.feat_names):
+ skips.append(feature)
+ return skips
diff --git a/vidar/arch/networks/perceiver/DefineNet.py b/vidar/arch/networks/perceiver/DefineNet.py
new file mode 100755
index 0000000000000000000000000000000000000000..8e087a8293072816219e66b595b3a5f9b1c56e15
--- /dev/null
+++ b/vidar/arch/networks/perceiver/DefineNet.py
@@ -0,0 +1,371 @@
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+import abc
+from transformers import PerceiverModel, PerceiverConfig
+from transformers.models.perceiver.modeling_perceiver import build_position_encoding
+
+from vidar.arch.networks.perceiver.externals.modeling_perceiver import PerceiverDepthDecoder, PerceiverRGBDecoder, build_position_encoding
+from vidar.arch.blocks.depth.SigmoidToInvDepth import SigmoidToInvDepth
+from vidar.arch.networks.decoders.DepthDecoder import DepthDecoder
+from vidar.arch.networks.encoders.ResNetEncoder import ResNetEncoder
+from vidar.utils.config import Config
+from vidar.utils.networks import freeze_layers_and_norms
+from vidar.utils.tensor import interpolate
+from vidar.utils.types import is_int
+
+
+class DownSampleRGB(nn.Module):
+ def __init__(self, out_dim):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(3, out_dim, kernel_size=7, stride=2, padding=3)
+ self.norm = torch.nn.BatchNorm2d(out_dim)
+ self.actv = torch.nn.ReLU()
+ self.pool = torch.nn.MaxPool2d(2, stride=2)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.norm(x)
+ x = self.actv(x)
+ x = self.pool(x)
+ return x
+
+
+class DefineNet(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+
+ self.tasks = cfg.tasks
+ self.to_world = cfg.to_world
+ self.depth_range = cfg.depth_range
+ self.rgb_feat_dim = cfg.rgb_feat_dim
+ self.rgb_feat_type = cfg.rgb_feat_type
+ self.encoder_with_rgb = cfg.encoder_with_rgb
+ self.decoder_with_rgb = cfg.decoder_with_rgb
+ self.output_mode = cfg.output_mode
+ self.sample_encoding_rays = cfg.sample_encoding_rays
+ self.with_monodepth = cfg.with_monodepth
+
+ self.upsample_convex = cfg.upsample_convex
+ self.downsample_encoder = cfg.downsample_encoder
+ self.downsample_decoder = cfg.downsample_decoder
+
+ self.image_shape = [s // self.downsample_encoder for s in cfg.image_shape]
+
+ self.fourier_encoding_orig, _ = build_position_encoding(
+ position_encoding_type='fourier',
+ fourier_position_encoding_kwargs={
+ 'num_bands': cfg.num_bands_orig,
+ 'max_resolution': [cfg.max_resolution_orig] * 3,
+ 'concat_pos': True,
+ 'sine_only': False,
+ }
+ )
+
+ self.fourier_encoding_dirs, _ = build_position_encoding(
+ position_encoding_type='fourier',
+ fourier_position_encoding_kwargs={
+ 'num_bands': cfg.num_bands_dirs,
+ 'max_resolution': [cfg.num_bands_dirs] * 3,
+ 'concat_pos': True,
+ 'sine_only': False,
+ }
+ )
+
+ tot_encoder = self.fourier_encoding_orig.output_size() + \
+ self.fourier_encoding_dirs.output_size()
+ if self.encoder_with_rgb:
+ tot_encoder += self.rgb_feat_dim
+
+ tot_decoder = self.fourier_encoding_orig.output_size() + \
+ self.fourier_encoding_dirs.output_size()
+ if self.decoder_with_rgb:
+ tot_decoder += self.rgb_feat_dim
+
+ tot_decoder_depth = tot_decoder
+ tot_decoder_rgb = tot_decoder
+
+ self.config = PerceiverConfig(
+ train_size=self.image_shape,
+ d_latents=cfg.d_latents,
+ d_model=tot_encoder,
+ num_latents=cfg.num_latents,
+ hidden_act='gelu',
+ hidden_dropout_prob=cfg.hidden_dropout_prob,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ num_blocks=1,
+ num_cross_attention_heads=cfg.num_cross_attention_heads,
+ num_self_attends_per_block=cfg.num_self_attends_per_block,
+ num_self_attention_heads=cfg.num_self_attention_heads,
+ qk_channels=None,
+ v_channels=None,
+ )
+
+ if 'depth' in self.tasks:
+ self.decoder = PerceiverDepthDecoder(
+ self.config,
+ num_channels=tot_decoder_depth,
+ use_query_residual=False,
+ output_num_channels=1,
+ position_encoding_type="none",
+ min_depth=self.depth_range[0],
+ max_depth=self.depth_range[1],
+ num_heads=cfg.decoder_num_heads,
+ upsample_mode=cfg.upsample_convex,
+ upsample_value=cfg.downsample_decoder,
+ output_mode=cfg.output_mode
+ )
+ if 'rgb' in self.tasks:
+ self.decoder_rgb = PerceiverRGBDecoder(
+ self.config,
+ num_channels=tot_decoder_rgb,
+ use_query_residual=False,
+ output_num_channels=3,
+ position_encoding_type="none",
+ num_heads=cfg.decoder_num_heads,
+ upsample_mode=cfg.upsample_convex,
+ upsample_value=cfg.downsample_decoder,
+ )
+
+ self.model = PerceiverModel(
+ self.config,
+ )
+
+ if self.rgb_feat_type == 'convnet':
+ self.feature = DownSampleRGB(out_dim=self.rgb_feat_dim)
+ elif self.rgb_feat_type in ['resnet', 'resnet_all', 'resnet_all_rgb']:
+ self.feature = ResNetEncoder(Config(version=18, pretrained=True, num_rgb_in=1))
+
+ if self.with_monodepth:
+ self.mono_encoder = ResNetEncoder(Config(version=18, pretrained=True, num_rgb_in=1))
+ self.mono_decoder = DepthDecoder(Config(
+ num_scales=4, use_skips=True, num_ch_enc=self.feature.num_ch_enc,
+ num_ch_out=1, activation='sigmoid',
+ ))
+ self.sigmoid_to_depth = SigmoidToInvDepth(
+ min_depth=self.depth_range[0], max_depth=self.depth_range[1], return_depth=True)
+
+ def get_rgb_feat(self, rgb):
+ if self.rgb_feat_type == 'convnet':
+ return {
+ 'feat': self.feature(rgb)
+ }
+ elif self.rgb_feat_type == 'resnet':
+ return {
+ 'feat': self.feature(rgb)[1]
+ }
+ elif self.rgb_feat_type.startswith('resnet_all'):
+ all_feats = self.feature(rgb)
+ feats = all_feats[1:]
+ for i in range(1, len(feats)):
+ feats[i] = interpolate(
+ feats[i], size=feats[0], scale_factor=None, mode='bilinear', align_corners=True)
+ if self.rgb_feat_type.endswith('rgb'):
+ feats = feats + [interpolate(
+ rgb, size=feats[0], scale_factor=None, mode='bilinear', align_corners=True)]
+ feat = torch.cat(feats, 1)
+ return {
+ 'all_feats': all_feats,
+ 'feat': feat
+ }
+
+ def run_monodepth(self, rgb, freeze):
+ freeze_layers_and_norms(self.mono_encoder, flag_freeze=freeze)
+ freeze_layers_and_norms(self.mono_decoder, flag_freeze=freeze)
+ mono_features = self.mono_encoder(rgb)
+ mono_output = self.mono_decoder(mono_features)
+ sigmoids = [mono_output[('output', i)] for i in range(1)]
+ return self.sigmoid_to_depth(sigmoids)[0]
+
+ def embeddings(self, data, sources, downsample):
+
+ if 'rgb' in sources:
+ assert 'rgb' in data[0].keys()
+ b = [datum['rgb'].shape[0] for datum in data]
+ rgb = torch.cat([datum['rgb'] for datum in data], 0)
+ output_feats = self.get_rgb_feat(rgb)
+ feats = torch.split(output_feats['feat'], b)
+ for i in range(len(data)):
+ data[i]['feat'] = feats[i]
+
+ if self.with_monodepth:
+ depth = self.run_monodepth(rgb, freeze=False)
+ depth = torch.split(depth, b)
+ for i in range(len(data)):
+ data[i]['depth_mono'] = depth[i]
+
+ encodings = []
+ for datum in data:
+
+ encoding = OrderedDict()
+
+ if 'cam' in sources:
+ assert 'cam' in data[0].keys()
+
+ cam = datum['cam'].scaled(1. / downsample)
+ orig = cam.get_origin(flatten=True)
+
+ if self.to_world:
+ dirs = cam.get_viewdirs(normalize=True, flatten=True, to_world=True)
+ else:
+ dirs = cam.no_translation().get_viewdirs(normalize=True, flatten=True, to_world=True)
+
+ orig_encodings = self.fourier_encoding_orig(
+ index_dims=None, pos=orig, batch_size=orig.shape[0], device=orig.device)
+ dirs_encodings = self.fourier_encoding_dirs(
+ index_dims=None, pos=dirs, batch_size=dirs.shape[0], device=dirs.device)
+
+ encoding['cam'] = torch.cat([orig_encodings, dirs_encodings], -1)
+
+ if 'rgb' in sources:
+ rgb = datum['feat']
+ rgb_flat = rgb.view(*rgb.shape[:-2], -1).permute(0, 2, 1)
+ encoding['rgb'] = rgb_flat
+
+ encoding['all'] = torch.cat([val for val in encoding.values()], -1)
+ encodings.append(encoding)
+
+ return encodings
+
+ @staticmethod
+ def sample_decoder(data, embeddings, field, sample_queries, filter_invalid):
+
+ query_idx = []
+
+ if filter_invalid:
+ tot_min = []
+
+ for i in range(len(embeddings)):
+ for b in range(data[i]['rgb'].shape[0]):
+ tot_min.append((data[i]['rgb'][b].mean(0) >= 0).sum())
+ tot_min = min(tot_min)
+
+ tot = embeddings[0][field][0].shape[0]
+ tot = int(sample_queries * tot)
+ tot = min([tot, tot_min])
+
+ for i in range(len(embeddings)):
+ idx = []
+
+ for b in range(data[i]['rgb'].shape[0]):
+ if filter_invalid:
+
+ valid = data[i]['rgb'][b].mean(0, keepdim=True) >= 0
+ valid = valid.view(1, -1).permute(1, 0)
+
+ num = embeddings[i][field][0].shape[0]
+ all_idx = torch.arange(num, device=valid.device).unsqueeze(1)
+ valid_idx = all_idx[valid]
+
+ num = valid_idx.shape[0]
+ idx_i = torch.randperm(num)[tot:]
+ valid[valid_idx[idx_i]] = 0
+ idx_i = all_idx[valid]
+
+ else:
+
+ num = embeddings[i][field][0].shape[0]
+ tot = int(sample_queries * num)
+ idx_i = torch.randperm(num)[:tot]
+
+ idx.append(idx_i)
+
+ idx = torch.stack(idx, 0)
+ embeddings[i][field] = torch.stack(
+ [embeddings[i][field][b][idx[b]] for b in range(idx.shape[0])], 0)
+
+ query_idx.append(idx)
+
+ return query_idx, embeddings
+
+ def forward(self, encode_data, decode_data=None,
+ sample_queries=0, filter_invalid=False):
+
+ encode_field = 'all' if self.encoder_with_rgb else 'cam'
+ decode_field = 'all' if self.decoder_with_rgb else 'cam'
+
+ encode_sources = ['rgb', 'cam']
+ decode_sources = ['cam']
+
+ shape = encode_data[0]['cam'].hw
+
+ output = {}
+
+ encode_dict = self.encode(
+ data=encode_data, field=encode_field, sources=encode_sources
+ )
+
+ if 'depth_mono' in encode_data[0].keys():
+ output['depth_mono'] = [datum['depth_mono'] for datum in encode_data]
+
+ decode_embeddings = encode_dict['embeddings'] if decode_data is None else None
+
+ decode_dict = self.decode(
+ latent=encode_dict['latent'], shape=shape,
+ data=decode_data, embeddings=decode_embeddings,
+ field=decode_field, sources=decode_sources,
+ sample_queries=sample_queries, filter_invalid=filter_invalid
+ )
+
+ output.update(decode_dict['output'])
+
+ return {
+ 'output': output,
+ 'encode_embeddings': encode_dict['embeddings'],
+ 'decode_embeddings': decode_dict['embeddings'],
+ 'latent': encode_dict['latent'],
+ }
+
+ def encode(self, field, sources, data=None, embeddings=None):
+ assert data is not None or embeddings is not None
+ assert data is None or embeddings is None
+
+ if embeddings is None:
+ embeddings = self.embeddings(data, sources=sources, downsample=self.downsample_encoder)
+
+ all_embeddings = torch.cat([emb[field] for emb in embeddings], 1)
+
+ if self.training and self.sample_encoding_rays > 0:
+ tot = self.sample_encoding_rays if is_int(self.sample_encoding_rays) \
+ else int(self.sample_encoding_rays * all_embeddings.shape[1])
+ all_embeddings = torch.stack([all_embeddings[i, torch.randperm(all_embeddings.shape[1])[:tot], :]
+ for i in range(all_embeddings.shape[0])], 0)
+
+ return {
+ 'embeddings': embeddings,
+ 'latent': self.model(inputs=all_embeddings).last_hidden_state,
+ }
+
+ def decode(self, latent, field, sources=None, data=None, embeddings=None, shape=None,
+ sample_queries=0, filter_invalid=False):
+ assert data is not None or embeddings is not None
+ assert data is None or embeddings is None
+
+ if embeddings is None:
+ shape = data[0]['cam'].hw
+ shape = [s // self.downsample_decoder for s in shape]
+ embeddings = self.embeddings(data, sources=sources, downsample=self.downsample_decoder)
+
+ output = {}
+
+ if self.training and (sample_queries > 0): # or filter_invalid):
+ output['query_idx'], embeddings = self.sample_decoder(
+ data, embeddings, field, sample_queries, filter_invalid)
+ shape = None
+
+ if 'rgb' in self.tasks:
+ output['rgb'] = [
+ self.decoder_rgb(query=emb[field], z=latent, shape=shape).logits
+ for emb in embeddings]
+
+ if 'depth' in self.tasks:
+ output['depth'] = [
+ self.decoder(query=emb[field], z=latent, shape=shape).logits
+ for emb in embeddings]
+
+ return {
+ 'embeddings': embeddings,
+ 'output': output,
+ }
diff --git a/vidar/arch/networks/perceiver/MLPNet.py b/vidar/arch/networks/perceiver/MLPNet.py
new file mode 100755
index 0000000000000000000000000000000000000000..00543ad989cc8bb6d0c2a9655d23da559e90abd2
--- /dev/null
+++ b/vidar/arch/networks/perceiver/MLPNet.py
@@ -0,0 +1,504 @@
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+
+from vidar.arch.networks.layers.define.decoders.camera import CameraDecoder
+from vidar.arch.networks.layers.define.decoders.depth import DepthDecoder
+from vidar.arch.networks.layers.define.decoders.multiview import MultiviewDecoder
+from vidar.arch.networks.layers.define.decoders.normals import NormalsDecoder
+from vidar.arch.networks.layers.define.decoders.rgb import RGBDecoder
+from vidar.arch.networks.layers.define.decoders.semantic import SemanticDecoder
+from vidar.arch.networks.layers.define.decoders.volumetric import VolumetricDecoder
+from vidar.arch.networks.layers.define.embeddings.camera import CameraEmbeddings
+from vidar.arch.networks.layers.define.embeddings.image import ImageEmbeddings
+from vidar.arch.networks.layers.define.embeddings.multiview import MultiviewEmbeddings
+from vidar.arch.networks.layers.define.embeddings.projection import ProjectionEmbeddings
+from vidar.arch.networks.layers.define.embeddings.volumetric import VolumetricEmbeddings
+from vidar.arch.networks.layers.define.perceiver.model import PerceiverModel
+from vidar.utils.data import make_list, str_not_in
+from vidar.utils.networks import freeze_layers_and_norms
+from vidar.utils.types import is_int, is_dict, is_tensor, is_list
+
+
+def enable_dropout(model):
+ """ Function to enable the dropout layers during test-time """
+ for m in model.modules():
+ if m.__class__.__name__.startswith('Dropout'):
+ m.train()
+
+
+class MLPNet(nn.Module):
+
+ DECODER_CLASSES = {
+ 'rgb': RGBDecoder,
+ 'depth': DepthDecoder,
+ 'normals': NormalsDecoder,
+ 'semantic': SemanticDecoder,
+ 'camera': CameraDecoder,
+ 'volumetric': VolumetricDecoder,
+ 'multiview': MultiviewDecoder,
+ }
+
+ EMBEDDING_CLASSES = {
+ 'image': ImageEmbeddings,
+ 'camera': CameraEmbeddings,
+ 'volumetric': VolumetricEmbeddings,
+ 'multiview': MultiviewEmbeddings,
+ 'projection': ProjectionEmbeddings,
+ }
+
+ def __init__(self, cfg):
+ super().__init__()
+
+ # VARIATIONAL
+
+ self.is_variational = cfg.has('variational')
+ self.variational_params = None if not self.is_variational else {
+ 'kld_weight': cfg.variational.kld_weight,
+ 'encode_vae': cfg.variational.encode_vae,
+ 'soft_mask_weight': cfg.variational.has('soft_mask_weight', 1.0),
+ }
+
+ # Parameters
+
+ self.encode_sources = cfg.encode_sources
+ self.decode_sources = cfg.decode_sources
+ self.extra_sources = cfg.has('extra_sources', [])
+ self.encode_cameras = cfg.has('encode_cameras', 'mod')
+ self.latent_grid = cfg.latent.has('mode') and cfg.latent.mode == 'grid'
+ self.latent_grid_dim = cfg.latent.has('grid_dim', 0)
+
+ self.use_image_embeddings = cfg.encoder.has('use_image_embeddings', True)
+
+ if len(self.encode_sources) == 0 or not is_list(self.encode_sources[0]):
+ self.encode_sources = [self.encode_sources]
+
+ # Downsample
+
+ self.downsample_encoder = cfg.downsample_encoder
+ self.downsample_decoder = cfg.downsample_decoder
+
+ # Create embeddings
+
+ self.embeddings = nn.ModuleDict()
+ for key in cfg.embeddings.keys():
+ mode = key if '_' not in key else key.split('_')[0]
+ self.embeddings[key] = self.EMBEDDING_CLASSES[mode](cfg.embeddings.dict[key])
+
+ # Parse shared decoder parameters
+
+ if cfg.decoders.has('shared'):
+ for task in cfg.decoders.keys():
+ if task is not 'shared':
+ for key, val in cfg.decoders.shared.items():
+ if key not in cfg.decoders.dict[task].keys():
+ cfg.decoders.dict[task].dict[key] = val
+
+ # Embeddings dimension
+
+ tot_encoders = [self.total_dimension(sources) for sources in self.encode_sources]
+
+ self.models = nn.ModuleList()
+ for i in range(len(tot_encoders)):
+ cfg.encoder.d_model = tot_encoders[i]
+ is_variational = self.is_variational and self.variational_params['encode_vae'][i]
+ self.models.append(PerceiverModel(cfg, is_variational=is_variational))
+
+ # Decoders [###### MOVED TO AFTER PERCEIVER MODEL ######]
+
+ self.decoders = nn.ModuleDict()
+ for i in range(len(self.decode_sources)):
+ is_variational = self.is_variational and \
+ self.variational_params['encode_vae'][self.decode_sources[i][1]]
+ self.decode_sources[i] += [is_variational]
+ for name, _, _, embeddings, decoders, _ in self.decode_sources:
+ tot_decoder = self.total_dimension(embeddings)
+ for dec in decoders:
+ if dec not in self.decoders.keys():
+ cfg.decoders.dict[dec].name = name
+ cfg.decoders.dict[dec].num_channels = tot_decoder
+ cfg.decoders.dict[dec].d_latents = cfg.latent.dim
+ self.decoders[dec] = self.DECODER_CLASSES[dec.split('_')[0]](cfg.decoders.dict[dec])
+
+ self.bypass_self_attention = cfg.has('bypass_self_attention', True)
+ self.freeze_encoder = cfg.encoder.has('freeze', False)
+ self.really_encode = cfg.encoder.has('really_encode', True)
+
+ @property
+ def tasks(self):
+ return self.decoders.keys()
+
+ def total_dimension(self, sources):
+ if not self.use_image_embeddings and 'image' in sources:
+ sources = list(sources)
+ sources.remove('image')
+ return sum([self.embeddings[source].channels for source in sources if source not in self.extra_sources]) + \
+ self.latent_grid_dim
+
+ def get_embeddings(self, data, camera_mode, sources, downsample, encoded=None, sample=None,
+ previous=None, monodepth=None):
+
+ cam_mode = f'cam_{camera_mode}'
+
+ if 'image' in sources:
+ rgb_dict = {key: val['rgb'] for key, val in data.items()}
+ feats_dict = self.embeddings['image'](rgb_dict)
+ for key in data.keys():
+ data[key]['feats'] = feats_dict[key]
+
+ if monodepth is not None:
+ monodepth['depth_idx'] = {}
+
+ embeddings = OrderedDict()
+ for key, val in data.items():
+
+ embedding = OrderedDict()
+
+ cam_scaled = val[cam_mode].scaled(1. / downsample)
+ cam_scaled2 = None
+
+ embedding['info'] = {}
+
+ if self.training and sample is not None and sample > 0.0:
+ if previous is None:
+ idx, start = self.sample_decoder_idx(sample, cam_scaled.batch_size, cam_scaled.hw)
+ else:
+ idx, start = previous['info'][key]['idx'], previous['info'][key]['start']
+ if start is not None:
+ cam_scaled2 = cam_scaled.offset_start(start).scaled(1. / sample)
+ val['cam_scaled'] = embedding['info']['cam_scaled'] = cam_scaled2
+ else:
+ idx = start = None
+ val['cam_scaled'] = embedding['info']['cam_scaled'] = cam_scaled
+
+ if monodepth is not None:
+ b, _, h, w = monodepth['depth'][key][0].shape
+ depth = monodepth['depth'][key][0].view(b, 1, -1).permute(0, 2, 1)
+ if idx is not None:
+ depth = torch.stack([depth[j][idx[j]] for j in range(len(idx))], 0)
+ # from vidar.utils.viz import viz_depth
+ # from vidar.utils.write import write_image
+ # write_image('depth.png', viz_depth(depth1.view(b, 1, 192, 640)[0]))
+ # write_image('depth_sub.png', viz_depth(depth2.view(b, 1, 48, 160)[0]))
+ # import sys
+ # sys.exit()
+ monodepth['depth_idx'][key] = depth
+
+ embedding['info']['idx'] = idx
+ embedding['info']['start'] = start
+
+ for source in sources:
+ if source.startswith('camera'):
+ val['camera'], embedding[source] = self.embeddings[source](
+ cam_scaled, key, idx,
+ meta=val['meta'] if 'meta' in val else None)
+ if start is not None:
+ h, w = cam_scaled2.hw
+ b, n, c = val['camera'].shape
+ val['camera'] = val['camera'].permute(0, 2, 1).reshape(b, c, h, w)
+
+ if 'image' in sources:
+ rgb = val['feats']
+ embedding['image'] = rgb.view(*rgb.shape[:-2], -1).permute(0, 2, 1)
+ if idx is not None:
+ embedding['image'] = torch.stack([
+ embedding['image'][i][idx[i]] for i in range(embedding['image'].shape[0])], 0)
+
+ for source in sources:
+ if source.startswith('volumetric'):
+ embedding['info']['z_samples'], embedding[source] = self.embeddings[source](
+ cam_scaled, key, data, previous, idx)
+
+ for source in sources:
+ if source.startswith('multiview'):
+ if encoded is None:
+ encoded = {'data': data}
+ multiview_previous = previous if previous is not None else monodepth
+ embedding['info']['z_samples'], embedding['info']['xyz'], embedding[source] = \
+ self.embeddings[source](cam_scaled, key, data, encoded, multiview_previous, idx)
+ embedding.pop('image', None)
+
+ for source in sources:
+ if source.startswith('projection'):
+ embedding[source] = self.embeddings[source](
+ cam_scaled2 if cam_scaled2 is not None else cam_scaled,
+ key, encoded, embedding['info'], idx)
+
+ if idx is not None and previous is None:
+ data_key = data[key]
+ for key_gt in data_key.keys():
+ if data_key[key_gt] is not None and not key_gt.startswith('cam'):
+ if is_tensor(data_key[key_gt]) and data_key[key_gt].dim() == 4:
+ data_key[key_gt] = data_key[key_gt].permute(0, 2, 3, 1).reshape(
+ data_key[key_gt].shape[0], -1, data_key[key_gt].shape[1])
+ data_key[key_gt] = torch.stack([
+ data_key[key_gt][i, idx[i]] for i in range(data_key[key_gt].shape[0])], 0)
+ if start is not None:
+ h, w = cam_scaled2.hw
+ b, n, c = data_key[key_gt].shape
+ data_key[key_gt] = data_key[key_gt].permute(0, 2, 1).reshape(b, c, h, w)
+
+ embeddings[key] = embedding
+
+ if monodepth is not None:
+ monodepth.pop('depth_idx')
+
+ return embeddings
+
+ @staticmethod
+ def sample_decoder_idx(sample_decoder, b, hw):
+ n = hw[0] * hw[1]
+ if is_int(sample_decoder):
+ idx, start = [], []
+ for _ in range(b):
+ start_i = torch.randint(0, sample_decoder, (2,))
+ idx_i = torch.arange(0, n).reshape(hw)
+ idx_i = idx_i[start_i[0]::sample_decoder, start_i[1]::sample_decoder].reshape(-1)
+ idx.append(idx_i)
+ start.append(start_i)
+ start = torch.stack(start, 0)
+ else:
+ tot = int(sample_decoder * n)
+ idx = [torch.randperm(n)[:tot] for _ in range(b)]
+ start = None
+ idx = torch.stack(idx, 0)
+ return idx, start
+
+ def bypass_encode(self, data, scene, idx):
+ return {
+ 'embeddings': None,
+ 'latent': self.models[idx].embeddings(batch_size=1, scene=scene) if self.bypass_self_attention else
+ self.models[idx](data=None)['last_hidden_state'],
+ 'data': data,
+ }
+
+ def encode(self, sources=None, data=None, embeddings=None, sample_encoder=0, scene=None):
+ assert data is not None or embeddings is not None
+ assert data is None or embeddings is None
+
+ sources = sources if sources is not None else self.encode_sources
+
+ return [self.single_encode(sources, data, embeddings, sample_encoder, scene, idx=i)
+ for i in range(len(sources))]
+
+ def single_encode(self, sources=None, data=None, embeddings=None, sample_encoder=0, scene=None, idx=0):
+ assert data is not None or embeddings is not None
+ assert data is None or embeddings is None
+
+ # Freeze encoder if requested
+ for model in self.models:
+ freeze_layers_and_norms(model, flag_freeze=self.freeze_encoder)
+
+ # Get default sources if they are not provided
+ sources = sources if sources is not None else self.encode_sources
+ sources = sources[idx]
+
+ camera_mode = self.encode_cameras[idx] if is_list(self.encode_cameras) else self.encode_cameras
+
+ # Don't encode if there is no data or sources to use
+ if len(data) == 0 or len(sources) == 0:
+ return self.bypass_encode(data, scene, idx=idx)
+
+ # Create embeddings if they are not provided
+ if embeddings is None:
+ embeddings = self.get_embeddings(
+ data=data, sources=sources, camera_mode=camera_mode, downsample=self.downsample_encoder)
+ embeddings = {key: torch.cat([val[source] for source in sources if source in val], -1) for key, val in embeddings.items()}
+
+ # Sample embeddings if requested
+ if self.training and sample_encoder > 0:
+ for key in embeddings.keys():
+ tot = sample_encoder if is_int(sample_encoder) \
+ else int(sample_encoder * embeddings[key].shape[1])
+ embeddings[key] = torch.stack([
+ embeddings[key][i, torch.randperm(embeddings[key].shape[1])[:tot], :]
+ for i in range(embeddings[key].shape[0])], 0)
+
+ # Don't encode if not requested
+ if not self.really_encode:
+ return self.bypass_encode(data, scene, idx=idx)
+
+ # Encode embeddings
+ encode_output = self.models[idx](data=embeddings, scene=scene)
+
+ # Return embeddings, latent space, and updated data
+ return {
+ 'embeddings': embeddings,
+ 'latent': encode_output['last_hidden_state'],
+ 'data': data,
+ }
+
+ def decode(self, encoded, sources=None, data=None, embeddings=None, sample_decoder=0, scene=None, monodepth=None):
+ assert data is not None or embeddings is not None
+ assert data is None or embeddings is None
+
+ # Get default sources if they are not provided
+ sources = sources if sources is not None else self.decode_sources
+
+ # Initialize structures
+ outputs, previous = [], None
+ merged_output = {'losses': {}, 'embeddings': {}, 'output': {}}
+
+ # Decode output for each source
+ for source in sources:
+ output = self.single_decode(encoded, source[1], source[2], source[3], source[4],
+ data, embeddings, previous, sample_decoder,
+ is_variational=source[5], scene=scene, monodepth=monodepth)
+ previous = output['output']
+ outputs.append(output)
+
+ # Combine all outputs
+ for output, source in zip(outputs, sources):
+ name = '' if len(source[0]) == 0 else '_' + source[0]
+ merged_output['losses'].update({'%s%s' % (key, name): val
+ for key, val in output['losses'].items()})
+ merged_output['embeddings'].update({'%s%s' % (key, name): val
+ for key, val in output['embeddings'].items()})
+ merged_output['output'].update({'%s%s' % (key, name): val
+ for key, val in output['output'].items()})
+
+ # Return merged output
+ return merged_output
+
+ def single_decode(self, encoded, idx, camera_mode, sources, tasks, data=None,
+ embeddings=None, previous=None, sample_decoder=0,
+ is_variational=False, scene=None, monodepth=None):
+
+ # Create embeddings if they are not provided
+ if embeddings is None:
+ embeddings = self.get_embeddings(
+ data=data, sources=sources, camera_mode=camera_mode, downsample=self.downsample_decoder,
+ encoded=encoded[idx], sample=sample_decoder, previous=previous, monodepth=monodepth,
+ )
+
+####
+
+ latent = encoded[idx]['latent']
+
+ #if self.latent_grid:
+ if 0:
+ latent1, latent2 = latent
+ for key in embeddings.keys():
+ xyz = embeddings[key]['info']['xyz']
+ embeddings[key]['grid'] = latent1.sample(xyz).squeeze(1).permute(0, 2, 3, 1)
+ sources = [s for s in sources] + ['grid']
+ latent = latent2
+
+####
+
+ # Initialize output and losses
+ output, losses = {}, {}
+
+ # Get additional information and stack embeddings according to source
+ info = {key: val['info'] for key, val in embeddings.items()}
+ source_embeddings = {key: torch.cat([
+ val[source] for source in sources if source not in self.extra_sources], -1)
+ for key, val in embeddings.items()}
+ extra_embeddings = {key: torch.cat([
+ val[source] for source in self.extra_sources], -1) if len(self.extra_sources) > 0 else None
+ for key, val in embeddings.items()}
+
+ # Expand latent space if batch dimensions do not agree
+ batch_size = source_embeddings[list(source_embeddings.keys())[0]].shape[0]
+ if not is_dict(latent):
+ if latent.shape[0] == 1 and latent.shape[0] != batch_size:
+ latent = latent.repeat(batch_size, 1, 1)
+ else:
+ for key in latent.keys():
+ if latent[key].shape[0] == 1 and latent[key].shape[0] != batch_size:
+ latent[key] = latent[key].repeat(batch_size, 1, 1)
+
+ # Sample from latent space if it's variational
+ if is_variational:
+ output_variational = self.sample_from_latent(latent)
+ losses.update(**{key: val for key, val in output_variational.items() if 'loss' in key})
+ latent = output_variational['sampled_latent']
+
+ # Decode all tasks from each embedding
+ for task in tasks:
+ output[task] = {key: make_list(self.decoders[task](
+ query=val, z=latent[key] if is_dict(latent) else latent,
+ key=key, info=info, previous=previous,
+ extra=extra_embeddings[key], scene=scene)['predictions'])
+ for key, val in source_embeddings.items()}
+ # Break volumetric into rgb and depth predictions
+ if task.startswith('volumetric'):
+ for task_key in ['rgb', 'depth']:
+ output[task.replace('volumetric', task_key)] = {
+ key: [v[task_key] for v in val] for key, val in output[task].items()}
+ if task.startswith('multiview'):
+ for task_key in ['rgb', 'depth']:
+ output[task.replace('multiview', task_key)] = {
+ key: [v[task_key] for v in val] for key, val in output[task].items()}
+
+ output['info'] = info
+
+ # Return losses, embeddings, and output
+ return {
+ 'losses': losses,
+ 'embeddings': embeddings,
+ 'output': output,
+ }
+
+ def multi_decode(self, encoded, sources=None, data=None,
+ embeddings=None, sample_decoder=0, num_evaluations=None, scene=None, monodepth=None):
+ output = {}
+
+ for i in range(num_evaluations):
+ output_i = self.decode(
+ encoded, sources, data, embeddings, sample_decoder, scene=scene, monodepth=monodepth)
+
+ if i == 0:
+ output = {key: val for key, val in output_i['output'].items()}
+ else:
+ for task in output_i['output'].keys():
+ if str_not_in(task, ['info', 'volumetric']):
+ for key in output_i['output'][task].keys():
+ output[task][key].extend(output_i['output'][task][key])
+
+ if not self.training:
+ for task in list(output.keys()):
+ if str_not_in(task, ['info', 'volumetric']):
+ output[f'{task}_mean'] = {}
+ output[f'stddev_{task}'] = {}
+ for key in output[task].keys():
+ val = torch.stack(output[task][key], 0)
+ output[f'{task}_mean'][key] = [val.mean(0)]
+ output[f'stddev_{task}'][key] = [val.std(0).sum(1, keepdim=True)]
+
+ return {
+ 'losses': None,
+ 'embeddings': embeddings,
+ 'output': output,
+ }
+
+ def sample_from_latent(self, latent):
+
+ if is_dict(latent):
+ latents = {key: self.sample_from_latent(val) for key, val in latent.items()}
+ output = {
+ 'sampled_latent': {key: val['sampled_latent'] for key, val in latents.items()}
+ }
+ if self.training:
+ output['kld_loss'] = sum([val['kld_loss'] for val in latents.values()]) / len(latents)
+ return output
+
+ n = latent.shape[-1] // 2
+ mu, logvar = latent[:, :, :n], latent[:, :, n:]
+ logvar = logvar.clamp(max=10)
+
+ std = torch.exp(0.5 * logvar)
+ eps = torch.randn_like(std)
+ sampled_latent = eps * std + mu
+
+ output = {
+ 'sampled_latent': sampled_latent
+ }
+
+ if self.training:
+ output['kld_loss'] = self.variational_params['kld_weight'] * torch.mean(
+ - 0.5 * torch.mean(1 + logvar - mu ** 2 - logvar.exp(), dim=[1, 2]), dim=0)
+
+ return output
+
diff --git a/vidar/arch/networks/perceiver/externals/modeling_perceiver.py b/vidar/arch/networks/perceiver/externals/modeling_perceiver.py
new file mode 100755
index 0000000000000000000000000000000000000000..988eefb9cd54d2227e92cd68f114dfced5f4df53
--- /dev/null
+++ b/vidar/arch/networks/perceiver/externals/modeling_perceiver.py
@@ -0,0 +1,3751 @@
+# coding=utf-8
+# Copyright 2021 Deepmind and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch Perceiver model."""
+
+from transformers import PerceiverModel, PerceiverConfig, PreTrainedModel, apply_chunking_to_forward
+from transformers.utils import ModelOutput
+import abc
+import math
+from dataclasses import dataclass
+from functools import reduce
+from operator import __add__
+from typing import Any, Callable, Mapping, Optional, Tuple
+
+import numpy as np
+import torch
+import torch.utils.checkpoint
+from torch import nn, Tensor
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from packaging import version
+import torch.nn.functional as tfn
+
+
+def upsample_tensor(tensor, mask, up=8):
+ b, c, h, w = tensor.shape
+ mask = mask.view(b, 1, 9, up, up, h, w)
+ mask = torch.softmax(mask, dim=2)
+
+ up_tensor = tfn.unfold(tensor, [3, 3], padding=1)
+ up_tensor = up_tensor.view(b, -1, 9, 1, 1, h, w)
+
+ up_tensor = torch.sum(mask * up_tensor, dim=2)
+ up_tensor = up_tensor.permute(0, 1, 4, 2, 5, 3)
+ return up_tensor.reshape(b, -1, up * h, up * w)
+
+
+class NewGELUActivation(nn.Module):
+ """
+ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
+ the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
+ """
+
+ def forward(self, input: Tensor) -> Tensor:
+ return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
+
+
+class GELUActivation(nn.Module):
+ """
+ Original Implementation of the GELU activation function in Google BERT repo when initially created. For
+ information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
+ torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
+ Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
+ """
+
+ def __init__(self, use_gelu_python: bool = False):
+ super().__init__()
+ if use_gelu_python:
+ self.act = self._gelu_python
+ else:
+ self.act = nn.functional.gelu
+
+ def _gelu_python(self, input: Tensor) -> Tensor:
+ return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
+
+ def forward(self, input: Tensor) -> Tensor:
+ return self.act(input)
+
+
+class FastGELUActivation(nn.Module):
+ """
+ Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
+ """
+
+ def forward(self, input: Tensor) -> Tensor:
+ return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
+
+
+class QuickGELUActivation(nn.Module):
+ """
+ Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
+ """
+
+ def forward(self, input: Tensor) -> Tensor:
+ return input * torch.sigmoid(1.702 * input)
+
+
+class ClippedGELUActivation(nn.Module):
+ """
+ Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as
+ it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to
+ https://arxiv.org/abs/2004.09602.
+ Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
+ initially created.
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
+ torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415
+ """
+
+ def __init__(self, min: float, max: float):
+ if min > max:
+ raise ValueError(
+ f"min should be < max (got min: {min}, max: {max})")
+
+ super().__init__()
+ self.min = min
+ self.max = max
+
+ def forward(self, x: Tensor) -> Tensor:
+ return torch.clip(gelu(x), self.min, self.max)
+
+
+class SiLUActivation(nn.Module):
+ """
+ See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
+ Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
+ Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
+ Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
+ later.
+ """
+
+ def forward(self, input: Tensor) -> Tensor:
+ return nn.functional.silu(input)
+
+
+class MishActivation(nn.Module):
+ """
+ See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
+ visit the official repository for the paper: https://github.com/digantamisra98/Mish
+ """
+
+ def __init__(self):
+ super().__init__()
+ if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.9"):
+ self.act = self._mish_python
+ else:
+ self.act = nn.functional.mish
+
+ def _mish_python(self, input: Tensor) -> Tensor:
+ return input * torch.tanh(nn.functional.softplus(input))
+
+ def forward(self, input: Tensor) -> Tensor:
+ return self.act(input)
+
+
+class LinearActivation(nn.Module):
+ """
+ Applies the linear activation function, i.e. forwarding input directly to output.
+ """
+
+ def forward(self, input: Tensor) -> Tensor:
+ return input
+
+
+ACT2FN = {
+ "gelu": GELUActivation(),
+ "gelu_10": ClippedGELUActivation(-10, 10),
+ "gelu_fast": FastGELUActivation(),
+ "gelu_new": NewGELUActivation(),
+ "gelu_python": GELUActivation(use_gelu_python=True),
+ "linear": LinearActivation(),
+ "mish": MishActivation(),
+ "quick_gelu": QuickGELUActivation(),
+ "relu": nn.ReLU(),
+ "sigmoid": nn.Sigmoid(),
+ "silu": SiLUActivation(),
+ "swish": SiLUActivation(),
+ "tanh": nn.Tanh(),
+}
+
+
+def get_activation(activation_string):
+ if activation_string in ACT2FN:
+ return ACT2FN[activation_string]
+ else:
+ raise KeyError(
+ f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
+#from ...activations import ACT2FN
+# from ...file_utils import (
+# ModelOutput,
+# add_start_docstrings,
+# add_start_docstrings_to_model_forward,
+# replace_return_docstrings,
+# )
+#from ...modeling_outputs import BaseModelOutputWithCrossAttentions
+# from ...modeling_utils import (
+# PreTrainedModel,
+# apply_chunking_to_forward,
+# find_pruneable_heads_and_indices,
+# prune_linear_layer,
+# )
+
+
+#from ...utils import logging
+#from .configuration_perceiver import PerceiverConfig
+
+
+ModalitySizeType = Mapping[str, int]
+PreprocessorOutputType = Tuple[torch.Tensor,
+ Optional[torch.Tensor], torch.Tensor]
+PreprocessorType = Callable[..., PreprocessorOutputType]
+PostprocessorType = Callable[..., Any]
+
+#logger = logging.get_logger(__name__)
+
+#_CHECKPOINT_FOR_DOC = "deepmind/language-perceiver"
+_CONFIG_FOR_DOC = "PerceiverConfig"
+#_TOKENIZER_FOR_DOC = "PerceiverTokenizer"
+
+PERCEIVER_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "deepmind/language-perceiver",
+ # See all Perceiver models at https://huggingface.co/models?filter=perceiver
+]
+
+
+@dataclass
+class PerceiverModelOutput(ModelOutput):
+ """
+ Base class for Perceiver base model's outputs, with potential hidden states, attentions and cross-attentions.
+
+ Args:
+ logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
+ plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
+ the self-attention heads.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
+ used to compute the weighted average in the cross-attention heads.
+ """
+
+ logits: torch.FloatTensor = None
+ last_hidden_state: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class PerceiverDecoderOutput(ModelOutput):
+ """
+ Base class for Perceiver decoder outputs, with potential cross-attentions.
+
+ Args:
+ logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
+ Output of the basic decoder.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
+ used to compute the weighted average in the cross-attention heads.
+ """
+
+ logits: torch.FloatTensor = None
+ cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class PerceiverMaskedLMOutput(ModelOutput):
+ """
+ Base class for Perceiver's masked language model outputs.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Masked language modeling (MLM) loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
+ plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, num_latents,
+ num_latents)`. Attentions weights after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
+ used to compute the weighted average in the cross-attention heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class PerceiverClassifierOutput(ModelOutput):
+ """
+ Base class for Perceiver's outputs of sequence/image classification models, optical flow and multimodal
+ autoencoding.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
+ plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
+ the self-attention heads.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
+ used to compute the weighted average in the cross-attention heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+class PerceiverEmbeddings(nn.Module):
+ """Construct the latent embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.latents = nn.Parameter(torch.randn(
+ config.num_latents, config.d_latents))
+
+ def forward(self, batch_size):
+ return self.latents.expand(batch_size, -1, -1) # Thanks, Phil Wang
+
+
+class PerceiverSelfAttention(nn.Module):
+ """Multi-headed {cross, self}-attention. Can be used both in the encoder as well as in the decoder."""
+
+ def __init__(
+ self,
+ config,
+ is_cross_attention=False,
+ qk_channels=None,
+ v_channels=None,
+ num_heads=1,
+ q_dim=None,
+ kv_dim=None,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ # Q and K must have the same number of channels.
+ # Default to preserving Q's input's shape.
+ if qk_channels is None:
+ qk_channels = q_dim
+ # V's num_channels determines the shape of the output of QKV-attention.
+ # Default to the same number of channels used in the key-query operation.
+ if v_channels is None:
+ v_channels = qk_channels
+ if qk_channels % num_heads != 0:
+ raise ValueError(
+ f"qk_channels ({qk_channels}) must be divisible by num_heads ({num_heads}).")
+ if v_channels % num_heads != 0:
+ raise ValueError(
+ f"v_channels ({v_channels}) must be divisible by num_heads ({num_heads}).")
+
+ self.qk_channels = qk_channels
+ self.v_channels = v_channels
+ self.qk_channels_per_head = self.qk_channels // num_heads
+ self.v_channels_per_head = self.v_channels // num_heads
+
+ # Layer normalization
+ self.layernorm1 = nn.LayerNorm(q_dim)
+ self.layernorm2 = nn.LayerNorm(
+ kv_dim) if is_cross_attention else nn.Identity()
+
+ # Projection matrices
+ self.query = nn.Linear(q_dim, qk_channels)
+ self.key = nn.Linear(kv_dim, qk_channels)
+ self.value = nn.Linear(kv_dim, v_channels)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def transpose_for_scores(self, x, channels_per_head):
+ new_x_shape = x.size()[:-1] + (self.num_heads, channels_per_head)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ inputs=None,
+ inputs_mask=None,
+ output_attentions=False,
+ ):
+ hidden_states = self.layernorm1(hidden_states)
+ inputs = self.layernorm2(inputs)
+
+ # Project queries, keys and values to a common feature dimension. If this is instantiated as a cross-attention module,
+ # the keys and values come from the inputs; the attention mask needs to be such that the inputs's non-relevant tokens are not attended to.
+ is_cross_attention = inputs is not None
+ queries = self.query(hidden_states)
+
+ if is_cross_attention:
+ keys = self.key(inputs)
+ values = self.value(inputs)
+ attention_mask = inputs_mask
+ else:
+ keys = self.key(hidden_states)
+ values = self.value(hidden_states)
+
+ # Reshape channels for multi-head attention.
+ # We reshape from (batch_size, time, channels) to (batch_size, num_heads, time, channels per head)
+ queries = self.transpose_for_scores(queries, self.qk_channels_per_head)
+ keys = self.transpose_for_scores(keys, self.qk_channels_per_head)
+ values = self.transpose_for_scores(values, self.v_channels_per_head)
+
+ # Take the dot product between the queries and keys to get the raw attention scores.
+ attention_scores = torch.matmul(queries, keys.transpose(-1, -2))
+
+ batch_size, num_heads, seq_len, q_head_dim = queries.shape
+ _, _, _, v_head_dim = values.shape
+ hiddens = self.num_heads * v_head_dim
+
+ attention_scores = attention_scores / math.sqrt(q_head_dim)
+
+ if attention_mask is not None:
+ # Apply the attention mask (precomputed for all layers in PerceiverModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, values)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (hiddens,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (
+ context_layer,)
+
+ return outputs
+
+
+class PerceiverSelfOutput(nn.Module):
+ def __init__(self, config, input_channels, output_channels):
+ super().__init__()
+ self.dense = nn.Linear(input_channels, output_channels)
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ return hidden_states
+
+
+class PerceiverAttention(nn.Module):
+ """Attention module, including a dense block."""
+
+ def __init__(
+ self,
+ config,
+ is_cross_attention=False,
+ qk_channels=None,
+ v_channels=None,
+ num_heads=1,
+ q_dim=None,
+ kv_dim=None,
+ use_query_residual=True,
+ ):
+ super().__init__()
+ # MultiHead attention
+ if is_cross_attention and qk_channels is None:
+ if config.cross_attention_shape_for_attention == "q":
+ qk_channels = q_dim
+ elif config.cross_attention_shape_for_attention == "kv":
+ qk_channels = kv_dim
+ else:
+ raise ValueError(
+ f"Unknown value {config.cross_attention_shape_for_attention} for "
+ "cross_attention_shape_for_attention."
+ )
+ else:
+ if qk_channels is None:
+ qk_channels = q_dim
+ if v_channels is None:
+ v_channels = qk_channels
+ self.self = PerceiverSelfAttention(
+ config,
+ is_cross_attention=is_cross_attention,
+ qk_channels=qk_channels,
+ v_channels=v_channels,
+ num_heads=num_heads,
+ q_dim=q_dim,
+ kv_dim=kv_dim,
+ )
+ # dense block
+ output_channels = None
+ if is_cross_attention:
+ output_channels = q_dim
+ else:
+ if output_channels is None:
+ output_channels = v_channels
+ self.output = PerceiverSelfOutput(
+ config, input_channels=self.self.v_channels, output_channels=output_channels)
+ self.use_query_residual = use_query_residual
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - \
+ len(heads)
+ self.self.all_head_size = self.self.attention_head_size * \
+ self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ inputs=None,
+ inputs_mask=None,
+ output_attentions=False,
+ ):
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ inputs,
+ inputs_mask,
+ output_attentions,
+ )
+
+ # Output projection
+ attention_output = self.output(self_outputs[0])
+
+ # Optionally include a residual to the original queries.
+ # Consider omitting the residual if the semantics of query and output
+ # are different, e.g. if queries are positions and outputs are pixels.
+ if self.use_query_residual:
+ attention_output = attention_output + hidden_states
+
+ # add attentions if we output them
+ outputs = (attention_output,) + self_outputs[1:]
+ return outputs
+
+
+class PerceiverMLP(nn.Module):
+ """A Transformer-style dense module to follow attention."""
+
+ def __init__(self, config, input_size, widening_factor):
+ super().__init__()
+ self.dense1 = nn.Linear(input_size, widening_factor * input_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+ self.dense2 = nn.Linear(input_size, input_size)
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense1(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ hidden_states = self.dense2(hidden_states)
+ return hidden_states
+
+
+class PerceiverLayer(nn.Module):
+ def __init__(
+ self,
+ config,
+ is_cross_attention=False,
+ qk_channels=None,
+ v_channels=None,
+ num_heads=1,
+ q_dim=None,
+ kv_dim=None,
+ widening_factor=4,
+ use_query_residual=True,
+ ):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = PerceiverAttention(
+ config,
+ is_cross_attention=is_cross_attention,
+ qk_channels=qk_channels,
+ v_channels=v_channels,
+ num_heads=num_heads,
+ q_dim=q_dim,
+ kv_dim=kv_dim,
+ use_query_residual=use_query_residual,
+ )
+ self.layernorm = nn.LayerNorm(q_dim)
+ self.mlp = PerceiverMLP(config, input_size=q_dim,
+ widening_factor=widening_factor)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ inputs=None,
+ inputs_mask=None,
+ output_attentions=False,
+ ):
+ attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ inputs,
+ inputs_mask,
+ output_attentions,
+ )
+ attention_output = attention_outputs[0]
+
+ # add attentions if we output attention weights
+ outputs = attention_outputs[1:]
+
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+ )
+
+ layer_output = layer_output + attention_output # residual connection
+
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ layer_output = self.layernorm(attention_output)
+ layer_output = self.mlp(layer_output)
+ return layer_output
+
+
+class PerceiverEncoder(nn.Module):
+ """The Perceiver Encoder: a scalable, fully attentional encoder."""
+
+ def __init__(self, config, kv_dim=None):
+ super().__init__()
+ self.config = config
+
+ # Check that we can use multihead-attention with these shapes.
+ if config.d_latents % config.num_self_attention_heads != 0:
+ raise ValueError(
+ f"num_z_channels ({config.d_latents}) must be divisible by"
+ f" num_self_attend_heads ({config.num_self_attention_heads})."
+ )
+ if config.d_latents % config.num_cross_attention_heads != 0:
+ raise ValueError(
+ f"num_z_channels ({config.d_latents}) must be divisible by"
+ f" num_cross_attend_heads ({config.num_cross_attention_heads})."
+ )
+
+ # Construct the cross attention layer.
+ self.cross_attention = PerceiverLayer(
+ config,
+ is_cross_attention=True,
+ qk_channels=config.qk_channels,
+ v_channels=config.v_channels,
+ num_heads=config.num_cross_attention_heads,
+ q_dim=config.d_latents,
+ kv_dim=kv_dim,
+ widening_factor=config.cross_attention_widening_factor,
+ use_query_residual=config.use_query_residual,
+ )
+
+ # Construct a single block of self-attention layers.
+ # We get deeper architectures by applying this block more than once.
+ self_attention_layers = []
+ for _ in range(config.num_self_attends_per_block):
+ layer = PerceiverLayer(
+ config,
+ is_cross_attention=False,
+ qk_channels=config.qk_channels,
+ v_channels=config.v_channels,
+ num_heads=config.num_self_attention_heads,
+ q_dim=config.d_latents,
+ kv_dim=config.d_latents,
+ widening_factor=config.self_attention_widening_factor,
+ )
+ self_attention_layers.append(layer)
+
+ self.self_attends = nn.ModuleList(self_attention_layers)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ inputs=None,
+ inputs_mask=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions else None
+
+ # Apply the cross-attention between the latents (hidden_states) and inputs:
+ layer_outputs = self.cross_attention(
+ hidden_states,
+ attention_mask=attention_mask,
+ head_mask=None,
+ inputs=inputs,
+ inputs_mask=inputs_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_cross_attentions = all_cross_attentions + (layer_outputs[1],)
+
+ # Apply the block of self-attention layers more than once:
+ for _ in range(self.config.num_blocks):
+ for i, layer_module in enumerate(self.self_attends):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask=attention_mask,
+ head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+ if output_attentions:
+ all_self_attentions = all_self_attentions + \
+ (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions]
+ if v is not None
+ )
+ return BaseModelOutputWithCrossAttentions(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class PerceiverPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = PerceiverConfig
+ base_model_prefix = "perceiver"
+ main_input_name = "inputs"
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(
+ mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif hasattr(module, "latents"):
+ module.latents.data.normal_(
+ mean=0.0, std=self.config.initializer_range)
+ elif hasattr(module, "position_embeddings") and isinstance(module, PerceiverTrainablePositionEncoding):
+ module.position_embeddings.data.normal_(
+ mean=0.0, std=self.config.initializer_range)
+ elif isinstance(module, nn.ParameterDict):
+ for modality in module.keys():
+ module[modality].data.normal_(
+ mean=0.0, std=self.config.initializer_range)
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(
+ mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+PERCEIVER_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`PerceiverConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+PERCEIVER_MODEL_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`PerceiverConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+ decoder (*DecoderType*, *optional*):
+ Optional decoder to use to decode the latent representation of the encoder. Examples include
+ *transformers.models.perceiver.modeling_perceiver.PerceiverBasicDecoder*,
+ *transformers.models.perceiver.modeling_perceiver.PerceiverClassificationDecoder*,
+ *transformers.models.perceiver.modeling_perceiver.PerceiverMultimodalDecoder*.
+ input_preprocessor (*PreprocessorType*, *optional*):
+ Optional input preprocessor to use. Examples include
+ *transformers.models.perceiver.modeling_perceiver.PerceiverImagePreprocessor*,
+ *transformers.models.perceiver.modeling_perceiver.PerceiverAudioPreprocessor*,
+ *transformers.models.perceiver.modeling_perceiver.PerceiverTextPreprocessor*,
+ *transformers.models.perceiver.modeling_perceiver.PerceiverMultimodalPreprocessor*.
+ output_postprocessor (*PostprocessorType*, *optional*):
+ Optional output postprocessor to use. Examples include
+ *transformers.models.perceiver.modeling_perceiver.PerceiverImagePostprocessor*,
+ *transformers.models.perceiver.modeling_perceiver.PerceiverAudioPostprocessor*,
+ *transformers.models.perceiver.modeling_perceiver.PerceiverClassificationPostprocessor*,
+ *transformers.models.perceiver.modeling_perceiver.PerceiverProjectionPostprocessor*,
+ *transformers.models.perceiver.modeling_perceiver.PerceiverMultimodalPostprocessor*.
+
+ Note that you can define your own decoders, preprocessors and/or postprocessors to fit your use-case.
+"""
+
+PERCEIVER_INPUTS_DOCSTRING = r"""
+ Args:
+ inputs (`torch.FloatTensor`):
+ Inputs to the perceiver. Can be anything: images, text, audio, video, etc.
+ attention_mask (`torch.FloatTensor` of shape `{0}`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+# @add_start_docstrings(
+# """The Perceiver: a scalable, fully attentional architecture.""",
+# PERCEIVER_MODEL_START_DOCSTRING,
+# )
+class PerceiverModel(PerceiverPreTrainedModel):
+ def __init__(
+ self,
+ config,
+ decoder=None,
+ input_preprocessor: PreprocessorType = None,
+ output_postprocessor: PostprocessorType = None,
+ ):
+ super().__init__(config)
+ self.config = config
+
+ self.input_preprocessor = input_preprocessor
+ self.output_postprocessor = output_postprocessor
+ self.embeddings = PerceiverEmbeddings(config)
+ self.encoder = PerceiverEncoder(
+ config, kv_dim=input_preprocessor.num_channels if input_preprocessor is not None else config.d_model
+ )
+ self.decoder = decoder
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.latents
+
+ def set_input_embeddings(self, value):
+ self.embeddings.latents = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ # @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
+ # @replace_return_docstrings(output_type=PerceiverModelOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ inputs,
+ attention_mask=None,
+ subsampled_output_points=None,
+ head_mask=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import PerceiverConfig, PerceiverTokenizer, PerceiverFeatureExtractor, PerceiverModel
+ >>> from transformers.models.perceiver.modeling_perceiver import (
+ ... PerceiverTextPreprocessor,
+ ... PerceiverImagePreprocessor,
+ ... PerceiverClassificationDecoder,
+ ... )
+ >>> import torch
+ >>> import requests
+ >>> from PIL import Image
+
+ >>> # EXAMPLE 1: using the Perceiver to classify texts
+ >>> # - we define a TextPreprocessor, which can be used to embed tokens
+ >>> # - we define a ClassificationDecoder, which can be used to decode the
+ >>> # final hidden states of the latents to classification logits
+ >>> # using trainable position embeddings
+ >>> config = PerceiverConfig()
+ >>> preprocessor = PerceiverTextPreprocessor(config)
+ >>> decoder = PerceiverClassificationDecoder(
+ ... config,
+ ... num_channels=config.d_latents,
+ ... trainable_position_encoding_kwargs=dict(num_channels=config.d_latents, index_dims=1),
+ ... use_query_residual=True,
+ ... )
+ >>> model = PerceiverModel(config, input_preprocessor=preprocessor, decoder=decoder)
+
+ >>> # you can then do a forward pass as follows:
+ >>> tokenizer = PerceiverTokenizer()
+ >>> text = "hello world"
+ >>> inputs = tokenizer(text, return_tensors="pt").input_ids
+
+ >>> with torch.no_grad():
+ ... outputs = model(inputs=inputs)
+ >>> logits = outputs.logits
+
+ >>> # to train, one can train the model using standard cross-entropy:
+ >>> criterion = torch.nn.CrossEntropyLoss()
+
+ >>> labels = torch.tensor([1])
+ >>> loss = criterion(logits, labels)
+
+ >>> # EXAMPLE 2: using the Perceiver to classify images
+ >>> # - we define an ImagePreprocessor, which can be used to embed images
+ >>> preprocessor = PerceiverImagePreprocessor(
+ ... config,
+ ... prep_type="conv1x1",
+ ... spatial_downsample=1,
+ ... out_channels=256,
+ ... position_encoding_type="trainable",
+ ... concat_or_add_pos="concat",
+ ... project_pos_dim=256,
+ ... trainable_position_encoding_kwargs=dict(
+ ... num_channels=256,
+ ... index_dims=config.image_size ** 2,
+ ... ),
+ ... )
+
+ >>> model = PerceiverModel(
+ ... config,
+ ... input_preprocessor=preprocessor,
+ ... decoder=PerceiverClassificationDecoder(
+ ... config,
+ ... num_channels=config.d_latents,
+ ... trainable_position_encoding_kwargs=dict(num_channels=config.d_latents, index_dims=1),
+ ... use_query_residual=True,
+ ... ),
+ ... )
+
+ >>> # you can then do a forward pass as follows:
+ >>> feature_extractor = PerceiverFeatureExtractor()
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+ >>> inputs = feature_extractor(image, return_tensors="pt").pixel_values
+
+ >>> with torch.no_grad():
+ ... outputs = model(inputs=inputs)
+ >>> logits = outputs.logits
+
+ >>> # to train, one can train the model using standard cross-entropy:
+ >>> criterion = torch.nn.CrossEntropyLoss()
+
+ >>> labels = torch.tensor([1])
+ >>> loss = criterion(logits, labels)
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # if self.input_preprocessor is not None:
+ # inputs, modality_sizes, inputs_without_pos = self.input_preprocessor(inputs)
+ # else:
+ # modality_sizes = None
+ # inputs_without_pos = None
+ if inputs.size()[-1] != self.config.d_model:
+ raise ValueError(
+ f"Last dimension of the inputs: {inputs.size()[-1]} doesn't correspond to config.d_model: {self.config.d_model}. "
+ "Make sure to set config.d_model appropriately."
+ )
+
+ batch_size, seq_length, _ = inputs.size()
+ device = inputs.device
+
+ # If no attention mask is provided, make them all ones
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ ((batch_size, seq_length)), device=device)
+ # Make the attention mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ extended_attention_mask = self.invert_attention_mask(attention_mask)
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_blocks x num_heads]
+ # and head_mask is converted to shape [num_blocks x batch x num_heads x N x N]
+ head_mask = self.get_head_mask(
+ head_mask, self.config.num_blocks * self.config.num_self_attends_per_block)
+
+ embedding_output = self.embeddings(batch_size=batch_size)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=None,
+ head_mask=head_mask,
+ inputs=inputs,
+ inputs_mask=extended_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+
+ logits = None
+ # if self.decoder:
+ # if subsampled_output_points is not None:
+ # output_modality_sizes = {
+ # "audio": subsampled_output_points["audio"].shape[0],
+ # "image": subsampled_output_points["image"].shape[0],
+ # "label": 1,
+ # }
+ # else:
+ # output_modality_sizes = None
+ # decoder_query = self.decoder.decoder_query(
+ # inputs, modality_sizes, inputs_without_pos, subsampled_points=subsampled_output_points
+ # )
+ # decoder_outputs = self.decoder(
+ # decoder_query,
+ # z=sequence_output,
+ # query_mask=extended_attention_mask,
+ # output_attentions=output_attentions,
+ # )
+ # logits = decoder_outputs.logits
+ #
+ # # add cross-attentions of decoder
+ # if output_attentions and decoder_outputs.cross_attentions is not None:
+ # if return_dict:
+ # encoder_outputs.cross_attentions = (
+ # encoder_outputs.cross_attentions + decoder_outputs.cross_attentions
+ # )
+ # else:
+ # encoder_outputs = encoder_outputs + decoder_outputs.cross_attentions
+ #
+ # if self.output_postprocessor:
+ # logits = self.output_postprocessor(logits, modality_sizes=output_modality_sizes)
+
+ if not return_dict:
+ if logits is not None:
+ return (logits, sequence_output) + encoder_outputs[1:]
+ else:
+ return (sequence_output,) + encoder_outputs[1:]
+
+ return PerceiverModelOutput(
+ logits=logits,
+ last_hidden_state=sequence_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+
+# @add_start_docstrings("""Example use of Perceiver for masked language modeling.""", PERCEIVER_START_DOCSTRING)
+class PerceiverForMaskedLM(PerceiverPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ text_preprocessor = PerceiverTextPreprocessor(config)
+
+ trainable_position_encoding_kwargs_decoder = dict(
+ num_channels=text_preprocessor.num_channels, index_dims=config.max_position_embeddings
+ )
+
+ self.perceiver = PerceiverModel(
+ config,
+ input_preprocessor=text_preprocessor,
+ decoder=PerceiverBasicDecoder(
+ config,
+ output_num_channels=config.d_latents,
+ # we need to define the seq_len of the inputs beforehand
+ output_index_dims=config.max_position_embeddings,
+ num_channels=text_preprocessor.num_channels,
+ qk_channels=8 * 32,
+ v_channels=text_preprocessor.num_channels,
+ num_heads=8,
+ use_query_residual=False,
+ final_project=False,
+ trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder,
+ ),
+ )
+ self.embedding_decoder = PerceiverEmbeddingDecoder(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ # @replace_return_docstrings(output_type=PerceiverMaskedLMOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ inputs=None,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ labels=None,
+ return_dict=None,
+ input_ids=None,
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import PerceiverTokenizer, PerceiverForMaskedLM
+ >>> import torch
+
+ >>> tokenizer = PerceiverTokenizer.from_pretrained("deepmind/language-perceiver")
+ >>> model = PerceiverForMaskedLM.from_pretrained("deepmind/language-perceiver")
+
+ >>> # training
+ >>> text = "This is an incomplete sentence where some words are missing."
+ >>> inputs = tokenizer(text, padding="max_length", return_tensors="pt")
+ >>> # mask " missing."
+ >>> inputs["input_ids"][0, 52:61] = tokenizer.mask_token_id
+ >>> labels = tokenizer(text, padding="max_length", return_tensors="pt").input_ids
+
+ >>> outputs = model(**inputs, labels=labels)
+ >>> loss = outputs.loss
+ >>> logits = outputs.logits
+
+ >>> # inference
+ >>> text = "This is an incomplete sentence where some words are missing."
+ >>> encoding = tokenizer(text, padding="max_length", return_tensors="pt")
+
+ >>> # mask bytes corresponding to " missing.". Note that the model performs much better if the masked span starts with a space.
+ >>> encoding["input_ids"][0, 52:61] = tokenizer.mask_token_id
+
+ >>> # forward pass
+ >>> with torch.no_grad():
+ ... outputs = model(**encoding)
+ >>> logits = outputs.logits
+
+ >>> masked_tokens_predictions = logits[0, 52:61].argmax(dim=-1).tolist()
+ >>> tokenizer.decode(masked_tokens_predictions)
+ ' missing.'
+ ```"""
+ if inputs is not None and input_ids is not None:
+ raise ValueError("You cannot use both `inputs` and `input_ids`")
+ elif inputs is None and input_ids is not None:
+ inputs = input_ids
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.perceiver(
+ inputs=inputs,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ logits = self.embedding_decoder(
+ outputs.logits if return_dict else outputs[0], embedding_layer=self.perceiver.input_preprocessor.embeddings
+ )
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
+ masked_lm_loss = loss_fct(
+ logits.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return PerceiverMaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+
+# @add_start_docstrings("""Example use of Perceiver for text classification.""", PERCEIVER_START_DOCSTRING)
+class PerceiverForSequenceClassification(PerceiverPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ trainable_position_encoding_kwargs_decoder = dict(
+ num_channels=config.d_latents, index_dims=1)
+
+ self.num_labels = config.num_labels
+ self.perceiver = PerceiverModel(
+ config,
+ input_preprocessor=PerceiverTextPreprocessor(config),
+ decoder=PerceiverClassificationDecoder(
+ config,
+ num_channels=config.d_latents,
+ trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder,
+ use_query_residual=True,
+ ),
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ # @replace_return_docstrings(output_type=PerceiverClassifierOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ inputs=None,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ labels=None,
+ return_dict=None,
+ input_ids=None,
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the classification/regression loss. Indices should be in `[0, ..., config.num_labels -
+ 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels >
+ 1` a classification loss is computed (Cross-Entropy).
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import PerceiverTokenizer, PerceiverForSequenceClassification
+
+ >>> tokenizer = PerceiverTokenizer.from_pretrained("deepmind/language-perceiver")
+ >>> model = PerceiverForSequenceClassification.from_pretrained("deepmind/language-perceiver")
+
+ >>> text = "hello world"
+ >>> inputs = tokenizer(text, return_tensors="pt").input_ids
+ >>> outputs = model(inputs=inputs)
+ >>> logits = outputs.logits
+ ```"""
+ if inputs is not None and input_ids is not None:
+ raise ValueError("You cannot use both `inputs` and `input_ids`")
+ elif inputs is None and input_ids is not None:
+ inputs = input_ids
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.perceiver(
+ inputs=inputs,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ logits = outputs.logits if return_dict else outputs[0]
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(
+ logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return PerceiverClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+
+# @add_start_docstrings(
+ """
+Example use of Perceiver for image classification, for tasks such as ImageNet.
+
+This model uses learned position embeddings. In other words, this model is not given any privileged information about
+the structure of images. As shown in the paper, this model can achieve a top-1 accuracy of 72.7 on ImageNet.
+
+[`PerceiverForImageClassificationLearned`] uses [`~models.perceiver.modeling_perceiver.PerceiverImagePreprocessor`]
+(with `prep_type="conv1x1"`) to preprocess the input images, and
+[`~models.perceiver.modeling_perceiver.PerceiverClassificationDecoder`] to decode the latent representation of
+[`PerceiverModel`] into classification logits.
+""",
+ # PERCEIVER_START_DOCSTRING,
+# )
+
+
+class PerceiverForImageClassificationLearned(PerceiverPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ trainable_position_encoding_kwargs_preprocessor = dict(
+ num_channels=256, index_dims=config.image_size ** 2)
+ trainable_position_encoding_kwargs_decoder = dict(
+ num_channels=config.d_latents, index_dims=1)
+
+ self.num_labels = config.num_labels
+ self.perceiver = PerceiverModel(
+ config,
+ input_preprocessor=PerceiverImagePreprocessor(
+ config,
+ prep_type="conv1x1",
+ spatial_downsample=1,
+ out_channels=256,
+ position_encoding_type="trainable",
+ concat_or_add_pos="concat",
+ project_pos_dim=256,
+ trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_preprocessor,
+ ),
+ decoder=PerceiverClassificationDecoder(
+ config,
+ num_channels=config.d_latents,
+ trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder,
+ use_query_residual=True,
+ ),
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ # @replace_return_docstrings(output_type=PerceiverClassifierOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ inputs=None,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ labels=None,
+ return_dict=None,
+ pixel_values=None,
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import PerceiverFeatureExtractor, PerceiverForImageClassificationLearned
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> feature_extractor = PerceiverFeatureExtractor.from_pretrained("deepmind/vision-perceiver-learned")
+ >>> model = PerceiverForImageClassificationLearned.from_pretrained("deepmind/vision-perceiver-learned")
+
+ >>> inputs = feature_extractor(images=image, return_tensors="pt").pixel_values
+ >>> outputs = model(inputs=inputs)
+ >>> logits = outputs.logits
+ >>> # model predicts one of the 1000 ImageNet classes
+ >>> predicted_class_idx = logits.argmax(-1).item()
+ >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
+ ```"""
+ if inputs is not None and pixel_values is not None:
+ raise ValueError("You cannot use both `inputs` and `pixel_values`")
+ elif inputs is None and pixel_values is not None:
+ inputs = pixel_values
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.perceiver(
+ inputs=inputs,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ logits = outputs.logits if return_dict else outputs[0]
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(
+ logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return PerceiverClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+
+# @add_start_docstrings(
+# """
+# Example use of Perceiver for image classification, for tasks such as ImageNet.
+
+# This model uses fixed 2D Fourier position embeddings. As shown in the paper, this model can achieve a top-1 accuracy of
+# 79.0 on ImageNet, and 84.5 when pre-trained on a large-scale dataset (i.e. JFT).
+
+# [`PerceiverForImageClassificationLearned`] uses [`~models.perceiver.modeling_perceiver.PerceiverImagePreprocessor`]
+# (with `prep_type="pixels"`) to preprocess the input images, and
+# [`~models.perceiver.modeling_perceiver.PerceiverClassificationDecoder`] to decode the latent representation of
+# [`PerceiverModel`] into classification logits.
+# """,
+# PERCEIVER_START_DOCSTRING,
+# )
+class PerceiverForImageClassificationFourier(PerceiverPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ fourier_position_encoding_kwargs_preprocessor = dict(
+ concat_pos=True, max_resolution=(224, 224), num_bands=64, sine_only=False
+ )
+ trainable_position_encoding_kwargs_decoder = dict(
+ num_channels=config.d_latents, index_dims=1)
+
+ self.num_labels = config.num_labels
+ self.perceiver = PerceiverModel(
+ config,
+ input_preprocessor=PerceiverImagePreprocessor(
+ config,
+ prep_type="pixels",
+ spatial_downsample=1,
+ fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor,
+ ),
+ decoder=PerceiverClassificationDecoder(
+ config,
+ num_channels=config.d_latents,
+ trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder,
+ use_query_residual=True,
+ ),
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ # @replace_return_docstrings(output_type=PerceiverClassifierOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ inputs=None,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ labels=None,
+ return_dict=None,
+ pixel_values=None,
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import PerceiverFeatureExtractor, PerceiverForImageClassificationFourier
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> feature_extractor = PerceiverFeatureExtractor.from_pretrained("deepmind/vision-perceiver-fourier")
+ >>> model = PerceiverForImageClassificationFourier.from_pretrained("deepmind/vision-perceiver-fourier")
+
+ >>> inputs = feature_extractor(images=image, return_tensors="pt").pixel_values
+ >>> outputs = model(inputs=inputs)
+ >>> logits = outputs.logits
+ >>> # model predicts one of the 1000 ImageNet classes
+ >>> predicted_class_idx = logits.argmax(-1).item()
+ >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
+ ```"""
+ if inputs is not None and pixel_values is not None:
+ raise ValueError("You cannot use both `inputs` and `pixel_values`")
+ elif inputs is None and pixel_values is not None:
+ inputs = pixel_values
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.perceiver(
+ inputs=inputs,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ logits = outputs.logits if return_dict else outputs[0]
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(
+ logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return PerceiverClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+
+# @add_start_docstrings(
+# """
+# Example use of Perceiver for image classification, for tasks such as ImageNet.
+
+# This model uses a 2D conv+maxpool preprocessing network. As shown in the paper, this model can achieve a top-1 accuracy
+# of 82.1 on ImageNet.
+
+# [`PerceiverForImageClassificationLearned`] uses [`~models.perceiver.modeling_perceiver.PerceiverImagePreprocessor`]
+# (with `prep_type="conv"`) to preprocess the input images, and
+# [`~models.perceiver.modeling_perceiver.PerceiverClassificationDecoder`] to decode the latent representation of
+# [`PerceiverModel`] into classification logits.
+# """,
+# PERCEIVER_START_DOCSTRING,
+# )
+class PerceiverForImageClassificationConvProcessing(PerceiverPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ fourier_position_encoding_kwargs_preprocessor = dict(
+ concat_pos=True, max_resolution=(56, 56), num_bands=64, sine_only=False
+ )
+ trainable_position_encoding_kwargs_decoder = dict(
+ num_channels=config.d_latents, index_dims=1)
+
+ self.num_labels = config.num_labels
+ self.perceiver = PerceiverModel(
+ config,
+ input_preprocessor=PerceiverImagePreprocessor(
+ config,
+ prep_type="conv",
+ spatial_downsample=1,
+ position_encoding_type="fourier",
+ fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor,
+ ),
+ decoder=PerceiverClassificationDecoder(
+ config,
+ num_channels=config.d_latents,
+ trainable_position_encoding_kwargs=trainable_position_encoding_kwargs_decoder,
+ use_query_residual=True,
+ ),
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ # @replace_return_docstrings(output_type=PerceiverClassifierOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ inputs=None,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ labels=None,
+ return_dict=None,
+ pixel_values=None,
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import PerceiverFeatureExtractor, PerceiverForImageClassificationConvProcessing
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> feature_extractor = PerceiverFeatureExtractor.from_pretrained("deepmind/vision-perceiver-conv")
+ >>> model = PerceiverForImageClassificationConvProcessing.from_pretrained("deepmind/vision-perceiver-conv")
+
+ >>> inputs = feature_extractor(images=image, return_tensors="pt").pixel_values
+ >>> outputs = model(inputs=inputs)
+ >>> logits = outputs.logits
+ >>> # model predicts one of the 1000 ImageNet classes
+ >>> predicted_class_idx = logits.argmax(-1).item()
+ >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
+ ```"""
+ if inputs is not None and pixel_values is not None:
+ raise ValueError("You cannot use both `inputs` and `pixel_values`")
+ elif inputs is None and pixel_values is not None:
+ inputs = pixel_values
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.perceiver(
+ inputs=inputs,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ logits = outputs.logits if return_dict else outputs[0]
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(
+ logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return PerceiverClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+
+# @add_start_docstrings(
+# """
+# Example use of Perceiver for optical flow, for tasks such as Sintel and KITTI. [`PerceiverForOpticalFlow`] uses
+# [`~models.perceiver.modeling_perceiver.PerceiverImagePreprocessor`] (with *prep_type="patches"*) to preprocess the
+# input images, and [`~models.perceiver.modeling_perceiver.PerceiverOpticalFlowDecoder`] to decode the latent
+# representation of [`PerceiverModel`].
+
+# As input, one concatenates 2 subsequent frames along the channel dimension and extract a 3 x 3 patch around each pixel
+# (leading to 3 x 3 x 3 x 2 = 54 values for each pixel). Fixed Fourier position encodings are used to encode the position
+# of each pixel in the patch. Next, one applies the Perceiver encoder. To decode, one queries the latent representation
+# using the same encoding used for the input.
+# """,
+# PERCEIVER_START_DOCSTRING,
+# )
+class PerceiverForOpticalFlow(PerceiverPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ fourier_position_encoding_kwargs_preprocessor = dict(
+ num_bands=64,
+ max_resolution=config.train_size,
+ sine_only=False,
+ concat_pos=True,
+ )
+ fourier_position_encoding_kwargs_decoder = dict(
+ concat_pos=True, max_resolution=config.train_size, num_bands=64, sine_only=False
+ )
+
+ image_preprocessor = PerceiverImagePreprocessor(
+ config,
+ prep_type="patches",
+ spatial_downsample=1,
+ conv_after_patching=True,
+ conv_after_patching_in_channels=54,
+ temporal_downsample=2,
+ position_encoding_type="fourier",
+ # position_encoding_kwargs
+ fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor,
+ )
+
+ self.perceiver = PerceiverModel(
+ config,
+ input_preprocessor=image_preprocessor,
+ decoder=PerceiverOpticalFlowDecoder(
+ config,
+ num_channels=image_preprocessor.num_channels,
+ output_image_shape=config.train_size,
+ rescale_factor=100.0,
+ # decoder kwargs
+ use_query_residual=False,
+ output_num_channels=2,
+ # We query the decoder using the first frame features
+ # rather than a standard decoder position encoding.
+ position_encoding_type="fourier",
+ fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_decoder,
+ ),
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ # @replace_return_docstrings(output_type=PerceiverClassifierOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ inputs=None,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ labels=None,
+ return_dict=None,
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the optical flow loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import PerceiverForOpticalFlow
+ >>> import torch
+
+ >>> model = PerceiverForOpticalFlow.from_pretrained("deepmind/optical-flow-perceiver")
+
+ >>> # in the Perceiver IO paper, the authors extract a 3 x 3 patch around each pixel,
+ >>> # leading to 3 x 3 x 3 = 27 values for each pixel (as each pixel also has 3 color channels)
+ >>> # patches have shape (batch_size, num_frames, num_channels, height, width)
+ >>> # the authors train on resolutions of 368 x 496
+ >>> patches = torch.randn(1, 2, 27, 368, 496)
+ >>> outputs = model(inputs=patches)
+ >>> logits = outputs.logits
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.perceiver(
+ inputs=inputs,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ logits = outputs.logits if return_dict else outputs[0]
+
+ loss = None
+ if labels is not None:
+ raise NotImplementedError(
+ "Optical flow training is not yet supported")
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return PerceiverClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+
+# @add_start_docstrings(
+ """
+Example use of Perceiver for multimodal (video) autoencoding, for tasks such as Kinetics-700.
+
+[`PerceiverForMultimodalAutoencoding`] uses [`~models.perceiver.modeling_perceiver.PerceiverMultimodalPreprocessor`] to
+preprocess the 3 modalities: images, audio and class labels. This preprocessor uses modality-specific preprocessors to
+preprocess every modality separately, after which they are concatenated. Trainable position embeddings are used to pad
+each modality to the same number of channels to make concatenation along the time dimension possible. Next, one applies
+the Perceiver encoder.
+
+[`~models.perceiver.modeling_perceiver.PerceiverMultimodalDecoder`] is used to decode the latent representation of
+[`PerceiverModel`]. This decoder uses each modality-specific decoder to construct queries. The decoder queries are
+created based on the inputs after preprocessing. However, autoencoding an entire video in a single forward pass is
+computationally infeasible, hence one only uses parts of the decoder queries to do cross-attention with the latent
+representation. This is determined by the subsampled indices for each modality, which can be provided as additional
+input to the forward pass of [`PerceiverForMultimodalAutoencoding`].
+
+[`~models.perceiver.modeling_perceiver.PerceiverMultimodalDecoder`] also pads the decoder queries of the different
+modalities to the same number of channels, in order to concatenate them along the time dimension. Next, cross-attention
+is performed with the latent representation of [`PerceiverModel`].
+
+Finally, [`~models.perceiver.modeling_perceiver.PerceiverMultiModalPostprocessor`] is used to turn this tensor into an
+actual video. It first splits up the output into the different modalities, and then applies the respective
+postprocessor for each modality.
+
+Note that, by masking the classification label during evaluation (i.e. simply providing a tensor of zeros for the
+"label" modality), this auto-encoding model becomes a Kinetics 700 video classifier.
+#""",
+# PERCEIVER_START_DOCSTRING,
+# )
+
+
+class PerceiverForMultimodalAutoencoding(PerceiverPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ n_audio_samples = config.num_frames * config.audio_samples_per_frame
+
+ input_preprocessor = PerceiverMultimodalPreprocessor(
+ min_padding_size=4,
+ modalities={
+ "audio": PerceiverAudioPreprocessor(
+ config,
+ position_encoding_type="fourier",
+ fourier_position_encoding_kwargs=dict(
+ num_bands=192,
+ max_resolution=(n_audio_samples,),
+ sine_only=False,
+ concat_pos=True,
+ ),
+ prep_type="patches",
+ samples_per_patch=config.samples_per_patch,
+ ),
+ "image": PerceiverImagePreprocessor(
+ config,
+ position_encoding_type="fourier",
+ fourier_position_encoding_kwargs=dict(
+ num_bands=32,
+ max_resolution=(config.num_frames,
+ config.image_size, config.image_size),
+ sine_only=False,
+ concat_pos=True,
+ ),
+ prep_type="patches",
+ spatial_downsample=4,
+ temporal_downsample=1,
+ ),
+ "label": PerceiverOneHotPreprocessor(config),
+ },
+ mask_probs={"image": 0.0, "audio": 0.0, "label": 1.0},
+ )
+
+ image_decoder = PerceiverBasicVideoAutoencodingDecoder(
+ config,
+ # Autoencoding, don't pass inputs to the queries.
+ concat_preprocessed_input=False,
+ output_shape=config.output_shape,
+ output_num_channels=512,
+ use_query_residual=False,
+ position_encoding_only=True,
+ position_encoding_type="fourier",
+ fourier_position_encoding_kwargs=dict(
+ num_bands=32,
+ max_resolution=(config.num_frames,
+ config.image_size, config.image_size),
+ sine_only=False,
+ concat_pos=True,
+ ),
+ )
+
+ decoder = PerceiverMultimodalDecoder(
+ config,
+ # Autoencoding, don't pass inputs to the queries.
+ concat_preprocessed_input=False,
+ # Modality specific decoders are used ONLY to generate queries.
+ # All modalties are decoded together using a unified decoder.
+ modalities={
+ "audio": PerceiverBasicDecoder(
+ config,
+ # Autoencoding, don't pass inputs to the queries.
+ concat_preprocessed_input=False,
+ output_index_dims=(n_audio_samples //
+ config.samples_per_patch,),
+ output_num_channels=512,
+ use_query_residual=False,
+ position_encoding_only=True,
+ position_encoding_type="fourier",
+ fourier_position_encoding_kwargs=dict(
+ num_bands=192,
+ max_resolution=(n_audio_samples,),
+ sine_only=False,
+ concat_pos=True,
+ ),
+ ),
+ "image": image_decoder,
+ "label": PerceiverClassificationDecoder(
+ config,
+ # Autoencoding, don't pass inputs to the queries.
+ concat_preprocessed_input=False,
+ use_query_residual=False,
+ position_encoding_only=True,
+ position_encoding_type="trainable",
+ trainable_position_encoding_kwargs=dict(
+ num_channels=1024,
+ index_dims=1,
+ ),
+ ),
+ },
+ num_outputs=None,
+ output_num_channels=512,
+ use_query_residual=False,
+ )
+
+ output_postprocessor = PerceiverMultimodalPostprocessor(
+ modalities={
+ "audio": PerceiverAudioPostprocessor(config, in_channels=512),
+ "image": PerceiverProjectionPostprocessor(in_channels=512, out_channels=3),
+ "label": PerceiverClassificationPostprocessor(config, in_channels=512),
+ }
+ )
+
+ self.perceiver = PerceiverModel(
+ config,
+ input_preprocessor=input_preprocessor,
+ decoder=decoder,
+ output_postprocessor=output_postprocessor,
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # @add_start_docstrings_to_model_forward(PERCEIVER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ # @replace_return_docstrings(output_type=PerceiverClassifierOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ inputs=None,
+ attention_mask=None,
+ subsampled_output_points=None,
+ head_mask=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ labels=None,
+ return_dict=None,
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import PerceiverForMultimodalAutoencoding
+ >>> import torch
+ >>> import numpy as np
+
+ >>> # create multimodal inputs
+ >>> images = torch.randn((1, 16, 3, 224, 224))
+ >>> audio = torch.randn((1, 30720, 1))
+ >>> inputs = dict(image=images, audio=audio, label=torch.zeros((images.shape[0], 700)))
+
+ >>> model = PerceiverForMultimodalAutoencoding.from_pretrained("deepmind/multimodal-perceiver")
+
+ >>> # in the Perceiver IO paper, videos are auto-encoded in chunks
+ >>> # each chunk subsamples different index dimensions of the image and audio modality decoder queries
+ >>> nchunks = 128
+ >>> image_chunk_size = np.prod((16, 224, 224)) // nchunks
+ >>> audio_chunk_size = audio.shape[1] // model.config.samples_per_patch // nchunks
+ >>> # process the first chunk
+ >>> chunk_idx = 0
+ >>> subsampling = {
+ ... "image": torch.arange(image_chunk_size * chunk_idx, image_chunk_size * (chunk_idx + 1)),
+ ... "audio": torch.arange(audio_chunk_size * chunk_idx, audio_chunk_size * (chunk_idx + 1)),
+ ... "label": None,
+ ... }
+
+ >>> outputs = model(inputs=inputs, subsampled_output_points=subsampling)
+ >>> logits = outputs.logits
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.perceiver(
+ inputs=inputs,
+ attention_mask=attention_mask,
+ subsampled_output_points=subsampled_output_points,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ logits = outputs.logits if return_dict else outputs[0]
+
+ loss = None
+ if labels is not None:
+ raise NotImplementedError(
+ "Multimodal autoencoding training is not yet supported")
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return PerceiverClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+
+# Below: position encodings
+
+
+def build_position_encoding(
+ position_encoding_type,
+ out_channels=None,
+ project_pos_dim=-1,
+ trainable_position_encoding_kwargs=None,
+ fourier_position_encoding_kwargs=None,
+):
+ """
+ Builds the position encoding.
+
+ Args:
+
+ - out_channels: refers to the number of channels of the position encodings.
+ - project_pos_dim: if specified, will project the position encodings to this dimension.
+
+ """
+
+ if position_encoding_type == "trainable":
+ if not trainable_position_encoding_kwargs:
+ raise ValueError(
+ "Make sure to pass trainable_position_encoding_kwargs")
+ output_pos_enc = PerceiverTrainablePositionEncoding(
+ **trainable_position_encoding_kwargs)
+ elif position_encoding_type == "fourier":
+ # We don't use the index_dims argument, as this is only known during the forward pass
+ if not fourier_position_encoding_kwargs:
+ raise ValueError(
+ "Make sure to pass fourier_position_encoding_kwargs")
+ output_pos_enc = PerceiverFourierPositionEncoding(
+ **fourier_position_encoding_kwargs)
+ else:
+ raise ValueError(
+ f"Unknown position encoding type: {position_encoding_type}.")
+
+ # Optionally, project the position encoding to a target dimension:
+ positions_projection = nn.Linear(
+ out_channels, project_pos_dim) if project_pos_dim > 0 else nn.Identity()
+
+ return output_pos_enc, positions_projection
+
+
+# Below: Perceiver decoders
+
+
+class PerceiverAbstractDecoder(nn.Module, metaclass=abc.ABCMeta):
+ """Perceiver abstract decoder."""
+
+ @abc.abstractmethod
+ def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
+ raise NotImplementedError
+
+ @property
+ @abc.abstractmethod
+ def num_query_channels(self):
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def forward(self, query, z, query_mask=None):
+ raise NotImplementedError
+
+
+class PerceiverProjectionDecoder(PerceiverAbstractDecoder):
+ """
+ Baseline projection decoder (no cross-attention).
+
+ Args:
+ config ([`PerceiverConfig`]):
+ Model configuration.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.classifier = nn.Linear(config.d_latents, config.num_labels)
+
+ def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
+ return None
+
+ def forward(self, query, z, query_mask=None):
+ # (batch_size, num_latents, d_latents) -> (batch_size, d_latents)
+ z = torch.mean(z, dim=1)
+ # (batch_size, d_latents) -> (batch_size, config.num_labels)
+ logits = self.classifier(z)
+ return logits
+
+
+class PerceiverBasicDecoder(PerceiverAbstractDecoder):
+ """
+ Cross-attention-based decoder. This class can be used to decode the final hidden states of the latents using a
+ cross-attention operation, in which the latents produce keys and values.
+
+ The shape of the output of this class depends on how one defines the output queries (also called decoder queries).
+
+ Args:
+ config ([*PerceiverConfig*]):
+ Model configuration.
+ output_num_channels (`int`, *optional*):
+ The number of channels in the output. Will only be used in case *final_project* is set to `True`.
+ position_encoding_type (`str`, *optional*, defaults to "trainable"):
+ The type of position encoding to use. Can be either "trainable", "fourier", or "none".
+ output_index_dims (`int`, *optional*):
+ The number of dimensions of the output queries. Ignored if 'position_encoding_type' == 'none'.
+ num_channels (`int`, *optional*):
+ The number of channels of the decoder queries. Ignored if 'position_encoding_type' == 'none'.
+ qk_channels (`int`, *optional*):
+ The number of channels of the queries and keys in the cross-attention layer.
+ v_channels (`int`, *optional*, defaults to 128):
+ The number of channels of the values in the cross-attention layer.
+ num_heads (`int`, *optional*, defaults to 1):
+ The number of attention heads in the cross-attention layer.
+ widening_factor (`int`, *optional*, defaults to 1):
+ The widening factor of the cross-attention layer.
+ use_query_residual (`bool`, *optional*, defaults to `False`):
+ Whether to use a residual connection between the query and the output of the cross-attention layer.
+ concat_preprocessed_input (`bool`, *optional*, defaults to `False`):
+ Whether to concatenate the preprocessed input to the query.
+ final_project (`bool`, *optional*, defaults to `True`):
+ Whether to project the output of the cross-attention layer to a target dimension.
+ position_encoding_only (`bool`, *optional*, defaults to `False`):
+ Whether to only use this class to define output queries.
+ """
+
+ def __init__(
+ self,
+ config,
+ output_num_channels,
+ position_encoding_type="trainable",
+ # The following 2 arguments are ignored if position_encoding_type == 'none':
+ output_index_dims=None,
+ num_channels=128,
+ subsampled_index_dims=None,
+ qk_channels=None,
+ v_channels=None,
+ num_heads=1,
+ widening_factor=1,
+ use_query_residual=False,
+ concat_preprocessed_input=False,
+ final_project=True,
+ position_encoding_only=False,
+ **position_encoding_kwargs,
+ ):
+ super().__init__()
+
+ self.output_num_channels = output_num_channels
+ # If `none`, the decoder will not construct any position encodings.
+ # You should construct your own when quering the decoder.
+ self.output_position_encodings = None
+ self.position_encoding_type = position_encoding_type
+ self.position_encoding_kwargs = position_encoding_kwargs
+ if position_encoding_type != "none":
+ self.output_position_encodings, self.positions_projection = build_position_encoding(
+ position_encoding_type=position_encoding_type, **position_encoding_kwargs
+ )
+
+ self.output_index_dims = output_index_dims
+ self.num_channels = num_channels
+ if subsampled_index_dims is None:
+ subsampled_index_dims = output_index_dims
+ self.subsampled_index_dims = subsampled_index_dims
+ self.concat_preprocessed_input = concat_preprocessed_input
+ self.final_project = final_project
+ self.position_encoding_only = position_encoding_only
+
+ # for multimodal autoencoding, we don't need the decoder cross-attention and final layer
+ # so then we will set position_encoding_only to True
+ if not self.position_encoding_only:
+ self.decoding_cross_attention = PerceiverLayer(
+ config,
+ is_cross_attention=True,
+ qk_channels=qk_channels,
+ v_channels=v_channels,
+ num_heads=num_heads,
+ q_dim=num_channels,
+ kv_dim=config.d_latents,
+ widening_factor=widening_factor,
+ use_query_residual=use_query_residual,
+ )
+ self.final_layer = nn.Linear(
+ num_channels, output_num_channels) if final_project else nn.Identity()
+
+ @property
+ def num_query_channels(self) -> int:
+ if self.position_encoding_type == "none": # Queries come from elsewhere
+ raise ValueError(
+ "You cannot calculate number of decoder query channels when position_encoding_type is set to none"
+ )
+ if self.position_encoding_only:
+ if "project_pos_dim" in self.position_encoding_kwargs:
+ return self.position_encoding_kwargs["project_pos_dim"]
+ return self.output_position_encodings.output_size()
+ if self.final_project:
+ return self.output_num_channels
+ return self.num_channels
+
+ def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
+ if self.position_encoding_type == "none": # Queries come from elsewhere
+ raise ValueError(
+ "You cannot construct decoder queries when position_encoding_type is set to none")
+ if subsampled_points is not None:
+ # subsampled_points are the indices if the inputs would be flattened
+ # however, the inputs aren't flattened, that's why we use unravel_index
+ # to get the indices for the unflattened array
+ # unravel_index returns a tuple (x_idx, y_idx, ...)
+ # stack to get the [n, d] tensor of coordinates
+ indices = list(
+ torch.from_numpy(x) for x in np.unravel_index(subsampled_points.cpu(), self.output_index_dims)
+ )
+ pos = torch.stack(indices, dim=1)
+ batch_size = inputs.shape[0]
+ # Map these coordinates to [-1, 1]
+ pos = -1 + 2 * pos / torch.tensor(self.output_index_dims)[None, :]
+ pos = torch.broadcast_to(
+ pos[None], [batch_size, pos.shape[0], pos.shape[1]])
+ # Construct the position encoding.
+ if self.position_encoding_type == "trainable":
+ pos_emb = self.output_position_encodings(batch_size)
+ elif self.position_encoding_type == "fourier":
+ pos_emb = self.output_position_encodings(
+ self.output_index_dims, batch_size=batch_size, device=inputs.device, pos=pos
+ )
+
+ # Optionally project them to a target dimension.
+ pos_emb = self.positions_projection(pos_emb)
+ pos_emb = torch.reshape(
+ pos_emb, [pos_emb.shape[0], -1, pos_emb.shape[-1]])
+ else:
+ batch_size = inputs.shape[0]
+ index_dims = inputs.shape[2:]
+
+ # Construct the position encoding.
+ if self.position_encoding_type == "trainable":
+ pos_emb = self.output_position_encodings(batch_size)
+ elif self.position_encoding_type == "fourier":
+ pos_emb = self.output_position_encodings(
+ index_dims, batch_size, device=inputs.device)
+
+ # Optionally project them to a target dimension.
+ pos_emb = self.positions_projection(pos_emb)
+
+ if self.concat_preprocessed_input:
+ if inputs_without_pos is None:
+ raise ValueError(
+ "Value is required for inputs_without_pos if concat_preprocessed_input is True")
+ pos_emb = torch.cat([inputs_without_pos, pos_emb], div=-1)
+
+ return pos_emb
+
+ def forward(self, query, z, query_mask=None, output_attentions=False):
+ # Cross-attention decoding.
+ # key, value: B x N x K; query: B x M x K
+ # Attention maps -> B x N x M
+ # Output -> B x M x K
+ cross_attentions = () if output_attentions else None
+
+ layer_outputs = self.decoding_cross_attention(
+ query,
+ attention_mask=query_mask,
+ head_mask=None,
+ inputs=z,
+ inputs_mask=None,
+ output_attentions=output_attentions,
+ )
+ output = layer_outputs[0]
+
+ if output_attentions:
+ cross_attentions = cross_attentions + (layer_outputs[1],)
+
+ logits = self.final_layer(output)
+
+ return PerceiverDecoderOutput(logits=logits, cross_attentions=cross_attentions)
+
+
+class PerceiverClassificationDecoder(PerceiverAbstractDecoder):
+ """
+ Cross-attention based classification decoder. Light-weight wrapper of [`PerceiverBasicDecoder`] for logit output.
+ Will turn the output of the Perceiver encoder which is of shape (batch_size, num_latents, d_latents) to a tensor of
+ shape (batch_size, num_labels). The queries are of shape (batch_size, 1, num_labels).
+
+ Args:
+ config ([`PerceiverConfig`]):
+ Model configuration.
+ """
+
+ def __init__(self, config, **decoder_kwargs):
+ super().__init__()
+
+ self.num_labels = config.num_labels
+ self.decoder = PerceiverBasicDecoder(
+ config,
+ output_num_channels=self.num_labels,
+ output_index_dims=1, # Predict a single logit array.
+ **decoder_kwargs,
+ )
+
+ @property
+ def num_query_channels(self) -> int:
+ return self.decoder.num_query_channels
+
+ def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
+ return self.decoder.decoder_query(
+ inputs, modality_sizes, inputs_without_pos, subsampled_points=subsampled_points
+ )
+
+ def forward(self, query, z, query_mask=None, output_attentions=False):
+ decoder_outputs = self.decoder(
+ query, z, output_attentions=output_attentions)
+
+ # B x 1 x num_classes -> B x num_classes
+ logits = decoder_outputs.logits[:, 0, :]
+
+ return PerceiverDecoderOutput(logits=logits, cross_attentions=decoder_outputs.cross_attentions)
+
+
+class PerceiverOpticalFlowDecoder(PerceiverAbstractDecoder):
+ """Cross-attention based optical flow decoder."""
+
+ def __init__(self, config, output_image_shape, output_num_channels=2, rescale_factor=100.0, **decoder_kwargs):
+ super().__init__()
+
+ self.output_image_shape = output_image_shape
+ self.output_num_channels = output_num_channels
+ self.rescale_factor = rescale_factor
+ self.decoder = PerceiverBasicDecoder(
+ config, output_num_channels=output_num_channels, **decoder_kwargs)
+
+ @property
+ def num_query_channels(self) -> int:
+ return self.decoder.num_query_channels
+
+ def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
+ if subsampled_points is not None:
+ raise ValueError("FlowDecoder doesn't support subsampling yet.")
+ return inputs
+
+ def forward(self, query, z, query_mask=None, output_attentions=False):
+ decoder_outputs = self.decoder(
+ query, z, output_attentions=output_attentions)
+ preds = decoder_outputs.logits
+ # Output flow and rescale.
+ preds /= self.rescale_factor
+ preds = preds.reshape(
+ [preds.shape[0]] + list(self.output_image_shape) + [preds.shape[-1]])
+ print('qwerqwerqwer', decoder_outputs.cross_attentions)
+ return PerceiverDecoderOutput(logits=preds, cross_attentions=decoder_outputs.cross_attentions)
+
+
+class PerceiverRGBDecoder(PerceiverAbstractDecoder):
+ """Cross-attention based optical flow decoder."""
+
+ def __init__(self, config, output_num_channels=3, upsample_mode=None, upsample_value=None, **decoder_kwargs):
+ super().__init__()
+
+ if upsample_value != 1:
+ self.upsample_mode = upsample_mode
+ self.upsample_value = upsample_value
+ if self.upsample_mode == 'convex':
+ output_num_channels_mask = 9 * upsample_value ** 2
+ self.decoder_mask = PerceiverBasicDecoder(
+ config, output_num_channels=output_num_channels_mask, **decoder_kwargs)
+ elif self.upsample_mode == 'bilinear':
+ from vidar.utils.tensor import interpolate
+ from functools import partial
+ self.interpolate = partial(interpolate, scale_factor=upsample_value,
+ size=None, mode='bilinear', align_corners=True)
+ elif self.upsample_mode == 'unpack':
+ output_num_channels *= upsample_value ** 2
+ self.interpolate = torch.nn.PixelShuffle(upsample_value)
+ else:
+ self.upsample_mode = None
+
+ self.output_num_channels = output_num_channels
+ self.decoder = PerceiverBasicDecoder(
+ config, output_num_channels=output_num_channels, **decoder_kwargs)
+ self.sigmoid = torch.nn.Sigmoid()
+
+ @property
+ def num_query_channels(self) -> int:
+ return self.decoder.num_query_channels
+
+ def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
+ if subsampled_points is not None:
+ raise ValueError("FlowDecoder doesn't support subsampling yet.")
+ return inputs
+
+ def forward(self, query, z, shape=None, query_mask=None, output_attentions=False):
+ decoder_outputs = self.decoder(
+ query, z, output_attentions=output_attentions)
+ pred = decoder_outputs.logits
+
+ if shape is not None:
+ pred = pred.reshape(
+ [pred.shape[0]] + list(shape) + [pred.shape[-1]]).permute(0, 3, 1, 2)
+
+ pred = self.sigmoid(pred)
+
+ if self.upsample_mode == 'convex':
+ mask = self.decoder_mask(
+ query, z, output_attentions=output_attentions).logits
+ if shape is not None:
+ mask = mask.reshape(
+ [mask.shape[0]] + list(shape) + [mask.shape[-1]]).permute(0, 3, 1, 2)
+ pred = upsample_tensor(pred, mask, up=self.upsample_value)
+ elif self.upsample_mode == 'bilinear':
+ pred = self.interpolate(pred)
+ elif self.upsample_mode == 'unpack':
+ pred = self.interpolate(pred)
+
+ return PerceiverDecoderOutput(logits=pred, cross_attentions=decoder_outputs.cross_attentions)
+
+
+class PerceiverDepthDecoder(PerceiverAbstractDecoder):
+ """Cross-attention based optical flow decoder."""
+
+ def __init__(self, config, output_num_channels=1, upsample_mode=None, upsample_value=None,
+ min_depth=0.1, max_depth=100.0, output_mode='inv_depth', **decoder_kwargs):
+ super().__init__()
+
+ if upsample_value != 1:
+ self.upsample_mode = upsample_mode
+ self.upsample_value = upsample_value
+ if self.upsample_mode == 'convex':
+ output_num_channels_mask = 9 * upsample_value ** 2
+ self.decoder_mask = PerceiverBasicDecoder(
+ config, output_num_channels=output_num_channels_mask, **decoder_kwargs)
+ elif self.upsample_mode == 'bilinear':
+ from vidar.utils.tensor import interpolate
+ from functools import partial
+ self.interpolate = partial(interpolate, scale_factor=upsample_value,
+ size=None, mode='bilinear', align_corners=True)
+ elif self.upsample_mode == 'unpack':
+ output_num_channels *= upsample_value ** 2
+ self.interpolate = torch.nn.PixelShuffle(upsample_value)
+ else:
+ self.upsample_mode = None
+
+ self.output_num_channels = output_num_channels
+ self.decoder = PerceiverBasicDecoder(
+ config, output_num_channels=output_num_channels, **decoder_kwargs)
+
+ self.output_mode = output_mode
+ if self.output_mode == 'inv_depth':
+ from vidar.arch.blocks.depth.SigmoidToInvDepth import SigmoidToInvDepth
+ self.sigmoid = torch.nn.Sigmoid()
+ self.sigmoid_to_depth = SigmoidToInvDepth(
+ min_depth=min_depth, max_depth=max_depth, return_depth=True)
+ elif self.output_mode == 'log_depth':
+ from vidar.arch.blocks.depth.SigmoidToLogDepth import SigmoidToLogDepth
+ self.sigmoid_to_log_depth = SigmoidToLogDepth()
+ else:
+ raise ValueError('Invalid depth output mode')
+
+ @property
+ def num_query_channels(self) -> int:
+ return self.decoder.num_query_channels
+
+ def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
+ if subsampled_points is not None:
+ raise ValueError("FlowDecoder doesn't support subsampling yet.")
+ return inputs
+
+ def forward(self, query, z, shape=None, query_mask=None, output_attentions=False):
+ decoder_outputs = self.decoder(
+ query, z, output_attentions=output_attentions)
+ pred = decoder_outputs.logits
+
+ if shape is not None:
+ pred = pred.reshape(
+ [pred.shape[0]] + list(shape) + [pred.shape[-1]]).permute(0, 3, 1, 2)
+
+ if self.output_mode == 'inv_depth':
+ pred = self.sigmoid_to_depth(self.sigmoid(pred))
+ elif self.output_mode == 'log_depth':
+ pred = self.sigmoid_to_log_depth(pred)
+
+ if self.upsample_mode == 'convex':
+ mask = self.decoder_mask(
+ query, z, output_attentions=output_attentions).logits
+ if shape is not None:
+ mask = mask.reshape(
+ [mask.shape[0]] + list(shape) + [mask.shape[-1]]).permute(0, 3, 1, 2)
+ pred = upsample_tensor(pred, mask, up=self.upsample_value)
+ elif self.upsample_mode == 'bilinear':
+ pred = self.interpolate(pred)
+ elif self.upsample_mode == 'unpack':
+ pred = self.interpolate(pred)
+
+ return PerceiverDecoderOutput(logits=pred, cross_attentions=decoder_outputs.cross_attentions)
+
+
+class PerceiverBasicVideoAutoencodingDecoder(PerceiverAbstractDecoder):
+ """
+ Cross-attention based video-autoencoding decoder. Light-weight wrapper of [*PerceiverBasicDecoder*] with video
+ reshaping logic.
+
+ Args:
+ config ([*PerceiverConfig*]):
+ Model configuration.
+ output_shape (`List[int]`):
+ Shape of the output as (batch_size, num_frames, height, width), excluding the channel dimension.
+ position_encoding_type (`str`):
+ The type of position encoding to use. Can be either "trainable", "fourier", or "none".
+ """
+
+ def __init__(self, config, output_shape, position_encoding_type, **decoder_kwargs):
+ super().__init__()
+ if len(output_shape) != 4: # B, T, H, W
+ raise ValueError(
+ f"Expected rank 4 output_shape, got {output_shape}.")
+ # Build the decoder components:
+ self.output_shape = output_shape
+ self.output_num_channels = decoder_kwargs["output_num_channels"]
+
+ self.decoder = PerceiverBasicDecoder(
+ config,
+ output_index_dims=self.output_shape[1:4], # T*H*W
+ position_encoding_type=position_encoding_type,
+ **decoder_kwargs,
+ )
+
+ @property
+ def num_query_channels(self) -> int:
+ return self.decoder.num_query_channels
+
+ def decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
+ return self.decoder.decoder_query(
+ inputs,
+ modality_sizes=modality_sizes,
+ inputs_without_pos=inputs_without_pos,
+ subsampled_points=subsampled_points,
+ )
+
+ def forward(self, query, z, query_mask=None):
+ decoder_outputs = self.decoder(query, z)
+ logits = decoder_outputs.logits
+
+ logits = torch.reshape(logits, self.output_shape + [logits.shape[-1]])
+ return PerceiverDecoderOutput(logits=logits, cross_attentions=decoder_outputs.cross_attentions)
+
+
+def restructure(modality_sizes: ModalitySizeType, inputs: torch.Tensor) -> Mapping[str, torch.Tensor]:
+ """
+ Partitions a [B, N, C] tensor into tensors for each modality.
+
+ Args:
+ modality_sizes
+ dict specifying the size of the modality
+ inputs:
+ input tensor
+
+ Returns:
+ dict mapping name of modality to its associated tensor.
+ """
+ outputs = {}
+ index = 0
+ # Apply a predictable ordering to the modalities
+ for modality in sorted(modality_sizes.keys()):
+ size = modality_sizes[modality]
+ inp = inputs[:, index: index + size]
+ index += size
+ outputs[modality] = inp
+ return outputs
+
+
+class PerceiverMultimodalDecoder(PerceiverAbstractDecoder):
+ """
+ Multimodal decoding by composing uni-modal decoders. The *modalities* argument of the constructor is a dictionary
+ mapping modality name to the decoder of that modality. That decoder will be used to construct queries for that
+ modality. Modality-specific queries are padded with trainable modality-specific parameters, after which they are
+ concatenated along the time dimension.
+
+ Next, there is a shared cross attention operation across all modalities.
+
+ Args:
+ config ([*PerceiverConfig*]):
+ Model configuration.
+ modalities (`Dict[str, PerceiverAbstractDecoder]`):
+ Dictionary mapping modality name to the decoder of that modality.
+ num_outputs (`int`):
+ The number of outputs of the decoder.
+ output_num_channels (`int`):
+ The number of channels in the output.
+ min_padding_size (`int`, *optional*, defaults to 2):
+ The minimum padding size for all modalities. The final output will have num_channels equal to the maximum
+ channels across all modalities plus min_padding_size.
+ subsampled_index_dims (`Dict[str, PerceiverAbstractDecoder]`, *optional*):
+ Dictionary mapping modality name to the subsampled index dimensions to use for the decoder query of that
+ modality.
+ """
+
+ def __init__(
+ self,
+ config,
+ modalities,
+ num_outputs,
+ output_num_channels,
+ min_padding_size=2,
+ subsampled_index_dims=None,
+ **decoder_kwargs
+ ):
+ super().__init__()
+ self.modalities = nn.ModuleDict(modalities)
+ self.subsampled_index_dims = subsampled_index_dims
+ self.min_padding_size = min_padding_size
+ self.output_num_channels = output_num_channels
+ self.num_outputs = num_outputs
+ self.decoder = PerceiverBasicDecoder(
+ config,
+ output_index_dims=(num_outputs,),
+ output_num_channels=output_num_channels,
+ position_encoding_type="none",
+ num_channels=self.num_query_channels,
+ **decoder_kwargs,
+ )
+ self.padding = nn.ParameterDict(
+ {
+ modality: nn.Parameter(torch.randn(
+ 1, self.num_query_channels - decoder.num_query_channels))
+ for modality, decoder in modalities.items()
+ }
+ )
+
+ @property
+ def num_query_channels(self) -> int:
+ max_channel_size = max(
+ decoder.num_query_channels for _, decoder in self.modalities.items())
+ common_channel_size = max_channel_size + self.min_padding_size
+ return common_channel_size
+
+ def decoder_query(self, inputs, modality_sizes, inputs_without_pos=None, subsampled_points=None):
+ # Partition the flat inputs among the different modalities
+ inputs = restructure(modality_sizes, inputs)
+
+ # Obtain modality-specific decoders' queries
+ subsampled_points = subsampled_points or dict()
+
+ decoder_queries = dict()
+ for modality, decoder in self.modalities.items():
+ # Get input_without_pos for this modality if it exists.
+ input_without_pos = None
+ if inputs_without_pos is not None:
+ input_without_pos = inputs_without_pos.get(modality, None)
+ query = decoder.decoder_query(
+ inputs=inputs[modality],
+ modality_sizes=None,
+ inputs_without_pos=input_without_pos,
+ subsampled_points=subsampled_points.get(modality, None),
+ )
+ decoder_queries[modality] = query
+
+ # Pad all queries with trainable position encodings to make them have the same channels
+
+ def embed(modality, x):
+ x = torch.reshape(
+ x, [x.shape[0], np.prod(x.shape[1:-1]), x.shape[-1]])
+ pos = self.padding[modality]
+ pos = torch.broadcast_to(
+ pos, [x.shape[0], x.shape[1], self.num_query_channels - x.shape[2]])
+ return torch.cat([x, pos], dim=2)
+
+ # Apply a predictable ordering to the modalities
+ return torch.cat(
+ [embed(modality, decoder_queries[modality]) for modality in sorted(self.modalities.keys())], dim=1
+ )
+
+ def forward(self, query, z, query_mask=None, output_attentions=False):
+ # B x 1 x num_classes -> B x num_classes
+ decoder_outputs = self.decoder(
+ query, z, output_attentions=output_attentions)
+
+ return decoder_outputs
+
+
+# Below: IO pre- and post-processor classes for Perceiver.
+def space_to_depth(frames: torch.Tensor, temporal_block_size: int = 1, spatial_block_size: int = 1) -> torch.Tensor:
+ """
+ Space to depth transform. Rearranges blocks of spatial data, into depth.
+
+ This function assumes the channels to be first, but will place the channels last after transformation.
+
+ Based on https://discuss.pytorch.org/t/is-there-any-layer-like-tensorflows-space-to-depth-function/3487/15.
+ """
+ if len(frames.shape) == 4:
+ batch_size, num_channels, height, width = frames.shape
+ # split up dimensions (height by spatial_block_size, width by spatial_block_size)
+ frames = frames.view(
+ batch_size,
+ num_channels,
+ height // spatial_block_size,
+ spatial_block_size,
+ width // spatial_block_size,
+ spatial_block_size,
+ )
+ # move blocks to last dimension: (batch_size, H//bs, W//bs, bs, bs, C)
+ frames = frames.permute(0, 2, 4, 3, 5, 1).contiguous()
+ # concatenate blocks along channel dimension: (batch_size, H//bs, W//bs, bs*bs*C)
+ frames = frames.view(
+ batch_size,
+ height // spatial_block_size,
+ width // spatial_block_size,
+ (spatial_block_size ** 2) * num_channels,
+ )
+ return frames
+ elif len(frames.shape) == 5:
+ batch_size, time, num_channels, height, width = frames.shape
+ # split up dimensions (time by temporal_block_size, height by spatial_block_size, width by spatial_block_size)
+ frames = frames.view(
+ batch_size,
+ time // temporal_block_size,
+ temporal_block_size,
+ num_channels,
+ height // spatial_block_size,
+ spatial_block_size,
+ width // spatial_block_size,
+ spatial_block_size,
+ )
+ # move blocks to last dimension: (batch_size, T//ts, H//bs, W//bs, ts, bs, bs, C)
+ frames = frames.permute(0, 1, 4, 6, 2, 5, 7, 3).contiguous()
+ # concatenate blocks along channel dimension: (batch_size, T//ts, H//bs, W//bs, ts*bs*bs*C)
+ frames = frames.view(
+ batch_size,
+ time // temporal_block_size,
+ height // spatial_block_size,
+ width // spatial_block_size,
+ temporal_block_size * (spatial_block_size ** 2) * num_channels,
+ )
+ return frames
+ else:
+ raise ValueError(
+ "Frames should be of rank 4 (batch, channels, height, width)"
+ " or rank 5 (batch, time, channels, height, width)"
+ )
+
+
+class Conv2dSamePadding(nn.Conv2d):
+ """
+ Conv2d layer with padding="same" support. Source:
+ https://gist.github.com/sumanmichael/4de9dee93f972d47c80c4ade8e149ea6
+ """
+
+ def __init__(self, *args, **kwargs):
+ super(Conv2dSamePadding, self).__init__(*args, **kwargs)
+ self.zero_pad_2d = nn.ZeroPad2d(
+ reduce(__add__, [(k // 2 + (k - 2 * (k // 2)) - 1, k // 2)
+ for k in self.kernel_size[::-1]])
+ )
+
+ def forward(self, input):
+ return self._conv_forward(self.zero_pad_2d(input), self.weight, self.bias)
+
+
+class Conv2DDownsample(nn.Module):
+ """Downsamples 4x by applying a 2D convolution and doing max pooling."""
+
+ def __init__(
+ self,
+ num_layers: int = 1,
+ in_channels: int = 3,
+ out_channels: int = 64,
+ use_batchnorm: bool = True,
+ ):
+ """
+ Constructs a Conv2DDownsample model.
+
+ Args:
+ in_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ out_channels (`int`, *optional*, defaults to 64):
+ The number of conv output channels.
+ use_batchnorm (`bool`, *optional*, defaults to `True`):
+ Whether to use batchnorm.
+ """
+ super().__init__()
+
+ self.conv = Conv2dSamePadding(
+ in_channels=in_channels, out_channels=out_channels, kernel_size=7, stride=2, bias=False
+ )
+ self.batchnorm = nn.BatchNorm2d(
+ num_features=out_channels) if use_batchnorm else nn.Identity()
+ self.relu = nn.ReLU()
+ self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ out = self.conv(inputs)
+ out = self.batchnorm(out)
+ out = self.relu(out)
+ out = self.max_pool(out)
+ return out
+
+
+def generate_fourier_features(pos, num_bands, max_resolution=(224, 224), concat_pos=True, sine_only=False):
+ """
+ Generate a Fourier frequency position encoding with linear spacing.
+
+ Args:
+ pos (`torch.LongTensor` of shape `(batch_size, sequence_length, dim)`):
+ The Tensor containing the position of n points in d dimensional space.
+ num_bands (`int`):
+ The number of frequency bands (K) to use.
+ max_resolution (`Tuple[int]`, *optional*, defaults to (224, 224)):
+ The maximum resolution (i.e. the number of pixels per dim). A tuple representing resolution for each dimension.
+ concat_pos (`bool`, *optional*, defaults to `True`):
+ Whether to concatenate the input position encoding to the Fourier features.
+ sine_only (`bool`, *optional*, defaults to `False`):
+ Whether to use a single phase (sin) or two (sin/cos) for each frequency band.
+
+ Returns:
+ `torch.FloatTensor` of shape `(batch_size, sequence_length, n_channels)`: The Fourier position embeddings. If
+ `concat_pos` is `True` and `sine_only` is `False`, output dimensions are ordered as: [dim_1, dim_2, ..., dim_d,
+ sin(pi*f_1*dim_1), ..., sin(pi*f_K*dim_1), ..., sin(pi*f_1*dim_d), ..., sin(pi*f_K*dim_d), cos(pi*f_1*dim_1),
+ ..., cos(pi*f_K*dim_1), ..., cos(pi*f_1*dim_d), ..., cos(pi*f_K*dim_d)], where dim_i is pos[:, i] and f_k is the
+ kth frequency band.
+ """
+
+ batch_size = pos.shape[0]
+ device = pos.device
+
+ min_freq = 1.0
+ # Nyquist frequency at the target resolution:
+ freq_bands = torch.stack(
+ [torch.linspace(start=min_freq, end=res / 2, steps=num_bands, device=device) for res in max_resolution], dim=0
+ )
+
+ # Get frequency bands for each spatial dimension.
+ # Output is size [n, d * num_bands]
+ per_pos_features = pos[0, :, :][:, :, None] * freq_bands[None, :, :]
+ per_pos_features = torch.reshape(
+ per_pos_features, [-1, np.prod(per_pos_features.shape[1:])])
+
+ if sine_only:
+ # Output is size [n, d * num_bands]
+ per_pos_features = torch.sin(np.pi * (per_pos_features))
+ else:
+ # Output is size [n, 2 * d * num_bands]
+ per_pos_features = torch.cat(
+ [torch.sin(np.pi * per_pos_features), torch.cos(np.pi * per_pos_features)], dim=-1
+ )
+ # Concatenate the raw input positions.
+ if concat_pos:
+ # Adds d bands to the encoding.
+ per_pos_features = torch.cat(
+ [pos, per_pos_features.expand(batch_size, -1, -1)], dim=-1)
+ return per_pos_features
+
+
+def build_linear_positions(index_dims, output_range=(-1.0, 1.0)):
+ """
+ Generate an array of position indices for an N-D input array.
+
+ Args:
+ index_dims (`List[int]`):
+ The shape of the index dimensions of the input array.
+ output_range (`Tuple[float]`, *optional*, defaults to `(-1.0, 1.0)`):
+ The min and max values taken by each input index dimension.
+
+ Returns:
+ `torch.FloatTensor` of shape `(index_dims[0], index_dims[1], .., index_dims[-1], N)`.
+ """
+
+ def _linspace(n_xels_per_dim):
+ return torch.linspace(start=output_range[0], end=output_range[1], steps=n_xels_per_dim, dtype=torch.float32)
+
+ dim_ranges = [_linspace(n_xels_per_dim) for n_xels_per_dim in index_dims]
+ array_index_grid = torch.meshgrid(*dim_ranges)
+
+ return torch.stack(array_index_grid, dim=-1)
+
+
+class PerceiverAbstractPositionEncoding(nn.Module, metaclass=abc.ABCMeta):
+ """Perceiver abstract position encoding."""
+
+ @property
+ @abc.abstractmethod
+ def num_dimensions(self) -> int:
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def output_size(self, *args, **kwargs) -> int:
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def forward(self, batch_size, pos):
+ raise NotImplementedError
+
+
+class PerceiverTrainablePositionEncoding(PerceiverAbstractPositionEncoding):
+ """Trainable position encoding."""
+
+ def __init__(self, index_dims, num_channels=128):
+ super().__init__()
+ self._num_channels = num_channels
+ self._index_dims = index_dims
+ index_dim = np.prod(index_dims)
+ self.position_embeddings = nn.Parameter(
+ torch.randn(index_dim, num_channels))
+
+ @property
+ def num_dimensions(self) -> int:
+ if isinstance(self._index_dims, int):
+ return 1
+ return len(self._index_dims)
+
+ def output_size(self, *args, **kwargs) -> int:
+ return self._num_channels
+
+ def forward(self, batch_size):
+ position_embeddings = self.position_embeddings
+
+ if batch_size is not None:
+ position_embeddings = position_embeddings.expand(
+ batch_size, -1, -1)
+ return position_embeddings
+
+
+def _check_or_build_spatial_positions(pos, index_dims, batch_size):
+ """
+ Checks or builds spatial position features (x, y, ...).
+
+ Args:
+ pos (`torch.FloatTensor`):
+ None, or an array of position features. If None, position features are built. Otherwise, their size is checked.
+ index_dims (`List[int]`):
+ An iterable giving the spatial/index size of the data to be featurized.
+ batch_size (`int`):
+ The batch size of the data to be featurized.
+
+ Returns:
+ `torch.FloatTensor` of shape `(batch_size, prod(index_dims))` an array of position features.
+ """
+ if pos is None:
+ pos = build_linear_positions(index_dims)
+ pos = torch.broadcast_to(pos[None], (batch_size,) + pos.shape)
+ pos = torch.reshape(pos, [batch_size, np.prod(index_dims), -1])
+ # else:
+ # # Just a warning label: you probably don't want your spatial features to
+ # # have a different spatial layout than your pos coordinate system.
+ # # But feel free to override if you think it'll work!
+ # if pos.shape[-1] != len(index_dims):
+ # raise ValueError("Spatial features have the wrong number of dimensions.")
+ return pos
+
+
+class PerceiverFourierPositionEncoding(PerceiverAbstractPositionEncoding):
+ """Fourier (Sinusoidal) position encoding."""
+
+ def __init__(self, num_bands, max_resolution, concat_pos=True, sine_only=False):
+ super().__init__()
+ self.num_bands = num_bands
+ self.max_resolution = max_resolution
+ self.concat_pos = concat_pos
+ self.sine_only = sine_only
+
+ @property
+ def num_dimensions(self) -> int:
+ return len(self.max_resolution)
+
+ def output_size(self):
+ """Returns size of positional encodings last dimension."""
+ num_dims = len(self.max_resolution)
+ encoding_size = self.num_bands * num_dims
+ if not self.sine_only:
+ encoding_size *= 2
+ if self.concat_pos:
+ encoding_size += self.num_dimensions
+
+ return encoding_size
+
+ def forward(self, index_dims, batch_size, device, pos=None):
+ pos = _check_or_build_spatial_positions(pos, index_dims, batch_size)
+ fourier_pos_enc = generate_fourier_features(
+ pos,
+ num_bands=self.num_bands,
+ max_resolution=self.max_resolution,
+ concat_pos=self.concat_pos,
+ sine_only=self.sine_only,
+ ).to(device)
+ return fourier_pos_enc
+
+
+class AbstractPreprocessor(nn.Module):
+ @property
+ def num_channels(self) -> int:
+ """Returns size of preprocessor output."""
+ raise NotImplementedError()
+
+
+class PerceiverTextPreprocessor(AbstractPreprocessor):
+ """
+ Text preprocessing for Perceiver Encoder. Can be used to embed `inputs` and add positional encodings.
+
+ The dimensionality of the embeddings is determined by the `d_model` attribute of the configuration.
+
+ Args:
+ config ([`PerceiverConfig`]):
+ Model configuration.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.embeddings = nn.Embedding(
+ num_embeddings=config.vocab_size, embedding_dim=config.d_model)
+ self.position_embeddings = nn.Embedding(
+ config.max_position_embeddings, config.d_model)
+
+ @property
+ def num_channels(self) -> int:
+ return self.config.d_model
+
+ def forward(self, inputs):
+ embeddings = self.embeddings(inputs)
+
+ seq_length = inputs.shape[1]
+ position_ids = torch.arange(0, seq_length, device=inputs.device)
+ embeddings = embeddings + self.position_embeddings(position_ids)
+
+ return embeddings, None, None
+
+
+class PerceiverEmbeddingDecoder(nn.Module):
+ """
+ Module to decode embeddings (for masked language modeling).
+
+ Args:
+ config ([`PerceiverConfig`]):
+ Model configuration.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.vocab_size = config.vocab_size
+ self.bias = nn.Parameter(torch.zeros(self.vocab_size))
+
+ def forward(self, hidden_states, embedding_layer):
+ batch_size, seq_len, d_model = hidden_states.shape
+ output = torch.matmul(hidden_states.reshape(
+ [-1, d_model]), embedding_layer.weight.T) # Flatten batch dim
+ output = output + self.bias
+
+ return output.reshape([batch_size, seq_len, self.vocab_size])
+
+
+class PerceiverMultimodalPostprocessor(nn.Module):
+ """
+ Multimodal postprocessing for Perceiver. Can be used to combine modality-specific postprocessors into a single
+ postprocessor.
+
+ Args:
+ modalities (`Dict[str, PostprocessorType]`):
+ Dictionary mapping modality name to postprocessor class for that modality.
+ input_is_dict (`bool`, *optional*, defaults to `False`):
+ If True, input is assumed to be dictionary structured, and outputs keep the same dictionary shape. If
+ False, input is a tensor which is sliced up during postprocessing by *modality_sizes*.
+ """
+
+ def __init__(self, modalities: Mapping[str, PostprocessorType], input_is_dict: bool = False):
+ super().__init__()
+ self.modalities = nn.ModuleDict(modalities)
+ self.input_is_dict = input_is_dict
+
+ def forward(
+ self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, modality_sizes=None
+ ) -> Mapping[str, torch.Tensor]:
+ if not self.input_is_dict:
+ # Slice up modalities by their sizes.
+ if modality_sizes is None:
+ raise ValueError(
+ "Modality sizes should be specified if input is not a dictionary.")
+ inputs = restructure(modality_sizes=modality_sizes, inputs=inputs)
+
+ outputs = {
+ modality: postprocessor(
+ inputs[modality], pos=pos, modality_sizes=None)
+ for modality, postprocessor in self.modalities.items()
+ }
+ return outputs
+
+
+class PerceiverClassificationPostprocessor(nn.Module):
+ """
+ Classification postprocessing for Perceiver. Can be used to convert the decoder output to classification logits.
+
+ Args:
+ config ([*PerceiverConfig*]):
+ Model configuration.
+ in_channels (`int`):
+ Number of channels in the input.
+ """
+
+ def __init__(self, config, in_channels):
+ super().__init__()
+ self.classifier = nn.Linear(in_channels, config.num_labels)
+
+ def forward(self, inputs, pos: Optional[torch.Tensor] = None, modality_sizes=None) -> torch.Tensor:
+ logits = self.classifier(inputs)
+ return logits[:, 0, :]
+
+
+class PerceiverAudioPostprocessor(nn.Module):
+ """
+ Audio postprocessing for Perceiver. Can be used to convert the decoder output to audio features.
+
+ Args:
+ config ([*PerceiverConfig*]):
+ Model configuration.
+ in_channels (`int`):
+ Number of channels in the input.
+ postproc_type (`str`, *optional*, defaults to `"patches"`):
+ Postprocessor type to use. Currently, only "patches" is supported.
+ """
+
+ def __init__(self, config, in_channels, postproc_type: str = "patches"):
+ super().__init__()
+
+ # to be supported: 'conv', 'patches', 'pixels'
+ if postproc_type not in ("patches",):
+ raise ValueError("Invalid postproc_type!")
+
+ # Architecture parameters:
+ self.classifier = nn.Linear(in_channels, config.samples_per_patch)
+
+ def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, modality_sizes=None) -> torch.Tensor:
+
+ logits = self.classifier(inputs)
+ return torch.reshape(logits, [inputs.shape[0], -1])
+
+
+class PerceiverProjectionPostprocessor(nn.Module):
+ """
+ Projection postprocessing for Perceiver. Can be used to project the channels of the decoder output to a lower
+ dimension.
+
+ Args:
+ in_channels (`int`):
+ Number of channels in the input.
+ out_channels (`int`):
+ Number of channels in the output.
+ """
+
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+ self.classifier = nn.Linear(in_channels, out_channels)
+
+ def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, modality_sizes=None) -> torch.Tensor:
+ logits = self.classifier(inputs)
+ return logits
+
+
+class PerceiverImagePreprocessor(AbstractPreprocessor):
+ """
+ Image preprocessing for Perceiver Encoder.
+
+ Note: the *out_channels* argument refers to the output channels of a convolutional layer, if *prep_type* is set to
+ "conv1x1" or "conv". If one adds absolute position embeddings, one must make sure the *num_channels* of the
+ position encoding kwargs are set equal to the *out_channels*.
+
+ Args:
+ config ([*PerceiverConfig*]):
+ Model configuration.
+ prep_type (`str`, *optional*, defaults to `"conv"`):
+ Preprocessing type. Can be "conv1x1", "conv", "patches", "pixels".
+ spatial_downsample (`int`, *optional*, defaults to 4):
+ Spatial downsampling factor.
+ temporal_downsample (`int`, *optional*, defaults to 1):
+ Temporal downsampling factor (only relevant in case a time dimension is present).
+ position_encoding_type (`str`, *optional*, defaults to `"fourier"`):
+ Position encoding type. Can be "fourier" or "trainable".
+ in_channels (`int`, *optional*, defaults to 3):
+ Number of channels in the input.
+ out_channels (`int`, *optional*, defaults to 64):
+ Number of channels in the output.
+ conv_after_patching (`bool`, *optional*, defaults to `False`):
+ Whether to apply a convolutional layer after patching.
+ conv_after_patching_in_channels (`int`, *optional*, defaults to 54):
+ Number of channels in the input of the convolutional layer after patching.
+ conv2d_use_batchnorm (`bool`, *optional*, defaults to `True`):
+ Whether to use batch normalization in the convolutional layer.
+ concat_or_add_pos (`str`, *optional*, defaults to `"concat"`):
+ How to concatenate the position encoding to the input. Can be "concat" or "add".
+ project_pos_dim (`int`, *optional*, defaults to -1):
+ Dimension of the position encoding to project to. If -1, no projection is applied.
+ **position_encoding_kwargs (`Dict`, *optional*):
+ Keyword arguments for the position encoding.
+ """
+
+ def __init__(
+ self,
+ config,
+ prep_type="conv",
+ spatial_downsample: int = 4,
+ temporal_downsample: int = 1,
+ position_encoding_type: str = "fourier",
+ in_channels: int = 3,
+ out_channels: int = 64,
+ conv_after_patching: bool = False,
+ # only relevant when conv_after_patching = True
+ conv_after_patching_in_channels: int = 54,
+ conv2d_use_batchnorm: bool = True,
+ concat_or_add_pos: str = "concat",
+ project_pos_dim: int = -1,
+ **position_encoding_kwargs,
+ ):
+ super().__init__()
+ self.config = config
+
+ if prep_type not in ("conv", "patches", "pixels", "conv1x1"):
+ raise ValueError(f"Prep_type {prep_type} is invalid")
+
+ if concat_or_add_pos not in ["concat", "add"]:
+ raise ValueError(
+ f"Invalid value {concat_or_add_pos} for concat_or_add_pos.")
+
+ self.in_channels = in_channels
+ self.prep_type = prep_type
+ self.spatial_downsample = spatial_downsample
+ self.temporal_downsample = temporal_downsample
+ self.position_encoding_type = position_encoding_type
+ self.concat_or_add_pos = concat_or_add_pos
+ self.conv_after_patching = conv_after_patching
+ self.out_channels = out_channels
+
+ if self.prep_type == "conv":
+ # Downsampling with conv is currently restricted
+ convnet_num_layers = math.log(spatial_downsample, 4)
+ convnet_num_layers_is_int = convnet_num_layers == np.round(
+ convnet_num_layers)
+ if not convnet_num_layers_is_int or temporal_downsample != 1:
+ raise ValueError(
+ "Only powers of 4 expected for spatial and 1 expected for temporal downsampling with conv."
+ )
+ self.convnet = Conv2DDownsample(
+ in_channels=in_channels,
+ num_layers=int(convnet_num_layers),
+ out_channels=out_channels,
+ use_batchnorm=conv2d_use_batchnorm,
+ )
+
+ elif self.prep_type == "conv1x1":
+ if temporal_downsample != 1:
+ raise ValueError("Conv1x1 does not downsample in time.")
+ self.convnet_1x1 = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=(1, 1),
+ # spatial_downsample is unconstrained for 1x1 convolutions.
+ stride=(spatial_downsample, spatial_downsample),
+ )
+
+ # Position embeddings
+ self.project_pos_dim = project_pos_dim
+ self.position_embeddings, self.positions_projection = build_position_encoding(
+ position_encoding_type=position_encoding_type,
+ out_channels=out_channels,
+ project_pos_dim=project_pos_dim,
+ **position_encoding_kwargs,
+ )
+
+ # Optional convolutional layer after patches.
+ self.conv_after_patches = (
+ nn.Linear(conv_after_patching_in_channels,
+ self.out_channels) if conv_after_patching else nn.Identity()
+ )
+
+ @property
+ def num_channels(self) -> int:
+ # Let's assume that the number of resolutions (in the context of image preprocessing)
+ # of the input data is 2 or 3 depending on whether we are processing image or video respectively.
+ # In this case, for convenience, we will declare is_temporal variable,
+ # which will show whether the data has a temporal dimension or not.
+ is_temporal = self.position_embeddings.num_dimensions > 2
+
+ # position embedding
+ if self.project_pos_dim > 0:
+ pos_dim = self.project_pos_dim
+ else:
+ pos_dim = self.position_embeddings.output_size()
+ if self.concat_or_add_pos == "add":
+ return pos_dim
+
+ # inputs
+ if self.conv_after_patching or self.prep_type in ("conv1x1", "conv"):
+ inp_dim = self.out_channels
+ elif self.prep_type == "pixels":
+ inp_dim = self.in_channels
+ if not is_temporal:
+ inp_dim = math.ceil(inp_dim / self.spatial_downsample)
+ elif self.prep_type == "patches":
+ if self.conv_after_patching:
+ inp_dim = self.out_channels
+ else:
+ inp_dim = self.in_channels * self.spatial_downsample ** 2
+ if is_temporal:
+ inp_dim *= self.temporal_downsample
+
+ return inp_dim + pos_dim
+
+ def _build_network_inputs(self, inputs: torch.Tensor, pos: torch.Tensor, network_input_is_1d: bool = True):
+ """
+ Construct the final input, including position encoding.
+
+ This method expects the inputs to always have channels as last dimension.
+
+ """
+ batch_size = inputs.shape[0]
+ index_dims = inputs.shape[1:-1]
+ indices = np.prod(index_dims)
+
+ # Flatten input features to a 1D index dimension if necessary.
+ if len(inputs.shape) > 3 and network_input_is_1d:
+ inputs = torch.reshape(inputs, [batch_size, indices, -1])
+
+ # Construct the position encoding.
+ if self.position_encoding_type == "trainable":
+ pos_enc = self.position_embeddings(batch_size)
+ elif self.position_encoding_type == "fourier":
+ pos_enc = self.position_embeddings(
+ index_dims, batch_size, device=inputs.device)
+
+ # Optionally project them to a target dimension.
+ pos_enc = self.positions_projection(pos_enc)
+
+ if not network_input_is_1d:
+ # Reshape pos to match the input feature shape
+ # if the network takes non-1D inputs
+ sh = inputs.shape
+ pos_enc = torch.reshape(pos_enc, list(sh)[:-1] + [-1])
+ if self.concat_or_add_pos == "concat":
+ inputs_with_pos = torch.cat([inputs, pos_enc], dim=-1)
+ elif self.concat_or_add_pos == "add":
+ inputs_with_pos = inputs + pos_enc
+ return inputs_with_pos, inputs
+
+ def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True):
+ if self.prep_type == "conv":
+ # Convnet image featurization.
+ # Downsamples spatially by a factor of 4
+ inputs = self.convnet(inputs)
+
+ elif self.prep_type == "conv1x1":
+ # map inputs to self.out_channels
+ inputs = self.convnet_1x1(inputs)
+
+ elif self.prep_type == "pixels":
+ # if requested, downsamples in the crudest way
+ if inputs.ndim == 4:
+ inputs = inputs[:: self.spatial_downsample,
+ :: self.spatial_downsample]
+ elif inputs.ndim == 5:
+ inputs = inputs[
+ :, :: self.temporal_downsample, :, :: self.spatial_downsample, :: self.spatial_downsample
+ ]
+ else:
+ raise ValueError("Unsupported data format for pixels.")
+
+ elif self.prep_type == "patches":
+ # Space2depth featurization.
+ # Video: B x T x C x H x W
+ inputs = space_to_depth(
+ inputs, temporal_block_size=self.temporal_downsample, spatial_block_size=self.spatial_downsample
+ )
+
+ if inputs.ndim == 5 and inputs.shape[1] == 1:
+ # for flow
+ inputs = inputs.squeeze(dim=1)
+
+ # Optionally apply conv layer.
+ inputs = self.conv_after_patches(inputs)
+
+ if self.prep_type != "patches":
+ # move channels to last dimension, as the _build_network_inputs method below expects this
+ if inputs.ndim == 4:
+ inputs = torch.moveaxis(inputs, 1, -1)
+ elif inputs.ndim == 5:
+ inputs = torch.moveaxis(inputs, 2, -1)
+ else:
+ raise ValueError("Unsupported data format for conv1x1.")
+
+ inputs, inputs_without_pos = self._build_network_inputs(
+ inputs, pos, network_input_is_1d)
+ modality_sizes = None # Size for each modality, only needed for multimodal
+
+ return inputs, modality_sizes, inputs_without_pos
+
+
+class PerceiverOneHotPreprocessor(AbstractPreprocessor):
+ """
+ One-hot preprocessor for Perceiver Encoder. Can be used to add a dummy index dimension to the input.
+
+ Args:
+ config ([`PerceiverConfig`]):
+ Model configuration.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.config: PerceiverConfig = config
+
+ @property
+ def num_channels(self) -> int:
+ return self.config.num_labels
+
+ def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True):
+ # Add a dummy index dimension.
+ inputs = inputs[:, None, :]
+
+ # No position encodings, so the 1st (input) and 3rd (inputs_without_pos)
+ # outputs are identical.
+ return inputs, None, inputs
+
+
+class PerceiverAudioPreprocessor(AbstractPreprocessor):
+ """
+ Audio preprocessing for Perceiver Encoder.
+
+ Args:
+ config ([*PerceiverConfig*]):
+ Model configuration.
+ prep_type (`str`, *optional*, defaults to `"patches"`):
+ Preprocessor type to use. Only "patches" is supported.
+ samples_per_patch (`int`, *optional*, defaults to 96):
+ Number of samples per patch.
+ position_encoding_type (`str`, *optional*, defaults to `"fourier"`):
+ Type of position encoding to use. Can be "trainable" or "fourier".
+ concat_or_add_pos (`str`, *optional*, defaults to `"concat"`):
+ How to concatenate the position encoding to the input. Can be "concat" or "add".
+ out_channels (`int`, *optional*, defaults to 64):
+ Number of channels in the output.
+ project_pos_dim (`int`, *optional*, defaults to -1):
+ Dimension of the position encoding to project to. If -1, no projection is applied.
+ **position_encoding_kwargs (`Dict`, *optional*):
+ Keyword arguments for the position encoding.
+ """
+
+ def __init__(
+ self,
+ config,
+ prep_type: str = "patches",
+ samples_per_patch: int = 96,
+ position_encoding_type: str = "fourier",
+ concat_or_add_pos: str = "concat",
+ out_channels=64,
+ project_pos_dim=-1,
+ **position_encoding_kwargs,
+ ):
+ super().__init__()
+ self.config = config
+
+ if prep_type not in ("patches",):
+ raise ValueError(
+ f"Prep_type {prep_type} is invalid, can only be 'patches'.")
+
+ if concat_or_add_pos not in ["concat", "add"]:
+ raise ValueError(
+ f"Concat_or_pos {concat_or_add_pos} is invalid, can only be 'concat' or 'add'.")
+
+ self.samples_per_patch = samples_per_patch
+ self.position_encoding_type = position_encoding_type
+ self.concat_or_add_pos = concat_or_add_pos
+ self.project_pos_dim = project_pos_dim
+
+ # Position embeddings
+ self.position_embeddings, self.positions_projection = build_position_encoding(
+ position_encoding_type=position_encoding_type,
+ out_channels=out_channels,
+ project_pos_dim=project_pos_dim,
+ **position_encoding_kwargs,
+ )
+
+ @property
+ def num_channels(self) -> int:
+ # position embedding
+ if self.project_pos_dim > 0:
+ pos_dim = self.project_pos_dim
+ else:
+ pos_dim = self.position_embeddings.output_size()
+ if self.concat_or_add_pos == "add":
+ return pos_dim
+ return self.samples_per_patch + pos_dim
+
+ def _build_network_inputs(self, inputs, pos):
+ """Construct the final input, including position encoding."""
+ batch_size = inputs.shape[0]
+ index_dims = inputs.shape[1:-1]
+
+ # Construct the position encoding.
+ if self.position_encoding_type == "trainable":
+ pos_enc = self.position_embeddings(batch_size)
+ elif self.position_encoding_type == "fourier":
+ pos_enc = self.position_embeddings(
+ index_dims, batch_size, device=inputs.device)
+
+ # Optionally project them to a target dimension.
+ pos_enc = self.positions_projection(pos_enc)
+
+ if self.concat_or_add_pos == "concat":
+ inputs_with_pos = torch.cat([inputs, pos_enc], dim=-1)
+ elif self.concat_or_add_pos == "add":
+ inputs_with_pos = inputs + pos_enc
+
+ return inputs_with_pos, inputs
+
+ def forward(self, inputs, pos, network_input_is_1d: bool = True):
+ inputs = torch.reshape(
+ inputs, [inputs.shape[0], -1, self.samples_per_patch])
+
+ inputs, inputs_without_pos = self._build_network_inputs(inputs, pos)
+ modality_sizes = None # Size for each modality, only needed for multimodal
+
+ return inputs, modality_sizes, inputs_without_pos
+
+
+class PerceiverMultimodalPreprocessor(AbstractPreprocessor):
+ """
+ Multimodal preprocessing for Perceiver Encoder.
+
+ Inputs for each modality are preprocessed, then padded with trainable position embeddings to have the same number
+ of channels.
+
+ Args:
+ modalities (`Dict[str, PreprocessorType]`):
+ Dict mapping modality name to preprocessor.
+ mask_probs (`Dict[str, float]`):
+ Dict mapping modality name to masking probability of that modality.
+ min_padding_size (`int`, *optional*, defaults to 2):
+ The minimum padding size for all modalities. The final output will have num_channels equal to the maximum
+ channels across all modalities plus min_padding_size.
+ """
+
+ def __init__(
+ self,
+ modalities: Mapping[str, PreprocessorType],
+ mask_probs: Optional[Mapping[str, float]] = None,
+ min_padding_size: int = 2,
+ ):
+ super().__init__()
+ self.modalities = modalities
+ self.min_padding_size = min_padding_size
+ self.mask_probs = mask_probs if mask_probs is not None else dict()
+ self.padding = nn.ParameterDict(
+ {
+ modality: nn.Parameter(torch.randn(
+ 1, self.num_channels - preprocessor.num_channels))
+ for modality, preprocessor in modalities.items()
+ }
+ )
+ self.mask = nn.ParameterDict(
+ {modality: nn.Parameter(torch.randn(1, self.num_channels))
+ for modality, _ in self.mask_probs.items()}
+ )
+
+ @property
+ def num_channels(self) -> int:
+ max_channel_size = max(processor.num_channels for _,
+ processor in self.modalities.items())
+ common_channel_size = max_channel_size + self.min_padding_size
+ return common_channel_size
+
+ def forward(
+ self, inputs: Mapping[str, torch.Tensor], pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True
+ ) -> PreprocessorOutputType:
+ padded = {}
+ modality_sizes = {}
+ inputs_without_pos = {}
+ for modality, preprocessor in self.modalities.items():
+ # preprocess each modality using the respective preprocessor.
+ output, _, inputs_without_pos[modality] = preprocessor(
+ inputs[modality], pos=pos, network_input_is_1d=network_input_is_1d
+ )
+
+ # pad to the same common_channel_size.
+ batch_size, num_samples, num_channels = output.shape
+ pos_enc = self.padding[modality].expand(batch_size, -1, -1)
+
+ padding = torch.broadcast_to(
+ pos_enc,
+ [batch_size, num_samples, self.num_channels - num_channels],
+ )
+ output_padded = torch.cat([output, padding], dim=2)
+
+ # mask if required
+ if modality in self.mask_probs:
+ mask_token = self.mask[modality].expand(batch_size, -1, -1)
+ mask_prob = self.mask_probs[modality]
+ mask = torch.bernoulli(torch.full(
+ [batch_size, num_samples], mask_prob))
+ mask = torch.unsqueeze(mask, dim=2).to(mask_token.device)
+ output_padded = (1 - mask) * output_padded + mask * mask_token
+
+ padded[modality] = output_padded
+ modality_sizes[modality] = output_padded.shape[1]
+
+ # Apply a predictable ordering to the modalities
+ padded_ls = [padded[k] for k in sorted(padded.keys())]
+
+ # Finally, concatenate along the time dimension
+ final_inputs = torch.cat(padded_ls, dim=1)
+
+ return final_inputs, modality_sizes, inputs_without_pos
diff --git a/vidar/arch/networks/pose/ConvPoseNet.py b/vidar/arch/networks/pose/ConvPoseNet.py
new file mode 100644
index 0000000000000000000000000000000000000000..abed96abd72f2350779c181b80d65e0d733f54ab
--- /dev/null
+++ b/vidar/arch/networks/pose/ConvPoseNet.py
@@ -0,0 +1,86 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import torch
+import torch.nn as nn
+
+
+def conv_gn(in_planes, out_planes, kernel_size=3):
+ """
+ Convolutional block with GroupNorm
+
+ Parameters
+ ----------
+ in_planes : int
+ Number of input channels
+ out_planes : int
+ Number of output channels
+ kernel_size : int
+ Convolutional kernel size
+
+ Returns
+ -------
+ layers : nn.Sequential
+ Sequence of Conv2D + GroupNorm + ReLU
+ """
+ return nn.Sequential(
+ nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size,
+ padding=(kernel_size - 1) // 2, stride=2),
+ nn.GroupNorm(16, out_planes),
+ nn.ReLU(inplace=True)
+ )
+
+
+class ConvPoseNet(nn.Module):
+ """Pose network """
+
+ def __init__(self, nb_ref_imgs=2, rotation_mode='euler', **kwargs):
+
+ nb_ref_imgs = 2
+ rotation_mode = 'euler'
+
+ super().__init__()
+ self.nb_ref_imgs = nb_ref_imgs
+ self.rotation_mode = rotation_mode
+
+ conv_channels = [16, 32, 64, 128, 256, 256, 256]
+ self.conv1 = conv_gn(3 * (1 + self.nb_ref_imgs), conv_channels[0], kernel_size=7)
+ self.conv2 = conv_gn(conv_channels[0], conv_channels[1], kernel_size=5)
+ self.conv3 = conv_gn(conv_channels[1], conv_channels[2])
+ self.conv4 = conv_gn(conv_channels[2], conv_channels[3])
+ self.conv5 = conv_gn(conv_channels[3], conv_channels[4])
+ self.conv6 = conv_gn(conv_channels[4], conv_channels[5])
+ self.conv7 = conv_gn(conv_channels[5], conv_channels[6])
+
+ self.pose_pred = nn.Conv2d(conv_channels[6], 6 * self.nb_ref_imgs,
+ kernel_size=1, padding=0)
+
+ self.init_weights()
+
+ def init_weights(self):
+ """Initialize weights"""
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
+ nn.init.xavier_uniform_(m.weight.data)
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ def forward(self, image, context):
+ """Network forward pass"""
+
+ assert (len(context) == self.nb_ref_imgs)
+ input = [image]
+ input.extend(context)
+ input = torch.cat(input, 1)
+ out_conv1 = self.conv1(input)
+ out_conv2 = self.conv2(out_conv1)
+ out_conv3 = self.conv3(out_conv2)
+ out_conv4 = self.conv4(out_conv3)
+ out_conv5 = self.conv5(out_conv4)
+ out_conv6 = self.conv6(out_conv5)
+ out_conv7 = self.conv7(out_conv6)
+
+ pose = self.pose_pred(out_conv7)
+ pose = pose.mean(3).mean(2)
+ pose = 0.01 * pose.view(pose.size(0), self.nb_ref_imgs, 6)
+
+ return pose
diff --git a/vidar/arch/networks/pose/PoseNet.py b/vidar/arch/networks/pose/PoseNet.py
new file mode 100755
index 0000000000000000000000000000000000000000..3dccc04d40f89886e8cfa700fa37d40927912363
--- /dev/null
+++ b/vidar/arch/networks/pose/PoseNet.py
@@ -0,0 +1,49 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from abc import ABC
+
+import torch
+
+from vidar.arch.networks.BaseNet import BaseNet
+from vidar.arch.networks.decoders.PoseDecoder import PoseDecoder
+from vidar.arch.networks.encoders.ResNetEncoder import ResNetEncoder
+from vidar.geometry.pose_utils import vec2mat
+
+
+class PoseNet(BaseNet, ABC):
+ """
+ Pose Network
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration with parameters
+ """
+ def __init__(self, cfg):
+ super().__init__(cfg)
+
+ self.networks['pose_encoder'] = \
+ ResNetEncoder(cfg)
+
+ self.networks['pose'] = \
+ PoseDecoder(
+ num_ch_enc=self.networks['pose_encoder'].num_ch_enc,
+ num_input_features=1,
+ num_frames_to_predict_for=2
+ )
+
+ def forward(self, rgb, invert):
+ """Network forward pass"""
+
+ rgb = torch.cat(rgb[::-1] if invert else rgb, 1)
+ feats = self.networks['pose_encoder'](rgb)
+ rotation, translation = self.networks['pose']([feats])
+ transformation = vec2mat(
+ rotation[:, 0], translation[:, 0], invert=invert)
+
+ return {
+ 'rotation': rotation,
+ 'translation': translation,
+ 'transformation': transformation,
+ }
+
diff --git a/vidar/arch/networks/transformers/MatchModule.py b/vidar/arch/networks/transformers/MatchModule.py
new file mode 100644
index 0000000000000000000000000000000000000000..44114c9d118bf647de4806e6d059e68d37109cf5
--- /dev/null
+++ b/vidar/arch/networks/transformers/MatchModule.py
@@ -0,0 +1,46 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from abc import ABC
+
+import torch
+
+from vidar.arch.networks.BaseNet import BaseNet
+from vidar.arch.networks.layers.depthformer.transformer_net import TransformerNet
+from vidar.utils.config import cfg_has
+
+
+class MatchModule(BaseNet, ABC):
+ """
+ Feature matching module (https://arxiv.org/abs/2204.07616)
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration with parameters
+ """
+ def __init__(self, cfg):
+ super().__init__(cfg)
+
+ self.downsample = cfg_has(cfg, 'downsample', 3)
+ self.set_attr(cfg, 'monocular', False)
+ self.set_attr(cfg, 'decoder_type', 'regression')
+ self.set_attr(cfg, 'pos_enc', True)
+ self.set_attr(cfg, 'calc_right', True)
+
+ self.match_module = TransformerNet(cfg, decoder_type=self.decoder_type)
+ if cfg_has(cfg, 'fix_layers', True):
+ self.match_module.fix_layers()
+ self.set_attr(cfg, 'preprocess', False)
+
+ def forward(self, target, context, device, cam):
+ """Network forward pass"""
+
+ bs, _, h, w = target.size()
+
+ downsample = 4
+ col_offset = int(downsample / 2)
+ row_offset = int(downsample / 2)
+ sampled_cols = torch.arange(col_offset, w, downsample)[None,].expand(bs, -1).to(device)
+ sampled_rows = torch.arange(row_offset, h, downsample)[None,].expand(bs, -1).to(device)
+
+ return self.match_module(target, context, sampled_rows, sampled_cols, cam)
diff --git a/vidar/core/__init__.py b/vidar/core/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vidar/core/checkpoint.py b/vidar/core/checkpoint.py
new file mode 100755
index 0000000000000000000000000000000000000000..5ed6eba1473283a7e35006f33aa77e474a34f41d
--- /dev/null
+++ b/vidar/core/checkpoint.py
@@ -0,0 +1,251 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import os
+from datetime import datetime
+
+import numpy as np
+import torch
+
+from vidar.utils.config import cfg_has
+from vidar.utils.logging import pcolor
+
+
+class ModelCheckpoint:
+ """
+ Class for model checkpointing
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration with parameters
+ verbose : Bool
+ Print information on screen if enabled
+ """
+ def __init__(self, cfg, verbose=False):
+ super().__init__()
+
+ # Create checkpoint folder
+ self.folder = cfg_has(cfg, 'folder', None)
+ self.name = cfg_has(cfg, 'name', datetime.now().strftime("%Y-%m-%d_%Hh%Mm%Ss"))
+ if self.folder:
+ self.path = os.path.join(self.folder, self.name)
+ os.makedirs(self.path, exist_ok=True)
+ else:
+ self.path = None
+
+ # Exclude folders
+ self.excludes = ['sandbox']
+
+ # If there is no folder, only track metrics
+ self.tracking_only = self.path is None
+
+ # Store arguments
+ self.keep_top = cfg_has(cfg, 'keep_top', -1)
+ self.dataset = cfg_has(cfg, 'dataset', [])
+ self.monitor = cfg_has(cfg, 'monitor', [])
+ self.mode = cfg_has(cfg, 'mode', [])
+
+ # Number of metrics to track
+ self.num_tracking = len(self.mode)
+
+ # Prepare s3 bucket
+ if cfg_has(cfg, 's3_bucket'):
+ self.s3_path = f's3://{cfg.s3_bucket}/{self.name}'
+ self.s3_url = f'https://s3.console.aws.amazon.com/s3/buckets/{self.s3_path[5:]}'
+ else:
+ self.s3_path = self.s3_url = None
+
+ # Get starting information
+ self.torch_inf = torch.tensor(np.Inf)
+ mode_dict = {
+ 'min': (self.torch_inf, 'min'),
+ 'max': (-self.torch_inf, 'max'),
+ 'auto': (-self.torch_inf, 'max') if \
+ 'acc' in self.monitor or \
+ 'a1' in self.monitor or \
+ 'fmeasure' in self.monitor \
+ else (self.torch_inf, 'min'),
+ }
+
+ if self.mode:
+ self.top = [[] for _ in self.mode]
+ self.store_val = [[] for _ in self.mode]
+ self.previous = [0 for _ in self.mode]
+ self.best = [mode_dict[m][0] for m in self.mode]
+ self.mode = [mode_dict[m][1] for m in self.mode]
+ else:
+ self.top = []
+
+ # Print if requested
+ if verbose:
+ self.print()
+
+ # Save if requested
+ if cfg_has(cfg, 'save_code', False):
+ self.save_code()
+ if self.s3_url:
+ self.sync_s3(verbose=False)
+
+ def print(self):
+ """Print information on screen"""
+ font_base = {'color': 'red', 'attrs': ('bold', 'dark')}
+ font_name = {'color': 'red', 'attrs': ('bold',)}
+ font_underline = {'color': 'red', 'attrs': ('underline',)}
+
+ print(pcolor('#' * 60, **font_base))
+ if self.path:
+ print(pcolor('### Checkpoint: ', **font_base) + \
+ pcolor('{}/{}'.format(self.folder, self.name), **font_name))
+ if self.s3_url:
+ print(pcolor('### ', **font_base) + \
+ pcolor('{}'.format(self.s3_url), **font_underline))
+ else:
+ print(pcolor('### Checkpoint: ', **font_base) + \
+ pcolor('Tracking only', **font_name))
+ print(pcolor('#' * 60, **font_base))
+
+ @staticmethod
+ def save_model(wrapper, name, epoch):
+ """Save model"""
+ torch.save({
+ 'config': wrapper.cfg, 'epoch': epoch,
+ 'state_dict': wrapper.arch.state_dict(),
+ }, name)
+
+ @staticmethod
+ def del_model(name):
+ """Delete model"""
+ if os.path.isfile(name):
+ os.remove(name)
+
+ def save_code(self):
+ """Save code in the models folder"""
+ excludes = ' '.join([f'--exclude {exclude}' for exclude in self.excludes])
+ os.system(f"tar cfz {self.path}/{self.name}.tar.gz {excludes} *")
+
+ def sync_s3(self, verbose=True):
+ """Sync saved models with the s3 bucket"""
+
+ font_base = {'color': 'magenta', 'attrs': ('bold', 'dark')}
+ font_name = {'color': 'magenta', 'attrs': ('bold',)}
+
+ if verbose:
+ print(pcolor('Syncing ', **font_base) +
+ pcolor('{}'.format(self.path), **font_name) +
+ pcolor(' -> ', **font_base) +
+ pcolor('{}'.format(self.s3_path), **font_name))
+
+ command = f'aws s3 sync {self.path} {self.s3_path} ' \
+ f'--acl bucket-owner-full-control --quiet --delete'
+ os.system(command)
+
+ def print_improvements(self, key, value, idx, is_best):
+ """Print color-coded changes in tracked metrics"""
+
+ font1 = {'color': 'cyan', 'attrs':('dark', 'bold')}
+ font2 = {'color': 'cyan', 'attrs': ('bold',)}
+ font3 = {'color': 'yellow', 'attrs': ('bold',)}
+ font4 = {'color': 'green', 'attrs': ('bold',)}
+ font5 = {'color': 'red', 'attrs': ('bold',)}
+
+ current_inf = self.best[idx] == self.torch_inf or \
+ self.best[idx] == -self.torch_inf
+
+ print(
+ pcolor(f'{key}', **font2) + \
+ pcolor(f' ({self.mode[idx]}) : ', **font1) + \
+ ('' if current_inf else
+ pcolor('%3.6f' % self.previous[idx], **font3) +
+ pcolor(f' -> ', **font1)) + \
+ (pcolor('%3.6f' % value, **font4) if is_best else
+ pcolor('%3.6f' % value, **font5)) +
+ ('' if current_inf else
+ pcolor(' (%3.6f)' % self.best[idx], **font2))
+ )
+
+ def save(self, wrapper, epoch, verbose=True):
+ """Save model"""
+ # Do nothing if no path is provided
+ if self.path:
+
+ name = '%03d.ckpt' % epoch
+ folder = os.path.join(self.path, 'models')
+
+ os.makedirs(folder, exist_ok=True)
+ folder_name = os.path.join(folder, name)
+ self.save_model(wrapper, folder_name, epoch)
+ self.top.append(folder_name)
+ if 0 < self.keep_top < len(self.top):
+ self.del_model(self.top.pop(0))
+ if self.s3_url:
+ self.sync_s3(verbose=False)
+
+ if verbose:
+ print()
+
+ def check_and_save(self, wrapper, metrics, prefixes, epoch, verbose=True):
+ """Check if model should be saved and maybe save it"""
+ # Not tracking any metric, save every iteration
+ if self.num_tracking == 0:
+ # Do nothing if no path is provided
+ if self.path:
+
+ name = '%03d.ckpt' % epoch
+ folder = os.path.join(self.path, 'models')
+
+ os.makedirs(folder, exist_ok=True)
+ folder_name = os.path.join(folder, name)
+ self.save_model(wrapper, folder_name, epoch)
+ self.top.append(folder_name)
+ if 0 < self.keep_top < len(self.top):
+ self.del_model(self.top.pop(0))
+ if self.s3_url:
+ self.sync_s3(verbose=False)
+
+ # Check if saving for every metric
+ else:
+
+ for idx in range(self.num_tracking):
+
+ key = '{}-{}'.format(prefixes[self.dataset[idx]], self.monitor[idx])
+ value = metrics[key]
+
+ if self.mode[idx] == 'min':
+ is_best = value < self.best[idx]
+ will_store = len(self.store_val[idx]) < self.keep_top or \
+ value < np.max(self.store_val[idx])
+ store_idx = 0 if len(self.store_val[idx]) == 0 else int(np.argmax(self.store_val[idx]))
+ else:
+ is_best = value > self.best[idx]
+ will_store = len(self.store_val[idx]) < self.keep_top or \
+ value > np.min(self.store_val[idx])
+ store_idx = 0 if len(self.store_val[idx]) == 0 else int(np.argmin(self.store_val[idx]))
+
+ if verbose:
+ self.print_improvements(key, value, idx, is_best)
+
+ self.previous[idx] = value
+
+ if is_best:
+ self.best[idx] = value
+
+ if is_best or will_store:
+
+ if self.path:
+
+ name = '%03d_%3.6f.ckpt' % (epoch, value)
+ folder = os.path.join(self.path, key)
+
+ os.makedirs(folder, exist_ok=True)
+ folder_name = os.path.join(folder, name)
+ self.save_model(wrapper, folder_name, epoch)
+ self.top[idx].append(folder_name)
+ self.store_val[idx].append(value)
+ if 0 < self.keep_top < len(self.top[idx]):
+ self.del_model(self.top[idx].pop(store_idx))
+ self.store_val[idx].pop(store_idx)
+ if self.s3_url:
+ self.sync_s3(verbose=False)
+
+ if verbose:
+ print()
diff --git a/vidar/core/logger.py b/vidar/core/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..98d2ea88ecd4392f3e2086a52ce2125cabc677c8
--- /dev/null
+++ b/vidar/core/logger.py
@@ -0,0 +1,311 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from collections import OrderedDict
+from copy import deepcopy
+
+import torch
+import torch.nn as nn
+import wandb
+
+from vidar.utils.config import cfg_has
+from vidar.utils.distributed import world_size
+from vidar.utils.logging import pcolor
+from vidar.utils.types import is_dict, is_tensor, is_seq, is_namespace
+from vidar.utils.viz import viz_depth, viz_inv_depth, viz_normals, viz_optical_flow, viz_camera
+
+
+class WandbLogger:
+ """
+ Wandb logger class to monitor training
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration with parameters
+ verbose : Bool
+ Print information on screen if enabled
+ """
+ def __init__(self, cfg, verbose=False):
+ super().__init__()
+
+ self.num_logs = {
+ 'train': cfg_has(cfg, 'num_train_logs', 0),
+ 'val': cfg_has(cfg, 'num_validation_logs', 0),
+ 'test': cfg_has(cfg, 'num_test_logs', 0),
+ }
+
+ self._name = cfg.name if cfg_has(cfg, 'name') else None
+ self._dir = cfg.folder
+ self._entity = cfg.entity
+ self._project = cfg.project
+
+ self._tags = cfg_has(cfg, 'tags', '')
+ self._notes = cfg_has(cfg, 'notes', '')
+
+ self._id = None
+ self._anonymous = None
+ self._log_model = True
+
+ self._experiment = self._create_experiment()
+ self._metrics = OrderedDict()
+
+ self.only_first = cfg_has(cfg, 'only_first', False)
+
+ cfg.name = self.run_name
+ cfg.url = self.run_url
+
+ if verbose:
+ self.print()
+
+ @staticmethod
+ def finish():
+ """Finish wandb session"""
+ wandb.finish()
+
+ def print(self):
+ """Print information on screen"""
+
+ font_base = {'color': 'red', 'attrs': ('bold', 'dark')}
+ font_name = {'color': 'red', 'attrs': ('bold',)}
+ font_underline = {'color': 'red', 'attrs': ('underline',)}
+
+ print(pcolor('#' * 60, **font_base))
+ print(pcolor('### WandB: ', **font_base) + \
+ pcolor('{}'.format(self.run_name), **font_name))
+ print(pcolor('### ', **font_base) + \
+ pcolor('{}'.format(self.run_url), **font_underline))
+ print(pcolor('#' * 60, **font_base))
+
+ def __getstate__(self):
+ """Get the current logger state"""
+ state = self.__dict__.copy()
+ state['_id'] = self._experiment.id if self._experiment is not None else None
+ state['_experiment'] = None
+ return state
+
+ def _create_experiment(self):
+ """Creates and returns a new experiment"""
+ experiment = wandb.init(
+ name=self._name, dir=self._dir, project=self._project,
+ anonymous=self._anonymous, reinit=True, id=self._id, notes=self._notes,
+ resume='allow', tags=self._tags, entity=self._entity
+ )
+ wandb.run.save()
+ return experiment
+
+ def watch(self, model: nn.Module, log='gradients', log_freq=100):
+ """Watch training parameters"""
+ self.experiment.watch(model, log=log, log_freq=log_freq)
+
+ @property
+ def experiment(self):
+ """Returns the experiment (creates a new if it doesn't exist)"""
+ if self._experiment is None:
+ self._experiment = self._create_experiment()
+ return self._experiment
+
+ @property
+ def run_name(self):
+ """Returns run name"""
+ return wandb.run.name if self._experiment else None
+
+ @property
+ def run_url(self):
+ """Returns run URL"""
+ return f'https://app.wandb.ai/' \
+ f'{wandb.run.entity}/' \
+ f'{wandb.run.project}/runs/' \
+ f'{wandb.run.id}' if self._experiment else None
+
+ def log_config(self, cfg):
+ """Log model configuration"""
+ cfg = recursive_convert_config(deepcopy(cfg))
+ self.experiment.config.update(cfg, allow_val_change=True)
+
+ def log_metrics(self, metrics):
+ """Log training metrics"""
+ self._metrics.update(metrics)
+ if 'epochs' in metrics or 'samples' in metrics:
+ self.experiment.log(self._metrics)
+ self._metrics.clear()
+
+ def log_images(self, batch, output, prefix, ontology=None):
+ """
+ Log images depending on its nature
+
+ Parameters
+ ----------
+ batch : Dict
+ Dictionary containing batch information
+ output : Dict
+ Dictionary containing output information
+ prefix : String
+ Prefix string for the log name
+ ontology : Dict
+ Dictionary with ontology information
+ """
+ for data, suffix in zip([batch, output['predictions']], ['-gt', '-pred']):
+ for key in data.keys():
+ if key.startswith('rgb'):
+ self._metrics.update(log_rgb(
+ key, prefix + suffix, data, only_first=self.only_first))
+ elif key.startswith('depth'):
+ self._metrics.update(log_depth(
+ key, prefix + suffix, data, only_first=self.only_first))
+ elif key.startswith('inv_depth'):
+ self._metrics.update(log_inv_depth(
+ key, prefix + suffix, data, only_first=self.only_first))
+ elif 'normals' in key:
+ self._metrics.update(log_normals(
+ key, prefix + suffix, data, only_first=self.only_first))
+ elif key.startswith('stddev'):
+ self._metrics.update(log_stddev(
+ key, prefix + suffix, data, only_first=self.only_first))
+ elif key.startswith('logvar'):
+ self._metrics.update(log_logvar(
+ key, prefix + suffix, data, only_first=self.only_first))
+ elif 'optical_flow' in key:
+ self._metrics.update(log_optical_flow(
+ key, prefix + suffix, data, only_first=self.only_first))
+ elif 'mask' in key or 'valid' in key:
+ self._metrics.update(log_rgb(
+ key, prefix, data, only_first=self.only_first))
+ # elif 'camera' in key:
+ # self._metrics.update(log_camera(
+ # key, prefix + suffix, data, only_first=self.only_first))
+ # elif 'uncertainty' in key:
+ # self._metrics.update(log_uncertainty(key, prefix, data))
+ # elif 'semantic' in key and ontology is not None:
+ # self._metrics.update(log_semantic(key, prefix, data, ontology=ontology))
+ # if 'scene_flow' in key:
+ # self._metrics.update(log_scene_flow(key, prefix_idx, data))
+ # elif 'score' in key:
+ # # Log score as image heatmap
+ # self._metrics.update(log_keypoint_score(key, prefix, data))
+
+ def log_data(self, mode, batch, output, dataset, prefix, ontology=None):
+ """Helper function used to log images"""
+ idx = batch['idx'][0]
+ num_logs = self.num_logs[mode]
+ if num_logs > 0:
+ interval = (len(dataset) // world_size() // num_logs) * world_size()
+ if interval == 0 or (idx % interval == 0 and idx < interval * num_logs):
+ prefix = '{}-{}-{}'.format(mode, prefix, batch['idx'][0].item())
+ # batch, output = prepare_logging(batch, output)
+ self.log_images(batch, output, prefix, ontology=ontology)
+
+
+def recursive_convert_config(cfg):
+ """Convert configuration to dictionary recursively"""
+ cfg = cfg.__dict__
+ for key, val in cfg.items():
+ if is_namespace(val):
+ cfg[key] = recursive_convert_config(val)
+ return cfg
+
+
+def prep_image(key, prefix, image):
+ """Prepare image for logging"""
+ if is_tensor(image):
+ if image.dim() == 2:
+ image = image.unsqueeze(0)
+ if image.dim() == 4:
+ image = image[0]
+ image = image.detach().permute(1, 2, 0).cpu().numpy()
+ prefix_key = '{}-{}'.format(prefix, key)
+ return {prefix_key: wandb.Image(image, caption=key)}
+
+
+def log_sequence(key, prefix, data, i, only_first, fn):
+ """Logs a sequence of images (list, tuple or dict)"""
+ log = {}
+ if is_dict(data):
+ for ctx, dict_val in data.items():
+ if is_seq(dict_val):
+ if only_first:
+ dict_val = dict_val[:1]
+ for idx, list_val in enumerate(dict_val):
+ if list_val.dim() == 5:
+ for j in range(list_val.shape[1]):
+ log.update(fn('%s(%s_%d)_%d' % (key, str(ctx), j, idx), prefix, list_val[:, j], i))
+ else:
+ log.update(fn('%s(%s)_%d' % (key, str(ctx), idx), prefix, list_val, i))
+ else:
+ if dict_val.dim() == 5:
+ for j in range(dict_val.shape[1]):
+ log.update(fn('%s(%s_%d)' % (key, str(ctx), j), prefix, dict_val[:, j], i))
+ else:
+ log.update(fn('%s(%s)' % (key, str(ctx)), prefix, dict_val, i))
+ elif is_seq(data):
+ if only_first:
+ data = data[:1]
+ for idx, list_val in enumerate(data):
+ log.update(fn('%s_%d' % (key, idx), prefix, list_val, i))
+ else:
+ log.update(fn('%s' % key, prefix, data, i))
+ return log
+
+
+def log_rgb(key, prefix, batch, i=0, only_first=None):
+ """Log RGB image"""
+ rgb = batch[key] if is_dict(batch) else batch
+ if is_seq(rgb) or is_dict(rgb):
+ return log_sequence(key, prefix, rgb, i, only_first, log_rgb)
+ return prep_image(key, prefix, rgb[i].clamp(min=0.0, max=1.0))
+
+
+def log_depth(key, prefix, batch, i=0, only_first=None):
+ """Log depth map"""
+ depth = batch[key] if is_dict(batch) else batch
+ if is_seq(depth) or is_dict(depth):
+ return log_sequence(key, prefix, depth, i, only_first, log_depth)
+ return prep_image(key, prefix, viz_depth(depth[i], filter_zeros=True))
+
+
+def log_inv_depth(key, prefix, batch, i=0, only_first=None):
+ """Log inverse depth map"""
+ inv_depth = batch[key] if is_dict(batch) else batch
+ if is_seq(inv_depth) or is_dict(inv_depth):
+ return log_sequence(key, prefix, inv_depth, i, only_first, log_inv_depth)
+ return prep_image(key, prefix, viz_inv_depth(inv_depth[i]))
+
+
+def log_normals(key, prefix, batch, i=0, only_first=None):
+ """Log normals"""
+ normals = batch[key] if is_dict(batch) else batch
+ if is_seq(normals) or is_dict(normals):
+ return log_sequence(key, prefix, normals, i, only_first, log_normals)
+ return prep_image(key, prefix, viz_normals(normals[i]))
+
+
+def log_optical_flow(key, prefix, batch, i=0, only_first=None):
+ """Log optical flow"""
+ optical_flow = batch[key] if is_dict(batch) else batch
+ if is_seq(optical_flow) or is_dict(optical_flow):
+ return log_sequence(key, prefix, optical_flow, i, only_first, log_optical_flow)
+ return prep_image(key, prefix, viz_optical_flow(optical_flow[i]))
+
+
+def log_stddev(key, prefix, batch, i=0, only_first=None):
+ """Log standard deviation"""
+ stddev = batch[key] if is_dict(batch) else batch
+ if is_seq(stddev) or is_dict(stddev):
+ return log_sequence(key, prefix, stddev, i, only_first, log_stddev)
+ return prep_image(key, prefix, viz_inv_depth(stddev[i], colormap='jet'))
+
+
+def log_logvar(key, prefix, batch, i=0, only_first=None):
+ """Log standard deviation"""
+ logvar = batch[key] if is_dict(batch) else batch
+ if is_seq(logvar) or is_dict(logvar):
+ return log_sequence(key, prefix, logvar, i, only_first, log_logvar)
+ return prep_image(key, prefix, viz_inv_depth(torch.exp(logvar[i]), colormap='jet'))
+
+
+def log_camera(key, prefix, batch, i=0, only_first=None):
+ """Log camera"""
+ camera = batch[key] if is_dict(batch) else batch
+ if is_seq(camera) or is_dict(camera):
+ return log_sequence(key, prefix, camera, i, only_first, log_camera)
+ return prep_image(key, prefix, viz_camera(camera[i]))
+
diff --git a/vidar/core/saver.py b/vidar/core/saver.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0e0506a8496ea70d38ad7346daeb5da1e95ac1c
--- /dev/null
+++ b/vidar/core/saver.py
@@ -0,0 +1,214 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import os
+
+from vidar.utils.config import cfg_has
+from vidar.utils.data import make_list
+from vidar.utils.types import is_dict, is_list
+from vidar.utils.viz import viz_depth, viz_optical_flow
+from vidar.utils.write import write_depth, write_image, write_pickle, write_npz
+
+
+class Saver:
+ """
+ Wandb logger class to monitor training
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration with parameters
+ ckpt : String
+ Name of the model checkpoint (used to create the save folder)
+ """
+ def __init__(self, cfg, ckpt=None):
+ self.folder = cfg_has(cfg, 'folder', None)
+
+ self.rgb = make_list(cfg.rgb) if cfg_has(cfg, 'rgb') else []
+ self.depth = make_list(cfg.depth) if cfg_has(cfg, 'depth') else []
+ self.pose = make_list(cfg.pose) if cfg_has(cfg, 'pose') else []
+ self.optical_flow = make_list(cfg.optical_flow) if cfg_has(cfg, 'optical_flow') else []
+
+ self.store_data = cfg_has(cfg, 'store_data', False)
+ self.separate = cfg.has('separate', False)
+
+ self.ckpt = None if ckpt is None else \
+ os.path.splitext(os.path.basename(ckpt))[0]
+
+ self.naming = cfg_has(cfg, 'naming', 'filename')
+ assert self.naming in ['filename', 'splitname'], \
+ 'Invalid naming for saver: {}'.format(self.naming)
+
+ def get_filename(self, path, batch, idx, i):
+ """Get filename based on input information"""
+ if self.naming == 'filename':
+ filename = os.path.join(path, batch['filename'][0][i]).replace('{}', 'rgb')
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+ return filename
+ elif self.naming == 'splitname':
+ if self.separate:
+ return os.path.join(path, '%010d' % idx, '%010d' % idx)
+ else:
+ return os.path.join(path, '%010d' % idx)
+ else:
+ raise NotImplementedError('Invalid naming for saver: {}'.format(self.naming))
+
+ def save_data(self, batch, output, prefix):
+ """
+ Prepare for data saving
+
+ Parameters
+ ----------
+ batch : Dict
+ Dictionary with batch information
+ output : Dict
+ Dictionary with output information
+ prefix : String
+ Prefix string for the log name
+ """
+ if self.folder is None:
+ return
+
+ idx = batch['idx']
+ predictions = output['predictions']
+
+ path = os.path.join(self.folder, prefix)
+ if self.ckpt is not None:
+ path = os.path.join(path, self.ckpt)
+ os.makedirs(path, exist_ok=True)
+
+ self.save(batch, predictions, path, idx, 0)
+
+ def save(self, batch, predictions, path, idx, i):
+ """
+ Save batch and prediction information
+
+ Parameters
+ ----------
+ batch : Dict
+ Dictionary with batch information
+ predictions : Dict
+ Dictionary with output predictions
+ path : String
+ Path where data will be saved
+ idx : Int
+ Batch index in the split
+ i : Int
+ Index within batch
+
+ Returns
+ -------
+ data : Dict
+ Dictionary with output data that was saved
+ """
+
+ filename = self.get_filename(path, batch, idx, i)
+
+ raw_intrinsics = batch['raw_intrinsics'][0][i].cpu() if 'raw_intrinsics' in batch else \
+ batch['intrinsics'][0][i].cpu() if 'intrinsics' in batch else None
+ intrinsics = batch['intrinsics'][0][i].cpu() if 'intrinsics' in batch else None
+
+ data = {
+ 'raw_intrinsics': raw_intrinsics,
+ 'intrinsics': intrinsics,
+ }
+
+ for key in batch.keys():
+
+ if key.startswith('rgb'):
+ data[key + '_gt'] = {k: v[i].cpu() for k, v in batch[key].items()}
+ for ctx in batch[key].keys():
+ rgb = batch[key][ctx][i].cpu()
+ if 'gt' in self.rgb:
+ if rgb.dim() == 5:
+ for j in range(rgb.shape[1]):
+ write_image('%s_%s(%d_%d)_gt.png' % (filename, key, j, ctx),
+ rgb[:, j])
+ else:
+ write_image('%s_%s(%d)_gt.png' % (filename, key, ctx),
+ rgb)
+
+ if key.startswith('depth'):
+ data[key + '_gt'] = {k: v[i].cpu() for k, v in batch[key].items()}
+ for ctx in batch[key].keys():
+ depth = batch[key][ctx][i].cpu()
+ if 'gt_png' in self.depth:
+ write_depth('%s_%s(%d)_gt.png' % (filename, key, ctx),
+ depth)
+ if 'gt_npz' in self.depth:
+ write_depth('%s_%s(%d)_gt.npz' % (filename, key, ctx),
+ depth, intrinsics=raw_intrinsics)
+ if 'gt_viz' in self.depth:
+ write_image('%s_%s(%d)_gt_viz.png' % (filename, key, ctx),
+ viz_depth(depth, filter_zeros=True))
+
+ if key.startswith('pose'):
+ pose = {k: v[i].cpu() for k, v in batch[key].items()}
+ data[key + '_gt'] = pose
+ if 'gt' in self.pose:
+ write_pickle('%s_%s_gt' % (filename, key),
+ pose)
+
+ for key in predictions.keys():
+
+ if key.startswith('rgb'):
+ data[key + '_pred'] = {k: v[i].cpu() for k, v in predictions[key].items()}
+ for ctx in predictions[key].keys():
+ rgb = predictions[key][ctx][i].cpu()
+ if 'pred' in self.rgb:
+ if rgb.dim() == 5:
+ for j in range(rgb.shape[1]):
+ write_image('%s_%s(%d_%d)_pred.png' % (filename, key, j, ctx),
+ rgb[:, j])
+ else:
+ write_image('%s_%s(%d)_pred.png' % (filename, key, ctx),
+ rgb)
+
+ if key.startswith('depth'):
+ data[key + '_pred'] = {k: v[i].cpu() for k, v in predictions[key].items()}
+ for ctx in predictions[key].keys():
+ depth = predictions[key][ctx][0][i].cpu()
+ if 'png' in self.depth:
+ write_depth('%s_%s(%d)_pred.png' % (filename, key, ctx),
+ depth)
+ if 'npz' in self.depth:
+ write_depth('%s_%s(%d)_pred.npz' % (filename, key, ctx),
+ depth, intrinsics=intrinsics)
+ if 'viz' in self.depth:
+ write_image('%s_%s(%d)_pred_viz.png' % (filename, key, ctx),
+ viz_depth(depth))
+
+ if key.startswith('pose'):
+ pose = {key: val[i].cpu() for key, val in predictions[key].items()}
+ data[key + '_pred'] = pose
+ if 'pred' in self.pose:
+ write_pickle('%s_%s_pred' % (filename, key),
+ pose)
+
+ if key.startswith('fwd_optical_flow'):
+ optical_flow = {key: val[i].cpu() for key, val in predictions[key].items()}
+ data[key + '_pred'] = optical_flow
+ if 'npz' in self.optical_flow:
+ write_npz('%s_%s_pred' % (filename, key),
+ {'fwd_optical_flow': optical_flow})
+ if 'viz' in self.optical_flow:
+ for ctx in optical_flow.keys():
+ write_image('%s_%s(%d)_pred_viz.png' % (filename, key, ctx),
+ viz_optical_flow(optical_flow[ctx]))
+
+ if key.startswith('mask'):
+ if is_dict(predictions[key]):
+ data[key] = {k: v[i].cpu() for k, v in predictions[key].items()}
+ for ctx in data[key].keys():
+ write_image('%s_%s(%d)_pred_viz.png' % (filename, key, ctx), predictions[key][ctx][0])
+ elif is_list(predictions[key]):
+ data[key] = [v[i].cpu() for k, v in predictions[key]]
+ for ctx in data[key]:
+ write_image('%s_%s(%d)_pred_viz.png' % (filename, key, ctx), predictions[key][ctx][0])
+ else:
+ data[key] = predictions[key][i].cpu()
+ write_image('%s_%s_pred_viz.png' % (filename, key), predictions[key][0])
+
+ if self.store_data:
+ write_pickle('%s' % filename, data)
+
+ return data
diff --git a/vidar/core/trainer.py b/vidar/core/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bcbe596068dbe1d349760f2c67508e46882f297
--- /dev/null
+++ b/vidar/core/trainer.py
@@ -0,0 +1,472 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from collections import OrderedDict
+
+import torch
+from tqdm import tqdm
+
+from vidar.core.checkpoint import ModelCheckpoint
+from vidar.core.logger import WandbLogger
+from vidar.core.saver import Saver
+from vidar.utils.config import cfg_has, dataset_prefix
+from vidar.utils.data import make_list, keys_in
+from vidar.utils.distributed import on_rank_0, rank, world_size, print0, dist_mode
+from vidar.utils.logging import pcolor, AvgMeter
+from vidar.utils.setup import setup_dataloader, reduce
+from vidar.utils.types import is_dict, is_seq, is_numpy, is_tensor, is_list
+
+
+def sample_to_cuda(sample, proc_rank, dtype=None):
+ """
+ Copy sample to GPU
+
+ Parameters
+ ----------
+ sample : Dict
+ Dictionary with sample information
+ proc_rank : Int
+ Process rank
+ dtype : torch.Type
+ Data type for conversion
+
+ Returns
+ -------
+ sample : Dict
+ Dictionary with sample on the GPU
+ """
+ # Do nothing if cuda is not available
+ if not torch.cuda.is_available():
+ return sample
+ # If it's a sequence (list or tuple)
+ if is_seq(sample):
+ return [sample_to_cuda(val, proc_rank, dtype) for val in sample]
+ # If it's a dictionary
+ elif is_dict(sample):
+ return {key: sample_to_cuda(sample[key], proc_rank, dtype) for key in sample.keys()}
+ # If it's a torch tensor
+ elif is_tensor(sample):
+ dtype = dtype if torch.is_floating_point(sample) else None
+ return sample.to(f'cuda:{proc_rank}', dtype=dtype)
+ # If it's a numpy array
+ elif is_numpy(sample):
+ tensor_data = torch.Tensor(sample)
+ dtype = dtype if torch.is_floating_point(tensor_data) else None
+ return tensor_data.to(f'cuda:{proc_rank}', dtype=dtype)
+ # Otherwise, do nothing
+ else:
+ return sample
+
+
+class Trainer:
+ """
+ Trainer class for model optimization and inference
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration with parameters
+ ckpt : String
+ Name of the model checkpoint to start from
+ """
+ def __init__(self, cfg, ckpt=None):
+ super().__init__()
+
+ self.avg_losses = {}
+ self.min_epochs = cfg_has(cfg.wrapper, 'min_epochs', 0)
+ self.max_epochs = cfg_has(cfg.wrapper, 'max_epochs', 100)
+
+ self.validate_first = cfg_has(cfg.wrapper, 'validate_first', False)
+ self.find_unused_parameters = cfg_has(cfg.wrapper, 'find_unused_parameters', False)
+ self.grad_scaler = cfg_has(cfg.wrapper, 'grad_scaler', False) and torch.cuda.is_available()
+
+ self.saver = self.logger = self.checkpoint = None
+ self.prep_logger_and_checkpoint(cfg)
+
+ self.prep_saver(cfg, ckpt)
+
+ self.all_modes = ['train', 'mixed', 'validation', 'test']
+ self.train_modes = ['train', 'mixed']
+
+ self.current_epoch = 0
+ self.training_bar_metrics = cfg_has(cfg.wrapper, 'training_bar_metrics', [])
+
+ @property
+ def progress(self):
+ """Current epoch progress (percentage)"""
+ return self.current_epoch / self.max_epochs
+
+ @property
+ def proc_rank(self):
+ """Process rank"""
+ return rank()
+
+ @property
+ def world_size(self):
+ """World size"""
+ return world_size()
+
+ @property
+ def is_rank_0(self):
+ """True if worker is on rank 0"""
+ return self.proc_rank == 0
+
+ def param_logs(self, optimizers):
+ """Returns various logs for tracking"""
+ params = OrderedDict()
+ for key, val in optimizers.items():
+ params[f'{key}_learning_rate'] = val['optimizer'].param_groups[0]['lr']
+ params[f'{key}_weight_decay'] = val['optimizer'].param_groups[0]['weight_decay']
+ params['progress'] = self.progress
+ return {
+ **params,
+ }
+
+ @on_rank_0
+ def prep_logger_and_checkpoint(self, cfg):
+ """Prepare logger and checkpoint class if requested"""
+ add_logger = cfg_has(cfg, 'wandb')
+ add_checkpoint = cfg_has(cfg, 'checkpoint')
+
+ if add_logger:
+ self.logger = WandbLogger(cfg.wandb, verbose=True)
+ if add_checkpoint and not cfg_has(cfg.checkpoint, 'name'):
+ cfg.checkpoint.name = self.logger.run_name
+ else:
+ self.logger = None
+
+ if add_checkpoint:
+ self.checkpoint = ModelCheckpoint(cfg.checkpoint, verbose=True)
+ else:
+ self.checkpoint = None
+
+ if add_logger:
+ self.logger.log_config(cfg)
+
+ def prep_saver(self, cfg, ckpt=None):
+ """Prepare saver class if requested"""
+ ckpt = ckpt if ckpt is not None else cfg.arch.model.has('checkpoint', None)
+ add_saver = cfg_has(cfg, 'save')
+
+ if add_saver:
+ print0(pcolor('#' * 60, color='red', attrs=('dark',)))
+ print0(pcolor('### Saving data to: %s' % cfg.save.folder, color='red'))
+ print0(pcolor('#' * 60, color='red', attrs=('dark',)))
+ self.saver = Saver(cfg.save, ckpt)
+
+ @on_rank_0
+ def check_and_save(self, wrapper, output, prefixes):
+ """Check for conditions and save if it's time"""
+ if self.checkpoint is not None:
+ self.checkpoint.check_and_save(
+ wrapper, output, prefixes, epoch=self.current_epoch)
+
+ @on_rank_0
+ def log_losses_and_metrics(self, metrics=None, optimizers=None):
+ """Log losses and metrics on wandb"""
+ if self.logger is not None:
+ self.logger.log_metrics({
+ '{}'.format(key): val.get() for key, val in self.avg_losses.items()
+ })
+ if optimizers is not None:
+ self.logger.log_metrics(self.param_logs(optimizers))
+ if metrics is not None:
+ self.logger.log_metrics({
+ **metrics, 'epochs': self.current_epoch,
+ })
+
+ @on_rank_0
+ def print_logger_and_checkpoint(self):
+ """Print logger and checkpoint information"""
+ font_base = {'color': 'red', 'attrs': ('bold', 'dark')}
+ font_name = {'color': 'red', 'attrs': ('bold',)}
+ font_underline = {'color': 'red', 'attrs': ('underline',)}
+
+ if self.logger or self.checkpoint:
+ print(pcolor('#' * 120, **font_base))
+ if self.logger:
+ print(pcolor('### WandB: ', **font_base) + \
+ pcolor('{}'.format(self.logger.run_name), **font_name) + \
+ pcolor(' - ', **font_base) + \
+ pcolor('{}'.format(self.logger.run_url), **font_underline))
+ if self.checkpoint and self.checkpoint.s3_url is not None:
+ print(pcolor('### Checkpoint: ', **font_base) + \
+ pcolor('{}'.format(self.checkpoint.s3_url), **font_underline))
+ if self.logger or self.checkpoint:
+ print(pcolor('#' * 120 + '\n', **font_base))
+
+ @on_rank_0
+ def update_train_progress_bar(self, progress_bar):
+ """Update training progress bar on screen"""
+ string = '| {} | Loss {:.3f}'.format(
+ self.current_epoch, self.avg_losses['loss'].get())
+ bar_keys = self.training_bar_metrics
+ for key in keys_in(self.avg_losses, bar_keys):
+ name, abbrv = (key[0], key[1]) if is_list(key) else (key, key)
+ string += ' | {} {:.2f}'.format(abbrv, self.avg_losses[name].get())
+ progress_bar.set_description(string)
+
+ @on_rank_0
+ def update_averages(self, output):
+ """Update loss averages"""
+ averages = {'loss': output['loss'], **output['metrics']}
+ for key in averages.keys():
+ if key not in self.avg_losses.keys():
+ self.avg_losses[key] = AvgMeter(50)
+ self.avg_losses[key](averages[key].item() if is_tensor(averages[key]) else averages[key])
+
+ def train_progress_bar(self, dataloader, ncols=None, aux_dataloader=None):
+ """Print training progress bar on screen"""
+ full_dataloader = dataloader if aux_dataloader is None else zip(dataloader, aux_dataloader)
+ return tqdm(enumerate(full_dataloader, 0),
+ unit='im', unit_scale=self.world_size * dataloader.batch_size,
+ total=len(dataloader), smoothing=0,
+ disable=not self.is_rank_0, ncols=ncols)
+
+ def val_progress_bar(self, dataloader, prefix, ncols=None):
+ """Print validation progress bar on screen"""
+ return tqdm(enumerate(dataloader, 0),
+ unit='im', unit_scale=self.world_size * dataloader.batch_size,
+ total=len(dataloader), smoothing=0,
+ disable=not self.is_rank_0, ncols=ncols,
+ desc=prefix)
+
+ def prepare_distributed_model(self, wrapper):
+ """Prepare model for distributed training or not (CPU/GPU/DDP)"""
+ if dist_mode() == 'cpu':
+ wrapper.arch = wrapper.arch
+ elif dist_mode() == 'gpu':
+ wrapper = wrapper.cuda(self.proc_rank)
+ wrapper.arch = wrapper.arch
+ elif dist_mode() == 'ddp':
+ wrapper = wrapper.cuda(self.proc_rank)
+ wrapper.arch = torch.nn.parallel.DistributedDataParallel(
+ wrapper.arch, device_ids=[self.proc_rank],
+ find_unused_parameters=self.find_unused_parameters,
+ broadcast_buffers=True)
+ else:
+ raise ValueError('Wrong distributed mode {}'.format(dist_mode))
+ return wrapper
+
+ def prepare_dataloaders(self, wrapper):
+ """Prepare dataloaders for training and inference"""
+ font1 = {'color': 'blue', 'attrs': ('dark', 'bold')}
+ font2 = {'color': 'blue', 'attrs': ('bold',)}
+
+ print0(pcolor('#' * 60, **font1))
+
+ if dist_mode() == 'cpu':
+ print0(pcolor(f'### ', **font1) +
+ pcolor(f'CPU Training', **font2))
+ elif dist_mode() == 'gpu':
+ print0(pcolor(f'### ', **font1) +
+ pcolor(f'GPU Training', **font2))
+ elif dist_mode() == 'ddp':
+ print0(pcolor(f'### ', **font1) +
+ pcolor(f'DDP Training ', **font2) +
+ pcolor(f'with ', **font1) +
+ pcolor(f'{self.world_size}', **font2) +
+ pcolor(f' GPUs', **font1))
+
+ # Send wrapper to GPU
+ wrapper = self.prepare_distributed_model(wrapper)
+
+ for key in wrapper.datasets_cfg.keys():
+ wrapper.datasets_cfg[key] = make_list(wrapper.datasets_cfg[key])
+
+ # Prepare dataloaders
+ dataloaders = {
+ key: setup_dataloader(val, wrapper.datasets_cfg[key][0].dataloader, key)
+ for key, val in wrapper.datasets.items() if key in wrapper.datasets_cfg.keys()
+ }
+
+ # Prepare prefixes
+
+ prefixes = {
+ key: [dataset_prefix(wrapper.datasets_cfg[key][n], n) for n in range(len(val))]
+ for key, val in wrapper.datasets_cfg.items() if 'name' in wrapper.datasets_cfg[key][0].__dict__.keys()
+ }
+
+ # Reduce information
+ reduced_dataloaders = reduce(dataloaders, self.all_modes, self.train_modes)
+ reduced_prefixes = reduce(prefixes, self.all_modes, self.train_modes)
+
+ print0(pcolor('#' * 60, **font1))
+
+ return reduced_dataloaders, reduced_prefixes
+
+ def filter_optimizers(self, optimizers):
+ """Filter optimizers to find those being used at each epoch"""
+ in_optimizers, out_optimizers = {}, {}
+ for key, val in optimizers.items():
+ if 'stop_epoch' not in val['settings'] or \
+ val['settings']['stop_epoch'] >= self.current_epoch:
+ in_optimizers[key] = val['optimizer']
+ else:
+ out_optimizers[key] = val['optimizer']
+
+ if rank() == 0:
+
+ string = pcolor('Optimizing: ', color='yellow')
+ for key, val in in_optimizers.items():
+ string += pcolor('{}'.format(key), color='green', attrs=('bold', 'dark'))
+ string += pcolor(' ({}) '.format(val.param_groups[0]['lr']),
+ color='green', attrs=('dark',))
+ for key, val in out_optimizers.items():
+ string += pcolor('{}'.format(key), color='cyan', attrs=('bold', 'dark'))
+ string += pcolor(' ({}) '.format(val.param_groups[0]['lr']),
+ color='cyan', attrs=('dark',))
+
+ print(pcolor('#' * 120, color='yellow', attrs=('dark',)))
+ print(string)
+ print(pcolor('#' * 120, color='yellow', attrs=('dark',)))
+ print()
+
+ return in_optimizers, out_optimizers
+
+ def learn(self, wrapper):
+ """Entry-point class for training a model"""
+ # Get optimizers and schedulers
+ optimizers, schedulers = wrapper.configure_optimizers_and_schedulers()
+
+ # Get gradient scaler if requested
+ scaler = torch.cuda.amp.GradScaler() if self.grad_scaler else None
+
+ # Get learn information
+ dataloaders, prefixes = self.prepare_dataloaders(wrapper)
+ aux_dataloader = None if 'mixed' not in dataloaders else dataloaders['mixed']
+
+ # Check for train and validation dataloaders
+ has_train_dataloader = 'train' in dataloaders
+ has_validation_dataloader = 'validation' in dataloaders
+
+ # Validate before training if requested
+ if self.validate_first and has_validation_dataloader:
+ validation_output = self.validate('validation', dataloaders, prefixes, wrapper)
+ self.post_validation(validation_output, optimizers, prefixes['validation'], wrapper)
+ else:
+ self.current_epoch += 1
+
+ # Epoch loop
+ if has_train_dataloader:
+ for epoch in range(self.current_epoch, self.max_epochs + 1):
+
+ # Train and log
+ self.train(dataloaders['train'], optimizers, schedulers, wrapper, scaler=scaler,
+ aux_dataloader=aux_dataloader)
+
+ # Validate, save and log
+ if has_validation_dataloader:
+ validation_output = self.validate('validation', dataloaders, prefixes, wrapper)
+ self.post_validation(validation_output, optimizers, prefixes['validation'], wrapper)
+
+ # Take a scheduler step
+ if wrapper.update_schedulers == 'epoch':
+ for scheduler in schedulers.values():
+ scheduler.step()
+
+ # Finish logger if available
+ if self.logger:
+ self.logger.finish()
+
+ def train(self, dataloader, optimizers, schedulers, wrapper, scaler=None, aux_dataloader=None):
+ """Training loop for each epoch"""
+ # Choose which optimizers to use
+ in_optimizers, out_optimizers = self.filter_optimizers(optimizers)
+
+ # Set wrapper to train
+ wrapper.train_custom(in_optimizers, out_optimizers)
+
+ # Shuffle dataloader sampler
+ if hasattr(dataloader.sampler, "set_epoch"):
+ dataloader.sampler.set_epoch(self.current_epoch)
+
+ # Shuffle auxiliar dataloader sampler
+ if aux_dataloader is not None:
+ if hasattr(aux_dataloader.sampler, "set_epoch"):
+ aux_dataloader.sampler.set_epoch(self.current_epoch)
+
+ # Prepare progress bar
+ progress_bar = self.train_progress_bar(
+ dataloader, aux_dataloader=aux_dataloader, ncols=120)
+
+ # Zero gradients for the first iteration
+ for optimizer in in_optimizers.values():
+ optimizer.zero_grad()
+
+ # Loop through all batches
+ for i, batch in progress_bar:
+
+ # Send samples to GPU and take a training step
+ batch = sample_to_cuda(batch, self.proc_rank)
+ output = wrapper.training_step(batch, epoch=self.current_epoch)
+
+ # Step optimizer
+ if wrapper.update_schedulers == 'step':
+ for scheduler in schedulers.values():
+ scheduler.step()
+
+ # Backprop through loss
+ if scaler is None:
+ output['loss'].backward()
+ else:
+ scaler.scale(output['loss']).backward()
+
+ for optimizer in in_optimizers.values():
+ if not output['loss'].isnan().any():
+ if scaler is None:
+ optimizer.step()
+ else:
+ scaler.step(optimizer)
+ else:
+ print('NAN DETECTED!', i, batch['idx'])
+ optimizer.zero_grad()
+ if scaler is not None:
+ scaler.update()
+
+ self.update_averages(output)
+ self.update_train_progress_bar(progress_bar)
+
+ # Return outputs for epoch end
+ return wrapper.training_epoch_end()
+
+ @torch.no_grad()
+ def validate(self, mode, dataloaders, prefixes, wrapper):
+ """Validation loop"""
+ # Set wrapper to eval
+ wrapper.eval_custom()
+ # For all validation datasets
+ dataset_outputs = []
+ for dataset_idx, (dataset, dataloader, prefix) in \
+ enumerate(zip(wrapper.datasets[mode], dataloaders[mode], prefixes[mode])):
+ # Prepare progress bar for that dataset
+ progress_bar = self.val_progress_bar(dataloader, prefix, ncols=120)
+ # For all batches
+ batch_outputs = []
+ for batch_idx, batch in progress_bar:
+ # Send batch to GPU and take a validation step
+ batch = sample_to_cuda(batch, self.proc_rank)
+ output, results = wrapper.validation_step(batch, epoch=self.current_epoch)
+ if 'batch' in output:
+ batch = output['batch']
+ batch_outputs += results
+ if self.logger:
+ self.logger.log_data('val', batch, output, dataset, prefix)
+ if self.saver:
+ self.saver.save_data(batch, output, prefix)
+ # Append dataset outputs to list of all outputs
+ dataset_outputs.append(batch_outputs)
+ # Get results from validation epoch end
+ return wrapper.validation_epoch_end(dataset_outputs, prefixes[mode])
+
+ def post_validation(self, output, optimizers, prefixes, wrapper):
+ """Post-processing steps for validation"""
+ self.check_and_save(wrapper, output, prefixes)
+ self.log_losses_and_metrics(output, optimizers)
+ self.print_logger_and_checkpoint()
+ self.current_epoch += 1
+
+ def test(self, wrapper):
+ """Test a model by running validation once"""
+ dataloaders, prefixes = self.prepare_dataloaders(wrapper)
+ self.validate('validation', dataloaders, prefixes, wrapper)
+
diff --git a/vidar/core/wrapper.py b/vidar/core/wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..953d226d98428bde3f958ef928b9a4dfa84281e1
--- /dev/null
+++ b/vidar/core/wrapper.py
@@ -0,0 +1,262 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import os
+import random
+from abc import ABC
+from collections import OrderedDict
+
+import torch
+
+from vidar.utils.config import cfg_has, read_config
+from vidar.utils.data import set_random_seed
+from vidar.utils.distributed import rank, world_size
+from vidar.utils.flip import flip_batch, flip_output
+from vidar.utils.logging import pcolor, set_debug
+from vidar.utils.networks import load_checkpoint, save_checkpoint, freeze_layers_and_norms
+from vidar.utils.setup import setup_arch, setup_datasets, setup_metrics
+from vidar.utils.types import is_str
+
+
+class Wrapper(torch.nn.Module, ABC):
+ """
+ Trainer class for model optimization and inference
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration with parameters
+ ckpt : String
+ Name of the model checkpoint to start from
+ verbose : Bool
+ Print information on screen if enabled
+ """
+ def __init__(self, cfg, ckpt=None, verbose=False):
+ super().__init__()
+
+ if verbose and rank() == 0:
+ font = {'color': 'cyan', 'attrs': ('bold', 'dark')}
+ print(pcolor('#' * 100, **font))
+ print(pcolor('#' * 42 + ' VIDAR WRAPPER ' + '#' * 43, **font))
+ print(pcolor('#' * 100, **font))
+
+ # Get configuration
+ cfg = read_config(cfg) if is_str(cfg) else cfg
+ self.cfg = cfg
+
+ # Data augmentations
+ self.flip_lr_prob = cfg_has(cfg.wrapper, 'flip_lr_prob', 0.0)
+ self.validate_flipped = cfg_has(cfg.wrapper, 'validate_flipped', False)
+
+ # Set random seed
+ set_random_seed(cfg.wrapper.seed + rank())
+ set_debug(cfg_has(cfg.wrapper, 'debug', False))
+
+ # Setup architecture, datasets and tasks
+ self.arch = setup_arch(cfg.arch, checkpoint=ckpt, verbose=verbose) if cfg_has(cfg, 'arch') else None
+ self.datasets, self.datasets_cfg = setup_datasets(
+ cfg.datasets, verbose=verbose) if cfg_has(cfg, 'datasets') else (None, None)
+ self.metrics = setup_metrics(cfg.evaluation) if cfg_has(cfg, 'evaluation') else {}
+
+ sync_batch_norm = cfg_has(cfg.wrapper, 'sync_batch_norm', False)
+ if sync_batch_norm and os.environ['DIST_MODE'] == 'ddp':
+ self.arch = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.arch)
+
+ self.mixed_precision = cfg_has(cfg.wrapper, 'mixed_precision', False)
+
+ self.update_schedulers = None
+
+ def save(self, filename, epoch=None):
+ """Save checkpoint"""
+ save_checkpoint(filename, self, epoch=epoch)
+
+ def load(self, checkpoint, strict=True, verbose=False):
+ """Load checkpoint"""
+ load_checkpoint(self, checkpoint, strict=strict, verbose=verbose)
+
+ def train_custom(self, in_optimizers, out_optimizers):
+ """Customized training flag for the model"""
+ self.arch.train()
+ for key in in_optimizers.keys():
+ arch = self.arch.module if hasattr(self.arch, 'module') else self.arch
+ freeze_layers_and_norms(arch.networks[key], ['ALL'], flag_freeze=False)
+ for key in out_optimizers.keys():
+ arch = self.arch.module if hasattr(self.arch, 'module') else self.arch
+ freeze_layers_and_norms(arch.networks[key], ['ALL'], flag_freeze=True)
+
+ def eval_custom(self):
+ """Customized evaluation flag for the model"""
+ self.arch.eval()
+
+ def configure_optimizers_and_schedulers(self):
+ """Configure depth and pose optimizers and the corresponding scheduler"""
+
+ if not cfg_has(self.cfg, 'optimizers'):
+ return None, None
+
+ optimizers = OrderedDict()
+ schedulers = OrderedDict()
+
+ for key, val in self.cfg.optimizers.__dict__.items():
+ assert key in self.arch.networks, f'There is no network for optimizer {key}'
+ optimizers[key] = {
+ 'optimizer': getattr(torch.optim, val.name)(**{
+ 'lr': val.lr,
+ 'weight_decay': cfg_has(val, 'weight_decay', 0.0),
+ 'params': self.arch.networks[key].parameters(),
+ }),
+ 'settings': {} if not cfg_has(val, 'settings') else val.settings.__dict__
+ }
+ if cfg_has(val, 'scheduler'):
+ if val.scheduler.name == 'CosineAnnealingWarmUpRestarts':
+ from cosine_annealing_warmup import CosineAnnealingWarmupRestarts
+ epoch = float(len(self.datasets['train']) / (
+ world_size() * self.datasets_cfg['train'].dataloader.batch_size * self.datasets_cfg['train'].repeat[0]))
+ schedulers[key] = CosineAnnealingWarmupRestarts(**{
+ 'optimizer': optimizers[key]['optimizer'],
+ 'first_cycle_steps': int(val.scheduler.first_cycle_steps * epoch),
+ 'cycle_mult': val.scheduler.cycle_mult,
+ 'min_lr': val.scheduler.min_lr,
+ 'max_lr': val.scheduler.max_lr,
+ 'warmup_steps': int(val.scheduler.warmup_steps * epoch),
+ 'gamma': val.scheduler.gamma,
+ })
+ self.update_schedulers = 'step'
+ elif val.scheduler.name == 'LinearWarmUp':
+ from externals.huggingface.transformers.src.transformers.optimization import get_linear_schedule_with_warmup
+ schedulers[key] = get_linear_schedule_with_warmup(**{
+ 'optimizer': optimizers[key]['optimizer'],
+ 'num_warmup_steps': val.scheduler.num_warmup_steps,
+ 'num_training_steps': val.scheduler.num_training_steps,
+ })
+ self.update_schedulers = 'step'
+ else:
+ schedulers[key] = getattr(torch.optim.lr_scheduler, val.scheduler.name)(**{
+ 'optimizer': optimizers[key]['optimizer'],
+ 'step_size': val.scheduler.step_size,
+ 'gamma': val.scheduler.gamma,
+ })
+ self.update_schedulers = 'epoch'
+
+ # Return optimizer and scheduler
+ return optimizers, schedulers
+
+ def run_arch(self, batch, epoch, flip, unflip):
+ """
+ Run model on a batch
+
+ Parameters
+ ----------
+ batch : Dict
+ Dictionary with batch information
+ epoch : Int
+ Current epoch
+ flip : Bool
+ Batch should be flipped
+ unflip : Bool
+ Output should be unflipped
+
+ Returns
+ -------
+ output : Dict
+ Dictionary with model outputs
+ """
+ batch = flip_batch(batch) if flip else batch
+ output = self.arch(batch, epoch=epoch)
+ return flip_output(output) if flip and unflip else output
+
+ def training_step(self, batch, epoch):
+ """Processes a training batch"""
+ flip_lr = False if self.flip_lr_prob == 0 else \
+ random.random() < self.flip_lr_prob
+
+ if self.mixed_precision:
+ with torch.cuda.amp.autocast():
+ output = self.run_arch(batch, epoch=epoch, flip=flip_lr, unflip=False)
+ else:
+ output = self.run_arch(batch, epoch=epoch, flip=flip_lr, unflip=False)
+
+ losses = {key: val for key, val in output.items() if key.startswith('loss')}
+
+ return {
+ **losses,
+ 'metrics': output['metrics']
+ }
+
+ def validation_step(self, batch, epoch):
+ """Processes a validation batch"""
+ # from vidar.utils.data import break_batch
+ # batch = break_batch(batch)
+
+ if self.mixed_precision:
+ with torch.cuda.amp.autocast():
+ output = self.run_arch(batch, epoch=epoch, flip=False, unflip=False)
+ flipped_output = None if not self.validate_flipped else \
+ self.run_arch(batch, epoch=epoch, flip=True, unflip=True)
+ else:
+ output = self.run_arch(batch, epoch=epoch, flip=False, unflip=False)
+ flipped_output = None if not self.validate_flipped else \
+ self.run_arch(batch, epoch=epoch, flip=True, unflip=True)
+
+ if 'batch' in output:
+ batch = output['batch']
+
+ results = self.evaluate(batch, output, flipped_output)
+
+ results = [{
+ 'idx': batch['idx'][i],
+ **{key: val[i] for key, val in results['metrics'].items()}
+ } for i in range(len(batch['idx']))]
+
+ return output, results
+
+ @staticmethod
+ def training_epoch_end():
+ """Finishes a training epoch (do nothing for now)"""
+ return {}
+
+ def validation_epoch_end(self, output, prefixes):
+ """Finishes a validation epoch"""
+ if isinstance(output[0], dict):
+ output = [output]
+
+ metrics_dict = {}
+ for task in self.metrics:
+ metrics_dict.update(
+ self.metrics[task].reduce(
+ output, self.datasets['validation'], prefixes))
+
+ return metrics_dict
+
+ def evaluate(self, batch, output, flipped_output=None):
+ """
+ Evaluate batch to produce predictions and metrics for different tasks
+
+ Parameters
+ ----------
+ batch : Dict
+ Dictionary with batch information
+ output : Dict
+ Dictionary with output information
+ flipped_output : Dict
+ Dictionary with flipped output information
+
+ Returns
+ -------
+ results: Dict
+ Dictionary with evaluation results
+ """
+ # Evaluate different tasks
+ metrics, predictions = OrderedDict(), OrderedDict()
+ for task in self.metrics:
+ task_metrics, task_predictions = \
+ self.metrics[task].evaluate(batch, output['predictions'],
+ flipped_output['predictions'] if flipped_output else None)
+ metrics.update(task_metrics)
+ predictions.update(task_predictions)
+ # Crate results dictionary with metrics and predictions
+ results = {'metrics': metrics, 'predictions': predictions}
+ # Return final results
+ return results
+
+
+
diff --git a/vidar/datasets/BaseDataset.py b/vidar/datasets/BaseDataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fc241194616d62d3810ab72d414dfc70b872b54
--- /dev/null
+++ b/vidar/datasets/BaseDataset.py
@@ -0,0 +1,173 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import os
+from abc import ABC
+
+from torch.utils.data import Dataset
+
+from vidar.utils.types import is_list
+
+
+class BaseDataset(Dataset, ABC):
+ """
+ Base dataset class
+
+ Parameters
+ ----------
+ path : String
+ Dataset location
+ context : Tuple
+ Temporal context
+ cameras : Tuple
+ Camera names
+ labels : Tuple
+ Labels to be loaded
+ labels_context :
+ Context labels to be loaded
+ data_transform : Function
+ Transformations to be applied to sample
+ ontology : String
+ Which semantic ontology should be used
+ return_ontology : Bool
+ Whether the ontology should be returned
+ virtual : Bool
+ Whether the dataset is virtual or not
+ kwargs : Dict
+ Additional parameters
+ """
+ def __init__(self, path, context, cameras, labels=(), labels_context=(),
+ data_transform=None, ontology=None, return_ontology=False, virtual=False,
+ **kwargs):
+ super().__init__()
+
+ self.path = path
+ self.labels = labels
+ self.labels_context = labels_context
+ self.cameras = cameras
+ self.data_transform = data_transform
+
+ self.num_cameras = len(cameras) if is_list(cameras) else cameras
+
+ self.bwd_contexts = [ctx for ctx in context if ctx < 0]
+ self.fwd_contexts = [ctx for ctx in context if ctx > 0]
+
+ self.bwd_context = 0 if len(context) == 0 else - min(0, min(context))
+ self.fwd_context = 0 if len(context) == 0 else max(0, max(context))
+
+ self.context = [v for v in range(- self.bwd_context, 0)] + \
+ [v for v in range(1, self.fwd_context + 1)]
+
+ self.num_context = self.bwd_context + self.fwd_context
+ self.with_context = self.num_context > 0
+
+ self.ontology = ontology
+ self.return_ontology = return_ontology
+ self.virtual = virtual
+
+ def relative_path(self, filename):
+ return {key: os.path.splitext(val.replace(self.path + '/', ''))[0]
+ for key, val in filename.items()}
+
+ # Label properties
+
+ @property
+ def with_depth(self):
+ """If dataset contains depth"""
+ return 'depth' in self.labels
+
+ @property
+ def with_input_depth(self):
+ """If dataset contains input depth"""
+ return 'input_depth' in self.labels
+
+ @property
+ def with_pose(self):
+ """If dataset contains pose"""
+ return 'pose' in self.labels
+
+ @property
+ def with_semantic(self):
+ """If dataset contains semantic"""
+ return 'semantic' in self.labels
+
+ @property
+ def with_instance(self):
+ """If dataset contains instance"""
+ return 'instance' in self.labels
+
+ @property
+ def with_optical_flow(self):
+ """If dataset contains optical flow"""
+ return 'optical_flow' in self.labels
+
+ @property
+ def with_scene_flow(self):
+ """If dataset contains scene flow"""
+ return 'scene_flow' in self.labels
+
+ @property
+ def with_bbox2d(self):
+ """If dataset contains 2d bounding boxes"""
+ return 'bbox2d' in self.labels
+
+ @property
+ def with_bbox3d(self):
+ """If dataset contains 3d bounding boxes"""
+ return 'bbox3d' in self.labels
+
+ @property
+ def with_lidar(self):
+ """If dataset contains lidar"""
+ return 'lidar' in self.labels
+
+ @property
+ def with_radar(self):
+ """If dataset contains radar"""
+ return 'radar' in self.labels
+
+ @property
+ def with_pointcache(self):
+ """If dataset contains pointcaches"""
+ return 'pointcache' in self.labels
+
+ # Label context properties
+
+ @property
+ def with_depth_context(self):
+ """If dataset contains context depth"""
+ return 'depth' in self.labels_context
+
+ @property
+ def with_input_depth_context(self):
+ """If dataset contains context input depth"""
+ return 'input_depth' in self.labels_context
+
+ @property
+ def with_semantic_context(self):
+ """If dataset contains context semantic"""
+ return 'semantic' in self.labels_context
+
+ @property
+ def with_instance_context(self):
+ """If dataset contains context instance"""
+ return 'instance' in self.labels_context
+
+ @property
+ def with_optical_flow_context(self):
+ """If dataset contains context optical flow"""
+ return 'optical_flow' in self.labels_context
+
+ @property
+ def with_scene_flow_context(self):
+ """If dataset contains context scene flow"""
+ return 'scene_flow' in self.labels_context
+
+ @property
+ def with_bbox2d_context(self):
+ """If dataset contains context 2d bounding boxes"""
+ return 'bbox2d' in self.labels_context
+
+ @property
+ def with_bbox3d_context(self):
+ """If dataset contains context 3d bounding boxes"""
+ return 'bbox3d' in self.labels_context
diff --git a/vidar/datasets/EUROCDataset.py b/vidar/datasets/EUROCDataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..c70c0151d94448ff11d7bbc9fb1a5a632cf0051e
--- /dev/null
+++ b/vidar/datasets/EUROCDataset.py
@@ -0,0 +1,177 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import re
+from collections import defaultdict
+import os
+from abc import ABC
+
+from PIL import Image
+import numpy as np
+from vidar.utils.read import read_image
+from vidar.datasets.BaseDataset import BaseDataset
+from vidar.datasets.utils.misc import stack_sample
+
+
+def dummy_calibration(image):
+ w, h = [float(d) for d in image.size]
+ return np.array([[1000. , 0. , w / 2. - 0.5],
+ [0. , 1000. , h / 2. - 0.5],
+ [0. , 0. , 1. ]])
+
+
+def get_idx(filename):
+ return int(re.search(r'\d+', filename).group())
+
+
+class EUROCDataset(BaseDataset, ABC):
+ """
+ KITTI dataset class
+
+ Parameters
+ ----------
+ split : String
+ Split file
+ stride : Tuple
+ Which context strides to use
+ spaceholder : String
+ Space pattern on input images
+ data_transform : Function
+ Transformations applied to the sample
+ """
+ def __init__(self, split, strides=(1,), spaceholder='{:19}', **kwargs):
+ super().__init__(**kwargs)
+
+ self.split = split
+ self.spaceholder = spaceholder
+
+ self.backward_context = strides[0]
+ self.forward_context = strides[1]
+ self.has_context = self.backward_context + self.forward_context > 0
+
+ self.file_tree = defaultdict(list)
+ self.read_files(self.path)
+
+ self.files = []
+ for k, v in self.file_tree.items():
+ file_set = set(self.file_tree[k])
+ files = [fname for fname in sorted(v) if self._has_context(fname, file_set)]
+ self.files.extend([[k, fname] for fname in files])
+
+ def read_files(self, directory, ext=('.png', '.jpg', '.jpeg'), skip_empty=True):
+ """Read input images"""
+ files = defaultdict(list)
+ for entry in os.scandir(directory):
+ relpath = os.path.relpath(entry.path, directory)
+ if entry.is_dir():
+ d_files = self.read_files(entry.path, ext=ext, skip_empty=skip_empty)
+ if skip_empty and not len(d_files):
+ continue
+ self.file_tree[entry.path] = d_files[entry.path]
+ elif entry.is_file():
+ if ext is None or entry.path.lower().endswith(tuple(ext)):
+ files[directory].append(relpath)
+ return files
+
+ def __len__(self):
+ return len(self.files)
+
+ def _change_idx(self, idx, filename):
+ """Prepare name strings according to index"""
+ _, ext = os.path.splitext(os.path.basename(filename))
+ return self.spaceholder.format(idx) + ext
+
+ def _has_context(self, filename, file_set):
+ """Check if image has context"""
+ context_paths = self._get_context_file_paths(filename, file_set)
+ return len([f in file_set for f in context_paths]) >= len(self.context)
+
+ def _get_context_file_paths(self, filename, file_set):
+ """Get file path for contexts"""
+ fidx = get_idx(filename)
+ idxs = [-self.backward_context, -self.forward_context, self.backward_context, self.forward_context]
+ potential_files = [self._change_idx(fidx + i, filename) for i in idxs]
+ return [fname for fname in potential_files if fname in file_set]
+
+ def _read_rgb_context_files(self, session, filename):
+ """Read context images"""
+ file_set = set(self.file_tree[session])
+ context_paths = self._get_context_file_paths(filename, file_set)
+ return [self._read_rgb_file(session, filename) for filename in context_paths]
+
+ def _read_rgb_file(self, session, filename):
+ """Read target images"""
+ gray_image = read_image(os.path.join(self.path, session, filename))
+ gray_image_np = np.array(gray_image)
+ rgb_image_np = np.stack([gray_image_np for _ in range(3)], axis=2)
+ return Image.fromarray(rgb_image_np)
+
+ def _read_npy_depth(self, session, depth_filename):
+ """Read depth from numpy file"""
+ depth_file_path = os.path.join(self.path, session, '../../depth_maps', depth_filename)
+ return np.load(depth_file_path)
+
+ def _read_depth(self, session, depth_filename):
+ """Get the depth map from a file."""
+ return self._read_npy_depth(session, depth_filename)
+
+ def _has_depth(self, session, depth_filename):
+ """Check if depth map exists"""
+ depth_file_path = os.path.join(self.path, session, '../../depth_maps', depth_filename)
+ return os.path.isfile(depth_file_path)
+
+ def __getitem__(self, idx):
+ """Get dataset sample"""
+
+ samples = []
+
+ session, filename = self.files[idx]
+ image = self._read_rgb_file(session, filename)
+
+ sample = {
+ 'idx': idx,
+ 'filename': '%s_%s' % (session, os.path.splitext(filename)[0]),
+ 'rgb': {0: image},
+ 'intrinsics': {0: dummy_calibration(image)},
+ }
+
+ if self.has_context:
+ image_context = self._read_rgb_context_files(session, filename)
+ sample['rgb'].update({
+ key: val for key, val in zip(self.context, image_context)
+ })
+
+ depth_filename = filename.split('.')[0] + 'depth.npy'
+ if self.with_depth:
+ if self._has_depth(session, depth_filename):
+ sample['depth'] = {0: self._read_depth(session, depth_filename)}
+
+ samples.append(sample)
+
+ if self.data_transform:
+ samples = self.data_transform(samples)
+
+ return stack_sample(samples)
+
+
+if __name__ == "__main__":
+
+ data_dir = '/data/vidar/euroc/euroc_cam/cam0'
+ euroc_dataset = EUROCDataset(path=data_dir,
+ strides=[49999872, 50000128],
+ context=[-1,1],
+ split='{:19}',
+ labels=['depth'],
+ cameras=[[0]],
+ )
+ print(len(euroc_dataset))
+ print('\nsample 0:')
+ print(euroc_dataset[0].keys())
+ print(euroc_dataset[0]['filename'])
+ print(euroc_dataset[0]['rgb'])
+ print(euroc_dataset[0]['intrinsics'])
+
+ print('\nsample 1:')
+ print(euroc_dataset[1].keys())
+ print(euroc_dataset[1]['filename'])
+ print(euroc_dataset[1]['rgb'])
+ print(euroc_dataset[1]['intrinsics'])
\ No newline at end of file
diff --git a/vidar/datasets/GenericDataset.py b/vidar/datasets/GenericDataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9af533139b2ac5c64ea218fe92b74425c112e8c
--- /dev/null
+++ b/vidar/datasets/GenericDataset.py
@@ -0,0 +1,86 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import numpy as np
+
+from vidar.datasets.BaseDataset import BaseDataset
+from vidar.datasets.utils.FolderTree import FolderTree
+from vidar.datasets.utils.misc import stack_sample
+from vidar.utils.read import read_image
+
+
+class GenericDataset(BaseDataset):
+ def __init__(self, tag=None, single_folder=False, split=None, extension='png', **kwargs):
+ """
+ Generic dataset, used to load information from folders
+
+ Parameters
+ ----------
+ tag : String
+ Dataset tag
+ single_folder : Bool
+ Whether the dataset is a single folder
+ split : String
+ Dataset split
+ kwargs : Dict
+ Additional arguments
+ """
+ super().__init__(**kwargs)
+ self.tag = 'generic' if tag is None else tag
+ if split is None or split == '':
+ split = ('', )
+ self.rgb_tree = FolderTree(
+ self.path, context=self.context, sub_folders=split,
+ single_folder=single_folder, suffix=f'.{extension}')
+
+ def __len__(self):
+ """Dataset length"""
+ return len(self.rgb_tree)
+
+ @staticmethod
+ def get_intrinsics(rgb):
+ """Return dummy intrinsics"""
+ return np.array([[rgb.size[0] / 2., 0., rgb.size[0] / 2.],
+ [0., rgb.size[1], rgb.size[1] / 2.],
+ [0., 0., 1.]])
+
+ def __getitem__(self, idx):
+ """Get dataset sample"""
+ samples = []
+
+ for _ in self.cameras:
+
+ # Filename
+ filename = self.rgb_tree.get_item(idx)
+
+ # Base sample
+ sample = {
+ 'idx': idx,
+ 'tag': self.tag,
+ 'filename': self.relative_path(filename),
+ 'splitname': '%010d' % idx
+ }
+
+ # Image
+ sample['rgb'] = read_image(filename)
+
+ # Intrinsics
+ sample['intrinsics'] = {
+ 0: self.get_intrinsics((sample['rgb'][0]))
+ }
+
+ # If with context
+ if self.with_context:
+ filename_context = self.rgb_tree.get_context(idx)
+ sample['rgb'].update(read_image(filename_context))
+
+ # Stack sample
+ samples.append(sample)
+
+ # Transform data
+ if self.data_transform:
+ samples = self.data_transform(samples)
+
+ # Return stacked sample
+ return stack_sample(samples)
+
+
diff --git a/vidar/datasets/KITTIDataset.py b/vidar/datasets/KITTIDataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..5927f4e74bb2e0e4630603a6a0570bcac22a800e
--- /dev/null
+++ b/vidar/datasets/KITTIDataset.py
@@ -0,0 +1,485 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import glob
+import os
+
+import numpy as np
+
+from vidar.datasets.BaseDataset import BaseDataset
+from vidar.datasets.KITTIDataset_utils import \
+ pose_from_oxts_packet, read_calib_file, transform_from_rot_trans
+from vidar.datasets.utils.misc import \
+ invert_pose, stack_sample, make_relative_pose
+from vidar.utils.read import read_image
+
+# Cameras from the stereo pair (left is the origin)
+IMAGE_FOLDER = {
+ 'left': 'image_02',
+ 'right': 'image_03',
+}
+
+
+# Name of different calibration files
+CALIB_FILE = {
+ 'cam2cam': 'calib_cam_to_cam.txt',
+ 'velo2cam': 'calib_velo_to_cam.txt',
+ 'imu2velo': 'calib_imu_to_velo.txt',
+}
+
+
+PNG_DEPTH_DATASETS = ['groundtruth']
+OXTS_POSE_DATA = 'oxts'
+
+
+def read_npz_depth(file, depth_type):
+ """Reads a .npz depth map given a certain depth_type"""
+ depth = np.load(file)[depth_type + '_depth'].astype(np.float32)
+ return np.expand_dims(depth, axis=2)
+
+
+def read_png_depth(file):
+ """Reads a .png depth map"""
+ depth_png = np.array(read_image(file), dtype=int)
+ assert (np.max(depth_png) > 255), 'Wrong .png depth file'
+ depth = depth_png.astype(np.float) / 256.
+ depth[depth_png == 0] = -1.
+ return np.expand_dims(depth, axis=2)
+
+
+class KITTIDataset(BaseDataset):
+ """
+ KITTI dataset class
+
+ Parameters
+ ----------
+ path : String
+ Path to the dataset
+ split : String
+ Split file, with paths to the images to be used
+ data_transform : Function
+ Transformations applied to the sample
+ depth_type : String
+ Which depth type to load
+ input_depth_type : String
+ Which input depth type to load
+ """
+ def __init__(self, split, tag=None,
+ depth_type=None, input_depth_type=None,
+ single_intrinsics=False, **kwargs):
+ # Assertions
+ super().__init__(**kwargs)
+ self.tag = 'kitti' if tag is None else tag
+
+ self.baseline = + 0.5407
+
+ self.backward_context_paths = []
+ self.forward_context_paths = []
+
+ self.single_intrinsics = None if not single_intrinsics else \
+ np.array([[0.58, 0.00, 0.5, 0],
+ [0.00, 1.92, 0.5, 0],
+ [0.00, 0.00, 1.0, 0],
+ [0.00, 0.00, 0.0, 1]], dtype=np.float32)
+
+ self.split = split.split('/')[-1].split('.')[0]
+
+ self.depth_type = depth_type
+ self.input_depth_type = input_depth_type
+
+ self._cache = {}
+ self.pose_cache = {}
+ self.oxts_cache = {}
+ self.calibration_cache = {}
+ self.imu2velo_calib_cache = {}
+ self.sequence_origin_cache = {}
+
+ with open(os.path.join(self.path, split), "r") as f:
+ data = f.readlines()
+
+ self.paths = []
+ # Get file list from data
+ for i, fname in enumerate(data):
+ path = os.path.join(self.path, fname.split()[0])
+ add_flag = True
+ if add_flag and self.with_depth:
+ # Check if depth file exists
+ depth = self._get_depth_file(path, self.depth_type)
+ add_flag = depth is not None and os.path.exists(depth)
+ if add_flag and self.with_input_depth:
+ # Check if input depth file exists
+ depth = self._get_depth_file(path, self.input_depth_type)
+ add_flag = depth is not None and os.path.exists(depth)
+ if add_flag:
+ self.paths.append(path)
+
+ # If using context, filter file list
+ if self.with_context:
+ paths_with_context = []
+ for stride in [1]:
+ for idx, file in enumerate(self.paths):
+ backward_context_idxs, forward_context_idxs = \
+ self._get_sample_context(
+ file, self.bwd_context, self.fwd_context, stride)
+ if backward_context_idxs is not None and forward_context_idxs is not None:
+ exists = True
+ if self.with_depth_context:
+ _, depth_context_files = self._get_context_files(
+ self.paths[idx], backward_context_idxs + forward_context_idxs)
+ for depth_file in depth_context_files:
+ exists = os.path.exists(depth_file)
+ if not exists:
+ break
+ if exists:
+ paths_with_context.append(self.paths[idx])
+ self.forward_context_paths.append(forward_context_idxs)
+ self.backward_context_paths.append(backward_context_idxs[::-1])
+ self.paths = paths_with_context
+
+ if len(self.cameras) > 1:
+ self.paths = [im.replace('image_03', 'image_02') for im in self.paths]
+
+ if 1 in self.cameras:
+ self.paths_stereo = [im.replace('image_02', 'image_03') for im in self.paths]
+ else:
+ self.paths_stereo = None
+
+ @staticmethod
+ def _get_next_file(idx, file):
+ """Get next file given next idx and current file"""
+ base, ext = os.path.splitext(os.path.basename(file))
+ return os.path.join(os.path.dirname(file), str(idx).zfill(len(base)) + ext)
+
+ @staticmethod
+ def _get_parent_folder(image_file):
+ """Get the parent folder from image_file"""
+ return os.path.abspath(os.path.join(image_file, "../../../.."))
+
+ @staticmethod
+ def _get_intrinsics(image_file, calib_data):
+ """Get intrinsics from the calib_data dictionary"""
+ for cam in ['left', 'right']:
+ # Check for both cameras, if found replace and return intrinsics
+ if IMAGE_FOLDER[cam] in image_file:
+ return np.reshape(calib_data[IMAGE_FOLDER[cam].replace('image', 'P_rect')], (3, 4))[:, :3]
+
+ @staticmethod
+ def _read_raw_calib_file(folder):
+ """Read raw calibration files from folder"""
+ return read_calib_file(os.path.join(folder, CALIB_FILE['cam2cam']))
+
+ @staticmethod
+ def _get_keypoints(filename, size):
+ """Get keypoints from file"""
+ filename = filename. \
+ replace('KITTI_tiny', 'KITTI_tiny_keypoints'). \
+ replace('.png', '.txt.npz')
+ keypoints = np.load(filename)['data']
+ keypoints_coord, keypoints_desc = keypoints[:, :2], keypoints[:, 2:]
+ keypoints_coord[:, 0] *= size[0] / 320
+ keypoints_coord[:, 1] *= size[1] / 240
+ return keypoints_coord, keypoints_desc
+
+ def get_filename(self, sample_idx):
+ """
+ Returns the filename for an index, following DGP structure
+
+ Parameters
+ ----------
+ sample_idx : Int
+ Sample index
+
+ Returns
+ -------
+ filename : String
+ Filename for that sample
+ """
+ filename = os.path.splitext(self.paths[sample_idx].replace(self.path + '/', ''))[0]
+ for cam in ['left', 'right']:
+ filename = filename.replace('{}/data'.format(IMAGE_FOLDER[cam]),
+ 'proj_depth/{}/%s' % IMAGE_FOLDER[cam])
+ return filename
+########################################################################################################################
+#### DEPTH
+########################################################################################################################
+
+ def _read_depth(self, depth_file):
+ """Get the depth map from a file"""
+ if depth_file.endswith('.npz'):
+ return read_npz_depth(depth_file, 'velodyne')
+ elif depth_file.endswith('.png'):
+ return read_png_depth(depth_file)
+ else:
+ raise NotImplementedError(
+ 'Depth type {} not implemented'.format(self.depth_type))
+
+ @staticmethod
+ def _get_depth_file(image_file, depth_type):
+ """Get the corresponding depth file from an image file"""
+ for cam in ['left', 'right']:
+ if IMAGE_FOLDER[cam] in image_file:
+ depth_file = image_file.replace(
+ IMAGE_FOLDER[cam] + '/data', 'proj_depth/{}/{}'.format(depth_type, IMAGE_FOLDER[cam]))
+ if depth_type not in PNG_DEPTH_DATASETS:
+ depth_file = depth_file.replace('png', 'npz')
+ return depth_file
+
+ def _get_sample_context(self, sample_name,
+ backward_context, forward_context, stride=1):
+ """
+ Get a sample context
+
+ Parameters
+ ----------
+ sample_name : String
+ Path + Name of the sample
+ backward_context : Int
+ Size of backward context
+ forward_context : Int
+ Size of forward context
+ stride : Int
+ Stride value to consider when building the context
+
+ Returns
+ -------
+ backward_context : list[Int]
+ List containing the indexes for the backward context
+ forward_context : list[Int]
+ List containing the indexes for the forward context
+ """
+ base, ext = os.path.splitext(os.path.basename(sample_name))
+ parent_folder = os.path.dirname(sample_name)
+ f_idx = int(base)
+
+ # Check number of files in folder
+ if parent_folder in self._cache:
+ max_num_files = self._cache[parent_folder]
+ else:
+ max_num_files = len(glob.glob(os.path.join(parent_folder, '*' + ext)))
+ self._cache[parent_folder] = max_num_files
+
+ # Check bounds
+ if (f_idx - backward_context * stride) < 0 or (
+ f_idx + forward_context * stride) >= max_num_files:
+ return None, None
+
+ # Backward context
+ c_idx = f_idx
+ backward_context_idxs = []
+ while len(backward_context_idxs) < backward_context and c_idx > 0:
+ c_idx -= stride
+ filename = self._get_next_file(c_idx, sample_name)
+ if os.path.exists(filename):
+ backward_context_idxs.append(c_idx)
+ if c_idx < 0:
+ return None, None
+
+ # Forward context
+ c_idx = f_idx
+ forward_context_idxs = []
+ while len(forward_context_idxs) < forward_context and c_idx < max_num_files:
+ c_idx += stride
+ filename = self._get_next_file(c_idx, sample_name)
+ if os.path.exists(filename):
+ forward_context_idxs.append(c_idx)
+ if c_idx >= max_num_files:
+ return None, None
+
+ return backward_context_idxs, forward_context_idxs
+
+ def _get_context_files(self, sample_name, idxs):
+ """
+ Returns image and depth context files
+
+ Parameters
+ ----------
+ sample_name : String
+ Name of current sample
+ idxs : list[Int]
+ Context indexes
+
+ Returns
+ -------
+ image_context_paths : list[String]
+ List of image names for the context
+ depth_context_paths : list[String]
+ List of depth names for the context
+ """
+ image_context_paths = [self._get_next_file(i, sample_name) for i in idxs]
+ if self.with_depth:
+ depth_context_paths = [self._get_depth_file(f, self.depth_type) for f in image_context_paths]
+ return image_context_paths, depth_context_paths
+ else:
+ return image_context_paths, None
+
+########################################################################################################################
+#### POSE
+########################################################################################################################
+
+ def _get_imu2cam_transform(self, image_file):
+ """Gets the transformation between IMU and camera from an image file"""
+ parent_folder = self._get_parent_folder(image_file)
+ if image_file in self.imu2velo_calib_cache:
+ return self.imu2velo_calib_cache[image_file]
+
+ cam2cam = read_calib_file(os.path.join(parent_folder, CALIB_FILE['cam2cam']))
+ imu2velo = read_calib_file(os.path.join(parent_folder, CALIB_FILE['imu2velo']))
+ velo2cam = read_calib_file(os.path.join(parent_folder, CALIB_FILE['velo2cam']))
+
+ velo2cam_mat = transform_from_rot_trans(velo2cam['R'], velo2cam['T'])
+ imu2velo_mat = transform_from_rot_trans(imu2velo['R'], imu2velo['T'])
+ cam_2rect_mat = transform_from_rot_trans(cam2cam['R_rect_00'], np.zeros(3))
+
+ imu2cam = cam_2rect_mat @ velo2cam_mat @ imu2velo_mat
+ self.imu2velo_calib_cache[image_file] = imu2cam
+ return imu2cam
+
+ @staticmethod
+ def _get_oxts_file(image_file):
+ """Gets the oxts file from an image file"""
+ # find oxts pose file
+ for cam in ['left', 'right']:
+ # Check for both cameras, if found replace and return file name
+ if IMAGE_FOLDER[cam] in image_file:
+ return image_file.replace(IMAGE_FOLDER[cam], OXTS_POSE_DATA).replace('.png', '.txt')
+ # Something went wrong (invalid image file)
+ raise ValueError('Invalid KITTI path for pose supervision.')
+
+ def _get_oxts_data(self, image_file):
+ """Gets the oxts data from an image file"""
+ oxts_file = self._get_oxts_file(image_file)
+ if oxts_file in self.oxts_cache:
+ oxts_data = self.oxts_cache[oxts_file]
+ else:
+ oxts_data = np.loadtxt(oxts_file, delimiter=' ', skiprows=0)
+ self.oxts_cache[oxts_file] = oxts_data
+ return oxts_data
+
+ def _get_pose(self, image_file, camera):
+ """Gets the pose information from an image file"""
+ if image_file in self.pose_cache:
+ return self.pose_cache[image_file]
+ # Find origin frame in this sequence to determine scale & origin translation
+ base, ext = os.path.splitext(os.path.basename(image_file))
+ origin_frame = os.path.join(os.path.dirname(image_file), str(0).zfill(len(base)) + ext)
+ # Get origin data
+ origin_oxts_data = self._get_oxts_data(origin_frame)
+ lat = origin_oxts_data[0]
+ scale = np.cos(lat * np.pi / 180.)
+ # Get origin pose
+ origin_R, origin_t = pose_from_oxts_packet(origin_oxts_data, scale)
+ origin_pose = transform_from_rot_trans(origin_R, origin_t)
+ # Compute current pose
+ oxts_data = self._get_oxts_data(image_file)
+ R, t = pose_from_oxts_packet(oxts_data, scale)
+ pose = transform_from_rot_trans(R, t)
+ # Compute odometry pose
+ imu2cam = self._get_imu2cam_transform(image_file)
+ odo_pose = (imu2cam @ np.linalg.inv(origin_pose) @
+ pose @ np.linalg.inv(imu2cam)).astype(np.float32)
+ odo_pose = invert_pose(odo_pose)
+ # Cache and return pose
+ self.pose_cache[image_file] = odo_pose
+ if camera == 1:
+ odo_pose[0, -1] -= self.baseline
+ return odo_pose
+
+########################################################################################################################
+
+ def __len__(self):
+ """Dataset length"""
+ return len(self.paths)
+
+ def __getitem__(self, idx):
+ """Get dataset sample"""
+
+ samples = []
+
+ for camera in self.cameras:
+
+ # Filename
+ filename = self.paths[idx] if camera == 0 else self.paths_stereo[idx]
+
+ # Base sample
+ sample = {
+ 'idx': idx,
+ 'tag': self.tag,
+ 'filename': self.relative_path({0: filename}),
+ 'splitname': '%s_%010d' % (self.split, idx)
+ }
+
+ # Image
+ sample['rgb'] = {0: read_image(filename)}
+
+ # Intrinsics
+ parent_folder = self._get_parent_folder(filename)
+ if parent_folder in self.calibration_cache:
+ c_data = self.calibration_cache[parent_folder]
+ else:
+ c_data = self._read_raw_calib_file(parent_folder)
+ self.calibration_cache[parent_folder] = c_data
+
+ # Return individual or single intrinsics
+ if self.single_intrinsics is not None:
+ intrinsics = self.single_intrinsics.copy()
+ intrinsics[0, :] *= sample['rgb'][0].size[0]
+ intrinsics[1, :] *= sample['rgb'][0].size[1]
+ sample['intrinsics'] = {0: intrinsics}
+ else:
+ sample['intrinsics'] = {0: self._get_intrinsics(filename, c_data)}
+
+ # Add pose information if requested
+ if self.with_pose:
+ sample['pose'] = {0: self._get_pose(filename, camera)}
+
+ # Add depth information if requested
+ if self.with_depth:
+ sample['depth'] = {0: self._read_depth(
+ self._get_depth_file(filename, self.depth_type))}
+
+ # Add input depth information if requested
+ if self.with_input_depth:
+ sample['input_depth'] = {0: self._read_depth(
+ self._get_depth_file(filename, self.input_depth_type))}
+
+ # Add context information if requested
+ if self.with_context:
+
+ # Add context images
+ all_context_idxs = self.backward_context_paths[idx] + \
+ self.forward_context_paths[idx]
+ image_context_paths, depth_context_paths = \
+ self._get_context_files(filename, all_context_idxs)
+ image_context = [read_image(f) for f in image_context_paths]
+ sample['rgb'].update({
+ key: val for key, val in zip(self.context, image_context)
+ })
+
+ # Add context poses
+ if self.with_pose:
+ image_context_pose = [self._get_pose(f, camera) for f in image_context_paths]
+ sample['pose'].update({
+ key: val for key, val in zip(self.context, image_context_pose)
+ })
+
+ # Add context depth
+ if self.with_depth_context:
+ depth_context = [self._read_depth(self._get_depth_file(
+ path, self.depth_type)) for path in depth_context_paths]
+ sample['depth'].update({
+ key: val for key, val in zip(self.context, depth_context)
+ })
+
+ # Stack sample
+ samples.append(sample)
+
+ # Make relative poses
+ samples = make_relative_pose(samples)
+
+ # Transform data
+ if self.data_transform:
+ samples = self.data_transform(samples)
+
+ # Return stacked sample
+ return stack_sample(samples)
+
+########################################################################################################################
diff --git a/vidar/datasets/KITTIDataset_utils.py b/vidar/datasets/KITTIDataset_utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..0a1ca1965ad21220bdece990d2cd1dd02e120869
--- /dev/null
+++ b/vidar/datasets/KITTIDataset_utils.py
@@ -0,0 +1,223 @@
+"""Provides helper methods for loading and parsing KITTI data."""
+
+from collections import namedtuple
+
+import numpy as np
+
+__author__ = "Lee Clement"
+__email__ = "lee.clement@robotics.utias.utoronto.ca"
+
+# Per dataformat.txt
+OxtsPacket = namedtuple('OxtsPacket',
+ 'lat, lon, alt, ' +
+ 'roll, pitch, yaw, ' +
+ 'vn, ve, vf, vl, vu, ' +
+ 'ax, ay, az, af, al, au, ' +
+ 'wx, wy, wz, wf, wl, wu, ' +
+ 'pos_accuracy, vel_accuracy, ' +
+ 'navstat, numsats, ' +
+ 'posmode, velmode, orimode')
+
+# Bundle into an easy-to-access structure
+OxtsData = namedtuple('OxtsData', 'packet, T_w_imu')
+
+
+def rotx(t):
+ """
+ Rotation about the x-axis
+
+ Parameters
+ ----------
+ t : Float
+ Theta angle
+
+ Returns
+ -------
+ matrix : np.Array
+ Rotation matrix [3,3]
+ """
+ c = np.cos(t)
+ s = np.sin(t)
+ return np.array([[1, 0, 0],
+ [0, c, -s],
+ [0, s, c]])
+
+
+def roty(t):
+ """
+ Rotation about the y-axis
+
+ Parameters
+ ----------
+ t : Float
+ Theta angle
+
+ Returns
+ -------
+ matrix : np.Array
+ Rotation matrix [3,3]
+ """
+ c = np.cos(t)
+ s = np.sin(t)
+ return np.array([[c, 0, s],
+ [0, 1, 0],
+ [-s, 0, c]])
+
+
+def rotz(t):
+ """
+ Rotation about the z-axis
+
+ Parameters
+ ----------
+ t : Float
+ Theta angle
+
+ Returns
+ -------
+ matrix : np.Array
+ Rotation matrix [3,3]
+ """
+ c = np.cos(t)
+ s = np.sin(t)
+ return np.array([[c, -s, 0],
+ [s, c, 0],
+ [0, 0, 1]])
+
+
+def transform_from_rot_trans(R, t):
+ """
+ Transformation matrix from rotation matrix and translation vector.
+
+ Parameters
+ ----------
+ R : np.Array
+ Rotation matrix [3,3]
+ t : np.Array
+ translation vector [3]
+
+ Returns
+ -------
+ matrix : np.Array
+ Transformation matrix [4,4]
+ """
+ R = R.reshape(3, 3)
+ t = t.reshape(3, 1)
+ return np.vstack((np.hstack([R, t]), [0, 0, 0, 1]))
+
+
+def read_calib_file(filepath):
+ """
+ Read in a calibration file and parse into a dictionary
+
+ Parameters
+ ----------
+ filepath : String
+ File path to read from
+
+ Returns
+ -------
+ calib : Dict
+ Dictionary with calibration values
+ """
+ data = {}
+
+ with open(filepath, 'r') as f:
+ for line in f.readlines():
+ key, value = line.split(':', 1)
+ # The only non-float values in these files are dates, which
+ # we don't care about anyway
+ try:
+ data[key] = np.array([float(x) for x in value.split()])
+ except ValueError:
+ pass
+
+ return data
+
+
+def pose_from_oxts_packet(raw_data, scale):
+ """
+ Helper method to compute a SE(3) pose matrix from an OXTS packet
+
+ Parameters
+ ----------
+ raw_data : Dict
+ Oxts data to read from
+ scale : Float
+ Oxts scale
+
+ Returns
+ -------
+ R : np.Array
+ Rotation matrix [3,3]
+ t : np.Array
+ Translation vector [3]
+ """
+ packet = OxtsPacket(*raw_data)
+ er = 6378137. # earth radius (approx.) in meters
+
+ # Use a Mercator projection to get the translation vector
+ tx = scale * packet.lon * np.pi * er / 180.
+ ty = scale * er * \
+ np.log(np.tan((90. + packet.lat) * np.pi / 360.))
+ tz = packet.alt
+ t = np.array([tx, ty, tz])
+
+ # Use the Euler angles to get the rotation matrix
+ Rx = rotx(packet.roll)
+ Ry = roty(packet.pitch)
+ Rz = rotz(packet.yaw)
+ R = Rz.dot(Ry.dot(Rx))
+
+ # Combine the translation and rotation into a homogeneous transform
+ return R, t
+
+
+def load_oxts_packets_and_poses(oxts_files):
+ """
+ Generator to read OXTS ground truth data.
+ Poses are given in an East-North-Up coordinate system
+ whose origin is the first GPS position.
+
+ Parameters
+ ----------
+ oxts_files : list[String]
+ List of oxts files to read from
+
+ Returns
+ -------
+ oxts : list[Dict]
+ List of oxts ground-truth data
+ """
+ # Scale for Mercator projection (from first lat value)
+ scale = None
+ # Origin of the global coordinate system (first GPS position)
+ origin = None
+
+ oxts = []
+
+ for filename in oxts_files:
+ with open(filename, 'r') as f:
+ for line in f.readlines():
+ line = line.split()
+ # Last five entries are flags and counts
+ line[:-5] = [float(x) for x in line[:-5]]
+ line[-5:] = [int(float(x)) for x in line[-5:]]
+
+ packet = OxtsPacket(*line)
+
+ if scale is None:
+ scale = np.cos(packet.lat * np.pi / 180.)
+
+ R, t = pose_from_oxts_packet(packet, scale)
+
+ if origin is None:
+ origin = t
+
+ T_w_imu = transform_from_rot_trans(R, t - origin)
+
+ oxts.append(OxtsData(packet, T_w_imu))
+
+ return oxts
+
+
diff --git a/vidar/datasets/OuroborosDataset.py b/vidar/datasets/OuroborosDataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..004688ae84564c53fb71658ebe9532742f1dffd0
--- /dev/null
+++ b/vidar/datasets/OuroborosDataset.py
@@ -0,0 +1,487 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import os
+import pickle
+
+import numpy as np
+from dgp.utils.camera import Camera
+from dgp.utils.pose import Pose
+
+from vidar.datasets.BaseDataset import BaseDataset
+from vidar.datasets.utils.misc import \
+ stack_sample, make_relative_pose
+from vidar.utils.data import make_list
+from vidar.utils.read import read_image
+from vidar.utils.types import is_str
+
+
+def load_from_file(filename, key):
+ """Load data cache from a file"""
+ data = np.load(filename, allow_pickle=True)[key]
+ if len(data.shape) == 0:
+ data = None
+ return data
+
+
+def save_to_file(filename, key, value):
+ """Save data to a cache file"""
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+ np.savez_compressed(filename, **{key: value})
+
+
+def generate_proj_maps(camera, Xw, shape):
+ """Render pointcloud on image.
+
+ Parameters
+ ----------
+ camera: Camera
+ Camera object with appropriately set extrinsics wrt world.
+ Xw: np.Array
+ 3D point cloud (x, y, z) in the world coordinate. [N,3]
+ shape: np.Array
+ Output depth image shape [H, W]
+
+ Returns
+ -------
+ depth: np.Array
+ Rendered depth image
+ """
+ assert len(shape) == 2, 'Shape needs to be 2-tuple.'
+ # Move point cloud to the camera's (C) reference frame from the world (W)
+ Xc = camera.p_cw * Xw
+ # Project the points as if they were in the camera's frame of reference
+ uv = Camera(K=camera.K).project(Xc).astype(int)
+ # Colorize the point cloud based on depth
+ z_c = Xc[:, 2]
+
+ # Create an empty image to overlay
+ H, W = shape
+ proj_depth = np.zeros((H, W), dtype=np.float32)
+ in_view = np.logical_and.reduce([(uv >= 0).all(axis=1), uv[:, 0] < W, uv[:, 1] < H, z_c > 0])
+ uv, z_c = uv[in_view], z_c[in_view]
+ proj_depth[uv[:, 1], uv[:, 0]] = z_c
+
+ # Return projected maps
+ return proj_depth
+
+
+class OuroborosDataset(BaseDataset):
+ """
+ DGP dataset class
+
+ Parameters
+ ----------
+ path : String
+ Path to the dataset
+ split : String {'train', 'val', 'test'}
+ Which dataset split to use
+ cameras : list[String]
+ Which cameras to get information from
+ depth_type : String
+ Which lidar will be used to generate ground-truth information
+ input_depth_type : String
+ Which lidar will be used as input to the networks
+ with_pose : Bool
+ If enabled pose estimates are also returned
+ with_extra_context : Bool
+ If enabled extra context information (e.g. depth, semantic, instance) are also returned
+ back_context : Int
+ Size of the backward context
+ forward_context : Int
+ Size of the forward context
+ data_transform : Function
+ Transformations applied to the sample
+ dataset : String ['synchronized', 'parallel_domain']
+ Which dataset will be used
+ only_cache : Bool
+ Only use cached pointcloud information, without loading the sensor
+ """
+ def __init__(self, split, tag=None,
+ depth_type=None, input_depth_type=None,
+ masks=None, **kwargs):
+ super().__init__(**kwargs)
+ self.tag = 'ouroboros' if tag is None else tag
+
+ cameras = [c if is_str(c) else 'camera_%02d' % c for c in self.cameras]
+
+ # Store variables
+ self.split = split
+ self.dataset_idx = 0
+ self.sensors = list(cameras)
+
+ # Store task information
+ self.depth_type = depth_type
+ self.input_depth_type = input_depth_type
+ self.only_cache = False
+
+ self.masks_path = masks
+
+ # Add requested annotations
+ requested_annotations = []
+
+ # Add depth sensor
+ if self.with_depth and not self.only_cache and \
+ self.depth_type != 'zbuffer':
+ self.sensors.append(depth_type)
+ self.depth_idx = len(self.sensors) - 1
+
+ # Add input depth sensor
+ if self.with_input_depth and not self.only_cache and \
+ self.input_depth_type != 'zbuffer' and \
+ self.input_depth_type != self.depth_type:
+ self.sensors.append(input_depth_type)
+ self.input_depth_idx = len(self.sensors) - 1
+
+ # Add radar sensor
+ if self.with_radar:
+ self.sensors.append('radar')
+ self.radar_idx = len(self.sensors) - 1
+
+ # Choose which dataset to use
+ if not self.virtual:
+ from dgp.datasets.synchronized_dataset import SynchronizedSceneDataset
+ dataset = SynchronizedSceneDataset
+ extra_args = {}
+ else:
+ from dgp.datasets.pd_dataset import ParallelDomainSceneDataset
+ dataset = ParallelDomainSceneDataset
+ extra_args = {
+ 'use_virtual_camera_datums': False,
+ }
+
+ # Initialize chosen dataset
+ self.dataset = dataset(
+ scene_dataset_json=self.path,
+ split=split,
+ datum_names=self.sensors,
+ backward_context=self.bwd_context,
+ forward_context=self.fwd_context,
+ requested_annotations=requested_annotations,
+ only_annotated_datums=False,
+ **extra_args,
+ )
+
+ def depth_to_world_points(self, depth, datum_idx):
+ """
+ Unproject depth from a camera's perspective into a world-frame pointcloud
+
+ Parameters
+ ----------
+ depth : np.Array
+ Depth map to be lifted [H,W]
+ datum_idx : Int
+ Index of the camera
+
+ Returns
+ -------
+ pointcloud : np.Array
+ Lifted 3D pointcloud [Nx3]
+ """
+ # Access data
+ intrinsics = self.get_current('intrinsics', datum_idx)
+ pose = self.get_current('pose', datum_idx)
+ # Create pixel grid for 3D unprojection
+ h, w = depth.shape[:2]
+ uv = np.mgrid[:w, :h].transpose(2, 1, 0).reshape(-1, 2).astype(np.float32)
+ # Unproject grid to 3D in the camera frame of reference
+ pcl = Camera(K=intrinsics).unproject(uv) * depth.reshape(-1, 1)
+ # Return pointcloud in world frame of reference
+ return pose * pcl
+
+ def create_camera(self, datum_idx, context=None):
+ """
+ Create current camera
+
+ Parameters
+ ----------
+ datum_idx : Int
+ Index of the camera
+ context : Int
+ Context value for choosing current of reference information
+
+ Returns
+ -------
+ camera : Camera
+ DGP camera
+ """
+ camera_pose = self.get_current_or_context('pose', datum_idx, context)
+ camera_intrinsics = self.get_current_or_context('intrinsics', datum_idx, context)
+ return Camera(K=camera_intrinsics, p_cw=camera_pose.inverse())
+
+ def create_proj_maps(self, filename, camera_idx, depth_idx, depth_type,
+ world_points=None, context=None):
+ """
+ Creates the depth map for a camera by projecting LiDAR information.
+ It also caches the depth map following DGP folder structure, so it's not recalculated
+
+ Parameters
+ ----------
+ filename : String
+ Filename used for loading / saving
+ camera_idx : Int
+ Camera sensor index
+ depth_idx : Int
+ Depth sensor index
+ depth_type : String
+ Which depth type will be loaded
+ world_points : np.Array [Nx3]
+ Points that will be projected (optional)
+ context : Int
+ Context value for choosing current of reference information
+
+ Returns
+ -------
+ depth : np.Array
+ Depth map for that datum in that sample [H,W]
+ """
+ # If we want the z-buffer (simulation)
+ if depth_type == 'zbuffer':
+ sensor_name = self.get_current('datum_name', camera_idx)
+ filename = filename.replace(self.sensors[camera_idx], sensor_name)
+ filename = '{}/{}.npz'.format(
+ os.path.dirname(self.path), filename.format('depth'))
+ return np.load(filename)['data'], None, None
+ # Otherwise, we want projected information
+ filename_depth = '{}/{}.npz'.format(
+ os.path.dirname(self.path), filename.format('projected/depth/{}'.format(depth_type)))
+ # Load and return if exists
+ try:
+ # Get cached depth map
+ depth = load_from_file(filename_depth, 'depth')
+ return depth
+ except:
+ pass
+ # Calculate world points if needed
+ if world_points is None:
+ # Get lidar information
+ lidar_pose = self.get_current_or_context('pose', depth_idx, context)
+ lidar_points = self.get_current_or_context('point_cloud', depth_idx, context)
+ world_points = lidar_pose * lidar_points
+
+ # Create camera
+ camera = self.create_camera(camera_idx, context)
+ image_shape = self.get_current_or_context('rgb', camera_idx, context).size[::-1]
+ # Generate depth maps
+ depth = generate_proj_maps(camera, world_points, image_shape)
+ # Save depth map
+ save_to_file(filename_depth, 'depth', depth)
+ # Return depth
+ return depth
+
+
+ def get_current(self, key, sensor_idx, as_dict=False):
+ """Return current timestep of a key from a sensor"""
+ current = self.sample_dgp[self.bwd_context][sensor_idx][key]
+ return current if not as_dict else {0: current}
+
+ def get_backward(self, key, sensor_idx):
+ """Return backward timesteps of a key from a sensor"""
+ return [] if self.bwd_context == 0 else \
+ [self.sample_dgp[i][sensor_idx][key] for i in range(0, self.bwd_context)]
+
+ def get_forward(self, key, sensor_idx):
+ """Return forward timesteps of a key from a sensor"""
+ return [] if self.fwd_context == 0 else \
+ [self.sample_dgp[i][sensor_idx][key]
+ for i in range(self.bwd_context + 1,
+ self.bwd_context + self.fwd_context + 1)]
+
+ def get_context(self, key, sensor_idx, as_dict=False):
+ """Get both backward and forward contexts"""
+ context = self.get_backward(key, sensor_idx) + self.get_forward(key, sensor_idx)
+ if not as_dict:
+ return context
+ else:
+ return {key: val for key, val in zip(self.context, context)}
+
+ def get_current_or_context(self, key, sensor_idx, context=None, as_dict=False):
+ """Return current or context information for a given key and sensor index"""
+ if context is None:
+ return self.get_current(key, sensor_idx, as_dict=as_dict)
+ else:
+ return self.get_context(key, sensor_idx, as_dict=as_dict)[context]
+
+ def has_dgp_key(self, key, sensor_idx):
+ """Returns True if the DGP sample contains a certain key"""
+ return key in self.sample_dgp[self.bwd_context][sensor_idx].keys()
+
+ def get_filename(self, sample_idx, datum_idx, context=0):
+ """
+ Returns the filename for an index, following DGP structure
+
+ Parameters
+ ----------
+ sample_idx : Int
+ Sample index
+ datum_idx : Int
+ Datum index
+ context : Int
+ Context offset for the sample
+
+ Returns
+ -------
+ filename : String
+ Filename for the datum in that sample
+ """
+ scene_idx, sample_idx_in_scene, _ = self.dataset.dataset_item_index[sample_idx]
+ scene_dir = self.dataset.scenes[scene_idx].directory
+ filename = self.dataset.get_datum(
+ scene_idx, sample_idx_in_scene + context, self.sensors[datum_idx]).datum.image.filename
+ return os.path.splitext(os.path.join(os.path.basename(scene_dir),
+ filename.replace('rgb', '{}')))[0]
+
+ def __len__(self):
+ """Length of dataset"""
+ return len(self.dataset)
+
+ def __getitem__(self, idx):
+ """Get dataset sample"""
+
+ # Get DGP sample (if single sensor, make it a list)
+ self.sample_dgp = self.dataset[idx]
+ self.sample_dgp = [make_list(sample) for sample in self.sample_dgp]
+
+ # Reorganize sensors to the right order
+ sensor_names = [self.get_current('datum_name', i).lower() for i in range(len(self.sensors))]
+ indexes = [sensor_names.index(v) for v in self.sensors]
+ self.sample_dgp = [[s[idx] for idx in indexes] for s in self.sample_dgp]
+
+ # Loop over all cameras
+ samples = []
+ for i in range(self.num_cameras):
+
+ # Filename
+ filename = self.get_filename(idx, i)
+
+ # Base sample
+ sample = {
+ 'idx': idx,
+ 'tag': self.tag,
+ 'filename': self.relative_path({0: filename}),
+ 'splitname': '%s_%010d' % (self.split, idx),
+ 'sensor_name': self.get_current('datum_name', i),
+ }
+
+ # Image and intrinsics
+ sample.update({
+ 'rgb': self.get_current('rgb', i, as_dict=True),
+ 'intrinsics': self.get_current('intrinsics', i, as_dict=True),
+ })
+
+ # If masks are returned
+ if self.masks_path is not None:
+ sample.update({
+ 'mask': read_image(os.path.join(
+ self.masks_path, '%02d.png' % self.cameras[i]))
+ })
+
+ # If depth is returned
+ if self.with_depth:
+ # Get depth maps
+ depth = self.create_proj_maps(
+ filename, i, self.depth_idx, self.depth_type)
+ # Include depth map
+ sample.update({
+ 'depth': {0: depth}
+ })
+
+ # If input depth is returned
+ if self.with_input_depth:
+ sample.update({
+ 'input_depth': {0: self.create_proj_maps(
+ filename, i, self.input_depth_idx, self.input_depth_type)[0]}
+ })
+
+ # If pose is returned
+ if self.with_pose:
+ sample.update({
+ 'extrinsics': {key: val.inverse().matrix for key, val in
+ self.get_current('extrinsics', i, as_dict=True).items()},
+ 'pose': {key: val.inverse().matrix for key, val in
+ self.get_current('pose', i, as_dict=True).items()},
+ })
+
+ # If context is returned
+ if self.with_context:
+
+ # Include context images
+ sample['rgb'].update(self.get_context('rgb', i, as_dict=True))
+
+ # Create contexts filenames if extra context is required
+ filename_context = []
+ for context in range(-self.bwd_context, 0):
+ filename_context.append(self.get_filename(idx, i, context))
+ for context in range(1, self.fwd_context + 1):
+ filename_context.append(self.get_filename(idx, i, context))
+ sample['filename_context'] = filename_context
+
+ # If context pose is returned
+ if self.with_pose:
+ # Get original values to calculate relative motion
+ inv_orig_extrinsics = Pose.from_matrix(sample['extrinsics'][0]).inverse()
+ sample['extrinsics'].update(
+ {key: (inv_orig_extrinsics * val.inverse()).matrix for key, val in zip(
+ self.context, self.get_context('extrinsics', i))})
+ sample['pose'].update(
+ {key: (val.inverse()).matrix for key, val in zip(
+ self.context, self.get_context('pose', i))})
+
+ # If context depth is returned
+ if self.with_depth_context:
+ depth_context = [
+ self.create_proj_maps(
+ filename, i, self.depth_idx, self.depth_type,
+ context=k)
+ for k, filename in enumerate(filename_context)]
+ sample['depth'].update(
+ {key: val for key, val in zip(
+ self.context, [dsf for dsf in depth_context])})
+
+
+ samples.append(sample)
+
+ # Make relative poses
+ samples = make_relative_pose(samples)
+
+ # Add LiDAR information
+
+ lidar_sample = {}
+ if self.with_lidar:
+
+ # Include pointcloud information
+ lidar_sample.update({
+ 'lidar_pointcloud': self.get_current('point_cloud', self.depth_idx),
+ })
+
+ # If pose is included
+ if self.with_pose:
+ lidar_sample.update({
+ 'lidar_extrinsics': self.get_current('extrinsics', self.depth_idx).matrix,
+ 'lidar_pose': self.get_current('pose', self.depth_idx).matrix,
+ })
+
+ # If extra context is included
+ if self.with_extra_context:
+ lidar_sample['lidar_context'] = self.get_context('point_cloud', self.depth_idx)
+ # If context pose is included
+ if self.with_pose:
+ # Get original values to calculate relative motion
+ orig_extrinsics = Pose.from_matrix(lidar_sample['lidar_extrinsics'])
+ orig_pose = Pose.from_matrix(lidar_sample['lidar_pose'])
+ lidar_sample.update({
+ 'lidar_extrinsics_context':
+ [(orig_extrinsics.inverse() * extrinsics).inverse().matrix
+ for extrinsics in self.get_context('extrinsics', self.depth_idx)],
+ 'lidar_pose_context':
+ [(orig_pose.inverse() * pose).inverse().matrix
+ for pose in self.get_context('pose', self.depth_idx)],
+ })
+
+
+ # Apply same data transformations for all sensors
+ if self.data_transform:
+ samples = self.data_transform(samples)
+ # lidar_sample = self.data_transform(lidar_sample)
+
+ # Return sample (stacked if necessary)
+ return stack_sample(samples, lidar_sample)
diff --git a/vidar/datasets/ScanNetTemporalDataset.py b/vidar/datasets/ScanNetTemporalDataset.py
new file mode 100755
index 0000000000000000000000000000000000000000..42524c99d3e0ab5834c9e2c58ef9bf5f12806353
--- /dev/null
+++ b/vidar/datasets/ScanNetTemporalDataset.py
@@ -0,0 +1,116 @@
+
+import os
+
+import numpy as np
+
+import cv2
+
+from vidar.datasets.BaseDataset import BaseDataset
+from vidar.datasets.utils.FolderTree import FolderTree
+from vidar.datasets.utils.misc import stack_sample
+from vidar.utils.read import read_image
+from vidar.datasets.BaseDataset import BaseDataset
+from vidar.datasets.utils.misc import invert_pose, stack_sample, make_relative_pose
+from vidar.utils.read import read_image
+
+
+class ScanNetTemporalDataset(BaseDataset):
+ def __init__(self, tag=None, single_folder=False, split=None, stride=1, **kwargs):
+ super().__init__(**kwargs)
+ self.tag = 'scannet_temporal' if tag is None else tag
+ if split is None or split == '':
+ split = ('', )
+ self.rgb_tree = FolderTree(
+ os.path.join(self.path, split),
+ context=self.context, sub_folders=['color'], stride=stride,
+ single_folder=single_folder, suffix='.jpg')
+
+ def __len__(self):
+ """Dataset length"""
+ return len(self.rgb_tree)
+
+ @staticmethod
+ def get_intrinsics(rgb):
+ """Return dummy intrinsics"""
+ return np.array([[rgb.size[0] / 2., 0., rgb.size[0] / 2.],
+ [0., rgb.size[1], rgb.size[1] / 2.],
+ [0., 0., 1.]])
+
+ @staticmethod
+ def load_intrinsics(filename):
+ filename_intrinsics = {key: '/'.join(val.split('/')[:-2]) + '/intrinsic/intrinsic_depth.txt'
+ for key, val in filename.items()}
+ return {key: np.genfromtxt(val).astype(np.float32).reshape((4, 4))[:3, :3]
+ for key, val in filename_intrinsics.items()}
+
+ @staticmethod
+ def load_depth(filename):
+ try:
+ filename_depth = {key: val.replace('color', 'depth').replace('.jpg', '.png')
+ for key, val in filename.items()}
+ return {key: (cv2.imread(val, cv2.IMREAD_ANYDEPTH).astype(np.float32) / 1000.0)
+ for key, val in filename_depth.items()}
+ except:
+ filename_depth = {key: val.replace('color', 'depth').replace('.jpg', '.npy')
+ for key, val in filename.items()}
+ return {key: (np.load(val) / 1000.0).astype(np.float32)
+ for key, val in filename_depth.items()}
+
+ @staticmethod
+ def load_pose(filename):
+ filename_pose = {key: val.replace('color', 'pose').replace('.jpg', '.txt')
+ for key, val in filename.items()}
+ return {key: invert_pose(np.genfromtxt(val).astype(np.float32).reshape((4, 4)))
+ for key, val in filename_pose.items()}
+
+ def __getitem__(self, idx):
+ """Get dataset sample given an index."""
+
+ samples = []
+
+ for _ in self.cameras:
+
+ # Filename
+ filename = self.rgb_tree.get_item(idx)
+
+ # Base sample
+ sample = {
+ 'idx': idx,
+ 'tag': self.tag,
+ 'filename': self.relative_path(filename),
+ 'splitname': '%010d' % idx
+ }
+
+ # Image
+ sample['rgb'] = read_image(filename)
+
+ # Intrinsics
+ sample['intrinsics'] = self.load_intrinsics(filename)
+
+ if self.with_depth:
+ sample['depth'] = self.load_depth(filename)
+
+ if self.with_pose:
+ sample['pose'] = self.load_pose(filename)
+
+ # If with context
+ if self.with_context:
+ filename_context = self.rgb_tree.get_context(idx)
+ sample['rgb'].update(read_image(filename_context))
+ if self.with_depth:
+ sample['depth'].update(self.load_depth(filename_context))
+ if self.with_pose:
+ sample['pose'].update(self.load_pose(filename_context))
+
+ # Stack sample
+ samples.append(sample)
+
+ # Make relative poses
+ samples = make_relative_pose(samples)
+
+ # Transform data
+ if self.data_transform:
+ samples = self.data_transform(samples)
+
+ # Return stacked sample
+ return stack_sample(samples)
diff --git a/vidar/datasets/VKITTI2Dataset.py b/vidar/datasets/VKITTI2Dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5985d56c80ff7240cbda51500961f97c6a9a374
--- /dev/null
+++ b/vidar/datasets/VKITTI2Dataset.py
@@ -0,0 +1,413 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import csv
+import os
+
+import cv2
+import numpy as np
+
+from vidar.datasets.BaseDataset import BaseDataset
+from vidar.datasets.utils.FolderTree import FolderTree
+from vidar.datasets.utils.misc import \
+ convert_ontology, initialize_ontology, stack_sample, make_relative_pose
+from vidar.utils.data import dict_remove_nones
+from vidar.utils.decorators import iterate1
+from vidar.utils.read import read_image
+
+
+def make_tree(path, sub_folder, camera, mode, context):
+ """
+ Create a folder tree for a certain task
+
+ Parameters
+ ----------
+ path : String
+ Data path
+ sub_folder : String
+ Subfolder path
+ camera : Int
+ Camera index
+ mode : String
+ Which task we are using
+ context : list[Int]
+ Context samples
+
+ Returns
+ -------
+ tree : FolderTree
+ Folder tree containing task data
+ """
+ path = os.path.join(path, sub_folder)
+ sub_folders = '{}/frames/{}/Camera_{}'.format(mode, sub_folder, camera)
+ return FolderTree(path, sub_folders=sub_folders, context=context)
+
+
+def semantic_color_to_id(semantic_color, ontology):
+ """
+ Convert semantic color to semantic ID
+
+ Parameters
+ ----------
+ semantic_color : numpy.Array
+ Matrix with semantic colors [H, W, 3]
+ ontology : Dict
+ Ontology dictionary, with {id: color}
+
+ Returns
+ -------
+ semantic_id : numpy.Array
+ Matrix with semantic IDs [H, W]
+ """
+ # Create semantic ID map
+ semantic_id = np.zeros(semantic_color.shape[:2])
+ # Loop over every ontology item and assign ID to color
+ for key, val in ontology.items():
+ idx = (semantic_color[:, :, 0] == val['color'][0]) & \
+ (semantic_color[:, :, 1] == val['color'][1]) & \
+ (semantic_color[:, :, 2] == val['color'][2])
+ semantic_id[idx] = key
+ # Return semantic ID map
+ return semantic_id
+
+
+class VKITTI2Dataset(BaseDataset):
+ """
+ VKITTI2 dataset class
+
+ Parameters
+ ----------
+ path : String
+ Path to the dataset
+ split : String {'train', 'val', 'test'}
+ Which dataset split to use
+ ontology : String
+ Which ontology should be used
+ return_ontology : Bool
+ Returns ontology information in the sample
+ data_transform : Function
+ Transformations applied to the sample
+ """
+ def __init__(self, split, tag=None, **kwargs):
+ super().__init__(**kwargs)
+ self.tag = 'vkitti2' if tag is None else tag
+
+ # Store variables
+ self.split = split
+ self.mode = 'clone'
+
+ # Initialize ontology
+ if self.with_semantic:
+ self.ontology, self.ontology_convert = initialize_ontology('vkitti2', self.ontology)
+ self.return_ontology = self.return_ontology
+
+ # Create RGB tree
+ self.rgb_tree = make_tree(
+ self.path, 'rgb', 0, self.mode, self.context)
+
+ # Create semantic tree
+ if self.with_semantic:
+ self.semantic_tree = make_tree(
+ self.path, 'classSegmentation', 0, self.mode, self.context)
+
+ # Create instance tree
+ if self.with_instance:
+ self.instance_tree = make_tree(
+ self.path, 'instanceSegmentation', 0, self.mode, self.context)
+
+ def __len__(self):
+ """Dataset length"""
+ return len(self.rgb_tree)
+
+ @staticmethod
+ @iterate1
+ def _get_depth(filename):
+ """Get depth map from filename"""
+ filename = filename.replace('rgb', 'depth').replace('jpg', 'png')
+ return cv2.imread(filename, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) / 100.
+
+ @staticmethod
+ @iterate1
+ def _get_intrinsics(filename, camera, mode):
+ """Get intrinsics from filename"""
+ # Get sample number in the scene
+ number = int(filename.split('/')[-1].replace('rgb_', '').replace('.jpg', ''))
+ # Get intrinsic filename
+ filename_idx = filename.rfind(mode) + len(mode)
+ filename_intrinsics = os.path.join(filename[:filename_idx].replace(
+ '/rgb/', '/textgt/'), 'intrinsic.txt')
+ # Open intrinsic file
+ with open(filename_intrinsics, 'r') as f:
+ # Get intrinsic parameters
+ lines = list(csv.reader(f, delimiter=' '))[1:]
+ params = [float(p) for p in lines[number * 2 + camera][2:]]
+ # Build intrinsics matrix
+ intrinsics = np.array([[params[0], 0.0, params[2]],
+ [0.0, params[1], params[3]],
+ [0.0, 0.0, 1.0]]).astype(np.float32)
+ # Return intrinsics
+ return intrinsics
+
+ @staticmethod
+ @iterate1
+ def _get_pose(filename, camera, mode):
+ """Get pose from filename"""
+ # Get sample number in the scene
+ number = int(filename.split('/')[-1].replace('rgb_', '').replace('.jpg', ''))
+ # Get intrinsic filename
+ filename_idx = filename.rfind(mode) + len(mode)
+ filename_pose = os.path.join(filename[:filename_idx].replace(
+ '/rgb/', '/textgt/'), 'extrinsic.txt')
+ # Open intrinsics file
+ with open(filename_pose, 'r') as f:
+ # Get pose parameters
+ lines = list(csv.reader(f, delimiter=' '))[1:]
+ pose = np.array([float(p) for p in lines[number * 2 + camera][2:]]).reshape(4, 4)
+ # Return pose
+ return pose
+
+ @staticmethod
+ def _get_ontology(filename, mode):
+ """Get ontology from filename"""
+ # Get ontology filename
+ filename_idx = filename.rfind(mode) + len(mode)
+ filename_ontology = os.path.join(filename[:filename_idx].replace(
+ '/classSegmentation/', '/textgt/'), 'colors.txt')
+ # Open ontology file
+ with open(filename_ontology, 'r') as f:
+ # Get ontology parameters
+ lines = list(csv.reader(f, delimiter=' '))[1:]
+ from collections import OrderedDict
+ ontology = OrderedDict()
+ for i, line in enumerate(lines):
+ ontology[i] = {
+ 'name': line[0],
+ 'color': np.array([int(clr) for clr in line[1:]])
+ }
+ return ontology
+
+ def _get_semantic(self, filename):
+ """Get semantic from filename"""
+ # Get semantic color map
+ semantic_color = {key: np.array(val) for key, val in read_image(filename).items()}
+ # Return semantic id map
+ semantic_id = {key: semantic_color_to_id(val, self.ontology) for key, val in semantic_color.items()}
+ return convert_ontology(semantic_id, self.ontology_convert)
+
+ @staticmethod
+ def _get_instance(filename):
+ """Get instance from filename"""
+ # Get instance id map
+ return np.array(read_image(filename))
+
+ @staticmethod
+ def _get_bbox3d(filename):
+
+ bboxes3d_dim = []
+ bboxes3d_pos = []
+ bboxes3d_rot = []
+ bboxes3d_idx = []
+
+ k = int(filename.split('/')[-1][4:-4])
+ bb = '/'.join(filename.replace('/rgb/', '/textgt/').split('/')[:-4])
+ bb += '/pose.txt'
+
+ with open(bb, 'r') as file:
+ for i, f in enumerate(file):
+ if i == 0:
+ continue
+ line = [float(a) for a in f.split(' ')]
+ if line[0] == k and line[1] == 0:
+ bboxes3d_dim.append(np.array([line[6], line[5], line[4]]))
+ bboxes3d_pos.append(np.array(line[13:16]))
+ # bboxes3d_rot.append(np.array([line[18], line[17], line[16]]))
+ bboxes3d_rot.append(np.array([line[17], line[16], line[18]]))
+ bboxes3d_idx.append(np.array([line[2]]))
+
+ return {
+ 'dim': np.stack(bboxes3d_dim, 0),
+ 'pos': np.stack(bboxes3d_pos, 0),
+ 'rot': np.stack(bboxes3d_rot, 0),
+ 'idx': np.stack(bboxes3d_idx, 0),
+ }
+
+ @staticmethod
+ @iterate1
+ def _get_optical_flow(filename, mode):
+ """
+ Get optical flow from filename. Code obtained here:
+ https://europe.naverlabs.com/research/computer-vision-research-naver-labs-europe/proxy-virtual-worlds-vkitti-2/
+ """
+ # Get filename
+ if mode == 'bwd':
+ filename = filename.replace('rgb', 'backwardFlow')
+ elif mode == 'fwd':
+ filename = filename.replace('/rgb/', '/forwardFlow/').replace('rgb_', 'flow_')
+ else:
+ raise ValueError('Invalid optical flow mode')
+ filename = filename.replace('jpg', 'png')
+ # Return None if file does not exist
+ if not os.path.exists(filename):
+ return None
+ else:
+ # Get optical flow
+ optical_flow = cv2.imread(filename, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
+ h, w = optical_flow.shape[:2]
+ # Get invalid optical flow pixels
+ invalid = optical_flow[..., 0] == 0
+ # Normalize and scale optical flow values
+ optical_flow = 2.0 / (2 ** 16 - 1.0) * optical_flow[..., 2:0:-1].astype('f4') - 1.
+ optical_flow[..., 0] *= w - 1
+ optical_flow[..., 1] *= h - 1
+ # Remove invalid pixels
+ optical_flow[invalid] = 0
+ return optical_flow
+
+ @staticmethod
+ @iterate1
+ def _get_scene_flow(filename, mode):
+ """Get scene flow from filename. Code obtained here:
+ https://europe.naverlabs.com/research/computer-vision-research-naver-labs-europe/proxy-virtual-worlds-vkitti-2/
+ """
+ # Get filename
+ if mode == 'bwd':
+ filename = filename.replace('rgb', 'backwardSceneFlow')
+ elif mode == 'fwd':
+ filename = filename.replace('/rgb/', '/forwardSceneFlow/').replace('rgb_', 'sceneFlow_')
+ else:
+ raise ValueError('Invalid scene flow mode')
+ filename = filename.replace('jpg', 'png')
+ # Return None if file does not exist
+ if not os.path.exists(filename):
+ return None
+ else:
+ # Get scene flow
+ scene_flow = cv2.imread(filename, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
+ # Return normalized and scaled optical flow (-10m to 10m)
+ return (scene_flow[:, :, ::-1] * 2. / 65535. - 1.) * 10.
+
+ def __getitem__(self, idx):
+ """Get dataset sample"""
+
+ samples = []
+
+ for camera in self.cameras:
+
+ # Get filename
+ filename = self.rgb_tree.get_item(idx)
+ filename = {key: val.replace('Camera_0', 'Camera_{}'.format(camera))
+ for key, val in filename.items()}
+
+ # Base sample
+ sample = {
+ 'idx': idx,
+ 'tag': self.tag,
+ 'filename': self.relative_path(filename),
+ 'splitname': '%s_%010d' % (self.split, idx),
+ }
+
+ # Image and intrinsics
+ sample.update({
+ 'rgb': read_image(filename),
+ 'intrinsics': self._get_intrinsics(filename, camera, self.mode),
+ })
+
+ # If returning pose
+ if self.with_pose:
+ sample['pose'] = self._get_pose(filename, camera, self.mode)
+
+ # If returning depth
+ if self.with_depth:
+ sample['depth'] = self._get_depth(filename)
+
+ # If returning input depth
+ if self.with_input_depth:
+ sample['input_depth'] = self._get_depth(filename)
+
+ # If returning semantic
+ if self.with_semantic:
+ filename = self.semantic_tree.get_item(idx)
+ sample.update({'semantic': self._get_semantic(filename)})
+ # If returning ontology
+ if self.return_ontology:
+ sample.update({'ontology': self._get_ontology(filename, self.mode)})
+
+ # If returning instance
+ if self.with_instance:
+ filename = self.instance_tree.get_item(idx)
+ sample.update({'instance': self._get_instance(filename)})
+
+ # If returning 3D bounding boxes
+ if self.with_bbox3d:
+ filename = self.rgb_tree.get_item(idx)
+ sample.update({
+ 'bboxes3d': self._get_bbox3d(filename)
+ })
+
+ # If returning optical flow
+ if self.with_optical_flow:
+ sample['bwd_optical_flow'] = \
+ dict_remove_nones(self._get_optical_flow(filename, 'bwd'))
+ sample['fwd_optical_flow'] = \
+ dict_remove_nones(self._get_optical_flow(filename, 'fwd'))
+
+ # If returning scene flow
+ if self.with_scene_flow:
+ sample['bwd_scene_flow'] = \
+ dict_remove_nones(self._get_scene_flow(filename, 'bwd'))
+ sample['fwd_scene_flow'] = \
+ dict_remove_nones(self._get_scene_flow(filename, 'fwd'))
+
+ # If returning context information
+ if self.with_context:
+
+ # Get context filenames
+ filename_context = self.rgb_tree.get_context(idx)
+ filename_context = {key: val.replace('Camera_0', 'Camera_{}'.format(camera))
+ for key, val in filename_context.items()}
+
+ # Get RGB context
+ sample['rgb'].update(read_image(filename_context))
+
+ # Get pose context
+ if self.with_pose:
+ sample['pose'].update(self._get_pose(filename_context, camera, self.mode))
+
+ # Get depth context
+ if self.with_depth_context:
+ sample['depth'].update(self._get_depth(filename_context))
+
+ # Get input depth context
+ if self.with_input_depth_context:
+ sample['input_depth'].update(self._get_depth(filename_context))
+
+ # Get semantic context
+ if self.with_semantic_context:
+ sample['semantic'].update(self._get_semantic(self.semantic_tree.get_context(idx)))
+
+ # Get optical flow context
+ if self.with_optical_flow_context:
+ sample['bwd_optical_flow'].update(
+ dict_remove_nones(self._get_optical_flow(filename_context, 'bwd')))
+ sample['fwd_optical_flow'].update(
+ dict_remove_nones(self._get_optical_flow(filename_context, 'fwd')))
+
+ # Get scene flow context
+ if self.with_scene_flow_context:
+ sample['bwd_scene_flow'].update(
+ dict_remove_nones(self._get_scene_flow(filename_context, 'bwd')))
+ sample['fwd_scene_flow'].update(
+ dict_remove_nones(self._get_scene_flow(filename_context, 'fwd')))
+
+ # Stack sample
+ samples.append(sample)
+
+ # Make relative poses
+ samples = make_relative_pose(samples)
+
+ # Transform data
+ if self.data_transform:
+ samples = self.data_transform(samples)
+
+ # Return stacked sample
+ return stack_sample(samples)
+
diff --git a/vidar/datasets/__init__.py b/vidar/datasets/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vidar/datasets/augmentations/__init__.py b/vidar/datasets/augmentations/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vidar/datasets/augmentations/crop.py b/vidar/datasets/augmentations/crop.py
new file mode 100755
index 0000000000000000000000000000000000000000..36a18ee99b7b807ea066ff9a8c2992c218e4caa2
--- /dev/null
+++ b/vidar/datasets/augmentations/crop.py
@@ -0,0 +1,158 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from copy import deepcopy
+
+import numpy as np
+
+from vidar.utils.data import keys_with
+from vidar.utils.decorators import iterate1
+
+
+@iterate1
+def crop_pil(image, borders):
+ """
+ Crop a PIL Image
+
+ Parameters
+ ----------
+ image : PIL Image
+ Input image
+ borders : Tuple
+ Borders used for cropping (left, top, right, lower)
+
+ Returns
+ -------
+ image : PIL.Image
+ Cropped image
+ """
+ return image.crop(borders)
+
+
+@iterate1
+def crop_npy(depth, borders):
+ """
+ Crop a numpy depth map
+
+ Parameters
+ ----------
+ depth : np.Array
+ Input numpy array
+ borders : Tuple
+ Borders used for cropping
+
+ Returns
+ -------
+ image : np.array
+ Cropped numpy array
+ """
+ # Return if depth value is None
+ return depth[borders[1]:borders[3], borders[0]:borders[2]]
+
+
+@iterate1
+def crop_intrinsics(intrinsics, borders):
+ """
+ Crop camera intrinsics matrix
+
+ Parameters
+ ----------
+ intrinsics : np.Array
+ Original intrinsics matrix [3,3]
+ borders : Tuple
+ Borders used for cropping
+ Returns
+ -------
+ intrinsics : np.Array
+ Cropped intrinsics matrix [3,3]
+ """
+ intrinsics = np.copy(intrinsics)
+ intrinsics[0, 2] -= borders[0]
+ intrinsics[1, 2] -= borders[1]
+ return intrinsics
+
+
+def crop_sample_input(sample, borders):
+ """
+ Crops the input information of a sample
+
+ Parameters
+ ----------
+ sample : Dict
+ Dictionary with sample values
+ borders : Tuple
+ Borders used for cropping
+
+ Returns
+ -------
+ sample : Dict
+ Cropped sample
+ """
+ # Intrinsics
+ for key in keys_with(sample, 'intrinsics', without='raw'):
+ # Create copy of full intrinsics
+ if f'raw_{key}' not in sample.keys():
+ sample[f'raw_{key}'] = deepcopy(sample[key])
+ sample[key] = crop_intrinsics(sample[key], borders)
+ # RGB
+ for key in keys_with(sample, 'rgb', without='raw'):
+ sample[key] = crop_pil(sample[key], borders)
+ # Mask
+ for key in keys_with(sample, 'mask', without='raw'):
+ sample[key] = crop_pil(sample[key], borders)
+ # Input depth
+ for key in keys_with(sample, 'input_depth'):
+ sample[key] = crop_npy(sample[key], borders)
+ # Return cropped sample
+ return sample
+
+
+def crop_sample_supervision(sample, borders):
+ """
+ Crops the output information of a sample
+
+ Parameters
+ ----------
+ sample : Dict
+ Dictionary with sample values
+ borders : Tuple
+ Borders used for cropping
+
+ Returns
+ -------
+ sample : Dict
+ Cropped sample
+ """
+ for key in keys_with(sample, 'depth', without='input_depth'):
+ sample[key] = crop_npy(sample[key], borders)
+ for key in keys_with(sample, 'optical_flow'):
+ sample[key] = crop_npy(sample[key], borders)
+ for key in keys_with(sample, 'scene_flow'):
+ sample[key] = crop_npy(sample[key], borders)
+ # Return cropped sample
+ return sample
+
+
+def crop_sample(sample, borders):
+ """
+ Crops a sample, including image, intrinsics and depth maps.
+
+ Parameters
+ ----------
+ sample : Dict
+ Dictionary with sample values
+ borders : Tuple
+ Borders used for cropping
+
+ Returns
+ -------
+ sample : Dict
+ Cropped sample
+ """
+ # Crop input information
+ sample = crop_sample_input(sample, borders)
+ # Crop output information
+ sample = crop_sample_supervision(sample, borders)
+ # Return cropped sample
+ return sample
+
+
diff --git a/vidar/datasets/augmentations/image.py b/vidar/datasets/augmentations/image.py
new file mode 100644
index 0000000000000000000000000000000000000000..a891a225a702ee9a573498f360cf4b0167e4ca74
--- /dev/null
+++ b/vidar/datasets/augmentations/image.py
@@ -0,0 +1,138 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import random
+
+import torch
+import torchvision.transforms as transforms
+
+from vidar.utils.data import keys_in
+from vidar.utils.decorators import iterate1
+
+
+def random_colorjitter(parameters):
+ """
+ Creates a reusable color jitter transformation
+
+ Parameters
+ ----------
+ parameters : Tuple
+ Color jittering parameters (brightness, contrast, saturation, hue, color)
+
+ Returns
+ -------
+ transform : torchvision.Transform
+ Color jitter transformation with fixed parameters
+ """
+ # Get and unpack values
+ brightness, contrast, saturation, hue = parameters
+ brightness = [max(0, 1 - brightness), 1 + brightness]
+ contrast = [max(0, 1 - contrast), 1 + contrast]
+ saturation = [max(0, 1 - saturation), 1 + saturation]
+ hue = [-hue, hue]
+
+ # Initialize transformation list
+ all_transforms = []
+
+ # Add brightness transformation
+ if brightness is not None:
+ brightness_factor = random.uniform(brightness[0], brightness[1])
+ all_transforms.append(transforms.Lambda(
+ lambda img: transforms.functional.adjust_brightness(img, brightness_factor)))
+ # Add contrast transformation
+ if contrast is not None:
+ contrast_factor = random.uniform(contrast[0], contrast[1])
+ all_transforms.append(transforms.Lambda(
+ lambda img: transforms.functional.adjust_contrast(img, contrast_factor)))
+ # Add saturation transformation
+ if saturation is not None:
+ saturation_factor = random.uniform(saturation[0], saturation[1])
+ all_transforms.append(transforms.Lambda(
+ lambda img: transforms.functional.adjust_saturation(img, saturation_factor)))
+ # Add hue transformation
+ if hue is not None:
+ hue_factor = random.uniform(hue[0], hue[1])
+ all_transforms.append(transforms.Lambda(
+ lambda img: transforms.functional.adjust_hue(img, hue_factor)))
+ # Shuffle transformation order
+ random.shuffle(all_transforms)
+ # Return composed transformation
+ return transforms.Compose(all_transforms)
+
+
+def colorjitter_sample(samples, parameters, background=None, prob=1.0):
+ """
+ Jitters input images as data augmentation.
+
+ Parameters
+ ----------
+ samples : Dict
+ Input sample
+ parameters : tuple (brightness, contrast, saturation, hue, color)
+ Color jittering parameters
+ background: None or String
+ Which background color should be use
+ prob : Float
+ Jittering probability
+
+ Returns
+ -------
+ sample : dict
+ Jittered sample
+ """
+ if random.random() < prob:
+ # Prepare jitter transformation
+ colorjitter_transform = random_colorjitter(parameters[:4])
+ # Prepare color transformation if requested
+ if len(parameters) > 4 and parameters[4] > 0:
+ matrix = (random.uniform(1. - parameters[4], 1 + parameters[4]), 0, 0, 0,
+ 0, random.uniform(1. - parameters[4], 1 + parameters[4]), 0, 0,
+ 0, 0, random.uniform(1. - parameters[4], 1 + parameters[4]), 0)
+ else:
+ matrix = None
+ for sample in samples:
+ # Jitter sample keys
+ for key in keys_in(sample, ['rgb']):
+ for ctx in sample[key].keys():
+ bkg, color = [], {'white': (255, 255, 255), 'black': (0, 0, 0)}
+ if background is not None:
+ for i in range(sample[key][ctx].size[0]):
+ for j in range(sample[key][ctx].size[1]):
+ if sample[key][ctx].getpixel((i,j)) == color[background]:
+ bkg.append((i,j))
+ sample[key][ctx] = colorjitter_transform(sample[key][ctx])
+ if matrix is not None:
+ sample[key][ctx] = sample[key][ctx].convert('RGB', matrix)
+ if background is not None:
+ for ij in bkg:
+ sample[key][ctx].putpixel(ij, color[background])
+ # Return jittered (?) sample
+ return samples
+
+
+@iterate1
+def normalize_sample(sample, mean, std):
+ """
+ Normalize sample
+
+ Parameters
+ ----------
+ sample : Dict
+ Input sample dictionary
+ mean : torch.Tensor
+ Normalization mean [B,3]
+ std : torch.Tensor
+ Normalization standard deviation [B,3]
+
+ Returns
+ -------
+ sample : Dict
+ Normalized sample
+ """
+ # Get mean and std values in the right shape
+ mean = torch.tensor(mean).reshape(3, 1, 1)
+ std = torch.tensor(std).reshape(3, 1, 1)
+ # Apply mean and std to every image
+ for key_sample in keys_in(sample, ['rgb']):
+ sample[key_sample] = {key:(val - mean) / std for
+ key, val in sample[key_sample].items()}
+ return sample
diff --git a/vidar/datasets/augmentations/misc.py b/vidar/datasets/augmentations/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cbf175195b7a9ed51dab7b6341e0ce5ea2dad1e
--- /dev/null
+++ b/vidar/datasets/augmentations/misc.py
@@ -0,0 +1,98 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from copy import deepcopy
+
+import numpy as np
+
+from vidar.utils.data import keys_in
+from vidar.utils.decorators import iterate1
+
+
+def duplicate_sample(sample, keys):
+ """
+ Duplicates sample images and contexts to preserve their un-augmented versions.
+
+ Parameters
+ ----------
+ sample : Dict
+ Input sample
+
+ Returns
+ -------
+ sample : Dict
+ Sample including [+"_original"] keys with copies of images and contexts.
+ """
+ for key in keys_in(sample, keys):
+ sample[f'raw_{key}'] = deepcopy(sample[key])
+ # Return duplicated sample
+ return sample
+
+
+@iterate1
+def mask_depth_number(depth, num_points):
+ """
+ Mask depth map to remove valid pixels given the target number of points to keep.
+
+ Parameters
+ ----------
+ depth : np.Array
+ Depth map to be masked
+ num_points : Int
+ Number of input depth points that should be kept at each iteration
+ Returns
+ -------
+ depth : np.Array
+ Masked depth map (modification done in-place!)
+ """
+ # Find probability of maintaining
+ total_points = depth.shape[0] * depth.shape[1]
+ rnd = np.random.rand(depth.shape[0], depth.shape[1])
+ percentile = 100 * num_points / total_points
+ # Mask depth map
+ mask = rnd < np.percentile(rnd, q=100 - percentile)
+ depth[mask] = 0.0
+ # Return depth map
+ return depth
+
+
+@iterate1
+def mask_depth_percentage(depth, percentage):
+ """
+ Mask depth map to remove valid pixels given a range of percentages.
+
+ Parameters
+ ----------
+ depth : np.Array
+ Depth map to be masked
+ percentage : Tuple
+ Min/Max percentages to be maintained (min, max)
+ Returns
+ -------
+ depth : np.Array
+ Masked depth map (modification done in-place!)
+ """
+ # Find probability of maintaining
+ rnd = np.random.uniform(low=percentage[0], high=percentage[1], size=1)[0]
+ # Mask depth map
+ depth[np.random.rand(*depth.shape) > rnd] = 0.0
+ # Return depth map
+ return depth
+
+
+def clip_depth(sample, max_value):
+ """Clip depth map to a maximum range"""
+ for i in range(len(sample)):
+ if 'depth' in sample[i]:
+ for ctx in sample[i]['depth'].keys():
+ sample[i]['depth'][ctx][sample[i]['depth'][ctx] > max_value] = max_value
+ return sample
+
+
+def mask_depth_range(sample, depth_range):
+ """Mask out depth map within a range"""
+ for i in range(len(sample)):
+ if 'depth' in sample[i]:
+ for ctx in sample[i]['depth'].keys():
+ sample[i]['depth'][ctx][sample[i]['depth'][ctx] < depth_range[0]] = 0.0
+ sample[i]['depth'][ctx][sample[i]['depth'][ctx] > depth_range[1]] = 0.0
+ return sample
diff --git a/vidar/datasets/augmentations/resize.py b/vidar/datasets/augmentations/resize.py
new file mode 100644
index 0000000000000000000000000000000000000000..0deec59b69fa761d7aac94526d4f2e37d1b54b34
--- /dev/null
+++ b/vidar/datasets/augmentations/resize.py
@@ -0,0 +1,312 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from copy import deepcopy
+
+import cv2
+import numpy as np
+import torch
+import torchvision.transforms as transforms
+from torchvision.transforms import InterpolationMode
+
+from vidar.utils.data import keys_with
+from vidar.utils.decorators import iterate1
+from vidar.utils.types import is_seq
+
+
+@iterate1
+def resize_pil(image, shape, interpolation=InterpolationMode.LANCZOS):
+ """
+ Resizes input image
+
+ Parameters
+ ----------
+ image : Image PIL
+ Input image
+ shape : Tuple
+ Output shape [H,W]
+ interpolation : Int
+ Interpolation mode
+
+ Returns
+ -------
+ image : Image PIL
+ Resized image
+ """
+ transform = transforms.Resize(shape, interpolation=interpolation)
+ return transform(image)
+
+
+@iterate1
+def resize_npy(depth, shape, expand=True):
+ """
+ Resizes depth map
+
+ Parameters
+ ----------
+ depth : np.Array
+ Depth map [h,w]
+ shape : Tuple
+ Output shape (H,W)
+ expand : Bool
+ Expand output to [H,W,1]
+
+ Returns
+ -------
+ depth : np.Array
+ Resized depth map [H,W]
+ """
+ # If a single number is provided, use resize ratio
+ if not is_seq(shape):
+ shape = tuple(int(s * shape) for s in depth.shape)
+ # Resize depth map
+ depth = cv2.resize(depth, dsize=tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST)
+ # Return resized depth map
+ return np.expand_dims(depth, axis=2) if expand else depth
+
+
+@iterate1
+def resize_npy_preserve(depth, shape):
+ """
+ Resizes depth map preserving all valid depth pixels
+ Multiple downsampled points can be assigned to the same pixel.
+
+ Parameters
+ ----------
+ depth : np.Array
+ Depth map [h,w]
+ shape : Tuple
+ Output shape (H,W)
+
+ Returns
+ -------
+ depth : np.Array
+ Resized depth map [H,W,1]
+ """
+ # If a single number is provided, use resize ratio
+ if not is_seq(shape):
+ shape = tuple(int(s * shape) for s in depth.shape)
+ # Store dimensions and reshapes to single column
+ depth = np.squeeze(depth)
+ h, w = depth.shape
+ x = depth.reshape(-1)
+ # Create coordinate grid
+ uv = np.mgrid[:h, :w].transpose(1, 2, 0).reshape(-1, 2)
+ # Filters valid points
+ idx = x > 0
+ crd, val = uv[idx], x[idx]
+ # Downsamples coordinates
+ crd[:, 0] = (crd[:, 0] * (shape[0] / h)).astype(np.int32)
+ crd[:, 1] = (crd[:, 1] * (shape[1] / w)).astype(np.int32)
+ # Filters points inside image
+ idx = (crd[:, 0] < shape[0]) & (crd[:, 1] < shape[1])
+ crd, val = crd[idx], val[idx]
+ # Creates downsampled depth image and assigns points
+ depth = np.zeros(shape)
+ depth[crd[:, 0], crd[:, 1]] = val
+ # Return resized depth map
+ return np.expand_dims(depth, axis=2)
+
+
+@iterate1
+def resize_torch_preserve(depth, shape):
+ """
+ Resizes depth map preserving all valid depth pixels
+ Multiple downsampled points can be assigned to the same pixel.
+
+ Parameters
+ ----------
+ depth : torch.Tensor
+ Depth map [B,1,h,w]
+ shape : Tuple
+ Output shape (H,W)
+
+ Returns
+ -------
+ depth : torch.Tensor
+ Resized depth map [B,1,H,W]
+ """
+ if depth.dim() == 4:
+ return torch.stack([resize_torch_preserve(depth[i], shape)
+ for i in range(depth.shape[0])], 0)
+ # If a single number is provided, use resize ratio
+ if not is_seq(shape):
+ shape = tuple(int(s * shape) for s in depth.shape)
+ # Store dimensions and reshapes to single column
+ c, h, w = depth.shape
+ # depth = np.squeeze(depth)
+ # h, w = depth.shape
+ x = depth.reshape(-1)
+ # Create coordinate grid
+ uv = np.mgrid[:h, :w].transpose(1, 2, 0).reshape(-1, 2)
+ # Filters valid points
+ idx = x > 0
+ crd, val = uv[idx], x[idx]
+ # Downsamples coordinates
+ crd[:, 0] = (crd[:, 0] * (shape[0] / h)).astype(np.int32)
+ crd[:, 1] = (crd[:, 1] * (shape[1] / w)).astype(np.int32)
+ # Filters points inside image
+ idx = (crd[:, 0] < shape[0]) & (crd[:, 1] < shape[1])
+ crd, val = crd[idx], val[idx]
+ # Creates downsampled depth image and assigns points
+ depth = torch.zeros(shape, device=depth.device, dtype=depth.dtype)
+ depth[crd[:, 0], crd[:, 1]] = val
+ # Return resized depth map
+ return depth.unsqueeze(0)
+
+
+@iterate1
+def resize_npy_multiply(data, shape):
+ """Resize a numpy array and scale its content accordingly"""
+ ratio_w = shape[0] / data.shape[0]
+ ratio_h = shape[1] / data.shape[1]
+ out = resize_npy(data, shape, expand=False)
+ out[..., 0] *= ratio_h
+ out[..., 1] *= ratio_w
+ return out
+
+
+@iterate1
+def resize_intrinsics(intrinsics, original, resized):
+ """
+ Resize camera intrinsics matrix to match a target resolution
+
+ Parameters
+ ----------
+ intrinsics : np.Array
+ Original intrinsics matrix [3,3]
+ original : Tuple
+ Original image resolution [W,H]
+ resized : Tuple
+ Target image resolution [w,h]
+ Returns
+ -------
+ intrinsics : np.Array
+ Resized intrinsics matrix [3,3]
+ """
+ intrinsics = np.copy(intrinsics)
+ intrinsics[0] *= resized[0] / original[0]
+ intrinsics[1] *= resized[1] / original[1]
+ return intrinsics
+
+
+@iterate1
+def resize_sample_input(sample, shape, shape_supervision=None,
+ depth_downsample=1.0, preserve_depth=False,
+ pil_interpolation=InterpolationMode.LANCZOS):
+ """
+ Resizes the input information of a sample
+
+ Parameters
+ ----------
+ sample : Dict
+ Dictionary with sample values
+ shape : tuple (H,W)
+ Output shape
+ shape_supervision : Tuple
+ Output supervision shape (H,W)
+ depth_downsample: Float
+ Resize ratio for depth maps
+ preserve_depth : Bool
+ Preserve depth maps when resizing
+ pil_interpolation : Int
+ Interpolation mode
+
+ Returns
+ -------
+ sample : Dict
+ Resized sample
+ """
+ # Intrinsics
+ for key in keys_with(sample, 'intrinsics', without='raw'):
+ if f'raw_{key}' not in sample.keys():
+ sample[f'raw_{key}'] = deepcopy(sample[key])
+ sample[key] = resize_intrinsics(sample[key], list(sample['rgb'].values())[0].size, shape[::-1])
+ # RGB
+ for key in keys_with(sample, 'rgb', without='raw'):
+ sample[key] = resize_pil(sample[key], shape, interpolation=pil_interpolation)
+ # Mask
+ for key in keys_with(sample, 'mask', without='raw'):
+ sample[key] = resize_pil(sample[key], shape, interpolation=InterpolationMode.NEAREST)
+ # Input depth
+ for key in keys_with(sample, 'input_depth'):
+ shape_depth = [int(s * depth_downsample) for s in shape]
+ resize_npy_depth = resize_npy_preserve if preserve_depth else resize_npy
+ sample[key] = resize_npy_depth(sample[key], shape_depth)
+ return sample
+
+
+@iterate1
+def resize_sample_supervision(sample, shape, depth_downsample=1.0, preserve_depth=False):
+ """
+ Resizes the output information of a sample
+
+ Parameters
+ ----------
+ sample : Dict
+ Dictionary with sample values
+ shape : Tuple
+ Output shape (H,W)
+ depth_downsample: Float
+ Resize ratio for depth maps
+ preserve_depth : Bool
+ Preserve depth maps when resizing
+
+ Returns
+ -------
+ sample : Dict
+ Resized sample
+ """
+ # Depth
+ for key in keys_with(sample, 'depth', without='input_depth'):
+ shape_depth = [int(s * depth_downsample) for s in shape]
+ resize_npy_depth = resize_npy_preserve if preserve_depth else resize_npy
+ sample[key] = resize_npy_depth(sample[key], shape_depth)
+ # Semantic
+ for key in keys_with(sample, 'semantic'):
+ sample[key] = resize_npy(sample[key], shape, expand=False)
+ # Optical flow
+ for key in keys_with(sample, 'optical_flow'):
+ sample[key] = resize_npy_multiply(sample[key], shape)
+ # Scene flow
+ for key in keys_with(sample, 'scene_flow'):
+ sample[key] = resize_npy(sample[key], shape, expand=False)
+ # Return resized sample
+ return sample
+
+
+def resize_sample(sample, shape, shape_supervision=None, depth_downsample=1.0, preserve_depth=False,
+ pil_interpolation=InterpolationMode.LANCZOS):
+ """
+ Resizes a sample, including image, intrinsics and depth maps.
+
+ Parameters
+ ----------
+ sample : Dict
+ Dictionary with sample values
+ shape : Tuple
+ Output shape (H,W)
+ shape_supervision : Tuple
+ Output shape (H,W)
+ depth_downsample: Float
+ Resize ratio for depth maps
+ preserve_depth : Bool
+ Preserve depth maps when resizing
+ pil_interpolation : Int
+ Interpolation mode
+
+ Returns
+ -------
+ sample : Dict
+ Resized sample
+ """
+ # Resize input information
+ sample = resize_sample_input(sample, shape,
+ depth_downsample=depth_downsample,
+ preserve_depth=preserve_depth,
+ pil_interpolation=pil_interpolation)
+ # Resize output information
+ sample = resize_sample_supervision(sample, shape_supervision,
+ depth_downsample=depth_downsample,
+ preserve_depth=preserve_depth)
+ # Return resized sample
+ return sample
diff --git a/vidar/datasets/augmentations/tensor.py b/vidar/datasets/augmentations/tensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..c90bf7de2bdef5d5d3f1773ecdedf21645c54755
--- /dev/null
+++ b/vidar/datasets/augmentations/tensor.py
@@ -0,0 +1,53 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import torch
+import torchvision.transforms as transforms
+
+from vidar.utils.decorators import iterate1
+
+
+@iterate1
+def to_tensor(matrix, tensor_type='torch.FloatTensor'):
+ """Casts a matrix to a torch.Tensor"""
+ return torch.Tensor(matrix).type(tensor_type)
+
+
+@iterate1
+def to_tensor_image(image, tensor_type='torch.FloatTensor'):
+ """Casts an image to a torch.Tensor"""
+ transform = transforms.ToTensor()
+ return transform(image).type(tensor_type)
+
+
+@iterate1
+def to_tensor_sample(sample, tensor_type='torch.FloatTensor'):
+ """
+ Casts the keys of sample to tensors.
+
+ Parameters
+ ----------
+ sample : Dict
+ Input sample
+ tensor_type : String
+ Type of tensor we are casting to
+
+ Returns
+ -------
+ sample : Dict
+ Sample with keys cast as tensors
+ """
+ # Convert using torchvision
+ keys = ['rgb', 'mask', 'input_depth', 'depth', 'disparity',
+ 'optical_flow', 'scene_flow']
+ for key_sample, val_sample in sample.items():
+ for key in keys:
+ if key in key_sample:
+ sample[key_sample] = to_tensor_image(val_sample, tensor_type)
+ # Convert from numpy
+ keys = ['intrinsics', 'extrinsics', 'pose', 'pointcloud', 'semantic']
+ for key_sample, val_sample in sample.items():
+ for key in keys:
+ if key in key_sample:
+ sample[key_sample] = to_tensor(val_sample, tensor_type)
+ # Return converted sample
+ return sample
diff --git a/vidar/datasets/utils/FolderTree.py b/vidar/datasets/utils/FolderTree.py
new file mode 100644
index 0000000000000000000000000000000000000000..667a86052903187dd9cb982951a00be674e50c30
--- /dev/null
+++ b/vidar/datasets/utils/FolderTree.py
@@ -0,0 +1,132 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import os
+from glob import glob
+
+import numpy as np
+
+from vidar.utils.data import make_list
+from vidar.utils.types import is_list
+
+
+class FolderTree:
+ """
+ Creates a dataset tree folder structure for file loading.
+
+ Parameters
+ ----------
+ path : String
+ Path where dataset is stored
+ prefix : String
+ Optional prefix for each filename
+ suffix : String
+ Optional prefix for each filename
+ sub_folders : list[String]
+ Optional list of sub_folders located inside each data folder where data is stored
+ deep : Int
+ How deep in the folder structure we should go
+ single_folder : Bool
+ Whether the folder structure should be ignored, and a single folder is used
+ nested : bool
+ If true, go one folder deeper to find scenes
+ stride: Int
+ Stride for context generation
+ context : Tuple
+ Which context should be used
+ """
+ def __init__(self, path, prefix='', suffix='', sub_folders=('',), deep=1,
+ single_folder=False, nested=False, stride=1, context=()):
+
+ # Store context information
+ self.context = list(context)
+ if 0 not in self.context:
+ self.context.append(0)
+ self.num_context = 0 if len(self.context) == 0 else max(self.context) - min(self.context)
+ self.with_context = self.num_context > 0
+ self.min_context = 0 if not self.with_context else min(self.context)
+
+ self.stride = stride
+ self.pad_numbers = False
+
+ # Initialize empty folder tree
+ self.folder_tree = []
+
+ # If we are providing a file list, treat each line as a scene
+ if is_list(path):
+ self.folder_tree = [[file] for file in path]
+ # If we are providing a folder
+ else:
+ # Get folders
+ string = '*' + '/*' * (deep - 1)
+ folders = glob(os.path.join(path, string))
+ folders.sort()
+
+ # If nesting, go one folder deeper in order to find the scenes
+ if nested:
+ upd_folders = []
+ for folder in folders:
+ new_folders = glob(os.path.join(folder, '*'))
+ upd_folders.extend(new_folders)
+ folders = upd_folders
+ folders.sort()
+
+ if single_folder:
+ # Use current folder as the only one
+ self.folder_tree.append(folders)
+ else:
+ # Populate folder tree
+ for folder in folders:
+ # For each sub-folder
+ for sub_folder in make_list(sub_folders):
+ # Get and sort files in each folder
+ files = glob(os.path.join(folder, sub_folder, '{}*{}'.format(prefix, suffix)))
+ if self.pad_numbers:
+ for i in range(len(files)):
+ pref, suf = files[i].split('/')[:-1], files[i].split('/')[-1]
+ num, ext = suf.split('.')
+ files[i] = '/'.join(pref) + ('/%010d' % int(num)) + '.' + ext
+ files.sort()
+ if self.pad_numbers:
+ for i in range(len(files)):
+ pref, suf = files[i].split('/')[:-1], files[i].split('/')[-1]
+ num, ext = suf.split('.')
+ files[i] = '/'.join(pref) + ('/%d' % int(num)) + '.' + ext
+ if self.stride > 1:
+ files = files[::self.stride]
+ # Only store if there are more images than context
+ if len(files) > self.num_context:
+ self.folder_tree.append(files)
+
+ # Get size of each folder
+ self.slices = [len(folder) for folder in self.folder_tree]
+ # Compensate for context size
+ if self.with_context:
+ self.slices = [s - self.num_context for s in self.slices]
+ # Create cumulative size and get total
+ self.slices = [0] + list(np.cumsum(self.slices))
+ self.total = self.slices[-1]
+
+ def __len__(self):
+ """Dataset size"""
+ return self.total
+
+ def get_idxs(self, idx):
+ """Get folder and file indexes given dataset index"""
+ idx1 = np.searchsorted(self.slices, idx, side='right') - 1
+ idx2 = idx - self.slices[idx1]
+ return idx1, idx2
+
+ def get_item(self, idx, return_loc=False):
+ """Return filename item given index"""
+ idx1, idx2 = self.get_idxs(idx)
+ item = {0: self.folder_tree[idx1][idx2 - self.min_context]}
+ if return_loc:
+ return item, idx2 - self.min_context
+ else:
+ return item
+
+ def get_context(self, idx):
+ """Return forward context given index."""
+ idx1, idx2 = self.get_idxs(idx)
+ return {ctx: self.folder_tree[idx1][idx2 - self.min_context + ctx] for ctx in self.context}
+
diff --git a/vidar/datasets/utils/__init__.py b/vidar/datasets/utils/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vidar/datasets/utils/misc.py b/vidar/datasets/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..67b3fd2d7ffe0f32327c9d50c79cde5700b9ec77
--- /dev/null
+++ b/vidar/datasets/utils/misc.py
@@ -0,0 +1,322 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import json
+import os
+import random
+from collections import OrderedDict
+
+import numpy as np
+import torch
+
+from vidar.utils.decorators import iterate1
+from vidar.utils.types import is_seq, is_tensor, is_dict, is_int
+
+
+def stack_sample(sample, lidar_sample=None, radar_sample=None):
+ """
+ Stack samples from multiple cameras
+
+ Parameters
+ ----------
+ sample : list[Dict]
+ List of camera samples
+ lidar_sample : list[Dict]
+ List of lidar samples
+ radar_sample : list[Dict]
+ List of radar samples
+
+ Returns
+ -------
+ stacked_sample: Dict
+ Stacked sample
+ """
+ # If there are no tensors, return empty list
+ if len(sample) == 0:
+ return None
+ # If there is only one sensor don't do anything
+ if len(sample) == 1:
+ sample = sample[0]
+ return sample
+ # Otherwise, stack sample
+ first_sample = sample[0]
+ stacked_sample = {}
+ for key, val in first_sample.items():
+ # Global keys (do not stack)
+ if key in ['idx', 'dataset_idx']:
+ stacked_sample[key] = first_sample[key]
+ # Meta keys
+ elif key in ['meta']:
+ stacked_sample[key] = {}
+ for key2 in first_sample[key].keys():
+ stacked_sample[key][key2] = {}
+ for key3 in first_sample[key][key2].keys():
+ stacked_sample[key][key2][key3] = torch.stack(
+ [torch.tensor(s[key][key2][key3]) for s in sample], 0)
+ # Stack tensors
+ elif is_tensor(val):
+ stacked_sample[key] = torch.stack([s[key] for s in sample], 0)
+ # Stack list
+ elif is_seq(first_sample[key]):
+ stacked_sample[key] = []
+ # Stack list of torch tensors
+ if is_tensor(first_sample[key][0]):
+ for i in range(len(first_sample[key])):
+ stacked_sample[key].append(
+ torch.stack([s[key][i] for s in sample], 0))
+ else:
+ stacked_sample[key] = [s[key] for s in sample]
+ # Repeat for dictionaries
+ elif is_dict(first_sample[key]):
+ stacked_sample[key] = stack_sample([s[key] for s in sample])
+ # Append lists
+ else:
+ stacked_sample[key] = [s[key] for s in sample]
+
+ # Return stacked sample
+ return stacked_sample
+
+
+def merge_sample(samples):
+ """Merge information from multiple samples"""
+ merged_sample = {}
+ for sample in samples:
+ for key, val in sample.items():
+ if key not in merged_sample:
+ merged_sample[key] = val
+ else:
+ merged_sample[key] = merge_sample([merged_sample[key], val])
+ return merged_sample
+
+
+def parse_crop(cfg, shape):
+ """Parse crop information to generate borders"""
+ borders = None
+ if cfg.has('crop_borders'):
+ borders = parse_crop_borders(cfg.crop_borders, shape)
+ if cfg.has('crop_random'):
+ if borders is None:
+ borders = [0, 0, shape[1], shape[0]]
+ borders = parse_crop_random(borders, cfg.crop_random)
+ return borders
+
+
+def parse_crop_borders(borders, shape):
+ """
+ Calculate borders for cropping.
+
+ Parameters
+ ----------
+ borders : Tuple
+ Border input for parsing. Can be one of the following forms:
+ (int, int, int, int): y, height, x, width
+ (int, int): y, x --> y, height = image_height - y, x, width = image_width - x
+ Negative numbers are taken from image borders, according to the shape argument
+ Float numbers for y and x are treated as percentage, according to the shape argument,
+ and in this case height and width are centered at that point.
+ shape : Tuple
+ Image shape (image_height, image_width), used to determine negative crop boundaries
+
+ Returns
+ -------
+ borders : Tuple
+ Parsed borders for cropping (left, top, right, bottom)
+ """
+ # Return full image if there are no borders to crop
+ if len(borders) == 0:
+ return 0, 0, shape[1], shape[0]
+ # Copy borders for modification
+ borders = list(borders).copy()
+ # If borders are 4-dimensional
+ if len(borders) == 4:
+ borders = [borders[2], borders[0], borders[3], borders[1]]
+ if is_int(borders[0]):
+ # If horizontal cropping is integer (regular cropping)
+ borders[0] += shape[1] if borders[0] < 0 else 0
+ borders[2] += shape[1] if borders[2] <= 0 else borders[0]
+ else:
+ # If horizontal cropping is float (center cropping)
+ center_w, half_w = borders[0] * shape[1], borders[2] / 2
+ borders[0] = int(center_w - half_w)
+ borders[2] = int(center_w + half_w)
+ if is_int(borders[1]):
+ # If vertical cropping is integer (regular cropping)
+ borders[1] += shape[0] if borders[1] < 0 else 0
+ borders[3] += shape[0] if borders[3] <= 0 else borders[1]
+ else:
+ # If vertical cropping is float (center cropping)
+ center_h, half_h = borders[1] * shape[0], borders[3] / 2
+ borders[1] = int(center_h - half_h)
+ borders[3] = int(center_h + half_h)
+ # If borders are 2-dimensional
+ elif len(borders) == 2:
+ borders = [borders[1], borders[0]]
+ if is_int(borders[0]):
+ # If cropping is integer (regular cropping)
+ borders = (max(0, borders[0]),
+ max(0, borders[1]),
+ shape[1] + min(0, borders[0]),
+ shape[0] + min(0, borders[1]))
+ else:
+ # If cropping is float (center cropping)
+ center_w, half_w = borders[0] * shape[1], borders[1] / 2
+ center_h, half_h = borders[0] * shape[0], borders[1] / 2
+ borders = (int(center_w - half_w), int(center_h - half_h),
+ int(center_w + half_w), int(center_h + half_h))
+ # Otherwise, invalid
+ else:
+ raise NotImplementedError('Crop tuple must have 2 or 4 values.')
+ # Assert that borders are valid
+ assert 0 <= borders[0] < borders[2] <= shape[1] and \
+ 0 <= borders[1] < borders[3] <= shape[0], 'Crop borders {} are invalid'.format(borders)
+ # Return updated borders
+ return borders
+
+
+def parse_crop_random(borders, shape):
+ """
+ Create borders for random cropping.
+ Crops are generated anywhere in the image inside the borders
+
+ Parameters
+ ----------
+ borders : Tuple
+ Area of the image where random cropping can happen (left, top, right, bottom)
+ shape : Tuple
+ Cropped output shape (height, width)
+
+ Returns
+ -------
+ borders : tuple
+ Parsed borders for cropping (left, top, right, bottom)
+ """
+ # Return full borders if there is no random crop
+ if len(shape) == 0:
+ return borders
+ # Check if random crop is valid
+ assert 0 < shape[1] <= borders[2] - borders[0] and \
+ 0 < shape[0] <= borders[3] - borders[1], 'Random crop must be smaller than the image'
+ # Sample a crop
+ x = random.randint(borders[0], borders[2] - shape[1])
+ y = random.randint(borders[1], borders[3] - shape[0])
+ # Return new borders
+ return x, y, x + shape[1], y + shape[0]
+
+
+@iterate1
+def invert_pose(pose):
+ """
+ Inverts a transformation matrix (pose)
+
+ Parameters
+ ----------
+ pose : np.Array
+ Input pose [4, 4]
+
+ Returns
+ -------
+ inv_pose : np.Array
+ Inverted pose [4, 4]
+ """
+ inv_pose = np.eye(4)
+ inv_pose[:3, :3] = np.transpose(pose[:3, :3])
+ inv_pose[:3, -1] = - inv_pose[:3, :3] @ pose[:3, -1]
+ # return np.linalg.inv(pose)
+ return inv_pose
+
+
+def make_relative_pose(samples):
+ """
+ Convert sample poses to relative frane of reference (based on the first target frame)
+
+ Parameters
+ ----------
+ samples : list[Dict]
+ Input samples
+
+ Returns
+ -------
+ samples : list[Dict]
+ Relative samples
+ """
+ # Do nothing if there is no pose
+ if 'pose' not in samples[0]:
+ return samples
+ # Get inverse current poses
+ inv_pose = [invert_pose(samples[i]['pose'][0]) for i in range(len(samples))]
+ # For each camera
+ for i in range(len(samples)):
+ # For each context
+ for j in samples[0]['pose'].keys():
+ if j == 0:
+ if i > 0:
+ samples[i]['pose'][j] = \
+ samples[i]['pose'][j] @ inv_pose[0]
+ else:
+ samples[i]['pose'][j] = \
+ samples[i]['pose'][j] @ inv_pose[i]
+ return samples
+
+
+def dummy_intrinsics(image):
+ """
+ Return dummy intrinsics calculated based on image resolution
+
+ Parameters
+ ----------
+ image : PIL Image
+ Image from which intrinsics will be calculated
+
+ Returns
+ -------
+ intrinsics : np.Array
+ Image intrinsics (fx = cx = w/2, fy = cy = h/2) [3,3]
+ """
+ w, h = [float(d) for d in image.size]
+ return np.array([[w/2, 0., w/2. - 0.5],
+ [0., h/2, h/2. - 0.5],
+ [0., 0., 1.]])
+
+
+def load_ontology(name, filter_list=None):
+ """Loads ontology from file and optionally filters it"""
+ filename = 'vidar/ontologies/{}.json'.format(name)
+ if os.path.exists(filename):
+ ontology = json.load(open(filename, 'r'))
+ if filter_list is not None and len(filter_list) > 0:
+ ontology = filter_ontology(ontology, filter_list)
+ return ontology
+ else:
+ return None
+
+
+def save_ontology(ontology, name):
+ """Save ontology to a JSON file"""
+ if is_seq(ontology):
+ ontology = ontology[0]
+ for key in ontology.keys():
+ ontology[key]['color'] = ontology[key]['color'].tolist()
+ json.dump(ontology, open('ontologies/{}.json'.format(name), 'w'))
+
+
+def filter_ontology(ontology, values):
+ """Filter ontology to remove certain classes"""
+ new_ontology = OrderedDict()
+ for i, val in enumerate(values[1:]):
+ new_ontology[i] = ontology[str(val)]
+ return new_ontology
+
+
+def convert_ontology(semantic_id, ontology_convert):
+ """Convert from one ontology to another"""
+ if ontology_convert is None:
+ return semantic_id
+ else:
+ semantic_id_convert = semantic_id.copy()
+ for key, val in ontology_convert.items():
+ semantic_id_convert[semantic_id == key] = val
+ return semantic_id_convert
+
+
+def initialize_ontology(base, ontology):
+ """Initialize ontology and conversion table if necessary"""
+ return load_ontology(base), None
diff --git a/vidar/datasets/utils/transforms.py b/vidar/datasets/utils/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd48f67f38494f6e48116b047d3f8434351f5478
--- /dev/null
+++ b/vidar/datasets/utils/transforms.py
@@ -0,0 +1,98 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from functools import partial
+
+from vidar.datasets.augmentations.image import \
+ colorjitter_sample, normalize_sample
+from vidar.datasets.augmentations.crop import \
+ crop_sample_input, crop_sample
+from vidar.datasets.augmentations.misc import \
+ duplicate_sample, mask_depth_percentage, mask_depth_number, clip_depth, mask_depth_range
+from vidar.datasets.augmentations.resize import resize_sample, resize_sample_input
+from vidar.datasets.augmentations.tensor import to_tensor_sample
+from vidar.datasets.utils.misc import parse_crop
+from vidar.utils.types import is_list
+
+
+def train_transforms(sample, cfg):
+ """
+ Training data augmentation transformations
+
+ Parameters
+ ----------
+ sample : Dict
+ Sample to be augmented
+ cfg : Config
+ Configuration for transformations
+
+ Returns
+ -------
+ sample : Dict
+ Augmented sample
+ """
+ # Resize
+ if cfg.has('resize'):
+ resize_fn = resize_sample if cfg.has('resize_supervision') else resize_sample_input
+ shape_supervision = None if not cfg.has('resize_supervision') else \
+ cfg.resize if not is_list(cfg.resize_supervision) else cfg.resize_supervision
+ sample = resize_fn(sample, shape=cfg.resize, shape_supervision=shape_supervision,
+ depth_downsample=cfg.has('depth_downsample', 1.0),
+ preserve_depth=cfg.has('preserve_depth', False))
+ # Crop
+ if cfg.has('crop_borders') or cfg.has('crop_random'):
+ crop_fn = crop_sample if cfg.has('crop_supervision') else crop_sample_input
+ sample = [crop_fn(s, parse_crop(cfg, s['rgb'][0].size[::-1])) for s in sample]
+ # Clip depth to a maximum value
+ if cfg.has('clip_depth'):
+ sample = clip_depth(sample, cfg.clip_depth)
+ if cfg.has('mask_depth_range'):
+ sample = mask_depth_range(sample, cfg.mask_depth_range)
+ # Change input depth density
+ if 'input_depth' in sample:
+ if cfg.has('input_depth_number'):
+ sample['input_depth'] = mask_depth_number(
+ sample['input_depth'], cfg.input_depth_number)
+ if cfg.has('input_depth_percentage'):
+ sample['input_depth'] = mask_depth_percentage(
+ sample['input_depth'], cfg.input_depth_percentage)
+ # Apply jittering
+ if cfg.has('jittering'):
+ sample = duplicate_sample(sample, ['rgb'])
+ sample = colorjitter_sample(sample, cfg.jittering, cfg.has('background', None), prob=1.0)
+ # Convert to tensor
+ sample = to_tensor_sample(sample)
+ if cfg.has('normalization'):
+ sample = normalize_sample(sample, cfg.normalization[0], cfg.normalization[1])
+ # Return augmented sample
+ return sample
+
+
+def no_transform(sample):
+ """No transformation, only convert sample to tensors"""
+ sample = to_tensor_sample(sample)
+ return sample
+
+
+def get_transforms(mode, cfg=None):
+ """
+ Get data augmentation transformations for each split
+
+ Parameters
+ ----------
+ mode : String {'train', 'validation', 'test'}
+ Mode from which we want the data augmentation transformations
+ cfg : Config
+ Configuration file
+
+ Returns
+ -------
+ XXX_transform: Partial function
+ Data augmentation transformation for that mode
+ """
+ if mode == 'train':
+ return partial(train_transforms, cfg=cfg)
+ elif mode == 'none':
+ return partial(no_transform)
+ else:
+ raise ValueError('Unknown mode {}'.format(mode))
+
diff --git a/vidar/geometry/camera.py b/vidar/geometry/camera.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ba80045ad235422ccd5aa651210c86fef102d62
--- /dev/null
+++ b/vidar/geometry/camera.py
@@ -0,0 +1,528 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from abc import ABC
+from copy import deepcopy
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+
+from vidar.geometry.camera_utils import invert_intrinsics, scale_intrinsics
+from vidar.geometry.pose import Pose
+from vidar.geometry.pose_utils import invert_pose
+from vidar.utils.tensor import pixel_grid, same_shape, cat_channel_ones, norm_pixel_grid, interpolate, interleave
+from vidar.utils.types import is_tensor, is_seq
+
+
+class Camera(nn.Module, ABC):
+ """
+ Camera class for 3D reconstruction
+
+ Parameters
+ ----------
+ K : torch.Tensor
+ Camera intrinsics [B,3,3]
+ hw : Tuple
+ Camera height and width
+ Twc : Pose or torch.Tensor
+ Camera pose (world to camera) [B,4,4]
+ Tcw : Pose or torch.Tensor
+ Camera pose (camera to world) [B,4,4]
+ """
+ def __init__(self, K, hw, Twc=None, Tcw=None):
+ super().__init__()
+
+ # Asserts
+
+ assert Twc is None or Tcw is None
+
+ # Fold if multi-batch
+
+ if K.dim() == 4:
+ K = rearrange(K, 'b n h w -> (b n) h w')
+ if Twc is not None:
+ Twc = rearrange(Twc, 'b n h w -> (b n) h w')
+ if Tcw is not None:
+ Tcw = rearrange(Tcw, 'b n h w -> (b n) h w')
+
+ # Intrinsics
+
+ if same_shape(K.shape[-2:], (3, 3)):
+ self._K = torch.eye(4, dtype=K.dtype, device=K.device).repeat(K.shape[0], 1, 1)
+ self._K[:, :3, :3] = K
+ else:
+ self._K = K
+
+ # Pose
+
+ if Twc is None and Tcw is None:
+ self._Twc = torch.eye(4, dtype=K.dtype, device=K.device).unsqueeze(0).repeat(K.shape[0], 1, 1)
+ else:
+ self._Twc = invert_pose(Tcw) if Tcw is not None else Twc
+ if is_tensor(self._Twc):
+ self._Twc = Pose(self._Twc)
+
+ # Resolution
+
+ self._hw = hw
+ if is_tensor(self._hw):
+ self._hw = self._hw.shape[-2:]
+
+ def __getitem__(self, idx):
+ """Return batch-wise pose"""
+ if is_seq(idx):
+ return type(self).from_list([self.__getitem__(i) for i in idx])
+ else:
+ return type(self)(
+ K=self._K[[idx]],
+ Twc=self._Twc[[idx]] if self._Twc is not None else None,
+ hw=self._hw,
+ )
+
+ def __len__(self):
+ """Return length as intrinsics batch"""
+ return self._K.shape[0]
+
+ def __eq__(self, cam):
+ """Check if two cameras are the same"""
+ if not isinstance(cam, type(self)):
+ return False
+ if self._hw[0] != cam.hw[0] or self._hw[1] != cam.hw[1]:
+ return False
+ if not torch.allclose(self._K, cam.K):
+ return False
+ if not torch.allclose(self._Twc.T, cam.Twc.T):
+ return False
+ return True
+
+ def clone(self):
+ """Return a copy of this camera"""
+ return deepcopy(self)
+
+ @property
+ def pose(self):
+ """Return camera pose (world to camera)"""
+ return self._Twc.T
+
+ @property
+ def K(self):
+ """Return camera intrinsics"""
+ return self._K
+
+ @K.setter
+ def K(self, K):
+ """Set camera intrinsics"""
+ self._K = K
+
+ @property
+ def invK(self):
+ """Return inverse of camera intrinsics"""
+ return invert_intrinsics(self._K)
+
+ @property
+ def batch_size(self):
+ """Return batch size"""
+ return self._Twc.T.shape[0]
+
+ @property
+ def hw(self):
+ """Return camera height and width"""
+ return self._hw
+
+ @hw.setter
+ def hw(self, hw):
+ """Set camera height and width"""
+ self._hw = hw
+
+ @property
+ def wh(self):
+ """Get camera width and height"""
+ return self._hw[::-1]
+
+ @property
+ def n_pixels(self):
+ """Return number of pixels"""
+ return self._hw[0] * self._hw[1]
+
+ @property
+ def fx(self):
+ """Return horizontal focal length"""
+ return self._K[:, 0, 0]
+
+ @property
+ def fy(self):
+ """Return vertical focal length"""
+ return self._K[:, 1, 1]
+
+ @property
+ def cx(self):
+ """Return horizontal principal point"""
+ return self._K[:, 0, 2]
+
+ @property
+ def cy(self):
+ """Return vertical principal point"""
+ return self._K[:, 1, 2]
+
+ @property
+ def fxy(self):
+ """Return focal length"""
+ return torch.tensor([self.fx, self.fy], dtype=self.dtype, device=self.device)
+
+ @property
+ def cxy(self):
+ """Return principal points"""
+ return self._K[:, :2, 2]
+ # return torch.tensor([self.cx, self.cy], dtype=self.dtype, device=self.device)
+
+ @property
+ def Tcw(self):
+ """Return camera pose (camera to world)"""
+ return None if self._Twc is None else self._Twc.inverse()
+
+ @Tcw.setter
+ def Tcw(self, Tcw):
+ """Set camera pose (camera to world)"""
+ self._Twc = Tcw.inverse()
+
+ @property
+ def Twc(self):
+ """Return camera pose (world to camera)"""
+ return self._Twc
+
+ @Twc.setter
+ def Twc(self, Twc):
+ """Set camera pose (world to camera)"""
+ self._Twc = Twc
+
+ @property
+ def dtype(self):
+ """Return tensor type"""
+ return self._K.dtype
+
+ @property
+ def device(self):
+ """Return device"""
+ return self._K.device
+
+ def detach_pose(self):
+ """Detach pose from the graph"""
+ return type(self)(K=self._K, hw=self._hw,
+ Twc=self._Twc.detach() if self._Twc is not None else None)
+
+ def detach_K(self):
+ """Detach intrinsics from the graph"""
+ return type(self)(K=self._K.detach(), hw=self._hw, Twc=self._Twc)
+
+ def detach(self):
+ """Detach camera from the graph"""
+ return type(self)(K=self._K.detach(), hw=self._hw,
+ Twc=self._Twc.detach() if self._Twc is not None else None)
+
+ def inverted_pose(self):
+ """Invert camera pose"""
+ return type(self)(K=self._K, hw=self._hw,
+ Twc=self._Twc.inverse() if self._Twc is not None else None)
+
+ def no_translation(self):
+ """Return new camera without translation"""
+ Twc = self.pose.clone()
+ Twc[:, :-1, -1] = 0
+ return type(self)(K=self._K, hw=self._hw, Twc=Twc)
+
+ @staticmethod
+ def from_dict(K, hw, Twc=None):
+ """Create cameras from a pose dictionary"""
+ return {key: Camera(K=K[0], hw=hw[0], Twc=val) for key, val in Twc.items()}
+
+ # @staticmethod
+ # def from_dict(K, hw, Twc=None):
+ # return {key: Camera(K=K[(0, 0)], hw=hw[(0, 0)], Twc=val) for key, val in Twc.items()}
+
+ @staticmethod
+ def from_list(cams):
+ """Create cameras from a list"""
+ K = torch.cat([cam.K for cam in cams], 0)
+ Twc = torch.cat([cam.Twc.T for cam in cams], 0)
+ return Camera(K=K, Twc=Twc, hw=cams[0].hw)
+
+ def scaled(self, scale_factor):
+ """Return a scaled camera"""
+ if scale_factor is None or scale_factor == 1:
+ return self
+ if is_seq(scale_factor):
+ if len(scale_factor) == 4:
+ scale_factor = scale_factor[-2:]
+ scale_factor = [float(scale_factor[i]) / float(self._hw[i]) for i in range(2)]
+ else:
+ scale_factor = [scale_factor] * 2
+ return type(self)(
+ K=scale_intrinsics(self._K, scale_factor),
+ hw=[int(self._hw[i] * scale_factor[i]) for i in range(len(self._hw))],
+ Twc=self._Twc
+ )
+
+ def offset_start(self, start):
+ """Offset camera intrinsics based on a crop"""
+ new_cam = self.clone()
+ start = start.to(self.device)
+ new_cam.K[:, 0, 2] -= start[:, 1]
+ new_cam.K[:, 1, 2] -= start[:, 0]
+ return new_cam
+
+ def interpolate(self, rgb):
+ """Interpolate an image to fit the camera"""
+ if rgb.dim() == 5:
+ rgb = rearrange(rgb, 'b n c h w -> (b n) c h w')
+ return interpolate(rgb, scale_factor=None, size=self.hw, mode='bilinear', align_corners=True)
+
+ def interleave_K(self, b):
+ """Interleave camera intrinsics to fit multiple batches"""
+ return type(self)(
+ K=interleave(self._K, b),
+ Twc=self._Twc,
+ hw=self._hw,
+ )
+
+ def interleave_Twc(self, b):
+ """Interleave camera pose to fit multiple batches"""
+ return type(self)(
+ K=self._K,
+ Twc=interleave(self._Twc, b),
+ hw=self._hw,
+ )
+
+ def interleave(self, b):
+ """Interleave camera to fit multiple batches"""
+ return type(self)(
+ K=interleave(self._K, b),
+ Twc=interleave(self._Twc, b),
+ hw=self._hw,
+ )
+
+ def Pwc(self, from_world=True):
+ """Return projection matrix"""
+ return self._K[:, :3] if not from_world or self._Twc is None else \
+ torch.matmul(self._K, self._Twc.T)[:, :3]
+
+ def to_world(self, points):
+ """Transform points to world coordinates"""
+ if points.dim() > 3:
+ points = points.reshape(points.shape[0], 3, -1)
+ return points if self.Tcw is None else self.Tcw * points
+
+ def from_world(self, points):
+ """Transform points back to camera coordinates"""
+ if points.dim() > 3:
+ points = points.reshape(points.shape[0], 3, -1)
+ return points if self._Twc is None else \
+ torch.matmul(self._Twc.T, cat_channel_ones(points, 1))[:, :3]
+
+ def to(self, *args, **kwargs):
+ """Copy camera to device"""
+ self._K = self._K.to(*args, **kwargs)
+ if self._Twc is not None:
+ self._Twc = self._Twc.to(*args, **kwargs)
+ return self
+
+ def cuda(self, *args, **kwargs):
+ """Copy camera to CUDA"""
+ return self.to('cuda')
+
+ def relative_to(self, cam):
+ """Create a new camera relative to another camera"""
+ return Camera(K=self._K, hw=self._hw, Twc=self._Twc * cam.Twc.inverse())
+
+ def global_from(self, cam):
+ """Create a new camera in global coordinates relative to another camera"""
+ return Camera(K=self._K, hw=self._hw, Twc=self._Twc * cam.Twc)
+
+ def reconstruct_depth_map(self, depth, to_world=False):
+ """
+ Reconstruct a depth map from the camera viewpoint
+
+ Parameters
+ ----------
+ depth : torch.Tensor
+ Input depth map [B,1,H,W]
+ to_world : Bool
+ Transform points to world coordinates
+
+ Returns
+ -------
+ points : torch.Tensor
+ Output 3D points [B,3,H,W]
+ """
+ if depth is None:
+ return None
+ b, _, h, w = depth.shape
+ grid = pixel_grid(depth, with_ones=True, device=depth.device).view(b, 3, -1)
+ points = depth.view(b, 1, -1) * torch.matmul(self.invK[:, :3, :3], grid)
+ if to_world and self.Tcw is not None:
+ points = self.Tcw * points
+ return points.view(b, 3, h, w)
+
+ def reconstruct_cost_volume(self, volume, to_world=True, flatten=True):
+ """
+ Reconstruct a cost volume from the camera viewpoint
+
+ Parameters
+ ----------
+ volume : torch.Tensor
+ Input depth map [B,1,D,H,W]
+ to_world : Bool
+ Transform points to world coordinates
+ flatten: Bool
+ Flatten volume points
+
+ Returns
+ -------
+ points : torch.Tensor
+ Output 3D points [B,3,D,H,W]
+ """
+ c, d, h, w = volume.shape
+ grid = pixel_grid((h, w), with_ones=True, device=volume.device).view(3, -1).repeat(1, d)
+ points = torch.stack([
+ (volume.view(c, -1) * torch.matmul(invK[:3, :3].unsqueeze(0), grid)).view(3, d * h * w)
+ for invK in self.invK], 0)
+ if to_world and self.Tcw is not None:
+ points = self.Tcw * points
+ if flatten:
+ return points.view(-1, 3, d, h * w).permute(0, 2, 1, 3)
+ else:
+ return points.view(-1, 3, d, h, w)
+
+ def project_points(self, points, from_world=True, normalize=True, return_z=False):
+ """
+ Project points back to image plane
+
+ Parameters
+ ----------
+ points : torch.Tensor
+ Input 3D points [B,3,H,W] or [B,3,N]
+ from_world : Bool
+ Whether points are in the global frame
+ normalize : Bool
+ Whether projections should be normalized to [-1,1]
+ return_z : Bool
+ Whether projected depth is return as well
+
+ Returns
+ -------
+ coords : torch.Tensor
+ Projected 2D coordinates [B,2,H,W]
+ depth : torch.Tensor
+ Projected depth [B,1,H,W]
+ """
+ is_depth_map = points.dim() == 4
+ hw = self._hw if not is_depth_map else points.shape[-2:]
+
+ if is_depth_map:
+ points = points.reshape(points.shape[0], 3, -1)
+ b, _, n = points.shape
+
+ points = torch.matmul(self.Pwc(from_world), cat_channel_ones(points, 1))
+
+ coords = points[:, :2] / (points[:, 2].unsqueeze(1) + 1e-7)
+ depth = points[:, 2]
+
+ if not is_depth_map:
+ if normalize:
+ coords = norm_pixel_grid(coords, hw=self._hw, in_place=True)
+ invalid = (coords[:, 0] < -1) | (coords[:, 0] > 1) | \
+ (coords[:, 1] < -1) | (coords[:, 1] > 1) | (depth < 0)
+ coords[invalid.unsqueeze(1).repeat(1, 2, 1)] = -2
+ if return_z:
+ return coords.permute(0, 2, 1), depth
+ else:
+ return coords.permute(0, 2, 1)
+
+ coords = coords.view(b, 2, *hw)
+ depth = depth.view(b, 1, *hw)
+
+ if normalize:
+ coords = norm_pixel_grid(coords, hw=self._hw, in_place=True)
+ invalid = (coords[:, 0] < -1) | (coords[:, 0] > 1) | \
+ (coords[:, 1] < -1) | (coords[:, 1] > 1) | (depth[:, 0] < 0)
+ coords[invalid.unsqueeze(1).repeat(1, 2, 1, 1)] = -2
+
+ if return_z:
+ return coords.permute(0, 2, 3, 1), depth
+ else:
+ return coords.permute(0, 2, 3, 1)
+
+ def project_cost_volume(self, points, from_world=True, normalize=True):
+ """
+ Project points back to image plane
+
+ Parameters
+ ----------
+ points : torch.Tensor
+ Input 3D points [B,3,H,W] or [B,3,N]
+ from_world : Bool
+ Whether points are in the global frame
+ normalize : Bool
+ Whether projections should be normalized to [-1,1]
+
+ Returns
+ -------
+ coords : torch.Tensor
+ Projected 2D coordinates [B,2,H,W]
+ """
+ if points.dim() == 4:
+ points = points.permute(0, 2, 1, 3).reshape(points.shape[0], 3, -1)
+ b, _, n = points.shape
+
+ points = torch.matmul(self.Pwc(from_world), cat_channel_ones(points, 1))
+
+ coords = points[:, :2] / (points[:, 2].unsqueeze(1) + 1e-7)
+ coords = coords.view(b, 2, -1, *self._hw).permute(0, 2, 3, 4, 1)
+
+ if normalize:
+ coords[..., 0] /= self._hw[1] - 1
+ coords[..., 1] /= self._hw[0] - 1
+ return 2 * coords - 1
+ else:
+ return coords
+
+ def coords_from_cost_volume(self, volume, ref_cam=None):
+ """
+ Get warp coordinates from a cost volume
+
+ Parameters
+ ----------
+ volume : torch.Tensor
+ Input cost volume [B,1,D,H,W]
+ ref_cam : Camera
+ Optional to generate cross-camera coordinates
+
+ Returns
+ -------
+ coords : torch.Tensor
+ Projected 2D coordinates [B,2,H,W]
+ """
+ if ref_cam is None:
+ return self.project_cost_volume(self.reconstruct_cost_volume(volume, to_world=False), from_world=True)
+ else:
+ return ref_cam.project_cost_volume(self.reconstruct_cost_volume(volume, to_world=True), from_world=True)
+
+ def coords_from_depth(self, depth, ref_cam=None):
+ """
+ Get warp coordinates from a depth map
+
+ Parameters
+ ----------
+ depth : torch.Tensor
+ Input cost volume [B,1,D,H,W]
+ ref_cam : Camera
+ Optional to generate cross-camera coordinates
+
+ Returns
+ -------
+ coords : torch.Tensor
+ Projected 2D coordinates [B,2,H,W]
+ """
+ if ref_cam is None:
+ return self.project_points(self.reconstruct_depth_map(depth, to_world=False), from_world=True)
+ else:
+ return ref_cam.project_points(self.reconstruct_depth_map(depth, to_world=True), from_world=True)
diff --git a/vidar/geometry/camera_ds.py b/vidar/geometry/camera_ds.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfb39de6560c4dab67cd04fcd779b24545b0946e
--- /dev/null
+++ b/vidar/geometry/camera_ds.py
@@ -0,0 +1,213 @@
+from functools import lru_cache
+import torch
+import torch.nn as nn
+
+from vidar.geometry.pose import Pose
+from vidar.utils.tensor import pixel_grid
+
+########################################################################################################################
+
+class DSCamera(nn.Module):
+ """
+ Differentiable camera class implementing reconstruction and projection
+ functions for the double sphere (DS) camera model.
+ """
+ def __init__(self, I, Tcw=None):
+ """
+ Initializes the Camera class
+
+ Parameters
+ ----------
+ I : torch.Tensor [6]
+ Camera intrinsics parameter vector
+ Tcw : Pose
+ Camera -> World pose transformation
+ """
+ super().__init__()
+ self.I = I
+ if Tcw is None:
+ self.Tcw = Pose.identity(len(I))
+ elif isinstance(Tcw, Pose):
+ self.Tcw = Tcw
+ else:
+ self.Tcw = Pose(Tcw)
+
+ self.Tcw.to(self.I.device)
+
+ def __len__(self):
+ """Batch size of the camera intrinsics"""
+ return len(self.I)
+
+ def to(self, *args, **kwargs):
+ """Moves object to a specific device"""
+ self.I = self.I.to(*args, **kwargs)
+ self.Tcw = self.Tcw.to(*args, **kwargs)
+ return self
+
+########################################################################################################################
+
+ @property
+ def fx(self):
+ """Focal length in x"""
+ return self.I[:, 0].unsqueeze(1).unsqueeze(2)
+
+ @property
+ def fy(self):
+ """Focal length in y"""
+ return self.I[:, 1].unsqueeze(1).unsqueeze(2)
+
+ @property
+ def cx(self):
+ """Principal point in x"""
+ return self.I[:, 2].unsqueeze(1).unsqueeze(2)
+
+ @property
+ def cy(self):
+ """Principal point in y"""
+ return self.I[:, 3].unsqueeze(1).unsqueeze(2)
+
+ @property
+ def xi(self):
+ """alpha in DS model"""
+ return self.I[:, 4].unsqueeze(1).unsqueeze(2)
+
+ @property
+ def alpha(self):
+ """beta in DS model"""
+ return self.I[:, 5].unsqueeze(1).unsqueeze(2)
+
+ @property
+ @lru_cache()
+ def Twc(self):
+ """World -> Camera pose transformation (inverse of Tcw)"""
+ return self.Tcw.inverse()
+
+########################################################################################################################
+
+ def reconstruct(self, depth, frame='w'):
+ """
+ Reconstructs pixel-wise 3D points from a depth map.
+
+ Parameters
+ ----------
+ depth : torch.Tensor [B,1,H,W]
+ Depth map for the camera
+ frame : 'w'
+ Reference frame: 'c' for camera and 'w' for world
+
+ Returns
+ -------
+ points : torch.tensor [B,3,H,W]
+ Pixel-wise 3D points
+ """
+
+ if depth is None:
+ return None
+ b, c, h, w = depth.shape
+ assert c == 1
+
+ grid = pixel_grid(depth, with_ones=True, device=depth.device)
+
+ # Estimate the outward rays in the camera frame
+ fx, fy, cx, cy, xi, alpha = self.fx, self.fy, self.cx, self.cy, self.xi, self.alpha
+
+ if torch.any(torch.isnan(alpha)):
+ raise ValueError('alpha is nan')
+
+ u = grid[:,0,:,:]
+ v = grid[:,1,:,:]
+
+ mx = (u - cx) / fx
+ my = (v - cy) / fy
+ r_square = mx ** 2 + my ** 2
+ mz = (1 - alpha ** 2 * r_square) / (alpha * torch.sqrt(1 - (2 * alpha - 1) * r_square) + (1 - alpha))
+ coeff = (mz * xi + torch.sqrt(mz ** 2 + (1 - xi ** 2) * r_square)) / (mz ** 2 + r_square)
+
+ x = coeff * mx
+ y = coeff * my
+ z = coeff * mz - xi
+ z = z.clamp(min=1e-7)
+
+ x_norm = x / z
+ y_norm = y / z
+ z_norm = z / z
+ xnorm = torch.stack(( x_norm, y_norm, z_norm ), dim=1)
+
+ # Scale rays to metric depth
+ Xc = xnorm * depth
+
+ # If in camera frame of reference
+ if frame == 'c':
+ return Xc
+ # If in world frame of reference
+ elif frame == 'w':
+ return (self.Twc * Xc.view(b, 3, -1)).view(b,3,h,w)
+ # If none of the above
+ else:
+ raise ValueError('Unknown reference frame {}'.format(frame))
+
+ def project(self, X, frame='w'):
+ """
+ Projects 3D points onto the image plane
+
+ Parameters
+ ----------
+ X : torch.Tensor [B,3,H,W]
+ 3D points to be projected
+ frame : 'w'
+ Reference frame: 'c' for camera and 'w' for world
+
+ Returns
+ -------
+ points : torch.Tensor [B,H,W,2]
+ 2D projected points that are within the image boundaries
+ """
+ B, C, H, W = X.shape
+ assert C == 3
+
+ # Project 3D points onto the camera image plane
+ if frame == 'c':
+ X = X
+ elif frame == 'w':
+ X = (self.Tcw * X.view(B,3,-1)).view(B,3,H,W)
+ else:
+ raise ValueError('Unknown reference frame {}'.format(frame))
+
+ fx, fy, cx, cy, xi, alpha = self.fx, self.fy, self.cx, self.cy, self.xi, self.alpha
+ x, y, z = X[:,0,:], X[:,1,:], X[:,2,:]
+ z = z.clamp(min=1e-7)
+ d_1 = torch.sqrt( x ** 2 + y ** 2 + z ** 2 )
+ d_2 = torch.sqrt( x ** 2 + y ** 2 + (xi * d_1 + z) ** 2 )
+
+ Xnorm = fx * x / (alpha * d_2 + (1 - alpha) * (xi * d_1 + z)) + cx
+ Ynorm = fy * y / (alpha * d_2 + (1 - alpha) * (xi * d_1 + z)) + cy
+ Xnorm = 2 * Xnorm / (W-1) - 1
+ Ynorm = 2 * Ynorm / (H-1) - 1
+
+ coords = torch.stack([Xnorm, Ynorm], dim=-1).permute(0,3,1,2)
+ z = z.unsqueeze(1)
+
+ invalid = (coords[:, 0] < -1) | (coords[:, 0] > 1) | \
+ (coords[:, 1] < -1) | (coords[:, 1] > 1) | (z[:, 0] < 0)
+ coords[invalid.unsqueeze(1).repeat(1, 2, 1, 1)] = -2
+
+ # Return pixel coordinates
+ return coords.permute(0, 2, 3, 1)
+
+ def reconstruct_depth_map(self, depth, to_world=True):
+ if to_world:
+ return self.reconstruct(depth, frame='w')
+ else:
+ return self.reconstruct(depth, frame='c')
+
+ def project_points(self, points, from_world=True, normalize=True, return_z=False):
+ if from_world:
+ return self.project(points, frame='w')
+ else:
+ return self.project(points, frame='c')
+
+ def coords_from_depth(self, depth, ref_cam=None):
+ if ref_cam is None:
+ return self.project_points(self.reconstruct_depth_map(depth, to_world=False), from_world=True)
+ else:
+ return ref_cam.project_points(self.reconstruct_depth_map(depth, to_world=True), from_world=True)
\ No newline at end of file
diff --git a/vidar/geometry/camera_eucm.py b/vidar/geometry/camera_eucm.py
new file mode 100644
index 0000000000000000000000000000000000000000..472db8bd98fe0fa72ab1de4ff7ffb343fb3636bb
--- /dev/null
+++ b/vidar/geometry/camera_eucm.py
@@ -0,0 +1,215 @@
+from functools import lru_cache
+import torch
+import torch.nn as nn
+
+from vidar.geometry.camera_utils import invert_intrinsics, scale_intrinsics
+from vidar.geometry.pose import Pose
+from vidar.geometry.pose_utils import invert_pose
+from vidar.utils.tensor import pixel_grid, same_shape, cat_channel_ones, norm_pixel_grid, interpolate, interleave
+from vidar.utils.types import is_tensor, is_seq
+
+########################################################################################################################
+
+class EUCMCamera(nn.Module):
+ """
+ Differentiable camera class implementing reconstruction and projection
+ functions for the extended unified camera model (EUCM).
+ """
+ def __init__(self, I, Tcw=None):
+ """
+ Initializes the Camera class
+
+ Parameters
+ ----------
+ I : torch.Tensor [6]
+ Camera intrinsics parameter vector
+ Tcw : Pose
+ Camera -> World pose transformation
+ """
+ super().__init__()
+ self.I = I
+ if Tcw is None:
+ self.Tcw = Pose.identity(len(I))
+ elif isinstance(Tcw, Pose):
+ self.Tcw = Tcw
+ else:
+ self.Tcw = Pose(Tcw)
+
+ self.Tcw.to(self.I.device)
+
+ def __len__(self):
+ """Batch size of the camera intrinsics"""
+ return len(self.I)
+
+ def to(self, *args, **kwargs):
+ """Moves object to a specific device"""
+ self.I = self.I.to(*args, **kwargs)
+ self.Tcw = self.Tcw.to(*args, **kwargs)
+ return self
+
+########################################################################################################################
+
+ @property
+ def fx(self):
+ """Focal length in x"""
+ return self.I[:, 0].unsqueeze(1).unsqueeze(2)
+
+ @property
+ def fy(self):
+ """Focal length in y"""
+ return self.I[:, 1].unsqueeze(1).unsqueeze(2)
+
+ @property
+ def cx(self):
+ """Principal point in x"""
+ return self.I[:, 2].unsqueeze(1).unsqueeze(2)
+
+ @property
+ def cy(self):
+ """Principal point in y"""
+ return self.I[:, 3].unsqueeze(1).unsqueeze(2)
+
+ @property
+ def alpha(self):
+ """alpha in EUCM model"""
+ return self.I[:, 4].unsqueeze(1).unsqueeze(2)
+
+ @property
+ def beta(self):
+ """beta in EUCM model"""
+ return self.I[:, 5].unsqueeze(1).unsqueeze(2)
+
+ @property
+ @lru_cache()
+ def Twc(self):
+ """World -> Camera pose transformation (inverse of Tcw)"""
+ return self.Tcw.inverse()
+
+########################################################################################################################
+
+ def reconstruct(self, depth, frame='w'):
+ """
+ Reconstructs pixel-wise 3D points from a depth map.
+
+ Parameters
+ ----------
+ depth : torch.Tensor [B,1,H,W]
+ Depth map for the camera
+ frame : 'w'
+ Reference frame: 'c' for camera and 'w' for world
+
+ Returns
+ -------
+ points : torch.tensor [B,3,H,W]
+ Pixel-wise 3D points
+ """
+
+ if depth is None:
+ return None
+ b, c, h, w = depth.shape
+ assert c == 1
+
+ grid = pixel_grid(depth, with_ones=True, device=depth.device)
+
+ # Estimate the outward rays in the camera frame
+ fx, fy, cx, cy, alpha, beta = self.fx, self.fy, self.cx, self.cy, self.alpha, self.beta
+
+ if torch.any(torch.isnan(alpha)):
+ raise ValueError('alpha is nan')
+
+ u = grid[:,0,:,:]
+ v = grid[:,1,:,:]
+
+ mx = (u - cx) / fx
+ my = (v - cy) / fy
+ r_square = mx ** 2 + my ** 2
+ mz = (1 - beta * alpha ** 2 * r_square) / (alpha * torch.sqrt(1 - (2 * alpha - 1) * beta * r_square) + (1 - alpha))
+ coeff = 1 / torch.sqrt(mx ** 2 + my ** 2 + mz ** 2)
+
+ x = coeff * mx
+ y = coeff * my
+ z = coeff * mz
+ z = z.clamp(min=1e-7)
+
+ x_norm = x / z
+ y_norm = y / z
+ z_norm = z / z
+ xnorm = torch.stack(( x_norm, y_norm, z_norm ), dim=1)
+
+ # Scale rays to metric depth
+ Xc = xnorm * depth
+
+ # If in camera frame of reference
+ if frame == 'c':
+ return Xc
+ # If in world frame of reference
+ elif frame == 'w':
+ return (self.Twc * Xc.view(b, 3, -1)).view(b,3,h,w)
+ # If none of the above
+ else:
+ raise ValueError('Unknown reference frame {}'.format(frame))
+
+ def project(self, X, frame='w'):
+ """
+ Projects 3D points onto the image plane
+
+ Parameters
+ ----------
+ X : torch.Tensor [B,3,H,W]
+ 3D points to be projected
+ frame : 'w'
+ Reference frame: 'c' for camera and 'w' for world
+
+ Returns
+ -------
+ points : torch.Tensor [B,H,W,2]
+ 2D projected points that are within the image boundaries
+ """
+ B, C, H, W = X.shape
+ assert C == 3
+
+ # Project 3D points onto the camera image plane
+ if frame == 'c':
+ X = X
+ elif frame == 'w':
+ X = (self.Tcw * X.view(B,3,-1)).view(B,3,H,W)
+ else:
+ raise ValueError('Unknown reference frame {}'.format(frame))
+
+ fx, fy, cx, cy, alpha, beta = self.fx, self.fy, self.cx, self.cy, self.alpha, self.beta
+ x, y, z = X[:,0,:], X[:,1,:], X[:,2,:]
+ d = torch.sqrt( beta * ( x ** 2 + y ** 2 ) + z ** 2 )
+ z = z.clamp(min=1e-7)
+
+ Xnorm = fx * x / (alpha * d + (1 - alpha) * z + 1e-7) + cx
+ Ynorm = fy * y / (alpha * d + (1 - alpha) * z + 1e-7) + cy
+ Xnorm = 2 * Xnorm / (W-1) - 1
+ Ynorm = 2 * Ynorm / (H-1) - 1
+
+ coords = torch.stack([Xnorm, Ynorm], dim=-1).permute(0,3,1,2)
+ z = z.unsqueeze(1)
+
+ invalid = (coords[:, 0] < -1) | (coords[:, 0] > 1) | \
+ (coords[:, 1] < -1) | (coords[:, 1] > 1) | (z[:, 0] < 0)
+ coords[invalid.unsqueeze(1).repeat(1, 2, 1, 1)] = -2
+
+ # Return pixel coordinates
+ return coords.permute(0, 2, 3, 1)
+
+ def reconstruct_depth_map(self, depth, to_world=True):
+ if to_world:
+ return self.reconstruct(depth, frame='w')
+ else:
+ return self.reconstruct(depth, frame='c')
+
+ def project_points(self, points, from_world=True, normalize=True, return_z=False):
+ if from_world:
+ return self.project(points, frame='w')
+ else:
+ return self.project(points, frame='c')
+
+ def coords_from_depth(self, depth, ref_cam=None):
+ if ref_cam is None:
+ return self.project_points(self.reconstruct_depth_map(depth, to_world=False), from_world=True)
+ else:
+ return ref_cam.project_points(self.reconstruct_depth_map(depth, to_world=True), from_world=True)
\ No newline at end of file
diff --git a/vidar/geometry/camera_full.py b/vidar/geometry/camera_full.py
new file mode 100644
index 0000000000000000000000000000000000000000..21372ef5e1339484ca71474523357dc1b74e2019
--- /dev/null
+++ b/vidar/geometry/camera_full.py
@@ -0,0 +1,230 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import torch
+from torch_scatter import scatter_min
+
+from vidar.geometry.camera import Camera
+from vidar.utils.tensor import unnorm_pixel_grid
+
+
+class CameraFull(Camera):
+ """Camera class with additional functionality"""
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.convert_matrix = torch.tensor(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]],
+ dtype=torch.float32,
+ ).unsqueeze(0)
+
+ @staticmethod
+ def from_list(cams):
+ """Create cameras from a list"""
+ K = torch.cat([cam.K for cam in cams], 0)
+ Twc = torch.cat([cam.Twc.T for cam in cams], 0)
+ return CameraFull(K=K, Twc=Twc, hw=cams[0].hw)
+
+ def switch(self):
+ """Switch camera between conventions"""
+ T = self.convert_matrix.to(self.device)
+ Twc = T @ self.Twc.T @ T
+ return type(self)(K=self.K, Twc=Twc, hw=self.hw)
+
+ def bwd(self):
+ """Switch camera to the backwards convention"""
+ T = self.convert_matrix.to(self.device)
+ Tcw = T @ self.Twc.T @ T
+ return type(self)(K=self.K, Tcw=Tcw, hw=self.hw)
+
+ def fwd(self):
+ """Switch camera to the forward convention"""
+ T = self.convert_matrix.to(self.device)
+ Twc = T @ self.Tcw.T @ T
+ return type(self)(K=self.K, Twc=Twc, hw=self.hw)
+
+ def look_at(self, at, up=torch.Tensor([0, 1, 0])):
+ """
+ Set a direction for the camera to point (in-place)
+
+ Parameters
+ ----------
+ at : torch.Tensor
+ Where the camera should be pointing at [B,3]
+ up : torch.Tensor
+ Up direction [B,3]
+ """
+ eps = 1e-5
+ eye = self.Tcw.T[:, :3, -1]
+
+ at = at.unsqueeze(0)
+ up = up.unsqueeze(0).to(at.device)
+
+ z_axis = at - eye
+ z_axis /= z_axis.norm(dim=-1, keepdim=True) + eps
+
+ up = up.expand(z_axis.shape)
+ x_axis = torch.cross(up, z_axis)
+ x_axis /= x_axis.norm(dim=-1, keepdim=True) + eps
+
+ y_axis = torch.cross(z_axis, x_axis)
+ y_axis /= y_axis.norm(dim=-1, keepdim=True) + eps
+
+ R = torch.stack((x_axis, y_axis, z_axis), dim=-1)
+
+ Tcw = self.Tcw
+ Tcw.T[:, :3, :3] = R
+ self.Twc = Tcw.inverse()
+
+ def get_origin(self, flatten=False):
+ """Return camera origin"""
+ orig = self.Tcw.T[:, :3, -1].view(len(self), 3, 1, 1).repeat(1, 1, *self.hw)
+ if flatten:
+ orig = orig.reshape(len(self), 3, -1).permute(0, 2, 1)
+ return orig
+
+ def get_viewdirs(self, normalize=False, flatten=False, to_world=False):
+ """Return camera viewing rays"""
+ ones = torch.ones((len(self), 1, *self.hw), dtype=self.dtype, device=self.device)
+ rays = self.reconstruct_depth_map(ones, to_world=False)
+ if normalize:
+ rays = rays / torch.norm(rays, dim=1).unsqueeze(1)
+ if to_world:
+ rays = self.to_world(rays).reshape(len(self), 3, *self.hw)
+ if flatten:
+ rays = rays.reshape(len(self), 3, -1).permute(0, 2, 1)
+ return rays
+
+ def get_render_rays(self, near=None, far=None, n_rays=None, gt=None):
+ """
+ Get render rays
+
+ Parameters
+ ----------
+ near : Float
+ Near plane
+ far : Float
+ Far plane
+ n_rays : Int
+ Number of rays
+ gt : torch.Tensor
+ Ground-truth values for concatenation
+
+ Returns
+ -------
+ rays : torch.Tensor
+ Camera viewing rays
+ """
+ b = len(self)
+
+ ones = torch.ones((b, 1, *self.hw), dtype=self.dtype, device=self.device)
+
+ rays = self.reconstruct_depth_map(ones, to_world=False)
+ rays = rays / torch.norm(rays, dim=1).unsqueeze(1)
+
+ rays[:, 1] = - rays[:, 1]
+ rays[:, 2] = - rays[:, 2]
+
+ orig = self.pose[:, :3, -1].view(b, 3, 1, 1).repeat(1, 1, *self.hw)
+ rays = self.no_translation().inverted_pose().to_world(rays).reshape(b, 3, *self.hw)
+
+ info = [orig, rays]
+ if near is not None:
+ info = info + [near * ones]
+ if far is not None:
+ info = info + [far * ones]
+ if gt is not None:
+ info = info + [gt]
+
+ rays = torch.cat(info, 1)
+ rays = rays.permute(0, 2, 3, 1).reshape(b, -1, rays.shape[1])
+
+ if n_rays is not None:
+ idx = torch.randint(0, self.n_pixels, (n_rays,))
+ rays = rays[:, idx, :]
+
+ return rays
+
+ def get_plucker(self):
+ """Get plucker vectors"""
+ b = len(self)
+ ones = torch.ones((b, 1, *self.hw), dtype=self.dtype, device=self.device)
+ rays = self.reconstruct_depth_map(ones, to_world=False)
+ rays = rays / torch.norm(rays, dim=1).unsqueeze(1)
+ orig = self.Tcw.T[:, :3, -1].view(b, 3, 1, 1).repeat(1, 1, *self.hw)
+
+ orig = orig.view(1, 3, -1).permute(0, 2, 1)
+ rays = rays.view(1, 3, -1).permute(0, 2, 1)
+
+ cross = torch.cross(orig, rays, dim=-1)
+ plucker = torch.cat((rays, cross), dim=-1)
+
+ return plucker
+
+ def project_pointcloud(self, pcl_src, rgb_src, thr=1):
+ """
+ Project pointcloud to the camera plane
+
+ Parameters
+ ----------
+ pcl_src : torch.Tensor
+ Input 3D pointcloud
+ rgb_src : torch.Tensor
+ Pointcloud color information
+ thr : Int
+ Threshold for the number of valid points
+
+ Returns
+ -------
+ rgb_tgt : torch.Tensor
+ Projected image [B,3,H,W]
+ depth_tgt : torch.Tensor
+ Projected depth map [B,1,H,W]
+ """
+ if rgb_src.dim() == 4:
+ rgb_src = rgb_src.view(*rgb_src.shape[:2], -1)
+
+ # Get projected coordinates and depth values
+ uv_all, z_all = self.project_points(pcl_src, return_z=True, from_world=True)
+
+ rgbs_tgt, depths_tgt = [], []
+
+ b = pcl_src.shape[0]
+ for i in range(b):
+ uv, z = uv_all[i].reshape(-1, 2), z_all[i].reshape(-1, 1)
+
+ # Remove out-of-bounds coordinates and points behind the camera
+ idx = (uv[:, 0] >= -1) & (uv[:, 0] <= 1) & \
+ (uv[:, 1] >= -1) & (uv[:, 1] <= 1) & (z[:, 0] > 0.0)
+
+ # Unormalize and stack coordinates for scatter operation
+ uv = (unnorm_pixel_grid(uv[idx], self.hw)).round().long()
+ uv = uv[:, 0] + uv[:, 1] * self.hw[1]
+
+ # Min scatter operation (only keep the closest depth)
+ depth_tgt = 1e10 * torch.ones((self.hw[0] * self.hw[1], 1), device=pcl_src.device)
+ depth_tgt, argmin = scatter_min(src=z[idx], index=uv.unsqueeze(1), dim=0, out=depth_tgt)
+ depth_tgt[depth_tgt == 1e10] = 0.
+
+ num_valid = (depth_tgt > 0).sum()
+ if num_valid > thr:
+
+ # Substitute invalid values with zero
+ invalid = argmin == argmin.max()
+ argmin[invalid] = 0
+ rgb_tgt = rgb_src[i].permute(1, 0)[idx][argmin]
+ rgb_tgt[invalid] = -1
+
+ else:
+
+ rgb_tgt = -1 * torch.ones(1, self.n_pixels, 3, device=self.device, dtype=self.dtype)
+
+ # Reshape outputs
+ rgb_tgt = rgb_tgt.reshape(1, self.hw[0], self.hw[1], 3).permute(0, 3, 1, 2)
+ depth_tgt = depth_tgt.reshape(1, 1, self.hw[0], self.hw[1])
+
+ rgbs_tgt.append(rgb_tgt)
+ depths_tgt.append(depth_tgt)
+
+ rgb_tgt = torch.cat(rgbs_tgt, 0)
+ depth_tgt = torch.cat(depths_tgt, 0)
+
+ return rgb_tgt, depth_tgt
diff --git a/vidar/geometry/camera_nerf.py b/vidar/geometry/camera_nerf.py
new file mode 100755
index 0000000000000000000000000000000000000000..7cf46d9ea3b9517edd6701122df8b66044e709dd
--- /dev/null
+++ b/vidar/geometry/camera_nerf.py
@@ -0,0 +1,193 @@
+
+import torch
+from torch_scatter import scatter_min
+
+from vidar.geometry.camera import Camera
+from vidar.utils.tensor import unnorm_pixel_grid
+
+
+class CameraNerf(Camera):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.convert_matrix = torch.tensor(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]],
+ dtype=torch.float32,
+ ).unsqueeze(0)
+
+ @staticmethod
+ def from_list(cams):
+ K = torch.cat([cam.K for cam in cams], 0)
+ Twc = torch.cat([cam.Twc.T for cam in cams], 0)
+ return CameraNerf(K=K, Twc=Twc, hw=cams[0].hw)
+
+ @staticmethod
+ def from_dict(K, hw, Twc=None):
+ return {key: CameraNerf(K=K[0], hw=hw[0], Twc=val) for key, val in Twc.items()}
+
+ def switch(self):
+ T = self.convert_matrix.to(self.device)
+ Twc = T @ self.Twc.T @ T
+ return type(self)(K=self.K, Twc=Twc, hw=self.hw)
+
+ def bwd(self):
+ T = self.convert_matrix.to(self.device)
+ Tcw = T @ self.Twc.T @ T
+ return type(self)(K=self.K, Tcw=Tcw, hw=self.hw)
+
+ def fwd(self):
+ T = self.convert_matrix.to(self.device)
+ Twc = T @ self.Tcw.T @ T
+ return type(self)(K=self.K, Twc=Twc, hw=self.hw)
+
+ def look_at(self, at, up=torch.Tensor([0, 1, 0])):
+
+ eps = 1e-5
+ eye = self.Tcw.T[:, :3, -1]
+
+ at = at.unsqueeze(0)
+ up = up.unsqueeze(0).to(at.device)
+ up /= up.norm(dim=-1, keepdim=True) + eps
+
+ z_axis = at - eye
+ z_axis /= z_axis.norm(dim=-1, keepdim=True) + eps
+
+ up = up.expand(z_axis.shape)
+ x_axis = torch.cross(up, z_axis)
+ x_axis /= x_axis.norm(dim=-1, keepdim=True) + eps
+
+ y_axis = torch.cross(z_axis, x_axis)
+ y_axis /= y_axis.norm(dim=-1, keepdim=True) + eps
+
+ R = torch.stack((x_axis, y_axis, z_axis), dim=-1)
+
+ Tcw = self.Tcw
+ Tcw.T[:, :3, :3] = R
+ self.Twc = Tcw.inverse()
+
+ def get_origin(self, flatten=False):
+ orig = self.Tcw.T[:, :3, -1].view(len(self), 3, 1, 1).repeat(1, 1, *self.hw)
+ if flatten:
+ orig = orig.reshape(len(self), 3, -1).permute(0, 2, 1)
+ return orig
+
+ def get_viewdirs(self, normalize=False, flatten=False, to_world=False):
+
+ ones = torch.ones((len(self), 1, *self.hw), dtype=self.dtype, device=self.device)
+ rays = self.reconstruct_depth_map(ones, to_world=False)
+ if normalize:
+ rays = rays / torch.norm(rays, dim=1).unsqueeze(1)
+ if to_world:
+ rays = self.to_world(rays).reshape(len(self), 3, *self.hw)
+ if flatten:
+ rays = rays.reshape(len(self), 3, -1).permute(0, 2, 1)
+ return rays
+
+ def get_render_rays(self, near=None, far=None, n_rays=None, gt=None):
+
+ b = len(self)
+
+ ones = torch.ones((b, 1, *self.hw), dtype=self.dtype, device=self.device)
+
+ rays = self.reconstruct_depth_map(ones, to_world=False)
+ rays = rays / torch.norm(rays, dim=1).unsqueeze(1)
+
+ rays[:, 1] = - rays[:, 1]
+ rays[:, 2] = - rays[:, 2]
+
+ orig = self.pose[:, :3, -1].view(b, 3, 1, 1).repeat(1, 1, *self.hw)
+ rays = self.no_translation().inverted_pose().to_world(rays).reshape(b, 3, *self.hw)
+
+ info = [orig, rays]
+ if near is not None:
+ info = info + [near * ones]
+ if far is not None:
+ info = info + [far * ones]
+ if gt is not None:
+ info = info + [gt]
+
+ rays = torch.cat(info, 1)
+ rays = rays.permute(0, 2, 3, 1).reshape(b, -1, rays.shape[1])
+
+ if n_rays is not None:
+ idx = torch.randint(0, self.n_pixels, (n_rays,))
+ rays = rays[:, idx, :]
+
+ return rays
+
+ def get_plucker(self):
+
+ b = len(self)
+ ones = torch.ones((b, 1, *self.hw), dtype=self.dtype, device=self.device)
+ rays = self.reconstruct_depth_map(ones, to_world=False)
+ rays = rays / torch.norm(rays, dim=1).unsqueeze(1)
+ orig = self.Tcw.T[:, :3, -1].view(b, 3, 1, 1).repeat(1, 1, *self.hw)
+
+ orig = orig.view(1, 3, -1).permute(0, 2, 1)
+ rays = rays.view(1, 3, -1).permute(0, 2, 1)
+
+ cross = torch.cross(orig, rays, dim=-1)
+ plucker = torch.cat((rays, cross), dim=-1)
+
+ return plucker
+
+ def project_pointcloud(self, pcl_src, rgb_src, thr=1):
+
+ if rgb_src.dim() == 4:
+ rgb_src = rgb_src.view(*rgb_src.shape[:2], -1)
+
+ # Get projected coordinates and depth values
+ uv_all, z_all = self.project_points(pcl_src, return_z=True, from_world=True)
+
+ rgbs_tgt, depths_tgt = [], []
+
+ b = pcl_src.shape[0]
+ for i in range(b):
+ uv, z = uv_all[i].reshape(-1, 2), z_all[i].reshape(-1, 1)
+
+ # Remove out-of-bounds coordinates and points behind the camera
+ idx = (uv[:, 0] >= -1) & (uv[:, 0] <= 1) & \
+ (uv[:, 1] >= -1) & (uv[:, 1] <= 1) & (z[:, 0] > 0.0)
+
+ # Unormalize and stack coordinates for scatter operation
+ uv = (unnorm_pixel_grid(uv[idx], self.hw)).round().long()
+ uv = uv[:, 0] + uv[:, 1] * self.hw[1]
+
+ # Min scatter operation (only keep the closest depth)
+ depth_tgt = 1e10 * torch.ones((self.hw[0] * self.hw[1], 1), device=pcl_src.device)
+ depth_tgt, argmin = scatter_min(src=z[idx], index=uv.unsqueeze(1), dim=0, out=depth_tgt)
+ depth_tgt[depth_tgt == 1e10] = 0.
+
+ num_valid = (depth_tgt > 0).sum()
+ if num_valid > thr:
+
+ # Substitute invalid values with zero
+ invalid = argmin == argmin.max()
+ argmin[invalid] = 0
+ rgb_tgt = rgb_src[i].permute(1, 0)[idx][argmin]
+ rgb_tgt[invalid] = -1
+
+ else:
+
+ rgb_tgt = -1 * torch.ones(1, self.n_pixels, 3, device=self.device, dtype=self.dtype)
+
+ # Reshape outputs
+ rgb_tgt = rgb_tgt.reshape(1, self.hw[0], self.hw[1], 3).permute(0, 3, 1, 2)
+ depth_tgt = depth_tgt.reshape(1, 1, self.hw[0], self.hw[1])
+
+ rgbs_tgt.append(rgb_tgt)
+ depths_tgt.append(depth_tgt)
+
+ rgb_tgt = torch.cat(rgbs_tgt, 0)
+ depth_tgt = torch.cat(depths_tgt, 0)
+
+ return rgb_tgt, depth_tgt
+
+ def reconstruct_depth_map_rays(self, depth, to_world=False):
+ if depth is None:
+ return None
+ b, _, h, w = depth.shape
+ rays = self.get_viewdirs(normalize=True, to_world=False)
+ points = (rays * depth).view(b, 3, -1)
+ if to_world and self.Tcw is not None:
+ points = self.Tcw * points
+ return points.view(b, 3, h, w)
diff --git a/vidar/geometry/camera_ucm.py b/vidar/geometry/camera_ucm.py
new file mode 100644
index 0000000000000000000000000000000000000000..90e9058ea78b4e30ac73d6136c1a75a579fddfdb
--- /dev/null
+++ b/vidar/geometry/camera_ucm.py
@@ -0,0 +1,212 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from functools import lru_cache
+import torch
+import torch.nn as nn
+
+from vidar.geometry.camera_utils import invert_intrinsics, scale_intrinsics
+from vidar.geometry.pose import Pose
+from vidar.geometry.pose_utils import invert_pose
+from vidar.utils.tensor import pixel_grid, same_shape, cat_channel_ones, norm_pixel_grid, interpolate, interleave
+from vidar.utils.types import is_tensor, is_seq
+
+########################################################################################################################
+
+class UCMCamera(nn.Module):
+ """
+ Differentiable camera class implementing reconstruction and projection
+ functions for the unified camera model (UCM).
+ """
+ def __init__(self, I, Tcw=None):
+ """
+ Initializes the Camera class
+
+ Parameters
+ ----------
+ I : torch.Tensor [5]
+ Camera intrinsics parameter vector
+ Tcw : Pose
+ Camera -> World pose transformation
+ """
+ super().__init__()
+ self.I = I
+ if Tcw is None:
+ self.Tcw = Pose.identity(len(I))
+ elif isinstance(Tcw, Pose):
+ self.Tcw = Tcw
+ else:
+ self.Tcw = Pose(Tcw)
+
+ self.Tcw.to(self.I.device)
+
+ def __len__(self):
+ """Batch size of the camera intrinsics"""
+ return len(self.I)
+
+ def to(self, *args, **kwargs):
+ """Moves object to a specific device"""
+ self.I = self.I.to(*args, **kwargs)
+ self.Tcw = self.Tcw.to(*args, **kwargs)
+ return self
+
+########################################################################################################################
+
+ @property
+ def fx(self):
+ """Focal length in x"""
+ return self.I[:, 0].unsqueeze(1).unsqueeze(2)
+
+ @property
+ def fy(self):
+ """Focal length in y"""
+ return self.I[:, 1].unsqueeze(1).unsqueeze(2)
+
+ @property
+ def cx(self):
+ """Principal point in x"""
+ return self.I[:, 2].unsqueeze(1).unsqueeze(2)
+
+ @property
+ def cy(self):
+ """Principal point in y"""
+ return self.I[:, 3].unsqueeze(1).unsqueeze(2)
+
+ @property
+ def alpha(self):
+ """alpha in UCM model"""
+ return self.I[:, 4].unsqueeze(1).unsqueeze(2)
+
+ @property
+ @lru_cache()
+ def Twc(self):
+ """World -> Camera pose transformation (inverse of Tcw)"""
+ return self.Tcw.inverse()
+
+########################################################################################################################
+
+ def reconstruct(self, depth, frame='w'):
+ """
+ Reconstructs pixel-wise 3D points from a depth map.
+
+ Parameters
+ ----------
+ depth : torch.Tensor [B,1,H,W]
+ Depth map for the camera
+ frame : 'w'
+ Reference frame: 'c' for camera and 'w' for world
+
+ Returns
+ -------
+ points : torch.tensor [B,3,H,W]
+ Pixel-wise 3D points
+ """
+
+ if depth is None:
+ return None
+ b, c, h, w = depth.shape
+ assert c == 1
+
+ grid = pixel_grid(depth, with_ones=True, device=depth.device)
+
+ # Estimate the outward rays in the camera frame
+ fx, fy, cx, cy, alpha = self.fx, self.fy, self.cx, self.cy, self.alpha # [B,1,1]
+
+ if torch.any(torch.isnan(alpha)):
+ raise ValueError('alpha is nan')
+
+ u = grid[:,0,:,:]
+ v = grid[:,1,:,:]
+
+ mx = (u - cx) / fx * (1 - alpha)
+ my = (v - cy) / fy * (1 - alpha)
+ r_square = mx ** 2 + my ** 2
+ xi = alpha / (1 - alpha) # [B, 1, 1]
+ coeff = (xi + torch.sqrt(1 + (1 - xi ** 2) * r_square)) / (1 + r_square) # [B, H, W]
+
+ x = coeff * mx
+ y = coeff * my
+ z = coeff * 1 - xi
+ z = z.clamp(min=1e-7)
+
+ x_norm = x / z
+ y_norm = y / z
+ z_norm = z / z
+ xnorm = torch.stack(( x_norm, y_norm, z_norm ), dim=1).float()
+
+ # Scale rays to metric depth
+ Xc = xnorm * depth
+
+ # If in camera frame of reference
+ if frame == 'c':
+ return Xc
+ # If in world frame of reference
+ elif frame == 'w':
+ return (self.Twc * Xc.view(b, 3, -1)).view(b,3,h,w)
+ # If none of the above
+ else:
+ raise ValueError('Unknown reference frame {}'.format(frame))
+
+ def project(self, X, frame='w'):
+ """
+ Projects 3D points onto the image plane
+
+ Parameters
+ ----------
+ X : torch.Tensor [B,3,H,W]
+ 3D points to be projected
+ frame : 'w'
+ Reference frame: 'c' for camera and 'w' for world
+
+ Returns
+ -------
+ points : torch.Tensor [B,H,W,2]
+ 2D projected points that are within the image boundaries
+ """
+ B, C, H, W = X.shape
+ assert C == 3
+
+ # Project 3D points onto the camera image plane
+ if frame == 'c':
+ X = X
+ elif frame == 'w':
+ X = (self.Tcw * X.view(B,3,-1)).view(B,3,H,W)
+ else:
+ raise ValueError('Unknown reference frame {}'.format(frame))
+
+ d = torch.norm(X, dim=1)
+ fx, fy, cx, cy, alpha = self.fx, self.fy, self.cx, self.cy, self.alpha
+ x, y, z = X[:,0,:], X[:,1,:], X[:,2,:]
+ z = z.clamp(min=1e-7)
+
+ Xnorm = fx * x / (alpha * d + (1 - alpha) * z + 1e-7) + cx
+ Ynorm = fy * y / (alpha * d + (1 - alpha) * z + 1e-7) + cy
+ Xnorm = 2 * Xnorm / (W-1) - 1
+ Ynorm = 2 * Ynorm / (H-1) - 1
+
+ coords = torch.stack([Xnorm, Ynorm], dim=-1).permute(0,3,1,2)
+ z = z.unsqueeze(1)
+
+ invalid = (coords[:, 0] < -1) | (coords[:, 0] > 1) | \
+ (coords[:, 1] < -1) | (coords[:, 1] > 1) | (z[:, 0] < 0)
+ coords[invalid.unsqueeze(1).repeat(1, 2, 1, 1)] = -2
+
+ # Return pixel coordinates
+ return coords.permute(0, 2, 3, 1)
+
+ def reconstruct_depth_map(self, depth, to_world=True):
+ if to_world:
+ return self.reconstruct(depth, frame='w')
+ else:
+ return self.reconstruct(depth, frame='c')
+
+ def project_points(self, points, from_world=True, normalize=True, return_z=False):
+ if from_world:
+ return self.project(points, frame='w')
+ else:
+ return self.project(points, frame='c')
+
+ def coords_from_depth(self, depth, ref_cam=None):
+ if ref_cam is None:
+ return self.project_points(self.reconstruct_depth_map(depth, to_world=False), from_world=True)
+ else:
+ return ref_cam.project_points(self.reconstruct_depth_map(depth, to_world=True), from_world=True)
\ No newline at end of file
diff --git a/vidar/geometry/camera_utils.py b/vidar/geometry/camera_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..461d140a4ebb0ad16ed1157f129bb773fb4467ad
--- /dev/null
+++ b/vidar/geometry/camera_utils.py
@@ -0,0 +1,31 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from vidar.utils.types import is_seq
+
+
+def invert_intrinsics(K):
+ """Invert camera intrinsics"""
+ Kinv = K.clone()
+ Kinv[:, 0, 0] = 1. / K[:, 0, 0]
+ Kinv[:, 1, 1] = 1. / K[:, 1, 1]
+ Kinv[:, 0, 2] = -1. * K[:, 0, 2] / K[:, 0, 0]
+ Kinv[:, 1, 2] = -1. * K[:, 1, 2] / K[:, 1, 1]
+ return Kinv
+
+
+def scale_intrinsics(K, ratio):
+ """Scale intrinsics given a ratio (tuple for individual hw ratios, float for the same ratio)"""
+ if is_seq(ratio):
+ ratio_h, ratio_w = ratio
+ else:
+ ratio_h = ratio_w = ratio
+
+ K = K.clone()
+ K[..., 0, 0] *= ratio_w
+ K[..., 1, 1] *= ratio_h
+ # K[..., 0, 2] = (K[..., 0, 2] + 0.5) * x_scale - 0.5
+ # K[..., 1, 2] = (K[..., 1, 2] + 0.5) * y_scale - 0.5
+ K[..., 0, 2] = K[..., 0, 2] * ratio_w
+ K[..., 1, 2] = K[..., 1, 2] * ratio_h
+
+ return K
diff --git a/vidar/geometry/pose.py b/vidar/geometry/pose.py
new file mode 100644
index 0000000000000000000000000000000000000000..b385cb772a658028ebc3f7ad0666c66846747a4b
--- /dev/null
+++ b/vidar/geometry/pose.py
@@ -0,0 +1,182 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import torch
+
+from vidar.geometry.pose_utils import invert_pose, pose_vec2mat, to_global_pose, euler2mat
+from vidar.utils.types import is_int
+
+
+def from_dict_sample(T, to_global=False, zero_origin=False, to_matrix=False):
+ """
+ Create poses from a sample dictionary
+
+ Parameters
+ ----------
+ T : Dict
+ Dictionary containing input poses [B,4,4]
+ to_global : Bool
+ Whether poses should be converted to global frame of reference
+ zero_origin : Bool
+ Whether the target camera should be the center of the frame of reference
+ to_matrix : Bool
+ Whether output poses should be classes or tensors
+
+ Returns
+ -------
+ pose : Dict
+ Dictionary containing output poses
+ """
+ pose = {key: Pose(val) for key, val in T.items()}
+ if to_global:
+ pose = to_global_pose(pose, zero_origin=zero_origin)
+ if to_matrix:
+ pose = {key: val.T for key, val in pose.items()}
+ return pose
+
+
+def from_dict_batch(T, **kwargs):
+ """Create poses from a batch dictionary"""
+ pose_batch = [from_dict_sample({key: val[b] for key, val in T.items()}, **kwargs)
+ for b in range(T[0].shape[0])]
+ return {key: torch.stack([v[key] for v in pose_batch], 0) for key in pose_batch[0]}
+
+
+class Pose:
+ """
+ Pose class for 3D operations
+
+ Parameters
+ ----------
+ T : torch.Tensor or Int
+ Transformation matrix [B,4,4], or batch size (poses initialized as identity)
+ """
+ def __init__(self, T=1):
+ if is_int(T):
+ T = torch.eye(4).repeat(T, 1, 1)
+ self.T = T if T.dim() == 3 else T.unsqueeze(0)
+
+ def __len__(self):
+ """Return batch size"""
+ return len(self.T)
+
+ def __getitem__(self, i):
+ """Return batch-wise pose"""
+ return Pose(self.T[[i]])
+
+ def __mul__(self, data):
+ """Transforms data (pose or 3D points)"""
+ if isinstance(data, Pose):
+ return Pose(self.T.bmm(data.T))
+ elif isinstance(data, torch.Tensor):
+ return self.T[:, :3, :3].bmm(data) + self.T[:, :3, -1].unsqueeze(-1)
+ else:
+ raise NotImplementedError()
+
+ def detach(self):
+ """Return detached pose"""
+ return Pose(self.T.detach())
+
+ @property
+ def shape(self):
+ """Return pose shape"""
+ return self.T.shape
+
+ @property
+ def device(self):
+ """Return pose device"""
+ return self.T.device
+
+ @property
+ def dtype(self):
+ """Return pose type"""
+ return self.T.dtype
+
+ @classmethod
+ def identity(cls, N=1, device=None, dtype=torch.float):
+ """Initializes as a [4,4] identity matrix"""
+ return cls(torch.eye(4, device=device, dtype=dtype).repeat([N,1,1]))
+
+ @staticmethod
+ def from_dict(T, to_global=False, zero_origin=False, to_matrix=False):
+ """Create poses from a dictionary"""
+ if T[0].dim() == 3:
+ return from_dict_sample(T, to_global=to_global, zero_origin=zero_origin, to_matrix=to_matrix)
+ elif T[0].dim() == 4:
+ return from_dict_batch(T, to_global=to_global, zero_origin=zero_origin, to_matrix=True)
+
+ @classmethod
+ def from_vec(cls, vec, mode):
+ """Initializes from a [B,6] batch vector"""
+ mat = pose_vec2mat(vec, mode)
+ pose = torch.eye(4, device=vec.device, dtype=vec.dtype).repeat([len(vec), 1, 1])
+ pose[:, :3, :3] = mat[:, :3, :3]
+ pose[:, :3, -1] = mat[:, :3, -1]
+ return cls(pose)
+
+ def repeat(self, *args, **kwargs):
+ """Repeats the transformation matrix multiple times"""
+ self.T = self.T.repeat(*args, **kwargs)
+ return self
+
+ def inverse(self):
+ """Returns a new Pose that is the inverse of this one"""
+ return Pose(invert_pose(self.T))
+
+ def to(self, *args, **kwargs):
+ """Copy pose to device"""
+ self.T = self.T.to(*args, **kwargs)
+ return self
+
+ def cuda(self, *args, **kwargs):
+ """Copy pose to CUDA"""
+ self.to('cuda')
+ return self
+
+ def translate(self, xyz):
+ """Translate pose"""
+ self.T[:, :3, -1] = self.T[:, :3, -1] + xyz.to(self.device)
+ return self
+
+ def rotate(self, rpw):
+ """Rotate pose"""
+ rot = euler2mat(rpw)
+ T = invert_pose(self.T).clone()
+ T[:, :3, :3] = T[:, :3, :3] @ rot.to(self.device)
+ self.T = invert_pose(T)
+ return self
+
+ def rotateRoll(self, r):
+ """Rotate pose in the roll axis"""
+ return self.rotate(torch.tensor([[0, 0, r]]))
+
+ def rotatePitch(self, p):
+ """Rotate pose in the pitcv axis"""
+ return self.rotate(torch.tensor([[p, 0, 0]]))
+
+ def rotateYaw(self, w):
+ """Rotate pose in the yaw axis"""
+ return self.rotate(torch.tensor([[0, w, 0]]))
+
+ def translateForward(self, t):
+ """Translate pose forward"""
+ return self.translate(torch.tensor([[0, 0, -t]]))
+
+ def translateBackward(self, t):
+ """Translate pose backward"""
+ return self.translate(torch.tensor([[0, 0, +t]]))
+
+ def translateLeft(self, t):
+ """Translate pose left"""
+ return self.translate(torch.tensor([[+t, 0, 0]]))
+
+ def translateRight(self, t):
+ """Translate pose right"""
+ return self.translate(torch.tensor([[-t, 0, 0]]))
+
+ def translateUp(self, t):
+ """Translate pose up"""
+ return self.translate(torch.tensor([[0, +t, 0]]))
+
+ def translateDown(self, t):
+ """Translate pose down"""
+ return self.translate(torch.tensor([[0, -t, 0]]))
diff --git a/vidar/geometry/pose_utils.py b/vidar/geometry/pose_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a57ee1531b5eb69012f361a6c3bae8136c940f2
--- /dev/null
+++ b/vidar/geometry/pose_utils.py
@@ -0,0 +1,206 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import torch
+import torch.nn.functional as F
+
+from vidar.utils.decorators import iterate1
+
+
+def to_global_pose(pose, zero_origin=False):
+ """Get global pose coordinates from current and context poses"""
+ if zero_origin:
+ pose[0].T[[0]] = torch.eye(4, device=pose[0].device, dtype=pose[0].dtype)
+ for b in range(1, len(pose[0])):
+ pose[0].T[[b]] = (pose[0][b] * pose[0][0]).T.float()
+ for key in pose.keys():
+ if key != 0:
+ pose[key] = pose[key] * pose[0]
+ return pose
+
+
+# def to_global_pose(pose, zero_origin=False):
+# """Get global pose coordinates from current and context poses"""
+# if zero_origin:
+# pose[(0, 0)].T = torch.eye(4, device=pose[(0, 0)].device, dtype=pose[(0, 0)].dtype). \
+# repeat(pose[(0, 0)].shape[0], 1, 1)
+# for key in pose.keys():
+# if key[0] == 0 and key[1] != 0:
+# pose[key].T = (pose[key] * pose[(0, 0)]).T
+# for key in pose.keys():
+# if key[0] != 0:
+# pose[key] = pose[key] * pose[(0, 0)]
+# return pose
+
+
+def euler2mat(angle):
+ """Convert euler angles to rotation matrix"""
+ B = angle.size(0)
+ x, y, z = angle[:, 0], angle[:, 1], angle[:, 2]
+
+ cosz = torch.cos(z)
+ sinz = torch.sin(z)
+
+ zeros = z.detach() * 0
+ ones = zeros.detach() + 1
+ zmat = torch.stack([ cosz, -sinz, zeros,
+ sinz, cosz, zeros,
+ zeros, zeros, ones], dim=1).view(B, 3, 3)
+
+ cosy = torch.cos(y)
+ siny = torch.sin(y)
+
+ ymat = torch.stack([ cosy, zeros, siny,
+ zeros, ones, zeros,
+ -siny, zeros, cosy], dim=1).view(B, 3, 3)
+
+ cosx = torch.cos(x)
+ sinx = torch.sin(x)
+
+ xmat = torch.stack([ ones, zeros, zeros,
+ zeros, cosx, -sinx,
+ zeros, sinx, cosx], dim=1).view(B, 3, 3)
+
+ rot_mat = xmat.bmm(ymat).bmm(zmat)
+ return rot_mat
+
+
+def pose_vec2mat(vec, mode='euler'):
+ """Convert translation and Euler rotation to a [B,4,4] torch.Tensor transformation matrix"""
+ if mode is None:
+ return vec
+ trans, rot = vec[:, :3].unsqueeze(-1), vec[:, 3:]
+ if mode == 'euler':
+ rot_mat = euler2mat(rot)
+ else:
+ raise ValueError('Rotation mode not supported {}'.format(mode))
+ mat = torch.cat([rot_mat, trans], dim=2) # [B,3,4]
+ return mat
+
+
+@iterate1
+def invert_pose(T):
+ """Invert a [B,4,4] torch.Tensor pose"""
+ Tinv = torch.eye(4, device=T.device, dtype=T.dtype).repeat([len(T), 1, 1])
+ Tinv[:, :3, :3] = torch.transpose(T[:, :3, :3], -2, -1)
+ Tinv[:, :3, -1] = torch.bmm(-1. * Tinv[:, :3, :3], T[:, :3, -1].unsqueeze(-1)).squeeze(-1)
+ return Tinv
+ # return torch.linalg.inv(T)
+
+
+def tvec_to_translation(tvec):
+ """Convert translation vector to translation matrix (no rotation)"""
+ batch_size = tvec.shape[0]
+ T = torch.eye(4).to(device=tvec.device).repeat(batch_size, 1, 1)
+ t = tvec.contiguous().view(-1, 3, 1)
+ T[:, :3, 3, None] = t
+ return T
+
+
+def euler2rot(euler):
+ """Convert Euler parameters to a [B,3,3] torch.Tensor rotation matrix"""
+ euler_norm = torch.norm(euler, 2, 2, True)
+ axis = euler / (euler_norm + 1e-7)
+
+ cos_a = torch.cos(euler_norm)
+ sin_a = torch.sin(euler_norm)
+ cos1_a = 1 - cos_a
+
+ x = axis[..., 0].unsqueeze(1)
+ y = axis[..., 1].unsqueeze(1)
+ z = axis[..., 2].unsqueeze(1)
+
+ x_sin = x * sin_a
+ y_sin = y * sin_a
+ z_sin = z * sin_a
+ x_cos1 = x * cos1_a
+ y_cos1 = y * cos1_a
+ z_cos1 = z * cos1_a
+
+ xx_cos1 = x * x_cos1
+ yy_cos1 = y * y_cos1
+ zz_cos1 = z * z_cos1
+ xy_cos1 = x * y_cos1
+ yz_cos1 = y * z_cos1
+ zx_cos1 = z * x_cos1
+
+ batch_size = euler.shape[0]
+ rot = torch.zeros((batch_size, 4, 4)).to(device=euler.device)
+
+ rot[:, 0, 0] = torch.squeeze(xx_cos1 + cos_a)
+ rot[:, 0, 1] = torch.squeeze(xy_cos1 - z_sin)
+ rot[:, 0, 2] = torch.squeeze(zx_cos1 + y_sin)
+ rot[:, 1, 0] = torch.squeeze(xy_cos1 + z_sin)
+ rot[:, 1, 1] = torch.squeeze(yy_cos1 + cos_a)
+ rot[:, 1, 2] = torch.squeeze(yz_cos1 - x_sin)
+ rot[:, 2, 0] = torch.squeeze(zx_cos1 - y_sin)
+ rot[:, 2, 1] = torch.squeeze(yz_cos1 + x_sin)
+ rot[:, 2, 2] = torch.squeeze(zz_cos1 + cos_a)
+ rot[:, 3, 3] = 1
+
+ return rot
+
+
+def vec2mat(euler, translation, invert=False):
+ """Convert Euler rotation and translation to a [B,4,4] torch.Tensor transformation matrix"""
+ R = euler2rot(euler)
+ t = translation.clone()
+
+ if invert:
+ R = R.transpose(1, 2)
+ t *= -1
+
+ T = tvec_to_translation(t)
+
+ if invert:
+ M = torch.matmul(R, T)
+ else:
+ M = torch.matmul(T, R)
+
+ return M
+
+
+def rot2quat(R):
+ """Convert a [B,3,3] rotation matrix to [B,4] quaternions"""
+ b, _, _ = R.shape
+ q = torch.ones((b, 4), device=R.device)
+
+ R00 = R[:, 0, 0]
+ R01 = R[:, 0, 1]
+ R02 = R[:, 0, 2]
+ R10 = R[:, 1, 0]
+ R11 = R[:, 1, 1]
+ R12 = R[:, 1, 2]
+ R20 = R[:, 2, 0]
+ R21 = R[:, 2, 1]
+ R22 = R[:, 2, 2]
+
+ q[:, 3] = torch.sqrt(1.0 + R00 + R11 + R22) / 2
+ q[:, 0] = (R21 - R12) / (4 * q[:, 3])
+ q[:, 1] = (R02 - R20) / (4 * q[:, 3])
+ q[:, 2] = (R10 - R01) / (4 * q[:, 3])
+
+ return q
+
+
+def quat2rot(q):
+ """Convert [B,4] quaternions to [B,3,3] rotation matrix"""
+ b, _ = q.shape
+ q = F.normalize(q, dim=1)
+ R = torch.ones((b, 3, 3), device=q.device)
+
+ qr = q[:, 0]
+ qi = q[:, 1]
+ qj = q[:, 2]
+ qk = q[:, 3]
+
+ R[:, 0, 0] = 1 - 2 * (qj ** 2 + qk ** 2)
+ R[:, 0, 1] = 2 * (qj * qi - qk * qr)
+ R[:, 0, 2] = 2 * (qi * qk + qr * qj)
+ R[:, 1, 0] = 2 * (qj * qi + qk * qr)
+ R[:, 1, 1] = 1 - 2 * (qi ** 2 + qk ** 2)
+ R[:, 1, 2] = 2 * (qj * qk - qi * qr)
+ R[:, 2, 0] = 2 * (qk * qi - qj * qr)
+ R[:, 2, 1] = 2 * (qj * qk + qi * qr)
+ R[:, 2, 2] = 1 - 2 * (qi ** 2 + qj ** 2)
+
+ return R
diff --git a/vidar/metrics/base.py b/vidar/metrics/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..61f303f9823d80fb1dee9eb7e779cdd68c6800e7
--- /dev/null
+++ b/vidar/metrics/base.py
@@ -0,0 +1,163 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from collections import OrderedDict
+from functools import partial
+
+import numpy as np
+import torch
+
+from vidar.utils.distributed import reduce_value
+from vidar.utils.tensor import same_shape, interpolate
+
+
+class BaseEvaluation:
+ """
+ Base class for evaluation metrics
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration file
+ name : String
+ Evaluation name
+ task : String
+ Task referent to the evaluation
+ metrics : String
+ Metrics name
+ """
+ def __init__(self, cfg, name, task, metrics):
+ self.name = name
+ self.task = task
+ self.width = 32 + 11 * len(metrics)
+ self.metrics = metrics
+ self.modes = ['']
+
+ self.font1 = {'color': 'magenta', 'attrs': ('bold',)}
+ self.font2 = {'color': 'cyan', 'attrs': ()}
+
+ self.nearest = partial(interpolate, scale_factor=None, mode='nearest', align_corners=None)
+ self.bilinear = partial(interpolate, scale_factor=None, mode='bilinear', align_corners=True)
+
+ self.only_first = cfg.has('only_first', False)
+
+ @property
+ def horz_line(self):
+ """Print horizontal line"""
+ return '|{:<}|'.format('*' * self.width)
+
+ @property
+ def metr_line(self):
+ """Print metrics line"""
+ return '| {:^30} |' + ' {:^8} |' * len(self.metrics)
+
+ @property
+ def outp_line(self):
+ """Print output line"""
+ return '{:<30}' + ' | {:^8.3f}' * len(self.metrics)
+
+ @staticmethod
+ def wrap(string):
+ """Wrap line around vertical bars"""
+ return '| {} |'.format(string)
+
+ def check_name(self, key):
+ """Check name for prefixes"""
+ return key.startswith(self.name) or \
+ key.startswith('fwd_' + self.name) or \
+ key.startswith('bwd_' + self.name)
+
+ def reduce_fn(self, *args, **kwargs):
+ """Reduce function"""
+ raise NotImplementedError('reduce_fn not implemented for {}'.format(self.__name__))
+
+ def populate_metrics_dict(self, *args, **kwargs):
+ """Populate metrics function"""
+ raise NotImplementedError('create_dict_key not implemented for {}'.format(self.__name__))
+
+ def print(self, *args, **kwargs):
+ """Print function"""
+ raise NotImplementedError('print not implemented for {}'.format(self.__name__))
+
+ @staticmethod
+ def interp(dst, src, fn):
+ """Interpolate dst to be the size of src using the interpolation function fn"""
+ if dst is None:
+ return dst
+ assert dst.dim() == src.dim()
+ if dst.dim() == 4 and not same_shape(dst.shape, src.shape):
+ dst = fn(dst, size=src)
+ return dst
+
+ def interp_bilinear(self, dst, src):
+ """Bilinear interpolation"""
+ return self.interp(dst, src, self.bilinear)
+
+ def interp_nearest(self, dst, src):
+ """Nearest interpolation"""
+ return self.interp(dst, src, self.nearest)
+
+ def reduce(self, output, dataloaders, prefixes, verbose=True):
+ """Reduce function"""
+ reduced_data = self.reduce_metrics(output, dataloaders)
+ metrics_dict = self.create_metrics_dict(reduced_data, prefixes)
+ if verbose:
+ self.print(reduced_data, prefixes)
+ return metrics_dict
+
+ def create_metrics_dict(self, reduced_data, prefixes):
+ """Create metrics dictionary"""
+ metrics_dict = {}
+ # For all datasets
+ for n, metrics in enumerate(reduced_data):
+ if metrics: # If there are calculated metrics
+ self.populate_metrics_dict(metrics, metrics_dict, prefixes[n])
+ # Return metrics dictionary
+ return metrics_dict
+
+ def reduce_metrics(self, dataset_outputs, datasets, ontology=None, strict=True):
+ """Reduce metrics"""
+ # If there is only one dataset, wrap in a list
+ if isinstance(dataset_outputs[0], dict):
+ dataset_outputs = [dataset_outputs]
+ # List storing metrics for all datasets
+ all_metrics_dict = []
+ # Loop over all datasets and all batches
+ for batch_outputs, dataset in zip(dataset_outputs, datasets):
+ # Initialize metrics dictionary
+ metrics_dict = OrderedDict()
+ # Get length, names and dimensions
+ length = len(dataset)
+ names = [key for key in list(batch_outputs[0].keys()) if self.check_name(key)]
+ dims = [tuple(batch_outputs[0][name].size()) for name in names]
+ # Get data device
+ device = batch_outputs[0]['idx'].device
+ # Count how many times each sample was seen
+ if strict:
+ seen = torch.zeros(length, device=device)
+ for output in batch_outputs:
+ seen[output['idx']] += 1
+ seen = reduce_value(seen, average=False, name='idx')
+ assert not np.any(seen.cpu().numpy() == 0), \
+ 'Not all samples were seen during evaluation'
+ # Reduce relevant metrics
+ for name, dim in zip(names, dims):
+ metrics = torch.zeros([length] + list(dim), device=device)
+
+ # Count how many times each sample was seen
+ if not strict:
+ seen = torch.zeros(length, device=device)
+ for output in batch_outputs:
+ if name in output:
+ seen[output['idx']] += 1
+ seen = reduce_value(seen, average=False, name='idx')
+
+ for output in batch_outputs:
+ if name in output:
+ metrics[output['idx']] = output[name]
+ metrics = reduce_value(metrics, average=False, name=name)
+ metrics_dict[name] = self.reduce_fn(metrics, seen)
+ # Append metrics dictionary to the list
+ all_metrics_dict.append(metrics_dict)
+ # Return list of metrics dictionary
+ return all_metrics_dict
+
diff --git a/vidar/metrics/depth.py b/vidar/metrics/depth.py
new file mode 100644
index 0000000000000000000000000000000000000000..86a8013d167719bc1f0587b1c24b9388ad1ce544
--- /dev/null
+++ b/vidar/metrics/depth.py
@@ -0,0 +1,223 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import torch
+
+from vidar.metrics.base import BaseEvaluation
+from vidar.metrics.utils import create_crop_mask, scale_output
+from vidar.utils.config import cfg_has
+from vidar.utils.data import dict_remove_nones
+from vidar.utils.depth import post_process_depth
+from vidar.utils.distributed import on_rank_0
+from vidar.utils.logging import pcolor
+from vidar.utils.types import is_dict
+
+
+class DepthEvaluation(BaseEvaluation):
+ """
+ Detph evaluation metrics
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration file
+ """
+ def __init__(self, cfg):
+ super().__init__(cfg,
+ name='depth', task='depth',
+ metrics=('abs_rel', 'sqr_rel', 'rmse', 'rmse_log', 'silog', 'a1', 'a2', 'a3'),
+ )
+
+ self.min_depth = cfg.min_depth
+ self.max_depth = cfg.max_depth
+ self.crop = cfg_has(cfg, 'crop', '')
+ self.scale_output = cfg_has(cfg, 'scale_output', 'resize')
+
+ self.post_process = cfg_has(cfg, 'post_process', False)
+ self.median_scaling = cfg_has(cfg, 'median_scaling', False)
+ self.valid_threshold = cfg.has('valid_threshold', None)
+
+ if self.post_process:
+ self.modes += ['pp']
+ if self.median_scaling:
+ self.modes += ['gt']
+ if self.post_process and self.median_scaling:
+ self.modes += ['pp_gt']
+
+ @staticmethod
+ def reduce_fn(metrics, seen):
+ """Reduce function"""
+ valid = seen.view(-1) > 0
+ return (metrics[valid] / seen.view(-1, 1)[valid]).mean(0)
+
+ def populate_metrics_dict(self, metrics, metrics_dict, prefix):
+ """Populate metrics function"""
+ for metric in metrics:
+ if metric.startswith(self.name):
+ name, suffix = metric.split('|')
+ for i, key in enumerate(self.metrics):
+ metrics_dict[f'{prefix}-{name}|{key}_{suffix}'] = \
+ metrics[metric][i].item()
+
+ @on_rank_0
+ def print(self, reduced_data, prefixes):
+ """Print function"""
+ print()
+ print(self.horz_line)
+ print(self.metr_line.format(*((self.name.upper(),) + self.metrics)))
+ for n, metrics in enumerate(reduced_data):
+ if sum([self.name in key for key in metrics.keys()]) == 0:
+ continue
+ print(self.horz_line)
+ print(self.wrap(pcolor('*** {:<114}'.format(prefixes[n]), **self.font1)))
+ print(self.horz_line)
+ for key, metric in sorted(metrics.items()):
+ if self.name in key:
+ print(self.wrap(pcolor(self.outp_line.format(
+ *((key.upper(),) + tuple(metric.tolist()))), **self.font2)))
+ print(self.horz_line)
+ print()
+
+ def compute(self, gt, pred, use_gt_scale=True, mask=None):
+ """
+ Compute depth metrics
+
+ Parameters
+ ----------
+ gt : torch.Tensor
+ Ground-truth depth maps [B,1,H,W]
+ pred : torch.Tensor
+ Predicted depth map [B,1,H,W]
+ use_gt_scale : Bool
+ Use median-scaling
+ mask : torch.Tensor or None
+ Mask to remove pixels from evaluation
+
+ Returns
+ -------
+ metrics : torch.Tensor
+ Depth metrics
+ """
+ # Match predicted depth map to ground-truth resolution
+ pred = scale_output(pred, gt, self.scale_output)
+ # Create crop mask if requested
+ crop_mask = create_crop_mask(self.crop, gt)
+ # For each batch sample
+ metrics = []
+ for i, (pred_i, gt_i) in enumerate(zip(pred, gt)):
+
+ # Squeeze GT and PRED
+ gt_i, pred_i = torch.squeeze(gt_i), torch.squeeze(pred_i)
+ mask_i = torch.squeeze(mask[i]) if mask is not None else None
+
+ # Keep valid pixels (min/max depth and crop)
+ valid = (gt_i > self.min_depth) & (gt_i < self.max_depth)
+ # Remove invalid predicted pixels as well
+ valid = valid & (pred_i > 0)
+ # Apply crop mask if requested
+ valid = valid & crop_mask.bool() if crop_mask is not None else valid
+ # Apply provided mask if available
+ valid = valid & mask_i.bool() if mask is not None else valid
+
+ # Invalid evaluation
+ if self.valid_threshold is not None and valid.sum() < self.valid_threshold:
+ return None
+
+ # Keep only valid pixels
+ gt_i, pred_i = gt_i[valid], pred_i[valid]
+ # GT median scaling if needed
+ if use_gt_scale:
+ pred_i = pred_i * torch.median(gt_i) / torch.median(pred_i)
+ # Clamp PRED depth values to min/max values
+ pred_i = pred_i.clamp(self.min_depth, self.max_depth)
+
+ # Calculate depth metrics
+
+ thresh = torch.max((gt_i / pred_i), (pred_i / gt_i))
+ a1 = (thresh < 1.25).float().mean()
+ a2 = (thresh < 1.25 ** 2).float().mean()
+ a3 = (thresh < 1.25 ** 3).float().mean()
+
+ diff_i = gt_i - pred_i
+ abs_rel = torch.mean(torch.abs(diff_i) / gt_i)
+ sq_rel = torch.mean(diff_i ** 2 / gt_i)
+ rmse = torch.sqrt(torch.mean(diff_i ** 2))
+ rmse_log = torch.sqrt(torch.mean((torch.log(gt_i) - torch.log(pred_i)) ** 2))
+
+ err = torch.log(pred_i) - torch.log(gt_i)
+ silog = torch.sqrt(torch.mean(err ** 2) - torch.mean(err) ** 2) * 100
+
+ metrics.append([abs_rel, sq_rel, rmse, rmse_log, silog, a1, a2, a3])
+
+ # Return metrics
+ return torch.tensor(metrics, dtype=gt.dtype)
+
+ def evaluate(self, batch, output, flipped_output=None):
+ """
+ Evaluate predictions
+
+ Parameters
+ ----------
+ batch : Dict
+ Dictionary containing ground-truth information
+ output : Dict
+ Dictionary containing predictions
+ flipped_output : Bool
+ Optional flipped output for post-processing
+
+ Returns
+ -------
+ metrics : Dict
+ Dictionary with calculated metrics
+ predictions : Dict
+ Dictionary with additional predictions
+ """
+ metrics, predictions = {}, {}
+ if self.name not in batch:
+ return metrics, predictions
+ # For each output item
+ for key, val in output.items():
+ # If it corresponds to this task
+ if key.startswith(self.name) and 'debug' not in key:
+ # Loop over every context
+ val = val if is_dict(val) else {0: val}
+ for ctx in val.keys():
+ # Loop over every scale
+ for i in range(1 if self.only_first else len(val[ctx])):
+
+ pred = val[ctx][i]
+ gt = batch[self.name][ctx]
+
+ if self.post_process:
+ pred_flipped = flipped_output[key][ctx][i]
+ pred_pp = post_process_depth(pred, pred_flipped, method='mean')
+ else:
+ pred_pp = None
+
+ if i > 0:
+ pred = self.interp_nearest(pred, val[ctx][0])
+ if self.post_process:
+ pred_pp = self.interp_nearest(pred_pp, val[ctx][0])
+
+ if pred.dim() == 4:
+ suffix = '(%s)' % str(ctx) + ('_%d' % i if not self.only_first else '')
+ for mode in self.modes:
+ metrics[f'{key}|{mode}{suffix}'] = \
+ self.compute(
+ gt=gt,
+ pred=pred_pp if 'pp' in mode else pred,
+ use_gt_scale='gt' in mode,
+ mask=None,
+ )
+ elif pred.dim() == 5:
+ for j in range(pred.shape[1]):
+ suffix = '(%s_%d)' % (str(ctx), j) + ('_%d' % i if not self.only_first else '')
+ for mode in self.modes:
+ metrics[f'{key}|{mode}{suffix}'] = self.compute(
+ gt=gt[:, j],
+ pred=pred_pp[:, j] if 'pp' in mode else pred[:, j],
+ use_gt_scale='gt' in mode,
+ mask=None,
+ )
+
+ return dict_remove_nones(metrics), predictions
+
diff --git a/vidar/metrics/utils.py b/vidar/metrics/utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..91a632391be37cccd5304fbeadf74f955b0d245f
--- /dev/null
+++ b/vidar/metrics/utils.py
@@ -0,0 +1,85 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import torch
+
+from vidar.utils.tensor import interpolate_image
+
+
+def scale_output(pred, gt, scale_fn):
+ """
+ Match depth maps to ground-truth resolution
+
+ Parameters
+ ----------
+ pred : torch.Tensor
+ Predicted depth maps [B,1,w,h]
+ gt : torch.tensor
+ Ground-truth depth maps [B,1,H,W]
+ scale_fn : String
+ How to scale output to GT resolution
+ Resize: Nearest neighbors interpolation
+ top-center: Pad the top of the image and left-right corners with zeros
+
+ Returns
+ -------
+ pred : torch.tensor
+ Uncropped predicted depth maps [B,1,H,W]
+ """
+ if pred.dim() == 5 and gt.dim() == 5:
+ return torch.stack([scale_output(pred[:, i], gt[:, i], scale_fn) for i in range(pred.shape[1])], 1)
+ # Return depth map if scaling is not required
+ if scale_fn == 'none':
+ return pred
+ elif scale_fn == 'resize':
+ # Resize depth map to GT resolution
+ return interpolate_image(pred, gt.shape, mode='bilinear', align_corners=True)
+ else:
+ # Create empty depth map with GT resolution
+ pred_uncropped = torch.zeros(gt.shape, dtype=pred.dtype, device=pred.device)
+ # Uncrop top vertically and center horizontally
+ if scale_fn == 'top-center':
+ top, left = gt.shape[2] - pred.shape[2], (gt.shape[3] - pred.shape[3]) // 2
+ pred_uncropped[:, :, top:(top + pred.shape[2]), left:(left + pred.shape[3])] = pred
+ else:
+ raise NotImplementedError('Depth scale function {} not implemented.'.format(scale_fn))
+ # Return uncropped depth map
+ return pred_uncropped
+
+
+def create_crop_mask(crop, gt):
+ """
+ Create crop mask for evaluation
+
+ Parameters
+ ----------
+ crop : String
+ Type of crop
+ gt : torch.Tensor
+ Ground-truth depth map (for dimensions)
+
+ Returns
+ -------
+ crop_mask: torch.Tensor
+ Mask for evaluation
+ """
+ # Return None if mask is not required
+ if crop in ('', None):
+ return None
+ # Create empty mask
+ batch_size, _, gt_height, gt_width = gt.shape
+ crop_mask = torch.zeros(gt.shape[-2:]).byte().type_as(gt)
+ # Get specific mask
+ if crop == 'eigen_nyu':
+ crop_mask[20:459, 24:615] = 1
+ elif crop == 'bts_nyu':
+ crop_mask[45:471, 41:601] = 1
+ elif crop == 'garg':
+ y1, y2 = int(0.40810811 * gt_height), int(0.99189189 * gt_height)
+ x1, x2 = int(0.03594771 * gt_width), int(0.96405229 * gt_width)
+ crop_mask[y1:y2, x1:x2] = 1
+ elif crop == 'eigen':
+ y1, y2 = int(0.3324324 * gt_height), int(0.91351351 * gt_height)
+ x1, x2 = int(0.03594771 * gt_width), int(0.96405229 * gt_width)
+ crop_mask[y1:y2, x1:x2] = 1
+ # Return crop mask
+ return crop_mask
diff --git a/vidar/utils/config.py b/vidar/utils/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..23c8d1b97e3e007db5bfea85b051c12ff3b9dfc7
--- /dev/null
+++ b/vidar/utils/config.py
@@ -0,0 +1,468 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import importlib
+import os
+from argparse import Namespace
+
+import torch
+import yaml
+
+from vidar.utils.data import make_list, num_trainable_params
+from vidar.utils.distributed import print0
+from vidar.utils.logging import pcolor
+from vidar.utils.networks import load_checkpoint
+from vidar.utils.types import is_dict, is_list, is_namespace
+
+
+def cfg_has(*args):
+ """
+ Check if a key is in configuration
+
+ Parameters
+ ----------
+ args : Tuple (Config, String, Value)
+ Inputs:
+ length 2 = configuration/name,
+ length 3 = configuration/name/default
+
+ Returns
+ -------
+ Flag : Bool or Value
+ True/False if key is in configuration, key value/default if default is provided
+ """
+ if len(args) == 2:
+ cfg, name = args
+ if not is_list(name):
+ return name in cfg.__dict__.keys()
+ else:
+ return all([n in cfg.__dict__.keys() for n in name])
+ elif len(args) == 3:
+ cfg, name, default = args
+ has = name in cfg.__dict__.keys()
+ return cfg.__dict__[name] if has else default
+ else:
+ raise ValueError('Wrong number of arguments for cfg_has')
+
+
+def cfg_add_to_dict(dic, cfg, key, i=None):
+ """
+ Add configuration key to dictionary
+
+ Parameters
+ ----------
+ dic : Dict
+ Input dictionary
+ cfg : Config
+ Input configuration
+ key : String
+ Input key
+ i : Int
+ Optional list index
+ """
+ if cfg_has(cfg, key):
+ dic[key] = cfg.__dict__[key] if i is None \
+ else cfg.__dict__[key][0] if len(cfg.__dict__[key]) == 1 \
+ else cfg.__dict__[key][i]
+
+
+def cfg_from_dict(dic):
+ """
+ Create configuration from dictionary
+
+ Parameters
+ ----------
+ dic : Dict
+ Input dictionary
+
+ Returns
+ -------
+ cfg : Config
+ Output configuration
+ """
+ for key, val in dic.items():
+ if is_dict(val):
+ dic[key] = cfg_from_dict(val)
+ return Config(**dic)
+
+
+def update_cfg(cfg):
+ """
+ Update configuration with hard-coded information
+
+ Parameters
+ ----------
+ cfg : Config
+ Input configuration
+
+ Returns
+ -------
+ cfg : Config
+ Updated configuration
+ """
+ if not torch.cuda.is_available():
+ cfg.setup.grad_scaler = False
+ return cfg
+
+
+def to_namespace(data):
+ """
+ Convert dictionary to namespace
+
+ Parameters
+ ----------
+ data : Dict or Config
+ Input dictionary
+
+ Returns
+ -------
+ cfg : Config
+ Output configuration
+ """
+ for key in data.keys():
+ if is_dict(data[key]):
+ data[key] = to_namespace(data[key])
+ return Config(**data)
+
+
+def merge_dict(default, config):
+ """
+ Merge two dictionaries
+
+ Parameters
+ ----------
+ default : Dict
+ Dictionary with default values
+ config : Dict
+ Dictionary with values to update
+
+ Returns
+ -------
+ cfg : Dict
+ Updated dictionary
+ """
+ if is_namespace(default):
+ default = default.__dict__
+ for key in config.keys():
+ if key not in default.keys():
+ default[key] = {}
+ if not is_dict(config[key]):
+ default[key] = config[key]
+ else:
+ default[key] = merge_dict(default[key], config[key])
+ return default
+
+
+def update_from_kwargs(cfg, **kwargs):
+ """
+ Update configuration based on keyword arguments
+
+ Parameters
+ ----------
+ cfg : Config
+ Input configuration
+ kwargs : Dict
+ Keyword arguments
+
+ Returns
+ -------
+ cfg : Config
+ Updated configuration
+ """
+ if kwargs is not None:
+ for key, val in kwargs.items():
+ key_split = key.split('.')
+ dic = cfg.__dict__
+ for k in key_split[:-1]:
+ dic = dic[k].__dict__
+ dic[key_split[-1]] = val
+ return cfg
+
+
+def recursive_recipe(cfg, super_key=None):
+ """
+ Add recipe parameters to configuration
+
+ Parameters
+ ----------
+ cfg : Config
+ Input configuration
+ super_key : String
+ Which recipe entry to use
+
+ Returns
+ -------
+ cfg : Config
+ Updated configuration
+ """
+ for key in list(cfg.keys()):
+ if is_dict(cfg[key]):
+ cfg[key] = recursive_recipe(cfg[key], super_key=key)
+ elif key == 'recipe':
+ recipe = 'configs/recipes/' + cfg.pop(key)
+ if '|' in recipe:
+ recipe, super_key = recipe.split('|')
+ recipe = read_config(recipe + '.yaml')
+ while '.' in super_key:
+ split = super_key.split('.')
+ recipe = recipe.__dict__[split[0]]
+ super_key = '.'.join(split[1:])
+ recipe = recipe.__dict__[super_key].__dict__
+ cfg = merge_dict(recipe, cfg)
+ return cfg
+
+
+def read_config(path, **kwargs):
+ """
+ Create configuration from file
+
+ Parameters
+ ----------
+ path : String
+ Configuration path
+ kwargs : Dict
+ Keyword arguments to update configuration
+
+ Returns
+ -------
+ cfg : Config
+ Output configuration
+ """
+ """Read configuration from file"""
+ with open(path) as cfg:
+ config = yaml.load(cfg, Loader=yaml.FullLoader)
+ config = recursive_recipe(config)
+ cfg = to_namespace(config)
+ if kwargs is not None:
+ cfg = update_from_kwargs(cfg, **kwargs)
+ return cfg
+
+
+def is_recursive(val):
+ """
+ Check if configuration entry is recursive
+
+ Parameters
+ ----------
+ val : Config
+ Input Configuration
+
+ Returns
+ -------
+ Flag : Bool
+ True/False if is recursive or not
+ """
+ return 'file' in val.__dict__.keys()
+
+
+def get_folder_name(path, mode, root='vidar/arch'):
+ """
+ Get folder and name from configuration path
+
+ Parameters
+ ----------
+ path : String
+ Input path
+ mode : String
+ Which mode to use (e.g., models, networks, losses)
+ root : String
+ Which folder to use
+
+ Returns
+ -------
+ folder : String
+ Output folder
+ name : String
+ Output name
+ """
+ """Get folder and name from configuration path"""
+ folder, name = os.path.dirname(path), os.path.basename(path)
+ folder = os.path.join(root, mode, folder)
+ if folder.endswith('/'):
+ folder = folder[:-1]
+ return folder, name
+
+
+def recursive_assignment(model, cfg, mode, verbose=True):
+ """
+ Recursively assign information from a configuration
+
+ Parameters
+ ----------
+ model : torch.nn.Module
+ Which network we are using
+ cfg : Config
+ Input Configuration
+ mode : String
+ Which mode we are using (e.g., models, networks, losses)
+ verbose : Bool
+ Print information on screen
+ """
+ font = {'color': 'yellow', 'attrs': ('dark',)}
+ for key, val in cfg.__dict__.items():
+ cls = cfg.__dict__[key]
+ if is_namespace(cls):
+ if is_recursive(val):
+ folder, name = get_folder_name(val.file, mode)
+ getattr(model, mode)[key] = load_class(name, folder)(cls)
+ if verbose:
+ string = '######### {}'.format(getattr(model, mode)[key].__class__.__name__)
+ num_params = num_trainable_params(getattr(model, mode)[key])
+ if num_params > 0:
+ string += f' ({num_params:,} parameters)'
+ print0(pcolor(string, **font))
+ if cfg_has(val, 'checkpoint'):
+ model_attr = getattr(model, mode)[key]
+ load_checkpoint(model_attr, val.checkpoint, strict=False, verbose=verbose, prefix=key)
+ recursive_assignment(getattr(model, mode)[key], cls, mode, verbose=verbose)
+ if key == 'blocks':
+ for key2, val2 in cfg.__dict__[key].__dict__.items():
+ cls2 = cfg.__dict__[key].__dict__[key2]
+ if is_recursive(val2):
+ folder, name = get_folder_name(val2.file, 'blocks')
+ model.blocks[key2] = load_class(name, folder)(cls2)
+ recursive_assignment(model.blocks[key2], cls2, 'blocks', verbose=verbose)
+
+
+def load_class(filename, paths, concat=True, methodname=None):
+ """
+ Look for a file in different locations and return its method with the same name
+ Optionally, you can use concat to search in path.filename instead
+
+ Parameters
+ ----------
+ filename : String
+ Name of the file we are searching for
+ paths : String or list[String]
+ Folders in which the file will be searched
+ concat : Bol
+ Flag to concatenate filename to each path during the search
+ methodname : String or list[String]
+ Method name (If None, use filename
+ If it's a string, use it as the methodname
+ If it's a list, use the first methodname found)
+
+ Returns
+ -------
+ method : Function
+ Loaded method
+ """
+ # If method name is not given, use filename
+ methodname = make_list(filename if methodname is None else methodname)
+ # for each path in paths
+ for path in make_list(paths):
+ # Create full path
+ path = path.replace('/', '.')
+ full_path = '{}.{}'.format(path, filename) if concat else path
+ # Get module
+ module = importlib.import_module(full_path)
+ # Try all method names
+ for name in methodname:
+ method = getattr(module, name, None)
+ # Return if found
+ if method is not None:
+ return method
+ # Didn't find anything
+ raise ValueError('Unknown class {}'.format(filename))
+
+
+def get_from_cfg_list(cfg, key, idx):
+ """
+ Get configuration value from a list
+
+ Parameters
+ ----------
+ cfg : Config
+ Input configuration
+ key : String
+ Input configuration key
+ idx : Int
+ List index
+
+ Returns
+ -------
+ data : Value
+ Key value at that index if it's a list, otherwise return the key value directly
+ """
+ if key not in cfg.__dict__.keys():
+ return None
+ data = cfg.__dict__[key]
+ return data if not is_list(data) else data[idx] if len(data) > 1 else data[0]
+
+
+def dataset_prefix(cfg, idx):
+ """
+ Create dataset prefix based on configuration information
+
+ Parameters
+ ----------
+ cfg : Config
+ Input configuration
+ idx : Int
+ Input index for information retrieval
+
+ Returns
+ -------
+ prefix : String
+ Dataset prefix
+ """
+ # Dataset path is always available
+ # prefix = cfg.name[idx]
+ prefix = '{}'.format(os.path.splitext(get_from_cfg_list(cfg, 'path', idx).split('/')[-1])[0])
+ # If split is available
+ val = get_from_cfg_list(cfg, 'split', idx)
+ if val is not None:
+ prefix += '-{}'.format(os.path.splitext(os.path.basename(val))[0])
+ # If input depth type is available
+ val = get_from_cfg_list(cfg, 'input_depth_type', idx)
+ if val is not None and val not in [None, '']:
+ prefix += '-+{}'.format(val)
+ # If depth type is available
+ val = get_from_cfg_list(cfg, 'depth_type', idx)
+ if val is not None and val not in [None, '']:
+ prefix += '-{}'.format(val)
+ # If there is camera information
+ val = get_from_cfg_list(cfg, 'cameras', idx)
+ if val is not None and is_list(val) and len(val) > 0:
+ prefix += '-cam{}'.format(val[0])
+ # Return prefix
+ return prefix
+
+
+class Config(Namespace):
+ """
+ Configuration class for passing arguments between other classes
+
+ Parameters
+ ----------
+ kwargs: Dict
+ Arguments to create configuration
+ """
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ @staticmethod
+ def from_file(file):
+ """Read configuration from file"""
+ return read_config(file)
+
+ @property
+ def dict(self):
+ """Return configuration as dictionary"""
+ return self.__dict__
+
+ def keys(self):
+ """Return dictionary keys of configuration"""
+ return self.dict.keys()
+
+ def items(self):
+ """Return dictionary items of configuration"""
+ return self.dict.items()
+
+ def values(self):
+ """Return dictionary values of configuration"""
+ return self.dict.values()
+
+ def has(self, *args):
+ """Check if configuration has certain parameters"""
+ return cfg_has(self, *args)
+
diff --git a/vidar/utils/data.py b/vidar/utils/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..540bb918f9176a65dcd39d3252fdf9698cf1e6cc
--- /dev/null
+++ b/vidar/utils/data.py
@@ -0,0 +1,245 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import random
+from collections import OrderedDict
+from inspect import signature
+
+import numpy as np
+import torch
+
+from vidar.utils.decorators import iterate1, iterate2
+from vidar.utils.types import is_list, is_double_list, is_tuple, is_tensor, is_dict, is_seq
+
+KEYS_IMAGE = [
+ 'rgb', 'mask',
+ 'input_depth', 'depth',
+ 'bwd_optical_flow', 'fwd_optical_flow',
+ ]
+
+KEYS_MATRIX = [
+ 'intrinsics', 'extrinsics', 'pose', 'semantic',
+ ]
+
+
+def modrem(v, n):
+ """Return round division and remainder"""
+ return v // n, v % n
+
+
+def flatten(lst):
+ """Flatten a list of lists into a list"""
+ return [l for ls in lst for l in ls] if is_double_list(lst) else lst
+
+
+def keys_with(dic, string, without=()):
+ """Return keys from a dictionary that contain a certain string"""
+ return [key for key in dic if string in key and not any(w in key for w in make_list(without))]
+
+
+def keys_startswith(dic, string):
+ """Return keys from a dictionary that contain a certain string"""
+ return [key for key in dic if key.startswith(string)]
+
+
+def keys_in(dic, keys):
+ """Return only keys in a dictionary"""
+ return [key for key in keys if key in dic]
+
+
+def str_not_in(string, keys):
+ for key in keys:
+ if key in string:
+ return False
+ return True
+
+
+def make_list(var, n=None):
+ """Wraps the input into a list, and optionally repeats it to be size n"""
+ var = var if is_list(var) or is_tuple(var) else [var]
+ if n is None:
+ return var
+ else:
+ assert len(var) == 1 or len(var) == n, 'Wrong list length for make_list'
+ return var * n if len(var) == 1 else var
+
+
+def filter_args(func, keys):
+ """Filters a dictionary, so it only contains keys that are arguments of a function"""
+ filtered = {}
+ sign = list(signature(func).parameters.keys())
+ for k, v in {**keys}.items():
+ if k in sign:
+ filtered[k] = v
+ return filtered
+
+
+def dict_remove_nones(dic):
+ """Filters dictionary to remove keys with None values"""
+ return {key: val for key, val in dic.items() if val is not None}
+
+
+@iterate1
+def matmul1(v1, v2):
+ """Iteratively multiply tensors"""
+ return v1 @ v2
+
+
+@iterate2
+def matmul2(v1, v2):
+ """Iteratively multiply tensors"""
+ return v1 @ v2
+
+
+@iterate1
+def unsqueeze(x):
+ """Iteratively unsqueeze tensors to batch size 1"""
+ return x.unsqueeze(0) if is_tensor(x) else x
+
+
+@iterate1
+def fold(data, n):
+ """Iteratively folds first and second dimensions into one"""
+ shape = list(data.shape)
+ if len(shape) == n + 1:
+ shape = [shape[0] * shape[1]] + shape[2:]
+ return data.view(*shape)
+ else:
+ return data
+
+
+@iterate1
+def expand(data, n, d):
+ """Iteratively folds first and second dimensions into one"""
+ shape = list(data.shape)
+ if len(shape) == n:
+ return data.unsqueeze(d)
+ else:
+ return data
+
+
+def fold_batch(batch, device=None):
+ """Combine the first (batch) and second (camera) dimensions of a batch"""
+ if is_seq(batch):
+ return [fold_batch(b, device=device) for b in batch]
+ for key in keys_in(batch, KEYS_IMAGE):
+ batch[key] = fold(batch[key], 4)
+ for key in keys_in(batch, KEYS_MATRIX):
+ batch[key] = fold(batch[key], 3)
+ if device is not None:
+ batch = batch_to_device(batch, device)
+ return batch
+
+
+def expand_batch(batch, d, device=None):
+ """Expand the batch to include an additional dimension (0 for batch, 1 for camera)"""
+ if is_seq(batch):
+ return [expand_batch(b, d, device=device) for b in batch]
+ d = {'batch': 0, 'camera': 1}[d]
+ for key in keys_in(batch, KEYS_IMAGE):
+ batch[key] = expand(batch[key], 4, d)
+ for key in keys_in(batch, KEYS_MATRIX):
+ batch[key] = expand(batch[key], 3, d)
+ if device is not None:
+ batch = batch_to_device(batch, device)
+ return batch
+
+
+def batch_to_device(batch, device):
+ """Copy batch information to device"""
+ if is_dict(batch):
+ return {key: batch_to_device(val, device) for key, val in batch.items()}
+ if is_list(batch):
+ return [batch_to_device(val, device) for val in batch]
+ if is_tensor(batch):
+ return batch.to(device)
+ return batch
+
+
+def num_trainable_params(model):
+ """Return number of trainable parameters"""
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+
+def set_random_seed(seed):
+ if seed >= 0:
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+
+def make_batch(batch, device=None):
+ """Transforms a sample into a batch"""
+ for key in batch.keys():
+ if is_dict(batch[key]):
+ batch[key] = make_batch(batch[key])
+ elif is_tensor(batch[key]):
+ batch[key] = batch[key].unsqueeze(0)
+ if device is not None:
+ batch = batch_to_device(batch, device)
+ return batch
+
+
+def break_key(sample, n=None):
+ """Break a multi-camera sample key, so different cameras have their own entries (context, camera)"""
+ if sample is None:
+ return sample
+ new_sample = OrderedDict()
+ for ctx in sample.keys():
+ if is_dict(sample[ctx]):
+ for key2, val in sample[ctx].items():
+ if val.dim() == 1:
+ val = val.unsqueeze(1)
+ for i in range(val.shape[1]):
+ if (ctx, i) not in new_sample.keys():
+ new_sample[(ctx, i)] = {}
+ new_sample[(ctx, i)][key2] = val[:, [i]]
+ elif sample[ctx].dim() == n + 1:
+ for i in range(sample[ctx].shape[1]):
+ new_sample[(ctx, i)] = sample[ctx][:, i]
+ return new_sample
+
+
+def break_batch(batch):
+ """Break a multi-camera batch, so different cameras have their own entries (context, camera)"""
+ for key in keys_in(batch, KEYS_IMAGE):
+ for ctx in list(batch[key].keys()):
+ if batch[key][ctx].dim() == 5:
+ for n in range(batch[key][ctx].shape[1]):
+ batch[key][(ctx,n)] = batch[key][ctx][:, n]
+ batch[key].pop(ctx)
+ for key in keys_in(batch, KEYS_MATRIX):
+ for ctx in list(batch[key].keys()):
+ if batch[key][ctx].dim() == 4:
+ for n in range(batch[key][ctx].shape[1]):
+ batch[key][(ctx,n)] = batch[key][ctx][:, n]
+ batch[key].pop(ctx)
+ return batch
+
+
+def dict_has(dic, key):
+ """Check if a dictionary has a certain key"""
+ return key in dic
+
+
+def get_from_dict(dic, key):
+ """Get value from a dictionary (return None if key is not in dictionary)"""
+ return None if key not in dic else dic[key]
+
+
+def get_mask_from_list(mask, i, return_ones=None):
+ """Retrieve mask from a list (if it's not a list, return the mask itself, and create one if requested)"""
+ if return_ones is None:
+ return None if mask is None else mask[i] if is_list(mask) else mask
+ else:
+ mask = torch.ones_like(return_ones[i] if is_list(return_ones) else return_ones).bool() if mask is None \
+ else mask[i].clone().bool() if is_list(mask) else mask.clone().bool()
+ if mask.dim() == 4:
+ return mask[:, [0]]
+ elif mask.dim() == 3:
+ return mask[..., [0]]
+
+
+def get_from_list(lst, i):
+ """Get information from a list (return None if input is None, and return value directly if it's not a list)"""
+ return None if lst is None else lst[i] if is_seq(lst) else lst
diff --git a/vidar/utils/decorators.py b/vidar/utils/decorators.py
new file mode 100755
index 0000000000000000000000000000000000000000..38d11b5d34014241ece78df8308907ee90e2913f
--- /dev/null
+++ b/vidar/utils/decorators.py
@@ -0,0 +1,60 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from vidar.utils.types import is_seq, is_dict
+
+
+def iterate1(func):
+ """Decorator to iterate over a list (first argument)"""
+ def inner(var, *args, **kwargs):
+ if is_seq(var):
+ return [func(v, *args, **kwargs) for v in var]
+ elif is_dict(var):
+ return {key: func(val, *args, **kwargs) for key, val in var.items()}
+ else:
+ return func(var, *args, **kwargs)
+ return inner
+
+
+def iterate2(func):
+ """Decorator to iterate over a list (second argument)"""
+ def inner(self, var, *args, **kwargs):
+ if is_seq(var):
+ return [func(self, v, *args, **kwargs) for v in var]
+ elif is_dict(var):
+ return {key: func(self, val, *args, **kwargs) for key, val in var.items()}
+ else:
+ return func(self, var, *args, **kwargs)
+ return inner
+
+
+def iterate12(func):
+ """Decorator to iterate over a list (first argument)"""
+ def inner(var1, var2, *args, **kwargs):
+ if is_seq(var1) and is_seq(var2):
+ return [func(v1, v2, *args, **kwargs) for v1, v2 in zip(var1, var2)]
+ elif is_dict(var1) and is_dict(var2):
+ return {key: func(val1, val2, *args, **kwargs)
+ for key, val1, val2 in zip(var1.keys(), var1.values(), var2.values())}
+ else:
+ return func(var1, var2, *args, **kwargs)
+ return inner
+
+
+def multi_write(func):
+ """Decorator to write multiple files"""
+ def inner(filename, data, **kwargs):
+ if is_seq(data):
+ for i in range(len(data)):
+ filename_i, ext = filename.split('.')
+ filename_i = '%s_%d.%s' % (filename_i, i, ext)
+ func(filename_i, data[i], **kwargs)
+ return
+ elif is_dict(data):
+ for key, val in data.items():
+ filename_i, ext = filename.split('.')
+ filename_i = '%s(%s).%s' % (filename_i, key, ext)
+ func(filename_i, val, **kwargs)
+ return
+ else:
+ return func(filename, data)
+ return inner
diff --git a/vidar/utils/depth.py b/vidar/utils/depth.py
new file mode 100644
index 0000000000000000000000000000000000000000..59493454bc15dbf5ddd6e3460f233444ddfab2f3
--- /dev/null
+++ b/vidar/utils/depth.py
@@ -0,0 +1,287 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import torch
+import torch.nn.functional as tfn
+
+from vidar.geometry.camera import Camera
+from vidar.utils.decorators import iterate1
+from vidar.utils.types import is_tensor, is_numpy
+
+
+@iterate1
+@iterate1
+def inv2depth(inv_depth):
+ """
+ Invert an inverse depth map to produce a depth map
+
+ Parameters
+ ----------
+ inv_depth : torch.Tensor or list[torch.Tensor] or np.array or list[np.array]
+ Inverse depth map [B,1,H,W]
+
+ Returns
+ -------
+ depth : torch.Tensor or list[torch.Tensor] or np.array or list[np.array]
+ Depth map [B,1,H,W]
+ """
+ if is_tensor(inv_depth):
+ depth = 1. / inv_depth.clamp(min=1e-6, max=None)
+ elif is_numpy(inv_depth):
+ depth = 1. / inv_depth.clip(min=1e-6, max=None)
+ else:
+ raise NotImplementedError('Wrong inverse depth format.')
+ depth[inv_depth <= 0.] = 0.
+ return depth
+
+
+@iterate1
+@iterate1
+def depth2inv(depth):
+ """
+ Invert a depth map to produce an inverse depth map
+
+ Parameters
+ ----------
+ depth : torch.Tensor or list[torch.Tensor] or np.array or list[np.array]
+ Depth map [B,1,H,W]
+
+ Returns
+ -------
+ inv_depth : torch.Tensor or list[torch.Tensor] pr np.array or list[np.array]
+ Inverse depth map [B,1,H,W]
+
+ """
+ if is_tensor(depth):
+ inv_depth = 1. / depth.clamp(min=1e-6, max=None)
+ elif is_numpy(depth):
+ inv_depth = 1. / depth.clip(min=1e-6, max=None)
+ else:
+ raise NotImplementedError('Wrong depth format')
+ inv_depth[depth <= 0.] = 0.
+ return inv_depth
+
+
+def fuse_inv_depth(inv_depth, inv_depth_hat, method='mean'):
+ """
+ Fuse inverse depth and flipped inverse depth maps
+
+ Parameters
+ ----------
+ inv_depth : torch.Tensor
+ Inverse depth map [B,1,H,W]
+ inv_depth_hat : torch.Tensor
+ Flipped inverse depth map produced from a flipped image [B,1,H,W]
+ method : String
+ Method that will be used to fuse the inverse depth maps
+
+ Returns
+ -------
+ fused_inv_depth : torch.Tensor [B,1,H,W]
+ Fused inverse depth map
+ """
+ if method == 'mean':
+ return 0.5 * (inv_depth + inv_depth_hat)
+ elif method == 'max':
+ return torch.max(inv_depth, inv_depth_hat)
+ elif method == 'min':
+ return torch.min(inv_depth, inv_depth_hat)
+ else:
+ raise ValueError('Unknown post-process method {}'.format(method))
+
+
+def post_process_inv_depth(inv_depth, inv_depth_flipped, method='mean'):
+ """
+ Post-process an inverse and flipped inverse depth map
+
+ Parameters
+ ----------
+ inv_depth : torch.Tensor
+ Inverse depth map [B,1,H,W]
+ inv_depth_flipped : torch.Tensor
+ Inverse depth map produced from a flipped image [B,1,H,W]
+ method : String
+ Method that will be used to fuse the inverse depth maps
+
+ Returns
+ -------
+ inv_depth_pp : torch.Tensor
+ Post-processed inverse depth map [B,1,H,W]
+ """
+ from vidar.utils.flip import flip_lr
+ B, C, H, W = inv_depth.shape
+ inv_depth_hat = inv_depth_flipped
+ inv_depth_fused = fuse_inv_depth(inv_depth, inv_depth_hat, method=method)
+ xs = torch.linspace(0., 1., W, device=inv_depth.device,
+ dtype=inv_depth.dtype).repeat(B, C, H, 1)
+ mask = 1.0 - torch.clamp(20. * (xs - 0.05), 0., 1.)
+ mask_hat = flip_lr(mask)
+ post_processed = mask_hat * inv_depth + mask * inv_depth_hat + \
+ (1.0 - mask - mask_hat) * inv_depth_fused
+ mask0, mask_hat0 = inv_depth == 0, inv_depth_hat == 0
+ post_processed[mask_hat0] = inv_depth[mask_hat0]
+ post_processed[mask0] = inv_depth_hat[mask0]
+ post_processed[mask0 & mask_hat0] = 0
+ return post_processed
+
+
+def post_process_depth(depth, depth_flipped, method='mean'):
+ """Post-process a depth map and flipped depth map"""
+ return inv2depth(post_process_inv_depth(
+ depth2inv(depth), depth2inv(depth_flipped), method=method))
+
+
+def calculate_normals(points, camera=None, intrinsics=None, pad_last=True):
+ """
+ Calculate normals from a pointcloud map or from a depth map + intrinsics
+
+ Parameters
+ ----------
+ points : torch.Tensor
+ A pointcloud map [B,3,H,W] containing 3D coordinates
+ A depth map [B,1,H,W] containing depth values
+ camera : Camera
+ Camera used for normal calculation, in case a depth map is provided
+ intrinsics : torch.Tensor
+ Camera intrinsics [B,3,3] necessary in case a depth map is provided, to create the pointcloud map
+ pad_last : Bool
+ If true, pad the last row and column with zeros
+
+ Returns
+ -------
+ normals : torch.Tensor
+ Normal map [B,3,H,W] containing normal estimates
+ """
+ if intrinsics is None and camera is None:
+ return points
+ # Create pointcloud map if intrinsics are provided
+ if camera is not None:
+ points = camera.reconstruct_depth_map(points)
+ elif intrinsics is not None:
+ points = Camera(K=intrinsics, hw=points).reconstruct_depth_map(points)
+ # Prepare points for cross-product
+ p0 = points[:, :, :-1, :-1]
+ p1 = points[:, :, 1:, :-1]
+ p2 = points[:, :, :-1, 1:]
+ # Calculate and normalize normals
+ normals = torch.cross(p1 - p0, p2 - p0, 1)
+ normals = normals / normals.norm(dim=1, keepdim=True)
+ # Pad normals
+ if pad_last:
+ normals = torch.nn.functional.pad(normals, [0, 1, 0, 1], mode='replicate')
+ # # Substitute nan values with zero
+ normals[torch.isnan(normals)] = 0.0
+ # Return normals
+ return normals
+
+
+def calc_dot_prod(pts, nrm):
+ """
+ Calculate dot product of 3D points and their normals
+
+ Parameters
+ ----------
+ pts : torch.Tensor
+ Input 3D points [B,3,H,W]
+ nrm : torch.Tensor
+ Input 3D normal vectors [B,3,H,W]
+
+ Returns
+ -------
+ dots : torch.Tensor
+ Output dot product [B,1,H,W]
+ """
+ pts = pts / pts.norm(dim=1, keepdim=True)
+ nrm = nrm / nrm.norm(dim=1, keepdim=True)
+ dots = torch.sum(pts * nrm, dim=1, keepdim=True)
+ return dots
+
+
+def get_depth_bins(mode, min_val, max_val, num_vals):
+ """
+ Create discretize depth bins
+
+ Parameters
+ ----------
+ mode : String
+ Discretization mode
+ min_val : Float
+ Minimum depth value
+ max_val : Float
+ Maximum depth value
+ num_vals : Int
+ Number of intervals
+
+ Returns
+ -------
+ bins : torch.Tensor
+ Discretized depth values [num_vals]
+ """
+ if mode == 'inverse':
+ depth_bins = 1. / torch.linspace(
+ 1. / max_val, 1. / min_val, num_vals)[::-1]
+ elif mode == 'linear':
+ depth_bins = torch.linspace(
+ min_val, max_val, num_vals)
+ elif mode == 'sid':
+ depth_bins = torch.tensor(
+ [torch.exp(torch.log(min_val) + torch.log(max_val / min_val) * i / (num_vals - 1))
+ for i in range(num_vals)])
+ else:
+ raise NotImplementedError
+ return depth_bins.float()
+
+
+def depth2index(depth, bins):
+ """
+ Convert a depth map to discretized indexes
+
+ Parameters
+ ----------
+ depth : torch.Tensor
+ Input depth map [B,1,H,W]
+ bins : torch.Tensor
+ Discretized depth bins [D]
+
+ Returns
+ -------
+ idx : torch.Tensor
+ Discretized depth indexes [B,1,H,W]
+ """
+ if depth.dim() == 2:
+ depth = tfn.relu(depth - bins.reshape(1, -1))
+ elif depth.dim() == 4:
+ depth = tfn.relu(depth - bins.reshape(1, -1, 1, 1))
+ else:
+ raise ValueError('Invalid depth dimension')
+ idx = torch.min(depth, dim=1)[1]
+ # idx[(idx < 0) | (idx == len(bins) - 1)] = -1
+ idx[(idx < 0)] = 0
+ idx[idx > len(bins) - 1] = len(bins) - 1
+ return idx.unsqueeze(1)
+
+
+def index2depth(idx, bins):
+ """
+ Converts discretized indexes to depth map
+
+ Parameters
+ ----------
+ idx : torch.Tensor
+ Discretized indexes [B,1,H,W]
+ bins : torch.Tensor
+ Discretized depth bins [D]
+
+ Returns
+ -------
+ depth : torch.Tensor
+ Output depth map [B,1,H,W]
+ """
+ if idx.dim() == 4:
+ b, _, h, w = idx.shape
+ bins = bins.reshape(1, -1, 1, 1).repeat(idx.shape[0], 1, idx.shape[2], idx.shape[3]).to(idx.device)
+ elif idx.dim() == 3:
+ b, _, n = idx.shape
+ bins = bins.reshape(1, -1, 1).repeat(idx.shape[0], 1, idx.shape[2]).to(idx.device)
+ else:
+ raise ValueError('Invalid depth dimension')
+ return torch.gather(bins, 1, idx)
diff --git a/vidar/utils/distributed.py b/vidar/utils/distributed.py
new file mode 100755
index 0000000000000000000000000000000000000000..ab0564a6e55ec4a1b5aa9611c530ed266733a22c
--- /dev/null
+++ b/vidar/utils/distributed.py
@@ -0,0 +1,74 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import os
+
+import torch.distributed as dist
+
+
+def dist_mode():
+ return os.getenv('DIST_MODE')
+
+
+def rank():
+ """Returns process rank"""
+ if dist_mode() in ['cpu', 'gpu', None]:
+ return 0
+ elif dist_mode() == 'ddp':
+ return int(os.environ['RANK'])
+ else:
+ raise ValueError('Wrong distributed mode {}'.format(dist_mode))
+
+
+def world_size():
+ """Returns world size"""
+ if dist_mode() in ['cpu', 'gpu', None]:
+ return 1
+ elif dist_mode() == 'ddp':
+ return int(os.environ['WORLD_SIZE'])
+ else:
+ raise ValueError('Wrong distributed mode {}'.format(dist_mode))
+
+
+def on_rank_0(func):
+ """Decorator to run function only on rank 0"""
+ def wrapper(*args, **kwargs):
+ if rank() == 0:
+ return func(*args, **kwargs)
+ return wrapper
+
+
+@on_rank_0
+def print0(string='\n'):
+ """Function to print only on rank 0"""
+ print(string)
+
+
+def reduce_value(value, average, name):
+ """
+ Reduce the mean value of a tensor from all GPUs
+
+ Parameters
+ ----------
+ value : torch.Tensor
+ Value to be reduced
+ average : Bool
+ Whether values will be averaged or not
+ name : String
+ Value name
+
+ Returns
+ -------
+ value : torch.Tensor
+ reduced value
+ """
+ if dist_mode() == 'cpu':
+ return value
+ elif dist_mode() == 'gpu':
+ return value
+ elif dist_mode() == 'ddp':
+ dist.all_reduce(tensor=value, op=dist.ReduceOp.SUM, async_op=False)
+ if average:
+ value /= world_size()
+ return value
+ else:
+ raise ValueError('Wrong distributed mode {}'.format(dist_mode))
diff --git a/vidar/utils/flip.py b/vidar/utils/flip.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f2f360a314ac9cef635dbc7c5a6d6af8d4907c4
--- /dev/null
+++ b/vidar/utils/flip.py
@@ -0,0 +1,182 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import torch
+from pytorch3d.transforms.rotation_conversions import \
+ matrix_to_euler_angles, euler_angles_to_matrix
+
+from vidar.utils.data import keys_in
+from vidar.utils.decorators import iterate1, iterate12
+from vidar.utils.types import is_tensor, is_list, is_seq
+
+
+def flip_lr_fn(tensor):
+ """Function to flip a tensor from left to right"""
+ return torch.flip(tensor, [-1])
+
+
+def flip_flow_lr_fn(flow):
+ """Function to flip a flow tensor from left to right"""
+ flow_flip = torch.flip(flow, [3])
+ flow_flip[:, :1, :, :] *= -1
+ return flow_flip.contiguous()
+
+
+def flip_intrinsics_lr_fn(K, shape):
+ """Function to flip a 3x3 intrinsic matrix from left to right"""
+ K = K.clone()
+ K[:, 0, 2] = shape[-1] - K[:, 0, 2]
+ return K
+
+
+def flip_pose_lr_fn(T):
+ """Function to flip a 4x4 transformation matrix from left to right"""
+ rot = T[:, :3, :3]
+ axis = matrix_to_euler_angles(rot, convention='XYZ')
+ axis[:, [1, 2]] = axis[:, [1, 2]] * -1
+ rot = euler_angles_to_matrix(axis, convention='XYZ')
+ T[:, :3, :3] = rot
+ T[:, 0, -1] = - T[:, 0, -1]
+ return T
+
+
+@iterate1
+def flip_lr(tensor, flip=True):
+ """Flip a tensor from left to right"""
+ # Not flipping option
+ if not flip:
+ return tensor
+ # If it's a list, repeat
+ if is_list(tensor):
+ return [flip_lr(t) for t in tensor]
+ # Return flipped tensor
+ if tensor.dim() == 5:
+ return torch.stack([flip_lr_fn(tensor[:, i])
+ for i in range(tensor.shape[1])], 1)
+ else:
+ return flip_lr_fn(tensor)
+
+
+@iterate1
+def flip_flow_lr(flow, flip=True):
+ """Flip a flow tensor from left to right"""
+ # Not flipping option
+ if not flip:
+ return flow
+ # If it's a list, repeat
+ if is_list(flow):
+ return [flip_flow_lr(f) for f in flow]
+ # Flip flow and invert first dimension
+ if flow.dim() == 5:
+ return torch.stack([flip_flow_lr_fn(flow[:, i])
+ for i in range(flow.shape[1])], 1)
+ else:
+ return flip_flow_lr_fn(flow)
+
+
+@iterate12
+def flip_intrinsics_lr(K, shape, flip=True):
+ """Flip a 3x3 camera intrinsic matrix from left to right"""
+ # Not flipping option
+ if not flip:
+ return K
+ # If shape is a tensor, use it's dimensions
+ if is_tensor(shape):
+ shape = shape.shape
+ # Flip horizontal information (first row)
+ if K.dim() == 4:
+ return torch.stack([flip_intrinsics_lr_fn(K[:, i], shape)
+ for i in range(K.shape[1])], 1)
+ else:
+ return flip_intrinsics_lr_fn(K, shape)
+
+
+def flip_pose_lr(pose, flip=True):
+ """Flip a 4x4 transformation matrix from left to right"""
+ # Not flipping option
+ if not flip:
+ return pose
+ # Repeat for all pose keys
+ for key in pose.keys():
+ # Get pose key
+ if key == 0:
+ if pose[key].dim() == 3:
+ continue
+ elif pose[key].dim() == 4:
+ T = pose[key][:, 1:].clone()
+ else:
+ raise ValueError('Invalid pose dimension')
+ else:
+ T = pose[key].clone()
+ # Flip pose
+ if T.dim() == 4:
+ T = torch.stack([flip_pose_lr_fn(T[:, i])
+ for i in range(T.shape[1])], 1)
+ else:
+ T = flip_pose_lr_fn(T)
+ # Store flipped value back
+ if key == 0:
+ pose[key][:, 1:] = T
+ else:
+ pose[key] = T
+ # Return flipped pose
+ return pose
+
+
+def flip_batch(batch, flip=True):
+ """Flip a batch from left to right"""
+ # Not flipping option
+ if not flip:
+ return batch
+ # If it's a list, repeat
+ if is_seq(batch):
+ return [flip_batch(b) for b in batch]
+ # Flip batch
+ flipped_batch = {}
+ # Keys to not flip
+ for key in keys_in(batch, ['idx', 'filename', 'splitname']):
+ flipped_batch[key] = batch[key]
+ # Tensor flipping
+ for key in keys_in(batch, ['rgb', 'mask', 'input_depth', 'depth', 'semantic']):
+ flipped_batch[key] = flip_lr(batch[key])
+ # Intrinsics flipping
+ for key in keys_in(batch, ['intrinsics']):
+ flipped_batch[key] = flip_intrinsics_lr(batch[key], batch['rgb'])
+ # Pose flipping
+ for key in keys_in(batch, ['pose']):
+ flipped_batch[key] = flip_pose_lr(batch[key])
+ return flipped_batch
+
+
+def flip_predictions(predictions, flip=True):
+ """Flip predictions from left to right"""
+ # Not flipping option
+ if not flip:
+ return predictions
+ # Flip predictions
+ flipped_predictions = {}
+ for key in predictions.keys():
+ if key.startswith('depth'):
+ flipped_predictions[key] = flip_lr(predictions[key])
+ if key.startswith('pose'):
+ flipped_predictions[key] = flip_pose_lr(predictions[key])
+ # Return flipped predictions
+ return flipped_predictions
+
+
+def flip_output(output, flip=True):
+ """Flip output from left to right"""
+ # Not flipping option
+ if not flip:
+ return output
+ # If it's a list, repeat
+ if is_seq(output):
+ return [flip_output(b) for b in output]
+ # Flip output
+ flipped_output = {}
+ # Do not flip loss and metrics
+ for key in keys_in(output, ['loss', 'metrics']):
+ flipped_output[key] = output[key]
+ # Flip predictions
+ flipped_output['predictions'] = flip_predictions(output['predictions'])
+ # Return flipped output
+ return flipped_output
diff --git a/vidar/utils/flow.py b/vidar/utils/flow.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ffbe6a4b6ec60835e3749ef063ff2fee141ceb1
--- /dev/null
+++ b/vidar/utils/flow.py
@@ -0,0 +1,293 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import torch
+import torch.nn.functional as tfn
+
+from vidar.utils.data import make_list
+from vidar.utils.flow_triangulation_support import bearing_grid, mult_rotation_bearing, triangulation
+from vidar.utils.tensor import pixel_grid, norm_pixel_grid, unnorm_pixel_grid
+from vidar.utils.types import is_list
+
+
+def warp_from_coords(tensor, coords, mode='bilinear',
+ padding_mode='zeros', align_corners=True):
+ """
+ Warp an image from a coordinate map
+
+ Parameters
+ ----------
+ tensor : torch.Tensor
+ Input tensor for warping [B,?,H,W]
+ coords : torch.Tensor
+ Warping coordinates [B,2,H,W]
+ mode : String
+ Warping mode
+ padding_mode : String
+ Padding mode
+ align_corners : Bool
+ Align corners flag
+
+ Returns
+ -------
+ warp : torch.Tensor
+ Warped tensor [B,?,H,W]
+ """
+ # Sample grid from data with coordinates
+ warp = tfn.grid_sample(tensor, coords.permute(0, 2, 3, 1),
+ mode=mode, padding_mode=padding_mode,
+ align_corners=align_corners)
+ # Returned warped tensor
+ return warp
+
+
+def coords_from_optical_flow(optflow):
+ """
+ Get warping coordinates from optical flow
+ Parameters
+ ----------
+ optflow : torch.Tensor
+ Input optical flow tensor [B,2,H,W]
+
+ Returns
+ -------
+ coords : torch.Tensor
+ Warping coordinates [B,2,H,W]
+ """
+ # Create coordinate with optical flow
+ coords = pixel_grid(optflow, device=optflow) + optflow
+ # Normalize and return coordinate grid
+ return norm_pixel_grid(coords)
+
+
+def warp_depth_from_motion(ref_depth, tgt_depth, ref_cam):
+ """
+ Warp depth map using motion (depth + ego-motion) information
+
+ Parameters
+ ----------
+ ref_depth : torch.Tensor
+ Reference depth map [B,1,H,W]
+ tgt_depth : torch.Tensor
+ Target depth map [B,1,H,W]
+ ref_cam : Camera
+ Reference camera
+
+ Returns
+ -------
+ warp : torch.Tensor
+ Warped depth map [B,1,H,W]
+ """
+ ref_depth = reproject_depth_from_motion(ref_depth, ref_cam)
+ return warp_from_motion(ref_depth, tgt_depth, ref_cam)
+
+
+def reproject_depth_from_motion(ref_depth, ref_cam):
+ """
+ Calculate reprojected depth from motion (depth + ego-motion) information
+
+ Parameters
+ ----------
+ ref_depth : torch.Tensor
+ Reference depth map [B,1,H,W]
+ ref_cam : Camera
+ Reference camera
+
+ Returns
+ -------
+ coords : torch.Tensor
+ Warping coordinates from reprojection [B,2,H,W]
+ """
+ ref_points = ref_cam.reconstruct_depth_map(ref_depth, to_world=True)
+ return ref_cam.project_points(ref_points, from_world=False, return_z=True)[1]
+
+
+def warp_from_motion(ref_rgb, tgt_depth, ref_cam):
+ """
+ Warp image using motion (depth + ego-motion) information
+
+ Parameters
+ ----------
+ ref_rgb : torch.Tensor
+ Reference image [B,3,H,W]
+ tgt_depth : torch.Tensor
+ Target depth map [B,1,H,W]
+ ref_cam : Camera
+ Reference camera
+
+ Returns
+ -------
+ warp : torch.Tensor
+ Warped image [B,3,H,W]
+ """
+ tgt_points = ref_cam.reconstruct_depth_map(tgt_depth, to_world=False)
+ return warp_from_coords(ref_rgb, ref_cam.project_points(tgt_points, from_world=True).permute(0, 3, 1, 2))
+
+
+def coords_from_motion(ref_camera, tgt_depth, tgt_camera):
+ """
+ Get coordinates from motion (depth + ego-motion) information
+
+ Parameters
+ ----------
+ ref_camera : Camera
+ Reference camera
+ tgt_depth : torch.Tensor
+ Target depth map [B,1,H,W]
+ tgt_camera : Camera
+ Target camera
+
+ Returns
+ -------
+ coords : torch.Tensor
+ Warping coordinates [B,2,H,W]
+ """
+ if is_list(ref_camera):
+ return [coords_from_motion(camera, tgt_depth, tgt_camera)
+ for camera in ref_camera]
+ # If there are multiple depth maps, iterate for each
+ if is_list(tgt_depth):
+ return [coords_from_motion(ref_camera, depth, tgt_camera)
+ for depth in tgt_depth]
+ world_points = tgt_camera.reconstruct_depth_map(tgt_depth, to_world=True)
+ return ref_camera.project_points(world_points, from_world=True).permute(0, 3, 1, 2)
+
+
+def optflow_from_motion(ref_camera, tgt_depth):
+ """
+ Get optical flow from motion (depth + ego-motion) information
+
+ Parameters
+ ----------
+ ref_camera : Camera
+ Reference camera
+ tgt_depth : torch.Tensor
+ Target depth map
+
+ Returns
+ -------
+ optflow : torch.Tensor
+ Optical flow map [B,2,H,W]
+ """
+ coords = ref_camera.coords_from_depth(tgt_depth).permute(0, 3, 1, 2)
+ return optflow_from_coords(coords)
+
+
+def optflow_from_coords(coords):
+ """
+ Get optical flow from coordinates
+
+ Parameters
+ ----------
+ coords : torch.Tensor
+ Input warping coordinates [B,2,H,W]
+
+ Returns
+ -------
+ optflow : torch.Tensor
+ Optical flow map [B,2,H,W]
+ """
+ return unnorm_pixel_grid(coords) - pixel_grid(coords, device=coords)
+
+
+def warp_from_optflow(ref_rgb, tgt_optflow):
+ """
+ Warp image using optical flow information
+
+ Parameters
+ ----------
+ ref_rgb : torch.Tensor
+ Reference image [B,3,H,W]
+ tgt_optflow : torch.Tensor
+ Target optical flow [B,2,H,W]
+
+ Returns
+ -------
+ warp : torch.Tensor
+ Warped image [B,3,H,W]
+ """
+ coords = coords_from_optical_flow(tgt_optflow)
+ return warp_from_coords(ref_rgb, coords, align_corners=True,
+ mode='bilinear', padding_mode='zeros')
+
+
+def reverse_optflow(tgt_optflow, ref_optflow):
+ """
+ Reverse optical flow
+
+ Parameters
+ ----------
+ tgt_optflow : torch.Tensor
+ Target optical flow [B,2,H,W]
+ ref_optflow : torch.Tensor
+ Reference optical flow [B,2,H,W]
+
+ Returns
+ -------
+ optflow : torch.Tensor
+ Reversed optical flow [B,2,H,W]
+ """
+ return - warp_from_optflow(tgt_optflow, ref_optflow)
+
+
+def mask_from_coords(coords, align_corners=True):
+ """
+ Get overlap mask from coordinates
+
+ Parameters
+ ----------
+ coords : torch.Tensor
+ Warping coordinates [B,2,H,W]
+ align_corners : Bool
+ Align corners flag
+
+ Returns
+ -------
+ mask : torch.Tensor
+ Overlap mask [B,1,H,W]
+ """
+ if is_list(coords):
+ return [mask_from_coords(coord) for coord in coords]
+ b, _, h, w = coords.shape
+ mask = torch.ones((b, 1, h, w), dtype=torch.float32, device=coords.device, requires_grad=False)
+ mask = warp_from_coords(mask, coords, mode='nearest', padding_mode='zeros', align_corners=True)
+ return mask.bool()
+
+
+def depth_from_optflow(rgb, intrinsics, pose_context, flows,
+ residual=False, clip_range=None):
+ """
+ Get depth from optical flow + camera information
+
+ Parameters
+ ----------
+ rgb : torch.Tensor
+ Base image [B,3,H,W]
+ intrinsics : torch.Tensor
+ Camera intrinsics [B,3,3]
+ pose_context : torch.Tensor or list[torch.Tensor]
+ List of relative context camera poses [B,4,4]
+ flows : torch.Tensor or list[torch.Tensor]
+ List of target optical flows [B,2,H,W]
+ residual : Bool
+ Return residual error with depth
+ clip_range : Tuple
+ Depth range clipping values
+
+ Returns
+ -------
+ depth : torch.Tensor
+ Depth map [B,1,H,W]
+ """
+ # Make lists if necessary
+ flows = make_list(flows)
+ pose_context = make_list(pose_context)
+ # Extract rotations and translations
+ rotations = [p[:, :3, :3] for p in pose_context]
+ translations = [p[:, :3, -1] for p in pose_context]
+ # Get bearings
+ bearings = bearing_grid(rgb, intrinsics).to(rgb.device)
+ rot_bearings = [mult_rotation_bearing(rotation, bearings)
+ for rotation in rotations]
+ # Return triangulation results
+ return triangulation(rot_bearings, translations, flows, intrinsics,
+ clip_range=clip_range, residual=residual)
diff --git a/vidar/utils/flow_triangulation_support.py b/vidar/utils/flow_triangulation_support.py
new file mode 100755
index 0000000000000000000000000000000000000000..286005be792e4fe562d0695001ab3c2d72c0f0a3
--- /dev/null
+++ b/vidar/utils/flow_triangulation_support.py
@@ -0,0 +1,223 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import numpy as np
+import torch
+import torch.nn.functional as tfunc
+
+from vidar.utils.tensor import pixel_grid, cat_channel_ones
+
+
+def bearing_grid(rgb, intrinsics):
+ """
+ Create a homogeneous bearing grid from camera intrinsics and a base image
+
+ Parameters
+ ----------
+ rgb : torch.Tensor
+ Base image for dimensions [B,3,H,W]
+ intrinsics : torch.Tensor
+ Camera intrinsics [B,3,3]
+
+ Returns
+ -------
+ grid : torch.Tensor
+ Bearing grid [B,3,H,W]
+ """
+ # Create pixel grid from base image
+ b, _, h, w = rgb.shape
+ grid = pixel_grid((h, w), b).to(rgb.device)
+ # Normalize pixel grid with camera parameters
+ grid[:, 0] = (grid[:, 0] - intrinsics[:, 0, 2].unsqueeze(-1).unsqueeze(-1)) / intrinsics[:, 0, 0].unsqueeze(-1).unsqueeze(-1)
+ grid[:, 1] = (grid[:, 1] - intrinsics[:, 1, 2].unsqueeze(-1).unsqueeze(-1)) / intrinsics[:, 1, 1].unsqueeze(-1).unsqueeze(-1)
+ # Return bearing grid (with 1s as extra dimension)
+ return cat_channel_ones(grid)
+
+
+def mult_rotation_bearing(rotation, bearing):
+ """
+ Rotates a bearing grid
+
+ Parameters
+ ----------
+ rotation : torch.Tensor
+ Rotation matrix [B,3,3]
+ bearing : torch.Tensor
+ Bearing grid [B,3,H,W]
+
+ Returns
+ -------
+ rot_bearing : torch.Tensor
+ Rotated bearing grid [B,3,H,W]
+ """
+ # Multiply rotation and bearing
+ product = torch.bmm(rotation, bearing.view(bearing.shape[0], 3, -1))
+ # Return product with bearing shape
+ return product.view(bearing.shape)
+
+
+def pre_triangulation(ref_bearings, ref_translations, tgt_flows,
+ intrinsics, concat=True):
+ """
+ Triangulates bearings and flows
+
+ Parameters
+ ----------
+ ref_bearings : list[torch.Tensor]
+ Reference bearings [B,3,H,W]
+ ref_translations : list[torch.Tensor]
+ Reference translations [B,3]
+ tgt_flows : list[torch.Tensor]
+ Target optical flow values [B,2,H,W]
+ intrinsics : torch.Tensor
+ Camera intrinsics [B,3,3]
+ concat : Bool
+ True if cross product results are concatenated
+
+ Returns
+ -------
+ rs : torch.Tensor or list[torch.Tensor]
+ Bearing x translation cross product [B,3,H,W] (concatenated or not)
+ ss : torch.Tensor or list[torch.Tensor]
+ Bearing x bearing cross product [B,3,H,W] (concatenated or not)
+ """
+ # Get target bearings from flow
+ tgt_bearings = [flow2bearing(flow, intrinsics, normalize=True)
+ for flow in tgt_flows]
+ # Bearings x translation cross product
+ rs = [torch.cross(tgt_bearing, ref_translation[:, :, None, None].expand_as(tgt_bearing), dim=1)
+ for tgt_bearing, ref_translation in zip(tgt_bearings, ref_translations)]
+ # Bearings x bearings cross product
+ ss = [torch.cross(tgt_bearing, ref_bearing, dim=1)
+ for tgt_bearing, ref_bearing in zip(tgt_bearings, ref_bearings)]
+ if concat:
+ # If results are to be concatenated
+ return torch.cat(rs, dim=1), torch.cat(ss, dim=1)
+ else:
+ # Otherwise, return as lists
+ return rs, ss
+
+
+def depth_ls2views(r, s, clip_range=None):
+ """
+ Least-squares depth estimation from two views
+
+ Parameters
+ ----------
+ r : torch.Tensor
+ Bearing x translation cross product between images [B,3,H,W]
+ s : torch.Tensor
+ Bearing x translation cross product between images [B,3,H,W]
+ clip_range : Tuple
+ Depth clipping range (min, max)
+
+ Returns
+ -------
+ depth : torch.Tensor
+ Calculated depth [B,1,H,W]
+ error : torch.Tensor
+ Calculated error [B,1,H,W]
+ hessian : torch.Tensor
+ Calculated hessian [B,1,H,W]
+
+ """
+ # Calculate matrices
+ hessian = (s * s).sum(dim=1, keepdims=True)
+ depth = -(s * r).sum(dim=1, keepdims=True) / (hessian + 1e-30)
+ error = (r * r).sum(dim=1, keepdims=True) - hessian * (depth ** 2)
+
+ # Clip depth and other matrices if requested
+ if clip_range is not None:
+
+ invalid_mask = (depth <= clip_range[0])
+ invalid_mask |= (depth >= clip_range[1])
+
+ depth[invalid_mask] = 0
+ error[invalid_mask] = 0
+ hessian[invalid_mask] = 0
+ # Return calculated matrices
+ return depth, error, hessian
+
+
+def flow2bearing(flow, intrinsics, normalize=True):
+ """
+ Convert optical flow to bearings
+
+ Parameters
+ ----------
+ flow : torch.Tensor
+ Input optical flow [B,2,H,W]
+ intrinsics : torch.Tensor
+ Camera intrinsics [B,3,3]
+ normalize : Bool
+ True if bearings are normalized
+
+ Returns
+ -------
+ bearings : torch.Tensor
+ Calculated bearings [B,3,H,W]
+ """
+ # Create initial grid
+ height, width = flow.shape[2:]
+ xx, yy = np.meshgrid(range(width), range(height))
+ # Initialize bearing matrix
+ bearings = torch.zeros_like(flow)
+ # Populate bearings
+ match = (flow[:, 0] + torch.from_numpy(xx).to(flow.device),
+ flow[:, 1] + torch.from_numpy(yy).to(flow.device))
+ bearings[:, 0] = (match[0] - intrinsics[:, 0, 2].unsqueeze(-1).unsqueeze(-1)) / intrinsics[:, 0, 0].unsqueeze(-1).unsqueeze(-1)
+ bearings[:, 1] = (match[1] - intrinsics[:, 1, 2].unsqueeze(-1).unsqueeze(-1)) / intrinsics[:, 1, 1].unsqueeze(-1).unsqueeze(-1)
+ # Stack 1s as the last dimension
+ bearings = cat_channel_ones(bearings)
+ # Normalize if necessary
+ if normalize:
+ bearings = tfunc.normalize(bearings)
+ # Return bearings
+ return bearings
+
+
+def triangulation(ref_bearings, ref_translations,
+ tgt_flows, intrinsics, clip_range=None, residual=False):
+ """
+ Triangulate optical flow points to produce depth estimates
+
+ Parameters
+ ----------
+ ref_bearings : list[torch.Tensor]
+ Reference bearings [B,3,H,W]
+ ref_translations : list[torch.Tensor]
+ Reference translations [B,3]
+ tgt_flows : list[torch.Tensor]
+ Target optical flow to reference [B,2,H,W]
+ intrinsics : torch.Tensor
+ Camera intrinsics [B,3,3]
+ clip_range : Tuple
+ Depth clipping range
+ residual : Bool
+ True to return residual error and squared root of Hessian
+
+ Returns
+ -------
+ depth : torch.Tensor
+ Estimated depth [B,1,H,W]
+ error : torch.Tensor
+ Estimated error [B,1,H,W]
+ sqrt_hessian : torch.Tensor
+ Squared root of Hessian [B,1,H,W]
+ """
+ # Pre-triangulate flows
+ rs, ss = pre_triangulation(ref_bearings, ref_translations, tgt_flows, intrinsics, concat=False)
+ # Calculate list of triangulations
+ outputs = [depth_ls2views(*rs_ss, clip_range=clip_range) for rs_ss in zip(rs, ss)]
+ # Calculate predicted hessian and depths
+ hessian = sum([output[2] for output in outputs])
+ depth = sum([output[0] * output[2] for output in outputs]) / (hessian + 1e-12)
+ # Return depth + residual error and hessian matrix
+ if residual:
+ error = torch.sqrt(sum([output[2] * (depth - output[0]) ** 2 + output[1]
+ for output in outputs]).clamp_min(0))
+ sqrt_hessian = torch.sqrt(hessian)
+ return depth, (error, sqrt_hessian)
+ # Return depth
+ else:
+ return depth
+
diff --git a/vidar/utils/logging.py b/vidar/utils/logging.py
new file mode 100755
index 0000000000000000000000000000000000000000..b0e85b2a5cad2e51e886c94d20efcab9078ecf5a
--- /dev/null
+++ b/vidar/utils/logging.py
@@ -0,0 +1,144 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import argparse
+import os
+from functools import partial
+
+from termcolor import colored
+
+from vidar.utils.distributed import on_rank_0
+
+
+def pcolor(string, color, on_color=None, attrs=None):
+ """
+ Produces a colored string for printing
+
+ Parameters
+ ----------
+ string : String
+ String that will be colored
+ color : String
+ Color to use
+ on_color : String
+ Background color to use
+ attrs : list[String]
+ Different attributes for the string
+
+ Returns
+ -------
+ string: String
+ Colored string
+ """
+ return colored(string, color, on_color, attrs)
+
+
+@on_rank_0
+def print_config(config):
+ """
+ Prints header for model configuration
+
+ Parameters
+ ----------
+ config : Config
+ Model configuration
+ """
+ header_colors = {
+ 0: ('red', ('bold', 'dark')),
+ 1: ('cyan', ('bold','dark')),
+ 2: ('green', ('bold', 'dark')),
+ 3: ('green', ('bold', 'dark')),
+ }
+ line_colors = ('blue', ())
+
+ # Recursive print function
+ def print_recursive(rec_args, pad=3, level=0):
+ # if level == 0:
+ # print(pcolor('config:',
+ # color=header_colors[level][0],
+ # attrs=header_colors[level][1]))
+ for key, val in rec_args.__dict__.items():
+ if isinstance(val, argparse.Namespace):
+ print(pcolor('{} {}:'.format('-' * pad, key),
+ color=header_colors[level][0],
+ attrs=header_colors[level][1]))
+ print_recursive(val, pad + 2, level + 1)
+ else:
+ print('{}: {}'.format(pcolor('{} {}'.format('-' * pad, key),
+ color=line_colors[0],
+ attrs=line_colors[1]), val))
+
+ # Color partial functions
+ pcolor1 = partial(pcolor, color='blue', attrs=('bold', 'dark'))
+ pcolor2 = partial(pcolor, color='blue', attrs=('bold',))
+ # Config and name
+ line = pcolor1('#' * 120)
+ # if 'default' in config.__dict__.keys():
+ # path = pcolor1('### Config: ') + \
+ # pcolor2('{}'.format(config.default.replace('/', '.'))) + \
+ # pcolor1(' -> ') + \
+ # pcolor2('{}'.format(config.config.replace('/', '.')))
+ # if 'name' in config.__dict__.keys():
+ # name = pcolor1('### Name: ') + \
+ # pcolor2('{}'.format(config.name))
+ # # Add wandb link if available
+ # if not config.wandb.dry_run:
+ # name += pcolor1(' -> ') + \
+ # pcolor2('{}'.format(config.wandb.url))
+ # # Add s3 link if available
+ # if config.checkpoint.s3_path is not '':
+ # name += pcolor1('\n### s3:') + \
+ # pcolor2(' {}'.format(config.checkpoint.s3_url))
+ # # # Create header string
+ # # header = '%s\n%s\n%s\n%s' % (line, path, name, line)
+
+ # Print header, config and header again
+ print()
+ # print(header)
+ print_recursive(config)
+ # print(header)
+ print()
+
+
+def set_debug(debug):
+ """
+ Enable or disable debug terminal logging
+
+ Parameters
+ ----------
+ debug : Bool
+ Debugging flag (True to enable)
+ """
+ # Disable logging if requested
+ if not debug:
+ os.environ['NCCL_DEBUG'] = ''
+ os.environ['WANDB_SILENT'] = 'true'
+ # warnings.filterwarnings("ignore")
+ # logging.disable(logging.CRITICAL)
+
+
+class AvgMeter:
+ """Average meter for logging"""
+ def __init__(self, n_max=100):
+ self.n_max = n_max
+ self.values = []
+
+ def __call__(self, value):
+ """Append new value and returns average"""
+ self.values.append(value)
+ if len(self.values) > self.n_max:
+ self.values.pop(0)
+ return self.get()
+
+ def get(self):
+ """Get current average"""
+ return sum(self.values) / len(self.values)
+
+ def reset(self):
+ """Reset meter"""
+ self.values.clear()
+
+ def get_and_reset(self):
+ """Get current average and reset"""
+ average = self.get()
+ self.reset()
+ return average
diff --git a/vidar/utils/networks.py b/vidar/utils/networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..29bc8216b6738d301c33c3b4eecbc51611f93e5d
--- /dev/null
+++ b/vidar/utils/networks.py
@@ -0,0 +1,203 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import math
+
+import torch
+import torch.nn as nn
+
+from vidar.utils.distributed import print0, rank, dist_mode
+from vidar.utils.logging import pcolor
+from vidar.utils.tensor import same_shape
+from vidar.utils.types import is_list
+
+
+def freeze_layers(network, layers=('ALL',), flag_freeze=True):
+ """
+ Freeze layers of a network (weights and biases)
+
+ Parameters
+ ----------
+ network : nn.Module
+ Network to be modified
+ layers : List or Tuple
+ List of layers to freeze/unfreeze ('ALL' for everything)
+ flag_freeze : Bool
+ Whether the layers will be frozen (True) or not (False)
+ """
+ if len(layers) > 0:
+ for name, parameters in network.named_parameters():
+ for layer in layers:
+ if layer in name or layer == 'ALL':
+ parameters.requires_grad_(not flag_freeze)
+
+
+def freeze_norms(network, layers=('ALL',), flag_freeze=True):
+ """
+ Freeze layers of a network (normalization)
+
+ Parameters
+ ----------
+ network : nn.Module
+ Network to be modified
+ layers : List or Tuple
+ List of layers to freeze/unfreeze ('ALL' for everything)
+ flag_freeze : Bool
+ Whether the layers will be frozen (True) or not (False)
+ """
+ if len(layers) > 0:
+ for name, module in network.named_modules():
+ for layer in layers:
+ if layer in name or layer == 'ALL':
+ if isinstance(module, nn.BatchNorm2d):
+ if hasattr(module, 'weight'):
+ module.weight.requires_grad_(not flag_freeze)
+ if hasattr(module, 'bias'):
+ module.bias.requires_grad_(not flag_freeze)
+ if flag_freeze:
+ module.eval()
+ else:
+ module.train()
+
+
+def freeze_layers_and_norms(network, layers=('ALL',), flag_freeze=True):
+ """Freeze layers and normalizations of a network"""
+ freeze_layers(network, layers, flag_freeze)
+ freeze_norms(network, layers, flag_freeze)
+
+
+def make_val_fit(model, key, val, updated_state_dict, strict=False):
+ """
+ Parse state dictionary to fit a model, and make tensors fit if requested
+
+ Parameters
+ ----------
+ model : nn.Module
+ Network to be used
+ key : String
+ Which key will be used
+ val : torch.Tensor
+ Key value
+ updated_state_dict : Dict
+ Updated dictionary
+ strict : Bool
+ True if no changes are allowed, False if tensors can be changed to fit
+
+ Returns
+ -------
+ fit : Int
+ Number of tensors that fit the model
+ """
+ fit = 0
+ val_new = model.state_dict()[key]
+ if same_shape(val.shape, val_new.shape):
+ updated_state_dict[key] = val
+ fit += 1
+ elif not strict:
+ for i in range(val.dim()):
+ if val.shape[i] != val_new.shape[i]:
+ if val_new.shape[i] > val.shape[i]:
+ ratio = math.ceil(val_new.shape[i] / val.shape[i])
+ val = torch.cat([val] * ratio, i)
+ if val.shape[i] != val_new.shape[i]:
+ val = val[:val_new.shape[i]]
+ if same_shape(val.shape, val_new.shape):
+ updated_state_dict[key] = val
+ fit += 1
+ elif val_new.shape[0] < val.shape[i]:
+ val = val[:val_new.shape[i]]
+ if same_shape(val.shape, val_new.shape):
+ updated_state_dict[key] = val
+ fit += 1
+ assert fit <= 1 # Each tensor cannot fit 2 or more times
+ return fit
+
+
+def load_checkpoint(model, checkpoint, strict=False, verbose=False, prefix=None):
+ """
+ Load checkpoint into a model
+
+ Parameters
+ ----------
+ model : nn.Module
+ Input network
+ checkpoint : String or list[String]
+ Checkpoint path (if it's a list, load them in order)
+ strict : Bool
+ True if all tensors are required, False if can be partially loaded
+ verbose : Bool
+ Print information on screen
+ prefix : String
+ Prefix used to change keys
+
+ Returns
+ -------
+ model: nn.Module
+ Loaded network
+ """
+ if is_list(checkpoint):
+ for ckpt in checkpoint:
+ load_checkpoint(model, ckpt, strict, verbose)
+ return model
+
+ font1 = {'color': 'magenta', 'attrs': ('bold', 'dark')}
+ font2 = {'color': 'magenta', 'attrs': ('bold',)}
+
+ if verbose:
+ print0(pcolor('#' * 60, **font1))
+ print0(pcolor('###### Loading from checkpoint: ', **font1) +
+ pcolor('{}'.format(checkpoint), **font2))
+
+ state_dict = torch.load(
+ checkpoint,
+ map_location='cpu' if dist_mode() == 'cpu' else 'cuda:{}'.format(rank())
+ )['state_dict']
+ updated_state_dict = {}
+
+ total, fit = len(model.state_dict()), 0
+ for key, val in state_dict.items():
+
+ for start in ['model.', 'module.']:
+ if key.startswith(start):
+ key = key[len(start):]
+ if prefix is not None:
+ idx = key.find(prefix)
+ if idx > -1:
+ key = key[(idx + len(prefix) + 1):]
+ if key in model.state_dict().keys():
+ fit += make_val_fit(model, key, val, updated_state_dict, strict=strict)
+
+ model.load_state_dict(updated_state_dict, strict=strict)
+
+ if verbose:
+ color = 'red' if fit == 0 else 'yellow' if fit < total else 'green'
+ print0(pcolor('###### Loaded ', **font1) + \
+ pcolor('{}/{}'.format(fit,total), color=color, attrs=('bold',)) + \
+ pcolor(' tensors', **font1))
+ print0(pcolor('#' * 60, **font1))
+
+ return model
+
+
+def save_checkpoint(filename, wrapper, epoch=None):
+ """
+ Save checkpoint to disk
+
+ Parameters
+ ----------
+ filename : String
+ Name of the file
+ wrapper : nn.Module
+ Model wrapper to save
+ epoch : Int
+ Training epoch
+ """
+ if epoch is None:
+ torch.save({
+ 'state_dict': wrapper.state_dict(),
+ }, filename)
+ else:
+ torch.save({
+ 'epoch': epoch,
+ 'config': wrapper.cfg,
+ 'state_dict': wrapper.arch.state_dict(),
+ }, filename)
diff --git a/vidar/utils/read.py b/vidar/utils/read.py
new file mode 100755
index 0000000000000000000000000000000000000000..2b85c499fb9ff49315127b54d196437e56c7133a
--- /dev/null
+++ b/vidar/utils/read.py
@@ -0,0 +1,74 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import pickle as pkl
+
+import numpy as np
+from PIL import Image
+
+from vidar.utils.decorators import iterate1
+
+
+def read_pickle(filename):
+ """
+ Read pickle file
+
+ Parameters
+ ----------
+ filename : String
+ File to read from
+
+ Returns
+ -------
+ data : Value
+ Data loaded from file
+ """
+ if not filename.endswith('.pkl'):
+ filename += '.pkl'
+ return pkl.load(open(filename, 'rb'))
+
+
+@iterate1
+def read_image(path):
+ """
+ Read an image using PIL
+
+ Parameters
+ ----------
+ path : String
+ Path to the image
+
+ Returns
+ -------
+ image : PIL Image
+ Loaded image
+ """
+ return Image.open(path)
+
+
+@iterate1
+def read_depth(file):
+ """
+ Load a depth map from file
+
+ Parameters
+ ----------
+ file : String
+ Depth map filename (.npz or .png or .dpt)
+
+ Returns
+ -------
+ depth : np.array
+ Depth map (invalid pixels are 0) [H,W]
+ """
+ # If loading a .npz array
+ if file.endswith('npz'):
+ return np.load(file)['depth']
+ # If loading a .png image
+ elif file.endswith('png'):
+ depth_png = np.array(read_image(file), dtype=int)
+ assert (np.max(depth_png) > 255), 'Wrong .png depth file'
+ return depth_png.astype(np.float) / 256.
+ # Invalid type
+ else:
+ raise NotImplementedError('Depth extension not supported.')
+
diff --git a/vidar/utils/reduce.py b/vidar/utils/reduce.py
new file mode 100755
index 0000000000000000000000000000000000000000..717b5d5a827a8b2510eb068bfbbe8c5044882d46
--- /dev/null
+++ b/vidar/utils/reduce.py
@@ -0,0 +1,76 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from collections import OrderedDict
+
+
+def average_key(batch_list, key):
+ """
+ Average key in a list of batches
+
+ Parameters
+ ----------
+ batch_list : list[Dict]
+ List containing dictionaries with the same keys
+ key : String
+ Key to be averaged
+
+ Returns
+ -------
+ average : Float
+ Average of the value contained in key for all batches
+ """
+ values = [batch[key] for batch in batch_list]
+ return sum(values) / len(values)
+
+
+def average_sub_key(batch_list, key, sub_key):
+ """
+ Average subkey in a dictionary in a list of batches
+
+ Parameters
+ ----------
+ batch_list : list[Dict]
+ List containing dictionaries with the same keys
+ key : String
+ Key to be averaged
+ sub_key : String
+ Sub key to be averaged (belonging to key)
+
+ Returns
+ -------
+ average : Float
+ Average of the value contained in the sub_key of key for all batches
+ """
+ values = [batch[key][sub_key] for batch in batch_list]
+ return sum(values) / len(values)
+
+
+def average_loss_and_metrics(batch_list, prefix):
+ """
+ Average loss and metrics values in a list of batches
+
+ Parameters
+ ----------
+ batch_list : list[Dict]
+ List containing dictionaries with the same keys
+ prefix : String
+ Prefix string for metrics logging
+
+ Returns
+ -------
+ values : Dict
+ Dictionary containing a 'loss' float entry and a 'metrics' dict entry
+ """
+ values = OrderedDict()
+
+ key = 'loss'
+ values['{}-{}'.format(prefix, key)] = \
+ average_key(batch_list, key)
+
+ key = 'metrics'
+ for sub_key in batch_list[0][key].keys():
+ values['{}-{}'.format(prefix, sub_key)] = \
+ average_sub_key(batch_list, key, sub_key)
+
+ return values
+
diff --git a/vidar/utils/setup.py b/vidar/utils/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef1533bc90d773c30eaf6ce1098ae56674cd73d9
--- /dev/null
+++ b/vidar/utils/setup.py
@@ -0,0 +1,361 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import time
+from collections import OrderedDict
+from copy import deepcopy
+
+import numpy as np
+import torch
+from torch.utils.data import ConcatDataset, DataLoader
+
+from vidar.datasets.utils.transforms import get_transforms
+from vidar.metrics.depth import DepthEvaluation
+from vidar.utils.config import get_folder_name, load_class, \
+ recursive_assignment, cfg_has, cfg_add_to_dict, get_from_cfg_list
+from vidar.utils.config import merge_dict, to_namespace
+from vidar.utils.data import flatten, keys_in
+from vidar.utils.decorators import iterate1
+from vidar.utils.distributed import print0, rank, world_size, dist_mode
+from vidar.utils.logging import pcolor
+from vidar.utils.networks import load_checkpoint, save_checkpoint
+from vidar.utils.types import is_namespace
+
+
+def setup_arch(cfg, checkpoint=None, verbose=False):
+ """
+ Set architecture up for training/inference
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration file
+ checkpoint : String
+ Checkpoint to be loaded
+ verbose : Bool
+ Print information on screen
+
+ Returns
+ -------
+ model: nn.Module
+ Model ready to go
+ """
+ font = {'color': 'green'}
+
+ if verbose:
+ print0(pcolor('#' * 60, **font))
+ print0(pcolor('### Preparing Architecture', **font))
+ print0(pcolor('#' * 60, **font))
+
+ font1 = {'color': 'yellow', 'attrs': ('dark',)}
+ font2 = {'color': 'yellow', 'attrs': ('dark', 'bold')}
+
+ folder, name = get_folder_name(cfg.model.file, 'models')
+ model = load_class(name, folder)(cfg)
+
+ if cfg_has(cfg, 'model'):
+ if verbose:
+ print0(pcolor('###### Model:', **font2))
+ print0(pcolor('######### %s' % model.__class__.__name__, **font1))
+ recursive_assignment(model, cfg.model, 'models', verbose=verbose)
+
+ if cfg_has(cfg, 'networks'):
+ if verbose:
+ print0(pcolor('###### Networks:', **font2))
+ recursive_assignment(model, cfg.networks, 'networks', verbose=verbose)
+
+ if cfg_has(cfg, 'losses'):
+ if verbose:
+ print0(pcolor('###### Losses:', **font2))
+ recursive_assignment(model, cfg.losses, 'losses', verbose=verbose)
+
+ if checkpoint is not None:
+ model = load_checkpoint(model, checkpoint,
+ strict=True, verbose=verbose)
+ elif cfg_has(cfg.model, 'checkpoint'):
+ model = load_checkpoint(model, cfg.model.checkpoint,
+ strict=cfg.model.has('checkpoint_strict', False), verbose=verbose)
+
+ if cfg.model.has('checkpoint_save'):
+ save_checkpoint(cfg.model.checkpoint_save, model)
+
+ return model
+
+
+def setup_dataset(cfg, root='vidar/datasets', verbose=False):
+ """
+ Set dataset up for training/inference
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration file
+ root : String
+ Where the dataset is located
+ verbose : Bool
+ Print information on screen
+
+ Returns
+ -------
+ dataset : Dataset
+ Dataset ready to go
+ """
+ shared_keys = ['context', 'labels', 'labels_context']
+
+ num_datasets = 0
+ for key, val in cfg.__dict__.items():
+ if key not in shared_keys and not is_namespace(val):
+ num_datasets = max(num_datasets, len(val))
+
+ datasets = []
+ for i in range(num_datasets):
+ args = {}
+ for key, val in cfg.__dict__.items():
+ if not is_namespace(val):
+ cfg_add_to_dict(args, cfg, key, i if key not in shared_keys else None)
+
+ args['data_transform'] = get_transforms('train', cfg.augmentation) \
+ if cfg_has(cfg, 'augmentation') else get_transforms('none')
+
+ name = get_from_cfg_list(cfg, 'name', i)
+ repeat = get_from_cfg_list(cfg, 'repeat', i)
+ cameras = get_from_cfg_list(cfg, 'cameras', i)
+
+ context = cfg.context
+ labels = cfg.labels
+
+ dataset = load_class(name + 'Dataset', root)(**args)
+
+ if cfg_has(cfg, 'repeat') and repeat > 1:
+ dataset = ConcatDataset([dataset for _ in range(repeat)])
+
+ if verbose:
+ string = f'######### {name}: {len(dataset)} samples'
+ if cfg_has(cfg, 'repeat'):
+ string += f' (x{repeat})'
+ if cfg_has(cfg, 'context'):
+ string += f' | context {context}'.replace(', ', ',')
+ if cfg_has(cfg, 'cameras'):
+ string += f' | cameras {cameras}'.replace(', ', ',')
+ if cfg_has(cfg, 'labels'):
+ string += f' | labels {labels}'.replace(', ', ',')
+ print0(pcolor(string , color='yellow', attrs=('dark',)))
+
+ datasets.append(dataset)
+
+ return datasets
+
+
+def setup_datasets(cfg, verbose=False, concat_modes=('train', 'mixed'), stack=True):
+ """
+ Set multiple datasets up for training/inference
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration file
+ verbose : Bool
+ Print information on screen
+ concat_modes : String
+ Which dataset modes are going to be concatenated into a single one
+ stack : Bool
+ Whether datasets are stacked together
+
+ Returns
+ -------
+ datasets : Dict
+ Datasets ready to go
+ datasets_cfg : Dict
+ Dataset configurations
+ """
+ if verbose:
+ print0(pcolor('#' * 60, 'green'))
+ print0(pcolor('### Preparing Datasets', 'green'))
+ print0(pcolor('#' * 60, 'green'))
+
+ font = {'color': 'yellow', 'attrs': ('bold', 'dark')}
+
+ datasets_cfg = {}
+ for key in cfg.__dict__.keys():
+ datasets_cfg[key] = cfg.__dict__[key]
+ for mode in ['train', 'validation']:
+ if key.startswith(mode) and key != mode and mode in cfg.__dict__.keys():
+ datasets_cfg[key] = to_namespace(merge_dict(deepcopy(
+ cfg.__dict__[mode].__dict__), cfg.__dict__[key].__dict__))
+
+ datasets = {}
+ for key, val in list(datasets_cfg.items()):
+ if 'name' in val.__dict__.keys():
+ if verbose:
+ print0(pcolor('###### {}'.format(key), **font))
+ datasets[key] = setup_dataset(val, verbose=verbose)
+ datasets_cfg[key] = [datasets_cfg[key]] * len(datasets[key])
+ for mode in concat_modes:
+ if key.startswith(mode) and len(datasets[key]) > 1:
+ datasets[key] = ConcatDataset(datasets[key])
+ else:
+ datasets_cfg.pop(key)
+
+ if stack:
+ datasets = stack_datasets(datasets)
+
+ modes = ['train', 'mixed', 'validation', 'test']
+ reduced_datasets_cfg = {key: [] for key in modes}
+ for key, val in datasets_cfg.items():
+ for mode in modes:
+ if key.startswith(mode):
+ reduced_datasets_cfg[mode].append(val)
+ for key in list(reduced_datasets_cfg.keys()):
+ reduced_datasets_cfg[key] = flatten(reduced_datasets_cfg[key])
+ if len(reduced_datasets_cfg[key]) == 0:
+ reduced_datasets_cfg.pop(key)
+ datasets_cfg = reduced_datasets_cfg
+
+ if 'train' in datasets_cfg:
+ datasets_cfg['train'] = datasets_cfg['train'][0]
+
+ return datasets, datasets_cfg
+
+
+def setup_metrics(cfg):
+ """
+ Set metrics up for evaluation
+
+ Parameters
+ ----------
+ cfg : Config
+ Configuration file
+
+ Returns
+ -------
+ tasks : Dict
+ Dictionary containing metric classes for requested tasks
+ """
+
+ methods = {
+ 'depth': DepthEvaluation,
+ }
+
+ available_tasks = [key for key in cfg.__dict__.keys() if key is not 'tasks']
+ requested_tasks = cfg_has(cfg, 'tasks', available_tasks)
+ tasks = [task for task in available_tasks if task in requested_tasks and task in methods]
+
+ return {task: methods[task](cfg.__dict__[task]) for task in tasks}
+
+
+def worker_init_fn(worker_id):
+ """Function to initialize workers"""
+ time_seed = np.array(time.time(), dtype=np.int32)
+ np.random.seed(time_seed + worker_id)
+
+
+def get_datasampler(dataset, shuffle):
+ """Return distributed data sampler"""
+ return torch.utils.data.distributed.DistributedSampler(
+ dataset, shuffle=shuffle,
+ num_replicas=world_size(), rank=rank())
+
+
+def no_collate(batch):
+ """Dummy function to use when dataset is not to be collated"""
+ return batch
+
+
+@iterate1
+def setup_dataloader(dataset, cfg, mode):
+ """
+ Create a dataloader class
+
+ Parameters
+ ----------
+ mode : String {'train', 'validation', 'test'}
+ Mode from which we want the dataloader
+ dataset : Dataset
+ List of datasets from which to create dataloaders
+ cfg : Config
+ Model configuration (cf. configs/default_config.py)
+
+ Returns
+ -------
+ dataloaders : list[Dataloader]
+ List of created dataloaders for each input dataset
+ """
+ ddp = dist_mode() == 'ddp'
+ shuffle = 'train' in mode
+ return DataLoader(dataset,
+ batch_size=cfg_has(cfg, 'batch_size', 1),
+ pin_memory=cfg_has(cfg, 'pin_memory', True),
+ num_workers=cfg_has(cfg, 'num_workers', 8),
+ worker_init_fn=worker_init_fn,
+ shuffle=False if ddp else shuffle,
+ sampler=get_datasampler(dataset, shuffle=shuffle) if ddp else None,
+ collate_fn=None if cfg_has(cfg, 'collate', True) else no_collate,
+ )
+
+
+def reduce(data, modes, train_modes):
+ """
+ Reduce dictionary values
+
+ Parameters
+ ----------
+ data : Dict
+ Dictionary with data for reduction
+ modes : String
+ Data mode ('train', 'validation', 'test')
+ train_modes : list[String]
+ Which modes are training modes
+
+ Returns
+ -------
+ reduced : Dict
+ Dictionary with reduced information
+ """
+ reduced = {
+ mode: flatten([val for key, val in data.items() if mode in key])
+ for mode in modes
+ }
+ for key, val in list(reduced.items()):
+ if len(val) == 0:
+ reduced.pop(key)
+ for mode in keys_in(reduced, train_modes):
+ reduced[mode] = reduced[mode][0]
+ return reduced
+
+
+def stack_datasets(datasets):
+ """
+ Stack datasets together for training/validation
+
+ Parameters
+ ----------
+ datasets : Dict
+ Dictionary containing datasets
+
+ Returns
+ -------
+ stacked_datasets: : Dict
+ Dictionary containing stacked datasets
+ """
+ all_modes = ['train', 'mixed', 'validation', 'test']
+ train_modes = ['train', 'mixed']
+
+ stacked_datasets = OrderedDict()
+
+ for mode in all_modes:
+ stacked_datasets[mode] = []
+ for key, val in datasets.items():
+ if mode in key:
+ stacked_datasets[mode].append(val)
+ stacked_datasets[mode] = flatten(stacked_datasets[mode])
+
+ for mode in train_modes:
+ length = len(stacked_datasets[mode])
+ if length == 1:
+ stacked_datasets[mode] = stacked_datasets[mode][0]
+ elif length > 1:
+ stacked_datasets[mode] = ConcatDataset(stacked_datasets[mode])
+ for key in list(datasets.keys()):
+ if key.startswith(mode) and key != mode:
+ datasets.pop(key)
+
+ return stacked_datasets
diff --git a/vidar/utils/tensor.py b/vidar/utils/tensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..497404b1287cfa9b4acb067912cd11d07aa78c29
--- /dev/null
+++ b/vidar/utils/tensor.py
@@ -0,0 +1,315 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from functools import reduce
+
+import torch
+import torch.nn.functional as tfn
+
+from vidar.utils.decorators import iterate1
+from vidar.utils.types import is_tensor, is_dict, is_seq
+
+
+@iterate1
+def interpolate(tensor, size, scale_factor, mode, align_corners):
+ """
+ Interpolate a tensor to a different resolution
+
+ Parameters
+ ----------
+ tensor : torch.Tensor
+ Input tensor [B,?,H,W]
+ size : Tuple
+ Interpolation size (H,W)
+ scale_factor : Float
+ Scale factor for interpolation
+ mode : String
+ Interpolation mode
+ align_corners : Bool
+ Corner alignment flag
+
+ Returns
+ -------
+ tensor : torch.Tensor
+ Interpolated tensor [B,?,h,w]
+ """
+ if is_tensor(size):
+ size = size.shape[-2:]
+ return tfn.interpolate(
+ tensor, size=size, scale_factor=scale_factor,
+ mode=mode, align_corners=align_corners, recompute_scale_factor=False,
+ )
+
+
+def masked_average(loss, mask, eps=1e-7):
+ """Calculates the average of a tensor considering mask information"""
+ return (loss * mask).sum() / (mask.sum() + eps)
+
+
+def multiply_mask(data, mask):
+ """Multiplies a tensor with a mask"""
+ return data if (data is None or mask is None) else data * mask
+
+
+def multiply_args(*args):
+ """Multiplies input arguments"""
+ valids = [v for v in args if v is not None]
+ return None if not valids else reduce((lambda x, y: x * y), valids)
+
+
+def grid_sample(tensor, grid, padding_mode, mode, align_corners):
+ return tfn.grid_sample(tensor, grid,
+ padding_mode=padding_mode, mode=mode, align_corners=align_corners)
+
+
+def pixel_grid(hw, b=None, with_ones=False, device=None, normalize=False):
+ """
+ Creates a pixel grid for image operations
+
+ Parameters
+ ----------
+ hw : Tuple
+ Height/width of the grid
+ b : Int
+ Batch size
+ with_ones : Bool
+ Stack an extra channel with 1s
+ device : String
+ Device where the grid will be created
+ normalize : Bool
+ Whether the grid is normalized between [-1,1]
+
+ Returns
+ -------
+ grid : torch.Tensor
+ Output pixel grid [B,2,H,W]
+ """
+ if is_tensor(hw):
+ b, hw = hw.shape[0], hw.shape[-2:]
+ if is_tensor(device):
+ device = device.device
+ hi, hf = 0, hw[0] - 1
+ wi, wf = 0, hw[1] - 1
+ yy, xx = torch.meshgrid([torch.linspace(hi, hf, hw[0], device=device),
+ torch.linspace(wi, wf, hw[1], device=device)], indexing='ij')
+ if with_ones:
+ grid = torch.stack([xx, yy, torch.ones(hw, device=device)], 0)
+ else:
+ grid = torch.stack([xx, yy], 0)
+ if b is not None:
+ grid = grid.unsqueeze(0).repeat(b, 1, 1, 1)
+ if normalize:
+ grid = norm_pixel_grid(grid)
+ return grid
+
+
+def norm_pixel_grid(grid, hw=None, in_place=False):
+ """
+ Normalize a pixel grid to be between [0,1]
+
+ Parameters
+ ----------
+ grid : torch.Tensor
+ Grid to be normalized [B,2,H,W]
+ hw : Tuple
+ Height/Width for normalization
+ in_place : Bool
+ Whether the operation is done in place or not
+
+ Returns
+ -------
+ grid : torch.Tensor
+ Normalized grid [B,2,H,W]
+ """
+ if hw is None:
+ hw = grid.shape[-2:]
+ if not in_place:
+ grid = grid.clone()
+ grid[:, 0] = 2.0 * grid[:, 0] / (hw[1] - 1) - 1.0
+ grid[:, 1] = 2.0 * grid[:, 1] / (hw[0] - 1) - 1.0
+ return grid
+
+
+def unnorm_pixel_grid(grid, hw=None, in_place=False):
+ """
+ Unnormalize pixel grid to be between [0,H] and [0,W]
+
+ Parameters
+ ----------
+ grid : torch.Tensor
+ Grid to be normalized [B,2,H,W]
+ hw : Tuple
+ Height/width for unnormalization
+ in_place : Bool
+ Whether the operation is done in place or not
+
+ Returns
+ -------
+ grid : torch.Tensor
+ Unnormalized grid [B,2,H,W]
+ """
+ if hw is None:
+ hw = grid.shape[-2:]
+ if not in_place:
+ grid = grid.clone()
+ grid[:, 0] = 0.5 * (hw[1] - 1) * (grid[:, 0] + 1)
+ grid[:, 1] = 0.5 * (hw[0] - 1) * (grid[:, 1] + 1)
+ return grid
+
+
+def match_scales(image, targets, num_scales,
+ mode='bilinear', align_corners=True):
+ """
+ Creates multiple resolution versions of the input to match another list of tensors
+
+ Parameters
+ ----------
+ image : torch.Tensor
+ Input image [B,?,H,W]
+ targets : list[torch.Tensor]
+ Target resolutions
+ num_scales : int
+ Number of scales to consider
+ mode : String
+ Interpolation mode
+ align_corners : Bool
+ Corner alignment flag
+
+ Returns
+ -------
+ images : list[torch.Tensor]
+ List containing tensors in the required resolutions
+ """
+ # For all scales
+ images = []
+ image_shape = image.shape[-2:]
+ for i in range(num_scales):
+ target_shape = targets[i].shape
+ # If image shape is equal to target shape
+ if same_shape(image_shape, target_shape):
+ images.append(image)
+ else:
+ # Otherwise, interpolate
+ images.append(interpolate_image(
+ image, target_shape, mode=mode, align_corners=align_corners))
+ # Return scaled images
+ return images
+
+
+def cat_channel_ones(tensor, n=1):
+ """
+ Concatenate tensor with an extra channel of ones
+
+ Parameters
+ ----------
+ tensor : torch.Tensor
+ Tensor to be concatenated
+ n : Int
+ Which channel will be concatenated
+
+ Returns
+ -------
+ cat_tensor : torch.Tensor
+ Concatenated tensor
+ """
+ # Get tensor shape with 1 channel
+ shape = list(tensor.shape)
+ shape[n] = 1
+ # Return concatenation of tensor with ones
+ return torch.cat([tensor, torch.ones(shape,
+ device=tensor.device, dtype=tensor.dtype)], n)
+
+
+def same_shape(shape1, shape2):
+ """Checks if two shapes are the same"""
+ if len(shape1) != len(shape2):
+ return False
+ for i in range(len(shape1)):
+ if shape1[i] != shape2[i]:
+ return False
+ return True
+
+
+def interpolate_image(image, shape=None, scale_factor=None, mode='bilinear',
+ align_corners=True, recompute_scale_factor=False):
+ """
+ Interpolate an image to a different resolution
+
+ Parameters
+ ----------
+ image : torch.Tensor
+ Image to be interpolated [B,?,h,w]
+ shape : torch.Tensor or tuple
+ Output shape [H,W]
+ scale_factor : Float
+ Scale factor for output shape
+ mode : String
+ Interpolation mode
+ align_corners : Bool
+ True if corners will be aligned after interpolation
+ recompute_scale_factor : Bool
+ True if scale factor is recomputed
+
+ Returns
+ -------
+ image : torch.Tensor
+ Interpolated image [B,?,H,W]
+ """
+ assert shape is not None or scale_factor is not None, 'Invalid option for interpolate_image'
+ if mode == 'nearest':
+ align_corners = None
+ # Take last two dimensions as shape
+ if shape is not None:
+ if is_tensor(shape):
+ shape = shape.shape
+ if len(shape) > 2:
+ shape = shape[-2:]
+ # If the shapes are the same, do nothing
+ if same_shape(image.shape[-2:], shape):
+ return image
+ # Interpolate image to match the shape
+ return tfn.interpolate(image, size=shape, scale_factor=scale_factor,
+ mode=mode, align_corners=align_corners,
+ recompute_scale_factor=recompute_scale_factor)
+
+
+def check_assert(pred, gt, atol=1e-5, rtol=1e-5):
+ """
+ Check two dictionaries with allclose assertions
+
+ Parameters
+ ----------
+ pred : Dict
+ Dictionary with predictions
+ gt : Dict
+ Dictionary with ground-truth
+ atol : Float
+ Absolute tolerance
+ rtol : Float
+ Relative tolerance
+ """
+ for key in gt.keys():
+ if key in pred.keys():
+ # assert key in pred and key in gt
+ if is_dict(pred[key]):
+ check_assert(pred[key], gt[key])
+ elif is_seq(pred[key]):
+ for val1, val2 in zip(pred[key], gt[key]):
+ if is_tensor(val1):
+ assert torch.allclose(val1, val2, atol=atol, rtol=rtol), \
+ f'Assert error in {key} : {val1.mean().item()} x {val2.mean().item()}'
+ else:
+ assert val1 == val2, \
+ f'Assert error in {key} : {val1} x {val2}'
+ else:
+ if is_tensor(pred[key]):
+ assert torch.allclose(pred[key], gt[key], atol=atol, rtol=rtol), \
+ f'Assert error in {key} : {pred[key].mean().item()} x {gt[key].mean().item()}'
+ else:
+ assert pred[key] == gt[key], \
+ f'Assert error in {key} : {pred[key]} x {gt[key]}'
+
+
+def interleave(data, b):
+ """Interleave data considering multiple batches"""
+ data_interleave = data.unsqueeze(1).expand(-1, b, *data.shape[1:])
+ return data_interleave.reshape(-1, *data.shape[1:])
diff --git a/vidar/utils/types.py b/vidar/utils/types.py
new file mode 100644
index 0000000000000000000000000000000000000000..853a54cb3bc15d4c8281c925b3f79f9bfe6bda85
--- /dev/null
+++ b/vidar/utils/types.py
@@ -0,0 +1,61 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+from argparse import Namespace
+
+import numpy as np
+import torch
+
+
+def is_numpy(data):
+ """Checks if data is a numpy array."""
+ return isinstance(data, np.ndarray)
+
+
+def is_tensor(data):
+ """Checks if data is a torch tensor."""
+ return type(data) == torch.Tensor
+
+
+def is_tuple(data):
+ """Checks if data is a tuple."""
+ return isinstance(data, tuple)
+
+
+def is_list(data):
+ """Checks if data is a list."""
+ return isinstance(data, list)
+
+
+def is_double_list(data):
+ """Checks if data is a double list (list of lists)"""
+ return is_list(data) and len(data) > 0 and is_list(data[0])
+
+
+def is_dict(data):
+ """Checks if data is a dictionary."""
+ return isinstance(data, dict)
+
+
+def is_str(data):
+ """Checks if data is a string."""
+ return isinstance(data, str)
+
+
+def is_int(data):
+ """Checks if data is an integer."""
+ return isinstance(data, int)
+
+
+def is_seq(data):
+ """Checks if data is a list or tuple."""
+ return is_tuple(data) or is_list(data)
+
+
+def is_namespace(data):
+ """Check if data is a Namespace"""
+ return isinstance(data, Namespace)
+
+
+def exists(data):
+ """Check if data exists (it is not None)"""
+ return data is not None
diff --git a/vidar/utils/viz.py b/vidar/utils/viz.py
new file mode 100644
index 0000000000000000000000000000000000000000..74e331b1c4d325e5b294cd0f3043490b6c1bc161
--- /dev/null
+++ b/vidar/utils/viz.py
@@ -0,0 +1,247 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import flow_vis
+import numpy as np
+import torch
+from matplotlib.cm import get_cmap
+
+from vidar.utils.decorators import iterate1
+from vidar.utils.depth import depth2inv
+from vidar.utils.types import is_tensor, is_list
+
+
+def flow_to_color(flow_uv, clip_flow=None):
+ """
+ Calculate color from optical flow
+
+ Parameters
+ ----------
+ flow_uv : np.Array
+ Optical flow [H,W,2]
+ clip_flow : Float
+ Clipping value for optical flow
+
+ Returns
+ -------
+ colors : np.array
+ Optical flow colormap [H,W,3]
+ """
+ # Clip if requested
+ if clip_flow is not None:
+ flow_uv = np.clip(flow_uv, -clip_flow, clip_flow)
+ # Get optical flow channels
+ u = flow_uv[:, :, 0]
+ v = flow_uv[:, :, 1]
+ # Calculate maximum radian
+ rad_max = np.sqrt(2) * clip_flow if clip_flow is not None else \
+ np.max(np.sqrt(np.square(u) + np.square(v)))
+ # Normalize optical flow channels
+ epsilon = 1e-5
+ u = u / (rad_max + epsilon)
+ v = v / (rad_max + epsilon)
+ # Return colormap [0,1]
+ return flow_vis.flow_uv_to_colors(u, v, convert_to_bgr=False) / 255
+
+
+@iterate1
+@iterate1
+def viz_inv_depth(inv_depth, normalizer=None, percentile=95,
+ colormap='plasma', filter_zeros=False):
+ """
+ Converts an inverse depth map to a colormap for visualization.
+
+ Parameters
+ ----------
+ inv_depth : torch.Tensor
+ Inverse depth map to be converted [B,1,H,W]
+ normalizer : Float
+ Value for inverse depth map normalization
+ percentile : Float
+ Percentile value for automatic normalization
+ colormap : String
+ Colormap to be used
+ filter_zeros : Bool
+ If True, do not consider zero values during normalization
+
+ Returns
+ -------
+ colormap : np.Array [H,W,3]
+ Colormap generated from the inverse depth map
+ """
+ if is_list(inv_depth):
+ return [viz_inv_depth(
+ inv[0], normalizer, percentile, colormap, filter_zeros)
+ for inv in inv_depth]
+ # If a tensor is provided, convert to numpy
+ if is_tensor(inv_depth):
+ # If it has a batch size, use first one
+ if len(inv_depth.shape) == 4:
+ inv_depth = inv_depth[0]
+ # Squeeze if depth channel exists
+ if len(inv_depth.shape) == 3:
+ inv_depth = inv_depth.squeeze(0)
+ inv_depth = inv_depth.detach().cpu().numpy()
+ cm = get_cmap(colormap)
+ if normalizer is None:
+ if (inv_depth > 0).sum() == 0:
+ normalizer = 1.0
+ else:
+ normalizer = np.percentile(
+ inv_depth[inv_depth > 0] if filter_zeros else inv_depth, percentile)
+ inv_depth = inv_depth / (normalizer + 1e-6)
+ colormap = cm(np.clip(inv_depth, 0., 1.0))[:, :, :3]
+ colormap[inv_depth == 0] = 0
+ return colormap
+
+
+@iterate1
+@iterate1
+def viz_depth(depth, *args, **kwargs):
+ """Same as viz_inv_depth, but takes depth as input instead"""
+ return viz_inv_depth(depth2inv(depth), *args, **kwargs)
+
+
+@iterate1
+@iterate1
+def viz_normals(normals):
+ """
+ Converts normals map to a colormap for visualization.
+
+ Parameters
+ ----------
+ normals : torch.Tensor
+ Inverse depth map to be converted [B,3,H,W]
+
+ Returns
+ -------
+ colormap : np.Array
+ Colormap generated from the normals map [H,W,3]
+ """
+ # If a tensor is provided, convert to numpy
+ if is_tensor(normals):
+ normals = normals.permute(1, 2, 0).detach().cpu().numpy()
+ return (normals + 1) / 2
+
+
+@iterate1
+@iterate1
+def viz_optical_flow(optflow, clip_value=100.):
+ """
+ Returns a colorized version of an optical flow map
+
+ Parameters
+ ----------
+ optflow : torch.Tensor
+ Optical flow to be colorized (NOT in batch) [2,H,W]
+ clip_value : Float
+ Optical flow clip value for visualization
+
+ Returns
+ -------
+ colorized : np.Array
+ Colorized version of the input optical flow [H,W,3]
+ """
+ # If a tensor is provided, convert to numpy
+ if is_list(optflow):
+ return [viz_optical_flow(opt[0]) for opt in optflow]
+ if is_tensor(optflow):
+ if len(optflow.shape) == 4:
+ optflow = optflow[0]
+ optflow = optflow.permute(1, 2, 0).detach().cpu().numpy()
+ # Return colorized optical flow
+ return flow_to_color(optflow, clip_flow=clip_value)
+
+
+@iterate1
+@iterate1
+def viz_photo(photo, colormap='viridis', normalize=False):
+ """
+ Returns a colorized version of the photometric loss
+
+ Parameters
+ ----------
+ photo : torch.Tensor
+ Per-pixel photometric error
+ colormap : String
+ Which colormap to use
+ normalize : Bool
+ Whether the photometric error should be normalized between [0,1]
+
+ Returns
+ -------
+ colorized : np.Array
+ Colorized version of the photometric error [H,W,3]
+ """
+ if is_tensor(photo):
+ if len(photo.shape) == 4:
+ photo = photo[0]
+ if len(photo.shape) == 3:
+ photo = photo.squeeze(0)
+ photo = photo.detach().cpu().numpy()
+ cm = get_cmap(colormap)
+ if normalize:
+ photo -= photo.min()
+ photo /= photo.max()
+ colormap = cm(np.clip(photo, 0., 1.0))[:, :, :3]
+ colormap[photo == 0] = 0
+ return colormap
+
+
+@iterate1
+@iterate1
+def viz_semantic(semantic, ontology):
+ """
+ Returns a colorized version of a semantic map
+
+ Parameters
+ ----------
+ semantic : torch.Tensor
+ Semantic map to be colorized [B,1,H,W]
+ ontology : Dict
+ Dictionary mapping between class and color
+
+ Returns
+ -------
+ colorized : np.Array
+ Colorized version of the semantic map [H,W,3]
+ """
+ # If it is a tensor, cast to numpy
+ if is_tensor(semantic):
+ if semantic.dim() == 3:
+ semantic = semantic.squeeze(0)
+ semantic = semantic.detach().cpu().numpy()
+ # Create and populate color map
+ color = np.zeros((semantic.shape[0], semantic.shape[1], 3))
+ for key in ontology.keys():
+ key_color = np.array(ontology[key]['color'])
+ if is_tensor(key_color):
+ key_color = key_color.detach().cpu().numpy()
+ color[semantic == int(key)] = key_color / 255.
+ # Return colored semantic map
+ return color
+
+
+@iterate1
+@iterate1
+def viz_camera(camera):
+ """
+ Returns a colorized version of a camera viewing rays
+
+ Parameters
+ ----------
+ camera : Camera of torch.Tensor
+ Input camera or viewing rays
+
+ Returns
+ -------
+ colorized : np.Array
+ Colorized version of the camera viewing rays [H,W,3]
+ """
+ if is_tensor(camera):
+ # If it's a tensor, reshape it
+ rays = camera[-3:].permute(1, 2, 0).detach().cpu().numpy()
+ else:
+ # If it's a camera, get viewing rays
+ rays = camera.no_translation().get_viewdirs(normalize=True, flatten=False, to_world=True)
+ rays = rays[0].permute(1, 2, 0).detach().cpu().numpy()
+ return (rays + 1) / 2
diff --git a/vidar/utils/volume.py b/vidar/utils/volume.py
new file mode 100644
index 0000000000000000000000000000000000000000..9792bb5d59e51e1de389c8996a0b6351f8831dc4
--- /dev/null
+++ b/vidar/utils/volume.py
@@ -0,0 +1,146 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import numpy as np
+import torch
+
+from vidar.geometry.camera import Camera
+from vidar.utils.tensor import grid_sample
+from vidar.utils.types import is_tensor
+
+
+def warp_bins(rgb, cam, bins):
+ """
+ Warp an image based on depth bins
+
+ Parameters
+ ----------
+ rgb : torch.Tensor [B,?,H,W]
+ Input image for warping
+ cam : Camera
+ Input camera
+ bins : torch.Tensor
+ Depth bins for warping
+
+ Returns
+ -------
+ warped : torch.Tensor
+ Warped images for each depth bin
+ """
+ ones = torch.ones((1, *cam.hw)).to(rgb.device)
+ volume = torch.stack([depth * ones for depth in bins], 1)
+ coords_volume = cam.coords_from_cost_volume(volume)
+ return grid_sample(
+ rgb.repeat(len(bins), 1, 1, 1), coords_volume[0].type(rgb.dtype),
+ padding_mode='zeros', mode='bilinear', align_corners=True)
+
+
+def sample(grid, pred):
+ """
+ Sample a grid based on predictions
+
+ Parameters
+ ----------
+ grid : torch.Tensor
+ Grid to be sampled [B,?,H,W]
+ pred : torch.Tensor
+ Coordinate predictions [B,2,H,W]
+
+ Returns
+ -------
+ values : torch.Tensor
+ Sampled grid[B,?,H,W]
+ """
+ n, _, h, w = grid.shape
+ coords = pred.permute(1, 2, 0).reshape(-1, 1, 1, 1).repeat(1, 1, 1, 2)
+ coords = 2 * coords / (n - 1) - 1
+ grid = grid.permute(2, 3, 0, 1).reshape(-1, 1, n, 1).repeat(1, 1, 1, 2)
+ values = grid_sample(grid, coords,
+ padding_mode='zeros', mode='bilinear', align_corners=True)
+ return values.reshape(h, w, 1, 1).permute(2, 3, 0, 1)
+
+
+def compute_depth_bin(min_depth, max_depth, num_bins, i):
+ """
+ Calculate a single SID depth bin
+
+ Parameters
+ ----------
+ min_depth : Float
+ Minimum depth value
+ max_depth : Float
+ Maximum depth value
+ num_bins : Int
+ Number of depth bins
+ i : Int
+ Index of the depth bin in the interval
+
+ Returns
+ -------
+ bin : torch.Tensor
+ Corresponding depth bin
+ """
+ return torch.exp(np.log(min_depth) + np.log(max_depth / min_depth) * i / (num_bins - 1)).\
+ clamp(min=min_depth, max=max_depth)
+
+
+def uncompute_depth_bin(min_depth, max_depth, num_bins, depth):
+ """
+ Recover the SID bin index from a depth value
+
+ Parameters
+ ----------
+ min_depth : Float
+ Minimum depth value
+ max_depth : Float
+ Maximum depth value
+ num_bins : Int
+ Number of depth bins
+ depth : torch.Tensor
+ Depth value
+
+ Returns
+ -------
+ index : torch.Tensor
+ Index for the depth value in the SID interval
+ """
+ return (num_bins - 1) * ((torch.log(depth) - np.log(min_depth)) /
+ np.log(max_depth / min_depth)).clamp(min=0, max=num_bins)
+
+
+def compute_depth_bins(min_depth, max_depth, num_bins, mode):
+ """
+ Compute depth bins for an interval
+
+ Parameters
+ ----------
+ min_depth : Float
+ Minimum depth value
+ max_depth : Float
+ Maximum depth value
+ num_bins : Int
+ Number of depth bins
+ mode : String
+ Depth discretization mode
+
+ Returns
+ -------
+ bins : torch.Tensor
+ Discretized depth bins
+ """
+ if is_tensor(min_depth):
+ min_depth = min_depth.detach().cpu()
+ if is_tensor(max_depth):
+ max_depth = max_depth.detach().cpu()
+ if mode == 'inverse':
+ depth_bins = 1. / np.linspace(
+ 1. / max_depth, 1. / min_depth, num_bins)[::-1]
+ elif mode == 'linear':
+ depth_bins = np.linspace(
+ min_depth, max_depth, num_bins)
+ elif mode == 'sid':
+ depth_bins = np.array(
+ [np.exp(np.log(min_depth) + np.log(max_depth / min_depth) * i / (num_bins - 1))
+ for i in range(num_bins)])
+ else:
+ raise NotImplementedError
+ return torch.from_numpy(depth_bins).float()
diff --git a/vidar/utils/write.py b/vidar/utils/write.py
new file mode 100644
index 0000000000000000000000000000000000000000..89a10398e81f6d2d5f7be965df9f6b4000b7783f
--- /dev/null
+++ b/vidar/utils/write.py
@@ -0,0 +1,132 @@
+# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
+
+import os
+import pickle as pkl
+
+import cv2
+import numpy as np
+import torchvision.transforms as transforms
+
+from vidar.utils.decorators import multi_write
+from vidar.utils.types import is_tensor, is_numpy
+
+
+def create_folder(filename):
+ """Create a new folder if it doesn't exist"""
+ if '/' in filename:
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+
+
+def write_pickle(filename, data):
+ """
+ Write a pickle file
+
+ Parameters
+ ----------
+ filename : String
+ File where the pickle file will be saved
+ data : Value
+ Data to be saved
+ """
+ create_folder(filename)
+ if not filename.endswith('.pkl'):
+ filename = filename + '.pkl'
+ pkl.dump(data, open(filename, 'wb'))
+
+
+def write_npz(filename, data):
+ """
+ Write a numpy compressed file
+
+ Parameters
+ ----------
+ filename : String
+ File where the numpy file will be saved
+ data : Value
+ Data to be saved
+ """
+ np.savez_compressed(filename, **data)
+
+
+@multi_write
+@multi_write
+def write_image(filename, image):
+ """
+ Write an image to file
+
+ Parameters
+ ----------
+ filename : String
+ File where image will be saved
+ image : np.Array [H,W,3]
+ RGB image
+ """
+ # Create folder if it doesn't exist
+ create_folder(filename)
+ # If image is a tensor
+ if is_tensor(image):
+ if len(image.shape) == 4:
+ image = image[0]
+ image = image.detach().cpu().numpy().transpose(1, 2, 0)
+ cv2.imwrite(filename, image[:, :, ::-1] * 255)
+ # If image is a numpy array
+ elif is_numpy(image):
+ cv2.imwrite(filename, image[:, :, ::-1] * 255)
+ # Otherwise, assume it's a PIL image
+ else:
+ image.save(filename)
+
+
+@multi_write
+def write_depth(filename, depth, intrinsics=None):
+ """
+ Write a depth map to file, and optionally its corresponding intrinsics.
+
+ Parameters
+ ----------
+ filename : String
+ File where depth map will be saved (.npz or .png)
+ depth : np.Array or torch.Tensor
+ Depth map [H,W]
+ intrinsics : np.Array
+ Optional camera intrinsics matrix [3,3]
+ """
+ # If depth is a tensor
+ if is_tensor(depth):
+ depth = depth.detach().squeeze().cpu().numpy()
+ # If intrinsics is a tensor
+ if is_tensor(intrinsics):
+ intrinsics = intrinsics.detach().cpu().numpy()
+ # If we are saving as a .npz
+ if filename.endswith('.npz'):
+ np.savez_compressed(filename, depth=depth, intrinsics=intrinsics)
+ # If we are saving as a .png
+ elif filename.endswith('.png'):
+ depth = transforms.ToPILImage()((depth * 256).astype(np.int32))
+ depth.save(filename)
+ # Something is wrong
+ else:
+ raise NotImplementedError('Depth filename not valid.')
+
+
+@multi_write
+def write_optical_flow(filename, optflow):
+ """
+ Write a depth map to file, and optionally its corresponding intrinsics.
+
+ Parameters
+ ----------
+ filename : String
+ File where depth map will be saved (.npz or .png)
+ optflow : np.Array or torch.Tensor
+ Optical flow map [H,W]
+ """
+ # If depth is a tensor
+ if is_tensor(optflow):
+ optflow = optflow.detach().squeeze().cpu().numpy()
+ # If we are saving as a .npz
+ if filename.endswith('.npz'):
+ np.savez_compressed(filename, optflow=optflow)
+ # Something is wrong
+ else:
+ raise NotImplementedError('Optical flow filename not valid.')
\ No newline at end of file