ashawkey
commited on
Commit
•
904ef7d
0
Parent(s):
init
Browse files- .gitignore +11 -0
- LICENSE +21 -0
- activation.py +18 -0
- assets/gallery.md +0 -0
- assets/update_logs.md +5 -0
- encoding.py +78 -0
- freqencoder/__init__.py +1 -0
- freqencoder/backend.py +41 -0
- freqencoder/freq.py +77 -0
- freqencoder/setup.py +51 -0
- freqencoder/src/bindings.cpp +8 -0
- freqencoder/src/freqencoder.cu +129 -0
- freqencoder/src/freqencoder.h +10 -0
- gridencoder/__init__.py +1 -0
- gridencoder/backend.py +40 -0
- gridencoder/grid.py +154 -0
- gridencoder/setup.py +50 -0
- gridencoder/src/bindings.cpp +8 -0
- gridencoder/src/gridencoder.cu +479 -0
- gridencoder/src/gridencoder.h +15 -0
- loss.py +11 -0
- main_nerf.py +137 -0
- nerf/clip.py +45 -0
- nerf/gui.py +465 -0
- nerf/network.py +184 -0
- nerf/network_grid.py +186 -0
- nerf/network_tcnn.py +189 -0
- nerf/provider.py +197 -0
- nerf/renderer.py +638 -0
- nerf/sd.py +201 -0
- nerf/utils.py +935 -0
- optimizer.py +470 -0
- raymarching/__init__.py +1 -0
- raymarching/backend.py +40 -0
- raymarching/raymarching.py +373 -0
- raymarching/setup.py +62 -0
- raymarching/src/bindings.cpp +19 -0
- raymarching/src/raymarching.cu +914 -0
- raymarching/src/raymarching.h +18 -0
- readme.md +91 -0
- requirements.txt +17 -0
- scripts/install_ext.sh +4 -0
- scripts/run.sh +5 -0
- shencoder/__init__.py +1 -0
- shencoder/backend.py +40 -0
- shencoder/setup.py +50 -0
- shencoder/sphere_harmonics.py +87 -0
- shencoder/src/bindings.cpp +8 -0
- shencoder/src/shencoder.cu +439 -0
- shencoder/src/shencoder.h +10 -0
.gitignore
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
build/
|
3 |
+
*.egg-info/
|
4 |
+
*.so
|
5 |
+
|
6 |
+
tmp*
|
7 |
+
data/
|
8 |
+
trial*/
|
9 |
+
.vs/
|
10 |
+
|
11 |
+
TOKEN
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2022 hawkey
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
activation.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.autograd import Function
|
3 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
4 |
+
|
5 |
+
class _trunc_exp(Function):
|
6 |
+
@staticmethod
|
7 |
+
@custom_fwd(cast_inputs=torch.float)
|
8 |
+
def forward(ctx, x):
|
9 |
+
ctx.save_for_backward(x)
|
10 |
+
return torch.exp(x)
|
11 |
+
|
12 |
+
@staticmethod
|
13 |
+
@custom_bwd
|
14 |
+
def backward(ctx, g):
|
15 |
+
x = ctx.saved_tensors[0]
|
16 |
+
return g * torch.exp(x.clamp(-15, 15))
|
17 |
+
|
18 |
+
trunc_exp = _trunc_exp.apply
|
assets/gallery.md
ADDED
File without changes
|
assets/update_logs.md
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### 2022.10.5
|
2 |
+
* Basic reproduction finished.
|
3 |
+
* Non --cuda_ray, --tcnn are not working, need to fix.
|
4 |
+
* Shading is not working, disabled in utils.py for now. Surface normals are bad.
|
5 |
+
* Use an entropy loss to regularize weights_sum (alpha), the original L2 reg always leads to degenerated geometry...
|
encoding.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
class FreqEncoder(nn.Module):
|
6 |
+
def __init__(self, input_dim, max_freq_log2, N_freqs,
|
7 |
+
log_sampling=True, include_input=True,
|
8 |
+
periodic_fns=(torch.sin, torch.cos)):
|
9 |
+
|
10 |
+
super().__init__()
|
11 |
+
|
12 |
+
self.input_dim = input_dim
|
13 |
+
self.include_input = include_input
|
14 |
+
self.periodic_fns = periodic_fns
|
15 |
+
|
16 |
+
self.output_dim = 0
|
17 |
+
if self.include_input:
|
18 |
+
self.output_dim += self.input_dim
|
19 |
+
|
20 |
+
self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns)
|
21 |
+
|
22 |
+
if log_sampling:
|
23 |
+
self.freq_bands = 2. ** torch.linspace(0., max_freq_log2, N_freqs)
|
24 |
+
else:
|
25 |
+
self.freq_bands = torch.linspace(2. ** 0., 2. ** max_freq_log2, N_freqs)
|
26 |
+
|
27 |
+
self.freq_bands = self.freq_bands.numpy().tolist()
|
28 |
+
|
29 |
+
def forward(self, input, **kwargs):
|
30 |
+
|
31 |
+
out = []
|
32 |
+
if self.include_input:
|
33 |
+
out.append(input)
|
34 |
+
|
35 |
+
for i in range(len(self.freq_bands)):
|
36 |
+
freq = self.freq_bands[i]
|
37 |
+
for p_fn in self.periodic_fns:
|
38 |
+
out.append(p_fn(input * freq))
|
39 |
+
|
40 |
+
out = torch.cat(out, dim=-1)
|
41 |
+
|
42 |
+
|
43 |
+
return out
|
44 |
+
|
45 |
+
def get_encoder(encoding, input_dim=3,
|
46 |
+
multires=6,
|
47 |
+
degree=4,
|
48 |
+
num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False,
|
49 |
+
**kwargs):
|
50 |
+
|
51 |
+
if encoding == 'None':
|
52 |
+
return lambda x, **kwargs: x, input_dim
|
53 |
+
|
54 |
+
elif encoding == 'frequency':
|
55 |
+
#encoder = FreqEncoder(input_dim=input_dim, max_freq_log2=multires-1, N_freqs=multires, log_sampling=True)
|
56 |
+
from freqencoder import FreqEncoder
|
57 |
+
encoder = FreqEncoder(input_dim=input_dim, degree=multires)
|
58 |
+
|
59 |
+
elif encoding == 'sphere_harmonics':
|
60 |
+
from shencoder import SHEncoder
|
61 |
+
encoder = SHEncoder(input_dim=input_dim, degree=degree)
|
62 |
+
|
63 |
+
elif encoding == 'hashgrid':
|
64 |
+
from gridencoder import GridEncoder
|
65 |
+
encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners)
|
66 |
+
|
67 |
+
elif encoding == 'tiledgrid':
|
68 |
+
from gridencoder import GridEncoder
|
69 |
+
encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners)
|
70 |
+
|
71 |
+
elif encoding == 'ash':
|
72 |
+
from ashencoder import AshEncoder
|
73 |
+
encoder = AshEncoder(input_dim=input_dim, output_dim=16, log2_hashmap_size=log2_hashmap_size, resolution=desired_resolution)
|
74 |
+
|
75 |
+
else:
|
76 |
+
raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, sphere_harmonics, hashgrid, tiledgrid]')
|
77 |
+
|
78 |
+
return encoder, encoder.output_dim
|
freqencoder/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .freq import FreqEncoder
|
freqencoder/backend.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from torch.utils.cpp_extension import load
|
3 |
+
|
4 |
+
_src_path = os.path.dirname(os.path.abspath(__file__))
|
5 |
+
|
6 |
+
nvcc_flags = [
|
7 |
+
'-O3', '-std=c++14',
|
8 |
+
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
9 |
+
'-use_fast_math'
|
10 |
+
]
|
11 |
+
|
12 |
+
if os.name == "posix":
|
13 |
+
c_flags = ['-O3', '-std=c++14']
|
14 |
+
elif os.name == "nt":
|
15 |
+
c_flags = ['/O2', '/std:c++17']
|
16 |
+
|
17 |
+
# find cl.exe
|
18 |
+
def find_cl_path():
|
19 |
+
import glob
|
20 |
+
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
21 |
+
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
|
22 |
+
if paths:
|
23 |
+
return paths[0]
|
24 |
+
|
25 |
+
# If cl.exe is not on path, try to find it.
|
26 |
+
if os.system("where cl.exe >nul 2>nul") != 0:
|
27 |
+
cl_path = find_cl_path()
|
28 |
+
if cl_path is None:
|
29 |
+
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
30 |
+
os.environ["PATH"] += ";" + cl_path
|
31 |
+
|
32 |
+
_backend = load(name='_freqencoder',
|
33 |
+
extra_cflags=c_flags,
|
34 |
+
extra_cuda_cflags=nvcc_flags,
|
35 |
+
sources=[os.path.join(_src_path, 'src', f) for f in [
|
36 |
+
'freqencoder.cu',
|
37 |
+
'bindings.cpp',
|
38 |
+
]],
|
39 |
+
)
|
40 |
+
|
41 |
+
__all__ = ['_backend']
|
freqencoder/freq.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.autograd import Function
|
6 |
+
from torch.autograd.function import once_differentiable
|
7 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
8 |
+
|
9 |
+
try:
|
10 |
+
import _freqencoder as _backend
|
11 |
+
except ImportError:
|
12 |
+
from .backend import _backend
|
13 |
+
|
14 |
+
|
15 |
+
class _freq_encoder(Function):
|
16 |
+
@staticmethod
|
17 |
+
@custom_fwd(cast_inputs=torch.float32) # force float32 for better precision
|
18 |
+
def forward(ctx, inputs, degree, output_dim):
|
19 |
+
# inputs: [B, input_dim], float
|
20 |
+
# RETURN: [B, F], float
|
21 |
+
|
22 |
+
if not inputs.is_cuda: inputs = inputs.cuda()
|
23 |
+
inputs = inputs.contiguous()
|
24 |
+
|
25 |
+
B, input_dim = inputs.shape # batch size, coord dim
|
26 |
+
|
27 |
+
outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)
|
28 |
+
|
29 |
+
_backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
|
30 |
+
|
31 |
+
ctx.save_for_backward(inputs, outputs)
|
32 |
+
ctx.dims = [B, input_dim, degree, output_dim]
|
33 |
+
|
34 |
+
return outputs
|
35 |
+
|
36 |
+
@staticmethod
|
37 |
+
#@once_differentiable
|
38 |
+
@custom_bwd
|
39 |
+
def backward(ctx, grad):
|
40 |
+
# grad: [B, C * C]
|
41 |
+
|
42 |
+
grad = grad.contiguous()
|
43 |
+
inputs, outputs = ctx.saved_tensors
|
44 |
+
B, input_dim, degree, output_dim = ctx.dims
|
45 |
+
|
46 |
+
grad_inputs = torch.zeros_like(inputs)
|
47 |
+
_backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
|
48 |
+
|
49 |
+
return grad_inputs, None, None
|
50 |
+
|
51 |
+
|
52 |
+
freq_encode = _freq_encoder.apply
|
53 |
+
|
54 |
+
|
55 |
+
class FreqEncoder(nn.Module):
|
56 |
+
def __init__(self, input_dim=3, degree=4):
|
57 |
+
super().__init__()
|
58 |
+
|
59 |
+
self.input_dim = input_dim
|
60 |
+
self.degree = degree
|
61 |
+
self.output_dim = input_dim + input_dim * 2 * degree
|
62 |
+
|
63 |
+
def __repr__(self):
|
64 |
+
return f"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}"
|
65 |
+
|
66 |
+
def forward(self, inputs, **kwargs):
|
67 |
+
# inputs: [..., input_dim]
|
68 |
+
# return: [..., ]
|
69 |
+
|
70 |
+
prefix_shape = list(inputs.shape[:-1])
|
71 |
+
inputs = inputs.reshape(-1, self.input_dim)
|
72 |
+
|
73 |
+
outputs = freq_encode(inputs, self.degree, self.output_dim)
|
74 |
+
|
75 |
+
outputs = outputs.reshape(prefix_shape + [self.output_dim])
|
76 |
+
|
77 |
+
return outputs
|
freqencoder/setup.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from setuptools import setup
|
3 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
4 |
+
|
5 |
+
_src_path = os.path.dirname(os.path.abspath(__file__))
|
6 |
+
|
7 |
+
nvcc_flags = [
|
8 |
+
'-O3', '-std=c++14',
|
9 |
+
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
10 |
+
'-use_fast_math'
|
11 |
+
]
|
12 |
+
|
13 |
+
if os.name == "posix":
|
14 |
+
c_flags = ['-O3', '-std=c++14']
|
15 |
+
elif os.name == "nt":
|
16 |
+
c_flags = ['/O2', '/std:c++17']
|
17 |
+
|
18 |
+
# find cl.exe
|
19 |
+
def find_cl_path():
|
20 |
+
import glob
|
21 |
+
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
22 |
+
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
|
23 |
+
if paths:
|
24 |
+
return paths[0]
|
25 |
+
|
26 |
+
# If cl.exe is not on path, try to find it.
|
27 |
+
if os.system("where cl.exe >nul 2>nul") != 0:
|
28 |
+
cl_path = find_cl_path()
|
29 |
+
if cl_path is None:
|
30 |
+
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
31 |
+
os.environ["PATH"] += ";" + cl_path
|
32 |
+
|
33 |
+
setup(
|
34 |
+
name='freqencoder', # package name, import this to use python API
|
35 |
+
ext_modules=[
|
36 |
+
CUDAExtension(
|
37 |
+
name='_freqencoder', # extension name, import this to use CUDA API
|
38 |
+
sources=[os.path.join(_src_path, 'src', f) for f in [
|
39 |
+
'freqencoder.cu',
|
40 |
+
'bindings.cpp',
|
41 |
+
]],
|
42 |
+
extra_compile_args={
|
43 |
+
'cxx': c_flags,
|
44 |
+
'nvcc': nvcc_flags,
|
45 |
+
}
|
46 |
+
),
|
47 |
+
],
|
48 |
+
cmdclass={
|
49 |
+
'build_ext': BuildExtension,
|
50 |
+
}
|
51 |
+
)
|
freqencoder/src/bindings.cpp
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
|
3 |
+
#include "freqencoder.h"
|
4 |
+
|
5 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
6 |
+
m.def("freq_encode_forward", &freq_encode_forward, "freq encode forward (CUDA)");
|
7 |
+
m.def("freq_encode_backward", &freq_encode_backward, "freq encode backward (CUDA)");
|
8 |
+
}
|
freqencoder/src/freqencoder.cu
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <stdint.h>
|
2 |
+
|
3 |
+
#include <cuda.h>
|
4 |
+
#include <cuda_fp16.h>
|
5 |
+
#include <cuda_runtime.h>
|
6 |
+
|
7 |
+
#include <ATen/cuda/CUDAContext.h>
|
8 |
+
#include <torch/torch.h>
|
9 |
+
|
10 |
+
#include <algorithm>
|
11 |
+
#include <stdexcept>
|
12 |
+
|
13 |
+
#include <cstdio>
|
14 |
+
|
15 |
+
|
16 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
17 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
|
18 |
+
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
|
19 |
+
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
|
20 |
+
|
21 |
+
inline constexpr __device__ float PI() { return 3.141592653589793f; }
|
22 |
+
|
23 |
+
template <typename T>
|
24 |
+
__host__ __device__ T div_round_up(T val, T divisor) {
|
25 |
+
return (val + divisor - 1) / divisor;
|
26 |
+
}
|
27 |
+
|
28 |
+
// inputs: [B, D]
|
29 |
+
// outputs: [B, C], C = D + D * deg * 2
|
30 |
+
__global__ void kernel_freq(
|
31 |
+
const float * __restrict__ inputs,
|
32 |
+
uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
|
33 |
+
float * outputs
|
34 |
+
) {
|
35 |
+
// parallel on per-element
|
36 |
+
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
|
37 |
+
if (t >= B * C) return;
|
38 |
+
|
39 |
+
// get index
|
40 |
+
const uint32_t b = t / C;
|
41 |
+
const uint32_t c = t - b * C; // t % C;
|
42 |
+
|
43 |
+
// locate
|
44 |
+
inputs += b * D;
|
45 |
+
outputs += t;
|
46 |
+
|
47 |
+
// write self
|
48 |
+
if (c < D) {
|
49 |
+
outputs[0] = inputs[c];
|
50 |
+
// write freq
|
51 |
+
} else {
|
52 |
+
const uint32_t col = c / D - 1;
|
53 |
+
const uint32_t d = c % D;
|
54 |
+
const uint32_t freq = col / 2;
|
55 |
+
const float phase_shift = (col % 2) * (PI() / 2);
|
56 |
+
outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift);
|
57 |
+
}
|
58 |
+
}
|
59 |
+
|
60 |
+
// grad: [B, C], C = D + D * deg * 2
|
61 |
+
// outputs: [B, C]
|
62 |
+
// grad_inputs: [B, D]
|
63 |
+
__global__ void kernel_freq_backward(
|
64 |
+
const float * __restrict__ grad,
|
65 |
+
const float * __restrict__ outputs,
|
66 |
+
uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
|
67 |
+
float * grad_inputs
|
68 |
+
) {
|
69 |
+
// parallel on per-element
|
70 |
+
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
|
71 |
+
if (t >= B * D) return;
|
72 |
+
|
73 |
+
const uint32_t b = t / D;
|
74 |
+
const uint32_t d = t - b * D; // t % D;
|
75 |
+
|
76 |
+
// locate
|
77 |
+
grad += b * C;
|
78 |
+
outputs += b * C;
|
79 |
+
grad_inputs += t;
|
80 |
+
|
81 |
+
// register
|
82 |
+
float result = grad[d];
|
83 |
+
grad += D;
|
84 |
+
outputs += D;
|
85 |
+
|
86 |
+
for (uint32_t f = 0; f < deg; f++) {
|
87 |
+
result += scalbnf(1.0f, f) * (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]);
|
88 |
+
grad += 2 * D;
|
89 |
+
outputs += 2 * D;
|
90 |
+
}
|
91 |
+
|
92 |
+
// write
|
93 |
+
grad_inputs[0] = result;
|
94 |
+
}
|
95 |
+
|
96 |
+
|
97 |
+
void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs) {
|
98 |
+
CHECK_CUDA(inputs);
|
99 |
+
CHECK_CUDA(outputs);
|
100 |
+
|
101 |
+
CHECK_CONTIGUOUS(inputs);
|
102 |
+
CHECK_CONTIGUOUS(outputs);
|
103 |
+
|
104 |
+
CHECK_IS_FLOATING(inputs);
|
105 |
+
CHECK_IS_FLOATING(outputs);
|
106 |
+
|
107 |
+
static constexpr uint32_t N_THREADS = 128;
|
108 |
+
|
109 |
+
kernel_freq<<<div_round_up(B * C, N_THREADS), N_THREADS>>>(inputs.data_ptr<float>(), B, D, deg, C, outputs.data_ptr<float>());
|
110 |
+
}
|
111 |
+
|
112 |
+
|
113 |
+
void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs) {
|
114 |
+
CHECK_CUDA(grad);
|
115 |
+
CHECK_CUDA(outputs);
|
116 |
+
CHECK_CUDA(grad_inputs);
|
117 |
+
|
118 |
+
CHECK_CONTIGUOUS(grad);
|
119 |
+
CHECK_CONTIGUOUS(outputs);
|
120 |
+
CHECK_CONTIGUOUS(grad_inputs);
|
121 |
+
|
122 |
+
CHECK_IS_FLOATING(grad);
|
123 |
+
CHECK_IS_FLOATING(outputs);
|
124 |
+
CHECK_IS_FLOATING(grad_inputs);
|
125 |
+
|
126 |
+
static constexpr uint32_t N_THREADS = 128;
|
127 |
+
|
128 |
+
kernel_freq_backward<<<div_round_up(B * D, N_THREADS), N_THREADS>>>(grad.data_ptr<float>(), outputs.data_ptr<float>(), B, D, deg, C, grad_inputs.data_ptr<float>());
|
129 |
+
}
|
freqencoder/src/freqencoder.h
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pragma once
|
2 |
+
|
3 |
+
#include <stdint.h>
|
4 |
+
#include <torch/torch.h>
|
5 |
+
|
6 |
+
// _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
|
7 |
+
void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs);
|
8 |
+
|
9 |
+
// _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
|
10 |
+
void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs);
|
gridencoder/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .grid import GridEncoder
|
gridencoder/backend.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from torch.utils.cpp_extension import load
|
3 |
+
|
4 |
+
_src_path = os.path.dirname(os.path.abspath(__file__))
|
5 |
+
|
6 |
+
nvcc_flags = [
|
7 |
+
'-O3', '-std=c++14',
|
8 |
+
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
9 |
+
]
|
10 |
+
|
11 |
+
if os.name == "posix":
|
12 |
+
c_flags = ['-O3', '-std=c++14']
|
13 |
+
elif os.name == "nt":
|
14 |
+
c_flags = ['/O2', '/std:c++17']
|
15 |
+
|
16 |
+
# find cl.exe
|
17 |
+
def find_cl_path():
|
18 |
+
import glob
|
19 |
+
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
20 |
+
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
|
21 |
+
if paths:
|
22 |
+
return paths[0]
|
23 |
+
|
24 |
+
# If cl.exe is not on path, try to find it.
|
25 |
+
if os.system("where cl.exe >nul 2>nul") != 0:
|
26 |
+
cl_path = find_cl_path()
|
27 |
+
if cl_path is None:
|
28 |
+
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
29 |
+
os.environ["PATH"] += ";" + cl_path
|
30 |
+
|
31 |
+
_backend = load(name='_grid_encoder',
|
32 |
+
extra_cflags=c_flags,
|
33 |
+
extra_cuda_cflags=nvcc_flags,
|
34 |
+
sources=[os.path.join(_src_path, 'src', f) for f in [
|
35 |
+
'gridencoder.cu',
|
36 |
+
'bindings.cpp',
|
37 |
+
]],
|
38 |
+
)
|
39 |
+
|
40 |
+
__all__ = ['_backend']
|
gridencoder/grid.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.autograd import Function
|
6 |
+
from torch.autograd.function import once_differentiable
|
7 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
8 |
+
|
9 |
+
try:
|
10 |
+
import _gridencoder as _backend
|
11 |
+
except ImportError:
|
12 |
+
from .backend import _backend
|
13 |
+
|
14 |
+
_gridtype_to_id = {
|
15 |
+
'hash': 0,
|
16 |
+
'tiled': 1,
|
17 |
+
}
|
18 |
+
|
19 |
+
class _grid_encode(Function):
|
20 |
+
@staticmethod
|
21 |
+
@custom_fwd
|
22 |
+
def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False):
|
23 |
+
# inputs: [B, D], float in [0, 1]
|
24 |
+
# embeddings: [sO, C], float
|
25 |
+
# offsets: [L + 1], int
|
26 |
+
# RETURN: [B, F], float
|
27 |
+
|
28 |
+
inputs = inputs.contiguous()
|
29 |
+
|
30 |
+
B, D = inputs.shape # batch size, coord dim
|
31 |
+
L = offsets.shape[0] - 1 # level
|
32 |
+
C = embeddings.shape[1] # embedding dim for each level
|
33 |
+
S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f
|
34 |
+
H = base_resolution # base resolution
|
35 |
+
|
36 |
+
# manually handle autocast (only use half precision embeddings, inputs must be float for enough precision)
|
37 |
+
# if C % 2 != 0, force float, since half for atomicAdd is very slow.
|
38 |
+
if torch.is_autocast_enabled() and C % 2 == 0:
|
39 |
+
embeddings = embeddings.to(torch.half)
|
40 |
+
|
41 |
+
# L first, optimize cache for cuda kernel, but needs an extra permute later
|
42 |
+
outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype)
|
43 |
+
|
44 |
+
if calc_grad_inputs:
|
45 |
+
dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype)
|
46 |
+
else:
|
47 |
+
dy_dx = None
|
48 |
+
|
49 |
+
_backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, dy_dx, gridtype, align_corners)
|
50 |
+
|
51 |
+
# permute back to [B, L * C]
|
52 |
+
outputs = outputs.permute(1, 0, 2).reshape(B, L * C)
|
53 |
+
|
54 |
+
ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)
|
55 |
+
ctx.dims = [B, D, C, L, S, H, gridtype]
|
56 |
+
ctx.align_corners = align_corners
|
57 |
+
|
58 |
+
return outputs
|
59 |
+
|
60 |
+
@staticmethod
|
61 |
+
#@once_differentiable
|
62 |
+
@custom_bwd
|
63 |
+
def backward(ctx, grad):
|
64 |
+
|
65 |
+
inputs, embeddings, offsets, dy_dx = ctx.saved_tensors
|
66 |
+
B, D, C, L, S, H, gridtype = ctx.dims
|
67 |
+
align_corners = ctx.align_corners
|
68 |
+
|
69 |
+
# grad: [B, L * C] --> [L, B, C]
|
70 |
+
grad = grad.view(B, L, C).permute(1, 0, 2).contiguous()
|
71 |
+
|
72 |
+
grad_embeddings = torch.zeros_like(embeddings)
|
73 |
+
|
74 |
+
if dy_dx is not None:
|
75 |
+
grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype)
|
76 |
+
else:
|
77 |
+
grad_inputs = None
|
78 |
+
|
79 |
+
_backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners)
|
80 |
+
|
81 |
+
if dy_dx is not None:
|
82 |
+
grad_inputs = grad_inputs.to(inputs.dtype)
|
83 |
+
|
84 |
+
return grad_inputs, grad_embeddings, None, None, None, None, None, None
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
grid_encode = _grid_encode.apply
|
89 |
+
|
90 |
+
|
91 |
+
class GridEncoder(nn.Module):
|
92 |
+
def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False):
|
93 |
+
super().__init__()
|
94 |
+
|
95 |
+
# the finest resolution desired at the last level, if provided, overridee per_level_scale
|
96 |
+
if desired_resolution is not None:
|
97 |
+
per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1))
|
98 |
+
|
99 |
+
self.input_dim = input_dim # coord dims, 2 or 3
|
100 |
+
self.num_levels = num_levels # num levels, each level multiply resolution by 2
|
101 |
+
self.level_dim = level_dim # encode channels per level
|
102 |
+
self.per_level_scale = per_level_scale # multiply resolution by this scale at each level.
|
103 |
+
self.log2_hashmap_size = log2_hashmap_size
|
104 |
+
self.base_resolution = base_resolution
|
105 |
+
self.output_dim = num_levels * level_dim
|
106 |
+
self.gridtype = gridtype
|
107 |
+
self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash"
|
108 |
+
self.align_corners = align_corners
|
109 |
+
|
110 |
+
# allocate parameters
|
111 |
+
offsets = []
|
112 |
+
offset = 0
|
113 |
+
self.max_params = 2 ** log2_hashmap_size
|
114 |
+
for i in range(num_levels):
|
115 |
+
resolution = int(np.ceil(base_resolution * per_level_scale ** i))
|
116 |
+
params_in_level = min(self.max_params, (resolution if align_corners else resolution + 1) ** input_dim) # limit max number
|
117 |
+
params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible
|
118 |
+
offsets.append(offset)
|
119 |
+
offset += params_in_level
|
120 |
+
offsets.append(offset)
|
121 |
+
offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
|
122 |
+
self.register_buffer('offsets', offsets)
|
123 |
+
|
124 |
+
self.n_params = offsets[-1] * level_dim
|
125 |
+
|
126 |
+
# parameters
|
127 |
+
self.embeddings = nn.Parameter(torch.empty(offset, level_dim))
|
128 |
+
|
129 |
+
self.reset_parameters()
|
130 |
+
|
131 |
+
def reset_parameters(self):
|
132 |
+
std = 1e-4
|
133 |
+
self.embeddings.data.uniform_(-std, std)
|
134 |
+
|
135 |
+
def __repr__(self):
|
136 |
+
return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners}"
|
137 |
+
|
138 |
+
def forward(self, inputs, bound=1):
|
139 |
+
# inputs: [..., input_dim], normalized real world positions in [-bound, bound]
|
140 |
+
# return: [..., num_levels * level_dim]
|
141 |
+
|
142 |
+
inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
|
143 |
+
|
144 |
+
#print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item())
|
145 |
+
|
146 |
+
prefix_shape = list(inputs.shape[:-1])
|
147 |
+
inputs = inputs.view(-1, self.input_dim)
|
148 |
+
|
149 |
+
outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners)
|
150 |
+
outputs = outputs.view(prefix_shape + [self.output_dim])
|
151 |
+
|
152 |
+
#print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())
|
153 |
+
|
154 |
+
return outputs
|
gridencoder/setup.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from setuptools import setup
|
3 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
4 |
+
|
5 |
+
_src_path = os.path.dirname(os.path.abspath(__file__))
|
6 |
+
|
7 |
+
nvcc_flags = [
|
8 |
+
'-O3', '-std=c++14',
|
9 |
+
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
10 |
+
]
|
11 |
+
|
12 |
+
if os.name == "posix":
|
13 |
+
c_flags = ['-O3', '-std=c++14']
|
14 |
+
elif os.name == "nt":
|
15 |
+
c_flags = ['/O2', '/std:c++17']
|
16 |
+
|
17 |
+
# find cl.exe
|
18 |
+
def find_cl_path():
|
19 |
+
import glob
|
20 |
+
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
21 |
+
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
|
22 |
+
if paths:
|
23 |
+
return paths[0]
|
24 |
+
|
25 |
+
# If cl.exe is not on path, try to find it.
|
26 |
+
if os.system("where cl.exe >nul 2>nul") != 0:
|
27 |
+
cl_path = find_cl_path()
|
28 |
+
if cl_path is None:
|
29 |
+
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
30 |
+
os.environ["PATH"] += ";" + cl_path
|
31 |
+
|
32 |
+
setup(
|
33 |
+
name='gridencoder', # package name, import this to use python API
|
34 |
+
ext_modules=[
|
35 |
+
CUDAExtension(
|
36 |
+
name='_gridencoder', # extension name, import this to use CUDA API
|
37 |
+
sources=[os.path.join(_src_path, 'src', f) for f in [
|
38 |
+
'gridencoder.cu',
|
39 |
+
'bindings.cpp',
|
40 |
+
]],
|
41 |
+
extra_compile_args={
|
42 |
+
'cxx': c_flags,
|
43 |
+
'nvcc': nvcc_flags,
|
44 |
+
}
|
45 |
+
),
|
46 |
+
],
|
47 |
+
cmdclass={
|
48 |
+
'build_ext': BuildExtension,
|
49 |
+
}
|
50 |
+
)
|
gridencoder/src/bindings.cpp
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
|
3 |
+
#include "gridencoder.h"
|
4 |
+
|
5 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
6 |
+
m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)");
|
7 |
+
m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)");
|
8 |
+
}
|
gridencoder/src/gridencoder.cu
ADDED
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <cuda.h>
|
2 |
+
#include <cuda_fp16.h>
|
3 |
+
#include <cuda_runtime.h>
|
4 |
+
|
5 |
+
#include <ATen/cuda/CUDAContext.h>
|
6 |
+
#include <torch/torch.h>
|
7 |
+
|
8 |
+
#include <algorithm>
|
9 |
+
#include <stdexcept>
|
10 |
+
|
11 |
+
#include <stdint.h>
|
12 |
+
#include <cstdio>
|
13 |
+
|
14 |
+
|
15 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
16 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
|
17 |
+
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
|
18 |
+
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
|
19 |
+
|
20 |
+
|
21 |
+
// just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF...
|
22 |
+
static inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) {
|
23 |
+
// requires CUDA >= 10 and ARCH >= 70
|
24 |
+
// this is very slow compared to float or __half2, and never used.
|
25 |
+
//return atomicAdd(reinterpret_cast<__half*>(address), val);
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
template <typename T>
|
30 |
+
static inline __host__ __device__ T div_round_up(T val, T divisor) {
|
31 |
+
return (val + divisor - 1) / divisor;
|
32 |
+
}
|
33 |
+
|
34 |
+
|
35 |
+
template <uint32_t D>
|
36 |
+
__device__ uint32_t fast_hash(const uint32_t pos_grid[D]) {
|
37 |
+
static_assert(D <= 7, "fast_hash can only hash up to 7 dimensions.");
|
38 |
+
|
39 |
+
// While 1 is technically not a good prime for hashing (or a prime at all), it helps memory coherence
|
40 |
+
// and is sufficient for our use case of obtaining a uniformly colliding index from high-dimensional
|
41 |
+
// coordinates.
|
42 |
+
constexpr uint32_t primes[7] = { 1, 2654435761, 805459861, 3674653429, 2097192037, 1434869437, 2165219737 };
|
43 |
+
|
44 |
+
uint32_t result = 0;
|
45 |
+
#pragma unroll
|
46 |
+
for (uint32_t i = 0; i < D; ++i) {
|
47 |
+
result ^= pos_grid[i] * primes[i];
|
48 |
+
}
|
49 |
+
|
50 |
+
return result;
|
51 |
+
}
|
52 |
+
|
53 |
+
|
54 |
+
template <uint32_t D, uint32_t C>
|
55 |
+
__device__ uint32_t get_grid_index(const uint32_t gridtype, const bool align_corners, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) {
|
56 |
+
uint32_t stride = 1;
|
57 |
+
uint32_t index = 0;
|
58 |
+
|
59 |
+
#pragma unroll
|
60 |
+
for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) {
|
61 |
+
index += pos_grid[d] * stride;
|
62 |
+
stride *= align_corners ? resolution: (resolution + 1);
|
63 |
+
}
|
64 |
+
|
65 |
+
// NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97.
|
66 |
+
// gridtype: 0 == hash, 1 == tiled
|
67 |
+
if (gridtype == 0 && stride > hashmap_size) {
|
68 |
+
index = fast_hash<D>(pos_grid);
|
69 |
+
}
|
70 |
+
|
71 |
+
return (index % hashmap_size) * C + ch;
|
72 |
+
}
|
73 |
+
|
74 |
+
|
75 |
+
template <typename scalar_t, uint32_t D, uint32_t C>
|
76 |
+
__global__ void kernel_grid(
|
77 |
+
const float * __restrict__ inputs,
|
78 |
+
const scalar_t * __restrict__ grid,
|
79 |
+
const int * __restrict__ offsets,
|
80 |
+
scalar_t * __restrict__ outputs,
|
81 |
+
const uint32_t B, const uint32_t L, const float S, const uint32_t H,
|
82 |
+
scalar_t * __restrict__ dy_dx,
|
83 |
+
const uint32_t gridtype,
|
84 |
+
const bool align_corners
|
85 |
+
) {
|
86 |
+
const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
|
87 |
+
|
88 |
+
if (b >= B) return;
|
89 |
+
|
90 |
+
const uint32_t level = blockIdx.y;
|
91 |
+
|
92 |
+
// locate
|
93 |
+
grid += (uint32_t)offsets[level] * C;
|
94 |
+
inputs += b * D;
|
95 |
+
outputs += level * B * C + b * C;
|
96 |
+
|
97 |
+
// check input range (should be in [0, 1])
|
98 |
+
bool flag_oob = false;
|
99 |
+
#pragma unroll
|
100 |
+
for (uint32_t d = 0; d < D; d++) {
|
101 |
+
if (inputs[d] < 0 || inputs[d] > 1) {
|
102 |
+
flag_oob = true;
|
103 |
+
}
|
104 |
+
}
|
105 |
+
// if input out of bound, just set output to 0
|
106 |
+
if (flag_oob) {
|
107 |
+
#pragma unroll
|
108 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
109 |
+
outputs[ch] = 0;
|
110 |
+
}
|
111 |
+
if (dy_dx) {
|
112 |
+
dy_dx += b * D * L * C + level * D * C; // B L D C
|
113 |
+
#pragma unroll
|
114 |
+
for (uint32_t d = 0; d < D; d++) {
|
115 |
+
#pragma unroll
|
116 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
117 |
+
dy_dx[d * C + ch] = 0;
|
118 |
+
}
|
119 |
+
}
|
120 |
+
}
|
121 |
+
return;
|
122 |
+
}
|
123 |
+
|
124 |
+
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
|
125 |
+
const float scale = exp2f(level * S) * H - 1.0f;
|
126 |
+
const uint32_t resolution = (uint32_t)ceil(scale) + 1;
|
127 |
+
|
128 |
+
// calculate coordinate
|
129 |
+
float pos[D];
|
130 |
+
uint32_t pos_grid[D];
|
131 |
+
|
132 |
+
#pragma unroll
|
133 |
+
for (uint32_t d = 0; d < D; d++) {
|
134 |
+
pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
|
135 |
+
pos_grid[d] = floorf(pos[d]);
|
136 |
+
pos[d] -= (float)pos_grid[d];
|
137 |
+
}
|
138 |
+
|
139 |
+
//printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]);
|
140 |
+
|
141 |
+
// interpolate
|
142 |
+
scalar_t results[C] = {0}; // temp results in register
|
143 |
+
|
144 |
+
#pragma unroll
|
145 |
+
for (uint32_t idx = 0; idx < (1 << D); idx++) {
|
146 |
+
float w = 1;
|
147 |
+
uint32_t pos_grid_local[D];
|
148 |
+
|
149 |
+
#pragma unroll
|
150 |
+
for (uint32_t d = 0; d < D; d++) {
|
151 |
+
if ((idx & (1 << d)) == 0) {
|
152 |
+
w *= 1 - pos[d];
|
153 |
+
pos_grid_local[d] = pos_grid[d];
|
154 |
+
} else {
|
155 |
+
w *= pos[d];
|
156 |
+
pos_grid_local[d] = pos_grid[d] + 1;
|
157 |
+
}
|
158 |
+
}
|
159 |
+
|
160 |
+
uint32_t index = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
|
161 |
+
|
162 |
+
// writing to register (fast)
|
163 |
+
#pragma unroll
|
164 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
165 |
+
results[ch] += w * grid[index + ch];
|
166 |
+
}
|
167 |
+
|
168 |
+
//printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]);
|
169 |
+
}
|
170 |
+
|
171 |
+
// writing to global memory (slow)
|
172 |
+
#pragma unroll
|
173 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
174 |
+
outputs[ch] = results[ch];
|
175 |
+
}
|
176 |
+
|
177 |
+
// prepare dy_dx
|
178 |
+
// differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9
|
179 |
+
if (dy_dx) {
|
180 |
+
|
181 |
+
dy_dx += b * D * L * C + level * D * C; // B L D C
|
182 |
+
|
183 |
+
#pragma unroll
|
184 |
+
for (uint32_t gd = 0; gd < D; gd++) {
|
185 |
+
|
186 |
+
scalar_t results_grad[C] = {0};
|
187 |
+
|
188 |
+
#pragma unroll
|
189 |
+
for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) {
|
190 |
+
float w = scale;
|
191 |
+
uint32_t pos_grid_local[D];
|
192 |
+
|
193 |
+
#pragma unroll
|
194 |
+
for (uint32_t nd = 0; nd < D - 1; nd++) {
|
195 |
+
const uint32_t d = (nd >= gd) ? (nd + 1) : nd;
|
196 |
+
|
197 |
+
if ((idx & (1 << nd)) == 0) {
|
198 |
+
w *= 1 - pos[d];
|
199 |
+
pos_grid_local[d] = pos_grid[d];
|
200 |
+
} else {
|
201 |
+
w *= pos[d];
|
202 |
+
pos_grid_local[d] = pos_grid[d] + 1;
|
203 |
+
}
|
204 |
+
}
|
205 |
+
|
206 |
+
pos_grid_local[gd] = pos_grid[gd];
|
207 |
+
uint32_t index_left = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
|
208 |
+
pos_grid_local[gd] = pos_grid[gd] + 1;
|
209 |
+
uint32_t index_right = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
|
210 |
+
|
211 |
+
#pragma unroll
|
212 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
213 |
+
results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]);
|
214 |
+
}
|
215 |
+
}
|
216 |
+
|
217 |
+
#pragma unroll
|
218 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
219 |
+
dy_dx[gd * C + ch] = results_grad[ch];
|
220 |
+
}
|
221 |
+
}
|
222 |
+
}
|
223 |
+
}
|
224 |
+
|
225 |
+
|
226 |
+
template <typename scalar_t, uint32_t D, uint32_t C, uint32_t N_C>
|
227 |
+
__global__ void kernel_grid_backward(
|
228 |
+
const scalar_t * __restrict__ grad,
|
229 |
+
const float * __restrict__ inputs,
|
230 |
+
const scalar_t * __restrict__ grid,
|
231 |
+
const int * __restrict__ offsets,
|
232 |
+
scalar_t * __restrict__ grad_grid,
|
233 |
+
const uint32_t B, const uint32_t L, const float S, const uint32_t H,
|
234 |
+
const uint32_t gridtype,
|
235 |
+
const bool align_corners
|
236 |
+
) {
|
237 |
+
const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C;
|
238 |
+
if (b >= B) return;
|
239 |
+
|
240 |
+
const uint32_t level = blockIdx.y;
|
241 |
+
const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C;
|
242 |
+
|
243 |
+
// locate
|
244 |
+
grad_grid += offsets[level] * C;
|
245 |
+
inputs += b * D;
|
246 |
+
grad += level * B * C + b * C + ch; // L, B, C
|
247 |
+
|
248 |
+
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
|
249 |
+
const float scale = exp2f(level * S) * H - 1.0f;
|
250 |
+
const uint32_t resolution = (uint32_t)ceil(scale) + 1;
|
251 |
+
|
252 |
+
// check input range (should be in [0, 1])
|
253 |
+
#pragma unroll
|
254 |
+
for (uint32_t d = 0; d < D; d++) {
|
255 |
+
if (inputs[d] < 0 || inputs[d] > 1) {
|
256 |
+
return; // grad is init as 0, so we simply return.
|
257 |
+
}
|
258 |
+
}
|
259 |
+
|
260 |
+
// calculate coordinate
|
261 |
+
float pos[D];
|
262 |
+
uint32_t pos_grid[D];
|
263 |
+
|
264 |
+
#pragma unroll
|
265 |
+
for (uint32_t d = 0; d < D; d++) {
|
266 |
+
pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
|
267 |
+
pos_grid[d] = floorf(pos[d]);
|
268 |
+
pos[d] -= (float)pos_grid[d];
|
269 |
+
}
|
270 |
+
|
271 |
+
scalar_t grad_cur[N_C] = {0}; // fetch to register
|
272 |
+
#pragma unroll
|
273 |
+
for (uint32_t c = 0; c < N_C; c++) {
|
274 |
+
grad_cur[c] = grad[c];
|
275 |
+
}
|
276 |
+
|
277 |
+
// interpolate
|
278 |
+
#pragma unroll
|
279 |
+
for (uint32_t idx = 0; idx < (1 << D); idx++) {
|
280 |
+
float w = 1;
|
281 |
+
uint32_t pos_grid_local[D];
|
282 |
+
|
283 |
+
#pragma unroll
|
284 |
+
for (uint32_t d = 0; d < D; d++) {
|
285 |
+
if ((idx & (1 << d)) == 0) {
|
286 |
+
w *= 1 - pos[d];
|
287 |
+
pos_grid_local[d] = pos_grid[d];
|
288 |
+
} else {
|
289 |
+
w *= pos[d];
|
290 |
+
pos_grid_local[d] = pos_grid[d] + 1;
|
291 |
+
}
|
292 |
+
}
|
293 |
+
|
294 |
+
uint32_t index = get_grid_index<D, C>(gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local);
|
295 |
+
|
296 |
+
// atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0
|
297 |
+
// TODO: use float which is better than __half, if N_C % 2 != 0
|
298 |
+
if (std::is_same<scalar_t, at::Half>::value && N_C % 2 == 0) {
|
299 |
+
#pragma unroll
|
300 |
+
for (uint32_t c = 0; c < N_C; c += 2) {
|
301 |
+
// process two __half at once (by interpreting as a __half2)
|
302 |
+
__half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])};
|
303 |
+
atomicAdd((__half2*)&grad_grid[index + c], v);
|
304 |
+
}
|
305 |
+
// float, or __half when N_C % 2 != 0 (which means C == 1)
|
306 |
+
} else {
|
307 |
+
#pragma unroll
|
308 |
+
for (uint32_t c = 0; c < N_C; c++) {
|
309 |
+
atomicAdd(&grad_grid[index + c], w * grad_cur[c]);
|
310 |
+
}
|
311 |
+
}
|
312 |
+
}
|
313 |
+
}
|
314 |
+
|
315 |
+
|
316 |
+
template <typename scalar_t, uint32_t D, uint32_t C>
|
317 |
+
__global__ void kernel_input_backward(
|
318 |
+
const scalar_t * __restrict__ grad,
|
319 |
+
const scalar_t * __restrict__ dy_dx,
|
320 |
+
scalar_t * __restrict__ grad_inputs,
|
321 |
+
uint32_t B, uint32_t L
|
322 |
+
) {
|
323 |
+
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
|
324 |
+
if (t >= B * D) return;
|
325 |
+
|
326 |
+
const uint32_t b = t / D;
|
327 |
+
const uint32_t d = t - b * D;
|
328 |
+
|
329 |
+
dy_dx += b * L * D * C;
|
330 |
+
|
331 |
+
scalar_t result = 0;
|
332 |
+
|
333 |
+
# pragma unroll
|
334 |
+
for (int l = 0; l < L; l++) {
|
335 |
+
# pragma unroll
|
336 |
+
for (int ch = 0; ch < C; ch++) {
|
337 |
+
result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch];
|
338 |
+
}
|
339 |
+
}
|
340 |
+
|
341 |
+
grad_inputs[t] = result;
|
342 |
+
}
|
343 |
+
|
344 |
+
|
345 |
+
template <typename scalar_t, uint32_t D>
|
346 |
+
void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) {
|
347 |
+
static constexpr uint32_t N_THREAD = 512;
|
348 |
+
const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 };
|
349 |
+
switch (C) {
|
350 |
+
case 1: kernel_grid<scalar_t, D, 1><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break;
|
351 |
+
case 2: kernel_grid<scalar_t, D, 2><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break;
|
352 |
+
case 4: kernel_grid<scalar_t, D, 4><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break;
|
353 |
+
case 8: kernel_grid<scalar_t, D, 8><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break;
|
354 |
+
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
|
355 |
+
}
|
356 |
+
}
|
357 |
+
|
358 |
+
// inputs: [B, D], float, in [0, 1]
|
359 |
+
// embeddings: [sO, C], float
|
360 |
+
// offsets: [L + 1], uint32_t
|
361 |
+
// outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.)
|
362 |
+
// H: base resolution
|
363 |
+
// dy_dx: [B, L * D * C]
|
364 |
+
template <typename scalar_t>
|
365 |
+
void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) {
|
366 |
+
switch (D) {
|
367 |
+
case 1: kernel_grid_wrapper<scalar_t, 1>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
|
368 |
+
case 2: kernel_grid_wrapper<scalar_t, 2>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
|
369 |
+
case 3: kernel_grid_wrapper<scalar_t, 3>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
|
370 |
+
case 4: kernel_grid_wrapper<scalar_t, 4>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
|
371 |
+
case 5: kernel_grid_wrapper<scalar_t, 5>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
|
372 |
+
default: throw std::runtime_error{"GridEncoding: D must be 1, 2, 3, 4, or 5."};
|
373 |
+
}
|
374 |
+
|
375 |
+
}
|
376 |
+
|
377 |
+
template <typename scalar_t, uint32_t D>
|
378 |
+
void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) {
|
379 |
+
static constexpr uint32_t N_THREAD = 256;
|
380 |
+
const uint32_t N_C = std::min(2u, C); // n_features_per_thread
|
381 |
+
const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), L, 1 };
|
382 |
+
switch (C) {
|
383 |
+
case 1:
|
384 |
+
kernel_grid_backward<scalar_t, D, 1, 1><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
|
385 |
+
if (dy_dx) kernel_input_backward<scalar_t, D, 1><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
|
386 |
+
break;
|
387 |
+
case 2:
|
388 |
+
kernel_grid_backward<scalar_t, D, 2, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
|
389 |
+
if (dy_dx) kernel_input_backward<scalar_t, D, 2><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
|
390 |
+
break;
|
391 |
+
case 4:
|
392 |
+
kernel_grid_backward<scalar_t, D, 4, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
|
393 |
+
if (dy_dx) kernel_input_backward<scalar_t, D, 4><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
|
394 |
+
break;
|
395 |
+
case 8:
|
396 |
+
kernel_grid_backward<scalar_t, D, 8, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
|
397 |
+
if (dy_dx) kernel_input_backward<scalar_t, D, 8><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
|
398 |
+
break;
|
399 |
+
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
|
400 |
+
}
|
401 |
+
}
|
402 |
+
|
403 |
+
|
404 |
+
// grad: [L, B, C], float
|
405 |
+
// inputs: [B, D], float, in [0, 1]
|
406 |
+
// embeddings: [sO, C], float
|
407 |
+
// offsets: [L + 1], uint32_t
|
408 |
+
// grad_embeddings: [sO, C]
|
409 |
+
// H: base resolution
|
410 |
+
template <typename scalar_t>
|
411 |
+
void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) {
|
412 |
+
switch (D) {
|
413 |
+
case 1: kernel_grid_backward_wrapper<scalar_t, 1>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
|
414 |
+
case 2: kernel_grid_backward_wrapper<scalar_t, 2>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
|
415 |
+
case 3: kernel_grid_backward_wrapper<scalar_t, 3>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
|
416 |
+
case 4: kernel_grid_backward_wrapper<scalar_t, 4>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
|
417 |
+
case 5: kernel_grid_backward_wrapper<scalar_t, 5>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
|
418 |
+
default: throw std::runtime_error{"GridEncoding: D must be 1, 2, 3, 4, or 5."};
|
419 |
+
}
|
420 |
+
}
|
421 |
+
|
422 |
+
|
423 |
+
|
424 |
+
void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional<at::Tensor> dy_dx, const uint32_t gridtype, const bool align_corners) {
|
425 |
+
CHECK_CUDA(inputs);
|
426 |
+
CHECK_CUDA(embeddings);
|
427 |
+
CHECK_CUDA(offsets);
|
428 |
+
CHECK_CUDA(outputs);
|
429 |
+
// CHECK_CUDA(dy_dx);
|
430 |
+
|
431 |
+
CHECK_CONTIGUOUS(inputs);
|
432 |
+
CHECK_CONTIGUOUS(embeddings);
|
433 |
+
CHECK_CONTIGUOUS(offsets);
|
434 |
+
CHECK_CONTIGUOUS(outputs);
|
435 |
+
// CHECK_CONTIGUOUS(dy_dx);
|
436 |
+
|
437 |
+
CHECK_IS_FLOATING(inputs);
|
438 |
+
CHECK_IS_FLOATING(embeddings);
|
439 |
+
CHECK_IS_INT(offsets);
|
440 |
+
CHECK_IS_FLOATING(outputs);
|
441 |
+
// CHECK_IS_FLOATING(dy_dx);
|
442 |
+
|
443 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
444 |
+
embeddings.scalar_type(), "grid_encode_forward", ([&] {
|
445 |
+
grid_encode_forward_cuda<scalar_t>(inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), outputs.data_ptr<scalar_t>(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr, gridtype, align_corners);
|
446 |
+
}));
|
447 |
+
}
|
448 |
+
|
449 |
+
void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional<at::Tensor> dy_dx, at::optional<at::Tensor> grad_inputs, const uint32_t gridtype, const bool align_corners) {
|
450 |
+
CHECK_CUDA(grad);
|
451 |
+
CHECK_CUDA(inputs);
|
452 |
+
CHECK_CUDA(embeddings);
|
453 |
+
CHECK_CUDA(offsets);
|
454 |
+
CHECK_CUDA(grad_embeddings);
|
455 |
+
// CHECK_CUDA(dy_dx);
|
456 |
+
// CHECK_CUDA(grad_inputs);
|
457 |
+
|
458 |
+
CHECK_CONTIGUOUS(grad);
|
459 |
+
CHECK_CONTIGUOUS(inputs);
|
460 |
+
CHECK_CONTIGUOUS(embeddings);
|
461 |
+
CHECK_CONTIGUOUS(offsets);
|
462 |
+
CHECK_CONTIGUOUS(grad_embeddings);
|
463 |
+
// CHECK_CONTIGUOUS(dy_dx);
|
464 |
+
// CHECK_CONTIGUOUS(grad_inputs);
|
465 |
+
|
466 |
+
CHECK_IS_FLOATING(grad);
|
467 |
+
CHECK_IS_FLOATING(inputs);
|
468 |
+
CHECK_IS_FLOATING(embeddings);
|
469 |
+
CHECK_IS_INT(offsets);
|
470 |
+
CHECK_IS_FLOATING(grad_embeddings);
|
471 |
+
// CHECK_IS_FLOATING(dy_dx);
|
472 |
+
// CHECK_IS_FLOATING(grad_inputs);
|
473 |
+
|
474 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
475 |
+
grad.scalar_type(), "grid_encode_backward", ([&] {
|
476 |
+
grid_encode_backward_cuda<scalar_t>(grad.data_ptr<scalar_t>(), inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), grad_embeddings.data_ptr<scalar_t>(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr, grad_inputs.has_value() ? grad_inputs.value().data_ptr<scalar_t>() : nullptr, gridtype, align_corners);
|
477 |
+
}));
|
478 |
+
|
479 |
+
}
|
gridencoder/src/gridencoder.h
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef _HASH_ENCODE_H
|
2 |
+
#define _HASH_ENCODE_H
|
3 |
+
|
4 |
+
#include <stdint.h>
|
5 |
+
#include <torch/torch.h>
|
6 |
+
|
7 |
+
// inputs: [B, D], float, in [0, 1]
|
8 |
+
// embeddings: [sO, C], float
|
9 |
+
// offsets: [L + 1], uint32_t
|
10 |
+
// outputs: [B, L * C], float
|
11 |
+
// H: base resolution
|
12 |
+
void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional<at::Tensor> dy_dx, const uint32_t gridtype, const bool align_corners);
|
13 |
+
void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional<at::Tensor> dy_dx, at::optional<at::Tensor> grad_inputs, const uint32_t gridtype, const bool align_corners);
|
14 |
+
|
15 |
+
#endif
|
loss.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
def mape_loss(pred, target):
|
6 |
+
# pred, target: [B, 1], torch tenspr
|
7 |
+
difference = (pred - target).abs()
|
8 |
+
scale = 1 / (target.abs() + 1e-2)
|
9 |
+
loss = difference * scale
|
10 |
+
|
11 |
+
return loss.mean()
|
main_nerf.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import argparse
|
3 |
+
|
4 |
+
from nerf.provider import NeRFDataset
|
5 |
+
from nerf.utils import *
|
6 |
+
from optimizer import Shampoo
|
7 |
+
|
8 |
+
from nerf.sd import StableDiffusion
|
9 |
+
from nerf.clip import CLIP
|
10 |
+
from nerf.gui import NeRFGUI
|
11 |
+
|
12 |
+
# torch.autograd.set_detect_anomaly(True)
|
13 |
+
|
14 |
+
if __name__ == '__main__':
|
15 |
+
|
16 |
+
parser = argparse.ArgumentParser()
|
17 |
+
parser.add_argument('--text', help="text prompt")
|
18 |
+
parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray --preload")
|
19 |
+
parser.add_argument('--test', action='store_true', help="test mode")
|
20 |
+
parser.add_argument('--workspace', type=str, default='workspace')
|
21 |
+
parser.add_argument('--guidance', type=str, default='stable-diffusion', help='choose from [stable-diffusion, clip]')
|
22 |
+
parser.add_argument('--seed', type=int, default=0)
|
23 |
+
|
24 |
+
### training options
|
25 |
+
parser.add_argument('--iters', type=int, default=15000, help="training iters")
|
26 |
+
parser.add_argument('--lr', type=float, default=1e-3, help="initial learning rate")
|
27 |
+
parser.add_argument('--ckpt', type=str, default='latest')
|
28 |
+
parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
|
29 |
+
parser.add_argument('--max_steps', type=int, default=1024, help="max num steps sampled per ray (only valid when using --cuda_ray)")
|
30 |
+
parser.add_argument('--num_steps', type=int, default=256, help="num steps sampled per ray (only valid when not using --cuda_ray)")
|
31 |
+
parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when not using --cuda_ray)")
|
32 |
+
parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)")
|
33 |
+
parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when not using --cuda_ray)")
|
34 |
+
parser.add_argument('--albedo_iters', type=int, default=15000, help="training iters")
|
35 |
+
# model options
|
36 |
+
parser.add_argument('--bg_radius', type=float, default=1.4, help="if positive, use a background model at sphere(bg_radius)")
|
37 |
+
parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied")
|
38 |
+
# network backbone
|
39 |
+
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
|
40 |
+
parser.add_argument('--backbone', type=str, default='grid', help="nerf backbone, choose from [grid, tcnn, vanilla]")
|
41 |
+
# rendering resolution in training
|
42 |
+
parser.add_argument('--w', type=int, default=64, help="render width for CLIP training (<=224)")
|
43 |
+
parser.add_argument('--h', type=int, default=64, help="render height for CLIP training (<=224)")
|
44 |
+
|
45 |
+
### dataset options
|
46 |
+
parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box(-bound, bound)")
|
47 |
+
parser.add_argument('--dt_gamma', type=float, default=0, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)")
|
48 |
+
parser.add_argument('--min_near', type=float, default=0.1, help="minimum near distance for camera")
|
49 |
+
parser.add_argument('--radius_range', type=float, nargs='*', default=[1.0, 1.5], help="training camera radius range")
|
50 |
+
parser.add_argument('--fovy_range', type=float, nargs='*', default=[40, 70], help="training camera fovy range")
|
51 |
+
parser.add_argument('--dir_text', action='store_true', help="direction encoded text prompt")
|
52 |
+
|
53 |
+
### GUI options
|
54 |
+
parser.add_argument('--gui', action='store_true', help="start a GUI")
|
55 |
+
parser.add_argument('--W', type=int, default=800, help="GUI width")
|
56 |
+
parser.add_argument('--H', type=int, default=800, help="GUI height")
|
57 |
+
parser.add_argument('--radius', type=float, default=3, help="default GUI camera radius from center")
|
58 |
+
parser.add_argument('--fovy', type=float, default=60, help="default GUI camera fovy")
|
59 |
+
parser.add_argument('--light_theta', type=float, default=60, help="default GUI light direction")
|
60 |
+
parser.add_argument('--light_phi', type=float, default=0, help="default GUI light direction")
|
61 |
+
parser.add_argument('--max_spp', type=int, default=64, help="GUI rendering max sample per pixel")
|
62 |
+
|
63 |
+
opt = parser.parse_args()
|
64 |
+
|
65 |
+
if opt.O:
|
66 |
+
opt.fp16 = True
|
67 |
+
opt.cuda_ray = True
|
68 |
+
opt.dir_text = True
|
69 |
+
|
70 |
+
if opt.backbone == 'vanilla':
|
71 |
+
from nerf.network import NeRFNetwork
|
72 |
+
elif opt.backbone == 'tcnn':
|
73 |
+
from nerf.network_tcnn import NeRFNetwork
|
74 |
+
elif opt.backbone == 'grid':
|
75 |
+
from nerf.network_grid import NeRFNetwork
|
76 |
+
else:
|
77 |
+
raise NotImplementedError(f'--backbone {opt.backbone} is not implemented!')
|
78 |
+
|
79 |
+
print(opt)
|
80 |
+
|
81 |
+
seed_everything(opt.seed)
|
82 |
+
|
83 |
+
model = NeRFNetwork(opt)
|
84 |
+
|
85 |
+
print(model)
|
86 |
+
|
87 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
88 |
+
|
89 |
+
if opt.test:
|
90 |
+
guidance = None # do not load guidance at test
|
91 |
+
|
92 |
+
trainer = Trainer('ngp', opt, model, guidance, device=device, workspace=opt.workspace, fp16=opt.fp16, use_checkpoint=opt.ckpt)
|
93 |
+
|
94 |
+
if opt.gui:
|
95 |
+
gui = NeRFGUI(opt, trainer)
|
96 |
+
gui.render()
|
97 |
+
|
98 |
+
else:
|
99 |
+
test_loader = NeRFDataset(opt, device=device, type='test', H=opt.H, W=opt.W, size=100).dataloader()
|
100 |
+
trainer.test(test_loader)
|
101 |
+
trainer.save_mesh(resolution=256)
|
102 |
+
|
103 |
+
else:
|
104 |
+
|
105 |
+
if opt.guidance == 'stable-diffusion':
|
106 |
+
guidance = StableDiffusion(device)
|
107 |
+
elif opt.guidance == 'clip':
|
108 |
+
guidance = CLIP(device)
|
109 |
+
else:
|
110 |
+
raise NotImplementedError(f'--guidance {opt.guidance} is not implemented.')
|
111 |
+
|
112 |
+
optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15)
|
113 |
+
# optimizer = lambda model: Shampoo(model.get_params(opt.lr))
|
114 |
+
|
115 |
+
train_loader = NeRFDataset(opt, device=device, type='train', H=opt.h, W=opt.w, size=100).dataloader()
|
116 |
+
|
117 |
+
# decay to 0.01 * init_lr at last iter step
|
118 |
+
scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.01 ** min(iter / opt.iters, 1))
|
119 |
+
|
120 |
+
trainer = Trainer('ngp', opt, model, guidance, device=device, workspace=opt.workspace, optimizer=optimizer, ema_decay=0.95, fp16=opt.fp16, lr_scheduler=scheduler, use_checkpoint=opt.ckpt, eval_interval=1)
|
121 |
+
|
122 |
+
if opt.gui:
|
123 |
+
trainer.train_loader = train_loader # attach dataloader to trainer
|
124 |
+
|
125 |
+
gui = NeRFGUI(opt, trainer)
|
126 |
+
gui.render()
|
127 |
+
|
128 |
+
else:
|
129 |
+
valid_loader = NeRFDataset(opt, device=device, type='val', H=opt.H, W=opt.W, size=5).dataloader()
|
130 |
+
|
131 |
+
max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32)
|
132 |
+
trainer.train(train_loader, valid_loader, max_epoch)
|
133 |
+
|
134 |
+
# also test
|
135 |
+
test_loader = NeRFDataset(opt, device=device, type='test', H=opt.H, W=opt.W, size=100).dataloader()
|
136 |
+
trainer.test(test_loader)
|
137 |
+
trainer.save_mesh(resolution=256)
|
nerf/clip.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
import torchvision.transforms as T
|
5 |
+
import torchvision.transforms.functional as TF
|
6 |
+
|
7 |
+
import clip
|
8 |
+
|
9 |
+
class CLIP(nn.Module):
|
10 |
+
def __init__(self, device):
|
11 |
+
super().__init__()
|
12 |
+
|
13 |
+
self.device = device
|
14 |
+
|
15 |
+
self.clip_model, self.clip_preprocess = clip.load("ViT-B/16", device=self.device, jit=False)
|
16 |
+
|
17 |
+
# image augmentation
|
18 |
+
self.aug = T.Compose([
|
19 |
+
T.Resize((224, 224)),
|
20 |
+
T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
21 |
+
])
|
22 |
+
|
23 |
+
# self.gaussian_blur = T.GaussianBlur(15, sigma=(0.1, 10))
|
24 |
+
|
25 |
+
|
26 |
+
def get_text_embeds(self, prompt):
|
27 |
+
|
28 |
+
text = clip.tokenize(prompt).to(self.device)
|
29 |
+
text_z = self.clip_model.encode_text(text)
|
30 |
+
text_z = text_z / text_z.norm(dim=-1, keepdim=True)
|
31 |
+
|
32 |
+
return text_z
|
33 |
+
|
34 |
+
|
35 |
+
def train_step(self, text_z, pred_rgb):
|
36 |
+
|
37 |
+
pred_rgb = self.aug(pred_rgb)
|
38 |
+
|
39 |
+
image_z = self.clip_model.encode_image(pred_rgb)
|
40 |
+
image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
|
41 |
+
|
42 |
+
loss = - (image_z * text_z).sum(-1).mean()
|
43 |
+
|
44 |
+
return loss
|
45 |
+
|
nerf/gui.py
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import dearpygui.dearpygui as dpg
|
5 |
+
from scipy.spatial.transform import Rotation as R
|
6 |
+
|
7 |
+
from nerf.utils import *
|
8 |
+
|
9 |
+
|
10 |
+
class OrbitCamera:
|
11 |
+
def __init__(self, W, H, r=2, fovy=60):
|
12 |
+
self.W = W
|
13 |
+
self.H = H
|
14 |
+
self.radius = r # camera distance from center
|
15 |
+
self.fovy = fovy # in degree
|
16 |
+
self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point
|
17 |
+
self.rot = R.from_quat([1, 0, 0, 0]) # init camera matrix: [[1, 0, 0], [0, -1, 0], [0, 0, 1]] (to suit ngp convention)
|
18 |
+
self.up = np.array([0, 1, 0], dtype=np.float32) # need to be normalized!
|
19 |
+
|
20 |
+
# pose
|
21 |
+
@property
|
22 |
+
def pose(self):
|
23 |
+
# first move camera to radius
|
24 |
+
res = np.eye(4, dtype=np.float32)
|
25 |
+
res[2, 3] -= self.radius
|
26 |
+
# rotate
|
27 |
+
rot = np.eye(4, dtype=np.float32)
|
28 |
+
rot[:3, :3] = self.rot.as_matrix()
|
29 |
+
res = rot @ res
|
30 |
+
# translate
|
31 |
+
res[:3, 3] -= self.center
|
32 |
+
return res
|
33 |
+
|
34 |
+
# intrinsics
|
35 |
+
@property
|
36 |
+
def intrinsics(self):
|
37 |
+
focal = self.H / (2 * np.tan(np.radians(self.fovy) / 2))
|
38 |
+
return np.array([focal, focal, self.W // 2, self.H // 2])
|
39 |
+
|
40 |
+
def orbit(self, dx, dy):
|
41 |
+
# rotate along camera up/side axis!
|
42 |
+
side = self.rot.as_matrix()[:3, 0] # why this is side --> ? # already normalized.
|
43 |
+
rotvec_x = self.up * np.radians(-0.1 * dx)
|
44 |
+
rotvec_y = side * np.radians(-0.1 * dy)
|
45 |
+
self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot
|
46 |
+
|
47 |
+
def scale(self, delta):
|
48 |
+
self.radius *= 1.1 ** (-delta)
|
49 |
+
|
50 |
+
def pan(self, dx, dy, dz=0):
|
51 |
+
# pan in camera coordinate system (careful on the sensitivity!)
|
52 |
+
self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([dx, dy, dz])
|
53 |
+
|
54 |
+
|
55 |
+
class NeRFGUI:
|
56 |
+
def __init__(self, opt, trainer, debug=True):
|
57 |
+
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
|
58 |
+
self.W = opt.W
|
59 |
+
self.H = opt.H
|
60 |
+
self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)
|
61 |
+
self.debug = debug
|
62 |
+
self.bg_color = torch.ones(3, dtype=torch.float32) # default white bg
|
63 |
+
self.training = False
|
64 |
+
self.step = 0 # training step
|
65 |
+
|
66 |
+
self.trainer = trainer
|
67 |
+
self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
|
68 |
+
self.need_update = True # camera moved, should reset accumulation
|
69 |
+
self.spp = 1 # sample per pixel
|
70 |
+
self.light_dir = np.array([opt.light_theta, opt.light_phi])
|
71 |
+
self.ambient_ratio = 1.0
|
72 |
+
self.mode = 'image' # choose from ['image', 'depth']
|
73 |
+
self.shading = 'albedo'
|
74 |
+
|
75 |
+
self.dynamic_resolution = True
|
76 |
+
self.downscale = 1
|
77 |
+
self.train_steps = 16
|
78 |
+
|
79 |
+
dpg.create_context()
|
80 |
+
self.register_dpg()
|
81 |
+
self.test_step()
|
82 |
+
|
83 |
+
|
84 |
+
def __del__(self):
|
85 |
+
dpg.destroy_context()
|
86 |
+
|
87 |
+
|
88 |
+
def train_step(self):
|
89 |
+
|
90 |
+
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
|
91 |
+
starter.record()
|
92 |
+
|
93 |
+
outputs = self.trainer.train_gui(self.trainer.train_loader, step=self.train_steps)
|
94 |
+
|
95 |
+
ender.record()
|
96 |
+
torch.cuda.synchronize()
|
97 |
+
t = starter.elapsed_time(ender)
|
98 |
+
|
99 |
+
self.step += self.train_steps
|
100 |
+
self.need_update = True
|
101 |
+
|
102 |
+
dpg.set_value("_log_train_time", f'{t:.4f}ms ({int(1000/t)} FPS)')
|
103 |
+
dpg.set_value("_log_train_log", f'step = {self.step: 5d} (+{self.train_steps: 2d}), loss = {outputs["loss"]:.4f}, lr = {outputs["lr"]:.5f}')
|
104 |
+
|
105 |
+
# dynamic train steps
|
106 |
+
# max allowed train time per-frame is 500 ms
|
107 |
+
full_t = t / self.train_steps * 16
|
108 |
+
train_steps = min(16, max(4, int(16 * 500 / full_t)))
|
109 |
+
if train_steps > self.train_steps * 1.2 or train_steps < self.train_steps * 0.8:
|
110 |
+
self.train_steps = train_steps
|
111 |
+
|
112 |
+
|
113 |
+
def prepare_buffer(self, outputs):
|
114 |
+
if self.mode == 'image':
|
115 |
+
return outputs['image']
|
116 |
+
else:
|
117 |
+
return np.expand_dims(outputs['depth'], -1).repeat(3, -1)
|
118 |
+
|
119 |
+
|
120 |
+
def test_step(self):
|
121 |
+
|
122 |
+
if self.need_update or self.spp < self.opt.max_spp:
|
123 |
+
|
124 |
+
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
|
125 |
+
starter.record()
|
126 |
+
|
127 |
+
outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, self.W, self.H, self.bg_color, self.spp, self.downscale, self.light_dir, self.ambient_ratio, self.shading)
|
128 |
+
|
129 |
+
ender.record()
|
130 |
+
torch.cuda.synchronize()
|
131 |
+
t = starter.elapsed_time(ender)
|
132 |
+
|
133 |
+
# update dynamic resolution
|
134 |
+
if self.dynamic_resolution:
|
135 |
+
# max allowed infer time per-frame is 200 ms
|
136 |
+
full_t = t / (self.downscale ** 2)
|
137 |
+
downscale = min(1, max(1/4, math.sqrt(200 / full_t)))
|
138 |
+
if downscale > self.downscale * 1.2 or downscale < self.downscale * 0.8:
|
139 |
+
self.downscale = downscale
|
140 |
+
|
141 |
+
if self.need_update:
|
142 |
+
self.render_buffer = self.prepare_buffer(outputs)
|
143 |
+
self.spp = 1
|
144 |
+
self.need_update = False
|
145 |
+
else:
|
146 |
+
self.render_buffer = (self.render_buffer * self.spp + self.prepare_buffer(outputs)) / (self.spp + 1)
|
147 |
+
self.spp += 1
|
148 |
+
|
149 |
+
dpg.set_value("_log_infer_time", f'{t:.4f}ms ({int(1000/t)} FPS)')
|
150 |
+
dpg.set_value("_log_resolution", f'{int(self.downscale * self.W)}x{int(self.downscale * self.H)}')
|
151 |
+
dpg.set_value("_log_spp", self.spp)
|
152 |
+
dpg.set_value("_texture", self.render_buffer)
|
153 |
+
|
154 |
+
|
155 |
+
def register_dpg(self):
|
156 |
+
|
157 |
+
### register texture
|
158 |
+
|
159 |
+
with dpg.texture_registry(show=False):
|
160 |
+
dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture")
|
161 |
+
|
162 |
+
### register window
|
163 |
+
|
164 |
+
# the rendered image, as the primary window
|
165 |
+
with dpg.window(tag="_primary_window", width=self.W, height=self.H):
|
166 |
+
|
167 |
+
# add the texture
|
168 |
+
dpg.add_image("_texture")
|
169 |
+
|
170 |
+
dpg.set_primary_window("_primary_window", True)
|
171 |
+
|
172 |
+
# control window
|
173 |
+
with dpg.window(label="Control", tag="_control_window", width=400, height=300):
|
174 |
+
|
175 |
+
# text prompt
|
176 |
+
if self.opt.text is not None:
|
177 |
+
dpg.add_text("text: " + self.opt.text, tag="_log_prompt_text")
|
178 |
+
|
179 |
+
# button theme
|
180 |
+
with dpg.theme() as theme_button:
|
181 |
+
with dpg.theme_component(dpg.mvButton):
|
182 |
+
dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18))
|
183 |
+
dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47))
|
184 |
+
dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83))
|
185 |
+
dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5)
|
186 |
+
dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3)
|
187 |
+
|
188 |
+
# time
|
189 |
+
if not self.opt.test:
|
190 |
+
with dpg.group(horizontal=True):
|
191 |
+
dpg.add_text("Train time: ")
|
192 |
+
dpg.add_text("no data", tag="_log_train_time")
|
193 |
+
|
194 |
+
with dpg.group(horizontal=True):
|
195 |
+
dpg.add_text("Infer time: ")
|
196 |
+
dpg.add_text("no data", tag="_log_infer_time")
|
197 |
+
|
198 |
+
with dpg.group(horizontal=True):
|
199 |
+
dpg.add_text("SPP: ")
|
200 |
+
dpg.add_text("1", tag="_log_spp")
|
201 |
+
|
202 |
+
# train button
|
203 |
+
if not self.opt.test:
|
204 |
+
with dpg.collapsing_header(label="Train", default_open=True):
|
205 |
+
with dpg.group(horizontal=True):
|
206 |
+
dpg.add_text("Train: ")
|
207 |
+
|
208 |
+
def callback_train(sender, app_data):
|
209 |
+
if self.training:
|
210 |
+
self.training = False
|
211 |
+
dpg.configure_item("_button_train", label="start")
|
212 |
+
else:
|
213 |
+
self.training = True
|
214 |
+
dpg.configure_item("_button_train", label="stop")
|
215 |
+
|
216 |
+
dpg.add_button(label="start", tag="_button_train", callback=callback_train)
|
217 |
+
dpg.bind_item_theme("_button_train", theme_button)
|
218 |
+
|
219 |
+
def callback_reset(sender, app_data):
|
220 |
+
@torch.no_grad()
|
221 |
+
def weight_reset(m: nn.Module):
|
222 |
+
reset_parameters = getattr(m, "reset_parameters", None)
|
223 |
+
if callable(reset_parameters):
|
224 |
+
m.reset_parameters()
|
225 |
+
self.trainer.model.apply(fn=weight_reset)
|
226 |
+
self.trainer.model.reset_extra_state() # for cuda_ray density_grid and step_counter
|
227 |
+
self.need_update = True
|
228 |
+
|
229 |
+
dpg.add_button(label="reset", tag="_button_reset", callback=callback_reset)
|
230 |
+
dpg.bind_item_theme("_button_reset", theme_button)
|
231 |
+
|
232 |
+
|
233 |
+
with dpg.group(horizontal=True):
|
234 |
+
dpg.add_text("Checkpoint: ")
|
235 |
+
|
236 |
+
def callback_save(sender, app_data):
|
237 |
+
self.trainer.save_checkpoint(full=True, best=False)
|
238 |
+
dpg.set_value("_log_ckpt", "saved " + os.path.basename(self.trainer.stats["checkpoints"][-1]))
|
239 |
+
self.trainer.epoch += 1 # use epoch to indicate different calls.
|
240 |
+
|
241 |
+
dpg.add_button(label="save", tag="_button_save", callback=callback_save)
|
242 |
+
dpg.bind_item_theme("_button_save", theme_button)
|
243 |
+
|
244 |
+
dpg.add_text("", tag="_log_ckpt")
|
245 |
+
|
246 |
+
# save mesh
|
247 |
+
with dpg.group(horizontal=True):
|
248 |
+
dpg.add_text("Marching Cubes: ")
|
249 |
+
|
250 |
+
def callback_mesh(sender, app_data):
|
251 |
+
self.trainer.save_mesh(resolution=256, threshold=10)
|
252 |
+
dpg.set_value("_log_mesh", "saved " + f'{self.trainer.name}_{self.trainer.epoch}.ply')
|
253 |
+
self.trainer.epoch += 1 # use epoch to indicate different calls.
|
254 |
+
|
255 |
+
dpg.add_button(label="mesh", tag="_button_mesh", callback=callback_mesh)
|
256 |
+
dpg.bind_item_theme("_button_mesh", theme_button)
|
257 |
+
|
258 |
+
dpg.add_text("", tag="_log_mesh")
|
259 |
+
|
260 |
+
with dpg.group(horizontal=True):
|
261 |
+
dpg.add_text("", tag="_log_train_log")
|
262 |
+
|
263 |
+
|
264 |
+
# rendering options
|
265 |
+
with dpg.collapsing_header(label="Options", default_open=True):
|
266 |
+
|
267 |
+
# dynamic rendering resolution
|
268 |
+
with dpg.group(horizontal=True):
|
269 |
+
|
270 |
+
def callback_set_dynamic_resolution(sender, app_data):
|
271 |
+
if self.dynamic_resolution:
|
272 |
+
self.dynamic_resolution = False
|
273 |
+
self.downscale = 1
|
274 |
+
else:
|
275 |
+
self.dynamic_resolution = True
|
276 |
+
self.need_update = True
|
277 |
+
|
278 |
+
dpg.add_checkbox(label="dynamic resolution", default_value=self.dynamic_resolution, callback=callback_set_dynamic_resolution)
|
279 |
+
dpg.add_text(f"{self.W}x{self.H}", tag="_log_resolution")
|
280 |
+
|
281 |
+
# mode combo
|
282 |
+
def callback_change_mode(sender, app_data):
|
283 |
+
self.mode = app_data
|
284 |
+
self.need_update = True
|
285 |
+
|
286 |
+
dpg.add_combo(('image', 'depth'), label='mode', default_value=self.mode, callback=callback_change_mode)
|
287 |
+
|
288 |
+
# bg_color picker
|
289 |
+
def callback_change_bg(sender, app_data):
|
290 |
+
self.bg_color = torch.tensor(app_data[:3], dtype=torch.float32) # only need RGB in [0, 1]
|
291 |
+
self.need_update = True
|
292 |
+
|
293 |
+
dpg.add_color_edit((255, 255, 255), label="Background Color", width=200, tag="_color_editor", no_alpha=True, callback=callback_change_bg)
|
294 |
+
|
295 |
+
# fov slider
|
296 |
+
def callback_set_fovy(sender, app_data):
|
297 |
+
self.cam.fovy = app_data
|
298 |
+
self.need_update = True
|
299 |
+
|
300 |
+
dpg.add_slider_int(label="FoV (vertical)", min_value=1, max_value=120, format="%d deg", default_value=self.cam.fovy, callback=callback_set_fovy)
|
301 |
+
|
302 |
+
# dt_gamma slider
|
303 |
+
def callback_set_dt_gamma(sender, app_data):
|
304 |
+
self.opt.dt_gamma = app_data
|
305 |
+
self.need_update = True
|
306 |
+
|
307 |
+
dpg.add_slider_float(label="dt_gamma", min_value=0, max_value=0.1, format="%.5f", default_value=self.opt.dt_gamma, callback=callback_set_dt_gamma)
|
308 |
+
|
309 |
+
# max_steps slider
|
310 |
+
def callback_set_max_steps(sender, app_data):
|
311 |
+
self.opt.max_steps = app_data
|
312 |
+
self.need_update = True
|
313 |
+
|
314 |
+
dpg.add_slider_int(label="max steps", min_value=1, max_value=1024, format="%d", default_value=self.opt.max_steps, callback=callback_set_max_steps)
|
315 |
+
|
316 |
+
# aabb slider
|
317 |
+
def callback_set_aabb(sender, app_data, user_data):
|
318 |
+
# user_data is the dimension for aabb (xmin, ymin, zmin, xmax, ymax, zmax)
|
319 |
+
self.trainer.model.aabb_infer[user_data] = app_data
|
320 |
+
|
321 |
+
# also change train aabb ? [better not...]
|
322 |
+
#self.trainer.model.aabb_train[user_data] = app_data
|
323 |
+
|
324 |
+
self.need_update = True
|
325 |
+
|
326 |
+
dpg.add_separator()
|
327 |
+
dpg.add_text("Axis-aligned bounding box:")
|
328 |
+
|
329 |
+
with dpg.group(horizontal=True):
|
330 |
+
dpg.add_slider_float(label="x", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=0)
|
331 |
+
dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=3)
|
332 |
+
|
333 |
+
with dpg.group(horizontal=True):
|
334 |
+
dpg.add_slider_float(label="y", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=1)
|
335 |
+
dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=4)
|
336 |
+
|
337 |
+
with dpg.group(horizontal=True):
|
338 |
+
dpg.add_slider_float(label="z", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=2)
|
339 |
+
dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=5)
|
340 |
+
|
341 |
+
# light dir
|
342 |
+
def callback_set_light_dir(sender, app_data, user_data):
|
343 |
+
self.light_dir[user_data] = app_data
|
344 |
+
self.need_update = True
|
345 |
+
|
346 |
+
dpg.add_separator()
|
347 |
+
dpg.add_text("Plane Light Direction:")
|
348 |
+
|
349 |
+
with dpg.group(horizontal=True):
|
350 |
+
dpg.add_slider_float(label="theta", min_value=0, max_value=180, format="%.2f", default_value=self.opt.light_theta, callback=callback_set_light_dir, user_data=0)
|
351 |
+
|
352 |
+
with dpg.group(horizontal=True):
|
353 |
+
dpg.add_slider_float(label="phi", min_value=0, max_value=360, format="%.2f", default_value=self.opt.light_phi, callback=callback_set_light_dir, user_data=1)
|
354 |
+
|
355 |
+
# ambient ratio
|
356 |
+
def callback_set_abm_ratio(sender, app_data):
|
357 |
+
self.ambient_ratio = app_data
|
358 |
+
self.need_update = True
|
359 |
+
|
360 |
+
dpg.add_slider_float(label="ambient", min_value=0, max_value=1.0, format="%.5f", default_value=self.ambient_ratio, callback=callback_set_abm_ratio)
|
361 |
+
|
362 |
+
# shading mode
|
363 |
+
def callback_change_shading(sender, app_data):
|
364 |
+
self.shading = app_data
|
365 |
+
self.need_update = True
|
366 |
+
|
367 |
+
dpg.add_combo(('albedo', 'lambertian', 'textureless', 'normal'), label='shading', default_value=self.shading, callback=callback_change_shading)
|
368 |
+
|
369 |
+
|
370 |
+
# debug info
|
371 |
+
if self.debug:
|
372 |
+
with dpg.collapsing_header(label="Debug"):
|
373 |
+
# pose
|
374 |
+
dpg.add_separator()
|
375 |
+
dpg.add_text("Camera Pose:")
|
376 |
+
dpg.add_text(str(self.cam.pose), tag="_log_pose")
|
377 |
+
|
378 |
+
|
379 |
+
### register camera handler
|
380 |
+
|
381 |
+
def callback_camera_drag_rotate(sender, app_data):
|
382 |
+
|
383 |
+
if not dpg.is_item_focused("_primary_window"):
|
384 |
+
return
|
385 |
+
|
386 |
+
dx = app_data[1]
|
387 |
+
dy = app_data[2]
|
388 |
+
|
389 |
+
self.cam.orbit(dx, dy)
|
390 |
+
self.need_update = True
|
391 |
+
|
392 |
+
if self.debug:
|
393 |
+
dpg.set_value("_log_pose", str(self.cam.pose))
|
394 |
+
|
395 |
+
|
396 |
+
def callback_camera_wheel_scale(sender, app_data):
|
397 |
+
|
398 |
+
if not dpg.is_item_focused("_primary_window"):
|
399 |
+
return
|
400 |
+
|
401 |
+
delta = app_data
|
402 |
+
|
403 |
+
self.cam.scale(delta)
|
404 |
+
self.need_update = True
|
405 |
+
|
406 |
+
if self.debug:
|
407 |
+
dpg.set_value("_log_pose", str(self.cam.pose))
|
408 |
+
|
409 |
+
|
410 |
+
def callback_camera_drag_pan(sender, app_data):
|
411 |
+
|
412 |
+
if not dpg.is_item_focused("_primary_window"):
|
413 |
+
return
|
414 |
+
|
415 |
+
dx = app_data[1]
|
416 |
+
dy = app_data[2]
|
417 |
+
|
418 |
+
self.cam.pan(dx, dy)
|
419 |
+
self.need_update = True
|
420 |
+
|
421 |
+
if self.debug:
|
422 |
+
dpg.set_value("_log_pose", str(self.cam.pose))
|
423 |
+
|
424 |
+
|
425 |
+
with dpg.handler_registry():
|
426 |
+
dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Left, callback=callback_camera_drag_rotate)
|
427 |
+
dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale)
|
428 |
+
dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan)
|
429 |
+
|
430 |
+
|
431 |
+
dpg.create_viewport(title='torch-ngp', width=self.W, height=self.H, resizable=False)
|
432 |
+
|
433 |
+
# TODO: seems dearpygui doesn't support resizing texture...
|
434 |
+
# def callback_resize(sender, app_data):
|
435 |
+
# self.W = app_data[0]
|
436 |
+
# self.H = app_data[1]
|
437 |
+
# # how to reload texture ???
|
438 |
+
|
439 |
+
# dpg.set_viewport_resize_callback(callback_resize)
|
440 |
+
|
441 |
+
### global theme
|
442 |
+
with dpg.theme() as theme_no_padding:
|
443 |
+
with dpg.theme_component(dpg.mvAll):
|
444 |
+
# set all padding to 0 to avoid scroll bar
|
445 |
+
dpg.add_theme_style(dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core)
|
446 |
+
dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core)
|
447 |
+
dpg.add_theme_style(dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core)
|
448 |
+
|
449 |
+
dpg.bind_item_theme("_primary_window", theme_no_padding)
|
450 |
+
|
451 |
+
dpg.setup_dearpygui()
|
452 |
+
|
453 |
+
#dpg.show_metrics()
|
454 |
+
|
455 |
+
dpg.show_viewport()
|
456 |
+
|
457 |
+
|
458 |
+
def render(self):
|
459 |
+
|
460 |
+
while dpg.is_dearpygui_running():
|
461 |
+
# update texture every frame
|
462 |
+
if self.training:
|
463 |
+
self.train_step()
|
464 |
+
self.test_step()
|
465 |
+
dpg.render_dearpygui_frame()
|
nerf/network.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from activation import trunc_exp
|
6 |
+
from .renderer import NeRFRenderer
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from encoding import get_encoder
|
10 |
+
|
11 |
+
from .utils import safe_normalize
|
12 |
+
|
13 |
+
class MLP(nn.Module):
|
14 |
+
def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
|
15 |
+
super().__init__()
|
16 |
+
self.dim_in = dim_in
|
17 |
+
self.dim_out = dim_out
|
18 |
+
self.dim_hidden = dim_hidden
|
19 |
+
self.num_layers = num_layers
|
20 |
+
|
21 |
+
net = []
|
22 |
+
for l in range(num_layers):
|
23 |
+
net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))
|
24 |
+
|
25 |
+
self.net = nn.ModuleList(net)
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
for l in range(self.num_layers):
|
29 |
+
x = self.net[l](x)
|
30 |
+
if l != self.num_layers - 1:
|
31 |
+
x = F.relu(x, inplace=True)
|
32 |
+
return x
|
33 |
+
|
34 |
+
|
35 |
+
class NeRFNetwork(NeRFRenderer):
|
36 |
+
def __init__(self,
|
37 |
+
opt,
|
38 |
+
num_layers=5,
|
39 |
+
hidden_dim=128,
|
40 |
+
num_layers_bg=3,
|
41 |
+
hidden_dim_bg=128,
|
42 |
+
):
|
43 |
+
|
44 |
+
super().__init__(opt)
|
45 |
+
|
46 |
+
self.num_layers = num_layers
|
47 |
+
self.hidden_dim = hidden_dim
|
48 |
+
|
49 |
+
self.encoder, self.in_dim = get_encoder('frequency', input_dim=3)
|
50 |
+
|
51 |
+
self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True)
|
52 |
+
|
53 |
+
# background network
|
54 |
+
if self.bg_radius > 0:
|
55 |
+
self.num_layers_bg = num_layers_bg
|
56 |
+
self.hidden_dim_bg = hidden_dim_bg
|
57 |
+
|
58 |
+
self.encoder_bg, self.in_dim_bg = get_encoder('tiledgrid', input_dim=2)
|
59 |
+
|
60 |
+
self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
|
61 |
+
|
62 |
+
else:
|
63 |
+
self.bg_net = None
|
64 |
+
|
65 |
+
def gaussian(self, x):
|
66 |
+
# x: [B, N, 3]
|
67 |
+
|
68 |
+
d = (x ** 2).sum(-1)
|
69 |
+
g = 5 * torch.exp(-d / (2 * 0.2 ** 2))
|
70 |
+
|
71 |
+
return g
|
72 |
+
|
73 |
+
def common_forward(self, x):
|
74 |
+
# x: [N, 3], in [-bound, bound]
|
75 |
+
|
76 |
+
# sigma
|
77 |
+
h = self.encoder(x, bound=self.bound)
|
78 |
+
|
79 |
+
h = self.sigma_net(h)
|
80 |
+
|
81 |
+
sigma = trunc_exp(h[..., 0] + self.gaussian(x))
|
82 |
+
albedo = torch.sigmoid(h[..., 1:])
|
83 |
+
|
84 |
+
return sigma, albedo
|
85 |
+
|
86 |
+
# ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192
|
87 |
+
def finite_differnce_normal(self, x, epsilon=5e-4):
|
88 |
+
# x: [N, 3]
|
89 |
+
dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
|
90 |
+
dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
|
91 |
+
dy_pos, _ = self.common_forward((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
|
92 |
+
dy_neg, _ = self.common_forward((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
|
93 |
+
dz_pos, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound))
|
94 |
+
dz_neg, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound))
|
95 |
+
|
96 |
+
normal = torch.stack([
|
97 |
+
0.5 * (dx_pos - dx_neg) / epsilon,
|
98 |
+
0.5 * (dy_pos - dy_neg) / epsilon,
|
99 |
+
0.5 * (dz_pos - dz_neg) / epsilon
|
100 |
+
], dim=-1)
|
101 |
+
|
102 |
+
return normal
|
103 |
+
|
104 |
+
def forward(self, x, d, l=None, ratio=1, shading='albedo'):
|
105 |
+
# x: [N, 3], in [-bound, bound]
|
106 |
+
# d: [N, 3], view direction, nomalized in [-1, 1]
|
107 |
+
# l: [3], plane light direction, nomalized in [-1, 1]
|
108 |
+
# ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)
|
109 |
+
|
110 |
+
if shading == 'albedo':
|
111 |
+
# no need to query normal
|
112 |
+
sigma, color = self.common_forward(x)
|
113 |
+
normal = None
|
114 |
+
|
115 |
+
else:
|
116 |
+
# query normal
|
117 |
+
|
118 |
+
# sigma, albedo = self.common_forward(x)
|
119 |
+
# normal = self.finite_differnce_normal(x)
|
120 |
+
|
121 |
+
with torch.enable_grad():
|
122 |
+
x.requires_grad_(True)
|
123 |
+
sigma, albedo = self.common_forward(x)
|
124 |
+
# query gradient
|
125 |
+
normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
|
126 |
+
|
127 |
+
# normalize...
|
128 |
+
normal = safe_normalize(normal)
|
129 |
+
normal[torch.isnan(normal)] = 0
|
130 |
+
|
131 |
+
# light direction (random if not provided)
|
132 |
+
if l is None:
|
133 |
+
l = torch.randn(3, device=x.device, dtype=torch.float)
|
134 |
+
l = safe_normalize(l)
|
135 |
+
|
136 |
+
# lambertian shading
|
137 |
+
lambertian = ratio + (1 - ratio) * (normal @ -l).clamp(min=0) # [N,]
|
138 |
+
|
139 |
+
if shading == 'textureless':
|
140 |
+
color = lambertian.unsqueeze(-1).repeat(1, 3)
|
141 |
+
elif shading == 'normal':
|
142 |
+
color = (normal + 1) / 2
|
143 |
+
else: # 'lambertian'
|
144 |
+
color = albedo * lambertian.unsqueeze(-1)
|
145 |
+
|
146 |
+
return sigma, color, normal
|
147 |
+
|
148 |
+
|
149 |
+
def density(self, x):
|
150 |
+
# x: [N, 3], in [-bound, bound]
|
151 |
+
|
152 |
+
sigma, albedo = self.common_forward(x)
|
153 |
+
|
154 |
+
return {
|
155 |
+
'sigma': sigma,
|
156 |
+
'albedo': albedo,
|
157 |
+
}
|
158 |
+
|
159 |
+
|
160 |
+
def background(self, x, d):
|
161 |
+
# x: [N, 2], in [-1, 1]
|
162 |
+
|
163 |
+
h = self.encoder_bg(x) # [N, C]
|
164 |
+
|
165 |
+
h = self.bg_net(h)
|
166 |
+
|
167 |
+
# sigmoid activation for rgb
|
168 |
+
rgbs = torch.sigmoid(h)
|
169 |
+
|
170 |
+
return rgbs
|
171 |
+
|
172 |
+
# optimizer utils
|
173 |
+
def get_params(self, lr):
|
174 |
+
|
175 |
+
params = [
|
176 |
+
# {'params': self.encoder.parameters(), 'lr': lr * 10},
|
177 |
+
{'params': self.sigma_net.parameters(), 'lr': lr},
|
178 |
+
]
|
179 |
+
|
180 |
+
if self.bg_radius > 0:
|
181 |
+
# params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})
|
182 |
+
params.append({'params': self.bg_net.parameters(), 'lr': lr})
|
183 |
+
|
184 |
+
return params
|
nerf/network_grid.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from activation import trunc_exp
|
6 |
+
from .renderer import NeRFRenderer
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from encoding import get_encoder
|
10 |
+
|
11 |
+
from .utils import safe_normalize
|
12 |
+
|
13 |
+
class MLP(nn.Module):
|
14 |
+
def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
|
15 |
+
super().__init__()
|
16 |
+
self.dim_in = dim_in
|
17 |
+
self.dim_out = dim_out
|
18 |
+
self.dim_hidden = dim_hidden
|
19 |
+
self.num_layers = num_layers
|
20 |
+
|
21 |
+
net = []
|
22 |
+
for l in range(num_layers):
|
23 |
+
net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))
|
24 |
+
|
25 |
+
self.net = nn.ModuleList(net)
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
for l in range(self.num_layers):
|
29 |
+
x = self.net[l](x)
|
30 |
+
if l != self.num_layers - 1:
|
31 |
+
x = F.relu(x, inplace=True)
|
32 |
+
return x
|
33 |
+
|
34 |
+
|
35 |
+
class NeRFNetwork(NeRFRenderer):
|
36 |
+
def __init__(self,
|
37 |
+
opt,
|
38 |
+
num_layers=3,
|
39 |
+
hidden_dim=64,
|
40 |
+
num_layers_bg=2,
|
41 |
+
hidden_dim_bg=64,
|
42 |
+
):
|
43 |
+
|
44 |
+
super().__init__(opt)
|
45 |
+
|
46 |
+
self.num_layers = num_layers
|
47 |
+
self.hidden_dim = hidden_dim
|
48 |
+
|
49 |
+
self.encoder, self.in_dim = get_encoder('tiledgrid', input_dim=3, desired_resolution=2048 * self.bound)
|
50 |
+
|
51 |
+
self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True)
|
52 |
+
|
53 |
+
# background network
|
54 |
+
if self.bg_radius > 0:
|
55 |
+
self.num_layers_bg = num_layers_bg
|
56 |
+
self.hidden_dim_bg = hidden_dim_bg
|
57 |
+
|
58 |
+
# use a very simple network to avoid it learning the prompt...
|
59 |
+
# self.encoder_bg, self.in_dim_bg = get_encoder('tiledgrid', input_dim=2, num_levels=4, desired_resolution=2048)
|
60 |
+
self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=2)
|
61 |
+
|
62 |
+
self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
|
63 |
+
|
64 |
+
else:
|
65 |
+
self.bg_net = None
|
66 |
+
|
67 |
+
def gaussian(self, x):
|
68 |
+
# x: [B, N, 3]
|
69 |
+
|
70 |
+
d = (x ** 2).sum(-1)
|
71 |
+
g = 5 * torch.exp(-d / (2 * 0.2 ** 2))
|
72 |
+
|
73 |
+
return g
|
74 |
+
|
75 |
+
def common_forward(self, x):
|
76 |
+
# x: [N, 3], in [-bound, bound]
|
77 |
+
|
78 |
+
# sigma
|
79 |
+
h = self.encoder(x, bound=self.bound)
|
80 |
+
|
81 |
+
h = self.sigma_net(h)
|
82 |
+
|
83 |
+
sigma = trunc_exp(h[..., 0] + self.gaussian(x))
|
84 |
+
albedo = torch.sigmoid(h[..., 1:])
|
85 |
+
|
86 |
+
return sigma, albedo
|
87 |
+
|
88 |
+
# ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192
|
89 |
+
def finite_differnce_normal(self, x, epsilon=5e-4):
|
90 |
+
# x: [N, 3]
|
91 |
+
dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
|
92 |
+
dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
|
93 |
+
dy_pos, _ = self.common_forward((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
|
94 |
+
dy_neg, _ = self.common_forward((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
|
95 |
+
dz_pos, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound))
|
96 |
+
dz_neg, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound))
|
97 |
+
|
98 |
+
normal = torch.stack([
|
99 |
+
0.5 * (dx_pos - dx_neg) / epsilon,
|
100 |
+
0.5 * (dy_pos - dy_neg) / epsilon,
|
101 |
+
0.5 * (dz_pos - dz_neg) / epsilon
|
102 |
+
], dim=-1)
|
103 |
+
|
104 |
+
return normal
|
105 |
+
|
106 |
+
def forward(self, x, d, l=None, ratio=1, shading='albedo'):
|
107 |
+
# x: [N, 3], in [-bound, bound]
|
108 |
+
# d: [N, 3], view direction, nomalized in [-1, 1]
|
109 |
+
# l: [3], plane light direction, nomalized in [-1, 1]
|
110 |
+
# ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)
|
111 |
+
|
112 |
+
if shading == 'albedo':
|
113 |
+
# no need to query normal
|
114 |
+
sigma, color = self.common_forward(x)
|
115 |
+
normal = None
|
116 |
+
|
117 |
+
else:
|
118 |
+
# query normal
|
119 |
+
|
120 |
+
sigma, albedo = self.common_forward(x)
|
121 |
+
normal = self.finite_differnce_normal(x)
|
122 |
+
|
123 |
+
# with torch.enable_grad():
|
124 |
+
# x.requires_grad_(True)
|
125 |
+
# sigma, albedo = self.common_forward(x)
|
126 |
+
# # query gradient
|
127 |
+
# normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
|
128 |
+
|
129 |
+
# normalize...
|
130 |
+
normal = safe_normalize(normal)
|
131 |
+
normal[torch.isnan(normal)] = 0
|
132 |
+
|
133 |
+
# light direction (random if not provided)
|
134 |
+
if l is None:
|
135 |
+
l = torch.randn(3, device=x.device, dtype=torch.float)
|
136 |
+
l = safe_normalize(l)
|
137 |
+
|
138 |
+
# lambertian shading
|
139 |
+
lambertian = ratio + (1 - ratio) * (normal @ -l).clamp(min=0) # [N,]
|
140 |
+
|
141 |
+
if shading == 'textureless':
|
142 |
+
color = lambertian.unsqueeze(-1).repeat(1, 3)
|
143 |
+
elif shading == 'normal':
|
144 |
+
color = (normal + 1) / 2
|
145 |
+
else: # 'lambertian'
|
146 |
+
color = albedo * lambertian.unsqueeze(-1)
|
147 |
+
|
148 |
+
return sigma, color, normal
|
149 |
+
|
150 |
+
|
151 |
+
def density(self, x):
|
152 |
+
# x: [N, 3], in [-bound, bound]
|
153 |
+
|
154 |
+
sigma, albedo = self.common_forward(x)
|
155 |
+
|
156 |
+
return {
|
157 |
+
'sigma': sigma,
|
158 |
+
'albedo': albedo,
|
159 |
+
}
|
160 |
+
|
161 |
+
|
162 |
+
def background(self, x, d):
|
163 |
+
# x: [N, 2], in [-1, 1]
|
164 |
+
|
165 |
+
h = self.encoder_bg(x) # [N, C]
|
166 |
+
|
167 |
+
h = self.bg_net(h)
|
168 |
+
|
169 |
+
# sigmoid activation for rgb
|
170 |
+
rgbs = torch.sigmoid(h)
|
171 |
+
|
172 |
+
return rgbs
|
173 |
+
|
174 |
+
# optimizer utils
|
175 |
+
def get_params(self, lr):
|
176 |
+
|
177 |
+
params = [
|
178 |
+
{'params': self.encoder.parameters(), 'lr': lr * 10},
|
179 |
+
{'params': self.sigma_net.parameters(), 'lr': lr},
|
180 |
+
]
|
181 |
+
|
182 |
+
if self.bg_radius > 0:
|
183 |
+
params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})
|
184 |
+
params.append({'params': self.bg_net.parameters(), 'lr': lr})
|
185 |
+
|
186 |
+
return params
|
nerf/network_tcnn.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from activation import trunc_exp
|
6 |
+
from .renderer import NeRFRenderer
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import tinycudann as tcnn
|
10 |
+
|
11 |
+
class MLP(nn.Module):
|
12 |
+
def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
|
13 |
+
super().__init__()
|
14 |
+
self.dim_in = dim_in
|
15 |
+
self.dim_out = dim_out
|
16 |
+
self.dim_hidden = dim_hidden
|
17 |
+
self.num_layers = num_layers
|
18 |
+
|
19 |
+
net = []
|
20 |
+
for l in range(num_layers):
|
21 |
+
net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))
|
22 |
+
|
23 |
+
self.net = nn.ModuleList(net)
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
for l in range(self.num_layers):
|
27 |
+
x = self.net[l](x)
|
28 |
+
if l != self.num_layers - 1:
|
29 |
+
x = F.relu(x, inplace=True)
|
30 |
+
return x
|
31 |
+
|
32 |
+
|
33 |
+
class NeRFNetwork(NeRFRenderer):
|
34 |
+
def __init__(self,
|
35 |
+
opt,
|
36 |
+
num_layers=3,
|
37 |
+
hidden_dim=64,
|
38 |
+
num_layers_bg=2,
|
39 |
+
hidden_dim_bg=64,
|
40 |
+
):
|
41 |
+
|
42 |
+
super().__init__(opt)
|
43 |
+
|
44 |
+
self.num_layers = num_layers
|
45 |
+
self.hidden_dim = hidden_dim
|
46 |
+
|
47 |
+
per_level_scale = np.exp2(np.log2(2048 * self.bound / 16) / (16 - 1))
|
48 |
+
|
49 |
+
self.encoder = tcnn.Encoding(
|
50 |
+
n_input_dims=3,
|
51 |
+
encoding_config={
|
52 |
+
"otype": "HashGrid",
|
53 |
+
"n_levels": 16,
|
54 |
+
"n_features_per_level": 2,
|
55 |
+
"log2_hashmap_size": 19,
|
56 |
+
"base_resolution": 16,
|
57 |
+
"per_level_scale": per_level_scale,
|
58 |
+
},
|
59 |
+
)
|
60 |
+
|
61 |
+
self.sigma_net = MLP(32, 4, hidden_dim, num_layers, bias=True)
|
62 |
+
|
63 |
+
# background network
|
64 |
+
if self.bg_radius > 0:
|
65 |
+
self.num_layers_bg = num_layers_bg
|
66 |
+
self.hidden_dim_bg = hidden_dim_bg
|
67 |
+
|
68 |
+
self.encoder_bg = tcnn.Encoding(
|
69 |
+
n_input_dims=2,
|
70 |
+
encoding_config={
|
71 |
+
"otype": "HashGrid",
|
72 |
+
"n_levels": 4,
|
73 |
+
"n_features_per_level": 2,
|
74 |
+
"log2_hashmap_size": 16,
|
75 |
+
"base_resolution": 16,
|
76 |
+
"per_level_scale": 1.5,
|
77 |
+
},
|
78 |
+
)
|
79 |
+
|
80 |
+
self.bg_net = MLP(8, 3, hidden_dim_bg, num_layers_bg, bias=True)
|
81 |
+
|
82 |
+
else:
|
83 |
+
self.bg_net = None
|
84 |
+
|
85 |
+
def gaussian(self, x):
|
86 |
+
# x: [B, N, 3]
|
87 |
+
|
88 |
+
d = (x ** 2).sum(-1)
|
89 |
+
g = 5 * torch.exp(-d / (2 * 0.2 ** 2))
|
90 |
+
|
91 |
+
return g
|
92 |
+
|
93 |
+
def common_forward(self, x):
|
94 |
+
# x: [N, 3], in [-bound, bound]
|
95 |
+
|
96 |
+
# sigma
|
97 |
+
h = (x + self.bound) / (2 * self.bound) # to [0, 1]
|
98 |
+
h = self.encoder(h)
|
99 |
+
|
100 |
+
h = self.sigma_net(h)
|
101 |
+
|
102 |
+
sigma = trunc_exp(h[..., 0] + self.gaussian(x))
|
103 |
+
albedo = torch.sigmoid(h[..., 1:])
|
104 |
+
|
105 |
+
return sigma, albedo
|
106 |
+
|
107 |
+
|
108 |
+
def forward(self, x, d, l=None, ratio=1, shading='albedo'):
|
109 |
+
# x: [N, 3], in [-bound, bound]
|
110 |
+
# d: [N, 3], view direction, nomalized in [-1, 1]
|
111 |
+
# l: [3], plane light direction, nomalized in [-1, 1]
|
112 |
+
# ratio: scalar, ambient ratio, 1 == no shading (albedo only)
|
113 |
+
|
114 |
+
if shading == 'albedo':
|
115 |
+
# no need to query normal
|
116 |
+
sigma, color = self.common_forward(x)
|
117 |
+
normal = None
|
118 |
+
|
119 |
+
else:
|
120 |
+
# query normal
|
121 |
+
has_grad = torch.is_grad_enabled()
|
122 |
+
|
123 |
+
with torch.enable_grad():
|
124 |
+
x.requires_grad_(True)
|
125 |
+
sigma, albedo = self.common_forward(x)
|
126 |
+
# query gradient
|
127 |
+
normal = torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
|
128 |
+
|
129 |
+
# normalize...
|
130 |
+
normal = normal / (torch.norm(normal, dim=-1, keepdim=True) + 1e-9)
|
131 |
+
normal[torch.isnan(normal)] = 0
|
132 |
+
|
133 |
+
if not has_grad:
|
134 |
+
normal = normal.detach()
|
135 |
+
|
136 |
+
# light direction (random if not provided)
|
137 |
+
if l is None:
|
138 |
+
l = torch.randn(3, device=x.device, dtype=torch.float)
|
139 |
+
l = l / (torch.norm(l, dim=-1, keepdim=True) + 1e-9)
|
140 |
+
|
141 |
+
# lambertian shading
|
142 |
+
lambertian = ratio + (1 - ratio) * (normal @ l).clamp(min=0) # [N,]
|
143 |
+
|
144 |
+
if shading == 'textureless':
|
145 |
+
color = lambertian.unsqueeze(-1).repeat(1, 3)
|
146 |
+
elif shading == 'normal':
|
147 |
+
color = (normal + 1) / 2
|
148 |
+
else: # 'lambertian'
|
149 |
+
color = albedo * lambertian.unsqueeze(-1)
|
150 |
+
|
151 |
+
return sigma, color, normal
|
152 |
+
|
153 |
+
|
154 |
+
def density(self, x):
|
155 |
+
# x: [N, 3], in [-bound, bound]
|
156 |
+
|
157 |
+
sigma, _ = self.common_forward(x)
|
158 |
+
|
159 |
+
return {
|
160 |
+
'sigma': sigma
|
161 |
+
}
|
162 |
+
|
163 |
+
|
164 |
+
def background(self, x, d):
|
165 |
+
# x: [N, 2], in [-1, 1]
|
166 |
+
|
167 |
+
h = (x + 1) / (2 * 1) # to [0, 1]
|
168 |
+
h = self.encoder_bg(h) # [N, C]
|
169 |
+
|
170 |
+
h = self.bg_net(h)
|
171 |
+
|
172 |
+
# sigmoid activation for rgb
|
173 |
+
rgbs = torch.sigmoid(h)
|
174 |
+
|
175 |
+
return rgbs
|
176 |
+
|
177 |
+
# optimizer utils
|
178 |
+
def get_params(self, lr):
|
179 |
+
|
180 |
+
params = [
|
181 |
+
{'params': self.encoder.parameters(), 'lr': lr * 10},
|
182 |
+
{'params': self.sigma_net.parameters(), 'lr': lr},
|
183 |
+
]
|
184 |
+
|
185 |
+
if self.bg_radius > 0:
|
186 |
+
params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})
|
187 |
+
params.append({'params': self.bg_net.parameters(), 'lr': lr})
|
188 |
+
|
189 |
+
return params
|
nerf/provider.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import glob
|
4 |
+
import json
|
5 |
+
import tqdm
|
6 |
+
import random
|
7 |
+
import numpy as np
|
8 |
+
from scipy.spatial.transform import Slerp, Rotation
|
9 |
+
|
10 |
+
import trimesh
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch.utils.data import DataLoader
|
14 |
+
|
15 |
+
from .utils import get_rays, safe_normalize
|
16 |
+
|
17 |
+
def visualize_poses(poses, size=0.1):
|
18 |
+
# poses: [B, 4, 4]
|
19 |
+
|
20 |
+
axes = trimesh.creation.axis(axis_length=4)
|
21 |
+
sphere = trimesh.creation.icosphere(radius=1)
|
22 |
+
objects = [axes, sphere]
|
23 |
+
|
24 |
+
for pose in poses:
|
25 |
+
# a camera is visualized with 8 line segments.
|
26 |
+
pos = pose[:3, 3]
|
27 |
+
a = pos + size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2]
|
28 |
+
b = pos - size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2]
|
29 |
+
c = pos - size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2]
|
30 |
+
d = pos + size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2]
|
31 |
+
|
32 |
+
segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a]])
|
33 |
+
segs = trimesh.load_path(segs)
|
34 |
+
objects.append(segs)
|
35 |
+
|
36 |
+
trimesh.Scene(objects).show()
|
37 |
+
|
38 |
+
def get_view_direction(thetas, phis):
|
39 |
+
# phis [B,]; thetas: [B,]
|
40 |
+
# front = 0 0-90
|
41 |
+
# side (left) = 1 90-180
|
42 |
+
# back = 2 180-270
|
43 |
+
# side (right) = 3 270-360
|
44 |
+
# top = 4 0-30
|
45 |
+
# bottom = 5 150-180
|
46 |
+
res = torch.zeros(phis.shape[0], dtype=torch.long)
|
47 |
+
# first determine by phis
|
48 |
+
res[(phis < (np.pi / 2))] = 0
|
49 |
+
res[(phis >= (np.pi / 2)) & (phis < np.pi)] = 1
|
50 |
+
res[(phis >= np.pi) & (phis < (3 * np.pi / 2))] = 2
|
51 |
+
res[(phis >= (3 * np.pi / 2)) & (phis < (2 * np.pi))] = 3
|
52 |
+
# override by thetas
|
53 |
+
res[thetas < (np.pi / 6)] = 4
|
54 |
+
res[thetas >= (5 * np.pi / 6)] = 5
|
55 |
+
return res
|
56 |
+
|
57 |
+
|
58 |
+
def rand_poses(size, device, return_dirs=False, radius_range=[1, 1.5], theta_range=[0, 4 * np.pi / 6], phi_range=[0, 2*np.pi]):
|
59 |
+
''' generate random poses from an orbit camera
|
60 |
+
Args:
|
61 |
+
size: batch size of generated poses.
|
62 |
+
device: where to allocate the output.
|
63 |
+
radius: camera radius
|
64 |
+
theta_range: [min, max], should be in [0, \pi]
|
65 |
+
phi_range: [min, max], should be in [0, 2\pi]
|
66 |
+
Return:
|
67 |
+
poses: [size, 4, 4]
|
68 |
+
'''
|
69 |
+
|
70 |
+
radius = torch.rand(size, device=device) * (radius_range[1] - radius_range[0]) + radius_range[0]
|
71 |
+
thetas = torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0]
|
72 |
+
phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0]
|
73 |
+
|
74 |
+
centers = torch.stack([
|
75 |
+
radius * torch.sin(thetas) * torch.sin(phis),
|
76 |
+
radius * torch.cos(thetas),
|
77 |
+
radius * torch.sin(thetas) * torch.cos(phis),
|
78 |
+
], dim=-1) # [B, 3]
|
79 |
+
|
80 |
+
# jitters
|
81 |
+
centers = centers + (torch.rand_like(centers) * 0.2 - 0.1)
|
82 |
+
targets = torch.randn_like(centers) * 0.2
|
83 |
+
|
84 |
+
# lookat
|
85 |
+
forward_vector = safe_normalize(targets - centers)
|
86 |
+
up_vector = torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1)
|
87 |
+
right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1))
|
88 |
+
|
89 |
+
up_noise = torch.randn_like(up_vector) * 0.02
|
90 |
+
up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1) + up_noise)
|
91 |
+
|
92 |
+
poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1)
|
93 |
+
poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
|
94 |
+
poses[:, :3, 3] = centers
|
95 |
+
|
96 |
+
if return_dirs:
|
97 |
+
dirs = get_view_direction(thetas, phis)
|
98 |
+
else:
|
99 |
+
dirs = None
|
100 |
+
|
101 |
+
return poses, dirs
|
102 |
+
|
103 |
+
|
104 |
+
def circle_poses(device, return_dirs=False, radius=1.25, theta=np.pi/2, phi=0):
|
105 |
+
|
106 |
+
thetas = torch.FloatTensor([theta]).to(device)
|
107 |
+
phis = torch.FloatTensor([phi]).to(device)
|
108 |
+
|
109 |
+
centers = torch.stack([
|
110 |
+
radius * torch.sin(thetas) * torch.sin(phis),
|
111 |
+
radius * torch.cos(thetas),
|
112 |
+
radius * torch.sin(thetas) * torch.cos(phis),
|
113 |
+
], dim=-1) # [B, 3]
|
114 |
+
|
115 |
+
# lookat
|
116 |
+
forward_vector = - safe_normalize(centers)
|
117 |
+
up_vector = torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0)
|
118 |
+
right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1))
|
119 |
+
up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1))
|
120 |
+
|
121 |
+
poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0)
|
122 |
+
poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
|
123 |
+
poses[:, :3, 3] = centers
|
124 |
+
|
125 |
+
if return_dirs:
|
126 |
+
dirs = get_view_direction(thetas, phis)
|
127 |
+
else:
|
128 |
+
dirs = None
|
129 |
+
|
130 |
+
return poses, dirs
|
131 |
+
|
132 |
+
|
133 |
+
class NeRFDataset:
|
134 |
+
def __init__(self, opt, device, type='train', H=256, W=256, size=100):
|
135 |
+
super().__init__()
|
136 |
+
|
137 |
+
self.opt = opt
|
138 |
+
self.device = device
|
139 |
+
self.type = type # train, val, test
|
140 |
+
|
141 |
+
self.H = H
|
142 |
+
self.W = W
|
143 |
+
self.radius_range = opt.radius_range
|
144 |
+
self.fovy_range = opt.fovy_range
|
145 |
+
self.size = size
|
146 |
+
|
147 |
+
self.training = self.type in ['train', 'all']
|
148 |
+
|
149 |
+
self.cx = self.H / 2
|
150 |
+
self.cy = self.W / 2
|
151 |
+
|
152 |
+
# [debug] visualize poses
|
153 |
+
# poses, dirs = rand_poses(100, self.device, return_dirs=self.opt.dir_text, radius_range=self.radius_range)
|
154 |
+
# visualize_poses(poses.detach().cpu().numpy())
|
155 |
+
|
156 |
+
|
157 |
+
def collate(self, index):
|
158 |
+
|
159 |
+
B = len(index) # always 1
|
160 |
+
|
161 |
+
if self.training:
|
162 |
+
# random pose on the fly
|
163 |
+
poses, dirs = rand_poses(B, self.device, return_dirs=self.opt.dir_text, radius_range=self.radius_range)
|
164 |
+
|
165 |
+
# random focal
|
166 |
+
fov = random.random() * (self.fovy_range[1] - self.fovy_range[0]) + self.fovy_range[0]
|
167 |
+
focal = self.H / (2 * np.tan(np.radians(fov) / 2))
|
168 |
+
intrinsics = np.array([focal, focal, self.cx, self.cy])
|
169 |
+
else:
|
170 |
+
# circle pose
|
171 |
+
phi = (index[0] / self.size) * 2 * np.pi
|
172 |
+
poses, dirs = circle_poses(self.device, return_dirs=self.opt.dir_text, radius=self.radius_range[1], theta=np.pi/2, phi=phi)
|
173 |
+
|
174 |
+
# fixed focal
|
175 |
+
fov = (self.fovy_range[1] + self.fovy_range[0]) / 2
|
176 |
+
focal = self.H / (2 * np.tan(np.radians(fov) / 2))
|
177 |
+
intrinsics = np.array([focal, focal, self.cx, self.cy])
|
178 |
+
|
179 |
+
|
180 |
+
# sample a low-resolution but full image for CLIP
|
181 |
+
rays = get_rays(poses, intrinsics, self.H, self.W, -1)
|
182 |
+
|
183 |
+
data = {
|
184 |
+
'H': self.H,
|
185 |
+
'W': self.W,
|
186 |
+
'rays_o': rays['rays_o'],
|
187 |
+
'rays_d': rays['rays_d'],
|
188 |
+
'dir': dirs,
|
189 |
+
}
|
190 |
+
|
191 |
+
return data
|
192 |
+
|
193 |
+
|
194 |
+
def dataloader(self):
|
195 |
+
loader = DataLoader(list(range(self.size)), batch_size=1, collate_fn=self.collate, shuffle=self.training, num_workers=0)
|
196 |
+
loader._data = self # an ugly fix... we need to access dataset in trainer.
|
197 |
+
return loader
|
nerf/renderer.py
ADDED
@@ -0,0 +1,638 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import cv2
|
4 |
+
import trimesh
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
import mcubes
|
12 |
+
import raymarching
|
13 |
+
from .utils import custom_meshgrid, safe_normalize
|
14 |
+
|
15 |
+
def sample_pdf(bins, weights, n_samples, det=False):
|
16 |
+
# This implementation is from NeRF
|
17 |
+
# bins: [B, T], old_z_vals
|
18 |
+
# weights: [B, T - 1], bin weights.
|
19 |
+
# return: [B, n_samples], new_z_vals
|
20 |
+
|
21 |
+
# Get pdf
|
22 |
+
weights = weights + 1e-5 # prevent nans
|
23 |
+
pdf = weights / torch.sum(weights, -1, keepdim=True)
|
24 |
+
cdf = torch.cumsum(pdf, -1)
|
25 |
+
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
|
26 |
+
# Take uniform samples
|
27 |
+
if det:
|
28 |
+
u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(weights.device)
|
29 |
+
u = u.expand(list(cdf.shape[:-1]) + [n_samples])
|
30 |
+
else:
|
31 |
+
u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device)
|
32 |
+
|
33 |
+
# Invert CDF
|
34 |
+
u = u.contiguous()
|
35 |
+
inds = torch.searchsorted(cdf, u, right=True)
|
36 |
+
below = torch.max(torch.zeros_like(inds - 1), inds - 1)
|
37 |
+
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
|
38 |
+
inds_g = torch.stack([below, above], -1) # (B, n_samples, 2)
|
39 |
+
|
40 |
+
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
|
41 |
+
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
42 |
+
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
43 |
+
|
44 |
+
denom = (cdf_g[..., 1] - cdf_g[..., 0])
|
45 |
+
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
|
46 |
+
t = (u - cdf_g[..., 0]) / denom
|
47 |
+
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
|
48 |
+
|
49 |
+
return samples
|
50 |
+
|
51 |
+
|
52 |
+
def plot_pointcloud(pc, color=None):
|
53 |
+
# pc: [N, 3]
|
54 |
+
# color: [N, 3/4]
|
55 |
+
print('[visualize points]', pc.shape, pc.dtype, pc.min(0), pc.max(0))
|
56 |
+
pc = trimesh.PointCloud(pc, color)
|
57 |
+
# axis
|
58 |
+
axes = trimesh.creation.axis(axis_length=4)
|
59 |
+
# sphere
|
60 |
+
sphere = trimesh.creation.icosphere(radius=1)
|
61 |
+
trimesh.Scene([pc, axes, sphere]).show()
|
62 |
+
|
63 |
+
|
64 |
+
class NeRFRenderer(nn.Module):
|
65 |
+
def __init__(self, opt):
|
66 |
+
super().__init__()
|
67 |
+
|
68 |
+
self.opt = opt
|
69 |
+
self.bound = opt.bound
|
70 |
+
self.cascade = 1 + math.ceil(math.log2(opt.bound))
|
71 |
+
self.grid_size = 128
|
72 |
+
self.cuda_ray = opt.cuda_ray
|
73 |
+
self.min_near = opt.min_near
|
74 |
+
self.density_thresh = opt.density_thresh
|
75 |
+
self.bg_radius = opt.bg_radius
|
76 |
+
|
77 |
+
# prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax)
|
78 |
+
# NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing.
|
79 |
+
aabb_train = torch.FloatTensor([-opt.bound, -opt.bound, -opt.bound, opt.bound, opt.bound, opt.bound])
|
80 |
+
aabb_infer = aabb_train.clone()
|
81 |
+
self.register_buffer('aabb_train', aabb_train)
|
82 |
+
self.register_buffer('aabb_infer', aabb_infer)
|
83 |
+
|
84 |
+
# extra state for cuda raymarching
|
85 |
+
if self.cuda_ray:
|
86 |
+
# density grid
|
87 |
+
density_grid = torch.zeros([self.cascade, self.grid_size ** 3]) # [CAS, H * H * H]
|
88 |
+
density_bitfield = torch.zeros(self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8]
|
89 |
+
self.register_buffer('density_grid', density_grid)
|
90 |
+
self.register_buffer('density_bitfield', density_bitfield)
|
91 |
+
self.mean_density = 0
|
92 |
+
self.iter_density = 0
|
93 |
+
# step counter
|
94 |
+
step_counter = torch.zeros(16, 2, dtype=torch.int32) # 16 is hardcoded for averaging...
|
95 |
+
self.register_buffer('step_counter', step_counter)
|
96 |
+
self.mean_count = 0
|
97 |
+
self.local_step = 0
|
98 |
+
|
99 |
+
|
100 |
+
def forward(self, x, d):
|
101 |
+
raise NotImplementedError()
|
102 |
+
|
103 |
+
def density(self, x):
|
104 |
+
raise NotImplementedError()
|
105 |
+
|
106 |
+
def color(self, x, d, mask=None, **kwargs):
|
107 |
+
raise NotImplementedError()
|
108 |
+
|
109 |
+
def reset_extra_state(self):
|
110 |
+
if not self.cuda_ray:
|
111 |
+
return
|
112 |
+
# density grid
|
113 |
+
self.density_grid.zero_()
|
114 |
+
self.mean_density = 0
|
115 |
+
self.iter_density = 0
|
116 |
+
# step counter
|
117 |
+
self.step_counter.zero_()
|
118 |
+
self.mean_count = 0
|
119 |
+
self.local_step = 0
|
120 |
+
|
121 |
+
@torch.no_grad()
|
122 |
+
def export_mesh(self, path, resolution=None, S=128):
|
123 |
+
|
124 |
+
if resolution is None:
|
125 |
+
resolution = self.grid_size
|
126 |
+
|
127 |
+
density_thresh = min(self.mean_density, self.density_thresh)
|
128 |
+
|
129 |
+
sigmas = np.zeros([resolution, resolution, resolution], dtype=np.float32)
|
130 |
+
|
131 |
+
# query
|
132 |
+
X = torch.linspace(-1, 1, resolution).split(S)
|
133 |
+
Y = torch.linspace(-1, 1, resolution).split(S)
|
134 |
+
Z = torch.linspace(-1, 1, resolution).split(S)
|
135 |
+
|
136 |
+
for xi, xs in enumerate(X):
|
137 |
+
for yi, ys in enumerate(Y):
|
138 |
+
for zi, zs in enumerate(Z):
|
139 |
+
xx, yy, zz = custom_meshgrid(xs, ys, zs)
|
140 |
+
pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [S, 3]
|
141 |
+
val = self.density(pts.to(self.density_bitfield.device))
|
142 |
+
sigmas[xi * S: xi * S + len(xs), yi * S: yi * S + len(ys), zi * S: zi * S + len(zs)] = val['sigma'].reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() # [S, 1] --> [x, y, z]
|
143 |
+
|
144 |
+
vertices, triangles = mcubes.marching_cubes(sigmas, density_thresh)
|
145 |
+
|
146 |
+
vertices = vertices / (resolution - 1.0) * 2 - 1
|
147 |
+
vertices = vertices.astype(np.float32)
|
148 |
+
triangles = triangles.astype(np.int32)
|
149 |
+
|
150 |
+
v = torch.from_numpy(vertices).to(self.density_bitfield.device)
|
151 |
+
f = torch.from_numpy(triangles).int().to(self.density_bitfield.device)
|
152 |
+
|
153 |
+
# mesh = trimesh.Trimesh(vertices, triangles, process=False) # important, process=True leads to seg fault...
|
154 |
+
# mesh.export(os.path.join(path, f'mesh.ply'))
|
155 |
+
|
156 |
+
# texture?
|
157 |
+
def _export(v, f, h0=2048, w0=2048, ssaa=1, name=''):
|
158 |
+
# v, f: torch Tensor
|
159 |
+
device = v.device
|
160 |
+
v_np = v.cpu().numpy() # [N, 3]
|
161 |
+
f_np = f.cpu().numpy() # [M, 3]
|
162 |
+
|
163 |
+
print(f'[INFO] running xatlas to unwrap UVs for mesh: v={v_np.shape} f={f_np.shape}')
|
164 |
+
|
165 |
+
# unwrap uvs
|
166 |
+
import xatlas
|
167 |
+
import nvdiffrast.torch as dr
|
168 |
+
from sklearn.neighbors import NearestNeighbors
|
169 |
+
from scipy.ndimage import binary_dilation, binary_erosion
|
170 |
+
|
171 |
+
glctx = dr.RasterizeGLContext()
|
172 |
+
|
173 |
+
atlas = xatlas.Atlas()
|
174 |
+
atlas.add_mesh(v_np, f_np)
|
175 |
+
chart_options = xatlas.ChartOptions()
|
176 |
+
chart_options.max_iterations = 0 # disable merge_chart for faster unwrap...
|
177 |
+
atlas.generate(chart_options=chart_options)
|
178 |
+
vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
|
179 |
+
|
180 |
+
# vmapping, ft_np, vt_np = xatlas.parametrize(v_np, f_np) # [N], [M, 3], [N, 2]
|
181 |
+
|
182 |
+
vt = torch.from_numpy(vt_np.astype(np.float32)).float().to(device)
|
183 |
+
ft = torch.from_numpy(ft_np.astype(np.int64)).int().to(device)
|
184 |
+
|
185 |
+
# render uv maps
|
186 |
+
uv = vt * 2.0 - 1.0 # uvs to range [-1, 1]
|
187 |
+
uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4]
|
188 |
+
|
189 |
+
if ssaa > 1:
|
190 |
+
h = int(h0 * ssaa)
|
191 |
+
w = int(w0 * ssaa)
|
192 |
+
else:
|
193 |
+
h, w = h0, w0
|
194 |
+
|
195 |
+
rast, _ = dr.rasterize(glctx, uv.unsqueeze(0), ft, (h, w)) # [1, h, w, 4]
|
196 |
+
xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, h, w, 3]
|
197 |
+
mask, _ = dr.interpolate(torch.ones_like(v[:, :1]).unsqueeze(0), rast, f) # [1, h, w, 1]
|
198 |
+
|
199 |
+
# masked query
|
200 |
+
xyzs = xyzs.view(-1, 3)
|
201 |
+
mask = (mask > 0).view(-1)
|
202 |
+
|
203 |
+
sigmas = torch.zeros(h * w, device=device, dtype=torch.float32)
|
204 |
+
feats = torch.zeros(h * w, 3, device=device, dtype=torch.float32)
|
205 |
+
|
206 |
+
if mask.any():
|
207 |
+
xyzs = xyzs[mask] # [M, 3]
|
208 |
+
|
209 |
+
# batched inference to avoid OOM
|
210 |
+
all_sigmas = []
|
211 |
+
all_feats = []
|
212 |
+
head = 0
|
213 |
+
while head < xyzs.shape[0]:
|
214 |
+
tail = min(head + 640000, xyzs.shape[0])
|
215 |
+
results_ = self.density(xyzs[head:tail])
|
216 |
+
all_sigmas.append(results_['sigma'].float())
|
217 |
+
all_feats.append(results_['albedo'].float())
|
218 |
+
head += 640000
|
219 |
+
|
220 |
+
sigmas[mask] = torch.cat(all_sigmas, dim=0)
|
221 |
+
feats[mask] = torch.cat(all_feats, dim=0)
|
222 |
+
|
223 |
+
sigmas = sigmas.view(h, w, 1)
|
224 |
+
feats = feats.view(h, w, -1)
|
225 |
+
mask = mask.view(h, w)
|
226 |
+
|
227 |
+
### alpha mask
|
228 |
+
# deltas = 2 * np.sqrt(3) / 1024
|
229 |
+
# alphas = 1 - torch.exp(-sigmas * deltas)
|
230 |
+
# alphas_mask = alphas > 0.5
|
231 |
+
# feats = feats * alphas_mask
|
232 |
+
|
233 |
+
# quantize [0.0, 1.0] to [0, 255]
|
234 |
+
feats = feats.cpu().numpy()
|
235 |
+
feats = (feats * 255).astype(np.uint8)
|
236 |
+
|
237 |
+
# alphas = alphas.cpu().numpy()
|
238 |
+
# alphas = (alphas * 255).astype(np.uint8)
|
239 |
+
|
240 |
+
### NN search as an antialiasing ...
|
241 |
+
mask = mask.cpu().numpy()
|
242 |
+
|
243 |
+
inpaint_region = binary_dilation(mask, iterations=3)
|
244 |
+
inpaint_region[mask] = 0
|
245 |
+
|
246 |
+
search_region = mask.copy()
|
247 |
+
not_search_region = binary_erosion(search_region, iterations=2)
|
248 |
+
search_region[not_search_region] = 0
|
249 |
+
|
250 |
+
search_coords = np.stack(np.nonzero(search_region), axis=-1)
|
251 |
+
inpaint_coords = np.stack(np.nonzero(inpaint_region), axis=-1)
|
252 |
+
|
253 |
+
knn = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(search_coords)
|
254 |
+
_, indices = knn.kneighbors(inpaint_coords)
|
255 |
+
|
256 |
+
feats[tuple(inpaint_coords.T)] = feats[tuple(search_coords[indices[:, 0]].T)]
|
257 |
+
|
258 |
+
# do ssaa after the NN search, in numpy
|
259 |
+
feats = cv2.cvtColor(feats, cv2.COLOR_RGB2BGR)
|
260 |
+
|
261 |
+
if ssaa > 1:
|
262 |
+
# alphas = cv2.resize(alphas, (w0, h0), interpolation=cv2.INTER_NEAREST)
|
263 |
+
feats = cv2.resize(feats, (w0, h0), interpolation=cv2.INTER_LINEAR)
|
264 |
+
|
265 |
+
# cv2.imwrite(os.path.join(path, f'alpha.png'), alphas)
|
266 |
+
cv2.imwrite(os.path.join(path, f'{name}albedo.png'), feats)
|
267 |
+
|
268 |
+
# save obj (v, vt, f /)
|
269 |
+
obj_file = os.path.join(path, f'{name}mesh.obj')
|
270 |
+
mtl_file = os.path.join(path, f'{name}mesh.mtl')
|
271 |
+
|
272 |
+
print(f'[INFO] writing obj mesh to {obj_file}')
|
273 |
+
with open(obj_file, "w") as fp:
|
274 |
+
fp.write(f'mtllib {name}.mtl \n')
|
275 |
+
|
276 |
+
print(f'[INFO] writing vertices {v_np.shape}')
|
277 |
+
for v in v_np:
|
278 |
+
fp.write(f'v {v[0]} {v[1]} {v[2]} \n')
|
279 |
+
|
280 |
+
print(f'[INFO] writing vertices texture coords {vt_np.shape}')
|
281 |
+
for v in vt_np:
|
282 |
+
fp.write(f'vt {v[0]} {1 - v[1]} \n')
|
283 |
+
|
284 |
+
print(f'[INFO] writing faces {f_np.shape}')
|
285 |
+
fp.write(f'usemtl mat0 \n')
|
286 |
+
for i in range(len(f_np)):
|
287 |
+
fp.write(f"f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1} {f_np[i, 1] + 1}/{ft_np[i, 1] + 1} {f_np[i, 2] + 1}/{ft_np[i, 2] + 1} \n")
|
288 |
+
|
289 |
+
with open(mtl_file, "w") as fp:
|
290 |
+
fp.write(f'newmtl mat0 \n')
|
291 |
+
fp.write(f'Ka 1.000000 1.000000 1.000000 \n')
|
292 |
+
fp.write(f'Kd 1.000000 1.000000 1.000000 \n')
|
293 |
+
fp.write(f'Ks 0.000000 0.000000 0.000000 \n')
|
294 |
+
fp.write(f'Tr 1.000000 \n')
|
295 |
+
fp.write(f'illum 1 \n')
|
296 |
+
fp.write(f'Ns 0.000000 \n')
|
297 |
+
fp.write(f'map_Kd {name}albedo.png \n')
|
298 |
+
|
299 |
+
_export(v, f)
|
300 |
+
|
301 |
+
def run(self, rays_o, rays_d, num_steps=128, upsample_steps=128, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, **kwargs):
|
302 |
+
# rays_o, rays_d: [B, N, 3], assumes B == 1
|
303 |
+
# bg_color: [BN, 3] in range [0, 1]
|
304 |
+
# return: image: [B, N, 3], depth: [B, N]
|
305 |
+
|
306 |
+
prefix = rays_o.shape[:-1]
|
307 |
+
rays_o = rays_o.contiguous().view(-1, 3)
|
308 |
+
rays_d = rays_d.contiguous().view(-1, 3)
|
309 |
+
|
310 |
+
N = rays_o.shape[0] # N = B * N, in fact
|
311 |
+
device = rays_o.device
|
312 |
+
|
313 |
+
results = {}
|
314 |
+
|
315 |
+
# choose aabb
|
316 |
+
aabb = self.aabb_train if self.training else self.aabb_infer
|
317 |
+
|
318 |
+
# sample steps
|
319 |
+
nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, aabb, self.min_near)
|
320 |
+
nears.unsqueeze_(-1)
|
321 |
+
fars.unsqueeze_(-1)
|
322 |
+
|
323 |
+
#print(f'nears = {nears.min().item()} ~ {nears.max().item()}, fars = {fars.min().item()} ~ {fars.max().item()}')
|
324 |
+
|
325 |
+
z_vals = torch.linspace(0.0, 1.0, num_steps, device=device).unsqueeze(0) # [1, T]
|
326 |
+
z_vals = z_vals.expand((N, num_steps)) # [N, T]
|
327 |
+
z_vals = nears + (fars - nears) * z_vals # [N, T], in [nears, fars]
|
328 |
+
|
329 |
+
# perturb z_vals
|
330 |
+
sample_dist = (fars - nears) / num_steps
|
331 |
+
if perturb:
|
332 |
+
z_vals = z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist
|
333 |
+
#z_vals = z_vals.clamp(nears, fars) # avoid out of bounds xyzs.
|
334 |
+
|
335 |
+
# generate xyzs
|
336 |
+
xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(-1) # [N, 1, 3] * [N, T, 1] -> [N, T, 3]
|
337 |
+
xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:]) # a manual clip.
|
338 |
+
|
339 |
+
#plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())
|
340 |
+
|
341 |
+
# query SDF and RGB
|
342 |
+
density_outputs = self.density(xyzs.reshape(-1, 3))
|
343 |
+
|
344 |
+
#sigmas = density_outputs['sigma'].view(N, num_steps) # [N, T]
|
345 |
+
for k, v in density_outputs.items():
|
346 |
+
density_outputs[k] = v.view(N, num_steps, -1)
|
347 |
+
|
348 |
+
# upsample z_vals (nerf-like)
|
349 |
+
if upsample_steps > 0:
|
350 |
+
with torch.no_grad():
|
351 |
+
|
352 |
+
deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T-1]
|
353 |
+
deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1)
|
354 |
+
|
355 |
+
alphas = 1 - torch.exp(-deltas * density_outputs['sigma'].squeeze(-1)) # [N, T]
|
356 |
+
alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+1]
|
357 |
+
weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T]
|
358 |
+
|
359 |
+
# sample new z_vals
|
360 |
+
z_vals_mid = (z_vals[..., :-1] + 0.5 * deltas[..., :-1]) # [N, T-1]
|
361 |
+
new_z_vals = sample_pdf(z_vals_mid, weights[:, 1:-1], upsample_steps, det=not self.training).detach() # [N, t]
|
362 |
+
|
363 |
+
new_xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * new_z_vals.unsqueeze(-1) # [N, 1, 3] * [N, t, 1] -> [N, t, 3]
|
364 |
+
new_xyzs = torch.min(torch.max(new_xyzs, aabb[:3]), aabb[3:]) # a manual clip.
|
365 |
+
|
366 |
+
# only forward new points to save computation
|
367 |
+
new_density_outputs = self.density(new_xyzs.reshape(-1, 3))
|
368 |
+
#new_sigmas = new_density_outputs['sigma'].view(N, upsample_steps) # [N, t]
|
369 |
+
for k, v in new_density_outputs.items():
|
370 |
+
new_density_outputs[k] = v.view(N, upsample_steps, -1)
|
371 |
+
|
372 |
+
# re-order
|
373 |
+
z_vals = torch.cat([z_vals, new_z_vals], dim=1) # [N, T+t]
|
374 |
+
z_vals, z_index = torch.sort(z_vals, dim=1)
|
375 |
+
|
376 |
+
xyzs = torch.cat([xyzs, new_xyzs], dim=1) # [N, T+t, 3]
|
377 |
+
xyzs = torch.gather(xyzs, dim=1, index=z_index.unsqueeze(-1).expand_as(xyzs))
|
378 |
+
|
379 |
+
for k in density_outputs:
|
380 |
+
tmp_output = torch.cat([density_outputs[k], new_density_outputs[k]], dim=1)
|
381 |
+
density_outputs[k] = torch.gather(tmp_output, dim=1, index=z_index.unsqueeze(-1).expand_as(tmp_output))
|
382 |
+
|
383 |
+
deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T+t-1]
|
384 |
+
deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1)
|
385 |
+
alphas = 1 - torch.exp(-deltas * density_outputs['sigma'].squeeze(-1)) # [N, T+t]
|
386 |
+
alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+t+1]
|
387 |
+
weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T+t]
|
388 |
+
|
389 |
+
dirs = rays_d.view(-1, 1, 3).expand_as(xyzs)
|
390 |
+
for k, v in density_outputs.items():
|
391 |
+
density_outputs[k] = v.view(-1, v.shape[-1])
|
392 |
+
|
393 |
+
sigmas, rgbs, normals = self(xyzs.reshape(-1, 3), dirs.reshape(-1, 3), light_d, ratio=ambient_ratio, shading=shading)
|
394 |
+
rgbs = rgbs.view(N, -1, 3) # [N, T+t, 3]
|
395 |
+
|
396 |
+
#print(xyzs.shape, 'valid_rgb:', mask.sum().item())
|
397 |
+
# orientation loss
|
398 |
+
if normals is not None:
|
399 |
+
normals = normals.view(N, -1, 3)
|
400 |
+
# print(weights.shape, normals.shape, dirs.shape)
|
401 |
+
loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2
|
402 |
+
results['loss_orient'] = loss_orient.mean()
|
403 |
+
|
404 |
+
# calculate weight_sum (mask)
|
405 |
+
weights_sum = weights.sum(dim=-1) # [N]
|
406 |
+
|
407 |
+
# calculate depth
|
408 |
+
ori_z_vals = ((z_vals - nears) / (fars - nears)).clamp(0, 1)
|
409 |
+
depth = torch.sum(weights * ori_z_vals, dim=-1)
|
410 |
+
|
411 |
+
# calculate color
|
412 |
+
image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2) # [N, 3], in [0, 1]
|
413 |
+
|
414 |
+
# mix background color
|
415 |
+
if self.bg_radius > 0:
|
416 |
+
# use the bg model to calculate bg_color
|
417 |
+
sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1]
|
418 |
+
bg_color = self.background(sph, rays_d.reshape(-1, 3)) # [N, 3]
|
419 |
+
elif bg_color is None:
|
420 |
+
bg_color = 1
|
421 |
+
|
422 |
+
image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
|
423 |
+
|
424 |
+
image = image.view(*prefix, 3)
|
425 |
+
depth = depth.view(*prefix)
|
426 |
+
|
427 |
+
mask = (nears < fars).reshape(*prefix)
|
428 |
+
|
429 |
+
results['image'] = image
|
430 |
+
results['depth'] = depth
|
431 |
+
results['weights_sum'] = weights_sum
|
432 |
+
results['mask'] = mask
|
433 |
+
|
434 |
+
return results
|
435 |
+
|
436 |
+
|
437 |
+
def run_cuda(self, rays_o, rays_d, dt_gamma=0, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, **kwargs):
|
438 |
+
# rays_o, rays_d: [B, N, 3], assumes B == 1
|
439 |
+
# return: image: [B, N, 3], depth: [B, N]
|
440 |
+
|
441 |
+
prefix = rays_o.shape[:-1]
|
442 |
+
rays_o = rays_o.contiguous().view(-1, 3)
|
443 |
+
rays_d = rays_d.contiguous().view(-1, 3)
|
444 |
+
|
445 |
+
N = rays_o.shape[0] # N = B * N, in fact
|
446 |
+
device = rays_o.device
|
447 |
+
|
448 |
+
# pre-calculate near far
|
449 |
+
nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer)
|
450 |
+
|
451 |
+
results = {}
|
452 |
+
|
453 |
+
if self.training:
|
454 |
+
# setup counter
|
455 |
+
counter = self.step_counter[self.local_step % 16]
|
456 |
+
counter.zero_() # set to 0
|
457 |
+
self.local_step += 1
|
458 |
+
|
459 |
+
xyzs, dirs, deltas, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps)
|
460 |
+
|
461 |
+
#plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())
|
462 |
+
|
463 |
+
sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading)
|
464 |
+
|
465 |
+
#print(f'valid RGB query ratio: {mask.sum().item() / mask.shape[0]} (total = {mask.sum().item()})')
|
466 |
+
|
467 |
+
weights_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, deltas, rays, T_thresh)
|
468 |
+
|
469 |
+
# orientation loss
|
470 |
+
if normals is not None:
|
471 |
+
weights = 1 - torch.exp(-sigmas)
|
472 |
+
loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2
|
473 |
+
results['loss_orient'] = loss_orient.mean()
|
474 |
+
|
475 |
+
else:
|
476 |
+
|
477 |
+
# allocate outputs
|
478 |
+
dtype = torch.float32
|
479 |
+
|
480 |
+
# fix light for all samples if not provided
|
481 |
+
if light_d is None:
|
482 |
+
light_d = torch.randn(3, device=device, dtype=torch.float)
|
483 |
+
light_d = safe_normalize(light_d)
|
484 |
+
|
485 |
+
weights_sum = torch.zeros(N, dtype=dtype, device=device)
|
486 |
+
depth = torch.zeros(N, dtype=dtype, device=device)
|
487 |
+
image = torch.zeros(N, 3, dtype=dtype, device=device)
|
488 |
+
|
489 |
+
n_alive = N
|
490 |
+
rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N]
|
491 |
+
rays_t = nears.clone() # [N]
|
492 |
+
|
493 |
+
step = 0
|
494 |
+
|
495 |
+
while step < max_steps: # hard coded max step
|
496 |
+
|
497 |
+
# count alive rays
|
498 |
+
n_alive = rays_alive.shape[0]
|
499 |
+
|
500 |
+
# exit loop
|
501 |
+
if n_alive <= 0:
|
502 |
+
break
|
503 |
+
|
504 |
+
# decide compact_steps
|
505 |
+
n_step = max(min(N // n_alive, 8), 1)
|
506 |
+
|
507 |
+
xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, 128, perturb if step == 0 else False, dt_gamma, max_steps)
|
508 |
+
|
509 |
+
sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading)
|
510 |
+
|
511 |
+
raymarching.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh)
|
512 |
+
|
513 |
+
rays_alive = rays_alive[rays_alive >= 0]
|
514 |
+
#print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}')
|
515 |
+
|
516 |
+
step += n_step
|
517 |
+
|
518 |
+
# mix background color
|
519 |
+
if self.bg_radius > 0:
|
520 |
+
|
521 |
+
# use the bg model to calculate bg_color
|
522 |
+
sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1]
|
523 |
+
bg_color = self.background(sph, rays_d) # [N, 3]
|
524 |
+
|
525 |
+
elif bg_color is None:
|
526 |
+
bg_color = 1
|
527 |
+
|
528 |
+
image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
|
529 |
+
image = image.view(*prefix, 3)
|
530 |
+
|
531 |
+
depth = torch.clamp(depth - nears, min=0) / (fars - nears)
|
532 |
+
depth = depth.view(*prefix)
|
533 |
+
|
534 |
+
weights_sum = weights_sum.reshape(*prefix)
|
535 |
+
|
536 |
+
mask = (nears < fars).reshape(*prefix)
|
537 |
+
|
538 |
+
results['image'] = image
|
539 |
+
results['depth'] = depth
|
540 |
+
results['weights_sum'] = weights_sum
|
541 |
+
results['mask'] = mask
|
542 |
+
|
543 |
+
return results
|
544 |
+
|
545 |
+
|
546 |
+
@torch.no_grad()
|
547 |
+
def update_extra_state(self, decay=0.95, S=128):
|
548 |
+
# call before each epoch to update extra states.
|
549 |
+
|
550 |
+
if not self.cuda_ray:
|
551 |
+
return
|
552 |
+
|
553 |
+
### update density grid
|
554 |
+
tmp_grid = - torch.ones_like(self.density_grid)
|
555 |
+
|
556 |
+
X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
|
557 |
+
Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
|
558 |
+
Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S)
|
559 |
+
|
560 |
+
for xs in X:
|
561 |
+
for ys in Y:
|
562 |
+
for zs in Z:
|
563 |
+
|
564 |
+
# construct points
|
565 |
+
xx, yy, zz = custom_meshgrid(xs, ys, zs)
|
566 |
+
coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
|
567 |
+
indices = raymarching.morton3D(coords).long() # [N]
|
568 |
+
xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1]
|
569 |
+
|
570 |
+
# cascading
|
571 |
+
for cas in range(self.cascade):
|
572 |
+
bound = min(2 ** cas, self.bound)
|
573 |
+
half_grid_size = bound / self.grid_size
|
574 |
+
# scale to current cascade's resolution
|
575 |
+
cas_xyzs = xyzs * (bound - half_grid_size)
|
576 |
+
# add noise in [-hgs, hgs]
|
577 |
+
cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size
|
578 |
+
# query density
|
579 |
+
sigmas = self.density(cas_xyzs)['sigma'].reshape(-1).detach()
|
580 |
+
# assign
|
581 |
+
tmp_grid[cas, indices] = sigmas
|
582 |
+
|
583 |
+
# ema update
|
584 |
+
valid_mask = self.density_grid >= 0
|
585 |
+
self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask])
|
586 |
+
self.mean_density = torch.mean(self.density_grid[valid_mask]).item()
|
587 |
+
self.iter_density += 1
|
588 |
+
|
589 |
+
# convert to bitfield
|
590 |
+
density_thresh = min(self.mean_density, self.density_thresh)
|
591 |
+
self.density_bitfield = raymarching.packbits(self.density_grid, density_thresh, self.density_bitfield)
|
592 |
+
|
593 |
+
### update step counter
|
594 |
+
total_step = min(16, self.local_step)
|
595 |
+
if total_step > 0:
|
596 |
+
self.mean_count = int(self.step_counter[:total_step, 0].sum().item() / total_step)
|
597 |
+
self.local_step = 0
|
598 |
+
|
599 |
+
# print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > density_thresh).sum() / (128**3 * self.cascade):.3f} | [step counter] mean={self.mean_count}')
|
600 |
+
|
601 |
+
|
602 |
+
def render(self, rays_o, rays_d, staged=False, max_ray_batch=4096, **kwargs):
|
603 |
+
# rays_o, rays_d: [B, N, 3], assumes B == 1
|
604 |
+
# return: pred_rgb: [B, N, 3]
|
605 |
+
|
606 |
+
if self.cuda_ray:
|
607 |
+
_run = self.run_cuda
|
608 |
+
else:
|
609 |
+
_run = self.run
|
610 |
+
|
611 |
+
B, N = rays_o.shape[:2]
|
612 |
+
device = rays_o.device
|
613 |
+
|
614 |
+
# never stage when cuda_ray
|
615 |
+
if staged and not self.cuda_ray:
|
616 |
+
depth = torch.empty((B, N), device=device)
|
617 |
+
image = torch.empty((B, N, 3), device=device)
|
618 |
+
weights_sum = torch.empty((B, N), device=device)
|
619 |
+
|
620 |
+
for b in range(B):
|
621 |
+
head = 0
|
622 |
+
while head < N:
|
623 |
+
tail = min(head + max_ray_batch, N)
|
624 |
+
results_ = _run(rays_o[b:b+1, head:tail], rays_d[b:b+1, head:tail], **kwargs)
|
625 |
+
depth[b:b+1, head:tail] = results_['depth']
|
626 |
+
weights_sum[b:b+1, head:tail] = results_['weights_sum']
|
627 |
+
image[b:b+1, head:tail] = results_['image']
|
628 |
+
head += max_ray_batch
|
629 |
+
|
630 |
+
results = {}
|
631 |
+
results['depth'] = depth
|
632 |
+
results['image'] = image
|
633 |
+
results['weights_sum'] = weights_sum
|
634 |
+
|
635 |
+
else:
|
636 |
+
results = _run(rays_o, rays_d, **kwargs)
|
637 |
+
|
638 |
+
return results
|
nerf/sd.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import CLIPTextModel, CLIPTokenizer, logging
|
2 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
|
3 |
+
|
4 |
+
# suppress partial model loading warning
|
5 |
+
logging.set_verbosity_error()
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
import time
|
12 |
+
|
13 |
+
class StableDiffusion(nn.Module):
|
14 |
+
def __init__(self, device):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
try:
|
18 |
+
with open('./TOKEN', 'r') as f:
|
19 |
+
self.token = f.read()
|
20 |
+
print(f'[INFO] successfully loaded hugging face user token!')
|
21 |
+
except FileNotFoundError as e:
|
22 |
+
print(e)
|
23 |
+
print(f'[INFO] Please first create a file called TOKEN and copy your hugging face access token into it to download stable diffusion checkpoints.')
|
24 |
+
|
25 |
+
self.device = device
|
26 |
+
self.num_train_timesteps = 1000
|
27 |
+
self.min_step = int(self.num_train_timesteps * 0.02)
|
28 |
+
self.max_step = int(self.num_train_timesteps * 0.98)
|
29 |
+
|
30 |
+
print(f'[INFO] loading stable diffusion...')
|
31 |
+
|
32 |
+
# 1. Load the autoencoder model which will be used to decode the latents into image space.
|
33 |
+
self.vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", use_auth_token=self.token).to(self.device)
|
34 |
+
|
35 |
+
# 2. Load the tokenizer and text encoder to tokenize and encode the text.
|
36 |
+
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
37 |
+
self.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(self.device)
|
38 |
+
|
39 |
+
# 3. The UNet model for generating the latents.
|
40 |
+
self.unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", use_auth_token=self.token).to(self.device)
|
41 |
+
|
42 |
+
# 4. Create a scheduler for inference
|
43 |
+
self.scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=self.num_train_timesteps)
|
44 |
+
|
45 |
+
print(f'[INFO] loaded stable diffusion!')
|
46 |
+
|
47 |
+
def get_text_embeds(self, prompt):
|
48 |
+
# Tokenize text and get embeddings
|
49 |
+
text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt')
|
50 |
+
|
51 |
+
with torch.no_grad():
|
52 |
+
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
53 |
+
|
54 |
+
# Do the same for unconditional embeddings
|
55 |
+
uncond_input = self.tokenizer([''] * len(prompt), padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt')
|
56 |
+
|
57 |
+
with torch.no_grad():
|
58 |
+
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
59 |
+
|
60 |
+
# Cat for final embeddings
|
61 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
62 |
+
return text_embeddings
|
63 |
+
|
64 |
+
|
65 |
+
def train_step(self, text_embeddings, pred_rgb, guidance_scale=100):
|
66 |
+
|
67 |
+
# interp to 512x512 to be fed into vae.
|
68 |
+
|
69 |
+
# _t = time.time()
|
70 |
+
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
|
71 |
+
# torch.cuda.synchronize(); print(f'[TIME] guiding: interp {time.time() - _t:.4f}s')
|
72 |
+
|
73 |
+
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
|
74 |
+
t = torch.randint(self.min_step, self.max_step + 1, [1], dtype=torch.long, device=self.device)
|
75 |
+
|
76 |
+
# encode image into latents with vae, requires grad!
|
77 |
+
# _t = time.time()
|
78 |
+
latents = self.encode_imgs(pred_rgb_512)
|
79 |
+
# torch.cuda.synchronize(); print(f'[TIME] guiding: vae enc {time.time() - _t:.4f}s')
|
80 |
+
|
81 |
+
# predict the noise residual with unet, NO grad!
|
82 |
+
# _t = time.time()
|
83 |
+
with torch.no_grad():
|
84 |
+
# add noise
|
85 |
+
noise = torch.randn_like(latents)
|
86 |
+
latents_noisy = self.scheduler.add_noise(latents, noise, t)
|
87 |
+
# pred noise
|
88 |
+
latent_model_input = torch.cat([latents_noisy] * 2)
|
89 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
90 |
+
# torch.cuda.synchronize(); print(f'[TIME] guiding: unet {time.time() - _t:.4f}s')
|
91 |
+
|
92 |
+
# perform guidance (high scale from paper!)
|
93 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
94 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
95 |
+
|
96 |
+
# w(t), one_minus_alpha_prod, i.e., sigma^2
|
97 |
+
w = (1 - self.scheduler.alphas_cumprod[t]).to(self.device)
|
98 |
+
grad = w * (noise_pred - noise)
|
99 |
+
|
100 |
+
# clip grad for stable training?
|
101 |
+
# grad = grad.clamp(-1, 1)
|
102 |
+
|
103 |
+
# manually backward, since we omitted an item in grad and cannot simply autodiff.
|
104 |
+
# _t = time.time()
|
105 |
+
latents.backward(gradient=grad, retain_graph=True)
|
106 |
+
# torch.cuda.synchronize(); print(f'[TIME] guiding: backward {time.time() - _t:.4f}s')
|
107 |
+
|
108 |
+
return 0 # fake loss value
|
109 |
+
|
110 |
+
def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):
|
111 |
+
|
112 |
+
if latents is None:
|
113 |
+
latents = torch.randn((text_embeddings.shape[0] // 2, self.unet.in_channels, height // 8, width // 8), device=self.device)
|
114 |
+
|
115 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
116 |
+
|
117 |
+
with torch.autocast('cuda'):
|
118 |
+
for i, t in enumerate(self.scheduler.timesteps):
|
119 |
+
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
120 |
+
latent_model_input = torch.cat([latents] * 2)
|
121 |
+
|
122 |
+
# predict the noise residual
|
123 |
+
with torch.no_grad():
|
124 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']
|
125 |
+
|
126 |
+
# perform guidance
|
127 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
128 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
129 |
+
|
130 |
+
# compute the previous noisy sample x_t -> x_t-1
|
131 |
+
latents = self.scheduler.step(noise_pred, t, latents)['prev_sample']
|
132 |
+
|
133 |
+
return latents
|
134 |
+
|
135 |
+
def decode_latents(self, latents):
|
136 |
+
|
137 |
+
latents = 1 / 0.18215 * latents
|
138 |
+
|
139 |
+
with torch.no_grad():
|
140 |
+
imgs = self.vae.decode(latents).sample
|
141 |
+
|
142 |
+
imgs = (imgs / 2 + 0.5).clamp(0, 1)
|
143 |
+
|
144 |
+
return imgs
|
145 |
+
|
146 |
+
def encode_imgs(self, imgs):
|
147 |
+
# imgs: [B, 3, H, W]
|
148 |
+
|
149 |
+
imgs = 2 * imgs - 1
|
150 |
+
|
151 |
+
posterior = self.vae.encode(imgs).latent_dist
|
152 |
+
latents = posterior.sample() * 0.18215
|
153 |
+
|
154 |
+
return latents
|
155 |
+
|
156 |
+
def prompt_to_img(self, prompts, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):
|
157 |
+
|
158 |
+
if isinstance(prompts, str):
|
159 |
+
prompts = [prompts]
|
160 |
+
|
161 |
+
# Prompts -> text embeds
|
162 |
+
text_embeds = self.get_text_embeds(prompts) # [2, 77, 768]
|
163 |
+
|
164 |
+
# Text embeds -> img latents
|
165 |
+
latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64]
|
166 |
+
|
167 |
+
# Img latents -> imgs
|
168 |
+
imgs = self.decode_latents(latents) # [1, 3, 512, 512]
|
169 |
+
|
170 |
+
# Img to Numpy
|
171 |
+
imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
|
172 |
+
imgs = (imgs * 255).round().astype('uint8')
|
173 |
+
|
174 |
+
return imgs
|
175 |
+
|
176 |
+
|
177 |
+
if __name__ == '__main__':
|
178 |
+
|
179 |
+
import argparse
|
180 |
+
import matplotlib.pyplot as plt
|
181 |
+
|
182 |
+
parser = argparse.ArgumentParser()
|
183 |
+
parser.add_argument('prompt', type=str)
|
184 |
+
parser.add_argument('-H', type=int, default=512)
|
185 |
+
parser.add_argument('-W', type=int, default=512)
|
186 |
+
parser.add_argument('--steps', type=int, default=50)
|
187 |
+
opt = parser.parse_args()
|
188 |
+
|
189 |
+
device = torch.device('cuda')
|
190 |
+
|
191 |
+
sd = StableDiffusion(device)
|
192 |
+
|
193 |
+
imgs = sd.prompt_to_img(opt.prompt, opt.H, opt.W, opt.steps)
|
194 |
+
|
195 |
+
# visualize image
|
196 |
+
plt.imshow(imgs[0])
|
197 |
+
plt.show()
|
198 |
+
|
199 |
+
|
200 |
+
|
201 |
+
|
nerf/utils.py
ADDED
@@ -0,0 +1,935 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import tqdm
|
4 |
+
import math
|
5 |
+
import imageio
|
6 |
+
import random
|
7 |
+
import warnings
|
8 |
+
import tensorboardX
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import pandas as pd
|
12 |
+
|
13 |
+
import time
|
14 |
+
from datetime import datetime
|
15 |
+
|
16 |
+
import cv2
|
17 |
+
import matplotlib.pyplot as plt
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
import torch.optim as optim
|
22 |
+
import torch.nn.functional as F
|
23 |
+
import torch.distributed as dist
|
24 |
+
from torch.utils.data import Dataset, DataLoader
|
25 |
+
|
26 |
+
import trimesh
|
27 |
+
from rich.console import Console
|
28 |
+
from torch_ema import ExponentialMovingAverage
|
29 |
+
|
30 |
+
from packaging import version as pver
|
31 |
+
|
32 |
+
def custom_meshgrid(*args):
|
33 |
+
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
|
34 |
+
if pver.parse(torch.__version__) < pver.parse('1.10'):
|
35 |
+
return torch.meshgrid(*args)
|
36 |
+
else:
|
37 |
+
return torch.meshgrid(*args, indexing='ij')
|
38 |
+
|
39 |
+
def safe_normalize(x, eps=1e-20):
|
40 |
+
return x / torch.sqrt(torch.clamp(torch.sum(x * x, -1, keepdim=True), min=eps))
|
41 |
+
|
42 |
+
@torch.cuda.amp.autocast(enabled=False)
|
43 |
+
def get_rays(poses, intrinsics, H, W, N=-1, error_map=None):
|
44 |
+
''' get rays
|
45 |
+
Args:
|
46 |
+
poses: [B, 4, 4], cam2world
|
47 |
+
intrinsics: [4]
|
48 |
+
H, W, N: int
|
49 |
+
error_map: [B, 128 * 128], sample probability based on training error
|
50 |
+
Returns:
|
51 |
+
rays_o, rays_d: [B, N, 3]
|
52 |
+
inds: [B, N]
|
53 |
+
'''
|
54 |
+
|
55 |
+
device = poses.device
|
56 |
+
B = poses.shape[0]
|
57 |
+
fx, fy, cx, cy = intrinsics
|
58 |
+
|
59 |
+
i, j = custom_meshgrid(torch.linspace(0, W-1, W, device=device), torch.linspace(0, H-1, H, device=device))
|
60 |
+
i = i.t().reshape([1, H*W]).expand([B, H*W]) + 0.5
|
61 |
+
j = j.t().reshape([1, H*W]).expand([B, H*W]) + 0.5
|
62 |
+
|
63 |
+
results = {}
|
64 |
+
|
65 |
+
if N > 0:
|
66 |
+
N = min(N, H*W)
|
67 |
+
|
68 |
+
if error_map is None:
|
69 |
+
inds = torch.randint(0, H*W, size=[N], device=device) # may duplicate
|
70 |
+
inds = inds.expand([B, N])
|
71 |
+
else:
|
72 |
+
|
73 |
+
# weighted sample on a low-reso grid
|
74 |
+
inds_coarse = torch.multinomial(error_map.to(device), N, replacement=False) # [B, N], but in [0, 128*128)
|
75 |
+
|
76 |
+
# map to the original resolution with random perturb.
|
77 |
+
inds_x, inds_y = inds_coarse // 128, inds_coarse % 128 # `//` will throw a warning in torch 1.10... anyway.
|
78 |
+
sx, sy = H / 128, W / 128
|
79 |
+
inds_x = (inds_x * sx + torch.rand(B, N, device=device) * sx).long().clamp(max=H - 1)
|
80 |
+
inds_y = (inds_y * sy + torch.rand(B, N, device=device) * sy).long().clamp(max=W - 1)
|
81 |
+
inds = inds_x * W + inds_y
|
82 |
+
|
83 |
+
results['inds_coarse'] = inds_coarse # need this when updating error_map
|
84 |
+
|
85 |
+
i = torch.gather(i, -1, inds)
|
86 |
+
j = torch.gather(j, -1, inds)
|
87 |
+
|
88 |
+
results['inds'] = inds
|
89 |
+
|
90 |
+
else:
|
91 |
+
inds = torch.arange(H*W, device=device).expand([B, H*W])
|
92 |
+
|
93 |
+
zs = torch.ones_like(i)
|
94 |
+
xs = (i - cx) / fx * zs
|
95 |
+
ys = (j - cy) / fy * zs
|
96 |
+
directions = torch.stack((xs, ys, zs), dim=-1)
|
97 |
+
directions = safe_normalize(directions)
|
98 |
+
rays_d = directions @ poses[:, :3, :3].transpose(-1, -2) # (B, N, 3)
|
99 |
+
|
100 |
+
rays_o = poses[..., :3, 3] # [B, 3]
|
101 |
+
rays_o = rays_o[..., None, :].expand_as(rays_d) # [B, N, 3]
|
102 |
+
|
103 |
+
results['rays_o'] = rays_o
|
104 |
+
results['rays_d'] = rays_d
|
105 |
+
|
106 |
+
return results
|
107 |
+
|
108 |
+
|
109 |
+
def seed_everything(seed):
|
110 |
+
random.seed(seed)
|
111 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
112 |
+
np.random.seed(seed)
|
113 |
+
torch.manual_seed(seed)
|
114 |
+
torch.cuda.manual_seed(seed)
|
115 |
+
#torch.backends.cudnn.deterministic = True
|
116 |
+
#torch.backends.cudnn.benchmark = True
|
117 |
+
|
118 |
+
|
119 |
+
def torch_vis_2d(x, renormalize=False):
|
120 |
+
# x: [3, H, W] or [1, H, W] or [H, W]
|
121 |
+
import matplotlib.pyplot as plt
|
122 |
+
import numpy as np
|
123 |
+
import torch
|
124 |
+
|
125 |
+
if isinstance(x, torch.Tensor):
|
126 |
+
if len(x.shape) == 3:
|
127 |
+
x = x.permute(1,2,0).squeeze()
|
128 |
+
x = x.detach().cpu().numpy()
|
129 |
+
|
130 |
+
print(f'[torch_vis_2d] {x.shape}, {x.dtype}, {x.min()} ~ {x.max()}')
|
131 |
+
|
132 |
+
x = x.astype(np.float32)
|
133 |
+
|
134 |
+
# renormalize
|
135 |
+
if renormalize:
|
136 |
+
x = (x - x.min(axis=0, keepdims=True)) / (x.max(axis=0, keepdims=True) - x.min(axis=0, keepdims=True) + 1e-8)
|
137 |
+
|
138 |
+
plt.imshow(x)
|
139 |
+
plt.show()
|
140 |
+
|
141 |
+
@torch.jit.script
|
142 |
+
def linear_to_srgb(x):
|
143 |
+
return torch.where(x < 0.0031308, 12.92 * x, 1.055 * x ** 0.41666 - 0.055)
|
144 |
+
|
145 |
+
|
146 |
+
@torch.jit.script
|
147 |
+
def srgb_to_linear(x):
|
148 |
+
return torch.where(x < 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4)
|
149 |
+
|
150 |
+
|
151 |
+
class Trainer(object):
|
152 |
+
def __init__(self,
|
153 |
+
name, # name of this experiment
|
154 |
+
opt, # extra conf
|
155 |
+
model, # network
|
156 |
+
guidance, # guidance network
|
157 |
+
criterion=None, # loss function, if None, assume inline implementation in train_step
|
158 |
+
optimizer=None, # optimizer
|
159 |
+
ema_decay=None, # if use EMA, set the decay
|
160 |
+
lr_scheduler=None, # scheduler
|
161 |
+
metrics=[], # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric.
|
162 |
+
local_rank=0, # which GPU am I
|
163 |
+
world_size=1, # total num of GPUs
|
164 |
+
device=None, # device to use, usually setting to None is OK. (auto choose device)
|
165 |
+
mute=False, # whether to mute all print
|
166 |
+
fp16=False, # amp optimize level
|
167 |
+
eval_interval=1, # eval once every $ epoch
|
168 |
+
max_keep_ckpt=2, # max num of saved ckpts in disk
|
169 |
+
workspace='workspace', # workspace to save logs & ckpts
|
170 |
+
best_mode='min', # the smaller/larger result, the better
|
171 |
+
use_loss_as_metric=True, # use loss as the first metric
|
172 |
+
report_metric_at_train=False, # also report metrics at training
|
173 |
+
use_checkpoint="latest", # which ckpt to use at init time
|
174 |
+
use_tensorboardX=True, # whether to use tensorboard for logging
|
175 |
+
scheduler_update_every_step=False, # whether to call scheduler.step() after every train step
|
176 |
+
):
|
177 |
+
|
178 |
+
self.name = name
|
179 |
+
self.opt = opt
|
180 |
+
self.mute = mute
|
181 |
+
self.metrics = metrics
|
182 |
+
self.local_rank = local_rank
|
183 |
+
self.world_size = world_size
|
184 |
+
self.workspace = workspace
|
185 |
+
self.ema_decay = ema_decay
|
186 |
+
self.fp16 = fp16
|
187 |
+
self.best_mode = best_mode
|
188 |
+
self.use_loss_as_metric = use_loss_as_metric
|
189 |
+
self.report_metric_at_train = report_metric_at_train
|
190 |
+
self.max_keep_ckpt = max_keep_ckpt
|
191 |
+
self.eval_interval = eval_interval
|
192 |
+
self.use_checkpoint = use_checkpoint
|
193 |
+
self.use_tensorboardX = use_tensorboardX
|
194 |
+
self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S")
|
195 |
+
self.scheduler_update_every_step = scheduler_update_every_step
|
196 |
+
self.device = device if device is not None else torch.device(f'cuda:{local_rank}' if torch.cuda.is_available() else 'cpu')
|
197 |
+
self.console = Console()
|
198 |
+
|
199 |
+
# text prompt
|
200 |
+
ref_text = self.opt.text
|
201 |
+
|
202 |
+
model.to(self.device)
|
203 |
+
if self.world_size > 1:
|
204 |
+
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
205 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
|
206 |
+
self.model = model
|
207 |
+
|
208 |
+
# guide model
|
209 |
+
self.guidance = guidance
|
210 |
+
|
211 |
+
if self.guidance is not None:
|
212 |
+
|
213 |
+
for p in self.guidance.parameters():
|
214 |
+
p.requires_grad = False
|
215 |
+
|
216 |
+
if not self.opt.dir_text:
|
217 |
+
self.text_z = self.guidance.get_text_embeds([ref_text])
|
218 |
+
else:
|
219 |
+
self.text_z = []
|
220 |
+
for d in ['front', 'side', 'back', 'side', 'overhead', 'bottom']:
|
221 |
+
text = f"{ref_text}, {d} view"
|
222 |
+
text_z = self.guidance.get_text_embeds([text])
|
223 |
+
self.text_z.append(text_z)
|
224 |
+
|
225 |
+
else:
|
226 |
+
self.text_z = None
|
227 |
+
|
228 |
+
if isinstance(criterion, nn.Module):
|
229 |
+
criterion.to(self.device)
|
230 |
+
self.criterion = criterion
|
231 |
+
|
232 |
+
if optimizer is None:
|
233 |
+
self.optimizer = optim.Adam(self.model.parameters(), lr=0.001, weight_decay=5e-4) # naive adam
|
234 |
+
else:
|
235 |
+
self.optimizer = optimizer(self.model)
|
236 |
+
|
237 |
+
if lr_scheduler is None:
|
238 |
+
self.lr_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: 1) # fake scheduler
|
239 |
+
else:
|
240 |
+
self.lr_scheduler = lr_scheduler(self.optimizer)
|
241 |
+
|
242 |
+
if ema_decay is not None:
|
243 |
+
self.ema = ExponentialMovingAverage(self.model.parameters(), decay=ema_decay)
|
244 |
+
else:
|
245 |
+
self.ema = None
|
246 |
+
|
247 |
+
self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16)
|
248 |
+
|
249 |
+
# variable init
|
250 |
+
self.epoch = 0
|
251 |
+
self.global_step = 0
|
252 |
+
self.local_step = 0
|
253 |
+
self.stats = {
|
254 |
+
"loss": [],
|
255 |
+
"valid_loss": [],
|
256 |
+
"results": [], # metrics[0], or valid_loss
|
257 |
+
"checkpoints": [], # record path of saved ckpt, to automatically remove old ckpt
|
258 |
+
"best_result": None,
|
259 |
+
}
|
260 |
+
|
261 |
+
# auto fix
|
262 |
+
if len(metrics) == 0 or self.use_loss_as_metric:
|
263 |
+
self.best_mode = 'min'
|
264 |
+
|
265 |
+
# workspace prepare
|
266 |
+
self.log_ptr = None
|
267 |
+
if self.workspace is not None:
|
268 |
+
os.makedirs(self.workspace, exist_ok=True)
|
269 |
+
self.log_path = os.path.join(workspace, f"log_{self.name}.txt")
|
270 |
+
self.log_ptr = open(self.log_path, "a+")
|
271 |
+
|
272 |
+
self.ckpt_path = os.path.join(self.workspace, 'checkpoints')
|
273 |
+
self.best_path = f"{self.ckpt_path}/{self.name}.pth"
|
274 |
+
os.makedirs(self.ckpt_path, exist_ok=True)
|
275 |
+
|
276 |
+
self.log(f'[INFO] Trainer: {self.name} | {self.time_stamp} | {self.device} | {"fp16" if self.fp16 else "fp32"} | {self.workspace}')
|
277 |
+
self.log(f'[INFO] #parameters: {sum([p.numel() for p in model.parameters() if p.requires_grad])}')
|
278 |
+
|
279 |
+
if self.workspace is not None:
|
280 |
+
if self.use_checkpoint == "scratch":
|
281 |
+
self.log("[INFO] Training from scratch ...")
|
282 |
+
elif self.use_checkpoint == "latest":
|
283 |
+
self.log("[INFO] Loading latest checkpoint ...")
|
284 |
+
self.load_checkpoint()
|
285 |
+
elif self.use_checkpoint == "latest_model":
|
286 |
+
self.log("[INFO] Loading latest checkpoint (model only)...")
|
287 |
+
self.load_checkpoint(model_only=True)
|
288 |
+
elif self.use_checkpoint == "best":
|
289 |
+
if os.path.exists(self.best_path):
|
290 |
+
self.log("[INFO] Loading best checkpoint ...")
|
291 |
+
self.load_checkpoint(self.best_path)
|
292 |
+
else:
|
293 |
+
self.log(f"[INFO] {self.best_path} not found, loading latest ...")
|
294 |
+
self.load_checkpoint()
|
295 |
+
else: # path to ckpt
|
296 |
+
self.log(f"[INFO] Loading {self.use_checkpoint} ...")
|
297 |
+
self.load_checkpoint(self.use_checkpoint)
|
298 |
+
|
299 |
+
def __del__(self):
|
300 |
+
if self.log_ptr:
|
301 |
+
self.log_ptr.close()
|
302 |
+
|
303 |
+
|
304 |
+
def log(self, *args, **kwargs):
|
305 |
+
if self.local_rank == 0:
|
306 |
+
if not self.mute:
|
307 |
+
#print(*args)
|
308 |
+
self.console.print(*args, **kwargs)
|
309 |
+
if self.log_ptr:
|
310 |
+
print(*args, file=self.log_ptr)
|
311 |
+
self.log_ptr.flush() # write immediately to file
|
312 |
+
|
313 |
+
### ------------------------------
|
314 |
+
|
315 |
+
def train_step(self, data):
|
316 |
+
|
317 |
+
rays_o = data['rays_o'] # [B, N, 3]
|
318 |
+
rays_d = data['rays_d'] # [B, N, 3]
|
319 |
+
|
320 |
+
B, N = rays_o.shape[:2]
|
321 |
+
H, W = data['H'], data['W']
|
322 |
+
|
323 |
+
# TODO: shading is not working right now...
|
324 |
+
if self.global_step < self.opt.albedo_iters:
|
325 |
+
shading = 'albedo'
|
326 |
+
ambient_ratio = 1.0
|
327 |
+
else:
|
328 |
+
rand = random.random()
|
329 |
+
if rand > 0.8:
|
330 |
+
shading = 'albedo'
|
331 |
+
ambient_ratio = 1.0
|
332 |
+
elif rand > 0.4:
|
333 |
+
shading = 'lambertian'
|
334 |
+
ambient_ratio = 0.1
|
335 |
+
else:
|
336 |
+
shading = 'textureless'
|
337 |
+
ambient_ratio = 0.1
|
338 |
+
|
339 |
+
# _t = time.time()
|
340 |
+
bg_color = torch.rand((B * N, 3), device=rays_o.device) # pixel-wise random
|
341 |
+
outputs = self.model.render(rays_o, rays_d, staged=False, perturb=True, bg_color=bg_color, ambient_ratio=ambient_ratio, shading=shading, force_all_rays=True, **vars(self.opt))
|
342 |
+
pred_rgb = outputs['image'].reshape(B, H, W, 3).permute(0, 3, 1, 2).contiguous() # [1, 3, H, W]
|
343 |
+
# torch.cuda.synchronize(); print(f'[TIME] nerf render {time.time() - _t:.4f}s')
|
344 |
+
|
345 |
+
# text embeddings
|
346 |
+
if self.opt.dir_text:
|
347 |
+
dirs = data['dir'] # [B,]
|
348 |
+
text_z = self.text_z[dirs]
|
349 |
+
else:
|
350 |
+
text_z = self.text_z
|
351 |
+
|
352 |
+
# encode pred_rgb to latents
|
353 |
+
# _t = time.time()
|
354 |
+
loss_guidance = self.guidance.train_step(text_z, pred_rgb)
|
355 |
+
# torch.cuda.synchronize(); print(f'[TIME] total guiding {time.time() - _t:.4f}s')
|
356 |
+
|
357 |
+
# occupancy loss
|
358 |
+
pred_ws = outputs['weights_sum'].reshape(B, 1, H, W)
|
359 |
+
# mask_ws = outputs['mask'].reshape(B, 1, H, W) # near < far
|
360 |
+
|
361 |
+
# loss_ws = (pred_ws ** 2 + 0.01).sqrt().mean()
|
362 |
+
|
363 |
+
alphas = (pred_ws).clamp(1e-5, 1 - 1e-5)
|
364 |
+
# alphas = alphas ** 2 # skewed entropy, favors 0 over 1
|
365 |
+
loss_entropy = (- alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas)).mean()
|
366 |
+
|
367 |
+
loss = loss_guidance + 1e-3 * loss_entropy
|
368 |
+
|
369 |
+
if 'loss_orient' in outputs:
|
370 |
+
loss_orient = outputs['loss_orient']
|
371 |
+
loss = loss + 1e-2 * loss_orient
|
372 |
+
|
373 |
+
return pred_rgb, pred_ws, loss
|
374 |
+
|
375 |
+
def eval_step(self, data):
|
376 |
+
|
377 |
+
rays_o = data['rays_o'] # [B, N, 3]
|
378 |
+
rays_d = data['rays_d'] # [B, N, 3]
|
379 |
+
|
380 |
+
B, N = rays_o.shape[:2]
|
381 |
+
H, W = data['H'], data['W']
|
382 |
+
|
383 |
+
shading = data['shading'] if 'shading' in data else 'albedo'
|
384 |
+
ambient_ratio = data['ambient_ratio'] if 'ambient_ratio' in data else 1.0
|
385 |
+
light_d = data['light_d'] if 'light_d' in data else None
|
386 |
+
|
387 |
+
outputs = self.model.render(rays_o, rays_d, staged=True, perturb=False, bg_color=None, light_d=light_d, ambient_ratio=ambient_ratio, shading=shading, force_all_rays=True, **vars(self.opt))
|
388 |
+
pred_rgb = outputs['image'].reshape(B, H, W, 3)
|
389 |
+
pred_depth = outputs['depth'].reshape(B, H, W)
|
390 |
+
pred_ws = outputs['weights_sum'].reshape(B, H, W)
|
391 |
+
# mask_ws = outputs['mask'].reshape(B, H, W) # near < far
|
392 |
+
|
393 |
+
# loss_ws = pred_ws.sum() / mask_ws.sum()
|
394 |
+
# loss_ws = pred_ws.mean()
|
395 |
+
|
396 |
+
alphas = (pred_ws).clamp(1e-5, 1 - 1e-5)
|
397 |
+
# alphas = alphas ** 2 # skewed entropy, favors 0 over 1
|
398 |
+
loss_entropy = (- alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas)).mean()
|
399 |
+
|
400 |
+
loss = 1e-3 * loss_entropy
|
401 |
+
|
402 |
+
return pred_rgb, pred_depth, loss
|
403 |
+
|
404 |
+
# moved out bg_color and perturb for more flexible control...
|
405 |
+
def test_step(self, data, bg_color=None, perturb=False):
|
406 |
+
rays_o = data['rays_o'] # [B, N, 3]
|
407 |
+
rays_d = data['rays_d'] # [B, N, 3]
|
408 |
+
|
409 |
+
B, N = rays_o.shape[:2]
|
410 |
+
H, W = data['H'], data['W']
|
411 |
+
|
412 |
+
if bg_color is not None:
|
413 |
+
bg_color = bg_color.to(rays_o.device)
|
414 |
+
else:
|
415 |
+
bg_color = torch.ones(3, device=rays_o.device) # [3]
|
416 |
+
|
417 |
+
shading = data['shading'] if 'shading' in data else 'albedo'
|
418 |
+
ambient_ratio = data['ambient_ratio'] if 'ambient_ratio' in data else 1.0
|
419 |
+
light_d = data['light_d'] if 'light_d' in data else None
|
420 |
+
|
421 |
+
outputs = self.model.render(rays_o, rays_d, staged=True, perturb=perturb, light_d=light_d, ambient_ratio=ambient_ratio, shading=shading, force_all_rays=True, bg_color=bg_color, **vars(self.opt))
|
422 |
+
|
423 |
+
pred_rgb = outputs['image'].reshape(B, H, W, 3)
|
424 |
+
pred_depth = outputs['depth'].reshape(B, H, W)
|
425 |
+
|
426 |
+
return pred_rgb, pred_depth
|
427 |
+
|
428 |
+
|
429 |
+
def save_mesh(self, save_path=None, resolution=128):
|
430 |
+
|
431 |
+
if save_path is None:
|
432 |
+
save_path = os.path.join(self.workspace, 'mesh')
|
433 |
+
|
434 |
+
self.log(f"==> Saving mesh to {save_path}")
|
435 |
+
|
436 |
+
os.makedirs(save_path, exist_ok=True)
|
437 |
+
|
438 |
+
self.model.export_mesh(save_path, resolution=resolution)
|
439 |
+
|
440 |
+
self.log(f"==> Finished saving mesh.")
|
441 |
+
|
442 |
+
### ------------------------------
|
443 |
+
|
444 |
+
def train(self, train_loader, valid_loader, max_epochs):
|
445 |
+
if self.use_tensorboardX and self.local_rank == 0:
|
446 |
+
self.writer = tensorboardX.SummaryWriter(os.path.join(self.workspace, "run", self.name))
|
447 |
+
|
448 |
+
start_t = time.time()
|
449 |
+
|
450 |
+
for epoch in range(self.epoch + 1, max_epochs + 1):
|
451 |
+
self.epoch = epoch
|
452 |
+
|
453 |
+
self.train_one_epoch(train_loader)
|
454 |
+
|
455 |
+
if self.workspace is not None and self.local_rank == 0:
|
456 |
+
self.save_checkpoint(full=True, best=False)
|
457 |
+
|
458 |
+
if self.epoch % self.eval_interval == 0:
|
459 |
+
self.evaluate_one_epoch(valid_loader)
|
460 |
+
self.save_checkpoint(full=False, best=True)
|
461 |
+
|
462 |
+
end_t = time.time()
|
463 |
+
|
464 |
+
self.log(f"[INFO] training takes {(end_t - start_t)/ 60:.4f} minutes.")
|
465 |
+
|
466 |
+
if self.use_tensorboardX and self.local_rank == 0:
|
467 |
+
self.writer.close()
|
468 |
+
|
469 |
+
def evaluate(self, loader, name=None):
|
470 |
+
self.use_tensorboardX, use_tensorboardX = False, self.use_tensorboardX
|
471 |
+
self.evaluate_one_epoch(loader, name)
|
472 |
+
self.use_tensorboardX = use_tensorboardX
|
473 |
+
|
474 |
+
def test(self, loader, save_path=None, name=None, write_video=True):
|
475 |
+
|
476 |
+
if save_path is None:
|
477 |
+
save_path = os.path.join(self.workspace, 'results')
|
478 |
+
|
479 |
+
if name is None:
|
480 |
+
name = f'{self.name}_ep{self.epoch:04d}'
|
481 |
+
|
482 |
+
os.makedirs(save_path, exist_ok=True)
|
483 |
+
|
484 |
+
self.log(f"==> Start Test, save results to {save_path}")
|
485 |
+
|
486 |
+
pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')
|
487 |
+
self.model.eval()
|
488 |
+
|
489 |
+
if write_video:
|
490 |
+
all_preds = []
|
491 |
+
all_preds_depth = []
|
492 |
+
|
493 |
+
with torch.no_grad():
|
494 |
+
|
495 |
+
for i, data in enumerate(loader):
|
496 |
+
|
497 |
+
with torch.cuda.amp.autocast(enabled=self.fp16):
|
498 |
+
preds, preds_depth = self.test_step(data)
|
499 |
+
|
500 |
+
pred = preds[0].detach().cpu().numpy()
|
501 |
+
pred = (pred * 255).astype(np.uint8)
|
502 |
+
|
503 |
+
pred_depth = preds_depth[0].detach().cpu().numpy()
|
504 |
+
pred_depth = (pred_depth * 255).astype(np.uint8)
|
505 |
+
|
506 |
+
if write_video:
|
507 |
+
all_preds.append(pred)
|
508 |
+
all_preds_depth.append(pred_depth)
|
509 |
+
else:
|
510 |
+
cv2.imwrite(os.path.join(save_path, f'{name}_{i:04d}_rgb.png'), cv2.cvtColor(pred, cv2.COLOR_RGB2BGR))
|
511 |
+
cv2.imwrite(os.path.join(save_path, f'{name}_{i:04d}_depth.png'), pred_depth)
|
512 |
+
|
513 |
+
pbar.update(loader.batch_size)
|
514 |
+
|
515 |
+
if write_video:
|
516 |
+
all_preds = np.stack(all_preds, axis=0)
|
517 |
+
all_preds_depth = np.stack(all_preds_depth, axis=0)
|
518 |
+
|
519 |
+
imageio.mimwrite(os.path.join(save_path, f'{name}_rgb.mp4'), all_preds, fps=25, quality=8, macro_block_size=1)
|
520 |
+
imageio.mimwrite(os.path.join(save_path, f'{name}_depth.mp4'), all_preds_depth, fps=25, quality=8, macro_block_size=1)
|
521 |
+
|
522 |
+
self.log(f"==> Finished Test.")
|
523 |
+
|
524 |
+
# [GUI] train text step.
|
525 |
+
def train_gui(self, train_loader, step=16):
|
526 |
+
|
527 |
+
self.model.train()
|
528 |
+
|
529 |
+
total_loss = torch.tensor([0], dtype=torch.float32, device=self.device)
|
530 |
+
|
531 |
+
loader = iter(train_loader)
|
532 |
+
|
533 |
+
for _ in range(step):
|
534 |
+
|
535 |
+
# mimic an infinite loop dataloader (in case the total dataset is smaller than step)
|
536 |
+
try:
|
537 |
+
data = next(loader)
|
538 |
+
except StopIteration:
|
539 |
+
loader = iter(train_loader)
|
540 |
+
data = next(loader)
|
541 |
+
|
542 |
+
# update grid every 16 steps
|
543 |
+
if self.model.cuda_ray and self.global_step % self.opt.update_extra_interval == 0:
|
544 |
+
with torch.cuda.amp.autocast(enabled=self.fp16):
|
545 |
+
self.model.update_extra_state()
|
546 |
+
|
547 |
+
self.global_step += 1
|
548 |
+
|
549 |
+
self.optimizer.zero_grad()
|
550 |
+
|
551 |
+
with torch.cuda.amp.autocast(enabled=self.fp16):
|
552 |
+
pred_rgbs, pred_ws, loss = self.train_step(data)
|
553 |
+
|
554 |
+
self.scaler.scale(loss).backward()
|
555 |
+
self.scaler.step(self.optimizer)
|
556 |
+
self.scaler.update()
|
557 |
+
|
558 |
+
if self.scheduler_update_every_step:
|
559 |
+
self.lr_scheduler.step()
|
560 |
+
|
561 |
+
total_loss += loss.detach()
|
562 |
+
|
563 |
+
if self.ema is not None:
|
564 |
+
self.ema.update()
|
565 |
+
|
566 |
+
average_loss = total_loss.item() / step
|
567 |
+
|
568 |
+
if not self.scheduler_update_every_step:
|
569 |
+
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
570 |
+
self.lr_scheduler.step(average_loss)
|
571 |
+
else:
|
572 |
+
self.lr_scheduler.step()
|
573 |
+
|
574 |
+
outputs = {
|
575 |
+
'loss': average_loss,
|
576 |
+
'lr': self.optimizer.param_groups[0]['lr'],
|
577 |
+
}
|
578 |
+
|
579 |
+
return outputs
|
580 |
+
|
581 |
+
|
582 |
+
# [GUI] test on a single image
|
583 |
+
def test_gui(self, pose, intrinsics, W, H, bg_color=None, spp=1, downscale=1, light_d=None, ambient_ratio=1.0, shading='albedo'):
|
584 |
+
|
585 |
+
# render resolution (may need downscale to for better frame rate)
|
586 |
+
rH = int(H * downscale)
|
587 |
+
rW = int(W * downscale)
|
588 |
+
intrinsics = intrinsics * downscale
|
589 |
+
|
590 |
+
pose = torch.from_numpy(pose).unsqueeze(0).to(self.device)
|
591 |
+
|
592 |
+
rays = get_rays(pose, intrinsics, rH, rW, -1)
|
593 |
+
|
594 |
+
# from degree theta/phi to 3D normalized vec
|
595 |
+
light_d = np.deg2rad(light_d)
|
596 |
+
light_d = np.array([
|
597 |
+
np.sin(light_d[0]) * np.sin(light_d[1]),
|
598 |
+
np.cos(light_d[0]),
|
599 |
+
np.sin(light_d[0]) * np.cos(light_d[1]),
|
600 |
+
], dtype=np.float32)
|
601 |
+
light_d = torch.from_numpy(light_d).to(self.device)
|
602 |
+
|
603 |
+
data = {
|
604 |
+
'rays_o': rays['rays_o'],
|
605 |
+
'rays_d': rays['rays_d'],
|
606 |
+
'H': rH,
|
607 |
+
'W': rW,
|
608 |
+
'light_d': light_d,
|
609 |
+
'ambient_ratio': ambient_ratio,
|
610 |
+
'shading': shading,
|
611 |
+
}
|
612 |
+
|
613 |
+
self.model.eval()
|
614 |
+
|
615 |
+
if self.ema is not None:
|
616 |
+
self.ema.store()
|
617 |
+
self.ema.copy_to()
|
618 |
+
|
619 |
+
with torch.no_grad():
|
620 |
+
with torch.cuda.amp.autocast(enabled=self.fp16):
|
621 |
+
# here spp is used as perturb random seed!
|
622 |
+
preds, preds_depth = self.test_step(data, bg_color=bg_color, perturb=spp)
|
623 |
+
|
624 |
+
if self.ema is not None:
|
625 |
+
self.ema.restore()
|
626 |
+
|
627 |
+
# interpolation to the original resolution
|
628 |
+
if downscale != 1:
|
629 |
+
# have to permute twice with torch...
|
630 |
+
preds = F.interpolate(preds.permute(0, 3, 1, 2), size=(H, W), mode='nearest').permute(0, 2, 3, 1).contiguous()
|
631 |
+
preds_depth = F.interpolate(preds_depth.unsqueeze(1), size=(H, W), mode='nearest').squeeze(1)
|
632 |
+
|
633 |
+
outputs = {
|
634 |
+
'image': preds[0].detach().cpu().numpy(),
|
635 |
+
'depth': preds_depth[0].detach().cpu().numpy(),
|
636 |
+
}
|
637 |
+
|
638 |
+
return outputs
|
639 |
+
|
640 |
+
def train_one_epoch(self, loader):
|
641 |
+
self.log(f"==> Start Training Epoch {self.epoch}, lr={self.optimizer.param_groups[0]['lr']:.6f} ...")
|
642 |
+
|
643 |
+
total_loss = 0
|
644 |
+
if self.local_rank == 0 and self.report_metric_at_train:
|
645 |
+
for metric in self.metrics:
|
646 |
+
metric.clear()
|
647 |
+
|
648 |
+
self.model.train()
|
649 |
+
|
650 |
+
# distributedSampler: must call set_epoch() to shuffle indices across multiple epochs
|
651 |
+
# ref: https://pytorch.org/docs/stable/data.html
|
652 |
+
if self.world_size > 1:
|
653 |
+
loader.sampler.set_epoch(self.epoch)
|
654 |
+
|
655 |
+
if self.local_rank == 0:
|
656 |
+
pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')
|
657 |
+
|
658 |
+
self.local_step = 0
|
659 |
+
|
660 |
+
for data in loader:
|
661 |
+
|
662 |
+
# update grid every 16 steps
|
663 |
+
if self.model.cuda_ray and self.global_step % self.opt.update_extra_interval == 0:
|
664 |
+
with torch.cuda.amp.autocast(enabled=self.fp16):
|
665 |
+
self.model.update_extra_state()
|
666 |
+
|
667 |
+
self.local_step += 1
|
668 |
+
self.global_step += 1
|
669 |
+
|
670 |
+
self.optimizer.zero_grad()
|
671 |
+
|
672 |
+
with torch.cuda.amp.autocast(enabled=self.fp16):
|
673 |
+
pred_rgbs, pred_ws, loss = self.train_step(data)
|
674 |
+
|
675 |
+
self.scaler.scale(loss).backward()
|
676 |
+
self.scaler.step(self.optimizer)
|
677 |
+
self.scaler.update()
|
678 |
+
|
679 |
+
if self.scheduler_update_every_step:
|
680 |
+
self.lr_scheduler.step()
|
681 |
+
|
682 |
+
loss_val = loss.item()
|
683 |
+
total_loss += loss_val
|
684 |
+
|
685 |
+
if self.local_rank == 0:
|
686 |
+
# if self.report_metric_at_train:
|
687 |
+
# for metric in self.metrics:
|
688 |
+
# metric.update(preds, truths)
|
689 |
+
|
690 |
+
if self.use_tensorboardX:
|
691 |
+
self.writer.add_scalar("train/loss", loss_val, self.global_step)
|
692 |
+
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]['lr'], self.global_step)
|
693 |
+
|
694 |
+
if self.scheduler_update_every_step:
|
695 |
+
pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f}), lr={self.optimizer.param_groups[0]['lr']:.6f}")
|
696 |
+
else:
|
697 |
+
pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})")
|
698 |
+
pbar.update(loader.batch_size)
|
699 |
+
|
700 |
+
if self.ema is not None:
|
701 |
+
self.ema.update()
|
702 |
+
|
703 |
+
average_loss = total_loss / self.local_step
|
704 |
+
self.stats["loss"].append(average_loss)
|
705 |
+
|
706 |
+
if self.local_rank == 0:
|
707 |
+
pbar.close()
|
708 |
+
if self.report_metric_at_train:
|
709 |
+
for metric in self.metrics:
|
710 |
+
self.log(metric.report(), style="red")
|
711 |
+
if self.use_tensorboardX:
|
712 |
+
metric.write(self.writer, self.epoch, prefix="train")
|
713 |
+
metric.clear()
|
714 |
+
|
715 |
+
if not self.scheduler_update_every_step:
|
716 |
+
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
717 |
+
self.lr_scheduler.step(average_loss)
|
718 |
+
else:
|
719 |
+
self.lr_scheduler.step()
|
720 |
+
|
721 |
+
self.log(f"==> Finished Epoch {self.epoch}.")
|
722 |
+
|
723 |
+
|
724 |
+
def evaluate_one_epoch(self, loader, name=None):
|
725 |
+
self.log(f"++> Evaluate at epoch {self.epoch} ...")
|
726 |
+
|
727 |
+
if name is None:
|
728 |
+
name = f'{self.name}_ep{self.epoch:04d}'
|
729 |
+
|
730 |
+
total_loss = 0
|
731 |
+
if self.local_rank == 0:
|
732 |
+
for metric in self.metrics:
|
733 |
+
metric.clear()
|
734 |
+
|
735 |
+
self.model.eval()
|
736 |
+
|
737 |
+
if self.ema is not None:
|
738 |
+
self.ema.store()
|
739 |
+
self.ema.copy_to()
|
740 |
+
|
741 |
+
if self.local_rank == 0:
|
742 |
+
pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')
|
743 |
+
|
744 |
+
with torch.no_grad():
|
745 |
+
self.local_step = 0
|
746 |
+
|
747 |
+
for data in loader:
|
748 |
+
self.local_step += 1
|
749 |
+
|
750 |
+
with torch.cuda.amp.autocast(enabled=self.fp16):
|
751 |
+
preds, preds_depth, loss = self.eval_step(data)
|
752 |
+
|
753 |
+
# all_gather/reduce the statistics (NCCL only support all_*)
|
754 |
+
if self.world_size > 1:
|
755 |
+
dist.all_reduce(loss, op=dist.ReduceOp.SUM)
|
756 |
+
loss = loss / self.world_size
|
757 |
+
|
758 |
+
preds_list = [torch.zeros_like(preds).to(self.device) for _ in range(self.world_size)] # [[B, ...], [B, ...], ...]
|
759 |
+
dist.all_gather(preds_list, preds)
|
760 |
+
preds = torch.cat(preds_list, dim=0)
|
761 |
+
|
762 |
+
preds_depth_list = [torch.zeros_like(preds_depth).to(self.device) for _ in range(self.world_size)] # [[B, ...], [B, ...], ...]
|
763 |
+
dist.all_gather(preds_depth_list, preds_depth)
|
764 |
+
preds_depth = torch.cat(preds_depth_list, dim=0)
|
765 |
+
|
766 |
+
loss_val = loss.item()
|
767 |
+
total_loss += loss_val
|
768 |
+
|
769 |
+
# only rank = 0 will perform evaluation.
|
770 |
+
if self.local_rank == 0:
|
771 |
+
|
772 |
+
# save image
|
773 |
+
save_path = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_rgb.png')
|
774 |
+
save_path_depth = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_depth.png')
|
775 |
+
|
776 |
+
#self.log(f"==> Saving validation image to {save_path}")
|
777 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
778 |
+
|
779 |
+
pred = preds[0].detach().cpu().numpy()
|
780 |
+
pred = (pred * 255).astype(np.uint8)
|
781 |
+
|
782 |
+
pred_depth = preds_depth[0].detach().cpu().numpy()
|
783 |
+
pred_depth = (pred_depth * 255).astype(np.uint8)
|
784 |
+
|
785 |
+
cv2.imwrite(save_path, cv2.cvtColor(pred, cv2.COLOR_RGB2BGR))
|
786 |
+
cv2.imwrite(save_path_depth, pred_depth)
|
787 |
+
|
788 |
+
pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})")
|
789 |
+
pbar.update(loader.batch_size)
|
790 |
+
|
791 |
+
|
792 |
+
average_loss = total_loss / self.local_step
|
793 |
+
self.stats["valid_loss"].append(average_loss)
|
794 |
+
|
795 |
+
if self.local_rank == 0:
|
796 |
+
pbar.close()
|
797 |
+
if not self.use_loss_as_metric and len(self.metrics) > 0:
|
798 |
+
result = self.metrics[0].measure()
|
799 |
+
self.stats["results"].append(result if self.best_mode == 'min' else - result) # if max mode, use -result
|
800 |
+
else:
|
801 |
+
self.stats["results"].append(average_loss) # if no metric, choose best by min loss
|
802 |
+
|
803 |
+
for metric in self.metrics:
|
804 |
+
self.log(metric.report(), style="blue")
|
805 |
+
if self.use_tensorboardX:
|
806 |
+
metric.write(self.writer, self.epoch, prefix="evaluate")
|
807 |
+
metric.clear()
|
808 |
+
|
809 |
+
if self.ema is not None:
|
810 |
+
self.ema.restore()
|
811 |
+
|
812 |
+
self.log(f"++> Evaluate epoch {self.epoch} Finished.")
|
813 |
+
|
814 |
+
def save_checkpoint(self, name=None, full=False, best=False):
|
815 |
+
|
816 |
+
if name is None:
|
817 |
+
name = f'{self.name}_ep{self.epoch:04d}'
|
818 |
+
|
819 |
+
state = {
|
820 |
+
'epoch': self.epoch,
|
821 |
+
'global_step': self.global_step,
|
822 |
+
'stats': self.stats,
|
823 |
+
}
|
824 |
+
|
825 |
+
if self.model.cuda_ray:
|
826 |
+
state['mean_count'] = self.model.mean_count
|
827 |
+
state['mean_density'] = self.model.mean_density
|
828 |
+
|
829 |
+
if full:
|
830 |
+
state['optimizer'] = self.optimizer.state_dict()
|
831 |
+
state['lr_scheduler'] = self.lr_scheduler.state_dict()
|
832 |
+
state['scaler'] = self.scaler.state_dict()
|
833 |
+
if self.ema is not None:
|
834 |
+
state['ema'] = self.ema.state_dict()
|
835 |
+
|
836 |
+
if not best:
|
837 |
+
|
838 |
+
state['model'] = self.model.state_dict()
|
839 |
+
|
840 |
+
file_path = f"{name}.pth"
|
841 |
+
|
842 |
+
self.stats["checkpoints"].append(file_path)
|
843 |
+
|
844 |
+
if len(self.stats["checkpoints"]) > self.max_keep_ckpt:
|
845 |
+
old_ckpt = os.path.join(self.ckpt_path, self.stats["checkpoints"].pop(0))
|
846 |
+
if os.path.exists(old_ckpt):
|
847 |
+
os.remove(old_ckpt)
|
848 |
+
|
849 |
+
torch.save(state, os.path.join(self.ckpt_path, file_path))
|
850 |
+
|
851 |
+
else:
|
852 |
+
if len(self.stats["results"]) > 0:
|
853 |
+
if self.stats["best_result"] is None or self.stats["results"][-1] < self.stats["best_result"]:
|
854 |
+
self.log(f"[INFO] New best result: {self.stats['best_result']} --> {self.stats['results'][-1]}")
|
855 |
+
self.stats["best_result"] = self.stats["results"][-1]
|
856 |
+
|
857 |
+
# save ema results
|
858 |
+
if self.ema is not None:
|
859 |
+
self.ema.store()
|
860 |
+
self.ema.copy_to()
|
861 |
+
|
862 |
+
state['model'] = self.model.state_dict()
|
863 |
+
|
864 |
+
if self.ema is not None:
|
865 |
+
self.ema.restore()
|
866 |
+
|
867 |
+
torch.save(state, self.best_path)
|
868 |
+
else:
|
869 |
+
self.log(f"[WARN] no evaluated results found, skip saving best checkpoint.")
|
870 |
+
|
871 |
+
def load_checkpoint(self, checkpoint=None, model_only=False):
|
872 |
+
if checkpoint is None:
|
873 |
+
checkpoint_list = sorted(glob.glob(f'{self.ckpt_path}/*.pth'))
|
874 |
+
if checkpoint_list:
|
875 |
+
checkpoint = checkpoint_list[-1]
|
876 |
+
self.log(f"[INFO] Latest checkpoint is {checkpoint}")
|
877 |
+
else:
|
878 |
+
self.log("[WARN] No checkpoint found, model randomly initialized.")
|
879 |
+
return
|
880 |
+
|
881 |
+
checkpoint_dict = torch.load(checkpoint, map_location=self.device)
|
882 |
+
|
883 |
+
if 'model' not in checkpoint_dict:
|
884 |
+
self.model.load_state_dict(checkpoint_dict)
|
885 |
+
self.log("[INFO] loaded model.")
|
886 |
+
return
|
887 |
+
|
888 |
+
missing_keys, unexpected_keys = self.model.load_state_dict(checkpoint_dict['model'], strict=False)
|
889 |
+
self.log("[INFO] loaded model.")
|
890 |
+
if len(missing_keys) > 0:
|
891 |
+
self.log(f"[WARN] missing keys: {missing_keys}")
|
892 |
+
if len(unexpected_keys) > 0:
|
893 |
+
self.log(f"[WARN] unexpected keys: {unexpected_keys}")
|
894 |
+
|
895 |
+
if self.ema is not None and 'ema' in checkpoint_dict:
|
896 |
+
try:
|
897 |
+
self.ema.load_state_dict(checkpoint_dict['ema'])
|
898 |
+
self.log("[INFO] loaded EMA.")
|
899 |
+
except:
|
900 |
+
self.log("[WARN] failed to loaded EMA.")
|
901 |
+
|
902 |
+
if self.model.cuda_ray:
|
903 |
+
if 'mean_count' in checkpoint_dict:
|
904 |
+
self.model.mean_count = checkpoint_dict['mean_count']
|
905 |
+
if 'mean_density' in checkpoint_dict:
|
906 |
+
self.model.mean_density = checkpoint_dict['mean_density']
|
907 |
+
|
908 |
+
if model_only:
|
909 |
+
return
|
910 |
+
|
911 |
+
self.stats = checkpoint_dict['stats']
|
912 |
+
self.epoch = checkpoint_dict['epoch']
|
913 |
+
self.global_step = checkpoint_dict['global_step']
|
914 |
+
self.log(f"[INFO] load at epoch {self.epoch}, global step {self.global_step}")
|
915 |
+
|
916 |
+
if self.optimizer and 'optimizer' in checkpoint_dict:
|
917 |
+
try:
|
918 |
+
self.optimizer.load_state_dict(checkpoint_dict['optimizer'])
|
919 |
+
self.log("[INFO] loaded optimizer.")
|
920 |
+
except:
|
921 |
+
self.log("[WARN] Failed to load optimizer.")
|
922 |
+
|
923 |
+
if self.lr_scheduler and 'lr_scheduler' in checkpoint_dict:
|
924 |
+
try:
|
925 |
+
self.lr_scheduler.load_state_dict(checkpoint_dict['lr_scheduler'])
|
926 |
+
self.log("[INFO] loaded scheduler.")
|
927 |
+
except:
|
928 |
+
self.log("[WARN] Failed to load scheduler.")
|
929 |
+
|
930 |
+
if self.scaler and 'scaler' in checkpoint_dict:
|
931 |
+
try:
|
932 |
+
self.scaler.load_state_dict(checkpoint_dict['scaler'])
|
933 |
+
self.log("[INFO] loaded scaler.")
|
934 |
+
except:
|
935 |
+
self.log("[WARN] Failed to load scaler.")
|
optimizer.py
ADDED
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import enum
|
4 |
+
import itertools
|
5 |
+
from dataclasses import dataclass
|
6 |
+
import torch.optim as optim
|
7 |
+
|
8 |
+
@torch.no_grad()
|
9 |
+
def PowerIter(mat_g, error_tolerance=1e-6, num_iters=100):
|
10 |
+
"""Power iteration.
|
11 |
+
Compute the maximum eigenvalue of mat, for scaling.
|
12 |
+
v is a random vector with values in (-1, 1)
|
13 |
+
Args:
|
14 |
+
mat_g: the symmetric PSD matrix.
|
15 |
+
error_tolerance: Iterative exit condition.
|
16 |
+
num_iters: Number of iterations.
|
17 |
+
Returns:
|
18 |
+
eigen vector, eigen value, num_iters
|
19 |
+
"""
|
20 |
+
v = torch.rand(list(mat_g.shape)[0], device=mat_g.get_device()) * 2 - 1
|
21 |
+
error = 1
|
22 |
+
iters = 0
|
23 |
+
singular_val = 0
|
24 |
+
while error > error_tolerance and iters < num_iters:
|
25 |
+
v = v / torch.norm(v)
|
26 |
+
mat_v = torch.mv(mat_g, v)
|
27 |
+
s_v = torch.dot(v, mat_v)
|
28 |
+
error = torch.abs(s_v - singular_val)
|
29 |
+
v = mat_v
|
30 |
+
singular_val = s_v
|
31 |
+
iters += 1
|
32 |
+
return singular_val, v / torch.norm(v), iters
|
33 |
+
|
34 |
+
|
35 |
+
@torch.no_grad()
|
36 |
+
def MatPower(mat_m, p):
|
37 |
+
"""Computes mat_m^p, for p a positive integer.
|
38 |
+
Args:
|
39 |
+
mat_m: a square matrix
|
40 |
+
p: a positive integer
|
41 |
+
Returns:
|
42 |
+
mat_m^p
|
43 |
+
"""
|
44 |
+
if p in [1, 2, 4, 8, 16, 32]:
|
45 |
+
p_done = 1
|
46 |
+
res = mat_m
|
47 |
+
while p_done < p:
|
48 |
+
res = torch.matmul(res, res)
|
49 |
+
p_done *= 2
|
50 |
+
return res
|
51 |
+
|
52 |
+
power = None
|
53 |
+
while p > 0:
|
54 |
+
if p % 2 == 1:
|
55 |
+
power = torch.matmul(mat_m, power) if power is not None else mat_m
|
56 |
+
p //= 2
|
57 |
+
mat_m = torch.matmul(mat_m, mat_m)
|
58 |
+
return power
|
59 |
+
|
60 |
+
|
61 |
+
@torch.no_grad()
|
62 |
+
def ComputePower(mat_g, p,
|
63 |
+
iter_count=100,
|
64 |
+
error_tolerance=1e-6,
|
65 |
+
ridge_epsilon=1e-6):
|
66 |
+
"""A method to compute G^{-1/p} using a coupled Newton iteration.
|
67 |
+
See for example equation 3.2 on page 9 of:
|
68 |
+
A Schur-Newton Method for the Matrix p-th Root and its Inverse
|
69 |
+
by Chun-Hua Guo and Nicholas J. Higham
|
70 |
+
SIAM Journal on Matrix Analysis and Applications,
|
71 |
+
2006, Vol. 28, No. 3 : pp. 788-804
|
72 |
+
https://pdfs.semanticscholar.org/0abe/7f77433cf5908bfe2b79aa91af881da83858.pdf
|
73 |
+
Args:
|
74 |
+
mat_g: A square positive semidefinite matrix
|
75 |
+
p: a positive integer
|
76 |
+
iter_count: Stop iterating after this many rounds.
|
77 |
+
error_tolerance: Threshold for stopping iteration
|
78 |
+
ridge_epsilon: We add this times I to G, to make is positive definite.
|
79 |
+
For scaling, we multiply it by the largest eigenvalue of G.
|
80 |
+
Returns:
|
81 |
+
(mat_g + rI)^{-1/p} (r = ridge_epsilon * max_eigenvalue of mat_g).
|
82 |
+
"""
|
83 |
+
shape = list(mat_g.shape)
|
84 |
+
if len(shape) == 1:
|
85 |
+
return torch.pow(mat_g + ridge_epsilon, -1/p)
|
86 |
+
identity = torch.eye(shape[0], device=mat_g.get_device())
|
87 |
+
if shape[0] == 1:
|
88 |
+
return identity
|
89 |
+
alpha = -1.0/p
|
90 |
+
max_ev, _, _ = PowerIter(mat_g)
|
91 |
+
ridge_epsilon *= max_ev
|
92 |
+
mat_g += ridge_epsilon * identity
|
93 |
+
z = (1 + p) / (2 * torch.norm(mat_g))
|
94 |
+
# The best value for z is
|
95 |
+
# (1 + p) * (c_max^{1/p} - c_min^{1/p}) /
|
96 |
+
# (c_max^{1+1/p} - c_min^{1+1/p})
|
97 |
+
# where c_max and c_min are the largest and smallest singular values of
|
98 |
+
# mat_g.
|
99 |
+
# The above estimate assumes that c_max > c_min * 2^p
|
100 |
+
# Can replace above line by the one below, but it is less accurate,
|
101 |
+
# hence needs more iterations to converge.
|
102 |
+
# z = (1 + p) / tf.trace(mat_g)
|
103 |
+
# If we want the method to always converge, use z = 1 / norm(mat_g)
|
104 |
+
# or z = 1 / tf.trace(mat_g), but these can result in many
|
105 |
+
# extra iterations.
|
106 |
+
|
107 |
+
mat_root = identity * torch.pow(z, 1.0/p)
|
108 |
+
mat_m = mat_g * z
|
109 |
+
error = torch.max(torch.abs(mat_m - identity))
|
110 |
+
count = 0
|
111 |
+
while error > error_tolerance and count < iter_count:
|
112 |
+
tmp_mat_m = (1 - alpha) * identity + alpha * mat_m
|
113 |
+
new_mat_root = torch.matmul(mat_root, tmp_mat_m)
|
114 |
+
mat_m = torch.matmul(MatPower(tmp_mat_m, p), mat_m)
|
115 |
+
new_error = torch.max(torch.abs(mat_m - identity))
|
116 |
+
if new_error > error * 1.2:
|
117 |
+
break
|
118 |
+
mat_root = new_mat_root
|
119 |
+
error = new_error
|
120 |
+
count += 1
|
121 |
+
return mat_root
|
122 |
+
|
123 |
+
|
124 |
+
|
125 |
+
# Grafting is a technique to fix the layerwise scale of Shampoo optimizer.
|
126 |
+
# https://arxiv.org/pdf/2002.11803.pdf studies this in detail. This
|
127 |
+
# allows us to plugin the Shampoo optimizer into settings where SGD/AdaGrad
|
128 |
+
# is already well tuned. Grafting onto Shampoo means take the Shampoo direction,
|
129 |
+
# but use the step magnitude from the grafted optimizer such as Adagrad or SGD.
|
130 |
+
class LayerwiseGrafting(enum.IntEnum):
|
131 |
+
NONE = 0
|
132 |
+
SGD = 1
|
133 |
+
ADAGRAD = 2
|
134 |
+
|
135 |
+
|
136 |
+
@dataclass
|
137 |
+
class ShampooHyperParams:
|
138 |
+
"""Shampoo hyper parameters."""
|
139 |
+
beta2: float = 0.9
|
140 |
+
diagonal_eps: float = 1e-6
|
141 |
+
matrix_eps: float = 1e-12
|
142 |
+
weight_decay: float = 0.0
|
143 |
+
inverse_exponent_override: int = 2 # fixed exponent for preconditioner, if >0
|
144 |
+
start_preconditioning_step: int = 1
|
145 |
+
# Performance tuning params for controlling memory and compute requirements.
|
146 |
+
# How often to compute preconditioner.
|
147 |
+
preconditioning_compute_steps: int = 1
|
148 |
+
# How often to compute statistics.
|
149 |
+
statistics_compute_steps: int = 1
|
150 |
+
# Block size for large layers (if > 0).
|
151 |
+
# Block size = 1 ==> Adagrad (Don't do this, extremely inefficient!)
|
152 |
+
# Block size should be as large as feasible under memory/time constraints.
|
153 |
+
block_size: int = 128
|
154 |
+
# Automatic shape interpretation (for eg: [4, 3, 1024, 512] would result in
|
155 |
+
# 12 x [1024, 512] L and R statistics. Disabled by default which results in
|
156 |
+
# Shampoo constructing statistics [4, 4], [3, 3], [1024, 1024], [512, 512].
|
157 |
+
best_effort_shape_interpretation: bool = True
|
158 |
+
# Type of grafting (SGD or AdaGrad).
|
159 |
+
# https://arxiv.org/pdf/2002.11803.pdf
|
160 |
+
graft_type: int = LayerwiseGrafting.ADAGRAD
|
161 |
+
# Nesterov momentum
|
162 |
+
nesterov: bool = True
|
163 |
+
|
164 |
+
|
165 |
+
class Graft:
|
166 |
+
"""Base class to perform grafting onto Shampoo. This class does no grafting.
|
167 |
+
"""
|
168 |
+
|
169 |
+
def __init__(self, hps, unused_var):
|
170 |
+
self.hps = hps
|
171 |
+
|
172 |
+
def add_statistics(self, grad):
|
173 |
+
pass
|
174 |
+
|
175 |
+
def precondition_gradient(self, grad):
|
176 |
+
return grad
|
177 |
+
|
178 |
+
def update_momentum(self, update, unused_beta1):
|
179 |
+
return update
|
180 |
+
|
181 |
+
|
182 |
+
class SGDGraft(Graft):
|
183 |
+
"""Graft using SGD+momentum.
|
184 |
+
momentum maintains an exponentially weighted moving average of gradients.
|
185 |
+
"""
|
186 |
+
|
187 |
+
def __init__(self, hps, var):
|
188 |
+
super(SGDGraft, self).__init__(hps, var)
|
189 |
+
self.momentum = torch.zeros_like(var.data, device=var.get_device())
|
190 |
+
|
191 |
+
def update_momentum(self, update, beta1):
|
192 |
+
self.momentum.mul_(beta1).add_(update)
|
193 |
+
return self.momentum
|
194 |
+
|
195 |
+
|
196 |
+
class AdagradGraft(SGDGraft):
|
197 |
+
"""Graft using Adagrad.
|
198 |
+
Essentially an implementation of Adagrad with momentum.
|
199 |
+
"""
|
200 |
+
|
201 |
+
def __init__(self, hps, var):
|
202 |
+
super(AdagradGraft, self).__init__(hps, var)
|
203 |
+
self.statistics = torch.zeros_like(var.data, device=var.get_device())
|
204 |
+
|
205 |
+
def add_statistics(self, grad):
|
206 |
+
self.statistics.add_(grad * grad)
|
207 |
+
|
208 |
+
def precondition_gradient(self, grad):
|
209 |
+
return grad / (torch.sqrt(self.statistics) + self.hps.diagonal_eps)
|
210 |
+
|
211 |
+
|
212 |
+
class BlockPartitioner:
|
213 |
+
"""Partitions a tensor into smaller tensors for preconditioning.
|
214 |
+
For example, if a variable has shape (4096, 512), we might split the
|
215 |
+
4096 into 4 blocks, so we effectively have 4 variables of size
|
216 |
+
(1024, 512) each.
|
217 |
+
"""
|
218 |
+
|
219 |
+
def __init__(self, var, hps):
|
220 |
+
self._shape = var.shape
|
221 |
+
self._splits = []
|
222 |
+
self._split_sizes = []
|
223 |
+
split_sizes = []
|
224 |
+
# We split var into smaller blocks. Here we store the metadata to make
|
225 |
+
# that split.
|
226 |
+
for i, d in enumerate(var.shape):
|
227 |
+
if hps.block_size > 0 and d > hps.block_size:
|
228 |
+
# d-1, otherwise split appends a 0-size array.
|
229 |
+
nsplit = (d-1) // hps.block_size
|
230 |
+
indices = (np.arange(nsplit, dtype=np.int32) + 1) * hps.block_size
|
231 |
+
sizes = np.ones(nsplit + 1, dtype=np.int32) * hps.block_size
|
232 |
+
sizes[-1] = d - indices[-1]
|
233 |
+
self._splits.append((i, indices))
|
234 |
+
self._split_sizes.append((i, sizes))
|
235 |
+
split_sizes.append(sizes)
|
236 |
+
else:
|
237 |
+
split_sizes.append(np.array([d], dtype=np.int32))
|
238 |
+
self._num_splits = len(split_sizes)
|
239 |
+
self._preconditioner_shapes = []
|
240 |
+
for t in itertools.product(*split_sizes):
|
241 |
+
self._preconditioner_shapes.extend([[d, d] for d in t])
|
242 |
+
|
243 |
+
def shapes_for_preconditioners(self):
|
244 |
+
return self._preconditioner_shapes
|
245 |
+
|
246 |
+
def num_splits(self):
|
247 |
+
return self._num_splits
|
248 |
+
|
249 |
+
def partition(self, tensor):
|
250 |
+
"""Partition tensor into blocks."""
|
251 |
+
|
252 |
+
assert tensor.shape == self._shape
|
253 |
+
tensors = [tensor]
|
254 |
+
for (i, sizes) in self._split_sizes:
|
255 |
+
tensors_local = []
|
256 |
+
for t in tensors:
|
257 |
+
tensors_local.extend(
|
258 |
+
torch.split(t, tuple(sizes), dim=i))
|
259 |
+
tensors = tensors_local
|
260 |
+
return tensors
|
261 |
+
|
262 |
+
def merge_partitions(self, partitions):
|
263 |
+
"""Merge partitions back to original shape."""
|
264 |
+
|
265 |
+
for (i, indices) in reversed(self._splits):
|
266 |
+
n = len(indices) + 1
|
267 |
+
partial_merged_tensors = []
|
268 |
+
ind = 0
|
269 |
+
while ind < len(partitions):
|
270 |
+
partial_merged_tensors.append(
|
271 |
+
torch.cat(partitions[ind:ind + n], axis=i))
|
272 |
+
ind += n
|
273 |
+
partitions = partial_merged_tensors
|
274 |
+
assert len(partitions) == 1
|
275 |
+
return partitions[0]
|
276 |
+
|
277 |
+
|
278 |
+
def _merge_small_dims(shape_to_merge, max_dim):
|
279 |
+
"""Merge small dimensions.
|
280 |
+
If there are some small dimensions, we collapse them:
|
281 |
+
e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024
|
282 |
+
[1, 2, 768, 1, 2048] --> [2, 768, 2048]
|
283 |
+
Args:
|
284 |
+
shape_to_merge: Shape to merge small dimensions.
|
285 |
+
max_dim: Maximal dimension of output shape used in merging.
|
286 |
+
Returns:
|
287 |
+
Merged shape.
|
288 |
+
"""
|
289 |
+
resulting_shape = []
|
290 |
+
product = 1
|
291 |
+
for d in shape_to_merge:
|
292 |
+
if product * d <= max_dim:
|
293 |
+
product *= d
|
294 |
+
else:
|
295 |
+
if product > 1:
|
296 |
+
resulting_shape.append(product)
|
297 |
+
product = d
|
298 |
+
if product > 1:
|
299 |
+
resulting_shape.append(product)
|
300 |
+
return resulting_shape
|
301 |
+
|
302 |
+
|
303 |
+
class Preconditioner:
|
304 |
+
"""Compute statistics/shape from gradients for preconditioning."""
|
305 |
+
|
306 |
+
def __init__(self, var, hps):
|
307 |
+
self._hps = hps
|
308 |
+
self._original_shape = var.shape
|
309 |
+
self._transformed_shape = var.shape
|
310 |
+
if hps.best_effort_shape_interpretation:
|
311 |
+
self._transformed_shape = _merge_small_dims(
|
312 |
+
self._original_shape, hps.block_size)
|
313 |
+
|
314 |
+
reshaped_var = torch.reshape(var, self._transformed_shape)
|
315 |
+
self._partitioner = BlockPartitioner(reshaped_var, hps)
|
316 |
+
shapes = self._partitioner.shapes_for_preconditioners()
|
317 |
+
rank = len(self._transformed_shape)
|
318 |
+
device = var.get_device()
|
319 |
+
if rank <= 1:
|
320 |
+
self.statistics = []
|
321 |
+
self.preconditioners = []
|
322 |
+
else:
|
323 |
+
eps = self._hps.matrix_eps
|
324 |
+
self.statistics = [eps * torch.eye(s[0], device=device) for s in shapes]
|
325 |
+
self.preconditioners = [torch.eye(s[0], device=device) for s in shapes]
|
326 |
+
|
327 |
+
def add_statistics(self, grad):
|
328 |
+
"""Compute statistics from gradients and add to the correct state entries.
|
329 |
+
Args:
|
330 |
+
grad: Gradient to compute statistics from.
|
331 |
+
"""
|
332 |
+
if not self.statistics: return
|
333 |
+
reshaped_grad = torch.reshape(grad, self._transformed_shape)
|
334 |
+
partitioned_grads = self._partitioner.partition(reshaped_grad)
|
335 |
+
w1 = self._hps.beta2
|
336 |
+
w2 = 1.0 if w1 == 1.0 else (1.0 - w1)
|
337 |
+
rank = len(self._transformed_shape)
|
338 |
+
for j, grad in enumerate(partitioned_grads):
|
339 |
+
for i in range(rank):
|
340 |
+
axes = list(range(i)) + list(range(i + 1, rank))
|
341 |
+
stat = torch.tensordot(grad, grad, [axes, axes])
|
342 |
+
self.statistics[j*rank + i].mul_(w1).add_(stat, alpha=w2)
|
343 |
+
|
344 |
+
def exponent_for_preconditioner(self):
|
345 |
+
"""Returns exponent to use for inverse-pth root M^{-1/p}."""
|
346 |
+
if self._hps.inverse_exponent_override > 0:
|
347 |
+
return self._hps.inverse_exponent_override
|
348 |
+
return 2 * len(self._transformed_shape)
|
349 |
+
|
350 |
+
def compute_preconditioners(self):
|
351 |
+
"""Compute L^{-1/exp} for each stats matrix L."""
|
352 |
+
exp = self.exponent_for_preconditioner()
|
353 |
+
eps = self._hps.matrix_eps
|
354 |
+
for i, stat in enumerate(self.statistics):
|
355 |
+
self.preconditioners[i] = ComputePower(
|
356 |
+
stat, exp, ridge_epsilon=eps)
|
357 |
+
|
358 |
+
def preconditioned_grad(self, grad):
|
359 |
+
"""Precondition the gradient.
|
360 |
+
Args:
|
361 |
+
grad: A gradient tensor to precondition.
|
362 |
+
Returns:
|
363 |
+
A preconditioned gradient.
|
364 |
+
"""
|
365 |
+
if not self.preconditioners: return grad
|
366 |
+
reshaped_grad = torch.reshape(grad, self._transformed_shape)
|
367 |
+
partitioned_grads = self._partitioner.partition(reshaped_grad)
|
368 |
+
preconditioned_partitioned_grads = []
|
369 |
+
num_splits = self._partitioner.num_splits()
|
370 |
+
for i, grad in enumerate(partitioned_grads):
|
371 |
+
preconditioners_for_grad = self.preconditioners[i * num_splits:(i + 1) *
|
372 |
+
num_splits]
|
373 |
+
rank = len(grad.shape)
|
374 |
+
precond_grad = grad
|
375 |
+
for j in range(rank):
|
376 |
+
preconditioner = preconditioners_for_grad[j]
|
377 |
+
precond_grad = torch.tensordot(
|
378 |
+
precond_grad, preconditioner, [[0], [0]])
|
379 |
+
preconditioned_partitioned_grads.append(precond_grad)
|
380 |
+
merged_grad = self._partitioner.merge_partitions(
|
381 |
+
preconditioned_partitioned_grads)
|
382 |
+
return torch.reshape(merged_grad, self._original_shape)
|
383 |
+
|
384 |
+
|
385 |
+
STEP = 'step'
|
386 |
+
MOMENTUM = 'momentum'
|
387 |
+
PRECONDITIONER = 'preconditioner'
|
388 |
+
GRAFT = 'graft'
|
389 |
+
|
390 |
+
|
391 |
+
class Shampoo(optim.Optimizer):
|
392 |
+
"""The Shampoo optimizer."""
|
393 |
+
|
394 |
+
def __init__(self,
|
395 |
+
params,
|
396 |
+
lr=1.0,
|
397 |
+
momentum=0.9,
|
398 |
+
hyperparams=ShampooHyperParams()):
|
399 |
+
defaults = dict(lr=lr, momentum=momentum)
|
400 |
+
self.hps = hyperparams
|
401 |
+
super(Shampoo, self).__init__(params, defaults)
|
402 |
+
|
403 |
+
def init_var_state(self, var, state):
|
404 |
+
"""Initialize the PyTorch state of for a single variable."""
|
405 |
+
state[STEP] = 0
|
406 |
+
state[MOMENTUM] = torch.zeros_like(var.data, device=var.get_device())
|
407 |
+
state[PRECONDITIONER] = Preconditioner(var, self.hps)
|
408 |
+
if self.hps.graft_type == LayerwiseGrafting.ADAGRAD:
|
409 |
+
state[GRAFT] = AdagradGraft(self.hps, var)
|
410 |
+
elif self.hps.graft_type == LayerwiseGrafting.SGD:
|
411 |
+
state[GRAFT] = SGDGraft(self.hps, var)
|
412 |
+
else:
|
413 |
+
state[GRAFT] = Graft(self.hps, var)
|
414 |
+
|
415 |
+
def step(self, closure=None):
|
416 |
+
hps = self.hps
|
417 |
+
for group in self.param_groups:
|
418 |
+
lr = group['lr']
|
419 |
+
for p in group['params']:
|
420 |
+
if p.grad is None: continue
|
421 |
+
grad = p.grad.data
|
422 |
+
if grad.is_sparse:
|
423 |
+
raise RuntimeError('Shampoo does not support sparse yet')
|
424 |
+
state = self.state[p]
|
425 |
+
if not state:
|
426 |
+
self.init_var_state(p, state)
|
427 |
+
state[STEP] += 1
|
428 |
+
|
429 |
+
preconditioner = state[PRECONDITIONER]
|
430 |
+
graft = state[GRAFT]
|
431 |
+
|
432 |
+
# Gather statistics, compute preconditioners
|
433 |
+
graft.add_statistics(grad)
|
434 |
+
if state[STEP] % hps.statistics_compute_steps == 0:
|
435 |
+
preconditioner.add_statistics(grad)
|
436 |
+
if state[STEP] % hps.preconditioning_compute_steps == 0:
|
437 |
+
preconditioner.compute_preconditioners()
|
438 |
+
|
439 |
+
# Precondition gradients
|
440 |
+
graft_grad = graft.precondition_gradient(grad)
|
441 |
+
shampoo_grad = grad
|
442 |
+
if state[STEP] >= self.hps.start_preconditioning_step:
|
443 |
+
shampoo_grad = preconditioner.preconditioned_grad(grad)
|
444 |
+
|
445 |
+
# Grafting
|
446 |
+
graft_norm = torch.norm(graft_grad)
|
447 |
+
shampoo_norm = torch.norm(shampoo_grad)
|
448 |
+
shampoo_grad.mul_(graft_norm / (shampoo_norm + 1e-16))
|
449 |
+
|
450 |
+
# Weight decay
|
451 |
+
if self.hps.weight_decay != 0.0:
|
452 |
+
shampoo_grad.add_(p.data, alpha=self.hps.weight_decay)
|
453 |
+
graft_grad.add_(p.data, alpha=self.hps.weight_decay)
|
454 |
+
|
455 |
+
# Momentum and Nesterov momentum, if needed
|
456 |
+
state[MOMENTUM].mul_(group['momentum']).add_(shampoo_grad)
|
457 |
+
graft_momentum = graft.update_momentum(grad, group['momentum'])
|
458 |
+
|
459 |
+
if state[STEP] >= self.hps.start_preconditioning_step:
|
460 |
+
momentum_update = state[MOMENTUM]
|
461 |
+
wd_update = shampoo_grad
|
462 |
+
else:
|
463 |
+
momentum_update = graft_momentum
|
464 |
+
wd_update = graft_grad
|
465 |
+
|
466 |
+
if hps.nesterov:
|
467 |
+
momentum_update.mul_(group['momentum']).add_(wd_update)
|
468 |
+
|
469 |
+
# Final update
|
470 |
+
p.data.add_(momentum_update, alpha=-lr)
|
raymarching/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .raymarching import *
|
raymarching/backend.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from torch.utils.cpp_extension import load
|
3 |
+
|
4 |
+
_src_path = os.path.dirname(os.path.abspath(__file__))
|
5 |
+
|
6 |
+
nvcc_flags = [
|
7 |
+
'-O3', '-std=c++14',
|
8 |
+
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
9 |
+
]
|
10 |
+
|
11 |
+
if os.name == "posix":
|
12 |
+
c_flags = ['-O3', '-std=c++14']
|
13 |
+
elif os.name == "nt":
|
14 |
+
c_flags = ['/O2', '/std:c++17']
|
15 |
+
|
16 |
+
# find cl.exe
|
17 |
+
def find_cl_path():
|
18 |
+
import glob
|
19 |
+
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
20 |
+
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
|
21 |
+
if paths:
|
22 |
+
return paths[0]
|
23 |
+
|
24 |
+
# If cl.exe is not on path, try to find it.
|
25 |
+
if os.system("where cl.exe >nul 2>nul") != 0:
|
26 |
+
cl_path = find_cl_path()
|
27 |
+
if cl_path is None:
|
28 |
+
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
29 |
+
os.environ["PATH"] += ";" + cl_path
|
30 |
+
|
31 |
+
_backend = load(name='_raymarching',
|
32 |
+
extra_cflags=c_flags,
|
33 |
+
extra_cuda_cflags=nvcc_flags,
|
34 |
+
sources=[os.path.join(_src_path, 'src', f) for f in [
|
35 |
+
'raymarching.cu',
|
36 |
+
'bindings.cpp',
|
37 |
+
]],
|
38 |
+
)
|
39 |
+
|
40 |
+
__all__ = ['_backend']
|
raymarching/raymarching.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import time
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.autograd import Function
|
7 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
8 |
+
|
9 |
+
try:
|
10 |
+
import _raymarching as _backend
|
11 |
+
except ImportError:
|
12 |
+
from .backend import _backend
|
13 |
+
|
14 |
+
|
15 |
+
# ----------------------------------------
|
16 |
+
# utils
|
17 |
+
# ----------------------------------------
|
18 |
+
|
19 |
+
class _near_far_from_aabb(Function):
|
20 |
+
@staticmethod
|
21 |
+
@custom_fwd(cast_inputs=torch.float32)
|
22 |
+
def forward(ctx, rays_o, rays_d, aabb, min_near=0.2):
|
23 |
+
''' near_far_from_aabb, CUDA implementation
|
24 |
+
Calculate rays' intersection time (near and far) with aabb
|
25 |
+
Args:
|
26 |
+
rays_o: float, [N, 3]
|
27 |
+
rays_d: float, [N, 3]
|
28 |
+
aabb: float, [6], (xmin, ymin, zmin, xmax, ymax, zmax)
|
29 |
+
min_near: float, scalar
|
30 |
+
Returns:
|
31 |
+
nears: float, [N]
|
32 |
+
fars: float, [N]
|
33 |
+
'''
|
34 |
+
if not rays_o.is_cuda: rays_o = rays_o.cuda()
|
35 |
+
if not rays_d.is_cuda: rays_d = rays_d.cuda()
|
36 |
+
|
37 |
+
rays_o = rays_o.contiguous().view(-1, 3)
|
38 |
+
rays_d = rays_d.contiguous().view(-1, 3)
|
39 |
+
|
40 |
+
N = rays_o.shape[0] # num rays
|
41 |
+
|
42 |
+
nears = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)
|
43 |
+
fars = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)
|
44 |
+
|
45 |
+
_backend.near_far_from_aabb(rays_o, rays_d, aabb, N, min_near, nears, fars)
|
46 |
+
|
47 |
+
return nears, fars
|
48 |
+
|
49 |
+
near_far_from_aabb = _near_far_from_aabb.apply
|
50 |
+
|
51 |
+
|
52 |
+
class _sph_from_ray(Function):
|
53 |
+
@staticmethod
|
54 |
+
@custom_fwd(cast_inputs=torch.float32)
|
55 |
+
def forward(ctx, rays_o, rays_d, radius):
|
56 |
+
''' sph_from_ray, CUDA implementation
|
57 |
+
get spherical coordinate on the background sphere from rays.
|
58 |
+
Assume rays_o are inside the Sphere(radius).
|
59 |
+
Args:
|
60 |
+
rays_o: [N, 3]
|
61 |
+
rays_d: [N, 3]
|
62 |
+
radius: scalar, float
|
63 |
+
Return:
|
64 |
+
coords: [N, 2], in [-1, 1], theta and phi on a sphere. (further-surface)
|
65 |
+
'''
|
66 |
+
if not rays_o.is_cuda: rays_o = rays_o.cuda()
|
67 |
+
if not rays_d.is_cuda: rays_d = rays_d.cuda()
|
68 |
+
|
69 |
+
rays_o = rays_o.contiguous().view(-1, 3)
|
70 |
+
rays_d = rays_d.contiguous().view(-1, 3)
|
71 |
+
|
72 |
+
N = rays_o.shape[0] # num rays
|
73 |
+
|
74 |
+
coords = torch.empty(N, 2, dtype=rays_o.dtype, device=rays_o.device)
|
75 |
+
|
76 |
+
_backend.sph_from_ray(rays_o, rays_d, radius, N, coords)
|
77 |
+
|
78 |
+
return coords
|
79 |
+
|
80 |
+
sph_from_ray = _sph_from_ray.apply
|
81 |
+
|
82 |
+
|
83 |
+
class _morton3D(Function):
|
84 |
+
@staticmethod
|
85 |
+
def forward(ctx, coords):
|
86 |
+
''' morton3D, CUDA implementation
|
87 |
+
Args:
|
88 |
+
coords: [N, 3], int32, in [0, 128) (for some reason there is no uint32 tensor in torch...)
|
89 |
+
TODO: check if the coord range is valid! (current 128 is safe)
|
90 |
+
Returns:
|
91 |
+
indices: [N], int32, in [0, 128^3)
|
92 |
+
|
93 |
+
'''
|
94 |
+
if not coords.is_cuda: coords = coords.cuda()
|
95 |
+
|
96 |
+
N = coords.shape[0]
|
97 |
+
|
98 |
+
indices = torch.empty(N, dtype=torch.int32, device=coords.device)
|
99 |
+
|
100 |
+
_backend.morton3D(coords.int(), N, indices)
|
101 |
+
|
102 |
+
return indices
|
103 |
+
|
104 |
+
morton3D = _morton3D.apply
|
105 |
+
|
106 |
+
class _morton3D_invert(Function):
|
107 |
+
@staticmethod
|
108 |
+
def forward(ctx, indices):
|
109 |
+
''' morton3D_invert, CUDA implementation
|
110 |
+
Args:
|
111 |
+
indices: [N], int32, in [0, 128^3)
|
112 |
+
Returns:
|
113 |
+
coords: [N, 3], int32, in [0, 128)
|
114 |
+
|
115 |
+
'''
|
116 |
+
if not indices.is_cuda: indices = indices.cuda()
|
117 |
+
|
118 |
+
N = indices.shape[0]
|
119 |
+
|
120 |
+
coords = torch.empty(N, 3, dtype=torch.int32, device=indices.device)
|
121 |
+
|
122 |
+
_backend.morton3D_invert(indices.int(), N, coords)
|
123 |
+
|
124 |
+
return coords
|
125 |
+
|
126 |
+
morton3D_invert = _morton3D_invert.apply
|
127 |
+
|
128 |
+
|
129 |
+
class _packbits(Function):
|
130 |
+
@staticmethod
|
131 |
+
@custom_fwd(cast_inputs=torch.float32)
|
132 |
+
def forward(ctx, grid, thresh, bitfield=None):
|
133 |
+
''' packbits, CUDA implementation
|
134 |
+
Pack up the density grid into a bit field to accelerate ray marching.
|
135 |
+
Args:
|
136 |
+
grid: float, [C, H * H * H], assume H % 2 == 0
|
137 |
+
thresh: float, threshold
|
138 |
+
Returns:
|
139 |
+
bitfield: uint8, [C, H * H * H / 8]
|
140 |
+
'''
|
141 |
+
if not grid.is_cuda: grid = grid.cuda()
|
142 |
+
grid = grid.contiguous()
|
143 |
+
|
144 |
+
C = grid.shape[0]
|
145 |
+
H3 = grid.shape[1]
|
146 |
+
N = C * H3 // 8
|
147 |
+
|
148 |
+
if bitfield is None:
|
149 |
+
bitfield = torch.empty(N, dtype=torch.uint8, device=grid.device)
|
150 |
+
|
151 |
+
_backend.packbits(grid, N, thresh, bitfield)
|
152 |
+
|
153 |
+
return bitfield
|
154 |
+
|
155 |
+
packbits = _packbits.apply
|
156 |
+
|
157 |
+
# ----------------------------------------
|
158 |
+
# train functions
|
159 |
+
# ----------------------------------------
|
160 |
+
|
161 |
+
class _march_rays_train(Function):
|
162 |
+
@staticmethod
|
163 |
+
@custom_fwd(cast_inputs=torch.float32)
|
164 |
+
def forward(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, step_counter=None, mean_count=-1, perturb=False, align=-1, force_all_rays=False, dt_gamma=0, max_steps=1024):
|
165 |
+
''' march rays to generate points (forward only)
|
166 |
+
Args:
|
167 |
+
rays_o/d: float, [N, 3]
|
168 |
+
bound: float, scalar
|
169 |
+
density_bitfield: uint8: [CHHH // 8]
|
170 |
+
C: int
|
171 |
+
H: int
|
172 |
+
nears/fars: float, [N]
|
173 |
+
step_counter: int32, (2), used to count the actual number of generated points.
|
174 |
+
mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.)
|
175 |
+
perturb: bool
|
176 |
+
align: int, pad output so its size is dividable by align, set to -1 to disable.
|
177 |
+
force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays.
|
178 |
+
dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)
|
179 |
+
max_steps: int, max number of sampled points along each ray, also affect min_stepsize.
|
180 |
+
Returns:
|
181 |
+
xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray)
|
182 |
+
dirs: float, [M, 3], all generated points' view dirs.
|
183 |
+
deltas: float, [M, 2], all generated points' deltas. (first for RGB, second for Depth)
|
184 |
+
rays: int32, [N, 3], all rays' (index, point_offset, point_count), e.g., xyzs[rays[i, 1]:rays[i, 2]] --> points belonging to rays[i, 0]
|
185 |
+
'''
|
186 |
+
|
187 |
+
if not rays_o.is_cuda: rays_o = rays_o.cuda()
|
188 |
+
if not rays_d.is_cuda: rays_d = rays_d.cuda()
|
189 |
+
if not density_bitfield.is_cuda: density_bitfield = density_bitfield.cuda()
|
190 |
+
|
191 |
+
rays_o = rays_o.contiguous().view(-1, 3)
|
192 |
+
rays_d = rays_d.contiguous().view(-1, 3)
|
193 |
+
density_bitfield = density_bitfield.contiguous()
|
194 |
+
|
195 |
+
N = rays_o.shape[0] # num rays
|
196 |
+
M = N * max_steps # init max points number in total
|
197 |
+
|
198 |
+
# running average based on previous epoch (mimic `measured_batch_size_before_compaction` in instant-ngp)
|
199 |
+
# It estimate the max points number to enable faster training, but will lead to random ignored rays if underestimated.
|
200 |
+
if not force_all_rays and mean_count > 0:
|
201 |
+
if align > 0:
|
202 |
+
mean_count += align - mean_count % align
|
203 |
+
M = mean_count
|
204 |
+
|
205 |
+
xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
|
206 |
+
dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
|
207 |
+
deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device)
|
208 |
+
rays = torch.empty(N, 3, dtype=torch.int32, device=rays_o.device) # id, offset, num_steps
|
209 |
+
|
210 |
+
if step_counter is None:
|
211 |
+
step_counter = torch.zeros(2, dtype=torch.int32, device=rays_o.device) # point counter, ray counter
|
212 |
+
|
213 |
+
if perturb:
|
214 |
+
noises = torch.rand(N, dtype=rays_o.dtype, device=rays_o.device)
|
215 |
+
else:
|
216 |
+
noises = torch.zeros(N, dtype=rays_o.dtype, device=rays_o.device)
|
217 |
+
|
218 |
+
_backend.march_rays_train(rays_o, rays_d, density_bitfield, bound, dt_gamma, max_steps, N, C, H, M, nears, fars, xyzs, dirs, deltas, rays, step_counter, noises) # m is the actually used points number
|
219 |
+
|
220 |
+
#print(step_counter, M)
|
221 |
+
|
222 |
+
# only used at the first (few) epochs.
|
223 |
+
if force_all_rays or mean_count <= 0:
|
224 |
+
m = step_counter[0].item() # D2H copy
|
225 |
+
if align > 0:
|
226 |
+
m += align - m % align
|
227 |
+
xyzs = xyzs[:m]
|
228 |
+
dirs = dirs[:m]
|
229 |
+
deltas = deltas[:m]
|
230 |
+
|
231 |
+
torch.cuda.empty_cache()
|
232 |
+
|
233 |
+
return xyzs, dirs, deltas, rays
|
234 |
+
|
235 |
+
march_rays_train = _march_rays_train.apply
|
236 |
+
|
237 |
+
|
238 |
+
class _composite_rays_train(Function):
|
239 |
+
@staticmethod
|
240 |
+
@custom_fwd(cast_inputs=torch.float32)
|
241 |
+
def forward(ctx, sigmas, rgbs, deltas, rays, T_thresh=1e-4):
|
242 |
+
''' composite rays' rgbs, according to the ray marching formula.
|
243 |
+
Args:
|
244 |
+
rgbs: float, [M, 3]
|
245 |
+
sigmas: float, [M,]
|
246 |
+
deltas: float, [M, 2]
|
247 |
+
rays: int32, [N, 3]
|
248 |
+
Returns:
|
249 |
+
weights_sum: float, [N,], the alpha channel
|
250 |
+
depth: float, [N, ], the Depth
|
251 |
+
image: float, [N, 3], the RGB channel (after multiplying alpha!)
|
252 |
+
'''
|
253 |
+
|
254 |
+
sigmas = sigmas.contiguous()
|
255 |
+
rgbs = rgbs.contiguous()
|
256 |
+
|
257 |
+
M = sigmas.shape[0]
|
258 |
+
N = rays.shape[0]
|
259 |
+
|
260 |
+
weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
|
261 |
+
depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
|
262 |
+
image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device)
|
263 |
+
|
264 |
+
_backend.composite_rays_train_forward(sigmas, rgbs, deltas, rays, M, N, T_thresh, weights_sum, depth, image)
|
265 |
+
|
266 |
+
ctx.save_for_backward(sigmas, rgbs, deltas, rays, weights_sum, depth, image)
|
267 |
+
ctx.dims = [M, N, T_thresh]
|
268 |
+
|
269 |
+
return weights_sum, depth, image
|
270 |
+
|
271 |
+
@staticmethod
|
272 |
+
@custom_bwd
|
273 |
+
def backward(ctx, grad_weights_sum, grad_depth, grad_image):
|
274 |
+
|
275 |
+
# NOTE: grad_depth is not used now! It won't be propagated to sigmas.
|
276 |
+
|
277 |
+
grad_weights_sum = grad_weights_sum.contiguous()
|
278 |
+
grad_image = grad_image.contiguous()
|
279 |
+
|
280 |
+
sigmas, rgbs, deltas, rays, weights_sum, depth, image = ctx.saved_tensors
|
281 |
+
M, N, T_thresh = ctx.dims
|
282 |
+
|
283 |
+
grad_sigmas = torch.zeros_like(sigmas)
|
284 |
+
grad_rgbs = torch.zeros_like(rgbs)
|
285 |
+
|
286 |
+
_backend.composite_rays_train_backward(grad_weights_sum, grad_image, sigmas, rgbs, deltas, rays, weights_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs)
|
287 |
+
|
288 |
+
return grad_sigmas, grad_rgbs, None, None, None
|
289 |
+
|
290 |
+
|
291 |
+
composite_rays_train = _composite_rays_train.apply
|
292 |
+
|
293 |
+
# ----------------------------------------
|
294 |
+
# infer functions
|
295 |
+
# ----------------------------------------
|
296 |
+
|
297 |
+
class _march_rays(Function):
|
298 |
+
@staticmethod
|
299 |
+
@custom_fwd(cast_inputs=torch.float32)
|
300 |
+
def forward(ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, density_bitfield, C, H, near, far, align=-1, perturb=False, dt_gamma=0, max_steps=1024):
|
301 |
+
''' march rays to generate points (forward only, for inference)
|
302 |
+
Args:
|
303 |
+
n_alive: int, number of alive rays
|
304 |
+
n_step: int, how many steps we march
|
305 |
+
rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive)
|
306 |
+
rays_t: float, [N], the alive rays' time, we only use the first n_alive.
|
307 |
+
rays_o/d: float, [N, 3]
|
308 |
+
bound: float, scalar
|
309 |
+
density_bitfield: uint8: [CHHH // 8]
|
310 |
+
C: int
|
311 |
+
H: int
|
312 |
+
nears/fars: float, [N]
|
313 |
+
align: int, pad output so its size is dividable by align, set to -1 to disable.
|
314 |
+
perturb: bool/int, int > 0 is used as the random seed.
|
315 |
+
dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)
|
316 |
+
max_steps: int, max number of sampled points along each ray, also affect min_stepsize.
|
317 |
+
Returns:
|
318 |
+
xyzs: float, [n_alive * n_step, 3], all generated points' coords
|
319 |
+
dirs: float, [n_alive * n_step, 3], all generated points' view dirs.
|
320 |
+
deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth).
|
321 |
+
'''
|
322 |
+
|
323 |
+
if not rays_o.is_cuda: rays_o = rays_o.cuda()
|
324 |
+
if not rays_d.is_cuda: rays_d = rays_d.cuda()
|
325 |
+
|
326 |
+
rays_o = rays_o.contiguous().view(-1, 3)
|
327 |
+
rays_d = rays_d.contiguous().view(-1, 3)
|
328 |
+
|
329 |
+
M = n_alive * n_step
|
330 |
+
|
331 |
+
if align > 0:
|
332 |
+
M += align - (M % align)
|
333 |
+
|
334 |
+
xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
|
335 |
+
dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
|
336 |
+
deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) # 2 vals, one for rgb, one for depth
|
337 |
+
|
338 |
+
if perturb:
|
339 |
+
# torch.manual_seed(perturb) # test_gui uses spp index as seed
|
340 |
+
noises = torch.rand(n_alive, dtype=rays_o.dtype, device=rays_o.device)
|
341 |
+
else:
|
342 |
+
noises = torch.zeros(n_alive, dtype=rays_o.dtype, device=rays_o.device)
|
343 |
+
|
344 |
+
_backend.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, dt_gamma, max_steps, C, H, density_bitfield, near, far, xyzs, dirs, deltas, noises)
|
345 |
+
|
346 |
+
return xyzs, dirs, deltas
|
347 |
+
|
348 |
+
march_rays = _march_rays.apply
|
349 |
+
|
350 |
+
|
351 |
+
class _composite_rays(Function):
|
352 |
+
@staticmethod
|
353 |
+
@custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
|
354 |
+
def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh=1e-2):
|
355 |
+
''' composite rays' rgbs, according to the ray marching formula. (for inference)
|
356 |
+
Args:
|
357 |
+
n_alive: int, number of alive rays
|
358 |
+
n_step: int, how many steps we march
|
359 |
+
rays_alive: int, [n_alive], the alive rays' IDs in N (N >= n_alive)
|
360 |
+
rays_t: float, [N], the alive rays' time
|
361 |
+
sigmas: float, [n_alive * n_step,]
|
362 |
+
rgbs: float, [n_alive * n_step, 3]
|
363 |
+
deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth).
|
364 |
+
In-place Outputs:
|
365 |
+
weights_sum: float, [N,], the alpha channel
|
366 |
+
depth: float, [N,], the depth value
|
367 |
+
image: float, [N, 3], the RGB channel (after multiplying alpha!)
|
368 |
+
'''
|
369 |
+
_backend.composite_rays(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image)
|
370 |
+
return tuple()
|
371 |
+
|
372 |
+
|
373 |
+
composite_rays = _composite_rays.apply
|
raymarching/setup.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from setuptools import setup
|
3 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
4 |
+
|
5 |
+
_src_path = os.path.dirname(os.path.abspath(__file__))
|
6 |
+
|
7 |
+
nvcc_flags = [
|
8 |
+
'-O3', '-std=c++14',
|
9 |
+
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
10 |
+
]
|
11 |
+
|
12 |
+
if os.name == "posix":
|
13 |
+
c_flags = ['-O3', '-std=c++14']
|
14 |
+
elif os.name == "nt":
|
15 |
+
c_flags = ['/O2', '/std:c++17']
|
16 |
+
|
17 |
+
# find cl.exe
|
18 |
+
def find_cl_path():
|
19 |
+
import glob
|
20 |
+
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
21 |
+
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
|
22 |
+
if paths:
|
23 |
+
return paths[0]
|
24 |
+
|
25 |
+
# If cl.exe is not on path, try to find it.
|
26 |
+
if os.system("where cl.exe >nul 2>nul") != 0:
|
27 |
+
cl_path = find_cl_path()
|
28 |
+
if cl_path is None:
|
29 |
+
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
30 |
+
os.environ["PATH"] += ";" + cl_path
|
31 |
+
|
32 |
+
'''
|
33 |
+
Usage:
|
34 |
+
|
35 |
+
python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory)
|
36 |
+
|
37 |
+
python setup.py install # build extensions and install (copy) to PATH.
|
38 |
+
pip install . # ditto but better (e.g., dependency & metadata handling)
|
39 |
+
|
40 |
+
python setup.py develop # build extensions and install (symbolic) to PATH.
|
41 |
+
pip install -e . # ditto but better (e.g., dependency & metadata handling)
|
42 |
+
|
43 |
+
'''
|
44 |
+
setup(
|
45 |
+
name='raymarching', # package name, import this to use python API
|
46 |
+
ext_modules=[
|
47 |
+
CUDAExtension(
|
48 |
+
name='_raymarching', # extension name, import this to use CUDA API
|
49 |
+
sources=[os.path.join(_src_path, 'src', f) for f in [
|
50 |
+
'raymarching.cu',
|
51 |
+
'bindings.cpp',
|
52 |
+
]],
|
53 |
+
extra_compile_args={
|
54 |
+
'cxx': c_flags,
|
55 |
+
'nvcc': nvcc_flags,
|
56 |
+
}
|
57 |
+
),
|
58 |
+
],
|
59 |
+
cmdclass={
|
60 |
+
'build_ext': BuildExtension,
|
61 |
+
}
|
62 |
+
)
|
raymarching/src/bindings.cpp
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
|
3 |
+
#include "raymarching.h"
|
4 |
+
|
5 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
6 |
+
// utils
|
7 |
+
m.def("packbits", &packbits, "packbits (CUDA)");
|
8 |
+
m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)");
|
9 |
+
m.def("sph_from_ray", &sph_from_ray, "sph_from_ray (CUDA)");
|
10 |
+
m.def("morton3D", &morton3D, "morton3D (CUDA)");
|
11 |
+
m.def("morton3D_invert", &morton3D_invert, "morton3D_invert (CUDA)");
|
12 |
+
// train
|
13 |
+
m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)");
|
14 |
+
m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)");
|
15 |
+
m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)");
|
16 |
+
// infer
|
17 |
+
m.def("march_rays", &march_rays, "march rays (CUDA)");
|
18 |
+
m.def("composite_rays", &composite_rays, "composite rays (CUDA)");
|
19 |
+
}
|
raymarching/src/raymarching.cu
ADDED
@@ -0,0 +1,914 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <cuda.h>
|
2 |
+
#include <cuda_fp16.h>
|
3 |
+
#include <cuda_runtime.h>
|
4 |
+
|
5 |
+
#include <ATen/cuda/CUDAContext.h>
|
6 |
+
#include <torch/torch.h>
|
7 |
+
|
8 |
+
#include <cstdio>
|
9 |
+
#include <stdint.h>
|
10 |
+
#include <stdexcept>
|
11 |
+
#include <limits>
|
12 |
+
|
13 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
14 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
|
15 |
+
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
|
16 |
+
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
|
17 |
+
|
18 |
+
|
19 |
+
inline constexpr __device__ float SQRT3() { return 1.7320508075688772f; }
|
20 |
+
inline constexpr __device__ float RSQRT3() { return 0.5773502691896258f; }
|
21 |
+
inline constexpr __device__ float PI() { return 3.141592653589793f; }
|
22 |
+
inline constexpr __device__ float RPI() { return 0.3183098861837907f; }
|
23 |
+
|
24 |
+
|
25 |
+
template <typename T>
|
26 |
+
inline __host__ __device__ T div_round_up(T val, T divisor) {
|
27 |
+
return (val + divisor - 1) / divisor;
|
28 |
+
}
|
29 |
+
|
30 |
+
inline __host__ __device__ float signf(const float x) {
|
31 |
+
return copysignf(1.0, x);
|
32 |
+
}
|
33 |
+
|
34 |
+
inline __host__ __device__ float clamp(const float x, const float min, const float max) {
|
35 |
+
return fminf(max, fmaxf(min, x));
|
36 |
+
}
|
37 |
+
|
38 |
+
inline __host__ __device__ void swapf(float& a, float& b) {
|
39 |
+
float c = a; a = b; b = c;
|
40 |
+
}
|
41 |
+
|
42 |
+
inline __device__ int mip_from_pos(const float x, const float y, const float z, const float max_cascade) {
|
43 |
+
const float mx = fmaxf(fabsf(x), fmaxf(fabs(y), fabs(z)));
|
44 |
+
int exponent;
|
45 |
+
frexpf(mx, &exponent); // [0, 0.5) --> -1, [0.5, 1) --> 0, [1, 2) --> 1, [2, 4) --> 2, ...
|
46 |
+
return fminf(max_cascade - 1, fmaxf(0, exponent));
|
47 |
+
}
|
48 |
+
|
49 |
+
inline __device__ int mip_from_dt(const float dt, const float H, const float max_cascade) {
|
50 |
+
const float mx = dt * H * 0.5;
|
51 |
+
int exponent;
|
52 |
+
frexpf(mx, &exponent);
|
53 |
+
return fminf(max_cascade - 1, fmaxf(0, exponent));
|
54 |
+
}
|
55 |
+
|
56 |
+
inline __host__ __device__ uint32_t __expand_bits(uint32_t v)
|
57 |
+
{
|
58 |
+
v = (v * 0x00010001u) & 0xFF0000FFu;
|
59 |
+
v = (v * 0x00000101u) & 0x0F00F00Fu;
|
60 |
+
v = (v * 0x00000011u) & 0xC30C30C3u;
|
61 |
+
v = (v * 0x00000005u) & 0x49249249u;
|
62 |
+
return v;
|
63 |
+
}
|
64 |
+
|
65 |
+
inline __host__ __device__ uint32_t __morton3D(uint32_t x, uint32_t y, uint32_t z)
|
66 |
+
{
|
67 |
+
uint32_t xx = __expand_bits(x);
|
68 |
+
uint32_t yy = __expand_bits(y);
|
69 |
+
uint32_t zz = __expand_bits(z);
|
70 |
+
return xx | (yy << 1) | (zz << 2);
|
71 |
+
}
|
72 |
+
|
73 |
+
inline __host__ __device__ uint32_t __morton3D_invert(uint32_t x)
|
74 |
+
{
|
75 |
+
x = x & 0x49249249;
|
76 |
+
x = (x | (x >> 2)) & 0xc30c30c3;
|
77 |
+
x = (x | (x >> 4)) & 0x0f00f00f;
|
78 |
+
x = (x | (x >> 8)) & 0xff0000ff;
|
79 |
+
x = (x | (x >> 16)) & 0x0000ffff;
|
80 |
+
return x;
|
81 |
+
}
|
82 |
+
|
83 |
+
|
84 |
+
////////////////////////////////////////////////////
|
85 |
+
///////////// utils /////////////
|
86 |
+
////////////////////////////////////////////////////
|
87 |
+
|
88 |
+
// rays_o/d: [N, 3]
|
89 |
+
// nears/fars: [N]
|
90 |
+
// scalar_t should always be float in use.
|
91 |
+
template <typename scalar_t>
|
92 |
+
__global__ void kernel_near_far_from_aabb(
|
93 |
+
const scalar_t * __restrict__ rays_o,
|
94 |
+
const scalar_t * __restrict__ rays_d,
|
95 |
+
const scalar_t * __restrict__ aabb,
|
96 |
+
const uint32_t N,
|
97 |
+
const float min_near,
|
98 |
+
scalar_t * nears, scalar_t * fars
|
99 |
+
) {
|
100 |
+
// parallel per ray
|
101 |
+
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
102 |
+
if (n >= N) return;
|
103 |
+
|
104 |
+
// locate
|
105 |
+
rays_o += n * 3;
|
106 |
+
rays_d += n * 3;
|
107 |
+
|
108 |
+
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
|
109 |
+
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
|
110 |
+
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
|
111 |
+
|
112 |
+
// get near far (assume cube scene)
|
113 |
+
float near = (aabb[0] - ox) * rdx;
|
114 |
+
float far = (aabb[3] - ox) * rdx;
|
115 |
+
if (near > far) swapf(near, far);
|
116 |
+
|
117 |
+
float near_y = (aabb[1] - oy) * rdy;
|
118 |
+
float far_y = (aabb[4] - oy) * rdy;
|
119 |
+
if (near_y > far_y) swapf(near_y, far_y);
|
120 |
+
|
121 |
+
if (near > far_y || near_y > far) {
|
122 |
+
nears[n] = fars[n] = std::numeric_limits<scalar_t>::max();
|
123 |
+
return;
|
124 |
+
}
|
125 |
+
|
126 |
+
if (near_y > near) near = near_y;
|
127 |
+
if (far_y < far) far = far_y;
|
128 |
+
|
129 |
+
float near_z = (aabb[2] - oz) * rdz;
|
130 |
+
float far_z = (aabb[5] - oz) * rdz;
|
131 |
+
if (near_z > far_z) swapf(near_z, far_z);
|
132 |
+
|
133 |
+
if (near > far_z || near_z > far) {
|
134 |
+
nears[n] = fars[n] = std::numeric_limits<scalar_t>::max();
|
135 |
+
return;
|
136 |
+
}
|
137 |
+
|
138 |
+
if (near_z > near) near = near_z;
|
139 |
+
if (far_z < far) far = far_z;
|
140 |
+
|
141 |
+
if (near < min_near) near = min_near;
|
142 |
+
|
143 |
+
nears[n] = near;
|
144 |
+
fars[n] = far;
|
145 |
+
}
|
146 |
+
|
147 |
+
|
148 |
+
void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars) {
|
149 |
+
|
150 |
+
static constexpr uint32_t N_THREAD = 128;
|
151 |
+
|
152 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
153 |
+
rays_o.scalar_type(), "near_far_from_aabb", ([&] {
|
154 |
+
kernel_near_far_from_aabb<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), aabb.data_ptr<scalar_t>(), N, min_near, nears.data_ptr<scalar_t>(), fars.data_ptr<scalar_t>());
|
155 |
+
}));
|
156 |
+
}
|
157 |
+
|
158 |
+
|
159 |
+
// rays_o/d: [N, 3]
|
160 |
+
// radius: float
|
161 |
+
// coords: [N, 2]
|
162 |
+
template <typename scalar_t>
|
163 |
+
__global__ void kernel_sph_from_ray(
|
164 |
+
const scalar_t * __restrict__ rays_o,
|
165 |
+
const scalar_t * __restrict__ rays_d,
|
166 |
+
const float radius,
|
167 |
+
const uint32_t N,
|
168 |
+
scalar_t * coords
|
169 |
+
) {
|
170 |
+
// parallel per ray
|
171 |
+
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
172 |
+
if (n >= N) return;
|
173 |
+
|
174 |
+
// locate
|
175 |
+
rays_o += n * 3;
|
176 |
+
rays_d += n * 3;
|
177 |
+
coords += n * 2;
|
178 |
+
|
179 |
+
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
|
180 |
+
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
|
181 |
+
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
|
182 |
+
|
183 |
+
// solve t from || o + td || = radius
|
184 |
+
const float A = dx * dx + dy * dy + dz * dz;
|
185 |
+
const float B = ox * dx + oy * dy + oz * dz; // in fact B / 2
|
186 |
+
const float C = ox * ox + oy * oy + oz * oz - radius * radius;
|
187 |
+
|
188 |
+
const float t = (- B + sqrtf(B * B - A * C)) / A; // always use the larger solution (positive)
|
189 |
+
|
190 |
+
// solve theta, phi (assume y is the up axis)
|
191 |
+
const float x = ox + t * dx, y = oy + t * dy, z = oz + t * dz;
|
192 |
+
const float theta = atan2(sqrtf(x * x + z * z), y); // [0, PI)
|
193 |
+
const float phi = atan2(z, x); // [-PI, PI)
|
194 |
+
|
195 |
+
// normalize to [-1, 1]
|
196 |
+
coords[0] = 2 * theta * RPI() - 1;
|
197 |
+
coords[1] = phi * RPI();
|
198 |
+
}
|
199 |
+
|
200 |
+
|
201 |
+
void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords) {
|
202 |
+
|
203 |
+
static constexpr uint32_t N_THREAD = 128;
|
204 |
+
|
205 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
206 |
+
rays_o.scalar_type(), "sph_from_ray", ([&] {
|
207 |
+
kernel_sph_from_ray<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), radius, N, coords.data_ptr<scalar_t>());
|
208 |
+
}));
|
209 |
+
}
|
210 |
+
|
211 |
+
|
212 |
+
// coords: int32, [N, 3]
|
213 |
+
// indices: int32, [N]
|
214 |
+
__global__ void kernel_morton3D(
|
215 |
+
const int * __restrict__ coords,
|
216 |
+
const uint32_t N,
|
217 |
+
int * indices
|
218 |
+
) {
|
219 |
+
// parallel
|
220 |
+
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
221 |
+
if (n >= N) return;
|
222 |
+
|
223 |
+
// locate
|
224 |
+
coords += n * 3;
|
225 |
+
indices[n] = __morton3D(coords[0], coords[1], coords[2]);
|
226 |
+
}
|
227 |
+
|
228 |
+
|
229 |
+
void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices) {
|
230 |
+
static constexpr uint32_t N_THREAD = 128;
|
231 |
+
kernel_morton3D<<<div_round_up(N, N_THREAD), N_THREAD>>>(coords.data_ptr<int>(), N, indices.data_ptr<int>());
|
232 |
+
}
|
233 |
+
|
234 |
+
|
235 |
+
// indices: int32, [N]
|
236 |
+
// coords: int32, [N, 3]
|
237 |
+
__global__ void kernel_morton3D_invert(
|
238 |
+
const int * __restrict__ indices,
|
239 |
+
const uint32_t N,
|
240 |
+
int * coords
|
241 |
+
) {
|
242 |
+
// parallel
|
243 |
+
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
244 |
+
if (n >= N) return;
|
245 |
+
|
246 |
+
// locate
|
247 |
+
coords += n * 3;
|
248 |
+
|
249 |
+
const int ind = indices[n];
|
250 |
+
|
251 |
+
coords[0] = __morton3D_invert(ind >> 0);
|
252 |
+
coords[1] = __morton3D_invert(ind >> 1);
|
253 |
+
coords[2] = __morton3D_invert(ind >> 2);
|
254 |
+
}
|
255 |
+
|
256 |
+
|
257 |
+
void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords) {
|
258 |
+
static constexpr uint32_t N_THREAD = 128;
|
259 |
+
kernel_morton3D_invert<<<div_round_up(N, N_THREAD), N_THREAD>>>(indices.data_ptr<int>(), N, coords.data_ptr<int>());
|
260 |
+
}
|
261 |
+
|
262 |
+
|
263 |
+
// grid: float, [C, H, H, H]
|
264 |
+
// N: int, C * H * H * H / 8
|
265 |
+
// density_thresh: float
|
266 |
+
// bitfield: uint8, [N]
|
267 |
+
template <typename scalar_t>
|
268 |
+
__global__ void kernel_packbits(
|
269 |
+
const scalar_t * __restrict__ grid,
|
270 |
+
const uint32_t N,
|
271 |
+
const float density_thresh,
|
272 |
+
uint8_t * bitfield
|
273 |
+
) {
|
274 |
+
// parallel per byte
|
275 |
+
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
276 |
+
if (n >= N) return;
|
277 |
+
|
278 |
+
// locate
|
279 |
+
grid += n * 8;
|
280 |
+
|
281 |
+
uint8_t bits = 0;
|
282 |
+
|
283 |
+
#pragma unroll
|
284 |
+
for (uint8_t i = 0; i < 8; i++) {
|
285 |
+
bits |= (grid[i] > density_thresh) ? ((uint8_t)1 << i) : 0;
|
286 |
+
}
|
287 |
+
|
288 |
+
bitfield[n] = bits;
|
289 |
+
}
|
290 |
+
|
291 |
+
|
292 |
+
void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield) {
|
293 |
+
|
294 |
+
static constexpr uint32_t N_THREAD = 128;
|
295 |
+
|
296 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
297 |
+
grid.scalar_type(), "packbits", ([&] {
|
298 |
+
kernel_packbits<<<div_round_up(N, N_THREAD), N_THREAD>>>(grid.data_ptr<scalar_t>(), N, density_thresh, bitfield.data_ptr<uint8_t>());
|
299 |
+
}));
|
300 |
+
}
|
301 |
+
|
302 |
+
////////////////////////////////////////////////////
|
303 |
+
///////////// training /////////////
|
304 |
+
////////////////////////////////////////////////////
|
305 |
+
|
306 |
+
// rays_o/d: [N, 3]
|
307 |
+
// grid: [CHHH / 8]
|
308 |
+
// xyzs, dirs, deltas: [M, 3], [M, 3], [M, 2]
|
309 |
+
// dirs: [M, 3]
|
310 |
+
// rays: [N, 3], idx, offset, num_steps
|
311 |
+
template <typename scalar_t>
|
312 |
+
__global__ void kernel_march_rays_train(
|
313 |
+
const scalar_t * __restrict__ rays_o,
|
314 |
+
const scalar_t * __restrict__ rays_d,
|
315 |
+
const uint8_t * __restrict__ grid,
|
316 |
+
const float bound,
|
317 |
+
const float dt_gamma, const uint32_t max_steps,
|
318 |
+
const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M,
|
319 |
+
const scalar_t* __restrict__ nears,
|
320 |
+
const scalar_t* __restrict__ fars,
|
321 |
+
scalar_t * xyzs, scalar_t * dirs, scalar_t * deltas,
|
322 |
+
int * rays,
|
323 |
+
int * counter,
|
324 |
+
const scalar_t* __restrict__ noises
|
325 |
+
) {
|
326 |
+
// parallel per ray
|
327 |
+
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
328 |
+
if (n >= N) return;
|
329 |
+
|
330 |
+
// locate
|
331 |
+
rays_o += n * 3;
|
332 |
+
rays_d += n * 3;
|
333 |
+
|
334 |
+
// ray marching
|
335 |
+
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
|
336 |
+
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
|
337 |
+
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
|
338 |
+
const float rH = 1 / (float)H;
|
339 |
+
const float H3 = H * H * H;
|
340 |
+
|
341 |
+
const float near = nears[n];
|
342 |
+
const float far = fars[n];
|
343 |
+
const float noise = noises[n];
|
344 |
+
|
345 |
+
const float dt_min = 2 * SQRT3() / max_steps;
|
346 |
+
const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H;
|
347 |
+
|
348 |
+
float t0 = near;
|
349 |
+
|
350 |
+
// perturb
|
351 |
+
t0 += clamp(t0 * dt_gamma, dt_min, dt_max) * noise;
|
352 |
+
|
353 |
+
// first pass: estimation of num_steps
|
354 |
+
float t = t0;
|
355 |
+
uint32_t num_steps = 0;
|
356 |
+
|
357 |
+
//if (t < far) printf("valid ray %d t=%f near=%f far=%f \n", n, t, near, far);
|
358 |
+
|
359 |
+
while (t < far && num_steps < max_steps) {
|
360 |
+
// current point
|
361 |
+
const float x = clamp(ox + t * dx, -bound, bound);
|
362 |
+
const float y = clamp(oy + t * dy, -bound, bound);
|
363 |
+
const float z = clamp(oz + t * dz, -bound, bound);
|
364 |
+
|
365 |
+
const float dt = clamp(t * dt_gamma, dt_min, dt_max);
|
366 |
+
|
367 |
+
// get mip level
|
368 |
+
const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]
|
369 |
+
|
370 |
+
const float mip_bound = fminf(scalbnf(1.0f, level), bound);
|
371 |
+
const float mip_rbound = 1 / mip_bound;
|
372 |
+
|
373 |
+
// convert to nearest grid position
|
374 |
+
const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
375 |
+
const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
376 |
+
const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
377 |
+
|
378 |
+
const uint32_t index = level * H3 + __morton3D(nx, ny, nz);
|
379 |
+
const bool occ = grid[index / 8] & (1 << (index % 8));
|
380 |
+
|
381 |
+
// if occpuied, advance a small step, and write to output
|
382 |
+
//if (n == 0) printf("t=%f density=%f vs thresh=%f step=%d\n", t, density, density_thresh, num_steps);
|
383 |
+
|
384 |
+
if (occ) {
|
385 |
+
num_steps++;
|
386 |
+
t += dt;
|
387 |
+
// else, skip a large step (basically skip a voxel grid)
|
388 |
+
} else {
|
389 |
+
// calc distance to next voxel
|
390 |
+
const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx;
|
391 |
+
const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy;
|
392 |
+
const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz;
|
393 |
+
|
394 |
+
const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
|
395 |
+
// step until next voxel
|
396 |
+
do {
|
397 |
+
t += clamp(t * dt_gamma, dt_min, dt_max);
|
398 |
+
} while (t < tt);
|
399 |
+
}
|
400 |
+
}
|
401 |
+
|
402 |
+
//printf("[n=%d] num_steps=%d, near=%f, far=%f, dt=%f, max_steps=%f\n", n, num_steps, near, far, dt_min, (far - near) / dt_min);
|
403 |
+
|
404 |
+
// second pass: really locate and write points & dirs
|
405 |
+
uint32_t point_index = atomicAdd(counter, num_steps);
|
406 |
+
uint32_t ray_index = atomicAdd(counter + 1, 1);
|
407 |
+
|
408 |
+
//printf("[n=%d] num_steps=%d, point_index=%d, ray_index=%d\n", n, num_steps, point_index, ray_index);
|
409 |
+
|
410 |
+
// write rays
|
411 |
+
rays[ray_index * 3] = n;
|
412 |
+
rays[ray_index * 3 + 1] = point_index;
|
413 |
+
rays[ray_index * 3 + 2] = num_steps;
|
414 |
+
|
415 |
+
if (num_steps == 0) return;
|
416 |
+
if (point_index + num_steps > M) return;
|
417 |
+
|
418 |
+
xyzs += point_index * 3;
|
419 |
+
dirs += point_index * 3;
|
420 |
+
deltas += point_index * 2;
|
421 |
+
|
422 |
+
t = t0;
|
423 |
+
uint32_t step = 0;
|
424 |
+
|
425 |
+
float last_t = t;
|
426 |
+
|
427 |
+
while (t < far && step < num_steps) {
|
428 |
+
// current point
|
429 |
+
const float x = clamp(ox + t * dx, -bound, bound);
|
430 |
+
const float y = clamp(oy + t * dy, -bound, bound);
|
431 |
+
const float z = clamp(oz + t * dz, -bound, bound);
|
432 |
+
|
433 |
+
const float dt = clamp(t * dt_gamma, dt_min, dt_max);
|
434 |
+
|
435 |
+
// get mip level
|
436 |
+
const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]
|
437 |
+
|
438 |
+
const float mip_bound = fminf(scalbnf(1.0f, level), bound);
|
439 |
+
const float mip_rbound = 1 / mip_bound;
|
440 |
+
|
441 |
+
// convert to nearest grid position
|
442 |
+
const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
443 |
+
const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
444 |
+
const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
445 |
+
|
446 |
+
// query grid
|
447 |
+
const uint32_t index = level * H3 + __morton3D(nx, ny, nz);
|
448 |
+
const bool occ = grid[index / 8] & (1 << (index % 8));
|
449 |
+
|
450 |
+
// if occpuied, advance a small step, and write to output
|
451 |
+
if (occ) {
|
452 |
+
// write step
|
453 |
+
xyzs[0] = x;
|
454 |
+
xyzs[1] = y;
|
455 |
+
xyzs[2] = z;
|
456 |
+
dirs[0] = dx;
|
457 |
+
dirs[1] = dy;
|
458 |
+
dirs[2] = dz;
|
459 |
+
t += dt;
|
460 |
+
deltas[0] = dt;
|
461 |
+
deltas[1] = t - last_t; // used to calc depth
|
462 |
+
last_t = t;
|
463 |
+
xyzs += 3;
|
464 |
+
dirs += 3;
|
465 |
+
deltas += 2;
|
466 |
+
step++;
|
467 |
+
// else, skip a large step (basically skip a voxel grid)
|
468 |
+
} else {
|
469 |
+
// calc distance to next voxel
|
470 |
+
const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx;
|
471 |
+
const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy;
|
472 |
+
const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz;
|
473 |
+
const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
|
474 |
+
// step until next voxel
|
475 |
+
do {
|
476 |
+
t += clamp(t * dt_gamma, dt_min, dt_max);
|
477 |
+
} while (t < tt);
|
478 |
+
}
|
479 |
+
}
|
480 |
+
}
|
481 |
+
|
482 |
+
void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises) {
|
483 |
+
|
484 |
+
static constexpr uint32_t N_THREAD = 128;
|
485 |
+
|
486 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
487 |
+
rays_o.scalar_type(), "march_rays_train", ([&] {
|
488 |
+
kernel_march_rays_train<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), grid.data_ptr<uint8_t>(), bound, dt_gamma, max_steps, N, C, H, M, nears.data_ptr<scalar_t>(), fars.data_ptr<scalar_t>(), xyzs.data_ptr<scalar_t>(), dirs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), counter.data_ptr<int>(), noises.data_ptr<scalar_t>());
|
489 |
+
}));
|
490 |
+
}
|
491 |
+
|
492 |
+
|
493 |
+
// sigmas: [M]
|
494 |
+
// rgbs: [M, 3]
|
495 |
+
// deltas: [M, 2]
|
496 |
+
// rays: [N, 3], idx, offset, num_steps
|
497 |
+
// weights_sum: [N], final pixel alpha
|
498 |
+
// depth: [N,]
|
499 |
+
// image: [N, 3]
|
500 |
+
template <typename scalar_t>
|
501 |
+
__global__ void kernel_composite_rays_train_forward(
|
502 |
+
const scalar_t * __restrict__ sigmas,
|
503 |
+
const scalar_t * __restrict__ rgbs,
|
504 |
+
const scalar_t * __restrict__ deltas,
|
505 |
+
const int * __restrict__ rays,
|
506 |
+
const uint32_t M, const uint32_t N, const float T_thresh,
|
507 |
+
scalar_t * weights_sum,
|
508 |
+
scalar_t * depth,
|
509 |
+
scalar_t * image
|
510 |
+
) {
|
511 |
+
// parallel per ray
|
512 |
+
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
513 |
+
if (n >= N) return;
|
514 |
+
|
515 |
+
// locate
|
516 |
+
uint32_t index = rays[n * 3];
|
517 |
+
uint32_t offset = rays[n * 3 + 1];
|
518 |
+
uint32_t num_steps = rays[n * 3 + 2];
|
519 |
+
|
520 |
+
// empty ray, or ray that exceed max step count.
|
521 |
+
if (num_steps == 0 || offset + num_steps > M) {
|
522 |
+
weights_sum[index] = 0;
|
523 |
+
depth[index] = 0;
|
524 |
+
image[index * 3] = 0;
|
525 |
+
image[index * 3 + 1] = 0;
|
526 |
+
image[index * 3 + 2] = 0;
|
527 |
+
return;
|
528 |
+
}
|
529 |
+
|
530 |
+
sigmas += offset;
|
531 |
+
rgbs += offset * 3;
|
532 |
+
deltas += offset * 2;
|
533 |
+
|
534 |
+
// accumulate
|
535 |
+
uint32_t step = 0;
|
536 |
+
|
537 |
+
scalar_t T = 1.0f;
|
538 |
+
scalar_t r = 0, g = 0, b = 0, ws = 0, t = 0, d = 0;
|
539 |
+
|
540 |
+
while (step < num_steps) {
|
541 |
+
|
542 |
+
const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);
|
543 |
+
const scalar_t weight = alpha * T;
|
544 |
+
|
545 |
+
r += weight * rgbs[0];
|
546 |
+
g += weight * rgbs[1];
|
547 |
+
b += weight * rgbs[2];
|
548 |
+
|
549 |
+
t += deltas[1]; // real delta
|
550 |
+
d += weight * t;
|
551 |
+
|
552 |
+
ws += weight;
|
553 |
+
|
554 |
+
T *= 1.0f - alpha;
|
555 |
+
|
556 |
+
// minimal remained transmittence
|
557 |
+
if (T < T_thresh) break;
|
558 |
+
|
559 |
+
//printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d);
|
560 |
+
|
561 |
+
// locate
|
562 |
+
sigmas++;
|
563 |
+
rgbs += 3;
|
564 |
+
deltas += 2;
|
565 |
+
|
566 |
+
step++;
|
567 |
+
}
|
568 |
+
|
569 |
+
//printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d);
|
570 |
+
|
571 |
+
// write
|
572 |
+
weights_sum[index] = ws; // weights_sum
|
573 |
+
depth[index] = d;
|
574 |
+
image[index * 3] = r;
|
575 |
+
image[index * 3 + 1] = g;
|
576 |
+
image[index * 3 + 2] = b;
|
577 |
+
}
|
578 |
+
|
579 |
+
|
580 |
+
void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor depth, at::Tensor image) {
|
581 |
+
|
582 |
+
static constexpr uint32_t N_THREAD = 128;
|
583 |
+
|
584 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
585 |
+
sigmas.scalar_type(), "composite_rays_train_forward", ([&] {
|
586 |
+
kernel_composite_rays_train_forward<<<div_round_up(N, N_THREAD), N_THREAD>>>(sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), M, N, T_thresh, weights_sum.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>());
|
587 |
+
}));
|
588 |
+
}
|
589 |
+
|
590 |
+
|
591 |
+
// grad_weights_sum: [N,]
|
592 |
+
// grad: [N, 3]
|
593 |
+
// sigmas: [M]
|
594 |
+
// rgbs: [M, 3]
|
595 |
+
// deltas: [M, 2]
|
596 |
+
// rays: [N, 3], idx, offset, num_steps
|
597 |
+
// weights_sum: [N,], weights_sum here
|
598 |
+
// image: [N, 3]
|
599 |
+
// grad_sigmas: [M]
|
600 |
+
// grad_rgbs: [M, 3]
|
601 |
+
template <typename scalar_t>
|
602 |
+
__global__ void kernel_composite_rays_train_backward(
|
603 |
+
const scalar_t * __restrict__ grad_weights_sum,
|
604 |
+
const scalar_t * __restrict__ grad_image,
|
605 |
+
const scalar_t * __restrict__ sigmas,
|
606 |
+
const scalar_t * __restrict__ rgbs,
|
607 |
+
const scalar_t * __restrict__ deltas,
|
608 |
+
const int * __restrict__ rays,
|
609 |
+
const scalar_t * __restrict__ weights_sum,
|
610 |
+
const scalar_t * __restrict__ image,
|
611 |
+
const uint32_t M, const uint32_t N, const float T_thresh,
|
612 |
+
scalar_t * grad_sigmas,
|
613 |
+
scalar_t * grad_rgbs
|
614 |
+
) {
|
615 |
+
// parallel per ray
|
616 |
+
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
617 |
+
if (n >= N) return;
|
618 |
+
|
619 |
+
// locate
|
620 |
+
uint32_t index = rays[n * 3];
|
621 |
+
uint32_t offset = rays[n * 3 + 1];
|
622 |
+
uint32_t num_steps = rays[n * 3 + 2];
|
623 |
+
|
624 |
+
if (num_steps == 0 || offset + num_steps > M) return;
|
625 |
+
|
626 |
+
grad_weights_sum += index;
|
627 |
+
grad_image += index * 3;
|
628 |
+
weights_sum += index;
|
629 |
+
image += index * 3;
|
630 |
+
sigmas += offset;
|
631 |
+
rgbs += offset * 3;
|
632 |
+
deltas += offset * 2;
|
633 |
+
grad_sigmas += offset;
|
634 |
+
grad_rgbs += offset * 3;
|
635 |
+
|
636 |
+
// accumulate
|
637 |
+
uint32_t step = 0;
|
638 |
+
|
639 |
+
scalar_t T = 1.0f;
|
640 |
+
const scalar_t r_final = image[0], g_final = image[1], b_final = image[2], ws_final = weights_sum[0];
|
641 |
+
scalar_t r = 0, g = 0, b = 0, ws = 0;
|
642 |
+
|
643 |
+
while (step < num_steps) {
|
644 |
+
|
645 |
+
const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);
|
646 |
+
const scalar_t weight = alpha * T;
|
647 |
+
|
648 |
+
r += weight * rgbs[0];
|
649 |
+
g += weight * rgbs[1];
|
650 |
+
b += weight * rgbs[2];
|
651 |
+
ws += weight;
|
652 |
+
|
653 |
+
T *= 1.0f - alpha;
|
654 |
+
|
655 |
+
// check https://note.kiui.moe/others/nerf_gradient/ for the gradient calculation.
|
656 |
+
// write grad_rgbs
|
657 |
+
grad_rgbs[0] = grad_image[0] * weight;
|
658 |
+
grad_rgbs[1] = grad_image[1] * weight;
|
659 |
+
grad_rgbs[2] = grad_image[2] * weight;
|
660 |
+
|
661 |
+
// write grad_sigmas
|
662 |
+
grad_sigmas[0] = deltas[0] * (
|
663 |
+
grad_image[0] * (T * rgbs[0] - (r_final - r)) +
|
664 |
+
grad_image[1] * (T * rgbs[1] - (g_final - g)) +
|
665 |
+
grad_image[2] * (T * rgbs[2] - (b_final - b)) +
|
666 |
+
grad_weights_sum[0] * (1 - ws_final)
|
667 |
+
);
|
668 |
+
|
669 |
+
//printf("[n=%d] num_steps=%d, T=%f, grad_sigmas=%f, r_final=%f, r=%f\n", n, step, T, grad_sigmas[0], r_final, r);
|
670 |
+
// minimal remained transmittence
|
671 |
+
if (T < T_thresh) break;
|
672 |
+
|
673 |
+
// locate
|
674 |
+
sigmas++;
|
675 |
+
rgbs += 3;
|
676 |
+
deltas += 2;
|
677 |
+
grad_sigmas++;
|
678 |
+
grad_rgbs += 3;
|
679 |
+
|
680 |
+
step++;
|
681 |
+
}
|
682 |
+
}
|
683 |
+
|
684 |
+
|
685 |
+
void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs) {
|
686 |
+
|
687 |
+
static constexpr uint32_t N_THREAD = 128;
|
688 |
+
|
689 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
690 |
+
grad_image.scalar_type(), "composite_rays_train_backward", ([&] {
|
691 |
+
kernel_composite_rays_train_backward<<<div_round_up(N, N_THREAD), N_THREAD>>>(grad_weights_sum.data_ptr<scalar_t>(), grad_image.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), weights_sum.data_ptr<scalar_t>(), image.data_ptr<scalar_t>(), M, N, T_thresh, grad_sigmas.data_ptr<scalar_t>(), grad_rgbs.data_ptr<scalar_t>());
|
692 |
+
}));
|
693 |
+
}
|
694 |
+
|
695 |
+
|
696 |
+
////////////////////////////////////////////////////
|
697 |
+
///////////// infernce /////////////
|
698 |
+
////////////////////////////////////////////////////
|
699 |
+
|
700 |
+
template <typename scalar_t>
|
701 |
+
__global__ void kernel_march_rays(
|
702 |
+
const uint32_t n_alive,
|
703 |
+
const uint32_t n_step,
|
704 |
+
const int* __restrict__ rays_alive,
|
705 |
+
const scalar_t* __restrict__ rays_t,
|
706 |
+
const scalar_t* __restrict__ rays_o,
|
707 |
+
const scalar_t* __restrict__ rays_d,
|
708 |
+
const float bound,
|
709 |
+
const float dt_gamma, const uint32_t max_steps,
|
710 |
+
const uint32_t C, const uint32_t H,
|
711 |
+
const uint8_t * __restrict__ grid,
|
712 |
+
const scalar_t* __restrict__ nears,
|
713 |
+
const scalar_t* __restrict__ fars,
|
714 |
+
scalar_t* xyzs, scalar_t* dirs, scalar_t* deltas,
|
715 |
+
const scalar_t* __restrict__ noises
|
716 |
+
) {
|
717 |
+
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
718 |
+
if (n >= n_alive) return;
|
719 |
+
|
720 |
+
const int index = rays_alive[n]; // ray id
|
721 |
+
const float noise = noises[n];
|
722 |
+
|
723 |
+
// locate
|
724 |
+
rays_o += index * 3;
|
725 |
+
rays_d += index * 3;
|
726 |
+
xyzs += n * n_step * 3;
|
727 |
+
dirs += n * n_step * 3;
|
728 |
+
deltas += n * n_step * 2;
|
729 |
+
|
730 |
+
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
|
731 |
+
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
|
732 |
+
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
|
733 |
+
const float rH = 1 / (float)H;
|
734 |
+
const float H3 = H * H * H;
|
735 |
+
|
736 |
+
float t = rays_t[index]; // current ray's t
|
737 |
+
const float near = nears[index], far = fars[index];
|
738 |
+
|
739 |
+
const float dt_min = 2 * SQRT3() / max_steps;
|
740 |
+
const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H;
|
741 |
+
|
742 |
+
// march for n_step steps, record points
|
743 |
+
uint32_t step = 0;
|
744 |
+
|
745 |
+
// introduce some randomness
|
746 |
+
t += clamp(t * dt_gamma, dt_min, dt_max) * noise;
|
747 |
+
|
748 |
+
float last_t = t;
|
749 |
+
|
750 |
+
while (t < far && step < n_step) {
|
751 |
+
// current point
|
752 |
+
const float x = clamp(ox + t * dx, -bound, bound);
|
753 |
+
const float y = clamp(oy + t * dy, -bound, bound);
|
754 |
+
const float z = clamp(oz + t * dz, -bound, bound);
|
755 |
+
|
756 |
+
const float dt = clamp(t * dt_gamma, dt_min, dt_max);
|
757 |
+
|
758 |
+
// get mip level
|
759 |
+
const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]
|
760 |
+
|
761 |
+
const float mip_bound = fminf(scalbnf(1, level), bound);
|
762 |
+
const float mip_rbound = 1 / mip_bound;
|
763 |
+
|
764 |
+
// convert to nearest grid position
|
765 |
+
const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
766 |
+
const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
767 |
+
const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
768 |
+
|
769 |
+
const uint32_t index = level * H3 + __morton3D(nx, ny, nz);
|
770 |
+
const bool occ = grid[index / 8] & (1 << (index % 8));
|
771 |
+
|
772 |
+
// if occpuied, advance a small step, and write to output
|
773 |
+
if (occ) {
|
774 |
+
// write step
|
775 |
+
xyzs[0] = x;
|
776 |
+
xyzs[1] = y;
|
777 |
+
xyzs[2] = z;
|
778 |
+
dirs[0] = dx;
|
779 |
+
dirs[1] = dy;
|
780 |
+
dirs[2] = dz;
|
781 |
+
// calc dt
|
782 |
+
t += dt;
|
783 |
+
deltas[0] = dt;
|
784 |
+
deltas[1] = t - last_t; // used to calc depth
|
785 |
+
last_t = t;
|
786 |
+
// step
|
787 |
+
xyzs += 3;
|
788 |
+
dirs += 3;
|
789 |
+
deltas += 2;
|
790 |
+
step++;
|
791 |
+
|
792 |
+
// else, skip a large step (basically skip a voxel grid)
|
793 |
+
} else {
|
794 |
+
// calc distance to next voxel
|
795 |
+
const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx;
|
796 |
+
const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy;
|
797 |
+
const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz;
|
798 |
+
const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
|
799 |
+
// step until next voxel
|
800 |
+
do {
|
801 |
+
t += clamp(t * dt_gamma, dt_min, dt_max);
|
802 |
+
} while (t < tt);
|
803 |
+
}
|
804 |
+
}
|
805 |
+
}
|
806 |
+
|
807 |
+
|
808 |
+
void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor near, const at::Tensor far, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises) {
|
809 |
+
static constexpr uint32_t N_THREAD = 128;
|
810 |
+
|
811 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
812 |
+
rays_o.scalar_type(), "march_rays", ([&] {
|
813 |
+
kernel_march_rays<<<div_round_up(n_alive, N_THREAD), N_THREAD>>>(n_alive, n_step, rays_alive.data_ptr<int>(), rays_t.data_ptr<scalar_t>(), rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), bound, dt_gamma, max_steps, C, H, grid.data_ptr<uint8_t>(), near.data_ptr<scalar_t>(), far.data_ptr<scalar_t>(), xyzs.data_ptr<scalar_t>(), dirs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), noises.data_ptr<scalar_t>());
|
814 |
+
}));
|
815 |
+
}
|
816 |
+
|
817 |
+
|
818 |
+
template <typename scalar_t>
|
819 |
+
__global__ void kernel_composite_rays(
|
820 |
+
const uint32_t n_alive,
|
821 |
+
const uint32_t n_step,
|
822 |
+
const float T_thresh,
|
823 |
+
int* rays_alive,
|
824 |
+
scalar_t* rays_t,
|
825 |
+
const scalar_t* __restrict__ sigmas,
|
826 |
+
const scalar_t* __restrict__ rgbs,
|
827 |
+
const scalar_t* __restrict__ deltas,
|
828 |
+
scalar_t* weights_sum, scalar_t* depth, scalar_t* image
|
829 |
+
) {
|
830 |
+
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
831 |
+
if (n >= n_alive) return;
|
832 |
+
|
833 |
+
const int index = rays_alive[n]; // ray id
|
834 |
+
|
835 |
+
// locate
|
836 |
+
sigmas += n * n_step;
|
837 |
+
rgbs += n * n_step * 3;
|
838 |
+
deltas += n * n_step * 2;
|
839 |
+
|
840 |
+
rays_t += index;
|
841 |
+
weights_sum += index;
|
842 |
+
depth += index;
|
843 |
+
image += index * 3;
|
844 |
+
|
845 |
+
scalar_t t = rays_t[0]; // current ray's t
|
846 |
+
|
847 |
+
scalar_t weight_sum = weights_sum[0];
|
848 |
+
scalar_t d = depth[0];
|
849 |
+
scalar_t r = image[0];
|
850 |
+
scalar_t g = image[1];
|
851 |
+
scalar_t b = image[2];
|
852 |
+
|
853 |
+
// accumulate
|
854 |
+
uint32_t step = 0;
|
855 |
+
while (step < n_step) {
|
856 |
+
|
857 |
+
// ray is terminated if delta == 0
|
858 |
+
if (deltas[0] == 0) break;
|
859 |
+
|
860 |
+
const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);
|
861 |
+
|
862 |
+
/*
|
863 |
+
T_0 = 1; T_i = \prod_{j=0}^{i-1} (1 - alpha_j)
|
864 |
+
w_i = alpha_i * T_i
|
865 |
+
-->
|
866 |
+
T_i = 1 - \sum_{j=0}^{i-1} w_j
|
867 |
+
*/
|
868 |
+
const scalar_t T = 1 - weight_sum;
|
869 |
+
const scalar_t weight = alpha * T;
|
870 |
+
weight_sum += weight;
|
871 |
+
|
872 |
+
t += deltas[1]; // real delta
|
873 |
+
d += weight * t;
|
874 |
+
r += weight * rgbs[0];
|
875 |
+
g += weight * rgbs[1];
|
876 |
+
b += weight * rgbs[2];
|
877 |
+
|
878 |
+
//printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d);
|
879 |
+
|
880 |
+
// ray is terminated if T is too small
|
881 |
+
// use a larger bound to further accelerate inference
|
882 |
+
if (T < T_thresh) break;
|
883 |
+
|
884 |
+
// locate
|
885 |
+
sigmas++;
|
886 |
+
rgbs += 3;
|
887 |
+
deltas += 2;
|
888 |
+
step++;
|
889 |
+
}
|
890 |
+
|
891 |
+
//printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d);
|
892 |
+
|
893 |
+
// rays_alive = -1 means ray is terminated early.
|
894 |
+
if (step < n_step) {
|
895 |
+
rays_alive[n] = -1;
|
896 |
+
} else {
|
897 |
+
rays_t[0] = t;
|
898 |
+
}
|
899 |
+
|
900 |
+
weights_sum[0] = weight_sum; // this is the thing I needed!
|
901 |
+
depth[0] = d;
|
902 |
+
image[0] = r;
|
903 |
+
image[1] = g;
|
904 |
+
image[2] = b;
|
905 |
+
}
|
906 |
+
|
907 |
+
|
908 |
+
void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, at::Tensor weights, at::Tensor depth, at::Tensor image) {
|
909 |
+
static constexpr uint32_t N_THREAD = 128;
|
910 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
911 |
+
image.scalar_type(), "composite_rays", ([&] {
|
912 |
+
kernel_composite_rays<<<div_round_up(n_alive, N_THREAD), N_THREAD>>>(n_alive, n_step, T_thresh, rays_alive.data_ptr<int>(), rays_t.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), weights.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>());
|
913 |
+
}));
|
914 |
+
}
|
raymarching/src/raymarching.h
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <stdint.h>
|
4 |
+
#include <torch/torch.h>
|
5 |
+
|
6 |
+
|
7 |
+
void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars);
|
8 |
+
void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords);
|
9 |
+
void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices);
|
10 |
+
void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords);
|
11 |
+
void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield);
|
12 |
+
|
13 |
+
void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises);
|
14 |
+
void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);
|
15 |
+
void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs);
|
16 |
+
|
17 |
+
void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises);
|
18 |
+
void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);
|
readme.md
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Stable-Dreamfusion
|
2 |
+
|
3 |
+
A pytorch implementation of the text-to-3D model **Dreamfusion**, powered by the [Stable Diffusion](https://github.com/CompVis/stable-diffusion) text-to-2D model.
|
4 |
+
|
5 |
+
The original paper's project page: [_DreamFusion: Text-to-3D using 2D Diffusion_](https://dreamfusion3d.github.io/).
|
6 |
+
|
7 |
+
Example of "a squierrel" and "a hamburger":
|
8 |
+
|
9 |
+
### [Gallery](assets/gallery.md) | [Update Logs](assets/update_logs.md)
|
10 |
+
|
11 |
+
# Important Notice
|
12 |
+
This project is a **work-in-progress**, and contains lots of differences from the paper. Also, many features are still not implmented now. The current generation quality cannot match the results from the original paper, and still fail badly for many prompts.
|
13 |
+
|
14 |
+
## Notable differences from the paper
|
15 |
+
* Since the Imagen model is not publicly available, we use [Stable Diffusion](https://github.com/CompVis/stable-diffusion) to replace it (implementation from [diffusers](https://github.com/huggingface/diffusers)). Different from Imagen, Stable-Diffusion is a latent diffusion model, which diffuses in a latent space instead of the original image space. Therefore, we need the loss to propagate back from the VAE's encoder part too, which introduces extra time cost in training. Currently, 15000 training steps take about 5 hours to train on a V100.
|
16 |
+
* We use the [multi-resolution grid encoder](https://github.com/NVlabs/instant-ngp/) to implement the NeRF backbone (implementation from [torch-ngp](https://github.com/ashawkey/torch-ngp)), which enables much faster rendering (~10FPS at 800x800).
|
17 |
+
* We use the Adam optimizer with a larger initial learning rate.
|
18 |
+
|
19 |
+
|
20 |
+
## TODOs
|
21 |
+
* The shading part & normal evaluation.
|
22 |
+
* Exporting colored mesh.
|
23 |
+
|
24 |
+
|
25 |
+
# Install
|
26 |
+
|
27 |
+
```bash
|
28 |
+
git clone https://github.com/ashawkey/stable-dreamfusion.git
|
29 |
+
cd stable-dreamfusion
|
30 |
+
```
|
31 |
+
|
32 |
+
**Important**: To download the Stable Diffusion model checkpoint, you should create a file under this directory called `TOKEN` and copy your hugging face [access token](https://huggingface.co/docs/hub/security-tokens) into it.
|
33 |
+
|
34 |
+
### Install with pip
|
35 |
+
```bash
|
36 |
+
pip install -r requirements.txt
|
37 |
+
|
38 |
+
# (optional) install the tcnn backbone if using --tcnn
|
39 |
+
pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
|
40 |
+
|
41 |
+
# (optional) install CLIP guidance for the dreamfield setting
|
42 |
+
pip install git+https://github.com/openai/CLIP.git
|
43 |
+
|
44 |
+
# (optional) install nvdiffrast for exporting textured mesh
|
45 |
+
pip install git+https://github.com/NVlabs/nvdiffrast/
|
46 |
+
```
|
47 |
+
|
48 |
+
### Build extension (optional)
|
49 |
+
By default, we use [`load`](https://pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.load) to build the extension at runtime.
|
50 |
+
We also provide the `setup.py` to build each extension:
|
51 |
+
```bash
|
52 |
+
# install all extension modules
|
53 |
+
bash scripts/install_ext.sh
|
54 |
+
|
55 |
+
# if you want to install manually, here is an example:
|
56 |
+
pip install ./raymarching # install to python path (you still need the raymarching/ folder, since this only install the built extension.)
|
57 |
+
```
|
58 |
+
|
59 |
+
### Tested environments
|
60 |
+
* Ubuntu 22 with torch 1.12 & CUDA 11.6 on a V100.
|
61 |
+
|
62 |
+
|
63 |
+
# Usage
|
64 |
+
|
65 |
+
First time running will take some time to compile the CUDA extensions.
|
66 |
+
|
67 |
+
```bash
|
68 |
+
### stable-dreamfusion setting
|
69 |
+
# train with text prompt
|
70 |
+
# `-O` equals `--cuda_ray --fp16 --dir_text`
|
71 |
+
python main_nerf.py --text "a hamburger" --workspace trial -O
|
72 |
+
|
73 |
+
# test (exporting 360 video)
|
74 |
+
python main_nerf.py --text "a hamburger" --workspace trial -O --test
|
75 |
+
|
76 |
+
# test with a GUI (free view control!)
|
77 |
+
python main_nerf.py --text "a hamburger" --workspace trial -O --test --gui
|
78 |
+
|
79 |
+
### dreamfields (CLIP) setting
|
80 |
+
python main_nerf.py --text "a hamburger" --workspace trial_clip -O --guidance clip
|
81 |
+
python main_nerf.py --text "a hamburger" --workspace trial_clip -O --test --gui --guidance clip
|
82 |
+
```
|
83 |
+
|
84 |
+
# Acknowledgement
|
85 |
+
|
86 |
+
* The amazing original work: [_DreamFusion: Text-to-3D using 2D Diffusion_](https://dreamfusion3d.github.io/).
|
87 |
+
|
88 |
+
* Huge thanks to the [Stable Diffusion](https://github.com/CompVis/stable-diffusion) and the [diffusers](https://github.com/huggingface/diffusers) library.
|
89 |
+
|
90 |
+
|
91 |
+
* The GUI is developed with [DearPyGui](https://github.com/hoffstadt/DearPyGui).
|
requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch-ema
|
2 |
+
ninja
|
3 |
+
trimesh
|
4 |
+
opencv-python
|
5 |
+
tensorboardX
|
6 |
+
torch
|
7 |
+
numpy
|
8 |
+
pandas
|
9 |
+
tqdm
|
10 |
+
matplotlib
|
11 |
+
PyMCubes
|
12 |
+
rich
|
13 |
+
pysdf
|
14 |
+
dearpygui
|
15 |
+
scipy
|
16 |
+
diffusers
|
17 |
+
xatlas
|
scripts/install_ext.sh
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pip install ./raymarching
|
2 |
+
pip install ./shencoder
|
3 |
+
pip install ./freqencoder
|
4 |
+
pip install ./gridencoder
|
scripts/run.sh
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /bin/bash
|
2 |
+
|
3 |
+
CUDA_VISIBLE_DEVICES=1 python main_nerf.py -O --text "a DSLR photo of cthulhu" --workspace trial_cthulhu
|
4 |
+
CUDA_VISIBLE_DEVICES=1 python main_nerf.py -O --text "a DSLR photo of a squirrel" --workspace trial_squirrel
|
5 |
+
CUDA_VISIBLE_DEVICES=1 python main_nerf.py -O --text "a DSLR photo of a cat lying on its side batting at a ball of yarn" --workspace trial_cat_lying
|
shencoder/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .sphere_harmonics import SHEncoder
|
shencoder/backend.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from torch.utils.cpp_extension import load
|
3 |
+
|
4 |
+
_src_path = os.path.dirname(os.path.abspath(__file__))
|
5 |
+
|
6 |
+
nvcc_flags = [
|
7 |
+
'-O3', '-std=c++14',
|
8 |
+
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
9 |
+
]
|
10 |
+
|
11 |
+
if os.name == "posix":
|
12 |
+
c_flags = ['-O3', '-std=c++14']
|
13 |
+
elif os.name == "nt":
|
14 |
+
c_flags = ['/O2', '/std:c++17']
|
15 |
+
|
16 |
+
# find cl.exe
|
17 |
+
def find_cl_path():
|
18 |
+
import glob
|
19 |
+
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
20 |
+
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
|
21 |
+
if paths:
|
22 |
+
return paths[0]
|
23 |
+
|
24 |
+
# If cl.exe is not on path, try to find it.
|
25 |
+
if os.system("where cl.exe >nul 2>nul") != 0:
|
26 |
+
cl_path = find_cl_path()
|
27 |
+
if cl_path is None:
|
28 |
+
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
29 |
+
os.environ["PATH"] += ";" + cl_path
|
30 |
+
|
31 |
+
_backend = load(name='_sh_encoder',
|
32 |
+
extra_cflags=c_flags,
|
33 |
+
extra_cuda_cflags=nvcc_flags,
|
34 |
+
sources=[os.path.join(_src_path, 'src', f) for f in [
|
35 |
+
'shencoder.cu',
|
36 |
+
'bindings.cpp',
|
37 |
+
]],
|
38 |
+
)
|
39 |
+
|
40 |
+
__all__ = ['_backend']
|
shencoder/setup.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from setuptools import setup
|
3 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
4 |
+
|
5 |
+
_src_path = os.path.dirname(os.path.abspath(__file__))
|
6 |
+
|
7 |
+
nvcc_flags = [
|
8 |
+
'-O3', '-std=c++14',
|
9 |
+
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
10 |
+
]
|
11 |
+
|
12 |
+
if os.name == "posix":
|
13 |
+
c_flags = ['-O3', '-std=c++14']
|
14 |
+
elif os.name == "nt":
|
15 |
+
c_flags = ['/O2', '/std:c++17']
|
16 |
+
|
17 |
+
# find cl.exe
|
18 |
+
def find_cl_path():
|
19 |
+
import glob
|
20 |
+
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
21 |
+
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
|
22 |
+
if paths:
|
23 |
+
return paths[0]
|
24 |
+
|
25 |
+
# If cl.exe is not on path, try to find it.
|
26 |
+
if os.system("where cl.exe >nul 2>nul") != 0:
|
27 |
+
cl_path = find_cl_path()
|
28 |
+
if cl_path is None:
|
29 |
+
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
30 |
+
os.environ["PATH"] += ";" + cl_path
|
31 |
+
|
32 |
+
setup(
|
33 |
+
name='shencoder', # package name, import this to use python API
|
34 |
+
ext_modules=[
|
35 |
+
CUDAExtension(
|
36 |
+
name='_shencoder', # extension name, import this to use CUDA API
|
37 |
+
sources=[os.path.join(_src_path, 'src', f) for f in [
|
38 |
+
'shencoder.cu',
|
39 |
+
'bindings.cpp',
|
40 |
+
]],
|
41 |
+
extra_compile_args={
|
42 |
+
'cxx': c_flags,
|
43 |
+
'nvcc': nvcc_flags,
|
44 |
+
}
|
45 |
+
),
|
46 |
+
],
|
47 |
+
cmdclass={
|
48 |
+
'build_ext': BuildExtension,
|
49 |
+
}
|
50 |
+
)
|
shencoder/sphere_harmonics.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.autograd import Function
|
6 |
+
from torch.autograd.function import once_differentiable
|
7 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
8 |
+
|
9 |
+
try:
|
10 |
+
import _shencoder as _backend
|
11 |
+
except ImportError:
|
12 |
+
from .backend import _backend
|
13 |
+
|
14 |
+
class _sh_encoder(Function):
|
15 |
+
@staticmethod
|
16 |
+
@custom_fwd(cast_inputs=torch.float32) # force float32 for better precision
|
17 |
+
def forward(ctx, inputs, degree, calc_grad_inputs=False):
|
18 |
+
# inputs: [B, input_dim], float in [-1, 1]
|
19 |
+
# RETURN: [B, F], float
|
20 |
+
|
21 |
+
inputs = inputs.contiguous()
|
22 |
+
B, input_dim = inputs.shape # batch size, coord dim
|
23 |
+
output_dim = degree ** 2
|
24 |
+
|
25 |
+
outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)
|
26 |
+
|
27 |
+
if calc_grad_inputs:
|
28 |
+
dy_dx = torch.empty(B, input_dim * output_dim, dtype=inputs.dtype, device=inputs.device)
|
29 |
+
else:
|
30 |
+
dy_dx = None
|
31 |
+
|
32 |
+
_backend.sh_encode_forward(inputs, outputs, B, input_dim, degree, dy_dx)
|
33 |
+
|
34 |
+
ctx.save_for_backward(inputs, dy_dx)
|
35 |
+
ctx.dims = [B, input_dim, degree]
|
36 |
+
|
37 |
+
return outputs
|
38 |
+
|
39 |
+
@staticmethod
|
40 |
+
#@once_differentiable
|
41 |
+
@custom_bwd
|
42 |
+
def backward(ctx, grad):
|
43 |
+
# grad: [B, C * C]
|
44 |
+
|
45 |
+
inputs, dy_dx = ctx.saved_tensors
|
46 |
+
|
47 |
+
if dy_dx is not None:
|
48 |
+
grad = grad.contiguous()
|
49 |
+
B, input_dim, degree = ctx.dims
|
50 |
+
grad_inputs = torch.zeros_like(inputs)
|
51 |
+
_backend.sh_encode_backward(grad, inputs, B, input_dim, degree, dy_dx, grad_inputs)
|
52 |
+
return grad_inputs, None, None
|
53 |
+
else:
|
54 |
+
return None, None, None
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
sh_encode = _sh_encoder.apply
|
59 |
+
|
60 |
+
|
61 |
+
class SHEncoder(nn.Module):
|
62 |
+
def __init__(self, input_dim=3, degree=4):
|
63 |
+
super().__init__()
|
64 |
+
|
65 |
+
self.input_dim = input_dim # coord dims, must be 3
|
66 |
+
self.degree = degree # 0 ~ 4
|
67 |
+
self.output_dim = degree ** 2
|
68 |
+
|
69 |
+
assert self.input_dim == 3, "SH encoder only support input dim == 3"
|
70 |
+
assert self.degree > 0 and self.degree <= 8, "SH encoder only supports degree in [1, 8]"
|
71 |
+
|
72 |
+
def __repr__(self):
|
73 |
+
return f"SHEncoder: input_dim={self.input_dim} degree={self.degree}"
|
74 |
+
|
75 |
+
def forward(self, inputs, size=1):
|
76 |
+
# inputs: [..., input_dim], normalized real world positions in [-size, size]
|
77 |
+
# return: [..., degree^2]
|
78 |
+
|
79 |
+
inputs = inputs / size # [-1, 1]
|
80 |
+
|
81 |
+
prefix_shape = list(inputs.shape[:-1])
|
82 |
+
inputs = inputs.reshape(-1, self.input_dim)
|
83 |
+
|
84 |
+
outputs = sh_encode(inputs, self.degree, inputs.requires_grad)
|
85 |
+
outputs = outputs.reshape(prefix_shape + [self.output_dim])
|
86 |
+
|
87 |
+
return outputs
|
shencoder/src/bindings.cpp
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
|
3 |
+
#include "shencoder.h"
|
4 |
+
|
5 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
6 |
+
m.def("sh_encode_forward", &sh_encode_forward, "SH encode forward (CUDA)");
|
7 |
+
m.def("sh_encode_backward", &sh_encode_backward, "SH encode backward (CUDA)");
|
8 |
+
}
|
shencoder/src/shencoder.cu
ADDED
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <stdint.h>
|
2 |
+
|
3 |
+
#include <cuda.h>
|
4 |
+
#include <cuda_fp16.h>
|
5 |
+
#include <cuda_runtime.h>
|
6 |
+
|
7 |
+
#include <ATen/cuda/CUDAContext.h>
|
8 |
+
#include <torch/torch.h>
|
9 |
+
|
10 |
+
#include <algorithm>
|
11 |
+
#include <stdexcept>
|
12 |
+
|
13 |
+
#include <cstdio>
|
14 |
+
|
15 |
+
|
16 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
17 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
|
18 |
+
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
|
19 |
+
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
|
20 |
+
|
21 |
+
|
22 |
+
template <typename T>
|
23 |
+
__host__ __device__ T div_round_up(T val, T divisor) {
|
24 |
+
return (val + divisor - 1) / divisor;
|
25 |
+
}
|
26 |
+
|
27 |
+
template <typename scalar_t>
|
28 |
+
__global__ void kernel_sh(
|
29 |
+
const scalar_t * __restrict__ inputs,
|
30 |
+
scalar_t * outputs,
|
31 |
+
uint32_t B, uint32_t D, uint32_t C,
|
32 |
+
scalar_t * dy_dx
|
33 |
+
) {
|
34 |
+
const uint32_t b = threadIdx.x + blockIdx.x * blockDim.x;
|
35 |
+
if (b >= B) return;
|
36 |
+
|
37 |
+
const uint32_t C2 = C * C;
|
38 |
+
|
39 |
+
// locate
|
40 |
+
inputs += b * D;
|
41 |
+
outputs += b * C2;
|
42 |
+
|
43 |
+
scalar_t x = inputs[0], y = inputs[1], z = inputs[2];
|
44 |
+
|
45 |
+
scalar_t xy=x*y, xz=x*z, yz=y*z, x2=x*x, y2=y*y, z2=z*z, xyz=xy*z;
|
46 |
+
scalar_t x4=x2*x2, y4=y2*y2, z4=z2*z2;
|
47 |
+
scalar_t x6=x4*x2, y6=y4*y2, z6=z4*z2;
|
48 |
+
|
49 |
+
auto write_sh = [&]() {
|
50 |
+
outputs[0] = 0.28209479177387814f ; // 1/(2*sqrt(pi))
|
51 |
+
if (C <= 1) { return; }
|
52 |
+
outputs[1] = -0.48860251190291987f*y ; // -sqrt(3)*y/(2*sqrt(pi))
|
53 |
+
outputs[2] = 0.48860251190291987f*z ; // sqrt(3)*z/(2*sqrt(pi))
|
54 |
+
outputs[3] = -0.48860251190291987f*x ; // -sqrt(3)*x/(2*sqrt(pi))
|
55 |
+
if (C <= 2) { return; }
|
56 |
+
outputs[4] = 1.0925484305920792f*xy ; // sqrt(15)*xy/(2*sqrt(pi))
|
57 |
+
outputs[5] = -1.0925484305920792f*yz ; // -sqrt(15)*yz/(2*sqrt(pi))
|
58 |
+
outputs[6] = 0.94617469575755997f*z2 - 0.31539156525251999f ; // sqrt(5)*(3*z2 - 1)/(4*sqrt(pi))
|
59 |
+
outputs[7] = -1.0925484305920792f*xz ; // -sqrt(15)*xz/(2*sqrt(pi))
|
60 |
+
outputs[8] = 0.54627421529603959f*x2 - 0.54627421529603959f*y2 ; // sqrt(15)*(x2 - y2)/(4*sqrt(pi))
|
61 |
+
if (C <= 3) { return; }
|
62 |
+
outputs[9] = 0.59004358992664352f*y*(-3.0f*x2 + y2) ; // sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi))
|
63 |
+
outputs[10] = 2.8906114426405538f*xy*z ; // sqrt(105)*xy*z/(2*sqrt(pi))
|
64 |
+
outputs[11] = 0.45704579946446572f*y*(1.0f - 5.0f*z2) ; // sqrt(42)*y*(1 - 5*z2)/(8*sqrt(pi))
|
65 |
+
outputs[12] = 0.3731763325901154f*z*(5.0f*z2 - 3.0f) ; // sqrt(7)*z*(5*z2 - 3)/(4*sqrt(pi))
|
66 |
+
outputs[13] = 0.45704579946446572f*x*(1.0f - 5.0f*z2) ; // sqrt(42)*x*(1 - 5*z2)/(8*sqrt(pi))
|
67 |
+
outputs[14] = 1.4453057213202769f*z*(x2 - y2) ; // sqrt(105)*z*(x2 - y2)/(4*sqrt(pi))
|
68 |
+
outputs[15] = 0.59004358992664352f*x*(-x2 + 3.0f*y2) ; // sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi))
|
69 |
+
if (C <= 4) { return; }
|
70 |
+
outputs[16] = 2.5033429417967046f*xy*(x2 - y2) ; // 3*sqrt(35)*xy*(x2 - y2)/(4*sqrt(pi))
|
71 |
+
outputs[17] = 1.7701307697799304f*yz*(-3.0f*x2 + y2) ; // 3*sqrt(70)*yz*(-3*x2 + y2)/(8*sqrt(pi))
|
72 |
+
outputs[18] = 0.94617469575756008f*xy*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*xy*(7*z2 - 1)/(4*sqrt(pi))
|
73 |
+
outputs[19] = 0.66904654355728921f*yz*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*yz*(3 - 7*z2)/(8*sqrt(pi))
|
74 |
+
outputs[20] = -3.1735664074561294f*z2 + 3.7024941420321507f*z4 + 0.31735664074561293f ; // 3*(-30*z2 + 35*z4 + 3)/(16*sqrt(pi))
|
75 |
+
outputs[21] = 0.66904654355728921f*xz*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*xz*(3 - 7*z2)/(8*sqrt(pi))
|
76 |
+
outputs[22] = 0.47308734787878004f*(x2 - y2)*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*(x2 - y2)*(7*z2 - 1)/(8*sqrt(pi))
|
77 |
+
outputs[23] = 1.7701307697799304f*xz*(-x2 + 3.0f*y2) ; // 3*sqrt(70)*xz*(-x2 + 3*y2)/(8*sqrt(pi))
|
78 |
+
outputs[24] = -3.7550144126950569f*x2*y2 + 0.62583573544917614f*x4 + 0.62583573544917614f*y4 ; // 3*sqrt(35)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))
|
79 |
+
if (C <= 5) { return; }
|
80 |
+
outputs[25] = 0.65638205684017015f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(154)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
|
81 |
+
outputs[26] = 8.3026492595241645f*xy*z*(x2 - y2) ; // 3*sqrt(385)*xy*z*(x2 - y2)/(4*sqrt(pi))
|
82 |
+
outputs[27] = -0.48923829943525038f*y*(3.0f*x2 - y2)*(9.0f*z2 - 1.0f) ; // -sqrt(770)*y*(3*x2 - y2)*(9*z2 - 1)/(32*sqrt(pi))
|
83 |
+
outputs[28] = 4.7935367849733241f*xy*z*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xy*z*(3*z2 - 1)/(4*sqrt(pi))
|
84 |
+
outputs[29] = 0.45294665119569694f*y*(14.0f*z2 - 21.0f*z4 - 1.0f) ; // sqrt(165)*y*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))
|
85 |
+
outputs[30] = 0.1169503224534236f*z*(-70.0f*z2 + 63.0f*z4 + 15.0f) ; // sqrt(11)*z*(-70*z2 + 63*z4 + 15)/(16*sqrt(pi))
|
86 |
+
outputs[31] = 0.45294665119569694f*x*(14.0f*z2 - 21.0f*z4 - 1.0f) ; // sqrt(165)*x*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))
|
87 |
+
outputs[32] = 2.3967683924866621f*z*(x2 - y2)*(3.0f*z2 - 1.0f) ; // sqrt(1155)*z*(x2 - y2)*(3*z2 - 1)/(8*sqrt(pi))
|
88 |
+
outputs[33] = -0.48923829943525038f*x*(x2 - 3.0f*y2)*(9.0f*z2 - 1.0f) ; // -sqrt(770)*x*(x2 - 3*y2)*(9*z2 - 1)/(32*sqrt(pi))
|
89 |
+
outputs[34] = 2.0756623148810411f*z*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(385)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))
|
90 |
+
outputs[35] = 0.65638205684017015f*x*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(154)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))
|
91 |
+
if (C <= 6) { return; }
|
92 |
+
outputs[36] = 1.3663682103838286f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // sqrt(6006)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))
|
93 |
+
outputs[37] = 2.3666191622317521f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(2002)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
|
94 |
+
outputs[38] = 2.0182596029148963f*xy*(x2 - y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*xy*(x2 - y2)*(11*z2 - 1)/(8*sqrt(pi))
|
95 |
+
outputs[39] = -0.92120525951492349f*yz*(3.0f*x2 - y2)*(11.0f*z2 - 3.0f) ; // -sqrt(2730)*yz*(3*x2 - y2)*(11*z2 - 3)/(32*sqrt(pi))
|
96 |
+
outputs[40] = 0.92120525951492349f*xy*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*xy*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))
|
97 |
+
outputs[41] = 0.58262136251873131f*yz*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*yz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))
|
98 |
+
outputs[42] = 6.6747662381009842f*z2 - 20.024298714302954f*z4 + 14.684485723822165f*z6 - 0.31784601133814211f ; // sqrt(13)*(105*z2 - 315*z4 + 231*z6 - 5)/(32*sqrt(pi))
|
99 |
+
outputs[43] = 0.58262136251873131f*xz*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*xz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))
|
100 |
+
outputs[44] = 0.46060262975746175f*(x2 - y2)*(11.0f*z2*(3.0f*z2 - 1.0f) - 7.0f*z2 + 1.0f) ; // sqrt(2730)*(x2 - y2)*(11*z2*(3*z2 - 1) - 7*z2 + 1)/(64*sqrt(pi))
|
101 |
+
outputs[45] = -0.92120525951492349f*xz*(x2 - 3.0f*y2)*(11.0f*z2 - 3.0f) ; // -sqrt(2730)*xz*(x2 - 3*y2)*(11*z2 - 3)/(32*sqrt(pi))
|
102 |
+
outputs[46] = 0.50456490072872406f*(11.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(91)*(11*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi))
|
103 |
+
outputs[47] = 2.3666191622317521f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(2002)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))
|
104 |
+
outputs[48] = 10.247761577878714f*x2*y4 - 10.247761577878714f*x4*y2 + 0.6831841051919143f*x6 - 0.6831841051919143f*y6 ; // sqrt(6006)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi))
|
105 |
+
if (C <= 7) { return; }
|
106 |
+
outputs[49] = 0.70716273252459627f*y*(-21.0f*x2*y4 + 35.0f*x4*y2 - 7.0f*x6 + y6) ; // 3*sqrt(715)*y*(-21*x2*y4 + 35*x4*y2 - 7*x6 + y6)/(64*sqrt(pi))
|
107 |
+
outputs[50] = 5.2919213236038001f*xy*z*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 3*sqrt(10010)*xy*z*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))
|
108 |
+
outputs[51] = -0.51891557872026028f*y*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // -3*sqrt(385)*y*(13*z2 - 1)*(-10*x2*y2 + 5*x4 + y4)/(64*sqrt(pi))
|
109 |
+
outputs[52] = 4.1513246297620823f*xy*z*(x2 - y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xy*z*(x2 - y2)*(13*z2 - 3)/(8*sqrt(pi))
|
110 |
+
outputs[53] = -0.15645893386229404f*y*(3.0f*x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -3*sqrt(35)*y*(3*x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi))
|
111 |
+
outputs[54] = 0.44253269244498261f*xy*z*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xy*z*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))
|
112 |
+
outputs[55] = 0.090331607582517306f*y*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f) ; // sqrt(105)*y*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))
|
113 |
+
outputs[56] = 0.068284276912004949f*z*(315.0f*z2 - 693.0f*z4 + 429.0f*z6 - 35.0f) ; // sqrt(15)*z*(315*z2 - 693*z4 + 429*z6 - 35)/(32*sqrt(pi))
|
114 |
+
outputs[57] = 0.090331607582517306f*x*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f) ; // sqrt(105)*x*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))
|
115 |
+
outputs[58] = 0.07375544874083044f*z*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) - 187.0f*z2 + 45.0f) ; // sqrt(70)*z*(x2 - y2)*(143*z2*(3*z2 - 1) - 187*z2 + 45)/(64*sqrt(pi))
|
116 |
+
outputs[59] = -0.15645893386229404f*x*(x2 - 3.0f*y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -3*sqrt(35)*x*(x2 - 3*y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi))
|
117 |
+
outputs[60] = 1.0378311574405206f*z*(13.0f*z2 - 3.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(385)*z*(13*z2 - 3)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi))
|
118 |
+
outputs[61] = -0.51891557872026028f*x*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // -3*sqrt(385)*x*(13*z2 - 1)*(-10*x2*y2 + x4 + 5*y4)/(64*sqrt(pi))
|
119 |
+
outputs[62] = 2.6459606618019f*z*(15.0f*x2*y4 - 15.0f*x4*y2 + x6 - y6) ; // 3*sqrt(10010)*z*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi))
|
120 |
+
outputs[63] = 0.70716273252459627f*x*(-35.0f*x2*y4 + 21.0f*x4*y2 - x6 + 7.0f*y6) ; // 3*sqrt(715)*x*(-35*x2*y4 + 21*x4*y2 - x6 + 7*y6)/(64*sqrt(pi))
|
121 |
+
};
|
122 |
+
|
123 |
+
write_sh();
|
124 |
+
|
125 |
+
if (dy_dx) {
|
126 |
+
scalar_t *dx = dy_dx + b * D * C2;
|
127 |
+
scalar_t *dy = dx + C2;
|
128 |
+
scalar_t *dz = dy + C2;
|
129 |
+
|
130 |
+
auto write_sh_dx = [&]() {
|
131 |
+
dx[0] = 0.0f ; // 0
|
132 |
+
if (C <= 1) { return; }
|
133 |
+
dx[1] = 0.0f ; // 0
|
134 |
+
dx[2] = 0.0f ; // 0
|
135 |
+
dx[3] = -0.48860251190291992f ; // -sqrt(3)/(2*sqrt(pi))
|
136 |
+
if (C <= 2) { return; }
|
137 |
+
dx[4] = 1.0925484305920792f*y ; // sqrt(15)*y/(2*sqrt(pi))
|
138 |
+
dx[5] = 0.0f ; // 0
|
139 |
+
dx[6] = 0.0f ; // 0
|
140 |
+
dx[7] = -1.0925484305920792f*z ; // -sqrt(15)*z/(2*sqrt(pi))
|
141 |
+
dx[8] = 1.0925484305920792f*x ; // sqrt(15)*x/(2*sqrt(pi))
|
142 |
+
if (C <= 3) { return; }
|
143 |
+
dx[9] = -3.5402615395598609f*xy ; // -3*sqrt(70)*xy/(4*sqrt(pi))
|
144 |
+
dx[10] = 2.8906114426405538f*yz ; // sqrt(105)*yz/(2*sqrt(pi))
|
145 |
+
dx[11] = 0.0f ; // 0
|
146 |
+
dx[12] = 0.0f ; // 0
|
147 |
+
dx[13] = 0.45704579946446572f - 2.2852289973223288f*z2 ; // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi))
|
148 |
+
dx[14] = 2.8906114426405538f*xz ; // sqrt(105)*xz/(2*sqrt(pi))
|
149 |
+
dx[15] = -1.7701307697799304f*x2 + 1.7701307697799304f*y2 ; // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi))
|
150 |
+
if (C <= 4) { return; }
|
151 |
+
dx[16] = 2.5033429417967046f*y*(3.0f*x2 - y2) ; // 3*sqrt(35)*y*(3*x2 - y2)/(4*sqrt(pi))
|
152 |
+
dx[17] = -10.620784618679583f*xy*z ; // -9*sqrt(70)*xy*z/(4*sqrt(pi))
|
153 |
+
dx[18] = 0.94617469575756008f*y*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*y*(7*z2 - 1)/(4*sqrt(pi))
|
154 |
+
dx[19] = 0.0f ; // 0
|
155 |
+
dx[20] = 0.0f ; // 0
|
156 |
+
dx[21] = 0.66904654355728921f*z*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi))
|
157 |
+
dx[22] = 0.94617469575756008f*x*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi))
|
158 |
+
dx[23] = 5.3103923093397913f*z*(-x2 + y2) ; // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi))
|
159 |
+
dx[24] = 2.5033429417967046f*x*(x2 - 3.0f*y2) ; // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi))
|
160 |
+
if (C <= 5) { return; }
|
161 |
+
dx[25] = 13.127641136803401f*xy*(-x2 + y2) ; // 15*sqrt(154)*xy*(-x2 + y2)/(8*sqrt(pi))
|
162 |
+
dx[26] = 8.3026492595241645f*yz*(3.0f*x2 - y2) ; // 3*sqrt(385)*yz*(3*x2 - y2)/(4*sqrt(pi))
|
163 |
+
dx[27] = 2.9354297966115022f*xy*(1.0f - 9.0f*z2) ; // 3*sqrt(770)*xy*(1 - 9*z2)/(16*sqrt(pi))
|
164 |
+
dx[28] = 4.7935367849733241f*yz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*yz*(3*z2 - 1)/(4*sqrt(pi))
|
165 |
+
dx[29] = 0.0f ; // 0
|
166 |
+
dx[30] = 0.0f ; // 0
|
167 |
+
dx[31] = 6.3412531167397574f*z2 - 9.5118796751096362f*z4 - 0.45294665119569694f ; // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))
|
168 |
+
dx[32] = 4.7935367849733241f*xz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi))
|
169 |
+
dx[33] = -13.209434084751759f*x2*z2 + 1.4677148983057511f*x2 + 13.209434084751759f*y2*z2 - 1.4677148983057511f*y2 ; // 3*sqrt(770)*(-9*x2*z2 + x2 + 9*y2*z2 - y2)/(32*sqrt(pi))
|
170 |
+
dx[34] = 8.3026492595241645f*xz*(x2 - 3.0f*y2) ; // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi))
|
171 |
+
dx[35] = 19.6914617052051f*x2*y2 - 3.2819102842008503f*x4 - 3.2819102842008503f*y4 ; // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))
|
172 |
+
if (C <= 6) { return; }
|
173 |
+
dx[36] = 4.0991046311514854f*y*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // 3*sqrt(6006)*y*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi))
|
174 |
+
dx[37] = 47.332383244635047f*xy*z*(-x2 + y2) ; // 15*sqrt(2002)*xy*z*(-x2 + y2)/(8*sqrt(pi))
|
175 |
+
dx[38] = 2.0182596029148963f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi))
|
176 |
+
dx[39] = 5.5272315570895412f*xy*z*(3.0f - 11.0f*z2) ; // 3*sqrt(2730)*xy*z*(3 - 11*z2)/(16*sqrt(pi))
|
177 |
+
dx[40] = 0.92120525951492349f*y*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*y*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))
|
178 |
+
dx[41] = 0.0f ; // 0
|
179 |
+
dx[42] = 0.0f ; // 0
|
180 |
+
dx[43] = 0.58262136251873131f*z*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))
|
181 |
+
dx[44] = 0.92120525951492349f*x*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))
|
182 |
+
dx[45] = -2.7636157785447706f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi))
|
183 |
+
dx[46] = 2.0182596029148963f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi))
|
184 |
+
dx[47] = 11.833095811158762f*z*(6.0f*x2*y2 - x4 - y4) ; // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))
|
185 |
+
dx[48] = 4.0991046311514854f*x*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))
|
186 |
+
if (C <= 7) { return; }
|
187 |
+
dx[49] = 9.9002782553443485f*xy*(10.0f*x2*y2 - 3.0f*x4 - 3.0f*y4) ; // 21*sqrt(715)*xy*(10*x2*y2 - 3*x4 - 3*y4)/(32*sqrt(pi))
|
188 |
+
dx[50] = 15.875763970811402f*yz*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // 9*sqrt(10010)*yz*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi))
|
189 |
+
dx[51] = -10.378311574405206f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // -15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi))
|
190 |
+
dx[52] = 4.1513246297620823f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi))
|
191 |
+
dx[53] = 0.93875360317376422f*xy*(66.0f*z2 - 143.0f*z4 - 3.0f) ; // 9*sqrt(35)*xy*(66*z2 - 143*z4 - 3)/(32*sqrt(pi))
|
192 |
+
dx[54] = 0.44253269244498261f*yz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*yz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))
|
193 |
+
dx[55] = 0.0f ; // 0
|
194 |
+
dx[56] = 0.0f ; // 0
|
195 |
+
dx[57] = -12.194767023639836f*z2 + 44.714145753346067f*z4 - 38.752259652899923f*z6 + 0.45165803791258652f ; // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))
|
196 |
+
dx[58] = 0.44253269244498261f*xz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))
|
197 |
+
dx[59] = 30.97886890473422f*x2*z2 - 67.120882626924143f*x2*z4 - 1.4081304047606462f*x2 - 30.97886890473422f*y2*z2 + 67.120882626924143f*y2*z4 + 1.4081304047606462f*y2 ; // 9*sqrt(35)*(66*x2*z2 - 143*x2*z4 - 3*x2 - 66*y2*z2 + 143*y2*z4 + 3*y2)/(64*sqrt(pi))
|
198 |
+
dx[60] = 4.1513246297620823f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi))
|
199 |
+
dx[61] = -0.51891557872026028f*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 4.0f*x2*(x2 - 5.0f*y2) + x4 + 5.0f*y4) ; // -3*sqrt(385)*(13*z2 - 1)*(-10*x2*y2 + 4*x2*(x2 - 5*y2) + x4 + 5*y4)/(64*sqrt(pi))
|
200 |
+
dx[62] = 15.875763970811402f*xz*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))
|
201 |
+
dx[63] = -74.252086915082614f*x2*y4 + 74.252086915082614f*x4*y2 - 4.9501391276721742f*x6 + 4.9501391276721742f*y6 ; // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi))
|
202 |
+
};
|
203 |
+
|
204 |
+
auto write_sh_dy = [&]() {
|
205 |
+
dy[0] = 0.0f ; // 0
|
206 |
+
if (C <= 1) { return; }
|
207 |
+
dy[1] = -0.48860251190291992f ; // -sqrt(3)/(2*sqrt(pi))
|
208 |
+
dy[2] = 0.0f ; // 0
|
209 |
+
dy[3] = 0.0f ; // 0
|
210 |
+
if (C <= 2) { return; }
|
211 |
+
dy[4] = 1.0925484305920792f*x ; // sqrt(15)*x/(2*sqrt(pi))
|
212 |
+
dy[5] = -1.0925484305920792f*z ; // -sqrt(15)*z/(2*sqrt(pi))
|
213 |
+
dy[6] = 0.0f ; // 0
|
214 |
+
dy[7] = 0.0f ; // 0
|
215 |
+
dy[8] = -1.0925484305920792f*y ; // -sqrt(15)*y/(2*sqrt(pi))
|
216 |
+
if (C <= 3) { return; }
|
217 |
+
dy[9] = -1.7701307697799304f*x2 + 1.7701307697799304f*y2 ; // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi))
|
218 |
+
dy[10] = 2.8906114426405538f*xz ; // sqrt(105)*xz/(2*sqrt(pi))
|
219 |
+
dy[11] = 0.45704579946446572f - 2.2852289973223288f*z2 ; // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi))
|
220 |
+
dy[12] = 0.0f ; // 0
|
221 |
+
dy[13] = 0.0f ; // 0
|
222 |
+
dy[14] = -2.8906114426405538f*yz ; // -sqrt(105)*yz/(2*sqrt(pi))
|
223 |
+
dy[15] = 3.5402615395598609f*xy ; // 3*sqrt(70)*xy/(4*sqrt(pi))
|
224 |
+
if (C <= 4) { return; }
|
225 |
+
dy[16] = 2.5033429417967046f*x*(x2 - 3.0f*y2) ; // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi))
|
226 |
+
dy[17] = 5.3103923093397913f*z*(-x2 + y2) ; // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi))
|
227 |
+
dy[18] = 0.94617469575756008f*x*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi))
|
228 |
+
dy[19] = 0.66904654355728921f*z*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi))
|
229 |
+
dy[20] = 0.0f ; // 0
|
230 |
+
dy[21] = 0.0f ; // 0
|
231 |
+
dy[22] = 0.94617469575756008f*y*(1.0f - 7.0f*z2) ; // 3*sqrt(5)*y*(1 - 7*z2)/(4*sqrt(pi))
|
232 |
+
dy[23] = 10.620784618679583f*xy*z ; // 9*sqrt(70)*xy*z/(4*sqrt(pi))
|
233 |
+
dy[24] = 2.5033429417967046f*y*(-3.0f*x2 + y2) ; // 3*sqrt(35)*y*(-3*x2 + y2)/(4*sqrt(pi))
|
234 |
+
if (C <= 5) { return; }
|
235 |
+
dy[25] = 19.6914617052051f*x2*y2 - 3.2819102842008503f*x4 - 3.2819102842008503f*y4 ; // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))
|
236 |
+
dy[26] = 8.3026492595241645f*xz*(x2 - 3.0f*y2) ; // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi))
|
237 |
+
dy[27] = -1.4677148983057511f*(x2 - y2)*(9.0f*z2 - 1.0f) ; // -3*sqrt(770)*(x2 - y2)*(9*z2 - 1)/(32*sqrt(pi))
|
238 |
+
dy[28] = 4.7935367849733241f*xz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi))
|
239 |
+
dy[29] = 6.3412531167397574f*z2 - 9.5118796751096362f*z4 - 0.45294665119569694f ; // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))
|
240 |
+
dy[30] = 0.0f ; // 0
|
241 |
+
dy[31] = 0.0f ; // 0
|
242 |
+
dy[32] = 4.7935367849733241f*yz*(1.0f - 3.0f*z2) ; // sqrt(1155)*yz*(1 - 3*z2)/(4*sqrt(pi))
|
243 |
+
dy[33] = 2.9354297966115022f*xy*(9.0f*z2 - 1.0f) ; // 3*sqrt(770)*xy*(9*z2 - 1)/(16*sqrt(pi))
|
244 |
+
dy[34] = 8.3026492595241645f*yz*(-3.0f*x2 + y2) ; // 3*sqrt(385)*yz*(-3*x2 + y2)/(4*sqrt(pi))
|
245 |
+
dy[35] = 13.127641136803401f*xy*(x2 - y2) ; // 15*sqrt(154)*xy*(x2 - y2)/(8*sqrt(pi))
|
246 |
+
if (C <= 6) { return; }
|
247 |
+
dy[36] = 4.0991046311514854f*x*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))
|
248 |
+
dy[37] = 11.833095811158762f*z*(6.0f*x2*y2 - x4 - y4) ; // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))
|
249 |
+
dy[38] = 2.0182596029148963f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi))
|
250 |
+
dy[39] = -2.7636157785447706f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi))
|
251 |
+
dy[40] = 0.92120525951492349f*x*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))
|
252 |
+
dy[41] = 0.58262136251873131f*z*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))
|
253 |
+
dy[42] = 0.0f ; // 0
|
254 |
+
dy[43] = 0.0f ; // 0
|
255 |
+
dy[44] = 0.92120525951492349f*y*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // sqrt(2730)*y*(18*z2 - 33*z4 - 1)/(32*sqrt(pi))
|
256 |
+
dy[45] = 5.5272315570895412f*xy*z*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(16*sqrt(pi))
|
257 |
+
dy[46] = -2.0182596029148963f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi))
|
258 |
+
dy[47] = 47.332383244635047f*xy*z*(x2 - y2) ; // 15*sqrt(2002)*xy*z*(x2 - y2)/(8*sqrt(pi))
|
259 |
+
dy[48] = 4.0991046311514854f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(6006)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
|
260 |
+
if (C <= 7) { return; }
|
261 |
+
dy[49] = -74.252086915082614f*x2*y4 + 74.252086915082614f*x4*y2 - 4.9501391276721742f*x6 + 4.9501391276721742f*y6 ; // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi))
|
262 |
+
dy[50] = 15.875763970811402f*xz*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))
|
263 |
+
dy[51] = 0.51891557872026028f*(13.0f*z2 - 1.0f)*(10.0f*x2*y2 - 5.0f*x4 + 4.0f*y2*(5.0f*x2 - y2) - y4) ; // 3*sqrt(385)*(13*z2 - 1)*(10*x2*y2 - 5*x4 + 4*y2*(5*x2 - y2) - y4)/(64*sqrt(pi))
|
264 |
+
dy[52] = 4.1513246297620823f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi))
|
265 |
+
dy[53] = -0.46937680158688211f*(x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -9*sqrt(35)*(x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi))
|
266 |
+
dy[54] = 0.44253269244498261f*xz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))
|
267 |
+
dy[55] = -12.194767023639836f*z2 + 44.714145753346067f*z4 - 38.752259652899923f*z6 + 0.45165803791258652f ; // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))
|
268 |
+
dy[56] = 0.0f ; // 0
|
269 |
+
dy[57] = 0.0f ; // 0
|
270 |
+
dy[58] = 0.44253269244498261f*yz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 3*sqrt(70)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi))
|
271 |
+
dy[59] = 0.93875360317376422f*xy*(-66.0f*z2 + 143.0f*z4 + 3.0f) ; // 9*sqrt(35)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi))
|
272 |
+
dy[60] = -4.1513246297620823f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // -3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi))
|
273 |
+
dy[61] = 10.378311574405206f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // 15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi))
|
274 |
+
dy[62] = 15.875763970811402f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 9*sqrt(10010)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
|
275 |
+
dy[63] = 9.9002782553443485f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 21*sqrt(715)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))
|
276 |
+
};
|
277 |
+
|
278 |
+
auto write_sh_dz = [&]() {
|
279 |
+
dz[0] = 0.0f ; // 0
|
280 |
+
if (C <= 1) { return; }
|
281 |
+
dz[1] = 0.0f ; // 0
|
282 |
+
dz[2] = 0.48860251190291992f ; // sqrt(3)/(2*sqrt(pi))
|
283 |
+
dz[3] = 0.0f ; // 0
|
284 |
+
if (C <= 2) { return; }
|
285 |
+
dz[4] = 0.0f ; // 0
|
286 |
+
dz[5] = -1.0925484305920792f*y ; // -sqrt(15)*y/(2*sqrt(pi))
|
287 |
+
dz[6] = 1.8923493915151202f*z ; // 3*sqrt(5)*z/(2*sqrt(pi))
|
288 |
+
dz[7] = -1.0925484305920792f*x ; // -sqrt(15)*x/(2*sqrt(pi))
|
289 |
+
dz[8] = 0.0f ; // 0
|
290 |
+
if (C <= 3) { return; }
|
291 |
+
dz[9] = 0.0f ; // 0
|
292 |
+
dz[10] = 2.8906114426405538f*xy ; // sqrt(105)*xy/(2*sqrt(pi))
|
293 |
+
dz[11] = -4.5704579946446566f*yz ; // -5*sqrt(42)*yz/(4*sqrt(pi))
|
294 |
+
dz[12] = 5.597644988851731f*z2 - 1.1195289977703462f ; // 3*sqrt(7)*(5*z2 - 1)/(4*sqrt(pi))
|
295 |
+
dz[13] = -4.5704579946446566f*xz ; // -5*sqrt(42)*xz/(4*sqrt(pi))
|
296 |
+
dz[14] = 1.4453057213202769f*x2 - 1.4453057213202769f*y2 ; // sqrt(105)*(x2 - y2)/(4*sqrt(pi))
|
297 |
+
dz[15] = 0.0f ; // 0
|
298 |
+
if (C <= 4) { return; }
|
299 |
+
dz[16] = 0.0f ; // 0
|
300 |
+
dz[17] = 1.7701307697799304f*y*(-3.0f*x2 + y2) ; // 3*sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi))
|
301 |
+
dz[18] = 13.246445740605839f*xy*z ; // 21*sqrt(5)*xy*z/(2*sqrt(pi))
|
302 |
+
dz[19] = 2.0071396306718676f*y*(1.0f - 7.0f*z2) ; // 9*sqrt(10)*y*(1 - 7*z2)/(8*sqrt(pi))
|
303 |
+
dz[20] = 14.809976568128603f*pow(z, 3) - 6.3471328149122579f*z ; // (105*z**3 - 45*z)/(4*sqrt(pi))
|
304 |
+
dz[21] = 2.0071396306718676f*x*(1.0f - 7.0f*z2) ; // 9*sqrt(10)*x*(1 - 7*z2)/(8*sqrt(pi))
|
305 |
+
dz[22] = 6.6232228703029197f*z*(x2 - y2) ; // 21*sqrt(5)*z*(x2 - y2)/(4*sqrt(pi))
|
306 |
+
dz[23] = 1.7701307697799304f*x*(-x2 + 3.0f*y2) ; // 3*sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi))
|
307 |
+
dz[24] = 0.0f ; // 0
|
308 |
+
if (C <= 5) { return; }
|
309 |
+
dz[25] = 0.0f ; // 0
|
310 |
+
dz[26] = 8.3026492595241645f*xy*(x2 - y2) ; // 3*sqrt(385)*xy*(x2 - y2)/(4*sqrt(pi))
|
311 |
+
dz[27] = 8.8062893898345074f*yz*(-3.0f*x2 + y2) ; // 9*sqrt(770)*yz*(-3*x2 + y2)/(16*sqrt(pi))
|
312 |
+
dz[28] = 4.7935367849733241f*xy*(9.0f*z2 - 1.0f) ; // sqrt(1155)*xy*(9*z2 - 1)/(4*sqrt(pi))
|
313 |
+
dz[29] = 12.682506233479513f*yz*(1.0f - 3.0f*z2) ; // 7*sqrt(165)*yz*(1 - 3*z2)/(4*sqrt(pi))
|
314 |
+
dz[30] = -24.559567715218954f*z2 + 36.839351572828434f*z4 + 1.754254836801354f ; // 15*sqrt(11)*(-14*z2 + 21*z4 + 1)/(16*sqrt(pi))
|
315 |
+
dz[31] = 12.682506233479513f*xz*(1.0f - 3.0f*z2) ; // 7*sqrt(165)*xz*(1 - 3*z2)/(4*sqrt(pi))
|
316 |
+
dz[32] = 2.3967683924866621f*(x2 - y2)*(9.0f*z2 - 1.0f) ; // sqrt(1155)*(x2 - y2)*(9*z2 - 1)/(8*sqrt(pi))
|
317 |
+
dz[33] = 8.8062893898345074f*xz*(-x2 + 3.0f*y2) ; // 9*sqrt(770)*xz*(-x2 + 3*y2)/(16*sqrt(pi))
|
318 |
+
dz[34] = -12.453973889286246f*x2*y2 + 2.0756623148810411f*x4 + 2.0756623148810411f*y4 ; // 3*sqrt(385)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))
|
319 |
+
dz[35] = 0.0f ; // 0
|
320 |
+
if (C <= 6) { return; }
|
321 |
+
dz[36] = 0.0f ; // 0
|
322 |
+
dz[37] = 2.3666191622317521f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(2002)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
|
323 |
+
dz[38] = 44.401711264127719f*xy*z*(x2 - y2) ; // 33*sqrt(91)*xy*z*(x2 - y2)/(4*sqrt(pi))
|
324 |
+
dz[39] = -2.7636157785447706f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(2730)*y*(3*x2 - y2)*(11*z2 - 1)/(32*sqrt(pi))
|
325 |
+
dz[40] = 11.054463114179082f*xy*z*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(8*sqrt(pi))
|
326 |
+
dz[41] = 2.9131068125936568f*y*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // 5*sqrt(273)*y*(18*z2 - 33*z4 - 1)/(16*sqrt(pi))
|
327 |
+
dz[42] = 2.6699064952403937f*z*(-30.0f*z2 + 33.0f*z4 + 5.0f) ; // 21*sqrt(13)*z*(-30*z2 + 33*z4 + 5)/(16*sqrt(pi))
|
328 |
+
dz[43] = 2.9131068125936568f*x*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // 5*sqrt(273)*x*(18*z2 - 33*z4 - 1)/(16*sqrt(pi))
|
329 |
+
dz[44] = 5.5272315570895412f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(16*sqrt(pi))
|
330 |
+
dz[45] = -2.7636157785447706f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(2730)*x*(x2 - 3*y2)*(11*z2 - 1)/(32*sqrt(pi))
|
331 |
+
dz[46] = 11.10042781603193f*z*(-6.0f*x2*y2 + x4 + y4) ; // 33*sqrt(91)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))
|
332 |
+
dz[47] = 2.3666191622317521f*x*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(2002)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))
|
333 |
+
dz[48] = 0.0f ; // 0
|
334 |
+
if (C <= 7) { return; }
|
335 |
+
dz[49] = 0.0f ; // 0
|
336 |
+
dz[50] = 5.2919213236038001f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 3*sqrt(10010)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))
|
337 |
+
dz[51] = 13.491805046726766f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 39*sqrt(385)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))
|
338 |
+
dz[52] = 12.453973889286248f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // 9*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(8*sqrt(pi))
|
339 |
+
dz[53] = -6.8841930899409371f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // -33*sqrt(35)*yz*(3*x2 - y2)*(13*z2 - 3)/(16*sqrt(pi))
|
340 |
+
dz[54] = 2.2126634622249131f*xy*(-66.0f*z2 + 143.0f*z4 + 3.0f) ; // 15*sqrt(70)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi))
|
341 |
+
dz[55] = 1.6259689364853116f*yz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 9*sqrt(105)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi))
|
342 |
+
dz[56] = 64.528641681844675f*z2 - 236.60501950009714f*z4 + 205.05768356675085f*z6 - 2.3899496919201733f ; // 7*sqrt(15)*(135*z2 - 495*z4 + 429*z6 - 5)/(32*sqrt(pi))
|
343 |
+
dz[57] = 1.6259689364853116f*xz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 9*sqrt(105)*xz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi))
|
344 |
+
dz[58] = 0.07375544874083044f*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) + 132.0f*z2*(13.0f*z2 - 5.0f) - 187.0f*z2 + 45.0f) ; // sqrt(70)*(x2 - y2)*(143*z2*(3*z2 - 1) + 132*z2*(13*z2 - 5) - 187*z2 + 45)/(64*sqrt(pi))
|
345 |
+
dz[59] = -6.8841930899409371f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // -33*sqrt(35)*xz*(x2 - 3*y2)*(13*z2 - 3)/(16*sqrt(pi))
|
346 |
+
dz[60] = 3.1134934723215619f*(13.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 9*sqrt(385)*(13*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi))
|
347 |
+
dz[61] = 13.491805046726766f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 39*sqrt(385)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))
|
348 |
+
dz[62] = 39.6894099270285f*x2*y4 - 39.6894099270285f*x4*y2 + 2.6459606618019f*x6 - 2.6459606618019f*y6 ; // 3*sqrt(10010)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi))
|
349 |
+
dz[63] = 0.0f ; // 0
|
350 |
+
};
|
351 |
+
write_sh_dx();
|
352 |
+
write_sh_dy();
|
353 |
+
write_sh_dz();
|
354 |
+
}
|
355 |
+
}
|
356 |
+
|
357 |
+
|
358 |
+
template <typename scalar_t>
|
359 |
+
__global__ void kernel_sh_backward(
|
360 |
+
const scalar_t * __restrict__ grad,
|
361 |
+
const scalar_t * __restrict__ inputs,
|
362 |
+
uint32_t B, uint32_t D, uint32_t C,
|
363 |
+
const scalar_t * __restrict__ dy_dx,
|
364 |
+
scalar_t * grad_inputs
|
365 |
+
) {
|
366 |
+
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
|
367 |
+
const uint32_t b = t / D;
|
368 |
+
if (b >= B) return;
|
369 |
+
|
370 |
+
const uint32_t d = t - b * D;
|
371 |
+
const uint32_t C2 = C * C;
|
372 |
+
|
373 |
+
// locate
|
374 |
+
grad += b * C2;
|
375 |
+
dy_dx += b * D * C2 + d * C2;
|
376 |
+
|
377 |
+
for (int ch = 0; ch < C2; ch++) {
|
378 |
+
grad_inputs[t] += grad[ch] * dy_dx[ch];
|
379 |
+
//printf("t=%d, b=%d, d=%d, ch=%d, grad=%f (+= %f * %f)\n", t, b, d, ch, grad_inputs[t], grad[ch], dy_dx[ch]);
|
380 |
+
}
|
381 |
+
|
382 |
+
}
|
383 |
+
|
384 |
+
// inputs: [B, D], float, in [0, 1]
|
385 |
+
// outputs: [B, L * C], float
|
386 |
+
template <typename scalar_t>
|
387 |
+
void sh_encode_forward_cuda(const scalar_t *inputs, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, scalar_t *dy_dx) {
|
388 |
+
static constexpr uint32_t N_THREADS = 256;
|
389 |
+
kernel_sh<scalar_t><<<div_round_up(B, N_THREADS), N_THREADS>>>(inputs, outputs, B, D, C, dy_dx);
|
390 |
+
}
|
391 |
+
|
392 |
+
|
393 |
+
template <typename scalar_t>
|
394 |
+
void sh_encode_backward_cuda(const scalar_t *grad, const scalar_t *inputs, const uint32_t B, const uint32_t D, const uint32_t C, scalar_t *dy_dx, scalar_t *grad_inputs) {
|
395 |
+
static constexpr uint32_t N_THREADS = 256;
|
396 |
+
kernel_sh_backward<scalar_t><<<div_round_up(B * D, N_THREADS), N_THREADS>>>(grad, inputs, B, D, C, dy_dx, grad_inputs);
|
397 |
+
}
|
398 |
+
|
399 |
+
|
400 |
+
void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional<at::Tensor> dy_dx) {
|
401 |
+
CHECK_CUDA(inputs);
|
402 |
+
CHECK_CUDA(outputs);
|
403 |
+
// CHECK_CUDA(dy_dx);
|
404 |
+
|
405 |
+
CHECK_CONTIGUOUS(inputs);
|
406 |
+
CHECK_CONTIGUOUS(outputs);
|
407 |
+
// CHECK_CONTIGUOUS(dy_dx);
|
408 |
+
|
409 |
+
CHECK_IS_FLOATING(inputs);
|
410 |
+
CHECK_IS_FLOATING(outputs);
|
411 |
+
// CHECK_IS_FLOATING(dy_dx);
|
412 |
+
|
413 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
414 |
+
inputs.scalar_type(), "sh_encode_forward_cuda", ([&] {
|
415 |
+
sh_encode_forward_cuda<scalar_t>(inputs.data_ptr<scalar_t>(), outputs.data_ptr<scalar_t>(), B, D, C, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr);
|
416 |
+
}));
|
417 |
+
}
|
418 |
+
|
419 |
+
void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs) {
|
420 |
+
CHECK_CUDA(grad);
|
421 |
+
CHECK_CUDA(inputs);
|
422 |
+
CHECK_CUDA(dy_dx);
|
423 |
+
CHECK_CUDA(grad_inputs);
|
424 |
+
|
425 |
+
CHECK_CONTIGUOUS(grad);
|
426 |
+
CHECK_CONTIGUOUS(inputs);
|
427 |
+
CHECK_CONTIGUOUS(dy_dx);
|
428 |
+
CHECK_CONTIGUOUS(grad_inputs);
|
429 |
+
|
430 |
+
CHECK_IS_FLOATING(grad);
|
431 |
+
CHECK_IS_FLOATING(inputs);
|
432 |
+
CHECK_IS_FLOATING(dy_dx);
|
433 |
+
CHECK_IS_FLOATING(grad_inputs);
|
434 |
+
|
435 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
436 |
+
grad.scalar_type(), "sh_encode_backward_cuda", ([&] {
|
437 |
+
sh_encode_backward_cuda<scalar_t>(grad.data_ptr<scalar_t>(), inputs.data_ptr<scalar_t>(), B, D, C, dy_dx.data_ptr<scalar_t>(), grad_inputs.data_ptr<scalar_t>());
|
438 |
+
}));
|
439 |
+
}
|
shencoder/src/shencoder.h
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pragma once
|
2 |
+
|
3 |
+
#include <stdint.h>
|
4 |
+
#include <torch/torch.h>
|
5 |
+
|
6 |
+
// inputs: [B, D], float, in [-1, 1]
|
7 |
+
// outputs: [B, F], float
|
8 |
+
|
9 |
+
void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional<at::Tensor> dy_dx);
|
10 |
+
void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs);
|