alexlau commited on
Commit
19677a1
·
1 Parent(s): faeefa7

first deploy demo

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .idea
2
+ __pycache__
3
+ models
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import streamlit as st
4
+ from google_drive_downloader import GoogleDriveDownloader as gdd
5
+
6
+ from demo.src.models import load_trained_model
7
+ from demo.src.utils import render_predict_from_pose, predict_to_image
8
+ from demo.src.config import MODEL_DIR, MODEL_NAME, FILE_ID
9
+
10
+
11
+ if not os.path.isfile('models'):
12
+ model_path = os.path.join(MODEL_DIR, MODEL_NAME)
13
+ gdd.download_file_from_google_drive(file_id=FILE_ID,
14
+ dest_path=model_path,
15
+ unzip=False)
16
+ print(f'model downloaded from google drive: {model_path}')
17
+
18
+
19
+ @st.cache(show_spinner=False, allow_output_mutation=True)
20
+ def fetch_model():
21
+ model, state = load_trained_model(MODEL_DIR, MODEL_NAME)
22
+ return model, state
23
+
24
+
25
+ model, state = fetch_model()
26
+ pi = math.pi
27
+ st.set_page_config(page_title="DietNeRF Demo")
28
+ st.sidebar.header('SELECT YOUR VIEW DIRECTION')
29
+ theta = st.sidebar.slider("Theta", min_value=0., max_value=2.*pi,
30
+ step=0.5, value=0.)
31
+ phi = st.sidebar.slider("Phi", min_value=0., max_value=0.5*pi,
32
+ step=0.1, value=1.)
33
+ radius = st.sidebar.slider("Radius", min_value=2., max_value=6.,
34
+ step=1., value=3.)
35
+
36
+
37
+ pred_color, _ = render_predict_from_pose(state, theta, phi, radius)
38
+ im = predict_to_image(pred_color)
39
+
40
+ st.image(im, use_column_width=False)
demo/__init__.py ADDED
File without changes
demo/src/__init__.py ADDED
File without changes
demo/src/config.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # for downloading model from google drive
2
+ FILE_ID = "1iytA1n2z4go3uVCwE__vIKouTKyIDjEq"
3
+ MODEL_DIR = './models'
4
+ MODEL_NAME = 'trained_model'
5
+
6
+
7
+ class NerfConfig:
8
+ # MODEL CONFIG
9
+ model = "nerf"
10
+ net_activation = "relu"
11
+ rgb_activation = "sigmoid"
12
+ sigma_activation = "relu"
13
+ min_deg_point = 0
14
+ max_deg_point = 10
15
+ deg_view = 4
16
+ # reduce num_coarse_samples, num_fine_samples for speedup
17
+ num_coarse_samples = 32
18
+ num_fine_samples = 64
19
+ use_viewdirs = True
20
+ near = 2
21
+ far = 6
22
+ noise_std = None
23
+ # TODO @Alex: set white_bkgd as flag if we add LLFF dataset
24
+ white_bkgd = True
25
+ net_depth = 8
26
+ net_width = 256
27
+ net_depth_condition = 1
28
+ net_width_condition = 128
29
+ skip_layer = 4
30
+ num_rgb_channels = 3
31
+ num_sigma_channels = 1
32
+ lindisp = True
33
+ legacy_posenc_order = False
34
+ randomized = True
35
+
36
+ # DATA CONFIG
37
+ W = 800
38
+ H = 800
39
+ IMAGE_SHAPE = (W, H, 3)
40
+ # TODO @Alex: flexible focal if we add LLFF dataset
41
+ FOCAL = 555.5555155968841
42
+ # reduce CHUNK if OOM
43
+ CHUNK = 4096
44
+ DOWNSAMPLE = 2
demo/src/models.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import flax
3
+ from jax import random
4
+ from flax.training import checkpoints
5
+
6
+ from jaxnerf.nerf import models
7
+ from jaxnerf.nerf import utils
8
+ from demo.src.config import NerfConfig
9
+
10
+ rng = random.PRNGKey(0)
11
+ # TODO @Alex: make image size flexible if needed
12
+ dummy_rays = random.uniform(rng, shape=NerfConfig.IMAGE_SHAPE)
13
+ dummy_batch = {"rays": utils.Rays(dummy_rays, dummy_rays, dummy_rays)}
14
+ dummy_lr = 1e-2
15
+
16
+
17
+ def load_trained_model(model_dir, model_fn):
18
+ model, init_variables = init_model()
19
+ optimizer = flax.optim.Adam(dummy_lr).create(init_variables)
20
+ state = utils.TrainState(optimizer=optimizer)
21
+ del optimizer, init_variables
22
+ assert os.path.isfile(os.path.join(model_dir, model_fn))
23
+ state = checkpoints.restore_checkpoint(model_dir, state,
24
+ prefix=model_fn)
25
+ return model, state
26
+
27
+
28
+ def init_model():
29
+ _, key = random.split(rng)
30
+ model, init_variables = models.get_model(key, dummy_batch,
31
+ NerfConfig)
32
+ return model, init_variables
33
+
34
+
35
+ if __name__ == '__main__':
36
+ _model_dir = '../ship_fewshot_wsc'
37
+ _model_fn = 'checkpoint_345000'
38
+ _model, _state = load_trained_model(_model_dir, _model_fn)
demo/src/utils.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import jax
3
+ from jax import random
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+ from jaxnerf.nerf import clip_utils
8
+ from jaxnerf.nerf import utils
9
+ from demo.src.config import NerfConfig
10
+ from demo.src.models import init_model
11
+
12
+ model, _ = init_model()
13
+
14
+
15
+ def render_predict_from_pose(state, theta, phi, radius):
16
+ rng = random.PRNGKey(0)
17
+ partial_render_fn = partial(render_pfn, state.optimizer.target)
18
+ rays = _render_rays_from_pose(theta, phi, radius)
19
+ pred_color, pred_disp, _ = utils.render_image(
20
+ partial_render_fn, rays,
21
+ rng, False, chunk=NerfConfig.CHUNK)
22
+ return pred_color, pred_disp
23
+
24
+
25
+ def predict_to_image(pred_out):
26
+ image_arr = np.array(np.clip(pred_out, 0., 1.) * 255.).astype(np.uint8)
27
+ return Image.fromarray(image_arr)
28
+
29
+
30
+ def _render_rays_from_pose(theta, phi, radius):
31
+ camtoworld = np.array(clip_utils.pose_spherical(theta, phi, radius))
32
+ rays = _camtoworld_matrix_to_rays(camtoworld)
33
+ return rays
34
+
35
+
36
+ def _camtoworld_matrix_to_rays(camtoworld):
37
+ """ render one instance of rays given a camera to world matrix (4, 4) """
38
+ pixel_center = 0.
39
+ w, h = NerfConfig.W, NerfConfig.H
40
+ focal, downsample = NerfConfig.FOCAL, NerfConfig.DOWNSAMPLE
41
+ x, y = np.meshgrid( # pylint: disable=unbalanced-tuple-unpacking
42
+ np.arange(0, w, downsample, dtype=np.float32) + pixel_center, # X-Axis (columns)
43
+ np.arange(0, h, downsample, dtype=np.float32) + pixel_center, # Y-Axis (rows)
44
+ indexing="xy")
45
+ camera_dirs = np.stack([(x - w * 0.5) / focal,
46
+ -(y - h * 0.5) / focal,
47
+ -np.ones_like(x)],
48
+ axis=-1)
49
+ directions = (camera_dirs[..., None, :] * camtoworld[None, None, :3, :3]).sum(axis=-1)
50
+ origins = np.broadcast_to(camtoworld[None, None, :3, -1], directions.shape)
51
+ viewdirs = directions / np.linalg.norm(directions, axis=-1, keepdims=True)
52
+ return utils.Rays(origins=origins, directions=directions, viewdirs=viewdirs)
53
+
54
+
55
+ def _render_fn(variables, key_0, key_1, rays):
56
+ return jax.lax.all_gather(model.apply(
57
+ variables, key_0, key_1, rays, False),
58
+ axis_name="batch")
59
+
60
+
61
+ render_pfn = jax.pmap(_render_fn, in_axes=(None, None, None, 0),
62
+ donate_argnums=3, axis_name="batch")
jaxnerf/README.md ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # JaxNeRF
2
+
3
+ This is a [JAX](https://github.com/google/jax) implementation of
4
+ [NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis](http://www.matthewtancik.com/nerf).
5
+ This code is created and maintained by
6
+ [Boyang Deng](https://boyangdeng.com/),
7
+ [Jon Barron](https://jonbarron.info/),
8
+ and [Pratul Srinivasan](https://people.eecs.berkeley.edu/~pratul/).
9
+
10
+ <div align="center">
11
+ <img width="95%" alt="NeRF Teaser" src="https://raw.githubusercontent.com/bmild/nerf/master/imgs/pipeline.jpg">
12
+ </div>
13
+
14
+ Our JAX implementation currently supports:
15
+
16
+ <table class="tg">
17
+ <thead>
18
+ <tr>
19
+ <th class="tg-0lax"><span style="font-weight:bold">Platform</span></th>
20
+ <th class="tg-0lax" colspan="2"><span style="font-weight:bold">Single-Host GPU</span></th>
21
+ <th class="tg-0lax" colspan="2"><span style="font-weight:bold">Multi-Device TPU</span></th>
22
+ </tr>
23
+ </thead>
24
+ <tbody>
25
+ <tr>
26
+ <td class="tg-0lax"><span style="font-weight:bold">Type</span></td>
27
+ <td class="tg-0lax">Single-Device</td>
28
+ <td class="tg-0lax">Multi-Device</td>
29
+ <td class="tg-0lax">Single-Host</td>
30
+ <td class="tg-0lax">Multi-Host</td>
31
+ </tr>
32
+ <tr>
33
+ <td class="tg-0lax"><span style="font-weight:bold">Training</span></td>
34
+ <td class="tg-0lax"><img src="http://storage.googleapis.com/gresearch/jaxnerf/check.png" alt="Supported" width=18px height=18px></td>
35
+ <td class="tg-0lax"><img src="http://storage.googleapis.com/gresearch/jaxnerf/check.png" alt="Supported" width=18px height=18px></td>
36
+ <td class="tg-0lax"><img src="http://storage.googleapis.com/gresearch/jaxnerf/check.png" alt="Supported" width=18px height=18px></td>
37
+ <td class="tg-0lax"><img src="http://storage.googleapis.com/gresearch/jaxnerf/check.png" alt="Supported" width=18px height=18px></td>
38
+ </tr>
39
+ <tr>
40
+ <td class="tg-0lax"><span style="font-weight:bold">Evaluation</span></td>
41
+ <td class="tg-0lax"><img src="http://storage.googleapis.com/gresearch/jaxnerf/check.png" alt="Supported" width=18px height=18px></td>
42
+ <td class="tg-0lax"><img src="http://storage.googleapis.com/gresearch/jaxnerf/check.png" alt="Supported" width=18px height=18px></td>
43
+ <td class="tg-0lax"><img src="http://storage.googleapis.com/gresearch/jaxnerf/check.png" alt="Supported" width=18px height=18px></td>
44
+ <td class="tg-0lax"><img src="http://storage.googleapis.com/gresearch/jaxnerf/check.png" alt="Supported" width=18px height=18px></td>
45
+ </tr>
46
+ </tbody>
47
+ </table>
48
+
49
+ The training job on 128 TPUv2 cores can be done in **2.5 hours (v.s 3 days for TF
50
+ NeRF)** for 1 million optimization steps. In other words, JaxNeRF trains to the best while trains very fast.
51
+
52
+ As for inference speed, here are the statistics of rendering an image with
53
+ 800x800 resolution (numbers are averaged over 50 rendering passes):
54
+
55
+ | Platform | 1 x NVIDIA V100 | 8 x NVIDIA V100 | 128 x TPUv2 |
56
+ |----------|:---------------:|:-----------------------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------------------------:|
57
+ | TF NeRF | 27.74 secs | <img src="http://storage.googleapis.com/gresearch/jaxnerf/cross.png" alt="Not Supported" width=18px height=18px> | <img src="http://storage.googleapis.com/gresearch/jaxnerf/cross.png" alt="Not Supported" width=18px height=18px> |
58
+ | JaxNeRF | 20.77 secs | 2.65 secs | 0.35 secs |
59
+
60
+
61
+ The code is tested and reviewed carefully to match the
62
+ [original TF NeRF implementation](https://github.com/bmild/nerf).
63
+ If you have any issues using this code, please do not open an issue as the repo
64
+ is shared by all projects under Google Research. Instead, just email
65
+ jaxnerf@google.com.
66
+
67
+ ## Installation
68
+ We recommend using [Anaconda](https://www.anaconda.com/products/individual) to set
69
+ up the environment. Run the following commands:
70
+
71
+ ```
72
+ # Clone the repo
73
+ svn export https://github.com/google-research/google-research/trunk/jaxnerf
74
+ # Create a conda environment, note you can use python 3.6-3.8 as
75
+ # one of the dependencies (TensorFlow) hasn't supported python 3.9 yet.
76
+ conda create --name jaxnerf python=3.6.12; conda activate jaxnerf
77
+ # Prepare pip
78
+ conda install pip; pip install --upgrade pip
79
+ # Install requirements
80
+ pip install -r jaxnerf/requirements.txt
81
+ # [Optional] Install GPU and TPU support for Jax
82
+ # Remember to change cuda101 to your CUDA version, e.g. cuda110 for CUDA 11.0.
83
+ pip install --upgrade jax jaxlib==0.1.57+cuda101 -f https://storage.googleapis.com/jax-releases/jax_releases.html
84
+ ```
85
+
86
+ Then, you'll need to download the datasets
87
+ from the [NeRF official Google Drive](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1).
88
+ Please download the `nerf_synthetic.zip` and `nerf_llff_data.zip` and unzip them
89
+ in the place you like. Let's assume they are placed under `/tmp/jaxnerf/data/`.
90
+
91
+ That's it for installation. You're good to go. **Notice:** For the following instructions, you don't need to enter the jaxnerf folder. Just stay in the parent folder.
92
+
93
+ ## Two Commands for Everything
94
+
95
+ ```
96
+ bash jaxnerf/train.sh demo /tmp/jaxnerf/data
97
+ bash jaxnerf/eval.sh demo /tmp/jaxnerf/data
98
+ ```
99
+
100
+ Once both jobs are done running (which may take a while if you only have 1 GPU
101
+ or CPU), you'll have a folder, `/tmp/jaxnerf/data/demo`, with:
102
+
103
+ * Trained NeRF models for all scenes in the blender dataset.
104
+ * Rendered images and depth maps for all test views.
105
+ * The collected PSNRs of all scenes in a TXT file.
106
+
107
+ Note that we used the `demo` config here which is basically the `blender` config
108
+ in the paper except smaller batch size and much less train steps. Of course, you
109
+ can use other configs to replace `demo` and other data locations to replace
110
+ `/tmp/jaxnerf/data`.
111
+
112
+ We provide 2 configurations in the folder `configs` which match the original
113
+ configurations used in the paper for the blender dataset and the LLFF dataset.
114
+ Be careful when you use them. Their batch sizes are large so you may get OOM error if you have limited resources, for example, 1 GPU with small memory. Also, they have many many train steps so you may need days to finish training all scenes.
115
+
116
+ ## Play with One Scene
117
+
118
+ You can also train NeRF on only one scene. The easiest way is to use given configs:
119
+
120
+ ```
121
+ python -m jaxnerf.train \
122
+ --data_dir=/PATH/TO/YOUR/SCENE/DATA \
123
+ --train_dir=/PATH/TO/THE/PLACE/YOU/WANT/TO/SAVE/CHECKPOINTS \
124
+ --config=configs/CONFIG_YOU_LIKE
125
+ ```
126
+
127
+ Evaluating NeRF on one scene is similar:
128
+
129
+ ```
130
+ python -m jaxnerf.eval \
131
+ --data_dir=/PATH/TO/YOUR/SCENE/DATA \
132
+ --train_dir=/PATH/TO/THE/PLACE/YOU/SAVED/CHECKPOINTS \
133
+ --config=configs/CONFIG_YOU_LIKE \
134
+ --chunk=4096
135
+ ```
136
+
137
+ The `chunk` parameter defines how many rays are feed to the model in one go.
138
+ We recommend you to use the largest value that fits to your device's memory but
139
+ small values are fine, only a bit slow.
140
+
141
+ You can also define your own configurations by passing command line flags. Please refer to the `define_flags` function in `nerf/utils.py` for all the flags and their meanings.
142
+
143
+ **Note**: For the ficus scene in the blender dataset, we noticed that it's sensible to different initializations,
144
+ e.g. using different random seeds, if using the original learning rate schedule in the paper.
145
+ Therefore, we provide a simple tweak (turned off by default) for more stable trainings: using `lr_delay_steps` and `lr_delay_mult`.
146
+ This allows the training to start from a smaller learning rate (`lr_init` * `lr_delay_mult`) in the first `lr_delay_steps`.
147
+ We didn't use them for our pretrained models
148
+ but we tested `lr_delay_steps=5000` with `lr_delay_mult=0.2` and it works quite smoothly.
149
+
150
+ ## Pretrained Models
151
+
152
+ We provide a collection of pretrained NeRF models that match the numbers
153
+ reported in the [paper](https://arxiv.org/abs/2003.08934). Actually, ours are
154
+ slightly better overall because we trained for more iterations (while still
155
+ being much faster!). You can find our pretrained models
156
+ [here](http://storage.googleapis.com/gresearch/jaxnerf/jaxnerf_pretrained_models.zip).
157
+ The performances (in PSNR) of our pretrained NeRF models are listed below:
158
+
159
+ ### Blender
160
+
161
+
162
+ | Scene | Chair | Drums | Ficus | Hotdog | Lego | Materials | Mic | Ship | Mean |
163
+ |---------|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|
164
+ | TF NeRF | 33.00 | 25.01 | 30.13 | 36.18 | 32.54 | 29.62 | 32.91 | 28.65 | 31.01 |
165
+ | JaxNeRF | **34.08** | **25.03** | **30.43** | **36.92** | **33.28** | **29.91** | **34.53** | **29.36** | **31.69** |
166
+
167
+ ### LLFF
168
+
169
+ | Scene | Room | Fern | Leaves | Fortress | Orchids | Flower | T-Rex | Horns | Mean |
170
+ |---------|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|
171
+ | TF NeRF | 32.70 | **25.17** | 20.92 | 31.16 | **20.36** | 27.40 | 26.80 | 27.45 | 26.50 |
172
+ | JaxNeRF | **33.04** | 24.83 | **21.23** | **31.76** | 20.27 | **28.07** | **27.42** | **28.10** | **26.84** |
173
+
174
+ ## Citation
175
+ If you use this software package, please cite it as:
176
+
177
+ ```
178
+ @software{jaxnerf2020github,
179
+ author = {Boyang Deng and Jonathan T. Barron and Pratul P. Srinivasan},
180
+ title = {{JaxNeRF}: an efficient {JAX} implementation of {NeRF}},
181
+ url = {https://github.com/google-research/google-research/tree/master/jaxnerf},
182
+ version = {0.0},
183
+ year = {2020},
184
+ }
185
+ ```
186
+
187
+ and also cite the original NeRF paper:
188
+
189
+ ```
190
+ @inproceedings{mildenhall2020nerf,
191
+ title={NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis},
192
+ author={Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng},
193
+ year={2020},
194
+ booktitle={ECCV},
195
+ }
196
+ ```
197
+
198
+ ## Acknowledgement
199
+ We'd like to thank
200
+ [Daniel Duckworth](http://www.stronglyconvex.com/),
201
+ [Dan Gnanapragasam](https://research.google/people/DanGnanapragasam/),
202
+ and [James Bradbury](https://twitter.com/jekbradbury)
203
+ for their help on reviewing and optimizing this code.
204
+ We'd like to also thank the amazing [JAX](https://github.com/google/jax) team for
205
+ very insightful and helpful discussions on how to use JAX for NeRF.
jaxnerf/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
jaxnerf/configs/blender.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ dataset: blender
2
+ batching: single_image
3
+ factor: 0
4
+ num_coarse_samples: 64
5
+ num_fine_samples: 128
6
+ use_viewdirs: true
7
+ white_bkgd: true
8
+ batch_size: 4096
9
+ randomized: true
jaxnerf/configs/demo.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: blender
2
+ batching: single_image
3
+ factor: 0
4
+ num_coarse_samples: 64
5
+ num_fine_samples: 128
6
+ use_viewdirs: true
7
+ white_bkgd: true
8
+ batch_size: 1024
9
+ randomized: true
10
+ max_steps: 50000
jaxnerf/configs/diet_nerf_tpu_vm_few_shot.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: blender
2
+ batching: single_image
3
+ factor: 0
4
+ num_coarse_samples: 64
5
+ num_fine_samples: 128
6
+ use_viewdirs: true
7
+ white_bkgd: true
8
+ batch_size: 1024
9
+ randomized: true
10
+ max_steps: 500000
11
+ print_every: 100
12
+ render_every: 5000
13
+ save_every: 5000
14
+ use_semantic_loss: true
15
+ clip_model_name: openai/clip-vit-base-patch32
16
+ clip_output_dtype: float32
17
+ sc_loss_factor: 4
18
+ sc_loss_every: 16
19
+ sc_loss_mult: 10
20
+ few_shot: 8
jaxnerf/configs/diet_nerf_tpu_vm_test.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: blender
2
+ batching: single_image
3
+ factor: 0
4
+ num_coarse_samples: 64
5
+ num_fine_samples: 128
6
+ use_viewdirs: true
7
+ white_bkgd: true
8
+ batch_size: 1024
9
+ randomized: true
10
+ max_steps: 500000
11
+ print_every: 100
12
+ render_every: 5000
13
+ save_every: 5000
14
+ use_semantic_loss: true
15
+ clip_model_name: openai/clip-vit-base-patch32
16
+ clip_output_dtype: float32
17
+ sc_loss_factor: 4
18
+ sc_loss_every: 16
19
+ sc_loss_mult: 10
20
+ few_shot: -1
jaxnerf/configs/eval_diet_nerf_tpu_vm_few_shot.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: blender
2
+ batching: single_image
3
+ factor: 0
4
+ num_coarse_samples: 64
5
+ num_fine_samples: 128
6
+ use_viewdirs: true
7
+ white_bkgd: true
8
+ batch_size: 1024
9
+ randomized: true
10
+ max_steps: 500000
11
+ print_every: 100
12
+ render_every: 5000
13
+ save_every: 5000
14
+ use_semantic_loss: true
15
+ clip_model_name: openai/clip-vit-base-patch32
16
+ clip_output_dtype: float32
17
+ sc_loss_factor: 4
18
+ sc_loss_every: 16
19
+ sc_loss_mult: 10
20
+ few_shot: 8
21
+ spherify: True
22
+ lindisp: True
jaxnerf/configs/llff.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: llff
2
+ batching: all_images
3
+ num_coarse_samples: 64
4
+ num_fine_samples: 128
5
+ use_viewdirs: true
6
+ white_bkgd: false
7
+ batch_size: 4096
8
+ randomized: true
9
+ near: 0.
10
+ far: 1.
11
+ factor: 4
12
+ llffhold: 8
13
+ noise_std: 1.
jaxnerf/configs/llff_360.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: llff
2
+ batching: all_images
3
+ num_coarse_samples: 64
4
+ num_fine_samples: 128
5
+ use_viewdirs: true
6
+ white_bkgd: false
7
+ batch_size: 4096
8
+ randomized: true
9
+ near: 0.2
10
+ far: 100.
11
+ factor: 8
12
+ llffhold: 8
13
+ noise_std: 1.
14
+ spherify: True
15
+ lindisp: True
jaxnerf/configs/nerf_tpu_vm_few_shot.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: blender
2
+ batching: single_image
3
+ factor: 0
4
+ num_coarse_samples: 64
5
+ num_fine_samples: 128
6
+ use_viewdirs: true
7
+ white_bkgd: true
8
+ batch_size: 1024
9
+ randomized: true
10
+ max_steps: 500000
11
+ print_every: 100
12
+ render_every: 5000
13
+ save_every: 5000
14
+ use_semantic_loss: false
15
+ clip_model_name: openai/clip-vit-base-patch32
16
+ clip_output_dtype: float32
17
+ sc_loss_factor: 4
18
+ sc_loss_every: 16
19
+ sc_loss_mult: 10
20
+ few_shot: 8
jaxnerf/configs/orig_nerf_tpu_vm_full.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: blender
2
+ batching: single_image
3
+ factor: 0
4
+ num_coarse_samples: 64
5
+ num_fine_samples: 128
6
+ use_viewdirs: true
7
+ white_bkgd: true
8
+ batch_size: 1024
9
+ randomized: true
10
+ max_steps: 100000
11
+ print_every: 1000
12
+ render_every: 5000
13
+ save_every: 5000
jaxnerf/configs/orig_nerf_tpu_vm_test.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: blender
2
+ batching: single_image
3
+ factor: 0
4
+ num_coarse_samples: 64
5
+ num_fine_samples: 128
6
+ use_viewdirs: true
7
+ white_bkgd: true
8
+ batch_size: 1024
9
+ randomized: true
10
+ max_steps: 5000
11
+ print_every: 100
12
+ render_every: 500
13
+ save_every: 500
jaxnerf/eval.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Lint as: python3
17
+ """Evaluation script for Nerf."""
18
+ import functools
19
+ from os import path
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ import flax
24
+ from flax.metrics import tensorboard
25
+ from flax.training import checkpoints
26
+ import jax
27
+ from jax import random
28
+ import numpy as np
29
+ import tensorflow as tf
30
+ import tensorflow_hub as tf_hub
31
+ #import wandb
32
+ import glob
33
+ import cv2
34
+ import os
35
+
36
+ from jaxnerf.nerf import datasets
37
+ from jaxnerf.nerf import models
38
+ from jaxnerf.nerf import utils
39
+
40
+ FLAGS = flags.FLAGS
41
+
42
+ utils.define_flags()
43
+
44
+ #LPIPS_TFHUB_PATH = "@neural-rendering/lpips/distance/1"
45
+
46
+
47
+ def compute_lpips(image1, image2, model):
48
+ """Compute the LPIPS metric."""
49
+ # The LPIPS model expects a batch dimension.
50
+ return model(
51
+ tf.convert_to_tensor(image1[None, Ellipsis]),
52
+ tf.convert_to_tensor(image2[None, Ellipsis]))[0]
53
+
54
+
55
+ def main(unused_argv):
56
+ # Hide the GPUs and TPUs from TF so it does not reserve memory on them for
57
+ # LPIPS computation or dataset loading.
58
+ tf.config.experimental.set_visible_devices([], "GPU")
59
+ tf.config.experimental.set_visible_devices([], "TPU")
60
+
61
+ #wandb.init(project="hf-flax-clip-nerf", entity="wandb", sync_tensorboard=True)
62
+
63
+ rng = random.PRNGKey(20200823)
64
+
65
+ if FLAGS.config is not None:
66
+ utils.update_flags(FLAGS)
67
+ if FLAGS.train_dir is None:
68
+ raise ValueError("train_dir must be set. None set now.")
69
+ if FLAGS.data_dir is None:
70
+ raise ValueError("data_dir must be set. None set now.")
71
+
72
+ dataset = datasets.get_dataset("test", FLAGS)
73
+ rng, key = random.split(rng)
74
+ model, init_variables = models.get_model(key, dataset.peek(), FLAGS)
75
+ optimizer = flax.optim.Adam(FLAGS.lr_init).create(init_variables)
76
+ state = utils.TrainState(optimizer=optimizer)
77
+ del optimizer, init_variables
78
+
79
+ #lpips_model = tf_hub.load(LPIPS_TFHUB_PATH)
80
+
81
+ # Rendering is forced to be deterministic even if training was randomized, as
82
+ # this eliminates "speckle" artifacts.
83
+ def render_fn(variables, key_0, key_1, rays):
84
+ return jax.lax.all_gather(
85
+ model.apply(variables, key_0, key_1, rays, False), axis_name="batch")
86
+
87
+ # pmap over only the data input.
88
+ render_pfn = jax.pmap(
89
+ render_fn,
90
+ in_axes=(None, None, None, 0),
91
+ donate_argnums=3,
92
+ axis_name="batch",
93
+ )
94
+
95
+ # Compiling to the CPU because it's faster and more accurate.
96
+ ssim_fn = jax.jit(
97
+ functools.partial(utils.compute_ssim, max_val=1.), backend="cpu")
98
+
99
+ last_step = 0
100
+ out_dir = path.join(FLAGS.train_dir,
101
+ "path_renders" if FLAGS.render_path else "test_preds")
102
+ if not FLAGS.eval_once:
103
+ summary_writer = tensorboard.SummaryWriter(
104
+ path.join(FLAGS.train_dir, "eval"))
105
+ while True:
106
+ state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
107
+ step = int(state.optimizer.state.step)
108
+ if step <= last_step:
109
+ continue
110
+ if FLAGS.save_output and (not utils.isdir(out_dir)):
111
+ utils.makedirs(out_dir)
112
+ psnr_values = []
113
+ ssim_values = []
114
+ #lpips_values = []
115
+ if not FLAGS.eval_once:
116
+ showcase_index = np.random.randint(0, dataset.size)
117
+ for idx in range(dataset.sizerender_image):
118
+ print(f"Evaluating {idx + 1}/{dataset.size}")
119
+ batch = next(dataset)
120
+ pred_color, pred_disp, pred_acc = utils.render_image(
121
+ functools.partial(render_pfn, state.optimizer.target),
122
+ batch["rays"],
123
+ rng,
124
+ FLAGS.dataset == "llff",
125
+ chunk=FLAGS.chunk)
126
+ if jax.host_id() != 0: # Only record via host 0.
127
+ continue
128
+ if not FLAGS.eval_once and idx == showcase_index:
129
+ showcase_color = pred_color
130
+ showcase_disp = pred_disp
131
+ showcase_acc = pred_acc
132
+ if not FLAGS.render_path:
133
+ showcase_gt = batch["pixels"]
134
+ if not FLAGS.render_path:
135
+ psnr = utils.compute_psnr(((pred_color - batch["pixels"]) ** 2).mean())
136
+ ssim = ssim_fn(pred_color, batch["pixels"])
137
+ #lpips = compute_lpips(pred_color, batch["pixels"], lpips_model)
138
+ print(f"PSNR = {psnr:.4f}, SSIM = {ssim:.4f}")
139
+ psnr_values.append(float(psnr))
140
+ ssim_values.append(float(ssim))
141
+ #lpips_values.append(float(lpips))
142
+ if FLAGS.save_output:
143
+ utils.save_img(pred_color, path.join(out_dir, "{:03d}.png".format(idx)))
144
+ utils.save_img(pred_disp[Ellipsis, 0],
145
+ path.join(out_dir, "disp_{:03d}.png".format(idx)))
146
+ if (not FLAGS.eval_once) and (jax.host_id() == 0):
147
+ summary_writer.image("pred_color", showcase_color, step)
148
+ summary_writer.image("pred_disp", showcase_disp, step)
149
+ summary_writer.image("pred_acc", showcase_acc, step)
150
+ if not FLAGS.render_path:
151
+ summary_writer.scalar("psnr", np.mean(np.array(psnr_values)), step)
152
+ summary_writer.scalar("ssim", np.mean(np.array(ssim_values)), step)
153
+ #summary_writer.scalar("lpips", np.mean(np.array(lpips_values)), step)
154
+ summary_writer.image("target", showcase_gt, step)
155
+ if FLAGS.save_output and (not FLAGS.render_path) and (jax.host_id() == 0):
156
+ with utils.open_file(path.join(out_dir, f"psnrs_{step}.txt"), "w") as f:
157
+ f.write(" ".join([str(v) for v in psnr_values]))
158
+ with utils.open_file(path.join(out_dir, f"ssims_{step}.txt"), "w") as f:
159
+ f.write(" ".join([str(v) for v in ssim_values]))
160
+ #with utils.open_file(path.join(out_dir, f"lpips_{step}.txt"), "w") as f:
161
+ #f.write(" ".join([str(v) for v in lpips_values]))
162
+ with utils.open_file(path.join(out_dir, "psnr.txt"), "w") as f:
163
+ f.write("{}".format(np.mean(np.array(psnr_values))))
164
+ with utils.open_file(path.join(out_dir, "ssim.txt"), "w") as f:
165
+ f.write("{}".format(np.mean(np.array(ssim_values))))
166
+ #with utils.open_file(path.join(out_dir, "lpips.txt"), "w") as f:
167
+ #f.write("{}".format(np.mean(np.array(lpips_values))))
168
+ imglist = glob.glob(os.path.join(out_dir, "[0-9][0-9][0-9].png"))
169
+ sorted_files = sorted(imglist, key=lambda x: int(x.split('/')[-1].split('.')[0]))
170
+ imglist2 = glob.glob(os.path.join(out_dir, "disp_[0-9][0-9][0-9].png"))
171
+ sorted_files2 = sorted(imglist2, key=lambda x: int(x.split('/')[-1].split('.')[0].split('_')[-1]))
172
+ fourcc = cv2.VideoWriter_fourcc(*'MP4V')
173
+ fps = 10.0
174
+ out = cv2.VideoWriter(os.path.join(out_dir, "rendering_video.mp4"), fourcc, fps,
175
+ (2 * img.shape[1], img.shape[0]))
176
+
177
+ for i in range(len(imglist)):
178
+ img = cv2.imread(imglist[i], cv2.IMREAD_COLOR)
179
+ img2 = cv2.imread(imglist2[i], cv2.IMREAD_COLOR)
180
+ catimg = np.concatenate((img, img2), axis=1)
181
+ out.write(catimg)
182
+
183
+ out.release()
184
+ if FLAGS.eval_once:
185
+ break
186
+ if int(step) >= FLAGS.max_steps:
187
+ break
188
+ last_step = step
189
+
190
+
191
+ if __name__ == "__main__":
192
+ app.run(main)
jaxnerf/eval.sh ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The Google Research Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ #!/bin/bash
16
+ CONFIG=$1
17
+ DATA_ROOT=$2
18
+ ROOT_DIR=/tmp/jaxnerf/"$CONFIG"
19
+ if [ $CONFIG == "llff" ]
20
+ then
21
+ SCENES="room fern leaves fortress orchids flower trex horns"
22
+ DATA_FOLDER="nerf_llff_data"
23
+ else
24
+ SCENES="lego chair drums ficus hotdog materials mic ship"
25
+ DATA_FOLDER="nerf_synthetic"
26
+ fi
27
+
28
+ # launch evaluation jobs for all scenes.
29
+ for scene in $SCENES; do
30
+ python -m jaxnerf.eval \
31
+ --data_dir="$DATA_ROOT"/"$DATA_FOLDER"/"$scene" \
32
+ --train_dir="$ROOT_DIR"/"$scene" \
33
+ --chunk=4096 \
34
+ --config=configs/"$CONFIG"
35
+ done
36
+
37
+ # collect PSNR of all scenes.
38
+ touch "$ROOT_DIR"/psnr.txt
39
+ for scene in $SCENES; do
40
+ printf "${scene}: " >> "$ROOT_DIR"/psnr.txt
41
+ cat "$ROOT_DIR"/"$scene"/test_preds/psnr.txt >> \
42
+ "$ROOT_DIR"/psnr.txt
43
+ printf $'\n' >> "$ROOT_DIR"/psnr.txt
44
+ done
jaxnerf/example_data/imgs/r_0.png ADDED
jaxnerf/example_data/transforms_test.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"camera_angle_x": 0.6911112070083618, "frames": [{"file_path": "./imgs/r_0", "rotation": 0.012566370614359171, "transform_matrix": [[-0.9999021887779236, 0.004192245192825794, -0.013345719315111637, -0.05379832163453102], [-0.013988681137561798, -0.2996590733528137, 0.95394366979599, 3.845470428466797], [-4.656612873077393e-10, 0.9540371894836426, 0.29968830943107605, 1.2080823183059692], [0.0, 0.0, 0.0, 1.0]]}]}
jaxnerf/example_data/transforms_train.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"camera_angle_x": 0.6911112070083618, "frames": [{"file_path": "./imgs/r_0", "rotation": 0.012566370614359171, "transform_matrix": [[-0.9999021887779236, 0.004192245192825794, -0.013345719315111637, -0.05379832163453102], [-0.013988681137561798, -0.2996590733528137, 0.95394366979599, 3.845470428466797], [-4.656612873077393e-10, 0.9540371894836426, 0.29968830943107605, 1.2080823183059692], [0.0, 0.0, 0.0, 1.0]]}]}
jaxnerf/nerf/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
jaxnerf/nerf/clip_utils.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+ from absl import flags
4
+ from functools import partial
5
+
6
+ import jax
7
+ from jax import random
8
+ import jax.numpy as jnp
9
+ import numpy as np
10
+ from transformers import FlaxCLIPModel
11
+
12
+ FLAGS = flags.FLAGS
13
+ # import jmp
14
+ # my_policy = jmp.Policy(compute_dtype=np.float16,
15
+ # param_dtype=np.float16,
16
+ # output_dtype=np.float16)
17
+
18
+
19
+ @partial(jax.jit, static_argnums=[0, 1])
20
+ def update_semantic_loss(model, clip_model, rng, state, batch, lr):
21
+ # the batch is without shard
22
+ random_rays = batch["random_rays"]
23
+ #rng, key_0, key_1 = rng
24
+ rng, key_0, key_1 = random.split(rng,3)
25
+
26
+ def semantic_loss(variables):
27
+ # TODO @Alex: (alt) sample less along a ray/ sample on a strided grid (make change on model call)
28
+ # TODO @Alex: (alt) apply mixed precision
29
+ src_ret = model.apply(variables, key_0, key_1, random_rays, False)
30
+ src_image, _, _ = src_ret[-1]
31
+ # reshape flat pixel to an image (assume 3 channels & square shape)
32
+ w = int(math.sqrt(src_image.shape[0]))
33
+ src_image = src_image.reshape([-1, w, w, 3]).transpose(0, 3, 1, 2)
34
+ src_image = preprocess_for_CLIP(src_image)
35
+ src_embedding = clip_model.get_image_features(pixel_values=src_image)
36
+ src_embedding /= jnp.linalg.norm(src_embedding, axis=-1, keepdims=True)
37
+ src_embedding = jnp.array(src_embedding)
38
+ target_embedding = batch["embedding"]
39
+ sc_loss = 0.5 * FLAGS.sc_loss_mult * jnp.sum((src_embedding - target_embedding) ** 2) / src_embedding.shape[0]
40
+ return sc_loss * 1e-2
41
+
42
+ sc_loss, grad = jax.value_and_grad(semantic_loss)(jax.device_get(jax.tree_map(lambda x:x[0], state)).optimizer.target)
43
+ return sc_loss, grad
44
+
45
+ def trans_t(t):
46
+ return jnp.array([
47
+ [1, 0, 0, 0],
48
+ [0, 1, 0, 0],
49
+ [0, 0, 1, t],
50
+ [0, 0, 0, 1]], dtype=jnp.float32)
51
+
52
+
53
+ def rot_phi(phi):
54
+ return jnp.array([
55
+ [1, 0, 0, 0],
56
+ [0, jnp.cos(phi), -np.sin(phi), 0],
57
+ [0, jnp.sin(phi), jnp.cos(phi), 0],
58
+ [0, 0, 0, 1]], dtype=jnp.float32)
59
+
60
+
61
+ def rot_theta(th):
62
+ return jnp.array([
63
+ [np.cos(th), 0, -np.sin(th), 0],
64
+ [0, 1, 0, 0],
65
+ [np.sin(th), 0, jnp.cos(th), 0],
66
+ [0, 0, 0, 1]], dtype=jnp.float32)
67
+
68
+
69
+ def pose_spherical(theta, phi, radius):
70
+ c2w = trans_t(radius)
71
+ c2w = rot_phi(phi / 180. * jnp.pi) @ c2w
72
+ c2w = rot_theta(theta / 180. * jnp.pi) @ c2w
73
+ c2w = jnp.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) @ c2w
74
+ return c2w
75
+
76
+
77
+ def random_pose(rng, bds):
78
+ rng, *rng_inputs = jax.random.split(rng, 3)
79
+ radius = random.uniform(rng_inputs[1], minval=bds[0], maxval=bds[1])
80
+ theta = random.uniform(rng_inputs[1], minval=0, maxval=2 * jnp.pi)
81
+ phi = random.uniform(rng_inputs[1], minval=0, maxval=np.pi / 2)
82
+ return pose_spherical(radius, theta, phi)
83
+
84
+
85
+ def preprocess_for_CLIP(image):
86
+ """
87
+ jax-based preprocessing for CLIP
88
+ image [B, 3, H, W]: batch image
89
+ return [B, 3, 224, 224]: pre-processed image for CLIP
90
+ """
91
+ B, D, H, W = image.shape
92
+ image = jax.image.resize(image, (B, D, 224, 224), 'bicubic') # assume that images have rectangle shape.
93
+ mean = jnp.array([0.48145466, 0.4578275, 0.40821073]).reshape(1, 3, 1, 1)
94
+ std = jnp.array([0.26862954, 0.26130258, 0.27577711]).reshape(1, 3, 1, 1)
95
+ image = (image - mean.astype(image.dtype)) / std.astype(image.dtype)
96
+ return image
97
+
98
+
99
+ # TODO @Alex: VisionModel v.s. original CLIP? (differ by a projection matrix)
100
+ def init_CLIP(dtype: str, model_name: Optional[str]) -> FlaxCLIPModel:
101
+ if dtype == 'float16':
102
+ dtype = jnp.float16
103
+ elif dtype == 'float32':
104
+ dtype = jnp.float32
105
+ else:
106
+ raise ValueError
107
+
108
+ if model_name is None:
109
+ model_name = 'openai/clip-vit-base-patch32'
110
+ return FlaxCLIPModel.from_pretrained(model_name, dtype=dtype)
111
+
112
+
113
+ # def SC_loss(rng_inputs, model, params, bds, rays, N_samples, target_emb, CLIP_model, l):
114
+ # """
115
+ # target_emb [1, D]: pre-computed target embedding vector \phi(I)
116
+ # source_img [1, 3, H, W]: source image \hat{I}
117
+ # l: loss weight lambda
118
+ # return: SC_loss
119
+ # """
120
+ # # _,H,W,D = rays.shape
121
+ # rng_inputs, model, params, bds, rays, N_samples, target_emb, CLIP_model, l = my_policy.cast_to_compute(
122
+ # (rng_inputs, model, params, bds, rays, N_samples, target_emb, CLIP_model, l))
123
+ # _, H, W, _ = rays.shape
124
+ # source_img = jnp.clip(render_fn(rng_inputs, model, params, None,
125
+ # np.reshape(rays, (2, -1, 3)),
126
+ # bds[0], bds[1], 1, rand=False),
127
+ # 0, 1)
128
+ # # source_img = np.clip(render_rays(rng_inputs, model, params, None, np.reshape(rays, (2, -1, 3)), bds[0], bds[1], 1, rand=False), 0, 1)
129
+ # source_img = np.reshape(source_img, [1, H, W, 3]).transpose(0, 3, 1, 2)
130
+ # source_img = preprocess_for_CLIP(source_img)
131
+ # source_emb = CLIP_model.get_image_features(pixel_values=source_img)
132
+ # source_emb /= np.linalg.norm(source_emb, axis=-1, keepdims=True)
133
+ # return l/2 * (np.sum((source_emb - target_emb) ** 2) / source_emb.shape[0])
134
+
jaxnerf/nerf/datasets.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Lint as: python3
17
+ """Different datasets implementation plus a general port for all the datasets."""
18
+ INTERNAL = False # pylint: disable=g-statement-before-imports
19
+ import json
20
+ import os
21
+ from os import path
22
+ import queue
23
+ import threading
24
+
25
+ if not INTERNAL:
26
+ import cv2 # pylint: disable=g-import-not-at-top
27
+ import jax
28
+ import numpy as np
29
+ from PIL import Image
30
+
31
+ from jaxnerf.nerf import utils
32
+ from jaxnerf.nerf import clip_utils
33
+
34
+ def get_dataset(split, args, clip_model = None):
35
+ return dataset_dict[args.dataset](split, args, clip_model)
36
+
37
+
38
+ def convert_to_ndc(origins, directions, focal, w, h, near=1.):
39
+ """Convert a set of rays to NDC coordinates."""
40
+ # Shift ray origins to near plane
41
+ t = -(near + origins[..., 2]) / directions[..., 2]
42
+ origins = origins + t[..., None] * directions
43
+
44
+ dx, dy, dz = tuple(np.moveaxis(directions, -1, 0))
45
+ ox, oy, oz = tuple(np.moveaxis(origins, -1, 0))
46
+
47
+ # Projection
48
+ o0 = -((2 * focal) / w) * (ox / oz)
49
+ o1 = -((2 * focal) / h) * (oy / oz)
50
+ o2 = 1 + 2 * near / oz
51
+
52
+ d0 = -((2 * focal) / w) * (dx / dz - ox / oz)
53
+ d1 = -((2 * focal) / h) * (dy / dz - oy / oz)
54
+ d2 = -2 * near / oz
55
+
56
+ origins = np.stack([o0, o1, o2], -1)
57
+ directions = np.stack([d0, d1, d2], -1)
58
+ return origins, directions
59
+
60
+
61
+ class Dataset(threading.Thread):
62
+ """Dataset Base Class."""
63
+
64
+ def __init__(self, split, flags, clip_model):
65
+ super(Dataset, self).__init__()
66
+ self.queue = queue.Queue(3) # Set prefetch buffer to 3 batches.
67
+ self.daemon = True
68
+ self.use_pixel_centers = flags.use_pixel_centers
69
+ self.split = split
70
+
71
+ if split == "train":
72
+ self._train_init(flags, clip_model)
73
+ elif split == "test":
74
+ self._test_init(flags)
75
+ else:
76
+ raise ValueError(
77
+ "the split argument should be either \"train\" or \"test\", set"
78
+ "to {} here.".format(split))
79
+ self.batch_size = flags.batch_size // jax.process_count()
80
+ self.batching = flags.batching
81
+ self.render_path = flags.render_path
82
+ self.far = flags.far
83
+ self.near = flags.near
84
+ self.max_steps = flags.max_steps
85
+ self.sc_loss_factor = flags.sc_loss_factor
86
+ self.start()
87
+
88
+ def __iter__(self):
89
+ return self
90
+
91
+ def __next__(self):
92
+ """Get the next training batch or test example.
93
+
94
+ Returns:
95
+ batch: dict, has "pixels" and "rays".
96
+ """
97
+ x = self.queue.get()
98
+ if self.split == "train":
99
+ return utils.shard(x)
100
+ else:
101
+ return utils.to_device(x)
102
+
103
+ def peek(self):
104
+ """Peek at the next training batch or test example without dequeuing it.
105
+
106
+ Returns:
107
+ batch: dict, has "pixels" and "rays".
108
+ """
109
+ x = self.queue.queue[0].copy() # Make a copy of the front of the queue.
110
+ if self.split == "train":
111
+ return utils.shard(x)
112
+ else:
113
+ return utils.to_device(x)
114
+
115
+ def run(self):
116
+ if self.split == "train":
117
+ next_func = self._next_train
118
+ else:
119
+ next_func = self._next_test
120
+ while True:
121
+ self.queue.put(next_func())
122
+
123
+ @property
124
+ def size(self):
125
+ return self.n_examples
126
+
127
+ def _train_init(self, flags, clip_model):
128
+ """Initialize training."""
129
+ self._load_renderings(flags, clip_model)
130
+ self._generate_rays()
131
+
132
+ if flags.batching == "all_images":
133
+ # flatten the ray and image dimension together.
134
+ self.images = self.images.reshape([-1, 3])
135
+ self.rays = utils.namedtuple_map(lambda r: r.reshape([-1, r.shape[-1]]),
136
+ self.rays)
137
+ elif flags.batching == "single_image":
138
+ self.images = self.images.reshape([-1, self.resolution, 3])
139
+ self.rays = utils.namedtuple_map(
140
+ lambda r: r.reshape([-1, self.resolution, r.shape[-1]]), self.rays)
141
+ else:
142
+ raise NotImplementedError(
143
+ f"{flags.batching} batching strategy is not implemented.")
144
+
145
+ def _test_init(self, flags):
146
+ self._load_renderings(flags, clip_model = None)
147
+ self._generate_rays()
148
+ self.it = 0
149
+
150
+ def _next_train(self):
151
+ """Sample next training batch."""
152
+
153
+ if self.batching == "all_images":
154
+ ray_indices = np.random.randint(0, self.rays[0].shape[0],
155
+ (self.batch_size,))
156
+ batch_pixels = self.images[ray_indices]
157
+ batch_rays = utils.namedtuple_map(lambda r: r[ray_indices], self.rays)
158
+ raise NotImplementedError("image_index not implemented for batching=all_images")
159
+
160
+ elif self.batching == "single_image":
161
+ image_index = np.random.randint(0, self.n_examples, ())
162
+ ray_indices = np.random.randint(0, self.rays[0][0].shape[0],
163
+ (self.batch_size,))
164
+ batch_pixels = self.images[image_index][ray_indices]
165
+ batch_rays = utils.namedtuple_map(lambda r: r[image_index][ray_indices],
166
+ self.rays)
167
+ else:
168
+ raise NotImplementedError(
169
+ f"{self.batching} batching strategy is not implemented.")
170
+ return {"pixels": batch_pixels, "rays": batch_rays, "image_index": image_index}
171
+
172
+ def _next_test(self):
173
+ """Sample next test example."""
174
+ idx = self.it
175
+ self.it = (self.it + 1) % self.n_examples
176
+
177
+ if self.render_path:
178
+ return {"rays": utils.namedtuple_map(lambda r: r[idx], self.render_rays)}
179
+ else:
180
+ return {"pixels": self.images[idx],
181
+ "rays": utils.namedtuple_map(lambda r: r[idx], self.rays),
182
+ "image_index": idx}
183
+
184
+ # TODO(bydeng): Swap this function with a more flexible camera model.
185
+ def _generate_rays(self):
186
+ """Generating rays for all images."""
187
+ pixel_center = 0.5 if self.use_pixel_centers else 0.0
188
+ x, y = np.meshgrid( # pylint: disable=unbalanced-tuple-unpacking
189
+ np.arange(self.w, dtype=np.float32) + pixel_center, # X-Axis (columns)
190
+ np.arange(self.h, dtype=np.float32) + pixel_center, # Y-Axis (rows)
191
+ indexing="xy")
192
+ camera_dirs = np.stack([(x - self.w * 0.5) / self.focal,
193
+ -(y - self.h * 0.5) / self.focal, -np.ones_like(x)],
194
+ axis=-1)
195
+ directions = ((camera_dirs[None, ..., None, :] *
196
+ self.camtoworlds[:, None, None, :3, :3]).sum(axis=-1))
197
+ origins = np.broadcast_to(self.camtoworlds[:, None, None, :3, -1],
198
+ directions.shape)
199
+ viewdirs = directions / np.linalg.norm(directions, axis=-1, keepdims=True)
200
+ self.rays = utils.Rays(
201
+ origins=origins, directions=directions, viewdirs=viewdirs)
202
+
203
+ def camtoworld_matrix_to_rays(self, camtoworld, downsample = 1):
204
+ """ render one instance of rays given a camera to world matrix (4, 4) """
205
+ pixel_center = 0.5 if self.use_pixel_centers else 0.0
206
+ # TODO @Alex: apply mesh downsampling here
207
+ x, y = np.meshgrid( # pylint: disable=unbalanced-tuple-unpacking
208
+ np.arange(self.w, step = downsample, dtype=np.float32) + pixel_center, # X-Axis (columns)
209
+ np.arange(self.h, step = downsample, dtype=np.float32) + pixel_center, # Y-Axis (rows)
210
+ indexing="xy")
211
+ camera_dirs = np.stack([(x - self.w * 0.5) / self.focal,
212
+ -(y - self.h * 0.5) / self.focal, -np.ones_like(x)],
213
+ axis=-1)
214
+ directions = (camera_dirs[..., None, :] * camtoworld[None, None, :3, :3]).sum(axis=-1)
215
+ origins = np.broadcast_to(camtoworld[None, None, :3, -1], directions.shape)
216
+ viewdirs = directions / np.linalg.norm(directions, axis=-1, keepdims=True)
217
+ return utils.Rays(origins=origins, directions=directions, viewdirs=viewdirs)
218
+
219
+ class Blender(Dataset):
220
+ """Blender Dataset."""
221
+
222
+ def _load_renderings(self, flags, clip_model = None):
223
+ """Load images from disk."""
224
+ if flags.render_path:
225
+ raise ValueError("render_path cannot be used for the blender dataset.")
226
+ cams, images, meta = self.load_files(flags.data_dir, self.split, flags.factor, flags.few_shot)
227
+
228
+ # load in CLIP precomputed image features
229
+ self.images = np.stack(images, axis=0)
230
+ if flags.white_bkgd:
231
+ self.images = (self.images[..., :3] * self.images[..., -1:] +
232
+ (1. - self.images[..., -1:]))
233
+ else:
234
+ self.images = self.images[..., :3]
235
+ self.h, self.w = self.images.shape[1:3]
236
+ self.resolution = self.h * self.w
237
+ self.camtoworlds = np.stack(cams, axis=0)
238
+ camera_angle_x = float(meta["camera_angle_x"])
239
+ self.focal = .5 * self.w / np.tan(.5 * camera_angle_x)
240
+ self.n_examples = self.images.shape[0]
241
+
242
+ if flags.use_semantic_loss and clip_model is not None:
243
+ embs = []
244
+ for img in self.images:
245
+ img = np.expand_dims(np.transpose(img,[2,0,1]), 0)
246
+ embs.append(clip_model.get_image_features(pixel_values = clip_utils.preprocess_for_CLIP(img)))
247
+ self.embeddings = np.concatenate(embs, 0)
248
+
249
+ self.image_idx = np.arange(self.images.shape[0])
250
+ np.random.shuffle(self.image_idx)
251
+ self.image_idx = self.image_idx.tolist()
252
+
253
+ # self.embeddings = utils.read_pickle(flags.precompute_pkl_path)
254
+ # self.precompute_pkl_path = flags.precompute_pkl_path
255
+
256
+
257
+ @staticmethod
258
+ def load_files(data_dir, split, factor, few_shot):
259
+ with utils.open_file(path.join(data_dir, "transforms_{}.json".format(split)), "r") as fp:
260
+ meta = json.load(fp)
261
+ images = []
262
+ cams = []
263
+
264
+ frames = np.arange(len(meta["frames"]))
265
+ if few_shot > 0 and split == 'train':
266
+ np.random.shuffle(frames)
267
+ frames = frames[:few_shot]
268
+
269
+ for i in frames:
270
+ frame = meta["frames"][i]
271
+ fname = os.path.join(data_dir, frame["file_path"] + ".png")
272
+ with utils.open_file(fname, "rb") as imgin:
273
+ image = np.array(Image.open(imgin)).astype(np.float32) / 255.
274
+ if factor == 2:
275
+ [halfres_h, halfres_w] = [hw // 2 for hw in image.shape[:2]]
276
+ image = cv2.resize(image, (halfres_w, halfres_h),
277
+ interpolation=cv2.INTER_AREA)
278
+ elif factor == 4:
279
+ [halfres_h, halfres_w] = [hw // 4 for hw in image.shape[:2]]
280
+ image = cv2.resize(image, (halfres_w, halfres_h),
281
+ interpolation=cv2.INTER_AREA)
282
+ elif factor > 0:
283
+ raise ValueError("Blender dataset only supports factor=0 or 2 or 4, {} "
284
+ "set.".format(factor))
285
+ cams.append(np.array(frame["transform_matrix"], dtype=np.float32))
286
+ images.append(image)
287
+ return cams, images, meta
288
+
289
+ def _next_train(self):
290
+ batch_dict = super(Blender, self)._next_train()
291
+ if self.batching == "single_image":
292
+ image_index = batch_dict.pop("image_index")
293
+ # target image for CLIP
294
+ '''
295
+ batch_dict["embedding"] = self.embeddings[image_index]
296
+
297
+ # source rays for CLIP (for constructing source image later)
298
+ src_seed = int(np.random.randint(0, self.max_steps, ()))
299
+ src_rng = jax.random.PRNGKey(src_seed)
300
+ src_camtoworld = np.array(clip_utils.random_pose(src_rng, (self.near, self.far)))
301
+ random_rays = self.camtoworld_matrix_to_rays(src_camtoworld, downsample = 16)
302
+ random_rays = utils.Rays(origins=np.reshape(random_rays[0], [-1,3]), directions=np.reshape(random_rays[1], [-1,3]), viewdirs=np.reshape(random_rays[2], [-1,3]))
303
+ batch_dict["random_rays"] = random_rays
304
+ '''
305
+ else:
306
+ raise NotImplementedError
307
+ return batch_dict
308
+
309
+ def get_clip_data(self):
310
+ if len(self.image_idx) == 0:
311
+ self.image_idx = np.arange(self.images.shape[0])
312
+ np.random.shuffle(self.image_idx)
313
+ self.image_idx = self.image_idx.tolist()
314
+ image_index = self.image_idx.pop()
315
+
316
+ batch_dict = {}
317
+ batch_dict["embedding"] = self.embeddings[image_index]
318
+
319
+ # source rays for CLIP (for constructing source image later)
320
+ src_seed = int(np.random.randint(0, self.max_steps, ()))
321
+ src_rng = jax.random.PRNGKey(src_seed)
322
+ src_camtoworld = np.array(clip_utils.random_pose(src_rng, (self.near, self.far)))
323
+ random_rays = self.camtoworld_matrix_to_rays(src_camtoworld, downsample = 16)
324
+ random_rays = utils.Rays(origins=np.reshape(random_rays[0], [-1,3]), directions=np.reshape(random_rays[1], [-1,3]), viewdirs=np.reshape(random_rays[2], [-1,3]))
325
+ batch_dict["random_rays"] = random_rays
326
+ return batch_dict
327
+
328
+ class LLFF(Dataset):
329
+ """LLFF Dataset."""
330
+
331
+ def _load_renderings(self, flags):
332
+ """Load images from disk."""
333
+ # Load images.
334
+ imgdir_suffix = ""
335
+ if flags.factor > 0:
336
+ imgdir_suffix = "_{}".format(flags.factor)
337
+ factor = flags.factor
338
+ else:
339
+ factor = 1
340
+ imgdir = path.join(flags.data_dir, "images" + imgdir_suffix)
341
+ if not utils.file_exists(imgdir):
342
+ raise ValueError("Image folder {} doesn't exist.".format(imgdir))
343
+ imgfiles = [
344
+ path.join(imgdir, f)
345
+ for f in sorted(utils.listdir(imgdir))
346
+ if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png")
347
+ ]
348
+ images = []
349
+ for imgfile in imgfiles:
350
+ with utils.open_file(imgfile, "rb") as imgin:
351
+ image = np.array(Image.open(imgin), dtype=np.float32) / 255.
352
+ images.append(image)
353
+ images = np.stack(images, axis=-1)
354
+
355
+ # Load poses and bds.
356
+ with utils.open_file(path.join(flags.data_dir, "poses_bounds.npy"),
357
+ "rb") as fp:
358
+ poses_arr = np.load(fp)
359
+ poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0])
360
+ bds = poses_arr[:, -2:].transpose([1, 0])
361
+ if poses.shape[-1] != images.shape[-1]:
362
+ raise RuntimeError("Mismatch between imgs {} and poses {}".format(
363
+ images.shape[-1], poses.shape[-1]))
364
+
365
+ # Update poses according to downsampling.
366
+ poses[:2, 4, :] = np.array(images.shape[:2]).reshape([2, 1])
367
+ poses[2, 4, :] = poses[2, 4, :] * 1. / factor
368
+
369
+ # Correct rotation matrix ordering and move variable dim to axis 0.
370
+ poses = np.concatenate(
371
+ [poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1)
372
+ poses = np.moveaxis(poses, -1, 0).astype(np.float32)
373
+ images = np.moveaxis(images, -1, 0)
374
+ bds = np.moveaxis(bds, -1, 0).astype(np.float32)
375
+
376
+ # Rescale according to a default bd factor.
377
+ scale = 1. / (bds.min() * .75)
378
+ poses[:, :3, 3] *= scale
379
+ bds *= scale
380
+
381
+ # Recenter poses.
382
+ poses = self._recenter_poses(poses)
383
+
384
+ # Generate a spiral/spherical ray path for rendering videos.
385
+ if flags.spherify:
386
+ poses = self._generate_spherical_poses(poses, bds)
387
+ self.spherify = True
388
+ else:
389
+ self.spherify = False
390
+ if not flags.spherify and self.split == "test":
391
+ self._generate_spiral_poses(poses, bds)
392
+
393
+ # Select the split.
394
+ i_test = np.arange(images.shape[0])[::flags.llffhold]
395
+ i_train = np.array(
396
+ [i for i in np.arange(int(images.shape[0])) if i not in i_test])
397
+ if self.split == "train":
398
+ indices = i_train
399
+ else:
400
+ indices = i_test
401
+ images = images[indices]
402
+ poses = poses[indices]
403
+
404
+ self.images = images
405
+ self.camtoworlds = poses[:, :3, :4]
406
+ self.focal = poses[0, -1, -1]
407
+ self.h, self.w = images.shape[1:3]
408
+ self.resolution = self.h * self.w
409
+ if flags.render_path:
410
+ self.n_examples = self.render_poses.shape[0]
411
+ else:
412
+ self.n_examples = images.shape[0]
413
+
414
+ def _generate_rays(self):
415
+ """Generate normalized device coordinate rays for llff."""
416
+ if self.split == "test":
417
+ n_render_poses = self.render_poses.shape[0]
418
+ self.camtoworlds = np.concatenate([self.render_poses, self.camtoworlds],
419
+ axis=0)
420
+
421
+ super()._generate_rays()
422
+
423
+ if not self.spherify:
424
+ ndc_origins, ndc_directions = convert_to_ndc(self.rays.origins,
425
+ self.rays.directions,
426
+ self.focal, self.w, self.h)
427
+ self.rays = utils.Rays(
428
+ origins=ndc_origins,
429
+ directions=ndc_directions,
430
+ viewdirs=self.rays.viewdirs)
431
+
432
+ # Split poses from the dataset and generated poses
433
+ if self.split == "test":
434
+ self.camtoworlds = self.camtoworlds[n_render_poses:]
435
+ split = [np.split(r, [n_render_poses], 0) for r in self.rays]
436
+ split0, split1 = zip(*split)
437
+ self.render_rays = utils.Rays(*split0)
438
+ self.rays = utils.Rays(*split1)
439
+
440
+ def _recenter_poses(self, poses):
441
+ """Recenter poses according to the original NeRF code."""
442
+ poses_ = poses.copy()
443
+ bottom = np.reshape([0, 0, 0, 1.], [1, 4])
444
+ c2w = self._poses_avg(poses)
445
+ c2w = np.concatenate([c2w[:3, :4], bottom], -2)
446
+ bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1])
447
+ poses = np.concatenate([poses[:, :3, :4], bottom], -2)
448
+ poses = np.linalg.inv(c2w) @ poses
449
+ poses_[:, :3, :4] = poses[:, :3, :4]
450
+ poses = poses_
451
+ return poses
452
+
453
+ def _poses_avg(self, poses):
454
+ """Average poses according to the original NeRF code."""
455
+ hwf = poses[0, :3, -1:]
456
+ center = poses[:, :3, 3].mean(0)
457
+ vec2 = self._normalize(poses[:, :3, 2].sum(0))
458
+ up = poses[:, :3, 1].sum(0)
459
+ c2w = np.concatenate([self._viewmatrix(vec2, up, center), hwf], 1)
460
+ return c2w
461
+
462
+ def _viewmatrix(self, z, up, pos):
463
+ """Construct lookat view matrix."""
464
+ vec2 = self._normalize(z)
465
+ vec1_avg = up
466
+ vec0 = self._normalize(np.cross(vec1_avg, vec2))
467
+ vec1 = self._normalize(np.cross(vec2, vec0))
468
+ m = np.stack([vec0, vec1, vec2, pos], 1)
469
+ return m
470
+
471
+ def _normalize(self, x):
472
+ """Normalization helper function."""
473
+ return x / np.linalg.norm(x)
474
+
475
+ def _generate_spiral_poses(self, poses, bds):
476
+ """Generate a spiral path for rendering."""
477
+ c2w = self._poses_avg(poses)
478
+ # Get average pose.
479
+ up = self._normalize(poses[:, :3, 1].sum(0))
480
+ # Find a reasonable "focus depth" for this dataset.
481
+ close_depth, inf_depth = bds.min() * .9, bds.max() * 5.
482
+ dt = .75
483
+ mean_dz = 1. / (((1. - dt) / close_depth + dt / inf_depth))
484
+ focal = mean_dz
485
+ # Get radii for spiral path.
486
+ tt = poses[:, :3, 3]
487
+ rads = np.percentile(np.abs(tt), 90, 0)
488
+ c2w_path = c2w
489
+ n_views = 120
490
+ n_rots = 2
491
+ # Generate poses for spiral path.
492
+ render_poses = []
493
+ rads = np.array(list(rads) + [1.])
494
+ hwf = c2w_path[:, 4:5]
495
+ zrate = .5
496
+ for theta in np.linspace(0., 2. * np.pi * n_rots, n_views + 1)[:-1]:
497
+ c = np.dot(c2w[:3, :4], (np.array(
498
+ [np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads))
499
+ z = self._normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.])))
500
+ render_poses.append(np.concatenate([self._viewmatrix(z, up, c), hwf], 1))
501
+ self.render_poses = np.array(render_poses).astype(np.float32)[:, :3, :4]
502
+
503
+ def _generate_spherical_poses(self, poses, bds):
504
+ """Generate a 360 degree spherical path for rendering."""
505
+ # pylint: disable=g-long-lambda
506
+ p34_to_44 = lambda p: np.concatenate([
507
+ p,
508
+ np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])
509
+ ], 1)
510
+ rays_d = poses[:, :3, 2:3]
511
+ rays_o = poses[:, :3, 3:4]
512
+
513
+ def min_line_dist(rays_o, rays_d):
514
+ a_i = np.eye(3) - rays_d * np.transpose(rays_d, [0, 2, 1])
515
+ b_i = -a_i @ rays_o
516
+ pt_mindist = np.squeeze(-np.linalg.inv(
517
+ (np.transpose(a_i, [0, 2, 1]) @ a_i).mean(0)) @ (b_i).mean(0))
518
+ return pt_mindist
519
+
520
+ pt_mindist = min_line_dist(rays_o, rays_d)
521
+ center = pt_mindist
522
+ up = (poses[:, :3, 3] - center).mean(0)
523
+ vec0 = self._normalize(up)
524
+ vec1 = self._normalize(np.cross([.1, .2, .3], vec0))
525
+ vec2 = self._normalize(np.cross(vec0, vec1))
526
+ pos = center
527
+ c2w = np.stack([vec1, vec2, vec0, pos], 1)
528
+ poses_reset = (
529
+ np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:, :3, :4]))
530
+ rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:, :3, 3]), -1)))
531
+ sc = 1. / rad
532
+ poses_reset[:, :3, 3] *= sc
533
+ bds *= sc
534
+ rad *= sc
535
+ centroid = np.mean(poses_reset[:, :3, 3], 0)
536
+ zh = centroid[2]
537
+ radcircle = np.sqrt(rad ** 2 - zh ** 2)
538
+ new_poses = []
539
+
540
+ for th in np.linspace(0., 2. * np.pi, 120):
541
+ camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
542
+ up = np.array([0, 0, -1.])
543
+ vec2 = self._normalize(camorigin)
544
+ vec0 = self._normalize(np.cross(vec2, up))
545
+ vec1 = self._normalize(np.cross(vec2, vec0))
546
+ pos = camorigin
547
+ p = np.stack([vec0, vec1, vec2, pos], 1)
548
+ new_poses.append(p)
549
+
550
+ new_poses = np.stack(new_poses, 0)
551
+ new_poses = np.concatenate([
552
+ new_poses,
553
+ np.broadcast_to(poses[0, :3, -1:], new_poses[:, :3, -1:].shape)
554
+ ], -1)
555
+ poses_reset = np.concatenate([
556
+ poses_reset[:, :3, :4],
557
+ np.broadcast_to(poses[0, :3, -1:], poses_reset[:, :3, -1:].shape)
558
+ ], -1)
559
+ if self.split == "test":
560
+ self.render_poses = new_poses[:, :3, :4]
561
+ return poses_reset
562
+
563
+
564
+ dataset_dict = {"blender": Blender,
565
+ "llff": LLFF}
jaxnerf/nerf/model_utils.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Lint as: python3
17
+ """Helper functions/classes for model definition."""
18
+
19
+ import functools
20
+ from typing import Any, Callable
21
+
22
+ from flax import linen as nn
23
+ import jax
24
+ from jax import lax
25
+ from jax import random
26
+ import jax.numpy as jnp
27
+
28
+
29
+ class MLP(nn.Module):
30
+ """A simple MLP."""
31
+ net_depth: int = 8 # The depth of the first part of MLP.
32
+ net_width: int = 256 # The width of the first part of MLP.
33
+ net_depth_condition: int = 1 # The depth of the second part of MLP.
34
+ net_width_condition: int = 128 # The width of the second part of MLP.
35
+ net_activation: Callable[..., Any] = nn.relu # The activation function.
36
+ skip_layer: int = 4 # The layer to add skip layers to.
37
+ num_rgb_channels: int = 3 # The number of RGB channels.
38
+ num_sigma_channels: int = 1 # The number of sigma channels.
39
+
40
+ @nn.compact
41
+ def __call__(self, x, condition=None):
42
+ """
43
+ Evaluate the MLP.
44
+
45
+ Args:
46
+ x: jnp.ndarray(float32), [batch, num_samples, feature], points.
47
+ condition: jnp.ndarray(float32), [batch, feature], if not None, this
48
+ variable will be part of the input to the second part of the MLP
49
+ concatenated with the output vector of the first part of the MLP. If
50
+ None, only the first part of the MLP will be used with input x. In the
51
+ original paper, this variable is the view direction.
52
+
53
+ Returns:
54
+ raw_rgb: jnp.ndarray(float32), with a shape of
55
+ [batch, num_samples, num_rgb_channels].
56
+ raw_sigma: jnp.ndarray(float32), with a shape of
57
+ [batch, num_samples, num_sigma_channels].
58
+ """
59
+ feature_dim = x.shape[-1]
60
+ num_samples = x.shape[1]
61
+ x = x.reshape([-1, feature_dim])
62
+ dense_layer = functools.partial(
63
+ nn.Dense, kernel_init=jax.nn.initializers.glorot_uniform())
64
+ inputs = x
65
+ for i in range(self.net_depth):
66
+ x = dense_layer(self.net_width)(x)
67
+ x = self.net_activation(x)
68
+ if i % self.skip_layer == 0 and i > 0:
69
+ x = jnp.concatenate([x, inputs], axis=-1)
70
+ raw_sigma = dense_layer(self.num_sigma_channels)(x).reshape(
71
+ [-1, num_samples, self.num_sigma_channels])
72
+
73
+ if condition is not None:
74
+ # Output of the first part of MLP.
75
+ bottleneck = dense_layer(self.net_width)(x)
76
+ # Broadcast condition from [batch, feature] to
77
+ # [batch, num_samples, feature] since all the samples along the same ray
78
+ # have the same viewdir.
79
+ condition = jnp.tile(condition[:, None, :], (1, num_samples, 1))
80
+ # Collapse the [batch, num_samples, feature] tensor to
81
+ # [batch * num_samples, feature] so that it can be fed into nn.Dense.
82
+ condition = condition.reshape([-1, condition.shape[-1]])
83
+ x = jnp.concatenate([bottleneck, condition], axis=-1)
84
+ # Here use 1 extra layer to align with the original nerf model.
85
+ for i in range(self.net_depth_condition):
86
+ x = dense_layer(self.net_width_condition)(x)
87
+ x = self.net_activation(x)
88
+ raw_rgb = dense_layer(self.num_rgb_channels)(x).reshape(
89
+ [-1, num_samples, self.num_rgb_channels])
90
+ return raw_rgb, raw_sigma
91
+
92
+
93
+ def cast_rays(z_vals, origins, directions):
94
+ return origins[..., None, :] + z_vals[..., None] * directions[..., None, :]
95
+
96
+
97
+ def sample_along_rays(key, origins, directions, num_samples, near, far,
98
+ randomized, lindisp):
99
+ """
100
+ Stratified sampling along the rays.
101
+
102
+ Args:
103
+ key: jnp.ndarray, random generator key.
104
+ origins: jnp.ndarray(float32), [batch_size, 3], ray origins.
105
+ directions: jnp.ndarray(float32), [batch_size, 3], ray directions.
106
+ num_samples: int.
107
+ near: float, near clip.
108
+ far: float, far clip.
109
+ randomized: bool, use randomized stratified sampling.
110
+ lindisp: bool, sampling linearly in disparity rather than depth.
111
+
112
+ Returns:
113
+ z_vals: jnp.ndarray, [batch_size, num_samples], sampled z values.
114
+ points: jnp.ndarray, [batch_size, num_samples, 3], sampled points.
115
+ """
116
+ batch_size = origins.shape[0]
117
+
118
+ t_vals = jnp.linspace(0., 1., num_samples)
119
+ if lindisp:
120
+ z_vals = 1. / (1. / near * (1. - t_vals) + 1. / far * t_vals)
121
+ else:
122
+ z_vals = near * (1. - t_vals) + far * t_vals
123
+
124
+ if randomized:
125
+ mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
126
+ upper = jnp.concatenate([mids, z_vals[..., -1:]], -1)
127
+ lower = jnp.concatenate([z_vals[..., :1], mids], -1)
128
+ t_rand = random.uniform(key, [batch_size, num_samples])
129
+ z_vals = lower + (upper - lower) * t_rand
130
+ else:
131
+ # Broadcast z_vals to make the returned shape consistent.
132
+ z_vals = jnp.broadcast_to(z_vals[None, ...], [batch_size, num_samples])
133
+
134
+ coords = cast_rays(z_vals, origins, directions)
135
+ return z_vals, coords
136
+
137
+
138
+ def posenc(x, min_deg, max_deg, legacy_posenc_order=False):
139
+ """
140
+ Cat x with a positional encoding of x with scales 2^[min_deg, max_deg-1].
141
+
142
+ Instead of computing [sin(x), cos(x)], we use the trig identity
143
+ cos(x) = sin(x + pi/2) and do one vectorized call to sin([x, x+pi/2]).
144
+
145
+ Args:
146
+ x: jnp.ndarray, variables to be encoded. Note that x should be in [-pi, pi].
147
+ min_deg: int, the minimum (inclusive) degree of the encoding.
148
+ max_deg: int, the maximum (exclusive) degree of the encoding.
149
+ legacy_posenc_order: bool, keep the same ordering as the original tf code.
150
+
151
+ Returns:
152
+ encoded: jnp.ndarray, encoded variables.
153
+ """
154
+ if min_deg == max_deg:
155
+ return x
156
+ scales = jnp.array([2 ** i for i in range(min_deg, max_deg)])
157
+ if legacy_posenc_order:
158
+ xb = x[..., None, :] * scales[:, None]
159
+ four_feat = jnp.reshape(
160
+ jnp.sin(jnp.stack([xb, xb + 0.5 * jnp.pi], -2)),
161
+ list(x.shape[:-1]) + [-1])
162
+ else:
163
+ xb = jnp.reshape((x[..., None, :] * scales[:, None]),
164
+ list(x.shape[:-1]) + [-1])
165
+ four_feat = jnp.sin(jnp.concatenate([xb, xb + 0.5 * jnp.pi], axis=-1))
166
+ return jnp.concatenate([x] + [four_feat], axis=-1)
167
+
168
+
169
+ def volumetric_rendering(rgb, sigma, z_vals, dirs, white_bkgd):
170
+ """
171
+ Volumetric Rendering Function.
172
+
173
+ Args:
174
+ rgb: jnp.ndarray(float32), color, [batch_size, num_samples, 3]
175
+ sigma: jnp.ndarray(float32), density, [batch_size, num_samples, 1].
176
+ z_vals: jnp.ndarray(float32), [batch_size, num_samples].
177
+ dirs: jnp.ndarray(float32), [batch_size, 3].
178
+ white_bkgd: bool.
179
+
180
+ Returns:
181
+ comp_rgb: jnp.ndarray(float32), [batch_size, 3].
182
+ disp: jnp.ndarray(float32), [batch_size].
183
+ acc: jnp.ndarray(float32), [batch_size].
184
+ weights: jnp.ndarray(float32), [batch_size, num_samples]
185
+ """
186
+ eps = 1e-10
187
+ dists = jnp.concatenate([
188
+ z_vals[..., 1:] - z_vals[..., :-1],
189
+ jnp.broadcast_to([1e10], z_vals[..., :1].shape)
190
+ ], -1)
191
+ dists = dists * jnp.linalg.norm(dirs[..., None, :], axis=-1)
192
+ # Note that we're quietly turning sigma from [..., 0] to [...].
193
+ alpha = 1.0 - jnp.exp(-sigma[..., 0] * dists)
194
+ accum_prod = jnp.concatenate([
195
+ jnp.ones_like(alpha[..., :1], alpha.dtype),
196
+ jnp.cumprod(1.0 - alpha[..., :-1] + eps, axis=-1)
197
+ ],
198
+ axis=-1)
199
+ weights = alpha * accum_prod
200
+
201
+ comp_rgb = (weights[..., None] * rgb).sum(axis=-2)
202
+ depth = (weights * z_vals).sum(axis=-1)
203
+ acc = weights.sum(axis=-1)
204
+ # Equivalent to (but slightly more efficient and stable than):
205
+ # disp = 1 / max(eps, where(acc > eps, depth / acc, 0))
206
+ inv_eps = 1 / eps
207
+ disp = acc / depth
208
+ disp = jnp.where((disp > 0) & (disp < inv_eps) & (acc > eps), disp, inv_eps)
209
+ if white_bkgd:
210
+ comp_rgb = comp_rgb + (1. - acc[..., None])
211
+ return comp_rgb, disp, acc, weights
212
+
213
+
214
+ def piecewise_constant_pdf(key, bins, weights, num_samples, randomized):
215
+ """
216
+ Piecewise-Constant PDF sampling.
217
+
218
+ Args:
219
+ key: jnp.ndarray(float32), [2,], random number generator.
220
+ bins: jnp.ndarray(float32), [batch_size, num_bins + 1].
221
+ weights: jnp.ndarray(float32), [batch_size, num_bins].
222
+ num_samples: int, the number of samples.
223
+ randomized: bool, use randomized samples.
224
+
225
+ Returns:
226
+ z_samples: jnp.ndarray(float32), [batch_size, num_samples].
227
+ """
228
+ # Pad each weight vector (only if necessary) to bring its sum to `eps`. This
229
+ # avoids NaNs when the input is zeros or small, but has no effect otherwise.
230
+ eps = 1e-5
231
+ weight_sum = jnp.sum(weights, axis=-1, keepdims=True)
232
+ padding = jnp.maximum(0, eps - weight_sum)
233
+ weights += padding / weights.shape[-1]
234
+ weight_sum += padding
235
+
236
+ # Compute the PDF and CDF for each weight vector, while ensuring that the CDF
237
+ # starts with exactly 0 and ends with exactly 1.
238
+ pdf = weights / weight_sum
239
+ cdf = jnp.minimum(1, jnp.cumsum(pdf[..., :-1], axis=-1))
240
+ cdf = jnp.concatenate([
241
+ jnp.zeros(list(cdf.shape[:-1]) + [1]), cdf,
242
+ jnp.ones(list(cdf.shape[:-1]) + [1])
243
+ ],
244
+ axis=-1)
245
+
246
+ # Draw uniform samples.
247
+ if randomized:
248
+ # Note that `u` is in [0, 1) --- it can be zero, but it can never be 1.
249
+ u = random.uniform(key, list(cdf.shape[:-1]) + [num_samples])
250
+ else:
251
+ # Match the behavior of random.uniform() by spanning [0, 1-eps].
252
+ u = jnp.linspace(0., 1. - jnp.finfo('float32').eps, num_samples)
253
+ u = jnp.broadcast_to(u, list(cdf.shape[:-1]) + [num_samples])
254
+
255
+ # Identify the location in `cdf` that corresponds to a random sample.
256
+ # The final `True` index in `mask` will be the start of the sampled interval.
257
+ mask = u[..., None, :] >= cdf[..., :, None]
258
+
259
+ def find_interval(x):
260
+ # Grab the value where `mask` switches from True to False, and vice versa.
261
+ # This approach takes advantage of the fact that `x` is sorted.
262
+ x0 = jnp.max(jnp.where(mask, x[..., None], x[..., :1, None]), -2)
263
+ x1 = jnp.min(jnp.where(~mask, x[..., None], x[..., -1:, None]), -2)
264
+ return x0, x1
265
+
266
+ bins_g0, bins_g1 = find_interval(bins)
267
+ cdf_g0, cdf_g1 = find_interval(cdf)
268
+
269
+ t = jnp.clip(jnp.nan_to_num((u - cdf_g0) / (cdf_g1 - cdf_g0), 0), 0, 1)
270
+ samples = bins_g0 + t * (bins_g1 - bins_g0)
271
+
272
+ # Prevent gradient from backprop-ing through `samples`.
273
+ return lax.stop_gradient(samples)
274
+
275
+
276
+ def sample_pdf(key, bins, weights, origins, directions, z_vals, num_samples,
277
+ randomized):
278
+ """
279
+ Hierarchical sampling.
280
+
281
+ Args:
282
+ key: jnp.ndarray(float32), [2,], random number generator.
283
+ bins: jnp.ndarray(float32), [batch_size, num_bins + 1].
284
+ weights: jnp.ndarray(float32), [batch_size, num_bins].
285
+ origins: jnp.ndarray(float32), [batch_size, 3], ray origins.
286
+ directions: jnp.ndarray(float32), [batch_size, 3], ray directions.
287
+ z_vals: jnp.ndarray(float32), [batch_size, num_coarse_samples].
288
+ num_samples: int, the number of samples.
289
+ randomized: bool, use randomized samples.
290
+
291
+ Returns:
292
+ z_vals: jnp.ndarray(float32),
293
+ [batch_size, num_coarse_samples + num_fine_samples].
294
+ points: jnp.ndarray(float32),
295
+ [batch_size, num_coarse_samples + num_fine_samples, 3].
296
+ """
297
+ z_samples = piecewise_constant_pdf(key, bins, weights, num_samples,
298
+ randomized)
299
+ # Compute united z_vals and sample points
300
+ z_vals = jnp.sort(jnp.concatenate([z_vals, z_samples], axis=-1), axis=-1)
301
+ coords = cast_rays(z_vals, origins, directions)
302
+ return z_vals, coords
303
+
304
+
305
+ def add_gaussian_noise(key, raw, noise_std, randomized):
306
+ """
307
+ Adds gaussian noise to `raw`, which can used to regularize it.
308
+
309
+ Args:
310
+ key: jnp.ndarray(float32), [2,], random number generator.
311
+ raw: jnp.ndarray(float32), arbitrary shape.
312
+ noise_std: float, The standard deviation of the noise to be added.
313
+ randomized: bool, add noise if randomized is True.
314
+
315
+ Returns:
316
+ raw + noise: jnp.ndarray(float32), with the same shape as `raw`.
317
+ """
318
+ if (noise_std is not None) and randomized:
319
+ return raw + random.normal(key, raw.shape, dtype=raw.dtype) * noise_std
320
+ else:
321
+ return raw
jaxnerf/nerf/models.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Lint as: python3
17
+ """Different model implementation plus a general port for all the models."""
18
+ from typing import Any, Callable
19
+ from flax import linen as nn
20
+ from jax import random
21
+ import jax.numpy as jnp
22
+
23
+ from jaxnerf.nerf import model_utils
24
+ from jaxnerf.nerf import utils
25
+
26
+
27
+ def get_model(key, example_batch, args):
28
+ """A helper function that wraps around a 'model zoo'."""
29
+ model_dict = {"nerf": construct_nerf}
30
+ return model_dict[args.model](key, example_batch, args)
31
+
32
+
33
+ class NerfModel(nn.Module):
34
+ """Nerf NN Model with both coarse and fine MLPs."""
35
+ num_coarse_samples: int # The number of samples for the coarse nerf.
36
+ num_fine_samples: int # The number of samples for the fine nerf.
37
+ use_viewdirs: bool # If True, use viewdirs as an input.
38
+ near: float # The distance to the near plane
39
+ far: float # The distance to the far plane
40
+ noise_std: float # The std dev of noise added to raw sigma.
41
+ net_depth: int # The depth of the first part of MLP.
42
+ net_width: int # The width of the first part of MLP.
43
+ net_depth_condition: int # The depth of the second part of MLP.
44
+ net_width_condition: int # The width of the second part of MLP.
45
+ net_activation: Callable[..., Any] # MLP activation
46
+ skip_layer: int # How often to add skip connections.
47
+ num_rgb_channels: int # The number of RGB channels.
48
+ num_sigma_channels: int # The number of density channels.
49
+ white_bkgd: bool # If True, use a white background.
50
+ min_deg_point: int # The minimum degree of positional encoding for positions.
51
+ max_deg_point: int # The maximum degree of positional encoding for positions.
52
+ deg_view: int # The degree of positional encoding for viewdirs.
53
+ lindisp: bool # If True, sample linearly in disparity rather than in depth.
54
+ rgb_activation: Callable[..., Any] # Output RGB activation.
55
+ sigma_activation: Callable[..., Any] # Output sigma activation.
56
+ legacy_posenc_order: bool # Keep the same ordering as the original tf code.
57
+
58
+ @nn.compact
59
+ def __call__(self, rng_0, rng_1, rays, randomized):
60
+ """Nerf Model.
61
+
62
+ Args:
63
+ rng_0: jnp.ndarray, random number generator for coarse model sampling.
64
+ rng_1: jnp.ndarray, random number generator for fine model sampling.
65
+ rays: util.Rays, a namedtuple of ray origins, directions, and viewdirs.
66
+ randomized: bool, use randomized stratified sampling.
67
+
68
+ Returns:
69
+ ret: list, [(rgb_coarse, disp_coarse, acc_coarse), (rgb, disp, acc)]
70
+ """
71
+ # Stratified sampling along rays
72
+ key, rng_0 = random.split(rng_0)
73
+ z_vals, samples = model_utils.sample_along_rays(
74
+ key,
75
+ rays.origins,
76
+ rays.directions,
77
+ self.num_coarse_samples,
78
+ self.near,
79
+ self.far,
80
+ randomized,
81
+ self.lindisp,
82
+ )
83
+ samples_enc = model_utils.posenc(
84
+ samples,
85
+ self.min_deg_point,
86
+ self.max_deg_point,
87
+ self.legacy_posenc_order,
88
+ )
89
+
90
+ # Construct the "coarse" MLP.
91
+ coarse_mlp = model_utils.MLP(
92
+ net_depth=self.net_depth,
93
+ net_width=self.net_width,
94
+ net_depth_condition=self.net_depth_condition,
95
+ net_width_condition=self.net_width_condition,
96
+ net_activation=self.net_activation,
97
+ skip_layer=self.skip_layer,
98
+ num_rgb_channels=self.num_rgb_channels,
99
+ num_sigma_channels=self.num_sigma_channels)
100
+
101
+ # Point attribute predictions
102
+ if self.use_viewdirs:
103
+ viewdirs_enc = model_utils.posenc(
104
+ rays.viewdirs,
105
+ 0,
106
+ self.deg_view,
107
+ self.legacy_posenc_order,
108
+ )
109
+ raw_rgb, raw_sigma = coarse_mlp(samples_enc, viewdirs_enc)
110
+ else:
111
+ viewdirs_enc = None
112
+ raw_rgb, raw_sigma = coarse_mlp(samples_enc)
113
+ # Add noises to regularize the density predictions if needed
114
+ key, rng_0 = random.split(rng_0)
115
+ raw_sigma = model_utils.add_gaussian_noise(
116
+ key,
117
+ raw_sigma,
118
+ self.noise_std,
119
+ randomized,
120
+ )
121
+ rgb = self.rgb_activation(raw_rgb)
122
+ sigma = self.sigma_activation(raw_sigma)
123
+ # Volumetric rendering.
124
+ comp_rgb, disp, acc, weights = model_utils.volumetric_rendering(
125
+ rgb,
126
+ sigma,
127
+ z_vals,
128
+ rays.directions,
129
+ white_bkgd=self.white_bkgd,
130
+ )
131
+ ret = [
132
+ (comp_rgb, disp, acc),
133
+ ]
134
+ # Hierarchical sampling based on coarse predictions
135
+ if self.num_fine_samples > 0:
136
+ z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
137
+ key, rng_1 = random.split(rng_1)
138
+ z_vals, samples = model_utils.sample_pdf(
139
+ key,
140
+ z_vals_mid,
141
+ weights[..., 1:-1],
142
+ rays.origins,
143
+ rays.directions,
144
+ z_vals,
145
+ self.num_fine_samples,
146
+ randomized,
147
+ )
148
+ samples_enc = model_utils.posenc(
149
+ samples,
150
+ self.min_deg_point,
151
+ self.max_deg_point,
152
+ self.legacy_posenc_order,
153
+ )
154
+
155
+ # Construct the "fine" MLP.
156
+ fine_mlp = model_utils.MLP(
157
+ net_depth=self.net_depth,
158
+ net_width=self.net_width,
159
+ net_depth_condition=self.net_depth_condition,
160
+ net_width_condition=self.net_width_condition,
161
+ net_activation=self.net_activation,
162
+ skip_layer=self.skip_layer,
163
+ num_rgb_channels=self.num_rgb_channels,
164
+ num_sigma_channels=self.num_sigma_channels)
165
+
166
+ if self.use_viewdirs:
167
+ raw_rgb, raw_sigma = fine_mlp(samples_enc, viewdirs_enc)
168
+ else:
169
+ raw_rgb, raw_sigma = fine_mlp(samples_enc)
170
+ key, rng_1 = random.split(rng_1)
171
+ raw_sigma = model_utils.add_gaussian_noise(
172
+ key,
173
+ raw_sigma,
174
+ self.noise_std,
175
+ randomized,
176
+ )
177
+ rgb = self.rgb_activation(raw_rgb)
178
+ sigma = self.sigma_activation(raw_sigma)
179
+ comp_rgb, disp, acc, unused_weights = model_utils.volumetric_rendering(
180
+ rgb,
181
+ sigma,
182
+ z_vals,
183
+ rays.directions,
184
+ white_bkgd=self.white_bkgd,
185
+ )
186
+ ret.append((comp_rgb, disp, acc))
187
+ return ret
188
+
189
+
190
+ def construct_nerf(key, example_batch, args):
191
+ """Construct a Neural Radiance Field.
192
+
193
+ Args:
194
+ key: jnp.ndarray. Random number generator.
195
+ example_batch: dict, an example of a batch of data.
196
+ args: FLAGS class. Hyperparameters of nerf.
197
+
198
+ Returns:
199
+ model: nn.Model. Nerf model with parameters.
200
+ state: flax.Module.state. Nerf model state for stateful parameters.
201
+ """
202
+ net_activation = getattr(nn, str(args.net_activation))
203
+ rgb_activation = getattr(nn, str(args.rgb_activation))
204
+ sigma_activation = getattr(nn, str(args.sigma_activation))
205
+
206
+ # Assert that rgb_activation always produces outputs in [0, 1], and
207
+ # sigma_activation always produce non-negative outputs.
208
+ x = jnp.exp(jnp.linspace(-90, 90, 1024))
209
+ x = jnp.concatenate([-x[::-1], x], 0)
210
+
211
+ rgb = rgb_activation(x)
212
+ if jnp.any(rgb < 0) or jnp.any(rgb > 1):
213
+ raise NotImplementedError(
214
+ "Choice of rgb_activation `{}` produces colors outside of [0, 1]"
215
+ .format(args.rgb_activation))
216
+
217
+ sigma = sigma_activation(x)
218
+ if jnp.any(sigma < 0):
219
+ raise NotImplementedError(
220
+ "Choice of sigma_activation `{}` produces negative densities".format(
221
+ args.sigma_activation))
222
+
223
+ model = NerfModel(
224
+ min_deg_point=args.min_deg_point,
225
+ max_deg_point=args.max_deg_point,
226
+ deg_view=args.deg_view,
227
+ num_coarse_samples=args.num_coarse_samples,
228
+ num_fine_samples=args.num_fine_samples,
229
+ use_viewdirs=args.use_viewdirs,
230
+ near=args.near,
231
+ far=args.far,
232
+ noise_std=args.noise_std,
233
+ white_bkgd=args.white_bkgd,
234
+ net_depth=args.net_depth,
235
+ net_width=args.net_width,
236
+ net_depth_condition=args.net_depth_condition,
237
+ net_width_condition=args.net_width_condition,
238
+ skip_layer=args.skip_layer,
239
+ num_rgb_channels=args.num_rgb_channels,
240
+ num_sigma_channels=args.num_sigma_channels,
241
+ lindisp=args.lindisp,
242
+ net_activation=net_activation,
243
+ rgb_activation=rgb_activation,
244
+ sigma_activation=sigma_activation,
245
+ legacy_posenc_order=args.legacy_posenc_order)
246
+ rays = example_batch["rays"]
247
+ key1, key2, key3 = random.split(key, num=3)
248
+
249
+ init_variables = model.init(
250
+ key1,
251
+ rng_0=key2,
252
+ rng_1=key3,
253
+ rays=utils.namedtuple_map(lambda x: x[0], rays),
254
+ randomized=args.randomized)
255
+
256
+ return model, init_variables
jaxnerf/nerf/precompute.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ command line example:
3
+ $ python -i -m jaxnerf.nerf.precompute --data_dir {path-to-data-dir} --split train \
4
+ --dataset blender --factor 4 --dtype float16
5
+ """
6
+ import os
7
+ import argparse
8
+ from typing import Optional
9
+
10
+ import jax.numpy as np
11
+
12
+ from jaxnerf.nerf import utils
13
+ from jaxnerf.nerf import clip_utils
14
+ from jaxnerf.nerf import datasets
15
+
16
+
17
+ def precompute_image_features(data_dir: str, split: str, dataset: str, factor: int, dtype: str,
18
+ model_name: Optional[str], render_path: Optional[str]):
19
+ if dataset == "blender":
20
+ if render_path:
21
+ raise ValueError("render_path cannot be used for the blender dataset.")
22
+
23
+ # image in numpy.ndarray
24
+ _, images, _ = datasets.Blender.load_files(data_dir, split, factor)
25
+ clip_model = clip_utils.init_CLIP(dtype, model_name)
26
+
27
+ # CLIP output in jax.numpy.ndarray
28
+ images = np.stack(images).transpose(0, 3, 1, 2)
29
+ images = images[:, :3, :, :]
30
+ images = clip_utils.preprocess_for_CLIP(images)
31
+ embeddings = clip_model.get_image_features(pixel_values=images)
32
+ embeddings /= np.linalg.norm(embeddings, axis=-1, keepdims=True)
33
+ print(f'completed precomputing CLIP embeddings: ({embeddings.shape[0]} images)')
34
+
35
+ # write as pickle
36
+ write_path = os.path.join(data_dir, f'clip_cache_{split}_factor{factor}_{dtype}.pkl')
37
+ utils.write_pickle(embeddings, write_path)
38
+ print(f'precompute written as pickle: {write_path}')
39
+
40
+ elif dataset == "llff":
41
+ raise NotImplementedError
42
+ else:
43
+ raise ValueError(f"invalid dataset: {dataset}")
44
+
45
+
46
+ if __name__ == '__main__':
47
+ parser = argparse.ArgumentParser()
48
+ parser.add_argument("--data_dir", type=str, required=True)
49
+ parser.add_argument("--split", type=str, required=True, help="train/val/test")
50
+ parser.add_argument("--dataset", type=str, required=True)
51
+ parser.add_argument("--factor", type=int, required=True,
52
+ help="downsampling factor: 0/2/4")
53
+ parser.add_argument("--dtype", type=str, required=True,
54
+ help="float32/float16 (float16 is used to save memory)")
55
+ parser.add_argument("--model_name", type=str, required=False, default=None)
56
+ parser.add_argument("--render_path", type=str, required=False, default=None)
57
+ args = parser.parse_args()
58
+ precompute_image_features(args.data_dir, args.split, args.dataset, args.factor,
59
+ args.dtype, args.model_name, args.render_path)
jaxnerf/nerf/utils.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Lint as: python3
17
+ """Utility functions."""
18
+ import collections
19
+ import os
20
+ from os import path
21
+ import pickle
22
+ from absl import flags
23
+ import flax
24
+ import jax
25
+ import jax.numpy as jnp
26
+ import jax.scipy as jsp
27
+ import numpy as np
28
+ from PIL import Image
29
+ import yaml
30
+ from jaxnerf.nerf import datasets
31
+
32
+ BASE_DIR = "jaxnerf"
33
+ INTERNAL = False
34
+
35
+
36
+ @flax.struct.dataclass
37
+ class TrainState:
38
+ optimizer: flax.optim.Optimizer
39
+
40
+
41
+ @flax.struct.dataclass
42
+ class Stats:
43
+ loss: float
44
+ psnr: float
45
+ loss_c: float
46
+ psnr_c: float
47
+ weight_l2: float
48
+
49
+
50
+ Rays = collections.namedtuple("Rays", ("origins", "directions", "viewdirs"))
51
+
52
+
53
+ def namedtuple_map(fn, tup):
54
+ """Apply `fn` to each element of `tup` and cast to `tup`'s namedtuple."""
55
+ return type(tup)(*map(fn, tup))
56
+
57
+
58
+ def define_flags():
59
+ """Define flags for both training and evaluation modes."""
60
+ flags.DEFINE_string("train_dir", None, "where to store ckpts and logs")
61
+ flags.DEFINE_string("data_dir", None, "input data directory.")
62
+ flags.DEFINE_string("config", None,
63
+ "using config files to set hyperparameters.")
64
+
65
+ # CLIP part Flags
66
+ flags.DEFINE_bool("use_semantic_loss", True,
67
+ "whether use semantic loss or not")
68
+ flags.DEFINE_string("precompute_pkl_path", None,
69
+ "where to load the pickle file that precompute image features")
70
+ flags.DEFINE_string("clip_model_name", "openai/clip-vit-base-patch32", "model type for CLIP")
71
+ flags.DEFINE_string("clip_output_dtype", "float32",
72
+ "float32/ float16 (float16 for memory saving)")
73
+ flags.DEFINE_integer("sc_loss_factor", 4,
74
+ "factor for downsampling image (0/2/4). "
75
+ "its compounded on top of another flag: factor")
76
+ flags.DEFINE_integer("sc_loss_every", 16,
77
+ "no. of steps to take before performing semantic loss evaluation")
78
+ flags.DEFINE_float("sc_loss_mult", 10.,
79
+ "weighting for semantic loss from CLIP")
80
+
81
+ # Dataset Flags
82
+ # TODO(pratuls): rename to dataset_loader and consider cleaning up
83
+ flags.DEFINE_enum("dataset", "blender",
84
+ list(k for k in datasets.dataset_dict.keys()),
85
+ "The type of dataset feed to nerf.")
86
+ flags.DEFINE_enum(
87
+ "batching", "single_image", ["single_image", "all_images"],
88
+ "source of ray sampling when collecting training batch,"
89
+ "single_image for sampling from only one image in a batch,"
90
+ "all_images for sampling from all the training images.")
91
+ flags.DEFINE_bool(
92
+ "white_bkgd", True, "using white color as default background."
93
+ "(used in the blender dataset only)")
94
+ flags.DEFINE_integer("batch_size", 1024,
95
+ "the number of rays in a mini-batch (for training).")
96
+ flags.DEFINE_integer("factor", 4,
97
+ "the downsample factor of images, 0 for no downsample.")
98
+ flags.DEFINE_bool("spherify", False, "set for spherical 360 scenes.")
99
+ flags.DEFINE_bool(
100
+ "render_path", False, "render generated path if set true."
101
+ "(used in the llff dataset only)")
102
+ flags.DEFINE_integer(
103
+ "llffhold", 8, "will take every 1/N images as LLFF test set."
104
+ "(used in the llff dataset only)")
105
+ flags.DEFINE_bool(
106
+ "use_pixel_centers", False,
107
+ "If True, generate rays through the center of each pixel. Note: While "
108
+ "this is the correct way to handle rays, it is not the way rays are "
109
+ "handled in the original NeRF paper. Setting this TRUE yields ~ +1 PSNR "
110
+ "compared to Vanilla NeRF.")
111
+
112
+ # Model Flags
113
+ flags.DEFINE_string("model", "nerf", "name of model to use.")
114
+ flags.DEFINE_float("near", 2., "near clip of volumetric rendering.")
115
+ flags.DEFINE_float("far", 6., "far clip of volumentric rendering.")
116
+ flags.DEFINE_integer("net_depth", 8, "depth of the first part of MLP.")
117
+ flags.DEFINE_integer("net_width", 256, "width of the first part of MLP.")
118
+ flags.DEFINE_integer("net_depth_condition", 1,
119
+ "depth of the second part of MLP.")
120
+ flags.DEFINE_integer("net_width_condition", 128,
121
+ "width of the second part of MLP.")
122
+ flags.DEFINE_float("weight_decay_mult", 0, "The multiplier on weight decay")
123
+ flags.DEFINE_integer(
124
+ "skip_layer", 4, "add a skip connection to the output vector of every"
125
+ "skip_layer layers.")
126
+ flags.DEFINE_integer("num_rgb_channels", 3, "the number of RGB channels.")
127
+ flags.DEFINE_integer("num_sigma_channels", 1,
128
+ "the number of density channels.")
129
+ flags.DEFINE_bool("randomized", True, "use randomized stratified sampling.")
130
+ flags.DEFINE_integer("min_deg_point", 0,
131
+ "Minimum degree of positional encoding for points.")
132
+ flags.DEFINE_integer("max_deg_point", 10,
133
+ "Maximum degree of positional encoding for points.")
134
+ flags.DEFINE_integer("deg_view", 4,
135
+ "Degree of positional encoding for viewdirs.")
136
+ flags.DEFINE_integer(
137
+ "num_coarse_samples", 64,
138
+ "the number of samples on each ray for the coarse model.")
139
+ flags.DEFINE_integer("num_fine_samples", 128,
140
+ "the number of samples on each ray for the fine model.")
141
+ flags.DEFINE_bool("use_viewdirs", True, "use view directions as a condition.")
142
+ flags.DEFINE_float(
143
+ "noise_std", None, "std dev of noise added to regularize sigma output."
144
+ "(used in the llff dataset only)")
145
+ flags.DEFINE_bool("lindisp", False,
146
+ "sampling linearly in disparity rather than depth.")
147
+ flags.DEFINE_string("net_activation", "relu",
148
+ "activation function used within the MLP.")
149
+ flags.DEFINE_string("rgb_activation", "sigmoid",
150
+ "activation function used to produce RGB.")
151
+ flags.DEFINE_string("sigma_activation", "relu",
152
+ "activation function used to produce density.")
153
+ flags.DEFINE_bool(
154
+ "legacy_posenc_order", False,
155
+ "If True, revert the positional encoding feature order to an older version of this codebase."
156
+ )
157
+
158
+ # Train Flags
159
+ flags.DEFINE_float("lr_init", 5e-4, "The initial learning rate.")
160
+ flags.DEFINE_float("lr_final", 5e-6, "The final learning rate.")
161
+ flags.DEFINE_integer(
162
+ "lr_delay_steps", 0, "The number of steps at the beginning of "
163
+ "training to reduce the learning rate by lr_delay_mult")
164
+ flags.DEFINE_float(
165
+ "lr_delay_mult", 1., "A multiplier on the learning rate when the step "
166
+ "is < lr_delay_steps")
167
+ flags.DEFINE_float("grad_max_norm", 0.,
168
+ "The gradient clipping magnitude (disabled if == 0).")
169
+ flags.DEFINE_float("grad_max_val", 0.,
170
+ "The gradient clipping value (disabled if == 0).")
171
+
172
+ flags.DEFINE_integer("max_steps", 1000000,
173
+ "the number of optimization steps.")
174
+ flags.DEFINE_integer("save_every", 10000,
175
+ "the number of steps to save a checkpoint.")
176
+ flags.DEFINE_integer("print_every", 100,
177
+ "the number of steps between reports to tensorboard.")
178
+ flags.DEFINE_integer(
179
+ "render_every", 5000, "the number of steps to render a test image,"
180
+ "better to be x00 for accurate step time record.")
181
+ flags.DEFINE_integer("gc_every", 10000,
182
+ "the number of steps to run python garbage collection.")
183
+ flags.DEFINE_integer("few_shot", -1,
184
+ "the number of images.")
185
+
186
+ # Eval Flags
187
+ flags.DEFINE_bool(
188
+ "eval_once", True,
189
+ "evaluate the model only once if true, otherwise keeping evaluating new"
190
+ "checkpoints if there's any.")
191
+ flags.DEFINE_bool("save_output", True,
192
+ "save predicted images to disk if True.")
193
+ flags.DEFINE_integer(
194
+ "chunk", 8192,
195
+ "the size of chunks for evaluation inferences, set to the value that"
196
+ "fits your GPU/TPU memory.")
197
+
198
+ def update_flags(args):
199
+ """Update the flags in `args` with the contents of the config YAML file."""
200
+ pth = path.join(BASE_DIR, args.config + ".yaml")
201
+ with open_file(pth, "r") as fin:
202
+ configs = yaml.load(fin, Loader=yaml.FullLoader)
203
+ # Only allow args to be updated if they already exist.
204
+ invalid_args = list(set(configs.keys()) - set(dir(args)))
205
+ if invalid_args:
206
+ raise ValueError(f"Invalid args {invalid_args} in {pth}.")
207
+ args.__dict__.update(configs)
208
+
209
+ def open_file(pth, mode="r"):
210
+ if not INTERNAL:
211
+ return open(pth, mode=mode)
212
+
213
+
214
+ def file_exists(pth):
215
+ if not INTERNAL:
216
+ return path.exists(pth)
217
+
218
+
219
+ def listdir(pth):
220
+ if not INTERNAL:
221
+ return os.listdir(pth)
222
+
223
+
224
+ def isdir(pth):
225
+ if not INTERNAL:
226
+ return path.isdir(pth)
227
+
228
+
229
+ def makedirs(pth):
230
+ if not INTERNAL:
231
+ os.makedirs(pth)
232
+
233
+
234
+ def render_image(render_fn, rays, rng, normalize_disp, chunk=8192):
235
+ """Render all the pixels of an image (in test mode).
236
+
237
+ Args:
238
+ render_fn: function, jit-ed render function.
239
+ rays: a `Rays` namedtuple, the rays to be rendered.
240
+ rng: jnp.ndarray, random number generator (used in training mode only).
241
+ normalize_disp: bool, if true then normalize `disp` to [0, 1].
242
+ chunk: int, the size of chunks to render sequentially.
243
+
244
+ Returns:
245
+ rgb: jnp.ndarray, rendered color image.
246
+ disp: jnp.ndarray, rendered disparity image.
247
+ acc: jnp.ndarray, rendered accumulated weights per pixel.
248
+ """
249
+ height, width = rays[0].shape[:2]
250
+ num_rays = height * width
251
+ rays = namedtuple_map(lambda r: r.reshape((num_rays, -1)), rays)
252
+
253
+ unused_rng, key_0, key_1 = jax.random.split(rng, 3)
254
+ host_id = jax.host_id()
255
+ results = []
256
+ for i in range(0, num_rays, chunk):
257
+ # pylint: disable=cell-var-from-loop
258
+ chunk_rays = namedtuple_map(lambda r: r[i:i + chunk], rays)
259
+ chunk_size = chunk_rays[0].shape[0]
260
+ rays_remaining = chunk_size % jax.device_count()
261
+ if rays_remaining != 0:
262
+ padding = jax.device_count() - rays_remaining
263
+ chunk_rays = namedtuple_map(
264
+ lambda r: jnp.pad(r, ((0, padding), (0, 0)), mode="edge"), chunk_rays)
265
+ else:
266
+ padding = 0
267
+ # After padding the number of chunk_rays is always divisible by
268
+ # host_count.
269
+ rays_per_host = chunk_rays[0].shape[0] // jax.process_count()
270
+ start, stop = host_id * rays_per_host, (host_id + 1) * rays_per_host
271
+ chunk_rays = namedtuple_map(lambda r: shard(r[start:stop]), chunk_rays)
272
+ chunk_results = render_fn(key_0, key_1, chunk_rays)[-1]
273
+ results.append([unshard(x[0], padding) for x in chunk_results])
274
+ # pylint: enable=cell-var-from-loop
275
+ rgb, disp, acc = [jnp.concatenate(r, axis=0) for r in zip(*results)]
276
+ # Normalize disp for visualization for ndc_rays in llff front-facing scenes.
277
+ if normalize_disp:
278
+ disp = (disp - disp.min()) / (disp.max() - disp.min())
279
+ return (rgb.reshape((height, width, -1)), disp.reshape(
280
+ (height, width, -1)), acc.reshape((height, width, -1)))
281
+
282
+
283
+ def compute_psnr(mse):
284
+ """Compute psnr value given mse (we assume the maximum pixel value is 1).
285
+
286
+ Args:
287
+ mse: float, mean square error of pixels.
288
+
289
+ Returns:
290
+ psnr: float, the psnr value.
291
+ """
292
+ return -10. * jnp.log(mse) / jnp.log(10.)
293
+
294
+
295
+ def compute_ssim(img0,
296
+ img1,
297
+ max_val,
298
+ filter_size=11,
299
+ filter_sigma=1.5,
300
+ k1=0.01,
301
+ k2=0.03,
302
+ return_map=False):
303
+ """Computes SSIM from two images.
304
+
305
+ This function was modeled after tf.image.ssim, and should produce comparable
306
+ output.
307
+
308
+ Args:
309
+ img0: array. An image of size [..., width, height, num_channels].
310
+ img1: array. An image of size [..., width, height, num_channels].
311
+ max_val: float > 0. The maximum magnitude that `img0` or `img1` can have.
312
+ filter_size: int >= 1. Window size.
313
+ filter_sigma: float > 0. The bandwidth of the Gaussian used for filtering.
314
+ k1: float > 0. One of the SSIM dampening parameters.
315
+ k2: float > 0. One of the SSIM dampening parameters.
316
+ return_map: Bool. If True, will cause the per-pixel SSIM "map" to returned
317
+
318
+ Returns:
319
+ Each image's mean SSIM, or a tensor of individual values if `return_map`.
320
+ """
321
+ # Construct a 1D Gaussian blur filter.
322
+ hw = filter_size // 2
323
+ shift = (2 * hw - filter_size + 1) / 2
324
+ f_i = ((jnp.arange(filter_size) - hw + shift) / filter_sigma) ** 2
325
+ filt = jnp.exp(-0.5 * f_i)
326
+ filt /= jnp.sum(filt)
327
+
328
+ # Blur in x and y (faster than the 2D convolution).
329
+ filt_fn1 = lambda z: jsp.signal.convolve2d(z, filt[:, None], mode="valid")
330
+ filt_fn2 = lambda z: jsp.signal.convolve2d(z, filt[None, :], mode="valid")
331
+
332
+ # Vmap the blurs to the tensor size, and then compose them.
333
+ num_dims = len(img0.shape)
334
+ map_axes = tuple(list(range(num_dims - 3)) + [num_dims - 1])
335
+ for d in map_axes:
336
+ filt_fn1 = jax.vmap(filt_fn1, in_axes=d, out_axes=d)
337
+ filt_fn2 = jax.vmap(filt_fn2, in_axes=d, out_axes=d)
338
+ filt_fn = lambda z: filt_fn1(filt_fn2(z))
339
+
340
+ mu0 = filt_fn(img0)
341
+ mu1 = filt_fn(img1)
342
+ mu00 = mu0 * mu0
343
+ mu11 = mu1 * mu1
344
+ mu01 = mu0 * mu1
345
+ sigma00 = filt_fn(img0 ** 2) - mu00
346
+ sigma11 = filt_fn(img1 ** 2) - mu11
347
+ sigma01 = filt_fn(img0 * img1) - mu01
348
+
349
+ # Clip the variances and covariances to valid values.
350
+ # Variance must be non-negative:
351
+ sigma00 = jnp.maximum(0., sigma00)
352
+ sigma11 = jnp.maximum(0., sigma11)
353
+ sigma01 = jnp.sign(sigma01) * jnp.minimum(
354
+ jnp.sqrt(sigma00 * sigma11), jnp.abs(sigma01))
355
+
356
+ c1 = (k1 * max_val) ** 2
357
+ c2 = (k2 * max_val) ** 2
358
+ numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
359
+ denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
360
+ ssim_map = numer / denom
361
+ ssim = jnp.mean(ssim_map, list(range(num_dims - 3, num_dims)))
362
+ return ssim_map if return_map else ssim
363
+
364
+
365
+ def save_img(img, pth):
366
+ """Save an image to disk.
367
+
368
+ Args:
369
+ img: jnp.ndarry, [height, width, channels], img will be clipped to [0, 1]
370
+ before saved to pth.
371
+ pth: string, path to save the image to.
372
+ """
373
+ with open_file(pth, "wb") as imgout:
374
+ Image.fromarray(np.array(
375
+ (np.clip(img, 0., 1.) * 255.).astype(jnp.uint8))).save(imgout, "PNG")
376
+
377
+
378
+ def learning_rate_decay(step,
379
+ lr_init,
380
+ lr_final,
381
+ max_steps,
382
+ lr_delay_steps=0,
383
+ lr_delay_mult=1):
384
+ """Continuous learning rate decay function.
385
+
386
+ The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
387
+ is log-linearly interpolated elsewhere (equivalent to exponential decay).
388
+ If lr_delay_steps>0 then the learning rate will be scaled by some smooth
389
+ function of lr_delay_mult, such that the initial learning rate is
390
+ lr_init*lr_delay_mult at the beginning of optimization but will be eased back
391
+ to the normal learning rate when steps>lr_delay_steps.
392
+
393
+ Args:
394
+ step: int, the current optimization step.
395
+ lr_init: float, the initial learning rate.
396
+ lr_final: float, the final learning rate.
397
+ max_steps: int, the number of steps during optimization.
398
+ lr_delay_steps: int, the number of steps to delay the full learning rate.
399
+ lr_delay_mult: float, the multiplier on the rate when delaying it.
400
+
401
+ Returns:
402
+ lr: the learning for current step 'step'.
403
+ """
404
+ if lr_delay_steps > 0:
405
+ # A kind of reverse cosine decay.
406
+ delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
407
+ 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1))
408
+ else:
409
+ delay_rate = 1.
410
+ t = np.clip(step / max_steps, 0, 1)
411
+ log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
412
+ return delay_rate * log_lerp
413
+
414
+
415
+ def shard(xs):
416
+ """Split data into shards for multiple devices along the first dimension."""
417
+ '''
418
+ if 'embedding' in xs:
419
+ xs['pixels'] = jax.tree_map(lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), xs['pixels'])
420
+ xs['rays'] = jax.tree_map(lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), xs['rays'])
421
+ xs['embedding'] = np.stack([xs['embedding']]*jax.local_device_count(),0)
422
+ xs['random_rays'] = jax.tree_map(lambda x: np.stack([x]*jax.local_device_count(),0), xs['random_rays'])
423
+ else:
424
+ xs = jax.tree_map(
425
+ lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]) if len(x.shape) != 0 else x
426
+ , xs)
427
+
428
+ return xs
429
+ '''
430
+ return jax.tree_map(
431
+ lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]) if len(x.shape) != 0 else x
432
+ , xs)
433
+
434
+
435
+ def to_device(xs):
436
+ """Transfer data to devices (GPU/TPU)."""
437
+ return jax.tree_map(jnp.array, xs)
438
+
439
+
440
+ def unshard(x, padding=0):
441
+ """Collect the sharded tensor to the shape before sharding."""
442
+ y = x.reshape([x.shape[0] * x.shape[1]] + list(x.shape[2:]))
443
+ if padding > 0:
444
+ y = y[:-padding]
445
+ return y
446
+
447
+
448
+ def write_pickle(data, fn):
449
+ with open(fn, 'wb') as f:
450
+ pickle.dump(data, f)
451
+ return None
452
+
453
+
454
+ def read_pickle(fn):
455
+ with open(fn, 'rb') as f:
456
+ data = pickle.load(f)
457
+ return data
jaxnerf/requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy>=1.16.4
2
+ jax>=0.2.6
3
+ jaxlib>=0.1.57
4
+ flax>=0.2.2
5
+ opencv-python>=4.4.0
6
+ Pillow>=7.2.0
7
+ pyyaml>=5.3.1
8
+ tensorboard>=2.4.0
9
+ tensorflow>=2.3.1
10
+ tensorflow-hub>=0.11.0
11
+ transformers==4.8.2
12
+ wandb==0.10.33
13
+ tqdm==4.61.2
14
+ # pip install git+https://github.com/deepmind/jmp # mixed precision for JAX
jaxnerf/run.sh ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The Google Research Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ #!/bin/bash
16
+ set -e
17
+ set -x
18
+
19
+ virtualenv -p python3 .
20
+ source ./bin/activate
21
+
22
+ pip install -r jaxnerf/requirements.txt
23
+ pip uninstall jax
24
+ pip install --upgrade pip
25
+ pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
26
+ python -m jaxnerf.train \
27
+ --data_dir=/mnt/data/NeRF_Data/nerf_synthetic/lego \
28
+ --train_dir=test_output \
29
+ --max_steps=5 \
30
+ --factor=2 \
31
+ --batch_size=512 \
32
+ --config=configs/orig_nerf_tpu_vm_test \
33
+ --precompute_pkl_path /mnt/data/NeRF_Data/nerf_synthetic/lego/clip_cache_train_factor4_float32.pkl
jaxnerf/train.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Lint as: python3
17
+ """Training script for Nerf."""
18
+ import functools
19
+ import gc
20
+ import time
21
+ from absl import app
22
+ from absl import flags
23
+ import flax
24
+ from flax.metrics import tensorboard
25
+ from flax.training import checkpoints
26
+ import jax
27
+ from jax import config
28
+ from jax import random
29
+ import jax.numpy as jnp
30
+ import numpy as np
31
+ # import wandb
32
+ from tqdm import tqdm
33
+
34
+ from jaxnerf.nerf import datasets
35
+ from jaxnerf.nerf import models
36
+ from jaxnerf.nerf import utils
37
+ from jaxnerf.nerf import clip_utils
38
+
39
+ FLAGS = flags.FLAGS
40
+
41
+ utils.define_flags()
42
+ config.parse_flags_with_absl()
43
+
44
+ # set up TPU for colab
45
+ import os
46
+ if "COLAB_TPU_ADDR" in os.environ:
47
+ import jax.tools.colab_tpu
48
+ jax.tools.colab_tpu.setup_tpu()
49
+ print(f"detected device: {jax.local_devices()}")
50
+
51
+
52
+ def train_step(model, clip_model, rng, state, batch, lr, step, K):#, clip_grad):
53
+ # TODO make clip_grad input enable
54
+ """One optimization step.
55
+
56
+ Args:
57
+ model: The linen model.
58
+ rng: jnp.ndarray, random number generator.
59
+ state: utils.TrainState, state of the model/optimizer.
60
+ batch: dict, a mini-batch of data for training.
61
+ lr: float, real-time learning rate.
62
+
63
+ Returns:
64
+ new_state: utils.TrainState, new training state.
65
+ stats: list. [(loss, psnr), (loss_coarse, psnr_coarse)].
66
+ rng: jnp.ndarray, updated random number generator.
67
+ """
68
+ rng, key_0, key_1 = random.split(rng, 3)
69
+
70
+ def loss_fn(variables):
71
+ rays = batch["rays"]
72
+ ret = model.apply(variables, key_0, key_1, rays, FLAGS.randomized)
73
+ if len(ret) not in (1, 2):
74
+ raise ValueError(
75
+ "ret should contain either 1 set of output (coarse only), or 2 sets"
76
+ "of output (coarse as ret[0] and fine as ret[1]).")
77
+ # The main prediction is always at the end of the ret list.
78
+ rgb, unused_disp, unused_acc = ret[-1]
79
+ loss = ((rgb - batch["pixels"][Ellipsis, :3]) ** 2).mean()
80
+ psnr = utils.compute_psnr(loss)
81
+ if len(ret) > 1:
82
+ # If there are both coarse and fine predictions, we compute the loss for
83
+ # the coarse prediction (ret[0]) as well.
84
+ rgb_c, unused_disp_c, unused_acc_c = ret[0]
85
+ loss_c = ((rgb_c - batch["pixels"][Ellipsis, :3]) ** 2).mean()
86
+ psnr_c = utils.compute_psnr(loss_c)
87
+ else:
88
+ loss_c = 0.
89
+ psnr_c = 0.
90
+
91
+ def tree_sum_fn(fn):
92
+ return jax.tree_util.tree_reduce(lambda x, y: x + fn(y),
93
+ variables, initializer=0)
94
+
95
+ weight_l2 = (tree_sum_fn(lambda z: jnp.sum(z ** 2)) /
96
+ tree_sum_fn(lambda z: jnp.prod(jnp.array(z.shape))))
97
+
98
+ total_loss = loss + loss_c + FLAGS.weight_decay_mult * weight_l2
99
+ stats = utils.Stats(loss=loss, psnr=psnr, loss_c=loss_c,
100
+ psnr_c=psnr_c, weight_l2=weight_l2)
101
+ return total_loss, stats
102
+
103
+ (_, stats), grad = (
104
+ jax.value_and_grad(loss_fn, has_aux=True)(state.optimizer.target))
105
+ grad = jax.lax.pmean(grad, axis_name="batch")
106
+ stats = jax.lax.pmean(stats, axis_name="batch")
107
+
108
+ # Clip the gradient by value.
109
+ if FLAGS.grad_max_val > 0:
110
+ clip_fn = lambda z: jnp.clip(z, -FLAGS.grad_max_val, FLAGS.grad_max_val)
111
+ grad = jax.tree_util.tree_map(clip_fn, grad)
112
+
113
+ # Clip the (possibly value-clipped) gradient by norm.
114
+ if FLAGS.grad_max_norm > 0:
115
+ grad_norm = jnp.sqrt(
116
+ jax.tree_util.tree_reduce(
117
+ lambda x, y: x + jnp.sum(y ** 2), grad, initializer=0))
118
+ mult = jnp.minimum(1, FLAGS.grad_max_norm / (1e-7 + grad_norm))
119
+ grad = jax.tree_util.tree_map(lambda z: mult * z, grad)
120
+
121
+ #return grad, state, rng
122
+ new_optimizer = state.optimizer.apply_gradient(grad, learning_rate =lr)
123
+ new_state = state.replace(optimizer=new_optimizer)
124
+ return new_state, stats, rng
125
+
126
+ def update_step(state, grad, lr):
127
+ new_optimizer = state.optimizer.apply_gradient(grad, learning_rate=lr)
128
+ new_state = state.replace(optimizer=new_optimizer)
129
+ return new_state
130
+
131
+
132
+ def main(unused_argv):
133
+ #wandb.init(project="hf-flax-clip-nerf", entity="wandb", sync_tensorboard=True)
134
+ rng = random.PRNGKey(20200823)
135
+ # Shift the numpy random seed by host_id() to shuffle data loaded by different
136
+ # hosts.
137
+ np.random.seed(20201473 + jax.host_id())
138
+
139
+ if FLAGS.config is not None:
140
+ utils.update_flags(FLAGS)
141
+ if FLAGS.batch_size % jax.device_count() != 0:
142
+ raise ValueError("Batch size must be divisible by the number of devices.")
143
+ if FLAGS.train_dir is None:
144
+ raise ValueError("train_dir must be set. None set now.")
145
+ if FLAGS.data_dir is None:
146
+ raise ValueError("data_dir must be set. None set now.")
147
+
148
+ # setup CLIP model
149
+ if FLAGS.use_semantic_loss:
150
+ clip_model = clip_utils.init_CLIP(FLAGS.clip_output_dtype,
151
+ FLAGS.clip_model_name)
152
+ print('semantic loss ACTIVATED, CLIP is set up')
153
+ else:
154
+ clip_model = None
155
+ print('semantic loss DEACTIVATED, CLIP is set to None')
156
+
157
+ dataset = datasets.get_dataset("train", FLAGS, clip_model)
158
+ test_dataset = datasets.get_dataset("test", FLAGS, clip_model)
159
+
160
+ # setup NeRF model
161
+ rng, key = random.split(rng)
162
+ model, variables = models.get_model(key, dataset.peek(), FLAGS)
163
+ optimizer = flax.optim.Adam(FLAGS.lr_init).create(variables)
164
+ state = utils.TrainState(optimizer=optimizer)
165
+ del optimizer, variables
166
+ learning_rate_fn = functools.partial(
167
+ utils.learning_rate_decay,
168
+ lr_init=FLAGS.lr_init,
169
+ lr_final=FLAGS.lr_final,
170
+ max_steps=FLAGS.max_steps,
171
+ lr_delay_steps=FLAGS.lr_delay_steps,
172
+ lr_delay_mult=FLAGS.lr_delay_mult)
173
+
174
+ train_pstep = jax.pmap(
175
+ functools.partial(train_step, model, clip_model),
176
+ axis_name="batch",
177
+ in_axes=(0, 0, 0, None, None, None),
178
+ donate_argnums=(2,))
179
+
180
+ update_pstep = jax.pmap(
181
+ functools.partial(update_step,),
182
+ axis_name="batch",
183
+ in_axes=(0, None, None),
184
+ donate_argnums=(0,))
185
+
186
+
187
+ def render_fn(variables, key_0, key_1, rays):
188
+ return jax.lax.all_gather(
189
+ model.apply(variables, key_0, key_1, rays, FLAGS.randomized),
190
+ axis_name="batch")
191
+
192
+ render_pfn = jax.pmap(
193
+ render_fn,
194
+ in_axes=(None, None, None, 0), # Only distribute the data input.
195
+ donate_argnums=(3,),
196
+ axis_name="batch")
197
+
198
+ # Compiling to the CPU because it's faster and more accurate.
199
+ ssim_fn = jax.jit(
200
+ functools.partial(utils.compute_ssim, max_val=1.), backend="cpu")
201
+
202
+ if not utils.isdir(FLAGS.train_dir):
203
+ utils.makedirs(FLAGS.train_dir)
204
+ state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
205
+ # Resume training a the step of the last checkpoint.
206
+ init_step = state.optimizer.state.step + 1
207
+
208
+ # for distributive training
209
+ state = flax.jax_utils.replicate(state)
210
+ if jax.host_id() == 0:
211
+ summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir)
212
+
213
+ # Prefetch_buffer_size = 3 x batch_size
214
+ pdataset = flax.jax_utils.prefetch_to_device(dataset, 3)
215
+ n_local_devices = jax.local_device_count()
216
+ rng = rng + jax.host_id() # Make random seed separate across hosts.
217
+ keys = random.split(rng, n_local_devices) # For pmapping RNG keys.
218
+ gc.disable() # Disable automatic garbage collection for efficiency.
219
+ stats_trace = []
220
+ reset_timer = True
221
+
222
+ # for semantic loss update
223
+ cnter = 1
224
+ trigger = int(FLAGS.sc_loss_every / n_local_devices)
225
+
226
+ for step, batch in tqdm(zip(range(init_step, FLAGS.max_steps + 1), pdataset)):
227
+ if reset_timer:
228
+ t_loop_start = time.time()
229
+ reset_timer = False
230
+ lr = learning_rate_fn(step)
231
+
232
+ if step%FLAGS.sc_loss_every == 0 and FLAGS.use_semantic_loss:
233
+ # remove dimension for device coz its only run in host core
234
+ sc_batch = dataset.get_clip_data()
235
+ sc_loss, sc_grad = clip_utils.update_semantic_loss(model, clip_model,
236
+ keys[0], state, sc_batch, lr)
237
+ sc_grad = flax.jax_utils.replicate(sc_grad)
238
+ sc_grad = jax.tree_map( lambda x: x[0], sc_grad)
239
+
240
+ else:
241
+ sc_loss = 0.
242
+
243
+ state, stats, keys = train_pstep(keys, state, batch, lr, step, FLAGS.sc_loss_every)#, grad)
244
+
245
+ if step%FLAGS.sc_loss_every == 0 and FLAGS.use_semantic_loss:
246
+ state = update_pstep(state, sc_grad, lr)
247
+
248
+ if jax.host_id() == 0:
249
+ stats_trace.append(stats)
250
+ if step % FLAGS.gc_every == 0:
251
+ gc.collect()
252
+
253
+ # Log training summaries. This is put behind a host_id check because in
254
+ # multi-host evaluation, all hosts need to run inference even though we
255
+ # only use host 0 to record results.
256
+ if jax.host_id() == 0:
257
+ if step % FLAGS.print_every == 0:
258
+ summary_writer.scalar("train_loss", stats.loss[0], step)
259
+ summary_writer.scalar("train_psnr", stats.psnr[0], step)
260
+ summary_writer.scalar("train_loss_coarse", stats.loss_c[0], step)
261
+ summary_writer.scalar("train_psnr_coarse", stats.psnr_c[0], step)
262
+ summary_writer.scalar("weight_l2", stats.weight_l2[0], step)
263
+ avg_loss = np.mean(np.concatenate([s.loss for s in stats_trace]))
264
+ avg_psnr = np.mean(np.concatenate([s.psnr for s in stats_trace]))
265
+ stats_trace = []
266
+ summary_writer.scalar("train_avg_loss", avg_loss, step)
267
+ summary_writer.scalar("train_avg_psnr", avg_psnr, step)
268
+ summary_writer.scalar("learning_rate", lr, step)
269
+ steps_per_sec = FLAGS.print_every / (time.time() - t_loop_start)
270
+ reset_timer = True
271
+ rays_per_sec = FLAGS.batch_size * steps_per_sec
272
+ summary_writer.scalar("train_steps_per_sec", steps_per_sec, step)
273
+ summary_writer.scalar("train_rays_per_sec", rays_per_sec, step)
274
+ precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1
275
+ print(("{:" + "{:d}".format(precision) + "d}").format(step) +
276
+ f"/{FLAGS.max_steps:d}: " + f"i_loss={stats.loss[0]:0.4f}, " +
277
+ f"avg_loss={avg_loss:0.4f}, " +
278
+ f"weight_l2={stats.weight_l2[0]:0.2e}, " +
279
+ # f"sc_loss={sc_loss:0.4f}, " +
280
+ f"lr={lr:0.2e}, {rays_per_sec:0.0f} rays/sec")
281
+ if step % FLAGS.save_every == 0:
282
+ state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state))
283
+ checkpoints.save_checkpoint(
284
+ FLAGS.train_dir, state_to_save, int(step), keep=100)
285
+
286
+ # Test-set evaluation.
287
+ if FLAGS.render_every > 0 and step % FLAGS.render_every == 0:
288
+ # We reuse the same random number generator from the optimization step
289
+ # here on purpose so that the visualization matches what happened in
290
+ # training.
291
+ t_eval_start = time.time()
292
+ eval_variables = jax.device_get(jax.tree_map(lambda x: x[0],
293
+ state)).optimizer.target
294
+ test_case = next(test_dataset)
295
+ pred_color, pred_disp, pred_acc = utils.render_image(
296
+ functools.partial(render_pfn, eval_variables),
297
+ test_case["rays"],
298
+ keys[0],
299
+ FLAGS.dataset == "llff",
300
+ chunk=FLAGS.chunk)
301
+
302
+ # Log eval summaries on host 0.
303
+ if jax.host_id() == 0:
304
+ psnr = utils.compute_psnr(
305
+ ((pred_color - test_case["pixels"]) ** 2).mean())
306
+ ssim = ssim_fn(pred_color, test_case["pixels"])
307
+ eval_time = time.time() - t_eval_start
308
+ num_rays = jnp.prod(jnp.array(test_case["rays"].directions.shape[:-1]))
309
+ rays_per_sec = num_rays / eval_time
310
+ summary_writer.scalar("test_rays_per_sec", rays_per_sec, step)
311
+ print(f"Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec")
312
+ summary_writer.scalar("test_psnr", psnr, step)
313
+ summary_writer.scalar("test_ssim", ssim, step)
314
+ summary_writer.image("test_pred_color", pred_color, step)
315
+ summary_writer.image("test_pred_disp", pred_disp, step)
316
+ summary_writer.image("test_pred_acc", pred_acc, step)
317
+ summary_writer.image("test_target", test_case["pixels"], step)
318
+
319
+ if FLAGS.max_steps % FLAGS.save_every != 0:
320
+ state = jax.device_get(jax.tree_map(lambda x: x[0], state))
321
+ checkpoints.save_checkpoint(
322
+ FLAGS.train_dir, state, int(FLAGS.max_steps), keep=100)
323
+
324
+
325
+ if __name__ == "__main__":
326
+ app.run(main)
jaxnerf/train.sh ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The Google Research Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ #!/bin/bash
16
+ CONFIG=$1
17
+ DATA_ROOT=$2
18
+ ROOT_DIR=/tmp/jaxnerf/"$CONFIG"
19
+ if [ $CONFIG == "llff" ]
20
+ then
21
+ SCENES="room fern leaves fortress orchids flower trex horns"
22
+ DATA_FOLDER="nerf_llff_data"
23
+ else
24
+ SCENES="lego chair drums ficus hotdog materials mic ship"
25
+ DATA_FOLDER="nerf_synthetic"
26
+ fi
27
+
28
+ # launch training jobs for all scenes.
29
+ for scene in $SCENES; do
30
+ python -m jaxnerf.train \
31
+ --data_dir="$DATA_ROOT"/"$DATA_FOLDER"/"$scene" \
32
+ --train_dir="$ROOT_DIR"/"$scene" \
33
+ --config=configs/"$CONFIG"
34
+ done
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ numpy>=1.16.4
2
+ jax>=0.2.6
3
+ jaxlib>=0.1.57
4
+ flax>=0.2.2
5
+ opencv-python>=4.4.0
6
+ Pillow>=7.2.0
7
+ streamlit==0.84.1
8
+ googledrivedownloader==0.4