Spaces:
Runtime error
Runtime error
Duplicate from ashawkey/stable-dreamfusion
Browse filesCo-authored-by: ashawkey <ashawkey@users.noreply.huggingface.co>
This view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +31 -0
- LICENSE +201 -0
- README.md +14 -0
- activation.py +18 -0
- app.py +227 -0
- assets/update_logs.md +9 -0
- docker/Dockerfile +53 -0
- docker/README.md +80 -0
- encoding.py +33 -0
- freqencoder/__init__.py +1 -0
- freqencoder/backend.py +41 -0
- freqencoder/freq.py +77 -0
- freqencoder/setup.py +51 -0
- freqencoder/src/bindings.cpp +8 -0
- freqencoder/src/freqencoder.cu +129 -0
- freqencoder/src/freqencoder.h +10 -0
- gridencoder/__init__.py +1 -0
- gridencoder/backend.py +40 -0
- gridencoder/grid.py +154 -0
- gridencoder/setup.py +50 -0
- gridencoder/src/bindings.cpp +8 -0
- gridencoder/src/gridencoder.cu +479 -0
- gridencoder/src/gridencoder.h +15 -0
- main.py +160 -0
- nerf/clip.py +45 -0
- nerf/gui.py +465 -0
- nerf/network.py +174 -0
- nerf/network_grid.py +181 -0
- nerf/network_tcnn.py +174 -0
- nerf/provider.py +214 -0
- nerf/renderer.py +645 -0
- nerf/sd.py +203 -0
- nerf/utils.py +950 -0
- optimizer.py +470 -0
- raymarching/__init__.py +1 -0
- raymarching/backend.py +40 -0
- raymarching/raymarching.py +373 -0
- raymarching/setup.py +62 -0
- raymarching/src/bindings.cpp +19 -0
- raymarching/src/raymarching.cu +914 -0
- raymarching/src/raymarching.h +18 -0
- requirements.txt +21 -0
- scripts/install_ext.sh +4 -0
- scripts/run.sh +5 -0
- shencoder/__init__.py +1 -0
- shencoder/backend.py +40 -0
- shencoder/setup.py +50 -0
- shencoder/sphere_harmonics.py +87 -0
- shencoder/src/bindings.cpp +8 -0
- shencoder/src/shencoder.cu +439 -0
.gitattributes
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
23 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Stable Dreamfusion
|
3 |
+
emoji: 🍍
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: indigo
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.5
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
+
duplicated_from: ashawkey/stable-dreamfusion
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
activation.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.autograd import Function
|
3 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
4 |
+
|
5 |
+
class _trunc_exp(Function):
|
6 |
+
@staticmethod
|
7 |
+
@custom_fwd(cast_inputs=torch.float)
|
8 |
+
def forward(ctx, x):
|
9 |
+
ctx.save_for_backward(x)
|
10 |
+
return torch.exp(x)
|
11 |
+
|
12 |
+
@staticmethod
|
13 |
+
@custom_bwd
|
14 |
+
def backward(ctx, g):
|
15 |
+
x = ctx.saved_tensors[0]
|
16 |
+
return g * torch.exp(x.clamp(-15, 15))
|
17 |
+
|
18 |
+
trunc_exp = _trunc_exp.apply
|
app.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import argparse
|
3 |
+
|
4 |
+
from nerf.provider import NeRFDataset
|
5 |
+
from nerf.utils import *
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import gc
|
9 |
+
|
10 |
+
print(f'[INFO] loading options..')
|
11 |
+
|
12 |
+
# fake config object, this should not be used in CMD, only allow change from gradio UI.
|
13 |
+
parser = argparse.ArgumentParser()
|
14 |
+
parser.add_argument('--text', default=None, help="text prompt")
|
15 |
+
# parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray --dir_text")
|
16 |
+
# parser.add_argument('-O2', action='store_true', help="equals --fp16 --dir_text")
|
17 |
+
parser.add_argument('--test', action='store_true', help="test mode")
|
18 |
+
parser.add_argument('--save_mesh', action='store_true', help="export an obj mesh with texture")
|
19 |
+
parser.add_argument('--eval_interval', type=int, default=10, help="evaluate on the valid set every interval epochs")
|
20 |
+
parser.add_argument('--workspace', type=str, default='trial_gradio')
|
21 |
+
parser.add_argument('--guidance', type=str, default='stable-diffusion', help='choose from [stable-diffusion, clip]')
|
22 |
+
parser.add_argument('--seed', type=int, default=0)
|
23 |
+
|
24 |
+
### training options
|
25 |
+
parser.add_argument('--iters', type=int, default=10000, help="training iters")
|
26 |
+
parser.add_argument('--lr', type=float, default=1e-3, help="initial learning rate")
|
27 |
+
parser.add_argument('--ckpt', type=str, default='latest')
|
28 |
+
parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
|
29 |
+
parser.add_argument('--max_steps', type=int, default=1024, help="max num steps sampled per ray (only valid when using --cuda_ray)")
|
30 |
+
parser.add_argument('--num_steps', type=int, default=64, help="num steps sampled per ray (only valid when not using --cuda_ray)")
|
31 |
+
parser.add_argument('--upsample_steps', type=int, default=64, help="num steps up-sampled per ray (only valid when not using --cuda_ray)")
|
32 |
+
parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)")
|
33 |
+
parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when not using --cuda_ray)")
|
34 |
+
parser.add_argument('--albedo_iters', type=int, default=1000, help="training iters that only use albedo shading")
|
35 |
+
# model options
|
36 |
+
parser.add_argument('--bg_radius', type=float, default=1.4, help="if positive, use a background model at sphere(bg_radius)")
|
37 |
+
parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied")
|
38 |
+
# network backbone
|
39 |
+
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
|
40 |
+
parser.add_argument('--backbone', type=str, default='grid', help="nerf backbone, choose from [grid, tcnn, vanilla]")
|
41 |
+
# rendering resolution in training, decrease this if CUDA OOM.
|
42 |
+
parser.add_argument('--w', type=int, default=64, help="render width for NeRF in training")
|
43 |
+
parser.add_argument('--h', type=int, default=64, help="render height for NeRF in training")
|
44 |
+
parser.add_argument('--jitter_pose', action='store_true', help="add jitters to the randomly sampled camera poses")
|
45 |
+
|
46 |
+
### dataset options
|
47 |
+
parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box(-bound, bound)")
|
48 |
+
parser.add_argument('--dt_gamma', type=float, default=0, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)")
|
49 |
+
parser.add_argument('--min_near', type=float, default=0.1, help="minimum near distance for camera")
|
50 |
+
parser.add_argument('--radius_range', type=float, nargs='*', default=[1.0, 1.5], help="training camera radius range")
|
51 |
+
parser.add_argument('--fovy_range', type=float, nargs='*', default=[40, 70], help="training camera fovy range")
|
52 |
+
parser.add_argument('--dir_text', action='store_true', help="direction-encode the text prompt, by appending front/side/back/overhead view")
|
53 |
+
parser.add_argument('--angle_overhead', type=float, default=30, help="[0, angle_overhead] is the overhead region")
|
54 |
+
parser.add_argument('--angle_front', type=float, default=60, help="[0, angle_front] is the front region, [180, 180+angle_front] the back region, otherwise the side region.")
|
55 |
+
|
56 |
+
parser.add_argument('--lambda_entropy', type=float, default=1e-4, help="loss scale for alpha entropy")
|
57 |
+
parser.add_argument('--lambda_opacity', type=float, default=0, help="loss scale for alpha value")
|
58 |
+
parser.add_argument('--lambda_orient', type=float, default=1e-2, help="loss scale for orientation")
|
59 |
+
|
60 |
+
### GUI options
|
61 |
+
parser.add_argument('--gui', action='store_true', help="start a GUI")
|
62 |
+
parser.add_argument('--W', type=int, default=800, help="GUI width")
|
63 |
+
parser.add_argument('--H', type=int, default=800, help="GUI height")
|
64 |
+
parser.add_argument('--radius', type=float, default=3, help="default GUI camera radius from center")
|
65 |
+
parser.add_argument('--fovy', type=float, default=60, help="default GUI camera fovy")
|
66 |
+
parser.add_argument('--light_theta', type=float, default=60, help="default GUI light direction in [0, 180], corresponding to elevation [90, -90]")
|
67 |
+
parser.add_argument('--light_phi', type=float, default=0, help="default GUI light direction in [0, 360), azimuth")
|
68 |
+
parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel")
|
69 |
+
|
70 |
+
opt = parser.parse_args()
|
71 |
+
|
72 |
+
# default to use -O !!!
|
73 |
+
opt.fp16 = True
|
74 |
+
opt.dir_text = True
|
75 |
+
opt.cuda_ray = True
|
76 |
+
# opt.lambda_entropy = 1e-4
|
77 |
+
# opt.lambda_opacity = 0
|
78 |
+
|
79 |
+
if opt.backbone == 'vanilla':
|
80 |
+
from nerf.network import NeRFNetwork
|
81 |
+
elif opt.backbone == 'tcnn':
|
82 |
+
from nerf.network_tcnn import NeRFNetwork
|
83 |
+
elif opt.backbone == 'grid':
|
84 |
+
from nerf.network_grid import NeRFNetwork
|
85 |
+
else:
|
86 |
+
raise NotImplementedError(f'--backbone {opt.backbone} is not implemented!')
|
87 |
+
|
88 |
+
print(opt)
|
89 |
+
|
90 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
91 |
+
|
92 |
+
print(f'[INFO] loading models..')
|
93 |
+
|
94 |
+
if opt.guidance == 'stable-diffusion':
|
95 |
+
from nerf.sd import StableDiffusion
|
96 |
+
guidance = StableDiffusion(device)
|
97 |
+
elif opt.guidance == 'clip':
|
98 |
+
from nerf.clip import CLIP
|
99 |
+
guidance = CLIP(device)
|
100 |
+
else:
|
101 |
+
raise NotImplementedError(f'--guidance {opt.guidance} is not implemented.')
|
102 |
+
|
103 |
+
train_loader = NeRFDataset(opt, device=device, type='train', H=opt.h, W=opt.w, size=100).dataloader()
|
104 |
+
valid_loader = NeRFDataset(opt, device=device, type='val', H=opt.H, W=opt.W, size=5).dataloader()
|
105 |
+
test_loader = NeRFDataset(opt, device=device, type='test', H=opt.H, W=opt.W, size=100).dataloader()
|
106 |
+
|
107 |
+
print(f'[INFO] everything loaded!')
|
108 |
+
|
109 |
+
trainer = None
|
110 |
+
model = None
|
111 |
+
|
112 |
+
# define UI
|
113 |
+
|
114 |
+
with gr.Blocks(css=".gradio-container {max-width: 512px; margin: auto;}") as demo:
|
115 |
+
|
116 |
+
# title
|
117 |
+
gr.Markdown('[Stable-DreamFusion](https://github.com/ashawkey/stable-dreamfusion) Text-to-3D Example')
|
118 |
+
|
119 |
+
# inputs
|
120 |
+
prompt = gr.Textbox(label="Prompt", max_lines=1, value="a DSLR photo of a koi fish")
|
121 |
+
iters = gr.Slider(label="Iters", minimum=1000, maximum=20000, value=5000, step=100)
|
122 |
+
seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True)
|
123 |
+
button = gr.Button('Generate')
|
124 |
+
|
125 |
+
# outputs
|
126 |
+
image = gr.Image(label="image", visible=True)
|
127 |
+
video = gr.Video(label="video", visible=False)
|
128 |
+
logs = gr.Textbox(label="logging")
|
129 |
+
|
130 |
+
# gradio main func
|
131 |
+
def submit(text, iters, seed):
|
132 |
+
|
133 |
+
global trainer, model
|
134 |
+
|
135 |
+
# seed
|
136 |
+
opt.seed = seed
|
137 |
+
opt.text = text
|
138 |
+
opt.iters = iters
|
139 |
+
|
140 |
+
seed_everything(seed)
|
141 |
+
|
142 |
+
# clean up
|
143 |
+
if trainer is not None:
|
144 |
+
del model
|
145 |
+
del trainer
|
146 |
+
gc.collect()
|
147 |
+
torch.cuda.empty_cache()
|
148 |
+
print('[INFO] clean up!')
|
149 |
+
|
150 |
+
# simply reload everything...
|
151 |
+
model = NeRFNetwork(opt)
|
152 |
+
optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15)
|
153 |
+
scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1))
|
154 |
+
|
155 |
+
trainer = Trainer('df', opt, model, guidance, device=device, workspace=opt.workspace, optimizer=optimizer, ema_decay=0.95, fp16=opt.fp16, lr_scheduler=scheduler, use_checkpoint=opt.ckpt, eval_interval=opt.eval_interval, scheduler_update_every_step=True)
|
156 |
+
|
157 |
+
# train (every ep only contain 8 steps, so we can get some vis every ~10s)
|
158 |
+
STEPS = 8
|
159 |
+
max_epochs = np.ceil(opt.iters / STEPS).astype(np.int32)
|
160 |
+
|
161 |
+
# we have to get the explicit training loop out here to yield progressive results...
|
162 |
+
loader = iter(valid_loader)
|
163 |
+
|
164 |
+
start_t = time.time()
|
165 |
+
|
166 |
+
for epoch in range(max_epochs):
|
167 |
+
|
168 |
+
trainer.train_gui(train_loader, step=STEPS)
|
169 |
+
|
170 |
+
# manual test and get intermediate results
|
171 |
+
try:
|
172 |
+
data = next(loader)
|
173 |
+
except StopIteration:
|
174 |
+
loader = iter(valid_loader)
|
175 |
+
data = next(loader)
|
176 |
+
|
177 |
+
trainer.model.eval()
|
178 |
+
|
179 |
+
if trainer.ema is not None:
|
180 |
+
trainer.ema.store()
|
181 |
+
trainer.ema.copy_to()
|
182 |
+
|
183 |
+
with torch.no_grad():
|
184 |
+
with torch.cuda.amp.autocast(enabled=trainer.fp16):
|
185 |
+
preds, preds_depth = trainer.test_step(data, perturb=False)
|
186 |
+
|
187 |
+
if trainer.ema is not None:
|
188 |
+
trainer.ema.restore()
|
189 |
+
|
190 |
+
pred = preds[0].detach().cpu().numpy()
|
191 |
+
# pred_depth = preds_depth[0].detach().cpu().numpy()
|
192 |
+
|
193 |
+
pred = (pred * 255).astype(np.uint8)
|
194 |
+
|
195 |
+
yield {
|
196 |
+
image: gr.update(value=pred, visible=True),
|
197 |
+
video: gr.update(visible=False),
|
198 |
+
logs: f"training iters: {epoch * STEPS} / {iters}, lr: {trainer.optimizer.param_groups[0]['lr']:.6f}",
|
199 |
+
}
|
200 |
+
|
201 |
+
|
202 |
+
# test
|
203 |
+
trainer.test(test_loader)
|
204 |
+
|
205 |
+
results = glob.glob(os.path.join(opt.workspace, 'results', '*rgb*.mp4'))
|
206 |
+
assert results is not None, "cannot retrieve results!"
|
207 |
+
results.sort(key=lambda x: os.path.getmtime(x)) # sort by mtime
|
208 |
+
|
209 |
+
end_t = time.time()
|
210 |
+
|
211 |
+
yield {
|
212 |
+
image: gr.update(visible=False),
|
213 |
+
video: gr.update(value=results[-1], visible=True),
|
214 |
+
logs: f"Generation Finished in {(end_t - start_t)/ 60:.4f} minutes!",
|
215 |
+
}
|
216 |
+
|
217 |
+
|
218 |
+
button.click(
|
219 |
+
submit,
|
220 |
+
[prompt, iters, seed],
|
221 |
+
[image, video, logs]
|
222 |
+
)
|
223 |
+
|
224 |
+
# concurrency_count: only allow ONE running progress, else GPU will OOM.
|
225 |
+
demo.queue(concurrency_count=1)
|
226 |
+
|
227 |
+
demo.launch()
|
assets/update_logs.md
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### 2022.10.9
|
2 |
+
* The shading (partially) starts to work, at least it won't make scene empty. For some prompts, it shows better results (less severe Janus problem). The textureless rendering mode is still disabled.
|
3 |
+
* Enable shading by default (--albedo_iters 1000).
|
4 |
+
|
5 |
+
### 2022.10.5
|
6 |
+
* Basic reproduction finished.
|
7 |
+
* Non --cuda_ray, --tcnn are not working, need to fix.
|
8 |
+
* Shading is not working, disabled in utils.py for now. Surface normals are bad.
|
9 |
+
* Use an entropy loss to regularize weights_sum (alpha), the original L2 reg always leads to degenerated geometry...
|
docker/Dockerfile
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04
|
2 |
+
|
3 |
+
# Remove any third-party apt sources to avoid issues with expiring keys.
|
4 |
+
RUN rm -f /etc/apt/sources.list.d/*.list
|
5 |
+
|
6 |
+
RUN apt-get update
|
7 |
+
|
8 |
+
RUN DEBIAN_FRONTEND=noninteractive TZ=Europe/MADRID apt-get install -y tzdata
|
9 |
+
|
10 |
+
# Install some basic utilities
|
11 |
+
RUN apt-get install -y \
|
12 |
+
curl \
|
13 |
+
ca-certificates \
|
14 |
+
sudo \
|
15 |
+
git \
|
16 |
+
bzip2 \
|
17 |
+
libx11-6 \
|
18 |
+
python3 \
|
19 |
+
python3-pip \
|
20 |
+
libglfw3-dev \
|
21 |
+
libgles2-mesa-dev \
|
22 |
+
libglib2.0-0 \
|
23 |
+
&& rm -rf /var/lib/apt/lists/*
|
24 |
+
|
25 |
+
|
26 |
+
# Create a working directory
|
27 |
+
RUN mkdir /app
|
28 |
+
WORKDIR /app
|
29 |
+
|
30 |
+
RUN cd /app
|
31 |
+
RUN git clone https://github.com/ashawkey/stable-dreamfusion.git
|
32 |
+
|
33 |
+
|
34 |
+
RUN pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116
|
35 |
+
|
36 |
+
WORKDIR /app/stable-dreamfusion
|
37 |
+
|
38 |
+
RUN pip3 install -r requirements.txt
|
39 |
+
RUN pip3 install git+https://github.com/NVlabs/nvdiffrast/
|
40 |
+
|
41 |
+
# Needs nvidia runtime, if you have "No CUDA runtime is found" error: https://stackoverflow.com/questions/59691207/docker-build-with-nvidia-runtime, first answer
|
42 |
+
RUN pip3 install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
|
43 |
+
|
44 |
+
RUN pip3 install git+https://github.com/openai/CLIP.git
|
45 |
+
RUN bash scripts/install_ext.sh
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
# Set the default command to python3
|
52 |
+
#CMD ["python3"]
|
53 |
+
|
docker/README.md
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Docker installation
|
2 |
+
|
3 |
+
## Build image
|
4 |
+
To build the docker image on your own machine, which may take 15-30 mins:
|
5 |
+
```
|
6 |
+
docker build -t stable-dreamfusion:latest .
|
7 |
+
```
|
8 |
+
|
9 |
+
If you have the error **No CUDA runtime is found** when building the wheels for tiny-cuda-nn you need to setup the nvidia-runtime for docker.
|
10 |
+
```
|
11 |
+
sudo apt-get install nvidia-container-runtime
|
12 |
+
```
|
13 |
+
Then edit `/etc/docker/daemon.json` and add the default-runtime:
|
14 |
+
```
|
15 |
+
{
|
16 |
+
"runtimes": {
|
17 |
+
"nvidia": {
|
18 |
+
"path": "nvidia-container-runtime",
|
19 |
+
"runtimeArgs": []
|
20 |
+
}
|
21 |
+
},
|
22 |
+
"default-runtime": "nvidia"
|
23 |
+
}
|
24 |
+
```
|
25 |
+
And restart docker:
|
26 |
+
```
|
27 |
+
sudo systemctl restart docker
|
28 |
+
```
|
29 |
+
Now you can build tiny-cuda-nn inside docker.
|
30 |
+
|
31 |
+
## Download image
|
32 |
+
To download the image (~6GB) instead:
|
33 |
+
```
|
34 |
+
docker pull supercabb/stable-dreamfusion:3080_0.0.1
|
35 |
+
docker tag supercabb/stable-dreamfusion:3080_0.0.1 stable-dreamfusion
|
36 |
+
```
|
37 |
+
|
38 |
+
## Use image
|
39 |
+
|
40 |
+
You can launch an interactive shell inside the container:
|
41 |
+
|
42 |
+
```
|
43 |
+
docker run --gpus all -it --rm -v $(cd ~ && pwd):/mnt stable-dreamfusion /bin/bash
|
44 |
+
```
|
45 |
+
From this shell, all the code in the repo should work.
|
46 |
+
|
47 |
+
To run any single command `<command...>` inside the docker container:
|
48 |
+
```
|
49 |
+
docker run --gpus all -it --rm -v $(cd ~ && pwd):/mnt stable-dreamfusion /bin/bash -c "<command...>"
|
50 |
+
```
|
51 |
+
To train:
|
52 |
+
```
|
53 |
+
export TOKEN="#HUGGING FACE ACCESS TOKEN#"
|
54 |
+
docker run --gpus all -it --rm -v $(cd ~ && pwd):/mnt stable-dreamfusion /bin/bash -c "echo ${TOKEN} > TOKEN \
|
55 |
+
&& python3 main.py --text \"a hamburger\" --workspace trial -O"
|
56 |
+
|
57 |
+
```
|
58 |
+
Run test without gui:
|
59 |
+
```
|
60 |
+
export PATH_TO_WORKSPACE="#PATH_TO_WORKSPACE#"
|
61 |
+
docker run --gpus all -it --rm -e DISPLAY=$DISPLAY -v /tmp/.X11-unix:/tmp/.X11-unix:ro -v $(cd ~ && pwd):/mnt \
|
62 |
+
-v $(cd ${PATH_TO_WORKSPACE} && pwd):/app/stable-dreamfusion/trial stable-dreamfusion /bin/bash -c "python3 \
|
63 |
+
main.py --workspace trial -O --test"
|
64 |
+
```
|
65 |
+
Run test with gui:
|
66 |
+
```
|
67 |
+
export PATH_TO_WORKSPACE="#PATH_TO_WORKSPACE#"
|
68 |
+
xhost +
|
69 |
+
docker run --gpus all -it --rm -e DISPLAY=$DISPLAY -v /tmp/.X11-unix:/tmp/.X11-unix:ro -v $(cd ~ && pwd):/mnt \
|
70 |
+
-v $(cd ${PATH_TO_WORKSPACE} && pwd):/app/stable-dreamfusion/trial stable-dreamfusion /bin/bash -c "python3 \
|
71 |
+
main.py --workspace trial -O --test --gui"
|
72 |
+
xhost -
|
73 |
+
```
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
|
encoding.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
def get_encoder(encoding, input_dim=3,
|
6 |
+
multires=6,
|
7 |
+
degree=4,
|
8 |
+
num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False,
|
9 |
+
**kwargs):
|
10 |
+
|
11 |
+
if encoding == 'None':
|
12 |
+
return lambda x, **kwargs: x, input_dim
|
13 |
+
|
14 |
+
elif encoding == 'frequency':
|
15 |
+
from freqencoder import FreqEncoder
|
16 |
+
encoder = FreqEncoder(input_dim=input_dim, degree=multires)
|
17 |
+
|
18 |
+
elif encoding == 'sphere_harmonics':
|
19 |
+
from shencoder import SHEncoder
|
20 |
+
encoder = SHEncoder(input_dim=input_dim, degree=degree)
|
21 |
+
|
22 |
+
elif encoding == 'hashgrid':
|
23 |
+
from gridencoder import GridEncoder
|
24 |
+
encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners)
|
25 |
+
|
26 |
+
elif encoding == 'tiledgrid':
|
27 |
+
from gridencoder import GridEncoder
|
28 |
+
encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners)
|
29 |
+
|
30 |
+
else:
|
31 |
+
raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, sphere_harmonics, hashgrid, tiledgrid]')
|
32 |
+
|
33 |
+
return encoder, encoder.output_dim
|
freqencoder/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .freq import FreqEncoder
|
freqencoder/backend.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from torch.utils.cpp_extension import load
|
3 |
+
|
4 |
+
_src_path = os.path.dirname(os.path.abspath(__file__))
|
5 |
+
|
6 |
+
nvcc_flags = [
|
7 |
+
'-O3', '-std=c++14',
|
8 |
+
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
9 |
+
'-use_fast_math'
|
10 |
+
]
|
11 |
+
|
12 |
+
if os.name == "posix":
|
13 |
+
c_flags = ['-O3', '-std=c++14']
|
14 |
+
elif os.name == "nt":
|
15 |
+
c_flags = ['/O2', '/std:c++17']
|
16 |
+
|
17 |
+
# find cl.exe
|
18 |
+
def find_cl_path():
|
19 |
+
import glob
|
20 |
+
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
21 |
+
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
|
22 |
+
if paths:
|
23 |
+
return paths[0]
|
24 |
+
|
25 |
+
# If cl.exe is not on path, try to find it.
|
26 |
+
if os.system("where cl.exe >nul 2>nul") != 0:
|
27 |
+
cl_path = find_cl_path()
|
28 |
+
if cl_path is None:
|
29 |
+
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
30 |
+
os.environ["PATH"] += ";" + cl_path
|
31 |
+
|
32 |
+
_backend = load(name='_freqencoder',
|
33 |
+
extra_cflags=c_flags,
|
34 |
+
extra_cuda_cflags=nvcc_flags,
|
35 |
+
sources=[os.path.join(_src_path, 'src', f) for f in [
|
36 |
+
'freqencoder.cu',
|
37 |
+
'bindings.cpp',
|
38 |
+
]],
|
39 |
+
)
|
40 |
+
|
41 |
+
__all__ = ['_backend']
|
freqencoder/freq.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.autograd import Function
|
6 |
+
from torch.autograd.function import once_differentiable
|
7 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
8 |
+
|
9 |
+
try:
|
10 |
+
import _freqencoder as _backend
|
11 |
+
except ImportError:
|
12 |
+
from .backend import _backend
|
13 |
+
|
14 |
+
|
15 |
+
class _freq_encoder(Function):
|
16 |
+
@staticmethod
|
17 |
+
@custom_fwd(cast_inputs=torch.float32) # force float32 for better precision
|
18 |
+
def forward(ctx, inputs, degree, output_dim):
|
19 |
+
# inputs: [B, input_dim], float
|
20 |
+
# RETURN: [B, F], float
|
21 |
+
|
22 |
+
if not inputs.is_cuda: inputs = inputs.cuda()
|
23 |
+
inputs = inputs.contiguous()
|
24 |
+
|
25 |
+
B, input_dim = inputs.shape # batch size, coord dim
|
26 |
+
|
27 |
+
outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)
|
28 |
+
|
29 |
+
_backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
|
30 |
+
|
31 |
+
ctx.save_for_backward(inputs, outputs)
|
32 |
+
ctx.dims = [B, input_dim, degree, output_dim]
|
33 |
+
|
34 |
+
return outputs
|
35 |
+
|
36 |
+
@staticmethod
|
37 |
+
#@once_differentiable
|
38 |
+
@custom_bwd
|
39 |
+
def backward(ctx, grad):
|
40 |
+
# grad: [B, C * C]
|
41 |
+
|
42 |
+
grad = grad.contiguous()
|
43 |
+
inputs, outputs = ctx.saved_tensors
|
44 |
+
B, input_dim, degree, output_dim = ctx.dims
|
45 |
+
|
46 |
+
grad_inputs = torch.zeros_like(inputs)
|
47 |
+
_backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
|
48 |
+
|
49 |
+
return grad_inputs, None, None
|
50 |
+
|
51 |
+
|
52 |
+
freq_encode = _freq_encoder.apply
|
53 |
+
|
54 |
+
|
55 |
+
class FreqEncoder(nn.Module):
|
56 |
+
def __init__(self, input_dim=3, degree=4):
|
57 |
+
super().__init__()
|
58 |
+
|
59 |
+
self.input_dim = input_dim
|
60 |
+
self.degree = degree
|
61 |
+
self.output_dim = input_dim + input_dim * 2 * degree
|
62 |
+
|
63 |
+
def __repr__(self):
|
64 |
+
return f"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}"
|
65 |
+
|
66 |
+
def forward(self, inputs, **kwargs):
|
67 |
+
# inputs: [..., input_dim]
|
68 |
+
# return: [..., ]
|
69 |
+
|
70 |
+
prefix_shape = list(inputs.shape[:-1])
|
71 |
+
inputs = inputs.reshape(-1, self.input_dim)
|
72 |
+
|
73 |
+
outputs = freq_encode(inputs, self.degree, self.output_dim)
|
74 |
+
|
75 |
+
outputs = outputs.reshape(prefix_shape + [self.output_dim])
|
76 |
+
|
77 |
+
return outputs
|
freqencoder/setup.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from setuptools import setup
|
3 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
4 |
+
|
5 |
+
_src_path = os.path.dirname(os.path.abspath(__file__))
|
6 |
+
|
7 |
+
nvcc_flags = [
|
8 |
+
'-O3', '-std=c++14',
|
9 |
+
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
10 |
+
'-use_fast_math'
|
11 |
+
]
|
12 |
+
|
13 |
+
if os.name == "posix":
|
14 |
+
c_flags = ['-O3', '-std=c++14']
|
15 |
+
elif os.name == "nt":
|
16 |
+
c_flags = ['/O2', '/std:c++17']
|
17 |
+
|
18 |
+
# find cl.exe
|
19 |
+
def find_cl_path():
|
20 |
+
import glob
|
21 |
+
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
22 |
+
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
|
23 |
+
if paths:
|
24 |
+
return paths[0]
|
25 |
+
|
26 |
+
# If cl.exe is not on path, try to find it.
|
27 |
+
if os.system("where cl.exe >nul 2>nul") != 0:
|
28 |
+
cl_path = find_cl_path()
|
29 |
+
if cl_path is None:
|
30 |
+
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
31 |
+
os.environ["PATH"] += ";" + cl_path
|
32 |
+
|
33 |
+
setup(
|
34 |
+
name='freqencoder', # package name, import this to use python API
|
35 |
+
ext_modules=[
|
36 |
+
CUDAExtension(
|
37 |
+
name='_freqencoder', # extension name, import this to use CUDA API
|
38 |
+
sources=[os.path.join(_src_path, 'src', f) for f in [
|
39 |
+
'freqencoder.cu',
|
40 |
+
'bindings.cpp',
|
41 |
+
]],
|
42 |
+
extra_compile_args={
|
43 |
+
'cxx': c_flags,
|
44 |
+
'nvcc': nvcc_flags,
|
45 |
+
}
|
46 |
+
),
|
47 |
+
],
|
48 |
+
cmdclass={
|
49 |
+
'build_ext': BuildExtension,
|
50 |
+
}
|
51 |
+
)
|
freqencoder/src/bindings.cpp
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
|
3 |
+
#include "freqencoder.h"
|
4 |
+
|
5 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
6 |
+
m.def("freq_encode_forward", &freq_encode_forward, "freq encode forward (CUDA)");
|
7 |
+
m.def("freq_encode_backward", &freq_encode_backward, "freq encode backward (CUDA)");
|
8 |
+
}
|
freqencoder/src/freqencoder.cu
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <stdint.h>
|
2 |
+
|
3 |
+
#include <cuda.h>
|
4 |
+
#include <cuda_fp16.h>
|
5 |
+
#include <cuda_runtime.h>
|
6 |
+
|
7 |
+
#include <ATen/cuda/CUDAContext.h>
|
8 |
+
#include <torch/torch.h>
|
9 |
+
|
10 |
+
#include <algorithm>
|
11 |
+
#include <stdexcept>
|
12 |
+
|
13 |
+
#include <cstdio>
|
14 |
+
|
15 |
+
|
16 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
17 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
|
18 |
+
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
|
19 |
+
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
|
20 |
+
|
21 |
+
inline constexpr __device__ float PI() { return 3.141592653589793f; }
|
22 |
+
|
23 |
+
template <typename T>
|
24 |
+
__host__ __device__ T div_round_up(T val, T divisor) {
|
25 |
+
return (val + divisor - 1) / divisor;
|
26 |
+
}
|
27 |
+
|
28 |
+
// inputs: [B, D]
|
29 |
+
// outputs: [B, C], C = D + D * deg * 2
|
30 |
+
__global__ void kernel_freq(
|
31 |
+
const float * __restrict__ inputs,
|
32 |
+
uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
|
33 |
+
float * outputs
|
34 |
+
) {
|
35 |
+
// parallel on per-element
|
36 |
+
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
|
37 |
+
if (t >= B * C) return;
|
38 |
+
|
39 |
+
// get index
|
40 |
+
const uint32_t b = t / C;
|
41 |
+
const uint32_t c = t - b * C; // t % C;
|
42 |
+
|
43 |
+
// locate
|
44 |
+
inputs += b * D;
|
45 |
+
outputs += t;
|
46 |
+
|
47 |
+
// write self
|
48 |
+
if (c < D) {
|
49 |
+
outputs[0] = inputs[c];
|
50 |
+
// write freq
|
51 |
+
} else {
|
52 |
+
const uint32_t col = c / D - 1;
|
53 |
+
const uint32_t d = c % D;
|
54 |
+
const uint32_t freq = col / 2;
|
55 |
+
const float phase_shift = (col % 2) * (PI() / 2);
|
56 |
+
outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift);
|
57 |
+
}
|
58 |
+
}
|
59 |
+
|
60 |
+
// grad: [B, C], C = D + D * deg * 2
|
61 |
+
// outputs: [B, C]
|
62 |
+
// grad_inputs: [B, D]
|
63 |
+
__global__ void kernel_freq_backward(
|
64 |
+
const float * __restrict__ grad,
|
65 |
+
const float * __restrict__ outputs,
|
66 |
+
uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
|
67 |
+
float * grad_inputs
|
68 |
+
) {
|
69 |
+
// parallel on per-element
|
70 |
+
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
|
71 |
+
if (t >= B * D) return;
|
72 |
+
|
73 |
+
const uint32_t b = t / D;
|
74 |
+
const uint32_t d = t - b * D; // t % D;
|
75 |
+
|
76 |
+
// locate
|
77 |
+
grad += b * C;
|
78 |
+
outputs += b * C;
|
79 |
+
grad_inputs += t;
|
80 |
+
|
81 |
+
// register
|
82 |
+
float result = grad[d];
|
83 |
+
grad += D;
|
84 |
+
outputs += D;
|
85 |
+
|
86 |
+
for (uint32_t f = 0; f < deg; f++) {
|
87 |
+
result += scalbnf(1.0f, f) * (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]);
|
88 |
+
grad += 2 * D;
|
89 |
+
outputs += 2 * D;
|
90 |
+
}
|
91 |
+
|
92 |
+
// write
|
93 |
+
grad_inputs[0] = result;
|
94 |
+
}
|
95 |
+
|
96 |
+
|
97 |
+
void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs) {
|
98 |
+
CHECK_CUDA(inputs);
|
99 |
+
CHECK_CUDA(outputs);
|
100 |
+
|
101 |
+
CHECK_CONTIGUOUS(inputs);
|
102 |
+
CHECK_CONTIGUOUS(outputs);
|
103 |
+
|
104 |
+
CHECK_IS_FLOATING(inputs);
|
105 |
+
CHECK_IS_FLOATING(outputs);
|
106 |
+
|
107 |
+
static constexpr uint32_t N_THREADS = 128;
|
108 |
+
|
109 |
+
kernel_freq<<<div_round_up(B * C, N_THREADS), N_THREADS>>>(inputs.data_ptr<float>(), B, D, deg, C, outputs.data_ptr<float>());
|
110 |
+
}
|
111 |
+
|
112 |
+
|
113 |
+
void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs) {
|
114 |
+
CHECK_CUDA(grad);
|
115 |
+
CHECK_CUDA(outputs);
|
116 |
+
CHECK_CUDA(grad_inputs);
|
117 |
+
|
118 |
+
CHECK_CONTIGUOUS(grad);
|
119 |
+
CHECK_CONTIGUOUS(outputs);
|
120 |
+
CHECK_CONTIGUOUS(grad_inputs);
|
121 |
+
|
122 |
+
CHECK_IS_FLOATING(grad);
|
123 |
+
CHECK_IS_FLOATING(outputs);
|
124 |
+
CHECK_IS_FLOATING(grad_inputs);
|
125 |
+
|
126 |
+
static constexpr uint32_t N_THREADS = 128;
|
127 |
+
|
128 |
+
kernel_freq_backward<<<div_round_up(B * D, N_THREADS), N_THREADS>>>(grad.data_ptr<float>(), outputs.data_ptr<float>(), B, D, deg, C, grad_inputs.data_ptr<float>());
|
129 |
+
}
|
freqencoder/src/freqencoder.h
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pragma once
|
2 |
+
|
3 |
+
#include <stdint.h>
|
4 |
+
#include <torch/torch.h>
|
5 |
+
|
6 |
+
// _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
|
7 |
+
void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs);
|
8 |
+
|
9 |
+
// _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
|
10 |
+
void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs);
|
gridencoder/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .grid import GridEncoder
|
gridencoder/backend.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from torch.utils.cpp_extension import load
|
3 |
+
|
4 |
+
_src_path = os.path.dirname(os.path.abspath(__file__))
|
5 |
+
|
6 |
+
nvcc_flags = [
|
7 |
+
'-O3', '-std=c++14',
|
8 |
+
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
9 |
+
]
|
10 |
+
|
11 |
+
if os.name == "posix":
|
12 |
+
c_flags = ['-O3', '-std=c++14']
|
13 |
+
elif os.name == "nt":
|
14 |
+
c_flags = ['/O2', '/std:c++17']
|
15 |
+
|
16 |
+
# find cl.exe
|
17 |
+
def find_cl_path():
|
18 |
+
import glob
|
19 |
+
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
20 |
+
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
|
21 |
+
if paths:
|
22 |
+
return paths[0]
|
23 |
+
|
24 |
+
# If cl.exe is not on path, try to find it.
|
25 |
+
if os.system("where cl.exe >nul 2>nul") != 0:
|
26 |
+
cl_path = find_cl_path()
|
27 |
+
if cl_path is None:
|
28 |
+
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
29 |
+
os.environ["PATH"] += ";" + cl_path
|
30 |
+
|
31 |
+
_backend = load(name='_grid_encoder',
|
32 |
+
extra_cflags=c_flags,
|
33 |
+
extra_cuda_cflags=nvcc_flags,
|
34 |
+
sources=[os.path.join(_src_path, 'src', f) for f in [
|
35 |
+
'gridencoder.cu',
|
36 |
+
'bindings.cpp',
|
37 |
+
]],
|
38 |
+
)
|
39 |
+
|
40 |
+
__all__ = ['_backend']
|
gridencoder/grid.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.autograd import Function
|
6 |
+
from torch.autograd.function import once_differentiable
|
7 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
8 |
+
|
9 |
+
try:
|
10 |
+
import _gridencoder as _backend
|
11 |
+
except ImportError:
|
12 |
+
from .backend import _backend
|
13 |
+
|
14 |
+
_gridtype_to_id = {
|
15 |
+
'hash': 0,
|
16 |
+
'tiled': 1,
|
17 |
+
}
|
18 |
+
|
19 |
+
class _grid_encode(Function):
|
20 |
+
@staticmethod
|
21 |
+
@custom_fwd
|
22 |
+
def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False):
|
23 |
+
# inputs: [B, D], float in [0, 1]
|
24 |
+
# embeddings: [sO, C], float
|
25 |
+
# offsets: [L + 1], int
|
26 |
+
# RETURN: [B, F], float
|
27 |
+
|
28 |
+
inputs = inputs.contiguous()
|
29 |
+
|
30 |
+
B, D = inputs.shape # batch size, coord dim
|
31 |
+
L = offsets.shape[0] - 1 # level
|
32 |
+
C = embeddings.shape[1] # embedding dim for each level
|
33 |
+
S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f
|
34 |
+
H = base_resolution # base resolution
|
35 |
+
|
36 |
+
# manually handle autocast (only use half precision embeddings, inputs must be float for enough precision)
|
37 |
+
# if C % 2 != 0, force float, since half for atomicAdd is very slow.
|
38 |
+
if torch.is_autocast_enabled() and C % 2 == 0:
|
39 |
+
embeddings = embeddings.to(torch.half)
|
40 |
+
|
41 |
+
# L first, optimize cache for cuda kernel, but needs an extra permute later
|
42 |
+
outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype)
|
43 |
+
|
44 |
+
if calc_grad_inputs:
|
45 |
+
dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype)
|
46 |
+
else:
|
47 |
+
dy_dx = None
|
48 |
+
|
49 |
+
_backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, dy_dx, gridtype, align_corners)
|
50 |
+
|
51 |
+
# permute back to [B, L * C]
|
52 |
+
outputs = outputs.permute(1, 0, 2).reshape(B, L * C)
|
53 |
+
|
54 |
+
ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)
|
55 |
+
ctx.dims = [B, D, C, L, S, H, gridtype]
|
56 |
+
ctx.align_corners = align_corners
|
57 |
+
|
58 |
+
return outputs
|
59 |
+
|
60 |
+
@staticmethod
|
61 |
+
#@once_differentiable
|
62 |
+
@custom_bwd
|
63 |
+
def backward(ctx, grad):
|
64 |
+
|
65 |
+
inputs, embeddings, offsets, dy_dx = ctx.saved_tensors
|
66 |
+
B, D, C, L, S, H, gridtype = ctx.dims
|
67 |
+
align_corners = ctx.align_corners
|
68 |
+
|
69 |
+
# grad: [B, L * C] --> [L, B, C]
|
70 |
+
grad = grad.view(B, L, C).permute(1, 0, 2).contiguous()
|
71 |
+
|
72 |
+
grad_embeddings = torch.zeros_like(embeddings)
|
73 |
+
|
74 |
+
if dy_dx is not None:
|
75 |
+
grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype)
|
76 |
+
else:
|
77 |
+
grad_inputs = None
|
78 |
+
|
79 |
+
_backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners)
|
80 |
+
|
81 |
+
if dy_dx is not None:
|
82 |
+
grad_inputs = grad_inputs.to(inputs.dtype)
|
83 |
+
|
84 |
+
return grad_inputs, grad_embeddings, None, None, None, None, None, None
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
grid_encode = _grid_encode.apply
|
89 |
+
|
90 |
+
|
91 |
+
class GridEncoder(nn.Module):
|
92 |
+
def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False):
|
93 |
+
super().__init__()
|
94 |
+
|
95 |
+
# the finest resolution desired at the last level, if provided, overridee per_level_scale
|
96 |
+
if desired_resolution is not None:
|
97 |
+
per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1))
|
98 |
+
|
99 |
+
self.input_dim = input_dim # coord dims, 2 or 3
|
100 |
+
self.num_levels = num_levels # num levels, each level multiply resolution by 2
|
101 |
+
self.level_dim = level_dim # encode channels per level
|
102 |
+
self.per_level_scale = per_level_scale # multiply resolution by this scale at each level.
|
103 |
+
self.log2_hashmap_size = log2_hashmap_size
|
104 |
+
self.base_resolution = base_resolution
|
105 |
+
self.output_dim = num_levels * level_dim
|
106 |
+
self.gridtype = gridtype
|
107 |
+
self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash"
|
108 |
+
self.align_corners = align_corners
|
109 |
+
|
110 |
+
# allocate parameters
|
111 |
+
offsets = []
|
112 |
+
offset = 0
|
113 |
+
self.max_params = 2 ** log2_hashmap_size
|
114 |
+
for i in range(num_levels):
|
115 |
+
resolution = int(np.ceil(base_resolution * per_level_scale ** i))
|
116 |
+
params_in_level = min(self.max_params, (resolution if align_corners else resolution + 1) ** input_dim) # limit max number
|
117 |
+
params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible
|
118 |
+
offsets.append(offset)
|
119 |
+
offset += params_in_level
|
120 |
+
offsets.append(offset)
|
121 |
+
offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
|
122 |
+
self.register_buffer('offsets', offsets)
|
123 |
+
|
124 |
+
self.n_params = offsets[-1] * level_dim
|
125 |
+
|
126 |
+
# parameters
|
127 |
+
self.embeddings = nn.Parameter(torch.empty(offset, level_dim))
|
128 |
+
|
129 |
+
self.reset_parameters()
|
130 |
+
|
131 |
+
def reset_parameters(self):
|
132 |
+
std = 1e-4
|
133 |
+
self.embeddings.data.uniform_(-std, std)
|
134 |
+
|
135 |
+
def __repr__(self):
|
136 |
+
return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners}"
|
137 |
+
|
138 |
+
def forward(self, inputs, bound=1):
|
139 |
+
# inputs: [..., input_dim], normalized real world positions in [-bound, bound]
|
140 |
+
# return: [..., num_levels * level_dim]
|
141 |
+
|
142 |
+
inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
|
143 |
+
|
144 |
+
#print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item())
|
145 |
+
|
146 |
+
prefix_shape = list(inputs.shape[:-1])
|
147 |
+
inputs = inputs.view(-1, self.input_dim)
|
148 |
+
|
149 |
+
outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners)
|
150 |
+
outputs = outputs.view(prefix_shape + [self.output_dim])
|
151 |
+
|
152 |
+
#print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())
|
153 |
+
|
154 |
+
return outputs
|
gridencoder/setup.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from setuptools import setup
|
3 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
4 |
+
|
5 |
+
_src_path = os.path.dirname(os.path.abspath(__file__))
|
6 |
+
|
7 |
+
nvcc_flags = [
|
8 |
+
'-O3', '-std=c++14',
|
9 |
+
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
10 |
+
]
|
11 |
+
|
12 |
+
if os.name == "posix":
|
13 |
+
c_flags = ['-O3', '-std=c++14']
|
14 |
+
elif os.name == "nt":
|
15 |
+
c_flags = ['/O2', '/std:c++17']
|
16 |
+
|
17 |
+
# find cl.exe
|
18 |
+
def find_cl_path():
|
19 |
+
import glob
|
20 |
+
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
21 |
+
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
|
22 |
+
if paths:
|
23 |
+
return paths[0]
|
24 |
+
|
25 |
+
# If cl.exe is not on path, try to find it.
|
26 |
+
if os.system("where cl.exe >nul 2>nul") != 0:
|
27 |
+
cl_path = find_cl_path()
|
28 |
+
if cl_path is None:
|
29 |
+
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
30 |
+
os.environ["PATH"] += ";" + cl_path
|
31 |
+
|
32 |
+
setup(
|
33 |
+
name='gridencoder', # package name, import this to use python API
|
34 |
+
ext_modules=[
|
35 |
+
CUDAExtension(
|
36 |
+
name='_gridencoder', # extension name, import this to use CUDA API
|
37 |
+
sources=[os.path.join(_src_path, 'src', f) for f in [
|
38 |
+
'gridencoder.cu',
|
39 |
+
'bindings.cpp',
|
40 |
+
]],
|
41 |
+
extra_compile_args={
|
42 |
+
'cxx': c_flags,
|
43 |
+
'nvcc': nvcc_flags,
|
44 |
+
}
|
45 |
+
),
|
46 |
+
],
|
47 |
+
cmdclass={
|
48 |
+
'build_ext': BuildExtension,
|
49 |
+
}
|
50 |
+
)
|
gridencoder/src/bindings.cpp
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
|
3 |
+
#include "gridencoder.h"
|
4 |
+
|
5 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
6 |
+
m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)");
|
7 |
+
m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)");
|
8 |
+
}
|
gridencoder/src/gridencoder.cu
ADDED
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <cuda.h>
|
2 |
+
#include <cuda_fp16.h>
|
3 |
+
#include <cuda_runtime.h>
|
4 |
+
|
5 |
+
#include <ATen/cuda/CUDAContext.h>
|
6 |
+
#include <torch/torch.h>
|
7 |
+
|
8 |
+
#include <algorithm>
|
9 |
+
#include <stdexcept>
|
10 |
+
|
11 |
+
#include <stdint.h>
|
12 |
+
#include <cstdio>
|
13 |
+
|
14 |
+
|
15 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
16 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
|
17 |
+
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
|
18 |
+
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
|
19 |
+
|
20 |
+
|
21 |
+
// just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF...
|
22 |
+
static inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) {
|
23 |
+
// requires CUDA >= 10 and ARCH >= 70
|
24 |
+
// this is very slow compared to float or __half2, and never used.
|
25 |
+
//return atomicAdd(reinterpret_cast<__half*>(address), val);
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
template <typename T>
|
30 |
+
static inline __host__ __device__ T div_round_up(T val, T divisor) {
|
31 |
+
return (val + divisor - 1) / divisor;
|
32 |
+
}
|
33 |
+
|
34 |
+
|
35 |
+
template <uint32_t D>
|
36 |
+
__device__ uint32_t fast_hash(const uint32_t pos_grid[D]) {
|
37 |
+
static_assert(D <= 7, "fast_hash can only hash up to 7 dimensions.");
|
38 |
+
|
39 |
+
// While 1 is technically not a good prime for hashing (or a prime at all), it helps memory coherence
|
40 |
+
// and is sufficient for our use case of obtaining a uniformly colliding index from high-dimensional
|
41 |
+
// coordinates.
|
42 |
+
constexpr uint32_t primes[7] = { 1, 2654435761, 805459861, 3674653429, 2097192037, 1434869437, 2165219737 };
|
43 |
+
|
44 |
+
uint32_t result = 0;
|
45 |
+
#pragma unroll
|
46 |
+
for (uint32_t i = 0; i < D; ++i) {
|
47 |
+
result ^= pos_grid[i] * primes[i];
|
48 |
+
}
|
49 |
+
|
50 |
+
return result;
|
51 |
+
}
|
52 |
+
|
53 |
+
|
54 |
+
template <uint32_t D, uint32_t C>
|
55 |
+
__device__ uint32_t get_grid_index(const uint32_t gridtype, const bool align_corners, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) {
|
56 |
+
uint32_t stride = 1;
|
57 |
+
uint32_t index = 0;
|
58 |
+
|
59 |
+
#pragma unroll
|
60 |
+
for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) {
|
61 |
+
index += pos_grid[d] * stride;
|
62 |
+
stride *= align_corners ? resolution: (resolution + 1);
|
63 |
+
}
|
64 |
+
|
65 |
+
// NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97.
|
66 |
+
// gridtype: 0 == hash, 1 == tiled
|
67 |
+
if (gridtype == 0 && stride > hashmap_size) {
|
68 |
+
index = fast_hash<D>(pos_grid);
|
69 |
+
}
|
70 |
+
|
71 |
+
return (index % hashmap_size) * C + ch;
|
72 |
+
}
|
73 |
+
|
74 |
+
|
75 |
+
template <typename scalar_t, uint32_t D, uint32_t C>
|
76 |
+
__global__ void kernel_grid(
|
77 |
+
const float * __restrict__ inputs,
|
78 |
+
const scalar_t * __restrict__ grid,
|
79 |
+
const int * __restrict__ offsets,
|
80 |
+
scalar_t * __restrict__ outputs,
|
81 |
+
const uint32_t B, const uint32_t L, const float S, const uint32_t H,
|
82 |
+
scalar_t * __restrict__ dy_dx,
|
83 |
+
const uint32_t gridtype,
|
84 |
+
const bool align_corners
|
85 |
+
) {
|
86 |
+
const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
|
87 |
+
|
88 |
+
if (b >= B) return;
|
89 |
+
|
90 |
+
const uint32_t level = blockIdx.y;
|
91 |
+
|
92 |
+
// locate
|
93 |
+
grid += (uint32_t)offsets[level] * C;
|
94 |
+
inputs += b * D;
|
95 |
+
outputs += level * B * C + b * C;
|
96 |
+
|
97 |
+
// check input range (should be in [0, 1])
|
98 |
+
bool flag_oob = false;
|
99 |
+
#pragma unroll
|
100 |
+
for (uint32_t d = 0; d < D; d++) {
|
101 |
+
if (inputs[d] < 0 || inputs[d] > 1) {
|
102 |
+
flag_oob = true;
|
103 |
+
}
|
104 |
+
}
|
105 |
+
// if input out of bound, just set output to 0
|
106 |
+
if (flag_oob) {
|
107 |
+
#pragma unroll
|
108 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
109 |
+
outputs[ch] = 0;
|
110 |
+
}
|
111 |
+
if (dy_dx) {
|
112 |
+
dy_dx += b * D * L * C + level * D * C; // B L D C
|
113 |
+
#pragma unroll
|
114 |
+
for (uint32_t d = 0; d < D; d++) {
|
115 |
+
#pragma unroll
|
116 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
117 |
+
dy_dx[d * C + ch] = 0;
|
118 |
+
}
|
119 |
+
}
|
120 |
+
}
|
121 |
+
return;
|
122 |
+
}
|
123 |
+
|
124 |
+
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
|
125 |
+
const float scale = exp2f(level * S) * H - 1.0f;
|
126 |
+
const uint32_t resolution = (uint32_t)ceil(scale) + 1;
|
127 |
+
|
128 |
+
// calculate coordinate
|
129 |
+
float pos[D];
|
130 |
+
uint32_t pos_grid[D];
|
131 |
+
|
132 |
+
#pragma unroll
|
133 |
+
for (uint32_t d = 0; d < D; d++) {
|
134 |
+
pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
|
135 |
+
pos_grid[d] = floorf(pos[d]);
|
136 |
+
pos[d] -= (float)pos_grid[d];
|
137 |
+
}
|
138 |
+
|
139 |
+
//printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]);
|
140 |
+
|
141 |
+
// interpolate
|
142 |
+
scalar_t results[C] = {0}; // temp results in register
|
143 |
+
|
144 |
+
#pragma unroll
|
145 |
+
for (uint32_t idx = 0; idx < (1 << D); idx++) {
|
146 |
+
float w = 1;
|
147 |
+
uint32_t pos_grid_local[D];
|
148 |
+
|
149 |
+
#pragma unroll
|
150 |
+
for (uint32_t d = 0; d < D; d++) {
|
151 |
+
if ((idx & (1 << d)) == 0) {
|
152 |
+
w *= 1 - pos[d];
|
153 |
+
pos_grid_local[d] = pos_grid[d];
|
154 |
+
} else {
|
155 |
+
w *= pos[d];
|
156 |
+
pos_grid_local[d] = pos_grid[d] + 1;
|
157 |
+
}
|
158 |
+
}
|
159 |
+
|
160 |
+
uint32_t index = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
|
161 |
+
|
162 |
+
// writing to register (fast)
|
163 |
+
#pragma unroll
|
164 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
165 |
+
results[ch] += w * grid[index + ch];
|
166 |
+
}
|
167 |
+
|
168 |
+
//printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]);
|
169 |
+
}
|
170 |
+
|
171 |
+
// writing to global memory (slow)
|
172 |
+
#pragma unroll
|
173 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
174 |
+
outputs[ch] = results[ch];
|
175 |
+
}
|
176 |
+
|
177 |
+
// prepare dy_dx
|
178 |
+
// differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9
|
179 |
+
if (dy_dx) {
|
180 |
+
|
181 |
+
dy_dx += b * D * L * C + level * D * C; // B L D C
|
182 |
+
|
183 |
+
#pragma unroll
|
184 |
+
for (uint32_t gd = 0; gd < D; gd++) {
|
185 |
+
|
186 |
+
scalar_t results_grad[C] = {0};
|
187 |
+
|
188 |
+
#pragma unroll
|
189 |
+
for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) {
|
190 |
+
float w = scale;
|
191 |
+
uint32_t pos_grid_local[D];
|
192 |
+
|
193 |
+
#pragma unroll
|
194 |
+
for (uint32_t nd = 0; nd < D - 1; nd++) {
|
195 |
+
const uint32_t d = (nd >= gd) ? (nd + 1) : nd;
|
196 |
+
|
197 |
+
if ((idx & (1 << nd)) == 0) {
|
198 |
+
w *= 1 - pos[d];
|
199 |
+
pos_grid_local[d] = pos_grid[d];
|
200 |
+
} else {
|
201 |
+
w *= pos[d];
|
202 |
+
pos_grid_local[d] = pos_grid[d] + 1;
|
203 |
+
}
|
204 |
+
}
|
205 |
+
|
206 |
+
pos_grid_local[gd] = pos_grid[gd];
|
207 |
+
uint32_t index_left = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
|
208 |
+
pos_grid_local[gd] = pos_grid[gd] + 1;
|
209 |
+
uint32_t index_right = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
|
210 |
+
|
211 |
+
#pragma unroll
|
212 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
213 |
+
results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]);
|
214 |
+
}
|
215 |
+
}
|
216 |
+
|
217 |
+
#pragma unroll
|
218 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
219 |
+
dy_dx[gd * C + ch] = results_grad[ch];
|
220 |
+
}
|
221 |
+
}
|
222 |
+
}
|
223 |
+
}
|
224 |
+
|
225 |
+
|
226 |
+
template <typename scalar_t, uint32_t D, uint32_t C, uint32_t N_C>
|
227 |
+
__global__ void kernel_grid_backward(
|
228 |
+
const scalar_t * __restrict__ grad,
|
229 |
+
const float * __restrict__ inputs,
|
230 |
+
const scalar_t * __restrict__ grid,
|
231 |
+
const int * __restrict__ offsets,
|
232 |
+
scalar_t * __restrict__ grad_grid,
|
233 |
+
const uint32_t B, const uint32_t L, const float S, const uint32_t H,
|
234 |
+
const uint32_t gridtype,
|
235 |
+
const bool align_corners
|
236 |
+
) {
|
237 |
+
const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C;
|
238 |
+
if (b >= B) return;
|
239 |
+
|
240 |
+
const uint32_t level = blockIdx.y;
|
241 |
+
const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C;
|
242 |
+
|
243 |
+
// locate
|
244 |
+
grad_grid += offsets[level] * C;
|
245 |
+
inputs += b * D;
|
246 |
+
grad += level * B * C + b * C + ch; // L, B, C
|
247 |
+
|
248 |
+
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
|
249 |
+
const float scale = exp2f(level * S) * H - 1.0f;
|
250 |
+
const uint32_t resolution = (uint32_t)ceil(scale) + 1;
|
251 |
+
|
252 |
+
// check input range (should be in [0, 1])
|
253 |
+
#pragma unroll
|
254 |
+
for (uint32_t d = 0; d < D; d++) {
|
255 |
+
if (inputs[d] < 0 || inputs[d] > 1) {
|
256 |
+
return; // grad is init as 0, so we simply return.
|
257 |
+
}
|
258 |
+
}
|
259 |
+
|
260 |
+
// calculate coordinate
|
261 |
+
float pos[D];
|
262 |
+
uint32_t pos_grid[D];
|
263 |
+
|
264 |
+
#pragma unroll
|
265 |
+
for (uint32_t d = 0; d < D; d++) {
|
266 |
+
pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
|
267 |
+
pos_grid[d] = floorf(pos[d]);
|
268 |
+
pos[d] -= (float)pos_grid[d];
|
269 |
+
}
|
270 |
+
|
271 |
+
scalar_t grad_cur[N_C] = {0}; // fetch to register
|
272 |
+
#pragma unroll
|
273 |
+
for (uint32_t c = 0; c < N_C; c++) {
|
274 |
+
grad_cur[c] = grad[c];
|
275 |
+
}
|
276 |
+
|
277 |
+
// interpolate
|
278 |
+
#pragma unroll
|
279 |
+
for (uint32_t idx = 0; idx < (1 << D); idx++) {
|
280 |
+
float w = 1;
|
281 |
+
uint32_t pos_grid_local[D];
|
282 |
+
|
283 |
+
#pragma unroll
|
284 |
+
for (uint32_t d = 0; d < D; d++) {
|
285 |
+
if ((idx & (1 << d)) == 0) {
|
286 |
+
w *= 1 - pos[d];
|
287 |
+
pos_grid_local[d] = pos_grid[d];
|
288 |
+
} else {
|
289 |
+
w *= pos[d];
|
290 |
+
pos_grid_local[d] = pos_grid[d] + 1;
|
291 |
+
}
|
292 |
+
}
|
293 |
+
|
294 |
+
uint32_t index = get_grid_index<D, C>(gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local);
|
295 |
+
|
296 |
+
// atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0
|
297 |
+
// TODO: use float which is better than __half, if N_C % 2 != 0
|
298 |
+
if (std::is_same<scalar_t, at::Half>::value && N_C % 2 == 0) {
|
299 |
+
#pragma unroll
|
300 |
+
for (uint32_t c = 0; c < N_C; c += 2) {
|
301 |
+
// process two __half at once (by interpreting as a __half2)
|
302 |
+
__half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])};
|
303 |
+
atomicAdd((__half2*)&grad_grid[index + c], v);
|
304 |
+
}
|
305 |
+
// float, or __half when N_C % 2 != 0 (which means C == 1)
|
306 |
+
} else {
|
307 |
+
#pragma unroll
|
308 |
+
for (uint32_t c = 0; c < N_C; c++) {
|
309 |
+
atomicAdd(&grad_grid[index + c], w * grad_cur[c]);
|
310 |
+
}
|
311 |
+
}
|
312 |
+
}
|
313 |
+
}
|
314 |
+
|
315 |
+
|
316 |
+
template <typename scalar_t, uint32_t D, uint32_t C>
|
317 |
+
__global__ void kernel_input_backward(
|
318 |
+
const scalar_t * __restrict__ grad,
|
319 |
+
const scalar_t * __restrict__ dy_dx,
|
320 |
+
scalar_t * __restrict__ grad_inputs,
|
321 |
+
uint32_t B, uint32_t L
|
322 |
+
) {
|
323 |
+
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
|
324 |
+
if (t >= B * D) return;
|
325 |
+
|
326 |
+
const uint32_t b = t / D;
|
327 |
+
const uint32_t d = t - b * D;
|
328 |
+
|
329 |
+
dy_dx += b * L * D * C;
|
330 |
+
|
331 |
+
scalar_t result = 0;
|
332 |
+
|
333 |
+
# pragma unroll
|
334 |
+
for (int l = 0; l < L; l++) {
|
335 |
+
# pragma unroll
|
336 |
+
for (int ch = 0; ch < C; ch++) {
|
337 |
+
result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch];
|
338 |
+
}
|
339 |
+
}
|
340 |
+
|
341 |
+
grad_inputs[t] = result;
|
342 |
+
}
|
343 |
+
|
344 |
+
|
345 |
+
template <typename scalar_t, uint32_t D>
|
346 |
+
void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) {
|
347 |
+
static constexpr uint32_t N_THREAD = 512;
|
348 |
+
const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 };
|
349 |
+
switch (C) {
|
350 |
+
case 1: kernel_grid<scalar_t, D, 1><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break;
|
351 |
+
case 2: kernel_grid<scalar_t, D, 2><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break;
|
352 |
+
case 4: kernel_grid<scalar_t, D, 4><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break;
|
353 |
+
case 8: kernel_grid<scalar_t, D, 8><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break;
|
354 |
+
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
|
355 |
+
}
|
356 |
+
}
|
357 |
+
|
358 |
+
// inputs: [B, D], float, in [0, 1]
|
359 |
+
// embeddings: [sO, C], float
|
360 |
+
// offsets: [L + 1], uint32_t
|
361 |
+
// outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.)
|
362 |
+
// H: base resolution
|
363 |
+
// dy_dx: [B, L * D * C]
|
364 |
+
template <typename scalar_t>
|
365 |
+
void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) {
|
366 |
+
switch (D) {
|
367 |
+
case 1: kernel_grid_wrapper<scalar_t, 1>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
|
368 |
+
case 2: kernel_grid_wrapper<scalar_t, 2>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
|
369 |
+
case 3: kernel_grid_wrapper<scalar_t, 3>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
|
370 |
+
case 4: kernel_grid_wrapper<scalar_t, 4>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
|
371 |
+
case 5: kernel_grid_wrapper<scalar_t, 5>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
|
372 |
+
default: throw std::runtime_error{"GridEncoding: D must be 1, 2, 3, 4, or 5."};
|
373 |
+
}
|
374 |
+
|
375 |
+
}
|
376 |
+
|
377 |
+
template <typename scalar_t, uint32_t D>
|
378 |
+
void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) {
|
379 |
+
static constexpr uint32_t N_THREAD = 256;
|
380 |
+
const uint32_t N_C = std::min(2u, C); // n_features_per_thread
|
381 |
+
const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), L, 1 };
|
382 |
+
switch (C) {
|
383 |
+
case 1:
|
384 |
+
kernel_grid_backward<scalar_t, D, 1, 1><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
|
385 |
+
if (dy_dx) kernel_input_backward<scalar_t, D, 1><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
|
386 |
+
break;
|
387 |
+
case 2:
|
388 |
+
kernel_grid_backward<scalar_t, D, 2, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
|
389 |
+
if (dy_dx) kernel_input_backward<scalar_t, D, 2><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
|
390 |
+
break;
|
391 |
+
case 4:
|
392 |
+
kernel_grid_backward<scalar_t, D, 4, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
|
393 |
+
if (dy_dx) kernel_input_backward<scalar_t, D, 4><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
|
394 |
+
break;
|
395 |
+
case 8:
|
396 |
+
kernel_grid_backward<scalar_t, D, 8, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
|
397 |
+
if (dy_dx) kernel_input_backward<scalar_t, D, 8><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
|
398 |
+
break;
|
399 |
+
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
|
400 |
+
}
|
401 |
+
}
|
402 |
+
|
403 |
+
|
404 |
+
// grad: [L, B, C], float
|
405 |
+
// inputs: [B, D], float, in [0, 1]
|
406 |
+
// embeddings: [sO, C], float
|
407 |
+
// offsets: [L + 1], uint32_t
|
408 |
+
// grad_embeddings: [sO, C]
|
409 |
+
// H: base resolution
|
410 |
+
template <typename scalar_t>
|
411 |
+
void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) {
|
412 |
+
switch (D) {
|
413 |
+
case 1: kernel_grid_backward_wrapper<scalar_t, 1>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
|
414 |
+
case 2: kernel_grid_backward_wrapper<scalar_t, 2>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
|
415 |
+
case 3: kernel_grid_backward_wrapper<scalar_t, 3>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
|
416 |
+
case 4: kernel_grid_backward_wrapper<scalar_t, 4>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
|
417 |
+
case 5: kernel_grid_backward_wrapper<scalar_t, 5>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
|
418 |
+
default: throw std::runtime_error{"GridEncoding: D must be 1, 2, 3, 4, or 5."};
|
419 |
+
}
|
420 |
+
}
|
421 |
+
|
422 |
+
|
423 |
+
|
424 |
+
void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional<at::Tensor> dy_dx, const uint32_t gridtype, const bool align_corners) {
|
425 |
+
CHECK_CUDA(inputs);
|
426 |
+
CHECK_CUDA(embeddings);
|
427 |
+
CHECK_CUDA(offsets);
|
428 |
+
CHECK_CUDA(outputs);
|
429 |
+
// CHECK_CUDA(dy_dx);
|
430 |
+
|
431 |
+
CHECK_CONTIGUOUS(inputs);
|
432 |
+
CHECK_CONTIGUOUS(embeddings);
|
433 |
+
CHECK_CONTIGUOUS(offsets);
|
434 |
+
CHECK_CONTIGUOUS(outputs);
|
435 |
+
// CHECK_CONTIGUOUS(dy_dx);
|
436 |
+
|
437 |
+
CHECK_IS_FLOATING(inputs);
|
438 |
+
CHECK_IS_FLOATING(embeddings);
|
439 |
+
CHECK_IS_INT(offsets);
|
440 |
+
CHECK_IS_FLOATING(outputs);
|
441 |
+
// CHECK_IS_FLOATING(dy_dx);
|
442 |
+
|
443 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
444 |
+
embeddings.scalar_type(), "grid_encode_forward", ([&] {
|
445 |
+
grid_encode_forward_cuda<scalar_t>(inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), outputs.data_ptr<scalar_t>(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr, gridtype, align_corners);
|
446 |
+
}));
|
447 |
+
}
|
448 |
+
|
449 |
+
void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional<at::Tensor> dy_dx, at::optional<at::Tensor> grad_inputs, const uint32_t gridtype, const bool align_corners) {
|
450 |
+
CHECK_CUDA(grad);
|
451 |
+
CHECK_CUDA(inputs);
|
452 |
+
CHECK_CUDA(embeddings);
|
453 |
+
CHECK_CUDA(offsets);
|
454 |
+
CHECK_CUDA(grad_embeddings);
|
455 |
+
// CHECK_CUDA(dy_dx);
|
456 |
+
// CHECK_CUDA(grad_inputs);
|
457 |
+
|
458 |
+
CHECK_CONTIGUOUS(grad);
|
459 |
+
CHECK_CONTIGUOUS(inputs);
|
460 |
+
CHECK_CONTIGUOUS(embeddings);
|
461 |
+
CHECK_CONTIGUOUS(offsets);
|
462 |
+
CHECK_CONTIGUOUS(grad_embeddings);
|
463 |
+
// CHECK_CONTIGUOUS(dy_dx);
|
464 |
+
// CHECK_CONTIGUOUS(grad_inputs);
|
465 |
+
|
466 |
+
CHECK_IS_FLOATING(grad);
|
467 |
+
CHECK_IS_FLOATING(inputs);
|
468 |
+
CHECK_IS_FLOATING(embeddings);
|
469 |
+
CHECK_IS_INT(offsets);
|
470 |
+
CHECK_IS_FLOATING(grad_embeddings);
|
471 |
+
// CHECK_IS_FLOATING(dy_dx);
|
472 |
+
// CHECK_IS_FLOATING(grad_inputs);
|
473 |
+
|
474 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
475 |
+
grad.scalar_type(), "grid_encode_backward", ([&] {
|
476 |
+
grid_encode_backward_cuda<scalar_t>(grad.data_ptr<scalar_t>(), inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), grad_embeddings.data_ptr<scalar_t>(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr, grad_inputs.has_value() ? grad_inputs.value().data_ptr<scalar_t>() : nullptr, gridtype, align_corners);
|
477 |
+
}));
|
478 |
+
|
479 |
+
}
|
gridencoder/src/gridencoder.h
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _HASH_ENCODE_H
|
2 |
+
#define _HASH_ENCODE_H
|
3 |
+
|
4 |
+
#include <stdint.h>
|
5 |
+
#include <torch/torch.h>
|
6 |
+
|
7 |
+
// inputs: [B, D], float, in [0, 1]
|
8 |
+
// embeddings: [sO, C], float
|
9 |
+
// offsets: [L + 1], uint32_t
|
10 |
+
// outputs: [B, L * C], float
|
11 |
+
// H: base resolution
|
12 |
+
void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional<at::Tensor> dy_dx, const uint32_t gridtype, const bool align_corners);
|
13 |
+
void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional<at::Tensor> dy_dx, at::optional<at::Tensor> grad_inputs, const uint32_t gridtype, const bool align_corners);
|
14 |
+
|
15 |
+
#endif
|
main.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import argparse
|
3 |
+
|
4 |
+
from nerf.provider import NeRFDataset
|
5 |
+
from nerf.utils import *
|
6 |
+
from optimizer import Shampoo
|
7 |
+
|
8 |
+
from nerf.gui import NeRFGUI
|
9 |
+
|
10 |
+
# torch.autograd.set_detect_anomaly(True)
|
11 |
+
|
12 |
+
if __name__ == '__main__':
|
13 |
+
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument('--text', default=None, help="text prompt")
|
16 |
+
parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray --dir_text")
|
17 |
+
parser.add_argument('-O2', action='store_true', help="equals --fp16 --dir_text")
|
18 |
+
parser.add_argument('--test', action='store_true', help="test mode")
|
19 |
+
parser.add_argument('--save_mesh', action='store_true', help="export an obj mesh with texture")
|
20 |
+
parser.add_argument('--eval_interval', type=int, default=10, help="evaluate on the valid set every interval epochs")
|
21 |
+
parser.add_argument('--workspace', type=str, default='workspace')
|
22 |
+
parser.add_argument('--guidance', type=str, default='stable-diffusion', help='choose from [stable-diffusion, clip]')
|
23 |
+
parser.add_argument('--seed', type=int, default=0)
|
24 |
+
|
25 |
+
### training options
|
26 |
+
parser.add_argument('--iters', type=int, default=10000, help="training iters")
|
27 |
+
parser.add_argument('--lr', type=float, default=1e-3, help="initial learning rate")
|
28 |
+
parser.add_argument('--ckpt', type=str, default='latest')
|
29 |
+
parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
|
30 |
+
parser.add_argument('--max_steps', type=int, default=1024, help="max num steps sampled per ray (only valid when using --cuda_ray)")
|
31 |
+
parser.add_argument('--num_steps', type=int, default=64, help="num steps sampled per ray (only valid when not using --cuda_ray)")
|
32 |
+
parser.add_argument('--upsample_steps', type=int, default=64, help="num steps up-sampled per ray (only valid when not using --cuda_ray)")
|
33 |
+
parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)")
|
34 |
+
parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when not using --cuda_ray)")
|
35 |
+
parser.add_argument('--albedo_iters', type=int, default=1000, help="training iters that only use albedo shading")
|
36 |
+
# model options
|
37 |
+
parser.add_argument('--bg_radius', type=float, default=1.4, help="if positive, use a background model at sphere(bg_radius)")
|
38 |
+
parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied")
|
39 |
+
# network backbone
|
40 |
+
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
|
41 |
+
parser.add_argument('--backbone', type=str, default='grid', help="nerf backbone, choose from [grid, tcnn, vanilla]")
|
42 |
+
# rendering resolution in training, decrease this if CUDA OOM.
|
43 |
+
parser.add_argument('--w', type=int, default=64, help="render width for NeRF in training")
|
44 |
+
parser.add_argument('--h', type=int, default=64, help="render height for NeRF in training")
|
45 |
+
parser.add_argument('--jitter_pose', action='store_true', help="add jitters to the randomly sampled camera poses")
|
46 |
+
|
47 |
+
### dataset options
|
48 |
+
parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box(-bound, bound)")
|
49 |
+
parser.add_argument('--dt_gamma', type=float, default=0, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)")
|
50 |
+
parser.add_argument('--min_near', type=float, default=0.1, help="minimum near distance for camera")
|
51 |
+
parser.add_argument('--radius_range', type=float, nargs='*', default=[1.0, 1.5], help="training camera radius range")
|
52 |
+
parser.add_argument('--fovy_range', type=float, nargs='*', default=[40, 70], help="training camera fovy range")
|
53 |
+
parser.add_argument('--dir_text', action='store_true', help="direction-encode the text prompt, by appending front/side/back/overhead view")
|
54 |
+
parser.add_argument('--angle_overhead', type=float, default=30, help="[0, angle_overhead] is the overhead region")
|
55 |
+
parser.add_argument('--angle_front', type=float, default=60, help="[0, angle_front] is the front region, [180, 180+angle_front] the back region, otherwise the side region.")
|
56 |
+
|
57 |
+
parser.add_argument('--lambda_entropy', type=float, default=1e-4, help="loss scale for alpha entropy")
|
58 |
+
parser.add_argument('--lambda_opacity', type=float, default=0, help="loss scale for alpha value")
|
59 |
+
parser.add_argument('--lambda_orient', type=float, default=1e-2, help="loss scale for orientation")
|
60 |
+
|
61 |
+
### GUI options
|
62 |
+
parser.add_argument('--gui', action='store_true', help="start a GUI")
|
63 |
+
parser.add_argument('--W', type=int, default=800, help="GUI width")
|
64 |
+
parser.add_argument('--H', type=int, default=800, help="GUI height")
|
65 |
+
parser.add_argument('--radius', type=float, default=3, help="default GUI camera radius from center")
|
66 |
+
parser.add_argument('--fovy', type=float, default=60, help="default GUI camera fovy")
|
67 |
+
parser.add_argument('--light_theta', type=float, default=60, help="default GUI light direction in [0, 180], corresponding to elevation [90, -90]")
|
68 |
+
parser.add_argument('--light_phi', type=float, default=0, help="default GUI light direction in [0, 360), azimuth")
|
69 |
+
parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel")
|
70 |
+
|
71 |
+
opt = parser.parse_args()
|
72 |
+
|
73 |
+
if opt.O:
|
74 |
+
opt.fp16 = True
|
75 |
+
opt.dir_text = True
|
76 |
+
# use occupancy grid to prune ray sampling, faster rendering.
|
77 |
+
opt.cuda_ray = True
|
78 |
+
# opt.lambda_entropy = 1e-4
|
79 |
+
# opt.lambda_opacity = 0
|
80 |
+
|
81 |
+
elif opt.O2:
|
82 |
+
opt.fp16 = True
|
83 |
+
opt.dir_text = True
|
84 |
+
opt.lambda_entropy = 1e-4 # necessary to keep non-empty
|
85 |
+
opt.lambda_opacity = 3e-3 # no occupancy grid, so use a stronger opacity loss.
|
86 |
+
|
87 |
+
if opt.backbone == 'vanilla':
|
88 |
+
from nerf.network import NeRFNetwork
|
89 |
+
elif opt.backbone == 'tcnn':
|
90 |
+
from nerf.network_tcnn import NeRFNetwork
|
91 |
+
elif opt.backbone == 'grid':
|
92 |
+
from nerf.network_grid import NeRFNetwork
|
93 |
+
else:
|
94 |
+
raise NotImplementedError(f'--backbone {opt.backbone} is not implemented!')
|
95 |
+
|
96 |
+
print(opt)
|
97 |
+
|
98 |
+
seed_everything(opt.seed)
|
99 |
+
|
100 |
+
model = NeRFNetwork(opt)
|
101 |
+
|
102 |
+
print(model)
|
103 |
+
|
104 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
105 |
+
|
106 |
+
if opt.test:
|
107 |
+
guidance = None # no need to load guidance model at test
|
108 |
+
|
109 |
+
trainer = Trainer('df', opt, model, guidance, device=device, workspace=opt.workspace, fp16=opt.fp16, use_checkpoint=opt.ckpt)
|
110 |
+
|
111 |
+
if opt.gui:
|
112 |
+
gui = NeRFGUI(opt, trainer)
|
113 |
+
gui.render()
|
114 |
+
|
115 |
+
else:
|
116 |
+
test_loader = NeRFDataset(opt, device=device, type='test', H=opt.H, W=opt.W, size=100).dataloader()
|
117 |
+
trainer.test(test_loader)
|
118 |
+
|
119 |
+
if opt.save_mesh:
|
120 |
+
trainer.save_mesh(resolution=256)
|
121 |
+
|
122 |
+
else:
|
123 |
+
|
124 |
+
if opt.guidance == 'stable-diffusion':
|
125 |
+
from nerf.sd import StableDiffusion
|
126 |
+
guidance = StableDiffusion(device)
|
127 |
+
elif opt.guidance == 'clip':
|
128 |
+
from nerf.clip import CLIP
|
129 |
+
guidance = CLIP(device)
|
130 |
+
else:
|
131 |
+
raise NotImplementedError(f'--guidance {opt.guidance} is not implemented.')
|
132 |
+
|
133 |
+
optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15)
|
134 |
+
# optimizer = lambda model: Shampoo(model.get_params(opt.lr))
|
135 |
+
|
136 |
+
train_loader = NeRFDataset(opt, device=device, type='train', H=opt.h, W=opt.w, size=100).dataloader()
|
137 |
+
|
138 |
+
scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1))
|
139 |
+
# scheduler = lambda optimizer: optim.lr_scheduler.OneCycleLR(optimizer, max_lr=opt.lr, total_steps=opt.iters, pct_start=0.1)
|
140 |
+
|
141 |
+
trainer = Trainer('df', opt, model, guidance, device=device, workspace=opt.workspace, optimizer=optimizer, ema_decay=None, fp16=opt.fp16, lr_scheduler=scheduler, use_checkpoint=opt.ckpt, eval_interval=opt.eval_interval, scheduler_update_every_step=True)
|
142 |
+
|
143 |
+
if opt.gui:
|
144 |
+
trainer.train_loader = train_loader # attach dataloader to trainer
|
145 |
+
|
146 |
+
gui = NeRFGUI(opt, trainer)
|
147 |
+
gui.render()
|
148 |
+
|
149 |
+
else:
|
150 |
+
valid_loader = NeRFDataset(opt, device=device, type='val', H=opt.H, W=opt.W, size=5).dataloader()
|
151 |
+
|
152 |
+
max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32)
|
153 |
+
trainer.train(train_loader, valid_loader, max_epoch)
|
154 |
+
|
155 |
+
# also test
|
156 |
+
test_loader = NeRFDataset(opt, device=device, type='test', H=opt.H, W=opt.W, size=100).dataloader()
|
157 |
+
trainer.test(test_loader)
|
158 |
+
|
159 |
+
if opt.save_mesh:
|
160 |
+
trainer.save_mesh(resolution=256)
|
nerf/clip.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
import torchvision.transforms as T
|
5 |
+
import torchvision.transforms.functional as TF
|
6 |
+
|
7 |
+
import clip
|
8 |
+
|
9 |
+
class CLIP(nn.Module):
|
10 |
+
def __init__(self, device):
|
11 |
+
super().__init__()
|
12 |
+
|
13 |
+
self.device = device
|
14 |
+
|
15 |
+
self.clip_model, self.clip_preprocess = clip.load("ViT-B/16", device=self.device, jit=False)
|
16 |
+
|
17 |
+
# image augmentation
|
18 |
+
self.aug = T.Compose([
|
19 |
+
T.Resize((224, 224)),
|
20 |
+
T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
21 |
+
])
|
22 |
+
|
23 |
+
# self.gaussian_blur = T.GaussianBlur(15, sigma=(0.1, 10))
|
24 |
+
|
25 |
+
|
26 |
+
def get_text_embeds(self, prompt):
|
27 |
+
|
28 |
+
text = clip.tokenize(prompt).to(self.device)
|
29 |
+
text_z = self.clip_model.encode_text(text)
|
30 |
+
text_z = text_z / text_z.norm(dim=-1, keepdim=True)
|
31 |
+
|
32 |
+
return text_z
|
33 |
+
|
34 |
+
|
35 |
+
def train_step(self, text_z, pred_rgb):
|
36 |
+
|
37 |
+
pred_rgb = self.aug(pred_rgb)
|
38 |
+
|
39 |
+
image_z = self.clip_model.encode_image(pred_rgb)
|
40 |
+
image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
|
41 |
+
|
42 |
+
loss = - (image_z * text_z).sum(-1).mean()
|
43 |
+
|
44 |
+
return loss
|
45 |
+
|
nerf/gui.py
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import dearpygui.dearpygui as dpg
|
5 |
+
from scipy.spatial.transform import Rotation as R
|
6 |
+
|
7 |
+
from nerf.utils import *
|
8 |
+
|
9 |
+
|
10 |
+
class OrbitCamera:
|
11 |
+
def __init__(self, W, H, r=2, fovy=60):
|
12 |
+
self.W = W
|
13 |
+
self.H = H
|
14 |
+
self.radius = r # camera distance from center
|
15 |
+
self.fovy = fovy # in degree
|
16 |
+
self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point
|
17 |
+
self.rot = R.from_quat([1, 0, 0, 0]) # init camera matrix: [[1, 0, 0], [0, -1, 0], [0, 0, 1]] (to suit ngp convention)
|
18 |
+
self.up = np.array([0, 1, 0], dtype=np.float32) # need to be normalized!
|
19 |
+
|
20 |
+
# pose
|
21 |
+
@property
|
22 |
+
def pose(self):
|
23 |
+
# first move camera to radius
|
24 |
+
res = np.eye(4, dtype=np.float32)
|
25 |
+
res[2, 3] -= self.radius
|
26 |
+
# rotate
|
27 |
+
rot = np.eye(4, dtype=np.float32)
|
28 |
+
rot[:3, :3] = self.rot.as_matrix()
|
29 |
+
res = rot @ res
|
30 |
+
# translate
|
31 |
+
res[:3, 3] -= self.center
|
32 |
+
return res
|
33 |
+
|
34 |
+
# intrinsics
|
35 |
+
@property
|
36 |
+
def intrinsics(self):
|
37 |
+
focal = self.H / (2 * np.tan(np.deg2rad(self.fovy) / 2))
|
38 |
+
return np.array([focal, focal, self.W // 2, self.H // 2])
|
39 |
+
|
40 |
+
def orbit(self, dx, dy):
|
41 |
+
# rotate along camera up/side axis!
|
42 |
+
side = self.rot.as_matrix()[:3, 0] # why this is side --> ? # already normalized.
|
43 |
+
rotvec_x = self.up * np.deg2rad(-0.1 * dx)
|
44 |
+
rotvec_y = side * np.deg2rad(-0.1 * dy)
|
45 |
+
self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot
|
46 |
+
|
47 |
+
def scale(self, delta):
|
48 |
+
self.radius *= 1.1 ** (-delta)
|
49 |
+
|
50 |
+
def pan(self, dx, dy, dz=0):
|
51 |
+
# pan in camera coordinate system (careful on the sensitivity!)
|
52 |
+
self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([dx, dy, dz])
|
53 |
+
|
54 |
+
|
55 |
+
class NeRFGUI:
|
56 |
+
def __init__(self, opt, trainer, debug=True):
|
57 |
+
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
|
58 |
+
self.W = opt.W
|
59 |
+
self.H = opt.H
|
60 |
+
self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)
|
61 |
+
self.debug = debug
|
62 |
+
self.bg_color = torch.ones(3, dtype=torch.float32) # default white bg
|
63 |
+
self.training = False
|
64 |
+
self.step = 0 # training step
|
65 |
+
|
66 |
+
self.trainer = trainer
|
67 |
+
self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
|
68 |
+
self.need_update = True # camera moved, should reset accumulation
|
69 |
+
self.spp = 1 # sample per pixel
|
70 |
+
self.light_dir = np.array([opt.light_theta, opt.light_phi])
|
71 |
+
self.ambient_ratio = 1.0
|
72 |
+
self.mode = 'image' # choose from ['image', 'depth']
|
73 |
+
self.shading = 'albedo'
|
74 |
+
|
75 |
+
self.dynamic_resolution = True
|
76 |
+
self.downscale = 1
|
77 |
+
self.train_steps = 16
|
78 |
+
|
79 |
+
dpg.create_context()
|
80 |
+
self.register_dpg()
|
81 |
+
self.test_step()
|
82 |
+
|
83 |
+
|
84 |
+
def __del__(self):
|
85 |
+
dpg.destroy_context()
|
86 |
+
|
87 |
+
|
88 |
+
def train_step(self):
|
89 |
+
|
90 |
+
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
|
91 |
+
starter.record()
|
92 |
+
|
93 |
+
outputs = self.trainer.train_gui(self.trainer.train_loader, step=self.train_steps)
|
94 |
+
|
95 |
+
ender.record()
|
96 |
+
torch.cuda.synchronize()
|
97 |
+
t = starter.elapsed_time(ender)
|
98 |
+
|
99 |
+
self.step += self.train_steps
|
100 |
+
self.need_update = True
|
101 |
+
|
102 |
+
dpg.set_value("_log_train_time", f'{t:.4f}ms ({int(1000/t)} FPS)')
|
103 |
+
dpg.set_value("_log_train_log", f'step = {self.step: 5d} (+{self.train_steps: 2d}), loss = {outputs["loss"]:.4f}, lr = {outputs["lr"]:.5f}')
|
104 |
+
|
105 |
+
# dynamic train steps
|
106 |
+
# max allowed train time per-frame is 500 ms
|
107 |
+
full_t = t / self.train_steps * 16
|
108 |
+
train_steps = min(16, max(4, int(16 * 500 / full_t)))
|
109 |
+
if train_steps > self.train_steps * 1.2 or train_steps < self.train_steps * 0.8:
|
110 |
+
self.train_steps = train_steps
|
111 |
+
|
112 |
+
|
113 |
+
def prepare_buffer(self, outputs):
|
114 |
+
if self.mode == 'image':
|
115 |
+
return outputs['image']
|
116 |
+
else:
|
117 |
+
return np.expand_dims(outputs['depth'], -1).repeat(3, -1)
|
118 |
+
|
119 |
+
|
120 |
+
def test_step(self):
|
121 |
+
|
122 |
+
if self.need_update or self.spp < self.opt.max_spp:
|
123 |
+
|
124 |
+
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
|
125 |
+
starter.record()
|
126 |
+
|
127 |
+
outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, self.W, self.H, self.bg_color, self.spp, self.downscale, self.light_dir, self.ambient_ratio, self.shading)
|
128 |
+
|
129 |
+
ender.record()
|
130 |
+
torch.cuda.synchronize()
|
131 |
+
t = starter.elapsed_time(ender)
|
132 |
+
|
133 |
+
# update dynamic resolution
|
134 |
+
if self.dynamic_resolution:
|
135 |
+
# max allowed infer time per-frame is 200 ms
|
136 |
+
full_t = t / (self.downscale ** 2)
|
137 |
+
downscale = min(1, max(1/4, math.sqrt(200 / full_t)))
|
138 |
+
if downscale > self.downscale * 1.2 or downscale < self.downscale * 0.8:
|
139 |
+
self.downscale = downscale
|
140 |
+
|
141 |
+
if self.need_update:
|
142 |
+
self.render_buffer = self.prepare_buffer(outputs)
|
143 |
+
self.spp = 1
|
144 |
+
self.need_update = False
|
145 |
+
else:
|
146 |
+
self.render_buffer = (self.render_buffer * self.spp + self.prepare_buffer(outputs)) / (self.spp + 1)
|
147 |
+
self.spp += 1
|
148 |
+
|
149 |
+
dpg.set_value("_log_infer_time", f'{t:.4f}ms ({int(1000/t)} FPS)')
|
150 |
+
dpg.set_value("_log_resolution", f'{int(self.downscale * self.W)}x{int(self.downscale * self.H)}')
|
151 |
+
dpg.set_value("_log_spp", self.spp)
|
152 |
+
dpg.set_value("_texture", self.render_buffer)
|
153 |
+
|
154 |
+
|
155 |
+
def register_dpg(self):
|
156 |
+
|
157 |
+
### register texture
|
158 |
+
|
159 |
+
with dpg.texture_registry(show=False):
|
160 |
+
dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture")
|
161 |
+
|
162 |
+
### register window
|
163 |
+
|
164 |
+
# the rendered image, as the primary window
|
165 |
+
with dpg.window(tag="_primary_window", width=self.W, height=self.H):
|
166 |
+
|
167 |
+
# add the texture
|
168 |
+
dpg.add_image("_texture")
|
169 |
+
|
170 |
+
dpg.set_primary_window("_primary_window", True)
|
171 |
+
|
172 |
+
# control window
|
173 |
+
with dpg.window(label="Control", tag="_control_window", width=400, height=300):
|
174 |
+
|
175 |
+
# text prompt
|
176 |
+
if self.opt.text is not None:
|
177 |
+
dpg.add_text("text: " + self.opt.text, tag="_log_prompt_text")
|
178 |
+
|
179 |
+
# button theme
|
180 |
+
with dpg.theme() as theme_button:
|
181 |
+
with dpg.theme_component(dpg.mvButton):
|
182 |
+
dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18))
|
183 |
+
dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47))
|
184 |
+
dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83))
|
185 |
+
dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5)
|
186 |
+
dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3)
|
187 |
+
|
188 |
+
# time
|
189 |
+
if not self.opt.test:
|
190 |
+
with dpg.group(horizontal=True):
|
191 |
+
dpg.add_text("Train time: ")
|
192 |
+
dpg.add_text("no data", tag="_log_train_time")
|
193 |
+
|
194 |
+
with dpg.group(horizontal=True):
|
195 |
+
dpg.add_text("Infer time: ")
|
196 |
+
dpg.add_text("no data", tag="_log_infer_time")
|
197 |
+
|
198 |
+
with dpg.group(horizontal=True):
|
199 |
+
dpg.add_text("SPP: ")
|
200 |
+
dpg.add_text("1", tag="_log_spp")
|
201 |
+
|
202 |
+
# train button
|
203 |
+
if not self.opt.test:
|
204 |
+
with dpg.collapsing_header(label="Train", default_open=True):
|
205 |
+
with dpg.group(horizontal=True):
|
206 |
+
dpg.add_text("Train: ")
|
207 |
+
|
208 |
+
def callback_train(sender, app_data):
|
209 |
+
if self.training:
|
210 |
+
self.training = False
|
211 |
+
dpg.configure_item("_button_train", label="start")
|
212 |
+
else:
|
213 |
+
self.training = True
|
214 |
+
dpg.configure_item("_button_train", label="stop")
|
215 |
+
|
216 |
+
dpg.add_button(label="start", tag="_button_train", callback=callback_train)
|
217 |
+
dpg.bind_item_theme("_button_train", theme_button)
|
218 |
+
|
219 |
+
def callback_reset(sender, app_data):
|
220 |
+
@torch.no_grad()
|
221 |
+
def weight_reset(m: nn.Module):
|
222 |
+
reset_parameters = getattr(m, "reset_parameters", None)
|
223 |
+
if callable(reset_parameters):
|
224 |
+
m.reset_parameters()
|
225 |
+
self.trainer.model.apply(fn=weight_reset)
|
226 |
+
self.trainer.model.reset_extra_state() # for cuda_ray density_grid and step_counter
|
227 |
+
self.need_update = True
|
228 |
+
|
229 |
+
dpg.add_button(label="reset", tag="_button_reset", callback=callback_reset)
|
230 |
+
dpg.bind_item_theme("_button_reset", theme_button)
|
231 |
+
|
232 |
+
|
233 |
+
with dpg.group(horizontal=True):
|
234 |
+
dpg.add_text("Checkpoint: ")
|
235 |
+
|
236 |
+
def callback_save(sender, app_data):
|
237 |
+
self.trainer.save_checkpoint(full=True, best=False)
|
238 |
+
dpg.set_value("_log_ckpt", "saved " + os.path.basename(self.trainer.stats["checkpoints"][-1]))
|
239 |
+
self.trainer.epoch += 1 # use epoch to indicate different calls.
|
240 |
+
|
241 |
+
dpg.add_button(label="save", tag="_button_save", callback=callback_save)
|
242 |
+
dpg.bind_item_theme("_button_save", theme_button)
|
243 |
+
|
244 |
+
dpg.add_text("", tag="_log_ckpt")
|
245 |
+
|
246 |
+
# save mesh
|
247 |
+
with dpg.group(horizontal=True):
|
248 |
+
dpg.add_text("Marching Cubes: ")
|
249 |
+
|
250 |
+
def callback_mesh(sender, app_data):
|
251 |
+
self.trainer.save_mesh(resolution=256, threshold=10)
|
252 |
+
dpg.set_value("_log_mesh", "saved " + f'{self.trainer.name}_{self.trainer.epoch}.ply')
|
253 |
+
self.trainer.epoch += 1 # use epoch to indicate different calls.
|
254 |
+
|
255 |
+
dpg.add_button(label="mesh", tag="_button_mesh", callback=callback_mesh)
|
256 |
+
dpg.bind_item_theme("_button_mesh", theme_button)
|
257 |
+
|
258 |
+
dpg.add_text("", tag="_log_mesh")
|
259 |
+
|
260 |
+
with dpg.group(horizontal=True):
|
261 |
+
dpg.add_text("", tag="_log_train_log")
|
262 |
+
|
263 |
+
|
264 |
+
# rendering options
|
265 |
+
with dpg.collapsing_header(label="Options", default_open=True):
|
266 |
+
|
267 |
+
# dynamic rendering resolution
|
268 |
+
with dpg.group(horizontal=True):
|
269 |
+
|
270 |
+
def callback_set_dynamic_resolution(sender, app_data):
|
271 |
+
if self.dynamic_resolution:
|
272 |
+
self.dynamic_resolution = False
|
273 |
+
self.downscale = 1
|
274 |
+
else:
|
275 |
+
self.dynamic_resolution = True
|
276 |
+
self.need_update = True
|
277 |
+
|
278 |
+
dpg.add_checkbox(label="dynamic resolution", default_value=self.dynamic_resolution, callback=callback_set_dynamic_resolution)
|
279 |
+
dpg.add_text(f"{self.W}x{self.H}", tag="_log_resolution")
|
280 |
+
|
281 |
+
# mode combo
|
282 |
+
def callback_change_mode(sender, app_data):
|
283 |
+
self.mode = app_data
|
284 |
+
self.need_update = True
|
285 |
+
|
286 |
+
dpg.add_combo(('image', 'depth'), label='mode', default_value=self.mode, callback=callback_change_mode)
|
287 |
+
|
288 |
+
# bg_color picker
|
289 |
+
def callback_change_bg(sender, app_data):
|
290 |
+
self.bg_color = torch.tensor(app_data[:3], dtype=torch.float32) # only need RGB in [0, 1]
|
291 |
+
self.need_update = True
|
292 |
+
|
293 |
+
dpg.add_color_edit((255, 255, 255), label="Background Color", width=200, tag="_color_editor", no_alpha=True, callback=callback_change_bg)
|
294 |
+
|
295 |
+
# fov slider
|
296 |
+
def callback_set_fovy(sender, app_data):
|
297 |
+
self.cam.fovy = app_data
|
298 |
+
self.need_update = True
|
299 |
+
|
300 |
+
dpg.add_slider_int(label="FoV (vertical)", min_value=1, max_value=120, format="%d deg", default_value=self.cam.fovy, callback=callback_set_fovy)
|
301 |
+
|
302 |
+
# dt_gamma slider
|
303 |
+
def callback_set_dt_gamma(sender, app_data):
|
304 |
+
self.opt.dt_gamma = app_data
|
305 |
+
self.need_update = True
|
306 |
+
|
307 |
+
dpg.add_slider_float(label="dt_gamma", min_value=0, max_value=0.1, format="%.5f", default_value=self.opt.dt_gamma, callback=callback_set_dt_gamma)
|
308 |
+
|
309 |
+
# max_steps slider
|
310 |
+
def callback_set_max_steps(sender, app_data):
|
311 |
+
self.opt.max_steps = app_data
|
312 |
+
self.need_update = True
|
313 |
+
|
314 |
+
dpg.add_slider_int(label="max steps", min_value=1, max_value=1024, format="%d", default_value=self.opt.max_steps, callback=callback_set_max_steps)
|
315 |
+
|
316 |
+
# aabb slider
|
317 |
+
def callback_set_aabb(sender, app_data, user_data):
|
318 |
+
# user_data is the dimension for aabb (xmin, ymin, zmin, xmax, ymax, zmax)
|
319 |
+
self.trainer.model.aabb_infer[user_data] = app_data
|
320 |
+
|
321 |
+
# also change train aabb ? [better not...]
|
322 |
+
#self.trainer.model.aabb_train[user_data] = app_data
|
323 |
+
|
324 |
+
self.need_update = True
|
325 |
+
|
326 |
+
dpg.add_separator()
|
327 |
+
dpg.add_text("Axis-aligned bounding box:")
|
328 |
+
|
329 |
+
with dpg.group(horizontal=True):
|
330 |
+
dpg.add_slider_float(label="x", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=0)
|
331 |
+
dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=3)
|
332 |
+
|
333 |
+
with dpg.group(horizontal=True):
|
334 |
+
dpg.add_slider_float(label="y", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=1)
|
335 |
+
dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=4)
|
336 |
+
|
337 |
+
with dpg.group(horizontal=True):
|
338 |
+
dpg.add_slider_float(label="z", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=2)
|
339 |
+
dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=5)
|
340 |
+
|
341 |
+
# light dir
|
342 |
+
def callback_set_light_dir(sender, app_data, user_data):
|
343 |
+
self.light_dir[user_data] = app_data
|
344 |
+
self.need_update = True
|
345 |
+
|
346 |
+
dpg.add_separator()
|
347 |
+
dpg.add_text("Plane Light Direction:")
|
348 |
+
|
349 |
+
with dpg.group(horizontal=True):
|
350 |
+
dpg.add_slider_float(label="theta", min_value=0, max_value=180, format="%.2f", default_value=self.opt.light_theta, callback=callback_set_light_dir, user_data=0)
|
351 |
+
|
352 |
+
with dpg.group(horizontal=True):
|
353 |
+
dpg.add_slider_float(label="phi", min_value=0, max_value=360, format="%.2f", default_value=self.opt.light_phi, callback=callback_set_light_dir, user_data=1)
|
354 |
+
|
355 |
+
# ambient ratio
|
356 |
+
def callback_set_abm_ratio(sender, app_data):
|
357 |
+
self.ambient_ratio = app_data
|
358 |
+
self.need_update = True
|
359 |
+
|
360 |
+
dpg.add_slider_float(label="ambient", min_value=0, max_value=1.0, format="%.5f", default_value=self.ambient_ratio, callback=callback_set_abm_ratio)
|
361 |
+
|
362 |
+
# shading mode
|
363 |
+
def callback_change_shading(sender, app_data):
|
364 |
+
self.shading = app_data
|
365 |
+
self.need_update = True
|
366 |
+
|
367 |
+
dpg.add_combo(('albedo', 'lambertian', 'textureless', 'normal'), label='shading', default_value=self.shading, callback=callback_change_shading)
|
368 |
+
|
369 |
+
|
370 |
+
# debug info
|
371 |
+
if self.debug:
|
372 |
+
with dpg.collapsing_header(label="Debug"):
|
373 |
+
# pose
|
374 |
+
dpg.add_separator()
|
375 |
+
dpg.add_text("Camera Pose:")
|
376 |
+
dpg.add_text(str(self.cam.pose), tag="_log_pose")
|
377 |
+
|
378 |
+
|
379 |
+
### register camera handler
|
380 |
+
|
381 |
+
def callback_camera_drag_rotate(sender, app_data):
|
382 |
+
|
383 |
+
if not dpg.is_item_focused("_primary_window"):
|
384 |
+
return
|
385 |
+
|
386 |
+
dx = app_data[1]
|
387 |
+
dy = app_data[2]
|
388 |
+
|
389 |
+
self.cam.orbit(dx, dy)
|
390 |
+
self.need_update = True
|
391 |
+
|
392 |
+
if self.debug:
|
393 |
+
dpg.set_value("_log_pose", str(self.cam.pose))
|
394 |
+
|
395 |
+
|
396 |
+
def callback_camera_wheel_scale(sender, app_data):
|
397 |
+
|
398 |
+
if not dpg.is_item_focused("_primary_window"):
|
399 |
+
return
|
400 |
+
|
401 |
+
delta = app_data
|
402 |
+
|
403 |
+
self.cam.scale(delta)
|
404 |
+
self.need_update = True
|
405 |
+
|
406 |
+
if self.debug:
|
407 |
+
dpg.set_value("_log_pose", str(self.cam.pose))
|
408 |
+
|
409 |
+
|
410 |
+
def callback_camera_drag_pan(sender, app_data):
|
411 |
+
|
412 |
+
if not dpg.is_item_focused("_primary_window"):
|
413 |
+
return
|
414 |
+
|
415 |
+
dx = app_data[1]
|
416 |
+
dy = app_data[2]
|
417 |
+
|
418 |
+
self.cam.pan(dx, dy)
|
419 |
+
self.need_update = True
|
420 |
+
|
421 |
+
if self.debug:
|
422 |
+
dpg.set_value("_log_pose", str(self.cam.pose))
|
423 |
+
|
424 |
+
|
425 |
+
with dpg.handler_registry():
|
426 |
+
dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Left, callback=callback_camera_drag_rotate)
|
427 |
+
dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale)
|
428 |
+
dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan)
|
429 |
+
|
430 |
+
|
431 |
+
dpg.create_viewport(title='torch-ngp', width=self.W, height=self.H, resizable=False)
|
432 |
+
|
433 |
+
# TODO: seems dearpygui doesn't support resizing texture...
|
434 |
+
# def callback_resize(sender, app_data):
|
435 |
+
# self.W = app_data[0]
|
436 |
+
# self.H = app_data[1]
|
437 |
+
# # how to reload texture ???
|
438 |
+
|
439 |
+
# dpg.set_viewport_resize_callback(callback_resize)
|
440 |
+
|
441 |
+
### global theme
|
442 |
+
with dpg.theme() as theme_no_padding:
|
443 |
+
with dpg.theme_component(dpg.mvAll):
|
444 |
+
# set all padding to 0 to avoid scroll bar
|
445 |
+
dpg.add_theme_style(dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core)
|
446 |
+
dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core)
|
447 |
+
dpg.add_theme_style(dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core)
|
448 |
+
|
449 |
+
dpg.bind_item_theme("_primary_window", theme_no_padding)
|
450 |
+
|
451 |
+
dpg.setup_dearpygui()
|
452 |
+
|
453 |
+
#dpg.show_metrics()
|
454 |
+
|
455 |
+
dpg.show_viewport()
|
456 |
+
|
457 |
+
|
458 |
+
def render(self):
|
459 |
+
|
460 |
+
while dpg.is_dearpygui_running():
|
461 |
+
# update texture every frame
|
462 |
+
if self.training:
|
463 |
+
self.train_step()
|
464 |
+
self.test_step()
|
465 |
+
dpg.render_dearpygui_frame()
|
nerf/network.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from activation import trunc_exp
|
6 |
+
from .renderer import NeRFRenderer
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from encoding import get_encoder
|
10 |
+
|
11 |
+
from .utils import safe_normalize
|
12 |
+
|
13 |
+
class MLP(nn.Module):
|
14 |
+
def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
|
15 |
+
super().__init__()
|
16 |
+
self.dim_in = dim_in
|
17 |
+
self.dim_out = dim_out
|
18 |
+
self.dim_hidden = dim_hidden
|
19 |
+
self.num_layers = num_layers
|
20 |
+
|
21 |
+
net = []
|
22 |
+
for l in range(num_layers):
|
23 |
+
net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))
|
24 |
+
|
25 |
+
self.net = nn.ModuleList(net)
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
for l in range(self.num_layers):
|
29 |
+
x = self.net[l](x)
|
30 |
+
if l != self.num_layers - 1:
|
31 |
+
x = F.relu(x, inplace=True)
|
32 |
+
return x
|
33 |
+
|
34 |
+
|
35 |
+
class NeRFNetwork(NeRFRenderer):
|
36 |
+
def __init__(self,
|
37 |
+
opt,
|
38 |
+
num_layers=5,
|
39 |
+
hidden_dim=128,
|
40 |
+
num_layers_bg=2,
|
41 |
+
hidden_dim_bg=64,
|
42 |
+
):
|
43 |
+
|
44 |
+
super().__init__(opt)
|
45 |
+
|
46 |
+
self.num_layers = num_layers
|
47 |
+
self.hidden_dim = hidden_dim
|
48 |
+
self.encoder, self.in_dim = get_encoder('frequency', input_dim=3)
|
49 |
+
self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True)
|
50 |
+
|
51 |
+
# background network
|
52 |
+
if self.bg_radius > 0:
|
53 |
+
self.num_layers_bg = num_layers_bg
|
54 |
+
self.hidden_dim_bg = hidden_dim_bg
|
55 |
+
self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3)
|
56 |
+
self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
|
57 |
+
|
58 |
+
else:
|
59 |
+
self.bg_net = None
|
60 |
+
|
61 |
+
def gaussian(self, x):
|
62 |
+
# x: [B, N, 3]
|
63 |
+
|
64 |
+
d = (x ** 2).sum(-1)
|
65 |
+
g = 5 * torch.exp(-d / (2 * 0.2 ** 2))
|
66 |
+
|
67 |
+
return g
|
68 |
+
|
69 |
+
def common_forward(self, x):
|
70 |
+
# x: [N, 3], in [-bound, bound]
|
71 |
+
|
72 |
+
# sigma
|
73 |
+
h = self.encoder(x, bound=self.bound)
|
74 |
+
|
75 |
+
h = self.sigma_net(h)
|
76 |
+
|
77 |
+
sigma = trunc_exp(h[..., 0] + self.gaussian(x))
|
78 |
+
albedo = torch.sigmoid(h[..., 1:])
|
79 |
+
|
80 |
+
return sigma, albedo
|
81 |
+
|
82 |
+
# ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192
|
83 |
+
def finite_difference_normal(self, x, epsilon=1e-2):
|
84 |
+
# x: [N, 3]
|
85 |
+
dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
|
86 |
+
dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
|
87 |
+
dy_pos, _ = self.common_forward((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
|
88 |
+
dy_neg, _ = self.common_forward((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
|
89 |
+
dz_pos, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound))
|
90 |
+
dz_neg, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound))
|
91 |
+
|
92 |
+
normal = torch.stack([
|
93 |
+
0.5 * (dx_pos - dx_neg) / epsilon,
|
94 |
+
0.5 * (dy_pos - dy_neg) / epsilon,
|
95 |
+
0.5 * (dz_pos - dz_neg) / epsilon
|
96 |
+
], dim=-1)
|
97 |
+
|
98 |
+
return normal
|
99 |
+
|
100 |
+
def forward(self, x, d, l=None, ratio=1, shading='albedo'):
|
101 |
+
# x: [N, 3], in [-bound, bound]
|
102 |
+
# d: [N, 3], view direction, nomalized in [-1, 1]
|
103 |
+
# l: [3], plane light direction, nomalized in [-1, 1]
|
104 |
+
# ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)
|
105 |
+
|
106 |
+
if shading == 'albedo':
|
107 |
+
# no need to query normal
|
108 |
+
sigma, color = self.common_forward(x)
|
109 |
+
normal = None
|
110 |
+
|
111 |
+
else:
|
112 |
+
# query normal
|
113 |
+
|
114 |
+
# sigma, albedo = self.common_forward(x)
|
115 |
+
# normal = self.finite_difference_normal(x)
|
116 |
+
|
117 |
+
with torch.enable_grad():
|
118 |
+
x.requires_grad_(True)
|
119 |
+
sigma, albedo = self.common_forward(x)
|
120 |
+
# query gradient
|
121 |
+
normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
|
122 |
+
|
123 |
+
# normalize...
|
124 |
+
normal = safe_normalize(normal)
|
125 |
+
normal[torch.isnan(normal)] = 0
|
126 |
+
|
127 |
+
# lambertian shading
|
128 |
+
lambertian = ratio + (1 - ratio) * (normal @ -l).clamp(min=0) # [N,]
|
129 |
+
|
130 |
+
if shading == 'textureless':
|
131 |
+
color = lambertian.unsqueeze(-1).repeat(1, 3)
|
132 |
+
elif shading == 'normal':
|
133 |
+
color = (normal + 1) / 2
|
134 |
+
else: # 'lambertian'
|
135 |
+
color = albedo * lambertian.unsqueeze(-1)
|
136 |
+
|
137 |
+
return sigma, color, normal
|
138 |
+
|
139 |
+
|
140 |
+
def density(self, x):
|
141 |
+
# x: [N, 3], in [-bound, bound]
|
142 |
+
|
143 |
+
sigma, albedo = self.common_forward(x)
|
144 |
+
|
145 |
+
return {
|
146 |
+
'sigma': sigma,
|
147 |
+
'albedo': albedo,
|
148 |
+
}
|
149 |
+
|
150 |
+
|
151 |
+
def background(self, d):
|
152 |
+
|
153 |
+
h = self.encoder_bg(d) # [N, C]
|
154 |
+
|
155 |
+
h = self.bg_net(h)
|
156 |
+
|
157 |
+
# sigmoid activation for rgb
|
158 |
+
rgbs = torch.sigmoid(h)
|
159 |
+
|
160 |
+
return rgbs
|
161 |
+
|
162 |
+
# optimizer utils
|
163 |
+
def get_params(self, lr):
|
164 |
+
|
165 |
+
params = [
|
166 |
+
# {'params': self.encoder.parameters(), 'lr': lr * 10},
|
167 |
+
{'params': self.sigma_net.parameters(), 'lr': lr},
|
168 |
+
]
|
169 |
+
|
170 |
+
if self.bg_radius > 0:
|
171 |
+
# params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})
|
172 |
+
params.append({'params': self.bg_net.parameters(), 'lr': lr})
|
173 |
+
|
174 |
+
return params
|
nerf/network_grid.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from activation import trunc_exp
|
6 |
+
from .renderer import NeRFRenderer
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from encoding import get_encoder
|
10 |
+
|
11 |
+
from .utils import safe_normalize
|
12 |
+
|
13 |
+
class MLP(nn.Module):
|
14 |
+
def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
|
15 |
+
super().__init__()
|
16 |
+
self.dim_in = dim_in
|
17 |
+
self.dim_out = dim_out
|
18 |
+
self.dim_hidden = dim_hidden
|
19 |
+
self.num_layers = num_layers
|
20 |
+
|
21 |
+
net = []
|
22 |
+
for l in range(num_layers):
|
23 |
+
net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))
|
24 |
+
|
25 |
+
self.net = nn.ModuleList(net)
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
for l in range(self.num_layers):
|
29 |
+
x = self.net[l](x)
|
30 |
+
if l != self.num_layers - 1:
|
31 |
+
x = F.relu(x, inplace=True)
|
32 |
+
return x
|
33 |
+
|
34 |
+
|
35 |
+
class NeRFNetwork(NeRFRenderer):
|
36 |
+
def __init__(self,
|
37 |
+
opt,
|
38 |
+
num_layers=3,
|
39 |
+
hidden_dim=64,
|
40 |
+
num_layers_bg=2,
|
41 |
+
hidden_dim_bg=64,
|
42 |
+
):
|
43 |
+
|
44 |
+
super().__init__(opt)
|
45 |
+
|
46 |
+
self.num_layers = num_layers
|
47 |
+
self.hidden_dim = hidden_dim
|
48 |
+
|
49 |
+
self.encoder, self.in_dim = get_encoder('tiledgrid', input_dim=3, desired_resolution=2048 * self.bound)
|
50 |
+
|
51 |
+
self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True)
|
52 |
+
|
53 |
+
# background network
|
54 |
+
if self.bg_radius > 0:
|
55 |
+
self.num_layers_bg = num_layers_bg
|
56 |
+
self.hidden_dim_bg = hidden_dim_bg
|
57 |
+
|
58 |
+
# use a very simple network to avoid it learning the prompt...
|
59 |
+
# self.encoder_bg, self.in_dim_bg = get_encoder('tiledgrid', input_dim=2, num_levels=4, desired_resolution=2048)
|
60 |
+
self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3)
|
61 |
+
|
62 |
+
self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
|
63 |
+
|
64 |
+
else:
|
65 |
+
self.bg_net = None
|
66 |
+
|
67 |
+
# add a density blob to the scene center
|
68 |
+
def gaussian(self, x):
|
69 |
+
# x: [B, N, 3]
|
70 |
+
|
71 |
+
d = (x ** 2).sum(-1)
|
72 |
+
g = 5 * torch.exp(-d / (2 * 0.2 ** 2))
|
73 |
+
|
74 |
+
return g
|
75 |
+
|
76 |
+
def common_forward(self, x):
|
77 |
+
# x: [N, 3], in [-bound, bound]
|
78 |
+
|
79 |
+
# sigma
|
80 |
+
h = self.encoder(x, bound=self.bound)
|
81 |
+
|
82 |
+
h = self.sigma_net(h)
|
83 |
+
|
84 |
+
sigma = trunc_exp(h[..., 0] + self.gaussian(x))
|
85 |
+
albedo = torch.sigmoid(h[..., 1:])
|
86 |
+
|
87 |
+
return sigma, albedo
|
88 |
+
|
89 |
+
# ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192
|
90 |
+
def finite_difference_normal(self, x, epsilon=1e-2):
|
91 |
+
# x: [N, 3]
|
92 |
+
dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
|
93 |
+
dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
|
94 |
+
dy_pos, _ = self.common_forward((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
|
95 |
+
dy_neg, _ = self.common_forward((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
|
96 |
+
dz_pos, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound))
|
97 |
+
dz_neg, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound))
|
98 |
+
|
99 |
+
normal = torch.stack([
|
100 |
+
0.5 * (dx_pos - dx_neg) / epsilon,
|
101 |
+
0.5 * (dy_pos - dy_neg) / epsilon,
|
102 |
+
0.5 * (dz_pos - dz_neg) / epsilon
|
103 |
+
], dim=-1)
|
104 |
+
|
105 |
+
return normal
|
106 |
+
|
107 |
+
def forward(self, x, d, l=None, ratio=1, shading='albedo'):
|
108 |
+
# x: [N, 3], in [-bound, bound]
|
109 |
+
# d: [N, 3], view direction, nomalized in [-1, 1]
|
110 |
+
# l: [3], plane light direction, nomalized in [-1, 1]
|
111 |
+
# ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)
|
112 |
+
|
113 |
+
if shading == 'albedo':
|
114 |
+
# no need to query normal
|
115 |
+
sigma, color = self.common_forward(x)
|
116 |
+
normal = None
|
117 |
+
|
118 |
+
else:
|
119 |
+
# query normal
|
120 |
+
|
121 |
+
sigma, albedo = self.common_forward(x)
|
122 |
+
normal = self.finite_difference_normal(x)
|
123 |
+
|
124 |
+
# with torch.enable_grad():
|
125 |
+
# x.requires_grad_(True)
|
126 |
+
# sigma, albedo = self.common_forward(x)
|
127 |
+
# # query gradient
|
128 |
+
# normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
|
129 |
+
|
130 |
+
# normalize...
|
131 |
+
normal = safe_normalize(normal)
|
132 |
+
normal[torch.isnan(normal)] = 0
|
133 |
+
|
134 |
+
# lambertian shading
|
135 |
+
lambertian = ratio + (1 - ratio) * (normal @ -l).clamp(min=0) # [N,]
|
136 |
+
|
137 |
+
if shading == 'textureless':
|
138 |
+
color = lambertian.unsqueeze(-1).repeat(1, 3)
|
139 |
+
elif shading == 'normal':
|
140 |
+
color = (normal + 1) / 2
|
141 |
+
else: # 'lambertian'
|
142 |
+
color = albedo * lambertian.unsqueeze(-1)
|
143 |
+
|
144 |
+
return sigma, color, normal
|
145 |
+
|
146 |
+
|
147 |
+
def density(self, x):
|
148 |
+
# x: [N, 3], in [-bound, bound]
|
149 |
+
|
150 |
+
sigma, albedo = self.common_forward(x)
|
151 |
+
|
152 |
+
return {
|
153 |
+
'sigma': sigma,
|
154 |
+
'albedo': albedo,
|
155 |
+
}
|
156 |
+
|
157 |
+
|
158 |
+
def background(self, d):
|
159 |
+
|
160 |
+
h = self.encoder_bg(d) # [N, C]
|
161 |
+
|
162 |
+
h = self.bg_net(h)
|
163 |
+
|
164 |
+
# sigmoid activation for rgb
|
165 |
+
rgbs = torch.sigmoid(h)
|
166 |
+
|
167 |
+
return rgbs
|
168 |
+
|
169 |
+
# optimizer utils
|
170 |
+
def get_params(self, lr):
|
171 |
+
|
172 |
+
params = [
|
173 |
+
{'params': self.encoder.parameters(), 'lr': lr * 10},
|
174 |
+
{'params': self.sigma_net.parameters(), 'lr': lr},
|
175 |
+
]
|
176 |
+
|
177 |
+
if self.bg_radius > 0:
|
178 |
+
params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})
|
179 |
+
params.append({'params': self.bg_net.parameters(), 'lr': lr})
|
180 |
+
|
181 |
+
return params
|
nerf/network_tcnn.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from activation import trunc_exp
|
6 |
+
from .renderer import NeRFRenderer
|
7 |
+
from encoding import get_encoder
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import tinycudann as tcnn
|
11 |
+
|
12 |
+
class MLP(nn.Module):
|
13 |
+
def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
|
14 |
+
super().__init__()
|
15 |
+
self.dim_in = dim_in
|
16 |
+
self.dim_out = dim_out
|
17 |
+
self.dim_hidden = dim_hidden
|
18 |
+
self.num_layers = num_layers
|
19 |
+
|
20 |
+
net = []
|
21 |
+
for l in range(num_layers):
|
22 |
+
net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))
|
23 |
+
|
24 |
+
self.net = nn.ModuleList(net)
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
for l in range(self.num_layers):
|
28 |
+
x = self.net[l](x)
|
29 |
+
if l != self.num_layers - 1:
|
30 |
+
x = F.relu(x, inplace=True)
|
31 |
+
return x
|
32 |
+
|
33 |
+
|
34 |
+
class NeRFNetwork(NeRFRenderer):
|
35 |
+
def __init__(self,
|
36 |
+
opt,
|
37 |
+
num_layers=3,
|
38 |
+
hidden_dim=64,
|
39 |
+
num_layers_bg=2,
|
40 |
+
hidden_dim_bg=64,
|
41 |
+
):
|
42 |
+
|
43 |
+
super().__init__(opt)
|
44 |
+
|
45 |
+
self.num_layers = num_layers
|
46 |
+
self.hidden_dim = hidden_dim
|
47 |
+
|
48 |
+
per_level_scale = np.exp2(np.log2(2048 * self.bound / 16) / (16 - 1))
|
49 |
+
|
50 |
+
self.encoder = tcnn.Encoding(
|
51 |
+
n_input_dims=3,
|
52 |
+
encoding_config={
|
53 |
+
"otype": "HashGrid",
|
54 |
+
"n_levels": 16,
|
55 |
+
"n_features_per_level": 2,
|
56 |
+
"log2_hashmap_size": 19,
|
57 |
+
"base_resolution": 16,
|
58 |
+
"per_level_scale": per_level_scale,
|
59 |
+
},
|
60 |
+
)
|
61 |
+
|
62 |
+
self.sigma_net = MLP(32, 4, hidden_dim, num_layers, bias=True)
|
63 |
+
|
64 |
+
# background network
|
65 |
+
if self.bg_radius > 0:
|
66 |
+
self.num_layers_bg = num_layers_bg
|
67 |
+
self.hidden_dim_bg = hidden_dim_bg
|
68 |
+
|
69 |
+
self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3)
|
70 |
+
|
71 |
+
self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
|
72 |
+
|
73 |
+
else:
|
74 |
+
self.bg_net = None
|
75 |
+
|
76 |
+
def gaussian(self, x):
|
77 |
+
# x: [B, N, 3]
|
78 |
+
|
79 |
+
d = (x ** 2).sum(-1)
|
80 |
+
g = 5 * torch.exp(-d / (2 * 0.2 ** 2))
|
81 |
+
|
82 |
+
return g
|
83 |
+
|
84 |
+
def common_forward(self, x):
|
85 |
+
# x: [N, 3], in [-bound, bound]
|
86 |
+
|
87 |
+
# sigma
|
88 |
+
h = (x + self.bound) / (2 * self.bound) # to [0, 1]
|
89 |
+
h = self.encoder(h)
|
90 |
+
|
91 |
+
h = self.sigma_net(h)
|
92 |
+
|
93 |
+
sigma = trunc_exp(h[..., 0] + self.gaussian(x))
|
94 |
+
albedo = torch.sigmoid(h[..., 1:])
|
95 |
+
|
96 |
+
return sigma, albedo
|
97 |
+
|
98 |
+
|
99 |
+
def forward(self, x, d, l=None, ratio=1, shading='albedo'):
|
100 |
+
# x: [N, 3], in [-bound, bound]
|
101 |
+
# d: [N, 3], view direction, nomalized in [-1, 1]
|
102 |
+
# l: [3], plane light direction, nomalized in [-1, 1]
|
103 |
+
# ratio: scalar, ambient ratio, 1 == no shading (albedo only)
|
104 |
+
|
105 |
+
if shading == 'albedo':
|
106 |
+
# no need to query normal
|
107 |
+
sigma, color = self.common_forward(x)
|
108 |
+
normal = None
|
109 |
+
|
110 |
+
else:
|
111 |
+
# query normal
|
112 |
+
has_grad = torch.is_grad_enabled()
|
113 |
+
|
114 |
+
with torch.enable_grad():
|
115 |
+
x.requires_grad_(True)
|
116 |
+
sigma, albedo = self.common_forward(x)
|
117 |
+
# query gradient
|
118 |
+
normal = torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
|
119 |
+
|
120 |
+
# normalize...
|
121 |
+
normal = normal / (torch.norm(normal, dim=-1, keepdim=True) + 1e-9)
|
122 |
+
normal[torch.isnan(normal)] = 0
|
123 |
+
|
124 |
+
if not has_grad:
|
125 |
+
normal = normal.detach()
|
126 |
+
|
127 |
+
# lambertian shading
|
128 |
+
lambertian = ratio + (1 - ratio) * (normal @ l).clamp(min=0) # [N,]
|
129 |
+
|
130 |
+
if shading == 'textureless':
|
131 |
+
color = lambertian.unsqueeze(-1).repeat(1, 3)
|
132 |
+
elif shading == 'normal':
|
133 |
+
color = (normal + 1) / 2
|
134 |
+
else: # 'lambertian'
|
135 |
+
color = albedo * lambertian.unsqueeze(-1)
|
136 |
+
|
137 |
+
return sigma, color, normal
|
138 |
+
|
139 |
+
|
140 |
+
def density(self, x):
|
141 |
+
# x: [N, 3], in [-bound, bound]
|
142 |
+
|
143 |
+
sigma, _ = self.common_forward(x)
|
144 |
+
|
145 |
+
return {
|
146 |
+
'sigma': sigma
|
147 |
+
}
|
148 |
+
|
149 |
+
|
150 |
+
def background(self, d):
|
151 |
+
# x: [N, 2], in [-1, 1]
|
152 |
+
|
153 |
+
h = self.encoder_bg(d) # [N, C]
|
154 |
+
|
155 |
+
h = self.bg_net(h)
|
156 |
+
|
157 |
+
# sigmoid activation for rgb
|
158 |
+
rgbs = torch.sigmoid(h)
|
159 |
+
|
160 |
+
return rgbs
|
161 |
+
|
162 |
+
# optimizer utils
|
163 |
+
def get_params(self, lr):
|
164 |
+
|
165 |
+
params = [
|
166 |
+
{'params': self.encoder.parameters(), 'lr': lr * 10},
|
167 |
+
{'params': self.sigma_net.parameters(), 'lr': lr},
|
168 |
+
]
|
169 |
+
|
170 |
+
if self.bg_radius > 0:
|
171 |
+
params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})
|
172 |
+
params.append({'params': self.bg_net.parameters(), 'lr': lr})
|
173 |
+
|
174 |
+
return params
|
nerf/provider.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import glob
|
4 |
+
import json
|
5 |
+
import tqdm
|
6 |
+
import random
|
7 |
+
import numpy as np
|
8 |
+
from scipy.spatial.transform import Slerp, Rotation
|
9 |
+
|
10 |
+
import trimesh
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch.utils.data import DataLoader
|
14 |
+
|
15 |
+
from .utils import get_rays, safe_normalize
|
16 |
+
|
17 |
+
def visualize_poses(poses, size=0.1):
|
18 |
+
# poses: [B, 4, 4]
|
19 |
+
|
20 |
+
axes = trimesh.creation.axis(axis_length=4)
|
21 |
+
sphere = trimesh.creation.icosphere(radius=1)
|
22 |
+
objects = [axes, sphere]
|
23 |
+
|
24 |
+
for pose in poses:
|
25 |
+
# a camera is visualized with 8 line segments.
|
26 |
+
pos = pose[:3, 3]
|
27 |
+
a = pos + size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2]
|
28 |
+
b = pos - size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2]
|
29 |
+
c = pos - size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2]
|
30 |
+
d = pos + size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2]
|
31 |
+
|
32 |
+
segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a]])
|
33 |
+
segs = trimesh.load_path(segs)
|
34 |
+
objects.append(segs)
|
35 |
+
|
36 |
+
trimesh.Scene(objects).show()
|
37 |
+
|
38 |
+
def get_view_direction(thetas, phis, overhead, front):
|
39 |
+
# phis [B,]; thetas: [B,]
|
40 |
+
# front = 0 [0, front)
|
41 |
+
# side (left) = 1 [front, 180)
|
42 |
+
# back = 2 [180, 180+front)
|
43 |
+
# side (right) = 3 [180+front, 360)
|
44 |
+
# top = 4 [0, overhead]
|
45 |
+
# bottom = 5 [180-overhead, 180]
|
46 |
+
res = torch.zeros(thetas.shape[0], dtype=torch.long)
|
47 |
+
# first determine by phis
|
48 |
+
res[(phis < front)] = 0
|
49 |
+
res[(phis >= front) & (phis < np.pi)] = 1
|
50 |
+
res[(phis >= np.pi) & (phis < (np.pi + front))] = 2
|
51 |
+
res[(phis >= (np.pi + front))] = 3
|
52 |
+
# override by thetas
|
53 |
+
res[thetas <= overhead] = 4
|
54 |
+
res[thetas >= (np.pi - overhead)] = 5
|
55 |
+
return res
|
56 |
+
|
57 |
+
|
58 |
+
def rand_poses(size, device, radius_range=[1, 1.5], theta_range=[0, 100], phi_range=[0, 360], return_dirs=False, angle_overhead=30, angle_front=60, jitter=False):
|
59 |
+
''' generate random poses from an orbit camera
|
60 |
+
Args:
|
61 |
+
size: batch size of generated poses.
|
62 |
+
device: where to allocate the output.
|
63 |
+
radius: camera radius
|
64 |
+
theta_range: [min, max], should be in [0, pi]
|
65 |
+
phi_range: [min, max], should be in [0, 2 * pi]
|
66 |
+
Return:
|
67 |
+
poses: [size, 4, 4]
|
68 |
+
'''
|
69 |
+
|
70 |
+
theta_range = np.deg2rad(theta_range)
|
71 |
+
phi_range = np.deg2rad(phi_range)
|
72 |
+
angle_overhead = np.deg2rad(angle_overhead)
|
73 |
+
angle_front = np.deg2rad(angle_front)
|
74 |
+
|
75 |
+
radius = torch.rand(size, device=device) * (radius_range[1] - radius_range[0]) + radius_range[0]
|
76 |
+
thetas = torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0]
|
77 |
+
phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0]
|
78 |
+
|
79 |
+
centers = torch.stack([
|
80 |
+
radius * torch.sin(thetas) * torch.sin(phis),
|
81 |
+
radius * torch.cos(thetas),
|
82 |
+
radius * torch.sin(thetas) * torch.cos(phis),
|
83 |
+
], dim=-1) # [B, 3]
|
84 |
+
|
85 |
+
targets = 0
|
86 |
+
|
87 |
+
# jitters
|
88 |
+
if jitter:
|
89 |
+
centers = centers + (torch.rand_like(centers) * 0.2 - 0.1)
|
90 |
+
targets = targets + torch.randn_like(centers) * 0.2
|
91 |
+
|
92 |
+
# lookat
|
93 |
+
forward_vector = safe_normalize(targets - centers)
|
94 |
+
up_vector = torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1)
|
95 |
+
right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1))
|
96 |
+
|
97 |
+
if jitter:
|
98 |
+
up_noise = torch.randn_like(up_vector) * 0.02
|
99 |
+
else:
|
100 |
+
up_noise = 0
|
101 |
+
|
102 |
+
up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1) + up_noise)
|
103 |
+
|
104 |
+
poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1)
|
105 |
+
poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
|
106 |
+
poses[:, :3, 3] = centers
|
107 |
+
|
108 |
+
if return_dirs:
|
109 |
+
dirs = get_view_direction(thetas, phis, angle_overhead, angle_front)
|
110 |
+
else:
|
111 |
+
dirs = None
|
112 |
+
|
113 |
+
return poses, dirs
|
114 |
+
|
115 |
+
|
116 |
+
def circle_poses(device, radius=1.25, theta=60, phi=0, return_dirs=False, angle_overhead=30, angle_front=60):
|
117 |
+
|
118 |
+
theta = np.deg2rad(theta)
|
119 |
+
phi = np.deg2rad(phi)
|
120 |
+
angle_overhead = np.deg2rad(angle_overhead)
|
121 |
+
angle_front = np.deg2rad(angle_front)
|
122 |
+
|
123 |
+
thetas = torch.FloatTensor([theta]).to(device)
|
124 |
+
phis = torch.FloatTensor([phi]).to(device)
|
125 |
+
|
126 |
+
centers = torch.stack([
|
127 |
+
radius * torch.sin(thetas) * torch.sin(phis),
|
128 |
+
radius * torch.cos(thetas),
|
129 |
+
radius * torch.sin(thetas) * torch.cos(phis),
|
130 |
+
], dim=-1) # [B, 3]
|
131 |
+
|
132 |
+
# lookat
|
133 |
+
forward_vector = - safe_normalize(centers)
|
134 |
+
up_vector = torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0)
|
135 |
+
right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1))
|
136 |
+
up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1))
|
137 |
+
|
138 |
+
poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0)
|
139 |
+
poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
|
140 |
+
poses[:, :3, 3] = centers
|
141 |
+
|
142 |
+
if return_dirs:
|
143 |
+
dirs = get_view_direction(thetas, phis, angle_overhead, angle_front)
|
144 |
+
else:
|
145 |
+
dirs = None
|
146 |
+
|
147 |
+
return poses, dirs
|
148 |
+
|
149 |
+
|
150 |
+
class NeRFDataset:
|
151 |
+
def __init__(self, opt, device, type='train', H=256, W=256, size=100):
|
152 |
+
super().__init__()
|
153 |
+
|
154 |
+
self.opt = opt
|
155 |
+
self.device = device
|
156 |
+
self.type = type # train, val, test
|
157 |
+
|
158 |
+
self.H = H
|
159 |
+
self.W = W
|
160 |
+
self.radius_range = opt.radius_range
|
161 |
+
self.fovy_range = opt.fovy_range
|
162 |
+
self.size = size
|
163 |
+
|
164 |
+
self.training = self.type in ['train', 'all']
|
165 |
+
|
166 |
+
self.cx = self.H / 2
|
167 |
+
self.cy = self.W / 2
|
168 |
+
|
169 |
+
# [debug] visualize poses
|
170 |
+
# poses, dirs = rand_poses(100, self.device, return_dirs=self.opt.dir_text, radius_range=self.radius_range)
|
171 |
+
# visualize_poses(poses.detach().cpu().numpy())
|
172 |
+
|
173 |
+
|
174 |
+
def collate(self, index):
|
175 |
+
|
176 |
+
B = len(index) # always 1
|
177 |
+
|
178 |
+
if self.training:
|
179 |
+
# random pose on the fly
|
180 |
+
poses, dirs = rand_poses(B, self.device, radius_range=self.radius_range, return_dirs=self.opt.dir_text, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front, jitter=self.opt.jitter_pose)
|
181 |
+
|
182 |
+
# random focal
|
183 |
+
fov = random.random() * (self.fovy_range[1] - self.fovy_range[0]) + self.fovy_range[0]
|
184 |
+
focal = self.H / (2 * np.tan(np.deg2rad(fov) / 2))
|
185 |
+
intrinsics = np.array([focal, focal, self.cx, self.cy])
|
186 |
+
else:
|
187 |
+
# circle pose
|
188 |
+
phi = (index[0] / self.size) * 360
|
189 |
+
poses, dirs = circle_poses(self.device, radius=self.radius_range[1] * 1.2, theta=60, phi=phi, return_dirs=self.opt.dir_text, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front)
|
190 |
+
|
191 |
+
# fixed focal
|
192 |
+
fov = (self.fovy_range[1] + self.fovy_range[0]) / 2
|
193 |
+
focal = self.H / (2 * np.tan(np.deg2rad(fov) / 2))
|
194 |
+
intrinsics = np.array([focal, focal, self.cx, self.cy])
|
195 |
+
|
196 |
+
|
197 |
+
# sample a low-resolution but full image for CLIP
|
198 |
+
rays = get_rays(poses, intrinsics, self.H, self.W, -1)
|
199 |
+
|
200 |
+
data = {
|
201 |
+
'H': self.H,
|
202 |
+
'W': self.W,
|
203 |
+
'rays_o': rays['rays_o'],
|
204 |
+
'rays_d': rays['rays_d'],
|
205 |
+
'dir': dirs,
|
206 |
+
}
|
207 |
+
|
208 |
+
return data
|
209 |
+
|
210 |
+
|
211 |
+
def dataloader(self):
|
212 |
+
loader = DataLoader(list(range(self.size)), batch_size=1, collate_fn=self.collate, shuffle=self.training, num_workers=0)
|
213 |
+
loader._data = self # an ugly fix... we need to access dataset in trainer.
|
214 |
+
return loader
|
nerf/renderer.py
ADDED
@@ -0,0 +1,645 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import cv2
|
4 |
+
import trimesh
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
import mcubes
|
12 |
+
import raymarching
|
13 |
+
from .utils import custom_meshgrid, safe_normalize
|
14 |
+
|
15 |
+
def sample_pdf(bins, weights, n_samples, det=False):
|
16 |
+
# This implementation is from NeRF
|
17 |
+
# bins: [B, T], old_z_vals
|
18 |
+
# weights: [B, T - 1], bin weights.
|
19 |
+
# return: [B, n_samples], new_z_vals
|
20 |
+
|
21 |
+
# Get pdf
|
22 |
+
weights = weights + 1e-5 # prevent nans
|
23 |
+
pdf = weights / torch.sum(weights, -1, keepdim=True)
|
24 |
+
cdf = torch.cumsum(pdf, -1)
|
25 |
+
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
|
26 |
+
# Take uniform samples
|
27 |
+
if det:
|
28 |
+
u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(weights.device)
|
29 |
+
u = u.expand(list(cdf.shape[:-1]) + [n_samples])
|
30 |
+
else:
|
31 |
+
u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device)
|
32 |
+
|
33 |
+
# Invert CDF
|
34 |
+
u = u.contiguous()
|
35 |
+
inds = torch.searchsorted(cdf, u, right=True)
|
36 |
+
below = torch.max(torch.zeros_like(inds - 1), inds - 1)
|
37 |
+
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
|
38 |
+
inds_g = torch.stack([below, above], -1) # (B, n_samples, 2)
|
39 |
+
|
40 |
+
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
|
41 |
+
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
42 |
+
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
43 |
+
|
44 |
+
denom = (cdf_g[..., 1] - cdf_g[..., 0])
|
45 |
+
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
|
46 |
+
t = (u - cdf_g[..., 0]) / denom
|
47 |
+
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
|
48 |
+
|
49 |
+
return samples
|
50 |
+
|
51 |
+
|
52 |
+
def plot_pointcloud(pc, color=None):
|
53 |
+
# pc: [N, 3]
|
54 |
+
# color: [N, 3/4]
|
55 |
+
print('[visualize points]', pc.shape, pc.dtype, pc.min(0), pc.max(0))
|
56 |
+
pc = trimesh.PointCloud(pc, color)
|
57 |
+
# axis
|
58 |
+
axes = trimesh.creation.axis(axis_length=4)
|
59 |
+
# sphere
|
60 |
+
sphere = trimesh.creation.icosphere(radius=1)
|
61 |
+
trimesh.Scene([pc, axes, sphere]).show()
|
62 |
+
|
63 |
+
|
64 |
+
class NeRFRenderer(nn.Module):
|
65 |
+
def __init__(self, opt):
|
66 |
+
super().__init__()
|
67 |
+
|
68 |
+
self.opt = opt
|
69 |
+
self.bound = opt.bound
|
70 |
+
self.cascade = 1 + math.ceil(math.log2(opt.bound))
|
71 |
+
self.grid_size = 128
|
72 |
+
self.cuda_ray = opt.cuda_ray
|
73 |
+
self.min_near = opt.min_near
|
74 |
+
self.density_thresh = opt.density_thresh
|
75 |
+
self.bg_radius = opt.bg_radius
|
76 |
+
|
77 |
+
# prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax)
|
78 |
+
# NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing.
|
79 |
+
aabb_train = torch.FloatTensor([-opt.bound, -opt.bound, -opt.bound, opt.bound, opt.bound, opt.bound])
|
80 |
+
aabb_infer = aabb_train.clone()
|
81 |
+
self.register_buffer('aabb_train', aabb_train)
|
82 |
+
self.register_buffer('aabb_infer', aabb_infer)
|
83 |
+
|
84 |
+
# extra state for cuda raymarching
|
85 |
+
if self.cuda_ray:
|
86 |
+
# density grid
|
87 |
+
density_grid = torch.zeros([self.cascade, self.grid_size ** 3]) # [CAS, H * H * H]
|
88 |
+
density_bitfield = torch.zeros(self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8]
|
89 |
+
self.register_buffer('density_grid', density_grid)
|
90 |
+
self.register_buffer('density_bitfield', density_bitfield)
|
91 |
+
self.mean_density = 0
|
92 |
+
self.iter_density = 0
|
93 |
+
# step counter
|
94 |
+
step_counter = torch.zeros(16, 2, dtype=torch.int32) # 16 is hardcoded for averaging...
|
95 |
+
self.register_buffer('step_counter', step_counter)
|
96 |
+
self.mean_count = 0
|
97 |
+
self.local_step = 0
|
98 |
+
|
99 |
+
|
100 |
+
def forward(self, x, d):
|
101 |
+
raise NotImplementedError()
|
102 |
+
|
103 |
+
def density(self, x):
|
104 |
+
raise NotImplementedError()
|
105 |
+
|
106 |
+
def color(self, x, d, mask=None, **kwargs):
|
107 |
+
raise NotImplementedError()
|
108 |
+
|
109 |
+
def reset_extra_state(self):
|
110 |
+
if not self.cuda_ray:
|
111 |
+
return
|
112 |
+
# density grid
|
113 |
+
self.density_grid.zero_()
|
114 |
+
self.mean_density = 0
|
115 |
+
self.iter_density = 0
|
116 |
+
# step counter
|
117 |
+
self.step_counter.zero_()
|
118 |
+
self.mean_count = 0
|
119 |
+
self.local_step = 0
|
120 |
+
|
121 |
+
@torch.no_grad()
|
122 |
+
def export_mesh(self, path, resolution=None, S=128):
|
123 |
+
|
124 |
+
if resolution is None:
|
125 |
+
resolution = self.grid_size
|
126 |
+
|
127 |
+
density_thresh = min(self.mean_density, self.density_thresh)
|
128 |
+
|
129 |
+
sigmas = np.zeros([resolution, resolution, resolution], dtype=np.float32)
|
130 |
+
|
131 |
+
# query
|
132 |
+
X = torch.linspace(-1, 1, resolution).split(S)
|
133 |
+
Y = torch.linspace(-1, 1, resolution).split(S)
|
134 |
+
Z = torch.linspace(-1, 1, resolution).split(S)
|
135 |
+
|
136 |
+
for xi, xs in enumerate(X):
|
137 |
+
for yi, ys in enumerate(Y):
|
138 |
+
for zi, zs in enumerate(Z):
|
139 |
+
xx, yy, zz = custom_meshgrid(xs, ys, zs)
|
140 |
+
pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [S, 3]
|
141 |
+
val = self.density(pts.to(self.density_bitfield.device))
|
142 |
+
sigmas[xi * S: xi * S + len(xs), yi * S: yi * S + len(ys), zi * S: zi * S + len(zs)] = val['sigma'].reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() # [S, 1] --> [x, y, z]
|
143 |
+
|
144 |
+
vertices, triangles = mcubes.marching_cubes(sigmas, density_thresh)
|
145 |
+
|
146 |
+
vertices = vertices / (resolution - 1.0) * 2 - 1
|
147 |
+
vertices = vertices.astype(np.float32)
|
148 |
+
triangles = triangles.astype(np.int32)
|
149 |
+
|
150 |
+
v = torch.from_numpy(vertices).to(self.density_bitfield.device)
|
151 |
+
f = torch.from_numpy(triangles).int().to(self.density_bitfield.device)
|
152 |
+
|
153 |
+
# mesh = trimesh.Trimesh(vertices, triangles, process=False) # important, process=True leads to seg fault...
|
154 |
+
# mesh.export(os.path.join(path, f'mesh.ply'))
|
155 |
+
|
156 |
+
# texture?
|
157 |
+
def _export(v, f, h0=2048, w0=2048, ssaa=1, name=''):
|
158 |
+
# v, f: torch Tensor
|
159 |
+
device = v.device
|
160 |
+
v_np = v.cpu().numpy() # [N, 3]
|
161 |
+
f_np = f.cpu().numpy() # [M, 3]
|
162 |
+
|
163 |
+
print(f'[INFO] running xatlas to unwrap UVs for mesh: v={v_np.shape} f={f_np.shape}')
|
164 |
+
|
165 |
+
# unwrap uvs
|
166 |
+
import xatlas
|
167 |
+
import nvdiffrast.torch as dr
|
168 |
+
from sklearn.neighbors import NearestNeighbors
|
169 |
+
from scipy.ndimage import binary_dilation, binary_erosion
|
170 |
+
|
171 |
+
glctx = dr.RasterizeCudaContext()
|
172 |
+
|
173 |
+
atlas = xatlas.Atlas()
|
174 |
+
atlas.add_mesh(v_np, f_np)
|
175 |
+
chart_options = xatlas.ChartOptions()
|
176 |
+
chart_options.max_iterations = 0 # disable merge_chart for faster unwrap...
|
177 |
+
atlas.generate(chart_options=chart_options)
|
178 |
+
vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
|
179 |
+
|
180 |
+
# vmapping, ft_np, vt_np = xatlas.parametrize(v_np, f_np) # [N], [M, 3], [N, 2]
|
181 |
+
|
182 |
+
vt = torch.from_numpy(vt_np.astype(np.float32)).float().to(device)
|
183 |
+
ft = torch.from_numpy(ft_np.astype(np.int64)).int().to(device)
|
184 |
+
|
185 |
+
# render uv maps
|
186 |
+
uv = vt * 2.0 - 1.0 # uvs to range [-1, 1]
|
187 |
+
uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4]
|
188 |
+
|
189 |
+
if ssaa > 1:
|
190 |
+
h = int(h0 * ssaa)
|
191 |
+
w = int(w0 * ssaa)
|
192 |
+
else:
|
193 |
+
h, w = h0, w0
|
194 |
+
|
195 |
+
rast, _ = dr.rasterize(glctx, uv.unsqueeze(0), ft, (h, w)) # [1, h, w, 4]
|
196 |
+
xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, h, w, 3]
|
197 |
+
mask, _ = dr.interpolate(torch.ones_like(v[:, :1]).unsqueeze(0), rast, f) # [1, h, w, 1]
|
198 |
+
|
199 |
+
# masked query
|
200 |
+
xyzs = xyzs.view(-1, 3)
|
201 |
+
mask = (mask > 0).view(-1)
|
202 |
+
|
203 |
+
sigmas = torch.zeros(h * w, device=device, dtype=torch.float32)
|
204 |
+
feats = torch.zeros(h * w, 3, device=device, dtype=torch.float32)
|
205 |
+
|
206 |
+
if mask.any():
|
207 |
+
xyzs = xyzs[mask] # [M, 3]
|
208 |
+
|
209 |
+
# batched inference to avoid OOM
|
210 |
+
all_sigmas = []
|
211 |
+
all_feats = []
|
212 |
+
head = 0
|
213 |
+
while head < xyzs.shape[0]:
|
214 |
+
tail = min(head + 640000, xyzs.shape[0])
|
215 |
+
results_ = self.density(xyzs[head:tail])
|
216 |
+
all_sigmas.append(results_['sigma'].float())
|
217 |
+
all_feats.append(results_['albedo'].float())
|
218 |
+
head += 640000
|
219 |
+
|
220 |
+
sigmas[mask] = torch.cat(all_sigmas, dim=0)
|
221 |
+
feats[mask] = torch.cat(all_feats, dim=0)
|
222 |
+
|
223 |
+
sigmas = sigmas.view(h, w, 1)
|
224 |
+
feats = feats.view(h, w, -1)
|
225 |
+
mask = mask.view(h, w)
|
226 |
+
|
227 |
+
### alpha mask
|
228 |
+
# deltas = 2 * np.sqrt(3) / 1024
|
229 |
+
# alphas = 1 - torch.exp(-sigmas * deltas)
|
230 |
+
# alphas_mask = alphas > 0.5
|
231 |
+
# feats = feats * alphas_mask
|
232 |
+
|
233 |
+
# quantize [0.0, 1.0] to [0, 255]
|
234 |
+
feats = feats.cpu().numpy()
|
235 |
+
feats = (feats * 255).astype(np.uint8)
|
236 |
+
|
237 |
+
# alphas = alphas.cpu().numpy()
|
238 |
+
# alphas = (alphas * 255).astype(np.uint8)
|
239 |
+
|
240 |
+
### NN search as an antialiasing ...
|
241 |
+
mask = mask.cpu().numpy()
|
242 |
+
|
243 |
+
inpaint_region = binary_dilation(mask, iterations=3)
|
244 |
+
inpaint_region[mask] = 0
|
245 |
+
|
246 |
+
search_region = mask.copy()
|
247 |
+
not_search_region = binary_erosion(search_region, iterations=2)
|
248 |
+
search_region[not_search_region] = 0
|
249 |
+
|
250 |
+
search_coords = np.stack(np.nonzero(search_region), axis=-1)
|
251 |
+
inpaint_coords = np.stack(np.nonzero(inpaint_region), axis=-1)
|
252 |
+
|
253 |
+
knn = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(search_coords)
|
254 |
+
_, indices = knn.kneighbors(inpaint_coords)
|
255 |
+
|
256 |
+
feats[tuple(inpaint_coords.T)] = feats[tuple(search_coords[indices[:, 0]].T)]
|
257 |
+
|
258 |
+
# do ssaa after the NN search, in numpy
|
259 |
+
feats = cv2.cvtColor(feats, cv2.COLOR_RGB2BGR)
|
260 |
+
|
261 |
+
if ssaa > 1:
|
262 |
+
# alphas = cv2.resize(alphas, (w0, h0), interpolation=cv2.INTER_NEAREST)
|
263 |
+
feats = cv2.resize(feats, (w0, h0), interpolation=cv2.INTER_LINEAR)
|
264 |
+
|
265 |
+
# cv2.imwrite(os.path.join(path, f'alpha.png'), alphas)
|
266 |
+
cv2.imwrite(os.path.join(path, f'{name}albedo.png'), feats)
|
267 |
+
|
268 |
+
# save obj (v, vt, f /)
|
269 |
+
obj_file = os.path.join(path, f'{name}mesh.obj')
|
270 |
+
mtl_file = os.path.join(path, f'{name}mesh.mtl')
|
271 |
+
|
272 |
+
print(f'[INFO] writing obj mesh to {obj_file}')
|
273 |
+
with open(obj_file, "w") as fp:
|
274 |
+
fp.write(f'mtllib {name}mesh.mtl \n')
|
275 |
+
|
276 |
+
print(f'[INFO] writing vertices {v_np.shape}')
|
277 |
+
for v in v_np:
|
278 |
+
fp.write(f'v {v[0]} {v[1]} {v[2]} \n')
|
279 |
+
|
280 |
+
print(f'[INFO] writing vertices texture coords {vt_np.shape}')
|
281 |
+
for v in vt_np:
|
282 |
+
fp.write(f'vt {v[0]} {1 - v[1]} \n')
|
283 |
+
|
284 |
+
print(f'[INFO] writing faces {f_np.shape}')
|
285 |
+
fp.write(f'usemtl mat0 \n')
|
286 |
+
for i in range(len(f_np)):
|
287 |
+
fp.write(f"f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1} {f_np[i, 1] + 1}/{ft_np[i, 1] + 1} {f_np[i, 2] + 1}/{ft_np[i, 2] + 1} \n")
|
288 |
+
|
289 |
+
with open(mtl_file, "w") as fp:
|
290 |
+
fp.write(f'newmtl mat0 \n')
|
291 |
+
fp.write(f'Ka 1.000000 1.000000 1.000000 \n')
|
292 |
+
fp.write(f'Kd 1.000000 1.000000 1.000000 \n')
|
293 |
+
fp.write(f'Ks 0.000000 0.000000 0.000000 \n')
|
294 |
+
fp.write(f'Tr 1.000000 \n')
|
295 |
+
fp.write(f'illum 1 \n')
|
296 |
+
fp.write(f'Ns 0.000000 \n')
|
297 |
+
fp.write(f'map_Kd {name}albedo.png \n')
|
298 |
+
|
299 |
+
_export(v, f)
|
300 |
+
|
301 |
+
def run(self, rays_o, rays_d, num_steps=128, upsample_steps=128, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, **kwargs):
|
302 |
+
# rays_o, rays_d: [B, N, 3], assumes B == 1
|
303 |
+
# bg_color: [BN, 3] in range [0, 1]
|
304 |
+
# return: image: [B, N, 3], depth: [B, N]
|
305 |
+
|
306 |
+
prefix = rays_o.shape[:-1]
|
307 |
+
rays_o = rays_o.contiguous().view(-1, 3)
|
308 |
+
rays_d = rays_d.contiguous().view(-1, 3)
|
309 |
+
|
310 |
+
N = rays_o.shape[0] # N = B * N, in fact
|
311 |
+
device = rays_o.device
|
312 |
+
|
313 |
+
results = {}
|
314 |
+
|
315 |
+
# choose aabb
|
316 |
+
aabb = self.aabb_train if self.training else self.aabb_infer
|
317 |
+
|
318 |
+
# sample steps
|
319 |
+
nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, aabb, self.min_near)
|
320 |
+
nears.unsqueeze_(-1)
|
321 |
+
fars.unsqueeze_(-1)
|
322 |
+
|
323 |
+
# random sample light_d if not provided
|
324 |
+
if light_d is None:
|
325 |
+
# gaussian noise around the ray origin, so the light always face the view dir (avoid dark face)
|
326 |
+
light_d = (rays_o[0] + torch.randn(3, device=device, dtype=torch.float))
|
327 |
+
light_d = safe_normalize(light_d)
|
328 |
+
|
329 |
+
#print(f'nears = {nears.min().item()} ~ {nears.max().item()}, fars = {fars.min().item()} ~ {fars.max().item()}')
|
330 |
+
|
331 |
+
z_vals = torch.linspace(0.0, 1.0, num_steps, device=device).unsqueeze(0) # [1, T]
|
332 |
+
z_vals = z_vals.expand((N, num_steps)) # [N, T]
|
333 |
+
z_vals = nears + (fars - nears) * z_vals # [N, T], in [nears, fars]
|
334 |
+
|
335 |
+
# perturb z_vals
|
336 |
+
sample_dist = (fars - nears) / num_steps
|
337 |
+
if perturb:
|
338 |
+
z_vals = z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist
|
339 |
+
#z_vals = z_vals.clamp(nears, fars) # avoid out of bounds xyzs.
|
340 |
+
|
341 |
+
# generate xyzs
|
342 |
+
xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(-1) # [N, 1, 3] * [N, T, 1] -> [N, T, 3]
|
343 |
+
xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:]) # a manual clip.
|
344 |
+
|
345 |
+
#plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())
|
346 |
+
|
347 |
+
# query SDF and RGB
|
348 |
+
density_outputs = self.density(xyzs.reshape(-1, 3))
|
349 |
+
|
350 |
+
#sigmas = density_outputs['sigma'].view(N, num_steps) # [N, T]
|
351 |
+
for k, v in density_outputs.items():
|
352 |
+
density_outputs[k] = v.view(N, num_steps, -1)
|
353 |
+
|
354 |
+
# upsample z_vals (nerf-like)
|
355 |
+
if upsample_steps > 0:
|
356 |
+
with torch.no_grad():
|
357 |
+
|
358 |
+
deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T-1]
|
359 |
+
deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1)
|
360 |
+
|
361 |
+
alphas = 1 - torch.exp(-deltas * density_outputs['sigma'].squeeze(-1)) # [N, T]
|
362 |
+
alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+1]
|
363 |
+
weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T]
|
364 |
+
|
365 |
+
# sample new z_vals
|
366 |
+
z_vals_mid = (z_vals[..., :-1] + 0.5 * deltas[..., :-1]) # [N, T-1]
|
367 |
+
new_z_vals = sample_pdf(z_vals_mid, weights[:, 1:-1], upsample_steps, det=not self.training).detach() # [N, t]
|
368 |
+
|
369 |
+
new_xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * new_z_vals.unsqueeze(-1) # [N, 1, 3] * [N, t, 1] -> [N, t, 3]
|
370 |
+
new_xyzs = torch.min(torch.max(new_xyzs, aabb[:3]), aabb[3:]) # a manual clip.
|
371 |
+
|
372 |
+
# only forward new points to save computation
|
373 |
+
new_density_outputs = self.density(new_xyzs.reshape(-1, 3))
|
374 |
+
#new_sigmas = new_density_outputs['sigma'].view(N, upsample_steps) # [N, t]
|
375 |
+
for k, v in new_density_outputs.items():
|
376 |
+
new_density_outputs[k] = v.view(N, upsample_steps, -1)
|
377 |
+
|
378 |
+
# re-order
|
379 |
+
z_vals = torch.cat([z_vals, new_z_vals], dim=1) # [N, T+t]
|
380 |
+
z_vals, z_index = torch.sort(z_vals, dim=1)
|
381 |
+
|
382 |
+
xyzs = torch.cat([xyzs, new_xyzs], dim=1) # [N, T+t, 3]
|
383 |
+
xyzs = torch.gather(xyzs, dim=1, index=z_index.unsqueeze(-1).expand_as(xyzs))
|
384 |
+
|
385 |
+
for k in density_outputs:
|
386 |
+
tmp_output = torch.cat([density_outputs[k], new_density_outputs[k]], dim=1)
|
387 |
+
density_outputs[k] = torch.gather(tmp_output, dim=1, index=z_index.unsqueeze(-1).expand_as(tmp_output))
|
388 |
+
|
389 |
+
deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T+t-1]
|
390 |
+
deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1)
|
391 |
+
alphas = 1 - torch.exp(-deltas * density_outputs['sigma'].squeeze(-1)) # [N, T+t]
|
392 |
+
alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+t+1]
|
393 |
+
weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T+t]
|
394 |
+
|
395 |
+
dirs = rays_d.view(-1, 1, 3).expand_as(xyzs)
|
396 |
+
for k, v in density_outputs.items():
|
397 |
+
density_outputs[k] = v.view(-1, v.shape[-1])
|
398 |
+
|
399 |
+
sigmas, rgbs, normals = self(xyzs.reshape(-1, 3), dirs.reshape(-1, 3), light_d, ratio=ambient_ratio, shading=shading)
|
400 |
+
rgbs = rgbs.view(N, -1, 3) # [N, T+t, 3]
|
401 |
+
|
402 |
+
#print(xyzs.shape, 'valid_rgb:', mask.sum().item())
|
403 |
+
# orientation loss
|
404 |
+
if normals is not None:
|
405 |
+
normals = normals.view(N, -1, 3)
|
406 |
+
# print(weights.shape, normals.shape, dirs.shape)
|
407 |
+
loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2
|
408 |
+
results['loss_orient'] = loss_orient.mean()
|
409 |
+
|
410 |
+
# calculate weight_sum (mask)
|
411 |
+
weights_sum = weights.sum(dim=-1) # [N]
|
412 |
+
|
413 |
+
# calculate depth
|
414 |
+
ori_z_vals = ((z_vals - nears) / (fars - nears)).clamp(0, 1)
|
415 |
+
depth = torch.sum(weights * ori_z_vals, dim=-1)
|
416 |
+
|
417 |
+
# calculate color
|
418 |
+
image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2) # [N, 3], in [0, 1]
|
419 |
+
|
420 |
+
# mix background color
|
421 |
+
if self.bg_radius > 0:
|
422 |
+
# use the bg model to calculate bg_color
|
423 |
+
# sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1]
|
424 |
+
bg_color = self.background(rays_d.reshape(-1, 3)) # [N, 3]
|
425 |
+
elif bg_color is None:
|
426 |
+
bg_color = 1
|
427 |
+
|
428 |
+
image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
|
429 |
+
|
430 |
+
image = image.view(*prefix, 3)
|
431 |
+
depth = depth.view(*prefix)
|
432 |
+
|
433 |
+
mask = (nears < fars).reshape(*prefix)
|
434 |
+
|
435 |
+
results['image'] = image
|
436 |
+
results['depth'] = depth
|
437 |
+
results['weights_sum'] = weights_sum
|
438 |
+
results['mask'] = mask
|
439 |
+
|
440 |
+
return results
|
441 |
+
|
442 |
+
|
443 |
+
def run_cuda(self, rays_o, rays_d, dt_gamma=0, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, **kwargs):
|
444 |
+
# rays_o, rays_d: [B, N, 3], assumes B == 1
|
445 |
+
# return: image: [B, N, 3], depth: [B, N]
|
446 |
+
|
447 |
+
prefix = rays_o.shape[:-1]
|
448 |
+
rays_o = rays_o.contiguous().view(-1, 3)
|
449 |
+
rays_d = rays_d.contiguous().view(-1, 3)
|
450 |
+
|
451 |
+
N = rays_o.shape[0] # N = B * N, in fact
|
452 |
+
device = rays_o.device
|
453 |
+
|
454 |
+
# pre-calculate near far
|
455 |
+
nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer)
|
456 |
+
|
457 |
+
# random sample light_d if not provided
|
458 |
+
if light_d is None:
|
459 |
+
# gaussian noise around the ray origin, so the light always face the view dir (avoid dark face)
|
460 |
+
light_d = (rays_o[0] + torch.randn(3, device=device, dtype=torch.float))
|
461 |
+
light_d = safe_normalize(light_d)
|
462 |
+
|
463 |
+
results = {}
|
464 |
+
|
465 |
+
if self.training:
|
466 |
+
# setup counter
|
467 |
+
counter = self.step_counter[self.local_step % 16]
|
468 |
+
counter.zero_() # set to 0
|
469 |
+
self.local_step += 1
|
470 |
+
|
471 |
+
xyzs, dirs, deltas, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps)
|
472 |
+
|
473 |
+
#plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())
|
474 |
+
|
475 |
+
sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading)
|
476 |
+
|
477 |
+
#print(f'valid RGB query ratio: {mask.sum().item() / mask.shape[0]} (total = {mask.sum().item()})')
|
478 |
+
|
479 |
+
weights_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, deltas, rays, T_thresh)
|
480 |
+
|
481 |
+
# orientation loss
|
482 |
+
if normals is not None:
|
483 |
+
weights = 1 - torch.exp(-sigmas)
|
484 |
+
loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2
|
485 |
+
results['loss_orient'] = loss_orient.mean()
|
486 |
+
|
487 |
+
else:
|
488 |
+
|
489 |
+
# allocate outputs
|
490 |
+
dtype = torch.float32
|
491 |
+
|
492 |
+
weights_sum = torch.zeros(N, dtype=dtype, device=device)
|
493 |
+
depth = torch.zeros(N, dtype=dtype, device=device)
|
494 |
+
image = torch.zeros(N, 3, dtype=dtype, device=device)
|
495 |
+
|
496 |
+
n_alive = N
|
497 |
+
rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N]
|
498 |
+
rays_t = nears.clone() # [N]
|
499 |
+
|
500 |
+
step = 0
|
501 |
+
|
502 |
+
while step < max_steps: # hard coded max step
|
503 |
+
|
504 |
+
# count alive rays
|
505 |
+
n_alive = rays_alive.shape[0]
|
506 |
+
|
507 |
+
# exit loop
|
508 |
+
if n_alive <= 0:
|
509 |
+
break
|
510 |
+
|
511 |
+
# decide compact_steps
|
512 |
+
n_step = max(min(N // n_alive, 8), 1)
|
513 |
+
|
514 |
+
xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, 128, perturb if step == 0 else False, dt_gamma, max_steps)
|
515 |
+
|
516 |
+
sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading)
|
517 |
+
|
518 |
+
raymarching.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh)
|
519 |
+
|
520 |
+
rays_alive = rays_alive[rays_alive >= 0]
|
521 |
+
#print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}')
|
522 |
+
|
523 |
+
step += n_step
|
524 |
+
|
525 |
+
# mix background color
|
526 |
+
if self.bg_radius > 0:
|
527 |
+
|
528 |
+
# use the bg model to calculate bg_color
|
529 |
+
# sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1]
|
530 |
+
bg_color = self.background(rays_d) # [N, 3]
|
531 |
+
|
532 |
+
elif bg_color is None:
|
533 |
+
bg_color = 1
|
534 |
+
|
535 |
+
image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
|
536 |
+
image = image.view(*prefix, 3)
|
537 |
+
|
538 |
+
depth = torch.clamp(depth - nears, min=0) / (fars - nears)
|
539 |
+
depth = depth.view(*prefix)
|
540 |
+
|
541 |
+
weights_sum = weights_sum.reshape(*prefix)
|
542 |
+
|
543 |
+
mask = (nears < fars).reshape(*prefix)
|
544 |
+
|
545 |
+
results['image'] = image
|
546 |
+
results['depth'] = depth
|
547 |
+
results['weights_sum'] = weights_sum
|
548 |
+
results['mask'] = mask
|
549 |
+
|
550 |
+
return results
|
551 |
+
|
552 |
+
|
553 |
+
@torch.no_grad()
|
554 |
+
def update_extra_state(self, decay=0.95, S=128):
|
555 |
+
# call before each epoch to update extra states.
|
556 |
+
|
557 |
+
if not self.cuda_ray:
|
558 |
+
return
|
559 |
+
|
560 |
+
### update density grid
|
561 |
+
tmp_grid = - torch.ones_like(self.density_grid)
|
562 |
+
|
563 |
+
X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
|
564 |
+
Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
|
565 |
+
Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
|
566 |
+
|
567 |
+
for xs in X:
|
568 |
+
for ys in Y:
|
569 |
+
for zs in Z:
|
570 |
+
|
571 |
+
# construct points
|
572 |
+
xx, yy, zz = custom_meshgrid(xs, ys, zs)
|
573 |
+
coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
|
574 |
+
indices = raymarching.morton3D(coords).long() # [N]
|
575 |
+
xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1]
|
576 |
+
|
577 |
+
# cascading
|
578 |
+
for cas in range(self.cascade):
|
579 |
+
bound = min(2 ** cas, self.bound)
|
580 |
+
half_grid_size = bound / self.grid_size
|
581 |
+
# scale to current cascade's resolution
|
582 |
+
cas_xyzs = xyzs * (bound - half_grid_size)
|
583 |
+
# add noise in [-hgs, hgs]
|
584 |
+
cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size
|
585 |
+
# query density
|
586 |
+
sigmas = self.density(cas_xyzs)['sigma'].reshape(-1).detach()
|
587 |
+
# assign
|
588 |
+
tmp_grid[cas, indices] = sigmas
|
589 |
+
|
590 |
+
# ema update
|
591 |
+
valid_mask = self.density_grid >= 0
|
592 |
+
self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask])
|
593 |
+
self.mean_density = torch.mean(self.density_grid[valid_mask]).item()
|
594 |
+
self.iter_density += 1
|
595 |
+
|
596 |
+
# convert to bitfield
|
597 |
+
density_thresh = min(self.mean_density, self.density_thresh)
|
598 |
+
self.density_bitfield = raymarching.packbits(self.density_grid, density_thresh, self.density_bitfield)
|
599 |
+
|
600 |
+
### update step counter
|
601 |
+
total_step = min(16, self.local_step)
|
602 |
+
if total_step > 0:
|
603 |
+
self.mean_count = int(self.step_counter[:total_step, 0].sum().item() / total_step)
|
604 |
+
self.local_step = 0
|
605 |
+
|
606 |
+
# print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > density_thresh).sum() / (128**3 * self.cascade):.3f} | [step counter] mean={self.mean_count}')
|
607 |
+
|
608 |
+
|
609 |
+
def render(self, rays_o, rays_d, staged=False, max_ray_batch=4096, **kwargs):
|
610 |
+
# rays_o, rays_d: [B, N, 3], assumes B == 1
|
611 |
+
# return: pred_rgb: [B, N, 3]
|
612 |
+
|
613 |
+
if self.cuda_ray:
|
614 |
+
_run = self.run_cuda
|
615 |
+
else:
|
616 |
+
_run = self.run
|
617 |
+
|
618 |
+
B, N = rays_o.shape[:2]
|
619 |
+
device = rays_o.device
|
620 |
+
|
621 |
+
# never stage when cuda_ray
|
622 |
+
if staged and not self.cuda_ray:
|
623 |
+
depth = torch.empty((B, N), device=device)
|
624 |
+
image = torch.empty((B, N, 3), device=device)
|
625 |
+
weights_sum = torch.empty((B, N), device=device)
|
626 |
+
|
627 |
+
for b in range(B):
|
628 |
+
head = 0
|
629 |
+
while head < N:
|
630 |
+
tail = min(head + max_ray_batch, N)
|
631 |
+
results_ = _run(rays_o[b:b+1, head:tail], rays_d[b:b+1, head:tail], **kwargs)
|
632 |
+
depth[b:b+1, head:tail] = results_['depth']
|
633 |
+
weights_sum[b:b+1, head:tail] = results_['weights_sum']
|
634 |
+
image[b:b+1, head:tail] = results_['image']
|
635 |
+
head += max_ray_batch
|
636 |
+
|
637 |
+
results = {}
|
638 |
+
results['depth'] = depth
|
639 |
+
results['image'] = image
|
640 |
+
results['weights_sum'] = weights_sum
|
641 |
+
|
642 |
+
else:
|
643 |
+
results = _run(rays_o, rays_d, **kwargs)
|
644 |
+
|
645 |
+
return results
|
nerf/sd.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import CLIPTextModel, CLIPTokenizer, logging
|
2 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
|
3 |
+
|
4 |
+
# suppress partial model loading warning
|
5 |
+
logging.set_verbosity_error()
|
6 |
+
|
7 |
+
import os
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
import time
|
13 |
+
|
14 |
+
class StableDiffusion(nn.Module):
|
15 |
+
def __init__(self, device):
|
16 |
+
super().__init__()
|
17 |
+
|
18 |
+
try:
|
19 |
+
self.token = os.environ['TOKEN']
|
20 |
+
print(f'[INFO] loaded hugging face access token from environment variable TOKEN')
|
21 |
+
except FileNotFoundError as e:
|
22 |
+
self.token = True
|
23 |
+
print(f'[INFO] try to load hugging face access token from the default place, make sure you have run `huggingface-cli login`.')
|
24 |
+
|
25 |
+
self.device = device
|
26 |
+
self.num_train_timesteps = 1000
|
27 |
+
self.min_step = int(self.num_train_timesteps * 0.02)
|
28 |
+
self.max_step = int(self.num_train_timesteps * 0.98)
|
29 |
+
|
30 |
+
print(f'[INFO] loading stable diffusion...')
|
31 |
+
|
32 |
+
# 1. Load the autoencoder model which will be used to decode the latents into image space.
|
33 |
+
self.vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", use_auth_token=self.token).to(self.device)
|
34 |
+
|
35 |
+
# 2. Load the tokenizer and text encoder to tokenize and encode the text.
|
36 |
+
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
37 |
+
self.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(self.device)
|
38 |
+
|
39 |
+
# 3. The UNet model for generating the latents.
|
40 |
+
self.unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", use_auth_token=self.token).to(self.device)
|
41 |
+
|
42 |
+
# 4. Create a scheduler for inference
|
43 |
+
self.scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=self.num_train_timesteps)
|
44 |
+
self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
|
45 |
+
|
46 |
+
print(f'[INFO] loaded stable diffusion!')
|
47 |
+
|
48 |
+
def get_text_embeds(self, prompt):
|
49 |
+
# Tokenize text and get embeddings
|
50 |
+
text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt')
|
51 |
+
|
52 |
+
with torch.no_grad():
|
53 |
+
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
54 |
+
|
55 |
+
# Do the same for unconditional embeddings
|
56 |
+
uncond_input = self.tokenizer([''] * len(prompt), padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt')
|
57 |
+
|
58 |
+
with torch.no_grad():
|
59 |
+
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
60 |
+
|
61 |
+
# Cat for final embeddings
|
62 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
63 |
+
return text_embeddings
|
64 |
+
|
65 |
+
|
66 |
+
def train_step(self, text_embeddings, pred_rgb, guidance_scale=100):
|
67 |
+
|
68 |
+
# interp to 512x512 to be fed into vae.
|
69 |
+
|
70 |
+
# _t = time.time()
|
71 |
+
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
|
72 |
+
# torch.cuda.synchronize(); print(f'[TIME] guiding: interp {time.time() - _t:.4f}s')
|
73 |
+
|
74 |
+
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
|
75 |
+
t = torch.randint(self.min_step, self.max_step + 1, [1], dtype=torch.long, device=self.device)
|
76 |
+
|
77 |
+
# encode image into latents with vae, requires grad!
|
78 |
+
# _t = time.time()
|
79 |
+
latents = self.encode_imgs(pred_rgb_512)
|
80 |
+
# torch.cuda.synchronize(); print(f'[TIME] guiding: vae enc {time.time() - _t:.4f}s')
|
81 |
+
|
82 |
+
# predict the noise residual with unet, NO grad!
|
83 |
+
# _t = time.time()
|
84 |
+
with torch.no_grad():
|
85 |
+
# add noise
|
86 |
+
noise = torch.randn_like(latents)
|
87 |
+
latents_noisy = self.scheduler.add_noise(latents, noise, t)
|
88 |
+
# pred noise
|
89 |
+
latent_model_input = torch.cat([latents_noisy] * 2)
|
90 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
91 |
+
# torch.cuda.synchronize(); print(f'[TIME] guiding: unet {time.time() - _t:.4f}s')
|
92 |
+
|
93 |
+
# perform guidance (high scale from paper!)
|
94 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
95 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
96 |
+
|
97 |
+
# w(t), sigma_t^2
|
98 |
+
w = (1 - self.alphas[t])
|
99 |
+
# w = self.alphas[t] ** 0.5 * (1 - self.alphas[t])
|
100 |
+
grad = w * (noise_pred - noise)
|
101 |
+
|
102 |
+
# clip grad for stable training?
|
103 |
+
# grad = grad.clamp(-1, 1)
|
104 |
+
|
105 |
+
# manually backward, since we omitted an item in grad and cannot simply autodiff.
|
106 |
+
# _t = time.time()
|
107 |
+
latents.backward(gradient=grad, retain_graph=True)
|
108 |
+
# torch.cuda.synchronize(); print(f'[TIME] guiding: backward {time.time() - _t:.4f}s')
|
109 |
+
|
110 |
+
return 0 # dummy loss value
|
111 |
+
|
112 |
+
def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):
|
113 |
+
|
114 |
+
if latents is None:
|
115 |
+
latents = torch.randn((text_embeddings.shape[0] // 2, self.unet.in_channels, height // 8, width // 8), device=self.device)
|
116 |
+
|
117 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
118 |
+
|
119 |
+
with torch.autocast('cuda'):
|
120 |
+
for i, t in enumerate(self.scheduler.timesteps):
|
121 |
+
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
122 |
+
latent_model_input = torch.cat([latents] * 2)
|
123 |
+
|
124 |
+
# predict the noise residual
|
125 |
+
with torch.no_grad():
|
126 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']
|
127 |
+
|
128 |
+
# perform guidance
|
129 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
130 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
131 |
+
|
132 |
+
# compute the previous noisy sample x_t -> x_t-1
|
133 |
+
latents = self.scheduler.step(noise_pred, t, latents)['prev_sample']
|
134 |
+
|
135 |
+
return latents
|
136 |
+
|
137 |
+
def decode_latents(self, latents):
|
138 |
+
|
139 |
+
latents = 1 / 0.18215 * latents
|
140 |
+
|
141 |
+
with torch.no_grad():
|
142 |
+
imgs = self.vae.decode(latents).sample
|
143 |
+
|
144 |
+
imgs = (imgs / 2 + 0.5).clamp(0, 1)
|
145 |
+
|
146 |
+
return imgs
|
147 |
+
|
148 |
+
def encode_imgs(self, imgs):
|
149 |
+
# imgs: [B, 3, H, W]
|
150 |
+
|
151 |
+
imgs = 2 * imgs - 1
|
152 |
+
|
153 |
+
posterior = self.vae.encode(imgs).latent_dist
|
154 |
+
latents = posterior.sample() * 0.18215
|
155 |
+
|
156 |
+
return latents
|
157 |
+
|
158 |
+
def prompt_to_img(self, prompts, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):
|
159 |
+
|
160 |
+
if isinstance(prompts, str):
|
161 |
+
prompts = [prompts]
|
162 |
+
|
163 |
+
# Prompts -> text embeds
|
164 |
+
text_embeds = self.get_text_embeds(prompts) # [2, 77, 768]
|
165 |
+
|
166 |
+
# Text embeds -> img latents
|
167 |
+
latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64]
|
168 |
+
|
169 |
+
# Img latents -> imgs
|
170 |
+
imgs = self.decode_latents(latents) # [1, 3, 512, 512]
|
171 |
+
|
172 |
+
# Img to Numpy
|
173 |
+
imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
|
174 |
+
imgs = (imgs * 255).round().astype('uint8')
|
175 |
+
|
176 |
+
return imgs
|
177 |
+
|
178 |
+
|
179 |
+
if __name__ == '__main__':
|
180 |
+
|
181 |
+
import argparse
|
182 |
+
import matplotlib.pyplot as plt
|
183 |
+
|
184 |
+
parser = argparse.ArgumentParser()
|
185 |
+
parser.add_argument('prompt', type=str)
|
186 |
+
parser.add_argument('-H', type=int, default=512)
|
187 |
+
parser.add_argument('-W', type=int, default=512)
|
188 |
+
parser.add_argument('--steps', type=int, default=50)
|
189 |
+
opt = parser.parse_args()
|
190 |
+
|
191 |
+
device = torch.device('cuda')
|
192 |
+
|
193 |
+
sd = StableDiffusion(device)
|
194 |
+
|
195 |
+
imgs = sd.prompt_to_img(opt.prompt, opt.H, opt.W, opt.steps)
|
196 |
+
|
197 |
+
# visualize image
|
198 |
+
plt.imshow(imgs[0])
|
199 |
+
plt.show()
|
200 |
+
|
201 |
+
|
202 |
+
|
203 |
+
|
nerf/utils.py
ADDED
@@ -0,0 +1,950 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import tqdm
|
4 |
+
import math
|
5 |
+
import imageio
|
6 |
+
import random
|
7 |
+
import warnings
|
8 |
+
import tensorboardX
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import pandas as pd
|
12 |
+
|
13 |
+
import time
|
14 |
+
from datetime import datetime
|
15 |
+
|
16 |
+
import cv2
|
17 |
+
import matplotlib.pyplot as plt
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
import torch.optim as optim
|
22 |
+
import torch.nn.functional as F
|
23 |
+
import torch.distributed as dist
|
24 |
+
from torch.utils.data import Dataset, DataLoader
|
25 |
+
|
26 |
+
import trimesh
|
27 |
+
from rich.console import Console
|
28 |
+
from torch_ema import ExponentialMovingAverage
|
29 |
+
|
30 |
+
from packaging import version as pver
|
31 |
+
|
32 |
+
def custom_meshgrid(*args):
|
33 |
+
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
|
34 |
+
if pver.parse(torch.__version__) < pver.parse('1.10'):
|
35 |
+
return torch.meshgrid(*args)
|
36 |
+
else:
|
37 |
+
return torch.meshgrid(*args, indexing='ij')
|
38 |
+
|
39 |
+
def safe_normalize(x, eps=1e-20):
|
40 |
+
return x / torch.sqrt(torch.clamp(torch.sum(x * x, -1, keepdim=True), min=eps))
|
41 |
+
|
42 |
+
@torch.cuda.amp.autocast(enabled=False)
|
43 |
+
def get_rays(poses, intrinsics, H, W, N=-1, error_map=None):
|
44 |
+
''' get rays
|
45 |
+
Args:
|
46 |
+
poses: [B, 4, 4], cam2world
|
47 |
+
intrinsics: [4]
|
48 |
+
H, W, N: int
|
49 |
+
error_map: [B, 128 * 128], sample probability based on training error
|
50 |
+
Returns:
|
51 |
+
rays_o, rays_d: [B, N, 3]
|
52 |
+
inds: [B, N]
|
53 |
+
'''
|
54 |
+
|
55 |
+
device = poses.device
|
56 |
+
B = poses.shape[0]
|
57 |
+
fx, fy, cx, cy = intrinsics
|
58 |
+
|
59 |
+
i, j = custom_meshgrid(torch.linspace(0, W-1, W, device=device), torch.linspace(0, H-1, H, device=device))
|
60 |
+
i = i.t().reshape([1, H*W]).expand([B, H*W]) + 0.5
|
61 |
+
j = j.t().reshape([1, H*W]).expand([B, H*W]) + 0.5
|
62 |
+
|
63 |
+
results = {}
|
64 |
+
|
65 |
+
if N > 0:
|
66 |
+
N = min(N, H*W)
|
67 |
+
|
68 |
+
if error_map is None:
|
69 |
+
inds = torch.randint(0, H*W, size=[N], device=device) # may duplicate
|
70 |
+
inds = inds.expand([B, N])
|
71 |
+
else:
|
72 |
+
|
73 |
+
# weighted sample on a low-reso grid
|
74 |
+
inds_coarse = torch.multinomial(error_map.to(device), N, replacement=False) # [B, N], but in [0, 128*128)
|
75 |
+
|
76 |
+
# map to the original resolution with random perturb.
|
77 |
+
inds_x, inds_y = inds_coarse // 128, inds_coarse % 128 # `//` will throw a warning in torch 1.10... anyway.
|
78 |
+
sx, sy = H / 128, W / 128
|
79 |
+
inds_x = (inds_x * sx + torch.rand(B, N, device=device) * sx).long().clamp(max=H - 1)
|
80 |
+
inds_y = (inds_y * sy + torch.rand(B, N, device=device) * sy).long().clamp(max=W - 1)
|
81 |
+
inds = inds_x * W + inds_y
|
82 |
+
|
83 |
+
results['inds_coarse'] = inds_coarse # need this when updating error_map
|
84 |
+
|
85 |
+
i = torch.gather(i, -1, inds)
|
86 |
+
j = torch.gather(j, -1, inds)
|
87 |
+
|
88 |
+
results['inds'] = inds
|
89 |
+
|
90 |
+
else:
|
91 |
+
inds = torch.arange(H*W, device=device).expand([B, H*W])
|
92 |
+
|
93 |
+
zs = torch.ones_like(i)
|
94 |
+
xs = (i - cx) / fx * zs
|
95 |
+
ys = (j - cy) / fy * zs
|
96 |
+
directions = torch.stack((xs, ys, zs), dim=-1)
|
97 |
+
directions = safe_normalize(directions)
|
98 |
+
rays_d = directions @ poses[:, :3, :3].transpose(-1, -2) # (B, N, 3)
|
99 |
+
|
100 |
+
rays_o = poses[..., :3, 3] # [B, 3]
|
101 |
+
rays_o = rays_o[..., None, :].expand_as(rays_d) # [B, N, 3]
|
102 |
+
|
103 |
+
results['rays_o'] = rays_o
|
104 |
+
results['rays_d'] = rays_d
|
105 |
+
|
106 |
+
return results
|
107 |
+
|
108 |
+
|
109 |
+
def seed_everything(seed):
|
110 |
+
random.seed(seed)
|
111 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
112 |
+
np.random.seed(seed)
|
113 |
+
torch.manual_seed(seed)
|
114 |
+
torch.cuda.manual_seed(seed)
|
115 |
+
#torch.backends.cudnn.deterministic = True
|
116 |
+
#torch.backends.cudnn.benchmark = True
|
117 |
+
|
118 |
+
|
119 |
+
def torch_vis_2d(x, renormalize=False):
|
120 |
+
# x: [3, H, W] or [1, H, W] or [H, W]
|
121 |
+
import matplotlib.pyplot as plt
|
122 |
+
import numpy as np
|
123 |
+
import torch
|
124 |
+
|
125 |
+
if isinstance(x, torch.Tensor):
|
126 |
+
if len(x.shape) == 3:
|
127 |
+
x = x.permute(1,2,0).squeeze()
|
128 |
+
x = x.detach().cpu().numpy()
|
129 |
+
|
130 |
+
print(f'[torch_vis_2d] {x.shape}, {x.dtype}, {x.min()} ~ {x.max()}')
|
131 |
+
|
132 |
+
x = x.astype(np.float32)
|
133 |
+
|
134 |
+
# renormalize
|
135 |
+
if renormalize:
|
136 |
+
x = (x - x.min(axis=0, keepdims=True)) / (x.max(axis=0, keepdims=True) - x.min(axis=0, keepdims=True) + 1e-8)
|
137 |
+
|
138 |
+
plt.imshow(x)
|
139 |
+
plt.show()
|
140 |
+
|
141 |
+
@torch.jit.script
|
142 |
+
def linear_to_srgb(x):
|
143 |
+
return torch.where(x < 0.0031308, 12.92 * x, 1.055 * x ** 0.41666 - 0.055)
|
144 |
+
|
145 |
+
|
146 |
+
@torch.jit.script
|
147 |
+
def srgb_to_linear(x):
|
148 |
+
return torch.where(x < 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4)
|
149 |
+
|
150 |
+
|
151 |
+
class Trainer(object):
|
152 |
+
def __init__(self,
|
153 |
+
name, # name of this experiment
|
154 |
+
opt, # extra conf
|
155 |
+
model, # network
|
156 |
+
guidance, # guidance network
|
157 |
+
criterion=None, # loss function, if None, assume inline implementation in train_step
|
158 |
+
optimizer=None, # optimizer
|
159 |
+
ema_decay=None, # if use EMA, set the decay
|
160 |
+
lr_scheduler=None, # scheduler
|
161 |
+
metrics=[], # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric.
|
162 |
+
local_rank=0, # which GPU am I
|
163 |
+
world_size=1, # total num of GPUs
|
164 |
+
device=None, # device to use, usually setting to None is OK. (auto choose device)
|
165 |
+
mute=False, # whether to mute all print
|
166 |
+
fp16=False, # amp optimize level
|
167 |
+
eval_interval=1, # eval once every $ epoch
|
168 |
+
max_keep_ckpt=2, # max num of saved ckpts in disk
|
169 |
+
workspace='workspace', # workspace to save logs & ckpts
|
170 |
+
best_mode='min', # the smaller/larger result, the better
|
171 |
+
use_loss_as_metric=True, # use loss as the first metric
|
172 |
+
report_metric_at_train=False, # also report metrics at training
|
173 |
+
use_checkpoint="latest", # which ckpt to use at init time
|
174 |
+
use_tensorboardX=True, # whether to use tensorboard for logging
|
175 |
+
scheduler_update_every_step=False, # whether to call scheduler.step() after every train step
|
176 |
+
):
|
177 |
+
|
178 |
+
self.name = name
|
179 |
+
self.opt = opt
|
180 |
+
self.mute = mute
|
181 |
+
self.metrics = metrics
|
182 |
+
self.local_rank = local_rank
|
183 |
+
self.world_size = world_size
|
184 |
+
self.workspace = workspace
|
185 |
+
self.ema_decay = ema_decay
|
186 |
+
self.fp16 = fp16
|
187 |
+
self.best_mode = best_mode
|
188 |
+
self.use_loss_as_metric = use_loss_as_metric
|
189 |
+
self.report_metric_at_train = report_metric_at_train
|
190 |
+
self.max_keep_ckpt = max_keep_ckpt
|
191 |
+
self.eval_interval = eval_interval
|
192 |
+
self.use_checkpoint = use_checkpoint
|
193 |
+
self.use_tensorboardX = use_tensorboardX
|
194 |
+
self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S")
|
195 |
+
self.scheduler_update_every_step = scheduler_update_every_step
|
196 |
+
self.device = device if device is not None else torch.device(f'cuda:{local_rank}' if torch.cuda.is_available() else 'cpu')
|
197 |
+
self.console = Console()
|
198 |
+
|
199 |
+
model.to(self.device)
|
200 |
+
if self.world_size > 1:
|
201 |
+
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
202 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
|
203 |
+
self.model = model
|
204 |
+
|
205 |
+
# guide model
|
206 |
+
self.guidance = guidance
|
207 |
+
|
208 |
+
# text prompt
|
209 |
+
if self.guidance is not None:
|
210 |
+
|
211 |
+
for p in self.guidance.parameters():
|
212 |
+
p.requires_grad = False
|
213 |
+
|
214 |
+
self.prepare_text_embeddings()
|
215 |
+
|
216 |
+
else:
|
217 |
+
self.text_z = None
|
218 |
+
|
219 |
+
if isinstance(criterion, nn.Module):
|
220 |
+
criterion.to(self.device)
|
221 |
+
self.criterion = criterion
|
222 |
+
|
223 |
+
if optimizer is None:
|
224 |
+
self.optimizer = optim.Adam(self.model.parameters(), lr=0.001, weight_decay=5e-4) # naive adam
|
225 |
+
else:
|
226 |
+
self.optimizer = optimizer(self.model)
|
227 |
+
|
228 |
+
if lr_scheduler is None:
|
229 |
+
self.lr_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: 1) # fake scheduler
|
230 |
+
else:
|
231 |
+
self.lr_scheduler = lr_scheduler(self.optimizer)
|
232 |
+
|
233 |
+
if ema_decay is not None:
|
234 |
+
self.ema = ExponentialMovingAverage(self.model.parameters(), decay=ema_decay)
|
235 |
+
else:
|
236 |
+
self.ema = None
|
237 |
+
|
238 |
+
self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16)
|
239 |
+
|
240 |
+
# variable init
|
241 |
+
self.epoch = 0
|
242 |
+
self.global_step = 0
|
243 |
+
self.local_step = 0
|
244 |
+
self.stats = {
|
245 |
+
"loss": [],
|
246 |
+
"valid_loss": [],
|
247 |
+
"results": [], # metrics[0], or valid_loss
|
248 |
+
"checkpoints": [], # record path of saved ckpt, to automatically remove old ckpt
|
249 |
+
"best_result": None,
|
250 |
+
}
|
251 |
+
|
252 |
+
# auto fix
|
253 |
+
if len(metrics) == 0 or self.use_loss_as_metric:
|
254 |
+
self.best_mode = 'min'
|
255 |
+
|
256 |
+
# workspace prepare
|
257 |
+
self.log_ptr = None
|
258 |
+
if self.workspace is not None:
|
259 |
+
os.makedirs(self.workspace, exist_ok=True)
|
260 |
+
self.log_path = os.path.join(workspace, f"log_{self.name}.txt")
|
261 |
+
self.log_ptr = open(self.log_path, "a+")
|
262 |
+
|
263 |
+
self.ckpt_path = os.path.join(self.workspace, 'checkpoints')
|
264 |
+
self.best_path = f"{self.ckpt_path}/{self.name}.pth"
|
265 |
+
os.makedirs(self.ckpt_path, exist_ok=True)
|
266 |
+
|
267 |
+
self.log(f'[INFO] Trainer: {self.name} | {self.time_stamp} | {self.device} | {"fp16" if self.fp16 else "fp32"} | {self.workspace}')
|
268 |
+
self.log(f'[INFO] #parameters: {sum([p.numel() for p in model.parameters() if p.requires_grad])}')
|
269 |
+
|
270 |
+
if self.workspace is not None:
|
271 |
+
if self.use_checkpoint == "scratch":
|
272 |
+
self.log("[INFO] Training from scratch ...")
|
273 |
+
elif self.use_checkpoint == "latest":
|
274 |
+
self.log("[INFO] Loading latest checkpoint ...")
|
275 |
+
self.load_checkpoint()
|
276 |
+
elif self.use_checkpoint == "latest_model":
|
277 |
+
self.log("[INFO] Loading latest checkpoint (model only)...")
|
278 |
+
self.load_checkpoint(model_only=True)
|
279 |
+
elif self.use_checkpoint == "best":
|
280 |
+
if os.path.exists(self.best_path):
|
281 |
+
self.log("[INFO] Loading best checkpoint ...")
|
282 |
+
self.load_checkpoint(self.best_path)
|
283 |
+
else:
|
284 |
+
self.log(f"[INFO] {self.best_path} not found, loading latest ...")
|
285 |
+
self.load_checkpoint()
|
286 |
+
else: # path to ckpt
|
287 |
+
self.log(f"[INFO] Loading {self.use_checkpoint} ...")
|
288 |
+
self.load_checkpoint(self.use_checkpoint)
|
289 |
+
|
290 |
+
# calculate the text embs.
|
291 |
+
def prepare_text_embeddings(self):
|
292 |
+
|
293 |
+
if self.opt.text is None:
|
294 |
+
self.log(f"[WARN] text prompt is not provided.")
|
295 |
+
self.text_z = None
|
296 |
+
return
|
297 |
+
|
298 |
+
if not self.opt.dir_text:
|
299 |
+
self.text_z = self.guidance.get_text_embeds([self.opt.text])
|
300 |
+
else:
|
301 |
+
self.text_z = []
|
302 |
+
for d in ['front', 'side', 'back', 'side', 'overhead', 'bottom']:
|
303 |
+
text = f"{self.opt.text}, {d} view"
|
304 |
+
text_z = self.guidance.get_text_embeds([text])
|
305 |
+
self.text_z.append(text_z)
|
306 |
+
|
307 |
+
def __del__(self):
|
308 |
+
if self.log_ptr:
|
309 |
+
self.log_ptr.close()
|
310 |
+
|
311 |
+
|
312 |
+
def log(self, *args, **kwargs):
|
313 |
+
if self.local_rank == 0:
|
314 |
+
if not self.mute:
|
315 |
+
#print(*args)
|
316 |
+
self.console.print(*args, **kwargs)
|
317 |
+
if self.log_ptr:
|
318 |
+
print(*args, file=self.log_ptr)
|
319 |
+
self.log_ptr.flush() # write immediately to file
|
320 |
+
|
321 |
+
### ------------------------------
|
322 |
+
|
323 |
+
def train_step(self, data):
|
324 |
+
|
325 |
+
rays_o = data['rays_o'] # [B, N, 3]
|
326 |
+
rays_d = data['rays_d'] # [B, N, 3]
|
327 |
+
|
328 |
+
B, N = rays_o.shape[:2]
|
329 |
+
H, W = data['H'], data['W']
|
330 |
+
|
331 |
+
# TODO: shading is not working right now...
|
332 |
+
if self.global_step < self.opt.albedo_iters:
|
333 |
+
shading = 'albedo'
|
334 |
+
ambient_ratio = 1.0
|
335 |
+
else:
|
336 |
+
rand = random.random()
|
337 |
+
if rand > 0.8:
|
338 |
+
shading = 'albedo'
|
339 |
+
ambient_ratio = 1.0
|
340 |
+
# elif rand > 0.4:
|
341 |
+
# shading = 'textureless'
|
342 |
+
# ambient_ratio = 0.1
|
343 |
+
else:
|
344 |
+
shading = 'lambertian'
|
345 |
+
ambient_ratio = 0.1
|
346 |
+
|
347 |
+
# _t = time.time()
|
348 |
+
bg_color = torch.rand((B * N, 3), device=rays_o.device) # pixel-wise random
|
349 |
+
outputs = self.model.render(rays_o, rays_d, staged=False, perturb=True, bg_color=bg_color, ambient_ratio=ambient_ratio, shading=shading, force_all_rays=True, **vars(self.opt))
|
350 |
+
pred_rgb = outputs['image'].reshape(B, H, W, 3).permute(0, 3, 1, 2).contiguous() # [1, 3, H, W]
|
351 |
+
# torch.cuda.synchronize(); print(f'[TIME] nerf render {time.time() - _t:.4f}s')
|
352 |
+
|
353 |
+
# print(shading)
|
354 |
+
# torch_vis_2d(pred_rgb[0])
|
355 |
+
|
356 |
+
# text embeddings
|
357 |
+
if self.opt.dir_text:
|
358 |
+
dirs = data['dir'] # [B,]
|
359 |
+
text_z = self.text_z[dirs]
|
360 |
+
else:
|
361 |
+
text_z = self.text_z
|
362 |
+
|
363 |
+
# encode pred_rgb to latents
|
364 |
+
# _t = time.time()
|
365 |
+
loss = self.guidance.train_step(text_z, pred_rgb)
|
366 |
+
# torch.cuda.synchronize(); print(f'[TIME] total guiding {time.time() - _t:.4f}s')
|
367 |
+
|
368 |
+
# occupancy loss
|
369 |
+
pred_ws = outputs['weights_sum'].reshape(B, 1, H, W)
|
370 |
+
|
371 |
+
if self.opt.lambda_opacity > 0:
|
372 |
+
loss_opacity = (pred_ws ** 2).mean()
|
373 |
+
loss = loss + self.opt.lambda_opacity * loss_opacity
|
374 |
+
|
375 |
+
if self.opt.lambda_entropy > 0:
|
376 |
+
alphas = (pred_ws).clamp(1e-5, 1 - 1e-5)
|
377 |
+
# alphas = alphas ** 2 # skewed entropy, favors 0 over 1
|
378 |
+
loss_entropy = (- alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas)).mean()
|
379 |
+
|
380 |
+
loss = loss + self.opt.lambda_entropy * loss_entropy
|
381 |
+
|
382 |
+
if self.opt.lambda_orient > 0 and 'loss_orient' in outputs:
|
383 |
+
loss_orient = outputs['loss_orient']
|
384 |
+
loss = loss + self.opt.lambda_orient * loss_orient
|
385 |
+
|
386 |
+
return pred_rgb, pred_ws, loss
|
387 |
+
|
388 |
+
def eval_step(self, data):
|
389 |
+
|
390 |
+
rays_o = data['rays_o'] # [B, N, 3]
|
391 |
+
rays_d = data['rays_d'] # [B, N, 3]
|
392 |
+
|
393 |
+
B, N = rays_o.shape[:2]
|
394 |
+
H, W = data['H'], data['W']
|
395 |
+
|
396 |
+
shading = data['shading'] if 'shading' in data else 'albedo'
|
397 |
+
ambient_ratio = data['ambient_ratio'] if 'ambient_ratio' in data else 1.0
|
398 |
+
light_d = data['light_d'] if 'light_d' in data else None
|
399 |
+
|
400 |
+
outputs = self.model.render(rays_o, rays_d, staged=True, perturb=False, bg_color=None, light_d=light_d, ambient_ratio=ambient_ratio, shading=shading, force_all_rays=True, **vars(self.opt))
|
401 |
+
pred_rgb = outputs['image'].reshape(B, H, W, 3)
|
402 |
+
pred_depth = outputs['depth'].reshape(B, H, W)
|
403 |
+
pred_ws = outputs['weights_sum'].reshape(B, H, W)
|
404 |
+
# mask_ws = outputs['mask'].reshape(B, H, W) # near < far
|
405 |
+
|
406 |
+
# loss_ws = pred_ws.sum() / mask_ws.sum()
|
407 |
+
# loss_ws = pred_ws.mean()
|
408 |
+
|
409 |
+
alphas = (pred_ws).clamp(1e-5, 1 - 1e-5)
|
410 |
+
# alphas = alphas ** 2 # skewed entropy, favors 0 over 1
|
411 |
+
loss_entropy = (- alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas)).mean()
|
412 |
+
|
413 |
+
loss = self.opt.lambda_entropy * loss_entropy
|
414 |
+
|
415 |
+
return pred_rgb, pred_depth, loss
|
416 |
+
|
417 |
+
def test_step(self, data, bg_color=None, perturb=False):
|
418 |
+
rays_o = data['rays_o'] # [B, N, 3]
|
419 |
+
rays_d = data['rays_d'] # [B, N, 3]
|
420 |
+
|
421 |
+
B, N = rays_o.shape[:2]
|
422 |
+
H, W = data['H'], data['W']
|
423 |
+
|
424 |
+
if bg_color is not None:
|
425 |
+
bg_color = bg_color.to(rays_o.device)
|
426 |
+
else:
|
427 |
+
bg_color = torch.ones(3, device=rays_o.device) # [3]
|
428 |
+
|
429 |
+
shading = data['shading'] if 'shading' in data else 'albedo'
|
430 |
+
ambient_ratio = data['ambient_ratio'] if 'ambient_ratio' in data else 1.0
|
431 |
+
light_d = data['light_d'] if 'light_d' in data else None
|
432 |
+
|
433 |
+
outputs = self.model.render(rays_o, rays_d, staged=True, perturb=perturb, light_d=light_d, ambient_ratio=ambient_ratio, shading=shading, force_all_rays=True, bg_color=bg_color, **vars(self.opt))
|
434 |
+
|
435 |
+
pred_rgb = outputs['image'].reshape(B, H, W, 3)
|
436 |
+
pred_depth = outputs['depth'].reshape(B, H, W)
|
437 |
+
|
438 |
+
return pred_rgb, pred_depth
|
439 |
+
|
440 |
+
|
441 |
+
def save_mesh(self, save_path=None, resolution=128):
|
442 |
+
|
443 |
+
if save_path is None:
|
444 |
+
save_path = os.path.join(self.workspace, 'mesh')
|
445 |
+
|
446 |
+
self.log(f"==> Saving mesh to {save_path}")
|
447 |
+
|
448 |
+
os.makedirs(save_path, exist_ok=True)
|
449 |
+
|
450 |
+
self.model.export_mesh(save_path, resolution=resolution)
|
451 |
+
|
452 |
+
self.log(f"==> Finished saving mesh.")
|
453 |
+
|
454 |
+
### ------------------------------
|
455 |
+
|
456 |
+
def train(self, train_loader, valid_loader, max_epochs):
|
457 |
+
|
458 |
+
assert self.text_z is not None, 'Training must provide a text prompt!'
|
459 |
+
|
460 |
+
if self.use_tensorboardX and self.local_rank == 0:
|
461 |
+
self.writer = tensorboardX.SummaryWriter(os.path.join(self.workspace, "run", self.name))
|
462 |
+
|
463 |
+
start_t = time.time()
|
464 |
+
|
465 |
+
for epoch in range(self.epoch + 1, max_epochs + 1):
|
466 |
+
self.epoch = epoch
|
467 |
+
|
468 |
+
self.train_one_epoch(train_loader)
|
469 |
+
|
470 |
+
if self.workspace is not None and self.local_rank == 0:
|
471 |
+
self.save_checkpoint(full=True, best=False)
|
472 |
+
|
473 |
+
if self.epoch % self.eval_interval == 0:
|
474 |
+
self.evaluate_one_epoch(valid_loader)
|
475 |
+
self.save_checkpoint(full=False, best=True)
|
476 |
+
|
477 |
+
end_t = time.time()
|
478 |
+
|
479 |
+
self.log(f"[INFO] training takes {(end_t - start_t)/ 60:.4f} minutes.")
|
480 |
+
|
481 |
+
if self.use_tensorboardX and self.local_rank == 0:
|
482 |
+
self.writer.close()
|
483 |
+
|
484 |
+
def evaluate(self, loader, name=None):
|
485 |
+
self.use_tensorboardX, use_tensorboardX = False, self.use_tensorboardX
|
486 |
+
self.evaluate_one_epoch(loader, name)
|
487 |
+
self.use_tensorboardX = use_tensorboardX
|
488 |
+
|
489 |
+
def test(self, loader, save_path=None, name=None, write_video=True):
|
490 |
+
|
491 |
+
if save_path is None:
|
492 |
+
save_path = os.path.join(self.workspace, 'results')
|
493 |
+
|
494 |
+
if name is None:
|
495 |
+
name = f'{self.name}_ep{self.epoch:04d}'
|
496 |
+
|
497 |
+
os.makedirs(save_path, exist_ok=True)
|
498 |
+
|
499 |
+
self.log(f"==> Start Test, save results to {save_path}")
|
500 |
+
|
501 |
+
pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')
|
502 |
+
self.model.eval()
|
503 |
+
|
504 |
+
if write_video:
|
505 |
+
all_preds = []
|
506 |
+
all_preds_depth = []
|
507 |
+
|
508 |
+
with torch.no_grad():
|
509 |
+
|
510 |
+
for i, data in enumerate(loader):
|
511 |
+
|
512 |
+
with torch.cuda.amp.autocast(enabled=self.fp16):
|
513 |
+
preds, preds_depth = self.test_step(data)
|
514 |
+
|
515 |
+
pred = preds[0].detach().cpu().numpy()
|
516 |
+
pred = (pred * 255).astype(np.uint8)
|
517 |
+
|
518 |
+
pred_depth = preds_depth[0].detach().cpu().numpy()
|
519 |
+
pred_depth = (pred_depth * 255).astype(np.uint8)
|
520 |
+
|
521 |
+
if write_video:
|
522 |
+
all_preds.append(pred)
|
523 |
+
all_preds_depth.append(pred_depth)
|
524 |
+
else:
|
525 |
+
cv2.imwrite(os.path.join(save_path, f'{name}_{i:04d}_rgb.png'), cv2.cvtColor(pred, cv2.COLOR_RGB2BGR))
|
526 |
+
cv2.imwrite(os.path.join(save_path, f'{name}_{i:04d}_depth.png'), pred_depth)
|
527 |
+
|
528 |
+
pbar.update(loader.batch_size)
|
529 |
+
|
530 |
+
if write_video:
|
531 |
+
all_preds = np.stack(all_preds, axis=0)
|
532 |
+
all_preds_depth = np.stack(all_preds_depth, axis=0)
|
533 |
+
|
534 |
+
imageio.mimwrite(os.path.join(save_path, f'{name}_rgb.mp4'), all_preds, fps=25, quality=8, macro_block_size=1)
|
535 |
+
imageio.mimwrite(os.path.join(save_path, f'{name}_depth.mp4'), all_preds_depth, fps=25, quality=8, macro_block_size=1)
|
536 |
+
|
537 |
+
self.log(f"==> Finished Test.")
|
538 |
+
|
539 |
+
# [GUI] train text step.
|
540 |
+
def train_gui(self, train_loader, step=16):
|
541 |
+
|
542 |
+
self.model.train()
|
543 |
+
|
544 |
+
total_loss = torch.tensor([0], dtype=torch.float32, device=self.device)
|
545 |
+
|
546 |
+
loader = iter(train_loader)
|
547 |
+
|
548 |
+
for _ in range(step):
|
549 |
+
|
550 |
+
# mimic an infinite loop dataloader (in case the total dataset is smaller than step)
|
551 |
+
try:
|
552 |
+
data = next(loader)
|
553 |
+
except StopIteration:
|
554 |
+
loader = iter(train_loader)
|
555 |
+
data = next(loader)
|
556 |
+
|
557 |
+
# update grid every 16 steps
|
558 |
+
if self.model.cuda_ray and self.global_step % self.opt.update_extra_interval == 0:
|
559 |
+
with torch.cuda.amp.autocast(enabled=self.fp16):
|
560 |
+
self.model.update_extra_state()
|
561 |
+
|
562 |
+
self.global_step += 1
|
563 |
+
|
564 |
+
self.optimizer.zero_grad()
|
565 |
+
|
566 |
+
with torch.cuda.amp.autocast(enabled=self.fp16):
|
567 |
+
pred_rgbs, pred_ws, loss = self.train_step(data)
|
568 |
+
|
569 |
+
self.scaler.scale(loss).backward()
|
570 |
+
self.scaler.step(self.optimizer)
|
571 |
+
self.scaler.update()
|
572 |
+
|
573 |
+
if self.scheduler_update_every_step:
|
574 |
+
self.lr_scheduler.step()
|
575 |
+
|
576 |
+
total_loss += loss.detach()
|
577 |
+
|
578 |
+
if self.ema is not None:
|
579 |
+
self.ema.update()
|
580 |
+
|
581 |
+
average_loss = total_loss.item() / step
|
582 |
+
|
583 |
+
if not self.scheduler_update_every_step:
|
584 |
+
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
585 |
+
self.lr_scheduler.step(average_loss)
|
586 |
+
else:
|
587 |
+
self.lr_scheduler.step()
|
588 |
+
|
589 |
+
outputs = {
|
590 |
+
'loss': average_loss,
|
591 |
+
'lr': self.optimizer.param_groups[0]['lr'],
|
592 |
+
}
|
593 |
+
|
594 |
+
return outputs
|
595 |
+
|
596 |
+
|
597 |
+
# [GUI] test on a single image
|
598 |
+
def test_gui(self, pose, intrinsics, W, H, bg_color=None, spp=1, downscale=1, light_d=None, ambient_ratio=1.0, shading='albedo'):
|
599 |
+
|
600 |
+
# render resolution (may need downscale to for better frame rate)
|
601 |
+
rH = int(H * downscale)
|
602 |
+
rW = int(W * downscale)
|
603 |
+
intrinsics = intrinsics * downscale
|
604 |
+
|
605 |
+
pose = torch.from_numpy(pose).unsqueeze(0).to(self.device)
|
606 |
+
|
607 |
+
rays = get_rays(pose, intrinsics, rH, rW, -1)
|
608 |
+
|
609 |
+
# from degree theta/phi to 3D normalized vec
|
610 |
+
light_d = np.deg2rad(light_d)
|
611 |
+
light_d = np.array([
|
612 |
+
np.sin(light_d[0]) * np.sin(light_d[1]),
|
613 |
+
np.cos(light_d[0]),
|
614 |
+
np.sin(light_d[0]) * np.cos(light_d[1]),
|
615 |
+
], dtype=np.float32)
|
616 |
+
light_d = torch.from_numpy(light_d).to(self.device)
|
617 |
+
|
618 |
+
data = {
|
619 |
+
'rays_o': rays['rays_o'],
|
620 |
+
'rays_d': rays['rays_d'],
|
621 |
+
'H': rH,
|
622 |
+
'W': rW,
|
623 |
+
'light_d': light_d,
|
624 |
+
'ambient_ratio': ambient_ratio,
|
625 |
+
'shading': shading,
|
626 |
+
}
|
627 |
+
|
628 |
+
self.model.eval()
|
629 |
+
|
630 |
+
if self.ema is not None:
|
631 |
+
self.ema.store()
|
632 |
+
self.ema.copy_to()
|
633 |
+
|
634 |
+
with torch.no_grad():
|
635 |
+
with torch.cuda.amp.autocast(enabled=self.fp16):
|
636 |
+
# here spp is used as perturb random seed!
|
637 |
+
preds, preds_depth = self.test_step(data, bg_color=bg_color, perturb=spp)
|
638 |
+
|
639 |
+
if self.ema is not None:
|
640 |
+
self.ema.restore()
|
641 |
+
|
642 |
+
# interpolation to the original resolution
|
643 |
+
if downscale != 1:
|
644 |
+
# have to permute twice with torch...
|
645 |
+
preds = F.interpolate(preds.permute(0, 3, 1, 2), size=(H, W), mode='nearest').permute(0, 2, 3, 1).contiguous()
|
646 |
+
preds_depth = F.interpolate(preds_depth.unsqueeze(1), size=(H, W), mode='nearest').squeeze(1)
|
647 |
+
|
648 |
+
outputs = {
|
649 |
+
'image': preds[0].detach().cpu().numpy(),
|
650 |
+
'depth': preds_depth[0].detach().cpu().numpy(),
|
651 |
+
}
|
652 |
+
|
653 |
+
return outputs
|
654 |
+
|
655 |
+
def train_one_epoch(self, loader):
|
656 |
+
self.log(f"==> Start Training {self.workspace} Epoch {self.epoch}, lr={self.optimizer.param_groups[0]['lr']:.6f} ...")
|
657 |
+
|
658 |
+
total_loss = 0
|
659 |
+
if self.local_rank == 0 and self.report_metric_at_train:
|
660 |
+
for metric in self.metrics:
|
661 |
+
metric.clear()
|
662 |
+
|
663 |
+
self.model.train()
|
664 |
+
|
665 |
+
# distributedSampler: must call set_epoch() to shuffle indices across multiple epochs
|
666 |
+
# ref: https://pytorch.org/docs/stable/data.html
|
667 |
+
if self.world_size > 1:
|
668 |
+
loader.sampler.set_epoch(self.epoch)
|
669 |
+
|
670 |
+
if self.local_rank == 0:
|
671 |
+
pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')
|
672 |
+
|
673 |
+
self.local_step = 0
|
674 |
+
|
675 |
+
for data in loader:
|
676 |
+
|
677 |
+
# update grid every 16 steps
|
678 |
+
if self.model.cuda_ray and self.global_step % self.opt.update_extra_interval == 0:
|
679 |
+
with torch.cuda.amp.autocast(enabled=self.fp16):
|
680 |
+
self.model.update_extra_state()
|
681 |
+
|
682 |
+
self.local_step += 1
|
683 |
+
self.global_step += 1
|
684 |
+
|
685 |
+
self.optimizer.zero_grad()
|
686 |
+
|
687 |
+
with torch.cuda.amp.autocast(enabled=self.fp16):
|
688 |
+
pred_rgbs, pred_ws, loss = self.train_step(data)
|
689 |
+
|
690 |
+
self.scaler.scale(loss).backward()
|
691 |
+
self.scaler.step(self.optimizer)
|
692 |
+
self.scaler.update()
|
693 |
+
|
694 |
+
if self.scheduler_update_every_step:
|
695 |
+
self.lr_scheduler.step()
|
696 |
+
|
697 |
+
loss_val = loss.item()
|
698 |
+
total_loss += loss_val
|
699 |
+
|
700 |
+
if self.local_rank == 0:
|
701 |
+
# if self.report_metric_at_train:
|
702 |
+
# for metric in self.metrics:
|
703 |
+
# metric.update(preds, truths)
|
704 |
+
|
705 |
+
if self.use_tensorboardX:
|
706 |
+
self.writer.add_scalar("train/loss", loss_val, self.global_step)
|
707 |
+
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]['lr'], self.global_step)
|
708 |
+
|
709 |
+
if self.scheduler_update_every_step:
|
710 |
+
pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f}), lr={self.optimizer.param_groups[0]['lr']:.6f}")
|
711 |
+
else:
|
712 |
+
pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})")
|
713 |
+
pbar.update(loader.batch_size)
|
714 |
+
|
715 |
+
if self.ema is not None:
|
716 |
+
self.ema.update()
|
717 |
+
|
718 |
+
average_loss = total_loss / self.local_step
|
719 |
+
self.stats["loss"].append(average_loss)
|
720 |
+
|
721 |
+
if self.local_rank == 0:
|
722 |
+
pbar.close()
|
723 |
+
if self.report_metric_at_train:
|
724 |
+
for metric in self.metrics:
|
725 |
+
self.log(metric.report(), style="red")
|
726 |
+
if self.use_tensorboardX:
|
727 |
+
metric.write(self.writer, self.epoch, prefix="train")
|
728 |
+
metric.clear()
|
729 |
+
|
730 |
+
if not self.scheduler_update_every_step:
|
731 |
+
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
732 |
+
self.lr_scheduler.step(average_loss)
|
733 |
+
else:
|
734 |
+
self.lr_scheduler.step()
|
735 |
+
|
736 |
+
self.log(f"==> Finished Epoch {self.epoch}.")
|
737 |
+
|
738 |
+
|
739 |
+
def evaluate_one_epoch(self, loader, name=None):
|
740 |
+
self.log(f"++> Evaluate {self.workspace} at epoch {self.epoch} ...")
|
741 |
+
|
742 |
+
if name is None:
|
743 |
+
name = f'{self.name}_ep{self.epoch:04d}'
|
744 |
+
|
745 |
+
total_loss = 0
|
746 |
+
if self.local_rank == 0:
|
747 |
+
for metric in self.metrics:
|
748 |
+
metric.clear()
|
749 |
+
|
750 |
+
self.model.eval()
|
751 |
+
|
752 |
+
if self.ema is not None:
|
753 |
+
self.ema.store()
|
754 |
+
self.ema.copy_to()
|
755 |
+
|
756 |
+
if self.local_rank == 0:
|
757 |
+
pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')
|
758 |
+
|
759 |
+
with torch.no_grad():
|
760 |
+
self.local_step = 0
|
761 |
+
|
762 |
+
for data in loader:
|
763 |
+
self.local_step += 1
|
764 |
+
|
765 |
+
with torch.cuda.amp.autocast(enabled=self.fp16):
|
766 |
+
preds, preds_depth, loss = self.eval_step(data)
|
767 |
+
|
768 |
+
# all_gather/reduce the statistics (NCCL only support all_*)
|
769 |
+
if self.world_size > 1:
|
770 |
+
dist.all_reduce(loss, op=dist.ReduceOp.SUM)
|
771 |
+
loss = loss / self.world_size
|
772 |
+
|
773 |
+
preds_list = [torch.zeros_like(preds).to(self.device) for _ in range(self.world_size)] # [[B, ...], [B, ...], ...]
|
774 |
+
dist.all_gather(preds_list, preds)
|
775 |
+
preds = torch.cat(preds_list, dim=0)
|
776 |
+
|
777 |
+
preds_depth_list = [torch.zeros_like(preds_depth).to(self.device) for _ in range(self.world_size)] # [[B, ...], [B, ...], ...]
|
778 |
+
dist.all_gather(preds_depth_list, preds_depth)
|
779 |
+
preds_depth = torch.cat(preds_depth_list, dim=0)
|
780 |
+
|
781 |
+
loss_val = loss.item()
|
782 |
+
total_loss += loss_val
|
783 |
+
|
784 |
+
# only rank = 0 will perform evaluation.
|
785 |
+
if self.local_rank == 0:
|
786 |
+
|
787 |
+
# save image
|
788 |
+
save_path = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_rgb.png')
|
789 |
+
save_path_depth = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_depth.png')
|
790 |
+
|
791 |
+
#self.log(f"==> Saving validation image to {save_path}")
|
792 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
793 |
+
|
794 |
+
pred = preds[0].detach().cpu().numpy()
|
795 |
+
pred = (pred * 255).astype(np.uint8)
|
796 |
+
|
797 |
+
pred_depth = preds_depth[0].detach().cpu().numpy()
|
798 |
+
pred_depth = (pred_depth * 255).astype(np.uint8)
|
799 |
+
|
800 |
+
cv2.imwrite(save_path, cv2.cvtColor(pred, cv2.COLOR_RGB2BGR))
|
801 |
+
cv2.imwrite(save_path_depth, pred_depth)
|
802 |
+
|
803 |
+
pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})")
|
804 |
+
pbar.update(loader.batch_size)
|
805 |
+
|
806 |
+
|
807 |
+
average_loss = total_loss / self.local_step
|
808 |
+
self.stats["valid_loss"].append(average_loss)
|
809 |
+
|
810 |
+
if self.local_rank == 0:
|
811 |
+
pbar.close()
|
812 |
+
if not self.use_loss_as_metric and len(self.metrics) > 0:
|
813 |
+
result = self.metrics[0].measure()
|
814 |
+
self.stats["results"].append(result if self.best_mode == 'min' else - result) # if max mode, use -result
|
815 |
+
else:
|
816 |
+
self.stats["results"].append(average_loss) # if no metric, choose best by min loss
|
817 |
+
|
818 |
+
for metric in self.metrics:
|
819 |
+
self.log(metric.report(), style="blue")
|
820 |
+
if self.use_tensorboardX:
|
821 |
+
metric.write(self.writer, self.epoch, prefix="evaluate")
|
822 |
+
metric.clear()
|
823 |
+
|
824 |
+
if self.ema is not None:
|
825 |
+
self.ema.restore()
|
826 |
+
|
827 |
+
self.log(f"++> Evaluate epoch {self.epoch} Finished.")
|
828 |
+
|
829 |
+
def save_checkpoint(self, name=None, full=False, best=False):
|
830 |
+
|
831 |
+
if name is None:
|
832 |
+
name = f'{self.name}_ep{self.epoch:04d}'
|
833 |
+
|
834 |
+
state = {
|
835 |
+
'epoch': self.epoch,
|
836 |
+
'global_step': self.global_step,
|
837 |
+
'stats': self.stats,
|
838 |
+
}
|
839 |
+
|
840 |
+
if self.model.cuda_ray:
|
841 |
+
state['mean_count'] = self.model.mean_count
|
842 |
+
state['mean_density'] = self.model.mean_density
|
843 |
+
|
844 |
+
if full:
|
845 |
+
state['optimizer'] = self.optimizer.state_dict()
|
846 |
+
state['lr_scheduler'] = self.lr_scheduler.state_dict()
|
847 |
+
state['scaler'] = self.scaler.state_dict()
|
848 |
+
if self.ema is not None:
|
849 |
+
state['ema'] = self.ema.state_dict()
|
850 |
+
|
851 |
+
if not best:
|
852 |
+
|
853 |
+
state['model'] = self.model.state_dict()
|
854 |
+
|
855 |
+
file_path = f"{name}.pth"
|
856 |
+
|
857 |
+
self.stats["checkpoints"].append(file_path)
|
858 |
+
|
859 |
+
if len(self.stats["checkpoints"]) > self.max_keep_ckpt:
|
860 |
+
old_ckpt = os.path.join(self.ckpt_path, self.stats["checkpoints"].pop(0))
|
861 |
+
if os.path.exists(old_ckpt):
|
862 |
+
os.remove(old_ckpt)
|
863 |
+
|
864 |
+
torch.save(state, os.path.join(self.ckpt_path, file_path))
|
865 |
+
|
866 |
+
else:
|
867 |
+
if len(self.stats["results"]) > 0:
|
868 |
+
if self.stats["best_result"] is None or self.stats["results"][-1] < self.stats["best_result"]:
|
869 |
+
self.log(f"[INFO] New best result: {self.stats['best_result']} --> {self.stats['results'][-1]}")
|
870 |
+
self.stats["best_result"] = self.stats["results"][-1]
|
871 |
+
|
872 |
+
# save ema results
|
873 |
+
if self.ema is not None:
|
874 |
+
self.ema.store()
|
875 |
+
self.ema.copy_to()
|
876 |
+
|
877 |
+
state['model'] = self.model.state_dict()
|
878 |
+
|
879 |
+
if self.ema is not None:
|
880 |
+
self.ema.restore()
|
881 |
+
|
882 |
+
torch.save(state, self.best_path)
|
883 |
+
else:
|
884 |
+
self.log(f"[WARN] no evaluated results found, skip saving best checkpoint.")
|
885 |
+
|
886 |
+
def load_checkpoint(self, checkpoint=None, model_only=False):
|
887 |
+
if checkpoint is None:
|
888 |
+
checkpoint_list = sorted(glob.glob(f'{self.ckpt_path}/*.pth'))
|
889 |
+
if checkpoint_list:
|
890 |
+
checkpoint = checkpoint_list[-1]
|
891 |
+
self.log(f"[INFO] Latest checkpoint is {checkpoint}")
|
892 |
+
else:
|
893 |
+
self.log("[WARN] No checkpoint found, model randomly initialized.")
|
894 |
+
return
|
895 |
+
|
896 |
+
checkpoint_dict = torch.load(checkpoint, map_location=self.device)
|
897 |
+
|
898 |
+
if 'model' not in checkpoint_dict:
|
899 |
+
self.model.load_state_dict(checkpoint_dict)
|
900 |
+
self.log("[INFO] loaded model.")
|
901 |
+
return
|
902 |
+
|
903 |
+
missing_keys, unexpected_keys = self.model.load_state_dict(checkpoint_dict['model'], strict=False)
|
904 |
+
self.log("[INFO] loaded model.")
|
905 |
+
if len(missing_keys) > 0:
|
906 |
+
self.log(f"[WARN] missing keys: {missing_keys}")
|
907 |
+
if len(unexpected_keys) > 0:
|
908 |
+
self.log(f"[WARN] unexpected keys: {unexpected_keys}")
|
909 |
+
|
910 |
+
if self.ema is not None and 'ema' in checkpoint_dict:
|
911 |
+
try:
|
912 |
+
self.ema.load_state_dict(checkpoint_dict['ema'])
|
913 |
+
self.log("[INFO] loaded EMA.")
|
914 |
+
except:
|
915 |
+
self.log("[WARN] failed to loaded EMA.")
|
916 |
+
|
917 |
+
if self.model.cuda_ray:
|
918 |
+
if 'mean_count' in checkpoint_dict:
|
919 |
+
self.model.mean_count = checkpoint_dict['mean_count']
|
920 |
+
if 'mean_density' in checkpoint_dict:
|
921 |
+
self.model.mean_density = checkpoint_dict['mean_density']
|
922 |
+
|
923 |
+
if model_only:
|
924 |
+
return
|
925 |
+
|
926 |
+
self.stats = checkpoint_dict['stats']
|
927 |
+
self.epoch = checkpoint_dict['epoch']
|
928 |
+
self.global_step = checkpoint_dict['global_step']
|
929 |
+
self.log(f"[INFO] load at epoch {self.epoch}, global step {self.global_step}")
|
930 |
+
|
931 |
+
if self.optimizer and 'optimizer' in checkpoint_dict:
|
932 |
+
try:
|
933 |
+
self.optimizer.load_state_dict(checkpoint_dict['optimizer'])
|
934 |
+
self.log("[INFO] loaded optimizer.")
|
935 |
+
except:
|
936 |
+
self.log("[WARN] Failed to load optimizer.")
|
937 |
+
|
938 |
+
if self.lr_scheduler and 'lr_scheduler' in checkpoint_dict:
|
939 |
+
try:
|
940 |
+
self.lr_scheduler.load_state_dict(checkpoint_dict['lr_scheduler'])
|
941 |
+
self.log("[INFO] loaded scheduler.")
|
942 |
+
except:
|
943 |
+
self.log("[WARN] Failed to load scheduler.")
|
944 |
+
|
945 |
+
if self.scaler and 'scaler' in checkpoint_dict:
|
946 |
+
try:
|
947 |
+
self.scaler.load_state_dict(checkpoint_dict['scaler'])
|
948 |
+
self.log("[INFO] loaded scaler.")
|
949 |
+
except:
|
950 |
+
self.log("[WARN] Failed to load scaler.")
|
optimizer.py
ADDED
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import enum
|
4 |
+
import itertools
|
5 |
+
from dataclasses import dataclass
|
6 |
+
import torch.optim as optim
|
7 |
+
|
8 |
+
@torch.no_grad()
|
9 |
+
def PowerIter(mat_g, error_tolerance=1e-6, num_iters=100):
|
10 |
+
"""Power iteration.
|
11 |
+
Compute the maximum eigenvalue of mat, for scaling.
|
12 |
+
v is a random vector with values in (-1, 1)
|
13 |
+
Args:
|
14 |
+
mat_g: the symmetric PSD matrix.
|
15 |
+
error_tolerance: Iterative exit condition.
|
16 |
+
num_iters: Number of iterations.
|
17 |
+
Returns:
|
18 |
+
eigen vector, eigen value, num_iters
|
19 |
+
"""
|
20 |
+
v = torch.rand(list(mat_g.shape)[0], device=mat_g.get_device()) * 2 - 1
|
21 |
+
error = 1
|
22 |
+
iters = 0
|
23 |
+
singular_val = 0
|
24 |
+
while error > error_tolerance and iters < num_iters:
|
25 |
+
v = v / torch.norm(v)
|
26 |
+
mat_v = torch.mv(mat_g, v)
|
27 |
+
s_v = torch.dot(v, mat_v)
|
28 |
+
error = torch.abs(s_v - singular_val)
|
29 |
+
v = mat_v
|
30 |
+
singular_val = s_v
|
31 |
+
iters += 1
|
32 |
+
return singular_val, v / torch.norm(v), iters
|
33 |
+
|
34 |
+
|
35 |
+
@torch.no_grad()
|
36 |
+
def MatPower(mat_m, p):
|
37 |
+
"""Computes mat_m^p, for p a positive integer.
|
38 |
+
Args:
|
39 |
+
mat_m: a square matrix
|
40 |
+
p: a positive integer
|
41 |
+
Returns:
|
42 |
+
mat_m^p
|
43 |
+
"""
|
44 |
+
if p in [1, 2, 4, 8, 16, 32]:
|
45 |
+
p_done = 1
|
46 |
+
res = mat_m
|
47 |
+
while p_done < p:
|
48 |
+
res = torch.matmul(res, res)
|
49 |
+
p_done *= 2
|
50 |
+
return res
|
51 |
+
|
52 |
+
power = None
|
53 |
+
while p > 0:
|
54 |
+
if p % 2 == 1:
|
55 |
+
power = torch.matmul(mat_m, power) if power is not None else mat_m
|
56 |
+
p //= 2
|
57 |
+
mat_m = torch.matmul(mat_m, mat_m)
|
58 |
+
return power
|
59 |
+
|
60 |
+
|
61 |
+
@torch.no_grad()
|
62 |
+
def ComputePower(mat_g, p,
|
63 |
+
iter_count=100,
|
64 |
+
error_tolerance=1e-6,
|
65 |
+
ridge_epsilon=1e-6):
|
66 |
+
"""A method to compute G^{-1/p} using a coupled Newton iteration.
|
67 |
+
See for example equation 3.2 on page 9 of:
|
68 |
+
A Schur-Newton Method for the Matrix p-th Root and its Inverse
|
69 |
+
by Chun-Hua Guo and Nicholas J. Higham
|
70 |
+
SIAM Journal on Matrix Analysis and Applications,
|
71 |
+
2006, Vol. 28, No. 3 : pp. 788-804
|
72 |
+
https://pdfs.semanticscholar.org/0abe/7f77433cf5908bfe2b79aa91af881da83858.pdf
|
73 |
+
Args:
|
74 |
+
mat_g: A square positive semidefinite matrix
|
75 |
+
p: a positive integer
|
76 |
+
iter_count: Stop iterating after this many rounds.
|
77 |
+
error_tolerance: Threshold for stopping iteration
|
78 |
+
ridge_epsilon: We add this times I to G, to make is positive definite.
|
79 |
+
For scaling, we multiply it by the largest eigenvalue of G.
|
80 |
+
Returns:
|
81 |
+
(mat_g + rI)^{-1/p} (r = ridge_epsilon * max_eigenvalue of mat_g).
|
82 |
+
"""
|
83 |
+
shape = list(mat_g.shape)
|
84 |
+
if len(shape) == 1:
|
85 |
+
return torch.pow(mat_g + ridge_epsilon, -1/p)
|
86 |
+
identity = torch.eye(shape[0], device=mat_g.get_device())
|
87 |
+
if shape[0] == 1:
|
88 |
+
return identity
|
89 |
+
alpha = -1.0/p
|
90 |
+
max_ev, _, _ = PowerIter(mat_g)
|
91 |
+
ridge_epsilon *= max_ev
|
92 |
+
mat_g += ridge_epsilon * identity
|
93 |
+
z = (1 + p) / (2 * torch.norm(mat_g))
|
94 |
+
# The best value for z is
|
95 |
+
# (1 + p) * (c_max^{1/p} - c_min^{1/p}) /
|
96 |
+
# (c_max^{1+1/p} - c_min^{1+1/p})
|
97 |
+
# where c_max and c_min are the largest and smallest singular values of
|
98 |
+
# mat_g.
|
99 |
+
# The above estimate assumes that c_max > c_min * 2^p
|
100 |
+
# Can replace above line by the one below, but it is less accurate,
|
101 |
+
# hence needs more iterations to converge.
|
102 |
+
# z = (1 + p) / tf.trace(mat_g)
|
103 |
+
# If we want the method to always converge, use z = 1 / norm(mat_g)
|
104 |
+
# or z = 1 / tf.trace(mat_g), but these can result in many
|
105 |
+
# extra iterations.
|
106 |
+
|
107 |
+
mat_root = identity * torch.pow(z, 1.0/p)
|
108 |
+
mat_m = mat_g * z
|
109 |
+
error = torch.max(torch.abs(mat_m - identity))
|
110 |
+
count = 0
|
111 |
+
while error > error_tolerance and count < iter_count:
|
112 |
+
tmp_mat_m = (1 - alpha) * identity + alpha * mat_m
|
113 |
+
new_mat_root = torch.matmul(mat_root, tmp_mat_m)
|
114 |
+
mat_m = torch.matmul(MatPower(tmp_mat_m, p), mat_m)
|
115 |
+
new_error = torch.max(torch.abs(mat_m - identity))
|
116 |
+
if new_error > error * 1.2:
|
117 |
+
break
|
118 |
+
mat_root = new_mat_root
|
119 |
+
error = new_error
|
120 |
+
count += 1
|
121 |
+
return mat_root
|
122 |
+
|
123 |
+
|
124 |
+
|
125 |
+
# Grafting is a technique to fix the layerwise scale of Shampoo optimizer.
|
126 |
+
# https://arxiv.org/pdf/2002.11803.pdf studies this in detail. This
|
127 |
+
# allows us to plugin the Shampoo optimizer into settings where SGD/AdaGrad
|
128 |
+
# is already well tuned. Grafting onto Shampoo means take the Shampoo direction,
|
129 |
+
# but use the step magnitude from the grafted optimizer such as Adagrad or SGD.
|
130 |
+
class LayerwiseGrafting(enum.IntEnum):
|
131 |
+
NONE = 0
|
132 |
+
SGD = 1
|
133 |
+
ADAGRAD = 2
|
134 |
+
|
135 |
+
|
136 |
+
@dataclass
|
137 |
+
class ShampooHyperParams:
|
138 |
+
"""Shampoo hyper parameters."""
|
139 |
+
beta2: float = 0.9
|
140 |
+
diagonal_eps: float = 1e-6
|
141 |
+
matrix_eps: float = 1e-12
|
142 |
+
weight_decay: float = 0.0
|
143 |
+
inverse_exponent_override: int = 2 # fixed exponent for preconditioner, if >0
|
144 |
+
start_preconditioning_step: int = 1
|
145 |
+
# Performance tuning params for controlling memory and compute requirements.
|
146 |
+
# How often to compute preconditioner.
|
147 |
+
preconditioning_compute_steps: int = 1
|
148 |
+
# How often to compute statistics.
|
149 |
+
statistics_compute_steps: int = 1
|
150 |
+
# Block size for large layers (if > 0).
|
151 |
+
# Block size = 1 ==> Adagrad (Don't do this, extremely inefficient!)
|
152 |
+
# Block size should be as large as feasible under memory/time constraints.
|
153 |
+
block_size: int = 128
|
154 |
+
# Automatic shape interpretation (for eg: [4, 3, 1024, 512] would result in
|
155 |
+
# 12 x [1024, 512] L and R statistics. Disabled by default which results in
|
156 |
+
# Shampoo constructing statistics [4, 4], [3, 3], [1024, 1024], [512, 512].
|
157 |
+
best_effort_shape_interpretation: bool = True
|
158 |
+
# Type of grafting (SGD or AdaGrad).
|
159 |
+
# https://arxiv.org/pdf/2002.11803.pdf
|
160 |
+
graft_type: int = LayerwiseGrafting.ADAGRAD
|
161 |
+
# Nesterov momentum
|
162 |
+
nesterov: bool = True
|
163 |
+
|
164 |
+
|
165 |
+
class Graft:
|
166 |
+
"""Base class to perform grafting onto Shampoo. This class does no grafting.
|
167 |
+
"""
|
168 |
+
|
169 |
+
def __init__(self, hps, unused_var):
|
170 |
+
self.hps = hps
|
171 |
+
|
172 |
+
def add_statistics(self, grad):
|
173 |
+
pass
|
174 |
+
|
175 |
+
def precondition_gradient(self, grad):
|
176 |
+
return grad
|
177 |
+
|
178 |
+
def update_momentum(self, update, unused_beta1):
|
179 |
+
return update
|
180 |
+
|
181 |
+
|
182 |
+
class SGDGraft(Graft):
|
183 |
+
"""Graft using SGD+momentum.
|
184 |
+
momentum maintains an exponentially weighted moving average of gradients.
|
185 |
+
"""
|
186 |
+
|
187 |
+
def __init__(self, hps, var):
|
188 |
+
super(SGDGraft, self).__init__(hps, var)
|
189 |
+
self.momentum = torch.zeros_like(var.data, device=var.get_device())
|
190 |
+
|
191 |
+
def update_momentum(self, update, beta1):
|
192 |
+
self.momentum.mul_(beta1).add_(update)
|
193 |
+
return self.momentum
|
194 |
+
|
195 |
+
|
196 |
+
class AdagradGraft(SGDGraft):
|
197 |
+
"""Graft using Adagrad.
|
198 |
+
Essentially an implementation of Adagrad with momentum.
|
199 |
+
"""
|
200 |
+
|
201 |
+
def __init__(self, hps, var):
|
202 |
+
super(AdagradGraft, self).__init__(hps, var)
|
203 |
+
self.statistics = torch.zeros_like(var.data, device=var.get_device())
|
204 |
+
|
205 |
+
def add_statistics(self, grad):
|
206 |
+
self.statistics.add_(grad * grad)
|
207 |
+
|
208 |
+
def precondition_gradient(self, grad):
|
209 |
+
return grad / (torch.sqrt(self.statistics) + self.hps.diagonal_eps)
|
210 |
+
|
211 |
+
|
212 |
+
class BlockPartitioner:
|
213 |
+
"""Partitions a tensor into smaller tensors for preconditioning.
|
214 |
+
For example, if a variable has shape (4096, 512), we might split the
|
215 |
+
4096 into 4 blocks, so we effectively have 4 variables of size
|
216 |
+
(1024, 512) each.
|
217 |
+
"""
|
218 |
+
|
219 |
+
def __init__(self, var, hps):
|
220 |
+
self._shape = var.shape
|
221 |
+
self._splits = []
|
222 |
+
self._split_sizes = []
|
223 |
+
split_sizes = []
|
224 |
+
# We split var into smaller blocks. Here we store the metadata to make
|
225 |
+
# that split.
|
226 |
+
for i, d in enumerate(var.shape):
|
227 |
+
if hps.block_size > 0 and d > hps.block_size:
|
228 |
+
# d-1, otherwise split appends a 0-size array.
|
229 |
+
nsplit = (d-1) // hps.block_size
|
230 |
+
indices = (np.arange(nsplit, dtype=np.int32) + 1) * hps.block_size
|
231 |
+
sizes = np.ones(nsplit + 1, dtype=np.int32) * hps.block_size
|
232 |
+
sizes[-1] = d - indices[-1]
|
233 |
+
self._splits.append((i, indices))
|
234 |
+
self._split_sizes.append((i, sizes))
|
235 |
+
split_sizes.append(sizes)
|
236 |
+
else:
|
237 |
+
split_sizes.append(np.array([d], dtype=np.int32))
|
238 |
+
self._num_splits = len(split_sizes)
|
239 |
+
self._preconditioner_shapes = []
|
240 |
+
for t in itertools.product(*split_sizes):
|
241 |
+
self._preconditioner_shapes.extend([[d, d] for d in t])
|
242 |
+
|
243 |
+
def shapes_for_preconditioners(self):
|
244 |
+
return self._preconditioner_shapes
|
245 |
+
|
246 |
+
def num_splits(self):
|
247 |
+
return self._num_splits
|
248 |
+
|
249 |
+
def partition(self, tensor):
|
250 |
+
"""Partition tensor into blocks."""
|
251 |
+
|
252 |
+
assert tensor.shape == self._shape
|
253 |
+
tensors = [tensor]
|
254 |
+
for (i, sizes) in self._split_sizes:
|
255 |
+
tensors_local = []
|
256 |
+
for t in tensors:
|
257 |
+
tensors_local.extend(
|
258 |
+
torch.split(t, tuple(sizes), dim=i))
|
259 |
+
tensors = tensors_local
|
260 |
+
return tensors
|
261 |
+
|
262 |
+
def merge_partitions(self, partitions):
|
263 |
+
"""Merge partitions back to original shape."""
|
264 |
+
|
265 |
+
for (i, indices) in reversed(self._splits):
|
266 |
+
n = len(indices) + 1
|
267 |
+
partial_merged_tensors = []
|
268 |
+
ind = 0
|
269 |
+
while ind < len(partitions):
|
270 |
+
partial_merged_tensors.append(
|
271 |
+
torch.cat(partitions[ind:ind + n], axis=i))
|
272 |
+
ind += n
|
273 |
+
partitions = partial_merged_tensors
|
274 |
+
assert len(partitions) == 1
|
275 |
+
return partitions[0]
|
276 |
+
|
277 |
+
|
278 |
+
def _merge_small_dims(shape_to_merge, max_dim):
|
279 |
+
"""Merge small dimensions.
|
280 |
+
If there are some small dimensions, we collapse them:
|
281 |
+
e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024
|
282 |
+
[1, 2, 768, 1, 2048] --> [2, 768, 2048]
|
283 |
+
Args:
|
284 |
+
shape_to_merge: Shape to merge small dimensions.
|
285 |
+
max_dim: Maximal dimension of output shape used in merging.
|
286 |
+
Returns:
|
287 |
+
Merged shape.
|
288 |
+
"""
|
289 |
+
resulting_shape = []
|
290 |
+
product = 1
|
291 |
+
for d in shape_to_merge:
|
292 |
+
if product * d <= max_dim:
|
293 |
+
product *= d
|
294 |
+
else:
|
295 |
+
if product > 1:
|
296 |
+
resulting_shape.append(product)
|
297 |
+
product = d
|
298 |
+
if product > 1:
|
299 |
+
resulting_shape.append(product)
|
300 |
+
return resulting_shape
|
301 |
+
|
302 |
+
|
303 |
+
class Preconditioner:
|
304 |
+
"""Compute statistics/shape from gradients for preconditioning."""
|
305 |
+
|
306 |
+
def __init__(self, var, hps):
|
307 |
+
self._hps = hps
|
308 |
+
self._original_shape = var.shape
|
309 |
+
self._transformed_shape = var.shape
|
310 |
+
if hps.best_effort_shape_interpretation:
|
311 |
+
self._transformed_shape = _merge_small_dims(
|
312 |
+
self._original_shape, hps.block_size)
|
313 |
+
|
314 |
+
reshaped_var = torch.reshape(var, self._transformed_shape)
|
315 |
+
self._partitioner = BlockPartitioner(reshaped_var, hps)
|
316 |
+
shapes = self._partitioner.shapes_for_preconditioners()
|
317 |
+
rank = len(self._transformed_shape)
|
318 |
+
device = var.get_device()
|
319 |
+
if rank <= 1:
|
320 |
+
self.statistics = []
|
321 |
+
self.preconditioners = []
|
322 |
+
else:
|
323 |
+
eps = self._hps.matrix_eps
|
324 |
+
self.statistics = [eps * torch.eye(s[0], device=device) for s in shapes]
|
325 |
+
self.preconditioners = [torch.eye(s[0], device=device) for s in shapes]
|
326 |
+
|
327 |
+
def add_statistics(self, grad):
|
328 |
+
"""Compute statistics from gradients and add to the correct state entries.
|
329 |
+
Args:
|
330 |
+
grad: Gradient to compute statistics from.
|
331 |
+
"""
|
332 |
+
if not self.statistics: return
|
333 |
+
reshaped_grad = torch.reshape(grad, self._transformed_shape)
|
334 |
+
partitioned_grads = self._partitioner.partition(reshaped_grad)
|
335 |
+
w1 = self._hps.beta2
|
336 |
+
w2 = 1.0 if w1 == 1.0 else (1.0 - w1)
|
337 |
+
rank = len(self._transformed_shape)
|
338 |
+
for j, grad in enumerate(partitioned_grads):
|
339 |
+
for i in range(rank):
|
340 |
+
axes = list(range(i)) + list(range(i + 1, rank))
|
341 |
+
stat = torch.tensordot(grad, grad, [axes, axes])
|
342 |
+
self.statistics[j*rank + i].mul_(w1).add_(stat, alpha=w2)
|
343 |
+
|
344 |
+
def exponent_for_preconditioner(self):
|
345 |
+
"""Returns exponent to use for inverse-pth root M^{-1/p}."""
|
346 |
+
if self._hps.inverse_exponent_override > 0:
|
347 |
+
return self._hps.inverse_exponent_override
|
348 |
+
return 2 * len(self._transformed_shape)
|
349 |
+
|
350 |
+
def compute_preconditioners(self):
|
351 |
+
"""Compute L^{-1/exp} for each stats matrix L."""
|
352 |
+
exp = self.exponent_for_preconditioner()
|
353 |
+
eps = self._hps.matrix_eps
|
354 |
+
for i, stat in enumerate(self.statistics):
|
355 |
+
self.preconditioners[i] = ComputePower(
|
356 |
+
stat, exp, ridge_epsilon=eps)
|
357 |
+
|
358 |
+
def preconditioned_grad(self, grad):
|
359 |
+
"""Precondition the gradient.
|
360 |
+
Args:
|
361 |
+
grad: A gradient tensor to precondition.
|
362 |
+
Returns:
|
363 |
+
A preconditioned gradient.
|
364 |
+
"""
|
365 |
+
if not self.preconditioners: return grad
|
366 |
+
reshaped_grad = torch.reshape(grad, self._transformed_shape)
|
367 |
+
partitioned_grads = self._partitioner.partition(reshaped_grad)
|
368 |
+
preconditioned_partitioned_grads = []
|
369 |
+
num_splits = self._partitioner.num_splits()
|
370 |
+
for i, grad in enumerate(partitioned_grads):
|
371 |
+
preconditioners_for_grad = self.preconditioners[i * num_splits:(i + 1) *
|
372 |
+
num_splits]
|
373 |
+
rank = len(grad.shape)
|
374 |
+
precond_grad = grad
|
375 |
+
for j in range(rank):
|
376 |
+
preconditioner = preconditioners_for_grad[j]
|
377 |
+
precond_grad = torch.tensordot(
|
378 |
+
precond_grad, preconditioner, [[0], [0]])
|
379 |
+
preconditioned_partitioned_grads.append(precond_grad)
|
380 |
+
merged_grad = self._partitioner.merge_partitions(
|
381 |
+
preconditioned_partitioned_grads)
|
382 |
+
return torch.reshape(merged_grad, self._original_shape)
|
383 |
+
|
384 |
+
|
385 |
+
STEP = 'step'
|
386 |
+
MOMENTUM = 'momentum'
|
387 |
+
PRECONDITIONER = 'preconditioner'
|
388 |
+
GRAFT = 'graft'
|
389 |
+
|
390 |
+
|
391 |
+
class Shampoo(optim.Optimizer):
|
392 |
+
"""The Shampoo optimizer."""
|
393 |
+
|
394 |
+
def __init__(self,
|
395 |
+
params,
|
396 |
+
lr=1.0,
|
397 |
+
momentum=0.9,
|
398 |
+
hyperparams=ShampooHyperParams()):
|
399 |
+
defaults = dict(lr=lr, momentum=momentum)
|
400 |
+
self.hps = hyperparams
|
401 |
+
super(Shampoo, self).__init__(params, defaults)
|
402 |
+
|
403 |
+
def init_var_state(self, var, state):
|
404 |
+
"""Initialize the PyTorch state of for a single variable."""
|
405 |
+
state[STEP] = 0
|
406 |
+
state[MOMENTUM] = torch.zeros_like(var.data, device=var.get_device())
|
407 |
+
state[PRECONDITIONER] = Preconditioner(var, self.hps)
|
408 |
+
if self.hps.graft_type == LayerwiseGrafting.ADAGRAD:
|
409 |
+
state[GRAFT] = AdagradGraft(self.hps, var)
|
410 |
+
elif self.hps.graft_type == LayerwiseGrafting.SGD:
|
411 |
+
state[GRAFT] = SGDGraft(self.hps, var)
|
412 |
+
else:
|
413 |
+
state[GRAFT] = Graft(self.hps, var)
|
414 |
+
|
415 |
+
def step(self, closure=None):
|
416 |
+
hps = self.hps
|
417 |
+
for group in self.param_groups:
|
418 |
+
lr = group['lr']
|
419 |
+
for p in group['params']:
|
420 |
+
if p.grad is None: continue
|
421 |
+
grad = p.grad.data
|
422 |
+
if grad.is_sparse:
|
423 |
+
raise RuntimeError('Shampoo does not support sparse yet')
|
424 |
+
state = self.state[p]
|
425 |
+
if not state:
|
426 |
+
self.init_var_state(p, state)
|
427 |
+
state[STEP] += 1
|
428 |
+
|
429 |
+
preconditioner = state[PRECONDITIONER]
|
430 |
+
graft = state[GRAFT]
|
431 |
+
|
432 |
+
# Gather statistics, compute preconditioners
|
433 |
+
graft.add_statistics(grad)
|
434 |
+
if state[STEP] % hps.statistics_compute_steps == 0:
|
435 |
+
preconditioner.add_statistics(grad)
|
436 |
+
if state[STEP] % hps.preconditioning_compute_steps == 0:
|
437 |
+
preconditioner.compute_preconditioners()
|
438 |
+
|
439 |
+
# Precondition gradients
|
440 |
+
graft_grad = graft.precondition_gradient(grad)
|
441 |
+
shampoo_grad = grad
|
442 |
+
if state[STEP] >= self.hps.start_preconditioning_step:
|
443 |
+
shampoo_grad = preconditioner.preconditioned_grad(grad)
|
444 |
+
|
445 |
+
# Grafting
|
446 |
+
graft_norm = torch.norm(graft_grad)
|
447 |
+
shampoo_norm = torch.norm(shampoo_grad)
|
448 |
+
shampoo_grad.mul_(graft_norm / (shampoo_norm + 1e-16))
|
449 |
+
|
450 |
+
# Weight decay
|
451 |
+
if self.hps.weight_decay != 0.0:
|
452 |
+
shampoo_grad.add_(p.data, alpha=self.hps.weight_decay)
|
453 |
+
graft_grad.add_(p.data, alpha=self.hps.weight_decay)
|
454 |
+
|
455 |
+
# Momentum and Nesterov momentum, if needed
|
456 |
+
state[MOMENTUM].mul_(group['momentum']).add_(shampoo_grad)
|
457 |
+
graft_momentum = graft.update_momentum(grad, group['momentum'])
|
458 |
+
|
459 |
+
if state[STEP] >= self.hps.start_preconditioning_step:
|
460 |
+
momentum_update = state[MOMENTUM]
|
461 |
+
wd_update = shampoo_grad
|
462 |
+
else:
|
463 |
+
momentum_update = graft_momentum
|
464 |
+
wd_update = graft_grad
|
465 |
+
|
466 |
+
if hps.nesterov:
|
467 |
+
momentum_update.mul_(group['momentum']).add_(wd_update)
|
468 |
+
|
469 |
+
# Final update
|
470 |
+
p.data.add_(momentum_update, alpha=-lr)
|
raymarching/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .raymarching import *
|
raymarching/backend.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from torch.utils.cpp_extension import load
|
3 |
+
|
4 |
+
_src_path = os.path.dirname(os.path.abspath(__file__))
|
5 |
+
|
6 |
+
nvcc_flags = [
|
7 |
+
'-O3', '-std=c++14',
|
8 |
+
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
9 |
+
]
|
10 |
+
|
11 |
+
if os.name == "posix":
|
12 |
+
c_flags = ['-O3', '-std=c++14']
|
13 |
+
elif os.name == "nt":
|
14 |
+
c_flags = ['/O2', '/std:c++17']
|
15 |
+
|
16 |
+
# find cl.exe
|
17 |
+
def find_cl_path():
|
18 |
+
import glob
|
19 |
+
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
20 |
+
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
|
21 |
+
if paths:
|
22 |
+
return paths[0]
|
23 |
+
|
24 |
+
# If cl.exe is not on path, try to find it.
|
25 |
+
if os.system("where cl.exe >nul 2>nul") != 0:
|
26 |
+
cl_path = find_cl_path()
|
27 |
+
if cl_path is None:
|
28 |
+
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
29 |
+
os.environ["PATH"] += ";" + cl_path
|
30 |
+
|
31 |
+
_backend = load(name='_raymarching',
|
32 |
+
extra_cflags=c_flags,
|
33 |
+
extra_cuda_cflags=nvcc_flags,
|
34 |
+
sources=[os.path.join(_src_path, 'src', f) for f in [
|
35 |
+
'raymarching.cu',
|
36 |
+
'bindings.cpp',
|
37 |
+
]],
|
38 |
+
)
|
39 |
+
|
40 |
+
__all__ = ['_backend']
|
raymarching/raymarching.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import time
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.autograd import Function
|
7 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
8 |
+
|
9 |
+
try:
|
10 |
+
import _raymarching as _backend
|
11 |
+
except ImportError:
|
12 |
+
from .backend import _backend
|
13 |
+
|
14 |
+
|
15 |
+
# ----------------------------------------
|
16 |
+
# utils
|
17 |
+
# ----------------------------------------
|
18 |
+
|
19 |
+
class _near_far_from_aabb(Function):
|
20 |
+
@staticmethod
|
21 |
+
@custom_fwd(cast_inputs=torch.float32)
|
22 |
+
def forward(ctx, rays_o, rays_d, aabb, min_near=0.2):
|
23 |
+
''' near_far_from_aabb, CUDA implementation
|
24 |
+
Calculate rays' intersection time (near and far) with aabb
|
25 |
+
Args:
|
26 |
+
rays_o: float, [N, 3]
|
27 |
+
rays_d: float, [N, 3]
|
28 |
+
aabb: float, [6], (xmin, ymin, zmin, xmax, ymax, zmax)
|
29 |
+
min_near: float, scalar
|
30 |
+
Returns:
|
31 |
+
nears: float, [N]
|
32 |
+
fars: float, [N]
|
33 |
+
'''
|
34 |
+
if not rays_o.is_cuda: rays_o = rays_o.cuda()
|
35 |
+
if not rays_d.is_cuda: rays_d = rays_d.cuda()
|
36 |
+
|
37 |
+
rays_o = rays_o.contiguous().view(-1, 3)
|
38 |
+
rays_d = rays_d.contiguous().view(-1, 3)
|
39 |
+
|
40 |
+
N = rays_o.shape[0] # num rays
|
41 |
+
|
42 |
+
nears = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)
|
43 |
+
fars = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)
|
44 |
+
|
45 |
+
_backend.near_far_from_aabb(rays_o, rays_d, aabb, N, min_near, nears, fars)
|
46 |
+
|
47 |
+
return nears, fars
|
48 |
+
|
49 |
+
near_far_from_aabb = _near_far_from_aabb.apply
|
50 |
+
|
51 |
+
|
52 |
+
class _sph_from_ray(Function):
|
53 |
+
@staticmethod
|
54 |
+
@custom_fwd(cast_inputs=torch.float32)
|
55 |
+
def forward(ctx, rays_o, rays_d, radius):
|
56 |
+
''' sph_from_ray, CUDA implementation
|
57 |
+
get spherical coordinate on the background sphere from rays.
|
58 |
+
Assume rays_o are inside the Sphere(radius).
|
59 |
+
Args:
|
60 |
+
rays_o: [N, 3]
|
61 |
+
rays_d: [N, 3]
|
62 |
+
radius: scalar, float
|
63 |
+
Return:
|
64 |
+
coords: [N, 2], in [-1, 1], theta and phi on a sphere. (further-surface)
|
65 |
+
'''
|
66 |
+
if not rays_o.is_cuda: rays_o = rays_o.cuda()
|
67 |
+
if not rays_d.is_cuda: rays_d = rays_d.cuda()
|
68 |
+
|
69 |
+
rays_o = rays_o.contiguous().view(-1, 3)
|
70 |
+
rays_d = rays_d.contiguous().view(-1, 3)
|
71 |
+
|
72 |
+
N = rays_o.shape[0] # num rays
|
73 |
+
|
74 |
+
coords = torch.empty(N, 2, dtype=rays_o.dtype, device=rays_o.device)
|
75 |
+
|
76 |
+
_backend.sph_from_ray(rays_o, rays_d, radius, N, coords)
|
77 |
+
|
78 |
+
return coords
|
79 |
+
|
80 |
+
sph_from_ray = _sph_from_ray.apply
|
81 |
+
|
82 |
+
|
83 |
+
class _morton3D(Function):
|
84 |
+
@staticmethod
|
85 |
+
def forward(ctx, coords):
|
86 |
+
''' morton3D, CUDA implementation
|
87 |
+
Args:
|
88 |
+
coords: [N, 3], int32, in [0, 128) (for some reason there is no uint32 tensor in torch...)
|
89 |
+
TODO: check if the coord range is valid! (current 128 is safe)
|
90 |
+
Returns:
|
91 |
+
indices: [N], int32, in [0, 128^3)
|
92 |
+
|
93 |
+
'''
|
94 |
+
if not coords.is_cuda: coords = coords.cuda()
|
95 |
+
|
96 |
+
N = coords.shape[0]
|
97 |
+
|
98 |
+
indices = torch.empty(N, dtype=torch.int32, device=coords.device)
|
99 |
+
|
100 |
+
_backend.morton3D(coords.int(), N, indices)
|
101 |
+
|
102 |
+
return indices
|
103 |
+
|
104 |
+
morton3D = _morton3D.apply
|
105 |
+
|
106 |
+
class _morton3D_invert(Function):
|
107 |
+
@staticmethod
|
108 |
+
def forward(ctx, indices):
|
109 |
+
''' morton3D_invert, CUDA implementation
|
110 |
+
Args:
|
111 |
+
indices: [N], int32, in [0, 128^3)
|
112 |
+
Returns:
|
113 |
+
coords: [N, 3], int32, in [0, 128)
|
114 |
+
|
115 |
+
'''
|
116 |
+
if not indices.is_cuda: indices = indices.cuda()
|
117 |
+
|
118 |
+
N = indices.shape[0]
|
119 |
+
|
120 |
+
coords = torch.empty(N, 3, dtype=torch.int32, device=indices.device)
|
121 |
+
|
122 |
+
_backend.morton3D_invert(indices.int(), N, coords)
|
123 |
+
|
124 |
+
return coords
|
125 |
+
|
126 |
+
morton3D_invert = _morton3D_invert.apply
|
127 |
+
|
128 |
+
|
129 |
+
class _packbits(Function):
|
130 |
+
@staticmethod
|
131 |
+
@custom_fwd(cast_inputs=torch.float32)
|
132 |
+
def forward(ctx, grid, thresh, bitfield=None):
|
133 |
+
''' packbits, CUDA implementation
|
134 |
+
Pack up the density grid into a bit field to accelerate ray marching.
|
135 |
+
Args:
|
136 |
+
grid: float, [C, H * H * H], assume H % 2 == 0
|
137 |
+
thresh: float, threshold
|
138 |
+
Returns:
|
139 |
+
bitfield: uint8, [C, H * H * H / 8]
|
140 |
+
'''
|
141 |
+
if not grid.is_cuda: grid = grid.cuda()
|
142 |
+
grid = grid.contiguous()
|
143 |
+
|
144 |
+
C = grid.shape[0]
|
145 |
+
H3 = grid.shape[1]
|
146 |
+
N = C * H3 // 8
|
147 |
+
|
148 |
+
if bitfield is None:
|
149 |
+
bitfield = torch.empty(N, dtype=torch.uint8, device=grid.device)
|
150 |
+
|
151 |
+
_backend.packbits(grid, N, thresh, bitfield)
|
152 |
+
|
153 |
+
return bitfield
|
154 |
+
|
155 |
+
packbits = _packbits.apply
|
156 |
+
|
157 |
+
# ----------------------------------------
|
158 |
+
# train functions
|
159 |
+
# ----------------------------------------
|
160 |
+
|
161 |
+
class _march_rays_train(Function):
|
162 |
+
@staticmethod
|
163 |
+
@custom_fwd(cast_inputs=torch.float32)
|
164 |
+
def forward(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, step_counter=None, mean_count=-1, perturb=False, align=-1, force_all_rays=False, dt_gamma=0, max_steps=1024):
|
165 |
+
''' march rays to generate points (forward only)
|
166 |
+
Args:
|
167 |
+
rays_o/d: float, [N, 3]
|
168 |
+
bound: float, scalar
|
169 |
+
density_bitfield: uint8: [CHHH // 8]
|
170 |
+
C: int
|
171 |
+
H: int
|
172 |
+
nears/fars: float, [N]
|
173 |
+
step_counter: int32, (2), used to count the actual number of generated points.
|
174 |
+
mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.)
|
175 |
+
perturb: bool
|
176 |
+
align: int, pad output so its size is dividable by align, set to -1 to disable.
|
177 |
+
force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays.
|
178 |
+
dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)
|
179 |
+
max_steps: int, max number of sampled points along each ray, also affect min_stepsize.
|
180 |
+
Returns:
|
181 |
+
xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray)
|
182 |
+
dirs: float, [M, 3], all generated points' view dirs.
|
183 |
+
deltas: float, [M, 2], all generated points' deltas. (first for RGB, second for Depth)
|
184 |
+
rays: int32, [N, 3], all rays' (index, point_offset, point_count), e.g., xyzs[rays[i, 1]:rays[i, 2]] --> points belonging to rays[i, 0]
|
185 |
+
'''
|
186 |
+
|
187 |
+
if not rays_o.is_cuda: rays_o = rays_o.cuda()
|
188 |
+
if not rays_d.is_cuda: rays_d = rays_d.cuda()
|
189 |
+
if not density_bitfield.is_cuda: density_bitfield = density_bitfield.cuda()
|
190 |
+
|
191 |
+
rays_o = rays_o.contiguous().view(-1, 3)
|
192 |
+
rays_d = rays_d.contiguous().view(-1, 3)
|
193 |
+
density_bitfield = density_bitfield.contiguous()
|
194 |
+
|
195 |
+
N = rays_o.shape[0] # num rays
|
196 |
+
M = N * max_steps # init max points number in total
|
197 |
+
|
198 |
+
# running average based on previous epoch (mimic `measured_batch_size_before_compaction` in instant-ngp)
|
199 |
+
# It estimate the max points number to enable faster training, but will lead to random ignored rays if underestimated.
|
200 |
+
if not force_all_rays and mean_count > 0:
|
201 |
+
if align > 0:
|
202 |
+
mean_count += align - mean_count % align
|
203 |
+
M = mean_count
|
204 |
+
|
205 |
+
xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
|
206 |
+
dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
|
207 |
+
deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device)
|
208 |
+
rays = torch.empty(N, 3, dtype=torch.int32, device=rays_o.device) # id, offset, num_steps
|
209 |
+
|
210 |
+
if step_counter is None:
|
211 |
+
step_counter = torch.zeros(2, dtype=torch.int32, device=rays_o.device) # point counter, ray counter
|
212 |
+
|
213 |
+
if perturb:
|
214 |
+
noises = torch.rand(N, dtype=rays_o.dtype, device=rays_o.device)
|
215 |
+
else:
|
216 |
+
noises = torch.zeros(N, dtype=rays_o.dtype, device=rays_o.device)
|
217 |
+
|
218 |
+
_backend.march_rays_train(rays_o, rays_d, density_bitfield, bound, dt_gamma, max_steps, N, C, H, M, nears, fars, xyzs, dirs, deltas, rays, step_counter, noises) # m is the actually used points number
|
219 |
+
|
220 |
+
#print(step_counter, M)
|
221 |
+
|
222 |
+
# only used at the first (few) epochs.
|
223 |
+
if force_all_rays or mean_count <= 0:
|
224 |
+
m = step_counter[0].item() # D2H copy
|
225 |
+
if align > 0:
|
226 |
+
m += align - m % align
|
227 |
+
xyzs = xyzs[:m]
|
228 |
+
dirs = dirs[:m]
|
229 |
+
deltas = deltas[:m]
|
230 |
+
|
231 |
+
torch.cuda.empty_cache()
|
232 |
+
|
233 |
+
return xyzs, dirs, deltas, rays
|
234 |
+
|
235 |
+
march_rays_train = _march_rays_train.apply
|
236 |
+
|
237 |
+
|
238 |
+
class _composite_rays_train(Function):
|
239 |
+
@staticmethod
|
240 |
+
@custom_fwd(cast_inputs=torch.float32)
|
241 |
+
def forward(ctx, sigmas, rgbs, deltas, rays, T_thresh=1e-4):
|
242 |
+
''' composite rays' rgbs, according to the ray marching formula.
|
243 |
+
Args:
|
244 |
+
rgbs: float, [M, 3]
|
245 |
+
sigmas: float, [M,]
|
246 |
+
deltas: float, [M, 2]
|
247 |
+
rays: int32, [N, 3]
|
248 |
+
Returns:
|
249 |
+
weights_sum: float, [N,], the alpha channel
|
250 |
+
depth: float, [N, ], the Depth
|
251 |
+
image: float, [N, 3], the RGB channel (after multiplying alpha!)
|
252 |
+
'''
|
253 |
+
|
254 |
+
sigmas = sigmas.contiguous()
|
255 |
+
rgbs = rgbs.contiguous()
|
256 |
+
|
257 |
+
M = sigmas.shape[0]
|
258 |
+
N = rays.shape[0]
|
259 |
+
|
260 |
+
weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
|
261 |
+
depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
|
262 |
+
image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device)
|
263 |
+
|
264 |
+
_backend.composite_rays_train_forward(sigmas, rgbs, deltas, rays, M, N, T_thresh, weights_sum, depth, image)
|
265 |
+
|
266 |
+
ctx.save_for_backward(sigmas, rgbs, deltas, rays, weights_sum, depth, image)
|
267 |
+
ctx.dims = [M, N, T_thresh]
|
268 |
+
|
269 |
+
return weights_sum, depth, image
|
270 |
+
|
271 |
+
@staticmethod
|
272 |
+
@custom_bwd
|
273 |
+
def backward(ctx, grad_weights_sum, grad_depth, grad_image):
|
274 |
+
|
275 |
+
# NOTE: grad_depth is not used now! It won't be propagated to sigmas.
|
276 |
+
|
277 |
+
grad_weights_sum = grad_weights_sum.contiguous()
|
278 |
+
grad_image = grad_image.contiguous()
|
279 |
+
|
280 |
+
sigmas, rgbs, deltas, rays, weights_sum, depth, image = ctx.saved_tensors
|
281 |
+
M, N, T_thresh = ctx.dims
|
282 |
+
|
283 |
+
grad_sigmas = torch.zeros_like(sigmas)
|
284 |
+
grad_rgbs = torch.zeros_like(rgbs)
|
285 |
+
|
286 |
+
_backend.composite_rays_train_backward(grad_weights_sum, grad_image, sigmas, rgbs, deltas, rays, weights_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs)
|
287 |
+
|
288 |
+
return grad_sigmas, grad_rgbs, None, None, None
|
289 |
+
|
290 |
+
|
291 |
+
composite_rays_train = _composite_rays_train.apply
|
292 |
+
|
293 |
+
# ----------------------------------------
|
294 |
+
# infer functions
|
295 |
+
# ----------------------------------------
|
296 |
+
|
297 |
+
class _march_rays(Function):
|
298 |
+
@staticmethod
|
299 |
+
@custom_fwd(cast_inputs=torch.float32)
|
300 |
+
def forward(ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, density_bitfield, C, H, near, far, align=-1, perturb=False, dt_gamma=0, max_steps=1024):
|
301 |
+
''' march rays to generate points (forward only, for inference)
|
302 |
+
Args:
|
303 |
+
n_alive: int, number of alive rays
|
304 |
+
n_step: int, how many steps we march
|
305 |
+
rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive)
|
306 |
+
rays_t: float, [N], the alive rays' time, we only use the first n_alive.
|
307 |
+
rays_o/d: float, [N, 3]
|
308 |
+
bound: float, scalar
|
309 |
+
density_bitfield: uint8: [CHHH // 8]
|
310 |
+
C: int
|
311 |
+
H: int
|
312 |
+
nears/fars: float, [N]
|
313 |
+
align: int, pad output so its size is dividable by align, set to -1 to disable.
|
314 |
+
perturb: bool/int, int > 0 is used as the random seed.
|
315 |
+
dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)
|
316 |
+
max_steps: int, max number of sampled points along each ray, also affect min_stepsize.
|
317 |
+
Returns:
|
318 |
+
xyzs: float, [n_alive * n_step, 3], all generated points' coords
|
319 |
+
dirs: float, [n_alive * n_step, 3], all generated points' view dirs.
|
320 |
+
deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth).
|
321 |
+
'''
|
322 |
+
|
323 |
+
if not rays_o.is_cuda: rays_o = rays_o.cuda()
|
324 |
+
if not rays_d.is_cuda: rays_d = rays_d.cuda()
|
325 |
+
|
326 |
+
rays_o = rays_o.contiguous().view(-1, 3)
|
327 |
+
rays_d = rays_d.contiguous().view(-1, 3)
|
328 |
+
|
329 |
+
M = n_alive * n_step
|
330 |
+
|
331 |
+
if align > 0:
|
332 |
+
M += align - (M % align)
|
333 |
+
|
334 |
+
xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
|
335 |
+
dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
|
336 |
+
deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) # 2 vals, one for rgb, one for depth
|
337 |
+
|
338 |
+
if perturb:
|
339 |
+
# torch.manual_seed(perturb) # test_gui uses spp index as seed
|
340 |
+
noises = torch.rand(n_alive, dtype=rays_o.dtype, device=rays_o.device)
|
341 |
+
else:
|
342 |
+
noises = torch.zeros(n_alive, dtype=rays_o.dtype, device=rays_o.device)
|
343 |
+
|
344 |
+
_backend.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, dt_gamma, max_steps, C, H, density_bitfield, near, far, xyzs, dirs, deltas, noises)
|
345 |
+
|
346 |
+
return xyzs, dirs, deltas
|
347 |
+
|
348 |
+
march_rays = _march_rays.apply
|
349 |
+
|
350 |
+
|
351 |
+
class _composite_rays(Function):
|
352 |
+
@staticmethod
|
353 |
+
@custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
|
354 |
+
def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh=1e-2):
|
355 |
+
''' composite rays' rgbs, according to the ray marching formula. (for inference)
|
356 |
+
Args:
|
357 |
+
n_alive: int, number of alive rays
|
358 |
+
n_step: int, how many steps we march
|
359 |
+
rays_alive: int, [n_alive], the alive rays' IDs in N (N >= n_alive)
|
360 |
+
rays_t: float, [N], the alive rays' time
|
361 |
+
sigmas: float, [n_alive * n_step,]
|
362 |
+
rgbs: float, [n_alive * n_step, 3]
|
363 |
+
deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth).
|
364 |
+
In-place Outputs:
|
365 |
+
weights_sum: float, [N,], the alpha channel
|
366 |
+
depth: float, [N,], the depth value
|
367 |
+
image: float, [N, 3], the RGB channel (after multiplying alpha!)
|
368 |
+
'''
|
369 |
+
_backend.composite_rays(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image)
|
370 |
+
return tuple()
|
371 |
+
|
372 |
+
|
373 |
+
composite_rays = _composite_rays.apply
|
raymarching/setup.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from setuptools import setup
|
3 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
4 |
+
|
5 |
+
_src_path = os.path.dirname(os.path.abspath(__file__))
|
6 |
+
|
7 |
+
nvcc_flags = [
|
8 |
+
'-O3', '-std=c++14',
|
9 |
+
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
10 |
+
]
|
11 |
+
|
12 |
+
if os.name == "posix":
|
13 |
+
c_flags = ['-O3', '-std=c++14']
|
14 |
+
elif os.name == "nt":
|
15 |
+
c_flags = ['/O2', '/std:c++17']
|
16 |
+
|
17 |
+
# find cl.exe
|
18 |
+
def find_cl_path():
|
19 |
+
import glob
|
20 |
+
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
21 |
+
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
|
22 |
+
if paths:
|
23 |
+
return paths[0]
|
24 |
+
|
25 |
+
# If cl.exe is not on path, try to find it.
|
26 |
+
if os.system("where cl.exe >nul 2>nul") != 0:
|
27 |
+
cl_path = find_cl_path()
|
28 |
+
if cl_path is None:
|
29 |
+
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
30 |
+
os.environ["PATH"] += ";" + cl_path
|
31 |
+
|
32 |
+
'''
|
33 |
+
Usage:
|
34 |
+
|
35 |
+
python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory)
|
36 |
+
|
37 |
+
python setup.py install # build extensions and install (copy) to PATH.
|
38 |
+
pip install . # ditto but better (e.g., dependency & metadata handling)
|
39 |
+
|
40 |
+
python setup.py develop # build extensions and install (symbolic) to PATH.
|
41 |
+
pip install -e . # ditto but better (e.g., dependency & metadata handling)
|
42 |
+
|
43 |
+
'''
|
44 |
+
setup(
|
45 |
+
name='raymarching', # package name, import this to use python API
|
46 |
+
ext_modules=[
|
47 |
+
CUDAExtension(
|
48 |
+
name='_raymarching', # extension name, import this to use CUDA API
|
49 |
+
sources=[os.path.join(_src_path, 'src', f) for f in [
|
50 |
+
'raymarching.cu',
|
51 |
+
'bindings.cpp',
|
52 |
+
]],
|
53 |
+
extra_compile_args={
|
54 |
+
'cxx': c_flags,
|
55 |
+
'nvcc': nvcc_flags,
|
56 |
+
}
|
57 |
+
),
|
58 |
+
],
|
59 |
+
cmdclass={
|
60 |
+
'build_ext': BuildExtension,
|
61 |
+
}
|
62 |
+
)
|
raymarching/src/bindings.cpp
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
|
3 |
+
#include "raymarching.h"
|
4 |
+
|
5 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
6 |
+
// utils
|
7 |
+
m.def("packbits", &packbits, "packbits (CUDA)");
|
8 |
+
m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)");
|
9 |
+
m.def("sph_from_ray", &sph_from_ray, "sph_from_ray (CUDA)");
|
10 |
+
m.def("morton3D", &morton3D, "morton3D (CUDA)");
|
11 |
+
m.def("morton3D_invert", &morton3D_invert, "morton3D_invert (CUDA)");
|
12 |
+
// train
|
13 |
+
m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)");
|
14 |
+
m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)");
|
15 |
+
m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)");
|
16 |
+
// infer
|
17 |
+
m.def("march_rays", &march_rays, "march rays (CUDA)");
|
18 |
+
m.def("composite_rays", &composite_rays, "composite rays (CUDA)");
|
19 |
+
}
|
raymarching/src/raymarching.cu
ADDED
@@ -0,0 +1,914 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <cuda.h>
|
2 |
+
#include <cuda_fp16.h>
|
3 |
+
#include <cuda_runtime.h>
|
4 |
+
|
5 |
+
#include <ATen/cuda/CUDAContext.h>
|
6 |
+
#include <torch/torch.h>
|
7 |
+
|
8 |
+
#include <cstdio>
|
9 |
+
#include <stdint.h>
|
10 |
+
#include <stdexcept>
|
11 |
+
#include <limits>
|
12 |
+
|
13 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
14 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
|
15 |
+
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
|
16 |
+
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
|
17 |
+
|
18 |
+
|
19 |
+
inline constexpr __device__ float SQRT3() { return 1.7320508075688772f; }
|
20 |
+
inline constexpr __device__ float RSQRT3() { return 0.5773502691896258f; }
|
21 |
+
inline constexpr __device__ float PI() { return 3.141592653589793f; }
|
22 |
+
inline constexpr __device__ float RPI() { return 0.3183098861837907f; }
|
23 |
+
|
24 |
+
|
25 |
+
template <typename T>
|
26 |
+
inline __host__ __device__ T div_round_up(T val, T divisor) {
|
27 |
+
return (val + divisor - 1) / divisor;
|
28 |
+
}
|
29 |
+
|
30 |
+
inline __host__ __device__ float signf(const float x) {
|
31 |
+
return copysignf(1.0, x);
|
32 |
+
}
|
33 |
+
|
34 |
+
inline __host__ __device__ float clamp(const float x, const float min, const float max) {
|
35 |
+
return fminf(max, fmaxf(min, x));
|
36 |
+
}
|
37 |
+
|
38 |
+
inline __host__ __device__ void swapf(float& a, float& b) {
|
39 |
+
float c = a; a = b; b = c;
|
40 |
+
}
|
41 |
+
|
42 |
+
inline __device__ int mip_from_pos(const float x, const float y, const float z, const float max_cascade) {
|
43 |
+
const float mx = fmaxf(fabsf(x), fmaxf(fabs(y), fabs(z)));
|
44 |
+
int exponent;
|
45 |
+
frexpf(mx, &exponent); // [0, 0.5) --> -1, [0.5, 1) --> 0, [1, 2) --> 1, [2, 4) --> 2, ...
|
46 |
+
return fminf(max_cascade - 1, fmaxf(0, exponent));
|
47 |
+
}
|
48 |
+
|
49 |
+
inline __device__ int mip_from_dt(const float dt, const float H, const float max_cascade) {
|
50 |
+
const float mx = dt * H * 0.5;
|
51 |
+
int exponent;
|
52 |
+
frexpf(mx, &exponent);
|
53 |
+
return fminf(max_cascade - 1, fmaxf(0, exponent));
|
54 |
+
}
|
55 |
+
|
56 |
+
inline __host__ __device__ uint32_t __expand_bits(uint32_t v)
|
57 |
+
{
|
58 |
+
v = (v * 0x00010001u) & 0xFF0000FFu;
|
59 |
+
v = (v * 0x00000101u) & 0x0F00F00Fu;
|
60 |
+
v = (v * 0x00000011u) & 0xC30C30C3u;
|
61 |
+
v = (v * 0x00000005u) & 0x49249249u;
|
62 |
+
return v;
|
63 |
+
}
|
64 |
+
|
65 |
+
inline __host__ __device__ uint32_t __morton3D(uint32_t x, uint32_t y, uint32_t z)
|
66 |
+
{
|
67 |
+
uint32_t xx = __expand_bits(x);
|
68 |
+
uint32_t yy = __expand_bits(y);
|
69 |
+
uint32_t zz = __expand_bits(z);
|
70 |
+
return xx | (yy << 1) | (zz << 2);
|
71 |
+
}
|
72 |
+
|
73 |
+
inline __host__ __device__ uint32_t __morton3D_invert(uint32_t x)
|
74 |
+
{
|
75 |
+
x = x & 0x49249249;
|
76 |
+
x = (x | (x >> 2)) & 0xc30c30c3;
|
77 |
+
x = (x | (x >> 4)) & 0x0f00f00f;
|
78 |
+
x = (x | (x >> 8)) & 0xff0000ff;
|
79 |
+
x = (x | (x >> 16)) & 0x0000ffff;
|
80 |
+
return x;
|
81 |
+
}
|
82 |
+
|
83 |
+
|
84 |
+
////////////////////////////////////////////////////
|
85 |
+
///////////// utils /////////////
|
86 |
+
////////////////////////////////////////////////////
|
87 |
+
|
88 |
+
// rays_o/d: [N, 3]
|
89 |
+
// nears/fars: [N]
|
90 |
+
// scalar_t should always be float in use.
|
91 |
+
template <typename scalar_t>
|
92 |
+
__global__ void kernel_near_far_from_aabb(
|
93 |
+
const scalar_t * __restrict__ rays_o,
|
94 |
+
const scalar_t * __restrict__ rays_d,
|
95 |
+
const scalar_t * __restrict__ aabb,
|
96 |
+
const uint32_t N,
|
97 |
+
const float min_near,
|
98 |
+
scalar_t * nears, scalar_t * fars
|
99 |
+
) {
|
100 |
+
// parallel per ray
|
101 |
+
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
102 |
+
if (n >= N) return;
|
103 |
+
|
104 |
+
// locate
|
105 |
+
rays_o += n * 3;
|
106 |
+
rays_d += n * 3;
|
107 |
+
|
108 |
+
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
|
109 |
+
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
|
110 |
+
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
|
111 |
+
|
112 |
+
// get near far (assume cube scene)
|
113 |
+
float near = (aabb[0] - ox) * rdx;
|
114 |
+
float far = (aabb[3] - ox) * rdx;
|
115 |
+
if (near > far) swapf(near, far);
|
116 |
+
|
117 |
+
float near_y = (aabb[1] - oy) * rdy;
|
118 |
+
float far_y = (aabb[4] - oy) * rdy;
|
119 |
+
if (near_y > far_y) swapf(near_y, far_y);
|
120 |
+
|
121 |
+
if (near > far_y || near_y > far) {
|
122 |
+
nears[n] = fars[n] = std::numeric_limits<scalar_t>::max();
|
123 |
+
return;
|
124 |
+
}
|
125 |
+
|
126 |
+
if (near_y > near) near = near_y;
|
127 |
+
if (far_y < far) far = far_y;
|
128 |
+
|
129 |
+
float near_z = (aabb[2] - oz) * rdz;
|
130 |
+
float far_z = (aabb[5] - oz) * rdz;
|
131 |
+
if (near_z > far_z) swapf(near_z, far_z);
|
132 |
+
|
133 |
+
if (near > far_z || near_z > far) {
|
134 |
+
nears[n] = fars[n] = std::numeric_limits<scalar_t>::max();
|
135 |
+
return;
|
136 |
+
}
|
137 |
+
|
138 |
+
if (near_z > near) near = near_z;
|
139 |
+
if (far_z < far) far = far_z;
|
140 |
+
|
141 |
+
if (near < min_near) near = min_near;
|
142 |
+
|
143 |
+
nears[n] = near;
|
144 |
+
fars[n] = far;
|
145 |
+
}
|
146 |
+
|
147 |
+
|
148 |
+
void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars) {
|
149 |
+
|
150 |
+
static constexpr uint32_t N_THREAD = 128;
|
151 |
+
|
152 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
153 |
+
rays_o.scalar_type(), "near_far_from_aabb", ([&] {
|
154 |
+
kernel_near_far_from_aabb<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), aabb.data_ptr<scalar_t>(), N, min_near, nears.data_ptr<scalar_t>(), fars.data_ptr<scalar_t>());
|
155 |
+
}));
|
156 |
+
}
|
157 |
+
|
158 |
+
|
159 |
+
// rays_o/d: [N, 3]
|
160 |
+
// radius: float
|
161 |
+
// coords: [N, 2]
|
162 |
+
template <typename scalar_t>
|
163 |
+
__global__ void kernel_sph_from_ray(
|
164 |
+
const scalar_t * __restrict__ rays_o,
|
165 |
+
const scalar_t * __restrict__ rays_d,
|
166 |
+
const float radius,
|
167 |
+
const uint32_t N,
|
168 |
+
scalar_t * coords
|
169 |
+
) {
|
170 |
+
// parallel per ray
|
171 |
+
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
172 |
+
if (n >= N) return;
|
173 |
+
|
174 |
+
// locate
|
175 |
+
rays_o += n * 3;
|
176 |
+
rays_d += n * 3;
|
177 |
+
coords += n * 2;
|
178 |
+
|
179 |
+
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
|
180 |
+
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
|
181 |
+
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
|
182 |
+
|
183 |
+
// solve t from || o + td || = radius
|
184 |
+
const float A = dx * dx + dy * dy + dz * dz;
|
185 |
+
const float B = ox * dx + oy * dy + oz * dz; // in fact B / 2
|
186 |
+
const float C = ox * ox + oy * oy + oz * oz - radius * radius;
|
187 |
+
|
188 |
+
const float t = (- B + sqrtf(B * B - A * C)) / A; // always use the larger solution (positive)
|
189 |
+
|
190 |
+
// solve theta, phi (assume y is the up axis)
|
191 |
+
const float x = ox + t * dx, y = oy + t * dy, z = oz + t * dz;
|
192 |
+
const float theta = atan2(sqrtf(x * x + z * z), y); // [0, PI)
|
193 |
+
const float phi = atan2(z, x); // [-PI, PI)
|
194 |
+
|
195 |
+
// normalize to [-1, 1]
|
196 |
+
coords[0] = 2 * theta * RPI() - 1;
|
197 |
+
coords[1] = phi * RPI();
|
198 |
+
}
|
199 |
+
|
200 |
+
|
201 |
+
void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords) {
|
202 |
+
|
203 |
+
static constexpr uint32_t N_THREAD = 128;
|
204 |
+
|
205 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
206 |
+
rays_o.scalar_type(), "sph_from_ray", ([&] {
|
207 |
+
kernel_sph_from_ray<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), radius, N, coords.data_ptr<scalar_t>());
|
208 |
+
}));
|
209 |
+
}
|
210 |
+
|
211 |
+
|
212 |
+
// coords: int32, [N, 3]
|
213 |
+
// indices: int32, [N]
|
214 |
+
__global__ void kernel_morton3D(
|
215 |
+
const int * __restrict__ coords,
|
216 |
+
const uint32_t N,
|
217 |
+
int * indices
|
218 |
+
) {
|
219 |
+
// parallel
|
220 |
+
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
221 |
+
if (n >= N) return;
|
222 |
+
|
223 |
+
// locate
|
224 |
+
coords += n * 3;
|
225 |
+
indices[n] = __morton3D(coords[0], coords[1], coords[2]);
|
226 |
+
}
|
227 |
+
|
228 |
+
|
229 |
+
void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices) {
|
230 |
+
static constexpr uint32_t N_THREAD = 128;
|
231 |
+
kernel_morton3D<<<div_round_up(N, N_THREAD), N_THREAD>>>(coords.data_ptr<int>(), N, indices.data_ptr<int>());
|
232 |
+
}
|
233 |
+
|
234 |
+
|
235 |
+
// indices: int32, [N]
|
236 |
+
// coords: int32, [N, 3]
|
237 |
+
__global__ void kernel_morton3D_invert(
|
238 |
+
const int * __restrict__ indices,
|
239 |
+
const uint32_t N,
|
240 |
+
int * coords
|
241 |
+
) {
|
242 |
+
// parallel
|
243 |
+
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
244 |
+
if (n >= N) return;
|
245 |
+
|
246 |
+
// locate
|
247 |
+
coords += n * 3;
|
248 |
+
|
249 |
+
const int ind = indices[n];
|
250 |
+
|
251 |
+
coords[0] = __morton3D_invert(ind >> 0);
|
252 |
+
coords[1] = __morton3D_invert(ind >> 1);
|
253 |
+
coords[2] = __morton3D_invert(ind >> 2);
|
254 |
+
}
|
255 |
+
|
256 |
+
|
257 |
+
void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords) {
|
258 |
+
static constexpr uint32_t N_THREAD = 128;
|
259 |
+
kernel_morton3D_invert<<<div_round_up(N, N_THREAD), N_THREAD>>>(indices.data_ptr<int>(), N, coords.data_ptr<int>());
|
260 |
+
}
|
261 |
+
|
262 |
+
|
263 |
+
// grid: float, [C, H, H, H]
|
264 |
+
// N: int, C * H * H * H / 8
|
265 |
+
// density_thresh: float
|
266 |
+
// bitfield: uint8, [N]
|
267 |
+
template <typename scalar_t>
|
268 |
+
__global__ void kernel_packbits(
|
269 |
+
const scalar_t * __restrict__ grid,
|
270 |
+
const uint32_t N,
|
271 |
+
const float density_thresh,
|
272 |
+
uint8_t * bitfield
|
273 |
+
) {
|
274 |
+
// parallel per byte
|
275 |
+
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
276 |
+
if (n >= N) return;
|
277 |
+
|
278 |
+
// locate
|
279 |
+
grid += n * 8;
|
280 |
+
|
281 |
+
uint8_t bits = 0;
|
282 |
+
|
283 |
+
#pragma unroll
|
284 |
+
for (uint8_t i = 0; i < 8; i++) {
|
285 |
+
bits |= (grid[i] > density_thresh) ? ((uint8_t)1 << i) : 0;
|
286 |
+
}
|
287 |
+
|
288 |
+
bitfield[n] = bits;
|
289 |
+
}
|
290 |
+
|
291 |
+
|
292 |
+
void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield) {
|
293 |
+
|
294 |
+
static constexpr uint32_t N_THREAD = 128;
|
295 |
+
|
296 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
297 |
+
grid.scalar_type(), "packbits", ([&] {
|
298 |
+
kernel_packbits<<<div_round_up(N, N_THREAD), N_THREAD>>>(grid.data_ptr<scalar_t>(), N, density_thresh, bitfield.data_ptr<uint8_t>());
|
299 |
+
}));
|
300 |
+
}
|
301 |
+
|
302 |
+
////////////////////////////////////////////////////
|
303 |
+
///////////// training /////////////
|
304 |
+
////////////////////////////////////////////////////
|
305 |
+
|
306 |
+
// rays_o/d: [N, 3]
|
307 |
+
// grid: [CHHH / 8]
|
308 |
+
// xyzs, dirs, deltas: [M, 3], [M, 3], [M, 2]
|
309 |
+
// dirs: [M, 3]
|
310 |
+
// rays: [N, 3], idx, offset, num_steps
|
311 |
+
template <typename scalar_t>
|
312 |
+
__global__ void kernel_march_rays_train(
|
313 |
+
const scalar_t * __restrict__ rays_o,
|
314 |
+
const scalar_t * __restrict__ rays_d,
|
315 |
+
const uint8_t * __restrict__ grid,
|
316 |
+
const float bound,
|
317 |
+
const float dt_gamma, const uint32_t max_steps,
|
318 |
+
const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M,
|
319 |
+
const scalar_t* __restrict__ nears,
|
320 |
+
const scalar_t* __restrict__ fars,
|
321 |
+
scalar_t * xyzs, scalar_t * dirs, scalar_t * deltas,
|
322 |
+
int * rays,
|
323 |
+
int * counter,
|
324 |
+
const scalar_t* __restrict__ noises
|
325 |
+
) {
|
326 |
+
// parallel per ray
|
327 |
+
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
328 |
+
if (n >= N) return;
|
329 |
+
|
330 |
+
// locate
|
331 |
+
rays_o += n * 3;
|
332 |
+
rays_d += n * 3;
|
333 |
+
|
334 |
+
// ray marching
|
335 |
+
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
|
336 |
+
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
|
337 |
+
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
|
338 |
+
const float rH = 1 / (float)H;
|
339 |
+
const float H3 = H * H * H;
|
340 |
+
|
341 |
+
const float near = nears[n];
|
342 |
+
const float far = fars[n];
|
343 |
+
const float noise = noises[n];
|
344 |
+
|
345 |
+
const float dt_min = 2 * SQRT3() / max_steps;
|
346 |
+
const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H;
|
347 |
+
|
348 |
+
float t0 = near;
|
349 |
+
|
350 |
+
// perturb
|
351 |
+
t0 += clamp(t0 * dt_gamma, dt_min, dt_max) * noise;
|
352 |
+
|
353 |
+
// first pass: estimation of num_steps
|
354 |
+
float t = t0;
|
355 |
+
uint32_t num_steps = 0;
|
356 |
+
|
357 |
+
//if (t < far) printf("valid ray %d t=%f near=%f far=%f \n", n, t, near, far);
|
358 |
+
|
359 |
+
while (t < far && num_steps < max_steps) {
|
360 |
+
// current point
|
361 |
+
const float x = clamp(ox + t * dx, -bound, bound);
|
362 |
+
const float y = clamp(oy + t * dy, -bound, bound);
|
363 |
+
const float z = clamp(oz + t * dz, -bound, bound);
|
364 |
+
|
365 |
+
const float dt = clamp(t * dt_gamma, dt_min, dt_max);
|
366 |
+
|
367 |
+
// get mip level
|
368 |
+
const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]
|
369 |
+
|
370 |
+
const float mip_bound = fminf(scalbnf(1.0f, level), bound);
|
371 |
+
const float mip_rbound = 1 / mip_bound;
|
372 |
+
|
373 |
+
// convert to nearest grid position
|
374 |
+
const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
375 |
+
const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
376 |
+
const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
377 |
+
|
378 |
+
const uint32_t index = level * H3 + __morton3D(nx, ny, nz);
|
379 |
+
const bool occ = grid[index / 8] & (1 << (index % 8));
|
380 |
+
|
381 |
+
// if occpuied, advance a small step, and write to output
|
382 |
+
//if (n == 0) printf("t=%f density=%f vs thresh=%f step=%d\n", t, density, density_thresh, num_steps);
|
383 |
+
|
384 |
+
if (occ) {
|
385 |
+
num_steps++;
|
386 |
+
t += dt;
|
387 |
+
// else, skip a large step (basically skip a voxel grid)
|
388 |
+
} else {
|
389 |
+
// calc distance to next voxel
|
390 |
+
const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx;
|
391 |
+
const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy;
|
392 |
+
const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz;
|
393 |
+
|
394 |
+
const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
|
395 |
+
// step until next voxel
|
396 |
+
do {
|
397 |
+
t += clamp(t * dt_gamma, dt_min, dt_max);
|
398 |
+
} while (t < tt);
|
399 |
+
}
|
400 |
+
}
|
401 |
+
|
402 |
+
//printf("[n=%d] num_steps=%d, near=%f, far=%f, dt=%f, max_steps=%f\n", n, num_steps, near, far, dt_min, (far - near) / dt_min);
|
403 |
+
|
404 |
+
// second pass: really locate and write points & dirs
|
405 |
+
uint32_t point_index = atomicAdd(counter, num_steps);
|
406 |
+
uint32_t ray_index = atomicAdd(counter + 1, 1);
|
407 |
+
|
408 |
+
//printf("[n=%d] num_steps=%d, point_index=%d, ray_index=%d\n", n, num_steps, point_index, ray_index);
|
409 |
+
|
410 |
+
// write rays
|
411 |
+
rays[ray_index * 3] = n;
|
412 |
+
rays[ray_index * 3 + 1] = point_index;
|
413 |
+
rays[ray_index * 3 + 2] = num_steps;
|
414 |
+
|
415 |
+
if (num_steps == 0) return;
|
416 |
+
if (point_index + num_steps > M) return;
|
417 |
+
|
418 |
+
xyzs += point_index * 3;
|
419 |
+
dirs += point_index * 3;
|
420 |
+
deltas += point_index * 2;
|
421 |
+
|
422 |
+
t = t0;
|
423 |
+
uint32_t step = 0;
|
424 |
+
|
425 |
+
float last_t = t;
|
426 |
+
|
427 |
+
while (t < far && step < num_steps) {
|
428 |
+
// current point
|
429 |
+
const float x = clamp(ox + t * dx, -bound, bound);
|
430 |
+
const float y = clamp(oy + t * dy, -bound, bound);
|
431 |
+
const float z = clamp(oz + t * dz, -bound, bound);
|
432 |
+
|
433 |
+
const float dt = clamp(t * dt_gamma, dt_min, dt_max);
|
434 |
+
|
435 |
+
// get mip level
|
436 |
+
const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]
|
437 |
+
|
438 |
+
const float mip_bound = fminf(scalbnf(1.0f, level), bound);
|
439 |
+
const float mip_rbound = 1 / mip_bound;
|
440 |
+
|
441 |
+
// convert to nearest grid position
|
442 |
+
const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
443 |
+
const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
444 |
+
const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
445 |
+
|
446 |
+
// query grid
|
447 |
+
const uint32_t index = level * H3 + __morton3D(nx, ny, nz);
|
448 |
+
const bool occ = grid[index / 8] & (1 << (index % 8));
|
449 |
+
|
450 |
+
// if occpuied, advance a small step, and write to output
|
451 |
+
if (occ) {
|
452 |
+
// write step
|
453 |
+
xyzs[0] = x;
|
454 |
+
xyzs[1] = y;
|
455 |
+
xyzs[2] = z;
|
456 |
+
dirs[0] = dx;
|
457 |
+
dirs[1] = dy;
|
458 |
+
dirs[2] = dz;
|
459 |
+
t += dt;
|
460 |
+
deltas[0] = dt;
|
461 |
+
deltas[1] = t - last_t; // used to calc depth
|
462 |
+
last_t = t;
|
463 |
+
xyzs += 3;
|
464 |
+
dirs += 3;
|
465 |
+
deltas += 2;
|
466 |
+
step++;
|
467 |
+
// else, skip a large step (basically skip a voxel grid)
|
468 |
+
} else {
|
469 |
+
// calc distance to next voxel
|
470 |
+
const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx;
|
471 |
+
const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy;
|
472 |
+
const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz;
|
473 |
+
const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
|
474 |
+
// step until next voxel
|
475 |
+
do {
|
476 |
+
t += clamp(t * dt_gamma, dt_min, dt_max);
|
477 |
+
} while (t < tt);
|
478 |
+
}
|
479 |
+
}
|
480 |
+
}
|
481 |
+
|
482 |
+
void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises) {
|
483 |
+
|
484 |
+
static constexpr uint32_t N_THREAD = 128;
|
485 |
+
|
486 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
487 |
+
rays_o.scalar_type(), "march_rays_train", ([&] {
|
488 |
+
kernel_march_rays_train<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), grid.data_ptr<uint8_t>(), bound, dt_gamma, max_steps, N, C, H, M, nears.data_ptr<scalar_t>(), fars.data_ptr<scalar_t>(), xyzs.data_ptr<scalar_t>(), dirs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), counter.data_ptr<int>(), noises.data_ptr<scalar_t>());
|
489 |
+
}));
|
490 |
+
}
|
491 |
+
|
492 |
+
|
493 |
+
// sigmas: [M]
|
494 |
+
// rgbs: [M, 3]
|
495 |
+
// deltas: [M, 2]
|
496 |
+
// rays: [N, 3], idx, offset, num_steps
|
497 |
+
// weights_sum: [N], final pixel alpha
|
498 |
+
// depth: [N,]
|
499 |
+
// image: [N, 3]
|
500 |
+
template <typename scalar_t>
|
501 |
+
__global__ void kernel_composite_rays_train_forward(
|
502 |
+
const scalar_t * __restrict__ sigmas,
|
503 |
+
const scalar_t * __restrict__ rgbs,
|
504 |
+
const scalar_t * __restrict__ deltas,
|
505 |
+
const int * __restrict__ rays,
|
506 |
+
const uint32_t M, const uint32_t N, const float T_thresh,
|
507 |
+
scalar_t * weights_sum,
|
508 |
+
scalar_t * depth,
|
509 |
+
scalar_t * image
|
510 |
+
) {
|
511 |
+
// parallel per ray
|
512 |
+
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
513 |
+
if (n >= N) return;
|
514 |
+
|
515 |
+
// locate
|
516 |
+
uint32_t index = rays[n * 3];
|
517 |
+
uint32_t offset = rays[n * 3 + 1];
|
518 |
+
uint32_t num_steps = rays[n * 3 + 2];
|
519 |
+
|
520 |
+
// empty ray, or ray that exceed max step count.
|
521 |
+
if (num_steps == 0 || offset + num_steps > M) {
|
522 |
+
weights_sum[index] = 0;
|
523 |
+
depth[index] = 0;
|
524 |
+
image[index * 3] = 0;
|
525 |
+
image[index * 3 + 1] = 0;
|
526 |
+
image[index * 3 + 2] = 0;
|
527 |
+
return;
|
528 |
+
}
|
529 |
+
|
530 |
+
sigmas += offset;
|
531 |
+
rgbs += offset * 3;
|
532 |
+
deltas += offset * 2;
|
533 |
+
|
534 |
+
// accumulate
|
535 |
+
uint32_t step = 0;
|
536 |
+
|
537 |
+
scalar_t T = 1.0f;
|
538 |
+
scalar_t r = 0, g = 0, b = 0, ws = 0, t = 0, d = 0;
|
539 |
+
|
540 |
+
while (step < num_steps) {
|
541 |
+
|
542 |
+
const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);
|
543 |
+
const scalar_t weight = alpha * T;
|
544 |
+
|
545 |
+
r += weight * rgbs[0];
|
546 |
+
g += weight * rgbs[1];
|
547 |
+
b += weight * rgbs[2];
|
548 |
+
|
549 |
+
t += deltas[1]; // real delta
|
550 |
+
d += weight * t;
|
551 |
+
|
552 |
+
ws += weight;
|
553 |
+
|
554 |
+
T *= 1.0f - alpha;
|
555 |
+
|
556 |
+
// minimal remained transmittence
|
557 |
+
if (T < T_thresh) break;
|
558 |
+
|
559 |
+
//printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d);
|
560 |
+
|
561 |
+
// locate
|
562 |
+
sigmas++;
|
563 |
+
rgbs += 3;
|
564 |
+
deltas += 2;
|
565 |
+
|
566 |
+
step++;
|
567 |
+
}
|
568 |
+
|
569 |
+
//printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d);
|
570 |
+
|
571 |
+
// write
|
572 |
+
weights_sum[index] = ws; // weights_sum
|
573 |
+
depth[index] = d;
|
574 |
+
image[index * 3] = r;
|
575 |
+
image[index * 3 + 1] = g;
|
576 |
+
image[index * 3 + 2] = b;
|
577 |
+
}
|
578 |
+
|
579 |
+
|
580 |
+
void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor depth, at::Tensor image) {
|
581 |
+
|
582 |
+
static constexpr uint32_t N_THREAD = 128;
|
583 |
+
|
584 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
585 |
+
sigmas.scalar_type(), "composite_rays_train_forward", ([&] {
|
586 |
+
kernel_composite_rays_train_forward<<<div_round_up(N, N_THREAD), N_THREAD>>>(sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), M, N, T_thresh, weights_sum.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>());
|
587 |
+
}));
|
588 |
+
}
|
589 |
+
|
590 |
+
|
591 |
+
// grad_weights_sum: [N,]
|
592 |
+
// grad: [N, 3]
|
593 |
+
// sigmas: [M]
|
594 |
+
// rgbs: [M, 3]
|
595 |
+
// deltas: [M, 2]
|
596 |
+
// rays: [N, 3], idx, offset, num_steps
|
597 |
+
// weights_sum: [N,], weights_sum here
|
598 |
+
// image: [N, 3]
|
599 |
+
// grad_sigmas: [M]
|
600 |
+
// grad_rgbs: [M, 3]
|
601 |
+
template <typename scalar_t>
|
602 |
+
__global__ void kernel_composite_rays_train_backward(
|
603 |
+
const scalar_t * __restrict__ grad_weights_sum,
|
604 |
+
const scalar_t * __restrict__ grad_image,
|
605 |
+
const scalar_t * __restrict__ sigmas,
|
606 |
+
const scalar_t * __restrict__ rgbs,
|
607 |
+
const scalar_t * __restrict__ deltas,
|
608 |
+
const int * __restrict__ rays,
|
609 |
+
const scalar_t * __restrict__ weights_sum,
|
610 |
+
const scalar_t * __restrict__ image,
|
611 |
+
const uint32_t M, const uint32_t N, const float T_thresh,
|
612 |
+
scalar_t * grad_sigmas,
|
613 |
+
scalar_t * grad_rgbs
|
614 |
+
) {
|
615 |
+
// parallel per ray
|
616 |
+
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
617 |
+
if (n >= N) return;
|
618 |
+
|
619 |
+
// locate
|
620 |
+
uint32_t index = rays[n * 3];
|
621 |
+
uint32_t offset = rays[n * 3 + 1];
|
622 |
+
uint32_t num_steps = rays[n * 3 + 2];
|
623 |
+
|
624 |
+
if (num_steps == 0 || offset + num_steps > M) return;
|
625 |
+
|
626 |
+
grad_weights_sum += index;
|
627 |
+
grad_image += index * 3;
|
628 |
+
weights_sum += index;
|
629 |
+
image += index * 3;
|
630 |
+
sigmas += offset;
|
631 |
+
rgbs += offset * 3;
|
632 |
+
deltas += offset * 2;
|
633 |
+
grad_sigmas += offset;
|
634 |
+
grad_rgbs += offset * 3;
|
635 |
+
|
636 |
+
// accumulate
|
637 |
+
uint32_t step = 0;
|
638 |
+
|
639 |
+
scalar_t T = 1.0f;
|
640 |
+
const scalar_t r_final = image[0], g_final = image[1], b_final = image[2], ws_final = weights_sum[0];
|
641 |
+
scalar_t r = 0, g = 0, b = 0, ws = 0;
|
642 |
+
|
643 |
+
while (step < num_steps) {
|
644 |
+
|
645 |
+
const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);
|
646 |
+
const scalar_t weight = alpha * T;
|
647 |
+
|
648 |
+
r += weight * rgbs[0];
|
649 |
+
g += weight * rgbs[1];
|
650 |
+
b += weight * rgbs[2];
|
651 |
+
ws += weight;
|
652 |
+
|
653 |
+
T *= 1.0f - alpha;
|
654 |
+
|
655 |
+
// check https://note.kiui.moe/others/nerf_gradient/ for the gradient calculation.
|
656 |
+
// write grad_rgbs
|
657 |
+
grad_rgbs[0] = grad_image[0] * weight;
|
658 |
+
grad_rgbs[1] = grad_image[1] * weight;
|
659 |
+
grad_rgbs[2] = grad_image[2] * weight;
|
660 |
+
|
661 |
+
// write grad_sigmas
|
662 |
+
grad_sigmas[0] = deltas[0] * (
|
663 |
+
grad_image[0] * (T * rgbs[0] - (r_final - r)) +
|
664 |
+
grad_image[1] * (T * rgbs[1] - (g_final - g)) +
|
665 |
+
grad_image[2] * (T * rgbs[2] - (b_final - b)) +
|
666 |
+
grad_weights_sum[0] * (1 - ws_final)
|
667 |
+
);
|
668 |
+
|
669 |
+
//printf("[n=%d] num_steps=%d, T=%f, grad_sigmas=%f, r_final=%f, r=%f\n", n, step, T, grad_sigmas[0], r_final, r);
|
670 |
+
// minimal remained transmittence
|
671 |
+
if (T < T_thresh) break;
|
672 |
+
|
673 |
+
// locate
|
674 |
+
sigmas++;
|
675 |
+
rgbs += 3;
|
676 |
+
deltas += 2;
|
677 |
+
grad_sigmas++;
|
678 |
+
grad_rgbs += 3;
|
679 |
+
|
680 |
+
step++;
|
681 |
+
}
|
682 |
+
}
|
683 |
+
|
684 |
+
|
685 |
+
void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs) {
|
686 |
+
|
687 |
+
static constexpr uint32_t N_THREAD = 128;
|
688 |
+
|
689 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
690 |
+
grad_image.scalar_type(), "composite_rays_train_backward", ([&] {
|
691 |
+
kernel_composite_rays_train_backward<<<div_round_up(N, N_THREAD), N_THREAD>>>(grad_weights_sum.data_ptr<scalar_t>(), grad_image.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), weights_sum.data_ptr<scalar_t>(), image.data_ptr<scalar_t>(), M, N, T_thresh, grad_sigmas.data_ptr<scalar_t>(), grad_rgbs.data_ptr<scalar_t>());
|
692 |
+
}));
|
693 |
+
}
|
694 |
+
|
695 |
+
|
696 |
+
////////////////////////////////////////////////////
|
697 |
+
///////////// infernce /////////////
|
698 |
+
////////////////////////////////////////////////////
|
699 |
+
|
700 |
+
template <typename scalar_t>
|
701 |
+
__global__ void kernel_march_rays(
|
702 |
+
const uint32_t n_alive,
|
703 |
+
const uint32_t n_step,
|
704 |
+
const int* __restrict__ rays_alive,
|
705 |
+
const scalar_t* __restrict__ rays_t,
|
706 |
+
const scalar_t* __restrict__ rays_o,
|
707 |
+
const scalar_t* __restrict__ rays_d,
|
708 |
+
const float bound,
|
709 |
+
const float dt_gamma, const uint32_t max_steps,
|
710 |
+
const uint32_t C, const uint32_t H,
|
711 |
+
const uint8_t * __restrict__ grid,
|
712 |
+
const scalar_t* __restrict__ nears,
|
713 |
+
const scalar_t* __restrict__ fars,
|
714 |
+
scalar_t* xyzs, scalar_t* dirs, scalar_t* deltas,
|
715 |
+
const scalar_t* __restrict__ noises
|
716 |
+
) {
|
717 |
+
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
718 |
+
if (n >= n_alive) return;
|
719 |
+
|
720 |
+
const int index = rays_alive[n]; // ray id
|
721 |
+
const float noise = noises[n];
|
722 |
+
|
723 |
+
// locate
|
724 |
+
rays_o += index * 3;
|
725 |
+
rays_d += index * 3;
|
726 |
+
xyzs += n * n_step * 3;
|
727 |
+
dirs += n * n_step * 3;
|
728 |
+
deltas += n * n_step * 2;
|
729 |
+
|
730 |
+
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
|
731 |
+
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
|
732 |
+
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
|
733 |
+
const float rH = 1 / (float)H;
|
734 |
+
const float H3 = H * H * H;
|
735 |
+
|
736 |
+
float t = rays_t[index]; // current ray's t
|
737 |
+
const float near = nears[index], far = fars[index];
|
738 |
+
|
739 |
+
const float dt_min = 2 * SQRT3() / max_steps;
|
740 |
+
const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H;
|
741 |
+
|
742 |
+
// march for n_step steps, record points
|
743 |
+
uint32_t step = 0;
|
744 |
+
|
745 |
+
// introduce some randomness
|
746 |
+
t += clamp(t * dt_gamma, dt_min, dt_max) * noise;
|
747 |
+
|
748 |
+
float last_t = t;
|
749 |
+
|
750 |
+
while (t < far && step < n_step) {
|
751 |
+
// current point
|
752 |
+
const float x = clamp(ox + t * dx, -bound, bound);
|
753 |
+
const float y = clamp(oy + t * dy, -bound, bound);
|
754 |
+
const float z = clamp(oz + t * dz, -bound, bound);
|
755 |
+
|
756 |
+
const float dt = clamp(t * dt_gamma, dt_min, dt_max);
|
757 |
+
|
758 |
+
// get mip level
|
759 |
+
const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]
|
760 |
+
|
761 |
+
const float mip_bound = fminf(scalbnf(1, level), bound);
|
762 |
+
const float mip_rbound = 1 / mip_bound;
|
763 |
+
|
764 |
+
// convert to nearest grid position
|
765 |
+
const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
766 |
+
const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
767 |
+
const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
768 |
+
|
769 |
+
const uint32_t index = level * H3 + __morton3D(nx, ny, nz);
|
770 |
+
const bool occ = grid[index / 8] & (1 << (index % 8));
|
771 |
+
|
772 |
+
// if occpuied, advance a small step, and write to output
|
773 |
+
if (occ) {
|
774 |
+
// write step
|
775 |
+
xyzs[0] = x;
|
776 |
+
xyzs[1] = y;
|
777 |
+
xyzs[2] = z;
|
778 |
+
dirs[0] = dx;
|
779 |
+
dirs[1] = dy;
|
780 |
+
dirs[2] = dz;
|
781 |
+
// calc dt
|
782 |
+
t += dt;
|
783 |
+
deltas[0] = dt;
|
784 |
+
deltas[1] = t - last_t; // used to calc depth
|
785 |
+
last_t = t;
|
786 |
+
// step
|
787 |
+
xyzs += 3;
|
788 |
+
dirs += 3;
|
789 |
+
deltas += 2;
|
790 |
+
step++;
|
791 |
+
|
792 |
+
// else, skip a large step (basically skip a voxel grid)
|
793 |
+
} else {
|
794 |
+
// calc distance to next voxel
|
795 |
+
const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx;
|
796 |
+
const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy;
|
797 |
+
const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz;
|
798 |
+
const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
|
799 |
+
// step until next voxel
|
800 |
+
do {
|
801 |
+
t += clamp(t * dt_gamma, dt_min, dt_max);
|
802 |
+
} while (t < tt);
|
803 |
+
}
|
804 |
+
}
|
805 |
+
}
|
806 |
+
|
807 |
+
|
808 |
+
void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor near, const at::Tensor far, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises) {
|
809 |
+
static constexpr uint32_t N_THREAD = 128;
|
810 |
+
|
811 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
812 |
+
rays_o.scalar_type(), "march_rays", ([&] {
|
813 |
+
kernel_march_rays<<<div_round_up(n_alive, N_THREAD), N_THREAD>>>(n_alive, n_step, rays_alive.data_ptr<int>(), rays_t.data_ptr<scalar_t>(), rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), bound, dt_gamma, max_steps, C, H, grid.data_ptr<uint8_t>(), near.data_ptr<scalar_t>(), far.data_ptr<scalar_t>(), xyzs.data_ptr<scalar_t>(), dirs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), noises.data_ptr<scalar_t>());
|
814 |
+
}));
|
815 |
+
}
|
816 |
+
|
817 |
+
|
818 |
+
template <typename scalar_t>
|
819 |
+
__global__ void kernel_composite_rays(
|
820 |
+
const uint32_t n_alive,
|
821 |
+
const uint32_t n_step,
|
822 |
+
const float T_thresh,
|
823 |
+
int* rays_alive,
|
824 |
+
scalar_t* rays_t,
|
825 |
+
const scalar_t* __restrict__ sigmas,
|
826 |
+
const scalar_t* __restrict__ rgbs,
|
827 |
+
const scalar_t* __restrict__ deltas,
|
828 |
+
scalar_t* weights_sum, scalar_t* depth, scalar_t* image
|
829 |
+
) {
|
830 |
+
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
831 |
+
if (n >= n_alive) return;
|
832 |
+
|
833 |
+
const int index = rays_alive[n]; // ray id
|
834 |
+
|
835 |
+
// locate
|
836 |
+
sigmas += n * n_step;
|
837 |
+
rgbs += n * n_step * 3;
|
838 |
+
deltas += n * n_step * 2;
|
839 |
+
|
840 |
+
rays_t += index;
|
841 |
+
weights_sum += index;
|
842 |
+
depth += index;
|
843 |
+
image += index * 3;
|
844 |
+
|
845 |
+
scalar_t t = rays_t[0]; // current ray's t
|
846 |
+
|
847 |
+
scalar_t weight_sum = weights_sum[0];
|
848 |
+
scalar_t d = depth[0];
|
849 |
+
scalar_t r = image[0];
|
850 |
+
scalar_t g = image[1];
|
851 |
+
scalar_t b = image[2];
|
852 |
+
|
853 |
+
// accumulate
|
854 |
+
uint32_t step = 0;
|
855 |
+
while (step < n_step) {
|
856 |
+
|
857 |
+
// ray is terminated if delta == 0
|
858 |
+
if (deltas[0] == 0) break;
|
859 |
+
|
860 |
+
const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);
|
861 |
+
|
862 |
+
/*
|
863 |
+
T_0 = 1; T_i = \prod_{j=0}^{i-1} (1 - alpha_j)
|
864 |
+
w_i = alpha_i * T_i
|
865 |
+
-->
|
866 |
+
T_i = 1 - \sum_{j=0}^{i-1} w_j
|
867 |
+
*/
|
868 |
+
const scalar_t T = 1 - weight_sum;
|
869 |
+
const scalar_t weight = alpha * T;
|
870 |
+
weight_sum += weight;
|
871 |
+
|
872 |
+
t += deltas[1]; // real delta
|
873 |
+
d += weight * t;
|
874 |
+
r += weight * rgbs[0];
|
875 |
+
g += weight * rgbs[1];
|
876 |
+
b += weight * rgbs[2];
|
877 |
+
|
878 |
+
//printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d);
|
879 |
+
|
880 |
+
// ray is terminated if T is too small
|
881 |
+
// use a larger bound to further accelerate inference
|
882 |
+
if (T < T_thresh) break;
|
883 |
+
|
884 |
+
// locate
|
885 |
+
sigmas++;
|
886 |
+
rgbs += 3;
|
887 |
+
deltas += 2;
|
888 |
+
step++;
|
889 |
+
}
|
890 |
+
|
891 |
+
//printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d);
|
892 |
+
|
893 |
+
// rays_alive = -1 means ray is terminated early.
|
894 |
+
if (step < n_step) {
|
895 |
+
rays_alive[n] = -1;
|
896 |
+
} else {
|
897 |
+
rays_t[0] = t;
|
898 |
+
}
|
899 |
+
|
900 |
+
weights_sum[0] = weight_sum; // this is the thing I needed!
|
901 |
+
depth[0] = d;
|
902 |
+
image[0] = r;
|
903 |
+
image[1] = g;
|
904 |
+
image[2] = b;
|
905 |
+
}
|
906 |
+
|
907 |
+
|
908 |
+
void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights, at::Tensor depth, at::Tensor image) {
|
909 |
+
static constexpr uint32_t N_THREAD = 128;
|
910 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
911 |
+
image.scalar_type(), "composite_rays", ([&] {
|
912 |
+
kernel_composite_rays<<<div_round_up(n_alive, N_THREAD), N_THREAD>>>(n_alive, n_step, T_thresh, rays_alive.data_ptr<int>(), rays_t.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), weights.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>());
|
913 |
+
}));
|
914 |
+
}
|
raymarching/src/raymarching.h
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <stdint.h>
|
4 |
+
#include <torch/torch.h>
|
5 |
+
|
6 |
+
|
7 |
+
void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars);
|
8 |
+
void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords);
|
9 |
+
void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices);
|
10 |
+
void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords);
|
11 |
+
void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield);
|
12 |
+
|
13 |
+
void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises);
|
14 |
+
void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);
|
15 |
+
void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs);
|
16 |
+
|
17 |
+
void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises);
|
18 |
+
void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);
|
requirements.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch-ema
|
2 |
+
ninja
|
3 |
+
trimesh
|
4 |
+
opencv-python
|
5 |
+
tensorboardX
|
6 |
+
torch
|
7 |
+
numpy
|
8 |
+
pandas
|
9 |
+
tqdm
|
10 |
+
matplotlib
|
11 |
+
PyMCubes
|
12 |
+
rich
|
13 |
+
dearpygui
|
14 |
+
scipy
|
15 |
+
huggingface_hub
|
16 |
+
diffusers
|
17 |
+
transformers
|
18 |
+
xatlas
|
19 |
+
scikit-learn
|
20 |
+
imageio
|
21 |
+
imageio-ffmpeg
|
scripts/install_ext.sh
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pip install ./raymarching
|
2 |
+
pip install ./shencoder
|
3 |
+
pip install ./freqencoder
|
4 |
+
pip install ./gridencoder
|
scripts/run.sh
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /bin/bash
|
2 |
+
|
3 |
+
CUDA_VISIBLE_DEVICES=1 python main.py -O --text "a DSLR photo of cthulhu" --workspace trial_cthulhu
|
4 |
+
CUDA_VISIBLE_DEVICES=1 python main.py -O --text "a DSLR photo of a squirrel" --workspace trial_squirrel
|
5 |
+
CUDA_VISIBLE_DEVICES=1 python main.py -O --text "a DSLR photo of a cat lying on its side batting at a ball of yarn" --workspace trial_cat_lying
|
shencoder/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .sphere_harmonics import SHEncoder
|
shencoder/backend.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from torch.utils.cpp_extension import load
|
3 |
+
|
4 |
+
_src_path = os.path.dirname(os.path.abspath(__file__))
|
5 |
+
|
6 |
+
nvcc_flags = [
|
7 |
+
'-O3', '-std=c++14',
|
8 |
+
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
9 |
+
]
|
10 |
+
|
11 |
+
if os.name == "posix":
|
12 |
+
c_flags = ['-O3', '-std=c++14']
|
13 |
+
elif os.name == "nt":
|
14 |
+
c_flags = ['/O2', '/std:c++17']
|
15 |
+
|
16 |
+
# find cl.exe
|
17 |
+
def find_cl_path():
|
18 |
+
import glob
|
19 |
+
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
20 |
+
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
|
21 |
+
if paths:
|
22 |
+
return paths[0]
|
23 |
+
|
24 |
+
# If cl.exe is not on path, try to find it.
|
25 |
+
if os.system("where cl.exe >nul 2>nul") != 0:
|
26 |
+
cl_path = find_cl_path()
|
27 |
+
if cl_path is None:
|
28 |
+
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
29 |
+
os.environ["PATH"] += ";" + cl_path
|
30 |
+
|
31 |
+
_backend = load(name='_sh_encoder',
|
32 |
+
extra_cflags=c_flags,
|
33 |
+
extra_cuda_cflags=nvcc_flags,
|
34 |
+
sources=[os.path.join(_src_path, 'src', f) for f in [
|
35 |
+
'shencoder.cu',
|
36 |
+
'bindings.cpp',
|
37 |
+
]],
|
38 |
+
)
|
39 |
+
|
40 |
+
__all__ = ['_backend']
|
shencoder/setup.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from setuptools import setup
|
3 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
4 |
+
|
5 |
+
_src_path = os.path.dirname(os.path.abspath(__file__))
|
6 |
+
|
7 |
+
nvcc_flags = [
|
8 |
+
'-O3', '-std=c++14',
|
9 |
+
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
10 |
+
]
|
11 |
+
|
12 |
+
if os.name == "posix":
|
13 |
+
c_flags = ['-O3', '-std=c++14']
|
14 |
+
elif os.name == "nt":
|
15 |
+
c_flags = ['/O2', '/std:c++17']
|
16 |
+
|
17 |
+
# find cl.exe
|
18 |
+
def find_cl_path():
|
19 |
+
import glob
|
20 |
+
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
21 |
+
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
|
22 |
+
if paths:
|
23 |
+
return paths[0]
|
24 |
+
|
25 |
+
# If cl.exe is not on path, try to find it.
|
26 |
+
if os.system("where cl.exe >nul 2>nul") != 0:
|
27 |
+
cl_path = find_cl_path()
|
28 |
+
if cl_path is None:
|
29 |
+
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
30 |
+
os.environ["PATH"] += ";" + cl_path
|
31 |
+
|
32 |
+
setup(
|
33 |
+
name='shencoder', # package name, import this to use python API
|
34 |
+
ext_modules=[
|
35 |
+
CUDAExtension(
|
36 |
+
name='_shencoder', # extension name, import this to use CUDA API
|
37 |
+
sources=[os.path.join(_src_path, 'src', f) for f in [
|
38 |
+
'shencoder.cu',
|
39 |
+
'bindings.cpp',
|
40 |
+
]],
|
41 |
+
extra_compile_args={
|
42 |
+
'cxx': c_flags,
|
43 |
+
'nvcc': nvcc_flags,
|
44 |
+
}
|
45 |
+
),
|
46 |
+
],
|
47 |
+
cmdclass={
|
48 |
+
'build_ext': BuildExtension,
|
49 |
+
}
|
50 |
+
)
|
shencoder/sphere_harmonics.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.autograd import Function
|
6 |
+
from torch.autograd.function import once_differentiable
|
7 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
8 |
+
|
9 |
+
try:
|
10 |
+
import _shencoder as _backend
|
11 |
+
except ImportError:
|
12 |
+
from .backend import _backend
|
13 |
+
|
14 |
+
class _sh_encoder(Function):
|
15 |
+
@staticmethod
|
16 |
+
@custom_fwd(cast_inputs=torch.float32) # force float32 for better precision
|
17 |
+
def forward(ctx, inputs, degree, calc_grad_inputs=False):
|
18 |
+
# inputs: [B, input_dim], float in [-1, 1]
|
19 |
+
# RETURN: [B, F], float
|
20 |
+
|
21 |
+
inputs = inputs.contiguous()
|
22 |
+
B, input_dim = inputs.shape # batch size, coord dim
|
23 |
+
output_dim = degree ** 2
|
24 |
+
|
25 |
+
outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)
|
26 |
+
|
27 |
+
if calc_grad_inputs:
|
28 |
+
dy_dx = torch.empty(B, input_dim * output_dim, dtype=inputs.dtype, device=inputs.device)
|
29 |
+
else:
|
30 |
+
dy_dx = None
|
31 |
+
|
32 |
+
_backend.sh_encode_forward(inputs, outputs, B, input_dim, degree, dy_dx)
|
33 |
+
|
34 |
+
ctx.save_for_backward(inputs, dy_dx)
|
35 |
+
ctx.dims = [B, input_dim, degree]
|
36 |
+
|
37 |
+
return outputs
|
38 |
+
|
39 |
+
@staticmethod
|
40 |
+
#@once_differentiable
|
41 |
+
@custom_bwd
|
42 |
+
def backward(ctx, grad):
|
43 |
+
# grad: [B, C * C]
|
44 |
+
|
45 |
+
inputs, dy_dx = ctx.saved_tensors
|
46 |
+
|
47 |
+
if dy_dx is not None:
|
48 |
+
grad = grad.contiguous()
|
49 |
+
B, input_dim, degree = ctx.dims
|
50 |
+
grad_inputs = torch.zeros_like(inputs)
|
51 |
+
_backend.sh_encode_backward(grad, inputs, B, input_dim, degree, dy_dx, grad_inputs)
|
52 |
+
return grad_inputs, None, None
|
53 |
+
else:
|
54 |
+
return None, None, None
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
sh_encode = _sh_encoder.apply
|
59 |
+
|
60 |
+
|
61 |
+
class SHEncoder(nn.Module):
|
62 |
+
def __init__(self, input_dim=3, degree=4):
|
63 |
+
super().__init__()
|
64 |
+
|
65 |
+
self.input_dim = input_dim # coord dims, must be 3
|
66 |
+
self.degree = degree # 0 ~ 4
|
67 |
+
self.output_dim = degree ** 2
|
68 |
+
|
69 |
+
assert self.input_dim == 3, "SH encoder only support input dim == 3"
|
70 |
+
assert self.degree > 0 and self.degree <= 8, "SH encoder only supports degree in [1, 8]"
|
71 |
+
|
72 |
+
def __repr__(self):
|
73 |
+
return f"SHEncoder: input_dim={self.input_dim} degree={self.degree}"
|
74 |
+
|
75 |
+
def forward(self, inputs, size=1):
|
76 |
+
# inputs: [..., input_dim], normalized real world positions in [-size, size]
|
77 |
+
# return: [..., degree^2]
|
78 |
+
|
79 |
+
inputs = inputs / size # [-1, 1]
|
80 |
+
|
81 |
+
prefix_shape = list(inputs.shape[:-1])
|
82 |
+
inputs = inputs.reshape(-1, self.input_dim)
|
83 |
+
|
84 |
+
outputs = sh_encode(inputs, self.degree, inputs.requires_grad)
|
85 |
+
outputs = outputs.reshape(prefix_shape + [self.output_dim])
|
86 |
+
|
87 |
+
return outputs
|
shencoder/src/bindings.cpp
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
|
3 |
+
#include "shencoder.h"
|
4 |
+
|
5 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
6 |
+
m.def("sh_encode_forward", &sh_encode_forward, "SH encode forward (CUDA)");
|
7 |
+
m.def("sh_encode_backward", &sh_encode_backward, "SH encode backward (CUDA)");
|
8 |
+
}
|
shencoder/src/shencoder.cu
ADDED
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <stdint.h>
|
2 |
+
|
3 |
+
#include <cuda.h>
|
4 |
+
#include <cuda_fp16.h>
|
5 |
+
#include <cuda_runtime.h>
|
6 |
+
|
7 |
+
#include <ATen/cuda/CUDAContext.h>
|
8 |
+
#include <torch/torch.h>
|
9 |
+
|
10 |
+
#include <algorithm>
|
11 |
+
#include <stdexcept>
|
12 |
+
|
13 |
+
#include <cstdio>
|
14 |
+
|
15 |
+
|
16 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
17 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
|
18 |
+
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
|
19 |
+
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
|
20 |
+
|
21 |
+
|
22 |
+
template <typename T>
|
23 |
+
__host__ __device__ T div_round_up(T val, T divisor) {
|
24 |
+
return (val + divisor - 1) / divisor;
|
25 |
+
}
|
26 |
+
|
27 |
+
template <typename scalar_t>
|
28 |
+
__global__ void kernel_sh(
|
29 |
+
const scalar_t * __restrict__ inputs,
|
30 |
+
scalar_t * outputs,
|
31 |
+
uint32_t B, uint32_t D, uint32_t C,
|
32 |
+
scalar_t * dy_dx
|
33 |
+
) {
|
34 |
+
const uint32_t b = threadIdx.x + blockIdx.x * blockDim.x;
|
35 |
+
if (b >= B) return;
|
36 |
+
|
37 |
+
const uint32_t C2 = C * C;
|
38 |
+
|
39 |
+
// locate
|
40 |
+
inputs += b * D;
|
41 |
+
outputs += b * C2;
|
42 |
+
|
43 |
+
scalar_t x = inputs[0], y = inputs[1], z = inputs[2];
|
44 |
+
|
45 |
+
scalar_t xy=x*y, xz=x*z, yz=y*z, x2=x*x, y2=y*y, z2=z*z, xyz=xy*z;
|
46 |
+
scalar_t x4=x2*x2, y4=y2*y2, z4=z2*z2;
|
47 |
+
scalar_t x6=x4*x2, y6=y4*y2, z6=z4*z2;
|
48 |
+
|
49 |
+
auto write_sh = [&]() {
|
50 |
+
outputs[0] = 0.28209479177387814f ; // 1/(2*sqrt(pi))
|
51 |
+
if (C <= 1) { return; }
|
52 |
+
outputs[1] = -0.48860251190291987f*y ; // -sqrt(3)*y/(2*sqrt(pi))
|
53 |
+
outputs[2] = 0.48860251190291987f*z ; // sqrt(3)*z/(2*sqrt(pi))
|
54 |
+
outputs[3] = -0.48860251190291987f*x ; // -sqrt(3)*x/(2*sqrt(pi))
|
55 |
+
if (C <= 2) { return; }
|
56 |
+
outputs[4] = 1.0925484305920792f*xy ; // sqrt(15)*xy/(2*sqrt(pi))
|
57 |
+
outputs[5] = -1.0925484305920792f*yz ; // -sqrt(15)*yz/(2*sqrt(pi))
|
58 |
+
outputs[6] = 0.94617469575755997f*z2 - 0.31539156525251999f ; // sqrt(5)*(3*z2 - 1)/(4*sqrt(pi))
|
59 |
+
outputs[7] = -1.0925484305920792f*xz ; // -sqrt(15)*xz/(2*sqrt(pi))
|
60 |
+
outputs[8] = 0.54627421529603959f*x2 - 0.54627421529603959f*y2 ; // sqrt(15)*(x2 - y2)/(4*sqrt(pi))
|
61 |
+
if (C <= 3) { return; }
|
62 |
+
outputs[9] = 0.59004358992664352f*y*(-3.0f*x2 + y2) ; // sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi))
|
63 |
+
outputs[10] = 2.8906114426405538f*xy*z ; // sqrt(105)*xy*z/(2*sqrt(pi))
|
64 |
+
outputs[11] = 0.45704579946446572f*y*(1.0f - 5.0f*z2) ; // sqrt(42)*y*(1 - 5*z2)/(8*sqrt(pi))
|
65 |
+
outputs[12] = 0.3731763325901154f*z*(5.0f*z2 - 3.0f) ; // sqrt(7)*z*(5*z2 - 3)/(4*sqrt(pi))
|
66 |
+
outputs[13] = 0.45704579946446572f*x*(1.0f - 5.0f*z2) ; // sqrt(42)*x*(1 - 5*z2)/(8*sqrt(pi))
|
67 |
+
outputs[14] = 1.4453057213202769f*z*(x2 - y2) ; // sqrt(105)*z*(x2 - y2)/(4*sqrt(pi))
|
68 |
+
outputs[15] = 0.59004358992664352f*x*(-x2 + 3.0f*y2) ; // sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi))
|
69 |
+
if (C <= 4) { return; }
|
70 |
+
outputs[16] = 2.5033429417967046f*xy*(x2 - y2) ; // 3*sqrt(35)*xy*(x2 - y2)/(4*sqrt(pi))
|
71 |
+
outputs[17] = 1.7701307697799304f*yz*(-3.0f*x2 + y2) ; // 3*sqrt(70)*yz*(-3*x2 + y2)/(8*sqrt(pi))
|
72 |
+
outputs[18] = 0.94617469575756008f*xy*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*xy*(7*z2 - 1)/(4*sqrt(pi))
|
73 |
+
outputs[19] = 0.66904654355728921f*yz*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*yz*(3 - 7*z2)/(8*sqrt(pi))
|
74 |
+
outputs[20] = -3.1735664074561294f*z2 + 3.7024941420321507f*z4 + 0.31735664074561293f ; // 3*(-30*z2 + 35*z4 + 3)/(16*sqrt(pi))
|
75 |
+
outputs[21] = 0.66904654355728921f*xz*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*xz*(3 - 7*z2)/(8*sqrt(pi))
|
76 |
+
outputs[22] = 0.47308734787878004f*(x2 - y2)*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*(x2 - y2)*(7*z2 - 1)/(8*sqrt(pi))
|
77 |
+
outputs[23] = 1.7701307697799304f*xz*(-x2 + 3.0f*y2) ; // 3*sqrt(70)*xz*(-x2 + 3*y2)/(8*sqrt(pi))
|
78 |
+
outputs[24] = -3.7550144126950569f*x2*y2 + 0.62583573544917614f*x4 + 0.62583573544917614f*y4 ; // 3*sqrt(35)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))
|
79 |
+
if (C <= 5) { return; }
|
80 |
+
outputs[25] = 0.65638205684017015f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(154)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
|
81 |
+
outputs[26] = 8.3026492595241645f*xy*z*(x2 - y2) ; // 3*sqrt(385)*xy*z*(x2 - y2)/(4*sqrt(pi))
|
82 |
+
outputs[27] = -0.48923829943525038f*y*(3.0f*x2 - y2)*(9.0f*z2 - 1.0f) ; // -sqrt(770)*y*(3*x2 - y2)*(9*z2 - 1)/(32*sqrt(pi))
|
83 |
+
outputs[28] = 4.7935367849733241f*xy*z*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xy*z*(3*z2 - 1)/(4*sqrt(pi))
|
84 |
+
outputs[29] = 0.45294665119569694f*y*(14.0f*z2 - 21.0f*z4 - 1.0f) ; // sqrt(165)*y*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))
|
85 |
+
outputs[30] = 0.1169503224534236f*z*(-70.0f*z2 + 63.0f*z4 + 15.0f) ; // sqrt(11)*z*(-70*z2 + 63*z4 + 15)/(16*sqrt(pi))
|
86 |
+
outputs[31] = 0.45294665119569694f*x*(14.0f*z2 - 21.0f*z4 - 1.0f) ; // sqrt(165)*x*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))
|
87 |
+
outputs[32] = 2.3967683924866621f*z*(x2 - y2)*(3.0f*z2 - 1.0f) ; // sqrt(1155)*z*(x2 - y2)*(3*z2 - 1)/(8*sqrt(pi))
|
88 |
+
outputs[33] = -0.48923829943525038f*x*(x2 - 3.0f*y2)*(9.0f*z2 - 1.0f) ; // -sqrt(770)*x*(x2 - 3*y2)*(9*z2 - 1)/(32*sqrt(pi))
|
89 |
+
outputs[34] = 2.0756623148810411f*z*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(385)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))
|
90 |
+
outputs[35] = 0.65638205684017015f*x*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(154)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))
|
91 |
+
if (C <= 6) { return; }
|
92 |
+
outputs[36] = 1.3663682103838286f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // sqrt(6006)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))
|
93 |
+
outputs[37] = 2.3666191622317521f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(2002)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
|
94 |
+
outputs[38] = 2.0182596029148963f*xy*(x2 - y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*xy*(x2 - y2)*(11*z2 - 1)/(8*sqrt(pi))
|
95 |
+
outputs[39] = -0.92120525951492349f*yz*(3.0f*x2 - y2)*(11.0f*z2 - 3.0f) ; // -sqrt(2730)*yz*(3*x2 - y2)*(11*z2 - 3)/(32*sqrt(pi))
|
96 |
+
outputs[40] = 0.92120525951492349f*xy*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*xy*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))
|
97 |
+
outputs[41] = 0.58262136251873131f*yz*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*yz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))
|
98 |
+
outputs[42] = 6.6747662381009842f*z2 - 20.024298714302954f*z4 + 14.684485723822165f*z6 - 0.31784601133814211f ; // sqrt(13)*(105*z2 - 315*z4 + 231*z6 - 5)/(32*sqrt(pi))
|
99 |
+
outputs[43] = 0.58262136251873131f*xz*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*xz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))
|
100 |
+
outputs[44] = 0.46060262975746175f*(x2 - y2)*(11.0f*z2*(3.0f*z2 - 1.0f) - 7.0f*z2 + 1.0f) ; // sqrt(2730)*(x2 - y2)*(11*z2*(3*z2 - 1) - 7*z2 + 1)/(64*sqrt(pi))
|
101 |
+
outputs[45] = -0.92120525951492349f*xz*(x2 - 3.0f*y2)*(11.0f*z2 - 3.0f) ; // -sqrt(2730)*xz*(x2 - 3*y2)*(11*z2 - 3)/(32*sqrt(pi))
|
102 |
+
outputs[46] = 0.50456490072872406f*(11.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(91)*(11*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi))
|
103 |
+
outputs[47] = 2.3666191622317521f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(2002)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))
|
104 |
+
outputs[48] = 10.247761577878714f*x2*y4 - 10.247761577878714f*x4*y2 + 0.6831841051919143f*x6 - 0.6831841051919143f*y6 ; // sqrt(6006)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi))
|
105 |
+
if (C <= 7) { return; }
|
106 |
+
outputs[49] = 0.70716273252459627f*y*(-21.0f*x2*y4 + 35.0f*x4*y2 - 7.0f*x6 + y6) ; // 3*sqrt(715)*y*(-21*x2*y4 + 35*x4*y2 - 7*x6 + y6)/(64*sqrt(pi))
|
107 |
+
outputs[50] = 5.2919213236038001f*xy*z*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 3*sqrt(10010)*xy*z*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))
|
108 |
+
outputs[51] = -0.51891557872026028f*y*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // -3*sqrt(385)*y*(13*z2 - 1)*(-10*x2*y2 + 5*x4 + y4)/(64*sqrt(pi))
|
109 |
+
outputs[52] = 4.1513246297620823f*xy*z*(x2 - y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xy*z*(x2 - y2)*(13*z2 - 3)/(8*sqrt(pi))
|
110 |
+
outputs[53] = -0.15645893386229404f*y*(3.0f*x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -3*sqrt(35)*y*(3*x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi))
|
111 |
+
outputs[54] = 0.44253269244498261f*xy*z*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xy*z*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))
|
112 |
+
outputs[55] = 0.090331607582517306f*y*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f) ; // sqrt(105)*y*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))
|
113 |
+
outputs[56] = 0.068284276912004949f*z*(315.0f*z2 - 693.0f*z4 + 429.0f*z6 - 35.0f) ; // sqrt(15)*z*(315*z2 - 693*z4 + 429*z6 - 35)/(32*sqrt(pi))
|
114 |
+
outputs[57] = 0.090331607582517306f*x*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f) ; // sqrt(105)*x*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))
|
115 |
+
outputs[58] = 0.07375544874083044f*z*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) - 187.0f*z2 + 45.0f) ; // sqrt(70)*z*(x2 - y2)*(143*z2*(3*z2 - 1) - 187*z2 + 45)/(64*sqrt(pi))
|
116 |
+
outputs[59] = -0.15645893386229404f*x*(x2 - 3.0f*y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -3*sqrt(35)*x*(x2 - 3*y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi))
|
117 |
+
outputs[60] = 1.0378311574405206f*z*(13.0f*z2 - 3.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(385)*z*(13*z2 - 3)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi))
|
118 |
+
outputs[61] = -0.51891557872026028f*x*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // -3*sqrt(385)*x*(13*z2 - 1)*(-10*x2*y2 + x4 + 5*y4)/(64*sqrt(pi))
|
119 |
+
outputs[62] = 2.6459606618019f*z*(15.0f*x2*y4 - 15.0f*x4*y2 + x6 - y6) ; // 3*sqrt(10010)*z*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi))
|
120 |
+
outputs[63] = 0.70716273252459627f*x*(-35.0f*x2*y4 + 21.0f*x4*y2 - x6 + 7.0f*y6) ; // 3*sqrt(715)*x*(-35*x2*y4 + 21*x4*y2 - x6 + 7*y6)/(64*sqrt(pi))
|
121 |
+
};
|
122 |
+
|
123 |
+
write_sh();
|
124 |
+
|
125 |
+
if (dy_dx) {
|
126 |
+
scalar_t *dx = dy_dx + b * D * C2;
|
127 |
+
scalar_t *dy = dx + C2;
|
128 |
+
scalar_t *dz = dy + C2;
|
129 |
+
|
130 |
+
auto write_sh_dx = [&]() {
|
131 |
+
dx[0] = 0.0f ; // 0
|
132 |
+
if (C <= 1) { return; }
|
133 |
+
dx[1] = 0.0f ; // 0
|
134 |
+
dx[2] = 0.0f ; // 0
|
135 |
+
dx[3] = -0.48860251190291992f ; // -sqrt(3)/(2*sqrt(pi))
|
136 |
+
if (C <= 2) { return; }
|
137 |
+
dx[4] = 1.0925484305920792f*y ; // sqrt(15)*y/(2*sqrt(pi))
|
138 |
+
dx[5] = 0.0f ; // 0
|
139 |
+
dx[6] = 0.0f ; // 0
|
140 |
+
dx[7] = -1.0925484305920792f*z ; // -sqrt(15)*z/(2*sqrt(pi))
|
141 |
+
dx[8] = 1.0925484305920792f*x ; // sqrt(15)*x/(2*sqrt(pi))
|
142 |
+
if (C <= 3) { return; }
|
143 |
+
dx[9] = -3.5402615395598609f*xy ; // -3*sqrt(70)*xy/(4*sqrt(pi))
|
144 |
+
dx[10] = 2.8906114426405538f*yz ; // sqrt(105)*yz/(2*sqrt(pi))
|
145 |
+
dx[11] = 0.0f ; // 0
|
146 |
+
dx[12] = 0.0f ; // 0
|
147 |
+
dx[13] = 0.45704579946446572f - 2.2852289973223288f*z2 ; // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi))
|
148 |
+
dx[14] = 2.8906114426405538f*xz ; // sqrt(105)*xz/(2*sqrt(pi))
|
149 |
+
dx[15] = -1.7701307697799304f*x2 + 1.7701307697799304f*y2 ; // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi))
|
150 |
+
if (C <= 4) { return; }
|
151 |
+
dx[16] = 2.5033429417967046f*y*(3.0f*x2 - y2) ; // 3*sqrt(35)*y*(3*x2 - y2)/(4*sqrt(pi))
|
152 |
+
dx[17] = -10.620784618679583f*xy*z ; // -9*sqrt(70)*xy*z/(4*sqrt(pi))
|
153 |
+
dx[18] = 0.94617469575756008f*y*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*y*(7*z2 - 1)/(4*sqrt(pi))
|
154 |
+
dx[19] = 0.0f ; // 0
|
155 |
+
dx[20] = 0.0f ; // 0
|
156 |
+
dx[21] = 0.66904654355728921f*z*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi))
|
157 |
+
dx[22] = 0.94617469575756008f*x*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi))
|
158 |
+
dx[23] = 5.3103923093397913f*z*(-x2 + y2) ; // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi))
|
159 |
+
dx[24] = 2.5033429417967046f*x*(x2 - 3.0f*y2) ; // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi))
|
160 |
+
if (C <= 5) { return; }
|
161 |
+
dx[25] = 13.127641136803401f*xy*(-x2 + y2) ; // 15*sqrt(154)*xy*(-x2 + y2)/(8*sqrt(pi))
|
162 |
+
dx[26] = 8.3026492595241645f*yz*(3.0f*x2 - y2) ; // 3*sqrt(385)*yz*(3*x2 - y2)/(4*sqrt(pi))
|
163 |
+
dx[27] = 2.9354297966115022f*xy*(1.0f - 9.0f*z2) ; // 3*sqrt(770)*xy*(1 - 9*z2)/(16*sqrt(pi))
|
164 |
+
dx[28] = 4.7935367849733241f*yz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*yz*(3*z2 - 1)/(4*sqrt(pi))
|
165 |
+
dx[29] = 0.0f ; // 0
|
166 |
+
dx[30] = 0.0f ; // 0
|
167 |
+
dx[31] = 6.3412531167397574f*z2 - 9.5118796751096362f*z4 - 0.45294665119569694f ; // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))
|
168 |
+
dx[32] = 4.7935367849733241f*xz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi))
|
169 |
+
dx[33] = -13.209434084751759f*x2*z2 + 1.4677148983057511f*x2 + 13.209434084751759f*y2*z2 - 1.4677148983057511f*y2 ; // 3*sqrt(770)*(-9*x2*z2 + x2 + 9*y2*z2 - y2)/(32*sqrt(pi))
|
170 |
+
dx[34] = 8.3026492595241645f*xz*(x2 - 3.0f*y2) ; // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi))
|
171 |
+
dx[35] = 19.6914617052051f*x2*y2 - 3.2819102842008503f*x4 - 3.2819102842008503f*y4 ; // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))
|
172 |
+
if (C <= 6) { return; }
|
173 |
+
dx[36] = 4.0991046311514854f*y*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // 3*sqrt(6006)*y*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi))
|
174 |
+
dx[37] = 47.332383244635047f*xy*z*(-x2 + y2) ; // 15*sqrt(2002)*xy*z*(-x2 + y2)/(8*sqrt(pi))
|
175 |
+
dx[38] = 2.0182596029148963f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi))
|
176 |
+
dx[39] = 5.5272315570895412f*xy*z*(3.0f - 11.0f*z2) ; // 3*sqrt(2730)*xy*z*(3 - 11*z2)/(16*sqrt(pi))
|
177 |
+
dx[40] = 0.92120525951492349f*y*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*y*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))
|
178 |
+
dx[41] = 0.0f ; // 0
|
179 |
+
dx[42] = 0.0f ; // 0
|
180 |
+
dx[43] = 0.58262136251873131f*z*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))
|
181 |
+
dx[44] = 0.92120525951492349f*x*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))
|
182 |
+
dx[45] = -2.7636157785447706f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi))
|
183 |
+
dx[46] = 2.0182596029148963f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi))
|
184 |
+
dx[47] = 11.833095811158762f*z*(6.0f*x2*y2 - x4 - y4) ; // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))
|
185 |
+
dx[48] = 4.0991046311514854f*x*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))
|
186 |
+
if (C <= 7) { return; }
|
187 |
+
dx[49] = 9.9002782553443485f*xy*(10.0f*x2*y2 - 3.0f*x4 - 3.0f*y4) ; // 21*sqrt(715)*xy*(10*x2*y2 - 3*x4 - 3*y4)/(32*sqrt(pi))
|
188 |
+
dx[50] = 15.875763970811402f*yz*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // 9*sqrt(10010)*yz*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi))
|
189 |
+
dx[51] = -10.378311574405206f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // -15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi))
|
190 |
+
dx[52] = 4.1513246297620823f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi))
|
191 |
+
dx[53] = 0.93875360317376422f*xy*(66.0f*z2 - 143.0f*z4 - 3.0f) ; // 9*sqrt(35)*xy*(66*z2 - 143*z4 - 3)/(32*sqrt(pi))
|
192 |
+
dx[54] = 0.44253269244498261f*yz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*yz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))
|
193 |
+
dx[55] = 0.0f ; // 0
|
194 |
+
dx[56] = 0.0f ; // 0
|
195 |
+
dx[57] = -12.194767023639836f*z2 + 44.714145753346067f*z4 - 38.752259652899923f*z6 + 0.45165803791258652f ; // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))
|
196 |
+
dx[58] = 0.44253269244498261f*xz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))
|
197 |
+
dx[59] = 30.97886890473422f*x2*z2 - 67.120882626924143f*x2*z4 - 1.4081304047606462f*x2 - 30.97886890473422f*y2*z2 + 67.120882626924143f*y2*z4 + 1.4081304047606462f*y2 ; // 9*sqrt(35)*(66*x2*z2 - 143*x2*z4 - 3*x2 - 66*y2*z2 + 143*y2*z4 + 3*y2)/(64*sqrt(pi))
|
198 |
+
dx[60] = 4.1513246297620823f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi))
|
199 |
+
dx[61] = -0.51891557872026028f*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 4.0f*x2*(x2 - 5.0f*y2) + x4 + 5.0f*y4) ; // -3*sqrt(385)*(13*z2 - 1)*(-10*x2*y2 + 4*x2*(x2 - 5*y2) + x4 + 5*y4)/(64*sqrt(pi))
|
200 |
+
dx[62] = 15.875763970811402f*xz*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))
|
201 |
+
dx[63] = -74.252086915082614f*x2*y4 + 74.252086915082614f*x4*y2 - 4.9501391276721742f*x6 + 4.9501391276721742f*y6 ; // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi))
|
202 |
+
};
|
203 |
+
|
204 |
+
auto write_sh_dy = [&]() {
|
205 |
+
dy[0] = 0.0f ; // 0
|
206 |
+
if (C <= 1) { return; }
|
207 |
+
dy[1] = -0.48860251190291992f ; // -sqrt(3)/(2*sqrt(pi))
|
208 |
+
dy[2] = 0.0f ; // 0
|
209 |
+
dy[3] = 0.0f ; // 0
|
210 |
+
if (C <= 2) { return; }
|
211 |
+
dy[4] = 1.0925484305920792f*x ; // sqrt(15)*x/(2*sqrt(pi))
|
212 |
+
dy[5] = -1.0925484305920792f*z ; // -sqrt(15)*z/(2*sqrt(pi))
|
213 |
+
dy[6] = 0.0f ; // 0
|
214 |
+
dy[7] = 0.0f ; // 0
|
215 |
+
dy[8] = -1.0925484305920792f*y ; // -sqrt(15)*y/(2*sqrt(pi))
|
216 |
+
if (C <= 3) { return; }
|
217 |
+
dy[9] = -1.7701307697799304f*x2 + 1.7701307697799304f*y2 ; // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi))
|
218 |
+
dy[10] = 2.8906114426405538f*xz ; // sqrt(105)*xz/(2*sqrt(pi))
|
219 |
+
dy[11] = 0.45704579946446572f - 2.2852289973223288f*z2 ; // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi))
|
220 |
+
dy[12] = 0.0f ; // 0
|
221 |
+
dy[13] = 0.0f ; // 0
|
222 |
+
dy[14] = -2.8906114426405538f*yz ; // -sqrt(105)*yz/(2*sqrt(pi))
|
223 |
+
dy[15] = 3.5402615395598609f*xy ; // 3*sqrt(70)*xy/(4*sqrt(pi))
|
224 |
+
if (C <= 4) { return; }
|
225 |
+
dy[16] = 2.5033429417967046f*x*(x2 - 3.0f*y2) ; // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi))
|
226 |
+
dy[17] = 5.3103923093397913f*z*(-x2 + y2) ; // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi))
|
227 |
+
dy[18] = 0.94617469575756008f*x*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi))
|
228 |
+
dy[19] = 0.66904654355728921f*z*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi))
|
229 |
+
dy[20] = 0.0f ; // 0
|
230 |
+
dy[21] = 0.0f ; // 0
|
231 |
+
dy[22] = 0.94617469575756008f*y*(1.0f - 7.0f*z2) ; // 3*sqrt(5)*y*(1 - 7*z2)/(4*sqrt(pi))
|
232 |
+
dy[23] = 10.620784618679583f*xy*z ; // 9*sqrt(70)*xy*z/(4*sqrt(pi))
|
233 |
+
dy[24] = 2.5033429417967046f*y*(-3.0f*x2 + y2) ; // 3*sqrt(35)*y*(-3*x2 + y2)/(4*sqrt(pi))
|
234 |
+
if (C <= 5) { return; }
|
235 |
+
dy[25] = 19.6914617052051f*x2*y2 - 3.2819102842008503f*x4 - 3.2819102842008503f*y4 ; // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))
|
236 |
+
dy[26] = 8.3026492595241645f*xz*(x2 - 3.0f*y2) ; // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi))
|
237 |
+
dy[27] = -1.4677148983057511f*(x2 - y2)*(9.0f*z2 - 1.0f) ; // -3*sqrt(770)*(x2 - y2)*(9*z2 - 1)/(32*sqrt(pi))
|
238 |
+
dy[28] = 4.7935367849733241f*xz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi))
|
239 |
+
dy[29] = 6.3412531167397574f*z2 - 9.5118796751096362f*z4 - 0.45294665119569694f ; // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))
|
240 |
+
dy[30] = 0.0f ; // 0
|
241 |
+
dy[31] = 0.0f ; // 0
|
242 |
+
dy[32] = 4.7935367849733241f*yz*(1.0f - 3.0f*z2) ; // sqrt(1155)*yz*(1 - 3*z2)/(4*sqrt(pi))
|
243 |
+
dy[33] = 2.9354297966115022f*xy*(9.0f*z2 - 1.0f) ; // 3*sqrt(770)*xy*(9*z2 - 1)/(16*sqrt(pi))
|
244 |
+
dy[34] = 8.3026492595241645f*yz*(-3.0f*x2 + y2) ; // 3*sqrt(385)*yz*(-3*x2 + y2)/(4*sqrt(pi))
|
245 |
+
dy[35] = 13.127641136803401f*xy*(x2 - y2) ; // 15*sqrt(154)*xy*(x2 - y2)/(8*sqrt(pi))
|
246 |
+
if (C <= 6) { return; }
|
247 |
+
dy[36] = 4.0991046311514854f*x*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))
|
248 |
+
dy[37] = 11.833095811158762f*z*(6.0f*x2*y2 - x4 - y4) ; // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))
|
249 |
+
dy[38] = 2.0182596029148963f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi))
|
250 |
+
dy[39] = -2.7636157785447706f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi))
|
251 |
+
dy[40] = 0.92120525951492349f*x*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))
|
252 |
+
dy[41] = 0.58262136251873131f*z*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))
|
253 |
+
dy[42] = 0.0f ; // 0
|
254 |
+
dy[43] = 0.0f ; // 0
|
255 |
+
dy[44] = 0.92120525951492349f*y*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // sqrt(2730)*y*(18*z2 - 33*z4 - 1)/(32*sqrt(pi))
|
256 |
+
dy[45] = 5.5272315570895412f*xy*z*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(16*sqrt(pi))
|
257 |
+
dy[46] = -2.0182596029148963f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi))
|
258 |
+
dy[47] = 47.332383244635047f*xy*z*(x2 - y2) ; // 15*sqrt(2002)*xy*z*(x2 - y2)/(8*sqrt(pi))
|
259 |
+
dy[48] = 4.0991046311514854f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(6006)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
|
260 |
+
if (C <= 7) { return; }
|
261 |
+
dy[49] = -74.252086915082614f*x2*y4 + 74.252086915082614f*x4*y2 - 4.9501391276721742f*x6 + 4.9501391276721742f*y6 ; // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi))
|
262 |
+
dy[50] = 15.875763970811402f*xz*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))
|
263 |
+
dy[51] = 0.51891557872026028f*(13.0f*z2 - 1.0f)*(10.0f*x2*y2 - 5.0f*x4 + 4.0f*y2*(5.0f*x2 - y2) - y4) ; // 3*sqrt(385)*(13*z2 - 1)*(10*x2*y2 - 5*x4 + 4*y2*(5*x2 - y2) - y4)/(64*sqrt(pi))
|
264 |
+
dy[52] = 4.1513246297620823f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi))
|
265 |
+
dy[53] = -0.46937680158688211f*(x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -9*sqrt(35)*(x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi))
|
266 |
+
dy[54] = 0.44253269244498261f*xz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))
|
267 |
+
dy[55] = -12.194767023639836f*z2 + 44.714145753346067f*z4 - 38.752259652899923f*z6 + 0.45165803791258652f ; // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))
|
268 |
+
dy[56] = 0.0f ; // 0
|
269 |
+
dy[57] = 0.0f ; // 0
|
270 |
+
dy[58] = 0.44253269244498261f*yz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 3*sqrt(70)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi))
|
271 |
+
dy[59] = 0.93875360317376422f*xy*(-66.0f*z2 + 143.0f*z4 + 3.0f) ; // 9*sqrt(35)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi))
|
272 |
+
dy[60] = -4.1513246297620823f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // -3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi))
|
273 |
+
dy[61] = 10.378311574405206f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // 15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi))
|
274 |
+
dy[62] = 15.875763970811402f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 9*sqrt(10010)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
|
275 |
+
dy[63] = 9.9002782553443485f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 21*sqrt(715)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))
|
276 |
+
};
|
277 |
+
|
278 |
+
auto write_sh_dz = [&]() {
|
279 |
+
dz[0] = 0.0f ; // 0
|
280 |
+
if (C <= 1) { return; }
|
281 |
+
dz[1] = 0.0f ; // 0
|
282 |
+
dz[2] = 0.48860251190291992f ; // sqrt(3)/(2*sqrt(pi))
|
283 |
+
dz[3] = 0.0f ; // 0
|
284 |
+
if (C <= 2) { return; }
|
285 |
+
dz[4] = 0.0f ; // 0
|
286 |
+
dz[5] = -1.0925484305920792f*y ; // -sqrt(15)*y/(2*sqrt(pi))
|
287 |
+
dz[6] = 1.8923493915151202f*z ; // 3*sqrt(5)*z/(2*sqrt(pi))
|
288 |
+
dz[7] = -1.0925484305920792f*x ; // -sqrt(15)*x/(2*sqrt(pi))
|
289 |
+
dz[8] = 0.0f ; // 0
|
290 |
+
if (C <= 3) { return; }
|
291 |
+
dz[9] = 0.0f ; // 0
|
292 |
+
dz[10] = 2.8906114426405538f*xy ; // sqrt(105)*xy/(2*sqrt(pi))
|
293 |
+
dz[11] = -4.5704579946446566f*yz ; // -5*sqrt(42)*yz/(4*sqrt(pi))
|
294 |
+
dz[12] = 5.597644988851731f*z2 - 1.1195289977703462f ; // 3*sqrt(7)*(5*z2 - 1)/(4*sqrt(pi))
|
295 |
+
dz[13] = -4.5704579946446566f*xz ; // -5*sqrt(42)*xz/(4*sqrt(pi))
|
296 |
+
dz[14] = 1.4453057213202769f*x2 - 1.4453057213202769f*y2 ; // sqrt(105)*(x2 - y2)/(4*sqrt(pi))
|
297 |
+
dz[15] = 0.0f ; // 0
|
298 |
+
if (C <= 4) { return; }
|
299 |
+
dz[16] = 0.0f ; // 0
|
300 |
+
dz[17] = 1.7701307697799304f*y*(-3.0f*x2 + y2) ; // 3*sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi))
|
301 |
+
dz[18] = 13.246445740605839f*xy*z ; // 21*sqrt(5)*xy*z/(2*sqrt(pi))
|
302 |
+
dz[19] = 2.0071396306718676f*y*(1.0f - 7.0f*z2) ; // 9*sqrt(10)*y*(1 - 7*z2)/(8*sqrt(pi))
|
303 |
+
dz[20] = 14.809976568128603f*pow(z, 3) - 6.3471328149122579f*z ; // (105*z**3 - 45*z)/(4*sqrt(pi))
|
304 |
+
dz[21] = 2.0071396306718676f*x*(1.0f - 7.0f*z2) ; // 9*sqrt(10)*x*(1 - 7*z2)/(8*sqrt(pi))
|
305 |
+
dz[22] = 6.6232228703029197f*z*(x2 - y2) ; // 21*sqrt(5)*z*(x2 - y2)/(4*sqrt(pi))
|
306 |
+
dz[23] = 1.7701307697799304f*x*(-x2 + 3.0f*y2) ; // 3*sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi))
|
307 |
+
dz[24] = 0.0f ; // 0
|
308 |
+
if (C <= 5) { return; }
|
309 |
+
dz[25] = 0.0f ; // 0
|
310 |
+
dz[26] = 8.3026492595241645f*xy*(x2 - y2) ; // 3*sqrt(385)*xy*(x2 - y2)/(4*sqrt(pi))
|
311 |
+
dz[27] = 8.8062893898345074f*yz*(-3.0f*x2 + y2) ; // 9*sqrt(770)*yz*(-3*x2 + y2)/(16*sqrt(pi))
|
312 |
+
dz[28] = 4.7935367849733241f*xy*(9.0f*z2 - 1.0f) ; // sqrt(1155)*xy*(9*z2 - 1)/(4*sqrt(pi))
|
313 |
+
dz[29] = 12.682506233479513f*yz*(1.0f - 3.0f*z2) ; // 7*sqrt(165)*yz*(1 - 3*z2)/(4*sqrt(pi))
|
314 |
+
dz[30] = -24.559567715218954f*z2 + 36.839351572828434f*z4 + 1.754254836801354f ; // 15*sqrt(11)*(-14*z2 + 21*z4 + 1)/(16*sqrt(pi))
|
315 |
+
dz[31] = 12.682506233479513f*xz*(1.0f - 3.0f*z2) ; // 7*sqrt(165)*xz*(1 - 3*z2)/(4*sqrt(pi))
|
316 |
+
dz[32] = 2.3967683924866621f*(x2 - y2)*(9.0f*z2 - 1.0f) ; // sqrt(1155)*(x2 - y2)*(9*z2 - 1)/(8*sqrt(pi))
|
317 |
+
dz[33] = 8.8062893898345074f*xz*(-x2 + 3.0f*y2) ; // 9*sqrt(770)*xz*(-x2 + 3*y2)/(16*sqrt(pi))
|
318 |
+
dz[34] = -12.453973889286246f*x2*y2 + 2.0756623148810411f*x4 + 2.0756623148810411f*y4 ; // 3*sqrt(385)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))
|
319 |
+
dz[35] = 0.0f ; // 0
|
320 |
+
if (C <= 6) { return; }
|
321 |
+
dz[36] = 0.0f ; // 0
|
322 |
+
dz[37] = 2.3666191622317521f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(2002)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
|
323 |
+
dz[38] = 44.401711264127719f*xy*z*(x2 - y2) ; // 33*sqrt(91)*xy*z*(x2 - y2)/(4*sqrt(pi))
|
324 |
+
dz[39] = -2.7636157785447706f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(2730)*y*(3*x2 - y2)*(11*z2 - 1)/(32*sqrt(pi))
|
325 |
+
dz[40] = 11.054463114179082f*xy*z*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(8*sqrt(pi))
|
326 |
+
dz[41] = 2.9131068125936568f*y*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // 5*sqrt(273)*y*(18*z2 - 33*z4 - 1)/(16*sqrt(pi))
|
327 |
+
dz[42] = 2.6699064952403937f*z*(-30.0f*z2 + 33.0f*z4 + 5.0f) ; // 21*sqrt(13)*z*(-30*z2 + 33*z4 + 5)/(16*sqrt(pi))
|
328 |
+
dz[43] = 2.9131068125936568f*x*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // 5*sqrt(273)*x*(18*z2 - 33*z4 - 1)/(16*sqrt(pi))
|
329 |
+
dz[44] = 5.5272315570895412f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(16*sqrt(pi))
|
330 |
+
dz[45] = -2.7636157785447706f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(2730)*x*(x2 - 3*y2)*(11*z2 - 1)/(32*sqrt(pi))
|
331 |
+
dz[46] = 11.10042781603193f*z*(-6.0f*x2*y2 + x4 + y4) ; // 33*sqrt(91)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))
|
332 |
+
dz[47] = 2.3666191622317521f*x*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(2002)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))
|
333 |
+
dz[48] = 0.0f ; // 0
|
334 |
+
if (C <= 7) { return; }
|
335 |
+
dz[49] = 0.0f ; // 0
|
336 |
+
dz[50] = 5.2919213236038001f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 3*sqrt(10010)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))
|
337 |
+
dz[51] = 13.491805046726766f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 39*sqrt(385)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
|
338 |
+
dz[52] = 12.453973889286248f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // 9*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(8*sqrt(pi))
|
339 |
+
dz[53] = -6.8841930899409371f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // -33*sqrt(35)*yz*(3*x2 - y2)*(13*z2 - 3)/(16*sqrt(pi))
|
340 |
+
dz[54] = 2.2126634622249131f*xy*(-66.0f*z2 + 143.0f*z4 + 3.0f) ; // 15*sqrt(70)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi))
|
341 |
+
dz[55] = 1.6259689364853116f*yz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 9*sqrt(105)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi))
|
342 |
+
dz[56] = 64.528641681844675f*z2 - 236.60501950009714f*z4 + 205.05768356675085f*z6 - 2.3899496919201733f ; // 7*sqrt(15)*(135*z2 - 495*z4 + 429*z6 - 5)/(32*sqrt(pi))
|
343 |
+
dz[57] = 1.6259689364853116f*xz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 9*sqrt(105)*xz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi))
|
344 |
+
dz[58] = 0.07375544874083044f*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) + 132.0f*z2*(13.0f*z2 - 5.0f) - 187.0f*z2 + 45.0f) ; // sqrt(70)*(x2 - y2)*(143*z2*(3*z2 - 1) + 132*z2*(13*z2 - 5) - 187*z2 + 45)/(64*sqrt(pi))
|
345 |
+
dz[59] = -6.8841930899409371f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // -33*sqrt(35)*xz*(x2 - 3*y2)*(13*z2 - 3)/(16*sqrt(pi))
|
346 |
+
dz[60] = 3.1134934723215619f*(13.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 9*sqrt(385)*(13*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi))
|
347 |
+
dz[61] = 13.491805046726766f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 39*sqrt(385)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))
|
348 |
+
dz[62] = 39.6894099270285f*x2*y4 - 39.6894099270285f*x4*y2 + 2.6459606618019f*x6 - 2.6459606618019f*y6 ; // 3*sqrt(10010)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi))
|
349 |
+
dz[63] = 0.0f ; // 0
|
350 |
+
};
|
351 |
+
write_sh_dx();
|
352 |
+
write_sh_dy();
|
353 |
+
write_sh_dz();
|
354 |
+
}
|
355 |
+
}
|
356 |
+
|
357 |
+
|
358 |
+
template <typename scalar_t>
|
359 |
+
__global__ void kernel_sh_backward(
|
360 |
+
const scalar_t * __restrict__ grad,
|
361 |
+
const scalar_t * __restrict__ inputs,
|
362 |
+
uint32_t B, uint32_t D, uint32_t C,
|
363 |
+
const scalar_t * __restrict__ dy_dx,
|
364 |
+
scalar_t * grad_inputs
|
365 |
+
) {
|
366 |
+
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
|
367 |
+
const uint32_t b = t / D;
|
368 |
+
if (b >= B) return;
|
369 |
+
|
370 |
+
const uint32_t d = t - b * D;
|
371 |
+
const uint32_t C2 = C * C;
|
372 |
+
|
373 |
+
// locate
|
374 |
+
grad += b * C2;
|
375 |
+
dy_dx += b * D * C2 + d * C2;
|
376 |
+
|
377 |
+
for (int ch = 0; ch < C2; ch++) {
|
378 |
+
grad_inputs[t] += grad[ch] * dy_dx[ch];
|
379 |
+
//printf("t=%d, b=%d, d=%d, ch=%d, grad=%f (+= %f * %f)\n", t, b, d, ch, grad_inputs[t], grad[ch], dy_dx[ch]);
|
380 |
+
}
|
381 |
+
|
382 |
+
}
|
383 |
+
|
384 |
+
// inputs: [B, D], float, in [0, 1]
|
385 |
+
// outputs: [B, L * C], float
|
386 |
+
template <typename scalar_t>
|
387 |
+
void sh_encode_forward_cuda(const scalar_t *inputs, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, scalar_t *dy_dx) {
|
388 |
+
static constexpr uint32_t N_THREADS = 256;
|
389 |
+
kernel_sh<scalar_t><<<div_round_up(B, N_THREADS), N_THREADS>>>(inputs, outputs, B, D, C, dy_dx);
|
390 |
+
}
|
391 |
+
|
392 |
+
|
393 |
+
template <typename scalar_t>
|
394 |
+
void sh_encode_backward_cuda(const scalar_t *grad, const scalar_t *inputs, const uint32_t B, const uint32_t D, const uint32_t C, scalar_t *dy_dx, scalar_t *grad_inputs) {
|
395 |
+
static constexpr uint32_t N_THREADS = 256;
|
396 |
+
kernel_sh_backward<scalar_t><<<div_round_up(B * D, N_THREADS), N_THREADS>>>(grad, inputs, B, D, C, dy_dx, grad_inputs);
|
397 |
+
}
|
398 |
+
|
399 |
+
|
400 |
+
void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional<at::Tensor> dy_dx) {
|
401 |
+
CHECK_CUDA(inputs);
|
402 |
+
CHECK_CUDA(outputs);
|
403 |
+
// CHECK_CUDA(dy_dx);
|
404 |
+
|
405 |
+
CHECK_CONTIGUOUS(inputs);
|
406 |
+
CHECK_CONTIGUOUS(outputs);
|
407 |
+
// CHECK_CONTIGUOUS(dy_dx);
|
408 |
+
|
409 |
+
CHECK_IS_FLOATING(inputs);
|
410 |
+
CHECK_IS_FLOATING(outputs);
|
411 |
+
// CHECK_IS_FLOATING(dy_dx);
|
412 |
+
|
413 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
414 |
+
inputs.scalar_type(), "sh_encode_forward_cuda", ([&] {
|
415 |
+
sh_encode_forward_cuda<scalar_t>(inputs.data_ptr<scalar_t>(), outputs.data_ptr<scalar_t>(), B, D, C, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr);
|
416 |
+
}));
|
417 |
+
}
|
418 |
+
|
419 |
+
void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs) {
|
420 |
+
CHECK_CUDA(grad);
|
421 |
+
CHECK_CUDA(inputs);
|
422 |
+
CHECK_CUDA(dy_dx);
|
423 |
+
CHECK_CUDA(grad_inputs);
|
424 |
+
|
425 |
+
CHECK_CONTIGUOUS(grad);
|
426 |
+
CHECK_CONTIGUOUS(inputs);
|
427 |
+
CHECK_CONTIGUOUS(dy_dx);
|
428 |
+
CHECK_CONTIGUOUS(grad_inputs);
|
429 |
+
|
430 |
+
CHECK_IS_FLOATING(grad);
|
431 |
+
CHECK_IS_FLOATING(inputs);
|
432 |
+
CHECK_IS_FLOATING(dy_dx);
|
433 |
+
CHECK_IS_FLOATING(grad_inputs);
|
434 |
+
|
435 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
436 |
+
grad.scalar_type(), "sh_encode_backward_cuda", ([&] {
|
437 |
+
sh_encode_backward_cuda<scalar_t>(grad.data_ptr<scalar_t>(), inputs.data_ptr<scalar_t>(), B, D, C, dy_dx.data_ptr<scalar_t>(), grad_inputs.data_ptr<scalar_t>());
|
438 |
+
}));
|
439 |
+
}
|