akhaliq HF staff ashawkey commited on
Commit
363b2a6
0 Parent(s):

Duplicate from ashawkey/stable-dreamfusion

Browse files

Co-authored-by: ashawkey <ashawkey@users.noreply.huggingface.co>

This view is limited to 50 files because it contains too many changes.   See raw diff
.gitattributes ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.npy filter=lfs diff=lfs merge=lfs -text
13
+ *.npz filter=lfs diff=lfs merge=lfs -text
14
+ *.onnx filter=lfs diff=lfs merge=lfs -text
15
+ *.ot filter=lfs diff=lfs merge=lfs -text
16
+ *.parquet filter=lfs diff=lfs merge=lfs -text
17
+ *.pickle filter=lfs diff=lfs merge=lfs -text
18
+ *.pkl filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pt filter=lfs diff=lfs merge=lfs -text
21
+ *.pth filter=lfs diff=lfs merge=lfs -text
22
+ *.rar filter=lfs diff=lfs merge=lfs -text
23
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
25
+ *.tflite filter=lfs diff=lfs merge=lfs -text
26
+ *.tgz filter=lfs diff=lfs merge=lfs -text
27
+ *.wasm filter=lfs diff=lfs merge=lfs -text
28
+ *.xz filter=lfs diff=lfs merge=lfs -text
29
+ *.zip filter=lfs diff=lfs merge=lfs -text
30
+ *.zst filter=lfs diff=lfs merge=lfs -text
31
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Stable Dreamfusion
3
+ emoji: 🍍
4
+ colorFrom: gray
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 3.5
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ duplicated_from: ashawkey/stable-dreamfusion
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
activation.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.autograd import Function
3
+ from torch.cuda.amp import custom_bwd, custom_fwd
4
+
5
+ class _trunc_exp(Function):
6
+ @staticmethod
7
+ @custom_fwd(cast_inputs=torch.float)
8
+ def forward(ctx, x):
9
+ ctx.save_for_backward(x)
10
+ return torch.exp(x)
11
+
12
+ @staticmethod
13
+ @custom_bwd
14
+ def backward(ctx, g):
15
+ x = ctx.saved_tensors[0]
16
+ return g * torch.exp(x.clamp(-15, 15))
17
+
18
+ trunc_exp = _trunc_exp.apply
app.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+
4
+ from nerf.provider import NeRFDataset
5
+ from nerf.utils import *
6
+
7
+ import gradio as gr
8
+ import gc
9
+
10
+ print(f'[INFO] loading options..')
11
+
12
+ # fake config object, this should not be used in CMD, only allow change from gradio UI.
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument('--text', default=None, help="text prompt")
15
+ # parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray --dir_text")
16
+ # parser.add_argument('-O2', action='store_true', help="equals --fp16 --dir_text")
17
+ parser.add_argument('--test', action='store_true', help="test mode")
18
+ parser.add_argument('--save_mesh', action='store_true', help="export an obj mesh with texture")
19
+ parser.add_argument('--eval_interval', type=int, default=10, help="evaluate on the valid set every interval epochs")
20
+ parser.add_argument('--workspace', type=str, default='trial_gradio')
21
+ parser.add_argument('--guidance', type=str, default='stable-diffusion', help='choose from [stable-diffusion, clip]')
22
+ parser.add_argument('--seed', type=int, default=0)
23
+
24
+ ### training options
25
+ parser.add_argument('--iters', type=int, default=10000, help="training iters")
26
+ parser.add_argument('--lr', type=float, default=1e-3, help="initial learning rate")
27
+ parser.add_argument('--ckpt', type=str, default='latest')
28
+ parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
29
+ parser.add_argument('--max_steps', type=int, default=1024, help="max num steps sampled per ray (only valid when using --cuda_ray)")
30
+ parser.add_argument('--num_steps', type=int, default=64, help="num steps sampled per ray (only valid when not using --cuda_ray)")
31
+ parser.add_argument('--upsample_steps', type=int, default=64, help="num steps up-sampled per ray (only valid when not using --cuda_ray)")
32
+ parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)")
33
+ parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when not using --cuda_ray)")
34
+ parser.add_argument('--albedo_iters', type=int, default=1000, help="training iters that only use albedo shading")
35
+ # model options
36
+ parser.add_argument('--bg_radius', type=float, default=1.4, help="if positive, use a background model at sphere(bg_radius)")
37
+ parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied")
38
+ # network backbone
39
+ parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
40
+ parser.add_argument('--backbone', type=str, default='grid', help="nerf backbone, choose from [grid, tcnn, vanilla]")
41
+ # rendering resolution in training, decrease this if CUDA OOM.
42
+ parser.add_argument('--w', type=int, default=64, help="render width for NeRF in training")
43
+ parser.add_argument('--h', type=int, default=64, help="render height for NeRF in training")
44
+ parser.add_argument('--jitter_pose', action='store_true', help="add jitters to the randomly sampled camera poses")
45
+
46
+ ### dataset options
47
+ parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box(-bound, bound)")
48
+ parser.add_argument('--dt_gamma', type=float, default=0, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)")
49
+ parser.add_argument('--min_near', type=float, default=0.1, help="minimum near distance for camera")
50
+ parser.add_argument('--radius_range', type=float, nargs='*', default=[1.0, 1.5], help="training camera radius range")
51
+ parser.add_argument('--fovy_range', type=float, nargs='*', default=[40, 70], help="training camera fovy range")
52
+ parser.add_argument('--dir_text', action='store_true', help="direction-encode the text prompt, by appending front/side/back/overhead view")
53
+ parser.add_argument('--angle_overhead', type=float, default=30, help="[0, angle_overhead] is the overhead region")
54
+ parser.add_argument('--angle_front', type=float, default=60, help="[0, angle_front] is the front region, [180, 180+angle_front] the back region, otherwise the side region.")
55
+
56
+ parser.add_argument('--lambda_entropy', type=float, default=1e-4, help="loss scale for alpha entropy")
57
+ parser.add_argument('--lambda_opacity', type=float, default=0, help="loss scale for alpha value")
58
+ parser.add_argument('--lambda_orient', type=float, default=1e-2, help="loss scale for orientation")
59
+
60
+ ### GUI options
61
+ parser.add_argument('--gui', action='store_true', help="start a GUI")
62
+ parser.add_argument('--W', type=int, default=800, help="GUI width")
63
+ parser.add_argument('--H', type=int, default=800, help="GUI height")
64
+ parser.add_argument('--radius', type=float, default=3, help="default GUI camera radius from center")
65
+ parser.add_argument('--fovy', type=float, default=60, help="default GUI camera fovy")
66
+ parser.add_argument('--light_theta', type=float, default=60, help="default GUI light direction in [0, 180], corresponding to elevation [90, -90]")
67
+ parser.add_argument('--light_phi', type=float, default=0, help="default GUI light direction in [0, 360), azimuth")
68
+ parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel")
69
+
70
+ opt = parser.parse_args()
71
+
72
+ # default to use -O !!!
73
+ opt.fp16 = True
74
+ opt.dir_text = True
75
+ opt.cuda_ray = True
76
+ # opt.lambda_entropy = 1e-4
77
+ # opt.lambda_opacity = 0
78
+
79
+ if opt.backbone == 'vanilla':
80
+ from nerf.network import NeRFNetwork
81
+ elif opt.backbone == 'tcnn':
82
+ from nerf.network_tcnn import NeRFNetwork
83
+ elif opt.backbone == 'grid':
84
+ from nerf.network_grid import NeRFNetwork
85
+ else:
86
+ raise NotImplementedError(f'--backbone {opt.backbone} is not implemented!')
87
+
88
+ print(opt)
89
+
90
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
91
+
92
+ print(f'[INFO] loading models..')
93
+
94
+ if opt.guidance == 'stable-diffusion':
95
+ from nerf.sd import StableDiffusion
96
+ guidance = StableDiffusion(device)
97
+ elif opt.guidance == 'clip':
98
+ from nerf.clip import CLIP
99
+ guidance = CLIP(device)
100
+ else:
101
+ raise NotImplementedError(f'--guidance {opt.guidance} is not implemented.')
102
+
103
+ train_loader = NeRFDataset(opt, device=device, type='train', H=opt.h, W=opt.w, size=100).dataloader()
104
+ valid_loader = NeRFDataset(opt, device=device, type='val', H=opt.H, W=opt.W, size=5).dataloader()
105
+ test_loader = NeRFDataset(opt, device=device, type='test', H=opt.H, W=opt.W, size=100).dataloader()
106
+
107
+ print(f'[INFO] everything loaded!')
108
+
109
+ trainer = None
110
+ model = None
111
+
112
+ # define UI
113
+
114
+ with gr.Blocks(css=".gradio-container {max-width: 512px; margin: auto;}") as demo:
115
+
116
+ # title
117
+ gr.Markdown('[Stable-DreamFusion](https://github.com/ashawkey/stable-dreamfusion) Text-to-3D Example')
118
+
119
+ # inputs
120
+ prompt = gr.Textbox(label="Prompt", max_lines=1, value="a DSLR photo of a koi fish")
121
+ iters = gr.Slider(label="Iters", minimum=1000, maximum=20000, value=5000, step=100)
122
+ seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True)
123
+ button = gr.Button('Generate')
124
+
125
+ # outputs
126
+ image = gr.Image(label="image", visible=True)
127
+ video = gr.Video(label="video", visible=False)
128
+ logs = gr.Textbox(label="logging")
129
+
130
+ # gradio main func
131
+ def submit(text, iters, seed):
132
+
133
+ global trainer, model
134
+
135
+ # seed
136
+ opt.seed = seed
137
+ opt.text = text
138
+ opt.iters = iters
139
+
140
+ seed_everything(seed)
141
+
142
+ # clean up
143
+ if trainer is not None:
144
+ del model
145
+ del trainer
146
+ gc.collect()
147
+ torch.cuda.empty_cache()
148
+ print('[INFO] clean up!')
149
+
150
+ # simply reload everything...
151
+ model = NeRFNetwork(opt)
152
+ optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15)
153
+ scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1))
154
+
155
+ trainer = Trainer('df', opt, model, guidance, device=device, workspace=opt.workspace, optimizer=optimizer, ema_decay=0.95, fp16=opt.fp16, lr_scheduler=scheduler, use_checkpoint=opt.ckpt, eval_interval=opt.eval_interval, scheduler_update_every_step=True)
156
+
157
+ # train (every ep only contain 8 steps, so we can get some vis every ~10s)
158
+ STEPS = 8
159
+ max_epochs = np.ceil(opt.iters / STEPS).astype(np.int32)
160
+
161
+ # we have to get the explicit training loop out here to yield progressive results...
162
+ loader = iter(valid_loader)
163
+
164
+ start_t = time.time()
165
+
166
+ for epoch in range(max_epochs):
167
+
168
+ trainer.train_gui(train_loader, step=STEPS)
169
+
170
+ # manual test and get intermediate results
171
+ try:
172
+ data = next(loader)
173
+ except StopIteration:
174
+ loader = iter(valid_loader)
175
+ data = next(loader)
176
+
177
+ trainer.model.eval()
178
+
179
+ if trainer.ema is not None:
180
+ trainer.ema.store()
181
+ trainer.ema.copy_to()
182
+
183
+ with torch.no_grad():
184
+ with torch.cuda.amp.autocast(enabled=trainer.fp16):
185
+ preds, preds_depth = trainer.test_step(data, perturb=False)
186
+
187
+ if trainer.ema is not None:
188
+ trainer.ema.restore()
189
+
190
+ pred = preds[0].detach().cpu().numpy()
191
+ # pred_depth = preds_depth[0].detach().cpu().numpy()
192
+
193
+ pred = (pred * 255).astype(np.uint8)
194
+
195
+ yield {
196
+ image: gr.update(value=pred, visible=True),
197
+ video: gr.update(visible=False),
198
+ logs: f"training iters: {epoch * STEPS} / {iters}, lr: {trainer.optimizer.param_groups[0]['lr']:.6f}",
199
+ }
200
+
201
+
202
+ # test
203
+ trainer.test(test_loader)
204
+
205
+ results = glob.glob(os.path.join(opt.workspace, 'results', '*rgb*.mp4'))
206
+ assert results is not None, "cannot retrieve results!"
207
+ results.sort(key=lambda x: os.path.getmtime(x)) # sort by mtime
208
+
209
+ end_t = time.time()
210
+
211
+ yield {
212
+ image: gr.update(visible=False),
213
+ video: gr.update(value=results[-1], visible=True),
214
+ logs: f"Generation Finished in {(end_t - start_t)/ 60:.4f} minutes!",
215
+ }
216
+
217
+
218
+ button.click(
219
+ submit,
220
+ [prompt, iters, seed],
221
+ [image, video, logs]
222
+ )
223
+
224
+ # concurrency_count: only allow ONE running progress, else GPU will OOM.
225
+ demo.queue(concurrency_count=1)
226
+
227
+ demo.launch()
assets/update_logs.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ ### 2022.10.9
2
+ * The shading (partially) starts to work, at least it won't make scene empty. For some prompts, it shows better results (less severe Janus problem). The textureless rendering mode is still disabled.
3
+ * Enable shading by default (--albedo_iters 1000).
4
+
5
+ ### 2022.10.5
6
+ * Basic reproduction finished.
7
+ * Non --cuda_ray, --tcnn are not working, need to fix.
8
+ * Shading is not working, disabled in utils.py for now. Surface normals are bad.
9
+ * Use an entropy loss to regularize weights_sum (alpha), the original L2 reg always leads to degenerated geometry...
docker/Dockerfile ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04
2
+
3
+ # Remove any third-party apt sources to avoid issues with expiring keys.
4
+ RUN rm -f /etc/apt/sources.list.d/*.list
5
+
6
+ RUN apt-get update
7
+
8
+ RUN DEBIAN_FRONTEND=noninteractive TZ=Europe/MADRID apt-get install -y tzdata
9
+
10
+ # Install some basic utilities
11
+ RUN apt-get install -y \
12
+ curl \
13
+ ca-certificates \
14
+ sudo \
15
+ git \
16
+ bzip2 \
17
+ libx11-6 \
18
+ python3 \
19
+ python3-pip \
20
+ libglfw3-dev \
21
+ libgles2-mesa-dev \
22
+ libglib2.0-0 \
23
+ && rm -rf /var/lib/apt/lists/*
24
+
25
+
26
+ # Create a working directory
27
+ RUN mkdir /app
28
+ WORKDIR /app
29
+
30
+ RUN cd /app
31
+ RUN git clone https://github.com/ashawkey/stable-dreamfusion.git
32
+
33
+
34
+ RUN pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116
35
+
36
+ WORKDIR /app/stable-dreamfusion
37
+
38
+ RUN pip3 install -r requirements.txt
39
+ RUN pip3 install git+https://github.com/NVlabs/nvdiffrast/
40
+
41
+ # Needs nvidia runtime, if you have "No CUDA runtime is found" error: https://stackoverflow.com/questions/59691207/docker-build-with-nvidia-runtime, first answer
42
+ RUN pip3 install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
43
+
44
+ RUN pip3 install git+https://github.com/openai/CLIP.git
45
+ RUN bash scripts/install_ext.sh
46
+
47
+
48
+
49
+
50
+
51
+ # Set the default command to python3
52
+ #CMD ["python3"]
53
+
docker/README.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Docker installation
2
+
3
+ ## Build image
4
+ To build the docker image on your own machine, which may take 15-30 mins:
5
+ ```
6
+ docker build -t stable-dreamfusion:latest .
7
+ ```
8
+
9
+ If you have the error **No CUDA runtime is found** when building the wheels for tiny-cuda-nn you need to setup the nvidia-runtime for docker.
10
+ ```
11
+ sudo apt-get install nvidia-container-runtime
12
+ ```
13
+ Then edit `/etc/docker/daemon.json` and add the default-runtime:
14
+ ```
15
+ {
16
+ "runtimes": {
17
+ "nvidia": {
18
+ "path": "nvidia-container-runtime",
19
+ "runtimeArgs": []
20
+ }
21
+ },
22
+ "default-runtime": "nvidia"
23
+ }
24
+ ```
25
+ And restart docker:
26
+ ```
27
+ sudo systemctl restart docker
28
+ ```
29
+ Now you can build tiny-cuda-nn inside docker.
30
+
31
+ ## Download image
32
+ To download the image (~6GB) instead:
33
+ ```
34
+ docker pull supercabb/stable-dreamfusion:3080_0.0.1
35
+ docker tag supercabb/stable-dreamfusion:3080_0.0.1 stable-dreamfusion
36
+ ```
37
+
38
+ ## Use image
39
+
40
+ You can launch an interactive shell inside the container:
41
+
42
+ ```
43
+ docker run --gpus all -it --rm -v $(cd ~ && pwd):/mnt stable-dreamfusion /bin/bash
44
+ ```
45
+ From this shell, all the code in the repo should work.
46
+
47
+ To run any single command `<command...>` inside the docker container:
48
+ ```
49
+ docker run --gpus all -it --rm -v $(cd ~ && pwd):/mnt stable-dreamfusion /bin/bash -c "<command...>"
50
+ ```
51
+ To train:
52
+ ```
53
+ export TOKEN="#HUGGING FACE ACCESS TOKEN#"
54
+ docker run --gpus all -it --rm -v $(cd ~ && pwd):/mnt stable-dreamfusion /bin/bash -c "echo ${TOKEN} > TOKEN \
55
+ && python3 main.py --text \"a hamburger\" --workspace trial -O"
56
+
57
+ ```
58
+ Run test without gui:
59
+ ```
60
+ export PATH_TO_WORKSPACE="#PATH_TO_WORKSPACE#"
61
+ docker run --gpus all -it --rm -e DISPLAY=$DISPLAY -v /tmp/.X11-unix:/tmp/.X11-unix:ro -v $(cd ~ && pwd):/mnt \
62
+ -v $(cd ${PATH_TO_WORKSPACE} && pwd):/app/stable-dreamfusion/trial stable-dreamfusion /bin/bash -c "python3 \
63
+ main.py --workspace trial -O --test"
64
+ ```
65
+ Run test with gui:
66
+ ```
67
+ export PATH_TO_WORKSPACE="#PATH_TO_WORKSPACE#"
68
+ xhost +
69
+ docker run --gpus all -it --rm -e DISPLAY=$DISPLAY -v /tmp/.X11-unix:/tmp/.X11-unix:ro -v $(cd ~ && pwd):/mnt \
70
+ -v $(cd ${PATH_TO_WORKSPACE} && pwd):/app/stable-dreamfusion/trial stable-dreamfusion /bin/bash -c "python3 \
71
+ main.py --workspace trial -O --test --gui"
72
+ xhost -
73
+ ```
74
+
75
+
76
+
77
+
78
+
79
+
80
+
encoding.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ def get_encoder(encoding, input_dim=3,
6
+ multires=6,
7
+ degree=4,
8
+ num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False,
9
+ **kwargs):
10
+
11
+ if encoding == 'None':
12
+ return lambda x, **kwargs: x, input_dim
13
+
14
+ elif encoding == 'frequency':
15
+ from freqencoder import FreqEncoder
16
+ encoder = FreqEncoder(input_dim=input_dim, degree=multires)
17
+
18
+ elif encoding == 'sphere_harmonics':
19
+ from shencoder import SHEncoder
20
+ encoder = SHEncoder(input_dim=input_dim, degree=degree)
21
+
22
+ elif encoding == 'hashgrid':
23
+ from gridencoder import GridEncoder
24
+ encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners)
25
+
26
+ elif encoding == 'tiledgrid':
27
+ from gridencoder import GridEncoder
28
+ encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners)
29
+
30
+ else:
31
+ raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, sphere_harmonics, hashgrid, tiledgrid]')
32
+
33
+ return encoder, encoder.output_dim
freqencoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .freq import FreqEncoder
freqencoder/backend.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from torch.utils.cpp_extension import load
3
+
4
+ _src_path = os.path.dirname(os.path.abspath(__file__))
5
+
6
+ nvcc_flags = [
7
+ '-O3', '-std=c++14',
8
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
9
+ '-use_fast_math'
10
+ ]
11
+
12
+ if os.name == "posix":
13
+ c_flags = ['-O3', '-std=c++14']
14
+ elif os.name == "nt":
15
+ c_flags = ['/O2', '/std:c++17']
16
+
17
+ # find cl.exe
18
+ def find_cl_path():
19
+ import glob
20
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
21
+ paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
22
+ if paths:
23
+ return paths[0]
24
+
25
+ # If cl.exe is not on path, try to find it.
26
+ if os.system("where cl.exe >nul 2>nul") != 0:
27
+ cl_path = find_cl_path()
28
+ if cl_path is None:
29
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
30
+ os.environ["PATH"] += ";" + cl_path
31
+
32
+ _backend = load(name='_freqencoder',
33
+ extra_cflags=c_flags,
34
+ extra_cuda_cflags=nvcc_flags,
35
+ sources=[os.path.join(_src_path, 'src', f) for f in [
36
+ 'freqencoder.cu',
37
+ 'bindings.cpp',
38
+ ]],
39
+ )
40
+
41
+ __all__ = ['_backend']
freqencoder/freq.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.autograd import Function
6
+ from torch.autograd.function import once_differentiable
7
+ from torch.cuda.amp import custom_bwd, custom_fwd
8
+
9
+ try:
10
+ import _freqencoder as _backend
11
+ except ImportError:
12
+ from .backend import _backend
13
+
14
+
15
+ class _freq_encoder(Function):
16
+ @staticmethod
17
+ @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision
18
+ def forward(ctx, inputs, degree, output_dim):
19
+ # inputs: [B, input_dim], float
20
+ # RETURN: [B, F], float
21
+
22
+ if not inputs.is_cuda: inputs = inputs.cuda()
23
+ inputs = inputs.contiguous()
24
+
25
+ B, input_dim = inputs.shape # batch size, coord dim
26
+
27
+ outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)
28
+
29
+ _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
30
+
31
+ ctx.save_for_backward(inputs, outputs)
32
+ ctx.dims = [B, input_dim, degree, output_dim]
33
+
34
+ return outputs
35
+
36
+ @staticmethod
37
+ #@once_differentiable
38
+ @custom_bwd
39
+ def backward(ctx, grad):
40
+ # grad: [B, C * C]
41
+
42
+ grad = grad.contiguous()
43
+ inputs, outputs = ctx.saved_tensors
44
+ B, input_dim, degree, output_dim = ctx.dims
45
+
46
+ grad_inputs = torch.zeros_like(inputs)
47
+ _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
48
+
49
+ return grad_inputs, None, None
50
+
51
+
52
+ freq_encode = _freq_encoder.apply
53
+
54
+
55
+ class FreqEncoder(nn.Module):
56
+ def __init__(self, input_dim=3, degree=4):
57
+ super().__init__()
58
+
59
+ self.input_dim = input_dim
60
+ self.degree = degree
61
+ self.output_dim = input_dim + input_dim * 2 * degree
62
+
63
+ def __repr__(self):
64
+ return f"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}"
65
+
66
+ def forward(self, inputs, **kwargs):
67
+ # inputs: [..., input_dim]
68
+ # return: [..., ]
69
+
70
+ prefix_shape = list(inputs.shape[:-1])
71
+ inputs = inputs.reshape(-1, self.input_dim)
72
+
73
+ outputs = freq_encode(inputs, self.degree, self.output_dim)
74
+
75
+ outputs = outputs.reshape(prefix_shape + [self.output_dim])
76
+
77
+ return outputs
freqencoder/setup.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from setuptools import setup
3
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4
+
5
+ _src_path = os.path.dirname(os.path.abspath(__file__))
6
+
7
+ nvcc_flags = [
8
+ '-O3', '-std=c++14',
9
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
10
+ '-use_fast_math'
11
+ ]
12
+
13
+ if os.name == "posix":
14
+ c_flags = ['-O3', '-std=c++14']
15
+ elif os.name == "nt":
16
+ c_flags = ['/O2', '/std:c++17']
17
+
18
+ # find cl.exe
19
+ def find_cl_path():
20
+ import glob
21
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
22
+ paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
23
+ if paths:
24
+ return paths[0]
25
+
26
+ # If cl.exe is not on path, try to find it.
27
+ if os.system("where cl.exe >nul 2>nul") != 0:
28
+ cl_path = find_cl_path()
29
+ if cl_path is None:
30
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
31
+ os.environ["PATH"] += ";" + cl_path
32
+
33
+ setup(
34
+ name='freqencoder', # package name, import this to use python API
35
+ ext_modules=[
36
+ CUDAExtension(
37
+ name='_freqencoder', # extension name, import this to use CUDA API
38
+ sources=[os.path.join(_src_path, 'src', f) for f in [
39
+ 'freqencoder.cu',
40
+ 'bindings.cpp',
41
+ ]],
42
+ extra_compile_args={
43
+ 'cxx': c_flags,
44
+ 'nvcc': nvcc_flags,
45
+ }
46
+ ),
47
+ ],
48
+ cmdclass={
49
+ 'build_ext': BuildExtension,
50
+ }
51
+ )
freqencoder/src/bindings.cpp ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+ #include "freqencoder.h"
4
+
5
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6
+ m.def("freq_encode_forward", &freq_encode_forward, "freq encode forward (CUDA)");
7
+ m.def("freq_encode_backward", &freq_encode_backward, "freq encode backward (CUDA)");
8
+ }
freqencoder/src/freqencoder.cu ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <stdint.h>
2
+
3
+ #include <cuda.h>
4
+ #include <cuda_fp16.h>
5
+ #include <cuda_runtime.h>
6
+
7
+ #include <ATen/cuda/CUDAContext.h>
8
+ #include <torch/torch.h>
9
+
10
+ #include <algorithm>
11
+ #include <stdexcept>
12
+
13
+ #include <cstdio>
14
+
15
+
16
+ #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
17
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
18
+ #define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
19
+ #define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
20
+
21
+ inline constexpr __device__ float PI() { return 3.141592653589793f; }
22
+
23
+ template <typename T>
24
+ __host__ __device__ T div_round_up(T val, T divisor) {
25
+ return (val + divisor - 1) / divisor;
26
+ }
27
+
28
+ // inputs: [B, D]
29
+ // outputs: [B, C], C = D + D * deg * 2
30
+ __global__ void kernel_freq(
31
+ const float * __restrict__ inputs,
32
+ uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
33
+ float * outputs
34
+ ) {
35
+ // parallel on per-element
36
+ const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
37
+ if (t >= B * C) return;
38
+
39
+ // get index
40
+ const uint32_t b = t / C;
41
+ const uint32_t c = t - b * C; // t % C;
42
+
43
+ // locate
44
+ inputs += b * D;
45
+ outputs += t;
46
+
47
+ // write self
48
+ if (c < D) {
49
+ outputs[0] = inputs[c];
50
+ // write freq
51
+ } else {
52
+ const uint32_t col = c / D - 1;
53
+ const uint32_t d = c % D;
54
+ const uint32_t freq = col / 2;
55
+ const float phase_shift = (col % 2) * (PI() / 2);
56
+ outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift);
57
+ }
58
+ }
59
+
60
+ // grad: [B, C], C = D + D * deg * 2
61
+ // outputs: [B, C]
62
+ // grad_inputs: [B, D]
63
+ __global__ void kernel_freq_backward(
64
+ const float * __restrict__ grad,
65
+ const float * __restrict__ outputs,
66
+ uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
67
+ float * grad_inputs
68
+ ) {
69
+ // parallel on per-element
70
+ const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
71
+ if (t >= B * D) return;
72
+
73
+ const uint32_t b = t / D;
74
+ const uint32_t d = t - b * D; // t % D;
75
+
76
+ // locate
77
+ grad += b * C;
78
+ outputs += b * C;
79
+ grad_inputs += t;
80
+
81
+ // register
82
+ float result = grad[d];
83
+ grad += D;
84
+ outputs += D;
85
+
86
+ for (uint32_t f = 0; f < deg; f++) {
87
+ result += scalbnf(1.0f, f) * (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]);
88
+ grad += 2 * D;
89
+ outputs += 2 * D;
90
+ }
91
+
92
+ // write
93
+ grad_inputs[0] = result;
94
+ }
95
+
96
+
97
+ void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs) {
98
+ CHECK_CUDA(inputs);
99
+ CHECK_CUDA(outputs);
100
+
101
+ CHECK_CONTIGUOUS(inputs);
102
+ CHECK_CONTIGUOUS(outputs);
103
+
104
+ CHECK_IS_FLOATING(inputs);
105
+ CHECK_IS_FLOATING(outputs);
106
+
107
+ static constexpr uint32_t N_THREADS = 128;
108
+
109
+ kernel_freq<<<div_round_up(B * C, N_THREADS), N_THREADS>>>(inputs.data_ptr<float>(), B, D, deg, C, outputs.data_ptr<float>());
110
+ }
111
+
112
+
113
+ void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs) {
114
+ CHECK_CUDA(grad);
115
+ CHECK_CUDA(outputs);
116
+ CHECK_CUDA(grad_inputs);
117
+
118
+ CHECK_CONTIGUOUS(grad);
119
+ CHECK_CONTIGUOUS(outputs);
120
+ CHECK_CONTIGUOUS(grad_inputs);
121
+
122
+ CHECK_IS_FLOATING(grad);
123
+ CHECK_IS_FLOATING(outputs);
124
+ CHECK_IS_FLOATING(grad_inputs);
125
+
126
+ static constexpr uint32_t N_THREADS = 128;
127
+
128
+ kernel_freq_backward<<<div_round_up(B * D, N_THREADS), N_THREADS>>>(grad.data_ptr<float>(), outputs.data_ptr<float>(), B, D, deg, C, grad_inputs.data_ptr<float>());
129
+ }
freqencoder/src/freqencoder.h ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # pragma once
2
+
3
+ #include <stdint.h>
4
+ #include <torch/torch.h>
5
+
6
+ // _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
7
+ void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs);
8
+
9
+ // _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
10
+ void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs);
gridencoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .grid import GridEncoder
gridencoder/backend.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from torch.utils.cpp_extension import load
3
+
4
+ _src_path = os.path.dirname(os.path.abspath(__file__))
5
+
6
+ nvcc_flags = [
7
+ '-O3', '-std=c++14',
8
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
9
+ ]
10
+
11
+ if os.name == "posix":
12
+ c_flags = ['-O3', '-std=c++14']
13
+ elif os.name == "nt":
14
+ c_flags = ['/O2', '/std:c++17']
15
+
16
+ # find cl.exe
17
+ def find_cl_path():
18
+ import glob
19
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
20
+ paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
21
+ if paths:
22
+ return paths[0]
23
+
24
+ # If cl.exe is not on path, try to find it.
25
+ if os.system("where cl.exe >nul 2>nul") != 0:
26
+ cl_path = find_cl_path()
27
+ if cl_path is None:
28
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
29
+ os.environ["PATH"] += ";" + cl_path
30
+
31
+ _backend = load(name='_grid_encoder',
32
+ extra_cflags=c_flags,
33
+ extra_cuda_cflags=nvcc_flags,
34
+ sources=[os.path.join(_src_path, 'src', f) for f in [
35
+ 'gridencoder.cu',
36
+ 'bindings.cpp',
37
+ ]],
38
+ )
39
+
40
+ __all__ = ['_backend']
gridencoder/grid.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.autograd import Function
6
+ from torch.autograd.function import once_differentiable
7
+ from torch.cuda.amp import custom_bwd, custom_fwd
8
+
9
+ try:
10
+ import _gridencoder as _backend
11
+ except ImportError:
12
+ from .backend import _backend
13
+
14
+ _gridtype_to_id = {
15
+ 'hash': 0,
16
+ 'tiled': 1,
17
+ }
18
+
19
+ class _grid_encode(Function):
20
+ @staticmethod
21
+ @custom_fwd
22
+ def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False):
23
+ # inputs: [B, D], float in [0, 1]
24
+ # embeddings: [sO, C], float
25
+ # offsets: [L + 1], int
26
+ # RETURN: [B, F], float
27
+
28
+ inputs = inputs.contiguous()
29
+
30
+ B, D = inputs.shape # batch size, coord dim
31
+ L = offsets.shape[0] - 1 # level
32
+ C = embeddings.shape[1] # embedding dim for each level
33
+ S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f
34
+ H = base_resolution # base resolution
35
+
36
+ # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision)
37
+ # if C % 2 != 0, force float, since half for atomicAdd is very slow.
38
+ if torch.is_autocast_enabled() and C % 2 == 0:
39
+ embeddings = embeddings.to(torch.half)
40
+
41
+ # L first, optimize cache for cuda kernel, but needs an extra permute later
42
+ outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype)
43
+
44
+ if calc_grad_inputs:
45
+ dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype)
46
+ else:
47
+ dy_dx = None
48
+
49
+ _backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, dy_dx, gridtype, align_corners)
50
+
51
+ # permute back to [B, L * C]
52
+ outputs = outputs.permute(1, 0, 2).reshape(B, L * C)
53
+
54
+ ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)
55
+ ctx.dims = [B, D, C, L, S, H, gridtype]
56
+ ctx.align_corners = align_corners
57
+
58
+ return outputs
59
+
60
+ @staticmethod
61
+ #@once_differentiable
62
+ @custom_bwd
63
+ def backward(ctx, grad):
64
+
65
+ inputs, embeddings, offsets, dy_dx = ctx.saved_tensors
66
+ B, D, C, L, S, H, gridtype = ctx.dims
67
+ align_corners = ctx.align_corners
68
+
69
+ # grad: [B, L * C] --> [L, B, C]
70
+ grad = grad.view(B, L, C).permute(1, 0, 2).contiguous()
71
+
72
+ grad_embeddings = torch.zeros_like(embeddings)
73
+
74
+ if dy_dx is not None:
75
+ grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype)
76
+ else:
77
+ grad_inputs = None
78
+
79
+ _backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners)
80
+
81
+ if dy_dx is not None:
82
+ grad_inputs = grad_inputs.to(inputs.dtype)
83
+
84
+ return grad_inputs, grad_embeddings, None, None, None, None, None, None
85
+
86
+
87
+
88
+ grid_encode = _grid_encode.apply
89
+
90
+
91
+ class GridEncoder(nn.Module):
92
+ def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False):
93
+ super().__init__()
94
+
95
+ # the finest resolution desired at the last level, if provided, overridee per_level_scale
96
+ if desired_resolution is not None:
97
+ per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1))
98
+
99
+ self.input_dim = input_dim # coord dims, 2 or 3
100
+ self.num_levels = num_levels # num levels, each level multiply resolution by 2
101
+ self.level_dim = level_dim # encode channels per level
102
+ self.per_level_scale = per_level_scale # multiply resolution by this scale at each level.
103
+ self.log2_hashmap_size = log2_hashmap_size
104
+ self.base_resolution = base_resolution
105
+ self.output_dim = num_levels * level_dim
106
+ self.gridtype = gridtype
107
+ self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash"
108
+ self.align_corners = align_corners
109
+
110
+ # allocate parameters
111
+ offsets = []
112
+ offset = 0
113
+ self.max_params = 2 ** log2_hashmap_size
114
+ for i in range(num_levels):
115
+ resolution = int(np.ceil(base_resolution * per_level_scale ** i))
116
+ params_in_level = min(self.max_params, (resolution if align_corners else resolution + 1) ** input_dim) # limit max number
117
+ params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible
118
+ offsets.append(offset)
119
+ offset += params_in_level
120
+ offsets.append(offset)
121
+ offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
122
+ self.register_buffer('offsets', offsets)
123
+
124
+ self.n_params = offsets[-1] * level_dim
125
+
126
+ # parameters
127
+ self.embeddings = nn.Parameter(torch.empty(offset, level_dim))
128
+
129
+ self.reset_parameters()
130
+
131
+ def reset_parameters(self):
132
+ std = 1e-4
133
+ self.embeddings.data.uniform_(-std, std)
134
+
135
+ def __repr__(self):
136
+ return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners}"
137
+
138
+ def forward(self, inputs, bound=1):
139
+ # inputs: [..., input_dim], normalized real world positions in [-bound, bound]
140
+ # return: [..., num_levels * level_dim]
141
+
142
+ inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
143
+
144
+ #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item())
145
+
146
+ prefix_shape = list(inputs.shape[:-1])
147
+ inputs = inputs.view(-1, self.input_dim)
148
+
149
+ outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners)
150
+ outputs = outputs.view(prefix_shape + [self.output_dim])
151
+
152
+ #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())
153
+
154
+ return outputs
gridencoder/setup.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from setuptools import setup
3
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4
+
5
+ _src_path = os.path.dirname(os.path.abspath(__file__))
6
+
7
+ nvcc_flags = [
8
+ '-O3', '-std=c++14',
9
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
10
+ ]
11
+
12
+ if os.name == "posix":
13
+ c_flags = ['-O3', '-std=c++14']
14
+ elif os.name == "nt":
15
+ c_flags = ['/O2', '/std:c++17']
16
+
17
+ # find cl.exe
18
+ def find_cl_path():
19
+ import glob
20
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
21
+ paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
22
+ if paths:
23
+ return paths[0]
24
+
25
+ # If cl.exe is not on path, try to find it.
26
+ if os.system("where cl.exe >nul 2>nul") != 0:
27
+ cl_path = find_cl_path()
28
+ if cl_path is None:
29
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
30
+ os.environ["PATH"] += ";" + cl_path
31
+
32
+ setup(
33
+ name='gridencoder', # package name, import this to use python API
34
+ ext_modules=[
35
+ CUDAExtension(
36
+ name='_gridencoder', # extension name, import this to use CUDA API
37
+ sources=[os.path.join(_src_path, 'src', f) for f in [
38
+ 'gridencoder.cu',
39
+ 'bindings.cpp',
40
+ ]],
41
+ extra_compile_args={
42
+ 'cxx': c_flags,
43
+ 'nvcc': nvcc_flags,
44
+ }
45
+ ),
46
+ ],
47
+ cmdclass={
48
+ 'build_ext': BuildExtension,
49
+ }
50
+ )
gridencoder/src/bindings.cpp ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+ #include "gridencoder.h"
4
+
5
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6
+ m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)");
7
+ m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)");
8
+ }
gridencoder/src/gridencoder.cu ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <cuda.h>
2
+ #include <cuda_fp16.h>
3
+ #include <cuda_runtime.h>
4
+
5
+ #include <ATen/cuda/CUDAContext.h>
6
+ #include <torch/torch.h>
7
+
8
+ #include <algorithm>
9
+ #include <stdexcept>
10
+
11
+ #include <stdint.h>
12
+ #include <cstdio>
13
+
14
+
15
+ #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
16
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
17
+ #define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
18
+ #define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
19
+
20
+
21
+ // just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF...
22
+ static inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) {
23
+ // requires CUDA >= 10 and ARCH >= 70
24
+ // this is very slow compared to float or __half2, and never used.
25
+ //return atomicAdd(reinterpret_cast<__half*>(address), val);
26
+ }
27
+
28
+
29
+ template <typename T>
30
+ static inline __host__ __device__ T div_round_up(T val, T divisor) {
31
+ return (val + divisor - 1) / divisor;
32
+ }
33
+
34
+
35
+ template <uint32_t D>
36
+ __device__ uint32_t fast_hash(const uint32_t pos_grid[D]) {
37
+ static_assert(D <= 7, "fast_hash can only hash up to 7 dimensions.");
38
+
39
+ // While 1 is technically not a good prime for hashing (or a prime at all), it helps memory coherence
40
+ // and is sufficient for our use case of obtaining a uniformly colliding index from high-dimensional
41
+ // coordinates.
42
+ constexpr uint32_t primes[7] = { 1, 2654435761, 805459861, 3674653429, 2097192037, 1434869437, 2165219737 };
43
+
44
+ uint32_t result = 0;
45
+ #pragma unroll
46
+ for (uint32_t i = 0; i < D; ++i) {
47
+ result ^= pos_grid[i] * primes[i];
48
+ }
49
+
50
+ return result;
51
+ }
52
+
53
+
54
+ template <uint32_t D, uint32_t C>
55
+ __device__ uint32_t get_grid_index(const uint32_t gridtype, const bool align_corners, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) {
56
+ uint32_t stride = 1;
57
+ uint32_t index = 0;
58
+
59
+ #pragma unroll
60
+ for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) {
61
+ index += pos_grid[d] * stride;
62
+ stride *= align_corners ? resolution: (resolution + 1);
63
+ }
64
+
65
+ // NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97.
66
+ // gridtype: 0 == hash, 1 == tiled
67
+ if (gridtype == 0 && stride > hashmap_size) {
68
+ index = fast_hash<D>(pos_grid);
69
+ }
70
+
71
+ return (index % hashmap_size) * C + ch;
72
+ }
73
+
74
+
75
+ template <typename scalar_t, uint32_t D, uint32_t C>
76
+ __global__ void kernel_grid(
77
+ const float * __restrict__ inputs,
78
+ const scalar_t * __restrict__ grid,
79
+ const int * __restrict__ offsets,
80
+ scalar_t * __restrict__ outputs,
81
+ const uint32_t B, const uint32_t L, const float S, const uint32_t H,
82
+ scalar_t * __restrict__ dy_dx,
83
+ const uint32_t gridtype,
84
+ const bool align_corners
85
+ ) {
86
+ const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
87
+
88
+ if (b >= B) return;
89
+
90
+ const uint32_t level = blockIdx.y;
91
+
92
+ // locate
93
+ grid += (uint32_t)offsets[level] * C;
94
+ inputs += b * D;
95
+ outputs += level * B * C + b * C;
96
+
97
+ // check input range (should be in [0, 1])
98
+ bool flag_oob = false;
99
+ #pragma unroll
100
+ for (uint32_t d = 0; d < D; d++) {
101
+ if (inputs[d] < 0 || inputs[d] > 1) {
102
+ flag_oob = true;
103
+ }
104
+ }
105
+ // if input out of bound, just set output to 0
106
+ if (flag_oob) {
107
+ #pragma unroll
108
+ for (uint32_t ch = 0; ch < C; ch++) {
109
+ outputs[ch] = 0;
110
+ }
111
+ if (dy_dx) {
112
+ dy_dx += b * D * L * C + level * D * C; // B L D C
113
+ #pragma unroll
114
+ for (uint32_t d = 0; d < D; d++) {
115
+ #pragma unroll
116
+ for (uint32_t ch = 0; ch < C; ch++) {
117
+ dy_dx[d * C + ch] = 0;
118
+ }
119
+ }
120
+ }
121
+ return;
122
+ }
123
+
124
+ const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
125
+ const float scale = exp2f(level * S) * H - 1.0f;
126
+ const uint32_t resolution = (uint32_t)ceil(scale) + 1;
127
+
128
+ // calculate coordinate
129
+ float pos[D];
130
+ uint32_t pos_grid[D];
131
+
132
+ #pragma unroll
133
+ for (uint32_t d = 0; d < D; d++) {
134
+ pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
135
+ pos_grid[d] = floorf(pos[d]);
136
+ pos[d] -= (float)pos_grid[d];
137
+ }
138
+
139
+ //printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]);
140
+
141
+ // interpolate
142
+ scalar_t results[C] = {0}; // temp results in register
143
+
144
+ #pragma unroll
145
+ for (uint32_t idx = 0; idx < (1 << D); idx++) {
146
+ float w = 1;
147
+ uint32_t pos_grid_local[D];
148
+
149
+ #pragma unroll
150
+ for (uint32_t d = 0; d < D; d++) {
151
+ if ((idx & (1 << d)) == 0) {
152
+ w *= 1 - pos[d];
153
+ pos_grid_local[d] = pos_grid[d];
154
+ } else {
155
+ w *= pos[d];
156
+ pos_grid_local[d] = pos_grid[d] + 1;
157
+ }
158
+ }
159
+
160
+ uint32_t index = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
161
+
162
+ // writing to register (fast)
163
+ #pragma unroll
164
+ for (uint32_t ch = 0; ch < C; ch++) {
165
+ results[ch] += w * grid[index + ch];
166
+ }
167
+
168
+ //printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]);
169
+ }
170
+
171
+ // writing to global memory (slow)
172
+ #pragma unroll
173
+ for (uint32_t ch = 0; ch < C; ch++) {
174
+ outputs[ch] = results[ch];
175
+ }
176
+
177
+ // prepare dy_dx
178
+ // differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9
179
+ if (dy_dx) {
180
+
181
+ dy_dx += b * D * L * C + level * D * C; // B L D C
182
+
183
+ #pragma unroll
184
+ for (uint32_t gd = 0; gd < D; gd++) {
185
+
186
+ scalar_t results_grad[C] = {0};
187
+
188
+ #pragma unroll
189
+ for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) {
190
+ float w = scale;
191
+ uint32_t pos_grid_local[D];
192
+
193
+ #pragma unroll
194
+ for (uint32_t nd = 0; nd < D - 1; nd++) {
195
+ const uint32_t d = (nd >= gd) ? (nd + 1) : nd;
196
+
197
+ if ((idx & (1 << nd)) == 0) {
198
+ w *= 1 - pos[d];
199
+ pos_grid_local[d] = pos_grid[d];
200
+ } else {
201
+ w *= pos[d];
202
+ pos_grid_local[d] = pos_grid[d] + 1;
203
+ }
204
+ }
205
+
206
+ pos_grid_local[gd] = pos_grid[gd];
207
+ uint32_t index_left = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
208
+ pos_grid_local[gd] = pos_grid[gd] + 1;
209
+ uint32_t index_right = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
210
+
211
+ #pragma unroll
212
+ for (uint32_t ch = 0; ch < C; ch++) {
213
+ results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]);
214
+ }
215
+ }
216
+
217
+ #pragma unroll
218
+ for (uint32_t ch = 0; ch < C; ch++) {
219
+ dy_dx[gd * C + ch] = results_grad[ch];
220
+ }
221
+ }
222
+ }
223
+ }
224
+
225
+
226
+ template <typename scalar_t, uint32_t D, uint32_t C, uint32_t N_C>
227
+ __global__ void kernel_grid_backward(
228
+ const scalar_t * __restrict__ grad,
229
+ const float * __restrict__ inputs,
230
+ const scalar_t * __restrict__ grid,
231
+ const int * __restrict__ offsets,
232
+ scalar_t * __restrict__ grad_grid,
233
+ const uint32_t B, const uint32_t L, const float S, const uint32_t H,
234
+ const uint32_t gridtype,
235
+ const bool align_corners
236
+ ) {
237
+ const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C;
238
+ if (b >= B) return;
239
+
240
+ const uint32_t level = blockIdx.y;
241
+ const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C;
242
+
243
+ // locate
244
+ grad_grid += offsets[level] * C;
245
+ inputs += b * D;
246
+ grad += level * B * C + b * C + ch; // L, B, C
247
+
248
+ const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
249
+ const float scale = exp2f(level * S) * H - 1.0f;
250
+ const uint32_t resolution = (uint32_t)ceil(scale) + 1;
251
+
252
+ // check input range (should be in [0, 1])
253
+ #pragma unroll
254
+ for (uint32_t d = 0; d < D; d++) {
255
+ if (inputs[d] < 0 || inputs[d] > 1) {
256
+ return; // grad is init as 0, so we simply return.
257
+ }
258
+ }
259
+
260
+ // calculate coordinate
261
+ float pos[D];
262
+ uint32_t pos_grid[D];
263
+
264
+ #pragma unroll
265
+ for (uint32_t d = 0; d < D; d++) {
266
+ pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
267
+ pos_grid[d] = floorf(pos[d]);
268
+ pos[d] -= (float)pos_grid[d];
269
+ }
270
+
271
+ scalar_t grad_cur[N_C] = {0}; // fetch to register
272
+ #pragma unroll
273
+ for (uint32_t c = 0; c < N_C; c++) {
274
+ grad_cur[c] = grad[c];
275
+ }
276
+
277
+ // interpolate
278
+ #pragma unroll
279
+ for (uint32_t idx = 0; idx < (1 << D); idx++) {
280
+ float w = 1;
281
+ uint32_t pos_grid_local[D];
282
+
283
+ #pragma unroll
284
+ for (uint32_t d = 0; d < D; d++) {
285
+ if ((idx & (1 << d)) == 0) {
286
+ w *= 1 - pos[d];
287
+ pos_grid_local[d] = pos_grid[d];
288
+ } else {
289
+ w *= pos[d];
290
+ pos_grid_local[d] = pos_grid[d] + 1;
291
+ }
292
+ }
293
+
294
+ uint32_t index = get_grid_index<D, C>(gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local);
295
+
296
+ // atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0
297
+ // TODO: use float which is better than __half, if N_C % 2 != 0
298
+ if (std::is_same<scalar_t, at::Half>::value && N_C % 2 == 0) {
299
+ #pragma unroll
300
+ for (uint32_t c = 0; c < N_C; c += 2) {
301
+ // process two __half at once (by interpreting as a __half2)
302
+ __half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])};
303
+ atomicAdd((__half2*)&grad_grid[index + c], v);
304
+ }
305
+ // float, or __half when N_C % 2 != 0 (which means C == 1)
306
+ } else {
307
+ #pragma unroll
308
+ for (uint32_t c = 0; c < N_C; c++) {
309
+ atomicAdd(&grad_grid[index + c], w * grad_cur[c]);
310
+ }
311
+ }
312
+ }
313
+ }
314
+
315
+
316
+ template <typename scalar_t, uint32_t D, uint32_t C>
317
+ __global__ void kernel_input_backward(
318
+ const scalar_t * __restrict__ grad,
319
+ const scalar_t * __restrict__ dy_dx,
320
+ scalar_t * __restrict__ grad_inputs,
321
+ uint32_t B, uint32_t L
322
+ ) {
323
+ const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
324
+ if (t >= B * D) return;
325
+
326
+ const uint32_t b = t / D;
327
+ const uint32_t d = t - b * D;
328
+
329
+ dy_dx += b * L * D * C;
330
+
331
+ scalar_t result = 0;
332
+
333
+ # pragma unroll
334
+ for (int l = 0; l < L; l++) {
335
+ # pragma unroll
336
+ for (int ch = 0; ch < C; ch++) {
337
+ result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch];
338
+ }
339
+ }
340
+
341
+ grad_inputs[t] = result;
342
+ }
343
+
344
+
345
+ template <typename scalar_t, uint32_t D>
346
+ void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) {
347
+ static constexpr uint32_t N_THREAD = 512;
348
+ const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 };
349
+ switch (C) {
350
+ case 1: kernel_grid<scalar_t, D, 1><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break;
351
+ case 2: kernel_grid<scalar_t, D, 2><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break;
352
+ case 4: kernel_grid<scalar_t, D, 4><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break;
353
+ case 8: kernel_grid<scalar_t, D, 8><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break;
354
+ default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
355
+ }
356
+ }
357
+
358
+ // inputs: [B, D], float, in [0, 1]
359
+ // embeddings: [sO, C], float
360
+ // offsets: [L + 1], uint32_t
361
+ // outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.)
362
+ // H: base resolution
363
+ // dy_dx: [B, L * D * C]
364
+ template <typename scalar_t>
365
+ void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) {
366
+ switch (D) {
367
+ case 1: kernel_grid_wrapper<scalar_t, 1>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
368
+ case 2: kernel_grid_wrapper<scalar_t, 2>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
369
+ case 3: kernel_grid_wrapper<scalar_t, 3>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
370
+ case 4: kernel_grid_wrapper<scalar_t, 4>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
371
+ case 5: kernel_grid_wrapper<scalar_t, 5>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
372
+ default: throw std::runtime_error{"GridEncoding: D must be 1, 2, 3, 4, or 5."};
373
+ }
374
+
375
+ }
376
+
377
+ template <typename scalar_t, uint32_t D>
378
+ void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) {
379
+ static constexpr uint32_t N_THREAD = 256;
380
+ const uint32_t N_C = std::min(2u, C); // n_features_per_thread
381
+ const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), L, 1 };
382
+ switch (C) {
383
+ case 1:
384
+ kernel_grid_backward<scalar_t, D, 1, 1><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
385
+ if (dy_dx) kernel_input_backward<scalar_t, D, 1><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
386
+ break;
387
+ case 2:
388
+ kernel_grid_backward<scalar_t, D, 2, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
389
+ if (dy_dx) kernel_input_backward<scalar_t, D, 2><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
390
+ break;
391
+ case 4:
392
+ kernel_grid_backward<scalar_t, D, 4, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
393
+ if (dy_dx) kernel_input_backward<scalar_t, D, 4><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
394
+ break;
395
+ case 8:
396
+ kernel_grid_backward<scalar_t, D, 8, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
397
+ if (dy_dx) kernel_input_backward<scalar_t, D, 8><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
398
+ break;
399
+ default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
400
+ }
401
+ }
402
+
403
+
404
+ // grad: [L, B, C], float
405
+ // inputs: [B, D], float, in [0, 1]
406
+ // embeddings: [sO, C], float
407
+ // offsets: [L + 1], uint32_t
408
+ // grad_embeddings: [sO, C]
409
+ // H: base resolution
410
+ template <typename scalar_t>
411
+ void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) {
412
+ switch (D) {
413
+ case 1: kernel_grid_backward_wrapper<scalar_t, 1>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
414
+ case 2: kernel_grid_backward_wrapper<scalar_t, 2>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
415
+ case 3: kernel_grid_backward_wrapper<scalar_t, 3>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
416
+ case 4: kernel_grid_backward_wrapper<scalar_t, 4>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
417
+ case 5: kernel_grid_backward_wrapper<scalar_t, 5>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
418
+ default: throw std::runtime_error{"GridEncoding: D must be 1, 2, 3, 4, or 5."};
419
+ }
420
+ }
421
+
422
+
423
+
424
+ void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional<at::Tensor> dy_dx, const uint32_t gridtype, const bool align_corners) {
425
+ CHECK_CUDA(inputs);
426
+ CHECK_CUDA(embeddings);
427
+ CHECK_CUDA(offsets);
428
+ CHECK_CUDA(outputs);
429
+ // CHECK_CUDA(dy_dx);
430
+
431
+ CHECK_CONTIGUOUS(inputs);
432
+ CHECK_CONTIGUOUS(embeddings);
433
+ CHECK_CONTIGUOUS(offsets);
434
+ CHECK_CONTIGUOUS(outputs);
435
+ // CHECK_CONTIGUOUS(dy_dx);
436
+
437
+ CHECK_IS_FLOATING(inputs);
438
+ CHECK_IS_FLOATING(embeddings);
439
+ CHECK_IS_INT(offsets);
440
+ CHECK_IS_FLOATING(outputs);
441
+ // CHECK_IS_FLOATING(dy_dx);
442
+
443
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
444
+ embeddings.scalar_type(), "grid_encode_forward", ([&] {
445
+ grid_encode_forward_cuda<scalar_t>(inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), outputs.data_ptr<scalar_t>(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr, gridtype, align_corners);
446
+ }));
447
+ }
448
+
449
+ void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional<at::Tensor> dy_dx, at::optional<at::Tensor> grad_inputs, const uint32_t gridtype, const bool align_corners) {
450
+ CHECK_CUDA(grad);
451
+ CHECK_CUDA(inputs);
452
+ CHECK_CUDA(embeddings);
453
+ CHECK_CUDA(offsets);
454
+ CHECK_CUDA(grad_embeddings);
455
+ // CHECK_CUDA(dy_dx);
456
+ // CHECK_CUDA(grad_inputs);
457
+
458
+ CHECK_CONTIGUOUS(grad);
459
+ CHECK_CONTIGUOUS(inputs);
460
+ CHECK_CONTIGUOUS(embeddings);
461
+ CHECK_CONTIGUOUS(offsets);
462
+ CHECK_CONTIGUOUS(grad_embeddings);
463
+ // CHECK_CONTIGUOUS(dy_dx);
464
+ // CHECK_CONTIGUOUS(grad_inputs);
465
+
466
+ CHECK_IS_FLOATING(grad);
467
+ CHECK_IS_FLOATING(inputs);
468
+ CHECK_IS_FLOATING(embeddings);
469
+ CHECK_IS_INT(offsets);
470
+ CHECK_IS_FLOATING(grad_embeddings);
471
+ // CHECK_IS_FLOATING(dy_dx);
472
+ // CHECK_IS_FLOATING(grad_inputs);
473
+
474
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
475
+ grad.scalar_type(), "grid_encode_backward", ([&] {
476
+ grid_encode_backward_cuda<scalar_t>(grad.data_ptr<scalar_t>(), inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), grad_embeddings.data_ptr<scalar_t>(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr, grad_inputs.has_value() ? grad_inputs.value().data_ptr<scalar_t>() : nullptr, gridtype, align_corners);
477
+ }));
478
+
479
+ }
gridencoder/src/gridencoder.h ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _HASH_ENCODE_H
2
+ #define _HASH_ENCODE_H
3
+
4
+ #include <stdint.h>
5
+ #include <torch/torch.h>
6
+
7
+ // inputs: [B, D], float, in [0, 1]
8
+ // embeddings: [sO, C], float
9
+ // offsets: [L + 1], uint32_t
10
+ // outputs: [B, L * C], float
11
+ // H: base resolution
12
+ void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional<at::Tensor> dy_dx, const uint32_t gridtype, const bool align_corners);
13
+ void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional<at::Tensor> dy_dx, at::optional<at::Tensor> grad_inputs, const uint32_t gridtype, const bool align_corners);
14
+
15
+ #endif
main.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+
4
+ from nerf.provider import NeRFDataset
5
+ from nerf.utils import *
6
+ from optimizer import Shampoo
7
+
8
+ from nerf.gui import NeRFGUI
9
+
10
+ # torch.autograd.set_detect_anomaly(True)
11
+
12
+ if __name__ == '__main__':
13
+
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument('--text', default=None, help="text prompt")
16
+ parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray --dir_text")
17
+ parser.add_argument('-O2', action='store_true', help="equals --fp16 --dir_text")
18
+ parser.add_argument('--test', action='store_true', help="test mode")
19
+ parser.add_argument('--save_mesh', action='store_true', help="export an obj mesh with texture")
20
+ parser.add_argument('--eval_interval', type=int, default=10, help="evaluate on the valid set every interval epochs")
21
+ parser.add_argument('--workspace', type=str, default='workspace')
22
+ parser.add_argument('--guidance', type=str, default='stable-diffusion', help='choose from [stable-diffusion, clip]')
23
+ parser.add_argument('--seed', type=int, default=0)
24
+
25
+ ### training options
26
+ parser.add_argument('--iters', type=int, default=10000, help="training iters")
27
+ parser.add_argument('--lr', type=float, default=1e-3, help="initial learning rate")
28
+ parser.add_argument('--ckpt', type=str, default='latest')
29
+ parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
30
+ parser.add_argument('--max_steps', type=int, default=1024, help="max num steps sampled per ray (only valid when using --cuda_ray)")
31
+ parser.add_argument('--num_steps', type=int, default=64, help="num steps sampled per ray (only valid when not using --cuda_ray)")
32
+ parser.add_argument('--upsample_steps', type=int, default=64, help="num steps up-sampled per ray (only valid when not using --cuda_ray)")
33
+ parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)")
34
+ parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when not using --cuda_ray)")
35
+ parser.add_argument('--albedo_iters', type=int, default=1000, help="training iters that only use albedo shading")
36
+ # model options
37
+ parser.add_argument('--bg_radius', type=float, default=1.4, help="if positive, use a background model at sphere(bg_radius)")
38
+ parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied")
39
+ # network backbone
40
+ parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
41
+ parser.add_argument('--backbone', type=str, default='grid', help="nerf backbone, choose from [grid, tcnn, vanilla]")
42
+ # rendering resolution in training, decrease this if CUDA OOM.
43
+ parser.add_argument('--w', type=int, default=64, help="render width for NeRF in training")
44
+ parser.add_argument('--h', type=int, default=64, help="render height for NeRF in training")
45
+ parser.add_argument('--jitter_pose', action='store_true', help="add jitters to the randomly sampled camera poses")
46
+
47
+ ### dataset options
48
+ parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box(-bound, bound)")
49
+ parser.add_argument('--dt_gamma', type=float, default=0, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)")
50
+ parser.add_argument('--min_near', type=float, default=0.1, help="minimum near distance for camera")
51
+ parser.add_argument('--radius_range', type=float, nargs='*', default=[1.0, 1.5], help="training camera radius range")
52
+ parser.add_argument('--fovy_range', type=float, nargs='*', default=[40, 70], help="training camera fovy range")
53
+ parser.add_argument('--dir_text', action='store_true', help="direction-encode the text prompt, by appending front/side/back/overhead view")
54
+ parser.add_argument('--angle_overhead', type=float, default=30, help="[0, angle_overhead] is the overhead region")
55
+ parser.add_argument('--angle_front', type=float, default=60, help="[0, angle_front] is the front region, [180, 180+angle_front] the back region, otherwise the side region.")
56
+
57
+ parser.add_argument('--lambda_entropy', type=float, default=1e-4, help="loss scale for alpha entropy")
58
+ parser.add_argument('--lambda_opacity', type=float, default=0, help="loss scale for alpha value")
59
+ parser.add_argument('--lambda_orient', type=float, default=1e-2, help="loss scale for orientation")
60
+
61
+ ### GUI options
62
+ parser.add_argument('--gui', action='store_true', help="start a GUI")
63
+ parser.add_argument('--W', type=int, default=800, help="GUI width")
64
+ parser.add_argument('--H', type=int, default=800, help="GUI height")
65
+ parser.add_argument('--radius', type=float, default=3, help="default GUI camera radius from center")
66
+ parser.add_argument('--fovy', type=float, default=60, help="default GUI camera fovy")
67
+ parser.add_argument('--light_theta', type=float, default=60, help="default GUI light direction in [0, 180], corresponding to elevation [90, -90]")
68
+ parser.add_argument('--light_phi', type=float, default=0, help="default GUI light direction in [0, 360), azimuth")
69
+ parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel")
70
+
71
+ opt = parser.parse_args()
72
+
73
+ if opt.O:
74
+ opt.fp16 = True
75
+ opt.dir_text = True
76
+ # use occupancy grid to prune ray sampling, faster rendering.
77
+ opt.cuda_ray = True
78
+ # opt.lambda_entropy = 1e-4
79
+ # opt.lambda_opacity = 0
80
+
81
+ elif opt.O2:
82
+ opt.fp16 = True
83
+ opt.dir_text = True
84
+ opt.lambda_entropy = 1e-4 # necessary to keep non-empty
85
+ opt.lambda_opacity = 3e-3 # no occupancy grid, so use a stronger opacity loss.
86
+
87
+ if opt.backbone == 'vanilla':
88
+ from nerf.network import NeRFNetwork
89
+ elif opt.backbone == 'tcnn':
90
+ from nerf.network_tcnn import NeRFNetwork
91
+ elif opt.backbone == 'grid':
92
+ from nerf.network_grid import NeRFNetwork
93
+ else:
94
+ raise NotImplementedError(f'--backbone {opt.backbone} is not implemented!')
95
+
96
+ print(opt)
97
+
98
+ seed_everything(opt.seed)
99
+
100
+ model = NeRFNetwork(opt)
101
+
102
+ print(model)
103
+
104
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
105
+
106
+ if opt.test:
107
+ guidance = None # no need to load guidance model at test
108
+
109
+ trainer = Trainer('df', opt, model, guidance, device=device, workspace=opt.workspace, fp16=opt.fp16, use_checkpoint=opt.ckpt)
110
+
111
+ if opt.gui:
112
+ gui = NeRFGUI(opt, trainer)
113
+ gui.render()
114
+
115
+ else:
116
+ test_loader = NeRFDataset(opt, device=device, type='test', H=opt.H, W=opt.W, size=100).dataloader()
117
+ trainer.test(test_loader)
118
+
119
+ if opt.save_mesh:
120
+ trainer.save_mesh(resolution=256)
121
+
122
+ else:
123
+
124
+ if opt.guidance == 'stable-diffusion':
125
+ from nerf.sd import StableDiffusion
126
+ guidance = StableDiffusion(device)
127
+ elif opt.guidance == 'clip':
128
+ from nerf.clip import CLIP
129
+ guidance = CLIP(device)
130
+ else:
131
+ raise NotImplementedError(f'--guidance {opt.guidance} is not implemented.')
132
+
133
+ optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15)
134
+ # optimizer = lambda model: Shampoo(model.get_params(opt.lr))
135
+
136
+ train_loader = NeRFDataset(opt, device=device, type='train', H=opt.h, W=opt.w, size=100).dataloader()
137
+
138
+ scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1))
139
+ # scheduler = lambda optimizer: optim.lr_scheduler.OneCycleLR(optimizer, max_lr=opt.lr, total_steps=opt.iters, pct_start=0.1)
140
+
141
+ trainer = Trainer('df', opt, model, guidance, device=device, workspace=opt.workspace, optimizer=optimizer, ema_decay=None, fp16=opt.fp16, lr_scheduler=scheduler, use_checkpoint=opt.ckpt, eval_interval=opt.eval_interval, scheduler_update_every_step=True)
142
+
143
+ if opt.gui:
144
+ trainer.train_loader = train_loader # attach dataloader to trainer
145
+
146
+ gui = NeRFGUI(opt, trainer)
147
+ gui.render()
148
+
149
+ else:
150
+ valid_loader = NeRFDataset(opt, device=device, type='val', H=opt.H, W=opt.W, size=5).dataloader()
151
+
152
+ max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32)
153
+ trainer.train(train_loader, valid_loader, max_epoch)
154
+
155
+ # also test
156
+ test_loader = NeRFDataset(opt, device=device, type='test', H=opt.H, W=opt.W, size=100).dataloader()
157
+ trainer.test(test_loader)
158
+
159
+ if opt.save_mesh:
160
+ trainer.save_mesh(resolution=256)
nerf/clip.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ import torchvision.transforms as T
5
+ import torchvision.transforms.functional as TF
6
+
7
+ import clip
8
+
9
+ class CLIP(nn.Module):
10
+ def __init__(self, device):
11
+ super().__init__()
12
+
13
+ self.device = device
14
+
15
+ self.clip_model, self.clip_preprocess = clip.load("ViT-B/16", device=self.device, jit=False)
16
+
17
+ # image augmentation
18
+ self.aug = T.Compose([
19
+ T.Resize((224, 224)),
20
+ T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
21
+ ])
22
+
23
+ # self.gaussian_blur = T.GaussianBlur(15, sigma=(0.1, 10))
24
+
25
+
26
+ def get_text_embeds(self, prompt):
27
+
28
+ text = clip.tokenize(prompt).to(self.device)
29
+ text_z = self.clip_model.encode_text(text)
30
+ text_z = text_z / text_z.norm(dim=-1, keepdim=True)
31
+
32
+ return text_z
33
+
34
+
35
+ def train_step(self, text_z, pred_rgb):
36
+
37
+ pred_rgb = self.aug(pred_rgb)
38
+
39
+ image_z = self.clip_model.encode_image(pred_rgb)
40
+ image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
41
+
42
+ loss = - (image_z * text_z).sum(-1).mean()
43
+
44
+ return loss
45
+
nerf/gui.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import numpy as np
4
+ import dearpygui.dearpygui as dpg
5
+ from scipy.spatial.transform import Rotation as R
6
+
7
+ from nerf.utils import *
8
+
9
+
10
+ class OrbitCamera:
11
+ def __init__(self, W, H, r=2, fovy=60):
12
+ self.W = W
13
+ self.H = H
14
+ self.radius = r # camera distance from center
15
+ self.fovy = fovy # in degree
16
+ self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point
17
+ self.rot = R.from_quat([1, 0, 0, 0]) # init camera matrix: [[1, 0, 0], [0, -1, 0], [0, 0, 1]] (to suit ngp convention)
18
+ self.up = np.array([0, 1, 0], dtype=np.float32) # need to be normalized!
19
+
20
+ # pose
21
+ @property
22
+ def pose(self):
23
+ # first move camera to radius
24
+ res = np.eye(4, dtype=np.float32)
25
+ res[2, 3] -= self.radius
26
+ # rotate
27
+ rot = np.eye(4, dtype=np.float32)
28
+ rot[:3, :3] = self.rot.as_matrix()
29
+ res = rot @ res
30
+ # translate
31
+ res[:3, 3] -= self.center
32
+ return res
33
+
34
+ # intrinsics
35
+ @property
36
+ def intrinsics(self):
37
+ focal = self.H / (2 * np.tan(np.deg2rad(self.fovy) / 2))
38
+ return np.array([focal, focal, self.W // 2, self.H // 2])
39
+
40
+ def orbit(self, dx, dy):
41
+ # rotate along camera up/side axis!
42
+ side = self.rot.as_matrix()[:3, 0] # why this is side --> ? # already normalized.
43
+ rotvec_x = self.up * np.deg2rad(-0.1 * dx)
44
+ rotvec_y = side * np.deg2rad(-0.1 * dy)
45
+ self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot
46
+
47
+ def scale(self, delta):
48
+ self.radius *= 1.1 ** (-delta)
49
+
50
+ def pan(self, dx, dy, dz=0):
51
+ # pan in camera coordinate system (careful on the sensitivity!)
52
+ self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([dx, dy, dz])
53
+
54
+
55
+ class NeRFGUI:
56
+ def __init__(self, opt, trainer, debug=True):
57
+ self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
58
+ self.W = opt.W
59
+ self.H = opt.H
60
+ self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)
61
+ self.debug = debug
62
+ self.bg_color = torch.ones(3, dtype=torch.float32) # default white bg
63
+ self.training = False
64
+ self.step = 0 # training step
65
+
66
+ self.trainer = trainer
67
+ self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
68
+ self.need_update = True # camera moved, should reset accumulation
69
+ self.spp = 1 # sample per pixel
70
+ self.light_dir = np.array([opt.light_theta, opt.light_phi])
71
+ self.ambient_ratio = 1.0
72
+ self.mode = 'image' # choose from ['image', 'depth']
73
+ self.shading = 'albedo'
74
+
75
+ self.dynamic_resolution = True
76
+ self.downscale = 1
77
+ self.train_steps = 16
78
+
79
+ dpg.create_context()
80
+ self.register_dpg()
81
+ self.test_step()
82
+
83
+
84
+ def __del__(self):
85
+ dpg.destroy_context()
86
+
87
+
88
+ def train_step(self):
89
+
90
+ starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
91
+ starter.record()
92
+
93
+ outputs = self.trainer.train_gui(self.trainer.train_loader, step=self.train_steps)
94
+
95
+ ender.record()
96
+ torch.cuda.synchronize()
97
+ t = starter.elapsed_time(ender)
98
+
99
+ self.step += self.train_steps
100
+ self.need_update = True
101
+
102
+ dpg.set_value("_log_train_time", f'{t:.4f}ms ({int(1000/t)} FPS)')
103
+ dpg.set_value("_log_train_log", f'step = {self.step: 5d} (+{self.train_steps: 2d}), loss = {outputs["loss"]:.4f}, lr = {outputs["lr"]:.5f}')
104
+
105
+ # dynamic train steps
106
+ # max allowed train time per-frame is 500 ms
107
+ full_t = t / self.train_steps * 16
108
+ train_steps = min(16, max(4, int(16 * 500 / full_t)))
109
+ if train_steps > self.train_steps * 1.2 or train_steps < self.train_steps * 0.8:
110
+ self.train_steps = train_steps
111
+
112
+
113
+ def prepare_buffer(self, outputs):
114
+ if self.mode == 'image':
115
+ return outputs['image']
116
+ else:
117
+ return np.expand_dims(outputs['depth'], -1).repeat(3, -1)
118
+
119
+
120
+ def test_step(self):
121
+
122
+ if self.need_update or self.spp < self.opt.max_spp:
123
+
124
+ starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
125
+ starter.record()
126
+
127
+ outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, self.W, self.H, self.bg_color, self.spp, self.downscale, self.light_dir, self.ambient_ratio, self.shading)
128
+
129
+ ender.record()
130
+ torch.cuda.synchronize()
131
+ t = starter.elapsed_time(ender)
132
+
133
+ # update dynamic resolution
134
+ if self.dynamic_resolution:
135
+ # max allowed infer time per-frame is 200 ms
136
+ full_t = t / (self.downscale ** 2)
137
+ downscale = min(1, max(1/4, math.sqrt(200 / full_t)))
138
+ if downscale > self.downscale * 1.2 or downscale < self.downscale * 0.8:
139
+ self.downscale = downscale
140
+
141
+ if self.need_update:
142
+ self.render_buffer = self.prepare_buffer(outputs)
143
+ self.spp = 1
144
+ self.need_update = False
145
+ else:
146
+ self.render_buffer = (self.render_buffer * self.spp + self.prepare_buffer(outputs)) / (self.spp + 1)
147
+ self.spp += 1
148
+
149
+ dpg.set_value("_log_infer_time", f'{t:.4f}ms ({int(1000/t)} FPS)')
150
+ dpg.set_value("_log_resolution", f'{int(self.downscale * self.W)}x{int(self.downscale * self.H)}')
151
+ dpg.set_value("_log_spp", self.spp)
152
+ dpg.set_value("_texture", self.render_buffer)
153
+
154
+
155
+ def register_dpg(self):
156
+
157
+ ### register texture
158
+
159
+ with dpg.texture_registry(show=False):
160
+ dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture")
161
+
162
+ ### register window
163
+
164
+ # the rendered image, as the primary window
165
+ with dpg.window(tag="_primary_window", width=self.W, height=self.H):
166
+
167
+ # add the texture
168
+ dpg.add_image("_texture")
169
+
170
+ dpg.set_primary_window("_primary_window", True)
171
+
172
+ # control window
173
+ with dpg.window(label="Control", tag="_control_window", width=400, height=300):
174
+
175
+ # text prompt
176
+ if self.opt.text is not None:
177
+ dpg.add_text("text: " + self.opt.text, tag="_log_prompt_text")
178
+
179
+ # button theme
180
+ with dpg.theme() as theme_button:
181
+ with dpg.theme_component(dpg.mvButton):
182
+ dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18))
183
+ dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47))
184
+ dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83))
185
+ dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5)
186
+ dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3)
187
+
188
+ # time
189
+ if not self.opt.test:
190
+ with dpg.group(horizontal=True):
191
+ dpg.add_text("Train time: ")
192
+ dpg.add_text("no data", tag="_log_train_time")
193
+
194
+ with dpg.group(horizontal=True):
195
+ dpg.add_text("Infer time: ")
196
+ dpg.add_text("no data", tag="_log_infer_time")
197
+
198
+ with dpg.group(horizontal=True):
199
+ dpg.add_text("SPP: ")
200
+ dpg.add_text("1", tag="_log_spp")
201
+
202
+ # train button
203
+ if not self.opt.test:
204
+ with dpg.collapsing_header(label="Train", default_open=True):
205
+ with dpg.group(horizontal=True):
206
+ dpg.add_text("Train: ")
207
+
208
+ def callback_train(sender, app_data):
209
+ if self.training:
210
+ self.training = False
211
+ dpg.configure_item("_button_train", label="start")
212
+ else:
213
+ self.training = True
214
+ dpg.configure_item("_button_train", label="stop")
215
+
216
+ dpg.add_button(label="start", tag="_button_train", callback=callback_train)
217
+ dpg.bind_item_theme("_button_train", theme_button)
218
+
219
+ def callback_reset(sender, app_data):
220
+ @torch.no_grad()
221
+ def weight_reset(m: nn.Module):
222
+ reset_parameters = getattr(m, "reset_parameters", None)
223
+ if callable(reset_parameters):
224
+ m.reset_parameters()
225
+ self.trainer.model.apply(fn=weight_reset)
226
+ self.trainer.model.reset_extra_state() # for cuda_ray density_grid and step_counter
227
+ self.need_update = True
228
+
229
+ dpg.add_button(label="reset", tag="_button_reset", callback=callback_reset)
230
+ dpg.bind_item_theme("_button_reset", theme_button)
231
+
232
+
233
+ with dpg.group(horizontal=True):
234
+ dpg.add_text("Checkpoint: ")
235
+
236
+ def callback_save(sender, app_data):
237
+ self.trainer.save_checkpoint(full=True, best=False)
238
+ dpg.set_value("_log_ckpt", "saved " + os.path.basename(self.trainer.stats["checkpoints"][-1]))
239
+ self.trainer.epoch += 1 # use epoch to indicate different calls.
240
+
241
+ dpg.add_button(label="save", tag="_button_save", callback=callback_save)
242
+ dpg.bind_item_theme("_button_save", theme_button)
243
+
244
+ dpg.add_text("", tag="_log_ckpt")
245
+
246
+ # save mesh
247
+ with dpg.group(horizontal=True):
248
+ dpg.add_text("Marching Cubes: ")
249
+
250
+ def callback_mesh(sender, app_data):
251
+ self.trainer.save_mesh(resolution=256, threshold=10)
252
+ dpg.set_value("_log_mesh", "saved " + f'{self.trainer.name}_{self.trainer.epoch}.ply')
253
+ self.trainer.epoch += 1 # use epoch to indicate different calls.
254
+
255
+ dpg.add_button(label="mesh", tag="_button_mesh", callback=callback_mesh)
256
+ dpg.bind_item_theme("_button_mesh", theme_button)
257
+
258
+ dpg.add_text("", tag="_log_mesh")
259
+
260
+ with dpg.group(horizontal=True):
261
+ dpg.add_text("", tag="_log_train_log")
262
+
263
+
264
+ # rendering options
265
+ with dpg.collapsing_header(label="Options", default_open=True):
266
+
267
+ # dynamic rendering resolution
268
+ with dpg.group(horizontal=True):
269
+
270
+ def callback_set_dynamic_resolution(sender, app_data):
271
+ if self.dynamic_resolution:
272
+ self.dynamic_resolution = False
273
+ self.downscale = 1
274
+ else:
275
+ self.dynamic_resolution = True
276
+ self.need_update = True
277
+
278
+ dpg.add_checkbox(label="dynamic resolution", default_value=self.dynamic_resolution, callback=callback_set_dynamic_resolution)
279
+ dpg.add_text(f"{self.W}x{self.H}", tag="_log_resolution")
280
+
281
+ # mode combo
282
+ def callback_change_mode(sender, app_data):
283
+ self.mode = app_data
284
+ self.need_update = True
285
+
286
+ dpg.add_combo(('image', 'depth'), label='mode', default_value=self.mode, callback=callback_change_mode)
287
+
288
+ # bg_color picker
289
+ def callback_change_bg(sender, app_data):
290
+ self.bg_color = torch.tensor(app_data[:3], dtype=torch.float32) # only need RGB in [0, 1]
291
+ self.need_update = True
292
+
293
+ dpg.add_color_edit((255, 255, 255), label="Background Color", width=200, tag="_color_editor", no_alpha=True, callback=callback_change_bg)
294
+
295
+ # fov slider
296
+ def callback_set_fovy(sender, app_data):
297
+ self.cam.fovy = app_data
298
+ self.need_update = True
299
+
300
+ dpg.add_slider_int(label="FoV (vertical)", min_value=1, max_value=120, format="%d deg", default_value=self.cam.fovy, callback=callback_set_fovy)
301
+
302
+ # dt_gamma slider
303
+ def callback_set_dt_gamma(sender, app_data):
304
+ self.opt.dt_gamma = app_data
305
+ self.need_update = True
306
+
307
+ dpg.add_slider_float(label="dt_gamma", min_value=0, max_value=0.1, format="%.5f", default_value=self.opt.dt_gamma, callback=callback_set_dt_gamma)
308
+
309
+ # max_steps slider
310
+ def callback_set_max_steps(sender, app_data):
311
+ self.opt.max_steps = app_data
312
+ self.need_update = True
313
+
314
+ dpg.add_slider_int(label="max steps", min_value=1, max_value=1024, format="%d", default_value=self.opt.max_steps, callback=callback_set_max_steps)
315
+
316
+ # aabb slider
317
+ def callback_set_aabb(sender, app_data, user_data):
318
+ # user_data is the dimension for aabb (xmin, ymin, zmin, xmax, ymax, zmax)
319
+ self.trainer.model.aabb_infer[user_data] = app_data
320
+
321
+ # also change train aabb ? [better not...]
322
+ #self.trainer.model.aabb_train[user_data] = app_data
323
+
324
+ self.need_update = True
325
+
326
+ dpg.add_separator()
327
+ dpg.add_text("Axis-aligned bounding box:")
328
+
329
+ with dpg.group(horizontal=True):
330
+ dpg.add_slider_float(label="x", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=0)
331
+ dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=3)
332
+
333
+ with dpg.group(horizontal=True):
334
+ dpg.add_slider_float(label="y", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=1)
335
+ dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=4)
336
+
337
+ with dpg.group(horizontal=True):
338
+ dpg.add_slider_float(label="z", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=2)
339
+ dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=5)
340
+
341
+ # light dir
342
+ def callback_set_light_dir(sender, app_data, user_data):
343
+ self.light_dir[user_data] = app_data
344
+ self.need_update = True
345
+
346
+ dpg.add_separator()
347
+ dpg.add_text("Plane Light Direction:")
348
+
349
+ with dpg.group(horizontal=True):
350
+ dpg.add_slider_float(label="theta", min_value=0, max_value=180, format="%.2f", default_value=self.opt.light_theta, callback=callback_set_light_dir, user_data=0)
351
+
352
+ with dpg.group(horizontal=True):
353
+ dpg.add_slider_float(label="phi", min_value=0, max_value=360, format="%.2f", default_value=self.opt.light_phi, callback=callback_set_light_dir, user_data=1)
354
+
355
+ # ambient ratio
356
+ def callback_set_abm_ratio(sender, app_data):
357
+ self.ambient_ratio = app_data
358
+ self.need_update = True
359
+
360
+ dpg.add_slider_float(label="ambient", min_value=0, max_value=1.0, format="%.5f", default_value=self.ambient_ratio, callback=callback_set_abm_ratio)
361
+
362
+ # shading mode
363
+ def callback_change_shading(sender, app_data):
364
+ self.shading = app_data
365
+ self.need_update = True
366
+
367
+ dpg.add_combo(('albedo', 'lambertian', 'textureless', 'normal'), label='shading', default_value=self.shading, callback=callback_change_shading)
368
+
369
+
370
+ # debug info
371
+ if self.debug:
372
+ with dpg.collapsing_header(label="Debug"):
373
+ # pose
374
+ dpg.add_separator()
375
+ dpg.add_text("Camera Pose:")
376
+ dpg.add_text(str(self.cam.pose), tag="_log_pose")
377
+
378
+
379
+ ### register camera handler
380
+
381
+ def callback_camera_drag_rotate(sender, app_data):
382
+
383
+ if not dpg.is_item_focused("_primary_window"):
384
+ return
385
+
386
+ dx = app_data[1]
387
+ dy = app_data[2]
388
+
389
+ self.cam.orbit(dx, dy)
390
+ self.need_update = True
391
+
392
+ if self.debug:
393
+ dpg.set_value("_log_pose", str(self.cam.pose))
394
+
395
+
396
+ def callback_camera_wheel_scale(sender, app_data):
397
+
398
+ if not dpg.is_item_focused("_primary_window"):
399
+ return
400
+
401
+ delta = app_data
402
+
403
+ self.cam.scale(delta)
404
+ self.need_update = True
405
+
406
+ if self.debug:
407
+ dpg.set_value("_log_pose", str(self.cam.pose))
408
+
409
+
410
+ def callback_camera_drag_pan(sender, app_data):
411
+
412
+ if not dpg.is_item_focused("_primary_window"):
413
+ return
414
+
415
+ dx = app_data[1]
416
+ dy = app_data[2]
417
+
418
+ self.cam.pan(dx, dy)
419
+ self.need_update = True
420
+
421
+ if self.debug:
422
+ dpg.set_value("_log_pose", str(self.cam.pose))
423
+
424
+
425
+ with dpg.handler_registry():
426
+ dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Left, callback=callback_camera_drag_rotate)
427
+ dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale)
428
+ dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan)
429
+
430
+
431
+ dpg.create_viewport(title='torch-ngp', width=self.W, height=self.H, resizable=False)
432
+
433
+ # TODO: seems dearpygui doesn't support resizing texture...
434
+ # def callback_resize(sender, app_data):
435
+ # self.W = app_data[0]
436
+ # self.H = app_data[1]
437
+ # # how to reload texture ???
438
+
439
+ # dpg.set_viewport_resize_callback(callback_resize)
440
+
441
+ ### global theme
442
+ with dpg.theme() as theme_no_padding:
443
+ with dpg.theme_component(dpg.mvAll):
444
+ # set all padding to 0 to avoid scroll bar
445
+ dpg.add_theme_style(dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core)
446
+ dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core)
447
+ dpg.add_theme_style(dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core)
448
+
449
+ dpg.bind_item_theme("_primary_window", theme_no_padding)
450
+
451
+ dpg.setup_dearpygui()
452
+
453
+ #dpg.show_metrics()
454
+
455
+ dpg.show_viewport()
456
+
457
+
458
+ def render(self):
459
+
460
+ while dpg.is_dearpygui_running():
461
+ # update texture every frame
462
+ if self.training:
463
+ self.train_step()
464
+ self.test_step()
465
+ dpg.render_dearpygui_frame()
nerf/network.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from activation import trunc_exp
6
+ from .renderer import NeRFRenderer
7
+
8
+ import numpy as np
9
+ from encoding import get_encoder
10
+
11
+ from .utils import safe_normalize
12
+
13
+ class MLP(nn.Module):
14
+ def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
15
+ super().__init__()
16
+ self.dim_in = dim_in
17
+ self.dim_out = dim_out
18
+ self.dim_hidden = dim_hidden
19
+ self.num_layers = num_layers
20
+
21
+ net = []
22
+ for l in range(num_layers):
23
+ net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))
24
+
25
+ self.net = nn.ModuleList(net)
26
+
27
+ def forward(self, x):
28
+ for l in range(self.num_layers):
29
+ x = self.net[l](x)
30
+ if l != self.num_layers - 1:
31
+ x = F.relu(x, inplace=True)
32
+ return x
33
+
34
+
35
+ class NeRFNetwork(NeRFRenderer):
36
+ def __init__(self,
37
+ opt,
38
+ num_layers=5,
39
+ hidden_dim=128,
40
+ num_layers_bg=2,
41
+ hidden_dim_bg=64,
42
+ ):
43
+
44
+ super().__init__(opt)
45
+
46
+ self.num_layers = num_layers
47
+ self.hidden_dim = hidden_dim
48
+ self.encoder, self.in_dim = get_encoder('frequency', input_dim=3)
49
+ self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True)
50
+
51
+ # background network
52
+ if self.bg_radius > 0:
53
+ self.num_layers_bg = num_layers_bg
54
+ self.hidden_dim_bg = hidden_dim_bg
55
+ self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3)
56
+ self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
57
+
58
+ else:
59
+ self.bg_net = None
60
+
61
+ def gaussian(self, x):
62
+ # x: [B, N, 3]
63
+
64
+ d = (x ** 2).sum(-1)
65
+ g = 5 * torch.exp(-d / (2 * 0.2 ** 2))
66
+
67
+ return g
68
+
69
+ def common_forward(self, x):
70
+ # x: [N, 3], in [-bound, bound]
71
+
72
+ # sigma
73
+ h = self.encoder(x, bound=self.bound)
74
+
75
+ h = self.sigma_net(h)
76
+
77
+ sigma = trunc_exp(h[..., 0] + self.gaussian(x))
78
+ albedo = torch.sigmoid(h[..., 1:])
79
+
80
+ return sigma, albedo
81
+
82
+ # ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192
83
+ def finite_difference_normal(self, x, epsilon=1e-2):
84
+ # x: [N, 3]
85
+ dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
86
+ dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
87
+ dy_pos, _ = self.common_forward((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
88
+ dy_neg, _ = self.common_forward((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
89
+ dz_pos, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound))
90
+ dz_neg, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound))
91
+
92
+ normal = torch.stack([
93
+ 0.5 * (dx_pos - dx_neg) / epsilon,
94
+ 0.5 * (dy_pos - dy_neg) / epsilon,
95
+ 0.5 * (dz_pos - dz_neg) / epsilon
96
+ ], dim=-1)
97
+
98
+ return normal
99
+
100
+ def forward(self, x, d, l=None, ratio=1, shading='albedo'):
101
+ # x: [N, 3], in [-bound, bound]
102
+ # d: [N, 3], view direction, nomalized in [-1, 1]
103
+ # l: [3], plane light direction, nomalized in [-1, 1]
104
+ # ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)
105
+
106
+ if shading == 'albedo':
107
+ # no need to query normal
108
+ sigma, color = self.common_forward(x)
109
+ normal = None
110
+
111
+ else:
112
+ # query normal
113
+
114
+ # sigma, albedo = self.common_forward(x)
115
+ # normal = self.finite_difference_normal(x)
116
+
117
+ with torch.enable_grad():
118
+ x.requires_grad_(True)
119
+ sigma, albedo = self.common_forward(x)
120
+ # query gradient
121
+ normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
122
+
123
+ # normalize...
124
+ normal = safe_normalize(normal)
125
+ normal[torch.isnan(normal)] = 0
126
+
127
+ # lambertian shading
128
+ lambertian = ratio + (1 - ratio) * (normal @ -l).clamp(min=0) # [N,]
129
+
130
+ if shading == 'textureless':
131
+ color = lambertian.unsqueeze(-1).repeat(1, 3)
132
+ elif shading == 'normal':
133
+ color = (normal + 1) / 2
134
+ else: # 'lambertian'
135
+ color = albedo * lambertian.unsqueeze(-1)
136
+
137
+ return sigma, color, normal
138
+
139
+
140
+ def density(self, x):
141
+ # x: [N, 3], in [-bound, bound]
142
+
143
+ sigma, albedo = self.common_forward(x)
144
+
145
+ return {
146
+ 'sigma': sigma,
147
+ 'albedo': albedo,
148
+ }
149
+
150
+
151
+ def background(self, d):
152
+
153
+ h = self.encoder_bg(d) # [N, C]
154
+
155
+ h = self.bg_net(h)
156
+
157
+ # sigmoid activation for rgb
158
+ rgbs = torch.sigmoid(h)
159
+
160
+ return rgbs
161
+
162
+ # optimizer utils
163
+ def get_params(self, lr):
164
+
165
+ params = [
166
+ # {'params': self.encoder.parameters(), 'lr': lr * 10},
167
+ {'params': self.sigma_net.parameters(), 'lr': lr},
168
+ ]
169
+
170
+ if self.bg_radius > 0:
171
+ # params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})
172
+ params.append({'params': self.bg_net.parameters(), 'lr': lr})
173
+
174
+ return params
nerf/network_grid.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from activation import trunc_exp
6
+ from .renderer import NeRFRenderer
7
+
8
+ import numpy as np
9
+ from encoding import get_encoder
10
+
11
+ from .utils import safe_normalize
12
+
13
+ class MLP(nn.Module):
14
+ def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
15
+ super().__init__()
16
+ self.dim_in = dim_in
17
+ self.dim_out = dim_out
18
+ self.dim_hidden = dim_hidden
19
+ self.num_layers = num_layers
20
+
21
+ net = []
22
+ for l in range(num_layers):
23
+ net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))
24
+
25
+ self.net = nn.ModuleList(net)
26
+
27
+ def forward(self, x):
28
+ for l in range(self.num_layers):
29
+ x = self.net[l](x)
30
+ if l != self.num_layers - 1:
31
+ x = F.relu(x, inplace=True)
32
+ return x
33
+
34
+
35
+ class NeRFNetwork(NeRFRenderer):
36
+ def __init__(self,
37
+ opt,
38
+ num_layers=3,
39
+ hidden_dim=64,
40
+ num_layers_bg=2,
41
+ hidden_dim_bg=64,
42
+ ):
43
+
44
+ super().__init__(opt)
45
+
46
+ self.num_layers = num_layers
47
+ self.hidden_dim = hidden_dim
48
+
49
+ self.encoder, self.in_dim = get_encoder('tiledgrid', input_dim=3, desired_resolution=2048 * self.bound)
50
+
51
+ self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True)
52
+
53
+ # background network
54
+ if self.bg_radius > 0:
55
+ self.num_layers_bg = num_layers_bg
56
+ self.hidden_dim_bg = hidden_dim_bg
57
+
58
+ # use a very simple network to avoid it learning the prompt...
59
+ # self.encoder_bg, self.in_dim_bg = get_encoder('tiledgrid', input_dim=2, num_levels=4, desired_resolution=2048)
60
+ self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3)
61
+
62
+ self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
63
+
64
+ else:
65
+ self.bg_net = None
66
+
67
+ # add a density blob to the scene center
68
+ def gaussian(self, x):
69
+ # x: [B, N, 3]
70
+
71
+ d = (x ** 2).sum(-1)
72
+ g = 5 * torch.exp(-d / (2 * 0.2 ** 2))
73
+
74
+ return g
75
+
76
+ def common_forward(self, x):
77
+ # x: [N, 3], in [-bound, bound]
78
+
79
+ # sigma
80
+ h = self.encoder(x, bound=self.bound)
81
+
82
+ h = self.sigma_net(h)
83
+
84
+ sigma = trunc_exp(h[..., 0] + self.gaussian(x))
85
+ albedo = torch.sigmoid(h[..., 1:])
86
+
87
+ return sigma, albedo
88
+
89
+ # ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192
90
+ def finite_difference_normal(self, x, epsilon=1e-2):
91
+ # x: [N, 3]
92
+ dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
93
+ dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
94
+ dy_pos, _ = self.common_forward((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
95
+ dy_neg, _ = self.common_forward((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
96
+ dz_pos, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound))
97
+ dz_neg, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound))
98
+
99
+ normal = torch.stack([
100
+ 0.5 * (dx_pos - dx_neg) / epsilon,
101
+ 0.5 * (dy_pos - dy_neg) / epsilon,
102
+ 0.5 * (dz_pos - dz_neg) / epsilon
103
+ ], dim=-1)
104
+
105
+ return normal
106
+
107
+ def forward(self, x, d, l=None, ratio=1, shading='albedo'):
108
+ # x: [N, 3], in [-bound, bound]
109
+ # d: [N, 3], view direction, nomalized in [-1, 1]
110
+ # l: [3], plane light direction, nomalized in [-1, 1]
111
+ # ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)
112
+
113
+ if shading == 'albedo':
114
+ # no need to query normal
115
+ sigma, color = self.common_forward(x)
116
+ normal = None
117
+
118
+ else:
119
+ # query normal
120
+
121
+ sigma, albedo = self.common_forward(x)
122
+ normal = self.finite_difference_normal(x)
123
+
124
+ # with torch.enable_grad():
125
+ # x.requires_grad_(True)
126
+ # sigma, albedo = self.common_forward(x)
127
+ # # query gradient
128
+ # normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
129
+
130
+ # normalize...
131
+ normal = safe_normalize(normal)
132
+ normal[torch.isnan(normal)] = 0
133
+
134
+ # lambertian shading
135
+ lambertian = ratio + (1 - ratio) * (normal @ -l).clamp(min=0) # [N,]
136
+
137
+ if shading == 'textureless':
138
+ color = lambertian.unsqueeze(-1).repeat(1, 3)
139
+ elif shading == 'normal':
140
+ color = (normal + 1) / 2
141
+ else: # 'lambertian'
142
+ color = albedo * lambertian.unsqueeze(-1)
143
+
144
+ return sigma, color, normal
145
+
146
+
147
+ def density(self, x):
148
+ # x: [N, 3], in [-bound, bound]
149
+
150
+ sigma, albedo = self.common_forward(x)
151
+
152
+ return {
153
+ 'sigma': sigma,
154
+ 'albedo': albedo,
155
+ }
156
+
157
+
158
+ def background(self, d):
159
+
160
+ h = self.encoder_bg(d) # [N, C]
161
+
162
+ h = self.bg_net(h)
163
+
164
+ # sigmoid activation for rgb
165
+ rgbs = torch.sigmoid(h)
166
+
167
+ return rgbs
168
+
169
+ # optimizer utils
170
+ def get_params(self, lr):
171
+
172
+ params = [
173
+ {'params': self.encoder.parameters(), 'lr': lr * 10},
174
+ {'params': self.sigma_net.parameters(), 'lr': lr},
175
+ ]
176
+
177
+ if self.bg_radius > 0:
178
+ params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})
179
+ params.append({'params': self.bg_net.parameters(), 'lr': lr})
180
+
181
+ return params
nerf/network_tcnn.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from activation import trunc_exp
6
+ from .renderer import NeRFRenderer
7
+ from encoding import get_encoder
8
+
9
+ import numpy as np
10
+ import tinycudann as tcnn
11
+
12
+ class MLP(nn.Module):
13
+ def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
14
+ super().__init__()
15
+ self.dim_in = dim_in
16
+ self.dim_out = dim_out
17
+ self.dim_hidden = dim_hidden
18
+ self.num_layers = num_layers
19
+
20
+ net = []
21
+ for l in range(num_layers):
22
+ net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))
23
+
24
+ self.net = nn.ModuleList(net)
25
+
26
+ def forward(self, x):
27
+ for l in range(self.num_layers):
28
+ x = self.net[l](x)
29
+ if l != self.num_layers - 1:
30
+ x = F.relu(x, inplace=True)
31
+ return x
32
+
33
+
34
+ class NeRFNetwork(NeRFRenderer):
35
+ def __init__(self,
36
+ opt,
37
+ num_layers=3,
38
+ hidden_dim=64,
39
+ num_layers_bg=2,
40
+ hidden_dim_bg=64,
41
+ ):
42
+
43
+ super().__init__(opt)
44
+
45
+ self.num_layers = num_layers
46
+ self.hidden_dim = hidden_dim
47
+
48
+ per_level_scale = np.exp2(np.log2(2048 * self.bound / 16) / (16 - 1))
49
+
50
+ self.encoder = tcnn.Encoding(
51
+ n_input_dims=3,
52
+ encoding_config={
53
+ "otype": "HashGrid",
54
+ "n_levels": 16,
55
+ "n_features_per_level": 2,
56
+ "log2_hashmap_size": 19,
57
+ "base_resolution": 16,
58
+ "per_level_scale": per_level_scale,
59
+ },
60
+ )
61
+
62
+ self.sigma_net = MLP(32, 4, hidden_dim, num_layers, bias=True)
63
+
64
+ # background network
65
+ if self.bg_radius > 0:
66
+ self.num_layers_bg = num_layers_bg
67
+ self.hidden_dim_bg = hidden_dim_bg
68
+
69
+ self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3)
70
+
71
+ self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
72
+
73
+ else:
74
+ self.bg_net = None
75
+
76
+ def gaussian(self, x):
77
+ # x: [B, N, 3]
78
+
79
+ d = (x ** 2).sum(-1)
80
+ g = 5 * torch.exp(-d / (2 * 0.2 ** 2))
81
+
82
+ return g
83
+
84
+ def common_forward(self, x):
85
+ # x: [N, 3], in [-bound, bound]
86
+
87
+ # sigma
88
+ h = (x + self.bound) / (2 * self.bound) # to [0, 1]
89
+ h = self.encoder(h)
90
+
91
+ h = self.sigma_net(h)
92
+
93
+ sigma = trunc_exp(h[..., 0] + self.gaussian(x))
94
+ albedo = torch.sigmoid(h[..., 1:])
95
+
96
+ return sigma, albedo
97
+
98
+
99
+ def forward(self, x, d, l=None, ratio=1, shading='albedo'):
100
+ # x: [N, 3], in [-bound, bound]
101
+ # d: [N, 3], view direction, nomalized in [-1, 1]
102
+ # l: [3], plane light direction, nomalized in [-1, 1]
103
+ # ratio: scalar, ambient ratio, 1 == no shading (albedo only)
104
+
105
+ if shading == 'albedo':
106
+ # no need to query normal
107
+ sigma, color = self.common_forward(x)
108
+ normal = None
109
+
110
+ else:
111
+ # query normal
112
+ has_grad = torch.is_grad_enabled()
113
+
114
+ with torch.enable_grad():
115
+ x.requires_grad_(True)
116
+ sigma, albedo = self.common_forward(x)
117
+ # query gradient
118
+ normal = torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
119
+
120
+ # normalize...
121
+ normal = normal / (torch.norm(normal, dim=-1, keepdim=True) + 1e-9)
122
+ normal[torch.isnan(normal)] = 0
123
+
124
+ if not has_grad:
125
+ normal = normal.detach()
126
+
127
+ # lambertian shading
128
+ lambertian = ratio + (1 - ratio) * (normal @ l).clamp(min=0) # [N,]
129
+
130
+ if shading == 'textureless':
131
+ color = lambertian.unsqueeze(-1).repeat(1, 3)
132
+ elif shading == 'normal':
133
+ color = (normal + 1) / 2
134
+ else: # 'lambertian'
135
+ color = albedo * lambertian.unsqueeze(-1)
136
+
137
+ return sigma, color, normal
138
+
139
+
140
+ def density(self, x):
141
+ # x: [N, 3], in [-bound, bound]
142
+
143
+ sigma, _ = self.common_forward(x)
144
+
145
+ return {
146
+ 'sigma': sigma
147
+ }
148
+
149
+
150
+ def background(self, d):
151
+ # x: [N, 2], in [-1, 1]
152
+
153
+ h = self.encoder_bg(d) # [N, C]
154
+
155
+ h = self.bg_net(h)
156
+
157
+ # sigmoid activation for rgb
158
+ rgbs = torch.sigmoid(h)
159
+
160
+ return rgbs
161
+
162
+ # optimizer utils
163
+ def get_params(self, lr):
164
+
165
+ params = [
166
+ {'params': self.encoder.parameters(), 'lr': lr * 10},
167
+ {'params': self.sigma_net.parameters(), 'lr': lr},
168
+ ]
169
+
170
+ if self.bg_radius > 0:
171
+ params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})
172
+ params.append({'params': self.bg_net.parameters(), 'lr': lr})
173
+
174
+ return params
nerf/provider.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import glob
4
+ import json
5
+ import tqdm
6
+ import random
7
+ import numpy as np
8
+ from scipy.spatial.transform import Slerp, Rotation
9
+
10
+ import trimesh
11
+
12
+ import torch
13
+ from torch.utils.data import DataLoader
14
+
15
+ from .utils import get_rays, safe_normalize
16
+
17
+ def visualize_poses(poses, size=0.1):
18
+ # poses: [B, 4, 4]
19
+
20
+ axes = trimesh.creation.axis(axis_length=4)
21
+ sphere = trimesh.creation.icosphere(radius=1)
22
+ objects = [axes, sphere]
23
+
24
+ for pose in poses:
25
+ # a camera is visualized with 8 line segments.
26
+ pos = pose[:3, 3]
27
+ a = pos + size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2]
28
+ b = pos - size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2]
29
+ c = pos - size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2]
30
+ d = pos + size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2]
31
+
32
+ segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a]])
33
+ segs = trimesh.load_path(segs)
34
+ objects.append(segs)
35
+
36
+ trimesh.Scene(objects).show()
37
+
38
+ def get_view_direction(thetas, phis, overhead, front):
39
+ # phis [B,]; thetas: [B,]
40
+ # front = 0 [0, front)
41
+ # side (left) = 1 [front, 180)
42
+ # back = 2 [180, 180+front)
43
+ # side (right) = 3 [180+front, 360)
44
+ # top = 4 [0, overhead]
45
+ # bottom = 5 [180-overhead, 180]
46
+ res = torch.zeros(thetas.shape[0], dtype=torch.long)
47
+ # first determine by phis
48
+ res[(phis < front)] = 0
49
+ res[(phis >= front) & (phis < np.pi)] = 1
50
+ res[(phis >= np.pi) & (phis < (np.pi + front))] = 2
51
+ res[(phis >= (np.pi + front))] = 3
52
+ # override by thetas
53
+ res[thetas <= overhead] = 4
54
+ res[thetas >= (np.pi - overhead)] = 5
55
+ return res
56
+
57
+
58
+ def rand_poses(size, device, radius_range=[1, 1.5], theta_range=[0, 100], phi_range=[0, 360], return_dirs=False, angle_overhead=30, angle_front=60, jitter=False):
59
+ ''' generate random poses from an orbit camera
60
+ Args:
61
+ size: batch size of generated poses.
62
+ device: where to allocate the output.
63
+ radius: camera radius
64
+ theta_range: [min, max], should be in [0, pi]
65
+ phi_range: [min, max], should be in [0, 2 * pi]
66
+ Return:
67
+ poses: [size, 4, 4]
68
+ '''
69
+
70
+ theta_range = np.deg2rad(theta_range)
71
+ phi_range = np.deg2rad(phi_range)
72
+ angle_overhead = np.deg2rad(angle_overhead)
73
+ angle_front = np.deg2rad(angle_front)
74
+
75
+ radius = torch.rand(size, device=device) * (radius_range[1] - radius_range[0]) + radius_range[0]
76
+ thetas = torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0]
77
+ phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0]
78
+
79
+ centers = torch.stack([
80
+ radius * torch.sin(thetas) * torch.sin(phis),
81
+ radius * torch.cos(thetas),
82
+ radius * torch.sin(thetas) * torch.cos(phis),
83
+ ], dim=-1) # [B, 3]
84
+
85
+ targets = 0
86
+
87
+ # jitters
88
+ if jitter:
89
+ centers = centers + (torch.rand_like(centers) * 0.2 - 0.1)
90
+ targets = targets + torch.randn_like(centers) * 0.2
91
+
92
+ # lookat
93
+ forward_vector = safe_normalize(targets - centers)
94
+ up_vector = torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1)
95
+ right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1))
96
+
97
+ if jitter:
98
+ up_noise = torch.randn_like(up_vector) * 0.02
99
+ else:
100
+ up_noise = 0
101
+
102
+ up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1) + up_noise)
103
+
104
+ poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1)
105
+ poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
106
+ poses[:, :3, 3] = centers
107
+
108
+ if return_dirs:
109
+ dirs = get_view_direction(thetas, phis, angle_overhead, angle_front)
110
+ else:
111
+ dirs = None
112
+
113
+ return poses, dirs
114
+
115
+
116
+ def circle_poses(device, radius=1.25, theta=60, phi=0, return_dirs=False, angle_overhead=30, angle_front=60):
117
+
118
+ theta = np.deg2rad(theta)
119
+ phi = np.deg2rad(phi)
120
+ angle_overhead = np.deg2rad(angle_overhead)
121
+ angle_front = np.deg2rad(angle_front)
122
+
123
+ thetas = torch.FloatTensor([theta]).to(device)
124
+ phis = torch.FloatTensor([phi]).to(device)
125
+
126
+ centers = torch.stack([
127
+ radius * torch.sin(thetas) * torch.sin(phis),
128
+ radius * torch.cos(thetas),
129
+ radius * torch.sin(thetas) * torch.cos(phis),
130
+ ], dim=-1) # [B, 3]
131
+
132
+ # lookat
133
+ forward_vector = - safe_normalize(centers)
134
+ up_vector = torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0)
135
+ right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1))
136
+ up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1))
137
+
138
+ poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0)
139
+ poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
140
+ poses[:, :3, 3] = centers
141
+
142
+ if return_dirs:
143
+ dirs = get_view_direction(thetas, phis, angle_overhead, angle_front)
144
+ else:
145
+ dirs = None
146
+
147
+ return poses, dirs
148
+
149
+
150
+ class NeRFDataset:
151
+ def __init__(self, opt, device, type='train', H=256, W=256, size=100):
152
+ super().__init__()
153
+
154
+ self.opt = opt
155
+ self.device = device
156
+ self.type = type # train, val, test
157
+
158
+ self.H = H
159
+ self.W = W
160
+ self.radius_range = opt.radius_range
161
+ self.fovy_range = opt.fovy_range
162
+ self.size = size
163
+
164
+ self.training = self.type in ['train', 'all']
165
+
166
+ self.cx = self.H / 2
167
+ self.cy = self.W / 2
168
+
169
+ # [debug] visualize poses
170
+ # poses, dirs = rand_poses(100, self.device, return_dirs=self.opt.dir_text, radius_range=self.radius_range)
171
+ # visualize_poses(poses.detach().cpu().numpy())
172
+
173
+
174
+ def collate(self, index):
175
+
176
+ B = len(index) # always 1
177
+
178
+ if self.training:
179
+ # random pose on the fly
180
+ poses, dirs = rand_poses(B, self.device, radius_range=self.radius_range, return_dirs=self.opt.dir_text, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front, jitter=self.opt.jitter_pose)
181
+
182
+ # random focal
183
+ fov = random.random() * (self.fovy_range[1] - self.fovy_range[0]) + self.fovy_range[0]
184
+ focal = self.H / (2 * np.tan(np.deg2rad(fov) / 2))
185
+ intrinsics = np.array([focal, focal, self.cx, self.cy])
186
+ else:
187
+ # circle pose
188
+ phi = (index[0] / self.size) * 360
189
+ poses, dirs = circle_poses(self.device, radius=self.radius_range[1] * 1.2, theta=60, phi=phi, return_dirs=self.opt.dir_text, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front)
190
+
191
+ # fixed focal
192
+ fov = (self.fovy_range[1] + self.fovy_range[0]) / 2
193
+ focal = self.H / (2 * np.tan(np.deg2rad(fov) / 2))
194
+ intrinsics = np.array([focal, focal, self.cx, self.cy])
195
+
196
+
197
+ # sample a low-resolution but full image for CLIP
198
+ rays = get_rays(poses, intrinsics, self.H, self.W, -1)
199
+
200
+ data = {
201
+ 'H': self.H,
202
+ 'W': self.W,
203
+ 'rays_o': rays['rays_o'],
204
+ 'rays_d': rays['rays_d'],
205
+ 'dir': dirs,
206
+ }
207
+
208
+ return data
209
+
210
+
211
+ def dataloader(self):
212
+ loader = DataLoader(list(range(self.size)), batch_size=1, collate_fn=self.collate, shuffle=self.training, num_workers=0)
213
+ loader._data = self # an ugly fix... we need to access dataset in trainer.
214
+ return loader
nerf/renderer.py ADDED
@@ -0,0 +1,645 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import cv2
4
+ import trimesh
5
+ import numpy as np
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ import mcubes
12
+ import raymarching
13
+ from .utils import custom_meshgrid, safe_normalize
14
+
15
+ def sample_pdf(bins, weights, n_samples, det=False):
16
+ # This implementation is from NeRF
17
+ # bins: [B, T], old_z_vals
18
+ # weights: [B, T - 1], bin weights.
19
+ # return: [B, n_samples], new_z_vals
20
+
21
+ # Get pdf
22
+ weights = weights + 1e-5 # prevent nans
23
+ pdf = weights / torch.sum(weights, -1, keepdim=True)
24
+ cdf = torch.cumsum(pdf, -1)
25
+ cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
26
+ # Take uniform samples
27
+ if det:
28
+ u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(weights.device)
29
+ u = u.expand(list(cdf.shape[:-1]) + [n_samples])
30
+ else:
31
+ u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device)
32
+
33
+ # Invert CDF
34
+ u = u.contiguous()
35
+ inds = torch.searchsorted(cdf, u, right=True)
36
+ below = torch.max(torch.zeros_like(inds - 1), inds - 1)
37
+ above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
38
+ inds_g = torch.stack([below, above], -1) # (B, n_samples, 2)
39
+
40
+ matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
41
+ cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
42
+ bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
43
+
44
+ denom = (cdf_g[..., 1] - cdf_g[..., 0])
45
+ denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
46
+ t = (u - cdf_g[..., 0]) / denom
47
+ samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
48
+
49
+ return samples
50
+
51
+
52
+ def plot_pointcloud(pc, color=None):
53
+ # pc: [N, 3]
54
+ # color: [N, 3/4]
55
+ print('[visualize points]', pc.shape, pc.dtype, pc.min(0), pc.max(0))
56
+ pc = trimesh.PointCloud(pc, color)
57
+ # axis
58
+ axes = trimesh.creation.axis(axis_length=4)
59
+ # sphere
60
+ sphere = trimesh.creation.icosphere(radius=1)
61
+ trimesh.Scene([pc, axes, sphere]).show()
62
+
63
+
64
+ class NeRFRenderer(nn.Module):
65
+ def __init__(self, opt):
66
+ super().__init__()
67
+
68
+ self.opt = opt
69
+ self.bound = opt.bound
70
+ self.cascade = 1 + math.ceil(math.log2(opt.bound))
71
+ self.grid_size = 128
72
+ self.cuda_ray = opt.cuda_ray
73
+ self.min_near = opt.min_near
74
+ self.density_thresh = opt.density_thresh
75
+ self.bg_radius = opt.bg_radius
76
+
77
+ # prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax)
78
+ # NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing.
79
+ aabb_train = torch.FloatTensor([-opt.bound, -opt.bound, -opt.bound, opt.bound, opt.bound, opt.bound])
80
+ aabb_infer = aabb_train.clone()
81
+ self.register_buffer('aabb_train', aabb_train)
82
+ self.register_buffer('aabb_infer', aabb_infer)
83
+
84
+ # extra state for cuda raymarching
85
+ if self.cuda_ray:
86
+ # density grid
87
+ density_grid = torch.zeros([self.cascade, self.grid_size ** 3]) # [CAS, H * H * H]
88
+ density_bitfield = torch.zeros(self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8]
89
+ self.register_buffer('density_grid', density_grid)
90
+ self.register_buffer('density_bitfield', density_bitfield)
91
+ self.mean_density = 0
92
+ self.iter_density = 0
93
+ # step counter
94
+ step_counter = torch.zeros(16, 2, dtype=torch.int32) # 16 is hardcoded for averaging...
95
+ self.register_buffer('step_counter', step_counter)
96
+ self.mean_count = 0
97
+ self.local_step = 0
98
+
99
+
100
+ def forward(self, x, d):
101
+ raise NotImplementedError()
102
+
103
+ def density(self, x):
104
+ raise NotImplementedError()
105
+
106
+ def color(self, x, d, mask=None, **kwargs):
107
+ raise NotImplementedError()
108
+
109
+ def reset_extra_state(self):
110
+ if not self.cuda_ray:
111
+ return
112
+ # density grid
113
+ self.density_grid.zero_()
114
+ self.mean_density = 0
115
+ self.iter_density = 0
116
+ # step counter
117
+ self.step_counter.zero_()
118
+ self.mean_count = 0
119
+ self.local_step = 0
120
+
121
+ @torch.no_grad()
122
+ def export_mesh(self, path, resolution=None, S=128):
123
+
124
+ if resolution is None:
125
+ resolution = self.grid_size
126
+
127
+ density_thresh = min(self.mean_density, self.density_thresh)
128
+
129
+ sigmas = np.zeros([resolution, resolution, resolution], dtype=np.float32)
130
+
131
+ # query
132
+ X = torch.linspace(-1, 1, resolution).split(S)
133
+ Y = torch.linspace(-1, 1, resolution).split(S)
134
+ Z = torch.linspace(-1, 1, resolution).split(S)
135
+
136
+ for xi, xs in enumerate(X):
137
+ for yi, ys in enumerate(Y):
138
+ for zi, zs in enumerate(Z):
139
+ xx, yy, zz = custom_meshgrid(xs, ys, zs)
140
+ pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [S, 3]
141
+ val = self.density(pts.to(self.density_bitfield.device))
142
+ sigmas[xi * S: xi * S + len(xs), yi * S: yi * S + len(ys), zi * S: zi * S + len(zs)] = val['sigma'].reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() # [S, 1] --> [x, y, z]
143
+
144
+ vertices, triangles = mcubes.marching_cubes(sigmas, density_thresh)
145
+
146
+ vertices = vertices / (resolution - 1.0) * 2 - 1
147
+ vertices = vertices.astype(np.float32)
148
+ triangles = triangles.astype(np.int32)
149
+
150
+ v = torch.from_numpy(vertices).to(self.density_bitfield.device)
151
+ f = torch.from_numpy(triangles).int().to(self.density_bitfield.device)
152
+
153
+ # mesh = trimesh.Trimesh(vertices, triangles, process=False) # important, process=True leads to seg fault...
154
+ # mesh.export(os.path.join(path, f'mesh.ply'))
155
+
156
+ # texture?
157
+ def _export(v, f, h0=2048, w0=2048, ssaa=1, name=''):
158
+ # v, f: torch Tensor
159
+ device = v.device
160
+ v_np = v.cpu().numpy() # [N, 3]
161
+ f_np = f.cpu().numpy() # [M, 3]
162
+
163
+ print(f'[INFO] running xatlas to unwrap UVs for mesh: v={v_np.shape} f={f_np.shape}')
164
+
165
+ # unwrap uvs
166
+ import xatlas
167
+ import nvdiffrast.torch as dr
168
+ from sklearn.neighbors import NearestNeighbors
169
+ from scipy.ndimage import binary_dilation, binary_erosion
170
+
171
+ glctx = dr.RasterizeCudaContext()
172
+
173
+ atlas = xatlas.Atlas()
174
+ atlas.add_mesh(v_np, f_np)
175
+ chart_options = xatlas.ChartOptions()
176
+ chart_options.max_iterations = 0 # disable merge_chart for faster unwrap...
177
+ atlas.generate(chart_options=chart_options)
178
+ vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
179
+
180
+ # vmapping, ft_np, vt_np = xatlas.parametrize(v_np, f_np) # [N], [M, 3], [N, 2]
181
+
182
+ vt = torch.from_numpy(vt_np.astype(np.float32)).float().to(device)
183
+ ft = torch.from_numpy(ft_np.astype(np.int64)).int().to(device)
184
+
185
+ # render uv maps
186
+ uv = vt * 2.0 - 1.0 # uvs to range [-1, 1]
187
+ uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4]
188
+
189
+ if ssaa > 1:
190
+ h = int(h0 * ssaa)
191
+ w = int(w0 * ssaa)
192
+ else:
193
+ h, w = h0, w0
194
+
195
+ rast, _ = dr.rasterize(glctx, uv.unsqueeze(0), ft, (h, w)) # [1, h, w, 4]
196
+ xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, h, w, 3]
197
+ mask, _ = dr.interpolate(torch.ones_like(v[:, :1]).unsqueeze(0), rast, f) # [1, h, w, 1]
198
+
199
+ # masked query
200
+ xyzs = xyzs.view(-1, 3)
201
+ mask = (mask > 0).view(-1)
202
+
203
+ sigmas = torch.zeros(h * w, device=device, dtype=torch.float32)
204
+ feats = torch.zeros(h * w, 3, device=device, dtype=torch.float32)
205
+
206
+ if mask.any():
207
+ xyzs = xyzs[mask] # [M, 3]
208
+
209
+ # batched inference to avoid OOM
210
+ all_sigmas = []
211
+ all_feats = []
212
+ head = 0
213
+ while head < xyzs.shape[0]:
214
+ tail = min(head + 640000, xyzs.shape[0])
215
+ results_ = self.density(xyzs[head:tail])
216
+ all_sigmas.append(results_['sigma'].float())
217
+ all_feats.append(results_['albedo'].float())
218
+ head += 640000
219
+
220
+ sigmas[mask] = torch.cat(all_sigmas, dim=0)
221
+ feats[mask] = torch.cat(all_feats, dim=0)
222
+
223
+ sigmas = sigmas.view(h, w, 1)
224
+ feats = feats.view(h, w, -1)
225
+ mask = mask.view(h, w)
226
+
227
+ ### alpha mask
228
+ # deltas = 2 * np.sqrt(3) / 1024
229
+ # alphas = 1 - torch.exp(-sigmas * deltas)
230
+ # alphas_mask = alphas > 0.5
231
+ # feats = feats * alphas_mask
232
+
233
+ # quantize [0.0, 1.0] to [0, 255]
234
+ feats = feats.cpu().numpy()
235
+ feats = (feats * 255).astype(np.uint8)
236
+
237
+ # alphas = alphas.cpu().numpy()
238
+ # alphas = (alphas * 255).astype(np.uint8)
239
+
240
+ ### NN search as an antialiasing ...
241
+ mask = mask.cpu().numpy()
242
+
243
+ inpaint_region = binary_dilation(mask, iterations=3)
244
+ inpaint_region[mask] = 0
245
+
246
+ search_region = mask.copy()
247
+ not_search_region = binary_erosion(search_region, iterations=2)
248
+ search_region[not_search_region] = 0
249
+
250
+ search_coords = np.stack(np.nonzero(search_region), axis=-1)
251
+ inpaint_coords = np.stack(np.nonzero(inpaint_region), axis=-1)
252
+
253
+ knn = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(search_coords)
254
+ _, indices = knn.kneighbors(inpaint_coords)
255
+
256
+ feats[tuple(inpaint_coords.T)] = feats[tuple(search_coords[indices[:, 0]].T)]
257
+
258
+ # do ssaa after the NN search, in numpy
259
+ feats = cv2.cvtColor(feats, cv2.COLOR_RGB2BGR)
260
+
261
+ if ssaa > 1:
262
+ # alphas = cv2.resize(alphas, (w0, h0), interpolation=cv2.INTER_NEAREST)
263
+ feats = cv2.resize(feats, (w0, h0), interpolation=cv2.INTER_LINEAR)
264
+
265
+ # cv2.imwrite(os.path.join(path, f'alpha.png'), alphas)
266
+ cv2.imwrite(os.path.join(path, f'{name}albedo.png'), feats)
267
+
268
+ # save obj (v, vt, f /)
269
+ obj_file = os.path.join(path, f'{name}mesh.obj')
270
+ mtl_file = os.path.join(path, f'{name}mesh.mtl')
271
+
272
+ print(f'[INFO] writing obj mesh to {obj_file}')
273
+ with open(obj_file, "w") as fp:
274
+ fp.write(f'mtllib {name}mesh.mtl \n')
275
+
276
+ print(f'[INFO] writing vertices {v_np.shape}')
277
+ for v in v_np:
278
+ fp.write(f'v {v[0]} {v[1]} {v[2]} \n')
279
+
280
+ print(f'[INFO] writing vertices texture coords {vt_np.shape}')
281
+ for v in vt_np:
282
+ fp.write(f'vt {v[0]} {1 - v[1]} \n')
283
+
284
+ print(f'[INFO] writing faces {f_np.shape}')
285
+ fp.write(f'usemtl mat0 \n')
286
+ for i in range(len(f_np)):
287
+ fp.write(f"f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1} {f_np[i, 1] + 1}/{ft_np[i, 1] + 1} {f_np[i, 2] + 1}/{ft_np[i, 2] + 1} \n")
288
+
289
+ with open(mtl_file, "w") as fp:
290
+ fp.write(f'newmtl mat0 \n')
291
+ fp.write(f'Ka 1.000000 1.000000 1.000000 \n')
292
+ fp.write(f'Kd 1.000000 1.000000 1.000000 \n')
293
+ fp.write(f'Ks 0.000000 0.000000 0.000000 \n')
294
+ fp.write(f'Tr 1.000000 \n')
295
+ fp.write(f'illum 1 \n')
296
+ fp.write(f'Ns 0.000000 \n')
297
+ fp.write(f'map_Kd {name}albedo.png \n')
298
+
299
+ _export(v, f)
300
+
301
+ def run(self, rays_o, rays_d, num_steps=128, upsample_steps=128, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, **kwargs):
302
+ # rays_o, rays_d: [B, N, 3], assumes B == 1
303
+ # bg_color: [BN, 3] in range [0, 1]
304
+ # return: image: [B, N, 3], depth: [B, N]
305
+
306
+ prefix = rays_o.shape[:-1]
307
+ rays_o = rays_o.contiguous().view(-1, 3)
308
+ rays_d = rays_d.contiguous().view(-1, 3)
309
+
310
+ N = rays_o.shape[0] # N = B * N, in fact
311
+ device = rays_o.device
312
+
313
+ results = {}
314
+
315
+ # choose aabb
316
+ aabb = self.aabb_train if self.training else self.aabb_infer
317
+
318
+ # sample steps
319
+ nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, aabb, self.min_near)
320
+ nears.unsqueeze_(-1)
321
+ fars.unsqueeze_(-1)
322
+
323
+ # random sample light_d if not provided
324
+ if light_d is None:
325
+ # gaussian noise around the ray origin, so the light always face the view dir (avoid dark face)
326
+ light_d = (rays_o[0] + torch.randn(3, device=device, dtype=torch.float))
327
+ light_d = safe_normalize(light_d)
328
+
329
+ #print(f'nears = {nears.min().item()} ~ {nears.max().item()}, fars = {fars.min().item()} ~ {fars.max().item()}')
330
+
331
+ z_vals = torch.linspace(0.0, 1.0, num_steps, device=device).unsqueeze(0) # [1, T]
332
+ z_vals = z_vals.expand((N, num_steps)) # [N, T]
333
+ z_vals = nears + (fars - nears) * z_vals # [N, T], in [nears, fars]
334
+
335
+ # perturb z_vals
336
+ sample_dist = (fars - nears) / num_steps
337
+ if perturb:
338
+ z_vals = z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist
339
+ #z_vals = z_vals.clamp(nears, fars) # avoid out of bounds xyzs.
340
+
341
+ # generate xyzs
342
+ xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(-1) # [N, 1, 3] * [N, T, 1] -> [N, T, 3]
343
+ xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:]) # a manual clip.
344
+
345
+ #plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())
346
+
347
+ # query SDF and RGB
348
+ density_outputs = self.density(xyzs.reshape(-1, 3))
349
+
350
+ #sigmas = density_outputs['sigma'].view(N, num_steps) # [N, T]
351
+ for k, v in density_outputs.items():
352
+ density_outputs[k] = v.view(N, num_steps, -1)
353
+
354
+ # upsample z_vals (nerf-like)
355
+ if upsample_steps > 0:
356
+ with torch.no_grad():
357
+
358
+ deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T-1]
359
+ deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1)
360
+
361
+ alphas = 1 - torch.exp(-deltas * density_outputs['sigma'].squeeze(-1)) # [N, T]
362
+ alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+1]
363
+ weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T]
364
+
365
+ # sample new z_vals
366
+ z_vals_mid = (z_vals[..., :-1] + 0.5 * deltas[..., :-1]) # [N, T-1]
367
+ new_z_vals = sample_pdf(z_vals_mid, weights[:, 1:-1], upsample_steps, det=not self.training).detach() # [N, t]
368
+
369
+ new_xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * new_z_vals.unsqueeze(-1) # [N, 1, 3] * [N, t, 1] -> [N, t, 3]
370
+ new_xyzs = torch.min(torch.max(new_xyzs, aabb[:3]), aabb[3:]) # a manual clip.
371
+
372
+ # only forward new points to save computation
373
+ new_density_outputs = self.density(new_xyzs.reshape(-1, 3))
374
+ #new_sigmas = new_density_outputs['sigma'].view(N, upsample_steps) # [N, t]
375
+ for k, v in new_density_outputs.items():
376
+ new_density_outputs[k] = v.view(N, upsample_steps, -1)
377
+
378
+ # re-order
379
+ z_vals = torch.cat([z_vals, new_z_vals], dim=1) # [N, T+t]
380
+ z_vals, z_index = torch.sort(z_vals, dim=1)
381
+
382
+ xyzs = torch.cat([xyzs, new_xyzs], dim=1) # [N, T+t, 3]
383
+ xyzs = torch.gather(xyzs, dim=1, index=z_index.unsqueeze(-1).expand_as(xyzs))
384
+
385
+ for k in density_outputs:
386
+ tmp_output = torch.cat([density_outputs[k], new_density_outputs[k]], dim=1)
387
+ density_outputs[k] = torch.gather(tmp_output, dim=1, index=z_index.unsqueeze(-1).expand_as(tmp_output))
388
+
389
+ deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T+t-1]
390
+ deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1)
391
+ alphas = 1 - torch.exp(-deltas * density_outputs['sigma'].squeeze(-1)) # [N, T+t]
392
+ alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+t+1]
393
+ weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T+t]
394
+
395
+ dirs = rays_d.view(-1, 1, 3).expand_as(xyzs)
396
+ for k, v in density_outputs.items():
397
+ density_outputs[k] = v.view(-1, v.shape[-1])
398
+
399
+ sigmas, rgbs, normals = self(xyzs.reshape(-1, 3), dirs.reshape(-1, 3), light_d, ratio=ambient_ratio, shading=shading)
400
+ rgbs = rgbs.view(N, -1, 3) # [N, T+t, 3]
401
+
402
+ #print(xyzs.shape, 'valid_rgb:', mask.sum().item())
403
+ # orientation loss
404
+ if normals is not None:
405
+ normals = normals.view(N, -1, 3)
406
+ # print(weights.shape, normals.shape, dirs.shape)
407
+ loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2
408
+ results['loss_orient'] = loss_orient.mean()
409
+
410
+ # calculate weight_sum (mask)
411
+ weights_sum = weights.sum(dim=-1) # [N]
412
+
413
+ # calculate depth
414
+ ori_z_vals = ((z_vals - nears) / (fars - nears)).clamp(0, 1)
415
+ depth = torch.sum(weights * ori_z_vals, dim=-1)
416
+
417
+ # calculate color
418
+ image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2) # [N, 3], in [0, 1]
419
+
420
+ # mix background color
421
+ if self.bg_radius > 0:
422
+ # use the bg model to calculate bg_color
423
+ # sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1]
424
+ bg_color = self.background(rays_d.reshape(-1, 3)) # [N, 3]
425
+ elif bg_color is None:
426
+ bg_color = 1
427
+
428
+ image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
429
+
430
+ image = image.view(*prefix, 3)
431
+ depth = depth.view(*prefix)
432
+
433
+ mask = (nears < fars).reshape(*prefix)
434
+
435
+ results['image'] = image
436
+ results['depth'] = depth
437
+ results['weights_sum'] = weights_sum
438
+ results['mask'] = mask
439
+
440
+ return results
441
+
442
+
443
+ def run_cuda(self, rays_o, rays_d, dt_gamma=0, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, **kwargs):
444
+ # rays_o, rays_d: [B, N, 3], assumes B == 1
445
+ # return: image: [B, N, 3], depth: [B, N]
446
+
447
+ prefix = rays_o.shape[:-1]
448
+ rays_o = rays_o.contiguous().view(-1, 3)
449
+ rays_d = rays_d.contiguous().view(-1, 3)
450
+
451
+ N = rays_o.shape[0] # N = B * N, in fact
452
+ device = rays_o.device
453
+
454
+ # pre-calculate near far
455
+ nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer)
456
+
457
+ # random sample light_d if not provided
458
+ if light_d is None:
459
+ # gaussian noise around the ray origin, so the light always face the view dir (avoid dark face)
460
+ light_d = (rays_o[0] + torch.randn(3, device=device, dtype=torch.float))
461
+ light_d = safe_normalize(light_d)
462
+
463
+ results = {}
464
+
465
+ if self.training:
466
+ # setup counter
467
+ counter = self.step_counter[self.local_step % 16]
468
+ counter.zero_() # set to 0
469
+ self.local_step += 1
470
+
471
+ xyzs, dirs, deltas, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps)
472
+
473
+ #plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())
474
+
475
+ sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading)
476
+
477
+ #print(f'valid RGB query ratio: {mask.sum().item() / mask.shape[0]} (total = {mask.sum().item()})')
478
+
479
+ weights_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, deltas, rays, T_thresh)
480
+
481
+ # orientation loss
482
+ if normals is not None:
483
+ weights = 1 - torch.exp(-sigmas)
484
+ loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2
485
+ results['loss_orient'] = loss_orient.mean()
486
+
487
+ else:
488
+
489
+ # allocate outputs
490
+ dtype = torch.float32
491
+
492
+ weights_sum = torch.zeros(N, dtype=dtype, device=device)
493
+ depth = torch.zeros(N, dtype=dtype, device=device)
494
+ image = torch.zeros(N, 3, dtype=dtype, device=device)
495
+
496
+ n_alive = N
497
+ rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N]
498
+ rays_t = nears.clone() # [N]
499
+
500
+ step = 0
501
+
502
+ while step < max_steps: # hard coded max step
503
+
504
+ # count alive rays
505
+ n_alive = rays_alive.shape[0]
506
+
507
+ # exit loop
508
+ if n_alive <= 0:
509
+ break
510
+
511
+ # decide compact_steps
512
+ n_step = max(min(N // n_alive, 8), 1)
513
+
514
+ xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, 128, perturb if step == 0 else False, dt_gamma, max_steps)
515
+
516
+ sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading)
517
+
518
+ raymarching.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh)
519
+
520
+ rays_alive = rays_alive[rays_alive >= 0]
521
+ #print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}')
522
+
523
+ step += n_step
524
+
525
+ # mix background color
526
+ if self.bg_radius > 0:
527
+
528
+ # use the bg model to calculate bg_color
529
+ # sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1]
530
+ bg_color = self.background(rays_d) # [N, 3]
531
+
532
+ elif bg_color is None:
533
+ bg_color = 1
534
+
535
+ image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
536
+ image = image.view(*prefix, 3)
537
+
538
+ depth = torch.clamp(depth - nears, min=0) / (fars - nears)
539
+ depth = depth.view(*prefix)
540
+
541
+ weights_sum = weights_sum.reshape(*prefix)
542
+
543
+ mask = (nears < fars).reshape(*prefix)
544
+
545
+ results['image'] = image
546
+ results['depth'] = depth
547
+ results['weights_sum'] = weights_sum
548
+ results['mask'] = mask
549
+
550
+ return results
551
+
552
+
553
+ @torch.no_grad()
554
+ def update_extra_state(self, decay=0.95, S=128):
555
+ # call before each epoch to update extra states.
556
+
557
+ if not self.cuda_ray:
558
+ return
559
+
560
+ ### update density grid
561
+ tmp_grid = - torch.ones_like(self.density_grid)
562
+
563
+ X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
564
+ Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
565
+ Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
566
+
567
+ for xs in X:
568
+ for ys in Y:
569
+ for zs in Z:
570
+
571
+ # construct points
572
+ xx, yy, zz = custom_meshgrid(xs, ys, zs)
573
+ coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
574
+ indices = raymarching.morton3D(coords).long() # [N]
575
+ xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1]
576
+
577
+ # cascading
578
+ for cas in range(self.cascade):
579
+ bound = min(2 ** cas, self.bound)
580
+ half_grid_size = bound / self.grid_size
581
+ # scale to current cascade's resolution
582
+ cas_xyzs = xyzs * (bound - half_grid_size)
583
+ # add noise in [-hgs, hgs]
584
+ cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size
585
+ # query density
586
+ sigmas = self.density(cas_xyzs)['sigma'].reshape(-1).detach()
587
+ # assign
588
+ tmp_grid[cas, indices] = sigmas
589
+
590
+ # ema update
591
+ valid_mask = self.density_grid >= 0
592
+ self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask])
593
+ self.mean_density = torch.mean(self.density_grid[valid_mask]).item()
594
+ self.iter_density += 1
595
+
596
+ # convert to bitfield
597
+ density_thresh = min(self.mean_density, self.density_thresh)
598
+ self.density_bitfield = raymarching.packbits(self.density_grid, density_thresh, self.density_bitfield)
599
+
600
+ ### update step counter
601
+ total_step = min(16, self.local_step)
602
+ if total_step > 0:
603
+ self.mean_count = int(self.step_counter[:total_step, 0].sum().item() / total_step)
604
+ self.local_step = 0
605
+
606
+ # print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > density_thresh).sum() / (128**3 * self.cascade):.3f} | [step counter] mean={self.mean_count}')
607
+
608
+
609
+ def render(self, rays_o, rays_d, staged=False, max_ray_batch=4096, **kwargs):
610
+ # rays_o, rays_d: [B, N, 3], assumes B == 1
611
+ # return: pred_rgb: [B, N, 3]
612
+
613
+ if self.cuda_ray:
614
+ _run = self.run_cuda
615
+ else:
616
+ _run = self.run
617
+
618
+ B, N = rays_o.shape[:2]
619
+ device = rays_o.device
620
+
621
+ # never stage when cuda_ray
622
+ if staged and not self.cuda_ray:
623
+ depth = torch.empty((B, N), device=device)
624
+ image = torch.empty((B, N, 3), device=device)
625
+ weights_sum = torch.empty((B, N), device=device)
626
+
627
+ for b in range(B):
628
+ head = 0
629
+ while head < N:
630
+ tail = min(head + max_ray_batch, N)
631
+ results_ = _run(rays_o[b:b+1, head:tail], rays_d[b:b+1, head:tail], **kwargs)
632
+ depth[b:b+1, head:tail] = results_['depth']
633
+ weights_sum[b:b+1, head:tail] = results_['weights_sum']
634
+ image[b:b+1, head:tail] = results_['image']
635
+ head += max_ray_batch
636
+
637
+ results = {}
638
+ results['depth'] = depth
639
+ results['image'] = image
640
+ results['weights_sum'] = weights_sum
641
+
642
+ else:
643
+ results = _run(rays_o, rays_d, **kwargs)
644
+
645
+ return results
nerf/sd.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPTextModel, CLIPTokenizer, logging
2
+ from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
3
+
4
+ # suppress partial model loading warning
5
+ logging.set_verbosity_error()
6
+
7
+ import os
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ import time
13
+
14
+ class StableDiffusion(nn.Module):
15
+ def __init__(self, device):
16
+ super().__init__()
17
+
18
+ try:
19
+ self.token = os.environ['TOKEN']
20
+ print(f'[INFO] loaded hugging face access token from environment variable TOKEN')
21
+ except FileNotFoundError as e:
22
+ self.token = True
23
+ print(f'[INFO] try to load hugging face access token from the default place, make sure you have run `huggingface-cli login`.')
24
+
25
+ self.device = device
26
+ self.num_train_timesteps = 1000
27
+ self.min_step = int(self.num_train_timesteps * 0.02)
28
+ self.max_step = int(self.num_train_timesteps * 0.98)
29
+
30
+ print(f'[INFO] loading stable diffusion...')
31
+
32
+ # 1. Load the autoencoder model which will be used to decode the latents into image space.
33
+ self.vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", use_auth_token=self.token).to(self.device)
34
+
35
+ # 2. Load the tokenizer and text encoder to tokenize and encode the text.
36
+ self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
37
+ self.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(self.device)
38
+
39
+ # 3. The UNet model for generating the latents.
40
+ self.unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", use_auth_token=self.token).to(self.device)
41
+
42
+ # 4. Create a scheduler for inference
43
+ self.scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=self.num_train_timesteps)
44
+ self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
45
+
46
+ print(f'[INFO] loaded stable diffusion!')
47
+
48
+ def get_text_embeds(self, prompt):
49
+ # Tokenize text and get embeddings
50
+ text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt')
51
+
52
+ with torch.no_grad():
53
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
54
+
55
+ # Do the same for unconditional embeddings
56
+ uncond_input = self.tokenizer([''] * len(prompt), padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt')
57
+
58
+ with torch.no_grad():
59
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
60
+
61
+ # Cat for final embeddings
62
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
63
+ return text_embeddings
64
+
65
+
66
+ def train_step(self, text_embeddings, pred_rgb, guidance_scale=100):
67
+
68
+ # interp to 512x512 to be fed into vae.
69
+
70
+ # _t = time.time()
71
+ pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
72
+ # torch.cuda.synchronize(); print(f'[TIME] guiding: interp {time.time() - _t:.4f}s')
73
+
74
+ # timestep ~ U(0.02, 0.98) to avoid very high/low noise level
75
+ t = torch.randint(self.min_step, self.max_step + 1, [1], dtype=torch.long, device=self.device)
76
+
77
+ # encode image into latents with vae, requires grad!
78
+ # _t = time.time()
79
+ latents = self.encode_imgs(pred_rgb_512)
80
+ # torch.cuda.synchronize(); print(f'[TIME] guiding: vae enc {time.time() - _t:.4f}s')
81
+
82
+ # predict the noise residual with unet, NO grad!
83
+ # _t = time.time()
84
+ with torch.no_grad():
85
+ # add noise
86
+ noise = torch.randn_like(latents)
87
+ latents_noisy = self.scheduler.add_noise(latents, noise, t)
88
+ # pred noise
89
+ latent_model_input = torch.cat([latents_noisy] * 2)
90
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
91
+ # torch.cuda.synchronize(); print(f'[TIME] guiding: unet {time.time() - _t:.4f}s')
92
+
93
+ # perform guidance (high scale from paper!)
94
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
95
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
96
+
97
+ # w(t), sigma_t^2
98
+ w = (1 - self.alphas[t])
99
+ # w = self.alphas[t] ** 0.5 * (1 - self.alphas[t])
100
+ grad = w * (noise_pred - noise)
101
+
102
+ # clip grad for stable training?
103
+ # grad = grad.clamp(-1, 1)
104
+
105
+ # manually backward, since we omitted an item in grad and cannot simply autodiff.
106
+ # _t = time.time()
107
+ latents.backward(gradient=grad, retain_graph=True)
108
+ # torch.cuda.synchronize(); print(f'[TIME] guiding: backward {time.time() - _t:.4f}s')
109
+
110
+ return 0 # dummy loss value
111
+
112
+ def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):
113
+
114
+ if latents is None:
115
+ latents = torch.randn((text_embeddings.shape[0] // 2, self.unet.in_channels, height // 8, width // 8), device=self.device)
116
+
117
+ self.scheduler.set_timesteps(num_inference_steps)
118
+
119
+ with torch.autocast('cuda'):
120
+ for i, t in enumerate(self.scheduler.timesteps):
121
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
122
+ latent_model_input = torch.cat([latents] * 2)
123
+
124
+ # predict the noise residual
125
+ with torch.no_grad():
126
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']
127
+
128
+ # perform guidance
129
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
130
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
131
+
132
+ # compute the previous noisy sample x_t -> x_t-1
133
+ latents = self.scheduler.step(noise_pred, t, latents)['prev_sample']
134
+
135
+ return latents
136
+
137
+ def decode_latents(self, latents):
138
+
139
+ latents = 1 / 0.18215 * latents
140
+
141
+ with torch.no_grad():
142
+ imgs = self.vae.decode(latents).sample
143
+
144
+ imgs = (imgs / 2 + 0.5).clamp(0, 1)
145
+
146
+ return imgs
147
+
148
+ def encode_imgs(self, imgs):
149
+ # imgs: [B, 3, H, W]
150
+
151
+ imgs = 2 * imgs - 1
152
+
153
+ posterior = self.vae.encode(imgs).latent_dist
154
+ latents = posterior.sample() * 0.18215
155
+
156
+ return latents
157
+
158
+ def prompt_to_img(self, prompts, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):
159
+
160
+ if isinstance(prompts, str):
161
+ prompts = [prompts]
162
+
163
+ # Prompts -> text embeds
164
+ text_embeds = self.get_text_embeds(prompts) # [2, 77, 768]
165
+
166
+ # Text embeds -> img latents
167
+ latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64]
168
+
169
+ # Img latents -> imgs
170
+ imgs = self.decode_latents(latents) # [1, 3, 512, 512]
171
+
172
+ # Img to Numpy
173
+ imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
174
+ imgs = (imgs * 255).round().astype('uint8')
175
+
176
+ return imgs
177
+
178
+
179
+ if __name__ == '__main__':
180
+
181
+ import argparse
182
+ import matplotlib.pyplot as plt
183
+
184
+ parser = argparse.ArgumentParser()
185
+ parser.add_argument('prompt', type=str)
186
+ parser.add_argument('-H', type=int, default=512)
187
+ parser.add_argument('-W', type=int, default=512)
188
+ parser.add_argument('--steps', type=int, default=50)
189
+ opt = parser.parse_args()
190
+
191
+ device = torch.device('cuda')
192
+
193
+ sd = StableDiffusion(device)
194
+
195
+ imgs = sd.prompt_to_img(opt.prompt, opt.H, opt.W, opt.steps)
196
+
197
+ # visualize image
198
+ plt.imshow(imgs[0])
199
+ plt.show()
200
+
201
+
202
+
203
+
nerf/utils.py ADDED
@@ -0,0 +1,950 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import tqdm
4
+ import math
5
+ import imageio
6
+ import random
7
+ import warnings
8
+ import tensorboardX
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+
13
+ import time
14
+ from datetime import datetime
15
+
16
+ import cv2
17
+ import matplotlib.pyplot as plt
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.optim as optim
22
+ import torch.nn.functional as F
23
+ import torch.distributed as dist
24
+ from torch.utils.data import Dataset, DataLoader
25
+
26
+ import trimesh
27
+ from rich.console import Console
28
+ from torch_ema import ExponentialMovingAverage
29
+
30
+ from packaging import version as pver
31
+
32
+ def custom_meshgrid(*args):
33
+ # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
34
+ if pver.parse(torch.__version__) < pver.parse('1.10'):
35
+ return torch.meshgrid(*args)
36
+ else:
37
+ return torch.meshgrid(*args, indexing='ij')
38
+
39
+ def safe_normalize(x, eps=1e-20):
40
+ return x / torch.sqrt(torch.clamp(torch.sum(x * x, -1, keepdim=True), min=eps))
41
+
42
+ @torch.cuda.amp.autocast(enabled=False)
43
+ def get_rays(poses, intrinsics, H, W, N=-1, error_map=None):
44
+ ''' get rays
45
+ Args:
46
+ poses: [B, 4, 4], cam2world
47
+ intrinsics: [4]
48
+ H, W, N: int
49
+ error_map: [B, 128 * 128], sample probability based on training error
50
+ Returns:
51
+ rays_o, rays_d: [B, N, 3]
52
+ inds: [B, N]
53
+ '''
54
+
55
+ device = poses.device
56
+ B = poses.shape[0]
57
+ fx, fy, cx, cy = intrinsics
58
+
59
+ i, j = custom_meshgrid(torch.linspace(0, W-1, W, device=device), torch.linspace(0, H-1, H, device=device))
60
+ i = i.t().reshape([1, H*W]).expand([B, H*W]) + 0.5
61
+ j = j.t().reshape([1, H*W]).expand([B, H*W]) + 0.5
62
+
63
+ results = {}
64
+
65
+ if N > 0:
66
+ N = min(N, H*W)
67
+
68
+ if error_map is None:
69
+ inds = torch.randint(0, H*W, size=[N], device=device) # may duplicate
70
+ inds = inds.expand([B, N])
71
+ else:
72
+
73
+ # weighted sample on a low-reso grid
74
+ inds_coarse = torch.multinomial(error_map.to(device), N, replacement=False) # [B, N], but in [0, 128*128)
75
+
76
+ # map to the original resolution with random perturb.
77
+ inds_x, inds_y = inds_coarse // 128, inds_coarse % 128 # `//` will throw a warning in torch 1.10... anyway.
78
+ sx, sy = H / 128, W / 128
79
+ inds_x = (inds_x * sx + torch.rand(B, N, device=device) * sx).long().clamp(max=H - 1)
80
+ inds_y = (inds_y * sy + torch.rand(B, N, device=device) * sy).long().clamp(max=W - 1)
81
+ inds = inds_x * W + inds_y
82
+
83
+ results['inds_coarse'] = inds_coarse # need this when updating error_map
84
+
85
+ i = torch.gather(i, -1, inds)
86
+ j = torch.gather(j, -1, inds)
87
+
88
+ results['inds'] = inds
89
+
90
+ else:
91
+ inds = torch.arange(H*W, device=device).expand([B, H*W])
92
+
93
+ zs = torch.ones_like(i)
94
+ xs = (i - cx) / fx * zs
95
+ ys = (j - cy) / fy * zs
96
+ directions = torch.stack((xs, ys, zs), dim=-1)
97
+ directions = safe_normalize(directions)
98
+ rays_d = directions @ poses[:, :3, :3].transpose(-1, -2) # (B, N, 3)
99
+
100
+ rays_o = poses[..., :3, 3] # [B, 3]
101
+ rays_o = rays_o[..., None, :].expand_as(rays_d) # [B, N, 3]
102
+
103
+ results['rays_o'] = rays_o
104
+ results['rays_d'] = rays_d
105
+
106
+ return results
107
+
108
+
109
+ def seed_everything(seed):
110
+ random.seed(seed)
111
+ os.environ['PYTHONHASHSEED'] = str(seed)
112
+ np.random.seed(seed)
113
+ torch.manual_seed(seed)
114
+ torch.cuda.manual_seed(seed)
115
+ #torch.backends.cudnn.deterministic = True
116
+ #torch.backends.cudnn.benchmark = True
117
+
118
+
119
+ def torch_vis_2d(x, renormalize=False):
120
+ # x: [3, H, W] or [1, H, W] or [H, W]
121
+ import matplotlib.pyplot as plt
122
+ import numpy as np
123
+ import torch
124
+
125
+ if isinstance(x, torch.Tensor):
126
+ if len(x.shape) == 3:
127
+ x = x.permute(1,2,0).squeeze()
128
+ x = x.detach().cpu().numpy()
129
+
130
+ print(f'[torch_vis_2d] {x.shape}, {x.dtype}, {x.min()} ~ {x.max()}')
131
+
132
+ x = x.astype(np.float32)
133
+
134
+ # renormalize
135
+ if renormalize:
136
+ x = (x - x.min(axis=0, keepdims=True)) / (x.max(axis=0, keepdims=True) - x.min(axis=0, keepdims=True) + 1e-8)
137
+
138
+ plt.imshow(x)
139
+ plt.show()
140
+
141
+ @torch.jit.script
142
+ def linear_to_srgb(x):
143
+ return torch.where(x < 0.0031308, 12.92 * x, 1.055 * x ** 0.41666 - 0.055)
144
+
145
+
146
+ @torch.jit.script
147
+ def srgb_to_linear(x):
148
+ return torch.where(x < 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4)
149
+
150
+
151
+ class Trainer(object):
152
+ def __init__(self,
153
+ name, # name of this experiment
154
+ opt, # extra conf
155
+ model, # network
156
+ guidance, # guidance network
157
+ criterion=None, # loss function, if None, assume inline implementation in train_step
158
+ optimizer=None, # optimizer
159
+ ema_decay=None, # if use EMA, set the decay
160
+ lr_scheduler=None, # scheduler
161
+ metrics=[], # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric.
162
+ local_rank=0, # which GPU am I
163
+ world_size=1, # total num of GPUs
164
+ device=None, # device to use, usually setting to None is OK. (auto choose device)
165
+ mute=False, # whether to mute all print
166
+ fp16=False, # amp optimize level
167
+ eval_interval=1, # eval once every $ epoch
168
+ max_keep_ckpt=2, # max num of saved ckpts in disk
169
+ workspace='workspace', # workspace to save logs & ckpts
170
+ best_mode='min', # the smaller/larger result, the better
171
+ use_loss_as_metric=True, # use loss as the first metric
172
+ report_metric_at_train=False, # also report metrics at training
173
+ use_checkpoint="latest", # which ckpt to use at init time
174
+ use_tensorboardX=True, # whether to use tensorboard for logging
175
+ scheduler_update_every_step=False, # whether to call scheduler.step() after every train step
176
+ ):
177
+
178
+ self.name = name
179
+ self.opt = opt
180
+ self.mute = mute
181
+ self.metrics = metrics
182
+ self.local_rank = local_rank
183
+ self.world_size = world_size
184
+ self.workspace = workspace
185
+ self.ema_decay = ema_decay
186
+ self.fp16 = fp16
187
+ self.best_mode = best_mode
188
+ self.use_loss_as_metric = use_loss_as_metric
189
+ self.report_metric_at_train = report_metric_at_train
190
+ self.max_keep_ckpt = max_keep_ckpt
191
+ self.eval_interval = eval_interval
192
+ self.use_checkpoint = use_checkpoint
193
+ self.use_tensorboardX = use_tensorboardX
194
+ self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S")
195
+ self.scheduler_update_every_step = scheduler_update_every_step
196
+ self.device = device if device is not None else torch.device(f'cuda:{local_rank}' if torch.cuda.is_available() else 'cpu')
197
+ self.console = Console()
198
+
199
+ model.to(self.device)
200
+ if self.world_size > 1:
201
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
202
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
203
+ self.model = model
204
+
205
+ # guide model
206
+ self.guidance = guidance
207
+
208
+ # text prompt
209
+ if self.guidance is not None:
210
+
211
+ for p in self.guidance.parameters():
212
+ p.requires_grad = False
213
+
214
+ self.prepare_text_embeddings()
215
+
216
+ else:
217
+ self.text_z = None
218
+
219
+ if isinstance(criterion, nn.Module):
220
+ criterion.to(self.device)
221
+ self.criterion = criterion
222
+
223
+ if optimizer is None:
224
+ self.optimizer = optim.Adam(self.model.parameters(), lr=0.001, weight_decay=5e-4) # naive adam
225
+ else:
226
+ self.optimizer = optimizer(self.model)
227
+
228
+ if lr_scheduler is None:
229
+ self.lr_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: 1) # fake scheduler
230
+ else:
231
+ self.lr_scheduler = lr_scheduler(self.optimizer)
232
+
233
+ if ema_decay is not None:
234
+ self.ema = ExponentialMovingAverage(self.model.parameters(), decay=ema_decay)
235
+ else:
236
+ self.ema = None
237
+
238
+ self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16)
239
+
240
+ # variable init
241
+ self.epoch = 0
242
+ self.global_step = 0
243
+ self.local_step = 0
244
+ self.stats = {
245
+ "loss": [],
246
+ "valid_loss": [],
247
+ "results": [], # metrics[0], or valid_loss
248
+ "checkpoints": [], # record path of saved ckpt, to automatically remove old ckpt
249
+ "best_result": None,
250
+ }
251
+
252
+ # auto fix
253
+ if len(metrics) == 0 or self.use_loss_as_metric:
254
+ self.best_mode = 'min'
255
+
256
+ # workspace prepare
257
+ self.log_ptr = None
258
+ if self.workspace is not None:
259
+ os.makedirs(self.workspace, exist_ok=True)
260
+ self.log_path = os.path.join(workspace, f"log_{self.name}.txt")
261
+ self.log_ptr = open(self.log_path, "a+")
262
+
263
+ self.ckpt_path = os.path.join(self.workspace, 'checkpoints')
264
+ self.best_path = f"{self.ckpt_path}/{self.name}.pth"
265
+ os.makedirs(self.ckpt_path, exist_ok=True)
266
+
267
+ self.log(f'[INFO] Trainer: {self.name} | {self.time_stamp} | {self.device} | {"fp16" if self.fp16 else "fp32"} | {self.workspace}')
268
+ self.log(f'[INFO] #parameters: {sum([p.numel() for p in model.parameters() if p.requires_grad])}')
269
+
270
+ if self.workspace is not None:
271
+ if self.use_checkpoint == "scratch":
272
+ self.log("[INFO] Training from scratch ...")
273
+ elif self.use_checkpoint == "latest":
274
+ self.log("[INFO] Loading latest checkpoint ...")
275
+ self.load_checkpoint()
276
+ elif self.use_checkpoint == "latest_model":
277
+ self.log("[INFO] Loading latest checkpoint (model only)...")
278
+ self.load_checkpoint(model_only=True)
279
+ elif self.use_checkpoint == "best":
280
+ if os.path.exists(self.best_path):
281
+ self.log("[INFO] Loading best checkpoint ...")
282
+ self.load_checkpoint(self.best_path)
283
+ else:
284
+ self.log(f"[INFO] {self.best_path} not found, loading latest ...")
285
+ self.load_checkpoint()
286
+ else: # path to ckpt
287
+ self.log(f"[INFO] Loading {self.use_checkpoint} ...")
288
+ self.load_checkpoint(self.use_checkpoint)
289
+
290
+ # calculate the text embs.
291
+ def prepare_text_embeddings(self):
292
+
293
+ if self.opt.text is None:
294
+ self.log(f"[WARN] text prompt is not provided.")
295
+ self.text_z = None
296
+ return
297
+
298
+ if not self.opt.dir_text:
299
+ self.text_z = self.guidance.get_text_embeds([self.opt.text])
300
+ else:
301
+ self.text_z = []
302
+ for d in ['front', 'side', 'back', 'side', 'overhead', 'bottom']:
303
+ text = f"{self.opt.text}, {d} view"
304
+ text_z = self.guidance.get_text_embeds([text])
305
+ self.text_z.append(text_z)
306
+
307
+ def __del__(self):
308
+ if self.log_ptr:
309
+ self.log_ptr.close()
310
+
311
+
312
+ def log(self, *args, **kwargs):
313
+ if self.local_rank == 0:
314
+ if not self.mute:
315
+ #print(*args)
316
+ self.console.print(*args, **kwargs)
317
+ if self.log_ptr:
318
+ print(*args, file=self.log_ptr)
319
+ self.log_ptr.flush() # write immediately to file
320
+
321
+ ### ------------------------------
322
+
323
+ def train_step(self, data):
324
+
325
+ rays_o = data['rays_o'] # [B, N, 3]
326
+ rays_d = data['rays_d'] # [B, N, 3]
327
+
328
+ B, N = rays_o.shape[:2]
329
+ H, W = data['H'], data['W']
330
+
331
+ # TODO: shading is not working right now...
332
+ if self.global_step < self.opt.albedo_iters:
333
+ shading = 'albedo'
334
+ ambient_ratio = 1.0
335
+ else:
336
+ rand = random.random()
337
+ if rand > 0.8:
338
+ shading = 'albedo'
339
+ ambient_ratio = 1.0
340
+ # elif rand > 0.4:
341
+ # shading = 'textureless'
342
+ # ambient_ratio = 0.1
343
+ else:
344
+ shading = 'lambertian'
345
+ ambient_ratio = 0.1
346
+
347
+ # _t = time.time()
348
+ bg_color = torch.rand((B * N, 3), device=rays_o.device) # pixel-wise random
349
+ outputs = self.model.render(rays_o, rays_d, staged=False, perturb=True, bg_color=bg_color, ambient_ratio=ambient_ratio, shading=shading, force_all_rays=True, **vars(self.opt))
350
+ pred_rgb = outputs['image'].reshape(B, H, W, 3).permute(0, 3, 1, 2).contiguous() # [1, 3, H, W]
351
+ # torch.cuda.synchronize(); print(f'[TIME] nerf render {time.time() - _t:.4f}s')
352
+
353
+ # print(shading)
354
+ # torch_vis_2d(pred_rgb[0])
355
+
356
+ # text embeddings
357
+ if self.opt.dir_text:
358
+ dirs = data['dir'] # [B,]
359
+ text_z = self.text_z[dirs]
360
+ else:
361
+ text_z = self.text_z
362
+
363
+ # encode pred_rgb to latents
364
+ # _t = time.time()
365
+ loss = self.guidance.train_step(text_z, pred_rgb)
366
+ # torch.cuda.synchronize(); print(f'[TIME] total guiding {time.time() - _t:.4f}s')
367
+
368
+ # occupancy loss
369
+ pred_ws = outputs['weights_sum'].reshape(B, 1, H, W)
370
+
371
+ if self.opt.lambda_opacity > 0:
372
+ loss_opacity = (pred_ws ** 2).mean()
373
+ loss = loss + self.opt.lambda_opacity * loss_opacity
374
+
375
+ if self.opt.lambda_entropy > 0:
376
+ alphas = (pred_ws).clamp(1e-5, 1 - 1e-5)
377
+ # alphas = alphas ** 2 # skewed entropy, favors 0 over 1
378
+ loss_entropy = (- alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas)).mean()
379
+
380
+ loss = loss + self.opt.lambda_entropy * loss_entropy
381
+
382
+ if self.opt.lambda_orient > 0 and 'loss_orient' in outputs:
383
+ loss_orient = outputs['loss_orient']
384
+ loss = loss + self.opt.lambda_orient * loss_orient
385
+
386
+ return pred_rgb, pred_ws, loss
387
+
388
+ def eval_step(self, data):
389
+
390
+ rays_o = data['rays_o'] # [B, N, 3]
391
+ rays_d = data['rays_d'] # [B, N, 3]
392
+
393
+ B, N = rays_o.shape[:2]
394
+ H, W = data['H'], data['W']
395
+
396
+ shading = data['shading'] if 'shading' in data else 'albedo'
397
+ ambient_ratio = data['ambient_ratio'] if 'ambient_ratio' in data else 1.0
398
+ light_d = data['light_d'] if 'light_d' in data else None
399
+
400
+ outputs = self.model.render(rays_o, rays_d, staged=True, perturb=False, bg_color=None, light_d=light_d, ambient_ratio=ambient_ratio, shading=shading, force_all_rays=True, **vars(self.opt))
401
+ pred_rgb = outputs['image'].reshape(B, H, W, 3)
402
+ pred_depth = outputs['depth'].reshape(B, H, W)
403
+ pred_ws = outputs['weights_sum'].reshape(B, H, W)
404
+ # mask_ws = outputs['mask'].reshape(B, H, W) # near < far
405
+
406
+ # loss_ws = pred_ws.sum() / mask_ws.sum()
407
+ # loss_ws = pred_ws.mean()
408
+
409
+ alphas = (pred_ws).clamp(1e-5, 1 - 1e-5)
410
+ # alphas = alphas ** 2 # skewed entropy, favors 0 over 1
411
+ loss_entropy = (- alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas)).mean()
412
+
413
+ loss = self.opt.lambda_entropy * loss_entropy
414
+
415
+ return pred_rgb, pred_depth, loss
416
+
417
+ def test_step(self, data, bg_color=None, perturb=False):
418
+ rays_o = data['rays_o'] # [B, N, 3]
419
+ rays_d = data['rays_d'] # [B, N, 3]
420
+
421
+ B, N = rays_o.shape[:2]
422
+ H, W = data['H'], data['W']
423
+
424
+ if bg_color is not None:
425
+ bg_color = bg_color.to(rays_o.device)
426
+ else:
427
+ bg_color = torch.ones(3, device=rays_o.device) # [3]
428
+
429
+ shading = data['shading'] if 'shading' in data else 'albedo'
430
+ ambient_ratio = data['ambient_ratio'] if 'ambient_ratio' in data else 1.0
431
+ light_d = data['light_d'] if 'light_d' in data else None
432
+
433
+ outputs = self.model.render(rays_o, rays_d, staged=True, perturb=perturb, light_d=light_d, ambient_ratio=ambient_ratio, shading=shading, force_all_rays=True, bg_color=bg_color, **vars(self.opt))
434
+
435
+ pred_rgb = outputs['image'].reshape(B, H, W, 3)
436
+ pred_depth = outputs['depth'].reshape(B, H, W)
437
+
438
+ return pred_rgb, pred_depth
439
+
440
+
441
+ def save_mesh(self, save_path=None, resolution=128):
442
+
443
+ if save_path is None:
444
+ save_path = os.path.join(self.workspace, 'mesh')
445
+
446
+ self.log(f"==> Saving mesh to {save_path}")
447
+
448
+ os.makedirs(save_path, exist_ok=True)
449
+
450
+ self.model.export_mesh(save_path, resolution=resolution)
451
+
452
+ self.log(f"==> Finished saving mesh.")
453
+
454
+ ### ------------------------------
455
+
456
+ def train(self, train_loader, valid_loader, max_epochs):
457
+
458
+ assert self.text_z is not None, 'Training must provide a text prompt!'
459
+
460
+ if self.use_tensorboardX and self.local_rank == 0:
461
+ self.writer = tensorboardX.SummaryWriter(os.path.join(self.workspace, "run", self.name))
462
+
463
+ start_t = time.time()
464
+
465
+ for epoch in range(self.epoch + 1, max_epochs + 1):
466
+ self.epoch = epoch
467
+
468
+ self.train_one_epoch(train_loader)
469
+
470
+ if self.workspace is not None and self.local_rank == 0:
471
+ self.save_checkpoint(full=True, best=False)
472
+
473
+ if self.epoch % self.eval_interval == 0:
474
+ self.evaluate_one_epoch(valid_loader)
475
+ self.save_checkpoint(full=False, best=True)
476
+
477
+ end_t = time.time()
478
+
479
+ self.log(f"[INFO] training takes {(end_t - start_t)/ 60:.4f} minutes.")
480
+
481
+ if self.use_tensorboardX and self.local_rank == 0:
482
+ self.writer.close()
483
+
484
+ def evaluate(self, loader, name=None):
485
+ self.use_tensorboardX, use_tensorboardX = False, self.use_tensorboardX
486
+ self.evaluate_one_epoch(loader, name)
487
+ self.use_tensorboardX = use_tensorboardX
488
+
489
+ def test(self, loader, save_path=None, name=None, write_video=True):
490
+
491
+ if save_path is None:
492
+ save_path = os.path.join(self.workspace, 'results')
493
+
494
+ if name is None:
495
+ name = f'{self.name}_ep{self.epoch:04d}'
496
+
497
+ os.makedirs(save_path, exist_ok=True)
498
+
499
+ self.log(f"==> Start Test, save results to {save_path}")
500
+
501
+ pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')
502
+ self.model.eval()
503
+
504
+ if write_video:
505
+ all_preds = []
506
+ all_preds_depth = []
507
+
508
+ with torch.no_grad():
509
+
510
+ for i, data in enumerate(loader):
511
+
512
+ with torch.cuda.amp.autocast(enabled=self.fp16):
513
+ preds, preds_depth = self.test_step(data)
514
+
515
+ pred = preds[0].detach().cpu().numpy()
516
+ pred = (pred * 255).astype(np.uint8)
517
+
518
+ pred_depth = preds_depth[0].detach().cpu().numpy()
519
+ pred_depth = (pred_depth * 255).astype(np.uint8)
520
+
521
+ if write_video:
522
+ all_preds.append(pred)
523
+ all_preds_depth.append(pred_depth)
524
+ else:
525
+ cv2.imwrite(os.path.join(save_path, f'{name}_{i:04d}_rgb.png'), cv2.cvtColor(pred, cv2.COLOR_RGB2BGR))
526
+ cv2.imwrite(os.path.join(save_path, f'{name}_{i:04d}_depth.png'), pred_depth)
527
+
528
+ pbar.update(loader.batch_size)
529
+
530
+ if write_video:
531
+ all_preds = np.stack(all_preds, axis=0)
532
+ all_preds_depth = np.stack(all_preds_depth, axis=0)
533
+
534
+ imageio.mimwrite(os.path.join(save_path, f'{name}_rgb.mp4'), all_preds, fps=25, quality=8, macro_block_size=1)
535
+ imageio.mimwrite(os.path.join(save_path, f'{name}_depth.mp4'), all_preds_depth, fps=25, quality=8, macro_block_size=1)
536
+
537
+ self.log(f"==> Finished Test.")
538
+
539
+ # [GUI] train text step.
540
+ def train_gui(self, train_loader, step=16):
541
+
542
+ self.model.train()
543
+
544
+ total_loss = torch.tensor([0], dtype=torch.float32, device=self.device)
545
+
546
+ loader = iter(train_loader)
547
+
548
+ for _ in range(step):
549
+
550
+ # mimic an infinite loop dataloader (in case the total dataset is smaller than step)
551
+ try:
552
+ data = next(loader)
553
+ except StopIteration:
554
+ loader = iter(train_loader)
555
+ data = next(loader)
556
+
557
+ # update grid every 16 steps
558
+ if self.model.cuda_ray and self.global_step % self.opt.update_extra_interval == 0:
559
+ with torch.cuda.amp.autocast(enabled=self.fp16):
560
+ self.model.update_extra_state()
561
+
562
+ self.global_step += 1
563
+
564
+ self.optimizer.zero_grad()
565
+
566
+ with torch.cuda.amp.autocast(enabled=self.fp16):
567
+ pred_rgbs, pred_ws, loss = self.train_step(data)
568
+
569
+ self.scaler.scale(loss).backward()
570
+ self.scaler.step(self.optimizer)
571
+ self.scaler.update()
572
+
573
+ if self.scheduler_update_every_step:
574
+ self.lr_scheduler.step()
575
+
576
+ total_loss += loss.detach()
577
+
578
+ if self.ema is not None:
579
+ self.ema.update()
580
+
581
+ average_loss = total_loss.item() / step
582
+
583
+ if not self.scheduler_update_every_step:
584
+ if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
585
+ self.lr_scheduler.step(average_loss)
586
+ else:
587
+ self.lr_scheduler.step()
588
+
589
+ outputs = {
590
+ 'loss': average_loss,
591
+ 'lr': self.optimizer.param_groups[0]['lr'],
592
+ }
593
+
594
+ return outputs
595
+
596
+
597
+ # [GUI] test on a single image
598
+ def test_gui(self, pose, intrinsics, W, H, bg_color=None, spp=1, downscale=1, light_d=None, ambient_ratio=1.0, shading='albedo'):
599
+
600
+ # render resolution (may need downscale to for better frame rate)
601
+ rH = int(H * downscale)
602
+ rW = int(W * downscale)
603
+ intrinsics = intrinsics * downscale
604
+
605
+ pose = torch.from_numpy(pose).unsqueeze(0).to(self.device)
606
+
607
+ rays = get_rays(pose, intrinsics, rH, rW, -1)
608
+
609
+ # from degree theta/phi to 3D normalized vec
610
+ light_d = np.deg2rad(light_d)
611
+ light_d = np.array([
612
+ np.sin(light_d[0]) * np.sin(light_d[1]),
613
+ np.cos(light_d[0]),
614
+ np.sin(light_d[0]) * np.cos(light_d[1]),
615
+ ], dtype=np.float32)
616
+ light_d = torch.from_numpy(light_d).to(self.device)
617
+
618
+ data = {
619
+ 'rays_o': rays['rays_o'],
620
+ 'rays_d': rays['rays_d'],
621
+ 'H': rH,
622
+ 'W': rW,
623
+ 'light_d': light_d,
624
+ 'ambient_ratio': ambient_ratio,
625
+ 'shading': shading,
626
+ }
627
+
628
+ self.model.eval()
629
+
630
+ if self.ema is not None:
631
+ self.ema.store()
632
+ self.ema.copy_to()
633
+
634
+ with torch.no_grad():
635
+ with torch.cuda.amp.autocast(enabled=self.fp16):
636
+ # here spp is used as perturb random seed!
637
+ preds, preds_depth = self.test_step(data, bg_color=bg_color, perturb=spp)
638
+
639
+ if self.ema is not None:
640
+ self.ema.restore()
641
+
642
+ # interpolation to the original resolution
643
+ if downscale != 1:
644
+ # have to permute twice with torch...
645
+ preds = F.interpolate(preds.permute(0, 3, 1, 2), size=(H, W), mode='nearest').permute(0, 2, 3, 1).contiguous()
646
+ preds_depth = F.interpolate(preds_depth.unsqueeze(1), size=(H, W), mode='nearest').squeeze(1)
647
+
648
+ outputs = {
649
+ 'image': preds[0].detach().cpu().numpy(),
650
+ 'depth': preds_depth[0].detach().cpu().numpy(),
651
+ }
652
+
653
+ return outputs
654
+
655
+ def train_one_epoch(self, loader):
656
+ self.log(f"==> Start Training {self.workspace} Epoch {self.epoch}, lr={self.optimizer.param_groups[0]['lr']:.6f} ...")
657
+
658
+ total_loss = 0
659
+ if self.local_rank == 0 and self.report_metric_at_train:
660
+ for metric in self.metrics:
661
+ metric.clear()
662
+
663
+ self.model.train()
664
+
665
+ # distributedSampler: must call set_epoch() to shuffle indices across multiple epochs
666
+ # ref: https://pytorch.org/docs/stable/data.html
667
+ if self.world_size > 1:
668
+ loader.sampler.set_epoch(self.epoch)
669
+
670
+ if self.local_rank == 0:
671
+ pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')
672
+
673
+ self.local_step = 0
674
+
675
+ for data in loader:
676
+
677
+ # update grid every 16 steps
678
+ if self.model.cuda_ray and self.global_step % self.opt.update_extra_interval == 0:
679
+ with torch.cuda.amp.autocast(enabled=self.fp16):
680
+ self.model.update_extra_state()
681
+
682
+ self.local_step += 1
683
+ self.global_step += 1
684
+
685
+ self.optimizer.zero_grad()
686
+
687
+ with torch.cuda.amp.autocast(enabled=self.fp16):
688
+ pred_rgbs, pred_ws, loss = self.train_step(data)
689
+
690
+ self.scaler.scale(loss).backward()
691
+ self.scaler.step(self.optimizer)
692
+ self.scaler.update()
693
+
694
+ if self.scheduler_update_every_step:
695
+ self.lr_scheduler.step()
696
+
697
+ loss_val = loss.item()
698
+ total_loss += loss_val
699
+
700
+ if self.local_rank == 0:
701
+ # if self.report_metric_at_train:
702
+ # for metric in self.metrics:
703
+ # metric.update(preds, truths)
704
+
705
+ if self.use_tensorboardX:
706
+ self.writer.add_scalar("train/loss", loss_val, self.global_step)
707
+ self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]['lr'], self.global_step)
708
+
709
+ if self.scheduler_update_every_step:
710
+ pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f}), lr={self.optimizer.param_groups[0]['lr']:.6f}")
711
+ else:
712
+ pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})")
713
+ pbar.update(loader.batch_size)
714
+
715
+ if self.ema is not None:
716
+ self.ema.update()
717
+
718
+ average_loss = total_loss / self.local_step
719
+ self.stats["loss"].append(average_loss)
720
+
721
+ if self.local_rank == 0:
722
+ pbar.close()
723
+ if self.report_metric_at_train:
724
+ for metric in self.metrics:
725
+ self.log(metric.report(), style="red")
726
+ if self.use_tensorboardX:
727
+ metric.write(self.writer, self.epoch, prefix="train")
728
+ metric.clear()
729
+
730
+ if not self.scheduler_update_every_step:
731
+ if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
732
+ self.lr_scheduler.step(average_loss)
733
+ else:
734
+ self.lr_scheduler.step()
735
+
736
+ self.log(f"==> Finished Epoch {self.epoch}.")
737
+
738
+
739
+ def evaluate_one_epoch(self, loader, name=None):
740
+ self.log(f"++> Evaluate {self.workspace} at epoch {self.epoch} ...")
741
+
742
+ if name is None:
743
+ name = f'{self.name}_ep{self.epoch:04d}'
744
+
745
+ total_loss = 0
746
+ if self.local_rank == 0:
747
+ for metric in self.metrics:
748
+ metric.clear()
749
+
750
+ self.model.eval()
751
+
752
+ if self.ema is not None:
753
+ self.ema.store()
754
+ self.ema.copy_to()
755
+
756
+ if self.local_rank == 0:
757
+ pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')
758
+
759
+ with torch.no_grad():
760
+ self.local_step = 0
761
+
762
+ for data in loader:
763
+ self.local_step += 1
764
+
765
+ with torch.cuda.amp.autocast(enabled=self.fp16):
766
+ preds, preds_depth, loss = self.eval_step(data)
767
+
768
+ # all_gather/reduce the statistics (NCCL only support all_*)
769
+ if self.world_size > 1:
770
+ dist.all_reduce(loss, op=dist.ReduceOp.SUM)
771
+ loss = loss / self.world_size
772
+
773
+ preds_list = [torch.zeros_like(preds).to(self.device) for _ in range(self.world_size)] # [[B, ...], [B, ...], ...]
774
+ dist.all_gather(preds_list, preds)
775
+ preds = torch.cat(preds_list, dim=0)
776
+
777
+ preds_depth_list = [torch.zeros_like(preds_depth).to(self.device) for _ in range(self.world_size)] # [[B, ...], [B, ...], ...]
778
+ dist.all_gather(preds_depth_list, preds_depth)
779
+ preds_depth = torch.cat(preds_depth_list, dim=0)
780
+
781
+ loss_val = loss.item()
782
+ total_loss += loss_val
783
+
784
+ # only rank = 0 will perform evaluation.
785
+ if self.local_rank == 0:
786
+
787
+ # save image
788
+ save_path = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_rgb.png')
789
+ save_path_depth = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_depth.png')
790
+
791
+ #self.log(f"==> Saving validation image to {save_path}")
792
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
793
+
794
+ pred = preds[0].detach().cpu().numpy()
795
+ pred = (pred * 255).astype(np.uint8)
796
+
797
+ pred_depth = preds_depth[0].detach().cpu().numpy()
798
+ pred_depth = (pred_depth * 255).astype(np.uint8)
799
+
800
+ cv2.imwrite(save_path, cv2.cvtColor(pred, cv2.COLOR_RGB2BGR))
801
+ cv2.imwrite(save_path_depth, pred_depth)
802
+
803
+ pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})")
804
+ pbar.update(loader.batch_size)
805
+
806
+
807
+ average_loss = total_loss / self.local_step
808
+ self.stats["valid_loss"].append(average_loss)
809
+
810
+ if self.local_rank == 0:
811
+ pbar.close()
812
+ if not self.use_loss_as_metric and len(self.metrics) > 0:
813
+ result = self.metrics[0].measure()
814
+ self.stats["results"].append(result if self.best_mode == 'min' else - result) # if max mode, use -result
815
+ else:
816
+ self.stats["results"].append(average_loss) # if no metric, choose best by min loss
817
+
818
+ for metric in self.metrics:
819
+ self.log(metric.report(), style="blue")
820
+ if self.use_tensorboardX:
821
+ metric.write(self.writer, self.epoch, prefix="evaluate")
822
+ metric.clear()
823
+
824
+ if self.ema is not None:
825
+ self.ema.restore()
826
+
827
+ self.log(f"++> Evaluate epoch {self.epoch} Finished.")
828
+
829
+ def save_checkpoint(self, name=None, full=False, best=False):
830
+
831
+ if name is None:
832
+ name = f'{self.name}_ep{self.epoch:04d}'
833
+
834
+ state = {
835
+ 'epoch': self.epoch,
836
+ 'global_step': self.global_step,
837
+ 'stats': self.stats,
838
+ }
839
+
840
+ if self.model.cuda_ray:
841
+ state['mean_count'] = self.model.mean_count
842
+ state['mean_density'] = self.model.mean_density
843
+
844
+ if full:
845
+ state['optimizer'] = self.optimizer.state_dict()
846
+ state['lr_scheduler'] = self.lr_scheduler.state_dict()
847
+ state['scaler'] = self.scaler.state_dict()
848
+ if self.ema is not None:
849
+ state['ema'] = self.ema.state_dict()
850
+
851
+ if not best:
852
+
853
+ state['model'] = self.model.state_dict()
854
+
855
+ file_path = f"{name}.pth"
856
+
857
+ self.stats["checkpoints"].append(file_path)
858
+
859
+ if len(self.stats["checkpoints"]) > self.max_keep_ckpt:
860
+ old_ckpt = os.path.join(self.ckpt_path, self.stats["checkpoints"].pop(0))
861
+ if os.path.exists(old_ckpt):
862
+ os.remove(old_ckpt)
863
+
864
+ torch.save(state, os.path.join(self.ckpt_path, file_path))
865
+
866
+ else:
867
+ if len(self.stats["results"]) > 0:
868
+ if self.stats["best_result"] is None or self.stats["results"][-1] < self.stats["best_result"]:
869
+ self.log(f"[INFO] New best result: {self.stats['best_result']} --> {self.stats['results'][-1]}")
870
+ self.stats["best_result"] = self.stats["results"][-1]
871
+
872
+ # save ema results
873
+ if self.ema is not None:
874
+ self.ema.store()
875
+ self.ema.copy_to()
876
+
877
+ state['model'] = self.model.state_dict()
878
+
879
+ if self.ema is not None:
880
+ self.ema.restore()
881
+
882
+ torch.save(state, self.best_path)
883
+ else:
884
+ self.log(f"[WARN] no evaluated results found, skip saving best checkpoint.")
885
+
886
+ def load_checkpoint(self, checkpoint=None, model_only=False):
887
+ if checkpoint is None:
888
+ checkpoint_list = sorted(glob.glob(f'{self.ckpt_path}/*.pth'))
889
+ if checkpoint_list:
890
+ checkpoint = checkpoint_list[-1]
891
+ self.log(f"[INFO] Latest checkpoint is {checkpoint}")
892
+ else:
893
+ self.log("[WARN] No checkpoint found, model randomly initialized.")
894
+ return
895
+
896
+ checkpoint_dict = torch.load(checkpoint, map_location=self.device)
897
+
898
+ if 'model' not in checkpoint_dict:
899
+ self.model.load_state_dict(checkpoint_dict)
900
+ self.log("[INFO] loaded model.")
901
+ return
902
+
903
+ missing_keys, unexpected_keys = self.model.load_state_dict(checkpoint_dict['model'], strict=False)
904
+ self.log("[INFO] loaded model.")
905
+ if len(missing_keys) > 0:
906
+ self.log(f"[WARN] missing keys: {missing_keys}")
907
+ if len(unexpected_keys) > 0:
908
+ self.log(f"[WARN] unexpected keys: {unexpected_keys}")
909
+
910
+ if self.ema is not None and 'ema' in checkpoint_dict:
911
+ try:
912
+ self.ema.load_state_dict(checkpoint_dict['ema'])
913
+ self.log("[INFO] loaded EMA.")
914
+ except:
915
+ self.log("[WARN] failed to loaded EMA.")
916
+
917
+ if self.model.cuda_ray:
918
+ if 'mean_count' in checkpoint_dict:
919
+ self.model.mean_count = checkpoint_dict['mean_count']
920
+ if 'mean_density' in checkpoint_dict:
921
+ self.model.mean_density = checkpoint_dict['mean_density']
922
+
923
+ if model_only:
924
+ return
925
+
926
+ self.stats = checkpoint_dict['stats']
927
+ self.epoch = checkpoint_dict['epoch']
928
+ self.global_step = checkpoint_dict['global_step']
929
+ self.log(f"[INFO] load at epoch {self.epoch}, global step {self.global_step}")
930
+
931
+ if self.optimizer and 'optimizer' in checkpoint_dict:
932
+ try:
933
+ self.optimizer.load_state_dict(checkpoint_dict['optimizer'])
934
+ self.log("[INFO] loaded optimizer.")
935
+ except:
936
+ self.log("[WARN] Failed to load optimizer.")
937
+
938
+ if self.lr_scheduler and 'lr_scheduler' in checkpoint_dict:
939
+ try:
940
+ self.lr_scheduler.load_state_dict(checkpoint_dict['lr_scheduler'])
941
+ self.log("[INFO] loaded scheduler.")
942
+ except:
943
+ self.log("[WARN] Failed to load scheduler.")
944
+
945
+ if self.scaler and 'scaler' in checkpoint_dict:
946
+ try:
947
+ self.scaler.load_state_dict(checkpoint_dict['scaler'])
948
+ self.log("[INFO] loaded scaler.")
949
+ except:
950
+ self.log("[WARN] Failed to load scaler.")
optimizer.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import enum
4
+ import itertools
5
+ from dataclasses import dataclass
6
+ import torch.optim as optim
7
+
8
+ @torch.no_grad()
9
+ def PowerIter(mat_g, error_tolerance=1e-6, num_iters=100):
10
+ """Power iteration.
11
+ Compute the maximum eigenvalue of mat, for scaling.
12
+ v is a random vector with values in (-1, 1)
13
+ Args:
14
+ mat_g: the symmetric PSD matrix.
15
+ error_tolerance: Iterative exit condition.
16
+ num_iters: Number of iterations.
17
+ Returns:
18
+ eigen vector, eigen value, num_iters
19
+ """
20
+ v = torch.rand(list(mat_g.shape)[0], device=mat_g.get_device()) * 2 - 1
21
+ error = 1
22
+ iters = 0
23
+ singular_val = 0
24
+ while error > error_tolerance and iters < num_iters:
25
+ v = v / torch.norm(v)
26
+ mat_v = torch.mv(mat_g, v)
27
+ s_v = torch.dot(v, mat_v)
28
+ error = torch.abs(s_v - singular_val)
29
+ v = mat_v
30
+ singular_val = s_v
31
+ iters += 1
32
+ return singular_val, v / torch.norm(v), iters
33
+
34
+
35
+ @torch.no_grad()
36
+ def MatPower(mat_m, p):
37
+ """Computes mat_m^p, for p a positive integer.
38
+ Args:
39
+ mat_m: a square matrix
40
+ p: a positive integer
41
+ Returns:
42
+ mat_m^p
43
+ """
44
+ if p in [1, 2, 4, 8, 16, 32]:
45
+ p_done = 1
46
+ res = mat_m
47
+ while p_done < p:
48
+ res = torch.matmul(res, res)
49
+ p_done *= 2
50
+ return res
51
+
52
+ power = None
53
+ while p > 0:
54
+ if p % 2 == 1:
55
+ power = torch.matmul(mat_m, power) if power is not None else mat_m
56
+ p //= 2
57
+ mat_m = torch.matmul(mat_m, mat_m)
58
+ return power
59
+
60
+
61
+ @torch.no_grad()
62
+ def ComputePower(mat_g, p,
63
+ iter_count=100,
64
+ error_tolerance=1e-6,
65
+ ridge_epsilon=1e-6):
66
+ """A method to compute G^{-1/p} using a coupled Newton iteration.
67
+ See for example equation 3.2 on page 9 of:
68
+ A Schur-Newton Method for the Matrix p-th Root and its Inverse
69
+ by Chun-Hua Guo and Nicholas J. Higham
70
+ SIAM Journal on Matrix Analysis and Applications,
71
+ 2006, Vol. 28, No. 3 : pp. 788-804
72
+ https://pdfs.semanticscholar.org/0abe/7f77433cf5908bfe2b79aa91af881da83858.pdf
73
+ Args:
74
+ mat_g: A square positive semidefinite matrix
75
+ p: a positive integer
76
+ iter_count: Stop iterating after this many rounds.
77
+ error_tolerance: Threshold for stopping iteration
78
+ ridge_epsilon: We add this times I to G, to make is positive definite.
79
+ For scaling, we multiply it by the largest eigenvalue of G.
80
+ Returns:
81
+ (mat_g + rI)^{-1/p} (r = ridge_epsilon * max_eigenvalue of mat_g).
82
+ """
83
+ shape = list(mat_g.shape)
84
+ if len(shape) == 1:
85
+ return torch.pow(mat_g + ridge_epsilon, -1/p)
86
+ identity = torch.eye(shape[0], device=mat_g.get_device())
87
+ if shape[0] == 1:
88
+ return identity
89
+ alpha = -1.0/p
90
+ max_ev, _, _ = PowerIter(mat_g)
91
+ ridge_epsilon *= max_ev
92
+ mat_g += ridge_epsilon * identity
93
+ z = (1 + p) / (2 * torch.norm(mat_g))
94
+ # The best value for z is
95
+ # (1 + p) * (c_max^{1/p} - c_min^{1/p}) /
96
+ # (c_max^{1+1/p} - c_min^{1+1/p})
97
+ # where c_max and c_min are the largest and smallest singular values of
98
+ # mat_g.
99
+ # The above estimate assumes that c_max > c_min * 2^p
100
+ # Can replace above line by the one below, but it is less accurate,
101
+ # hence needs more iterations to converge.
102
+ # z = (1 + p) / tf.trace(mat_g)
103
+ # If we want the method to always converge, use z = 1 / norm(mat_g)
104
+ # or z = 1 / tf.trace(mat_g), but these can result in many
105
+ # extra iterations.
106
+
107
+ mat_root = identity * torch.pow(z, 1.0/p)
108
+ mat_m = mat_g * z
109
+ error = torch.max(torch.abs(mat_m - identity))
110
+ count = 0
111
+ while error > error_tolerance and count < iter_count:
112
+ tmp_mat_m = (1 - alpha) * identity + alpha * mat_m
113
+ new_mat_root = torch.matmul(mat_root, tmp_mat_m)
114
+ mat_m = torch.matmul(MatPower(tmp_mat_m, p), mat_m)
115
+ new_error = torch.max(torch.abs(mat_m - identity))
116
+ if new_error > error * 1.2:
117
+ break
118
+ mat_root = new_mat_root
119
+ error = new_error
120
+ count += 1
121
+ return mat_root
122
+
123
+
124
+
125
+ # Grafting is a technique to fix the layerwise scale of Shampoo optimizer.
126
+ # https://arxiv.org/pdf/2002.11803.pdf studies this in detail. This
127
+ # allows us to plugin the Shampoo optimizer into settings where SGD/AdaGrad
128
+ # is already well tuned. Grafting onto Shampoo means take the Shampoo direction,
129
+ # but use the step magnitude from the grafted optimizer such as Adagrad or SGD.
130
+ class LayerwiseGrafting(enum.IntEnum):
131
+ NONE = 0
132
+ SGD = 1
133
+ ADAGRAD = 2
134
+
135
+
136
+ @dataclass
137
+ class ShampooHyperParams:
138
+ """Shampoo hyper parameters."""
139
+ beta2: float = 0.9
140
+ diagonal_eps: float = 1e-6
141
+ matrix_eps: float = 1e-12
142
+ weight_decay: float = 0.0
143
+ inverse_exponent_override: int = 2 # fixed exponent for preconditioner, if >0
144
+ start_preconditioning_step: int = 1
145
+ # Performance tuning params for controlling memory and compute requirements.
146
+ # How often to compute preconditioner.
147
+ preconditioning_compute_steps: int = 1
148
+ # How often to compute statistics.
149
+ statistics_compute_steps: int = 1
150
+ # Block size for large layers (if > 0).
151
+ # Block size = 1 ==> Adagrad (Don't do this, extremely inefficient!)
152
+ # Block size should be as large as feasible under memory/time constraints.
153
+ block_size: int = 128
154
+ # Automatic shape interpretation (for eg: [4, 3, 1024, 512] would result in
155
+ # 12 x [1024, 512] L and R statistics. Disabled by default which results in
156
+ # Shampoo constructing statistics [4, 4], [3, 3], [1024, 1024], [512, 512].
157
+ best_effort_shape_interpretation: bool = True
158
+ # Type of grafting (SGD or AdaGrad).
159
+ # https://arxiv.org/pdf/2002.11803.pdf
160
+ graft_type: int = LayerwiseGrafting.ADAGRAD
161
+ # Nesterov momentum
162
+ nesterov: bool = True
163
+
164
+
165
+ class Graft:
166
+ """Base class to perform grafting onto Shampoo. This class does no grafting.
167
+ """
168
+
169
+ def __init__(self, hps, unused_var):
170
+ self.hps = hps
171
+
172
+ def add_statistics(self, grad):
173
+ pass
174
+
175
+ def precondition_gradient(self, grad):
176
+ return grad
177
+
178
+ def update_momentum(self, update, unused_beta1):
179
+ return update
180
+
181
+
182
+ class SGDGraft(Graft):
183
+ """Graft using SGD+momentum.
184
+ momentum maintains an exponentially weighted moving average of gradients.
185
+ """
186
+
187
+ def __init__(self, hps, var):
188
+ super(SGDGraft, self).__init__(hps, var)
189
+ self.momentum = torch.zeros_like(var.data, device=var.get_device())
190
+
191
+ def update_momentum(self, update, beta1):
192
+ self.momentum.mul_(beta1).add_(update)
193
+ return self.momentum
194
+
195
+
196
+ class AdagradGraft(SGDGraft):
197
+ """Graft using Adagrad.
198
+ Essentially an implementation of Adagrad with momentum.
199
+ """
200
+
201
+ def __init__(self, hps, var):
202
+ super(AdagradGraft, self).__init__(hps, var)
203
+ self.statistics = torch.zeros_like(var.data, device=var.get_device())
204
+
205
+ def add_statistics(self, grad):
206
+ self.statistics.add_(grad * grad)
207
+
208
+ def precondition_gradient(self, grad):
209
+ return grad / (torch.sqrt(self.statistics) + self.hps.diagonal_eps)
210
+
211
+
212
+ class BlockPartitioner:
213
+ """Partitions a tensor into smaller tensors for preconditioning.
214
+ For example, if a variable has shape (4096, 512), we might split the
215
+ 4096 into 4 blocks, so we effectively have 4 variables of size
216
+ (1024, 512) each.
217
+ """
218
+
219
+ def __init__(self, var, hps):
220
+ self._shape = var.shape
221
+ self._splits = []
222
+ self._split_sizes = []
223
+ split_sizes = []
224
+ # We split var into smaller blocks. Here we store the metadata to make
225
+ # that split.
226
+ for i, d in enumerate(var.shape):
227
+ if hps.block_size > 0 and d > hps.block_size:
228
+ # d-1, otherwise split appends a 0-size array.
229
+ nsplit = (d-1) // hps.block_size
230
+ indices = (np.arange(nsplit, dtype=np.int32) + 1) * hps.block_size
231
+ sizes = np.ones(nsplit + 1, dtype=np.int32) * hps.block_size
232
+ sizes[-1] = d - indices[-1]
233
+ self._splits.append((i, indices))
234
+ self._split_sizes.append((i, sizes))
235
+ split_sizes.append(sizes)
236
+ else:
237
+ split_sizes.append(np.array([d], dtype=np.int32))
238
+ self._num_splits = len(split_sizes)
239
+ self._preconditioner_shapes = []
240
+ for t in itertools.product(*split_sizes):
241
+ self._preconditioner_shapes.extend([[d, d] for d in t])
242
+
243
+ def shapes_for_preconditioners(self):
244
+ return self._preconditioner_shapes
245
+
246
+ def num_splits(self):
247
+ return self._num_splits
248
+
249
+ def partition(self, tensor):
250
+ """Partition tensor into blocks."""
251
+
252
+ assert tensor.shape == self._shape
253
+ tensors = [tensor]
254
+ for (i, sizes) in self._split_sizes:
255
+ tensors_local = []
256
+ for t in tensors:
257
+ tensors_local.extend(
258
+ torch.split(t, tuple(sizes), dim=i))
259
+ tensors = tensors_local
260
+ return tensors
261
+
262
+ def merge_partitions(self, partitions):
263
+ """Merge partitions back to original shape."""
264
+
265
+ for (i, indices) in reversed(self._splits):
266
+ n = len(indices) + 1
267
+ partial_merged_tensors = []
268
+ ind = 0
269
+ while ind < len(partitions):
270
+ partial_merged_tensors.append(
271
+ torch.cat(partitions[ind:ind + n], axis=i))
272
+ ind += n
273
+ partitions = partial_merged_tensors
274
+ assert len(partitions) == 1
275
+ return partitions[0]
276
+
277
+
278
+ def _merge_small_dims(shape_to_merge, max_dim):
279
+ """Merge small dimensions.
280
+ If there are some small dimensions, we collapse them:
281
+ e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024
282
+ [1, 2, 768, 1, 2048] --> [2, 768, 2048]
283
+ Args:
284
+ shape_to_merge: Shape to merge small dimensions.
285
+ max_dim: Maximal dimension of output shape used in merging.
286
+ Returns:
287
+ Merged shape.
288
+ """
289
+ resulting_shape = []
290
+ product = 1
291
+ for d in shape_to_merge:
292
+ if product * d <= max_dim:
293
+ product *= d
294
+ else:
295
+ if product > 1:
296
+ resulting_shape.append(product)
297
+ product = d
298
+ if product > 1:
299
+ resulting_shape.append(product)
300
+ return resulting_shape
301
+
302
+
303
+ class Preconditioner:
304
+ """Compute statistics/shape from gradients for preconditioning."""
305
+
306
+ def __init__(self, var, hps):
307
+ self._hps = hps
308
+ self._original_shape = var.shape
309
+ self._transformed_shape = var.shape
310
+ if hps.best_effort_shape_interpretation:
311
+ self._transformed_shape = _merge_small_dims(
312
+ self._original_shape, hps.block_size)
313
+
314
+ reshaped_var = torch.reshape(var, self._transformed_shape)
315
+ self._partitioner = BlockPartitioner(reshaped_var, hps)
316
+ shapes = self._partitioner.shapes_for_preconditioners()
317
+ rank = len(self._transformed_shape)
318
+ device = var.get_device()
319
+ if rank <= 1:
320
+ self.statistics = []
321
+ self.preconditioners = []
322
+ else:
323
+ eps = self._hps.matrix_eps
324
+ self.statistics = [eps * torch.eye(s[0], device=device) for s in shapes]
325
+ self.preconditioners = [torch.eye(s[0], device=device) for s in shapes]
326
+
327
+ def add_statistics(self, grad):
328
+ """Compute statistics from gradients and add to the correct state entries.
329
+ Args:
330
+ grad: Gradient to compute statistics from.
331
+ """
332
+ if not self.statistics: return
333
+ reshaped_grad = torch.reshape(grad, self._transformed_shape)
334
+ partitioned_grads = self._partitioner.partition(reshaped_grad)
335
+ w1 = self._hps.beta2
336
+ w2 = 1.0 if w1 == 1.0 else (1.0 - w1)
337
+ rank = len(self._transformed_shape)
338
+ for j, grad in enumerate(partitioned_grads):
339
+ for i in range(rank):
340
+ axes = list(range(i)) + list(range(i + 1, rank))
341
+ stat = torch.tensordot(grad, grad, [axes, axes])
342
+ self.statistics[j*rank + i].mul_(w1).add_(stat, alpha=w2)
343
+
344
+ def exponent_for_preconditioner(self):
345
+ """Returns exponent to use for inverse-pth root M^{-1/p}."""
346
+ if self._hps.inverse_exponent_override > 0:
347
+ return self._hps.inverse_exponent_override
348
+ return 2 * len(self._transformed_shape)
349
+
350
+ def compute_preconditioners(self):
351
+ """Compute L^{-1/exp} for each stats matrix L."""
352
+ exp = self.exponent_for_preconditioner()
353
+ eps = self._hps.matrix_eps
354
+ for i, stat in enumerate(self.statistics):
355
+ self.preconditioners[i] = ComputePower(
356
+ stat, exp, ridge_epsilon=eps)
357
+
358
+ def preconditioned_grad(self, grad):
359
+ """Precondition the gradient.
360
+ Args:
361
+ grad: A gradient tensor to precondition.
362
+ Returns:
363
+ A preconditioned gradient.
364
+ """
365
+ if not self.preconditioners: return grad
366
+ reshaped_grad = torch.reshape(grad, self._transformed_shape)
367
+ partitioned_grads = self._partitioner.partition(reshaped_grad)
368
+ preconditioned_partitioned_grads = []
369
+ num_splits = self._partitioner.num_splits()
370
+ for i, grad in enumerate(partitioned_grads):
371
+ preconditioners_for_grad = self.preconditioners[i * num_splits:(i + 1) *
372
+ num_splits]
373
+ rank = len(grad.shape)
374
+ precond_grad = grad
375
+ for j in range(rank):
376
+ preconditioner = preconditioners_for_grad[j]
377
+ precond_grad = torch.tensordot(
378
+ precond_grad, preconditioner, [[0], [0]])
379
+ preconditioned_partitioned_grads.append(precond_grad)
380
+ merged_grad = self._partitioner.merge_partitions(
381
+ preconditioned_partitioned_grads)
382
+ return torch.reshape(merged_grad, self._original_shape)
383
+
384
+
385
+ STEP = 'step'
386
+ MOMENTUM = 'momentum'
387
+ PRECONDITIONER = 'preconditioner'
388
+ GRAFT = 'graft'
389
+
390
+
391
+ class Shampoo(optim.Optimizer):
392
+ """The Shampoo optimizer."""
393
+
394
+ def __init__(self,
395
+ params,
396
+ lr=1.0,
397
+ momentum=0.9,
398
+ hyperparams=ShampooHyperParams()):
399
+ defaults = dict(lr=lr, momentum=momentum)
400
+ self.hps = hyperparams
401
+ super(Shampoo, self).__init__(params, defaults)
402
+
403
+ def init_var_state(self, var, state):
404
+ """Initialize the PyTorch state of for a single variable."""
405
+ state[STEP] = 0
406
+ state[MOMENTUM] = torch.zeros_like(var.data, device=var.get_device())
407
+ state[PRECONDITIONER] = Preconditioner(var, self.hps)
408
+ if self.hps.graft_type == LayerwiseGrafting.ADAGRAD:
409
+ state[GRAFT] = AdagradGraft(self.hps, var)
410
+ elif self.hps.graft_type == LayerwiseGrafting.SGD:
411
+ state[GRAFT] = SGDGraft(self.hps, var)
412
+ else:
413
+ state[GRAFT] = Graft(self.hps, var)
414
+
415
+ def step(self, closure=None):
416
+ hps = self.hps
417
+ for group in self.param_groups:
418
+ lr = group['lr']
419
+ for p in group['params']:
420
+ if p.grad is None: continue
421
+ grad = p.grad.data
422
+ if grad.is_sparse:
423
+ raise RuntimeError('Shampoo does not support sparse yet')
424
+ state = self.state[p]
425
+ if not state:
426
+ self.init_var_state(p, state)
427
+ state[STEP] += 1
428
+
429
+ preconditioner = state[PRECONDITIONER]
430
+ graft = state[GRAFT]
431
+
432
+ # Gather statistics, compute preconditioners
433
+ graft.add_statistics(grad)
434
+ if state[STEP] % hps.statistics_compute_steps == 0:
435
+ preconditioner.add_statistics(grad)
436
+ if state[STEP] % hps.preconditioning_compute_steps == 0:
437
+ preconditioner.compute_preconditioners()
438
+
439
+ # Precondition gradients
440
+ graft_grad = graft.precondition_gradient(grad)
441
+ shampoo_grad = grad
442
+ if state[STEP] >= self.hps.start_preconditioning_step:
443
+ shampoo_grad = preconditioner.preconditioned_grad(grad)
444
+
445
+ # Grafting
446
+ graft_norm = torch.norm(graft_grad)
447
+ shampoo_norm = torch.norm(shampoo_grad)
448
+ shampoo_grad.mul_(graft_norm / (shampoo_norm + 1e-16))
449
+
450
+ # Weight decay
451
+ if self.hps.weight_decay != 0.0:
452
+ shampoo_grad.add_(p.data, alpha=self.hps.weight_decay)
453
+ graft_grad.add_(p.data, alpha=self.hps.weight_decay)
454
+
455
+ # Momentum and Nesterov momentum, if needed
456
+ state[MOMENTUM].mul_(group['momentum']).add_(shampoo_grad)
457
+ graft_momentum = graft.update_momentum(grad, group['momentum'])
458
+
459
+ if state[STEP] >= self.hps.start_preconditioning_step:
460
+ momentum_update = state[MOMENTUM]
461
+ wd_update = shampoo_grad
462
+ else:
463
+ momentum_update = graft_momentum
464
+ wd_update = graft_grad
465
+
466
+ if hps.nesterov:
467
+ momentum_update.mul_(group['momentum']).add_(wd_update)
468
+
469
+ # Final update
470
+ p.data.add_(momentum_update, alpha=-lr)
raymarching/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .raymarching import *
raymarching/backend.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from torch.utils.cpp_extension import load
3
+
4
+ _src_path = os.path.dirname(os.path.abspath(__file__))
5
+
6
+ nvcc_flags = [
7
+ '-O3', '-std=c++14',
8
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
9
+ ]
10
+
11
+ if os.name == "posix":
12
+ c_flags = ['-O3', '-std=c++14']
13
+ elif os.name == "nt":
14
+ c_flags = ['/O2', '/std:c++17']
15
+
16
+ # find cl.exe
17
+ def find_cl_path():
18
+ import glob
19
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
20
+ paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
21
+ if paths:
22
+ return paths[0]
23
+
24
+ # If cl.exe is not on path, try to find it.
25
+ if os.system("where cl.exe >nul 2>nul") != 0:
26
+ cl_path = find_cl_path()
27
+ if cl_path is None:
28
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
29
+ os.environ["PATH"] += ";" + cl_path
30
+
31
+ _backend = load(name='_raymarching',
32
+ extra_cflags=c_flags,
33
+ extra_cuda_cflags=nvcc_flags,
34
+ sources=[os.path.join(_src_path, 'src', f) for f in [
35
+ 'raymarching.cu',
36
+ 'bindings.cpp',
37
+ ]],
38
+ )
39
+
40
+ __all__ = ['_backend']
raymarching/raymarching.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import time
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.autograd import Function
7
+ from torch.cuda.amp import custom_bwd, custom_fwd
8
+
9
+ try:
10
+ import _raymarching as _backend
11
+ except ImportError:
12
+ from .backend import _backend
13
+
14
+
15
+ # ----------------------------------------
16
+ # utils
17
+ # ----------------------------------------
18
+
19
+ class _near_far_from_aabb(Function):
20
+ @staticmethod
21
+ @custom_fwd(cast_inputs=torch.float32)
22
+ def forward(ctx, rays_o, rays_d, aabb, min_near=0.2):
23
+ ''' near_far_from_aabb, CUDA implementation
24
+ Calculate rays' intersection time (near and far) with aabb
25
+ Args:
26
+ rays_o: float, [N, 3]
27
+ rays_d: float, [N, 3]
28
+ aabb: float, [6], (xmin, ymin, zmin, xmax, ymax, zmax)
29
+ min_near: float, scalar
30
+ Returns:
31
+ nears: float, [N]
32
+ fars: float, [N]
33
+ '''
34
+ if not rays_o.is_cuda: rays_o = rays_o.cuda()
35
+ if not rays_d.is_cuda: rays_d = rays_d.cuda()
36
+
37
+ rays_o = rays_o.contiguous().view(-1, 3)
38
+ rays_d = rays_d.contiguous().view(-1, 3)
39
+
40
+ N = rays_o.shape[0] # num rays
41
+
42
+ nears = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)
43
+ fars = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)
44
+
45
+ _backend.near_far_from_aabb(rays_o, rays_d, aabb, N, min_near, nears, fars)
46
+
47
+ return nears, fars
48
+
49
+ near_far_from_aabb = _near_far_from_aabb.apply
50
+
51
+
52
+ class _sph_from_ray(Function):
53
+ @staticmethod
54
+ @custom_fwd(cast_inputs=torch.float32)
55
+ def forward(ctx, rays_o, rays_d, radius):
56
+ ''' sph_from_ray, CUDA implementation
57
+ get spherical coordinate on the background sphere from rays.
58
+ Assume rays_o are inside the Sphere(radius).
59
+ Args:
60
+ rays_o: [N, 3]
61
+ rays_d: [N, 3]
62
+ radius: scalar, float
63
+ Return:
64
+ coords: [N, 2], in [-1, 1], theta and phi on a sphere. (further-surface)
65
+ '''
66
+ if not rays_o.is_cuda: rays_o = rays_o.cuda()
67
+ if not rays_d.is_cuda: rays_d = rays_d.cuda()
68
+
69
+ rays_o = rays_o.contiguous().view(-1, 3)
70
+ rays_d = rays_d.contiguous().view(-1, 3)
71
+
72
+ N = rays_o.shape[0] # num rays
73
+
74
+ coords = torch.empty(N, 2, dtype=rays_o.dtype, device=rays_o.device)
75
+
76
+ _backend.sph_from_ray(rays_o, rays_d, radius, N, coords)
77
+
78
+ return coords
79
+
80
+ sph_from_ray = _sph_from_ray.apply
81
+
82
+
83
+ class _morton3D(Function):
84
+ @staticmethod
85
+ def forward(ctx, coords):
86
+ ''' morton3D, CUDA implementation
87
+ Args:
88
+ coords: [N, 3], int32, in [0, 128) (for some reason there is no uint32 tensor in torch...)
89
+ TODO: check if the coord range is valid! (current 128 is safe)
90
+ Returns:
91
+ indices: [N], int32, in [0, 128^3)
92
+
93
+ '''
94
+ if not coords.is_cuda: coords = coords.cuda()
95
+
96
+ N = coords.shape[0]
97
+
98
+ indices = torch.empty(N, dtype=torch.int32, device=coords.device)
99
+
100
+ _backend.morton3D(coords.int(), N, indices)
101
+
102
+ return indices
103
+
104
+ morton3D = _morton3D.apply
105
+
106
+ class _morton3D_invert(Function):
107
+ @staticmethod
108
+ def forward(ctx, indices):
109
+ ''' morton3D_invert, CUDA implementation
110
+ Args:
111
+ indices: [N], int32, in [0, 128^3)
112
+ Returns:
113
+ coords: [N, 3], int32, in [0, 128)
114
+
115
+ '''
116
+ if not indices.is_cuda: indices = indices.cuda()
117
+
118
+ N = indices.shape[0]
119
+
120
+ coords = torch.empty(N, 3, dtype=torch.int32, device=indices.device)
121
+
122
+ _backend.morton3D_invert(indices.int(), N, coords)
123
+
124
+ return coords
125
+
126
+ morton3D_invert = _morton3D_invert.apply
127
+
128
+
129
+ class _packbits(Function):
130
+ @staticmethod
131
+ @custom_fwd(cast_inputs=torch.float32)
132
+ def forward(ctx, grid, thresh, bitfield=None):
133
+ ''' packbits, CUDA implementation
134
+ Pack up the density grid into a bit field to accelerate ray marching.
135
+ Args:
136
+ grid: float, [C, H * H * H], assume H % 2 == 0
137
+ thresh: float, threshold
138
+ Returns:
139
+ bitfield: uint8, [C, H * H * H / 8]
140
+ '''
141
+ if not grid.is_cuda: grid = grid.cuda()
142
+ grid = grid.contiguous()
143
+
144
+ C = grid.shape[0]
145
+ H3 = grid.shape[1]
146
+ N = C * H3 // 8
147
+
148
+ if bitfield is None:
149
+ bitfield = torch.empty(N, dtype=torch.uint8, device=grid.device)
150
+
151
+ _backend.packbits(grid, N, thresh, bitfield)
152
+
153
+ return bitfield
154
+
155
+ packbits = _packbits.apply
156
+
157
+ # ----------------------------------------
158
+ # train functions
159
+ # ----------------------------------------
160
+
161
+ class _march_rays_train(Function):
162
+ @staticmethod
163
+ @custom_fwd(cast_inputs=torch.float32)
164
+ def forward(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, step_counter=None, mean_count=-1, perturb=False, align=-1, force_all_rays=False, dt_gamma=0, max_steps=1024):
165
+ ''' march rays to generate points (forward only)
166
+ Args:
167
+ rays_o/d: float, [N, 3]
168
+ bound: float, scalar
169
+ density_bitfield: uint8: [CHHH // 8]
170
+ C: int
171
+ H: int
172
+ nears/fars: float, [N]
173
+ step_counter: int32, (2), used to count the actual number of generated points.
174
+ mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.)
175
+ perturb: bool
176
+ align: int, pad output so its size is dividable by align, set to -1 to disable.
177
+ force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays.
178
+ dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)
179
+ max_steps: int, max number of sampled points along each ray, also affect min_stepsize.
180
+ Returns:
181
+ xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray)
182
+ dirs: float, [M, 3], all generated points' view dirs.
183
+ deltas: float, [M, 2], all generated points' deltas. (first for RGB, second for Depth)
184
+ rays: int32, [N, 3], all rays' (index, point_offset, point_count), e.g., xyzs[rays[i, 1]:rays[i, 2]] --> points belonging to rays[i, 0]
185
+ '''
186
+
187
+ if not rays_o.is_cuda: rays_o = rays_o.cuda()
188
+ if not rays_d.is_cuda: rays_d = rays_d.cuda()
189
+ if not density_bitfield.is_cuda: density_bitfield = density_bitfield.cuda()
190
+
191
+ rays_o = rays_o.contiguous().view(-1, 3)
192
+ rays_d = rays_d.contiguous().view(-1, 3)
193
+ density_bitfield = density_bitfield.contiguous()
194
+
195
+ N = rays_o.shape[0] # num rays
196
+ M = N * max_steps # init max points number in total
197
+
198
+ # running average based on previous epoch (mimic `measured_batch_size_before_compaction` in instant-ngp)
199
+ # It estimate the max points number to enable faster training, but will lead to random ignored rays if underestimated.
200
+ if not force_all_rays and mean_count > 0:
201
+ if align > 0:
202
+ mean_count += align - mean_count % align
203
+ M = mean_count
204
+
205
+ xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
206
+ dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
207
+ deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device)
208
+ rays = torch.empty(N, 3, dtype=torch.int32, device=rays_o.device) # id, offset, num_steps
209
+
210
+ if step_counter is None:
211
+ step_counter = torch.zeros(2, dtype=torch.int32, device=rays_o.device) # point counter, ray counter
212
+
213
+ if perturb:
214
+ noises = torch.rand(N, dtype=rays_o.dtype, device=rays_o.device)
215
+ else:
216
+ noises = torch.zeros(N, dtype=rays_o.dtype, device=rays_o.device)
217
+
218
+ _backend.march_rays_train(rays_o, rays_d, density_bitfield, bound, dt_gamma, max_steps, N, C, H, M, nears, fars, xyzs, dirs, deltas, rays, step_counter, noises) # m is the actually used points number
219
+
220
+ #print(step_counter, M)
221
+
222
+ # only used at the first (few) epochs.
223
+ if force_all_rays or mean_count <= 0:
224
+ m = step_counter[0].item() # D2H copy
225
+ if align > 0:
226
+ m += align - m % align
227
+ xyzs = xyzs[:m]
228
+ dirs = dirs[:m]
229
+ deltas = deltas[:m]
230
+
231
+ torch.cuda.empty_cache()
232
+
233
+ return xyzs, dirs, deltas, rays
234
+
235
+ march_rays_train = _march_rays_train.apply
236
+
237
+
238
+ class _composite_rays_train(Function):
239
+ @staticmethod
240
+ @custom_fwd(cast_inputs=torch.float32)
241
+ def forward(ctx, sigmas, rgbs, deltas, rays, T_thresh=1e-4):
242
+ ''' composite rays' rgbs, according to the ray marching formula.
243
+ Args:
244
+ rgbs: float, [M, 3]
245
+ sigmas: float, [M,]
246
+ deltas: float, [M, 2]
247
+ rays: int32, [N, 3]
248
+ Returns:
249
+ weights_sum: float, [N,], the alpha channel
250
+ depth: float, [N, ], the Depth
251
+ image: float, [N, 3], the RGB channel (after multiplying alpha!)
252
+ '''
253
+
254
+ sigmas = sigmas.contiguous()
255
+ rgbs = rgbs.contiguous()
256
+
257
+ M = sigmas.shape[0]
258
+ N = rays.shape[0]
259
+
260
+ weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
261
+ depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
262
+ image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device)
263
+
264
+ _backend.composite_rays_train_forward(sigmas, rgbs, deltas, rays, M, N, T_thresh, weights_sum, depth, image)
265
+
266
+ ctx.save_for_backward(sigmas, rgbs, deltas, rays, weights_sum, depth, image)
267
+ ctx.dims = [M, N, T_thresh]
268
+
269
+ return weights_sum, depth, image
270
+
271
+ @staticmethod
272
+ @custom_bwd
273
+ def backward(ctx, grad_weights_sum, grad_depth, grad_image):
274
+
275
+ # NOTE: grad_depth is not used now! It won't be propagated to sigmas.
276
+
277
+ grad_weights_sum = grad_weights_sum.contiguous()
278
+ grad_image = grad_image.contiguous()
279
+
280
+ sigmas, rgbs, deltas, rays, weights_sum, depth, image = ctx.saved_tensors
281
+ M, N, T_thresh = ctx.dims
282
+
283
+ grad_sigmas = torch.zeros_like(sigmas)
284
+ grad_rgbs = torch.zeros_like(rgbs)
285
+
286
+ _backend.composite_rays_train_backward(grad_weights_sum, grad_image, sigmas, rgbs, deltas, rays, weights_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs)
287
+
288
+ return grad_sigmas, grad_rgbs, None, None, None
289
+
290
+
291
+ composite_rays_train = _composite_rays_train.apply
292
+
293
+ # ----------------------------------------
294
+ # infer functions
295
+ # ----------------------------------------
296
+
297
+ class _march_rays(Function):
298
+ @staticmethod
299
+ @custom_fwd(cast_inputs=torch.float32)
300
+ def forward(ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, density_bitfield, C, H, near, far, align=-1, perturb=False, dt_gamma=0, max_steps=1024):
301
+ ''' march rays to generate points (forward only, for inference)
302
+ Args:
303
+ n_alive: int, number of alive rays
304
+ n_step: int, how many steps we march
305
+ rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive)
306
+ rays_t: float, [N], the alive rays' time, we only use the first n_alive.
307
+ rays_o/d: float, [N, 3]
308
+ bound: float, scalar
309
+ density_bitfield: uint8: [CHHH // 8]
310
+ C: int
311
+ H: int
312
+ nears/fars: float, [N]
313
+ align: int, pad output so its size is dividable by align, set to -1 to disable.
314
+ perturb: bool/int, int > 0 is used as the random seed.
315
+ dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)
316
+ max_steps: int, max number of sampled points along each ray, also affect min_stepsize.
317
+ Returns:
318
+ xyzs: float, [n_alive * n_step, 3], all generated points' coords
319
+ dirs: float, [n_alive * n_step, 3], all generated points' view dirs.
320
+ deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth).
321
+ '''
322
+
323
+ if not rays_o.is_cuda: rays_o = rays_o.cuda()
324
+ if not rays_d.is_cuda: rays_d = rays_d.cuda()
325
+
326
+ rays_o = rays_o.contiguous().view(-1, 3)
327
+ rays_d = rays_d.contiguous().view(-1, 3)
328
+
329
+ M = n_alive * n_step
330
+
331
+ if align > 0:
332
+ M += align - (M % align)
333
+
334
+ xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
335
+ dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
336
+ deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) # 2 vals, one for rgb, one for depth
337
+
338
+ if perturb:
339
+ # torch.manual_seed(perturb) # test_gui uses spp index as seed
340
+ noises = torch.rand(n_alive, dtype=rays_o.dtype, device=rays_o.device)
341
+ else:
342
+ noises = torch.zeros(n_alive, dtype=rays_o.dtype, device=rays_o.device)
343
+
344
+ _backend.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, dt_gamma, max_steps, C, H, density_bitfield, near, far, xyzs, dirs, deltas, noises)
345
+
346
+ return xyzs, dirs, deltas
347
+
348
+ march_rays = _march_rays.apply
349
+
350
+
351
+ class _composite_rays(Function):
352
+ @staticmethod
353
+ @custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
354
+ def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh=1e-2):
355
+ ''' composite rays' rgbs, according to the ray marching formula. (for inference)
356
+ Args:
357
+ n_alive: int, number of alive rays
358
+ n_step: int, how many steps we march
359
+ rays_alive: int, [n_alive], the alive rays' IDs in N (N >= n_alive)
360
+ rays_t: float, [N], the alive rays' time
361
+ sigmas: float, [n_alive * n_step,]
362
+ rgbs: float, [n_alive * n_step, 3]
363
+ deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth).
364
+ In-place Outputs:
365
+ weights_sum: float, [N,], the alpha channel
366
+ depth: float, [N,], the depth value
367
+ image: float, [N, 3], the RGB channel (after multiplying alpha!)
368
+ '''
369
+ _backend.composite_rays(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image)
370
+ return tuple()
371
+
372
+
373
+ composite_rays = _composite_rays.apply
raymarching/setup.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from setuptools import setup
3
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4
+
5
+ _src_path = os.path.dirname(os.path.abspath(__file__))
6
+
7
+ nvcc_flags = [
8
+ '-O3', '-std=c++14',
9
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
10
+ ]
11
+
12
+ if os.name == "posix":
13
+ c_flags = ['-O3', '-std=c++14']
14
+ elif os.name == "nt":
15
+ c_flags = ['/O2', '/std:c++17']
16
+
17
+ # find cl.exe
18
+ def find_cl_path():
19
+ import glob
20
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
21
+ paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
22
+ if paths:
23
+ return paths[0]
24
+
25
+ # If cl.exe is not on path, try to find it.
26
+ if os.system("where cl.exe >nul 2>nul") != 0:
27
+ cl_path = find_cl_path()
28
+ if cl_path is None:
29
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
30
+ os.environ["PATH"] += ";" + cl_path
31
+
32
+ '''
33
+ Usage:
34
+
35
+ python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory)
36
+
37
+ python setup.py install # build extensions and install (copy) to PATH.
38
+ pip install . # ditto but better (e.g., dependency & metadata handling)
39
+
40
+ python setup.py develop # build extensions and install (symbolic) to PATH.
41
+ pip install -e . # ditto but better (e.g., dependency & metadata handling)
42
+
43
+ '''
44
+ setup(
45
+ name='raymarching', # package name, import this to use python API
46
+ ext_modules=[
47
+ CUDAExtension(
48
+ name='_raymarching', # extension name, import this to use CUDA API
49
+ sources=[os.path.join(_src_path, 'src', f) for f in [
50
+ 'raymarching.cu',
51
+ 'bindings.cpp',
52
+ ]],
53
+ extra_compile_args={
54
+ 'cxx': c_flags,
55
+ 'nvcc': nvcc_flags,
56
+ }
57
+ ),
58
+ ],
59
+ cmdclass={
60
+ 'build_ext': BuildExtension,
61
+ }
62
+ )
raymarching/src/bindings.cpp ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+ #include "raymarching.h"
4
+
5
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6
+ // utils
7
+ m.def("packbits", &packbits, "packbits (CUDA)");
8
+ m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)");
9
+ m.def("sph_from_ray", &sph_from_ray, "sph_from_ray (CUDA)");
10
+ m.def("morton3D", &morton3D, "morton3D (CUDA)");
11
+ m.def("morton3D_invert", &morton3D_invert, "morton3D_invert (CUDA)");
12
+ // train
13
+ m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)");
14
+ m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)");
15
+ m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)");
16
+ // infer
17
+ m.def("march_rays", &march_rays, "march rays (CUDA)");
18
+ m.def("composite_rays", &composite_rays, "composite rays (CUDA)");
19
+ }
raymarching/src/raymarching.cu ADDED
@@ -0,0 +1,914 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <cuda.h>
2
+ #include <cuda_fp16.h>
3
+ #include <cuda_runtime.h>
4
+
5
+ #include <ATen/cuda/CUDAContext.h>
6
+ #include <torch/torch.h>
7
+
8
+ #include <cstdio>
9
+ #include <stdint.h>
10
+ #include <stdexcept>
11
+ #include <limits>
12
+
13
+ #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
14
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
15
+ #define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
16
+ #define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
17
+
18
+
19
+ inline constexpr __device__ float SQRT3() { return 1.7320508075688772f; }
20
+ inline constexpr __device__ float RSQRT3() { return 0.5773502691896258f; }
21
+ inline constexpr __device__ float PI() { return 3.141592653589793f; }
22
+ inline constexpr __device__ float RPI() { return 0.3183098861837907f; }
23
+
24
+
25
+ template <typename T>
26
+ inline __host__ __device__ T div_round_up(T val, T divisor) {
27
+ return (val + divisor - 1) / divisor;
28
+ }
29
+
30
+ inline __host__ __device__ float signf(const float x) {
31
+ return copysignf(1.0, x);
32
+ }
33
+
34
+ inline __host__ __device__ float clamp(const float x, const float min, const float max) {
35
+ return fminf(max, fmaxf(min, x));
36
+ }
37
+
38
+ inline __host__ __device__ void swapf(float& a, float& b) {
39
+ float c = a; a = b; b = c;
40
+ }
41
+
42
+ inline __device__ int mip_from_pos(const float x, const float y, const float z, const float max_cascade) {
43
+ const float mx = fmaxf(fabsf(x), fmaxf(fabs(y), fabs(z)));
44
+ int exponent;
45
+ frexpf(mx, &exponent); // [0, 0.5) --> -1, [0.5, 1) --> 0, [1, 2) --> 1, [2, 4) --> 2, ...
46
+ return fminf(max_cascade - 1, fmaxf(0, exponent));
47
+ }
48
+
49
+ inline __device__ int mip_from_dt(const float dt, const float H, const float max_cascade) {
50
+ const float mx = dt * H * 0.5;
51
+ int exponent;
52
+ frexpf(mx, &exponent);
53
+ return fminf(max_cascade - 1, fmaxf(0, exponent));
54
+ }
55
+
56
+ inline __host__ __device__ uint32_t __expand_bits(uint32_t v)
57
+ {
58
+ v = (v * 0x00010001u) & 0xFF0000FFu;
59
+ v = (v * 0x00000101u) & 0x0F00F00Fu;
60
+ v = (v * 0x00000011u) & 0xC30C30C3u;
61
+ v = (v * 0x00000005u) & 0x49249249u;
62
+ return v;
63
+ }
64
+
65
+ inline __host__ __device__ uint32_t __morton3D(uint32_t x, uint32_t y, uint32_t z)
66
+ {
67
+ uint32_t xx = __expand_bits(x);
68
+ uint32_t yy = __expand_bits(y);
69
+ uint32_t zz = __expand_bits(z);
70
+ return xx | (yy << 1) | (zz << 2);
71
+ }
72
+
73
+ inline __host__ __device__ uint32_t __morton3D_invert(uint32_t x)
74
+ {
75
+ x = x & 0x49249249;
76
+ x = (x | (x >> 2)) & 0xc30c30c3;
77
+ x = (x | (x >> 4)) & 0x0f00f00f;
78
+ x = (x | (x >> 8)) & 0xff0000ff;
79
+ x = (x | (x >> 16)) & 0x0000ffff;
80
+ return x;
81
+ }
82
+
83
+
84
+ ////////////////////////////////////////////////////
85
+ ///////////// utils /////////////
86
+ ////////////////////////////////////////////////////
87
+
88
+ // rays_o/d: [N, 3]
89
+ // nears/fars: [N]
90
+ // scalar_t should always be float in use.
91
+ template <typename scalar_t>
92
+ __global__ void kernel_near_far_from_aabb(
93
+ const scalar_t * __restrict__ rays_o,
94
+ const scalar_t * __restrict__ rays_d,
95
+ const scalar_t * __restrict__ aabb,
96
+ const uint32_t N,
97
+ const float min_near,
98
+ scalar_t * nears, scalar_t * fars
99
+ ) {
100
+ // parallel per ray
101
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
102
+ if (n >= N) return;
103
+
104
+ // locate
105
+ rays_o += n * 3;
106
+ rays_d += n * 3;
107
+
108
+ const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
109
+ const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
110
+ const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
111
+
112
+ // get near far (assume cube scene)
113
+ float near = (aabb[0] - ox) * rdx;
114
+ float far = (aabb[3] - ox) * rdx;
115
+ if (near > far) swapf(near, far);
116
+
117
+ float near_y = (aabb[1] - oy) * rdy;
118
+ float far_y = (aabb[4] - oy) * rdy;
119
+ if (near_y > far_y) swapf(near_y, far_y);
120
+
121
+ if (near > far_y || near_y > far) {
122
+ nears[n] = fars[n] = std::numeric_limits<scalar_t>::max();
123
+ return;
124
+ }
125
+
126
+ if (near_y > near) near = near_y;
127
+ if (far_y < far) far = far_y;
128
+
129
+ float near_z = (aabb[2] - oz) * rdz;
130
+ float far_z = (aabb[5] - oz) * rdz;
131
+ if (near_z > far_z) swapf(near_z, far_z);
132
+
133
+ if (near > far_z || near_z > far) {
134
+ nears[n] = fars[n] = std::numeric_limits<scalar_t>::max();
135
+ return;
136
+ }
137
+
138
+ if (near_z > near) near = near_z;
139
+ if (far_z < far) far = far_z;
140
+
141
+ if (near < min_near) near = min_near;
142
+
143
+ nears[n] = near;
144
+ fars[n] = far;
145
+ }
146
+
147
+
148
+ void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars) {
149
+
150
+ static constexpr uint32_t N_THREAD = 128;
151
+
152
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
153
+ rays_o.scalar_type(), "near_far_from_aabb", ([&] {
154
+ kernel_near_far_from_aabb<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), aabb.data_ptr<scalar_t>(), N, min_near, nears.data_ptr<scalar_t>(), fars.data_ptr<scalar_t>());
155
+ }));
156
+ }
157
+
158
+
159
+ // rays_o/d: [N, 3]
160
+ // radius: float
161
+ // coords: [N, 2]
162
+ template <typename scalar_t>
163
+ __global__ void kernel_sph_from_ray(
164
+ const scalar_t * __restrict__ rays_o,
165
+ const scalar_t * __restrict__ rays_d,
166
+ const float radius,
167
+ const uint32_t N,
168
+ scalar_t * coords
169
+ ) {
170
+ // parallel per ray
171
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
172
+ if (n >= N) return;
173
+
174
+ // locate
175
+ rays_o += n * 3;
176
+ rays_d += n * 3;
177
+ coords += n * 2;
178
+
179
+ const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
180
+ const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
181
+ const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
182
+
183
+ // solve t from || o + td || = radius
184
+ const float A = dx * dx + dy * dy + dz * dz;
185
+ const float B = ox * dx + oy * dy + oz * dz; // in fact B / 2
186
+ const float C = ox * ox + oy * oy + oz * oz - radius * radius;
187
+
188
+ const float t = (- B + sqrtf(B * B - A * C)) / A; // always use the larger solution (positive)
189
+
190
+ // solve theta, phi (assume y is the up axis)
191
+ const float x = ox + t * dx, y = oy + t * dy, z = oz + t * dz;
192
+ const float theta = atan2(sqrtf(x * x + z * z), y); // [0, PI)
193
+ const float phi = atan2(z, x); // [-PI, PI)
194
+
195
+ // normalize to [-1, 1]
196
+ coords[0] = 2 * theta * RPI() - 1;
197
+ coords[1] = phi * RPI();
198
+ }
199
+
200
+
201
+ void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords) {
202
+
203
+ static constexpr uint32_t N_THREAD = 128;
204
+
205
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
206
+ rays_o.scalar_type(), "sph_from_ray", ([&] {
207
+ kernel_sph_from_ray<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), radius, N, coords.data_ptr<scalar_t>());
208
+ }));
209
+ }
210
+
211
+
212
+ // coords: int32, [N, 3]
213
+ // indices: int32, [N]
214
+ __global__ void kernel_morton3D(
215
+ const int * __restrict__ coords,
216
+ const uint32_t N,
217
+ int * indices
218
+ ) {
219
+ // parallel
220
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
221
+ if (n >= N) return;
222
+
223
+ // locate
224
+ coords += n * 3;
225
+ indices[n] = __morton3D(coords[0], coords[1], coords[2]);
226
+ }
227
+
228
+
229
+ void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices) {
230
+ static constexpr uint32_t N_THREAD = 128;
231
+ kernel_morton3D<<<div_round_up(N, N_THREAD), N_THREAD>>>(coords.data_ptr<int>(), N, indices.data_ptr<int>());
232
+ }
233
+
234
+
235
+ // indices: int32, [N]
236
+ // coords: int32, [N, 3]
237
+ __global__ void kernel_morton3D_invert(
238
+ const int * __restrict__ indices,
239
+ const uint32_t N,
240
+ int * coords
241
+ ) {
242
+ // parallel
243
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
244
+ if (n >= N) return;
245
+
246
+ // locate
247
+ coords += n * 3;
248
+
249
+ const int ind = indices[n];
250
+
251
+ coords[0] = __morton3D_invert(ind >> 0);
252
+ coords[1] = __morton3D_invert(ind >> 1);
253
+ coords[2] = __morton3D_invert(ind >> 2);
254
+ }
255
+
256
+
257
+ void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords) {
258
+ static constexpr uint32_t N_THREAD = 128;
259
+ kernel_morton3D_invert<<<div_round_up(N, N_THREAD), N_THREAD>>>(indices.data_ptr<int>(), N, coords.data_ptr<int>());
260
+ }
261
+
262
+
263
+ // grid: float, [C, H, H, H]
264
+ // N: int, C * H * H * H / 8
265
+ // density_thresh: float
266
+ // bitfield: uint8, [N]
267
+ template <typename scalar_t>
268
+ __global__ void kernel_packbits(
269
+ const scalar_t * __restrict__ grid,
270
+ const uint32_t N,
271
+ const float density_thresh,
272
+ uint8_t * bitfield
273
+ ) {
274
+ // parallel per byte
275
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
276
+ if (n >= N) return;
277
+
278
+ // locate
279
+ grid += n * 8;
280
+
281
+ uint8_t bits = 0;
282
+
283
+ #pragma unroll
284
+ for (uint8_t i = 0; i < 8; i++) {
285
+ bits |= (grid[i] > density_thresh) ? ((uint8_t)1 << i) : 0;
286
+ }
287
+
288
+ bitfield[n] = bits;
289
+ }
290
+
291
+
292
+ void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield) {
293
+
294
+ static constexpr uint32_t N_THREAD = 128;
295
+
296
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
297
+ grid.scalar_type(), "packbits", ([&] {
298
+ kernel_packbits<<<div_round_up(N, N_THREAD), N_THREAD>>>(grid.data_ptr<scalar_t>(), N, density_thresh, bitfield.data_ptr<uint8_t>());
299
+ }));
300
+ }
301
+
302
+ ////////////////////////////////////////////////////
303
+ ///////////// training /////////////
304
+ ////////////////////////////////////////////////////
305
+
306
+ // rays_o/d: [N, 3]
307
+ // grid: [CHHH / 8]
308
+ // xyzs, dirs, deltas: [M, 3], [M, 3], [M, 2]
309
+ // dirs: [M, 3]
310
+ // rays: [N, 3], idx, offset, num_steps
311
+ template <typename scalar_t>
312
+ __global__ void kernel_march_rays_train(
313
+ const scalar_t * __restrict__ rays_o,
314
+ const scalar_t * __restrict__ rays_d,
315
+ const uint8_t * __restrict__ grid,
316
+ const float bound,
317
+ const float dt_gamma, const uint32_t max_steps,
318
+ const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M,
319
+ const scalar_t* __restrict__ nears,
320
+ const scalar_t* __restrict__ fars,
321
+ scalar_t * xyzs, scalar_t * dirs, scalar_t * deltas,
322
+ int * rays,
323
+ int * counter,
324
+ const scalar_t* __restrict__ noises
325
+ ) {
326
+ // parallel per ray
327
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
328
+ if (n >= N) return;
329
+
330
+ // locate
331
+ rays_o += n * 3;
332
+ rays_d += n * 3;
333
+
334
+ // ray marching
335
+ const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
336
+ const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
337
+ const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
338
+ const float rH = 1 / (float)H;
339
+ const float H3 = H * H * H;
340
+
341
+ const float near = nears[n];
342
+ const float far = fars[n];
343
+ const float noise = noises[n];
344
+
345
+ const float dt_min = 2 * SQRT3() / max_steps;
346
+ const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H;
347
+
348
+ float t0 = near;
349
+
350
+ // perturb
351
+ t0 += clamp(t0 * dt_gamma, dt_min, dt_max) * noise;
352
+
353
+ // first pass: estimation of num_steps
354
+ float t = t0;
355
+ uint32_t num_steps = 0;
356
+
357
+ //if (t < far) printf("valid ray %d t=%f near=%f far=%f \n", n, t, near, far);
358
+
359
+ while (t < far && num_steps < max_steps) {
360
+ // current point
361
+ const float x = clamp(ox + t * dx, -bound, bound);
362
+ const float y = clamp(oy + t * dy, -bound, bound);
363
+ const float z = clamp(oz + t * dz, -bound, bound);
364
+
365
+ const float dt = clamp(t * dt_gamma, dt_min, dt_max);
366
+
367
+ // get mip level
368
+ const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]
369
+
370
+ const float mip_bound = fminf(scalbnf(1.0f, level), bound);
371
+ const float mip_rbound = 1 / mip_bound;
372
+
373
+ // convert to nearest grid position
374
+ const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
375
+ const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
376
+ const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
377
+
378
+ const uint32_t index = level * H3 + __morton3D(nx, ny, nz);
379
+ const bool occ = grid[index / 8] & (1 << (index % 8));
380
+
381
+ // if occpuied, advance a small step, and write to output
382
+ //if (n == 0) printf("t=%f density=%f vs thresh=%f step=%d\n", t, density, density_thresh, num_steps);
383
+
384
+ if (occ) {
385
+ num_steps++;
386
+ t += dt;
387
+ // else, skip a large step (basically skip a voxel grid)
388
+ } else {
389
+ // calc distance to next voxel
390
+ const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx;
391
+ const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy;
392
+ const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz;
393
+
394
+ const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
395
+ // step until next voxel
396
+ do {
397
+ t += clamp(t * dt_gamma, dt_min, dt_max);
398
+ } while (t < tt);
399
+ }
400
+ }
401
+
402
+ //printf("[n=%d] num_steps=%d, near=%f, far=%f, dt=%f, max_steps=%f\n", n, num_steps, near, far, dt_min, (far - near) / dt_min);
403
+
404
+ // second pass: really locate and write points & dirs
405
+ uint32_t point_index = atomicAdd(counter, num_steps);
406
+ uint32_t ray_index = atomicAdd(counter + 1, 1);
407
+
408
+ //printf("[n=%d] num_steps=%d, point_index=%d, ray_index=%d\n", n, num_steps, point_index, ray_index);
409
+
410
+ // write rays
411
+ rays[ray_index * 3] = n;
412
+ rays[ray_index * 3 + 1] = point_index;
413
+ rays[ray_index * 3 + 2] = num_steps;
414
+
415
+ if (num_steps == 0) return;
416
+ if (point_index + num_steps > M) return;
417
+
418
+ xyzs += point_index * 3;
419
+ dirs += point_index * 3;
420
+ deltas += point_index * 2;
421
+
422
+ t = t0;
423
+ uint32_t step = 0;
424
+
425
+ float last_t = t;
426
+
427
+ while (t < far && step < num_steps) {
428
+ // current point
429
+ const float x = clamp(ox + t * dx, -bound, bound);
430
+ const float y = clamp(oy + t * dy, -bound, bound);
431
+ const float z = clamp(oz + t * dz, -bound, bound);
432
+
433
+ const float dt = clamp(t * dt_gamma, dt_min, dt_max);
434
+
435
+ // get mip level
436
+ const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]
437
+
438
+ const float mip_bound = fminf(scalbnf(1.0f, level), bound);
439
+ const float mip_rbound = 1 / mip_bound;
440
+
441
+ // convert to nearest grid position
442
+ const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
443
+ const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
444
+ const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
445
+
446
+ // query grid
447
+ const uint32_t index = level * H3 + __morton3D(nx, ny, nz);
448
+ const bool occ = grid[index / 8] & (1 << (index % 8));
449
+
450
+ // if occpuied, advance a small step, and write to output
451
+ if (occ) {
452
+ // write step
453
+ xyzs[0] = x;
454
+ xyzs[1] = y;
455
+ xyzs[2] = z;
456
+ dirs[0] = dx;
457
+ dirs[1] = dy;
458
+ dirs[2] = dz;
459
+ t += dt;
460
+ deltas[0] = dt;
461
+ deltas[1] = t - last_t; // used to calc depth
462
+ last_t = t;
463
+ xyzs += 3;
464
+ dirs += 3;
465
+ deltas += 2;
466
+ step++;
467
+ // else, skip a large step (basically skip a voxel grid)
468
+ } else {
469
+ // calc distance to next voxel
470
+ const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx;
471
+ const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy;
472
+ const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz;
473
+ const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
474
+ // step until next voxel
475
+ do {
476
+ t += clamp(t * dt_gamma, dt_min, dt_max);
477
+ } while (t < tt);
478
+ }
479
+ }
480
+ }
481
+
482
+ void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises) {
483
+
484
+ static constexpr uint32_t N_THREAD = 128;
485
+
486
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
487
+ rays_o.scalar_type(), "march_rays_train", ([&] {
488
+ kernel_march_rays_train<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), grid.data_ptr<uint8_t>(), bound, dt_gamma, max_steps, N, C, H, M, nears.data_ptr<scalar_t>(), fars.data_ptr<scalar_t>(), xyzs.data_ptr<scalar_t>(), dirs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), counter.data_ptr<int>(), noises.data_ptr<scalar_t>());
489
+ }));
490
+ }
491
+
492
+
493
+ // sigmas: [M]
494
+ // rgbs: [M, 3]
495
+ // deltas: [M, 2]
496
+ // rays: [N, 3], idx, offset, num_steps
497
+ // weights_sum: [N], final pixel alpha
498
+ // depth: [N,]
499
+ // image: [N, 3]
500
+ template <typename scalar_t>
501
+ __global__ void kernel_composite_rays_train_forward(
502
+ const scalar_t * __restrict__ sigmas,
503
+ const scalar_t * __restrict__ rgbs,
504
+ const scalar_t * __restrict__ deltas,
505
+ const int * __restrict__ rays,
506
+ const uint32_t M, const uint32_t N, const float T_thresh,
507
+ scalar_t * weights_sum,
508
+ scalar_t * depth,
509
+ scalar_t * image
510
+ ) {
511
+ // parallel per ray
512
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
513
+ if (n >= N) return;
514
+
515
+ // locate
516
+ uint32_t index = rays[n * 3];
517
+ uint32_t offset = rays[n * 3 + 1];
518
+ uint32_t num_steps = rays[n * 3 + 2];
519
+
520
+ // empty ray, or ray that exceed max step count.
521
+ if (num_steps == 0 || offset + num_steps > M) {
522
+ weights_sum[index] = 0;
523
+ depth[index] = 0;
524
+ image[index * 3] = 0;
525
+ image[index * 3 + 1] = 0;
526
+ image[index * 3 + 2] = 0;
527
+ return;
528
+ }
529
+
530
+ sigmas += offset;
531
+ rgbs += offset * 3;
532
+ deltas += offset * 2;
533
+
534
+ // accumulate
535
+ uint32_t step = 0;
536
+
537
+ scalar_t T = 1.0f;
538
+ scalar_t r = 0, g = 0, b = 0, ws = 0, t = 0, d = 0;
539
+
540
+ while (step < num_steps) {
541
+
542
+ const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);
543
+ const scalar_t weight = alpha * T;
544
+
545
+ r += weight * rgbs[0];
546
+ g += weight * rgbs[1];
547
+ b += weight * rgbs[2];
548
+
549
+ t += deltas[1]; // real delta
550
+ d += weight * t;
551
+
552
+ ws += weight;
553
+
554
+ T *= 1.0f - alpha;
555
+
556
+ // minimal remained transmittence
557
+ if (T < T_thresh) break;
558
+
559
+ //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d);
560
+
561
+ // locate
562
+ sigmas++;
563
+ rgbs += 3;
564
+ deltas += 2;
565
+
566
+ step++;
567
+ }
568
+
569
+ //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d);
570
+
571
+ // write
572
+ weights_sum[index] = ws; // weights_sum
573
+ depth[index] = d;
574
+ image[index * 3] = r;
575
+ image[index * 3 + 1] = g;
576
+ image[index * 3 + 2] = b;
577
+ }
578
+
579
+
580
+ void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor depth, at::Tensor image) {
581
+
582
+ static constexpr uint32_t N_THREAD = 128;
583
+
584
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
585
+ sigmas.scalar_type(), "composite_rays_train_forward", ([&] {
586
+ kernel_composite_rays_train_forward<<<div_round_up(N, N_THREAD), N_THREAD>>>(sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), M, N, T_thresh, weights_sum.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>());
587
+ }));
588
+ }
589
+
590
+
591
+ // grad_weights_sum: [N,]
592
+ // grad: [N, 3]
593
+ // sigmas: [M]
594
+ // rgbs: [M, 3]
595
+ // deltas: [M, 2]
596
+ // rays: [N, 3], idx, offset, num_steps
597
+ // weights_sum: [N,], weights_sum here
598
+ // image: [N, 3]
599
+ // grad_sigmas: [M]
600
+ // grad_rgbs: [M, 3]
601
+ template <typename scalar_t>
602
+ __global__ void kernel_composite_rays_train_backward(
603
+ const scalar_t * __restrict__ grad_weights_sum,
604
+ const scalar_t * __restrict__ grad_image,
605
+ const scalar_t * __restrict__ sigmas,
606
+ const scalar_t * __restrict__ rgbs,
607
+ const scalar_t * __restrict__ deltas,
608
+ const int * __restrict__ rays,
609
+ const scalar_t * __restrict__ weights_sum,
610
+ const scalar_t * __restrict__ image,
611
+ const uint32_t M, const uint32_t N, const float T_thresh,
612
+ scalar_t * grad_sigmas,
613
+ scalar_t * grad_rgbs
614
+ ) {
615
+ // parallel per ray
616
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
617
+ if (n >= N) return;
618
+
619
+ // locate
620
+ uint32_t index = rays[n * 3];
621
+ uint32_t offset = rays[n * 3 + 1];
622
+ uint32_t num_steps = rays[n * 3 + 2];
623
+
624
+ if (num_steps == 0 || offset + num_steps > M) return;
625
+
626
+ grad_weights_sum += index;
627
+ grad_image += index * 3;
628
+ weights_sum += index;
629
+ image += index * 3;
630
+ sigmas += offset;
631
+ rgbs += offset * 3;
632
+ deltas += offset * 2;
633
+ grad_sigmas += offset;
634
+ grad_rgbs += offset * 3;
635
+
636
+ // accumulate
637
+ uint32_t step = 0;
638
+
639
+ scalar_t T = 1.0f;
640
+ const scalar_t r_final = image[0], g_final = image[1], b_final = image[2], ws_final = weights_sum[0];
641
+ scalar_t r = 0, g = 0, b = 0, ws = 0;
642
+
643
+ while (step < num_steps) {
644
+
645
+ const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);
646
+ const scalar_t weight = alpha * T;
647
+
648
+ r += weight * rgbs[0];
649
+ g += weight * rgbs[1];
650
+ b += weight * rgbs[2];
651
+ ws += weight;
652
+
653
+ T *= 1.0f - alpha;
654
+
655
+ // check https://note.kiui.moe/others/nerf_gradient/ for the gradient calculation.
656
+ // write grad_rgbs
657
+ grad_rgbs[0] = grad_image[0] * weight;
658
+ grad_rgbs[1] = grad_image[1] * weight;
659
+ grad_rgbs[2] = grad_image[2] * weight;
660
+
661
+ // write grad_sigmas
662
+ grad_sigmas[0] = deltas[0] * (
663
+ grad_image[0] * (T * rgbs[0] - (r_final - r)) +
664
+ grad_image[1] * (T * rgbs[1] - (g_final - g)) +
665
+ grad_image[2] * (T * rgbs[2] - (b_final - b)) +
666
+ grad_weights_sum[0] * (1 - ws_final)
667
+ );
668
+
669
+ //printf("[n=%d] num_steps=%d, T=%f, grad_sigmas=%f, r_final=%f, r=%f\n", n, step, T, grad_sigmas[0], r_final, r);
670
+ // minimal remained transmittence
671
+ if (T < T_thresh) break;
672
+
673
+ // locate
674
+ sigmas++;
675
+ rgbs += 3;
676
+ deltas += 2;
677
+ grad_sigmas++;
678
+ grad_rgbs += 3;
679
+
680
+ step++;
681
+ }
682
+ }
683
+
684
+
685
+ void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs) {
686
+
687
+ static constexpr uint32_t N_THREAD = 128;
688
+
689
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
690
+ grad_image.scalar_type(), "composite_rays_train_backward", ([&] {
691
+ kernel_composite_rays_train_backward<<<div_round_up(N, N_THREAD), N_THREAD>>>(grad_weights_sum.data_ptr<scalar_t>(), grad_image.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), weights_sum.data_ptr<scalar_t>(), image.data_ptr<scalar_t>(), M, N, T_thresh, grad_sigmas.data_ptr<scalar_t>(), grad_rgbs.data_ptr<scalar_t>());
692
+ }));
693
+ }
694
+
695
+
696
+ ////////////////////////////////////////////////////
697
+ ///////////// infernce /////////////
698
+ ////////////////////////////////////////////////////
699
+
700
+ template <typename scalar_t>
701
+ __global__ void kernel_march_rays(
702
+ const uint32_t n_alive,
703
+ const uint32_t n_step,
704
+ const int* __restrict__ rays_alive,
705
+ const scalar_t* __restrict__ rays_t,
706
+ const scalar_t* __restrict__ rays_o,
707
+ const scalar_t* __restrict__ rays_d,
708
+ const float bound,
709
+ const float dt_gamma, const uint32_t max_steps,
710
+ const uint32_t C, const uint32_t H,
711
+ const uint8_t * __restrict__ grid,
712
+ const scalar_t* __restrict__ nears,
713
+ const scalar_t* __restrict__ fars,
714
+ scalar_t* xyzs, scalar_t* dirs, scalar_t* deltas,
715
+ const scalar_t* __restrict__ noises
716
+ ) {
717
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
718
+ if (n >= n_alive) return;
719
+
720
+ const int index = rays_alive[n]; // ray id
721
+ const float noise = noises[n];
722
+
723
+ // locate
724
+ rays_o += index * 3;
725
+ rays_d += index * 3;
726
+ xyzs += n * n_step * 3;
727
+ dirs += n * n_step * 3;
728
+ deltas += n * n_step * 2;
729
+
730
+ const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
731
+ const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
732
+ const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
733
+ const float rH = 1 / (float)H;
734
+ const float H3 = H * H * H;
735
+
736
+ float t = rays_t[index]; // current ray's t
737
+ const float near = nears[index], far = fars[index];
738
+
739
+ const float dt_min = 2 * SQRT3() / max_steps;
740
+ const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H;
741
+
742
+ // march for n_step steps, record points
743
+ uint32_t step = 0;
744
+
745
+ // introduce some randomness
746
+ t += clamp(t * dt_gamma, dt_min, dt_max) * noise;
747
+
748
+ float last_t = t;
749
+
750
+ while (t < far && step < n_step) {
751
+ // current point
752
+ const float x = clamp(ox + t * dx, -bound, bound);
753
+ const float y = clamp(oy + t * dy, -bound, bound);
754
+ const float z = clamp(oz + t * dz, -bound, bound);
755
+
756
+ const float dt = clamp(t * dt_gamma, dt_min, dt_max);
757
+
758
+ // get mip level
759
+ const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]
760
+
761
+ const float mip_bound = fminf(scalbnf(1, level), bound);
762
+ const float mip_rbound = 1 / mip_bound;
763
+
764
+ // convert to nearest grid position
765
+ const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
766
+ const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
767
+ const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
768
+
769
+ const uint32_t index = level * H3 + __morton3D(nx, ny, nz);
770
+ const bool occ = grid[index / 8] & (1 << (index % 8));
771
+
772
+ // if occpuied, advance a small step, and write to output
773
+ if (occ) {
774
+ // write step
775
+ xyzs[0] = x;
776
+ xyzs[1] = y;
777
+ xyzs[2] = z;
778
+ dirs[0] = dx;
779
+ dirs[1] = dy;
780
+ dirs[2] = dz;
781
+ // calc dt
782
+ t += dt;
783
+ deltas[0] = dt;
784
+ deltas[1] = t - last_t; // used to calc depth
785
+ last_t = t;
786
+ // step
787
+ xyzs += 3;
788
+ dirs += 3;
789
+ deltas += 2;
790
+ step++;
791
+
792
+ // else, skip a large step (basically skip a voxel grid)
793
+ } else {
794
+ // calc distance to next voxel
795
+ const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx;
796
+ const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy;
797
+ const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz;
798
+ const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
799
+ // step until next voxel
800
+ do {
801
+ t += clamp(t * dt_gamma, dt_min, dt_max);
802
+ } while (t < tt);
803
+ }
804
+ }
805
+ }
806
+
807
+
808
+ void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor near, const at::Tensor far, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises) {
809
+ static constexpr uint32_t N_THREAD = 128;
810
+
811
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
812
+ rays_o.scalar_type(), "march_rays", ([&] {
813
+ kernel_march_rays<<<div_round_up(n_alive, N_THREAD), N_THREAD>>>(n_alive, n_step, rays_alive.data_ptr<int>(), rays_t.data_ptr<scalar_t>(), rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), bound, dt_gamma, max_steps, C, H, grid.data_ptr<uint8_t>(), near.data_ptr<scalar_t>(), far.data_ptr<scalar_t>(), xyzs.data_ptr<scalar_t>(), dirs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), noises.data_ptr<scalar_t>());
814
+ }));
815
+ }
816
+
817
+
818
+ template <typename scalar_t>
819
+ __global__ void kernel_composite_rays(
820
+ const uint32_t n_alive,
821
+ const uint32_t n_step,
822
+ const float T_thresh,
823
+ int* rays_alive,
824
+ scalar_t* rays_t,
825
+ const scalar_t* __restrict__ sigmas,
826
+ const scalar_t* __restrict__ rgbs,
827
+ const scalar_t* __restrict__ deltas,
828
+ scalar_t* weights_sum, scalar_t* depth, scalar_t* image
829
+ ) {
830
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
831
+ if (n >= n_alive) return;
832
+
833
+ const int index = rays_alive[n]; // ray id
834
+
835
+ // locate
836
+ sigmas += n * n_step;
837
+ rgbs += n * n_step * 3;
838
+ deltas += n * n_step * 2;
839
+
840
+ rays_t += index;
841
+ weights_sum += index;
842
+ depth += index;
843
+ image += index * 3;
844
+
845
+ scalar_t t = rays_t[0]; // current ray's t
846
+
847
+ scalar_t weight_sum = weights_sum[0];
848
+ scalar_t d = depth[0];
849
+ scalar_t r = image[0];
850
+ scalar_t g = image[1];
851
+ scalar_t b = image[2];
852
+
853
+ // accumulate
854
+ uint32_t step = 0;
855
+ while (step < n_step) {
856
+
857
+ // ray is terminated if delta == 0
858
+ if (deltas[0] == 0) break;
859
+
860
+ const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);
861
+
862
+ /*
863
+ T_0 = 1; T_i = \prod_{j=0}^{i-1} (1 - alpha_j)
864
+ w_i = alpha_i * T_i
865
+ -->
866
+ T_i = 1 - \sum_{j=0}^{i-1} w_j
867
+ */
868
+ const scalar_t T = 1 - weight_sum;
869
+ const scalar_t weight = alpha * T;
870
+ weight_sum += weight;
871
+
872
+ t += deltas[1]; // real delta
873
+ d += weight * t;
874
+ r += weight * rgbs[0];
875
+ g += weight * rgbs[1];
876
+ b += weight * rgbs[2];
877
+
878
+ //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d);
879
+
880
+ // ray is terminated if T is too small
881
+ // use a larger bound to further accelerate inference
882
+ if (T < T_thresh) break;
883
+
884
+ // locate
885
+ sigmas++;
886
+ rgbs += 3;
887
+ deltas += 2;
888
+ step++;
889
+ }
890
+
891
+ //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d);
892
+
893
+ // rays_alive = -1 means ray is terminated early.
894
+ if (step < n_step) {
895
+ rays_alive[n] = -1;
896
+ } else {
897
+ rays_t[0] = t;
898
+ }
899
+
900
+ weights_sum[0] = weight_sum; // this is the thing I needed!
901
+ depth[0] = d;
902
+ image[0] = r;
903
+ image[1] = g;
904
+ image[2] = b;
905
+ }
906
+
907
+
908
+ void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights, at::Tensor depth, at::Tensor image) {
909
+ static constexpr uint32_t N_THREAD = 128;
910
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
911
+ image.scalar_type(), "composite_rays", ([&] {
912
+ kernel_composite_rays<<<div_round_up(n_alive, N_THREAD), N_THREAD>>>(n_alive, n_step, T_thresh, rays_alive.data_ptr<int>(), rays_t.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), weights.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>());
913
+ }));
914
+ }
raymarching/src/raymarching.h ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <stdint.h>
4
+ #include <torch/torch.h>
5
+
6
+
7
+ void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars);
8
+ void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords);
9
+ void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices);
10
+ void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords);
11
+ void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield);
12
+
13
+ void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises);
14
+ void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);
15
+ void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs);
16
+
17
+ void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises);
18
+ void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch-ema
2
+ ninja
3
+ trimesh
4
+ opencv-python
5
+ tensorboardX
6
+ torch
7
+ numpy
8
+ pandas
9
+ tqdm
10
+ matplotlib
11
+ PyMCubes
12
+ rich
13
+ dearpygui
14
+ scipy
15
+ huggingface_hub
16
+ diffusers
17
+ transformers
18
+ xatlas
19
+ scikit-learn
20
+ imageio
21
+ imageio-ffmpeg
scripts/install_ext.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ pip install ./raymarching
2
+ pip install ./shencoder
3
+ pip install ./freqencoder
4
+ pip install ./gridencoder
scripts/run.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #! /bin/bash
2
+
3
+ CUDA_VISIBLE_DEVICES=1 python main.py -O --text "a DSLR photo of cthulhu" --workspace trial_cthulhu
4
+ CUDA_VISIBLE_DEVICES=1 python main.py -O --text "a DSLR photo of a squirrel" --workspace trial_squirrel
5
+ CUDA_VISIBLE_DEVICES=1 python main.py -O --text "a DSLR photo of a cat lying on its side batting at a ball of yarn" --workspace trial_cat_lying
shencoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sphere_harmonics import SHEncoder
shencoder/backend.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from torch.utils.cpp_extension import load
3
+
4
+ _src_path = os.path.dirname(os.path.abspath(__file__))
5
+
6
+ nvcc_flags = [
7
+ '-O3', '-std=c++14',
8
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
9
+ ]
10
+
11
+ if os.name == "posix":
12
+ c_flags = ['-O3', '-std=c++14']
13
+ elif os.name == "nt":
14
+ c_flags = ['/O2', '/std:c++17']
15
+
16
+ # find cl.exe
17
+ def find_cl_path():
18
+ import glob
19
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
20
+ paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
21
+ if paths:
22
+ return paths[0]
23
+
24
+ # If cl.exe is not on path, try to find it.
25
+ if os.system("where cl.exe >nul 2>nul") != 0:
26
+ cl_path = find_cl_path()
27
+ if cl_path is None:
28
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
29
+ os.environ["PATH"] += ";" + cl_path
30
+
31
+ _backend = load(name='_sh_encoder',
32
+ extra_cflags=c_flags,
33
+ extra_cuda_cflags=nvcc_flags,
34
+ sources=[os.path.join(_src_path, 'src', f) for f in [
35
+ 'shencoder.cu',
36
+ 'bindings.cpp',
37
+ ]],
38
+ )
39
+
40
+ __all__ = ['_backend']
shencoder/setup.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from setuptools import setup
3
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4
+
5
+ _src_path = os.path.dirname(os.path.abspath(__file__))
6
+
7
+ nvcc_flags = [
8
+ '-O3', '-std=c++14',
9
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
10
+ ]
11
+
12
+ if os.name == "posix":
13
+ c_flags = ['-O3', '-std=c++14']
14
+ elif os.name == "nt":
15
+ c_flags = ['/O2', '/std:c++17']
16
+
17
+ # find cl.exe
18
+ def find_cl_path():
19
+ import glob
20
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
21
+ paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
22
+ if paths:
23
+ return paths[0]
24
+
25
+ # If cl.exe is not on path, try to find it.
26
+ if os.system("where cl.exe >nul 2>nul") != 0:
27
+ cl_path = find_cl_path()
28
+ if cl_path is None:
29
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
30
+ os.environ["PATH"] += ";" + cl_path
31
+
32
+ setup(
33
+ name='shencoder', # package name, import this to use python API
34
+ ext_modules=[
35
+ CUDAExtension(
36
+ name='_shencoder', # extension name, import this to use CUDA API
37
+ sources=[os.path.join(_src_path, 'src', f) for f in [
38
+ 'shencoder.cu',
39
+ 'bindings.cpp',
40
+ ]],
41
+ extra_compile_args={
42
+ 'cxx': c_flags,
43
+ 'nvcc': nvcc_flags,
44
+ }
45
+ ),
46
+ ],
47
+ cmdclass={
48
+ 'build_ext': BuildExtension,
49
+ }
50
+ )
shencoder/sphere_harmonics.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.autograd import Function
6
+ from torch.autograd.function import once_differentiable
7
+ from torch.cuda.amp import custom_bwd, custom_fwd
8
+
9
+ try:
10
+ import _shencoder as _backend
11
+ except ImportError:
12
+ from .backend import _backend
13
+
14
+ class _sh_encoder(Function):
15
+ @staticmethod
16
+ @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision
17
+ def forward(ctx, inputs, degree, calc_grad_inputs=False):
18
+ # inputs: [B, input_dim], float in [-1, 1]
19
+ # RETURN: [B, F], float
20
+
21
+ inputs = inputs.contiguous()
22
+ B, input_dim = inputs.shape # batch size, coord dim
23
+ output_dim = degree ** 2
24
+
25
+ outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)
26
+
27
+ if calc_grad_inputs:
28
+ dy_dx = torch.empty(B, input_dim * output_dim, dtype=inputs.dtype, device=inputs.device)
29
+ else:
30
+ dy_dx = None
31
+
32
+ _backend.sh_encode_forward(inputs, outputs, B, input_dim, degree, dy_dx)
33
+
34
+ ctx.save_for_backward(inputs, dy_dx)
35
+ ctx.dims = [B, input_dim, degree]
36
+
37
+ return outputs
38
+
39
+ @staticmethod
40
+ #@once_differentiable
41
+ @custom_bwd
42
+ def backward(ctx, grad):
43
+ # grad: [B, C * C]
44
+
45
+ inputs, dy_dx = ctx.saved_tensors
46
+
47
+ if dy_dx is not None:
48
+ grad = grad.contiguous()
49
+ B, input_dim, degree = ctx.dims
50
+ grad_inputs = torch.zeros_like(inputs)
51
+ _backend.sh_encode_backward(grad, inputs, B, input_dim, degree, dy_dx, grad_inputs)
52
+ return grad_inputs, None, None
53
+ else:
54
+ return None, None, None
55
+
56
+
57
+
58
+ sh_encode = _sh_encoder.apply
59
+
60
+
61
+ class SHEncoder(nn.Module):
62
+ def __init__(self, input_dim=3, degree=4):
63
+ super().__init__()
64
+
65
+ self.input_dim = input_dim # coord dims, must be 3
66
+ self.degree = degree # 0 ~ 4
67
+ self.output_dim = degree ** 2
68
+
69
+ assert self.input_dim == 3, "SH encoder only support input dim == 3"
70
+ assert self.degree > 0 and self.degree <= 8, "SH encoder only supports degree in [1, 8]"
71
+
72
+ def __repr__(self):
73
+ return f"SHEncoder: input_dim={self.input_dim} degree={self.degree}"
74
+
75
+ def forward(self, inputs, size=1):
76
+ # inputs: [..., input_dim], normalized real world positions in [-size, size]
77
+ # return: [..., degree^2]
78
+
79
+ inputs = inputs / size # [-1, 1]
80
+
81
+ prefix_shape = list(inputs.shape[:-1])
82
+ inputs = inputs.reshape(-1, self.input_dim)
83
+
84
+ outputs = sh_encode(inputs, self.degree, inputs.requires_grad)
85
+ outputs = outputs.reshape(prefix_shape + [self.output_dim])
86
+
87
+ return outputs
shencoder/src/bindings.cpp ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+ #include "shencoder.h"
4
+
5
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6
+ m.def("sh_encode_forward", &sh_encode_forward, "SH encode forward (CUDA)");
7
+ m.def("sh_encode_backward", &sh_encode_backward, "SH encode backward (CUDA)");
8
+ }
shencoder/src/shencoder.cu ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <stdint.h>
2
+
3
+ #include <cuda.h>
4
+ #include <cuda_fp16.h>
5
+ #include <cuda_runtime.h>
6
+
7
+ #include <ATen/cuda/CUDAContext.h>
8
+ #include <torch/torch.h>
9
+
10
+ #include <algorithm>
11
+ #include <stdexcept>
12
+
13
+ #include <cstdio>
14
+
15
+
16
+ #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
17
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
18
+ #define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
19
+ #define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
20
+
21
+
22
+ template <typename T>
23
+ __host__ __device__ T div_round_up(T val, T divisor) {
24
+ return (val + divisor - 1) / divisor;
25
+ }
26
+
27
+ template <typename scalar_t>
28
+ __global__ void kernel_sh(
29
+ const scalar_t * __restrict__ inputs,
30
+ scalar_t * outputs,
31
+ uint32_t B, uint32_t D, uint32_t C,
32
+ scalar_t * dy_dx
33
+ ) {
34
+ const uint32_t b = threadIdx.x + blockIdx.x * blockDim.x;
35
+ if (b >= B) return;
36
+
37
+ const uint32_t C2 = C * C;
38
+
39
+ // locate
40
+ inputs += b * D;
41
+ outputs += b * C2;
42
+
43
+ scalar_t x = inputs[0], y = inputs[1], z = inputs[2];
44
+
45
+ scalar_t xy=x*y, xz=x*z, yz=y*z, x2=x*x, y2=y*y, z2=z*z, xyz=xy*z;
46
+ scalar_t x4=x2*x2, y4=y2*y2, z4=z2*z2;
47
+ scalar_t x6=x4*x2, y6=y4*y2, z6=z4*z2;
48
+
49
+ auto write_sh = [&]() {
50
+ outputs[0] = 0.28209479177387814f ; // 1/(2*sqrt(pi))
51
+ if (C <= 1) { return; }
52
+ outputs[1] = -0.48860251190291987f*y ; // -sqrt(3)*y/(2*sqrt(pi))
53
+ outputs[2] = 0.48860251190291987f*z ; // sqrt(3)*z/(2*sqrt(pi))
54
+ outputs[3] = -0.48860251190291987f*x ; // -sqrt(3)*x/(2*sqrt(pi))
55
+ if (C <= 2) { return; }
56
+ outputs[4] = 1.0925484305920792f*xy ; // sqrt(15)*xy/(2*sqrt(pi))
57
+ outputs[5] = -1.0925484305920792f*yz ; // -sqrt(15)*yz/(2*sqrt(pi))
58
+ outputs[6] = 0.94617469575755997f*z2 - 0.31539156525251999f ; // sqrt(5)*(3*z2 - 1)/(4*sqrt(pi))
59
+ outputs[7] = -1.0925484305920792f*xz ; // -sqrt(15)*xz/(2*sqrt(pi))
60
+ outputs[8] = 0.54627421529603959f*x2 - 0.54627421529603959f*y2 ; // sqrt(15)*(x2 - y2)/(4*sqrt(pi))
61
+ if (C <= 3) { return; }
62
+ outputs[9] = 0.59004358992664352f*y*(-3.0f*x2 + y2) ; // sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi))
63
+ outputs[10] = 2.8906114426405538f*xy*z ; // sqrt(105)*xy*z/(2*sqrt(pi))
64
+ outputs[11] = 0.45704579946446572f*y*(1.0f - 5.0f*z2) ; // sqrt(42)*y*(1 - 5*z2)/(8*sqrt(pi))
65
+ outputs[12] = 0.3731763325901154f*z*(5.0f*z2 - 3.0f) ; // sqrt(7)*z*(5*z2 - 3)/(4*sqrt(pi))
66
+ outputs[13] = 0.45704579946446572f*x*(1.0f - 5.0f*z2) ; // sqrt(42)*x*(1 - 5*z2)/(8*sqrt(pi))
67
+ outputs[14] = 1.4453057213202769f*z*(x2 - y2) ; // sqrt(105)*z*(x2 - y2)/(4*sqrt(pi))
68
+ outputs[15] = 0.59004358992664352f*x*(-x2 + 3.0f*y2) ; // sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi))
69
+ if (C <= 4) { return; }
70
+ outputs[16] = 2.5033429417967046f*xy*(x2 - y2) ; // 3*sqrt(35)*xy*(x2 - y2)/(4*sqrt(pi))
71
+ outputs[17] = 1.7701307697799304f*yz*(-3.0f*x2 + y2) ; // 3*sqrt(70)*yz*(-3*x2 + y2)/(8*sqrt(pi))
72
+ outputs[18] = 0.94617469575756008f*xy*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*xy*(7*z2 - 1)/(4*sqrt(pi))
73
+ outputs[19] = 0.66904654355728921f*yz*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*yz*(3 - 7*z2)/(8*sqrt(pi))
74
+ outputs[20] = -3.1735664074561294f*z2 + 3.7024941420321507f*z4 + 0.31735664074561293f ; // 3*(-30*z2 + 35*z4 + 3)/(16*sqrt(pi))
75
+ outputs[21] = 0.66904654355728921f*xz*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*xz*(3 - 7*z2)/(8*sqrt(pi))
76
+ outputs[22] = 0.47308734787878004f*(x2 - y2)*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*(x2 - y2)*(7*z2 - 1)/(8*sqrt(pi))
77
+ outputs[23] = 1.7701307697799304f*xz*(-x2 + 3.0f*y2) ; // 3*sqrt(70)*xz*(-x2 + 3*y2)/(8*sqrt(pi))
78
+ outputs[24] = -3.7550144126950569f*x2*y2 + 0.62583573544917614f*x4 + 0.62583573544917614f*y4 ; // 3*sqrt(35)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))
79
+ if (C <= 5) { return; }
80
+ outputs[25] = 0.65638205684017015f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(154)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
81
+ outputs[26] = 8.3026492595241645f*xy*z*(x2 - y2) ; // 3*sqrt(385)*xy*z*(x2 - y2)/(4*sqrt(pi))
82
+ outputs[27] = -0.48923829943525038f*y*(3.0f*x2 - y2)*(9.0f*z2 - 1.0f) ; // -sqrt(770)*y*(3*x2 - y2)*(9*z2 - 1)/(32*sqrt(pi))
83
+ outputs[28] = 4.7935367849733241f*xy*z*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xy*z*(3*z2 - 1)/(4*sqrt(pi))
84
+ outputs[29] = 0.45294665119569694f*y*(14.0f*z2 - 21.0f*z4 - 1.0f) ; // sqrt(165)*y*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))
85
+ outputs[30] = 0.1169503224534236f*z*(-70.0f*z2 + 63.0f*z4 + 15.0f) ; // sqrt(11)*z*(-70*z2 + 63*z4 + 15)/(16*sqrt(pi))
86
+ outputs[31] = 0.45294665119569694f*x*(14.0f*z2 - 21.0f*z4 - 1.0f) ; // sqrt(165)*x*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))
87
+ outputs[32] = 2.3967683924866621f*z*(x2 - y2)*(3.0f*z2 - 1.0f) ; // sqrt(1155)*z*(x2 - y2)*(3*z2 - 1)/(8*sqrt(pi))
88
+ outputs[33] = -0.48923829943525038f*x*(x2 - 3.0f*y2)*(9.0f*z2 - 1.0f) ; // -sqrt(770)*x*(x2 - 3*y2)*(9*z2 - 1)/(32*sqrt(pi))
89
+ outputs[34] = 2.0756623148810411f*z*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(385)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))
90
+ outputs[35] = 0.65638205684017015f*x*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(154)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))
91
+ if (C <= 6) { return; }
92
+ outputs[36] = 1.3663682103838286f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // sqrt(6006)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))
93
+ outputs[37] = 2.3666191622317521f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(2002)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
94
+ outputs[38] = 2.0182596029148963f*xy*(x2 - y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*xy*(x2 - y2)*(11*z2 - 1)/(8*sqrt(pi))
95
+ outputs[39] = -0.92120525951492349f*yz*(3.0f*x2 - y2)*(11.0f*z2 - 3.0f) ; // -sqrt(2730)*yz*(3*x2 - y2)*(11*z2 - 3)/(32*sqrt(pi))
96
+ outputs[40] = 0.92120525951492349f*xy*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*xy*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))
97
+ outputs[41] = 0.58262136251873131f*yz*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*yz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))
98
+ outputs[42] = 6.6747662381009842f*z2 - 20.024298714302954f*z4 + 14.684485723822165f*z6 - 0.31784601133814211f ; // sqrt(13)*(105*z2 - 315*z4 + 231*z6 - 5)/(32*sqrt(pi))
99
+ outputs[43] = 0.58262136251873131f*xz*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*xz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))
100
+ outputs[44] = 0.46060262975746175f*(x2 - y2)*(11.0f*z2*(3.0f*z2 - 1.0f) - 7.0f*z2 + 1.0f) ; // sqrt(2730)*(x2 - y2)*(11*z2*(3*z2 - 1) - 7*z2 + 1)/(64*sqrt(pi))
101
+ outputs[45] = -0.92120525951492349f*xz*(x2 - 3.0f*y2)*(11.0f*z2 - 3.0f) ; // -sqrt(2730)*xz*(x2 - 3*y2)*(11*z2 - 3)/(32*sqrt(pi))
102
+ outputs[46] = 0.50456490072872406f*(11.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(91)*(11*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi))
103
+ outputs[47] = 2.3666191622317521f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(2002)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))
104
+ outputs[48] = 10.247761577878714f*x2*y4 - 10.247761577878714f*x4*y2 + 0.6831841051919143f*x6 - 0.6831841051919143f*y6 ; // sqrt(6006)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi))
105
+ if (C <= 7) { return; }
106
+ outputs[49] = 0.70716273252459627f*y*(-21.0f*x2*y4 + 35.0f*x4*y2 - 7.0f*x6 + y6) ; // 3*sqrt(715)*y*(-21*x2*y4 + 35*x4*y2 - 7*x6 + y6)/(64*sqrt(pi))
107
+ outputs[50] = 5.2919213236038001f*xy*z*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 3*sqrt(10010)*xy*z*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))
108
+ outputs[51] = -0.51891557872026028f*y*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // -3*sqrt(385)*y*(13*z2 - 1)*(-10*x2*y2 + 5*x4 + y4)/(64*sqrt(pi))
109
+ outputs[52] = 4.1513246297620823f*xy*z*(x2 - y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xy*z*(x2 - y2)*(13*z2 - 3)/(8*sqrt(pi))
110
+ outputs[53] = -0.15645893386229404f*y*(3.0f*x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -3*sqrt(35)*y*(3*x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi))
111
+ outputs[54] = 0.44253269244498261f*xy*z*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xy*z*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))
112
+ outputs[55] = 0.090331607582517306f*y*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f) ; // sqrt(105)*y*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))
113
+ outputs[56] = 0.068284276912004949f*z*(315.0f*z2 - 693.0f*z4 + 429.0f*z6 - 35.0f) ; // sqrt(15)*z*(315*z2 - 693*z4 + 429*z6 - 35)/(32*sqrt(pi))
114
+ outputs[57] = 0.090331607582517306f*x*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f) ; // sqrt(105)*x*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))
115
+ outputs[58] = 0.07375544874083044f*z*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) - 187.0f*z2 + 45.0f) ; // sqrt(70)*z*(x2 - y2)*(143*z2*(3*z2 - 1) - 187*z2 + 45)/(64*sqrt(pi))
116
+ outputs[59] = -0.15645893386229404f*x*(x2 - 3.0f*y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -3*sqrt(35)*x*(x2 - 3*y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi))
117
+ outputs[60] = 1.0378311574405206f*z*(13.0f*z2 - 3.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(385)*z*(13*z2 - 3)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi))
118
+ outputs[61] = -0.51891557872026028f*x*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // -3*sqrt(385)*x*(13*z2 - 1)*(-10*x2*y2 + x4 + 5*y4)/(64*sqrt(pi))
119
+ outputs[62] = 2.6459606618019f*z*(15.0f*x2*y4 - 15.0f*x4*y2 + x6 - y6) ; // 3*sqrt(10010)*z*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi))
120
+ outputs[63] = 0.70716273252459627f*x*(-35.0f*x2*y4 + 21.0f*x4*y2 - x6 + 7.0f*y6) ; // 3*sqrt(715)*x*(-35*x2*y4 + 21*x4*y2 - x6 + 7*y6)/(64*sqrt(pi))
121
+ };
122
+
123
+ write_sh();
124
+
125
+ if (dy_dx) {
126
+ scalar_t *dx = dy_dx + b * D * C2;
127
+ scalar_t *dy = dx + C2;
128
+ scalar_t *dz = dy + C2;
129
+
130
+ auto write_sh_dx = [&]() {
131
+ dx[0] = 0.0f ; // 0
132
+ if (C <= 1) { return; }
133
+ dx[1] = 0.0f ; // 0
134
+ dx[2] = 0.0f ; // 0
135
+ dx[3] = -0.48860251190291992f ; // -sqrt(3)/(2*sqrt(pi))
136
+ if (C <= 2) { return; }
137
+ dx[4] = 1.0925484305920792f*y ; // sqrt(15)*y/(2*sqrt(pi))
138
+ dx[5] = 0.0f ; // 0
139
+ dx[6] = 0.0f ; // 0
140
+ dx[7] = -1.0925484305920792f*z ; // -sqrt(15)*z/(2*sqrt(pi))
141
+ dx[8] = 1.0925484305920792f*x ; // sqrt(15)*x/(2*sqrt(pi))
142
+ if (C <= 3) { return; }
143
+ dx[9] = -3.5402615395598609f*xy ; // -3*sqrt(70)*xy/(4*sqrt(pi))
144
+ dx[10] = 2.8906114426405538f*yz ; // sqrt(105)*yz/(2*sqrt(pi))
145
+ dx[11] = 0.0f ; // 0
146
+ dx[12] = 0.0f ; // 0
147
+ dx[13] = 0.45704579946446572f - 2.2852289973223288f*z2 ; // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi))
148
+ dx[14] = 2.8906114426405538f*xz ; // sqrt(105)*xz/(2*sqrt(pi))
149
+ dx[15] = -1.7701307697799304f*x2 + 1.7701307697799304f*y2 ; // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi))
150
+ if (C <= 4) { return; }
151
+ dx[16] = 2.5033429417967046f*y*(3.0f*x2 - y2) ; // 3*sqrt(35)*y*(3*x2 - y2)/(4*sqrt(pi))
152
+ dx[17] = -10.620784618679583f*xy*z ; // -9*sqrt(70)*xy*z/(4*sqrt(pi))
153
+ dx[18] = 0.94617469575756008f*y*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*y*(7*z2 - 1)/(4*sqrt(pi))
154
+ dx[19] = 0.0f ; // 0
155
+ dx[20] = 0.0f ; // 0
156
+ dx[21] = 0.66904654355728921f*z*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi))
157
+ dx[22] = 0.94617469575756008f*x*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi))
158
+ dx[23] = 5.3103923093397913f*z*(-x2 + y2) ; // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi))
159
+ dx[24] = 2.5033429417967046f*x*(x2 - 3.0f*y2) ; // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi))
160
+ if (C <= 5) { return; }
161
+ dx[25] = 13.127641136803401f*xy*(-x2 + y2) ; // 15*sqrt(154)*xy*(-x2 + y2)/(8*sqrt(pi))
162
+ dx[26] = 8.3026492595241645f*yz*(3.0f*x2 - y2) ; // 3*sqrt(385)*yz*(3*x2 - y2)/(4*sqrt(pi))
163
+ dx[27] = 2.9354297966115022f*xy*(1.0f - 9.0f*z2) ; // 3*sqrt(770)*xy*(1 - 9*z2)/(16*sqrt(pi))
164
+ dx[28] = 4.7935367849733241f*yz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*yz*(3*z2 - 1)/(4*sqrt(pi))
165
+ dx[29] = 0.0f ; // 0
166
+ dx[30] = 0.0f ; // 0
167
+ dx[31] = 6.3412531167397574f*z2 - 9.5118796751096362f*z4 - 0.45294665119569694f ; // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))
168
+ dx[32] = 4.7935367849733241f*xz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi))
169
+ dx[33] = -13.209434084751759f*x2*z2 + 1.4677148983057511f*x2 + 13.209434084751759f*y2*z2 - 1.4677148983057511f*y2 ; // 3*sqrt(770)*(-9*x2*z2 + x2 + 9*y2*z2 - y2)/(32*sqrt(pi))
170
+ dx[34] = 8.3026492595241645f*xz*(x2 - 3.0f*y2) ; // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi))
171
+ dx[35] = 19.6914617052051f*x2*y2 - 3.2819102842008503f*x4 - 3.2819102842008503f*y4 ; // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))
172
+ if (C <= 6) { return; }
173
+ dx[36] = 4.0991046311514854f*y*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // 3*sqrt(6006)*y*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi))
174
+ dx[37] = 47.332383244635047f*xy*z*(-x2 + y2) ; // 15*sqrt(2002)*xy*z*(-x2 + y2)/(8*sqrt(pi))
175
+ dx[38] = 2.0182596029148963f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi))
176
+ dx[39] = 5.5272315570895412f*xy*z*(3.0f - 11.0f*z2) ; // 3*sqrt(2730)*xy*z*(3 - 11*z2)/(16*sqrt(pi))
177
+ dx[40] = 0.92120525951492349f*y*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*y*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))
178
+ dx[41] = 0.0f ; // 0
179
+ dx[42] = 0.0f ; // 0
180
+ dx[43] = 0.58262136251873131f*z*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))
181
+ dx[44] = 0.92120525951492349f*x*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))
182
+ dx[45] = -2.7636157785447706f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi))
183
+ dx[46] = 2.0182596029148963f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi))
184
+ dx[47] = 11.833095811158762f*z*(6.0f*x2*y2 - x4 - y4) ; // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))
185
+ dx[48] = 4.0991046311514854f*x*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))
186
+ if (C <= 7) { return; }
187
+ dx[49] = 9.9002782553443485f*xy*(10.0f*x2*y2 - 3.0f*x4 - 3.0f*y4) ; // 21*sqrt(715)*xy*(10*x2*y2 - 3*x4 - 3*y4)/(32*sqrt(pi))
188
+ dx[50] = 15.875763970811402f*yz*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // 9*sqrt(10010)*yz*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi))
189
+ dx[51] = -10.378311574405206f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // -15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi))
190
+ dx[52] = 4.1513246297620823f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi))
191
+ dx[53] = 0.93875360317376422f*xy*(66.0f*z2 - 143.0f*z4 - 3.0f) ; // 9*sqrt(35)*xy*(66*z2 - 143*z4 - 3)/(32*sqrt(pi))
192
+ dx[54] = 0.44253269244498261f*yz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*yz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))
193
+ dx[55] = 0.0f ; // 0
194
+ dx[56] = 0.0f ; // 0
195
+ dx[57] = -12.194767023639836f*z2 + 44.714145753346067f*z4 - 38.752259652899923f*z6 + 0.45165803791258652f ; // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))
196
+ dx[58] = 0.44253269244498261f*xz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))
197
+ dx[59] = 30.97886890473422f*x2*z2 - 67.120882626924143f*x2*z4 - 1.4081304047606462f*x2 - 30.97886890473422f*y2*z2 + 67.120882626924143f*y2*z4 + 1.4081304047606462f*y2 ; // 9*sqrt(35)*(66*x2*z2 - 143*x2*z4 - 3*x2 - 66*y2*z2 + 143*y2*z4 + 3*y2)/(64*sqrt(pi))
198
+ dx[60] = 4.1513246297620823f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi))
199
+ dx[61] = -0.51891557872026028f*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 4.0f*x2*(x2 - 5.0f*y2) + x4 + 5.0f*y4) ; // -3*sqrt(385)*(13*z2 - 1)*(-10*x2*y2 + 4*x2*(x2 - 5*y2) + x4 + 5*y4)/(64*sqrt(pi))
200
+ dx[62] = 15.875763970811402f*xz*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))
201
+ dx[63] = -74.252086915082614f*x2*y4 + 74.252086915082614f*x4*y2 - 4.9501391276721742f*x6 + 4.9501391276721742f*y6 ; // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi))
202
+ };
203
+
204
+ auto write_sh_dy = [&]() {
205
+ dy[0] = 0.0f ; // 0
206
+ if (C <= 1) { return; }
207
+ dy[1] = -0.48860251190291992f ; // -sqrt(3)/(2*sqrt(pi))
208
+ dy[2] = 0.0f ; // 0
209
+ dy[3] = 0.0f ; // 0
210
+ if (C <= 2) { return; }
211
+ dy[4] = 1.0925484305920792f*x ; // sqrt(15)*x/(2*sqrt(pi))
212
+ dy[5] = -1.0925484305920792f*z ; // -sqrt(15)*z/(2*sqrt(pi))
213
+ dy[6] = 0.0f ; // 0
214
+ dy[7] = 0.0f ; // 0
215
+ dy[8] = -1.0925484305920792f*y ; // -sqrt(15)*y/(2*sqrt(pi))
216
+ if (C <= 3) { return; }
217
+ dy[9] = -1.7701307697799304f*x2 + 1.7701307697799304f*y2 ; // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi))
218
+ dy[10] = 2.8906114426405538f*xz ; // sqrt(105)*xz/(2*sqrt(pi))
219
+ dy[11] = 0.45704579946446572f - 2.2852289973223288f*z2 ; // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi))
220
+ dy[12] = 0.0f ; // 0
221
+ dy[13] = 0.0f ; // 0
222
+ dy[14] = -2.8906114426405538f*yz ; // -sqrt(105)*yz/(2*sqrt(pi))
223
+ dy[15] = 3.5402615395598609f*xy ; // 3*sqrt(70)*xy/(4*sqrt(pi))
224
+ if (C <= 4) { return; }
225
+ dy[16] = 2.5033429417967046f*x*(x2 - 3.0f*y2) ; // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi))
226
+ dy[17] = 5.3103923093397913f*z*(-x2 + y2) ; // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi))
227
+ dy[18] = 0.94617469575756008f*x*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi))
228
+ dy[19] = 0.66904654355728921f*z*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi))
229
+ dy[20] = 0.0f ; // 0
230
+ dy[21] = 0.0f ; // 0
231
+ dy[22] = 0.94617469575756008f*y*(1.0f - 7.0f*z2) ; // 3*sqrt(5)*y*(1 - 7*z2)/(4*sqrt(pi))
232
+ dy[23] = 10.620784618679583f*xy*z ; // 9*sqrt(70)*xy*z/(4*sqrt(pi))
233
+ dy[24] = 2.5033429417967046f*y*(-3.0f*x2 + y2) ; // 3*sqrt(35)*y*(-3*x2 + y2)/(4*sqrt(pi))
234
+ if (C <= 5) { return; }
235
+ dy[25] = 19.6914617052051f*x2*y2 - 3.2819102842008503f*x4 - 3.2819102842008503f*y4 ; // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))
236
+ dy[26] = 8.3026492595241645f*xz*(x2 - 3.0f*y2) ; // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi))
237
+ dy[27] = -1.4677148983057511f*(x2 - y2)*(9.0f*z2 - 1.0f) ; // -3*sqrt(770)*(x2 - y2)*(9*z2 - 1)/(32*sqrt(pi))
238
+ dy[28] = 4.7935367849733241f*xz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi))
239
+ dy[29] = 6.3412531167397574f*z2 - 9.5118796751096362f*z4 - 0.45294665119569694f ; // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))
240
+ dy[30] = 0.0f ; // 0
241
+ dy[31] = 0.0f ; // 0
242
+ dy[32] = 4.7935367849733241f*yz*(1.0f - 3.0f*z2) ; // sqrt(1155)*yz*(1 - 3*z2)/(4*sqrt(pi))
243
+ dy[33] = 2.9354297966115022f*xy*(9.0f*z2 - 1.0f) ; // 3*sqrt(770)*xy*(9*z2 - 1)/(16*sqrt(pi))
244
+ dy[34] = 8.3026492595241645f*yz*(-3.0f*x2 + y2) ; // 3*sqrt(385)*yz*(-3*x2 + y2)/(4*sqrt(pi))
245
+ dy[35] = 13.127641136803401f*xy*(x2 - y2) ; // 15*sqrt(154)*xy*(x2 - y2)/(8*sqrt(pi))
246
+ if (C <= 6) { return; }
247
+ dy[36] = 4.0991046311514854f*x*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))
248
+ dy[37] = 11.833095811158762f*z*(6.0f*x2*y2 - x4 - y4) ; // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))
249
+ dy[38] = 2.0182596029148963f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi))
250
+ dy[39] = -2.7636157785447706f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi))
251
+ dy[40] = 0.92120525951492349f*x*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))
252
+ dy[41] = 0.58262136251873131f*z*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))
253
+ dy[42] = 0.0f ; // 0
254
+ dy[43] = 0.0f ; // 0
255
+ dy[44] = 0.92120525951492349f*y*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // sqrt(2730)*y*(18*z2 - 33*z4 - 1)/(32*sqrt(pi))
256
+ dy[45] = 5.5272315570895412f*xy*z*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(16*sqrt(pi))
257
+ dy[46] = -2.0182596029148963f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi))
258
+ dy[47] = 47.332383244635047f*xy*z*(x2 - y2) ; // 15*sqrt(2002)*xy*z*(x2 - y2)/(8*sqrt(pi))
259
+ dy[48] = 4.0991046311514854f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(6006)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
260
+ if (C <= 7) { return; }
261
+ dy[49] = -74.252086915082614f*x2*y4 + 74.252086915082614f*x4*y2 - 4.9501391276721742f*x6 + 4.9501391276721742f*y6 ; // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi))
262
+ dy[50] = 15.875763970811402f*xz*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))
263
+ dy[51] = 0.51891557872026028f*(13.0f*z2 - 1.0f)*(10.0f*x2*y2 - 5.0f*x4 + 4.0f*y2*(5.0f*x2 - y2) - y4) ; // 3*sqrt(385)*(13*z2 - 1)*(10*x2*y2 - 5*x4 + 4*y2*(5*x2 - y2) - y4)/(64*sqrt(pi))
264
+ dy[52] = 4.1513246297620823f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi))
265
+ dy[53] = -0.46937680158688211f*(x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -9*sqrt(35)*(x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi))
266
+ dy[54] = 0.44253269244498261f*xz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))
267
+ dy[55] = -12.194767023639836f*z2 + 44.714145753346067f*z4 - 38.752259652899923f*z6 + 0.45165803791258652f ; // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))
268
+ dy[56] = 0.0f ; // 0
269
+ dy[57] = 0.0f ; // 0
270
+ dy[58] = 0.44253269244498261f*yz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 3*sqrt(70)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi))
271
+ dy[59] = 0.93875360317376422f*xy*(-66.0f*z2 + 143.0f*z4 + 3.0f) ; // 9*sqrt(35)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi))
272
+ dy[60] = -4.1513246297620823f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // -3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi))
273
+ dy[61] = 10.378311574405206f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // 15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi))
274
+ dy[62] = 15.875763970811402f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 9*sqrt(10010)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
275
+ dy[63] = 9.9002782553443485f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 21*sqrt(715)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))
276
+ };
277
+
278
+ auto write_sh_dz = [&]() {
279
+ dz[0] = 0.0f ; // 0
280
+ if (C <= 1) { return; }
281
+ dz[1] = 0.0f ; // 0
282
+ dz[2] = 0.48860251190291992f ; // sqrt(3)/(2*sqrt(pi))
283
+ dz[3] = 0.0f ; // 0
284
+ if (C <= 2) { return; }
285
+ dz[4] = 0.0f ; // 0
286
+ dz[5] = -1.0925484305920792f*y ; // -sqrt(15)*y/(2*sqrt(pi))
287
+ dz[6] = 1.8923493915151202f*z ; // 3*sqrt(5)*z/(2*sqrt(pi))
288
+ dz[7] = -1.0925484305920792f*x ; // -sqrt(15)*x/(2*sqrt(pi))
289
+ dz[8] = 0.0f ; // 0
290
+ if (C <= 3) { return; }
291
+ dz[9] = 0.0f ; // 0
292
+ dz[10] = 2.8906114426405538f*xy ; // sqrt(105)*xy/(2*sqrt(pi))
293
+ dz[11] = -4.5704579946446566f*yz ; // -5*sqrt(42)*yz/(4*sqrt(pi))
294
+ dz[12] = 5.597644988851731f*z2 - 1.1195289977703462f ; // 3*sqrt(7)*(5*z2 - 1)/(4*sqrt(pi))
295
+ dz[13] = -4.5704579946446566f*xz ; // -5*sqrt(42)*xz/(4*sqrt(pi))
296
+ dz[14] = 1.4453057213202769f*x2 - 1.4453057213202769f*y2 ; // sqrt(105)*(x2 - y2)/(4*sqrt(pi))
297
+ dz[15] = 0.0f ; // 0
298
+ if (C <= 4) { return; }
299
+ dz[16] = 0.0f ; // 0
300
+ dz[17] = 1.7701307697799304f*y*(-3.0f*x2 + y2) ; // 3*sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi))
301
+ dz[18] = 13.246445740605839f*xy*z ; // 21*sqrt(5)*xy*z/(2*sqrt(pi))
302
+ dz[19] = 2.0071396306718676f*y*(1.0f - 7.0f*z2) ; // 9*sqrt(10)*y*(1 - 7*z2)/(8*sqrt(pi))
303
+ dz[20] = 14.809976568128603f*pow(z, 3) - 6.3471328149122579f*z ; // (105*z**3 - 45*z)/(4*sqrt(pi))
304
+ dz[21] = 2.0071396306718676f*x*(1.0f - 7.0f*z2) ; // 9*sqrt(10)*x*(1 - 7*z2)/(8*sqrt(pi))
305
+ dz[22] = 6.6232228703029197f*z*(x2 - y2) ; // 21*sqrt(5)*z*(x2 - y2)/(4*sqrt(pi))
306
+ dz[23] = 1.7701307697799304f*x*(-x2 + 3.0f*y2) ; // 3*sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi))
307
+ dz[24] = 0.0f ; // 0
308
+ if (C <= 5) { return; }
309
+ dz[25] = 0.0f ; // 0
310
+ dz[26] = 8.3026492595241645f*xy*(x2 - y2) ; // 3*sqrt(385)*xy*(x2 - y2)/(4*sqrt(pi))
311
+ dz[27] = 8.8062893898345074f*yz*(-3.0f*x2 + y2) ; // 9*sqrt(770)*yz*(-3*x2 + y2)/(16*sqrt(pi))
312
+ dz[28] = 4.7935367849733241f*xy*(9.0f*z2 - 1.0f) ; // sqrt(1155)*xy*(9*z2 - 1)/(4*sqrt(pi))
313
+ dz[29] = 12.682506233479513f*yz*(1.0f - 3.0f*z2) ; // 7*sqrt(165)*yz*(1 - 3*z2)/(4*sqrt(pi))
314
+ dz[30] = -24.559567715218954f*z2 + 36.839351572828434f*z4 + 1.754254836801354f ; // 15*sqrt(11)*(-14*z2 + 21*z4 + 1)/(16*sqrt(pi))
315
+ dz[31] = 12.682506233479513f*xz*(1.0f - 3.0f*z2) ; // 7*sqrt(165)*xz*(1 - 3*z2)/(4*sqrt(pi))
316
+ dz[32] = 2.3967683924866621f*(x2 - y2)*(9.0f*z2 - 1.0f) ; // sqrt(1155)*(x2 - y2)*(9*z2 - 1)/(8*sqrt(pi))
317
+ dz[33] = 8.8062893898345074f*xz*(-x2 + 3.0f*y2) ; // 9*sqrt(770)*xz*(-x2 + 3*y2)/(16*sqrt(pi))
318
+ dz[34] = -12.453973889286246f*x2*y2 + 2.0756623148810411f*x4 + 2.0756623148810411f*y4 ; // 3*sqrt(385)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))
319
+ dz[35] = 0.0f ; // 0
320
+ if (C <= 6) { return; }
321
+ dz[36] = 0.0f ; // 0
322
+ dz[37] = 2.3666191622317521f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(2002)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
323
+ dz[38] = 44.401711264127719f*xy*z*(x2 - y2) ; // 33*sqrt(91)*xy*z*(x2 - y2)/(4*sqrt(pi))
324
+ dz[39] = -2.7636157785447706f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(2730)*y*(3*x2 - y2)*(11*z2 - 1)/(32*sqrt(pi))
325
+ dz[40] = 11.054463114179082f*xy*z*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(8*sqrt(pi))
326
+ dz[41] = 2.9131068125936568f*y*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // 5*sqrt(273)*y*(18*z2 - 33*z4 - 1)/(16*sqrt(pi))
327
+ dz[42] = 2.6699064952403937f*z*(-30.0f*z2 + 33.0f*z4 + 5.0f) ; // 21*sqrt(13)*z*(-30*z2 + 33*z4 + 5)/(16*sqrt(pi))
328
+ dz[43] = 2.9131068125936568f*x*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // 5*sqrt(273)*x*(18*z2 - 33*z4 - 1)/(16*sqrt(pi))
329
+ dz[44] = 5.5272315570895412f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(16*sqrt(pi))
330
+ dz[45] = -2.7636157785447706f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(2730)*x*(x2 - 3*y2)*(11*z2 - 1)/(32*sqrt(pi))
331
+ dz[46] = 11.10042781603193f*z*(-6.0f*x2*y2 + x4 + y4) ; // 33*sqrt(91)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))
332
+ dz[47] = 2.3666191622317521f*x*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(2002)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))
333
+ dz[48] = 0.0f ; // 0
334
+ if (C <= 7) { return; }
335
+ dz[49] = 0.0f ; // 0
336
+ dz[50] = 5.2919213236038001f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 3*sqrt(10010)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))
337
+ dz[51] = 13.491805046726766f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 39*sqrt(385)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
338
+ dz[52] = 12.453973889286248f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // 9*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(8*sqrt(pi))
339
+ dz[53] = -6.8841930899409371f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // -33*sqrt(35)*yz*(3*x2 - y2)*(13*z2 - 3)/(16*sqrt(pi))
340
+ dz[54] = 2.2126634622249131f*xy*(-66.0f*z2 + 143.0f*z4 + 3.0f) ; // 15*sqrt(70)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi))
341
+ dz[55] = 1.6259689364853116f*yz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 9*sqrt(105)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi))
342
+ dz[56] = 64.528641681844675f*z2 - 236.60501950009714f*z4 + 205.05768356675085f*z6 - 2.3899496919201733f ; // 7*sqrt(15)*(135*z2 - 495*z4 + 429*z6 - 5)/(32*sqrt(pi))
343
+ dz[57] = 1.6259689364853116f*xz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 9*sqrt(105)*xz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi))
344
+ dz[58] = 0.07375544874083044f*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) + 132.0f*z2*(13.0f*z2 - 5.0f) - 187.0f*z2 + 45.0f) ; // sqrt(70)*(x2 - y2)*(143*z2*(3*z2 - 1) + 132*z2*(13*z2 - 5) - 187*z2 + 45)/(64*sqrt(pi))
345
+ dz[59] = -6.8841930899409371f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // -33*sqrt(35)*xz*(x2 - 3*y2)*(13*z2 - 3)/(16*sqrt(pi))
346
+ dz[60] = 3.1134934723215619f*(13.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 9*sqrt(385)*(13*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi))
347
+ dz[61] = 13.491805046726766f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 39*sqrt(385)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))
348
+ dz[62] = 39.6894099270285f*x2*y4 - 39.6894099270285f*x4*y2 + 2.6459606618019f*x6 - 2.6459606618019f*y6 ; // 3*sqrt(10010)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi))
349
+ dz[63] = 0.0f ; // 0
350
+ };
351
+ write_sh_dx();
352
+ write_sh_dy();
353
+ write_sh_dz();
354
+ }
355
+ }
356
+
357
+
358
+ template <typename scalar_t>
359
+ __global__ void kernel_sh_backward(
360
+ const scalar_t * __restrict__ grad,
361
+ const scalar_t * __restrict__ inputs,
362
+ uint32_t B, uint32_t D, uint32_t C,
363
+ const scalar_t * __restrict__ dy_dx,
364
+ scalar_t * grad_inputs
365
+ ) {
366
+ const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
367
+ const uint32_t b = t / D;
368
+ if (b >= B) return;
369
+
370
+ const uint32_t d = t - b * D;
371
+ const uint32_t C2 = C * C;
372
+
373
+ // locate
374
+ grad += b * C2;
375
+ dy_dx += b * D * C2 + d * C2;
376
+
377
+ for (int ch = 0; ch < C2; ch++) {
378
+ grad_inputs[t] += grad[ch] * dy_dx[ch];
379
+ //printf("t=%d, b=%d, d=%d, ch=%d, grad=%f (+= %f * %f)\n", t, b, d, ch, grad_inputs[t], grad[ch], dy_dx[ch]);
380
+ }
381
+
382
+ }
383
+
384
+ // inputs: [B, D], float, in [0, 1]
385
+ // outputs: [B, L * C], float
386
+ template <typename scalar_t>
387
+ void sh_encode_forward_cuda(const scalar_t *inputs, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, scalar_t *dy_dx) {
388
+ static constexpr uint32_t N_THREADS = 256;
389
+ kernel_sh<scalar_t><<<div_round_up(B, N_THREADS), N_THREADS>>>(inputs, outputs, B, D, C, dy_dx);
390
+ }
391
+
392
+
393
+ template <typename scalar_t>
394
+ void sh_encode_backward_cuda(const scalar_t *grad, const scalar_t *inputs, const uint32_t B, const uint32_t D, const uint32_t C, scalar_t *dy_dx, scalar_t *grad_inputs) {
395
+ static constexpr uint32_t N_THREADS = 256;
396
+ kernel_sh_backward<scalar_t><<<div_round_up(B * D, N_THREADS), N_THREADS>>>(grad, inputs, B, D, C, dy_dx, grad_inputs);
397
+ }
398
+
399
+
400
+ void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional<at::Tensor> dy_dx) {
401
+ CHECK_CUDA(inputs);
402
+ CHECK_CUDA(outputs);
403
+ // CHECK_CUDA(dy_dx);
404
+
405
+ CHECK_CONTIGUOUS(inputs);
406
+ CHECK_CONTIGUOUS(outputs);
407
+ // CHECK_CONTIGUOUS(dy_dx);
408
+
409
+ CHECK_IS_FLOATING(inputs);
410
+ CHECK_IS_FLOATING(outputs);
411
+ // CHECK_IS_FLOATING(dy_dx);
412
+
413
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
414
+ inputs.scalar_type(), "sh_encode_forward_cuda", ([&] {
415
+ sh_encode_forward_cuda<scalar_t>(inputs.data_ptr<scalar_t>(), outputs.data_ptr<scalar_t>(), B, D, C, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr);
416
+ }));
417
+ }
418
+
419
+ void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs) {
420
+ CHECK_CUDA(grad);
421
+ CHECK_CUDA(inputs);
422
+ CHECK_CUDA(dy_dx);
423
+ CHECK_CUDA(grad_inputs);
424
+
425
+ CHECK_CONTIGUOUS(grad);
426
+ CHECK_CONTIGUOUS(inputs);
427
+ CHECK_CONTIGUOUS(dy_dx);
428
+ CHECK_CONTIGUOUS(grad_inputs);
429
+
430
+ CHECK_IS_FLOATING(grad);
431
+ CHECK_IS_FLOATING(inputs);
432
+ CHECK_IS_FLOATING(dy_dx);
433
+ CHECK_IS_FLOATING(grad_inputs);
434
+
435
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
436
+ grad.scalar_type(), "sh_encode_backward_cuda", ([&] {
437
+ sh_encode_backward_cuda<scalar_t>(grad.data_ptr<scalar_t>(), inputs.data_ptr<scalar_t>(), B, D, C, dy_dx.data_ptr<scalar_t>(), grad_inputs.data_ptr<scalar_t>());
438
+ }));
439
+ }